From 132f88e8d57f768ff690f197db30b20de837feb6 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 26 Feb 2025 17:07:44 +0000 Subject: [PATCH 0001/1769] Fix ROCm builds not finding numa library --- build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 08b6bd3ff8d6..8afe8b17252c 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 + dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ From 2bb7dbaa32f2eb42b785edbe377e15b3f5e73f28 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 9 Mar 2025 01:26:40 +0000 Subject: [PATCH 0002/1769] add jax.input_saved_vjp to let user pass primal inputs to bwd pass Co-authored-by: Dougal Maclaurin --- jax/_src/api.py | 80 +++++++++++++++++++++++++++++++++++- jax/experimental/__init__.py | 4 ++ tests/api_test.py | 58 ++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b14d809621d..9bbaf01d2c50 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -25,6 +25,7 @@ import atexit import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from functools import partial, lru_cache import inspect import math @@ -41,7 +42,8 @@ from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, tree_flatten_with_path) + prefix_errors, generate_key_paths, tree_flatten_with_path, + equality_errors_pytreedef) from jax._src import config from jax._src import core from jax._src import dispatch @@ -2031,6 +2033,82 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) +def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, + allow_unused: bool = True, allow_opaque: bool = True): + if len(which) != len(primals): + raise ValueError( + "length of 'which' argument must equal the number of primal input values, " + f"but got {len(which)=} and {len(primals)=}") + + dbg = debug_info("saved_input_vjp", f, primals, {}) + fun = lu.wrap_init(f, debug_info=dbg) + primals_flat, in_tree = tree_flatten(primals) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat) + primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) + id_map = {id(x): i for i, x in enumerate(primals_filt)} + opaque_residuals = [] + res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + for r in residuals] + f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(), + jaxpr, opaque_residuals) + + if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): + unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) + if w and id(x) not in res_ids] + assert unused + if len(unused) == 1: + (i, a), = unused + start, was = "an input value", "was" + msg = f" {dbg.arg_names[i]} of type {a.str_short()}" + else: + start, was = "multiple input values", "were" + msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}" + for i, a in unused) + raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} " + f"not used by the backward pass:{msg}") + + if not allow_opaque and opaque_residuals: + msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals) + raise Exception(f"with {allow_opaque=}, the backward pass requires opaque " + f"(non-input) residuals: {msg}") + + out_primals = tree_unflatten(out_tree(), out_primals_flat) + return out_primals, f_vjp + +def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr, + opaque_residuals, ct, *saved_primals): + primals_filtered, filtered_tree_ = tree_flatten(saved_primals) + if filtered_tree != filtered_tree_: + raise ValueError( + "inputs passed to f_vjp must be a tuple of (pytrees of) " + "arrays with the same structure as\n" + " tuple(x for x, w in zip(inputs, which) if w)\n" + "given the original call\n" + " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n" + "but the structures differ:\n" + + "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original " + f"call, but a {thing2} here, so {explanation}" + for path, thing1, thing2, explanation + in equality_errors_pytreedef(filtered_tree, filtered_tree_))) + + residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx] + for i in res_spec] + dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + cts_flat, out_tree_ = tree_flatten(ct) + assert out_tree_ == out_tree + arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) + return tree_unflatten(in_tree, arg_cts) + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +si_vjp = saved_input_vjp + + def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 375d058d0edc..6c37635df1b0 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -19,6 +19,10 @@ enable_x64 as enable_x64, disable_x64 as disable_x64, ) +from jax._src.api import ( + saved_input_vjp as saved_input_vjp, + si_vjp as si_vjp +) from jax._src.callback import ( io_callback as io_callback ) diff --git a/tests/api_test.py b/tests/api_test.py index c9cf28e0af28..39590f744dcc 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -11496,5 +11496,63 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) +class InputSavedVJPTest(jtu.JaxTestCase): + + def test_basic(self): + def f(x, y): + return x * y + + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + arg_cts = f_vjp(1., *primals) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_unused(self): + f = jnp.sin + primals = 3., + y, f_vjp = api.si_vjp(f, [True], *primals) + x_ct, = f_vjp(1., *primals) + self.assertAllClose(y, jnp.sin(3.)) + self.assertAllClose(x_ct, jnp.cos(3.)) + + with self.assertRaisesRegex(Exception, "not used by the backward pass: x"): + _ = api.si_vjp(f, [True], *primals, allow_unused=False) + + def test_basic_opaque(self): + f = jnp.sin + primals = 3., + with self.assertRaisesRegex(Exception, "the backward pass requires opaque"): + _ = api.si_vjp(f, [True], *primals, allow_opaque=False) + + def test_basic_pytree_error(self): + def f(x): + return [x['hi'] * x['bye']] + + y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.}) + arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) + self.assertAllClose(y, [6.]) + self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) + + with self.assertRaisesRegex(ValueError, "but the structures differ"): + f_vjp(1., {'hi': 2.}) + + def test_fsdp(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y_grad = jnp.ones_like(y) + x_grad, w_grad = f2_sivjp(y_grad, w) + self.assertAllClose(x_grad, 2. * y_grad @ w.T) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From eb8c908d2daff31559ac556be8564ed09e001509 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 23 Sep 2024 18:33:46 +0000 Subject: [PATCH 0003/1769] Add CI workflow for JAX distibuted initialize in K8s jobsets --- .github/workflows/k8s.yaml | 105 +++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 1 + examples/k8s/example.yaml | 40 ++++++++++++ examples/k8s/svc-acct.yaml | 31 +++++++++ jax/_src/clusters/k8s_cluster.py | 20 +++--- 5 files changed, 188 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/k8s.yaml create mode 100644 examples/k8s/example.yaml create mode 100644 examples/k8s/svc-acct.yaml diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..5149e79f14b4 --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,105 @@ +name: Distributed run using K8s Jobset + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + pull-requests: read + actions: write # to cancel previous workflows + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -ex -o pipefail {0} + +jobs: + + distributed-initialize: + runs-on: ubuntu-22.04 + outputs: + TAG: ${{ steps.metadata.outputs.tags }} + steps: + - name: Checkout + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 + with: + path: jax + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18 + + - name: Install K8s Jobset + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml + + - name: Build image + run: | + cat > Dockerfile < 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml new file mode 100644 index 000000000000..d05fb9b0cd2a --- /dev/null +++ b/examples/k8s/svc-acct.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: training-job-sa + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: pod-reader +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-reader-binding + namespace: default +subjects: + - kind: ServiceAccount + name: training-job-sa + namespace: default +roleRef: + kind: Role + name: pod-reader + apiGroup: rbac.authorization.k8s.io diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 1274724b8ebd..a3b415df580a 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -35,15 +35,17 @@ def is_env_present(cls) -> bool: try: import kubernetes as k8s # pytype: disable=import-error except ImportError as e: - warnings.warn(textwrap.fill( - "Kubernetes environment detected, but the `kubernetes` package is " - "not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " - "jax with the [k8s] extra. For example:" - " pip install jax[k8s]" - " OR" - " pip install jax[k8s,]" - )) + warnings.warn( + '\n'.join([ + textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package " + "is not installed to enable automatic bootstrapping in this " + "environment. To enable automatic boostrapping, please install " + "jax with the [k8s] extra. For example:"), + " pip install jax[k8s]", + " pip install jax[k8s,]", + ]) + ) return False k8s.config.load_incluster_config() From c6ef01d1618510deafde9ee6b91ea658bd105e0d Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 17 Mar 2025 18:21:01 +0000 Subject: [PATCH 0004/1769] address review comments --- .github/workflows/k8s.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 5149e79f14b4..4da6a69775c2 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -10,8 +10,6 @@ on: permissions: contents: read - pull-requests: read - actions: write # to cancel previous workflows concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -25,8 +23,6 @@ jobs: distributed-initialize: runs-on: ubuntu-22.04 - outputs: - TAG: ${{ steps.metadata.outputs.tags }} steps: - name: Checkout uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 From ec994986ca5e0bacf15c324e24e514fd1f4005b8 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 17 Mar 2025 18:28:18 +0000 Subject: [PATCH 0005/1769] update ratchet action pin --- .github/workflows/k8s.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 4da6a69775c2..31ee05a03482 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 with: path: jax From ed43119a86069a777a4e0c045c90bbbbe7accccd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Mar 2025 21:38:14 -0400 Subject: [PATCH 0006/1769] JAX release v0.5.3 --- CHANGELOG.md | 2 +- jax/version.py | 2 +- setup.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..9faff67cf305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## Unreleased +## jax 0.5.3 * New Features diff --git a/jax/version.py b/jax/version.py index be20aca06358..13df5f00a11b 100644 --- a/jax/version.py +++ b/jax/version.py @@ -146,7 +146,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.1" +_minimum_jaxlib_version = "0.5.3" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 80f45285ba61..a5c8500dc1cf 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.1' +_current_jaxlib_version = '0.5.3' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.5.1' -_libtpu_version = '0.0.10.*' +_libtpu_version = '0.0.11.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From 8a493129e7dcf4c2c3a3187b4a6ea0ca780ceb04 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 02:40:10 -0700 Subject: [PATCH 0007/1769] [mosaic_gpu] Fix usage of `absl::Cleanup` in CUDA events timer. PiperOrigin-RevId: 738315605 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 4f804c9e2116..8f52dce3b021 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -100,10 +100,10 @@ static const auto* kEventElapsed = gpuStreamSynchronize(stream); auto start_event = std::make_unique(); auto end_event = std::make_unique(); - absl::MakeCleanup([&]() { + absl::Cleanup cleanup = [&]() { gpuEventDestroy(*start_event); gpuEventDestroy(*end_event); - }); + }; gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), From 00ce0bee56361cf88de49e11eaf61484895b047c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 03:06:24 -0700 Subject: [PATCH 0008/1769] [mosaic_gpu] Remove unnecessary allocations in CUDA events timer. PiperOrigin-RevId: 738321801 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 8f52dce3b021..a726acd4d662 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -98,19 +98,21 @@ static const auto* kEventElapsed = .Ret>() // elapsed_ms .To([](gpuStream_t stream, auto start, auto end, auto out) { gpuStreamSynchronize(stream); - auto start_event = std::make_unique(); - auto end_event = std::make_unique(); + gpuEvent_t start_event = nullptr; + gpuEvent_t end_event = nullptr; + absl::Cleanup cleanup = [&]() { - gpuEventDestroy(*start_event); - gpuEventDestroy(*end_event); + gpuEventDestroy(start_event); + gpuEventDestroy(end_event); }; - gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + + gpuMemcpy(&start_event, start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); - gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpy(&end_event, end.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); + float elapsed; - if (auto res = - gpuEventElapsedTime(&elapsed, *start_event, *end_event); + if (auto res = gpuEventElapsedTime(&elapsed, start_event, end_event); res) { return ffi::Error::Internal(absl::StrCat( "Failed to get elapsed time between events: ", ToString(res))); From b0865508a63089d9ce4e0ca4d372e3fd7f2d5cfd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 19 Mar 2025 04:32:29 -0700 Subject: [PATCH 0009/1769] [pallas:mosaic_gpu] Dialect lowering can now handle `lax.cond` PiperOrigin-RevId: 738342517 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 +- .../mosaic/gpu/dialect_lowering.py | 192 +++++++++++++----- .../mosaic/gpu/fragmented_array.py | 62 ++++-- .../mosaic/gpu/layout_inference.py | 42 +++- tests/pallas/mosaic_gpu_test.py | 10 +- 5 files changed, 223 insertions(+), 95 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b06e6b7dfc2..2ae51a8b22e8 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2036,16 +2036,16 @@ def _yielded_values(outs, avals): ret.append(_ensure_ir_value(out, aval.dtype)) return ret - # We need the branch return mlir types in order to construct the - # switch operation. To avoid leaking information about what kind of - # mlir types are internal to FragmentedArrays and other mgpu types, - # we run one of the branches in a dummy module that we throw away to - # extract the return types + # We need to know the result types ahead of time to construct the switch + # operation. Below we lower the first branch in a throw-away module to + # extract them. with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + yielded_types = [ + v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out)) + ] del outs switch_op = scf_dialect.IndexSwitchOp( diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fedde5a00887..6e7f4a981f9d 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -14,10 +14,11 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable import dataclasses import functools import itertools +import math import operator from typing import Any, Sequence, Type, cast @@ -34,6 +35,7 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np from . import fragmented_array as fa @@ -872,6 +874,66 @@ def _slice_smem(result: ir.Type, offset: ir.Value): return memref.view(result, smem_base, offset, []) +# The metadata needed to recostruct a vector from its flattened representation. +_VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] + + +def _flatten_ir_values( + values: Sequence[ir.Value], fa_layouts: Iterable[ir.Attribute] +) -> tuple[Sequence[ir.Value], Sequence[_VectorTemplate | None]]: + """Flattens a sequence of values. + + Non-vector values are preserved as is. Vectors are mapped to fragmented + arrays and then flattened into per-register values. + + Args: + values: The sequence of values to flatten. + fa_layouts: The layouts of vectors in ``values``. + + Returns: + A tuple of (flattened values, templates). The templates are used to + reconstruct the vectors from the per-register values. + """ + fa_layouts_it = iter(fa_layouts) + result = [] + templates = [] + for v in values: + if ir.VectorType.isinstance(v.type): + fa = _fragmented_array_from_ir(v, next(fa_layouts_it)) + result.extend(fa.registers.flat) + templates.append((fa.registers.shape, fa.layout, ir.VectorType(v.type))) + else: + result.append(v) + templates.append(None) + return result, templates + + +def _unflatten_ir_values( + flat_values: Sequence[ir.Value], templates: Sequence[_VectorTemplate | None] +) -> Sequence[ir.Value]: + """The inverse of ``_flatten_ir_values``.""" + result = [] + flat_values_it = iter(flat_values) + for template in templates: + if template is None: + result.append(next(flat_values_it)) + continue + registers_shape, layout, vec_type = template + value_registers = np.asarray( + [next(flat_values_it) for _ in range(math.prod(registers_shape))], + dtype=object, + ) + value = fa.FragmentedArray( + _registers=value_registers.reshape(registers_shape), + _layout=layout, + _is_signed=False + if ir.IntegerType.isinstance(vec_type.element_type) + else None, + ) + result.append(_fragmented_array_to_ir(value, vec_type)) + return result + + @_register_lowering(scf.ForOp) def _for_op_lowering_rule( ctx: LoweringContext, for_op: scf.ForOp @@ -884,60 +946,22 @@ def _for_op_lowering_rule( yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") - fa_layouts = in_layouts - - fa_layouts_it = iter(fa_layouts) - arg_template = [ - (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) - if ir.VectorType.isinstance(arg.type) - else (arg, arg.type) - for arg in for_op.initArgs - ] - def lower_carry(carry): - fa_layouts_it = iter(fa_layouts) - carry_with_fas = [ - _fragmented_array_from_ir(arg, next(fa_layouts_it)) - if ir.VectorType.isinstance(arg.type) - else arg - for arg in carry - ] - lowered_carry = [] - for c in carry_with_fas: - if isinstance(c, fa.FragmentedArray): - lowered_carry.extend(c.registers.flat) - else: - lowered_carry.append(c) - return lowered_carry - - def recreate_carry(lowered_carry): - recreated_carry = [] - arg_it = iter(lowered_carry) - for arg_value, arg_type in arg_template: - if isinstance(arg_value, fa.FragmentedArray): - carry_registers = np.asarray( - [next(arg_it) for _ in arg_value.registers.flat], dtype=object - ) - carry_registers = carry_registers.reshape(arg_value.registers.shape) - carry = fa.FragmentedArray( - _registers=carry_registers, - _layout=arg_value.layout, - _is_signed=arg_value.is_signed, - ) - recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) - else: - recreated_carry.append(next(arg_it)) - return recreated_carry + flat_init_args, args_template = _flatten_ir_values( + for_op.initArgs, in_layouts + ) new_for_op = scf.ForOp( for_op.lowerBound, for_op.upperBound, for_op.step, - lower_carry(for_op.initArgs), + flat_init_args, ) with ir.InsertionPoint(new_for_op.body): - recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) + recreated_carry = _unflatten_ir_values( + new_for_op.body.arguments[1:], args_template + ) ops_to_lower = [] - for op in for_op.body: + for op in [*for_op.body]: if op == yield_op: continue mgpu.private_operation_remove_from_parent(op) @@ -952,16 +976,80 @@ def recreate_carry(lowered_carry): ctx.lower_op(op) with ir.InsertionPoint(new_for_op.body): - new_yield_operands = lower_carry(yield_op.operands) + flat_operands, _ = _flatten_ir_values(yield_op.operands, in_layouts) yield_op.erase() - scf.yield_(new_yield_operands) - return recreate_carry(new_for_op.results) + scf.yield_(flat_operands) + + return _unflatten_ir_values(new_for_op.results, args_template) + + +def _infer_flat_result_types( + op: ir.OpView, out_layouts: Sequence[ir.Attribute] +) -> Sequence[ir.Type]: + result_types: list[ir.Type] = [] + out_layouts_it = iter(out_layouts) + for r in op.results: + if not ir.VectorType.isinstance(r.type): + result_types.append(r.type) + continue + vec_type = ir.VectorType(r.type) + layout = layouts_lib.from_layout_attr(next(out_layouts_it)) + result_types.extend( + [layout.registers_element_type(vec_type.element_type)] + * math.prod(layout.registers_shape(tuple(vec_type.shape))) + ) + return result_types + + +@_register_lowering(scf.IfOp) +def _if_op_lowering_rule( + ctx: LoweringContext, if_op: scf.IfOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(if_op): + return _traverse_op_lowering_rule(ctx, if_op) + + raise NotImplementedError + + +@_register_lowering(scf.IndexSwitchOp) +def _index_switch_op_lowering_rule( + ctx: LoweringContext, switch_op: scf.IndexSwitchOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(switch_op): + return _traverse_op_lowering_rule(ctx, switch_op) + + out_layouts = inference_utils.out_layouts(switch_op) + new_switch_op = scf.IndexSwitchOp( + _infer_flat_result_types(switch_op, out_layouts), + switch_op.arg, + switch_op.cases, + len(switch_op.regions) - 1, + ) + + results_template: Sequence[_VectorTemplate | None] = [] + for region, new_region in zip( + switch_op.regions, new_switch_op.regions, strict=True + ): + [block] = region.blocks + new_block = new_region.blocks.append() + with ir.InsertionPoint(new_block): + for op in [*block]: + if not isinstance(op, scf.YieldOp): + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_block, op) + ctx.lower_op(op) + continue + if inference_utils.in_layouts(op) != out_layouts: + raise ValueError("Layout mismatch") + flat_results, results_template = _flatten_ir_values( + op.operands, out_layouts + ) + scf.yield_(flat_results) + return _unflatten_ir_values(new_switch_op.results, results_template) @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) -@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5daed8416589..4bbfd0dd8afe 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -319,6 +319,9 @@ def tiled_tiling_rank(self) -> int: def vector_length(self) -> int: return self.tiled_tiling_shape[self.vector_dim] + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vector_length,), t) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) @@ -386,6 +389,19 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if len(shape) != 1: + raise ValueError("WGMMARowFragLayout requires a 1D shape") + if shape[0] % 64: + raise ValueError( + "WGMMARowFragLayout requires shape[0] to be a multiple of 64" + ) + return (shape[0] // 64, 2) + def thread_idxs(self, shape): index = ir.IndexType.get() assert len(shape) == 1 @@ -435,6 +451,14 @@ def can_broadcast_to(self, shape) -> bool: """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + del shape # Unused. + return () + def thread_idxs(self, shape): assert shape == self.shape raise NotImplementedError @@ -469,6 +493,15 @@ def from_shaped_type(cls, shaped_ty: ir.Type): shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vec_size,), t) + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if shape != self.shape: + raise ValueError(f"Shape {shape} is not compatible with {self}") + return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),) + def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() @@ -626,8 +659,8 @@ def __init__( != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( - "Invalid register array shape: math.prod({_registers.shape}) *" - " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + f"Invalid register array shape: math.prod({_registers.shape}) *" + f" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register @@ -703,30 +736,15 @@ def load_wgmma_row( def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout(): - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - reg_shape = (shape[0] // 64, 2) - case WGStridedFragLayout(vec_size=vec_size): - assert shape == layout.shape - elems = np.prod(shape) - reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) - value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) - case WGSplatFragLayout(): - assert shape == layout.shape - reg_shape = () - case TiledLayout(): - value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) - reg_shape = layout.registers_shape(shape) + case WGMMARowFragLayout() | WGSplatFragLayout(): + pass + case WGStridedFragLayout() | TiledLayout(): + value = vector.splat(layout.registers_element_type(value.type), value) case _: raise NotImplementedError(layout) return cls( - _registers=np.full(reg_shape, value, dtype=object), + _registers=np.full(layout.registers_shape(shape), value, dtype=object), _layout=layout, _is_signed=is_signed, ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 0d2811bb5610..470b0d328d8e 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -306,23 +306,46 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: return (layouts, []) -@partial(_add_layout_inference_rule, scf.ForOp) -def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: - yield_op = op.body.operations[len(op.body.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - - if inference_utils.has_in_layouts_set(yield_op): - yield_layouts = list(inference_utils.in_layouts(yield_op)) +def _infer_from_yield_ops(op: ir.Operation) -> list[ir.Attribute] | None: + candidates = [] + for region in op.regions: + [block] = region.blocks + yield_op = block.operations[len(block.operations) - 1] + assert isinstance(yield_op, scf.YieldOp) + if not inference_utils.has_in_layouts_set(yield_op): + continue + yield_layouts = inference_utils.in_layouts(yield_op) if any( layouts_lib.is_splat_fragmented_layout(layout) for layout in yield_layouts ): - return None - return (yield_layouts, yield_layouts) + continue + candidates.append(yield_layouts) + if not candidates: + return None + return [_choose_representative_layout(set(c)) for c in zip(*candidates)] + +@partial(_add_layout_inference_rule, scf.ForOp) +def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: # TODO(bchetioui): we don't attempt to propagate from outside for the moment. # For the existing kernels, propagating from the YieldOp should be enough. + if layouts := _infer_from_yield_ops(op): + return layouts, layouts + return None + +@partial(_add_layout_inference_rule, scf.IfOp) +def _infer_if_op_layout(op: scf.IfOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts + return None + + +@partial(_add_layout_inference_rule, scf.IndexSwitchOp) +def _infer_index_switch_op_layout(op: scf.IndexSwitchOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts return None @@ -333,7 +356,6 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: shape=cast(ir.ShapedType, splat_op.result.type).shape ) ) - return [], [layout] diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c3ddb84e09..6792ddfaa9a8 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1121,13 +1121,13 @@ def test_cond_returning_array(self, thread_semantics): ), ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) + acc_sum = _sum_same_dtype(x_ref[...]) acc2, acc = jax.lax.cond( - acc % 2 == 0, - lambda: (acc * 2, acc), - lambda: (acc, acc * 2), + acc_sum % 2 == 0, + lambda: (acc_sum * 2, x_ref[...]), + lambda: (acc_sum, x_ref[...]), ) - o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) From 30f770970404785016d3503ab1543540c8c88df0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 05:14:52 -0700 Subject: [PATCH 0010/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0d20d73f2c8f21c21b9f343c4363a76e980f032e. PiperOrigin-RevId: 738352930 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 73bf2eb3850d..f81e3931b1dc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" -XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" +XLA_COMMIT = "0d20d73f2c8f21c21b9f343c4363a76e980f032e" +XLA_SHA256 = "9df61c200b0a54b7a5c55155fa7a454e33d660e6a49239b6980f5a10305fecc5" def repo(): tf_http_archive( From c8032a9904eeb2410995425f817929f507fe22d5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 08:56:02 -0400 Subject: [PATCH 0011/1769] Fix line continuation character in Windows wheel build. --- .github/workflows/wheel_win_x64.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 444bc83f2889..912088428fd5 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -38,7 +38,7 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ + python -m uv pip install -r build/test-requirements.txt ` --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py build --wheels=jaxlib ` @@ -58,7 +58,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ + python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib ` -e ${{ github.workspace }} echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples From 133a885e3b7a8347c121dce99eb3a920b6333a9e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 06:44:35 -0700 Subject: [PATCH 0012/1769] `use_mesh` and `use_concrete_mesh` should error when used under jit PiperOrigin-RevId: 738376533 --- jax/_src/array.py | 4 ++-- jax/_src/pjit.py | 8 +++++--- jax/_src/sharding_impls.py | 17 +++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index b0793d2c3330..e49963ccda9c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -43,7 +43,7 @@ from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices, - local_to_global_shape, use_concrete_mesh) # pyformat: disable + local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np @@ -1149,7 +1149,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with use_concrete_mesh(None): + with _internal_use_concrete_mesh(None): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f7a4361ffee2..38ccb4513766 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -68,7 +68,7 @@ NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, - flatten_spec) + flatten_spec, _internal_use_concrete_mesh) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -689,8 +689,10 @@ def _infer_params_cached( def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: - if ji.use_resource_env: - with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + if ji.use_resource_env: # pjit + phys_mesh = mesh_lib.thread_resources.env.physical_mesh + with (_internal_use_concrete_mesh(phys_mesh), + mesh_lib.use_abstract_mesh(phys_mesh.abstract_mesh)): return _infer_params_internal(fun, ji, args, kwargs) return _infer_params_internal(fun, ji, args, kwargs) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2bbf913783e3..f3295a75cf7a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1382,10 +1382,8 @@ def use_mesh(mesh: mesh_lib.Mesh): if not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') + if not core.trace_state_clean(): + raise ValueError('`use_mesh` can only be used outside of `jax.jit`') with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield @@ -1410,13 +1408,16 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: @contextlib.contextmanager def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if not core.trace_state_clean(): + raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + with _internal_use_concrete_mesh(mesh): + yield + +@contextlib.contextmanager +def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh | None): if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') - prev_val = config.device_context.swap_local(mesh) try: yield From 1e25c44d67a024daa2333b652c578ac3535bb803 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 07:42:54 -0700 Subject: [PATCH 0013/1769] [mosaic_gpu] Only `jit` function to profile with cupti if it not already `jit`ted. PiperOrigin-RevId: 738393973 --- jax/experimental/mosaic/gpu/profiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0c128f88d169..99fefc1adc9c 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -21,6 +21,7 @@ import warnings import jax +from jax._src import stages from jax._src.lib import xla_client import jax.numpy as jnp from jaxlib.mlir import ir @@ -98,10 +99,13 @@ def run(*args, **kwargs): def _measure_cupti(f, aggregate): + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) + def run(*args, **kwargs): mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() try: - results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) + results = jax.block_until_ready(f(*args, **kwargs)) finally: timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() return results, timings From d7d0aa943e825b89d9d696066f3a7389b1e9bb9e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 19 Mar 2025 07:56:34 -0700 Subject: [PATCH 0014/1769] Move PRNG GPU lowering from jaxlib into JAX. PiperOrigin-RevId: 738398099 --- jax/_src/prng.py | 36 ++++++++++----------- jaxlib/gpu_prng.py | 79 ++++++---------------------------------------- 2 files changed, 25 insertions(+), 90 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..5fdd673b3454 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import ffi from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import tree_util as tree_util_internal @@ -64,6 +65,13 @@ UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} +if hasattr(gpu_prng, "registrations"): + for platform, targets in gpu_prng.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + # -- PRNG implementation interface class PRNGImpl(NamedTuple): @@ -902,7 +910,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): multiple_results=True) -def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): +def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix): if not config.threefry_gpu_kernel_lowering.value: # back to default lowering return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) @@ -917,23 +925,11 @@ def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) - out_len = reduce(op.mul, aval_out.shape, 1) - if not core.is_constant_dim(out_len): - length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.convert( - ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length) - output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - else: - length = int(out_len) # will be passed statically - output_shape = None - - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape, - False, # forward_compatibility_mode - ) + sub_ctx = ctx.replace(avals_in=(aval_out,) * 4) + rule = ffi.ffi_lowering( + f"{target_name_prefix}_threefry2x32_ffi") + return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), + _broadcast(x1, x1_aval), _broadcast(x2, x2_aval)) threefry2x32_p = core.Primitive("threefry2x32") @@ -947,11 +943,11 @@ def _broadcast(x, aval): threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'), platform='rocm') diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 6f74d5813ce4..b112534c0575 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,79 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from functools import partial -import itertools +from typing import Any -import jaxlib.mlir.ir as ir - -from jaxlib import xla_client - -from .hlo_helpers import custom_call from .plugin_support import import_from_plugin _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") -if _cuda_prng: - for _name, _value in _cuda_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hip_prng: - for _name, _value in _hip_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - - -def _threefry2x32_lowering(prng, platform: str, keys, data, - length: int | ir.Value | None = None, - output_shape: ir.Value | None = None, - forward_compatibility_mode: bool = False): - """ThreeFry2x32 kernel for GPU. - - In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` - is a 1D tensor describing the shape of the two outputs. - """ - del forward_compatibility_mode - assert len(keys) == 2, keys - assert len(data) == 2, data - assert (ir.RankedTensorType(keys[0].type).element_type == - ir.IntegerType.get_unsigned(32)), keys[0].type - - typ = keys[0].type - dims = ir.RankedTensorType(typ).shape - - for x in itertools.chain(keys, data): - assert x.type == typ, (x.type, typ) - ndims = len(dims) - layout = tuple(range(ndims - 1, -1, -1)) - operand_layouts = [layout] * 4 - operands = [keys[0], keys[1], data[0], data[1]] - - opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). - if isinstance(length, int): - result_shapes = None - else: - assert output_shape is not None - # We also need to pass separately the shapes of the outputs. - result_shapes = [output_shape, output_shape] - - custom_call_target = f"{platform}_threefry2x32_ffi" - return custom_call( - custom_call_target, - api_version=4, - result_types=[typ, typ], - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[layout] * 2, - result_shapes=result_shapes).results - - -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") -rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations From 1dcf872c64dbd6bf93dfeebdde869847d9ac5b53 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 08:15:10 -0700 Subject: [PATCH 0015/1769] Move //jaxlib:pass_boilerplate to //jaxlib/mosaic:pass_boilerplate. This code is Mosaic specific, move it to the Mosaic directory. PiperOrigin-RevId: 738404429 --- jaxlib/BUILD | 11 ----------- jaxlib/mosaic/BUILD | 15 +++++++++++++-- jaxlib/mosaic/dialect/tpu/transforms/serde.h | 2 +- jaxlib/mosaic/gpu/BUILD | 2 +- jaxlib/mosaic/gpu/passes.cc | 2 +- jaxlib/mosaic/gpu/serde.h | 2 +- jaxlib/{ => mosaic}/pass_boilerplate.h | 6 +++--- 7 files changed, 20 insertions(+), 20 deletions(-) rename jaxlib/{ => mosaic}/pass_boilerplate.h (94%) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a35eabc9a505..a5e8cee08cdc 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -171,17 +171,6 @@ cc_library( ], ) -cc_library( - name = "pass_boilerplate", - hdrs = ["pass_boilerplate.h"], - # compatible with libtpu - deps = [ - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "handle_pool", hdrs = ["handle_pool.h"], diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 4cc2530dd7ca..775c34c8e7c7 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,9 +60,9 @@ cc_library( ]), # compatible with libtpu deps = [ + ":pass_boilerplate", + ":serde", ":tpu_inc_gen", - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -279,6 +279,17 @@ filegroup( # compatible with libtpu ) +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], + # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "serde", srcs = ["serde.cc"], diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 8685918d3b39..64753a22e7be 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -9,7 +9,7 @@ #include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 9249ae256901..abe326474808 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -52,7 +52,7 @@ cc_library( "serde.h", ], deps = [ - "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:pass_boilerplate", "//jaxlib/mosaic:serde", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index b8c3fbb74c81..1815e18ca927 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/include/mlir/Support/LLVM.h" #include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index 6187d72b4cd5..d1e25e3f0912 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h similarity index 94% rename from jaxlib/pass_boilerplate.h rename to jaxlib/mosaic/pass_boilerplate.h index b9754a8738ee..546981feeef7 100644 --- a/jaxlib/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_PASS_BOILERPLATE_H_ -#define JAXLIB_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ #include @@ -64,4 +64,4 @@ class Pass : public ::mlir::OperationPass { } // namespace mlir } // namespace jaxlib -#endif // JAXLIB_PASS_BOILERPLATE_H_ +#endif // JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ From 4893c08441231bd15c20dd76c4acb5d36890cd79 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Wed, 19 Mar 2025 08:32:49 -0700 Subject: [PATCH 0016/1769] Support bfloat16 and other scalar values in broadcast PiperOrigin-RevId: 738410122 --- jax/_src/pallas/mosaic/lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 10b9de7487eb..2bfd2f357510 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2346,13 +2346,13 @@ def _bcast(x, y, x_aval, y_aval, out_aval): y_dtype = x_aval.dtype elif x_aval.weak_type: x_dtype = y_aval.dtype - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, ir.Value): if getattr(y, "type", None) == ir.IndexType.get(): mlir_type = y.type else: mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.number, int, float)): + if not isinstance(y, ir.Value): if getattr(x, "type", None) == ir.IndexType.get(): mlir_type = x.type else: From fd23fa8cf0f67e3bc82940b827536a61465b08b7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 19 Mar 2025 08:58:39 -0700 Subject: [PATCH 0017/1769] [Mosaic GPU] Remove `transpose_{a,b}` attributes from `mosaic_gpu.WGMMAOp`. Now that we have full control over strides in the lowering, these attributes are no longer necessary. PiperOrigin-RevId: 738418852 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 3 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 16 ++++------------ 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 6e7f4a981f9d..d605a2dea8f9 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -763,9 +763,6 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - if wgmma_op.transpose_a or wgmma_op.transpose_b: - raise ValueError("Transpose arguments are to be deleted.") - fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0882986fcf5e..108ff952b571 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -394,19 +394,14 @@ def MosaicGPU_WGMMAOp : Op { This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to - accomplish the calculation. The `b` matrix, and optionally `a`, needs to be - provided as a 2-dimensional memref. All memrefs may have transforms that - define swizzling, tiling, and transposition. + accomplish the calculation. The `b` matrix, and optionally `a`, need to be + provided as a 2-dimensional memref. The inputs should have the following shapes: - a: [groups_m * 64, groups_k * s] - b: [groups_k * s, groups_n * s] - accumulator: [groups_m * 64, groups_n * s] - Where: - - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) - and the tilings are [64, s] for `a` and [s, s] for `b`. - - `a` and/or `b` may be transposed if the corresponding attribute is set - to `true`. + where `s == swizzle / element_bytewidth`. The output has an identical shape and type as the input accumulator. @@ -429,10 +424,7 @@ def MosaicGPU_WGMMAOp : Op { AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, - - DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b ); let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); From af5b2efd3e2fd7b071b95581f01d555451e95c32 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 18 Mar 2025 16:30:13 -0700 Subject: [PATCH 0018/1769] Fix input_output_aliases for non-HBM kernel args in TPU interpret mode. --- jax/_src/pallas/mosaic/interpret.py | 140 +++++++++++----------- tests/pallas/tpu_pallas_interpret_test.py | 42 +++++++ 2 files changed, 111 insertions(+), 71 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 1ad7be8154cd..3384026c1f5b 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -13,7 +13,6 @@ # limitations under the License. import collections -from collections.abc import Iterable, Sequence import dataclasses import enum import functools @@ -1283,23 +1282,6 @@ def f(*args, jaxpr): return jax.util.safe_map(read, jaxpr.outvars) -def _initialize_output_vals( - block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases, - interpret_params: TPUInterpretParams, -) -> Sequence[jax.Array]: - oi_map = {v: k for k, v in input_output_aliases} - output_vals = [] - for i, bm in enumerate(block_mappings_output): - if i in oi_map: - output_vals.append(input_args[oi_map[i]]) - else: - output_vals.append(_uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype, - interpret_params)) - return output_vals - def _compute_start_indices(block_mapping, loop_idx, *args): block_indices = ( jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) @@ -1423,30 +1405,52 @@ def interpret_pallas_call( for a, bs in zip(input_args, block_shapes[:num_inputs]) ] - # Allocate buffers in HBM for outputs. - output_buffer_ids = [] - output_buffer_shapes = [] - output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, - scalars + input_args, - input_output_aliases, - interpret_params) - num_outputs = grid_mapping.num_outputs - output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] - for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) - output_buffer_shapes.append(padded_val.shape) - output_buffer_ids.append(callback.io_callback( + # Allocate HBM buffers for pallas_call inputs. + # + # TODO(jburnim): As an optimization, skip allocating buffers for inputs that + # are neither aliased nor passed to the kernel in HBM? + input_buffer_ids = [] + for i, var in enumerate( + jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, + input_args[i], ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, - # outputs, scratch). - io_alias_map = dict(input_output_aliases) + + # Allocate buffers in HBM for pallas_call outputs. oi_alias_map = {v: k for k, v in input_output_aliases} + output_buffer_ids = [] + output_buffer_shapes = [] + output_vals = [] + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for i, bm in enumerate(grid_mapping.block_mappings_output): + if i in oi_alias_map: + # Re-use the HBM buffer for the aliased pallas_call input. + output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) + output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) + output_vals.append(input_args[oi_alias_map[i]]) + else: + out_val = _uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype, + interpret_params) + padded_val = _pad_to_block_dimension( + out_val, output_block_shapes[i], interpret_params) + output_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + padded_val, + ordered=True)) + output_buffer_shapes.append(padded_val.shape) + output_vals.append(out_val) + + # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, + # outputs, scratch). kernel_buffer_ids = [] for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): kernel_buffer_ids.append(callback.io_callback( @@ -1467,23 +1471,18 @@ def interpret_pallas_call( device_id, var.aval.shape, ordered=True)) - elif is_output and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. + elif _is_any(var.aval.memory_space): + # Use the already-allocated HBM input or output buffer. # - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map + # TODO(jburnim): For kernel args in HBM, check that block shape eqals the + # shape of the corresponding pallas_call input, and that the index_map # is trivial. - kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif is_output and (output_idx in oi_alias_map): - # Use the already-allocated (non-HBM) input buffer. - kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) - elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) + assert is_input ^ is_output + if is_input: + kernel_buffer_ids.append(input_buffer_ids[i]) + if is_output: + kernel_buffer_ids.append(output_buffer_ids[output_idx]) else: - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1499,24 +1498,6 @@ def interpret_pallas_call( input_vars, output_vars = split_list( jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) - # For kernel inputs that are in HBM, we populate the buffer once before - # any kernel invocations. - for buffer_id, var, val in zip(input_ids, input_vars, input_args): - if not _is_any(var.aval.memory_space): - continue - if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype): - # TODO(jburnim): Also check that the index_map is trivial. - raise ValueError() - callback.io_callback( - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - (), - val, - ordered=True) - if grid: num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] else: @@ -1547,9 +1528,26 @@ def body(carry): for j, var in enumerate(input_vars): if _is_any(var.aval.memory_space): continue - sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], - input_args[j], is_indexing_dim[j]) - assert(sliced_val.shape == var.aval.shape) + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple(indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip(start_indices[j], + block_shapes[j], + is_indexing_dim[j])), + shape=input_args[j].shape, + int_indexer_shape=()) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + get, + jax.ShapeDtypeStruct(var.aval.shape, var.aval.dtype), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + input_buffer_ids[j], + (transform,), + ordered=True) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index bc589855b836..9b8a5b46865d 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -69,6 +69,7 @@ def matmul(x: jax.Array, y: jax.Array): np.testing.assert_allclose(z, x @ y, atol=1e-4) def test_dynamic_grid_and_aliasing(self): + self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) @@ -91,8 +92,49 @@ def f(s, x): s = jnp.array([1], dtype=jnp.int32) x = jnp.arange(32 * 128.).reshape((32, 128)) y = f(s, x) + # NOTE: No matter how many times the kernel body is run, the kernel input + # buffer will only be written once by the pallas_call machinery, just + # before the first iteration. So the output will be x + 1 , despite the + # aliasing in HBM. np.testing.assert_allclose(y, x + 1.0) + def test_aliasing(self): + def kernel(x_ref, o_ref, s_ref): + @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) + def _(): + s_ref[0] = jnp.int32(0) + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) + + x = jnp.zeros((4 * 8, 4 * 128)) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (j, i)), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + input_output_aliases={0: 0}, + interpret=mosaic_interpret.TPUInterpretParams(), + )(x) + + expected = np.zeros((4, 4)) + t = 0 + for i in range(4): + for j in range(4): + expected[j, i] = expected[i, j] + t + t += 1 + # NOTE: expected is + # [[0, 5, 10, 15], + # [1, 5, 15, 20], + # [2, 6, 10, 25], + # [3, 7, 11, 15]] + np.testing.assert_allclose(y[::8, ::128], expected) + @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): def kernel_without_race(x_ref, o_ref, t_ref, sem): From dde861af5fcf7d56863cce5afd671720df975cf4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 09:05:05 -0700 Subject: [PATCH 0019/1769] Remove the jax Array migration guide from the TOC tree but keep the doc around PiperOrigin-RevId: 738421256 --- docs/jax_array_migration.md | 3 +++ docs/notes.rst | 4 ---- jax/_src/pjit.py | 8 +++----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..a557f4ae7efc 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..24a9dc8594cd 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -27,7 +24,6 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 38ccb4513766..b6024dcdfedd 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1573,14 +1573,12 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' 'replicated sharding explicitly or use ' - '`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` ' + '`jax.make_array_from_process_local_data(...)` ' 'to convert your host local numpy inputs to a jax.Array which you ' - 'can pass to pjit. ' + 'can pass to jit. ' 'If the numpy input is the same on each process, then you can use ' '`jax.make_array_from_callback(...) to create a `jax.Array` which ' - 'you can pass to pjit. ' - 'Please see the jax.Array migration guide for more information ' - 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' + 'you can pass to jit. ' f'Got arg shape: {arg.shape}, arg value: {arg}') if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete: # jax.jit does not allow resharding across different memory kinds even From dd93eeae2e603352f2a57afb5c8af432bacbbdcf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 09:21:12 -0700 Subject: [PATCH 0020/1769] [JAX] Move py_client_gpu into JAX. This callback functionality is only used by JAX and shipped as part of its CUDA and ROCM GPU plugins. Move it into JAX, as part of a wider move of xla/python pieces that belong to JAX into JAX. PiperOrigin-RevId: 738426489 --- jaxlib/cuda/BUILD | 44 ++++ jaxlib/cuda/cuda_plugin_extension.cc | 13 ++ jaxlib/gpu/BUILD | 3 +- jaxlib/gpu/gpu_plugin_extension.cc | 9 - jaxlib/gpu/py_client_gpu.cc | 295 +++++++++++++++++++++++++++ jaxlib/gpu/py_client_gpu.h | 37 ++++ jaxlib/gpu/vendor.h | 2 + jaxlib/rocm/BUILD | 44 ++++ jaxlib/rocm/rocm_plugin_extension.cc | 12 ++ 9 files changed, 449 insertions(+), 10 deletions(-) create mode 100644 jaxlib/gpu/py_client_gpu.cc create mode 100644 jaxlib/gpu/py_client_gpu.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a9bd35b7768d..23ab64aa2d01 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -657,11 +657,55 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:callback", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", + ], +) + nanobind_extension( name = "cuda_plugin_extension", srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 8d8514bd2740..789227e273b6 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -21,12 +21,15 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; namespace xla { namespace { + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,10 +41,20 @@ static std::string ToString(CUresult result) { } return absl::StrCat(error_name, ": ", error_string); } + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); + return dict; +} + } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("registrations", &Registrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index b5292746dd10..de55989bf73f 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -52,6 +52,8 @@ exports_files(srcs = [ "prng_kernels.cc", "prng_kernels.cu.cc", "prng_kernels.h", + "py_client_gpu.cc", + "py_client_gpu.h", "rnn.cc", "rnn_kernels.cc", "rnn_kernels.h", @@ -115,7 +117,6 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index b56cb8337f1b..5726e0929ee5 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -35,7 +35,6 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" @@ -202,13 +201,6 @@ absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, return absl::OkStatus(); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - } // namespace void BuildGpuPluginExtension(nanobind::module_& m) { @@ -264,7 +256,6 @@ void BuildGpuPluginExtension(nanobind::module_& m) { type_name_size, std::move(type_id))); }, nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); - m.def("registrations", &Registrations); } } // namespace xla diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc new file mode 100644 index 000000000000..d6faa1859eb8 --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.cc @@ -0,0 +1,295 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/include/llvm/Support/Casting.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/callback.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_host_callback.h" +#include "xla/python/types.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + // Ignore `descriptor` arg to callback + buffers += 1; + uint64_t descriptor; + if (!absl::SimpleAtoi(opaque, &descriptor)) { + throw xla::XlaRuntimeError("Invalid callback descriptor"); + return; + } + xla::CpuCallback* callback = + absl::bit_cast(static_cast(descriptor)); + size_t arity = callback->num_args(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + const xla::CpuCallback::Arg& arg = callback->args()[i]; + if (arg.type == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg.size_in_bytes]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + xla::CpuCallback::Arg arg = callback->args()[i]; + if (arg.type == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto array = xla::nb_numpy_ndarray(arg.dtype, arg.dims, arg.strides, + const_cast(host_input_buffers[i]), + /*base=*/base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + xla::EnterHostCallback(); + absl::StatusOr maybe_result_tuple = + callback->Call(host_input_arrays); + xla::LeaveHostCallback(); + if (!maybe_result_tuple.ok()) { + absl::string_view msg = maybe_result_tuple.status().message(); + XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); + return; + } + nb::tuple result_tuple = maybe_result_tuple.value(); + std::vector temp_buffers; + for (size_t i = 0; i < callback->results().size(); ++i) { + xla::CpuCallback::Result result = callback->results()[i]; + if (result.type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == result.expected_strides) { + auto gpu_res = + gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } else { + void* temp = new char[result.size_in_bytes]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(result.type); + options.dims = dims; + options.permutation = result.reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + callback->transpose_cache().GetOrCreate(options); + if (!plan.ok()) { + throw xla::XlaRuntimeError(plan.status().ToString()); + } + plan.value()->Execute(array.data(), temp); + auto gpu_res = + gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } +} + +// TODO(danfm): When compiled as part of a jaxlib plugin, this will register +// the custom call target in the plugin's registry. This won't affect +// registration via the Python API, but we should remove this once we have +// fully migrated to the plugin interface. +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + "xla_python_gpu_callback", &XlaPythonGpuCallback, + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); + +absl::Status XlaFfiPythonGpuCallback( + gpuStream_t stream, + std::vector>* callbacks, + uint64_t index, xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + auto loaded_callback = llvm::dyn_cast_or_null( + callbacks->at(index).get()); + if (loaded_callback == nullptr) { + return absl::InternalError( + "Expected a PyCpuLoadedHostCallback, got something else."); + } + xla::CpuCallback* callback = loaded_callback->cpu_callback(); + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + if (arg->element_type() == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg->size_bytes()]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + xla::PrimitiveType ptype = arg->element_type(); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + } else { + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + TF_ASSIGN_OR_RETURN(auto dtype, xla::PrimitiveTypeToNbDtype(ptype)); + auto array = xla::nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + absl::StatusOr maybe_result_tuple = + callback->FfiCall(host_input_arrays); + xla::LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = ret->element_type(); + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + TF_ASSIGN_OR_RETURN(auto expected_shape, xla::ShapeUtil::MakeValidatedShape( + ptype, ret->dimensions())); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } else { + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + options.dims = dims; + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.rank()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + TF_ASSIGN_OR_RETURN(auto plan, + callback->transpose_cache().GetOrCreate(options)); + plan->Execute(array.data(), temp); + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>>>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaFfiPythonGpuCallback); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h new file mode 100644 index 000000000000..6be2d40823dc --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ +#define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ + +#include + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/ffi.h" +#include "xla/service/custom_call_status.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7334d4690b59..cadd5453107a 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -48,6 +48,7 @@ limitations under the License. #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" +#define JAX_GPU_PLUGIN_NAME "cuda" typedef cuComplex gpuComplex; typedef cuDoubleComplex gpuDoubleComplex; @@ -413,6 +414,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" +#define JAX_GPU_PLUGIN_NAME "rocm" #define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a25a795fd14..867048509afa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -555,11 +555,55 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":hip_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:callback", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", + ], +) + nanobind_extension( name = "rocm_plugin_extension", srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1dd1f1943fc8..1e8013f2bc1b 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -21,11 +21,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace nb = nanobind; namespace xla { namespace { + std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -62,10 +65,19 @@ std::string ToString(hipError_t result) { return absl::StrCat("hipError_t(", static_cast(result), ")"); } } + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); + return dict; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("registrations", &Registrations); m.def( "get_device_ordinal", [](std::intptr_t data_value) { From ee74c289ac02b0f0f07ce8819bef4b2f97207d4b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 10:03:57 -0700 Subject: [PATCH 0021/1769] Move //jaxlib:handle_pool to //jaxlib/gpu:handle_pool. This is a GPU-specific target. PiperOrigin-RevId: 738441625 --- jaxlib/BUILD | 15 --------------- jaxlib/cuda/BUILD | 8 ++++---- jaxlib/gpu/BUILD | 15 +++++++++++++++ jaxlib/gpu/blas_handle_pool.cc | 2 +- jaxlib/gpu/blas_handle_pool.h | 2 +- jaxlib/{ => gpu}/handle_pool.h | 6 +++--- jaxlib/gpu/rnn_kernels.cc | 2 +- jaxlib/gpu/solver_handle_pool.cc | 2 +- jaxlib/gpu/solver_handle_pool.h | 2 +- jaxlib/gpu/sparse_kernels.cc | 2 +- jaxlib/gpu/sparse_kernels.h | 2 +- jaxlib/rocm/BUILD | 8 ++++---- 12 files changed, 33 insertions(+), 33 deletions(-) rename jaxlib/{ => gpu}/handle_pool.h (96%) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a5e8cee08cdc..faf52a702386 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -171,21 +171,6 @@ cc_library( ], ) -cc_library( - name = "handle_pool", - hdrs = ["handle_pool.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - ], -) - # This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong # target architecture. nanobind_extension( diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 23ab64aa2d01..4e74cc2dcf5b 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -89,7 +89,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -155,8 +155,8 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -195,7 +195,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -308,8 +308,8 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index de55989bf73f..3613be567533 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -93,6 +93,21 @@ xla_py_proto_library( deps = [":triton_proto"], ) +cc_library( + name = "handle_pool", + hdrs = ["handle_pool.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_plugin_extension", srcs = ["gpu_plugin_extension.cc"], diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc index 2ce204453039..ff381b802ab2 100644 --- a/jaxlib/gpu/blas_handle_pool.cc +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h index b3cdbaa88867..43724baab45e 100644 --- a/jaxlib/gpu/blas_handle_pool.h +++ b/jaxlib/gpu/blas_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/handle_pool.h b/jaxlib/gpu/handle_pool.h similarity index 96% rename from jaxlib/handle_pool.h rename to jaxlib/gpu/handle_pool.h index 9201d8d579c5..9189bb174b06 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/gpu/handle_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HANDLE_POOL_H_ -#define JAXLIB_HANDLE_POOL_H_ +#ifndef JAXLIB_GPU_HANDLE_POOL_H_ +#define JAXLIB_GPU_HANDLE_POOL_H_ #include #include @@ -107,4 +107,4 @@ void HandlePool::Return(HandleType handle, } // namespace jax -#endif // JAXLIB_HANDLE_POOL_H_ +#endif // JAXLIB_GPU_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index e9820bc31f1e..45f8ba8187ba 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc index c55ea923b21b..416ccf9d1bbc 100644 --- a/jaxlib/gpu/solver_handle_pool.cc +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h index c46c062b3054..4e369ea85520 100644 --- a/jaxlib/gpu/solver_handle_pool.h +++ b/jaxlib/gpu/solver_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 5b620a05236d..c66e96b6b89e 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -28,7 +28,7 @@ limitations under the License. #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 323431812758..0d74ebc7d8e4 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 867048509afa..1e54d82c4f71 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -79,7 +79,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", @@ -143,8 +143,8 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -182,7 +182,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", @@ -291,8 +291,8 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", From 85e78840103a1ded7fa7feb1243ca3f0a9c7c63b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 19 Mar 2025 10:07:08 -0700 Subject: [PATCH 0022/1769] Support error checking in auto mode PiperOrigin-RevId: 738443014 --- jax/_src/error_check.py | 108 +++++++++++++++++++++++++------------- tests/error_check_test.py | 47 ++++++++++++++++- 2 files changed, 117 insertions(+), 38 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 60dc2f76a5b2..88dcec7063d9 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -16,6 +16,7 @@ from functools import partial import threading +import warnings import jax from jax._src import core @@ -58,11 +59,11 @@ def __init__(self): def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread. + """Initialize the error code ref in the current thread. - The size of the error code array is determined by the mesh in the context. In - single-device environment, the array is a scalar. In multi-device - environment, the array has the same shape as the mesh. + The shape and size of the error code array depend on the mesh in the context. + In single-device environments, the array is a scalar. In multi-device + environments, its shape and size match those of the mesh. """ with core.eval_context(): # Get mesh from the context. @@ -83,13 +84,18 @@ def _initialize_error_code_ref() -> None: class error_checking_context: - """Redefine the error checking state based on the mesh in the context. + """Redefine the internal error state based on the mesh in the context. - This context manager should be used when starting a multi-device - computation, and whenever the mesh is changed. + When using JAX in multi-device environments in explicit mode, error tracking + needs to be properly aligned with the device mesh. This context manager + ensures that the internal error state is correctly initialized based on the + current mesh configuration. - When exiting the context, the error checking state will be reset to the - original state. + This context manager should be used when starting a multi-device computation, + or when switching between different device meshes. + + On entering the context, it initializes a new error state based on the mesh in + the context. On exiting the context, it restores the previous error state. """ __slots__ = ("old_ref",) @@ -107,12 +113,28 @@ def __exit__(self, exc_type, exc_value, traceback): def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set error if any element of pred is true. - - If the error is already set, the new error will be ignored. It will not - override the existing error. - - In auto mode, this function does not work under jit. + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context()` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: _initialize_error_code_ref() @@ -127,28 +149,34 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: out_sharding = core.typeof(_error_storage.ref).sharding in_sharding: NamedSharding = core.typeof(pred).sharding - if out_sharding.mesh.shape_tuple == (): # single-device case. + # Reduce `pred`. + if all(dim is None for dim in out_sharding.spec): # single-device case. pred = pred.any() else: # multi-device case. has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types - if has_auto_axes: - raise NotImplementedError( - "Error checking in auto mode is not supported yet. Please use" - " explicit mode." + if has_auto_axes: # auto mode. + warnings.warn( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode.", + RuntimeWarning, ) - if out_sharding.mesh != in_sharding.mesh: - raise ValueError( - "The error code state and the predicate must be on the same mesh, " - f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " - "Please use `with error_checking_context()` to redefine the error " - "code state based on the mesh." - ) - pred = shard_map( - partial(jnp.any, keepdims=True), - mesh=out_sharding.mesh, - in_specs=in_sharding.spec, - out_specs=out_sharding.spec, - )(pred) # perform per-device reduction + pred = pred.any() # reduce to a single scalar + else: # explicit mode. + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction error_code = _error_storage.ref[...] should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) @@ -158,10 +186,18 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: def raise_if_error() -> None: - """Raise error if an error is set. + """Raise an exception if the internal error state is set. + + This function should be called after a computation completes to check for any + errors that were marked during execution via `set_error_if()`. If an error + exists, it raises a `JaxValueError` with the corresponding error message. + + This function should not be called inside a traced function (e.g., inside + :func:`jax.jit`). Doing so will raise a `ValueError`. - This function should be called after the computation is finished. It should - not be called within a traced context, such as within a jitted function." + Raises: + JaxValueError: If the internal error state is set. + ValueError: If called within a traced JAX function. """ if _error_storage.ref is None: # if not initialized, do nothing return diff --git a/tests/error_check_test.py b/tests/error_check_test.py index b96c6281411f..ad67cadfb074 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -18,6 +18,7 @@ import jax from jax._src import config from jax._src import error_check +from jax._src import mesh as mesh_lib from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -202,9 +203,51 @@ def f(x): if jit: f = jax.jit(f) - sharding = NamedSharding(mesh, P("x", "y")) - x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) with error_check.error_checking_context(): + x = jnp.full((4, 4), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + sharding = NamedSharding(mesh, P("x", "y")) + with error_check.error_checking_context(): + y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + f(y) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + # The unsharded version of `f` should still be able to check errors after + # exiting the error checking context. + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + @jtu.with_user_mesh( + (2, 2), + ("x", "y"), + axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), + ) + @jtu.ignore_warning( + message=( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode." + ), + category=RuntimeWarning, + ) + def test_error_check_auto_mode(self, jit, mesh): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + with error_check.error_checking_context(): + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) f(x) with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() From b456855c40da6c7904638013d58488c3ff8304a8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 19 Mar 2025 10:19:51 -0700 Subject: [PATCH 0023/1769] [pallas:mosaic_gpu] Added support for accessing cluster ID via `lax.axis_index` PiperOrigin-RevId: 738448436 --- jax/_src/pallas/mosaic_gpu/core.py | 18 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 138 +++++++++++++++++-------- tests/pallas/mosaic_gpu_test.py | 29 +++++- 3 files changed, 134 insertions(+), 51 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 630c1b8f4bed..5e4566ddfc9c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -506,17 +506,21 @@ class GPUMesh: axis_names: tuple[str, ...] = () def __post_init__(self): - if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): - raise ValueError("Need as many axis names as grid dimensions + warp groups") + if len(self.cluster) > 3: + raise ValueError(f"cluster= must be at most 3D, got {self}.") + num_axis_names = ( + len(self.grid) + len(self.cluster) + (self.num_threads is not None) + ) + if len(self.axis_names) != num_axis_names: + raise ValueError( + "Need an axis name for each grid and cluster dimension plus " + f" an additional axis name when num_threads= is given, got {self}." + ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) - if self.cluster: - raise NotImplementedError( - "Pallas/MosaicGPU does not support clusters yet." - ) @property def backend(self) -> str: @@ -556,8 +560,6 @@ def _gpu_mesh_discharge_rule( ): if not isinstance(mesh, GPUMesh): raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if mesh.cluster: - raise NotImplementedError if compiler_params and not isinstance(compiler_params, GPUCompilerParams): raise TypeError( "Compiler params must be a GPUCompilerParams, got" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2ae51a8b22e8..1fae91773178 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,11 +17,13 @@ from __future__ import annotations import collections -from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, Iterable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools +import itertools import math +import operator from typing import Any, Protocol, cast import jax @@ -233,10 +235,33 @@ def _reduce_sum_resource_estimator( return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) +@dataclasses.dataclass(frozen=True) +class _AxisNames: + grid: Sequence[Hashable] + cluster: Sequence[Hashable] = () + wg: Hashable | None = None + + def __iter__(self) -> Iterable[Hashable]: + return itertools.chain( + self.grid, self.cluster, [self.wg] if self.wg is not None else [] + ) + + @classmethod + def from_mesh( + cls, mesh: gpu_core.GPUMesh, axis_names: Sequence[str] + ) -> "_AxisNames": + wg_name = None + if mesh.num_threads is not None: + wg_name = axis_names[-1] + axis_names = axis_names[:-1] + grid_names, cluster_names = util.split_list(axis_names, [len(mesh.grid)]) + return cls(grid_names, cluster_names, wg_name) + + @dataclasses.dataclass class ModuleContext: name: str - grid_names: Sequence[Hashable] | None + axis_names: _AxisNames | None program_ids: Sequence[ir.Value] | None approx_math: bool single_wg_lane_predicate: ir.Value @@ -565,10 +590,15 @@ def body_fn(*refs): ) assert not new_consts + axis_names = ( + _AxisNames.from_mesh(mesh, grid_mapping.grid_names) + if mesh is not None + else _AxisNames(grid_mapping.grid_names) + ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( parallel_grid, - grid_mapping.grid_names, + axis_names, block, mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], @@ -581,7 +611,7 @@ def body_fn(*refs): def lower_jaxpr_to_module( grid: Sequence[int], - grid_names: Sequence[str], + axis_names: _AxisNames, block: Sequence[int], cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], @@ -597,6 +627,11 @@ def lower_jaxpr_to_module( "thread_semantics", mgpu_core.ThreadSemantics.Lane ) + if len(cluster) < 3: + cluster = cluster + (1,) * (3 - len(cluster)) + else: + assert len(cluster) == 3 + if len(grid) <= 3: squashed_dims = () parallel_grid = grid + (1,) * (3 - len(grid)) @@ -614,7 +649,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): grouped_barriers[barrier].append(barrier_ref) module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), - grid_names, + axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, mgpu.single_thread_predicate(per_block=False), @@ -645,7 +680,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module, out_structs_gmem, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, - grid=parallel_grid, + grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, in_shapes=in_shapes, @@ -1605,49 +1640,68 @@ def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result +def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: + result = gpu_dialect.block_id(dim) + cluster_size = ctx.launch_ctx.cluster_size + if math.prod(cluster_size) == 1 or cluster_size[dim.value] == 1: + return result + # We scale the grid in the presence of clusters, so we need to scale the + # block ID back here. + return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_names + axis_names = ctx.module_ctx.axis_names + if not axis_names or axis_name not in axis_names: + raise ValueError( + "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + ) + + if axis_names.wg is not None and axis_name == axis_names.wg: + return mgpu.warpgroup_idx(sync=True) + + if axis_name in axis_names.cluster: + idx = axis_names.cluster.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.cluster_block_id(gpu_dialect.Dimension(idx)), + ) + squashed_dims = ctx.module_ctx.squashed_dims if squashed_dims: - unsquashed_names = grid_names[-3:] - squashed_names = grid_names[:-3] + unsquashed_names = axis_names.grid[-2:] + squashed_names = axis_names.grid[:-2] else: # These are unused but initialized for type checkers. - unsquashed_names = () - squashed_names = () - if grid_names and axis_name in grid_names: - if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=True) + unsquashed_names = squashed_names = () + + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) else: - if squashed_dims: - if axis_name in unsquashed_names: - # We add 1 to the index because the first dimension is the - # squashed dimension. - # e.g. for the grid (a, b, c, d, wg) - # squashed = (a, b) Mapped to Dimension.x (0) - # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) - idx = unsquashed_names.index(axis_name) + 1 - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - elif axis_name in squashed_names: - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - axis = squashed_names.index(axis_name) - return _unravel_program_id(block_id, axis, squashed_dims) - else: - if axis_name in grid_names: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" - ) + assert axis_name in squashed_names + # All squashed dimensions are mapped to Dimension.x. + axis = squashed_names.index(axis_name) + return _unravel_program_id( + _block_id(ctx, gpu_dialect.Dimension.x), axis, squashed_dims + ) + else: + assert axis_name in axis_names.grid + idx = axis_names.grid.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) @register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6792ddfaa9a8..38335925b44d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2079,7 +2079,6 @@ def _(): result.shape) np.testing.assert_array_equal(result, ref) - def test_cross_wg_barrier(self): mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) @@ -2100,6 +2099,34 @@ def scoped(barrier): return inner(y_init) np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + def test_cluster(self): + mesh = plgpu.GPUMesh(grid=(2,), cluster=(2,), axis_names=("x", "cluster")) + + @jax.jit + def f(): + @pl.run_state + def inner(ref): + @pl.core_map(mesh) + def kernel(): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) + + ref[...] = ref[...] + return inner(jnp.zeros(128, np.int32)) + + with self.capture_stdout() as output: + jax.block_until_ready(f()) + self.assertEqual( + set(output().splitlines()), + { + "block: 0 cluster: 0", + "block: 1 cluster: 0", + "block: 0 cluster: 1", + "block: 1 cluster: 1", + }, + ) + class ExamplesTest(PallasTest): From 918192fd45e74ba1793a6507be6587cf01b814e8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 19 Mar 2025 10:35:48 -0700 Subject: [PATCH 0024/1769] Move sparse op GPU lowerings from jaxlib into JAX. PiperOrigin-RevId: 738454875 --- jax/_src/lax/linalg.py | 32 ++- jax/experimental/sparse/_base.py | 10 - jax/experimental/sparse/_lowerings.py | 174 ++++++++++--- jax/experimental/sparse/bcsr.py | 30 +-- jax/experimental/sparse/coo.py | 62 ++--- jax/experimental/sparse/csr.py | 59 ++--- jaxlib/gpu_sparse.py | 357 +------------------------- 7 files changed, 233 insertions(+), 491 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c674401fb80d..3e9077d0a51c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2511,16 +2511,30 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): "equal the dimensions of the diagonal arguments.") return b_shape -def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): +def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): _, _, _, b_aval = ctx.avals_in - if b_aval.dtype != np.float32 and b_aval.dtype != np.float64: + *batch_dims, m, n = b_aval.shape + batch_size = math.prod(batch_dims) + + mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse + assert mod is not None + opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) + if b_aval.dtype == np.float32: + buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f32_ffi" + elif b_aval.dtype == np.float64: + buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f64_ffi" + else: raise NotImplementedError( "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - m, n = b_aval.shape[-2:] - b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - return [lowering( - dl, d, du, b, m=m, n=n, ldb=m, t=b_aval.dtype, - b_shape_vals=b_shape_vals)] + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = _linalg_ffi_lowering( + f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, + batch_partitionable=False) + return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused @@ -2628,11 +2642,11 @@ def _tridiagonal_solve_jax(dl, d, du, b, **_): platform='cpu') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( _tridiagonal_solve_jax, multiple_results=False)) diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 7739af0291f1..36d84cb0db62 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,18 +19,8 @@ import jax from jax._src import core -from jax._src import ffi from jax._src import util from jax._src.typing import Array -from jax._src.lib import gpu_sparse - - -if hasattr(gpu_sparse, "registrations"): - for platform, targets in gpu_sparse.registrations().items(): - for name, value, api_version in targets: - ffi.register_ffi_target( - name, value, platform=platform, api_version=api_version - ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 6962ef78bcff..76e74d13ed69 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -18,13 +18,29 @@ """ from functools import partial +from typing import Any from jax._src import core from jax._src import dispatch +from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse import numpy as np +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +def _get_module(target_name_prefix: str) -> Any: + if target_name_prefix == "cu": + return gpu_sparse._cusparse + elif target_name_prefix == "hip": + return gpu_sparse._hipsparse + else: + raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}") SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128] SUPPORTED_INDEX_DTYPES = [np.int32] @@ -54,27 +70,30 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmv_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -103,27 +122,51 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmm_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + + batch_count = 1 + if len(shape) == 2: + rows, cols = shape + elif len(shape) == 3: + batch_count, rows, cols = shape + nnz = nnz // batch_count + else: + raise NotImplementedError(f"Unsupported shape: {shape}") + + # TODO(tianjianlu): use batch stride to trigger different mode of batch + # computation. Currently batch_stride = 0 is not allowed because of the issue + # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 + # Set batch stride to be the matrix size for now. + lhs_batch_stride = nnz + B_rows = rows if transpose else cols + rhs_batch_stride = B_rows * Ccols + + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, + rhs_batch_stride) + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] + coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') # csr_spmv_p @@ -151,30 +194,33 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmv_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') - # csr_spmm_p +# csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -199,25 +245,71 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmm_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - B_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, Ccols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') + +def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): + data_aval, row_aval, _ = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor( + data_aval.dtype, row_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi") + return rule(sub_ctx, data, row, col, opaque=opaque)[0] + +def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] + +def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): + data_aval, indices_aval, _, = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor( + data_aval.dtype, indices_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi") + return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0] + +def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7fefd1572f45..dc8be2237544 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -27,6 +27,7 @@ import jax.numpy as jnp from jax import lax from jax import tree_util +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( @@ -620,9 +621,9 @@ def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape): _bcsr_correct_out_of_bound_indices, multiple_results=True) def _bcsr_dot_general_gpu_lowering( - csr_matvec_lowering, csr_matmat_lowering, + # csr_matvec_lowering, csr_matmat_lowering, ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers, - preferred_element_type, lhs_spinfo: SparseInfo): + preferred_element_type, lhs_spinfo: SparseInfo, target_name_prefix): if not config.bcoo_cusparse_lowering.value: return _bcsr_dot_general_default_lowering( @@ -674,22 +675,23 @@ def _bcsr_dot_general_gpu_lowering( lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered( ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape) + sub_ctx = ctx if rhs_aval.ndim == 1: - dot_general_fn = csr_matvec_lowering - x_dtype = 'x_dtype' + dot_general_fn = _lowerings._csr_spmv_gpu_lowering elif rhs_aval.ndim == 2: - dot_general_fn = csr_matmat_lowering - x_dtype = 'B_dtype' + dot_general_fn = _lowerings._csr_spmm_gpu_lowering if rhs_contract[0] == 1: rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) + *avals_in, rhs_aval = sub_ctx.avals_in + rhs_aval = core.ShapedArray( + shape=(rhs_aval.shape[1], rhs_aval.shape[0]), dtype=rhs_aval.dtype) + sub_ctx = sub_ctx.replace(avals_in=[*avals_in, rhs_aval]) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") - return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs, - shape=lhs_spinfo.shape, transpose=False, - data_dtype=lhs_data_aval.dtype, - index_dtype=lhs_indices_aval.dtype, - **{x_dtype: rhs_aval.dtype})] + return dot_general_fn(sub_ctx, lhs_data, lhs_indices, lhs_indptr, rhs, + shape=lhs_spinfo.shape, transpose=False, + target_name_prefix=target_name_prefix) _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) @@ -700,14 +702,12 @@ def _bcsr_dot_general_gpu_lowering( if gpu_sparse.cuda_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.cuda_csr_matvec, - gpu_sparse.cuda_csr_matmat), + target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.rocm_csr_matvec, - gpu_sparse.rocm_csr_matmat), + target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c65bc87235d6..014fe9128c1b 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -26,6 +26,7 @@ import jax from jax import lax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util @@ -205,7 +206,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -226,8 +227,13 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_hlo( - data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) + sub_ctx = ctx + if transpose: + out_aval, = ctx.avals_out + out_aval = core.ShapedArray(shape=out_aval.shape[::-1], dtype=out_aval.dtype) + sub_ctx = sub_ctx.replace(avals_out=[out_aval]) + result = _lowerings.coo_todense_gpu_lowering( + sub_ctx, data, row, col, shape=shape, target_name_prefix=target_name_prefix) return ( [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) @@ -255,12 +261,12 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -325,20 +331,15 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, - index_dtype): +def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_hlo( - mat, nnz=nse, - data_dtype=dtype, - index_dtype=np.dtype(index_dtype), - index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, row, col] - + return _lowerings.coo_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals @@ -373,12 +374,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -444,8 +445,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, - transpose): +def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -466,9 +467,9 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_hlo( - data, row, col, v, shape=shape, transpose=transpose, - index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] + return _lowerings._coo_spmv_gpu_lowering( + ctx, data, row, col, v, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): @@ -497,12 +498,12 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -567,8 +568,8 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, - transpose): +def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,10 +590,9 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_hlo(data, row, col, B, shape=shape, - transpose=transpose, x_dtype=B_aval.dtype, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype)] + return _lowerings._coo_spmm_gpu_lowering( + ctx, data, row, col, B, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): @@ -618,10 +618,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 84171855b85e..cbc5bad1100b 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -23,6 +23,7 @@ import jax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning @@ -249,17 +250,16 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, - shape): +def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_hlo( - data, indices, indptr, shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + return [_lowerings.csr_todense_gpu_lowering( + ctx, data, indices, indptr, shape=shape, + target_name_prefix=target_name_prefix)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): @@ -284,12 +284,12 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -359,16 +359,16 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, + target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_hlo( - mat, nnz=nse, index_dtype=np.dtype(index_dtype), - data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, indices, indptr] + return _lowerings.csr_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): @@ -404,12 +404,12 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -470,8 +470,8 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, - shape, transpose): +def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -479,10 +479,9 @@ def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_hlo( - data, indices, indptr, v, shape=shape, transpose=transpose, - data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] - + return _lowerings._csr_spmv_gpu_lowering( + ctx, data, indices, indptr, v, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) @@ -511,12 +510,12 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -580,8 +579,8 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, - shape, transpose): +def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,11 +588,9 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_hlo( - data, indices, indptr, B, shape=shape, transpose=transpose, - index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, - B_dtype=B_aval.dtype)] - + return _lowerings._csr_spmm_gpu_lowering( + ctx, data, indices, indptr, B, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) @@ -621,10 +618,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d8645041c946..cc2b2ad08e55 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -11,25 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -cusparse wrappers for performing sparse matrix computations in JAX -""" -import math -from functools import partial from typing import Any -import jaxlib.mlir.ir as ir - -import numpy as np - -from .hlo_helpers import custom_call, mk_result_types_and_shapes - from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") +cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) +rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) + def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: @@ -38,346 +30,3 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: (name, value, int(name.endswith("_ffi"))) for name, value in module.registrations().items()) return registrations # pytype: disable=bad-return-type - - -cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) -rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) - - -def _validate_csr_hlo(data, indices, indptr, shape): - data_type = ir.RankedTensorType(data.type) - indices_type = ir.RankedTensorType(indices.type) - indptr_type = ir.RankedTensorType(indptr.type) - - nnz, = data_type.shape - assert indices_type.shape == [nnz] - assert indptr_type.element_type == indices_type.element_type - assert indptr_type.shape == [shape[0] + 1] - return data_type.element_type, indices_type.element_type, nnz - -def _validate_coo_hlo(data, row, col): - data_type = ir.RankedTensorType(data.type) - row_type = ir.RankedTensorType(row.type) - col_type = ir.RankedTensorType(col.type) - - nnz, = data_type.shape - assert row_type.shape == [nnz] - assert col_type.element_type == row_type.element_type - assert col_type.shape == [nnz] - return data_type.element_type, row_type.element_type, nnz - - -def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): - """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) - - -def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): - """CSR from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([rows + 1], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) - - -def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): - """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) - - -def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): - """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( - data_dtype, B_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matmat_ffi", - result_types=[ - ir.RankedTensorType.get([out_size, Ccols], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) - - -def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): - """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) - - -def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): - """COO from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) - - -def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): - """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_coo_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) - - -def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): - """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - is_batched_matmat = False - batch_count = 1 - if len(shape) == 2: - rows, cols = shape - elif len(shape) == 3: - is_batched_matmat = True - batch_count, rows, cols = shape - # Redefine nnz as nnz per batch. - nnz = nnz // batch_count - - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - # TODO(tianjianlu): use batch stride to trigger different mode of batch - # computation. Currently batch_stride = 0 is not allowed because of the issue - # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 - # Set batch stride to be the matrix size for now. - lhs_batch_stride = nnz - B_rows = rows if transpose else cols - rhs_batch_stride = B_rows * Ccols - - buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, - rhs_batch_stride) - out_size = cols if transpose else rows - - if is_batched_matmat: - out_shape = [batch_count, out_size, Ccols] - out_layout = [2, 1, 0] - else: - out_shape = [out_size, Ccols] - out_layout = [1, 0] - - out = custom_call( - f"{platform}sparse_coo_matmat_ffi", - result_types=[ - ir.RankedTensorType.get(out_shape, compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[out_layout, [0]]).results - return out[0] - -cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) - - -def _gtsv2_hlo( - platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): - """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" - assert len(b_shape_vals) >= 2 - batch_dim_vals = b_shape_vals[:-2] - batch_size = math.prod(batch_dim_vals) - num_bd = len(b_shape_vals) - 2 - f32 = (t == np.float32) - if f32: - buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) - else: - buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) - - b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) - b_type = ir.RankedTensorType(B.type) - - shape_type_pairs = [ - (batch_dim_vals + (ldb, n), b_type.element_type), - ((buffer_size,), ir.IntegerType.get_signless(8)) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb) - out = custom_call( - f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi", - result_types=result_types, - operands=[dl, d, du, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[d_layout] * 3 + [b_layout], - result_layouts=[b_layout, [0]], - operand_output_aliases={3: 0}, - result_shapes=result_shapes).results - return out[0] - -cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) From 84ec21e03e88014a2cdaebee23075924f2054181 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 10:42:17 -0700 Subject: [PATCH 0025/1769] Add sliding window support to the ragged paged attention. PiperOrigin-RevId: 738457532 --- .../pallas/ops/tpu/ragged_paged_attention.py | 23 +++++++++-- .../pallas/tpu_ragged_paged_attention_test.py | 40 ++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 6600d765024c..a9b61da290f7 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -19,7 +19,6 @@ specifications. It supports mixed prefill and decoding, enhancing throughput during inference. """ - import functools import jax from jax import lax @@ -81,6 +80,7 @@ def ref_ragged_paged_attention( num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, + sliding_window: int | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): _, _, num_kv_heads, head_dim = k_pages.shape @@ -105,7 +105,10 @@ def ref_ragged_paged_attention( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - attn += jnp.where(q_span < kv_span, mask_value, 0.0) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) outputs.append(out) @@ -122,6 +125,7 @@ def validate_inputs_on_runtime( page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] + sliding_window: int | None = None, ): check_inputs_shapes( q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs @@ -150,6 +154,8 @@ def validate_inputs_on_runtime( raise ValueError( f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." ) + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") # Expect to run these checks during compile time. @@ -221,7 +227,8 @@ def ragged_paged_attention_kernel( m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] *, sm_scale: float, - mask_value: float, + sliding_window: int | None = None, + mask_value: float = DEFAULT_MASK_VALUE, ): num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] @@ -373,7 +380,7 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) @@ -422,6 +429,9 @@ def init_scratch_ref(): 1, ) causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window>=col_ids) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -601,6 +611,7 @@ def can_be_xla_fully_tiled(x, packing): "num_kv_pages_per_block", "num_queries_per_block", "vmem_limit_bytes", + "sliding_window", ], ) def ragged_paged_attention( @@ -614,6 +625,7 @@ def ragged_paged_attention( num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, + sliding_window: int | None = None, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, @@ -632,6 +644,7 @@ def ragged_paged_attention( kv_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. @@ -705,6 +718,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): functools.partial( ragged_paged_attention_kernel, sm_scale=sm_scale, + sliding_window=sliding_window, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( @@ -724,6 +738,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), name="ragged_paged_attention_kernel", ) + # TODO(jevinjiang): Use f32 acc scratch for output! So we only need # to transfer output with desired dtype back to HBM. return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index bffcebc5254b..80d78ec32d07 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -13,6 +13,7 @@ # limitations under the License. import random + from absl.testing import absltest from absl.testing import parameterized import jax @@ -50,6 +51,7 @@ def _test_ragged_paged_attention( vmem_limit_bytes=32 * 1024 * 1024, max_num_batched_tokens=512, max_num_seq=8, + sliding_window: int | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -101,8 +103,10 @@ def _test_ragged_paged_attention( page_indices, cu_q_lens, num_seqs, + sliding_window=sliding_window, ) + actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, k_pages, @@ -114,7 +118,8 @@ def _test_ragged_paged_attention( num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, - )[: cu_q_lens[num_seqs[0]]] + sliding_window=sliding_window, + )[: actual_num_q_tokens] expected = ref_ragged_paged_attention( q, @@ -124,6 +129,7 @@ def _test_ragged_paged_attention( page_indices, cu_q_lens, num_seqs=num_seqs, + sliding_window=sliding_window, ) tols = { "float32": 0.15, @@ -266,6 +272,7 @@ def test_ragged_paged_attention_mixed(self, dtype): dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], ) def test_ragged_paged_attention_complex( self, @@ -274,6 +281,7 @@ def test_ragged_paged_attention_complex( dtype, num_kv_pages_per_block, num_queries_per_block, + sliding_window: int | None, ): seq_lens = [] for _ in range(num_seqs): @@ -294,8 +302,38 @@ def test_ragged_paged_attention_complex( num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, + sliding_window=sliding_window, ) + def test_ragged_paged_attention_sliding_window_should_be_positive(self): + dtype=jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=-1, + ) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5a5415bcda4944edb0009a94f793031a8708f0d5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 19 Mar 2025 19:42:44 +0200 Subject: [PATCH 0026/1769] Rename arguments x, y of assertAllClose and friends to actual, expected. --- jax/_src/test_util.py | 81 +++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 18f7efa16223..0dace13821fc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1343,15 +1343,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): else: return self.assertWarnsRegex(DeprecationWarning, message) - def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', + def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: - self.assertDtypesMatch(x, y) - x = np.asarray(x) - y = np.asarray(y) + self.assertDtypesMatch(actual, desired) + actual = np.asarray(actual) + desired = np.asarray(desired) - if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object): # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " @@ -1361,57 +1361,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', # Work around https://github.com/numpy/numpy/issues/18992 with np.errstate(over='ignore'): - np.testing.assert_array_equal(x, y, err_msg=err_msg, + np.testing.assert_array_equal(actual, desired, err_msg=err_msg, verbose=verbose) - def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, err_msg=''): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + """Assert that actual and desired are close (up to numerical tolerances).""" + self.assertEqual(actual.shape, desired.shape) + atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol)) + rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol)) - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) + _assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg) if check_dtypes: - self.assertDtypesMatch(x, y) + self.assertDtypesMatch(actual, desired) - def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): + def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True): if not config.enable_x64.value and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True)) else: - self.assertEqual(_dtype(x), _dtype(y)) + self.assertEqual(_dtype(actual), _dtype(desired)) - def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, + def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + """Assert that actual and desired, either arrays or nested tuples/lists, are close.""" + if isinstance(actual, dict): + self.assertIsInstance(desired, dict) + self.assertEqual(set(actual.keys()), set(desired.keys())) + for k in actual.keys(): + self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + elif is_sequence(actual) and not hasattr(actual, '__array__'): + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__')) + self.assertEqual(len(actual), len(desired)) + for actual_elt, desired_elt in zip(actual, desired): + self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif hasattr(x, '__array__') or np.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) + elif hasattr(actual, '__array__') or np.isscalar(actual): + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired)) if check_dtypes: - self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) - x = np.asarray(x) - y = np.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol, + self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) + actual = np.asarray(actual) + desired = np.asarray(desired) + self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg) - elif x == y: + elif actual == desired: return else: - raise TypeError((type(x), type(y))) + raise TypeError((type(actual), type(desired))) def assertMultiLineStrippedEqual(self, expected, what): """Asserts two strings are equal, after dedenting and stripping each line.""" @@ -1426,7 +1426,6 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") - @contextmanager def assertNoWarnings(self): with test_warning_util.raise_on_warnings(): @@ -1496,9 +1495,9 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) args = args_maker() @@ -1509,7 +1508,7 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, From 7a67c9bd63bcf160c1edb884930df1bfe1108496 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 19 Mar 2025 11:30:41 -0700 Subject: [PATCH 0027/1769] Fix lint error on main --- jax/experimental/pallas/ops/tpu/ragged_paged_attention.py | 2 +- tests/pallas/tpu_ragged_paged_attention_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index a9b61da290f7..90b808282c22 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -431,7 +431,7 @@ def init_scratch_ref(): causal_mask = row_ids < col_ids if sliding_window is not None: causal_mask = jnp.logical_or(causal_mask, - row_ids - sliding_window>=col_ids) + row_ids - sliding_window >= col_ids) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 80d78ec32d07..ba574a4ce98c 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -306,7 +306,7 @@ def test_ragged_paged_attention_complex( ) def test_ragged_paged_attention_sliding_window_should_be_positive(self): - dtype=jnp.float32 + dtype = jnp.float32 seq_lens = [(192, 328), (128, 180), (64, 255)] num_heads = (32, 8) head_dim = 128 From 9d534ad2cd40e11ce9d1f19c80300a35b9332c8d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 14:41:25 -0400 Subject: [PATCH 0028/1769] Update version numbers after JAX 0.5.3 release. --- CHANGELOG.md | 4 +++- jax/version.py | 2 +- setup.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9faff67cf305..9a817ce80937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.5.3 +## Unreleased + +## jax 0.5.3 (Mar 19, 2025) * New Features diff --git a/jax/version.py b/jax/version.py index 13df5f00a11b..6ed6a5fda600 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.3" +_version = "0.5.4" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index a5c8500dc1cf..dbb7040d2d2b 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.5.3' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.1' +_latest_jaxlib_version_on_pypi = '0.5.3' _libtpu_version = '0.0.11.*' From 4489303dfc6548878c81c4e5b3209e0aba002332 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 12:43:36 -0700 Subject: [PATCH 0029/1769] Delete `ParsedPartitionSpec` and `preprocess` function and do a couple more cleanups PiperOrigin-RevId: 738503430 --- jax/_src/interpreters/pxla.py | 41 +++++++++++--- jax/_src/named_sharding.py | 103 +++------------------------------- jax/_src/pjit.py | 38 ------------- jax/_src/sharding_impls.py | 5 +- 4 files changed, 41 insertions(+), 146 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c06eda5214ed..387f0661ae9d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2463,14 +2463,41 @@ def cost_analysis(self) -> dict[str, float]: return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) +def get_op_sharding_from_executable( + executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: + in_op_shardings: list[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.get_parameter_shardings() + if parameter_shardings_from_xla is not None: + in_op_shardings = parameter_shardings_from_xla + + out_op_shardings: list[xc.OpSharding] = [] + output_shardings_from_xla = executable.get_output_shardings() + if output_shardings_from_xla is not None: + out_op_shardings = output_shardings_from_xla + + return in_op_shardings, out_op_shardings + + +def get_pspec_from_executable( + executable, mesh: Mesh +) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: + input_op_s, output_op_s = get_op_sharding_from_executable(executable) + in_pspec: list[PartitionSpec] = [] + for s in input_op_s: + in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + + out_pspec: list[PartitionSpec] = [] + for s in output_op_s: + out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + return tuple(in_pspec), tuple(out_pspec) + + def get_out_shardings_from_executable( xla_executable, device_assignment: Sequence[xc.Device], num_out_avals: int, num_ordered_effects: int, ) -> Sequence[sharding_impls.GSPMDSharding] | None: - from jax._src import pjit - try: omk = xla_executable.get_output_memory_kinds()[0] if num_ordered_effects > 0: @@ -2486,7 +2513,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) for mk in omk] - _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2517,14 +2544,12 @@ def _get_in_shardings_from_xla( num_ordered_effects: int ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" - from jax._src import pjit - # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals - in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) + in_op_shardings, _ = get_op_sharding_from_executable(xla_executable) if not in_op_shardings: return None @@ -2543,9 +2568,7 @@ def _get_in_shardings_from_xla( def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: - from jax._src import pjit - - in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh) return ([NamedSharding(mesh, i) for i in in_pspec], [NamedSharding(mesh, o) for o in out_pspec]) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5accdd880a79..3d5b2e67f169 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -21,11 +21,11 @@ from typing import Any, Union from jax._src import config -from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert +from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding from jax._src import xla_bridge as xb import numpy as np @@ -198,7 +198,7 @@ def is_fully_addressable(self) -> bool: # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client + client = self._internal_device_list[0].client # type: ignore return (len(self.mesh._process_indices) == 1 and next(iter(self.mesh._process_indices)) == xb.process_index(client)) @@ -325,80 +325,6 @@ def __repr__(self): if self.replicated_axes else '') return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" -# TODO(yashkatariya): Remove this after jax 0.5.2 release -class ParsedPartitionSpec: - __slots__ = ('_user_spec', 'partitions') - - _user_spec: PartitionSpec | None - partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...] - - def __init__(self, user_spec, partitions): - self._user_spec = user_spec - assert None not in partitions, partitions - self.partitions = tuple(partitions) - - def get_partition_spec(self) -> PartitionSpec: - if isinstance(self._user_spec, PartitionSpec): - return self._user_spec - else: - return get_single_pspec(self) - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - return ParsedPartitionSpec(None, new_partitions) - - @classmethod - def from_user_input( - cls, - entry: PartitionSpec | None, - arg_name: str, - allow_unconstrained_dims: bool = False, - ) -> ParsedPartitionSpec: - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec is PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = PartitionSpec.UNCONSTRAINED - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - new_entry = PartitionSpec( - *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) - return cls(new_entry, axis_specs) - - def __hash__(self): - return hash(self.partitions) - - def __eq__(self, other): - if not isinstance(other, ParsedPartitionSpec): - return False - return self.partitions == other.partitions - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return f"ParsedPartitionSpec(partitions={self.partitions})" @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( @@ -491,18 +417,8 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): partitions.append(None) return PartitionSpec(*partitions) -get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - if parsed_pspec is None: - spec = PartitionSpec() if spec is None else spec - parsed_pspec = ParsedPartitionSpec.from_user_input( - spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) - return parsed_pspec +@cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, spec, _manual_axes) @@ -517,13 +433,10 @@ def __init__(self, message, mesh, pspec): def __str__(self): return f"{self.message}" -def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, -) -> None: +def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None + ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue @@ -542,10 +455,8 @@ def _check_unique_resources( f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -@cache(max_size=128, trace_context_in_key=False) + def _check_mesh_resource_axis(mesh, pspec, _manual_axes): - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b6024dcdfedd..d690cd6e9c67 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2905,41 +2905,3 @@ def get_unconstrained_dims(sharding: NamedSharding): assert sharding.spec is not None return {i for i, axes in enumerate(sharding.spec) if axes is PartitionSpec.UNCONSTRAINED} - - -def get_op_sharding_from_executable( - executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - in_op_shardings: list[xc.OpSharding] = [] - parameter_shardings_from_xla = executable.get_parameter_shardings() - if parameter_shardings_from_xla is not None: - in_op_shardings = parameter_shardings_from_xla - - out_op_shardings: list[xc.OpSharding] = [] - output_shardings_from_xla = executable.get_output_shardings() - if output_shardings_from_xla is not None: - out_op_shardings = output_shardings_from_xla - - return in_op_shardings, out_op_shardings - - -def _get_ppspec_from_executable( - executable, mesh - ) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]: - input_op_shardings, output_op_sharding = get_op_sharding_from_executable( - executable - ) - in_pspec: list[PartitionSpec] = [] - for s in input_op_shardings: - in_pspec.extend(parse_flatten_op_sharding(s, mesh)) - - out_pspec: list[PartitionSpec] = [] - for s in output_op_sharding: - out_pspec.extend(parse_flatten_op_sharding(s, mesh)) - return in_pspec, out_pspec - - -def get_pspec_from_executable( - executable, mesh: pxla.Mesh -) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: - in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh) - return tuple(in_pspec), tuple(out_pspec) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f3295a75cf7a..0ed8568e4bcc 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -37,10 +37,9 @@ from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, - ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED, + _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, get_single_pspec, preprocess, - named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding) from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec From 362fb7ae9d1413e0eadd4ff7227b318c99700a8b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 13:39:18 -0700 Subject: [PATCH 0030/1769] Remove code to support jaxlib < 0.5.3. The new xla_extension_version is 320. PiperOrigin-RevId: 738522486 --- jax/_src/export/_export.py | 16 ++--- jax/_src/interpreters/mlir.py | 9 +-- jax/_src/lax/lax.py | 6 -- jax/_src/sharding_impls.py | 4 +- jax/_src/util.py | 62 +------------------ .../mosaic/gpu/dialect_lowering.py | 8 +-- .../mosaic/gpu/transform_inference.py | 6 +- jax/experimental/sparse/linalg.py | 6 -- tests/lax_test.py | 6 -- tests/linalg_test.py | 3 - 10 files changed, 16 insertions(+), 110 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index afae3d9bcdc2..9b6a0f80930f 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -43,7 +43,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect @@ -674,10 +674,8 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(mlir_module)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(mlir_module)) mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) @@ -784,7 +782,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if xla_extension_version >= 319 and shardy_enabled: + if shardy_enabled: mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( mlir.module_to_bytecode(module)) else: @@ -1423,10 +1421,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(submodule)) if shardy_enabled: submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( mlir.module_to_bytecode(submodule))) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..8f257b976dff 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -55,7 +55,7 @@ SdyArraySharding, SdyArrayShardingList) from jax._src.util import foreach from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects @@ -3031,11 +3031,8 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) - if xla_extension_version >= 319: - refined_module_str = refine_polymorphic_shapes( - enable_shardy=config.use_shardy_partitioner.value) - else: - refined_module_str = refine_polymorphic_shapes() + refined_module_str = refine_polymorphic_shapes( + enable_shardy=config.use_shardy_partitioner.value) except Exception as e: raise ValueError( "Error refining shapes. " + diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 86a75ada63ad..388ad49ec83d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -66,7 +66,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_extension_version from jax._src.sharding_impls import (PmapSharding, NamedSharding, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape @@ -2267,11 +2266,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, case DotAlgorithmPreset.BF16_BF16_F32_X6: return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) case DotAlgorithmPreset.BF16_BF16_F32_X9: - if xla_extension_version < 320: - raise ValueError( - "The dot algorithm BF16_BF16_F32_X9 requires XLA extension " - "version >= 320." - ) return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False) case DotAlgorithmPreset.TF32_TF32_F32: return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0ed8568e4bcc..efa1b4cfd5b6 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -33,7 +33,6 @@ from jax._src import xla_bridge as xb from jax._src import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, @@ -881,8 +880,7 @@ def parse_flatten_op_sharding( return out elif hlo_sharding.is_replicated(): return [PartitionSpec()] - elif (xla_extension_version >= 319 and hlo_sharding.is_maximal() - and mesh.size == 1): + elif hlo_sharding.is_maximal() and mesh.size == 1: return [PartitionSpec()] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape diff --git a/jax/_src/util.py b/jax/_src/util.py index 0e28aea04b5a..d558954e881c 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -108,11 +108,7 @@ def foreach(f, *args): return None else: - # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. - if hasattr(jaxlib_utils, 'foreach'): - foreach = jaxlib_utils.foreach - else: - foreach = safe_map + foreach = jaxlib_utils.foreach def unzip2(xys: Iterable[tuple[T1, T2]] @@ -244,61 +240,8 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. toposort: Callable[[Iterable[Any]], list[Any]] -if hasattr(jaxlib_utils, "topological_sort"): - toposort = partial(jaxlib_utils.topological_sort, "parents") -else: - - def toposort(end_nodes): - if not end_nodes: - return [] - end_nodes = _remove_duplicates(end_nodes) - - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 - - sorted_nodes = [] - childless_nodes = [ - node for node in end_nodes if child_counts[id(node)] == 0 - ] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) - else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] - - check_toposort(sorted_nodes) - return sorted_nodes - - def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) - - def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out +toposort = partial(jaxlib_utils.topological_sort, "parents") def split_merge(predicate, xs): @@ -320,7 +263,6 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge - def _ignore(): return None diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index d605a2dea8f9..ae702d50ebb7 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -849,13 +849,9 @@ def _mgpu_wait_op_lowering_rule( return [] -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - - -@_register_lowering(SliceSMEMOp) +@_register_lowering(mgpu.SliceSMEMOp) def _mgpu_slice_smem_op_lowering_rule( - ctx: LoweringContext, op: SliceSMEMOp + ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx return [_slice_smem(op.result.type, op.offset)] diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index ef2d3661674c..d285e5df188f 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -172,11 +172,9 @@ def _infer_vector_load_store_transforms( return None -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) -@partial(_add_transform_inference_rule, SliceSMEMOp) -def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: +@partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) +def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: transforms = None uses = cast(ir.OpResult, op.result).uses diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a931b0a30dcf..b2e57caba9a6 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -29,7 +29,6 @@ from jax._src import core from jax._src import ffi from jax._src.interpreters import ad -from jax._src.lib import gpu_solver import numpy as np from scipy.sparse import csr_matrix, linalg @@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - # TODO(danfm): remove after JAX 0.5.1 release. - if hasattr(gpu_solver, "cuda_csrlsvqr"): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( ctx, data, indices, indptr, b, tol=np.float64(tol), reorder=np.int32(reorder)) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8764caeb2e49..f7cca2c9b48f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -1128,11 +1127,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): - if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and - xla_extension_version < 320): - raise SkipTest( - f"The dot algorithm ${algorithm} requires XLA extension version " - ">= 320.") # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { diff --git a/tests/linalg_test.py b/tests/linalg_test.py index feab105ccbe2..60e507d84782 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -867,9 +867,6 @@ def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorith self.skipTest("Hermitian SVD doesn't support the algorithm parameter.") if not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("SVD algorithm selection only supported on CPU and GPU.") - # TODO(danfm): Remove this check after 0.5.2 is released. - if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1): - self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.") if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI: self.skipTest("Jacobi SVD not supported on GPU.") From 29e90a30cd7b4373ba3755fd1ddc9b2abc4b85d4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 19 Mar 2025 13:48:01 -0700 Subject: [PATCH 0031/1769] Add a presubmit check to test against oldest supported numpy PiperOrigin-RevId: 738525650 --- .github/workflows/oldest_supported_numpy.yml | 60 ++++++++++++++++++++ ci/run_pytest_cpu.sh | 6 +- ci/utilities/install_wheels_locally.sh | 34 +++++------ 3 files changed, 81 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/oldest_supported_numpy.yml diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml new file mode 100644 index 000000000000..80e0cb154ecd --- /dev/null +++ b/.github/workflows/oldest_supported_numpy.yml @@ -0,0 +1,60 @@ +# CI - Oldest Supported NumPy (presubmit) +# This workflow tests the oldest supported NumPy and jaxlib versions. + +name: CI - Oldest Supported NumPy (presubmit) + +on: + pull_request: + branches: + - main + push: + branches: + - main + - 'release/**' + +# This should also be set to read-only in the project settings, but it's nice to +# document and enforce the permissions here. +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} + +jobs: + test-oldest-supported-numpy: + if: github.event.repository.fork == false + defaults: + run: + shell: bash + runs-on: "linux-x86-n2-64" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" +# Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "CI - Oldest Supported NumPy (Python 3.10, x64=0)" +# End Presubmit Naming Check github-oldest-supported-numpy-presubmit + + env: + JAXCI_PYTHON: "python3.10" + JAXCI_ENABLE_X64: 0 + JAX_NUM_GENERATED_CASES: 5 + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + + # Install NumPy and SciPy with the oldest supported versions + $JAXCI_PYTHON -m uv pip install numpy==1.25.2 scipy==1.11.1 + + # Install JAX using the changes in the PR + $JAXCI_PYTHON -m uv pip install -e .[minimum-jaxlib] + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest CPU tests + timeout-minutes: 30 + run: ./ci/run_pytest_cpu.sh \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..9de29691f753 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,13 +26,13 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" "$JAXCI_PYTHON" -m uv pip list diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..64f88765bb75 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -30,23 +30,25 @@ for i in "${!WHEELS[@]}"; do fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file From 945582add83b33d4796b44a57935133232a33d3e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 19 Mar 2025 14:31:13 -0700 Subject: [PATCH 0032/1769] jax.numpy: add tests for __jax_array__ handling --- tests/BUILD | 9 + tests/array_extensibility_test.py | 516 ++++++++++++++++++++++++++++++ 2 files changed, 525 insertions(+) create mode 100644 tests/array_extensibility_test.py diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..6706971dc7d1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -80,6 +80,15 @@ jax_py_test( ] + py_deps("absl/testing"), ) +jax_py_test( + name = "array_extensibility_test", + srcs = ["array_extensibility_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py new file mode 100644 index 000000000000..45c83f7967ce --- /dev/null +++ b/tests/array_extensibility_test.py @@ -0,0 +1,516 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest +from absl.testing import parameterized +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + + +class JaxArrayWrapper: + """Class that provides a __jax_array__ method.""" + x: ArrayLike + + def __init__(self, x: ArrayLike): + self.x = x + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + +class NumPyAPI(NamedTuple): + fun: Callable[..., Any] + args: list[jax.ShapeDtypeStruct] + kwargs: dict[str, Any] + + def name(self): + return self.fun.__name__ + + def make_args(self, rng): + rng = jtu.rand_default(rng) + return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + + @classmethod + def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': + return cls(fun, args, kwargs) + + +class ShapeDtype: + """Shortcut for specifying ShapeDtypeStruct.""" + def __init__(self, dtype): + self.dtype = jax.dtypes.canonicalize_dtype(dtype) + def __getitem__(self, shape) -> jax.ShapeDtypeStruct: + if isinstance(shape, int): + shape = (shape,) + return jax.ShapeDtypeStruct(shape, self.dtype) + +Bool = ShapeDtype(bool) +Int = ShapeDtype(int) +Uint8 = ShapeDtype('uint8') +Float = ShapeDtype(float) +Complex = ShapeDtype(complex) + + +# NumPy namespace objects skipped in the enumeration below, mainly because +# they are not functions or do not take arrays as positional arguments. +SKIPPED_APIS = [ + 'apply_along_axis', + 'apply_over_axes', + 'arange', + 'astype', + 'bartlett', + 'bfloat16', + 'blackman', + 'block', + 'bool', + 'bool_', + 'broadcast_shapes', + 'c_', + 'cdouble', + 'character', + 'complex128', + 'complex64', + 'complex_', + 'complexfloating', + 'csingle', + 'diag_indices', + 'double', + 'dtype', + 'e', + 'einsum', + 'einsum_path', + 'euler_gamma', + 'empty', + 'eye', + 'finfo', + 'flexible', + 'float_', + 'float16', + 'float32', + 'float4_e2m1fn', + 'float64', + 'float8_e3m4', + 'float8_e4m3', + 'float8_e4m3b11fnuz', + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + 'floating', + 'from_dlpack', + 'frombuffer', + 'fromfile', + 'fromfunction', + 'fromiter', + 'frompyfunc', + 'fromstring', + 'full', + 'generic', + 'geomspace', + 'get_printoptions', + 'gradient', + 'hamming', + 'hanning', + 'identity', + 'iinfo', + 'index_exp', + 'indices', + 'inexact', + 'inf', + 'int16', + 'int2', + 'int32', + 'int4', + 'int64', + 'int8', + 'int_', + 'integer', + 'isdtype', + 'issubdtype' + 'iterable' + 'kaiser' + 'kron' + 'ix_', + 'linalg', + 'linspace', + 'load', + 'logspace', + 'mask_indices', + 'mgrid', + 'nan', + 'ndarray', + 'newaxis', + 'number', + 'object_', + 'ogrid', + 'ones', + 'pi', + 'printoptions', + 'promote_types' + 'r_', + 'result_type', + 's_', + 'save', + 'savez', + 'set_printoptions', + 'signedinteger', + 'single', + 'tri', + 'tril_indices', + 'triu_indices', + 'ufunc', + 'uint', + 'uint16', + 'uint2', + 'uint32', + 'uint4', + 'uint64', + 'uint8', + 'unsignedinteger', + 'vectorize', + 'zeros', +] + +# TODO(jakevdp): commented APIs are ones which do not yet support +# __jax_array__ on inputs. We should fix these! +NUMPY_APIS = [ + NumPyAPI.sig(jnp.abs, Float[5]), + NumPyAPI.sig(jnp.absolute, Float[5]), + NumPyAPI.sig(jnp.acos, Float[5]), + NumPyAPI.sig(jnp.acosh, Float[5]), + NumPyAPI.sig(jnp.add, Float[5], Float[5]), + NumPyAPI.sig(jnp.all, Bool[5]), + NumPyAPI.sig(jnp.allclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.amax, Float[5]), + NumPyAPI.sig(jnp.amin, Float[5]), + NumPyAPI.sig(jnp.angle, Float[5]), + NumPyAPI.sig(jnp.any, Float[5]), + NumPyAPI.sig(jnp.append, Float[10], Float[()]), + NumPyAPI.sig(jnp.arccos, Float[5]), + NumPyAPI.sig(jnp.arccosh, Float[5]), + NumPyAPI.sig(jnp.arcsin, Float[5]), + NumPyAPI.sig(jnp.arcsinh, Float[5]), + NumPyAPI.sig(jnp.arctan, Float[5]), + NumPyAPI.sig(jnp.arctan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.arctanh, Float[5]), + NumPyAPI.sig(jnp.argmax, Float[10]), + NumPyAPI.sig(jnp.argmin, Float[10]), + NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), + NumPyAPI.sig(jnp.argsort, Float[10]), + # NumPyAPI.sig(jnp.argwhere, [float], [(10,)]), + NumPyAPI.sig(jnp.around, Float[5]), + NumPyAPI.sig(jnp.array, Float[5]), + NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), + # NumPyAPI.sig(jnp.array_repr, Float[5]), + NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), + # NumPyAPI.sig(jnp.array_str, Float[5]), + NumPyAPI.sig(jnp.asarray, Float[5]), + NumPyAPI.sig(jnp.asin, Float[5]), + NumPyAPI.sig(jnp.asinh, Float[5]), + NumPyAPI.sig(jnp.atan, Float[5]), + NumPyAPI.sig(jnp.atan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.atanh, Float[5]), + NumPyAPI.sig(jnp.atleast_1d, Float[5]), + NumPyAPI.sig(jnp.atleast_2d, Float[5]), + NumPyAPI.sig(jnp.atleast_3d, Float[5]), + NumPyAPI.sig(jnp.average, Float[10]), + # NumPyAPI.sig(jnp.bincount, int[10]), + NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_count, Int[5]), + NumPyAPI.sig(jnp.bitwise_invert, Int[5]), + NumPyAPI.sig(jnp.bitwise_left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_not, Int[5]), + NumPyAPI.sig(jnp.bitwise_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), + # NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), + # NumPyAPI.sig(jnp.can_cast, Float[()], to='int32'), + NumPyAPI.sig(jnp.cbrt, Float[5]), + NumPyAPI.sig(jnp.ceil, Float[5]), + # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), + NumPyAPI.sig(jnp.clip, Float[5]), + # NumPyAPI.sig(jnp.column_stack, [float], [(3, 10)]), + NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), + # NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + # NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.conj, Float[5]), + NumPyAPI.sig(jnp.conjugate, Float[5]), + NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), + NumPyAPI.sig(jnp.copy, Float[5]), + NumPyAPI.sig(jnp.copysign, Float[5], Float[5]), + NumPyAPI.sig(jnp.corrcoef, Float[7], Float[7]), + NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), + NumPyAPI.sig(jnp.cos, Float[5]), + NumPyAPI.sig(jnp.cosh, Float[5]), + # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), + # NumPyAPI.sig(np.cov, [float], [(10,)]), + # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), + # NumPyAPI.sig(np.cumprod, [float], [(10,)]), + # NumPyAPI.sig(np.cumsum, [float], [(10,)]), + # NumPyAPI.sig(np.cumulative_prod, [float], [(10,)]), + # NumPyAPI.sig(np.cumulative_sum, [float], [(10,)]), + NumPyAPI.sig(jnp.deg2rad, Float[5]), + NumPyAPI.sig(jnp.degrees, Float[5]), + # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.diag, Float[5]), + # NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diagflat, Float[5]), + NumPyAPI.sig(jnp.diagonal, Float[5, 5]), + NumPyAPI.sig(jnp.diff, Float[5]), + NumPyAPI.sig(jnp.digitize, Float[5], Float[5]), + NumPyAPI.sig(jnp.divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.dot, Float[5], Float[5]), + NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), + # NumPyAPI.sig(jnp.dstack, Float[3, 5]), + NumPyAPI.sig(jnp.ediff1d, Float[5]), + # NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.exp, Float[5]), + NumPyAPI.sig(jnp.exp2, Float[5]), + # NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expm1, Float[5]), + NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), + NumPyAPI.sig(jnp.fabs, Float[5]), + NumPyAPI.sig(jnp.fft.fft, Float[5]), + NumPyAPI.sig(jnp.fft.fft2, Float[5, 5]), + NumPyAPI.sig(jnp.fft.ifft, Float[5]), + NumPyAPI.sig(jnp.fft.ifft2, Float[5, 5]), + NumPyAPI.sig(jnp.fill_diagonal, Float[5, 5], Float[()], inplace=False), + NumPyAPI.sig(jnp.fix, Float[5]), + NumPyAPI.sig(jnp.flatnonzero, Float[5]), + NumPyAPI.sig(jnp.flip, Float[5]), + NumPyAPI.sig(jnp.fliplr, Float[5, 5]), + NumPyAPI.sig(jnp.flipud, Float[5, 5]), + NumPyAPI.sig(jnp.float_power, Float[5], Float[5]), + NumPyAPI.sig(jnp.floor, Float[5]), + NumPyAPI.sig(jnp.floor_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmax, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.frexp, Float[5]), + # NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), + NumPyAPI.sig(jnp.greater, Float[5], Float[5]), + NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), + # NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), + # NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + # NumPyAPI.sig(jnp.hsplit, Float[3, 5], Int[1]), + NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), + NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), + NumPyAPI.sig(jnp.i0, Float[5]), + NumPyAPI.sig(jnp.imag, Complex[5]), + NumPyAPI.sig(jnp.inner, Float[5], Float[5]), + NumPyAPI.sig(jnp.insert, Float[5], Int[()], Float[2]), + NumPyAPI.sig(jnp.interp, Float[10], Float[5], Float[5]), + NumPyAPI.sig(jnp.intersect1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.invert, Int[5]), + NumPyAPI.sig(jnp.isclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.iscomplex, Float[5]), + NumPyAPI.sig(jnp.iscomplexobj, Complex[5]), + NumPyAPI.sig(jnp.isfinite, Float[5]), + NumPyAPI.sig(jnp.isin, Int[5], Int[10]), + NumPyAPI.sig(jnp.isinf, Float[5]), + NumPyAPI.sig(jnp.isnan, Float[5]), + # NumPyAPI.sig(jnp.isneginf, Float[5]), + # NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isreal, Float[5]), + NumPyAPI.sig(jnp.isrealobj, Float[5]), + NumPyAPI.sig(jnp.isscalar, Float[()]), + NumPyAPI.sig(jnp.lcm, Int[5], Int[5]), + NumPyAPI.sig(jnp.ldexp, Float[5], Int[5]), + NumPyAPI.sig(jnp.left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.less, Float[5], Float[5]), + NumPyAPI.sig(jnp.less_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.lexsort, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.log, Float[5]), + NumPyAPI.sig(jnp.log10, Float[5]), + NumPyAPI.sig(jnp.log1p, Float[5]), + NumPyAPI.sig(jnp.log2, Float[5]), + NumPyAPI.sig(jnp.logaddexp, Float[5], Float[5]), + NumPyAPI.sig(jnp.logaddexp2, Float[5], Float[5]), + NumPyAPI.sig(jnp.logical_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_not, Int[5]), + NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), + # NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.max, Float[5]), + NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mean, Float[5]), + NumPyAPI.sig(jnp.median, Float[5]), + NumPyAPI.sig(jnp.meshgrid, Float[5], Float[5]), + NumPyAPI.sig(jnp.min, Float[5]), + NumPyAPI.sig(jnp.minimum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mod, Float[5], Float[5]), + NumPyAPI.sig(jnp.modf, Float[5]), + NumPyAPI.sig(jnp.moveaxis, Float[5, 3], source=0, destination=1), + NumPyAPI.sig(jnp.multiply, Float[5], Float[5]), + NumPyAPI.sig(jnp.nan_to_num, Float[5]), + NumPyAPI.sig(jnp.nanargmax, Float[5]), + NumPyAPI.sig(jnp.nanargmin, Float[5]), + NumPyAPI.sig(jnp.nancumprod, Float[5]), + NumPyAPI.sig(jnp.nancumsum, Float[5]), + NumPyAPI.sig(jnp.nanmax, Float[5]), + NumPyAPI.sig(jnp.nanmean, Float[5]), + NumPyAPI.sig(jnp.nanmedian, Float[5]), + NumPyAPI.sig(jnp.nanmin, Float[5]), + NumPyAPI.sig(jnp.nanpercentile, Float[5], q=75), + NumPyAPI.sig(jnp.nanprod, Float[5]), + NumPyAPI.sig(jnp.nanquantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.nanstd, Float[5]), + NumPyAPI.sig(jnp.nansum, Float[5]), + NumPyAPI.sig(jnp.nanvar, Float[5]), + # NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.negative, Float[5]), + NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), + NumPyAPI.sig(jnp.nonzero, Float[5]), + NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), + # NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.outer, Float[5], Float[5]), + NumPyAPI.sig(jnp.packbits, Int[5]), + # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.partition, Float[5], kth=3), + NumPyAPI.sig(jnp.percentile, Float[5], q=75), + NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), + NumPyAPI.sig(jnp.piecewise, Float[5], [Bool[5], Bool[5]], funclist=[jnp.sin, jnp.cos]), + NumPyAPI.sig(jnp.place, Float[5], Bool[5], Float[3], inplace=False), + NumPyAPI.sig(jnp.poly, Float[5]), + NumPyAPI.sig(jnp.polyadd, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyder, Float[5]), + NumPyAPI.sig(jnp.polydiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyfit, Float[5], Float[5], deg=2), + NumPyAPI.sig(jnp.polyint, Float[5]), + NumPyAPI.sig(jnp.polymul, Float[5], Float[5]), + NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), + NumPyAPI.sig(jnp.positive, Float[5]), + # NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + # NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.prod, Float[5]), + NumPyAPI.sig(jnp.ptp, Float[5]), + NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), + NumPyAPI.sig(jnp.put_along_axis, Float[5], Int[1], Float[1], axis=0, inplace=False), + NumPyAPI.sig(jnp.quantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.rad2deg, Float[5]), + NumPyAPI.sig(jnp.radians, Float[5]), + NumPyAPI.sig(jnp.ravel, Float[5]), + # NumPyAPI.sig(jnp.ravel_multi_index, Int[2, 5], dims=(2, 3)), + NumPyAPI.sig(jnp.real, Complex[5]), + NumPyAPI.sig(jnp.reciprocal, Float[5]), + NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), + # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), + # NumPyAPI.sig(jnp.reshape, Float[6], (2, 3)), + NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), + NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.rint, Float[5]), + NumPyAPI.sig(jnp.roll, Float[5], Int[1]), + NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), + NumPyAPI.sig(jnp.roots, Float[5]), + NumPyAPI.sig(jnp.rot90, Float[5, 3]), + NumPyAPI.sig(jnp.round, Float[5]), + NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), + # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), + NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), + # NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.sign, Float[5]), + NumPyAPI.sig(jnp.signbit, Float[5]), + NumPyAPI.sig(jnp.sin, Float[5]), + NumPyAPI.sig(jnp.sinc, Float[5]), + NumPyAPI.sig(jnp.sinh, Float[5]), + # NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.sort, Float[5]), + NumPyAPI.sig(jnp.sort_complex, Complex[5]), + NumPyAPI.sig(jnp.spacing, Float[5]), + NumPyAPI.sig(jnp.split, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.sqrt, Float[5]), + NumPyAPI.sig(jnp.square, Float[5]), + NumPyAPI.sig(jnp.squeeze, Float[5]), + # NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.std, Float[5]), + NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), + NumPyAPI.sig(jnp.sum, Float[5]), + NumPyAPI.sig(jnp.swapaxes, Float[3, 5], axis1=1, axis2=0), + NumPyAPI.sig(jnp.take, Float[5], Int[2]), + NumPyAPI.sig(jnp.take_along_axis, Float[5], Int[2], axis=0), + NumPyAPI.sig(jnp.tan, Float[5]), + NumPyAPI.sig(jnp.tanh, Float[5]), + NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), + # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.trace, Float[5, 5]), + # NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.trapezoid, Float[5]), + NumPyAPI.sig(jnp.tril, Float[5, 6]), + # NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.trim_zeros, Float[5]), + NumPyAPI.sig(jnp.triu, Float[5, 6]), + # NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.trunc, Float[5]), + NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.unique, Int[10]), + NumPyAPI.sig(jnp.unique_all, Int[10]), + NumPyAPI.sig(jnp.unique_counts, Int[10]), + NumPyAPI.sig(jnp.unique_inverse, Int[10]), + NumPyAPI.sig(jnp.unique_values, Int[10]), + NumPyAPI.sig(jnp.unpackbits, Uint8[8]), + NumPyAPI.sig(jnp.unravel_index, Int[5], shape=(2, 3)), + NumPyAPI.sig(jnp.unstack, Float[5]), + NumPyAPI.sig(jnp.unwrap, Float[5]), + NumPyAPI.sig(jnp.vander, Float[5]), + NumPyAPI.sig(jnp.var, Float[5]), + NumPyAPI.sig(jnp.vdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecmat, Float[5], Float[5, 3]), + NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), + NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), + # NumPyAPI.sig(jnp.zeros_like, Float[5]), +] + + +class JaxArrayTests(jtu.JaxTestCase): + @parameterized.named_parameters( + {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) + def test_numpy_api_supports_jax_array(self, api): + fun = api.fun + args = api.make_args(self.rng()) + wrapped_args = jax.tree.map(JaxArrayWrapper, args) + kwargs = api.kwargs + + expected = fun(*args, **kwargs) + wrapped = fun(*wrapped_args, **kwargs) + + self.assertAllClose(wrapped, expected, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 16dc0ad1dd475a5ea994f03d95127eb2a003d43b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 14:47:50 -0700 Subject: [PATCH 0033/1769] Add `jax_source_package` macros and target to generate a source package `.tar.gz`. Refactor `jax_wheel` macros, so it outputs a `.whl` file only. When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs. The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`. PiperOrigin-RevId: 738547491 --- BUILD.bazel | 18 ++++++- build/build.py | 3 ++ build_wheel.py | 17 +++++- jaxlib/jax.bzl | 100 +++++++++++++++++++++++++----------- jaxlib/tools/build_utils.py | 8 +-- 5 files changed, 110 insertions(+), 36 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..eb43d7ec0fd8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -15,6 +15,7 @@ load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", ) @@ -67,7 +68,6 @@ py_binary( jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, source_files = [ ":transitive_py_data", @@ -82,3 +82,19 @@ jax_wheel( wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_source_package( + name = "jax_source_package", + source_files = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], + source_package_binary = ":build_wheel", + source_package_name = "jax", +) diff --git a/build/build.py b/build/build.py index d38b911bb904..cdb568171b66 100755 --- a/build/build.py +++ b/build/build.py @@ -68,6 +68,7 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", @@ -661,6 +662,8 @@ async def main(): # Append the build target to the Bazel command. build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) + if args.use_new_wheel_build_rule and wheel == "jax": + wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: wheel_build_command.append("--") diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..b4db96773527 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,20 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) args = parser.parse_args() @@ -94,7 +108,8 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: args.output_path, package_name="jax", git_hash=args.jaxlib_git_hash, - build_wheel_only=False, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, ) finally: if tmpdir: diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 89f1545995d5..02e6b10b1de1 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -362,7 +362,7 @@ def _get_full_wheel_name( free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) -def _get_source_distribution_name(package_name, wheel_version): +def _get_source_package_name(package_name, wheel_version): return "{package_name}-{wheel_version}.tar.gz".format( package_name = package_name, wheel_version = wheel_version, @@ -394,37 +394,47 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + build_source_package_only = ctx.attr.build_source_package_only editable = ctx.attr.editable platform_name = ctx.attr.platform_name + + output_dir_path = "" + outputs = [] if editable: output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) - wheel_dir = output_dir.path + output_dir_path = output_dir.path outputs = [output_dir] args.add("--editable") else: - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - py_freethreaded = py_freethreaded, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if build_wheel_only: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) - - args.add("--output_path", wheel_dir) # required argument + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-wheel-only", "True") + if build_source_package_only: + source_package_name = _get_source_package_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_package_file = ctx.actions.declare_file(output_path + + "/" + source_package_name) + output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")] + outputs = [source_package_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-source-package-only", "True") + + args.add("--output_path", output_dir_path) # required argument if not platform_independent: args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument @@ -472,16 +482,17 @@ _jax_wheel = rule( "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), - "build_wheel_only": attr.bool(default = True), + "build_wheel_only": attr.bool(mandatory = True, default = True), + "build_source_package_only": attr.bool(mandatory = True, default = False), "editable": attr.bool(default = False), - "cpu": attr.string(mandatory = True), - "platform_name": attr.string(mandatory = True), + "cpu": attr.string(), + "platform_name": attr.string(), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. - "platform_version": attr.string(mandatory = True, default = ""), + "platform_version": attr.string(), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), @@ -498,7 +509,6 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, - build_wheel_only = True, editable = False, enable_cuda = False, enable_rocm = False, @@ -509,11 +519,10 @@ def jax_wheel( Common artifact attributes are grouped within a single macro. Args: - name: the name of the wheel + name: the target name wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI - build_wheel_only: whether to build a wheel without source distribution editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel @@ -522,7 +531,7 @@ def jax_wheel( source_files: the source files to include in the wheel Returns: - A directory containing the wheel + A wheel file or a wheel directory. """ _jax_wheel( name = name, @@ -530,7 +539,8 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, - build_wheel_only = build_wheel_only, + build_wheel_only = True, + build_source_package_only = False, editable = editable, enable_cuda = enable_cuda, enable_rocm = enable_rocm, @@ -554,6 +564,34 @@ def jax_wheel( source_files = source_files, ) +def jax_source_package( + name, + source_package_binary, + source_package_name, + source_files = []): + """Create jax source package. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the target name + source_package_binary: the binary to use to build the package + source_package_name: the name of the source package + source_files: the source files to include in the package + + Returns: + A jax source package file. + """ + _jax_wheel( + name = name, + wheel_binary = source_package_binary, + wheel_name = source_package_name, + build_source_package_only = True, + build_wheel_only = False, + platform_independent = True, + source_files = source_files, + ) + jax_test_file_visibility = [] jax_export_file_visibility = [] diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 4c50cff16743..582a0c9f1d6f 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -65,6 +65,7 @@ def build_wheel( package_name: str, git_hash: str = "", build_wheel_only: bool = True, + build_source_package_only: bool = False, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) @@ -78,7 +79,8 @@ def build_wheel( env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] - + (["-w"] if build_wheel_only else []), + + (["-w"] if build_wheel_only else []) + + (["-s"] if build_source_package_only else []), check=True, cwd=sources_path, env=env, @@ -97,10 +99,10 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) - if not build_wheel_only: + if build_source_package_only: for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): output_file = os.path.join(output_path, os.path.basename(dist)) - sys.stderr.write(f"Output source distribution: {output_file}\n\n") + sys.stderr.write(f"Output source package: {output_file}\n\n") shutil.copy(dist, output_path) From f74711254feae1e1d0ba532ac4b3b56e388d036d Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 15:40:40 -0700 Subject: [PATCH 0034/1769] Fix `lax_autodiff_test` on v5p PiperOrigin-RevId: 738565192 --- tests/lax_autodiff_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a69f44f37754..aea9d2ad3dff 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -205,14 +205,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.test_device_matches(["cpu"]): + if jtu.test_device_matches(["cpu", "tpu"]): if op is lax.cosh and dtype == np.complex64: - tol = 3e-1 # 2nd-order gradients are noisy on CPU + tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU if jtu.test_device_matches(["tpu"]): if op is lax.pow: raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.sin: + order = 1 # 2nd-order gradient is imprecise on TPUv5p. if op is lax.log: order = 1 # 2nd-order gradient is imprecise on TPU. From 47dde87b9d734baf4b9f58f896305cdba0b9f484 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Mar 2025 15:53:08 -0700 Subject: [PATCH 0035/1769] Use np.ones to avoid signed integer overflow at run time PiperOrigin-RevId: 738569856 --- tests/pjit_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293b37a9fbc7..6cf11494988a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6360,8 +6360,8 @@ def f(x): def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) s = NamedSharding(mesh, P('data')) arr1 = jax.device_put(np_inp1, s) @@ -6387,9 +6387,9 @@ def test_intermediate_einsum_auto_complete_spec(self, mesh): shape1 = (8, 32, 2*16) shape2 = (8, 32, 2, 8) shape3 = (8, 32, 2, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) + np_inp3 = np.ones(math.prod(shape3)).reshape(shape3) arr1 = jax.device_put(np_inp1, s) arr2 = jax.device_put(np_inp2, s) @@ -6436,8 +6436,8 @@ def f(condition, x, y): def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) arr1 = jax.device_put( np_inp1, NamedSharding(mesh, P(None, None, None, 'data'))) From ab42a3e6382a0e2eedcc63176f36ec0d48f617bf Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 12 Mar 2025 15:09:57 +0200 Subject: [PATCH 0036/1769] Fix betainc edge cases and inaccuracies when a is close to zero. --- jax/_src/lax/special.py | 45 ++++++++++++++---- jax/_src/test_util.py | 2 +- tests/lax_scipy_special_functions_test.py | 56 +++++++++++++++++++---- 3 files changed, 84 insertions(+), 19 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index b70513bc2d20..ba2687d4acd7 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -194,12 +194,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x): iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1)) iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1) m = iteration_minus_one // full_like(iteration_minus_one, 2) + m_is_zero = eq(m, full_like(m, 0)) m = convert_element_type(m, dtype) one = full_like(a, 1) two = full_like(a, 2.0) # Partial numerator terms - even_numerator = -(a + m) * (a + b + m) * x / ( - (a + two * m) * (a + two * m + one)) + + # When a is close to zero and m == 0, using zero_numerator avoids + # inaccuracies when FTZ or DAZ is enabled: + zero_numerator = -(a + b) * x / (a + one) + even_numerator = select(m_is_zero, zero_numerator, + -(a + m) * (a + b + m) * x / ( + (a + two * m) * (a + two * m + one))) odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m)) one_numerator = full_like(x, 1.0) numerator = select(iteration_is_even, even_numerator, odd_numerator) @@ -210,12 +216,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x): return select(eq(iteration_bcast, full_like(iteration_bcast, 0)), full_like(x, 0), full_like(x, 1)) + a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf')))) + b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf')))) + x_is_zero = eq(x, full_like(x, 0)) + x_is_one = eq(x, full_like(x, 1)) + x_is_not_zero = bitwise_not(x_is_zero) + x_is_not_one = bitwise_not(x_is_one) + is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x)) + + result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero)) + result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one)) + result_is_nan = bitwise_or(bitwise_or(bitwise_or( - le(a, full_like(a, 0)), le(b, full_like(b, 0))), + lt(a, full_like(a, 0)), lt(b, full_like(b, 0))), lt(x, full_like(x, 0))), gt(x, full_like(x, 1))) + result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan)) - # The continued fraction will converge rapidly when x < (a+1)/(a+b+2) - # as per: http://dlmf.nist.gov/8.17.E23 + # The continued fraction will converge rapidly when x < + # (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23. # # Otherwise, we can rewrite using the symmetry relation as per: # http://dlmf.nist.gov/8.17.E4 @@ -234,10 +252,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x): inputs=[a, b, x] ) - lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b) - result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a + # For very small a and to avoid division by zero, we'll use + # a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+. + very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype) + lbeta_ab_small_a = lgamma(b) - lgamma(a + b) + lbeta_ab = lgamma(a) + lbeta_ab_small_a + factor = select(lt(a, full_like(a, very_small)), + exp(log1p(-x) * b - lbeta_ab_small_a), + exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a) + result = continued_fraction * factor + result = select(converges_rapidly, result, sub(full_like(result, 1), result)) + + result = select(result_is_zero, full_like(a, 0), result) + result = select(result_is_one, full_like(a, 1), result) result = select(result_is_nan, full_like(a, float('nan')), result) - return select(converges_rapidly, result, sub(full_like(result, 1), result)) + return result class IgammaMode(Enum): VALUE = 1 diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 0dc4fe641029..3a18d12e9d4b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1522,7 +1522,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, + self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol, canonicalize_dtypes=canonicalize_dtypes) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f4e4e4f48213..4b3945a84453 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -288,35 +288,71 @@ def testExpiDisableJit(self): self.assertAllClose(result_jit, result_nojit) def testGammaIncBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) def testGammaIncCBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + def testBetaIncBoundaryValues(self): + dtype = jax.dtypes.canonicalize_dtype(float) + fi = jax.numpy.finfo(dtype) + nan = float('nan') + inf = float('inf') + tiny = fi.tiny + eps = fi.eps + if jtu.parse_version(scipy.__version__) >= (1, 16): + # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682 + # will be available + a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + elif jtu.parse_version(scipy.__version__) >= (1, 12): + # disabled samples that contradict with scipy/scipy#22425 + a_samples = [nan, -0.5, 0.5] + b_samples = [nan, -0.5, 0.5] + else: + a_samples = [-0.5, 0.5] + b_samples = [-0.5, 0.5] + x_samples = [nan, -0.5, 0, 0.5, 1, 1.5] + + a_samples = np.array(a_samples, dtype=dtype) + b_samples = np.array(b_samples, dtype=dtype) + x_samples = np.array(x_samples, dtype=dtype) + + args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples) + + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 2562da7026ccd930e5f0972598c7d5479175b787 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 18:26:36 -0700 Subject: [PATCH 0037/1769] Expose profiler_data submodule from XLA to Jaxlib. PiperOrigin-RevId: 738613439 --- jaxlib/setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..60f17a987307 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -68,10 +68,10 @@ def has_ext_modules(self): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ], package_data={ 'jaxlib': [ @@ -105,7 +105,7 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi'], + 'jaxlib.xla_extension': ['*.pyi', 'profiler/*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, From b5c467e6cf702160be29ee93084f3f9a0da2b888 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 19 Mar 2025 23:56:24 -0400 Subject: [PATCH 0038/1769] Fix doc for random.categorical replace argument. --- jax/_src/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 094268c65825..c0663dc67f80 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1568,8 +1568,8 @@ def categorical( shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - replace: If True, perform sampling without replacement. Default (False) is to - perform sampling with replacement. + replace: If True (default), perform sampling with replacement. If False, perform + sampling without replacement. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` From 258ed1b0a5bd56b797d2ca47627db539c1be81f8 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Thu, 20 Mar 2025 04:03:11 +0000 Subject: [PATCH 0039/1769] Fixes the stream annotation compute on box. --- jax/_src/interpreters/mlir.py | 5 ++++- tests/memories_test.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..4a723ffe5227 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2340,7 +2340,10 @@ def wrap_compute_type_in_place(ctx, op): if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: if ctx.jaxpr_eqn_ctx.compute_type.startswith("gpu_stream:"): stream = ctx.jaxpr_eqn_ctx.compute_type.split(":")[1] - dict_attr = {"_xla_stream_annotation": ir.StringAttr.get(stream)} + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(stream), + "inlineable": ir.StringAttr.get("false"), + } op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( diff --git a/tests/memories_test.py b/tests/memories_test.py index 0ca973c4d221..bdb88b418697 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1670,24 +1670,36 @@ def test_stream_annotation_inside_shmap(self): arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) + # Makes sure the compute wrapped here is fusible. + # This is a workaround for limitations in XLA. + # 1) Compute-on boxes contain a single instruction cannot work. + # 2) Compute-on boxes contain tiny matmul cannot work. @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x * y + return x * y + x @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x * y + return x * y + x def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 7) + compiled_f = jax.jit( + shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x'))).lower(arr1, arr2).compile( + {"xla_gpu_experimental_stream_annotation": True} + ) + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_f.as_text()) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) class ActivationOffloadingTest(jtu.JaxTestCase): From e0c093314d8d9a6f68953f0c340c1b01d50ce386 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 21:24:30 -0700 Subject: [PATCH 0040/1769] Remove ; in code blocks of `thinking_in_jax.md` PiperOrigin-RevId: 738656531 --- docs/notebooks/thinking_in_jax.ipynb | 4 ++-- docs/notebooks/thinking_in_jax.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..560e0500ad13 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -57,7 +57,7 @@ "\n", "x_np = np.linspace(0, 10, 1000)\n", "y_np = 2 * np.sin(x_np) * np.cos(x_np)\n", - "plt.plot(x_np, y_np);" + "plt.plot(x_np, y_np)" ] }, { @@ -91,7 +91,7 @@ "\n", "x_jnp = jnp.linspace(0, 10, 1000)\n", "y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n", - "plt.plot(x_jnp, y_jnp);" + "plt.plot(x_jnp, y_jnp)" ] }, { diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..b107f78635f6 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -42,7 +42,7 @@ import numpy as np x_np = np.linspace(0, 10, 1000) y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np); +plt.plot(x_np, y_np) ``` ```{code-cell} ipython3 @@ -53,7 +53,7 @@ import jax.numpy as jnp x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) -plt.plot(x_jnp, y_jnp); +plt.plot(x_jnp, y_jnp) ``` +++ {"id": "kTZcsCJiuPG8"} From 4da751a97a2a7837e977ecde77cd4ba0a05cfda5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 21:50:36 -0700 Subject: [PATCH 0041/1769] Reverts e0c093314d8d9a6f68953f0c340c1b01d50ce386 PiperOrigin-RevId: 738662342 --- docs/notebooks/thinking_in_jax.ipynb | 4 ++-- docs/notebooks/thinking_in_jax.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 560e0500ad13..5ddcdd32e2b4 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -57,7 +57,7 @@ "\n", "x_np = np.linspace(0, 10, 1000)\n", "y_np = 2 * np.sin(x_np) * np.cos(x_np)\n", - "plt.plot(x_np, y_np)" + "plt.plot(x_np, y_np);" ] }, { @@ -91,7 +91,7 @@ "\n", "x_jnp = jnp.linspace(0, 10, 1000)\n", "y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n", - "plt.plot(x_jnp, y_jnp)" + "plt.plot(x_jnp, y_jnp);" ] }, { diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index b107f78635f6..0693f6ba8579 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -42,7 +42,7 @@ import numpy as np x_np = np.linspace(0, 10, 1000) y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np) +plt.plot(x_np, y_np); ``` ```{code-cell} ipython3 @@ -53,7 +53,7 @@ import jax.numpy as jnp x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) -plt.plot(x_jnp, y_jnp) +plt.plot(x_jnp, y_jnp); ``` +++ {"id": "kTZcsCJiuPG8"} From 58ba4106c33752856738bbf5f22cd16854eb2b22 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 21:51:15 -0700 Subject: [PATCH 0042/1769] [mosaic_gpu] Check for dropped activity records in cupti profiler. PiperOrigin-RevId: 738662559 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index a726acd4d662..f91018cf7287 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -195,6 +195,12 @@ void callback_complete(CUcontext context, uint32_t streamId, THROW_IF_CUPTI_ERROR(status); } } + + size_t num_dropped; + THROW_IF_CUPTI_ERROR( + cuptiActivityGetNumDroppedRecords(context, streamId, &num_dropped), + "failed to get number of dropped activity records"); + THROW_IF(num_dropped > 0, "activity records were dropped"); } NB_MODULE(_mosaic_gpu_ext, m) { From 509c65895dd3e6011e08269f2b1c61ba620ce7ed Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 23:41:08 -0700 Subject: [PATCH 0043/1769] [mosaic_gpu] Make cupti finalization optional. cupti initialization / finalization is somewhat expensive. This gives us the option of avoiding repeated initialization when performing multiple cupti timings. Disable kernel activity to ensure we've restored cupti to its original state. PiperOrigin-RevId: 738685851 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index f91018cf7287..2c7242b6e6c0 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -245,15 +245,23 @@ NB_MODULE(_mosaic_gpu_ext, m) { cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), "failed to enable tracking of kernel activity by CUPTI"); }); - m.def("_cupti_get_timings", []() { - THROW_IF_CUPTI_ERROR( - cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), - "failed to unsubscribe from CUPTI"); - return profiler_state.timings; - }); + m.def( + "_cupti_get_timings", + [](bool finalize) { + THROW_IF_CUPTI_ERROR( + cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to disable tracking of kernel activity by CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + if (finalize) { + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + } + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + return profiler_state.timings; + }, + nb::arg("finalize") = true); } } // namespace From 6e204171f53d08ead0238f25e647ad4a41367c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 19 Mar 2025 23:47:35 -0700 Subject: [PATCH 0044/1769] [Mosaic:TPU] Add overload to ComputeTileStrides that just takes a shape. PiperOrigin-RevId: 738687016 --- jaxlib/mosaic/dialect/tpu/util.cc | 12 ++++++------ jaxlib/mosaic/dialect/tpu/util.h | 9 ++++++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 651cef85f740..44cc301d6f0d 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -45,18 +45,18 @@ std::ostream &operator<<(std::ostream &os, Print p) { return os; } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling) { - SmallVector tile_strides(memref_ty.getRank()); + SmallVector tile_strides(shape.size()); int64_t stride = 1; - for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - int64_t idx = memref_ty.getRank() - 1 - i; + for (int64_t i = 0; i < shape.size(); ++i) { + int64_t idx = shape.size() - 1 - i; int64_t tiling_idx = tiling.size() - 1 - i; tile_strides[idx] = stride; if (tiling_idx >= 0) { - stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]); + stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]); } else { - stride *= memref_ty.getShape()[idx]; + stride *= shape[idx]; } } return tile_strides; diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2e19cb820b5b..f9ab1b7e349d 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -192,8 +192,15 @@ std::string shapeToString(const T &shape) { return os.str(); } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling); + +inline SmallVector ComputeTileStrides( + MemRefType memref_ty, absl::Span tiling) { + absl::Span shape(memref_ty.getShape().data(), + memref_ty.getShape().size()); + return ComputeTileStrides(shape, tiling); +} // Assuming MKN matmul - This function must only be called after // canonicalization passes. // From 2d43fb473001af5f3a779b7d9ce7f54bf0ae1fe6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 20 Mar 2025 02:07:33 -0700 Subject: [PATCH 0045/1769] =?UTF-8?q?[Mosaic=20GPU]=C2=A0Introduce=20an=20?= =?UTF-8?q?optimization=20barrier=20op.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also add layout inference and lowering rules for it. Its initial use case will be to fence WGMMA accumulator registers. As a result, transform inference is not immediately useful for this op, and we omit it here. PiperOrigin-RevId: 738718000 --- .../mosaic/gpu/dialect_lowering.py | 33 ++++++++++++ .../mosaic/gpu/layout_inference.py | 51 ++++++++++++++++++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 29 +++++++++++ tests/mosaic/gpu_layout_inference_test.py | 51 +++++++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index ae702d50ebb7..936bba73915b 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -35,6 +35,7 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np @@ -203,6 +204,38 @@ def _initialize_barrier_op_lowering_rule( barrier_base_ptr, initialize_barrier_op.barriers_ref.type), +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@_register_lowering(OptimizationBarrierOp) +def _optimization_barrier_op_lowering_rule( + _: LoweringContext, + op: OptimizationBarrierOp, +) -> Sequence[ir.Value]: + if not all(ir.VectorType.isinstance(operand.type) for operand in op.operands): + raise NotImplementedError( + f"Optimization barrier op {op} has non-vector operands." + ) + + fragmented_arrays = [] + for operand, layout in safe_zip(op.operands, inference_utils.in_layouts(op)): + ty = ir.VectorType(operand.type) + is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None + fragmented_arrays.append( + _fragmented_array_from_ir(operand, layout, is_signed=is_signed) + ) + + lowered_fragmented_arrays = fa.optimization_barrier(*fragmented_arrays) + if isinstance(lowered_fragmented_arrays, fa.FragmentedArray): + lowered_fragmented_arrays = [lowered_fragmented_arrays] + + return [ + _fragmented_array_to_ir(arr, result.type) + for arr, result in safe_zip(lowered_fragmented_arrays, op.results) + ] + + @_register_lowering(arith.ConstantOp) def _arith_constant_op_lowering_rule( _: LoweringContext, op: arith.ConstantOp diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 470b0d328d8e..dec75e4db1a0 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -44,7 +44,9 @@ def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule): - _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule def _set_layout_attributes( @@ -192,7 +194,7 @@ def is_array(v: ir.Value) -> bool: # This is left for a future change, and currently we only do "down # propagation". layout = _choose_representative_layout(layouts) - # It is unsafe to t conclude that this op produces a splat if not all inputs + # It is unsafe to conclude that this op produces a splat if not all inputs # have been inferred: some of them might turn out not to be splats! if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: return None @@ -247,6 +249,51 @@ def is_array(v: ir.Value) -> bool: _add_layout_inference_rule(op, _infer_pointwise_op_layouts) +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@partial(_add_layout_inference_rule, OptimizationBarrierOp) +def _infer_optimization_barrier_op_layout( + op: OptimizationBarrierOp, +) -> OptionalLayouts: + def is_array(v: ir.Value) -> bool: + return ir.VectorType.isinstance(v.type) + + if inference_utils.has_in_layouts_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + return op_in_layouts, op_in_layouts + + if inference_utils.has_out_layouts_set(op): + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_out_layouts, op_out_layouts + + layouts = [None] * len(op.operands) + for i, operand in enumerate(filter(is_array, op.operands)): + layouts[i] = inference_utils.value_layout(operand) + + for i, result in enumerate(filter(is_array, op.results)): + possible_layouts = set() + for op_operand_use in cast(ir.OpResult, result).uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, op_user) + if layout is not None: + possible_layouts.add(layout) + if possible_layouts and layouts[i] is None: + # TODO(bchetioui): we could actually just pick any user layout here, + # and optimize later. This is fine for now. + layouts[i] = _choose_representative_layout(possible_layouts) + + # TODO(bchetioui): handle annotating layout for only certain operands. + # Otherwise, layouts may not get propagated through optimization barriers, if + # a single branch does not carry any forcing layout, which is pretty bad. + if any(layout is None for layout in layouts): + return None + + return layouts, layouts + + @partial(_add_layout_inference_rule, arith.ConstantOp) def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: if not ir.VectorType.isinstance(constant_op.result.type): diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 108ff952b571..f0a37084b759 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -457,4 +457,33 @@ def MosaicGPU_WGMMAOp : Op { let hasVerifier = 1; } +def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Prevents MLIR from moving operations across the barrier."; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + ::mlir::TypeRange operand_types = operands.getTypes(); + inferredReturnTypes.assign(operand_types.begin(), operand_types.end()); + return ::mlir::success(); + } + }]; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 36c8ff9cf47e..893e21efc6d0 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -442,6 +442,57 @@ def body(lhs, rhs): self.assertNotIn("in_layouts", f.attributes) self.assertNotIn("out_layouts", f.attributes) + def test_optimization_barrier_op_propagates_user_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) + lhs, rhs = optimization_barrier.results + add = arith.AddFOp(lhs, rhs) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], + [splat_layout, splat_layout], + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], + [splat_layout, splat_layout], + ) + + def test_optimization_barrier_op_propagates_producer_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + add = arith.AddFOp(lhs, rhs) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], [splat_layout] + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], [splat_layout] + ) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 18326abea65df2efee72cf909d9f4ae910df4f76 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 20 Mar 2025 02:38:54 -0700 Subject: [PATCH 0046/1769] [mosaic_gpu] Don't time the warmup step in cupti profiler. Initializing and finalizing cupti has an overhead. PiperOrigin-RevId: 738725435 --- jax/experimental/mosaic/gpu/profiler.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 99fefc1adc9c..32b3edf7caf9 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -102,23 +102,21 @@ def _measure_cupti(f, aggregate): if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) - def run(*args, **kwargs): + def wrapper(*args, **kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() try: results = jax.block_until_ready(f(*args, **kwargs)) finally: timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() - return results, timings - def wrapper(*args, **kwargs): - run(*args, **kwargs) # Warmup. - results, timings = run(*args, **kwargs) if not timings: return results, None elif aggregate: return results, sum(item[1] for item in timings) else: return results, timings + return wrapper From e2b6859e7d3e5c0c01be9013d6cb680ab647d9a4 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 20 Mar 2025 02:51:56 -0700 Subject: [PATCH 0047/1769] Deprecate the jaxlib.hlo_helpers submodule. jaxlib no longer includes any lowering logic, so we don't need this module anymore. Users would be better served by the APIs in JAX core like `jax.ffi` or `jax.interpreters.mlir`. This module isn't covered by JAX's compatibility policy, so no formal deprecation period is required, but there are enough users that we should keep this warning for at least one full release cycle. PiperOrigin-RevId: 738728721 --- jax/_src/lib/__init__.py | 1 - jaxlib/hlo_helpers.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7933bb769733..70dc914668cf 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -109,7 +109,6 @@ def _xla_gc_callback(*args): import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 # Jaxlib code is split between the Jax and the Tensorflow repositories. # Only for the internal usage of the JAX developers, we expose a version diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 0d57a04f1aa7..11ff844ae53f 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -19,11 +19,22 @@ from collections.abc import Callable, Sequence from functools import partial from typing import Union +import warnings import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np +# TODO(danfm): This module isn't covered by JAX's compatibility policy, so no +# formal deprecation period is required, but there are enough users that we +# should keep this warning for at least one full release cycle. +# Deprecation added 2025-03-19 after the release of v0.5.3. Remove this whole +# module after the release of v0.5.4 or later. +warnings.warn( + "The jaxlib.hlo_helpers submodule is deprecated. Instead, use jax.ffi if " + "possible or, for lower-level operations, jax.interpreters.mlir.", + DeprecationWarning, +) _dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = { np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), From f1298ae7f11464e697ebec66e75046e90ae739e1 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 20 Mar 2025 03:01:50 -0700 Subject: [PATCH 0048/1769] Remove XLA FFI GPU callback handler. - In order to migrate the GPU FFI handler from the internal API intended for static linking to the external API intended for dynamic linking, we need to migrate both CPU and GPU FFI handlers at the same time. - Builds break if we include both versions of the FFI APIs. - Now that py_client_gpu sits in jaxlib, tests that run new FFI API in jaxlib against old FFI API in xla (and vice versa) for GPU targets will fail. - This change lets us update the CPU handler first in XLA and then update the GPU handler second in jaxlib. - Because the GPU handler depends on new symbols in xla, we need to land XLA changes first anyway (i.e., no point to deleting both CPU and GPU to try to land jaxlib and xla in one go). PiperOrigin-RevId: 738730955 --- jaxlib/cuda/BUILD | 12 ---- jaxlib/gpu/py_client_gpu.cc | 136 ------------------------------------ jaxlib/gpu/py_client_gpu.h | 3 - jaxlib/rocm/BUILD | 12 ---- 4 files changed, 163 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 4e74cc2dcf5b..ee32888864dd 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -668,34 +668,22 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", "@nanobind", "@xla//xla:comparison_util", - "@xla//xla:shape_util", - "@xla//xla/ffi", - "@xla//xla/ffi:ffi_api", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:statusor", ], ) diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index d6faa1859eb8..cf701574959b 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -21,9 +21,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "absl/algorithm/container.h" #include "absl/base/casts.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -31,24 +29,15 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/Casting.h" #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/ffi.h" -#include "xla/ffi/ffi_api.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" -#include "xla/python/ifrt/host_callback.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_host_callback.h" -#include "xla/python/types.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/shape_util.h" -#include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/platform/statusor.h" namespace nb = nanobind; @@ -166,130 +155,5 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "xla_python_gpu_callback", &XlaPythonGpuCallback, absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); - -absl::Status XlaFfiPythonGpuCallback( - gpuStream_t stream, - std::vector>* callbacks, - uint64_t index, xla::ffi::RemainingArgs args, - xla::ffi::RemainingRets rets) { - auto loaded_callback = llvm::dyn_cast_or_null( - callbacks->at(index).get()); - if (loaded_callback == nullptr) { - return absl::InternalError( - "Expected a PyCpuLoadedHostCallback, got something else."); - } - xla::CpuCallback* callback = loaded_callback->cpu_callback(); - size_t arity = args.size(); - std::vector host_input_buffers(arity); - // Copy input GPU buffers to host - for (size_t i = 0; i < arity; ++i) { - auto arg = args.get(i); - if (arg->element_type() == xla::TOKEN) { - host_input_buffers[i] = nullptr; - continue; - } - void* buf = new char[arg->size_bytes()]; - host_input_buffers[i] = buf; - // TODO(b/238441608): Use pinned memory here to speed up the transfer. - auto gpu_res = - gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), - gpuMemcpyDeviceToHost, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - nb::gil_scoped_acquire gil; - nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); - for (size_t i = 0; i < arity; ++i) { - auto arg = args.get(i); - xla::PrimitiveType ptype = arg->element_type(); - if (ptype == xla::TOKEN) { - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); - } else { - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); - TF_ASSIGN_OR_RETURN(auto dtype, xla::PrimitiveTypeToNbDtype(ptype)); - auto array = xla::nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt, - host_input_buffers[i], base); - array.attr("flags").attr("writeable") = nb::bool_(false); - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); - } - } - - xla::EnterHostCallback(); - // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows - // you to avoid constructing a tuple for the arguments. - absl::StatusOr maybe_result_tuple = - callback->FfiCall(host_input_arrays); - xla::LeaveHostCallback(); - TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); - - std::vector temp_buffers; - for (size_t i = 0; i < rets.size(); ++i) { - auto ret = rets.get(i).value(); - auto ptype = ret->element_type(); - if (ptype == xla::TOKEN) continue; - nb::object output = - nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); - xla::nb_numpy_ndarray array = - xla::nb_numpy_ndarray::ensure(std::move(output)); - absl::Span strides( - reinterpret_cast(array.strides()), array.ndim()); - // We expect the output to be in default numpy layout. - TF_ASSIGN_OR_RETURN(auto expected_shape, xla::ShapeUtil::MakeValidatedShape( - ptype, ret->dimensions())); - auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } else { - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - absl::Span dims( - reinterpret_cast(array.shape()), array.ndim()); - options.dims = dims; - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.rank()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - TF_ASSIGN_OR_RETURN(auto plan, - callback->transpose_cache().GetOrCreate(options)); - plan->Execute(array.data(), temp); - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - } - nb::gil_scoped_release release; - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - for (int i = 0; i < temp_buffers.size(); ++i) { - delete[] static_cast(temp_buffers[i]); - } - return absl::OkStatus(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, - xla::ffi::Ffi::Bind() - .Ctx>() - .Ctx>>>() - .Attr("index") - .RemainingArgs() - .RemainingRets()); -XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), - "xla_ffi_python_gpu_callback", - absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), - kXlaFfiPythonGpuCallback); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 6be2d40823dc..e9454504f5d9 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { @@ -29,8 +28,6 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); -XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 1e54d82c4f71..99df757018f3 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -566,34 +566,22 @@ cc_library( features = ["-use_header_modules"], deps = [ ":hip_vendor", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", "@nanobind", "@xla//xla:comparison_util", - "@xla//xla:shape_util", - "@xla//xla/ffi", - "@xla//xla/ffi:ffi_api", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:statusor", ], ) From 84cc397b4eab2825b9b3479995fd700a3e17f17f Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 20 Mar 2025 04:00:30 -0700 Subject: [PATCH 0049/1769] [XLA:GPU][Triton] Remove sparsity code. It's unused but causes significant burden during Triton integrates. PiperOrigin-RevId: 738744625 --- tests/BUILD | 14 --- tests/sparse_nm_test.py | 209 ---------------------------------------- 2 files changed, 223 deletions(-) delete mode 100644 tests/sparse_nm_test.py diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..b126655b0a06 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1101,20 +1101,6 @@ jax_multiplatform_test( ] + py_deps("scipy"), ) -jax_multiplatform_test( - name = "sparse_nm_test", - srcs = ["sparse_nm_test.py"], - enable_backends = [], - enable_configs = [ - "gpu_a100", - "gpu_h100", - ], - deps = [ - "//jax:experimental_sparse", - "//jax:pallas_gpu", - ], -) - jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py deleted file mode 100644 index 9ecf30eb6229..000000000000 --- a/tests/sparse_nm_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax._src import config -from jax._src import test_util as jtu -from jax.experimental.sparse import nm - -jax.config.parse_flags_with_absl() - - -class SpmmTest(jtu.JaxTestCase): - def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - super().setUp() - - # ----- Test different input shapes - @parameterized.product( - tile_m=(32, 128), - tile_n=(32, 128), - tile_k=(32, 128), - batch=(None, 5), - sparse_idx=(0, 1), - ) - @jtu.run_on_devices("gpu") - def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - # Build keyword arguments - kwargs = { - "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), - "sparse_operand_idx": sparse_idx, - } - if batch: - kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) - - # Build input data - batch_dims = (batch,) if batch else tuple() - lhs = ( - (np.arange((batch or 1) * tile_m * tile_k) % 11) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_m, tile_k)) - ) - rhs = ( - (np.arange((batch or 1) * tile_n * tile_k) % 13) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_n, tile_k)) - ) - - # Build sparsity mask and metadata - sp = [lhs, rhs][sparse_idx] - mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) - sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - if sparse_idx == 0: - dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) - else: - dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) - - # Verify the result - jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) - - # ----- Test different input types - @parameterized.product( - lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], - rhs_type=[jnp.bfloat16], - output_type=[jnp.bfloat16, jnp.float32], - ) - @jtu.run_on_devices("gpu") - def test_types(self, lhs_type, rhs_type, output_type): - tile_m, tile_n, tile_k = 64, 32, 128 - - # Build input data - lhs = ( - (np.arange(tile_m * tile_k) % 17) - .astype(lhs_type) - .reshape((tile_m, tile_k)) - ) - rhs = ( - (np.arange(tile_k * tile_n) % 19) - .astype(rhs_type) - .reshape((tile_k, tile_n)) - ) - - # Build sparsity mask and metadata - mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) - sparse = lhs[mask].reshape(tile_m, tile_k // 2) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) - dot_dense = (lhs * mask) @ rhs - - # Verify the result - jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) - - # ----- Test validation - @jtu.run_on_devices("gpu") - def test_validate_nm_pack(self): - with self.assertRaisesRegex(TypeError, "Mask should be bool"): - nm.nm_pack(jnp.zeros(16, jnp.int8)) - with self.assertRaisesRegex( - TypeError, "Inner dimension size should be divisible by 16" - ): - nm.nm_pack(jnp.array([False] * 8)) - - @jtu.run_on_devices("gpu") - def test_validate_nm_spmm(self): - batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 - lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) - rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) - meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) - - if config.enable_x64.value: - with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): - nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) - with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): - nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) - with self.assertRaisesRegex(TypeError, "Unsupported output type"): - nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) - - # Check dimension numbers - nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( - lhs, rhs, meta, dimension_numbers=(c, b) - ) - with self.assertRaisesRegex( - TypeError, "Only single contracting dimension is supported" - ): - nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for lhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for rhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) - with self.assertRaisesRegex( - TypeError, "Only single non-contracting dimension is supported" - ): - nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Batch dimension sizes do not match" - ): - nm.nm_spmm( - lhs, - rhs.reshape(1, tile_k, tile_n * batch), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - # Check metadata - nm_spmm_with_meta = lambda m: nm.nm_spmm( - lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) - ) - with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): - nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) - with self.assertRaisesRegex( - TypeError, "Metadata shape must match the operand shape" - ): - nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) - with self.assertRaisesRegex( - TypeError, - "Metadata must be exactly 8 times less than the contracting dimension" - " for 2:4 structured sparsity", - ): - nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) - with self.assertRaisesRegex( - TypeError, "Contracting dimension must be the minor one" - ): - nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) - with self.assertRaisesRegex( - TypeError, "Contracting dimension sizes should have 2:4 ratio" - ): - nm.nm_spmm( - lhs, - jnp.repeat(rhs, 2, axis=1), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) From 1c8e60e6c299bb6cc39f5d9a0d68df327c79da10 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 04:30:49 -0700 Subject: [PATCH 0050/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d3145a119840723c16fd27ee342729d68fddb7ef. PiperOrigin-RevId: 738751933 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f81e3931b1dc..08c5af0c32b8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0d20d73f2c8f21c21b9f343c4363a76e980f032e" -XLA_SHA256 = "9df61c200b0a54b7a5c55155fa7a454e33d660e6a49239b6980f5a10305fecc5" +XLA_COMMIT = "d3145a119840723c16fd27ee342729d68fddb7ef" +XLA_SHA256 = "daf2a72e36a9358803a8156c48b32117c9699fd327fcbc37b465f1a0045bccae" def repo(): tf_http_archive( From 4d6f15f20c588fffd87ad1d610d92b636b194c5d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 20 Mar 2025 05:17:57 -0700 Subject: [PATCH 0051/1769] [Mosaic GPU] Add support for slicing tiled refs with (tile aligned) dynamic base offsets PiperOrigin-RevId: 738762062 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 84 +++++++++++++++++++++++------- tests/pallas/mosaic_gpu_test.py | 28 ++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e5b491aef330..ab35eebafc04 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -80,6 +80,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:state_types", "//jax:tree_util", + "//jax/_src/lib", "//jax/_src/pallas", "//jaxlib/mlir:ir", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 5e4566ddfc9c..b1e0a683f64d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,10 +29,11 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types -from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -135,6 +136,24 @@ def cmap_body(): return wrapper +def _is_known_divisible(value, divisor, fuel=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if fuel < 0: + return False + if not isinstance(value.owner, ir.Operation): + return False + def_op = value.owner.opview + match def_op: + case arith_dialect.IndexCastOp(): + return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1) + case arith_dialect.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith_dialect.MulIOp(): + return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or + _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) + return False + + @dataclasses.dataclass(frozen=True) class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () @@ -171,7 +190,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = slice | int | ir.Value +Index = mgpu.DynamicSlice | slice | int | ir.Value @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -218,16 +237,37 @@ def untransform_index( ) -> tuple[tuple[Index, ...], state_types.Transform]: untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] - idxs_after_tiling = [] + idxs_after_tiling: list[Index] = [] for idx, tile in zip(tiled_idxs, self.tiling): - if not isinstance(idx, slice): - raise NotImplementedError("Non-slice indices are not supported") - assert isinstance(idx, slice) - if idx.step is not None and idx.step != 1: - raise NotImplementedError("Strided slices unsupported") - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): - raise ValueError("Non-empty slices must be tile aligned") - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + if isinstance(idx, slice): + if idx.step is not None and idx.step != 1: + raise NotImplementedError("Strided slices unsupported") + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): + raise ValueError("Non-empty slices must be tile aligned") + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + elif isinstance(idx, mgpu.DynamicSlice): + if idx.length % tile: + raise ValueError( + f"Dynamic slice length ({idx.length}) is not divisible by the" + f" tiling ({tile})" + ) + if isinstance(idx.base, ir.Value): + if not _is_known_divisible(idx.base, tile): + raise ValueError( + "Dynamic slice base index (which is a dynamic value) cannot be" + f" statically proven to be divisible by the tiling ({tile})" + ) + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) + else: + if idx.base % tile: + raise ValueError( + f"Dynamic slice base ({idx.base}) is not divisible by the" + f" tiling ({tile})" + ) + new_base = idx.base // tile + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: @@ -285,7 +325,7 @@ def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: removed_dims = [ - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] new_perm = tuple( p - sum(d < p for d in removed_dims) @@ -358,18 +398,22 @@ def untransform_index( ) -> tuple[tuple[Index, ...], state_types.Transform]: if not idxs: return idxs, self - if not all(isinstance(idx, slice) for idx in idxs[-2:]): + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): raise NotImplementedError( "Non-slice indices are not supported in 2 minormost dims" ) last_idx = idxs[-1] - assert isinstance(last_idx, slice) - if last_idx.step is not None and last_idx.step != 1: - raise NotImplementedError("Swizzled dims cannot be sliced") - if (last_idx.start is not None and last_idx.start != 0) or ( - last_idx.stop is not None and last_idx.stop != self.swizzle - ): - raise ValueError("Swizzled dims cannot be sliced") + if isinstance(last_idx, mgpu.DynamicSlice): + if last_idx.base != 0 or last_idx.length != self.swizzle: + raise ValueError("Swizzled dims cannot be sliced") + else: + assert isinstance(last_idx, slice) + if ( + (last_idx.step is not None and last_idx.step != 1) + or (last_idx.start is not None and last_idx.start != 0) + or (last_idx.stop is not None and last_idx.stop != self.swizzle) + ): + raise ValueError("Swizzled dims cannot be sliced") return idxs, self diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 38335925b44d..40e98bf05ba9 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1132,6 +1132,34 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + # Not testing with warpgroup semantics, because we want to enforce a layout. + def test_tile_slicing(self): + shape = (256, 128) + block_spec = plgpu.GPUBlockSpec( + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ) + ) + @functools.partial( + pl.pallas_call, + in_specs=[block_spec], + out_specs=block_spec, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), + ) + def kernel(x_ref, o_ref): + def sum_tiles(row, acc): + row_slice = pl.ds(row * 64, 64) + for col in range(128 // 64): + acc += x_ref[row_slice, pl.ds(col * 64, 64)] + return acc + acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA) + o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc) + + x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape) + y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16) + np.testing.assert_array_equal(kernel(x), y) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): From 2c90fe2dea5a0ec5941d21973e33f4334d43ed0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Mar 2025 06:12:22 -0700 Subject: [PATCH 0052/1769] Reorder C++ imports. PiperOrigin-RevId: 738774175 --- examples/jax_cpp/main.cc | 2 +- jaxlib/gpu/vendor.h | 2 +- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 2 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 3 +- jaxlib/mlir/_mlir_libs/triton_ext.cc | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 15 ++-- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 4 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 38 ++++---- jaxlib/mosaic/dialect/tpu/array_util.cc | 4 +- jaxlib/mosaic/dialect/tpu/array_util.h | 2 +- jaxlib/mosaic/dialect/tpu/array_util_test.cc | 2 +- .../dialect/tpu/integrations/c/tpu_dialect.cc | 6 +- jaxlib/mosaic/dialect/tpu/layout.cc | 2 +- jaxlib/mosaic/dialect/tpu/layout.h | 2 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 8 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 7 +- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 18 ++-- .../tpu/transforms/apply_vector_layout.cc | 25 +++--- .../apply_vector_layout_extensions.h | 6 +- .../tpu/transforms/canonicalize_mosaic.cc | 33 ++++--- .../apply_vector_layout_extensions.cc | 4 +- .../infer_vector_layout_extensions.cc | 6 +- .../tpu/transforms/infer_memref_layout.cc | 2 +- .../tpu/transforms/infer_vector_layout.cc | 19 ++-- .../infer_vector_layout_extensions.h | 4 +- .../tpu/transforms/linalg_vectorization.cc | 52 +++++------ .../transforms/memory_space_specialization.cc | 10 +-- .../tpu/transforms/relayout_insertion.cc | 13 ++- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 10 +-- jaxlib/mosaic/dialect/tpu/transforms/serde.h | 10 +-- jaxlib/mosaic/dialect/tpu/util.cc | 16 ++-- jaxlib/mosaic/dialect/tpu/util.h | 10 +-- jaxlib/mosaic/dialect/tpu/vreg_util.cc | 20 ++--- jaxlib/mosaic/dialect/tpu/vreg_util.h | 12 +-- jaxlib/mosaic/dialect/tpu/vreg_util_test.cc | 28 +++--- jaxlib/mosaic/gpu/custom_call.cc | 86 +++++++++---------- jaxlib/mosaic/gpu/launch_lowering.cc | 46 +++++----- jaxlib/mosaic/gpu/passes.cc | 24 +++--- jaxlib/mosaic/gpu/serde.cc | 8 +- jaxlib/mosaic/gpu/serde.h | 12 +-- jaxlib/mosaic/gpu/target.cc | 4 +- jaxlib/mosaic/pass_boilerplate.h | 8 +- jaxlib/mosaic/serde.cc | 18 ++-- jaxlib/mosaic/serde.h | 10 +-- jaxlib/triton/triton_dialect_capi.cc | 12 +-- jaxlib/triton/triton_dialect_capi.h | 4 +- 46 files changed, 307 insertions(+), 324 deletions(-) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 0a1d3a63acfd..5d1190ff1f2c 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,7 +41,7 @@ limitations under the License. #include #include -#include "third_party/absl/status/statusor.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index cadd5453107a..58a02e7c568c 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -29,7 +29,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_fp8.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..7483d7ed1eea 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep -#include "nanobind/nanobind.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 9da841acc7de..64f84965b8e2 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -2,6 +2,7 @@ // This module is called by mlir/__init__.py during initialization. #include +#include "shardy/integrations/c/passes.h" #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/Dialect/GPU.h" @@ -14,10 +15,8 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" - namespace nb = nanobind; #define REGISTER_DIALECT(name) \ diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..e824d4058d7e 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "nanobind/nanobind.h" #include "jaxlib/triton/triton_dialect_capi.h" namespace nb = nanobind; diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index a1e7b571d20e..2358a97ba20d 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,6 +18,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" @@ -32,7 +37,9 @@ limitations under the License. #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -43,14 +50,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" #include "tsl/platform/statusor.h" // Generated definitions. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b4f13c50bd8c..47b286aec302 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -28,8 +30,6 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 527aa7c7ce25..c259da3e737c 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -25,25 +25,25 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" #include "tsl/platform/errors.h" namespace mosaic_gpu { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.cc b/jaxlib/mosaic/dialect/tpu/array_util.cc index 4c1e79667c0f..f7d559fb08bc 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::internal { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.h b/jaxlib/mosaic/dialect/tpu/array_util.h index 1b755dbf8495..ab8e98d17836 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.h +++ b/jaxlib/mosaic/dialect/tpu/array_util.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/array_util_test.cc b/jaxlib/mosaic/dialect/tpu/array_util_test.cc index 18c2f94fa8b6..bcbf417a967b 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 772e87beff71..ce7e90d45fb9 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" @@ -39,8 +41,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" @@ -410,7 +410,7 @@ MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { mlir::tpu::registerMosaicSerdePass(); } -#include "mlir/CAPI/Pass.h" // IWYU pragma: keep +#include "mlir/CAPI/Pass.h" // IWYU pragma: keep #include "mlir/CAPI/Support.h" // IWYU pragma: keep extern "C" { diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 172f2e91b41f..c54c99fc9825 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -41,7 +42,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 2c45be62fa7d..bcfe205d58a9 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -38,7 +39,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 59ca5d7a3437..73c119b70e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -23,7 +23,11 @@ limitations under the License. #include #include +#include "absl/hash/hash.h" +#include "absl/log/log.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -32,10 +36,6 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 0800a9e75087..cf74689dd3e6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -24,11 +24,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" #include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index c73accb09b26..b69a6ae06a7f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -19,24 +19,24 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/IRMapping.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1997ffe34535..7755738a4fc7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -13,15 +13,23 @@ #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -33,9 +41,11 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" @@ -45,21 +55,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/include/llvm/ADT/APInt.h" -#include "llvm/include/llvm/Support/LogicalResult.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/array_util.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index 33c9e7421004..fded0d1dbfd7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -3,9 +3,9 @@ #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 5efbdb9cb437..6f56489ab4b1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -13,31 +13,30 @@ // NOLINTNEXTLINE(misc-include-cleaner) #include "mlir/Dialect/MemRef/IR/MemRef.h" // NOLINTNEXTLINE(misc-include-cleaner) +#include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/include/mlir/IR/AffineExpr.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Block.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Region.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index e7528533938f..067f8e592e30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,7 +1,7 @@ #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index c9c4a97e6222..9dbf89724fef 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -3,9 +3,9 @@ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 0926f8a3c7b5..cdf48632784b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/log/check.h" #include "llvm/ADT/bit.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -23,7 +24,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 0081feba985b..00e53314e588 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -23,6 +23,9 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/raw_ostream.h" @@ -32,22 +35,16 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index d240f27fd42d..36fa2ce8113f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -4,8 +4,8 @@ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 949a26a4f593..0d310ff45b30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -19,32 +19,32 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/include/mlir/IR/AffineMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Matchers.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index b73ea0f1250f..f78df135a45a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "absl/log/check.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index b88504e35068..8aae7a10279a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -3,20 +3,19 @@ #include #include +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" -#include "absl/log/check.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/Support/MathExtras.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 0981c263d252..5f6c9bd712ff 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -18,19 +18,17 @@ limitations under the License. #include #include +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/serde.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 64753a22e7be..ccb32131e519 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -4,11 +4,11 @@ #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 44cc301d6f0d..e61d9fa8d417 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -22,16 +22,16 @@ limitations under the License. #include #include -#include "llvm/Support/MathExtras.h" #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index f9ab1b7e349d..dadd71800f3e 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -10,19 +10,17 @@ #include #include +#include "absl/status/status.h" +#include "absl/types/span.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "tsl/platform/statusor.h" diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 1f59ee13a311..72e0bf7f0caf 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 86955e128f59..8c2967e776c7 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index ea3063361e1a..8a6d437ab73c 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/DebugStringHelper.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 402e099c8d6b..d4f4d1732b2e 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -40,49 +40,49 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "llvm/include/llvm/Support/CodeGen.h" -#include "llvm/include/llvm/Support/TargetSelect.h" -#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/include/mlir/Conversion/Passes.h" -#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" -#include "mlir/include/mlir/IR/AsmState.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/Parser/Parser.h" -#include "mlir/include/mlir/Pass/PassManager.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0331d800ec50..f3f982f07481 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -31,29 +31,29 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 1815e18ca927..1705405d2f32 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -19,18 +19,18 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index f4cf846acc11..5fca1d445774 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -15,10 +15,10 @@ limitations under the License. #include "jaxlib/mosaic/gpu/serde.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/serde.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index d1e25e3f0912..29dda33d0c5a 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a1a66a709cbe..a259b3dead7b 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "llvm/include/llvm/MC/MCSubtargetInfo.h" -#include "llvm/include/llvm/MC/TargetRegistry.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h index 546981feeef7..96d9e85a1d2d 100644 --- a/jaxlib/mosaic/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace jaxlib { namespace mlir { diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc index 88bca44bf181..307164d91dd9 100644 --- a/jaxlib/mosaic/serde.cc +++ b/jaxlib/mosaic/serde.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h index 762d9e5dad73..fdcaf58d4a8e 100644 --- a/jaxlib/mosaic/serde.h +++ b/jaxlib/mosaic/serde.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index 6a46d2914f57..8781fd16d76a 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -15,12 +15,12 @@ limitations under the License. #include "jaxlib/triton/triton_dialect_capi.h" -#include "llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir/CAPI/IR.h" -#include "mlir/include/mlir/CAPI/Registration.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Dialect.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 8c27b5b82500..7d2a2f10404a 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ #define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir-c/Support.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { From 7fa7db7a9fd874fd4561f66fdf1fdecb0611432e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 07:39:10 -0700 Subject: [PATCH 0053/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6e2c9b024cec7dca4b2e1b07cc89373574c9c5af. PiperOrigin-RevId: 738795997 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 08c5af0c32b8..b20048193bd1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d3145a119840723c16fd27ee342729d68fddb7ef" -XLA_SHA256 = "daf2a72e36a9358803a8156c48b32117c9699fd327fcbc37b465f1a0045bccae" +XLA_COMMIT = "6e2c9b024cec7dca4b2e1b07cc89373574c9c5af" +XLA_SHA256 = "387917467d6f6e8358d54ba2b89f3fef14a00e62d8b0a096bf07acc8186444d4" def repo(): tf_http_archive( From c098b363fb032bbf812eceef679141e5261380bd Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 20 Mar 2025 08:01:08 -0700 Subject: [PATCH 0054/1769] [JAX Shardy] Unskip stream annotation test when shardy is enabled, since the underlying issue is now resolved. PiperOrigin-RevId: 738802372 --- tests/memories_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index fdd654e2186d..adc45dbdb0c1 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1664,8 +1664,6 @@ class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") - if config.use_shardy_partitioner.value: - self.skipTest("Doesn't work with shardy") mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) From 8bbd738df1d77b998241b36a110eb5545cf4d2f3 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 20 Mar 2025 08:36:38 -0700 Subject: [PATCH 0055/1769] [JAX Shardy] #sdy Unskip another test that is now passing PiperOrigin-RevId: 738814411 --- tests/memories_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index adc45dbdb0c1..570b0c375834 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -756,9 +756,6 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - if config.use_shardy_partitioner.value: - self.skipTest("XLA failure due to b/370786664 and b/366411266. " - "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) From dad1b41f7bbe9f4d4bc39c261358df3f0823c84d Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 08:57:12 -0700 Subject: [PATCH 0056/1769] Reverts 2562da7026ccd930e5f0972598c7d5479175b787 PiperOrigin-RevId: 738820673 --- jaxlib/setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 60f17a987307..b3a37a25f1b2 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -68,10 +68,10 @@ def has_ext_modules(self): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: 3.13', + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ 'jaxlib': [ @@ -105,7 +105,7 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi', 'profiler/*.pyi'], + 'jaxlib.xla_extension': ['*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, From 1ec0585361b9306d02cd01c07053f196750afe59 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 09:06:27 -0700 Subject: [PATCH 0057/1769] Fix process_allgather of global jax.Arrays with shardy PiperOrigin-RevId: 738823617 --- jax/experimental/multihost_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 2bde1fbeadc4..7be349f0fc8f 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -99,8 +99,11 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding_impls.GSPMDSharding.get_replicated( - inp.sharding._device_assignment) + if isinstance(inp.sharding, sharding_impls.NamedSharding): + reps = inp.sharding.with_spec(P()) + else: + reps = sharding_impls.GSPMDSharding.get_replicated( + inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. From 1b37613e4b9385ef1da473cbe3b12d1ac82dd833 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 20 Mar 2025 19:01:20 +0000 Subject: [PATCH 0058/1769] Introduce optional CUDA presubmit for additional hardware config --- .github/workflows/bazel_optional_cuda.yml | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 .github/workflows/bazel_optional_cuda.yml diff --git a/.github/workflows/bazel_optional_cuda.yml b/.github/workflows/bazel_optional_cuda.yml new file mode 100644 index 000000000000..71936aeb9ae8 --- /dev/null +++ b/.github/workflows/bazel_optional_cuda.yml @@ -0,0 +1,65 @@ +name: CI - Bazel Optional CUDA tests +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: + contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + run_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + strategy: + matrix: + # Optional gpus to run against + runner: ["linux-x86-a4-224-b200-1gpu"] + name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CUDA Tests + run: | + nvidia-smi + bazel test --config=ci_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --local_test_jobs=32 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file From 549c6694513d778f93ae63f282d8990c520d0a27 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 20 Mar 2025 19:28:32 +0000 Subject: [PATCH 0059/1769] Straight-through estimator for nvfp4 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 27 +++++++++++++++++------ tests/scaled_matmul_stablehlo_test.py | 24 ++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 1a8dee293082..ddcde6a95b26 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -489,18 +489,24 @@ def quantize(x, config): assert config.scale_type == jnp.float8_e8m0fnu scales_q = cast_to_e8m0_with_rounding_up(scales) scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + clipped_x = jnp.clip(scaled_x, -MAX, MAX) + x_q = clipped_x.astype(config.data_type) elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 - - scales = scales / config.global_scale - scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) - scaled_x = x / (scales_q.astype(jnp.float32) * - config.global_scale).astype(x.dtype) + SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) + + prev_amax = config.global_scale * (MAX * SCALE_MAX) + scales_q = jnp.clip( + (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX + ).astype(config.scale_type) + x_q = jnp.where( + amax <= prev_amax, + (x * MAX) / amax, + jnp.clip((x * MAX) / prev_amax, -MAX, MAX), + ).astype(config.data_type) else: raise ValueError(f"Unrecognized mode: {config.mode}.") - clipped_x = jnp.clip(scaled_x, -MAX, MAX) - x_q = clipped_x.astype(config.data_type) x_q = x_q.reshape(x_shape) # shape = (B, M, K) scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( @@ -639,6 +645,13 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + + if configs[2].mode == "nvfp4": + assert rhs.dtype == lhs.dtype + MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) + SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) + grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) + grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) return (grad_lhs, grad_rhs) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 141839a19a08..224f6b6204e5 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -194,6 +194,11 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) + # To emulate calibrated amax + amax_sf = 0.9 + amax_a *= amax_sf + amax_b *= amax_sf + # Update global scales data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( jnp.float32 @@ -567,8 +572,27 @@ def fwd(a, b, is_ref=False, use_normalized=False): out_ref, _ = j_train_fwd_ref(a_dq, b_dq) self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + def _grad_clip(amax, x, grad): + return jnp.where(jnp.abs(x) <= amax, grad, 0) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + x_grad_ref = _grad_clip(prev_amax_a, a_raw, x_grad_ref) + w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + # Verify straight_through_estimator + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > prev_amax_a, x_grad, 0), + jnp.zeros_like(x_grad) + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > prev_amax_b, w_grad, 0), + jnp.zeros_like(w_grad) + ) else: j_inference = jax.jit(fwd) j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True)) From 59e480db99ea221c21efc566d4fe7f51ffebadf8 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 20 Mar 2025 12:51:04 -0700 Subject: [PATCH 0060/1769] [Mosaic GPU] Skip Mosaic GPU tests if jax_pallas_use_mosaic_gpu flag is not set. PiperOrigin-RevId: 738906466 --- jax/_src/pallas/pallas_call.py | 2 +- tests/pallas/BUILD | 1 - tests/pallas/mosaic_gpu_test.py | 4 ++++ tests/pallas/ops_test.py | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0b74b2e5148..fbe3d23c6c27 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1206,7 +1206,7 @@ def _trace_kernel_to_jaxpr( return jaxpr, tuple(consts) -_PALLAS_USE_MOSAIC_GPU = config.bool_flag( +_PALLAS_USE_MOSAIC_GPU = config.bool_state( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 987a3aa9d50a..8ec4eea7aa1f 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -215,7 +215,6 @@ jax_multiplatform_test( "gpu_h100", ], env = { - "JAX_PALLAS_USE_MOSAIC_GPU": "1", "JAX_PALLAS_VERBOSE_ERRORS": "0", }, deps = [ diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 40e98bf05ba9..27017e4eb740 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,6 +26,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp @@ -59,6 +60,9 @@ class PallasTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) super().setUp() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0fc375bf64a1..38426747d85d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -61,7 +62,7 @@ jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) -use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) From 412f1d35dca4efa655a291f3d4b3f8f3d6b4546d Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Fri, 7 Feb 2025 07:12:19 +0000 Subject: [PATCH 0061/1769] Adding sharding support to dynamic masks --- .../splash_attention_kernel.py | 25 +- .../splash_attention_mask_info.py | 110 ++++++--- tests/pallas/BUILD | 16 ++ ...pu_splash_attention_kernel_sharded_test.py | 223 ++++++++++++++++++ .../tpu_splash_attention_kernel_test.py | 14 +- .../pallas/tpu_splash_attention_mask_test.py | 4 +- 6 files changed, 352 insertions(+), 40 deletions(-) create mode 100644 tests/pallas/tpu_splash_attention_kernel_sharded_test.py diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4b6e4a41c43b..d0fb6f2f9670 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -2293,6 +2293,26 @@ def _splash_attention( mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: + """ + For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). + This shape allows sharding across both head count and query sequence dimensions. + + Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be + collapsed into a single dimension before being passed to the kernel. + """ + def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): + if mask_info is None or mask_info.partial_mask_blocks is None: + return mask_info + + return mask_info._replace( + partial_mask_blocks=mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ) + ) + + fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) + dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) + dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, @@ -2352,13 +2372,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() + partial_mask_blocks_spec = ( + spec if self.fwd_mask_info.is_dynamic_mask else replicated + ) # Shard q_sequence over the sequence dimension only. q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, - partial_mask_blocks=replicated + partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 65081e79c0cf..9c79fbbf7e09 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -67,6 +67,10 @@ class MaskInfo(NamedTuple): q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). + is_dynamic_mask: A bool indicating whether the mask is dynamic or static. + When True, the leading dimensions of `partial_mask_blocks` (num_heads, + q_blocks, kv_blocks) are not collapsed, allowing us to shard it along + those dimensions. """ data_next: np.ndarray | jax.Array | None @@ -74,6 +78,7 @@ class MaskInfo(NamedTuple): block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None + is_dynamic_mask: bool = None def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: @@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, - mask: mask_lib.MultiHeadMask, + mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, @@ -338,7 +343,8 @@ def _process_dynamic_mask( launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. - shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. + shrink_grid: Whether or not we should apply the grid shrinking optimization. + This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. @@ -349,11 +355,6 @@ def _process_dynamic_mask( """ del shrink_grid - - # TODO(pobudzey): Properly support sharding. - if head_shards != 1 or q_seq_shards != 1: - raise ValueError('Dynamic mask processing does not support sharding.') - if len(mask.shape) != 3: raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') @@ -370,6 +371,18 @@ def _process_dynamic_mask( if kv_mod != 0: raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + heads_per_shard, mod = divmod(head_count, head_shards) + if mod != 0: + raise ValueError(f'{head_shards=} should divide {head_count=}.') + block_mask_shape = ( head_count, q_blocks_count, @@ -398,26 +411,66 @@ def _process_dynamic_mask( block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) - # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. - mask_next = jnp.where( - jnp.logical_or(is_empty_mask, is_full_mask), - 0, - jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( - block_mask_shape - ), - ) + q_sequence_axis = 1 + head_axis = 0 - # data_next stores the index of the next non-empty data block in the sequence. - # The indices of empty blocks are set to 0 to avoid copying extra data when - # pipeling. - if is_dkv: - data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] - else: - data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] - data_next = jnp.broadcast_to(data_next, block_mask_shape) - data_next = jnp.where(is_empty_mask, 0, data_next) + # Each iteration of the loop processes a slice of the mask info + # tensors of this shape: + mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) + + # Collect mask_info shards along the head dimension, concatentate (or + # broadcast) them after the loop. + data_next_per_head_list, mask_next_per_head_list = [], [] + for head_shard in range(head_shards): + head_start = head_shard * heads_per_shard + mask_head_slice = slice(head_start, head_start + heads_per_shard) + + # Collect mask_info shards along the q_sequence dimension, concatenate them + # after the loop. + data_next_sequence_slices, mask_next_sequence_slices = [], [] + for q_seq_len_shard in range(q_seq_shards): + q_seq_len_start = q_seq_len_shard * q_blocks_per_shard + blocked_q_seq_len_slice = slice( + q_seq_len_start, q_seq_len_start + q_blocks_per_shard + ) + local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] + + mask_next_slice = jnp.arange( + math.prod(mask_info_slice_shape), dtype=np.int32 + ).reshape(mask_info_slice_shape) + mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) + + # data_next stores the index of the next non-empty data block in the sequence. + # The indices of empty blocks are set to 0 to avoid copying extra data when + # pipeling. + if is_dkv: + data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[ + None, :, None + ] + else: + data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[ + None, None, : + ] + data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) + data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) + + data_next_sequence_slices.append(data_next_slice) + mask_next_sequence_slices.append(mask_next_slice) + + # Concatenate the sequence shards. + data_next_per_head = jnp.concatenate( + data_next_sequence_slices, axis=q_sequence_axis + ) + data_next_per_head_list.append(data_next_per_head) + mask_next_per_head = jnp.concatenate( + mask_next_sequence_slices, axis=q_sequence_axis + ) + mask_next_per_head_list.append(mask_next_per_head) + + # Concatenate (or broadcast) the head shards. + data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) + mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) - partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) @@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: if downcast_smem_data: block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] data_next = _downcast( - data_next, q_blocks_count if is_dkv else kv_blocks_count + data_next, q_blocks_per_shard if is_dkv else kv_blocks_count + ) + mask_next = _downcast( + mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count ) - mask_next = _downcast(mask_next, math.prod(block_mask_shape)) return ( MaskInfo( @@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, + is_dynamic_mask=True, ), None, ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 8ec4eea7aa1f..34af5e16a9b6 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -540,6 +540,22 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) +jax_multiplatform_test( + name = "tpu_splash_attention_kernel_sharded_test", + srcs = ["tpu_splash_attention_kernel_sharded_test.py"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_2x2", + ], + shard_count = 5, + deps = [ + "//jax:extend", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ], +) + + # This test doesn't need a TPU; it only tests numpy-using helpers. jax_py_test( name = "tpu_splash_attention_mask_test", diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py new file mode 100644 index 000000000000..db14b44938e9 --- /dev/null +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -0,0 +1,223 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning splash_attention.""" + +import functools +import math +from absl.testing import absltest, parameterized +import jax +from jax import random +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import PartitionSpec +import numpy as np + +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Test requires TPU.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def generate_mask(shape, num_heads, seed) -> np.ndarray: + assert num_heads >= 2 + assert shape > (64, 64) + + masks = [ + mask_lib.make_causal_mask(shape), + mask_lib.make_local_attention_mask(shape, window_size=(64, 64)), + ] + masks += [mask_lib.make_random_mask(shape, 0.8, seed)] * (num_heads - 2) + return np.stack(masks, axis=0) + + +class SplashAttentionShardingTest(PallasBaseTest): + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4, 16], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha( + self, topology, num_heads, dtype, is_dynamic_mask + ): + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + out = f(kernel, q, k, v) + out_ref = jax.vmap(splash.attention_reference)(mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_dynamic_mask + ): + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + f_ref = jax.vmap(splash.attention_reference) + + out, out_vjp = jax.vjp(f, kernel, q, k, v) + out_ref, out_vjp_ref = jax.vjp(f_ref, mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv = out_vjp(do) + _, dq_ref, dk_ref, dv_ref, _ = out_vjp_ref(do.astype(jnp.float32)) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=5e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index dfe0bcc0da3b..240a9c91c02d 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -303,14 +303,6 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) -def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array: - q_seq_len, kv_seq_len = mask.masks[0].shape - full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len)) - dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0) - - return dynamic_mask - - @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -384,7 +376,7 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -460,7 +452,7 @@ def test_splash_attention_fwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -628,7 +620,7 @@ def test_splash_attention_bwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if use_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, use_fused_bwd_kernel=use_fused_bwd_kernel) diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f39b4d839340..5379eb10990f 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -2166,7 +2166,9 @@ def test_dynamic_mask(self, is_dkv: bool): self.assertArraysEqual(mask_info.block_mask, _expected_block_mask) self.assertArraysEqual( - mask_info.partial_mask_blocks, + mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ), _expected_partial_mask_blocks, ) self.assertArraysEqual(mask_info.mask_next, _expected_mask_next) From ea7fa29be73f322eed727a59f1dcbf8cb7ac7170 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 13:23:09 -0700 Subject: [PATCH 0062/1769] Allow `tuple(arrays)` as an input to `make_array_from_single_device_arrays`. Fixes https://github.com/jax-ml/jax/issues/27303 PiperOrigin-RevId: 738917340 --- jax/_src/array.py | 7 ++++--- tests/array_test.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index e49963ccda9c..ee196026887d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1024,7 +1024,7 @@ def make_array_from_single_device_arrays( shape : Shape of the output ``jax.Array``. This conveys information already included with ``sharding`` and ``arrays`` and serves as a double check. sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices. - arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` + arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code, each process will call with a different ``arrays`` argument that corresponds to that processes' data. These arrays are commonly created via ``jax.device_put``. @@ -1071,14 +1071,15 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) except TypeError: - if not isinstance(arrays, Sequence): + if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " - "argument must be a Sequence (list or tuple), but got " + "argument must be a list or tuple, but got " f"{type(arrays)}.") if any(isinstance(arr, core.Tracer) for arr in arrays): raise ValueError( diff --git a/tests/array_test.py b/tests/array_test.py index cc8990828ded..6100283cc032 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1301,6 +1301,18 @@ def f(x): with self.assertRaisesRegex(TypeError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_tuple(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (8, 8) + s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + inp_data = np.arange(math.prod(shape)).reshape(shape) + + arrays = tuple( + jax.device_put(inp_data[index], d) + for d, index in s.addressable_devices_indices_map(shape).items()) + + jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash + def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) mesh = jtu.create_mesh((2,), ('x',)) From d0b71fa1ceb11e9fbf89a8d0e4f6be47b80ab382 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 20 Mar 2025 14:04:48 -0700 Subject: [PATCH 0063/1769] [Mosaic GPU] Add preliminary TMEM allocation support for Pallas/Mosaic GPU. PiperOrigin-RevId: 738932990 --- jax/_src/pallas/mosaic_gpu/core.py | 3 + jax/_src/pallas/mosaic_gpu/lowering.py | 108 ++++++++++++++++++++++--- jax/experimental/mosaic/gpu/core.py | 2 + jax/experimental/pallas/mosaic_gpu.py | 2 + tests/pallas/mosaic_gpu_test.py | 29 +++++++ 5 files changed, 134 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index b1e0a683f64d..1e4a9de1830c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -101,6 +101,8 @@ class GPUMemorySpace(enum.Enum): GMEM = "gmem" #: Shared memory. SMEM = "smem" + #: Tensor memory. + TMEM = "tmem" #: Registers. REGS = "regs" @@ -452,6 +454,7 @@ def to_block_mapping( GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM +TMEM = GPUMemorySpace.TMEM REGS = GPUMemorySpace.REGS diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1fae91773178..ef4c80cb4649 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -59,6 +59,7 @@ from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp import numpy as np @@ -100,6 +101,7 @@ def arrival_multiplier(self) -> int: @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 + tmem_scratch_cols: int = 0 barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( default_factory=collections.Counter ) @@ -110,6 +112,12 @@ def __post_init__(self): "smem_scratch_bytes", _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), ) + object.__setattr__( + self, + "tmem_scratch_cols", + # TMEM must be allocated in 128x8 chunks. + _align_to(self.tmem_scratch_cols, 8), + ) @property def barriers(self) -> Sequence[mgpu.Barrier]: @@ -122,6 +130,7 @@ def __add__(self, other: Resources) -> Resources: # we will allocate two barriers, even though one would be enough. return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, ) @@ -130,6 +139,9 @@ def __or__(self, other: Resources) -> Resources: smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), + tmem_scratch_cols=max( + self.tmem_scratch_cols, other.tmem_scratch_cols + ), barrier_counts=self.barrier_counts | other.barrier_counts, ) @@ -218,10 +230,26 @@ def _run_scoped_resource_estimator( ) ]) ) - else: + elif aval.memory_space == gpu_core.TMEM: + if aval.dtype.itemsize != 4: + raise ValueError("TMEM only supports 32-bit types.") + if len(aval.shape) != 2: + raise ValueError("TMEM allocations must be 2D.") + if aval.shape[0] % tcgen05.TMEM_ROWS != 0: + raise ValueError("TMEM shape[0] must be a multiple of 128.") + if aval.shape[1] % 8 != 0: + raise ValueError("TMEM shape[1] must be a multiple of 8.") + rs += Resources(tmem_scratch_cols=aval.shape[1]) + elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize ) + elif aval.memory_space == gpu_core.REGS: + # Don't need to allocate anything. + pass + else: + raise NotImplementedError( + f"Unsupported memory space: {aval.memory_space}") return rs + _estimate_resources(ctx, jaxpr) @@ -267,6 +295,9 @@ class ModuleContext: single_wg_lane_predicate: ir.Value smem_requested_bytes: int smem_used_bytes: int + tmem_requested_cols: int + tmem_used_cols: int + tmem_base_ptr: ir.Value runtime_barriers: MutableMapping[ mgpu.Barrier, MutableSequence[mgpu.BarrierRef] ] @@ -286,6 +317,27 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: raise RuntimeError(f"Barrier {barrier} is already reserved") return available.pop() + @contextlib.contextmanager + def alloc_tmem( + self, + struct: jax.ShapeDtypeStruct, + layout: tcgen05.TMEMLayout | None = None + ) -> ir.Value: + if self.tmem_used_cols > 0: + raise NotImplementedError( + "Multiple TMEM allocations are not implemented.") + if layout is None: + layout = tcgen05._infer_tmem_layout(struct.shape, collective=False) + cols_used = np.prod(struct.shape) // tcgen05.TMEM_ROWS + self.tmem_used_cols += cols_used + off = self.tmem_base_ptr + tmem_ref = tcgen05.TMEMRef(address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout) + yield tmem_ref + self.tmem_used_cols -= cols_used + # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager def scratch_view( @@ -642,11 +694,15 @@ def lower_jaxpr_to_module( parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers) = buffers + *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): grouped_barriers[barrier].append(barrier_ref) + if runtime_tmem is not None: + tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS + else: + tmem_cols = 0 module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), axis_names, @@ -655,6 +711,9 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mgpu.single_thread_predicate(per_block=False), smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, + tmem_requested_cols=tmem_cols, + tmem_used_cols=0, + tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), @@ -671,6 +730,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes + tmem_scratch_cols = rs.tmem_scratch_cols + + scratch_buffers = [ + jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), + rs.barriers, + ] + if tmem_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM(shape=[tcgen05.TMEM_ROWS, tmem_scratch_cols], dtype=np.int32), + ) + else: + scratch_buffers.append(None) prof_ctx = prof_spec = None if prof_space := params.get("profile_space", 0): @@ -685,10 +756,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): block=block, in_shapes=in_shapes, out_shape=out_shapes, - smem_scratch_shape=( - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, - ), + smem_scratch_shape=scratch_buffers, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) @@ -990,14 +1058,26 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... @register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") +def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + if len(transforms) != 1 or not isinstance( + transforms[0], indexing.NDIndexer): + raise NotImplementedError( + "Only a single indexing transform is supported for TMEM refs.") + indexer = cast(indexing.NDIndexer, transforms[0]) + if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): + raise NotImplementedError( + "Only trivial indexing is supported for TMEM refs.") + return x_ref[:] + + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only load from references (got {x_ref}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) + x_smem, transforms = _handle_reshaping(x_ref, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: @@ -1784,6 +1864,14 @@ def _run_scoped_lowering_rule( ) input_refs.append(input_ref) should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) else: raise ValueError(f"Can't convert to ref: {aval}") diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b255893e2e2e..f5331eb1b56a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -307,6 +307,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: raise NotImplementedError("Misaligned barrier allocation") size += num_barriers * utils.MBARRIER_BYTES case TMEM(_): + # TODO(justinfu): This can trigger misaligned barrier allocations + # if TMEM is requested before barriers b/c it's not divisible by 8. size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 631b4f720984..aab58d092190 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -51,3 +51,5 @@ GMEM = GPUMemorySpace.GMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. SMEM = GPUMemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.TMEM`. +TMEM = GPUMemorySpace.TMEM diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 27017e4eb740..94c2620f7ae6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -83,6 +83,13 @@ def setUp(self): super().setUp() +class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm100a() + super().setUp() + + class PallasCallTest(PallasTest): @parameterized.product( @@ -1531,6 +1538,28 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) +class PallasCallSm100ATest(PallasSm100ATest): + + def test_tmem_alloc(self): + mesh = plgpu.GPUMesh(num_threads=1, axis_names=("x")) + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def _(): + def scope(tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32)) + y_init = jnp.zeros((128, 128), np.float32) + # Test that this runs without errors. + jax.block_until_ready(inner(y_init)) + + class PipelineTest(PallasTest): def test_pipeline_mode(self): From 55b55e6b1b64c24c3dd87274427594fc56e8f6e0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 10 Mar 2025 17:00:44 +0100 Subject: [PATCH 0064/1769] Enable multi-threading in Jax Context with shared thread pool --- jax/_src/interpreters/mlir.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 47d3fc52de26..f96f07be4149 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -593,6 +593,15 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Translation rules +# Create one global thread pool that can be shared between multiple ir.Contexts +# and enabling multi-threading +# TODO: remove this check after jaxlib 0.5.4 +if hasattr(ir, "ThreadPool"): + global_thread_pool = ir.ThreadPool() +else: + global_thread_pool = None + + class JaxIrContext(ir.Context): def __init__(self, *args, **kwargs): # Note: we're very intentionally *not* calling the __init__() of our @@ -607,12 +616,16 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) + # TODO: remove this check after v0.5.4 jaxlib + if global_thread_pool is not None: + context.set_thread_pool(global_thread_pool) + else: + # If threading is enabled, each MLIR context will keep alive a thread pool. + # Since we cache MLIR modules (and hence contexts), this means we might keep + # several threads alive for each cache entry. This is a terrible idea. However + # we don't do any heavy computation on MLIR modules from Python anyway, so we + # just disable threading. + context.enable_multithreading(False) # TODO(bartchr): Once JAX is released with SDY, remove the if. if dialects.sdy: dialects.sdy.register_dialect(context) From a8fb0e01f8d083fff337d3c26375bb1b77344a99 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 15:22:31 -0700 Subject: [PATCH 0065/1769] [sharding_in_types] Fix a dynamic_slice bug where in the transpose, `DUS`'s operand was not sharded properly PiperOrigin-RevId: 738959282 --- jax/_src/lax/slicing.py | 3 ++- tests/pjit_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c26de99c7374..b4a8817fbb8d 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1476,7 +1476,8 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: - zeros = lax.full(operand_shape, 0, operand_dtype) + zeros = lax.full(operand_shape, 0, operand_dtype, + sharding=operand.aval.sharding) return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] + [None] * len(start_indices)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6cf11494988a..6a1a73fe4301 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7090,6 +7090,30 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) + @jtu.with_user_mesh((2,), ('x',)) + def test_dynamic_slice(self, mesh): + np_inp = np.arange(16., dtype=np.float32) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + + def g(x): + return jnp.sum(f(x)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + def test_auto_axes_computation_follows_data_error(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) From 1fe24ca7552576d59649f335cf5878c069d180f8 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 20 Mar 2025 23:26:21 +0000 Subject: [PATCH 0066/1769] Improve based on review 1 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 15 +++-- tests/scaled_matmul_stablehlo_test.py | 75 +++++++++++++++++++---- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index ddcde6a95b26..b1d353e7bcd1 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -489,8 +489,6 @@ def quantize(x, config): assert config.scale_type == jnp.float8_e8m0fnu scales_q = cast_to_e8m0_with_rounding_up(scales) scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) - clipped_x = jnp.clip(scaled_x, -MAX, MAX) - x_q = clipped_x.astype(config.data_type) elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 @@ -499,15 +497,15 @@ def quantize(x, config): prev_amax = config.global_scale * (MAX * SCALE_MAX) scales_q = jnp.clip( (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX - ).astype(config.scale_type) - x_q = jnp.where( - amax <= prev_amax, - (x * MAX) / amax, - jnp.clip((x * MAX) / prev_amax, -MAX, MAX), - ).astype(config.data_type) + ) + scaled_x = x / scales_q + scales_q = scales_q.astype(config.scale_type) else: raise ValueError(f"Unrecognized mode: {config.mode}.") + clipped_x = jnp.clip(scaled_x, -MAX, MAX) + x_q = clipped_x.astype(config.data_type) + x_q = x_q.reshape(x_shape) # shape = (B, M, K) scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( config.scale_type @@ -652,6 +650,7 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) + return (grad_lhs, grad_rhs) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 224f6b6204e5..b53ffcd5b977 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -174,7 +174,7 @@ def update_global_scale(config, new_global_scale): config.global_scale = new_global_scale return config -def generate_nvfp4_quantized_tensors(dot_config, output_type): +def generate_nvfp4_quantized_tensors(dot_config, output_type, enable_grad_clip=False): k1, k2 = jax.random.split(jax.random.key(0), 2) a_shape, b_shape, dimension_numbers = dot_config @@ -195,7 +195,7 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) # To emulate calibrated amax - amax_sf = 0.9 + amax_sf = 0.9 if enable_grad_clip else 1.0 amax_a *= amax_sf amax_b *= amax_sf @@ -513,6 +513,68 @@ def fn(a): self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + enable_grad_clip=[True, False], + configs=[ + # a_shape, b_shape, dimension_numbers + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0]))), + ((30, 64), (100, 64), (([1], [1]), ([], []))), + ] + ) + @jtu.run_on_devices("cuda") + def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): + output_type = jnp.float32 + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs, output_type, enable_grad_clip) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + dimension_numbers = configs[2] + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + def fwd(a, b, use_normalized=False): + y = scaled_dot_general( + a, b, dimension_numbers, + preferred_element_type=output_type + ) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + _, (x_grad, w_grad) = j_train(a_raw, b_raw) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + # Use a large value to ensure no clipping + threshold_a = prev_amax_a if enable_grad_clip else 1e9 + threshold_b = prev_amax_b if enable_grad_clip else 1e9 + + # Verify gradients are clipped to 0 where |input| > global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > threshold_a, x_grad, 0), + jnp.zeros_like(x_grad), + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > threshold_b, w_grad, 0), + jnp.zeros_like(w_grad), + ) + if enable_grad_clip: + # Verify gradients are preserved where |input| <= global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) <= prev_amax_a, x_grad, 0), + x_grad, + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) <= prev_amax_b, w_grad, 0), + w_grad, + ) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training @@ -584,15 +646,6 @@ def _grad_clip(amax, x, grad): w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) - # Verify straight_through_estimator - self.assertArraysEqual( - jnp.where(jnp.abs(a_raw) > prev_amax_a, x_grad, 0), - jnp.zeros_like(x_grad) - ) - self.assertArraysEqual( - jnp.where(jnp.abs(b_raw) > prev_amax_b, w_grad, 0), - jnp.zeros_like(w_grad) - ) else: j_inference = jax.jit(fwd) j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True)) From 0eb430c128cfe9448971441ec2fef61e13548592 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Mar 2025 23:26:37 +0100 Subject: [PATCH 0067/1769] Increased test timeout in TSAN CI Description: - Increased test timeout in TSAN CI - Skip tests: testMishGrad and testSquareplusGrad --- .github/workflows/tsan.yaml | 2 +- tests/nn_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..6c97b7347ceb 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -210,7 +210,7 @@ jobs: --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ - --test_timeout=600 \ + --test_timeout=1800 \ --config=resultstore \ --config=rbe_cache \ //tests:cpu_tests diff --git a/tests/nn_test.py b/tests/nn_test.py index ed016ec349ef..1a1670444ef8 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -422,6 +422,7 @@ def testSparseplusAndSparseSigmoid(self): jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.), check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -442,6 +443,7 @@ def testSquareplusGradNan(self): def testSquareplusZero(self, dtype): self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testMishGrad(self): check_grads(nn.mish, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) From 4b7ead4d02f866077f11dcfcca7507533a441bcc Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 19 Feb 2025 22:51:35 +0000 Subject: [PATCH 0068/1769] Bump ml_dtypes>=0.5.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dbb7040d2d2b..bdaeb624bf38 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def load_version_module(pkg_path): python_requires='>=3.10', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.4.0', + 'ml_dtypes>=0.5.0', 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', From c7d6b653cea5c4144bed123d5f9ff1fa4b668e73 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 22:18:18 -0700 Subject: [PATCH 0069/1769] [sharding_in_types] Add `core.ShardingTypeError` as a new Exception that are sharding-in-types specific errors should raise. This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo. Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up. Previous error: ``` jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}). ``` New error: ``` jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}). This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues ``` The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error. PiperOrigin-RevId: 739053305 --- jax/_src/core.py | 4 ++++ jax/_src/interpreters/ad.py | 6 ++++++ jax/_src/lax/slicing.py | 2 +- jax/_src/lax/utils.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 36ce2f004ed4..243ffc871042 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1786,6 +1786,10 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +class ShardingTypeError(Exception): + pass + + # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ddf96af6a010..e47e518a11f2 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -407,6 +407,12 @@ def write_primal(v, val): try: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) + except core.ShardingTypeError as e: + extra_msg = ("This is a potential JAX bug. Please file an issue at" + " https://github.com/jax-ml/jax/issues") + if extra_msg in str(e): + raise + raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") except (FloatingPointError, ZeroDivisionError) as e: msg = "When differentiating the code at the top of the callstack:" if msg not in e.args[0]: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b4a8817fbb8d..b3a0a8e2d0c1 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1607,7 +1607,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: - raise TypeError( + raise core.ShardingTypeError( "dynamic_update_slice update sharding must be equal to operand" " sharding, got update sharding" f" {update.str_short(mesh_axis_types=True)} for operand sharding" diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index f39d925ac2ad..9fc9ba16a604 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -74,7 +74,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): s = NamedSharding(aval_mesh, P()) return s if num_out is None else [s] * num_out if rule is None: - raise ValueError( + raise core.ShardingTypeError( f'sharding rule for {prim.name} is not implemented. Please file a' ' bug at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' From 7953e6d0f88c068e0af9f38c0dd0c0c3ce05688a Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Fri, 21 Mar 2025 00:35:54 -0700 Subject: [PATCH 0070/1769] Add tests for varying `{batch, feature}_group_count`s for roofline `conv`. We'll need to use batch/feature when calculating flops, so it'll help reduce the size of the "calculating-flops" change if we can include them in our tests now. PiperOrigin-RevId: 739081930 --- tests/roofline_test.py | 88 ++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 564b4a9a1f9e..f94f5a328c46 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -572,40 +572,63 @@ def test_dot_general(self): result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) - def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int: return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 - @jtu.parameterized.named_parameters( - dict( - testcase_name="simple", - window_strides=(1, 1), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="padding", - window_strides=(1, 1), - padding=((1, 2), (3, 4)), - ), - dict( - testcase_name="window_strides", - window_strides=(2, 2), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="window_strides_and_padding", - window_strides=(3, 3), - padding=((1, 2), (3, 4)), - ), + def get_conv_num_output_channels( + self, batch_group_count: int, feature_group_count: int + ) -> int: + if batch_group_count > 1: + return batch_group_count + elif feature_group_count > 1: + return feature_group_count + else: + return 1 + + @jtu.parameterized.product( + window_strides=[(1, 1), (2, 2)], + padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))], + # batch must be divisible by batch_group_count, so we only include factors + # of batch_group_count. + batch=[6, 12], + batch_group_count=[1, 3], + # num_input_channels must be divisible by feature_group_count, so we only + # include factors of feature_group_count. + num_input_channels=[6, 12], + feature_group_count=[1, 3], ) def test_conv_general_dilated_unfused_hbm_bytes( - self, window_strides: Sequence[int, int], padding: Sequence[int, int] + self, + window_strides: Sequence[int, int], + padding: Sequence[int, int], + batch: int, + batch_group_count: int, + num_input_channels: int, + feature_group_count: int, ): + if batch_group_count > 1 and feature_group_count > 1: + self.skipTest( + "batch_group_count and feature_group_count cannot both be > 1" + ) + + num_output_channels = self.get_conv_num_output_channels( + batch_group_count, feature_group_count + ) + + num_input_features = int(num_input_channels / feature_group_count) iw, ih = 100, 200 kw, kh = 7, 7 - input_data = jnp.zeros((1, 1, iw, ih), dtype=int) - kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int) + kernel_data = jnp.ones( + (num_output_channels, num_input_features, kw, kh), dtype=int + ) conv = lambda a, b: lax.conv_general_dilated( - lhs=a, rhs=b, window_strides=window_strides, padding=padding + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + batch_group_count=batch_group_count, + feature_group_count=feature_group_count, ) _, result = roofline.roofline( @@ -615,8 +638,8 @@ def test_conv_general_dilated_unfused_hbm_bytes( out_specs=P(), )(input_data, kernel_data) - expected_input_size = 1 * 1 * iw * ih - expected_kernel_size = 1 * 1 * kw * kh + expected_input_size = batch * num_input_channels * iw * ih + expected_kernel_size = num_output_channels * num_input_features * kw * kh ow = self.get_conv_output_dim( iw, kw, padding[0][0], padding[0][1], window_strides[0] @@ -624,7 +647,10 @@ def test_conv_general_dilated_unfused_hbm_bytes( oh = self.get_conv_output_dim( ih, kh, padding[1][0], padding[1][1], window_strides[1] ) - expected_output_size = 1 * 1 * ow * oh + expected_output_shape = jnp.array( + (batch / batch_group_count, num_output_channels, ow, oh) + ) + expected_output_size = jnp.prod((expected_output_shape)) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size @@ -642,7 +668,9 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): + def test_conv_general_dilated_padding_string_unfused_hbm_bytes( + self, padding: str + ): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( From ad21b62bfec5560d4c612ed3c8412eb2d240468b Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 02:41:42 -0700 Subject: [PATCH 0071/1769] [AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler. PiperOrigin-RevId: 739109278 --- tests/pgle_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..7dabd809d95e 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -321,7 +327,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y From 5fef4cff7a37d0bb4d7004741189880b357699a2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 03:41:01 -0700 Subject: [PATCH 0072/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/469329ec36be093fd71d29e4518402300e04aeec. PiperOrigin-RevId: 739121877 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b20048193bd1..00f985cdf352 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6e2c9b024cec7dca4b2e1b07cc89373574c9c5af" -XLA_SHA256 = "387917467d6f6e8358d54ba2b89f3fef14a00e62d8b0a096bf07acc8186444d4" +XLA_COMMIT = "469329ec36be093fd71d29e4518402300e04aeec" +XLA_SHA256 = "9de006d7b51c36057898c81111fa9723b59f024eced067572fe5f6b1df63abdd" def repo(): tf_http_archive( From be5713309521d5cf0d2252b9c8f1d38ab50952d1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 21 Mar 2025 05:18:03 -0700 Subject: [PATCH 0073/1769] Delay the unflattening in `jnp.array` PiperOrigin-RevId: 739143346 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..16355695792d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,15 +49,16 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -65,8 +66,7 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax.tree_util import tree_flatten, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,9 +5504,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) + leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5515,7 +5513,13 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves = tree_leaves(object) + leaves, treedef = tree_flatten(object) + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5530,8 +5534,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + object = treedef.unflatten(leaves) out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From 7f0f185abd84b9b704d64f89c7fce0236b7c3403 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Fri, 21 Mar 2025 05:56:03 -0700 Subject: [PATCH 0074/1769] In JEP-12049, fix link to EAFP in the Python glossary: the anchor became mixed-case as of Python 3.10. PiperOrigin-RevId: 739150752 --- docs/jep/12049-type-annotations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..5ed760dd6c5c 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): From a93035f6250672230675290af82a829f0b0dd862 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 06:25:37 -0700 Subject: [PATCH 0075/1769] Migrate xla_client and its Python tests out of XLA into JAX. This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA. Future changes will migrate more of the C++ pieces of XLA:Python. PiperOrigin-RevId: 739158120 --- jax/_src/lib/BUILD | 3 +- jax/_src/lib/__init__.py | 22 +- jaxlib/BUILD | 4 +- jaxlib/jax.bzl | 3 + jaxlib/xla/BUILD | 162 + jaxlib/xla/config_test.py | 71 + jaxlib/xla/custom_calls_testlib.cc | 128 + jaxlib/xla/jax_jit_test.py | 47 + jaxlib/xla/pytree_test.py | 144 + jaxlib/xla/weakref_lru_cache_test.py | 257 ++ jaxlib/xla/xla_client.py | 1044 +++++ jaxlib/xla/xla_client.pyi | 322 ++ .../xla_client_backend_independent_test.py | 195 + jaxlib/xla/xla_client_test.py | 3714 +++++++++++++++++ pyproject.toml | 5 +- 15 files changed, 6106 insertions(+), 15 deletions(-) create mode 100644 jaxlib/xla/BUILD create mode 100644 jaxlib/xla/config_test.py create mode 100644 jaxlib/xla/custom_calls_testlib.cc create mode 100644 jaxlib/xla/jax_jit_test.py create mode 100644 jaxlib/xla/pytree_test.py create mode 100644 jaxlib/xla/weakref_lru_cache_test.py create mode 100644 jaxlib/xla/xla_client.py create mode 100644 jaxlib/xla/xla_client.pyi create mode 100644 jaxlib/xla/xla_client_backend_independent_test.py create mode 100644 jaxlib/xla/xla_client_test.py diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1fcbd4b6b7ef..1f4f41132e9e 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -44,6 +44,7 @@ py_library_providing_imports_info( "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", + "//jaxlib/xla:xla_client", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", @@ -60,6 +61,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", - # xla_client + # xla_extension ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 70dc914668cf..be551449aa17 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -40,7 +40,7 @@ raise ImportError(msg) from err -# Checks the jaxlib version before importing anything else from jaxlib. +# Checks the jaxlib version before importing anything else. # Returns the jaxlib version string. def check_jaxlib_version(jax_version: str, jaxlib_version: str, minimum_jaxlib_version: str) -> tuple[int, ...]: @@ -77,20 +77,23 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_version=jaxlib.version.__version__, minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -# Before importing any C compiled modules from jaxlib, first import the CPU +# Before importing any C compiled modules, first import the CPU # feature guard module to verify that jaxlib was compiled in a way that only # uses instructions that are present on this machine. import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack # noqa: F401 +import jaxlib.utils as utils # noqa: F401 +import jaxlib.xla_extension as xla_extension # noqa: F401 +from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401 +from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401 +from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401 +from jaxlib.xla_extension import pytree as pytree # noqa: F401 +import jaxlib.xla_client as xla_client # noqa: F401 + +from jaxlib.xla_extension import Device as Device # noqa: F401 -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): @@ -167,6 +170,3 @@ def _try_cuda_nvcc_import() -> str | None: return None cuda_path = _cuda_path() - -guard_lib = xla_client._xla.guard_lib -Device = xla_client._xla.Device diff --git a/jaxlib/BUILD b/jaxlib/BUILD index faf52a702386..2397639fddf2 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -81,7 +81,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - "@xla//xla/python:xla_extension", + "//jaxlib/xla:xla_client", ], ) @@ -94,7 +94,7 @@ symlink_files( symlink_files( name = "xla_client", - srcs = ["@xla//xla/python:xla_client"], + srcs = ["//jaxlib/xla:xla_client"], dst = ".", flatten = True, ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 02e6b10b1de1..4403915154bc 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -132,6 +132,9 @@ def pytype_strict_library(name, pytype_srcs = [], **kwargs): new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} native.py_library(name = name, data = data, **new_kwargs) +py_strict_library = native.py_library +py_strict_test = native.py_test + def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD new file mode 100644 index 000000000000..41152d642fc8 --- /dev/null +++ b/jaxlib/xla/BUILD @@ -0,0 +1,162 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "nanobind_extension", + "py_deps", + "py_strict_library", + "py_strict_test", + "pytype_strict_library", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + ["@xla//xla/python:xla_extension"], +) + +py_strict_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/testing", + "numpy", + "portpicker", + ]), +) + +py_strict_library( + name = "xla_client_test", + testonly = 1, + srcs = ["xla_client_test.py"], + visibility = [":xla_python"], + deps = [ + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +nanobind_extension( + name = "custom_calls_testlib", + testonly = 1, + srcs = ["custom_calls_testlib.cc"], + deps = [ + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + ], +) + +py_strict_test( + name = "xla_client_test_cpu", + srcs = ["xla_client_test.py"], + args = ["--backend=cpu"], + env = { + "XLA_FLAGS": "--xla_force_host_platform_device_count=4", + }, + main = "xla_client_test.py", + deps = [ + ":custom_calls_testlib", + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +py_strict_test( + name = "weakref_lru_cache_test", + srcs = ["weakref_lru_cache_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "numpy", + ]), +) diff --git a/jaxlib/xla/config_test.py b/jaxlib/xla/config_test.py new file mode 100644 index 000000000000..8701a37acd1d --- /dev/null +++ b/jaxlib/xla/config_test.py @@ -0,0 +1,71 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import threading + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +config = xla_client._xla.config + + +class ConfigTest(absltest.TestCase): + + def testBasic(self): + c = config.Config(1) + self.assertEqual(c.value, 1) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.get_local(), config.unset) + + c.set_global(2) + self.assertEqual(c.value, 2) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), config.unset) + + c.set_local(3) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), 3) + + c.set_global(4) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), 3) + + c.set_local(config.unset) + self.assertEqual(c.value, 4) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), config.unset) + + def testThreading(self): + c = config.Config(1) + + def Body(): + for i in range(100): + c.set_local(i) + self.assertEqual(c.get_local(), i) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.value, i) + + threads = [threading.Thread(target=Body) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/xla/custom_calls_testlib.cc new file mode 100644 index 000000000000..d06105fce76f --- /dev/null +++ b/jaxlib/xla/custom_calls_testlib.cc @@ -0,0 +1,128 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace xla::ffi { +namespace nb = ::nanobind; + +// Implement custom calls as static functions with XLA FFI types in the function +// signature that gives access to the arguments and results buffers together +// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI +// examples and features (e.g. binding attributes, custom user-defined structs +// and arbitrary execution context). + +static Error AlwaysFail(Result) { + return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally"); +} + +static Error AlwaysSucceed(Result) { return Error::Success(); } + +static Error Subtract(BufferR0 a, BufferR0 b, + Result> out) { + *out->typed_data() = *a.typed_data() - *b.typed_data(); + return Error::Success(); +} + +static Error SubtractCst(BufferR0 a, + Result> out, float cst) { + *out->typed_data() = *a.typed_data() - cst; + return Error::Success(); +} + +// Define XLA FFI handlers from the implementations defined above using explicit +// XLA FFI binding API to describe type signatures of custom calls. + +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, + Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract, + Ffi::Bind() + .Arg>() + .Arg>() + .Ret>()); + +XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst, + Ffi::Bind() + .Arg>() + .Ret>() + .Attr("cst")); + +// XLA FFI calls can also be stateful. +struct TestFfiState { + static TypeId id; + explicit TestFfiState(int32_t value) : value(value) {} + int32_t value; +}; +TypeId TestFfiState::id = {}; + +static ErrorOr> StateInstantiate() { + return std::make_unique(42); +} + +static Error StateExecute(TestFfiState* state, + Result> out) { + *out->typed_data() = state->value; + return Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, + Ffi::BindInstantiate()); +XLA_FFI_DEFINE_HANDLER( + kStateExecute, StateExecute, + Ffi::Bind().Ctx>().Ret>()); + +template +static auto BindFunction(T* fn) { + return nb::capsule(reinterpret_cast(fn)); +} + +template +static auto BindTypeId(T* typeId) { + return nb::capsule(reinterpret_cast(typeId)); +} + +// Custom calls registration library that exports function pointers to XLA FFI +// handlers to the python users. +NB_MODULE(custom_calls_testlib, m) { + m.def("registrations", []() { + nb::dict dict; + dict["always_fail"] = BindFunction(kAlwaysFail); + dict["always_succeed"] = BindFunction(kAlwaysSucceed); + dict["subtract_f32"] = BindFunction(kSubtract); + dict["subtract_f32_cst"] = BindFunction(kSubtractCst); + + nb::dict bundle; + bundle["instantiate"] = BindFunction(kStateInstantiate); + bundle["execute"] = BindFunction(kStateExecute); + dict["stateful"] = bundle; + + return dict; + }); + m.def("type_ids", []() { + nb::dict type_ids; + type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id); + return type_ids; + }); +} + +} // namespace xla::ffi diff --git a/jaxlib/xla/jax_jit_test.py b/jaxlib/xla/jax_jit_test.py new file mode 100644 index 000000000000..a090bc8dfadd --- /dev/null +++ b/jaxlib/xla/jax_jit_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jax_jit helper functions.""" + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +jax_jit = xla_client._xla.jax_jit +pytree = xla_client._xla.pytree + +pytree_registry = pytree.default_registry() + + +class JaxJitTest(absltest.TestCase): + + def testParseArguments(self): + sig, args = jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/pytree_test.py b/jaxlib/xla/pytree_test.py new file mode 100644 index 000000000000..b5ac7dd5b4d2 --- /dev/null +++ b/jaxlib/xla/pytree_test.py @@ -0,0 +1,144 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections +import dataclasses +import gc + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +pytree = xla_client._xla.pytree + + +ExampleType = collections.namedtuple("ExampleType", "field0 field1") + +registry = pytree.PyTreeRegistry() + + +class ExampleType2: + + def __init__(self, field0, field1): + self.field0 = field0 + self.field1 = field1 + + def to_iterable(self): + return [self.field0, self.field1], (None,) + + +def from_iterable(state, values): + del state + return ExampleType2(field0=values[0], field1=values[1]) + + +registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) + + +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + +class PyTreeTest(absltest.TestCase): + + def roundtrip(self, example): + original = registry.flatten(example)[1] + self.assertEqual( + pytree.PyTreeDef.deserialize_using_proto( + registry, original.serialize_using_proto() + ), + original, + ) + + def testSerializeDeserializeNoPickle(self): + o = object() + self.roundtrip(({"a": o, "b": o}, [o, (o, o), None])) + + def testSerializeWithFallback(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType(field0=o, field1=o)}) + + def testRegisteredType(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType2(field0=o, field1=o)}) + + def roundtrip_node_data(self, example): + original = registry.flatten(example)[1] + restored = pytree.PyTreeDef.make_from_node_data_and_children( + registry, original.node_data(), original.children() + ) + self.assertEqual(restored, original) + + def testRoundtripNodeData(self): + o = object() + self.roundtrip_node_data([o, o, o]) + self.roundtrip_node_data((o, o, o)) + self.roundtrip_node_data({"a": o, "b": o}) + self.roundtrip_node_data({22: o, 88: o}) + self.roundtrip_node_data(None) + self.roundtrip_node_data(o) + self.roundtrip_node_data(ExampleType(field0=o, field1=o)) + self.roundtrip_node_data(ExampleType2(field0=o, field1=o)) + + def testCompose(self): + x = registry.flatten(0)[1] + y = registry.flatten((0, 0))[1] + self.assertEqual((x.compose(y)).num_leaves, 2) + + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = c_tree.make_from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + + def testTpTraverse(self): + self.assertContainsSubset( + [ + pytree.PyTreeRegistry, + ExampleType2, + ExampleType2.to_iterable, + from_iterable, + ], + gc.get_referents(registry), + ) + k1 = "k1" + k2 = "k2" + + t = ExampleType("a", "b") + _, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t]) + + self.assertContainsSubset( + [ + pytree.PyTreeDef, + registry, + k1, + k2, + ExampleType, + ], + gc.get_referents(treedef), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/weakref_lru_cache_test.py b/jaxlib/xla/weakref_lru_cache_test.py new file mode 100644 index 000000000000..6ac3bfd71075 --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache_test.py @@ -0,0 +1,257 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import gc +import threading +import time +import weakref + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + + +class WeakrefLRUCacheTest(absltest.TestCase): + + def testMultiThreaded(self): + insert_evs = [threading.Event() for _ in range(2)] + insert_evs_i = 0 + + class WRKey: + pass + + class ClashingKey: + + def __eq__(self, other): + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + class GilReleasingCacheKey: + + def __eq__(self, other): + nonlocal insert_evs_i + if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len( + insert_evs + ): + insert_evs[insert_evs_i].set() + insert_evs_i += 1 + time.sleep(0.01) + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + def CacheFn(obj, gil_releasing_cache_key): + del obj + del gil_releasing_cache_key + return None + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 2048) + + wrkey = WRKey() + + def Body(): + for insert_ev in insert_evs: + insert_ev.wait() + for _ in range(20): + cache(wrkey, ClashingKey()) + + t = threading.Thread(target=Body) + t.start() + for _ in range(3): + cache(wrkey, GilReleasingCacheKey()) + t.join() + + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + + def testKwargsDictOrder(self): + miss_id = 0 + + class WRKey: + pass + + def CacheFn(obj, kwkey1, kwkey2): + del obj, kwkey1, kwkey2 + nonlocal miss_id + miss_id += 1 + return miss_id + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + wrkey = WRKey() + + self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1) + self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) + self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + + def testTpTraverse(self): + class WRKey: + pass + + def CacheContextFn(): + return None + + def CallFn(x, y, *args, **kwargs): + del x, args, kwargs + return y + + cache = xla_client.weakref_lru_cache(CacheContextFn, CallFn, 2048) + + keys = [WRKey() for _ in range(10)] + values = [str(i) for i in range(10)] + args = [str(i) for i in range(10)] + kwargs = {"a": "b"} + + for key, value in zip(keys, values): + cache(key, value, *args, **kwargs) + + expected_refs = ( + [ + CacheContextFn, + CallFn, + xla_client._xla.WeakrefLRUCache, + kwargs, + ] + + [weakref.getweakrefs(key)[0] for key in keys] + + values + + args + ) + + # Can't use assertContainsSubset because it doesn't support kwargs since + # dicts aren't hashable. + for ref in expected_refs: + self.assertIn(ref, gc.get_referents(cache)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py new file mode 100644 index 000000000000..b6c5707d05dd --- /dev/null +++ b/jaxlib/xla/xla_client.py @@ -0,0 +1,1044 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An XLA client in Python.""" + +from __future__ import annotations + +import atexit +from collections.abc import Mapping, Sequence +import contextlib +import enum # pylint: disable=g-bad-import-order +import gzip +import inspect +import logging +import os +import threading +from typing import Any, Protocol, Union + +import ml_dtypes +import numpy as np + +from jaxlib import xla_extension as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs +ops = _xla.ops +profiler = _xla.profiler + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.xla_extension_version. +_version = 320 + +# Version number for MLIR:Python components. +mlir_api_version = 58 + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, +) -> ...: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + ) + + +def make_gpu_client( + distributed_client=None, + node_id=0, + num_nodes=1, + platform_name=None, + allowed_devices=None, + mock=False, + mock_gpu_topology=None, +): + """Returns a GPU client. BFC allocator is used by default.""" + options = generate_pjrt_gpu_plugin_options() + allocator = options['allocator'] + config = _xla.GpuAllocatorConfig() + if allocator == 'default': + config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT + if allocator == 'platform': + config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM + if allocator == 'bfc': + config.kind = _xla.GpuAllocatorConfig.Kind.BFC + if allocator == 'cuda_async': + config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC + if 'memory_fraction' in options: + config.memory_fraction = options['memory_fraction'] + if 'preallocate' in options: + config.preallocate = options['preallocate'] + if 'collective_memory_size' in options: + config.collective_memory_size = options['collective_memory_size'] + register_custom_call_handler('CUDA', _xla.register_custom_call_target) + register_custom_call_handler('ROCM', _xla.register_custom_call_target) + register_custom_type_id_handler('CUDA', _xla.register_custom_type_id) + register_custom_type_id_handler('ROCM', _xla.register_custom_type_id) + + return _xla.get_gpu_client( + asynchronous=True, + allocator_config=config, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + platform_name=platform_name, + allowed_devices=allowed_devices, + mock=mock, + mock_gpu_topology=mock_gpu_topology, + ) + + +def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client(plugin_name, options, distributed_client) + + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not pjrt_plugin_loaded('tpu'): + c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') + profiler.register_plugin_profiler(c_api) + return make_tfrt_tpu_c_api_client(options) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + return options + + +class OpMetadata: + """Python representation of a xla.OpMetadata protobuf.""" + + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') + + def __init__(self, op_type='', op_name='', source_file='', source_line=0): + self.op_type = op_type + self.op_name = op_name + self.source_file = source_file + self.source_line = source_line + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) + + +PrimitiveType = _xla.PrimitiveType + +bfloat16 = ml_dtypes.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# Also, it would be better to conditionally import these based on whether they +# are in the current version of ml_dtypes. +# float4_e2m1fn = ml_dtypes.float4_e2m1fn +# float8_e3m4 = ml_dtypes.float8_e3m4 +# float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float8_e4m3fn = ml_dtypes.float8_e4m3fn +float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz + +XLA_ELEMENT_TYPE_TO_DTYPE = { + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S4: np.dtype('int4'), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U4: np.dtype('uint4'), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), + # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), + PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), + PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), + PrimitiveType.F8E5M2: np.dtype(float8_e5m2), + PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), + PrimitiveType.BF16: np.dtype(bfloat16), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +ShapeIndex = _xla.ShapeIndex +ShapeIndex.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class ShapeIndex: + '''Represents an XLA ShapeIndex. + + An index for specifying a particular nested subshape within a shape. Used in + ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through + the Shape tree where each element of ShapeIndex indexes into a tuple (or + nested tuple) within the shape. For a non-nested tuple, an index has a single + element. + ''' + + def __init__(self, List[int]) -> ShapeIndex: + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): +""" + + +def shape_from_pyval(pyval, layout: Sequence[int] | None = None): + """Returns a Shape that describes a tuple-tree of Numpy arrays.""" + + def convert(pyval): + if isinstance(pyval, tuple): + if layout is not None: + raise NotImplementedError( + 'shape_from_pyval does not support layouts for tuple shapes' + ) + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) + + return convert(pyval) + + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): + """Maps PaddingType or string to pad values (list of pairs of ints).""" + if not isinstance(padding_type, (str, PaddingType)): + msg = 'padding_type must be str or PaddingType, got {}.' + raise TypeError(msg.format(type(padding_type))) + + if isinstance(padding_type, str): + if padding_type.upper() == 'VALID': + padding_type = PaddingType.VALID + elif padding_type.upper() == 'SAME': + padding_type = PaddingType.SAME + else: + msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' + raise ValueError(msg.format(padding_type)) + + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + elif padding_type == PaddingType.SAME: + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [ + max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size in zip( + out_shape, window_strides, rhs_dims, lhs_dims + ) + ] + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + else: + msg = 'Unexpected PaddingType value: {}' + raise ValueError(msg.format(padding_type)) + + +XlaBuilder = _xla.XlaBuilder +XlaComputation = _xla.XlaComputation +XlaOp = _xla.XlaOp +FftType = _xla.FftType +Client = _xla.Client +Memory = _xla.Memory +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode +ResultAccuracyMode = _xla.ResultAccuracy_Mode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +class PaddingConfigDimension: + """Python representation of a xla.PaddingConfigDimension protobuf.""" + + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') + + edge_padding_low: int + edge_padding_high: int + interior_padding: int + + def __init__(self): + self.edge_padding_low = 0 + self.edge_padding_high = 0 + self.interior_padding = 0 + + +class PaddingConfig: + """Python representation of a xla.PaddingConfig protobuf.""" + + __slots__ = ('dimensions',) + + def __init__(self): + self.dimensions = [] + + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] +) -> PaddingConfig: + """Create PaddingConfig proto from list of triples of integers. + + Args: + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + A `PaddingConfig` object. + """ + if not isinstance(padding_config, PaddingConfig): + triples = padding_config + padding_config = PaddingConfig() + for lo, hi, interior in triples: + dimension = PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + return padding_config + + +class DotDimensionNumbers: + """Python representation of a xla.DotDimensionNumbers protobuf.""" + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) + + def __init__(self): + self.lhs_contracting_dimensions = [] + self.rhs_contracting_dimensions = [] + self.lhs_batch_dimensions = [] + self.rhs_batch_dimensions = [] + + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ] +) -> DotDimensionNumbers: + """Builds a DotDimensionNumbers object from a specification. + + Args: + dimension_numbers: either a `DotDimensionNumbers` or a nested tuple + `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: + A `DotDimensionNumbers` object. + """ + if isinstance(dimension_numbers, (list, tuple)): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto + else: + return dimension_numbers + + +class ConvolutionDimensionNumbers: + """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) + + def __init__(self): + self.input_batch_dimension = 0 + self.input_feature_dimension = 0 + self.input_spatial_dimensions = [] + self.kernel_input_feature_dimension = 0 + self.kernel_output_feature_dimension = 0 + self.kernel_spatial_dimensions = [] + self.output_batch_dimension = 0 + self.output_feature_dimension = 0 + self.output_spatial_dimensions = [] + + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + """Builds a ConvolutionDimensionNumbers object from a specification. + + Args: + dimension_numbers: optional, either a ConvolutionDimensionNumbers object or + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. + num_spatial_dimensions: the number of spatial dimensions. + + Returns: + A `ConvolutionDimensionNumbers` object. + """ + if dimension_numbers is None: + nd = num_spatial_dimensions + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) + dimension_numbers.input_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) + dimension_numbers.output_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) + return dimension_numbers + + +class PrecisionConfig: + """Python representation of a xla.PrecisionConfig protobuf.""" + + __slots__ = ('operand_precision',) + + Precision = _xla.PrecisionConfig_Precision + + def __init__(self): + self.operand_precision = [] + + +class ResultAccuracy: + """Python representation of a xla.ResultAccuracy protobuf.""" + + __slots__ = ('mode', 'atol', 'rtol', 'ulps') + + def __init__(self): + self.mode = _xla.ResultAccuracy_Mode.DEFAULT + self.atol = 0.0 + self.rtol = 0.0 + self.ulps = 0 + + +class GatherDimensionNumbers: + """Python representation of a xla.GatherDimensionNumbers protobuf.""" + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) + + def __init__(self): + self.offset_dims = [] + self.collapsed_slice_dims = [] + self.start_index_map = [] + self.index_vector_dim = 0 + + +class ScatterDimensionNumbers: + """Python representation of a xla.ScatterDimensionNumbers protobuf.""" + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) + + def __init__(self): + self.update_window_dims = [] + self.inserted_window_dims = [] + self.scatter_dims_to_operand_dims = [] + self.index_vector_dim = 0 + + +class ReplicaGroup: + """Python representation of a xla.ReplicaGroup protobuf.""" + + __slots__ = ('replica_ids',) + + def __init__(self): + self.replica_ids = [] + + +def _make_replica_group_proto(replica_group): + replica_group_proto = ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + +def make_replica_groups(replica_groups): + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return replica_groups_protos + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = Traceback.enabled + Traceback.enabled = enabled + try: + yield + finally: + Traceback.enabled = saved + + +def heap_profile(client: Client) -> bytes: + """Returns a gzipped pprof protocol buffer containing a heap profile.""" + return gzip.compress(client.heap_profile()) + + +XlaRuntimeError = _xla.XlaRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +weakref_lru_cache = _xla.weakref_lru_cache +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi new file mode 100644 index 000000000000..234af8f7b87d --- /dev/null +++ b/jaxlib/xla/xla_client.pyi @@ -0,0 +1,322 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +import enum +from typing import Any, Union + +import numpy + +from jaxlib import xla_extension as _xla +from jaxlib.xla_extension import ArrayImpl as ArrayImpl +from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode +from jaxlib.xla_extension import Client as Client +from jaxlib.xla_extension import CompileOptions as CompileOptions +from jaxlib.xla_extension import Device as Device +from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment +from jaxlib.xla_extension import DeviceList as DeviceList +from jaxlib.xla_extension import DeviceTopology as DeviceTopology +from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient +from jaxlib.xla_extension import FftType as FftType +from jaxlib.xla_extension import Frame as Frame +from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding +from jaxlib.xla_extension import HloSharding as HloSharding +from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics +from jaxlib.xla_extension import ifrt_programs as ifrt_programs +from jaxlib.xla_extension import Layout as Layout +from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable +from jaxlib.xla_extension import Memory as Memory +from jaxlib.xla_extension import NamedSharding as NamedSharding +from jaxlib.xla_extension import ops as ops +from jaxlib.xla_extension import OpSharding as OpSharding +from jaxlib.xla_extension import PjRtLayout as PjRtLayout +from jaxlib.xla_extension import PmapSharding as PmapSharding +from jaxlib.xla_extension import PrimitiveType as PrimitiveType +from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics +from jaxlib.xla_extension import profiler as profiler +from jaxlib.xla_extension import Shape as Shape +from jaxlib.xla_extension import Sharding as Sharding +from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding +from jaxlib.xla_extension import Traceback as Traceback +from jaxlib.xla_extension import XlaBuilder as XlaBuilder +from jaxlib.xla_extension import XlaComputation as XlaComputation +from jaxlib.xla_extension import XlaOp as XlaOp + +_version: int + +mlir_api_version: int + +bfloat16: type[numpy.generic] +# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn: type[numpy.generic] +# float8_e3m4: type[numpy.generic] +# float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] +float8_e4m3fn: type[numpy.generic] +float8_e4m3b11fnuz: type[numpy.generic] +float8_e4m3fnuz: type[numpy.generic] +float8_e5m2: type[numpy.generic] +float8_e5m2fnuz: type[numpy.generic] +XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: + ... + +def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... + +def heap_profile(client: Client) -> bytes: + ... + +XlaRuntimeError = _xla.XlaRuntimeError + +def make_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., +) -> Client: + ... + +def make_gpu_client( + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + platform_name: str | None = ..., + allowed_devices: set[int] | None = ..., + mock: bool | None = ..., + mock_gpu_topology: str | None = ..., +) -> Client: + ... + +def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client: + ... + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str | None = None, **kwargs +) -> DeviceTopology: + ... + +def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology: + ... + +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: + ... + +def make_tpu_client( + library_path: str | None, options: _NameValueMapping | None = None +) -> Client: + ... + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: DistributedRuntimeClient | None = None, +) -> Client: + ... + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + ... + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + ... + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + ... + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + ... + +def initialize_pjrt_plugin(plugin_name: str) -> None: + ... + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + ... + +class OpMetadata: + + def __init__( + self, + op_type: str | None = ..., + op_name: str | None = ..., + source_file: str | None = ..., + source_line: int | None = ..., + ): + ... + op_type: str | None + op_name: str | None + source_file: str | None + source_line: int | None + +class PaddingConfigDimension: + edge_padding_low: int + edge_padding_high: int + interior_padding: int + +class PaddingConfig: + dimensions: list[PaddingConfigDimension] + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]], +) -> PaddingConfig: + ... + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + +class DotDimensionNumbers: + lhs_contracting_dimensions: list[int] + rhs_contracting_dimensions: list[int] + lhs_batch_dimensions: list[int] + rhs_batch_dimensions: list[int] + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ], +) -> DotDimensionNumbers: + ... + +class ConvolutionDimensionNumbers: + input_batch_dimension: int + input_feature_dimension: int + input_spatial_dimensions: list[int] + kernel_input_feature_dimension: int + kernel_output_feature_dimension: int + kernel_spatial_dimensions: list[int] + output_batch_dimension: int + output_feature_dimension: int + output_spatial_dimensions: list[int] + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + ... + +class PrecisionConfig: + Precision = _xla.PrecisionConfig_Precision + operand_precision: list[_xla.PrecisionConfig_Precision] + +class ResultAccuracy: + mode: _xla.ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class GatherDimensionNumbers: + offset_dims: list[int] + collapsed_slice_dims: list[int] + start_index_map: list[int] + index_vector_dim: int + operand_batching_dims: list[int] + start_indices_batching_dims: list[int] + +class ScatterDimensionNumbers: + update_window_dims: list[int] + inserted_window_dims: list[int] + scatter_dims_to_operand_dims: list[int] + index_vector_dim: int + input_batching_dims: list[int] + scatter_indices_batching_dims: list[int] + +class ReplicaGroup: + replica_ids: list[int] + +def make_replica_groups( + replica_groups: Sequence[Sequence[int]] | None, +) -> list[ReplicaGroup]: + ... + +def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache: + ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[list[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: list[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def check_and_canonicalize_memory_kind( + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... + +def array_result_handler( + aval: Any, + sharding: Any, + committed: bool, + _skip_checks: bool = ...) -> Callable: + ... + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + COMMAND_BUFFER_COMPATIBLE = 1 + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = ..., + api_version: int = ..., + traits: CustomCallTargetTraits = ..., +) -> None: ... + +def register_custom_call_handler( + xla_platform_name: str, handler: Any +) -> None: ... + +def custom_call_targets(platform: str) -> dict[str, Any]: ... + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = ..., +) -> None: ... + +def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... + +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) diff --git a/jaxlib/xla/xla_client_backend_independent_test.py b/jaxlib/xla/xla_client_backend_independent_test.py new file mode 100644 index 000000000000..ee1c33feb40c --- /dev/null +++ b/jaxlib/xla/xla_client_backend_independent_test.py @@ -0,0 +1,195 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-independent tests for the Python XLA client.""" + +import unittest + +from absl.testing import absltest +import numpy as np + +from jax.jaxlib.xla import xla_client + +# pylint: disable=g-import-not-at-top +try: + import portpicker +except ImportError: + portpicker = None +# pylint: enable=g-import-not-at-top + +ops = xla_client.ops + + +class ShapeTest(absltest.TestCase): + + def testInvalidShapes(self): + with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field contains 1 element.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], + [1, -1]) + + +class ComputationPrinting(absltest.TestCase): + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleFromText(self): + hlo_module_text = """HloModule test + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + ENTRY entry { + p0 = f32[2,3] parameter(0) + start = f32[2,3] all-reduce-start(p0), to_apply=add + ROOT done = f32[2,3] all-reduce-done(start) + }""" + hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text) + hlo_text = hlo_module.to_string() + self.assertTrue(hlo_text.startswith("HloModule test")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder1, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + +class AliasTest(absltest.TestCase): + + def testSetUpAlias(self): + c = xla_client.XlaBuilder(self.id()) + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c.build(out) + + +class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + +class HloModuleGroupTest(absltest.TestCase): + + def testHloModuleGroup(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + computation0 = builder0.build(root) + + m = computation0.get_hlo_module() + mg_name = "test_module_group" + mg = xla_client._xla.HloModuleGroup(mg_name, [m]) + self.assertEqual(mg.name, mg_name) + + modules = mg.to_modules() + self.assertLen(modules, 1) + self.assertEqual(m.to_string(), modules[0].to_string()) + + +class RunHloPassTest(absltest.TestCase): + + def testHloDCE(self): + b = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(b, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + + # Dead instructions + p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0))) + ops.Add(p2, p2) + + hlo_module = b.build(root).get_hlo_module() + self.assertTrue(xla_client._xla.HloDCE().run(hlo_module)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py new file mode 100644 index 000000000000..e228905637cb --- /dev/null +++ b/jaxlib/xla/xla_client_test.py @@ -0,0 +1,3714 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-dependent tests for the Python XLA client.""" + +import collections +import functools +import itertools +import re +import threading +import traceback +from typing import Sequence +import unittest + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import ml_dtypes +import numpy as np + +from jax.jaxlib.xla import xla_client +import jax +import jax._src.test_util + +# pylint: disable=g-import-not-at-top +try: + from jax.jaxlib.xla import custom_calls_testlib +except ImportError: + custom_calls_testlib = None + +xla_client._xla.jax_jit.set_thread_local_state_initialization_callback( + lambda: None +) + +bfloat16 = ml_dtypes.bfloat16 +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float8_e3m4 = ml_dtypes.float8_e3m4 +float8_e4m3 = ml_dtypes.float8_e4m3 +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float8_e4m3fn = ml_dtypes.float8_e4m3fn +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz +float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +ops = xla_client.ops +xla_computation_to_mlir_module = ( + xla_client._xla.mlir.xla_computation_to_mlir_module) + + +def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): # pylint: disable=invalid-name + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + + arguments = [put(arg) for arg in arguments] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] + + +# pylint: disable=invalid-name +def jax_array_convert_to_array(self, dtype=None, copy=None): + del copy + out, _ = self._single_device_array_to_np_array_did_copy() + if dtype is not None: + out = out.astype(dtype) + return out + + +def jax_array_device(self): + return self._sharding._device + + +def jax_array_copy_to_host_async(self): + self._copy_single_device_array_to_host_async() + + +Array = xla_client.ArrayImpl +Array.__array__ = jax_array_convert_to_array +Array.copy_to_host_async = jax_array_copy_to_host_async +Array.device = jax_array_device +xla_client.SingleDeviceSharding.device_set = property( + lambda self: {self._device} +) +# pylint: enable=invalid-name + + +FLAGS = flags.FLAGS + +# We choose to ignore pylint's complaints about complex comprehensions, which we +# use widely for parameterizing tests. +# pylint: disable=g-complex-comprehension + +_CUSTOM_CALLS_REGISTERED = False + + +# XLA' alignment is 16 bytes at the moment, but it should match what Eigen +# supports, and that can go up to 128 bytes on hardware with HVX. +_XLA_CPU_MAX_ALIGNMENT = 128 + + +# Minimum possible alignment for XLA. +_XLA_CPU_MIN_ALIGNMENT = 16 + + +# Return a copy of `x` with the given alignment. Does nothing if `x` is already +# aligned. We do this manually, because numpy doesn't support custom alignment +# value. +def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): + if (x.ctypes.data % alignment) == 0: + return x + + # Create temporary buffer with extra space for alignment. + assert alignment % x.itemsize == 0 + extra = alignment // x.itemsize + buf = np.empty(x.size + extra, dtype=x.dtype) + + # Create a view of the temporary buffer with such an offset, that the result + # buffer is aligned. + offset = (-buf.ctypes.data % alignment) // x.itemsize + result = buf[offset : offset + x.size].reshape(x.shape) + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +# Return an unaligned copy of `x`. The result buffer's memory address is +# guaranteed to not be aligned to `alignment`. This function is useful for +# testing failiures. +def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): + if (x.ctypes.data % alignment) != 0: + return x + + # Create temporary buffer with extra space. + assert (x.itemsize % alignment) != 0 + offset = 1 + buf = np.empty(x.size + offset, dtype=x.dtype) + + if (buf.ctypes.data % alignment) != 0: + # If the temporary buffer is already unaligned, return it. + result = buf + else: + # Otherwise, create a view of the temporary buffer with an offset. + result = buf[offset : offset + x.size].reshape(x.shape) + assert (result.ctypes.data % alignment) != 0 + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +def TestFactory(xla_backend, + cloud_tpu=False, + tfrt_tpu=False, + pjrt_c_api=False, + pathways=False, + pathways_ifrt=False): + tests = [] + + int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] + # TODO(phawkins): test np.float16, where supported. + float_dtypes = [bfloat16, np.float32, np.float64] + complex_dtypes = [np.complex64, np.complex128] + standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. + # standard_dtypes is only used for BufferProtocolTest so we only test fp8 + # round trip tests. + fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + standard_dtypes += fp8_dtypes + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes + + class ComputationTest(parameterized.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def setUp(self): + super(ComputationTest, self).setUp() + self.backend = xla_backend() + + global _CUSTOM_CALLS_REGISTERED + if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: + for name, fn in custom_calls_testlib.registrations().items(): + xla_client.register_custom_call_target( + name, fn, platform="cpu", api_version=1 + ) + for name, val in custom_calls_testlib.type_ids().items(): + xla_client.register_custom_type_id(name, val, platform="cpu") + _CUSTOM_CALLS_REGISTERED = True + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.XlaBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + return execute_with_python_values( + compiled_c, arguments, backend=self.backend) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, + expected) + + def _ExecuteAndCompareClose(self, + c, + arguments=(), + expected=None, + rtol=1e-4, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) + + def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" + return np.array(*args, dtype=np.bool_, **kwargs) + + class ComputationPrinting(absltest.TestCase): + + def setUp(self): + super(ComputationPrinting, self).setUp() + self.backend = xla_backend() + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleToHloText(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + self.assertIn("fusion", hlo_text) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleAsSerializedProto(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + proto = hlo_modules[0].as_serialized_hlo_module_proto() + hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() + hlo_text_roundtrip = hlo_module_roundtrip.to_string() + self.assertEqual(hlo_text, hlo_text_roundtrip) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testStableComputationSerialization(self): + # Ideally we would test identical computations produced in different + # processes. For now we have this limited smoke test. + computation = self.ExampleComputation() + ref = computation.as_serialized_hlo_module_proto() + for _ in range(10): + self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) + + # TODO(b/261771737): some version of this should work with pjrt_c_api=True + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + + def testFingerprint(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + fingerprint = executable.fingerprint + if ( + self.backend.platform == "tpu" + or self.backend.platform == "gpu" + or self.backend.platform == "cpu" + ) and not (cloud_tpu or pathways or pathways_ifrt): + logging.info("fingerprint: %s", fingerprint) + self.assertNotEmpty(fingerprint) + else: + self.assertIsNone(fingerprint) + + tests.append(ComputationPrinting) + + class ComputationsWithConstantsTest(ComputationTest): + """Tests focusing on Constant ops.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testConstantScalarSum(self, dtype): + c = self._NewComputation() + ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) + self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorMul(self, dtype): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarDiv(self, dtype): + c = self._NewComputation() + ops.Div( + ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), + ops.Constant(c, dtype(2.0))) + self._ExecuteAndCompareClose( + c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarPow(self, dtype): + c = self._NewComputation() + ops.Pow( + ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), + ops.Constant(c, dtype(2.))) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) + + def testIota(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + self._ExecuteAndCompareExact( + c, expected=[np.arange(10, dtype=np.float32)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testBroadcastedIota(self, dtype): + c = self._NewComputation() + shape = xla_client.Shape.array_shape( + xla_client.dtype_to_etype(dtype), (2, 3)) + ops.Iota(c, shape, 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) + self._ExecuteAndCompareExact(c, expected=[expected]) + + def testBooleanAnd(self): + c = self._NewComputation() + ops.And( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) + + def testBooleanOr(self): + c = self._NewComputation() + ops.Or( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) + + def testBooleanXor(self): + c = self._NewComputation() + ops.Xor( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2D(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), + ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) + + def testShiftLeft(self): + c = self._NewComputation() + ops.ShiftLeft( + ops.Constant(c, NumpyArrayS32([3])), + ops.Constant(c, NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[[12]]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + ops.ShiftRightArithmetic( + ops.Constant(c, NumpyArrayS32([-2])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[-1]]) + + def testShiftRightLogical(self): + c = self._NewComputation() + ops.ShiftRightLogical( + ops.Constant(c, NumpyArrayS32([-1])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim0(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim1(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantAxpy(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Mul( + ops.Constant(c, dtype(2)), + ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), + ops.Constant(c, np.array([100, -100, 200, -200], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) + + def testCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"subtract_f32", + operands=[ + ops.Constant(c, np.float32(1.25)), + ops.Constant(c, np.float32(0.5)) + ], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), ()), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[0.75]) + + def testCustomCallWithUnifiedApiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING_UNIFIED, + ) + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, expected_regex="NOT_FOUND" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + with self.assertRaises(xla_client.XlaRuntimeError): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysFail(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_fail", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + with self.assertRaisesRegex( + Exception, expected_regex="Failed intentionally" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysSucceed(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_succeed", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiSubtract(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"subtract_f32_cst", + operands=[ops.Constant(c, np.float32(1.25))], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + opaque=b"{cst = 3.0 : f32}", + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + self._ExecuteAndCompareClose(c, expected=[-1.75]) + + def testStatefulCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"stateful", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.int32), (), ()), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[42]) + + def testCustomCallLookup(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + if xla_client._version < 241: + self.skipTest("Test requires jaxlib version 241") + + self.assertTrue(_CUSTOM_CALLS_REGISTERED) + xla_client.make_cpu_client() + self.assertContainsSubset( + list(custom_calls_testlib.registrations().keys()), + xla_client.custom_call_targets("Host").keys(), + ) + + tests.append(ComputationsWithConstantsTest) + + class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def setUp(self): + super(ComputationFromProtoTest, self).setUp() + self.backend = xla_backend() + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + serialized_proto = b.build().as_serialized_hlo_module_proto() + + # Load and execute the proto + c = xla_client.XlaComputation(serialized_proto) + m = xla_computation_to_mlir_module(c) + ans, = execute_with_python_values( + self.backend.compile(m), (), backend=self.backend) + np.testing.assert_equal(ans, np.int32(3)) + + tests.append(ComputationFromProtoTest) + + class ParametersTest(ComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testScalarTimesVector(self, dtype): + c = self._NewComputation() + arg0 = np.array(3, dtype=dtype) + if np.issubdtype(dtype, np.unsignedinteger): + arg1 = np.array([10, 15, 2, 7], dtype=dtype) + else: + arg1 = np.array([10, 15, -2, 7], dtype=dtype) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, arguments=[arg0, arg1], expected=[arg0 * arg1]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testScalarMinusVectorExplicitNumbering(self, dtype): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + arg0 = np.array(2.0, dtype=dtype) + arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + ops.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, arguments=[arg0, arg1], expected=[arg1 - arg0]) + + tests.append(ParametersTest) + + class LayoutsTest(ComputationTest): + """Tests related to getting and setting on-device memory layouts.""" + + def _minor_to_major(self, layout: xla_client.PjRtLayout): # pylint: disable=invalid-name + m2m_str = re.search("{([0-9,]*)", str(layout)).group(1) + if not m2m_str: + return () + return tuple(int(x) for x in m2m_str.split(",")) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayouts(self): + # Create computation with a few parameters. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype): + nonlocal param_count + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) + param = ops.Parameter(c, param_count, shape) + param_count += 1 + return param + + p0 = MakeArg((2, 3, 4), np.float32) + MakeArg((3, 2), np.int32) + MakeArg((), np.float64) + + ops.Add(p0, ops.Constant(c, np.ones((2, 3, 4), np.float32))) + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertLen(self._minor_to_major(layouts[1]), 2) + self.assertEmpty(self._minor_to_major(layouts[2])) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayoutsTupled(self): + # Generated with: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} +""" + options = xla_client.CompileOptions() + # 'parameter_is_tupled_arguments' causes MLIR untupled arguments to get + # turned into HLO tupled arguments. + options.parameter_is_tupled_arguments = True + executable = self.backend.compile(module_str, compile_options=options) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testGetOutputLayouts(self): + # Generated with jax.jit(lambda: (np.ones((1024, 128)), np.int32(42), + # np.ones(10)))() + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<1024x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<42> : tensor + return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_output_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 2) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testSetArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0,1,2}"}, + %arg1: tensor {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 3) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + self.assertEqual(self._minor_to_major(input_layouts[2]), (0,)) + + # Compile a version with default arg0 layout so we can make sure we + # actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testSetArgumentLayoutsLegacy(self): + """Tests setting the arg layouts with compile_options (deprecated). + + New code should use the mhlo.layout_mode string attr on parameters. + """ + # Create computation with custom input layouts. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype, layout): + nonlocal param_count + arr = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + param = ops.Parameter(c, param_count, + xla_client.shape_from_pyval(arr, layout)) + param_count += 1 + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape, layout) + return arr, param, shape + + arg0, p0, shape0 = MakeArg((2, 3, 4), np.float32, (1, 2, 0)) + arg1, p1, shape1 = MakeArg((3, 2), np.int32, (0, 1)) + arg2, p2, shape2 = MakeArg((), np.float64, ()) + + ops.Tuple(c, [ + ops.Add(p0, ops.Constant(c, np.ones(arg0.shape, arg0.dtype))), + ops.Add(p1, ops.Constant(c, np.ones(arg1.shape, arg1.dtype))), + ops.Add(p2, ops.Constant(c, np.ones(arg2.shape, arg2.dtype))), + ]) + + # We also need to set the input layouts in the compile options. + options = xla_client.CompileOptions() + options.argument_layouts = [shape0, shape1, shape2] + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + # Test that compiled executable has expected layouts. + expected_layouts: Sequence[xla_client.Shape] = [shape0, shape1, shape2] + actual_layouts: Sequence[xla_client.Layout] = ( + executable.get_parameter_layouts()) + self.assertEqual(len(actual_layouts), len(expected_layouts)) + for actual, expected in zip(actual_layouts, expected_layouts): + self.assertEqual( + self._minor_to_major(actual), + expected.layout().minor_to_major(), + ) + + @unittest.skipIf(pathways, "not implemented") + def testSetOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]", + mhlo.layout_mode = "{0,1,2}"}, + tensor {jax.result_info = "[1]", + mhlo.layout_mode = "{}"}, + tensor<10xf32> {jax.result_info = "[2]", + mhlo.layout_mode = "{0}"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check output layouts. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 3) + self.assertEqual(self._minor_to_major(output_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(output_layouts[1]), ()) + self.assertEqual(self._minor_to_major(output_layouts[2]), (0,)) + + # Compile a version with default first output layout so we can make sure + # we actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def SetLayoutsSharded(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) + # x = jax.device_put(np.ones((1024, 128)), sharding.reshape(4, 2)) + # jax.jit(lambda x, y: x + y, out_shardings=sharding)(x, 1.) + # + # This also lightly tests mixed default + user-specified input layouts. + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x128xf32> {mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x128xf32> {jax.result_info = "", + mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}) { + %0 = stablehlo.convert %arg1 : tensor + %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1024x128xf32> + %2 = stablehlo.add %arg0, %1 : tensor<1024x128xf32> + return %2 : tensor<1024x128xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 2) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + + # Check output layout. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 1) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + + # Compile a version with default layouts so we can make sure we actually + # set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}) + -> (tensor<1024x8x128xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertEqual(self._minor_to_major(input_layouts[0]), (1, 0)) + self.assertEqual(self._minor_to_major(input_layouts[1]), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default layout for the second + # (1024,8,128) argument. + self.assertNotEqual( + self._minor_to_major(input_layouts[1]), + self._minor_to_major(default_executable.get_parameter_layouts()[1]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Generated with jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "", + mhlo.layout_mode = "auto"}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check output layout + output_layout, = executable.get_output_layouts() + self.assertEqual(self._minor_to_major(output_layout), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default output layout. + self.assertNotEqual( + self._minor_to_major(output_layout), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + tests.append(LayoutsTest) + + class BufferTest(ComputationTest): + """Tests focusing on execution with Buffers.""" + + def testConstantSum(self): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose(c, expected=[4.25]) + + def testOneParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose( + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) + + def testTwoParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + arg = NumpyArrayF32(1.11) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.delete() + with self.assertRaises(xla_client.XlaRuntimeError): + compiled_c.execute([arg_buffer]) + + def testXlaShapeIndex(self): + a = xla_client.ShapeIndex((1, 2)) + b = xla_client.ShapeIndex((1, 2)) + c = xla_client.ShapeIndex((2, 3)) + self.assertEqual(a, b) + self.assertNotEqual(b, c) + + def testLayout(self): + f32 = xla_client.PrimitiveType.F32 + a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout() + self.assertEqual(a.minor_to_major(), (0, 1)) + self.assertEqual(b.minor_to_major(), (0, 1)) + self.assertEqual(c.minor_to_major(), (1, 0)) + self.assertEqual(a, b) + self.assertNotEqual(a, c) + self.assertNotEqual(b, c) + self.assertEqual(hash(a), hash(b)) + self.assertNotEqual(hash(a), hash(c)) + self.assertNotEqual(hash(b), hash(c)) + + def testBlockUntilReadyWorks(self): + arg = np.array([[1., 2.]], np.float32) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.block_until_ready() + # This test merely checks that nothing goes awry when we call + # block_until_ready(); it's difficult to test anything else. + + def testBlockUntilReadyRaisesOnDeletedBuffer(self): + arg = np.array([[1., 2.]], np.float32) + buffer = self.backend.buffer_from_pyval(arg) + buffer.delete() + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "BlockHostUntilReady() called on deleted or donated buffer")): + buffer.block_until_ready() + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testOnDeviceSizeInBytes(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) + # OnDeviceSizeInBytes varies depending on the platform. Confirm there's + # a reasonable value. + self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) + self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) + + def testLiveBuffers(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support LiveBuffers().") + self.assertEmpty(self.backend.live_buffers()) + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertLen(self.backend.live_buffers(), 3) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg1_buffer) + self.assertIs(self.backend.live_buffers()[2], arg0_buffer) + + arg1_buffer.delete() + self.assertLen(self.backend.live_buffers(), 2) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg0_buffer) + + arg0_buffer.delete() + arg2_buffer.delete() + self.assertEmpty(self.backend.live_buffers()) + + def testCopyToHost(self): + arg0 = np.array([[1., 2.]], np.float32) + arg1 = np.array([[3., 4.]], np.float32) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + # Prefetch two buffers using copy_to_host_async, and then retrieve their + # values using np.asarray(). + arg0_buffer.copy_to_host_async() + arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. + arg1_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + np.testing.assert_equal(arg1, np.asarray(arg1_buffer)) + # copy_to_host_async does nothing after np.asarray() is called. + arg0_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + + def testDevice(self): + x = np.arange(8, dtype=np.int32) + for device in self.backend.local_devices(): + buf = self.backend.buffer_from_pyval(x, device=device) + self.assertEqual(buf.device(), device) + np.testing.assert_equal(x, np.asarray(buf)) + + def testStandardTypes(self): + for dtype in standard_dtypes: + if dtype == np.complex128: + continue + # float8_e4m3b11fnuz not supported on some TPU backends. + if ( + dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz] + and self.backend.platform == "tpu" + ): + if self.backend.platform_version.find("TPU") == -1: + continue + arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) + arr = np.asarray(arr) + self.assertEqual(dtype, type(arr[0])) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testUnsafeBufferPointer(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt, "not implemented") + def testClone(self): + x = np.array([[3., 4., 5.]], np.float32) + y = self.backend.buffer_from_pyval(x) + z = y.clone() + self.assertNotEqual(id(x), id(y)) + np.testing.assert_array_equal(np.asarray(y), np.asarray(z)) + self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) + + tests.append(BufferTest) + + class SingleOpTest(ComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConcatenate(self, dtype): + c = self._NewComputation() + args = ( + ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), + ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), + ) + ops.ConcatInDim(c, args, dimension=0) + self._ExecuteAndCompareExact( + c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) + + # pyformat: disable + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } for src_dtype, dst_dtype in itertools.permutations( + [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) + # pyformat: enable + def testConvertElementType(self, src_dtype, dst_dtype): + if ((src_dtype in [np.int64, np.float64] or + dst_dtype in [np.int64, np.float64]) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.ConvertElementType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = np.array(x, dtype=dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } + for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] + for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) + # pyformat: enable + def testBitcastConvertType(self, src_dtype, dst_dtype): + if (np.float64 in (src_dtype, dst_dtype) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.BitcastConvertType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = x.view(dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # TODO(b/123523486) implement AllToAll on CPU + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + ops.AllToAll(ops.Constant(c, lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum(ops.Constant(c, lhs)) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testReplicaId(self): + c = self._NewComputation() + _ = ops.ReplicaId(c) + self._ExecuteAndCompareExact(c, expected=[0]) + + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum( + ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixVector(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0], [20.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixMatrix(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testDotGeneralWithPrecisionConfig(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGH) + config.operand_precision.append(config.Precision.HIGHEST) + ops.DotGeneral( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + dimension_numbers, + precision_config=config) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedF32WithPrecisionConfig(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGHEST) + config.operand_precision.append(config.Precision.DEFAULT) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + precision_config=config) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NHWC", "OIHW", "CWNH"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, np.transpose(lhs, + (0, 2, 3, 1))), ops.Constant(c, rhs), + strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) + + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + feature_group_count = 2 + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ], [ + [0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedWindowReversalF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + window_reversal = [False, True] + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + window_reversal=window_reversal) + result = np.array([[[ + [0., 0., 0.], + [0., 10., 20.], + [0., 0., 0.], + [30., 40., 50.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + ops.Not(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[~arr]) + + def testPopulationCount(self): + c = self._NewComputation() + arr = NumpyArrayS32([3, 0, 1]) + ops.PopulationCount(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) + + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + ops.Clz(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Exp(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpWithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Exp(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Expm1(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testExpm1WithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Expm1(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Round(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) + + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log1p(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Neg(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[-arr]) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Floor(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Ceil(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + ops.Abs(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) + + def testTanF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tan(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tan(arr)]) + + def testTanhF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + + def testTanhF64(self): + if self.backend.platform == "tpu": + self.skipTest("TPU doesn't support 64bit tanh") + c = self._NewComputation() + arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + ops.Transpose(ops.Constant(c, array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=[expected]) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + ops.Eq( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + def testNe(self): + c = self._NewComputation() + ops.Ne( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) + + ops.Ne( + ops.Constant(c, NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, + float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, + c, (), + expected=[[True, False, True, True]]) + + def testGt(self): + c = self._NewComputation() + ops.Gt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) + + def testGe(self): + c = self._NewComputation() + ops.Ge( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, True, True, False, False]]) + + def testLt(self): + c = self._NewComputation() + ops.Lt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) + + def testLe(self): + c = self._NewComputation() + ops.Le( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, False, False, True, True]]) + + def testMax(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) + + def testMin(self): + c = self._NewComputation() + ops.Min( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) + + def testPad(self): + c = self._NewComputation() + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), + xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = xla_client.PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), padding_config) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testReshape(self): + c = self._NewComputation() + ops.Reshape( + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) + + def testCollapse(self): + c = self._NewComputation() + ops.Collapse( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) + + def testRev(self): + c = self._NewComputation() + ops.Rev( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) + + def testReducePrecision(self): + c = self._NewComputation() + ops.ReducePrecision( + ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), + exponent_bits=8, + mantissa_bits=7) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) + + def testClampF32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayF32(-1)), + ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testClampS32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayS32(-1)), + ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testSelect(self): + c = self._NewComputation() + ops.Select( + ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), + ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) + + def testSlice(self): + c = self._NewComputation() + ops.Slice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + [1, 0], [3, 2], [1, 1]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testSliceInDim(self): + c = self._NewComputation() + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) + + def testDynamicSlice(self): + c = self._NewComputation() + ops.DynamicSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(0)) + ], [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + ops.DynamicUpdateSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(1)) + ]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) + + def testTuple(self): + c = self._NewComputation() + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 3) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + ops.GetTupleElement( + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]), 1) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) + + def testBroadcast(self): + c = self._NewComputation() + ops.Broadcast( + ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) + + def testBroadcastInDim(self): + c = self._NewComputation() + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + ops.RngNormal( + ops.Constant(c, NumpyArrayF32(0.)), + ops.Constant(c, NumpyArrayF32(1.)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape and uniqueness + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayF32(lo)), + ops.Constant(c, NumpyArrayF32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, uniqueness, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayS32(lo)), + ops.Constant(c, NumpyArrayS32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, integrality, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) + + def testSort(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + c = self._NewComputation() + ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) + self._ExecuteAndCompareClose( + c, + expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) + + def testSortKeyVal(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) + np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + + def testSortCustomComparator(self): + b = self._NewComputation("comparator") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) + comparator = b.build() + + keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort( + c, (ops.Constant(c, keys), ops.Constant(c, values)), + dimension=1, + comparator=comparator) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) + np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) + + def testQR(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testEigh(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) + # TODO(b/129396575): Turn this test back on when it passes without + # fastmath. + # v, w = self._Execute(c, ()) + # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.SVD(ops.Constant(c, a))) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + ops.TriangularSolve( + ops.Constant(c, a_vals), + ops.Constant(c, b_vals), + left_side=False, + lower=True, + transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, + unit_diagonal=False) + self._ExecuteAndCompareClose( + c, + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) + ], + rtol=1e-4) + + def testApproxTopK(self): + if self.backend.platform != "tpu": + self.skipTest("ApproxTopK is only supported on TPU") + k = 10 + qy_size = 256 + db_size = 3000 + feature = 128 + recall_target = 0.95 + b = self._NewComputation() + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Gt(p0, q0) + comparator = b.build() + qy_shape = [qy_size, feature] + db_shape = [feature, db_size] + rng = np.random.RandomState(0) + qy_arg = rng.randn(*qy_shape).astype(np.float32) + db_arg = rng.randn(*db_shape).astype(np.float32) + b = self._NewComputation() + qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg)) + db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg)) + scores = ops.Dot(qy, db) + iota = ops.Iota( + b, + xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + (qy_size, db_size)), 1) + init_val = ops.Constant(b, np.float32(-1)) + init_arg = ops.Constant(b, np.int32(-1)) + ground_truth = ops.TopK(scores, k=k) + approx_topk = ops.ApproxTopK( + b, [scores, iota], [init_val, init_arg], + top_k=k, + reduction_dim=1, + comparator=comparator, + recall_target=recall_target) + ops.Tuple(b, [ + ops.GetTupleElement(ground_truth, 1), + ops.GetTupleElement(approx_topk, 1) + ]) + results = self._Execute(b, [qy_arg, db_arg]) + ground_truth_docids = [set(x) for x in results[0]] + hits = sum( + len([x for x in approx_topk_per_q if x in ground_truth_docids[q]]) + for q, approx_topk_per_q in enumerate(results[1]) + ) + self.assertGreater(hits / (qy_size * k), recall_target) + + def testIsConstant(self): + c = self._NewComputation() + a = ops.Constant(c, np.int32(3)) + b = ops.Constant(c, np.int32(1)) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) + const_expr = ops.Sub(b, a) + non_const_expr = ops.Mul(const_expr, x) + self.assertTrue(c.is_constant(const_expr)) + self.assertFalse(c.is_constant(non_const_expr)) + + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + ops.Gather( + ops.Constant(c, a), + ops.Constant(c, indices), + dnums, + slice_sizes=[1, 1]) + g, = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + + def testAllGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + c = self._NewComputation() + ops.AllGather( + operand=ops.Constant(c, a), + all_gather_dimension=0, + shard_count=1, + replica_groups=xla_client.make_replica_groups([[0]]), + use_global_device_ids=False) + [g] = self._Execute(c, ()) + np.testing.assert_equal(g, a) + + def testFft(self): + if self.backend.platform == "tpu": + self.skipTest("TPU only supports 1D FFT") + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) + # IFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) + # IRFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4 + ) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes + fp8_dtypes) + def testNextAfter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + if dtype == bfloat16 and self.backend.platform == "tpu": + self.skipTest("b/371119032: Test fails on TPUs with bfloat16") + finfo = ml_dtypes.finfo(dtype) + eps = finfo.eps + c = self._NewComputation() + # Each row is (value, direction, expected), where + # 'nextafter(value, direction)' should be 'expected'. + data = np.array( + [ + [1, 2, 1 + finfo.eps], + [2, 1, 2 - eps], + [-0., 1, finfo.smallest_subnormal], + [0., -1, -finfo.smallest_subnormal], + [-finfo.smallest_subnormal, 1, -0.], + [finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal], + [finfo.smallest_subnormal, -1, 0], + ], + dtype=dtype, + ) + + ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1])) + out, = self._Execute(c, ()) + np.testing.assert_equal(out, data[:, 2]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testRegularizedIncompleteBeta(self, dtype): + x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], + dtype=dtype) + a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], + dtype=dtype) + b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], + dtype=dtype) + c = self._NewComputation() + ops.RegularizedIncompleteBeta( + ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) + expected = np.array( + [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) + + tests.append(SingleOpTest) + + class EmbeddedComputationsTest(ComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantComputation(self, in_dtype, out_dtype): + """Computation (A) -> B that returns a constant 1 for any input.""" + c = self._NewComputation("constant_{}_{}_one".format( + in_dtype.__name__, out_dtype.__name__)) + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) + ops.Constant(c, out_dtype(1)) + return c.build() + + def _CreateMulBy2Computation(self, dtype): + """Computation (dtype) -> dtype that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + ops.Mul( + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), + ops.Constant(c, dtype(2.0))) + return c.build() + + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + ops.Mul( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) + return c.build() + + def _CreateBinaryAddComputation(self, dtype): + """Computation (dtype, dtype) -> dtype that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _CreateBinaryGeComputation(self, dtype): + """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" + c = self._NewComputation("param0_lt_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _MakeSample3DArray(self, dtype): + return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], + dtype=dtype) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testCall(self, dtype): + c = self._NewComputation() + ops.Call( + c, + self._CreateMulBy2Computation(dtype), + operands=(ops.Constant(c, dtype(5.0)),)) + self._ExecuteAndCompareClose(c, expected=[10.0]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), + "in_dtype": in_dtype, + "out_dtype": out_dtype, + } for in_dtype, out_dtype in [[np.float32, np.int32]]) + def testMapEachElementToConstant(self, in_dtype, out_dtype): + c = self._NewComputation() + ops.Map(c, + [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], + self._CreateConstantComputation(in_dtype, out_dtype), [0]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testMapMulBy2(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSimpleMapChain(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + # Chains a map of constant-out with a map of mul-by-2 + c = self._NewComputation() + const = ops.Map( + c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateConstantComputation(dtype, dtype), [0]) + ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) + + # TODO(b/154752816): bfloat16 crashes in evaluator. + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDivVectorsWithMap(self, dtype): + + def DivComputation(): + c = self._NewComputation("div_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + c = self._NewComputation() + ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), + ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), + DivComputation(), [0]) + self._ExecuteAndCompareClose( + c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSelectAndScatter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + operand = ops.Constant( + c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, + c.get_shape(operand).dimensions(), window_dimensions, window_strides) + ops.SelectAndScatterWithGeneralPadding( + operand, + select=self._CreateBinaryGeComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), + init_value=ops.Constant(c, np.array(1, dtype=dtype)), + scatter=self._CreateBinaryAddComputation(dtype)) + self._ExecuteAndCompareClose( + c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduce1DtoScalar(self, dtype): + c = self._NewComputation() + ops.Reduce( + c, + operands=[ + ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) + ], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[0]) + self._ExecuteAndCompareClose(c, expected=[10]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), + "dtype": dtype, + "dim": dim, + } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) + def testReduce2DTo1D(self, dtype, dim): + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[dim]) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), + "dtype": dtype, + "dims": tuple(dims) + } for dtype in float_dtypes for dims in itertools.permutations(range(3))) + def testReduce3DAllPossibleWaysF32(self, dtype, dims): + input_array = self._MakeSample3DArray(dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=dims) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowSameUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.SAME, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidGeneralStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + def testReduceWindowVariadic(self): + c = self._NewComputation("reducer") + shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) + shape = shape.with_major_to_minor_layout_if_absent() + ps = [ops.Parameter(c, i, shape) for i in range(4)] + which = ops.Ge(ps[0], ps[2]) + ops.Tuple( + c, [ops.Select(which, ps[0], ps[2]), + ops.Select(which, ps[1], ps[3])]) + reducer = c.build() + + key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) + val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, key_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operands=[ops.Constant(c, key_array), + ops.Constant(c, val_array)], + init_values=[ + ops.Constant(c, np.int32(0)), + ops.Constant(c, np.int32(0)) + ], + computation=reducer, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testWhile(self, dtype): + + def LessThan10Cond(): + c = self._NewComputation("test_lt_10") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) + return c.build() + + cond = LessThan10Cond() + body = self._CreateMulBy2Computation(dtype) + c = self._NewComputation() + init = ops.Constant(c, dtype(1.)) + ops.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=[16.]) + + def testConditionalTrue(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(True)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[6.]) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(False)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[1.]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + for item in to_infeed: + device.transfer_to_infeed(item) + + for item in to_infeed: + result, = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertEqual(result, item) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedTuple(self): + to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + device.transfer_to_infeed(to_infeed) + + result = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_equal(result[0], to_infeed[0]) + np.testing.assert_equal(result[1], to_infeed[1]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x_and_token = ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent()) + x = ops.GetTupleElement(x_and_token, 0) + token = ops.GetTupleElement(x_and_token, 1) + outfeed_shape = xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent() + ops.OutfeedWithToken(x, token, outfeed_shape) + ops.Tuple(c, ()) + + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + + for want in to_round_trip: + execution = threading.Thread(target=lambda: compiled_c.execute([])) + execution.start() + device.transfer_to_infeed(want) + got = device.transfer_from_outfeed(outfeed_shape) + execution.join() + self.assertEqual(want, got) + + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + ops.Scatter( + ops.Constant(c, a), ops.Constant(c, scatter_indices), + ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), + dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], + dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=[expected]) + + class DeviceTest(ComputationTest): + + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) + + def testGetAllDevices(self): + # TODO(hyeontaek): Remove this method once we have a unified API for + # enumerating devices with different criteria. + self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access + + def testPlatform(self): + for device in self.backend.local_devices(): + self.assertEqual(device.platform, self.backend.platform) + + def testCoreCount(self): + if self.backend.platform != "gpu": + self.skipTest("core_count is only supported on GPU") + for device in self.backend.local_devices(): + self.assertGreater(device.core_count, 0) + + def testLocalHardwareId(self): + for device in self.backend.devices(): + local_hardware_id = device.local_hardware_id + if local_hardware_id is not None: + self.assertGreaterEqual(local_hardware_id, 0) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testLocalDeviceFromLocalHardwareId(self): + for device in self.backend.local_devices(): + if device.local_hardware_id is not None: + lookup_device = self.backend.device_from_local_hardware_id( + device.local_hardware_id) + self.assertEqual(lookup_device, device) + + @unittest.skipIf(pathways, "not implemented") + @unittest.skipIf(pathways_ifrt, "not implemented") + def testMemoryStats(self): + for device in self.backend.local_devices(): + stats = device.memory_stats() + if ( + self.backend.platform != "tpu" or not tfrt_tpu + ) and self.backend.platform not in ("gpu", "cuda", "rocm"): + self.assertIsNone(stats) + else: + self.assertIsNotNone(stats) + # Spot check a few fields + self.assertEqual(type(stats["num_allocs"]), int) + self.assertGreaterEqual(stats["num_allocs"], 0) + self.assertEqual(type(stats["bytes_in_use"]), int) + self.assertGreaterEqual(stats["bytes_in_use"], 0) + self.assertEqual(type(stats["peak_bytes_in_use"]), int) + self.assertGreaterEqual(stats["peak_bytes_in_use"], 0) + self.assertEqual(type(stats["largest_alloc_size"]), int) + self.assertGreaterEqual(stats["largest_alloc_size"], 0) + + @unittest.skipIf(pathways, "not implemented") + def testMemory(self): + for device in self.backend.local_devices(): + for memory in device.addressable_memories(): + self.assertEqual(memory.process_index, device.process_index) + self.assertEqual(memory.platform, device.platform) + self.assertIn(device, memory.addressable_by_devices()) + self.assertEqual(memory, device.memory(memory.kind)) + + tests.append(DeviceTest) + + class ErrorTest(ComputationTest): + + def setUp(self): + super(ErrorTest, self).setUp() + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) + + def testCompileWithWrongElementTypeInLayout(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + options = xla_client.CompileOptions() + options.argument_layouts = [ + xla_client.Shape.array_shape(np.dtype(np.float32), []) + ] + + def TestFun(): + return self.backend.compile(c.build(), compile_options=options) + + self.assertRaisesRegex( + RuntimeError, r".*Invalid argument shape.*" + r"expected s32\[\], got f32\[\].*", TestFun) + + def testInvokeWithWrongElementType(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + def TestFun(): + return execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), + [self.f32_scalar_2], self.backend) + + self.assertRaisesRegex( + RuntimeError, r"Invalid argument: Argument does not match.*" + r"want s32\[\], got f32\[\].*", TestFun) + + tests.append(EmbeddedComputationsTest) + + class ComputationRootTest(ComputationTest): + """Tests related to setting the root of the computation.""" + + def testComputationRootDifferentFromLastOp(self): + c = self._NewComputation() + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(ComputationRootTest) + + class SetShardingTest(ComputationTest): + """Tests related to set OpSharding.""" + + def testSetSharding(self): + c = self._NewComputation() + sharding = xla_client.OpSharding() + sharding.type = xla_client.OpSharding.Type.REPLICATED + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] + c.set_sharding(sharding) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + c.clear_sharding() + + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(SetShardingTest) + + testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2), + (2, 1, 3), + (2, 4, 1), + (3, 1), + (1, 3), + ] + + def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) + + class DLPackTest(parameterized.TestCase): + + def setUp(self): + super(DLPackTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform not in ("cpu", "gpu", "cuda", "rocm"): + self.skipTest("DLPack requires CPU or GPU") + self.cpu_backend = ( + self.backend + if self.backend.platform == "cpu" else xla_client.make_cpu_client()) + self.gpu_backend = ( + self.backend + if self.backend.platform in ("gpu", "cuda", "rocm") + else None + ) + + def tearDown(self): + super().tearDown() + del self.backend + del self.cpu_backend + del self.gpu_backend + + @classmethod + def _GetStreamFromDevice(cls, device): + try: + return device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + return None + else: + raise + + def _DLPackManagedTensorToBuffer( + self, tensor, use_legacy_api, backend=None + ): + if use_legacy_api: + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, self.cpu_backend, self.gpu_backend + ) + else: + if not backend: + backend = self.backend + device = backend.local_devices()[0] + stream = DLPackTest._GetStreamFromDevice(device) + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, device, stream + ) + + # pylint: disable=g-complex-comprehension + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "{}_gpu={}{}".format( + FormatShapeAndDtype(shape, dtype), + gpu, + "_legacy" if use_legacy_api else "", + ), + "dtype": dtype, + "shape": shape, + "gpu": gpu, + "use_legacy_api": use_legacy_api, + } + for dtype in dlpack_dtypes + for shape in testcase_shapes + for gpu in [False, True] + for use_legacy_api in [False, True] + ) + # pyformat: enable + def testRoundTrip(self, dtype, shape, gpu, use_legacy_api): + if gpu and self.gpu_backend is None: + raise unittest.SkipTest("Test not running with GPU support") + backend = self.gpu_backend if gpu else self.cpu_backend + if dtype == np.bool_: + x = np.random.randint(0, 2, size=shape).astype(np.bool_) + else: + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + buffer = backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + del buffer # Free "buffer" to make sure dlt retains ownership. + self.assertEqual(type(dlt).__name__, "PyCapsule") + y = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api, backend) + np.testing.assert_array_equal( + x.astype(np.uint8) if dtype == np.bool_ else x, np.asarray(y)) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testTensorsCanBeConsumedOnceOnly(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + def ConsumeDLPackTensor(): + _ = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api) + + ConsumeDLPackTensor() + self.assertRaisesRegex( + RuntimeError, ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testNonOwnedDlpackCanBeViewedTwice(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + y = self._DLPackManagedTensorToBuffer(d1, use_legacy_api) + z = self._DLPackManagedTensorToBuffer(d2, use_legacy_api) + del d1, d2 + np.testing.assert_array_equal(x, np.asarray(buffer)) + np.testing.assert_array_equal(x, np.asarray(y)) + np.testing.assert_array_equal(x, np.asarray(z)) + + @parameterized.parameters(False, True) + def testZeroCopyOnAlignedDlpackTensor(self, use_legacy_api): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # Create a numpy array that is aligned to XLA requirements. + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Aligned(x) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was sufficiently aligned, so input and output should alias. + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertEqual( + x_ptr, + y_ptr, + msg=f"Buffers are not aliased ({hex(x_ptr)} != {hex(y_ptr)}).", + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}{}".format( + "_legacy" if use_legacy_api else "", + "_transpose" if transpose else "", + ), + "use_legacy_api": use_legacy_api, + "transpose": transpose, + } + for use_legacy_api in [False, True] + for transpose in [False, True] + ) + def testReturnCopyOnUnalignedDlpackTensor(self, use_legacy_api, transpose): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + if transpose and use_legacy_api: + self.skipTest("Non-default layout is not supported in legacy API") + + # Create a numpy array that is not aligned to XLA requirements. XLA's + # alignment requirements differ for different hardware, so we use the + # smallest possible value. If we make sure the buffer is not aligned to + # this value (16 bytes), then it is also not aligned to its multiples (32, + # 64 etc.) + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT) + + # Transpose the array to test non-default layout with trivial striding. + if transpose: + x = x.transpose((0, 2, 1, 3)) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was not sufficiently aligned, so input and output should not + # alias (output should be a copy of input, and it should be aligned). + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertNotEqual( + x_ptr, + y_ptr, + msg=( + f"Buffers aliased, but should not be ({hex(x_ptr)} ==" + f" {hex(y_ptr)})" + ), + ) + self.assertEqual( + y_ptr % _XLA_CPU_MIN_ALIGNMENT, + 0, + msg="Output buffer not aligned: {hex(y_ptr)}", + ) + np.testing.assert_array_equal(y, x) + + tests.append(DLPackTest) + + class BufferProtocolTest(parameterized.TestCase): + + def setUp(self): + super(BufferProtocolTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape + } for dtype in standard_dtypes if dtype != bfloat16 + for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + + x = _Aligned(x) + x_ptr = x.__array_interface__["data"][0] + buffer = self.backend.buffer_from_pyval( + x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) + y = np.array(buffer, copy=False) + y_ptr = y.__array_interface__["data"][0] + np.testing.assert_array_equal(x, y) + + # The input was sufficiently aligned, so input and output should alias. + self.assertEqual(x_ptr, y_ptr) + self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) + + during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL + buffer2 = self.backend.buffer_from_pyval( + x, host_buffer_semantics=during_call) + z = np.array(buffer2, copy=False) + self.assertNotEqual(x.__array_interface__["data"][0], + z.__array_interface__["data"][0]) + + def testDeleteWithActiveView(self): + x = np.random.randn(20, 10) + buffer = self.backend.buffer_from_pyval(x) + buffer_ptr = buffer.unsafe_buffer_pointer() + y = np.array(buffer, copy=False) + buffer.delete() + # It is still legal to access `y`; the array view must keep it alive. + np.testing.assert_array_equal(x, y) + self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) + + tests.append(BufferProtocolTest) + + class TracebackTest(absltest.TestCase): + + def setUp(self): + super(TracebackTest, self).setUp() + self.backend = xla_backend() + + def testNoTracebacksIfDisabled(self): + with xla_client.tracebacks(enabled=False): + self.assertEqual(None, xla_client.Traceback.get_traceback()) + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertEqual(None, buffer.traceback) + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertEqual(None, e.traceback) + + def assertIsTracebackContaining(self, tb, function): + self.assertIsInstance(tb, xla_client.Traceback) + self.assertIn(function, str(tb)) + self.assertTrue(any(f.function_name == function for f in tb.frames)) + + def testTracebacks(self): + with xla_client.tracebacks(enabled=True): + tb = xla_client.Traceback.get_traceback() + self.assertIsTracebackContaining(tb, "testTracebacks") + + # Tracebacks are not implemented on the TPU driver extension's variant + # of buffers and executables. + if not isinstance(self.backend, xla_client.Client): + return + + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertIsTracebackContaining(e.traceback, "testTracebacks") + + def testNestedFunction(self): + + def AFunction(): + + def AnotherFunction(): + return xla_client.Traceback.get_traceback() + + return AnotherFunction() + + with xla_client.tracebacks(enabled=True): + tb = AFunction() + self.assertIsInstance(tb, xla_client.Traceback) + frames = tb.frames + i = next( + i for (i, f) in enumerate(frames) if f.function_name == "AFunction") + self.assertEqual(frames[i - 1].function_name, "AnotherFunction") + self.assertEqual(frames[i + 1].function_name, "testNestedFunction") + + def testPythonTracebackHasCorrectLineNumbers(self): + def B(): + return xla_client.Traceback.get_traceback() + + def A(): + return B() + + tb = A().as_python_traceback() + for frame, lineno in traceback.walk_tb(tb): + if frame.f_code.co_name == "A": + line = A.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + elif frame.f_code.co_name == "B": + line = B.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + + def testAccessingLocalsDoesNotCrash(self): + # https://github.com/google/jax/issues/16027 + tb = xla_client.Traceback.get_traceback() + python_tb = tb.as_python_traceback() + for frame, _ in traceback.walk_tb(python_tb): + _ = frame.f_locals # should not crash + + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = xla_client.Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = xla_client.Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = xla_client.Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + + tests.append(TracebackTest) + + class ClientTest(ComputationTest): + + def setUp(self): + super(ClientTest, self).setUp() + self.backend = xla_backend() + + def testPlatformVersion(self): + version = self.backend.platform_version + logging.info("platform_version:\n%s", version) + if self.backend.platform == "cpu": + self.assertEqual(version, "cpu") + elif self.backend.platform in ("gpu", "cuda", "rocm"): + # Following is false if not built with --config=cuda + if version != "": + self.assertTrue( + re.match(r"^cuda \d{4,}$", version), + msg=f"Expected CUDA version string; got {repr(version)}") + elif self.backend.platform == "tpu" and not (pathways or pathways_ifrt): + self.assertIn("tpu", version.lower()) + self.assertIn("cl/", version) + self.assertIn("Built on ", version) + + @unittest.skipIf( + not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins" + ) + def testPjRtCApiVersion(self): + self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0) + + @unittest.skipUnless( + not pjrt_c_api and tfrt_tpu, + "Test that attributes are zero for non-plugin tfrt_tpu", + ) + def testStaticTfrtTpuAttributes(self): + self.assertEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertEqual(self.backend.pjrt_c_api_minor_version, 0) + # CL number is defined as -1 when running as test. + self.assertEqual(self.backend.__getattr__("cl_number"), -1) + + @unittest.skipIf( + cloud_tpu or pjrt_c_api or (not pjrt_c_api and tfrt_tpu), + "PJRT version only exist for plugins", + ) + def testNotExistPjRtCApiVersion(self): + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement + + @unittest.skipIf(pathways or pathways_ifrt, "has different behavior") + def testPluginProgramDoesNotCompile(self): + program = xla_client.ifrt_programs.make_plugin_program("foobar") + options = xla_client.ifrt_programs.make_plugin_compile_options() + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, "PjRtCompiler requires an HloProgram" + ): + self.backend.compile_ifrt_program(program, options) + + @unittest.skipIf(pathways, "does not work with non-ifrt legacy pathways") + def testHloProgramViaIfrtProgram(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + program = xla_client.ifrt_programs.make_hlo_program( + xla_computation_to_mlir_module(c.build()) + ) + options = xla_client.ifrt_programs.make_xla_compile_options( + xla_client.CompileOptions(), [] + ) + + compiled_c = self.backend.compile_ifrt_program(program, options) + results = execute_with_python_values( + compiled_c, arguments=(), backend=self.backend + ) + + self.assertLen(results, 1) + np.testing.assert_equal(results[0], np.arange(10, dtype=np.float32)) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, + "not implemented") + def testExecutableSerialization(self): + if self.backend.platform != "tpu": + self.skipTest("Test requires tpu platform") + + c = self._NewComputation() + ops.Add( + ops.Constant(c, NumpyArrayS32([1, 2])), + ops.Constant(c, NumpyArrayS32([3, 4]))) + + options = xla_client.CompileOptions() + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), options) + self.assertLen(executable.hlo_modules(), 1) + + serialized = self.backend.serialize_executable(executable) + deserialized = self.backend.deserialize_executable(serialized, options) + + expected, = execute_with_python_values(executable, (), self.backend) + actual, = execute_with_python_values(deserialized, (), self.backend) + self.assertTrue(np.all(actual == expected)) + + def testCompileOptionsSerialization(self): + options = xla_client.CompileOptions() + executable_build_options = options.executable_build_options + options.num_replicas = 3 + options.num_partitions = 2 + options.profile_version = 1337 + options.compile_portable_executable = True + executable_build_options.num_replicas = 3 + executable_build_options.num_partitions = 2 + deb_opt = executable_build_options.debug_options + deb_opt.xla_cpu_enable_fast_math = True + deb_opt.xla_test_all_input_layouts = True + deb_opt.xla_gpu_kernel_cache_file = "/foo/bar" + deb_opt.xla_gpu_enable_llvm_module_compilation_parallelism = True + deb_opt.xla_gpu_per_fusion_autotune_cache_dir = "/bar/foo/" + deb_opt.xla_gpu_experimental_autotune_cache_mode = ( + xla_client.AutotuneCacheMode.READ + ) + + b = options.SerializeAsString() + restored = xla_client.CompileOptions.ParseFromString(b) + + for name in ("num_replicas", "num_partitions", "profile_version", + "compile_portable_executable"): + self.assertEqual(getattr(options, name), getattr(restored, name), + msg=name) + + for name in ("num_replicas", "num_partitions"): + self.assertEqual(getattr(options.executable_build_options, name), + getattr(restored.executable_build_options, name), + msg=name) + + for name in ( + "xla_cpu_enable_fast_math", + "xla_test_all_input_layouts", + "xla_gpu_kernel_cache_file", + "xla_gpu_enable_llvm_module_compilation_parallelism", + "xla_gpu_per_fusion_autotune_cache_dir", + "xla_gpu_experimental_autotune_cache_mode", + ): + self.assertEqual( + getattr(options.executable_build_options.debug_options, name), + getattr(restored.executable_build_options.debug_options, name), + msg=name) + + tests.append(ClientTest) + + # TODO(b/182461453): Add TFRT and cloud TPU implementation of + # ReadDynamicShapes + @unittest.skip("Test fails HLO -> MHLO conversion") + class DynamicReshapeTest(ComputationTest): + """Tests related to DynamicReshape.""" + + def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, + test_fn): + compiled = self.backend.compile( + xla_computation_to_mlir_module(builder.build())) + output_buffers = compiled.execute([ + self.backend.buffer_from_pyval( + arg, device=compiled.local_devices()[0]) for arg in args + ]) + self.assertLen(output_buffers, len(expected_results)) + for buf, expected in zip(output_buffers, expected_results): + to_py_result = np.asarray(buf) + self.assertEqual(expected.shape, to_py_result.shape) + test_fn(expected, to_py_result) + if self.backend.platform == "cpu" and buf.dtype != bfloat16: + mview = memoryview(buf) + self.assertEqual(expected.shape, mview.shape) + test_fn(expected, np.asarray(mview)) + else: + # Buffer protocol expected to fail on non-cpu platforms and bfloat16 + # Note that np.asarray(buf) doesn't throw an exception. To test if the + # error was thrown properly we must use memoryview(buf). + with self.assertRaises(BufferError): + memoryview(buf) + + # 1D reshape of full size, half size, and size of 0. + @unittest.skip("not implemented") + @parameterized.parameters((5), (3), (0)) + def testReshape1D(self, reshape_size): + full_size = 5 + c = self._NewComputation() + arg = np.array(reshape_size, dtype=np.int32) + expected = np.array(range(reshape_size), dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + ops.DynamicReshape( + ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], + [True]) + self._CompareToPyAndBufferProtocol(c, [arg], [expected], + np.testing.assert_equal) + + # 2D reshape with an slice on the minor dimension. We test different types + # where the strides may differ between the host and devices. The reshaped + # physical memory layout is not consecutive, and we test if the program can + # return the correct logical view of the data. + @unittest.skipIf( + cloud_tpu or pathways or tfrt_tpu or pjrt_c_api, + "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testReshape2D(self, dtype): + arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + arg1 = np.array(2, dtype=np.int32) + expected = np.array([[1, 2], [4, 5]], dtype=np.int32) + c = self._NewComputation() + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) + self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], + np.testing.assert_equal) + + @unittest.skipIf(cloud_tpu or pathways or tfrt_tpu, "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testDynamicShapeArgs(self, dtype): + full_size = 10 + dynamic_shape_size = 4 + # subcomputation 1 + binary_add_builder = self._NewComputation() + scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) + ops.Add( + ops.Parameter(binary_add_builder, 0, scalar_shape), + ops.Parameter(binary_add_builder, 1, scalar_shape)) + # subcomputation 2 + reshape_reduce_builder = self._NewComputation() + dshape = xla_client.Shape.array_shape( + np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) + reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) + ops.Reduce( + reshape_reduce_builder, + operands=[reshape_reduce_p], + init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], + computation=binary_add_builder.build(), + dimensions_to_reduce=[0]) + # main computation: sum(range(full_size)[:dynamic_shape_size]) + c = self._NewComputation() + arg = np.array(dynamic_shape_size, dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + reshaped = ops.DynamicReshape( + ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], + [full_size], [True]) + ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) + self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) + + tests.append(DynamicReshapeTest) + + class DeviceAssignmentTest(ComputationTest): + + def testSerialize(self): + shape = (3, 4) + device_assignment = xla_client.DeviceAssignment.create( + np.arange(np.prod(shape)).reshape(*shape)) + self.assertEqual(device_assignment.replica_count(), shape[0]) + self.assertEqual(device_assignment.computation_count(), shape[1]) + serialized = device_assignment.serialize() + self.assertIsInstance(serialized, bytes) + self.assertNotEmpty(serialized) + + tests.append(DeviceAssignmentTest) + + class TokenTest(ComputationTest): + """Tests related to PyToken.""" + + def testExecuteWithToken(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + results, token = compiled_c.execute_with_token([]) + token.block_until_ready() + self.assertLen(results, 1) + np.testing.assert_allclose( + np.asarray(results[0]), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) + + def testExecuteShardedOnLocalDevicesWithTokens(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + num_replicas = 1 + options = xla_client.CompileOptions() + options.num_replicas = num_replicas + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + results, sharded_token = ( + compiled_c.execute_sharded_on_local_devices_with_tokens([]) + ) + sharded_token.block_until_ready() + self.assertLen(results, 1) + self.assertLen(results[0], 1) + np.testing.assert_allclose( + np.asarray(results[0][0]), + np.float32([-3, 6.6, 2.4, -2.1]), + rtol=3e-3) + + tests.append(TokenTest) + + class ExecutePortableTest(ComputationTest): + + @unittest.skip("Test does not work under IFRT") + def testExecutePortable(self): + devices_by_kind = collections.defaultdict(list) + for device in self.backend.devices(): + devices_by_kind[device.device_kind].append(device) + multi_devices = [d for d in devices_by_kind.values() if len(d) > 1] + if not multi_devices: + raise unittest.SkipTest("Test needs multiple identical devices") + devices = multi_devices[0] + + c = self._NewComputation() + args = [ + np.array(3, dtype=np.int32), + np.array([10, 15, -2, 7], dtype=np.int32) + ] + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0])) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1])) + ops.Mul(p0, p1) + options = xla_client.CompileOptions() + options.compile_portable_executable = True + compiled_c = self.backend.compile(c.build(), compile_options=options) + for device in devices: + out, = compiled_c.execute( + [self.backend.buffer_from_pyval(a, device=device) for a in args], + device=device) + np.testing.assert_array_equal(np.asarray(out), args[0] * args[1]) + + tests.append(ExecutePortableTest) + + class ExecuteShardedOverloadTest(ComputationTest): + + def testExecuteShardedOverloadEmptyInput(self): + c = self._NewComputation() + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)) + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + results = compiled_c.execute_sharded_on_local_devices([]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens([]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + def testExecuteShardedOverloadBufferInput(self): + arg = np.arange(12, dtype=np.int16).reshape(3, 4) + c = self._NewComputation() + ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + buffer = self.backend.buffer_from_pyval(arg) + + results = compiled_c.execute_sharded_on_local_devices([[buffer]]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens( + [[buffer]]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + tests.append(ExecuteShardedOverloadTest) + + return tests + + +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): + test = type(test_prefix + klass.__name__, (klass,), {}) + # Clean up the qualified names of the tests to not include the test factory. + test.__qualname__ = test.__name__ + globals_dict[test.__name__] = test + + +backends = { + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), +} + +if __name__ == "__main__": + flags.DEFINE_string("backend", "cpu", "Target platform.") + jax.config.parse_flags_with_absl() + # pylint: disable=unnecessary-lambda + InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) + # pylint: enable=unnecessary-lambda + absltest.main() diff --git a/pyproject.toml b/pyproject.toml index a1b9e7dd446a..be29e16beb9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,11 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", - "jaxlib.*", + "jaxlib.cpu_feature_guard", + "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.utils", + "jaxlib.xla_extension.utils", "jraph.*", "libtpu.*", "matplotlib.*", From c2e7c3e72d7ef9f6f30ae155f2142089fe1d6e48 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 07:10:29 -0700 Subject: [PATCH 0076/1769] [Mosaic GPU] Add a transform inference rule for `memref.subview`. This will be used when lowering from Pallas, in order to handle calls to `memref_slice` in the `_handle_indexing` util. Currently, we only allow propagating a restricted set of transforms (tile and swizzle transforms), and only when they can be passed through the op bidirectionally without modification. PiperOrigin-RevId: 739168839 --- .../mosaic/gpu/transform_inference.py | 91 ++++++++++++++-- tests/mosaic/gpu_transform_inference_test.py | 101 ++++++++++++++++++ 2 files changed, 184 insertions(+), 8 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index d285e5df188f..80ab6077755a 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -30,6 +30,7 @@ from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip from . import fragmented_array as fa from . import inference_utils @@ -184,21 +185,20 @@ def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: + if transforms is None: transforms = out_transforms + elif out_transforms is not None and transforms != out_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) return None if transforms is None else ([], [transforms]) # TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use # the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` +# The rule is necessary in order to handle the lowering of `utils.memref_ptr` # which is used in `_construct_smem_reftree`. @partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) def _infer_unrealized_conversion_cast_transforms( @@ -250,6 +250,81 @@ def _infer_dynamic_smem_transforms( return None +# This is used by Pallas' "_handle_indexing" memory transform. +@partial(_add_transform_inference_rule, memref.SubViewOp) +def _infer_memref_subview_transforms( + op: memref.SubViewOp, +) -> OptionalTransforms: + transforms = None + + for result_use in cast(ir.OpResult, op.result).uses: + consumer = result_use.owner + op_user = consumer.operands[result_use.operand_number] + user_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is None: + transforms = user_transforms + elif user_transforms is not None and transforms != user_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {user_transforms}." + ) + + in_transforms = inference_utils.value_transforms(op.source) + if transforms is None: + transforms = in_transforms + elif in_transforms is not None and transforms != in_transforms: + raise ValueError( + f"Conflicting transforms for {op.source} in {op}: " + f"{transforms} != {in_transforms}." + ) + + if transforms is None: + return None + + # Here, we have some transforms to propagate one way or the other. For now, + # we implement only the following basic propagation rules: + # - A tile transform can be propagated bidirectionally if the axes being + # tiled are not sliced, and are the logical minor axes of the source. + # - A swizzle transform can be propagated towards the input of a subview if + # the physical minormost dimension is unchanged. + # - We only propagate transforms if they consist of a single tile transform + # and a single swizzle transform. + # TODO(bchetioui): implement more complex propagation rules. + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Can't propagate transforms {transforms}.") + else: + raise NotImplementedError(f"Can't propagate transforms {transforms}.") + + # Check swizzle transform propagation. + strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) + minor_dim = strides.index(min(strides)) + if op.source.type.shape[minor_dim] != op.static_sizes[minor_dim]: + raise NotImplementedError( + "Swizzle transforms can only propagated if the minor dimension is " + "unchanged." + ) + + # Check tile transform propagation. + num_tiled_axes = len(mgpu.TileTransformAttr(tile_transform).tiling) + last_n_dims = op.source.type.shape[-num_tiled_axes:] + last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] + for slice_size, dim_size in safe_zip(last_n_sizes, last_n_dims): + if slice_size != dim_size: + raise NotImplementedError( + "Tile transforms are only propagated if the tiled axes are not " + "sliced." + ) + + return [transforms], [transforms] + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index b7cd146dfdb6..983efebc4f86 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -24,7 +24,9 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import fragmented_array as fa @@ -418,6 +420,105 @@ def body(offset): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_propagates_undisturbed_tile_and_swizzle_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get(shape[2:], elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets=[1, 0, 0], + static_sizes=[1, 64, 64], + static_strides=[1, 1, 1], + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(subview_op), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(subview_op), [transforms] + ) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets = [1, 0, 0], + static_sizes = [2, 64, 32], + static_strides = [1, 1, 1] + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaises(NotImplementedError): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From dac5247cca85df9a8bcac65b7a033038739ebb90 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 07:26:10 -0700 Subject: [PATCH 0077/1769] Ensure traceback correctness in error checking PiperOrigin-RevId: 739172653 --- tests/error_check_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index ad67cadfb074..5bf71a9eb592 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -13,6 +13,8 @@ # limitations under the License. +import traceback + from absl.testing import absltest from absl.testing import parameterized import jax @@ -108,6 +110,32 @@ def g(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_includes_traceback(self, jit): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback. + x <= 0, "x must be greater than 0" + ) + return x + 1 + + if jit: + function_that_triggers_error_for_traceback_test = jax.jit( + function_that_triggers_error_for_traceback_test + ) + + x = jnp.zeros((4,), dtype=jnp.int32) + function_that_triggers_error_for_traceback_test(x) + + tb_string = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + + self.assertIn("function_that_triggers_error_for_traceback_test", tb_string) + self.assertIn("This line must be included in the traceback", tb_string) + @parameterized.product(jit=[True, False]) def test_error_check_works_with_cond(self, jit): def f(x): From be6585d0005340d1f6ef3830bdac64d7e7e52b8c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 08:07:51 -0700 Subject: [PATCH 0078/1769] [pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X3` in Triton lowering. PiperOrigin-RevId: 739183661 --- jax/_src/pallas/triton/lowering.py | 28 ++++++++++++++++++++++++++++ tests/pallas/pallas_test.py | 14 +++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f3a8dd175ec1..64bf635a34ed 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1260,6 +1260,10 @@ def _cmp( ) +def _is_nan(x: ir.Value) -> ir.Value: + return arith_dialect.cmpf(arith_dialect.CmpFPredicate.UNO, x, x) + + _JAX_TO_TRITON_BINARY = { lax.add_p: _add, lax.sub_p: _sub, @@ -2237,6 +2241,7 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.F16_F16_F32 | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 ): input_precision = None case _: @@ -2276,6 +2281,29 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: + bf16 = _dtype_to_ir_type(jnp.bfloat16) + f32 = _dtype_to_ir_type(jnp.float32) + as_bf16 = lambda x: _ir_cast(x, bf16, signed=False) + as_f32 = lambda x: _ir_cast(x, f32, signed=False) + + a_bf16 = as_bf16(a) + b_bf16 = as_bf16(b) + a_err0 = as_bf16(_sub(a, as_f32(a_bf16))) + b_err0 = as_bf16(_sub(b, as_f32(b_bf16))) + # Accumulate the smallest values first to reduce the numeric error. + acc = tt_dialect.dot(a_err0, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0, acc) + # If `a_err0` will be zero and `b` is infinite, then `acc` may contain + # `NaN`s (as `0 * inf = NaN`), and vice versa. + acc = arith_dialect.select( + _is_nan(acc), + _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0), + acc, + ) + a, b = a_bf16, b_bf16 + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) return _cast(acc, acc_dtype, out_aval.dtype) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 745c30ba98cb..0ce68a5c023c 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -702,6 +702,7 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -731,7 +732,18 @@ def dot_kernel(x_ref, y_ref, o_ref): precision=jax.lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) - self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + if dtype == "bfloat16" or precision in ( + jax.lax.Precision.HIGHEST, jax.lax.DotAlgorithmPreset.F32_F32_F32 + ): + atol = 0 + elif precision in ( + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, + ): + atol = 5e-4 + else: + atol = 5e-2 + self.assertAllClose(dot_kernel(x, y), expected, atol=atol, rtol=atol / 10) @parameterized.parameters(jnp.int8, jnp.uint8) def test_integer_dot(self, dtype): From f1ff64f404c522210b9edd23a6e5f76cf77ef896 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:08:31 -0700 Subject: [PATCH 0079/1769] [Mosaic GPU][NFC] Factor our transform resolution into a `_resolve_transforms` util. PiperOrigin-RevId: 739183876 --- .../mosaic/gpu/transform_inference.py | 81 +++++++++---------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 80ab6077755a..c76af4fb07e2 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -61,6 +61,37 @@ def _set_transform_attributes( op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms) +def _resolve_transforms( + transforms: ir.ArrayAttr | None, + other_transforms: ir.ArrayAttr | None, +) -> ir.ArrayAttr | None: + """Resolves two sets of competing transforms to a single compatible set. + + Args: + transforms: one optional set of transforms. + other_transforms: another optional set of transforms. + + Returns: + A single set of transforms that is compatible with both `transforms` and + `other_transforms`, or `None` if both transforms are `None`. + Raises: + NotImplementedError: if the two sets of transforms can't be resolved to a + single set. + """ + if transforms is None: + return other_transforms + + if other_transforms is None: + return transforms + + if transforms != other_transforms: + raise NotImplementedError( + f"Conflicting transforms {transforms} != {other_transforms}." + ) + + return transforms + + def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") @@ -157,21 +188,8 @@ def _infer_vector_load_store_transforms( f"Got layout {layout} which is not yet supported" ) - if transforms is not None and layout_transforms is not None: - if transforms != layout_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op.base} in {op}: " - f"{transforms} != {layout_transforms}." - ) - return [transforms], [] - - if transforms is not None: - return [transforms], [] - - if layout_transforms is not None: - return [layout_transforms], [] - - return None + transforms = _resolve_transforms(transforms, layout_transforms) + return None if transforms is None else ([transforms], []) @partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) @@ -185,13 +203,7 @@ def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is None: - transforms = out_transforms - elif out_transforms is not None and transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) + transforms = _resolve_transforms(transforms, out_transforms) return None if transforms is None else ([], [transforms]) @@ -227,14 +239,7 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise ValueError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms + transforms = _resolve_transforms(transforms, out_transforms) # TODO(bchetioui): do we actually need to assign a transform to the input of # the view op? Presumably, it'll only be used to access scratch memory. @@ -263,22 +268,10 @@ def _infer_memref_subview_transforms( user_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is None: - transforms = user_transforms - elif user_transforms is not None and transforms != user_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {user_transforms}." - ) + transforms = _resolve_transforms(transforms, user_transforms) in_transforms = inference_utils.value_transforms(op.source) - if transforms is None: - transforms = in_transforms - elif in_transforms is not None and transforms != in_transforms: - raise ValueError( - f"Conflicting transforms for {op.source} in {op}: " - f"{transforms} != {in_transforms}." - ) + transforms = _resolve_transforms(transforms, in_transforms) if transforms is None: return None From 59d25f4642e383f4236fd85dc753c642ef2307aa Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:21:42 -0700 Subject: [PATCH 0080/1769] [Mosaic GPU] Add transform inference rule for `memref.load`. PiperOrigin-RevId: 739187660 --- jax/experimental/mosaic/gpu/transform_inference.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index c76af4fb07e2..3438a654f90a 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -318,6 +318,16 @@ def _infer_memref_subview_transforms( return [transforms], [transforms] +# `memref.load` is used to load barrier phases---the rule needn't do anything +# interesting, but we need to have it in order to avoid crashing on it. +@partial(_add_transform_inference_rule, memref.LoadOp) +def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: + if not ir.MemRefType(op.memref.type).shape: + # memref.load returns a scalar, so there is nothing interesting to do here. + return None + raise NotImplementedError("Non-scalar memref.load transforms") + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( From 27b30190be173d759fcab223242211a33ab6e3f3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:36:52 -0700 Subject: [PATCH 0081/1769] [Pallas/Mosaic GPU] Add lowering for WGMMA using warpgroup semantics. When using warpgroup semantics, the transforms are inferred by the transform inference pass---except for transposition which will still get propagated down from Pallas. Also turn on `transform_inference` in the Pallas->Mosaic GPU lowering pipeline. PiperOrigin-RevId: 739191716 --- jax/_src/pallas/mosaic_gpu/lowering.py | 17 +++--- jax/_src/pallas/mosaic_gpu/primitives.py | 65 ++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 68 ++++++++++++++---------- 3 files changed, 106 insertions(+), 44 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ef4c80cb4649..004a6e7f2760 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -766,6 +766,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error + mgpu.infer_transforms(module) # pytype: disable=attribute-error mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error mgpu_core._initialize_scratch(launch_ctx, scratch_arr) @@ -1837,13 +1838,15 @@ def _run_scoped_lowering_rule( for v in jaxpr.invars: aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: - # TODO(bchetioui): Fix this and remove the NotImplementedError. - raise NotImplementedError( - "WGMMA accumulators are not supported with Warpgroup semantics." - ) - mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) should_discharge.append(True) elif isinstance(aval.dtype, gpu_core.BarrierType): input_refs.append( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7f26f5d2b6a3..edfae55fb288 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -517,11 +517,7 @@ def commit_smem_to_gmem_group() -> None: wgmma_ref_p.multiple_results = True -def wgmma( - acc: gpu_core.WGMMAAbstractAccumulatorRef, - a, - b: pallas_core.TransformedRef, -) -> None: +def wgmma(acc: gpu_core.WGMMAAbstractAccumulatorRef, a, b) -> None: """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, @@ -555,12 +551,17 @@ def wgmma( a = a.ref else: a_transforms_leaves, a_transforms_tree = [], None - b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None wgmma_ref_p.bind( acc, a, - b.ref, + b, *a_transforms_leaves, *b_transforms_leaves, a_transforms_tree=a_transforms_tree, @@ -674,6 +675,40 @@ def _wgmma_lowering( return new_acc +@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Warpgroup) +def _wgmma_warpgroup_lowering( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, +): + del ctx, transforms_leaves # Unused. + if a_transforms_tree is not None: + match a_transforms_tree: + case gpu_core.TransposeRef((1, 0)): + raise NotImplementedError("WGMMA lhs transpose not supported.") + case _: + raise ValueError( + f"WGMMA lhs has unsupported transforms: {a_transforms_tree}." + ) + + if b_transforms_tree is not None: + match b_transforms_tree: + case gpu_core.TransposeRef((1, 0)): + raise NotImplementedError("WGMMA rhs transpose not supported.") + case _: + raise ValueError( + f"WGMMA rhs has unsupported transforms: {b_transforms_tree}." + ) + + new_acc = mgpu.dialect.wgmma(acc, a, b) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + + @wgmma_p.def_effectful_abstract_eval def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs @@ -698,6 +733,7 @@ def wgmma_wait_effectful_abstract_eval(_): @lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -728,11 +764,19 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane +) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Warpgroup +) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): - del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) - return acc.value + return ( + acc.value + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + else acc + ) class Layout(enum.Enum): @@ -835,6 +879,7 @@ def _commit_smem_abstract_eval(): @lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): + # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() return () diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 94c2620f7ae6..408b8bdf5713 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1302,7 +1302,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.parameters([*plgpu.ThreadSemantics]) + def test_realistic_matmul(self, thread_semantics): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1326,34 +1327,46 @@ def _epilogue(): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) + out_spec = pl.BlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + ) + + if thread_semantics == plgpu.ThreadSemantics.Lane: + lhs_spec = plgpu.GPUBlockSpec( + lhs_spec.block_shape, lhs_spec.index_map, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + rhs_spec = plgpu.GPUBlockSpec( + rhs_spec.block_shape, rhs_spec.index_map, + transforms=( + plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + out_spec = plgpu.GPUBlockSpec( + out_spec.block_shape, out_spec.index_map, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + res = pl.pallas_call( kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), @@ -1361,6 +1374,7 @@ def _epilogue(): dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, + thread_semantics=thread_semantics, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) From 92f5d9caa33f59b1c8511f4fc0676e1a155ab4a2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 21 Mar 2025 08:53:50 -0700 Subject: [PATCH 0082/1769] Deprecated `jax.tree_util.build_tree` We have no usages of it neither in JAX nor internally, but we still have to go through the deprecation cycle, becuase `jax.tree_util` is public API. PiperOrigin-RevId: 739196514 --- CHANGELOG.md | 5 +++++ jax/_src/tree_util.py | 9 ++------- jax/tree_util.py | 24 ++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a817ce80937..17fb421fcc06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + ## jax 0.5.3 (Mar 19, 2025) * New Features diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6c7e15a042e5..883937fcce6e 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -362,6 +362,8 @@ def tree_map(f: Callable[..., Any], def build_tree(treedef: PyTreeDef, xs: Any) -> Any: """Build a treedef from a nested iterable structure + DEPRECATED: Use :func:`jax.tree.unflatten` instead. + Args: treedef: the PyTreeDef structure to build. xs: nested iterables matching the arity as the treedef @@ -376,13 +378,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) - - Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct - the tree from new values, but ``build_tree`` takes these values in terms of - a nested rather than flat structure: - - >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) - [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}] """ diff --git a/jax/tree_util.py b/jax/tree_util.py index 956d79b9b4ef..3d24c457b3f8 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,13 +48,13 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as build_tree, + build_tree as _deprecated_build_tree, default_registry as default_registry, keystr as keystr, + register_dataclass as register_dataclass, register_pytree_node_class as register_pytree_node_class, register_pytree_node as register_pytree_node, register_pytree_with_keys_class as register_pytree_with_keys_class, - register_dataclass as register_dataclass, register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, @@ -72,3 +72,23 @@ treedef_is_leaf as treedef_is_leaf, treedef_tuple as treedef_tuple, ) + +_deprecations = { + # Added March 21, 2025: + "build_tree": ( + ( + "jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten" + " instead." + ), + _deprecated_build_tree, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + from jax._src.tree_util import build_tree as build_tree +else: + from jax._src.deprecations import deprecation_getattr + __getattr__ = deprecation_getattr(__name__, _deprecations) + del deprecation_getattr, _deprecations +del _typing From 027195489a5b9a1d253550c2954f8aa11fa03370 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 08:55:59 -0700 Subject: [PATCH 0083/1769] Reorder C++ imports (nanobind). PiperOrigin-RevId: 739197184 --- examples/ffi/src/jax_ffi_example/gpu_examples.cc | 2 +- jaxlib/cpu/lapack.cc | 2 +- jaxlib/cuda/cuda_plugin_extension.cc | 2 +- jaxlib/cuda/versions.cc | 3 +-- jaxlib/gpu/blas.cc | 4 ++-- jaxlib/gpu/gpu_plugin_extension.cc | 6 +++--- jaxlib/gpu/hybrid.cc | 2 +- jaxlib/gpu/py_client_gpu.cc | 2 +- jaxlib/gpu/solver.cc | 4 ++-- jaxlib/gpu/sparse.cc | 4 ++-- jaxlib/gpu/triton.cc | 2 +- jaxlib/kernel_nanobind_helpers.h | 2 +- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 2 +- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 10 +++++----- jaxlib/mlir/_mlir_libs/triton_ext.cc | 2 +- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- jaxlib/rocm/rocm_plugin_extension.cc | 2 +- jaxlib/utils.cc | 2 +- jaxlib/xla/custom_calls_testlib.cc | 2 +- 19 files changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.cc b/examples/ffi/src/jax_ffi_example/gpu_examples.cc index 921039debe5d..79a4ee91e8c6 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "cuda_runtime_api.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" namespace nb = nanobind; diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c104019777e5..7cc4fa9e2dbd 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 789227e273b6..6655128b9842 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index 8d6577f46709..d9f9f4c86865 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/versions_helpers.h" - #include "nanobind/nanobind.h" +#include "jaxlib/cuda/versions_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index e8761bd32ac9..4a58859016f1 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index 5726e0929ee5..d026806e9479 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index 94975a5b969f..71c320a60f02 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/gpu/hybrid_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index cf701574959b..3e140411770d 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -29,6 +28,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/vendor.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 357a38eecfd5..1cf799bbb491 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 429c8018dc7a..a7f8dbebc2b3 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/base/casts.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 500034af3ebb..135410568f6b 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -5,13 +5,13 @@ #include #include +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "nanobind/stl/string.h" #include "nanobind/stl/string_view.h" #include "nanobind/stl/tuple.h" #include "nanobind/stl/vector.h" -#include "absl/status/statusor.h" #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index fde37e695349..127d89f702c8 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" +#include "nanobind/nanobind.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index 7483d7ed1eea..c73084abc99d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "nanobind/nanobind.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 2b5ec898ad3e..7d616968b9aa 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -43,13 +43,13 @@ limitations under the License. // clang-format off #include "mlir-c/Bindings/Python/Interop.h" // clang-format on +#include "absl/log/check.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "absl/log/check.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/python/lib/core/numpy.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index e824d4058d7e..2a13c40d963f 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" #include "jaxlib/triton/triton_dialect_capi.h" namespace nb = nanobind; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 2c7242b6e6c0..ee11b22020dc 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" #include "nanobind/nanobind.h" #include "nanobind/stl/tuple.h" #include "nanobind/stl/vector.h" -#include "absl/cleanup/cleanup.h" -#include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1e8013f2bc1b..454f4741d667 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index bf50b3a5254d..e5bb45e999da 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/xla/custom_calls_testlib.cc index d06105fce76f..58f4818a431e 100644 --- a/jaxlib/xla/custom_calls_testlib.cc +++ b/jaxlib/xla/custom_calls_testlib.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" From 40ce44d143e160d7c44f5453fe3f49d413598301 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Mar 2025 09:25:11 -0700 Subject: [PATCH 0084/1769] Add `ShardingTypeError` to all sharding rules in JAX PiperOrigin-RevId: 739205830 --- jax/_src/lax/lax.py | 25 ++++++++++--------- jax/_src/lax/linalg.py | 4 +-- jax/_src/lax/slicing.py | 7 ++---- jax/_src/lax/utils.py | 2 +- jax/_src/lax/windowed_reductions.py | 4 +-- tests/pjit_test.py | 38 +++++++++++++++-------------- 6 files changed, 40 insertions(+), 40 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 388ad49ec83d..f6ab848ccd5b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3805,7 +3805,7 @@ def broadcasting_sharding_rule(name, *avals): for a in avals: if a.sharding is not None and not a.sharding.mesh.empty: if mesh is not None and mesh != a.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') mesh = a.sharding.mesh @@ -3839,7 +3839,7 @@ def broadcasting_sharding_rule(name, *avals): result_specs[i] = s elif (result_specs[i] is not None and s is not None and result_specs[i] != s): - raise TypeError( + raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) @@ -4990,13 +4990,13 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): def _check_specs_match(lhs_spec, rhs_spec, msg): for l, r in zip(lhs_spec, rhs_spec): if l is not None and r is not None and l != r: - raise TypeError(msg) + raise core.ShardingTypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_sharding): if lhs.sharding.mesh != rhs.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') @@ -5020,7 +5020,7 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, for l, r in zip(lhs_contracting_spec, rhs_contracting_spec): if l is not None and r is not None: - raise ValueError( + raise core.ShardingTypeError( 'Contracting dimensions are sharded and it is ambiguous how the' ' output should be sharded. Please specify the output sharding via' ' the `out_sharding` parameter of einsum. Or reshard your input via' @@ -6378,7 +6378,7 @@ def _concatenate_sharding_rule(*operands, **kwargs): return core.get_cur_mesh_sharding() if not all(s == non_empty_s[0] for s in non_empty_s): ss = ", ".join(str(o.sharding) for o in operands) - raise TypeError( + raise core.ShardingTypeError( f"All operands should have the same sharding. Got shardings {ss}") return non_empty_s[0] @@ -6697,7 +6697,7 @@ def _split_on_one_axis(op_shape, new_sizes, name): else: count += 1 if count > 1: - raise ValueError( + raise core.ShardingTypeError( f'{name} on more than 1 axis is not supported. Please specify' ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') @@ -6738,7 +6738,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions) - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6771,7 +6771,7 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0: new_spec.extend([sp] + [None] * (len(out) - 1)) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6796,7 +6796,7 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): assert new_size % _get_spec_size(sp[0], mesh) == 0 new_spec.append(sp[0]) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6979,10 +6979,11 @@ def _select_sharding_rule(which, *cases): return core.get_cur_mesh_sharding() if any(s != non_empty_s[0] for s in non_empty_s[1:]): msg = "select cases must have the same shardings, got [{}]." - raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + raise core.ShardingTypeError( + msg.format(", ".join([str(c.sharding) for c in cases]))) if (which.shape and not which.sharding.mesh.empty and which.sharding != non_empty_s[0]): - raise TypeError( + raise core.ShardingTypeError( 'select `which` must be scalar or have the same sharding as cases, got' f' `which` sharding {which.sharding} but case sharding' f' {cases[0].sharding}.') diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3e9077d0a51c..027ec8b801b9 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -717,14 +717,14 @@ def linalg_sharding_rule( spec = aval.sharding.spec batch_spec, rest_spec = spec[:len(spec) - rank], spec[len(spec) - rank:] if not all(s is None for s in rest_spec): - raise ValueError( + raise core.ShardingTypeError( f"Input {i} to {name} must be unsharded on non-batch dimensions, " f"but got {spec}." ) batch_specs.append(batch_spec) batch_spec = batch_specs[0] if any(b != batch_spec for b in batch_specs[1:]): - raise ValueError( + raise core.ShardingTypeError( f"All inputs to {name} must have the same batch sharding, but got " f"{batch_specs}." ) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b3a0a8e2d0c1..d3bcb6da2807 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1333,7 +1333,7 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( + raise core.ShardingTypeError( f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" f" ({op_spec}) is not implemented.") @@ -1922,9 +1922,6 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): else next(indices_shape_gen) for i in range(output_shape_rank)) return ans -class GatherShardingError(Exception): - pass - def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1936,7 +1933,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, all(s is None for s in operand.sharding.spec) and all(s is None for s in indices.sharding.spec)): return core.get_cur_mesh_sharding() - raise GatherShardingError( + raise core.ShardingTypeError( "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" " the gather indexing.") diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 9fc9ba16a604..206a8312ba8c 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -96,7 +96,7 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, short_dtypes=True) - raise TypeError( + raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 400646f6238f..42b2e9278889 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -525,7 +525,7 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, if spec is None: continue if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): - raise NotImplementedError( + raise core.ShardingTypeError( "Only trivial windowing is supported along non-replicated" f" dimensions. Got {operand.sharding.spec=}") return operand.sharding @@ -826,7 +826,7 @@ def _select_and_gather_add_sharding_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.sharding != operand.sharding: - raise TypeError( + raise core.ShardingTypeError( "select_and_gather_add tangents and operand shardings must match, " f"got {tangents.sharding} and {operand.sharding}.") return reduce_window_sharding_rule( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6a1a73fe4301..6fdfa62887b9 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4995,11 +4995,13 @@ def g(x, y): return x * y with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) @parameterized.named_parameters( @@ -5098,14 +5100,14 @@ def f(x, y): @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), "dot_general operation.*produces an illegally sharded result", - TypeError), + core.ShardingTypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", - TypeError), + core.ShardingTypeError), ('contracting1', P('x', 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ('other_half_tp', P(None, 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ) @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): @@ -5127,14 +5129,14 @@ def test_dot_general_batch_error(self, mesh): arr2 = jax.device_put(np.ones((8, 2, 4)), NamedSharding(mesh, P('y', 'z', 'x'))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jax.lax.dot_general( arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) @@ -5569,7 +5571,7 @@ def f(x): return y if error_msg: - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex(core.ShardingTypeError, error_msg): f(arr) else: out = f(arr) @@ -5608,7 +5610,7 @@ def f(pred, on_true, on_false): arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) with self.assertRaisesRegex( - TypeError, "select cases must have the same shardings"): + core.ShardingTypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) def test_explicit_mode_no_context_mesh(self): @@ -5778,10 +5780,10 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) @jtu.with_user_mesh((2, 2), ('x', 'y')) @@ -5842,13 +5844,13 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((2, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((0, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) @@ -5879,7 +5881,7 @@ def f(x, y, method='jnp'): self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) with self.assertRaisesRegex( - TypeError, "All operands should have the same sharding"): + core.ShardingTypeError, "All operands should have the same sharding"): arr3 = jax.device_put(np.arange(4.).reshape(4, 1), NamedSharding(mesh, P('x'))) f(arr1, arr3) @@ -6147,7 +6149,7 @@ def f(x, sizes=(4, 4), axis=0): f(arr) self.check_wsc_in_lowered(f.lower(arr).as_text()) - with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "split on sharded dims"): f(arr, sizes=(1, 1), axis=1) def g(x): @@ -6452,7 +6454,7 @@ def f(x, y, z): # Errors out on the intermediate einsum: `bthj,bthD->bthjD` # because of a conflict with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) From 3bf2eea259107cfaadedd8ee59d0b586401eeb08 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 10:16:42 -0700 Subject: [PATCH 0085/1769] Add AOT support for error checking PiperOrigin-RevId: 739222389 --- jax/_src/error_check.py | 198 ++++++++++++++++++++++++++++++++------ tests/error_check_test.py | 59 +++++++++++- 2 files changed, 224 insertions(+), 33 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 88dcec7063d9..e78b9bc82115 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,8 +14,12 @@ from __future__ import annotations +import dataclasses from functools import partial +import json import threading +import traceback as tb_lib +from types import TracebackType import warnings import jax @@ -23,19 +27,17 @@ from jax._src import source_info_util from jax._src import traceback_util import jax._src.mesh as mesh_lib -from jax.experimental.shard_map import shard_map +from jax.experimental import shard_map +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P -Traceback = source_info_util.Traceback - - traceback_util.register_exclusion(__file__) class JaxValueError(ValueError): - """Exception raised for failed runtime error checks in JAX.""" + """Exception raised for runtime errors detected within JAX computations.""" #: The default error code for no error. @@ -45,8 +47,9 @@ class JaxValueError(ValueError): _NO_ERROR = jnp.iinfo(jnp.uint32).max -_error_list_lock = threading.Lock() -_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair +_error_list_lock = threading.RLock() +# (error_message, traceback) pairs. Traceback is `str` when imported from AOT. +_error_list: list[tuple[str, TracebackType | str]] = [] class _ErrorStorage(threading.local): @@ -65,22 +68,21 @@ def _initialize_error_code_ref() -> None: In single-device environments, the array is a scalar. In multi-device environments, its shape and size match those of the mesh. """ - with core.eval_context(): - # Get mesh from the context. - mesh = mesh_lib.get_concrete_mesh() - - if mesh is None: # single-device case. - error_code = jnp.uint32(_NO_ERROR) - - else: # multi-device case. - sharding = NamedSharding(mesh, P(*mesh.axis_names)) - error_code = jnp.full( - mesh.axis_sizes, - jnp.uint32(_NO_ERROR), - device=sharding, - ) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh is None: # single-device case. + error_code = jnp.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = jnp.full( + mesh.axis_sizes, + jnp.uint32(_NO_ERROR), + device=sharding, + ) - _error_storage.ref = core.mutable_array(error_code) + _error_storage.ref = core.mutable_array(error_code) class error_checking_context: @@ -105,7 +107,8 @@ def __init__(self): def __enter__(self): self.old_ref = _error_storage.ref - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() return self def __exit__(self, exc_type, exc_value, traceback): @@ -126,22 +129,33 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: will not override the existing error. For multi-device environments, in explicit mode, users must call - :func:`error_checking_context()` to initialize a new error tracking state that + :func:`error_checking_context` to initialize a new error tracking state that matches the device mesh. In auto mode, implicit cross-device communication may occur inside this function, which could impact performance. A warning is issued in such cases. + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + Args: pred: A JAX boolean array. If any element of `pred` is `True`, the internal error state will be set. msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() assert _error_storage.ref is not None + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None + traceback = traceback.as_python_traceback() + assert isinstance(traceback, TracebackType) + traceback = traceback_util.filter_traceback(traceback) + assert isinstance(traceback, TracebackType) + with _error_list_lock: new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) @@ -171,7 +185,7 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: "Please use `with error_checking_context()` to redefine the error " "code state based on the mesh." ) - pred = shard_map( + pred = shard_map.shard_map( partial(jnp.any, keepdims=True), mesh=out_sharding.mesh, in_specs=in_sharding.spec, @@ -179,7 +193,7 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: )(pred) # perform per-device reduction error_code = _error_storage.ref[...] - should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + should_update = jnp.logical_and(error_code == jnp.uint32(_NO_ERROR), pred) error_code = jnp.where(should_update, new_error_code, error_code) # TODO(ayx): support vmap and shard_map. _error_storage.ref[...] = error_code @@ -216,8 +230,128 @@ def raise_if_error() -> None: device=_error_storage.ref.sharding, ) # clear the error code - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) + with _error_list_lock: + msg, traceback = _error_list[error_code] + if isinstance(traceback, str): # from imported AOT functions + exc = JaxValueError( + f"{msg}\nThe original traceback is shown below:\n{traceback}" + ) + raise exc + else: + exc = JaxValueError(msg) + raise exc.with_traceback(traceback) + + +@dataclasses.dataclass(frozen=True) +class _ErrorClass: + """A class to store error information for AOT compilation. + + This class is used internally by the wrapper functions `wrap_for_export` and + `unwrap_from_import` to encapsulate error-related data within an exported + function. + + Attributes: + error_code (jax.Array): A JAX array representing the final error state of + the function to be exported. This value is local to the wrapper function. + error_list (list[tuple[str, str]]): A list of `(error_message, traceback)` + pairs containing error messages and corresponding stack traces. This error + list is local to the wrapper function, and does not contain pairs of error + information from other functions. + """ + + error_code: jax.Array + error_list: list[tuple[str, str]] + + +jax.tree_util.register_dataclass( + _ErrorClass, data_fields=("error_code",), meta_fields=("error_list",) +) +jax.export.register_pytree_node_serialization( + _ErrorClass, + serialized_name=f"{_ErrorClass.__module__}.{_ErrorClass.__name__}", + serialize_auxdata=lambda x: json.dumps(x, ensure_ascii=False).encode( + "utf-8" + ), + deserialize_auxdata=lambda x: json.loads(x.decode("utf-8")), +) + + +def _traceback_to_str(traceback: TracebackType) -> str: + """Convert a traceback to a string for export.""" + return "".join(tb_lib.format_list(tb_lib.extract_tb(traceback))).rstrip("\n") + + +def wrap_for_export(f): + """Wrap a function with error checking to make it compatible with AOT mode. + + Error checking relies on global state, which cannot be serialized across + processes. This wrapper ensures that the error state remains within the + function scope, making it possible to export the function and later import in + other processes. + + This function should only be applied once to a function; wrapping the same + function multiple times is unnecessary. + """ + + def inner(*args, **kwargs): + global _error_list + + # 1. Save the old state and initialize a new state. + with core.eval_context(): + old_ref = _error_storage.ref + _initialize_error_code_ref() + with _error_list_lock: + old_error_list, _error_list = _error_list, [] + + # 2. Trace the function. + out = f(*args, **kwargs) + error_code = _error_storage.ref[...].min() + + # 3. Restore the old state. + _error_list, new_error_list = old_error_list, _error_list + with core.eval_context(): + _error_storage.ref = old_ref + + new_error_list = [ + (msg, _traceback_to_str(traceback)) for msg, traceback in new_error_list + ] + return out, _ErrorClass(error_code, new_error_list) + + return inner + + +def unwrap_from_import(f): + """Unwrap a function after AOT import to restore error checking. + + When an AOT-exported function is imported in a new process, its error state is + separate from the global error state of the current process. This wrapper + ensures that errors detected during execution are correctly integrated into + the global error checking mechanism of the current process. + """ + if _error_storage.ref is None: + with core.eval_context(): + _initialize_error_code_ref() + assert _error_storage.ref is not None + + def inner(*args, **kwargs): + out, error_class = f(*args, **kwargs) + new_error_code, error_list = error_class.error_code, error_class.error_list + + # Update the global error list. + with _error_list_lock: + offset = len(_error_list) + _error_list.extend(error_list) + + # Update the global error code array. + error_code = _error_storage.ref[...] + should_update = jnp.logical_and( + error_code == jnp.uint32(_NO_ERROR), + new_error_code != jnp.uint32(_NO_ERROR), + ) + error_code = jnp.where(should_update, new_error_code + offset, error_code) + # TODO(ayx): support vmap and shard_map. + _error_storage.ref[...] = error_code + + return out + + return inner diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 5bf71a9eb592..69e292fb6704 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -22,6 +22,7 @@ from jax._src import error_check from jax._src import mesh as mesh_lib from jax._src import test_util as jtu +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -33,7 +34,9 @@ jtu.request_cpu_devices(4) -@jtu.with_config(jax_check_tracer_leaks=True) +# TODO: AOT tests fails with the tracer leak checker. +# Reenable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): @parameterized.product(jit=[True, False]) @@ -280,6 +283,60 @@ def f(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() + def test_error_check_aot(self): + def run_export(): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + f = jax.jit(error_check.wrap_for_export(jax.jit(f))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.) + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_should_not_override_existing_error(self): + def f1(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") + return x + 1 + + def run_export(): + def f2(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f2") + return x + 1 + + f2 = jax.jit(error_check.wrap_for_export(jax.jit(f2))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f2)(x).serialize() + return serialized + + def run_import(serialized): + f2 = jax.export.deserialize(serialized).call + f2 = jax.jit(error_check.unwrap_from_import(jax.jit(f2))) + return f2 + + x = jnp.float32(-3.) + _ = f1(x) # check fails. so it should set error + + serialized = run_export() + f2 = run_import(serialized) + _ = f2(x) # check fails, but should not override the error + + with self.assertRaisesRegex( + JaxValueError, "x must be greater than 0 in f1" + ): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3163fbaac43c8d8187efbd58ee42b333560cf42f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Mar 2025 10:25:38 -0700 Subject: [PATCH 0086/1769] Add varying manual axes rules to `mul_p` and `convert_element_type_p`. There are 2 things that need to be added: 1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has. 2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray. This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules. * pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty. * psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty. Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives PiperOrigin-RevId: 739225392 --- jax/_src/core.py | 44 +++++++++++++++++------- jax/_src/lax/lax.py | 19 +++++++++-- jax/_src/lax/linalg.py | 2 +- jax/_src/lax/utils.py | 14 +++++--- jax/experimental/shard_map.py | 64 +++++++++++++++++++++++++++++++---- tests/shard_map_test.py | 22 ++++++++++++ 6 files changed, 138 insertions(+), 27 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 243ffc871042..ef90341f5cf7 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1900,26 +1900,35 @@ def get_sharding(sharding, shape): _check_divisibility(out_s, shape) return out_s -def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, - mesh_axis_types=False) -> str: +def str_short_aval(shape, dtype, mesh, spec, vma, + short_dtypes=False, mesh_axis_types=False) -> str: dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = _get_shape_sharding_str(shape, spec) mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - return f'{dt_str}[{shapestr}]{mesh_axes}' + vma = f"{{{','.join(i for i in vma)}}}" if vma else '' + return f'{dt_str}[{shapestr}]{vma}{mesh_axes}' + +def get_vma(vma, mesh): + for i in vma: + if mesh._name_to_type[i] != AxisType.Manual: + raise ValueError( + "Axes mentioned in `vma` field of ShapedArray should" + f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + return vma class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'vma'] # inherits slots from parent array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, *, sharding=None, - varying_manual_axes: frozenset[AxisName] = frozenset()): + vma: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) if config.varying_axes_in_types.value: - self.varying_manual_axes = varying_manual_axes + self.vma = get_vma(vma, self.sharding.mesh) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1930,8 +1939,8 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding - if 'varying_manual_axes' not in kwargs: - kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', + if 'vma' not in kwargs: + kwargs['vma'] = getattr(self, 'vma', frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) @@ -1950,25 +1959,26 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'varying_manual_axes', frozenset()) == - getattr(other, 'varying_manual_axes', frozenset()))) + and (getattr(self, 'vma', frozenset()) == + getattr(other, 'vma', frozenset()))) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'varying_manual_axes', frozenset()))) + getattr(self, 'vma', frozenset()))) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type, sharding=self.sharding, - varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) + vma=getattr(self, 'vma', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, + getattr(self, 'varying_manual_axes', frozenset()), short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): @@ -2000,6 +2010,16 @@ def primal_dtype_to_tangent_dtype(primal_dtype): return primal_dtype +def standard_insert_pbroadcast(*args): + if not config.varying_axes_in_types.value: + return args + # TODO(yashkatariya): Move pbroadcast out of shard_map + from jax.experimental.shard_map import pbroadcast # type: ignore + in_vma = [get_aval(a).vma for a in args] + out_vma = frozenset.union(*in_vma) + return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) + if out_vma - src else arg for arg, src in zip(args, in_vma)] + # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f6ab848ccd5b..a6a0924c9c0f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1140,6 +1140,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, y = core.standard_insert_pbroadcast(x, y) return mul_p.bind(x, y) @export @@ -1610,6 +1611,7 @@ def _convert_element_type( (sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))): return operand else: + operand, = core.standard_insert_pbroadcast(operand) return convert_element_type_p.bind( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) @@ -3844,6 +3846,15 @@ def broadcasting_sharding_rule(name, *avals): f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) +def standard_vma_rule(prim_name, *avals, **kwargs): + vma, *vmas = [a.vma for a in avals] + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_rep=False argument to shard_map') + return vma def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): @@ -3852,8 +3863,9 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) sharding_rule = partial(broadcasting_sharding_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name, - sharding_rule=sharding_rule) + prim = standard_primitive( + shape_rule, dtype_rule, name, sharding_rule=sharding_rule, + vma_rule=partial(standard_vma_rule, name)) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -4704,7 +4716,8 @@ def _convert_element_type_bind_with_trace(trace, args, params): partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, - _convert_element_type_sharding_rule)) + _convert_element_type_sharding_rule, + partial(standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 027ec8b801b9..b22a4cf56062 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -765,7 +765,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule)) + lax_utils._standard_weak_type_rule, sharding_rule, None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 206a8312ba8c..63088d665afd 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -19,6 +19,7 @@ from functools import partial from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib @@ -37,13 +38,13 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule)) + weak_type_rule, sharding_rule, vma_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -95,14 +96,14 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, - short_dtypes=True) + frozenset(), short_dtypes=True) raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, *avals, **kwargs): + sharding_rule, vma_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -112,8 +113,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) + out_vma = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value + else frozenset()) out_aval = core.ShapedArray( - out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, + vma=out_vma) core.check_avals_context_mesh([out_aval], prim.name) return out_aval elif least_specialized is core.DShapedArray: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 66b70c6c2d34..c0306f0c5e91 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -189,7 +189,8 @@ def out_names_thunk(): raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) - if rewrite := check_rep: + rewrite = check_rep + if not config.varying_axes_in_types.value and rewrite: fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) try: @@ -577,7 +578,8 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - return aval.update(shape=new_shape, sharding=new_sharding) + vma = frozenset({n for ns in names.values() for n in ns}) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, @@ -606,7 +608,7 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else get_abstract_mesh()) new_sharding = NamedSharding(new_mesh, out_spec) - return aval.update(shape=new_shape, sharding=new_sharding) + return aval.update(shape=new_shape, sharding=new_sharding, vma=frozenset()) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking @@ -1069,7 +1071,41 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) -psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) + +def _psum2_abstract_eval(*args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return lax_parallel.psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + lax_parallel._check_axis_names(axes) + arg_vma = [a.vma for a in args] + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + "Collective psum must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +psum2_p.def_effectful_abstract_eval(_psum2_abstract_eval) + mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) batching.fancy_primitive_batchers[psum2_p] = \ partial(lax_parallel._batched_reduction_collective, psum2_p, @@ -1088,10 +1124,26 @@ def pbroadcast(x, axis_name): xs, treedef = tree_flatten(x) ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) return tree_unflatten(treedef, ys) + pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.multiple_results = True pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) -pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) + +def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return args + assert isinstance(axes, tuple) + arg_vma = [a.vma for a in args] + if any(set(axes) & a for a in arg_vma): + raise ValueError( + "Collective pbroadcast must be applied to a " + f"non-device-varying type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return [a.update(vma=a.vma.union(frozenset(axes))) for a in args] +pbroadcast_p.def_abstract_eval(_pbroadcast_abstract_eval) + mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): if any(type(axis) is int for axis in axes): raise NotImplementedError @@ -1140,7 +1192,7 @@ def _standard_check(prim, mesh, *in_rep, **__): # The standard check require args' and outputs' replications to be the same, # except for Nones which correspond to constants. in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: + if in_rep_ and in_rep_[:-1] != in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " "https://github.com/jax-ml/jax/issues and as a temporary " diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f8d5a11e842f..ce01b6e6e944 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2685,6 +2685,28 @@ def test_pmax(self): )(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + @config.varying_axes_in_types(True) + def test_mul_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset({'x'})) + out = x * 2 + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('x',)", str(jaxpr)) + out = f(x) + self.assertArraysEqual(out, x * 2) + + # TODO(yashkatariya): Enable grad test which requires adding psum_p support. + # def g(x, y): + # return jnp.sum(f(x, y)) + # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + class FunSpec(NamedTuple): name: str From 37b5066d5bc2f0d3915050b7522f189bf61e125d Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 21 Mar 2025 10:45:57 -0700 Subject: [PATCH 0087/1769] [Pallas] Fixes scalar prefetch in TPU interpret mode. --- jax/_src/pallas/mosaic/interpret.py | 42 ++++++++++++++--------- tests/pallas/tpu_pallas_interpret_test.py | 33 ++++++++++++++++++ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 3384026c1f5b..439ac98b2ac6 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1282,17 +1282,20 @@ def f(*args, jaxpr): return jax.util.safe_map(read, jaxpr.outvars) -def _compute_start_indices(block_mapping, loop_idx, *args): - block_indices = ( - jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") - return ret +def _compute_start_indices( + block_mapping, loop_idx, *args, compiler_params, interpret_params): + jaxpr = block_mapping.index_map_jaxpr + block_indices = _interpret_jaxpr( + jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, + compiler_params=compiler_params, interpret_params=interpret_params) + if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): + ret = tuple(i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices)) + elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + ret = block_indices + else: + raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") + return ret def _get_next_indices(grid, indices): next_indices = [] @@ -1412,6 +1415,7 @@ def interpret_pallas_call( input_buffer_ids = [] for i, var in enumerate( jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + assert var.aval.dtype == input_args[i].dtype input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1451,15 +1455,18 @@ def interpret_pallas_call( # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, # outputs, scratch). - kernel_buffer_ids = [] - for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): - kernel_buffer_ids.append(callback.io_callback( + scalar_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + assert var.aval.shape == val.shape + assert var.aval.dtype == val.dtype + scalar_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], val, ordered=True)) + kernel_buffer_ids = scalar_buffer_ids.copy() for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs @@ -1520,11 +1527,14 @@ def body(carry): ) with pallas_core.grid_env(local_grid_env): + start_indices = [ + _compute_start_indices( + bm, loop_idx, *scalar_buffer_ids, compiler_params=compiler_params, + interpret_params=interpret_params) + for bm in grid_mapping.block_mappings] # Copy slices of the input to the kernel buffers. # # TODO(jburnim): Only copy slices when the index mapping has changed? - start_indices = [_compute_start_indices(bm, loop_idx, *scalars) - for bm in grid_mapping.block_mappings] for j, var in enumerate(input_vars): if _is_any(var.aval.memory_space): continue diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9b8a5b46865d..5b729f0fe07e 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized +import functools import jax from jax._src import test_util as jtu @@ -68,6 +69,38 @@ def matmul(x: jax.Array, y: jax.Array): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) + def test_scalar_prefetch_example(self): + def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(2,)) + def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[pl.BlockSpec( + sizes, + lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + interpret=mosaic_interpret.TPUInterpretParams(), + ) + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + + shape = (512, 512) + x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) + result = block_dynamic_slice(x, starts=jnp.array([128, 256]), sizes=(128, 128)) + ref = jax.lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) + diff = jnp.max(jnp.abs(result - ref)) + np.testing.assert_allclose(result, ref) + def test_dynamic_grid_and_aliasing(self): self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): From 7dd78d97fad15f47295a25896833abafb92601e0 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 10:52:34 -0700 Subject: [PATCH 0088/1769] Add support for configurable error checking categories PiperOrigin-RevId: 739234594 --- jax/_src/config.py | 40 +++++++++- jax/_src/error_check.py | 149 ++++++++++++++++++++++++++++++++------ tests/error_check_test.py | 86 ++++++++++++++++++++++ 3 files changed, 250 insertions(+), 25 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index cf6a07834a10..5b8b87be2095 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -245,7 +245,10 @@ def trace_context(): pgle_profiling_runs.value, enable_pgle.value, use_shardy_partitioner.value, - use_high_dynamic_range_gumbel.value) + use_high_dynamic_range_gumbel.value, + error_checking_behavior_nan.value, + error_checking_behavior_divide.value, + error_checking_behavior_oob.value) config = Config() @@ -1317,6 +1320,41 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ), ) +# TODO(ayx): Move these 3 flags out of config once we have a user-level +# extension mechanism for adding contexts to which the jit cache is sensitive. +error_checking_behavior_nan = enum_state( + name='jax_error_checking_behavior_nan', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a NaN is encountered. Options are "ignore"' + ' or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_divide = enum_state( + name='jax_error_checking_behavior_divide', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a divide by zero is encountered. Options are' + ' "ignore" or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_oob = enum_state( + name='jax_error_checking_behavior_oob', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when an out of bounds access is encountered.' + ' Options are "ignore" or "raise".' + ), + include_in_jit_key=True, +) + def _update_x64_global(val): jax_jit.global_state().enable_x64 = val diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index e78b9bc82115..9d493c1f351b 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,15 +14,18 @@ from __future__ import annotations +import contextlib import dataclasses from functools import partial import json import threading import traceback as tb_lib from types import TracebackType +from typing import Literal import warnings import jax +from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import traceback_util @@ -115,39 +118,56 @@ def __exit__(self, exc_type, exc_value, traceback): _error_storage.ref = self.old_ref -def set_error_if(pred: jax.Array, /, msg: str) -> None: +# TODO(ayx): Move all category-related logic into the jax.numpy integration once +# it is ready. This logic is specific to how jax.numpy decides when to call +# `set_error_if`, and doesn't belong in the core error-checking library itself. +# The responsibility for deciding whether to predicate an error should lie with +# the user or the higher-level library (like jax.numpy), not with +# `set_error_if`. +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + return config.error_checking_behavior_nan.value == "ignore" + if category == "divide": + return config.error_checking_behavior_divide.value == "ignore" + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: jax.Array, + /, + msg: str, + category: Category | None = None, +) -> None: """Set the internal error state if any element of `pred` is `True`. - This function is used inside JAX computations to detect runtime errors without - immediately halting execution. When this function is traced (e.g., inside - :func:`jax.jit`), the corresponding error message and its traceback are - recorded. At execution time, if `pred` contains any `True` values, the error - state is set, but execution continues without interruption. The recorded error - can later be raised using :func:`raise_if_error`. - - If the error state has already been set, subsequent errors are ignored and - will not override the existing error. - - For multi-device environments, in explicit mode, users must call - :func:`error_checking_context` to initialize a new error tracking state that - matches the device mesh. In auto mode, implicit cross-device communication may - occur inside this function, which could impact performance. A warning is - issued in such cases. + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. - When exporting a function with `jax.export`, error checking must be explicitly - wrapped using :func:`wrap_for_export` before export and - :func:`unwrap_from_import` after import. - - Args: - pred: A JAX boolean array. If any element of `pred` is `True`, the internal - error state will be set. - msg: The corresponding error message to be raised later. + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. """ if _error_storage.ref is None: with core.eval_context(): _initialize_error_code_ref() assert _error_storage.ref is not None + if _is_category_disabled(category): + return + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None @@ -199,6 +219,37 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: _error_storage.ref[...] = error_code +def set_error_if(pred: jax.Array, /, msg: str) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. + """ + _set_error_if_with_category(pred, msg) + + def raise_if_error() -> None: """Raise an exception if the internal error state is set. @@ -355,3 +406,53 @@ def inner(*args, **kwargs): return out return inner + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 69e292fb6704..7f75eeb629a0 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -337,6 +337,92 @@ def run_import(serialized): ): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_category_nan_check(self, jit): + def f(x): + error_check._set_error_if_with_category( + jnp.isnan(x), "x is NaN", category="nan" + ) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with error_check.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x is NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_divide_check(self, jit): + def f(x, y): + error_check._set_error_if_with_category( + y == 0.0, "division by zero", category="divide" + ) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with error_check.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + error_check._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with error_check.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + error_check._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 4fdce200300df181b3088dd0d114e5c759dbf63d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 11:16:46 -0700 Subject: [PATCH 0089/1769] Add logit soft-capping support to the ragged paged attention Pallas kernel. PiperOrigin-RevId: 739242412 --- .../pallas/ops/tpu/ragged_paged_attention.py | 12 +++ .../pallas/tpu_ragged_paged_attention_test.py | 94 ++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 90b808282c22..60ac2e34f610 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -81,6 +81,7 @@ def ref_ragged_paged_attention( *, sm_scale: float = 1.0, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): _, _, num_kv_heads, head_dim = k_pages.shape @@ -108,6 +109,8 @@ def ref_ragged_paged_attention( mask = q_span < kv_span if sliding_window is not None: mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) @@ -126,6 +129,7 @@ def validate_inputs_on_runtime( cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] sliding_window: int | None = None, + soft_cap: float | None = None, ): check_inputs_shapes( q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs @@ -156,6 +160,8 @@ def validate_inputs_on_runtime( ) if sliding_window is not None and sliding_window <= 0: raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") # Expect to run these checks during compile time. @@ -228,6 +234,7 @@ def ragged_paged_attention_kernel( *, sm_scale: float, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape @@ -432,6 +439,8 @@ def init_scratch_ref(): if sliding_window is not None: causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -612,6 +621,7 @@ def can_be_xla_fully_tiled(x, packing): "num_queries_per_block", "vmem_limit_bytes", "sliding_window", + "soft_cap", ], ) def ragged_paged_attention( @@ -626,6 +636,7 @@ def ragged_paged_attention( *, sm_scale: float = 1.0, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, @@ -719,6 +730,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ragged_paged_attention_kernel, sm_scale=sm_scale, sliding_window=sliding_window, + soft_cap=soft_cap, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index ba574a4ce98c..815c9dc6327f 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -52,6 +52,7 @@ def _test_ragged_paged_attention( max_num_batched_tokens=512, max_num_seq=8, sliding_window: int | None = None, + soft_cap: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -104,6 +105,7 @@ def _test_ragged_paged_attention( cu_q_lens, num_seqs, sliding_window=sliding_window, + soft_cap=soft_cap, ) actual_num_q_tokens = cu_q_lens[num_seqs[0]] @@ -119,6 +121,7 @@ def _test_ragged_paged_attention( num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, sliding_window=sliding_window, + soft_cap=soft_cap, )[: actual_num_q_tokens] expected = ref_ragged_paged_attention( @@ -130,6 +133,7 @@ def _test_ragged_paged_attention( cu_q_lens, num_seqs=num_seqs, sliding_window=sliding_window, + soft_cap=soft_cap, ) tols = { "float32": 0.15, @@ -272,7 +276,6 @@ def test_ragged_paged_attention_mixed(self, dtype): dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], - sliding_window=[None, 5, 128], ) def test_ragged_paged_attention_complex( self, @@ -281,8 +284,42 @@ def test_ragged_paged_attention_complex( dtype, num_kv_pages_per_block, num_queries_per_block, + ): + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], + ) + def test_ragged_paged_attention_sliding_window( + self, + num_kv_pages_per_block, + num_queries_per_block, sliding_window: int | None, ): + num_seqs = 5 + num_heads = (4, 4) + dtype = jnp.float32 seq_lens = [] for _ in range(num_seqs): q_len = random.randint(1, 100) @@ -305,6 +342,41 @@ def test_ragged_paged_attention_complex( sliding_window=sliding_window, ) + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + soft_cap=[None, 50.0], + ) + def test_ragged_paged_attention_logit_soft_capping( + self, + num_kv_pages_per_block, + num_queries_per_block, + soft_cap: float | None, + ): + num_heads = (12, 2) + num_seqs = 2 + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + soft_cap=soft_cap, + ) + def test_ragged_paged_attention_sliding_window_should_be_positive(self): dtype = jnp.float32 seq_lens = [(192, 328), (128, 180), (64, 255)] @@ -335,5 +407,25 @@ def test_ragged_paged_attention_sliding_window_should_be_positive(self): sliding_window=-1, ) + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must not be 0.0"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + soft_cap=0.0, + ) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e23069b39cd747563abd53132eaee15290ce0ce2 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Fri, 21 Mar 2025 11:42:14 -0700 Subject: [PATCH 0090/1769] Allow forcing pallas forward compatibility for some backends PiperOrigin-RevId: 739249745 --- jax/_src/interpreters/mlir.py | 13 ++++++++++--- jax/_src/xla_bridge.py | 3 +++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f96f07be4149..a707981f5403 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -847,11 +847,18 @@ def is_forward_compat(self) -> bool: """Returns true if the lowering parameters are in forward compatibility mode. """ lowering_parameters = self.module_context.lowering_parameters - return ( - lowering_parameters.for_export - and not lowering_parameters.export_ignore_forward_compatibility + + check_platforms: Sequence[str] = ( + self.platforms or self.module_context.platforms + ) + force_forward_compat = any( + p in xb.FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS for p in check_platforms ) + return ( + lowering_parameters.for_export or force_forward_compat + ) and not lowering_parameters.export_ignore_forward_compatibility + if not MYPY: class LoweringRule(Protocol): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index be96deab81d8..72d88b9735b7 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -60,6 +60,9 @@ XlaBackend = xla_client.Client +# The platforms in this set will force forward compatibility for lowering. +FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS: set[str] = set() + MIN_COMPUTE_CAPABILITY = 52 _DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' From 53e8eac7134a13c1d28de673e7e3a23b4a837aed Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Fri, 21 Mar 2025 12:12:05 -0700 Subject: [PATCH 0091/1769] Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1 PiperOrigin-RevId: 739258607 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16355695792d..96efc48062e1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,16 +49,15 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize -from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -66,7 +65,8 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax.tree_util import tree_flatten, tree_map +from jax._src.sharding_impls import SingleDeviceSharding +from jax.tree_util import tree_leaves, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,7 +5504,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) + object = tree_map(lambda leaf: leaf.__jax_array__() + if hasattr(leaf, "__jax_array__") else leaf, object) + leaves = tree_leaves(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5513,13 +5515,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves, treedef = tree_flatten(object) - leaves = [ - leaf - if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None - else leaf_jax_array() - for leaf in leaves - ] + leaves = tree_leaves(object) if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5534,8 +5530,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - object = treedef.unflatten(leaves) out: ArrayLike + if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From 520b44fc5ca70d6bb5d70e539ac6e53b2e53072b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 12:50:37 -0700 Subject: [PATCH 0092/1769] Ensure traceback correctness in error checking in AOT mode This PR is similar to https://github.com/jax-ml/jax/pull/27329. The difference is that in AOT mode, the original traceback is exported as a string and appended to the error message when an exception is raised. PiperOrigin-RevId: 739270141 --- tests/error_check_test.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 7f75eeb629a0..af3f35c7ab62 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -305,6 +305,41 @@ def run_import(serialized): serialized = run_export() run_import(serialized) + def test_error_check_aot_includes_traceback(self): + def run_export(): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback + x <= 0, "x must be greater than 0" + ) + return x + 1 + + f = jax.jit( + error_check.wrap_for_export( + jax.jit(function_that_triggers_error_for_traceback_test) + ) + ) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.0) + _ = f(x) + + msg = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + msg = str(e) + + self.assertIn("function_that_triggers_error_for_traceback_test", msg) + self.assertIn("This line must be included in the traceback", msg) + + serialized = run_export() + run_import(serialized) + def test_error_check_aot_should_not_override_existing_error(self): def f1(x): error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") From e71bcde543ba4db23f382f0aca7d9d6fe4227f06 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 13:22:36 -0700 Subject: [PATCH 0093/1769] Remove some long-stale version guards. PiperOrigin-RevId: 739279729 --- jax/_src/tpu_custom_call.py | 10 +++------- jaxlib/xla/xla_client_test.py | 2 -- tests/array_test.py | 2 -- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 4089e047f8b0..e37d5e064a26 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -484,13 +484,9 @@ def _lower_mosaic_module_to_asm( module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True - # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. - if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): - target_version = "" - else: - target_version = ( - f"target-version={ir_version}" if ir_version is not None else "" - ) + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: pipeline = PassManager.parse( "builtin.module(mosaic-serde{serialize=true " + target_version + "})" diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index e228905637cb..5a2f3881f510 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -630,8 +630,6 @@ def testStatefulCustomCall(self): def testCustomCallLookup(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") - if xla_client._version < 241: - self.skipTest("Test requires jaxlib version 241") self.assertTrue(_CUSTOM_CALLS_REGISTERED) xla_client.make_cpu_client() diff --git a/tests/array_test.py b/tests/array_test.py index 6100283cc032..5891db5a3e36 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -368,8 +368,6 @@ def test_different_devices_in_arrays_than_sharding(self): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_duplicated_devices_in_arrays(self): - if xc._version <= 274: - self.skipTest('Test requires jaxlib version 275') shape = (8, 2) mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} From ba5be78cdd218136506d5a11b10a793a8692aae2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 13:26:09 -0700 Subject: [PATCH 0094/1769] Remove symlinking of xla_client.py. Use a stub instead. Symlinking led to confusing behaviors because Python may believe there are two copies of the module. PiperOrigin-RevId: 739280690 --- jaxlib/BUILD | 9 +-------- jaxlib/tools/build_wheel.py | 2 +- jaxlib/xla_client.py | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 jaxlib/xla_client.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 2397639fddf2..5f693c5384df 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -50,8 +50,8 @@ py_library_providing_imports_info( "init.py", "lapack.py", "plugin_support.py", + "xla_client.py", ":version", - ":xla_client", ":xla_extension_py", ], data = [":ffi_headers"], @@ -92,13 +92,6 @@ symlink_files( flatten = True, ) -symlink_files( - name = "xla_client", - srcs = ["//jaxlib/xla:xla_client"], - dst = ".", - flatten = True, -) - symlink_files( name = "ffi_headers", srcs = ["@xla//xla/ffi/api:all_headers"], diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 8632468acb97..0e0ce077cb23 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -197,7 +197,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/gpu_sparse.py", "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", - "__main__/jaxlib/xla_client.py", + "__main__/jaxlib/xla/xla_client.py", f"xla/xla/python/xla_extension.{pyext}", ], ) diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py new file mode 100644 index 000000000000..01b01ecf704e --- /dev/null +++ b/jaxlib/xla_client.py @@ -0,0 +1,18 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_client import * # noqa: F403 +from jaxlib.xla.xla_client import _version # noqa: F401 +from jaxlib.xla.xla_client import _xla # noqa: F401 From 93f3e4aa19cd2b892ad2c788f2f1a2ebdb853ce6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 08:41:56 -0400 Subject: [PATCH 0095/1769] Increase the test timeout for tsan builds. Update the list of TSAN suppressions. Issue #27244 --- .github/workflows/tsan-suppressions.txt | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 7b713b2da194..296f4432e687 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -2,14 +2,11 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # Likely only happens when the process is crashing. race:dump_traceback @@ -18,19 +15,15 @@ race:dump_traceback # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - - # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx - # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi @@ -65,3 +58,10 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added + +# https://github.com/python/cpython/issues/128130 +# race_top:run_eval_code_obj + +# https://github.com/python/cpython/issues/129547 +# Maybe fixed? +# race:type_get_annotations From 6b7744581b6810d1fab176994e591b8ccb4f6f5b Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Fri, 21 Feb 2025 16:51:36 -0600 Subject: [PATCH 0096/1769] [Pallas] [1/3] Move communication primitives from mosaic to core --- jax/_src/pallas/core.py | 4 + jax/_src/pallas/mosaic/core.py | 11 +- jax/_src/pallas/mosaic/helpers.py | 4 +- jax/_src/pallas/mosaic/interpret.py | 8 +- jax/_src/pallas/mosaic/lowering.py | 30 +-- jax/_src/pallas/mosaic/primitives.py | 277 +-------------------------- jax/_src/pallas/primitives.py | 264 +++++++++++++++++++++++++ jax/experimental/pallas/__init__.py | 5 + jax/experimental/pallas/tpu.py | 14 +- 9 files changed, 314 insertions(+), 303 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 466f6037a8ef..389bbd3b0733 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -65,6 +65,10 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max +class semaphore_dtype(dtypes.extended): pass +class semaphore(semaphore_dtype): pass +class dma_semaphore(semaphore_dtype): pass +class barrier_semaphore(semaphore_dtype): pass @runtime_checkable class CompilerParams(Protocol): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3e60e471dfa2..5d503779f092 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -112,11 +112,6 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass - class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: @@ -142,15 +137,15 @@ def __hash__(self) -> int: # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy class SemaphoreTy(AbstractSemaphoreTy): - type = semaphore + type = pallas_core.semaphore name = "sem" class DmaSemaphoreTy(AbstractSemaphoreTy): - type = dma_semaphore + type = pallas_core.dma_semaphore name = "dma_sem" class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = barrier_semaphore + type = pallas_core.barrier_semaphore name = "barrier_sem" class SemaphoreType(enum.Enum): diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 76421cec3340..24cd7cad6086 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -88,8 +88,8 @@ def signal_core(i): # Don't signal ourself @pl_helpers.when(core_id != i) def _(): - plm_primitives.semaphore_signal(sem, 1, core_index=i) + pl_primitives.semaphore_signal(sem, 1, core_index=i) for i in range(num_cores): signal_core(i) - plm_primitives.semaphore_wait(sem, num_cores - 1) + pl_primitives.semaphore_wait(sem, num_cores - 1) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index a731bfdfdae1..ba1a7b0017c4 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -943,9 +943,9 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): def _device_id_to_logical(device_id, device_id_type, axis_sizes): if device_id is None: return None - if device_id_type == mosaic_primitives.DeviceIdType.MESH: + if device_id_type == primitives.DeviceIdType.MESH: return _device_coords_to_logical_id(device_id, axis_sizes) - elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: + elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: raise ValueError(f'Unsupported device ID type: {device_id_type}') @@ -1223,7 +1223,7 @@ def f(*args, jaxpr): compiler_params['mosaic']['collective_id'], ordered=True) - elif prim is mosaic_primitives.semaphore_signal_p: + elif prim is primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( @@ -1239,7 +1239,7 @@ def f(*args, jaxpr): ordered=True) out = [] - elif prim is mosaic_primitives.semaphore_wait_p: + elif prim is primitives.semaphore_wait_p: sem, sem_transforms, value = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) callback.io_callback( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4efb2b276f56..3469ef4de952 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -192,12 +192,12 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: - if jnp.issubdtype(dtype, tpu_core.semaphore_dtype): - if jnp.issubdtype(dtype, tpu_core.dma_semaphore): + if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): + if jnp.issubdtype(dtype, pallas_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") - elif jnp.issubdtype(dtype, tpu_core.semaphore): + elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") - elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore): + elif jnp.issubdtype(dtype, pallas_core.barrier_semaphore): return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError @@ -3291,7 +3291,7 @@ def _alloc_value( ) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): memspace = _memory_space_to_mosaic_attribute(aval.memory_space) - if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3341,8 +3341,8 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): def _device_id_to_logical( ctx: LoweringRuleContext, device_id, - device_id_type: tpu_primitives.DeviceIdType): - if device_id_type is tpu_primitives.DeviceIdType.MESH: + device_id_type: primitives.DeviceIdType): + if device_id_type is primitives.DeviceIdType.MESH: # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides @@ -3357,7 +3357,7 @@ def _device_id_to_logical( for a, b in zip(device_ids, mesh_strides) ), ) - elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: + elif device_id_type is primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") @@ -3373,13 +3373,13 @@ def _semaphore_read_lowering_rule( return tpu.sem_read(sem) -lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule +lowering_rules[primitives.semaphore_read_p] = _semaphore_read_lowering_rule def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, - device_id_type: tpu_primitives.DeviceIdType, + device_id_type: primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( @@ -3392,7 +3392,7 @@ def _semaphore_signal_lowering_rule( return [] -lowering_rules[tpu_primitives.semaphore_signal_p] = ( +lowering_rules[primitives.semaphore_signal_p] = ( _semaphore_signal_lowering_rule) @@ -3402,10 +3402,10 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule +lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): ( src_ref, src_transforms, @@ -3445,7 +3445,7 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): del device_id_type (src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = ( tree_util.tree_unflatten(tree, args) @@ -3477,7 +3477,7 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _device_id_lowering_rule(ctx: LoweringRuleContext): return tpu.device_id() -lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule +lowering_rules[primitives.device_id_p] = _device_id_lowering_rule def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index fb0e0c2c55e3..106f342bace8 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,7 +16,6 @@ from __future__ import annotations import dataclasses -import enum from typing import Any import jax @@ -28,6 +27,7 @@ from jax._src import util from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge @@ -160,255 +160,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -class DeviceIdType(enum.Enum): - MESH = "mesh" - LOGICAL = "logical" - - -def check_sem_avals( - sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None -): - if allowed_semaphore_types is None: - allowed_semaphore_types = { - tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE, - } - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_transforms_avals: - sem_shape = sem_transforms_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not any( - jnp.issubdtype(sem_dtype, sem_type) - for sem_type in allowed_semaphore_types - ): - raise ValueError( - f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}. Got {sem_dtype}." - ) - - -def _transform_semaphore(ref_value, transforms, ref_aval): - """Helper function for indexing into a semaphore during state_discharge.""" - if ref_value.shape == ref_aval.shape: - return state_discharge.transform_array(ref_value, transforms) - elif len(ref_value.shape) == 0: - return ref_value - else: - raise ValueError( - f"Semaphore value shape {ref_value.shape} does not match aval shape" - f" {ref_aval.shape}" - ) - - -semaphore_read_p = jax_core.Primitive("semaphore_read") -semaphore_read_p.multiple_results = False - - -def semaphore_read(sem_or_view): - ref, transforms = _get_ref_and_transforms(sem_or_view) - args = [ref, transforms] - flat_args, args_tree = tree_util.tree_flatten(args) - return semaphore_read_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_read_p.def_abstract_eval -def _semaphore_read_abstract_eval( - *avals, - args_tree, -): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - tpu_core.dma_semaphore, - tpu_core.semaphore, - tpu_core.barrier_semaphore, - pl_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) - return jax_core.ShapedArray((), jnp.dtype("int32")) - -def _semaphore_read_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - sem_value = sem_value.astype(jnp.int32) - return (None,) * len(in_avals), sem_value -state_discharge.register_discharge_rule(semaphore_read_p)( - _semaphore_read_discharge_rule -) - - -semaphore_signal_p = jax_core.Primitive('semaphore_signal') -semaphore_signal_p.multiple_results = True - - -def semaphore_signal( - sem_or_view, - inc: int | jax.Array = 1, - *, - device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, - device_id_type: DeviceIdType = DeviceIdType.MESH, - core_index: int | jax.Array | None = None, -): - ref, transforms = _get_ref_and_transforms(sem_or_view) - inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, transforms, inc, device_id, core_index] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_signal_p.bind( - *flat_args, - args_tree=args_tree, - device_id_type=device_id_type, - ) - - -@semaphore_signal_p.def_abstract_eval -def _semaphore_signal_abstract_eval( - *avals, - args_tree, - device_id_type: DeviceIdType, -): - del device_id_type - ( - sem_aval, - sem_transforms_avals, - value_aval, - device_id_avals, - core_index_aval, - ) = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_transforms_avals, "signal") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) - for aval in device_id_flat_avals: - if aval.dtype != jnp.dtype("int32"): - raise ValueError("`device_id`s must be an int32 value.") - return [] - - -def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - device_ids, - _, - ) = tree_util.tree_unflatten(tree, invars) - out = pp.concat([ - pp.text("semaphore_signal"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) - if device_ids is not None: - flat_device_ids = tree_util.tree_leaves(device_ids) - if not flat_device_ids: - return out - device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] - for device_id in flat_device_ids[1:]: - device_ids_pp.append(pp.text(" ")) - device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) - out = pp.concat([out, pp.concat(device_ids_pp)]) - return out -jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn - - -def _semaphore_signal_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree, - device_id_type): - del out_avals, device_id_type - [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) - if device_id is not None: - raise NotImplementedError("Remote signal not implemented.") - if core_index is not None: - raise NotImplementedError("Multiple core support not implemented.") - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value + inc - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_signal_p)( - _semaphore_signal_discharge_rule -) - - -semaphore_wait_p = jax_core.Primitive('semaphore_wait') -semaphore_wait_p.multiple_results = True - -def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, transforms = _get_ref_and_transforms(sem_or_view) - dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, transforms, dec] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_wait_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_wait_p.def_abstract_eval -def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( - args_tree, avals - ) - check_sem_avals(sem_aval, sem_transforms_avals, "wait") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must wait an int32 value.") - return [] - -def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - ) = tree_util.tree_unflatten(tree, invars) - return pp.concat([ - pp.text("semaphore_wait"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) -jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn - -def _semaphore_wait_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms, dec] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value - dec - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_wait_p)( - _semaphore_wait_discharge_rule -) - - @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any @@ -420,7 +171,7 @@ class AsyncCopyDescriptor: src_sem: int | jax.Array | None src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None - device_id_type: DeviceIdType = DeviceIdType.MESH + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH def __post_init__(self): if (self.src_sem is None) ^ (self.device_id is None): @@ -610,14 +361,14 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, # TODO(justinfu): Verify that code only works in SPMD mode. axis_env = jax_core.get_axis_env() nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] - if device_id_type == DeviceIdType.LOGICAL: + if device_id_type == primitives.DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) - elif device_id_type == DeviceIdType.MESH: + elif device_id_type == primitives.DeviceIdType.MESH: device_id_len = 1 if isinstance(device_id, jax.Array): device_id_len = device_id.size @@ -667,7 +418,7 @@ def do_discharge_dst(dst_ref=dst_ref): def do_discharge_dst_sem(dst_sem=dst_sem): recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _transform_semaphore( + dst_sem_value = primitives._transform_semaphore( dst_sem, dst_sem_transforms, dst_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -678,7 +429,7 @@ def do_discharge_dst_sem(dst_sem=dst_sem): def do_discharge_src_sem(src_sem=src_sem): send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _transform_semaphore( + src_sem_value = primitives._transform_semaphore( src_sem, src_sem_transforms, src_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -778,7 +529,7 @@ def dma_wait_partial_discharge_rule(should_discharge, updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) + sem_value = primitives._transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( dst_sem, dst_sem_transforms, sem_value - copy_size ) @@ -814,7 +565,7 @@ def make_async_copy(src_ref, dst_ref, sem): None, None, None, - DeviceIdType.MESH, + primitives.DeviceIdType.MESH, ) def async_copy(src_ref, dst_ref, sem): @@ -824,7 +575,7 @@ def async_copy(src_ref, dst_ref, sem): return copy_descriptor def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. Copies data from src_ref on the current device to dst_ref on the device @@ -861,20 +612,12 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() return copy_descriptor -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind - get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') @get_barrier_semaphore_p.def_abstract_eval diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3306649f24f3..5d3444ef719f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -993,3 +993,267 @@ def _lower_fun(*lower_fun_args): return out[:num_return_values] return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) + + +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms + return ref, () + + +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): + if allowed_semaphore_types is None: + allowed_semaphore_types = { + pallas_core.semaphore, + pallas_core.barrier_semaphore, + # For interpret mode. + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + } + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + # Uncomment when semaphore type works for Mosaic-GPU lowering + # sem_dtype = sem_aval.dtype + # if not any( + # jnp.issubdtype(sem_dtype, sem_type) + # for sem_type in allowed_semaphore_types + # ): + # raise ValueError( + # f"Must {name} semaphores of the following types:" + # f" {allowed_semaphore_types}." + # ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view): + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + pallas_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) + return jax_core.ShapedArray((), jnp.dtype("int32")) + +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + + +semaphore_signal_p = jax_core.Primitive('semaphore_signal') +semaphore_signal_p.multiple_results = True + + +def semaphore_signal( + sem_or_view, + inc: int | jax.Array = 1, + *, + device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, +): + ref, transforms = _get_ref_and_transforms(sem_or_view) + inc = jnp.asarray(inc, dtype=jnp.int32) + args = [ref, transforms, inc, device_id, core_index] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_p.bind( + *flat_args, + args_tree=args_tree, + device_id_type=device_id_type, + ) + + +@semaphore_signal_p.def_abstract_eval +def _semaphore_signal_abstract_eval( + *avals, + args_tree, + device_id_type: DeviceIdType, +): + del device_id_type + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must signal an int32 value.") + if device_id_avals is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_avals) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError("`device_id`s must be an int32 value.") + return [] + + +def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + device_ids, + _, + ) = tree_util.tree_unflatten(tree, invars) + out = pp.concat([ + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) + if device_ids is not None: + flat_device_ids = tree_util.tree_leaves(device_ids) + if not flat_device_ids: + return out + device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] + for device_id in flat_device_ids[1:]: + device_ids_pp.append(pp.text(" ")) + device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) + out = pp.concat([out, pp.concat(device_ids_pp)]) + return out +jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + +semaphore_wait_p = jax_core.Primitive('semaphore_wait') +semaphore_wait_p.multiple_results = True + +def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): + ref, transforms = _get_ref_and_transforms(sem_or_view) + dec = jnp.asarray(dec, dtype=jnp.int32) + args = [ref, transforms, dec] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_wait_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_wait_p.def_abstract_eval +def _semaphore_wait_abstract_eval(*avals, args_tree): + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must wait an int32 value.") + return [] + +def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + ) = tree_util.tree_unflatten(tree, invars) + return pp.concat([ + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) +jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn + +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + dec = dec.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) + +device_id_p = jax_core.Primitive('device_id') + +@device_id_p.def_abstract_eval +def _device_id_abstract_eval(): + return jax_core.ShapedArray((), jnp.dtype("int32")) + +device_id = device_id_p.bind diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e0abacfc25f..ea58fae3d283 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -47,6 +47,7 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous @@ -55,8 +56,12 @@ from jax._src.pallas.primitives import program_id as program_id from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.utils import cdiv as cdiv from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 from jax._src.pallas.utils import strides_from_shape as strides_from_shape diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ecc9d0d15120..c81edaf76fa3 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -17,11 +17,11 @@ from jax._src.pallas.mosaic import core as core from jax._src.pallas.mosaic.core import ARBITRARY as ARBITRARY from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh -from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore +from jax._src.pallas.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore as semaphore +from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams @@ -40,8 +40,8 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.mosaic.primitives import device_id as device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.primitives import device_id as device_id +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy @@ -49,9 +49,9 @@ from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll -from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.mosaic.random import sample_block as sample_block from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key From 2692c5ff98c6dfbcf45b9f8d26db4cc9c2a67a79 Mon Sep 17 00:00:00 2001 From: Praveen Narayanan Date: Fri, 21 Mar 2025 17:35:37 -0700 Subject: [PATCH 0097/1769] Lower lax.ragged_dot_general to chlo.ragged_dot in some cases on tpu. PiperOrigin-RevId: 739348011 --- jax/_src/lax/lax.py | 131 ++++++++++++++++++++++++++++++++++++------- tests/export_test.py | 4 +- tests/lax_test.py | 92 ++++++++++++++++++++++++++++-- 3 files changed, 200 insertions(+), 27 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a6a0924c9c0f..80a469ab6a11 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -67,6 +67,7 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import (PmapSharding, NamedSharding, + ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, @@ -5378,15 +5379,26 @@ def _dot_general_batch_unpack_dims(batch_dims): core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule -def precision_attr(precision: Precision) -> ir.ArrayAttr: + +def _full_precision(precision: Precision) -> tuple[Precision, Precision]: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): - full_precision = (Precision.DEFAULT, Precision.DEFAULT) + return (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): - full_precision = (precision, precision) + return (precision, precision) else: - full_precision = precision + return precision + + +def precision_attr(precision: Precision) -> ir.ArrayAttr: return ir.ArrayAttr.get( - [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) + + +def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr: + return ir.ArrayAttr.get( + [chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, @@ -5424,9 +5436,7 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type -def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: np.dtype | None, - out_sharding, platform: str = "default"): +def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -5437,19 +5447,12 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): if dtypes.float8_e8m0fnu is not None: fp8_dtypes += (dtypes.float8_e8m0fnu,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes - del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = ctx.avals_in + + # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. + lhs_aval, rhs_aval, *_ = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out accumulation_aval = aval_out - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch), - rhs_batching_dimensions=list(rhs_batch), - lhs_contracting_dimensions=list(lhs_contracting), - rhs_contracting_dimensions=list(rhs_contracting)) - algorithm_kwarg = {} if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): # The CPU backend silently ignores the algorithm spec, so we check here to @@ -5507,7 +5510,22 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + return lhs, rhs, accumulation_aval, algorithm_kwarg + +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, + precision, preferred_element_type: np.dtype | None, + out_sharding, platform: str = "default"): + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + dot_dnums = hlo.DotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting)) result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), lhs, @@ -5516,7 +5534,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): precision_config=precision_attr(precision), **algorithm_kwarg, ) - + aval_out, = ctx.avals_out result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -6035,10 +6053,85 @@ def expand(x, dim, gs, *axes): ) +def _ragged_dot_general_lower( + ctx, + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: np.dtype | None, + group_offset: Array | None = None, + platform: str = 'default', +): + if group_offset is not None: + raise NotImplementedError('Unimplemented group_offset support.') + + # TODO(pravnar): Remove this once we have sharding support. + def use_default_lowering(): + axis_context = ctx.module_context.axis_context + return ( + isinstance(axis_context, SPMDAxisContext) + or isinstance(axis_context, ShardingContext) + and axis_context.num_devices > 1 + ) + if use_default_lowering(): + result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)( + ctx, lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset + ) + (aval_out,) = ctx.avals_out + return mlir.lower_with_sharding_in_types(ctx, result, aval_out) + + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, _ = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting), + lhs_ragged_dimensions=list( + ragged_dot_dimension_numbers.lhs_ragged_dimensions + ), + rhs_group_dimensions=list( + ragged_dot_dimension_numbers.rhs_group_dimensions + ), + ) + result = chlo.ragged_dot( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + group_sizes, + ragged_dot_dnums, + precision_config=chlo_precision_attr(precision), + ) + (aval_out,) = ctx.avals_out + result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] + + mlir.register_lowering(ragged_dot_general_p, mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)) +for platform in ['tpu']: + mlir.register_lowering( + ragged_dot_general_p, + partial(_ragged_dot_general_lower, platform=platform), + platform=platform, + ) + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, sharding): diff --git a/tests/export_test.py b/tests/export_test.py index 2b083f3121f4..0b78a29a8e6a 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1903,8 +1903,8 @@ def f_jax(x): @jtu.parameterized_filterable( kwargs=[ - {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, - {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + {"m": 64, "k": 4, "n": 3, "group_sizes": [5]}, + {"m": 64, "k": 9, "n": 8, "group_sizes": [3, 7]}, ]) def test_ragged_dot(self, m, k, n, group_sizes): def f_jax(x, y, gs): diff --git a/tests/lax_test.py b/tests/lax_test.py index f7cca2c9b48f..40f2eb8f3588 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4796,10 +4796,10 @@ class RaggedTest(jtu.JaxTestCase): @jtu.sample_product( [ - {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, - {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, + {'m': 64, 'k': 4, 'n': 3, 'num_groups': 1}, + {'m': 64, 'k': 9, 'n': 8, 'num_groups': 2}, ], - dtype=jtu.dtypes.numeric, + dtype=jtu.dtypes.all_floating, ) def test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. @@ -4810,6 +4810,8 @@ def test_ragged_dot(self, m, k, n, num_groups, dtype): Raises: SkipTest: in the case dtype is not supported. """ + if (dtype == np.float16): + raise SkipTest(f"unsupported dtype for ragged_dot: {dtype}") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) @@ -4831,6 +4833,25 @@ def group_sizes(m, num_groups): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { "m": 5, "k": 4, "n": 3, "num_groups": 1}, + { "m": 10, "k": 9, "n": 8, "num_groups": 2}, + ) + def test_ragged_dot_unsupported( + self, m, k, n, num_groups): + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + group_sizes_shape = (num_groups,) + + args_maker = lambda: [ + jnp.ones(lhs_shape, dtype=jnp.float32), + jnp.ones(rhs_shape, dtype=jnp.float32), + jnp.ones(group_sizes_shape, dtype=jnp.int32), + ] + if jtu.test_device_matches(["tpu"]): + with self.assertRaises(jax.errors.JaxRuntimeError): + self._CompileAndCheck(lax.ragged_dot, args_maker) + @parameterized.parameters( { "lhs_shape": lhs_shape, @@ -5049,10 +5070,69 @@ def test_ragged_dot_general_shape_inference_success( lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) - self.assertEqual( - lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, - out_shape, + if jtu.test_device_matches(["tpu"]): + actual_shape = lax_internal._ragged_dot_general_shape_rule( + lhs, rhs, group_sizes, ragged_dot_dimension_numbers=ragged_dnums, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, + ) + else: + actual_shape = lax.ragged_dot_general( + lhs, rhs, group_sizes, ragged_dnums + ).shape + self.assertEqual(actual_shape, out_shape) + + @parameterized.product( + batch_size=[3, 5], + m=[128, 1024], + k=[128, 1024], + n=[128, 1024], + num_groups=[2, 4], + ) + def test_ragged_dot_general_vmap( + self, batch_size: int, m: int, k: int, n: int, num_groups: int + ): + if (jtu.test_device_matches(["tpu"])): + raise SkipTest("batched ragged_dot not yet supported on TPU") + + lhs_shape = (batch_size, m, k) + rhs_shape = (batch_size, num_groups, k, n) + dtype = jnp.float32 + + def make_group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate( + [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate( + [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [ + rng(lhs_shape, dtype), + rng(rhs_shape, dtype), + jnp.array([make_group_sizes(m, num_groups) for _ in range(batch_size)]), + ] + lhs, rhs, group_sizes = args_maker() + + out_dtype = jnp.float32 + precision = jax.lax.Precision.HIGHEST + ragged_dot = partial( + jax.lax.ragged_dot, + preferred_element_type=out_dtype, + precision=precision, ) + tol = 1e-5 + + batch_res = jax.vmap(ragged_dot)(lhs, rhs, group_sizes) + for i in range(batch_size): + # The ragged_dot does not zero out the output in the case sum(group_sizes) + # < m, hence we need to compare only the valid part of the output. + upper_bound = group_sizes[i].sum(axis=0) + ref_res = ragged_dot(lhs[i], rhs[i], group_sizes[i])[0:upper_bound, :] + self.assertArraysAllClose( + batch_res[i, 0:upper_bound, :], ref_res, rtol=tol, atol=tol + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 55e408471ceaf5f0ed0e10053331d919fa2540ec Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 18:52:12 -0700 Subject: [PATCH 0098/1769] [JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib. Future changes will migrate many of its dependent modules. PiperOrigin-RevId: 739361786 --- jax/_src/lib/BUILD | 2 +- jaxlib/BUILD | 10 +- jaxlib/jax.bzl | 9 + jaxlib/tools/BUILD.bazel | 4 +- jaxlib/tools/build_wheel.py | 6 +- jaxlib/xla/BUILD | 111 +- jaxlib/xla/xla.cc | 965 +++++++++++++++ jaxlib/xla/xla_client.py | 2 +- jaxlib/xla/xla_extension/__init__.pyi | 1059 +++++++++++++++++ jaxlib/xla/xla_extension/config.pyi | 32 + jaxlib/xla/xla_extension/guard_lib.pyi | 46 + jaxlib/xla/xla_extension/ifrt_programs.pyi | 43 + jaxlib/xla/xla_extension/ifrt_proxy.pyi | 33 + jaxlib/xla/xla_extension/jax_jit.pyi | 76 ++ jaxlib/xla/xla_extension/mlir.pyi | 34 + jaxlib/xla/xla_extension/ops.pyi | 465 ++++++++ jaxlib/xla/xla_extension/pmap_lib.pyi | 83 ++ jaxlib/xla/xla_extension/profiler.pyi | 58 + jaxlib/xla/xla_extension/pytree.pyi | 158 +++ jaxlib/xla/xla_extension/sdy.pyi | 32 + .../xla/xla_extension/transfer_guard_lib.pyi | 39 + jaxlib/xla_extension.py | 17 + 22 files changed, 3268 insertions(+), 16 deletions(-) create mode 100644 jaxlib/xla/xla.cc create mode 100644 jaxlib/xla/xla_extension/__init__.pyi create mode 100644 jaxlib/xla/xla_extension/config.pyi create mode 100644 jaxlib/xla/xla_extension/guard_lib.pyi create mode 100644 jaxlib/xla/xla_extension/ifrt_programs.pyi create mode 100644 jaxlib/xla/xla_extension/ifrt_proxy.pyi create mode 100644 jaxlib/xla/xla_extension/jax_jit.pyi create mode 100644 jaxlib/xla/xla_extension/mlir.pyi create mode 100644 jaxlib/xla/xla_extension/ops.pyi create mode 100644 jaxlib/xla/xla_extension/pmap_lib.pyi create mode 100644 jaxlib/xla/xla_extension/profiler.pyi create mode 100644 jaxlib/xla/xla_extension/pytree.pyi create mode 100644 jaxlib/xla/xla_extension/sdy.pyi create mode 100644 jaxlib/xla/xla_extension/transfer_guard_lib.pyi create mode 100644 jaxlib/xla_extension.py diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1f4f41132e9e..aa2d9cba4973 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -45,6 +45,7 @@ py_library_providing_imports_info( "//jaxlib:cpu_feature_guard", "//jaxlib:utils", "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", @@ -61,6 +62,5 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", - # xla_extension ]), ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5f693c5384df..52c945482222 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -29,13 +29,6 @@ package( default_visibility = ["//jax:internal"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", -) - py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -51,8 +44,8 @@ py_library_providing_imports_info( "lapack.py", "plugin_support.py", "xla_client.py", + "xla_extension.py", ":version", - ":xla_extension_py", ], data = [":ffi_headers"], lib_rule = pytype_library, @@ -82,6 +75,7 @@ py_library_providing_imports_info( "//jaxlib/mosaic", "//jaxlib/triton", "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", ], ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 4403915154bc..c6f55a86143f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -610,3 +610,12 @@ def jax_py_test( if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" py_test(name = name, env = env, **kwargs) + +def if_oss(oss_value, google_value = []): + """Returns one of the arguments based on the non-configurable build env. + + Specifically, it does not return a `select`, and can be used to e.g. + compute elements of list attributes. + """ + _ = (google_value, oss_value) # buildifier: disable=unused-variable + return oss_value diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index afa5866e286d..2ddc9e90a702 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -62,11 +62,11 @@ py_binary( "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", + "//jaxlib/xla:xla_client.py", + "//jaxlib/xla:xla_extension", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]), diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 0e0ce077cb23..9967fc14b9f9 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -110,7 +110,7 @@ def patch_copy_xla_extension_stubs(dst_dir): xla_extension_dir = os.path.join(dst_dir, "xla_extension") os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name) + stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): continue @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")], + ["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")], capture_output=True, text=True, check=False, @@ -198,7 +198,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla/xla_client.py", - f"xla/xla/python/xla_extension.{pyext}", + f"__main__/jaxlib/xla/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 41152d642fc8..3239ba703937 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_oss", "nanobind_extension", "py_deps", "py_strict_library", @@ -35,6 +36,114 @@ package_group( ], ) +nanobind_extension( + name = "xla_extension", + srcs = ["xla.cc"], + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["xla_extension/*.pyi"]), + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:config", + "@xla//xla/python:custom_call_sharding", + "@xla//xla/python:dlpack", + "@xla//xla/python:guard_lib", + "@xla//xla/python:jax_jit", + "@xla//xla/python:logging", + "@xla//xla/python:mlir", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:ops", + "@xla//xla/python:pjit", + "@xla//xla/python:pmap_lib", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:profiler", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:sdy", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python:weakref_lru_cache", + "@xla//xla/python:xla_compiler", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt_proxy/client:py_module", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@xla//xla/python/transfer:py_socket_transfer", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], @@ -43,7 +152,7 @@ pytype_strict_library( deps = py_deps([ "numpy", "ml_dtypes", - ]) + ["@xla//xla/python:xla_extension"], + ]) + [":xla_extension"], ) py_strict_test( diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc new file mode 100644 index 000000000000..5f39b9173b89 --- /dev/null +++ b/jaxlib/xla/xla.cc @@ -0,0 +1,965 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt_proxy/client/py_module.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/py_client.h" +#include "xla/python/py_program.h" +#include "xla/python/sdy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#include "xla/python/transfer/py_socket_transfer.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/config.h" +#include "xla/python/custom_call_sharding.h" +#include "xla/python/dlpack.h" +#include "xla/python/guard_lib.h" +#include "xla/python/jax_jit.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/mlir.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/ops.h" +#include "xla/python/pjit.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pmap_lib.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/profiler.h" +#include "xla/python/py_array.h" +#include "xla/python/py_compile_only_client.h" +#include "xla/python/py_device.h" +#include "xla/python/py_device_list.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_memory_space.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/weakref_lru_cache.h" +#include "xla/python/xla_compiler.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace xla { +namespace { + +namespace nb = nanobind; + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(ADDRESS_SANITIZER) + return true; +#else // defined(ADDRESS_SANITIZER) + return false; +#endif +} + +bool IsMsan() { +#if defined(MEMORY_SANITIZER) + return true; +#else // defined(MEMORY_SANITIZER) + return false; +#endif +} + +bool IsTsan() { +#if defined(THREAD_SANITIZER) + return true; +#else // defined(THREAD_SANITIZER) + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(xla_extension, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "XlaRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F8E3M4", F8E3M4) + // .value("F8E4M3", F8E4M3) + .value("F8E8M0FNU", F8E8M0FNU) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Must be before PyClient.compile. + BuildXlaCompilerSubmodule(m); + + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout& layout, + const PjRtLayout& other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) + .def("__getstate__", + [](const PjRtLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](PjRtLayout* self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtLayout((*layout)->xla_layout()); + }); + + jax::BuildWeakrefLRUCacheAPI(m); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); + ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + ifrt_client = + ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api* api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map& options, + std::shared_ptr distributed_client) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + ifrt::DeviceListRef device_list = + client->ifrt_client()->MakeDeviceList(ifrt_devices); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_hlo_proto", + [](const CompiledMemoryStats& cms) -> nb::bytes { + return nb::bytes(cms.serialized_hlo_proto.data(), + cms.serialized_hlo_proto.size()); + }) + .def("__str__", &CompiledMemoryStats::DebugString); + + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); + + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("delete", &PyLoadedExecutable::Delete) + .def("execute_sharded_on_local_devices", + xla::ValueOrThrowWrapper( + &PyLoadedExecutable::ExecuteShardedOnLocalDevices), + nb::arg("arguments")) + .def("execute_sharded_on_local_devices_with_tokens", + xla::ValueOrThrowWrapper( + &PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens), + nb::arg("arguments")) + // TODO(parkers): Switch execute_sharded_on_local_devices* to this. + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable& self) { + auto map = ValueOrThrow(self.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken& self) { xla::ThrowIfError(self.Await()); }); + + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken& self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); + // Legacy overload + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, std::move(cpu_client), std::move(gpu_client))); + }, + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildProfilerSubmodule(m); + BuildOpsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildSdySubmodule(m); + BuildCustomCallShardingPybindAPI(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](tsl::PreemptionSyncManager& manager, + DistributedRuntimeClient* client) { + tsl::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](tsl::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "wait_at_barrier", + [](DistributedRuntimeClient& client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](DistributedRuntimeClient& client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.GetLiveNodes(process_ids)); + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, absl::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, absl::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](DistributedRuntimeClient& client, absl::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto& kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional> + missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + std::move(*missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return MakeCompileOnlyClient( + std::dynamic_pointer_cast(topology)) + ->Devices(); + }) + .def_prop_ro( + "platform", + [](ifrt::Topology& topology) { return topology.platform_name(); }) + .def_prop_ro( + "platform_version", + [](ifrt::Topology& topology) { return topology.platform_version(); }) + .def("serialize", + [](ifrt::Topology& topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { + const auto& attrs = topology.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) + .def("get_output_shardings", &ifrt::Executable::GetOutputShardings) + .def("get_parameter_layouts", + ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts)) + .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) + .def("serialize", + [](const ifrt::Executable& exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const ifrt::Executable& exec) { + auto attrs = ValueOrThrow(exec.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { + return ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + return ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); +} // NOLINT(readability/fn_size) + +} // namespace xla diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index b6c5707d05dd..a111c14232de 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -19,7 +19,7 @@ import atexit from collections.abc import Mapping, Sequence import contextlib -import enum # pylint: disable=g-bad-import-order +import enum import gzip import inspect import logging diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi new file mode 100644 index 000000000000..3a6435824b67 --- /dev/null +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -0,0 +1,1059 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import enum +import inspect +import types +import typing +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import numpy as np + +from . import config as config +from . import guard_lib as guard_lib +from . import ifrt_programs as ifrt_programs +from . import ifrt_proxy as ifrt_proxy +from . import jax_jit as jax_jit +from . import mlir as mlir +from . import ops as ops +from . import pmap_lib as pmap_lib +from . import profiler as profiler +from . import pytree as pytree +from . import transfer_guard_lib as transfer_guard_lib + +custom_call_targets = Any +hlo_sharding_util = Any + +_LiteralSlice = Any +_Status = Any +_Dtype = Any +_XlaOpMetadata = Any + +_T = TypeVar("_T") + +class XlaRuntimeError(RuntimeError): + pass + +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID: PrimitiveType + PRED: PrimitiveType + S2: PrimitiveType + S4: PrimitiveType + S8: PrimitiveType + S16: PrimitiveType + S32: PrimitiveType + S64: PrimitiveType + U2: PrimitiveType + U4: PrimitiveType + U8: PrimitiveType + U16: PrimitiveType + U32: PrimitiveType + U64: PrimitiveType + F4E2M1FN: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType + F8E4M3FN: PrimitiveType + F8E4M3B11FNUZ: PrimitiveType + F8E4M3FNUZ: PrimitiveType + F8E5M2: PrimitiveType + F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType + BF16: PrimitiveType + F16: PrimitiveType + F32: PrimitiveType + F64: PrimitiveType + C64: PrimitiveType + C128: PrimitiveType + TUPLE: PrimitiveType + OPAQUE_TYPE: PrimitiveType + TOKEN: PrimitiveType + +# === BEGIN xla_compiler.cc + +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY: ArrayCopySemantics + REUSE_INPUT: ArrayCopySemantics + DONATE_INPUT: ArrayCopySemantics + +class Layout: + @overload + def __init__(self, minor_to_major: Tuple[int, ...]): ... + @overload + def __init__(self, minor_to_major: Tuple[int, ...], + tiling: Tuple[Tuple[int, ...], ...], + element_size_in_bits: int): ... + def minor_to_major(self) -> Tuple[int, ...]: ... + def tiling(self) -> Sequence[Tuple[int, ...]]: ... + def element_size_in_bits(self) -> int: ... + def to_string(self) -> str: ... + def __eq__(self, other: Layout) -> bool: ... + def __ne__(self, other: Layout) -> bool: ... + def __hash__(self) -> int: ... + +class Shape: + def __init__(self, s: str): ... + @staticmethod + def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... + @staticmethod + def array_shape( + type: Union[np.dtype, PrimitiveType], + dims_seq: Any = ..., + layout_seq: Any = ..., + dynamic_dimensions: Optional[List[bool]] = ..., + ) -> Shape: ... + @staticmethod + def token_shape() -> Shape: ... + @staticmethod + def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... + def dimensions(self) -> Tuple[int, ...]: ... + def layout(self) -> Layout: ... + def xla_element_type(self) -> PrimitiveType: ... + def element_type(self) -> np.dtype: ... + def numpy_dtype(self) -> np.dtype: ... + def is_tuple(self) -> bool: ... + def is_array(self) -> bool: ... + def is_token(self) -> bool: ... + def is_static(self) -> bool: ... + def is_dynamic(self) -> bool: ... + def is_dynamic_dimension(self, dimension: int) -> bool: ... + def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... + def rank(self) -> int: ... + def to_serialized_proto(self) -> bytes: ... + def tuple_shapes(self) -> List[Shape]: ... + def leaf_count(self) -> int: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ProgramShape: + def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... + def parameter_shapes(self) -> List[Shape]: ... + def result_shape(self) -> Shape: ... + def __repr__(self) -> str: ... + +class ShapeIndex: + def __init__(self, indices: List[int]) -> ShapeIndex: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class Literal: + def __init__(self, shape: Shape) -> Literal: ... + def __repr__(self) -> str: ... + def __array__( + self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None + ) -> np.ndarray: ... + def shape(self) -> Shape: ... + +class XlaComputation: + def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... + def get_hlo_module(self) -> HloModule: ... + def program_shape(self) -> ProgramShape: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_dot_graph(self) -> str: ... + def hash(self) -> int: ... + def as_hlo_module(self) -> HloModule: ... + +class HloPrintOptions: + def __init__(self) -> None: ... + @staticmethod + def short_parsable() -> HloPrintOptions: ... + @staticmethod + def canonical() -> HloPrintOptions: ... + @staticmethod + def fingerprint() -> HloPrintOptions: ... + print_large_constants: bool + print_metadata: bool + print_backend_config: bool + print_result_shape: bool + print_operand_shape: bool + print_operand_names: bool + print_ids: bool + print_extra_attributes: bool + print_program_shape: bool + print_percent: bool + print_control_dependencies: bool + compact_operands: bool + include_layout_in_shapes: bool + canonicalize_instruction_names: bool + canonicalize_computations: bool + indent_amount: int + is_in_nested_computation: bool + +class HloComputation: + def render_html(self) -> None: ... + +class HloModule: + spmd_output_sharding: Optional[OpSharding] + spmd_parameters_shardings: Optional[List[OpSharding]] + @property + def name(self) -> str: ... + def to_string(self, options: HloPrintOptions = ...) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + @staticmethod + def from_serialized_hlo_module_proto( + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... + def computations(self) -> List[HloComputation]: ... + +class HloModuleGroup: + def __init__(self, name: str, modules: List[HloModule]) -> None: ... + @property + def name(self) -> str: ... + def to_string(self) -> str: ... + def to_modules(self) -> List[HloModule]: ... + +def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... +def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... +def hlo_module_cost_analysis( + client: Client, module: HloModule +) -> Dict[str, float]: ... + +class XlaOp: ... + +class XlaBuilder: + def __init__(self, name: str) -> None: ... + def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... + def GetShape(self, __op: XlaOp) -> Shape: ... + build = Build + def clear_op_metadata(self) -> None: ... + get_shape = GetShape + def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... + def is_constant(self, __op: XlaOp) -> bool: ... + def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... + def set_sharding(self, sharding: OpSharding_Type) -> None: ... + def clear_sharding(self) -> None: ... + def setup_alias( + self, + __output_index: Sequence[int], + __param_number: int, + __param_index: Sequence[int], + ) -> None: ... + +class DeviceAssignment: + @staticmethod + def create(array: np.ndarray) -> DeviceAssignment: ... + def replica_count(self) -> int: ... + def computation_count(self) -> int: ... + def __repr__(self) -> str: ... + def serialize(self) -> bytes: ... + +class CompileOptions: + @staticmethod + def ParseFromString(s: bytes) -> CompileOptions: ... + def __init__(self) -> None: ... + def SerializeAsString(self) -> bytes: ... + argument_layouts: Optional[List[Shape]] + parameter_is_tupled_arguments: bool + executable_build_options: ExecutableBuildOptions + tuple_arguments: bool + num_replicas: int + num_partitions: int + profile_version: int + device_assignment: Optional[DeviceAssignment] + compile_portable_executable: bool + env_option_overrides: List[Tuple[str, str]] + +def register_custom_call_target( + fn_name: str, capsule: Any, platform: str, api_version: int = ..., +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: Optional[Any] = ..., +) -> None: ... +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, + c_api: Optional[Any] = ..., +) -> None: ... + +def register_custom_type_id(type_name: str, type_id: Any) -> None: ... + +class AutotuneCacheMode(enum.IntEnum): + UNSPECIFIED: AutotuneCacheMode + UPDATE: AutotuneCacheMode + READ: AutotuneCacheMode + +class DebugOptions: + def __repr__(self) -> str: ... + xla_cpu_enable_fast_math: bool + xla_cpu_fast_math_honor_infs: bool + xla_cpu_fast_math_honor_nans: bool + xla_cpu_fast_math_honor_division: bool + xla_cpu_fast_math_honor_functions: bool + xla_gpu_enable_fast_min_max: bool + xla_backend_optimization_level: int + xla_cpu_enable_xprof_traceme: bool + xla_llvm_disable_expensive_passes: bool + xla_test_all_input_layouts: bool + xla_disable_hlo_passes: str + xla_enable_hlo_passes_only: str + xla_force_host_platform_device_count: int + xla_dump_to: str + xla_dump_hlo_module_re: str + xla_dump_hlo_pass_re: str + xla_dump_hlo_as_text: bool + xla_dump_hlo_as_proto: bool + xla_dump_hlo_as_dot: bool + xla_dump_hlo_as_url: bool + xla_dump_hlo_as_html: bool + xla_dump_fusion_visualization: bool + xla_dump_hlo_snapshots: bool + xla_dump_max_hlo_modules: bool + xla_dump_module_metadata: bool + xla_dump_compress_protos: bool + xla_dump_hlo_as_long_text: bool + xla_dump_disable_metadata: bool + xla_dump_hlo_pipeline_re: str + xla_gpu_cuda_data_dir: str + xla_detailed_logging: bool + xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str + xla_gpu_dump_autotune_logs_to: str + xla_gpu_kernel_cache_file: str + xla_gpu_enable_llvm_module_compilation_parallelism: bool + xla_gpu_per_fusion_autotune_cache_dir: str + xla_gpu_experimental_autotune_cache_mode: AutotuneCacheMode + +class CompiledMemoryStats: + generated_code_size_in_bytes: int + argument_size_in_bytes: int + output_size_in_bytes: int + alias_size_in_bytes: int + temp_size_in_bytes: int + host_generated_code_size_in_bytes: int + host_argument_size_in_bytes: int + host_output_size_in_bytes: int + host_alias_size_in_bytes: int + host_temp_size_in_bytes: int + serialized_hlo_proto: bytes + def __str__(self) -> str: ... + +class ExecutableBuildOptions: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + result_layout: Optional[Shape] + fdo_profile: Optional[bytes] + num_replicas: int + num_partitions: int + debug_options: DebugOptions + device_assignment: Optional[DeviceAssignment] + use_spmd_partitioning: bool + use_auto_spmd_partitioning: bool + auto_spmd_partitioning_mesh_shape: List[int] + auto_spmd_partitioning_mesh_ids: List[int] + use_shardy_partitioner: bool + def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... + +class PrecisionConfig_Precision(enum.IntEnum): + DEFAULT: int + HIGH: int + HIGHEST: int + + +class ResultAccuracy_Mode(enum.IntEnum): + DEFAULT: int + HIGHEST: int + TOLERANCE: int + +class ResultAccuracy: + mode: ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class OpSharding_Type(enum.IntEnum): + REPLICATED: int + MAXIMAL: int + TUPLE: int + OTHER: int + MANUAL: int + UNKNOWN: int + +class OpSharding_ShardGroupType(enum.IntEnum): + AS: int + LIKE: int + +class OpSharding: + Type: typing.Type[OpSharding_Type] + type: OpSharding_Type + replicate_on_last_tile_dim: bool + last_tile_dims: Sequence[Type] + tile_assignment_dimensions: Sequence[int] + tile_assignment_devices: Sequence[int] + iota_reshape_dims: Sequence[int] + iota_transpose_perm: Sequence[int] + tuple_shardings: Sequence[OpSharding] + is_shard_group: bool + shard_group_id: int + ShardGroupType: typing.Type[OpSharding_ShardGroupType] + shard_group_type: OpSharding_ShardGroupType + def ParseFromString(self, s: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + def clone(self) -> OpSharding: ... + +class HloSharding: + @staticmethod + def from_proto(proto: OpSharding) -> HloSharding: ... + @staticmethod + def from_string(sharding: str) -> HloSharding: ... + @staticmethod + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... + @staticmethod + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type], + ) -> HloSharding: ... + @staticmethod + def replicate() -> HloSharding: ... + @staticmethod + def manual() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: np.ndarray, + subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... + def __eq__(self, other: HloSharding) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def tile(self, shape: Shape) -> Shape: ... + def is_replicated(self) -> bool: ... + def is_manual(self) -> bool: ... + def is_unknown(self) -> bool: ... + def is_tiled(self) -> bool: ... + def is_maximal(self) -> bool: ... + def tuple_elements(self) -> List[HloSharding]: ... + def num_devices(self) -> int: ... + def num_dimensions(self) -> int: ... + def tile_assignment_dimensions(self) -> Sequence[int]: ... + def tile_assignment_devices(self) -> Sequence[int]: ... + def subgroup_types(self) -> Sequence[OpSharding.Type]: ... + def replicate_on_last_tile_dim(self) -> bool: ... + def to_proto(self) -> OpSharding: ... + +class FftType(enum.IntEnum): + FFT: FftType + IFFT: FftType + RFFT: FftType + IRFFT: FftType + +# === END xla_compiler.cc + +class Device: + id: int + host_id: int + process_index: int + platform: str + device_kind: str + client: Client + local_hardware_id: int | None + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def transfer_to_infeed(self, literal: _LiteralSlice): ... + def transfer_from_outfeed(self, shape: Shape): ... + def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: ... + def addressable_memories(self) -> List[Memory]: ... + def live_buffers(self) -> List[Any]: ... + def memory_stats(self) -> Optional[Dict[str, int]]: ... + def get_stream_for_external_ready_events(self) -> int: ... + def __getattr__(self, name: str) -> Any: ... + +class Memory: + process_index: int + platform: str + kind: str + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def addressable_by_devices(self) -> List[Device]: ... + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, other: PjRtLayout) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, _: Any): ... + def _xla_layout(self) -> Layout: ... + +class GpuAllocatorConfig: + class Kind(enum.IntEnum): + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int + + def __init__( + self, + kind: Kind = ..., + memory_fraction: float = ..., + preallocate: bool = ..., + collective_memory_size: int = ..., + ) -> None: ... + +class HostBufferSemantics(enum.IntEnum): + IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics + IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics + ZERO_COPY: HostBufferSemantics + +class Client: + platform: str + _raw_platform: str + platform_version: str + runtime_type: str + def device_count(self) -> int: ... + def local_device_count(self) -> int: ... + def devices(self) -> List[Device]: ... + def local_devices(self) -> List[Device]: ... + def _get_all_devices(self) -> List[Device]: ... + def device_from_local_hardware_id(self, int) -> Device: ... + def live_buffers(self) -> List[Any]: ... + def live_arrays(self) -> List[ArrayImpl]: ... + def live_executables(self) -> List[LoadedExecutable]: ... + def host_id(self) -> int: ... + def process_index(self) -> int: ... + def buffer_from_pyval( + self, + argument: Any, + device: Optional[Device] = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... + def compile( + self, + computation: Union[str, bytes], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... + def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... + def deserialize_executable( + self, + serialized: bytes, + options: Optional[CompileOptions], + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def defragment(self) -> _Status: ... + def get_emit_python_callback_descriptor( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + results_shapes: Sequence[Shape], + ) -> Tuple[Any, Any]: ... + def make_python_callback_from_host_send_and_recv( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Optional[Callable] = ..., + ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + ) -> PjRtLayout: ... + def __getattr__(self, name: str) -> Any: ... + +class CpuCollectives: ... + +def make_gloo_tcp_collectives( + distributed_client: Optional[DistributedRuntimeClient] = ..., + hostname: Optional[str] = ..., + interface: Optional[str] = ..., +) -> CpuCollectives: ... + +class MpiCollectives(CpuCollectives): + def Init(self): ... + def Finalize(self): ... + +def make_mpi_collectives() -> MpiCollectives: ... + +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., +) -> Client: ... +def get_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., + mock: Optional[bool] = ..., + mock_gpu_topology: Optional[str] = ..., +) -> Client: ... +def get_mock_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., +) -> Client: ... +def get_c_api_client( + platform_name: str, + options: Dict[str, Union[str, int, List[int], float, bool]], + distributed_client: Optional[DistributedRuntimeClient] = ..., +) -> Client: ... +def get_default_c_api_topology( + platform_name: str, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_c_api_topology( + c_api: Any, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... +def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(platform_name: str) -> _Status: ... + +ArrayImpl = Any + +# TODO(phawkins): this type is problematic because it is not a subtype of +# jax.Array, and pytype notices. +# class ArrayImpl: +# def __init__(self, +# aval: Any, +# sharding: Any, +# arrays: Sequence[ArrayImpl], +# committed: bool, +# _skip_checks: bool = ...): ... +# def block_until_ready(self) -> ArrayImpl: ... +# def is_deleted(self) -> bool: ... +# def is_ready(self) -> bool: ... +# def delete(self): ... +# def unsafe_buffer_pointer(self) -> Any: ... +# def clone(self) -> ArrayImpl: ... +# def _copy_single_device_array_to_host_async(self): ... +# def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: ... +# def on_device_size_in_bytes(self) -> int: ... +# def _fully_replicated_shard(self) -> ArrayImpl: ... +# __cuda_array_interface__: Dict[str, Any] +# dtype: np.dtype +# shape: Tuple[int, ...] +# _arrays: Any +# _npy_value: Any +# traceback: Traceback +# _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[List[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: List[Device], + committed: bool = True, +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def check_and_canonicalize_memory_kind( + memory_kind: Optional[str], device_list: DeviceList +) -> Optional[str]: ... +def array_result_handler( + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... + +class Token: + def block_until_ready(self): ... + +class ShardedToken: + def block_until_ready(self): ... + def get_token(self, device_id: int): ... + +class ExecuteResults: + def __len__(self) -> int: ... + def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> List[List[ArrayImpl]]: ... + def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... + def consume_token(self) -> ShardedToken: ... + +class LoadedExecutable: + client: Client + def local_devices(self) -> List[Device]: ... + def size_of_generated_code_in_bytes(self) -> int: ... + def delete(self) -> None: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... + def execute_with_token( + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... + def execute_sharded_on_local_devices( + self, arguments: Sequence[List[ArrayImpl]] + ) -> List[List[ArrayImpl]]: ... + def execute_sharded_on_local_devices_with_tokens( + self, arguments: Sequence[List[ArrayImpl]] + ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... + def execute_sharded( + self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def keep_alive(self) -> None: ... + def cost_analysis(self) -> Dict[str, Any]: ... + traceback: Traceback + fingerprint: Optional[bytes] + +class Executable: + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> str: ... + def cost_analysis(self) -> Dict[str, Any]: ... + +class DeviceTopology: + platform: str + platform_version: str + def _make_compile_only_devices(self) -> List[Device]: ... + def serialize(self) -> bytes: ... + def __getattr__(self, name: str) -> Any: ... + +def buffer_to_dlpack_managed_tensor( + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... +@overload +def dlpack_managed_tensor_to_buffer( + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... +@overload +def dlpack_managed_tensor_to_buffer( # Legacy overload + tensor: Any, + cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... + +def cuda_array_interface_to_buffer( + cai: Dict[str, Union[ + str, int, None, + Tuple[int, ...], Tuple[int, bool], + List[Tuple[str, str]], + List[Tuple[str, str, Tuple[int, ...]]]] + ], + gpu_backend: Optional[Client] = ..., + device_id: int | None = None, +) -> ArrayImpl: ... + +# === BEGIN py_traceback.cc + +class Frame: + file_name: str + function_name: str + function_line_start: int + line_num: int + def __init__(self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int): ... + def __repr__(self) -> str: ... + +class Traceback: + enabled: ClassVar[bool] + @staticmethod + def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... + frames: Sequence[Frame] + def __str__(self) -> str: ... + def as_python_traceback(self) -> Any: ... + def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: ... + @staticmethod + def code_addr2location( + code: types.CodeType, lasti: int + ) -> Tuple[int, int, int, int]: ... + +def replace_thread_exc_traceback(traceback: Any): ... + +# === END py_traceback.cc + +class DistributedRuntimeService: + def shutdown(self) -> None: ... + +class DistributedRuntimeClient: + def connect(self) -> _Status: ... + def shutdown(self) -> _Status: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... + def key_value_dir_get(self, key: str) -> _Status: ... + def key_value_dir_get_bytes(self, key: str) -> _Status: ... + def key_value_set(self, key: str, value: str, + allow_overwrite: bool = False) -> _Status: ... + def key_value_set_bytes(self, key: str, value: bytes, + allow_overwrite: bool = False) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... + def wait_at_barrier( + self, barrier_id: str, timeout_in_ms: int, process_ids: Optional[List[int]] + ) -> _Status: ... + def get_live_nodes(self, process_ids: List[int]) -> _Status: ... + +def get_distributed_runtime_service( + address: str, + num_nodes: int, + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + cluster_register_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., +) -> DistributedRuntimeService: ... +def get_distributed_runtime_client( + address: str, + node_id: int, + rpc_timeout: Optional[int] = ..., + init_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + missed_heartbeat_callback: Optional[Any] = ..., + shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., +) -> DistributedRuntimeClient: ... + +class PreemptionSyncManager: + def initialize(self, client: DistributedRuntimeClient) -> _Status: ... + def reached_sync_point(self, step_counter: int) -> bool: ... + +def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def collect_garbage() -> None: ... +def is_optimized_build() -> bool: ... +def json_to_pprof_profile(json: str) -> bytes: ... +def pprof_profile_to_json(proto: bytes) -> str: ... + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + +def weakref_lru_cache( + cache_context_fn: Callable, call: Callable, maxsize=... +) -> WeakrefLRUCache: ... + +class DeviceList: + def __init__(self, device_assignment: Tuple[Device, ...]): ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __len__(self) -> int: ... + def __getitem__(self, index: Any) -> Any: ... + def __iter__(self) -> Iterator[Device]: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def default_memory_kind(self) -> Optional[str]: ... + @property + def memory_kinds(self) -> Tuple[str, ...]: ... + +class Sharding: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: Optional[str] = None, + _manual_axes: frozenset[Any] = frozenset(), + _logical_device_ids: tuple[int, ...] | None = None, + ): ... + mesh: Any + spec: Any + _memory_kind: Optional[str] + _internal_device_list: DeviceList + _manual_axes: frozenset[Any] + _logical_device_ids: tuple[int, ...] | None + +class SingleDeviceSharding(Sharding): + def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ... + _device: Device + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PmapSharding(Sharding): + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... + devices: List[Any] + sharding_spec: pmap_lib.ShardingSpec + _internal_device_list: DeviceList + +class GSPMDSharding(Sharding): + def __init__( + self, + devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, + memory_kind: Optional[str] = None, + _device_list: Optional[DeviceList] = None, + ): ... + _devices: Tuple[Device, ...] + _hlo_sharding: HloSharding + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PjitFunction: + def __call__(self, *args, **kwargs) -> Any: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = ...): ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self): ... + @staticmethod + def clear_all(): ... + +def pjit( + function_name: str, + fun: Optional[Callable], + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: pytree.PyTreeRegistry, + shard_arg_fallback: Callable, + cache: Optional[PjitFunctionCache] = ..., +) -> PjitFunction: ... + +class HloPassInterface: + @property + def name(self) -> str: ... + def is_pass_pipeline(self) -> bool: ... + def run(self, module: HloModule) -> bool: ... + def run_on_module_group(self, module_group: HloModuleGroup) -> bool: ... + +class HloDCE(HloPassInterface): + def __init__(self) -> None: ... + +class CallInliner(HloPassInterface): + def __init__(self) -> None: ... + +class FlattenCallGraph(HloPassInterface): + def __init__(self) -> None: ... + +class TupleSimplifer(HloPassInterface): + def __init__(self) -> None: ... + +class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + +class WeakrefLRUCache: + def __call__(self, weakref_key: Any, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCacheInfo: ... + def cache_clear(self): ... + +def is_asan() -> bool: ... +def is_msan() -> bool: ... +def is_tsan() -> bool: ... +def is_sanitized() -> bool: ... + +class TransferConnection: + + def address(self) -> str: ... + + def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... + +class TransferServer: + def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... + + def connect(self, address: str) -> TransferConnection: ... + +def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ... diff --git a/jaxlib/xla/xla_extension/config.pyi b/jaxlib/xla/xla_extension/config.pyi new file mode 100644 index 000000000000..535554559180 --- /dev/null +++ b/jaxlib/xla/xla_extension/config.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Generic, TypeVar + +unset: object + +_T = TypeVar('_T') + +class Config(Generic[_T]): + def __init__(self, value: _T, include_in_jit_key: bool = False): ... + + @property + def value(self) -> _T: ... + + def get_local(self) -> Any: ... + def get_global(self) -> _T: ... + def set_local(self, value: Any) -> None: ... + def swap_local(self, value: Any) -> Any: ... + def set_global(self, value: _T) -> None: ... diff --git a/jaxlib/xla/xla_extension/guard_lib.pyi b/jaxlib/xla/xla_extension/guard_lib.pyi new file mode 100644 index 000000000000..cfa8b0c5fa5e --- /dev/null +++ b/jaxlib/xla/xla_extension/guard_lib.pyi @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class GarbageCollectionGuardLevel: + ALLOW: Any + LOG: Any + FATAL: Any + +class GuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + + garbage_collect_array: Optional[GarbageCollectionGuardLevel] + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla/xla_extension/ifrt_programs.pyi b/jaxlib/xla/xla_extension/ifrt_programs.pyi new file mode 100644 index 000000000000..bcee365e5732 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_programs.pyi @@ -0,0 +1,43 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Sequence, Union + +from jax.jaxlib.xla import xla_extension + +class Program: ... + +class CompileOptions: ... + +def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ... + +def make_colocated_python_program( + name : str, + picked_function: bytes, + devices: Sequence[xla_extension.Device] | xla_extension.DeviceList, + input_avals: Sequence[Any], + output_avals: Sequence[Any], +) -> Program: ... + +def make_plugin_program(data: Union[str, bytes]) -> Program: ... + +def make_colocated_python_compile_options() -> CompileOptions: ... + +def make_xla_compile_options( + compile_options: xla_extension.CompileOptions, + host_callbacks: Sequence[Any] +) -> CompileOptions: ... + +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/xla/xla_extension/ifrt_proxy.pyi b/jaxlib/xla/xla_extension/ifrt_proxy.pyi new file mode 100644 index 000000000000..3b5de7aa97c9 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_proxy.pyi @@ -0,0 +1,33 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable + +from jax.jaxlib.xla import xla_extension + +_Status = Any +Client = xla_extension.Client + + +class ClientConnectionOptions: + on_disconnect: Optional[Callable[[_Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + connection_timeout_in_seconds: Optional[int] = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... diff --git a/jaxlib/xla/xla_extension/jax_jit.pyi b/jaxlib/xla/xla_extension/jax_jit.pyi new file mode 100644 index 000000000000..1f78d283333c --- /dev/null +++ b/jaxlib/xla/xla_extension/jax_jit.pyi @@ -0,0 +1,76 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +from jax.jaxlib.xla import xla_extension + +from . import pytree + +Client = xla_extension.Client +Device = xla_extension.Device + + +class JitState: + disable_jit: Optional[bool] + enable_x64: Optional[bool] + default_device: Optional[Any] + extra_jit_context: Optional[Any] + post_hook: Optional[Callable[..., Any]] + +def global_state() -> JitState: ... +def thread_local_state() -> JitState: ... + +def get_enable_x64() -> bool: ... +def set_thread_local_state_initialization_callback( + function: Callable[[], None]): ... + +def swap_thread_local_state_disable_jit( + value: Optional[bool]) -> Optional[bool]: ... + +class ArgSignature: + dtype: np.dtype + shape: Tuple[int, ...] + weak_type: bool + +def _ArgSignatureOfValue( + __arg: Any, + __jax_enable_x64: bool) -> ArgSignature: ... + +def _is_float0(__arg: Any) -> bool: ... + + +class ArgumentSignature: + static_args: Sequence[Any] + static_arg_names: Sequence[str] + dynamic_arg_names: Sequence[str] + dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] + + def __eq__(self, value, /): ... + def __ne__(self, value, /): ... + def __hash__(self, /): ... + def __str__(self): ... + def __repr__(self): ... + + +def parse_arguments( + positional_args: Sequence[Any], + keyword_args: Sequence[Any], + kwnames: Tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: pytree.PyTreeRegistry, +) -> tuple[ArgumentSignature, Sequence[Any]]: ... diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla/xla_extension/mlir.pyi new file mode 100644 index 000000000000..95eeae660c0c --- /dev/null +++ b/jaxlib/xla/xla_extension/mlir.pyi @@ -0,0 +1,34 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union +from . import XlaComputation + +def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def mlir_module_to_xla_computation( + mlir_module: Union[bytes, str], + use_tuple_args: bool = ..., + return_tuple: bool = ..., +) -> XlaComputation: ... +def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ... +def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... +def deserialize_portable_artifact(mlir_module: bytes) -> str: ... +def refine_polymorphic_shapes( + mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., + enable_shardy: bool = ..., +) -> bytes: ... diff --git a/jaxlib/xla/xla_extension/ops.pyi b/jaxlib/xla/xla_extension/ops.pyi new file mode 100644 index 000000000000..ff55de3a5cdc --- /dev/null +++ b/jaxlib/xla/xla_extension/ops.pyi @@ -0,0 +1,465 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +from typing import Any, Optional, Sequence, overload + +from jax.jaxlib.xla import xla_extension + +FftType = xla_extension.FftType +XlaBuilder = xla_extension.XlaBuilder +XlaComputation = xla_extension.XlaComputation +XlaOp = xla_extension.XlaOp +PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision +PrimitiveType = xla_extension.PrimitiveType +Shape = xla_extension.Shape +ShapeIndex = xla_extension.ShapeIndex +ResultAccuracy = xla_extension.ResultAccuracy + +_ChannelHandle = Any +_ConvDimensionNumbers = Any +_DotDimensionNumbers = Any +_Layout = Any +_LiteralSlice = Any +_GatherDimensionNumbers = Any +_PaddingConfig = Any +_ReplicaGroup = Any +_ScatterDimensionNumbers = Any + +class TriangularSolveOptions_Transpose(enum.IntEnum): + TRANSPOSE_INVALID: int + NO_TRANSPOSE: int + TRANSPOSE: int + ADJOINT: int + +class RandomAlgorithm(enum.IntEnum): + RNG_DEFAULT: int + RNG_THREE_FRY: int + RNG_PHILOX: int + +class CustomCallSchedule(enum.IntEnum): + SCHEDULE_NONE: int + SCHEDULE_LATEST: int + SCHEDULE_EARLIEST: int + +# TODO(b/189822916): Remove this enum when all clients are migrated to the +# status-returning API. +class CustomCallApiVersion(enum.IntEnum): + API_VERSION_ORIGINAL: int + API_VERSION_STATUS_RETURNING: int + API_VERSION_STATUS_RETURNING_UNIFIED: int + API_VERSION_TYPED_FFI: int + +def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... +def AllGather( + operand: XlaOp, + all_gather_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllReduce( + operand: XlaOp, + computation: XlaComputation, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ... +def ApproxTopK( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKFallback( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKReductionOutputSize( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: Optional[bool] = ..., + input_size_override: Optional[int] = ...) -> tuple[int, int]: ... +def ReduceScatter( + operand: XlaOp, + computation: XlaComputation, + scatter_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllToAll( + operand: XlaOp, + split_dimension: int, + concat_dimension: int, + split_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + layout: Optional[_Layout] = ..., + channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ... +def BitcastConvertType(operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ... +def BroadcastInDim(operand: XlaOp, + shape: Sequence[int], + broadcast_dimensions: Sequence[int]) -> XlaOp: ... +def Call(builder: XlaBuilder, + computation: XlaComputation, + operands: Sequence[XlaOp]) -> XlaOp: ... +def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ... +def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ... +def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def CollectivePermute( + operand: XlaOp, + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def ConcatInDim(builder: XlaBuilder, + operands: Sequence[XlaOp], + dimension: int) -> XlaOp: ... +@overload +def Conditional(branch_index: XlaOp, + branch_computations: Sequence[XlaComputation], + branch_operands: Sequence[XlaOp]) -> XlaOp: ... +@overload +def Conditional( + predicate: XlaOp, + true_operand: XlaOp, + true_computation: XlaComputation, + false_operand: XlaOp, + false_computation: XlaComputation) -> XlaOp: ... + +def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConvGeneralDilated( + lhs: XlaOp, + rhs: XlaOp, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: _ConvDimensionNumbers, + feature_group_count: int = ..., + batch_group_count: int = ..., + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ..., + window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ... +def ConvertElementType( + operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def CreateToken(builder: XlaBuilder) -> XlaOp: ... +def CrossReplicaSum( + operand: XlaOp, + replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ... +def CustomCall( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape: Shape, + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithLayout( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithAliasing( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ..., + literal: _LiteralSlice = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def Dot( + lhs: XlaOp, + rhs: XlaOp, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DotGeneral( + lhs: XlaOp, + rhs: XlaOp, + dimensions_numbers: _DotDimensionNumbers, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DynamicReshape( + operand: XlaOp, + dim_sizes: Sequence[XlaOp], + new_size_bounds: Sequence[int], + dims_are_dynamic: Sequence[bool]) -> XlaOp: ... +def DynamicSlice( + operand: XlaOp, + start_indices: Sequence[XlaOp], + slice_sizes: Sequence[int]) -> XlaOp: ... +def DynamicUpdateSlice( + operand: XlaOp, + update: XlaOp, + start_indices: Sequence[XlaOp]) -> XlaOp: ... +def Eigh( + a: XlaOp, + lower: bool = ..., + max_iter: int = ..., + epsilon: float = ..., + sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ... +def Fft( + operand: XlaOp, + fft_type: FftType, + fft_length: Sequence[int]) -> XlaOp: ... +def Gather( + a: XlaOp, + start_indices: XlaOp, + dimension_numbers: _GatherDimensionNumbers, + slice_sizes: Sequence[int], + indices_are_sorted: bool = ...) -> XlaOp: ... +def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ... +def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ... +def InfeedWithToken( + token: XlaOp, + shape: Shape, + config: Optional[str] = ...) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ... +def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def Map( + builder: XlaBuilder, + operands: Sequence[XlaOp], + computation: XlaComputation, + dimensions: Sequence[int], + static_operands: Sequence[XlaOp] = ...) -> XlaOp: ... +def MultiCollectivePermute( + operands: Sequence[XlaOp], + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ... +def OutfeedWithToken( + operand: XlaOp, + token: XlaOp, + shape_with_layout: Shape, + outfeed_config: Optional[str] = ...) -> XlaOp: ... +def Pad( + operand: XlaOp, + padding_value: XlaOp, + padding_config: _PaddingConfig) -> XlaOp: ... +def Parameter( + builder: XlaBuilder, + parameter_number: int, + shape: Shape, + name: str = ..., + replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ... +def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ... +def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ... +def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ... +def Reduce( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + dimensions_to_reduce: Sequence[int]) -> XlaOp: ... +def ReducePrecision( + operand: XlaOp, + exponent_bits: int, + mantissa_bits: int) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operand: XlaOp, + init_value: XlaOp, + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +def ReplicaId(builder: XlaBuilder) -> XlaOp: ... +def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ... +def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def RngBitGenerator( + algorithm: RandomAlgorithm, + initial_state: XlaOp, + shape: Shape) -> XlaOp: ... +def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ... +def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ... +@overload +def Scatter( + input: XlaOp, + scatter_indices: XlaOp, + updates: XlaOp, + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +@overload +def Scatter( + inputs: Sequence[XlaOp], + scatter_indices: XlaOp, + updates: Sequence[XlaOp], + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ... +def SelectAndScatterWithGeneralPadding( + operand: XlaOp, + select: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + source: XlaOp, + init_value: XlaOp, + scatter: XlaComputation) -> XlaOp: ... +def Slice( + operand: XlaOp, + start_indices: Sequence[int], + limit_indices: Sequence[int], + strides: Sequence[int]) -> XlaOp: ... +def SliceInDim( + operand: XlaOp, + start_index: int, + limit_index: int, + stride: int, + dimno: int) -> XlaOp: ... +def Sort( + builder: XlaBuilder, + operands: Sequence[XlaOp], + comparator: Optional[XlaComputation] = ..., + dimension: int = ..., + is_stable: bool = ...) -> XlaOp: ... +def SVD( + a: XlaOp, + max_iter: int = ..., + epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def TopK(input: XlaOp, k: int) -> XlaOp: ... +def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ... +def TriangularSolve( + a: XlaOp, + b: XlaOp, + left_side: bool, + lower: bool, + unit_diagonal: bool, + transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ... +def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ... +def While( + condition: XlaComputation, + body: XlaComputation, + init: XlaOp) -> XlaOp: ... + + +def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ... +def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ... +def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ... +def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ... + +def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... + +def Not(__arg: XlaOp) -> XlaOp: ... +def PopulationCount(__arg: XlaOp) -> XlaOp: ... +def Clz(__arg: XlaOp) -> XlaOp: ... +def Abs(__arg: XlaOp) -> XlaOp: ... +def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Floor(__arg: XlaOp) -> XlaOp: ... +def Ceil(__arg: XlaOp) -> XlaOp: ... +def Round(__arg: XlaOp) -> XlaOp: ... +def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Sign(__arg: XlaOp) -> XlaOp: ... +def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ... +def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def IsFinite(__arg: XlaOp) -> XlaOp: ... +def Neg(__arg: XlaOp) -> XlaOp: ... +def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Square(__arg: XlaOp) -> XlaOp: ... +def Reciprocal(__arg: XlaOp) -> XlaOp: ... +def Erfc(__arg: XlaOp) -> XlaOp: ... +def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def ErfInv(__arg: XlaOp) -> XlaOp: ... +def Lgamma(__arg: XlaOp) -> XlaOp: ... +def Digamma(__arg: XlaOp) -> XlaOp: ... +def BesselI0e(__arg: XlaOp) -> XlaOp: ... +def BesselI1e(__arg: XlaOp) -> XlaOp: ... +def Acos(__arg: XlaOp) -> XlaOp: ... +def Asin(__arg: XlaOp) -> XlaOp: ... +def Atan(__arg: XlaOp) -> XlaOp: ... +def Acosh(__arg: XlaOp) -> XlaOp: ... +def Asinh(__arg: XlaOp) -> XlaOp: ... +def Atanh(__arg: XlaOp) -> XlaOp: ... +def Cosh(__arg: XlaOp) -> XlaOp: ... +def Sinh(__arg: XlaOp) -> XlaOp: ... +def Real(__arg: XlaOp) -> XlaOp: ... +def Imag(__arg: XlaOp) -> XlaOp: ... +def Conj(__arg: XlaOp) -> XlaOp: ... diff --git a/jaxlib/xla/xla_extension/pmap_lib.pyi b/jaxlib/xla/xla_extension/pmap_lib.pyi new file mode 100644 index 000000000000..8733d6c27b21 --- /dev/null +++ b/jaxlib/xla/xla_extension/pmap_lib.pyi @@ -0,0 +1,83 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import inspect +from typing import Any, Callable, Sequence, Iterable, Tuple + +from . import pytree + +_AvalDimSharding = Any +_MeshDimAssignment = Any + +class NoSharding: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Chunked: + @property + def chunks(self) -> Sequence[int]: ... + def __init__(self, __chunks: Sequence[int]) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Unstacked: + @property + def size(self) -> int: ... + def __init__(self, __sz: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class ShardedAxis: + @property + def axis(self) -> int: ... + def __init__(self, __axis: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: ShardedAxis) -> bool: ... + +class Replicated: + @property + def replicas(self) -> int: ... + def __init__(self, __replicas: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Replicated) -> bool: ... + +class ShardingSpec: + def __init__(self, + sharding: Iterable[_AvalDimSharding], + mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... + @property + def sharding(self) -> Tuple[_AvalDimSharding, ...]: ... + @property + def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ... + def __eq__(self, __other: ShardingSpec) -> bool: ... + def __hash__(self) -> int: ... + + _HAS_DYNAMIC_ATTRIBUTES = True + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + def _debug_cache_keys(self) -> str: ... + +def pmap(fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ... diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla/xla_extension/profiler.pyi new file mode 100644 index 000000000000..7610ce1000bf --- /dev/null +++ b/jaxlib/xla/xla_extension/profiler.pyi @@ -0,0 +1,58 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from types import TracebackType +from typing import Any, Optional, Type, Union, List, Tuple + +_Status = Any + +class ProfilerServer: ... +def start_server(port: int) -> ProfilerServer: ... + +def register_plugin_profiler(c_api: Any) -> None: ... + +def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> Union[bytes, str]: ... + +class ProfilerSession: + def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ... + def stop(self) -> bytes: ... + def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... + +class ProfileOptions: + include_dataset_ops: bool + host_tracer_level: int + python_tracer_level: int + enable_hlo_proto: bool + start_timestamp_ns: int + duration_ms: int + repository_path: str + +def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... + +class TraceMe: + def __init__(self, name: str, **kwargs: Any) -> None: ... + def __enter__(self) -> TraceMe: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]:... + def set_metadata(self, **kwargs): ... + @staticmethod + def is_enabled() -> bool: ... diff --git a/jaxlib/xla/xla_extension/pytree.pyi b/jaxlib/xla/xla_extension/pytree.pyi new file mode 100644 index 000000000000..bfbad5de89d5 --- /dev/null +++ b/jaxlib/xla/xla_extension/pytree.pyi @@ -0,0 +1,158 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import ( + Any, + Callable, + Hashable, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) + +_T = TypeVar("_T") + +version: int + +class PyTreeRegistry: + def __init__( + self, + *, + enable_none: bool = ..., + enable_tuple: bool = ..., + enable_namedtuple: bool = ..., + enable_list: bool = ..., + enable_dict: bool = ... + ): ... + def flatten( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: Any + ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... + def flatten_with_path( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ... + def register_node( + self, + __type: Type[_T], + to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], + from_iterable: Callable[[_AuxData, _Children], _T], + to_iterable_with_keys: ( + Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None + ) = ..., + ) -> Any: ... + def register_dataclass_node( + self, __type: Type[_T], meta_fields: List[str], data_fields: List[str] + ) -> Any: ... + +def default_registry() -> PyTreeRegistry: ... +def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ... +def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... + +class SequenceKey(Hashable): + idx: int + __match_args__: tuple = ... + def __init__(self, idx: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class DictKey(Hashable): + key: Hashable + __match_args__: tuple = ... + def __init__(self, key: Hashable): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class GetAttrKey(Hashable): + name: str + __match_args__: tuple = ... + def __init__(self, name: str): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class FlattenedIndexKey(Hashable): + key: int + __match_args__: tuple = ... + def __init__(self, key: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class PyTreeDef: + def unflatten(self, __leaves: Iterable[Any]) -> Any: ... + def flatten_up_to(self, __xs: Any) -> List[Any]: ... + def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Optional[Callable[[_T], Any]], + leaves: Iterable[Any], + ) -> Any: ... + def from_iterable_tree(self, __xs: Any): ... + def node_data(self) -> Optional[Tuple[Type, Any]]: ... + def children(self) -> List[PyTreeDef]: ... + @staticmethod + def make_from_node_data_and_children( + registry: PyTreeRegistry, + node_data: Optional[Tuple[Type, Any]], + children: Iterable[PyTreeDef], + ) -> PyTreeDef: ... + + num_leaves: int + num_nodes: int + def __repr__(self) -> str: ... + def __eq__(self, __other: PyTreeDef) -> bool: ... + def __ne__(self, __other: PyTreeDef) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + +_Children = TypeVar("_Children", bound=Iterable[Any]) +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair]) +_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) +_AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/jaxlib/xla/xla_extension/sdy.pyi b/jaxlib/xla/xla_extension/sdy.pyi new file mode 100644 index 000000000000..34714e5c0219 --- /dev/null +++ b/jaxlib/xla/xla_extension/sdy.pyi @@ -0,0 +1,32 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from mlir import ir + +def sdy_round_trip_export_pipeline( + module: ir.module +) -> str: ... + +def sdy_round_trip_import_shardings( + module: ir.module +) -> str: ... + +def get_mesh( + module: ir.module +) -> tuple[tuple[str, int], ...]: ... + +def lowered_with_shardy( + module: ir.module +) -> bool: ... diff --git a/jaxlib/xla/xla_extension/transfer_guard_lib.pyi b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi new file mode 100644 index 000000000000..091e1e10a742 --- /dev/null +++ b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi @@ -0,0 +1,39 @@ +# Copyright 2022 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class TransferGuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + +def global_state() -> TransferGuardState: ... +def thread_local_state() -> TransferGuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla_extension.py b/jaxlib/xla_extension.py new file mode 100644 index 000000000000..e4fc7e96a1ab --- /dev/null +++ b/jaxlib/xla_extension.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_extension import * # noqa: F403 +from jaxlib.xla.xla_extension import sdy # noqa: F401 From 396e389001ce3d6f6e3f1bc944245868968539f3 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:02:58 -0700 Subject: [PATCH 0099/1769] [pallas] Add `_zeros[_like]` and `_ones[_like]` utility functions in Triton lowering. PiperOrigin-RevId: 739395754 --- jax/_src/pallas/triton/lowering.py | 40 +++++++++++++++++++----------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 64bf635a34ed..bc7144f376b4 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1120,7 +1120,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: def _minus(x: ir.Value) -> ir.Value: if tt_dialect.PointerType.isinstance(_element_type(x.type)): raise NotImplementedError(f"unsupported type: {x.type}") - return _sub(_full(x.type, 0), x) + return _sub(_zeros_like(x), x) def _add(x: ir.Value, y: ir.Value): @@ -1377,7 +1377,7 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): @register_lowering(lax.integer_pow_p) def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if y == 0: - return _full(x.type, 1) + return _ones_like(x) is_reciprocal = y < 0 if is_reciprocal: @@ -1397,7 +1397,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): acc = _cast(acc, x_aval.dtype, out_aval.dtype) if is_reciprocal: signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) - return _truediv(_full(acc.type, 1), acc, signed=signed) + return _truediv(_ones_like(acc), acc, signed=signed) else: return acc @@ -1518,6 +1518,22 @@ def _full(t: ir.Type, v: object) -> ir.Type: return result +def _zeros(t: ir.Type) -> ir.Value: + return _full(t, 0) + + +def _zeros_like(x: ir.Value) -> ir.Value: + return _full(x.type, 0) + + +def _ones(t: ir.Type) -> ir.Value: + return _full(t, 1) + + +def _ones_like(x: ir.Value) -> ir.Value: + return _full(x.type, 1) + + def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: if ir.RankedTensorType.isinstance(x.type): raise TypeError("cannot splat a tensor") @@ -1556,7 +1572,7 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) @@ -1576,7 +1592,7 @@ def _float_int_cast( raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) else: # We clamp the float value to the min/max integer destination value # in order to match JAX/XLA casting behavior. Note that this differs @@ -1679,7 +1695,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, return tt_dialect.ptr_to_int(dst_type, src) elif dst_element_type.width == 1: x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) - zero = _full(x.type, 0) + zero = _zeros_like(x) return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) if isinstance( src_element_type, ir.IntegerType @@ -1802,7 +1818,7 @@ def _compute_offsets_from_indices( # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) if indexer_shape: - offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + offsets = _zeros(ir.RankedTensorType.get(indexer_shape, offset_eltype)) else: offsets = _ir_constant(0, offset_eltype) @@ -2074,7 +2090,7 @@ def _masked_load_lowering_rule( offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) if jaxlib_version < (0, 5, 2): - in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) + in_msb = arith_dialect.xori(in_msb, _ones_like(in_msb)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) @@ -2280,7 +2296,7 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape - acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: bf16 = _dtype_to_ir_type(jnp.bfloat16) @@ -2297,11 +2313,7 @@ def _dot_general_lowering( acc = tt_dialect.dot(a_bf16, b_err0, acc) # If `a_err0` will be zero and `b` is infinite, then `acc` may contain # `NaN`s (as `0 * inf = NaN`), and vice versa. - acc = arith_dialect.select( - _is_nan(acc), - _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0), - acc, - ) + acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) a, b = a_bf16, b_bf16 acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) From fd0ac0229ff8006e3105615f0837d2f224ff1095 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:13:40 -0700 Subject: [PATCH 0100/1769] [mosaic_gpu] Add `cupti_no_finalize` profiler mode. PiperOrigin-RevId: 739397564 --- jax/experimental/mosaic/gpu/profiler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 32b3edf7caf9..011b921d728e 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -98,7 +98,7 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate): +def _measure_cupti(f, aggregate, *, finalize=True): if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) @@ -108,7 +108,7 @@ def wrapper(*args, **kwargs): try: results = jax.block_until_ready(f(*args, **kwargs)) finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() + timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings(finalize) if not timings: return results, None @@ -133,6 +133,7 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True mode: The mode of operation. Possible values are: - "cupti", for CUPTI-based profiling. + - "cupti_no_finalize", as above, but CUPTI left attached to the process. - "events", for CUDA events-based profiling. The two modes use different measurement methodologies and should not be @@ -175,10 +176,12 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True In an attempt to minimize the second effect, internally the events-based implementation may execute ``f`` more than once to "warm up" and exclude compilation time from the measurement. - """ + """ # fmt: skip match mode: case "cupti": return _measure_cupti(f, aggregate) + case "cupti_no_finalize": + return _measure_cupti(f, aggregate, finalize=False) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") From 74977938d8355b41c389b146f4c71f205ceff3ec Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:32:43 -0700 Subject: [PATCH 0101/1769] [pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X{6,9}` in Triton lowering. PiperOrigin-RevId: 739400359 --- jax/_src/pallas/triton/lowering.py | 50 ++++++++++++++++++++++-------- tests/pallas/pallas_test.py | 9 ++++-- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index bc7144f376b4..0077ec55ace8 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2218,6 +2218,14 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +def _as_bf16(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.bfloat16), signed=False) + + +def _as_f32(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.float32), signed=False) + + @register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, @@ -2258,6 +2266,8 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X6 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X9 ): input_precision = None case _: @@ -2298,20 +2308,34 @@ def _dot_general_lowering( _, n = b_type.shape acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) - if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: - bf16 = _dtype_to_ir_type(jnp.bfloat16) - f32 = _dtype_to_ir_type(jnp.float32) - as_bf16 = lambda x: _ir_cast(x, bf16, signed=False) - as_f32 = lambda x: _ir_cast(x, f32, signed=False) - - a_bf16 = as_bf16(a) - b_bf16 = as_bf16(b) - a_err0 = as_bf16(_sub(a, as_f32(a_bf16))) - b_err0 = as_bf16(_sub(b, as_f32(b_bf16))) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + a_bf16 = _as_bf16(a) + b_bf16 = _as_bf16(b) + a_err0 = _sub(a, _as_f32(a_bf16)) + b_err0 = _sub(b, _as_f32(b_bf16)) + a_err0_bf16 = _as_bf16(a_err0) + b_err0_bf16 = _as_bf16(b_err0) + a_err1_bf16 = _as_bf16(_sub(a_err0, _as_f32(a_err0_bf16))) + b_err1_bf16 = _as_bf16(_sub(b_err0, _as_f32(b_err0_bf16))) # Accumulate the smallest values first to reduce the numeric error. - acc = tt_dialect.dot(a_err0, b_bf16, acc) - acc = tt_dialect.dot(a_bf16, b_err0, acc) - # If `a_err0` will be zero and `b` is infinite, then `acc` may contain + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X9: + acc = tt_dialect.dot(a_err1_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err1_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err1_bf16, acc) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + acc = tt_dialect.dot(a_err1_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0_bf16, acc) + # If `a` rounding error is zero and `b` is `inf` then `acc` may contain # `NaN`s (as `0 * inf = NaN`), and vice versa. acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) a, b = a_bf16, b_bf16 diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 0ce68a5c023c..6f52a7afb1bf 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -703,6 +703,8 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -733,9 +735,12 @@ def dot_kernel(x_ref, y_ref, o_ref): preferred_element_type=jnp.float32, ) if dtype == "bfloat16" or precision in ( - jax.lax.Precision.HIGHEST, jax.lax.DotAlgorithmPreset.F32_F32_F32 + jax.lax.Precision.HIGHEST, + jax.lax.DotAlgorithmPreset.F32_F32_F32, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9, ): - atol = 0 + atol = 5e-6 elif precision in ( jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, From d4745b9bd81b49e2a7a8938ea98516296d54635f Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 22 Mar 2025 01:54:12 -0700 Subject: [PATCH 0102/1769] Reverts ad21b62bfec5560d4c612ed3c8412eb2d240468b PiperOrigin-RevId: 739431800 --- tests/pgle_test.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dabd809d95e..7f9ea598d51b 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,11 +65,7 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', - }, + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -97,8 +93,6 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -327,11 +321,7 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', - }, + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y From 34aa5e69477f74d5e1d5e2945c7fd23f72c6dd6e Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 22 Mar 2025 04:21:51 -0700 Subject: [PATCH 0103/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1b26f2c8502a7d180ce959d0e6546c91ef820b02. PiperOrigin-RevId: 739453338 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 00f985cdf352..305cf14c1045 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "469329ec36be093fd71d29e4518402300e04aeec" -XLA_SHA256 = "9de006d7b51c36057898c81111fa9723b59f024eced067572fe5f6b1df63abdd" +XLA_COMMIT = "1b26f2c8502a7d180ce959d0e6546c91ef820b02" +XLA_SHA256 = "9492831de7840a3977eb8fcad34f2673e1bd8871cb060f9b6ee93f622956b896" def repo(): tf_http_archive( From a092df90ba7868f86e71cdaed245bb1abd77f1d4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 22 Mar 2025 20:40:14 +0000 Subject: [PATCH 0104/1769] fix a linearize-of-remat-of-while_loop-fixpoint bug We were using the original unknown-carries-in rather than the fixpoint-updated ones. --- jax/_src/lax/control_flow/loops.py | 41 +++++++++++++++++++++++++----- tests/lax_control_flow_test.py | 12 +++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3084fa722977..33e2d2cbb0c8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1438,7 +1438,29 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): - del avals + cond_consts_avals, body_consts_avals, in_avals = \ + util.split_list(avals, [cond_nconsts, body_nconsts]) + + if len(cond_jaxpr.in_avals) != len(cond_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(cond_jaxpr.in_avals)=} but {len(cond_consts_avals) + len(in_avals)=}") + if len(body_jaxpr.in_avals) != len(body_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(body_jaxpr.in_avals)=} but {len(body_consts_avals) + len(in_avals)=}") + # TODO(mattjj): check body carry type + # TODO(mattjj): make these typecompat checks work with bints + # if not all(_map(core.typecompat, [*cond_consts_avals, *in_avals], cond_jaxpr.in_avals)): # type: ignore + # cond_avals = [*cond_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(cond_avals, cond_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop cond function input type error: {a1} != {a2}") + # if not all(_map(core.typecompat, [*body_consts_avals, *in_avals], body_jaxpr.in_avals)): # type: ignore + # body_avals = [*body_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(body_avals, body_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop body function input type error: {a1} != {a2}") + + joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -1679,7 +1701,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) - del jaxpr_known_, carry_uk_out, num_res + del jaxpr_known_, carry_uk_out, num_res, unks_in # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -1701,6 +1723,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del cond_uk # Build the known eqn. + unks_in = [*cond_consts_uk, *body_consts_uk, *carry_uk] # fixpoint carry_uk ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, @@ -1711,6 +1734,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info, eqn.ctx) + # Typecheck known eqn. + _while_loop_abstract_eval( + *[v.aval for v in eqn_known.invars], cond_jaxpr=cond_jaxpr_known, + body_jaxpr=body_jaxpr_known, body_nconsts=params_known['body_nconsts'], + cond_nconsts=params_known['cond_nconsts']) # Staged eqn is same as input eqn. eqn_staged = eqn @@ -1798,8 +1826,7 @@ def fun(*args): cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = mlir.unflatten_ir_values_like_types(flat_cond_args, loop_carry_types) - # Remove tokens from cond args - cond_args = cond_args[num_tokens:] + cond_args = cond_args[num_tokens:] # Remove tokens from cond args x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_consts = [ mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts @@ -1861,8 +1888,9 @@ def fun(*args): partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + hlo.return_([*mlir.flatten_ir_values(out_tokens), + *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), + *mlir.flatten_ir_values(new_z)]) outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1976,7 +2004,6 @@ def new_cond(*consts_refs_carry): batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) -core.custom_typechecks[while_p] = _while_typecheck state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..fcc7fd99ee13 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3066,6 +3066,18 @@ def g(): leak() self.assertEqual(base, nbufs()) + def test_grad_remat_while_fixpoint(self): + @jax.remat + def f(x, y): + def cond(_): + return False + def body(c): + x, y = c + return (y, x) + x, y = jax.lax.while_loop(cond, body, (x, y)) + return x + y + jax.linearize(f, 1., 2.) # don't crash + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 540541a3d32fe6678aa1c208a4f5a9a697b92e2c Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 23 Mar 2025 03:39:34 -0700 Subject: [PATCH 0105/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5ed2ff5c07868d1a7486f4040f8b38936640268e. PiperOrigin-RevId: 739649983 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 305cf14c1045..3e3d636f0f43 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1b26f2c8502a7d180ce959d0e6546c91ef820b02" -XLA_SHA256 = "9492831de7840a3977eb8fcad34f2673e1bd8871cb060f9b6ee93f622956b896" +XLA_COMMIT = "5ed2ff5c07868d1a7486f4040f8b38936640268e" +XLA_SHA256 = "08d175c57d0db599ad57b8fa820ca2f2a6d2808578d53dba421e3af4edb0bccf" def repo(): tf_http_archive( From 5d79df7e67cfc8b253c817f90af81393ea256763 Mon Sep 17 00:00:00 2001 From: Jesse Perla Date: Sun, 23 Mar 2025 15:03:49 -0700 Subject: [PATCH 0106/1769] Add identity activation Fix typo --- docs/jax.nn.rst | 1 + jax/_src/nn/functions.py | 19 +++++++++++++++++++ jax/nn/__init__.py | 1 + tests/nn_test.py | 8 +++++++- 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..2e2e9644d50d 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -40,6 +40,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7df0a638e566..ee0643e116f9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -54,6 +54,25 @@ def __repr__(self): # activations +@jax.jit +def identity(x: ArrayLike) -> Array: + r"""Identity activation function. + + Returns the argument unmodified. + + Args: + x : input array + + Returns: + The argument `x` unmodified. + + Examples: + >>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) + Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32) + + """ + numpy_util.check_arraylike("identity", x) + return jnp.asarray(x) @custom_jvp @jax.jit diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 3f08e1c0fd12..10f11f829abe 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -35,6 +35,7 @@ standardize as standardize, one_hot as one_hot, relu as relu, + identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, scaled_dot_general as scaled_dot_general, diff --git a/tests/nn_test.py b/tests/nn_test.py index 1a1670444ef8..e46843186c02 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -543,7 +543,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) + nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -831,6 +831,12 @@ def testVarianceScalingError(self): ): initializer(rng, shape) + def testIdentity(self): + x = jnp.array([1., 2., 3.]) + self.assertAllClose(nn.identity(x), x, check_dtypes=False) + grad = jax.grad(nn.identity)(6.0) + self.assertEqual(grad, 1.) + def testAccidentalUpcasting(self): rng = random.PRNGKey(0) shape = (4, 4) From 5b0a767d83cf28b41dd1c2207eb56010bcb594d7 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Sun, 23 Mar 2025 21:33:01 -0700 Subject: [PATCH 0107/1769] [jax] Add `ndim` and `size` properties to `TransformedRef`. Without these implementations, `ndim` and `size` were retrieved from the underlying, non-transformed reference and were inconsistent with `TransformedRef.shape`. PiperOrigin-RevId: 739802491 --- jax/_src/state/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 057242f4c1ac..fa9d0cb9fb16 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -266,6 +266,9 @@ def dtype(self): assert dtype is not None return dtype + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: math.prod(self.shape)) + @property def at(self) -> RefIndexer: return RefIndexer(self) From a2475a66c50c148fbe4dcafd61b917c6435d1e4a Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 02:06:10 -0700 Subject: [PATCH 0108/1769] [pallas] Add support for `split` (into two equal parts) in Triton lowering. PiperOrigin-RevId: 739855323 --- jax/_src/pallas/triton/lowering.py | 15 +++++++++++++++ tests/pallas/pallas_test.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0077ec55ace8..d7f6e4695229 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1798,6 +1798,21 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): ) +@register_lowering(lax.split_p) +def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): + pass + # TODO(cjfj): Add support for larger powers of 2. + if len(sizes) != 2: + raise NotImplementedError("Only splitting into two parts is supported.") + if sizes[0] != sizes[1]: + raise NotImplementedError("Only equal-sized splits are supported.") + (x_aval,) = ctx.avals_in + shape = x_aval.shape + x = _reshape(x, shape[:axis] + (2, sizes[0]) + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tt_dialect.split(tt_dialect.trans(x, permutation)) + + def _compute_offsets_from_indices( block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 6f52a7afb1bf..0b16260ff25a 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -843,6 +843,20 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) + @parameterized.parameters( + ((32,), 0), ((32, 64), 0), ((32, 16), 1), ((32, 16, 2), 1) + ) + def test_split(self, shape, axis): + x = jax.random.normal(jax.random.key(0), shape) + expected = jnp.split(x, 2, axis) + + @functools.partial(self.pallas_call, out_shape=expected) + def kernel(x_ref, o0_ref, o1_ref): + o0_ref[()], o1_ref[()] = jnp.split(x_ref[()], 2, axis) + + self.assertAllClose(kernel(x), expected) + + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True From 4da1faf5b6cc8c1e99b3abf6de5f5889f0dc43dd Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 02:49:43 -0700 Subject: [PATCH 0109/1769] Move PGLE documentation to JAX docs. PiperOrigin-RevId: 739865595 --- docs/gpu_performance_tips.md | 144 ++++++++++++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 2 deletions(-) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..737486485736 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,147 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# colletion and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. From 0c38368bce53e5aab7a9ad3e1fc858668035874a Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 03:27:15 -0700 Subject: [PATCH 0110/1769] [mosaic_gpu] Add `Cupti` profiler class. PiperOrigin-RevId: 739874654 --- jax/experimental/mosaic/gpu/profiler.py | 59 +++++++++++++++---------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 011b921d728e..5b278468b98c 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,7 +17,7 @@ import itertools import json import math -from typing import Callable, ParamSpec, TypeVar +from typing import Callable, ParamSpec, TypeAlias, TypeVar import warnings import jax @@ -98,30 +98,44 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate, *, finalize=True): - if not isinstance(f, (stages.Wrapped, stages.Compiled)): - f = jax.jit(f) +Timings: TypeAlias = list[tuple[str, float]] | float | None - def wrapper(*args, **kwargs): - jax.block_until_ready(f(*args, **kwargs)) # Warmup. - mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() - try: - results = jax.block_until_ready(f(*args, **kwargs)) - finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings(finalize) - if not timings: - return results, None - elif aggregate: - return results, sum(item[1] for item in timings) - else: - return results, timings +@dataclasses.dataclass(frozen=True, kw_only=True) +class Cupti: + """CUPTI-based profiler.""" - return wrapper + # If `True`, detach CUPTI from the process after measurement. + finalize: bool = True + def measure( + self, f: Callable[P, T], *, aggregate: bool = True + ) -> Callable[P, tuple[T, Timings]]: + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) -def measure(f: Callable, *, mode: str = "events", aggregate: bool = True -) -> Callable: + def wrapper(*args: P.args, **kwargs: P.kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. + ext = mosaic_gpu_lib._mosaic_gpu_ext + ext._cupti_init() + try: + results = jax.block_until_ready(f(*args, **kwargs)) + finally: + timings = ext._cupti_get_timings(self.finalize) + + if not timings: + return results, None + elif aggregate: + return results, sum(item[1] for item in timings) + else: + return results, timings + + return wrapper + + +def measure( + f: Callable[P, T], *, mode: str = "events", aggregate: bool = True +) -> Callable[P, tuple[T, Timings]]: """Sets up a function ``f`` for profiling on GPU. ``measure`` is a higher-order function that augments the argument ``f`` to @@ -133,7 +147,6 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True mode: The mode of operation. Possible values are: - "cupti", for CUPTI-based profiling. - - "cupti_no_finalize", as above, but CUPTI left attached to the process. - "events", for CUDA events-based profiling. The two modes use different measurement methodologies and should not be @@ -179,9 +192,7 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True """ # fmt: skip match mode: case "cupti": - return _measure_cupti(f, aggregate) - case "cupti_no_finalize": - return _measure_cupti(f, aggregate, finalize=False) + return Cupti().measure(f, aggregate=aggregate) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") From a3e6c6ef61bd0193ea5977e2dce7b6e861e48f52 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 24 Mar 2025 04:02:58 -0700 Subject: [PATCH 0111/1769] [Mosaic GPU] Add support for f16 Blackwell MMA accumulation Very importantly, this also includes support for loading the packed accumulator from TMEM. PiperOrigin-RevId: 739883035 --- jax/experimental/mosaic/gpu/tcgen05.py | 59 +++++++++++++++++--------- tests/mosaic/gpu_test.py | 53 ++++++++++++++++------- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 3330500cd6dc..ac3b80b93689 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -197,7 +197,7 @@ def mma( ), a_mk, b_nk, - d_type=ir.F32Type.get(), + d_type=d.dtype, m=m_group_elems, n=n_group_elems, collective=collective, @@ -327,7 +327,7 @@ def tmem_relinquish_alloc_permit(): has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num): +def tmem_load(tmem_addr, shape, num, packing: int = 1): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: @@ -345,12 +345,18 @@ def tmem_load(tmem_addr, shape, num): num_out_regs *= num i32 = ir.IntegerType.get_signless(32) out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".pack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {{{out_regs}}}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) @@ -521,9 +527,9 @@ def __getitem__(self, *idxs): raise NotImplementedError("Slicing of TMEM not impelmented yet") if self.shape[1] % 8: raise NotImplementedError - if self.dtype != ir.F32Type.get(): - raise NotImplementedError(self.dtype) - layout = _m128_256bit_32bit_layout(self.shape) + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + layout = _m128_layout(self.shape) regs_shape = layout.registers_shape(self.shape) if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): # load_32xcols returns a 4xN array, but the FA tiling we use here tiles @@ -556,20 +562,28 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + def _load_32xcols(base_addr, cols, dtype): # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b i32 = ir.IntegerType.get_signless(32) - assert cols % 8 == 0 - cols_per_num_tile = 8 - load_shape = "16x256b" - num = cols // 8 + packing = 32 // utils.bitwidth(dtype) + if packing == 1: + load_shape = "16x256b" # 8 columns * 32 bits = 256 bits + cols_per_num_tile = 8 * packing + elif packing == 2: + load_shape = "16x128b" # 8 columns * 16 bits = 128 bits + cols_per_num_tile = 4 * packing + else: + raise NotImplementedError(packing) + assert cols % cols_per_num_tile == 0 + num = cols // cols_per_num_tile if num <= 32: num_tiling = num elif num == 64: num_tiling = 32 else: raise NotImplementedError(num) - vector_regs = np.ndarray((4, num), dtype=object) + vector_regs = np.ndarray((4, cols // 8), dtype=object) # We load 16 lanes at a time, but need 32 in total. for row_group in range(2): addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) @@ -579,17 +593,24 @@ def _load_32xcols(base_addr, cols, dtype): addr_row, arith.constant(i32, num_tiling * num_group * cols_per_num_tile), ) - regs += tmem_load(addr_row_col, load_shape, num_tiling) - regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + regs += tmem_load(addr_row_col, load_shape, num_tiling, packing) + if packing == 1: + regs = [llvm.bitcast(dtype, r) for r in regs] + undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(cols // 8, 2), strict=True): + high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) + vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + else: + assert packing == 2 + regs = [llvm.bitcast(ir.VectorType.get((2,), dtype), r) for r in regs] + for vreg, idx in zip(regs, np.ndindex(cols // 8, 2), strict=True): + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg return vector_regs -def _m128_256bit_32bit_layout(shape: tuple[int, ...]): +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7bd7fad3798..478064188750 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -912,15 +912,41 @@ def setUp(self): lhs_transpose=(False, True), rhs_transpose=(False, True), in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), swizzle=(32, 64, 128,), - rhs_transpose_tiles=(False, True), + ) + def test_mma_basic(self, *args, **kwargs): + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128,), + n=(128, 512), + swizzle=(32, 64, 128,), lhs_transpose_tiles=(False, True), + rhs_transpose_tiles=(False, True), ) - def test_mma_basic( + def test_mma_transposed_tiles(self, *args, **kwargs): + if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: + self.skipTest("This is already tested in test_mma_basic") + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + ) + + def _basic_mma_test( self, m, n, @@ -981,16 +1007,10 @@ def kernel(ctx, lhs, rhs, out, scratch): barriers[2].wait(for_tensor_core=True) acc[:].store_untiled(out) - in_finfo = jnp.finfo(in_jax_dtype) - exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant - def quantize(x): - # Quantize the input to avoid rounding when feeding the TensorCore - return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) - x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) if rhs_transpose_tiles: rhs_smem_shape = ( @@ -1015,14 +1035,15 @@ def quantize(x): )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), From 01904363e4a9d2f721c0e2193ef4b199a4ffe9ac Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 04:16:40 -0700 Subject: [PATCH 0112/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003. PiperOrigin-RevId: 739886693 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3e3d636f0f43..996ee511f835 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5ed2ff5c07868d1a7486f4040f8b38936640268e" -XLA_SHA256 = "08d175c57d0db599ad57b8fa820ca2f2a6d2808578d53dba421e3af4edb0bccf" +XLA_COMMIT = "9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003" +XLA_SHA256 = "4e3248d37a1b0598de3e93e8e46ede060578bc45bfbdfaf24d91ab598543b770" def repo(): tf_http_archive( From 381f11090e702fa9403e178a6699017c62d24453 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 09:26:06 -0400 Subject: [PATCH 0113/1769] Reenable tsan suppression, mark some tests as thread-unsafe. --- .github/workflows/tsan-suppressions.txt | 6 +++--- .github/workflows/tsan.yaml | 1 + tests/cache_key_test.py | 2 ++ tests/pjit_test.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 296f4432e687..bdffddc58ca0 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -24,6 +24,9 @@ race_top:PyMember_GetOne # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi @@ -59,9 +62,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added -# https://github.com/python/cpython/issues/128130 -# race_top:run_eval_code_obj - # https://github.com/python/cpython/issues/129547 # Maybe fixed? # race:type_get_annotations diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 6c97b7347ceb..cd59c0bf45e0 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,6 +13,7 @@ on: - main paths: - '**/workflows/tsan.yaml' + - '**/workflows/tsan-suppressions.txt' jobs: tsan: diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 2faa4dbaf9d4..ed80c7060e4c 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -163,6 +163,8 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + # TODO(phawkins): this test flakes if test concurrency is enabled. + @jtu.thread_unsafe_test() def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6fdfa62887b9..2033126759e4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3438,6 +3438,7 @@ def f(x): pjit(f)(inp) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # count_pjit_cpp_cache_miss is not thread-safe def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) From c6525bc58f0ec1507c83a4c3f149208a1b60368f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 24 Mar 2025 08:10:33 -0700 Subject: [PATCH 0114/1769] [Mosaic GPU][NFC] Fix documentation of `WGMMA_LAYOUT`. `TiledLayout` has no notion of partitioning over warpgroups, and each warp holds `16 x 8` elements. PiperOrigin-RevId: 739942481 --- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 4bbfd0dd8afe..c2b61c6d5bfe 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -537,7 +537,7 @@ def linear_thread_idxs(self): # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d -# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# In this layout, we partition the 64x8 tiles over 4 warps into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. From b6b5d952392960ceb78401302cb9d620719407b1 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Thu, 20 Mar 2025 16:25:51 -0700 Subject: [PATCH 0115/1769] [Pallas] In TPU interpret mode, add initial barrier for kernels without one. --- jax/_src/pallas/mosaic/interpret.py | 33 ++++++++++++++++--- .../tpu_pallas_interpret_distributed_test.py | 16 --------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 439ac98b2ac6..2e31d0fba7cf 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -451,6 +451,7 @@ class SharedMemory: num_devices: int clocks: list[VectorClock] barrier: threading.Barrier + clean_up_barrier: threading.Barrier # (memory_space, buffer_id, device_id) -> NumPy array # TODO(jburnim): Handle Megacore. @@ -502,18 +503,35 @@ def _initialize_shared_memory(device_id, num_devices, *, interpret_params): interpret_params=interpret_params, num_devices=num_devices, clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], - barrier=threading.Barrier(num_devices)) + barrier=threading.Barrier( + num_devices, action=_update_clocks_for_global_barrier), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory)) assert _shared_memory.num_devices == num_devices global races races = RaceDetectionState(num_devices=num_devices) +def _update_clocks_for_global_barrier(): + shared_memory = _get_shared_memory() + with shared_memory.lock: + # Set the vector clock for device 0 to the max over all device clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(shared_memory.clocks[0], c) + # Set all other device vector clocks to the max over all the clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(c, shared_memory.clocks[0]) + +def _barrier(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + if shared_memory.num_devices > 1: + shared_memory.barrier.wait() + def _clean_up_shared_memory(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() - shared_memory.barrier.wait() - if device_id == 0: - _clear_shared_memory() + shared_memory.clean_up_barrier.wait() def _validate(device_id): device_id = int(device_id) @@ -1359,7 +1377,7 @@ def interpret_pallas_call( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: Any, + compiler_params: mosaic_core.TPUCompilerParams, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, @@ -1499,6 +1517,11 @@ def interpret_pallas_call( var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) + if compiler_params.get('mosaic', {}).get('collective_id', None) is None: + # The kernel doesn't specify its own barrier semaphore, so we do a global + # barrier before running the first iteration of the kernel. + callback.io_callback(_barrier, (), device_id, ordered=True) + _, input_ids, kernel_output_ids, _ = split_list( kernel_buffer_ids, [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 518c16ed2109..1ed139e9e867 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,8 +18,6 @@ contains only tests that use shard_map. """ -import functools - from absl.testing import absltest from absl.testing import parameterized @@ -1017,19 +1015,6 @@ def test_race_detection(self): input_arr = jax.device_put(input_arr, sharding) def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): - # Barrier with all devices before doing any DMAs. - barrier_sem = pltpu.get_barrier_semaphore() - @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) - def _(i, _): - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(jnp.int32(i),), - device_id_type=pltpu.DeviceIdType.MESH, - ) - return None - pltpu.semaphore_wait(barrier_sem, num_devices) - # Send the specified DMAs. my_id = lax.axis_index('x') src_dst_ids = src_dst_ids_ref[:] @@ -1076,7 +1061,6 @@ def run(src_dst_ids): ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - compiler_params=pltpu.TPUCompilerParams(collective_id=0), interpret=mosaic_interpret.TPUInterpretParams( dma_execution_mode='eager', detect_races=True, From 788ad8c6a2ded930a2cf4379780a749b680c5ba0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 09:02:09 -0700 Subject: [PATCH 0116/1769] Change `python-tag` to `python_tag` to conform to the new setuptools version. PiperOrigin-RevId: 739958612 --- jaxlib/tools/build_gpu_plugin_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 667807b51197..d52cc7da36e8 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -81,7 +81,7 @@ def write_setup_cfg(sources_path, cpu): [bdist_wheel] plat_name={tag} -python-tag=py3 +python_tag=py3 """ ) From c1f65c3e1f045106e090e41d74b6968fed824b1c Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 09:27:50 -0700 Subject: [PATCH 0117/1769] Update CUDA version in Bazel configs to 12.8, and CUDNN version to 9.8. PiperOrigin-RevId: 739967341 --- .bazelrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index fb938169b3c0..642fb15ed541 100644 --- a/.bazelrc +++ b/.bazelrc @@ -141,8 +141,8 @@ build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true # This config is used for building targets with CUDA libraries from stubs. From 198d7bb9c29bbf20dab893739cf546a6e78f4c18 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 09:29:55 -0700 Subject: [PATCH 0118/1769] [pallas] Add support for `split` into any power-of-two equal parts in Triton lowering. PiperOrigin-RevId: 739968019 --- jax/_src/pallas/triton/lowering.py | 23 +++++++++++++++-------- tests/pallas/pallas_test.py | 15 ++++++++++----- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d7f6e4695229..a0883ea589b0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1802,15 +1802,22 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): pass # TODO(cjfj): Add support for larger powers of 2. - if len(sizes) != 2: - raise NotImplementedError("Only splitting into two parts is supported.") - if sizes[0] != sizes[1]: + num_parts = len(sizes) + if num_parts != pallas_utils.next_power_of_2(num_parts): + raise NotImplementedError("Only power-of-2 num parts supported.") + if any(size != sizes[0] for size in sizes): raise NotImplementedError("Only equal-sized splits are supported.") - (x_aval,) = ctx.avals_in - shape = x_aval.shape - x = _reshape(x, shape[:axis] + (2, sizes[0]) + shape[axis + 1 :]) - permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) - return tt_dialect.split(tt_dialect.trans(x, permutation)) + + def split_into_2(x): + shape = ir.RankedTensorType(x.type).shape + x = _reshape(x, shape[:axis] + [2, shape[axis] // 2] + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tuple(tt_dialect.split(tt_dialect.trans(x, permutation))) + + x_parts = (x,) + while len(x_parts) < num_parts: + x_parts = sum(map(split_into_2, x_parts), ()) + return x_parts def _compute_offsets_from_indices( diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 0b16260ff25a..9e5130b8f449 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -844,15 +844,20 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) @parameterized.parameters( - ((32,), 0), ((32, 64), 0), ((32, 16), 1), ((32, 16, 2), 1) + ((32,), 2, 0), ((32, 64), 4, 0), ((32, 16), 8, 1), ((32, 16, 2), 16, 1) ) - def test_split(self, shape, axis): + def test_split(self, shape, num_parts, axis): + if jtu.test_device_matches(["tpu"]) and shape[axis] == num_parts: + self.skipTest("TPU doesn't support fully split axis.") + x = jax.random.normal(jax.random.key(0), shape) - expected = jnp.split(x, 2, axis) + expected = jnp.split(x, num_parts, axis) @functools.partial(self.pallas_call, out_shape=expected) - def kernel(x_ref, o0_ref, o1_ref): - o0_ref[()], o1_ref[()] = jnp.split(x_ref[()], 2, axis) + def kernel(x_ref, *o_ref): + x_parts = jnp.split(x_ref[()], num_parts, axis) + for o_ref, x_part in zip(o_ref, x_parts): + o_ref[...] = x_part self.assertAllClose(kernel(x), expected) From a2f22cc1dec0c02d5d1f0213af4c731a008775bd Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 10:46:30 -0700 Subject: [PATCH 0119/1769] [Mosaic GPU] Adding a primitive to load from memrefs *with* a specified layout. PiperOrigin-RevId: 739995908 --- jax/_src/pallas/mosaic_gpu/primitives.py | 96 +++++++++++++++++++++++- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 66 ++++++++++++++++ 3 files changed, 162 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index edfae55fb288..a27137964349 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,7 @@ import enum import itertools import math -from typing import Any, Literal +from typing import Any, Literal, Optional import jax from jax._src import core as jax_core @@ -31,6 +31,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -62,6 +63,99 @@ def _check_ref( ) +load_p = jax_core.Primitive("load") + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(src, *avals_flat, args_tree, layout): + del layout # Unused. + + transforms = args_tree.unflatten(avals_flat) + return ( + jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), + {state.ReadEffect(0)}, + ) + +@lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) +def _load_p_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout +): + if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): + raise TypeError(f"Can only load from references (got {x_ref}).") + + x_aval = ctx.avals_in[0] + + transforms = jax.tree.unflatten(args_tree, leaves) + x_ref, transforms = lowering._handle_reshaping(x_ref, transforms) + x_ref, transforms = lowering._handle_indexing(x_ref, transforms) + + if layout is not None: + layout = layout.to_mgpu() + + match transforms: + case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + if tiling != (64, swizzle // x_aval.dtype.itemsize): + raise NotImplementedError("Tiling does not fit swizzle") + return mgpu.FragmentedArray.load_tiled( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, + layout=layout + ) + case (): + # Handle scalar indexing. + if not ctx.avals_out[0].shape: + is_signed = mgpu_utils.is_signed(x_aval.dtype) + val = memref_dialect.load(x_ref, []) + return mgpu.FragmentedArray.splat(val, shape=(), layout=layout, is_signed=is_signed) + match layout: + case mgpu.WGMMARowFragLayout(): + return mgpu.FragmentedArray.load_wgmma_row( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): + ref_ty = ir.MemRefType(x_ref.type) + if shape != tuple(ref_ty.shape): + raise ValueError( + f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" + ) + + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), vec_size=vec_size, + ) + case None: + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + case _: + raise NotImplementedError(f"Unsupported layout: {layout}") + case _: + raise NotImplementedError(f"Unsupported transforms: {transforms}") + + +def load( + src: _Ref, idx, *, layout: Optional[Layout | ParameterizedLayout] = None +) -> mgpu.FragmentedArray: + """ Loads a ref (SMEM or GMEM) into a FragmentedArray with the specified layout. + + Args: + src: The reference to copy from. + idx: The index to load from. + layout: The optional layout to use for the returned FragmentedArray. + + Returns: + A FragmentedArray containing the loaded data in the specified layout. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, idx, "load", force_trailing_indexer=True, + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + return load_p.bind( + src, + *flat_src_transforms, + args_tree=src_transforms_treedef, + layout=layout + ) + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index aab58d092190..d5acb9b131ad 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -40,6 +40,7 @@ from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import load as load from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 408b8bdf5713..d31d1c9d41b2 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -634,6 +634,72 @@ def kernel(x_ref, o_ref, barrier_ref): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[ + plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WG_STRIDED((128,), vec_size=1), + None, + ], + ) + def test_load_to_layout_with_indexing(self, src_memory_space, layout): + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + in_spec = pl.BlockSpec(memory_space=src_memory_space) + out_spec = plgpu.GPUBlockSpec( + (2, 128), lambda: (0, 0), memory_space=plgpu.SMEM, + ) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + ) + x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) + np.testing.assert_array_equal(f(x), x) + + @parameterized.product(src_memory_space=[plgpu.SMEM, plgpu.GMEM]) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space): + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m,), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA_ROW) + x = lax.broadcast_in_dim(x, (m, k), [0]) + + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] + + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + + out_spec = plgpu.GPUBlockSpec( + (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, + ) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.GPUBlockSpec( + (k, n), + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ) + )), + out_specs=out_spec, + ) + + out_ref = jnp.broadcast_to(a[:, None], (m, k)) @ b + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + def test_indexing_before_transpose(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): From 92f231e875118cb114e25c8517eb5aed53729066 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 24 Mar 2025 11:31:18 -0700 Subject: [PATCH 0120/1769] Delay the unflattening in `jnp.array` Reverts 53e8eac7134a13c1d28de673e7e3a23b4a837aed PiperOrigin-RevId: 740012608 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..16355695792d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,15 +49,16 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -65,8 +66,7 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax.tree_util import tree_flatten, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,9 +5504,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) + leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5515,7 +5513,13 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves = tree_leaves(object) + leaves, treedef = tree_flatten(object) + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5530,8 +5534,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + object = treedef.unflatten(leaves) out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From b89cf0de91cdaec2388f6b9b2dc07d17a18d5b99 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Mon, 24 Mar 2025 11:40:41 -0700 Subject: [PATCH 0121/1769] Stop using mesh and `*_specs` in roofline tests. These args are optional, so not specifying them in our tests will make them simpler and easier to read. This change is a no-op. PiperOrigin-RevId: 740015584 --- tests/roofline_test.py | 93 +++++++++++------------------------------- 1 file changed, 24 insertions(+), 69 deletions(-) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index f94f5a328c46..98f6176c22a0 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest import jax -from jax._src import mesh from jax._src import test_util as jtu from jax.experimental import roofline import jax.lax as lax @@ -465,11 +464,7 @@ def collective_matmul(a, b): ) def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) - out, result = roofline.roofline( - f, - in_specs=(P()), - out_specs=P(), - )(data) + out, result = roofline.roofline(f)(data) with self.subTest("flops"): self.assertEqual(result.unfused_flops, 3 * 8) with self.subTest("hbm_bytes"): @@ -495,12 +490,9 @@ def test_binary_ops(self): lambda a, b: jnp.minimum(a, b), lambda a, b: jnp.maximum(a, b), ]: - out, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + out, result = roofline.roofline(f)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) self.assertEqual( result.unfused_hbm_bytes, @@ -515,12 +507,7 @@ def test_broadcast(self): (2.0, jnp.ones((3, 8))), (jnp.zeros((3, 8)), 2.0), ]: - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(left, right) + _, result = roofline.roofline(lambda a, b: a + b)(left, right) self.assertEqual(result.unfused_flops, 3 * 8) def test_nested(self): @@ -531,27 +518,21 @@ def g(x): return g(x) + g(y) - _, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) + _, result = roofline.roofline(f)( + jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * (11 * 4)) def test_no_mesh(self): - _, result = roofline.roofline( - lambda a, b: a + b, - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_specs(self): - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_mesh_and_no_specs(self): @@ -561,12 +542,9 @@ def test_no_mesh_and_no_specs(self): self.assertEqual(result.unfused_flops, 3 * 8) def test_dot_general(self): - _, result = roofline.roofline( - lambda a, b: a @ b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) + _, result = roofline.roofline(lambda a, b: a @ b)( + jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) + ) self.assertEqual(result.unfused_flops, 2 * 3 * 7 * 5) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) @@ -631,12 +609,7 @@ def test_conv_general_dilated_unfused_hbm_bytes( feature_group_count=feature_group_count, ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = batch * num_input_channels * iw * ih expected_kernel_size = num_output_channels * num_input_features * kw * kh @@ -677,12 +650,7 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( lhs=a, rhs=b, window_strides=(1, 1), padding=padding ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -702,12 +670,7 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -725,12 +688,7 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) def test_reduce_sum_no_axis(self): - _, result = roofline.roofline( - lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (11 * 4 + 1) @@ -743,12 +701,9 @@ def test_reduce_sum_with_axis(self): ([0, 1], 11 * 4 - 1, 11 * 4 + 1), ([], 0, 11 * 4 + 11 * 4), ]: - _, result = roofline.roofline( - lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x, axis=axis))( + jnp.zeros((11, 4)) + ) self.assertEqual(result.unfused_flops, expected_flops) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * expected_memory From 94846941e30c239e29b1f67ef0652567653b9ed3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 14:56:10 -0400 Subject: [PATCH 0122/1769] Fix mac wheel build. The xla_extension move introduced an incorrect path. --- jaxlib/tools/build_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 9967fc14b9f9..fcc811789c19 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")], + ["nm", "-g", r.Rlocation("__main__/jaxlib/xla/xla_extension.so")], capture_output=True, text=True, check=False, From 7e42539653d33ec995487b683794c0bc86f7199b Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Mon, 24 Mar 2025 11:55:42 -0700 Subject: [PATCH 0123/1769] Create `_FMA_FLOPS_FACTOR` to be used in roofline `dot` (and later `conv)`. This change is a no-op made for convenience for follow-up changes. PiperOrigin-RevId: 740020625 --- jax/experimental/roofline/rooflines.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1edd1e0649b1..bc8d65e966dd 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -36,6 +36,8 @@ from jax.experimental import shard_map +_FMA_FLOPS_FACTOR = 2 + for prim in it.chain( ad_util.__dict__.values(), ann.__dict__.values(), @@ -156,7 +158,7 @@ def _dot_general_roofline( (lhs_contract, _), (lhs_batch, _) = dimension_numbers flops = ( - 2 + _FMA_FLOPS_FACTOR * lhs.size * rhs.size / np.prod([lhs.shape[i] for i in lhs_contract]) From 13862ec10b0e3eaccb090822f103f1ae34b6e5b0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 09:13:49 -0400 Subject: [PATCH 0124/1769] Small cleanup to pretty-printer. Kidger's reimplementation of this code notes that the break mode and indent are unused in the _fits function (https://github.com/patrick-kidger/wadler_lindig/blob/851379b8f55e2bb98ea2c81905863f90f9606f0a/wadler_lindig/_wadler_lindig.py#L166). We can make the same optimization here. --- jax/_src/pretty_printer.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index e8fdff497445..d02b6d9962e0 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -201,26 +201,20 @@ def __init__(self, child: Doc, *, foreground: Color | None = None, # non-recursive formulation using an explicit stack, necessary because Python # doesn't have a tail recursion optimization. -def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] - ) -> bool: +def _fits(doc: Doc, width: int) -> bool: + agenda = [doc] while width >= 0 and len(agenda) > 0: - i, m, doc = agenda.pop() + doc = agenda.pop() if isinstance(doc, _NilDoc): pass elif isinstance(doc, _TextDoc): width -= len(doc.text) elif isinstance(doc, _ConcatDoc): - agenda.extend((i, m, d) for d in reversed(doc.children)) + agenda.extend(reversed(doc.children)) elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - return True width -= len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append((i + doc.n, m, doc.child)) - elif isinstance(doc, _GroupDoc): - agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append((i, m, doc.child)) + elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): + agenda.append(doc.child) else: raise ValueError("Invalid document ", doc) @@ -372,8 +366,7 @@ def _format( elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) - and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): + if (_sparse(doc) and _fits(doc, width - k)): agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) From 7e235e3aee527d3a4c6f6cc0b633c175303e5c46 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Mar 2025 12:43:28 -0700 Subject: [PATCH 0125/1769] jax.test_util: improve type annotations --- jax/_src/public_test_util.py | 13 ++++++++----- jax/_src/test_util.py | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 455a3b98cce2..59ddb73dc9e1 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -14,6 +14,7 @@ from functools import partial import operator +from typing import Any, TypeAlias from jax._src import api from jax._src import config @@ -32,7 +33,7 @@ EPS = 1e-4 -def _dtype(x): +def _dtype(x: Any) -> np.dtype: if hasattr(x, 'dtype'): return x.dtype elif type(x) in _dtypes.python_scalar_dtypes: @@ -40,8 +41,9 @@ def _dtype(x): else: return np.asarray(x).dtype +ToleranceDict: TypeAlias = dict[np.dtype, int | float] -_default_tolerance = { +_default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, np.dtype(_dtypes.int4): 0, @@ -76,7 +78,7 @@ def default_tolerance(): return _default_tolerance -default_gradient_tolerance = { +default_gradient_tolerance: ToleranceDict = { np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -104,7 +106,7 @@ def default_tolerance(): _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 -def is_python_scalar(val): +def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): @@ -151,7 +153,8 @@ def maybe_upcast(x): # value errors. It should not do that. np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def tolerance(dtype, tol=None): + +def tolerance(dtype: np.dtype, tol: int | float | ToleranceDict | None = None) -> int | float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 3a18d12e9d4b..c3c4a934dd0e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -51,6 +51,7 @@ from jax._src import lib as _jaxlib from jax._src import monitoring from jax._src import test_warning_util +from jax._src.typing import ArrayLike, DTypeLike from jax._src import xla_bridge from jax._src import util from jax._src import mesh as mesh_lib @@ -59,7 +60,7 @@ from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) from jax._src.util import unzip2 from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np @@ -131,10 +132,10 @@ def sanitize_test_name(s: str) -> str: return kSanitizeNameRE.sub("_", s) -def num_float_bits(dtype): +def num_float_bits(dtype: DTypeLike) -> int: return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits -def to_default_dtype(arr): +def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, @@ -145,7 +146,7 @@ def to_default_dtype(arr): dtype = _dtypes._default_types.get(arr.dtype.kind) return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr -def with_jax_dtype_defaults(func, use_defaults=True): +def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True): """Return a version of a function with outputs that match JAX's default dtypes. This is generally used to wrap numpy functions within tests, in order to make @@ -168,7 +169,7 @@ def wrapped(*args, **kwargs): return tree_map(f, result, use_defaults) return wrapped -def is_sequence(x): +def is_sequence(x: Any) -> bool: try: iter(x) except TypeError: @@ -176,14 +177,16 @@ def is_sequence(x): else: return True -def _normalize_tolerance(tol): +def _normalize_tolerance(tol: int | float | ToleranceDict | None) -> ToleranceDict: tol = tol or 0 if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: return dict.fromkeys(_default_tolerance, tol) -def join_tolerance(tol1, tol2): +def join_tolerance( + tol1: int | float | ToleranceDict | None, + tol2: int | float | ToleranceDict | None) -> ToleranceDict: tol1 = _normalize_tolerance(tol1) tol2 = _normalize_tolerance(tol2) out = tol1 @@ -192,7 +195,7 @@ def join_tolerance(tol1, tol2): return out -def check_eq(xs, ys, err_msg=''): +def check_eq(xs: Any, ys: Any, err_msg: str = '') -> None: assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) tree_all(tree_map(assert_close, xs, ys)) From f5a4d1a85c41a42ed8fb389259a241513970ff9a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 12:47:34 -0700 Subject: [PATCH 0126/1769] Enable `jax` wheel testing via Bazel. Remove jax dependencies from the Bazel test targets for `:build_jaxlib=false` and `:build_jaxlib=wheel`. `internal_test_util` is removed from the `jax` wheel. To use this package in Bazel py_test, we need to copy it to the unpacked wheel folder. This is done by adding `wheel_deps` value to `py_import` Jax targets. This change concludes ML Wheels design implementation in JAX and enables testing of all wheels via Bazel command. PiperOrigin-RevId: 740037952 --- BUILD.bazel | 44 ++++++++++++++++++++++++++ WORKSPACE | 1 + jax/BUILD | 41 ++++++++++++++++--------- jaxlib/jax.bzl | 83 +++++++++++++++++++++++++++++++++----------------- 4 files changed, 126 insertions(+), 43 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index eb43d7ec0fd8..5700fcef2e77 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", "jax_source_package", "jax_wheel", + "py_deps", ) collect_data_files( @@ -98,3 +103,42 @@ jax_source_package( source_package_binary = ":build_wheel", source_package_name = "jax", ) + +genrule( + name = "internal_test_util_sources", + srcs = [ + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:internal_export_back_compat_test_data", + ], + outs = ["internal_test_util_sources.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], +) + +COMMON_DEPS = py_deps([ + "absl/testing", + "numpy", + "ml_dtypes", + "scipy", + "opt_einsum", + "hypothesis", + "cloudpickle", +]) + +py_import( + name = "jax_py_import", + wheel = ":jax_wheel", + wheel_deps = [":internal_test_util_sources"], + deps = COMMON_DEPS, +) + +# This target is used to add internal test util sources to the jax wheel. +# This is needed for the tests that depend on jax and use internal test util sources. +py_import( + name = "jax_wheel_with_internal_test_util", + wheel = "@pypi_jax//:whl", + wheel_deps = [":internal_test_util_sources"], + deps = COMMON_DEPS, +) diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..a6968446a1ec 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,6 +16,7 @@ python_init_repositories( "3.13-ft": "//build:requirements_lock_3_13_ft.txt", }, local_wheel_inclusion_list = [ + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", diff --git a/jax/BUILD b/jax/BUILD index 12eae4afdcf7..5d37a8987445 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -167,22 +167,30 @@ py_library( ], ), visibility = [":internal"], - deps = [ - ":jax", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_test_harnesses", srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, - deps = [ - ":ad_util", - ":config", - ":jax", - ":test_util", - "//jax/_src/lib", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":ad_util", + ":config", + ":jax", + ":test_util", + "//jax/_src/lib", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( @@ -191,15 +199,18 @@ py_library( visibility = [ ":internal", ] + jax_internal_export_back_compat_test_util_visibility, - deps = [ - ":jax", - ":test_util", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_export_back_compat_test_data", - testonly = 1, srcs = glob([ "_src/internal_test_util/export_back_compat_test_data/*.py", "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index c6f55a86143f..9b8c861404c2 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -31,7 +31,6 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library -pytype_test = native.py_test nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -64,6 +63,18 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } +_GPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", + "@pypi_jax_cuda12_plugin//:pkg", + "@pypi_jax_cuda12_pjrt//:pkg", +] + +_CPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", +] + # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": @@ -223,39 +234,50 @@ ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( if_building, - if_not_building = [ - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", - ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], - if_py_import = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ], - if_py_import_for_cpu = [ - "//jaxlib/tools:jaxlib_py_import", - ]): + if_not_building = _GPU_PYPI_WHEEL_DEPS, + if_not_building_for_cpu = _CPU_PYPI_WHEEL_DEPS): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of + if_not_building: the wheels to depend on including gpu-specific plugins in case of gpu-enabled builds - if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds - if_py_import: the py_import targets to depend on in case of gpu-enabled builds - if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds + if_not_building_for_cpu: the wheels to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, + "//conditions:default": [], + }) + +def _get_test_deps(deps): + jaxlib_build_deps = [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ] + + gpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ] + cpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + ] + + return select({ + "//jax:enable_jaxlib_build": jaxlib_build_deps + deps, + "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": _GPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_imports, }) # buildifier: disable=function-docstring @@ -308,14 +330,10 @@ def jax_multiplatform_test( srcs = srcs, args = test_args, env = env, - deps = [ + deps = _get_test_deps([ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib([ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ]), + ] + deps), data = data, shard_count = test_shards, tags = test_tags, @@ -609,7 +627,16 @@ def jax_py_test( env = dict(env) if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" - py_test(name = name, env = env, **kwargs) + deps = kwargs.get("deps", []) + kwargs.pop("deps") + test_deps = _get_test_deps(deps) + py_test(name = name, env = env, deps = test_deps, **kwargs) + +def pytype_test(name, **kwargs): + deps = kwargs.get("deps", []) + kwargs.pop("deps") + test_deps = _get_test_deps(deps) + native.py_test(name = name, deps = test_deps, **kwargs) def if_oss(oss_value, google_value = []): """Returns one of the arguments based on the non-configurable build env. From 13b6e01acf84f9ee3d314e77de819960c52e3faa Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 24 Mar 2025 13:00:57 -0700 Subject: [PATCH 0127/1769] Increased tolerance in failing xla client tests. PiperOrigin-RevId: 740041921 --- jaxlib/xla/xla_client_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 5a2f3881f510..7de905d9ec41 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -1420,7 +1420,7 @@ def testDotGeneral(self): (([2], [1]), ([0], [0]))) ops.DotGeneral( ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() @@ -1436,7 +1436,7 @@ def testDotGeneralWithDotDimensionNumbersProto(self): ops.DotGeneral( ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testDotGeneralWithPrecisionConfig(self): c = self._NewComputation() @@ -1453,7 +1453,7 @@ def testDotGeneralWithPrecisionConfig(self): ops.Constant(c, rhs), dimension_numbers, precision_config=config) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testConvGeneralDilatedF32(self): c = self._NewComputation() From 9f3eb3e232bdf9355f4cd02cf91592da9b065850 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 13:06:22 -0700 Subject: [PATCH 0128/1769] Migrate more modules of xla/python to jax. PiperOrigin-RevId: 740043785 --- .../mlir/_mlir_libs/register_jax_dialects.cc | 8 +- jaxlib/xla/BUILD | 409 +++- jaxlib/xla/config.cc | 343 ++++ jaxlib/xla/config.h | 34 + jaxlib/xla/custom_call_sharding.cc | 343 ++++ jaxlib/xla/custom_call_sharding.h | 28 + jaxlib/xla/dlpack.cc | 699 +++++++ jaxlib/xla/dlpack.h | 57 + jaxlib/xla/jax_jit.cc | 495 +++++ jaxlib/xla/jax_jit.h | 265 +++ jaxlib/xla/mlir.cc | 251 +++ jaxlib/xla/mlir.h | 28 + jaxlib/xla/pjit.cc | 1402 ++++++++++++++ jaxlib/xla/pjit.h | 27 + jaxlib/xla/pmap_lib.cc | 1180 ++++++++++++ jaxlib/xla/pmap_lib.h | 37 + jaxlib/xla/sdy.cc | 143 ++ jaxlib/xla/sdy.h | 28 + jaxlib/xla/weakref_lru_cache.cc | 400 ++++ jaxlib/xla/weakref_lru_cache.h | 28 + jaxlib/xla/xla.cc | 20 +- jaxlib/xla/xla_compiler.cc | 1639 +++++++++++++++++ jaxlib/xla/xla_compiler.h | 28 + 23 files changed, 7868 insertions(+), 24 deletions(-) create mode 100644 jaxlib/xla/config.cc create mode 100644 jaxlib/xla/config.h create mode 100644 jaxlib/xla/custom_call_sharding.cc create mode 100644 jaxlib/xla/custom_call_sharding.h create mode 100644 jaxlib/xla/dlpack.cc create mode 100644 jaxlib/xla/dlpack.h create mode 100644 jaxlib/xla/jax_jit.cc create mode 100644 jaxlib/xla/jax_jit.h create mode 100644 jaxlib/xla/mlir.cc create mode 100644 jaxlib/xla/mlir.h create mode 100644 jaxlib/xla/pjit.cc create mode 100644 jaxlib/xla/pjit.h create mode 100644 jaxlib/xla/pmap_lib.cc create mode 100644 jaxlib/xla/pmap_lib.h create mode 100644 jaxlib/xla/sdy.cc create mode 100644 jaxlib/xla/sdy.h create mode 100644 jaxlib/xla/weakref_lru_cache.cc create mode 100644 jaxlib/xla/weakref_lru_cache.h create mode 100644 jaxlib/xla/xla_compiler.cc create mode 100644 jaxlib/xla/xla_compiler.h diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 64f84965b8e2..1ba6fd9375df 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -2,7 +2,6 @@ // This module is called by mlir/__init__.py during initialization. #include -#include "shardy/integrations/c/passes.h" #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/Dialect/GPU.h" @@ -15,13 +14,14 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" namespace nb = nanobind; -#define REGISTER_DIALECT(name) \ - MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ - mlirDialectHandleInsertDialect(name##_dialect, registry) +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) NB_MODULE(register_jax_dialects, m) { m.doc() = "Registers upstream MLIR dialects used by JAX."; diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 3239ba703937..592d9d1c24f3 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -43,6 +43,16 @@ nanobind_extension( pytype_srcs = glob(["xla_extension/*.pyi"]), visibility = ["//visibility:public"], deps = [ + ":config", + ":custom_call_sharding", + ":dlpack", + ":jax_jit", + ":mlir", + ":pjit", + ":pmap_lib", + ":sdy", + ":weakref_lru_cache", + ":xla_compiler", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", @@ -83,31 +93,21 @@ nanobind_extension( "@xla//xla/pjrt/distributed:service", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/python:config", - "@xla//xla/python:custom_call_sharding", - "@xla//xla/python:dlpack", "@xla//xla/python:guard_lib", - "@xla//xla/python:jax_jit", "@xla//xla/python:logging", - "@xla//xla/python:mlir", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:ops", - "@xla//xla/python:pjit", - "@xla//xla/python:pmap_lib", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:pytree", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/python:sdy", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python:util", - "@xla//xla/python:weakref_lru_cache", - "@xla//xla/python:xla_compiler", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", @@ -144,6 +144,395 @@ nanobind_extension( }), ) +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/python:python_ref_manager", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:Support", + "@nanobind", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/service/llvm_ir:llvm_util", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:guard_lib", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:traceback", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "sdy", + srcs = ["sdy.cc"], + hdrs = ["sdy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/transforms/import:passes", + "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + ], +) + +cc_library( + name = "weakref_lru_cache", + srcs = ["weakref_lru_cache.cc"], + hdrs = ["weakref_lru_cache.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/ir:hlo_module_group", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/hlo/pass:hlo_pass", + "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", + "@xla//xla/hlo/transforms/simplifiers:hlo_dce", + "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", + "@xla//xla/pjrt:compile_options_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:types", + "@xla//xla/service:call_inliner", + "@xla//xla/service:computation_placer", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service:name_uniquer", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc new file mode 100644 index 000000000000..b5bc5830acbf --- /dev/null +++ b/jaxlib/xla/config.cc @@ -0,0 +1,343 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/config.h" + +#include + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/python/python_ref_manager.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represet "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState& Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbarge + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +// TODO(phawkins): to support free-threading, we will need to add locking to +// this class. +class GlobalConfigState { + public: + static GlobalConfigState& Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject* self, visitproc visit, void* arg); + int tp_clear(int key, PyObject* self); + + // Returns the singleton object representing "value not set". + const nb::object& unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbarge collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector entries_; + std::vector include_in_jit_key_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject* self, visitproc visit, + void* arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject* value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(&mu_); + for (const auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject* value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject* self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(&mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(nb::object value, bool include_in_jit_key); + + // Returns the thread-local value if it is set, otherwise the global value. + nb::object Get(); + + // Returns the global value. + nb::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nb::object value); + + // Returns the thread-local value. + nb::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nb::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nb::object SwapLocal(nb::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + private: + int key_; +}; + +Config::Config(nb::object value, bool include_in_jit_key) { + auto& instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto& instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto& global_instance = GlobalConfigState::Instance(); + auto& instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Config* c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject* self) { + Config* c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +void BuildConfigSubmodule(nanobind::module_& m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic()); + config.def(nb::init(), nb::arg("value").none(), + nb::arg("include_in_jit_key") = false); + config.def_prop_ro("value", &Config::Get); + config.def("get_local", &Config::GetLocal); + config.def("get_global", &Config::GetGlobal); + config.def("set_local", &Config::SetLocal, nb::arg("value").none()); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none()); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none()); +} + +std::vector JitConfigs() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +} // namespace jax diff --git a/jaxlib/xla/config.h b/jaxlib/xla/config.h new file mode 100644 index 000000000000..40847bf4a370 --- /dev/null +++ b/jaxlib/xla/config.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +void BuildConfigSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc new file mode 100644 index 000000000000..f88bc93e3af3 --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/xla/custom_call_sharding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { + std::optional arg = jax::InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { + jax::InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_& m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const HloSharding& sharding, std::vector dims) { + return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error* error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace xla diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/xla/custom_call_sharding.h new file mode 100644 index 000000000000..c3470901f53e --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc new file mode 100644 index 000000000000..f6605a36f02b --- /dev/null +++ b/jaxlib/xla/dlpack.cc @@ -0,0 +1,699 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/python/util.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { + if (device.client()->platform_id() == CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == RocmId()) { + return kDLROCM; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const PjRtDevice& device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, + const PjRtClient* gpu_client, + const DLDevice& context) { + switch (context.device_type) { + case kDLCPU: + if (cpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on CPU, but no CPU backend was provided."); + } + TF_RET_CHECK(cpu_client->platform_id() == CpuId()); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLCUDA: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == CudaId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLROCM: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == RocmId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); + } + return result; +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, std::optional> cpu_client, + std::optional> gpu_client) { + // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex + // multiple PjRt clients. Devices from these PjRt clients could be expressed + // as a unified set of IFRT devices. + auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; + + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + DeviceForDLDevice(cpu_client ? cpu_pjrt_client : nullptr, + gpu_client ? gpu_pjrt_client : nullptr, + dlmt->dl_tensor.device)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( + element_type, dimensions)); + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + // TODO(phawkins): simplify the expression below once we know cpu_client is + // always non-null. + auto client = (cpu_client && device->client() == cpu_pjrt_client) + ? std::move(*cpu_client) + : std::move(*gpu_client); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, ifrt::Device* ifrt_device, + nb_class_ptr client, std::optional stream) { + ifrt::PjRtDevice* device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace xla diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h new file mode 100644 index 000000000000..5d7fd7c10bf8 --- /dev/null +++ b/jaxlib/xla/dlpack.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" + +namespace xla { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, ifrt::Device* device, + nb_class_ptr client, std::optional stream); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc new file mode 100644 index 000000000000..754272a078ed --- /dev/null +++ b/jaxlib/xla/jax_jit.cc @@ -0,0 +1,495 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/xla/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/py_values.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +// `thread_local_state.extra_jit_context` is set from Python. It's done when +// loading the Python jax modules on the main-thread. For other threads, we +// need to initialize the field the first time we access `thread_local_state`. +nb::object& initialize_local_state = *new nb::object(); + +} // namespace + +JitState& GlobalJitState() { + // Protected by the GIL. + static JitState& global_state = *new JitState(); + return global_state; +} + +JitState& ThreadLocalJitState() { + // TODO(phawkins): Google style guide forbids thread-local values with + // non-trivial destructors. + ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT + DCHECK(PyGILState_Check()); + if (thread_local_state.extra_jit_context == std::nullopt) { + CHECK(initialize_local_state.ptr() != nullptr); + // Avoids reentrant calls to the initialization function. + thread_local_state.extra_jit_context = nb::none(); + initialize_local_state(); + } + return thread_local_state; +} + +bool GetDisableJit() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.disable_jit.has_value()); + return thread_local_state.disable_jit.value_or(*global_state.disable_jit); +} + +bool GetEnableX64() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.enable_x64.has_value()); + return thread_local_state.enable_x64.value_or(*global_state.enable_x64); +} + +std::optional GetDefaultDevice() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.default_device.has_value() + ? thread_local_state.default_device + : global_state.default_device; +} + +std::optional GetPostHook() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook + : global_state.post_hook; +} + +static std::string OptionalDebugString( + const std::optional optional) { + if (optional.has_value()) { + return nb::cast(nb::str(optional.value())); + } else { + return "None"; + } +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string* out, + const xla::PyArgSignature& s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string* out, bool o) { + out->append(o ? "true" : "false"); + }; + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "default_device: %s\n" + "jax_enable_x64: %d\n" + "global_extra_jit_context: %s\n" + "thread_local_extra_jit_context: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + OptionalDebugString(default_device), jax_enable_x64, + OptionalDebugString(global_extra_jit_context), + OptionalDebugString(thread_local_extra_jit_context), + absl::StrJoin(configs, ", ", py_object_formatter)); +} + +bool CallSignature::operator==(const CallSignature& other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + ShardingEqual) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object& a, const nb::object& b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str& name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also xla::PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto& kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); + + jitlib.def( + "global_state", [&]() { return &GlobalJitState(); }, + nb::rv_policy::reference); + jitlib.def( + "thread_local_state", [&]() { return &ThreadLocalJitState(); }, + nb::rv_policy::reference); + + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + nb::arg("value").none(), nb::rv_policy::reference); + + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def("set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const xla::PyArgSignature& sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const xla::PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); + + jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h new file mode 100644 index 000000000000..303d7e69414d --- /dev/null +++ b/jaxlib/xla/jax_jit.h @@ -0,0 +1,265 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +// Flags, such as JIT disable and the x64 mode, are controlled by: +// - a global flag value, e.g., associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is +// used to implement context managers that locally override the global state. +struct JitState { + ~JitState() { + if (extra_jit_context) { + // We likely do not hold the GIL if this JitState is thread-local, so we + // hand the Python object to the global reference manager to destroy. + nanobind::object o = std::move(*extra_jit_context); + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); + extra_jit_context = std::nullopt; + } + } + + std::optional disable_jit; + std::optional enable_x64; + + // Used to manually set the default device jax should use. May be unset even + // in global state, indicating there is no manual override. + // TODO(skyewm): make this a C++ type when all JAX backends support a single + // C++ device interface + std::optional default_device; + + // Extra context that should be included in the JIT cache key. Must be + // hashable and have an equality defined. + std::optional extra_jit_context; + + // A callback that, if present, is called when a JITted function is executed + // from cache. May be unset even in global state. + std::optional post_hook; +}; + +JitState& GlobalJitState(); + +// Requires the GIL. +JitState& ThreadLocalJitState(); + +// Getters for JitState fields that first look in thread-local state, then +// fallback to global state. +bool GetDisableJit(); +bool GetEnableX64(); + +// TODO(skyewm): return a C++ type when all JAX backends support a single C++ +// device interface +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + absl::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice* device = nullptr; + bool jax_enable_x64; + + // For JIT on PJIT, we need to fallback to python whenever default_device + // changes. + std::optional default_device; + + // Opaque additional context that should be included as part of the cache key. + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; + + std::vector configs; + + bool operator==(const CallSignature& other) const; + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const CallSignature& s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in jax::Sharding and use those here. + for (const auto& sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), ShardingHash(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc new file mode 100644 index 000000000000..5905c6c6ec8d --- /dev/null +++ b/jaxlib/xla/mlir.cc @@ -0,0 +1,251 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/mlir.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { + auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; + auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const XlaComputation& computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + XlaComputation computation; + // SDY dialect may be part of the module which XLA doesn't know about. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*use_shardy=*/false)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // See PyMhloToStablehlo for an explanation of why we're allowing unregistered + // dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef module, + ParseMlirModuleString( + absl::string_view(mlir_module.c_str(), mlir_module.size()), context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("StableHLO => MHLO failed"); + } + + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + absl::string_view mlir_module, absl::string_view target) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation")); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + absl::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def("stablehlo_to_mhlo", + xla::ValueOrThrowWrapper(PyStablehloToMhlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, absl::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + absl::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("serialize_portable_artifact", + xla::ValueOrThrowWrapper(PySerializePortableArtifact), + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(RefinePolymorphicShapes( + absl::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace xla diff --git a/jaxlib/xla/mlir.h b/jaxlib/xla/mlir.h new file mode 100644 index 000000000000..f0bfd69bca6b --- /dev/null +++ b/jaxlib/xla/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildMlirSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc new file mode 100644 index 000000000000..96056708c2fb --- /dev/null +++ b/jaxlib/xla/pjit.cc @@ -0,0 +1,1402 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/jax_jit.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_array.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + xla::PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache> Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key& other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key& key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.Lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it != self->functions_.end()) { + self->functions_.erase(it); + } + }); + PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr& entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction&) = delete; + PjitFunction& operator=(const PjitFunction&) = delete; + PjitFunction(PjitFunction&&) = default; + PjitFunction& operator=(PjitFunction&&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction* func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string& function_name() const { return function_name_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector& static_argnums() const { return static_argnums_; } + const std::vector& static_argnames() const { + return static_argnames_; + } + const nb::object& global_cache_key() const { return global_cache_key_; } + const xla::nb_class_ptr& cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); + + void PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + xla::nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + xla::nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback( + nb::handle arg, nb::handle sharding, nb::handle layout, + const nb::callable& fallback, + std::vector>& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr>> +PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto& num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector> num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector> arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + xla::DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device* data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1) { + data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto on_device_fn, + DevicePut(arg, executable.ifrt_loaded_executable()->client(), + data_device, options, xla::ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(on_device_fn)(); + }()); + + num_args_arrays.push_back(std::move(on_device.ifrt_array)); + if (on_device.owning_pybuffer) { + keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); + } + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + xla::PyArray py_array = nb::borrow(arg); + const auto& sharding = py_array.sharding(); + int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `xla::PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto& ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto& copy_group = + copy_groups[std::make_pair(ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind())]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty()) { + xla::ifrt::Client* const ifrt_client = + executable.ifrt_loaded_executable()->client(); + xla::ifrt::DeviceListRef ifrt_devices = + ifrt_client->MakeDeviceList({addressable_devices[0]}); + for (auto& [key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + /*memory_kind=*/std::nullopt, + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + continue; + } + + xla::PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stablizes. + if (!py_array.committed() && + jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception& e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, + cache_entry->in_shardings, cache_entry->in_device_local_layouts, + shard_arg_fallback_, keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + auto traceback = xla::Traceback::Get(); + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + xla::PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + traceback, std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + bool jax_enable_x64 = GetEnableX64(); + + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + xla::PyArgSignatureOfValue(arg, jax_enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject* PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction* AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, + size_t nargs, PyObject* kwnames) { + PjitFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + PjitFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject* self) { + PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject* PjitFunction_tp_repr(PyObject* self) { + try { + const std::string& repr = absl::StrFormat( + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject* fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); + if (!cache) { + cache = xla::make_nb_class( + PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_& m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the xla_extension module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object& self) { + PjitFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + xla::nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->PythonSignature(); + }); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + xla::nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); +} + +} // namespace jax diff --git a/jaxlib/xla/pjit.h b/jaxlib/xla/pjit.h new file mode 100644 index 000000000000..545fb2307783 --- /dev/null +++ b/jaxlib/xla/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_& m); +} + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc new file mode 100644 index 000000000000..5582eccf4f8b --- /dev/null +++ b/jaxlib/xla/pmap_lib.cc @@ -0,0 +1,1180 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/jax_jit.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/py_device.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharded_device_array.h" +#include "xla/python/sharding.h" +#include "xla/python/to_ifrt_sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + tsl::RCReference ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto* ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), + ifrt_client->MakeDeviceList(ifrt_devices), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + std::vector> per_device_arrays; + per_device_arrays.reserve(n_devices); + absl::InlinedVector devices; + devices.reserve(n_devices); + // TODO(hyeontaek): The created array will never be disassembled. We should + // omit collecting shapes and make the OpaqueSharding non-disassemblable? + std::vector shapes; + shapes.reserve(n_devices); + + nb::list owning_pylist; + ShardArgResult result; + result.owning_sda = owning_pylist; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector device_put_fns; + device_put_fns.reserve(n_devices); + xla::DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(arg[indices[i]], to_device->client()->ifrt_client(), + to_device->device(), options, xla::ifrt::MemoryKind())); + } + std::vector device_puts; + device_puts.reserve(n_devices); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + per_device_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + per_device_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(per_device_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + if (per_device_arrays.empty()) { + return xla::InvalidArgument("Per-device arrays must not be empty."); + } + // TODO(hyeontaek): The logical shape here is inaccurate. We + // may want to avoid creating a new Array or specialize Array + // to disallow access to the logical shape. + xla::ifrt::Shape shape = per_device_arrays.front()->shape(); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + xla::GetIfrtConcreteSharding(input_spec.array_sharding, shape, shapes)); + TF_ASSIGN_OR_RETURN( + result.ifrt_array, + per_device_arrays.front() + ->client() + ->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(ifrt_sharding), + absl::MakeSpan(per_device_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + xla::PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction&) = delete; + PmapFunction& operator=(const PmapFunction& other) = delete; + PmapFunction(PmapFunction&&) = default; + PmapFunction& operator=(PmapFunction&&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + nb::object PythonSignature() { + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const std::string& function_name() const { return function_name_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector& static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction* func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + const bool jax_enable_x64 = GetEnableX64(); + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto& pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + xla::nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error& e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector>& devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr& entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry& cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception& e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + // 1. Parse arguments. + std::vector& input_devices = cache_entry.devices; + std::vector& input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector> num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + auto traceback = xla::Traceback::Get(); + // TODO(jblespiau): Change the `client` function to return a reference. + xla::nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto& output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec& result_spec = output_specs[i]; + xla::PyArray py_array( + result_spec.out_aval, result_spec.weak_type, cache_entry.out_dtypes[i], + cache_entry.out_shapes[i], cache_entry.out_array_shardings[i], client, + traceback, std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject* JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, + PyObject* const* args, size_t nargs, + PyObject* kwnames) { + JaxPmapFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + JaxPmapFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + JaxPmapFunctionObject* o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + JaxPmapFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject* self) { + JaxPmapFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); + JaxPmapFunctionObject* buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +} // namespace + +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", + [](const NoSharding& chuncked) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked& chuncked) { + return absl::StrCat("Chunked(", + absl::StrJoin(chuncked.chunks, ","), ")"); + }) + .def("__eq__", [](const Chunked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked& x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis& x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) { + return self == other; + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated& x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated& self, const Replicated& other) { + return self == other; + }); + + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", [](const ShardingSpec& self, + const ShardingSpec& other) { return self == other; }) + .def("__hash__", [](const ShardingSpec& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + + // Add PmapFunction to the xla_extension module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }); + // Required by `post_hook`. + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object& self) { + PmapFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + xla::nb_class_ptr pytree_registry = + nb::cast>( + pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + xla::nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); +} + +} // namespace jax diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h new file mode 100644 index 000000000000..9ad60a03daf6 --- /dev/null +++ b/jaxlib/xla/pmap_lib.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ + +#include +#include +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ diff --git a/jaxlib/xla/sdy.cc b/jaxlib/xla/sdy.cc new file mode 100644 index 000000000000..c6d1145517d8 --- /dev/null +++ b/jaxlib/xla/sdy.cc @@ -0,0 +1,143 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sdy.h" + +#include +#include + +#include "mhlo/transforms/passes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +} // namespace + +void BuildSdySubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("sdy", "Shardy/XLA integration"); + + mlir_module + // TODO(b/707574930): define a C API for the XLA pipelines. + .def( + "sdy_round_trip_export_pipeline", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + sdy::addSdyRoundTripExportPipeline(pm); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def( + "sdy_round_trip_import_shardings", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + pm.addPass(xla::sdy::createSdyRoundTripImportShardyAttrsPass()); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def("lowered_with_shardy", + [](const nb::bytes& bytecode) -> bool { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + return mlir::sdy::getMeshAttr(module.get(), "mesh") || + sdy::tryGetFrontendAttr( + module.get(), sdy::kMeshesRoundTripAttr) + .has_value(); + }) + // TODO(bartchr): delete this and all uses of it once I have JAX export + // support multiple meshes. + .def("get_mesh", [](const nb::bytes& bytecode) -> nb::list { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), context)); + auto mesh_op = + mlir::SymbolTable::lookupNearestSymbolFrom( + module.get(), mlir::StringAttr::get(&context, "mesh")); + if (!mesh_op) { + return {}; + } + nb::list mesh_shape; + for (auto axis : mesh_op.getMeshAttr().getAxes()) { + mesh_shape.append( + nb::make_tuple(axis.getName().str(), axis.getSize())); + } + return mesh_shape; + }); +} + +} // namespace xla diff --git a/jaxlib/xla/sdy.h b/jaxlib/xla/sdy.h new file mode 100644 index 000000000000..5d8c8c2eb7dd --- /dev/null +++ b/jaxlib/xla/sdy.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildSdySubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ diff --git a/jaxlib/xla/weakref_lru_cache.cc b/jaxlib/xla/weakref_lru_cache.cc new file mode 100644 index 000000000000..80498f30aaef --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache.cc @@ -0,0 +1,400 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/weakref_lru_cache.h" + +#include + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter& rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key& other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key& key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize) + : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); + } + return value.cache; + } + + nb::object Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache& cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() + ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { + inserted = true; + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } + } + std::vector GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { + nb::tuple result = + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; + } + CacheInfo GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_.Capacity(); + result.currsize = lru_list_.Size(); + return result; + } + void Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto& entry : entries_) { + deferred_deletes.push_back(std::move(entry.second.cache)); + } + entries_.clear(); + deferred_deletes.clear(); + } + + nb::callable cache_context_fn_; + nb::callable fn_; + Cache::LRUList lru_list_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + for (const auto& [wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + for (const auto& [key, cache_value] : *wr_value.cache) { + int rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + if (cache_value.value.has_value()) { + cache_value.value->get()->tp_traverse(visit, arg); + } + } + } + return 0; + } + + static int tp_clear(PyObject* self) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + return 0; + } + + static PyType_Slot slots_[]; +}; + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void*)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +void BuildWeakrefLRUCacheAPI(nb::module_& m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { + return std::make_shared(cache_context_fn, fn, maxsize); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); +} + +} // namespace jax diff --git a/jaxlib/xla/weakref_lru_cache.h b/jaxlib/xla/weakref_lru_cache.h new file mode 100644 index 000000000000..444e01cef575 --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildWeakrefLRUCacheAPI(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 5f39b9173b89..fdd4456b238c 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -46,6 +46,7 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/sdy.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" @@ -64,7 +65,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/py_client.h" #include "xla/python/py_program.h" -#include "xla/python/sdy.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -84,6 +84,15 @@ limitations under the License. #include "xla/backends/cpu/collectives/mpi_collectives.h" #endif // !_WIN32 && !PLATFORM_GOOGLE +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/custom_call_sharding.h" +#include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/mlir.h" +#include "jaxlib/xla/pjit.h" +#include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/weakref_lru_cache.h" +#include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" @@ -92,22 +101,15 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" -#include "xla/python/config.h" -#include "xla/python/custom_call_sharding.h" -#include "xla/python/dlpack.h" #include "xla/python/guard_lib.h" -#include "xla/python/jax_jit.h" #include "xla/python/logging.h" // IWYU pragma: keep -#include "xla/python/mlir.h" #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" -#include "xla/python/pjit.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" -#include "xla/python/pmap_lib.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" #include "xla/python/py_array.h" @@ -120,8 +122,6 @@ limitations under the License. #include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" -#include "xla/python/weakref_lru_cache.h" -#include "xla/python/xla_compiler.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" #include "tsl/platform/platform.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc new file mode 100644 index 000000000000..f4719b450988 --- /dev/null +++ b/jaxlib/xla/xla_compiler.cc @@ -0,0 +1,1639 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/dlpack.h" +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_client.h" +#include "xla/python/types.h" +#include "xla/service/call_inliner.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace nanobind { +namespace detail { + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, + const_name("xla::OpMetadata")); + + bool from_python(handle h, uint8_t, cleanup_list*) noexcept { + handle op_type = getattr(h, "op_type"); + if (!op_type.is_none()) { + value.set_op_type(cast(op_type)); + } + handle op_name = getattr(h, "op_name"); + if (!op_name.is_none()) { + value.set_op_name(cast(op_name)); + } + handle source_file = getattr(h, "source_file"); + if (!source_file.is_none()) { + value.set_source_file(cast(source_file)); + } + handle source_line = getattr(h, "source_line"); + if (!source_line.is_none()) { + value.set_source_line(cast(source_line)); + } + return true; + } +}; + +} // namespace detail +} // namespace nanobind + +namespace xla { +namespace { + +namespace nb = nanobind; + +struct Uniquer { + absl::Mutex mu; + NameUniquer name_uniquer ABSL_GUARDED_BY(mu); +}; + +Uniquer* GetUniquer() { + static Uniquer* uniquer = new Uniquer; + return uniquer; +} + +static std::string UniquifyName(const std::string& name) { + Uniquer* uniquer = GetUniquer(); + absl::MutexLock lock(&uniquer->mu); + return uniquer->name_uniquer.GetUniqueName(name); +} + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule& module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes& bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation& computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axises of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::object fn, + const std::string& platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + +template +void DefRepeatedProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, std::vector new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type& e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, nb::sequence new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal& obj) { + const Shape& shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout& layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!LayoutUtil::IsDenseArray(shape)) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_& m) { + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout* self, nb::sequence minor_to_major, nb::sequence tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector result; + result.reserve(layout.tiles().size()); + for (auto& t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def("__eq__", [](const Layout& layout, + const Layout& other) { return layout == other; }) + .def("__ne__", [](const Layout& layout, + const Layout& other) { return layout != other; }) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout& layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(Layout::CreateFromProto(result)); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static("array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = + SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout( + type, dims, std::nullopt, dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape& shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); + }) + .def("layout", + [](const Shape& shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::rank) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape& shape, + const Shape& other) { return shape == other; }) + .def("__ne__", [](const Shape& shape, + const Shape& other) { return shape != other; }) + .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape& param : params) { + *self->add_parameters() = param; + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + nb::class_(m, "ShapeIndex") + .def("__init__", + [](ShapeIndex* self, const std::vector& v) { + new (self) ShapeIndex(v.begin(), v.end()); + }) + .def("__repr__", &ShapeIndex::ToString) + .def("__eq__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind == other; }) + .def("__ne__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind != other; }) + .def("__hash__", + [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); }); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation* comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + absl::string_view name() const { return comp_->name(); } + void render_html(const std::string& filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation* comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def( + "to_string", + static_cast( + &HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation* comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + nb::class_ hlo_module_group_class(m, "HloModuleGroup"); + hlo_module_group_class + .def("__init__", + [](HloModuleGroup* self, const std::string& name, + const std::vector>& hlo_modules) { + std::vector> modules; + modules.reserve(hlo_modules.size()); + for (const auto& m : hlo_modules) { + modules.push_back(m->Clone(/*suffix=*/"")); + } + new (self) HloModuleGroup(name, std::move(modules)); + }) + .def_prop_ro("name", &HloModuleGroup::name) + .def("to_string", &HloModuleGroup::ToString) + .def("to_modules", + [](HloModuleGroup& m) -> std::vector> { + std::vector> modules = + m.ConsumeModules(); + std::vector> shared_modules; + shared_modules.reserve(modules.size()); + for (auto& module : modules) { + shared_modules.push_back(std::move(module)); + } + return shared_modules; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule& hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient* client, const HloModule& module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string& hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + nb::class_ xla_op_class(m, "XlaOp"); + + nb::class_(m, "XlaBuilder") + .def("__init__", + [](XlaBuilder* self, const std::string& name) { + new (self) XlaBuilder(UniquifyName(name)); + }) + // TODO(phawkins): delete capitalized names after updating callers. + .def("Build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("GetShape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def("build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) + .def("get_shape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def( + "get_program_shape", + [](const XlaBuilder& builder, + std::optional root) -> absl::StatusOr { + return root ? builder.GetProgramShape(*root) + : builder.GetProgramShape(); + }, + nb::arg("root") = std::nullopt) + .def("is_constant", xla::ValueOrThrowWrapper(&XlaBuilder::IsConstant)) + .def("set_op_metadata", &XlaBuilder::SetOpMetadata) + .def("set_sharding", &XlaBuilder::SetSharding) + .def("clear_sharding", &XlaBuilder::ClearSharding) + .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes) + .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes) + .def("setup_alias", + [](XlaBuilder& builder, const std::vector& output_index, + int64_t param_number, const std::vector& param_index) { + builder.SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); + }); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment& da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions& self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions& options) { return options.profile_version; }, + [](CompileOptions& options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions& options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string& platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto& [name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions& options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions& options, + const nb::bytes& serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding + .def_prop_ro_static( + "Type", + [op_sharding_type](const nb::object&) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object&) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding& sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding& sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def("__eq__", [](const xla::HloSharding& a, + const xla::HloSharding& b) { return a == b; }) + .def("__hash__", + [](const xla::HloSharding& self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding& self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding& self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto); + + nb::class_ frontend_attributes(m, "FrontendAttributes"); + frontend_attributes.def(nb::init<>()) + .def("__setitem__", + [](FrontendAttributes* attr, std::string key, std::string value) { + (*attr->mutable_map())[key] = value; + }); + + nb::enum_(m, "PrecisionConfig_Precision") + .value("DEFAULT", PrecisionConfig::DEFAULT) + .value("HIGH", PrecisionConfig::HIGH) + .value("HIGHEST", PrecisionConfig::HIGHEST); + + nb::enum_(m, "ResultAccuracy_Mode") + .value("DEFAULT", ResultAccuracy::DEFAULT) + .value("HIGHEST", ResultAccuracy::HIGHEST); + + nb::enum_(m, "FftType") + .value("FFT", FftType::FFT) + .value("IFFT", FftType::IFFT) + .value("RFFT", FftType::RFFT) + .value("IRFFT", FftType::IRFFT); + + // Hlo Module Passes + nb::class_ hlo_pass_interface(m, "HloPassInterface"); + hlo_pass_interface.def_prop_ro("name", &HloPassInterface::name) + .def("is_pass_pipeline", &HloPassInterface::IsPassPipeline) + .def("run", + [](HloPassInterface& pass, HloModule* module) -> bool { + return xla::ValueOrThrow(pass.Run(module)); + }) + .def("run_on_module_group", + [](HloPassInterface& pass, HloModuleGroup* module_group) -> bool { + return xla::ValueOrThrow(pass.RunOnModuleGroup(module_group)); + }); + + nb::class_(m, "HloDCE").def(nb::init<>()); + nb::class_(m, "CallInliner").def(nb::init<>()); + nb::class_(m, "FlattenCallGraph") + .def(nb::init<>()); + nb::class_(m, "TupleSimplifier") + .def(nb::init<>()); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla/xla_compiler.h new file mode 100644 index 000000000000..f3ffe5fe9440 --- /dev/null +++ b/jaxlib/xla/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ From ff718862f2b632a1f8e75c4b2c16f48e0f27724d Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 13:49:18 -0700 Subject: [PATCH 0129/1769] [Mosaic GPU] Adding a new layout WGMMAColFragLayout to be able to load a 1d array and broadcast it along the leading dimension to a 2d shape as an input to a wgmma. In this new layout the first 4 threads of a warp group hold 8 uniques values. These values are replicated in each (thread_idx % 4) group. PiperOrigin-RevId: 740058172 --- .../mosaic/gpu/fragmented_array.py | 95 ++++++++++++++++++- tests/mosaic/gpu_test.py | 34 +++++++ 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c2b61c6d5bfe..b730e34e2ed0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -420,6 +420,23 @@ def thread_idxs(self, shape): yield (row,) +@dataclasses.dataclass(frozen=True) +class WGMMAColFragLayout: + """[n] matrix, where n % 8 == 0.""" + + def thread_idxs(self, shape): + index = ir.IndexType.get() + assert len(shape) == 1 + assert shape[0] % 8 == 0 + + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + lane_id = arith.remui(tid, c(WARP_SIZE, index)) + col_base = arith.muli(arith.remui(lane_id, c(4, index)), c(2, index)) + + for col_group in range(0, shape[0], 8): + col = arith.addi(col_base, c(col_group, index)) + yield (col,) + @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -530,10 +547,11 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | WGMMAColFragLayout | TiledLayout WGMMA_ROW_LAYOUT = WGMMARowFragLayout() +WGMMA_COL_LAYOUT = WGMMAColFragLayout() # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d @@ -651,6 +669,12 @@ def __init__( if _registers.ndim != 2 or _registers.shape[-1] != 2: raise ValueError(f"Invalid register array shape: {_registers.shape}") + # Registers are [n_tiles] in WGMMA_COL layout + # Each element is a vector of size 2. + case WGMMAColFragLayout(): + if _registers.ndim != 1: + raise ValueError(f"Invalid register array shape: {_registers.shape}") + # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -731,6 +755,36 @@ def load_wgmma_row( registers = np.array(registers).reshape(-1, 2) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + @classmethod + def load_wgmma_col( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + ): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) + layout = WGMMAColFragLayout() + + if len(shape) != 1: + raise ValueError("WGMMAColFragLayout requires a 1D shape.") + + if shape[0] % 8: + raise ValueError( + f"WGMMAColFragLayout requires {shape[0]=} to be a multiple of 8." + ) + + vec_ty = ir.VectorType.get((2,), ref_ty.element_type) + new_regs = np.full((shape[0] // 8,), llvm.mlir_undef(vec_ty)) + + for col_tile, (idx,) in enumerate(layout.thread_idxs(shape)): + reg = vector.load(vec_ty, ref, [idx]) + new_regs[col_tile] = reg + + return cls(_registers=new_regs, _layout=layout, _is_signed=is_signed) @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @@ -755,6 +809,9 @@ def shape(self): case WGMMARowFragLayout(): row_tiles = self.registers.shape[0] return (row_tiles * 64,) + case WGMMAColFragLayout(): + col_tiles = self.registers.shape[0] + return (col_tiles * 8,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -768,7 +825,7 @@ def shape(self): def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGStridedFragLayout() | TiledLayout(): + case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty @@ -1745,6 +1802,23 @@ def broadcast_minor(self, n): _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) + def broadcast_major(self, m): + if not isinstance(self.layout, WGMMAColFragLayout): + raise NotImplementedError + + if m % 64: + raise ValueError("Number of rows must be divisible by 64") + + reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) + new_regs = np.empty(reg_shape, dtype=object) + for col_tile, reg in np.ndenumerate(self.registers): + tile = [slice(None)] * len(new_regs.shape) + tile[1] = col_tile + new_regs[tuple(tile)] = reg + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + ) + def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) @@ -1802,6 +1876,8 @@ def vs_unsupported(): match self.layout: case WGMMARowFragLayout(): self._store_untiled_wgmma_row(ref) + case WGMMAColFragLayout(): + self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1865,6 +1941,21 @@ def _store_untiled_wgmma_row(self, ref: ir.Value): ): memref.store(value, ref, [idx]) + def _store_untiled_wgmma_col(self, ref: ir.Value): + """Stores an array with a WGMMA col layout.""" + assert isinstance(self.layout, WGMMAColFragLayout) + index = ir.IndexType.get() + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) + + # Consecutive groups of 4 threads replicate the same data, so we only need to + # transfer data from one group. + is_first = arith.cmpi(arith.CmpIPredicate.ult, tid_wg, c(4, index)) + + with utils.when(is_first): + for (idx,), reg in zip(self.layout.thread_idxs(self.shape), self.registers): + vector.store(reg, ref, [idx]) + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 478064188750..d9f56ee1d454 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1982,6 +1982,40 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) + @parameterized.product( + in_shape=((128,), (64,)), dtype=[jnp.float16, jnp.float32] + ) + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_wgmma_col(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(result, inp) + + @parameterized.parameters((128, 128), (128, 64), (64, 128)) + def test_broadcast_major(self, m, n): + def kernel(ctx, *args): + gmem_input, gmem_output, () = args + t = mgpu.FragmentedArray.load_wgmma_col(gmem_input) + t.broadcast_major(m).store_untiled(gmem_output) + + inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) + + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, () + )(inp) + + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx From e75226392384824fecbd8859b3f2a4da84f898ab Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 24 Mar 2025 13:49:48 -0700 Subject: [PATCH 0130/1769] [pallas] Index Pallas refs instead of using `pl.load` and `pl.store` Indexing is less verbose and is thus easier to read in most cases. The functional API is really only necessary for masked loads and stores. PiperOrigin-RevId: 740058341 --- .../pallas/ops/tpu/flash_attention.py | 116 ++++++++---------- .../splash_attention_kernel.py | 40 +++--- tests/pallas/ops_test.py | 2 +- tests/pallas/pallas_test.py | 26 ++-- tests/pallas/tpu_pallas_test.py | 17 ++- 5 files changed, 85 insertions(+), 116 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0cb3d798d09e..ef8dd61abacb 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -391,9 +391,9 @@ def body(i, _): l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] start_k = i * block_k - k = pl.load( - k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) # [block_k, head_dim] + k = k_tile_ref[ + (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ] # [block_k, head_dim] s = jax.lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 @@ -403,10 +403,9 @@ def body(i, _): # TODO(tanburn) Should the attention bias be added before or after # multiplication by sm_scale? if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, + ab = ab_tile_ref[ (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) - ).astype(jnp.float32) + ].astype(jnp.float32) s += ab if sm_scale != 1.0: @@ -422,10 +421,9 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, - (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + batch_idx[0], :1, pl.dslice(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -471,9 +469,7 @@ def body(i, _): l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) - v = pl.load( - v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) + v = v_tile_ref[(*batch_idx, pl.dslice(start_k, block_k), slice(None))] o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32 ) @@ -529,15 +525,13 @@ def _flash_attention_kernel_single_batch_single_step( raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (batch_idx[0],) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + batch_idx[0] + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :1] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -840,33 +834,27 @@ def q_body(j, _): start_q = j * block_q def k_body(i, _): start_k = i * block_k - k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, head_dim] - l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ).astype(jnp.float32) # [block_q, 128] + k = k_tile_ref[0, 0, pl.ds(start_k, block_k), :] + v = v_tile_ref[0, 0, pl.ds(start_k, block_k), :] + q = q_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, head_dim] + l = l_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + m = m_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + do = do_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + di = di_tile_ref[0, 0, pl.ds(start_q, block_q), :].astype( + jnp.float32 + ) # [block_q, 128] capped_logits = lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) # [block_q_major, block_k] if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, - ( - 0, - 0, - pl.dslice(j * block_q, block_q), - pl.dslice(i * block_k, block_k), - ), - ).astype(jnp.float32) + ab = ab_tile_ref[ + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ].astype(jnp.float32) capped_logits += ab if sm_scale != 1.0: @@ -878,15 +866,15 @@ def k_body(i, _): if rem: raise NotImplementedError( ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + 0, pl.ds(start_q, block_q), : + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + :, 0, pl.ds(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -913,9 +901,9 @@ def k_body(i, _): 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 ) # [block_q_major, block_k_major] dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) - pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dv.astype(dv_scratch_ref.dtype)) + dv_scratch_ref[pl.ds(start_k, block_k), :] += dv.astype( + dv_scratch_ref.dtype + ) # di: [block_q, 128] # do: [block_q, head_dim] @@ -931,9 +919,9 @@ def k_body(i, _): # ds: [block_q_major, block_k_major] # q: [block_q_major, head_dim] dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) - pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dk.astype(dk_scratch_ref.dtype)) + dk_scratch_ref[pl.ds(start_k, block_k), :] += dk.astype( + dk_scratch_ref.dtype + ) lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True) if causal: @@ -1192,12 +1180,8 @@ def start_new_sequence(): def body(i, _): k_slice = pl.ds(i * block_k, block_k) q = q_tile_ref[0, 0, :, :] - k = pl.load( - k_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] - v = pl.load( - v_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] + k = k_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] + v = v_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] @@ -1208,9 +1192,9 @@ def body(i, _): ) if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) - ).astype(jnp.float32) + ab = ab_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)].astype( + jnp.float32 + ) capped_logits += ab if sm_scale != 1.0: @@ -1226,9 +1210,7 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[0], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, k_slice) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[:, 0, k_slice] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -1269,10 +1251,8 @@ def body(i, _): ds = ds * sm_scale if ds_tile_ref is not None: - pl.store( - ds_tile_ref, - (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), - ds.astype(ds_tile_ref.dtype), + ds_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)] = ds.astype( + ds_tile_ref.dtype ) # dp: [block_q_major, block_k] diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index d0fb6f2f9670..b69b0e36f177 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -599,9 +599,9 @@ def _apply_mask_and_soft_cap( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] masks.append( jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape)) @@ -630,7 +630,7 @@ def _apply_mask_and_soft_cap( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -644,7 +644,7 @@ def _apply_mask_and_soft_cap( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -655,9 +655,9 @@ def _apply_mask_and_soft_cap( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) def cap_logits(logits): @@ -743,9 +743,9 @@ def body(kv_compute_index, _): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] else: - k = pl.load(k_ref, (slice(None), slice_k)) + k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) @@ -794,9 +794,9 @@ def body(kv_compute_index, _): sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] else: - v = pl.load(v_ref, (slice(None), slice_k)) + v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) @@ -1688,13 +1688,13 @@ def body(i, _): q = q_ref[...] # We keep q potentially transposed, since it's always RHS def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: - return pl.load(ref, (slice_k, slice(None))) - return pl.load(ref, (slice(None), slice_k)).T + return ref[slice_k, :] + return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) - logsumexp = pl.load(logsumexp_ref, (pl.ds(1), slice(None))) + logsumexp = logsumexp_ref[:1, :] do = do_ref[...] - di = pl.load(di_ref, (pl.ds(1), slice(None))) + di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general( @@ -1718,10 +1718,8 @@ def _load_kv(ref, layout): ) p = jnp.exp(qk - logsumexp) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) - dv = dv.astype(dv_scratch_ref.dtype) + pl.load( - dv_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dv_scratch_ref, (slice_k, slice(None)), dv) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, @@ -1737,10 +1735,8 @@ def _load_kv(ref, layout): dk = lax.dot_general( ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 ) - dk = dk.astype(dk_scratch_ref.dtype) + pl.load( - dk_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dk_scratch_ref, (slice_k, slice(None)), dk) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 38426747d85d..8d5dc471e847 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1937,7 +1937,7 @@ def test_masked_oob_load_store_slice(self): def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), mask=mask_ref[:], other=-1.) - pl.store(o_ref, (pl.dslice(None),), x) + o_ref[...] = x x = random.normal(random.key(0), (n,)) slice_start = random.randint(random.key(2), (), 1, n) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9e5130b8f449..781934ecd682 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -128,8 +128,8 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + x_block = x_ref[:, pl.ds(i * bk, bk)] + y_block = y_ref[pl.ds(i * bk, bk), :] acc_ref[:, :] += pl.dot(x_block, y_block) acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc @@ -624,8 +624,9 @@ def test_unused_ref(self): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ) def dummy(_, o_ref): - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), - jnp.ones_like(o_ref)) + o_ref[jnp.arange(m)[:, None], jnp.arange(n)[None, :]] = jnp.ones_like( + o_ref + ) key = random.key(0) x = random.normal(key, (m, n)) @@ -667,8 +668,7 @@ def test_using_pallas_slice(self): out_shape=out_shape, ) def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) + y_ref[:4, :4] = x_ref[:4, :4] x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] @@ -1733,7 +1733,7 @@ def test_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(carry): i, j = carry @@ -1745,8 +1745,7 @@ def body(carry): sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = x_ref[0, sl, l] - s = pl.load(r_ref, (0, 0)) - pl.store(r_ref, (0, 0), s + v) + r_ref[0, 0] += v return io + 1, j i = 128 @@ -1798,7 +1797,7 @@ def test_non_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(state): i, s = state @@ -1808,14 +1807,11 @@ def body(state): i, s = state sl = jax.lax.div(i, jnp.astype(128, i.dtype)) l = jax.lax.rem(i, jnp.astype(128, i.dtype)) - v = pl.load(x_ref, (0, sl, l)) + v = x_ref[0, sl, l] return i + 1, s + v i = jnp.int32(0) - s = pl.load(r_ref, (0, 0)) - - i, s = jax.lax.while_loop(cond, body, (i, s)) - pl.store(r_ref, (0, 0), s) + _, r_ref[0, 0] = jax.lax.while_loop(cond, body, (i, r_ref[0, 0])) x = jnp.arange(4096) x = jnp.reshape(x, [4, 8, 128]) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 55831ff6af1d..128fe50687a0 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -145,8 +145,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) out = self.pallas_call( body, @@ -225,7 +224,7 @@ def kernel(s_refs, src, to_store, dst, *scratch_refs): assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) + dst[pl.dslice(store_idx, 1), :] = to_store[...] # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) @@ -281,7 +280,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) def f(x): @@ -423,7 +422,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = s[None] @@ -457,7 +456,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = jnp.tile(s[None], [2, 1]) @@ -1139,8 +1138,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): - pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], - sem).wait() + pltpu.async_copy(x_hbm_ref.at[:8, :], y_hbm_ref.at[:, :128], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( @@ -2570,8 +2568,7 @@ def body(scalar_ref, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) pallas_call = self.pallas_call( body, From 777d8f27408ba519579fbc8307cb4bef0572fb9f Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 14:16:26 -0700 Subject: [PATCH 0131/1769] [Mosaic GPU] Adding pallas bindings to broadcast over the leading dimension and load a ref into WGMMAColFragLayout format. PiperOrigin-RevId: 740068368 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 ++++++ jax/_src/pallas/mosaic_gpu/primitives.py | 9 ++++++++ jax/experimental/mosaic/gpu/__init__.py | 2 ++ tests/pallas/mosaic_gpu_test.py | 27 ++++++++++++++++++------ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 004a6e7f2760..607b3028f93b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1262,6 +1262,12 @@ def _broadcast_in_dim_lowering_rule( and x.layout == mgpu.WGMMA_ROW_LAYOUT ): return x.broadcast_minor(y_aval.shape[-1]) + if ( + broadcast_dimensions == (1,) + and y_aval.ndim == x_aval.ndim + 1 + and x.layout == mgpu.WGMMA_COL_LAYOUT + ): + return x.broadcast_major(y_aval.shape[-2]) if broadcast_dimensions: raise NotImplementedError return x.broadcast(shape) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a27137964349..9665f14254f8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -110,6 +110,10 @@ def _load_p_lowering_rule( return mgpu.FragmentedArray.load_wgmma_row( x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) ) + case mgpu.WGMMAColFragLayout(): + return mgpu.FragmentedArray.load_wgmma_col( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): @@ -878,6 +882,8 @@ class Layout(enum.Enum): WGMMA = enum.auto() #: [m] matrix, where m % 64 == 0. WGMMA_ROW = enum.auto() + #: [n] matrix, where n % 8 == 0. + WGMMA_COL = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() @@ -897,6 +903,9 @@ def check_no_args(): case Layout.WGMMA_ROW: check_no_args() return mgpu.WGMMA_ROW_LAYOUT + case Layout.WGMMA_COL: + check_no_args() + return mgpu.WGMMA_COL_LAYOUT case Layout.WG_SPLAT: return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d004c7deb3df..867fd84b8b3c 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -54,7 +54,9 @@ FragmentedLayout as FragmentedLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, + WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMARowFragLayout as WGMMARowFragLayout, + WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d31d1c9d41b2..b33857df40b6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -638,6 +638,7 @@ def kernel(x_ref, o_ref, barrier_ref): src_memory_space=[plgpu.SMEM, plgpu.GMEM], layout=[ plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WGMMA_COL, plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, ], @@ -661,15 +662,27 @@ def kernel(x_ref, o_ref): x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) np.testing.assert_array_equal(f(x), x) - @parameterized.product(src_memory_space=[plgpu.SMEM, plgpu.GMEM]) - def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space): + @parameterized.product(src_memory_space=[plgpu.SMEM], + layout=[ + plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WGMMA_COL, + ],) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): m, k, n = 64, 128, 192 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m,), dtype=jnp.float16) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) def kernel(x_ref, y_ref, o_ref): - x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA_ROW) - x = lax.broadcast_in_dim(x, (m, k), [0]) + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) def compute(acc_ref): plgpu.wgmma(acc_ref, x, y_ref) @@ -697,7 +710,9 @@ def compute(acc_ref): out_specs=out_spec, ) - out_ref = jnp.broadcast_to(a[:, None], (m, k)) @ b + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) def test_indexing_before_transpose(self): From 60b3e5156aed252132f54ab6ea337935dfaa2804 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 24 Mar 2025 15:02:44 -0700 Subject: [PATCH 0132/1769] Reduced sharding in various tests. PiperOrigin-RevId: 740084295 --- tests/BUILD | 10 +++++----- tests/pallas/BUILD | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 4fab173b4e15..0a8fb9459044 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -596,9 +596,9 @@ jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 10, + "gpu": 10, + "tpu": 10, }, deps = [ "//jax:internal_test_util", @@ -1432,8 +1432,8 @@ jax_multiplatform_test( "gpu_p100x2_shardy", ], shard_count = { - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, tags = [ "multiaccelerator", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 7581ff78802b..1ea05c700938 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -328,7 +328,7 @@ jax_multiplatform_test( "tpu_gmm_test.py", ], enable_backends = ["tpu"], - shard_count = 50, + shard_count = 5, tags = [ "noasan", # Times out. "nomsan", # Times out. From 5f1ab2ee6713934b37b5d272e67494a1816cbdba Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 15:39:22 -0700 Subject: [PATCH 0133/1769] Skip checking of manylinux compliance for `jax` wheel. If `auditwheel show` is executed on `jax` wheel, the following message is printed: ``` INFO:auditwheel.main_show:This does not look like a platform wheel, no ELF executable or shared library file (including compiled Python C extension) found in the wheel archive ``` PiperOrigin-RevId: 740096302 --- ci/utilities/run_auditwheel.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..b8f80c3e6778 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) From c1904dc7eb6e74c85daae57b4d79dfaf353f850f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 24 Mar 2025 16:05:45 -0700 Subject: [PATCH 0134/1769] Update the docstring to mesh to use computation follows data and jax.jit APIs. Fixes https://github.com/jax-ml/jax/issues/27390 PiperOrigin-RevId: 740104692 --- jax/_src/mesh.py | 40 +++++++++------------------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b490febf7b0c..a8003e693459 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -194,16 +194,9 @@ def _name_to_type(self): class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial + See the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. + and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) Args: devices: A NumPy ndarray object containing JAX device objects (as @@ -214,32 +207,17 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): Examples: - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P + >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + >>> devices = np.array(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, ('x', 'y')) + >>> inp = np.arange(16).reshape(8, 2) + >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) + >>> out = jax.jit(lambda x: x * 2)(arr) + >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) """ devices: np.ndarray From 49aad1b97fa937583ff21df194fae4cd50be20eb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 16:39:35 -0700 Subject: [PATCH 0135/1769] Add the missing `flatbuffers` dependency for the tests that run under `:build_jaxlib=false`. PiperOrigin-RevId: 740115575 --- BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/BUILD.bazel b/BUILD.bazel index 5700fcef2e77..ebf852a60924 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -125,6 +125,7 @@ COMMON_DEPS = py_deps([ "opt_einsum", "hypothesis", "cloudpickle", + "flatbuffers", ]) py_import( From 51560bf3f55e50469377f08151a94d36bef0b655 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 18:24:22 -0700 Subject: [PATCH 0136/1769] [JAX] [XLA:Python] Migrate pytree module to JAX. PiperOrigin-RevId: 740142231 --- jaxlib/xla/BUILD | 49 +- jaxlib/xla/jax_jit.cc | 2 +- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/pjit.cc | 2 +- jaxlib/xla/pmap_lib.cc | 2 +- jaxlib/xla/pytree.cc | 1825 +++++++++++++++++++++++++++++++++++++++ jaxlib/xla/pytree.h | 408 +++++++++ jaxlib/xla/pytree.proto | 32 + jaxlib/xla/xla.cc | 2 +- 9 files changed, 2315 insertions(+), 9 deletions(-) create mode 100644 jaxlib/xla/pytree.cc create mode 100644 jaxlib/xla/pytree.h create mode 100644 jaxlib/xla/pytree.proto diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 592d9d1c24f3..2edc183bc49b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -14,13 +14,16 @@ load( "//jaxlib:jax.bzl", + "cc_proto_library", "if_oss", + "jax_visibility", "nanobind_extension", "py_deps", "py_strict_library", "py_strict_test", "pytype_strict_library", ) +# Placeholder: load proto_library licenses(["notice"]) @@ -50,6 +53,7 @@ nanobind_extension( ":mlir", ":pjit", ":pmap_lib", + ":pytree", ":sdy", ":weakref_lru_cache", ":xla_compiler", @@ -103,7 +107,6 @@ nanobind_extension( "@xla//xla/python:profiler", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -250,6 +253,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -270,7 +274,6 @@ cc_library( "@xla//xla/python:nb_helpers", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -328,6 +331,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":pytree", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -352,7 +356,6 @@ cc_library( "@xla//xla/python:nb_numpy", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:traceback", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", @@ -376,6 +379,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":pytree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -399,7 +403,6 @@ cc_library( "@xla//xla/python:nb_numpy", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", @@ -411,6 +414,44 @@ cc_library( ], ) +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/pytree"), + deps = [ + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/tsl/platform:logging", + ], +) + cc_library( name = "sdy", srcs = ["sdy.cc"], diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc index 754272a078ed..23abe9a8404a 100644 --- a/jaxlib/xla/jax_jit.cc +++ b/jaxlib/xla/jax_jit.cc @@ -53,13 +53,13 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/py_values.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/types.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 303d7e69414d..a000ef6773b2 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -34,11 +34,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 96056708c2fb..a13d3f0b52e3 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -52,6 +52,7 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/pytree.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -69,7 +70,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 5582eccf4f8b..c6849f8c25fd 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,6 +46,7 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -64,7 +65,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharded_device_array.h" #include "xla/python/sharding.h" #include "xla/python/to_ifrt_sharding.h" diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc new file mode 100644 index 000000000000..dd5a0bd9cf69 --- /dev/null +++ b/jaxlib/xla/pytree.cc @@ -0,0 +1,1825 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/xla/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/pytree.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/tsl/platform/logging.h" + +namespace xla { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void* arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto& field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto& field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void*)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void*)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyTreeRegistry* registry = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + nb::ft_lock_guard lock(registry->mu_); + for (const auto& [key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject* self) { + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void*)DictKey::tp_traverse}, + {Py_tp_clear, (void*)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + DictKey* key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject* self) { + DictKey* dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object& other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object& other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object& other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object& other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +template +void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o = (*leaf_predicate)(handle); + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional>& keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, leaf_predicate, keypath); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object& key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto& [key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, + const nb::iterable& x) { + const PyTreeRegistry::Registration* custom; + for (const nb::handle& h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node& node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclasss node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node& n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle& key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto& node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string& key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto& node : traversal_) { + auto* node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto& key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto& s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto& node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case jax::PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject* type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto& node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); + int num_leaves = 0; + int arity = 0; + for (nb::handle pchild : children) { + const PyTreeDef& child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); + num_leaves += child.num_leaves(); + ++arity; + } + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = arity; + node.custom = nullptr; + node.num_leaves = num_leaves; + node.num_nodes = result->traversal_.size(); + if (node_data == std::nullopt) { + node.kind = PyTreeKind::kLeaf; + ++node.num_leaves; + return result; + } + int is_nt = PyObject_IsSubclass(node_data->first.ptr(), + reinterpret_cast(&PyTuple_Type)); + if (is_nt == -1) { + throw nb::python_error(); + } + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { + node.kind = PyTreeKind::kNamedTuple; + node.node_data = node_data->first; + return result; + } + auto* registration = result->registry()->Lookup(node_data->first); + if (registration == nullptr) { + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); + } + node.kind = registration->kind; + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { + node.custom = registration; + node.node_data = node_data->second; + } else if (node.kind == PyTreeKind::kNamedTuple) { + node.node_data = node_data->first; + } else if (node.kind == PyTreeKind::kDict) { + node.sorted_dict_keys = + nb::cast>(node_data->second); + } + return result; +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(node_data.ptr()); + for (const auto& key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyTreeDef* treedef = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto& node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject* self) { + PyTreeDef* treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void*)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void*)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_& m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass); + registry.def("__reduce__", + [](nb::object self) { return self.attr("__name__"); }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }); + treedef.def("__ne__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }); + treedef.def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef& a) { + jax::PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + jax::PyTreeDefProto input; + absl::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)"); + treedef.def_static( + "make_from_node_data_and_children", + &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), + "Reconstructs a pytree from `node_data()` and `children()`."); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey"); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey& key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey& key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_)); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey& key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey& key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey"); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey& key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey& key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key(pytree, + "FlattenedIndexKey"); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey& key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey& key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h new file mode 100644 index 000000000000..722fe41169a0 --- /dev/null +++ b/jaxlib/xla/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTREE_H_ +#define JAXLIB_XLA_PYTREE_H_ + +// See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/pytree.pb.h" +#include "xla/python/nb_class_ptr.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl(nanobind::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTREE_H_ diff --git a/jaxlib/xla/pytree.proto b/jaxlib/xla/pytree.proto new file mode 100644 index 000000000000..73c087ef55ab --- /dev/null +++ b/jaxlib/xla/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index fdd4456b238c..bd3ed3205fb2 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -91,6 +91,7 @@ limitations under the License. #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/pytree.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -119,7 +120,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_memory_space.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" From b4922df2206707cdf94aec466708e4de2bc52d7c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 25 Mar 2025 01:23:47 +0000 Subject: [PATCH 0137/1769] [attrs] allow setattr on a previously non-existant attr Before this change, we handled attrs for initial-style primitives like jit/scan like this: 1. the traceable would form a jaxpr and see what attrs were touched (by jax_getattr or jax_setattr), 2. for each such attr, the traceable would do jax_getattr to get the current value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new values out, tree-unflatten, then jax_setattr the result. That approach would error if the function called `jax_setattr` to set a previously non-existant attr. That is, this would work: ```python from jax.experimental.attrs import jax_setattr class Thing: ... thing = Thing() jax_setattr(thing, 'x', 1.0) ``` but it wouldn't work under a `jax.jit`. This commit makes the same code work under a jit. We just 1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation are deleted, using a special sentinel value `dne_sentinel` to indicate the attribute initially did not exist before tracing; 2. in pjit.py's `_get_states`, when reading initial attr values before the pjit_p bind, if the attribute does not exist we don't try to read it and instead just use `dne_sentinel` as the value, which is a convenient empty pytree; 3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based on the current attr states, we map attrs that don't exist to `dne_sentinel` (rather than just erroring when the attr doesn't exist, as before). In short, we use a special value to indicate "does not exist". If `jax_getattr` supported the 'default' argument, the code would be a little cleaner since we could avoid the `if hasattr` stuff. And that's probably a useful feature to have anyway. We can add that in a follow-up. This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan etc in follow-ups. --- jax/_src/interpreters/partial_eval.py | 12 +++++--- jax/_src/pjit.py | 12 ++++---- jax/experimental/attrs.py | 3 +- tests/attrs_test.py | 43 +++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 07c516fd95c7..58b97ce2f3da 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -42,8 +42,8 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_structure) +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, + tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, @@ -1699,7 +1699,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) + set_states(self.attrs_tracked, self.attrs_inits) # reset to initial values return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], @@ -2246,11 +2246,15 @@ def trace_to_jaxpr_dynamic2( AttrStates = list def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): for ((obj, attr), val) in zip(attrs_tracked, vals): - setattr(obj, attr, val) + setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) def get_states(attrs_tracked: AttrsTracked): return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +@register_static +class DoesNotExist: ... +dne_sentinel = DoesNotExist() + def infer_lambda_input_type( axes_specs: Sequence[AbstractedAxesSpec] | None, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d690cd6e9c67..bcdbe6b1bdb7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -240,10 +240,10 @@ def _set_states(attrs_tracked, vals): jax_setattr(obj, attr, val) def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel leaves, treedef_ = tree_flatten(tree) assert treedef == treedef_ vals.extend(leaves) @@ -1354,11 +1354,11 @@ def _attr_token( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import jax_getattr, dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) + val = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel vals, treedef_ = tree_flatten(val) avals_ = map(core.shaped_abstractify, vals) if treedef != treedef_ or avals != avals_: break @@ -1367,8 +1367,8 @@ def _attr_token( return len(cases) def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) + from jax.experimental.attrs import jax_getattr, dne_sentinel + leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel) records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) for init_tree, _, (obj, attr) in attrs_tracked] cases = seen_attrs_get(fun, in_type) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 4e1dc4b8f493..bb4c7bf83b3f 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -36,6 +36,7 @@ Pytree = Any register = api_util.register_class_with_attrs +dne_sentinel = pe.dne_sentinel def jax_getattr(obj: Any, attr: str): with core.take_current_trace() as t: @@ -65,7 +66,7 @@ def new_tracer(x): return tracer if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) + init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 2334a7b98f91..169df3712899 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -360,6 +360,49 @@ def body(i, _): return i + 1, None _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + @parameterized.parameters([True, False]) + def test_setattr_doesnt_exist(self, jit): + class Thing: + ... + thing = Thing() + + def f(x): + assert (not jit) or tracing_is_ok + jax_setattr(thing, 'x', x) + + if jit: + f = jax.jit(f) + + tracing_is_ok = True + self.assertFalse(hasattr(thing, 'x')) + f(1.0) + self.assertEqual(thing.x, 1.0) + f(2.0) + self.assertEqual(thing.x, 2.0) + + tracing_is_ok = False + f(3.0) + self.assertEqual(thing.x, 3.0) + + del thing.x + f(4.0) + self.assertEqual(thing.x, 4.0) + + tracing_is_ok = True + f(5) + self.assertEqual(thing.x, 5) + + def test_setattr_doesnt_exist_doesnt_leave_sentinel_around(self): + class Thing: + ... + thing = Thing() + + def f(x): + jax_setattr(thing, 'x', x) + + jax.make_jaxpr(f)(3.) + self.assertFalse(hasattr(thing, 'x')) + class AttrsJVPTest(jtu.JaxTestCase): From ca30ce69197f90a12003b654b7697272a7a44c88 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 03:40:28 -0700 Subject: [PATCH 0138/1769] [Mosaic GPU] Add warpgroup lowering for `AxisIndex` in Pallas. PiperOrigin-RevId: 740280136 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 607b3028f93b..677f63c6674a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1738,6 +1738,7 @@ def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names if not axis_names or axis_name not in axis_names: From fce11d0e472c2479cba6869262bc117cc20b95e7 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 03:53:27 -0700 Subject: [PATCH 0139/1769] [Mosaic GPU] Use `math.inf` instead of `None` when short-cutting default layout inference. default_vector_size is initialized with `math.inf` and is never `None`. PiperOrigin-RevId: 740283678 --- jax/experimental/mosaic/gpu/layout_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index dec75e4db1a0..402a8c08a4ef 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -564,7 +564,7 @@ def update_default_vector_size(op: ir.OpView): for op in module.body: traverse_op(op, update_default_vector_size) - if default_vector_size is None: # Nothing to annotate. + if default_vector_size == math.inf: # Nothing to annotate. return def to_default_layout(ty: ir.Type) -> ir.Attribute | None: From 9bbff1e4469bc0078edc6176e006d861411c8c00 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 25 Mar 2025 04:27:39 -0700 Subject: [PATCH 0140/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5. PiperOrigin-RevId: 740293121 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 996ee511f835..8fcda2281ea7 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003" -XLA_SHA256 = "4e3248d37a1b0598de3e93e8e46ede060578bc45bfbdfaf24d91ab598543b770" +XLA_COMMIT = "d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5" +XLA_SHA256 = "4fe51bd389428ce65415b08693f966b142fe8218ced771becab9033503a70a3d" def repo(): tf_http_archive( From 4ed257065a5d4de6e24826cf6546bedada78985f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 05:45:27 -0700 Subject: [PATCH 0141/1769] Fix ODR problem in jax_jit.h. We need to include the type caster for std::string_view if we use nb::cast. PiperOrigin-RevId: 740311318 --- jaxlib/xla/jax_jit.h | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index a000ef6773b2..254ed11ba78c 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" From ad7550de6de003f88d38417be466411c620ce5c4 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 06:08:15 -0700 Subject: [PATCH 0142/1769] [Mosaic GPU] Add warpgroup lowering for `SetMaxRegisters` in Pallas. PiperOrigin-RevId: 740318556 --- jax/_src/pallas/mosaic_gpu/primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 9665f14254f8..fe28766cfb96 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -952,6 +952,7 @@ def _set_max_registers_abstract_eval(n, *, action): @lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): From 411450b8b896f374758127ba109f93db5b75e742 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 06:41:04 -0700 Subject: [PATCH 0143/1769] Fix Jax XLA FFI callback handlers for OSS GPU. OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure: 1. C++ objects are not created/consumed by different shared libraries. 2. Static objects are declared and defined appropriately. This PR: 1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs. 2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered. 3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks. 4. Updates Jax `callback.py` to call the new FFI callback handlers. PiperOrigin-RevId: 740327296 --- jax/_src/callback.py | 146 +++++++++++++++-------- jax_plugins/cuda/__init__.py | 4 + jax_plugins/rocm/__init__.py | 4 + jaxlib/cuda/BUILD | 9 ++ jaxlib/cuda/cuda_plugin_extension.cc | 11 ++ jaxlib/gpu/py_client_gpu.cc | 168 +++++++++++++++++++++++++++ jaxlib/gpu/py_client_gpu.h | 3 + jaxlib/rocm/BUILD | 9 ++ jaxlib/rocm/rocm_plugin_extension.cc | 12 ++ jaxlib/xla/xla_client.py | 2 +- 10 files changed, 318 insertions(+), 50 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bdceb98d92b7..92c275e7e924 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -33,6 +33,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb +from jax._src.lib import xla_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -200,7 +201,11 @@ def _callback_op_sharding( # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - op_sharding = sharding_impls.SdyArrayShardingList([ + # For shardy, we need to have the same number of shardy annotations as the + # number of result ops. If there are no result ops, we need 1 shardy + # annotation. + num_sdy_shardings = max(1, len(avals_out)) + op_sharding = sharding_impls.SdyArrayShardingList(num_sdy_shardings * [ sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[], @@ -592,7 +597,6 @@ def io_callback( return tree_util.tree_unflatten(out_tree, out_flat) - def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) @@ -822,55 +826,99 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) - if token: + if xla_extension_version <= 320: + result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + if token: + + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + + operand_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + ] + result_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + ] + operands = [token, *operands] + result_types = [mlir.token_type(), *result_types] + operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] + result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] + callback_descriptor, ifrt_callback = ( + backend.get_emit_python_callback_descriptor(_wrapped_callback, + operand_shapes, + result_shapes)) + ctx.module_context.add_host_callback(ifrt_callback) + descriptor_operand = mlir.ir_constant(callback_descriptor) + callback_operands = [descriptor_operand, *operands] + if operand_mlir_layouts is not None: + operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] + result_type = ir.TupleType.get_tuple(result_types) + call_target_name = ("xla_python_gpu_callback" + if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") + result = hlo.CustomCallOp( + [result_type], + callback_operands, + call_target_name=ir.StringAttr.get(call_target_name), + has_side_effect=ir.BoolAttr.get(has_side_effect), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(str(callback_descriptor)), + operand_layouts=( + None if operand_mlir_layouts is None + else ir.ArrayAttr.get(operand_mlir_layouts)), + result_layouts=( + None if result_mlir_layouts is None + else ir.ArrayAttr.get(result_mlir_layouts))) + if sharding is not None: + mlir.set_sharding(result, sharding) + results = [ + hlo.get_tuple_element(result, mlir.i32_attr(i)) + for i in range(len(result_types)) + ] + else: + call_target_name = ( + "xla_ffi_python_gpu_callback" + if platform in {"cuda", "rocm"} + else "xla_ffi_python_cpu_callback" + ) + if token: + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + operands = [token, *operands] + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, sharding_impls.SdyArrayShardingList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = sharding_impls.SdyArrayShardingList( + [*sharding.shardings, sharding.shardings[-1]] + ) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback + ctx.module_context.add_host_callback(ifrt_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] - operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) - ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) - if sharding is not None: - mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] + if sharding is not None: + mlir.set_sharding(result, sharding) + + results = result.results # type: ignore if token: token, *results = results return results, token, ifrt_callback diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f6540e986024..2b02621c89f5 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -94,6 +94,10 @@ def initialize(): ) for _name, _value in cuda_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") + for _name, _value in cuda_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='CUDA', api_version=1 + ) xla_client.register_custom_type_id_handler( "CUDA", functools.partial( diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index c48a681bf337..0699ae1e34a1 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -94,6 +94,10 @@ def initialize(): ) for _name, _value in rocm_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") + for _name, _value in rocm_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='ROCM', api_version=1 + ) xla_client.register_custom_type_id_handler( "ROCM", functools.partial( diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index ee32888864dd..5cd7283ea3fc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -668,19 +668,28 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 6655128b9842..63375921e3be 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -48,12 +48,23 @@ nb::dict Registrations() { jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); return dict; } +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::cuda::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + return dict; +} } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); m.def("registrations", &Registrations); + m.def("ffi_registrations", &FfiRegistrations); m.def( "get_device_ordinal", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 3e140411770d..71d327ffdb28 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/gpu/py_client_gpu.h" + +#include + #include #include #include @@ -20,24 +24,31 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" #include "xla/python/nb_numpy.h" +#include "xla/python/types.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/shape_util.h" namespace nb = nanobind; @@ -155,5 +166,162 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "xla_python_gpu_callback", &XlaPythonGpuCallback, absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); + +struct GpuTransposePlanCache { + static xla::ffi::TypeId id; + explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; +xla::ffi::TypeId GpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(xla::ffi::GetXlaFfiApi(), "GpuTransposePlanCache", + &GpuTransposePlanCache::id); + +static xla::ffi::ErrorOr> +GpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kGpuTransposePlanCacheInstantiate, GpuTransposePlanCacheInstantiate, + xla::ffi::Ffi::BindInstantiate().Attr("index")); +xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, + xla::FfiLoadedHostCallbacks* callbacks, + GpuTransposePlanCache* transpose_cache, + uint64_t index, + xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg->size_bytes()]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(host_input_arrays)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return xla::ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + auto array = xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return xla::ffi::Error::Internal( + maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = xla::ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; + } + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.rank()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index e9454504f5d9..06a955365c0b 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { @@ -28,6 +29,8 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 99df757018f3..522aa8da0145 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -566,19 +566,28 @@ cc_library( features = ["-use_header_modules"], deps = [ ":hip_vendor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 454f4741d667..642467a9afef 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -72,12 +72,24 @@ nb::dict Registrations() { jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); return dict; } +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + return dict; +} } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); m.def("registrations", &Registrations); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index a111c14232de..0e4eebdfb26f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 320 +_version = 321 # Version number for MLIR:Python components. mlir_api_version = 58 From a58592ebb0c217cee4279c2afcec0f2d73688f0b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 06:46:19 -0700 Subject: [PATCH 0144/1769] Finalize some deprecations from jax.lib.xla_client --- CHANGELOG.md | 3 +++ jax/lib/xla_client.py | 32 +++++++++++++------------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17fb421fcc06..1acb2b48eab6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` instead. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, + and `shape_from_pyval`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 86e7307c804b..07c6914a1f59 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lax.fft import FftType as _FftType from jax._src.lib import xla_client as _xc get_topology_for_devices = _xc.get_topology_for_devices @@ -48,23 +47,27 @@ ), None, ), - # Added Oct 10 2024 + # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( - "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", - _FftType, + "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.", + None, ), "PaddingType": ( ( - "jax.lib.xla_client.PaddingType is deprecated; this type is unused" - " by JAX so there is no replacement." + "jax.lib.xla_client.PaddingType was removed in JAX v0.6.0;" + " this type is unused by JAX so there is no replacement." ), - _xc.PaddingType, + None, ), - # Added Oct 11 2024 "dtype_to_etype": ( - "dtype_to_etype is deprecated; use StableHLO instead.", - _xc.dtype_to_etype, + "dtype_to_etype was removed in JAX v0.6.0; use StableHLO instead.", + None, + ), + "shape_from_pyval": ( + "shape_from_pyval was removed in JAX v0.6.0; use StableHLO instead.", + None, ), + # Added Oct 11 2024 "ops": ( "ops is deprecated; use StableHLO instead.", _xc.ops, @@ -74,10 +77,6 @@ "(https://jax.readthedocs.io/en/latest/ffi.html)", _xc.register_custom_call_target, ), - "shape_from_pyval": ( - "shape_from_pyval is deprecated; use StableHLO instead.", - _xc.shape_from_pyval, - ), "PrimitiveType": ( "PrimitiveType is deprecated; use StableHLO instead.", _xc.PrimitiveType, @@ -104,14 +103,10 @@ import typing as _typing if _typing.TYPE_CHECKING: - dtype_to_etype = _xc.dtype_to_etype ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target - shape_from_pyval = _xc.shape_from_pyval ArrayImpl = _xc.ArrayImpl Device = _xc.Device - FftType = _FftType - PaddingType = _xc.PaddingType PrimitiveType = _xc.PrimitiveType Shape = _xc.Shape XlaBuilder = _xc.XlaBuilder @@ -123,5 +118,4 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _FftType del _xc From 3c63f600000423df181f114345d1fe56821dcd94 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 07:17:05 -0700 Subject: [PATCH 0145/1769] [JAX] [XLA:Python] Migrate py_socket_transfer to JAX. Also in passing fix up some header guards and authorship comments. PiperOrigin-RevId: 740337166 --- jaxlib/xla/BUILD | 39 ++- jaxlib/xla/config.h | 6 +- jaxlib/xla/custom_call_sharding.h | 6 +- jaxlib/xla/dlpack.h | 6 +- jaxlib/xla/jax_jit.h | 6 +- jaxlib/xla/mlir.h | 6 +- jaxlib/xla/pjit.h | 6 +- jaxlib/xla/pmap_lib.h | 6 +- jaxlib/xla/py_socket_transfer.cc | 409 ++++++++++++++++++++++++++++++ jaxlib/xla/py_socket_transfer.h | 26 ++ jaxlib/xla/pytree.cc | 2 +- jaxlib/xla/pytree.h | 2 +- jaxlib/xla/sdy.h | 6 +- jaxlib/xla/weakref_lru_cache.h | 6 +- jaxlib/xla/xla.cc | 2 +- jaxlib/xla/xla_compiler.h | 6 +- 16 files changed, 506 insertions(+), 34 deletions(-) create mode 100644 jaxlib/xla/py_socket_transfer.cc create mode 100644 jaxlib/xla/py_socket_transfer.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2edc183bc49b..e562cb7e84ea 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -134,10 +134,10 @@ nanobind_extension( ], "@xla//xla/tsl:windows": [], "//conditions:default": [ + ":py_socket_transfer", "@gloo//:transport_tcp", "@xla//xla/backends/cpu/collectives:gloo_collectives", "@xla//xla/backends/cpu/collectives:gloo_kv_store", - "@xla//xla/python/transfer:py_socket_transfer", ], }) + select({ # mpitrampoline does not build on windows @@ -414,6 +414,43 @@ cc_library( ], ) +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + proto_library( name = "pytree_proto", srcs = ["pytree.proto"], diff --git a/jaxlib/xla/config.h b/jaxlib/xla/config.h index 40847bf4a370..2a9281f498b4 100644 --- a/jaxlib/xla/config.h +++ b/jaxlib/xla/config.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#ifndef JAXLIB_XLA_CONFIG_H_ +#define JAXLIB_XLA_CONFIG_H_ #include @@ -31,4 +31,4 @@ void BuildConfigSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#endif // JAXLIB_XLA_CONFIG_H_ diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/xla/custom_call_sharding.h index c3470901f53e..5a5f3776cc30 100644 --- a/jaxlib/xla/custom_call_sharding.h +++ b/jaxlib/xla/custom_call_sharding.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#ifndef JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildCustomCallShardingPybindAPI(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#endif // JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index 5d7fd7c10bf8..d0079b1d4914 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#ifndef JAXLIB_XLA_DLPACK_H_ +#define JAXLIB_XLA_DLPACK_H_ #include #include @@ -54,4 +54,4 @@ absl::StatusOr PrimitiveTypeToNbDLDataType( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#endif // JAXLIB_XLA_DLPACK_H_ diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 254ed11ba78c..9e6f8e34f1e9 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#ifndef JAXLIB_XLA_JAX_JIT_H_ +#define JAXLIB_XLA_JAX_JIT_H_ #include @@ -263,4 +263,4 @@ void BuildJaxjitSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#endif // JAXLIB_XLA_JAX_JIT_H_ diff --git a/jaxlib/xla/mlir.h b/jaxlib/xla/mlir.h index f0bfd69bca6b..ee95f5f95921 100644 --- a/jaxlib/xla/mlir.h +++ b/jaxlib/xla/mlir.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#ifndef JAXLIB_XLA_MLIR_H_ +#define JAXLIB_XLA_MLIR_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildMlirSubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#endif // JAXLIB_XLA_MLIR_H_ diff --git a/jaxlib/xla/pjit.h b/jaxlib/xla/pjit.h index 545fb2307783..8d47347ab9a2 100644 --- a/jaxlib/xla/pjit.h +++ b/jaxlib/xla/pjit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#ifndef JAXLIB_XLA_PJIT_H_ +#define JAXLIB_XLA_PJIT_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ namespace jax { void BuildPjitSubmodule(nanobind::module_& m); } -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#endif // JAXLIB_XLA_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h index 9ad60a03daf6..e02311e03c73 100644 --- a/jaxlib/xla/pmap_lib.h +++ b/jaxlib/xla/pmap_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#ifndef JAXLIB_XLA_PMAP_LIB_H_ +#define JAXLIB_XLA_PMAP_LIB_H_ #include #include @@ -34,4 +34,4 @@ void BuildPmapSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#endif // JAXLIB_XLA_PMAP_LIB_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc new file mode 100644 index 000000000000..dd2c02898e18 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.cc @@ -0,0 +1,409 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/xla/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/to_ifrt_sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding& sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto* device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory* memory = nullptr; + for (xla::ifrt::Memory* ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string* out, xla::ifrt::Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory)->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +class IfrtArrayEntry : public PullTable::Entry { + public: + struct BufferRef { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; + size_t buf_size; + }; + explicit IfrtArrayEntry(std::vector arrs, + std::shared_ptr state, + size_t xfer_size) + : arrs_(std::move(arrs)), state_(state), xfer_size_(xfer_size) {} + bool Handle(tsl::RCReference state, + const SocketTransferPullRequest& req, + size_t base_req_id) override { + for (uint64_t bid : req.buffer_ids()) { + auto req_id = base_req_id; + ++base_req_id; + for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { + DmaCopyChunk blob; + blob.arr = std::move(arrs_[bid].arr); + blob.buffer = arrs_[bid].buffer; + blob.buffer_id = bid; + blob.offset = i * xfer_size_; + blob.size = std::min(xfer_size_, arrs_[bid].buf_size - blob.offset); + bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; + state_->ScheduleCopy( + blob, [req_id, state, copier_state = state_, is_largest]( + PremappedCopierState* copier_state_ptr, void* buf, + const DmaCopyChunk& chunk) { + state->Send( + req_id, buf, chunk.offset, chunk.size, is_largest, + [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); + }); + } + } + + num_consumed_bufs_ += req.buffer_ids().size(); + return num_consumed_bufs_ == arrs_.size(); + } + + private: + absl::Mutex mu_; + size_t num_consumed_bufs_ = 0; + std::vector arrs_; + std::shared_ptr state_; + size_t xfer_size_; +}; + +absl::StatusOr> CreatePullEntry( + const std::vector>& arrs, + std::shared_ptr state, size_t xfer_size) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({arr, pjrt_buf.get(), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress& addr, + const std::vector& transport_addresses) { + std::shared_ptr factory; + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, std::nullopt, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN(auto mem, + AllocateAndMapPjrtMemory( + client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string& saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, + const std::vector>& arrs) { + server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( + arrs, premapped_copier_, xfer_size_))); + } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + xla::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace* memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_& m) { + nb::class_(m, "TransferConnection") + .def("_pull_flat", [](PyTransferServerConnection& self, uint64_t uuid, + xla::nb_class_ptr py_client, + std::vector py_avals) { + auto* ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto& py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto& aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector< + std::shared_ptr> + atms; + atms.reserve(dests.size()); + + for (auto& dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto& fetch_idx : fetch_idxs) { + auto& atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + self.Pull(uuid, buffer_ids, std::move(pull_dests)); + + std::vector out; + auto traceback = xla::Traceback::Get(); + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto& v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, + std::move(buffers), avals[i].layout)); + out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( + py_client, traceback, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer& self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer& self, uint64_t uuid, + std::vector inputs) { + std::vector> arrs; + arrs.reserve(inputs.size()); + for (const xla::PyArray& input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + self.AwaitPull(uuid, arrs); + }) + .def("connect", [](PyTransferServer& self, const std::string& address) { + return self.Connect(address); + }); + + m.def( + "start_transfer_server", + [](xla::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, + size_t transfer_size) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string& addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024); +} + +} // namespace aux diff --git a/jaxlib/xla/py_socket_transfer.h b/jaxlib/xla/py_socket_transfer.h new file mode 100644 index 000000000000..fa477f24e3e5 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_& m); + +} // namespace aux + +#endif // JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index dd5a0bd9cf69..7d1f7676bada 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The OpenXLA Authors. +/* Copyright 2019 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h index 722fe41169a0..471d25af89bc 100644 --- a/jaxlib/xla/pytree.h +++ b/jaxlib/xla/pytree.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The OpenXLA Authors. +/* Copyright 2019 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jaxlib/xla/sdy.h b/jaxlib/xla/sdy.h index 5d8c8c2eb7dd..ef075855decd 100644 --- a/jaxlib/xla/sdy.h +++ b/jaxlib/xla/sdy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#ifndef JAXLIB_XLA_SDY_H_ +#define JAXLIB_XLA_SDY_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildSdySubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#endif // JAXLIB_XLA_SDY_H_ diff --git a/jaxlib/xla/weakref_lru_cache.h b/jaxlib/xla/weakref_lru_cache.h index 444e01cef575..7c75974d3d23 100644 --- a/jaxlib/xla/weakref_lru_cache.h +++ b/jaxlib/xla/weakref_lru_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#ifndef JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ +#define JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildWeakrefLRUCacheAPI(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#endif // JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index bd3ed3205fb2..54c94c57a734 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -71,9 +71,9 @@ limitations under the License. #if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" +#include "jaxlib/xla/py_socket_transfer.h" #include "xla/backends/cpu/collectives/gloo_collectives.h" #include "xla/backends/cpu/collectives/gloo_kv_store.h" -#include "xla/python/transfer/py_socket_transfer.h" #elif defined(__APPLE__) #include "gloo/transport/uv/device.h" #include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla/xla_compiler.h index f3ffe5fe9440..ca5bc762a7d8 100644 --- a/jaxlib/xla/xla_compiler.h +++ b/jaxlib/xla/xla_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#ifndef JAXLIB_XLA_XLA_COMPILER_H_ +#define JAXLIB_XLA_XLA_COMPILER_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildXlaCompilerSubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#endif // JAXLIB_XLA_XLA_COMPILER_H_ From 4f9571eb2bd72ab893e0ec3df1bf08777a0cc7c1 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Mon, 24 Mar 2025 19:32:47 +0000 Subject: [PATCH 0146/1769] Fix auditwheel --- build/rocm/tools/build_wheels.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index fd98bbb8ec04..a7ebdf86f916 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -226,7 +226,10 @@ def fix_wheel(path, jax_path): py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - cmd = ["pip", "install", "auditwheel>=6"] + # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed + # the fuction to ldd and also changed its behavior + # constrain range to 6.0 to 6.2.x + cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") From a7d46e6acc4aecee92dcfe68a8d3d86d21b3db3c Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 25 Mar 2025 07:48:50 -0700 Subject: [PATCH 0147/1769] Integrate Triton up to [cdb53266](https://github.com/openai/triton/commits/cdb53266e6c251d91a2c321d64e8466caff129a9) PiperOrigin-RevId: 740345806 --- jax/_src/pallas/triton/lowering.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index a0883ea589b0..c85c5f0a39c0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1779,6 +1779,12 @@ def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: ) +def get_join_type(old_type: ir.RankedTensorType): + shape = old_type.shape + shape.append(2) + return ir.RankedTensorType.get(shape, old_type.element_type, old_type.encoding) + + @register_lowering(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): if len(args) != 2: @@ -1793,9 +1799,10 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): raise NotImplementedError( "Only arguments with shape [..., 1] are supported." ) - return tt_dialect.join( - _reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1]) - ) + lhs = _reshape(x, x_aval.shape[:-1]) + rhs = _reshape(y, y_aval.shape[:-1]) + ret_type = get_join_type(ir.RankedTensorType(rhs.type)) + return tt_dialect.join(ret_type, lhs, rhs) @register_lowering(lax.split_p) @@ -2102,10 +2109,11 @@ def _masked_load_lowering_rule( # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: msb_values = arith_dialect.shrui(values, _full(values.type, 4)) + join_type = get_join_type(ir.RankedTensorType(values.type)) if jaxlib_version < (0, 5, 2): - values = tt_dialect.join(msb_values, values) + values = tt_dialect.join(join_type, msb_values, values) else: - values = tt_dialect.join(values, msb_values) + values = tt_dialect.join(join_type, values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: From 8260ab329145155da697a276591e73100993635f Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 25 Mar 2025 09:53:29 -0500 Subject: [PATCH 0148/1769] Address review comments --- jax/_src/pallas/core.py | 1 - jax/_src/pallas/mosaic/core.py | 4 +++- jax/_src/pallas/mosaic/lowering.py | 15 ++++++++++++-- jax/_src/pallas/primitives.py | 32 +++++++++-------------------- jax/experimental/pallas/__init__.py | 1 + jax/experimental/pallas/tpu.py | 15 ++++++++------ 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 389bbd3b0733..78b815820609 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -67,7 +67,6 @@ def __repr__(self): class semaphore_dtype(dtypes.extended): pass class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass class barrier_semaphore(semaphore_dtype): pass @runtime_checkable diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 5d503779f092..fc4ecbedaca5 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -112,6 +112,8 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) +class dma_semaphore(pallas_core.semaphore_dtype): pass + class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: @@ -141,7 +143,7 @@ class SemaphoreTy(AbstractSemaphoreTy): name = "sem" class DmaSemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.dma_semaphore + type = dma_semaphore name = "dma_sem" class BarrierSemaphoreTy(AbstractSemaphoreTy): diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3469ef4de952..00302494f67d 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -193,7 +193,7 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): - if jnp.issubdtype(dtype, pallas_core.dma_semaphore): + if jnp.issubdtype(dtype, tpu_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") @@ -3367,7 +3367,18 @@ def _semaphore_read_lowering_rule( *args, args_tree, ): - sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, ctx.avals_in) + primitives.check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + tpu_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.sem_read(sem) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 5d3444ef719f..4971b83a9ba2 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1023,16 +1023,15 @@ def check_sem_avals( sem_shape = sem_transforms_avals[-1].get_indexer_shape() if sem_shape: raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - # Uncomment when semaphore type works for Mosaic-GPU lowering - # sem_dtype = sem_aval.dtype - # if not any( - # jnp.issubdtype(sem_dtype, sem_type) - # for sem_type in allowed_semaphore_types - # ): - # raise ValueError( - # f"Must {name} semaphores of the following types:" - # f" {allowed_semaphore_types}." - # ) + sem_dtype = sem_aval.dtype + if not any( + jnp.issubdtype(sem_dtype, sem_type) + for sem_type in allowed_semaphore_types + ): + raise ValueError( + f"Must {name} semaphores of the following types:" + f" {allowed_semaphore_types}." + ) def _transform_semaphore(ref_value, transforms, ref_aval): @@ -1063,18 +1062,7 @@ def _semaphore_read_abstract_eval( *avals, args_tree, ): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - pallas_core.dma_semaphore, - pallas_core.semaphore, - pallas_core.barrier_semaphore, - pallas_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) + del avals, args_tree return jax_core.ShapedArray((), jnp.dtype("int32")) def _semaphore_read_discharge_rule(in_avals, diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index ea58fae3d283..fd523712fa9c 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -30,6 +30,7 @@ from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index c81edaf76fa3..da054bf18309 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -17,11 +17,10 @@ from jax._src.pallas.mosaic import core as core from jax._src.pallas.mosaic.core import ARBITRARY as ARBITRARY from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh -from jax._src.pallas.core import dma_semaphore as dma_semaphore +from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams @@ -40,8 +39,6 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.primitives import device_id as device_id -from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy @@ -49,11 +46,17 @@ from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll +from jax._src.pallas.mosaic.random import sample_block as sample_block +from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key + +# Those primitives got moved to Pallas core. Keeping the updated imports +# here for backward compatibility. +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.primitives import device_id as device_id +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal from jax._src.pallas.primitives import semaphore_wait as semaphore_wait -from jax._src.pallas.mosaic.random import sample_block as sample_block -from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key import types from jax._src.pallas.mosaic.verification import assume From a9266a1521ade250f99114227f281b30235cef9c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 25 Mar 2025 09:09:53 -0700 Subject: [PATCH 0149/1769] [pallas:mosaic_gpu] `PallasCallTest` now runs all tests under both Lane and WG thread semantics PiperOrigin-RevId: 740371195 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +- .../mosaic_gpu/pallas_call_registration.py | 10 +- tests/pallas/mosaic_gpu_test.py | 414 +++++++++--------- 3 files changed, 208 insertions(+), 220 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 677f63c6674a..493d8c07b941 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1404,7 +1404,7 @@ def convert(ty, x): mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( - lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, ), }) @@ -1821,7 +1821,7 @@ def _debug_print_lowering_rule( return () @register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) -def _debug_print_lowering_rule( +def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, fmt, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index d506349fe101..5399727878a6 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -27,7 +27,7 @@ from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -import jax.experimental.mosaic.gpu.core as mosaic_core +from jax.experimental.mosaic import gpu as mgpu def pallas_call_lowering( @@ -57,10 +57,10 @@ def pallas_call_lowering( print(grid_mapping) thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mosaic_core.ThreadSemantics.Lane + "thread_semantics", mgpu.ThreadSemantics.Warpgroup ) - if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: - mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, @@ -77,7 +77,7 @@ def pallas_call_lowering( new_avals_out = [ jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs ] - outs = mosaic_core._mosaic_gpu_lowering_rule( + outs = mgpu.core._mosaic_gpu_lowering_rule( ctx.replace(avals_out=new_avals_out), *args, module=module, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b33857df40b6..b39288252e08 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,24 +13,27 @@ # limitations under the License. import contextlib +import dataclasses import functools import math import operator import os import re import tempfile +from typing import ClassVar from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax from jax._src import test_util as jtu -from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np + try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: @@ -55,7 +58,16 @@ def _sum_same_dtype(x): return jnp.sum(x, dtype=x.dtype) -class PallasTest(jtu.JaxTestCase): +class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): + + def __new__(mcs, *args, thread_semantics=plgpu.ThreadSemantics.Lane): + cls = super().__new__(mcs, *args) + cls.THREAD_SEMANTICS = thread_semantics + return cls + + +class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): + THREAD_SEMANTICS: ClassVar[plgpu.ThreadSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): @@ -66,6 +78,17 @@ def setUp(self): super().setUp() + def skip_if_wg_semantics(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: + self.skipTest("Not supported under WG semantics") + + def pallas_call(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + thread_semantics=self.THREAD_SEMANTICS, + ) + return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) + @contextlib.contextmanager def capture_stdout(self): if mosaic_gpu_lib is None: @@ -104,17 +127,14 @@ class PallasCallTest(PallasTest): lax.log, ], approx_math=[True, False], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, op, approx_math, thread_semantics): + def test_unary_op(self, op, approx_math): dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - approx_math=approx_math, thread_semantics=thread_semantics - ), + compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -135,16 +155,10 @@ def kernel(x_ref, o_ref): jnp.maximum, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_binary_op(self, op, dtype, thread_semantics): - + def test_binary_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = op(x_ref[...], y_ref[...]) @@ -165,16 +179,10 @@ def kernel(x_ref, y_ref, o_ref): ], # TODO(slebedev): Support integral types. dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_comparison_op(self, op, dtype, thread_semantics): - + def test_comparison_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(o_ref): o_ref[...] = jnp.broadcast_to( @@ -184,8 +192,9 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype)) def test_add_first(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, y_ref, o_ref): @@ -195,16 +204,10 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) - @parameterized.product( - shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_reduce_sum(self, shape, thread_semantics): + @parameterized.product(shape=[(128,), (128, 128)]) + def test_reduce_sum(self, shape): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) @@ -213,11 +216,12 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), jnp.sum(x)) def test_reshape(self): + self.skip_if_wg_semantics() + shape1, shape2 = (128,), (2, 16, 4) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32) ) def kernel(x_ref, out_ref): x_ref_reshaped = x_ref.reshape(shape2) @@ -228,14 +232,9 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_add_xy_indexed(self, thread_semantics): + def test_add_xy_indexed(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -246,8 +245,9 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) def test_add_one_grid(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), @@ -260,9 +260,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), @@ -278,9 +277,8 @@ def kernel(x_ref, o_ref, scratch_ref): @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): - @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), @@ -297,9 +295,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_pipelined_program_id(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), compiler_params=plgpu.GPUCompilerParams( @@ -317,8 +314,9 @@ def kernel(o_ref): ) def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), @@ -345,30 +343,29 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) def test_iota(self, dtype): + self.skip_if_wg_semantics() + dimension = 1 + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype) ) def kernel(o_ref): - o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) - - np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + o_ref[...] = plgpu.broadcasted_iota( + dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA + ) - @parameterized.product( - indexer=[..., slice(128), slice(None, 128)], - thread_semantics=[*plgpu.ThreadSemantics], - ) - def test_copy_smem_to_gmem(self, indexer, thread_semantics): + np.testing.assert_array_equal( + kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) + ) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -388,8 +385,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): "shape": (64, 64), "indexers": (4, slice(0, 64))}, ) def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM(shape, jnp.float32)], @@ -413,8 +411,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_gmem_to_smem(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -458,13 +457,15 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): }, ) def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], grid=(1,), ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -489,7 +490,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_gmem_to_smem_with_multiple_smem_indexers(self): x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -506,21 +507,31 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): + self.skip_if_wg_semantics() + x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( - pl.pallas_call, + self.pallas_call, grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), - in_specs=(plgpu.GPUBlockSpec( - block_shape=(128, 128), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM, - transforms=(plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128))),), - out_specs=(plgpu.GPUBlockSpec( - block_shape=(64, 32), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM,)), + in_specs=( + plgpu.GPUBlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=( + plgpu.TilingTransform((64, 32)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=( + plgpu.GPUBlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + ) + ), ) def kernel(x_ref, o_ref): x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] @@ -532,8 +543,9 @@ def kernel(x_ref, o_ref): @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -553,6 +565,8 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @parameterized.named_parameters(("_g2s", False), ("_s2g", True)) def test_copy_with_transforms(self, to_smem): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): if to_smem: plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) @@ -574,7 +588,7 @@ def kernel(x_ref, o_ref, barrier_ref): ) if not to_smem: in_spec, out_spec = out_spec, in_spec - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), @@ -585,6 +599,8 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), x) def test_scoped_copy_with_transforms(self): + self.skip_if_wg_semantics() + ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): @@ -597,7 +613,7 @@ def body(tmp_ref): out_spec = plgpu.GPUBlockSpec( (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), @@ -608,6 +624,8 @@ def body(tmp_ref): np.testing.assert_array_equal(f(x), x * 2) def test_copy_with_transforms_and_indexing(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) @@ -624,7 +642,7 @@ def kernel(x_ref, o_ref, barrier_ref): ), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), in_specs=(in_spec,), @@ -644,30 +662,33 @@ def kernel(x_ref, o_ref, barrier_ref): ], ) def test_load_to_layout_with_indexing(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec( + (2, 128), + lambda: (0, 0), + memory_space=plgpu.SMEM, + ), + ) def kernel(x_ref, o_ref): for i in range(2): x = plgpu.load(x_ref, (i,), layout=layout) o_ref[i, ...] = x - in_spec = pl.BlockSpec(memory_space=src_memory_space) - out_spec = plgpu.GPUBlockSpec( - (2, 128), lambda: (0, 0), memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), - in_specs=(in_spec,), - out_specs=out_spec, - ) x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) - np.testing.assert_array_equal(f(x), x) + np.testing.assert_array_equal(kernel(x), x) - @parameterized.product(src_memory_space=[plgpu.SMEM], - layout=[ - plgpu.Layout.WGMMA_ROW, - plgpu.Layout.WGMMA_COL, - ],) + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + self.skip_if_wg_semantics() + m, k, n = 64, 128, 192 key1, key2 = jax.random.split(jax.random.key(42), 2) if layout == plgpu.Layout.WGMMA_ROW: @@ -694,7 +715,7 @@ def compute(acc_ref): out_spec = plgpu.GPUBlockSpec( (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( @@ -705,8 +726,9 @@ def compute(acc_ref): transforms=( plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128), - ) - )), + ), + ), + ), out_specs=out_spec, ) @@ -716,6 +738,8 @@ def compute(acc_ref): np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) def test_indexing_before_transpose(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( @@ -727,7 +751,7 @@ def kernel(x_ref, o_ref, barrier_ref): out_spec = plgpu.GPUBlockSpec( (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), in_specs=(in_spec,), @@ -739,8 +763,9 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) def test_copy_gmem_to_smem_in_run_scoped(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), ) @@ -757,8 +782,9 @@ def inner_body(scratch_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) def kernel(x_ref, o_ref): @@ -767,26 +793,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.named_parameters( - ("rsqrt", jax.lax.rsqrt, ), - ("log", jax.lax.log, 5e-7), - ("exp", jax.lax.exp, ), - ("exp2", jax.lax.exp2, 5e-7), - ("logistic", jax.lax.logistic, ), - ("tanh", jax.lax.tanh, 5e-7), - ) - def test_approx_math_unary_op(self, unary_op, rtol=1e-7): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def kernel(x_ref, o_ref): - o_ref[...] = unary_op(x_ref[...]) - - x = jnp.arange(128).astype(jnp.float32) / 128 - np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) - @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 @@ -794,7 +800,7 @@ def test_layer_norm(self, input_factor): beta = 1.0 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def layer_norm(x_ref, o_ref): @@ -822,8 +828,9 @@ def layer_norm_np(x): np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): @@ -836,16 +843,32 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): + self.skip_if_wg_semantics() + shape = (128, 64) size = math.prod(shape) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec( + shape, + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 32)), + plgpu.SwizzleTransform(128), + ), + ) + ], + ) def kernel(x_ref, o_ref): + del o_ref # Unused. pl.debug_print("prefix {}", x_ref[...]) - spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) - x = jnp.arange(size, dtype=jnp.float32).reshape(shape) - f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) with self.capture_stdout() as get_output: - jax.block_until_ready(f(x)) + jax.block_until_ready(kernel(x)) output = get_output() results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) @@ -855,8 +878,10 @@ def kernel(x_ref, o_ref): self.assertEqual(v, i * shape[1] + j) def test_print_scalar(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -870,8 +895,10 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum()}", output()) def test_print_scalar_array(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -885,10 +912,12 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum() + 1}", output()) def test_print_array(self): + self.skip_if_wg_semantics() + in_shape = [2, 1, 64, 64] @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): @@ -913,9 +942,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), jnp.full((128,), 10, dtype=jnp.int32)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_run_scoped(self, thread_semantics): - + def test_run_scoped(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -926,16 +957,8 @@ def body(tmp_ref): self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp - inp = np.ones((8, 128), jnp.float32) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) - o = f(inp) - np.testing.assert_array_equal(o, inp + 1.0) + x = np.ones((8, 128), jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) def test_program_id(self): @functools.partial( @@ -1031,14 +1054,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_array(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_array(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. @@ -1047,14 +1066,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_scalar(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_scalar(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -1066,7 +1081,6 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32)) def test_fori_loop_dynamic_bounds(self): - @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), @@ -1081,16 +1095,10 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_tuple(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_tuple(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): def body(step, xs): @@ -1109,16 +1117,11 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_indexed_store(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_indexed_store(self, force_while): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, y_ref, o_ref): def body(idx, _): @@ -1131,17 +1134,11 @@ def body(idx, _): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_while_loop(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("WG lowering does not support reduce_sum_p needed for this test") + def test_while_loop(self): + self.skip_if_wg_semantics() @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32) @@ -1182,12 +1179,9 @@ def body(acc): with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): kernel() - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond(self, thread_semantics): + def test_cond(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): jax.lax.cond( @@ -1203,14 +1197,9 @@ def kernel(x_ref, o_ref): self.assertIn("acc % 2", output()) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond_returning_array(self, thread_semantics): + def test_cond_returning_array(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): acc_sum = _sum_same_dtype(x_ref[...]) @@ -1341,20 +1330,13 @@ def kernel(x_ref, o_ref): (jnp.uint32, jnp.int32), (jnp.int32, jnp.uint32), ], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, dtypes, thread_semantics): + def test_bitcast_convert_type(self, dtypes): in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - @functools.partial( - pl.pallas_call, - out_shape=out_shape, - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + @functools.partial(self.pallas_call, out_shape=out_shape) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) @@ -1364,6 +1346,12 @@ def convert(x_ref, y_ref): ) +class PallasCallWGTest( + PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) From 336852c57bcccef7cc22db0ac6b499f5c277e78b Mon Sep 17 00:00:00 2001 From: Seunghoon Park Date: Tue, 25 Mar 2025 09:11:35 -0700 Subject: [PATCH 0150/1769] Expose jax.lax.shape_as_value(). PiperOrigin-RevId: 740371651 --- jax/lax/__init__.py | 1 + tests/lax_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4e376fb666d1..6f2163c424a6 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -198,6 +198,7 @@ select as select, select_n as select_n, select_n_p as select_n_p, + shape_as_value as shape_as_value, shift_left as shift_left, shift_left_p as shift_left_p, shift_right_arithmetic as shift_right_arithmetic, diff --git a/tests/lax_test.py b/tests/lax_test.py index 40f2eb8f3588..14b6c852c61f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -29,6 +29,7 @@ import jax from jax._src import core +from jax import export from jax import jvp, grad from jax import lax import jax.numpy as jnp @@ -3621,6 +3622,37 @@ def f(x): g = jax.grad(f)(5.) # doesn't crash self.assertAllClose(g, 3., check_dtypes=False) + def test_shape_as_value_handles_static_shapes(self): + result = lax.shape_as_value(()) + self.assertArraysEqual(result, lax.full((0,), np.array(0, np.int64))) + + result = lax.shape_as_value((2,)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + result = lax.shape_as_value((2, 3)) + self.assertArraysEqual(result, np.asarray((2, 3), np.int64)) + + def test_shape_as_value_handles_polymorphic_shapes(self): + @jax.jit + def f(x): + return lax.shape_as_value(x.shape) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a"), jnp.float32) + ) + result = exported.call(np.ones((1), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1,), np.int64)) + result = exported.call(np.ones((2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), jnp.float32) + ) + result = exported.call(np.ones((1, 2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1, 2), np.int64)) + result = exported.call(np.ones((3, 4), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((3, 4), np.int64)) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): From b088b3aef83d1254c82bc0d1780a305588868e59 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 25 Mar 2025 09:37:52 -0700 Subject: [PATCH 0151/1769] Fixed broken JAX distributed tests. PiperOrigin-RevId: 740379562 --- tests/BUILD | 11 +++++++ tests/distributed_initialize_test.py | 44 ++++++++++++++++++++++++++++ tests/distributed_test.py | 18 ------------ 3 files changed, 55 insertions(+), 18 deletions(-) create mode 100644 tests/distributed_initialize_test.py diff --git a/tests/BUILD b/tests/BUILD index 0a8fb9459044..d706f08b8092 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -137,9 +137,20 @@ jax_multiplatform_test( srcs = ["debug_nans_test.py"], ) +jax_py_test( + name = "distributed_initialize_test", + srcs = ["distributed_initialize_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("portpicker"), +) + jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], + enable_backends = ["gpu"], + deps = py_deps("portpicker"), ) jax_py_test( diff --git a/tests/distributed_initialize_test.py b/tests/distributed_initialize_test.py new file mode 100644 index 000000000000..33242a41a68e --- /dev/null +++ b/tests/distributed_initialize_test.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedInitializeTest(jtu.JaxTestCase): + + @jtu.skip_under_pytest( + """Side effects from jax.distributed.initialize conflict with other tests + in the same process. pytest runs multiple tests in the same process.""" + ) + def test_is_distributed_initialized(self): + port = portpicker.pick_unused_port() # type: ignore + self.assertFalse(jax.distributed.is_initialized()) + jax.distributed.initialize(f"localhost:{port}", 1, 0) + self.assertTrue(jax.distributed.is_initialized()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 3961932dfad0..5e47228c1719 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import threading import unittest @@ -67,22 +65,6 @@ def task(i): for thread in threads: thread.join() - def test_is_distributed_initialized(self): - # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other - # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend - # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. - port = portpicker.pick_unused_port() # type: ignore - cmd = f"""import jax; - assert not jax.distributed.is_initialized(); - jax.distributed.initialize('localhost:{port}', 1, 0); - assert jax.distributed.is_initialized(); - """.replace("\n", ' ') - - result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) - self.assertEqual( - result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From d8f38ff857726e4fb57bccb169cfcfdb1eb68656 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 11:42:24 -0700 Subject: [PATCH 0152/1769] [jaxlib:gpu] Clean up custom call GPU callback handling code. PiperOrigin-RevId: 740428623 --- jax_plugins/cuda/__init__.py | 2 - jax_plugins/rocm/__init__.py | 2 - jaxlib/cuda/BUILD | 4 - jaxlib/cuda/cuda_plugin_extension.cc | 7 -- jaxlib/gpu/py_client_gpu.cc | 118 --------------------------- jaxlib/gpu/py_client_gpu.h | 6 -- jaxlib/rocm/BUILD | 4 - jaxlib/rocm/rocm_plugin_extension.cc | 7 -- 8 files changed, 150 deletions(-) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 2b02621c89f5..13293de7181d 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -92,8 +92,6 @@ def initialize(): cuda_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in cuda_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") for _name, _value in cuda_plugin_extension.ffi_registrations().items(): xla_client.register_custom_call_target( _name, _value, platform='CUDA', api_version=1 diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 0699ae1e34a1..0b1b077acfcd 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -92,8 +92,6 @@ def initialize(): rocm_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in rocm_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") for _name, _value in rocm_plugin_extension.ffi_registrations().items(): xla_client.register_custom_call_target( _name, _value, platform='ROCM', api_version=1 diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 5cd7283ea3fc..48441632fba9 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -684,14 +684,10 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", - "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", ], ) diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 63375921e3be..68230a332d95 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -42,12 +42,6 @@ static std::string ToString(CUresult result) { return absl::StrCat(error_name, ": ", error_string); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); - return dict; -} nb::dict FfiRegistrations() { nb::dict dict; nb::dict gpu_callback_dict; @@ -63,7 +57,6 @@ nb::dict FfiRegistrations() { NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); - m.def("registrations", &Registrations); m.def("ffi_registrations", &FfiRegistrations); m.def( diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 71d327ffdb28..c39d5201f223 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -25,13 +25,11 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -39,15 +37,11 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" -#include "xla/python/callback.h" #include "xla/python/nb_numpy.h" #include "xla/python/types.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/shape_util.h" namespace nb = nanobind; @@ -55,118 +49,6 @@ namespace nb = nanobind; namespace jax { namespace JAX_GPU_NAMESPACE { -void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - // Ignore `descriptor` arg to callback - buffers += 1; - uint64_t descriptor; - if (!absl::SimpleAtoi(opaque, &descriptor)) { - throw xla::XlaRuntimeError("Invalid callback descriptor"); - return; - } - xla::CpuCallback* callback = - absl::bit_cast(static_cast(descriptor)); - size_t arity = callback->num_args(); - std::vector host_input_buffers(arity); - // Copy input GPU buffers to host - for (size_t i = 0; i < arity; ++i) { - const xla::CpuCallback::Arg& arg = callback->args()[i]; - if (arg.type == xla::TOKEN) { - host_input_buffers[i] = nullptr; - continue; - } - void* buf = new char[arg.size_in_bytes]; - host_input_buffers[i] = buf; - // TODO(b/238441608): Use pinned memory here to speed up the transfer. - auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, - gpuMemcpyDeviceToHost, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - nb::gil_scoped_acquire gil; - nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); - for (size_t i = 0; i < arity; ++i) { - xla::CpuCallback::Arg arg = callback->args()[i]; - if (arg.type == xla::TOKEN) { - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); - continue; - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); - auto array = xla::nb_numpy_ndarray(arg.dtype, arg.dims, arg.strides, - const_cast(host_input_buffers[i]), - /*base=*/base); - array.attr("flags").attr("writeable") = nb::bool_(false); - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); - } - xla::EnterHostCallback(); - absl::StatusOr maybe_result_tuple = - callback->Call(host_input_arrays); - xla::LeaveHostCallback(); - if (!maybe_result_tuple.ok()) { - absl::string_view msg = maybe_result_tuple.status().message(); - XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); - return; - } - nb::tuple result_tuple = maybe_result_tuple.value(); - std::vector temp_buffers; - for (size_t i = 0; i < callback->results().size(); ++i) { - xla::CpuCallback::Result result = callback->results()[i]; - if (result.type == xla::TOKEN) { - continue; - } - nb::object output = - nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); - xla::nb_numpy_ndarray array = - xla::nb_numpy_ndarray::ensure(std::move(output)); - absl::Span dims( - reinterpret_cast(array.shape()), array.ndim()); - absl::Span strides( - reinterpret_cast(array.strides()), array.ndim()); - if (strides == result.expected_strides) { - auto gpu_res = - gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes, - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } else { - void* temp = new char[result.size_in_bytes]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(result.type); - options.dims = dims; - options.permutation = result.reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - absl::StatusOr> plan = - callback->transpose_cache().GetOrCreate(options); - if (!plan.ok()) { - throw xla::XlaRuntimeError(plan.status().ToString()); - } - plan.value()->Execute(array.data(), temp); - auto gpu_res = - gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes, - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - } - nb::gil_scoped_release release; - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - for (int i = 0; i < temp_buffers.size(); ++i) { - delete[] static_cast(temp_buffers[i]); - } -} - -// TODO(danfm): When compiled as part of a jaxlib plugin, this will register -// the custom call target in the plugin's registry. This won't affect -// registration via the Python API, but we should remove this once we have -// fully migrated to the plugin interface. -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "xla_python_gpu_callback", &XlaPythonGpuCallback, - absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); - struct GpuTransposePlanCache { static xla::ffi::TypeId id; explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 06a955365c0b..8c5404570919 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -20,15 +20,9 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { - -void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 522aa8da0145..258556be8b1e 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -582,14 +582,10 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", - "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", ], ) diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 642467a9afef..2ba5d98ae668 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -66,12 +66,6 @@ std::string ToString(hipError_t result) { } } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); - return dict; -} nb::dict FfiRegistrations() { nb::dict dict; nb::dict gpu_callback_dict; @@ -87,7 +81,6 @@ nb::dict FfiRegistrations() { NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); - m.def("registrations", &Registrations); m.def("ffi_registrations", &FfiRegistrations); m.def( From c8ccd7570aa50dd67c80350f29477c4a44992897 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 25 Mar 2025 11:47:35 -0700 Subject: [PATCH 0153/1769] Add functionality that let us do a "jax" only release Introduces a new `download-jax-only-from-gcs` variable to the workflow configs. When set to 1, the test workflows will only download and install the `jax` wheel. Other artifacts such as the latest releases of `jaxlib` and the CUDA plugin dependencies will be downloaded and installed from PyPI. PiperOrigin-RevId: 740430538 --- .github/workflows/pytest_cpu.yml | 19 ++++++++++++-- .github/workflows/pytest_cuda.yml | 26 +++++++++++++++---- .github/workflows/pytest_tpu.yml | 11 +++++++- .../workflows/wheel_tests_nightly_release.yml | 8 ++++++ ci/utilities/install_wheels_locally.sh | 5 ++++ 5 files changed, 61 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 137f49c6d8c7..c952ef9ee1a6 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,6 +29,11 @@ on: type: string required: true default: "0" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -92,7 +97,12 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Download wheels from GCS (Windows runs) id: download-wheel-artifacts-w # Set continue-on-error to true to prevent actions from failing the workflow if this step @@ -106,7 +116,12 @@ jobs: @REM Use `call` so that we can run sequential gsutil commands on Windows @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + + if "${{ inputs.download-jax-only-from-gcs }}"=="1" ( + echo "JAX only release. Only downloading the jax wheel from the release bucket." + ) else ( + call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + ) - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index ae74da53edcb..b3d1b15a0052 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -34,6 +34,11 @@ on: type: string required: true default: "0" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -88,11 +93,22 @@ jobs: # informative error message. continue-on-error: true run: | - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + + # Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only + # release. + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + + # Set the env var to install the CUDA plugin and PJRT packages from PyPI. jaxlib is + # required dependency of jax so that gets installed automatically. + echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=jax_cuda_pypi">> $GITHUB_ENV + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..0b56635a8aac 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -54,6 +54,11 @@ on: # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -110,7 +115,11 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..9cd48c925cf3 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -17,6 +17,11 @@ on: required: true default: 'gs://jax-nightly-release-transient/nightly/latest' type: string + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -41,6 +46,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -60,6 +66,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-tpu: @@ -98,4 +105,5 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 64f88765bb75..53f070d1e0e6 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,6 +26,11 @@ for i in "${!WHEELS[@]}"; do # Append [tpu] to the jax wheel name to download the latest libtpu wheel # from PyPI. WHEELS[$i]="${WHEELS[$i]}[tpu]" + elif [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "jax_cuda_pypi" ]]; then + # Append [cuda12-local] to the jax wheel name to download the latest + # release of JAX's CUDA plugin and PJRT packages from PyPI. This is used + # when running CUDA tests for a "jax" only release. + WHEELS[$i]="${WHEELS[$i]}[cuda12-local]" fi fi done From bda37e322289c4a903e2aa490d4f376a2d370fcf Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 25 Mar 2025 13:28:12 -0700 Subject: [PATCH 0154/1769] Increased sharding for `lax_scipy_spectral_dac_test_cpu_shardy`. PiperOrigin-RevId: 740464973 --- tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index d706f08b8092..2e03f331744c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -607,7 +607,7 @@ jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, "tpu": 10, }, From 8c44b277bebf1cf801e8cc3d91890ac0f270a880 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 13:33:13 -0700 Subject: [PATCH 0155/1769] [Mosaic GPU] Add warpgroup lowering for `BarrierArrive` in Pallas. PiperOrigin-RevId: 740466565 --- jax/_src/pallas/mosaic_gpu/primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index fe28766cfb96..8eafa0ac8e6d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -489,6 +489,7 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, From 85150471e283cd9b5f167aa2345a1796ec1ae0d5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 13:45:54 -0700 Subject: [PATCH 0156/1769] Support __jax_array__ in jnp.full_like & co --- jax/_src/numpy/array_creation.py | 16 +++++++++--- tests/array_extensibility_test.py | 41 +++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 67418e7322c9..a0495986fcd1 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -244,6 +244,8 @@ def zeros_like(a: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: @@ -287,6 +289,8 @@ def ones_like(a: ArrayLike | DuckTypedArray, [1, 1, 1]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: @@ -332,9 +336,13 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + if hasattr(prototype, '__jax_array__'): + prototype = prototype.__jax_array__() + util.check_arraylike("ones_like", prototype) + dtypes.check_user_dtype_supported(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(prototype, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) @export @@ -382,6 +390,8 @@ def full_like(a: ArrayLike | DuckTypedArray, util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() dtypes.check_user_dtype_supported(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 45c83f7967ce..3e84f6668b8d 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +from typing import Any, Callable, NamedTuple from absl.testing import absltest from absl.testing import parameterized -from typing import Any, Callable, NamedTuple +import numpy as np import jax import jax.numpy as jnp @@ -38,6 +40,15 @@ def __jax_array__(self) -> jax.Array: return jnp.asarray(self.x) +class DuckTypedArrayWithErroringJaxArray: + """Duck-typed array that provides a __jax_array__ method which fails.""" + shape = (2, 3) + dtype = np.dtype('float32') + + def __jax_array__(self): + raise ValueError("jax array was called.") + + class NumPyAPI(NamedTuple): fun: Callable[..., Any] args: list[jax.ShapeDtypeStruct] @@ -287,7 +298,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), # NumPyAPI.sig(jnp.dstack, Float[3, 5]), NumPyAPI.sig(jnp.ediff1d, Float[5]), - # NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.empty_like, Float[5]), NumPyAPI.sig(jnp.equal, Float[5], Float[5]), NumPyAPI.sig(jnp.exp, Float[5]), NumPyAPI.sig(jnp.exp2, Float[5]), @@ -312,7 +323,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), NumPyAPI.sig(jnp.frexp, Float[5]), - # NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), NumPyAPI.sig(jnp.greater, Float[5], Float[5]), NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), @@ -393,7 +404,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), NumPyAPI.sig(jnp.nonzero, Float[5]), NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), - # NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.ones_like, Float[5]), NumPyAPI.sig(jnp.outer, Float[5], Float[5]), NumPyAPI.sig(jnp.packbits, Int[5]), # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), @@ -493,7 +504,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), - # NumPyAPI.sig(jnp.zeros_like, Float[5]), + NumPyAPI.sig(jnp.zeros_like, Float[5]), ] @@ -511,6 +522,26 @@ def test_numpy_api_supports_jax_array(self, api): self.assertAllClose(wrapped, expected, atol=0, rtol=0) + @parameterized.named_parameters( + {'testcase_name': func.__name__, 'func': func} + for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like] + ) + def test_array_creation_from_duck_typed_array(self, func): + # Ensure that jnp.*_like prefers shape/dtype over __jax_array__ when + # both methods are available. + if func is jnp.full_like: + func = functools.partial(func, fill_value=2.0) + obj = DuckTypedArrayWithErroringJaxArray() + + # The test relies on this failing + with self.assertRaises(ValueError): + jnp.asarray(obj) + + result = func(obj) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, obj.shape) + self.assertEqual(result.dtype, obj.dtype) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 679ea6370b54818cb1ec3449924addc02a413a0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 14:07:30 -0700 Subject: [PATCH 0157/1769] [JAX] [XLA:Python] Migrate py_client to JAX. PiperOrigin-RevId: 740478728 --- jaxlib/xla/BUILD | 190 ++- jaxlib/xla/dlpack.cc | 4 +- jaxlib/xla/dlpack.h | 2 +- jaxlib/xla/ifrt_proxy.cc | 162 ++ jaxlib/xla/ifrt_proxy.h | 31 + jaxlib/xla/jax_jit.cc | 4 +- jaxlib/xla/jax_jit.h | 4 +- jaxlib/xla/pjit.cc | 8 +- jaxlib/xla/pmap_lib.cc | 16 +- jaxlib/xla/py_array.cc | 2063 ++++++++++++++++++++++++++ jaxlib/xla/py_array.h | 360 +++++ jaxlib/xla/py_client.cc | 851 +++++++++++ jaxlib/xla/py_client.h | 270 ++++ jaxlib/xla/py_compile_only_client.cc | 131 ++ jaxlib/xla/py_compile_only_client.h | 45 + jaxlib/xla/py_device.cc | 350 +++++ jaxlib/xla/py_device.h | 82 + jaxlib/xla/py_device_list.cc | 472 ++++++ jaxlib/xla/py_device_list.h | 137 ++ jaxlib/xla/py_executable.cc | 463 ++++++ jaxlib/xla/py_executable.h | 263 ++++ jaxlib/xla/py_memory_space.cc | 102 ++ jaxlib/xla/py_memory_space.h | 64 + jaxlib/xla/py_program.cc | 291 ++++ jaxlib/xla/py_program.h | 27 + jaxlib/xla/py_socket_transfer.cc | 6 +- jaxlib/xla/py_values.cc | 745 ++++++++++ jaxlib/xla/py_values.h | 127 ++ jaxlib/xla/sharded_device_array.h | 217 +++ jaxlib/xla/sharding.cc | 346 +++++ jaxlib/xla/sharding.h | 242 +++ jaxlib/xla/to_ifrt_sharding.cc | 141 ++ jaxlib/xla/to_ifrt_sharding.h | 56 + jaxlib/xla/xla.cc | 20 +- jaxlib/xla/xla_compiler.cc | 2 +- 35 files changed, 8253 insertions(+), 41 deletions(-) create mode 100644 jaxlib/xla/ifrt_proxy.cc create mode 100644 jaxlib/xla/ifrt_proxy.h create mode 100644 jaxlib/xla/py_array.cc create mode 100644 jaxlib/xla/py_array.h create mode 100644 jaxlib/xla/py_client.cc create mode 100644 jaxlib/xla/py_client.h create mode 100644 jaxlib/xla/py_compile_only_client.cc create mode 100644 jaxlib/xla/py_compile_only_client.h create mode 100644 jaxlib/xla/py_device.cc create mode 100644 jaxlib/xla/py_device.h create mode 100644 jaxlib/xla/py_device_list.cc create mode 100644 jaxlib/xla/py_device_list.h create mode 100644 jaxlib/xla/py_executable.cc create mode 100644 jaxlib/xla/py_executable.h create mode 100644 jaxlib/xla/py_memory_space.cc create mode 100644 jaxlib/xla/py_memory_space.h create mode 100644 jaxlib/xla/py_program.cc create mode 100644 jaxlib/xla/py_program.h create mode 100644 jaxlib/xla/py_values.cc create mode 100644 jaxlib/xla/py_values.h create mode 100644 jaxlib/xla/sharded_device_array.h create mode 100644 jaxlib/xla/sharding.cc create mode 100644 jaxlib/xla/sharding.h create mode 100644 jaxlib/xla/to_ifrt_sharding.cc create mode 100644 jaxlib/xla/to_ifrt_sharding.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e562cb7e84ea..979e659a309f 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -49,10 +49,12 @@ nanobind_extension( ":config", ":custom_call_sharding", ":dlpack", + ":ifrt_proxy", ":jax_jit", ":mlir", ":pjit", ":pmap_lib", + ":py_client", ":pytree", ":sdy", ":weakref_lru_cache", @@ -105,7 +107,6 @@ nanobind_extension( "@xla//xla/python:ops", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", @@ -114,7 +115,6 @@ nanobind_extension( "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", - "@xla//xla/python/ifrt_proxy/client:py_module", "@xla//xla/python/pjrt_ifrt", "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "@xla//xla/python/pjrt_ifrt:xla_ifrt", @@ -211,6 +211,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -229,7 +230,6 @@ cc_library( "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/pjrt:pjrt_layout", "@xla//xla/python:nb_class_ptr", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -242,6 +242,37 @@ cc_library( ], ) +cc_library( + name = "ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + hdrs = ["ifrt_proxy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@nanobind", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:statusor", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + ], +) + cc_library( name = "jax_jit", srcs = ["jax_jit.cc"], @@ -253,6 +284,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -272,7 +304,6 @@ cc_library( "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", @@ -331,6 +362,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":py_client", ":pytree", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -354,7 +386,6 @@ cc_library( "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python/ifrt", @@ -379,6 +410,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":py_client", ":pytree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -401,7 +433,6 @@ cc_library( "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -414,6 +445,149 @@ cc_library( ], ) +cc_library( + name = "py_client", + srcs = [ + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/py_client"), + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:profiler_session", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/builder/lib:arithmetic", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:host_memory_spaces", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_device_description", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt:transpose", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/python:aggregate_profile", + "@xla//xla/python:callback", + "@xla//xla/python:guard_lib", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:py_client_cpu", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:py_host_callback_proto_cc", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python:xplane_to_profile_instructions", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/service:computation_placer_hdr", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + cc_library( name = "py_socket_transfer", srcs = ["py_socket_transfer.cc"], @@ -424,6 +598,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -435,7 +610,6 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", @@ -558,6 +732,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":dlpack", + ":py_client", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -594,7 +769,6 @@ cc_library( "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:types", "@xla//xla/service:call_inliner", "@xla//xla/service:computation_placer", diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index f6605a36f02b..8b29e136f296 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -34,6 +34,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" @@ -46,8 +48,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index d0079b1d4914..e73c477b1495 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/py_client.h" namespace xla { diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc new file mode 100644 index 000000000000..e03fde194d49 --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.cc @@ -0,0 +1,162 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jaxlib/xla/ifrt_proxy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/nb_class_ptr.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto& [key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_& m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/jaxlib/xla/ifrt_proxy.h b/jaxlib/xla/ifrt_proxy.h new file mode 100644 index 000000000000..a8fcb9e676ff --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc index 23abe9a8404a..4645c59c7147 100644 --- a/jaxlib/xla/jax_jit.cc +++ b/jaxlib/xla/jax_jit.cc @@ -53,14 +53,14 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/py_values.h" -#include "xla/python/sharding.h" #include "xla/python/types.h" #include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 9e6f8e34f1e9..e2c186c5d3ff 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -35,12 +35,12 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index a13d3f0b52e3..6681f72b7b49 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -52,7 +52,11 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -66,11 +70,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_array.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index c6849f8c25fd..3dbd736076db 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,7 +46,15 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -59,15 +67,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" -#include "xla/python/py_device.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharded_device_array.h" -#include "xla/python/sharding.h" -#include "xla/python/to_ifrt_sharding.h" #include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc new file mode 100644 index 000000000000..305582a987f7 --- /dev/null +++ b/jaxlib/xla/py_array.cc @@ -0,0 +1,2063 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/python/util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, + std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( + nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object& sharding) { + const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector> ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = xla::GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device* device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto& py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device* const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = device->client()->MakeDeviceList(devices); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr> ifrt_sharding = + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr> ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030B0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage* GetPyArrayStorageFromObject(PyArrayObject* py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject* PyArray_tp_new(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* self = type->tp_alloc(type, 0); + auto* obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + auto* obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value()->ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage* Construct(PyArrayObject* self, Args&&... args) { + PyArray::Storage* out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey& value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey& other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { + using CacheT = + LRUCache>>; + static nb::ft_mutex mu; + static auto* lru_list = new CacheT::LRUList(4096); + static auto* cache = new CacheT(lru_list); + + static const nb::object* shaped_array = []() -> nb::object* { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return nullptr; + } + return new nb::object(jax_core.attr("ShapedArray")); + }(); + if (!shaped_array) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + nb_dtype dtype = + IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = (*shaped_array)( + SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey& other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey& key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage( + nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, bool committed, + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, xla::PjRtFuture<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + traceback(std::move(traceback)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = nb::cast( + sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + Traceback::Get(), std::move(ifrt_array), xla::PjRtFuture<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw XlaRuntimeError( + InvalidArgument("Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nb::object sharding, + bool weak_type, bool committed, bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, + bool committed, bool skip_checks) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { + auto py_device_list = jax::GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::PjRtFuture<>()); +} + +PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status) const { + return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), Traceback::Get(), std::move(ifrt_array), + committed_, skip_checks_, std::move(result_status)); +} + +PyArray PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, xla::PjRtFuture<> result_status) { + auto* self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage& PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage& PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(tsl::RCReference ifrt_array) { + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector& PyArray::py_arrays_cached() { + auto& py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto& ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(ifrt_array), weak_type(), + committed(), result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + std::vector> ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return InvalidArgument("Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto& ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt::Shape(shape()), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto& cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(fully_replicated_ifrt_shard), + weak_type(), committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array* ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + if (ifrt_array() == nullptr) { + return InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + return result_status.Await(); + } + return result_status.Await(); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { + if (ifrt_array() == nullptr) { + return InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto& py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return InvalidArgument("%s() is supported only for unsharded arrays.", api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict& cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + auto* pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto& arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(tsl::RCReference()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto* ifrt_client = py_client()->ifrt_client(); + tsl::RCReference out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), traceback(), std::move(out), + committed(), /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client* const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector> ifrt_arrays; + }; + absl::flat_hash_map batches; + + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + results[i] = py_arrays[i]; + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + + std::vector>> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + for (auto& [key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + auto traceback = Traceback::Get(); + for (auto& [i, ifrt_array] : ifrt_arrays) { + const auto& py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_array.py_client(), traceback, + std::move(ifrt_array), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { + return InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + if (!dst_devices.empty()) { + options.ifrt_user_context = + dst_devices.front()->client()->ifrt_client()->CreateUserContext(); + } + + nb::list owning_pylist; + std::vector> ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector device_put_fns; + device_put_fns.reserve(xs.size()); + size_t i = 0; + for (auto& x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(x, dst_devices[i]->client()->ifrt_client(), + dst_devices[i]->device(), options, dst_memory_kind)); + ++i; + } + std::vector device_puts; + device_puts.reserve(device_put_fns.size()); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + ifrt_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + // TODO(phawkins): it's highly suspicious to me that owning_pylist isn't + // consumed here. Look into this. + + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape))); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, DtypeToIfRtDType(dtype)); + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype : ifrt_arrays.front()->dtype(); + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding)); + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + py_device_list->py_client() + ->ifrt_client() + ->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + + return PyArray(aval, weak_type, dtype, std::move(shape), sharding, + py_device_list->py_client(), Traceback::Get(), + std::move(ifrt_array), committed, /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array* ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client* const client = ifrt_array_ptr->client(); + + const auto& device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + tsl::RCReference new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto& mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector> input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return xla::PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), x.traceback(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo( + std::shared_ptr buffer, + std::unique_ptr external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout& layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto* array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + PjRtBuffer& buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { + return InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char* format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major_size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + tsl::RCReference buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto* hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void* h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + TF_RETURN_IF_ERROR(ifrt_array->GetReadyFuture().Await()); + } + void* data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +namespace { +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +} // namespace + +absl::Status PyArray::RegisterTypes(nb::module_& m) { + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; + + type_ = PyType_FromSpec(&PyArray_spec); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray& self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + if (self.ifrt_array()->client()->platform_name() == "cuda" || + self.ifrt_array()->client()->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return self.ifrt_array()->client()->platform_name(); + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + auto* client = arrays[0].ifrt_array()->client(); + std::vector device_lists; + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto& d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def("__call__", [](const PyArrayResultHandler& self, + PyArray arg) { return self.Call(arg); }) + .def("__call__", + [](const PyArrayResultHandler& self, + std::vector py_arrays) { return self.Call(py_arrays); }); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h new file mode 100644 index 000000000000..f914639e383f --- /dev/null +++ b/jaxlib/xla/py_array.h @@ -0,0 +1,360 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_ARRAY_H_ +#define JAXLIB_XLA_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/traceback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + tsl::RCReference ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nanobind::object sharding, + bool weak_type, bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + const std::optional& traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<>& result_status() const { + return GetStorage().result_status; + } + + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(tsl::RCReference ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc new file mode 100644 index 000000000000..434077b0824f --- /dev/null +++ b/jaxlib/xla/py_client.cc @@ -0,0 +1,851 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/py_values.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/callback.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/py_host_callback.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory* memory : device->Memories()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device* device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::list PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { + if (!exec->is_deleted()) { + executables.append(nb::find(exec)); + } + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { + return pjrt_client()->Defragment(); + } + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception& e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception& e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), + client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(auto put_fn, + DevicePut(argument, client->ifrt_client_.get(), device, + options, ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(auto put, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(put_fn)(); + }()); + + if (put.ifrt_array) { + auto traceback = Traceback::Get(); + return PyArray::MakeFromSingleDeviceArray( + std::move(client), std::move(traceback), std::move(put.ifrt_array), + /*weak_type=*/false, + /*committed=*/false); + } else { + return put.owning_pybuffer; + } +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +// Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->Compile( + std::move(ifrt_program), std::move(ifrt_options))); + TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + return CompileIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + auto compile_options = std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); + return CompileIfrtProgram( + client, std::make_unique(module.get()), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable& executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(host_callbacks)); + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + absl::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + Traceback* traceback; + int64_t size; + xla::PjRtDevice* device; + bool operator==(const HeapProfileKey& other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey& other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback == nullptr) != (other.traceback == nullptr)) { + return false; + } + if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey& key) { + if (key.traceback) { + h = H::combine(std::move(h), key.traceback->raw_frames()); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](PjRtBuffer* buffer, Traceback* traceback) { + // We only wish to count each PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto& buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR(add_buffer_to_profile( + buffer.get(), + array.traceback() ? array.traceback()->get() : nullptr)); + } + } + + for (PyLoadedExecutable* executable = executables_; executable; + executable = executable->next_) { + if (!executable->is_deleted()) { + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + } + + PprofProfileBuilder builder; + auto* allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto* space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto& entry : entries) { + auto* sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto& frame : entry.first.traceback->raw_frames()) { + sample->add_location_id(builder.LocationId(frame.first, frame.second)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto* kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto* device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +// TODO(b/394595987): Remove this API method once we remove the call from +// mlir.py's get_emit_python_callback. +absl::StatusOr> +PyClient::GetEmitPythonCallbackDescriptor( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), + operand_shapes, result_shapes)); + const uint64_t descriptor = loaded_host_callback->descriptor(); + + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); +} + +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", + &XlaPythonCpuCallback); + +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("get_emit_python_callback_descriptor", + xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes").none() = nb::none()) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, absl::string_view name) -> nb::object { + const auto& attrs = client.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h new file mode 100644 index 000000000000..9b9d43d90228 --- /dev/null +++ b/jaxlib/xla/py_client.h @@ -0,0 +1,270 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_H_ +#define JAXLIB_XLA_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy alises. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device* device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device* device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> CompileIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that + // takes in arguments of shapes `operand_shapes` and returns values of shapes + // `result_shapes`. It returns a pair of a `uint64_t` descriptor and a Python + // object whose reference will keep the Python callback alive. The descriptor + // should be passed into a 'xla_python_cpu_callback' or + // 'xla_python_gpu_callback' CustomCall as its first argument. Typically the + // callback may be kept alive by attaching the keep-alive object to the + // executable built from this computation. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr> + GetEmitPythonCallbackDescriptor(nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::Compile` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/xla/py_compile_only_client.cc new file mode 100644 index 000000000000..6319c70f91b0 --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr> CompileUnloaded( + absl::string_view mlir_module, CompileOptions options, + std::vector host_callbacks) { + if (!host_callbacks.empty()) { + return Unimplemented( + "Compiling with host_callbacks not available with compile-only " + "client."); + } + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; + auto xla_options = std::make_unique(options); + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(auto ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + return std::shared_ptr(std::move(ifrt_executable)); + } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } +}; + +} // namespace + +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr topology) { + return CompileOnlyPyClient::Make(std::move(topology)); +} + +void RegisterCompileOnlyClient(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(self.CompileUnloaded( + absl::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); +} + +} // namespace xla diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/xla/py_compile_only_client.h new file mode 100644 index 000000000000..721830d6f52e --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + +namespace xla { + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr); + +void RegisterCompileOnlyClient(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.cc b/jaxlib/xla/py_device.cc new file mode 100644 index 000000000000..20c257bb7d1a --- /dev/null +++ b/jaxlib/xla/py_device.cc @@ -0,0 +1,350 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_memory_space.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +absl::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +absl::string_view PyDevice::Str() const { return device_->DebugString(); } + +absl::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferToInfeed is only supported for PjRt devices."); + } + return client->TransferToInfeed(device, literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferFromOutfeed is only supported for PjRt devices."); + } + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(client->TransferFromOutfeed(device, literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + absl::string_view kind) const { + ifrt::Memory* result_memory_space = nullptr; + for (auto* memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto& attrs = device->device_->Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h new file mode 100644 index 000000000000..6d2b3893dea8 --- /dev/null +++ b/jaxlib/xla/py_device.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_H_ +#define JAXLIB_XLA_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc new file mode 100644 index 000000000000..593a86ccbe42 --- /dev/null +++ b/jaxlib/xla/py_device_list.cc @@ -0,0 +1,472 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + xla::GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device* device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator& other) const { return it == other.it; } + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + xla::nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + is_fully_addressable_ = true; + switch (device_list_.index()) { + case 0: { + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() != process_index) { + is_fully_addressable_ = false; + break; + } + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) != + nb::cast(device.attr("client").attr("process_index")())) { + is_fully_addressable_ = false; + break; + } + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *is_fully_addressable_; +} + +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices)); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *self->addressable_device_list_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + xla::ifrt::Device* addressable_device = nullptr; + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (xla::ifrt::Device* device : std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_device = device; + break; + } + } + if (addressable_device == nullptr) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + + auto default_memory = addressable_device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(addressable_device->Memories().size())); + for (size_t i = 0; i < addressable_device->Memories().size(); ++i) { + auto* memory = addressable_device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + nb::handle addressable_device; + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_device = device; + break; + } + } + if (!addressable_device) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + auto default_memory = addressable_device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = nb::tuple( + nb::object(addressable_device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error& e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem) + .def("__getitem__", &PyDeviceList::GetSlice) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }); +} + +} // namespace jax diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h new file mode 100644 index 000000000000..ea574c5dc5a2 --- /dev/null +++ b/jaxlib/xla/py_device_list.h @@ -0,0 +1,137 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_LIST_H_ +#define JAXLIB_XLA_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_XLA_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc new file mode 100644 index 000000000000..5a02a8f6dd20 --- /dev/null +++ b/jaxlib/xla/py_executable.cc @@ -0,0 +1,463 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +namespace nb = nanobind; + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + nb::gil_scoped_release gil_release; + return future_.Await(); +} + +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); + for (auto& future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + return status; +} + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)), + next_launch_id_( + fingerprint_.has_value() ? tsl::Fingerprint32(*fingerprint_) : 1) { + CHECK(PyGILState_Check()); + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +// Traits classes of common methods for std::vector. +template +struct ShardedBufferAdapter; + +template <> +struct ShardedBufferAdapter { + static int num_devices(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); + } + } + static tsl::RCReference GetIfRtArray( + const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); + } + auto& arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector> ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto& arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client* client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; + } +}; + +void PopulateExecuteShardedResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + const PjRtFuture<>& result_status, int num_computations, + std::vector>& outputs) { + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto& exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, traceback, std::move(exploded_array), false, true, + result_status)); + } + } +} + +template > +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, + ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, + std::optional>>& returned_futures) { + std::vector> output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + PjRtFuture<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto& arg : args) { + if (ArgAdapter::num_devices(arg) != num_computations) { + return InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", [](std::string* out, const ArgT& arg) { + out->append(std::to_string(ArgAdapter::num_devices(arg))); + })); + } + } + std::vector> arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), [&](const ArgT& arg) mutable { + return ArgAdapter::GetIfRtArray(arg); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +PyExecuteResults::PyExecuteResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, PjRtFuture<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector> PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector> ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers) { + std::vector outputs; + auto ifrt_arrays = Consume(); + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations_, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.reserve(num_output_buffers); + if (out_handlers.size() != num_output_buffers) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto& handler = out_handlers[buffer_id]; + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : PjRtFuture<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto& disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : PjRtFuture<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +absl::StatusOr>> +PyLoadedExecutable::ExecuteShardedOnLocalDevices( + absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = false; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); + return outputs_and_tokens.DisassembleIntoSingleDeviceArrays(); +} + +absl::StatusOr>, PyShardedToken>> +PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens( + absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = true; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + returned_futures.emplace(); + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); + return std::make_pair(outputs_and_tokens.DisassembleIntoSingleDeviceArrays(), + outputs_and_tokens.ConsumeToken()); +} + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> PyLoadedExecutable::GetOutputShardings() + const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int64_t PyLoadedExecutable::GetNextLaunchId() { + return next_launch_id_.fetch_add(1, std::memory_order_relaxed); +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h new file mode 100644 index 000000000000..214431f9472e --- /dev/null +++ b/jaxlib/xla/py_executable.h @@ -0,0 +1,263 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_EXECUTABLE_H_ +#define JAXLIB_XLA_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector> Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector> ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Python wrapper around PjRtExecutable. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + std::shared_ptr shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + void Delete() { + // TODO(hyeontaek): Return absl::Status. + TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); + } + + bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // PjRtExecutable::Execute. The result is similarly transposed back into the + // argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr>> + ExecuteShardedOnLocalDevices(absl::Span args); + + absl::StatusOr>, PyShardedToken>> + ExecuteShardedOnLocalDevicesWithTokens( + absl::Span args); + + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional& traceback() { return traceback_; } + + ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int64_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + std::shared_ptr ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/xla/py_memory_space.cc new file mode 100644 index 000000000000..f365dd25dfb6 --- /dev/null +++ b/jaxlib/xla/py_memory_space.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_memory_space.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory* memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +absl::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h new file mode 100644 index 000000000000..4ad7b852f416 --- /dev/null +++ b/jaxlib/xla/py_memory_space.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_MEMORY_SPACE_H_ +#define JAXLIB_XLA_PY_MEMORY_SPACE_H_ + +#include + +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_class_ptr.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, ifrt::Memory* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Memory* memory_space() const { return memory_; } + + int process_index() const; + absl::string_view platform() const; + absl::string_view kind() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Memory* memory_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/xla/py_program.cc b/jaxlib/xla/py_program.cc new file mode 100644 index 000000000000..ec82292a50cd --- /dev/null +++ b/jaxlib/xla/py_program.cc @@ -0,0 +1,291 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_program.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(jax::PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr& py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr> GetIfrtSharding( + nb::handle sharding, int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + std::shared_ptr ifrt_sharding; + if (sharding.type().is(jax::SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +constexpr absl::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + absl::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](absl::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + ValueOrThrowWrapper(MakeColocatedPythonProgram), nb::arg("name"), + nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromString), nb::arg("data")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) + .def("make_xla_compile_options", + ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("host_callbacks")) + .def("make_colocated_python_compile_options", + ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_program.h b/jaxlib/xla/py_program.h new file mode 100644 index 000000000000..9fd30eeeed2f --- /dev/null +++ b/jaxlib/xla/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_PROGRAM_H_ +#define JAXLIB_XLA_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_PROGRAM_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index dd2c02898e18..b1c4fbcc541f 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -34,6 +34,9 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/to_ifrt_sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -48,9 +51,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" -#include "xla/python/to_ifrt_sharding.h" #include "xla/python/traceback.h" #include "xla/python/transfer/event_loop.h" #include "xla/python/transfer/socket-server.h" diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc new file mode 100644 index 000000000000..9375dd5440c6 --- /dev/null +++ b/jaxlib/xla/py_values.cc @@ -0,0 +1,745 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/sharding.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +using DevicePutFunc = std::function( + nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind)>; + +template +absl::StatusOr HandlePythonScalar( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + Shape shape; + PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/ifrt::Shape({}), /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/{}, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +absl::StatusOr HandlePythonInt( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S64; + } + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +template +absl::StatusOr HandleNumpyScalar( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + std::variant data; + PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + return [client, data, py_buffer_ref, type, to_device, options, + to_memory_kind]() mutable -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) -> const void* { + if constexpr (std::is_same_v, void*>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = std::move( + py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }, + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + void* data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + std::shared_ptr sharding = + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [client, data = data, shape = std::move(shape), + sharding = std::move(sharding), + on_done_with_host_buffer = std::move(on_done_with_host_buffer), + options]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(sharding), + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, + std::move(on_done_with_host_buffer), options.ifrt_user_context)); + + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, options, + to_memory_kind); + } + + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); + + PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitTypes(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + absl::InlinedVector byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void* data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + return [client, data, squashed_type, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), options, to_device, + to_memory_kind]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type)); + + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (options.allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt_dtype, ifrt::Shape(dims), byte_strides, + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + host_buffer_semantics, std::move(on_done_with_host_buffer), + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandlePyArray( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array* ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, options, + to_memory_kind); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + DevicePutResult result(tsl::FormRef(ifrt_array), py_array.weak_type(), + /*owning_pybuffer=*/nb::borrow(obj)); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + owning_pybuffer = py_array.weak_type()]() mutable + -> absl::StatusOr { + auto* ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), + to_memory_kind, + ifrt::ArrayCopySemantics::kReuseInput)); + return DevicePutResult(std::move(copied_ifrt_arrays[0]), + std::move(owning_pybuffer)); + }; + } +} + +} // namespace + +absl::StatusOr DevicePut(nb::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind) { + tsl::profiler::TraceMe traceme("DevicePut"); + static const absl::flat_hash_map* const handlers = + [] { + auto p = new absl::flat_hash_map(); + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, + "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = + HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, options, to_memory_kind); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, client, to_device, options, to_memory_kind); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, options, to_memory_kind); +} + +bool IsFloat0(xla::nb_numpy_ndarray arg) { + static const auto* dtypes_module = + new nb::module_(nb::module_::import_("jax.dtypes")); + static const auto* float0_dtype = + new nb::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + static const absl::flat_hash_map* const + handlers = [] { + auto p = new absl::flat_hash_map(); + + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer overflow. + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitTypes(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if it + // is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, jax_enable_x64); + } + } + return InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "Buffer/DeviceArray, Numpy " + "arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +} // namespace xla diff --git a/jaxlib/xla/py_values.h b/jaxlib/xla/py_values.h new file mode 100644 index 000000000000..b64895100d8c --- /dev/null +++ b/jaxlib/xla/py_values.h @@ -0,0 +1,127 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_XLA_PY_VALUES_H_ +#define JAXLIB_XLA_PY_VALUES_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +struct DevicePutResult { + explicit DevicePutResult( + tsl::RCReference ifrt_array, bool weak_type, + nanobind::object owning_pybuffer = nanobind::object()) + : ifrt_array(std::move(ifrt_array)), + weak_type(weak_type), + owning_pybuffer(owning_pybuffer) {} + + // Disallow copy since copying `DevicePutResult` without holding GIL may be + // dangerous due to `owning_pybuffer`. + DevicePutResult(const DevicePutResult&) = delete; + DevicePutResult& operator=(const DevicePutResult&) = delete; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; + + // Points to the on-device array. Not owned. + tsl::RCReference ifrt_array; + bool weak_type; + + nanobind::object owning_pybuffer; +}; + +// Copies a buffer-like object to be on device. +// +// If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be +// returned; float0s are not supported yet. +// If the value is known to be a PyBuffer object, py_buffer can be passed as +// an optimization to avoid a Python->C++ cast. +// +// This function performs Python work inline but postpones C++ work until the +// returned function is called. The returned function must be called after +// releasing GIL. Useful for batching GIL release when there are many device_put +// to execute. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; + tsl::RCReference ifrt_user_context; +}; +using DevicePutResultFn = + absl::AnyInvocable() &&>; +absl::StatusOr DevicePut(nanobind::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind); + +// Returns `true` if `arg` is a JAX float0 array. +bool IsFloat0(xla::nb_numpy_ndarray arg); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature& other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const xla::PyArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} +} // namespace xla + +#endif // JAXLIB_XLA_PY_VALUES_H_ diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h new file mode 100644 index 000000000000..1b0ca20aa1fc --- /dev/null +++ b/jaxlib/xla/sharded_device_array.h @@ -0,0 +1,217 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "absl/types/variant.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimentional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding& key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked& other) const { return chunks == other.chunks; } + bool operator!=(const Chunked& other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked& key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked& other) const { return size == other.size; } + bool operator!=(const Unstacked& other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked& key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis& other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis& key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated& other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated& other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated& key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming an device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector& GetSharding() const { return sharding_; } + const std::vector& GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec& key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matchs the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec& key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc new file mode 100644 index 000000000000..9952c31bd393 --- /dev/null +++ b/jaxlib/xla/sharding.cc @@ -0,0 +1,346 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sharding.h" + +#include + +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding_py) { + nb::handle sharding(sharding_py.ptr()); + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr& device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + absl::string_view device_kind_str = + nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +int Sharding::SafeNumDevices(nb::handle sharding) { + const jax::Sharding* cpp_sharding; + if (nb::try_cast(sharding, cpp_sharding)) { + if (cpp_sharding->num_devices_.has_value()) { + return (*cpp_sharding->num_devices_); + } + } + nb::set device_set = sharding.attr("device_set"); + return device_set.size(); +} + +size_t ShardingHash(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool ShardingEqual(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) return true; + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) return false; + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + a_named_sharding->spec().equal(b_named_sharding->spec()) && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->manual_axes().equal( + b_named_sharding->manual_axes()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + + return a_gspmd_sharding == b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +NamedSharding::NamedSharding(nb::object mesh, nb::object spec, + nb::object memory_kind, nb::object manual_axes, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + manual_axes_(std::move(manual_axes)), + logical_device_ids_(std::move(logical_device_ids)) { + if (spec_.is_none()) { + throw nb::type_error( + "Unexpected None passed as spec for NamedSharding. Did you mean P()?"); + } + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + memory_kind_ = nb::none(); + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + static nb::object* check_pspec = []() { + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + return new nb::object(si.attr("check_pspec")); + }(); + (*check_pspec)(mesh_, spec_, manual_axes_); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +SingleDeviceSharding::SingleDeviceSharding( + xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(xla::make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} + +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_.size() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("mesh"), nb::arg("spec").none(), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); +} + +} // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h new file mode 100644 index 000000000000..572a6cd3c86e --- /dev/null +++ b/jaxlib/xla/sharding.h @@ -0,0 +1,242 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDING_H_ +#define JAXLIB_XLA_SHARDING_H_ + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding_py); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); + +// Returns a hash that may sometimes return different hashes for equal values. +// It is not a correct implementation of `__hash__` in python, but it's fine +// for jit/pjit dispatch since it only causes spurious cache misses. +size_t ShardingHash(nanobind::handle sharding); + +bool ShardingEqual(nanobind::handle a, nanobind::handle b); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, nanobind::object manual_axes, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const nanobind::object& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& manual_axes() const { return manual_axes_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + private: + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object manual_axes_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDING_H_ diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc new file mode 100644 index 000000000000..96ec9c77071d --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -0,0 +1,141 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/to_ifrt_sharding.h" + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(jax::NamedSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::GSPMDSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace xla diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/xla/to_ifrt_sharding.h new file mode 100644 index 000000000000..0fa7f17c4563 --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#define JAXLIB_XLA_TO_IFRT_SHARDING_H_ + +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nanobind::handle sharding, + const xla::ifrt::Shape& shape, + std::vector shard_shapes); + +} // namespace xla + +#endif // JAXLIB_XLA_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 54c94c57a734..0e1ba031670f 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -46,6 +46,9 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/ifrt_proxy.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_program.h" #include "jaxlib/xla/sdy.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -61,10 +64,7 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" -#include "xla/python/ifrt_proxy/client/py_module.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" -#include "xla/python/py_client.h" -#include "xla/python/py_program.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -91,7 +91,14 @@ limitations under the License. #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_compile_only_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -113,14 +120,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" -#include "xla/python/py_array.h" -#include "xla/python/py_compile_only_client.h" -#include "xla/python/py_device.h" -#include "xla/python/py_device_list.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_memory_space.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index f4719b450988..00f8b4c295a7 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -44,6 +44,7 @@ limitations under the License. #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/py_client.h" #include "xla/array.h" #include "xla/client/executable_build_options.h" #include "xla/debug_options_flags.h" @@ -71,7 +72,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_client.h" #include "xla/python/types.h" #include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" From 0a53c9aad23e4b63843c64f6b1af3652a22a16e4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 25 Mar 2025 14:10:26 -0700 Subject: [PATCH 0158/1769] [pallas:mosaic_gpu] Updated the tests to use `plgpu.kernel` It leads to much more compact kernel definitions, just look at the diff! The combination of `pl.core_map` and `pl.run_state` is too noisy to easily follow the kernel logic. PiperOrigin-RevId: 740479934 --- jax/_src/pallas/mosaic_gpu/core.py | 2 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 6 +- tests/pallas/mosaic_gpu_test.py | 336 ++++++++++++------------- 3 files changed, 161 insertions(+), 183 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 1e4a9de1830c..99a84962ae50 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -120,7 +120,7 @@ def __call__( return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) -def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): +def kernel(body, out_shape, *, compiler_params=None, **mesh_kwargs): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a48fec61b7af..d85ba4ae2a03 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -360,8 +360,8 @@ def emit_pipeline_warp_specialized( *, grid: pallas_core.StaticGrid, memory_registers: int, - in_specs: Sequence[gpu_core.GPUBlockSpec] = (), - out_specs: Sequence[gpu_core.GPUBlockSpec] = (), + in_specs: Sequence[pl.BlockSpec] = (), + out_specs: Sequence[pl.BlockSpec] = (), max_concurrent_steps: int = 2, wg_axis: str, num_compute_wgs: int, @@ -458,7 +458,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): gpu_core.SMEM( (slots, *spec.block_shape), # type: ignore gmem_ref.dtype, - transforms=spec.transforms, + transforms=getattr(spec, "transforms", ()), ) ) in_smem_refs, out_smem_refs = util.split_list( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b39288252e08..7f8cfa21e980 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1624,23 +1624,28 @@ def scope(acc_ref): class PallasCallSm100ATest(PallasSm100ATest): def test_tmem_alloc(self): - mesh = plgpu.GPUMesh(num_threads=1, axis_names=("x")) - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - def scope(tmem_ref, smem_ref): - # Issue a write so the TMEM load is not DCE'd. - smem_ref[...] = tmem_ref[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem_ref, y_ref) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped(scope, + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + num_threads=1, + axis_names=("x",), + ) + def kernel(y_ref): + def scope(tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + pl.run_scoped( + scope, plgpu.TMEM((128, 128), jnp.float32), - plgpu.SMEM((128, 128), jnp.float32)) - y_init = jnp.zeros((128, 128), np.float32) + plgpu.SMEM((128, 128), jnp.float32), + ) + # Test that this runs without errors. - jax.block_until_ready(inner(y_init)) + jax.block_until_ready(kernel()) class PipelineTest(PallasTest): @@ -1979,9 +1984,7 @@ class WarpSpecializedPipelineTest(PallasTest): manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) - o = jnp.zeros((m, n), dtype=jnp.float16) blk_m = blk_n = 64 - o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16) def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice @@ -1992,11 +1995,10 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): if manual_consumed_barriers: [x_barrier] = consumed_barriers plgpu.barrier_arrive(x_barrier) - block_spec = plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[], - ) + + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( copy_kernel, grid=(m // blk_m, n // blk_n), @@ -2005,33 +2007,35 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): num_compute_wgs=2, wg_axis="wg", manual_consumed_barriers=manual_consumed_barriers, - in_specs=[block_spec], - out_specs=[block_spec, - # Create an index-invariant output. - plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n), - index_map=lambda i, j: (0, 0)) - ], - ) - mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, o, o_last_block): - _, out, out_last = pl.run_state(run)((x, o, o_last_block)) - return (out, out_last) - out, out_last_block = run_function(x, o, o_last_block) + in_specs=[spec], + out_specs=[ + spec, + # Create an index-invariant output. + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0) + ), + ], + ) + kernel = plgpu.kernel( + pipeline, + out_shape=( + jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16), + ), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=3, + axis_names=("_", "wg"), + ) + out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) - o = jnp.zeros((m, n), dtype=jnp.float32) + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) def tiled_add_kernel(x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice @@ -2046,43 +2050,23 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): num_compute_wgs=num_compute_wgs, memory_registers=40, wg_axis="wg", - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - ], - out_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[])], + in_specs=[spec, spec], + out_specs=[spec], ) - mesh = plgpu.GPUMesh( - grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") + kernel = plgpu.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=num_compute_wgs + 1, + axis_names=("_", "wg"), ) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, y, o): - _, _, out = pl.run_state(run)((x, y, o)) - return out - out = run_function(x, y, o) - reference = x + y - np.testing.assert_allclose(out, reference, atol=1e-4) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32) def _scoped(acc_smem, x_gmem, acc_gmem): def _compute_thread(): @@ -2116,77 +2100,70 @@ def tiled_acc_kernel(x_smem, carry): wg_axis="wg", carry_coroutine=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) ], out_specs=[], ) pipeline(x_gmem) - mesh = plgpu.GPUMesh( + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - ) - def run(refs): - x_ref, acc_ref = refs - @pl.core_map(mesh) - def _kernel_entry(): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32) - ) - @jax.jit - def run_function(x, acc): - _, out_acc = pl.run_state(run)((x, acc)) - return out_acc - out_acc = run_function(x, acc_init) + axis_names=("_", "wg"), + ) + def kernel(x_ref, acc_ref): + pl.run_scoped( + functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ) + + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) - np.testing.assert_allclose(out_acc, ref, atol=1e-4) + np.testing.assert_allclose(kernel(x), ref, atol=1e-4) class CoreMapTest(PallasTest): def test_multiple_wg(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + axis_names=("wg",), + ) + def kernel(o_ref): + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - wg_idx = jax.lax.axis_index("y") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) np.testing.assert_array_equal( - f(), np.repeat(np.arange(2), 128).reshape(2, 128) + kernel(), np.repeat(np.arange(2), 128).reshape(2, 128) ) def test_multiple_wg_with_grid(self): - mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((4, 2, 128), np.int32), + grid=(2, 2), + num_threads=2, + axis_names=("x", "y", "wg"), + ) + def kernel(o_ref): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.psum(1, "wg") + o_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - xy_idx = jax.lax.axis_index(("x", "y")) - yx_idx = jax.lax.axis_index(("y", "x")) - wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") - y_ref[xy_idx, wg_idx] = jnp.broadcast_to( - yx_idx * num_wgs + wg_idx, (128,) - ) - y_init = jnp.zeros((4, 2, 128), np.int32) - return inner(y_init) np.testing.assert_array_equal( - f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) def test_multiple_wg_with_squashed_grid(self): @@ -2197,70 +2174,71 @@ def test_multiple_wg_with_squashed_grid(self): y_dim = 5 z_dim = 7 num_threads = 2 - mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), - num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - b_idx = jax.lax.axis_index("b") - x_idx = jax.lax.axis_index("x") - y_idx = jax.lax.axis_index("y") - z_idx = jax.lax.axis_index("z") - wg_idx = jax.lax.axis_index("wg") - bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) - y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( - bxyzw_idx, (128,) - ) - y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) - return inner(y_init) - result = f()[:, :, :, :, :, 0] + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros( + (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 + ), + grid=(b, x_dim, y_dim, z_dim), + num_threads=num_threads, + axis_names=("b", "x", "y", "z", "wg"), + ) + def kernel(o_ref): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + + result = kernel()[:, :, :, :, :, 0] ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( - result.shape) + result.shape + ) np.testing.assert_array_equal(result, ref) def test_cross_wg_barrier(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + axis_names=("wg",), + ) + def kernel(y_ref): + def scoped(barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + # Each warpgroup is a single logical thread! + pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1], 128).reshape(2, 128) + ) def test_cluster(self): - mesh = plgpu.GPUMesh(grid=(2,), cluster=(2,), axis_names=("x", "cluster")) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros(128, np.int32), + grid=(2,), + cluster=(2,), + axis_names=("x", "cluster"), + ) + def kernel(ref): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) - @jax.jit - def f(): - @pl.run_state - def inner(ref): - @pl.core_map(mesh) - def kernel(): - block_idx = jax.lax.axis_index("x") - cluster_idx = jax.lax.axis_index("cluster") - pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) - - ref[...] = ref[...] - return inner(jnp.zeros(128, np.int32)) + ref[...] = ref[...] with self.capture_stdout() as output: - jax.block_until_ready(f()) + jax.block_until_ready(kernel()) self.assertEqual( set(output().splitlines()), { From e9fdf67ecc243c4fcf2344e7047367b6e79c9035 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 14:45:33 -0700 Subject: [PATCH 0159/1769] [jaxlib:cpu] Cleaning up after callback FFI refactor. PiperOrigin-RevId: 740492139 --- jaxlib/xla/py_client.cc | 35 +++-------------------------------- jaxlib/xla/py_client.h | 17 ----------------- 2 files changed, 3 insertions(+), 49 deletions(-) diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 434077b0824f..5fe6bc648e07 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -367,8 +367,7 @@ std::unique_ptr MakeIfrtCompileOptions( ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); @@ -386,8 +385,7 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); @@ -480,8 +478,7 @@ PyClient::CompileIfrtProgram( ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { auto callback = tsl::MakeRef( client->ifrt_client(), std::move(host_callback)); @@ -660,28 +657,6 @@ absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( return callback_capsule; } -// TODO(b/394595987): Remove this API method once we remove the call from -// mlir.py's get_emit_python_callback. -absl::StatusOr> -PyClient::GetEmitPythonCallbackDescriptor( - nb::callable callable, absl::Span operand_shapes, - absl::Span result_shapes) { - TF_ASSIGN_OR_RETURN( - auto loaded_host_callback, - PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), - operand_shapes, result_shapes)); - const uint64_t descriptor = loaded_host_callback->descriptor(); - - nb::capsule callback_capsule( - loaded_host_callback.release(), [](void* ptr) noexcept { - static_cast(ptr)->DropRef(); - }); - return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); -} - -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", - &XlaPythonCpuCallback); - /* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, void* arg) { PyClient* c = nb::inst_ptr(self); @@ -813,10 +788,6 @@ PyType_Slot PyClient::slots_[] = { // TODO(zhangqiaorjc): Experimental. .def("defragment", [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) - .def("get_emit_python_callback_descriptor", - xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), - nb::arg("callable"), nb::arg("operand_shapes"), - nb::arg("result_shapes").none() = nb::none()) .def("make_python_callback_from_host_send_and_recv", xla::ValueOrThrowWrapper( &PyClient::MakePythonCallbackUsingHostSendAndRecv), diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 9b9d43d90228..8f50c6451627 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -184,23 +184,6 @@ class PyClient { absl::StatusOr HeapProfile(); - // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that - // takes in arguments of shapes `operand_shapes` and returns values of shapes - // `result_shapes`. It returns a pair of a `uint64_t` descriptor and a Python - // object whose reference will keep the Python callback alive. The descriptor - // should be passed into a 'xla_python_cpu_callback' or - // 'xla_python_gpu_callback' CustomCall as its first argument. Typically the - // callback may be kept alive by attaching the keep-alive object to the - // executable built from this computation. - // - // The callable receives as arguments NumPy arrays for arguments with array - // types, and None for Token argument. The callable must return a tuple of - // either arrays or None values. - absl::StatusOr> - GetEmitPythonCallbackDescriptor(nanobind::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes); - // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable // that takes in arguments of shapes `operand_shapes` and returns results of // shapes `result_shapes`. The arguments correspond to Send ops in the HLO From ec061566558ac9f3ca9c7966fa59f7e07e1b8d74 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 25 Mar 2025 14:47:33 -0700 Subject: [PATCH 0160/1769] [Pallas] A few fixes for TPU interpret mode: - Actually de-allocate buffers after a pl.run_scoped. - Periodically run an explicit garbage collection after de-allocating buffers. - Add no-op implementations for a few internal/testing mosaic primitives (prng_seed_p, prng_random_bits_p, assume_p, random_p). --- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/interpret.py | 55 ++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 24e8341046b0..fdd3a56ac7c8 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -158,6 +158,7 @@ py_library( deps = [ ":core", ":primitives", + ":verification", "//jax", "//jax:core", "//jax:source_info_util", diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5acbabc673aa..13e71a1f5c56 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -16,6 +16,7 @@ import dataclasses import enum import functools +import gc import itertools import math import threading @@ -28,8 +29,9 @@ from jax._src.lax.control_flow import for_loop from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src.pallas.mosaic import primitives as mosaic_primitives from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic import verification from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit @@ -477,6 +479,8 @@ class SharedMemory: next_dma_id: int = 100 + deallocated_bytes: int = 0 + # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? @@ -570,8 +574,18 @@ def _deallocate_buffer(device_id, memory_space, buffer_id): shared_memory = _get_shared_memory() with shared_memory.lock: - # TODO(jburnim): Error if buffer doesn't exist? - shared_memory.mem.pop((memory_space, buffer_id, device_id), None) + buff = shared_memory.mem.pop((memory_space, buffer_id, device_id)) + shared_memory.deallocated_bytes += buff.size * buff.itemsize + del buff + + should_collect = shared_memory.deallocated_bytes > 100_000_000 + if should_collect: + shared_memory.deallocated_bytes = 0 + + if should_collect: + # Periodic garbage collection here prevents OOMs -- although it's not clear + # why arrays are not getting freed without this. + gc.collect() def _allocate_semaphores(device_id, shape): device_id = int(device_id) @@ -1067,6 +1081,21 @@ def write(var, value): ordered=True) elif prim is mosaic_primitives.delay_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_seed_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_random_bits_p: + # TODO(jburnim): Implement this properly? + out = jnp.zeros(eqn.params['shape'], jnp.int32) + + elif prim is verification.assume_p: + out = read(eqn.invars[0]) + + elif prim is verification.pretend_p: out = [] elif prim is lax.cond_p: @@ -1142,16 +1171,8 @@ def f(*args, jaxpr): out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: + for a, v in zip(allocs, eqn.params['jaxpr'].invars): + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: # TODO(jburnim): De-allocate semaphores. # callback.io_callback( # _deallocate_semaphores, @@ -1160,6 +1181,14 @@ def f(*args, jaxpr): # a, # ordered=True) pass + else: + callback.io_callback( + _deallocate_buffer, + None, + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) elif prim is state_primitives.get_p: invals = deferred_invals() From ed75189c921a66e1d5232923ff5b5cdbc23e766f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 14:47:39 -0700 Subject: [PATCH 0161/1769] [sharding_in_types] Add support for rng_bit_generator PiperOrigin-RevId: 740492876 --- jax/_src/internal_test_util/test_harnesses.py | 5 ++-- jax/_src/lax/control_flow/loops.py | 15 +++++++---- jax/_src/lax/lax.py | 26 ++++++++++++------- jax/experimental/jax2tf/jax2tf.py | 3 ++- tests/pjit_test.py | 21 +++++++++++++++ 5 files changed, 53 insertions(+), 17 deletions(-) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 48c645c4d033..02779c85977e 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -3375,8 +3375,9 @@ def _make_conv_harness(name, define( lax.rng_bit_generator_p, f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}", - lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype, - algorithm=algorithm), + lambda key, shape, dtype, algorithm, out_sharding=None: lax.rng_bit_generator( + key, shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), [RandArg(key_shape, key_dtype), StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)], shape=shape, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 33e2d2cbb0c8..88af7c24e5b8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2264,17 +2264,22 @@ def map(f, xs): _, ys = scan(g, (), xs) return ys -def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): +def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, + algorithm, out_sharding): keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, - algorithm=algorithm), (None, None) + return lax.rng_bit_generator_p.bind( + keys, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] + out_s = (out_sharding.with_spec((keys.aval.sharding.spec[0], *out_sharding.spec)) + if out_sharding is not None else None) key = keys[0] - new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), - dtype=dtype, algorithm=algorithm) + new_key, bits = lax.rng_bit_generator_p.bind( + key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm, + out_sharding=out_s) new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 80a469ab6a11..dd6e7399321b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8053,15 +8053,20 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) -def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding): del dtype, algorithm return (key.shape, tuple(shape)) -def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, + out_sharding): + return (key.sharding, out_sharding) + +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) -def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm, + out_sharding): del shape, dtype, algorithm return (key.weak_type, False) @@ -8092,7 +8097,7 @@ def _rng_algorithm(algorithm: RandomAlgorithm): assert False def _rng_bit_generator_lowering( - ctx, key, *, shape, dtype, algorithm): + ctx, key, *, shape, dtype, algorithm, out_sharding): key_type = ir.RankedTensorType(key.type) key_shape, key_etype = key_type.shape, key_type.element_type # While the RngBitGenerator HLO accepts a u64[2] key on all backends, we @@ -8121,7 +8126,7 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get([2], u64_type), hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) - _, out_vals_aval = ctx.avals_out + out_key_aval, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): output_shape = mlir.shape_tensor( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) @@ -8145,7 +8150,8 @@ def _rng_bit_generator_lowering( out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals) - return [out_key, out_vals] + return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval), + mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)] rng_bit_generator_p = Primitive("rng_bit_generator") @@ -8155,7 +8161,7 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, None)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8219,7 +8225,7 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): + algorithm=RandomAlgorithm.RNG_DEFAULT, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype @@ -8235,12 +8241,14 @@ def rng_bit_generator(key, shape, dtype=np.uint32, """ shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) + out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator') if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') return tuple( rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) + key, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding)) def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f98ce433815..3d71af38388b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2822,7 +2822,8 @@ def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): multiple_results=False, extra_name_stack="random_gamma") -def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: +def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm, + out_sharding) -> Sequence[TfVal]: is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) if is_uint32_key: key = tf.reshape(key, (2, 2)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2033126759e4..608c54994b5d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7191,6 +7191,27 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) + @jtu.with_user_mesh((2,), ('x',)) + def test_rng_bit_generator(self, mesh): + def f(key): + out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) + self.assertEqual(out[0].aval.sharding.spec, P(None)) + self.assertEqual(out[1].aval.sharding.spec, P('x', None)) + return out + + key = np.array((1, 2, 3, 4)).astype(np.uint32) + out1 = f(key) + jit_f = jax.jit(f) + out2 = jit_f(key) + self.assertEqual(out1[0].shape, (4,)) + self.assertEqual(out1[1].shape, (4, 8)) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P())) + self.assertEqual(out2[1].sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out1[0].sharding, out2[0].sharding) + self.assertEqual(out1[1].sharding, out2[1].sharding) + self.assertArraysEqual(out1[0], out2[0]) + self.assertArraysEqual(out1[1], out2[1]) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 289fa625e562c96ffaf466368a5d620e14d2659c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 15:28:32 -0700 Subject: [PATCH 0162/1769] [sharding_in_types] Add fold_in support PiperOrigin-RevId: 740505750 --- jax/_src/prng.py | 4 +++- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 5fdd673b3454..ead939d74351 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -621,7 +621,9 @@ def random_fold_in(keys, msgs): def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype) + sharding = lax_internal.broadcasting_sharding_rule( + 'random_fold_in', keys_aval, msgs_aval) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 608c54994b5d..ed1f9e9b62d8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7212,6 +7212,20 @@ def f(key): self.assertArraysEqual(out1[0], out2[0]) self.assertArraysEqual(out1[1], out2[1]) + @jtu.with_user_mesh((2,), ('x',)) + def test_fold_in(self, mesh): + key = jax.random.key(72) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def f(key): + f1 = jax.random.fold_in(key, 1) + self.assertEqual(jax.random.key_data(f1).aval.sharding.spec, P(None)) + return f1 + + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 588b6932d69aad36d312cbec336effda9735da43 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 15:35:20 -0700 Subject: [PATCH 0163/1769] [JAX] [XLA:Python] Migrate more Python modules to JAX. PiperOrigin-RevId: 740507886 --- jaxlib/gpu/BUILD | 1 + jaxlib/jax.bzl | 1 + jaxlib/xla/BUILD | 176 ++++++++++++++++-- jaxlib/xla/callback.cc | 184 +++++++++++++++++++ jaxlib/xla/callback.h | 92 ++++++++++ jaxlib/xla/dlpack.cc | 2 +- jaxlib/xla/guard_lib.cc | 197 ++++++++++++++++++++ jaxlib/xla/guard_lib.h | 115 ++++++++++++ jaxlib/xla/pjit.cc | 2 +- jaxlib/xla/py_array.cc | 4 +- jaxlib/xla/py_client.cc | 6 +- jaxlib/xla/py_client_cpu.cc | 166 +++++++++++++++++ jaxlib/xla/py_client_cpu.h | 28 +++ jaxlib/xla/py_host_callback.cc | 290 ++++++++++++++++++++++++++++++ jaxlib/xla/py_host_callback.h | 170 ++++++++++++++++++ jaxlib/xla/py_host_callback.proto | 25 +++ jaxlib/xla/util.cc | 60 +++++++ jaxlib/xla/util.h | 31 ++++ jaxlib/xla/xla.cc | 2 +- 19 files changed, 1533 insertions(+), 19 deletions(-) create mode 100644 jaxlib/xla/callback.cc create mode 100644 jaxlib/xla/callback.h create mode 100644 jaxlib/xla/guard_lib.cc create mode 100644 jaxlib/xla/guard_lib.h create mode 100644 jaxlib/xla/py_client_cpu.cc create mode 100644 jaxlib/xla/py_client_cpu.h create mode 100644 jaxlib/xla/py_host_callback.cc create mode 100644 jaxlib/xla/py_host_callback.h create mode 100644 jaxlib/xla/py_host_callback.proto create mode 100644 jaxlib/xla/util.cc create mode 100644 jaxlib/xla/util.h diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 3613be567533..59c0ab8dc164 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -84,6 +84,7 @@ proto_library( cc_proto_library( name = "triton_cc_proto", + compatible_with = None, deps = [":triton_proto"], ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 9b8c861404c2..560db85d6a1e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -31,6 +31,7 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library +proto_library = native.proto_library nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 979e659a309f..e4db73d6c86d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -18,12 +18,12 @@ load( "if_oss", "jax_visibility", "nanobind_extension", + "proto_library", "py_deps", "py_strict_library", "py_strict_test", "pytype_strict_library", ) -# Placeholder: load proto_library licenses(["notice"]) @@ -49,6 +49,7 @@ nanobind_extension( ":config", ":custom_call_sharding", ":dlpack", + ":guard_lib", ":ifrt_proxy", ":jax_jit", ":mlir", @@ -57,6 +58,7 @@ nanobind_extension( ":py_client", ":pytree", ":sdy", + ":util", ":weakref_lru_cache", ":xla_compiler", "@com_google_absl//absl/base", @@ -99,7 +101,6 @@ nanobind_extension( "@xla//xla/pjrt/distributed:service", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/python:guard_lib", "@xla//xla/python:logging", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", @@ -111,7 +112,6 @@ nanobind_extension( "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", @@ -147,6 +147,41 @@ nanobind_extension( }), ) +cc_library( + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:python_ref_manager", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "config", srcs = ["config.cc"], @@ -212,6 +247,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_client", + ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -233,7 +269,6 @@ cc_library( "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", "@xla//xla/tsl/platform:errors", @@ -242,6 +277,26 @@ cc_library( ], ) +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla:util", + ], +) + cc_library( name = "ifrt_proxy", srcs = ["ifrt_proxy.cc"], @@ -361,6 +416,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":config", + ":guard_lib", ":jax_jit", ":py_client", ":pytree", @@ -382,7 +438,6 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:lru_cache", - "@xla//xla/python:guard_lib", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", @@ -482,6 +537,12 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/py_client"), deps = [ + ":callback", + ":guard_lib", + ":py_client_cpu", + ":py_host_callback", + ":py_host_callback_cc_proto", + ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", @@ -541,20 +602,14 @@ cc_library( "@xla//xla/pjrt/distributed", "@xla//xla/pjrt/distributed:client", "@xla//xla/python:aggregate_profile", - "@xla//xla/python:callback", - "@xla//xla/python:guard_lib", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:py_client_cpu", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:py_host_callback_proto_cc", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", "@xla//xla/python/ifrt", @@ -588,6 +643,86 @@ cc_library( ], ) +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/xla/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + cc_library( name = "py_socket_transfer", srcs = ["py_socket_transfer.cc"], @@ -697,6 +832,25 @@ cc_library( ], ) +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@xla//xla:util", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + cc_library( name = "weakref_lru_cache", srcs = ["weakref_lru_cache.cc"], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc new file mode 100644 index 000000000000..4eab8290c7bb --- /dev/null +++ b/jaxlib/xla/callback.cc @@ -0,0 +1,184 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/python_ref_manager.h" +#include "xla/service/custom_call_status.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(reinterpret_cast(result), + results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + absl::StatusOr maybe_result_tuple = Call(std::move(args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +void XlaPythonCpuCallback(void* output, void** inputs, + XlaCustomCallStatus* status) { + CpuCallback* callback = + absl::bit_cast(*static_cast(inputs[0])); + auto s = callback->PrepareAndCall(output, inputs + 1); + if (!s.ok()) { + auto msg = s.message(); + XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); + } +} + +} // namespace xla diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h new file mode 100644 index 000000000000..b63025efe120 --- /dev/null +++ b/jaxlib/xla/callback.h @@ -0,0 +1,92 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_CALLBACK_H_ +#define JAXLIB_XLA_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/service/custom_call_status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector& args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector& results() const { return results_; } + size_t num_results() const { return results_.size(); } + void* callback() const { return callable_.ptr(); } + + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void* result, void** arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +void XlaPythonCpuCallback(void* output, void** inputs, + XlaCustomCallStatus* status); + +} // namespace xla + +#endif // JAXLIB_XLA_CALLBACK_H_ diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 8b29e136f296..94d57e07c34a 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -36,6 +36,7 @@ limitations under the License. #include "nanobind/ndarray.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" @@ -51,7 +52,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" -#include "xla/python/util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" diff --git a/jaxlib/xla/guard_lib.cc b/jaxlib/xla/guard_lib.cc new file mode 100644 index 000000000000..77866741819c --- /dev/null +++ b/jaxlib/xla/guard_lib.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/xla/guard_lib.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState& global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); +} + +} // namespace jax diff --git a/jaxlib/xla/guard_lib.h b/jaxlib/xla/guard_lib.h new file mode 100644 index 000000000000..8ddf6e8e892e --- /dev/null +++ b/jaxlib/xla/guard_lib.h @@ -0,0 +1,115 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_GUARD_LIB_H_ +#define JAXLIB_XLA_GUARD_LIB_H_ + +#include +#include + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_GUARD_LIB_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 6681f72b7b49..0409397c82de 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -51,6 +51,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_executable.h" @@ -60,7 +61,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index 305582a987f7..a348b47454e7 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -57,12 +57,14 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_values.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/pjrt/exceptions.h" @@ -73,7 +75,6 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" @@ -95,7 +96,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" -#include "xla/python/util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 5fe6bc648e07..b74c37f28863 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -48,9 +48,12 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/callback.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_host_callback.h" #include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/py_values.h" #include "xla/literal.h" @@ -61,8 +64,6 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/callback.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -79,7 +80,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" -#include "xla/python/py_host_callback.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc new file mode 100644 index 000000000000..936a89aa3b42 --- /dev/null +++ b/jaxlib/xla/py_client_cpu.cc @@ -0,0 +1,166 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" + +namespace nb = nanobind; + +namespace xla { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, + CpuTransposePlanCache* transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + LeaveHostCallback(); + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; + } + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +} // namespace xla diff --git a/jaxlib/xla/py_client_cpu.h b/jaxlib/xla/py_client_cpu.h new file mode 100644 index 000000000000..0035b0a361fa --- /dev/null +++ b/jaxlib/xla/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_CPU_H_ +#define JAXLIB_XLA_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_CPU_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc new file mode 100644 index 000000000000..9d759cc6b77c --- /dev/null +++ b/jaxlib/xla/py_host_callback.cc @@ -0,0 +1,290 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/callback.h" +#include "jaxlib/xla/py_host_callback.pb.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +char PyFfiLoadedHostCallback::ID = 0; +char PyCpuLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const Shape& shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions_size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, + nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes) { + ifrt::PlatformId platform_id = ifrt_client->platform_id(); + if (platform_id != CpuId() && platform_id != CudaId() && + platform_id != RocmId() && platform_id != SyclId()) { + return Unimplemented("CpuCallback supports CPU and GPU only"); + } + + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = std::make_unique( + std::move(callable), callback_args, callback_results); + return tsl::RCReference( + tsl::MakeRef(ifrt_client, + std::move(cpu_callback))); +} + +absl::StatusOr PyCpuLoadedHostCallback::Serialize() const { + return Unimplemented( + "PyCpuLoadedHostCallback serialization is not supported"); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client* ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector& arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto& shape = shapes[i]; + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void** outputs, void** inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception& e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/xla/py_host_callback.h new file mode 100644 index 000000000000..da504d0c12ca --- /dev/null +++ b/jaxlib/xla/py_host_callback.h @@ -0,0 +1,170 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_HOST_CALLBACK_H_ +#define JAXLIB_XLA_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/callback.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(ifrt::Client* ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + ifrt::Client* client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return Unimplemented( + "PyCpuLoadedHostCallback::callback_data() is not supported"); + }; + + static char ID; // NOLINT + + private: + ifrt::Client* ifrt_client_; + nanobind::callable callable_; +}; + +// `PyCpuLoadedHostCallback` implements a Python host callback that uses a +// descriptor (a raw pointer to JAX `CpuCallback`). The descriptor should be +// passed into a 'xla_python_cpu_callback' or 'xla_python_gpu_callback' +// CustomCall as its first argument. +// +// Serialization is not supported. Once the descriptor is embedded in +// CustomCall in an XLA computation, the computation will not be serializable. +class PyCpuLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> Create( + ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes); + + // Returns the descriptor of `CpuCallback`. + uint64_t descriptor() const { + return absl::bit_cast(cpu_callback_.get()); + } + + CpuCallback* cpu_callback() { return cpu_callback_.get(); } + + // LoadedHostCallback implementation. + + ~PyCpuLoadedHostCallback() override = default; + + ifrt::Client* client() const override { return ifrt_client_; } + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyCpuLoadedHostCallback(ifrt::Client* ifrt_client, + std::unique_ptr cpu_callback) + : llvm::RTTIExtends( + ifrt_client, cpu_callback->callback()), + cpu_callback_(std::move(cpu_callback)) {} + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + ifrt::Client* ifrt_client_; + std::unique_ptr cpu_callback_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> + Create(ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.proto b/jaxlib/xla/py_host_callback.proto new file mode 100644 index 000000000000..997fc7fe450c --- /dev/null +++ b/jaxlib/xla/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/jaxlib/xla/util.cc b/jaxlib/xla/util.cc new file mode 100644 index 000000000000..ef0fb2ac3afd --- /dev/null +++ b/jaxlib/xla/util.cc @@ -0,0 +1,60 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/util.h" + +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/value.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + ifrt::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector> values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client* const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace xla diff --git a/jaxlib/xla/util.h b/jaxlib/xla/util.h new file mode 100644 index 000000000000..ef5fc735fc33 --- /dev/null +++ b/jaxlib/xla/util.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_UTIL_H_ +#define JAXLIB_XLA_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" + +namespace xla { + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace xla + +#endif // JAXLIB_XLA_UTIL_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 0e1ba031670f..a0508013910b 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -87,6 +87,7 @@ limitations under the License. #include "jaxlib/xla/config.h" #include "jaxlib/xla/custom_call_sharding.h" #include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" @@ -109,7 +110,6 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" -#include "xla/python/guard_lib.h" #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep From 087a38988c5608c300119ff16dce01132a931951 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 16:41:37 -0700 Subject: [PATCH 0164/1769] [sharding_in_types] Add `out_sharding` to `jax.random.uniform`. Drop into `Auto` mode inside for implementation. Co-authored-by: Roy Frostig PiperOrigin-RevId: 740529498 --- jax/_src/random.py | 16 +++++++++++++--- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index c0663dc67f80..2d315ed0cc8b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -38,6 +38,8 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.sharding_impls import canonicalize_sharding +from jax._src.pjit import auto_axes from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact @@ -379,7 +381,8 @@ def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., - maxval: RealArray = 1.) -> Array: + maxval: RealArray = 1., + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -397,14 +400,21 @@ def uniform(key: ArrayLike, key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "uniform") if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform(key, shape, dtype, minval, maxval) + return _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) + +@partial(jit, static_argnums=(1, 2, 5)) +def _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) -> Array: + if out_sharding is None: + return _uniform(key, shape, dtype, minval, maxval) + def f(key, minval, maxval): return _uniform(key, shape, dtype, minval, maxval) + return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ed1f9e9b62d8..528384358351 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7226,6 +7226,25 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_uniform(self, mesh): + @jax.jit + def f(key): + out = jax.random.uniform(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From f1a92411872dcb43c1f709701f1163bbf23299be Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 17:02:45 -0700 Subject: [PATCH 0165/1769] Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives. PiperOrigin-RevId: 740536031 --- jax/_src/lax/lax.py | 182 ++++++++++++++++++++++++++++++++++------- jax/_src/lax/linalg.py | 14 +++- jax/_src/lax/utils.py | 10 ++- 3 files changed, 172 insertions(+), 34 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index dd6e7399321b..655ef763f1ef 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -296,6 +296,7 @@ def neg(x: ArrayLike) -> Array: .. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate """ + x, = core.standard_insert_pbroadcast(x) return neg_p.bind(x) @export @@ -339,6 +340,7 @@ def sign(x: ArrayLike) -> Array: .. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign """ + x, = core.standard_insert_pbroadcast(x) return sign_p.bind(x) @export @@ -369,6 +371,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ + x1, x2 = core.standard_insert_pbroadcast(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -390,6 +393,7 @@ def floor(x: ArrayLike) -> Array: .. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor """ + x, = core.standard_insert_pbroadcast(x) return floor_p.bind(x) @export @@ -411,6 +415,7 @@ def ceil(x: ArrayLike) -> Array: .. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil """ + x, = core.standard_insert_pbroadcast(x) return ceil_p.bind(x) class RoundingMethod(enum.IntEnum): @@ -460,6 +465,7 @@ def round(x: ArrayLike, .. _stablehlo.round: https://openxla.org/stablehlo/spec#round """ rounding_method = RoundingMethod(rounding_method) + x, = core.standard_insert_pbroadcast(x) return round_p.bind(x, rounding_method=rounding_method) @export @@ -481,6 +487,7 @@ def is_finite(x: ArrayLike) -> Array: .. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite """ + x, = core.standard_insert_pbroadcast(x) return is_finite_p.bind(x) @export @@ -502,6 +509,7 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ + x, = core.standard_insert_pbroadcast(x) return exp_p.bind(x) @export @@ -525,6 +533,7 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, = core.standard_insert_pbroadcast(x) return exp2_p.bind(x) @export @@ -548,6 +557,7 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ + x, = core.standard_insert_pbroadcast(x) return expm1_p.bind(x) @export @@ -568,6 +578,7 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ + x, = core.standard_insert_pbroadcast(x) return log_p.bind(x) @export @@ -591,6 +602,7 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ + x, = core.standard_insert_pbroadcast(x) return log1p_p.bind(x) @export @@ -613,6 +625,7 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ + x, = core.standard_insert_pbroadcast(x) return tanh_p.bind(x) @export @@ -632,6 +645,7 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ + x, = core.standard_insert_pbroadcast(x) return logistic_p.bind(x) @export @@ -656,6 +670,7 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ + x, = core.standard_insert_pbroadcast(x) return sin_p.bind(x) @export @@ -680,6 +695,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ + x, = core.standard_insert_pbroadcast(x) return cos_p.bind(x) @export @@ -704,6 +720,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ + x, y = core.standard_insert_pbroadcast(x, y) return atan2_p.bind(x, y) @export @@ -726,6 +743,7 @@ def real(x: ArrayLike) -> Array: .. _stablehlo.real: https://openxla.org/stablehlo/spec#real """ + x, = core.standard_insert_pbroadcast(x) return real_p.bind(x) @export @@ -748,6 +766,7 @@ def imag(x: ArrayLike) -> Array: .. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag """ + x, = core.standard_insert_pbroadcast(x) return imag_p.bind(x) @export @@ -773,6 +792,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ + x, y = core.standard_insert_pbroadcast(x, y) return complex_p.bind(x, y) @export @@ -799,6 +819,7 @@ def conj(x: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ # TODO(mattjj): remove input_dtype, not needed anymore + x, = core.standard_insert_pbroadcast(x) return conj_p.bind(x, input_dtype=_dtype(x)) @export @@ -819,6 +840,7 @@ def abs(x: ArrayLike) -> Array: .. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs """ + x, = core.standard_insert_pbroadcast(x) return abs_p.bind(x) @export @@ -844,6 +866,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ + x, y = core.standard_insert_pbroadcast(x, y) return pow_p.bind(x, y) @export @@ -865,6 +888,7 @@ def integer_pow(x: ArrayLike, y: int) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, = core.standard_insert_pbroadcast(x) return integer_pow_p.bind(x, y=y) @export @@ -886,6 +910,7 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ + x, = core.standard_insert_pbroadcast(x) return sqrt_p.bind(x) @export @@ -908,6 +933,7 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ + x, = core.standard_insert_pbroadcast(x) return rsqrt_p.bind(x) @export @@ -929,6 +955,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ + x, = core.standard_insert_pbroadcast(x) return cbrt_p.bind(x) @export @@ -953,6 +980,7 @@ def bitwise_not(x: ArrayLike) -> Array: .. _stablehlo.not: https://openxla.org/stablehlo/spec#not """ + x, = core.standard_insert_pbroadcast(x) return not_p.bind(x) @export @@ -979,6 +1007,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ + x, y = core.standard_insert_pbroadcast(x, y) return and_p.bind(x, y) @export @@ -1005,6 +1034,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ + x, y = core.standard_insert_pbroadcast(x, y) return or_p.bind(x, y) @export @@ -1031,6 +1061,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ + x, y = core.standard_insert_pbroadcast(x, y) return xor_p.bind(x, y) @export @@ -1052,6 +1083,7 @@ def population_count(x: ArrayLike) -> Array: .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt """ + x, = core.standard_insert_pbroadcast(x) return population_count_p.bind(x) @export @@ -1072,6 +1104,7 @@ def clz(x: ArrayLike) -> Array: .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros """ + x, = core.standard_insert_pbroadcast(x) return clz_p.bind(x) @export @@ -1095,6 +1128,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ + x, y = core.standard_insert_pbroadcast(x, y) return add_p.bind(x, y) @export @@ -1118,6 +1152,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ + x, y = core.standard_insert_pbroadcast(x, y) return sub_p.bind(x, y) @export @@ -1171,6 +1206,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ + x, y = core.standard_insert_pbroadcast(x, y) return div_p.bind(x, y) @export @@ -1198,6 +1234,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ + x, y = core.standard_insert_pbroadcast(x, y) return rem_p.bind(x, y) @export @@ -1223,6 +1260,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ + x, y = core.standard_insert_pbroadcast(x, y) return max_p.bind(x, y) @export @@ -1248,6 +1286,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ + x, y = core.standard_insert_pbroadcast(x, y) return min_p.bind(x, y) @export @@ -1273,6 +1312,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_left_p.bind(x, y) @export @@ -1299,6 +1339,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1325,6 +1366,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1355,6 +1397,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return eq_p.bind(x, y) @export @@ -1385,6 +1428,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return ne_p.bind(x, y) @export @@ -1415,6 +1459,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return ge_p.bind(x, y) @export @@ -1445,6 +1490,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return gt_p.bind(x, y) @export @@ -1475,6 +1521,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return le_p.bind(x, y) @export @@ -1505,6 +1552,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return lt_p.bind(x, y) @export @@ -1574,6 +1622,8 @@ def _convert_element_type( "Instead, convert to and from their representation dtypes, e.g.:\n" f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + + operand, = core.standard_insert_pbroadcast(operand) if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) @@ -1649,6 +1699,7 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) + operand, = core.standard_insert_pbroadcast(operand) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @@ -1660,6 +1711,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ + min, x, max = core.standard_insert_pbroadcast(min, x, max) return clamp_p.bind(min, x, max) @@ -1766,6 +1818,7 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + flat_args = core.standard_insert_pbroadcast(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, @@ -1883,6 +1936,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op + operands = core.standard_insert_pbroadcast(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -1902,6 +1956,7 @@ def split(operand: ArrayLike, sizes: Sequence[int], taken along ``axis``. """ operand = asarray(operand) + operand, = core.standard_insert_pbroadcast(operand) return split_p.bind(operand, sizes=tuple(sizes), axis=canonicalize_axis(axis, operand.ndim)) @@ -2408,6 +2463,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2543,6 +2599,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ + lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2605,6 +2662,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) else: dyn_shape, static_shape = [], shape # type: ignore + operand, = core.standard_insert_pbroadcast(operand) return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions), @@ -2671,6 +2729,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape, else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) out_sharding = canonicalize_sharding(out_sharding, 'reshape') + operand, = core.standard_insert_pbroadcast(operand) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, @@ -2726,6 +2785,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ + operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -2733,6 +2793,7 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: `_ operator. """ + operand, = core.standard_insert_pbroadcast(operand) return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: @@ -2758,6 +2819,8 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. + pred, on_false, on_true = core.standard_insert_pbroadcast( + pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: @@ -2783,6 +2846,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") + which, *cases = core.standard_insert_pbroadcast(which, *cases) return select_n_p.bind(which, *cases) @@ -2796,17 +2860,20 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: + operand, = core.standard_insert_pbroadcast(operand) return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" + operand, = core.standard_insert_pbroadcast(operand) return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" + operand, = core.standard_insert_pbroadcast(operand) return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) @@ -2972,6 +3039,7 @@ def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_sum_p.bind(operand, axes=tuple(axes)) def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -2998,6 +3066,7 @@ def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_prod_p.bind(operand, axes=tuple(axes)) def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3019,6 +3088,7 @@ def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_max_p.bind(operand, axes=tuple(axes)) def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3040,6 +3110,7 @@ def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_min_p.bind(operand, axes=tuple(axes)) def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3062,6 +3133,7 @@ def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_or_p.bind(operand, axes=tuple(axes)) def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3084,6 +3156,7 @@ def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_and_p.bind(operand, axes=tuple(axes)) def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3106,6 +3179,7 @@ def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_xor_p.bind(operand, axes=tuple(axes)) @overload @@ -3143,6 +3217,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) + operand = core.standard_insert_pbroadcast(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -3190,6 +3265,7 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k = int(k) if k < 0: raise ValueError(f"k argument to top_k must be nonnegative, got {k}") + operand, = core.standard_insert_pbroadcast(operand) return top_k_p.bind(operand, k=k) def tie_in(x: Any, y: T) -> T: @@ -3375,7 +3451,9 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) + operand, = core.standard_insert_pbroadcast(operand) + return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits) def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" @@ -3383,6 +3461,7 @@ def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) if not dimensions and isinstance(array, Array): return array + array, = core.standard_insert_pbroadcast(array) return squeeze_p.bind(array, dimensions=dimensions) def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -3503,6 +3582,7 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" + x, = core.standard_insert_pbroadcast(x) return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: @@ -3530,6 +3610,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ + x, = core.standard_insert_pbroadcast(x) return tan_p.bind(x) @export @@ -3550,6 +3631,7 @@ def asin(x: ArrayLike) -> Array: - :func:`jax.lax.acos`: elementwise arc cosine. - :func:`jax.lax.atan`: elementwise arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return asin_p.bind(x) @export @@ -3570,6 +3652,7 @@ def acos(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan`: elementwise arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return acos_p.bind(x) @export @@ -3591,6 +3674,7 @@ def atan(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan2`: elementwise 2-term arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return atan_p.bind(x) @export @@ -3611,6 +3695,7 @@ def sinh(x: ArrayLike) -> Array: - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return sinh_p.bind(x) @export @@ -3631,6 +3716,7 @@ def cosh(x: ArrayLike) -> Array: - :func:`jax.lax.sinh`: elementwise hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return cosh_p.bind(x) @export @@ -3651,6 +3737,7 @@ def asinh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.sinh`: elementwise hyperbolic sine. """ + x, = core.standard_insert_pbroadcast(x) return asinh_p.bind(x) @export @@ -3671,6 +3758,7 @@ def acosh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. """ + x, = core.standard_insert_pbroadcast(x) return acosh_p.bind(x) @export @@ -3691,6 +3779,7 @@ def atanh(x: ArrayLike) -> Array: - :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return atanh_p.bind(x) @@ -3759,7 +3848,8 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): def unop(result_dtype, accepted_dtypes, name): dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), + vma_rule=_attrgetter('vma')) batching.defvectorized(prim) pe.def_trivial_padding(prim) return prim @@ -4314,7 +4404,7 @@ def _integer_pow_jvp(g, x, *, y): integer_pow_p = standard_primitive( _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) pe.def_trivial_padding(integer_pow_p) @@ -4883,7 +4973,8 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): bitcast_convert_type_p = standard_primitive( _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, - sharding_rule=_bitcast_convert_type_sharding_rule) + sharding_rule=_bitcast_convert_type_sharding_rule, + vma_rule=partial(standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -5352,6 +5443,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, + vma_rule=partial(standard_vma_rule, 'dot_general') ) @@ -6352,7 +6444,10 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) + new_vma = (standard_vma_rule('broadcast_in_dim', x) + if config.varying_axes_in_types.value else frozenset()) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -6436,7 +6531,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): return clamp_p.bind(min, x, max), 0 clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', - sharding_rule=_clamp_sharding_rule) + sharding_rule=_clamp_sharding_rule, + vma_rule=partial(standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6523,7 +6619,8 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', - sharding_rule=_concatenate_sharding_rule) + sharding_rule=_concatenate_sharding_rule, + vma_rule=partial(standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6595,11 +6692,17 @@ def _split_sharding_rule(operand, *, sizes, axis): return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') for out_sh in out_shapes] +def _split_vma_rule(operand, *, sizes, axis): + out_vma = standard_vma_rule('split', operand) + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [out_vma] * len(out_shapes) + split_p = core.Primitive('split') split_p.multiple_results = True split_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule, + _split_vma_rule)) split_p.def_impl(partial(dispatch.apply_primitive, split_p)) ad.deflinear2(split_p, _split_transpose_rule) batching.primitive_batchers[split_p] = _split_batch_rule @@ -6681,7 +6784,8 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): return select(mask, x, broadcasted_padding), operand_bdim pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', - sharding_rule=_pad_sharding_rule) + sharding_rule=_pad_sharding_rule, + vma_rule=partial(standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6745,7 +6849,8 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze', sharding_rule=_squeeze_sharding_rule) + 'squeeze', sharding_rule=_squeeze_sharding_rule, + vma_rule=partial(standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -6979,7 +7084,8 @@ def _reshape_staging_rule( return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape', sharding_rule=_reshape_sharding_rule) + 'reshape', sharding_rule=_reshape_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -7011,7 +7117,8 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): return rev(operand, new_dimensions), bdim rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', - sharding_rule=_rev_sharding_rule) + sharding_rule=_rev_sharding_rule, + vma_rule=partial(standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -7059,7 +7166,8 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', - sharding_rule=_transpose_sharding_rule) + sharding_rule=_transpose_sharding_rule, + vma_rule=partial(standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -7235,7 +7343,8 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, + vma_rule=partial(standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -7341,7 +7450,8 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, + None)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -7415,7 +7525,8 @@ def _reduce_op_sharding_rule(operand, *, axes): reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) + 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7430,7 +7541,8 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7450,7 +7562,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7460,7 +7573,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7527,13 +7641,15 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7556,20 +7672,23 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7616,7 +7735,8 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): reduce_precision_p = standard_primitive( _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), - name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule) + name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -7893,6 +8013,7 @@ def after_all(*operands): """Merges one or more XLA token values. Experimental. Wraps the XLA AfterAll operator.""" + operands = core.standard_insert_pbroadcast(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -8027,6 +8148,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ + a, b = core.standard_insert_pbroadcast(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -8161,7 +8283,8 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, + None)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8245,6 +8368,7 @@ def rng_bit_generator(key, shape, dtype=np.uint32, if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') + key, = core.standard_insert_pbroadcast(key) return tuple( rng_bit_generator_p.bind( key, shape=shape, dtype=dtype, algorithm=algorithm, @@ -8703,8 +8827,10 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - return tree_util.tree_unflatten( - treedef, optimization_barrier_p.bind(*flat_args)) + # TODO(yashkatariya): Enable this + # flat_args = core.standard_insert_pbroadcast(flat_args) + out = optimization_barrier_p.bind(*flat_args) + return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b22a4cf56062..b455257e107c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -740,6 +740,14 @@ def linalg_sharding_rule( ndim = len(output_shapes) - len(batch_spec) return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) +def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): + output_shapes = shape_rule(*avals, **kwargs) + out_vma = lax_internal.standard_vma_rule(name, *avals) + if multiple_results: + return [out_vma] * len(output_shapes) + else: + return out_vma + def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): @@ -754,6 +762,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, linalg_sharding_rule, multiple_results, shape_rule, ranks, name) else: sharding_rule = None + vma_rule = partial(linalg_vma_rule, multiple_results, shape_rule, name) prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(dispatch.apply_primitive, prim)) @@ -761,11 +770,12 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_multi_result_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, - sharding_rule)) + sharding_rule, vma_rule)) else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule, None)) + lax_utils._standard_weak_type_rule, sharding_rule, + partial(lax_internal.standard_vma_rule, name))) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 63088d665afd..0a641c122064 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -131,7 +131,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals @@ -141,11 +141,13 @@ def standard_multi_result_abstract_eval( core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + out_vmas = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value + else [frozenset()] * len(out_shapes)) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) - out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) - for s, d, weak_type, sh in zip(out_shapes, out_dtypes, - weak_types, out_shardings)] + out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) + for s, d, weak_type, sh, vma in zip( + out_shapes, out_dtypes, weak_types, out_shardings, out_vmas)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals elif least_specialized is core.UnshapedArray: From cc5141201976cbc1ce823cf24a5b6dd26412d888 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 17:11:44 -0700 Subject: [PATCH 0166/1769] [sharding_in_types] Add out_sharding to `jax.random.normal`. Drop into `Auto` mode inside for implementation. Co-authored-by: Roy Frostig PiperOrigin-RevId: 740538785 --- jax/_src/random.py | 9 +++++++-- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 2d315ed0cc8b..7277ed5aa966 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -690,7 +690,8 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -712,12 +713,16 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, 'normal') dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _normal(key, shape, dtype) + if out_sharding is None: + return _normal(key, shape, dtype) + return auto_axes(partial(_normal, shape=shape, dtype=dtype), + out_shardings=out_sharding)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 528384358351..d6673c6b6d5a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7245,6 +7245,25 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_normal(self, mesh): + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 3a593219d413a081247cae309872617bf5d2819f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 17:40:47 -0700 Subject: [PATCH 0167/1769] [jaxlib:cpu] Cleaning up after callback FFI refactor. PiperOrigin-RevId: 740547947 --- CHANGELOG.md | 2 ++ jax/_src/callback.py | 1 + jaxlib/xla/callback.cc | 13 --------- jaxlib/xla/callback.h | 4 --- jaxlib/xla/py_host_callback.cc | 31 -------------------- jaxlib/xla/py_host_callback.h | 53 +--------------------------------- 6 files changed, 4 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1acb2b48eab6..93bbe81b5e63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 92c275e7e924..dc60bfb94356 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -826,6 +826,7 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None + # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". if xla_extension_version <= 320: result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) if token: diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 4eab8290c7bb..2df1715d099f 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include -#include "absl/base/casts.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -40,7 +39,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" #include "xla/python/python_ref_manager.h" -#include "xla/service/custom_call_status.h" #include "xla/tsl/platform/statusor.h" namespace nb = nanobind; @@ -170,15 +168,4 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { return result_tuple; } -void XlaPythonCpuCallback(void* output, void** inputs, - XlaCustomCallStatus* status) { - CpuCallback* callback = - absl::bit_cast(*static_cast(inputs[0])); - auto s = callback->PrepareAndCall(output, inputs + 1); - if (!s.ok()) { - auto msg = s.message(); - XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); - } -} - } // namespace xla diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h index b63025efe120..ebd0aaca4e6d 100644 --- a/jaxlib/xla/callback.h +++ b/jaxlib/xla/callback.h @@ -28,7 +28,6 @@ limitations under the License. #include "nanobind/nanobind.h" #include "xla/pjrt/transpose.h" #include "xla/python/nb_numpy.h" -#include "xla/service/custom_call_status.h" #include "xla/xla_data.pb.h" namespace xla { @@ -84,9 +83,6 @@ class CpuCallback { xla::TransposePlanCache transpose_cache_; }; -void XlaPythonCpuCallback(void* output, void** inputs, - XlaCustomCallStatus* status); - } // namespace xla #endif // JAXLIB_XLA_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc index 9d759cc6b77c..833079335a36 100644 --- a/jaxlib/xla/py_host_callback.cc +++ b/jaxlib/xla/py_host_callback.cc @@ -34,7 +34,6 @@ limitations under the License. #include "jaxlib/xla/py_host_callback.pb.h" #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" -#include "xla/pjrt/pjrt_compiler.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" @@ -54,7 +53,6 @@ namespace nb = nanobind; namespace xla { char PyFfiLoadedHostCallback::ID = 0; -char PyCpuLoadedHostCallback::ID = 0; char PyHostSendAndRecvLoadedHostCallback::ID = 0; namespace { @@ -128,35 +126,6 @@ PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); } -absl::StatusOr> -PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, - nb::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes) { - ifrt::PlatformId platform_id = ifrt_client->platform_id(); - if (platform_id != CpuId() && platform_id != CudaId() && - platform_id != RocmId() && platform_id != SyclId()) { - return Unimplemented("CpuCallback supports CPU and GPU only"); - } - - TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); - TF_ASSIGN_OR_RETURN(auto callback_results, - CreateCallbackResults(result_shapes)); - - // `callable` will be destroyed safely with `PythonRefManager` when - // `CpuCallback` is destroyed. - auto cpu_callback = std::make_unique( - std::move(callable), callback_args, callback_results); - return tsl::RCReference( - tsl::MakeRef(ifrt_client, - std::move(cpu_callback))); -} - -absl::StatusOr PyCpuLoadedHostCallback::Serialize() const { - return Unimplemented( - "PyCpuLoadedHostCallback serialization is not supported"); -} - absl::StatusOr> PyHostSendAndRecvLoadedHostCallback::Create( ifrt::Client* ifrt_client, nb::callable callable, diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/xla/py_host_callback.h index da504d0c12ca..1a1402a4eee2 100644 --- a/jaxlib/xla/py_host_callback.h +++ b/jaxlib/xla/py_host_callback.h @@ -22,13 +22,10 @@ limitations under the License. #include #include -#include "absl/base/casts.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "nanobind/nanobind.h" -#include "jaxlib/xla/callback.h" #include "xla/pjrt/host_callback.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" @@ -56,7 +53,7 @@ class PyFfiLoadedHostCallback final ifrt::Client* client() const override { return ifrt_client_; } absl::StatusOr Serialize() const override { return Unimplemented( - "PyCpuLoadedHostCallback::callback_data() is not supported"); + "PyFfiLoadedHostCallback::Serialize() is not supported"); }; static char ID; // NOLINT @@ -66,54 +63,6 @@ class PyFfiLoadedHostCallback final nanobind::callable callable_; }; -// `PyCpuLoadedHostCallback` implements a Python host callback that uses a -// descriptor (a raw pointer to JAX `CpuCallback`). The descriptor should be -// passed into a 'xla_python_cpu_callback' or 'xla_python_gpu_callback' -// CustomCall as its first argument. -// -// Serialization is not supported. Once the descriptor is embedded in -// CustomCall in an XLA computation, the computation will not be serializable. -class PyCpuLoadedHostCallback final - : public llvm::RTTIExtends { - public: - static absl::StatusOr> Create( - ifrt::Client* ifrt_client, nanobind::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes); - - // Returns the descriptor of `CpuCallback`. - uint64_t descriptor() const { - return absl::bit_cast(cpu_callback_.get()); - } - - CpuCallback* cpu_callback() { return cpu_callback_.get(); } - - // LoadedHostCallback implementation. - - ~PyCpuLoadedHostCallback() override = default; - - ifrt::Client* client() const override { return ifrt_client_; } - - absl::StatusOr Serialize() const override; - - static char ID; // NOLINT - - private: - PyCpuLoadedHostCallback(ifrt::Client* ifrt_client, - std::unique_ptr cpu_callback) - : llvm::RTTIExtends( - ifrt_client, cpu_callback->callback()), - cpu_callback_(std::move(cpu_callback)) {} - - template - friend tsl::RCReference tsl::MakeRef(Args&&... args); - - ifrt::Client* ifrt_client_; - std::unique_ptr cpu_callback_; -}; - // `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that // uses XLA host send and recv. This object should be passed to the compiler // when creating `xla::ifrt::LoadedExecutable`. From fd5c1dc8a6855eed485df81a03c75e96f569399b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 19:04:48 -0700 Subject: [PATCH 0168/1769] [jaxlib:cpu] Return an error if we try to use subbyte types in CPU callbacks instead of failing silently. We will be adding subbyte type support in subsequence changes. PiperOrigin-RevId: 740569954 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e4db73d6c86d..e10977d526ed 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -664,6 +664,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index 936a89aa3b42..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -77,6 +78,13 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == TOKEN) { PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); continue; @@ -111,6 +119,13 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == TOKEN) continue; nb::object output = nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); From 81abbac53675f9d7fa71af822ef394413f2f86d5 Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Tue, 25 Mar 2025 06:36:28 +0000 Subject: [PATCH 0169/1769] add pascal matrix --- docs/jax.scipy.rst | 1 + jax/_src/scipy/linalg.py | 61 ++++++++++++++++++++++++++++++++++++++++ jax/scipy/linalg.py | 1 + tests/linalg_test.py | 16 +++++++++++ 4 files changed, 79 insertions(+) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..3c436697e1be 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,6 +69,7 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9917cbaa0b12..55961607b252 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2182,3 +2182,64 @@ def hilbert(n: int) -> Array: """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1) + +@partial(jit, static_argnames=("n", "kind",)) +def pascal(n: int, kind: str | None = None) -> Array: + r"""Create a Pascal matrix approximation of order n. + + JAX implementation of :func:`scipy.linalg.pascal`. + + The elements of the Pascal matrix approximate the binomial coefficents. This + implementation is not exact as JAX does not support exact factorials. + + Args: + n: the size of the matrix to create. + kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default). + + Returns: + A Pascal matrix of shape ``(n, n)`` + + Examples: + >>> with jnp.printoptions(precision=3): + ... print(jax.scipy.linalg.pascal(3, kind="lower")) + ... print(jax.scipy.linalg.pascal(4, kind="upper")) + ... print(jax.scipy.linalg.pascal(5)) + [[1. 0. 0.] + [1. 1. 0.] + [1. 2. 1.]] + [[1. 1. 1. 1.] + [0. 1. 2. 3.] + [0. 0. 1. 3.] + [0. 0. 0. 1.]] + [[ 1. 1. 1. 1. 1.] + [ 1. 2. 3. 4. 5.] + [ 1. 3. 6. 10. 15.] + [ 1. 4. 10. 20. 35.] + [ 1. 5. 15. 35. 70.]] + """ + if kind is None: + kind = "symmetric" + + valid_kind = ["symmetric", "lower", "upper"] + + if kind not in valid_kind: + raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}") + + a = jnp.arange(n, dtype=jnp.float32) + + L_n = _binom(a[:, None], a[None, :]) + + if kind == "lower": + return L_n + + if kind == "upper": + return L_n.T + + return jnp.dot(L_n, L_n.T) + +@jit +def _binom(n, k): + a = lax.lgamma(n + 1.0) + b = lax.lgamma(n - k + 1.0) + c = lax.lgamma(k + 1.0) + return lax.exp(a - b - c) diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 64bc0544000b..c8a2d5f81957 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -31,6 +31,7 @@ lu as lu, lu_factor as lu_factor, lu_solve as lu_solve, + pascal as pascal, polar as polar, qr as qr, rsf2csf as rsf2csf, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 60e507d84782..20c998d6a685 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2329,6 +2329,22 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): self.assertAllClose( new_product_with_batching, old_product, atol=atol) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + kind=["symmetric", "lower", "upper"], + ) + @jax.default_matmul_precision("float32") + def testPascal(self, n, kind): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False) + jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind) + self._CheckAgainstNumpy(osp_fun, + jsp_fun, args_maker, + atol=1e-3, + rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3, + check_dtypes=False) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From fd7775856e1de1ee161c821e123f0fc2cd21fc9d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 19:35:13 -0700 Subject: [PATCH 0170/1769] [jaxlib:gpu] Return an error if we try to use subbyte types in GPU callbacks instead of failing silently. We will be adding subbyte type support in subsequence changes. PiperOrigin-RevId: 740577676 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 17 ++++++++++++++++- jaxlib/rocm/BUILD | 1 + 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 48441632fba9..c47bc3c8126f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -682,6 +682,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index c39d5201f223..59cc385825a0 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,6 +80,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -115,7 +123,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], base); + host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -137,6 +145,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == xla::TOKEN) continue; nb::object output = nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 258556be8b1e..2c13228d3c51 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -580,6 +580,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", From 83989f6fc674a35a599a5d3cbfbdb5aa8a23fd2a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 01:29:10 -0700 Subject: [PATCH 0171/1769] [Pallas/Mosaic GPU] Add a test tracking primitives warpgroup lowering rules. The goal is to use this to figure out when we can enable warpgroup lowering by default. PiperOrigin-RevId: 740670338 --- tests/pallas/mosaic_gpu_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7f8cfa21e980..32851797fff5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,9 +26,13 @@ from absl.testing import parameterized import jax from jax import lax +from jax._src import pjit from jax._src import test_util as jtu from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.state import discharge from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp @@ -1351,6 +1355,29 @@ class PallasCallWGTest( ): ... + def test_missing_primitive_lowerings_are_tracked(self): + # This test is a way to keep track of which primitives need to be adapted + # to using warpgroup semantics. Once the set is empty, we should be able to + # enable warpgroup semantics by default (assuming we haven't overspecialized + # lowerings). + rules = mgpu_lowering.mosaic_lowering_rules + wg_lowered_primitives = set(rules[plgpu.ThreadSemantics.Warpgroup]) + lane_lowered_primitives = set(rules[plgpu.ThreadSemantics.Lane]) + + actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives + expected_missing_primitives = { + lax.optimization_barrier_p, + mgpu_primitives.broadcasted_iota_p, + lax.exp2_p, + mgpu_primitives.layout_cast_p, + mgpu_primitives.load_p, + pjit.mesh_cast_p, + lax.slice_p, + discharge.run_state_p, + } + + self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) + class PallasCallSm90ATest(PallasSm90ATest): From 660f536300e23680c14b26373d111c07477f669a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 02:22:01 -0700 Subject: [PATCH 0172/1769] [Pallas/Mosaic GPU] Add a lowering rule for `lax.optimization_barrier_p` with warpgroup semantics. PiperOrigin-RevId: 740684030 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 ++++++++++ tests/pallas/mosaic_gpu_test.py | 31 +++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 493d8c07b941..c5436c818e1d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2272,6 +2272,18 @@ def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): return mgpu.optimization_barrier(*args) +@register_lowering_rule( + lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup +) +def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): + args = [ + _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) + ] + result = mgpu.dialect.optimization_barrier(args) + + return (result,) if len(args) == 1 else result + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 32851797fff5..73440ebf5fa5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1349,6 +1349,36 @@ def convert(x_ref, y_ref): convert(x), jax.lax.bitcast_convert_type(x, out_dtype) ) + def test_optimization_barrier(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + self.skipTest("This test crashes with lane semantics") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.optimization_barrier(x_ref[...]) + + x = jax.lax.iota(jnp.float32, 128) + np.testing.assert_array_equal(kernel(x), x) + + def test_optimization_barrier_multiple_inputs(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + self.skipTest("This test crashes with lane semantics") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + x, y = lax.optimization_barrier([x_ref[...], y_ref[...]]) + o_ref[...] = x + y + + x = jax.lax.iota(jnp.float32, 128) + y = jax.lax.iota(jnp.float32, 128) * 3 + np.testing.assert_array_equal(kernel(x, y), x + y) + class PallasCallWGTest( PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup @@ -1366,7 +1396,6 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { - lax.optimization_barrier_p, mgpu_primitives.broadcasted_iota_p, lax.exp2_p, mgpu_primitives.layout_cast_p, From 3f3081d46ed77f8fd37a7497013c26df8abaa62c Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 04:45:34 -0700 Subject: [PATCH 0173/1769] [Pallas/Mosaic GPU] Add a lowering rule for `pjit.mesh_cast_p` for warpgroup semantics. PiperOrigin-RevId: 740719326 --- jax/_src/pallas/mosaic_gpu/lowering.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c5436c818e1d..c67633125fc0 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1202,9 +1202,12 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ) @register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Warpgroup) def _mesh_cast_lowering_rule(ctx, x, dst_sharding): + del ctx, dst_sharding # Unused. return x + @register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 73440ebf5fa5..4531bd568913 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized import jax from jax import lax -from jax._src import pjit from jax._src import test_util as jtu from jax._src.pallas import pallas_call from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering @@ -1400,7 +1399,6 @@ def test_missing_primitive_lowerings_are_tracked(self): lax.exp2_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, - pjit.mesh_cast_p, lax.slice_p, discharge.run_state_p, } From 9ff08909557c1a322740f15cada3e6514152a321 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 26 Mar 2025 04:49:46 -0700 Subject: [PATCH 0174/1769] [jax:callbacks] Add a test for callbacks with subbyte types. Today, we have TPU support for subbyte types, but not on CPU/GPU. Explicitly raise an error for now with a TODO for when we implement CPU/GPU support. PiperOrigin-RevId: 740720316 --- jaxlib/xla/xla_client.py | 2 +- tests/python_callback_test.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 0e4eebdfb26f..a9b1109c3bd3 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 321 +_version = 322 # Version number for MLIR:Python components. mlir_api_version = 58 diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 05b4c8d7c0ff..5650a2d4f48b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,6 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -585,6 +586,56 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_operands(self, dtype: str): + if xla_extension_version <= 321: + self.skipTest("Requires xla_extension_version >= 322.") + def get(x): + return x + def f(x): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype=dtype), + x, + ) + return y + x = np.arange(8, dtype=dtype) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) + + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_results(self, dtype: str): + if xla_extension_version <= 321: + self.skipTest("Requires xla_extension_version >= 322.") + def get(): + return np.arange(8, dtype=dtype) + + def f(): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype) + ) + return y + + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() + class PureCallbackTest(jtu.JaxTestCase): From 07ebcb2d63c3fa38d84c8c7eceef23c9d980bab8 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 26 Mar 2025 05:16:44 -0700 Subject: [PATCH 0175/1769] [Mosaic] Use large 2nd minor tiling for x2. To avoid relayout from (16, 128) to (128, 128) because we always use native tiling for ext/trunc. PiperOrigin-RevId: 740726621 --- .../mosaic/dialect/tpu/transforms/infer_memref_layout.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index cdf48632784b..fdfd04949bce 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -62,18 +62,21 @@ int getTilingFactor(const int src_sublane, const int hardware_generation, const int max_normal_tiling = tiling_sublane; int large_tiling = [&] { + if (bitwidth == 2) { + return target_sublane_count * 16; + } if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return tiling_sublane * 8; + return target_sublane_count * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return tiling_sublane * 4; + return target_sublane_count * 4; } // 16-bit values are generally always possible to relayout on the fly in v6, // so we allow large 2nd minor tiling whenever possible. We can't do this // for kernel arguments, because the layout of those is controlled by XLA. if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || (!is_kernel_argument && hardware_generation >= 6))) { - return tiling_sublane * 2; + return target_sublane_count * 2; } return tiling_sublane; }(); From 5e3330cf8cf93744f4a6ce512443ca1ee936bc3b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 05:16:55 -0700 Subject: [PATCH 0176/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d0b25f9cd8222a348c9728f88e909c4e2c30991b. PiperOrigin-RevId: 740726667 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8fcda2281ea7..359048ffacbb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5" -XLA_SHA256 = "4fe51bd389428ce65415b08693f966b142fe8218ced771becab9033503a70a3d" +XLA_COMMIT = "d0b25f9cd8222a348c9728f88e909c4e2c30991b" +XLA_SHA256 = "8cd70a67a56a8b18087fc4849908f52c95c6413eb7edc9f800fdff6304804fa4" def repo(): tf_http_archive( From 7a42e3d39d9beff823469ba0c87722248e6ace29 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 26 Mar 2025 07:06:29 -0700 Subject: [PATCH 0177/1769] [pallas:mosaic_gpu] `thread_semantics=` should still default to lane-level PiperOrigin-RevId: 740753009 --- jax/_src/pallas/mosaic_gpu/pallas_call_registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 5399727878a6..40b12215c003 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -57,7 +57,7 @@ def pallas_call_lowering( print(grid_mapping) thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mgpu.ThreadSemantics.Warpgroup + "thread_semantics", mgpu.ThreadSemantics.Lane ) if thread_semantics == mgpu.ThreadSemantics.Warpgroup: mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error From c15921243936d8027de28678b3ad199f9ac498d5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 25 Mar 2025 22:33:21 +0000 Subject: [PATCH 0178/1769] Some codebase fixes required for python 3.14 - Fix for "SyntaxWarning: 'return' in a 'finally' block" - Fix for "AttributeError: 'typing.Union' object attribute '__doc__' is read-only" --- jax/_src/basearray.py | 13 +++++++++++-- jax/_src/util.py | 5 +++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index a89d4a2949be..fbd14d157e78 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +import sys import numpy as np from typing import Any, Union from collections.abc import Sequence @@ -175,7 +176,11 @@ def copy_to_host_async(self): np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] -StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." # ArrayLike is a Union of all objects that can be implicitly converted to a @@ -187,4 +192,8 @@ def copy_to_host_async(self): np.ndarray, # NumPy array type StaticScalar, # valid scalars ] -ArrayLike.__doc__ = "Type annotation for JAX array-like objects." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/util.py b/jax/_src/util.py index d558954e881c..b3f7becee7eb 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -417,8 +417,9 @@ def wrapper(fun: T) -> T: else docstr.format(fun=name, doc=doc, **kwargs)) fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) fun.__wrapped__ = wrapped - finally: - return fun + except Exception: + pass + return fun return wrapper From 9f40440d476aee980b519fd82911e2ec5102466c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 07:56:27 -0700 Subject: [PATCH 0179/1769] Add missing `jax` wheel dependencies. PiperOrigin-RevId: 740767116 --- BUILD.bazel | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index ebf852a60924..e7cf6de66cad 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -47,6 +47,7 @@ transitive_py_deps( "//jax:sparse_test_util", "//jax:test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", "//jax/experimental/jax2tf", @@ -105,14 +106,19 @@ jax_source_package( ) genrule( - name = "internal_test_util_sources", + name = "wheel_additives", srcs = [ "//jax:internal_export_back_compat_test_util", "//jax:internal_test_harnesses", "//jax:internal_test_util", "//jax:internal_export_back_compat_test_data", + "//jax:experimental/pallas/ops/tpu/random/philox.py", + "//jax:experimental/pallas/ops/tpu/random/prng_utils.py", + "//jax:experimental/pallas/ops/tpu/random/threefry.py", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:matmul.py", ], - outs = ["internal_test_util_sources.zip"], + outs = ["wheel_additives.zip"], cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", tools = ["@bazel_tools//tools/zip:zipper"], ) @@ -131,15 +137,16 @@ COMMON_DEPS = py_deps([ py_import( name = "jax_py_import", wheel = ":jax_wheel", - wheel_deps = [":internal_test_util_sources"], + wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) -# This target is used to add internal test util sources to the jax wheel. -# This is needed for the tests that depend on jax and use internal test util sources. +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. py_import( name = "jax_wheel_with_internal_test_util", wheel = "@pypi_jax//:whl", - wheel_deps = [":internal_test_util_sources"], + wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) From dfa2f46968f07797ba9c21b2570f651f2c123c69 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 08:51:37 -0700 Subject: [PATCH 0180/1769] [Pallas/Mosaic GPU] Delete `mesh_cast_p` lowering rules. They don't seem to be used. PiperOrigin-RevId: 740785108 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c67633125fc0..fc3bdaac7aed 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1201,12 +1201,6 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, ) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Warpgroup) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): - del ctx, dst_sharding # Unused. - return x - @register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) def _slice_lowering_rule( From 9d768c475454f484769d15ebe177b8c63f3620bb Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 26 Mar 2025 09:09:20 -0700 Subject: [PATCH 0181/1769] [pallas:mgpu] Use the ExitStack context to manage smem allocations. PiperOrigin-RevId: 740790684 --- jax/_src/pallas/mosaic_gpu/lowering.py | 134 ++++++++++++------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index fc3bdaac7aed..e2c4ce322b1a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1838,76 +1838,76 @@ def _run_scoped_lowering_rule( ): input_refs = [] should_discharge = [] - alloc_stack = contextlib.ExitStack() - for v in jaxpr.invars: - aval = v.aval - if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - dtype = mlir.dtype_to_ir_type(aval.dtype) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + with contextlib.ExitStack() as alloc_stack: + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) + should_discharge.append(True) + elif isinstance(aval.dtype, gpu_core.BarrierType): + input_refs.append( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier( + aval.dtype.num_arrivals + * ctx.estimator_ctx.arrival_multiplier, + *aval.shape, + ) + ) + ) + should_discharge.append(False) + elif aval.memory_space == gpu_core.SMEM: + [input_ref] = alloc_stack.enter_context( + ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) else: - zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) - acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) - acc = mgpu.dialect.optimization_barrier([acc]) - nvvm_dialect.wgmma_fence_aligned() - input_refs.append(acc) - should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): - input_refs.append( - ctx.module_ctx.reserve_barrier( - mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, - *aval.shape, - ) - ) - ) - should_discharge.append(False) - elif aval.memory_space == gpu_core.SMEM: - [input_ref] = alloc_stack.enter_context( - ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] - ) - ) - input_refs.append(input_ref) - should_discharge.append(False) - elif aval.memory_space == gpu_core.TMEM: - input_ref = alloc_stack.enter_context( - ctx.module_ctx.alloc_tmem( - jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), - ) + raise ValueError(f"Can't convert to ref: {aval}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + discharged_jaxpr, + new_input_vals, + (), ) - input_refs.append(input_ref) - should_discharge.append(False) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] else: - raise ValueError(f"Can't convert to ref: {aval}") - - if any(should_discharge): - # We convert consts to args, because we only have ir.Values and - # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the - # consts. We also don't want to wrap the values in refs. - no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) - should_discharge = [False] * len(consts) + should_discharge - discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) - new_input_vals = consts + tuple(input_refs) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - discharged_jaxpr, - new_input_vals, - (), - ) - # Discharge appends to the output the refs that got discharged. - outs = outs[:-sum(should_discharge)] - else: - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - jaxpr, - input_refs, - consts, - ) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + jaxpr, + input_refs, + consts, + ) assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs From 6851d6a1c81dab8437c3d7a7bc94c4df66ac9af6 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Wed, 26 Mar 2025 09:11:36 -0700 Subject: [PATCH 0182/1769] Skip some array_extensibility_tests on TPU. PiperOrigin-RevId: 740791514 --- tests/array_extensibility_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 3e84f6668b8d..7c0ec07e6a05 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -53,6 +53,7 @@ class NumPyAPI(NamedTuple): fun: Callable[..., Any] args: list[jax.ShapeDtypeStruct] kwargs: dict[str, Any] + skip_on_devices: list[str] | None def name(self): return self.fun.__name__ @@ -61,9 +62,12 @@ def make_args(self, rng): rng = jtu.rand_default(rng) return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + def with_skip_on_devices(self, disabled_devices: list[str]) -> 'NumPyAPI': + return self._replace(skip_on_devices=disabled_devices) + @classmethod def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': - return cls(fun, args, kwargs) + return cls(fun, args, kwargs, None) class ShapeDtype: @@ -444,7 +448,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rint, Float[5]), NumPyAPI.sig(jnp.roll, Float[5], Int[1]), NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), - NumPyAPI.sig(jnp.roots, Float[5]), + NumPyAPI.sig(jnp.roots, Float[5]).with_skip_on_devices(['tpu']), NumPyAPI.sig(jnp.rot90, Float[5, 3]), NumPyAPI.sig(jnp.round, Float[5]), NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), @@ -512,6 +516,8 @@ class JaxArrayTests(jtu.JaxTestCase): @parameterized.named_parameters( {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) def test_numpy_api_supports_jax_array(self, api): + if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices): + self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}') fun = api.fun args = api.make_args(self.rng()) wrapped_args = jax.tree.map(JaxArrayWrapper, args) From 6386efe369dfa5c234c36c205dda6b270a1a91eb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 26 Mar 2025 09:46:26 -0700 Subject: [PATCH 0183/1769] [pallas:mosaic_gpu] `plgpu.kernel` now accepts scratch shapes This frees the caller from another level of indirection via `pl.run_scoped`. PiperOrigin-RevId: 740802977 --- jax/_src/pallas/mosaic_gpu/core.py | 18 +++-- tests/pallas/mosaic_gpu_test.py | 102 ++++++++++++++--------------- 2 files changed, 63 insertions(+), 57 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 99a84962ae50..19007b6850fd 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,7 +18,7 @@ import abc import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import enum import itertools as it @@ -31,6 +31,7 @@ from jax._src import tree_util from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types @@ -114,20 +115,29 @@ def __call__( shape: tuple[int, ...], dtype: jnp.dtype, transforms: Sequence[MemoryRefTransform] = (), - ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) -def kernel(body, out_shape, *, compiler_params=None, **mesh_kwargs): +def kernel( + body: Callable[..., None], + out_shape: object, + *, + scratch_shapes: Sequence[pallas_core.ScratchShape] = (), + compiler_params: object | None = None, + **mesh_kwargs: object, +): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs def cmap_body(): - body(*operand_refs, *out_refs) + pallas_primitives.run_scoped( + lambda *scratch_refs: body(*operand_refs, *out_refs, *scratch_refs), + *scratch_shapes, + ) pallas_core.core_map( GPUMesh(**mesh_kwargs), compiler_params=compiler_params )(cmap_body) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4531bd568913..c10f06f8bb5d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1681,22 +1681,19 @@ def test_tmem_alloc(self): @functools.partial( plgpu.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32), + ], num_threads=1, axis_names=("x",), ) - def kernel(y_ref): - def scope(tmem_ref, smem_ref): - # Issue a write so the TMEM load is not DCE'd. - smem_ref[...] = tmem_ref[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem_ref, y_ref) - plgpu.wait_smem_to_gmem(0) - - pl.run_scoped( - scope, - plgpu.TMEM((128, 128), jnp.float32), - plgpu.SMEM((128, 128), jnp.float32), - ) + def kernel(y_ref, tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) # Test that this runs without errors. jax.block_until_ready(kernel()) @@ -2122,7 +2119,18 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - def _scoped(acc_smem, x_gmem, acc_gmem): + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ], + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=num_compute_wgs + 1, + axis_names=("_", "wg"), + ) + def kernel(x_gmem, acc_gmem, acc_smem): def _compute_thread(): # Cast the init value to the same layout as x_smem, so the pipeline loop # carry has a constant signature. @@ -2162,20 +2170,6 @@ def tiled_acc_kernel(x_smem, carry): ) pipeline(x_gmem) - @functools.partial( - plgpu.kernel, - out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), - ) - def kernel(x_ref, acc_ref): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32), - ) - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) @@ -2259,18 +2253,16 @@ def test_cross_wg_barrier(self): @functools.partial( plgpu.kernel, out_shape=jnp.zeros((2, 128), np.int32), + # Each warpgroup is a single logical thread! + scratch_shapes=[plgpu.Barrier(num_arrivals=2)], num_threads=2, axis_names=("wg",), ) - def kernel(y_ref): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + def kernel(o_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) np.testing.assert_array_equal( kernel(), np.repeat([0, 1], 128).reshape(2, 128) @@ -2329,25 +2321,29 @@ def body(l_ref, r_ref, o_ref): # Async copies def test_stage3(self): row_block, col_block = 64, 128 - def body(l_ref, r_ref, o_ref): + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), + scratch_shapes=[ + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ], + grid=(2,), + axis_names=("rows",), + ) + def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) - def scoped(l_smem, r_smem, o_smem, barrier): - plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) - plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) - plgpu.barrier_wait(barrier) - o_smem[...] = l_smem[...] + r_smem[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped( - scoped, - *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), - plgpu.Barrier(num_arrivals=2), - ) + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Pipelining def test_stage4(self): From 2057df13ba70996324b3617436ffd04639f89dd7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 09:49:37 -0700 Subject: [PATCH 0184/1769] [Pallas/Mosaic GPU] Fix `copy_smem_to_gmem` lowering to not use a `single_thread_predicate` when using warpgroup semantics. Also avoid generating the predicate at all when using warpgroup semantics. PiperOrigin-RevId: 740803927 --- jax/_src/pallas/mosaic_gpu/lowering.py | 10 ++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 17 +++++++++++++---- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 12 +++++------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e2c4ce322b1a..a41a657ba738 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -292,7 +292,7 @@ class ModuleContext: axis_names: _AxisNames | None program_ids: Sequence[ir.Value] | None approx_math: bool - single_wg_lane_predicate: ir.Value + single_wg_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int tmem_requested_cols: int @@ -703,12 +703,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS else: tmem_cols = 0 + + if thread_semantics == mgpu.ThreadSemantics.Lane: + single_lane_predicate = mgpu.single_thread_predicate(per_block=False) + else: # Warpgroup semantics do not have a single lane predicate. + single_lane_predicate = None + module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, - mgpu.single_thread_predicate(per_block=False), + single_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, tmem_requested_cols=tmem_cols, diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8eafa0ac8e6d..9dc65c1bef88 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -186,12 +186,21 @@ def _copy_smem_to_gmem_lowering( has_user_predicate, commit_group, ): - predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] - predicate = arith_dialect.andi( - predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) - ) + predicate = lowering._ensure_ir_value(user_predicate, jnp.bool) + else: + predicate = None + + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if predicate is not None: + assert ctx.module_ctx.single_wg_lane_predicate is not None + predicate = arith_dialect.andi( + predicate, ctx.module_ctx.single_wg_lane_predicate + ) + else: + predicate = ctx.module_ctx.single_wg_lane_predicate + flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, [src_transforms_treedef.num_leaves], diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index f0a37084b759..85929080faec 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -281,12 +281,11 @@ def MosaicGPU_AsyncLoadOp : Op Date: Wed, 26 Mar 2025 09:51:13 -0700 Subject: [PATCH 0185/1769] [AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler. Reverts d4745b9bd81b49e2a7a8938ea98516296d54635f PiperOrigin-RevId: 740804528 --- jax/_src/profiler.py | 5 +++++ jaxlib/xla/xla_extension/profiler.pyi | 1 + tests/pgle_test.py | 14 ++++++++++++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f06933f57e22..96e742f33904 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -33,6 +33,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version _profiler_server: xla_client.profiler.ProfilerServer | None = None @@ -426,6 +427,10 @@ def trace(cls, runner: PGLEProfiler | None): else: options = xla_client.profiler.ProfileOptions() options.enable_hlo_proto = True + + # ToDo(patrios): Remove when jaxlib version is updated to 0.5.4. + if jaxlib_version > (0, 5, 3): + options.raise_error_on_start_failure = True runner.current_session = xla_client.profiler.ProfilerSession(options) try: diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla/xla_extension/profiler.pyi index 7610ce1000bf..95749f61978a 100644 --- a/jaxlib/xla/xla_extension/profiler.pyi +++ b/jaxlib/xla/xla_extension/profiler.pyi @@ -42,6 +42,7 @@ class ProfileOptions: start_timestamp_ns: int duration_ms: int repository_path: str + raise_error_on_start_failure: bool def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..7dabd809d95e 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -321,7 +327,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y From 91a07ea2e8911a5b6fab7b989d28402cc0176352 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 08:40:43 -0700 Subject: [PATCH 0186/1769] Clean up a number of finalized deprecations --- jax/__init__.py | 5 ----- jax/_src/numpy/lax_numpy.py | 13 +------------ jax/core.py | 23 ----------------------- jax/interpreters/xla.py | 13 ------------- jax/lib/xla_bridge.py | 9 --------- jax/lib/xla_client.py | 23 ----------------------- jax/numpy/__init__.py | 16 ---------------- jax/numpy/__init__.pyi | 3 +-- jax/sharding.py | 15 --------------- tests/lax_numpy_test.py | 5 ----- 10 files changed, 2 insertions(+), 123 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index ae3bac4ad3fa..988c224e4772 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -220,11 +220,6 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), - # Finalized Nov 12 2024; remove after Feb 12 2025 - "clear_backends": ( - "jax.clear_backends was removed in JAX v0.4.36", - None - ), } import typing as _typing diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16355695792d..fd6209ab22c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1944,8 +1944,7 @@ def isrealobj(x: Any) -> bool: @export def reshape( - a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), + a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, copy: bool | None = None) -> Array: """Return a reshaped copy of an array. @@ -1962,8 +1961,6 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. - newshape: deprecated alias of the ``shape`` argument. Will result in a - :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -2021,14 +2018,6 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. - if not isinstance(newshape, DeprecatedArg): - raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." - " Use shape instead.") - if shape is None: - raise TypeError( - "jnp.shape requires passing a `shape` argument, but none was given." - ) try: # forward to method for ndarrays return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] diff --git a/jax/core.py b/jax/core.py index 3fd7af440d4a..b404e66c2691 100644 --- a/jax/core.py +++ b/jax/core.py @@ -160,29 +160,6 @@ _src_core.lattice_join), "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.raise_to_shaped), - # Finalized 2024-12-11; remove after 2025-3-11 - "check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None), - "check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None), - "check_valid_jaxtype": ( - ("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually" - " raise an error if core.valid_jaxtype() returns False."), - None), - "non_negative_dim": ( - "jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None, - ), - # Finalized 2024-09-25; remove after 2024-12-25 - "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), - "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), - "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), - "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), - "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), - "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), - "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), - "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), - "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), - "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), } import typing diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd3b83e37d24..2f8417ade1f8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -38,19 +38,6 @@ "jax.interpreters.xla.pytype_aval_mappings is deprecated.", _src_core.pytype_aval_mappings ), - # Finalized 2024-10-24; remove after 2025-01-24 - "xb": ( - ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " - "Use jax.lib.xla_bridge instead."), None - ), - "xc": ( - ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " - "Use jax.lib.xla_client instead."), None - ), - "xe": ( - ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " - "Use jax.lib.xla_extension instead."), None - ), } import typing as _typing diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index b158d9b1ff51..95598c447262 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -27,15 +27,6 @@ "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", _deprecated_get_backend ), - # Finalized 2024-12-11; remove after 2025-3-11 - "xla_client": ( - "jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.", - None - ), - "default_backend": ( - "jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.", - None - ), } import typing as _typing diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 07c6914a1f59..314788bfa5e7 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -26,27 +26,6 @@ Traceback = _xc.Traceback _deprecations = { - # Finalized 2024-12-11; remove after 2025-3-11 - "_xla": ( - "jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.", - None, - ), - "bfloat16": ( - "jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.", - None, - ), - # Finalized 2024-12-23; remove after 2024-03-23 - "Device": ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - None, - ), - "XlaRuntimeError": ( - ( - "jax.lib.xla_client.XlaRuntimeError is deprecated; use" - " jax.errors.JaxRuntimeError." - ), - None, - ), # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.", @@ -106,12 +85,10 @@ ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target ArrayImpl = _xc.ArrayImpl - Device = _xc.Device PrimitiveType = _xc.PrimitiveType Shape = _xc.Shape XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation - XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index cb291bdca79a..31cca3578916 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -506,19 +506,3 @@ from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods() del register_jax_array_methods - - -_deprecations = { - # Finalized 2024-12-13; remove after 2024-3-13 - "round_": ( - "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", - None - ), -} - -import typing -if not typing.TYPE_CHECKING: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b73a3b95b9a5..640e9de7eac3 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -808,8 +808,7 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, shape: DimSize | Shape = ..., - newshape: DimSize | Shape | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... diff --git a/jax/sharding.py b/jax/sharding.py index 55ff0f6aea0b..bacf848f07ed 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -34,18 +34,3 @@ AxisType as AxisType, get_abstract_mesh as get_abstract_mesh, ) - -_deprecations = { - # Finalized 2024-10-01; remove after 2025-01-01. - "XLACompatibleSharding": ( - ( - "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " - "Use jax.sharding.Sharding instead." - ), - None, - ) -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 98f10d9c02b3..c0650441edd7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3496,11 +3496,6 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testReshapeDeprecatedArgs(self): - msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." - with self.assertRaisesRegex(TypeError, msg): - jnp.reshape(jnp.arange(4), newshape=(2, 2)) - @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ From b1b281a427c7182d5345f5f1a88b83feb13a46f2 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 10:36:35 -0700 Subject: [PATCH 0187/1769] Prototype of adding error checking to jax.numpy functions PiperOrigin-RevId: 740822504 --- jax/_src/error_check.py | 149 ++++++---------------------------- jax/_src/numpy/error.py | 131 ++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 9 +- tests/BUILD | 5 ++ tests/error_check_test.py | 86 -------------------- tests/jax_numpy_error_test.py | 130 +++++++++++++++++++++++++++++ 6 files changed, 296 insertions(+), 214 deletions(-) create mode 100644 jax/_src/numpy/error.py create mode 100644 tests/jax_numpy_error_test.py diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 9d493c1f351b..e78b9bc82115 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,18 +14,15 @@ from __future__ import annotations -import contextlib import dataclasses from functools import partial import json import threading import traceback as tb_lib from types import TracebackType -from typing import Literal import warnings import jax -from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import traceback_util @@ -118,56 +115,39 @@ def __exit__(self, exc_type, exc_value, traceback): _error_storage.ref = self.old_ref -# TODO(ayx): Move all category-related logic into the jax.numpy integration once -# it is ready. This logic is specific to how jax.numpy decides when to call -# `set_error_if`, and doesn't belong in the core error-checking library itself. -# The responsibility for deciding whether to predicate an error should lie with -# the user or the higher-level library (like jax.numpy), not with -# `set_error_if`. -Category = Literal["nan", "divide", "oob"] - - -def _is_category_disabled( - category: Category | None, -) -> bool: - """Check if the error checking behavior for the given category is disabled.""" - if category is None: - return False - if category == "nan": - return config.error_checking_behavior_nan.value == "ignore" - if category == "divide": - return config.error_checking_behavior_divide.value == "ignore" - if category == "oob": - return config.error_checking_behavior_oob.value == "ignore" - raise ValueError(f"Invalid category: {category}") - - -def _set_error_if_with_category( - pred: jax.Array, - /, - msg: str, - category: Category | None = None, -) -> None: +def set_error_if(pred: jax.Array, /, msg: str) -> None: """Set the internal error state if any element of `pred` is `True`. - This function is similar to :func:`set_error_if`, but it also takes a category - argument. The category can be "nan", "divide", or "oob". The error checking - behavior for each category can be configured using - :func:`set_error_checking_behavior`. If not provided, there will be no - category. + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. - This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) - to perform category-specific runtime checks tied to the operation being - performed. + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: with core.eval_context(): _initialize_error_code_ref() assert _error_storage.ref is not None - if _is_category_disabled(category): - return - # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None @@ -219,37 +199,6 @@ def _set_error_if_with_category( _error_storage.ref[...] = error_code -def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set the internal error state if any element of `pred` is `True`. - - This function is used inside JAX computations to detect runtime errors without - immediately halting execution. When this function is traced (e.g., inside - :func:`jax.jit`), the corresponding error message and its traceback are - recorded. At execution time, if `pred` contains any `True` values, the error - state is set, but execution continues without interruption. The recorded error - can later be raised using :func:`raise_if_error`. - - If the error state has already been set, subsequent errors are ignored and - will not override the existing error. - - For multi-device environments, in explicit mode, users must call - :func:`error_checking_context` to initialize a new error tracking state that - matches the device mesh. In auto mode, implicit cross-device communication may - occur inside this function, which could impact performance. A warning is - issued in such cases. - - When exporting a function with `jax.export`, error checking must be explicitly - wrapped using :func:`wrap_for_export` before export and - :func:`unwrap_from_import` after import. - - Args: - pred: A JAX boolean array. If any element of `pred` is `True`, the internal - error state will be set. - msg: The corresponding error message to be raised later. - """ - _set_error_if_with_category(pred, msg) - - def raise_if_error() -> None: """Raise an exception if the internal error state is set. @@ -406,53 +355,3 @@ def inner(*args, **kwargs): return out return inner - - -Behavior = Literal["ignore", "raise"] - - -class error_checking_behavior: - """A context manager to set the error checking behavior. - - If both `all` and a category are provided, the category will override the - `all` setting. - - When the error checking behavior is set to "ignore", all errors will be - ignored. When set to "raise", errors will be detected and recorded, but an - exception will not be raised immediately. Users must call - :func:`raise_if_error` to at the end of the computation to raise the - exception. - """ - - def __init__( - self, - *, - all: Behavior | None = None, - nan: Behavior | None = None, - divide: Behavior | None = None, - oob: Behavior | None = None, - ) -> None: - new_settings = {} - if all is not None: - new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all - if nan is not None: - new_settings["nan"] = nan - if divide is not None: - new_settings["divide"] = divide - if oob is not None: - new_settings["oob"] = oob - self.new_settings = new_settings - self.stack = contextlib.ExitStack() - - def __enter__(self): - config_flags = { - "nan": config.error_checking_behavior_nan, - "divide": config.error_checking_behavior_divide, - "oob": config.error_checking_behavior_oob, - } - for key, value in self.new_settings.items(): - self.stack.enter_context(config_flags[key](value)) - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.stack.close() diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py new file mode 100644 index 000000000000..52b996a0b050 --- /dev/null +++ b/jax/_src/numpy/error.py @@ -0,0 +1,131 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Literal + +import jax +from jax._src import config + +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") + if category == "divide": + return config.error_checking_behavior_divide.value == "ignore" + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: jax.Array, + /, + msg: str, + category: Category | None = None, +) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. + + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. + """ + if _is_category_disabled(category): + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(pred, msg) + + +def _set_error_if_nan(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is `NaN`. + + This function is disabled if the `jax_error_checking_behavior_nan` flag is + set to "ignore". + """ + if config.error_checking_behavior_nan.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + import jax.numpy as jnp + if not jnp.issubdtype(pred.dtype, jnp.floating): # only check floats + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91191d24a12e..1df973039213 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,12 +32,13 @@ from jax._src.lax import lax from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike +from jax._src.numpy import error as jnp_error +from jax._src.numpy import reductions +from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) -from jax._src.numpy.ufunc_api import ufunc -from jax._src.numpy import reductions from jax._src.util import set_module @@ -486,7 +487,9 @@ def log(x: ArrayLike, /) -> Array: >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ - return lax.log(*promote_args_inexact('log', x)) + out = lax.log(*promote_args_inexact('log', x)) + jnp_error._set_error_if_nan(out) + return out @export diff --git a/tests/BUILD b/tests/BUILD index 2e03f331744c..1baeb4f83af7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1165,6 +1165,11 @@ jax_multiplatform_test( srcs = ["error_check_test.py"], ) +jax_multiplatform_test( + name = "jax_numpy_error_test", + srcs = ["jax_numpy_error_test.py"], +) + jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], diff --git a/tests/error_check_test.py b/tests/error_check_test.py index af3f35c7ab62..0c77989b8a43 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -372,92 +372,6 @@ def run_import(serialized): ): error_check.raise_if_error() - @parameterized.product(jit=[True, False]) - def test_error_category_nan_check(self, jit): - def f(x): - error_check._set_error_if_with_category( - jnp.isnan(x), "x is NaN", category="nan" - ) - return x - - if jit: - f = jax.jit(f) - - x = jnp.full((4,), jnp.nan, dtype=jnp.float32) - - with error_check.error_checking_behavior(nan="ignore"): - _ = f(x) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(nan="raise"): - _ = f(x) - with self.assertRaisesRegex(JaxValueError, "x is NaN"): - error_check.raise_if_error() - - @parameterized.product(jit=[True, False]) - def test_error_category_divide_check(self, jit): - def f(x, y): - error_check._set_error_if_with_category( - y == 0.0, "division by zero", category="divide" - ) - return x / y - - if jit: - f = jax.jit(f) - - x = jnp.arange(4, dtype=jnp.float32) + 1 - y = jnp.arange(4, dtype=jnp.float32) - - with error_check.error_checking_behavior(divide="ignore"): - _ = f(x, y) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(divide="raise"): - _ = f(x, y) - with self.assertRaisesRegex(JaxValueError, "division by zero"): - error_check.raise_if_error() - - @parameterized.product(jit=[True, False]) - def test_error_category_oob_check(self, jit): - def f(x, start_indices, slice_sizes): - error_check._set_error_if_with_category( - jnp.logical_or( - start_indices < 0, - start_indices + jnp.array(slice_sizes, dtype=jnp.int32) - >= jnp.array(x.shape, dtype=jnp.int32), - ), - "Out of bounds in dynamic_slice", - category="oob", - ) - y = jax.lax.dynamic_slice( - x, start_indices, slice_sizes, allow_negative_indices=False - ) - return y - - if jit: - f = jax.jit(f, static_argnums=(2,)) - - x = jnp.arange(12).reshape(3, 4) - start_indices = jnp.array([0, -1], dtype=jnp.int32) - slice_sizes = (3, 4) - - with error_check.error_checking_behavior(oob="ignore"): - _ = f(x, start_indices, slice_sizes) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(oob="raise"): - _ = f(x, start_indices, slice_sizes) - with self.assertRaisesRegex( - JaxValueError, "Out of bounds in dynamic_slice", - ): - error_check.raise_if_error() - - def test_error_category_invalid_category(self): - with self.assertRaisesRegex(ValueError, "Invalid category"): - error_check._set_error_if_with_category( - jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py new file mode 100644 index 000000000000..c2883f2005e0 --- /dev/null +++ b/tests/jax_numpy_error_test.py @@ -0,0 +1,130 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +from jax._src.numpy import error as jnp_error +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +JaxValueError = error_check.JaxValueError + + +class JaxNumpyErrorTests(jtu.JaxTestCase): + @parameterized.product(jit=[True, False]) + def test_set_error_if_nan(self, jit): + def f(x): + jnp_error._set_error_if_nan(x) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_divide_check(self, jit): + def f(x, y): + jnp_error._set_error_if_with_category( + y == 0.0, "division by zero", category="divide" + ) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + jnp_error._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with jnp_error.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + jnp_error._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + + @parameterized.product(jit=[True, False]) + def test_can_raise_nan_error(self, jit): + x = jnp.arange(4, dtype=jnp.float32) - 1 + + f = jnp.log + if jit: + f = jax.jit(f) + + with jnp_error.error_checking_behavior(nan="raise"): + f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 55318d582424ba78dbbf7c0a7c9a33a60a43ada8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 10:55:14 -0700 Subject: [PATCH 0188/1769] `build/build.py` changes: copy the wheels created by the new build wheel targets into the path specified by `--output_path`. PiperOrigin-RevId: 740829299 --- build/build.py | 31 +++++++++++++++++++++++++++++++ build/tools/utils.py | 26 ++++++++++++++++++++++++++ ci/build_artifacts.sh | 11 ++--------- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/build/build.py b/build/build.py index cdb568171b66..7f70f7b2ffef 100755 --- a/build/build.py +++ b/build/build.py @@ -76,6 +76,8 @@ "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } +_JAX_CUDA_VERSION = "12" + def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -695,6 +697,35 @@ async def main(): if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + if args.use_new_wheel_build_rule: + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in args.wheels.split(","): + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + "-", "_" + ) + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl") + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, output_path, f"{wheel_dir}*.tar.gz" + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..8b8dc80d1e0f 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -256,3 +257,28 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src, dst, regex): + os.makedirs(dst, exist_ok=True) + for f in glob.glob(os.path.join(src, regex)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..d7ffe82eb699 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then From 2518e187f3fa63f0bc2e116b0592b18e6584f2c0 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 26 Mar 2025 11:10:42 -0700 Subject: [PATCH 0189/1769] [Mosaic GPU] Support more layouts in the `swap` lowering. PiperOrigin-RevId: 740835345 --- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +++++++++++++---- tests/pallas/mosaic_gpu_test.py | 36 +++++++++++++++++++++----- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a41a657ba738..286fedfa44d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1156,11 +1156,25 @@ def _swap_lowering_rule( value.store_tiled(x_smem, swizzle=swizzle) return old_value case (): - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + match value.layout: + case mgpu.WGMMARowFragLayout(): + old_value = mgpu.FragmentedArray.load_wgmma_row( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value + case mgpu.WGMMAColFragLayout(): + old_value = mgpu.FragmentedArray.load_wgmma_col( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value + case _: + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c10f06f8bb5d..aea49b645ec6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -657,14 +657,10 @@ def kernel(x_ref, o_ref, barrier_ref): @parameterized.product( src_memory_space=[plgpu.SMEM, plgpu.GMEM], - layout=[ - plgpu.Layout.WGMMA_ROW, - plgpu.Layout.WGMMA_COL, - plgpu.Layout.WG_STRIDED((128,), vec_size=1), - None, - ], + layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, + ] ) - def test_load_to_layout_with_indexing(self, src_memory_space, layout): + def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): self.skip_if_wg_semantics() @functools.partial( @@ -685,6 +681,32 @@ def kernel(x_ref, o_ref): x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) np.testing.assert_array_equal(kernel(x), x) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec( + (2, m), + lambda: (0, 0), + memory_space=plgpu.SMEM, + ), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) + @parameterized.product( src_memory_space=[plgpu.SMEM], layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], From feed69c56192ae5883082ecf4155bb2f69d1658b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 11:18:08 -0700 Subject: [PATCH 0190/1769] Add nan checking to jax.numpy functions PiperOrigin-RevId: 740838221 --- jax/_src/numpy/ufuncs.py | 71 ++++++++++++++++++++++++++--------- tests/jax_numpy_error_test.py | 64 ++++++++++++++++++++++++++++--- 2 files changed, 113 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 1df973039213..0ea2992c9955 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -575,7 +575,9 @@ def log1p(x: ArrayLike, /) -> Array: >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) """ - return lax.log1p(*promote_args_inexact('log1p', x)) + out = lax.log1p(*promote_args_inexact('log1p', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -607,7 +609,9 @@ def sin(x: ArrayLike, /) -> Array: ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ - return lax.sin(*promote_args_inexact('sin', x)) + out = lax.sin(*promote_args_inexact('sin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -638,7 +642,9 @@ def cos(x: ArrayLike, /) -> Array: ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ - return lax.cos(*promote_args_inexact('cos', x)) + out = lax.cos(*promote_args_inexact('cos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -669,7 +675,9 @@ def tan(x: ArrayLike, /) -> Array: ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ - return lax.tan(*promote_args_inexact('tan', x)) + out = lax.tan(*promote_args_inexact('tan', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -711,7 +719,9 @@ def arcsin(x: ArrayLike, /) -> Array: ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ - return lax.asin(*promote_args_inexact('arcsin', x)) + out = lax.asin(*promote_args_inexact('arcsin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -754,7 +764,9 @@ def arccos(x: ArrayLike, /) -> Array: ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ - return lax.acos(*promote_args_inexact('arccos', x)) + out = lax.acos(*promote_args_inexact('arccos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1008,6 +1020,7 @@ def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) + jnp_error._set_error_if_nan(result) if dtypes.issubdtype(result.dtype, np.complexfloating): result = _where(real(result) < 0, lax.neg(result), result) return result @@ -1113,7 +1126,9 @@ def arctanh(x: ArrayLike, /) -> Array: ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) """ - return lax.atanh(*promote_args_inexact('arctanh', x)) + out = lax.atanh(*promote_args_inexact('arctanh', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1146,7 +1161,9 @@ def sqrt(x: ArrayLike, /) -> Array: >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ - return lax.sqrt(*promote_args_inexact('sqrt', x)) + out = lax.sqrt(*promote_args_inexact('sqrt', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1215,7 +1232,11 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) - return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + if x.dtype == bool: + return lax.bitwise_or(x, y) + out = lax.add(x, y) + jnp_error._set_error_if_nan(out) + return out def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: @@ -1544,7 +1565,9 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: >>> x - 10 Array([-10, -9, -8, -7], dtype=int32) """ - return lax.sub(*promote_args("subtract", x, y)) + out = lax.sub(*promote_args("subtract", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1768,7 +1791,9 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True) """ - return lax.pow(*promote_args_inexact("float_power", x, y)) + out = lax.pow(*promote_args_inexact("float_power", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -2446,7 +2471,9 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) + out = lax.div(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export @@ -2648,7 +2675,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: - return _power(x1, x2) + out = _power(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -2774,7 +2803,9 @@ def log2(x: ArrayLike, /) -> Array: im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + jnp_error._set_error_if_nan(out) + return out @export @@ -2804,7 +2835,9 @@ def log10(x: ArrayLike, /) -> Array: im = lax.imag(r) ln10 = lax.log(_constant_like(re, 10)) return lax.complex(lax.div(re, ln10), lax.div(im, ln10)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + jnp_error._set_error_if_nan(out) + return out @export @@ -3064,7 +3097,9 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + out = lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + jnp_error._set_error_if_nan(out) + return out @export @@ -3112,7 +3147,9 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) - return lax.rem(*promote_args_numeric("fmod", x1, x2)) + out = lax.rem(*promote_args_numeric("fmod", x1, x2)) + jnp_error._set_error_if_nan(out) + return out @export diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index c2883f2005e0..08917aeed364 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + from absl.testing import absltest from absl.testing import parameterized import jax @@ -112,16 +114,68 @@ def test_error_category_invalid_category(self): jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" ) - @parameterized.product(jit=[True, False]) - def test_can_raise_nan_error(self, jit): - x = jnp.arange(4, dtype=jnp.float32) - 1 + @staticmethod + def op_cases(cases): + for jit in (True, False): + for func, operands in cases: + if not isinstance(operands, tuple): + operands = (operands,) + + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + name = f"_{jit_str}_{func_str}" + + yield name, jit, func, operands + + @parameterized.named_parameters( + op_cases(( + # list of all NaN-producing jax.numpy functions + # go/keep-sorted start + (jnp.acos, 2.0), + (jnp.acosh, 0.5), + (jnp.add, (jnp.inf, -jnp.inf)), + (jnp.arccos, 2.0), + (jnp.arccosh, 0.5), + (jnp.arcsin, -2.0), + (jnp.arctanh, -2.0), + (jnp.asin, -2.0), + (jnp.atanh, -2.0), + (jnp.cos, jnp.inf), + (jnp.divide, (0.0, 0.0)), + (jnp.divmod, (1.0, 0.0)), + (jnp.float_power, (-1.0, 0.5)), + (jnp.fmod, (1.0, 0.0)), + (jnp.log, -1.0), + (jnp.log10, -1.0), + (jnp.log1p, -1.5), + (jnp.log2, -1.0), + (jnp.mod, (1.0, 0.0)), + (jnp.pow, (-1.0, 0.5)), + (jnp.power, (-1.0, 0.5)), + (jnp.remainder, (1.0, 0.0)), + (jnp.sin, jnp.inf), + # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. + # (jnp.sinc, jnp.inf), + (jnp.sqrt, -4.0), + (jnp.subtract, (jnp.inf, jnp.inf)), + (jnp.tan, jnp.inf), + (jnp.true_divide, (0.0, 0.0)), + (operator.add, (jnp.inf, -jnp.inf)), + (operator.mod, (1.0, 0.0)), + (operator.pow, (-1.0, 0.5)), + (operator.sub, (jnp.inf, jnp.inf)), + (operator.truediv, (0.0, 0.0)), + # go/keep-sorted end + )) + ) + def test_can_raise_nan_error(self, jit, f, operands): + operands = [jnp.float32(x) for x in operands] - f = jnp.log if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(x) + f(*operands) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() From 1b7c8e8d08d1308c438d47334387c6339f0456f8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 11:25:04 -0700 Subject: [PATCH 0191/1769] Add editable `jax` wheel target. The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag. After [adding `jax` wheel dependencies](https://github.com/jax-ml/jax/commit/f5a4d1a85c41a42ed8fb389259a241513970ff9a) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing. PiperOrigin-RevId: 740840736 --- BUILD.bazel | 42 ++++++++++++++++++++++-------------------- build/build.py | 11 +++++++++-- build_wheel.py | 31 ++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index e7cf6de66cad..2c10f0d9a748 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -72,35 +72,37 @@ py_binary( ], ) +WHEEL_SOURCE_FILES = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", +] + jax_wheel( name = "jax_wheel", platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = WHEEL_SOURCE_FILES, wheel_binary = ":build_wheel", wheel_name = "jax", ) jax_source_package( name = "jax_source_package", - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, source_package_binary = ":build_wheel", source_package_name = "jax", ) diff --git a/build/build.py b/build/build.py index 7f70f7b2ffef..4d16851f837c 100755 --- a/build/build.py +++ b/build/build.py @@ -68,10 +68,14 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } @@ -662,9 +666,12 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.use_new_wheel_build_rule and args.editable: + build_target = wheel_build_targets[wheel + "_editable"] + else: + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) - if args.use_new_wheel_build_rule and wheel == "jax": + if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable: wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: diff --git a/build_wheel.py b/build_wheel.py index b4db96773527..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -61,6 +61,11 @@ "Whether to build the source package only. Optional." ), ) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -90,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -103,14 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=args.build_wheel_only, - build_source_package_only=args.build_source_package_only, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() From e364abe961ed251915b1a1c7374a0ebd9974c201 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 12 Mar 2025 05:15:58 +0000 Subject: [PATCH 0192/1769] Prune passthrough outputs in lax.switch. --- jax/_src/lax/control_flow/conditionals.py | 18 +++++++++++++-- tests/lax_control_flow_test.py | 28 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..1e9372254ca1 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -151,12 +151,27 @@ def switch(index, branches, *operands): out_trees[0], jaxprs[0].out_avals, f"branch {i + 1} output", out_tree, jaxpr.out_avals) + # prune passthrough outputs + fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] + in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] + keep = [f is None for f in in_fwd] + jaxprs = [pe.prune_closed_jaxpr_outputs(jaxpr, keep) for jaxpr in jaxprs] + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') + jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + out_ = iter(out) + + all_inputs = [*consts, *ops] + out = [ + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_trees[0], out) @@ -259,7 +274,7 @@ def cond(pred, true_fun, false_fun, *operands): out_tree, true_jaxpr.out_avals, "false_fun output", false_out_tree, false_jaxpr.out_avals) - # prune passhtrough outputs + # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] @@ -278,7 +293,6 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) - num_consts = len(consts) out_ = iter(out) all_inputs = [*consts, *ops] diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..9ac4e8c6da80 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1309,6 +1309,34 @@ def f(x): self.assertAllClose(ans, expected, check_dtypes=False) jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"]) + @parameterized.parameters(itertools.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + num_branches = 4 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + def branch(s, inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + branches = [partial(branch, i) for i in range(num_branches)] + + @jax.jit + def f_(idx, inputs): + idx = lax.convert_element_type(idx // 1, np.int32) + return lax.switch(idx, branches, inputs) + + for idx in range(num_branches): + f = partial(f_, idx) + jtu.check_grads(f, (jnp.arange(float(num_args)),), + order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2) + def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896 dtype = dtypes.canonicalize_dtype(np.float64) dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64 From ec2f0f5913a3376bb940e17cc0151090f5d07d2d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 11:56:05 -0700 Subject: [PATCH 0193/1769] [sharding_in_types] Enable auto_axes to work without any mesh context manager. We extract the mesh from `out_shardings` given. This allows APIs like `random.uniform` to accept NamedSharding in `out_sharding` argument and continue to work without a mesh context. PiperOrigin-RevId: 740852542 --- jax/_src/pjit.py | 46 ++++++++++++++++++++++++++++++++++------------ tests/pjit_test.py | 42 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index bcdbe6b1bdb7..054b55e32918 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2821,29 +2821,50 @@ def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisType, name: str, - error_on_manual_to_auto_explict=False): + axis_type: mesh_lib.AxisType, name: str, shardings=None, + error_on_manual_to_auto_explicit=False): cur_mesh = mesh_lib.get_abstract_mesh() - # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable - # computation follows data? - if cur_mesh.empty: + flat_shardings, _ = tree_flatten(shardings) + sharding_mesh = mesh_lib.empty_abstract_mesh + for i in flat_shardings: + if isinstance(i, NamedSharding): + if not sharding_mesh.empty and sharding_mesh != i.mesh: + raise ValueError( + f'Shardings passed to {name} should have the same mesh. Got one' + f' mesh {sharding_mesh} and another {i.mesh}') + sharding_mesh = i.mesh.abstract_mesh + + if sharding_mesh.empty and cur_mesh.empty: raise ValueError( f'Context mesh {cur_mesh} cannot be empty. Please use' ' `jax.sharding.use_mesh` API to enter into a mesh context when using' f' `{name}` API.') + if not sharding_mesh.empty and not cur_mesh.empty: + if sharding_mesh != cur_mesh: + raise ValueError( + f'Context mesh {cur_mesh} must match the mesh passed to shardings' + f' {sharding_mesh}. Recommended approach is to use' + ' `jax.sharding.use_mesh` context manager.') + mesh_to_use = cur_mesh + elif sharding_mesh.empty and not cur_mesh.empty: + mesh_to_use = cur_mesh + else: + assert not sharding_mesh.empty and cur_mesh.empty + mesh_to_use = sharding_mesh + if axes is None: - axes = cur_mesh.axis_names + axes = mesh_to_use.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + if (error_on_manual_to_auto_explicit and + mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' ' with your use case') - return cur_mesh.update_axis_types({a: axis_type for a in axes}) + return mesh_to_use.update_axis_types({a: axis_type for a in axes}) def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, out_shardings=None): @@ -2855,8 +2876,9 @@ def decorator(*args, **kwargs): raise TypeError("Missing required keyword argument: 'out_shardings'") else: _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', - error_on_manual_to_auto_explict=True) + new_mesh = _get_new_mesh( + axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_shardings, + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) @@ -2883,7 +2905,7 @@ def decorator(*args, **kwargs): else: _in_shardings = in_shardings new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', - error_on_manual_to_auto_explict=True) + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): args = mesh_cast(args, _in_shardings) out = fun(*args, **kwargs) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d6673c6b6d5a..5cd16e1e6925 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7117,7 +7117,7 @@ def g(x): out = jax.grad(g)(arr) self.assertEqual(out.sharding, arr.sharding) - def test_auto_axes_computation_follows_data_error(self): + def test_auto_axes_computation_follows_data(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8), s) @@ -7126,8 +7126,9 @@ def test_auto_axes_computation_follows_data_error(self): def f(x): return x * 2 - with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): - auto_axes(f, out_shardings=s)(arr) + out = auto_axes(f, out_shardings=s)(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr * 2) def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( @@ -7264,6 +7265,41 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_random_normal_wo_mesh_context(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(arr, key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return arr + out + + key = jax.random.key(1) + out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_auto_axes_no_context_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', + out_shardings=NamedSharding(mesh, P('x', 'y'))) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + return z + + out = jax.jit(h)(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + out = h(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From aa160937cf3a7aa4dd953c18c1bc1ef83ddc0546 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 12:05:37 -0700 Subject: [PATCH 0194/1769] [JAX] [XLA:Python] Migrate more modules to JAX. PiperOrigin-RevId: 740855958 --- jaxlib/xla/BUILD | 102 ++++++-- jaxlib/xla/callback.cc | 2 +- jaxlib/xla/config.cc | 2 +- jaxlib/xla/dlpack.cc | 6 +- jaxlib/xla/dlpack.h | 2 +- jaxlib/xla/ifrt_proxy.cc | 2 +- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/nb_class_ptr.h | 59 +++++ jaxlib/xla/pjit.cc | 6 +- jaxlib/xla/pmap_lib.cc | 6 +- jaxlib/xla/py_array.cc | 6 +- jaxlib/xla/py_array.h | 4 +- jaxlib/xla/py_client.cc | 6 +- jaxlib/xla/py_client.h | 2 +- jaxlib/xla/py_compile_only_client.cc | 2 +- jaxlib/xla/py_compile_only_client.h | 2 +- jaxlib/xla/py_device.cc | 4 +- jaxlib/xla/py_device.h | 2 +- jaxlib/xla/py_device_list.cc | 4 +- jaxlib/xla/py_device_list.h | 2 +- jaxlib/xla/py_executable.cc | 4 +- jaxlib/xla/py_executable.h | 4 +- jaxlib/xla/py_host_callback.cc | 2 +- jaxlib/xla/py_memory_space.cc | 2 +- jaxlib/xla/py_memory_space.h | 2 +- jaxlib/xla/py_program.cc | 4 +- jaxlib/xla/py_socket_transfer.cc | 4 +- jaxlib/xla/py_values.cc | 2 +- jaxlib/xla/python_ref_manager.cc | 104 ++++++++ jaxlib/xla/python_ref_manager.h | 108 ++++++++ jaxlib/xla/pytree.cc | 2 +- jaxlib/xla/pytree.h | 2 +- jaxlib/xla/sharding.cc | 2 +- jaxlib/xla/sharding.h | 2 +- jaxlib/xla/to_ifrt_sharding.cc | 2 +- jaxlib/xla/traceback.cc | 357 +++++++++++++++++++++++++++ jaxlib/xla/traceback.h | 108 ++++++++ jaxlib/xla/xla.cc | 6 +- 38 files changed, 866 insertions(+), 74 deletions(-) create mode 100644 jaxlib/xla/nb_class_ptr.h create mode 100644 jaxlib/xla/python_ref_manager.cc create mode 100644 jaxlib/xla/python_ref_manager.h create mode 100644 jaxlib/xla/traceback.cc create mode 100644 jaxlib/xla/traceback.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e10977d526ed..512eeb867618 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -53,11 +53,14 @@ nanobind_extension( ":ifrt_proxy", ":jax_jit", ":mlir", + ":nb_class_ptr", ":pjit", ":pmap_lib", ":py_client", + ":python_ref_manager", ":pytree", ":sdy", + ":traceback", ":util", ":weakref_lru_cache", ":xla_compiler", @@ -104,13 +107,10 @@ nanobind_extension( "@xla//xla/python:logging", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:ops", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", @@ -162,6 +162,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":python_ref_manager", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", @@ -176,7 +177,6 @@ cc_library( "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/platform:statusor", ], @@ -193,13 +193,13 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":python_ref_manager", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla/python:python_ref_manager", "@xla//xla/tsl/platform:logging", ], ) @@ -246,7 +246,10 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", + ":python_ref_manager", + ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", @@ -265,9 +268,6 @@ cc_library( "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/pjrt:pjrt_layout", - "@xla//xla/python:nb_class_ptr", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -308,6 +308,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -320,7 +321,6 @@ cc_library( "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", "@xla//xla/python/ifrt_proxy/client:grpc_client", @@ -340,6 +340,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_client", + ":python_ref_manager", ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -359,7 +360,6 @@ cc_library( "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -404,6 +404,15 @@ cc_library( ], ) +cc_library( + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/nb_class_ptr"), + deps = ["@nanobind"], +) + cc_library( name = "pjit", srcs = ["pjit.cc"], @@ -418,8 +427,11 @@ cc_library( ":config", ":guard_lib", ":jax_jit", + ":nb_class_ptr", ":py_client", + ":python_ref_manager", ":pytree", + ":traceback", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -438,11 +450,8 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:lru_cache", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/platform:env", @@ -465,8 +474,11 @@ cc_library( deps = [ ":config", ":jax_jit", + ":nb_class_ptr", ":py_client", + ":python_ref_manager", ":pytree", + ":traceback", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -485,11 +497,8 @@ cc_library( "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", @@ -539,9 +548,12 @@ cc_library( deps = [ ":callback", ":guard_lib", + ":nb_class_ptr", ":py_client_cpu", ":py_host_callback", ":py_host_callback_cc_proto", + ":python_ref_manager", + ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -603,12 +615,9 @@ cc_library( "@xla//xla/pjrt/distributed:client", "@xla//xla/python:aggregate_profile", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", @@ -688,6 +697,7 @@ cc_library( deps = [ ":callback", ":py_host_callback_cc_proto", + ":python_ref_manager", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/log:check", @@ -703,7 +713,6 @@ cc_library( "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -734,7 +743,9 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", + ":traceback", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -744,9 +755,7 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_numpy", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -761,6 +770,26 @@ cc_library( ], ) +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + ], +) + proto_library( name = "pytree_proto", srcs = ["pytree.proto"], @@ -783,6 +812,7 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/pytree"), deps = [ + ":nb_class_ptr", ":pytree_cc_proto", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -794,7 +824,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla/pjrt:exceptions", - "@xla//xla/python:nb_class_ptr", "@xla//xla/tsl/platform:logging", ], ) @@ -833,6 +862,33 @@ cc_library( ], ) +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + cc_library( name = "util", srcs = ["util.cc"], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 2df1715d099f..bb238e6991ec 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -34,11 +34,11 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/python_ref_manager.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" #include "xla/tsl/platform/statusor.h" namespace nb = nanobind; diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc index b5bc5830acbf..82f0bd0b0f5a 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/xla/config.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "xla/python/python_ref_manager.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 94d57e07c34a..6c4c24bfe10e 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -34,8 +34,11 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" @@ -45,12 +48,9 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index e73c477b1495..46b0954105f7 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc index e03fde194d49..eda57be86ba5 100644 --- a/jaxlib/xla/ifrt_proxy.cc +++ b/jaxlib/xla/ifrt_proxy.cc @@ -36,12 +36,12 @@ #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unordered_map.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt_proxy/client/registry.h" -#include "xla/python/nb_class_ptr.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index e2c186c5d3ff..a2e6d725f3b0 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -36,11 +36,11 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" -#include "xla/python/python_ref_manager.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/nb_class_ptr.h b/jaxlib/xla/nb_class_ptr.h new file mode 100644 index 000000000000..e468860dc661 --- /dev/null +++ b/jaxlib/xla/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_NB_CLASS_PTR_H_ +#define JAXLIB_XLA_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 0409397c82de..508bf79f9ec0 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -53,11 +53,14 @@ limitations under the License. #include "jaxlib/xla/config.h" #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -67,11 +70,8 @@ limitations under the License. #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 3dbd736076db..295ac8bfccfb 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,15 +46,18 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharded_device_array.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -64,11 +67,8 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index a348b47454e7..1325f0cbd2bc 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -58,12 +58,15 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -86,15 +89,12 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index f914639e383f..645f51096c1d 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -33,7 +33,9 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" @@ -41,10 +43,8 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/future.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/traceback.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index b74c37f28863..795a4fee29fa 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -50,12 +50,15 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/callback.h" #include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_host_callback.h" #include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -74,14 +77,11 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/program.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 8f50c6451627..898a40141307 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/program.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/shape.h" diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/xla/py_compile_only_client.cc index 6319c70f91b0..673dfc214346 100644 --- a/jaxlib/xla/py_compile_only_client.cc +++ b/jaxlib/xla/py_compile_only_client.cc @@ -30,6 +30,7 @@ limitations under the License. #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" @@ -37,7 +38,6 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/compile_only_ifrt/client.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/xla/py_compile_only_client.h index 721830d6f52e..6cc700e1d3a9 100644 --- a/jaxlib/xla/py_compile_only_client.h +++ b/jaxlib/xla/py_compile_only_client.h @@ -20,8 +20,8 @@ limitations under the License. // placeholder for index annotation headers #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" namespace xla { diff --git a/jaxlib/xla/py_device.cc b/jaxlib/xla/py_device.cc index 20c257bb7d1a..253bfd439a9b 100644 --- a/jaxlib/xla/py_device.cc +++ b/jaxlib/xla/py_device.cc @@ -36,17 +36,17 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h index 6d2b3893dea8..6071ede52305 100644 --- a/jaxlib/xla/py_device.h +++ b/jaxlib/xla/py_device.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/literal.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/shape.h" namespace xla { diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc index 593a86ccbe42..300e477dbbd0 100644 --- a/jaxlib/xla/py_device_list.cc +++ b/jaxlib/xla/py_device_list.cc @@ -32,13 +32,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h index ea574c5dc5a2..1d0f64003f8c 100644 --- a/jaxlib/xla/py_device_list.h +++ b/jaxlib/xla/py_device_list.h @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/tsl/concurrency/ref_count.h" namespace jax { diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc index 5a02a8f6dd20..71e6cfbdba7f 100644 --- a/jaxlib/xla/py_executable.cc +++ b/jaxlib/xla/py_executable.cc @@ -35,9 +35,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -48,8 +50,6 @@ limitations under the License. #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 214431f9472e..688eb779df8d 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -32,8 +32,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" @@ -45,9 +47,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/status.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc index 833079335a36..fdb40c04b517 100644 --- a/jaxlib/xla/py_host_callback.cc +++ b/jaxlib/xla/py_host_callback.cc @@ -32,13 +32,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/xla/callback.h" #include "jaxlib/xla/py_host_callback.pb.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" #include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/xla/py_memory_space.cc index f365dd25dfb6..0409861dd3b9 100644 --- a/jaxlib/xla/py_memory_space.cc +++ b/jaxlib/xla/py_memory_space.cc @@ -22,9 +22,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" namespace nb = ::nanobind; diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h index 4ad7b852f416..f111263497fb 100644 --- a/jaxlib/xla/py_memory_space.h +++ b/jaxlib/xla/py_memory_space.h @@ -19,9 +19,9 @@ limitations under the License. #include #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/py_program.cc b/jaxlib/xla/py_program.cc index ec82292a50cd..b3828f5372d9 100644 --- a/jaxlib/xla/py_program.cc +++ b/jaxlib/xla/py_program.cc @@ -34,8 +34,10 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -53,11 +55,9 @@ limitations under the License. #include "xla/python/ifrt/program.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index b1c4fbcc541f..05397cdf116f 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -34,9 +34,11 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -45,13 +47,11 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/traceback.h" #include "xla/python/transfer/event_loop.h" #include "xla/python/transfer/socket-server.h" #include "xla/python/transfer/socket_bulk_transport.h" diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 9375dd5440c6..1c7db0bec13a 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -42,6 +42,7 @@ limitations under the License. #include "nanobind/stl/complex.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" @@ -53,7 +54,6 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/xla/python_ref_manager.cc new file mode 100644 index 000000000000..a19622d94244 --- /dev/null +++ b/jaxlib/xla/python_ref_manager.cc @@ -0,0 +1,104 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/python_ref_manager.h" + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(&mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto& o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace xla diff --git a/jaxlib/xla/python_ref_manager.h b/jaxlib/xla/python_ref_manager.h new file mode 100644 index 000000000000..c0630da2ebd5 --- /dev/null +++ b/jaxlib/xla/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTHON_REF_MANAGER_H_ +#define JAXLIB_XLA_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index 7d1f7676bada..175e753515d0 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -49,9 +49,9 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/tuple.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pytree.pb.h" #include "xla/pjrt/exceptions.h" -#include "xla/python/nb_class_ptr.h" #include "xla/tsl/platform/logging.h" namespace xla { diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h index 471d25af89bc..9c4aaff0bfae 100644 --- a/jaxlib/xla/pytree.h +++ b/jaxlib/xla/pytree.h @@ -34,8 +34,8 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pytree.pb.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 9952c31bd393..409dddb62268 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -30,13 +30,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 572a6cd3c86e..dac18a4160b5 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -24,13 +24,13 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 96ec9c77071d..116ead49ad23 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" diff --git a/jaxlib/xla/traceback.cc b/jaxlib/xla/traceback.cc new file mode 100644 index 000000000000..35085b3e32fa --- /dev/null +++ b/jaxlib/xla/traceback.cc @@ -0,0 +1,357 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/traceback.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace xla { + +namespace nb = nanobind; + +bool Traceback::enabled_ = true; + +Traceback::Traceback() { + DCHECK(PyGILState_Check()); + PyThreadState* thread_state = PyThreadState_GET(); + +#if PY_VERSION_HEX < 0x030b0000 + // The representation of frame->f_lasti changed from bytes to words in Python + // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api + // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. + constexpr int kLastiWordBytes = 2; + + for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr; + py_frame = py_frame->f_back) { + Py_INCREF(py_frame->f_code); + frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); + } +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE + PyFrameObject* next; + for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr; py_frame = next) { + frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)); + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + +#endif // PY_VERSION_HEX < 0x030b0000 +} + +Traceback::~Traceback() { + for (auto& frame : frames_) { + DCHECK(PyGILState_Check()); + Py_DECREF(frame.first); + } +} + +Traceback::Traceback(Traceback&& other) noexcept + : frames_(std::move(other.frames_)) { + // absl::InlinedVector does not always clear itself if moved. Since we rely on + // its empty() method to destroy Traceback differently, we explicitly clear + // here. + other.frames_.clear(); +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + std::vector frame_strs; + frame_strs.reserve(frames_.size()); + for (const Frame& frame : Frames()) { + frame_strs.push_back(frame.ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + frames.reserve(frames_.size()); + for (const auto& frame : frames_) { + frames.push_back(Frame{nb::borrow(frame.first->co_filename), + nb::borrow(frame.first->co_name), + frame.first->co_firstlineno, + PyCode_Addr2Line(frame.first, frame.second)}); + } + return frames; +} + +std::optional> Traceback::Get() { + DCHECK(PyGILState_Check()); + if (!enabled_) { + return std::nullopt; + } + return make_nb_class(); +} + +void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } + +nb::object Traceback::AsPythonTraceback() const { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + for (const std::pair& frame : frames_) { + int lineno = PyCode_Addr2Line(frame.first, frame.second); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject* py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.first->co_filename), + PyUnicode_AsUTF8(frame.first->co_name), lineno); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + PyCode_Addr2Line(frame.first, frame.second)); + } + return traceback; +} + +namespace { + +Py_hash_t traceback_tp_hash(PyObject* o) { + Traceback* tb; + if (!nb::try_cast(nb::handle(o), tb)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return -1; + } + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + Traceback* x; + if (!nb::try_cast(nb::handle(self), x)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return nullptr; + } + + bool result; + Traceback* y; + if (nb::try_cast(nb::handle(other), y)) { + result = ((*x == *y) == (op == Py_EQ)); + } else { + result = (op == Py_NE); + } + return Py_NewRef(result ? Py_True : Py_False); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, (void*)traceback_tp_hash}, + {Py_tp_richcompare, (void*)traceback_tp_richcompare}, + {0, nullptr}, +}; + +} // namespace + +void BuildTracebackSubmodule(nb::module_& m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame& frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + nb::class_ traceback(m, "Traceback", + nb::type_slots(traceback_slots_), + "Represents a Python stack trace."); + traceback.def_prop_rw_static( + "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, + [](nb::object /* cls */, bool enabled) { + return Traceback::SetEnabled(enabled); + }); + traceback.def_static( + "get_traceback", []() { return Traceback::Get(); }, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object + that describes the Python stack of the calling thread. Stack trace + collection has a small overhead, so it is disabled by default. If traceback + collection is disabled, returns ``None``. + )doc"); + traceback.def_prop_ro("frames", &Traceback::Frames); + traceback.def("raw_frames", [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything rather + // than one per frame. + nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); + nb::list out_lasti = + nb::steal(PyList_New(tb.raw_frames().size())); + for (size_t i = 0; i < tb.raw_frames().size(); ++i) { + const auto& frame = tb.raw_frames()[i]; + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }); + traceback.def("__str__", &Traceback::ToString); + traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + traceback.def_static( + "code_addr2line", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + +#if PY_VERSION_HEX >= 0x030b0000 + traceback.def_static( + "code_addr2location", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +#endif // PY_VERSION_HEX >= 0x030b0000 + +#if PY_VERSION_HEX < 0x030b0000 + // This function replaces the exception traceback associated with the current + // Python thread. + m.def( + "replace_thread_exc_traceback", + [](nb::object tb) { + if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { + throw xla::XlaRuntimeError( + "argument must be a traceback object or None"); + } + PyThreadState* thread_state = PyThreadState_Get(); + if (!thread_state->exc_info->exc_traceback) { + throw xla::XlaRuntimeError( + "Current thread does not have an active " + "exception traceback"); + } + PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback; + PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr(); + thread_state->exc_info->exc_traceback = new_tb; + Py_XDECREF(old_exc_traceback); + }, + nb::arg("traceback").none()); +#endif // PY_VERSION_HEX < 0x30b0000 +} +} // namespace xla diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h new file mode 100644 index 000000000000..953d626439c4 --- /dev/null +++ b/jaxlib/xla/traceback.h @@ -0,0 +1,108 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TRACEBACK_H_ +#define JAXLIB_XLA_TRACEBACK_H_ + +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" + +namespace xla { + +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. +class Traceback { + public: + // Requires GIL. Creates a Traceback object that requires destructor to be + // invoked with GIL held as well. + static std::optional> Get(); + + // Requires GIL. + static bool enabled() { return enabled_; } + // Requires GIL. + static void SetEnabled(bool enabled); + + // Requires GIL. Don't call this directly, you're looking for Get(). + Traceback(); + // Requires GIL. + ~Traceback(); + + Traceback(const Traceback&) = delete; + Traceback(Traceback&& other) noexcept; + Traceback& operator=(const Traceback&) = delete; + Traceback& operator=(Traceback&&) = delete; + + // Requires the GIL be held. + std::string ToString() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + std::vector Frames() const; + + const absl::InlinedVector, 32>& raw_frames() + const { + return frames_; + } + + // Returns the traceback as a fake Python Traceback object, suitable for + // using as an exception traceback. + nanobind::object AsPythonTraceback() const; + + bool operator==(const Traceback& other) const { + return frames_ == other.frames_; + } + bool operator!=(const Traceback& other) const { + return frames_ != other.frames_; + } + + private: + // Each frame is a pair of a code object and a "lasti" instruction location + // in bytes. The size of _Py_CODEUNIT has changed across different Python + // versions; the lasti value here has already been multiplied by + // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions + // like PyCode_Addr2Line(). + absl::InlinedVector, 32> frames_; + + // Protected by GIL. + static bool enabled_; +}; + +using nb_traceback = nb_class_ptr; + +template +H AbslHashValue(H h, const Traceback& traceback) { + h = H::combine(std::move(h), traceback.raw_frames()); + return h; +} + +void BuildTracebackSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_TRACEBACK_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index a0508013910b..6e47be15fc68 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -90,6 +90,7 @@ limitations under the License. #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/mlir.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" #include "jaxlib/xla/py_array.h" @@ -98,8 +99,10 @@ limitations under the License. #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -113,15 +116,12 @@ limitations under the License. #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" #include "tsl/platform/platform.h" From 096810a72150df391d4986004795baeb19e7e1db Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 12:11:47 -0700 Subject: [PATCH 0195/1769] [array API] make capabilities more accurate --- jax/_src/numpy/array_api_metadata.py | 5 +++-- tests/array_api_test.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 4a01f579a67e..d8d2c2d1a2a4 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -51,8 +51,9 @@ class ArrayNamespaceInfo: .. _Python array API: https://data-apis.org/array-api/ """ _capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, + "boolean indexing": False, # within transformations + "data-dependent shapes": False, # within transformations + "max dimensions": 64, # XLA limitation } def _build_dtype_dict(self): diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 250eeb810872..d509fe78c35f 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -275,8 +275,9 @@ def build_dtype_dict(self, dtypes): def test_capabilities_info(self): capabilities = self.info.capabilities() - assert capabilities["boolean indexing"] + assert not capabilities["boolean indexing"] assert not capabilities["data-dependent shapes"] + assert capabilities["max dimensions"] == 64 def test_default_device_info(self): assert self.info.default_device() is None From 4644b2ba67fd8a8144602df014595a564104b8ae Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 12:55:20 -0700 Subject: [PATCH 0196/1769] Add tests to ensure nan checks do not produce false positives in jax.numpy functions PiperOrigin-RevId: 740872313 --- tests/jax_numpy_error_test.py | 90 +++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index 08917aeed364..566e0b1ba209 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -117,65 +117,73 @@ def test_error_category_invalid_category(self): @staticmethod def op_cases(cases): for jit in (True, False): - for func, operands in cases: - if not isinstance(operands, tuple): - operands = (operands,) + for func, ops_error, ops_no_err in cases: + if not isinstance(ops_error, tuple): + ops_error = (ops_error,) + if not isinstance(ops_no_err, tuple): + ops_no_err = (ops_no_err,) jit_str = "jit" if jit else "nojit" func_str = f"{func.__module__}.{func.__name__}" name = f"_{jit_str}_{func_str}" - yield name, jit, func, operands + yield name, jit, func, ops_error, ops_no_err @parameterized.named_parameters( op_cases(( - # list of all NaN-producing jax.numpy functions + # List of all NaN-producing jax.numpy functions. + # The first group of numbers is the input that will produce a NaN, and + # the second group is the input that will not produce a NaN. # go/keep-sorted start - (jnp.acos, 2.0), - (jnp.acosh, 0.5), - (jnp.add, (jnp.inf, -jnp.inf)), - (jnp.arccos, 2.0), - (jnp.arccosh, 0.5), - (jnp.arcsin, -2.0), - (jnp.arctanh, -2.0), - (jnp.asin, -2.0), - (jnp.atanh, -2.0), - (jnp.cos, jnp.inf), - (jnp.divide, (0.0, 0.0)), - (jnp.divmod, (1.0, 0.0)), - (jnp.float_power, (-1.0, 0.5)), - (jnp.fmod, (1.0, 0.0)), - (jnp.log, -1.0), - (jnp.log10, -1.0), - (jnp.log1p, -1.5), - (jnp.log2, -1.0), - (jnp.mod, (1.0, 0.0)), - (jnp.pow, (-1.0, 0.5)), - (jnp.power, (-1.0, 0.5)), - (jnp.remainder, (1.0, 0.0)), - (jnp.sin, jnp.inf), + (jnp.acos, 2.0, 0.5), + (jnp.acosh, 0.5, 2.0), + (jnp.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (jnp.arccos, 2.0, 0.5), + (jnp.arccosh, 0.5, 2.0), + (jnp.arcsin, -2.0, 0.5), + (jnp.arctanh, -2.0, 0.5), + (jnp.asin, -2.0, 0.5), + (jnp.atanh, -2.0, 0.5), + (jnp.cos, jnp.inf, 1.0), + (jnp.divide, (0.0, 0.0), (1.0, 1.0)), + (jnp.divmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.float_power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.fmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.log, -1.0, 1.0), + (jnp.log10, -1.0, 1.0), + (jnp.log1p, -1.5, 1.0), + (jnp.log2, -1.0, 1.0), + (jnp.mod, (1.0, 0.0), (1.0, 1.0)), + (jnp.pow, (-1.0, 0.5), (1.0, 1.0)), + (jnp.power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.remainder, (1.0, 0.0), (1.0, 1.0)), + (jnp.sin, jnp.inf, 1.0), # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. - # (jnp.sinc, jnp.inf), - (jnp.sqrt, -4.0), - (jnp.subtract, (jnp.inf, jnp.inf)), - (jnp.tan, jnp.inf), - (jnp.true_divide, (0.0, 0.0)), - (operator.add, (jnp.inf, -jnp.inf)), - (operator.mod, (1.0, 0.0)), - (operator.pow, (-1.0, 0.5)), - (operator.sub, (jnp.inf, jnp.inf)), - (operator.truediv, (0.0, 0.0)), + # (jnp.sinc, jnp.inf, 1.0), + (jnp.sqrt, -4.0, 4.0), + (jnp.subtract, (jnp.inf, jnp.inf), (0.0, 0.0)), + (jnp.tan, jnp.inf, 1.0), + (jnp.true_divide, (0.0, 0.0), (1.0, 1.0)), + (operator.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (operator.mod, (1.0, 0.0), (1.0, 1.0)), + (operator.pow, (-1.0, 0.5), (1.0, 1.0)), + (operator.sub, (jnp.inf, jnp.inf), (0.0, 0.0)), + (operator.truediv, (0.0, 0.0), (1.0, 1.0)), # go/keep-sorted end )) ) - def test_can_raise_nan_error(self, jit, f, operands): - operands = [jnp.float32(x) for x in operands] + def test_can_raise_nan_error(self, jit, f, ops_err, ops_no_err): + ops_err = [jnp.float32(x) for x in ops_err] + ops_no_err = [jnp.float32(x) for x in ops_no_err] if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(*operands) + f(*ops_no_err) + error_check.raise_if_error() # should not raise error + + f(*ops_err) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() From c9bc5f094d7b839d458c7191bbfc2c9defc02235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 26 Mar 2025 13:22:41 -0700 Subject: [PATCH 0197/1769] [Mosaic:TPU] 32-bit sublane broadcast for non-native tilings PiperOrigin-RevId: 740881404 --- .../tpu/transforms/apply_vector_layout.cc | 25 +++++++++++-------- .../tpu/transforms/infer_vector_layout.cc | 6 ++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 7755738a4fc7..71924739595c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3416,20 +3416,25 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (tiling[1] != ctx.target_shape[1]) { return op.emitOpError("Not implemented: unsupported tiling"); } - int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t sublanes_per_tile = + layout_in.sublanesPerTile(ctx.target_shape); if (needs_physical_broadcast == std::array{true, false}) { // Sublane broadcast const int packing = layout_in.packing(); - if (num_tiles != 1) { - return op.emitOpError( - "Not implemented: Only native tiling supported"); - } TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); TPU_ASSERT_OP(offsets_in[0].has_value()); const int64_t sublane_offset = *offsets_in[0] / packing; const int64_t subelement_offset = *offsets_in[0] % packing; - const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], sublane_offset)); + SmallVector pattern; + pattern.reserve(ctx.target_shape[0]); + for (int32_t t = 0; t < num_tiles; ++t) { + for (int32_t i = 0; i < sublanes_per_tile; ++i) { + pattern.push_back(sublanes_per_tile * t + sublane_offset); + } + } + const DenseI32ArrayAttr sublane_pattern = + builder.getDenseI32ArrayAttr(pattern); const absl::Status status = src_tiles.EachStatus([&](const absl::Span src_idx, Value *const src_vreg) { @@ -3446,8 +3451,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, return absl::InternalError(""); } } - dst_vreg = builder.create(dst_vreg.getType(), - dst_vreg, indices, 0); + dst_vreg = builder.create( + dst_vreg.getType(), dst_vreg, sublane_pattern, 0); SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3469,8 +3474,6 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); - const int64_t sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 00e53314e588..c1a642b48f04 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1092,13 +1092,11 @@ class VectorLayoutInferer { } auto src_tiled_ishape = layout.getImplicitTiledDims(src_ty.getShape(), 1); auto dst_tiled_ishape = layout.getImplicitTiledDims(res_ty.getShape(), 1); - // Since we can only do sublane broadcasts in the (8, 128) tiling, we - // should always use that when sublane broadcasting is required. if (src_tiled_ishape[0] != dst_tiled_ishape[0] && layout.offsets()[0] != std::nullopt) { + // TODO(tlongeri): Remove this. We support non-native tiling now, but + // things may still break downstream due to missing relayouts. LayoutOffsets offsets = layout.offsets(); - // At the moment relayout can only produce replicated sublanes when - // converting to (8, 128) if the input was in (1, 128) tiling if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { offsets[0] = std::nullopt; } From b92b9b0e26700202cea26ebbbc8e5f9ab42997d1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 13:36:39 -0700 Subject: [PATCH 0198/1769] Raise an informative error when the length of device_assignment doesn't match the mesh.size of out_avals. This happens when (1) we can't extract the device_assignment from the arguments and (2) there is no concrete mesh in context. For example: ``` def test_random_normal_wo_mesh_context_error(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) s = NamedSharding(mesh, P('x', 'y')) @jax.jit def f(key): out = jax.random.normal(key, shape=(8, 12), out_sharding=s) self.assertEqual(out.aval.sharding.spec, P('x', 'y')) self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) return out key = jax.random.key(1) with self.assertRaisesRegex( ValueError, 'Length of device assignment.*is not equal to the size of the mesh'): f(key) ``` PiperOrigin-RevId: 740886114 --- jax/_src/interpreters/pxla.py | 10 ++++++++++ tests/pjit_test.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 387f0661ae9d..6f95b1b72281 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2281,6 +2281,16 @@ def lower_sharding_computation( devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] + for a in global_out_avals: + if (a is not core.abstract_token and not a.sharding.mesh.empty and + a.sharding.mesh._are_all_axes_explicit and + len(device_assignment) != a.sharding.mesh.size): + raise ValueError( + f"Length of device assignment {len(device_assignment)} is not equal" + f" to the size of the mesh {a.sharding.mesh.size} of aval" + f" {a.str_short(True, True)}. Please enter your `jit` into a mesh" + " context via `jax.sharding.use_mesh`.") + # TODO(parkers): One _raw_platform has been unified with platform, # change this back to just read platform. platforms = lowering_platforms or ( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5cd16e1e6925..277f24bd703f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7265,6 +7265,24 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_random_normal_wo_mesh_context_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + return out + + key = jax.random.key(1) + with self.assertRaisesRegex( + ValueError, + 'Length of device assignment.*is not equal to the size of the mesh'): + f(key) + def test_random_normal_wo_mesh_context(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) From d9a6cd1a5ec1aeba2aa479f8f79f37c2504e4b77 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 13:40:50 -0700 Subject: [PATCH 0199/1769] Remove xla_client.make_gpu_client. Cleanup; this code is not used any more because we use C API plugins instead. PiperOrigin-RevId: 740887556 --- jax/_src/xla_bridge.py | 75 ++++------------------------------------ jaxlib/xla/xla_client.py | 45 ------------------------ 2 files changed, 6 insertions(+), 114 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 72d88b9735b7..227359dc4676 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -89,13 +89,13 @@ 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_MOCK_NUM_GPU_PROCESSES = config.int_flag( +MOCK_NUM_GPU_PROCESSES = config.int_flag( name="mock_num_gpu_processes", default=0, help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) -_MOCK_GPU_TOPOLOGY = config.string_flag( +MOCK_GPU_TOPOLOGY = config.string_flag( name="jax_mock_gpu_topology", default="", help='Mock multi-host GPU topology in GPU client. The value should ' @@ -432,7 +432,7 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') -def _get_num_nodes_from_gpu_topology(topology: str) -> int: +def get_num_nodes_from_gpu_topology(topology: str) -> int: try: slices_str, hosts_per_slice_str, _ = topology.split("x", 2) return int(slices_str) * int(hosts_per_slice_str) @@ -441,69 +441,6 @@ def _get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.Flag[str] -) -> xla_client.Client: - visible_devices = visible_devices_flag.value - allowed_devices = None - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) - - use_mock_gpu_client = mock_num_gpu_processes > 0 - num_nodes = (mock_num_gpu_processes if use_mock_gpu_client - else distributed.global_state.num_processes) - - if platform_name == "cuda": - if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): - _check_cuda_versions() - else: - print('Skipped CUDA versions constraints check due to the ' - 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - - devices_to_check = ( - allowed_devices - if allowed_devices - else range(cuda_versions.cuda_device_count()) - ) - _check_cuda_compute_capability(devices_to_check) - - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=use_mock_gpu_client, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag=CUDA_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag=_ROCM_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - - if hasattr(xla_client, "make_tpu_client"): # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, # and then fail loudly on initialization failure. @@ -652,9 +589,9 @@ def _options_from_jax_configs(plugin_name): else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + mock_gpu_topology = MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else MOCK_NUM_GPU_PROCESSES.value) options['enable_mock_nccl'] = mock_num_processes > 0 if mock_num_processes > 0: options['num_nodes'] = mock_num_processes diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index a9b1109c3bd3..ce881bee17c0 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -85,51 +85,6 @@ def make_cpu_client( ) -def make_gpu_client( - distributed_client=None, - node_id=0, - num_nodes=1, - platform_name=None, - allowed_devices=None, - mock=False, - mock_gpu_topology=None, -): - """Returns a GPU client. BFC allocator is used by default.""" - options = generate_pjrt_gpu_plugin_options() - allocator = options['allocator'] - config = _xla.GpuAllocatorConfig() - if allocator == 'default': - config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT - if allocator == 'platform': - config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM - if allocator == 'bfc': - config.kind = _xla.GpuAllocatorConfig.Kind.BFC - if allocator == 'cuda_async': - config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC - if 'memory_fraction' in options: - config.memory_fraction = options['memory_fraction'] - if 'preallocate' in options: - config.preallocate = options['preallocate'] - if 'collective_memory_size' in options: - config.collective_memory_size = options['collective_memory_size'] - register_custom_call_handler('CUDA', _xla.register_custom_call_target) - register_custom_call_handler('ROCM', _xla.register_custom_call_target) - register_custom_type_id_handler('CUDA', _xla.register_custom_type_id) - register_custom_type_id_handler('ROCM', _xla.register_custom_type_id) - - return _xla.get_gpu_client( - asynchronous=True, - allocator_config=config, - distributed_client=distributed_client, - node_id=node_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=mock, - mock_gpu_topology=mock_gpu_topology, - ) - - def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): assert pjrt_plugin_loaded('tpu') if not pjrt_plugin_initialized('tpu'): From 66908372af2a21832eb44fb9c652dda317397b4c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 14:06:26 -0700 Subject: [PATCH 0200/1769] jnp.tri*_indices: support __jax_array__ inputs --- jax/_src/numpy/lax_numpy.py | 18 +++++++++++++----- jax/numpy/__init__.pyi | 4 ++-- tests/array_extensibility_test.py | 4 ++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index fd6209ab22c4..83d5e9ee3e80 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -60,7 +60,7 @@ from jax._src.numpy.vectorize import vectorize from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, @@ -7557,7 +7557,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7608,14 +7608,18 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("triu_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @export -def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. @@ -7666,7 +7670,11 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("tril_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 640e9de7eac3..fb679969fe31 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -929,14 +929,14 @@ def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 7c0ec07e6a05..551f6d45dc41 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -484,10 +484,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), NumPyAPI.sig(jnp.tril, Float[5, 6]), - # NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), NumPyAPI.sig(jnp.trim_zeros, Float[5]), NumPyAPI.sig(jnp.triu, Float[5, 6]), - # NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), NumPyAPI.sig(jnp.trunc, Float[5]), NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), From c450b69dd7cb3f4ddc1700866ce7ab5dc9c4c459 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 14:26:54 -0700 Subject: [PATCH 0201/1769] Add missing `__len__` to MutableArray Fixes https://github.com/jax-ml/jax/issues/27476 PiperOrigin-RevId: 740903637 --- jax/_src/core.py | 4 ++-- jax/_src/state/types.py | 6 ++++++ tests/mutable_array_test.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ef90341f5cf7..14f3d9cc18e2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1940,8 +1940,7 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding if 'vma' not in kwargs: - kwargs['vma'] = getattr(self, 'vma', - frozenset()) + kwargs['vma'] = getattr(self, 'vma', frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -2170,6 +2169,7 @@ def __init__(self, aval, buf): def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) + def __len__(self) -> int: return self._aval._len(self) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index fa9d0cb9fb16..e926e3a35f80 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -333,6 +333,12 @@ def update(self, inner_aval=None): ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) + def _len(self, ignored_tracer) -> int: + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + @property def shape(self): try: diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e962653ed32d..950bddf544d7 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -116,6 +116,18 @@ def f(y_mut, z): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_len_mutable_array(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + def f(): + return jnp.int32(len(x_mut)) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 3) + @parameterized.parameters([True, False]) def test_internal_mutarray_basic(self, jit): def f(): From ce3941c635b9994b4e27ee3cd377d2bd568d5ea7 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 14:35:10 -0700 Subject: [PATCH 0202/1769] Add division-by-zero checks to jax.numpy functions PiperOrigin-RevId: 740906595 --- jax/_src/numpy/error.py | 23 +++++++++- jax/_src/numpy/ufuncs.py | 4 ++ tests/jax_numpy_error_test.py | 80 +++++++++++++++++++++++++++-------- 3 files changed, 88 insertions(+), 19 deletions(-) diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index 52b996a0b050..20dab289d779 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -30,7 +30,9 @@ def _is_category_disabled( if category == "nan": raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") if category == "divide": - return config.error_checking_behavior_divide.value == "ignore" + raise ValueError( + "divide is deprecated. Use `_set_error_if_divide_by_zero` instead." + ) if category == "oob": return config.error_checking_behavior_oob.value == "ignore" raise ValueError(f"Invalid category: {category}") @@ -81,6 +83,25 @@ def _set_error_if_nan(pred: jax.Array, /): error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") +def _set_error_if_divide_by_zero(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is zero. + + This function is intended for checking if the denominator of a division is + zero. + + This function is disabled if the `jax_error_checking_behavior_divide` flag is + set to "ignore". + """ + if config.error_checking_behavior_divide.value == "ignore": + return + + # TODO(ayx): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + zero = jnp.zeros_like(pred, shape=()) + error_check_lib.set_error_if(pred == zero, "Division by zero encountered") + + Behavior = Literal["ignore", "raise"] diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 0ea2992c9955..e561b7ae71b6 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2471,6 +2471,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) out = lax.div(x1, x2) jnp_error._set_error_if_nan(out) return out @@ -2523,6 +2524,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([3., 2., 2.], dtype=float32) """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): return lax.div(x1, x2) @@ -2577,6 +2579,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: + jnp_error._set_error_if_divide_by_zero(x2) return _float_divmod(x1, x2) @@ -3090,6 +3093,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 0., 2., -2.]], dtype=float32) """ x1, x2 = promote_args_numeric("remainder", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index 566e0b1ba209..f2262d8b5dc0 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -51,11 +51,9 @@ def f(x): error_check.raise_if_error() @parameterized.product(jit=[True, False]) - def test_error_category_divide_check(self, jit): + def test_set_error_if_divide_by_zero(self, jit): def f(x, y): - jnp_error._set_error_if_with_category( - y == 0.0, "division by zero", category="divide" - ) + jnp_error._set_error_if_divide_by_zero(y) return x / y if jit: @@ -70,7 +68,7 @@ def f(x, y): with jnp_error.error_checking_behavior(divide="raise"): _ = f(x, y) - with self.assertRaisesRegex(JaxValueError, "division by zero"): + with self.assertRaisesRegex(JaxValueError, "Division by zero"): error_check.raise_if_error() @parameterized.product(jit=[True, False]) @@ -115,22 +113,22 @@ def test_error_category_invalid_category(self): ) @staticmethod - def op_cases(cases): + def nan_cases(cases): for jit in (True, False): - for func, ops_error, ops_no_err in cases: - if not isinstance(ops_error, tuple): - ops_error = (ops_error,) - if not isinstance(ops_no_err, tuple): - ops_no_err = (ops_no_err,) + for func, args_error, args_no_err in cases: + if not isinstance(args_error, tuple): + args_error = (args_error,) + if not isinstance(args_no_err, tuple): + args_no_err = (args_no_err,) jit_str = "jit" if jit else "nojit" func_str = f"{func.__module__}.{func.__name__}" name = f"_{jit_str}_{func_str}" - yield name, jit, func, ops_error, ops_no_err + yield name, jit, func, args_error, args_no_err @parameterized.named_parameters( - op_cases(( + nan_cases(( # List of all NaN-producing jax.numpy functions. # The first group of numbers is the input that will produce a NaN, and # the second group is the input that will not produce a NaN. @@ -172,21 +170,67 @@ def op_cases(cases): # go/keep-sorted end )) ) - def test_can_raise_nan_error(self, jit, f, ops_err, ops_no_err): - ops_err = [jnp.float32(x) for x in ops_err] - ops_no_err = [jnp.float32(x) for x in ops_no_err] + def test_can_raise_nan_error(self, jit, f, args_err, args_no_err): + args_err = [jnp.float32(x) for x in args_err] + args_no_err = [jnp.float32(x) for x in args_no_err] if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(*ops_no_err) + f(*args_no_err) error_check.raise_if_error() # should not raise error - f(*ops_err) + f(*args_err) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() + INT_TYPES = (jnp.int32, jnp.uint32, jnp.int64, jnp.uint64, jnp.int16, + jnp.uint16, jnp.int8, jnp.uint8) + FLOAT_TYPES = (jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16) + + @staticmethod + def divide_cases(cases): + for jit in (True, False): + for func, dtypes in cases: + for dtype in dtypes: + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + dtype_str = dtype.__name__ + name = f"_{jit_str}_{func_str}_{dtype_str}" + yield name, jit, func, dtype + + @parameterized.named_parameters( + divide_cases(( + # go/keep-sorted start + (jnp.divmod, FLOAT_TYPES + INT_TYPES), + (jnp.floor_divide, INT_TYPES), + (jnp.mod, FLOAT_TYPES + INT_TYPES), + (jnp.remainder, FLOAT_TYPES + INT_TYPES), + (jnp.true_divide, FLOAT_TYPES), + (operator.mod, FLOAT_TYPES + INT_TYPES), + (operator.truediv, FLOAT_TYPES), + # go/keep-sorted end + )) + ) + def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + args_err = (dtype(1), dtype(0)) + args_no_err = (dtype(1), dtype(1)) + + if jit: + div_func = jax.jit(div_func) + + with jnp_error.error_checking_behavior(divide="raise"): + div_func(*args_no_err) + error_check.raise_if_error() # should not raise error + + div_func(*args_err) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From af25dc47196e1237b6e92d8bb399b3439f22eebd Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 26 Mar 2025 15:19:24 -0700 Subject: [PATCH 0203/1769] Update the Windows docker image to ltsc2022 PiperOrigin-RevId: 740921613 --- ci/envs/docker.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..a0f558520d45 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -41,5 +41,5 @@ fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" fi \ No newline at end of file From 667c4a0ee0da4cb96795624f7b91a9deacdeca14 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 15:27:25 -0700 Subject: [PATCH 0204/1769] Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim --- jax/_src/numpy/lax_numpy.py | 6 ++--- jax/_src/numpy/util.py | 38 +++++++++++++++++++++++-------- jax/_src/typing.py | 14 +++++++++++- jax/numpy/__init__.pyi | 12 +++++----- tests/array_extensibility_test.py | 6 ++--- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 83d5e9ee3e80..63edaed0adeb 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -60,7 +60,7 @@ from jax._src.numpy.vectorize import vectorize from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, @@ -7557,7 +7557,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7619,7 +7619,7 @@ def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Arra @export -def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e281c63ae654..e0e20d443e02 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -27,7 +27,8 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax._src.typing import (Array, ArrayLike, DimSize, DType, DTypeLike, + Shape, SupportsNdim, SupportsShape, SupportsSize) from jax.sharding import Sharding import numpy as np @@ -313,7 +314,7 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin @export -def ndim(a: ArrayLike) -> int: +def ndim(a: ArrayLike | SupportsNdim) -> int: """Return the number of dimensions of an array. JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function @@ -321,7 +322,7 @@ def ndim(a: ArrayLike) -> int: tuple. Args: - a: array-like object. + a: array-like object, or any object with an ``ndim`` attribute. Returns: An integer specifying the number of dimensions of ``a``. @@ -346,13 +347,18 @@ def ndim(a: ArrayLike) -> int: >>> x.ndim 1 """ + if hasattr(a, "ndim"): + return a.ndim # Deprecation warning added 2025-2-20. check_arraylike("ndim", a, emit_warning=True) - return np.ndim(a) # NumPy dispatches to a.ndim if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.ndim if available. + return np.ndim(a) # type: ignore[arg-type] @export -def shape(a: ArrayLike) -> tuple[int, ...]: +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function @@ -360,7 +366,7 @@ def shape(a: ArrayLike) -> tuple[int, ...]: tuple. Args: - a: array-like object. + a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. @@ -385,13 +391,18 @@ def shape(a: ArrayLike) -> tuple[int, ...]: >>> x.shape (10,) """ + if hasattr(a, "shape"): + return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) - return np.shape(a) # NumPy dispatches to a.shape if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.shape if available. + return np.shape(a) # type: ignore[arg-type] @export -def size(a: ArrayLike, axis: int | None = None) -> int: +def size(a: ArrayLike | SupportsSize | SupportsShape, axis: int | None = None) -> int: """Return number of elements along a given axis. JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function @@ -399,7 +410,8 @@ def size(a: ArrayLike, axis: int | None = None) -> int: tuple. Args: - a: array-like object + a: array-like object, or any object with a ``size`` attribute when ``axis`` is not + specified, or with a ``shape`` attribute when ``axis`` is specified. axis: optional integer along which to count elements. By default, return the total number of elements. @@ -428,6 +440,12 @@ def size(a: ArrayLike, axis: int | None = None) -> int: >>> y.size 6 """ + if (axis is None and hasattr(a, "size")) or (axis is not None and hasattr(a, "shape")): + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] # Deprecation warning added 2025-2-20. check_arraylike("size", a, emit_warning=True) - return np.size(a, axis=axis) # NumPy dispatches to a.size if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 010841b45dd2..ee2422dd2d73 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -47,7 +47,19 @@ @typing.runtime_checkable class SupportsDType(Protocol): @property - def dtype(self) -> DType: ... + def dtype(self, /) -> DType: ... + +class SupportsShape(Protocol): + @property + def shape(self, /) -> tuple[int, ...]: ... + +class SupportsSize(Protocol): + @property + def size(self, /) -> int: ... + +class SupportsNdim(Protocol): + @property + def ndim(self, /) -> int: ... # DTypeLike is meant to annotate inputs to np.dtype that return # a valid JAX dtype. It's different than numpy.typing.DTypeLike diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index fb679969fe31..259f6e3ed2ee 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -15,7 +15,7 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, - DimSize, DuckTypedArray, Shape, StaticScalar, + DimSize, DuckTypedArray, Shape, StaticScalar, SupportsNdim, SupportsShape, SupportsSize, ) from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax.numpy import fft as fft, linalg as linalg @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -def ndim(a: ArrayLike) -> int: ... +def ndim(a: ArrayLike | SupportsNdim) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -841,7 +841,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -def shape(a: ArrayLike) -> tuple[int, ...]: ... +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -849,7 +849,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -def size(a: ArrayLike, axis: int | None = None) -> int: ... +def size(a: ArrayLike | SupportsSize, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., @@ -929,14 +929,14 @@ def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 551f6d45dc41..fae9129dd99a 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -403,7 +403,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.nanstd, Float[5]), NumPyAPI.sig(jnp.nansum, Float[5]), NumPyAPI.sig(jnp.nanvar, Float[5]), - # NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.ndim, Float[5]), NumPyAPI.sig(jnp.negative, Float[5]), NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), NumPyAPI.sig(jnp.nonzero, Float[5]), @@ -455,13 +455,13 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), - # NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.shape, Float[5, 3]), NumPyAPI.sig(jnp.sign, Float[5]), NumPyAPI.sig(jnp.signbit, Float[5]), NumPyAPI.sig(jnp.sin, Float[5]), NumPyAPI.sig(jnp.sinc, Float[5]), NumPyAPI.sig(jnp.sinh, Float[5]), - # NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.size, Float[5]), NumPyAPI.sig(jnp.sort, Float[5]), NumPyAPI.sig(jnp.sort_complex, Complex[5]), NumPyAPI.sig(jnp.spacing, Float[5]), From 5bc4c57f09c778043b4932dd52cb4ea45c5d7069 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 15:27:20 -0700 Subject: [PATCH 0205/1769] Inline make_tfrt_tpu_c_api_client into its only caller. PiperOrigin-RevId: 740923936 --- jaxlib/xla/xla_client.py | 16 ++++++---------- jaxlib/xla/xla_client.pyi | 3 --- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index ce881bee17c0..eb7a3b5759e5 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -85,15 +85,6 @@ def make_cpu_client( ) -def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): - assert pjrt_plugin_loaded('tpu') - if not pjrt_plugin_initialized('tpu'): - initialize_pjrt_plugin('tpu') - if options is None: - options = {} - return _xla.get_c_api_client('tpu', options) - - DeviceTopology = _xla.DeviceTopology get_topology_for_devices = _xla.get_topology_for_devices @@ -169,7 +160,12 @@ def make_tpu_client( if not pjrt_plugin_loaded('tpu'): c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') profiler.register_plugin_profiler(c_api) - return make_tfrt_tpu_c_api_client(options) + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index 234af8f7b87d..5ac837ef1d85 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -106,9 +106,6 @@ def make_gpu_client( ) -> Client: ... -def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client: - ... - def make_tfrt_tpu_c_api_device_topology( topology_name: str | None = None, **kwargs ) -> DeviceTopology: From c88ea23035454a95d7e20e4976d1e595114c8a66 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 26 Mar 2025 15:47:55 -0700 Subject: [PATCH 0206/1769] [JAX] Add caching to `colocated_python.colocated_cpu_devices()` For a deployment with many devices, `colocated_python.colocated_cpu_devices()` can take some time to find colocated devices as it needs to find matching devices one by one in Python. This change adds caching as an optimization to reduce the overall cost of API calls. PiperOrigin-RevId: 740930124 --- jax/experimental/colocated_python/api.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b855bba48abb..e72e04c6ded9 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -28,6 +28,15 @@ def colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: """Finds CPU devices colocated with the given devices.""" + if not isinstance(devices, tuple): + devices = tuple(devices) + return _colocated_cpu_devices_cached(devices) + + +@jax.util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": From e8038501d0ee0bef99a8a772d2ad9b0f38b018bb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 16:30:23 -0700 Subject: [PATCH 0207/1769] Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output. Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True. Co-authored-by: Matthew Johnson PiperOrigin-RevId: 740942785 --- jax/_src/checkify.py | 6 ++--- jax/_src/core.py | 3 +-- jax/_src/pjit.py | 25 ++++++++++++--------- tests/api_test.py | 4 ++++ tests/checkify_test.py | 15 ++++++++++++- tests/debug_info_test.py | 25 +++++++++------------ tests/memories_test.py | 1 + tests/pjit_test.py | 48 ++++++++++++++++++++++++++++------------ 8 files changed, 81 insertions(+), 46 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1ec8ad50b456..f80a0cbd1d75 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -913,14 +913,14 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) - new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) + err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, diff --git a/jax/_src/core.py b/jax/_src/core.py index 14f3d9cc18e2..3a1558802682 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1977,8 +1977,7 @@ def to_tangent_aval(self): def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - getattr(self, 'varying_manual_axes', frozenset()), - short_dtypes, mesh_axis_types) + getattr(self, 'vma', frozenset()), short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 054b55e32918..03eb6835cb06 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1777,11 +1777,12 @@ def pjit_staging_rule(trace, *args, **params): return pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - params['jaxpr'], params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) + jaxpr = params['jaxpr'] if config.dynamic_shapes.value: + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + jaxpr, params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): @@ -1795,6 +1796,10 @@ def pjit_staging_rule(trace, *args, **params): map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, jaxpr.effects, source_info) trace.frame.add_eqn(eqn) + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) consts = map(trace.new_const, consts) @@ -1807,19 +1812,14 @@ def pjit_staging_rule(trace, *args, **params): pjit_p, (*args, *consts), new_params) else: out_tracers = trace.default_process_primitive(pjit_p, args, params) - - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol - in zip(in_fwd, out_shardings, out_layouts)] + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) out_shardings = [o for o, k in zip(out_shardings, keep) if k] @@ -1827,6 +1827,8 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts): return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): + if not config.dynamic_shapes.value: + return [None] * len(eqn.outvars), eqn jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] @@ -1835,6 +1837,7 @@ def pjit_forwarding_rule(eqn): new_eqn = eqn.replace(params=new_params, outvars=new_outvars) fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] return fwd_vars, new_eqn +# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[pjit_p] = pjit_forwarding_rule diff --git a/tests/api_test.py b/tests/api_test.py index 82b673fe4b1e..9d80b5fbed74 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4687,6 +4687,8 @@ def f(inputs): @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_happens(self): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() self.assertLen(jaxpr.jaxpr.outvars, 1) self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) @@ -4695,6 +4697,8 @@ def test_inner_jit_forwarding_happens(self): @parameterized.parameters(range(8)) @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_correctness(self, num_input_fwd): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") num_args = 8 rng = np.random.RandomState(0) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 6a1660b28578..5ea99d20a2ab 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -24,7 +24,7 @@ from jax.experimental import checkify from jax.experimental import pjit from jax.experimental import shard_map -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config from jax._src import core @@ -475,6 +475,19 @@ def f(init_val): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "division by zero") + def test_checify_donation_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @checkify.checkify + @partial(jax.jit, donate_argnums=(0,)) + def f(x: jax.Array) -> jax.Array: + checkify.check(jnp.all(x > 0), "a") + return x + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + err, y = f(x) + err, z = f(y) # doesn't crash + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index a39b53c3ad16..1d2935ea34d7 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -671,7 +671,7 @@ def my_g(b, d=1): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? - "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_f, arg_names=a, result_paths=result", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ @@ -794,7 +794,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", @@ -1318,17 +1318,15 @@ def the_grad(c, as_): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", @@ -1467,7 +1465,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -1611,11 +1609,8 @@ def my_f(x): x, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", - "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ diff --git a/tests/memories_test.py b/tests/memories_test.py index 570b0c375834..64ee2829873d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1090,6 +1090,7 @@ def f_bwd(res, tx): self.assertArraysEqual(g(arr), all_true) def test_scan_offload(self): + self.skipTest('b/406586554') np_inp = jnp.arange(4096).reshape(16, 16, 16) @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 277f24bd703f..d72ecc98e771 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1240,9 +1240,12 @@ def test_pretty_print_pjit_id(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - pjit[name= jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a - c:f32[1] = add a a - in (c,) } + b:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1] c:f32[1]. let in (a,) } + ] a a + d:f32[1] = add a b + in (d,) } """).strip(), ) @@ -1289,8 +1292,11 @@ def test_pretty_print_with_literal_outvar(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:i32[] = pjit[name= jaxpr={ lambda ; a:f32[1]. let in (2,) }] a - in (b, a) } + b:i32[] c:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1]. let in (2, a) } + ] a + in (b, c) } """).strip(), ) @@ -1336,19 +1342,19 @@ def f(x): self.assertEqual( jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in () } in - let f1 = { lambda ; b:f32[2]. let in () } in + let f = { lambda ; a:f32[1]. let in (a,) } in + let f1 = { lambda ; b:f32[2]. let in (b,) } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; c:f32[1] d:f32[2]. let - pjit[name=f jaxpr=f] c - pjit[name=f jaxpr=f] c - g:f32[1] = mul c c - pjit[name=f jaxpr=f1] d - pjit[name=f jaxpr=f1] d - h:f32[2] = mul d d - e:f32[2] = add g h + g:f32[1] = pjit[name=f jaxpr=f] c + h:f32[1] = pjit[name=f jaxpr=f] c + i:f32[1] = mul g h + j:f32[2] = pjit[name=f jaxpr=f1] d + k:f32[2] = pjit[name=f jaxpr=f1] d + l:f32[2] = mul j k + e:f32[2] = add i l in (e,) } ] c d in (e,) } @@ -2477,6 +2483,20 @@ def test_pjit_committed_array_different_devices_variadic_args(self): r"\[1\].*"): pjit(lambda *x: x)(a, b) + def test_jit_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @partial(jax.jit, donate_argnums=(0,)) + def f(x): + return x, x * 2 + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + jaxpr = jax.make_jaxpr(f)(x) + y = core.jaxpr_as_fun(jaxpr)(x) + self.assertTrue(x.is_deleted()) + self.assertFalse(y[0].is_deleted()) + self.assertFalse(y[1].is_deleted()) + def test_pjit_pytree_inp_device_assignment_mismatch(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) From 6033592a9544f8c440df871aa6502a5ffeae6641 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 26 Mar 2025 16:35:30 -0700 Subject: [PATCH 0208/1769] Rename xla_extension_version to jaxlib_extension_version to reflect its new scope. PiperOrigin-RevId: 740944270 --- docs/jep/9419-jax-versioning.md | 6 +++--- jax/_src/callback.py | 4 ++-- jax/_src/lib/__init__.py | 2 +- jaxlib/xla/xla_client.py | 2 +- tests/python_callback_test.py | 10 +++++----- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/jax/_src/callback.py b/jax/_src/callback.py index dc60bfb94356..683da66638e6 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -33,7 +33,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -827,7 +827,7 @@ def _wrapped_callback(*args): return outputs, token, None # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". - if xla_extension_version <= 320: + if jaxlib_extension_version <= 320: result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) if token: diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index be551449aa17..fef5d2c26038 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -117,7 +117,7 @@ def _xla_gc_callback(*args): # Only for the internal usage of the JAX developers, we expose a version # number that can be used to perform changes without breaking the main # branch on the Jax github. -xla_extension_version: int = getattr(xla_client, '_version', 0) +jaxlib_extension_version: int = getattr(xla_client, '_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index eb7a3b5759e5..80cdeef47387 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -49,7 +49,7 @@ profiler = _xla.profiler # Just an internal arbitrary increasing number to help with backward-compatible -# changes. In JAX, reference this via jax._src.lib.xla_extension_version. +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. _version = 322 # Version number for MLIR:Python components. diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 5650a2d4f48b..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,7 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -588,8 +588,8 @@ def fun(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if xla_extension_version <= 321: - self.skipTest("Requires xla_extension_version >= 322.") + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -613,8 +613,8 @@ def f(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if xla_extension_version <= 321: - self.skipTest("Requires xla_extension_version >= 322.") + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) From f949b8b8f62c986849fb2a59d8cac61467dc6eff Mon Sep 17 00:00:00 2001 From: kaixih Date: Wed, 26 Mar 2025 20:57:30 +0000 Subject: [PATCH 0209/1769] Enable public doc for scaled dot --- docs/jax.nn.rst | 3 + jax/_src/nn/functions.py | 222 ++++++++++++++++++++++++++++----------- jax/nn/__init__.py | 1 + tests/nn_test.py | 23 +--- 4 files changed, 168 insertions(+), 81 deletions(-) diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index 2e2e9644d50d..339f07f4cdcc 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -54,3 +54,6 @@ Other functions standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index ee0643e116f9..cc4a345641dd 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1210,81 +1210,184 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], return jnp.reshape(out, output_shape) def scaled_matmul( - lhs: Array, - rhs: Array, - lhs_scales: Array, - rhs_scales: Array, + a: Array, + b: Array, + a_scales: Array, + b_scales: Array, preferred_element_type: DTypeLike = jnp.float32, ) -> Array: - r""" - Performs scaled matrix multiplication between two 3D arrays, with scaling - factors applied to the matrices. - .. math:: - \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + r"""Scaled matrix multiplication function. + + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + The last dim is the contracting dim, and block size is inferred. + + Mathematically, this operation is equivalent to:: + + a_block_size = a.shape[-1] // a_scales.shape[-1] + b_block_size = b.shape[-1] // b_scales.shape[-1] + a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) + b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) + jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled) + Args: - lhs (Array): A 3D array of shape (B, M, K). - rhs (Array): A 3D array of shape (B, N, K). - lhs_scales (Array): A 3D array of shape (B, M, K_block). - rhs_scales (Array): A 3D array of shape (B, N, K_block). - preferred_element_type (DTypeLike, optional): The preferred data type - for the computation. Defaults to `jnp.float32`. + a (Array): Shape (B, M, K). + b (Array): Shape (B, N, K). + a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`. + b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`. + preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`. + Returns: - Array: A 3D array of shape (B, M, N) representing the scaled matrix - multiplication result. - Raises: - AssertionError: If the number of columns in `lhs` (`lhs_K`) does not - match the number of columns in `rhs` (`rhs_K`). + Array of shape (B, M, N). + Notes: - - The function ensures that the `preferred_element_type` is - danonicalized before passing it to the underlying computation. - - Scaling is applied to the matrices based on the `lhs_scales` and - `rhs_scales` arrays, enabling efficient computations in blocks. + - We currently do not support user-defined `precision` for customizing the + compute data type. It is fixed to `jnp.float32`. + - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`. + - To use cuDNN with Nvidia Blackwell GPUs, inputs must match:: + + # mxfp8 + a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 + a_scales, b_scales: jnp.float8_e8m0fnu + block_size: 32 + # nvfp4 + a, b: jnp.float4_e2m1fn + a_scales, b_scales: jnp.float8_e4m3fn + block_size: 16 + + Examples: + + Basic case: + + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) + Array([[[8.]]], dtype=float32) + + Using fused cuDNN call on Blackwell GPUs: + + >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) """ - B, M, lhs_K = lhs.shape - _, N, rhs_K = rhs.shape - assert lhs_K == rhs_K - _, _, K_block = lhs_scales.shape + assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales)) + B_a, M_a, K_a = a.shape + B_b, N_b, K_b = b.shape + assert K_a == K_b and B_a == B_b + B_as, M_as, K_as = a_scales.shape + B_bs, N_bs, K_bs = b_scales.shape + assert K_as == K_bs and B_as == B_bs + assert M_as == M_a and N_bs == N_b preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) ) out = cudnn_scaled_matmul( - lhs, - rhs, - lhs_scales, - rhs_scales, + a, + b, + a_scales, + b_scales, preferred_element_type=preferred_element_type, ) return out +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = None): + r"""Get quantization configs for scaled_dot_general. + + Create quantization configs for the `jax.nn.scaled_dot_general`. + + See Also: + - :func:`jax.nn.scaled_dot_general`: Scaled dot general function. + """ + + if mode == 'nvfp4': + one = jnp.ones((1,), dtype=jnp.float32) + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=jnp.float4_e2m1fn, + scale_type=jnp.float8_e4m3fn, + global_scale=one if global_scale is None else global_scale, + infer_only=False + ) + elif mode == 'mxfp8': + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + def scaled_dot_general( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, configs: List[BlockScaleConfig] | None = None, - implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. - Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: - .. math:: - \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ - \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ - \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Performs a generalized dot product with block-scaled quantization on the + lhs and rhs inputs. This operation extends `lax.dot_general` to support + user-defined scaling configurations. + + Essentially, the operation follows:: + + a, a_scales = quantize(lhs, configs[0]) + b, b_scales = quantize(rhs, configs[1]) + c = jax.nn.scaled_matmul(a, b, a_scales, b_scales) + Args: - lhs: Left-hand side input tensor. - rhs: Right-hand side input tensor. - dimension_numbers: A tuple specifying the contraction and batch dimensions - for the dot general operation. Must follow the format: - `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. - preferred_element_type: The preferred output data type. Supported types are - `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. - configs: A list of `BlockScaleConfig` specifying the scaling - configurations for the operation. Defaults to `mxfp8`. - implementation: A string to control which implementation backend to use. - Supported strings are `cudnn` (cuDNN block scaled dot). It defaults - to `None`, which will automatically select the best available backend. + lhs (ArrayLike): Input array. + rhs (ArrayLike): Input array. + dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying + the contraction and batch dimensions: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type (DTypeLike, optional): Output data type of the dot + product. Defaults to `jnp.float32`. Other valid types include + `jnp.bfloat16` and `jnp.float16`. + configs (list of BlockScaleConfig, optional): Scaling configurations for + lhs, rhs, and gradients. Users can obtain valid configurations via + `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` + are supported. If `None`, falls back to `lax.dot_general`. + Returns: - The result of the scaled dot general operation. + Array: The resulting tensor, with batch dimensions first, followed by + non-contracting/non-batch dimensions of lhs, and then those of rhs. + + See Also: + - :func:`jax.nn.scaled_matmul`: Scaled matmul function. + - :func:`jax.lax.dot_general`: General dot product operator. + + Notes: + - Unlike `nn.scaled_matmul`, which assumes quantized low-precision + inputs with explicit scaling factors, this operator takes high-precision + inputs, applies quantization internally, and handles the backward pass. + + Examples: + + Creating config for mxfp8: + + >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3 + + Creating config for nvfp4: + + >>> global_scale = jnp.array([0.5], jnp.float32) + >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3 + + Using scaled_dot_general with the configs: + + >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) + >>> lhs = random.normal(keys[0], (3, 128, 64)) + >>> rhs = random.normal(keys[1], (3, 128, 64)) + >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) """ # Create configs if not provided if configs is None: @@ -1300,17 +1403,10 @@ def scaled_dot_general( ) configs = [mxfp8_config for _ in range(3)] - if implementation is None: - implementation = 'cudnn' - - match implementation: - case 'cudnn': - out = cudnn_scaled_dot_general( - lhs, rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - configs=configs - ) - case _: - raise ValueError(f"Unsupported implementation option: {implementation}") + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) return out diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 10f11f829abe..651d9cf4e47f 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -38,6 +38,7 @@ identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, + get_scaled_dot_general_config as get_scaled_dot_general_config, scaled_dot_general as scaled_dot_general, scaled_matmul as scaled_matmul, selu as selu, diff --git a/tests/nn_test.py b/tests/nn_test.py index e46843186c02..385b216aeb57 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -31,7 +31,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import ( quantize, shape_normalization, - BlockScaleConfig, ) from jax.test_util import check_grads from jax import nn @@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available(): if _dtypes.float8_e8m0fnu is None: raise unittest.SkipTest("float8_e8m0fnu is not available.") - def _create_mxfp8_config(): - return BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - - return [_create_mxfp8_config() for _ in range(3)] + return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)] @jtu.with_config(jax_legacy_prng_key="allow", @@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) - def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + def testScaledMatmul(self, contract, lhs_non_contract, dtype): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") # Check if float8_e8m0fnu is available configs = create_mxfp8_configs_if_available() @@ -153,11 +141,10 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): @parameterized.product( is_training=[True, False], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) def testScaledDotGeneral( - self, is_training, output_type, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + self, is_training, output_type): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") configs = create_mxfp8_configs_if_available() From be1f649b510048e30b8c07bd7e1964987c6e2907 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 26 Mar 2025 17:30:22 -0700 Subject: [PATCH 0210/1769] Expose jax._src.lib.ifrt_version which tracks the version of third_party/tensorflow code inside jax. PiperOrigin-RevId: 740957982 --- jax/_src/lib/__init__.py | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/xla.cc | 3 +++ jaxlib/xla/xla_client.py | 6 ++++++ jaxlib/xla/xla_client.pyi | 2 ++ 5 files changed, 13 insertions(+) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index fef5d2c26038..bb542aa5d61d 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -118,6 +118,7 @@ def _xla_gc_callback(*args): # number that can be used to perform changes without breaking the main # branch on the Jax github. jaxlib_extension_version: int = getattr(xla_client, '_version', 0) +ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 512eeb867618..347da6998b57 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -112,6 +112,7 @@ nanobind_extension( "@xla//xla/python:profiler", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:types", + "@xla//xla/python:version", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 6e47be15fc68..668c96869479 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -64,6 +64,7 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" +#include "xla/python/version.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -960,6 +961,8 @@ NB_MODULE(xla_extension, m) { m.def("check_and_canonicalize_memory_kind", &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; } // NOLINT(readability/fn_size) } // namespace xla diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 80cdeef47387..30e8443276c8 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -52,6 +52,12 @@ # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. _version = 322 +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + # Version number for MLIR:Python components. mlir_api_version = 58 diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index 5ac837ef1d85..b182eb65ba60 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -58,6 +58,8 @@ from jaxlib.xla_extension import XlaOp as XlaOp _version: int +_ifrt_version: int + mlir_api_version: int bfloat16: type[numpy.generic] From 8f25337a9fb7ef7a86452a2ca3a2ccfc6d1aee20 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 26 Mar 2025 18:32:39 -0700 Subject: [PATCH 0211/1769] [ragged-paged-attn] Combine k_pages and v_pages into kv_pages and zip on num_kv_heads. Now we should be able to support sharding num_kv_heads to 1 even dtype is bfloat16 while still having good ragged KV scatter because ragged dim still remains in non-tiling dim. PiperOrigin-RevId: 740971413 --- .../pallas/ops/tpu/ragged_paged_attention.py | 174 +++++++++--------- .../pallas/tpu_ragged_paged_attention_test.py | 24 +-- 2 files changed, 99 insertions(+), 99 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 60ac2e34f610..255670c22e90 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -34,8 +34,8 @@ class MultiPageAsyncCopyDescriptor: def __init__( self, - pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] offset, # [seq_idx, kv_pages_start] @@ -72,8 +72,7 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -82,9 +81,16 @@ def ref_ragged_paged_attention( sm_scale: float = 1.0, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, ): - _, _, num_kv_heads, head_dim = k_pages.shape + check_inputs_shapes( + queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_heads = queries.shape[1] assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads @@ -96,8 +102,12 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -122,8 +132,7 @@ def ref_ragged_paged_attention( # Expect to run these checkes during runtime. def validate_inputs_on_runtime( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -131,16 +140,14 @@ def validate_inputs_on_runtime( sliding_window: int | None = None, soft_cap: float | None = None, ): - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs - ) + check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) max_num_batched_tokens = q.shape[0] - page_size = k_pages.shape[1] + page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape if num_seqs[0] > max_num_seqs: raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" @@ -167,22 +174,19 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. def check_inputs_shapes( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] ): _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 max_num_seqs, _ = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." - ) if head_dim_k != head_dim: raise ValueError( f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." @@ -221,13 +225,11 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] @@ -235,11 +237,16 @@ def ragged_paged_attention_kernel( sm_scale: float, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, ): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape + ) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -256,22 +263,17 @@ def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_kv_heads_per_blk - async_copy_k = MultiPageAsyncCopyDescriptor( - k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - k_bufs.at[buf_idx], - sems.at[buf_idx, 0], - page_indices_ref, - offset, - ) - async_copy_v = MultiPageAsyncCopyDescriptor( - v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - v_bufs.at[buf_idx], - sems.at[buf_idx, 1], + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], + kv_bufs.at[buf_idx], + sems.at[buf_idx], page_indices_ref, offset, ) - return async_copy_k, async_copy_v + return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: # 1. Support arbitrary strided load/store for any dtype. @@ -303,11 +305,10 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, init_seq_idx, 0, init_buf_idx ) - async_copy_k.start() - async_copy_v.start() + async_copy_kv.start() def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states @@ -512,21 +513,18 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): reuse the same buffer if it is already prefetched! # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! - next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_async_copy_kv = create_kv_async_copy_descriptors( next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx ) - next_async_copy_k.start() - next_async_copy_v.start() + next_async_copy_kv.start() - cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + cur_async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx ) - kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) - v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) for kv_head_idx in range(num_kv_heads_per_blk): q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at @@ -534,8 +532,12 @@ def prefetch_next_kv_blk(): q = fold_on_2nd_minor( q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] ) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + k = strided_load_kv( + kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk + ) + v = strided_load_kv( + kv_ref, kv_head_idx * 2 + 1, num_combined_kv_heads_per_blk + ) flash_attention( q, k, @@ -566,7 +568,7 @@ def prefetch_next_kv_blk(): seq_buf_idx_ref[1] = buf_idx -def ceil_div(a, b): +def cdiv(a, b): assert b != 0 return (a + b - 1) // b @@ -583,7 +585,9 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -594,22 +598,26 @@ def can_be_xla_fully_tiled(x, packing): return x in (1, 2, 4, 8) or x % 8 == 0 # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 ratio = num_q_heads // num_kv_heads # TODO(jevinjiang): we can choose smaller tiling for packed type if large # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads ) - min_q_heads = min_kv_heads * ratio + min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads @functools.partial( @@ -627,8 +635,7 @@ def can_be_xla_fully_tiled(x, packing): def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -637,7 +644,7 @@ def ragged_paged_attention( sm_scale: float = 1.0, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, vmem_limit_bytes: int | None = None, @@ -646,8 +653,7 @@ def ragged_paged_attention( Args: q: concatenated all sequences' queries. - k_pages: paged K cache. Normally in HBM. - v_pages: paged V cache. Normally in HBM. + kv_pages: paged K cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. @@ -666,18 +672,22 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs - ) + check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE _, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype + num_q_blks = cdiv(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -692,7 +702,6 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): in_specs = [ q_block_spec, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( @@ -706,15 +715,14 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, + num_combined_kv_heads_per_blk, head_dim, ), - k_pages.dtype, + kv_pages.dtype, ) scratch_shapes = [ - double_buf_scratch, # k_bufs - double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref ] @@ -753,4 +761,4 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): # TODO(jevinjiang): Use f32 acc scratch for output! So we only need # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) + return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 815c9dc6327f..b76d30bd1dcf 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -74,32 +74,26 @@ def _test_ragged_paged_attention( cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2, k3 = jax.random.split(prng_key, 4) + k0, k1, k2 = jax.random.split(prng_key, 3) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), dtype=dtype, ) - k_pages = jax.random.normal( + kv_pages = jax.random.normal( k1, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, - ) - v_pages = jax.random.normal( - k2, - (num_pages, page_size, num_kv_heads, head_dim), + (num_pages, page_size, num_kv_heads * 2, head_dim), dtype=dtype, ) page_indices = jax.random.randint( - k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + k2, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 ) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) validate_inputs_on_runtime( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -111,8 +105,7 @@ def _test_ragged_paged_attention( actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -126,8 +119,7 @@ def _test_ragged_paged_attention( expected = ref_ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -272,7 +264,7 @@ def test_ragged_paged_attention_mixed(self, dtype): @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4), (8, 1)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], From c7d04cc75a3aac39a677d318e4b82204a2f096b2 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 27 Mar 2025 05:09:25 +0000 Subject: [PATCH 0212/1769] Improve based on review 2 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index b1d353e7bcd1..60cdbee7fa20 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -494,12 +494,9 @@ def quantize(x, config): assert config.global_scale.dtype == jnp.float32 SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) - prev_amax = config.global_scale * (MAX * SCALE_MAX) - scales_q = jnp.clip( - (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX - ) - scaled_x = x / scales_q + scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) scales_q = scales_q.astype(config.scale_type) + scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") @@ -644,6 +641,9 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + # We apply a Straight-Through Estimator (STE) with zero-out behavior: if + # inputs are clipped during quantization in fprop, their corresponding gradients + # are zeroed out; otherwise, they pass through unchanged. if configs[2].mode == "nvfp4": assert rhs.dtype == lhs.dtype MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) From 0c1f4c155ec49ccd5cde85d3dacd8e0b7c7afb47 Mon Sep 17 00:00:00 2001 From: Mudit Gokhale Date: Wed, 26 Mar 2025 23:12:32 -0700 Subject: [PATCH 0213/1769] Remove backward compatibility logic for tool naming. PiperOrigin-RevId: 741030788 --- jax/collect_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/collect_profile.py b/jax/collect_profile.py index d1309e0c5bca..b355816772a1 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -91,7 +91,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, in root_trace_folder.iterdir()] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) - result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer^", {}) + result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer", {}) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) From e1762b0af6c5199d53141557bdf81eaef55bd4c5 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 27 Mar 2025 00:46:21 -0700 Subject: [PATCH 0214/1769] Assert unused variable in lax.all_to_all batching rule P.S. minor improvement to code readability PiperOrigin-RevId: 741051082 --- jax/_src/lax/parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 221fe2a9e87a..28e6dbef4a2c 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1109,15 +1109,15 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): - axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + axis_size, frame_name = axis_data.size, axis_data.name if isinstance(axis_name, (list, tuple)): axes_names = axis_name else: axes_names = [axis_name] - if axis_data.name not in axes_names: + if frame_name not in axes_names: return _all_to_all_batcher( vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) @@ -1157,6 +1157,7 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_index_groups=axis_index_groups, tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) + assert d == 1 x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis From 8bd956d96a6979bcead917d7d0f8593203888cfe Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 03:02:36 -0700 Subject: [PATCH 0215/1769] [Pallas] Skip reads/writes from/to slices of kernel input/output buffers when the slices do not change between iterations of the grid loop that interprets kernels on CPU. PiperOrigin-RevId: 741082349 --- jax/_src/pallas/mosaic/interpret.py | 192 ++++++++++++++++------ tests/pallas/tpu_pallas_interpret_test.py | 127 ++++++++++---- 2 files changed, 237 insertions(+), 82 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5acbabc673aa..9d7b03ad5589 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1307,8 +1307,13 @@ def _compute_start_indices( jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, compiler_params=compiler_params, interpret_params=interpret_params) if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) + ret = jnp.array( + tuple( + i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices) + ), + dtype=jnp.int32, + ) elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): ret = block_indices else: @@ -1534,64 +1539,114 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 - def body(carry): - # The loop carry: (i, loop_idx) -- - # - i:int32 is the interation index - # - loop_idx: tuple[int32] are the program ids for each grid axis - i, loop_idx = carry - + def _get_local_grid_env(loop_idx): if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + return grid_mapping.local_grid_env(loop_idx, grid) else: - local_grid_env = tuple( + return tuple( pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.vmapped_dims ) - with pallas_core.grid_env(local_grid_env): - start_indices = [ + def body( + carry: tuple[ + jnp.int32, tuple[jnp.int32, ...], list[jnp.ndarray], list[jnp.ndarray] + ], + ): + """Performs a single iteration of `jaxpr` in the device grid. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, prev_start_indices, cur_start_indices). + - iteration_idx is the interation index. + - loop_idx are the program ids for each grid axis. + - prev_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + iteration_idx, loop_idx, prev_start_indices, cur_start_indices = carry + + with pallas_core.grid_env(_get_local_grid_env(loop_idx)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_start_indices = [ _compute_start_indices( - bm, loop_idx, *scalar_buffer_ids, compiler_params=compiler_params, - interpret_params=interpret_params) - for bm in grid_mapping.block_mappings] + bm, + next_loop_idx, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + # Copy slices of the input to the kernel buffers. - # - # TODO(jburnim): Only copy slices when the index mapping has changed? - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue + + def _store_slice_to_kernel_input(index, input_var): # Copy from the HBM buffer for the pallas_call input to the kernel # input buffer. # TODO(jburnim): Just use input_args[j] when the input is not aliased? transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[j], - block_shapes[j], - is_indexing_dim[j])), - shape=input_args[j].shape, - int_indexer_shape=()) + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_indexing_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) sliced_val = callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # read is involved in a data race. get, - jax.ShapeDtypeStruct(var.aval.shape, var.aval.dtype), + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - input_buffer_ids[j], + input_buffer_ids[index], (transform,), - ordered=True) + ordered=True, + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. store, (), device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - input_ids[j], + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], (), sliced_val, - ordered=True) + ordered=True, + ) + + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == 0) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) # Invoke the kernel. _interpret_jaxpr(jaxpr, *kernel_buffer_ids, @@ -1599,29 +1654,30 @@ def body(carry): interpret_params=interpret_params) # Copy from the kernel buffers to slices of the output in HBM. - # - # TODO(jburnim): Only copy if the index mapping will change in the - # next iteration (or if this is the last iteration)? - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue + def _store_to_output_buffer(index, output_var): kernel_output_val = callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # get is involved in a data race. get, - var.aval, + output_var.aval, device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], kernel_output_ids[j], (), - ordered=True) + ordered=True, + ) transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j], - is_indexing_dim[num_inputs + j])), - shape=output_vals[j].shape, - int_indexer_shape=()) + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[num_inputs + index], + block_shapes[num_inputs + index], + is_indexing_dim[num_inputs + index], + ) + ), + shape=output_vals[index].shape, + int_indexer_shape=(index), + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. @@ -1629,18 +1685,52 @@ def body(carry): (), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_ids[j], + output_buffer_ids[index], (transform,), kernel_output_val, - ordered=True) + ordered=True, + ) - return i + 1, _get_next_indices(grid, loop_idx) + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + jax.lax.cond( + (iteration_idx + 1 == num_iterations) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var), + lambda: None, + ) + return iteration_idx + 1, next_loop_idx, cur_start_indices, next_start_indices + + initial_loop_idx = (jnp.int32(0),) * len(grid) + with pallas_core.grid_env(_get_local_grid_env(initial_loop_idx)): + initial_start_indices = [ + _compute_start_indices( + bm, + initial_loop_idx, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] # TODO(jburnim): Handle parallel grid dimensions + megacore. _ = lax.while_loop( lambda carry: carry[0] < num_iterations, body, - (jnp.int32(0), (jnp.int32(0),) * len(grid)) + ( + jnp.int32(0), + initial_loop_idx, + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_start_indices, + ), ) # Read the output from the allocated output buffers. diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 5b729f0fe07e..afb573f8cf44 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,24 +18,48 @@ contains only tests that do not use shard_map. """ -from absl.testing import absltest -from absl.testing import parameterized import functools +from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() +class CountStoreCallbacksContext(object): + """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" + + def __init__(self): + self._num_stores = 0 + self._saved = mosaic_interpret.store + + def __enter__(self): + def _store_callback(self, *args, **kwargs): + self._num_stores += 1 + return self._saved(*args, **kwargs) + + mosaic_interpret.store = functools.partial(_store_callback, self) + return self + + def __exit__(self, ty, value, traceback): + del ty, value, traceback + mosaic_interpret.store = self._saved + + @property + def num_stores(self): + return self._num_stores + + class InterpretTest(jtu.JaxTestCase): + def setUp(self): super().setUp() self.num_devices = jax.device_count() @@ -50,17 +74,18 @@ def matmul_kernel(x_ref, y_ref, z_ref): @jax.jit def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - grid=(2, 2), - in_specs=[ - pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), - pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) - ], - out_specs=pl.BlockSpec( - (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), - ), - interpret=mosaic_interpret.TPUInterpretParams(), + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), + lambda i, j: (i, j), + ), + interpret=mosaic_interpret.TPUInterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -79,9 +104,11 @@ def block_dynamic_slice(x, starts, sizes): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, grid=(1, 1), - in_specs=[pl.BlockSpec( - sizes, - lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + in_specs=[ + pl.BlockSpec( + sizes, lambda i, j, block_idx: (block_idx[0], block_idx[1]) + ) + ], out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), ) @@ -96,17 +123,21 @@ def block_dynamic_slice(x, starts, sizes): shape = (512, 512) x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) - result = block_dynamic_slice(x, starts=jnp.array([128, 256]), sizes=(128, 128)) - ref = jax.lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) + result = block_dynamic_slice( + x, starts=jnp.array([128, 256]), sizes=(128, 128) + ) + ref = jax.lax.dynamic_slice( + x, start_indices=(128, 256), slice_sizes=(128, 128) + ) diff = jnp.max(jnp.abs(result - ref)) np.testing.assert_allclose(result, ref) def test_dynamic_grid_and_aliasing(self): - self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) + @jax.jit def f(s, x): return pl.pallas_call( @@ -119,11 +150,11 @@ def f(s, x): ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams() + interpret=mosaic_interpret.TPUInterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) - x = jnp.arange(32 * 128.).reshape((32, 128)) + x = jnp.arange(32 * 128.0).reshape((32, 128)) y = f(s, x) # NOTE: No matter how many times the kernel body is run, the kernel input # buffer will only be written once by the pallas_call machinery, just @@ -136,6 +167,7 @@ def kernel(x_ref, o_ref, s_ref): @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) def _(): s_ref[0] = jnp.int32(0) + s = s_ref[0] s_ref[0] = s + 1 o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) @@ -149,7 +181,8 @@ def _(): pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), ], out_specs=pl.BlockSpec( - block_shape=(8, 128), index_map=lambda i, j: (j, i)), + block_shape=(8, 128), index_map=lambda i, j: (j, i) + ), scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), input_output_aliases={0: 0}, interpret=mosaic_interpret.TPUInterpretParams(), @@ -184,7 +217,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): copy.wait() x = jnp.zeros((8, 128), jnp.float32) - y = pl.pallas_call(kernel_without_race, + y = pl.pallas_call( + kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -192,12 +226,14 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, x + 1.0) - pl.pallas_call(kernel_with_race, + pl.pallas_call( + kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -205,7 +241,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) @@ -228,8 +265,8 @@ def matmul(x: jax.Array, y: jax.Array): z = jax.jit(matmul)(x, y) np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) - lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") - self.assertNotIn("dot_general", lowered) + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) @parameterized.parameters('nan', 'zero') def test_uninitialized_memory(self, uninitialized_memory): @@ -250,7 +287,8 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.int16), ], interpret=mosaic_interpret.TPUInterpretParams( - uninitialized_memory=uninitialized_memory), + uninitialized_memory=uninitialized_memory + ), )() if uninitialized_memory == 'nan': self.assertTrue(jnp.isnan(x).all()) @@ -261,6 +299,33 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): np.testing.assert_equal(np.array(y), 0) np.testing.assert_equal(np.array(z), 0) + def test_correct_number_of_stores(self): + def kernel(x_ref, s_ref, o_ref): + s = s_ref[0] + x_ref[:] += jax.lax.full_like(x_ref, s) + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + + def kernel_call(x, s): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.float32), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + interpret=mosaic_interpret.TPUInterpretParams(), + )(x, s) + + with CountStoreCallbacksContext() as store_callbacks_counter: + result = jax.jit(kernel_call)( + jnp.zeros((16, 256), jnp.float32), jnp.zeros((1,), jnp.int32) + ) + np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) + self.assertEqual(store_callbacks_counter.num_stores, 5) + -if __name__ == "__main__": +if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 8689550376089e78f22f87305422ffc6aaf5ddb8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 05:17:17 -0700 Subject: [PATCH 0216/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/95abd7942747bd5d1884b309baecdf5a93ff928a. PiperOrigin-RevId: 741114363 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 359048ffacbb..625f33a072f5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d0b25f9cd8222a348c9728f88e909c4e2c30991b" -XLA_SHA256 = "8cd70a67a56a8b18087fc4849908f52c95c6413eb7edc9f800fdff6304804fa4" +XLA_COMMIT = "95abd7942747bd5d1884b309baecdf5a93ff928a" +XLA_SHA256 = "f8472323ffe621ade5317091fdf9acd66aaf67660fedd3143a96d9a347e88bac" def repo(): tf_http_archive( From 875e4795c444071604afe441c0d0fe965ccb0d50 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 27 Mar 2025 07:02:22 -0700 Subject: [PATCH 0217/1769] Update `test_util.get_tpu_version()` PiperOrigin-RevId: 741139032 --- jax/_src/test_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c3c4a934dd0e..1cd9546a1655 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -435,10 +435,10 @@ def get_tpu_version() -> int: if device_under_test() != "tpu": raise ValueError("Device is not TPU") kind = jax.devices()[0].device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == "TPU v", kind - return int(kind[-1]) + match = re.match(r"TPU[^\d]*(\d+)", kind) + if match is None: + raise ValueError(f"Device kind {kind} is not supported") + return int(match.group(1)) def is_device_tpu_at_least(version: int) -> bool: if device_under_test() != "tpu": From 9932ff1f79e3488a6660b44c9390bf81dc6389f5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 07:27:43 -0700 Subject: [PATCH 0218/1769] Deprecate the contents of jax.lib.xla_extension. PiperOrigin-RevId: 741145943 --- CHANGELOG.md | 2 + jax/lib/xla_extension.py | 132 +++++++++++++++++++++++++++++++-------- 2 files changed, 108 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93bbe81b5e63..c8805599364d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,9 +22,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. instead. * Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 52fe94e231d1..8f1b27070e98 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -14,42 +14,122 @@ from jax._src.lib import xla_extension as _xe -get_distributed_runtime_client = _xe.get_distributed_runtime_client -get_distributed_runtime_service = _xe.get_distributed_runtime_service -hlo_module_cost_analysis = _xe.hlo_module_cost_analysis -hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph -ifrt_proxy = _xe.ifrt_proxy -jax_jit = _xe.jax_jit -mlir = _xe.mlir -pmap_lib = _xe.pmap_lib -profiler = _xe.profiler -pytree = _xe.pytree -Device = _xe.Device -DistributedRuntimeClient = _xe.DistributedRuntimeClient -HloModule = _xe.HloModule -HloPrintOptions = _xe.HloPrintOptions -OpSharding = _xe.OpSharding -PjitFunctionCache = _xe.PjitFunctionCache -PjitFunction = _xe.PjitFunction -PmapFunction = _xe.PmapFunction - _deprecations = { - # Added Nov 20 2024 "ArrayImpl": ( - "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", - _xe.ArrayImpl, + ( + "jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array" + " instead." + ), + None, ), "XlaRuntimeError": ( - "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", - _xe.XlaRuntimeError, + ( + "jax.lib.xla_extension.XlaRuntimeError has been removed; use" + " jax.errors.JaxRuntimeError instead." + ), + None, + ), + # Deprecated March 26 2025. + "DistributedRuntimeClient": ( + ( + "jax.lib.xla_extension.DistributedRuntimeClient is" + " deprecated; use jax.distributed instead." + ), + _xe.DistributedRuntimeClient, + ), + "get_distributed_runtime_client": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_client is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_client, + ), + "get_distributed_runtime_service": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_service is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_service, + ), + "Device": ( + "jax.lib.xla_extension.Device is deprecated; use jax.Device instead.", + _xe.Device, + ), + "PjitFunctionCache": ( + "jax.lib.xla_extension.PjitFunctionCache is deprecated.", + _xe.PjitFunctionCache, + ), + "ifrt_proxy": ( + "jax.lib.xla_extension.ifrt_proxy is deprecated.", + _xe.ifrt_proxy, + ), + "jax_jit": ( + "jax.lib.xla_extension.jax_jit is deprecated.", + _xe.jax_jit, + ), + "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _xe.mlir), + "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _xe.pmap_lib), + "profiler": ( + "jax.lib.xla_extension.profiler is deprecated.", + _xe.profiler, + ), + "pytree": ( + "jax.lib.xla_extension.pytree is deprecated.", + _xe.pytree, + ), + "hlo_module_cost_analysis": ( + "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", + _xe.hlo_module_cost_analysis, + ), + "hlo_module_to_dot_graph": ( + "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", + _xe.hlo_module_to_dot_graph, + ), + "HloModule": ( + "jax.lib.xla_extension.HloModule is deprecated.", + _xe.HloModule, + ), + "HloPrintOptions": ( + "jax.lib.xla_extension.HloPrintOptions is deprecated.", + _xe.HloPrintOptions, + ), + "OpSharding": ( + "jax.lib.xla_extension.OpSharding is deprecated.", + _xe.OpSharding, + ), + "PjitFunction": ( + "jax.lib.xla_extension.PjitFunction is deprecated.", + _xe.PjitFunction, + ), + "PmapFunction": ( + "jax.lib.xla_extension.PmapFunction is deprecated.", + _xe.PmapFunction, ), } import typing as _typing if _typing.TYPE_CHECKING: - ArrayImpl = _xe.ArrayImpl - XlaRuntimeError = _xe.XlaRuntimeError + Device = _xe.Device + DistributedRuntimeClient = _xe.DistributedRuntimeClient + HloModule = _xe.HloModule + HloPrintOptions = _xe.HloPrintOptions + OpSharding = _xe.OpSharding + PjitFunction = _xe.PjitFunction + PjitFunctionCache = _xe.PjitFunctionCache + PmapFunction = _xe.PmapFunction + + get_distributed_runtime_client = _xe.get_distributed_runtime_client + get_distributed_runtime_service = _xe.get_distributed_runtime_service + hlo_module_cost_analysis = _xe.hlo_module_cost_analysis + hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph + ifrt_proxy = _xe.ifrt_proxy + jax_jit = _xe.jax_jit + mlir = _xe.mlir + pmap_lib = _xe.pmap_lib + profiler = _xe.profiler + pytree = _xe.pytree + else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr From 108c590b2f11ffd4ed7a75d884f907bb945ef05b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 08:10:00 -0700 Subject: [PATCH 0219/1769] Replace uses of deprecated `Shape::rank()` with: - `dimensions().size()` if it's OK for the result to be changed to an unsigned number, - `dimensions_size()` if it's important that the result is a signed number. This should be a pure refactoring that doesn't affect the code's behavior. Note that `rank()` returns `int64_t` and `dimensions().size()` returns `size_t`. Sometimes the change of the signedness is not desirable, and we use `dimensions_size()`, which returns `int`, in such cases. PiperOrigin-RevId: 741157851 --- jaxlib/xla/xla_compiler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index 00f8b4c295a7..0098cc28160d 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -648,7 +648,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { nb::arg("dimension")) .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, nb::arg("dimension"), nb::arg("is_dynamic")) - .def("rank", &Shape::rank) + .def("rank", &Shape::dimensions_size) .def("to_serialized_proto", [](const Shape& shape) { ShapeProto proto = shape.ToProto(); From 99d92f26a6f2c01659a6afe3ad2744f86fa521fa Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 08:10:06 -0700 Subject: [PATCH 0220/1769] Explicitly export mgpu runtime symbols. PiperOrigin-RevId: 741157879 --- jaxlib/mosaic/gpu/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index abe326474808..80a8f0e51080 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -111,9 +111,12 @@ cc_library( cc_library( name = "runtime", srcs = ["runtime.cc"], + # Linker may prune these symbols if they are not explicitly exported. + linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], deps = [ "@local_config_cuda//cuda:cuda_headers", ], + alwayslink = True, ) cc_library( From 083bdfc9cc2613086ee2273395f127b20598dc6d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 08:44:31 -0700 Subject: [PATCH 0221/1769] Add license headers to files that were missing them. PiperOrigin-RevId: 741167870 --- jax/_src/mesh_utils.py | 2 +- jax/experimental/mesh_utils.py | 2 +- jaxlib/ffi_helpers.h | 15 +++++++++++++++ jaxlib/gpu/triton.cc | 15 +++++++++++++++ jaxlib/gpu/triton_kernels.cc | 15 +++++++++++++++ jaxlib/gpu/triton_kernels.h | 15 +++++++++++++++ jaxlib/gpu/triton_utils.cc | 15 +++++++++++++++ jaxlib/gpu/triton_utils.h | 15 +++++++++++++++ jaxlib/mlir/_mlir_libs/register_jax_dialects.cc | 15 +++++++++++++++ .../dialect/gpu/integrations/c/attributes.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/apply_vector_layout.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/apply_vector_layout.h | 15 +++++++++++++++ .../transforms/apply_vector_layout_extensions.h | 15 +++++++++++++++ .../dialect/tpu/transforms/canonicalize_mosaic.cc | 15 +++++++++++++++ .../extensions/apply_vector_layout_extensions.cc | 15 +++++++++++++++ .../extensions/infer_vector_layout_extensions.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/infer_memref_layout.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/infer_memref_layout.h | 15 +++++++++++++++ .../transforms/infer_vector_layout_extensions.h | 15 +++++++++++++++ .../dialect/tpu/transforms/relayout_insertion.cc | 15 +++++++++++++++ jaxlib/mosaic/dialect/tpu/transforms/serde.h | 15 +++++++++++++++ jaxlib/mosaic/dialect/tpu/util.h | 15 +++++++++++++++ tests/mesh_utils_test.py | 2 +- 23 files changed, 303 insertions(+), 3 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index ccc75af8c84f..c135919b14c5 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 075e4e6eed48..58d20c331d5f 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 5c6d80093df5..634a48fcffc7 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 135410568f6b..d0c48eef492f 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,3 +1,18 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff908bc..6565b5b87be2 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_kernels.h" #include diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index c3457093c4f8..d23a9a7395e0 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index b3a0779118de..f6bbe46c846d 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_utils.h" #include diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 0c286391e296..19c64a88c216 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_UTILS_H_ #define JAXLIB_GPU_TRITON_UTILS_H_ diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 1ba6fd9375df..0eb4a57a2f4b 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,3 +1,18 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. #include diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index eac1d104f07f..259c37fe5d07 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -1,3 +1,18 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 71924739595c..c9c8d22a1363 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index ed72a21028eb..bbf23a9f3844 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index fded0d1dbfd7..72bd8ca370c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 6f56489ab4b1..a15947f48f78 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index 067f8e592e30..d2c149a47150 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" #include "llvm/ADT/StringMap.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index 9dbf89724fef..e34ef7fcb261 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index fdfd04949bce..e2196088728f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index f2ab7c624eb1..a6dd8ad1dbd3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index 36fa2ce8113f..a81e982f8e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index 8aae7a10279a..6ddf8bd5ce66 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index ccb32131e519..5da8a9e316e0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index dadd71800f3e..e2cf27811f09 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 136b507942e7..28efb266b281 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e342f2dd602ea33cc395dbcd71e38191ebf593d3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 09:53:09 -0700 Subject: [PATCH 0222/1769] Update the minimum supported CuDNN version to 9.8 (previously 9.1). Announce maximum supported CUDA version 12.8 (previously 12.3). PiperOrigin-RevId: 741188737 --- CHANGELOG.md | 5 +++++ build/gpu-test-requirements.txt | 2 +- build/requirements_lock_3_10.txt | 8 ++++---- build/requirements_lock_3_11.txt | 8 ++++---- build/requirements_lock_3_12.txt | 8 ++++---- build/requirements_lock_3_13.txt | 8 ++++---- build/requirements_lock_3_13_ft.txt | 8 ++++---- jax_plugins/cuda/plugin_setup.py | 2 +- 8 files changed, 27 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8805599364d..cfd8c2eb340b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + * Deprecations * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt index ff43f91ba90f..d0dda5cf526c 100644 --- a/build/gpu-test-requirements.txt +++ b/build/gpu-test-requirements.txt @@ -5,7 +5,7 @@ nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 6ed6b59aa584..8bf5293bd948 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -410,10 +410,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..487346ab6d12 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..e2f76cab8abc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..403d0ad8a061 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -460,10 +460,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..5157706c00e8 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -413,10 +413,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index ce31684de46f..db9928f6cf61 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -57,7 +57,7 @@ def has_ext_modules(self): "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.1,<10.0", + "nvidia-cudnn-cu12>=9.8,<10.0", "nvidia-cufft-cu12>=11.0.2.54", "nvidia-cusolver-cu12>=11.4.5.107", "nvidia-cusparse-cu12>=12.1.0.106", From 3c81b184a7b169827069451373f671fc42543c51 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 13:09:22 -0400 Subject: [PATCH 0223/1769] Add sm_100 and sm_120 to the list of CUDA GPU achitectures for which we compile. --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 642fb15ed541..76f72b0848a9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -136,7 +136,7 @@ build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda From 0dbc1222657e318c31edad50e1f567b835574c0f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 10:11:35 -0700 Subject: [PATCH 0224/1769] Add the `jax` wheel as a required dependency for running the Bazel CUDA non RBE tests Since https://github.com/jax-ml/jax/pull/27113, the wheel is tested when `--//jax:build_jaxlib=false`. Previously, we could depend on the source repository. Fixes https://github.com/jax-ml/jax/actions/runs/14108610313/job/39521951667 PiperOrigin-RevId: 741195252 --- .github/workflows/bazel_cuda_non_rbe.yml | 1 + .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 0b0e1cb62497..3d15f4211a3f 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -79,6 +79,7 @@ jobs: continue-on-error: true run: >- mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..4b6e1e0a8712 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -148,7 +148,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure From 18521fef08f0d42f6001141abb793998323f72b3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 14:45:41 -0700 Subject: [PATCH 0225/1769] Deprecate jax.tree_* aliases --- CHANGELOG.md | 4 ++++ jax/__init__.py | 50 ++++++++++++++++--------------------------------- 2 files changed, 20 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cfd8c2eb340b..5785f6193065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/__init__.py b/jax/__init__.py index 988c224e4772..32ae955ae5b8 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -141,16 +141,6 @@ make_array_from_process_local_data as make_array_from_process_local_data, ) -from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, - treedef_is_leaf as _deprecated_treedef_is_leaf, - tree_flatten as _deprecated_tree_flatten, - tree_leaves as _deprecated_tree_leaves, - tree_structure as _deprecated_tree_structure, - tree_transpose as _deprecated_tree_transpose, - tree_unflatten as _deprecated_tree_unflatten, -) - # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import custom_derivatives as custom_derivatives @@ -184,54 +174,46 @@ del _ccache _deprecations = { - # Added July 2022 + # Finalized 2025-03-25; remove after 2025-06-25 "treedef_is_leaf": ( - "jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.", - _deprecated_treedef_is_leaf + "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.", + None ), "tree_flatten": ( - "jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) " + "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_flatten (any JAX version).", - _deprecated_tree_flatten + None ), "tree_leaves": ( - "jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) " + "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) " "or jax.tree_util.tree_leaves (any JAX version).", - _deprecated_tree_leaves + None ), "tree_structure": ( - "jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) " + "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) " "or jax.tree_util.tree_structure (any JAX version).", - _deprecated_tree_structure + None ), "tree_transpose": ( - "jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) " + "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) " "or jax.tree_util.tree_transpose (any JAX version).", - _deprecated_tree_transpose + None ), "tree_unflatten": ( - "jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) " + "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_unflatten (any JAX version).", - _deprecated_tree_unflatten + None ), - # Added Feb 28, 2024 "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " + "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) " "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map + None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf - from jax._src.tree_util import tree_flatten as tree_flatten - from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map - from jax._src.tree_util import tree_structure as tree_structure - from jax._src.tree_util import tree_transpose as tree_transpose - from jax._src.tree_util import tree_unflatten as tree_unflatten - + pass else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) From 289221af8be2979d9a1e25c7e61e1eee18274948 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 10:23:08 -0700 Subject: [PATCH 0226/1769] Use h100x2 for tests rather than p100x2. PiperOrigin-RevId: 741199510 --- tests/BUILD | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 1baeb4f83af7..1d021b0c7110 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -97,7 +97,7 @@ jax_multiplatform_test( "gpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], env = { "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. @@ -190,7 +190,7 @@ jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. deps = ["//jax:extend"], @@ -274,7 +274,7 @@ jax_multiplatform_test( srcs = ["memories_test.py"], enable_configs = [ "cpu", - "gpu_p100x2", + "gpu_h100x2", "tpu_v3_2x2", "tpu_v4_2x2", "tpu_v5p_2x2", @@ -301,7 +301,7 @@ jax_multiplatform_test( "gpu_p100x2_shardy", "tpu_v3_2x2_shardy", "tpu_v3_2x2", - "gpu_p100x2", + "gpu_h100x2", ], shard_count = { "cpu": 5, @@ -725,7 +725,7 @@ jax_multiplatform_test( "cpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", "gpu_p100x2_shardy", "gpu_p100x2_pjrt_c_api", ], @@ -766,7 +766,7 @@ jax_multiplatform_test( srcs = ["multibackend_test.py"], enable_configs = [ "tpu_v3_2x2", - "gpu_p100x2", + "gpu_h100x2", ], ) From 1719fa0d5bd0d95be54a9327ae6dfdff142dafbe Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 27 Mar 2025 10:26:35 -0700 Subject: [PATCH 0227/1769] Make sure array is copied under this situation: ``` x = np.arange(1000) y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) ``` This condition will be true after this change `z.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()` Also lift the restrictions that CopyToMemorySpace doesn't work sometimes for matching src+dest memory spaces. We can always bounce through the host if there is no more efficient copy. PiperOrigin-RevId: 741200853 --- jax/_src/dispatch.py | 3 +++ jax/_src/interpreters/pxla.py | 2 +- jaxlib/xla/py_values.cc | 13 ++++++++----- jaxlib/xla/xla_client.py | 2 +- tests/pjit_test.py | 13 +++++++++++++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2330f7628966..d205f860b214 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -495,6 +495,9 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device + if copy == CopySemantics.COPY: + return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6f95b1b72281..51854b457b37 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -257,7 +257,7 @@ def _shard_abstract_array(size, axis: int, x): raise ValueError(f"Axis size {size} does not match dimension {axis} of " f"shape {x.shape}") except IndexError: - raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None + raise ValueError(f"Cannot split a {x.dim}D value along axis {axis}") from None if config.pmap_no_rank_reduction.value: return x.update(shape=tuple_update(x.shape, axis, 1)) else: diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 1c7db0bec13a..e13a38197c0a 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -418,6 +418,7 @@ absl::StatusOr HandlePyArray( } if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && (!to_memory_kind.memory_kind().has_value() || !ifrt_array->sharding().memory_kind().memory_kind().has_value() || ifrt_array->sharding().memory_kind() == to_memory_kind)) { @@ -426,15 +427,17 @@ absl::StatusOr HandlePyArray( return [result = std::move(result)]() mutable { return std::move(result); }; } else { return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, - owning_pybuffer = py_array.weak_type()]() mutable + owning_pybuffer = py_array.weak_type(), + allow_zero_copy = options.allow_zero_copy]() mutable -> absl::StatusOr { auto* ifrt_client = ifrt_array->client(); TF_ASSIGN_OR_RETURN( auto copied_ifrt_arrays, - ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1), - ifrt_client->MakeDeviceList({to_device}), - to_memory_kind, - ifrt::ArrayCopySemantics::kReuseInput)); + ifrt_client->CopyArrays( + absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), to_memory_kind, + allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); return DevicePutResult(std::move(copied_ifrt_arrays[0]), std::move(owning_pybuffer)); }; diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 30e8443276c8..776a22444208 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 322 +_version = 323 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d72ecc98e771..aa5afccb38d4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -63,6 +63,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -1400,6 +1401,18 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + def test_device_put_copy_donate(self): + if jaxlib_extension_version < 323: + raise unittest.SkipTest("Copy not supported in device put.") + x = np.arange(1000) + y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) + z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) + a = jax.jit(lambda y: y * 2, donate_argnums=0)(y) + self.assertDeleted(y) + self.assertNotDeleted(z) + self.assertArraysEqual(a, x * 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): From 3f8e1925f7f47a9aac176feb6c57028f594a5e17 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 10:30:49 -0700 Subject: [PATCH 0228/1769] Remove CUDA 12.3 from the CUDA test matrix Also, update the Docker image to one with cudnn 12.8 PiperOrigin-RevId: 741202254 --- .github/workflows/pytest_cuda.yml | 7 +++---- .github/workflows/wheel_tests_continuous.yml | 4 ++-- .github/workflows/wheel_tests_nightly_release.yml | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index b3d1b15a0052..671af873b48d 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -54,12 +54,11 @@ jobs: run-tests: defaults: run: - # Explicitly set the shell to bash + # Set the shell to bash as GitHub actions run with /bin/sh by default shell: bash runs-on: ${{ inputs.runner }} - # TODO: Update to the generic ML ecosystem test containers when they are ready. - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') || - (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') || + # Test the oldest and newest supported CUDA versions. + container: ${{ (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 4b6e1e0a8712..f48c39bf4721 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -111,9 +111,9 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: # L4 does not run on cuda 12.8 but tests other configs diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 9cd48c925cf3..fd4a52d296e0 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -58,7 +58,7 @@ jobs: # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + cuda: ["12.1", "12.8"] enable-x64: [0] name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: From a61785d2b6fbee58736ff5f570234894bc6d17d9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 10:49:41 -0700 Subject: [PATCH 0229/1769] Run include_cleaner over JAX C++ code. PiperOrigin-RevId: 741208842 --- jaxlib/BUILD | 1 + jaxlib/cuda/BUILD | 9 +++++++- jaxlib/cuda/versions_helpers.cc | 1 + jaxlib/gpu/BUILD | 1 + jaxlib/gpu/blas.cc | 2 +- jaxlib/gpu/gpu_kernel_helpers.cc | 5 +++- jaxlib/gpu/gpu_kernel_helpers.h | 3 +-- jaxlib/gpu/gpu_plugin_extension.cc | 1 + jaxlib/gpu/make_batch_pointers.cu.cc | 1 + jaxlib/gpu/prng.cc | 1 + jaxlib/gpu/prng_kernels.cc | 4 ---- jaxlib/gpu/prng_kernels.cu.cc | 3 +-- jaxlib/gpu/prng_kernels.h | 2 -- jaxlib/gpu/py_client_gpu.h | 1 - jaxlib/gpu/rnn.cc | 2 +- jaxlib/gpu/rnn_kernels.cc | 5 ++++ jaxlib/gpu/rnn_kernels.h | 1 + jaxlib/gpu/solver.cc | 2 +- jaxlib/gpu/sparse.cc | 13 +++++------ jaxlib/gpu/sparse_kernels.cc | 10 ++++---- jaxlib/gpu/sparse_kernels.h | 7 ++---- jaxlib/gpu/triton.cc | 15 +++++++----- jaxlib/gpu/triton_kernels.cc | 4 +++- jaxlib/gpu/triton_kernels.h | 3 +-- jaxlib/gpu/triton_utils.cc | 1 + jaxlib/gpu/triton_utils.h | 1 - jaxlib/gpu/vendor.h | 1 + jaxlib/kernel_helpers.h | 2 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 23 ++++++++++--------- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 5 +--- jaxlib/mlir/_mlir_libs/triton_ext.cc | 1 + jaxlib/mosaic/BUILD | 1 + jaxlib/mosaic/dialect/gpu/BUILD | 4 ++-- .../dialect/gpu/integrations/c/gpu_dialect.h | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 4 +--- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 4 ++-- .../dialect/tpu/integrations/c/tpu_dialect.cc | 5 +++- jaxlib/mosaic/dialect/tpu/layout.cc | 1 - jaxlib/mosaic/dialect/tpu/layout.h | 4 ++-- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 8 ++----- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 11 ++++----- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 1 + .../tpu/transforms/apply_vector_layout.cc | 1 - .../tpu/transforms/canonicalize_mosaic.cc | 13 ++++------- .../dialect/tpu/transforms/communication.cc | 7 ++++-- .../tpu/transforms/infer_memref_layout.cc | 1 - .../tpu/transforms/infer_vector_layout.cc | 3 --- .../transforms/memory_space_specialization.cc | 2 ++ jaxlib/mosaic/dialect/tpu/transforms/serde.h | 1 + jaxlib/mosaic/dialect/tpu/util.cc | 1 + jaxlib/mosaic/dialect/tpu/util.h | 3 ++- jaxlib/mosaic/gpu/custom_call.cc | 1 + jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- jaxlib/mosaic/gpu/passes.cc | 4 ++++ jaxlib/rocm/BUILD | 10 +++++++- jaxlib/xla/BUILD | 10 ++++---- jaxlib/xla/callback.cc | 1 + jaxlib/xla/callback.h | 1 - jaxlib/xla/config.cc | 1 + jaxlib/xla/custom_call_sharding.cc | 1 + jaxlib/xla/dlpack.cc | 2 +- jaxlib/xla/ifrt_proxy.cc | 4 ++-- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/mlir.cc | 4 ---- jaxlib/xla/pmap_lib.h | 3 --- jaxlib/xla/py_array.cc | 1 - jaxlib/xla/py_array.h | 1 + jaxlib/xla/py_client.cc | 2 -- jaxlib/xla/py_client.h | 1 - jaxlib/xla/py_device.h | 1 + jaxlib/xla/py_device_list.cc | 2 -- jaxlib/xla/py_device_list.h | 1 - jaxlib/xla/py_executable.h | 4 +--- jaxlib/xla/py_memory_space.h | 1 + jaxlib/xla/py_socket_transfer.cc | 2 ++ jaxlib/xla/python_ref_manager.cc | 2 ++ jaxlib/xla/sharded_device_array.h | 1 - jaxlib/xla/sharding.cc | 1 - jaxlib/xla/sharding.h | 3 ++- jaxlib/xla/to_ifrt_sharding.cc | 1 - jaxlib/xla/to_ifrt_sharding.h | 8 ++++++- jaxlib/xla/traceback.h | 3 ++- jaxlib/xla/xla.cc | 6 ++--- jaxlib/xla/xla_compiler.cc | 1 + 85 files changed, 161 insertions(+), 139 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 52c945482222..c8114b48835f 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -154,6 +154,7 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index c47bc3c8126f..fac62c81dee7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -160,6 +160,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", @@ -336,6 +337,7 @@ nanobind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -343,11 +345,13 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", "@xla//xla/tsl/python/lib/core:numpy", @@ -455,6 +459,7 @@ nanobind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", + ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@local_config_cuda//cuda:cuda_headers", "@nanobind", @@ -545,8 +550,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cuda_asm_compiler", "@xla//xla/tsl/cuda:cudart", @@ -586,7 +591,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index d42199d37467..508a92c326cb 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cuda/versions_helpers.h" #include +#include #include #include "absl/base/dynamic_annotations.h" diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 59c0ab8dc164..1fd2775ecf9a 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -133,6 +133,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", + "@xla//xla/tsl/platform:statusor", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 4a58859016f1..cf391e07e31e 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index 5a434f4b6ad5..5b509ad9912d 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -15,12 +15,15 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include +#include + #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index aecb8a4fdcf1..0326d7f44620 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ -#include +#include #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #define JAX_AS_STATUS(expr) \ diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index d026806e9479..cca615cfb260 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index 3a24e355ead0..1d05fa8adcac 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 1ce428d7f9dc..007e51b76de7 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -15,6 +15,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" namespace jax { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index f5d6abef83f8..1dac1e47bd44 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,16 +17,12 @@ limitations under the License. #include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi_helpers.h" -#include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index d4aaec62320d..e42165f95d15 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -15,8 +15,7 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" -#include -#include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index c98fd485700d..4d64d2b4a4e4 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -16,12 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_PRNG_KERNELS_H_ #define JAXLIB_GPU_PRNG_KERNELS_H_ -#include #include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 8c5404570919..4d48858ad278 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ #define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ -#include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index eaa815d33e68..32e0842e3038 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 45f8ba8187ba..d06535a668ac 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -16,14 +16,19 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" #include +#include +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index e95b7788382a..36d8c25c6a9f 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -17,6 +17,7 @@ limitations under the License. #define JAXLIB_GPU_RNN_KERNELS_H_ #include +#include #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 1cf799bbb491..20fc308100c4 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index a7f8dbebc2b3..592c0f454a55 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include +#include #include -#include -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_helpers.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" #include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index c66e96b6b89e..a9c08317e066 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -15,11 +15,9 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" -#include -#include -#include -#include -#include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,8 +25,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 0d74ebc7d8e4..d735c320307c 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,15 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include +#include #include -#include -#include -#include #include "absl/status/statusor.h" -#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index d0c48eef492f..b3f313e4f7ea 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -13,20 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include #include #include #include +#include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "nanobind/stl/string.h" -#include "nanobind/stl/string_view.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 6565b5b87be2..9e0dc6c855ac 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" @@ -52,7 +53,8 @@ limitations under the License. #endif // JAX_GPU_CUDA #ifdef JAX_GPU_HIP -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index d23a9a7395e0..3ab3e9143fb8 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ +#include #include -#include #include #include #include @@ -25,7 +25,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index f6bbe46c846d..fd63435da177 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 19c64a88c216..a79c098373d1 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 58a02e7c568c..5deb8d4c650a 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef JAXLIB_GPU_VENDOR_H_ #define JAXLIB_GPU_VENDOR_H_ +#include #if defined(JAX_GPU_CUDA) // IWYU pragma: begin_exports diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index dac0355fbde6..5a053f833ce4 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -17,10 +17,10 @@ limitations under the License. #define JAXLIB_KERNEL_HELPERS_H_ #include -#include #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" namespace jax { diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 0eb4a57a2f4b..b8432bf615c9 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -17,18 +17,19 @@ limitations under the License. // This module is called by mlir/__init__.py during initialization. #include -#include "mlir-c/Dialect/Arith.h" -#include "mlir-c/Dialect/Func.h" -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Dialect/Math.h" -#include "mlir-c/Dialect/MemRef.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/SCF.h" -#include "mlir-c/Dialect/Vector.h" +#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep +#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Math.h" // IWYU pragma: keep +#include "mlir-c/Dialect/MemRef.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVGPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/SCF.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Vector.h" // IWYU pragma: keep +#include "mlir-c/IR.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 7d616968b9aa..8f751693e451 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,9 +26,8 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" +#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" @@ -41,7 +39,6 @@ limitations under the License. #include "mlir-c/Support.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep // clang-format off -#include "mlir-c/Bindings/Python/Interop.h" // clang-format on #include "absl/log/check.h" #include "nanobind/nanobind.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..7fba7e1dfe80 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "mlir-c/IR.h" diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 775c34c8e7c7..41584d7692aa 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -95,6 +95,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", ] + mosaic_extension_deps, ) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index e21c8756a4e2..f0e399da0575 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -127,7 +127,7 @@ cc_library( "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:statusor", + "@xla//xla/tsl/platform:statusor", ], ) @@ -151,7 +151,7 @@ cc_test( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", + "@xla//xla/tsl/platform:errors", ], ) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h index bb6cf6e3af4a..5fd0ce7a4f7a 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/CAPI/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 2358a97ba20d..1b3d08f91fb0 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -50,7 +50,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 47b286aec302..474ed93806a1 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -21,14 +21,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" // Generated definitions. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index c259da3e737c..5458ba7fac88 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,7 +26,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/StructBuilder.h" @@ -44,7 +44,7 @@ limitations under the License. #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace mosaic_gpu { namespace { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index ce7e90d45fb9..dee4f5de43d8 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -21,10 +21,13 @@ limitations under the License. #include #include #include +#include #include "absl/log/check.h" #include "absl/log/log.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" @@ -33,11 +36,11 @@ limitations under the License. #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index c54c99fc9825..7ae8681e6980 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index bcfe205d58a9..12bf66cfcec0 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -18,16 +18,16 @@ limitations under the License. #include #include +#include #include -#include #include #include #include #include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 73c119b70e1a..e0e061fbd6dd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -15,27 +15,23 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include -#include #include -#include #include -#include -#include #include "absl/hash/hash.h" #include "absl/log/log.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index cf74689dd3e6..2afaf08f29ed 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -23,15 +23,14 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" -#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/layout.h" +#include "xla/layout.h" // IWYU pragma: keep namespace mlir::tpu { class TPUDialect; @@ -63,11 +62,11 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; - int64_t vmem_banks = -1; // -1 means "unspecified". + int64_t vmem_banks = -1; // -1 means "unspecified". int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; -std::pair mightCommunicateBetweenChips(Operation* op); +std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation = -1, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index b69a6ae06a7f..41342efeb1b4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/layout.h" namespace mlir { namespace tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c9c8d22a1363..e68d5da466eb 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -39,7 +39,6 @@ limitations under the License. #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index a15947f48f78..373a5db6b4f6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -22,26 +22,23 @@ limitations under the License. #include #include -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -// It requires these headers, but does not include them. -// NOLINTNEXTLINE(misc-include-cleaner) -#include "mlir/Dialect/MemRef/IR/MemRef.h" -// NOLINTNEXTLINE(misc-include-cleaner) #include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // IWYU pragma: keep #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 89e3a8bb9f70..7e99dd15611b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -17,13 +17,16 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index e2196088728f..bfb9be87dfd0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -30,7 +30,6 @@ limitations under the License. #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c1a642b48f04..54ac777fc205 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -19,8 +19,6 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/log/check.h" @@ -28,7 +26,6 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index f78df135a45a..1cfb797c5478 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/log/check.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 5da8a9e316e0..e5617ef151f7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index e61d9fa8d417..141f52ec125b 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index e2cf27811f09..eed0df14f707 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" +#include "llvm/Support/Compiler.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -38,7 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d4f4d1732b2e..d9a69c57e142 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index ee11b22020dc..decdbaef28e1 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 1705405d2f32..9fa6f8df78a8 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" + #include #include #include @@ -24,10 +25,13 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 2c13228d3c51..d0c0c798abb8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -148,6 +148,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -318,6 +319,7 @@ nanobind_extension( ":hip_vendor", ":hipsparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -325,12 +327,14 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -496,9 +500,11 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/util:env_var", ], ) @@ -536,7 +542,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 347da6998b57..2c2a76f29f9b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -66,6 +66,7 @@ nanobind_extension( ":xla_compiler", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/status", @@ -164,7 +165,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":python_ref_manager", - "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -178,8 +178,8 @@ cc_library( "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:nb_numpy", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -319,13 +319,13 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@nanobind", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:statusor", "@xla//xla/pjrt:status_casters", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", "@xla//xla/python/ifrt_proxy/client:grpc_client", "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", ], ) @@ -752,7 +752,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", "@nanobind", + "@tsl//tsl/platform:casts", "@xla//xla:util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:status_casters", diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index bb238e6991ec..6f5644c3b0c7 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h index ebd0aaca4e6d..ee1f35ce34a3 100644 --- a/jaxlib/xla/callback.h +++ b/jaxlib/xla/callback.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc index 82f0bd0b0f5a..c68ff7f4ac54 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/xla/config.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc index f88bc93e3af3..3cb53b438e09 100644 --- a/jaxlib/xla/custom_call_sharding.cc +++ b/jaxlib/xla/custom_call_sharding.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "jaxlib/xla/custom_call_sharding.h" #include diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 6c4c24bfe10e..d1cb91114b05 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" #include "llvm/Support/Casting.h" @@ -45,7 +46,6 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc index eda57be86ba5..a89941f8581c 100644 --- a/jaxlib/xla/ifrt_proxy.cc +++ b/jaxlib/xla/ifrt_proxy.cc @@ -42,8 +42,8 @@ #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt_proxy/client/registry.h" -#include "tsl/platform/env.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" namespace nb = ::nanobind; diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index a2e6d725f3b0..9eba2e9d3228 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -40,7 +40,7 @@ limitations under the License. #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/python/nb_helpers.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc index 5905c6c6ec8d..987856daa983 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/xla/mlir.cc @@ -24,10 +24,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/OwningOpRef.h" @@ -44,7 +41,6 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h index e02311e03c73..2bad85e59671 100644 --- a/jaxlib/xla/pmap_lib.h +++ b/jaxlib/xla/pmap_lib.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PMAP_LIB_H_ #define JAXLIB_XLA_PMAP_LIB_H_ -#include -#include -#include // placeholder for index annotation headers #include "nanobind/nanobind.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index 1325f0cbd2bc..a1937bc80327 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -75,7 +75,6 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index 645f51096c1d..7fa2434c7c9f 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 795a4fee29fa..1e41d9cf8a0d 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -48,7 +48,6 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/callback.h" #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" @@ -83,7 +82,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/types.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 898a40141307..29a506d48864 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h index 6071ede52305..4e74992fb2ee 100644 --- a/jaxlib/xla/py_device.h +++ b/jaxlib/xla/py_device.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc index 300e477dbbd0..205c971b9317 100644 --- a/jaxlib/xla/py_device_list.cc +++ b/jaxlib/xla/py_device_list.cc @@ -38,9 +38,7 @@ limitations under the License. #include "jaxlib/xla/python_ref_manager.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_helpers.h" #include "xla/python/types.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" namespace jax { diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h index 1d0f64003f8c..0fa9b3965dfe 100644 --- a/jaxlib/xla/py_device_list.h +++ b/jaxlib/xla/py_device_list.h @@ -26,7 +26,6 @@ limitations under the License. #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device_list.h" -#include "xla/tsl/concurrency/ref_count.h" namespace jax { diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 688eb779df8d..804682db717e 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -26,9 +26,9 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" @@ -37,10 +37,8 @@ limitations under the License. #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h index f111263497fb..f38038af4870 100644 --- a/jaxlib/xla/py_memory_space.h +++ b/jaxlib/xla/py_memory_space.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index 05397cdf116f..55d84fd71bb7 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" +#include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/array.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep @@ -61,6 +62,7 @@ limitations under the License. #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "tsl/platform/casts.h" namespace aux { diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/xla/python_ref_manager.cc index a19622d94244..5b85d2ab84cb 100644 --- a/jaxlib/xla/python_ref_manager.cc +++ b/jaxlib/xla/python_ref_manager.cc @@ -15,6 +15,8 @@ limitations under the License. #include "jaxlib/xla/python_ref_manager.h" +#include + #include #include #include diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h index 1b0ca20aa1fc..6e014789a289 100644 --- a/jaxlib/xla/sharded_device_array.h +++ b/jaxlib/xla/sharded_device_array.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/types/variant.h" #include "nanobind/nanobind.h" #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "xla/python/types.h" diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 409dddb62268..5a80c03e01da 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -38,7 +38,6 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/nb_numpy.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index dac18a4160b5..698ff2ca9ca8 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -30,8 +30,9 @@ limitations under the License. #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/nb_numpy.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace jax { diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 116ead49ad23..2a7c6707e766 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -37,7 +37,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" namespace xla { diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/xla/to_ifrt_sharding.h index 0fa7f17c4563..ebc999888297 100644 --- a/jaxlib/xla/to_ifrt_sharding.h +++ b/jaxlib/xla/to_ifrt_sharding.h @@ -16,12 +16,18 @@ limitations under the License. #ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ #define JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#include +#include +#include + +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/tsl/platform/statusor.h" namespace xla { diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h index 953d626439c4..685ecc5f8793 100644 --- a/jaxlib/xla/traceback.h +++ b/jaxlib/xla/traceback.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef JAXLIB_XLA_TRACEBACK_H_ #define JAXLIB_XLA_TRACEBACK_H_ -#include +#include + #include #include #include diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 668c96869479..e460a1773e94 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -61,12 +62,12 @@ limitations under the License. #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" -#include "xla/python/version.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" -#include "xla/tsl/concurrency/ref_count.h" +#include "xla/python/version.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT #if defined(__linux__) @@ -119,7 +120,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/ops.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index 0098cc28160d..bea3062c64e4 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" From d8fc40f121d59019a31e867b8b1a97a837c15414 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 27 Mar 2025 18:51:48 +0000 Subject: [PATCH 0230/1769] allow saved_input_vjp functions to be jit inputs/outputs --- jax/_src/api.py | 4 ++-- tests/api_test.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4626714b5399..692f049b5f0c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2053,8 +2053,8 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore for r in residuals] - f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(), - jaxpr, opaque_residuals) + f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, + out_tree(), jaxpr), opaque_residuals) if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) diff --git a/tests/api_test.py b/tests/api_test.py index 9d80b5fbed74..6a970051d56e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -11520,6 +11520,25 @@ def f(x, y): self.assertAllClose(y, 6.) self.assertAllClose(arg_cts, (3., 2.)) + def test_basic_pass_through_jit(self): + def f(x, y): + return x * y + + @jax.jit + def g(): + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + return y, f_vjp + + @jax.jit + def h(f_vjp): + return f_vjp(1., 2., 3.) + + y, f_vjp = g() + arg_cts = h(f_vjp) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + def test_basic_unused(self): f = jnp.sin primals = 3., From b02b1fe09267a5d4e7819763d6f9303d9d3a35e8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 12:49:31 -0700 Subject: [PATCH 0231/1769] Update Windows bazelrc configs to ltsc2022 PiperOrigin-RevId: 741249289 --- .bazelrc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 76f72b0848a9..2d38dcc87044 100644 --- a/.bazelrc +++ b/.bazelrc @@ -260,8 +260,8 @@ build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE build:ci_windows_amd64 --color=yes @@ -329,9 +329,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst build:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe build:rbe_windows_amd64 --enable_runfiles From 358c55d06650fc9cea39943f6e91f78219d52eb8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 12:50:44 -0700 Subject: [PATCH 0232/1769] Update instructions for usage of `:build_jaxlib=false` flag. By adding [jax wheel testing](https://github.com/jax-ml/jax/pull/27113) functionality, we need to have pre-built jax and jaxlib wheels. PiperOrigin-RevId: 741249718 --- build/build.py | 5 ++++- docs/developer.md | 16 ++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/build/build.py b/build/build.py index 4d16851f837c..1900073fc132 100755 --- a/build/build.py +++ b/build/build.py @@ -414,7 +414,10 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if ( + not hasattr(args,"use_new_wheel_build_rule") + or args.command == "requirements_update" + ): bazel_command_base.append("run") else: bazel_command_base.append("build") diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..b1a978ffd0d6 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -526,23 +526,27 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. -To use a preinstalled `jaxlib` instead of building it you first need to -make it available in the hermetic Python. To install a specific version of -`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): +To use the preinstalled `jax` and `jaxlib` instead of building them you first +need to make them available in the hermetic Python. To install the specific +versions of `jax` and `jaxlib` within hermetic Python run (using `jax >= 0.4.26` +and `jaxlib >= 0.4.26` as an example): ``` +echo -e "\njax >= 0.4.26" >> build/requirements.in echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py requirements_update ``` -Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): +Alternatively, to install `jax` and `jaxlib` from the local wheels +(assuming Python 3.12): ``` +echo -e "\n$(realpath jax-0.4.26-py3-none-any.whl)" >> build/requirements.in echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` -Once you have `jaxlib` installed hermetically, run: +Once you have `jax` and `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests From b290c132dd5b28e11e4d17f495b91a9bc8e88eac Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 27 Mar 2025 13:03:42 -0700 Subject: [PATCH 0233/1769] [jax:custom_partitioning] Raise an error when Shardy is used but the old sharding propagation callbacks instead of sharding rule are provided. PiperOrigin-RevId: 741253832 --- jax/_src/custom_partitioning.py | 6 +++++ tests/cache_key_test.py | 3 ++- tests/pjit_test.py | 39 +++++++++++++++++++++++++++++++++ tests/shard_map_test.py | 1 + 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 5374071517f1..feb1e0c39cc6 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -500,6 +500,12 @@ def __call__(self, *args, **kwargs): infer_sharding_from_operands = None sharding_rule = None if config.use_shardy_partitioner.value: + if (self.sharding_rule is None and + (self.propagate_user_sharding is not None or + self.infer_sharding_from_operands is not None)): + raise ValueError("Shardy is used, but sharding propagation callbacks " + "instead of sharding_rule are provided. Need to " + "provide sharding_rule to migrate to Shardy.") sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index ed80c7060e4c..a908d260d560 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -180,7 +180,8 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, - partition=_partition) + partition=_partition, + sharding_rule='i i -> i') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index aa5afccb38d4..2cfe61cdf1fe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8207,5 +8207,44 @@ def f(x, y, static_arg0=1, static_arg1=2): self.assertArraysEqual(result, expected_result) self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) + def test_custom_partition_shardy_migration(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(32 * 16).reshape(32, 16), + NamedSharding(mesh, P(None, 'x'))) + with self.assertRaisesRegex(ValueError, "provide sharding_rule to migrate " + "to Shardy"): + jax.jit(f)(x) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ce01b6e6e944..1ffb3e1d137a 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3105,6 +3105,7 @@ def f(x): infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, + sharding_rule='i -> i', ) @jax.jit From 591c327e613507d1d4bb9706d4f353c2d4835eba Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 14:09:15 -0700 Subject: [PATCH 0234/1769] Remove unused build dependencies in jaxlib/xla/... PiperOrigin-RevId: 741276224 --- jaxlib/xla/BUILD | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2c2a76f29f9b..2ca18afda13d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -268,7 +268,6 @@ cc_library( "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_layout", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -360,7 +359,6 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_helpers", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -383,8 +381,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeWriter", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", @@ -398,7 +394,6 @@ cc_library( "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:status_casters", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/service/llvm_ir:llvm_util", "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", "@xla//xla/tsl/platform:statusor", @@ -547,18 +542,15 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/py_client"), deps = [ - ":callback", ":guard_lib", ":nb_class_ptr", ":py_client_cpu", ":py_host_callback", - ":py_host_callback_cc_proto", ":python_ref_manager", ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -571,56 +563,36 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@nanobind", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@tsl//tsl/platform:casts", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:ml_dtypes", - "@tsl//tsl/profiler/lib:profiler_session", "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:comparison_util", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla:status_macros", "@xla//xla:types", "@xla//xla:util", "@xla//xla:xla_data_proto_cc", - "@xla//xla/hlo/builder:xla_builder", - "@xla//xla/hlo/builder:xla_computation", - "@xla//xla/hlo/builder/lib:arithmetic", "@xla//xla/hlo/ir:hlo", "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:host_memory_spaces", "@xla//xla/pjrt:lru_cache", "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_device_description", "@xla//xla/pjrt:pjrt_executable", "@xla//xla/pjrt:pjrt_future", "@xla//xla/pjrt:pjrt_layout", "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt:transpose", - "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/distributed:client", - "@xla//xla/python:aggregate_profile", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:types", - "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", @@ -630,20 +602,11 @@ cc_library( "@xla//xla/python/ifrt:user_context", "@xla//xla/python/ifrt/hlo:hlo_program", "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "@xla//xla/python/pjrt_ifrt:pjrt_dtype", - "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "@xla//xla/python/pjrt_ifrt:xla_ifrt", - "@xla//xla/service:computation_placer_hdr", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/service/spmd/shardy:constants", - "@xla//xla/service/spmd/shardy:utils", - "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/framework:allocator", - "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@xla//xla/tsl/platform:env", "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", @@ -700,7 +663,6 @@ cc_library( ":py_host_callback_cc_proto", ":python_ref_manager", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -713,7 +675,6 @@ cc_library( "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -852,7 +813,6 @@ cc_library( "@llvm-project//mlir:Support", "@nanobind", "@shardy//shardy/dialect/sdy/ir:dialect", - "@shardy//shardy/dialect/sdy/transforms/import:passes", "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@xla//xla/mlir_hlo:all_passes", "@xla//xla/pjrt:mlir_to_hlo", From 71b36dca8406538898df5e61b21ba29f6ef79ad5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 27 Mar 2025 14:42:40 -0700 Subject: [PATCH 0235/1769] Sort the replicated_axes wrt mesh names in Shardy PiperOrigin-RevId: 741287495 --- jax/_src/sharding_impls.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index efa1b4cfd5b6..d95b12f244ba 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -105,6 +105,9 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): if not d.axes and d.is_closed else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + remaining_axes = [n for n in mesh.axis_names if n in remaining_axes] replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, From d08676e927e4917a629d609ac80770d208187f9a Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Thu, 27 Mar 2025 16:48:28 -0700 Subject: [PATCH 0236/1769] Disable `lax_numpy_test` tsan tests. PiperOrigin-RevId: 741325580 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index 1d021b0c7110..2526be066635 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -496,6 +496,7 @@ jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { + "tpu": ["notsan"], # Test times out. "cpu": ["notsan"], # Test times out. }, shard_count = { From 25c106d132d01856ac3e1ad40b7ff52c65fafc4c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 27 Mar 2025 16:55:45 -0700 Subject: [PATCH 0237/1769] Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add `standard_insert_broadcast` for unary ops though) * slicing.py * windowed_reductions.py * special.py * convolution.py * fft.py * linalg.py * ann.py PiperOrigin-RevId: 741327361 --- jax/_src/ad_util.py | 1 + jax/_src/core.py | 15 +++- jax/_src/lax/ann.py | 6 +- jax/_src/lax/control_flow/loops.py | 3 +- jax/_src/lax/convolution.py | 4 +- jax/_src/lax/fft.py | 4 +- jax/_src/lax/lax.py | 116 ++++++---------------------- jax/_src/lax/linalg.py | 10 ++- jax/_src/lax/slicing.py | 42 +++++++--- jax/_src/lax/special.py | 8 ++ jax/_src/lax/windowed_reductions.py | 29 +++++-- 11 files changed, 123 insertions(+), 115 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c729a57cfb11..c8e64ce5c2ef 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,7 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + x, y = core.standard_insert_pbroadcast(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') diff --git a/jax/_src/core.py b/jax/_src/core.py index 3a1558802682..ca353486afd5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2018,6 +2018,16 @@ def standard_insert_pbroadcast(*args): return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) if out_vma - src else arg for arg, src in zip(args, in_vma)] +def standard_vma_rule(prim_name, *avals, **kwargs): + vma, *vmas = [a.vma for a in avals] + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_rep=False argument to shard_map') + return vma + # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. @@ -2697,7 +2707,10 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + sh_dt = t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + if config.varying_axes_in_types.value: + return sh_dt and t1.vma == t2.vma # type: ignore + return sh_dt else: return False diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0e037ec774b5..0d2eb338da22 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -77,6 +77,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src import ad_util from jax._src import core from jax._src import dispatch +from jax._src import config from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -239,9 +240,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, "approx_top_k with aggregate_to_topk=False not yet implemented when " f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") + out_vma = operand.vma if config.varying_axes_in_types.value else frozenset() return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type), - operand.update(shape=dims, dtype=np.dtype(np.int32))) + weak_type=operand.weak_type, vma=out_vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=out_vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 88af7c24e5b8..c7bcb1cf6b09 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2520,7 +2520,8 @@ def _cumred_dtype_rule(name, operand, *args, **kw): def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name, sharding_rule=_cumred_sharding_rule) + name, sharding_rule=_cumred_sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 290d027cc6bc..32294bbd72cf 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -158,6 +158,7 @@ def conv_general_dilated( preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -633,7 +634,8 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, - 'conv_general_dilated') + 'conv_general_dilated', + vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6ca1a4abd193..9044f48f278c 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,6 +23,7 @@ from jax import lax +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -124,7 +125,8 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - return x.update(shape=shape, dtype=dtype) + out_vma = x.vma if config.varying_axes_in_types.value else frozenset() + return x.update(shape=shape, dtype=dtype, vma=out_vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 655ef763f1ef..fcd7aba380bb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -296,7 +296,6 @@ def neg(x: ArrayLike) -> Array: .. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate """ - x, = core.standard_insert_pbroadcast(x) return neg_p.bind(x) @export @@ -340,7 +339,6 @@ def sign(x: ArrayLike) -> Array: .. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign """ - x, = core.standard_insert_pbroadcast(x) return sign_p.bind(x) @export @@ -393,7 +391,6 @@ def floor(x: ArrayLike) -> Array: .. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor """ - x, = core.standard_insert_pbroadcast(x) return floor_p.bind(x) @export @@ -415,7 +412,6 @@ def ceil(x: ArrayLike) -> Array: .. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil """ - x, = core.standard_insert_pbroadcast(x) return ceil_p.bind(x) class RoundingMethod(enum.IntEnum): @@ -465,7 +461,6 @@ def round(x: ArrayLike, .. _stablehlo.round: https://openxla.org/stablehlo/spec#round """ rounding_method = RoundingMethod(rounding_method) - x, = core.standard_insert_pbroadcast(x) return round_p.bind(x, rounding_method=rounding_method) @export @@ -487,7 +482,6 @@ def is_finite(x: ArrayLike) -> Array: .. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite """ - x, = core.standard_insert_pbroadcast(x) return is_finite_p.bind(x) @export @@ -509,7 +503,6 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - x, = core.standard_insert_pbroadcast(x) return exp_p.bind(x) @export @@ -533,7 +526,6 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - x, = core.standard_insert_pbroadcast(x) return exp2_p.bind(x) @export @@ -557,7 +549,6 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - x, = core.standard_insert_pbroadcast(x) return expm1_p.bind(x) @export @@ -578,7 +569,6 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - x, = core.standard_insert_pbroadcast(x) return log_p.bind(x) @export @@ -602,7 +592,6 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - x, = core.standard_insert_pbroadcast(x) return log1p_p.bind(x) @export @@ -625,7 +614,6 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - x, = core.standard_insert_pbroadcast(x) return tanh_p.bind(x) @export @@ -645,7 +633,6 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - x, = core.standard_insert_pbroadcast(x) return logistic_p.bind(x) @export @@ -670,7 +657,6 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - x, = core.standard_insert_pbroadcast(x) return sin_p.bind(x) @export @@ -695,7 +681,6 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - x, = core.standard_insert_pbroadcast(x) return cos_p.bind(x) @export @@ -743,7 +728,6 @@ def real(x: ArrayLike) -> Array: .. _stablehlo.real: https://openxla.org/stablehlo/spec#real """ - x, = core.standard_insert_pbroadcast(x) return real_p.bind(x) @export @@ -766,7 +750,6 @@ def imag(x: ArrayLike) -> Array: .. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag """ - x, = core.standard_insert_pbroadcast(x) return imag_p.bind(x) @export @@ -819,7 +802,6 @@ def conj(x: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ # TODO(mattjj): remove input_dtype, not needed anymore - x, = core.standard_insert_pbroadcast(x) return conj_p.bind(x, input_dtype=_dtype(x)) @export @@ -840,7 +822,6 @@ def abs(x: ArrayLike) -> Array: .. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs """ - x, = core.standard_insert_pbroadcast(x) return abs_p.bind(x) @export @@ -888,7 +869,6 @@ def integer_pow(x: ArrayLike, y: int) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - x, = core.standard_insert_pbroadcast(x) return integer_pow_p.bind(x, y=y) @export @@ -910,7 +890,6 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - x, = core.standard_insert_pbroadcast(x) return sqrt_p.bind(x) @export @@ -933,7 +912,6 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - x, = core.standard_insert_pbroadcast(x) return rsqrt_p.bind(x) @export @@ -955,7 +933,6 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - x, = core.standard_insert_pbroadcast(x) return cbrt_p.bind(x) @export @@ -980,7 +957,6 @@ def bitwise_not(x: ArrayLike) -> Array: .. _stablehlo.not: https://openxla.org/stablehlo/spec#not """ - x, = core.standard_insert_pbroadcast(x) return not_p.bind(x) @export @@ -1083,7 +1059,6 @@ def population_count(x: ArrayLike) -> Array: .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt """ - x, = core.standard_insert_pbroadcast(x) return population_count_p.bind(x) @export @@ -1104,7 +1079,6 @@ def clz(x: ArrayLike) -> Array: .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros """ - x, = core.standard_insert_pbroadcast(x) return clz_p.bind(x) @export @@ -1623,7 +1597,6 @@ def _convert_element_type( f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") - operand, = core.standard_insert_pbroadcast(operand) if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) @@ -1662,7 +1635,6 @@ def _convert_element_type( (sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))): return operand else: - operand, = core.standard_insert_pbroadcast(operand) return convert_element_type_p.bind( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) @@ -1699,7 +1671,6 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) - operand, = core.standard_insert_pbroadcast(operand) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @@ -1956,7 +1927,6 @@ def split(operand: ArrayLike, sizes: Sequence[int], taken along ``axis``. """ operand = asarray(operand) - operand, = core.standard_insert_pbroadcast(operand) return split_p.bind(operand, sizes=tuple(sizes), axis=canonicalize_axis(axis, operand.ndim)) @@ -2662,7 +2632,6 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) else: dyn_shape, static_shape = [], shape # type: ignore - operand, = core.standard_insert_pbroadcast(operand) return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions), @@ -2729,7 +2698,6 @@ def reshape(operand: ArrayLike, new_sizes: Shape, else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) out_sharding = canonicalize_sharding(out_sharding, 'reshape') - operand, = core.standard_insert_pbroadcast(operand) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, @@ -2793,7 +2761,6 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: `_ operator. """ - operand, = core.standard_insert_pbroadcast(operand) return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: @@ -2860,20 +2827,18 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: - operand, = core.standard_insert_pbroadcast(operand) + return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" - operand, = core.standard_insert_pbroadcast(operand) return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" - operand, = core.standard_insert_pbroadcast(operand) return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) @@ -3039,7 +3004,6 @@ def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_sum_p.bind(operand, axes=tuple(axes)) def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3066,7 +3030,6 @@ def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_prod_p.bind(operand, axes=tuple(axes)) def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3088,7 +3051,6 @@ def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_max_p.bind(operand, axes=tuple(axes)) def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3110,7 +3072,6 @@ def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_min_p.bind(operand, axes=tuple(axes)) def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3133,7 +3094,6 @@ def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_or_p.bind(operand, axes=tuple(axes)) def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3156,7 +3116,6 @@ def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_and_p.bind(operand, axes=tuple(axes)) def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3179,7 +3138,6 @@ def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_xor_p.bind(operand, axes=tuple(axes)) @overload @@ -3265,7 +3223,6 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k = int(k) if k < 0: raise ValueError(f"k argument to top_k must be nonnegative, got {k}") - operand, = core.standard_insert_pbroadcast(operand) return top_k_p.bind(operand, k=k) def tie_in(x: Any, y: T) -> T: @@ -3451,7 +3408,6 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - operand, = core.standard_insert_pbroadcast(operand) return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) @@ -3461,7 +3417,6 @@ def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) if not dimensions and isinstance(array, Array): return array - array, = core.standard_insert_pbroadcast(array) return squeeze_p.bind(array, dimensions=dimensions) def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -3582,7 +3537,6 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" - x, = core.standard_insert_pbroadcast(x) return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: @@ -3610,7 +3564,6 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - x, = core.standard_insert_pbroadcast(x) return tan_p.bind(x) @export @@ -3631,7 +3584,6 @@ def asin(x: ArrayLike) -> Array: - :func:`jax.lax.acos`: elementwise arc cosine. - :func:`jax.lax.atan`: elementwise arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return asin_p.bind(x) @export @@ -3652,7 +3604,6 @@ def acos(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan`: elementwise arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return acos_p.bind(x) @export @@ -3674,7 +3625,6 @@ def atan(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan2`: elementwise 2-term arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return atan_p.bind(x) @export @@ -3695,7 +3645,6 @@ def sinh(x: ArrayLike) -> Array: - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return sinh_p.bind(x) @export @@ -3716,7 +3665,6 @@ def cosh(x: ArrayLike) -> Array: - :func:`jax.lax.sinh`: elementwise hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return cosh_p.bind(x) @export @@ -3737,7 +3685,6 @@ def asinh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.sinh`: elementwise hyperbolic sine. """ - x, = core.standard_insert_pbroadcast(x) return asinh_p.bind(x) @export @@ -3758,7 +3705,6 @@ def acosh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. """ - x, = core.standard_insert_pbroadcast(x) return acosh_p.bind(x) @export @@ -3779,7 +3725,6 @@ def atanh(x: ArrayLike) -> Array: - :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return atanh_p.bind(x) @@ -3937,16 +3882,6 @@ def broadcasting_sharding_rule(name, *avals): f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) -def standard_vma_rule(prim_name, *avals, **kwargs): - vma, *vmas = [a.vma for a in avals] - if not all(vma == vma_ for vma_ in vmas): - raise ValueError( - f'Primitive {prim_name} requires varying manual axes ' - f'to match, but got {[vma, *vmas]}. Please open an issue at ' - 'https://github.com/jax-ml/jax/issues and as a temporary ' - 'workaround pass the check_rep=False argument to shard_map') - return vma - def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, @@ -3956,7 +3891,7 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, sharding_rule = partial(broadcasting_sharding_rule, name) prim = standard_primitive( shape_rule, dtype_rule, name, sharding_rule=sharding_rule, - vma_rule=partial(standard_vma_rule, name)) + vma_rule=partial(core.standard_vma_rule, name)) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -4808,7 +4743,7 @@ def _convert_element_type_bind_with_trace(trace, args, params): _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, _convert_element_type_sharding_rule, - partial(standard_vma_rule, convert_element_type_p.name))) + partial(core.standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) @@ -4974,7 +4909,7 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, sharding_rule=_bitcast_convert_type_sharding_rule, - vma_rule=partial(standard_vma_rule, 'bitcast_convert_type')) + vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -5443,7 +5378,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, - vma_rule=partial(standard_vma_rule, 'dot_general') + vma_rule=partial(core.standard_vma_rule, 'dot_general') ) @@ -6444,7 +6379,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - new_vma = (standard_vma_rule('broadcast_in_dim', x) + new_vma = (core.standard_vma_rule('broadcast_in_dim', x) if config.varying_axes_in_types.value else frozenset()) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, vma=new_vma) @@ -6532,7 +6467,7 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', sharding_rule=_clamp_sharding_rule, - vma_rule=partial(standard_vma_rule, 'clamp')) + vma_rule=partial(core.standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6620,7 +6555,7 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', sharding_rule=_concatenate_sharding_rule, - vma_rule=partial(standard_vma_rule, 'concatenate')) + vma_rule=partial(core.standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6693,7 +6628,7 @@ def _split_sharding_rule(operand, *, sizes, axis): for out_sh in out_shapes] def _split_vma_rule(operand, *, sizes, axis): - out_vma = standard_vma_rule('split', operand) + out_vma = core.standard_vma_rule('split', operand) out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) return [out_vma] * len(out_shapes) @@ -6785,7 +6720,7 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', sharding_rule=_pad_sharding_rule, - vma_rule=partial(standard_vma_rule, 'pad')) + vma_rule=partial(core.standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6850,7 +6785,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, 'squeeze', sharding_rule=_squeeze_sharding_rule, - vma_rule=partial(standard_vma_rule, 'squeeze')) + vma_rule=partial(core.standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -7085,7 +7020,7 @@ def _reshape_staging_rule( reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, 'reshape', sharding_rule=_reshape_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reshape')) + vma_rule=partial(core.standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -7118,7 +7053,7 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', sharding_rule=_rev_sharding_rule, - vma_rule=partial(standard_vma_rule, 'rev')) + vma_rule=partial(core.standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -7167,7 +7102,7 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', sharding_rule=_transpose_sharding_rule, - vma_rule=partial(standard_vma_rule, 'transpose')) + vma_rule=partial(core.standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -7344,7 +7279,7 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, - vma_rule=partial(standard_vma_rule, 'select_n')) + vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -7526,7 +7461,7 @@ def _reduce_op_sharding_rule(operand, *, axes): reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_sum')) + vma_rule=partial(core.standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7542,7 +7477,7 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_prod')) + vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7563,7 +7498,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_max')) + vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7574,7 +7509,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_min')) + vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7642,14 +7577,14 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, sharding_rule=_argminmax_sharding_rule, - vma_rule=partial(standard_vma_rule, 'argmin')) + vma_rule=partial(core.standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, sharding_rule=_argminmax_sharding_rule, - vma_rule=partial(standard_vma_rule, 'argmax')) + vma_rule=partial(core.standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7673,14 +7608,14 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_or')) + vma_rule=partial(core.standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_and')) + vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule @@ -7688,7 +7623,7 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_xor')) + vma_rule=partial(core.standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7736,7 +7671,7 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_precision')) + vma_rule=partial(core.standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -8368,7 +8303,6 @@ def rng_bit_generator(key, shape, dtype=np.uint32, if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') - key, = core.standard_insert_pbroadcast(key) return tuple( rng_bit_generator_p.bind( key, shape=shape, dtype=dtype, algorithm=algorithm, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b455257e107c..d53d54599527 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,6 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ + r_matrix, w_vector = core.standard_insert_pbroadcast(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) @@ -268,6 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ + a, taus = core.standard_insert_pbroadcast(a, taus) return householder_product_p.bind(a, taus) @@ -545,6 +547,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ + a_matrix, c_matrix = core.standard_insert_pbroadcast(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -602,6 +605,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) + a, b = core.standard_insert_pbroadcast(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -661,6 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ + dl, d, du, b = core.standard_insert_pbroadcast(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -742,7 +747,7 @@ def linalg_sharding_rule( def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): output_shapes = shape_rule(*avals, **kwargs) - out_vma = lax_internal.standard_vma_rule(name, *avals) + out_vma = core.standard_vma_rule(name, *avals) if multiple_results: return [out_vma] * len(output_shapes) else: @@ -775,7 +780,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, sharding_rule, - partial(lax_internal.standard_vma_rule, name))) + partial(core.standard_vma_rule, name))) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) @@ -1768,6 +1773,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ + a, jpvt = core.standard_insert_pbroadcast(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index d3bcb6da2807..3f4d1b6d576f 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -234,6 +234,7 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) + operand, update = core.standard_insert_pbroadcast(operand, update) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -416,6 +417,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None + operand, start_indices = core.standard_insert_pbroadcast(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -505,6 +507,8 @@ def scatter_add( """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -559,6 +563,8 @@ def scatter_sub( jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_sub_p.bind( operand, scatter_indices, @@ -613,6 +619,8 @@ def scatter_mul( """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -660,6 +668,8 @@ def scatter_min( """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -707,6 +717,8 @@ def scatter_max( """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -854,6 +866,8 @@ def scatter( ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([0., 2., 3., 0., 4.], dtype=float32) """ + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, update_consts=(), dimension_numbers=dimension_numbers, @@ -1393,7 +1407,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, return out, bdim slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', - sharding_rule=_slice_sharding_rule) + sharding_rule=_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'slice')) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1559,7 +1574,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0), - sharding_rule=_dynamic_slice_sharding_rule) + sharding_rule=_dynamic_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_slice')) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1679,7 +1695,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice')) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -2117,7 +2134,8 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *, gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'gather')) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule @@ -2599,7 +2617,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_add')) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( @@ -2610,6 +2629,7 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, _scatter_dtype_rule, "scatter-sub", weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_sub') ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) @@ -2619,7 +2639,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_mul')) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -2748,14 +2769,16 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_min')) batching.primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_max')) batching.primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) @@ -2913,7 +2936,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter')) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index ba2687d4acd7..041205156d58 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -21,6 +21,7 @@ import numpy as np from functools import partial +from jax._src import core from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or, broadcast_in_dim, broadcast_shapes, convert_element_type, div, eq, exp, full_like, ge, @@ -39,6 +40,7 @@ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" + a, b, x = core.standard_insert_pbroadcast(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -51,26 +53,32 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" + m, x = core.standard_insert_pbroadcast(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igamma_grad_a_p.bind(a, x) def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" + a, x = core.standard_insert_pbroadcast(a, x) return random_gamma_grad_p.bind(a, x) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + x, q = core.standard_insert_pbroadcast(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 42b2e9278889..00bdfe75f3e7 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -21,6 +21,7 @@ from jax import tree_util from jax._src import api_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import util @@ -97,6 +98,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') + flat_operands = core.standard_insert_pbroadcast(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -250,6 +252,8 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) + operand, source, init_value = core.standard_insert_pbroadcast( + operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, @@ -261,6 +265,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: + source, operand = core.standard_insert_pbroadcast(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -296,6 +301,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ + tangents, operand = core.standard_insert_pbroadcast(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -332,7 +338,10 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding) + out_vma = (core.standard_vma_rule('reduce_window', operand_avals) + if config.varying_axes_in_types.value else frozenset()) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, + vma=out_vma) for op in operand_avals) @@ -532,7 +541,8 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, reduce_window_sum_p = lax.standard_primitive( _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_sum')) ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) @@ -598,7 +608,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_max')) ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, lax.max_p)) batching.primitive_batchers[reduce_window_max_p] = partial( @@ -606,7 +617,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_min_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_min')) ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, lax.min_p)) @@ -671,7 +683,8 @@ def _select_and_scatter_shape_rule( return operand.shape select_and_scatter_p = lax.standard_primitive( - _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter') + _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( ctx, operand, source, init_value, *, select_jaxpr, @@ -766,7 +779,8 @@ def _select_and_scatter_add_batch_rule( select_and_scatter_add_p = lax.standard_primitive( _select_and_scatter_add_shape_rule, lax._input_dtype, - 'select_and_scatter_add') + 'select_and_scatter_add', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ _select_and_scatter_add_transpose @@ -1039,7 +1053,8 @@ def _select_and_gather_add_batching_rule( select_and_gather_add_p = lax.standard_primitive( _select_and_gather_add_shape_rule, lax._input_dtype, - 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule) + 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_gather_add')) ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose From a52f7b26e7d5b2696a73a150518441204a2d9565 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Thu, 27 Mar 2025 17:12:08 -0700 Subject: [PATCH 0238/1769] Add accuracy field to unary ops * Cbrt * Cos * Exp, Exp2 * Expm1 * Log * Logistic * Log1p * Rsqrt * Sin * Sqrt * Tan * Tanh which allows users to select implementation that will satisfy the requested accuracy. PiperOrigin-RevId: 741331787 --- jax/_src/api.py | 13 +- jax/_src/internal_test_util/test_harnesses.py | 26 +- jax/_src/lax/lax.py | 248 +++++++++--- jax/_src/pallas/mosaic/lowering.py | 44 ++- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +- jax/_src/pallas/triton/lowering.py | 6 +- jax/experimental/jax2tf/jax2tf.py | 27 +- jax/experimental/jet.py | 18 +- tests/BUILD | 14 + tests/api_test.py | 154 ++++---- tests/core_test.py | 8 +- tests/pallas/ops_test.py | 29 +- tests/pmap_test.py | 16 +- tests/unary_ops_accuracy_test.py | 373 ++++++++++++++++++ 14 files changed, 782 insertions(+), 218 deletions(-) create mode 100644 tests/unary_ops_accuracy_test.py diff --git a/jax/_src/api.py b/jax/_src/api.py index 692f049b5f0c..e01bdd4a9d81 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2269,13 +2269,16 @@ def make_jaxpr( >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } + { lambda ; a:f32[]. let + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] b + in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let - b:f32[] = cos a - c:f32[] = sin a - _:f32[] = sin b - d:f32[] = cos b + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] a + _:f32[] = sin[accuracy=None] b + d:f32[] = cos[accuracy=None] b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 02779c85977e..b557434ac7f3 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +429,19 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fcd7aba380bb..b79c81e19195 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -484,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -503,10 +530,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -514,6 +541,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -526,10 +559,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -538,6 +571,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -549,16 +588,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -569,10 +614,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -581,6 +626,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -592,16 +643,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -614,10 +671,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -633,10 +691,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -645,6 +703,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -657,10 +721,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -669,6 +733,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -681,7 +751,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -871,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array: """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -890,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -912,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -933,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -3544,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3564,7 +3659,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3958,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): return out -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params @@ -3968,6 +4064,8 @@ def _nary_lower_hlo(op: Callable, ctx, args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] @@ -4029,43 +4127,57 @@ def _round_lower(ctx, x, *, rounding_method): mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): + +def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) @@ -4088,20 +4200,26 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) + -def _sin_lin(nzs, x): +def _sin_p_lin(nzs, x, accuracy): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return ( + sin_p.bind(x, accuracy=accuracy), + nz, + cos_x, + lambda cos_x_, t: mul(t, cos_x_), + ) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_lin +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule @@ -4117,18 +4235,20 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') @@ -4245,18 +4365,23 @@ def _abs_jvp_rule(g, ans, x): _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) square_p = standard_unop(_int | _float | _complex, 'square') @@ -5463,6 +5588,17 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 537a2cc07575..617324d43bf9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2549,14 +2549,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) @@ -2572,7 +2576,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.square_p] = _square_lowering_rule -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) @@ -2605,9 +2611,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return lower_fun( lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), multiple_results=False, @@ -2618,7 +2626,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.exp2_p) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] @@ -2636,42 +2646,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.logistic_p] = _logistic_lowering_rule -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) lowering_rules[lax.sin_p] = _sin_lowering_rule -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) lowering_rules[lax.cos_p] = _cos_lowering_rule -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) lowering_rules[lax.tan_p] = _tan_lowering_rule -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) lowering_rules[lax.tanh_p] = _tanh_lowering_rule -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) lowering_rules[lax.log_p] = _log_lowering_rule -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 286fedfa44d5..0c9f70937873 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1584,7 +1584,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) @@ -1598,7 +1600,9 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) @@ -1608,7 +1612,9 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) @@ -1622,7 +1628,9 @@ def _logistic(x): @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) @@ -1633,7 +1641,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) @@ -1645,7 +1655,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index c85c5f0a39c0..150ae9b8b2d7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -654,7 +654,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1404,7 +1406,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), } for prim, fn in _JAX_FN_MAPPING.items(): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3d71af38388b..492e070de1af 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin +tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x) +tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp( + tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x) +) +tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x) +tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x) +tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x) +tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x) +tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x) +tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x) tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos +tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x) tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) @@ -1706,11 +1707,11 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asin_p] = tf.math.asin tf_impl[lax.acos_p] = tf.math.acos -tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x) tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt +tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x) -def _cbrt(x): +def _cbrt(x, accuracy): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd02a..acf8885b0f98 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -76,7 +76,7 @@ from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -405,11 +405,11 @@ def deriv_prop(prim, deriv, primals_in, series_in): lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -478,7 +478,7 @@ def _scale(k, j): def _scale2(k, j): return 1. / (fact(k - j) * fact(j)) -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -590,7 +590,7 @@ def scale(k, j): return 1. / (fact(k - j) * fact(j)) return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) diff --git a/tests/BUILD b/tests/BUILD index 2526be066635..b501a614da39 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1640,6 +1640,20 @@ jax_multiplatform_test( deps = ["//jax:experimental"], ) +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + disable_configs = [ + "tpu_pjrt_c_api", + ], + enable_backends = [ + "tpu", + ], + deps = [ + "//jax:experimental", + ], +) + jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index 6a970051d56e..9710131a92fa 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4780,7 +4780,7 @@ def sin_of_sin(x): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4827,8 +4827,8 @@ def f(x): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5092,11 +5092,11 @@ def f_yesremat(x): jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5121,7 +5121,7 @@ def f(x, y): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5449,9 +5449,9 @@ def f(x): ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ - ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), - ('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []), - ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']), + ('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']), + ('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []), + ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']), ]) def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): for square in [lambda x: x * x, api.jit(lambda x: x * x)]: @@ -5481,8 +5481,8 @@ def test_remat_custom_policy_save_cos(self, remat): policy=save_cos) _, f_lin = api.linearize(f, 1.) jaxpr_text = str(f_lin.func.args[0]) - self.assertNotIn(' sin ', jaxpr_text) - self.assertNotIn(' cos ', jaxpr_text) + self.assertNotIn(' sin[', jaxpr_text) + self.assertNotIn(' cos[', jaxpr_text) jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev']) @parameterized.named_parameters( @@ -5504,7 +5504,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5527,7 +5527,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5550,7 +5550,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((3, 2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 9) jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5574,7 +5574,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5598,8 +5598,8 @@ def body(x, _): return f(x), None # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which # effectively saves one sine. Three cosines for the Jacobian coefficients. - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compure the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5905,8 +5905,8 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5951,8 +5951,8 @@ def body(x, _): return f(x), None jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 3) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 3) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5969,8 +5969,8 @@ def test_remat_of_scan_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): @@ -5993,40 +5993,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -6051,40 +6051,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6099,8 +6099,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertNotIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertNotIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) true_fn = lambda c: jnp.sin(jnp.sin(c)) false_fn = lambda c: c @@ -6108,8 +6108,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6149,8 +6149,8 @@ def f(x): _, f_vjp = api.vjp(f, jnp.ones((5, 5))) jaxpr_text = str(f_vjp.args[0].func.args[1]) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Five calls to dot_general in the backward pass because we have two for # each forward-pass dot, except for the first which only has one (as we are # differentiating with respect to only W and not x). @@ -6180,8 +6180,8 @@ def f(x): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) self.assertEqual(jaxpr_text.count(' dot_'), 5) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, @@ -6195,8 +6195,8 @@ def test_remat_of_cond_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): @@ -6218,40 +6218,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -6275,40 +6275,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6333,8 +6333,8 @@ def f(x): self.assertArraysAllClose(y_dot, expected, check_dtypes=False) jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) def test_remat_of_while_loop_policy(self): def cond_fn(carry): @@ -6351,8 +6351,8 @@ def f(x): save_cos = lambda prim, *_, **__: str(prim) == 'cos' g = new_checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bda54..03d6355cb257 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -474,8 +474,8 @@ def new_jaxpr(): # jaxpr is: # # { lambda ; a. - # let b = sin a - # c = cos a + # let b = sin[accuracy=None] a + # c = cos[accuracy=None] a # d = add b c # in (d,) } # @@ -487,7 +487,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +496,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8d5dc471e847..f5b70878533d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -204,7 +204,7 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -214,7 +214,7 @@ def select_n_strategy( ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -226,6 +226,7 @@ def select_n_strategy( *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -234,23 +235,23 @@ def select_n_strategy( valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..d40293501edf 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2082,8 +2082,8 @@ def test_remat_of_pmap(self, remat): x = jnp.arange(1.) jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -2100,24 +2100,24 @@ def test_remat_of_pmap_policy(self, remat): _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 2) save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000000..fb370ab96923 --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,373 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, Callable, NamedTuple, Union +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib import xla_extension +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.") +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From c5aa86a41a8a6ec1d66b072080377c26c09512a8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 22:25:12 -0700 Subject: [PATCH 0239/1769] Remove redundant filtering in the paged flash attention kernel Reason: `l_next >= 1.0` so the `jnp.where(l_next == 0.0, 1.0, l_next)` clause is not needed. PiperOrigin-RevId: 741400472 --- .../pallas/ops/tpu/paged_attention/paged_attention_kernel.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index eb1e11df17da..99cb2c9c94c1 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -274,14 +274,13 @@ def prefetch_next_block(): # pylint: disable=unused-variable alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() o_curr_times_l_curr = jnp.dot(s_curr, v) - m_ref[...], l_ref[...] = m_next, l_next_safe o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next ).astype(o_ref.dtype) step_ref[0] = step + 1 From efa5ae8e9831dc0e510ff1b59d61615a2898925a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 04:26:18 -0700 Subject: [PATCH 0240/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4. PiperOrigin-RevId: 741478215 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 625f33a072f5..43bba2fcc903 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "95abd7942747bd5d1884b309baecdf5a93ff928a" -XLA_SHA256 = "f8472323ffe621ade5317091fdf9acd66aaf67660fedd3143a96d9a347e88bac" +XLA_COMMIT = "edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4" +XLA_SHA256 = "d82a7174a8a129180b180b08f5eedfa5fe6ff19fbd46dc11dae8cf64d87dfbf9" def repo(): tf_http_archive( From 063654000c148699383bad1656e23d808f76a97e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 28 Mar 2025 11:19:06 +0000 Subject: [PATCH 0241/1769] Marked as thread_unsafe_test: - ShardingInTypesTest.test_set_mesh - APITest.test_cache_clear_pmap This helps to prevent errors like: 1) in pjit_test.py: ``` ValueError: For primitive mul, context mesh AbstractMesh('x': 2, axis_types=(Explicit,)) should match the aval mesh AbstractMesh('x': 2, 'y': 1, axis_types=(Auto, Auto)) for shape float32[8,2] ``` raised for example by ArrayPjitTest.test_pjit_array_multi_input_multi_output_mesh3 and also by ArrayPjitTest.test_convert_element_type_sharding, when pjit tests are run concurrently with `--local_test_jobs=32` and `--test_env=JAX_TEST_NUM_THREADS=8` 2) in api_test.py ``` AssertionError: Expected exactly 1 XLA compilations, but executed 2 ``` raised by APITest.test_pmap_global_cache. --- tests/api_test.py | 1 + tests/pjit_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/api_test.py b/tests/api_test.py index 82b673fe4b1e..e99189c671d0 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4424,6 +4424,7 @@ def test_grad_conj_symbolic_zeros(self): out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) + @jtu.thread_unsafe_test() def test_cache_clear_pmap(self): @jax.pmap def f(i): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d6673c6b6d5a..0da97a6f0c14 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7162,6 +7162,7 @@ def f(x): out = f(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + @jtu.thread_unsafe_test() def test_set_mesh(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) try: From 30451478c05e9bae9caaba09cf6b5a15805b3808 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:09:30 -0700 Subject: [PATCH 0242/1769] [Pallas][NFC] Move the remainder of Semaphore-related extended dtypes to Pallas core This completes the move started in https://github.com/jax-ml/jax/pull/26673. PiperOrigin-RevId: 741487331 --- jax/_src/pallas/core.py | 52 ++++++++++++++++++++++++++-- jax/_src/pallas/mosaic/core.py | 41 +++------------------- jax/_src/pallas/mosaic/primitives.py | 2 +- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 26101405fdeb..8602205eef22 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -67,9 +67,55 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + +# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +class AbstractSemaphoreTy(dtypes.ExtendedDType): + name: str + _rules = AbstractSemaphoreTyRules + + def __repr__(self) -> str: + return self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self) -> int: + return hash(self.__class__) + +class semaphore_dtype(dtypes.extended): + """Common dtype for all kinds of semaphore dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of `jnp.issubdtype`. + """ + +class semaphore(semaphore_dtype): + """Regular semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class Semaphore(AbstractSemaphoreTy): + name = "semaphore" + type = semaphore + +class barrier_semaphore(semaphore_dtype): + """Barrier semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class BarrierSemaphore(AbstractSemaphoreTy): + name = "barrier_semaphore" + type = barrier_semaphore @runtime_checkable class CompilerParams(Protocol): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 5df3c01a1934..37b6e51892c7 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -25,7 +25,6 @@ import jax from jax._src import config from jax._src import core as jax_core -from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core import jax.numpy as jnp @@ -114,42 +113,10 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): class dma_semaphore(pallas_core.semaphore_dtype): pass -class AbstractSemaphoreTyRules: - @staticmethod - def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) - - @staticmethod - def physical_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.int32) - -class AbstractSemaphoreTy(dtypes.ExtendedDType): - name: str - _rules = AbstractSemaphoreTyRules - - def __repr__(self) -> str: - return self.name - - def __eq__(self, other): - return self.__class__ == other.__class__ - - def __hash__(self) -> int: - return hash(self.__class__) - -# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy - -class SemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.semaphore - name = "sem" - -class DmaSemaphoreTy(AbstractSemaphoreTy): +class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" -class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.barrier_semaphore - name = "barrier_sem" - class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -158,11 +125,11 @@ class SemaphoreType(enum.Enum): def __call__(self, shape: tuple[int, ...]): dtype: Any if self == SemaphoreType.DMA: - dtype = DmaSemaphoreTy() + dtype = DMASemaphore() elif self == SemaphoreType.BARRIER: - dtype = BarrierSemaphoreTy() + dtype = pallas_core.BarrierSemaphore() else: - dtype = SemaphoreTy() + dtype = pallas_core.Semaphore() return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 106f342bace8..c50a21218117 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -623,7 +623,7 @@ def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, @get_barrier_semaphore_p.def_abstract_eval def _get_barrier_semaphore_abstract_eval(): return pl_core.AbstractMemoryRef( - jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()), + jax_core.ShapedArray((), pl_core.BarrierSemaphore()), tpu_core.TPUMemorySpace.SEMAPHORE, ) From 1c1e2e6dc0d521fa01750cbff7d61c1c130897f1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:11:12 -0700 Subject: [PATCH 0243/1769] [Mosaic GPU] Add support for stores to TMEM We can support reading and writing of both 32- and 16-bit types now. PiperOrigin-RevId: 741487690 --- jax/experimental/mosaic/gpu/tcgen05.py | 221 ++++++++++++++++++++----- tests/mosaic/gpu_test.py | 37 +++++ 2 files changed, 212 insertions(+), 46 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index ac3b80b93689..53056ce594b2 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -327,24 +327,30 @@ def tmem_relinquish_alloc_permit(): has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num, packing: int = 1): +def _tmem_access_helper(shape, num, packing: int = 1): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: case "16x128b": - num_out_regs = 2 + num_regs = 2 case "16x256b": - num_out_regs = 4 + num_regs = 4 case _: raise NotImplementedError(f"{shape=} is unsupported") - if num * num_out_regs >= 256: + num_regs *= num + if num_regs > 255: raise ValueError( - f"Loading too much TMEM at once: {num=} and each load requires" - f" {num_out_regs} registers, which exceeds the limit of 256" + f"TMEM transation too big : {shape=} and {num=} involve" + f" {num_regs} registers per-thread, which exceeds the limit of 255" ) - num_out_regs *= num + regs_vector = ",".join(f"${i}" for i in range(num_regs)) + regs_vector = "{" + regs_vector + "}" + return num_regs, regs_vector + + +def tmem_load(tmem_addr, shape, num, packing: int = 1): i32 = ir.IntegerType.get_signless(32) - out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) if packing == 1: pack_mod = "" elif packing == 2: @@ -356,13 +362,30 @@ def tmem_load(tmem_addr, shape, num, packing: int = 1): "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] +def tmem_store(tmem_addr, shape, num, regs, packing: int = 1): + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".unpack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [*regs, tmem_addr], + f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};", + "r," * num_out_regs + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class TMEMLayout: """Represents the way a shape is laid out in TMEM. @@ -562,62 +585,168 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + def __setitem__(self, idxs, value): + if not isinstance(idxs, tuple): + idxs = (idxs,) + base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) + if any(is_squeezed): + raise ValueError( + "TMEM stores don't support integer indexing (only slices allowed)" + ) + if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: + raise NotImplementedError("Slicing parts of TMEM not implemented yet") + if self.shape[1] % 8: + raise NotImplementedError + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + if not isinstance(value, fa.FragmentedArray): + raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}") + if value.shape != self.shape: + raise ValueError( + f"Stored array has shape {value.shape}, but TMEM has shape" + f" {self.shape}" + ) + if value.mlir_dtype != self.dtype: + raise ValueError( + f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" + f" {self.dtype}" + ) + if value.layout != LAYOUT: + raise ValueError( + f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is" + " supported" + ) + if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)) + ) + else: # TODO(apaszke): Collective MMA layout + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + + +def _transfer_32xcols(base_addr, cols): + i32 = ir.IntegerType.get_signless(32) + cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. + assert cols % cols_per_num == 0 + total_num = cols // cols_per_num + if total_num <= 32: + instr_num = total_num + elif total_num == 64: + instr_num = 32 + else: + raise NotImplementedError(total_num) + # We transfer 16 lanes at a time, but have 32 to deal with. + for lane_step in range(2): + addr_row = arith.addi(base_addr, utils.c((lane_step * 16) << 16, i32)) + cols_per_instr = instr_num * cols_per_num + for num_step in range(total_num // instr_num): + num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) + addr_row_col = arith.addi(addr_row, utils.c(num_step * cols_per_instr, i32)) + yield addr_row_col, instr_num, lane_step, num_slice + + +def _store_32xcols(base_addr, vector_regs): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 + cols = vector_regs.shape[1] * 8 + + packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if packing == 1: + store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in np.ndenumerate(vector_regs): + regs[(*idx, 0)] = llvm.extractelement(vreg, c0) + regs[(*idx, 1)] = llvm.extractelement(vreg, c1) + regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2) + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + assert regs.shape[-2:] == (2, 2) + elif packing == 2: + store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + else: + raise NotImplementedError(packing) + + it = _transfer_32xcols(base_addr, cols) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs_slice = regs[lane_step, num_slice].flat + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, packing) + def _load_32xcols(base_addr, cols, dtype): - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((2,), dtype) packing = 32 // utils.bitwidth(dtype) if packing == 1: - load_shape = "16x256b" # 8 columns * 32 bits = 256 bits - cols_per_num_tile = 8 * packing + load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits elif packing == 2: - load_shape = "16x128b" # 8 columns * 16 bits = 128 bits - cols_per_num_tile = 4 * packing + load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits else: raise NotImplementedError(packing) - assert cols % cols_per_num_tile == 0 - num = cols // cols_per_num_tile - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 - else: - raise NotImplementedError(num) + vector_regs = np.ndarray((4, cols // 8), dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), - ) - regs += tmem_load(addr_row_col, load_shape, num_tiling, packing) + + it = _transfer_32xcols(base_addr, cols) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs = tmem_load(addr_row_col, load_shape, instr_num, packing) + row_slice = slice(lane_step * 2, (lane_step + 1) * 2) + # This aliases the original array, so updates will be reflected there. + vector_regs_update = vector_regs[row_slice, num_slice] + assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) if packing == 1: regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(cols // 8, 2), strict=True): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1) + undef = llvm.mlir_undef(vec_ty) + assert regs.shape == (*vector_regs_update.shape, 2) + for idx in np.ndindex(vector_regs_update.shape): + high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0) + vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) + vector_regs_update[idx] = vreg else: assert packing == 2 - regs = [llvm.bitcast(ir.VectorType.get((2,), dtype), r) for r in regs] - for vreg, idx in zip(regs, np.ndindex(cols // 8, 2), strict=True): - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + regs = [llvm.bitcast(vec_ty, r) for r in regs] + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1) + vector_regs_update[...] = regs + return vector_regs -# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: raise ValueError(f"Shape {shape} is not a multiple of 64x8") - return fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, + return LAYOUT + +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) +LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=(-4, -3), + vector_dim=-1, +) + + +def commit_tmem(): + void = ir.Type.parse("!llvm.void") + llvm.inline_asm( + void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True, ) + utils.warpgroup_barrier() diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d9f56ee1d454..6c63e3ce40e1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -908,6 +908,43 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + @parameterized.parameters([jnp.float32, jnp.float16]) + def test_load_store_tmem(self, jax_dtype): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tcgen05.commit_tmem() + tmem[:].store_tiled(smem, swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), From 39fb2a00a6b4313e836266dfa4e6a6c73b65ca42 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:43:05 -0700 Subject: [PATCH 0244/1769] [Mosaic GPU] Add support for allocation and lowering of scratch semaphores The semaphore arrays are allocated in GMEM and zeroed by XLA before the kernel begins. PiperOrigin-RevId: 741494241 --- jax/_src/pallas/mosaic_gpu/BUILD | 2 +- jax/_src/pallas/mosaic_gpu/core.py | 19 ++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 60 ++++++++++++++----- .../mosaic_gpu/pallas_call_registration.py | 32 ++++++++-- jax/experimental/mosaic/gpu/core.py | 1 + jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 24 ++++++++ 7 files changed, 117 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index ab35eebafc04..33883326e58c 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -48,7 +48,7 @@ pytype_strict_library( "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), ) pytype_strict_library( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 19007b6850fd..857daaefe38f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -120,6 +120,25 @@ def __call__( return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) +class SemaphoreType(enum.Enum): + REGULAR = "regular" + BARRIER = "barrier" + + def __call__(self, shape: tuple[int, ...]): + dtype: Any + if self == SemaphoreType.BARRIER: + dtype = pallas_core.BarrierSemaphore() + else: + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(shape, dtype, GPUMemorySpace.GMEM) + + def get_array_aval(self) -> jax_core.ShapedArray: + return self(()).get_array_aval() + + def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + return self(()).get_ref_aval() + + def kernel( body: Callable[..., None], out_shape: object, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c9f70937873..1b4aa33dc909 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -418,8 +418,9 @@ class LoweringResult: module: ir.Module grid: tuple[int, ...] block: tuple[int, ...] - out_structs: tuple[jax.ShapeDtypeStruct, ...] + new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch! profiler_context: ProfilerContext | None + gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...] @dataclasses.dataclass(frozen=True) @@ -588,16 +589,41 @@ def ref_for_aval(aval: jax_core.AbstractValue): else: return gpu_core.SMEM(aval.shape, aval.dtype) + sem_placeholder = None + semaphore_ref_avals = [] + scratch_avals = [] + # Need to unzip semaphores + for v in jaxpr.invars[grid_mapping.slice_scratch_ops]: + aval = v.aval + if (isinstance(aval, pallas_core.AbstractMemoryRef) and + jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): + if aval.memory_space != gpu_core.GPUMemorySpace.GMEM: + raise ValueError( + "Only GMEM memory space is supported for semaphores in Mosaic GPU." + ) + semaphore_ref_avals.append(aval) + scratch_avals.append(sem_placeholder) + else: + scratch_avals.append(aval) + def pipeline_fn(*refs): - return primitives.run_scoped( - functools.partial(scoped_pipeline_fn, *refs), + sem_refs = [] + if semaphore_ref_avals: + refs, sem_refs = util.split_list(refs, [-len(semaphore_ref_avals)]) + primitives.run_scoped( + functools.partial(scoped_pipeline_fn, *refs, sem_refs=sem_refs), scratch_refs=[ - ref_for_aval(v.aval) - for v in jaxpr.invars[grid_mapping.slice_scratch_ops] + ref_for_aval(aval) if aval is not sem_placeholder else aval + for aval in scratch_avals ], ) + return () # ``wrap_init`` does not support functions returning None. - def scoped_pipeline_fn(*refs, scratch_refs): + def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): + sem_refs_it = iter(sem_refs) + scratch_refs = [ + next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs + ] def body_fn(*refs): grid_env = pallas_core.current_grid_env() assert grid_env is not None # Set by ``emit_pipeline``. @@ -628,17 +654,13 @@ def body_fn(*refs): with grid_mapping.trace_env(): new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - # ``wrap_init`` does not support functions returning None. - lambda *args: pipeline_fn(*args) or (), - debug_info=jaxpr.debug_info, - ), + lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info), [ gpu_core.GMEM( bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype ).get_ref_aval() for bm in block_mappings - ], + ] + semaphore_ref_avals, ) assert not new_consts @@ -655,6 +677,10 @@ def body_fn(*refs): mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], + [ + jax.ShapeDtypeStruct(r.shape, np.dtype(np.int32)) + for r in semaphore_ref_avals + ], new_jaxpr, compiler_params, new_consts, @@ -668,6 +694,7 @@ def lower_jaxpr_to_module( cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], + gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, compiler_params: dict[str, Any], consts=(), @@ -754,14 +781,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + module, new_out_shapes, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, in_shapes=in_shapes, - out_shape=out_shapes, + out_shape=(*out_shapes, *gmem_scratch_shapes), smem_scratch_shape=scratch_buffers, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, @@ -777,8 +804,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + if gmem_scratch_shapes: + new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)] + return LoweringResult( - module, parallel_grid, block, out_structs_gmem, prof_ctx + module, parallel_grid, block, new_out_shapes, prof_ctx, tuple(gmem_scratch_shapes) ) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 40b12215c003..6dc958edbc53 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,11 +23,13 @@ import warnings import jax +from jax import lax from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering from jax.experimental.mosaic import gpu as mgpu +import numpy as np def pallas_call_lowering( @@ -74,16 +76,30 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - new_avals_out = [ - jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs - ] + new_avals_in = list(ctx.avals_in) + new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) + scratch_args = () + if lowering_result.gmem_scratch_shapes: + input_output_aliases += tuple( + (len(new_avals_in) + i, len(new_avals_out) + i) + for i in range(len(lowering_result.gmem_scratch_shapes)) + ) + new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + new_avals_out.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + def zero_init_gmem_scratch(): + return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes] + scratch_args = mlir.lower_fun( + zero_init_gmem_scratch, multiple_results=True + )(ctx.replace(avals_in=())) outs = mgpu.core._mosaic_gpu_lowering_rule( - ctx.replace(avals_out=new_avals_out), - *args, + ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), + *args, *scratch_args, module=module, - out_types=lowering_result.out_structs, + out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), input_output_aliases=input_output_aliases, ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] if (prof_ctx := lowering_result.profiler_context) is not None: *outs, prof_buffer = outs if (dump_path := prof_ctx.dump_path) == "sponge": @@ -112,3 +128,7 @@ def do_callback(prof_buffer): ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) return outs + + +def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray: + return jax_core.ShapedArray(t.shape, np.dtype(t.dtype)) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index f5331eb1b56a..fcc5d3db6d60 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -104,6 +104,7 @@ def _mosaic_gpu_lowering_rule( out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), ): + assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) module = _run_serde_pass( module, diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index d5acb9b131ad..b791fbb8b573 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index aea49b645ec6..c5da44d7b6fa 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2408,6 +2408,30 @@ def compute(l_smem, r_smem, o_smem): out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) + def test_semaphore_lowering(self): + # This is a smoke test until we add support for lowering of semaphore ops. + def body(i_ref1, i_ref2, o_ref, sem_ref): + del i_ref2 # Only here to have a different number of inputs and outputs. + assert sem_ref.shape == (4,) + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + o_ref[...] = i_ref1[...] + x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) + kernel = pl.pallas_call( + body, out_shape=x, scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + ) + text = jax.jit(kernel).lower(x, x).as_text() + self.assertIn( + r"output_operand_aliases =" + r" [#stablehlo.output_operand_alias]", + text, + ) + self.assertIn( + r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->" + r" (tensor<128xf32>, tensor<4xi32>)", + text, + ) + class ExamplesSm90ATest(PallasSm90ATest): From 5c61a69fd6dc46657c952e6f235c35cb5f3bcbd6 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 28 Mar 2025 09:17:20 -0400 Subject: [PATCH 0245/1769] Fixes failing FFI example builds. Breaking CI: https://github.com/jax-ml/jax/actions/runs/14126719325/job/39577362075?pr=27557 See breaking nanobind PR: https://github.com/wjakob/nanobind/pull/978 See fixing nanobind PR (not landed) https://github.com/wjakob/nanobind/pull/980 --- examples/ffi/pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 130dd91bbc70..6f188ee037da 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -1,5 +1,7 @@ [build-system] -requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] +# TODO(dsuo): Remove nanobind pin after +# https://github.com/wjacob/nanobind/pull/980 lands. +requires = ["scikit-build-core", "nanobind==2.5.0", "jax>=0.4.31"] build-backend = "scikit_build_core.build" [project] From 4024897372cac19d2c17004babf1063d4975a38b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 06:51:48 -0700 Subject: [PATCH 0246/1769] Update CUDA tests matrix in the continuous jobs We now test only CUDA 12.1 and CUDA 12.8 PiperOrigin-RevId: 741509853 --- .github/workflows/wheel_tests_continuous.yml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index f48c39bf4721..530e0a9b0768 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -116,23 +116,16 @@ jobs: cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" + # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" cuda: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - runner: "linux-x86-a4-224-b200-1gpu" cuda: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" with: From 28f63ee27e751f0d4033131c881f157615acbfd9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 06:53:19 -0700 Subject: [PATCH 0247/1769] Use the Docker image with CUDA 12.8 and cudnn 9.8 in the Bazel CUDA non RBE job PiperOrigin-RevId: 741510217 --- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 3d15f4211a3f..ff1cf9900ce3 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -47,7 +47,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} From f1ebb1e1e13e0ee57feb3916f6cb4919cfc0e62c Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 28 Mar 2025 07:16:57 -0700 Subject: [PATCH 0248/1769] Skip failing tests on TPU v6+ PiperOrigin-RevId: 741515935 --- tests/lax_numpy_reducers_test.py | 4 ++-- tests/lax_scipy_test.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0c3f1d1471fb..aa5e08e96a3e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -905,8 +905,8 @@ def testCumulativeSumBool(self): @jtu.ignore_warning(category=NumpyComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): - if jtu.is_device_tpu(6): - raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 388d053d9608..bc80ed4e1cc2 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -339,8 +339,8 @@ def scipy_fun(z): ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testLpmn(self, l_max, shape, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -461,8 +461,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -508,8 +508,8 @@ def testSphHarmCornerCaseWithWrongNmax(self): ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmY(self, l_max, num_z, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) From 563c3e224425d0fe3bf8016105cae7769eb0474b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 07:19:07 -0700 Subject: [PATCH 0249/1769] Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules PiperOrigin-RevId: 741516445 --- jax/_src/core.py | 8 ++++++-- jax/_src/ffi.py | 9 ++++++--- jax/_src/lax/control_flow/solves.py | 14 ++++++++------ jax/_src/lax/windowed_reductions.py | 2 +- jax/_src/prng.py | 29 ++++++++++++++++++++--------- 5 files changed, 41 insertions(+), 21 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ca353486afd5..1be60336f1a9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1528,7 +1528,7 @@ def check_valid_jaxtype(x): def update_aval_with_sharding(aval, sharding): if isinstance(sharding, NamedSharding): - aval = aval.update(sharding=NamedSharding( + return aval.update(sharding=NamedSharding( sharding.mesh.abstract_mesh, sharding.spec._normalized_spec_for_aval(aval.ndim))) return aval @@ -1659,8 +1659,10 @@ def physical_aval(aval): elt_aval = physical_element_aval(aval.dtype) if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore + vma = aval.vma if config.varying_axes_in_types.value else frozenset() return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding)) + sharding=physical_sharding(aval, aval.sharding), + vma=vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -2019,6 +2021,8 @@ def standard_insert_pbroadcast(*args): if out_vma - src else arg for arg, src in zip(args, in_vma)] def standard_vma_rule(prim_name, *avals, **kwargs): + if not avals: + return avals vma, *vmas = [a.vma for a in avals] if not all(vma == vma_ for vma_ in vmas): raise ValueError( diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 05697f00e945..c867ec16b9b3 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,6 +24,7 @@ import jax from jax._src import core +from jax._src import config from jax._src import deprecations from jax._src import dispatch from jax._src import effects @@ -515,7 +516,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - + args = core.standard_insert_pbroadcast(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -638,9 +639,11 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - del avals_in # unused + out_vma = (core.standard_vma_rule('ffi_call', *avals_in) + if config.varying_axes_in_types.value else frozenset()) effects = {_FfiEffect} if has_side_effect else core.no_effects - return result_avals, effects + return tuple(r if r is core.abstract_token else r.update(vma=out_vma) + for r in result_avals), effects def ffi_call_jvp(*args, target_name, **_): diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index acfcfd7ff3d3..4a0872bef4b2 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,6 +23,7 @@ from jax._src import api from jax._src import api_util from jax._src import core +from jax._src import config from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.interpreters import ad @@ -309,24 +310,25 @@ def f_aux(x): jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) + args = _flatten(all_consts) + b_flat + args = core.standard_insert_pbroadcast(*args) + out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat) def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] - # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise, jaxprs.solve.effects + out_vma = (core.standard_vma_rule('linear_solve', *args_to_raise) + if config.varying_axes_in_types.value else frozenset()) + return (tuple(a.update(vma=out_vma) for a in args_to_raise), + jaxprs.solve.effects) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 00bdfe75f3e7..73fae7df40e1 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -338,7 +338,7 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - out_vma = (core.standard_vma_rule('reduce_window', operand_avals) + out_vma = (core.standard_vma_rule('reduce_window', *operand_avals) if config.varying_axes_in_types.value else frozenset()) return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=out_vma) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index ead939d74351..17d16527bb71 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -178,7 +178,9 @@ def copy_to_host_async(self): def aval(self): logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') else None) - return keys_shaped_array(self._impl, self.shape, logical_sharding) + vma = (self._base_array.aval.vma if config.varying_axes_in_types.value else frozenset() + if hasattr(self._base_array, 'aval') else frozenset()) + return keys_shaped_array(self._impl, self.shape, logical_sharding, vma) @property def shape(self): @@ -329,8 +331,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray return random_seed(seed, impl=impl) -def keys_shaped_array(impl, shape, sharding): - aval = core.ShapedArray(shape, KeyTy(impl)) +def keys_shaped_array(impl, shape, sharding, vma): + aval = core.ShapedArray(shape, KeyTy(impl), vma=vma) return core.update_aval_with_sharding(aval, sharding) def base_arr_shape_to_keys_shape(impl, base_arr_shape): @@ -550,7 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding) + out_vma = seeds_aval.vma if config.varying_axes_in_types.value else frozenset() + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, out_vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -584,8 +587,9 @@ def random_split_abstract_eval(keys_aval, *, shape): # TODO(yashkatariya): random_split should take sharding as an arg too so we # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) + out_vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec)) + keys_aval.sharding.with_spec(new_spec), out_vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -611,7 +615,9 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): - return random_fold_in_p.bind(keys, jnp.asarray(msgs)) + msgs = jnp.asarray(msgs) + keys, msgs = core.standard_insert_pbroadcast(keys, msgs) + return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) @@ -623,7 +629,9 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval): 'random_fold_in', keys_aval, msgs_aval) sharding = lax_internal.broadcasting_sharding_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding) + vma = (core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) + if config.varying_axes_in_types.value else frozenset()) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -661,7 +669,8 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - return core.ShapedArray(out_shape, out_dtype) + vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() + return core.ShapedArray(out_shape, out_dtype, vma=vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -718,7 +727,9 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - return keys_shaped_array(impl, shape, sharding) + out_vma = (base_arr_aval.vma if config.varying_axes_in_types.value else + frozenset()) + return keys_shaped_array(impl, shape, sharding, out_vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): From e679811c4ae245bcc48e9f18a51361cabbdf5561 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 28 Mar 2025 07:20:57 -0700 Subject: [PATCH 0250/1769] [Mosaic GPU] Add warpgroup lowering for `Exp2` in Pallas. This change also enables tests for supported elementwise ops. PiperOrigin-RevId: 741516852 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + jax/experimental/mosaic/gpu/dialect_lowering.py | 12 ++++++++++++ tests/pallas/mosaic_gpu_test.py | 1 - tests/pallas/ops_test.py | 9 ++++++++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1b4aa33dc909..baac1e6eb316 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1671,6 +1671,7 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Warpgroup) def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 936bba73915b..f00cff9a500c 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -560,6 +560,12 @@ def _mgpu_async_load_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, @@ -596,6 +602,12 @@ def _mgpu_async_store_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=transform_memref(store_op.source, transforms), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c5da44d7b6fa..874ecae93f3f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1418,7 +1418,6 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { mgpu_primitives.broadcasted_iota_p, - lax.exp2_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, lax.slice_p, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index f5b70878533d..aeb0ba1cca1a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -562,7 +562,8 @@ def kernel(*refs): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - self.skip_if_mosaic_gpu() + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: + self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") @@ -579,6 +580,12 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) x_shape_dtype = data.draw(shape_dtype_strategy) + + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if sut_is_mosaic_gpu: + hp.assume(math.prod(x_shape_dtype.shape) % 128 == 0) + hp.assume(x_shape_dtype.shape[-1] >= 16) + key = random.key(0) x = _random_value(key, x_shape_dtype) out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) From 431c2c080728a1c880f1facab0bc431631658fe9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 07:44:38 -0700 Subject: [PATCH 0251/1769] cleanup now that we depend on ml_dtypes>=0.5 --- jax/_src/dtypes.py | 119 +++++++++++-------------------- jax/_src/export/serialization.py | 12 ++-- jax/_src/interpreters/mlir.py | 22 ++---- jax/_src/lax/lax.py | 19 ++--- jax/_src/numpy/scalar_types.py | 18 ++--- jax/_src/public_test_util.py | 41 ++++------- jax/_src/test_util.py | 12 ++-- jax/numpy/__init__.py | 26 ++----- jax/tools/jax_to_ir.py | 8 +-- jaxlib/xla/xla_client.py | 11 ++- tests/dtypes_test.py | 25 ++----- tests/jax_to_ir_test.py | 6 +- 12 files changed, 104 insertions(+), 215 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 01500c008405..d1e5b7bf430b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,19 +90,18 @@ def type(self) -> type: ... # fp8 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float8_e3m4: type[np.generic] | None = None -float8_e4m3: type[np.generic] | None = None -float8_e8m0fnu: type[np.generic] | None = None +float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4 +float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3 +float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz -_float8_e3m4_dtype: np.dtype | None = None -_float8_e4m3_dtype: np.dtype | None = None -_float8_e8m0fnu_dtype: np.dtype | None = None +_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4) +_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3) +_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu) _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -111,9 +110,9 @@ def type(self) -> type: ... # fp4 support # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float4_e2m1fn: type[np.generic] | None = None +float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn -_float4_e2m1fn_dtype: np.dtype | None = None +_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn) def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" @@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float4_e2m1fn, + float8_e3m4, + float8_e4m3, + float8_e8m0fnu, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool: bfloat16, ] _custom_float_dtypes = [ + _float4_e2m1fn_dtype, + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -143,6 +150,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype, ] _float8_dtypes = [ + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -150,58 +160,28 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] -_float4_dtypes: list[np.dtype] = [] - -# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 -if hasattr(ml_dtypes, "float8_e4m3"): - float8_e4m3 = ml_dtypes.float8_e4m3 - _float8_e4m3_dtype = np.dtype(float8_e4m3) - _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3_dtype) - _float8_dtypes.insert(0, _float8_e4m3_dtype) -if hasattr(ml_dtypes, "float8_e3m4"): - float8_e3m4 = ml_dtypes.float8_e3m4 - _float8_e3m4_dtype = np.dtype(float8_e3m4) - _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e3m4_dtype) - _float8_dtypes.insert(0, _float8_e3m4_dtype) -if hasattr(ml_dtypes, "float8_e8m0fnu"): - float8_e8m0fnu = ml_dtypes.float8_e8m0fnu - _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) - _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) - _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) -if hasattr(ml_dtypes, "float4_e2m1fn"): - float4_e2m1fn = ml_dtypes.float4_e2m1fn - _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) - _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) - _float4_dtypes.insert(0, _float4_e2m1fn_dtype) - -# 2-bit integer support -int2: type[np.generic] | None = None -uint2: type[np.generic] | None = None - -_int2_dtype: np.dtype | None = None -_uint2_dtype: np.dtype | None = None - -_intn_dtypes = [] - -# Remove the condition once the minimum ml_dtypes version required by JAX -# contains https://github.com/jax-ml/ml_dtypes/pull/154. -if hasattr(ml_dtypes, 'int2'): - int2 = ml_dtypes.int2 - uint2 = ml_dtypes.uint2 - _int2_dtype = np.dtype(int2) - _uint2_dtype = np.dtype(uint2) - _intn_dtypes.extend([_int2_dtype, _uint2_dtype]) +_float4_dtypes: list[np.dtype] = [ + _float4_e2m1fn_dtype, +] + +int2: type[np.generic] = ml_dtypes.int2 +uint2: type[np.generic] = ml_dtypes.uint2 + +_int2_dtype: np.dtype = np.dtype(int2) +_uint2_dtype: np.dtype = np.dtype(uint2) # 4-bit integer support int4: type[np.generic] = ml_dtypes.int4 uint4: type[np.generic] = ml_dtypes.uint4 _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) -_intn_dtypes.extend([_int4_dtype, _uint4_dtype]) + +_intn_dtypes = [ + _int2_dtype, + _uint2_dtype, + _int4_dtype, + _uint4_dtype, +] # Default types. bool_ = np.bool_ @@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, # to the normal scalar type hierarchy. if a_sctype in _custom_float_scalar_types: return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic} - if (int2 is not None and a_sctype == int2) or a_sctype == int4: + if a_sctype in [int2, int4]: return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic} - if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4: + if a_sctype in [uint2, uint4]: return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic} # Otherwise, fall back to numpy.issubdtype @@ -491,6 +471,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, _unsigned_types: list[JAXType] _int_types: list[JAXType] _unsigned_types = [ + np.dtype(uint2), np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), @@ -498,6 +479,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('uint64'), ] _signed_types = [ + np.dtype(int2), np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -505,11 +487,6 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('int64'), ] -if _int2_dtype is not None: - _signed_types.insert(0, _int2_dtype) -if _uint2_dtype is not None: - _unsigned_types.insert(0, _uint2_dtype) - _int_types = _unsigned_types + _signed_types _float_types: list[JAXType] = [ @@ -622,11 +599,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - if _int2_dtype is not None: - assert _uint2_dtype is not None - _uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types - else: - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -634,19 +607,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - i_: [u1, uint4, i1, int4], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + i_: [u1, uint2, uint4, i1, int2, int4], + uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } - if _int2_dtype is not None: - out[i_].append(_int2_dtype) - out[_int2_dtype] = [] - if _uint2_dtype is not None: - out[i_].append(_uint2_dtype) - out[_uint2_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index ac97c11d1177..94c0baf642b6 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -357,16 +357,12 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, + dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, + dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, + dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, + dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, } -if dtypes._float8_e3m4_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 -if dtypes._float8_e4m3_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 -if dtypes._float8_e8m0fnu_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu -if dtypes._float4_e2m1fn_dtype is not None: - _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a707981f5403..23d1b5dd9d89 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -185,24 +185,14 @@ def _is_ir_values(x: IrValues) -> bool: np.dtype(np.float64): ir.F64Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), + np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2), + np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2), + np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get, + np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get, + np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get, + np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get, } - -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) - -if dtypes.float8_e3m4 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get -if dtypes.float8_e4m3 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get -if dtypes.float8_e8m0fnu is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get - -if dtypes.float4_e2m1fn is not None: - _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get - def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b79c81e19195..a4fb04698365 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2346,13 +2346,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz), + np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), + np.dtype(dtypes.float8_e8m0fnu), ] - if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] - if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -5602,13 +5599,9 @@ def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) - if dtypes.float8_e3m4 is not None: - fp8_dtypes += (dtypes.float8_e3m4,) - if dtypes.float8_e4m3 is not None: - fp8_dtypes += (dtypes.float8_e4m3,) - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += (dtypes.float8_e8m0fnu,) + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, + dtypes.float8_e3m4, dtypes.float8_e4m3, + dtypes.float8_e8m0fnu) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2f9954488b41..2b0e04adc997 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint2 is not None: - uint2 = _make_scalar_type(dtypes.uint2) +uint2 = _make_scalar_type(dtypes.uint2) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int2 is not None: - int2 = _make_scalar_type(dtypes.int2) +int2 = _make_scalar_type(dtypes.int2) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) -if dtypes.float8_e3m4 is not None: - float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) -if dtypes.float8_e4m3 is not None: - float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) -if dtypes.float8_e8m0fnu is not None: - float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) +float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) +float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) -if dtypes.float4_e2m1fn is not None: - float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 59ddb73dc9e1..3b1e24bc9c50 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -46,16 +46,22 @@ def _dtype(x: Any) -> np.dtype: _default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int2): 0, np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint2): 0, np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -69,16 +75,15 @@ def _dtype(x: Any) -> np.dtype: np.dtype(np.complex128): 1e-15, } -if _dtypes.int2 is not None: - assert _dtypes.uint2 is not None - _default_tolerance[np.dtype(_dtypes.int2)] = 0 - _default_tolerance[np.dtype(_dtypes.uint2)] = 0 - def default_tolerance(): return _default_tolerance default_gradient_tolerance: ToleranceDict = { + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -92,19 +97,6 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO: make this unconditional when ml_dtypes>=0.5.0 is required -if _dtypes.float8_e3m4 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 -if _dtypes.float8_e4m3 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 -if _dtypes.float8_e8m0fnu is not None: - _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 -if _dtypes.float4_e2m1fn is not None: - _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -115,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): return custom_float_dtypes = [ + _dtypes.float4_e2m1fn, + _dtypes.float8_e8m0fnu, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, @@ -123,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.bfloat16, ] - if _dtypes.float8_e4m3 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e4m3) - if _dtypes.float8_e3m4 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e3m4) - if _dtypes.float8_e8m0fnu is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) - if _dtypes.float4_e2m1fn is not None: - custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) - def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1cd9546a1655..c3f7fb4c4139 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1632,15 +1632,11 @@ def custom_floats(self): _dtypes.float8_e4m3fnuz, _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, + _dtypes.float8_e8m0fnu, + _dtypes.float4_e2m1fn, ] - if _dtypes.float8_e3m4 is not None: - float_dtypes += [_dtypes.float8_e3m4] - if _dtypes.float8_e4m3 is not None: - float_dtypes += [_dtypes.float8_e4m3] - if _dtypes.float8_e8m0fnu is not None: - float_dtypes += [_dtypes.float8_e8m0fnu] - if _dtypes.float4_e2m1fn is not None: - float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 31cca3578916..b6cfb1ff06ac 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -211,13 +211,18 @@ double as double, float16 as float16, float32 as float32, + float4_e2m1fn as float4_e2m1fn, float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e4m3fnuz as float8_e4m3fnuz, float8_e5m2 as float8_e5m2, float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, float_ as float_, + int2 as int2, int4 as int4, int8 as int8, int16 as int16, @@ -226,6 +231,7 @@ int_ as int_, single as single, uint as uint, + uint2 as uint2, uint4 as uint4, uint8 as uint8, uint16 as uint16, @@ -295,26 +301,6 @@ unsignedinteger as unsignedinteger, ) -# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1. -try: - from jax._src.numpy.scalar_types import ( - int2 as int2, - uint2 as uint2, - ) -except ImportError: - pass - -# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 -try: - from jax._src.numpy.scalar_types import ( - float8_e3m4 as float8_e3m4, - float8_e4m3 as float8_e4m3, - float8_e8m0fnu as float8_e8m0fnu, - float4_e2m1fn as float4_e2m1fn, - ) -except ImportError: - pass - from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 904ce509a87e..47b85382f8bf 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -240,16 +240,12 @@ def parse_shape_str(s): _DT = { 'pred': jnp.bool_, - 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, 'bf16': jnp.bfloat16, 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, 'c64': jnp.complex64, 'c128': jnp.complex128 } -if hasattr(jnp, 'int2'): - _DT['s2'] = jnp.int2 -if hasattr(jnp, 'uint2'): - _DT['u2'] = jnp.uint2 _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 776a22444208..21ea81ac6efa 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -238,13 +238,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 -# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# Also, it would be better to conditionally import these based on whether they -# are in the current version of ml_dtypes. -# float4_e2m1fn = ml_dtypes.float4_e2m1fn -# float8_e3m4 = ml_dtypes.float8_e3m4 -# float8_e4m3 = ml_dtypes.float8_e4m3 -# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float8_e3m4 = ml_dtypes.float8_e3m4 +float8_e4m3 = ml_dtypes.float8_e4m3 +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 87380443f4cb..d8fb30397b27 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -46,30 +46,19 @@ np.dtype('uint64')] unsigned_dtypes = list(np_unsigned_dtypes) -intn_dtypes = [np.dtype('int4'), np.dtype('uint4')] -signed_dtypes += [np.dtype('int4')] -unsigned_dtypes += [np.dtype('uint4')] -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')] - signed_dtypes[:0] = [np.dtype('int2')] - unsigned_dtypes[:0] = [np.dtype('uint2')] - -np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), - np.dtype('float64')] +intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')] +signed_dtypes += [np.dtype('int2'), np.dtype('int4')] +unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')] + +np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')] float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] -if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] -if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] -if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] + np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index f600a08f5dc4..4eb8190b712f 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -114,15 +114,13 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) - if hasattr(jnp, 'int2'): - self.assertParsedShape('s2[1]', [1], jnp.int2) + self.assertParsedShape('s2[1]', [1], jnp.int2) self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) - if hasattr(jnp, 'uint2'): - self.assertParsedShape('u2[1]', [1], jnp.uint2) + self.assertParsedShape('u2[1]', [1], jnp.uint2) self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) From 4a8f520a8747207cb7d5b7f08cc4a7d6418aa539 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 07:53:06 -0700 Subject: [PATCH 0252/1769] Replace uses of deprecated `Shape::rank()` with: - `dimensions().size()` if it's OK for the result to be changed to an unsigned number, - `dimensions_size()` if it's important that the result is a signed number. This should be a pure refactoring that doesn't affect the code's behavior. Note that `rank()` returns `int64_t` and `dimensions().size()` returns `size_t`. Sometimes the change of the signedness is not desirable, and we use `dimensions_size()`, which returns `int`, in such cases. PiperOrigin-RevId: 741524661 --- jaxlib/gpu/py_client_gpu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 59cc385825a0..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -182,7 +182,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, options.dims = absl::Span( reinterpret_cast(array.shape()), array.ndim()); absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.rank()); + reversed_layout.resize(expected_shape.dimensions().size()); absl::c_reverse_copy(expected_shape.layout().minor_to_major(), reversed_layout.begin()); options.permutation = reversed_layout; From cf12cc5fc5cd9b76e3a09da99084fc9a1e943b09 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 28 Mar 2025 08:05:04 -0700 Subject: [PATCH 0253/1769] [Mosaic GPU] Ignore layouts that are already set when computing default vector size in layout inference. PiperOrigin-RevId: 741528085 --- .../mosaic/gpu/inference_utils.py | 20 +++++++++-- .../mosaic/gpu/layout_inference.py | 34 ++++++++++++------- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 6362626404c5..73ce23c427cd 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -95,6 +95,22 @@ def has_out_transforms_set(op: MlirOperation) -> bool: return "out_transforms" in op.attributes +def attr_element( + attr_name: str, op: MlirOperation, index: int +) -> ir.Attribute | None: + """Returns `op.attributes[attr_name][index]` if it exists, otherwise None. + + If `op.attributes[attr_name]` exists, then `index` must be a valid index into + the attribute array. + """ + if attr_name not in op.attributes: + return None + attr = op.attributes[attr_name] + if not attr: + return None + return op.attributes[attr_name][index] # type: ignore + + def _in_attr_for_operand( op: MlirOperation, operand: ir.Value, @@ -109,9 +125,7 @@ def _in_attr_for_operand( operand_number = [o for o in op.operands if predicate(o)].index(operand) - if attr_name not in op.attributes: - return None - return op.attributes[attr_name][operand_number] # type: ignore + return attr_element(attr_name, op, operand_number) in_layout_for_operand = partial( diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 402a8c08a4ef..e49b3677b2ad 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -548,21 +548,31 @@ def inference_step(op: ir.Operation): # make sure to derive a single vector size in order to avoid relayouts at # lowering time. default_vector_size = math.inf - - def update_default_vector_size(op: ir.OpView): + def update_default_vector_size_from_vector(v: ir.Value): nonlocal default_vector_size - for v in list(op.operands) + list(op.results): - if ir.VectorType.isinstance(v.type): - max_vec_size_for_v = ( - np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE - ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) - default_vector_size = min( - default_vector_size, max_vec_size_for_v, desired_vec_size - ) + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + def update_default_vector_size_from_op(op: ir.OpView): + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.operands) + ): + if inference_utils.attr_element("in_layouts", op, i) is None: + update_default_vector_size_from_vector(v) + + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.results) + ): + if inference_utils.attr_element("out_layouts", op, i) is None: + update_default_vector_size_from_vector(v) for op in module.body: - traverse_op(op, update_default_vector_size) + traverse_op(op, update_default_vector_size_from_op) if default_vector_size == math.inf: # Nothing to annotate. return From 968bbd2bf25e3ace63a4e6938adc70d5e4540caa Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 28 Mar 2025 08:09:01 -0700 Subject: [PATCH 0254/1769] Add a small atol bump to `betainc` test in `LaxVmapOpTest` PiperOrigin-RevId: 741529177 --- jax/_src/internal_test_util/lax_test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 4e28791e9cee..767b41dc8ba0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -304,7 +304,7 @@ def lax_ops(): float_dtypes, test_util.rand_uniform, { - np.float32: 1e-5, + np.float32: 2e-5, np.float64: 1e-12, }, ), From d974b090565022ef7139c4c407a047a7f2e406ea Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 08:29:13 -0700 Subject: [PATCH 0255/1769] Fix error in build.py when trying to build aarch64 jaxlib wheel. PiperOrigin-RevId: 741534342 --- build/build.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/build/build.py b/build/build.py index 1900073fc132..f8c0ccbfa6a4 100755 --- a/build/build.py +++ b/build/build.py @@ -414,10 +414,7 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if ( - not hasattr(args,"use_new_wheel_build_rule") - or args.command == "requirements_update" - ): + if args.command == "requirements_update" or not args.use_new_wheel_build_rule: bazel_command_base.append("run") else: bazel_command_base.append("build") From 98b763cfe48a14749252e29ceb862f9ca228ccbe Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 08:45:45 -0700 Subject: [PATCH 0256/1769] Use a 16 core Windows runner when building artifacts Also, switch the Linux aarch64 runner type to t2a as we run the tests on t2a. PiperOrigin-RevId: 741538543 --- .github/workflows/build_artifacts.yml | 12 ++++++------ .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..37a791784506 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -16,8 +16,8 @@ on: default: "linux-x86-n2-16" options: - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-arm64-t2a-48" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice @@ -119,11 +119,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 530e0a9b0768..3739c9267730 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -44,7 +44,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the From 4bfe0d10e95cc6d14eec74c46dcaf897322044ee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 09:13:50 -0700 Subject: [PATCH 0257/1769] Remove get_emit_python_callback_descriptor from the type stubs. The function itself was already deleted. PiperOrigin-RevId: 741546212 --- jaxlib/xla/xla_extension/__init__.pyi | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 3a6435824b67..d002080b17bc 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -580,12 +580,6 @@ class Client: ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> _Status: ... - def get_emit_python_callback_descriptor( - self, - callable: Callable, - operand_shapes: Sequence[Shape], - results_shapes: Sequence[Shape], - ) -> Tuple[Any, Any]: ... def make_python_callback_from_host_send_and_recv( self, callable: Callable, From 5495c56990956c92f2c671a47c99cc85c018df05 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 09:21:29 -0700 Subject: [PATCH 0258/1769] Remove a use of XlaComputation from call_tf. call_tf is the only remaining user of the XlaComputation type in JAX. Change it to use a new helper function that converts an HLO proto to stablehlo bytecode without using the XlaComputation Python bindings. Also port the code to parse types from the stablehlo rather than the HLO. Remove jax.interpreters.mlir.xla_computation_to_mlir_module. PiperOrigin-RevId: 741548298 --- jax/experimental/jax2tf/call_tf.py | 99 +++++++++++++++++++++--------- jax/interpreters/mlir.py | 1 - jaxlib/xla/mlir.cc | 14 +++++ jaxlib/xla/xla_client.py | 2 +- jaxlib/xla/xla_extension/mlir.pyi | 1 + 5 files changed, 85 insertions(+), 32 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 98c1c20cd6e5..3b175cd64c4c 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -41,11 +41,14 @@ from jax._src import effects from jax._src import util from jax._src.lib import xla_client +from jax._src.lib import xla_extension as _xla +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax._src.interpreters import mlir +import ml_dtypes import numpy as np import tensorflow as tf @@ -468,6 +471,47 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) +def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype: + """Converts an MLIR scalar type to a NumPy dtype.""" + + if ir.IntegerType.isinstance(type): + type = ir.IntegerType(type) + width = type.width + if width == 1: + return np.dtype(np.bool_) + elif width == 8: + return np.dtype(np.uint8 if type.is_unsigned else np.int8) + elif width == 16: + return np.dtype(np.uint16 if type.is_unsigned else np.int16) + elif width == 32: + return np.dtype(np.uint32 if type.is_unsigned else np.int32) + elif width == 64: + return np.dtype(np.uint64 if type.is_unsigned else np.int64) + else: + raise ValueError(f"Unsupported integer width: {width}") + + elif ir.F16Type.isinstance(type): + return np.dtype(np.float16) + elif ir.F32Type.isinstance(type): + return np.dtype(np.float32) + elif ir.F64Type.isinstance(type): + return np.dtype(np.float64) + elif ir.BF16Type.isinstance(type): + return np.dtype(ml_dtypes.bfloat16) + + elif ir.ComplexType.isinstance(type): + element_type = ir.ComplexType(type).element_type + if ir.F32Type.isinstance(element_type): + return np.dtype(np.complex64) + elif ir.F64Type.isinstance(element_type): + return np.dtype(np.complex128) + else: + raise ValueError(f"Unsupported complex element type: {element_type}") + + else: + raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}") + + def _call_tf_lowering( ctx: mlir.LoweringRuleContext, *args_op, @@ -555,33 +599,12 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - xla_comp = xla_client.XlaComputation(func_tf_hlo) - - # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: - if not res_shape.is_static(): - msg = ("Compiled TensorFlow function has dynamic output shape " + - f"{res_shape}. call_tf can used " + - "in a staged context (under jax.jit, lax.scan, etc.) only with " + - "compilable functions with static output shapes. " + - "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") - raise ValueError(msg) - - res_dtype = res_shape.numpy_dtype() - jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) - return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) - - result_shape = xla_comp.program_shape().result_shape() - if not result_shape.is_tuple(): - # TF does not wrap singletons as tuples, but JAX expects tuples because - # call_tf is a multiple_results primitive. - result_shapes = (result_shape,) + if jaxlib_extension_version >= 324: + stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) else: - result_shapes = result_shape.tuple_shapes() # type: ignore - - result_avals = tuple(map(canonical_res_aval, result_shapes)) - - submodule = mlir.xla_computation_to_mlir_module(xla_comp) + xla_comp = xla_client.XlaComputation(func_tf_hlo) + stablehlo = _xla.mlir.xla_computation_to_mlir_module(xla_comp) + submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, @@ -600,10 +623,26 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: ) outputs = [] - for op, res_aval, res_shape in zip(flat_results, result_avals, - result_shapes): - if res_aval.dtype != res_shape.numpy_dtype(): - op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + for op, res_type in zip(flat_results, callee_result_types): + if not res_type.has_static_shape: + msg = ( + "Compiled TensorFlow function has dynamic output shape " + + f"{res_type}. call_tf can used in a staged context (under jax.jit," + " lax.scan, etc.) only with compilable functions with static" + " output shapes. See" + " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + + res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type) + # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode + jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) + if res_dtype != jax_res_dtype: + op = hlo.ConvertOp( + mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)), + op, + ).result outputs.append(op) return outputs diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0f32799f7ea9..8a615be968a6 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -63,7 +63,6 @@ register_lowering as register_lowering, shape_tensor as shape_tensor, token_type as token_type, - xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) from jax._src.mesh import Mesh as Mesh diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc index 987856daa983..e045f7284ec6 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/xla/mlir.cc @@ -75,6 +75,17 @@ void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { pm.enableIRPrinting(print_before, print_after); } +absl::StatusOr HloToStableHlo(const nb::bytes& hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + // Converts an XlaComputation to a StableHLO mlir::Module string. // Exists for backwards compatibility. // TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules @@ -180,6 +191,9 @@ absl::StatusOr PyDeserializePortableArtifact( void BuildMlirSubmodule(nb::module_& m) { nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + mlir_module.def("xla_computation_to_mlir_module", xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), nb::arg("computation")); diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 776a22444208..7c4e2ccb427f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 323 +_version = 324 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla/xla_extension/mlir.pyi index 95eeae660c0c..961f01a0352c 100644 --- a/jaxlib/xla/xla_extension/mlir.pyi +++ b/jaxlib/xla/xla_extension/mlir.pyi @@ -16,6 +16,7 @@ from typing import Union from . import XlaComputation +def hlo_to_stablehlo(computation: bytes) -> bytes: ... def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... def mlir_module_to_xla_computation( mlir_module: Union[bytes, str], From 8c737993e94d8106e0641f565bc83a8632a03ec1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 09:32:42 -0700 Subject: [PATCH 0259/1769] Change the `step counter` to an `init flag` It is clearer to use a flag to indicate the first step than to use a step counter == 0, since in theory the step counter (a 32 bit integer in the code) can wrap around back to zero, even though this will unlikely happen since there are way less than 2**32 blocks. PiperOrigin-RevId: 741551623 --- .../paged_attention/paged_attention_kernel.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 99cb2c9c94c1..62f3101bef6e 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,7 +114,7 @@ def paged_flash_attention_kernel( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -223,16 +223,12 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): @pl.when(i * bk < length) def flash_attention(): # pylint: disable=unused-variable - step = step_ref[0] + init_flag = init_flag_ref[0] + init_flag_ref[0] = 0 buffer_index = buffer_index_ref[0] + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) - @pl.when(i == 0) - def init(): # pylint: disable=unused-variable - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - @pl.when(step == 0) + @pl.when(init_flag) def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k, async_copy_v = create_kv_async_copy_descriptors( b, h, i, buffer_index @@ -240,7 +236,11 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -283,14 +283,12 @@ def prefetch_next_block(): # pylint: disable=unused-variable (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next ).astype(o_ref.dtype) - step_ref[0] = step + 1 - def paged_flash_attention_kernel_inline_seq_dim( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -325,7 +323,7 @@ def body(i, _): lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -631,7 +629,7 @@ def paged_attention( ), grid_spec=pltpu.PrefetchScalarGridSpec( # There are 4 scalars prefetched per kernel call: `lengths_ref`, - # `page_indices_ref`, `buffer_index_ref`, `step_ref` + # `page_indices_ref`, `buffer_index_ref`, `init_flag_ref` num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ @@ -643,7 +641,8 @@ def paged_attention( scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_semantics), + dimension_semantics=dimension_semantics + ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), @@ -653,7 +652,7 @@ def paged_attention( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), # buffer index - jnp.zeros((1,), jnp.int32), # step + jnp.ones((1,), jnp.int32), # init flag q.astype(q_dtype_for_kernel_launch), k_pages, k_scales_pages, From 5950e722e292063f920f5be1d23296b10ce36074 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 09:43:05 -0700 Subject: [PATCH 0260/1769] Make sure `vma` on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though. PiperOrigin-RevId: 741554623 --- jax/_src/core.py | 39 ++++++++++++++++------------- jax/_src/ffi.py | 4 +-- jax/_src/lax/ann.py | 6 ++--- jax/_src/lax/control_flow/solves.py | 4 +-- jax/_src/lax/fft.py | 4 +-- jax/_src/lax/lax.py | 17 ++++++++++--- jax/_src/lax/utils.py | 7 ++---- jax/_src/lax/windowed_reductions.py | 7 ++---- jax/_src/prng.py | 21 ++++++---------- jax/experimental/shard_map.py | 3 ++- 10 files changed, 53 insertions(+), 59 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1be60336f1a9..ee6537650f20 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1659,10 +1659,9 @@ def physical_aval(aval): elt_aval = physical_element_aval(aval.dtype) if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore - vma = aval.vma if config.varying_axes_in_types.value else frozenset() return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, sharding=physical_sharding(aval, aval.sharding), - vma=vma) + vma=aval.vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -1917,6 +1916,7 @@ def get_vma(vma, mesh): raise ValueError( "Axes mentioned in `vma` field of ShapedArray should" f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + assert isinstance(vma, frozenset) return vma class ShapedArray(UnshapedArray): @@ -1929,8 +1929,7 @@ def __init__(self, shape, dtype, weak_type=False, *, sharding=None, self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) - if config.varying_axes_in_types.value: - self.vma = get_vma(vma, self.sharding.mesh) + self.vma = get_vma(vma, self.sharding.mesh) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1942,7 +1941,7 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding if 'vma' not in kwargs: - kwargs['vma'] = getattr(self, 'vma', frozenset()) + kwargs['vma'] = self.vma return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1960,26 +1959,24 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'vma', frozenset()) == - getattr(other, 'vma', frozenset()))) + and self.vma == other.vma) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'vma', frozenset()))) + self.vma)) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding, - vma=getattr(self, 'vma', frozenset())) + self.weak_type, sharding=self.sharding, vma=self.vma) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - getattr(self, 'vma', frozenset()), short_dtypes, mesh_axis_types) + self.vma, short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: @@ -2013,16 +2010,20 @@ def primal_dtype_to_tangent_dtype(primal_dtype): def standard_insert_pbroadcast(*args): if not config.varying_axes_in_types.value: return args + if not args: + return args # TODO(yashkatariya): Move pbroadcast out of shard_map from jax.experimental.shard_map import pbroadcast # type: ignore - in_vma = [get_aval(a).vma for a in args] + in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token + else aval.vma for a in args] out_vma = frozenset.union(*in_vma) return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) if out_vma - src else arg for arg, src in zip(args, in_vma)] -def standard_vma_rule(prim_name, *avals, **kwargs): +def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: + avals = tuple(a for a in avals if a is not abstract_token) if not avals: - return avals + return frozenset() vma, *vmas = [a.vma for a in avals] if not all(vma == vma_ for vma_ in vmas): raise ValueError( @@ -2078,6 +2079,10 @@ def update(self, shape=None, dtype=None, weak_type=None): def sharding(self): return NamedSharding(mesh_lib.empty_abstract_mesh, P()) + @property + def vma(self): + return frozenset() + def _len(self, tracer): return self.shape[0] @@ -2711,10 +2716,8 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - sh_dt = t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) - if config.varying_axes_in_types.value: - return sh_dt and t1.vma == t2.vma # type: ignore - return sh_dt + return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + and t1.vma == t2.vma) # type: ignore else: return False diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index c867ec16b9b3..eb3e9aaa10fb 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,7 +24,6 @@ import jax from jax._src import core -from jax._src import config from jax._src import deprecations from jax._src import dispatch from jax._src import effects @@ -639,8 +638,7 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - out_vma = (core.standard_vma_rule('ffi_call', *avals_in) - if config.varying_axes_in_types.value else frozenset()) + out_vma = core.standard_vma_rule('ffi_call', *avals_in) effects = {_FfiEffect} if has_side_effect else core.no_effects return tuple(r if r is core.abstract_token else r.update(vma=out_vma) for r in result_avals), effects diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0d2eb338da22..c9a68d84b024 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -77,7 +77,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src import ad_util from jax._src import core from jax._src import dispatch -from jax._src import config from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -240,10 +239,9 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, "approx_top_k with aggregate_to_topk=False not yet implemented when " f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") - out_vma = operand.vma if config.varying_axes_in_types.value else frozenset() return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type, vma=out_vma), - operand.update(shape=dims, dtype=np.dtype(np.int32), vma=out_vma)) + weak_type=operand.weak_type, vma=operand.vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4a0872bef4b2..2c736f403044 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,7 +23,6 @@ from jax._src import api from jax._src import api_util from jax._src import core -from jax._src import config from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.interpreters import ad @@ -325,8 +324,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - out_vma = (core.standard_vma_rule('linear_solve', *args_to_raise) - if config.varying_axes_in_types.value else frozenset()) + out_vma = core.standard_vma_rule('linear_solve', *args_to_raise) return (tuple(a.update(vma=out_vma) for a in args_to_raise), jaxprs.solve.effects) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 9044f48f278c..2eebe6d91f22 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,7 +23,6 @@ from jax import lax -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -125,8 +124,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - out_vma = x.vma if config.varying_axes_in_types.value else frozenset() - return x.update(shape=shape, dtype=dtype, vma=out_vma) + return x.update(shape=shape, dtype=dtype, vma=x.vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b79c81e19195..53bf9a0c7ebf 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6106,6 +6106,7 @@ def _ragged_dot_general_batch_rule( _ragged_dot_general_shape_rule, _ragged_dot_general_dtype_rule, 'ragged_dot_general', + vma_rule=partial(core.standard_vma_rule, 'ragged_dot') ) ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule @@ -6515,8 +6516,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - new_vma = (core.standard_vma_rule('broadcast_in_dim', x) - if config.varying_axes_in_types.value else frozenset()) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray @@ -7435,6 +7435,11 @@ def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] +def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + out_vma = core.standard_vma_rule('reduce', *operand_avals) + return [out_vma] * len(operand_avals) + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -7522,7 +7527,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, - None)) + _reduce_vma_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -8254,6 +8259,10 @@ def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, out_sharding): return (key.sharding, out_sharding) +def _rng_bit_generator_vma_rule(key, *, shape, dtype, algorithm, out_sharding): + assert key.vma == frozenset() + return (key.vma, frozenset()) + def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) @@ -8355,7 +8364,7 @@ def _rng_bit_generator_lowering( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, - None)) + _rng_bit_generator_vma_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 0a641c122064..8e97621912f1 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -19,7 +19,6 @@ from functools import partial from jax._src import core -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib @@ -113,8 +112,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) - out_vma = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value - else frozenset()) + out_vma = vma_rule(*avals, **kwargs) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, vma=out_vma) @@ -141,8 +139,7 @@ def standard_multi_result_abstract_eval( core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) - out_vmas = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value - else [frozenset()] * len(out_shapes)) + out_vmas = vma_rule(*avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 73fae7df40e1..472b92d858f9 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -21,7 +21,6 @@ from jax import tree_util from jax._src import api_util from jax._src import core -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import util @@ -338,10 +337,8 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - out_vma = (core.standard_vma_rule('reduce_window', *operand_avals) - if config.varying_axes_in_types.value else frozenset()) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, - vma=out_vma) + vma = core.standard_vma_rule('reduce_window', *operand_avals) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=vma) for op in operand_avals) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 17d16527bb71..926a57446f5b 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -178,8 +178,8 @@ def copy_to_host_async(self): def aval(self): logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') else None) - vma = (self._base_array.aval.vma if config.varying_axes_in_types.value else frozenset() - if hasattr(self._base_array, 'aval') else frozenset()) + vma = (self._base_array.aval.vma if hasattr(self._base_array, 'aval') + else frozenset()) return keys_shaped_array(self._impl, self.shape, logical_sharding, vma) @property @@ -552,8 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - out_vma = seeds_aval.vma if config.varying_axes_in_types.value else frozenset() - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, out_vma) + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, + seeds_aval.vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -587,9 +587,8 @@ def random_split_abstract_eval(keys_aval, *, shape): # TODO(yashkatariya): random_split should take sharding as an arg too so we # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) - out_vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec), out_vma) + keys_aval.sharding.with_spec(new_spec), keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -629,8 +628,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval): 'random_fold_in', keys_aval, msgs_aval) sharding = lax_internal.broadcasting_sharding_rule( 'random_fold_in', keys_aval, msgs_aval) - vma = (core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) - if config.varying_axes_in_types.value else frozenset()) + vma = core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl @@ -669,8 +667,7 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() - return core.ShapedArray(out_shape, out_dtype, vma=vma) + return core.ShapedArray(out_shape, out_dtype, vma=keys_aval.vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -727,9 +724,7 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - out_vma = (base_arr_aval.vma if config.varying_axes_in_types.value else - frozenset()) - return keys_shaped_array(impl, shape, sharding, out_vma) + return keys_shaped_array(impl, shape, sharding, base_arr_aval.vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c0306f0c5e91..44c2b569f947 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -578,7 +578,8 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - vma = frozenset({n for ns in names.values() for n in ns}) + vma = (frozenset({n for ns in names.values() for n in ns}) + if config.varying_axes_in_types.value else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array From e1c866cd0af657240620683cdc230e031f504998 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Fri, 28 Mar 2025 09:55:21 -0700 Subject: [PATCH 0261/1769] Fixed failing `ExcessPrecisionTest.test_matmul_f32_out_simple` test. PiperOrigin-RevId: 741558343 --- tests/pallas/tpu_fusable_matmul_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusable_matmul_test.py index df7c1221bb0c..5ee372ce92ab 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusable_matmul_test.py @@ -924,7 +924,7 @@ def matmul(impl, x, y): atol = 0 if jtu.is_device_tpu_at_least(6): # 256 MXU changes some tols. - atol = 1e-6 + atol = 1e-5 self.assertAllClose(out, out_ref, atol=atol) def test_matmul_f32_out_fused_downcast(self): From fde7d16c6086981e0f4bfd62e0a4a0618ded9b25 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 10:18:32 -0700 Subject: [PATCH 0262/1769] Clean up: num_groups = num_q_heads // num_kv_heads No code functionality change in this commit. PiperOrigin-RevId: 741566312 --- .../paged_attention/paged_attention_kernel.py | 35 ++++++++++--------- .../pallas/tpu_paged_attention_kernel_test.py | 25 +++++++------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 62f3101bef6e..4c03fb01be2b 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -257,7 +257,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable ) q = q_ref[...].astype(jnp.float32) k = async_copy_k.wait_and_get_loaded() - qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32) if attn_logits_soft_cap is not None: capped_qk = jnp.tanh(qk / attn_logits_soft_cap) qk = capped_qk * attn_logits_soft_cap @@ -277,10 +277,10 @@ def prefetch_next_block(): # pylint: disable=unused-variable m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() - o_curr_times_l_curr = jnp.dot(s_curr, v) + o_curr = jnp.einsum("gt,td->gd", s_curr, v) o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next + (l_prev * alpha * o_ref[...] + beta * o_curr) / l_next ).astype(o_ref.dtype) @@ -384,7 +384,7 @@ def paged_attention( """Paged grouped query attention. Args: - q: A [batch_size, num_heads, head_dim] jax.Array. + q: A [batch_size, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. @@ -409,7 +409,7 @@ def paged_attention( one kernel. Returns: - The output of attention([batch_size, num_heads, head_dim]). + The output of attention([batch_size, num_q_heads, head_dim]). """ if isinstance(k_pages, quantization_utils.QuantizedTensor): k_pages, k_scales_pages = k_pages.weight, k_pages.scales @@ -428,7 +428,7 @@ def paged_attention( else: v_scales_pages = None - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape @@ -437,10 +437,10 @@ def paged_attention( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) - if num_heads % num_kv_heads != 0: + if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" - f" {num_heads} and {num_kv_heads}." + f" {num_q_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( @@ -477,40 +477,41 @@ def paged_attention( else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") - if (num_heads // num_kv_heads) % 8 != 0: + num_groups = num_q_heads // num_kv_heads + if (num_groups) % 8 != 0: # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a # <8x128> layout for a <1x128> memref inside the kernel and error out. - q = q.reshape(batch_size, num_heads, 1, head_dim) + q = q.reshape(batch_size, num_q_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype @@ -659,4 +660,4 @@ def paged_attention( v_pages, v_scales_pages, ) - return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) + return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype) diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 7fbccdb338d4..e778c72a8278 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -22,15 +22,12 @@ import numpy as np -jax.config.parse_flags_with_absl() - - -def _generate_qkv( +def _generate_random_qkv( seq_lens, page_size, max_seq_len, num_kv_heads, - num_heads, + num_q_heads, head_dim, prng_key, dtype=jnp.float32, @@ -55,7 +52,7 @@ def _generate_qkv( page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) page_indices = jax.random.permutation(k3, page_indices, independent=True) page_indices = page_indices.reshape(batch_size, pages_per_sequence) - q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = jax.random.normal(k4, (batch_size, num_q_heads, head_dim), dtype=dtype) return q, k_pages, v_pages, page_indices @@ -64,7 +61,7 @@ def _reconstruct_kv(page_indices, pages): pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape + num_kv_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): return jnp.take(pages, page_indices, 1) @@ -72,15 +69,16 @@ def per_sequence_page_gather(pages, page_indices): gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( pages, page_indices ) - return gathered.reshape(batch_size, num_heads, -1, head_dim) + return gathered.reshape(batch_size, num_kv_heads, -1, head_dim) def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape _, num_kv_heads, max_seq_len, _ = k.shape assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) + assert num_q_heads % num_kv_heads == 0 + num_groups = num_q_heads // num_kv_heads + q = q.reshape(batch_size, num_kv_heads, num_groups, head_dim) if isinstance(k, quantization_utils.QuantizedTensor): k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) @@ -97,7 +95,7 @@ def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] weights = jax.nn.softmax(logits, axis=-1) o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) + return o.reshape(batch_size, num_q_heads, head_dim) def _megacore_enabled(): @@ -149,7 +147,7 @@ def test_paged_attention( max_kv_len = 2048 block_size = 512 seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) - q, k_pages, v_pages, page_indices = _generate_qkv( + q, k_pages, v_pages, page_indices = _generate_random_qkv( seq_lens, page_size, max_kv_len, @@ -188,4 +186,5 @@ def test_paged_attention( if __name__ == "__main__": + jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) From 829deb68f62a9c5e51fac12f9f824d21a8f379be Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 17:19:34 +0000 Subject: [PATCH 0263/1769] Set NB_DOMAIN=jax This is a precautionary measure to prevent conflicts with other packages using nanobind and registering the same types. We don't want JAX's nanobind registrations to conflict on, say, XLA types with other projects. --- .bazelrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.bazelrc b/.bazelrc index 2d38dcc87044..422363644578 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,7 @@ build -c opt build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +build --copt=-DNB_DOMAIN=jax # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel From ecd9f5ded81eede59986d90c10b52ca852b4325e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 10:27:01 -0700 Subject: [PATCH 0264/1769] Move aval_to_xla_shape into callback.py, which is its only user. Specialize it to one shape per aval, since that's the only case that exists. Remove some pointless assertions using this code. PiperOrigin-RevId: 741569024 --- jax/_src/api.py | 2 -- jax/_src/callback.py | 32 ++++++++++++++++++++++++-------- jax/_src/interpreters/xla.py | 22 +--------------------- jax/_src/prng.py | 1 - 4 files changed, 25 insertions(+), 32 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index e01bdd4a9d81..55e2b2126a68 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -82,7 +82,6 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla traceback_util.register_exclusion(__file__) @@ -2591,7 +2590,6 @@ def _device_put_replicated(x): sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) - assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) with config.explicit_device_put_scope(): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 683da66638e6..20334b6cd269 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -670,6 +670,25 @@ def receive_from_host( return token, result + +def _aval_to_xla_shape(aval: core.AbstractValue) -> xc.Shape: + try: + return _xla_shape_handlers[type(aval)](aval) + except KeyError as err: + raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err + +_xla_shape_handlers: dict[type[core.AbstractValue], + Callable[[Any], xc.Shape]] = {} + +def _make_array_shape(aval: core.ShapedArray) -> xc.Shape: + aval = core.physical_aval(aval) + dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype + return xc.Shape.array_shape(dtype, aval.shape) +_xla_shape_handlers[core.ShapedArray] = _make_array_shape + +_xla_shape_handlers[core.AbstractToken] = lambda _: xc.Shape.token_shape() + + def _emit_tpu_python_callback( backend: xb.XlaBackend, ctx: mlir.LoweringRuleContext, @@ -699,8 +718,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined send_channel = ctx.module_context.new_channel() dummy_send_aval = core.ShapedArray((1,), np.float32) dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] + operand_shapes = [*operand_shapes, _aval_to_xla_shape(dummy_send_aval)] token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, sharding=sharding) send_channels.append(send_channel) @@ -763,10 +781,8 @@ def emit_python_callback( raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) + result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] + operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] # Handling layouts if operand_layouts is None: operand_layouts = util.concatenate( @@ -836,10 +852,10 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function return (token, *callback_without_token(*args)) operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + _aval_to_xla_shape(core.abstract_token), *operand_shapes ] result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + _aval_to_xla_shape(core.abstract_token), *result_shapes ] operands = [token, *operands] result_types = [mlir.token_type(), *result_types] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 33a8992a8be4..7fbb22923e0f 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Any, Union @@ -25,7 +25,6 @@ from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -41,11 +40,6 @@ def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() -def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: - aval = core.physical_aval(aval) - dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype - return (xc.Shape.array_shape(dtype, aval.shape),) - # Utilities # HLO instructions optionally can be annotated to say how the output should be @@ -90,20 +84,6 @@ def tuple_sharding_proto(elems): ### handlers -# JAX abstract values -> XLA shapes - -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: - try: - return _xla_shape_handlers[type(aval)](aval) - except KeyError as err: - raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err - -_xla_shape_handlers: dict[type[core.AbstractValue], - Callable[[Any], Sequence[xc.Shape]]] = { - ShapedArray: _make_array_shape, -} -_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 926a57446f5b..0106aa310383 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -425,7 +425,6 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) - assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) phys_sharding = physical_sharding(aval, sharding) physical_result = pxla.batched_device_put( From d4c42d7199f39a0b4639a32350abf3e8fb8a6043 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Fri, 28 Mar 2025 10:54:48 -0700 Subject: [PATCH 0265/1769] implement nbytes for PRNGKeyArray --- jax/_src/prng.py | 4 ++++ tests/random_test.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..ad96d9409083 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -188,6 +188,10 @@ def ndim(self): def dtype(self): return KeyTy(self._impl) + @property + def nbytes(self): + return self.itemsize * self.size + @property def itemsize(self): return self.dtype.itemsize diff --git a/tests/random_test.py b/tests/random_test.py index a51e387dca76..22df8b0b0649 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -657,6 +657,11 @@ def test_non_integer_seed(self): with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): random.key(seed) + def test_nbytes_property(self): + key = self.make_keys() + self.assertEqual(key.nbytes, key._base_array.nbytes) + self.assertEqual(key.nbytes, key.itemsize * key.size) + def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys() self.assertEqual(k1.dtype, k2.dtype) From fbff338a8ef92f99b896ee8d1f0ac65d830edcfa Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 28 Mar 2025 11:00:30 -0700 Subject: [PATCH 0266/1769] [pallas:mosaic_gpu] `GPUMesh` now accepts axis names in a more structured way This is hopefully less confusing then bunching them together in a single argument. PiperOrigin-RevId: 741580827 --- jax/_src/pallas/mosaic_gpu/core.py | 36 +++++++++++-------- jax/_src/pallas/mosaic_gpu/lowering.py | 23 +++--------- .../pallas/ops/gpu/attention_mgpu.py | 6 ++-- tests/pallas/mosaic_gpu_test.py | 34 ++++++++++-------- 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 857daaefe38f..f8c1ebf442b0 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -575,22 +575,29 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: @dataclasses.dataclass(frozen=True, kw_only=True) class GPUMesh: - grid: tuple[int, ...] = () - cluster: tuple[int, ...] = () + grid: Sequence[int] = () + grid_names: Sequence[str] = () + cluster: Sequence[int] = () + cluster_names: Sequence[str] = () # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None - axis_names: tuple[str, ...] = () + thread_name: str | None = None def __post_init__(self): if len(self.cluster) > 3: raise ValueError(f"cluster= must be at most 3D, got {self}.") - num_axis_names = ( - len(self.grid) + len(self.cluster) + (self.num_threads is not None) - ) - if len(self.axis_names) != num_axis_names: + if len(self.grid_names) != len(self.grid): + raise ValueError( + f"grid_names must have the same length as grid, got {self}." + ) + if len(self.cluster_names) != len(self.cluster): raise ValueError( - "Need an axis name for each grid and cluster dimension plus " - f" an additional axis name when num_threads= is given, got {self}." + f"cluster_names must have the same length as cluster, got {self}." + ) + if (self.thread_name is None) != (self.num_threads is None): + raise ValueError( + "num_threads and thread_name must be either both set or both None," + f" got {self}" ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( @@ -607,14 +614,13 @@ def shape(self) -> collections.OrderedDict[object, int]: pairs: Iterable[tuple[object, int]] if self.num_threads is not None: pairs = zip( - self.axis_names, (*self.grid, *self.cluster, self.num_threads) + (*self.grid_names, *self.cluster_names, self.thread_name), + (*self.grid, *self.cluster, self.num_threads), ) else: - pairs = tuple( - zip( - (*self.axis_names, _WARPGROUP_AXIS_NAME), - (*self.grid, *self.cluster, 1), - ) + pairs = zip( + (*self.grid_names, *self.cluster_names), + (*self.grid, *self.cluster), ) return collections.OrderedDict(pairs) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index baac1e6eb316..42914c95085a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -274,17 +274,6 @@ def __iter__(self) -> Iterable[Hashable]: self.grid, self.cluster, [self.wg] if self.wg is not None else [] ) - @classmethod - def from_mesh( - cls, mesh: gpu_core.GPUMesh, axis_names: Sequence[str] - ) -> "_AxisNames": - wg_name = None - if mesh.num_threads is not None: - wg_name = axis_names[-1] - axis_names = axis_names[:-1] - grid_names, cluster_names = util.split_list(axis_names, [len(mesh.grid)]) - return cls(grid_names, cluster_names, wg_name) - @dataclasses.dataclass class ModuleContext: @@ -552,12 +541,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh is not None: + if mesh: assert isinstance(mesh, gpu_core.GPUMesh) - if mesh and mesh.num_threads is not None: - # Last dim corresponds to the warpgroup count. - block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid_mapping.grid[:-1] + block = (128 * (mesh.num_threads or 1), 1, 1) + grid = mesh.grid else: block = (128, 1, 1) grid = grid_mapping.grid @@ -665,9 +652,9 @@ def body_fn(*refs): assert not new_consts axis_names = ( - _AxisNames.from_mesh(mesh, grid_mapping.grid_names) + _AxisNames(mesh.grid_names, mesh.cluster_names, mesh.thread_name) if mesh is not None - else _AxisNames(grid_mapping.grid_names) + else _AxisNames(grid_mapping.grid_names or ()) ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 8883878f5f0e..534da419ed3b 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -227,8 +227,9 @@ def entry(q_ref, k_ref, v_ref, out_ref): entry, out_shape=q, grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", compiler_params=plgpu.GPUCompilerParams(approx_math=True), )(q, k, v) @@ -366,8 +367,9 @@ def compute_pv(acc_ref): pipeline(k_ref, v_ref) mesh = plgpu.GPUMesh( grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", ) def run(refs): q_ref, k_ref, v_ref, out_ref = refs diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 874ecae93f3f..b6c0652c13fe 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1707,7 +1707,7 @@ def test_tmem_alloc(self): plgpu.SMEM((128, 128), jnp.float32), ], num_threads=1, - axis_names=("x",), + thread_name="x", ) def kernel(y_ref, tmem_ref, smem_ref): # Issue a write so the TMEM load is not DCE'd. @@ -2096,8 +2096,9 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): ), compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=3, - axis_names=("_", "wg"), + thread_name="wg", ) out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) @@ -2130,8 +2131,9 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), + thread_name="wg", ) x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) @@ -2148,8 +2150,9 @@ def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): ], compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), + thread_name="wg", ) def kernel(x_gmem, acc_gmem, acc_smem): def _compute_thread(): @@ -2204,7 +2207,7 @@ def test_multiple_wg(self): plgpu.kernel, out_shape=jnp.zeros((2, 128), np.int32), num_threads=2, - axis_names=("wg",), + thread_name="wg", ) def kernel(o_ref): wg_idx = jax.lax.axis_index("wg") @@ -2219,8 +2222,9 @@ def test_multiple_wg_with_grid(self): plgpu.kernel, out_shape=jnp.zeros((4, 2, 128), np.int32), grid=(2, 2), + grid_names=("x", "y"), num_threads=2, - axis_names=("x", "y", "wg"), + thread_name="wg", ) def kernel(o_ref): xy_idx = jax.lax.axis_index(("x", "y")) @@ -2250,8 +2254,9 @@ def test_multiple_wg_with_squashed_grid(self): (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 ), grid=(b, x_dim, y_dim, z_dim), + grid_names=("b", "x", "y", "z"), num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg"), + thread_name="wg", ) def kernel(o_ref): b_idx = jax.lax.axis_index("b") @@ -2277,7 +2282,7 @@ def test_cross_wg_barrier(self): # Each warpgroup is a single logical thread! scratch_shapes=[plgpu.Barrier(num_arrivals=2)], num_threads=2, - axis_names=("wg",), + thread_name="wg", ) def kernel(o_ref, barrier): plgpu.barrier_arrive(barrier) @@ -2294,8 +2299,9 @@ def test_cluster(self): plgpu.kernel, out_shape=jnp.zeros(128, np.int32), grid=(2,), + grid_names=("x",), cluster=(2,), - axis_names=("x", "cluster"), + cluster_names=("cluster",), ) def kernel(ref): block_idx = jax.lax.axis_index("x") @@ -2336,7 +2342,7 @@ def body(l_ref, r_ref, o_ref): o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) # Async copies @@ -2351,7 +2357,7 @@ def test_stage3(self): plgpu.Barrier(num_arrivals=2), ], grid=(2,), - axis_names=("rows",), + grid_names=("rows",), ) def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) @@ -2382,7 +2388,7 @@ def compute(l_smem, r_smem, o_smem): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) # Transforms @@ -2404,7 +2410,7 @@ def compute(l_smem, r_smem, o_smem): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) def test_semaphore_lowering(self): @@ -2456,7 +2462,7 @@ def do_wgmma(acc_ref): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2, 2), grid_names=("m", "n"))(x, x) np.testing.assert_allclose(out, x @ x) # TODO(apaszke): Clusters and multicast From e838fe19d3b2a7b41c5ba0a8b7d98d7b9ea9e477 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 28 Mar 2025 13:00:24 -0700 Subject: [PATCH 0267/1769] [pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to lane-level lowering More work is needed to support these in the WG lowering. PiperOrigin-RevId: 741622096 --- jax/_src/pallas/core.py | 20 ++-- jax/_src/pallas/mosaic_gpu/core.py | 24 +++++ jax/_src/pallas/mosaic_gpu/lowering.py | 80 ++++++++++++--- jax/_src/pallas/mosaic_gpu/primitives.py | 19 +++- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 121 +++++++++++++++++++++++ 6 files changed, 243 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8602205eef22..a74206c46ce7 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1089,7 +1089,10 @@ def wrapped(f): debug_info=api_util.debug_info("pallas_core_map", f, (), {})), in_tree) - with jax_core.extend_axis_env_nd(mesh.shape.items()): + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, compiler_params=compiler_params, @@ -1144,6 +1147,7 @@ def default_mesh_discharge_rule( interpret, cost_estimate, name, + memory_space=MemorySpace.ANY, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" del out_avals # Unused. @@ -1160,13 +1164,9 @@ def body(*args): for eff in jaxpr.effects if isinstance(eff, state_types.WriteEffect) ) - any_spec = BlockSpec(memory_space=MemorySpace.ANY) - grid_spec = GridSpec( - grid=tuple(mesh.shape.items()), - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), - ) + spec = BlockSpec(memory_space=memory_space) from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call._pallas_call( body, name=name, @@ -1174,7 +1174,11 @@ def body(*args): input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid_spec=grid_spec, + grid_spec=GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[spec] * len(in_avals), + out_specs=[spec] * len(modified_idxs), + ), mesh=mesh, compiler_params=compiler_params, interpret=interpret, diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index f8c1ebf442b0..8522bdf651f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -502,6 +502,17 @@ def __str__(self): return self.name +@dataclasses.dataclass(frozen=True) +class ClusterBarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "cluster_barrier" + + collective_axes: tuple[str | tuple[str, ...], ...] + + def __str__(self): + return self.name + + @dataclasses.dataclass(frozen=True) class Barrier: num_arrivals: int @@ -514,6 +525,18 @@ def get_ref_aval(self) -> AbstractMemoryRef: return AbstractMemoryRef(aval, SMEM) +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_axes: tuple[str | tuple[str, ...], ...] + num_barriers: int = 1 + + def get_ref_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], ClusterBarrierType(self.collective_axes) + ) + return AbstractMemoryRef(aval, SMEM) + + @dataclasses.dataclass(frozen=True) class WGMMAAccumulatorRef: shape: tuple[int, int] @@ -660,6 +683,7 @@ def _gpu_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + memory_space=GMEM, ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 42914c95085a..e99feb4dc144 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -85,8 +85,9 @@ def _align_to(x: int, alignment: int): return x -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: + axis_names: _AxisNames thread_semantics: mgpu.ThreadSemantics @property @@ -98,11 +99,14 @@ def arrival_multiplier(self) -> int: ) +AnyBarrier = mgpu.Barrier | mgpu.ClusterBarrier + + @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 tmem_scratch_cols: int = 0 - barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) @@ -120,7 +124,7 @@ def __post_init__(self): ) @property - def barriers(self) -> Sequence[mgpu.Barrier]: + def barriers(self) -> Sequence[AnyBarrier]: return list(self.barrier_counts.elements()) def __add__(self, other: Resources) -> Resources: @@ -230,6 +234,16 @@ def _run_scoped_resource_estimator( ) ]) ) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + rs += Resources( + barrier_counts=collections.Counter( + [mgpu.ClusterBarrier(collective_dims, *aval.shape)] + ) + ) elif aval.memory_space == gpu_core.TMEM: if aval.dtype.itemsize != 4: raise ValueError("TMEM only supports 32-bit types.") @@ -275,6 +289,9 @@ def __iter__(self) -> Iterable[Hashable]: ) +AnyBarrierRef = mgpu.BarrierRef | mgpu.CollectiveBarrierRef + + @dataclasses.dataclass class ModuleContext: name: str @@ -287,9 +304,7 @@ class ModuleContext: tmem_requested_cols: int tmem_used_cols: int tmem_base_ptr: ir.Value - runtime_barriers: MutableMapping[ - mgpu.Barrier, MutableSequence[mgpu.BarrierRef] - ] + runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] @@ -399,7 +414,10 @@ class LoweringRuleContext: @property def estimator_ctx(self) -> ResourceEstimatorContext: - return ResourceEstimatorContext(thread_semantics=self.module_ctx.thread_semantics) + return ResourceEstimatorContext( + axis_names=self.module_ctx.axis_names, + thread_semantics=self.module_ctx.thread_semantics, + ) @dataclasses.dataclass(frozen=True) @@ -746,7 +764,12 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources(ResourceEstimatorContext(thread_semantics), jaxpr) + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, thread_semantics=thread_semantics + ), + jaxpr, + ) smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes @@ -1784,23 +1807,43 @@ def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) +def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if not axis_names or axis_name not in axis_names.cluster: + raise LookupError( + f"Unknown cluster axis {axis_name}, available axes:" + f" {[*axis_names.cluster]}" + ) + return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names - if not axis_names or axis_name not in axis_names: - raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if axis_name not in axis_names: + raise LookupError( + f"Unknown axis {axis_name}, available axes: {[*axis_names]}" ) if axis_names.wg is not None and axis_name == axis_names.wg: return mgpu.warpgroup_idx(sync=True) if axis_name in axis_names.cluster: - idx = axis_names.cluster.index(axis_name) return arith_dialect.index_cast( ir.IntegerType.get_signless(32), - gpu_dialect.cluster_block_id(gpu_dialect.Dimension(idx)), + gpu_dialect.cluster_block_id( + gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + ), ) squashed_dims = ctx.module_ctx.squashed_dims @@ -1913,6 +1956,17 @@ def _run_scoped_lowering_rule( ) ) should_discharge.append(False) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + input_refs.append( + ctx.module_ctx.reserve_barrier( + mgpu.ClusterBarrier(collective_dims, *aval.shape) + ) + ) + should_discharge.append(False) elif aval.memory_space == gpu_core.SMEM: [input_ref] = alloc_stack.enter_context( ctx.module_ctx.scratch_view( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 9dc65c1bef88..a9bd91b26622 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -361,6 +361,7 @@ def _copy_gmem_to_smem_lowering( src_transforms_treedef, dst_transforms_treedef, barrier_transforms_treedef, + collective_axes, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -382,6 +383,12 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes + ) dst_ty = ir.MemRefType(dst.type) bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: @@ -400,6 +407,7 @@ def _copy_gmem_to_smem_lowering( barrier=barrier, arrive=False, predicate=ctx.module_ctx.single_wg_lane_predicate, + collective=collective, **copy_params, ) return () @@ -425,7 +433,13 @@ def _copy_gmem_to_smem_lowering( return () -def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: +def copy_gmem_to_smem( + src: _Ref, + dst: _Ref, + barrier: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, +) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. See also: @@ -450,6 +464,8 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( barrier_transforms ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) copy_gmem_to_smem_p.bind( src, dst, @@ -460,6 +476,7 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, + collective_axes=collective_axes, ) return None diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index b791fbb8b573..e4c5ffe04093 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -18,6 +18,7 @@ """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier +from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b6c0652c13fe..c8013f634c67 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2322,6 +2322,127 @@ def kernel(ref): }, ) + def test_realistic_matmul_with_cluster(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 32 + # TODO(slebedev): Remove ``grid_tile_n`` to simplify the test. + grid_tile_n = 4 + assert grid_n % grid_tile_n == 0 + cluster_m = 2 + cluster_n = 2 + cluster_tile_n = min(cluster_n, grid_tile_n) + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = ( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + max_concurrent_steps = 2 + delay_release = 1 + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=[ + plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + plgpu.ACC((tile_m, tile_n), jnp.float32), + plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + plgpu.ClusterBarrier( + collective_axes=(("x", "z"), "y"), + num_barriers=max_concurrent_steps, + ), + ], + grid=(grid_tile_n, grid_m, grid_n // grid_tile_n), + grid_names=("tile_n", "m", "n"), + cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n), + cluster_names=("x", "y", "z"), + ) + def kernel( + a_gmem, + b_gmem, + o_gmem, + a_smem, + b_smem, + o_smem, + acc, + barrier, + cluster_barrier, + ): + m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m) + n_slice = pl.ds( + (lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n) + * tile_n, + tile_n, + ) + + def fetch(step, slot): + if not isinstance(slot, int): # Skip in initialization. + plgpu.barrier_arrive(cluster_barrier.at[slot]) + plgpu.barrier_wait(cluster_barrier.at[slot]) + + k_slice = pl.ds(step * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], + a_smem.at[slot], + barrier.at[slot], + collective_axes=("x", "z"), + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], + b_smem.at[slot], + barrier.at[slot], + collective_axes="y", + ) + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, grid_k)): + fetch(slot, slot) + + def body(step, _): + slot = step % max_concurrent_steps + plgpu.barrier_wait(barrier.at[slot]) + + plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot]) + plgpu.wgmma_wait(delay_release) + + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < grid_k), + lambda: fetch(fetch_step, fetch_slot), + lambda: None, + ) + return () + + jax.lax.fori_loop(0, grid_k, body, ()) + + # Finalize the pipeline. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + np.testing.assert_array_equal(kernel(a, b), a @ b) + class ExamplesTest(PallasTest): From b3a2c5341db9ad04c464b878ec4e59ffe9498918 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 14:14:07 -0700 Subject: [PATCH 0268/1769] [NFC] Fix linter errors in pipeline file PiperOrigin-RevId: 741644574 --- jax/_src/pallas/mosaic/pipeline.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 184b1497adf9..9b0a9322c94d 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -213,8 +213,8 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - window_ref: REF | None - accum_ref: REF | None + window_ref: ArrayRef | None + accum_ref: ArrayRef | None current_slot: ArrayRef | None # TODO(ramiroleal): Unused by class. Remove argument from # BufferedRef instantiations. @@ -337,6 +337,7 @@ def memory_space(self): def current_ref(self): buffer_slice = tuple( 0 if x is None else slice(None) for x in self.block_shape) + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: @@ -368,10 +369,12 @@ def is_input_output(self): @property def current_slot_index(self): + """Index in double buffer corresponding to the current slot.""" return self.current_slot[0] @property def next_slot_index(self): + """Index in double buffer corresponding to the next slot.""" return lax.rem(self.current_slot_index + 1, 2) def bind_existing_ref(self, window_ref, indices): @@ -463,6 +466,8 @@ def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None if self.swap is not None: self.swap[0] = True next_slot = self.next_slot_index @@ -470,7 +475,7 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.window_ref.at[next_slot].at[dst_slice], + self.window_ref.at[(next_slot, *dst_slice)], self.sem_recvs.at[next_slot], ).start() @@ -478,13 +483,15 @@ def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None if self.swap is not None: self.swap[0] = True slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[slot].at[src_slice], + self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], self.sem_sends.at[slot], ).start() @@ -493,13 +500,15 @@ def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter - self.window_ref.at[current_slot].at[ - dst_slice + self.window_ref.at[ + (current_slot, *dst_slice) ], # only dst shape is important self.sem_recvs.at[current_slot], ).wait() @@ -508,12 +517,14 @@ def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None # In a double buffer, previous slot is the same as next slot. prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important self.sem_sends.at[prev_slot], ).wait() @@ -533,16 +544,18 @@ def set_accumulator(self, init=False): """Set accumulator or zero it out to initialize.""" assert self.is_accumulator if self.accum_ref is not None: + accum_dtype = self.accum_ref.dtype def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) + self.accum_ref[...] = self.current_ref[...].astype(accum_dtype) lax.cond(init, _init, _set) def accumulate(self): """Add into the current slot.""" assert self.is_accumulator if self.accum_ref is not None: + assert self.window_ref is not None accum_dtype = jnp.float32 if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 From 91dac631fb79297a947d4742fb79e5898ece31c5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 14:15:25 -0700 Subject: [PATCH 0269/1769] scan: improve docs & errors around dynamic length --- jax/_src/lax/control_flow/loops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c7bcb1cf6b09..0362c139570a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -178,6 +178,11 @@ def scan(f, init, xs, length=None): :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. + .. note:: + :func:`scan` is designed for iterating with a static number of iterations. + For iteration with a dynamic number of iterations, use :func:`fori_loop` + or :func:`while_loop`. + Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop @@ -239,7 +244,9 @@ def scan(f, init, xs, length=None): try: length = int(length) except core.ConcretizationTypeError as err: - msg = 'The `length` argument to `scan` expects a concrete `int` value.' + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " From b719ac00c63ebb74766e7be3b142c046213a18ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 15:12:41 -0700 Subject: [PATCH 0270/1769] Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM. We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth. Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value. PiperOrigin-RevId: 741660728 --- .../pallas/ops/tpu/ragged_paged_attention.py | 54 +++++++++++-------- .../pallas/tpu_ragged_paged_attention_test.py | 8 +-- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 255670c22e90..e1eacee550a7 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -83,8 +83,8 @@ def ref_ragged_paged_attention( soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, ): - check_inputs_shapes( - queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs + validate_static_inputs( + queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE @@ -130,7 +130,7 @@ def ref_ragged_paged_attention( # Expect to run these checkes during runtime. -def validate_inputs_on_runtime( +def validate_dynamic_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] @@ -140,7 +140,7 @@ def validate_inputs_on_runtime( sliding_window: int | None = None, soft_cap: float | None = None, ): - check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) max_num_batched_tokens = q.shape[0] page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape @@ -165,20 +165,18 @@ def validate_inputs_on_runtime( raise ValueError( f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." ) - if sliding_window is not None and sliding_window <= 0: - raise ValueError(f"{sliding_window=} must be positive.") - if soft_cap is not None and soft_cap == 0.0: - raise ValueError(f"{soft_cap=} must not be 0.0.") # Expect to run these checks during compile time. -def check_inputs_shapes( +def validate_static_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] + sliding_window: int | None = None, + soft_cap: float | None = None, ): _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape @@ -213,6 +211,10 @@ def check_inputs_shapes( ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") def ragged_paged_attention_kernel( @@ -233,6 +235,7 @@ def ragged_paged_attention_kernel( sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] *, sm_scale: float, sliding_window: int | None = None, @@ -357,7 +360,7 @@ def flash_attention( v, # [num_kv_per_blk, head_dim] head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] *, kv_blk_idx, ): @@ -378,7 +381,7 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, 128, ) - assert head_o_ref.shape == ( + assert head_acc_ref.shape == ( num_q_per_blk, num_q_heads_per_kv_head, head_dim, @@ -414,8 +417,8 @@ def init_scratch_ref(): num_q_heads_per_kv_head, ) masked_store( - head_o_ref, - jnp.zeros_like(head_o_ref), + head_acc_ref, + jnp.zeros_like(head_acc_ref), store_start, store_end, ) @@ -481,17 +484,17 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_o_ref[...].reshape(-1, head_dim) + o_curr = head_acc_ref[...].reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) out = lax.div( l_alpha * o_curr + beta * qkv, l_next_safe, - ).astype(head_o_ref.dtype) + ) masked_store( - head_o_ref, - out.reshape(head_o_ref.shape), + head_acc_ref, + out.reshape(head_acc_ref.shape), store_start, store_end, ) @@ -544,7 +547,7 @@ def prefetch_next_kv_blk(): v, l_ref.at[kv_head_idx], m_ref.at[kv_head_idx], - o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], kv_blk_idx=kv_blk_idx, ) return kv_blk_idx + 1, next_buf_idx @@ -566,6 +569,7 @@ def prefetch_next_kv_blk(): # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) def cdiv(a, b): @@ -662,6 +666,7 @@ def ragged_paged_attention( num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. @@ -672,7 +677,7 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) if mask_value is None: mask_value = DEFAULT_MASK_VALUE _, num_q_heads, head_dim = q.shape @@ -710,6 +715,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), jnp.float32, ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) double_buf_scratch = pltpu.VMEM( ( 2, # For double buffering during DMA copies. @@ -725,6 +734,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref + acc_scratch, ] scalar_prefetches = ( kv_lens, @@ -755,10 +765,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ), vmem_limit_bytes=vmem_limit_bytes, ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), name="ragged_paged_attention_kernel", ) - # TODO(jevinjiang): Use f32 acc scratch for output! So we only need - # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype) + return kernel(*scalar_prefetches, q, kv_pages) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index b76d30bd1dcf..8d48bc281400 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -21,7 +21,7 @@ from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( ragged_paged_attention, ref_ragged_paged_attention, - validate_inputs_on_runtime, + validate_dynamic_inputs, ) import jax.numpy as jnp @@ -91,15 +91,15 @@ def _test_ragged_paged_attention( num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_inputs_on_runtime( + validate_dynamic_inputs( q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, - sliding_window=sliding_window, - soft_cap=soft_cap, + sliding_window, + soft_cap, ) actual_num_q_tokens = cu_q_lens[num_seqs[0]] From 177193662cba6a228fc26cc5a08efb073ec775ab Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 15:15:22 -0700 Subject: [PATCH 0271/1769] Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives PiperOrigin-RevId: 741661360 --- jax/_src/lax/parallel.py | 60 +++++++++++++++++++++++++++++++--------- tests/shard_map_test.py | 17 ++++++++++++ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 28e6dbef4a2c..3ef0a2520378 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -25,6 +25,7 @@ import jax from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, @@ -325,9 +326,10 @@ def ppermute(x, axis_name, perm): """ if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - return tree_util.tree_map( - partial(ppermute_p.bind, axis_name=axis_name, - perm=tuple(map(tuple, perm))), x) + def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) + return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + return tree_util.tree_map(bind, x) def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -447,6 +449,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now + x = insert_collective_pbroadcast(axis_name, x) result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, @@ -975,6 +978,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) + collective_vma_rule('ppermute', axis_name, x) return x ppermute_p = core.Primitive('ppermute') @@ -1189,7 +1193,8 @@ def _all_to_all_effectful_abstract_eval( assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + vma = collective_vma_rule('all_to_all', axis_name, input_aval) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects @@ -1313,6 +1318,19 @@ def _ragged_all_to_all_transpose( mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') +def insert_collective_pbroadcast(axis_name, x): + if not config.varying_axes_in_types.value: + return x + + from jax.experimental import shard_map + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + aval = core.get_aval(x) + names_union = set(axis_name) | aval.vma + pbroadcast_axis_name = tuple(n for n in names_union if n not in aval.vma) + if pbroadcast_axis_name: + x = shard_map.pbroadcast(x, pbroadcast_axis_name) + return x + def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): """Gather values of x across all replicas. @@ -1382,6 +1400,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) return all_gather_p.bind( leaf, all_gather_dimension=canonicalize_axis( @@ -1434,6 +1453,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, **other_args).results +def collective_vma_rule(prim_name, axis_name, x_aval): + if not config.varying_axes_in_types.value: + return frozenset() + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if any(a not in x_aval.vma for a in axis_name): + raise ValueError( + f"Collective {prim_name} must be applied to a device-varying " + f" type, but got {x_aval.vma} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return x_aval.vma + def _all_gather_effectful_abstract_eval( x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): @@ -1445,7 +1477,9 @@ def _all_gather_effectful_abstract_eval( new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + out_vma = collective_vma_rule('all_gather', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, @@ -1582,7 +1616,9 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + vma = collective_vma_rule('reduce_scatter', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=vma), + {*map(core.NamedAxisEffect, axis_name)}) def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1726,13 +1762,11 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, axis_name = axis_name, axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - bind = partial( - reduce_scatter_p.bind, - axis_name=axis_name, - scatter_dimension=scatter_dimension, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) + def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) + return reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1ffb3e1d137a..c1923f5b0ae3 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2707,6 +2707,23 @@ def f(x): # return jnp.sum(f(x, y)) # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + @config.varying_axes_in_types(True) + def test_all_gather_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset()) + out = jax.lax.all_gather(x, 'x') + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('x',)", str(jaxpr)) + + f(x) # doesn't crash + class FunSpec(NamedTuple): name: str From dafebd0d7f2c79dafb3fa2e6f358bdb67d0dfaa9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 15:20:58 -0700 Subject: [PATCH 0272/1769] DOC: add documentation note about default dtypes --- docs/default_dtypes.md | 82 ++++++++++++++++++++++++++++++++ docs/notes.rst | 8 +++- docs/user_guides.rst | 1 - jax/_src/numpy/array_creation.py | 9 ++-- 4 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 docs/default_dtypes.md diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/notes.rst b/docs/notes.rst index 24a9dc8594cd..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -17,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,4 +31,6 @@ Programmer guardrails: async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 6481da7a31dd..47984fc493f4 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -26,7 +26,6 @@ or deployed codebases. errors aot export/index - type_promotion transfer_guard .. toctree:: diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index a0495986fcd1..4f07f94fe8b4 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -50,7 +50,8 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -87,7 +88,8 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -126,7 +128,8 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. From 6fba4ecc58b21a478f223aeba3b8dfff6cef39c7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 28 Mar 2025 15:20:27 -0700 Subject: [PATCH 0273/1769] PR #27576: [attrs] experimental appendattr Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576 This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated. This PR also includes some fixes for getattr/setattr. Copybara import of the project: -- 3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson : [attrs] experimental appendattr Merging this change closes #27576 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8 PiperOrigin-RevId: 741662724 --- jax/_src/interpreters/partial_eval.py | 68 +++++--- jax/_src/lax/control_flow/loops.py | 92 +++++++---- jax/_src/pjit.py | 82 +++++----- jax/experimental/attrs.py | 63 +++++++- tests/attrs_test.py | 215 +++++++++++++++++++++++++- 5 files changed, 430 insertions(+), 90 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 58b97ce2f3da..0a8e3b7824ff 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -58,6 +58,13 @@ def identity(x): return x AvalId = int ConstId = int +AttrKind = Any +PyTree = Any + +# Attrs flavors, see jax/experimental/attrs.py +ReadWrite = type('ReadWrite', (), {})() +Append = type('Append', (), {})() + def _update_annotation_known( f: lu.WrappedFun, orig_type: InputType | None, @@ -1553,6 +1560,17 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] """Reorder `invars` by moving those indicated in `to_move` to the back.""" return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) +def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: + return _move_outvars_to_back(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_outvars_to_back(jaxpr, to_move): + new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] + + [e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m]) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + + + class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1657,7 +1675,7 @@ class JaxprStackFrame: eqns: list[JaxprEqn] invars: list[Var] effects: core.Effects - attrs_tracked: list[tuple[Any, str]] + attrs_tracked: list[tuple[Any, str, AttrKind]] attrs_inits: list attrs_vars: list[Var] debug_info: core.DebugInfo @@ -1679,10 +1697,14 @@ def __init__(self, debug_info: core.DebugInfo): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, trace: DynamicJaxprTrace, - out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo, - ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + def reset_states(self): + reset_states(self.attrs_tracked, self.attrs_inits) + + def to_jaxpr( + self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo, + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars @@ -1699,7 +1721,6 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) # reset to initial values return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], @@ -1840,10 +1861,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) + effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) + new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, + jaxpr.debug_info) return new_jaxpr, new_constvals @@ -2172,19 +2192,23 @@ def trace_to_jaxpr_dynamic( *, keep_inputs: list[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) + try: + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) - _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) - del trace, fun, in_tracers, out_tracers, ans + out_tracers = map(trace.to_jaxpr_tracer, ans) + _check_no_returned_refs(fun.debug_info, out_tracers) + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) + del fun, in_tracers, out_tracers, ans + finally: + trace.frame.reset_states() + del trace config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2242,14 +2266,14 @@ def trace_to_jaxpr_dynamic2( tuple[AbstractedAxisName, ...], ] -AttrsTracked = list[tuple[Any, str]] +AttrsTracked = list[tuple[Any, str, AttrKind]] AttrStates = list -def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): - for ((obj, attr), val) in zip(attrs_tracked, vals): +def reset_states(attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: + for ((obj, attr, _), val) in zip(attrs_tracked, init_vals): setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) -def get_states(attrs_tracked: AttrsTracked): - return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]: + return [getattr(obj, attr) for (obj, attr, kind) in attrs_tracked] @register_static class DoesNotExist: ... diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 0362c139570a..56323949a607 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -298,8 +298,17 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - _, carry_avals_out, _ = split_list( - jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) + + if attrs_tracked: + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr = pe.move_outvars_to_back( + jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves]) + else: + carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) @@ -332,9 +341,8 @@ def _create_jaxpr(init): raise ValueError("`unroll` must be a `bool` or a positive `int`.") if attrs_tracked: in_state = _get_states(attrs_tracked) - in_carry, in_ext = split_list(in_flat, [num_carry]) - in_flat = [*in_state, *in_carry, *in_ext] - num_carry += len(attrs_tracked) + in_flat = [*in_state, *in_flat] + num_carry += len(in_state) out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, @@ -342,27 +350,50 @@ def _create_jaxpr(init): unroll=unroll, _split_transpose=_split_transpose) if attrs_tracked: - out_state, out = split_list(out, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_ext = (len(out) - len(in_state) + - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) + out_state, out, out_append = split_list(out, [len(in_state), num_ext]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + val, = leaves + jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:])) + else: + assert False def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals +def _merge_attrs_out(attrs_tracked, out_state, out_append): + out_state_, out_append_ = iter(out_state), iter(out_append) + out_attrs = [item for _, out_tree, (_, _, k) in attrs_tracked for item in + (itertools.islice(out_state_, out_tree.num_leaves) + if k is pe.ReadWrite else [next(out_append_)])] + assert next(out_state_, None) is next(out_append_, None) is None + return out_attrs + + def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] @@ -662,7 +693,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and + # previous consts). We need to collect the computed intensive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] @@ -785,16 +816,21 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) + # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e]) jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) + + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr_trans = pe.move_outvars_to_back( + jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + linear_trans = ([False] * num_ires + [False] * num_attr_carry + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) in_state = _get_states(attrs_tracked) transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres - transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) + transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry if not _split_transpose: outs = scan_p.bind( @@ -889,8 +925,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, for mask in outs_mask ] - out_state, outs = split_list(outs, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_outs = len(outs) - num_attr_carry - sum(appends_out) + out_state, outs, out_append = split_list(outs, [num_attr_carry, num_outs]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres @@ -935,12 +973,10 @@ def transposed(*res1_cbar_bbar_res2): return c_bar + a_bar # TODO(necula): fix arg names and results for transposed - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - return _make_closed_jaxpr_attrs( - transposed_wrapped, - tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) + transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info) + trans_avals = (*res1_avals, *c_avals, *b_carry_avals, *b_ys_avals_stripped, *res2_avals) + trans_jaxpr, attrs_tracked = _make_closed_jaxpr_attrs(transposed_wrapped, trans_avals) + return trans_jaxpr, attrs_tracked def _scan_batching_rule(axis_data, args, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 03eb6835cb06..5727c36a646b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -233,20 +233,30 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + del treedef + val, = leaves + jax_extendattr(obj, attr, val) def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals def _need_to_rebuild_with_fdo(pgle_profiler): @@ -537,7 +547,7 @@ class PjitParams(NamedTuple): donated_invars: tuple[bool, ...] arg_names: tuple[str, ...] num_consts: int - attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]] def _infer_params_impl( @@ -613,14 +623,14 @@ def _infer_params_impl( ji.in_layouts_treedef, ji.in_layouts_leaves, in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) - attr_token = _attr_token(flat_fun, in_type) + attr_token = _attr_cache_index(flat_fun, in_type) jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) - _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + _attr_cachedata_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, @@ -636,13 +646,14 @@ def _infer_params_impl( implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_extra_args = len(implicit_args) + num_states_in + len(consts) + num_attrs_in = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + num_extra_args = len(implicit_args) + num_attrs_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == num_states_in + len(consts) + len(args_flat)) + len(donated_invars) == num_attrs_in + len(consts) + len(args_flat)) params = dict( jaxpr=jaxpr, @@ -1274,7 +1285,7 @@ def _create_pjit_jaxpr( attr_data: int, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]]]: util.test_event("create_pjit_jaxpr") del ignored_inline # just for explain_cache_miss if config.no_tracing.value: @@ -1350,32 +1361,31 @@ def seen_attrs_get( assert fun.in_type is None or fun.in_type == in_type return cache[(fun.transforms, fun.params, in_type)] -def _attr_token( +def _attr_cache_index( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr, dne_sentinel + from jax.experimental.attrs import dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): - for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel - vals, treedef_ = tree_flatten(val) - avals_ = map(core.shaped_abstractify, vals) - if treedef != treedef_ or avals != avals_: break + for obj, attr, kind, treedef, avals in records: + if kind is pe.ReadWrite: + val = getattr(obj, attr, dne_sentinel) + vals, treedef_ = tree_flatten(val) + avals_ = map(core.shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break else: return i return len(cases) -def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr, dne_sentinel - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel) - records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) - for init_tree, _, (obj, attr) in attrs_tracked] +def _attr_cachedata_update(fun, in_type, i, attrs_tracked): + from jax.experimental.attrs import dne_sentinel + leaves = lambda obj, attr: tree_leaves(getattr(obj, attr, dne_sentinel)) + records = [(obj, attr, kind, init_tree, map(core.typeof, leaves(obj, attr))) + for init_tree, _, (obj, attr, kind) in attrs_tracked] cases = seen_attrs_get(fun, in_type) if i == len(cases): cases.append(records) - else: - assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1540,6 +1550,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] + assert len(args) == len(pjit_in_shardings) for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. @@ -2337,11 +2348,12 @@ def prune_type(ty, xs, maybe_zeros): if attrs_tracked: init_states = _get_states(attrs_tracked) + num_attr_outs = sum(final_tree.num_leaves for _, final_tree, _ in attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] - transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings - transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings - transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts - transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts + transpose_in_shardings = (UNSPECIFIED,) * len(init_states) + transpose_in_shardings + transpose_out_shardings = (UNSPECIFIED,) * num_attr_outs + transpose_out_shardings + transpose_in_layouts = (None,) * len(init_states) + transpose_in_layouts + transpose_out_layouts = (None,) * num_attr_outs + transpose_out_layouts try: nz_cts_out = pjit_p.bind( @@ -2370,7 +2382,7 @@ def prune_type(ty, xs, maybe_zeros): dispatch._raise_no_nan_in_deoptimized(e) if attrs_tracked: - final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) + final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs]) _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index bb4c7bf83b3f..0d40938a85c4 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -16,6 +16,7 @@ from typing import Any, Callable +import jax from jax._src import core from jax._src import source_info_util from jax._src import api_util @@ -32,20 +33,31 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +Array = Any JaxVal = Any Pytree = Any +ReadWrite = pe.ReadWrite +Append = pe.Append + register = api_util.register_class_with_attrs dne_sentinel = pe.dne_sentinel -def jax_getattr(obj: Any, attr: str): +def jax_getattr(obj: Any, attr: str) -> Pytree: with core.take_current_trace() as t: return t.process_getattr(obj, attr) -def jax_setattr(obj: Any, attr: str, val: Pytree): +def jax_setattr(obj: Any, attr: str, val: Pytree) -> None: with core.take_current_trace() as t: return t.process_setattr(obj, attr, val) +def jax_appendattr(obj: Any, attr: str, val: Array) -> None: + return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) + +def jax_extendattr(obj: Any, attr: str, val: Array) -> None: + with core.take_current_trace() as t: + return t.process_extendattr(obj, attr, val) + def _getattr_impl(_, obj, attr): return getattr(obj, attr) core.EvalTrace.process_getattr = _getattr_impl @@ -54,6 +66,25 @@ def _setattr_impl(_, obj, attr, val): setattr(obj, attr, val) core.EvalTrace.process_setattr = _setattr_impl +def _extendattr_impl(_, obj, attr, val): + cur = getattr(obj, attr, dne_sentinel) + if cur is dne_sentinel: + new = val + else: + _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) + new = jax.numpy.concatenate([cur, val]) + setattr(obj, attr, new) +core.EvalTrace.process_extendattr = _extendattr_impl + +def _check_append_type_agreement(_, attr, curtype, valtype): + expected = core.mapped_aval(curtype.shape[0], 0, curtype) + got = core.mapped_aval(valtype.shape[0], 0, valtype) + if not core.typematch(expected, got): + raise TypeError( + f"can only append to attr {attr} with values of trailing shape " + f"{expected.str_short()}, but appendattr got value of type " + f"{valtype.str_short()} which has trailing shape {got.str_short()}.") + def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame @@ -65,13 +96,16 @@ def new_tracer(x): frame.tracers.append(tracer) return tracer - if (obj, attr) not in frame.attrs_tracked: + if (obj, attr, Append) in frame.attrs_tracked: + raise TypeError(f"can't read/write to append-only attr {attr}") + + if (obj, attr, ReadWrite) not in frame.attrs_tracked: init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr)) + frame.attrs_tracked.append((obj, attr, ReadWrite)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked def _getattr_staging(trace, obj, attr): @@ -84,6 +118,27 @@ def _setattr_staging(trace, obj, attr, val): setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging +def _extendattr_staging(trace, obj, attr, val): + frame = trace.frame + + if (obj, attr, ReadWrite) in frame.attrs_tracked: + raise TypeError("can't append to read/write-only attr {attr}") + + first_write = (obj, attr, Append) not in frame.attrs_tracked + init_val = getattr(obj, attr, dne_sentinel) + if init_val is not dne_sentinel: + _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) + if first_write: + frame.attrs_inits.append(init_val) + frame.attrs_tracked.append((obj, attr, Append)) + tracer = val + else: + assert init_val is not dne_sentinel + with core.set_current_trace(trace): + tracer = jax.numpy.concatenate([init_val, val]) + setattr(obj, attr, tracer) +pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging + def jvp(f, primals, tangents, attr_tangents): attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 169df3712899..8cf64790311b 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass +import itertools as it from absl.testing import absltest from absl.testing import parameterized @@ -28,7 +29,7 @@ from jax._src.util import safe_zip, safe_map from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr +from jax.experimental.attrs import jax_setattr, jax_getattr, jax_appendattr config.parse_flags_with_absl() @@ -66,6 +67,19 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + def test_setattr_doesnt_leak(self): + thing = Thing(1.0) + + @jax.jit + def f(x): + jax_setattr(thing, 'x', x) + raise Exception + + try: f(1.) + except: pass + self.assertNotIsInstance(thing.x, jax.core.Tracer) + + @parameterized.parameters([True, False]) def test_jit_basic_tree(self, jit: bool): thing = Thing((1.0, 2.0)) @@ -260,6 +274,26 @@ def body(_, __): double_it_10() self.assertAllClose(thing.x, 1024., check_dtypes=False) + @parameterized.parameters([True, False]) + def test_scan_basic_pytree(self, jit): + class Thing: ... + thing = Thing() + thing.x = (1.0, 1.0) + + def double_it_10(): + def body(_, __): + cur_x, _ = jax_getattr(thing ,"x") + jax_setattr(thing, "x", (cur_x * 2.0, 3.0)) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(thing.x[0], 1024., check_dtypes=False) + self.assertAllClose(thing.x[1], 3., check_dtypes=False) + def test_scan_basic_consts_and_args(self): thing = Thing(1.0) @@ -402,6 +436,184 @@ def f(x): jax.make_jaxpr(f)(3.) self.assertFalse(hasattr(thing, 'x')) + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, 0.) + tracing_ok = False + f(1.0) + self.assertAllClose(thing.x, 1.) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_basic(self, jit, initialized): + class Thing: + ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(x): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x + 1) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f(2.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3.])) + f(4.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_constant(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', 0.0) + jax_appendattr(thing, 'x', 1.0) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f() + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f() + self.assertAllClose(thing.x, jnp.array([0., 1., 0., 1.])) + + @parameterized.parameters([True, False]) + def test_appendattr_getattr_errors(self, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + @jax.jit + def f(x): + jax_appendattr(thing, 'x', x) + jax_getattr(thing, 'x') + + with self.assertRaisesRegex(TypeError, "can't read/write"): + f(1.0) + + @jax.jit + def g(x): + jax_setattr(thing, 'x', x) + jax_appendattr(thing, 'x', x) + + with self.assertRaisesRegex(TypeError, "can't append"): + g(1.0) + + if initialized: + self.assertNotIsInstance(thing.x, jax.core.Tracer) + else: + self.assertFalse(hasattr(thing, 'x')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_dtype_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([], 'float32') + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x.astype('complex64')) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape " + msg += "float32" if initialized else "int32" + with self.assertRaisesRegex(TypeError, msg): + f(jnp.array(1, 'int32')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_shape_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', jnp.stack([x, x])) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape" + with self.assertRaisesRegex(TypeError, msg): + f(1) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(): + def body(c, x): + jax_appendattr(thing, 'x', 2 * x) + jax_appendattr(thing, 'x', 2 * x + 1) + return c, () + _, () = jax.lax.scan(body, 0, jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan_vjp(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.y_bar = jnp.array([]) + + def f(x): + def body(c, _): + return 0.5 * g(2 * c), () + y, _ = jax.lax.scan(body, x, (), length=5) + return y + + if jit: + f = jax.jit(f) + + @jax.custom_vjp + def g(x): + return x + + def g_fwd(x): + return g(x), None + + def g_bwd(_, y_bar): + jax_appendattr(thing, 'y_bar', y_bar) + return y_bar, + + g.defvjp(g_fwd, g_bwd) + jax.grad(f)(3.) + + self.assertAllClose(thing.y_bar, jnp.array([0.5] * 5)) class AttrsJVPTest(jtu.JaxTestCase): @@ -543,6 +755,7 @@ def g_ref(x, x_dot, y, y_dot): self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) + class AttrsLinTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) From eb54cd2c6109fafb52894cc1f2d687cb3a25fb4d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 15:22:48 -0700 Subject: [PATCH 0274/1769] Remove GPU-specific dependencies from backend-independent tests. The GPU-specific deps were added to the backend-independent tests by mistake [here](https://github.com/jax-ml/jax/pull/27113). These tests should pass using `jax` and `jaxlib` wheels only. PiperOrigin-RevId: 741663266 --- jaxlib/jax.bzl | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 560db85d6a1e..1cc4fab12591 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -255,8 +255,8 @@ def if_building_jaxlib( "//conditions:default": [], }) -def _get_test_deps(deps): - jaxlib_build_deps = [ +def _get_test_deps(deps, backend_independent): + gpu_build_deps = [ "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", @@ -273,12 +273,21 @@ def _get_test_deps(deps): "//jaxlib/tools:jaxlib_py_import", ] + if backend_independent: + jaxlib_build_deps = deps + gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = cpu_py_imports + else: + jaxlib_build_deps = gpu_build_deps + deps + gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = gpu_py_imports + return select({ - "//jax:enable_jaxlib_build": jaxlib_build_deps + deps, + "//jax:enable_jaxlib_build": jaxlib_build_deps, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": _GPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_imports, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps, }) # buildifier: disable=function-docstring @@ -334,7 +343,7 @@ def jax_multiplatform_test( deps = _get_test_deps([ "//jax", "//jax:test_util", - ] + deps), + ] + deps, backend_independent = False), data = data, shard_count = test_shards, tags = test_tags, @@ -629,15 +638,15 @@ def jax_py_test( if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" deps = kwargs.get("deps", []) - kwargs.pop("deps") - test_deps = _get_test_deps(deps) - py_test(name = name, env = env, deps = test_deps, **kwargs) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps + py_test(name = name, env = env, **kwargs) def pytype_test(name, **kwargs): deps = kwargs.get("deps", []) - kwargs.pop("deps") - test_deps = _get_test_deps(deps) - native.py_test(name = name, deps = test_deps, **kwargs) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps + native.py_test(name = name, **kwargs) def if_oss(oss_value, google_value = []): """Returns one of the arguments based on the non-configurable build env. From 93c6bb72d3f550991969bb7bd13c4d5e0fbc46ae Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Thu, 20 Mar 2025 15:31:16 -0700 Subject: [PATCH 0275/1769] add discord release action Update community_release_actions.yml --- .../workflows/community_release_actions.yml | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .github/workflows/community_release_actions.yml diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..d61bea3d7e4d --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,31 @@ +name: Release Actions + +on: + release: + types: [published] + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} From 123ce5221b298d551e168029e12e4b53147b206c Mon Sep 17 00:00:00 2001 From: jeffcarp Date: Thu, 27 Feb 2025 15:51:00 -0800 Subject: [PATCH 0276/1769] Add scalar event logging function --- jax/_src/monitoring.py | 43 ++++++++++++++++++++++++++++++++++++++++ jax/monitoring.py | 2 ++ tests/monitoring_test.py | 28 ++++++++++++++++++++++++-- 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 99e957733ba2..de706ccbaef5 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -46,10 +46,18 @@ def __call__( ) -> None: ... +class ScalarListenerWithMetadata(Protocol): + + def __call__( + self, event: str, value: float | int, **kwargs: str | int, + ) -> None: + ... + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] _event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] +_scalar_listeners: list[ScalarListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -81,6 +89,14 @@ def record_event_time_span( callback(event, start_time, end_time, **kwargs) +def record_scalar( + event: str, value: float | int, **kwargs: str | int +) -> None: + """Record a scalar summary value.""" + for callback in _scalar_listeners: + callback(event, value, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -100,6 +116,14 @@ def register_event_duration_secs_listener( """Register a callback to be invoked during record_event_duration_secs().""" _event_duration_secs_listeners.append(callback) + +def register_scalar_listener( + callback : ScalarListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_scalar().""" + _scalar_listeners.append(callback) + + def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) @@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) + +def get_scalar_listeners() -> list[ScalarListenerWithMetadata]: + """Get scalar event listeners.""" + return list(_scalar_listeners) + + def clear_event_listeners(): """Clear event listeners.""" global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] _event_time_span_listeners = [] + _scalar_listeners = [] + def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback( """ assert callback in _event_listeners _event_listeners.remove(callback) + + +def _unregister_scalar_listener_by_callback( + callback: ScalarListenerWithMetadata, +) -> None: + """Unregister a scalar event listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _scalar_listeners + _scalar_listeners.remove(callback) diff --git a/jax/monitoring.py b/jax/monitoring.py index 4c9996da582c..f4ab8124f219 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -26,7 +26,9 @@ record_event_duration_secs as record_event_duration_secs, record_event_time_span as record_event_time_span, record_event as record_event, + record_scalar as record_scalar, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, register_event_time_span_listener as register_event_time_span_listener, + register_scalar_listener as register_scalar_listener, ) diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 52b53895c2cc..89c7148a2a42 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -29,7 +29,7 @@ def tearDown(self): def test_record_event(self): events = [] - counters = {} # Map event names to frequency. + counters = {} # Map event names to frequency. def increment_event_counter(event): if event not in counters: counters[event] = 0 @@ -48,7 +48,7 @@ def increment_event_counter(event): "test_common_event": 2}) def test_record_event_durations(self): - durations = {} # Map event names to frequency. + durations = {} # Map event names to frequency. def increment_event_duration(event, duration): if event not in durations: durations[event] = 0. @@ -62,6 +62,30 @@ def increment_event_duration(event, duration): self.assertDictEqual(durations, {"test_short_event": 3, "test_long_event": 10}) + def test_record_scalar(self): + observed_keys = [] + observed_values = [] + + monitoring.register_scalar_listener( + lambda key, _: observed_keys.append(key), + ) + monitoring.register_scalar_listener( + lambda _, value: observed_values.append(value), + ) + + monitoring.record_scalar("test_unique_event", 1) + monitoring.record_scalar("test_common_event", 2.5) + monitoring.record_scalar("test_common_event", 5e5) + + self.assertListEqual( + observed_keys, + ["test_unique_event", "test_common_event", "test_common_event"], + ) + self.assertListEqual( + observed_values, + [1, 2.5, 5e5], + ) + def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() callback = lambda event, durations: None From 80061ad4c433e419961c7c6d40e3d0e5bc4d24b4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 16:54:23 -0700 Subject: [PATCH 0277/1769] Add vma rules for pmin and pmax PiperOrigin-RevId: 741685454 --- jax/_src/lax/parallel.py | 48 +++++++++++++++++++++++++++++++++-- jax/experimental/shard_map.py | 37 +++------------------------ tests/shard_map_test.py | 10 ++++++++ 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3ef0a2520378..8fc8c336d61a 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -203,6 +203,7 @@ def pmax(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pbroadcast, axis_name), leaves) out_flat = pmax_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -233,6 +234,7 @@ def pmin(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pbroadcast, axis_name), leaves) out_flat = pmin_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -803,6 +805,48 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): ] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _psum2_abstract_eval(name, *args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + _check_axis_names(axes) + arg_vma = [a.vma for a in args] + # If intersection between arg_vma and axes is empty, error + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + f"Collective {name} must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} + +# TODO(yashkatariya): Replace this with _psum2_abstract_eval +def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return _allreduce_effectful_abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + return _psum2_abstract_eval(name, *args, axes=axes, + axis_index_groups=axis_index_groups) + def _check_axis_names(axes): named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) axis_env = core.get_axis_env() @@ -902,7 +946,7 @@ def broadcast_positional(ct, arg): pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) -pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax')) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ @@ -913,7 +957,7 @@ def broadcast_positional(ct, arg): pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) -pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin')) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 44c2b569f947..8e2d93af2639 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1072,40 +1072,8 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) - -def _psum2_abstract_eval(*args, axes, axis_index_groups): - if not config.varying_axes_in_types.value: - return lax_parallel.psum_p.abstract_eval( - *args, axes=axes, axis_index_groups=axis_index_groups) - - assert isinstance(axes, tuple) - lax_parallel._check_axis_names(axes) - arg_vma = [a.vma for a in args] - if any(not set(axes) & a for a in arg_vma): - raise ValueError( - "Collective psum must be applied to a device-varying " - f"type, but got {arg_vma} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - - named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) - pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) - if axis_index_groups is not None: - if len(pos_axes) != 0: - raise ValueError( - "axis_index_groups can only be used with reductions over " - f"named axes, but got: {axes}") - core.check_avals_context_mesh(args, 'all_reduce') - out_avals = [ - core.ShapedArray( - lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, - sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), - vma=frozenset(a for a in arg.vma if a not in named_axes)) - for arg in args - ] - return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} -psum2_p.def_effectful_abstract_eval(_psum2_abstract_eval) +psum2_p.def_effectful_abstract_eval( + partial(lax_parallel._psum2_abstract_eval, psum2_p.name)) mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) batching.fancy_primitive_batchers[psum2_p] = \ @@ -1135,6 +1103,7 @@ def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): return args assert isinstance(axes, tuple) arg_vma = [a.vma for a in args] + # If there is intersection between arg_vma and axes, error if any(set(axes) & a for a in arg_vma): raise ValueError( "Collective pbroadcast must be applied to a " diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index c1923f5b0ae3..36966fde2a90 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2685,6 +2685,16 @@ def test_pmax(self): )(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + @config.varying_axes_in_types(True) + def test_pmax_vma_in_types(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + f = jax.jit(shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, + in_specs=P(), out_specs=P())) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('i',)", str(jaxpr)) + f(x) # doesn't crash + @config.varying_axes_in_types(True) def test_mul_with_vma_in_types(self): mesh = jtu.create_mesh((2,), ('x',)) From 7ca50844f3d66ab2b158e22b76fcc62e4406f867 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 21:42:42 -0700 Subject: [PATCH 0278/1769] Fix an edge-case in reshape sharding rule where the last splitting/merging dim was `1`. PiperOrigin-RevId: 741740811 --- jax/_src/lax/lax.py | 7 ++++++- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c2f0876ce932..fd956136ccd3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6977,11 +6977,16 @@ def _split_on_one_axis(op_shape, new_sizes, name): ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') temp = [new_sizes[j]] - while math.prod(temp) != op_shape[i]: + next_j = j + 1 + while (math.prod(temp) != op_shape[i] or + (next_j < len(new_sizes) and new_sizes[next_j] == 1)): if math.prod(temp) > op_shape[i]: return False, [] j += 1 + if j >= len(new_sizes): + return False, [] temp.append(new_sizes[j]) + next_j += 1 out.append(temp) i += 1 j += 1 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b49ba19c72dc..0b2daee8ccff 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5508,6 +5508,18 @@ def h2(x, y): ('4', (1, 4, 1, 6, 1), (1, 4, 6), P(None, 'x', None, None, None), P(None, 'x', None), False), ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ('6', (1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ('7', (1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ('8', (1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ('9', (1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ('10', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), ) @jtu.with_user_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @@ -5519,6 +5531,8 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @partial(jax.jit, static_argnums=1) def f(x, new_sharding): y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + self.assertEqual(y.aval.sharding.spec, dst_spec) + self.assertEqual(y.shape, dst_shape) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y From e7ec418eba9ada336f755613948cbdf4a9e97d59 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 29 Mar 2025 05:19:04 -0700 Subject: [PATCH 0279/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0. PiperOrigin-RevId: 741809075 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 43bba2fcc903..f0a33d4c5e55 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4" -XLA_SHA256 = "d82a7174a8a129180b180b08f5eedfa5fe6ff19fbd46dc11dae8cf64d87dfbf9" +XLA_COMMIT = "f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0" +XLA_SHA256 = "e4935a201c105a705d2a26c718663f9a7073f8a1d337c0e7eb885e2e2480797d" def repo(): tf_http_archive( From 5fda4c1b0e49761d71ce4addec80cdc6b479d2e7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 30 Mar 2025 04:43:56 -0700 Subject: [PATCH 0280/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8df9390dc9444d900c7c7f2c123f23b549adf8e3. PiperOrigin-RevId: 741998725 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f0a33d4c5e55..8b3ddfde019b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0" -XLA_SHA256 = "e4935a201c105a705d2a26c718663f9a7073f8a1d337c0e7eb885e2e2480797d" +XLA_COMMIT = "8df9390dc9444d900c7c7f2c123f23b549adf8e3" +XLA_SHA256 = "8e97c395d1e50a49fab386ccc7da1f78dc86bf670b20a892656e2e75bbf64f0e" def repo(): tf_http_archive( From a865b4e4370d1301325db64005b92aacbf4c8c7a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sun, 30 Mar 2025 10:50:05 -0700 Subject: [PATCH 0281/1769] [mgpu] Register the mosaic_gpu dialect regardless of warpgroup/lane lowering. In `mgpu.bitwidth()` mosaic_gpu types are being checked even in Lane lowering which fails. PiperOrigin-RevId: 742044332 --- jax/_src/pallas/mosaic_gpu/pallas_call_registration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 6dc958edbc53..1d4be26187ce 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -61,8 +61,7 @@ def pallas_call_lowering( thread_semantics = compiler_params.get("mosaic_gpu", {}).get( "thread_semantics", mgpu.ThreadSemantics.Lane ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: - mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, From 0edd715e96850f0b2fd2fc13685fde1e426b603a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sun, 30 Mar 2025 16:11:49 -0700 Subject: [PATCH 0282/1769] [mgpu/pallas] Expose WGMMA_TRANSPOSED layout PiperOrigin-RevId: 742084936 --- jax/_src/pallas/mosaic_gpu/primitives.py | 4 ++++ jax/experimental/mosaic/gpu/__init__.py | 1 + 2 files changed, 5 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a9bd91b26622..48a4cae62824 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -911,6 +911,7 @@ class Layout(enum.Enum): WGMMA_ROW = enum.auto() #: [n] matrix, where n % 8 == 0. WGMMA_COL = enum.auto() + WGMMA_TRANSPOSED = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() @@ -924,6 +925,9 @@ def check_no_args(): raise ValueError(f"Can't instantiate {self} with arguments.") match self: + case Layout.WGMMA_TRANSPOSED: + check_no_args() + return mgpu.WGMMA_TRANSPOSED_LAYOUT case Layout.WGMMA: check_no_args() return mgpu.WGMMA_LAYOUT diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 867fd84b8b3c..afc87b5d96fa 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -55,6 +55,7 @@ WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, + WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, WGMMARowFragLayout as WGMMARowFragLayout, WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, From 29bd01f8307449205a4894048927e6669dba2929 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Tue, 18 Mar 2025 17:21:40 -0700 Subject: [PATCH 0283/1769] add reduction support in copy_smem_to_gmem --- jax/_src/pallas/mosaic_gpu/primitives.py | 7 ++ jax/experimental/mosaic/gpu/launch_context.py | 77 ++++++++++++++++--- jaxlib/mosaic/gpu/runtime.cc | 48 ++++++++---- tests/pallas/mosaic_gpu_test.py | 23 ++++++ 4 files changed, 132 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 48a4cae62824..e3f8e4c03f75 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -185,6 +185,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, + reduction_op: Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] | None, ): if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] @@ -215,6 +216,7 @@ def _copy_smem_to_gmem_lowering( dst_ref=dst, predicate=predicate, arrive=commit_group, + reduction_op=reduction_op, **copy_params, ) return () @@ -293,6 +295,9 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -304,6 +309,7 @@ def copy_smem_to_gmem( commit_group: If ``True``, this and any previously uncommitted copies are committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. + reduction_op: if set, perform the specified reduction op when copy to gmem See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -331,6 +337,7 @@ def copy_smem_to_gmem( dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, commit_group=commit_group, + reduction_op=reduction_op, ) return None diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index ce432f26dac2..41c15bc5492e 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -19,9 +19,10 @@ import enum import functools import math -from typing import Any +from typing import Any, Literal from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jax._src import lib as jaxlib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import func @@ -309,6 +310,9 @@ def _get_tma_desc( gmem_transform: tuple[MemRefTransform, ...], transformed_slice_shape: tuple[int, ...], swizzle: int | None, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None, ): tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: @@ -337,10 +341,38 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. + if jaxlib.version < (0, 5, 4): + dtype_or_bitwidth = c(utils.bitwidth(ref_ty.element_type), i64) + else: + if isinstance(ref_ty.element_type, ir.IntegerType): + if reduction_op is not None: + raise ValueError( + f"TMA with reduction_op={reduction_op} is not supported with Integers" + ) + bitwidth = utils.bitwidth_impl(ref_ty.element_type) + if bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 3 + elif bitwidth == 64: + tma_dtype = 4 + elif ir.F16Type.isinstance(ref_ty.element_type): + tma_dtype = 5 + elif ir.F32Type.isinstance(ref_ty.element_type): + tma_dtype = 6 + elif ir.BF16Type.isinstance(ref_ty.element_type): + tma_dtype = 7 + else: + raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") + dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, base_ptr, - c(utils.bitwidth(ref_ty.element_type), i64), + dtype_or_bitwidth, c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), @@ -375,6 +407,9 @@ def async_copy( collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -453,6 +488,13 @@ def async_copy( " multiple of 16 bytes" ) + if reduction_op is not None and jaxlib.version < (0, 5, 4): + raise ValueError("TMA with reduction is only supported with jaxlib >= 0.5.4") + if reduction_op is not None and not isinstance(gmem_ref_ty.element_type, ir.FloatType): + raise ValueError("TMA with reduction is only supported with float dtype") + if reduction_op is not None and reduction_op != "add": + raise ValueError("TMA with reduction is only supported with add operation") + # NOTE: TMA supports OOB indices, so we skip the check. base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False @@ -597,7 +639,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): multicast_mask = None tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, ) # We constuct TMA descriptors in column-major order. @@ -641,6 +683,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) barrier_ptr = barrier.get_ptr() with uniform_ctx(): + assert reduction_op is None if collective_size > 1 and partitioned is not None: if predicate is None: predicate = c(1, ir.IntegerType.get_signless(1)) @@ -679,12 +722,28 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) else: assert multicast_mask is None - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate - ) - if arrive: - nvvm.cp_async_bulk_commit_group() + if reduction_op is not None: + with uniform_ctx(): + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + ) + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..fd452e781c72 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -22,7 +22,7 @@ limitations under the License. extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, - int64_t elem_bitwidth, int64_t rank, + int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, int64_t swizzle_bytes, int64_t *window_shape) { if (((uintptr_t)tma_desc) % 64 != 0) { @@ -32,6 +32,39 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, abort(); } + CUtensorMapDataType data_type; + int64_t elem_bitwidth; + // types are defined in: LaunchContext._get_tma_desc() + if (elem_type == 0){ + // this is for int4s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 4; + } else if (elem_type == 1){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 8; + } else if (elem_type == 2){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + elem_bitwidth = 16; + } else if (elem_type == 3){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + elem_bitwidth = 32; + } else if (elem_type == 4){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + elem_bitwidth = 64; + } else if (elem_type == 5){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 6){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + elem_bitwidth = 32; + } else if (elem_type == 7){ + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bitwidth = 16; + } else{ + fprintf(stderr, "Unsupported element type: %ld \n", elem_type); + abort(); + } + // Pack 4 bit types in 8 bit pairs. int64_t elem_bytewidth; if (elem_bitwidth < 8) { @@ -54,19 +87,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, elem_bytewidth = elem_bitwidth / 8; } - CUtensorMapDataType data_type; - if (elem_bytewidth == 1) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (elem_bytewidth == 2) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if (elem_bytewidth == 4) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if (elem_bytewidth == 8) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else { - fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); - abort(); - } if (rank < 1 || rank > 5) { fprintf(stderr, "Rank must be in [1, 5], but got %ld\n", rank); abort(); diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c8013f634c67..040f994d0b2b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -379,6 +379,29 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) + def test_copy_smem_to_gmem_reduction(self, dtype): + @functools.partial( + pl.pallas_call, + grid=(200,), + in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct([128], dtype), + scratch_shapes=[plgpu.SMEM((128,), dtype)], + input_output_aliases={1:0} + ) + def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref): + del o_ref_gmem_alias + scratch_ref[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add") + plgpu.wait_smem_to_gmem(0) + x = jnp.ones(200 * 128).astype(dtype) # 200 blocks + output = jnp.zeros(128).astype(dtype) + output = kernel(x, output) + output_val = x.reshape(-1, 128).sum(axis=0) + np.testing.assert_array_equal(output, output_val) + @parameterized.named_parameters( {"testcase_name": "1d_none", "shape": (256,), "indexers": (slice(0, 128), slice(None, 32))}, From 10425ae6a9ebd77ecd0de775f3f758b8978e18bd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 14:32:59 -0700 Subject: [PATCH 0284/1769] jax.core: finalize a number of deprecations for JAX v0.6.0 --- CHANGELOG.md | 8 ++++ jax/core.py | 125 ++++++++++++++++----------------------------------- 2 files changed, 47 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5785f6193065..68450dca4057 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `get_referent`, + `join_effects`, `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/core.py b/jax/core.py index b404e66c2691..688fa14d9ccf 100644 --- a/jax/core.py +++ b/jax/core.py @@ -81,75 +81,21 @@ from jax._src import core as _src_core _deprecations = { - # Added 2024-12-16 - "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.ClosedJaxpr), - "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Jaxpr), - "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.JaxprEqn), - "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Literal), - "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Primitive), - "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Token), - "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Var), # Added 2024-12-11 "axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame), "AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName), - "AxisSize": ("jax.core.AxisSize is deprecated.", _src_core.AxisSize), "ConcretizationTypeError": ("jax.core.ConcretizationTypeError is deprecated; " "use jax.errors.ConcretizationTypeError.", _src_core.ConcretizationTypeError), - "EvalTrace": ("jax.core.EvalTrace is deprecated.", _src_core.EvalTrace), - "InDBIdx": ("jax.core.InDBIdx is deprecated.", _src_core.InDBIdx), - "InputType": ("jax.core.InputType is deprecated.", _src_core.InputType), - "MapPrimitive": ("jax.core.MapPrimitive is deprecated.", _src_core.MapPrimitive), - "OpaqueTraceState": ("jax.core.OpaqueTraceState is deprecated.", _src_core.OpaqueTraceState), - "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), - "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", - _src_core.TRACER_LEAK_DEBUGGER_WARNING), "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", _src_core.call_p), "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), - "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), - "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", - _src_core.escaped_tracer_error), - "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", - _src_core.extend_axis_env_nd), "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), - "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), - "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), - "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", - _src_core.leaked_tracer_error), - "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers is deprecated.", - _src_core.maybe_find_leaked_tracers), - "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings is deprecated." - " It is unused as of jax v0.4.36.", - _src_core.raise_to_shaped_mappings), - "reset_trace_state": ("jax.core.reset_trace_state is deprecated.", - _src_core.reset_trace_state), - "str_eqn_compact": ("jax.core.str_eqn_compact is deprecated.", _src_core.str_eqn_compact), - "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty is deprecated.", - _src_core.substitute_vars_in_output_ty), "trace_state_clean": ("jax.core.trace_state_clean is deprecated.", _src_core.trace_state_clean), "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), - "typecompat": ("jax.core.typecompat is deprecated.", _src_core.typecompat), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), - "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.", - _src_core.used_axis_names_jaxpr), # Added 2024-12-10 "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.full_lower), @@ -158,54 +104,61 @@ _src_core.jaxpr_as_fun), "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.lattice_join), - "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.raise_to_shaped), + # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 + "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), + "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "EvalTrace": ("jax.core.EvalTrace was removed in JAX v0.6.0.", None), + "InDBIdx": ("jax.core.InDBIdx was removed in JAX v0.6.0.", None), + "InputType": ("jax.core.InputType was removed in JAX v0.6.0.", None), + "Jaxpr": ("jax.core.Jaxpr was removed in JAX v0.6.0. Use jax.extend.core.Jaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "JaxprEqn": ("jax.core.JaxprEqn was removed in JAX v0.6.0. Use jax.extend.core.JaxprEqn instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "Literal": ("jax.core.Literal was removed in JAX v0.6.0. Use jax.extend.core.Literal instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "MapPrimitive": ("jax.core.MapPrimitive was removed in JAX v0.6.0.", None), + "OpaqueTraceState": ("jax.core.OpaqueTraceState was removed in JAX v0.6.0.", None), + "OutDBIdx": ("jax.core.OutDBIdx was removed in JAX v0.6.0.", None), + "Primitive": ("jax.core.Primitive was removed in JAX v0.6.0. Use jax.extend.core.Primitive instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "Token": ("jax.core.Token was removed in JAX v0.6.0. Use jax.extend.core.Token instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING was removed in JAX v0.6.0.", None), + "Var": ("jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "concrete_aval": ("jax.core.concrete_aval was removed in JAX v0.6.0.", None), + "dedup_referents": ("jax.core.dedup_referents was removed in JAX v0.6.0.", None), + "escaped_tracer_error": ("jax.core.escaped_tracer_error was removed in JAX v0.6.0.", None), + "extend_axis_env_nd": ("jax.core.extend_axis_env_nd was removed in JAX v0.6.0.", None), + "get_referent": ("jax.core.get_referent was removed in JAX v0.6.0.", None), + "join_effects": ("jax.core.join_effects was removed in JAX v0.6.0.", None), + "leaked_tracer_error": ("jax.core.leaked_tracer_error was removed in JAX v0.6.0.", None), + "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers was removed in JAX v0.6.0.", None), + "raise_to_shaped": ("jax.core.raise_to_shaped was removed in JAX v0.6.0. It is a no-op as of JAX v0.4.36.", None), + "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings was removed in JAX v0.6.0." + " It is unused as of jax v0.4.36.", None), + "reset_trace_state": ("jax.core.reset_trace_state was removed in JAX v0.6.0.", None), + "str_eqn_compact": ("jax.core.str_eqn_compact was removed in JAX v0.6.0.", None), + "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty was removed in JAX v0.6.0.", None), + "typecompat": ("jax.core.typecompat was removed in JAX v0.6.0.", None), + "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr was removed in JAX v0.6.0.", None), } import typing if typing.TYPE_CHECKING: AxisName = _src_core.AxisName - AxisSize = _src_core.AxisSize - ClosedJaxpr = _src_core.ClosedJaxpr ConcretizationTypeError = _src_core.ConcretizationTypeError - EvalTrace = _src_core.EvalTrace - InDBIdx = _src_core.InDBIdx - InputType = _src_core.InputType - Jaxpr = _src_core.Jaxpr - JaxprEqn = _src_core.JaxprEqn - Literal = _src_core.Literal - MapPrimitive = _src_core.MapPrimitive - OpaqueTraceState = _src_core.OpaqueTraceState - OutDBIdx = _src_core.OutDBIdx - Primitive = _src_core.Primitive - Token = _src_core.Token - TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING - Var = _src_core.Var axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.abstractify - dedup_referents = _src_core.dedup_referents - escaped_tracer_error = _src_core.escaped_tracer_error - extend_axis_env_nd = _src_core.extend_axis_env_nd full_lower = _src_core.full_lower get_type = _src_core.get_aval - get_referent = _src_core.get_referent jaxpr_as_fun = _src_core.jaxpr_as_fun - join_effects = _src_core.join_effects lattice_join = _src_core.lattice_join - leaked_tracer_error = _src_core.leaked_tracer_error - maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers - raise_to_shaped = _src_core.raise_to_shaped - raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings - reset_trace_state = _src_core.reset_trace_state - str_eqn_compact = _src_core.str_eqn_compact - substitute_vars_in_output_ty = _src_core.substitute_vars_in_output_ty trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck - typecompat = _src_core.typecompat typematch = _src_core.typematch - used_axis_names_jaxpr = _src_core.used_axis_names_jaxpr else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) From aee27854f056bdabdb25d2af0eac5a2e5b35f63b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 00:53:30 -0700 Subject: [PATCH 0285/1769] [Pallas:MGPU] Only allow small tiling in Pallas programs This is part of the removal of support for large MMA tiling in Mosaic GPU. It should also let us simplify some of the transpose transforms that are no longer necessary, but I decided to separate this. PiperOrigin-RevId: 742168801 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 6 +- .../pallas/ops/gpu/attention_mgpu.py | 4 +- tests/pallas/mosaic_gpu_test.py | 62 +++++++++---------- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e99feb4dc144..daa718ff1ff2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1129,7 +1129,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle @@ -1188,7 +1188,7 @@ def _swap_lowering_rule( x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") old_value = mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 48a4cae62824..d0632728f2b6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -93,7 +93,7 @@ def _load_p_lowering_rule( match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, @@ -739,7 +739,7 @@ def _wgmma_lowering( match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (64, swizzle_elems): + if tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") @@ -790,7 +790,7 @@ def _wgmma_lowering( swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") - if rhs_tiling != (swizzle_elems, swizzle_elems): + if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") if rhs_transpose: diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 534da419ed3b..48a0d18459cb 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -193,7 +193,7 @@ def kv_loop(kv_step, _): def entry(q_ref, k_ref, v_ref, out_ref): compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, @@ -263,7 +263,7 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c8013f634c67..9c5795e49da7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -523,7 +523,7 @@ def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): index_map=lambda i, j: (i, j), memory_space=plgpu.SMEM, transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), ), @@ -584,7 +584,7 @@ def kernel(x_ref, o_ref, barrier_ref): (128, 128), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, @@ -604,7 +604,7 @@ def kernel(x_ref, o_ref, barrier_ref): def test_scoped_copy_with_transforms(self): self.skip_if_wg_semantics() - ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + ts = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) @@ -639,7 +639,7 @@ def kernel(x_ref, o_ref, barrier_ref): (2, 128, 128), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), plgpu.SwizzleTransform(128), ), @@ -749,8 +749,7 @@ def compute(acc_ref): (k, n), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), ), @@ -881,8 +880,7 @@ def test_print_wgmma_tiled_layout(self): shape, lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), ) ], @@ -1061,8 +1059,7 @@ def test_swizzled_blockspec_shapes(self): (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @functools.partial( @@ -1243,8 +1240,7 @@ def test_tile_slicing(self): shape = (256, 128) block_spec = plgpu.GPUBlockSpec( transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ) ) @functools.partial( @@ -1297,7 +1293,7 @@ def rotate(src, dst): (128, 128), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @@ -1431,7 +1427,7 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( pl.pallas_call, in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], @@ -1488,21 +1484,21 @@ def _epilogue(): lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) rhs_spec = plgpu.GPUBlockSpec( rhs_spec.block_shape, rhs_spec.index_map, transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) out_spec = plgpu.GPUBlockSpec( out_spec.block_shape, out_spec.index_map, transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) @@ -1546,7 +1542,7 @@ def scope(acc_ref): b_shape = b_shape[::-1] b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) if rhs_transpose: rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( @@ -1556,7 +1552,7 @@ def scope(acc_ref): (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1585,7 +1581,7 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = pl.pallas_call( kernel, in_specs=[ @@ -1608,7 +1604,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = pl.pallas_call( kernel, in_specs=[ @@ -1639,14 +1635,14 @@ def scope(acc_ref): plgpu.GPUBlockSpec( (2, 64, 128), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), plgpu.GPUBlockSpec( (2, 128, 192), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), @@ -1676,7 +1672,7 @@ def scope(acc_ref): (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1684,7 +1680,7 @@ def scope(acc_ref): (128, 128), lambda *i: i, transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1820,7 +1816,7 @@ def body(step, _): @parameterized.parameters( ((),), - ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + ((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),), ) def test_emit(self, transforms): num_steps = 4 @@ -2005,7 +2001,7 @@ def kernel_body(a_smem, b_smem): (tile_m, tile_k), lambda k: (pid_m, k), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2013,7 +2009,7 @@ def kernel_body(a_smem, b_smem): (tile_k, tile_n), lambda k: (k, pid_n), transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2039,7 +2035,7 @@ def kernel_body(a_smem, b_smem): (tile_m, tile_n), lambda m, n: (m, n), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2339,7 +2335,7 @@ def test_realistic_matmul_with_cluster(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n transforms = ( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) @@ -2521,7 +2517,7 @@ def compute(l_smem, r_smem, o_smem): r = lax.axis_index("rows") block = plgpu.GPUBlockSpec( (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + transforms=(plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)), ) plgpu.emit_pipeline( compute, @@ -2572,8 +2568,8 @@ def do_wgmma(acc_ref): return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) m, n = lax.axis_index("m"), lax.axis_index("n") - lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) - r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) + r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), From 05e15ba032841b13bd95f684a2f4f0b57bd75ada Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 31 Mar 2025 02:49:03 -0700 Subject: [PATCH 0286/1769] [pallas:mgpu] Allow more freedom for the user to transform references. Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking: ``` # Store registers transposed o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out) o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m) o_smem_t = plgpu.untile_ref(o_smem_t, (n, m)) o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0)) o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED) plgpu.commit_smem() del o_smem_t # Now we need different transforms on the same smem to slice and async-store to gmem o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,) o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out) o_smem = plgpu.tile_ref(o_smem, swizzle_out) o_smem = o_smem.at[...] plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],) ``` Which in turn lets us write PiperOrigin-RevId: 742194519 --- jax/_src/pallas/mosaic_gpu/core.py | 22 +++++++++++++++++----- jax/experimental/pallas/mosaic_gpu.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 8522bdf651f4..b0d4f23c792e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -342,7 +342,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TransposeRef(state_types.Transform): - permutation: tuple[int, ...] + permutation: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -370,18 +370,30 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) -def transpose_ref( - ref: pallas_core.TransformedRef | Any, - permutation: tuple[int, ...], +def transform_ref( + ref: pallas_core.TransformedRef, + transform: state_types.Transform ) -> pallas_core.TransformedRef: if not isinstance(ref, pallas_core.TransformedRef): if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( - ref.ref, (*ref.transforms, TransposeRef(permutation)), + ref.ref, (*ref.transforms, transform), ) +def transpose_ref( + ref: pallas_core.TransformedRef | Any, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + return transform_ref(ref, TransposeRef(permutation)) + +def untile_ref(ref, tiling: tuple[int, ...]) -> pallas_core.TransformedRef: + return transform_ref(ref, UntileRef(tiling)) + +def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: + return transform_ref(ref, UnswizzleRef(swizzle)) + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index e4c5ffe04093..b44c86ea7a4c 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -27,7 +27,10 @@ from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref +from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref +from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9c5795e49da7..96532488a648 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -626,6 +626,26 @@ def body(tmp_ref): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) + def test_scoped_copy_with_user_transforms(self): + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128) + tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32)) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((16, 4, 8, 32), jnp.float32)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + in_specs=(in_spec,), + scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(f(x), x * 2) + def test_copy_with_transforms_and_indexing(self): self.skip_if_wg_semantics() From d3ed327572e4075ee5b7b0ba3b4b9633a737a39d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 03:47:57 -0700 Subject: [PATCH 0287/1769] [Pallas:MGPU] Remove (now) unnecessary TransposeTransforms Now that we always use small tiles, we can lay out the tiled dimension in arbitrary order so there's no need to swap them during the TMA. PiperOrigin-RevId: 742206980 --- jax/_src/pallas/mosaic_gpu/primitives.py | 5 ++--- jax/experimental/pallas/ops/gpu/attention_mgpu.py | 5 ++--- tests/pallas/mosaic_gpu_test.py | 2 -- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index d0632728f2b6..07235c2fc830 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -759,9 +759,8 @@ def _wgmma_lowering( rhs_transpose = False case ( gpu_core.UnswizzleRef(rhs_swizzle), - gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles gpu_core.UntileRef(rhs_tiling), - gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + gpu_core.TransposeRef((1, 0)), ): rhs_transpose = True case ( @@ -794,7 +793,7 @@ def _wgmma_lowering( raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") if rhs_transpose: - b = mgpu.memref_transpose(b, (0, 1, 3, 2)) + b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 48a0d18459cb..b19e371a1eb8 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -201,7 +201,7 @@ def entry(q_ref, k_ref, v_ref, out_ref): ) k_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + transforms=(tiling, swizzle), ) v_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, @@ -265,7 +265,6 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) - transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): batch = lax.axis_index("batch") @@ -354,7 +353,7 @@ def compute_pv(acc_ref): plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), - transforms=[tiling, transpose, swizzle]), + transforms=[tiling, swizzle]), plgpu.GPUBlockSpec( # v block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 96532488a648..965539af52ab 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1563,8 +1563,6 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) - if rhs_transpose: - rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( kernel, in_specs=[ From fc01058ee42cefc2502a5674eebfef79f2749ebe Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 31 Mar 2025 05:15:12 -0700 Subject: [PATCH 0288/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f4a53456b04acf9b63b3b30bd828cec29c4aa7de. PiperOrigin-RevId: 742228024 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8b3ddfde019b..d078359af86a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8df9390dc9444d900c7c7f2c123f23b549adf8e3" -XLA_SHA256 = "8e97c395d1e50a49fab386ccc7da1f78dc86bf670b20a892656e2e75bbf64f0e" +XLA_COMMIT = "f4a53456b04acf9b63b3b30bd828cec29c4aa7de" +XLA_SHA256 = "2ee32b70af547fd13ce404d75c3fa9834bc8be46a488cd8f0caa10e9a6ec7ede" def repo(): tf_http_archive( From cb5168269119ca098d678869c74303982ec84b17 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 31 Mar 2025 07:07:12 -0700 Subject: [PATCH 0289/1769] [pallas:mosaic_gpu] Run all Mosaic GPU-specific tests under WG semantics We do skip quite a few due to missing features. I tried to make the reason for skipping clear in each case. PiperOrigin-RevId: 742252858 --- tests/pallas/mosaic_gpu_test.py | 367 +++++++++++++++++++++----------- 1 file changed, 241 insertions(+), 126 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 965539af52ab..6b1839a64580 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -85,6 +85,13 @@ def skip_if_wg_semantics(self): if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: self.skipTest("Not supported under WG semantics") + def kernel(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + thread_semantics=self.THREAD_SEMANTICS, + ) + return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) + def pallas_call(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), @@ -975,7 +982,7 @@ def kernel(x_ref, o_ref): def test_load_scalar(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], ) @@ -987,7 +994,7 @@ def kernel(x_ref, o_ref): def test_run_scoped(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), ) def kernel(x_ref, o_ref): @@ -1005,7 +1012,7 @@ def body(tmp_ref): def test_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -1024,7 +1031,7 @@ def test_program_id_in_squashed_grid(self): # 3 CUDA grid dimensions. grid = (2, 3, 4, 5) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), @@ -1045,7 +1052,7 @@ def kernel(o_ref): def test_program_id_in_block_spec(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),), out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)), out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32), @@ -1059,7 +1066,7 @@ def kernel(x_ref, o_ref): def test_num_programs(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -1074,6 +1081,7 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + self.skip_if_wg_semantics() spec = plgpu.GPUBlockSpec( (128, 64), @@ -1083,7 +1091,7 @@ def test_swizzled_blockspec_shapes(self): ), ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[spec], out_specs=spec, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), @@ -1124,7 +1132,7 @@ def kernel(o_ref): def test_fori_loop_dynamic_bounds(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), grid=(1,) ) @@ -1201,8 +1209,10 @@ def body(acc): ) def test_while_loop_layout_mismatch(self): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(o_ref): def cond(acc): @@ -1255,8 +1265,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) - # Not testing with warpgroup semantics, because we want to enforce a layout. def test_tile_slicing(self): + # Not testing with warpgroup semantics, because we want to enforce a layout. + self.skip_if_wg_semantics() + shape = (256, 128) block_spec = plgpu.GPUBlockSpec( transforms=( @@ -1264,7 +1276,7 @@ def test_tile_slicing(self): ) ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[block_spec], out_specs=block_spec, out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), @@ -1289,7 +1301,7 @@ def kernel(a_ref, b_ref): a_ref[...] = jnp.ones_like(a_ref) a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( + b = self.pallas_call( kernel, in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), @@ -1299,6 +1311,8 @@ def kernel(a_ref, b_ref): np.testing.assert_array_equal(b, np.ones_like(a)) def test_slicing(self): + self.skip_if_wg_semantics() + left = upper = slice(None, 64) right = lower = slice(64, None) # We rotate the four quadrants of the input clockwise. @@ -1317,14 +1331,16 @@ def rotate(src, dst): plgpu.SwizzleTransform(128), ), ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) rotate(x, expected) np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), ) def kernel(o_ref): @@ -1334,6 +1350,8 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), x) def test_profiler(self): + self.skip_if_wg_semantics() # Transform inference fails. + def kernel(x_ref, o_ref): with jax.named_scope("add"): with jax.named_scope("load"): @@ -1343,7 +1361,7 @@ def kernel(x_ref, o_ref): o_ref[...] = o with tempfile.TemporaryDirectory() as tmpdir: x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), compiler_params=plgpu.GPUCompilerParams( @@ -1447,10 +1465,15 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): + # ``pl.run_state`` is not supported in WG semantics. + self.skip_if_wg_semantics() + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], + self.pallas_call, + in_specs=[ + plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms) + ], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), ) @@ -1462,8 +1485,7 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_realistic_matmul(self, thread_semantics): + def test_realistic_matmul(self): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1500,7 +1522,7 @@ def _epilogue(): lambda m, n, k: (m, n), ) - if thread_semantics == plgpu.ThreadSemantics.Lane: + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, transforms=( @@ -1523,7 +1545,7 @@ def _epilogue(): ) ) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[lhs_spec, rhs_spec], out_specs=out_spec, @@ -1534,13 +1556,14 @@ def _epilogue(): dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, - thread_semantics=thread_semantics, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_if_wg_semantics() + # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -1563,7 +1586,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( @@ -1599,12 +1622,18 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec( + (64, 128), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (128, 192), lambda: (0, 0), transforms=transforms + ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), @@ -1612,6 +1641,9 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): + # ``pl.run_state`` is not supported in WG semantics. + self.skip_if_wg_semantics() + def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1623,12 +1655,18 @@ def scope(acc_ref): i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec( + (64, 128), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (128, 192), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (64, 192), lambda: (0, 0), transforms=transforms + ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), @@ -1636,6 +1674,8 @@ def scope(acc_ref): np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -1647,22 +1687,18 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + (2, 64, 128), lambda: (0, 0, 0), transforms=transforms ), plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + (2, 128, 192), lambda: (0, 0, 0), transforms=transforms ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), @@ -1671,6 +1707,8 @@ def scope(acc_ref): np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -1683,38 +1721,41 @@ def scope(acc_ref): key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( - (64, 128), - lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (64, 128), lambda *ij: ij, transforms=transforms ), plgpu.GPUBlockSpec( - (128, 128), - lambda *i: i, - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (128, 128), lambda *ij: ij, transforms=transforms ), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) +class PallasCallSm90AWGTest( + PallasCallSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PallasCallSm100ATest(PallasSm100ATest): def test_tmem_alloc(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), scratch_shapes=[ plgpu.TMEM((128, 128), jnp.float32), @@ -1734,6 +1775,12 @@ def kernel(y_ref, tmem_ref, smem_ref): jax.block_until_ready(kernel()) +class PallasCallSm100AWGTest( + PallasCallSm100ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PipelineTest(PallasTest): def test_pipeline_mode(self): @@ -1755,13 +1802,13 @@ def body(x_ref, y_ref, o_ref): @jax.jit def vadd(x, y): - return pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), - in_specs=in_specs, - out_specs=out_specs, - grid=data_size // block_size, - )(x, y) + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=in_specs, + out_specs=out_specs, + grid=data_size // block_size, + )(x, y) with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): vadd(x, y) @@ -1823,7 +1870,7 @@ def body(step, _): plgpu.wait_smem_to_gmem(0) x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1837,6 +1884,9 @@ def body(step, _): ((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),), ) def test_emit(self, transforms): + if transforms: + self.skip_if_wg_semantics() + num_steps = 4 def kernel(x_gmem, o_gmem): @@ -1863,7 +1913,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(64 * num_steps * 64) x = x.reshape(-1, num_steps * 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1896,7 +1946,7 @@ def nested_kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1921,7 +1971,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1954,7 +2004,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1982,7 +2032,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1991,9 +2041,17 @@ def kernel_body(x_smem, o_smem): np.testing.assert_array_equal(kernel_fn(x), x + 1.0) +class PipelineWGTest( + PipelineTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PipelineSm90ATest(PallasSm90ATest): def test_realistic_matmul(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -2003,6 +2061,13 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + def kernel(a_gmem, b_gmem, o_smem, acc): def kernel_body(a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) @@ -2016,21 +2081,11 @@ def kernel_body(a_smem, b_smem): kernel_body, in_specs=[ plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda k: (pid_m, k), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda k: (k, pid_n), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), ], grid=(grid_k,), max_concurrent_steps=2, @@ -2043,19 +2098,14 @@ def kernel_body(a_smem, b_smem): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=plgpu.GMEM), - pl.BlockSpec(memory_space=plgpu.GMEM) + pl.BlockSpec(memory_space=plgpu.GMEM), ], out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n: (m, n), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms ), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], @@ -2064,11 +2114,19 @@ def kernel_body(a_smem, b_smem): np.testing.assert_array_equal(res, a @ b) +class PipelineSm90AWGTest( + PipelineSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): + self.skip_if_wg_semantics() # Times out! + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) blk_m = blk_n = 64 @@ -2102,7 +2160,7 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): ), ], ) - kernel = plgpu.kernel( + kernel = self.kernel( pipeline, out_shape=( jax.ShapeDtypeStruct((m, n), jnp.float16), @@ -2119,6 +2177,8 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + self.skip_if_wg_semantics() # Crashes! + blk_m = blk_n = 64 spec = pl.BlockSpec( block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) @@ -2140,7 +2200,7 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): in_specs=[spec, spec], out_specs=[spec], ) - kernel = plgpu.kernel( + kernel = self.kernel( pipeline, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), compiler_params=plgpu.GPUCompilerParams(approx_math=True), @@ -2154,10 +2214,12 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + blk_m = blk_n = 64 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), scratch_shapes=[ plgpu.SMEM((blk_m, blk_n), jnp.float32), @@ -2214,11 +2276,19 @@ def tiled_acc_kernel(x_smem, carry): np.testing.assert_allclose(kernel(x), ref, atol=1e-4) +class WarpSpecializedPipelineWGTest( + WarpSpecializedPipelineTest, + thread_semantics=plgpu.ThreadSemantics.Warpgroup, +): + ... + + class CoreMapTest(PallasTest): def test_multiple_wg(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((2, 128), np.int32), num_threads=2, thread_name="wg", @@ -2232,8 +2302,9 @@ def kernel(o_ref): ) def test_multiple_wg_with_grid(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((4, 2, 128), np.int32), grid=(2, 2), grid_names=("x", "y"), @@ -2263,7 +2334,7 @@ def test_multiple_wg_with_squashed_grid(self): num_threads = 2 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros( (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 ), @@ -2290,8 +2361,10 @@ def kernel(o_ref): np.testing.assert_array_equal(result, ref) def test_cross_wg_barrier(self): + self.skip_if_wg_semantics() # Times out! + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((2, 128), np.int32), # Each warpgroup is a single logical thread! scratch_shapes=[plgpu.Barrier(num_arrivals=2)], @@ -2309,8 +2382,10 @@ def kernel(o_ref, barrier): ) def test_cluster(self): + self.skip_if_wg_semantics() # Needs debug_print in the MGPU dialect. + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros(128, np.int32), grid=(2,), grid_names=("x",), @@ -2337,6 +2412,8 @@ def kernel(ref): ) def test_realistic_matmul_with_cluster(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -2361,7 +2438,7 @@ def test_realistic_matmul_with_cluster(self): delay_release = 1 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype), scratch_shapes=[ plgpu.SMEM( @@ -2458,34 +2535,44 @@ def body(step, _): np.testing.assert_array_equal(kernel(a, b), a @ b) +class CoreMapWGTest( + CoreMapTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class ExamplesTest(PallasTest): # Basic def test_stage0(self): - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial(self.kernel, out_shape=x) + def kernel(l_ref, r_ref, o_ref): o_ref[...] = l_ref[...] + r_ref[...] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x)(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Multi-block kernels def test_stage1(self): row_block = 64 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Async copies def test_stage3(self): row_block, col_block = 64, 128 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), scratch_shapes=[ *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), @@ -2510,7 +2597,12 @@ def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): # Pipelining def test_stage4(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") @@ -2522,14 +2614,19 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Transforms def test_stage5(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") @@ -2544,9 +2641,7 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) def test_semaphore_lowering(self): # This is a smoke test until we add support for lowering of semaphore ops. @@ -2556,8 +2651,10 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) o_ref[...] = i_ref1[...] x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) - kernel = pl.pallas_call( - body, out_shape=x, scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + kernel = self.pallas_call( + body, + out_shape=x, + scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], ) text = jax.jit(kernel).lower(x, x).as_text() self.assertIn( @@ -2573,19 +2670,33 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): ) +class ExamplesWGTest( + ExamplesTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class ExamplesSm90ATest(PallasSm90ATest): # WGMMA def test_stage6(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + m_block = n_block = 64 k_block = 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) - m, n = lax.axis_index("m"), lax.axis_index("n") + m = lax.axis_index("m") + n = lax.axis_index("n") lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( @@ -2596,12 +2707,16 @@ def do_wgmma(acc_ref): out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), grid_names=("m", "n"))(x, x) - np.testing.assert_allclose(out, x @ x) + np.testing.assert_allclose(kernel(x, x), x @ x) # TODO(apaszke): Clusters and multicast +class ExamplesSm90AWGTest( + ExamplesSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + if __name__ == "__main__": absltest.main() From 12526ea11646a75fac201e26c1a2e901f94a4c76 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 31 Mar 2025 07:08:48 -0700 Subject: [PATCH 0290/1769] [jaxlib] Pack/unpack subbyte types to/from numpy arrays to support int2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks. PiperOrigin-RevId: 742253272 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 89 +++++++++++++++++++++------------ jaxlib/rocm/BUILD | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 81 ++++++++++++++++++++---------- tests/python_callback_test.py | 94 ++++++++++++++++++++--------------- 6 files changed, 168 insertions(+), 99 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index fac62c81dee7..d35e421ef904 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,6 +689,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 861ffce3e749..38f2ac1896e7 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -80,13 +81,14 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); } + if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -112,9 +114,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -122,8 +121,22 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + arg->size_bytes()); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], base); + host_input_buffers[i], /*base=*/base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -146,8 +159,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -168,32 +180,43 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - continue; + + const void* data = array.data(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; } - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + ret->size_bytes()); + data = buffer.get(); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), temp); - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d0c0c798abb8..358a6d1cc9aa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,6 +588,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2ca18afda13d..5b532c1dc501 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,6 +637,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index ac4e7bee5680..fc4f895af6aa 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -78,9 +79,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,9 +97,18 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + arg->size_bytes()); + data = buffer.get(); + } // We pass in data using default numpy layout i.e., std::nullopt. - auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -119,9 +129,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -141,26 +151,45 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); - continue; + + const void* data = array.data(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); } - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions_size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + ret->size_bytes()); + data = buffer.get(); + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, ret->size_bytes()); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a8442b4a1356..34ab20c05644 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,10 +586,15 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(x): return x def f(x): @@ -600,21 +605,17 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)(x) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(): return np.arange(8, dtype=dtype) @@ -625,16 +626,43 @@ def f(): ) return y - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)() + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) class PureCallbackTest(jtu.JaxTestCase): @@ -1108,20 +1136,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): From b3d851d722ea5efb893d96b3c03a739ba9763bd0 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 31 Mar 2025 07:34:32 -0700 Subject: [PATCH 0291/1769] Add Jax tracing micro benchmarks. Add a first benchmark for tracing/lowering pallas splash attention. Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan. --------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------- test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19 test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18 PiperOrigin-RevId: 742259409 --- benchmarks/tracing_benchmark.py | 76 +++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 benchmarks/tracing_benchmark.py diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..e06ad538d476 --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import numpy as np + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + jax.clear_caches() + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + jax.clear_caches() + + +if __name__ == "__main__": + google_benchmark.main() From 95497ca2f0d41af0ca97af408932982fa3fa7160 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 31 Mar 2025 07:42:16 -0700 Subject: [PATCH 0292/1769] Remove legacy GPU kernel for LU decomposition. Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility It has been 6 months since the release of 0.4.33 which is the relevant release for this kernel. PiperOrigin-RevId: 742261532 --- jaxlib/gpu/blas.cc | 10 ----- jaxlib/gpu/blas_kernels.cc | 60 -------------------------- jaxlib/gpu/blas_kernels.h | 10 ----- jaxlib/gpu/gpu_kernels.cc | 3 -- jaxlib/gpu/solver.cc | 41 ------------------ jaxlib/gpu/solver_kernels.cc | 83 ------------------------------------ jaxlib/gpu/solver_kernels.h | 10 ----- 7 files changed, 217 deletions(-) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index cf391e07e31e..59bf2c4603f6 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -49,14 +49,6 @@ BlasType DtypeToBlasType(const dtype& np_type) { return it->second; } -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, - int b, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - // Returns the descriptor for a GetrfBatched operation. std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, int b, int m, int n) { @@ -67,7 +59,6 @@ std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; } @@ -76,7 +67,6 @@ NB_MODULE(_blas, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); } diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index ac30aa9cc520..cdcc154d026d 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -52,66 +52,6 @@ int SizeOfBlasType(BlasType type) { } // namespace -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Batched QR decomposition: geqrfbatched static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h index 724565ea73d1..8ca7b4db4668 100644 --- a/jaxlib/gpu/blas_kernels.h +++ b/jaxlib/gpu/blas_kernels.h @@ -32,16 +32,6 @@ enum class BlasType { C128, }; -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - BlasType type; - int batch, n; -}; - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Batched QR decomposition: geqrfbatched struct GeqrfBatchedDescriptor { diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 242078357254..840c313f2fa3 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -33,13 +33,10 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 20fc308100c4..8013d9877ed5 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -54,45 +54,6 @@ SolverType DtypeToSolverType(const dtype& np_type) { return it->second; } -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - // geqrf: QR decomposition // Returns the workspace size and a descriptor for a geqrf operation. @@ -462,7 +423,6 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); @@ -496,7 +456,6 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8c22dfcdbca7..8971619d7f34 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -50,89 +50,6 @@ static int SizeOfSolverType(SolverType type) { } } -// getrf: LU decomposition - -static absl::Status Getrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // geqrf: QR decomposition static absl::Status Geqrf_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index 51082f2fe812..6372e55b930d 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -33,16 +33,6 @@ enum class SolverType { C128, }; -// getrf: LU decomposition - -struct GetrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // geqrf: QR decomposition struct GeqrfDescriptor { From 6b719496ed83f3ca18e0e42f32892eb63102af3b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 31 Mar 2025 07:59:05 -0700 Subject: [PATCH 0293/1769] [pallas:mosaic_gpu] Fixed lane-level lowering of `lax.optimization_barrier` PiperOrigin-RevId: 742265860 --- jax/_src/pallas/mosaic_gpu/lowering.py | 14 +++++++------- tests/pallas/mosaic_gpu_test.py | 6 ------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index daa718ff1ff2..f027d5bcb76d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2369,20 +2369,20 @@ def _bitcast_convert_type_lowering_rule( @register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): - args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) - return mgpu.optimization_barrier(*args) + result = mgpu.optimization_barrier( + *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + ) + return (result,) if len(ctx.avals_in) == 1 else result @register_lowering_rule( lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup ) def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): - args = [ + result = mgpu.dialect.optimization_barrier([ _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) - ] - result = mgpu.dialect.optimization_barrier(args) - - return (result,) if len(args) == 1 else result + ]) + return (result,) if len(ctx.avals_in) == 1 else result def _bcast( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6b1839a64580..77e934f656e7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1405,9 +1405,6 @@ def convert(x_ref, y_ref): ) def test_optimization_barrier(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: - self.skipTest("This test crashes with lane semantics") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), @@ -1419,9 +1416,6 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x) def test_optimization_barrier_multiple_inputs(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: - self.skipTest("This test crashes with lane semantics") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), From 200f8263980bb1346c15f4616e28f129cf0b4f85 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 31 Mar 2025 08:50:39 -0700 Subject: [PATCH 0294/1769] [array api] return all devices in devices() --- jax/_src/numpy/array_api_metadata.py | 9 +++++++-- tests/array_api_test.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index d8d2c2d1a2a4..5267e51215ee 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -24,7 +24,9 @@ import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config +from jax._src import config +from jax._src import dtypes as _dtypes +from jax._src import xla_bridge as xb __array_api_version__ = '2023.12' @@ -73,7 +75,10 @@ def default_device(self): return None def devices(self): - return jax.devices() + out = [None] # None indicates "uncommitted" + for backend in xb.backends(): + out.extend(jax.devices(backend)) + return out def capabilities(self): return self._capabilities diff --git a/tests/array_api_test.py b/tests/array_api_test.py index d509fe78c35f..8e4ba275fdd3 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src import xla_bridge as xb ARRAY_API_NAMESPACE = jnp @@ -283,7 +284,10 @@ def test_default_device_info(self): assert self.info.default_device() is None def test_devices_info(self): - assert self.info.devices() == jax.devices() + devices = set(self.info.devices()) + assert None in devices + for backend in xb.backends(): + assert devices.issuperset(jax.devices(backend)) def test_default_dtypes_info(self): _default_dtypes = { From aaa3ebfb8a135e4c82c08e551fd756ad5db85716 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Mon, 31 Mar 2025 12:05:30 -0500 Subject: [PATCH 0295/1769] Add optimization barrier. --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 60cdbee7fa20..6766e3992202 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -495,7 +495,7 @@ def quantize(x, config): SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) - scales_q = scales_q.astype(config.scale_type) + scales_q = jax.lax.optimization_barrier(scales_q.astype(config.scale_type)) scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") From 05039fe520906a2cd9562593406dd4544828515e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 17:49:43 +0000 Subject: [PATCH 0296/1769] Bump tsickert/discord-webhook from 5.3.0 to 7.0.0 Bumps [tsickert/discord-webhook](https://github.com/tsickert/discord-webhook) from 5.3.0 to 7.0.0. - [Release notes](https://github.com/tsickert/discord-webhook/releases) - [Commits](https://github.com/tsickert/discord-webhook/compare/c840d45a03a323fbc3f7507ac7769dbd91bfb164...b217a69502f52803de774ded2b1ab7c282e99645) --- updated-dependencies: - dependency-name: tsickert/discord-webhook dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/community_release_actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index d61bea3d7e4d..1980e803ba9b 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -25,7 +25,7 @@ jobs: maxLength: 2000 truncationSymbol: "..." - name: Discord Webhook Action - uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0 + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 with: webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} content: ${{ steps.get-content.outputs.string }} From 5d69e6b64dca1ef4590a942e8de173fc40e46a34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 17:49:47 +0000 Subject: [PATCH 0297/1769] Bump actions/setup-python from 5.4.0 to 5.5.0 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.4.0 to 5.5.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/42375524e23c412d93fb67b49958b491fce71c38...8d9ed9ac5c53483de85588cdf95a591a75ab9f55) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/jax-array-api.yml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f43407af2ed9..c575c84cd422 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,7 +31,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.11 - run: python -m pip install pre-commit @@ -70,7 +70,7 @@ jobs: apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -108,7 +108,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -142,7 +142,7 @@ jobs: apt update apt install -y libssl-dev libsqlite3-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -168,7 +168,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -201,7 +201,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..c91ab6b8b7da 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 5132a12cf16f..ba2c750f8a8a 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 912088428fd5..a2b3aeddc24a 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index fc2b63396f56..5a435023ffda 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' From 1355e7c65003428c5922df306cef77cef48412ed Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 5 Mar 2025 15:32:25 +0000 Subject: [PATCH 0298/1769] AutoPGLE: force-disable graphs less Previously, XLA's command buffers (CUDA graphs) would be disabled both for PGLE profile collection and when re-compiling using the profile data. With this change, they are only disabled when collecting the profile data. --- jax/_src/compiler.py | 239 ++++++++++++++++++++++--------------------- tests/pgle_test.py | 52 +++++++++- 2 files changed, 174 insertions(+), 117 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index dea532d13031..9ac47aa4f0ea 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -17,6 +17,8 @@ from __future__ import annotations from collections.abc import Sequence +import copy +from functools import partial import logging import time from typing import Any, Callable @@ -197,15 +199,6 @@ def get_compile_options( config.memory_fitting_level.value ).value - # This is a temporary workaround to simplify the AutoPGLE usage. - # TODO(b/376647494): Remove once the bug is fixed. - if ((config.enable_pgle.value and config.pgle_profiling_runs.value > 0) - or config.compilation_cache_expect_pgle.value): - logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") - if env_options_overrides is None: - env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -298,6 +291,8 @@ def backend_compile( options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value # Convert ir.Module to a string representation, unless the backend # explicitly flags the ability to handle a module directly (avoiding the # overhead of back and forth conversions). @@ -308,6 +303,14 @@ def backend_compile( else: built_c = module + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results @@ -362,72 +365,31 @@ def compile_or_get_cached( if dumped_to := mlir.dump_module_to_file(computation, "compile"): logging.info("Dumped the module to %s.", dumped_to) - use_compilation_cache = compilation_cache.is_cache_used(backend) - is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 ) min_device_process_id = min( devices.flatten(), key=lambda device: device.id ).process_index - is_auto_pgle_used = ( - config.enable_pgle.value and config.pgle_profiling_runs.value > 0 - ) - if not use_compilation_cache: - if ( - is_multi_process - and is_auto_pgle_used - and distributed.global_state.client is not None - ): - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) + # cache_key: may be None if compilation caching is disabled + cache_key, compile_options = _resolve_compilation_strategy( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + module_name, + min_device_process_id, + ) + if cache_key is None: return backend_compile(backend, computation, compile_options, host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - try: - if config.remove_custom_partitioning_ptr_from_cache_key.value: - ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING - else: - ignore_callbacks = cache_key_type.IgnoreCallbacks.NO - - cache_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - ignore_callbacks=ignore_callbacks, - ) - except xc._xla.XlaRuntimeError as ex: - logger.error("compile_or_get_cached: unable to generate cache key, " - "skipping the cache: %s", ex) - return backend_compile(backend, computation, compile_options, - host_callbacks) - - if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: - cache_key = _resolve_pgle_module_cache_key( - computation, - devices, - compile_options, - backend, - pgle_profiler, - is_multi_process, - cache_key, - module_name, - min_device_process_id, - ) - cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) @@ -481,85 +443,130 @@ def compile_or_get_cached( # 1. PGLE optimized module (the one which was recompiled with FDO profile) is # in the persistent cache. In this case the module should be returned from # cache and PGLE should be disabled for this module. Is module is stored in -# the persistent cache under the "pgle_profiled_module_key" which calculated -# with replacing FDO profile with flag which identify that module were PGLE -# profiled. +# the persistent cache under the "pgle_optimized_cache_key", which is +# calculated by replacing the FDO profile with a sentinel value that identifies +# that the module was optimized with PGLE. # 2. PGLE profiled module is not in the persistent cache and the module is -# getting built with an FDO profile. In this case we need to share FDO profile -# with other processes and store the result under the -# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# getting built with an FDO profile. In this case we need to share the FDO +# profile with any other processes and store the result under the +# "pgle_optimized_cache_key" so later in case 1 we will be able to find the # module. # 3. PGLE profiled module is not in the persistent cache and the module is # getting compiled to be PGLEd (FDO profile is empty). In this case we need to -# simply return the non-PGLE profiled module from the persistent cache. +# simply return the non-PGLE profiled module from the persistent cache if it +# exists, and otherwise compile it. # # If the compilation_cache_expect_pgle option is set then in case 1 the PGLE # optimized module will be loaded even if PGLE is not enabled in the current # process. This is useful if we want to combine the use of PGLE with other # profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to # contention for CUPTI resources. -def _resolve_pgle_module_cache_key( +def _resolve_compilation_strategy( computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, pgle_profiler: profiler.PGLEProfiler | None, is_multi_process: bool, - cache_key: str, module_name: str, min_device_process_id: int, -) -> str: - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, +) -> tuple[str | None, xc.CompileOptions]: + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - result_key = cache_key - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - result_key = pgle_profiled_module_key - if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") - if pgle_profiler is not None: - pgle_profiler.disable() + + get_cache_key = partial(_get_cache_key, backend=backend, + computation=computation, devices=devices) + + if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: + # This can be None if cache key generation fails. + pgle_optimized_cache_key = get_cache_key(compile_options, + override_fdo_profile=b"pgle profiled") + # TODO(b/376647494): remove the workaround when the bug is fixed; the JAX + # profiler cannot collect sufficiently detailed profile data for PGLE if + # command buffers / CUDA graphs are enabled. Therefore disable command + # buffers when compiling for PGLE data collection, but not if AutoPGLE is + # not enabled, and not when re-compiling using PGLE data. This condition + # includes `compilation_cache_expect_pgle` so that slow-to-compile modules + # that are not executed often enough to trigger re-compilation will still + # be cached between an "enable_pgle" run and an "expect_pgle" run. + first_pass_compile_options = copy.deepcopy(compile_options) + first_pass_compile_options.env_option_overrides += [ + ("xla_gpu_enable_command_buffer", ""), + ] else: - # No PGLE-optimised module found in the persistent cache. - if (config.compilation_cache_expect_pgle.value - and _is_executable_in_cache(backend, cache_key)): - # The user asserted this miss was unexpected; emit a warning + pgle_optimized_cache_key = None + first_pass_compile_options = compile_options + + # This can be None if cache key generation fails or caching is disabled + cache_key = get_cache_key(first_pass_compile_options) + + if cache_key is not None and pgle_optimized_cache_key is not None: + # The compilation cache is enabled and AutoPGLE is enabled/expected + if _is_executable_in_cache(backend, pgle_optimized_cache_key): + if config.compilation_cache_expect_pgle.value: + logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") + # No need to record N profiles in this case + if pgle_profiler is not None: + pgle_profiler.disable() + return pgle_optimized_cache_key, compile_options + elif (config.compilation_cache_expect_pgle.value + and _is_executable_in_cache(backend, cache_key)): + # No PGLE-optimized module found in the persistent cache, and the user + # asserted (expect_pgle) that this miss was unexpected warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} " "despite non-PGLE hit; it may not have been executed " "enough times when the cache was populated") - if fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - result_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile of length %d", - module_name, - len(compile_options.executable_build_options.fdo_profile), + + if (is_auto_pgle_used + and compile_options.executable_build_options.fdo_profile is not None + and len(compile_options.executable_build_options.fdo_profile)): + # Profile data are available to trigger a PGLE-optimized recompilation; + # store under `pgle_optimized_cache_key` if the cache is enabled + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, ) - return result_key + ) + return pgle_optimized_cache_key, compile_options + else: + # Compile for PGLE collection, store under `cache_key` if the cache is + # enabled. This is also the AutoPGLE-disabled path. + return cache_key, first_pass_compile_options +def _get_cache_key( + options: xc.CompileOptions, + backend: xc.Client, + computation: ir.Module, + devices: np.ndarray, + override_fdo_profile: bytes | None = None) -> str | None: + if not compilation_cache.is_cache_used(backend): + return None + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + if override_fdo_profile is not None: + options = copy.deepcopy(options) + options.executable_build_options.fdo_profile = override_fdo_profile + try: + return compilation_cache.get_cache_key( + computation, + devices, + options, + backend, + ignore_callbacks, + ) + except xc._xla.XlaRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return None # The process that has the lowest device ID should share FDO profile before # compilation with other processes. diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dabd809d95e..2787de4c6e17 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,7 @@ import tempfile import warnings -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import api from jax._src import compilation_cache as cc @@ -478,5 +478,55 @@ def check_if_cache_hit(event): self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) + @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() + def testAutoPgleWithCommandBuffers(self, enable_compilation_cache): + with (config.pgle_profiling_runs(1), + config.enable_compilation_cache(enable_compilation_cache), + config.enable_pgle(True), + tempfile.TemporaryDirectory() as dump_dir, + tempfile.TemporaryDirectory() as cache_dir): + if enable_compilation_cache: + cc.reset_cache() + cc.set_cache_dir(cache_dir) + compiler_options = { + 'xla_dump_to': dump_dir, + # FUSION, see https://github.com/openxla/xla/issues/22459 + 'xla_gpu_enable_command_buffer': 1, + 'xla_gpu_graph_min_graph_size': 1, + } + @partial( + jax.jit, + compiler_options=compiler_options, + ) + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + # This is ugly, but it does not seem possible to get the AutoPGLE-recompiled + # executable text (.lower(x).compile().as_text() or similar). + def get_new_hlo(): + additions = set(os.listdir(dump_dir)) - get_new_hlo.seen_files + get_new_hlo.seen_files |= additions + new_hlos = list(filter(lambda f: f.endswith("_gpu_after_optimizations.txt"), additions)) + assert len(new_hlos) == 1 + with open(os.path.join(dump_dir, new_hlos[0]), "r") as ifile: + return ifile.read() + + get_new_hlo.seen_files = set() + + # Run 1 + self.assertArraysEqual(f(x), expected) + self.assertNotIn("command_buffer", get_new_hlo()) # b/376647494 workaround + # Run 2 + self.assertArraysEqual(f(x), expected) + self.assertIn("command_buffer", get_new_hlo()) # workaround disabled + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From d6b4fed5ed25432fd5298fe108e797dac734465d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 31 Mar 2025 11:33:05 -0700 Subject: [PATCH 0299/1769] Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable PiperOrigin-RevId: 742334213 --- jax/_src/lax/parallel.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 8fc8c336d61a..ebc6255cb66b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -35,6 +35,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.mesh import get_abstract_mesh from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir @@ -1860,8 +1861,14 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - _check_axis_names([axis_name]) - return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} + effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + _check_axis_names(axis_name) + mesh = get_abstract_mesh() + sharding = NamedSharding(mesh, P()) + vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) + if config.varying_axes_in_types.value else frozenset()) + return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): return lax.iota(np.int32, axis_data.size), 0 From 8cda2a23dda416bd150b39f1c3580602fe2aa4f5 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Thu, 20 Mar 2025 15:20:14 -0500 Subject: [PATCH 0300/1769] [Mosaic-GPU] [2/3] Add NVSHMEM support to Mosaic-GPU custom call --- .../mosaic_gpu/pallas_call_registration.py | 1 + jax/experimental/mosaic/gpu/core.py | 48 ++++- jaxlib/mosaic/gpu/BUILD | 15 ++ jaxlib/mosaic/gpu/custom_call.cc | 165 +++++++++++++++--- jaxlib/mosaic/gpu/mosaic_gpu_comm.h | 86 +++++++++ jaxlib/mosaic/gpu/runtime.cc | 13 ++ 6 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 jaxlib/mosaic/gpu/mosaic_gpu_comm.h diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 1d4be26187ce..ff3c4f89d30c 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -96,6 +96,7 @@ def zero_init_gmem_scratch(): module=module, out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), input_output_aliases=input_output_aliases, + use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. outs = outs[:-len(lowering_result.gmem_scratch_shapes)] diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index fcc5d3db6d60..43b93e7da023 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -83,6 +83,15 @@ # Set this so that the custom call can find it os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) +if os.environ.get("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH") is None: + try: + from nvidia import nvshmem + except ImportError: + pass + else: + os.environ["MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"] = ( + os.path.join(nvshmem.__path__[0], 'lib/libnvshmem_device.bc') + ) mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @@ -103,6 +112,7 @@ def _mosaic_gpu_lowering_rule( module, out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), + use_custom_barrier: bool = False, ): assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) @@ -121,15 +131,35 @@ def _mosaic_gpu_lowering_rule( raise RuntimeError("Hash collision!") else: KNOWN_KERNELS[kernel_id] = module_asm - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module_asm, - operand_output_aliases=dict(input_output_aliases), - ) + + if ctx.is_forward_compat(): + if use_custom_barrier: + raise ValueError("Barrier semaphore is not supported in forward compatibility mode. " + "Please, use 'export_ignore_forward_compatibility=True'.") + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + operand_output_aliases=dict(input_output_aliases), + ) + else: + op = mlir.custom_call( + "mosaic_gpu_v2", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=dict( + kernel_hash=ir.StringAttr.get(kernel_id), + module=ir.StringAttr.get(module_asm), + use_custom_barrier=ir.BoolAttr.get(use_custom_barrier), + ), + operand_output_aliases=dict(input_output_aliases), + api_version=4, + ) return op.results diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 80a8f0e51080..6f9a729688ff 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -114,11 +114,21 @@ cc_library( # Linker may prune these symbols if they are not explicitly exported. linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], deps = [ + ":mosaic_gpu_comm", "@local_config_cuda//cuda:cuda_headers", ], alwayslink = True, ) +cc_library( + name = "mosaic_gpu_comm", + hdrs = ["mosaic_gpu_comm.h"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "custom_call", srcs = ["custom_call.cc"], @@ -127,9 +137,11 @@ cc_library( ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "//jaxlib/mosaic/gpu:mosaic_gpu_comm", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -175,6 +187,8 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", ], alwayslink = True, ) @@ -210,5 +224,6 @@ cc_binary( deps = [ "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", + "//jaxlib/mosaic/gpu:mosaic_gpu_comm", ], ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d9a69c57e142..465551e2903b 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -87,14 +88,19 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" namespace { +namespace ffi = xla::ffi; + using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); @@ -121,7 +127,7 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, - const std::string& sm, const std::string& ptx_isa) { + const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { static bool register_once = []() { llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); @@ -179,8 +185,8 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)", - mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, + gpu-module-to-binary{format=)" + + mlir::gpu::stringifyCompilationTarget(target).str() + (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, @@ -289,7 +295,7 @@ class TemporaryDirectory { }; void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, - const std::string& ptx_isa) { + const std::string& ptx_isa, const std::string& nvshmem_path) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -300,7 +306,8 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, module = module.clone(); // Prevent accidental modification. absl::Cleanup module_destroyer = [module] { module->erase(); }; auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); + module.getContext(), mlir::gpu::CompilationTarget::Assembly, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -358,7 +365,29 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, } } -absl::StatusOr> Compile( +bool is_nvshmem_used(mlir::ModuleOp module) { + constexpr std::string_view prefix1 = "nvshmem_"; + constexpr std::string_view prefix2 = "nvshmemx_"; + for (mlir::LLVM::LLVMFuncOp llvm_func : module.getOps()) { + const auto& func_name = llvm_func.getName(); + if (!func_name.starts_with(prefix1) && !func_name.starts_with(prefix2)) { + continue; + } + auto uses = mlir::SymbolTable::getSymbolUses(llvm_func, module.getOperation()); + if (uses && !uses->empty()) { + return true; + } + } + return false; +} + +absl::StatusOr get_nvshmem_llvm_lib_path() { + const char * nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + if (!nvshmem_path_ptr) return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + return nvshmem_path_ptr; +} + +absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); if (!sm_and_ptx_isa.ok()) { @@ -366,9 +395,16 @@ absl::StatusOr> Compile( } const std::string sm = sm_and_ptx_isa.value().first; const std::string ptx_isa = sm_and_ptx_isa.value().second; - DumpCompilationOutput(module, sm, ptx_isa); + bool is_comm_used = is_nvshmem_used(module); + std::string nvshmem_path = ""; + if (is_comm_used) { + TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + } + DumpCompilationOutput(module, sm, ptx_isa, nvshmem_path); auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); + module.getContext(), + mlir::gpu::CompilationTarget::Binary, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } @@ -392,23 +428,25 @@ absl::StatusOr> Compile( if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); } - return std::move(*maybe_execution_engine); + return std::make_pair(std::move(*maybe_execution_engine), is_comm_used); } class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch) - : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + MosaicHostFunc* host_launch, bool is_comm_used) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch), + is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_, is_comm_used_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly MosaicHostFunc* host_launch_; + bool is_comm_used_; }; using KernelHash = std::array; @@ -477,7 +515,8 @@ absl::StatusOr CompileAndInit(const char* module) { if (!maybe_engine.ok()) { return maybe_engine.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + mlir::ExecutionEngine* execution_engine = maybe_engine.value().first.get(); + bool is_comm_used = maybe_engine.value().second; auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); if (!host_and_init_func_names.ok()) { @@ -496,14 +535,15 @@ absl::StatusOr CompileAndInit(const char* module) { void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*host)); + return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(*host), + is_comm_used); } // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CachedCompileAndInit( +absl::StatusOr CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -514,7 +554,7 @@ absl::StatusOr> CachedCompileAndInit( absl::ReaderMutexLock lock(mutex); auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) - return it->second.GetHostLaunch(); + return &it->second; } absl::MutexLock lock(mutex); @@ -526,11 +566,12 @@ absl::StatusOr> CachedCompileAndInit( } cache->insert_or_assign(key, std::move(*compiled)); } - return cache->at(key).GetHostLaunch(); + return &cache->at(key); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { + // Forward-compatible version using the legacy FFI API if (reinterpret_cast(opaque) % alignof(KernelHash)) { fprintf(stderr, "Misaligned opaque pointer\n"); abort(); @@ -542,20 +583,92 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!ctx_and_kernel.ok()) { + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { XlaCustomCallStatusSetFailure(status, - ctx_and_kernel.status().message().data(), - ctx_and_kernel.status().message().size()); + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; - std::get<1>(*ctx_and_kernel)(args); + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); +absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + absl::string_view kernel_hash, + absl::string_view module, + bool use_custom_barrier, + xla::RunId run_id) { + // Updated version using the new FFI API supporting custom barrier + // for distributed kernels + if (use_custom_barrier) { + fprintf(stderr, "Custom barrier is not supported on GPUs.\n"); + abort(); + } + if (reinterpret_cast(kernel_hash.data()) % + alignof(KernelHash) || + kernel_hash.size() != sizeof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(kernel_hash.data()); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(key, module.data())); + auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + + std::vector buffers; + buffers.reserve(inputs.size() + results.size()); + for (int i = 0; i < inputs.size(); ++i) { + buffers.push_back(inputs.get(i)->untyped_data()); + } + for (int i = 0; i < results.size(); ++i) { + buffers.push_back((*results.get(i))->untyped_data()); + } + void **buffers_ptr = buffers.data(); + void *args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; + + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, + ffi::Ffi::Bind() + .Ctx>() + .RemainingArgs() + .RemainingRets() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier") + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", + { + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMosaicGpuExecute, + }); + } // namespace extern "C" { @@ -566,7 +679,7 @@ void** MosaicGpuCompile(const char* module) { if (!compiled.ok()) { return nullptr; } - auto [ctx, launch] = compiled->GetHostLaunch(); + auto [ctx, launch, is_comm_used] = compiled->GetHostLaunch(); auto tuple_ptr = std::unique_ptr(new void*[3]); if (!tuple_ptr) { return nullptr; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_comm.h b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h new file mode 100644 index 000000000000..b0bd94883e43 --- /dev/null +++ b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_COMM_H_ +#define JAXLIB_MOSAIC_GPU_COMM_H_ + +#include +#include +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" + +#define NVSHMEM_SUCCESS 0 +#define NVSHMEM_LIB_SONAME "libnvshmem_host.so.3" + +namespace mosaic { +namespace gpu { + +#define NVSHMEM_SET_FN(FnName) \ + FnName = reinterpret_cast(dlsym(library, #FnName)); \ + if (!FnName) { \ + fprintf(stderr, #FnName " not available in this library."); \ + abort(); \ + } + +class NvshmemApi { + public: + // Returns a default NvshmemApi for a current process. + // NvshmemApi follows the Singleton design pattern + static NvshmemApi& Default() { + static NvshmemApi instance; + return instance; + } + + int cumodule_int(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_init(module); + } + + void barrier_all_on_stream(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); + } + + NvshmemApi(NvshmemApi const&) = delete; + void operator=(NvshmemApi const&) = delete; + + private: + NvshmemApi() { + const char* env_value = getenv("NVSHMEM_LIBRARY_PATH"); + const char* libnvshmem_path = + env_value && *env_value != 0 ? env_value : NVSHMEM_LIB_SONAME; + void* library = dlopen(libnvshmem_path, RTLD_LAZY); + if (library == nullptr) { + fprintf(stderr, "Failed to open %s library: %s", libnvshmem_path, dlerror()); + abort(); + } + + // Initialize supported NVSHMEM host API + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + } + + // Dlopened NVSHMEM API + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + + std::mutex mutex_; +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_COMM_H_ diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..6897bcf350df 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "third_party/gpus/cuda/include/cuda.h" + extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, @@ -154,6 +156,17 @@ void* mosaic_gpu_module_load(void *data) { fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); abort(); } + + CUdeviceptr ptr = 0; + size_t size = 0; + // Check if module contains NVSHMEM globals implying NVSHMEM state needs to set + if (cuModuleGetGlobal(&ptr, &size, module, "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_int(module) != NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } + } + return module; } From ca36047ac91b4e1b5107cfa55da7a7cd6301716d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 31 Mar 2025 15:14:47 -0700 Subject: [PATCH 0301/1769] __jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose --- jax/_src/numpy/lax_numpy.py | 8 ++++---- tests/array_extensibility_test.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 63edaed0adeb..7b900e09068e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1203,8 +1203,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Array([[1, 3], [2, 4]], dtype=int32) """ - util.check_arraylike("transpose", a) - axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + a = util.ensure_arraylike("transpose", a) + axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1285,8 +1285,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: [[5, 7], [6, 8]]], dtype=int32) """ - util.check_arraylike("matrix_transpose", x) - ndim = np.ndim(x) + x = util.ensure_arraylike("matrix_transpose", x) + ndim = x.ndim if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index fae9129dd99a..63a8762cd0b0 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -375,7 +375,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), - # NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), NumPyAPI.sig(jnp.max, Float[5]), NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), @@ -442,7 +442,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.reciprocal, Float[5]), NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), - # NumPyAPI.sig(jnp.reshape, Float[6], (2, 3)), + NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), NumPyAPI.sig(jnp.rint, Float[5]), @@ -481,7 +481,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), NumPyAPI.sig(jnp.trace, Float[5, 5]), - # NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), NumPyAPI.sig(jnp.tril, Float[5, 6]), NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), From f59f615f6f45c4524c9326daa800c95bead6cbf0 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 31 Mar 2025 15:54:01 -0700 Subject: [PATCH 0302/1769] Minor docstring updates for AOT wrappers in error checking PiperOrigin-RevId: 742431349 --- jax/_src/error_check.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index e78b9bc82115..b80def4fd2db 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -289,6 +289,11 @@ def wrap_for_export(f): function scope, making it possible to export the function and later import in other processes. + When the function is later imported, it must be wrapped with + :func:`unwrap_from_import` to integrate the error checking mechanism of the + imported function into the global error checking mechanism of the current + process. + This function should only be applied once to a function; wrapping the same function multiple times is unnecessary. """ @@ -327,6 +332,9 @@ def unwrap_from_import(f): separate from the global error state of the current process. This wrapper ensures that errors detected during execution are correctly integrated into the global error checking mechanism of the current process. + + This function should only be applied to functions that were previously wrapped + with :func:`wrap_for_export` before export. """ if _error_storage.ref is None: with core.eval_context(): From 4003e2d0eec70ab0b5ed4e5c8bad8a1148a2efd8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 27 Mar 2025 13:29:34 -0700 Subject: [PATCH 0303/1769] jnp.power: support __jax_array__ on inputs --- jax/_src/numpy/ufuncs.py | 5 +++++ tests/array_extensibility_test.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index e561b7ae71b6..3902b24b35ac 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2652,6 +2652,11 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: [nan, 27., 1.]], dtype=float32) """ check_arraylike("power", x1, x2) + + # Must do __jax_array__ conversion prior to dtype check. + x1 = x1.__jax_array__() if hasattr(x1, "__jax_array__") else x1 + x2 = x2.__jax_array__() if hasattr(x2, "__jax_array__") else x2 + check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 63a8762cd0b0..45847b6f0f29 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -427,8 +427,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), NumPyAPI.sig(jnp.positive, Float[5]), - # NumPyAPI.sig(jnp.pow, Float[5], Float[5]), - # NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + NumPyAPI.sig(jnp.power, Float[5], Float[5]), NumPyAPI.sig(jnp.prod, Float[5]), NumPyAPI.sig(jnp.ptp, Float[5]), NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), From 994af3efb85339b69112b9e75c9975d24d90d8b3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 17:37:13 -0700 Subject: [PATCH 0304/1769] [Pallas TPU] Remove forward compatibility code for float -> signed conversions This will be submitted automatically once the compatibility window has passed PiperOrigin-RevId: 742464046 --- jax/_src/pallas/mosaic/lowering.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 617324d43bf9..1139630ae602 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2096,16 +2096,6 @@ def _convert_helper(x, *, to_dtype): # unsigned -> float is unsupported. We fall through and raise at the bottom. if not jnp.issubdtype(to_dtype, jnp.floating): return x.astype(to_dtype) - if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype( - to_dtype, jnp.signedinteger - ): - if from_dtype.itemsize < 4: - x = x.astype(jnp.float32) - if to_dtype.itemsize < 4: - # Need to clip values to match XLA - minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max - x = jnp.clip(x, minval, maxval) - return x.astype(jnp.int32).astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") def _convert_element_type_lowering_rule( @@ -2149,10 +2139,7 @@ def _convert_element_type_lowering_rule( return x # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) + return arith.fptosi(out_type, x) elif _from(signed) and _to(floating) and both_32bit: return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: From 006a6a63feb64bf9984526030ba008186d69d2b4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 31 Mar 2025 22:01:48 -0700 Subject: [PATCH 0305/1769] [Easy] Make pallas mesh grid handling more resilient to tuple names. PiperOrigin-RevId: 742531956 --- jax/_src/lax/parallel.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index ebc6255cb66b..39b6c68679ca 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1861,8 +1861,8 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - effect = {core.NamedAxisEffect(axis_name)} axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + effect = {core.NamedAxisEffect(axis_name)} _check_axis_names(axis_name) mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1139630ae602..f8f49f3d7aea 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -517,11 +517,17 @@ def has_communication(self) -> bool: nonlocal_axis_names = set() def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): return { - e.name - for e in jaxpr.effects - if isinstance(e, jax_core.NamedAxisEffect) - and (not self.grid_names or e.name not in self.grid_names) - } + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + and ( + not self.grid_names + or all( + name not in self.grid_names + for name in tree_util.tree_leaves(e.name) + ) + ) + } nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) for bm in self.block_mappings: if bm is not None: From 6adb7289754edff320a670630dec12fc697100ab Mon Sep 17 00:00:00 2001 From: Louis-Justin TALLOT <72044417+LouisJustinTALLOT@users.noreply.github.com> Date: Tue, 1 Apr 2025 02:46:30 -0400 Subject: [PATCH 0306/1769] Clarify documentation of jnp.heaviside --- jax/_src/numpy/ufuncs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3902b24b35ac..60e10b3be048 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -3640,9 +3640,9 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: .. math:: \mathrm{heaviside}(x1, x2) = \begin{cases} - 0., & x < 0\\ - x2, & x = 0\\ - 1., & x > 0. + 0, & x1 < 0\\ + x2, & x1 = 0\\ + 1, & x1 > 0. \end{cases} Args: From 5d1bc005a00546ece0172d01bd3434f5026c80c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 05:02:28 -0700 Subject: [PATCH 0307/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b1971cc2b3407e87fada2674a057d72897b79acc. PiperOrigin-RevId: 742646393 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d078359af86a..13223c4a4b88 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f4a53456b04acf9b63b3b30bd828cec29c4aa7de" -XLA_SHA256 = "2ee32b70af547fd13ce404d75c3fa9834bc8be46a488cd8f0caa10e9a6ec7ede" +XLA_COMMIT = "b1971cc2b3407e87fada2674a057d72897b79acc" +XLA_SHA256 = "3b2feabbcd6adc5721533edfbe3dc2ad6517cb1b059cf41dea63f62874bff12d" def repo(): tf_http_archive( From 40a3d0c78dad1d539180a4f830a3e1c17460ced0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 09:11:08 -0700 Subject: [PATCH 0308/1769] Create the test targets for the wheel size verification. Add the tests to the Bazel presubmit RBE jobs (except `arm64`/`aarch64` jobs that use RBE cross-compilation). PiperOrigin-RevId: 742724458 --- BUILD.bazel | 16 ++++++++++ ci/run_bazel_test_cpu_rbe.sh | 4 ++- ci/run_bazel_test_cuda_rbe.sh | 8 ++++- jaxlib/tools/BUILD.bazel | 48 ++++++++++++++++++++++++++++ jaxlib/tools/wheel_size_test.py | 56 +++++++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 jaxlib/tools/wheel_size_test.py diff --git a/BUILD.bazel b/BUILD.bazel index 2c10f0d9a748..8dbf2bed0902 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,6 +22,7 @@ load( "jax_source_package", "jax_wheel", "py_deps", + "pytype_test", ) collect_data_files( @@ -152,3 +153,18 @@ py_import( wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..d8cb190079e0 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -64,5 +64,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test fi \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..94c6a89fdb8c 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + --@local_config_cuda//cuda:override_include_cuda_libs=true \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test \ No newline at end of file diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 2ddc9e90a702..79a1f7e7089d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -33,12 +33,15 @@ load( "jax_py_test", "jax_wheel", "pytype_strict_library", + "pytype_test", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +exports_files(["wheel_size_test.py"]) + genrule( name = "platform_tags_py", srcs = [], @@ -389,3 +392,48 @@ verify_manylinux_compliance_test( wheel = ":jax_cuda_pjrt_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) + +pytype_test( + name = "jaxlib_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jaxlib_wheel)", + "--max-size-mib=110", + ], + data = [":jaxlib_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_plugin_wheel)", + "--max-size-mib=20", + ], + data = [":jax_cuda_plugin_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_pjrt_wheel)", + "--max-size-mib=120", + ], + data = [":jax_cuda_pjrt_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/jaxlib/tools/wheel_size_test.py b/jaxlib/tools/wheel_size_test.py new file mode 100644 index 000000000000..7e9c08ff9797 --- /dev/null +++ b/jaxlib/tools/wheel_size_test.py @@ -0,0 +1,56 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel size verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--wheel-path", required=True, help="Path of the wheel, mandatory" + ) + parser.add_argument( + "--max-size-mib", + required=True, + help="Maximum size of the wheel in MiB", + ) + return parser.parse_args() + + +def verify_wheel_size(args): + wheel_size_mib = os.path.getsize(args.wheel_path) >> 20 + wheel_name = os.path.basename(args.wheel_path) + if wheel_size_mib > int(args.max_size_mib): + raise RuntimeError( + "The {name} size is {size} MiB, which is larger than the maximum size" + " {max_size} MiB".format( + name=wheel_name, + size=wheel_size_mib, + max_size=args.max_size_mb, + ) + ) + else: + logging.info( + "The %s size is %s MiB, which is less than the maximum size" + " %s MB", wheel_name, wheel_size_mib, args.max_size_mib) + + +if __name__ == "__main__": + verify_wheel_size(parse_args()) From 76271d638ad94f3df854054640f3b35161ee5be4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 1 Apr 2025 09:50:00 -0700 Subject: [PATCH 0309/1769] Add scan_p and cond_p vma rule. PiperOrigin-RevId: 742737384 --- jax/_src/core.py | 7 ++++--- jax/_src/lax/control_flow/conditionals.py | 9 +++++++++ jax/_src/lax/control_flow/loops.py | 12 ++++++++++-- jax/_src/state/types.py | 9 +++++++++ jax/experimental/shard_map.py | 4 ++-- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ee6537650f20..ae94782ce98a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -574,7 +574,7 @@ def read(v: Atom) -> Any: def write(v: Var, val: Any) -> None: if config.enable_checks.value and not config.dynamic_shapes.value: - assert typecheck(v.aval, val), (v.aval, val) + assert typecheck(v.aval, val), (v.aval, get_aval(val)) env[v] = val env: dict[Var, Any] = {} @@ -2594,7 +2594,7 @@ def _map_shaped_array( if axis is None: return aval sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, vma=aval.vma) def _unmap_shaped_array( size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray @@ -2604,7 +2604,8 @@ def _unmap_shaped_array( sharding = aval.sharding.with_spec(tuple_insert( aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, + vma=aval.vma) else: raise TypeError(axis) def _map_dshaped_array( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..b0e1221752bd 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -347,6 +347,15 @@ def _cond_abstract_eval(*avals: core.AbstractValue, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') + b0_vma = [o.vma for o in branches[0].out_avals] + for branch in branches[1:]: + b_vma = [o.vma for o in branch.out_avals] + if b0_vma != b_vma: + raise Exception("The branches of cond produced mismatched varying manual " + f"axes. Got {b0_vma} and {b_vma}. Please open an issue " + "at https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 56323949a607..9a66dd037d3a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -570,9 +570,17 @@ def _prepend_dim_to_aval(sz, aval): def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + out_carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + _, in_carry_avals, _ = split_list(args, [num_consts, num_carry]) + if [i.vma for i in in_carry_avals] != [o.vma for o in out_carry_avals]: + raise ValueError( + 'Scan carry input and output got mismatched varying manual axes ' + f'{in_carry_avals} and {out_carry_avals}. Please open an ' + 'issue at https://github.com/jax-ml/jax/issues, and as a ' + 'temporary workaround pass the check_rep=False argument to ' + 'shard_map') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects + return out_carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e926e3a35f80..b9dbaf35c5d2 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -366,6 +366,15 @@ def sharding(self): f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." ) from None + @property + def vma(self): + try: + return self.inner_aval.vma # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `vma`." + ) from None + @core.aval_property def at(self): return RefIndexer(self) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8e2d93af2639..4b9daf170dce 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1365,7 +1365,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not carry_rep_in == carry_rep_out: + if carry_rep_in != carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " "issue at https://github.com/jax-ml/jax/issues, and as a " @@ -1403,7 +1403,7 @@ def _cond_rule(mesh, *in_rep, branches): out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) for branch in branches[1:]: out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not out_rep_ == out_rep: + if out_rep_ != out_rep: raise Exception("The branches of cond produced mismatched replication " "types. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a " From a34c4628755877f431d075cde50d73ee33158b34 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 09:53:29 -0700 Subject: [PATCH 0310/1769] jnp.select: support __jax_array__ for inputs --- jax/_src/numpy/lax_numpy.py | 6 ++++++ tests/array_extensibility_test.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7b900e09068e..a47f66e5f621 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2907,6 +2907,12 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") + + util.check_arraylike("select", *condlist, *choicelist, default) + condlist = [asarray(cond) for cond in condlist] + choicelist = [asarray(choice) for choice in choicelist] + default = asarray(default) + # Put the default at front with condition False because # argmax returns zero for an array of False values. choicelist = util.promote_dtypes(default, *choicelist) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 45847b6f0f29..f62491a608c7 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -452,7 +452,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rot90, Float[5, 3]), NumPyAPI.sig(jnp.round, Float[5]), NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), - # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), + NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[()]), NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), NumPyAPI.sig(jnp.shape, Float[5, 3]), From a80f6279e9eba6ec0aa1fc2b37e979f883768c31 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 1 Apr 2025 00:17:30 +0000 Subject: [PATCH 0311/1769] make random_gamma_grad not a primitive anymore Fixes #16076 Co-authored-by: Roy Frostig --- jax/_src/checkify.py | 2 +- jax/_src/lax/special.py | 47 ++++++++++++++++------------------- jax/extend/core/primitives.py | 1 - jax/lax/__init__.py | 1 - 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index f80a0cbd1d75..f0abf53b0717 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -600,7 +600,7 @@ def isnan(x): lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 041205156d58..a59d62523c9f 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -38,6 +38,25 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +# TODO(mattjj): this function sucks, delete it +def _up_and_broadcast(doit): + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] + + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" a, b, x = core.standard_insert_pbroadcast(a, b, x) @@ -71,10 +90,11 @@ def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: a, x = core.standard_insert_pbroadcast(a, x) return igamma_grad_a_p.bind(a, x) -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: +@_up_and_broadcast +def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" a, x = core.standard_insert_pbroadcast(a, x) - return random_gamma_grad_p.bind(a, x) + return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" @@ -531,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _up_and_broadcast(doit): - def up_and_broadcast(*args): - broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) - args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] - - a_dtype = args[0].dtype - needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 - if needs_upcast: - args = [convert_element_type(a, np.float32) for a in args] - a_x_type = np.float32 - else: - a_x_type = a_dtype - result = doit(*args, dtype=a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result - return up_and_broadcast - def evaluate_chebyshev_polynomial(x, coefficients): b0 = full_like(x,0) @@ -694,11 +696,6 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -mlir.register_lowering(random_gamma_grad_p, - mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), - multiple_results=False)) - zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index d8a10154cf4a..60d8cd24a949 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -149,7 +149,6 @@ igamma_p as igamma_p, lgamma_p as lgamma_p, polygamma_p as polygamma_p, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta_p as zeta_p, ) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 6f2163c424a6..43c4cf17e559 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -261,7 +261,6 @@ polygamma as polygamma, polygamma_p as polygamma_p, random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta as zeta, zeta_p as zeta_p, From 2d2be0bbb922c2571d433eeaeb5209c043334e97 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Tue, 1 Apr 2025 10:44:52 -0700 Subject: [PATCH 0312/1769] Update permisisons community_release_actions.yml --- .github/workflows/community_release_actions.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index 1980e803ba9b..1110cbad9475 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -4,6 +4,9 @@ on: release: types: [published] +permissions: + contents: read + jobs: discord_release: if: github.repository_owner == 'jax-ml' From 5370ac2ec59c1acb347eb68771beec2487c8de64 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 1 Apr 2025 11:33:00 -0700 Subject: [PATCH 0313/1769] Remove the try/except for Shardy imports. Shardy has been been included in JAX for a while now. PiperOrigin-RevId: 742778405 --- jax/_src/interpreters/mlir.py | 4 +--- jax/_src/lib/mlir/dialects/__init__.py | 6 +----- jax/extend/mlir/dialects/sdy.py | 6 +----- tests/pjit_test.py | 7 ------- 4 files changed, 3 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23d1b5dd9d89..a1b37876f87e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -616,9 +616,7 @@ def make_ir_context() -> ir.Context: # we don't do any heavy computation on MLIR modules from Python anyway, so we # just disable threading. context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..be5317824c36 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -51,11 +51,7 @@ ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26760..d83fd90ecdf4 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0b2daee8ccff..ee4a8cd3e15e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -59,7 +59,6 @@ from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension @@ -8067,12 +8066,6 @@ def f(x, y): @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyTest(jtu.JaxTestCase): - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - def test_lowering_input_output_sharding(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) From 0b199f48c7e0d4e5837cee34ced7f3fc7065732f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 1 Apr 2025 11:49:03 -0700 Subject: [PATCH 0314/1769] [jaxlib] Roll back subbyte types due to failing asan tests. Reverts 12526ea11646a75fac201e26c1a2e901f94a4c76 PiperOrigin-RevId: 742784183 --- jaxlib/cuda/BUILD | 1 - jaxlib/gpu/py_client_gpu.cc | 89 ++++++++++++--------------------- jaxlib/rocm/BUILD | 1 - jaxlib/xla/BUILD | 1 - jaxlib/xla/py_client_cpu.cc | 81 ++++++++++-------------------- tests/python_callback_test.py | 94 +++++++++++++++-------------------- 6 files changed, 99 insertions(+), 168 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d35e421ef904..fac62c81dee7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,7 +689,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 38f2ac1896e7..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -81,14 +80,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == xla::S1 || ptype == xla::U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); } - if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -114,6 +112,9 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -121,22 +122,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - // We pass in data using default numpy layout i.e., std::nullopt. - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - auto buffer = xla::UnpackIntN( - bits_per_element, static_cast(host_input_buffers[i]), - arg->size_bytes()); - delete[] static_cast(host_input_buffers[i]); - host_input_buffers[i] = buffer.release(); - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], /*base=*/base); + host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -159,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -180,43 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - - const void* data = array.data(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - plan->Execute(data, temp); - data = temp; + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - ret->size_bytes()); - data = buffer.get(); + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); } - - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, ret->size_bytes(), + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 358a6d1cc9aa..d0c0c798abb8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,7 +588,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 5b532c1dc501..2ca18afda13d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,7 +637,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index fc4f895af6aa..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,9 +78,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -97,18 +96,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - const void* data = arg->untyped_data(); - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - buffer = xla::UnpackIntN(bits_per_element, static_cast(data), - arg->size_bytes()); - data = buffer.get(); - } // We pass in data using default numpy layout i.e., std::nullopt. - auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); + auto array = + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -129,9 +119,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -151,45 +141,26 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - - const void* data = array.data(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - plan->Execute(data, ret->untyped_data()); - data = ret->untyped_data(); - } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - ret->size_bytes()); - data = buffer.get(); + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; } - - // Copy data to output buffer if haven't already or modified the data to - // write back. - if (data != ret->untyped_data()) { - std::memcpy(ret->untyped_data(), data, ret->size_bytes()); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 34ab20c05644..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,15 +586,10 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -605,17 +600,21 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) @@ -626,43 +625,16 @@ def f(): ) return y - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") - def test_non_default_stride_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) - x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() class PureCallbackTest(jtu.JaxTestCase): @@ -1136,6 +1108,20 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase): From 7b04a79fbdc0fe7b75e44a77cae8ed7a003a6821 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 11:25:32 -0700 Subject: [PATCH 0315/1769] jnp.einsum: add support for __jax_array__ --- jax/_src/numpy/einsum.py | 4 ++++ tests/array_extensibility_test.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 9d745643b596..21333a9e7a0d 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -288,6 +288,10 @@ def einsum( spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize + # Extract __jax_array__ before passing to contract_path() + operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op + for op in operands) + # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index f62491a608c7..69e9e1609f86 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -548,6 +548,29 @@ def test_array_creation_from_duck_typed_array(self, func): self.assertEqual(result.shape, obj.shape) self.assertEqual(result.dtype, obj.dtype) + @parameterized.named_parameters( + {"testcase_name": "subscript-form", "args": ("jk,k->j", Float[5, 3], Float[3])}, + {"testcase_name": "index-form", "args": (Float[5, 3], (0, 1), Float[3], (1,), (0,))}, + ) + def test_einsum(self, args): + rng = jtu.rand_default(self.rng()) + def make_arg(arg): + if isinstance(arg, jax.ShapeDtypeStruct): + return rng(arg.shape, arg.dtype) + return arg + args = jax.tree.map(make_arg, args) + + def wrap_array(arg): + if isinstance(arg, (jax.Array, np.ndarray)): + return JaxArrayWrapper(arg) + return arg + wrapped_args = jax.tree.map(wrap_array, args) + + expected = jnp.einsum(*args) + actual = jnp.einsum(*wrapped_args) + + self.assertAllClose(actual, expected, atol=0, rtol=0) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 4908b2f167a78783e95c1d677a849d0a98a97dc4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 10:05:42 -0700 Subject: [PATCH 0316/1769] cumulative reductions: support __jax_array__ on inputs --- jax/_src/numpy/reductions.py | 10 ++++------ tests/array_extensibility_test.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..96b2782edc13 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, check_arraylike, _complex_elem_type, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -1992,7 +1992,7 @@ def _cumulative_reduction( fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" - check_arraylike(name, a) + a = ensure_arraylike(name, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") dtypes.check_user_dtype_supported(dtype, name) @@ -2242,8 +2242,7 @@ def cumulative_sum( Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32) """ - check_arraylike("cumulative_sum", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_sum", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " @@ -2304,8 +2303,7 @@ def cumulative_prod( Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32) """ - check_arraylike("cumulative_prod", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_prod", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative product, however a " diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 69e9e1609f86..8f5ea33b5894 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -283,10 +283,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), # NumPyAPI.sig(np.cov, [float], [(10,)]), # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), - # NumPyAPI.sig(np.cumprod, [float], [(10,)]), - # NumPyAPI.sig(np.cumsum, [float], [(10,)]), - # NumPyAPI.sig(np.cumulative_prod, [float], [(10,)]), - # NumPyAPI.sig(np.cumulative_sum, [float], [(10,)]), + NumPyAPI.sig(jnp.cumprod, Float[5]), + NumPyAPI.sig(jnp.cumsum, Float[5]), + NumPyAPI.sig(jnp.cumulative_prod, Float[5]), + NumPyAPI.sig(jnp.cumulative_sum, Float[5]), NumPyAPI.sig(jnp.deg2rad, Float[5]), NumPyAPI.sig(jnp.degrees, Float[5]), # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), From 05269a8ec90a1e14f89d514a0f4b228525bf906c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 1 Apr 2025 20:18:32 +0000 Subject: [PATCH 0317/1769] [mutable-arrays] add vmap rule for mutable_array_p, very basic test --- jax/_src/interpreters/batching.py | 5 +++++ tests/mutable_array_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 03c9a95105d7..a187d42511ac 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1169,3 +1169,8 @@ def add_batched(batched_args, batch_dims): x = moveaxis(x, bdx, bdy) return add_jaxvals(x, y), bdy primitive_batchers[add_jaxvals_p] = add_batched + + +### mutable arrays + +defvectorized(core.mutable_array_p) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 950bddf544d7..a51e1d7841ce 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -239,6 +239,16 @@ def f(x_ref): x_ref = core.mutable_array(x) y = f(x_ref) + def test_vmap_basic(self): + @jax.vmap + def f(x): + x_ref = core.mutable_array(x) + x_ref[...] = x_ref[...] * x_ref[...] + return x_ref[...] + xs = jnp.arange(4.) + ys = f(xs) + self.assertAllClose(ys, xs ** 2, check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From ff5a2e8c91c3e32db6a547326d1356023226f83c Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 14:25:40 -0700 Subject: [PATCH 0318/1769] Enable test_scan_offload in memories_test. PiperOrigin-RevId: 742840628 --- tests/memories_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 64ee2829873d..570b0c375834 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1090,7 +1090,6 @@ def f_bwd(res, tx): self.assertArraysEqual(g(arr), all_true) def test_scan_offload(self): - self.skipTest('b/406586554') np_inp = jnp.arange(4096).reshape(16, 16, 16) @jax.jit From f13919220118aacd7eea9d7794a6217eea83066d Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 1 Apr 2025 19:17:02 -0700 Subject: [PATCH 0319/1769] Add OOB checks to jax.numpy array indexing PiperOrigin-RevId: 742927160 --- jax/_src/numpy/error.py | 53 ++++++++++++++++++++++++++++++++++- jax/_src/numpy/indexing.py | 17 +++++++---- tests/jax_numpy_error_test.py | 36 ++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index 20dab289d779..e2c23b43bdf8 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -13,10 +13,11 @@ # limitations under the License. import contextlib -from typing import Literal +from typing import Literal, Sequence import jax from jax._src import config +from jax._src.typing import ArrayLike Category = Literal["nan", "divide", "oob"] @@ -102,6 +103,56 @@ def _set_error_if_divide_by_zero(pred: jax.Array, /): error_check_lib.set_error_if(pred == zero, "Division by zero encountered") +def _check_precondition_oob_gather( + shape: tuple[int, ...], gather_indices: ArrayLike +) -> None: + """Check for out of bounds errors before calling `lax.gather`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.min(gather_indices) < -shape, + jnp.max(gather_indices) >= shape, + ), + "Out of bounds encountered before calling `lax.gather`", + ) + + +def _check_precondition_oob_dynamic_slice( + shape: tuple[int, ...], + start_indices: Sequence[ArrayLike], + slice_sizes: list[int], + allow_negative_indices: list[bool], +) -> None: + """Check for out of bounds errors before calling `lax.dynamic_slice`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + start_indices = jnp.array(start_indices, dtype=jnp.int32) + slice_sizes = jnp.array(slice_sizes, dtype=jnp.int32) + allow_negative_indices = jnp.array(allow_negative_indices, dtype=jnp.bool_) + + lower_bound = jnp.where(allow_negative_indices, -shape, 0) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.minimum(start_indices, start_indices + slice_sizes) < lower_bound, + jnp.maximum(start_indices, start_indices + slice_sizes) >= shape, + ), + "Out of bounds encountered before calling `lax.dynamic_slice`", + ) + + Behavior = Literal["ignore", "raise"] diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 5d59bb53b457..863f0c775ec6 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -20,8 +20,6 @@ import string from typing import Any, NamedTuple, Sequence -import numpy as np - import jax from jax import lax from jax._src import array @@ -30,17 +28,19 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import errors +from jax._src import mesh as mesh_lib from jax._src.api import jit from jax._src.lax import lax as lax_internal from jax._src.numpy import einsum -from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_axes +from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.pjit import auto_axes from jax._src.tree_util import tree_flatten from jax._src.typing import Array, ArrayLike, StaticScalar -from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_replace +import numpy as np export = set_module('jax.numpy') @@ -570,7 +570,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] - slice_sizes: Sequence[int] = [] + slice_sizes: list[int] = [] allow_negative_indices: list[bool] = [] for ind, size in safe_zip(idx, arr.shape): @@ -587,6 +587,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> slice_sizes.append(1) allow_negative_indices.append( not isinstance(ind, (int, np.integer)) or bool(ind < 0)) + # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore @@ -598,6 +599,9 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> # start indices to have matching types. if len(start_indices) > 1: start_indices = util.promote_dtypes(*start_indices) + jnp_error._check_precondition_oob_dynamic_slice( + arr.shape, start_indices, slice_sizes, allow_negative_indices + ) arr = lax.dynamic_slice( arr, start_indices=start_indices, slice_sizes=slice_sizes, allow_negative_indices=allow_negative_indices) @@ -640,6 +644,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update + jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index f2262d8b5dc0..a38e7d5509f9 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -231,6 +231,42 @@ def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): with self.assertRaisesRegex(JaxValueError, "Division by zero"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_can_raise_oob_error_take(self, jit): + def f(x, a): + return x[a] + + if jit: + f = jax.jit(f) + + x = jnp.arange(10) + a = jnp.int32(10) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_can_raise_oob_error_dynamic_slice(self): + def f(x, a): + return x[:, a:a+4] # dynamic indices are non-jittable + + x = jnp.arange(10).reshape(2, 5) + a = jnp.array(3, dtype=jnp.int32) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 1875c76bd2944f64967e9c9b7989233502d8da95 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 1 Apr 2025 19:10:39 -0700 Subject: [PATCH 0320/1769] let XLA metadata be unset in nested dynamic scopes Treat `None` metadata values as a special instruction not to set (or to unset, if nested) the corresponding entry. In particular, this makes it possible to unset metadata within the sub-computations of higher-order operations (e.g. branches in conditionals, loop bodies, etc.). This can be used, for example, to annotate a conditional but not all the operations in its branches. That is, the HLO for the following function `f` on a scalar float argument: ``` def cos(x): with set_xla_metadata(a=None): return jnp.cos(x) @jax.jit def f(x): with set_xla_metadata(a="b"): return jax.lax.cond(x < 0., jnp.sin, cos, x) ``` produces an attribute `a` on the conditional and on the sine, but not on the cosine. --- jax/_src/core.py | 2 +- jax/_src/xla_metadata.py | 13 ++++++++++--- tests/xla_metadata_test.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ee6537650f20..bb23a540e526 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -320,7 +320,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): config.compute_on_context_manager.set_local(self.prev_compute_type) config.threefry_partitionable.set_local(self.prev_threefry_partitionable) - if self.context.xla_metadata is not None: + if self.context.xla_metadata: config.xla_metadata_context_manager.set_local(self.prev_xla_metadata) config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 91895b4e7851..77c0e2ff9910 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -24,6 +24,8 @@ class XlaMetadata: __slots__ = ['val', 'hash'] + val: dict[str, Any] + def __init__(self, val): self.val = val self.hash = hash(tuple(sorted(self.val.items()))) @@ -35,14 +37,19 @@ def __eq__(self, other): return other is not None and self.val == other.val +def filter_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} + + def update_metadata(a, b: dict[str, Any]): if not b: return a if a is None or a is config_ext.unset: - return XlaMetadata(b) - val = a.val.copy() + val = {} + else: + val = a.val.copy() val.update(b) - return XlaMetadata(val) + return XlaMetadata(filter_nones(val)) def current_xla_metadata(): diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index d141bc15c249..33fd7a08b1de 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -190,6 +190,39 @@ def while_fn(a): if "stablehlo.add" in line: self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + def test_cond_annotates_branches(self): + sin = jnp.sin + cos = jnp.cos + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + + def test_cond_annotates_branches_and_none_unsets(self): + sin = jnp.sin + + def cos(x): + with set_xla_metadata(a=None): + return jnp.cos(x) + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + def test_nested_jit(self): @jax.jit def f(x, y): From 6fe6d8050663358a3a4447e4022efe012285f840 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 1 Apr 2025 22:17:41 -0700 Subject: [PATCH 0321/1769] upgrade docs from `jax.core` to `jax.extend.core` where needed to fix doc build --- docs/jax-primitives.md | 4 ++-- docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb | 2 +- docs/notebooks/Writing_custom_interpreters_in_Jax.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..38a45ef4823e 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,7 +21,7 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: @@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..56b2d80fc58e 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..6b993a630e93 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` From 8e2c1a18c7676a2b481c5a41128ddf191793831b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 Jan 2025 21:52:37 +0000 Subject: [PATCH 0322/1769] Updates for 3.14 Added tsan ci cpython 3.14 job --- ...essions.txt => tsan-suppressions_3.13.txt} | 4 + .github/workflows/tsan-suppressions_3.14.txt | 26 +++ .github/workflows/tsan.yaml | 166 +++++++++++++++--- WORKSPACE | 1 + build/build.py | 2 + build/requirements_lock_3_13_ft.txt | 2 +- build/requirements_lock_3_14_ft.txt | 107 +++++++++++ build/tools/utils.py | 14 ++ jaxlib/jax.bzl | 4 +- 9 files changed, 296 insertions(+), 30 deletions(-) rename .github/workflows/{tsan-suppressions.txt => tsan-suppressions_3.13.txt} (93%) create mode 100644 .github/workflows/tsan-suppressions_3.14.txt create mode 100644 build/requirements_lock_3_14_ft.txt diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions_3.13.txt similarity index 93% rename from .github/workflows/tsan-suppressions.txt rename to .github/workflows/tsan-suppressions_3.13.txt index bdffddc58ca0..833fa856a7d6 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -21,6 +21,10 @@ race:_PyUnicode_InternImmortal # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top: new_reference + # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..9cfc68e1ae36 --- /dev/null +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -0,0 +1,26 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Likely only happens when the process is crashing. +race:dump_traceback + +# https://github.com/python/cpython/issues/129748 +race:mi_block_set_nextx + +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index cd59c0bf45e0..4c28608a8257 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -22,6 +22,16 @@ jobs: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.13" + python-version: "3.13" + github_branch: "3.13" + requirements_lock_name: "requirements_lock_3_13_ft" + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "main" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} @@ -44,22 +54,33 @@ jobs: with: repository: python/cpython path: cpython - ref: "3.13" + ref: ${{ matrix.github_branch }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.python-version == '3.14' }} + with: + repository: scipy/scipy + path: scipy + submodules: true - - name: Restore cached TSAN CPython + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} + + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython @@ -73,19 +94,14 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} - - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore @@ -93,7 +109,7 @@ jobs: with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' @@ -114,7 +130,8 @@ jobs: python3 -m pip install uv~=0.5.30 # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -147,7 +164,83 @@ jobs: with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Restore cached Scipy + if: ${{ matrix.python-version == '3.14' }} + id: cache-scipy-restore + uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build Scipy wheel + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + run: | + # Install scipy dependencies: + apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + cd scipy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install uv~=0.5.30 + # Make sure to install a compatible Cython version (master branch is best for now) + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ + python3 -m uv pip install pythran pybind11 meson-python ninja + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + export CC=clang-18 + export CXX=clang++-18 + python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy + + scipy_whl_name=($(cd dist && ls scipy*.whl)) + if [ -z "${scipy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Scipy wheel: ${scipy_whl_name}" + + cp dist/${scipy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/scipy + + # Recreate wheelhouse index with Numpy and Scipy + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ scipy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/scipy/index.html + + ${scipy_whl_name}
+ + EOF + + - name: Save Scipy wheel + id: cache-scipy-save + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build Jax and run tests timeout-minutes: 120 @@ -164,7 +257,7 @@ jobs: python3 -VV python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ + --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ @@ -174,18 +267,32 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 + if [ "${{ matrix.python-version }}" == "3.13" ]; then + # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/${{ matrix.requirements_lock_name }}.patch + cat .github/workflows/${{ matrix.requirements_lock_name }}.patch + git apply .github/workflows/${{ matrix.requirements_lock_name }}.patch || exit 1 + + # Display the content for debugging in logs + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 + # Check the patch + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" + if [ "$?" == "1" ]; then echo "Could not find the patch in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + cat build/${{ matrix.requirements_lock_name }}.txt | grep -E "(numpy==)" + if [ "$?" == "0" ]; then "Found original numpy dependency in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + + else + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy + + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt + + # We should install jpeg dev package to be able to build Pillow from source: + apt-get install -y libjpeg-dev --no-install-recommends + + # Install scipy runtime dependencies (in case we restore scipy wheel from cache): + apt-get install -y libopenblas-dev liblapack-dev --no-install-recommends + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -201,13 +308,18 @@ jobs: # Check numpy version ./bazel cquery @pypi_numpy//:* | grep whl + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check scipy version + ./bazel cquery @pypi_scipy//:* | grep whl + fi + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions_${{ matrix.python-version }}.txt \ --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ diff --git a/WORKSPACE b/WORKSPACE index a6968446a1ec..5c093ec2228f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,7 @@ python_init_repositories( "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", }, local_wheel_inclusion_list = [ "jax-*", diff --git a/build/build.py b/build/build.py index f8c0ccbfa6a4..226d984b3d89 100755 --- a/build/build.py +++ b/build/build.py @@ -496,6 +496,7 @@ async def main(): if args.use_clang: clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -505,6 +506,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 5157706c00e8..a96a3e6e489b 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -658,7 +658,7 @@ zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can compile 0.23.0 +# python 3.13t can't compile 0.23.0 # due to https://github.com/indygreg/python-zstandard/issues/231 # zstandard==0.23.0 \ # --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..18e4ef6d576a --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,107 @@ +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scipy + +absl-py==2.1.0 + +attrs==24.3.0 + +auditwheel==6.2.0 + +build==1.2.2.post1 + +cloudpickle==3.1.1 # version 3.1.0 leads to recursion error + +colorama==0.4.6 + +contourpy==1.3.1 + +cycler==0.12.1 + +etils[epath,epy]==1.11.0 + +execnet==2.1.1 + +filelock==3.16.1 + +flatbuffers==24.12.23 + +fonttools==4.56.0 + +fsspec==2024.12.0 + +hypothesis==6.123.9 + +importlib-resources==6.5.2 + +iniconfig==2.0.0 + +kiwisolver==1.4.8 + +markdown-it-py==3.0.0 + +matplotlib==3.10.1 + +mdurl==0.1.2 + +ml-dtypes==0.5.1 + +mpmath==1.3.0 + +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" + +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" + +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" +opt-einsum==3.4.0 + +packaging==24.2 + +pillow==11.1.0 +pluggy==1.5.0 + +portpicker==1.6.0 + +psutil==6.1.1 +pyelftools==0.31 + +pygments==2.19.1 + +pyparsing==3.2.2 # version 3.2.1 fails with SyntaxError(originally SyntaxWarning): 'return' in a 'finally' block in pyparsing/core.py", line 5716 + +pyproject-hooks==1.2.0 + +pytest==8.3.4 + +pytest-xdist==3.6.1 + +python-dateutil==2.9.0.post0 + +rich==13.9.4 + +six==1.17.0 + +sortedcontainers==2.4.0 + +typing-extensions==4.12.2 + +wheel==0.45.1 + +zipp==3.21.0 + +# python 3.14t can't compile 0.23.0 +# due to https://github.com/indygreg/python-zstandard/issues/231 +# zstandard==0.23.0 + +setuptools==70.3.0 diff --git a/build/tools/utils.py b/build/tools/utils.py index 8b8dc80d1e0f..ccce8aff09cc 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -202,6 +202,20 @@ def get_clang_major_version(clang_path): return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.stem + clangpp_exec_name = clang_exec_name + if "clang++" not in clang_exec_name: + clangpp_exec_name = clang_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) + def get_gcc_major_version(gcc_path: str): gcc_version_proc = subprocess.run( [gcc_path, "-dumpversion"], diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 1cc4fab12591..93e9ebacfa6f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -76,9 +76,9 @@ _CPU_PYPI_WHEEL_DEPS = [ "@pypi_jaxlib//:pkg", ] -# TODO(vam): remove this once zstandard builds against Python 3.13 +# TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): - if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": + if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): return [] return ["@pypi_zstandard//:pkg"] From 076d021057722aa58d0621d79630ddfab4a64bce Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 26 Mar 2025 12:39:23 +0200 Subject: [PATCH 0323/1769] [better_errors] Fix the handling of kwargs for debug_info. kwargs are passed sorted by the actual kwarg keyword. This order must be accounted for when we construct the `debug_info.arg_names`. Extended the tests to be more precise about not mixing up kwargs, e.g., use different shapes and look for the shape in the HLO. --- jax/_src/api.py | 2 +- jax/_src/api_util.py | 57 ++++++++---- jax/_src/pjit.py | 4 +- tests/debug_info_test.py | 191 +++++++++++++++++++++++++++++++++------ 4 files changed, 207 insertions(+), 47 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 55e2b2126a68..fb10245c30e9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -192,7 +192,7 @@ def jit( constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + ``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary Python objects. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static. diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a42141b96fbd..451d2e490a15 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -28,12 +28,12 @@ from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, broadcast_prefix, - prefix_errors) -from jax._src.tree_util import _replace_nones + prefix_errors, _replace_nones) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable, safe_zip) + Unhashable, safe_zip as zip) from jax._src import traceback_util + traceback_util.register_exclusion(__file__) map = safe_map @@ -201,9 +201,11 @@ def _validate_argnames( f"in {argnames_name}. Function does not take these args.") -def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): +def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], + args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) + fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): @@ -273,7 +275,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun: Callable, + _dyn_argnums: Sequence[int], + _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): @@ -334,7 +338,7 @@ def donation_vector(donate_argnums, donate_argnames, in_tree, donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) @@ -673,28 +677,45 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, top-level arguments. In other cases, including when the `args` and `kwargs` do not match the signature, we use names like `args[0[]`, `args[1]`, etc. """ + # Use the same argument parsing as jit: positional followed by kwargs + # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: - return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' - for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) - args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(args_) - if l is not static) - kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(kwargs_) - if l is not static) - arg_names = args_arg_names + kwargs_arg_names - return arg_names + # Do we have a **kwargs + kwargs_name = next((name for name, p in fn_signature.parameters.items() + if p.kind == inspect.Parameter.VAR_KEYWORD), None) + # Positional argument are those not passed by keyword and not passed + # by **kwargs. + positional = [(name, x) for name, x in ba.arguments.items() + if name not in kwargs and name != kwargs_name] + # Keyword arguments are passed sorted by actual kwarg keyword + sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), + key=lambda name_x: name_x[0]) + sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", + x) + for name, x in sorted_kwargs] + ordered_args = positional + sorted_kwargs + + if ordered_args is None: + positional = [("args", args_)] + keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], + key=lambda name_x: name_x[0]) + ordered_args = positional + keyword + + return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' + for name, x in ordered_args + for path, l in generate_key_paths(x) if l is not static) + def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5727c36a646b..af744ae5db96 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -683,7 +683,7 @@ def __init__(self): # We use an outer cache that is keyed on the signature of the arguments, but # when populating a cache entry using _infer_params_impl, we need to provide -# actual arguments. In principle we could refactor _infer_params_impl to look +# actual arguments. In principle, we could refactor _infer_params_impl to look # only at an argument signature instead of args/kwargs in those cases that we # cache, but this was a more minimal change. @util.weakref_lru_cache @@ -730,7 +730,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't popoulate the cache + if p.attrs_tracked: # if attrs, don't populate the cache return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 1d2935ea34d7..8ec5c42ef24c 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -46,6 +46,7 @@ from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir +from jax._src import util as util import numpy as np @@ -241,7 +242,7 @@ def my_f(x, y, z, w): dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") - self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertEqual(dbg.arg_names, ("x", "y", "w", "z")) self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): @@ -261,23 +262,29 @@ def my_f(x_tree, *, y_tree): "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, z, *, w, y): pass - dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + dbg = api_util.debug_info("jit", my_f, (1,), dict(y=2, z=3, w=4), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x", "z")) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_with_pytrees_and_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, y, *, z, w, t): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), + dict(z=(3, 4), w=(5, 6), t=7), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "z[0]", "z[1]")) + + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(z=(3, 4), w=(5, 6), t=7, y=3), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "y", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): @@ -287,7 +294,7 @@ def my_f(x): self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): - # built-in function "int" does not have an inspect.Signature + # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") self.assertEqual(dbg.arg_names, ("args[0]",)) @@ -761,6 +768,122 @@ def f(x, y, *args, **kwargs): re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) + def test_jit_arg_names_with_out_of_order_kwargs(self): + tracer_spy = TracerSpy() + + # The shapes are different, to differentiate them easily + a1 = (np.float32(0),) # a hashable tuple, can be static + b2 = np.arange(2, dtype=np.float32) # b2 + z3 = np.arange(3, dtype=np.float32) + y4 = (np.float32(0.), np.float32(1.), np.float32(2.), np.float32(3.)) + x5 = np.arange(5, dtype=np.float32) + u6 = np.arange(6, dtype=np.float32) + t7 = np.arange(7, dtype=np.float32) + + def my_f(a1, b2, z3, y4, x5, *, u6, t7): + assert np.shape(a1[0]) == () + assert np.shape(b2) == (2,) + assert np.shape(z3) == (3,) + assert np.shape(y4) == (4,) + assert np.shape(x5) == (5,) + assert np.shape(u6) == (6,) + assert np.shape(t7) == (7,) + tracer_spy.append(b2) + tracer_spy.append(x5) + return a1[0] + b2[0] + z3[0] + y4[0] + x5[0] + u6[0] + t7[0] + + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(0,), static_argnames=("y4",)), + # Some positional args passed as keyword + a1, b2, x5=x5, y4=y4, z3=z3, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from b2", + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static and passed by kwarg + a1, b2, z3, x5=x5, y4=y4, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static (declared as static_argnames) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(3,)), + # Positional argument y4 is static (declared as static_argnums) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -1493,34 +1616,50 @@ def my_f(x): def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() - x = np.ones((jax.device_count(),), dtype=np.float32) - def my_f(x, y, *args, a, **kwargs): - # y and kwargs[c] is dead + + # Use different shapes arguments to distinguish them in the HLO + def my_f(x0, y1, *args, b4, **kwargs): + assert np.shape(x0) == () + assert np.shape(y1) == (1,) + assert np.shape(args[0]) == (2,) + assert np.shape(args[1]) == (3,) + assert np.shape(b4) == (4,) + assert np.shape(kwargs["a5"]) == (5,) + assert np.shape(kwargs["c6"]) == (6,) + # kwargs[b5] is dead tracer_spy.append(args[1]) - s = x + a + args[1] + kwargs["d"] - return dict(u=s, v=x) + tracer_spy.append(b4) + tracer_spy.append(kwargs["c6"]) + s0 = x0 + y1[0] + b4[0] + args[1][0] + kwargs["c6"][0] + return dict(v1=jnp.broadcast_to(s0, (1,)), u0=s0) self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), - 1., x, x, x, # x, y, args[0], args[1] - d=x, a=x, b=x, # kwargs + 1., # x0 + np.ones((jax.device_count(), 1), dtype=np.float32), # y1 + np.ones((jax.device_count(), 2), dtype=np.float32), # args[0] + np.ones((jax.device_count(), 3), dtype=np.float32), # args[1] + b4=np.ones((jax.device_count(), 4), dtype=np.float32), + a5=np.ones((jax.device_count(), 5), dtype=np.float32), + c6=np.ones((jax.device_count(), 6), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], result_paths=result['u0'],result['v1']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from b4", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from kwargs['c6']", ], expected_lowering_lines=[ - # TODO(necula): we did not DCE y? - re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), - re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), - re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.*%arg0: tensor<1x1xf..> loc\(\"y1\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1x2xf..> loc\(\"args\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1x3xf..> loc\(\"args\[1\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1x5xf..> loc\(\"kwargs\['a5'\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1x4xf..> loc\(\"b4\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1x6xf..> loc\(\"kwargs\['c6'\]\"\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u0'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v1'\]\"\}"), ] ) From 82ec5737ff3e2466a8fa5615e5a49ebfbdbcd99e Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 2 Apr 2025 03:33:23 -0700 Subject: [PATCH 0324/1769] Remove nanobind pin now that nanobind fix landed. Reverts 33d306ab4090c17b427908853b314e17cb449661 PiperOrigin-RevId: 743062185 --- examples/ffi/pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 6f188ee037da..130dd91bbc70 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -1,7 +1,5 @@ [build-system] -# TODO(dsuo): Remove nanobind pin after -# https://github.com/wjacob/nanobind/pull/980 lands. -requires = ["scikit-build-core", "nanobind==2.5.0", "jax>=0.4.31"] +requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] build-backend = "scikit_build_core.build" [project] From 735cec18cb2f8dff2aea5e503fd886a37aee094e Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 2 Apr 2025 03:39:25 -0700 Subject: [PATCH 0325/1769] [jaxlib] Fix asan tests for subbyte types in CPU/GPU callbacks. Reverts 0b199f48c7e0d4e5837cee34ced7f3fc7065732f PiperOrigin-RevId: 743063615 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 88 ++++++++++++++++++++------------ jaxlib/rocm/BUILD | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 87 +++++++++++++++++++++++--------- tests/python_callback_test.py | 94 ++++++++++++++++++++--------------- 6 files changed, 177 insertions(+), 95 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index fac62c81dee7..d35e421ef904 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,6 +689,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 861ffce3e749..e3aec51d8d25 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" +#include "xla/util.h" namespace nb = nanobind; @@ -81,8 +82,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -112,9 +112,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -122,6 +119,23 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI argument and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + size_t packed_size = arg->size_bytes() * bits_per_element / 8; + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + packed_size); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); @@ -146,8 +160,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -168,32 +181,45 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - continue; + + const void* data = array.data(); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[size_bytes]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; } - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), temp); - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d0c0c798abb8..358a6d1cc9aa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,6 +588,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2ca18afda13d..5b532c1dc501 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,6 +637,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index ac4e7bee5680..fef6a54aab2d 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,8 +80,7 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,9 +96,20 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI argument and return buffers are sized assuming + size_t packed_size = arg->size_bytes() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + packed_size); + data = buffer.get(); + } // We pass in data using default numpy layout i.e., std::nullopt. auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + nb_numpy_ndarray(dtype, dims, std::nullopt, data); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -119,9 +130,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -141,26 +151,55 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); - continue; + + const void* data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } } - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions_size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a8442b4a1356..34ab20c05644 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,10 +586,15 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(x): return x def f(x): @@ -600,21 +605,17 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)(x) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(): return np.arange(8, dtype=dtype) @@ -625,16 +626,43 @@ def f(): ) return y - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)() + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) class PureCallbackTest(jtu.JaxTestCase): @@ -1108,20 +1136,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): From 45d577d3dc12f894416ab1eefb1ef48a15b1f3da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 04:16:48 -0700 Subject: [PATCH 0326/1769] Prepare for disallowing `jnp.array(None)` PiperOrigin-RevId: 743074472 --- jax/_src/numpy/lax_numpy.py | 2 +- tests/lax_numpy_test.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a47f66e5f621..ae32703e7113 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5503,7 +5503,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): - raise TypeError("None is not a valid value for jnp.array") + raise ValueError("None is not a valid value for jnp.array") warnings.warn( "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c0650441edd7..f94f42f027ce 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -47,6 +47,7 @@ from jax._src import array from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -3796,8 +3797,15 @@ def testArrayFromList(self): jnp.array([0, val]) def testArrayNoneWarning(self): - # TODO(jakevdp): make this an error after the deprecation period. - with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"): + if deprecations.is_accelerated('jax-numpy-array-none'): + ctx = self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ) + else: + ctx = self.assertWarnsRegex( + FutureWarning, r'None encountered in jnp.array\(\)' + ) + with ctx: jnp.array([0.0, None]) def testIssue121(self): From 0bee42b6cebe844d151ee4047406af0142756998 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 2 Apr 2025 05:26:24 -0700 Subject: [PATCH 0327/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c3087e022f3c07f7ed1dd4e47024c437a504341b. PiperOrigin-RevId: 743093178 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 13223c4a4b88..90a19ac95e51 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b1971cc2b3407e87fada2674a057d72897b79acc" -XLA_SHA256 = "3b2feabbcd6adc5721533edfbe3dc2ad6517cb1b059cf41dea63f62874bff12d" +XLA_COMMIT = "c3087e022f3c07f7ed1dd4e47024c437a504341b" +XLA_SHA256 = "66457303ddec4dbbe43accf38a8b6b635d55808938cf2495443b09ee9c95a147" def repo(): tf_http_archive( From 10b2cda90e9066dc9c02ae5a068dc47a1e745a2a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 06:09:36 -0700 Subject: [PATCH 0328/1769] Relax the aval check in `select_hlo_lowering_opaque` to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule PiperOrigin-RevId: 743105350 --- jax/_src/lax/lax.py | 6 +++++- tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fd956136ccd3..ac6054328f73 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7357,7 +7357,11 @@ def _select_jvp(primals, tangents): def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in = ctx.avals_in aval_out, = ctx.avals_out - assert all(aval_case == aval_out for aval_case in avals_in[1:]) + assert all((aval_case.shape, aval_case.dtype) == (aval_out.shape, aval_out.dtype) + for aval_case in avals_in[1:]) + assert all( + aval_case == aval_out for aval_case in avals_in[1:] + if not aval_case.sharding.mesh.empty and not aval_out.sharding.mesh.empty) select_lower = _select_hlo_lowering physical_aval_out = core.physical_aval(aval_out) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ee4a8cd3e15e..38f191302ea1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7365,6 +7365,26 @@ def h(y): out = h(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_scan_with_random_key_inside_jit(self): + mesh = jtu.create_mesh((2,), ('x',)) + sharding = NamedSharding(mesh, P(None, 'x')) + + @jax.jit + def scan(xs): + def step(carry, x): + next_carry = jax.vmap(jax.random.fold_in)(carry, x) + next_carry = jnp.where(x % 2 == 0, carry, next_carry) + return next_carry, None + rng = jnp.broadcast_to(jax.random.key(0), xs.shape[1:]) + rng, _ = jax.lax.scan(step, rng, xs) + return rng + + xs = jnp.arange(8).reshape(2, 4) + scan(xs) + + xs = jax.device_put(xs, sharding) + scan(xs) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 6242ffb1ca207e783150176cbca6d97db6fc3325 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 07:40:02 -0700 Subject: [PATCH 0329/1769] Remove unused Attrs from `lu_pivots_to_permutation` FFI kernel. It has been more than 6 months since the release of 0.4.32 which was the first release to stop including `permutation_size` as an attribute when lowering, so it is now safe (via our compatibility policy) to remove this argument. PiperOrigin-RevId: 743132169 --- .../cuda_lu_pivots_to_permutation.py | 25 ++++++++----------- jaxlib/gpu/linalg_kernels.cc | 7 +----- tests/export_back_compat_test.py | 4 +-- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py index 12285a45b77a..8063d9f44722 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -16,11 +16,11 @@ from numpy import array, int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_08 = dict( +data_2025_04_01 = dict( testdata_version=1, platform='cuda', custom_call_targets=['cu_lu_pivots_to_permutation'], - serialized_date=datetime.date(2024, 8, 8), + serialized_date=datetime.date(2025, 4, 1), inputs=(), expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3], @@ -31,25 +31,22 @@ [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), mlir_module_text=r""" module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "result"}) { %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) - %c = stablehlo.constant dense<2> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) - %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "2"}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) return %2 : tensor<2x3x8xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) -#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) -#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":409:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation"(#loc3)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03yQ\x15\x01+\x07\x0b\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x17\x0f\x0b\x17\x1b\x0b\x0b\x0f\x0b\x17\x03'\x0b\x0f\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0bo\x01\x05\x0b\x0f\x03\x11\x07\x1b\x13\x13\x07\x1b\x13\x07\x02\x9e\x02\x1f\x05\x11\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\x1d\x15\x17\x05\x1b\x17\x03b\x065\x1d\x1b\x1d\x05\x1d\x17\x03b\x06\x1d\x03\x05!?#A\x05\x1f\x05!\x1d')\x05#\x17\x03f\x06\x17\x03\x01\x03\x03O#\t\x03\x033\r\x0357\x1d%\x1d'\x1d)\x1d+\x13\r\x01\r\x01\r\x03CE\x1d-\x1d/\x0b\x03\x1d1\x1d3\x05\x01\x1f\x111\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02\x1b)\x07\t\r!\x05\x11\x01\x03\x07)\x03a\x05\x1d)\x07\t\r\x11\x05)\x03\r\x13\x13\x04c\x05\x01Q\x01\x07\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x07\x11\x05B\x13\x05\x03\x0b\x07\x06\x19\x03\x0f\x03\x01\tG%\x1f\x07\x03\x07\x03\x03\x0b\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00J\x0759\x03\x05\x1f\x0f\x0b\x0f!c3)A;\x1b%)9i\x15\x1f\x17\x11\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/lu_pivots_to_permutation\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x002\x00\x00cu_lu_pivots_to_permutation\x00\x08+\t\x05#\x01\x0b+/19;\x03=\x11GIK+M-+-", xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 2293bef89b7d..b48e64f2181d 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -90,8 +90,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, ffi::Dictionary /* unused */, - ffi::Buffer pivots, + gpuStream_t stream, ffi::Buffer pivots, ffi::Result> permutation) { FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), SplitBatch1D(pivots.dimensions())); @@ -119,10 +118,6 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - // TODO(b/358275922): remove Attrs (and the - // unused Dictionary above) 12 weeks after - // release of jaxlib v0.4.32. - .Attrs() .Arg>() .Ret>()); diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 9b457b8f27a5..6a6c8c213a64 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -140,7 +140,7 @@ def test_custom_call_coverage(self): cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2024_07_30, cpu_lu_lapack_getrf.data_2023_06_14, - cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2024_09_30, @@ -411,7 +411,7 @@ def lu_pivots_to_permutation_harness(shape): def test_cuda_lu_pivots_to_permutation(self): shape = (2, 3, 4) func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) - data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01) self.run_one_test(func, data) @parameterized.named_parameters( From 297a4f42dec42e0db08270cbdf436c993586445c Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 1 Apr 2025 09:59:38 +0000 Subject: [PATCH 0330/1769] docs: compilation_cache_expect_pgle option --- docs/gpu_performance_tips.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 737486485736..b3643cb8e292 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -71,6 +71,10 @@ JAX will collect profile information and recompile a module in a single run. Whi in manual mode you need to run a task twice, the first time to collect and save profiles and the second to compile and run with provided data. +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + ### Auto PGLE The auto PGLE can be turned on by setting the following environment variables: @@ -129,6 +133,28 @@ with config.enable_pgle(True), config.pgle_profiling_runs(1): train_step_compiled() ``` +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + ### Manual PGLE If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: From c18139ba7b511aff24c83c44aae5d9e1e0a5e014 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 08:09:13 -0700 Subject: [PATCH 0331/1769] Remove legacy GPU kernels for QR decomposition. Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility On Apr 2, it will have been 6 months since the release of 0.4.34 which is the relevant release for this kernels. PiperOrigin-RevId: 743142261 --- .../cuda_qr_cusolver_geqrf.py | 141 +------------- .../rocm_qr_hipsolver_geqrf.py | 176 ------------------ jaxlib/cuda/BUILD | 52 ------ jaxlib/gpu/BUILD | 3 - jaxlib/gpu/blas.cc | 75 -------- jaxlib/gpu/blas_kernels.cc | 138 -------------- jaxlib/gpu/blas_kernels.h | 48 ----- jaxlib/gpu/gpu_kernels.cc | 5 - jaxlib/gpu/solver.cc | 86 --------- jaxlib/gpu/solver_kernels.cc | 172 ----------------- jaxlib/gpu/solver_kernels.h | 20 -- jaxlib/gpu_solver.py | 6 - jaxlib/rocm/BUILD | 49 ----- jaxlib/tools/build_gpu_kernels_wheel.py | 2 - tests/export_back_compat_test.py | 25 --- 15 files changed, 1 insertion(+), 997 deletions(-) delete mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py delete mode 100644 jaxlib/gpu/blas.cc delete mode 100644 jaxlib/gpu/blas_kernels.cc delete mode 100644 jaxlib/gpu/blas_kernels.h diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index be5c6e01f8d8..00ced41a0492 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,149 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, float64, complex64, complex128 +from numpy import array, float32, complex64 -data_2023_03_18 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["unbatched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_geqrf', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128705 , 0.40824863], - [-0.44721356, 0.36514878, -0.8164964 ], - [-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00], - [ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00], - [ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.custom_call @cusolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\00\03\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<196608xf32>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<196608xf32> - %7 = stablehlo.constant dense<0> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor - %9 = stablehlo.compare EQ, %5, %8, SIGNED : (tensor, tensor) -> tensor - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> - %11 = stablehlo.constant dense<0x7FC00000> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<3x3xf32> - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %14 = stablehlo.select %13, %3, %12 : tensor<3x3xi1>, tensor<3x3xf32> - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> - %16 = stablehlo.constant dense<0x7FC00000> : tensor - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<3xf32> - %18 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %19 = stablehlo.select %18, %4, %17 : tensor<3xi1>, tensor<3xf32> - %20 = stablehlo.constant dense<0.000000e+00> : tensor - %21 = stablehlo.pad %14, %20, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %22 = stablehlo.custom_call @cusolver_orgqr(%21, %19) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<33056xf32>> - %23 = stablehlo.get_tuple_element %22[0] : (tuple, tensor, tensor<33056xf32>>) -> tensor<3x3xf32> - %24 = stablehlo.get_tuple_element %22[1] : (tuple, tensor, tensor<33056xf32>>) -> tensor - %25 = stablehlo.get_tuple_element %22[2] : (tuple, tensor, tensor<33056xf32>>) -> tensor<33056xf32> - %26 = stablehlo.constant dense<0> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor - %28 = stablehlo.compare EQ, %24, %27, SIGNED : (tensor, tensor) -> tensor - %29 = stablehlo.broadcast_in_dim %28, dims = [] : (tensor) -> tensor<1x1xi1> - %30 = stablehlo.constant dense<0x7FC00000> : tensor - %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor) -> tensor<3x3xf32> - %32 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %33 = stablehlo.select %32, %23, %31 : tensor<3x3xi1>, tensor<3x3xf32> - %34 = call @triu(%14) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %33, %34 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xf79\x01\x99\x0f\x0f\x17\x13\x0f\x07\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03_O/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x039\x17\x0f\x0f\x07\x07\x07\x07\x17\x13\x17\x07\x1b\x0f\x17\x13\x1b\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x06\t\x1d{\x05\x1d\x93\x05\x17\x1f\n\x06\x01\x03\x03\x13\xcb\x1dS\x05\x1f\x05!\x05#\x05%\x05'\x03\x03\r\xe9\x05)\x05+\x05-\x05/\x051\x03\x03#\xc7\x053\x1d[\x05\x055\x057\x03\x03\r\xd1\x17\x1f\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x0f\xdd\x03\x03\x0f\xdf\x03\x03\x0f\xe1\x03\x03\r\xe5\x03\x05'\xa7)\xe7\x03\x03\x13\xeb\x03\x03\x11M\x05I\x03\x0b\x17\x9d\x19\xb1\x1b\xb3\x11\xbd\x1d\xbf\x03\x0b\x17\xa3\x19\xc3\x1b\xa3\x11\xa5\x1d\xc5\x05K\x1dW\x05\x05M\x03\x03\r\xc9\x05O\x03\x03#\xcd\x1da\x05\x05Q\x03\x05'\xa7)\xcf\x1dg\x05\x05S\x1dk\x05\x05U\x1do\x05\x05W\x1ds-\x05Y\x1dw-\x05[\x03\x11/\xa91\xd33\xd55\x9d7\xab9\xd7;\xad=\xdb\x05]\x03\x03\x0f\xe3\x03\x03\x13\xed\x1d\x83\x05\x05_\x03\x07\x87\x9f\x89\x9f\x8b\x9f\x05a\x05c\x05e\x1d\x8f\x05\x05g\x03\x11/\xa91\xef3\xf15\x9d7\xab9\xf3;\xad=\xf5\x05i\x03\x03\x97\xa5\x05k\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dm\x03\x03\xc1\x1do\t\x07\x0b\x05\x05\x01\x03\x03\xd9\x1f/\x01#!\x03\x05\xb5\xb9\r\x03\xa1\xb7\x1dq\r\x03\xa1\xbb\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x03\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1d{\x1d}\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xaf\x9b\x13\t\x01\x13\t\x05\x13\t\t\x13\t\r\x1f\x03\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x03\x05\x99\x9b\x03\x07\x99\xaf\x9b)\x05\r\r\x07)\x01\t)\x01\x07\t\x1b\x1d\x01)\x05\r\r\t)\x03\r\x07)\x05\r\r\r\x13)\x03\x04\x000\x07)\x01\r)\x05\x05\x05\r)\x03\t\x0b)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x0b)\x03%\x07/\t\x01\x11\x03\x17)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x01\x03\x1f\x04\xe6\x05\x05\x01\x11\x0bK\x07\x03\x01\t\x0f\x11\x0bO\x05\x03G\x91\x0b\x03q!\x03'\x17\x06u\x03\x01\x03\x01\x13\x07\x01y\x03)\x03\x03\x07\x07\x01?\x03\x01\x03\x05\x07\x07\x01A\x03\x11\x03\x05\x07\x07\x01C\x03\x03\x03\x05\x07\x07\x01}\x03\x17\x03\x05\x05\x03\x01E\x03\x03\x03\x07\x01\x07\x03\x03\x03\x0f\r\x07\x01G\x03\x19\x05\x0b\x11\x03\x07\x01\x07\x03\x1b\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x01\x03\x17\x03\x07\x01I\x03\x13\x03\x15\t\x06\x01\x03\x01\x07\x1b\x07\x19\x03\x07\x01\x07\x031\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x11\x03!\x03\x07\x01\x7f\x033\x03\x1f\t\x06\x01\x03\x11\x07%\t#\x05\x03\x81+\x03\x05\x19\x07\x8d\x85\x03\x01\x05\x1d)\x13\x07\x03\x91\x037\x05+'\x07\x07\x03?\x03\x01\x03-\x07\x07\x03A\x03\x03\x03-\x07\x07\x03C\x03\x1f\x03-\x05\x03\x03E\x03\x03\x03\x07\x03\x07\x03\x03\x035\r\x07\x03G\x03\x19\x0517\x03\x07\x03\x07\x03\x1b\x039\x05\x03\x03\x15\x03\x05\x03\x07\x03\x07\x03\x01\x03=\x03\x07\x03I\x03\x13\x03;\t\x06\x03\x03\x01\x07A/?\x1b\x07\t\x95\x03\x01\x03\x1d\x11\x04\x0b\x05CE\x0f\x11\tQ\x05\x03\x15+\x03\x01\x0b\x0b\x03U!\x03\x0f\x05\x03\tY\x03\x03\x03\x07%\x07\x03\x0f\x03\x05\x15\x06%\x03\x0f\x05\x03\x07\x0b\x03_]\x03\x0f\r\x07ec\x03\x13\x05\t\x0b\x05\x03\t+\x03\x05\x03\x07i\x07\x03\x01\x03\x0f\t\x06m\x03\x01\x07\r\x11\x01\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x86\x19\x83\x1f3\x1f+\x11\x0f\x0b\t\t\x0b!\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x00\x03\x00\x00cusolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["batched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cublas_geqrf_batched', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[[ 0. , 0.91287094, 0.40824836], - [-0.4472136 , 0.36514843, -0.81649655], - [-0.8944272 , -0.18257417, 0.4082483 ]], - - [[-0.42426407, 0.80828977, 0.40824953], - [-0.5656854 , 0.11547142, -0.8164964 ], - [-0.7071068 , -0.5773508 , 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], - [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], - [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> - %2 = stablehlo.custom_call @cublas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.pad %3, %7, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.custom_call @cusolver_orgqr(%8, %4) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tuple, tensor<2xi32>, tensor<33056xf32>> - %10 = stablehlo.get_tuple_element %9[0] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2x3x3xf32> - %11 = stablehlo.get_tuple_element %9[1] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2xi32> - %12 = stablehlo.get_tuple_element %9[2] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<33056xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<2xi32> - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %16 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<2x3x3xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> - %20 = stablehlo.select %19, %10, %18 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - %21 = call @triu(%3) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> - return %20, %21 : tensor<2x3x3xf32>, tensor<2x3x3xf32> - } - func.func private @triu(%arg0: tensor<2x3x3xf32>) -> tensor<2x3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.broadcast_in_dim %5, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.select %6, %8, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - return %9 : tensor<2x3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste data_2024_09_26 = {} - data_2024_09_26["f32"] = dict( testdata_version=1, platform='cuda', diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py deleted file mode 100644 index bd5fa628741e..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32 - -data_2024_08_05 = {} - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["unbatched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipsolver_geqrf', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], dtype=float32), array([[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipsolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\01\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>, tensor, tensor<256xf32>) loc(#loc5) - %c = stablehlo.constant dense<0> : tensor loc(#loc5) - %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.compare EQ, %2#2, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %8 = stablehlo.select %7, %2#0, %6 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<3xf32> loc(#loc5) - %11 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> loc(#loc5) - %12 = stablehlo.select %11, %2#1, %10 : tensor<3xi1>, tensor<3xf32> loc(#loc5) - %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %13 = stablehlo.pad %8, %cst_1, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) - %14:3 = stablehlo.custom_call @hipsolver_orgqr(%13, %12) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<128xf32>) loc(#loc8) - %c_2 = stablehlo.constant dense<0> : tensor loc(#loc8) - %15 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor loc(#loc8) - %16 = stablehlo.compare EQ, %14#1, %15, SIGNED : (tensor, tensor) -> tensor loc(#loc8) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc8) - %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %18 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc8) - %19 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc8) - %20 = stablehlo.select %19, %14#0, %18 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc8) - %21 = call @triu(%8) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc9) - return %20, %21 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc14) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc15) - return %6 : tensor<3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03~\x02\xf39\x01\x99\x0f\x17\x13\x0f\x0f\x0b\x0b\x07\x0b\x13\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03[O/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x01\x05\x0b\x0f\x035\x17\x0f\x0f\x07\x07\x07\x17\x17\x13\x07\x07\x0f\x17\x13\x17\x17\x13\x13\x17\x13\x13\x13\x13\x13\x13\x17\x02\xde\x08\x1d}\x03\x17\x1fj\x05\x01\x03\x03\x11\xcf\x1d\x93\x03\x1dU\x03\x05\x1f\x05!\x1f\x05#\x03\x03\x0b\xe5\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03#\xcb\x05/\x1d]\x03\x051\x053\x03\x03\x0b\xd5\x17\x1ff\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03\x0b\xe1\x03\x05'\xab)\xe3\x03\x03\x11\xe7\x03\tGIK\x15M\x15\rO\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x17\x9d\x19\xb5\x1b\xb7\r\xc1\x1d\xc3\x03\x0b\x17\xa7\x19\xc7\x1b\xa7\r\xa9\x1d\xc9\x05M\x1dY\x03\x05O\x03\x03\x0b\xcd\x05Q\x03\x03#\xd1\x1dc\x03\x05S\x03\x05'\xab)\xd3\x1di\x03\x05U\x1dm\x03\x05W\x1dq\x03\x05Y\x1du-\x05[\x1dy-\x05]\x03\x11/\xad1\xd73\xd95\x9d7\xaf9\xdb;\xb1=\xdf\x05_\x03\x03\x11\xe9\x1d\x83\x03\x05a\x03\x07\x87\xa3\x89\xa3\x8b\xa3\x05c\x05e\x05g\x1d\x8f\x03\x05i\x03\x11/\xad1\xeb3\xed5\x9d7\xaf9\xef;\xb1=\xf1\x05k\x03\x03\x97\xa9\x05m\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc5\x1du\t\x07\x0b\x05\x05\x01\x03\x03\xdd\x1f/\x01#!\x03\x05\xb9\xbd\r\x05\xa5\xbb\x9f\xa1\x1dw\r\x05\xa5\xbf\x9f\xa1\x1dy\x1d{\x1d}\r\x03\x9f\xa1##\x1d\x7f\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f%\x01\x13\r\x05\x07\x05\x1f\t\t\x00\x00\x00\x00\x1d\x81\x1d\x83\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xb3\x9b\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\t\t\x00\x00\xc0\x7f\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x85\x1d\x87\x03\x05\x99\x9b\x03\x07\x99\xb3\x9b\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x19)\x01\x0b\t\x1d\x01)\x05\r\r\x19)\x05\r\r\x0f)\x03\r\x0b\x13\x1b)\x01\x0f)\x05\x05\x05\x0f)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\x02\x08\x0b)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x05\x0f)\x03\r\x0f)\x03\x05\r)\x03\x02\x04\x0b\x04\x1a\x05\x05\x01\x11\x0fE\x07\x03\x01\t\r\x11\x0fQ\x07\x03Cu\t\x03s!\x03'\x15\x06w\x03\x05\x03\x01\x11\x07\x01{\t\x05\x15\x07)\x03\x03\x05\x03\x01?\x03\x07\x03\x07\x01\x05\x03\x07\x03\r\x0b\x07\x01A\x03\x1b\x05\t\x0f\x03\x07\x01\x05\x03\x1d\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x05\x03\x15\x03\x07\x01C\x03\x13\x03\x13\x07\x06\x01\x03\x05\x07\x19\x05\x17\x03\x07\x01\x05\x031\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x15\x03\x1f\x03\x07\x01\x7f\x033\x03\x1d\x07\x06\x01\x03\x15\x07#\x07!\x05\x03\x81+\x03\t\x17\x07\x8d\x85\x03\x05\x05\x1b'\x11\x07\x07\x91\x07\x05\x077\x05)%\x05\x03\x07?\x03\x07\x03\x07\x07\x05\x03\x07\x031\x0b\x07\x07A\x03\x1b\x05-3\x03\x07\x07\x05\x03\x1d\x035\x05\x03\x07\x13\x03\t\x03\x07\x07\x05\x03\x05\x039\x03\x07\x07C\x03\x13\x037\x07\x06\x07\x03\x05\x07=+;\x19\x07\t\x95\x03\x05\x03\x1b\x0f\x04\x0f\x05?A\r\x11\tS\x07\x03\x15+\x03\x05\t\t\x03W!\x03\x11\x05\x03\t[\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x13\x06%\x03\x11\x05\x03\x07\t\x03a_\x03\x11\x0b\x07ge\x03\x13\x05\t\x0b\x05\x03\t+\x03\t\x03\x07k\x05\x03\x05\x03\x0f\x07\x06o\x03\x05\x07\r\x11\x01\x0f\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xea\x1a\x89!3!+\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15+\x13\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x01\x00\x00\x00hipsolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["batched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipblas_geqrf_batched', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], - - [[-0.42426407, 0.8082888 , 0.4082513 ], - [-0.5656854 , 0.11547317, -0.81649613], - [-0.7071068 , -0.5773518 , 0.40824607]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607313e+01], - [ 0.0000000e+00, 3.4641036e-01, 6.9281983e-01], - [ 0.0000000e+00, 0.0000000e+00, 8.3555670e-07]]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipblas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>) loc(#loc5) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) - %4:3 = stablehlo.custom_call @hipsolver_orgqr(%3, %2#1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> (tensor<2x3x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc8) - %c = stablehlo.constant dense<0> : tensor loc(#loc8) - %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc8) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc8) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc8) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc8) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> loc(#loc8) - %10 = stablehlo.select %9, %4#0, %8 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc8) - %11 = call @triu(%2#0) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> loc(#loc9) - return %10, %11 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<2x3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<2x3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc15) - %7 = stablehlo.select %5, %6, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc16) - return %7 : tensor<2x3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\x96\x02\xfb=\x01\x9f\x17\x0f\x0f\x0b\x13\x0b\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03]o/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x01\x05\x0b\x0f\x039\x1b\x07\x07\x0f\x17\x0f\x07\x07\x07\x1b\x13\x13\x13\x17\x17\x13\x17\x13\x13\x17\x07\x13\x13\x13\x17\x13\x1b\x13\x02\xf6\t\x17\x1bj\x05\x01\x1d\x8f\x01\x1dK\x01\x05\x1f\x03\x03\x0b\xd5\x05!\x05#\x1f\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03\x1f\xd1\x05/\x1dS\x01\x051\x053\x03\x03\x07\xdd\x17\x1bf\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\t=?A\x11C\x11\rE\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x13\xa3\x15\xbb\x17\xbd\r\xc7\x19\xc9\x03\x0b\x13\xad\x15\xcd\x17\xad\r\xaf\x19\xcf\x05M\x1dO\x01\x05O\x03\x03\x07\xd3\x05Q\x03\x03\x1f\xd7\x1dY\x01\x05S\x03\x05#\xb1%\xd9\x1d_\x01\x05U\x03\x03\x0b\xdb\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq)\x05]\x1du)\x05_\x03\x11+\xb3-\xdf/\xe11\xa33\xb55\xe37\xb79\xe7\x1d{\x01\x05a\x1d\x7f\x01\x05c\x03\x07\x83\xa9\x85\xa9\x87\xa9\x05e\x05g\x05i\x1d\x8b\x01\x05k\x03\x11+\xb3-\xe9/\xeb1\xa33\xb55\xed7\xb79\xef\x05m\x03\x03\x07\xf1\x03\x05#\xb1%\xf3\x03\x03\x0b\xf5\x03\x03\x07\xf7\x03\x03\x0b\xf9\x03\x03\x9d\xaf\x05o\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dq\x1ds\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1du\x03\x03\xcb\x1dw\t\x07\x0b\x05\x05\x01\x03\x03\xe5\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbf\xc3\r\x05\xab\xc1\xa5\xa7\x1dy\r\x05\xab\xc5\xa5\xa7\x1d{\x1d}\x1d\x7f\r\x03\xa5\xa7#!\x1d\x81\x13\x07\x01\x1f\x0f\t\xff\xff\xff\xff\x1f#\x01\x13\x07\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\x00\x00\x1d\x83\x1d\x85\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb9\xa1\xa1\x1d\x87\x1d\x89\x03\x05\x9f\xb9\x03\x07\x9f\xa1\xa1\x1f\x0f\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x01\t)\x05\r\r\x13)\x01\x13\x01\x1b\x13)\x07\t\r\r\x11)\x03A-)\x03\r\x07)\x03\t\x13\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\x07)\x05\r\r\x11)\x03\t\x07)\x03I\t)\x05\t\r\t\x17)\x03\r\x15)\x03\t\x15)\x03\x05\x15)\x03\x02\x04\t)\x03\t\x11)\x07\t\x05\x05\x11)\x03\x05\x07\x04\xa6\x03\x05\x01\x11\x0f;\x07\x03\x01\t\t\x11\x0fG\x07\x03)A\x07\x03o\x1d\x03)\x15\x06s\x03\x05\x03\x01\x11\x07yw\t\x05+\x19\x19\x03\x03\x05\x03}'\x03\x0b\x17\x07\x89\x81\x03\x05\x05\x05\r\x11\x07\x03\x8d\x07\x05\x1d5\x05\x0f\x07\x05\x03\x03\x91\x03\x0f\x03\x07\x03\t\x03\x1d\x03\x17\x0b\x07\x03\x93\x037\x05\x13\x19\x03\x07\x03\x95\x039\x03\x1b\x05\x03\x03\x97\x03\x0b\x03\x07\x03\t\x03\x05\x03\x1f\x03\x07\x03\x99\x03\x17\x03\x1d\r\x06\x03\x03\x05\x07#\x11!\x19\x07\x05\x9b\x03\x05\x03\x05\x0f\x04\x0f\x05%'\t\x11\x05I\x07\x03\x17/\x03\x05\x05\x07\x03M\x1d\x03\r\x05\x03\x05Q\x03\x0f\x03\x07!\t\x03\r\x03\x05\x13\x06!\x03\r\x05\x03\x07\x07\x03WU\x03\r\x0b\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x05\x03\x05'\x03\x0b\x03\x07g\t\x03\x05\x03\x11\r\x06k\x03\x05\x07\x0f\x13\x01\x0f\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00\xbe\x1c\x8b!3-#\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15\x13+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00hipblas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d35e421ef904..d7035c92b24a 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -64,7 +64,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", @@ -98,55 +97,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_make_batch_pointers", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":cublas_kernels", - ":cuda_vendor", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "cudnn_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -516,7 +466,6 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":cublas_kernels", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", @@ -651,7 +600,6 @@ nanobind_extension( py_library( name = "cuda_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 1fd2775ecf9a..e153e0588cf6 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -30,11 +30,8 @@ package( ) exports_files(srcs = [ - "blas.cc", "blas_handle_pool.cc", "blas_handle_pool.h", - "blas_kernels.cc", - "blas_kernels.h", "ffi_wrapper.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc deleted file mode 100644 index 59bf2c4603f6..000000000000 --- a/jaxlib/gpu/blas.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { -namespace { - -namespace nb = nanobind; - -// Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const dtype& np_type) { - static auto* types = new absl::flat_hash_map, BlasType>({ - {{'f', 4}, BlasType::F32}, - {{'f', 8}, BlasType::F64}, - {{'c', 8}, BlasType::C64}, - {{'c', 16}, BlasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, - int b, int m, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -NB_MODULE(_blas, m) { - tsl::ImportNumpy(); - - m.def("registrations", &Registrations); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc deleted file mode 100644 index cdcc154d026d..000000000000 --- a/jaxlib/gpu/blas_kernels.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/make_batch_pointers.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -namespace { - -int SizeOfBlasType(BlasType type) { - switch (type) { - case BlasType::F32: - return sizeof(float); - case BlasType::F64: - return sizeof(double); - case BlasType::C64: - return sizeof(gpublasComplex); - case BlasType::C128: - return sizeof(gpublasDoubleComplex); - } -} - -} // namespace - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); - gpublasComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - gpublasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h deleted file mode 100644 index 8ca7b4db4668..000000000000 --- a/jaxlib/gpu/blas_kernels.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ -#define JAXLIB_GPU_BLAS_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class BlasType { - F32, - F64, - C64, - C128, -}; - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - BlasType type; - int batch, m, n; -}; - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 840c313f2fa3..620f9cf45199 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,7 +16,6 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -33,21 +32,17 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 8013d9877ed5..3c76598e5285 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -54,45 +54,6 @@ SolverType DtypeToSolverType(const dtype& np_type) { return it->second; } -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -106,49 +67,6 @@ nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, - int n, int k) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd // Returns the workspace size and a descriptor for a syevd operation. @@ -423,8 +341,6 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); - dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); @@ -456,8 +372,6 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); m.def("build_syevj_descriptor", &BuildSyevjDescriptor); m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8971619d7f34..040b5a137bc6 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -50,92 +50,6 @@ static int SizeOfSolverType(SolverType type) { } } -// geqrf: QR decomposition - -static absl::Status Geqrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - gpuComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - gpuDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -237,92 +151,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[2], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[2]); - gpuComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[2]); - gpuDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd static absl::Status Syevd_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index 6372e55b930d..a68aaf1ca233 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -33,16 +33,6 @@ enum class SolverType { C128, }; -// geqrf: QR decomposition - -struct GeqrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - #ifdef JAX_GPU_CUDA // csrlsvpr: Linear system solve via Sparse QR @@ -58,16 +48,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - SolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd struct SyevdDescriptor { diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index efb58f9a4164..c846c63e2ff8 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -16,21 +16,15 @@ from .plugin_support import import_from_plugin -_cublas = import_from_plugin("cuda", "_blas") _cusolver = import_from_plugin("cuda", "_solver") _cuhybrid = import_from_plugin("cuda", "_hybrid") -_hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} - for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: - if module: - registrations[platform].extend( - (*i, 0) for i in module.registrations().items()) for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 358a6d1cc9aa..5893af26de85 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -87,54 +87,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_make_batch_pointers", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":hip_vendor", - ":hipblas_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "miopen_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -552,7 +504,6 @@ nanobind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 2f81eacbdde4..e9684108caf0 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -102,7 +102,6 @@ def prepare_wheel_cuda( dst_dir=plugin_dir, src_files=[ f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", f"__main__/jaxlib/cuda/_linalg.{pyext}", f"__main__/jaxlib/cuda/_prng.{pyext}", f"__main__/jaxlib/cuda/_rnn.{pyext}", @@ -140,7 +139,6 @@ def prepare_wheel_rocm( copy_runfiles( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_blas.{pyext}", f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_solver.{pyext}", diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 6a6c8c213a64..789838f99d14 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -38,7 +38,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf -from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd @@ -147,7 +146,6 @@ def test_custom_call_coverage(self): cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, - rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, cpu_svd_lapack_gesdd.data_2023_06_19, @@ -454,29 +452,6 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=info["custom_call_targets"]) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{batched}", - dtype_name=dtype_name, batched=batched) - for dtype_name in ("f32",) - # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. - for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - dtype = dict(f32=np.float32)[dtype_name] - rtol = dict(f32=1e-3)[dtype_name] - shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] - func = lambda: CompatTest.qr_harness(shape, dtype) - self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ - f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) From 576843283bf0e1a1be5fce7fe25dd1f8db94cde7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 2 Apr 2025 14:57:30 +0000 Subject: [PATCH 0332/1769] Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1 Description: - Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1 - When run inside TSAN CI job with 3.14t cpython and under multi-threading the test code from main leads to `RecursionError: maximum recursion depth exceeded` error: ``` ERROR: testLobpcgMonotonicityF32cluster_k_1__n100 (__main__.F32LobpcgTest) F32LobpcgTest.testLobpcgMonotonicityF32cluster_k_1__n100 testLobpcgMonotonicityF32cluster_k_1__n100(matrix_name='cluster(k-1)', n=100, k=10, m=20, tol=2e-06) ---------------------------------------------------------------------- Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_absl_py/site-packages/absl/testing/parameterized.py", line 319, in bound_param_test return test_method(self, **testcase_params) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 408, in testLobpcgMonotonicityF32 self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 272, in checkLobpcgMonotonicity self._possibly_plot(A, eigs, X, m, matrix_name) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 290, in _possibly_plot self._debug_plots(X, eigs, info, matrix_name, plot_dir) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 318, in _debug_plots ax0.legend() ~~~~~~~~~~^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/axes/_axes.py", line 337, in legend self.legend_ = mlegend.Legend(self, handles, labels, **kwargs) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 549, in __init__ self._init_legend_box(handles, labels, markerfirst) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 896, in _init_legend_box handle_list.append(handler.legend_artist(self, orig_handle, ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^ fontsize, handlebox)) ^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 129, in legend_artist artists = self.create_artists(legend, orig_handle, xdescent, ydescent, width, height, fontsize, handlebox.get_transform()) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 303, in create_artists self.update_prop(legline, orig_handle, legend) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 88, in update_prop self._update_prop(legend_handle, orig_handle) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 79, in _update_prop self._default_update_prop(legend_handle, orig_handle) ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 84, in _default_update_prop legend_handle.update_from(orig_handle) ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/lines.py", line 1358, in update_from self._marker = MarkerStyle(marker=other._marker) ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 248, in __init__ self._set_marker(marker) ~~~~~~~~~~~~~~~~^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 323, in _set_marker self.__dict__ = copy.deepcopy(marker.__dict__) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 131, in deepcopy y = copier(x, memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 202, in _deepcopy_dict y[deepcopy(key, memo)] = deepcopy(value, memo) ~~~~~~~~^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy y = copier(memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__ p = copy.deepcopy(super(), memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 157, in deepcopy y = _reconstruct(x, memo, *rv) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 234, in _reconstruct y = func(*args) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 233, in args = (deepcopy(arg, memo) for arg in args) ~~~~~~~~^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy y = copier(memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__ p = copy.deepcopy(super(), memo) ``` --- tests/BUILD | 5 ++++- tests/lobpcg_test.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index b501a614da39..23d59e8d549a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -239,7 +239,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], - env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, + # Set LOBPCG_EMIT_DEBUG_PLOTS=1 to debug + # checkLobpcgMonotonicity and checkApproxEigs tests + # using matplotlib plots + # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, shard_count = { "cpu": 48, "gpu": 48, diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index fc2b0df849d1..76d6006432f4 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -272,7 +272,7 @@ def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1': return if isinstance(A, (np.ndarray, jax.Array)): From 2e16367991bd72f98381d40a77835c4b03c2c3e1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 08:29:12 -0700 Subject: [PATCH 0333/1769] Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes. PiperOrigin-RevId: 743148311 --- jax/_src/random.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 7277ed5aa966..0dcbda7bb717 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -406,15 +406,13 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) - -@partial(jit, static_argnums=(1, 2, 5)) -def _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) -> Array: if out_sharding is None: return _uniform(key, shape, dtype, minval, maxval) - def f(key, minval, maxval): return _uniform(key, shape, dtype, minval, maxval) + def f(k, minv, maxv): + return _uniform(k, shape, dtype, minv, maxv) return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) +@partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): From 3aeabaedea957293ba6a2f777d8cb30a9bf0aed4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 13:13:25 -0700 Subject: [PATCH 0334/1769] jnp.isinf & friends: support __jax_array__ --- jax/_src/numpy/ufuncs.py | 10 ++++++---- tests/array_extensibility_test.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 60e10b3be048..3fe63545e6df 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -36,7 +36,7 @@ from jax._src.numpy import reductions from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( - check_arraylike, promote_args, promote_args_inexact, + check_arraylike, ensure_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) from jax._src.util import set_module @@ -3500,7 +3500,7 @@ def isinf(x: ArrayLike, /) -> Array: >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ - check_arraylike("isinf", x) + x = ensure_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) @@ -3513,7 +3513,7 @@ def isinf(x: ArrayLike, /) -> Array: return lax.full_like(x, False, dtype=np.bool_) -def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: +def _isposneginf(infinity: float, x: Array, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) @@ -3556,6 +3556,7 @@ def isposinf(x, /, out=None): >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool) """ + x = ensure_arraylike("isposinf", x) return _isposneginf(np.inf, x, out) @@ -3590,6 +3591,7 @@ def isneginf(x, /, out=None): >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool) """ + x = ensure_arraylike("isneginf", x) return _isposneginf(-np.inf, x, out) @@ -3624,7 +3626,7 @@ def isnan(x: ArrayLike, /) -> Array: >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ - check_arraylike("isnan", x) + x = ensure_arraylike("isnan", x) return lax.ne(x, x) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 8f5ea33b5894..55089720f520 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -353,8 +353,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.isin, Int[5], Int[10]), NumPyAPI.sig(jnp.isinf, Float[5]), NumPyAPI.sig(jnp.isnan, Float[5]), - # NumPyAPI.sig(jnp.isneginf, Float[5]), - # NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isneginf, Float[5]), + NumPyAPI.sig(jnp.isposinf, Float[5]), NumPyAPI.sig(jnp.isreal, Float[5]), NumPyAPI.sig(jnp.isrealobj, Float[5]), NumPyAPI.sig(jnp.isscalar, Float[()]), From 2a24b407368b392b384885c727eb0323be01c802 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 16:42:25 +0000 Subject: [PATCH 0335/1769] Bump actions/cache from 4.2.0 to 4.2.3 Bumps [actions/cache](https://github.com/actions/cache) from 4.2.0 to 4.2.3. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/1bd1e32a3bdc45362d1e726936510720a7c30a57...5a3ec84eff668545956fd18022155c47e93e2684) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 2 +- .github/workflows/tsan.yaml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c575c84cd422..5576ccd6e745 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -35,7 +35,7 @@ jobs: with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 4c28608a8257..1bdb36b2cd03 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -74,7 +74,7 @@ jobs: - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz @@ -97,7 +97,7 @@ jobs: - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz @@ -105,7 +105,7 @@ jobs: - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -160,7 +160,7 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -169,7 +169,7 @@ jobs: - name: Restore cached Scipy if: ${{ matrix.python-version == '3.14' }} id: cache-scipy-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -236,7 +236,7 @@ jobs: - name: Save Scipy wheel id: cache-scipy-save if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse From 3d70fc819748b2ac78025653e9625660b2664886 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 10:20:32 -0700 Subject: [PATCH 0336/1769] Add pbroadcast insertion for `psum_p` in the traceable. This effectively replaces `psum_p` with `psum2_p` if `varying_axes_in_types` is on. psum_p can be replaced with psum2_p in follow up CLs Also populate the aval of `ShardMapTracer` with `vma` PiperOrigin-RevId: 743188081 --- jax/_src/lax/parallel.py | 22 ++++++++++++++++++++-- jax/experimental/shard_map.py | 9 +++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 39b6c68679ca..5e02318da441 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -141,10 +141,28 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + if config.varying_axes_in_types.value: + out_flat = bind_psum2_p(leaves, axes=tuple(axis_name), + axis_index_groups=axis_index_groups) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) +def bind_psum2_p(leaves, *, axes, axis_index_groups): + if axis_index_groups is not None: + raise NotImplementedError + + from jax.experimental.shard_map import psum2_p, pbroadcast + axes_ = frozenset(axes) + args_ = [] + for x in leaves: + in_vma = core.get_aval(x).vma + args_.append(pbroadcast(x, tuple(pbroadcast_names)) + if (pbroadcast_names := axes_ - in_vma) else x) + return psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) + + def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 4b9daf170dce..ef3751c96901 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1018,7 +1018,10 @@ def aval(self): new_sharding = NamedSharding( _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error - return out.update(sharding=new_sharding) + manual_axes = set(self._trace.mesh.axis_names) - self._trace.auto + vma = (frozenset(manual_axes - self.rep) + if config.varying_axes_in_types.value else frozenset()) + return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): if self.rep == set(self._trace.mesh.axis_names): @@ -1111,7 +1114,9 @@ def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): f"over axis name {axes}. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") - return [a.update(vma=a.vma.union(frozenset(axes))) for a in args] + sharding = NamedSharding(get_abstract_mesh(), P()) + return [a.update(sharding=sharding, vma=a.vma.union(frozenset(axes))) + for a in args] pbroadcast_p.def_abstract_eval(_pbroadcast_abstract_eval) mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) From 92f7aeab48f144ba059cac29406b267e4030fe31 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 2 Apr 2025 12:09:48 -0700 Subject: [PATCH 0337/1769] Add simple vmap support for lax.ragged_all_to_all. PiperOrigin-RevId: 743230485 --- jax/_src/lax/parallel.py | 27 +++++ tests/ragged_collective_test.py | 194 ++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 5e02318da441..e533672a1d9b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1374,11 +1374,38 @@ def _ragged_all_to_all_transpose( output_t = jax.numpy.where(mask, 0, t) return [operand_t, output_t] + [None] * 4 +def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, + axis_name, axis_index_groups): + del axis_data + if axis_index_groups: + raise NotImplementedError("Please open a feature request!") + + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes = vals_in + operand_dim, output_dim, input_offsets_dim, send_sizes_dim, output_offsets_dim, recv_sizes_dim = dims_in + if not (operand.shape[operand_dim] == output.shape[output_dim] == input_offsets.shape[input_offsets_dim] == send_sizes.shape[send_sizes_dim] == output_offsets.shape[output_offsets_dim] == recv_sizes.shape[recv_sizes_dim]): + raise ValueError("all operands must have the same batch sizes") + + sliced_results = [] + for i in range(operand.shape[operand_dim]): + sliced_operand = slicing.slice_in_dim(operand, start_index=i, limit_index=i+1, axis=operand_dim).flatten() + sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim).flatten() + sliced_input_offsets = slicing.slice_in_dim(input_offsets, start_index=i, limit_index=i+1, axis=input_offsets_dim).flatten() + sliced_send_sizes = slicing.slice_in_dim(send_sizes, start_index=i, limit_index=i+1, axis=send_sizes_dim).flatten() + sliced_output_offsets = slicing.slice_in_dim(output_offsets, start_index=i, limit_index=i+1, axis=output_offsets_dim).flatten() + sliced_recv_sizes = slicing.slice_in_dim(recv_sizes, start_index=i, limit_index=i+1, axis=recv_sizes_dim).flatten() + sliced_result = ragged_all_to_all(sliced_operand, sliced_output, sliced_input_offsets, sliced_send_sizes, sliced_output_offsets, sliced_recv_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) + sliced_result = lax.expand_dims(sliced_result, dimensions=(output_dim,)) + sliced_results.append(sliced_result) + + concat_result = lax.concatenate(sliced_results, dimension=output_dim) + return concat_result, operand_dim + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) +batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') def insert_collective_pbroadcast(axis_name, x): diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 844892adc052..1dd6ef657561 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -21,6 +21,7 @@ import jax import jax.ad_checkpoint from jax import lax +from jax import vmap from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu @@ -381,6 +382,199 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_batch_0_data_shard_axis_0_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=0, + input_config=0, + ), + dict( + testcase_name='_batch_0_data_shard_axis_1_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=1, + input_config=0, + ), + dict( + testcase_name='_batch_1_data_shard_axis_0_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=0, + input_config=1, + ), + dict( + testcase_name='_batch_1_data_shard_axis_1_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=1, + input_config=1, + ), + ) + def test_ragged_all_to_all_vmap( + self, + axis_name, + vmap_axis_name, + mesh_axes, + vmap_batch_axis, + data_shard_axis, + input_config, + ): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + + def get_data_sharding(axis): + if axis == 0: + return P(axis_name, None, None) + elif axis == 1: + return P(None, axis_name, None) + else: + raise ValueError("Invalid data_shard_axis") + + data_sharding = get_data_sharding(data_shard_axis) + + if input_config == 0: + operand_data = jnp.array([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 1]], + [[1, 2], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [1, 2]], + [[0, 0], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [2, 1]], + [[1, 1], [2, 1]]], dtype=jnp.int32) + elif input_config == 1: + operand_data = jnp.array([[[1, 2, 3], [1, 2, 3]], + [[4, 5, 6], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 2]], + [[1, 1], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [0, 0]], + [[1, 2], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [1, 1]], + [[2, 1], [2, 1]]], dtype=jnp.int32) + else: + raise ValueError("Invalid input config") + + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.array([[[0, 1], [0, 1]], + [[0, 1], [0, 1]]], dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=vmap_batch_axis, out_axes=0, axis_name=vmap_axis_name + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ).reshape( + (2, 2, 4) + ) + expected_res = jnp.array([[[1, 4, 0, 0], [2, 3, 5, 0]], + [[1, 4, 0, 0], [2, 3, 5, 0]]], dtype=jnp.int32) + self.assertAllClose(res, expected_res) + + def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + axis_index_groups=[[0, 1]], + ) + + with self.assertRaisesWithLiteralMatch( + NotImplementedError, 'Please open a feature request!'): + vmap(fwd, in_axes=0, out_axes=0, axis_name='b')(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + def test_ragged_all_to_all_errors(self): operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) From a442fecca8b75f0803a27601046ca66d5cba134c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 15:32:15 -0400 Subject: [PATCH 0338/1769] Fix custom_transpose when composed with custom_jvp and use_direct_linearize=True. --- jax/_src/custom_transpose.py | 22 +++++++++++++--------- tests/api_test.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5e87fdb203c9..21e607b5bff2 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -177,15 +177,19 @@ def bind_with_trace(self, trace, call_args, params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): - assert 'call_jaxpr' in params - assert 'transpose_jaxpr_thunk' in params - new_params: dict[str, Any] = dict(params) - new_params['transpose'] = make_transpose_from_thunk( - new_params.pop('transpose_jaxpr_thunk'), - new_params['lin_tree']) - call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') - call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) + if 'call_jaxpr' in params: + assert 'transpose_jaxpr_thunk' in params + new_params: dict[str, Any] = dict(params) + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + else: + assert 'transpose' in params + new_params: dict[str, Any] = dict(params) + call = new_params.pop("call") return [call], new_params diff --git a/tests/api_test.py b/tests/api_test.py index 032c09910fd9..83264f10e033 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10387,6 +10387,28 @@ def cond_wrap(f): self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + class CustomDceTest(jtu.JaxTestCase): From 7f4e8c56fe0b47778ad3795545df2e946a4b4a57 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 13:17:59 -0700 Subject: [PATCH 0339/1769] jnp.concat and friends: support __jax_array__ --- jax/_src/numpy/lax_numpy.py | 7 ++++--- jax/_src/numpy/util.py | 2 +- tests/array_extensibility_test.py | 12 ++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ae32703e7113..43b8923246d4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4466,7 +4466,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - util.check_arraylike("stack", *arrays) + arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -4555,7 +4555,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: [1, 2], [3, 4]], dtype=int32) """ - util.check_arraylike("tile", A) + A = util.ensure_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -4628,7 +4628,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util.check_arraylike("concatenate", *arrays) + arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: @@ -4870,6 +4870,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) + tup = util.ensure_arraylike_tuple("dstack", tup) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e0e20d443e02..9d56267c4b61 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -159,7 +159,7 @@ def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type -def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: +def ensure_arraylike_tuple(fun_name: str, tup: Sequence[Any]) -> tuple[Array, ...]: """Check that argument elements are arraylike and convert to a tuple of arrays. This is useful because ensure_arraylike with a single argument returns a single array. diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 55089720f520..730001abef76 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -267,10 +267,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.ceil, Float[5]), # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), NumPyAPI.sig(jnp.clip, Float[5]), - # NumPyAPI.sig(jnp.column_stack, [float], [(3, 10)]), + NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), - # NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), - # NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), NumPyAPI.sig(jnp.conj, Float[5]), NumPyAPI.sig(jnp.conjugate, Float[5]), NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), @@ -300,7 +300,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), NumPyAPI.sig(jnp.dot, Float[5], Float[5]), NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), - # NumPyAPI.sig(jnp.dstack, Float[3, 5]), + NumPyAPI.sig(jnp.dstack, [Float[3, 5, 1], Float[3, 5, 3]]), NumPyAPI.sig(jnp.ediff1d, Float[5]), NumPyAPI.sig(jnp.empty_like, Float[5]), NumPyAPI.sig(jnp.equal, Float[5], Float[5]), @@ -469,7 +469,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.sqrt, Float[5]), NumPyAPI.sig(jnp.square, Float[5]), NumPyAPI.sig(jnp.squeeze, Float[5]), - # NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), NumPyAPI.sig(jnp.std, Float[5]), NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), NumPyAPI.sig(jnp.sum, Float[5]), @@ -479,7 +479,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.tan, Float[5]), NumPyAPI.sig(jnp.tanh, Float[5]), NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), - # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), NumPyAPI.sig(jnp.trace, Float[5, 5]), NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), From a2d62e2d3a332b1d67e0f4ef7a23375182f1646e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 13:46:07 -0700 Subject: [PATCH 0340/1769] [array_api] update array_api_version to 2024.12 --- .github/workflows/jax-array-api.yml | 2 +- jax/_src/numpy/array_api_metadata.py | 2 +- tests/array_api_skips.txt | 55 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index c91ab6b8b7da..7df4228dd2a3 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 5267e51215ee..d634a2856a1b 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -29,7 +29,7 @@ from jax._src import xla_bridge as xb -__array_api_version__ = '2023.12' +__array_api_version__ = '2024.12' def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2f8d4d1c666f..7534cf6f8acd 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -10,6 +10,24 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] + # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted @@ -19,3 +37,40 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip # JAX raises a ValueError rather than the expected IndexError for out-of-bound axis array_api_tests/test_manipulation_functions.py::test_expand_dims + +# Doesn't promote to uint64 +array_api_tests/test_statistical_functions.py::test_cumulative_prod + +# TODO(jakevdp): fix the following failures: + +# Returns NaN rather than inf +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is +0) -> -infinity] + +# Returns -1.0 rather than 0.0 +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] \ No newline at end of file From bff0fa18adaf0544d6465b62f4e57e8b83b4e614 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Wed, 2 Apr 2025 14:07:29 -0700 Subject: [PATCH 0341/1769] Support `conv` `unfused_flops` in roofline. Since calculating flops is non-trivial, we don't test all the cases currently tested by `test_conv_general_dilated_unfused_hbm_bytes`. Instead, we test behaviors more directly. PiperOrigin-RevId: 743272840 --- jax/experimental/roofline/rooflines.py | 203 ++++++++++++++++++++++++- tests/roofline_test.py | 106 ++++++++++++- 2 files changed, 298 insertions(+), 11 deletions(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index bc8d65e966dd..63a2d7cc4698 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import replace import itertools as it +from typing import Sequence import numpy as np from jax._src import ad_util @@ -35,7 +36,7 @@ from jax.experimental import roofline from jax.experimental import shard_map - +# One FMA (Fused Multiply Add) takes 2 flops to compute. _FMA_FLOPS_FACTOR = 2 for prim in it.chain( @@ -179,16 +180,208 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) + +def _get_spatial_valid_position_count_for_one_dim( + window_dim_stride: int, + base_dilation: int, + window_dilation: int, + kernel_limit: int, + input_limit: int, + output_limit: int, + padding: tuple[int, int], +) -> int: + """Gets the valid position count for conv for a single spatial dimension. + + Args: + window_dim_stride: The stride of the window along this dimension. + base_dilation: The base dilation factor along this dimension. + window_dilation: The window dilation factor along this dimension. + kernel_limit: The size of the kernel along this dimension. + input_limit: The size of the input along this dimension. + output_limit: The size of the output along this dimension. + padding: The padding applied to the input along this dimension. + """ + padding_low = padding[0] + padding_high = padding[1] + + # These two conditions will create an N^2 iteration pattern with only N + # valid elements. This is a performance optimization and produces the same + # result as the whole loop. + if ( + input_limit == output_limit + and kernel_limit == output_limit + and input_limit == base_dilation + and window_dilation == 1 + and max(1, input_limit - 1) == window_dim_stride + and padding_low == 0 + and padding_high == 0 + ): + return input_limit + + if ( + input_limit == 1 + and kernel_limit == output_limit + and window_dilation == 1 + and base_dilation == 1 + and window_dim_stride == 1 + and padding_low == output_limit - 1 + and padding_high == output_limit - 1 + ): + return output_limit + + valid_position_count = 0 + # Loop over each point in the kernel + for kernel_idx in range(kernel_limit): + + # Skip loop for trivial stride and base_dilation + if window_dim_stride == 1 and base_dilation == 1: + undilated_index_base = padding_low - kernel_idx * window_dilation + upper_limit = min( + input_limit + undilated_index_base, + output_limit, + ) + lower_limit = max(0, undilated_index_base) + + valid_position_count += max(upper_limit - lower_limit, 0) + continue + + # Loop over each point in the output + for output_idx in range(output_limit): + # Calculate lhs (input) index without taking base dilation into account + undilated_index = ( + output_idx * window_dim_stride + - padding_low + + kernel_idx * window_dilation + ) + # Calculate the actual lhs (input) index after dilation + lhs_spatial_index = int(undilated_index / base_dilation) + + # Skip if the lhs (input) index is to be dilated. + if undilated_index != lhs_spatial_index * base_dilation: + continue + # Skip if input index is not in bound. + if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit: + continue + + valid_position_count += 1 + return valid_position_count + + +def _get_spatial_valid_position_count( + dnums: convolution.ConvDimensionNumbers, + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], +) -> int: + """Gets the number of valid spatial positions for conv_general_dilated. + + Args: + dnums: The dimension numbers for the convolution. + lhs: The shape of the left-hand side of the convolution. + rhs: The shape of the right-hand side of the convolution. + out: The shape of the output of the convolution. + window_strides: The stride of the window along each spatial dimension. + padding: The padding applied to the input along each spatial dimension. + lhs_dilation: The dilation factor for the left-hand side along each spatial + dimension. + rhs_dilation: The dilation factor for the right-hand side along each spatial + dimension. + """ + input_spatial_dims, kernel_spatial_dims, out_spatial_dims = ( + dnums.lhs_spec[2:], + dnums.rhs_spec[2:], + dnums.out_spec[2:], + ) + + valid_position_counts = 1 + # Loop over each spatial dimension and determine how many valid positions + # there are for each dimension. + for d in range(len(input_spatial_dims)): + valid_position_counts *= _get_spatial_valid_position_count_for_one_dim( + window_dim_stride=window_strides[d], + base_dilation=lhs_dilation[d], + window_dilation=rhs_dilation[d], + kernel_limit=rhs.shape[kernel_spatial_dims[d]], + input_limit=lhs.shape[input_spatial_dims[d]], + output_limit=out.shape[out_spatial_dims[d]], + padding=padding[d], + ) + + return valid_position_counts + + +def _calculate_conv_flops( + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, +) -> int: + """Calculates roofline unfused flops for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ + dnums = convolution.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + spatial_valid_position_counts = _get_spatial_valid_position_count( + dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation + ) + + batch = lhs.shape[dnums.lhs_spec[0]] + num_output_features = out.shape[dnums.out_spec[1]] + num_input_features = rhs.shape[dnums.rhs_spec[1]] + num_output_batch = batch / batch_group_count + + non_spatial_dims_factor = ( + num_input_features * num_output_features * num_output_batch + ) + + fma_count = non_spatial_dims_factor * spatial_valid_position_counts + flops = fma_count * _FMA_FLOPS_FACTOR + return int(flops) + + @roofline.register_roofline(convolution.conv_general_dilated_p) def _conv_general_dilated_roofline( - ctx: roofline.RooflineRuleContext, - *args, - **kw, + ctx: roofline.RooflineRuleContext, + *args, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, + **kw, ) -> roofline.RooflineResult: + """Roofline for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) - # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_flops=_calculate_conv_flops( + lhs, + rhs, + out, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + batch_group_count, + ), unfused_hbm_bytes=( lhs.dtype.itemsize * lhs.size + rhs.dtype.itemsize * rhs.size diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 98f6176c22a0..140beb3c6e71 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -28,6 +28,8 @@ jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_VERY_LARGE_NUMBER = 512 * 1024 + def create_inputs( *shardings: P, @@ -628,7 +630,6 @@ def test_conv_general_dilated_unfused_hbm_bytes( expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) - # TODO(b/394648206): add subtest for unfused_flops once they are supported. self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) @jtu.parameterized.named_parameters( @@ -641,10 +642,10 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes( + def test_conv_general_dilated_padding_string( self, padding: str ): - input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + input_data = jnp.zeros((1, 1, 3, 3), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding=padding @@ -652,10 +653,11 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * 10 * 20 + # Test hbm bytes. + expected_input_size = 1 * 1 * 3 * 3 expected_kernel_size = 1 * 1 * 3 * 3 # Because of same{_lower} padding, output shape should equal to input shape. - # This may not be true for other `{feature, batch}`_group_count`s.c + # This may not be true for other `{feature, batch}`_group_count`s. expected_output_size = expected_input_size # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( @@ -663,7 +665,21 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) - def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + # Test flops. + # For spatial_valid_position_counts, we have 3x3 output with the following + # flops for each element: + # 4 6 4 + # 6 9 6 + # 4 6 4 + # Non_spatial_dims_factor = 1 because `{batch, feature}_group_count` are + # both equal to 1. + # Each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, + 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4), + ) + + def test_conv_general_dilated_padding_string_valid(self): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( @@ -681,12 +697,90 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): * self.get_conv_output_dim(10, 3, 0, 0, 1) * self.get_conv_output_dim(20, 3, 0, 0, 1) ) + # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + # Output shape is [1x1x8x18] and each output element requires (3x3) FMAs, + # and each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * expected_output_size * 3 * 3 + ) + + + @jtu.parameterized.named_parameters( + dict( + testcase_name="padding", + input_spatial_dim=1, + window_strides=[1], + padding=[(_VERY_LARGE_NUMBER - 1, _VERY_LARGE_NUMBER - 1)], + lhs_dilation=[1], + ), + dict( + testcase_name="input", + input_spatial_dim=_VERY_LARGE_NUMBER, + window_strides=[_VERY_LARGE_NUMBER - 1], + padding=[(0, 0)], + lhs_dilation=[_VERY_LARGE_NUMBER], + ), + ) + def test_conv_general_dilated_flops_very_large( + self, input_spatial_dim, window_strides, padding, lhs_dilation + ): + input_data = jnp.zeros((1, 1, input_spatial_dim), dtype=int) + kernel_data = jnp.ones((1, 1, _VERY_LARGE_NUMBER), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + self.assertEqual(result.unfused_flops, 2 * _VERY_LARGE_NUMBER) + + def test_conv_general_dilated_flops_feature_group_count(self): + feature_group_count = 120 + input_data = jnp.zeros((1, feature_group_count, 10, 20), dtype=int) + kernel_data = jnp.ones((feature_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + feature_group_count=feature_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [1x120x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + + def test_conv_general_dilated_flops_batch_group_count(self): + batch_group_count = 120 + input_data = jnp.zeros((batch_group_count, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((batch_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + batch_group_count=batch_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [120x1x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + def test_reduce_sum_no_axis(self): _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) From 96780f19b0b8775f02dc5d57dda11597a2f9c97e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 14:28:54 -0700 Subject: [PATCH 0342/1769] jax.numpy: support __jax_array__ in several more functions --- jax/_src/numpy/lax_numpy.py | 5 +++-- jax/_src/numpy/util.py | 2 +- tests/array_extensibility_test.py | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 43b8923246d4..dba208327adc 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2984,7 +2984,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32) """ - util.check_arraylike("bincount", x) + x = util.ensure_arraylike("bincount", x) if _dtype(x) == bool: x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), np.integer): @@ -5018,7 +5018,7 @@ def choose(a, choices): """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - util.check_arraylike('choose', a, *choices) + a, *choices = util.ensure_arraylike_tuple('choose', (a, *choices)) if not issubdtype(_dtype(a), np.integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -8781,6 +8781,7 @@ def argwhere( >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32) """ + a = util.ensure_arraylike("argwhere", a) result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9d56267c4b61..49605ffc3b0c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -259,7 +259,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None ) -> Array: - check_arraylike("broadcast_to", arr) + arr = ensure_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 730001abef76..14fcc18ca7a5 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -92,6 +92,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: 'apply_along_axis', 'apply_over_axes', 'arange', + 'array_str', + 'array_repr', 'astype', 'bartlett', 'bfloat16', @@ -101,6 +103,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: 'bool_', 'broadcast_shapes', 'c_', + 'can_cast', 'cdouble', 'character', 'complex128', @@ -233,14 +236,12 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.argmin, Float[10]), NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), NumPyAPI.sig(jnp.argsort, Float[10]), - # NumPyAPI.sig(jnp.argwhere, [float], [(10,)]), + NumPyAPI.sig(jnp.argwhere, Float[10]), NumPyAPI.sig(jnp.around, Float[5]), NumPyAPI.sig(jnp.array, Float[5]), NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), - # NumPyAPI.sig(jnp.array_repr, Float[5]), NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), - # NumPyAPI.sig(jnp.array_str, Float[5]), NumPyAPI.sig(jnp.asarray, Float[5]), NumPyAPI.sig(jnp.asin, Float[5]), NumPyAPI.sig(jnp.asinh, Float[5]), @@ -251,7 +252,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.atleast_2d, Float[5]), NumPyAPI.sig(jnp.atleast_3d, Float[5]), NumPyAPI.sig(jnp.average, Float[10]), - # NumPyAPI.sig(jnp.bincount, int[10]), + NumPyAPI.sig(jnp.bincount, Int[10]), NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), NumPyAPI.sig(jnp.bitwise_count, Int[5]), NumPyAPI.sig(jnp.bitwise_invert, Int[5]), @@ -261,11 +262,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), - # NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), - # NumPyAPI.sig(jnp.can_cast, Float[()], to='int32'), + NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), NumPyAPI.sig(jnp.cbrt, Float[5]), NumPyAPI.sig(jnp.ceil, Float[5]), - # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), + NumPyAPI.sig(jnp.choose, Int[3], [Float[3], Float[3], Float[3]], mode='clip'), NumPyAPI.sig(jnp.clip, Float[5]), NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), From 9c58a112b3e3ccf5a4eb8bdbddfb2760a9b2161a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 14:58:06 -0700 Subject: [PATCH 0343/1769] `jnp.array` no longer accepts None PiperOrigin-RevId: 743291099 --- CHANGELOG.md | 5 +++++ jax/_src/numpy/lax_numpy.py | 9 +-------- tests/lax_numpy_test.py | 15 ++++----------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68450dca4057..ffd197b390d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Changes * The minimum CuDNN version is v9.8. * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 43b8923246d4..1b59363d14c7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5502,14 +5502,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): - # Added Nov 16 2023 - if deprecations.is_accelerated("jax-numpy-array-none"): - raise ValueError("None is not a valid value for jnp.array") - warnings.warn( - "None encountered in jnp.array(); this is currently treated as NaN. " - "In the future this will result in an error.", - FutureWarning, stacklevel=2) - leaves, treedef = tree_flatten(object) + raise ValueError("None is not a valid value for jnp.array") leaves = [ leaf if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f94f42f027ce..2c305af6e8f5 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -47,7 +47,6 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -3796,16 +3795,10 @@ def testArrayFromList(self): with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) - def testArrayNoneWarning(self): - if deprecations.is_accelerated('jax-numpy-array-none'): - ctx = self.assertRaisesRegex( - ValueError, 'None is not a valid value for jnp.array' - ) - else: - ctx = self.assertWarnsRegex( - FutureWarning, r'None encountered in jnp.array\(\)' - ) - with ctx: + def testArrayNone(self): + with self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ): jnp.array([0.0, None]) def testIssue121(self): From 9fa5de7b0584f74fce9d0eea89817e8fa9b96b8f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 15:45:15 -0700 Subject: [PATCH 0344/1769] [pallas] Removed `pl.device_id`. Use `lax.axis_index` instead. PiperOrigin-RevId: 743307670 --- jax/_src/pallas/mosaic/lowering.py | 4 ---- jax/_src/pallas/primitives.py | 8 -------- jax/experimental/pallas/__init__.py | 1 - jax/experimental/pallas/tpu.py | 1 - tests/pallas/tpu_pallas_distributed_test.py | 3 +-- 5 files changed, 1 insertion(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f8f49f3d7aea..6c8b3c646a0d 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3602,10 +3602,6 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule -def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.device_id() -lowering_rules[primitives.device_id_p] = _device_id_lowering_rule - def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 4971b83a9ba2..986a62571010 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1237,11 +1237,3 @@ def _semaphore_wait_discharge_rule(in_avals, state_discharge.register_discharge_rule(semaphore_wait_p)( _semaphore_wait_discharge_rule ) - -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index fd523712fa9c..b6d2ac69d2c6 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -48,7 +48,6 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print -from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index da054bf18309..21976c47166b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -52,7 +52,6 @@ # Those primitives got moved to Pallas core. Keeping the updated imports # here for backward compatibility. from jax._src.pallas.core import semaphore as semaphore -from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7d7daf1874f..737ab5137e99 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -51,8 +51,7 @@ def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): - dev_id = pltpu.device_id() - other_dev_id = 1 - dev_id + other_dev_id = 1 - lax.axis_index('x') pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, device_id_type=pltpu.DeviceIdType.LOGICAL) pltpu.semaphore_wait(ready_sem) From 5e0ccb40d6f39a976aefebbb1c46664547195d2d Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Wed, 2 Apr 2025 22:55:58 +0000 Subject: [PATCH 0345/1769] add option to expose attention residual to user --- jax/_src/cudnn/fused_attention_stablehlo.py | 32 +++++++++++++-------- tests/fused_attention_stablehlo_test.py | 20 +++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..61c9aea122a2 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -396,7 +396,7 @@ def is_cuda_compute_capability_equal(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, False, @@ -405,14 +405,16 @@ def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, - sliding_window_length=sliding_window_length, is_training=False) - output = outputs[0] - return output + sliding_window_length=sliding_window_length, is_training=False or return_residual) + if return_residual: + return outputs + else: + return outputs[0] def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, True, @@ -424,11 +426,14 @@ def _dot_product_attention_fwd_rule( sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, outputs[1], outputs[0]) - return outputs[0], res + if return_residual: + return outputs, res + else: + return outputs[0], res def _dot_product_attention_bwd_rule( scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training, res, grad_output): + sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, activation, fwd_output) = res grads = _dot_product_attention_bwd_p_wrapper.bind( @@ -1098,7 +1103,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -1114,13 +1119,14 @@ def _dot_product_attention(query: Array, mask_type: bool, layout: int, sliding_window_length: int | None, - cudnn_version: int): + cudnn_version: int, + return_residual: bool): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, - cudnn_version=cudnn_version) + cudnn_version=cudnn_version, return_residual=return_residual) return output _dot_product_attention.defvjp( @@ -1720,7 +1726,8 @@ def dot_product_attention( dropout_rate: float = 0., qkv_layout: str = "BTNH", sliding_window_length: int | None = None, - use_fp8: bool = False + use_fp8: bool = False, + return_residual: bool = False ): """Computes dot-product attention given query (Q), key (K), and value (V). @@ -1776,6 +1783,7 @@ def dot_product_attention( is the index of each token. E.g., if sliding_window_length == 3 and the sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return softmax stat tensor to users. Returns: Output of the same shape as the query. amax_s: amax of state. (fp8 only) @@ -1851,5 +1859,5 @@ def dot_product_attention( output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout.value, - sliding_window_length, cudnn_version) + sliding_window_length, cudnn_version, return_residual) return output diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index af0b18b02f37..4cb7e11735d6 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -737,6 +737,26 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2) self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) + @jtu.run_on_devices("cuda") + def test_sdpa_residual(self): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), + ) + outs = jitted_sdpa_inference(query, key, value) + assert len(outs) == 2 + @jtu.run_on_devices("cuda") def test_layouts(self): if jax.device_count() < 4: From 5ddec650868df2bee004e062c5664f9b69c762ee Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 3 Apr 2025 00:00:25 +0000 Subject: [PATCH 0346/1769] Remove asserts --- jax/_src/nn/functions.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index cc4a345641dd..d0f5f770e196 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1273,14 +1273,35 @@ def scaled_matmul( >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> scaled_matmul(a, b, a_scales, b_scales) """ - assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales)) + if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): + raise ValueError( + "scaled_matmul requires all inputs to be 3-dimensional arrays" + ) + B_a, M_a, K_a = a.shape B_b, N_b, K_b = b.shape - assert K_a == K_b and B_a == B_b + if K_a != K_b or B_a != B_b: + raise ValueError( + "scaled_matmul requires inputs a and b to have matching batch (B) " + f"and contract (K) dimensions, but got shapes {a.shape} and " + f"{b.shape}" + ) + B_as, M_as, K_as = a_scales.shape B_bs, N_bs, K_bs = b_scales.shape - assert K_as == K_bs and B_as == B_bs - assert M_as == M_a and N_bs == N_b + if K_as != K_bs or B_as != B_bs: + raise ValueError( + "scaled_matmul requires scales to have matching batch (B) and " + f"contract (K) dimensions, but got shapes {a_scales.shape} and " + f"{b_scales.shape}" + ) + + if M_as != M_a or N_bs != N_b: + raise ValueError( + "scaled_matmul requires scales to match non-contract dimensions of " + f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: " + f"{a_scales.shape}, b_scales: {b_scales.shape}" + ) preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) From 2540fcde11e3531267b96e9ad40a80749984bace Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 13:13:13 -0700 Subject: [PATCH 0347/1769] add an `out_sharding` option to `jax.random.bits` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 26 +++++++++++++++++--------- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 0dcbda7bb717..fc571be9493a 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -38,11 +38,11 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.sharding_impls import canonicalize_sharding -from jax._src.pjit import auto_axes from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis @@ -348,9 +348,18 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) +def maybe_auto_axes(f, out_shardings, **hoist_kwargs): + f_ = partial(f, **hoist_kwargs) + if out_shardings is None: + return f_ + else: + return auto_axes(f_, out_shardings=out_shardings) + + def bits(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeUInt | None = None) -> Array: + dtype: DTypeLikeUInt | None = None, + out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -373,8 +382,10 @@ def bits(key: ArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "bits") bit_width = dtype.itemsize * 8 - return _random_bits(key, bit_width, shape) + return maybe_auto_axes(_random_bits, out_sharding, + bit_width=bit_width, shape=shape)(key) def uniform(key: ArrayLike, @@ -711,16 +722,13 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) - out_sharding = canonicalize_sharding(out_sharding, 'normal') + out_sharding = canonicalize_sharding(out_sharding, "normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - if out_sharding is None: - return _normal(key, shape, dtype) - return auto_axes(partial(_normal, shape=shape, dtype=dtype), - out_shardings=out_sharding)(key) + return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 38f191302ea1..ebfdd7fa0b20 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7274,6 +7274,25 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_bits(self, mesh): + @jax.jit + def f(key): + out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_uniform(self, mesh): @jax.jit From 2f617631fbceb56b33fb0312b228a49bc3bee608 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 17:31:23 -0700 Subject: [PATCH 0348/1769] use common `maybe_auto_axes` helper in `random.uniform` --- jax/_src/random.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index fc571be9493a..e519c284a567 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -417,11 +417,10 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - if out_sharding is None: + def f(key, minval, maxval, shape, dtype): # reorder args return _uniform(key, shape, dtype, minval, maxval) - def f(k, minv, maxv): - return _uniform(k, shape, dtype, minv, maxv) - return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) + return maybe_auto_axes(f, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) @partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: From ab816ed8c4d787d5a6760e32b6b34db1fe55e1d1 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 17:38:05 -0700 Subject: [PATCH 0349/1769] add an `out_sharding` option to `jax.random.randint` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 27 +++++++++++++++------------ tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index e519c284a567..1d044ec111ff 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -417,13 +417,11 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - def f(key, minval, maxval, shape, dtype): # reorder args - return _uniform(key, shape, dtype, minval, maxval) - return maybe_auto_axes(f, out_sharding, shape=shape, dtype=dtype)( - key, minval, maxval) + return maybe_auto_axes(_uniform, out_sharding, + shape=shape,dtype=dtype)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _uniform(key, minval, maxval, shape, dtype) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") @@ -467,7 +465,8 @@ def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt = int, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -487,10 +486,12 @@ def randint(key: ArrayLike, dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + out_sharding = canonicalize_sharding(out_sharding, "randint") + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _randint(key, minval, maxval, shape, dtype) -> Array: _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) if not jnp.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") @@ -1557,7 +1558,8 @@ def gumbel(key: ArrayLike, def _gumbel(key, shape, dtype, mode) -> Array: _check_shape("gumbel", shape) if mode == "high": - high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.) + high, low = _uniform(key, minval=0., maxval=1., + shape=(2,) + shape, dtype=dtype) # TODO(parkers): The condition is to protect against rounding up but # we should be able to add safely with the right addition operation. x = jnp.where(high >= 0.5, high, @@ -1565,7 +1567,8 @@ def _gumbel(key, shape, dtype, mode) -> Array: return -jnp.log(-jnp.log1p(-x)) else: return -jnp.log(-jnp.log( - _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + _uniform(key, minval=jnp.finfo(dtype).tiny, maxval=1., + shape=shape, dtype=dtype))) def categorical( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ebfdd7fa0b20..d3d9cab7a5ba 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7312,6 +7312,26 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_randint(self, mesh): + @jax.jit + def f(key): + out = jax.random.randint(key, shape=(8, 12), minval=0, maxval=10, + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_normal(self, mesh): @jax.jit From f1adec35641553fd38aceff4266fbd5986c11ded Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 00:26:46 -0700 Subject: [PATCH 0350/1769] [Mosaic GPU] Define the `mosaic_gpu.custom_primitive` dialect op. PiperOrigin-RevId: 743441718 --- jaxlib/mosaic/dialect/gpu/BUILD | 1 + jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 37 +++++++++++++++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 26 +++++++++++++++ tests/mosaic/gpu_dialect_test.py | 44 +++++++++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index f0e399da0575..592d22b699a3 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -119,6 +119,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 1b3d08f91fb0..0a36aab6fbcd 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -31,10 +31,12 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -370,6 +372,41 @@ llvm::LogicalResult WGMMAOp::verify() { return llvm::success(); } +llvm::LogicalResult CustomPrimitiveOp::verify() { + int num_vector_operands = 0; + int num_smem_ref_operands = 0; + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + for (auto operand : getOperands()) { + if (mlir::isa(operand.getType())) { + ++num_vector_operands; + } + + if (auto ref_ty = mlir::dyn_cast(operand.getType())) { + if (ref_ty.getMemorySpace() == smem) { + ++num_smem_ref_operands; + } + } + } + + if (num_vector_operands != getInLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector operand."); + } + + if (num_smem_ref_operands != getInTransforms().size()) { + return emitOpError( + "Custom primitive must have transforms for each memref operand in " + "smem."); + } + + if (getResults().size() != getOutLayouts().size()) { + return emitOpError("Custom primitive must have a layout for each result."); + } + + return llvm::success(); +} + mlir::AffineMap LayoutAttr::getAffineMap() const { // This always returns an identity map. It's technically not correct, but we // don't actually use it anywhere. It's only called during verification of the diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 85929080faec..0d954716b179 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -484,4 +484,30 @@ def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Allows defining a custom Mosaic GPU primitive."; + let description = [{ + Allows defining a custom Mosaic GPU primitive. + + Custom primitives should carry input and output layouts for each of their + vector operands and outputs, and input transforms for each of their memref + operands that live in SMEM. + + Custom primitives can only return vectors. + }]; + + let arguments = ( + ins Variadic:$operands, + // Attributes + ArrayAttr:$in_layouts, + ArrayAttr:$in_transforms, + ArrayAttr:$out_layouts + ); + + let results = (outs Variadic>); + let regions = (region AnyRegion:$body); + + let hasVerifier = 1; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ba9d23fa5b4f..bc94d72dc0d8 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -862,6 +862,50 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): self.assertLen(conversion_ops, 1) self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + @parameterized.parameters( + (True, False, False), + (False, True, False), + (False, False, True), + ) + def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_and_results( + self, omit_in_layouts, omit_in_transforms, omit_out_layouts + ): + vec_ty = ir.VectorType.get((4, 32), ir.BF16Type.get()) + out_layouts = [ + layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + ] + in_layouts = out_layouts * 2 + in_transforms = [ + ir.ArrayAttr.get([mgpu.dialect.SwizzleTransformAttr.get(128)]) + ] + + in_layouts = [] if omit_in_layouts else in_layouts + in_transforms = [] if omit_in_transforms else in_transforms + out_layouts = [] if omit_out_layouts else out_layouts + + def body(vec1, vec2, ref): + mgpu.dialect.custom_primitive( + [vec_ty], [vec1, vec2, ref], in_layouts, in_transforms, out_layouts + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=smem) + func.FuncOp.from_py_func(vec_ty, vec_ty, ref_ty)(body) + + if omit_in_layouts: + error = "layout for each vector operand" + elif omit_in_transforms: + error = "transforms for each memref operand in smem" + else: + assert omit_out_layouts + error = "layout for each result" + + with self.assertRaisesRegex(ir.MLIRError, error): + self.module.operation.verify() + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 6243ac80fca6ba718b01facc52c4cde7277838bc Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 3 Apr 2025 02:44:18 -0700 Subject: [PATCH 0351/1769] [CI] Enable nightly TPU CI tests for v6e. PiperOrigin-RevId: 743478967 --- .github/workflows/cloud-tpu-ci-nightly.yml | 12 ++++++++++-- .github/workflows/wheel_tests_nightly_release.yml | 11 +++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..fd799a3f70b5 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -26,11 +26,19 @@ jobs: matrix: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] python-version: ["3.10"] + # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. + exclude: + - tpu: + type: "v6e-8" + jaxlib-version: "nightly+oldest_supported_libtpu" + - tpu: + type: "v6e-8" + jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20241205 diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index fd4a52d296e0..6fd48d016bd0 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -80,19 +80,26 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 and v6e-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" + - tpu-specs: + type: "v6e-8" + python: "3.10" + - tpu-specs: + type: "v6e-8" + python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From ea196dac12d53d011e01db724156bd3c7f9952f5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 03:34:50 -0700 Subject: [PATCH 0352/1769] [pallas:mosaic_gpu] Slightly reworded the docstrings for a few recently added primitives PiperOrigin-RevId: 743492343 --- jax/_src/pallas/mosaic_gpu/primitives.py | 29 ++++++++++--------- jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/launch_context.py | 9 +++--- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index ff2678454b42..46ec8a87082e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,7 @@ import enum import itertools import math -from typing import Any, Literal, Optional +from typing import Any, Literal import jax from jax._src import core as jax_core @@ -135,17 +135,17 @@ def _load_p_lowering_rule( def load( - src: _Ref, idx, *, layout: Optional[Layout | ParameterizedLayout] = None -) -> mgpu.FragmentedArray: - """ Loads a ref (SMEM or GMEM) into a FragmentedArray with the specified layout. + src: _Ref, idx, *, layout: Layout | ParameterizedLayout | None = None +) -> jax.Array: + """Loads from a reference into an array with the specified layout. Args: - src: The reference to copy from. + src: The reference to load from. Can be either in SMEM or GMEM. idx: The index to load from. - layout: The optional layout to use for the returned FragmentedArray. + layout: The optional layout to use for the resulting array. Returns: - A FragmentedArray containing the loaded data in the specified layout. + The loaded array. """ src, src_transforms = state_primitives.get_ref_and_transforms( src, idx, "load", force_trailing_indexer=True, @@ -160,6 +160,7 @@ def load( layout=layout ) + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -185,7 +186,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, - reduction_op: Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] | None, + reduction_op, ): if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] @@ -295,9 +296,7 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, - reduction_op: Literal[ - "add","min","max","inc","dec","and","or","xor" - ] | None = None, + reduction_op: mgpu.ReductionOp | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -306,10 +305,12 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. - commit_group: If ``True``, this and any previously uncommitted copies - are committed to a group and can be awaited jointly via + commit_group: If ``True``, this and any previously uncommitted copies are + committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. - reduction_op: if set, perform the specified reduction op when copy to gmem + reduction_op: If set, perform the specified reduction operation when storing + to GMEM. For example, using ``"add"`` is conceptually equivalent to + doing ``src += dst``. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index afc87b5d96fa..e645115940e4 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -32,6 +32,7 @@ from .launch_context import ( LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, + ReductionOp as ReductionOp, Rounding as Rounding, TileTransform as TileTransform, TransposeTransform as TransposeTransform, diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 41c15bc5492e..aca3fc723882 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -229,6 +229,7 @@ def batch(self, leading_rank: int) -> MemRefTransform: OnDeviceProfiler = profiler.OnDeviceProfiler +ReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] @dataclasses.dataclass() class LaunchContext: @@ -406,10 +407,10 @@ def async_copy( uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, - predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. - reduction_op: Literal[ - "add","min","max","inc","dec","and","or","xor" - ] | None = None, + predicate: ( + ir.Value | None + ) = None, # Should select 0 or 1 threads from the WG. + reduction_op: ReductionOp | None = None, ): """Initiates an async copy between GMEM and SMEM. From 552eea8ebddccd7f9605f0f62e7ca685621bb0db Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 03:57:38 -0700 Subject: [PATCH 0353/1769] [pallas:mosaic_gpu] `emit_pipeline*` now passes the loop indices into the body This replaces the old behavior where `emit_pipeline*` would replace the current parallel grid with the sequential grid, changing the output of `pl.program_id`. With this change, `pl.program_id` always works wrt the parallel grid. PiperOrigin-RevId: 743498194 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 +-- jax/_src/pallas/mosaic_gpu/pipeline.py | 49 ++++++++++--------- .../pallas/ops/gpu/attention_mgpu.py | 2 +- tests/pallas/mosaic_gpu_test.py | 26 +++++----- 4 files changed, 43 insertions(+), 42 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f027d5bcb76d..aafab927d4c2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -629,13 +629,9 @@ def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): scratch_refs = [ next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs ] - def body_fn(*refs): - grid_env = pallas_core.current_grid_env() - assert grid_env is not None # Set by ``emit_pipeline``. + def body_fn(indices, *refs): program_ids_template = util.merge_lists( - which_parallel, - [grid_axis.index for grid_axis in grid_env], - [None] * sum(which_parallel), + which_parallel, indices, [None] * sum(which_parallel) ) assert len(refs) + len(scratch_refs) == len(jaxpr.invars) return gpu_primitives.jaxpr_call( diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index d85ba4ae2a03..ec088f43c4b2 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -33,7 +33,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp @@ -171,7 +170,8 @@ def emit_pipeline( """Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body. + body: The pipeline body, called with the indices for the current step, the + input refs, followed by the output refs. grid: The grid to use for the pipeline. in_specs: The block specs for the inputs. out_specs: The block specs for the outputs. @@ -248,7 +248,8 @@ def scoped_pipeline( it.islice(it.product(*map(range, grid)), max_concurrent_steps) ): indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + for bref in in_brefs: + bref.copy_in(step, indices, barrier_ref) # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) @@ -266,11 +267,13 @@ def loop_body(step, carry): max_concurrent_steps - (1 + delay_release), wait_read_only=True ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body(*( - bref.get_ref_for_slot(slot) - for bref in it.chain(in_brefs, out_brefs) - )) + body( + indices, + *( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + ), + ) if copies_out_in_loop: gpu_primitives.commit_smem() @@ -355,6 +358,7 @@ def do_fetch(): return pipeline + def emit_pipeline_warp_specialized( body: Callable[..., None], *, @@ -376,14 +380,16 @@ def emit_pipeline_warp_specialized( ``manual_consumed_barriers`` argument is True. ``` - def body(*input_refs, *output_refs, [consumed_barriers]) -> None: + def body(indices, *input_refs, *output_refs, [consumed_barriers]) -> None: ``` or with a carries enabled (enabled via the ``carry_coroutine`` argument), where the body returns the next carry: ``` - def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: + def body( + indices, *input_refs, *output_refs, [consumed_barriers], carry + ) -> Carry: ``` Args: @@ -545,18 +551,17 @@ def compute_loop_body(step, carry): if copies_out_in_loop: gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body_refs = [] - for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) - body_refs.append(bref.get_ref_for_slot(buf_slot)) - - body_args = body_refs - if manual_consumed_barriers: - body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] - if has_carry: - body_args += [prev_body_carry] - next_body_carry = body(*body_args) + body_refs = [] + for bref in it.chain(in_brefs, out_brefs): + buf_slot = _get_slot(slot, ~bref.is_index_invariant) + body_refs.append(bref.get_ref_for_slot(buf_slot)) + + body_args = body_refs + if manual_consumed_barriers: + body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] + if has_carry: + body_args += [prev_body_carry] + next_body_carry = body(indices, *body_args) if not manual_consumed_barriers: [consumed_barrier_ref] = consumed_barrier_refs diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index b19e371a1eb8..d06d3b39cb7a 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -310,7 +310,7 @@ def _compute_thread(): ) plgpu.wait_smem_to_gmem(0) - def kv_pipeline(k_smem, v_smem, + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): acc, m_i, l_i = carry diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index acf82ce23eba..06dfd453fb19 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1923,7 +1923,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): # +1 for the indexing done by ``emit_pipeline`. self.assertLen(x_smem.transforms, len(transforms) + 1) o_smem[...] = x_smem[...] + 1.0 @@ -1949,7 +1949,7 @@ def kernel(x_gmem, o_gmem): grid=(), )(x_gmem, o_gmem) - def nested_kernel(x_gmem, o_gmem): + def nested_kernel(_, x_gmem, o_gmem): plgpu.emit_pipeline( nested_kernel_body, in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], @@ -1958,7 +1958,7 @@ def nested_kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def nested_kernel_body(x_smem, o_smem): + def nested_kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) @@ -1983,7 +1983,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) @@ -2016,7 +2016,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) @@ -2044,7 +2044,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) @@ -2086,7 +2086,7 @@ def test_realistic_matmul(self): ) def kernel(a_gmem, b_gmem, o_smem, acc): - def kernel_body(a_smem, b_smem): + def kernel_body(_, a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) assert b_smem.shape == (tile_k, tile_n) plgpu.wgmma(acc, a_smem, b_smem) @@ -2147,7 +2147,7 @@ def test_pipelined_copy(self, m, n, manual_consumed_barriers): x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) blk_m = blk_n = 64 - def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): + def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -2201,7 +2201,7 @@ def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) ) - def tiled_add_kernel(x_smem, y_smem, o_smem): + def tiled_add_kernel(_, x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -2265,7 +2265,7 @@ def _compute_thread(): plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) plgpu.wait_smem_to_gmem(0) - def tiled_acc_kernel(x_smem, carry): + def tiled_acc_kernel(_, x_smem, carry): o_carry, = carry new_carry = x_smem[...] + o_carry return (new_carry,) @@ -2620,7 +2620,7 @@ def test_stage4(self): self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) @@ -2644,7 +2644,7 @@ def test_stage5(self): self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = plgpu.GPUBlockSpec( @@ -2707,7 +2707,7 @@ def test_stage6(self): self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] From 0ec1251d9ec6f8995ff50e35cecebc5a11afc71c Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 05:06:06 -0700 Subject: [PATCH 0354/1769] [Mosaic GPU] Get rid of `LayoutAttr` and related comments. This is no longer used, since we elected to refine the IR by annotating it with `{in,out}_transforms` in the lowering pipeline instead. PiperOrigin-RevId: 743516621 --- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 21 --------- .../dialect/gpu/integrations/c/attributes.cc | 34 -------------- .../dialect/gpu/integrations/c/attributes.h | 16 ------- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 8 ---- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 45 +++---------------- 5 files changed, 6 insertions(+), 118 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..2751719fc61d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -138,25 +138,4 @@ NB_MODULE(_mosaic_gpu_ext, m) { .def_property_readonly("swizzle", [](MlirAttribute self) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); - - mlir::python::nanobind_adaptors::mlir_attribute_subclass( - m, "LayoutAttr", mlirMosaicGpuIsALayoutAttr) - .def_classmethod( - "get", - [](nb::object cls, int32_t num_dimensions, - std::vector& transforms, MlirContext ctx) { - return cls(mlirMosaicGpuLayoutAttrGet( - ctx, num_dimensions, transforms.data(), transforms.size())); - }, - nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), - nb::arg("context").none() = nb::none(), - "Creates a LayoutAttr with the given transforms.") - .def_property_readonly("transforms", [](MlirAttribute self) { - std::vector result; - for (int i = 0; i < mlirMosaicGpuLayoutAttrGetTransformsSize(self); - ++i) { - result.push_back(mlirMosaicGpuLayoutAttrGetTransform(self, i)); - } - return result; - }); } diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index 259c37fe5d07..523b14e425c9 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -16,7 +16,6 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include -#include #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" @@ -97,36 +96,3 @@ int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { .getSwizzle() .getValue()); } - -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr) { - return mlir::isa(unwrap(attr)); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGet(MlirContext ctx, - int32_t num_dimensions, - MlirAttribute* transforms, - int32_t transforms_size) { - std::vector unwrapped_transforms; - unwrapped_transforms.reserve(transforms_size); - for (int i = 0; i < transforms_size; ++i) { - unwrapped_transforms.push_back(unwrap(transforms[i])); - } - return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, - unwrapped_transforms)); -} - -int32_t mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { - return mlir::cast(unwrap(attr)) - .getTransforms() - .size(); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, - int32_t index) { - return wrap( - mlir::cast(unwrap(attr)).getTransforms()[index]); -} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h index 3b8425b6b142..3221b9220e5d 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -69,22 +69,6 @@ mlirMosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); MLIR_CAPI_EXPORTED int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, - MlirAttribute* transforms, int32_t transforms_size); - -MLIR_CAPI_EXPORTED int32_t -mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); - #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 0a36aab6fbcd..073697df58ef 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -407,14 +407,6 @@ llvm::LogicalResult CustomPrimitiveOp::verify() { return llvm::success(); } -mlir::AffineMap LayoutAttr::getAffineMap() const { - // This always returns an identity map. It's technically not correct, but we - // don't actually use it anywhere. It's only called during verification of the - // layout attribute and needs to be semi-valid. - return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), - getContext()); -} - void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0d954716b179..cda521855250 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -225,27 +225,6 @@ def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { let assemblyFormat = "`<` $swizzle `>`"; } -def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", - [DeclareAttrInterfaceMethods]> { - let parameters = (ins - TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, - ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms - ); - - let summary = "Specifies a layout of a memref in SMEM."; - let description = [{ - This layout attribute is used to specify the layout of a memref in SMEM. - It is composed of a number of transforms, which are applied in the order - they are provided. The transforms can be any combination of: - - TileTransformAttr - - TransposeTransformAttr - - SwizzleTransformAttr - - The num_dimensions parameter must match the rank of the memref shape. - }]; - let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; -} - def MosaicGPU_AsyncLoadOp : Op { let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; @@ -265,16 +244,9 @@ def MosaicGPU_AsyncLoadOp : Op Date: Thu, 3 Apr 2025 06:18:14 -0700 Subject: [PATCH 0355/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/921c164a67e8ac4cf052aab26e849f29b719f802. PiperOrigin-RevId: 743535272 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 90a19ac95e51..c30648a2b3a1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c3087e022f3c07f7ed1dd4e47024c437a504341b" -XLA_SHA256 = "66457303ddec4dbbe43accf38a8b6b635d55808938cf2495443b09ee9c95a147" +XLA_COMMIT = "921c164a67e8ac4cf052aab26e849f29b719f802" +XLA_SHA256 = "9e734da4a0211ac09a00cc07969645e31f107cfee19bbc5d2d1e21ddbb19090d" def repo(): tf_http_archive( From 8d59902e735dbf17dcc7c70bb4c76f858eb93dde Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 3 Apr 2025 15:00:29 +0100 Subject: [PATCH 0356/1769] Fix problem finding clang++ when building JAX via build.py on windows. It's important we use the un-stemmed name because on Windows there is an .exe suffix. --- build/tools/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/tools/utils.py b/build/tools/utils.py index ccce8aff09cc..c52b89a1e6d2 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -204,7 +204,7 @@ def get_clang_major_version(clang_path): def get_clangpp_path(clang_path): clang_path = pathlib.Path(clang_path) - clang_exec_name = clang_path.stem + clang_exec_name = clang_path.name clangpp_exec_name = clang_exec_name if "clang++" not in clang_exec_name: clangpp_exec_name = clang_exec_name.replace("clang", "clang++") From 91b0884ad131ebddd69951927533b3ab12ec4113 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 08:05:22 -0700 Subject: [PATCH 0357/1769] Restrict the regex for copying the wheels. The change is made to address the case when bazel dir has multiple wheels with different version suffixes. We need to copy only those wheels that were created by the current execution of build.py script. PiperOrigin-RevId: 743566122 --- build/build.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/build/build.py b/build/build.py index 226d984b3d89..87aa36aeba8b 100755 --- a/build/build.py +++ b/build/build.py @@ -389,6 +389,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -621,6 +626,17 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) + + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + wheel_git_hash = option.split("=")[-1][:9] + if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda_libraries_from_stubs") @@ -729,10 +745,29 @@ async def main(): dst_dir = os.path.join(output_path, wheel_dir) utils.copy_dir_recursively(src_dir, dst_dir) else: - utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl") + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) if wheel == "jax": utils.copy_individual_files( - bazel_dir, output_path, f"{wheel_dir}*.tar.gz" + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", ) # Exit with success if all wheels in the list were built successfully. From 1941714d261daffd3f164d87a3bf8dd89d996211 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 3 Apr 2025 10:25:02 +0100 Subject: [PATCH 0358/1769] [export] Add support for override_lowering_rules to jax.export. This parameter is already part of the internal API for the AOT lowering function, here we just expose it to `jax.export`. --- jax/_src/config.py | 2 +- jax/_src/export/_export.py | 15 +++++++++++++-- jax/_src/interpreters/mlir.py | 2 +- tests/export_test.py | 13 ++++++++++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 5b8b87be2095..b4a12dcc1762 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -998,7 +998,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' - 'tracing cache), log an explanation.. Logging is performed with ' + 'tracing cache), log an explanation. Logging is performed with ' '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 9b6a0f80930f..90cc0c186ad1 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -529,6 +529,7 @@ def export( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), + _override_lowering_rules: Sequence[tuple[Any, Any]] | None = None ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -541,6 +542,13 @@ def export( If None, then use the default JAX backend. The calling convention for multiple platforms is explained at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + _override_lowering_rules: an optional sequence of custom lowering rules + for some JAX primitives. Each element of the sequence is a pair + of a JAX primitive and a lowering function. Defining lowering rules + is an advanced feature using JAX internal APIs, which are subject + to change. Furthermore, the responsibility for the stability of the + MLIR emitted through these custom lowering rules, rests with the user + of these rules. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -568,7 +576,8 @@ def export( Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ return _export_internal(fun_jit, platforms=platforms, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + override_lowering_rules=_override_lowering_rules) # TODO(necula): remove this once we improve the integration with jax2tf. @@ -577,7 +586,8 @@ def _export_internal( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, + override_lowering_rules=None, ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. @@ -604,6 +614,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: lowered = traced.lower( lowering_platforms=actual_lowering_platforms, _private_parameters=mlir.LoweringParameters( + override_lowering_rules=override_lowering_rules, for_export=True, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a1b37876f87e..a112063ce3ae 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -663,7 +663,7 @@ def __init__(self, @dataclasses.dataclass(frozen=True) class LoweringParameters: # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over + # When lowering a primitive, give priority to the rule in this map over # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None diff --git a/tests/export_test.py b/tests/export_test.py index 0b78a29a8e6a..2264fbdd997b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -19,7 +19,6 @@ import dataclasses import functools import logging -import json import math import re import unittest @@ -281,6 +280,18 @@ def test_unused_args(self): self.assertAllClose(f(x, y), exp_f.call(x, y)) + def test_override_lowering_rules(self): + @jax.jit + def f(x): + return jnp.sin(x) + + def my_lowering_rule(ctx, arg, **_): + return mlir.hlo.CosineOp(arg).results + + exp = get_exported(f, _override_lowering_rules=( + (lax.sin_p, my_lowering_rule),))(42.) + self.assertIn("stablehlo.cosine", exp.mlir_module()) + def test_pytree(self): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) From f2f9152d573fb6f09ce2a500d5602b4aea14075b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 08:27:10 -0700 Subject: [PATCH 0359/1769] Moved the `jax.Array` baseclass to C++ This allows `ArrayImpl` to directly subclass `jax.Array` without relying on the expensive virtual subclasses machinery from `abc`. PiperOrigin-RevId: 743573028 --- jax/BUILD | 1 + jax/_src/array.py | 6 +-- jax/_src/basearray.py | 49 ++++++++++++------- jax/_src/basearray.pyi | 8 +-- jaxlib/xla/py_array.cc | 70 ++++++++++++++++++++++++++- jaxlib/xla/xla_client.py | 3 +- jaxlib/xla/xla_extension/__init__.pyi | 1 + 7 files changed, 111 insertions(+), 27 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 5d37a8987445..f5745df0e5bf 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -400,6 +400,7 @@ pytype_strict_library( deps = [ ":partition_spec", ":sharding", + ":util", "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/_src/array.py b/jax/_src/array.py index ee196026887d..760593da9fa9 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -39,6 +39,7 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import jaxlib_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -1093,9 +1094,8 @@ def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -# TODO(jakevdp) replace this with true inheritance at the C++ level. -basearray.Array.register(ArrayImpl) - +if jaxlib_extension_version < 325: + basearray.Array.register(ArrayImpl) def _array_mlir_constant_handler(val): try: diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index fbd14d157e78..3aabba4440ec 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -17,10 +17,15 @@ from __future__ import annotations import abc +from collections.abc import Sequence import sys -import numpy as np from typing import Any, Union -from collections.abc import Sequence + +from jax._src.lib import jaxlib_extension_version +from jax._src.lib import xla_client as xc +from jax._src.util import use_cpp_class +import numpy as np + # TODO(jakevdp): fix import cycles and define these. Device = Any @@ -30,7 +35,9 @@ # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include # future non-standard array types like KeyArray and BInt. -class Array(abc.ABC): + + +class Array: """Array base class for JAX ``jax.Array`` is the public interface for instance checks and type annotation @@ -48,8 +55,6 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace :func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`, :func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc. """ - # Note: abstract methods for this class are defined dynamically in - # lax_numpy.py # For the sake of static type analysis, these definitions are mirrored in the # associated basearray.pyi file. @@ -57,42 +62,41 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace __hash__ = None @property - @abc.abstractmethod def dtype(self) -> np.dtype: """The data type (:class:`numpy.dtype`) of the array.""" + raise NotImplementedError @property - @abc.abstractmethod def ndim(self) -> int: """The number of dimensions in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def size(self) -> int: """The total number of elements in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def shape(self) -> tuple[int, ...]: """The shape of the array.""" + raise NotImplementedError # Documentation for sharding-related methods and properties defined on ArrayImpl: - @abc.abstractmethod def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" + raise NotImplementedError @property - @abc.abstractmethod def addressable_shards(self) -> Sequence[Shard]: """List of addressable shards.""" + raise NotImplementedError @property - @abc.abstractmethod def global_shards(self) -> Sequence[Shard]: """List of global shards.""" + raise NotImplementedError @property - @abc.abstractmethod def is_fully_addressable(self) -> bool: """Is this Array fully addressable? @@ -104,19 +108,19 @@ def is_fully_addressable(self) -> bool: a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable. """ + raise NotImplementedError @property - @abc.abstractmethod def is_fully_replicated(self) -> bool: """Is this Array fully replicated?""" + raise NotImplementedError @property - @abc.abstractmethod def sharding(self) -> Sharding: """The sharding for the array.""" + raise NotImplementedError @property - @abc.abstractmethod def committed(self) -> bool: """Whether the array is committed or not. @@ -141,17 +145,17 @@ def committed(self) -> bool: See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices for more information. """ + raise NotImplementedError @property - @abc.abstractmethod def device(self) -> Device | Sharding: """Array API-compatible device attribute. For single-device arrays, this returns a Device. For sharded arrays, this returns a Sharding. """ + raise NotImplementedError - @abc.abstractmethod def copy_to_host_async(self): """Copies an ``Array`` to the host asynchronously. @@ -166,10 +170,19 @@ def copy_to_host_async(self): array, but does not wait for the copy to complete. This may speed up a future on-host access to the array's contents. """ + raise NotImplementedError + + +if jaxlib_extension_version >= 325: + Array = use_cpp_class(xc.Array)(Array) +else: + class Array(Array, metaclass=abc.ABCMeta): + ... Array.__module__ = "jax" + # StaticScalar is the Union of all scalar types that can be converted to # JAX arrays, and are possible to mark as static arguments. StaticScalar = Union[ diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a368b593332d..8bf68f622051 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,11 +14,12 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable, Union import numpy as np -from jax._src.sharding import Sharding from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding + # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @@ -39,7 +40,8 @@ Traceback = Any PrecisionLike = Any -class Array(abc.ABC): +# TODO(slebedev): Remove the metaclass once ``jax_extension_version >= 325``. +class Array(metaclass=abc.ABCMeta): aval: Any @property diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index a1937bc80327..ce5ceacbad99 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -237,12 +237,33 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( return *std::move(ifrt_array); } +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyBaseArrayObject* self) { + PyObject_GC_UnTrack(self); + PyObject_ClearWeakRefs((PyObject*)self); + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + struct PyArrayObject { PyObject_HEAD; #if PY_VERSION_HEX < 0x030C0000 PyObject* weakrefs; PyObject* dict; -#endif // PY_VERSION_HEX < 0x030B0000 +#endif // PY_VERSION_HEX < 0x030C0000 bool initialized; alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; }; @@ -1879,6 +1900,23 @@ absl::Status PyHostValue::CopyToHostAsync( } namespace { +PyMemberDef PyBaseArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyBaseArrayObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PyBaseArray_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyBaseArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + PyGetSetDef PyArray_tp_getset[] = { {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr}, @@ -1911,6 +1949,34 @@ PyType_Slot PyArray_slots[] = { } // namespace absl::Status PyArray::RegisterTypes(nb::module_& m) { + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec PyBaseArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(base_name.c_str()), +#else + /*.name=*/base_name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyBaseArray_slots}; + auto* base_type = PyType_FromSpec(&PyBaseArray_spec); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = nb::borrow(base_type); + std::string name = absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); @@ -1934,7 +2000,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { /*.slots=*/PyArray_slots, }; - type_ = PyType_FromSpec(&PyArray_spec); + type_ = PyType_FromSpecWithBases(&PyArray_spec, base_type); if (!type_) { throw nb::python_error(); } diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index af751a00ab25..523f8bb57b90 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 324 +_version = 325 # An internal increasing version number for protecting jaxlib code against # ifrt changes. @@ -486,6 +486,7 @@ def window_padding_type_to_pad_values( FftType = _xla.FftType Client = _xla.Client Memory = _xla.Memory +Array = _xla.Array ArrayImpl = _xla.ArrayImpl LoadedExecutable = _xla.LoadedExecutable DeviceList = _xla.DeviceList diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index d002080b17bc..2d759236b8c5 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -656,6 +656,7 @@ def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... +Array = Any ArrayImpl = Any # TODO(phawkins): this type is problematic because it is not a subtype of From e7a5147638ba7ef2ede25ee7e2b7b29ea355a495 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 3 Apr 2025 08:43:29 -0700 Subject: [PATCH 0360/1769] Bump up tolerance in ShardMapSystematicTest.test_vmap_closure for GPUs. There's a mismatch between the resulting and the desired matrixes on H100, but not the older GPUs. PiperOrigin-RevId: 743578025 --- tests/shard_map_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 36966fde2a90..520cc02638df 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3091,7 +3091,7 @@ def g(*args): else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) - tol = 1e-2 if jtu.test_device_matches(['tpu']) else None + tol = 1e-2 if jtu.test_device_matches(['gpu', 'tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') From d1009a3bcda3aeb4667f9822a2379c1bf7718b56 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Thu, 3 Apr 2025 16:35:11 +0000 Subject: [PATCH 0361/1769] Only trigger K8s CI on changes to cluster config and distributed initialize --- .github/workflows/k8s.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 31ee05a03482..a96ce1ead26c 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -4,9 +4,17 @@ on: push: branches: - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' pull_request: branches: - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' permissions: contents: read From 42735d04f1249026ce2fd223e20a02d78c18ff7b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 09:44:12 -0700 Subject: [PATCH 0362/1769] Not to use dynamic grid in the ragged paged attention Pallas kernel. We found a hanging issue when we use dynamic grid. We'll disable it for now. PiperOrigin-RevId: 743597352 --- jax/experimental/pallas/ops/tpu/ragged_paged_attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index e1eacee550a7..d775e1331bcb 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -315,7 +315,9 @@ def prefetch_first_kv_blk(): def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -680,14 +682,14 @@ def ragged_paged_attention( validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) if mask_value is None: mask_value = DEFAULT_MASK_VALUE - _, num_q_heads, head_dim = q.shape + num_q_tokens, num_q_heads, head_dim = q.shape _, page_size, num_combined_kv_heads, _ = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = cdiv(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) From 780c8827f296fa692cb08bb9a1abd198e8cf8efe Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 3 Apr 2025 11:19:11 -0700 Subject: [PATCH 0363/1769] [Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline. PiperOrigin-RevId: 743633331 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 3 ++- tests/pallas/mosaic_gpu_test.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index ec088f43c4b2..257ecbf5da4a 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -629,7 +629,8 @@ def compute_loop_body(step, carry): last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) + bref.copy_out(_get_slot(last_slot, has_seq_dim=False), + last_indices, predicate=None) gpu_primitives.commit_smem_to_gmem_group() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 06dfd453fb19..8fd98f62eab1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2142,6 +2142,7 @@ class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): + self.skipTest("TODO(justinfu): Temporary skip for 3.12 update.") self.skip_if_wg_semantics() # Times out! x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) From c2eb9c1d9eff42eb05cac697759bcf8a5aeaf805 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 13:01:35 -0700 Subject: [PATCH 0364/1769] Eliminate DeprecationWarning in python3.12+ in jax pallas for ~. The code was using ~ with a boolean, which leads to a new DeprecationWarning. That should only be used with ints. PiperOrigin-RevId: 743668386 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 257ecbf5da4a..21efbbec6630 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -429,9 +429,9 @@ def body( # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. in_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in in_specs] + not _is_index_invariant(spec, grid) for spec in in_specs] out_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in out_specs] + not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] num_pipeline_steps = math.prod(grid) @@ -516,13 +516,13 @@ def scoped_pipeline( consumed_barrier_refs, ): in_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs ) ] out_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs ) @@ -553,7 +553,7 @@ def compute_loop_body(step, carry): body_refs = [] for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) + buf_slot = _get_slot(slot, not bref.is_index_invariant) body_refs.append(bref.get_ref_for_slot(buf_slot)) body_args = body_refs @@ -586,7 +586,7 @@ def compute_loop_body(step, carry): new_store_slices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) - bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), + bref.copy_out(_get_slot(slot, not bref.is_index_invariant), indices, predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() @@ -645,7 +645,7 @@ def memory_block(): # Begin initial copies. for step in range(max_concurrent_steps): for bref, barrier in zip(in_brefs, in_smem_barrier_refs): - buf_slot = _get_slot(step, ~bref.is_index_invariant) + buf_slot = _get_slot(step, not bref.is_index_invariant) bref.copy_in(buf_slot, indices, barrier) indices = _inc_grid_by_1(indices, grid) @@ -668,7 +668,7 @@ def memory_loop_body(step, carry): if manual_consumed_barriers: gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error bref.copy_in( - _get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier) + _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, From 41868ef06dd7e7da88f800071da040c6819b5707 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 3 Apr 2025 21:46:10 +0000 Subject: [PATCH 0365/1769] format --- jax/_src/nn/functions.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index d0f5f770e196..27436b01216a 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1218,11 +1218,11 @@ def scaled_matmul( ) -> Array: r"""Scaled matrix multiplication function. - Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. The last dim is the contracting dim, and block size is inferred. Mathematically, this operation is equivalent to:: - + a_block_size = a.shape[-1] // a_scales.shape[-1] b_block_size = b.shape[-1] // b_scales.shape[-1] a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) @@ -1258,26 +1258,26 @@ def scaled_matmul( Basic case: - >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) - >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) - >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) - >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) - >>> scaled_matmul(a, b, a_scales, b_scales) - Array([[[8.]]], dtype=float32) - + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) + Array([[[8.]]], dtype=float32) + Using fused cuDNN call on Blackwell GPUs: - >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) - >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) - >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) - >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) - >>> scaled_matmul(a, b, a_scales, b_scales) + >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) """ if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): raise ValueError( "scaled_matmul requires all inputs to be 3-dimensional arrays" ) - + B_a, M_a, K_a = a.shape B_b, N_b, K_b = b.shape if K_a != K_b or B_a != B_b: @@ -1286,7 +1286,7 @@ def scaled_matmul( f"and contract (K) dimensions, but got shapes {a.shape} and " f"{b.shape}" ) - + B_as, M_as, K_as = a_scales.shape B_bs, N_bs, K_bs = b_scales.shape if K_as != K_bs or B_as != B_bs: @@ -1295,7 +1295,7 @@ def scaled_matmul( f"contract (K) dimensions, but got shapes {a_scales.shape} and " f"{b_scales.shape}" ) - + if M_as != M_a or N_bs != N_b: raise ValueError( "scaled_matmul requires scales to match non-contract dimensions of " @@ -1378,7 +1378,7 @@ def scaled_dot_general( lhs, rhs, and gradients. Users can obtain valid configurations via `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` are supported. If `None`, falls back to `lax.dot_general`. - + Returns: Array: The resulting tensor, with batch dimensions first, followed by non-contracting/non-batch dimensions of lhs, and then those of rhs. @@ -1405,6 +1405,7 @@ def scaled_dot_general( Using scaled_dot_general with the configs: + >>> import functools >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) >>> lhs = random.normal(keys[0], (3, 128, 64)) >>> rhs = random.normal(keys[1], (3, 128, 64)) From cb67d5646f94918b9b4dfb2bc742aec698cc62f7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 3 Apr 2025 15:54:42 -0700 Subject: [PATCH 0366/1769] [Mosaic GPU] Re-enable WS pipelined copy test. PiperOrigin-RevId: 743727350 --- tests/pallas/mosaic_gpu_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 8fd98f62eab1..06dfd453fb19 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2142,7 +2142,6 @@ class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): - self.skipTest("TODO(justinfu): Temporary skip for 3.12 update.") self.skip_if_wg_semantics() # Times out! x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) From 3901014f9ab8451ec50d04f56844b6d56c6a1fd8 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 16:06:54 -0700 Subject: [PATCH 0367/1769] [pallas:mgpu] General ref transform handling at lowering time. Replace `_handle_reshape()` and `_handle_index()` with a general `_handle_transform()` that applies all transforms except tiling and (optionally) transposes. The implementation is based on `_untransform_{transpose,reshape,index}()` transform methods on transforms that find the conjugate of the transpose/reshape/index wrt the transform. PiperOrigin-RevId: 743731515 --- jax/_src/pallas/mosaic_gpu/core.py | 58 ++++++++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 99 ++++++++++-------------- jax/_src/pallas/mosaic_gpu/primitives.py | 20 +++-- 3 files changed, 112 insertions(+), 65 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index b0d4f23c792e..0a949840ab62 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -36,6 +36,7 @@ from jax._src.state import indexing from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp from jaxlib.mlir import ir @@ -263,6 +264,29 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + # The transpose in question is applied to the utiled ref so we + # need to translate it by duplicating and offseting the last part. + off = len(perm) + new_suffix = [i + off for i in perm[-len(self.tiling) :]] + if set(new_suffix) != set(range(off, off + len(self.tiling))): + raise ValueError( + "Transpose cannot be moved before a tiling transform when it changes" + f" the set of tiled dimensions. (permutation: {perm}, tiling:" + f" {self.tiling})" + ) + + new_tiling = tuple(self.tiling[i - off] for i in new_suffix) + return (*perm, *new_suffix), dataclasses.replace(self, tiling=new_tiling) + + def untransform_reshape( + self, dtype: jnp.dtype, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del dtype + raise NotImplementedError("Reshapes don't commute with transposes.") + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: @@ -352,6 +376,19 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm + ) -> tuple[tuple[int, ...], state_types.Transform]: + raise NotImplementedError( + "Commuting of transpose over transpose is not supported." + ) + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del shape, dtype + raise NotImplementedError("Can't reshape a transposed memref.") + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: @@ -436,6 +473,27 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(state_types.Transform): swizzle: int = dataclasses.field(metadata=dict(static=True)) + def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int: + if not isinstance(dtype, ir.Type): + dtype = mgpu_utils.dtype_to_ir_type(dtype) + return (self.swizzle * 8) // mgpu.bitwidth(dtype) + + def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Transform]: + if perm[-1] != len(perm) - 1: + raise ValueError("Can't transpose the swizzled dimension.") + + return perm, self + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + if shape[-1] == self.swizzle_elems(dtype): + raise ValueError( + f"Reshape shape {shape} is not divisible by swizzle elements" + f" {self.swizzle_elems(dtype)}" + ) + return shape, self + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index aafab927d4c2..d8e2083a845c 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1024,61 +1024,49 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), ) - -def _handle_reshaping( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - is_trivial_indexer = lambda t: isinstance( - t, indexing.NDIndexer - ) and gpu_core.is_trivial_index(t.indices, t.shape) - - last_reshaper_idx = next( - reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), - None, - ) - if last_reshaper_idx is None: - return ref, transforms - # Check that before the reshape are only trivial indexes and or - # other reshapes. - # TODO(cperivol): Reshapes should bubble up rather than being - # expected to effectively be the first ref transform. - if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): - raise NotImplementedError( - "Reshapes do not compose with other transforms and indexers must be" - f" trivial (transforms: {transforms})" - ) - reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) - # Skip all the reshapes and trivial indexes. - return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] - - -def _handle_indexing( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] +def _handle_transforms( + ref: ir.Value, + transforms: Sequence[gpu_core.Transform], + *, + handle_transposes=True, + handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - if not transforms: - pass - indexer_idxs = [ - i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer) - ] - if not indexer_idxs: - return ref, transforms - sliced_ref = ref + transformed_ref = ref new_transforms = [] - for t in transforms: - if not isinstance(t, indexing.NDIndexer): - new_transforms.append(t) - continue - indexer = cast(indexing.NDIndexer, t) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) + def _bubble_up(untransform_fn, data): + nonlocal new_transforms new_transforms_rev = [] for t in reversed(new_transforms): - indices, new_t = t.untransform_index(indices) + data, new_t = untransform_fn(t, data) new_transforms_rev.append(new_t) - sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) - return sliced_ref, new_transforms + return data + + for t in transforms: + match t: + case indexing.NDIndexer(): + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + indices = _bubble_up( + lambda t, idxs: t.untransform_index(idxs), indices + ) + transformed_ref = mgpu.memref_slice(transformed_ref, indices) + case gpu_core.TransposeRef(perm) if handle_transposes: + perm = _bubble_up(lambda t, p: t.untransform_transpose(p), + perm) + transformed_ref = mgpu.memref_transpose(transformed_ref, perm) + case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: + shape = _bubble_up( + lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop + shape) + transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case _: + new_transforms.append(t) + + return transformed_ref, new_transforms def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: @@ -1120,8 +1108,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_ref, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_ref, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): @@ -1152,8 +1139,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( @@ -1180,8 +1166,7 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (8, swizzle // x_aval.dtype.itemsize): @@ -1227,9 +1212,7 @@ def _swap_lowering_rule_wg( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) - + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 46ec8a87082e..fe5319113a03 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -85,8 +85,7 @@ def _load_p_lowering_rule( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(args_tree, leaves) - x_ref, transforms = lowering._handle_reshaping(x_ref, transforms) - x_ref, transforms = lowering._handle_indexing(x_ref, transforms) + x_ref, transforms = lowering._handle_transforms(x_ref, transforms) if layout is not None: layout = layout.to_mgpu() @@ -209,7 +208,7 @@ def _copy_smem_to_gmem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_indexing(src, src_transforms) + src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: ctx.launch_ctx.async_copy( @@ -382,7 +381,7 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms) + dst, dst_transforms = lowering._handle_transforms(dst, dst_transforms, handle_transposes=False) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -743,7 +742,7 @@ def _wgmma_lowering( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_indexing(a, a_transforms) + a, a_transforms = lowering._handle_transforms(a, a_transforms) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize @@ -760,7 +759,9 @@ def _wgmma_lowering( ) b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_indexing(b, b_transforms) + b, b_transforms = lowering._handle_transforms( + b, b_transforms, handle_transposes=False, handle_reshapes=False + ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): @@ -787,6 +788,8 @@ def _wgmma_lowering( f" {rhs_tiling}." ) + # TODO(cperivol): Find a generic way to move this reshape into + # _handle_transforms. high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) rhs_transpose = False @@ -1107,9 +1110,12 @@ def _jaxpr_call_lowering_rule( for treedef, flat_ref in zip(ref_treedefs, flat_refs): ref = treedef.unflatten(flat_ref) if isinstance(ref, tuple): + ref, transforms = ref # We ignore other transforms here, because they are already embedded # in the jaxpr. - ref, _ = lowering._handle_indexing(*ref) + ref, _ = lowering._handle_transforms( + ref, transforms, handle_reshapes=False, handle_transposes=False + ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) for axis, pid in enumerate(program_ids): From bbdea54ccb1b9338b8aa6932043393551474050e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 18:08:39 -0700 Subject: [PATCH 0368/1769] add an `out_sharding` option to `jax.random.permutation` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 13 ++++++++++--- tests/pjit_test.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 1d044ec111ff..a21cdf89a61f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -556,7 +556,8 @@ def _randint(key, minval, maxval, shape, dtype) -> Array: def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, - independent: bool = False) -> Array: + independent: bool = False, + out_sharding=None) -> Array: """Returns a randomly permuted array or range. Args: @@ -573,11 +574,17 @@ def permutation(key: ArrayLike, key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) + out_sharding = canonicalize_sharding(out_sharding, "permutation") if not np.ndim(x): if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") - r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()') - return _shuffle(key, jnp.arange(r), axis) + r = core.concrete_or_error(int, x, "argument x of jax.random.permutation()") + return maybe_auto_axes(lambda key: _shuffle(key, jnp.arange(r), axis), + out_sharding)(key) + return maybe_auto_axes( + _permutation, out_sharding, axis=axis, independent=independent)(key, x) + +def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d3d9cab7a5ba..580cfcd7ad8d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7332,6 +7332,45 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((4,), ('x',)) + def test_random_permutation_1d(self, mesh): + @jax.jit + def f(key): + out = jax.random.permutation(key, 8, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[4]<=[4]}"}', lowered_text) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_permutation_2d(self, mesh): + @jax.jit + def f(key): + out = jax.random.permutation(key, jnp.arange(8 * 12).reshape(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_normal(self, mesh): @jax.jit From 7583814e35c85b9df55eb6ed65c4559207262f33 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 16:28:22 -0700 Subject: [PATCH 0369/1769] [mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT. It is up to _handle_transposes() to check that the swizzle dimension is not transposed rather than `UnswizzleRef.untransform_transpose()`. This allows us to disable the check in certain situations where mgpu can handle it like wgmma and swap_p when storing a WGMMA_TRANSPOSED_LAYOUT. If this check is completely skipped it can cause the kernel to crash at runtime. Furthermore this CL adds a test to check this behavior. PiperOrigin-RevId: 743738166 --- jax/_src/pallas/mosaic_gpu/lowering.py | 30 ++++++++++++++++++-- tests/pallas/mosaic_gpu_test.py | 38 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d8e2083a845c..80757ef69e64 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1166,13 +1166,37 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(x_smem, transforms) + transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + x_smem, transforms = _handle_transforms( + x_smem, transforms, handle_transposes=not transposed_value + ) match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") + + if transposed_value != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) + old_value = mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + swizzle=swizzle, + layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) return old_value diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 06dfd453fb19..754e53255438 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1372,6 +1372,44 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) + def test_wgmma_transposed_layout(self): + """Tests that the result of wgmma can be store transposed using + the WGMMA_TRNASPOSED layout. + """ + + dtype = jnp.dtype(jnp.float16) + swizzle_elems = 128 // dtype.itemsize + shape = (128, 128) + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM( + shape, dtype, + transforms=( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ), + ) + ] + ) + def kernel(o_ref, smem): + iota = plgpu.broadcasted_iota( + dtype, o_ref.shape, 0, layout=plgpu.Layout.WGMMA + ) * o_ref.shape[0] + iota += plgpu.broadcasted_iota( + dtype, o_ref.shape, 1, layout=plgpu.Layout.WGMMA + ) + + smem_trns = plgpu.transpose_ref(smem, (1, 0)) + smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem, o_ref) + + x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T + np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): self.skip_if_wg_semantics() # Transform inference fails. From 26fc1cde4cdb593239f796a37834645184ac10fb Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 17:18:18 -0700 Subject: [PATCH 0370/1769] [pallas:mgpu] Initial version of inline_mgpu op PiperOrigin-RevId: 743751560 --- jax/_src/pallas/mosaic_gpu/primitives.py | 94 +++++++++++++++++++++++- jax/experimental/pallas/mosaic_gpu.py | 2 + tests/pallas/mosaic_gpu_test.py | 34 +++++++++ 3 files changed, 129 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index fe5319113a03..37d71cd6d1c6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Sequence, Callable import dataclasses import enum import itertools @@ -1218,3 +1218,95 @@ def jaxpr_call( ref_treedefs=ref_treedefs, program_ids_treedef=program_ids_treedef, ) + +inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") +inline_mgpu_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class RefType: + ... + + +def inline_mgpu(*args, arg_types): + flat_args, treedef = jax.tree.flatten(tuple(args)) + flat_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + raw_refs_flat_args = [] + for a, t in zip(flat_args, flat_types): + def traced_ty(ty): + return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) + + if isinstance(t, (ParameterizedLayout, Layout)) and traced_ty(jax_core.ShapedArray): + raw_refs_flat_args.append(a) + elif isinstance(t, RefType) and traced_ty(_Ref): + ref, transforms = a, () + if isinstance(a, state_types.TransformedRef): + ref, transforms = ref.ref, ref.transforms + + raw_refs_flat_args.append(ref) + if transforms: + raise NotImplementedError("Transformed refs (or types) are not supported.") + else: + raise ValueError(f"Mismatched type: {a, t}") + + def inner(f): + return inline_mgpu_p.bind( + *raw_refs_flat_args, + args_treedef=treedef, + flat_types=flat_types, + mgpu_fn=f, + ) + return inner + + +@inline_mgpu_p.def_effectful_abstract_eval +def _inline_mgpu_abstract_eval( + *flat_args, + args_treedef, + flat_types, + mgpu_fn, +): + del args_treedef, flat_types, mgpu_fn # Unused. + # TODO(cperivol): Let the user set the effects. + return (), { + gpu_core._wgmma_pipeline_effect, + gpu_core._memory_effect, + *itertools.chain.from_iterable( + (state.ReadEffect(i), state.WriteEffect(i)) + for i, r in enumerate(flat_args) + if isinstance(r, pallas_core.AbstractMemoryRef) + ), + } + + +@discharge.register_partial_discharge_rule(inline_mgpu_p) +def _inline_mgpu_discharge(*args, **kwargs): + raise NotImplementedError("inline_mgpu_p does not support discharge.") + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.ThreadSemantics.Lane) +def _inline_mgpu_lowering_rule( + ctx: lowering.LoweringRuleContext, + *flat_args, + mgpu_fn: Callable[..., Any], + flat_types, + args_treedef, +): + for a, t in zip(flat_args, flat_types, strict=True): + match a: + case ir.Value() if ir.MemRefType.isinstance(a.type): + # We checked the memory spaces at tracing time. + pass + case mgpu.FragmentedArray(): + if a.layout != t.to_mgpu(): + raise ValueError(f"Unexpected layout for {a} (expected: {t})") + case _: + raise ValueError(f"Unexpected argument {a}") + + args = jax.tree.unflatten(args_treedef, flat_args) + mgpu_fn(ctx.launch_ctx, *args) + return () diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index b44c86ea7a4c..1d3bebbc3757 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -43,9 +43,11 @@ from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast from jax._src.pallas.mosaic_gpu.primitives import load as load +from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 754e53255438..d35446359756 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -33,6 +33,7 @@ from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives from jax._src.state import discharge from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np @@ -369,6 +370,38 @@ def kernel(o_ref): kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) ) + def test_inline_mgpu(self): + dtype = jnp.bfloat16 + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128, 128), dtype), + plgpu.Barrier(num_arrivals=1), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, o_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + arr = jnp.ones_like(x_ref) + @plgpu.inline_mgpu( + smem_ref, + o_ref, + arr, + arg_types=[plgpu.RefType(), plgpu.RefType(), plgpu.Layout.WG_SPLAT(x_ref.shape)], + ) + def _(ctx, smem_ref, o_ref, y): + del ctx + x = mgpu.FragmentedArray.load_strided(smem_ref) + (x + y).store_untiled(o_ref) + + key = jax.random.key(0) + x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype) + np.testing.assert_array_equal(kernel(x), x + 1) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_smem_to_gmem(self, indexer): @functools.partial( @@ -1506,6 +1539,7 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { + mgpu_primitives.inline_mgpu_p, mgpu_primitives.broadcasted_iota_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, From d645172765886a0df06a3a7f58393b313681572f Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 3 Apr 2025 17:18:20 -0700 Subject: [PATCH 0371/1769] Delete `PjRtClient.Defragment`. The `Defragment` implementation for GPU is in `py_client.cc`, so this should be a no-op. PiperOrigin-RevId: 743751570 --- jaxlib/xla/py_client.cc | 7 ++++++- jaxlib/xla/xla_extension/__init__.pyi | 1 - tests/array_test.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 1e41d9cf8a0d..2ce11e7e76c7 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -204,9 +204,14 @@ absl::Status PyClient::Defragment() { platform_id == SyclId(); if (!is_gpu_client) { - return pjrt_client()->Defragment(); + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); } + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + struct TmpBuffer { // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays // can reference the same PjRtBuffer. diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 2d759236b8c5..3fe2de1e30a3 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -579,7 +579,6 @@ class Client: host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... - def defragment(self) -> _Status: ... def make_python_callback_from_host_send_and_recv( self, callable: Callable, diff --git a/tests/array_test.py b/tests/array_test.py index 5891db5a3e36..901ce9521da1 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -655,12 +655,15 @@ def f(x): output_shardings._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) - # TODO(skyewm): remove this test when we can remove the workaround manual - # defragment API - @jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU + # TODO(b/399879011): GPU is the only platform that has an implementation for + # this, which exists in py_client.cc. Ideally, this would be replaced with + # some kind of auto-defrag-on-OOM. + @jtu.run_on_devices('gpu') def test_defragment(self): + # Since the GPU implementation is in py_client.cc, it cannot be exposed via + # the PjRt C API. if xb.using_pjrt_c_api(): - self.skipTest("Manual defragment not exposed via PJRT C API") + self.skipTest('Manual defragment not exposed via PJRT C API') # Create a few arrays global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) @@ -673,7 +676,7 @@ def test_defragment(self): # Delete one of them arr2.delete() - # Defragment + # Defragment. xb.get_backend().defragment() # Sanity check remaining arrays From f8bbe98a860acd0d16ea0288f10839f7a0ed2d1d Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 3 Apr 2025 17:25:22 -0700 Subject: [PATCH 0372/1769] require `out_shardings` as a keyword-only argument on public functions PiperOrigin-RevId: 743753215 --- jax/_src/lax/lax.py | 12 +++++++----- jax/_src/random.py | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ac6054328f73..13511641558c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2460,6 +2460,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + *, out_sharding=None) -> Array: """General dot product/contraction operator. @@ -2667,7 +2668,7 @@ def ragged_dot_general( ) -def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None +def broadcast(operand: ArrayLike, sizes: Sequence[int], *, out_sharding=None ) -> Array: """Broadcasts an array, adding new leading dimensions @@ -2689,7 +2690,7 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], out_sharding=None + broadcast_dimensions: Sequence[int], *, out_sharding=None ) -> Array: """Wraps XLA's `BroadcastInDim `_ @@ -2732,7 +2733,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - out_sharding: NamedSharding | P | None = None) -> Array: + *, out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -3378,7 +3379,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - out_sharding=None) -> Array: + *, out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -8430,7 +8431,8 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT, out_sharding=None): + algorithm=RandomAlgorithm.RNG_DEFAULT, + *, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype diff --git a/jax/_src/random.py b/jax/_src/random.py index a21cdf89a61f..5cbd966e7a7b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -359,6 +359,7 @@ def maybe_auto_axes(f, out_shardings, **hoist_kwargs): def bits(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeUInt | None = None, + *, out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. @@ -393,6 +394,7 @@ def uniform(key: ArrayLike, dtype: DTypeLikeFloat = float, minval: RealArray = 0., maxval: RealArray = 1., + *, out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. @@ -466,6 +468,7 @@ def randint(key: ArrayLike, minval: IntegerArray, maxval: IntegerArray, dtype: DTypeLikeInt = int, + *, out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. @@ -557,6 +560,7 @@ def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, independent: bool = False, + *, out_sharding=None) -> Array: """Returns a randomly permuted array or range. @@ -707,6 +711,7 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. From 5b3e419515404de8650c169cb27ff00b1fb53340 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 3 Apr 2025 18:34:35 -0700 Subject: [PATCH 0373/1769] Add `auto_axes`, `explicit_axes` and `manual_axes` properties to Mesh and AbstractMesh PiperOrigin-RevId: 743767895 --- jax/_src/mesh.py | 15 +++++++++++++++ tests/array_test.py | 3 +++ 2 files changed, 18 insertions(+) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a8003e693459..00859f9b3d74 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -174,6 +174,21 @@ def _any_axis_auto(self) -> bool: def _any_axis_explicit(self) -> bool: return any_axis_types_match(self._axis_types, AxisType.Explicit) + @functools.cached_property + def auto_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Auto) + + @functools.cached_property + def explicit_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Explicit) + + @functools.cached_property + def manual_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Manual) + @functools.cached_property def _axis_types_dict(self): if not self.axis_names: diff --git a/tests/array_test.py b/tests/array_test.py index 901ce9521da1..2bdc54607473 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1391,6 +1391,9 @@ def test_make_mesh_axis_types(self): self.assertDictEqual( mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), AxisType.Manual: ('z',)}) + self.assertEqual(mesh.explicit_axes, ('x',)) + self.assertEqual(mesh.auto_axes, ('y',)) + self.assertEqual(mesh.manual_axes, ('z',)) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Explicit, Manual)) From c1bdd1a234ae6fa6c650426e1a4a3c04851e82da Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 3 Apr 2025 19:39:03 -0700 Subject: [PATCH 0374/1769] [Mosaic TPU] Allow specify priority in enqueueDMA. For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA. PiperOrigin-RevId: 743780185 --- jax/BUILD | 1 + jax/_src/tpu_custom_call.py | 14 ++++++++++++-- jaxlib/mosaic/dialect/tpu/tpu.td | 4 +++- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 13 ++++++++++++- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 10 +++++++++- 5 files changed, 37 insertions(+), 5 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index f5745df0e5bf..fe2e6b8d7df1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1073,6 +1073,7 @@ pytype_strict_library( srcs = ["_src/tpu_custom_call.py"], visibility = [":internal"], deps = [ + ":cloud_tpu_init", ":config", ":core", ":jax", diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index e37d5e064a26..f84db206f4d1 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -32,6 +32,7 @@ from jax._src import config from jax._src import core from jax._src import sharding_impls +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client @@ -64,7 +65,14 @@ # This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 3 +FWD_COMPAT_IR_VERSION = 4 +DEFAULT_IR_VERSION = None +# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date. +if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple( + jax.lib.__version__ +) < (0, 5, 4): + FWD_COMPAT_IR_VERSION = 3 + DEFAULT_IR_VERSION = 3 tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -671,7 +679,9 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ir_version=FWD_COMPAT_IR_VERSION + if ctx.is_forward_compat() + else DEFAULT_IR_VERSION, ) return _tpu_custom_call_lowering( ctx, diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4b5ed34934d7..0cd045621413 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -752,7 +752,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority ); let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 41342efeb1b4..5ed5e94b13c0 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -955,13 +955,24 @@ LogicalResult EnqueueDMAOp::verify() { "device_id or core_id is specified"); } } + bool is_remote = getDeviceId() || getCoreId(); if (getSourceSemaphore()) { - if (!getDeviceId() && !getCoreId()) { + if (!is_remote) { return emitOpError( "DMA destination device_id or core_id must be specified when source " "semaphore is specified"); } } + int priority = getPriority(); + if (priority < 0 || priority > 1) { + return emitOpError( + "Not implemented: only support priority 0 or 1, but got ") + << priority; + } + if (priority != 0 && is_remote) { + return emitOpError( + "Not implemented: non-zero priority is not supported for remote DMA"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 5f6c9bd712ff..e08149fe44fc 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -40,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 3; +constexpr int kVersion = 4; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; @@ -62,6 +62,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) { << op->getNumOperands(); } } + if (version < 4) { + op->setAttr("priority", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), 0)); + } return success(); } @@ -69,6 +74,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) { if (version < 2) { return op->emitError("Downgrade to version ") << version << " unsupported"; } + if (version < 4) { + op->removeAttr("priority"); + } return success(); } From a9bd1e3f9df474e769210d78f86ae829544c0e7b Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 3 Apr 2025 22:25:46 -0700 Subject: [PATCH 0375/1769] [Pallas TPU] Support DMA priority in async copy start For now, we can only specify priority 0 (on-demand) or priority 1 (background) in local DMA. Also added priority to pretty print by making `dma_start` to `dma_start(px)` which means priority x. Full example: ``` { lambda ; a:MemRef{int32[8,128]} b:MemRef{int32[8,128]} c:MemRef{int32[8,128]} d:MemRef{int32[8,128]} e:MemRef{int32[8,128]} f:MemRef{int32[8,128]} g:MemRef{dma_sem[]} h:MemRef{dma_sem[]}. let dma_start(p1) a[...] -> e[...] g[...] dma_start(p0) b[...] -> f[...] h[...] dma_wait e[...] g[...] dma_wait f[...] h[...] dma_start(p0) e[...] -> c[...] g[...] dma_start(p1) f[...] -> d[...] h[...] dma_wait c[...] g[...] dma_wait d[...] h[...] in () } ``` PiperOrigin-RevId: 743815050 --- jax/_src/pallas/mosaic/lowering.py | 23 ++++++++++++---- jax/_src/pallas/mosaic/primitives.py | 32 ++++++++++++++++------ tests/pallas/tpu_pallas_test.py | 41 +++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6c8b3c646a0d..0ca298c88dc5 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3531,8 +3531,14 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): return [] lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: primitives.DeviceIdType): + +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: primitives.DeviceIdType, + priority: int, +): ( src_ref, src_transforms, @@ -3564,10 +3570,17 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id) - + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + priority=priority, + ) return [] + + lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index c50a21218117..59856c0ca7b2 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -208,9 +208,14 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.device_id, )) - def start(self): + def start(self, priority: int = 0): flat_args, tree = self._get_args_and_tree() - dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) + dma_start_p.bind( + *flat_args, + tree=tree, + device_id_type=self.device_id_type, + priority=priority, + ) def wait(self): if self.is_remote: @@ -239,7 +244,9 @@ def wait_send(self): dma_start_p.multiple_results = True @dma_start_p.def_effectful_abstract_eval -def _dma_start_abstract_eval(*args, tree, device_id_type): +def _dma_start_abstract_eval(*args, tree, device_id_type, priority): + if priority < 0: + raise ValueError(f"DMA start priority must be non-negative: {priority}") ( src_ref_aval, src_transforms_avals, @@ -274,6 +281,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, settings: jax_core.JaxprPpSettings): invars = eqn.invars tree = eqn.params["tree"] + priority = eqn.params["priority"] ( src_ref, src_transforms, @@ -290,7 +298,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text("dma_start"), + pp.text(f"dma_start(p{priority})"), pp.text(" "), sp.pp_ref_transforms(context, src_ref, src_transforms), pp.text(" -> "), @@ -301,8 +309,12 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, - *args, tree, device_id_type): + +def dma_start_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, tree, device_id_type, priority +): + # Note: we ignore the DMA priority in discharge rules. + del priority ( src_ref, src_transforms, @@ -461,6 +473,7 @@ def do_discharge_src_sem(src_sem=src_sem): return new_vals, [] + state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule) @@ -550,6 +563,7 @@ def _get_ref_and_transforms(ref): return ref.ref, ref.transforms return ref, () + def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" src_ref, src_transforms = _get_ref_and_transforms(src_ref) @@ -568,12 +582,14 @@ def make_async_copy(src_ref, dst_ref, sem): primitives.DeviceIdType.MESH, ) -def async_copy(src_ref, dst_ref, sem): + +def async_copy(src_ref, dst_ref, sem, *, priority: int = 0): """Issues a DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_copy(src_ref, dst_ref, sem) - copy_descriptor.start() + copy_descriptor.start(priority=priority) return copy_descriptor + def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 128fe50687a0..2e773b88fbad 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1135,6 +1135,39 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) + def test_set_dma_priority(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 5): + self.skipTest('Needs a newer libTPU') + if jtu.get_tpu_version() < 5: + self.skipTest('Target does not support DMA prefetch between HBM and VMEM') + def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): + copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) + copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) + copy1.wait() + copy2.wait() + copy1 = pltpu.async_copy(scratch1, y1, sem1, priority=0) + copy2 = pltpu.async_copy(scratch2, y2, sem2, priority=1) + copy1.wait() + copy2.wait() + + shape = (8, 128) + dtype = jnp.int32 + x1 = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + x2 = x1 + 1 + y1, y2 = self.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + scratch_shapes=[pltpu.VMEM(shape, dtype)] * 2 + + [pltpu.SemaphoreType.DMA] * 2, + out_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + ), + out_shape=[jax.ShapeDtypeStruct(shape, dtype)] * 2, + )(x1, x2) + np.testing.assert_array_equal(y1, x1) + np.testing.assert_array_equal(y2, x2) + def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): @@ -2665,19 +2698,19 @@ class PrettyPrintingTest(PallasBaseTest): @parameterized.parameters( ( lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), - 'dma_start c[d,:,:] -> e[...] f', + 'dma_start(p0) c[d,:,:] -> e[...] f', ), ( lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), - 'dma_start c[0,d:d+8,:] -> e[...] f', + 'dma_start(p0) c[0,d:d+8,:] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), - 'dma_start c[d,2:6,:100] -> e[...] f', + 'dma_start(p0) c[d,2:6,:100] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), - 'dma_start c[d,2:,4:104] -> e[...] f', + 'dma_start(p0) c[d,2:,4:104] -> e[...] f', ), ) def test_dma_custom_pretty_print(self, indexer, expected): From 4f00249aa8bff45b379f76304e79e293273f9ad6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 22:30:45 -0700 Subject: [PATCH 0376/1769] [pallas:mosaic_gpu] Do not specify the default `index_map` in tests PiperOrigin-RevId: 743816110 --- tests/pallas/mosaic_gpu_test.py | 108 ++++++++++---------------------- 1 file changed, 32 insertions(+), 76 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d35446359756..0cfe9197db36 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -644,8 +644,6 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), @@ -676,9 +674,7 @@ def body(tmp_ref): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), @@ -719,8 +715,6 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), @@ -750,11 +744,7 @@ def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, 128), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -776,11 +766,7 @@ def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layo self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, m), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -819,24 +805,19 @@ def compute(acc_ref): out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) o_ref[...] = out - - out_spec = plgpu.GPUBlockSpec( - (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, - ) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( pl.BlockSpec(memory_space=src_memory_space), plgpu.GPUBlockSpec( - (k, n), - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ), ), - out_specs=out_spec, + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) out_ref = ( @@ -855,9 +836,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), @@ -960,11 +939,10 @@ def test_print_wgmma_tiled_layout(self): out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=[ plgpu.GPUBlockSpec( - shape, - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), - ), + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ) ) ], ) @@ -1143,7 +1121,8 @@ def test_swizzled_blockspec_shapes(self): (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ) @functools.partial( @@ -1327,9 +1306,7 @@ def test_tile_slicing(self): shape = (256, 128) block_spec = plgpu.GPUBlockSpec( - transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), - ) + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) @functools.partial( self.pallas_call, @@ -1380,12 +1357,7 @@ def rotate(src, dst): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) @@ -1560,11 +1532,9 @@ def test_fori_loop_accumulator(self, force_while): transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( self.pallas_call, - in_specs=[ - plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms) - ], + in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + out_specs=plgpu.GPUBlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1613,25 +1583,28 @@ def _epilogue(): if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( - lhs_spec.block_shape, lhs_spec.index_map, + lhs_spec.block_shape, + lhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) rhs_spec = plgpu.GPUBlockSpec( - rhs_spec.block_shape, rhs_spec.index_map, + rhs_spec.block_shape, + rhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) out_spec = plgpu.GPUBlockSpec( - out_spec.block_shape, out_spec.index_map, + out_spec.block_shape, + out_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) res = self.pallas_call( @@ -1717,14 +1690,9 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @@ -1747,17 +1715,10 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (64, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) @@ -1783,14 +1744,9 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) From 97cecdf862690e30da2296c90337176492f08e9e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 3 Apr 2025 21:40:40 -0700 Subject: [PATCH 0377/1769] add an `out_sharding` option to `jax.random.truncated_normal` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 7 +++++-- tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 5cbd966e7a7b..e632d4a9a2fa 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -851,7 +851,8 @@ def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -882,12 +883,14 @@ def truncated_normal(key: ArrayLike, if shape is not None: shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("truncated_normal", key) + out_sharding = canonicalize_sharding(out_sharding, "truncated_normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _truncated_normal(key, lower, upper, shape, dtype) + return maybe_auto_axes(_truncated_normal, out_sharding, + shape=shape, dtype=dtype)(key, lower, upper) @partial(jit, static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 580cfcd7ad8d..f2db913af736 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7423,6 +7423,26 @@ def f(arr, key): out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_truncated_normal(self, mesh): + @jax.jit + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key, -1.) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key, -1.).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_auto_axes_no_context_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) From 5eb4e7b2dc4d1ab9677cb7a22feac3f7933142ae Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 3 Apr 2025 23:17:25 -0700 Subject: [PATCH 0378/1769] [Mosaic GPU] Return the combined softmax residuals. It's scaled so that it can be used directly as an input to exp2 in the backwards pass. PiperOrigin-RevId: 743825330 --- .../pallas/ops/gpu/attention_mgpu.py | 112 +++++++++++++----- tests/pallas/mgpu_attention_test.py | 11 +- 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index d06d3b39cb7a..6a20b448ca54 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -43,8 +43,8 @@ def __post_init__(self): raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") -@functools.partial(jax.jit, static_argnames=["config"]) -def attention(q, k, v, config: TuningConfig): +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -69,12 +69,12 @@ def attention(q, k, v, config: TuningConfig): ) block_q, block_kv = config.block_q, config.block_kv - def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") - qo_smem2, k_smem, v_smem = smem_buffers + qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): @@ -85,6 +85,7 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q plgpu.copy_gmem_to_smem( @@ -162,15 +163,23 @@ def _wait(): 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) ) pl.when(wg_idx == 0)(perform_schedule_barrier) - del m_i # Not needed anymore # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): @@ -191,7 +200,7 @@ def kv_loop(kv_step, _): plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def entry(q_ref, k_ref, v_ref, out_ref): + def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) @@ -207,9 +216,12 @@ def entry(q_ref, k_ref, v_ref, out_ref): (max_concurrent_steps, block_kv, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, k_scratch, v_scratch, None] + if save_residuals: + scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, ( plgpu.Barrier(1, num_barriers=max_concurrent_steps), plgpu.Barrier(1, num_barriers=max_concurrent_steps), @@ -223,9 +235,17 @@ def entry(q_ref, k_ref, v_ref, out_ref): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - return plgpu.kernel( + out_shape = [q, None] + if save_residuals: + # Note that we keep seq_len in the minor-most dimension so that we can do + # 1D TMAs on chunks of `block_q`. + out_shape[1] = jax.ShapeDtypeStruct( + (batch_size, num_q_heads, q_seq_len), jnp.float32 + ) + + out, lse = plgpu.kernel( entry, - out_shape=q, + out_shape=out_shape, grid=(batch_size, num_q_tiles, num_q_heads), grid_names=("batch", "q_seq", "heads"), num_threads=3, @@ -233,8 +253,14 @@ def entry(q_ref, k_ref, v_ref, out_ref): compiler_params=plgpu.GPUCompilerParams(approx_math=True), )(q, k, v) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): + if save_residuals: + assert lse is not None + return out, (lse,) + + return out + +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -266,10 +292,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - qo_smem2, q_barriers, schedule_barrier = scoped + smem_buffers, q_barriers, schedule_barrier = scoped + qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -281,6 +308,7 @@ def perform_schedule_barrier(): def _compute_thread(): qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None m_i = plgpu.layout_cast( jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, ) @@ -299,15 +327,23 @@ def _compute_thread(): plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) final_carry = (yield (acc, m_i, l_i)) - del m_i # Unused pl.when(wg_idx == 0)(perform_schedule_barrier) - acc, _, l_i = final_carry + acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) def kv_pipeline(_, k_smem, v_smem, @@ -371,7 +407,7 @@ def compute_pv(acc_ref): thread_name="wg", ) def run(refs): - q_ref, k_ref, v_ref, out_ref = refs + q_ref, k_ref, v_ref, out_ref, lse_ref = refs @pl.core_map(mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True), ) @@ -380,22 +416,36 @@ def _kernel_entry(): (compute_wgs, block_q, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, None] + if save_residuals: + scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args), - qo_scratch, + lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, plgpu.Barrier(1, num_barriers=compute_wgs), plgpu.Barrier(num_arrivals=compute_wgs), ) @jax.jit - def run_function(q, k, v, o): - _, _, _, out = pl.run_state(run)((q, k, v, o)) - return out - out = run_function(q, k, v, jnp.full_like(q, jnp.inf)) + def run_function(q, k, v, o, lse): + _, _, _, out, lse = pl.run_state(run)((q, k, v, o, lse)) + return out, lse + + lse = ( + jnp.full((batch_size, num_q_heads, q_seq_len), -jnp.inf, dtype=jnp.float32) + if save_residuals + else None + ) + out, lse = run_function(q, k, v, jnp.full_like(q, jnp.inf), lse) + + if save_residuals: + assert lse is not None + return out, (lse,) + return out -@jax.jit -def attention_reference(q, k, v): +@functools.partial(jax.jit, static_argnames=["save_residuals"]) +def attention_reference(q, k, v, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) @@ -407,8 +457,16 @@ def attention_reference(q, k, v): unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) weights = unnormalized / l - return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) - + out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + if save_residuals: + log2e = math.log2(math.e) + l = l.reshape(*q.shape[:-1]) + m = m.reshape(*q.shape[:-1]) + lse = m * log2e + jnp.log2(l) + return out, (lse.swapaxes(-1, -2),) + else: + return out def main(unused_argv): num_q_heads = 16 diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index cf8ed30925bf..27588683d0e9 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -62,6 +62,7 @@ def setUp(self): attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), + save_residuals=(True,), ) def test_flash_attention( self, @@ -71,22 +72,28 @@ def test_flash_attention( num_q_and_kv_heads, head_dim, attention_impl, + save_residuals, ): num_q_heads, num_kv_heads = num_q_and_kv_heads k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - out = attention_impl( + out, *res = attention_impl( q, k, v, attention_mgpu.TuningConfig( block_q=64, block_kv=64, max_concurrent_steps=2 ), + save_residuals=save_residuals, ) - out_ref = attention_mgpu.attention_reference(q, k, v) + out_ref, *res_ref = attention_mgpu.attention_reference(q, k, v, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if save_residuals: + (lse,) = res[0] + (lse_ref,) = res_ref[0] + np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) if __name__ == "__main__": From 12b1a99ad943de56f776d9e18bcdfb351908927c Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 4 Apr 2025 11:35:46 +0500 Subject: [PATCH 0379/1769] fix(docs): corrected the name of the function call in the document --- docs/gradient-checkpointing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..e4e842df49f0 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) From e619fc0b72570cc4b8fe305ccc70c75f27c1a52f Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 00:02:47 -0700 Subject: [PATCH 0380/1769] Avoid double buffering when no windowing info is present. PiperOrigin-RevId: 743834475 --- jax/_src/pallas/mosaic/lowering.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0ca298c88dc5..67cacbc8dcf9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -743,6 +743,15 @@ def dynamic_shape_replacement_fn( block_shape = [ 1 if b is pallas_core.mapped else b for b in bm.block_shape ] + + # No sense in double-buffering without any windowing pattern. + buffer_count = 0 + if ( + tpu_memory_space == tpu_core.TPUMemorySpace.VMEM + and bm.has_trivial_window() + ): + buffer_count = 1 + # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) @@ -765,7 +774,8 @@ def dynamic_shape_replacement_fn( raise LoweringException( f"Unsupported pipeline mode: {bm.pipeline_mode}." ) - buffer_count = bm.pipeline_mode.buffer_count + if buffer_count == 0: + buffer_count = bm.pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" From 1b63d5e26f72cee57c72e8abee842b7a5ee35405 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 3 Apr 2025 16:26:53 +0000 Subject: [PATCH 0381/1769] Fixed deadlock in NamedSharding ctor Description: - Test timeout were seen in ColocatedPythonTest test case - GDB report: https://gist.github.com/vfdev-5/d64183f7b5dde3e666eea6cd61670128 --- jaxlib/xla/sharding.cc | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 5a80c03e01da..858c025745e4 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -223,9 +223,22 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - static nb::object* check_pspec = []() { + nb::object* check_pspec = [](){ + static absl::Mutex mu; + static nb::object* output = nullptr; + { + absl::MutexLock lock(&mu); + if (output) { + return output; + } + } nb::module_ si = nb::module_::import_("jax._src.named_sharding"); - return new nb::object(si.attr("check_pspec")); + nb::object attr = si.attr("check_pspec"); + absl::MutexLock lock(&mu); + if (!output) { + output = new nb::object(attr); + } + return output; }(); (*check_pspec)(mesh_, spec_, manual_axes_); } From 206dec859d30e10970e60664802295ec1737131c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 4 Apr 2025 02:33:24 -0700 Subject: [PATCH 0382/1769] [pallas:mosaic_gpu] Added pretty printing to primitives consuming refs I also changed existing pretty printers for transforms to use {} instead of [], so that transforms are visually distinct from slicing. PiperOrigin-RevId: 743869470 --- jax/_src/pallas/mosaic_gpu/BUILD | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 10 ++ jax/_src/pallas/mosaic_gpu/primitives.py | 181 +++++++++++++++++++++++ jax/_src/state/indexing.py | 34 +++++ jax/_src/state/primitives.py | 68 +-------- jax/_src/state/types.py | 11 ++ 6 files changed, 239 insertions(+), 68 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 33883326e58c..554b9db878f6 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -78,6 +78,7 @@ pytype_strict_library( "//jax:dtypes", "//jax:effects", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:state_types", "//jax:tree_util", "//jax/_src/lib", @@ -94,8 +95,8 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:mlir", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:tree_util", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 0a949840ab62..2150b48b5108 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,6 +29,7 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src import pretty_printer as pp from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives @@ -328,6 +329,9 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{untile({list(self.tiling)})}}") + def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -406,6 +410,9 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{transpose({list(self.permutation)})}}") + def transform_ref( ref: pallas_core.TransformedRef, @@ -517,6 +524,9 @@ def untransform_index( raise ValueError("Swizzled dims cannot be sliced") return idxs, self + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{unswizzle({self.swizzle})}}") + @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 37d71cd6d1c6..b909a31496bf 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -25,6 +25,7 @@ import jax from jax._src import core as jax_core +from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util @@ -172,6 +173,40 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} +def _copy_smem_to_gmem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + pp_params = {} + if not (commit_group := eqn.params["commit_group"]): + pp_params["commit_group"] = commit_group + if has_user_predicate := eqn.params["has_user_predicate"]: + pp_params["has_user_predicate"] = has_user_predicate + if reduction_op := eqn.params["reduction_op"]: + pp_params["reduction_op"] = reduction_op + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_args, + [src_transforms_treedef.num_leaves], + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + return pp.concat([ + pp.text("copy_smem_to_gmem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn + + @lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup @@ -355,6 +390,47 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} +def _copy_gmem_to_smem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, barrier, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + barrier_transforms_treedef = eqn.params["barrier_transforms_treedef"] + pp_params = {} + if collective_axes := eqn.params["collective_axes"]: + pp_params["collective_axes"] = collective_axes + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_args, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + barrier_transforms = barrier_transforms_treedef.unflatten( + flat_barrier_transforms + ) + return pp.concat([ + pp.text("copy_gmem_to_smem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + pp.text(" using "), + state_primitives.pp_ref_transforms(context, barrier, barrier_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn + + @lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup @@ -521,6 +597,25 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} +def _barrier_arrive_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_tree"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_arrive"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn + + @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_arrive_lowering( @@ -560,6 +655,25 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} +def _barrier_wait_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_wait"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn + + @lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_wait_lowering( @@ -715,6 +829,39 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): } +def _wgmma_ref_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + acc, a, b, *leaves = eqn.invars + a_transforms_treedef = eqn.params["a_transforms_tree"] + b_transforms_treedef = eqn.params["b_transforms_tree"] + a_transforms = ( + a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves]) + if a_transforms_treedef is not None + else [] + ) + b_transforms = ( + b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :]) + if b_transforms_treedef is not None + else [] + ) + return pp.concat([ + pp.text("wgmma_ref"), + pp.text(" "), + pp.text(jax_core.pp_var(acc, context)), + pp.text(" <- "), + state_primitives.pp_ref_transforms(context, a, a_transforms), + pp.text(" @ "), + state_primitives.pp_ref_transforms(context, b, b_transforms), + ]) + + +jax_core.pp_eqn_rules[wgmma_ref_p] = _wgmma_ref_pp_eqn + + @discharge.register_discharge_rule(wgmma_ref_p) def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals @@ -1090,6 +1237,40 @@ def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params): return [v.aval for v in jaxpr.outvars] +def _jaxpr_call_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + flat_args = eqn.invars + ref_treedefs = eqn.params["ref_treedefs"] + flat_refs, _ = util.split_list( + flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)] + ) + flat_refs = util.split_list( + flat_refs, + [treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]], + ) + trailer = [] + for treedef, flat_ref in zip(ref_treedefs, flat_refs): + ref = treedef.unflatten(flat_ref) + transforms = [] + if isinstance(ref, tuple): + ref, transforms = ref + trailer.append(pp.text(" ")) + trailer.append(state_primitives.pp_ref_transforms(context, ref, transforms)) + return pp.concat([ + pp.text("jaxpr_call"), + pp.text("["), + jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings), + pp.text("]"), + pp.concat(trailer), + ]) + + +jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn + + @lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) def _jaxpr_call_lowering_rule( diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 4b627c1cd581..e7b581680efe 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -20,6 +20,7 @@ from typing import Any, Sequence, Union from jax._src import core +from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.typing import Array from jax._src.util import merge_lists @@ -78,6 +79,30 @@ def from_slice(cls, slc: slice, size: int) -> Slice: return cls(start, size, step) +def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + size_str = ( + core.pp_var(size, context) if isinstance(size, core.Var) else str(size) + ) + return f"{start_str}:{start_str}+{size_str}" + else: + start_str = str(start) + if start == 0: + start_str = "" + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f"{start_str}:{start_str}+{size_str}" + else: + return f":{size_str}" + else: + end = start + size + end_str = "" if end == dim else str(end) + return f"{start_str}:{end_str}" + + def dslice( start: int | Array | None, size: int | Array | None = None, @@ -282,3 +307,12 @@ def transform_sharding(self, sharding): f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}") return sharding + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + indices = [] + for idx, dim in zip(self.indices, self.shape): + if isinstance(idx, Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context)) # type: ignore + return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 6f7570a5f3cd..f992f96992da 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -34,8 +34,6 @@ AbstractRef, AccumEffect, ReadEffect, - RefBitcaster, - RefReshaper, Transform, TransformedRef, WriteEffect, @@ -297,70 +295,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, core.Var): - start_str = core.pp_var(start, context) - size_str = ( - core.pp_var(size, context) - if isinstance(size, core.Var) - else str(size) - ) - return f'{start_str}:{start_str}+{size_str}' - else: - start_str = str(start) - if start == 0: - start_str = '' - if isinstance(size, core.Var): - size_str = core.pp_var(size, context) - if start_str: - return f'{start_str}:{start_str}+{size_str}' - else: - return f':{size_str}' - else: - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer - ) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(context, dim, idx)) - else: - indices.append(core.pp_var(idx, context)) # type: ignore - return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) - - -def pp_bitcaster( - context: core.JaxprPpContext, bitcaster: RefBitcaster -) -> pp.Doc: - del context - return pp.text( - f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" - ) - - -def pp_reshaper(context: core.JaxprPpContext, reshaper: RefReshaper) -> pp.Doc: - del context - return pp.text( - f"[reshape({reshaper.dtype}[{','.join(str(d) for d in reshaper.shape)}])]" - ) - - -def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: - match transform: - case indexing.NDIndexer(): - return pp_indexer(context, transform) - case RefBitcaster(): - return pp_bitcaster(context, transform) - case RefReshaper(): - return pp_reshaper(context, transform) - case _: - return pp.text(f"[{transform}]") - def _pp_transforms( context: core.JaxprPpContext, @@ -369,7 +303,7 @@ def _pp_transforms( if not transforms: return pp.text("[...]") return pp.concat( - [pp_transform(context, transform) for transform in transforms] + [transform.pretty_print(context) for transform in transforms] ) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index b9dbaf35c5d2..1acb856fd1ba 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -125,6 +125,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{bitcast({self.dtype}{list(self.shape)}])}}") + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -178,6 +182,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{reshape({self.dtype}{list(self.shape)})}}") + class Transform(Protocol): @@ -205,6 +213,9 @@ def transform_sharding(self, sharding): if all(p is None for p in sharding.spec): return sharding # no explicit axes raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{{self}}}") + @dataclasses.dataclass class RefIndexer: From b0a920dd92480962ecfb1fa55232fa2c0e584038 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 4 Apr 2025 05:11:10 -0700 Subject: [PATCH 0383/1769] [Mosaic GPU] Don't force TiledLayout.lane_dims to partition data This allows us to replicate elements across a warp and replace the special WGMMAFragRowLayout with a TiledLayout. PiperOrigin-RevId: 743903003 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +- jax/_src/pallas/mosaic_gpu/primitives.py | 49 ++-- jax/experimental/mosaic/gpu/__init__.py | 2 +- .../mosaic/gpu/fragmented_array.py | 231 +++++++++--------- jax/experimental/mosaic/gpu/layouts.py | 4 - jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 12 - tests/mosaic/gpu_test.py | 16 +- tests/pallas/mosaic_gpu_test.py | 4 +- 8 files changed, 163 insertions(+), 164 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 80757ef69e64..d26c71cecc31 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1202,9 +1202,12 @@ def _swap_lowering_rule( return old_value case (): match value.layout: - case mgpu.WGMMARowFragLayout(): - old_value = mgpu.FragmentedArray.load_wgmma_row( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + case mgpu.TiledLayout(): + old_value = mgpu.FragmentedArray.load_untiled( + x_smem, + layout=value.layout, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + optimized=False, ) value.store_untiled(x_smem) return old_value diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b909a31496bf..76759d6bcb83 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -67,9 +67,8 @@ def _check_ref( load_p = jax_core.Primitive("load") @load_p.def_effectful_abstract_eval -def _load_abstract_eval(src, *avals_flat, args_tree, layout): - del layout # Unused. - +def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): + del layout, optimized # Unused. transforms = args_tree.unflatten(avals_flat) return ( jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), @@ -78,7 +77,7 @@ def _load_abstract_eval(src, *avals_flat, args_tree, layout): @lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) def _load_p_lowering_rule( - ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized ): if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): raise TypeError(f"Can only load from references (got {x_ref}).") @@ -91,29 +90,36 @@ def _load_p_lowering_rule( if layout is not None: layout = layout.to_mgpu() + is_signed = mgpu_utils.is_signed(x_aval.dtype) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, - layout=layout + x_ref, + is_signed=is_signed, + swizzle=swizzle, + layout=layout, ) case (): # Handle scalar indexing. if not ctx.avals_out[0].shape: is_signed = mgpu_utils.is_signed(x_aval.dtype) val = memref_dialect.load(x_ref, []) - return mgpu.FragmentedArray.splat(val, shape=(), layout=layout, is_signed=is_signed) + return mgpu.FragmentedArray.splat( + val, shape=(), layout=layout, is_signed=is_signed + ) match layout: - case mgpu.WGMMARowFragLayout(): - return mgpu.FragmentedArray.load_wgmma_row( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + case mgpu.WGMMA_ROW_LAYOUT: + return mgpu.FragmentedArray.load_untiled( + x_ref, + is_signed=is_signed, + layout=layout, + swizzle=16, + optimized=optimized, ) case mgpu.WGMMAColFragLayout(): - return mgpu.FragmentedArray.load_wgmma_col( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return mgpu.FragmentedArray.load_wgmma_col(x_ref, is_signed=is_signed) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): @@ -122,12 +128,10 @@ def _load_p_lowering_rule( ) return mgpu.FragmentedArray.load_strided( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), vec_size=vec_size, + x_ref, is_signed=is_signed, vec_size=vec_size, ) case None: - return mgpu.FragmentedArray.load_strided( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return mgpu.FragmentedArray.load_strided(x_ref, is_signed=is_signed) case _: raise NotImplementedError(f"Unsupported layout: {layout}") case _: @@ -135,7 +139,11 @@ def _load_p_lowering_rule( def load( - src: _Ref, idx, *, layout: Layout | ParameterizedLayout | None = None + src: _Ref, + idx, + *, + layout: Layout | ParameterizedLayout | None = None, + optimized: bool = True, ) -> jax.Array: """Loads from a reference into an array with the specified layout. @@ -143,6 +151,8 @@ def load( src: The reference to load from. Can be either in SMEM or GMEM. idx: The index to load from. layout: The optional layout to use for the resulting array. + optimized: If True, a compilation error will be raised if no optimized + implementation for the load is available. Returns: The loaded array. @@ -157,7 +167,8 @@ def load( src, *flat_src_transforms, args_tree=src_transforms_treedef, - layout=layout + layout=layout, + optimized=optimized, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index e645115940e4..d5b3bea6d36b 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -53,11 +53,11 @@ from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, - WGMMARowFragLayout as WGMMARowFragLayout, WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index b730e34e2ed0..f7ce36c62c9e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -202,6 +202,11 @@ def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: yield i - offset, e +@dataclasses.dataclass(frozen=True) +class Replicated: + times: int + + @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. @@ -248,7 +253,7 @@ class TiledLayout: """ tiling: Tiling warp_dim: int - lane_dims: tuple[int, ...] # major-to-minor + lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int def __post_init__(self): @@ -256,8 +261,8 @@ def __post_init__(self): raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} - if len(dims_set) != len(self.lane_dims) + 2: + dims_set = {self.warp_dim, *self.partitioned_lane_dims, self.vector_dim} + if len(dims_set) != len(self.partitioned_lane_dims) + 2: raise ValueError for d in dims_set: if d >= 0: @@ -266,9 +271,19 @@ def __post_init__(self): raise ValueError("Dimension out of range") if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError - if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + lane_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.lane_dims + ) + if lane_dims_prod != WARP_SIZE: raise ValueError + @functools.cached_property + def partitioned_lane_dims(self) -> tuple[int, ...]: + return tuple( + d for d in self.lane_dims if not isinstance(d, Replicated) + ) + def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to # get the index. @@ -326,7 +341,7 @@ def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) tiled_shape[self.warp_dim] = 1 - for d in self.lane_dims: + for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) @@ -339,15 +354,18 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: tiled_tiling = self.tiled_tiling_shape shape = list(shape) shape[self.warp_dim] = WARPS_IN_WARPGROUP - for d in self.lane_dims: + for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) - def lane_indices(self) -> tuple[ir.Value, ...]: + def _full_lane_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape - lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) + lanes_shape = tuple( + d.times if isinstance(d, Replicated) else tiled_shape[d] + for d in self.lane_dims + ) assert math.prod(lanes_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(lanes_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) @@ -355,8 +373,16 @@ def lane_indices(self) -> tuple[ir.Value, ...]: arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, lanes_shape) ) + return lane_indices + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = self.tiled_tiling_shape + lane_indices = self._full_lane_indices() full_indices = [arith.constant(i32, 0)] * len(tiled_shape) for d, i in zip(self.lane_dims, lane_indices): + if isinstance(d, Replicated): + continue full_indices[d] = i return tuple(full_indices) @@ -385,41 +411,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMARowFragLayout: - """[m] matrix, where m % 64 == 0.""" - - def registers_element_type(self, t: ir.Type) -> ir.Type: - return t - - def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: - """Returns the shape of the register array needed to represent an array of the given logical shape.""" - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - return (shape[0] // 64, 2) - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 64 == 0 - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - warp_idx = arith.divui(tid_wg, c(32, index)) - lane_id = arith.remui(tid_wg, c(32, index)) - row_base = arith.addi( - arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) - ) - - for row_group in range(0, shape[0], 64): - for row_subgroup in (0, 8): - row = arith.addi(row_base, c(row_group + row_subgroup, index)) - yield (row,) - - @dataclasses.dataclass(frozen=True) class WGMMAColFragLayout: """[n] matrix, where n % 8 == 0.""" @@ -547,11 +538,16 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | WGMMAColFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAColFragLayout | TiledLayout -WGMMA_ROW_LAYOUT = WGMMARowFragLayout() WGMMA_COL_LAYOUT = WGMMAColFragLayout() +WGMMA_ROW_LAYOUT = TiledLayout( + Tiling(((64,), (16,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(4)), + vector_dim=-1, +) # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d @@ -663,12 +659,6 @@ def __init__( ) match self.layout: - # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout - # Each element is a dtype scalar - case WGMMARowFragLayout(): - if _registers.ndim != 2 or _registers.shape[-1] != 2: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are [n_tiles] in WGMMA_COL layout # Each element is a vector of size 2. case WGMMAColFragLayout(): @@ -731,30 +721,6 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_row( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - - layout = WGMMARowFragLayout() - registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] - registers = np.array(registers).reshape(-1, 2) - return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - @classmethod def load_wgmma_col( cls, @@ -790,7 +756,7 @@ def load_wgmma_col( def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): pass case WGStridedFragLayout() | TiledLayout(): value = vector.splat(layout.registers_element_type(value.type), value) @@ -806,9 +772,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMARowFragLayout(): - row_tiles = self.registers.shape[0] - return (row_tiles * 64,) case WGMMAColFragLayout(): col_tiles = self.registers.shape[0] return (col_tiles * 8,) @@ -827,7 +790,7 @@ def mlir_dtype(self): match self.layout: case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError @@ -1589,7 +1552,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1614,7 +1577,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1713,9 +1676,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): i32 = ir.IntegerType.get_signless(32) row_tile_dim = self.registers.shape[0] row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) + new_regs = np.empty((row_tile_dim, 1, row_subtile_dim, 1, 1), dtype=object) assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(new_regs.shape): + for row_tile, row_subtile in np.ndindex(row_tile_dim, row_subtile_dim): # Reduce the registers owned by the current thread over n tiles reg_index = [0] * self.registers.ndim reg_index[0] = row_tile @@ -1746,7 +1709,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): nvvm.ShflKind.bfly, ) result = op(result, other_result) - new_regs[row_tile, row_subtile] = result + new_regs[row_tile, :, row_subtile] = vector.splat( + ir.VectorType.get((1,), self.mlir_dtype), result + ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed ) @@ -1791,12 +1756,14 @@ def broadcast_minor(self, n): reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype - for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + i0 = arith.constant(ir.IndexType.get(), 0) + for (row_tile, _, row_subtile, *__), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg + ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), + vector.extractelement(reg, position=i0), ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed @@ -1874,8 +1841,6 @@ def vs_unsupported(): ) match self.layout: - case WGMMARowFragLayout(): - self._store_untiled_wgmma_row(ref) case WGMMAColFragLayout(): self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): @@ -1889,6 +1854,22 @@ def vs_unsupported(): case _: raise NotImplementedError(self.layout) + @classmethod + def load_untiled( + cls, + ref: ir.Value, + *, + layout: TiledLayout, + swizzle: int = 16, + is_signed: bool | None = None, + optimized: bool = True, + ): + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + return cls.load_tiled( + ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized + ) + def _store_untiled_splat(self, ref: ir.Value): vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: @@ -1924,23 +1905,6 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_row(self, ref: ir.Value): - """Stores an array with a WGMMA row layout.""" - assert self.layout == WGMMA_ROW_LAYOUT - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - - is_first = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) - ) - # Consecutive groups of 4 threads hold the same value in this layout, - # therefore we only need to transfer data from one of them. - with utils.when(is_first): - for (idx,), value in zip( - self.layout.thread_idxs(self.shape), self.registers.flatten() - ): - memref.store(value, ref, [idx]) - def _store_untiled_wgmma_col(self, ref: ir.Value): """Stores an array with a WGMMA col layout.""" assert isinstance(self.layout, WGMMAColFragLayout) @@ -2007,6 +1971,9 @@ def store_tiled(self, ref, swizzle: int | None): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape + # Note that the loop below will "race" for layouts that replicate data. + # However, in that case all of the racing writes store the same data, which + # is ok in the CUDA memory model. for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): llvm.store(get(self.registers), ptr) @@ -2018,6 +1985,7 @@ def load_tiled( *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, + optimized: bool = True, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type @@ -2036,7 +2004,8 @@ def load_tiled( ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) - for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for _, update, ptr in loads: update(registers, llvm.load(reg_ty, ptr)) case _: raise NotImplementedError(layout) @@ -2132,6 +2101,7 @@ def transfer_tiled2( swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], + optimized: bool = True, ): """Generate a transfer schedule for a tiled layout. @@ -2183,11 +2153,15 @@ def transfer_tiled2( raise NotImplementedError("Memory and register tiling incompatible") tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) - elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] - lane_shape = [tiled_shape[d] for d in layout.lane_dims] + lane_shape = [ + d.times if isinstance(d, Replicated) else tiled_shape[d] for d in layout.lane_dims + ] + lane_strides = [ + 0 if isinstance(d, Replicated) else elem_tiled_strides[d] for d in layout.lane_dims + ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + for d in (layout.warp_dim, *layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 element_bits = mgpu.bitwidth(dtype) @@ -2223,10 +2197,22 @@ def transfer_tiled2( transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) - plan = plan_tiled_transfer( - tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, - element_bits, swizzle - ) + if ref_ty.memory_space is None: + llvm_memory_space = None + elif ref_ty.memory_space == ir.Attribute.parse("#gpu.address_space"): + llvm_memory_space = 3 + else: + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + if optimized: + if llvm_memory_space != 3: + raise NotImplementedError("Only optimized transfers to SMEM supported") + plan = plan_tiled_transfer( + tiled_shape, elem_tiled_strides, lane_shape, lane_strides, + layout, element_bits, swizzle + ) + else: + plan = TrivialTransferPlan() # All offsets are in units of transfer_dtype. dyn_tiled_strides = [ @@ -2235,9 +2221,7 @@ def transfer_tiled2( lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("Tiled stores can be performed into SMEM") - ptr = utils.memref_ptr(ref, memory_space=3) + ptr = utils.memref_ptr(ref, memory_space=llvm_memory_space) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers @@ -2416,9 +2400,18 @@ def plan_tiled_transfer( num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts + lane_mask = np.full(lane_shape, False) + lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True + wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) + wavefront_active_lanes = wavefront_mask.sum(-1) + # We make a simplifying assumption: wavefronts have the same number of lanes + if any(act != wavefront_active_lanes[0] for act in wavefront_active_lanes): + raise NotImplementedError + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): - tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + num_tiles = math.prod(tiled_shape) + tile_idxs = np.unravel_index(np.arange(num_tiles), tiled_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} @@ -2429,6 +2422,8 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) + # Mask out the inactive lanes in each wavefront + wavefront_banks = wavefront_banks[:, wavefront_mask].reshape(num_tiles, num_wavefronts, -1) # Order of threads within the wavefront is unimportant. wavefront_banks = np.sort(wavefront_banks, axis=-1) # There are no conflicts if each wavefront only contains unique banks. diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 5c3b23119779..cb94c3eaf749 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -155,7 +155,6 @@ def to_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ), ) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" @@ -166,8 +165,6 @@ def to_layout_attr( return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) - case fa.WGMMARowFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" @@ -189,7 +186,6 @@ def from_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index cda521855250..6b934b951d93 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -142,18 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; - let description = [{ - This layout is used to handle rows that are fragmented across all threads - in a warpgroup that is executing a WGMMA operation. The length of the array - must be divisible by 64. - }]; - - let mnemonic = "WGMMARowFragLayout"; - let assemblyFormat = ""; -} - def MosaicGPU_TiledLayout : AttrDef { let summary = "A layout derived from a tiling expression."; let description = [{ diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6c63e3ce40e1..cc0e6a04bdc3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2004,12 +2004,16 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - @parameterized.product(in_shape=((128,), (64,))) - def test_wgmma_row_load_store_with_layout(self, in_shape): - def kernel(ctx, *args): - gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + @parameterized.product( + in_shape=((1024,), (256,), (128,), (64,)), swizzle=(16, 32, 64, 128) + ) + def test_wgmma_row_load_store_with_layout(self, in_shape, swizzle): + def kernel(ctx, gmem_input, gmem_output, smem): + smem_input, smem_output = smem + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, layout=mgpu.WGMMA_ROW_LAYOUT, swizzle=swizzle + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0cfe9197db36..67472dbdd9e5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -770,7 +770,9 @@ def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layo ) def kernel(x_ref, o_ref): for i in range(2): - x = plgpu.load(x_ref, (i,), layout=layout) + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM + ) o_ref[i, ...] = x x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) From 635805e9b02a9b400b54efe0e766964278e178dd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 4 Apr 2025 05:43:03 -0700 Subject: [PATCH 0384/1769] [Mosaic GPU] Allow replicating data over warps This extends the tiled layouts further and allows us to replace WGMMA_COL_LAYOUT implementation with a TiledLayout. PiperOrigin-RevId: 743909503 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 - jax/_src/pallas/mosaic_gpu/primitives.py | 5 +- jax/experimental/mosaic/gpu/__init__.py | 1 - .../mosaic/gpu/fragmented_array.py | 146 +++++------------- tests/mosaic/gpu_test.py | 31 ++-- 5 files changed, 61 insertions(+), 128 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d26c71cecc31..827794d37e2b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1211,12 +1211,6 @@ def _swap_lowering_rule( ) value.store_untiled(x_smem) return old_value - case mgpu.WGMMAColFragLayout(): - old_value = mgpu.FragmentedArray.load_wgmma_col( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value case _: old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 76759d6bcb83..b2beec700fad 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -110,7 +110,7 @@ def _load_p_lowering_rule( val, shape=(), layout=layout, is_signed=is_signed ) match layout: - case mgpu.WGMMA_ROW_LAYOUT: + case mgpu.WGMMA_ROW_LAYOUT | mgpu.WGMMA_COL_LAYOUT: return mgpu.FragmentedArray.load_untiled( x_ref, is_signed=is_signed, @@ -118,15 +118,12 @@ def _load_p_lowering_rule( swizzle=16, optimized=optimized, ) - case mgpu.WGMMAColFragLayout(): - return mgpu.FragmentedArray.load_wgmma_col(x_ref, is_signed=is_signed) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): raise ValueError( f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" ) - return mgpu.FragmentedArray.load_strided( x_ref, is_signed=is_signed, vec_size=vec_size, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d5b3bea6d36b..a4f1e0a9cfe0 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -58,7 +58,6 @@ WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, - WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f7ce36c62c9e..f6c5e7d1ed19 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -252,7 +252,7 @@ class TiledLayout: by a single (logical) register. """ tiling: Tiling - warp_dim: int + warp_dim: int | Replicated lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int @@ -261,15 +261,20 @@ def __post_init__(self): raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.partitioned_lane_dims, self.vector_dim} - if len(dims_set) != len(self.partitioned_lane_dims) + 2: + dims_set = {*self.partitioned_lane_dims, self.vector_dim} + if partitions_warp_dim := not isinstance(self.warp_dim, Replicated): + dims_set.add(self.warp_dim) + if len(dims_set) != len(self.partitioned_lane_dims) + 1 + partitions_warp_dim: raise ValueError for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") - if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: + if isinstance(self.warp_dim, Replicated): + if self.warp_dim.times != WARPS_IN_WARPGROUP: + raise ValueError + elif min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError lane_dims_prod = math.prod( d.times if isinstance(d, Replicated) else min_tiled_shape[d] @@ -340,7 +345,8 @@ def registers_element_type(self, t: ir.Type) -> ir.Type: def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) - tiled_shape[self.warp_dim] = 1 + if not isinstance(self.warp_dim, Replicated): + tiled_shape[self.warp_dim] = 1 for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 @@ -353,7 +359,8 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) - shape[self.warp_dim] = WARPS_IN_WARPGROUP + if not isinstance(self.warp_dim, Replicated): + shape[self.warp_dim] = WARPS_IN_WARPGROUP for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] @@ -389,12 +396,13 @@ def lane_indices(self) -> tuple[ir.Value, ...]: def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape_rank = len(self.tiled_tiling_shape) - warp_idx = arith.remui( - arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), - c(WARPS_IN_WARPGROUP, i32), - ) indices = [arith.constant(i32, 0)] * tiled_shape_rank - indices[self.warp_dim] = warp_idx + if not isinstance(self.warp_dim, Replicated): + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices[self.warp_dim] = warp_idx return tuple(indices) @@ -411,23 +419,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMAColFragLayout: - """[n] matrix, where n % 8 == 0.""" - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 8 == 0 - - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - lane_id = arith.remui(tid, c(WARP_SIZE, index)) - col_base = arith.muli(arith.remui(lane_id, c(4, index)), c(2, index)) - - for col_group in range(0, shape[0], 8): - col = arith.addi(col_base, c(col_group, index)) - yield (col,) - @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -538,10 +529,15 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAColFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | TiledLayout -WGMMA_COL_LAYOUT = WGMMAColFragLayout() +WGMMA_COL_LAYOUT = TiledLayout( + Tiling(((8,), (2,))), + warp_dim=Replicated(4), + lane_dims=(Replicated(8), -2), + vector_dim=-1, +) WGMMA_ROW_LAYOUT = TiledLayout( Tiling(((64,), (16,), (8,), (1,))), warp_dim=-4, @@ -659,12 +655,6 @@ def __init__( ) match self.layout: - # Registers are [n_tiles] in WGMMA_COL layout - # Each element is a vector of size 2. - case WGMMAColFragLayout(): - if _registers.ndim != 1: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -721,37 +711,6 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_col( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - layout = WGMMAColFragLayout() - - if len(shape) != 1: - raise ValueError("WGMMAColFragLayout requires a 1D shape.") - - if shape[0] % 8: - raise ValueError( - f"WGMMAColFragLayout requires {shape[0]=} to be a multiple of 8." - ) - - vec_ty = ir.VectorType.get((2,), ref_ty.element_type) - new_regs = np.full((shape[0] // 8,), llvm.mlir_undef(vec_ty)) - - for col_tile, (idx,) in enumerate(layout.thread_idxs(shape)): - reg = vector.load(vec_ty, ref, [idx]) - new_regs[col_tile] = reg - - return cls(_registers=new_regs, _layout=layout, _is_signed=is_signed) - @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) @@ -772,9 +731,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMAColFragLayout(): - col_tiles = self.registers.shape[0] - return (col_tiles * 8,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -788,7 +744,7 @@ def shape(self): def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): + case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGSplatFragLayout(): return reg_ty @@ -1770,15 +1726,11 @@ def broadcast_minor(self, n): ) def broadcast_major(self, m): - if not isinstance(self.layout, WGMMAColFragLayout): - raise NotImplementedError - if m % 64: raise ValueError("Number of rows must be divisible by 64") - reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) new_regs = np.empty(reg_shape, dtype=object) - for col_tile, reg in np.ndenumerate(self.registers): + for (col_tile, *_), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[1] = col_tile new_regs[tuple(tile)] = reg @@ -1841,8 +1793,6 @@ def vs_unsupported(): ) match self.layout: - case WGMMAColFragLayout(): - self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1905,21 +1855,6 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_col(self, ref: ir.Value): - """Stores an array with a WGMMA col layout.""" - assert isinstance(self.layout, WGMMAColFragLayout) - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - - # Consecutive groups of 4 threads replicate the same data, so we only need to - # transfer data from one group. - is_first = arith.cmpi(arith.CmpIPredicate.ult, tid_wg, c(4, index)) - - with utils.when(is_first): - for (idx,), reg in zip(self.layout.thread_idxs(self.shape), self.registers): - vector.store(reg, ref, [idx]) - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: @@ -2161,8 +2096,10 @@ def transfer_tiled2( ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.partitioned_lane_dims, layout.vector_dim): + for d in (*layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 + if not isinstance(layout.warp_dim, Replicated): + tiled_shape[layout.warp_dim] = 1 element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -2403,10 +2340,6 @@ def plan_tiled_transfer( lane_mask = np.full(lane_shape, False) lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) - wavefront_active_lanes = wavefront_mask.sum(-1) - # We make a simplifying assumption: wavefronts have the same number of lanes - if any(act != wavefront_active_lanes[0] for act in wavefront_active_lanes): - raise NotImplementedError lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): @@ -2422,12 +2355,17 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) - # Mask out the inactive lanes in each wavefront - wavefront_banks = wavefront_banks[:, wavefront_mask].reshape(num_tiles, num_wavefronts, -1) - # Order of threads within the wavefront is unimportant. - wavefront_banks = np.sort(wavefront_banks, axis=-1) - # There are no conflicts if each wavefront only contains unique banks. - return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + # We step over wavefronts since they might have a different number of lanes. + wavefront_banks = wavefront_banks.swapaxes(0, 1) + for banks, mask in zip(wavefront_banks, wavefront_mask): + banks = banks[:, mask] + # Order of threads within the wavefront is unimportant. + banks = np.sort(banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + repeats = np.any(banks[..., 1:] == banks[..., :-1]) + if repeats: + return True + return False # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cc0e6a04bdc3..9d9d3fa8979c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2005,9 +2005,11 @@ def kernel(ctx, *args): np.testing.assert_array_equal(inp, result) @parameterized.product( - in_shape=((1024,), (256,), (128,), (64,)), swizzle=(16, 32, 64, 128) + in_shape=((1024,), (256,), (128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128) ) - def test_wgmma_row_load_store_with_layout(self, in_shape, swizzle): + def test_wgmma_row_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, gmem_input, gmem_output, smem): smem_input, smem_output = smem copy(gmem_input, smem_input, swizzle=swizzle) @@ -2017,20 +2019,24 @@ def kernel(ctx, gmem_input, gmem_output, smem): t.store_untiled(smem_output) copy(smem_output, gmem_output) - inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) np.testing.assert_array_equal(inp, result) @parameterized.product( - in_shape=((128,), (64,)), dtype=[jnp.float16, jnp.float32] + in_shape=((128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128), ) - def test_wgmma_col_load_store_with_layout(self, in_shape, dtype): + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_col(smem_input) + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, swizzle=swizzle, layout=mgpu.WGMMA_COL_LAYOUT + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) @@ -2042,18 +2048,17 @@ def kernel(ctx, *args): @parameterized.parameters((128, 128), (128, 64), (64, 128)) def test_broadcast_major(self, m, n): - def kernel(ctx, *args): - gmem_input, gmem_output, () = args - t = mgpu.FragmentedArray.load_wgmma_col(gmem_input) + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) t.broadcast_major(m).store_untiled(gmem_output) inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) - result = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, () + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp )(inp) - out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) np.testing.assert_array_equal(result, out_ref) From e4a381c12e42c62d149f694122472d237dba333b Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 05:47:54 -0700 Subject: [PATCH 0385/1769] [pallas:mgpu] Check that swizzle dim is not transposed in copy_smem_to_gmem() PiperOrigin-RevId: 743910324 --- jax/experimental/mosaic/gpu/launch_context.py | 8 ++++++++ tests/pallas/mosaic_gpu_test.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index aca3fc723882..243c5e5df15c 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -325,6 +325,14 @@ def init_tma_desc(host_ptr): ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + strides, _ = ref_ty.get_strides_and_offset() + if strides[-1] != 1: + raise ValueError( + "TMA requires the stride of the last dimension after" + " transforming the GMEM reference to be 1, but it is" + f" {strides[-1]}." + ) + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) as_i64 = lambda i: arith.index_cast(i64, i) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 67472dbdd9e5..0a3087f5902d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1379,7 +1379,8 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) - def test_wgmma_transposed_layout(self): + @parameterized.parameters(False, True) + def test_wgmma_transposed_layout(self, store_transposed): """Tests that the result of wgmma can be store transposed using the WGMMA_TRNASPOSED layout. """ @@ -1412,10 +1413,14 @@ def kernel(o_ref, smem): smem_trns = plgpu.transpose_ref(smem, (1, 0)) smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem, o_ref) + plgpu.copy_smem_to_gmem(smem_trns if store_transposed else smem, o_ref) x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T - np.testing.assert_array_equal(kernel(), x) + if store_transposed: + with self.assertRaises(ValueError): + kernel() + else: + np.testing.assert_array_equal(kernel(), x) def test_profiler(self): self.skip_if_wg_semantics() # Transform inference fails. From cbae2539d4724e490aefd2aa3e8e661223c57e35 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 06:32:27 -0700 Subject: [PATCH 0386/1769] [mgpu:pallas] Typo in `UnswizzleRef.untransform_reshape()` check. PiperOrigin-RevId: 743920665 --- jax/_src/pallas/mosaic_gpu/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2150b48b5108..c0bf602e0962 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -494,7 +494,7 @@ def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Tran def untransform_reshape( self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] ) -> tuple[tuple[int, ...], state_types.Transform]: - if shape[-1] == self.swizzle_elems(dtype): + if shape[-1] != self.swizzle_elems(dtype): raise ValueError( f"Reshape shape {shape} is not divisible by swizzle elements" f" {self.swizzle_elems(dtype)}" From 5a29311c8b97922b4fc6f1d942a87d5a784d86c7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 06:36:33 -0700 Subject: [PATCH 0387/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec. PiperOrigin-RevId: 743921735 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c30648a2b3a1..7abf3da775d2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "921c164a67e8ac4cf052aab26e849f29b719f802" -XLA_SHA256 = "9e734da4a0211ac09a00cc07969645e31f107cfee19bbc5d2d1e21ddbb19090d" +XLA_COMMIT = "ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec" +XLA_SHA256 = "d5de319756b6a32748d2821f5319f831b062d0f5f22b7f0bde1d9564dc6b6f5e" def repo(): tf_http_archive( From da7b1577e24784c6e1edcc8167407e85bb85195e Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 07:20:58 -0700 Subject: [PATCH 0388/1769] [mgpu:pallas] Swizzle elements computed using bitwidth rather than bytewidth. PiperOrigin-RevId: 743933866 --- jax/_src/pallas/mosaic_gpu/core.py | 3 ++- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 8 +++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c0bf602e0962..d3d7f89812c3 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -33,6 +33,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives +import jax._src.pallas.utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types @@ -466,7 +467,7 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: raise NotImplementedError def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: - swizzle_elems = self.swizzle // aval.dtype.itemsize + swizzle_elems = (self.swizzle * 8) // pallas_utils.dtype_bitwidth(aval.dtype) if swizzle_elems != aval.shape[-1]: raise ValueError( f"Swizzle {self.swizzle} requires the trailing dimension to be of" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 827794d37e2b..f8932b08b90a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1112,7 +1112,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (8, swizzle // x_aval.dtype.itemsize): + if tiling != (8, (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype)): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b2beec700fad..bfe1bf3fe5bc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -481,7 +481,13 @@ def _copy_gmem_to_smem_lowering( for axis in collective_axes ) dst_ty = ir.MemRefType(dst.type) - bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) + bits = math.prod(dst_ty.shape) * mgpu.bitwidth(dst_ty.element_type) + if bits % 8: + raise ValueError( + f"Can only transfer integer bytes (shape={dst_ty.shape}," + f" dtype={dst_ty.element_type})" + ) + bytes = bits // 8 if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") From 53abbd5606495633fbe2eb0ea720d9d1f4e4f937 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 08:30:54 -0700 Subject: [PATCH 0389/1769] [mgpu] Foreach to handle scalar registers in fragmented arrays. PiperOrigin-RevId: 743953606 --- .../mosaic/gpu/fragmented_array.py | 16 ++++++++---- tests/mosaic/gpu_test.py | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f6c5e7d1ed19..ecd51f79eab0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1765,12 +1765,18 @@ def foreach( for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) - [elems] = ir.VectorType(reg.type).shape - for i in range(elems): - i = c(i, index) - val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if ir.VectorType.isinstance(reg.type): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): + i = c(i, index) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + else: + val = fn(reg, mlir_idx) if create_array: - new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + new_regs[reg_idx] = val + if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9d9d3fa8979c..80e5380d165b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1756,6 +1756,31 @@ def kernel(ctx, dst, _): rhs = rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) + def test_foreach_wgmma_row_array(self): + def kernel(ctx, out, smem): + del ctx, smem + x = iota_tensor(128, 128, jnp.float32) + row = x.reduce("add", 1) + # Test returning an array + row = row.foreach( + lambda x, _: arith.addf(x, c(1, row.mlir_dtype)), create_array=True + ) + # Test no array return + @row.foreach + def _(v, idx): + memref.store(v, out, idx) + + result = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape=(128,), dtype=jnp.float32), + smem_scratch_shape=(), + )() + iota = np.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(result, iota.sum(axis=1) + 1) + def test_foreach(self): dtype = jnp.int32 swizzle = 128 From b9007145d7c4f6f44c41c7111edc56b61be921d7 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 08:31:49 -0700 Subject: [PATCH 0390/1769] [mgpu:pallas] Fix swizzling check bug where it was comparing w/ #bytes rather than #elems. PiperOrigin-RevId: 743953910 --- jax/_src/pallas/mosaic_gpu/core.py | 13 ++++++++----- jax/_src/pallas/mosaic_gpu/lowering.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d3d7f89812c3..444fe6e50f88 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -290,8 +290,9 @@ def untransform_reshape( raise NotImplementedError("Reshapes don't commute with transposes.") def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling: list[Index] = [] @@ -395,8 +396,9 @@ def untransform_reshape( raise NotImplementedError("Can't reshape a transposed memref.") def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype removed_dims = [ i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] @@ -503,8 +505,9 @@ def untransform_reshape( return shape, self def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + swizzle_elems = self.swizzle_elems(dtype) if not idxs: return idxs, self if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): @@ -513,14 +516,14 @@ def untransform_index( ) last_idx = idxs[-1] if isinstance(last_idx, mgpu.DynamicSlice): - if last_idx.base != 0 or last_idx.length != self.swizzle: + if last_idx.base != 0 or last_idx.length != swizzle_elems: raise ValueError("Swizzled dims cannot be sliced") else: assert isinstance(last_idx, slice) if ( (last_idx.step is not None and last_idx.step != 1) or (last_idx.start is not None and last_idx.start != 0) - or (last_idx.stop is not None and last_idx.stop != self.swizzle) + or (last_idx.stop is not None and last_idx.stop != swizzle_elems) ): raise ValueError("Swizzled dims cannot be sliced") return idxs, self diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f8932b08b90a..2423e1c1a2a7 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1032,6 +1032,7 @@ def _handle_transforms( handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: transformed_ref = ref + mlir_dtype = ir.MemRefType(ref.type).element_type new_transforms = [] def _bubble_up(untransform_fn, data): nonlocal new_transforms @@ -1051,7 +1052,7 @@ def _bubble_up(untransform_fn, data): raise NotImplementedError("int_indexer_shape non-empty") indices = _ndindexer_indices(indexer) indices = _bubble_up( - lambda t, idxs: t.untransform_index(idxs), indices + lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices ) transformed_ref = mgpu.memref_slice(transformed_ref, indices) case gpu_core.TransposeRef(perm) if handle_transposes: From 35d75183c70e5f83d5df2065956547b0845c29a6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 10:09:44 -0700 Subject: [PATCH 0391/1769] `_attempt_rewriting_take_via_slice()`: canonicalize the slice index before checking it's not too long, so that e.g. `my_1d_array[:, ...]` can be treated as a slice rather than generating a gather operation. PiperOrigin-RevId: 743986126 --- jax/_src/numpy/indexing.py | 9 +++++---- tests/lax_numpy_indexing_test.py | 10 +++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 863f0c775ec6..05169dd541ce 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -526,8 +526,6 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> if not all(isinstance(i, int) for i in arr.shape): return None - if len(idx) > arr.ndim: - return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather @@ -535,10 +533,13 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None - if any(i is Ellipsis for i in idx): - # Remove ellipses and add trailing `slice(None)`. + # Remove ellipses and pad with trailing `slice(None)` if necessary. + # Do this before checking against rank of `arr` so that `...` can + # count as no dimensions at all (e.g. `my_1d_array[:, ...]` succeeds) idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + if len(idx) > arr.ndim: + return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 63a725ad3643..ca9ba9c88806 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -926,12 +926,20 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) - # Indexing with `Ellipsis` is not lowered to `gather`. + # Indexing with `Ellipsis` is not lowered to `gather` ... jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) self.assertLen((jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # ... even when the ellipsis expands to no dimensions. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0:1])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[0:1, ...])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) From e2f67e0ef1af19f4d32e02f6fb927502469b32c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 10:32:47 -0700 Subject: [PATCH 0392/1769] Always force synchronous pipelining when we have vmem storage and trivial PiperOrigin-RevId: 743993611 --- jax/_src/pallas/mosaic/lowering.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 67cacbc8dcf9..87e06f486366 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -744,13 +744,13 @@ def dynamic_shape_replacement_fn( 1 if b is pallas_core.mapped else b for b in bm.block_shape ] - # No sense in double-buffering without any windowing pattern. - buffer_count = 0 + # Force single-buffering pipelining for trivial windowing in VMEM. + pipeline_mode = bm.pipeline_mode if ( tpu_memory_space == tpu_core.TPUMemorySpace.VMEM and bm.has_trivial_window() ): - buffer_count = 1 + pipeline_mode = pallas_core.Buffered(1) # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. @@ -769,21 +769,20 @@ def dynamic_shape_replacement_fn( block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) - if bm.pipeline_mode is not None: - if not isinstance(bm.pipeline_mode, pallas_core.Buffered): + if pipeline_mode is not None: + if not isinstance(pipeline_mode, pallas_core.Buffered): raise LoweringException( - f"Unsupported pipeline mode: {bm.pipeline_mode}." + f"Unsupported pipeline mode: {pipeline_mode}." ) - if buffer_count == 0: - buffer_count = bm.pipeline_mode.buffer_count + buffer_count = pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" f" {buffer_count}." ) - pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered" + pipeline_mode_str = "synchronous" if buffer_count == 1 else "double_buffered" block_params["pipeline_mode"] = ir.Attribute.parse( - f"#tpu.pipeline_mode<{pipeline_mode}>" + f"#tpu.pipeline_mode<{pipeline_mode_str}>" ) window_params.append(ir.DictAttr.get(block_params)) m.body.append(mlir_func) From e6b01bd1ed18dcbe92041c4bc7254470a11bd0b1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 10:52:20 -0700 Subject: [PATCH 0393/1769] Parameterize the random tests taking out_sharding argument in pjit_test.py PiperOrigin-RevId: 744000229 --- tests/pjit_test.py | 142 +++++++++++---------------------------------- 1 file changed, 33 insertions(+), 109 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f2db913af736..2570c6090351 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7274,116 +7274,60 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @parameterized.named_parameters( + ("bits", partial(jax.random.bits, shape=(8, 12)), P('x', 'y')), + ("uniform", partial(jax.random.uniform, shape=(8, 12)), P('x', 'y')), + ("normal", partial(jax.random.normal, shape=(8, 12)), P('x', 'y')), + ("randint", partial(jax.random.randint, shape=(8, 12), minval=0, maxval=10), + P('x', 'y')), + ("permutation_1d", partial(jax.random.permutation, x=8), P('x')), + ("permutation_2d", partial(jax.random.permutation, + x=np.arange(8 * 12).reshape(8, 12)), + P('x', 'y')), + ) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_bits(self, mesh): - @jax.jit - def f(key): - out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_uniform(self, mesh): - @jax.jit - def f(key): - out = jax.random.uniform(key, shape=(8, 12), out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_randint(self, mesh): - @jax.jit - def f(key): - out = jax.random.randint(key, shape=(8, 12), minval=0, maxval=10, - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((4,), ('x',)) - def test_random_permutation_1d(self, mesh): - @jax.jit - def f(key): - out = jax.random.permutation(key, 8, out_sharding=P('x')) - self.assertEqual(out.aval.sharding.spec, P('x')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[4]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_permutation_2d(self, mesh): + def test_random_functions(self, fun, out_spec, mesh): @jax.jit def f(key): - out = jax.random.permutation(key, jnp.arange(8 * 12).reshape(8, 12), - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + out = fun(key, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) return out key = jax.random.key(1) out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered_text = f.lower(key).as_text() if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + assert out_spec == P('x') + self.assertIn('<@mesh, [{"x"}]>', lowered_text) else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + else: + assert out_spec == P('x') + self.assertIn( + 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', + lowered_text) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_normal(self, mesh): + def test_random_truncated_normal(self, mesh): @jax.jit - def f(key): - out = jax.random.normal(key, shape=(8, 12), out_sharding=P('x', 'y')) + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) self.assertEqual(out.aval.sharding.spec, P('x', 'y')) return out key = jax.random.key(1) - out = f(key) + out = f(key, -1.) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - lowered_text = f.lower(key).as_text() + lowered_text = f.lower(key, -1.).as_text() if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) @@ -7423,26 +7367,6 @@ def f(arr, key): out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_truncated_normal(self, mesh): - @jax.jit - def f(key, lower): - out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key, -1.) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key, -1.).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - def test_auto_axes_no_context_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) From be1a554d0bbce75a7fbc3e66fee81435f184676d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 4 Apr 2025 11:18:27 -0700 Subject: [PATCH 0394/1769] Fix a possible race in pjit.cc. We need to be careful not to destroy Python objects while using a Python 3.13- critical section to protect C++ state. The critical section might be released when calling back into Python code (much as the GIL may be released in GIL mode). In this code Key is kept alive by the function already, but the Value may be deleted before the hash table updates are done. PiperOrigin-RevId: 744008939 --- jaxlib/xla/pjit.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 508bf79f9ec0..503e8ef23f4b 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -245,9 +245,14 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { std::shared_ptr cache = std::make_shared(&self->lru_list_); auto callback = nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { - nb::ft_object_guard lock(self); - auto it = self->functions_.find(key); - if (it != self->functions_.end()) { + std::unique_ptr value; + { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + value = std::move(it->second); self->functions_.erase(it); } }); From 5d4ac775dd210d2e5deca46c5006b07d7e08d6e9 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 4 Apr 2025 11:28:13 -0700 Subject: [PATCH 0395/1769] PR #26906: [jax.distributed] Allow explicitly setting slice_index Imported from GitHub PR https://github.com/jax-ml/jax/pull/26906 Allows overriding the slice index used by XLA. More explicit control over which slice a device ends up in is desirable: - Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice. - For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host. (Companion PR in XLA: https://github.com/openxla/xla/pull/23347) Copybara import of the project: -- 45aa7ce316bb05ebcc3f3ed2d888385923285e58 by Georg Stefan Schmid : [jax.distributed] Allow overriding XLA slice_index Merging this change closes #26906 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/26906 from gspschmid:gschmid/jax-override-boot-id 45aa7ce316bb05ebcc3f3ed2d888385923285e58 PiperOrigin-RevId: 744012253 --- jax/_src/distributed.py | 16 +++++++++++++--- jax/_src/xla_bridge.py | 2 ++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af50e2e9e31a..a7551465c425 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,6 +41,7 @@ class State: client: Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None + slice_index: int | None = None def initialize(self, coordinator_address: str | None = None, @@ -53,7 +54,8 @@ def initialize(self, service_heartbeat_interval_seconds: int = 10, service_max_missing_heartbeats: int = 10, client_heartbeat_interval_seconds: int = 10, - client_max_missing_heartbeats: int = 10): + client_max_missing_heartbeats: int = 10, + slice_index: int | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -149,6 +151,10 @@ def initialize(self, self.initialize_preemption_sync_manager() + if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) + self.slice_index = slice_index + def shutdown(self): if self.client: self.client.shutdown() @@ -175,7 +181,8 @@ def initialize(coordinator_address: str | None = None, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + slice_index: int | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -236,6 +243,8 @@ def initialize(coordinator_address: str | None = None, all available addresses on the same port as ``coordinator_address``. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface. + slice_index: The slice index assigned to this process' local devices. If any process sets ``slice_index``, + then all processes must do so. If ``None`` the slice indices will be chosen automatically. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once @@ -261,7 +270,8 @@ def initialize(coordinator_address: str | None = None, "This includes any computation, but also calls to jax.devices, jax.device_put, and others.") global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, - initialization_timeout, coordinator_bind_address) + initialization_timeout, coordinator_bind_address, + slice_index=slice_index) def is_initialized() -> bool: diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 227359dc4676..178ac5e6fc01 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -636,6 +636,8 @@ def factory(): 'node_id': distributed.global_state.process_id, 'num_nodes': distributed.global_state.num_processes, } + if (slice_index := distributed.global_state.slice_index) is not None: + distribute_options['slice_index'] = slice_index if options is not None: distribute_options.update(updated_options) return xla_client.make_c_api_client( From 549f1cd856bbed820c23c3cec56a084ab5c31d9e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 14:50:19 -0700 Subject: [PATCH 0396/1769] Don't set `memory_kind` to `None` if the mesh is AbstractMesh and the PiperOrigin-RevId: 744077517 --- jax/_src/distributed.py | 2 +- jaxlib/xla/sharding.cc | 21 ++++++++++++++++++++- jaxlib/xla/xla_client.py | 2 +- tests/array_test.py | 18 +++++++++++++++++- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index a7551465c425..fb0aebb0e642 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -152,7 +152,7 @@ def initialize(self, self.initialize_preemption_sync_manager() if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: - slice_index = int(os.environ.get('JAX_SLICE_INDEX')) + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) # type: ignore self.slice_index = slice_index def shutdown(self): diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 858c025745e4..b6b58b0600ad 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -192,6 +193,14 @@ bool ShardingEqual(nb::handle a, nb::handle b) { return a.equal(b); } +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object manual_axes, nb::object logical_device_ids) @@ -217,7 +226,17 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, memory_kind_ = CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); } else { - memory_kind_ = nb::none(); + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } } // TODO(phawkins): this leaks a reference to the check_pspec function. diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 523f8bb57b90..58e0cb070e29 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 325 +_version = 326 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/array_test.py b/tests/array_test.py index 2bdc54607473..76aa1093ede3 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,9 +29,10 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, AbstractMesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, @@ -1418,6 +1419,21 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(mesh1, mesh2) self.assertNotEqual(hash(mesh1), hash(mesh2)) + def test_memory_kind_with_abstract_mesh(self): + if jaxlib_extension_version < 326: + self.skipTest('Requires jaxlib_extension_version >= 326') + + abstract_mesh = AbstractMesh((2,), ('x',)) + ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') + self.assertEqual(ns.memory_kind, 'pinned_host') + + ns = NamedSharding(abstract_mesh, P()) + self.assertIsNone(ns.memory_kind) + + with self.assertRaisesRegex( + ValueError, 'Got invalid memory kind'): + NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): From d81c0ffeb744d51e9e428fed0922f2df06dadfd5 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Fri, 4 Apr 2025 15:10:08 -0700 Subject: [PATCH 0397/1769] [Mosaic GPU] Limit the maximum number of registers per thread to 255. PiperOrigin-RevId: 744083257 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 14 +++++++++++--- tests/pallas/mosaic_gpu_test.py | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 21efbbec6630..df9c6668a51d 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -686,8 +686,16 @@ def _compute_registers( memory_registers: int, num_compute_wgs: int, ) -> int: - """Returns the number of registers to use for the compute thread.""" - # TODO(justinfu): Configure this per-platform. - n_registers = (512 - memory_registers) / num_compute_wgs + """Returns the max number of registers to use in compute threads. + + We start with the theoretical max registers per thread if one wargroup + (128 threads) used the entire SM's 64k register file (64k / 128 = 512). + Then reserve `memory_registers` for the producer warpgroup and distribute + the remaining registers evenly among the compute warpgroups. + + Note: The maximum number of registers per thread is 255, so we clamp + the value. + """ + n_registers = min(256, (512 - memory_registers) / num_compute_wgs) # Round down to the nearest multiple of 8. return int((n_registers // 8) * 8) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0a3087f5902d..809c9c8fcaeb 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2228,7 +2228,8 @@ def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2]) + def test_elementwise_add(self, m, n, num_compute_wgs): self.skip_if_wg_semantics() # Crashes! blk_m = blk_n = 64 From aab6613944857c76c19e7e4732870885a1fa0a27 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 4 Apr 2025 15:33:20 -0700 Subject: [PATCH 0398/1769] [pallas:mosaic_gpu] Fixed a typo in `_barrier_arrive_pp_eqn` PiperOrigin-RevId: 744089477 --- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index bfe1bf3fe5bc..8c04bd23cf22 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -618,7 +618,7 @@ def _barrier_arrive_pp_eqn( ): del settings barrier, *flat_transforms = eqn.invars - transforms_treedef = eqn.params["transforms_tree"] + transforms_treedef = eqn.params["transforms_treedef"] transforms = transforms_treedef.unflatten(flat_transforms) return pp.concat([ pp.text("barrier_arrive"), From fc5d9a4fcee2a4606f36d2d2bd517458afde24a3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 19:22:31 -0700 Subject: [PATCH 0399/1769] Check that memory_kind of an aval is always None PiperOrigin-RevId: 744136969 --- jax/_src/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1c35f5406543..9a5a6061cc5e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1899,6 +1899,7 @@ def get_sharding(sharding, shape): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") _check_divisibility(out_s, shape) + assert out_s.memory_kind is None return out_s def str_short_aval(shape, dtype, mesh, spec, vma, From 2e62693f72b4ce217dbd798830313439c8fbc1a6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 5 Apr 2025 06:36:53 -0700 Subject: [PATCH 0400/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8118a02a2d8af30563d2942818ddb7c07c373877. PiperOrigin-RevId: 744248817 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 7abf3da775d2..e632798e3132 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec" -XLA_SHA256 = "d5de319756b6a32748d2821f5319f831b062d0f5f22b7f0bde1d9564dc6b6f5e" +XLA_COMMIT = "8118a02a2d8af30563d2942818ddb7c07c373877" +XLA_SHA256 = "080edaa896d1537bb838428c164cab88532ab5b9609cb6b58ddaf19bad37f88b" def repo(): tf_http_archive( From 6bae8c75c87ed18f13a3eaf41a44b1cca9e38303 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 6 Apr 2025 00:29:23 +0000 Subject: [PATCH 0401/1769] [vmappable] fix trace context bugs to_elt must run in the parent context, while from_elt must run in the batching context. We previously had it precisely backward! Tests didn't catch it because our tests are extremely minimal, and in particular didn't check a to_elt that binds primitives. --- jax/_src/interpreters/batching.py | 12 +++--- tests/batching_test.py | 67 ++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index a187d42511ac..fd094f408567 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -616,17 +616,15 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals): trace = BatchTrace(parent_trace, tag, axis_data) idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with core.set_current_trace(parent_trace): + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): outs = f(*in_tracers) - - out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_data.size, - axis_data.explicit_mesh_axis), - range(len(outs)), outs, out_dim_dests) - + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests + out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis), + range(len(outs)), outs, out_dim_dests) return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. diff --git a/tests/batching_test.py b/tests/batching_test.py index f2a4e8c34fe3..393317bcbe77 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]: @jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe class VmappableTest(jtu.JaxTestCase): - def test_basic(self): + @parameterized.parameters([False, True]) + def test_basic(self, jit): with temporarily_register_named_array_vmappable(): def f(x): return named_mul(x, x) + if jit: + f = jax.jit(f) x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) g = jax.vmap(f, - in_axes=NamedMapSpec('i', 0), - out_axes=NamedMapSpec('i', 1), - axis_size=3) + in_axes=NamedMapSpec('i', 0), + out_axes=NamedMapSpec('i', 1), + axis_size=3) ans = g(x) expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2) self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) - def test_basic_jit(self): - with temporarily_register_named_array_vmappable(): - def f(x): - return named_mul(x, x) - - x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) - ans = jax.jit(f)(x) - expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) - - self.assertEqual(ans.names, expected.names) - self.assertAllClose(ans.data, expected.data) + def test_to_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return cont(val.data + 1, spec) + def from_elt(cont, size, elt, spec): + assert False + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x - 1, axis_size=3)(a) + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.)) + + def test_from_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return A(cont(val.data, spec)) + def from_elt(cont, size, elt, spec): + return A(cont(size, elt.data + 1, spec)) + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x, axis_size=3)(a).data + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.) + 1) def test_types_with_same_spec(self): # We register NamedArray. From b1b54f9b5ecbcd12bd09f91e7c36cfd40dd0ce15 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 05:43:14 -0700 Subject: [PATCH 0402/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3889bec6b7f48e304953a485b713e9982dff0441. PiperOrigin-RevId: 744444688 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e632798e3132..1a7522bda0fa 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8118a02a2d8af30563d2942818ddb7c07c373877" -XLA_SHA256 = "080edaa896d1537bb838428c164cab88532ab5b9609cb6b58ddaf19bad37f88b" +XLA_COMMIT = "3889bec6b7f48e304953a485b713e9982dff0441" +XLA_SHA256 = "f23bb226d334f933cd5e6ebc4b20dec9ad879137763975546120ddf582a472b8" def repo(): tf_http_archive( From ad36f7f2532528c661ed27fc7c71dbc0e2e11c9d Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 09:57:16 -0700 Subject: [PATCH 0403/1769] Automated Code Change PiperOrigin-RevId: 744478350 --- jaxlib/xla/BUILD | 2 ++ jaxlib/xla/custom_call_sharding.cc | 2 ++ jaxlib/xla/dlpack.cc | 1 + jaxlib/xla/dlpack.h | 1 + 4 files changed, 6 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 5b532c1dc501..c861ed06e5be 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -220,6 +220,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@nanobind", + "@xla//third_party/python_runtime:headers", "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/hlo/ir:hlo", @@ -264,6 +265,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:status_macros", "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:pjrt_common", diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc index 3cb53b438e09..00accd85aefd 100644 --- a/jaxlib/xla/custom_call_sharding.cc +++ b/jaxlib/xla/custom_call_sharding.cc @@ -15,6 +15,8 @@ limitations under the License. #include "jaxlib/xla/custom_call_sharding.h" +#include + #include #include #include diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index d1cb91114b05..c8d02e679036 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index 46b0954105f7..7fffdc345d79 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -25,6 +25,7 @@ limitations under the License. #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" namespace xla { From 477b10825ad88c25ff3730cc068ea265d5f54dfa Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:14 -0700 Subject: [PATCH 0404/1769] Automated Code Change PiperOrigin-RevId: 744480338 --- examples/jax_cpp/BUILD | 2 ++ examples/jax_cpp/main.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b3cb995aae21..86f3129c9876 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,6 +21,7 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", @@ -33,6 +34,7 @@ cc_binary( "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 5d1190ff1f2c..8deea5448fec 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,6 +41,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" From 3f083caef59b224809e10808c4c306383307c2dc Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:22 -0700 Subject: [PATCH 0405/1769] Automated Code Change PiperOrigin-RevId: 744480358 --- examples/ffi/src/jax_ffi_example/rms_norm.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 819f3b9f868d..bcfc1eb67aa4 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include #include -#include -#include #include #include From d1c7ba4335d09a5523e42fc8280ec94413dae31a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:59 -0700 Subject: [PATCH 0406/1769] Automated Code Change PiperOrigin-RevId: 744480452 --- jaxlib/mosaic/gpu/custom_call.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 465551e2903b..38a388224d65 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" From 7874d79f56f99f9366039883d304c659c84f1c47 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:32:55 -0700 Subject: [PATCH 0407/1769] Automated Code Change PiperOrigin-RevId: 744483310 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/callback.cc | 1 + jaxlib/xla/py_socket_transfer.cc | 1 + jaxlib/xla/to_ifrt_sharding.cc | 1 + 4 files changed, 4 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index c861ed06e5be..35f344046828 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -732,6 +732,7 @@ cc_library( "@xla//xla/python/transfer:socket_bulk_transport", "@xla//xla/python/transfer:streaming", "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/platform:statusor", ], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 6f5644c3b0c7..b5519ed3bee3 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index 55d84fd71bb7..4aa40cf66087 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/python/transfer/socket_bulk_transport.h" #include "xla/python/transfer/streaming.h" #include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 2a7c6707e766..52879cfa9fbe 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { From 8a6efa317d2c104ca7905a6a4d6e521a9b9ebe4c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 6 Apr 2025 13:35:12 -0700 Subject: [PATCH 0408/1769] Fix deadlock when computing cached Sharding::type() values. C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/sharding.cc | 38 +++++++++++++++++++++++++++++++++++++- jaxlib/xla/sharding.h | 32 ++++++++++++++++---------------- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 35f344046828..8602652cbd8a 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -565,6 +565,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index b6b58b0600ad..ff1539764864 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep @@ -242,7 +244,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - nb::object* check_pspec = [](){ + nb::object* check_pspec = []() { static absl::Mutex mu; static nb::object* output = nullptr; { @@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, (*check_pspec)(mesh_, spec_, manual_axes_); } +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding(nb::object device, nb::object memory_kind) : Sharding(/*num_devices=*/1), @@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding( xla::nb_class_ptr client, xla::ifrt::DeviceListRef device_list, nb::object memory_kind) @@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, xla::make_nb_class(nb::tuple(flat_devices)); } +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, nb::object memory_kind, nb::object device_list) : Sharding(/*num_devices=*/nb::len(devices.ptr())), @@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); @@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { return xla::ValueOrThrow(s.internal_device_list()); }); + NamedSharding::InitializeType(); nb::class_(m, "SingleDeviceSharding", nb::dynamic_attr()) @@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) .def_prop_ro("_internal_device_list", &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); nb::class_(m, "PmapSharding", nb::dynamic_attr()) .def( @@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) .def_prop_ro("_internal_device_list", &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) .def(nb::init(), @@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) .def_prop_ro("_internal_device_list", &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); } } // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 698ff2ca9ca8..4b602bd14324 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_XLA_SHARDING_H_ #define JAXLIB_XLA_SHARDING_H_ +#include + #include #include #include @@ -84,10 +86,8 @@ class NamedSharding : public Sharding { return logical_device_ids_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); absl::StatusOr> internal_device_list() const { if (internal_device_list_) { @@ -105,6 +105,7 @@ class NamedSharding : public Sharding { nanobind::object manual_axes_; nanobind::object logical_device_ids_; std::optional> internal_device_list_; + static PyObject* type_; }; class SingleDeviceSharding : public Sharding { @@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding { const nanobind::object& device() const { return device_; } const nanobind::object& memory_kind() const { return memory_kind_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding { nanobind::object device_; nanobind::object memory_kind_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; // The C++ implementation of jax.PmapSharding in python. It contains a few key @@ -147,10 +148,8 @@ class PmapSharding : public Sharding { const ShardingSpec& sharding_spec() const { return sharding_spec_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -160,6 +159,7 @@ class PmapSharding : public Sharding { xla::nb_numpy_ndarray devices_; ShardingSpec sharding_spec_; xla::nb_class_ptr internal_device_list_; + static PyObject* type_; }; class GSPMDSharding : public Sharding { @@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding { return *hash_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } @@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding { nanobind::object memory_kind_; std::optional hash_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; void RegisterSharding(nanobind::module_& m); From cccc34dc2334040e58eeb6131f2ac7f1470a8f62 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sun, 6 Apr 2025 23:37:20 -0700 Subject: [PATCH 0409/1769] Raise an error if the type passed to `axis_types` argument of `Mesh` and `AbstractMesh` is not `jax.sharding.AxisType`. PiperOrigin-RevId: 744602037 --- jax/_src/mesh.py | 13 +++++++++---- tests/array_test.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 00859f9b3d74..8db4445542d0 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -111,12 +111,16 @@ class AxisType(enum.Enum): def __repr__(self): return self.name -def _normalize_axis_types(axis_names, axis_types): +def _normalize_axis_types(axis_names, axis_types, name): axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) + + if not all(isinstance(a, AxisType) for a in axis_types): + raise TypeError( + f"axis_types passed to {name} must be of type `jax.sharding.AxisType`." + f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}") if len(axis_names) != len(axis_types): raise ValueError( "Number of axis names should match the number of axis_types. Got" @@ -256,7 +260,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = _normalize_axis_types(axis_names, axis_types) + axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh') key = (axis_names, devices.shape, tuple(devices.flat), axis_types) val = _mesh_object_dict.get(key, None) @@ -440,7 +444,8 @@ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], self.axis_sizes = axis_sizes self.axis_names = axis_names self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 - self._axis_types = _normalize_axis_types(self.axis_names, axis_types) + self._axis_types = _normalize_axis_types( + self.axis_names, axis_types, 'AbstractMesh') self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types)) def __hash__(self): diff --git a/tests/array_test.py b/tests/array_test.py index 76aa1093ede3..f097497cef51 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1377,6 +1377,16 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisType.Auto) + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types=("explicit",)) + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types="explicit") + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2, 2), ('x', 'y'), + axis_types=("explicit", AxisType.Explicit)) + def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual From 90cfa99a6868df89ea96923aae4338c123bfd242 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 7 Apr 2025 00:51:13 -0700 Subject: [PATCH 0410/1769] [Mosaic GPU] Support Slice and Transpose in the Pallas WGMMA lowering This change also fixes the transpose handling in the lowering and completely removes the use of the TransposeTransform. Instead we rely on strides. If we don't discover any issues with this, we will remove the transpose transform also from the mlir dialect. PiperOrigin-RevId: 744618241 --- jax/_src/pallas/mosaic_gpu/primitives.py | 54 ++++++++--- .../mosaic/gpu/dialect_lowering.py | 94 ++++++++++++++----- .../mosaic/gpu/transform_inference.py | 81 ++++++++++++---- tests/mosaic/gpu_dialect_test.py | 22 +++++ tests/pallas/mosaic_gpu_test.py | 47 +++++++--- 5 files changed, 229 insertions(+), 69 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8c04bd23cf22..8bd67e705cf0 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -903,15 +903,25 @@ def _wgmma_lowering( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_transforms(a, a_transforms) + a, a_transforms = lowering._handle_transforms( + a, a_transforms, handle_transposes=False, handle_reshapes=False + ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (8, swizzle_elems): - raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + if tiling != (8, swizzle_elems): + raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") else: + lhs_transpose = False b_transforms_leaves = transforms_leaves # type: ignore if not isinstance(a, mgpu.FragmentedArray): raise ValueError( @@ -949,8 +959,6 @@ def _wgmma_lowering( f" {rhs_tiling}." ) - # TODO(cperivol): Find a generic way to move this reshape into - # _handle_transforms. high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) rhs_transpose = False @@ -964,6 +972,8 @@ def _wgmma_lowering( if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + if lhs_transpose: + a = mgpu.memref_transpose(a, (1, 0, 3, 2)) if rhs_transpose: b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) @@ -981,23 +991,37 @@ def _wgmma_warpgroup_lowering( a_transforms_tree, b_transforms_tree, ): - del ctx, transforms_leaves # Unused. + del ctx # Unused. + if a_transforms_tree is not None: - match a_transforms_tree: - case gpu_core.TransposeRef((1, 0)): - raise NotImplementedError("WGMMA lhs transpose not supported.") + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a, a_transforms = lowering._handle_transforms(a, a_transforms) + match a_transforms: + case (gpu_core.TransposeRef((1, 0)),): + a = mgpu.memref_transpose(a, (1, 0)) + case (): + pass case _: raise ValueError( - f"WGMMA lhs has unsupported transforms: {a_transforms_tree}." + f"WGMMA lhs has unsupported transforms: {a_transforms}." ) + else: + b_transforms_leaves = transforms_leaves # type: ignore if b_transforms_tree is not None: - match b_transforms_tree: - case gpu_core.TransposeRef((1, 0)): - raise NotImplementedError("WGMMA rhs transpose not supported.") + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b, b_transforms = lowering._handle_transforms(b, b_transforms) + match b_transforms: + case (gpu_core.TransposeRef((1, 0)),): + b = mgpu.memref_transpose(b, (1, 0)) + case (): + pass case _: raise ValueError( - f"WGMMA rhs has unsupported transforms: {b_transforms_tree}." + f"WGMMA rhs has unsupported transforms: {b_transforms}." ) new_acc = mgpu.dialect.wgmma(acc, a, b) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index f00cff9a500c..3deb53646ce4 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -355,7 +355,7 @@ def _vector_load_op_lowering_rule( ) ref_ty = ir.MemRefType(vector_load_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = transform_memref(vector_load_op.base, transforms) + transformed_ref = reinterpret_smem_ref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, @@ -397,7 +397,7 @@ def _vector_store_op_lowering_rule( ref_ty = ir.MemRefType(vector_store_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) fragmented_array.store_tiled( - transform_memref(vector_store_op.base, transforms), swizzle + reinterpret_smem_ref(vector_store_op.base, transforms), swizzle ) elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): @@ -510,32 +510,78 @@ def swizzle_and_transforms_from_transforms_attr( return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) -def transform_memref( - mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...] +def _is_memref_transposed(mem_ref_type: ir.MemRefType) -> bool: + strides, _ = mem_ref_type.get_strides_and_offset() + prev_stride = math.inf + for stride in strides: + if stride > prev_stride: + return True + prev_stride = stride + return False + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], ) -> ir.Value: - """Reinterprets the memref to one where the shape is transformed as given.""" - if not transforms: - return mem_ref + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. - mem_ref_type = ir.MemRefType(mem_ref.type) - if mem_ref_type.memory_space != ir.Attribute.parse( - "#gpu.address_space" - ): - raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.") + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + transposed = _is_memref_transposed(ref_ty) + if not transforms and not transposed: + return ref + + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + + shape = ref_ty.shape + if transposed: + if len(shape) != 2: + raise NotImplementedError( + f"Only 2D shapes can be transposed, but got {shape}" + ) + strides, _ = ref_ty.get_strides_and_offset() + if strides[0] != 1 or strides[1] != shape[0]: + raise NotImplementedError( + f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" + ) - shape = mem_ref_type.shape for t in transforms: - shape = t.transform_shape(shape) + shape = list(t.transform_shape(shape)) + + if transposed: + # The expected output is a transposed ref and `shape` is already transposed. + # We need to compute the correct strides to match the shape. + if len(shape) == 2: + minor_to_major_stride_order = (1, 0) + elif len(shape) == 4: + minor_to_major_stride_order = (2, 3, 0, 1) + else: + raise NotImplementedError( + f"Expected a 2D or 4D shape after transforms, but got {shape}" + ) + strides = [1]*len(shape) + for i in minor_to_major_stride_order[1:]: + strides[i] = strides[i-1] * shape[i-1] + layout = ir.StridedLayoutAttr.get(0, strides) + else: + layout = None - memref_new_type = ir.MemRefType.get( + new_ref_ty = ir.MemRefType.get( shape, - mem_ref_type.element_type, - memory_space=mem_ref_type.memory_space, + ref_ty.element_type, + memory_space=ref_ty.memory_space, + layout=layout, ) - ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE - ptr = utils.memref_ptr(mem_ref, memory_space=ms) - return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms) + ptr = utils.memref_ptr(ref, memory_space=ms) + ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return ref @_register_lowering(mgpu.AsyncLoadOp) @@ -569,7 +615,7 @@ def _mgpu_async_load_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=transform_memref(load_op.destination, transforms), + dst_ref=reinterpret_smem_ref(load_op.destination, transforms), gmem_slice=tuple(gmem_slice), barrier=barrier, arrive=False, @@ -610,7 +656,7 @@ def _mgpu_async_store_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=transform_memref(store_op.source, transforms), + src_ref=reinterpret_smem_ref(store_op.source, transforms), dst_ref=store_op.destination, gmem_slice=tuple(gmem_slice), swizzle=swizzle, @@ -840,7 +886,7 @@ def _mgpu_wgmma_op_lowering_rule( _check_transforms_and_swizzle_are_supported( ref_ty, b_transforms, b_swizzle, minimum_swizzle ) - b_operand = transform_memref(wgmma_op.b, b_transforms) + b_operand = reinterpret_smem_ref(wgmma_op.b, b_transforms) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -857,7 +903,7 @@ def _mgpu_wgmma_op_lowering_rule( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = transform_memref(wgmma_op.a, a_transforms) + a_operand = reinterpret_smem_ref(wgmma_op.a, a_transforms) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 3438a654f90a..6026cb216166 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -98,11 +98,8 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: element_bytewidth = utils.bytewidth(ref_ty.element_type) strides, _ = ref_ty.get_strides_and_offset() - - if strides[0] < strides[1]: - raise NotImplementedError("Transpositions aren't handled yet.") - - minor_dim = ref_ty.shape[1] + transposed = strides[0] < strides[1] + minor_dim = ref_ty.shape[0 if transposed else 1] major_tiling = 8 # Try tiling with all swizzling modes starting from the largest one. @@ -118,12 +115,14 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: break else: # No valid tile transform can be inferred. - raise ValueError( - f"{ref_ty.shape} is not a valid WGMMA shape" - ) + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get((major_tiling, minor_tiling)), + mgpu.TileTransformAttr.get(tiling), mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), ]) @@ -255,6 +254,24 @@ def _infer_dynamic_smem_transforms( return None +def _get_tile_and_swizzle_transforms( + transforms: ir.ArrayAttr | None, +) -> tuple[ir.Attribute, ir.Attribute]: + if transforms is None: + return + + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Unsupported transforms {transforms}.") + return tile_transform, swizzle_transform + else: + raise NotImplementedError(f"Unsupported transforms {transforms}.") + + # This is used by Pallas' "_handle_indexing" memory transform. @partial(_add_transform_inference_rule, memref.SubViewOp) def _infer_memref_subview_transforms( @@ -285,15 +302,7 @@ def _infer_memref_subview_transforms( # - We only propagate transforms if they consist of a single tile transform # and a single swizzle transform. # TODO(bchetioui): implement more complex propagation rules. - if len(transforms) == 2: - tile_transform, swizzle_transform = transforms - if not ( - mgpu.TileTransformAttr.isinstance(tile_transform) - and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) - ): - raise NotImplementedError(f"Can't propagate transforms {transforms}.") - else: - raise NotImplementedError(f"Can't propagate transforms {transforms}.") + tile_transform, _ = _get_tile_and_swizzle_transforms(transforms) # Check swizzle transform propagation. strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) @@ -318,6 +327,42 @@ def _infer_memref_subview_transforms( return [transforms], [transforms] +@partial(_add_transform_inference_rule, memref.TransposeOp) +def _infer_memref_transpose_transforms( + op: memref.TransposeOp, +) -> OptionalTransforms: + in_ty = ir.MemRefType(op.in_.type) + if len(in_ty.shape) != 2: + raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() + transpose = in_strides != out_strides + + users = list(op.result.uses) + if len(users) != 1: + raise NotImplementedError( + f"Only memref.transpose with a single use are supported, got {op}" + ) + + op_operand_use = users[0] + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand(consumer, op_user) + + in_transforms = [] + if not transpose: + in_transforms = out_transforms + else: + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms( + out_transforms + ) + transposed_tiling = mgpu.TileTransformAttr(tile_transform).tiling[::-1] + in_transforms.append(mgpu.TileTransformAttr.get(transposed_tiling)) + in_transforms.append(swizzle_transform) + + return [ir.ArrayAttr.get(in_transforms)], [out_transforms] + + # `memref.load` is used to load barrier phases---the rule needn't do anything # interesting, but we need to have it in order to avoid crashing on it. @partial(_add_transform_inference_rule, memref.LoadOp) diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index bc94d72dc0d8..7e211abb955a 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -34,6 +34,7 @@ from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import dialect_lowering as lowering _cext = mgpu.dialect._cext if mgpu.dialect is not None else None @@ -906,6 +907,27 @@ def body(vec1, vec2, ref): with self.assertRaisesRegex(ir.MLIRError, error): self.module.operation.verify() + def test_memref_transforms_with_transpose(self): + with ir.InsertionPoint(self.module.body): + ty_in = ir.MemRefType.get( + (64, 128), + ir.BF16Type.get(), + memory_space=ir.Attribute.parse("#gpu.address_space"), + ) + ref = memref.alloc(ty_in, [], []) + + ref = mgpu_utils.memref_transpose(ref, (1, 0)) + # This tiling is applied to the transposed memref. + transforms = [mgpu.TileTransform(tiling=(16, 32))] + + ref_transformed = lowering.reinterpret_smem_ref(ref, transforms) + ty_transformed = ir.MemRefType(ref_transformed.type) + self.assertEqual(ty_transformed.shape, [8, 2, 16, 32]) + strides, _ = ty_transformed.get_strides_and_offset() + self.assertEqual(strides, [512, 4096, 1, 16]) + + + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 809c9c8fcaeb..10b4f3de60ad 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1551,7 +1551,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.product(lhs_transpose=[False, True], rhs_transpose=[False, True]) + def test_realistic_matmul(self, lhs_transpose, rhs_transpose): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1561,7 +1562,11 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): # Make sure tiling does not alter the shape of references + if lhs_transpose: + a_ref = plgpu.transpose_ref(a_ref, (1, 0)) assert a_ref.shape == (tile_m, tile_k) + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) assert b_ref.shape == (tile_k, tile_n) assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) @@ -1572,17 +1577,31 @@ def _epilogue(): plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + a_shape = (k, m) if lhs_transpose else (m, k) + a = jax.random.uniform(key1, shape=a_shape, dtype=dtype) + b_shape = (n, k) if rhs_transpose else (k, n) + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - lhs_spec = pl.BlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - ) - rhs_spec = pl.BlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - ) + if lhs_transpose: + lhs_spec = pl.BlockSpec( + (tile_k, tile_m), + lambda m, n, k: (k, m), + ) + else: + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + if rhs_transpose: + rhs_spec = pl.BlockSpec( + (tile_n, tile_k), + lambda m, n, k: (n, k), + ) + else: + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) out_spec = pl.BlockSpec( (tile_m, tile_n), lambda m, n, k: (m, n), @@ -1627,7 +1646,11 @@ def _epilogue(): delay_release=1, ), )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) + np.testing.assert_allclose( + res, + (a.T if lhs_transpose else a) @ (b.T if rhs_transpose else b), + rtol=1e-3, + ) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): From 245194ffa13f4a7f38e7ab1a30aa46a2d29af3f5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 02:40:01 -0700 Subject: [PATCH 0411/1769] Use `contextlib.nullcontext` instead of `trivial_ctx` I removed `trivial_ctx` from the public `jax.interpreters.partial_eval` submodule without going through a deprecation cycle, because it is highly unlikely anyone is using it. PiperOrigin-RevId: 744645764 --- jax/_src/interpreters/partial_eval.py | 6 ++---- jax/interpreters/partial_eval.py | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0a8e3b7824ff..532eb0f80029 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +import contextlib from functools import partial import itertools as it import operator as op @@ -1236,14 +1236,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a2e9..a2d988f6bea3 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -81,7 +81,6 @@ trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, ) From 695ee8f3d1cc47cb1da286e900b1434d1f0951a2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 05:48:16 -0400 Subject: [PATCH 0412/1769] Fix a race in pjit under free threading. Fixes https://github.com/jax-ml/jax/issues/27767 --- .github/workflows/tsan.yaml | 2 +- jaxlib/xla/pjit.cc | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 1bdb36b2cd03..4c28eab528e4 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,7 +13,7 @@ on: - main paths: - '**/workflows/tsan.yaml' - - '**/workflows/tsan-suppressions.txt' + - '**/workflows/tsan-suppressions*.txt' jobs: tsan: diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 503e8ef23f4b..50bdc750d3a4 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -245,16 +245,17 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { std::shared_ptr cache = std::make_shared(&self->lru_list_); auto callback = nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { - std::unique_ptr value; - { - nb::ft_object_guard lock(self); - auto it = self->functions_.find(key); - if (it == self->functions_.end()) { - return; - } - value = std::move(it->second); - self->functions_.erase(it); + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { From ce7dc85104813f153e42f546d698f4147a00795d Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 7 Apr 2025 11:53:20 +0200 Subject: [PATCH 0413/1769] [export] Add support for serializing functions with PRNG keys as inputs/outputs This introduces version 4 of serialization, fully backwards compatible with versions 2 and 3. Fixes: #24143 --- jax/_src/export/serialization.fbs | 6 +++++- jax/_src/export/serialization.py | 7 +++++++ jax/_src/export/serialization_generated.py | 7 +++++-- jax/_src/prng.py | 2 +- tests/export_test.py | 12 ++++++++++++ 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 7d3e342f1879..01cfa9944dfd 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -45,7 +45,7 @@ enum AbstractValueKind: byte { } enum DType: byte { - // Last used id: 22 + // Last used id: 29 bool = 0, i8 = 1, i16 = 2, @@ -76,6 +76,10 @@ enum DType: byte { f8_e5m2fnuz = 21, f8_e8m0fnu = 25, f4_e2m1fn = 26, + + key_fry = 27, + key_rbg = 28, + key_unsafe_rbg = 29, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 94c0baf642b6..3d878cccc701 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dtypes from jax._src import effects +from jax._src import prng from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export @@ -48,6 +49,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. +# Version 4, April 7th, 2025, adds serialization for PRNGs key types. +# This version is backwards compatible with Version 2 and 3. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -361,6 +364,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, + + prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry, + prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg, + prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg, } _dtype_kind_to_dtype = { diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index b1fc13333777..34211c1ebe54 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,16 +53,19 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 - f8_e3m4 = 24 - f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e4m3 = 23 + f8_e3m4 = 24 f8_e8m0fnu = 25 f4_e2m1fn = 26 + key_fry = 27 + key_rbg = 28 + key_unsafe_rbg = 29 class ShardingKind(object): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 35282cb716cb..1dc7e9c0df0e 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -113,7 +113,7 @@ def pprint(self): ])))) -prngs = {} +prngs: dict[str, PRNGImpl] = {} def register_prng(impl: PRNGImpl): if impl.name in prngs: diff --git a/tests/export_test.py b/tests/export_test.py index 2264fbdd997b..26157f2f6a79 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -421,6 +421,18 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) + @jtu.parameterized_filterable( + kwargs=[dict(impl=p) + for p in ("rbg", "unsafe_rbg", "threefry2x32")]) + def test_prng_keys(self, *, impl): + + key = jax.random.key(42, impl=impl) + @jax.jit + def f(key): + return key + exp_f = get_exported(jax.jit(f))(key) + self.assertEqual(f(key), exp_f.call(key)) + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c From 075d88febc58913ead503c452e0eeafd317fee6f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 03:08:09 -0700 Subject: [PATCH 0414/1769] Fix some test timeouts PiperOrigin-RevId: 744652508 --- tests/BUILD | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 23d59e8d549a..63969bc935da 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -947,7 +947,7 @@ jax_multiplatform_test( }, shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 50, "tpu": 40, }, tags = ["noasan"], # Times out @@ -1132,10 +1132,11 @@ jax_multiplatform_test( backend_tags = { "cpu": [ "noasan", # Times out under asan - "notsan", # Times out under asan + "notsan", # Times out under tsan ], "tpu": [ - "noasan", # Times out under asan. + "noasan", # Times out under asan + "notsan", # Times out under tsan ], }, shard_count = { From 6e93fa34f32c2e57bc3b65948f26eb27d180b9bd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 03:38:43 -0700 Subject: [PATCH 0415/1769] Removed unused deprecations PiperOrigin-RevId: 744659794 --- jax/_src/deprecations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 37f2f0264782..6c39c893a111 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -127,12 +127,10 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register('jax-nn-one-hot-float-input') register("jax-numpy-astype-complex-to-real") -register("jax-numpy-array-none") register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') -register('pallas-gpu-triton') register('jax-scipy-special-sph-harm') From 9b850a9e9413db077ef74ef6672b9eb36c388fb4 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 7 Apr 2025 03:39:29 -0700 Subject: [PATCH 0416/1769] [Mosaic GPU] Delete mentions of `WGMMARowFragLayout` in `layouts.py`. PiperOrigin-RevId: 744659986 --- jax/experimental/mosaic/gpu/layouts.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index cb94c3eaf749..78c9f670881b 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -171,15 +171,6 @@ def to_layout_attr( ) -_wgmma_row_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMARowFragLayout$" -) - - -def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr))) - - def from_layout_attr( attr: ir.Attribute, ) -> ( @@ -194,8 +185,6 @@ def from_layout_attr( return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) - elif is_wgmma_row_fragmented_layout(attr): - return fa.WGMMARowFragLayout() else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" From 4596ee3cc5e5970c9f250ad0136a55c6caa3ded0 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 04:12:51 -0700 Subject: [PATCH 0417/1769] Add a missing jaxlib version check in Pallas TPU lowering PiperOrigin-RevId: 744668747 --- jax/_src/pallas/mosaic/lowering.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 87e06f486366..2669d73691c1 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -45,6 +45,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -3579,13 +3580,16 @@ def _dma_start_lowering_rule( sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) + priority_kwarg = {"priority": priority} + if jaxlib_version < (0, 5, 4): + priority_kwarg = {} tpu.enqueue_dma( src_ref, dst_ref, sem, source_semaphore=src_sem, device_id=device_id, - priority=priority, + **priority_kwarg, ) return [] From 153fa228943bedb2420c3e7961a5781e5eef6319 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 04:14:17 -0700 Subject: [PATCH 0418/1769] Add more TSAN skips to avoid timeouts PiperOrigin-RevId: 744669093 --- tests/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index 63969bc935da..eb6ff81f5d68 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -597,7 +597,10 @@ jax_multiplatform_test( srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { "gpu": ["noasan"], # Times out. - "cpu": ["noasan"], # Times out. + "cpu": [ + "noasan", + "notsan", + ], # Times out. }, shard_count = { "cpu": 20, From c2aa811cd6e196a64c3572194e5aa86e4b65f7da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 05:49:47 -0700 Subject: [PATCH 0419/1769] `jex.core.Var` is no longer ordered This behavior was only needed for kfac_jax which has been updated *not* to rely on variable ordering. PiperOrigin-RevId: 744691114 --- jax/_src/core.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 9a5a6061cc5e..14ed19d4d441 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -412,7 +412,6 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() -@total_ordering class Var: __slots__ = ["count", "suffix", "aval"] @@ -425,11 +424,6 @@ def __init__(self, suffix: str, aval: AbstractValue): self.suffix = suffix self.aval = aval - # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not - # care about variable ordering, but the downstream package kfac_jax does. - def __lt__(self, other): - return self.count < other.count - def __repr__(self): return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' From 5c0f8858466d31e3678a9bbd16b6c305a38b0aec Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 06:28:35 -0700 Subject: [PATCH 0420/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/145f836bd5175dc5dd262f716a0c59af2b0297a0. PiperOrigin-RevId: 744700775 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1a7522bda0fa..a8a93026378f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3889bec6b7f48e304953a485b713e9982dff0441" -XLA_SHA256 = "f23bb226d334f933cd5e6ebc4b20dec9ad879137763975546120ddf582a472b8" +XLA_COMMIT = "145f836bd5175dc5dd262f716a0c59af2b0297a0" +XLA_SHA256 = "bd19d8a1d25468696809a69ef3984bb00ef432e3fe9c05116b9c114dc7c83fa2" def repo(): tf_http_archive( From 83572e17bd7ac833d1e346d91bad27dc4572aad8 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 7 Apr 2025 07:03:04 -0700 Subject: [PATCH 0421/1769] [Mosaic GPU] Add missing to/from tiled layout attributes with replicated lane dimensions. PiperOrigin-RevId: 744708476 --- jax/experimental/mosaic/gpu/layouts.py | 24 +++++++++++++++++++++--- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 11 +++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 78c9f670881b..d9b1a01a24b5 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -107,15 +107,25 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + def _lane_dim_str(d: int | fa.Replicated) -> str: + if isinstance(d, fa.Replicated): + return f"#mosaic_gpu.Replicated" + return str(d) + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + lane_dims = "[" + ",".join(_lane_dim_str(d) for d in layout.lane_dims) + "]" + return ir.Attribute.parse( f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," - f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_replicated_pattern = re.compile( + r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" +) def from_tiled_layout_attr( @@ -133,6 +143,12 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) + def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(lane_dim_str) + if match: + return fa.Replicated(int(match.group("times"))) + return int(lane_dim_str) + tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: @@ -141,8 +157,10 @@ def from_tiled_layout_attr( return fa.TiledLayout( tiling=fa.Tiling(tiles), warp_dim=int(match.group("warp_dim")), - lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), - vector_dim=int(match.group("vector_dim")) + lane_dims=tuple( + _lane_dim(s) for s in match.group("lane_dims").split(",") + ), + vector_dim=int(match.group("vector_dim")), ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 6b934b951d93..36f9f6f374e5 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -142,6 +142,17 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { + let summary = "Indicates a replicated dimension in a tiled layout."; + let description = [{ + See mosaic/gpu/fragmented_array.py -> Replicated for more details. + }]; + + let parameters = (ins "int":$times); + let mnemonic = "Replicated"; + let assemblyFormat = "`<` `times` `=` $times `>`"; +} + def MosaicGPU_TiledLayout : AttrDef { let summary = "A layout derived from a tiling expression."; let description = [{ From 70485e31b96a395a58de765b9b6a9260feb9d775 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 07:19:23 -0700 Subject: [PATCH 0422/1769] Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}. These are available via jax.extend.mlir.dialects. No deprecation period because jax.interpreters.mlir is not a stable API. PiperOrigin-RevId: 744712537 --- CHANGELOG.md | 3 +++ jax/_src/cudnn/fused_attention_stablehlo.py | 4 ++-- jax/_src/cudnn/fusion.py | 4 ++-- jax/interpreters/mlir.py | 2 -- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffd197b390d0..7f19fcb189ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call. * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..d901ed875ceb 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -28,8 +28,8 @@ from jax._src import xla_bridge from jax.interpreters import mlir from jax.interpreters import xla -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index f320672463cb..355b33e1509c 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -16,8 +16,8 @@ import jax from jax._src import core as jax_core from jax.interpreters import mlir -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 8a615be968a6..a0505c74f883 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -43,8 +43,6 @@ flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, - func_dialect as func_dialect, - hlo as hlo, i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, From 412f88e2234c4b82f18e43bdad7bf64a32ff94a5 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 7 Apr 2025 07:20:00 -0700 Subject: [PATCH 0423/1769] Temporarily skip JaxNumpyErrorTests in multi-thread environments PiperOrigin-RevId: 744712701 --- tests/jax_numpy_error_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index a38e7d5509f9..dba277289f42 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -30,6 +30,12 @@ class JaxNumpyErrorTests(jtu.JaxTestCase): + def setUp(self): + # TODO(b/408148001): Fix thread safety issue. + if jtu.TEST_NUM_THREADS.value > 1: + self.skipTest("Test does not work with multiple threads") + super().setUp() + @parameterized.product(jit=[True, False]) def test_set_error_if_nan(self, jit): def f(x): From a099b285307508efad12a015d6f6d9d13ae49077 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 7 Apr 2025 07:39:14 -0700 Subject: [PATCH 0424/1769] Reverts 735cec18cb2f8dff2aea5e503fd886a37aee094e PiperOrigin-RevId: 744717457 --- jaxlib/cuda/BUILD | 1 - jaxlib/gpu/py_client_gpu.cc | 88 ++++++++++++-------------------- jaxlib/rocm/BUILD | 1 - jaxlib/xla/BUILD | 1 - jaxlib/xla/py_client_cpu.cc | 87 +++++++++----------------------- tests/python_callback_test.py | 94 +++++++++++++++-------------------- 6 files changed, 95 insertions(+), 177 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d7035c92b24a..be7ac6116d2f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -637,7 +637,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index e3aec51d8d25..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "xla/util.h" namespace nb = nanobind; @@ -82,7 +81,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -112,6 +112,9 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -119,23 +122,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - // We pass in data using default numpy layout i.e., std::nullopt. - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI argument and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - size_t packed_size = arg->size_bytes() * bits_per_element / 8; - auto buffer = xla::UnpackIntN( - bits_per_element, static_cast(host_input_buffers[i]), - packed_size); - delete[] static_cast(host_input_buffers[i]); - host_input_buffers[i] = buffer.release(); - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); @@ -160,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -181,45 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - - const void* data = array.data(); - size_t size_bytes = array.size() * array.itemsize(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - void* temp = new char[size_bytes]; - temp_buffers.push_back(temp); - plan->Execute(data, temp); - data = temp; + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - size_bytes); - data = buffer.get(); - size_bytes = (size_bytes * bits_per_element) / 8; + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); } - - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 5893af26de85..94d75d9c19ae 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -539,7 +539,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 8602652cbd8a..c5d151b2cd3b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -640,7 +640,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index fef6a54aab2d..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -80,7 +79,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == U1) { + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,20 +96,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - std::unique_ptr buffer; - const void* data = arg->untyped_data(); - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI argument and return buffers are sized assuming - size_t packed_size = arg->size_bytes() * bits_per_element / 8; - buffer = xla::UnpackIntN(bits_per_element, static_cast(data), - packed_size); - data = buffer.get(); - } // We pass in data using default numpy layout i.e., std::nullopt. auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, data); + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -130,8 +119,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -151,55 +141,26 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - - const void* data = array.data(); - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - size_t size_bytes = array.size() * array.itemsize(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer - // supplied by FFI directly. - buffer = std::make_unique(size_bytes); - plan->Execute(data, buffer.get()); - data = buffer.get(); - } else { - plan->Execute(data, ret->untyped_data()); - data = ret->untyped_data(); - } - } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - size_bytes); - data = buffer.get(); - size_bytes = (size_bytes * bits_per_element) / 8; + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; } - - // Copy data to output buffer if haven't already or modified the data to - // write back. - if (data != ret->untyped_data()) { - std::memcpy(ret->untyped_data(), data, size_bytes); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 34ab20c05644..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,15 +586,10 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -605,17 +600,21 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) @@ -626,43 +625,16 @@ def f(): ) return y - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") - def test_non_default_stride_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) - x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() class PureCallbackTest(jtu.JaxTestCase): @@ -1136,6 +1108,20 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase): From 5a3fc606d47148cf47a96172ff67b1535182968b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 07:57:17 -0700 Subject: [PATCH 0425/1769] Deprecate public export of mlir.custom_call. PiperOrigin-RevId: 744722183 --- CHANGELOG.md | 2 ++ jax/_src/cudnn/fused_attention_stablehlo.py | 11 +++++------ jax/interpreters/mlir.py | 22 ++++++++++++++++++++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f19fcb189ed..21398b31cafb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, which were accidental exports, have been removed. If needed, they are available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index d901ed875ceb..818bc018cdf5 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -24,10 +24,9 @@ from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lib import cuda_versions from jax._src import xla_bridge -from jax.interpreters import mlir -from jax.interpreters import xla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp @@ -1018,7 +1017,7 @@ def sharded_impl(*args): _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True _dot_product_attention_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p) ) _dot_product_attention_fwd_p.def_abstract_eval( _dot_product_attention_fwd_abstract @@ -1043,7 +1042,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True _dot_product_attention_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p) ) _dot_product_attention_bwd_p.def_abstract_eval( _dot_product_attention_bwd_abstract @@ -1604,7 +1603,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") _dot_product_attention_fp8_fwd_p.multiple_results = True _dot_product_attention_fp8_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p) ) _dot_product_attention_fp8_fwd_p.def_abstract_eval( _dot_product_attention_fp8_fwd_abstract @@ -1629,7 +1628,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") _dot_product_attention_fp8_bwd_p.multiple_results = True _dot_product_attention_fp8_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p) ) _dot_product_attention_fp8_bwd_p.def_abstract_eval( _dot_product_attention_fp8_bwd_abstract diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index a0505c74f883..10c8d1e9e671 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -33,7 +33,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, - custom_call as custom_call, + custom_call as _custom_call, dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, @@ -77,3 +77,23 @@ from jax._src.callback import ( emit_python_callback as emit_python_callback, ) + +_deprecations = { + # Added Apr 7 2025 + "custom_call": ( + "mlir.custom_call is deprecated; use the APIs provided by jax.ffi instead.", + _custom_call, + ) +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + custom_call = _custom_call +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _custom_call From dbc3bcd3cebdacc3e0ef8ef717807cac170635eb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 4 Apr 2025 08:47:59 -0400 Subject: [PATCH 0426/1769] Apply forwarding in pjit linearization rule to avoid intermediate copies. --- jax/_src/interpreters/ad.py | 9 +++++++++ jax/_src/pjit.py | 18 +++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index e47e518a11f2..1824c39f03fe 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -194,6 +194,11 @@ def new_arg(trace, primal_aval, nz): tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") + tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + tangent_jaxpr, [True] * len(tangent_jaxpr.outvars), + [False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars)) + tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] + residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) @@ -871,6 +876,10 @@ def make_zero(aval): for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index af744ae5db96..8c3c5101eb51 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) - # constvars will become residuals. Move them to the end of the ordinary args. res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + in_fwd, _ = split_list(in_fwd, [num_residuals]) + keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings) + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + num_residuals = sum(f is None for f in in_fwd) + def tangent_fun(consts_, *tangents): + consts_it = iter(consts_) + res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd] + assert next(consts_it, None) is None tangents_nz = _filter_zeros(nzs, tangents) - assert len(consts_) == num_residuals - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_), + nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res), jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings, *out_shardings), + out_shardings=(*res_shardings[:num_residuals], *out_shardings), in_layouts=in_layouts, - out_layouts=(*res_layouts, *out_layouts), + out_layouts=(*res_layouts[:num_residuals], *out_layouts), donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, From ff00fa91cecd2b21f866559c5fd07061a335899a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 09:48:06 -0700 Subject: [PATCH 0427/1769] Removed unused `jax_remat_opt_barrier` config option It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 744754343 --- jax/_src/ad_checkpoint.py | 86 ++++---------------- jax/_src/config.py | 7 -- jax/experimental/jax2tf/jax2tf.py | 11 ++- jax/experimental/jax2tf/tests/jax2tf_test.py | 6 +- 4 files changed, 21 insertions(+), 89 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2868cf7c078..e5390be4cfe0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -757,89 +757,34 @@ def _has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) -def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion( + *args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_ +): assert not jaxpr.constvars if differentiated and prevent_cse: - if config.remat_opt_barrier.value: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond + translation_rule = _remat_translation_using_opt_barrier else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) + def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) -# TODO(mattjj): add core utility for 'create dummy value for this type'? -def _dummy_like(aval: core.AbstractValue) -> Any: - if aval is core.abstract_token: - return lax_internal.create_token() - elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore - else: - raise ValueError(aval) - -def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) - def cond(carry): - counter, _, _ = carry - unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) - return counter < unif - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = lax_control_flow.while_loop(cond, body, carry_init) - return carry_res[1] - -def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(map(_dummy_like, avals_out)) - - unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) - return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) - -def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, policy, is_gpu_platform=False): + +def _remat_lowering( + ctx, + *args, + jaxpr: core.Jaxpr, + prevent_cse: bool, + differentiated: bool, + policy, +): jaxpr_args: Sequence[mlir.IrValues] if differentiated and prevent_cse: - # If we're using the loop or cond lowerings, use the slower lower_fun - # based path. - if not config.remat_opt_barrier.value: - return mlir.lower_fun(remat_expansion, multiple_results=True)( - ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=differentiated, policy=policy, - is_gpu_platform=is_gpu_platform) - arg_types = map(mlir.aval_to_ir_type, ctx.avals_in) flat_args = mlir.flatten_ir_values(args) barrier_op = hlo.OptimizationBarrierOp(flat_args) @@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, ctx.set_tokens_out(tokens_out) return outs + mlir.register_lowering(remat_p, _remat_lowering) -mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), - platform="gpu") def checkpoint_name(x, name): diff --git a/jax/_src/config.py b/jax/_src/config.py index b4a12dcc1762..aca6d8e2c938 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1512,13 +1512,6 @@ def _update_disable_jit_thread_local(val): help=('Attempt constant folding during staging.'), include_in_jit_key=True) -# This flag is temporary during rollout of the remat barrier. -# TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = bool_state( - name='jax_remat_opt_barrier', - default=True, - help=('Enables using optimization-barrier op for lowering remat.')) - enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', default=True, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 492e070de1af..ce57bdad5311 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3173,12 +3173,11 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: lax_control_flow._scan_impl, extra_name_stack="scan") -tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_expansion, - # TODO: jax2tf cannot discriminate by platform - is_gpu_platform=False), - multiple_results=True, - extra_name_stack="checkpoint") +tf_impl_with_avals[ad_checkpoint.remat_p] = _convert_jax_impl( + ad_checkpoint.remat_expansion, + multiple_results=True, + extra_name_stack="checkpoint", +) tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index bea2b76cb7cf..b40b1a6d5571 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -832,11 +832,7 @@ def f(x1): arg = np.array(3.) f_tf = jax2tf.convert(jax.grad(remat_f)) f_tf_hlo = self.TfToHlo(f_tf, arg) - if config.remat_opt_barrier.value: - self.assertRegex(f_tf_hlo, r"opt-barrier") - else: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') + self.assertRegex(f_tf_hlo, r"opt-barrier") def test_remat_free_var(self): def f(x): From 51c224c446df943b565058144b23eeaf8966009d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 09:49:28 -0700 Subject: [PATCH 0428/1769] Removed deprecated `jax.core.{full_lower,jaxpr_as_fun,lattice_join}` PiperOrigin-RevId: 744754730 --- CHANGELOG.md | 4 ++-- jax/_src/core.py | 5 ----- jax/core.py | 13 ++++--------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21398b31cafb..beacd477390f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, - `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `get_referent`, - `join_effects`, `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most have no public replacement, though a few are available at {mod}`jax.extend.core`. diff --git a/jax/_src/core.py b/jax/_src/core.py index 14ed19d4d441..9f80842a38ff 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1496,11 +1496,6 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x - # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any diff --git a/jax/core.py b/jax/core.py index 688fa14d9ccf..9702798d9af9 100644 --- a/jax/core.py +++ b/jax/core.py @@ -97,13 +97,11 @@ "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), # Added 2024-12-10 - "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.full_lower), - "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", None), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun was removed in JAX v0.6.0. Use jax.extend.core.jaxpr_as_fun instead, " "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.jaxpr_as_fun), - "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.lattice_join), + None), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", None), # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " @@ -152,10 +150,7 @@ axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - full_lower = _src_core.full_lower get_type = _src_core.get_aval - jaxpr_as_fun = _src_core.jaxpr_as_fun - lattice_join = _src_core.lattice_join trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck typematch = _src_core.typematch From 855829e1bcf2fbdbe183469350108de50b4cf872 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 7 Apr 2025 10:51:46 -0700 Subject: [PATCH 0429/1769] Add int4, uint4 to test_util.suppported_types To increase test coverage for these types. PiperOrigin-RevId: 744777880 --- jax/_src/test_util.py | 14 ++++++++++---- jaxlib/xla/py_values.cc | 2 ++ jaxlib/xla/xla_client.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c3f7fb4c4139..5a2eaabd0f02 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -57,6 +57,7 @@ from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, @@ -376,10 +377,13 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, + _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} elif device_under_test() == "gpu": types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -389,10 +393,12 @@ def supported_dtypes(): elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: - types = {np.bool_, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} if not config.enable_x64.value: types -= {np.uint64, np.int64, np.float64, np.complex128} return types diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index e13a38197c0a..709f3cb3b2ef 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -684,10 +684,12 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, // float64_dt and complex128_dt which are taken care of in previous if // blocks. (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 58e0cb070e29..fa31d1764de2 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 326 +_version = 327 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 7239487ccc635e2374073c167b773d0627b070a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 18:06:56 +0000 Subject: [PATCH 0430/1769] Bump medyagh/setup-minikube from 0.0.18 to 0.0.19 Bumps [medyagh/setup-minikube](https://github.com/medyagh/setup-minikube) from 0.0.18 to 0.0.19. - [Release notes](https://github.com/medyagh/setup-minikube/releases) - [Commits](https://github.com/medyagh/setup-minikube/compare/d8c0eb871f6f455542491d86a574477bd3894533...cea33675329b799adccc9526aa5daccc26cd5052) --- updated-dependencies: - dependency-name: medyagh/setup-minikube dependency-version: 0.0.19 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/k8s.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index a96ce1ead26c..1042388fe9c6 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -38,7 +38,7 @@ jobs: path: jax - name: Start Minikube cluster - uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18 + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 - name: Install K8s Jobset run: | From fcf5115fdbd216a44c2daf4860fc241cb8ac4f8b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 7 Apr 2025 11:10:52 -0700 Subject: [PATCH 0431/1769] [Pallas Fuser] Add output_fusion_mask support Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion. For example, in the following snippet ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g(z1, z2) ``` it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g1(z1), g2(z2) ``` This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g. By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return z1 + z2 ``` because z1 and z2 interact with each other in the output fusion. The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way. PiperOrigin-RevId: 744784770 --- jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/fuser/fusable.py | 59 +++--- jax/_src/pallas/fuser/jaxpr_fusion.py | 203 ++++++++++++++++++--- tests/pallas/BUILD | 21 +++ tests/pallas/fusion_test.py | 232 ++++++++++++++++++++++++ tests/pallas/tpu_fusable_matmul_test.py | 9 +- 6 files changed, 469 insertions(+), 56 deletions(-) create mode 100644 tests/pallas/fusion_test.py diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 66bbac33aabb..8339ad6705ff 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -99,6 +99,7 @@ pytype_strict_library( "//jax:core", "//jax:partial_eval", "//jax:tree_util", + "//jax:util", ], ) diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py index b075c6d136c9..aa2ea0843c0a 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusable.py @@ -13,6 +13,7 @@ # limitations under the License. """Fusable primitive.""" +from typing import Any import jax from jax._src import api_util @@ -40,32 +41,38 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: ) -def fusable(f): - def wrapper(*args): - def wrapped(*args): - in_fusions = tree_util.tree_map(_make_trivial_fusion, args) - return f(*in_fusions, None) - - flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) - flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(wrapped, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out = fusable_p.bind( - *consts, - *flat_args, - jaxpr=jaxpr, - num_consts=len(consts), - in_tree=in_tree, - out_tree=out_tree, - func=f, - ) - return tree_util.tree_unflatten(out_tree, out) - - return wrapper +def fusable(f=None, *, output_fusion_prefix: Any = True): + def decorator(f): + def wrapper(*args): + def wrapped(*args): + in_fusions = tree_util.tree_map(_make_trivial_fusion, args) + return f(*in_fusions, None) + + flat_args, in_tree = tree_util.tree_flatten(args) + debug_info = api_util.debug_info('fusable', wrapped, args, {}) + flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(wrapped, debug_info=debug_info), in_tree + ) + flat_avals = [_get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + out = fusable_p.bind( + *consts, + *flat_args, + jaxpr=jaxpr, + num_consts=len(consts), + in_tree=in_tree, + out_tree=out_tree, + func=f, + output_fusion_prefix=output_fusion_prefix, + ) + return tree_util.tree_unflatten(out_tree, out) + + return wrapper + + if f is not None: + return decorator(f) + return decorator @fusable_p.def_impl diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3d36b8f3e2fd..649037e18092 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -14,15 +14,15 @@ """Fuses a function.""" +from collections.abc import Sequence +import functools from typing import Any - import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe - from jax._src.pallas.fuser import fusable_dtype from jax._src.pallas.fuser import fusion as fusion_lib from jax._src.pallas.fuser.fusable import fusable_p @@ -73,9 +73,9 @@ def wrapper(*args, **kwargs): _fusable: dict[jax_core.Primitive, Any] = {} -def construct_fusion( +def _construct_fusion_jaxpr( candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs -) -> fusion_lib.Fusion: +): flat_outvars, out_tree = tree_util.tree_flatten(outvars) flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs)) new_jaxpr_no_dce = jaxpr.replace( @@ -94,12 +94,6 @@ def construct_fusion( c for used, c in zip(used_consts, candidate_values, strict=True) if used ) kernel_in_tree = tree_util.tree_structure((invars, kwargs)) - - def _fn(*args, **kwargs): - flat_args, _ = tree_util.tree_flatten((args, kwargs)) - out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - flat_in_type = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars ] @@ -108,9 +102,158 @@ def _fn(*args, **kwargs): out_tree, [jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars], ) + return new_jaxpr, new_values, in_type, out_type, out_tree + + +def construct_fusion( + candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs +) -> fusion_lib.Fusion: + new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr( + candidate_values, jaxpr, outvars, *invars, **kwargs + ) + + def _fn(*args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + return fusion_lib.Fusion(_fn, in_type, out_type) +def _find_downstream( + jaxpr: jax_core.Jaxpr, in_used: Sequence[bool] +) -> tuple[bool, ...]: + # TODO(sharadmv): We use partial_eval to query downstream dependencies which + # is not an officially sanctioned way to do so, since PE is really used for + # AD. In the future, we should have a special Jaxpr API that queries this. + _, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr, + in_unknowns=in_used, + in_inst=in_used, + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False, + ) + return tuple(out_used) + + +def _construct_output_permutation( + used: list[tuple[bool, ...]], +) -> list[int]: + order = [] + for u in used: + true_vals = [i for i in range(len(u)) if u[i]] + order.extend(true_vals) + return [order.index(i) for i in range(len(order))] + + +def _construct_output_fusions( + candidate_values, + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn_outvars, # Flat list of vars output by the fusable eqn + fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs + output_fusion_prefix, # Pytree defining output groups +): + # 1. Create jaxpr_out: represents computation *after* the fusable + # Inputs: fusion_eqn_outvars + # Outputs: jaxpr.outvars + jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( + candidate_values, + jaxpr.replace( + eqns=jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :] + ), + tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs + tree_util.tree_unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ), # Fusable outputs as inputs + ) + + # 2. Group fusable outputs based on the mask + unflat_fusable_outvars = jax.tree.unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ) + partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( + unflat_fusable_outvars + ) + + # 3. Calculate dependencies and check disjointness + downstream_outputs_used_masks = [] # List of bool tuples, one per group + already_used_final_outputs = set() # Indices of final outputs already claimed + for outvars_group in partial_flat: + # Identify vars in this group + used_fusable_outvars = set(jax.tree.leaves(outvars_group)) + # Create mask for jaxpr_out inputs corresponding to this group + in_used_mask = [ + True if v in used_fusable_outvars else False for v in jaxpr_out.invars + ] + # Trace dependencies through jaxpr_out to find which final outputs are affected + downstream_used_mask = _find_downstream( + jaxpr_out, in_used_mask + ) # Mask for jaxpr_out.outvars (== jaxpr.outvars) + + # Check for overlap in final output usage across groups + for i, used in enumerate(downstream_used_mask): + if used: + if i in already_used_final_outputs: + raise ValueError( + "Outputs must be disjoint in order to use separate output fusions" + ) + already_used_final_outputs.add(i) + downstream_outputs_used_masks.append(downstream_used_mask) + + # 4. Construct output permutation needed to restore original output order + output_permutation = _construct_output_permutation( + downstream_outputs_used_masks + ) + + # Construct fusions for each group by DCEing the jaxpr_out + output_fusions = [] + for i, outvars_group in enumerate(partial_flat): + flat_group_vars, _ = tree_util.tree_flatten(outvars_group) + downstream_used_mask = downstream_outputs_used_masks[i] + + used_jaxpr_invars = [False] * len(all_values) + [ + v in flat_group_vars for v in jaxpr_out.invars + ] + jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars + ) + values_for_jaxpr = tuple( + c for used, c in zip(used_consts, all_values, strict=True) if used + ) + + def _fn(jaxpr, vals, *args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args) + return tuple(out_flat) + + fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr) + in_type = jax.tree.map( + lambda v: jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype), # pytype: disable=attribute-error + outvars_group, + ) + out_type = tuple( + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error + for v in jaxpr_out_for_group.outvars + ) + fusion = fusion_lib.Fusion( + fn, + (in_type, {}), + out_type, + ) + output_fusions.append(fusion) + + return ( + tree_util.tree_unflatten( + tree_util.tree_structure(output_fusion_prefix), output_fusions + ), + output_permutation, + ) + + def fuse_jaxpr( jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args ): @@ -125,6 +268,15 @@ def fuse_jaxpr( raise ValueError("No fusable eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] + # Now let's check if we need to do any fusion at all, e.g. do the outputs of + # the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr + # with all the inputs and outputs to check if there is a dependence. + dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), + instantiate=True) + if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns): + # Short circuit if there is nothing to fuse. + return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) + candidate_values = [*consts, *args] # Construct fusions for non-constant inputs to the fusable. @@ -141,21 +293,20 @@ def fuse_jaxpr( in_fusions = tree_util.tree_unflatten( fusion_eqn.params["in_tree"], in_fusions_flat ) - out_fusion = construct_fusion( + output_fusions, output_permutation = _construct_output_fusions( candidate_values, - jaxpr.replace( - eqns=jaxpr.eqns[:fusion_eqn_index] - + jaxpr.eqns[fusion_eqn_index + 1 :] - ), - tree_util.tree_unflatten(out_tree, jaxpr.outvars), - tree_util.tree_unflatten( - fusion_eqn.params["out_tree"], fusion_eqn.outvars - ), + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn.outvars, + fusion_eqn.params["out_tree"], + fusion_eqn.params["output_fusion_prefix"], ) - # Run the fusable. - out = fusion_eqn.params["func"](*in_fusions, out_fusion) - - # Now return the flattened output (the fuse_jaxpr caller should unflatten). - out_flat = tree_util.tree_leaves(out) - assert len(out_flat) == len(jaxpr.outvars) - return out_flat + out = fusion_eqn.params["func"](*in_fusions, output_fusions) + flat_out = jax.tree.leaves(out) + permuted_out = [flat_out[i] for i in output_permutation] + assert len(permuted_out) == len(jaxpr.outvars), ( + len(permuted_out), + len(jaxpr.outvars), + ) + return permuted_out diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 1ea05c700938..ba5d9d5f4ae7 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -680,6 +680,27 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "fusion_test", + srcs = [ + "fusion_test.py", + ], + disable_configs = [ + "cpu", + "cpu_shardy", + ], + enable_backends = ["cpu"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_fuser", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "tpu_fusable_matmul_test", srcs = ["tpu_fusable_matmul_test.py"], diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py new file mode 100644 index 000000000000..2edcf78f1aba --- /dev/null +++ b/tests/pallas/fusion_test.py @@ -0,0 +1,232 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas import fuser +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class FusionTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusable + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(f(x), x) + + def test_separate_output_fusions_trivial(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x, y = f(x, y) + return x, y * 2 + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + x_out, y_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_should_error_if_not_disjoint(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return x_res + y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + + with self.assertRaisesRegex( + ValueError, + "Outputs must be disjoint in order to use separate output fusions", + ): + g(x, y) + + def test_separate_output_fusions_allows_permute(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res * 2, x_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, x_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_with_nesting(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return (x_res * 2, x_res + x_res), y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + (x1_out, x2_out), y_out = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_nesting_and_permutation(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res, (x_res * 2, x_res + x_res) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_deep_output_mask(self): + + @fuser.fusable(output_fusion_prefix=(True, (True, True))) + def f(x_fn, y_fn, z_fn, o_fns): + x = x_fn() + y = y_fn() + z = z_fn() + if o_fns is None: + o_fns = lambda x: x, (lambda x: x, lambda x: x) + o_fn1, (o_fn2, o_fn3) = o_fns + return o_fn1(x), (o_fn2(y), o_fn3(z)) + + @jax.jit + @fuser.fuse + def g(x, y, z): + x_res, (y_res, z_res) = f(x, y, z) + return (x_res * 2, (y_res, z_res + z_res)) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(1), (128, 1), dtype=jnp.float32) + x_out, (y_out, z_out) = g(x, y, z) + np.testing.assert_array_equal(x_out, x * 2) + np.testing.assert_array_equal(y_out, y) + np.testing.assert_array_equal(z_out, z + z) + + def test_separate_output_fusions_with_reused_value(self): + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y, a): + x_res, y_res = f(x, y) + return y_res + a, (x_res * 2, x_res + x_res + a) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y, a) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x + a) + np.testing.assert_array_equal(y_out, y + a) + + def test_empty_fusion(self): + @fuser.fusable + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + @jax.jit + @fuser.fuse + def g(x, a): + _ = f(x) + return a + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + y_out = g(x, a) + np.testing.assert_array_equal(y_out, a) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusable_matmul_test.py index 5ee372ce92ab..93523b174774 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusable_matmul_test.py @@ -71,7 +71,8 @@ def _(): def _(): acc = acc_ref[...].astype(out_dtype) z_values = jax.tree.map(lambda ref: ref.get(), z_value_refs) - o_ref[...] = z_fn(pids, scalar_prefetch, z_values, acc) + out = z_fn(pids, scalar_prefetch, z_values, acc) + jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) def _fusable_matmul( @@ -174,12 +175,12 @@ def z_index_map(i, j, k, *_): y_value_block_specs, z_value_block_specs, ], - out_specs=z_out_block_spec, + out_specs=[z_out_block_spec], ), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=dimension_semantics, ), - out_shape=z_out_type, + out_shape=[z_out_type], interpret=interpret, debug=debug, )( @@ -187,7 +188,7 @@ def z_index_map(i, j, k, *_): x_values, y_values, z_values, - ) + )[0] def fusable_matmul( From e1e37f8d5e80597dbe7a3b447f8c29ba1575ee55 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 7 Apr 2025 11:43:26 -0700 Subject: [PATCH 0432/1769] [Mosaic TPU] FWD compatibility needs to keep previous version at least one month. PiperOrigin-RevId: 744796256 --- jax/_src/tpu_custom_call.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f84db206f4d1..cbec7f873156 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -64,15 +64,22 @@ ) -# This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 4 -DEFAULT_IR_VERSION = None -# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date. -if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple( - jax.lib.__version__ -) < (0, 5, 4): - FWD_COMPAT_IR_VERSION = 3 - DEFAULT_IR_VERSION = 3 +# Controls the IR serialization version. Upon incrementing the +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must +# continue to use the old serialization version when in forward compatibility +# mode: for 1 month when exporting, or when using old cloud TPU. +# +# This can be achieved by adding: +# if ctx.is_forward_compat() or is_cloud_tpu_older_than(): +# return +# return None +# +# We should also add a TODO to remove the conditional one month later. +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: + # TODO(jevinjiang): remove the forward compatibility check after 2025-05-05. + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5): + return 3 + return None tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -679,9 +686,7 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION - if ctx.is_forward_compat() - else DEFAULT_IR_VERSION, + ir_version=get_ir_version(ctx), ) return _tpu_custom_call_lowering( ctx, From 05ca0233914b429381aeb1758e96ed73c6c434c1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 19:12:32 +0000 Subject: [PATCH 0433/1769] [shard-map] in eager shmap, handle all rep rule output cases By convention, rep_rules can return three kinds of thing: 1. a sequence (tuple or list), 2. a single set, or 3. a single None. Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives. In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs. Previously, the code was checking `if type(out_rep) is set`, which doesn't handle case 3. (We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.) fixes #26148, fixes #27673 Co-authored-by: Justin Fu --- jax/experimental/shard_map.py | 3 ++- tests/shard_map_test.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index ef3751c96901..3a46f444fb1b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -958,7 +958,8 @@ def process_primitive(self, prim, tracers, params): rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() if prim.multiple_results: - out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep + out_rep = (out_rep if isinstance(out_rep, (list, tuple)) + else [out_rep] * len(out_vals)) return map(partial(ShardMapTracer, self), out_rep, out_vals) return ShardMapTracer(self, out_rep, out_vals) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 520cc02638df..3a4c3ea9779c 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -685,6 +685,19 @@ def f3(): f3() jax.jit(f3)() + def test_multiple_result_primitive_with_none_sharding(self): + # https://github.com/jax-ml/jax/issues/27673 + xs = jnp.arange(20).reshape(2, 10) + mesh = jtu.create_mesh((2,), ("i",)) + y = shard_map( + lambda x: jnp.split(x.squeeze(), 2), + mesh=mesh, + in_specs=(None,), + out_specs=P("i"), + )(xs) + expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10) + self.assertArraysEqual(y, expected) + def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From dc00f9bdaea77094cb6cdf959d99e61efbd87268 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 13:56:07 -0400 Subject: [PATCH 0434/1769] Apply output forwarding in lin rule for pjit. --- jax/_src/interpreters/partial_eval.py | 1 + jax/_src/pjit.py | 45 +++++++++++++++++++++------ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 532eb0f80029..21be93ee485e 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1876,6 +1876,7 @@ def invalidate(self): # avoid cyclic refs self.frame.tracers = [] self.frame.constid_to_tracer = {} + self.frame.constvar_to_val = {} def to_jaxpr_tracer(self, x): as_local_var = self.frame.tracer_to_var.get(id(x)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8c3c5101eb51..641456eca15b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2079,19 +2079,44 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + primal_out_shardings = res_shardings + tuple(out_shardings) + primal_out_layouts = res_layouts + tuple(out_layouts) + def keep_where(l, should_keep): + return tuple(x for x, keep in zip(l, should_keep) if keep) + + # Input-to-output forwarding. in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) - in_fwd, _ = split_list(in_fwd, [num_residuals]) - keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings) + in_fwd_res, in_fwd_primal = split_list(in_fwd, [num_residuals]) + in_fwd = in_fwd_res + [ + fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal) + ] + del in_fwd_res, in_fwd_primal + keep = [f is None for f in in_fwd] primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) - num_residuals = sum(f is None for f in in_fwd) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + kept_res, _ = split_list(keep, [num_residuals]) + num_kept_residuals = sum(kept_res) + del keep, kept_res + + # Output-to-output forwarding. + num_out_primals = len(primal_jaxpr.jaxpr.outvars) - num_kept_residuals + res_vars, out_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_kept_residuals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + offset = sum(id(v) not in idx_map for v in res_vars) + idx_map = {k: v + offset for k, v in idx_map.items()} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + [None] * num_out_primals + keep = [f is None for f in out_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + del keep def tangent_fun(consts_, *tangents): - consts_it = iter(consts_) - res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd] - assert next(consts_it, None) is None tangents_nz = _filter_zeros(nzs, tangents) - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res), + nz_tangents_out = pjit_p.bind(*tangents_nz, *consts_, jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2114,15 +2139,17 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings[:num_residuals], *out_shardings), + out_shardings=primal_out_shardings, in_layouts=in_layouts, - out_layouts=(*res_layouts[:num_residuals], *out_layouts), + out_layouts=primal_out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) + ans = subs_list(out_fwd, ans, ans) + ans = subs_list(in_fwd, primals_in, ans) residuals_ans, primal_ans = split_list(ans, [num_residuals]) return primal_ans, nzs_out, residuals_ans, tangent_fun From 23b63cd5e0c8f7ab337443ff18d7069d0b8b1afb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 12:50:03 -0700 Subject: [PATCH 0435/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/77635006f6a898f71f19db360e9b4485aa5106da. PiperOrigin-RevId: 744819336 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a8a93026378f..d4df9ee38034 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "145f836bd5175dc5dd262f716a0c59af2b0297a0" -XLA_SHA256 = "bd19d8a1d25468696809a69ef3984bb00ef432e3fe9c05116b9c114dc7c83fa2" +XLA_COMMIT = "77635006f6a898f71f19db360e9b4485aa5106da" +XLA_SHA256 = "d2a63a3cd2f354cd07699f30e7b5c16c7513e686e498b8ad712fb577ab677121" def repo(): tf_http_archive( From 522add2cccf1dc17cb0bb874468b7d87aebf32ef Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 7 Apr 2025 13:10:44 -0700 Subject: [PATCH 0436/1769] [CI] Temporarily disable TPU v6 due to runner issues PiperOrigin-RevId: 744825924 --- .github/workflows/cloud-tpu-ci-nightly.yml | 11 +---------- .github/workflows/wheel_tests_nightly_release.yml | 11 ++--------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index fd799a3f70b5..b50b07d5cc4a 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -27,18 +27,9 @@ jobs: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] python-version: ["3.10"] - # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. - exclude: - - tpu: - type: "v6e-8" - jaxlib-version: "nightly+oldest_supported_libtpu" - - tpu: - type: "v6e-8" - jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20241205 diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 6fd48d016bd0..132aad577d50 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -80,26 +80,19 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8 and v6e-8 + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" - - tpu-specs: - type: "v6e-8" - python: "3.10" - - tpu-specs: - type: "v6e-8" - python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From e1b057287967c477834fb9b62006f1e03dc763b1 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 7 Apr 2025 13:23:45 -0700 Subject: [PATCH 0437/1769] [mgpu] Allow bf16 printing PiperOrigin-RevId: 744830111 --- jax/experimental/mosaic/gpu/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 28534cf4025b..47401440fac2 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -135,7 +135,7 @@ def _debug_scalar_ty_format(arg): return "%llu", arg if ir.F32Type.isinstance(arg.type): return "%f", arg - if ir.F16Type.isinstance(arg.type): + if ir.BF16Type.isinstance(arg.type) or ir.F16Type.isinstance(arg.type): arg = arith.extf(ir.F32Type.get(), arg) return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") From b6e4b93851c75b0eea375cbf43771ddb094c547b Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 7 Apr 2025 13:49:28 -0700 Subject: [PATCH 0438/1769] Add jaxlib_extension_version guard against explicit copying in jax.device_put. PiperOrigin-RevId: 744838237 --- jax/_src/dispatch.py | 3 ++- tests/pjit_test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index d205f860b214..baab6d519291 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,6 +44,7 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span @@ -495,7 +496,7 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device - if copy == CopySemantics.COPY: + if copy == CopySemantics.COPY and jaxlib_extension_version >= 327: return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2570c6090351..2db75be18475 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1401,7 +1401,7 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) def test_device_put_copy_donate(self): - if jaxlib_extension_version < 323: + if jaxlib_extension_version < 327: raise unittest.SkipTest("Copy not supported in device put.") x = np.arange(1000) y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) From 9a3e94dec519c4a5dbf4549be9cee983d9b63cb8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 21:00:59 +0000 Subject: [PATCH 0439/1769] [shard-map] add while_map rep rule fixes #27664 --- jax/_src/util.py | 9 ++----- jax/experimental/shard_map.py | 38 ++++++++++++++++++++++++++ tests/shard_map_test.py | 51 +++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index b3f7becee7eb..30da28522840 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -497,13 +497,8 @@ def __eq__(self, other): self.args == other.args and self.kwargs == other.kwargs) def __hash__(self): - return hash( - ( - self.f.__code__, - self.args, - tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])), - ), - ) + kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])) + return hash((self.f.__code__, self.args, kwargs)) def __call__(self, *args, **kwargs): return self.f(*self.args, *args, **self.kwargs, **kwargs) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a46f444fb1b..7c4a8c6e2542 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1403,6 +1403,44 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.loops.while_p) +def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_): + _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in]) + if tuple(carry_rep_in) != tuple(carry_rep_out): + raise Exception("Scanwhile_loopcarry input and output got mismatched " + "replication types {carry_rep_in} and {carry_rep_out}. " + "Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return carry_rep_out + +@register_rewrite(control_flow.loops.while_p) +def _while_rewrite(mesh, in_rep, *args, cond_jaxpr, body_jaxpr, cond_nconsts, + body_nconsts): + # while while isn't transposable, we insert pbroadcasts for consistent carry + cconst_rep, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + num_carry = len(args) - cond_nconsts - body_nconsts + for _ in range(1 + num_carry): + in_rep_ = [*bconst_rep, *carry_rep_in] + _, carry_rep_out = _replication_rewrite_nomatch(mesh, body_jaxpr, in_rep_) + if tuple(carry_rep_in) == tuple(carry_rep_out): + break + carry_rep_in = map(op.and_, carry_rep_in, carry_rep_out) + else: + assert False, "Fixpoint not reached" + + cond_jaxpr_, _ = _replication_rewrite_nomatch( + mesh, cond_jaxpr, (*cconst_rep, *carry_rep_in)) + body_jaxpr_ = _replication_rewrite_match( + mesh, body_jaxpr, (*bconst_rep, *carry_rep_in), carry_rep_out) + args_ = [pbroadcast(x, tuple(n for n in src if n not in dst)) + if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] + out_vals = control_flow.loops.while_p.bind( + *args_, cond_jaxpr=cond_jaxpr_, body_jaxpr=body_jaxpr_, + cond_nconsts=cond_nconsts, body_nconsts=body_nconsts) + return out_vals, carry_rep_out + @register_check(control_flow.conditionals.cond_p) def _cond_rule(mesh, *in_rep, branches): _, *args_rep = in_rep diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3a4c3ea9779c..daf95ebbd50b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1016,6 +1016,57 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_while_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + + def f(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 5 + def body(c): + i, c, *cs = c + return (i + 1, *cs, c) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + x = jnp.arange(4) + + # doesn't crash, because out_spec assumes no replication (and there is none) + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(('x', 'y')))(x, x, x) + + # does crash, because output incorrectly promises replication + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('x'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('y'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(None))(x, x, x) + + def g(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 1 + def body(c): + i, *cs = c + return (i + 1, *cs) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + # doesn't crash, because everything matches + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) + + # does crash, because the second guy is wrong + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) x = jnp.arange(4) From d3cfff057fadfe173bc9410300ef37de09031d3a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 7 Apr 2025 14:08:35 -0700 Subject: [PATCH 0440/1769] jax.numpy: support __jax_array__ in remaining APIs --- jax/_src/numpy/lax_numpy.py | 25 +++++++++++++++---------- tests/array_extensibility_test.py | 25 +++++++++++++------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d4226617030b..503dca1784c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -911,11 +911,11 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogram", a, bins) + a, _ = util.ensure_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) weights = ones_like(a) else: - util.check_arraylike("histogram", a, bins, weights) + a, _, weights = util.ensure_arraylike("histogram", a, bins, weights) if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -1005,7 +1005,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool) """ - util.check_arraylike("histogram2d", x, y) + x, y = util.ensure_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -1077,10 +1077,10 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogramdd", sample) + sample = util.ensure_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: - util.check_arraylike("histogramdd", sample, weights) + sample, weights = util.ensure_arraylike("histogramdd", sample, weights) if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) @@ -2424,7 +2424,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: [2], [3]]]], dtype=int32) """ - util.check_arraylike("expand_dims", a) + a = util.ensure_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) @@ -4371,7 +4371,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) """ - util.check_arraylike("pad", array) + array = util.ensure_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -6988,8 +6988,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - arr = util.ensure_arraylike("repeat", a) - core.is_dim(repeats) or util.check_arraylike("repeat", repeats) + if core.is_dim(repeats): + arr = util.ensure_arraylike("repeat", a) + else: + arr, repeats = util.ensure_arraylike("repeat", a, repeats) if axis is None: arr = arr.ravel() @@ -7828,7 +7830,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) """ - util.check_arraylike("diag_indices_from", arr) + arr = util.ensure_arraylike("diag_indices_from", arr) nd = np.ndim(arr) if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") @@ -8244,6 +8246,9 @@ def delete( # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) + # Can't use ensure_arraylike here because obj may be static. + if hasattr(obj, "__jax_array__"): + obj = obj.__jax_array__() # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), np.integer) and assume_unique_indices: diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 14fcc18ca7a5..7c2681a2100e 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -81,6 +81,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: Bool = ShapeDtype(bool) Int = ShapeDtype(int) +UInt = ShapeDtype('uint32') Uint8 = ShapeDtype('uint8') Float = ShapeDtype(float) Complex = ShapeDtype(complex) @@ -280,18 +281,18 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), NumPyAPI.sig(jnp.cos, Float[5]), NumPyAPI.sig(jnp.cosh, Float[5]), - # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), - # NumPyAPI.sig(np.cov, [float], [(10,)]), - # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), + NumPyAPI.sig(jnp.count_nonzero, Float[10]), + NumPyAPI.sig(jnp.cov, Float[10]), + NumPyAPI.sig(jnp.cross, Float[3], Float[3]), NumPyAPI.sig(jnp.cumprod, Float[5]), NumPyAPI.sig(jnp.cumsum, Float[5]), NumPyAPI.sig(jnp.cumulative_prod, Float[5]), NumPyAPI.sig(jnp.cumulative_sum, Float[5]), NumPyAPI.sig(jnp.deg2rad, Float[5]), NumPyAPI.sig(jnp.degrees, Float[5]), - # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.delete, Float[5], Int[()]), NumPyAPI.sig(jnp.diag, Float[5]), - # NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), NumPyAPI.sig(jnp.diagflat, Float[5]), NumPyAPI.sig(jnp.diagonal, Float[5, 5]), NumPyAPI.sig(jnp.diff, Float[5]), @@ -306,7 +307,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.equal, Float[5], Float[5]), NumPyAPI.sig(jnp.exp, Float[5]), NumPyAPI.sig(jnp.exp2, Float[5]), - # NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), NumPyAPI.sig(jnp.expm1, Float[5]), NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), NumPyAPI.sig(jnp.fabs, Float[5]), @@ -332,11 +333,11 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.greater, Float[5], Float[5]), NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), - # NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram, Float[5]), NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), - # NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), - # NumPyAPI.sig(jnp.hsplit, Float[3, 5], Int[1]), + NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + NumPyAPI.sig(jnp.hsplit, Float[3, 6], indices_or_sections=2), NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), NumPyAPI.sig(jnp.i0, Float[5]), @@ -411,7 +412,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.ones_like, Float[5]), NumPyAPI.sig(jnp.outer, Float[5], Float[5]), NumPyAPI.sig(jnp.packbits, Int[5]), - # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), NumPyAPI.sig(jnp.partition, Float[5], kth=3), NumPyAPI.sig(jnp.percentile, Float[5], q=75), NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), @@ -437,11 +438,11 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rad2deg, Float[5]), NumPyAPI.sig(jnp.radians, Float[5]), NumPyAPI.sig(jnp.ravel, Float[5]), - # NumPyAPI.sig(jnp.ravel_multi_index, Int[2, 5], dims=(2, 3)), + NumPyAPI.sig(jnp.ravel_multi_index, [Uint8[5], Uint8[5]], dims=(8, 9)), NumPyAPI.sig(jnp.real, Complex[5]), NumPyAPI.sig(jnp.reciprocal, Float[5]), NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), - # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), + NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), From db11efab3be59105b2ac2ccce7281fda30438f1d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 14:43:34 -0700 Subject: [PATCH 0441/1769] Migrate jaxlib to use a single common .so file for all C++ dependencies. The idea is to move all of the jaxlib contents into a single .so file, and have all of the other Python extensions be tiny stubs that reexport part of the larger .so file. This has two main benefits: * it reduces the size of the jaxlib wheel, by about 70-80MB when installed. The benefit of the change is that it avoid duplication between the MLIR CAPI code and the copy of MLIR in XLA. * it gives us flexibility to split and merge Python extensions as we see fit. Issue https://github.com/jax-ml/jax/issues/11225 PiperOrigin-RevId: 744855997 --- .bazelrc | 1 + jaxlib/BUILD | 48 +++++++++- jaxlib/jax_common.json | 8 ++ jaxlib/libjax_common.lds | 7 ++ jaxlib/libjax_common_darwin.lds | 1 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 136 ++++++++++++--------------- jaxlib/mlir/_mlir_libs/triton_ext.cc | 10 ++ jaxlib/pyinit_stub.c | 28 ++++++ jaxlib/pywrap.bzl | 89 ++++++++++++++++++ jaxlib/setup.py | 2 + jaxlib/tools/BUILD.bazel | 3 +- jaxlib/tools/build_wheel.py | 44 ++++----- jaxlib/triton/BUILD | 5 +- jaxlib/xla/BUILD | 3 +- 14 files changed, 281 insertions(+), 104 deletions(-) create mode 100644 jaxlib/jax_common.json create mode 100644 jaxlib/libjax_common.lds create mode 100644 jaxlib/libjax_common_darwin.lds create mode 100644 jaxlib/pyinit_stub.c create mode 100644 jaxlib/pywrap.bzl diff --git a/.bazelrc b/.bazelrc index 422363644578..0c359e039c89 100644 --- a/.bazelrc +++ b/.bazelrc @@ -98,6 +98,7 @@ build:windows --incompatible_strict_action_env=true # ############################################################################# build:nonccl --define=no_nccl_support=true +build --repo_env USE_PYWRAP_RULES=1 build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 diff --git a/jaxlib/BUILD b/jaxlib/BUILD index c8114b48835f..d195bda41f32 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -20,6 +20,12 @@ load( "py_library_providing_imports_info", "pytype_library", ) +load( + "//jaxlib:pywrap.bzl", + "nanobind_pywrap_extension", + "pywrap_binaries", + "pywrap_library", +) load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) @@ -51,6 +57,7 @@ py_library_providing_imports_info( lib_rule = pytype_library, deps = [ ":cpu_feature_guard", + ":jax", ":utils", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", @@ -98,6 +105,44 @@ exports_files([ "setup.py", ]) +pywrap_library( + name = "jax", + common_lib_def_files_or_filters = { + "jaxlib/jax_common": "jax_common.json", + }, + common_lib_version_scripts = { + "jaxlib/jax_common": select({ + "@bazel_tools//src/conditions:windows": None, + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", + "//conditions:default": "libjax_common.lds", + }), + }, + deps = [ + ":utils", + "//jaxlib/mlir/_mlir_libs:_chlo", + "//jaxlib/mlir/_mlir_libs:_mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + "//jaxlib/mlir/_mlir_libs:_mlirHlo", + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + "//jaxlib/mlir/_mlir_libs:_sdy", + "//jaxlib/mlir/_mlir_libs:_stablehlo", + "//jaxlib/mlir/_mlir_libs:_tpu_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", + "//jaxlib/xla:xla_extension", + ], +) + +pywrap_binaries( + name = "jaxlib_binaries", + dep = ":jax", +) + cc_library( name = "absl_status_casters", hdrs = ["absl_status_casters.h"], @@ -170,10 +215,9 @@ nanobind_extension( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "utils", srcs = ["utils.cc"], - module_name = "utils", deps = [ "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/jaxlib/jax_common.json b/jaxlib/jax_common.json new file mode 100644 index 000000000000..61a2c9313897 --- /dev/null +++ b/jaxlib/jax_common.json @@ -0,0 +1,8 @@ +{ + "global": [ + "Wrapped_PyInit_*" + ], + "local": [ + "*" + ] +} diff --git a/jaxlib/libjax_common.lds b/jaxlib/libjax_common.lds new file mode 100644 index 000000000000..6130415a8d26 --- /dev/null +++ b/jaxlib/libjax_common.lds @@ -0,0 +1,7 @@ +{ + global: + Wrapped_PyInit_*; + + local: + *; +}; diff --git a/jaxlib/libjax_common_darwin.lds b/jaxlib/libjax_common_darwin.lds new file mode 100644 index 000000000000..aed9a1d7512a --- /dev/null +++ b/jaxlib/libjax_common_darwin.lds @@ -0,0 +1 @@ +*Wrapped_PyInit_* diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index fb94837cff37..6599e50695d4 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -15,10 +15,9 @@ load( "//jaxlib:jax.bzl", "if_windows", - "nanobind_extension", - "py_extension", "windows_cc_shared_mlir_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( @@ -44,7 +43,7 @@ LINKOPTS = select({ ], }) -py_extension( +nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", @@ -52,14 +51,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", + "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", @@ -67,15 +65,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirGPUPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", @@ -83,14 +80,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", @@ -98,15 +94,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPINVGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsLLVM", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", @@ -114,15 +109,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsSparseTensor", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", @@ -130,14 +124,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirSparseTensorPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", @@ -145,22 +138,20 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], @@ -171,17 +162,16 @@ py_extension( # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly # across platforms. It's not clear if Windows supports RPATH-like functionality # across different directories at all. -py_extension( +nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic:tpu_dialect_capi_headers", + "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", "@xla//xla/python:nb_numpy", @@ -190,7 +180,7 @@ py_extension( ) # This target contains the extension and it's Python dependencies, which are not -# supported by the `py_extension`/`nanobind_extension` macros. +# supported by the `nanobind_pywrap_extension`/`nanobind_extension` macros. py_library( name = "_tpu_ext_lib", deps = [ @@ -200,19 +190,22 @@ py_library( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/triton:triton_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + ], + ), ) symlink_inputs( @@ -235,7 +228,7 @@ cc_library( hdrs = ["jaxlib_mlir_capi_shims.h"], deps = [ "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:GPUPipelines", "@llvm-project//mlir:GPUToLLVMIRTranslation", "@llvm-project//mlir:LLVMToLLVMIRTranslation", @@ -250,34 +243,33 @@ cc_library( name = "jaxlib_mlir_capi_shims_hdrs", hdrs = ["jaxlib_mlir_capi_shims.h"], deps = [ - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", ], ) # JAX-specific registrations. -py_extension( +nanobind_pywrap_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/gpu:mlir_capi_headers", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIMathHeaders", - "@llvm-project//mlir:CAPIMemRefHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPISCFHeaders", - "@llvm-project//mlir:CAPITransformsHeaders", - "@llvm-project//mlir:CAPIVectorHeaders", + "//jaxlib/mosaic/gpu:mlir_capi", + "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPIMemRef", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPINVVM", + "@llvm-project//mlir:CAPISCF", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:CAPIVector", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -285,7 +277,7 @@ py_extension( # MHLO Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_mlirHlo", srcs = [ "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", @@ -293,12 +285,11 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@xla//xla/mlir_hlo:CAPIHeaders", + "@xla//xla/mlir_hlo:CAPI", ], ) @@ -306,7 +297,7 @@ py_extension( # Shardy Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_sdy", srcs = [ "@shardy//shardy/integrations/python/ir:sdy_module.cc", @@ -314,13 +305,12 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -328,7 +318,7 @@ py_extension( # Stablehlo Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_chlo", srcs = [ "@stablehlo//:chlo_py_api_files", @@ -336,16 +326,15 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:chlo_capi_headers", + "@stablehlo//:chlo_capi", ], ) -py_extension( +nanobind_pywrap_extension( name = "_stablehlo", srcs = [ "@stablehlo//:stablehlo_py_api_files", @@ -353,13 +342,12 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:stablehlo_capi_headers", + "@stablehlo//:stablehlo_capi", ], ) diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 7fba7e1dfe80..687ceec4cd33 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef _WIN32 + #include #include @@ -74,3 +76,11 @@ NB_MODULE(_triton_ext, m) { return encoding; }); } + +#else // _WIN32 + +#include "nanobind/nanobind.h" + +NB_MODULE(_triton_ext, m) {} + +#endif // _WIN32 diff --git a/jaxlib/pyinit_stub.c b/jaxlib/pyinit_stub.c new file mode 100644 index 000000000000..7fc873d9ae0e --- /dev/null +++ b/jaxlib/pyinit_stub.c @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Stub that reexports Wrapped_PyInit_module as PyInit_module. + +extern void* Wrapped_PyInit_@MODULE_NAME@(); + +#if defined(WIN32) || defined(_WIN32) +#define EXPORT_SYMBOL __declspec(dllexport) +#else +#define EXPORT_SYMBOL __attribute__ ((visibility("default"))) +#endif + +EXPORT_SYMBOL void* PyInit_@MODULE_NAME@() { + return Wrapped_PyInit_@MODULE_NAME@(); +} diff --git a/jaxlib/pywrap.bzl b/jaxlib/pywrap.bzl new file mode 100644 index 000000000000..75324e01907a --- /dev/null +++ b/jaxlib/pywrap.bzl @@ -0,0 +1,89 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrappers around pywrap rules for JAX.""" + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load( + "@xla//third_party/py/rules_pywrap:pywrap.impl.bzl", + "pybind_extension", + _pywrap_binaries = "pywrap_binaries", + _pywrap_library = "pywrap_library", +) + +pywrap_library = _pywrap_library +pywrap_binaries = _pywrap_binaries + +def nanobind_pywrap_extension( + name, + srcs = [], + deps = [], + pytype_srcs = [], + pytype_deps = [], + copts = [], + linkopts = [], + visibility = None): + # buildifier: disable=function-docstring-args + "Python extension rule using nanobind and the pywrap rules." + module_name = name + lib_name = name + "_pywrap_library" + src_cc_name = name + "_pywrap_stub.c" + + # We put the entire contents of the extension in a single cc_library, which will become part of + # the common pywrap library. All the contents of all extensions will end up in the common + # library. + native.cc_library( + name = lib_name, + srcs = srcs, + copts = copts, + deps = deps, + local_defines = [ + "PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name), + ], + visibility = ["//visibility:private"], + ) + + # We build a small stub library as the extension that forwards to the PyInit_... symbol from the + # common pywrap library. + expand_template( + name = name + "_pywrap_stub", + testonly = True, + out = src_cc_name, + substitutions = { + "@MODULE_NAME@": module_name, + }, + template = "//jaxlib:pyinit_stub.c", + visibility = ["//visibility:private"], + ) + + # Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension + # rule from the pywrap rules. + pybind_extension( + name = name, + srcs = [src_cc_name], + deps = [":" + lib_name], + linkopts = linkopts, + visibility = visibility, + default_deps = [], + common_lib_packages = [ + "jaxlib", + ], + ) + + # Create a py_library with the type stubs as data, on which wheel builds can depend. + native.py_library( + name = name + "_type_stubs", + data = pytype_srcs, + deps = pytype_deps, + ) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..5bd010525c96 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -76,6 +76,8 @@ def has_ext_modules(self): package_data={ 'jaxlib': [ '*.so', + '*.dylib', + '*.dll', '*.pyd*', 'py.typed', 'cpu/*', diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 79a1f7e7089d..3ea09802dfaf 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,9 +64,10 @@ py_binary( "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", + "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", "//jaxlib/xla:xla_client.py", - "//jaxlib/xla:xla_extension", + "//jaxlib/xla:xla_extension_type_stubs", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index fcc811789c19..bab1c6014ff4 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -60,11 +60,11 @@ r = runfiles.Create() - def _is_mac(): return platform.system() == "Darwin" +soext = "dll" if build_utils.is_windows() else ("dylib" if _is_mac() else "so") pyext = "pyd" if build_utils.is_windows() else "so" @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("__main__/jaxlib/xla/xla_extension.so")], + ["nm", "-g", r.Rlocation(f"__main__/jaxlib/xla_extension.{pyext}")], capture_output=True, text=True, check=False, @@ -186,6 +186,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): src_files=[ f"__main__/jaxlib/cpu_feature_guard.{pyext}", f"__main__/jaxlib/utils.{pyext}", + "__main__/jaxlib/jax_common.dll" if build_utils.is_windows() else f"__main__/jaxlib/libjax_common.{soext}", "__main__/jaxlib/lapack.py", "__main__/jaxlib/hlo_helpers.py", "__main__/jaxlib/gpu_prng.py", @@ -198,7 +199,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla/xla_client.py", - f"__main__/jaxlib/xla/xla_extension.{pyext}", + f"__main__/jaxlib/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing @@ -311,38 +312,31 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): ) - if build_utils.is_windows(): - capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" - else: - so_ext = "dylib" if _is_mac() else "so" - capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}" - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" copy_runfiles( dst_dir=mlir_libs_dir, src_files=[ - capi_so, "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", + f"__main__/jaxlib/_mlir.{pyext}", + f"__main__/jaxlib/_chlo.{pyext}", + f"__main__/jaxlib/_mlirHlo.{pyext}", + f"__main__/jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"__main__/jaxlib/_mlirSparseTensorPasses.{pyext}", + f"__main__/jaxlib/_mosaic_gpu_ext.{pyext}", + f"__main__/jaxlib/_tpu_ext.{pyext}", + f"__main__/jaxlib/_sdy.{pyext}", + f"__main__/jaxlib/_stablehlo.{pyext}", + f"__main__/jaxlib/register_jax_dialects.{pyext}", + f"__main__/jaxlib/_mlirDialectsGPU.{pyext}", + f"__main__/jaxlib/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/_mlirDialectsNVGPU.{pyext}", + f"__main__/jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", + f"__main__/jaxlib/_triton_ext.{pyext}", "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 99cddd9e6381..64410fdfeb00 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -35,7 +35,10 @@ pytype_strict_library( "//jaxlib/mlir:ir", ] + if_windows( [], - ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + [ + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext_type_stubs", + ], ), ) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index c5d151b2cd3b..a299629c3ba5 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -24,6 +24,7 @@ load( "py_strict_test", "pytype_strict_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") licenses(["notice"]) @@ -39,7 +40,7 @@ package_group( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "xla_extension", srcs = ["xla.cc"], pytype_deps = py_deps(["numpy"]), From 96e63eaee8a4f741eca6e30ebbed805df825e6bf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 7 Apr 2025 14:46:38 -0700 Subject: [PATCH 0442/1769] jnp.linalg: add symmetrize_input argument & docs --- jax/_src/numpy/linalg.py | 25 ++++++++++++++++++------- tests/linalg_test.py | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 23f2a58b09f6..146bbbda0213 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export -@partial(jit, static_argnames=['upper']) -def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: +@partial(jit, static_argnames=['upper', 'symmetrize_input']) +def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. @@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition @@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) - L = lax_linalg.cholesky(a) + L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads - to better behavior under automatic differentiation. + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where @@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @export -@partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: +@partial(jit, static_argnames=('UPLO', 'symmetrize_input')) +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, + symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. @@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in @@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) - w, _ = eigh(a, UPLO) + w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 20c998d6a685..1670f1ee4abd 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -96,7 +96,7 @@ def args_maker(): a = rng(factor_shape, dtype) return [np.matmul(a, jnp.conj(T(a)))] - jnp_fun = partial(jnp.linalg.cholesky, upper=upper) + jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True) def np_fun(x, upper=upper): # Upper argument added in NumPy 2.0.0 From b18dc1dfd7668859e07c5c823d14154564647708 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 7 Apr 2025 14:46:53 -0700 Subject: [PATCH 0443/1769] [Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSemantics), in addition to the existing ThreadSemantics (renamed to LoweringSemantics). UserThreadSemantics controls the thread semantics of the Pallas user's code, whereas LoweringSemantics controls the level at which Mosaic GPU emits code. PiperOrigin-RevId: 744857085 --- jax/_src/pallas/mosaic_gpu/core.py | 16 +- jax/_src/pallas/mosaic_gpu/lowering.py | 199 ++++++++++-------- .../mosaic_gpu/pallas_call_registration.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 72 ++++--- jax/experimental/mosaic/gpu/__init__.py | 2 +- jax/experimental/mosaic/gpu/core.py | 10 +- jax/experimental/pallas/mosaic_gpu.py | 2 +- tests/mosaic/gpu_test.py | 8 +- tests/pallas/mosaic_gpu_test.py | 53 ++--- tests/pallas/ops_test.py | 2 +- 10 files changed, 205 insertions(+), 163 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 444fe6e50f88..d964a8a90144 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -91,7 +91,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): delay_release: int = 0 profile_space: int = 0 profile_dir: str = "" - thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane + lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): if bool(self.profile_space) ^ bool(self.profile_dir): @@ -142,6 +142,20 @@ def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: return self(()).get_ref_aval() +class PrimitiveSemantics(enum.Enum): + """Thread semantics for a primitives at the Pallas user-level.""" + + Warp = enum.auto() + Warpgroup = enum.auto() + + +# Convenience constants for (lowering, primitive) thread semantics pairs. +LANExWG_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +WGxWG_SEMANTICS = ( + mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) + + def kernel( body: Callable[..., None], out_shape: object, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2423e1c1a2a7..f7bdbccc1ad6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -88,13 +88,13 @@ def _align_to(x: int, alignment: int): @dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: axis_names: _AxisNames - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics @property def arrival_multiplier(self) -> int: return ( WARPGROUP_SIZE - if self.thread_semantics == mgpu.ThreadSemantics.Lane + if self.lowering_semantics == mgpu.LoweringSemantics.Lane else 1 ) @@ -308,7 +308,8 @@ class ModuleContext: name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics + primitive_semantics: gpu_core.PrimitiveSemantics def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -367,7 +368,7 @@ def scratch_view( smem = ir.Attribute.parse("#gpu.address_space") i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: smem_base = gpu_dialect.dynamic_shared_memory( ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) ) @@ -383,7 +384,7 @@ def scratch_view( # The below code emission relies on the assumption that the first scratch # operand provided by Mosaic GPU always begins at the beginning of # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: view = memref_dialect.view( scratch_ty, smem_base, _as_index(off), [] ) @@ -416,7 +417,7 @@ class LoweringRuleContext: def estimator_ctx(self) -> ResourceEstimatorContext: return ResourceEstimatorContext( axis_names=self.module_ctx.axis_names, - thread_semantics=self.module_ctx.thread_semantics, + lowering_semantics=self.module_ctx.lowering_semantics, ) @@ -703,8 +704,8 @@ def lower_jaxpr_to_module( debug_info = jaxpr.debug_info params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) - thread_semantics = params.get( - "thread_semantics", mgpu_core.ThreadSemantics.Lane + lowering_semantics = params.get( + "lowering_semantics", mgpu_core.LoweringSemantics.Lane ) if len(cluster) < 3: @@ -732,7 +733,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): else: tmem_cols = 0 - if thread_semantics == mgpu.ThreadSemantics.Lane: + if lowering_semantics == mgpu.LoweringSemantics.Lane: single_lane_predicate = mgpu.single_thread_predicate(per_block=False) else: # Warpgroup semantics do not have a single lane predicate. single_lane_predicate = None @@ -752,7 +753,8 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), squashed_dims=squashed_dims, - thread_semantics=thread_semantics, + lowering_semantics=lowering_semantics, + primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -762,7 +764,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): rs = _estimate_resources( ResourceEstimatorContext( - axis_names=axis_names, thread_semantics=thread_semantics + axis_names=axis_names, lowering_semantics=lowering_semantics ), jaxpr, ) @@ -801,7 +803,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error @@ -820,17 +822,21 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. - mgpu.ThreadSemantics.Lane: {} , + (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , # Lowering rules when using Mosaic GPU warpgroup semantics. - mgpu.ThreadSemantics.Warpgroup: {}, + (mgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup): {}, } def register_lowering_rule( - primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics + primitive: jax_core.Primitive, + lowering_semantics: mgpu.LoweringSemantics, + primitive_semantics: gpu_core.PrimitiveSemantics = gpu_core.PrimitiveSemantics.Warpgroup, ): def deco(fn): - mosaic_lowering_rules[thread_semantics][primitive] = fn + mosaic_lowering_rules[ + (lowering_semantics, primitive_semantics)][primitive] = fn return fn return deco @@ -866,7 +872,7 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? - if module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Shaped arrays must be vectors if and only if their shape is non-empty. # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) @@ -903,10 +909,13 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): ) loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules[module_ctx.thread_semantics]: + if eqn.primitive not in mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics)]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " + f"{eqn.primitive.name} for lowering semantics " + f"{module_ctx.lowering_semantics} and user thread semantics " + f"{module_ctx.primitive_semantics}. " "Please file an issue on https://github.com/jax-ml/jax/issues." ) new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] @@ -918,7 +927,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[module_ctx.thread_semantics][eqn.primitive] + rule = mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics) + ][eqn.primitive] rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, @@ -947,8 +958,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.program_id_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.program_id_p, mgpu.LoweringSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -1015,8 +1027,9 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): return lowering_rule -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.num_programs_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.num_programs_p, mgpu.LoweringSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): del ctx # Unused. return arith_dialect.index_cast( @@ -1089,7 +1102,7 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane) def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): if isinstance(x_ref, tcgen05.TMEMRef): transforms = jax.tree.unflatten(tree, leaves) @@ -1132,7 +1145,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Warpgroup) def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") @@ -1157,7 +1170,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): return memref_dialect.load(x_smem, []) -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1222,7 +1235,7 @@ def _swap_lowering_rule( raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1253,8 +1266,8 @@ def _swap_lowering_rule_wg( return old_value -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Warpgroup) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError @@ -1263,7 +1276,7 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ) -@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1273,8 +1286,8 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1283,7 +1296,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): ) pred_aval, *cases_avals = ctx.avals_in [out_aval] = ctx.avals_out - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) cases = _bcast(*cases, *cases_avals, out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` @@ -1301,7 +1314,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return arith_dialect.select(pred, *reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1331,7 +1344,8 @@ def _broadcast_in_dim_lowering_rule( return x.broadcast(shape) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, x: ir.Value, @@ -1351,7 +1365,7 @@ def _broadcast_in_dim_lowering_rule_wg( ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1362,7 +1376,8 @@ def _convert_element_type_lowering_rule( ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.convert_element_type_p, mgpu.LoweringSemantics.Warpgroup) def _convert_element_type_lowering_rule_wg( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1454,12 +1469,12 @@ def convert(ty, x): return convert(ty, x) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, @@ -1472,7 +1487,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): return impl(x, y) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1533,7 +1548,7 @@ def _binary_op_lowering_rule_wg( arith_dialect.minimumf, ), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_op_lowering_rule_wg, si_impl=si_impl, ui_impl=ui_impl, @@ -1552,7 +1567,7 @@ def _binary_boolean_op_lowering_rule_wg( (lax.or_p, arith_dialect.ori), (lax.xor_p, arith_dialect.xori), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_boolean_op_lowering_rule_wg, impl=impl, ) @@ -1585,7 +1600,7 @@ def _comparison_lowering_rule_wg( (lax.gt_p, CmpIPred.sgt, CmpIPred.ugt, CmpFPred.OGT), (lax.ge_p, CmpIPred.sge, CmpIPred.uge, CmpFPred.OGE), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _comparison_lowering_rule_wg, si_pred=si_pred, ui_pred=ui_pred, @@ -1593,7 +1608,7 @@ def _comparison_lowering_rule_wg( ) -@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.div_p, mgpu.LoweringSemantics.Lane) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) if ir.FloatType.isinstance(x.mlir_dtype): @@ -1601,19 +1616,19 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return x // y -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): if y != 2: raise NotImplementedError return _square_lowering_rule(ctx, x) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: x = _ensure_fa(x, x_aval.dtype) return x * x if jnp.issubdtype(x_aval.dtype, jnp.integer): @@ -1623,13 +1638,13 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Warpgroup) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1639,13 +1654,13 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): ) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Warpgroup) def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1659,21 +1674,21 @@ def _logistic(x, accuracy): return 1.0 / (1 + lax.exp(-x)) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS][lax.logistic_p] = _lower_fun( _logistic, multiple_results=False ) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][lax.logistic_p] = ( _lower_fun(_logistic, multiple_results=False) ) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Warpgroup) def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1681,13 +1696,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Warpgroup) def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1695,13 +1710,13 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1709,7 +1724,7 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1729,7 +1744,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane) def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1770,7 +1785,7 @@ def _reduce_lowering_rule_wg( return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): op = _reduce_lowering_rule_wg( vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes @@ -1781,7 +1796,7 @@ def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return op.result -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in if jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1822,8 +1837,8 @@ def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names if not axis_names: @@ -1883,7 +1898,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): ) -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1911,7 +1926,7 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Warpgroup) def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, @@ -1925,8 +1940,8 @@ def _debug_print_lowering_rule_wg( return () -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): @@ -1937,7 +1952,7 @@ def _run_scoped_lowering_rule( aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): dtype = mlir.dtype_to_ir_type(aval.dtype) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) else: zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) @@ -2018,7 +2033,7 @@ def _run_scoped_lowering_rule( return outs -@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2090,7 +2105,7 @@ def as_values(vals, avals): _ensure = ( _ensure_fa - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else _ensure_ir_value ) return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] @@ -2110,8 +2125,8 @@ def loop(loop_index, body_args): return loop.results -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2199,8 +2214,8 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Warpgroup) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2224,7 +2239,7 @@ def _while_lowering_rule( _is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator) _ensure = _ensure_ir_value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: _ensure = lambda v, aval: v if _is_acc(v) else _ensure_fa(v, aval.dtype) # If we fail conversion to fori, fallback to an ordinary while loop. @@ -2276,8 +2291,8 @@ def _while_lowering_rule( return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -2334,9 +2349,9 @@ def _yielded_values(outs, avals): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule( - lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup + lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Warpgroup ) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype @@ -2352,7 +2367,7 @@ def _bitcast_convert_type_lowering_rule( " have different widths" ) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: x = _ensure_ir_value(x, x_aval.dtype) return arith_dialect.bitcast( ir.VectorType.get(x_aval.shape, dst_elem_type), x @@ -2368,7 +2383,7 @@ def _bitcast_convert_type_lowering_rule( ) -@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.LoweringSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): result = mgpu.optimization_barrier( *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) @@ -2377,7 +2392,7 @@ def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): @register_lowering_rule( - lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup + lax.optimization_barrier_p, mgpu.LoweringSemantics.Warpgroup ) def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): result = mgpu.dialect.optimization_barrier([ diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index ff3c4f89d30c..eb15aff21235 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -58,8 +58,8 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mgpu.ThreadSemantics.Lane + lowering_semantics = compiler_params.get("mosaic_gpu", {}).get( + "lowering_semantics", mgpu.LoweringSemantics.Lane ) mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8bd67e705cf0..a37b018860d7 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -75,7 +75,7 @@ def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): {state.ReadEffect(0)}, ) -@lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane) def _load_p_lowering_rule( ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized ): @@ -215,9 +215,10 @@ def _copy_smem_to_gmem_pp_eqn( jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn -@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, @@ -236,7 +237,7 @@ def _copy_smem_to_gmem_lowering( else: predicate = None - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if predicate is not None: assert ctx.module_ctx.single_wg_lane_predicate is not None predicate = arith_dialect.andi( @@ -253,7 +254,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, @@ -439,9 +440,10 @@ def _copy_gmem_to_smem_pp_eqn( jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn -@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, @@ -488,7 +490,7 @@ def _copy_gmem_to_smem_lowering( f" dtype={dst_ty.element_type})" ) bytes = bits // 8 - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") # We arrive uniformly from each thread in the WG, so we need to divide the @@ -630,8 +632,8 @@ def _barrier_arrive_pp_eqn( jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -688,8 +690,8 @@ def _barrier_wait_pp_eqn( jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -726,9 +728,10 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only @@ -759,8 +762,9 @@ def _commit_group_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_group_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_group_p, mgpu.LoweringSemantics.Warpgroup) def _commit_group_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. nvvm_dialect.cp_async_bulk_commit_group() @@ -886,7 +890,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -981,7 +985,7 @@ def _wgmma_lowering( return new_acc -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_warpgroup_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -1052,8 +1056,8 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -1085,16 +1089,16 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): @lowering.register_lowering_rule( - wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Lane ) @lowering.register_lowering_rule( - wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Warpgroup + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Warpgroup ) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): nvvm_dialect.wgmma_wait_group_sync_aligned(0) return ( acc.value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else acc ) @@ -1156,7 +1160,7 @@ def _layout_cast_abstract_eval(x, new_layout): return x -@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Lane) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. return x.to_layout(new_layout.to_mgpu()) @@ -1177,8 +1181,10 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -1206,8 +1212,9 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_smem_p, mgpu.LoweringSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() @@ -1227,7 +1234,8 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + broadcasted_iota_p, mgpu.LoweringSemantics.Lane) def _broadcasted_iota_lowering( ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout ): @@ -1309,8 +1317,8 @@ def _jaxpr_call_pp_eqn( jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Warpgroup) def _jaxpr_call_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, @@ -1507,7 +1515,7 @@ def _inline_mgpu_abstract_eval( def _inline_mgpu_discharge(*args, **kwargs): raise NotImplementedError("inline_mgpu_p does not support discharge.") -@lowering.register_lowering_rule(inline_mgpu_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) def _inline_mgpu_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index a4f1e0a9cfe0..8ecc5b9fd8da 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -23,7 +23,7 @@ Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, - ThreadSemantics as ThreadSemantics, + LoweringSemantics as LoweringSemantics, TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 43b93e7da023..e822ea5f3ebf 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -209,7 +209,7 @@ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize -class ThreadSemantics(enum.Enum): +class LoweringSemantics(enum.Enum): """Semantics for the kernel's instruction stream.""" Lane = enum.auto() @@ -595,7 +595,7 @@ def as_gpu_kernel( module_name: str = "unknown", kernel_name: str | None = None, ir_version: int | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -609,7 +609,7 @@ def as_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error @@ -669,7 +669,7 @@ def as_torch_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + lowering_semantics: LoweringSemantics = LoweringSemantics.Lane, ): try: import torch @@ -692,7 +692,7 @@ def as_torch_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 1d3bebbc3757..85e512d03290 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -52,7 +52,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait -from jax.experimental.mosaic.gpu.core import ThreadSemantics as ThreadSemantics +from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 80e5380d165b..f0930f5de8cc 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2602,7 +2602,7 @@ def add(ctx, a, b, result, smem): in_shape=(jax_shape, jax_shape), out_shape=jax_shape, smem_scratch_shape=[], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, shape).astype(dtype) @@ -2747,7 +2747,7 @@ def add( jax_shape_sliced, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) @@ -2846,7 +2846,7 @@ def add( spec, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) @@ -2994,7 +2994,7 @@ def matmul( result_jax_shape, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) prng_key = jax.random.key(1234) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 10b4f3de60ad..f0f3bdf41c32 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives @@ -64,14 +65,14 @@ def _sum_same_dtype(x): class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): - def __new__(mcs, *args, thread_semantics=plgpu.ThreadSemantics.Lane): + def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): cls = super().__new__(mcs, *args) - cls.THREAD_SEMANTICS = thread_semantics + cls.LOWERING_SEMANTICS = lowering_semantics return cls class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): - THREAD_SEMANTICS: ClassVar[plgpu.ThreadSemantics] + LOWERING_SEMANTICS: ClassVar[plgpu.LoweringSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): @@ -83,20 +84,20 @@ def setUp(self): super().setUp() def skip_if_wg_semantics(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: self.skipTest("Not supported under WG semantics") def kernel(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), - thread_semantics=self.THREAD_SEMANTICS, + lowering_semantics=self.LOWERING_SEMANTICS, ) return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) def pallas_call(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), - thread_semantics=self.THREAD_SEMANTICS, + lowering_semantics=self.LOWERING_SEMANTICS, ) return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) @@ -1503,7 +1504,7 @@ def kernel(x_ref, y_ref, o_ref): class PallasCallWGTest( - PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -1513,10 +1514,14 @@ def test_missing_primitive_lowerings_are_tracked(self): # enable warpgroup semantics by default (assuming we haven't overspecialized # lowerings). rules = mgpu_lowering.mosaic_lowering_rules - wg_lowered_primitives = set(rules[plgpu.ThreadSemantics.Warpgroup]) - lane_lowered_primitives = set(rules[plgpu.ThreadSemantics.Lane]) - - actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives + wg_wg_lowered_primitives = set( + rules[(plgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup)]) + lane_wg_lowered_primitives = set(rules[ + (plgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup)]) + + actual_missing_primitives = (lane_wg_lowered_primitives - + wg_wg_lowered_primitives) expected_missing_primitives = { mgpu_primitives.inline_mgpu_p, mgpu_primitives.broadcasted_iota_p, @@ -1607,7 +1612,7 @@ def _epilogue(): lambda m, n, k: (m, n), ) - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, @@ -1715,7 +1720,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = self.pallas_call( kernel, @@ -1768,7 +1773,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = self.pallas_call( @@ -1797,7 +1802,7 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = ( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), @@ -1820,7 +1825,7 @@ def scope(acc_ref): class PallasCallSm90AWGTest( - PallasCallSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -1851,7 +1856,7 @@ def kernel(y_ref, tmem_ref, smem_ref): class PallasCallSm100AWGTest( - PallasCallSm100ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2117,7 +2122,7 @@ def kernel_body(_, x_smem, o_smem): class PipelineWGTest( - PipelineTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2137,7 +2142,7 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = ( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), @@ -2190,7 +2195,7 @@ def kernel_body(_, a_smem, b_smem): class PipelineSm90AWGTest( - PipelineSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PipelineSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2354,7 +2359,7 @@ def tiled_acc_kernel(_, x_smem, carry): class WarpSpecializedPipelineWGTest( WarpSpecializedPipelineTest, - thread_semantics=plgpu.ThreadSemantics.Warpgroup, + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, ): ... @@ -2612,7 +2617,7 @@ def body(step, _): class CoreMapWGTest( - CoreMapTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + CoreMapTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2747,7 +2752,7 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): class ExamplesWGTest( - ExamplesTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2789,7 +2794,7 @@ def do_wgmma(acc_ref): class ExamplesSm90AWGTest( - ExamplesSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + ExamplesSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index aeb0ba1cca1a..ff02c334f45c 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -296,7 +296,7 @@ def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None compiler_params = plgpu_mgpu.GPUCompilerParams( - thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params From 64e4bf26324ddfeb957233f6370b74acdbb80e5f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 7 Apr 2025 14:49:09 -0700 Subject: [PATCH 0444/1769] Relax jax dependency constraints to be able to install RC wheels Also, add a job to the release test workflow that verifies that the release wheels can be installed. TESTED: 1. Full release: https://github.com/jax-ml/jax/actions/runs/14315832784 2. jax only release: https://github.com/jax-ml/jax/actions/runs/14316157252 PiperOrigin-RevId: 744857804 --- .../workflows/wheel_tests_nightly_release.yml | 98 +++++++++++++++++-- jax/version.py | 6 ++ setup.py | 19 ++-- tests/version_test.py | 16 +++ 4 files changed, 127 insertions(+), 12 deletions(-) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 132aad577d50..f6d2aa9b97c6 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,14 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that was built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs CUDA tests. +# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs TPU tests. +# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -106,4 +108,88 @@ jobs: run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + gcs_download_uri: ${{inputs.gcs_download_uri}} + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n2-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10", "3.13", "3.13-nogil"] + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.10, then python_major_minor=310 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then + gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda] + else + uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda12-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi \ No newline at end of file diff --git a/jax/version.py b/jax/version.py index 6ed6a5fda600..21662d078f7f 100644 --- a/jax/version.py +++ b/jax/version.py @@ -93,6 +93,12 @@ def _get_version_for_build() -> str: return _version_from_git_tree(_version) or _version_from_todays_date(_version) +def _is_prerelease() -> bool: + """Determine if this is a pre-release ("rc" wheels) build.""" + rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "") + return True if rc_version.startswith("rc") else False + + def _write_version(fname: str) -> None: """Used by setup.py to write the specified version info into the source tree.""" release_version = _get_version_for_build() diff --git a/setup.py b/setup.py index bdaeb624bf38..629836b30862 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,13 @@ def load_version_module(pkg_path): _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version +# If this is a pre-release ("rc" wheels), append "rc0" to +# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to +# install the rc wheels. +if _version_module._is_prerelease(): + _minimum_jaxlib_version += "rc0" + _current_jaxlib_version += "rc0" + with open('README.md', encoding='utf-8') as f: _long_description = f.read() @@ -81,32 +88,32 @@ def load_version_module(pkg_path): ], 'cuda': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Deprecated alias for cuda12, kept to avoid breaking users who wrote # cuda12_pip in their CI. 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", ], # ROCm support for ROCm 6.0 and above. 'rocm': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}", ], diff --git a/tests/version_test.py b/tests/version_test.py index b78e61ae024c..14da82df2e3e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -143,6 +143,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -150,6 +151,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -183,6 +185,20 @@ def testBuildVersionFromEnvironment(self): ): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) self.assertEqual(version, f"{base_version}rc0") self.assertValidVersion(version) From 3a3c145039c8d1b41946b0683cdcf601b29bd3f9 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 21:30:25 +0000 Subject: [PATCH 0445/1769] [shard-map] canonicalize rep=None to be rep={all possible axes} None is meant to represent the same thing as {replicated over all possible axes}. But without this canonicalization, we could compare None as not equal to {all possible axes}. fixes #26621 Unrelated: in several places, including the _check_rep path, we don't handle partial auto correctly, since we treat {all possible axes} as {all mesh axes}, but actually it should be more like {all mesh axes} - auto. We'll leave that fix for a follow-up... --- jax/experimental/shard_map.py | 7 ++++--- tests/shard_map_test.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a46f444fb1b..17d909b5629c 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -646,7 +646,7 @@ def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] env: dict[core.Var, RepType] = {} def read(x: core.Atom) -> RepType: - return env[x] if type(x) is core.Var else None + return env[x] if type(x) is core.Var else set(mesh.axis_names) def write(v: core.Var, val: RepType) -> None: env[v] = val @@ -942,7 +942,7 @@ def to_val_rep_pair(self, val): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) - return val_, None + return val_, set(self.mesh.axis_names) - set(self.auto) def process_primitive(self, prim, tracers, params): in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) @@ -1008,6 +1008,7 @@ class ShardMapTracer(core.Tracer): val: JaxType def __init__(self, trace, rep, val): + rep = set(trace.mesh.axis_names) - set(trace.auto) if rep is None else rep self._trace = trace self.rep = rep self.val = val @@ -2151,7 +2152,7 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() - t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) + t = RewriteTrace(parent_trace=parent, tag=tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): ans = f(*in_tracers) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3a4c3ea9779c..62395a8750ab 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2747,6 +2747,28 @@ def f(x): f(x) # doesn't crash + def test_rep_none_canonicalization(self): + # https://github.com/jax-ml/jax/issues/26621 + N = 8 + xs = jnp.ones((8, N), dtype=jnp.int32) + variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64) + mesh = jtu.create_mesh((2,), ('i',)) + in_specs = (P(), P("i"),) + out_specs = P("i") + + variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P())) + xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i'))) + + def fun(v, xs): + # Commenting this single line below makes everything work + v = jax.scipy.linalg.expm(v) + v = v.sum() + return v * xs.sum(axis=-1).astype(v.dtype) + + res = fun(variables, xs) + fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) + res = fun_shard_map(variables, xs) # don't crash + class FunSpec(NamedTuple): name: str From 48a9ad07968795357814c0b02e7abb52cae10786 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 15:07:30 -0700 Subject: [PATCH 0446/1769] Reverts 006a6a63feb64bf9984526030ba008186d69d2b4 PiperOrigin-RevId: 744864022 --- jax/_src/lax/parallel.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index e533672a1d9b..6ed4dddfcc21 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1906,8 +1906,8 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name _check_axis_names(axis_name) mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2669d73691c1..98ff98759c8c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -518,17 +518,11 @@ def has_communication(self) -> bool: nonlocal_axis_names = set() def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): return { - e.name - for e in jaxpr.effects - if isinstance(e, jax_core.NamedAxisEffect) - and ( - not self.grid_names - or all( - name not in self.grid_names - for name in tree_util.tree_leaves(e.name) - ) - ) - } + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + and (not self.grid_names or e.name not in self.grid_names) + } nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) for bm in self.block_mappings: if bm is not None: From 2944e3b2a64d26f3cadbda4694486c21979a7229 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 15:27:10 -0700 Subject: [PATCH 0447/1769] Removed `data_dependent_tracing_fallback` config option No internal code needs it any more. PiperOrigin-RevId: 744870756 --- CHANGELOG.md | 3 +++ jax/_src/config.py | 6 ------ jax/_src/core.py | 8 +------- jax/_src/pjit.py | 4 +--- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index beacd477390f..3aae0f432121 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.numpy.array` no longer accepts `None`. This behavior was deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. * Changes * The minimum CuDNN version is v9.8. diff --git a/jax/_src/config.py b/jax/_src/config.py index aca6d8e2c938..8aa4ee343664 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1099,12 +1099,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ' transpose rewrite machinery in shard_map'), include_in_jit_key=True) -data_dependent_tracing_fallback = bool_state( - name='jax_data_dependent_tracing_fallback', - default=False, - help=('When True, falls back to trace dispatch based on data dependence ' - 'instead of throwing an escaped tracer error.')) - softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 9f80842a38ff..8d32b9370091 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -497,9 +497,7 @@ def bind(self, *args, **params): def _true_bind(self, *args, **params): for arg in args: - if (isinstance(arg, Tracer) - and not arg._trace.is_valid() - and not config.data_dependent_tracing_fallback.value): + if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or @@ -1015,10 +1013,6 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - if config.data_dependent_tracing_fallback.value: - for arg in args: - if isinstance(arg, Tracer): - return primitive.bind_with_trace(arg._trace, args, params) check_eval_args(args) return primitive.impl(*args, **params) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8c3c5101eb51..cf4d13530b74 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -186,9 +186,7 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): + if core.trace_state_clean() and not config.debug_key_reuse.value: args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) From 0a72e856cfb8984ba4883d9449bc8928adebe535 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 7 Apr 2025 16:20:58 -0700 Subject: [PATCH 0448/1769] Add **experimental** `with_dll_constraint` API. This is for cases when the users wants to let SPMD decide the sharding. But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**. PiperOrigin-RevId: 744888557 --- jax/_src/layout.py | 4 ++ jax/_src/pjit.py | 44 +++++++++++++++++++ .../jax2tf/tests/primitives_test.py | 2 + jax/experimental/layout.py | 5 ++- tests/BUILD | 1 + tests/layout_test.py | 33 +++++++++++++- 6 files changed, 87 insertions(+), 2 deletions(-) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 5309f0b1fd9c..8d4f8acd5327 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -127,6 +127,10 @@ def __init__(self, device_local_layout: LayoutOptions = None, self.device_local_layout = device_local_layout self.sharding = sharding + @property + def dll(self): + return self.device_local_layout + def __repr__(self): return (f'Layout(device_local_layout={self.device_local_layout},' f' sharding={self.sharding})') diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index cf4d13530b74..0c8d7393b98c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2942,6 +2942,50 @@ def use_explicit_axes(*axes): with mesh_lib.use_abstract_mesh(new_mesh): yield +# -------------------- with_dll_constraint -------------------- + +def with_dll_constraint(x, layouts): + x_flat, tree = tree_flatten(x) + layouts_flat = tuple(flatten_axes("with_dll_constraint layouts", tree, + layouts)) + if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): + raise ValueError( + 'layouts passed to `with_dll_constraint` must be of type' + f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') + check_aval_layout_compatibility( + layouts_flat, x_flat, ("",) * len(layouts_flat), + "with_dll_constraint arguments") + outs = [dll_constraint_p.bind(xf, layout=l) + for xf, l in zip(x_flat, layouts_flat)] + return tree_unflatten(tree, outs) + +dll_constraint_p = core.Primitive('dll_constraint') +dll_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(dll_constraint_p, + lambda ct, _, **params: (dll_constraint_p.bind(ct, **params),)) + +def _dll_constraint_impl(x, *, layout): + if not isinstance(x, xc.ArrayImpl): + raise ValueError( + 'with_dll_constraint in eager mode can only be applied to' + f' jax.Arrays. Got {type(x)}') + if x.layout.device_local_layout == layout: # type: ignore + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x) +dll_constraint_p.def_impl(_dll_constraint_impl) + +def _dll_constraint_hlo_lowering(ctx, x_node, *, layout): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] +mlir.register_lowering(dll_constraint_p, + _dll_constraint_hlo_lowering) + +def _dll_constraint_batcher(axis_data, vals_in, dims_in, layout): + raise NotImplementedError +batching.fancy_primitive_batchers[dll_constraint_p] = _dll_constraint_batcher +batching.skippable_batchers[dll_constraint_p] = lambda _: () + # -------------------- helpers -------------------- def get_unconstrained_dims(sharding: NamedSharding): diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 1ccd009f157c..0156465e339a 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -174,6 +174,8 @@ def test_primitive_coverage(self): continue if p.name == "sharding_constraint": continue + if p.name == "dll_constraint": + continue if p.name == "mesh_cast": continue if p.name == "reshard": diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index ed9f8931938e..aa114a2803e8 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -14,5 +14,8 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout + Layout as Layout, +) +from jax._src.pjit import ( + with_dll_constraint as with_dll_constraint, ) diff --git a/tests/BUILD b/tests/BUILD index eb6ff81f5d68..0a58ee52d88c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -325,6 +325,7 @@ jax_multiplatform_test( }, enable_configs = [ "tpu_v3_2x2_shardy", + "tpu_v3_2x2", ], tags = ["multiaccelerator"], deps = [ diff --git a/tests/layout_test.py b/tests/layout_test.py index b9062b8d21dc..ae10013a5f60 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -21,9 +21,10 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.layout import (with_dll_constraint, Layout, + DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -744,6 +745,36 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) self.assertEqual(out.layout, out_layout) + def test_with_dll_constraint(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=arr.layout.dll.major_to_minor[::-1]) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_dll_constraint(y, custom_dll) + return y * 2 + + f(arr) # doesn't crash + + f = jax.jit(f) + out = f(arr) + self.assertEqual(out.layout.device_local_layout.major_to_minor, + custom_dll.major_to_minor) + self.assertArraysEqual(out, np_inp.T * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('LayoutConstraint', lowered_text) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 84e04fe60838db9deb4bd040f8bb678424fd439d Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Mon, 7 Apr 2025 16:24:17 -0700 Subject: [PATCH 0449/1769] Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None. PiperOrigin-RevId: 744889524 --- jax/_src/api.py | 13 ++- jax/_src/lax/lax.py | 19 +++- tests/api_test.py | 147 ++++++++++++++++--------------- tests/core_test.py | 8 +- tests/pmap_test.py | 16 ++-- tests/unary_ops_accuracy_test.py | 24 ++++- 6 files changed, 132 insertions(+), 95 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index fb10245c30e9..d338e2d70700 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2268,16 +2268,13 @@ def make_jaxpr( >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a:f32[]. let - b:f32[] = cos[accuracy=None] a - c:f32[] = sin[accuracy=None] b - in (c,) } + { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let - b:f32[] = cos[accuracy=None] a - c:f32[] = sin[accuracy=None] a - _:f32[] = sin[accuracy=None] b - d:f32[] = cos[accuracy=None] b + b:f32[] = cos a + c:f32[] = sin a + _:f32[] = sin b + d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 13511641558c..7ca73603c14b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4066,6 +4066,11 @@ def _nary_lower_hlo( out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] +def _unary_with_accuracy_pp_rule(eqn, context, settings): + params = dict(eqn.params) + if 'accuracy' in params and params['accuracy'] is None: + del params['accuracy'] + return core._pp_eqn(eqn.replace(params=params), context, settings) _float = {np.floating} _complex = {np.complexfloating} @@ -4128,6 +4133,7 @@ def _round_lower(ctx, x, *, rounding_method): ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2( @@ -4145,19 +4151,23 @@ def _exp2_lower(ctx, x, accuracy): ] mlir.register_lowering(exp2_p, _exp2_lower) +core.pp_eqn_rules[exp2_p] = _unary_with_accuracy_pp_rule log_p = standard_unop(_float | _complex, 'log') ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) +core.pp_eqn_rules[log_p] = _unary_with_accuracy_pp_rule expm1_p = standard_unop(_float | _complex, 'expm1') ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) +core.pp_eqn_rules[expm1_p] = _unary_with_accuracy_pp_rule log1p_p = standard_unop(_float | _complex, 'log1p') ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) +core.pp_eqn_rules[log1p_p] = _unary_with_accuracy_pp_rule tanh_p = standard_unop(_float | _complex, 'tanh') ad.defjvp2( @@ -4165,6 +4175,7 @@ def _exp2_lower(ctx, x, accuracy): lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), ) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) +core.pp_eqn_rules[tanh_p] = _unary_with_accuracy_pp_rule logistic_p = standard_unop(_float | _complex, 'logistic') ad.defjvp2( @@ -4174,13 +4185,13 @@ def _exp2_lower(ctx, x, accuracy): # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) - def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +core.pp_eqn_rules[logistic_p] = _unary_with_accuracy_pp_rule def _sin_complex(x): # use expm1 instead of exp to avoid cancellation when abs(x) is small @@ -4219,6 +4230,7 @@ def _sin_p_lin(nzs, x, accuracy): ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) +core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule def _cos_complex(x): @@ -4244,10 +4256,12 @@ def _cos_lowering(ctx, x, accuracy): cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) ) mlir.register_lowering(cos_p, _cos_lowering) +core.pp_eqn_rules[cos_p] = _unary_with_accuracy_pp_rule tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +core.pp_eqn_rules[tan_p] = _unary_with_accuracy_pp_rule asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) @@ -4365,6 +4379,7 @@ def _abs_jvp_rule(g, ans, x): sqrt_p = standard_unop(_float | _complex, 'sqrt') ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) +core.pp_eqn_rules[sqrt_p] = _unary_with_accuracy_pp_rule rsqrt_p = standard_unop(_float | _complex, 'rsqrt') ad.defjvp2( @@ -4372,6 +4387,7 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), ) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) +core.pp_eqn_rules[rsqrt_p] = _unary_with_accuracy_pp_rule cbrt_p = standard_unop(_float, 'cbrt') ad.defjvp2( @@ -4381,6 +4397,7 @@ def _abs_jvp_rule(g, ans, x): ), ) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +core.pp_eqn_rules[cbrt_p] = _unary_with_accuracy_pp_rule square_p = standard_unop(_int | _float | _complex, 'square') diff --git a/tests/api_test.py b/tests/api_test.py index 83264f10e033..440fea1b059c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5093,11 +5093,11 @@ def f_yesremat(x): jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5450,9 +5450,9 @@ def f(x): ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ - ('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']), - ('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []), - ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']), + ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), + ('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []), + ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']), ]) def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): for square in [lambda x: x * x, api.jit(lambda x: x * x)]: @@ -5482,8 +5482,8 @@ def test_remat_custom_policy_save_cos(self, remat): policy=save_cos) _, f_lin = api.linearize(f, 1.) jaxpr_text = str(f_lin.func.args[0]) - self.assertNotIn(' sin[', jaxpr_text) - self.assertNotIn(' cos[', jaxpr_text) + self.assertNotIn(' sin ', jaxpr_text) + self.assertNotIn(' cos ', jaxpr_text) jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev']) @parameterized.named_parameters( @@ -5505,7 +5505,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5528,7 +5528,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5551,7 +5551,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((3, 2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 9) jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5575,7 +5575,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5599,8 +5599,8 @@ def body(x, _): return f(x), None # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which # effectively saves one sine. Three cosines for the Jacobian coefficients. - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compure the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5906,8 +5906,9 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + print("debug jaxpr: ", str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5952,8 +5953,8 @@ def body(x, _): return f(x), None jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 3) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 3) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5970,8 +5971,8 @@ def test_remat_of_scan_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): @@ -5994,40 +5995,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -6052,40 +6053,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6100,8 +6101,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertNotIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertNotIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) true_fn = lambda c: jnp.sin(jnp.sin(c)) false_fn = lambda c: c @@ -6109,8 +6110,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6150,8 +6151,8 @@ def f(x): _, f_vjp = api.vjp(f, jnp.ones((5, 5))) jaxpr_text = str(f_vjp.args[0].func.args[1]) - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Five calls to dot_general in the backward pass because we have two for # each forward-pass dot, except for the first which only has one (as we are # differentiating with respect to only W and not x). @@ -6181,8 +6182,8 @@ def f(x): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) self.assertEqual(jaxpr_text.count(' dot_'), 5) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, @@ -6196,8 +6197,8 @@ def test_remat_of_cond_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): @@ -6219,40 +6220,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -6276,40 +6277,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6334,8 +6335,8 @@ def f(x): self.assertArraysAllClose(y_dot, expected, check_dtypes=False) jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) def test_remat_of_while_loop_policy(self): def cond_fn(carry): @@ -6352,8 +6353,8 @@ def f(x): save_cos = lambda prim, *_, **__: str(prim) == 'cos' g = new_checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): diff --git a/tests/core_test.py b/tests/core_test.py index 03d6355cb257..8ab24dbe51f6 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -474,8 +474,8 @@ def new_jaxpr(): # jaxpr is: # # { lambda ; a. - # let b = sin[accuracy=None] a - # c = cos[accuracy=None] a + # let b = sin a + # c = cos a # d = add b c # in (d,) } # @@ -487,7 +487,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +496,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d40293501edf..af2d03e2945d 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2082,8 +2082,8 @@ def test_remat_of_pmap(self, remat): x = jnp.arange(1.) jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -2100,24 +2100,24 @@ def test_remat_of_pmap_policy(self, remat): _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 2) save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index fb370ab96923..289e33a404f2 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -253,7 +253,7 @@ def f(x, y): @parameterized.named_parameters( *generate_test_cases(["exp", "expm1", "exp2"]) ) - def test_diff_grad(self, op, x, tp, **kwargs): + def test_diff_grad(self, op, x, tp, **kwargs): @jax.jit def f_default(x): default_op = op(x, accuracy=tp.low) @@ -368,6 +368,28 @@ def test_low_tol(self, op, x, **kwargs): ): op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + def test_accuracy_jaxpr(self): + # Since accuracy is not set, the jaxpr should not contain "accuracy". + self.assertNotIn( + "accuracy", + str( + jax.make_jaxpr(lambda x: lax.exp(x, accuracy=None))( + np.arange(4.0, dtype=np.float32) + ) + ), + ) + # Set accuracy. + self.assertIn( + "accuracy", + str( + jax.make_jaxpr( + lambda x: lax.exp( + x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0) + ) + )(np.arange(4.0, dtype=np.float32)) + ), + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From ca6e470d2f4c9e583532a1c413277e79ea2b7852 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Mon, 7 Apr 2025 23:30:31 +0000 Subject: [PATCH 0450/1769] harden cache against jaxlib ver --- jax/_src/cache_key.py | 11 +++++++---- tests/cache_key_test.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index e4b6e7a2669c..6fe3d8819d3c 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -110,6 +110,10 @@ def get( bytes(jaxlib_version_str.encode("utf-8")) ), ), + ( + "backend version", + lambda hash_obj: _hash_platform(hash_obj, backend) + ), ( "XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), @@ -126,7 +130,7 @@ def get( ), ( "accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + lambda hash_obj: _hash_accelerator_config(hash_obj, devices), ), ( "compression", @@ -220,7 +224,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): accelerator_devices = [] for accelerator in accelerators.flat: accelerator_devices.append(accelerator) @@ -233,9 +237,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): # PjRtTopologyDescription as yet. logger.info("get (_hash_accelerator_config): unable to hash " "accelerator config, falling back to hashing " - "devices + platform: %s (type %s)", ex, type(ex)) + "devices %s (type %s)", ex, type(ex)) _hash_devices(hash_obj, accelerators) - _hash_platform(hash_obj, backend) # LINT.IfChange(xla_flags) xla_flags_to_exclude_from_cache_key = [ diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index a908d260d560..fd3e7706260a 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -83,9 +83,9 @@ def test_hash_accelerator_devices(self): self.assertEqual(dev_hash1, dev_hash2) acc_hash1 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) acc_hash2 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) self.assertEqual(acc_hash1, acc_hash2) def test_hash_platform(self): From 31589960ff30816e57977f7aaa7c04f97ba9cac6 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 17:31:04 -0700 Subject: [PATCH 0451/1769] Migrate custom_call filecheck to use internal custom_call since the external one is deprecated. PiperOrigin-RevId: 744908555 --- tests/filecheck/custom_call.filecheck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py index c6af4235ebb4..27cc904e59d8 100644 --- a/tests/filecheck/custom_call.filecheck.py +++ b/tests/filecheck/custom_call.filecheck.py @@ -19,7 +19,7 @@ from absl import app import jax -from jax.interpreters import mlir +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect import numpy as np From 86de4783bb472af6a2ef17e61bd926097aa525eb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 19:25:34 -0700 Subject: [PATCH 0452/1769] Remove unused function jax._src.interpreters.mlir.xla_computation_to_mlir_module. PiperOrigin-RevId: 744934776 --- jax/_src/interpreters/mlir.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a112063ce3ae..65d9dbe5791f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2757,11 +2757,6 @@ def cached_lowering(ctx, *args, **params): return cached_lowering -def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation - ) -> ir.Module: - module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) - return ir.Module.parse(module_str) - def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module, From bb515aa74f24b7688b5a8f612990893d7da3654e Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Mon, 7 Apr 2025 20:00:46 -0700 Subject: [PATCH 0453/1769] Address previous FP8-related TODOs in jaxlib/XLA. The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4 This update allows us to address previous FP8-related TODOs in jaxlib/XLA. PiperOrigin-RevId: 744943824 --- jaxlib/xla/py_values.cc | 21 +++++++++++++++------ jaxlib/xla/xla.cc | 7 +++---- jaxlib/xla/xla_client.py | 11 +++++------ jaxlib/xla/xla_client.pyi | 9 ++++----- jaxlib/xla/xla_client_test.py | 32 +++++++++++++++++++++++++++----- 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 709f3cb3b2ef..90dd77209694 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -694,16 +694,25 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; - (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index e460a1773e94..660e62bd8019 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -208,15 +208,14 @@ NB_MODULE(xla_extension, m) { .value("U64", U64) .value("F16", F16) .value("F4E2M1FN", F4E2M1FN) - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // .value("F8E3M4", F8E3M4) - // .value("F8E4M3", F8E4M3) - .value("F8E8M0FNU", F8E8M0FNU) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) .value("F8E5M2", F8E5M2) .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) .value("BF16", BF16) .value("F32", F32) .value("F64", F64) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index fa31d1764de2..637d7d060aa2 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -260,16 +260,15 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), - # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), - # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), - # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), + PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), + PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), - PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), + PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.BF16: np.dtype(bfloat16), PrimitiveType.F16: np.dtype('float16'), PrimitiveType.F32: np.dtype('float32'), diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index b182eb65ba60..382858d2a6d0 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -63,16 +63,15 @@ _ifrt_version: int mlir_api_version: int bfloat16: type[numpy.generic] -# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# float4_e2m1fn: type[numpy.generic] -# float8_e3m4: type[numpy.generic] -# float8_e4m3: type[numpy.generic] -# float8_e8m0fnu: type[numpy.generic] +float4_e2m1fn: type[numpy.generic] +float8_e3m4: type[numpy.generic] +float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] float8_e5m2: type[numpy.generic] float8_e5m2fnuz: type[numpy.generic] +float8_e8m0fnu: type[numpy.generic] XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 7de905d9ec41..9c6625610ca6 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -48,12 +48,12 @@ float4_e2m1fn = ml_dtypes.float4_e2m1fn float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 -float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu ops = xla_client.ops xla_computation_to_mlir_module = ( xla_client._xla.mlir.xla_computation_to_mlir_module) @@ -178,10 +178,17 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + fp8_dtypes = [ + float8_e3m4, + float8_e4m3, + float8_e4m3fn, + float8_e4m3b11fnuz, + float8_e5m2, + float8_e8m0fnu, + ] standard_dtypes += fp8_dtypes - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + # TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type + # standard_dtypes += [float4_e2m1fn] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): @@ -1228,9 +1235,19 @@ def testStandardTypes(self): for dtype in standard_dtypes: if dtype == np.complex128: continue + # float8_e8m0fnu is not supported on TPU. + if dtype == float8_e8m0fnu and self.backend.platform == "tpu": + continue # float8_e4m3b11fnuz not supported on some TPU backends. if ( - dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz] + dtype + in [ + float8_e3m4, + float8_e4m3, + float8_e4m3fnuz, + float8_e4m3b11fnuz, + float8_e5m2fnuz, + ] and self.backend.platform == "tpu" ): if self.backend.platform_version.find("TPU") == -1: @@ -2253,6 +2270,11 @@ def testFft(self): "dtype": dtype, } for dtype in float_dtypes + fp8_dtypes) def testNextAfter(self, dtype): + if dtype == float8_e8m0fnu: + # TODO(b/409114865): Test fails with Mismatched elements error. + self.skipTest("b/409114865: Test fails with Mismatched elements error") + if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3") if dtype == np.float64 and self.backend.platform == "tpu": self.skipTest("TPU doesn't support float64") if dtype == bfloat16 and self.backend.platform == "tpu": From 51dbcd4dad3acf6f83943d1febbb7d5c773c7f59 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 8 Apr 2025 00:09:27 -0700 Subject: [PATCH 0454/1769] [export] Add backwards compatibility test for annotate_device_placement. This enables exporting functions that use memory kinds to place data in different memories. jax-fixit PiperOrigin-RevId: 745008959 --- jax/_src/export/_export.py | 1 + .../annotate_data_placement.py | 73 +++++++++++++++++++ .../export_back_compat_test_util.py | 21 ++++-- tests/export_back_compat_test.py | 32 +++++++- 4 files changed, 118 insertions(+), 9 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 90cc0c186ad1..4315c948bb5c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1082,6 +1082,7 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, *_GPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "annotate_device_placement", "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py new file mode 100644 index 000000000000..bf70df2cdb3a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -0,0 +1,73 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, int32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5d5e95b5cb9a..b86b24e2b4fc 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -90,6 +90,7 @@ def func(...): ... from jax.experimental import pjit from jax._src import core +from jax._src import stages from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -165,7 +166,8 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: else: assert False, testdata_nest - def run_one_test(self, func: Callable[..., jax.Array], + def run_one_test(self, + func: Callable[..., jax.Array] | stages.Wrapped, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None, rtol: float | None = None, @@ -176,7 +178,8 @@ def run_one_test(self, func: Callable[..., jax.Array], """Run one compatibility test. Args: - func: the JAX function to serialize and run + func: the JAX function to serialize and run, either as a Python Callable + or as a `jax.jit(callable)`. data: the test data polymorphic_shapes: when using shape polymorphism, the specification for each argument of `func`. @@ -269,19 +272,22 @@ def run_one_test(self, func: Callable[..., jax.Array], expect_current_custom_calls = data.custom_call_targets self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets) - def run_current(self, func: Callable, data: CompatTestData): + def run_current(self, + func: Callable | stages.Wrapped, + data: CompatTestData): """Lowers and runs the test function at the current JAX version.""" - return jax.jit(func)(*data.inputs) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) + return jit_func(*data.inputs) def serialize(self, - func: Callable, data: CompatTestData, *, + func: Callable | stages.Wrapped, data: CompatTestData, *, polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: - func: the function to serialize + func: the function to serialize. polymorphic_shapes: the polymorphic_shapes to use for serialization allow_unstable_custom_call_targets: whether to allow additional custom call targets besides those known as stable. @@ -292,8 +298,9 @@ def serialize(self, """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) exported = export.export( - jax.jit(func), + jit_func, platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 789838f99d14..fd2b349f6c95 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -31,6 +31,7 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -161,6 +162,8 @@ def test_custom_call_coverage(self): stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion stablehlo_dynamic_approx_top_k.data_2024_05_30, + annotate_data_placement.data_2025_04_07_tpu, + annotate_data_placement.data_2025_04_07_cuda, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -817,7 +820,7 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) - def test_approx_top_k(self): + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) y = lax.approx_max_k(x, 3) @@ -834,7 +837,7 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - def test_sharding(self): + def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: self.skipTest("Test runs only on TPU with at least 2 devices") @@ -856,6 +859,31 @@ def func(x): # b: f32[2, 4] with mesh: self.run_one_test(func, data) + @parameterized.named_parameters( + dict(testcase_name=f"_platform={platform}", platform=platform) + for platform in ("tpu", "gpu")) + def test_annotate_device_placement(self, platform): + if not jtu.test_device_matches([platform]): + self.skipTest(f"Test enabled only for {platform}") + + mesh = Mesh(jax.local_devices()[0:1], axis_names=("a")) + + dev_sharding = NS(mesh, P("a")) + host_sharding = NS(mesh, P("a"), memory_kind="pinned_host") + + @partial(jax.jit, + in_shardings=(dev_sharding, host_sharding), + out_shardings=host_sharding) + def func(x, y): + return x + y + + if platform == "tpu": + data = self.load_testdata(annotate_data_placement.data_2025_04_07_tpu) + else: + data = self.load_testdata(annotate_data_placement.data_2025_04_07_cuda) + + self.run_one_test(func, data) + def test_tpu_stablehlo_dynamic_reduce_window_unary(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a # reduce window with dynamic shapes. From 19fcae12078b7a8823524203bd3e48cee9c254f5 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 00:33:16 -0700 Subject: [PATCH 0455/1769] [Mosaic GPU] Add support for replicated warp_dim parsing and a dedicated test for parsing all canonical layouts. PiperOrigin-RevId: 745015431 --- jax/experimental/mosaic/gpu/layouts.py | 26 ++++++++++++++++--------- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 2 +- tests/mosaic/gpu_dialect_test.py | 12 ++++++++++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index d9b1a01a24b5..0a4f3ed09116 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," - r" warp_dim\s*=\s*(?P[-\d]+)," + r" warp_dim\s*=\s*(?P.+)," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) @@ -107,22 +107,26 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" - def _lane_dim_str(d: int | fa.Replicated) -> str: + def _int_or_replicated(d: int | fa.Replicated) -> str: if isinstance(d, fa.Replicated): return f"#mosaic_gpu.Replicated" return str(d) tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" - lane_dims = "[" + ",".join(_lane_dim_str(d) for d in layout.lane_dims) + "]" + lane_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" + ) return ir.Attribute.parse( - f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," + f"#mosaic_gpu.TiledLayout<{tiling}," + f" warp_dim={_int_or_replicated(layout.warp_dim)}," f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") _replicated_pattern = re.compile( r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" ) @@ -143,11 +147,14 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) - def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: - match = _replicated_pattern.fullmatch(lane_dim_str) + def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(replicated_dim) if match: return fa.Replicated(int(match.group("times"))) - return int(lane_dim_str) + match = _int_pattern.fullmatch(replicated_dim) + if match: + return int(match.group("num")) + raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") tiling_str = match.group("tiling") tile_strings = [] @@ -156,9 +163,10 @@ def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), - warp_dim=int(match.group("warp_dim")), + warp_dim=_int_or_replicated(match.group("warp_dim")), lane_dims=tuple( - _lane_dim(s) for s in match.group("lane_dims").split(",") + _int_or_replicated(s.strip()) + for s in match.group("lane_dims").split(",") ), vector_dim=int(match.group("vector_dim")), ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 36f9f6f374e5..86219dbc87ac 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -161,7 +161,7 @@ def MosaicGPU_TiledLayout : AttrDef { let parameters = (ins "::mlir::ArrayAttr":$tiling, - "int":$warp_dim, + "::mlir::Attribute":$warp_dim, "::mlir::ArrayAttr":$lane_dims, "int":$vector_dim ); diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 7e211abb955a..2d75c42424ef 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -593,6 +593,18 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): ): self.module.operation.verify() + def test_tiled_layout_attr_parsing(self): + with ir.InsertionPoint(self.module.body): + for layout in ( + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_ROW_LAYOUT, + mgpu.WGMMA_COL_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + ): + attr = layouts.to_tiled_layout_attr(layout) + parsed_layout = layouts.from_tiled_layout_attr(attr) + self.assertEqual(layout, parsed_layout) + class DialectLoweringTest(MosaicGpuTest): From bc11a63113f543779c1ed8b794b00c3e747d17f7 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Tue, 8 Apr 2025 09:50:31 +0200 Subject: [PATCH 0456/1769] Clarify jax.make_jaxpr docstring --- jax/_src/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 0055f6466dae..89dfe74acd5b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2142,7 +2142,7 @@ def make_jaxpr( return_shape: bool = False, abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: - """Creates a function that produces its jaxpr given example args. + """Create a function that returns the jaxpr of ``fun`` given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional From 8ed59d8b5d99619e35b3a7ab595e11fe1668ada2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 02:05:52 -0700 Subject: [PATCH 0457/1769] Removed `jax._src.raise_to_shaped` It is just an identity after the "stackless" rewrite. PiperOrigin-RevId: 745042532 --- jax/_src/core.py | 5 ----- jax/_src/pallas/fuser/fusable.py | 6 +----- jax/_src/pallas/fuser/jaxpr_fusion.py | 6 +----- tests/api_test.py | 2 +- tests/state_test.py | 4 ++-- 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 8d32b9370091..6c5b7a08a0e9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2229,11 +2229,6 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -# TODO(dougalm): Deprecate these. They're just here for backwards compat. -def raise_to_shaped(aval): - return aval -raise_to_shaped_mappings: dict[type, Callable] = {} - ### Operations on shapes and dimension sizes. class InconclusiveDimensionOperation(Exception): diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py index aa2ea0843c0a..d9d0ee0b4682 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusable.py @@ -29,10 +29,6 @@ fusable_p.multiple_results = True -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: return fusion_lib.Fusion( func=lambda: x, @@ -53,7 +49,7 @@ def wrapped(*args): flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(wrapped, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) out_tree = out_tree_thunk() out = fusable_p.bind( diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 649037e18092..3c3c2a3d7b66 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -28,10 +28,6 @@ from jax._src.pallas.fuser.fusable import fusable_p -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - def fuse(f=None, *, physicalize: bool = False, debug: bool = False): """Fuses a function into a single fusable. @@ -52,7 +48,7 @@ def wrapper(*args, **kwargs): flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) if debug: print("Jaxpr before fusion:") diff --git a/tests/api_test.py b/tests/api_test.py index 440fea1b059c..0e8cf2502540 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5024,7 +5024,7 @@ def g(x): # Make sure that introducing constants in vmap works. constant_introducing_p = core.Primitive('introduce_constant') - constant_introducing_p.def_abstract_eval(core.raise_to_shaped) + constant_introducing_p.def_abstract_eval(lambda x: x) def _constant_introducing_batcher(xs, ds): (x,), (d,) = xs, ds return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d diff --git a/tests/state_test.py b/tests/state_test.py index 60a7d8bc9f8a..03902687c40e 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -792,7 +792,7 @@ def body(i, st): lax.fori_loop(0, 5, body, init_val=()) return a_ref[...], b_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -1139,7 +1139,7 @@ def false_fun(): y_ref[...] = 2. lax.cond(pred, true_fun, false_fun) return x_ref[...], y_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref From af072feb5a02bc75d4f9fec487ae59be60b0c01b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 02:37:36 -0700 Subject: [PATCH 0458/1769] Removed redundant `pass`es If a function or class has a docstring, it does not need a `pass`. PiperOrigin-RevId: 745052107 --- jax/_src/core.py | 1 - jax/_src/errors.py | 1 - jax/_src/pallas/fuser/fusable_dtype.py | 2 -- jax/_src/profiler.py | 1 - tests/api_test.py | 1 - 5 files changed, 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 6c5b7a08a0e9..236781c16d27 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2233,7 +2233,6 @@ def block_until_ready(self): class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" - pass def is_symbolic_dim(v: Any) -> bool: """Checks if a value is a symbolic dimension used for shape polymorphism. diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 6540fd1f5d41..b9831bfe3b1a 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -680,4 +680,3 @@ class KeyReuseError(JAXTypeError): must be manually split; For more information on this see `the Pseudorandom Numbers tutorial `_. """ - pass diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusable_dtype.py index e5bc9ab683ab..99c80e652791 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusable_dtype.py @@ -83,8 +83,6 @@ def unpack(x): class FusableElementDType(dtypes.extended): """Scalar dtype for fusable dtypes.""" - pass - class FusableTyRules: allow_conversion: bool = False diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 96e742f33904..0e9949f27f55 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -272,7 +272,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe): This will cause a "my_label" event to show up on the trace timeline if the event occurs while the process is being traced. """ - pass class StepTraceAnnotation(TraceAnnotation): diff --git a/tests/api_test.py b/tests/api_test.py index 0e8cf2502540..2d9fcd1ff554 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3120,7 +3120,6 @@ def test_error_for_invalid_dtype(self): def test_vmap_preserves_docstr(self): def superfun(a): """Does things with stuff.""" - pass self.assertRegex(api.vmap(superfun).__doc__, "\n".join([ "Vectorized version of superfun.*", From d12cbffd4912980f290d676ee1b606cb9d1c9ad2 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 03:04:47 -0700 Subject: [PATCH 0459/1769] [Mosaic GPU] Refactor and generalize code in `optimization_barrier`. The change in `utils.py` is to enable the use of `bitwidth` when the mlir dialect is not registered. PiperOrigin-RevId: 745060221 --- jax/experimental/mosaic/gpu/core.py | 1 + .../mosaic/gpu/fragmented_array.py | 40 ++++++++++--------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index e822ea5f3ebf..860b41e7e8e3 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -479,6 +479,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + dialect.register_dialect(module.context) attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) if kernel_name is None: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ecd51f79eab0..9ab27927791a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2434,10 +2434,23 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + def _repack(regs_it, reg_ty): + if not ir.VectorType.isinstance(reg_ty): + result_reg = next(regs_it) + assert result_reg.type == reg_ty + return result_reg + + num_i32_regs = utils.bitwidth(reg_ty) // 32 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) + reg = llvm.mlir_undef(i32_reg_ty) + for i_elem in range(num_i32_regs): + val = llvm.bitcast(i32, next(regs_it)) + reg = llvm.insertelement(reg, val, arith.constant(i32, i_elem)) + return vector.bitcast(reg_ty, reg) + regs = [] reg_dtypes = [] reg_constraints = [] - repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: @@ -2451,36 +2464,25 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): for reg in array.registers.flat for pos in range(vec_len) ] - def _repack(regs, reg_ty=reg_ty): - reg = llvm.mlir_undef(reg_ty) - [vec_len] = ir.VectorType(reg_ty).shape - for i_elem in range(vec_len): - reg = llvm.insertelement( - reg, next(regs), arith.constant(i32, i_elem) - ) - return reg - repack_fns.append(_repack) else: array_regs = list(array.registers.flat) - repack_fns.append(lambda regs: next(regs)) reg_constraint = "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape - if vec_len != 2: + if vec_len % 2: raise NotImplementedError(vec_len) - i32_reg_ty = ir.VectorType.get((1,), i32) + num_i32_regs = vec_len // 2 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) array_regs = [ vector.extractelement( - vector.bitcast(i32_reg_ty, reg), position=c(0, index) + vector.bitcast(i32_reg_ty, reg), position=c(i, index) ) + for i in range(num_i32_regs) for reg in array.registers.flat ] reg_constraint = "r" - def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): - return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) - repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs @@ -2508,14 +2510,14 @@ def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) - for array, repack_fn in zip(arrays, repack_fns, strict=True): + for array in arrays: num_regs = array.registers.size reg_ty = array.registers.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): - reg = repack_fn(regs_it) + reg = _repack(regs_it, reg_ty) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( From c4cc94a10cde3e480b7a4b6c76d304d782292895 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 03:22:30 -0700 Subject: [PATCH 0460/1769] [Mosaic GPU] Add warpgroup lowering for `RunState` in Pallas. After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering. PiperOrigin-RevId: 745065173 --- jax/_src/pallas/mosaic_gpu/lowering.py | 31 ++++++++++++++++---------- tests/pallas/mosaic_gpu_test.py | 22 +++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f7bdbccc1ad6..3fc2362decfc 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2034,6 +2034,7 @@ def _run_scoped_lowering_rule( @register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2051,7 +2052,12 @@ def _run_state_lowering_rule( for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + arg = mgpu.dialect.optimization_barrier([arg]) + nvvm_dialect.wgmma_fence_aligned() + new_input_vals.append(arg) + else: + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) should_discharge.append(True) assert isinstance(out_aval, jax_core.ShapedArray) else: @@ -2273,18 +2279,19 @@ def _while_lowering_rule( ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args ) loop_out = [*map(_ensure, loop_out, carry_avals)] - for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): - if _is_acc(carry_fa) != _is_acc(out_fa): - raise ValueError( - f"The loop body output has unexpected accumulator type: output[{idx}]" - f" is {out_fa}, when it should be {carry_fa}." - ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): + if _is_acc(carry_fa) != _is_acc(out_fa): + raise ValueError( + f"The loop body output has unexpected accumulator type:" + f" output[{idx}] is {out_fa}, when it should be {carry_fa}." + ) - if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: - raise ValueError( - f"The loop body output has unexpected layout: output[{idx}] has" - f" layout {out_fa.layout}, when it should be {carry_fa.layout}." - ) + if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: + raise ValueError( + f"The loop body output has unexpected layout: output[{idx}] has" + f" layout {out_fa.layout}, when it should be {carry_fa.layout}." + ) scf_dialect.yield_( carry_treedef.flatten_up_to(loop_out) if loop_out else [] ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f0f3bdf41c32..a73c4f82c31d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -32,7 +32,6 @@ from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives -from jax._src.state import discharge from jax.experimental import pallas as pl import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu @@ -1528,7 +1527,6 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, lax.slice_p, - discharge.run_state_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) @@ -1538,10 +1536,14 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - # ``pl.run_state`` is not supported in WG semantics. - self.skip_if_wg_semantics() - - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + if force_while: + # Layout inference and lowering for 'while' are not yet implemented for + # warpgroup semantics. + self.skip_if_wg_semantics() + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () @functools.partial( self.pallas_call, in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], @@ -1733,9 +1735,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): - # ``pl.run_state`` is not supported in WG semantics. - self.skip_if_wg_semantics() - def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1746,7 +1745,10 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () res = self.pallas_call( kernel, in_specs=[ From 12811f08a8fc5fec7c39d17e3e48d14a8e339f06 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 03:29:50 -0700 Subject: [PATCH 0461/1769] Removed `eager_pmap` config option It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 745067361 --- CHANGELOG.md | 1 + jax/_src/config.py | 7 ------- jax/_src/interpreters/pxla.py | 4 ++-- tests/pmap_test.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aae0f432121..e744cad902de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Removed the `config.jax_data_dependent_tracing_fallback` config option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. * Changes * The minimum CuDNN version is v9.8. diff --git a/jax/_src/config.py b/jax/_src/config.py index 8aa4ee343664..1fbb401afb61 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1514,13 +1514,6 @@ def _update_disable_jit_thread_local(val): 'compute when encountering OOM errors. However, you are ' 'likely to get better results manually with jax.checkpoint')) -# TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = bool_state( - name='jax_eager_pmap', - default=True, - upgrade=True, - help='Enable eager-mode pmap when jax_disable_jit is activated.') - no_tracing = bool_state( name='jax_no_tracing', default=False, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 51854b457b37..45bdd4e17e8e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -338,8 +338,8 @@ def xla_pmap_impl_lazy( donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ) -> Callable: - if (config.disable_jit.value and config.eager_pmap.value and - not is_explicit_global_axis_size and not any(d for d in donated_invars)): + if (config.disable_jit.value and + not is_explicit_global_axis_size and not any(donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..a07a9e271907 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3189,7 +3189,7 @@ class EagerPmapMixin: def setUp(self): super().setUp() stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True)) stack.enter_context(jtu.ignore_warning( message="Some donated buffers were not usable", category=UserWarning)) self.addCleanup(stack.close) From 5f33280dedb50e72abc3613461bbbe8a67b97f70 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 04:54:30 -0700 Subject: [PATCH 0462/1769] [pallas:mosaic_gpu] `emit_pipeline*` now allows the grid to be dynamic PiperOrigin-RevId: 745091128 --- jax/_src/pallas/mosaic_gpu/lowering.py | 7 ++-- jax/_src/pallas/mosaic_gpu/pipeline.py | 42 +++++++++++++----------- jax/_src/pallas/mosaic_gpu/primitives.py | 5 +-- tests/pallas/mosaic_gpu_test.py | 38 +++++++++++++-------- 4 files changed, 54 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3fc2362decfc..b7aa01dbbfcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -897,8 +897,11 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - foreach(write_env, jaxpr.constvars, consts) - foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach( + functools.partial(write_env, require_value=False), jaxpr.constvars, consts + ) + foreach(functools.partial(write_env, require_value=False), jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index df9c6668a51d..ecd7e792afbe 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -114,7 +114,7 @@ def _uses_arguments( def _is_index_invariant( - spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid + spec: pallas_core.BlockSpec, grid: pallas_core.TupleGrid ) -> bool: if (index_map := spec.index_map) is None: return True @@ -122,7 +122,7 @@ def _is_index_invariant( def _inc_grid_by_1( - indices: tuple[jax.Array, ...], grid: Sequence[int] + indices: tuple[jax.Array, ...], grid: pallas_core.TupleGrid ) -> tuple[jax.Array, ...]: next_indices = [] carry: bool | jax.Array = True @@ -161,7 +161,7 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, @@ -182,19 +182,19 @@ def emit_pipeline( ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. """ - num_steps = math.prod(grid) - if max_concurrent_steps <= delay_release: raise ValueError( "max_concurrent_steps must be greater than delay_release, but" f" {max_concurrent_steps=}, {delay_release=}" ) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_steps: + if not has_dynamic_grid and max_concurrent_steps > num_steps: max_concurrent_steps = num_steps - delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -244,12 +244,14 @@ def scoped_pipeline( ) ] - for step, indices in enumerate( - it.islice(it.product(*map(range, grid)), max_concurrent_steps) - ): - indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) + # Initialize the pipeline. + indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + fetch_indices = indices + for step in range(max_concurrent_steps): for bref in in_brefs: - bref.copy_in(step, indices, barrier_ref) + bref.copy_in(step, fetch_indices, barrier_ref) + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + del fetch_indices # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) @@ -327,7 +329,6 @@ def do_fetch(): # Invariant: ``indices`` and ``fetch_indices`` are always # ``max_concurrent_steps-delay_release`` apart. - indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps-delay_release): fetch_indices = _inc_grid_by_1(fetch_indices, grid) @@ -362,7 +363,7 @@ def do_fetch(): def emit_pipeline_warp_specialized( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, memory_registers: int, in_specs: Sequence[pl.BlockSpec] = (), out_specs: Sequence[pl.BlockSpec] = (), @@ -434,7 +435,8 @@ def body( not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] - num_pipeline_steps = math.prod(grid) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) def _get_slot(step, has_seq_dim): """Returns the buffer slot given the pipeline step.""" @@ -445,8 +447,8 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_pipeline_steps: - max_concurrent_steps = num_pipeline_steps + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -612,7 +614,7 @@ def compute_loop_body(step, carry): carry_init = None init_loop_carry = (init_indices, last_store_slices, carry_init) last_indices, _, final_body_carry = lax.fori_loop(0, - num_pipeline_steps, + num_steps, compute_loop_body, init_loop_carry) if has_carry: @@ -626,7 +628,7 @@ def compute_loop_body(step, carry): # written in the main pipeline loop. if not copies_out_in_loop: gpu_primitives.commit_smem() - last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: bref.copy_out(_get_slot(last_slot, has_seq_dim=False), @@ -671,7 +673,7 @@ def memory_loop_body(step, carry): _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) - lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, + lax.fori_loop(0, num_steps - max_concurrent_steps, memory_loop_body, (indices,)) wg_idx = lax.axis_index(wg_axis) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a37b018860d7..c41a36da94e8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -192,8 +192,9 @@ def _copy_smem_to_gmem_pp_eqn( pp_params = {} if not (commit_group := eqn.params["commit_group"]): pp_params["commit_group"] = commit_group - if has_user_predicate := eqn.params["has_user_predicate"]: - pp_params["has_user_predicate"] = has_user_predicate + if eqn.params["has_user_predicate"]: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + pp_params["user_predicate"] = jax_core.pp_var(user_predicate, context) if reduction_op := eqn.params["reduction_op"]: pp_params["reduction_op"] = reduction_op flat_src_transforms, flat_dst_transforms = util.split_list( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a73c4f82c31d..e32222775f94 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2096,16 +2096,21 @@ def kernel_body(_, x_smem, o_smem): y = x + 1.0 np.testing.assert_array_equal(kernel_fn(x), y) - def test_emit_with_2d_grid(self): + @parameterized.product(static=[False, True]) + def test_emit_with_2d_grid(self, static): num_steps1 = 4 num_steps2 = 5 def kernel(x_gmem, o_gmem): + grid = (num_steps1, num_steps2) + if static: + grid = jax.tree.map(jnp.asarray, grid) + plgpu.emit_pipeline( kernel_body, in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - grid=(num_steps1, num_steps2), + grid=grid, max_concurrent_steps=2, )(x_gmem, o_gmem) @@ -2258,8 +2263,8 @@ def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2]) - def test_elementwise_add(self, m, n, num_compute_wgs): + @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2], static=[False, True]) + def test_elementwise_add(self, m, n, num_compute_wgs, static): self.skip_if_wg_semantics() # Crashes! blk_m = blk_n = 64 @@ -2273,16 +2278,21 @@ def tiled_add_kernel(_, x_smem, y_smem, o_smem): # This is currently a race, but the values written are the same. o_smem[...] = x_smem[...] + y_smem[...] - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_add_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - in_specs=[spec, spec], - out_specs=[spec], - ) + def pipeline(*gmem_refs): + grid = (m // blk_m, n // blk_n) + if static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=[spec, spec], + out_specs=[spec], + )(*gmem_refs) + kernel = self.kernel( pipeline, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), From 73ecf0bb483eb8239670c1c7a07349519bcf70ac Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 8 Apr 2025 05:24:34 -0700 Subject: [PATCH 0463/1769] Remove unused `return wrapper` in annotate_function that creates a self reference cycle loop in python. PiperOrigin-RevId: 745099538 --- jax/_src/profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 0e9949f27f55..912c90182977 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -332,7 +332,6 @@ def annotate_function(func: Callable, name: str | None = None, def wrapper(*args, **kwargs): with TraceAnnotation(name, **decorator_kwargs): return func(*args, **kwargs) - return wrapper return wrapper From 511f78202ff94c7fb88eb2f2ea7427a043c52962 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 11:52:53 +0000 Subject: [PATCH 0464/1769] Add a skeleton for Pallas:Mosaic GPU documentation --- docs/_static/pallas/gpu/nvidia_sm.svg | 99 +++++++++++++++++ docs/pallas/gpu/index.rst | 14 +++ docs/pallas/gpu/reference.md | 150 ++++++++++++++++++++++++++ docs/pallas/index.rst | 8 +- 4 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 docs/_static/pallas/gpu/nvidia_sm.svg create mode 100644 docs/pallas/gpu/index.rst create mode 100644 docs/pallas/gpu/reference.md diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst new file mode 100644 index 000000000000..2d95d5c928c4 --- /dev/null +++ b/docs/pallas/gpu/index.rst @@ -0,0 +1,14 @@ +Pallas:Mosaic GPU +================= +Backend specific documentation for the Mosaic GPU backend. + +.. toctree:: + :caption: Reference documentation + :maxdepth: 2 + + reference + +.. toctree:: + :caption: Guides + :maxdepth: 2 + diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md new file mode 100644 index 000000000000..416679d9654c --- /dev/null +++ b/docs/pallas/gpu/reference.md @@ -0,0 +1,150 @@ +# Writing Mosaic GPU kernels with Pallas + +This page is a reference for the most important features of the Pallas:MGPU backend. +It's not a tutorial and as such we do not expect everyone to read it top to bottom. +Still, it is worth going over +just to familiarise yourself with some patterns you can find in other tutorials. + +In the following examples, we're going to assume the following imports are in scope: +```python +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +``` + +## What is a GPU? + +Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into +_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model +is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple +blocks can be scheduled onto a single SM at a time. + +Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions, +each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...). +This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA +threads in a block) is assigned to one of those subdivisions in a round-robin fashion. +Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates), +but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the +warp scheduler from each subdivision tries to select one of its resident warps to execute +the next instruction. + +![A diagram of one SM](../../_static/pallas/gpu/nvidia_sm.svg) + +Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are +4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming +from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions +that utilize the whole SM. + +> A GPU can be viewed in many different ways and in here we want to focus on a slightly + simplified model that is very TensorCore-centric. This should help you navigate the + complexities of writing kernels involving the TensorCore, but keep in mind that the + real picture is more complicated. + +For our purposes, TensorCore operations have grown so big that it no longer makes much +sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores +(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each +operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent +warps always run in lockstep (modulo the jitter from hardware scheduling) and never take +different paths through control flow (with the small exception of `core_map` that we will +discuss later). One notable addition here is that we still allow you to co-schedule multiple +of those Pallas-level threads on the same SM so that they can cooperate and communicate +through shared memory (we relize that by putting them in the same CUDA block). + +> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), + but as you will see there are a few differences. Mosaic GPU tends to be more low level, + which usually means you will have to put in more work, but it also puts you more in control. + In our view both approaches have their merits and we encourage you to pick the backend that + suits your needs the best! Pallas supports and will continue to support Triton as an alternative + GPU backend. + +### In-order execution & using multiple hardware units + +Unlike more complicated CPU architectures GPU only support in-order execution. That, however, +does not mean that at any given time only a single instruction is running! Each SM quarter +has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU), +Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the +units and is followed by another one (that does not use the result of the first one), then the +warp scheduler can issue the second one before the first one completes. This is often referred +to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels: +TensorCore operations are so big and take so many cycles to complete, that it is a waste to not +try to use other units in the meantime. + +To extend this even further, we can take advantage of this hardware-unit-level parallelism by +allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily +occupies the ALU, while another one primarily issues TensorCore related instructions, we can +take advantage of the efficient context switching built into the warp schedulers to keep both +units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) +or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/). + +For more information on how warp scheduling and instruction issue works, we recommend reading +[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). + +## Array layouts and reference transforms + +TODO + +## MMA (TensorCore) + +In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. +NVIDIA continues to change the programming interface of the TensorCore significantly +between different hardware generations, which is why the lowest-level interfaces +differ in Pallas:MGPU as well. + +### Hopper (`wgmma`) + +TODO + +### Blackwell (`tcgen05`) + +TODO + +## Using `core_map` + +TODO + +## Synchronization structures and primitives + +### `commit_smem` + +TODO + +### `Barrier` + +This is essentially a thin wrapper around an array of PTX `mbarrier` types and is +passed in as a reference. All functions involving barriers expect to only get a single +barrier argument, and so if the reference contains multiple, you have to extract one +of them explicitly using `barriers.at[index]`. + +`Barrier`s are always allocated in SMEM and as such have relatively low overheads. +There are three primary use cases that require the use of `Barrier`s: + +1. Awaiting asynchronous GMEM-to-SMEM copies + +TODO + +2. Cross-warpgroup synchronization + +TODO + +3. Awaiting `tcgen05` TensorCore instructions + +TODO + +### `ClusterBarrier` + +TODO + +### `Semaphore` + +TODO + +## Asynchronous copies + +TODO + +## Inline Mosaic GPU + +TODO + +## Compiler parameters + +TODO diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index b2e2fca6c82e..6c1a048298c1 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -26,11 +26,17 @@ See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: - :caption: Platform Features + :caption: TPU backend guide :maxdepth: 2 tpu/index +.. toctree:: + :caption: Mosaic GPU backend guide + :maxdepth: 2 + + gpu/index + .. toctree:: :caption: Design Notes :maxdepth: 2 From d6524dc4616409808d8b1b0b9cd477d09fb0d818 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 8 Apr 2025 07:10:59 -0700 Subject: [PATCH 0465/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3764aee831189bd32a9c7dea56926b8f31ae86bf. PiperOrigin-RevId: 745130406 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d4df9ee38034..0b9751ead471 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "77635006f6a898f71f19db360e9b4485aa5106da" -XLA_SHA256 = "d2a63a3cd2f354cd07699f30e7b5c16c7513e686e498b8ad712fb577ab677121" +XLA_COMMIT = "3764aee831189bd32a9c7dea56926b8f31ae86bf" +XLA_SHA256 = "845ce079537b7c25ca236d9910e460803b4148564f5c9c5440b6dab479919e68" def repo(): tf_http_archive( From b926fac66e7f80d4869a4da35a3630c00e050c54 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 07:39:09 -0700 Subject: [PATCH 0466/1769] [Mosaic GPU] Simplify load/store methods now that we have fewer layouts PiperOrigin-RevId: 745139008 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- .../mosaic/gpu/fragmented_array.py | 72 +++-------------- tests/mosaic/gpu_test.py | 78 +++++++++---------- 3 files changed, 50 insertions(+), 102 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b7aa01dbbfcf..9b44a1165cdf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1226,7 +1226,7 @@ def _swap_lowering_rule( is_signed=mgpu_utils.is_signed(x_aval.dtype), optimized=False, ) - value.store_untiled(x_smem) + value.store_untiled(x_smem, optimized=False) return old_value case _: old_value = mgpu.FragmentedArray.load_strided( diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 9ab27927791a..df1e03627f94 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1788,25 +1788,23 @@ def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) - def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): + def store_untiled( + self, ref: ir.Value, *, swizzle: int = 16, optimized: bool = True + ): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) - - def vs_unsupported(): - if not vector_store: - raise NotImplementedError( - f"Can't use non-vector stores with layout {self.layout}" - ) - match self.layout: case WGSplatFragLayout(): - vs_unsupported() + # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): - vs_unsupported() + if swizzle != 16: + raise NotImplementedError self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref, vector_store=vector_store) + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized) case _: raise NotImplementedError(self.layout) @@ -1861,61 +1859,15 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): - """Stores an array with a tiled layout. Not optimized at the moment.""" - if utils.bitwidth(self.mlir_dtype) < 8: - raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") - i32 = ir.IntegerType.get_signless(32) - layout = self.layout - assert isinstance(layout, TiledLayout) - ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if vector_store and ref_strides[layout.vector_dim] != 1: - raise NotImplementedError( - "Can't use vector stores with non-unit minormost stride" - ) - strides = layout.tiling.tile_strides(ref_strides) - smem_space = ir.Attribute.parse("#gpu.address_space") - ref_space = ir.MemRefType(ref.type).memory_space - memory_space = None - if str(ref_space) == str(smem_space): - memory_space = 3 - elif ref_space: - raise NotImplementedError(f"Unexpected ref space {ref_space}") - ptr = utils.memref_ptr(ref, memory_space=memory_space) - # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [ - arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] - ] - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) - dyn_offset = arith.addi(warp_offset, lane_offset) - ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) - # All warp tile offsets are static and can be fused into the store. - for tile_idx, reg in np.ndenumerate(self.registers): - if vector_store: - elems = [reg] - else: - index = ir.IndexType.get() - elems = [ - vector.extractelement(reg, position=c(i, index)) - for i in range(ir.VectorType(reg.type).shape[0]) - ] - for i, e in enumerate(elems): - tile_idx_local = list(tile_idx) - tile_idx_local[layout.vector_dim] += i - tile_idx_local = list(tile_idx_local) - lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(e, reg_ptr) - - def store_tiled(self, ref, swizzle: int | None): + def store_tiled(self, ref, swizzle: int | None, optimized: bool = True): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape # Note that the loop below will "race" for layouts that replicate data. # However, in that case all of the racing writes store the same data, which # is ok in the CUDA memory model. - for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + stores = self.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for get, _, ptr in stores: llvm.store(get(self.registers), ptr) @classmethod diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f0930f5de8cc..b19dee9c065c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -489,19 +489,12 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.product(dtype=[jnp.float16, jnp.float32], - transposed_smem=[False, True]) - def test_store_untiled(self, dtype, transposed_smem): + @parameterized.product(dtype=[jnp.float16, jnp.float32]) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - if transposed_smem: - out = memref_transpose(out, (1, 0)) - iota_tensor(64, 64, dtype).store_untiled( - out, vector_store=not transposed_smem - ) + iota_tensor(64, 64, dtype).store_untiled(out, optimized=False) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) - if transposed_smem: - expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() @@ -749,7 +742,7 @@ def kernel(ctx, lhs, rhs, out, scratch): acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) def quantize(x): # Quantize the input to avoid rounding when feeding the WGMMA @@ -821,7 +814,7 @@ def kernel(ctx, rhs, out, rhs_smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) y_shape = (n, k) if rhs_transpose else (k, n) y = self.prng.uniform(-1, 1, y_shape).astype(dtype) @@ -881,7 +874,7 @@ def kernel(ctx, rhs, out, smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) jax_dtype = jnp.float16 y_shape = (n, k) if rhs_transpose else (k, n) @@ -1042,7 +1035,7 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out) + acc[:].store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) @@ -1145,7 +1138,7 @@ def kernel(ctx, lhs, rhs, out, scratch): tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) barriers[2].wait(for_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice)) + acc[:].store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant @@ -1198,7 +1191,7 @@ def kernel(ctx, dst, scratch): final_arr = arr + mgpu.FragmentedArray.load_strided( tmp, is_signed=False ) - final_arr.store_untiled(memref_slice(dst, 0)) + final_arr.store_untiled(memref_slice(dst, 0), optimized=False) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() @@ -1209,7 +1202,7 @@ def kernel(ctx, dst, scratch): barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) barriers[1].arrive() # Signal that tmp is ready. - final_arr.store_untiled(memref_slice(dst, 1)) + final_arr.store_untiled(memref_slice(dst, 1), optimized=False) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) y = mgpu.as_gpu_kernel( @@ -1670,7 +1663,7 @@ def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) iota = iota_tensor(m, n, dtype) rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) - op(iota, rhs).store_untiled(dst) + op(iota, rhs).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1716,7 +1709,7 @@ def test_division(self, op, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1746,14 +1739,14 @@ def kernel(ctx, dst, _): rhs = 0 if rhs_is_literal else iota + 1 res = op(iota, rhs) assert not res.is_signed - res.astype(i8, is_signed=False).store_untiled(dst) + res.astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - rhs = rhs = 0 if rhs_is_literal else iota + 1 + rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) def test_foreach_wgmma_row_array(self): @@ -1784,9 +1777,8 @@ def _(v, idx): def test_foreach(self): dtype = jnp.int32 swizzle = 128 - tile = 64, swizzle // jnp.dtype(dtype).itemsize + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) shape = 128, 192 - tiled_shape = mgpu.tile_shape(shape, tile) mlir_dtype = utils.dtype_to_ir_type(dtype) cst = 9999 def causal(val, idx): @@ -1794,12 +1786,16 @@ def causal(val, idx): mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) return arith.select(mask, val, c(cst, mlir_dtype)) - tiling = mgpu.TileTransform(tile) def kernel(ctx, dst, smem): x = iota_tensor(shape[0], shape[1], dtype) - x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128) mgpu.commit_shared() - ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=128, + ) ctx.await_async_copy(0) iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) @@ -1809,7 +1805,7 @@ def kernel(ctx, dst, smem): (128, 1, 1), (), jax.ShapeDtypeStruct(shape=shape, dtype=dtype), - jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype), )() expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst np.testing.assert_array_equal(result, expected) @@ -1821,7 +1817,7 @@ def kernel(ctx, dst, smem): def test_bitwise(self, op, dtype, m=64, n=8): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + op(iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1845,7 +1841,7 @@ def test_unary(self, ops, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1858,7 +1854,7 @@ def test_select(self, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.int32) - (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) result = mgpu.as_gpu_kernel( @@ -1881,7 +1877,7 @@ def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1902,7 +1898,7 @@ def kernel(ctx, src, dst, scratch): src, is_signed=utils.is_signed(dtype) ) acc = src.reduce_sum(scratch).broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) out_shape = jax.ShapeDtypeStruct((m,), dtype) @@ -1930,7 +1926,7 @@ def kernel(ctx, dst, _): is_signed=utils.is_signed(dtype), ) acc = src.reduce_sum().broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( kernel, @@ -1950,7 +1946,7 @@ def kernel(ctx, dst, _): def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1971,7 +1967,7 @@ def kernel(ctx, dst, _): cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) - (iota + cte_arr).store_untiled(dst) + (iota + cte_arr).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1986,7 +1982,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat( v, (128,), mgpu.WGMMA_ROW_LAYOUT ) - t.broadcast_minor(32).store_untiled(dst) + t.broadcast_minor(32).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -2005,7 +2001,7 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 @@ -2077,7 +2073,7 @@ def kernel(ctx, gmem_input, gmem_output, _): t = mgpu.FragmentedArray.load_untiled( gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False ) - t.broadcast_major(m).store_untiled(gmem_output) + t.broadcast_major(m).store_untiled(gmem_output, optimized=False) inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) @@ -2114,7 +2110,7 @@ def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] - arr.astype(mlir_dtype_to).store_untiled(out) + arr.astype(mlir_dtype_to).store_untiled(out, optimized=False) x = jnp.arange(-128, 128, dtype=jax_dtype_from) x = jnp.tile(x, reg_length // 2) @@ -2190,7 +2186,7 @@ def test_convert_bool_to_u8(self): def kernel(ctx, dst, _): i8 = ir.IntegerType.get_signless(8) iota = iota_tensor(m, n, jnp.uint8) - (iota > 10).astype(i8, is_signed=False).store_untiled(dst) + (iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( @@ -2318,7 +2314,7 @@ def kernel(ctx, dst, _): ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) - tiled.store_untiled(dst) + tiled.store_untiled(dst, optimized=False) ty = jax.ShapeDtypeStruct(shape, dtype) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) From f5d73b89ca8dc2a2d862154dff3f56362d33fc82 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 07:58:52 -0700 Subject: [PATCH 0467/1769] [pallas:mosaic_gpu] Added test for custom pretty-printing rules PiperOrigin-RevId: 745145207 --- jax/_src/pallas/mosaic_gpu/primitives.py | 5 +- tests/pallas/mosaic_gpu_test.py | 74 ++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index c41a36da94e8..f996d620af8f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -857,13 +857,14 @@ def _wgmma_ref_pp_eqn( acc, a, b, *leaves = eqn.invars a_transforms_treedef = eqn.params["a_transforms_tree"] b_transforms_treedef = eqn.params["b_transforms_tree"] + split = getattr(a_transforms_treedef, "num_leaves", 0) a_transforms = ( - a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves]) + a_transforms_treedef.unflatten(leaves[:split]) if a_transforms_treedef is not None else [] ) b_transforms = ( - b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :]) + b_transforms_treedef.unflatten(leaves[split:]) if b_transforms_treedef is not None else [] ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e32222775f94..cd4f2f8ab602 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2634,6 +2634,80 @@ class CoreMapWGTest( ... +class PrettyPrintingTest(PallasTest): + + def test_load(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,)) + o_ref[i, ...] = x + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32))) + + def test_copy_primitives(self): + num_steps = 4 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_gmem, o_gmem): + # ``plgpu.emit_pipeline`` is implemented in terms of async copy and + # synchronization primitives. + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + out_specs=[ + pl.BlockSpec( + (64, 64), + lambda i: (0, i), + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32))) + + def test_wgmma(self): + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + ], + ) + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) + + _ = str( + jax.make_jaxpr(kernel)( + jax.ShapeDtypeStruct((64, 128), jnp.float16), + jax.ShapeDtypeStruct((128, 192), jnp.float16), + ) + ) + + class ExamplesTest(PallasTest): # Basic From b8353d1b903b57e3a86e666847c126b6d4bb8f7e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 08:15:39 -0700 Subject: [PATCH 0468/1769] [Mosaic TPU] Add support for non-32bit types in vector.extract At least for as long as the extracted value is not a scalar. PiperOrigin-RevId: 745151577 --- .../mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index e68d5da466eb..25aebefa4506 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3740,10 +3740,6 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_in.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit vector.extract supported"); - } const VectorType res_vty = dyn_cast(extract_op.getResult().getType()); if (res_vty != nullptr) { @@ -3772,6 +3768,10 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { + if (layout_in.bitwidth() != 32) { + return op.emitOpError( + "Not implemented: Only 32-bit vector.extract supported"); + } // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); From e02faabfb2ed7eacd82b7c438a119fde9e362739 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 Apr 2025 08:32:59 -0700 Subject: [PATCH 0469/1769] Replace references to jax.readthedocs.io with docs.jax.dev. PiperOrigin-RevId: 745156931 --- CHANGELOG.md | 68 +++++++++---------- CONTRIBUTING.md | 2 +- README.md | 60 ++++++++-------- cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb | 2 +- cloud_tpu_colabs/JAX_demo.ipynb | 2 +- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 4 +- cloud_tpu_colabs/README.md | 2 +- docs/README.md | 2 +- docs/about.md | 16 ++--- docs/advanced-autodiff.md | 4 +- docs/aot.md | 2 +- docs/api_compatibility.md | 2 +- docs/autodidax.ipynb | 4 +- docs/autodidax.md | 4 +- docs/autodidax.py | 4 +- docs/building_on_jax.md | 4 +- docs/contributing.md | 2 +- docs/control-flow.md | 10 +-- docs/developer.md | 6 +- docs/export/export.md | 2 +- docs/export/shape_poly.md | 4 +- docs/faq.rst | 18 ++--- docs/ffi.ipynb | 4 +- docs/ffi.md | 4 +- docs/gpu_memory_allocation.rst | 2 +- docs/installation.md | 2 +- docs/jax-primitives.md | 2 +- docs/jax_array_migration.md | 2 +- docs/jep/10657-sequencing-effects.md | 2 +- docs/jep/12049-type-annotations.md | 2 +- docs/jep/14273-shard-map.md | 4 +- docs/jep/15856-jex.md | 14 ++-- docs/jep/17111-shmap-transpose.md | 2 +- docs/jep/2026-custom-derivatives.md | 2 +- docs/jep/4008-custom-vjp-update.md | 2 +- docs/jep/4410-omnistaging.md | 2 +- docs/jep/9407-type-promotion.ipynb | 8 +-- docs/jep/9407-type-promotion.md | 8 +-- docs/jit-compilation.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 12 ++-- docs/notebooks/Common_Gotchas_in_JAX.md | 8 +-- ...tom_derivative_rules_for_Python_code.ipynb | 6 +- ...Custom_derivative_rules_for_Python_code.md | 6 +- ...arrays_and_automatic_parallelization.ipynb | 6 +- ...ed_arrays_and_automatic_parallelization.md | 6 +- docs/notebooks/README.md | 2 +- .../Writing_custom_interpreters_in_Jax.ipynb | 2 +- .../Writing_custom_interpreters_in_Jax.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- .../neural_network_with_tfds_data.ipynb | 2 +- .../neural_network_with_tfds_data.md | 2 +- docs/notebooks/shard_map.ipynb | 10 +-- docs/notebooks/shard_map.md | 10 +-- docs/notebooks/thinking_in_jax.ipynb | 12 ++-- docs/notebooks/thinking_in_jax.md | 8 +-- docs/pallas/CHANGELOG.md | 2 +- docs/quickstart.md | 2 +- docs/stateful-computations.md | 2 +- docs/type_promotion.rst | 2 +- docs/xla_flags.md | 2 +- examples/ffi/README.md | 2 +- examples/ffi/src/jax_ffi_example/rms_norm.py | 2 +- jax/BUILD | 2 +- jax/_src/ad_checkpoint.py | 4 +- jax/_src/api.py | 8 +-- jax/_src/basearray.py | 2 +- jax/_src/callback.py | 4 +- jax/_src/compilation_cache.py | 2 +- jax/_src/config.py | 8 +-- jax/_src/custom_derivatives.py | 4 +- jax/_src/debugging.py | 2 +- jax/_src/effects.py | 2 +- jax/_src/errors.py | 8 +-- jax/_src/export/_export.py | 20 +++--- jax/_src/export/shape_poly.py | 28 ++++---- jax/_src/flatten_util.py | 2 +- jax/_src/interpreters/mlir.py | 4 +- jax/_src/lax/lax.py | 4 +- jax/_src/mesh.py | 2 +- jax/_src/named_sharding.py | 2 +- jax/_src/numpy/array_methods.py | 2 +- jax/_src/numpy/lax_numpy.py | 12 ++-- jax/_src/numpy/util.py | 4 +- jax/_src/numpy/vectorize.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/pallas_call.py | 4 +- jax/_src/pjit.py | 6 +- jax/_src/random.py | 2 +- jax/_src/xla_bridge.py | 6 +- jax/core.py | 16 ++--- jax/experimental/host_callback.py | 2 +- jax/experimental/jax2tf/README.md | 10 +-- .../jax2tf/g3doc/no_xla_limitations.md | 8 +-- jax/experimental/jax2tf/jax2tf.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 8 +-- jax/experimental/pallas/__init__.py | 2 +- jax/experimental/pallas/ops/tpu/matmul.py | 2 +- jax/experimental/shard_map.py | 2 +- jax/extend/__init__.py | 10 +-- jax/lib/xla_client.py | 2 +- jax/random.py | 4 +- jax/stages.py | 2 +- jax/typing.py | 4 +- jaxlib/xla/pytree.h | 2 +- tests/api_test.py | 2 +- tests/debug_info_test.py | 2 +- tests/errors_test.py | 2 +- tests/export_test.py | 8 +-- tests/lax_test.py | 4 +- tests/random_test.py | 2 +- 112 files changed, 323 insertions(+), 323 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e744cad902de..86d1c82f6401 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -126,7 +126,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -217,7 +217,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -259,7 +259,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. @@ -297,7 +297,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -577,7 +577,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -586,7 +586,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -798,7 +798,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1270,9 +1270,9 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -1317,7 +1317,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1496,7 +1496,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1580,7 +1580,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1665,9 +1665,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1696,7 +1696,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1810,7 +1810,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1831,7 +1831,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1843,7 +1843,7 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. @@ -1861,7 +1861,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1993,7 +1993,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2109,7 +2109,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2155,13 +2155,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2183,7 +2183,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2195,7 +2195,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2322,7 +2322,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2386,7 +2386,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2407,10 +2407,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2424,7 +2424,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2473,7 +2473,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -3013,7 +3013,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3085,7 +3085,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..00391f314044 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ | [**Transformations**](#transformations) | [**Install guide**](#installation) | [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -48,7 +48,7 @@ are instances of such transformations. Others are parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! @@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra ## Quickstart: Colab in the Cloud Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). @@ -105,7 +105,7 @@ Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) for reverse-mode gradients: ```python @@ -129,13 +129,13 @@ print(grad(grad(grad(tanh)))(1.0)) ``` For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): +matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian): ```python from jax import jit, jacfwd, jacrev @@ -160,15 +160,15 @@ print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) for more. ### Compilation with `jit` You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python @@ -189,12 +189,12 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a @@ -259,7 +259,7 @@ differentiation for fast Jacobian and Hessian matrix calculations in ### SPMD programming with `pmap` For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap). With `pmap` you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying `pmap` will mean that the function you write is compiled by XLA (similarly to `jit`), then @@ -284,7 +284,7 @@ print(pmap(jnp.mean)(result)) ``` In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) between devices: ```python @@ -341,20 +341,20 @@ for more. For a more thorough survey of current gotchas, with examples and explanations, we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Some standouts: 1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. 1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. + arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. 1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at startup (or set the environment variable `JAX_ENABLE_X64=True`). On TPU, JAX uses 32-bit values by default for everything _except_ internal @@ -368,14 +368,14 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). + flow](https://docs.jax.dev/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators) like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + [`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), or just use `jit` on smaller subfunctions. ## Installation @@ -403,7 +403,7 @@ Some standouts: | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. @@ -417,7 +417,7 @@ for training neural networks in JAX. If you want a fully featured library for ne training with examples and how-to guides, try [Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem) on the JAX documentation site for a list of JAX-based network libraries, which includes [Optax](https://github.com/deepmind/optax) for gradient processing and optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and @@ -452,7 +452,7 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..baeed941c8c3 100644 --- a/docs/about.md +++ b/docs/about.md @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index eaa3bc7317c8..bef2fd088a3a 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX: 1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). ### TL;DR: Custom JVPs with {func}`jax.custom_jvp` @@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32) #### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with {func}`jax.custom_jvp`: diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..8f68c2758148 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..9dca1fc08f50 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -91,7 +91,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module docuementation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..b6f12b624f8b 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -3620,7 +3620,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..1c375e21227c 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -2843,7 +2843,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..6329234224cb 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -2837,7 +2837,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..6d13f517f50b 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in [JAX Tutorials](https://docs.jax.dev/en/latest/tutorials.html) +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..53a863fdcd8c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,7 +6,7 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) +- Improving or expanding JAX's [documentation](http://docs.jax.dev/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/developer.md b/docs/developer.md index b1a978ffd0d6..9edeaeac83f8 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -789,7 +789,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -800,7 +800,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -813,7 +813,7 @@ For each automated documentation build you can see the If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build +see the generated documentation [here](https://docs.jax.dev/en/test-docs/). If the documentation build fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). For a local test, I was able to do it in a fresh directory by replaying the commands diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..63c0db14f905 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..6b63a536ab48 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..f5d43d25afb6 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -454,8 +454,8 @@ performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +.. _To JIT or not to JIT: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit +.. _Double (64 bit) precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision .. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time .. _Colab: https://colab.research.google.com/ @@ -841,12 +841,12 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..f74ae9d58a78 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -785,7 +785,7 @@ "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..97648c78e118 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -591,7 +591,7 @@ If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative a {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..be40dfc8004c 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -69,7 +69,7 @@ Common causes of OOM failures disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..34274d7596aa 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -229,7 +229,7 @@ refer to JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://docs.jax.dev/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). (install-intel-gpu)= ## Intel GPU diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index 38a45ef4823e..819d0418e894 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index a557f4ae7efc..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -27,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 5ed760dd6c5c..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..093f5ec4ab72 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..de6da98b7d62 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -365,7 +365,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -521,7 +521,7 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -604,7 +604,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +971,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1296,7 +1296,7 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..9fbc26a46c8f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -201,7 +201,7 @@ jax_array[1, :] = 1.0 Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -261,7 +261,7 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +292,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -664,7 +664,7 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. - When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..e80c7ae94687 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -17,9 +17,9 @@ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -2035,7 +2035,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..82b97e195bd9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -24,9 +24,9 @@ There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} @@ -1048,7 +1048,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..90d92c4ea241 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2339,7 +2339,7 @@ "source": [ "### Generating random numbers\n", "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", + "JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`.\n", "\n", "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", "\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..79990fefb95d 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -854,7 +854,7 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 ### Generating random numbers -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. +JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`. JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 56b2d80fc58e..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 6b993a630e93..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..d8a74e4b15fd 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..12564bd91f30 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..ecfa199c6b52 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,9 +13,9 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", @@ -499,7 +499,7 @@ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -1520,7 +1520,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1626,7 +1626,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..095f37d0dde1 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,9 +22,9 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. @@ -346,7 +346,7 @@ where: * `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; * `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; * `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -1061,7 +1061,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1137,7 +1137,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..d6cbf6e02198 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -248,7 +248,7 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { @@ -423,7 +423,7 @@ "id": "0GPqgT7S0q8r" }, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" ] }, { @@ -461,7 +461,7 @@ "id": "7mdo6ycczlbd" }, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", + "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", "\n", "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." @@ -562,7 +562,7 @@ "id": "3GvisB-CA9M8" }, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):" ] }, { @@ -650,7 +650,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -835,7 +835,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" + "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..7b0bb0d9b8ce 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -117,7 +117,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -189,7 +189,7 @@ jnp.convolve(x, y) +++ {"id": "0GPqgT7S0q8r"} -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): ```{code-cell} ipython3 :id: pi4f6ikjzc3l @@ -206,7 +206,7 @@ result[0, 0] +++ {"id": "7mdo6ycczlbd"} -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. @@ -261,7 +261,7 @@ np.allclose(norm(X), norm_compiled(X), atol=1E-6) +++ {"id": "3GvisB-CA9M8"} -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)): ```{code-cell} ipython3 :id: 6mUB6VdDAEIY diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..7533e6eda053 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -5,7 +5,7 @@ This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html). + +5. **Promote RC to Final and Publish to PyPI:** If the RC wheels pass all + testing, then we are ready to promote it as the final version and publish it + to PyPI. This entire flow is internal and is run in our internal CI system. + Final version of the packages are published to PyPI and JAX's release + artifact registry. JAX's release artifacts (RC and final versions) can be + found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-pjrt). + +### JAX's Official CI and Build/Test Scripts + +JAX's CI jobs (both internal and those on GitHub actions) run the scripts in +this folder. An overview of the different folders and their purpose is given +below: + +- **ci/**: Contains all build scripts, environment files, and utility scripts. +- **ci/utilities/**: Contains helper scripts used throughout the build/test + process. See + [README.md](https://github.com/jax-ml/jax/blob/main/ci/utilities/README.md) + for a brief overview of these utility scripts and their behavior. +- **ci/envs/**: Holds environment files that set `JAXCI` environment variables + that control build and test configurations. see + [README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) to + see the complete list of these variables and their behavior. + +Every build script in this folder first source the `JAXCI` envs in +[default.env](https://github.com/jax-ml/jax/blob/main/ci/envs/default.env) and +then run the +[setup_build_environment.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/setup_build_environment.sh) +script to set up the build environment. + +A brief overview of each build script in this folder is given below: + +> [!NOTE] +> Both internal and GitHub action jobs run under the +> [ml-build](https://github.com/tensorflow/tensorflow/tree/master/ci/official/containers) +> Docker image which contains build tools such as Python, Bazelisk, LLVM/Clang, +> manylinux compliant libraries (in Linux images), etc. + +- **build_artifacts.sh:** These build the various JAX artifacts. We build + three different type of artifacts based on the type of job: Nightly, + RC/Release, or at HEAD. +- **run_bazel_test_cpu_rbe.sh/run_bazel_test_cuda_rbe.sh**: These run Bazel + tests with RBE on every GitHub PR. We test compatibility with both CPU and + CUDA. On platforms where RBE is not natively supported (e.g Linux Arm64), we + cross-compile the test targets for Linux Aarch64 on Linux x86. As the tests + still need to be run on the host machines and because running the tests on a + single machine can take a long time, we skip running them on these + platforms. +- **run_bazel_test_cuda_non_rbe.sh**: These run the following Bazel CUDA + tests: Single accelerator tests with one GPU apiece and Multi-accelerator + tests with all GPUs. These jobs depend on local JAX wheels and therefore + require that the following wheels to be present in the `../dist` folder: + `jax`, `jaxlib`, `jax-cuda-plugin`, and `jax-cuda-pjrt` wheels. In CI + builds, we first build these wheels from source and then run the `bazel + test` command. +- **run_pytest_*.sh**: These run tests with Pytests and use the JAX wheel + packages installed on the system. In CI builds, we build the wheels first + from source and then run the `pytest` commands. We test compatibility with + CPU, CUDA, and TPU. These are primarily run as part of the continuous and + nightly/release test jobs except for TPU which is also run as a presubmit + testing a subset of the tests. + +## Different Test Configurations + +JAX's CI Test jobs run under different test configurations. These configurations +are described briefly in the sections below. + +### XLA Versions + +JAX's CI builds rely on XLA, but use different versions depending on the type of +build. To ensure stability and reproducibility, nightly and release builds use a +pinned XLA version specified in the JAX +[workspace](https://github.com/jax-ml/jax/blob/34a2f0ca4a8f8a26d9a056f8785f412bd156dc23/third_party/xla/workspace.bzl#L24-L25). + +However, to keep JAX compatible with the latest XLA developments, presubmit and +postsubmit builds utilize the most recent XLA version. This is done by +overriding the default XLA dependency with a local copy of the XLA repository. +We do this by passing `--override_repository=xla=/path/to/local/xla` which +instructs Bazel to depend on the XLA in the local system instead of the version +in the workspace. + +The CI system uses the `JAXCI` environment variables to manage this process. +When running jobs that need to use XLA at head, we set `JAXCI_CLONE_MAIN_XLA=1`. +This clones the XLA repository at head and sets `JAXCI_XLA_GIT_DIR` to its path. +[JAX build CLI](https://github.com/jax-ml/jax/blob/main/build/build.py) +automatically adds the necessary Bazel flag (`--override_repository`) to point +to this local XLA version during the build process if `JAXCI_XLA_GIT_DIR` is +set. In jobs where the build CLI is not used such as the RBE presubmits, we +explicitly include `--override_repository=xla="${JAXCI_XLA_GIT_DIR}"` as part +of the test command. + +### Enabling/Disabling 64-bit Data Types + +By default, JAX enforces single-precision numbers to mitigate the Numpy API’s +tendency to aggressively promote operands to `double`. In order to use +double-precision numbers, we need to set the `JAX_ENABLE_X64` environment +variable. In CI, we test both configurations in presubmits and postsubmits by +using the `JAXCI_ENABLE_X64` environment variable. + + + +## [Googlers Only] Connecting to CI Runners for Debugging + +If you are a Googler, you can connect to one of the self-hosted runners we have +on GitHub to debug your workflow. For more information, see +go/ml-github-actions:connect. + +## Running These Scripts Locally on Your Machine + +> [!IMPORTANT] +> If you are a Linux / Windows user, you need to have Docker installed as a +> prerequisite. Additionally, if running on Windows, please run these commands +> in a bash environment as all the scripts are written in Shell. + +Follow the steps below to run a CI script locally on your machine. + +1. [Optional] Set `JAXCI` variables in your shell environment. See + [ci/envs/README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) + for the list of `JAXCI` variables and their behavior. + +2. [Linux/Windows] + + Start the Docker container by running: + + ```bash + ./ci/utilities/run_docker_container.sh + ``` + + This will start a Docker container named "jax". Note that if you set any + `JAXCI` variables in step 1, they will also be be set in the container. + + Run the script under the Docker container. + + ```bash + # docker exec jax + docker exec jax ./ci/build_artifacts.sh jaxlib + ``` + +3. [Mac] Execute the build script directly. + + ```bash + # ./ + ./ci/build_artifacts.sh jaxlib + ``` diff --git a/ci/envs/README.md b/ci/envs/README.md new file mode 100644 index 000000000000..6b5dc554d824 --- /dev/null +++ b/ci/envs/README.md @@ -0,0 +1,41 @@ +# JAXCI Environment Variables + +This docpage describes the various `JAXCI` environment variables that are used +in the CI scripts and their behaviors. These variables are used to control the +behavior of the CI scripts such as the Python version used, path to JAX/XLA +repo, if to clone XLA repo, etc. + +Name | Default Value | Behavior | Usage +------------------------------------------- | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- +`JAXCI_JAX_GIT_DIR` | Present working directory: `$(pwd)` | Path to the JAX's Git directory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_GIT_DIR&type=code) +`JAXCI_HERMETIC_PYTHON_VERSION` | System default | Controls the version of hermetic Python to use. This affects the Bazel commands only such as when building artifacts or when running the Bazel test scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_HERMETIC_PYTHON_VERSION&type=code) +`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repoistory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) +`JAXCI_CLONE_MAIN_XLA` | 0 | If set to 1, the XLA repository is cloned at HEAD and its path is set in `JAXCI_XLA_GIT_DIR` | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CLONE_MAIN_XLA&type=code) +`JAXCI_XLA_COMMIT` | Unset | Allows overriding the XLA commit that is used when using a local copy of XLA. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_COMMIT&type=code) +`JAXCI_OUTPUT_DIR` | `$(pwd)/dist` | Controls the location where the artifacts are written to. The directory will be automatically created if it does not exist. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_OUTPUT_DIR&type=code) +`JAXCI_BUILD_ARTIFACT_WITH_RBE` | 0 | When set to 1, Bazel will use RBE to build the artifacts. Requires gcloud authentication and only certain platforms support RBE so this typically only set in CI builds | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_ARTIFACT_WITH_RBE&type=code) +`JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` | 0 | When set to 1, Bazel will also try to push new cache entries to the cache bucket. Since writes to the bucket require authentication, this flag is enabled only for CI builds. Note that the builds using RBE use the RBE cache and not Bazel's remote cache, therefore this variable is a no-op if `JAXCI_BUILD_ARTIFACT_WITH_RBE` is set to 1. When `JAXCI_BUILD_ARTIFACT_WITH_RBE` and `JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` are both not set, Bazel will still read from the public cache bucket to try to speed up the build. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE&type=code) +`JAXCI_ARTIFACT_TYPE` | "default" | Controls the type of artifacts to build. Valid values are "default", "release", "nightly". This affects the wheel tag and metadata, see [ci/build_artifacts.sh](https://github.com/jax-ml/jax/blob/main/ci/build_artifacts.sh) to understand how. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ARTIFACT_TYPE&type=code) +`JAXCI_WHEEL_RC_VERSION` | Unset | During the release process, we build a Release Candidate (RC) wheel in addition to the release wheel. This environment variable sets the version of the RC wheel to build. Values are set internally. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WHEEL_RC_VERSION&type=code) +`JAXCI_PYTHON` | `python${JAXCI_HERMETIC_PYTHON_VERSION}` | Points to the system Python binary to use. It used by scripts that make use of the system Python such as the Pytest scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_PYTHON&type=code) +`JAXCI_ENABLE_X64` | 0 | By default, JAX enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to `double`. When set to 1, the tests will use double-precision numbers. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ENABLE_X64&type=code) +`JAXCI_TPU_CORES` | Unset | Sets the number of TPU cores for the TPU machine type. Values are set in the workflow files. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_TPU_CORES&type=code) +`JAXCI_RUN_FULL_TPU_TEST_SUITE` | 0 | When set to 1, the full TPU test suite is run. Otherwise, a subset of tests is run. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_RUN_FULL_TPU_TEST_SUITE&type=code) +`JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI` | Unset | Used to control the installation of JAX [extras](https://github.com/jax-ml/jax/blob/7e42539653d33ec995487b683794c0bc86f7199b/setup.py#L64) from PyPI. See [ci/utilities/install_wheels_locally.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/install_wheels_locally.sh) for the list of valid values and their behavior. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI&type=code) + +## Docker Specific Environment Variables + +> [!NOTE] +> The following environment variables only affect the build if the +> [run_docker_container.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/run_docker_container.sh) +> script was invoked to start a Docker container and the build is running inside +> that container. Typically, this would be the internal CI builds and local +> builds. Note that while GitHub actions use the same Docker images, they do not +> invoke "run_docker_container.sh" as they leverage built-in containerization +> features to run jobs within a container. + +Name | Default Value | Behavior | Usage +----------------------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------- | ----- +`JAXCI_DOCKER_WORK_DIR` | "/jax" | The path on the container where the JAX Git repository is mounted to. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_WORK_DIR&type=code) +`JAXCI_DOCKER_ARGS` | Empty String | Space seprated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) +`JAXCI_DOCKER_IMAGE` | Depends on the system (see [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env)) | Docker image to pull | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_IMAGE&type=code) diff --git a/ci/envs/default.env b/ci/envs/default.env index a5a5d56eb8b3..774464724646 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== # This file contains all the default values for the "JAXCI_" environment -# variables used in the CI scripts. These variables are used to control the -# behavior of the CI scripts such as the Python version used, path to JAX/XLA -# repo, if to clone XLA repo, etc. +# variables used in the CI scripts. See ci/envs/README.md for more details on +# the behavior of these variables and their usage in the CI scripts. # The path to the JAX git repository. export JAXCI_JAX_GIT_DIR=$(pwd) @@ -25,12 +24,10 @@ export JAXCI_JAX_GIT_DIR=$(pwd) export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} # Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local -# copy of XLA instead of the pinned version in the WORKSPACE. When -# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +# copy of XLA instead of the pinned version in the WORKSPACE. export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} -# If set to 1, the builds will clone the XLA repository at HEAD and set its -# path in JAXCI_XLA_GIT_DIR. +# If set to 1, the builds will clone the XLA repository at HEAD. export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} # Allows overriding the XLA commit that is used. @@ -39,49 +36,35 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are written to. export JAXCI_OUTPUT_DIR="$(pwd)/dist" -# When enabled, artifacts will be built with RBE. Requires gcloud authentication -# and only certain platforms support RBE. Therefore, this flag is enabled only -# for CI builds where RBE is supported. +# Whether to use RBE to build the artifacts. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} -# On platforms where RBE is not supported, we use Bazel remote cache to speed up -# builds. When this flag is enabled, Bazel will also try to push new cache -# entries to the bucket. Since writes to the bucket require authentication, this -# flag is enabled only for CI builds. +# Whether to write new cache entries to the remote cache bucket. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} -# Type of artifacts to build. Valid values are "default", "release", "nightly". -# This affects the wheel naming/tag. +# Controls the type of artifacts to build. Valid values are "default", "release", "nightly". export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} -# When building release artifacts, we build a release candidate wheel ("rc" -# tagged wheel) in addition to the release wheel. This environment variable -# sets the version of the release candidate ("RC") artifact to build. +# Controls the version of the Release Candidate wheel to build during the +# release process. export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} # ############################################################################# # Test script specific environment variables. # ############################################################################# -# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override -# this value in the Github action workflow files. +# Whether to use double-precision numbers in the tests. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} -# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. -# Sets the number of TPU cores for the TPU machine type. These values are -# defined in the TPU GitHub Actions workflow. +# Sets the number of TPU cores for the TPU machine type. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels -# on the system. By default, it is set to match the version of the hermetic -# Python used by Bazel for building the wheels. +# JAXCI_PYTHON points to the Python binary on the system that should be used +# for installing the JAX wheels on the system and running Pytest scripts. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # When set to 1, the full TPU test suite is run. Otherwise, a subset of tests # is run. export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} -# We use this environment variable to control which additional wheels to install -# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest -# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the -# list of valid values and their behavior. +# Controls which additional extras to install from PyPI. export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file diff --git a/ci/jax_ci_system.png b/ci/jax_ci_system.png new file mode 100644 index 0000000000000000000000000000000000000000..19efe62ae59e4a2d3de99c42a18bdff5768c747d GIT binary patch literal 535880 zcmeFadpy)>8$WKJwrzK7lkK1+c9l>fA!pi6NlukYa@t6ekn?fO?zWXfD8~p>IVGo% zGcz49$zeha22%z@gNb1VV`jeBZA(3`@ALip_j>(a^*rsiV~o$|{#^HUy|4HCecgo} zGd1|>N3kD;goJ)FJfeR>NN59BNN9EU_us)QO7^~y@ULGSj+`(N5(?ZYB=qM$LP87h z(x0P3LOy$igx>!lB&2yqNJ#WbO4V^~_`-MgMh5yq0`$Lowb`-o$`4nLSb7NwDQri- zzwyu1^Mw~zc^jG>S~a#}2AQ2)1+{=MToY+U>><^va-wg$V`^+b2< z=4Y{Wjm?+VeEUP|e~hHR-7M1(F17yni4VUR&FNm+qc8ew^Q8wnuHU)WBul&$^fWWL z^UXis9{ihlR{z!?ImNexemCwiZrEY=XL9j}H#r~A$Sc>Na{d))tcS?JdKBD4SX`pd9xng8|rkpB4k|Mj}3)Yr=w68gIO zOR@1ac$Svo>()if@HO)+1;f|GiI(AO*;)#QuT^(x8NPxBv<&}m(6C~c&N7x;Ah6jl zikw*q8AbHy&6^J1NU<4{J6O0|LVH6{$U?~+k11emE9@k99yT%g^~bBLP6Z|}t?|GA zm{Ix+^Nwozj~ zj~E!#%yzj)^SPp~gqu_WpJ4y;7Kdob3}Utsa6VmJ`va#=WoEBF&>H3KT~|!7OOu=a z;aCA&yvK*v_tQx!#6V89KtWf`5O)D2*`4)g+@qJXB|MQ*%?W~Te4ly$TQ z+s(`C)c?)BRp@H|Rt(lCbEDXfW#t9{KELzagC~T zCxe5Bi1ywu9&Xkhuj!-Vw0tL`m^!TVj3SK3p3yM4|EXRx-v&Z$RldR0Nb>K$ipfMp z6E=6O6AfNmwACnIe$aoK-(9}BMfcQS(>({%WX_g{&^#zfUE0#RL9-{{nf+;Q9K?NB zG16PB-QpO=+tO;6cx=yQ9M&$ zp!J%H$!w(2o6~w6<35PGx*s-jv1)YfQ|DBLuH!Zzgpcli5@4h~WkxkQ%PqLvU#0`g zCb3&0XrI(T$YhOw`&{319GA@tnyw^^lK=J`?=RQmk5!XqWIqJtt*E!pJ6k(4%1#!O znc=>cTYhS<%a?l-b=pLCsn#T~>!41!e0=+s6^{yYBS}=xl6C8z$_-MeJihbB1?^MY z#ztke+Bwf76ziXi?DoW&`^-L_8+=Zxe}2eMI?}E@#GyIa!{iOQb7cRGph?W$fXgQL z*Ogo_WlqlulhY`W|C96t=gBUN^YH)eN8XY z#LY5mxOXd48AI9rACBF0NeONj^apU%cI)UYU%ILC@{UcpEnzc%vQGtjPF!}=?&`m(>X*K{{* zu7YacKx6!jIoh-B=6d;0$|WbKFhR6vyVgCv0T154|7)uL8ZP2m#N<&itf1JyzR%9K zJN@E_4oByXxKf_nhC%j9R*)mru%o?qQ@MOiy8XLGQ@ECBmx9!-yDhXgTzhtPSFL>Q zh0gbqY}s`wi3we0cU0*k9U--5(k@-P7Gs~o{>pdDE&SV?$F==sFSJxv4&KjPpMSP? zXYatGV;?87@cWQ<_frdyoek+Qh8W`Ujz0LH$YGbxa7TMqi&bDv7 zEhxZCx#Nc%i2CNVf6l-zXrDvTFHym9am1X4y*y3Vufbf-`t~9*vUYEWuVIM;d9uPe zIxg(=+D%(=my$(NM*{ zsrcdnD|1`yI zA27PyZ?g3E%V(+D9UmM1&o5UYC3$6LU+LWiWfpdhG_iBTkJ3y6HE6`3R*aihAhFrb zdam+J#b`msRFj0Q1dSLO3BlEfn`%zJOxof%EG|-5nA_qLlwAT5B5_T(I*8p@>0-7> zXNHJs*->M~@94C_4LA(Pnh|G9E}RU^@hk}r>Ivrq{W zRZMPL(fik-+5&-)H1Aixy-9LCH}A2k3oiHnapS!~0Zq9@gGDk;fy|U1<@A78Lr_lN z4BxqbuvJ$~xH_ckCR)&VGmK1EKK0H#YtoK=#QMV&tNw|Q>90<;%-Z<0-*?_67n+Fq z(fg_L4BLe4PM7wP!2Z$>6hhuw7y1_>Js?#i4JKRKsQ4PIsG1%*TMlKQ&x?AsNL}e= zCDhg!u794E=$>)*e5SZ$fsPt=rIrb#_fV@%1IN39%yBx~rC`-X(7xwI+#L+VF1U@n z9u6m_-6Y@_lxqV`XEH@?`{JAz;T{~in#XJo>zZ8{3Vk>AwAL;H5~NdjW!aOm`Y4;~ zAXRl^@`)F1%VBUU+wWB1C$jAI-Cl|?5$cCyh#l;nxi$_ioZ-~+Qv4MvLG^xpn|3Ga zojm=71<_e|J1XkO{LMl$lDXIqmm8I4eM*&lUO5_4f+l3Nf-tS+igzTTp0s!SFw}Og z8l9}#6^|Yb#$<&R zdCng!@@}+gB-&;NO&aMHER<~Dwyo&WZ1T_)diLnbHY{j&y%c#td$%g_LLdo(Wp|~% z(-TqK;GA~z<$HX)M&(}nCuZ)?fhPbIo0(%b1v4ixHm=$W!om#*{FwLAp}Z83LqG-VqI@Fm6tvBu2=Ywe%;&Ou0E&={3A1aVeQ1kjc6@4+Jt>x zdS`r(g_oC>OGaJ2u@>IeY7-O{G_ff10&?`*r++1_sGU(vrH9mM5$L}j9FRouBO>C_ zWy$n1O&vj5?eDPldKwDfoje64T~68o8!eB`BQpy-J6F&mR_i5-Z$-tjXhAF5FGyRD z;C3*}Da^f(H{47&pjUP#vL`8$?mV`hJr$3Q4h=pT;mfO@4IwZ-uR~v=v^EA|Ch=$P{r>ptfXp)JHnl>5-SeVJxmKpI`$RTWFb}F4K1O4%=hV$SsN=y5 z!NQH9iQut_)t;;C`=Fl35);-(UUJtHFXWwfTBPJs?0MHs8gD;p`zhUi;Yp6iGa`xU zA+)n;K><%I^ILR97i&hvWHr4cKgyPfY!NIR9S_Es<+K>+*Ja<`IlL*8Qsm+n%uJG$ zj@+ikeqNE`Y~SPJ)%m8$0k)A&8ViEmpVeYl-ylg4-)h?%Y`rc=3q`!+bf(8thS`O+ z*er26J!Eu+qm2waNc&-Ps-cA4LSo^?)%5gdV`3FsF|4E zzT$2uz5ucVdEAPN<999stcL_Gufc{GCgla}hkW7DPNq1#{)>0LiJ|Hi&Z61S&hum7=;xAa~*KeLI( z^j66aIO;|uRS;_SYOpOgYpcU&iF0}^(+Xc<`07dWq#c$xh9+(dXFF?LPQX@akz|4g z9~4hAa!(6BWRQB~NVh1ql(q@wla-w5j(2LNbrdplR?#gSx`?sqfzy!0$sz3b7gJPp zJN(K~RpwX;$1kR7l81~1(e|GrM<{#k7oo$jGmg&D*N3$FQnl<&j`CbjaEnhP%NepZ zBtF9SJoC-kvZ%raXayV+Tc{L3O_YUP}fiP1*a1=Z5#t5sYjqv7G9TZG6^~n`JwYUb+P$Mnck$ zR7hC|buMj8eDV!E9FAhp4pc56=QXHJ@;yE0Cf`0gXvk1;8JTG&_CO@1)La)4DY;<1 z9wI48H@#M8A_)EX4&R`IW+YxHjU)iHHqx)6=Jm%eseKqlo0fDfwboh%sXE;f_4c;o z`}%rvSlU~lr^=X8teHE3`{yTwC+5>IA5oisyGtiRZNYm0wb?^qs4X*tKIlvr7lnbc zOPUE@klG9o_Ov$pe4hV!V>+EX=iv|}{;a73ibd>;HzMwm)wNTYjo%4}118r3M6cnz zNHf#M9In^ZOYqM>qU8Ou2e7fYj=^0q8I%_X9bU9Xe9CT0&@L)OcjS@SSj?=Gr^X05 zk4GO0qVB=o-L#5C!-o{Jw3{#(IwR;6K3h_DQ#HoDD`Ba9dTap?G?AX*KJsN#az9EfoZ+vjXl|BprQoCw8#~Pk{`F5Kx zo1AX`?PvguN5uRtSNOvHoOsEel1ZD4#XS6ou!g2MecZQN{I2%i1UYL?ex@I%daeha zUpm}IAXLcX^!Wjm`9~o>c&FN|{>qEz3NdyNa&2$a?+@M4C~;~aZ&Ey3Wss)OrOQ#Q zbaNjLzhO@!$kGpx=jdX6-f6*UW9(;%f8ZnxbETE2^uB*b9d(J-SX% zmbJi9#m^n^o(X80&1FBd9$Ia5-9PXWBw^`vBjWpAGyafmce*?9t_R!WpqEJ0-Y=ct z58dM=>iez}G!bCWnqN(8i#Hp_PsH{>TF(+Ph1NJ@vrV>+9?vqGen+e&BJ?7Dx;R)= z_G)5zjabRZK;Q0P)47Q{F&kDMjVzk%E0yV>mzPVn*4W+M@7`F0)8hshwO5>zFfH?s=kDd|4mc;KyagY8;biTkf5J<|Ky6Lup09Ish)w z&C=gs+6ncO3st?vbWPYfl}y2-PQ9CVM--c-2m0dcjS?E@x&ptAdQ`~VGB--M2l6qQ z6&{m5L|%2)u~6q0bNjKF>Oli4cGqbItZV^Jb^GC{%>zuSqFQI@ABq4=_LW&|)_!Qf zP@8Wzla9s&b}H+nxhK${k9A}vh_m2Y?cY794%99z0Mz9Ge?(-D2y#0b5EVHU%=R%)S$Mr z?%DJesG!P`dX}R_P)a<-cSYZ`gcf(m71Fir789NJHj&l%55uTSiyk@k3ISSp9=*9Z z)CnR6fyF&x@0~gws#*KNK5(WK;JJ0z&fI~uhYl_JU3sWQPvbtzaCIGsnGj>K61s{< z?y6pj|M~lT0ISd6{+2}VSYpsDz1&(iykLZ(u3&?Irl9d>6LH9yZi4>!NWG9s5}OYF zYA+?g($1gw#jyb=B}8=bx88!ECrUdWTKiAFh( zvS$q}bUSmuVLIen5sC)s0e+_0RP|J)R_)(gP$TV#B`j9>RPl`LS{x~>!t3YQegtZ+ z%Q6l2di@v6?Tg3D1dIb-RJ!-*k=cl9aCboCv%=6&tSx8IP!KZVRA@SIxNGiBn63X* z=1`HGQb%D353d`++hgr*G2UHTZBf<_mGE?-r4{rV4&EZeDt@7EZjJskD54mKX@o$Z>#hVrPh_wWZfs*x<8+O0EsF zmH(s`7Z~I`-OV$5k71nWFyZ@@07Nu$3ZT~Lu5Lh0y`sXRb$j=xh6n=K_?+`R;{oMiS8c?zlk-dcR<)Y}!=eWCmN_0qBrE+La5o#fz)H z37o!+8?Jo(#ytZN0H=p#m5r|4f)HYd3%T&9V_3};Vh}AUj-gC>)aep6;DrMUr-1P$ zkpo_ee}Da%_ua|QebH9b{b{oY4QbrVw$PqDfPxhBLfvNNN@DcZyNgMXi9|b+3uS{^ zQBG1R$T}vwr6?GB%Ow2W&`^X*7b}XUoR|S-`ucg8NxOY3?fH|f(g?}ZCA@WjbI#*G zD{-v#gZ$o{q@z5q-VI#YW(|9POI#_myp0a-DOy@&)uU)*;PEqm88dSr$ir~rqLS+L zl+aLr3~6go$cwBY*|w#&uxe;Llf7`cMO6LJgD|o0i9^Rdeox8m{c>2juUa;IV8e zslk3Gzaunsu4MpI#G2JfO52Kl3h)P*P3_vJFC9>lLW{c)9OCphq`Q3>VYDH#8ZGMc zS#CFT%oP|VbAm>orL`~pcrC*zDceHFoUH-R`^u{!Gr##>0P1tqT8ZZu#tT|mCU?uv zR}l-X2ug1Al&uS=p0d3cBE8-oD6)_$sD|zcHDI+r|7|N6%tDBNlCn&>jmmw-m|6g{ z_eBa=3st(QR9LOXZ{lG&fi;T)HiO8kg+}nE+vfNU{{ZN0+90ic8o~=$10yj?up{So z29hSOwAmzejinWajBVqF7D2%qFBG?G^`kf=>CNYG0YPt;EO0aqbA zuMQ-jiQ}pWYjQRfu^WgmfNE=q#}`_1Ngj-vt)5n~8z3xU?MDsxHD+|btRp}1(nf(- zsRIbQ&CaUkp;_S%GF6kkV}`%h!6wJeaDEGaLw}z9`+dsANoaF+=uwPa(tbuOlt@^Q zehRefG@#H!rc0*YtGdu2hYU<{0LDuIF81%EB7$zGO)P5X0T}iVGi>^s%&t+@Exl*Q zJ1TW8anCe;%FTe+P_YZJ6xW=oE^dA2D6AOWU`X37HohD_>RY}DAJj*nq!qXE)mT`` z0B69MCM==p-L7GZ+pnteq?n$q7ZK4A)u>HI_++}90Grp%PJj|rO+JEWc}X4BYsxkN z1_#OmWs^ZbI7V*nfQ%g=uSgRZ@_*d{n#EU3fI^2qWcG-L7`(-C`X1P0w1YPR?Ey&N zaV@6Hz5AqZg}myW4gh~7(lyAfLk<=Ag8dQ|eujx2EvWZ~_}r|#@HZ;>F^pb-yg;ex zI=^uly;TxBGE;&p-*M=fZU4gAMhs)u@1?*RLN9XL`!XI+kIe8HZ(;^Ox1E%3BCJ<2 zlflru_h{GkwHjLJHwxD1g5X^mKW>`xt$CFD)Wza{-kWZ^osFl*w8udXDo`07Ef%(2 z!!D3y?a#8}Vu%C4jaUC$e-1=dx^WoLUU5j!tpPJ z6V^UM5U^upk7NFCo8ZQ7`nV88$)sctZxaBO5>U@Vf%0QVslHb@?Kik+E_F!e)pQza zc^M3^2Z$357^k~eYRFv##3_5az+Vz?`A)m!C;2tn#fwwMwCW!*Y0W_G)3RNg3haV6j=C^` zJujxXwoeQ3X*G98*cZ0~mO4`!Xo8-;iV1#KLB;#t3f8211BPsu^Tx1Eq zzMX^}KJ0>eVd$KKh$WUz=k=Prx*NM_ZCL-ka04S1s$+d@C}z!1bV&Zo4;H)Qvzjom z82wZ3OL`O(6>T<$ASjIGV*nUHCX0CziLeY5Ib*NOKz~7PG!qK~h&Rcel~+m&P-`&> zf9ov$CR0rAz^Ih)LWIZa#$)3Z9hGnDe@!;L-0Pk3NHJ*Pj9 zZay~l$%}eGn%5IFV1Apl&06YS?0djP4xR5+4ITWL0o%=wC>mLNLU!)Q^JgA8#guod zI|q;Wp#3kM{&(s*qUVEy{W+ke6==EbcBux*3`I3C2}h%BO4yZmx`Gd7O%wtsG^8aL z*5r|{dZ5}Uq27RX#ny19pKHnac)_07}hp{86*}Bt(5eJ%HWp?_r+h zn%&?65g2!~DdW38&>uD)S=v*wdySr|bfQ!u`@0HIzlJo;tKK=(Lt-RN(UfE`&lIhuuTp;5UIAhq7kzq=e1xTuHbl@z`I5j$`@#ls!JW|`9 z;1dG1yFhMqD7Kj|;nZXg-IVx1CaR?=}tjA}}=m}znBaW<2Z|C^ReD^2pfY=r15lB!ecoWcTwf4gI>bOw>PrxH= zO!5kk<9cL+AlF9QE0tn=?!&z^yi_6$2r911LcqNE`)6qknW1s0Wa5dHi9qT)fOO2|7s!a}*H1VPUV1pVdHoU7&u?;rdo?I`R*bN8}A1Jz`Qugle+qJs5 zg69XYN?G@Zf~q4D%D@=FPRT9I)e+mbXQTC=jZ~8`Rus2swb$9A_JLrQMWcg(+Ra^B zqi0+3ng1Lg3qsPbAuY=ouM%+g6M&<^_SjeN%=!S@cH{Ah^@yNL<4TNo%wJG)ySy;( ze|fV9P_5aQXIQ8+JZkn=>}CiS#5Zo$tW$ zv;*w03ne%r(!d<^I~02d;(Z>Szs6r z`k6@Z+1cN4H@m*=^BrA?`(+;klTJ<;rhtNu6u-0QwKfqgI5e z42O?toR(3z)W7Q)tvtkvlTNuKJs3yC4`a67ql1X+7ch>KWlVVjdP0I}3~fK!Ia~(` zNZp844|);_4MRiT7V=C_N{OtKX{yrG(DD^ekw%aDy_BHIRYBsNldB~Qc*ioc52zsw zg@Br!9WM#_pDqtfADvds&PL5U;_xb%TkRUXV8Br)lI!Ib+^!P};jQTbMCd$}*R=*B46OBsl{AqF{7b*Ke0ht3)dBYnYYdDSG zvM%#N=(q@)|Kt1K&p_^!cFoQ4cy72ozIFn&Ow&lNW-rY5wfFkASDb=vo^flEWOdeT zAvNwt#njCa6O+B2?Y&)&C%wT|C(S8>$5L4EF0A<5#{;}cy@`&>L;5FF67ekXv+a!B z3S0)IUWqEA}9r7uAy4IVYzNtiOT?35#if>1A}Tg12lZQm@@Hq|bHyL$3tLOn4nA1k2)vf+ z&(6w+8G^-r<}RIU*4El%v4&_bKk$Fd-P(AAFj#c0)wbbYnJn5Qb47Sm~MzkOj!h^+;7g(u_xBESNI{>lEDsRC_6` z>TC`OSA#y;50+dX{JxqH*Tqq;S^cYF;d*d^9)`Z{q1*;fZXUwkNZ*)Adoy<-$Hn5k z*SYu1FjGJ{cPI-f{%Lj|ou#}}7pS`&$D^dr;zs%eNY96K{lFR0=x_rl z&PJt=@Xt92*;Roe2{SybZw(1L$o;RtHpSkMq$c0PYByi`*ovC%+7^E^^bFxXmx*Y4 z5Xs{E5PC&ri=mtHsA{!2zy8kk7$)^eK$8~#H1LJ69yWL<Pm?GqJ-+V@H}MW4{GRgkWVPSkP^v zV4Yf3zx96r@S+afB8NaZ8wA1Sl=gRbr1`1CTW;Yy=e7H{2=NdnE{PH3oT(ZrF#c_< zAQ&^Y0yVw?i$g|w@hSz?E*2myWVnn2B09b~WzShFXoH~@9iOU70bO#cg0%?lzkT`r zNuy|Yf|-tUP;#_XV}!^31j2>x5!Qd20(<+!3qq*TVYi4in3BK`1I}Cb_KG)iBOSSE zW7DM{?IRfzH^KMm+=ueDLwq0!rSHniPFEU!IUxt@C<+j&5mi^WbQ2`2w$oMZ0(!LU zV{1AdM~7q#-HAPBtE~;>*Hzf%9U8X)FAUu%O&3+0H^T@0Va=YwY#WwSXs)&?92|ExNDVt z<=b-4QBA)P6GvxBM@uJ4x5cMD_h3(OIo4&Yzr5Q2(=+^%%lTxG5judCS?)!hose&( zfpejoael4%%S{Vn+Om7XW5lR^;L}W0%2T5*o@(C@5Lw^Hr&m)n;_uUPP$>?9H%&Tn zLZR=;g$XZ1%3Q9ILcxpz?q{%GOCPF*^o8vs2qrpR`=bcpd%(>wkvVrb2(rt7N9_VI z!_xFFU}o(0kdNmRM(u9fEff-Q{#EVf>ZX@AE9Qb8ku1I$Ihk|Q0#yPiKpyjU{pq*F zYrvk3-)3O!^DiQ6cqOyjPrEaXdw*vBLhz;x0jf$%jvr_)9)v>#Nr3mE^AzWt`GY(Rd}UWZjEV#gYN>QDqO3g(UCkr2~n3$ zgPehoal_;QtHDr!pCT&C)aprVzsUh?GfsQooVoV@XTUK>;xMjK76k0GdI3lVtb}N$XyQ(o^J@ z1a7b3WypofF81k@33!l@b|IEjZ-H;#NE%2*5|HyE@LB+aWHkzMo#x_5cfq0n&v^Et z$-=a@UCsG1Eh$F)Tgz3SD#)KF{!2N5ycG_4Qznu(psW{;T#r3{RX?hx8Fz&?Bi4t_ zy(X%b>W#Aimth78nHg*??`9fJh6|HJ9)dIlM%K$p-T0G`7|_0)vjLt4q>Mm=?)|C~ zzCe8x%oOar3l5jXPH9mK@3laA)Pn>v?%T`&mnB$659{j#_w7b_0gtBY=B4=X$8M<1 zdijCt;6WVHbTbpyR>h0mnKS~hrZ+ZQDuYbVLl4Iw_e(?ug9}OrSs;fvA;^bxCYb1o zAzTbY6d8~nOZXXf-Axc8Vi)rg>Z$#eI}jH}X6`+dKu?5TB*MAHVy34LwOEjF?~%wu z-iR=p=fhoW{U&BAusz%RqbvcTBYH84YjspeSIQow4SLzmRC)m1 zh`$fRrHEaL9gAlzVh3t|!L#b7hXk^}XQqx3JU%@Gq1KaMD1;55QHdv50$BoldkC)1 z9(AlQv_wR3Zpm@Z+#HeC*{%DLzT!{pJMsG=$iK}SWM2)wD3#o`MN7Ls^V?J@DK_+9 z9w7G+c@tx&NINhM;1k(h6@Bn}jb6vxQOWgKTRRKS<|6tV5Eltm+pI3}5C;YnjM+kO zm{MUugAZ5;CpaGE0rtR>zu$kHWFRqxd`f zX$wOfG|Qpf&)I3WkP}0?4xi{JZZ0P>2i=xM)C*|W5K|7i(O?^E5k)|1A+vEp8}asn zsHe`@VYOwi@YAF*oA|U(4`=2&LfBbTKU5~Ls_o}i$dN^+`u_bni@bsg?vk{wF1Q%kb#}u!hz{Iik_Ze8+I#X zR%C7+Xdox*t^?U75PREmA3Q#!r64ya3|QE+7D|iW=NLGEi2QIf|&=3n5J?4M_6)$yA6zncHaKiL&vGQI~MXn*dzF3{f$$$zbn z&WWR{A&QDT(hr6xTvu?tA&KRQ5O=LEXL2rbIfoyR`(j(bg(DevkLt!%BBduZ)b0tdcqj$(;`=5?&G*ov z#t``nYJ{;Vx?@T*zaLz3?BA~i^QMjMRyUwsowtg(8@MO(c+*2u zyf&yjMC1X5ogXS0!VFvT`vPfqM`Llm%Zf0;s&uLxK~lZ#Psa}HBZ)DQ_!gcPGCR@2 za|F5{ zoF*rmGTXB^0*nAFIydbYOm;gN#2~|Cj7XUl>w5RdE}PP%0Dw`{*qV*m+M+;Qmy+ozo$#4xbk5LYCiLOO?m2H>Br8xD z*Z8R{2+g^FPEKWya#tZjUAdMHTCbOI2>S77AP$oTSo2nHlWk6>AdeZ+P+1naz3Co2 z&qsMz-olRdJD$V6Jd2azzsI}J7g9@v46rdDv9 z8okg^g;SR{CB5_&1uBhKn3`GA;@DB<(Q+}F$Y`)P1S|rEYgV0G^34g_EtIe6(woyP zi`2c%GId#pYcRGmhpkH?%XI__OT_I7?U$3TA^>TQHm~&^uf+vmc-`m$3G}+7jSE|| z_39Yr;@aMqnpx*x>69<8cwo~Pq+g&pE3Pr=;9W!_mbX?bWtmpW1I~@-^*}MRhZz(n z{vur2bz!YS1Fj9gQ3yrp75-+0fjh5bX0{vNugQD&#=!*GYBWicQ6wO-6zhO8JKq}u zu&rWIS$Sy!QhRRVWz)Rm)=gM;8mERQkX(O890Z$UPP(EA1j^9T1N%J;KAd-nN`wYI zU>t}Plo*1f4S zi)z}SiGj&aWXCufr3LasF80=*4M|l6048GS+Ta?Wq*^4f1_8v^L!Y4RDgXqxQ<4LD z4uQ?-vKt`$1BsioXpII41|c}cI~@3MBM4p5mgi&vwhe-{T7r{WXvM9SxF%g9&Whg) z@BvM$p;R}-K?Y@Y?@>PwS|20u$$q$iK^VTQtN9iSI?3^Yris)LV5hlZual{8qUJZf9E zslT0BRzaRR=Y#Chr^#$EewGO=t0iJBfe#&Nz7IYqF<34%i=6A%@@KfM(^h-Mt z&P(q0)5*cJND8Lg?dp#Ed403kfhNko4d4;{S_KF`8s0rvx?|QINc*I}U!mIwFEAY7 zzmJJfo&Yzii~Z|+oH}G#v8E%Xn45@C9^`LO!3=}cqrSaX! zg%6Ilp-MD*IawLNJ^u3;8Zs#-t!O1K%({ofswj}Y z|KqfIw)LiHy(#Z3)h6l z6jzHwGioXY%6(~xaHC1BtG?*JMq9sNQAb_T1;PV>u`2X~pqH@*rf>>&{^64UH zpO~OI*{S*`c?m`Wq4xebW)wMSV!Gs2-QuMWzp+3&P}>*5MdW7}s|J~P=&5@`x2+j1 zX{iegPEa2WLmZ>DSf4q|E%~mtFESLrU2Qixtr_&-4X8}eI`tJ4^6CLLI_a7snhND2 z2Z#frjZ*~l3<$2mpzh8?{?W5Nx!KT{AR`Ac>s#L@4=@795|ND3=)gqXb&vokKh!`q zdr5-a170rKfJFe*G}HhR&TRQ8$VwA+_;no!a1`qz|jK_}rj<~UMy`<8ZvQg@`&zluu9Y8WLzpt?o(Pl9b$OYSAKr~b#zsw;O+ z*ezVja(MOyoUKy+a^h@1n0)h*aSKL;|c#?u&77ZBQTqwHQI z5l0b6V>1%&v=R$5K-lCRkrOxHA*+eUWTm9G^OFxKcTT@A8+u3sfRMgdo%+i$chs;9 zsWMt&Q5%12!mGf{mND@nU87su(KC_i_%XaQCeRrOFo4&{+6Ie`B697ozG47!&x`4S z9Nk!yyM8vY!Uf&|Fmt$(Jka`d@sz@TWGYqoLzWTjEjH)G5vI80nvIe2L(E}}Z2j>cxPEXVk0D3e4 zihYR01@3B4%S@g*%YBLf80_H~`Us4tAS?K!^^?5-4$RiB|Fj)5ieT?|%j*=FsOhq!fB;c7EE5_lHvRW#1k;ee&!dt;@fgl2d>gCBJO zRA>svfTEss(AeRy-K}{q%}5Y>n?l^vkn&faFo2n{HbuZDeQB^-UEuH{Dvb+fhx9xl zH)L!B|7>Pry2bgQ^wJ(?lM}7@Y|z=ZVi!H9yMc}#Db|%yhFt)a)=)D6?K>#nRXSK0 zb07j=aBFcMK`xD3z472WG2^#Fhv)&Yv0+g6d#wvU{G{|z}khB>~DH2s3IY7W?>D?>%+(@aLf7US@mPqKux9xLa zRGv|HOUMRhauCeHasc{A@Dqp~iY`Z^g_SK%@8g_J9)X{buHFPRgYebzN2vbd~Yjf+n1g>@`EB4g+2- zG$^QL-;)C_FocjY6YpE}16BjjV}PSlP*iJdLNhfT8h44jTrHS4#EHiTy{k_I&Ss+x zo)0!Ig&JsNG_$qGjOB;a80{( z;zL940DX<9*Ffo+EmOX0E?90UJ)j44(JPk5?sd!2EJ-GI`!c#p=h-evOS&L?7&*w@LxDFezz2WHGrM(y0kCSIQb zC&AK8J-G`I>c5loI+gB@x23PwfH*~5(@ZL;1040!DKEK(S%wBX71e>+e(%bjiPKws98#NkS%m6M2IUiTx7|dXnETUwHgUE=m>$z0u0orVta31 z59D{tE-v_!p0t~jE<-Xaa+~11(!F4o1*s53v1FrBFn@zBWnZK_m=wlIyP$D`_0Gaf zFqO)|qd!(EWJRCHLil_Bl$?4L!hTz%`~?l)s*WwItt%JrBkN@|6ANzTSmMaGbjFf* zPYe;!pnB6r3E+R(x6+v2uToB)$D6Um!meSA2k<>Tw8} z7={;Fy+vbw6LI z8o!bqG+f3%3>pp-b+->hUCT-%==>1D&TJK;AVU%w=8}Z5eCsYh@hK`9CO33S^VSWI zy$9W)4ZO1B1FXY{Jt&@ZhSqiv$a>CP)dI5R64^7sjgMg9$R2v51aTm&kfM$bXF(?G z3OH?{5r;4SCGbWRbTn4C&V{NbZNTNjAWC`qeEZ$Oes(y$bQwI_<7afh47kcF`FaWIDtC zFg?U_g|gPnpsTfN&0ct`F7*JcB82FH@)bJ7WMvX0SyKyAc>b zbehoM+xrJaM36U})_`>C8!n_s@Kb_kv2}~)4n%qJ6jX|#yE3M&5*9mIbYKknEp*t} z(xEW}i^GL9n0Gfgco;U2iuGPN9m|#(4a9nZQs)~VG4(5d;6jLCa*LL@nlrsd5}j>R zKCs~0-aVl1iv*}N9A7>}z!jj=0m4Et%SC>x5@SaTXOVIT8baf^teOOb*vQ3_&hF6VeM|3{m{*sEt6;T zkvWgo`$$Bi$za47Ek2lNQF!rZDu*{vLFX9A=&m>^8>pBsN5?PHv%UXHEfs% zWo$*GYw6yZ$Tw!Qy960}*XycEJC8%vxrv10u|*^uAX_(~r&ZaxQ3OIdU|B)mTa1P${YHiXIbiSTF@@-=w%|`Y11;1 z4?AA{TFpmFBcljO{owr3Urq!<+0|bh2?H-5s6yJdiktTf+55|9ehEDk`maws1SG`2 zI)5spD>&nB`SYB1{MQ>-qQN;B$u$*&k*K;4=jf*bUoF4)%Hf|sd3oTqKSL&=*X`qH z7B)ch;~e}+FEHD>{0*%~zBm~RUcKjn&dfsxK0%L>rhqq>E z4*ozYYJ;Uyts*M_5XDQhBuwLNf~}BLhU6$TTiag!X|D(gImv%{Qx6}b32sMpA_z)1 znSRS=2%*-C|9z_UpTXK!m@%lK7;|2}rl(s>K7lRtn&}EC=Nr)}S~@>1zek<6Y(Ih5 zp2|&v^_aGF{?+IA{37**^#a~?U>@apbcSDh$MOdW@zuV(cJ-Ag;cTVbibCPH8O!1prJg2`nZ0_^R@F!EzWnl1MuYWy5bZI%ho*{~s;s1Sx z=-1$Z<@tK}G76Ia&ktYb3cyid(m>hbC1^Z7;<`8z!7$>8Nq&3m>k5$fd^i%u|(c~q_KK2db?{XY*gW&g8s<;IA=-#*=O{xK&en0&O0 z$2+HS!AbiE8S$rNmp$ZHtezljYl&VE?uEV4z~>r;?qBGaga3T8H!l76r+*x{{{0ta z;M1$HU~5D^xkNv|i2k%i>VIePefrB2EdTWl=R+TVMipOt<>k6_|NC!c`1Se!{#zWr z-oVGz`}zbvF2vUm`nV8ZH}R*$@-+*6T!~v>)9Ak|@wJeB+8bXh@yC_;iiJLf#8)Ws zaV5S&iH{5M|6?f8*y4*A@&gCAs-3Q{nHPi-2>WG5TJK64C5GIZNPsW>LaXwjsIhk| zL}A!tKkIOq5-Zkf$a;G5+%I1=HZd`gKfCCw&0V~CrtES+jjoehnbk*1tXp1dqQXJk z|A6y`4V(0twmRwVxR`E-=>C%aO=w{c9A3lgn>HjCLz2Qu3dkKKJao8m))!RF@uY((1b#eF*%Ta2qW$AZj@A)X{}f@|Rw` zylJVEdc1LNrsLEUi{Z9m8ck0M*xlPX1_V5&xc|HJEe(;UrXm<_H^)f2xQWwC-zOwg zR&iBW^OwnLZtKgImX?n{#r59?ulO>$f;Y#V9IFun+5P_e#i86sd3oAV0{+W0gmVYW ztUl7uFU9^^>g1l5885Yf=Z|~+hU*XKM6M_J1&b^FabUbHfOEMV{WXbOIAiZmiHJ;1 zkN0>mw+lBc6c>NT?eijUQyP7Iu%v(2Wu=yt7;vV!r_5b_=W$@HJl(1s1F!#ky-zR5 z?cN>z=+UFOKYFJ9=LnKuX>N9X9Q|~u!TiHYdZ{xl&d7Ot{r*Axda|95iFb2TlWzok z`{VjIZ{E~-DcglBk)us+GM~_*rWWmRQBm7?I!P0D@NmVI&R@JrKeS5H`#|57{?a0% z1b*E|*Fkp(W?}5n+`Nz4{_lC>SVp8r-)o2BehGA02Xce=P|ylJwdm~Um|R-dcZD=Z zp_MV**3H0Ucxqwl=H@2u_vZhw_vZ0Xwr%{d?w)9ulr@x!N=QQXx)XO1NtUdkEZLHM zETh$)cDoTpcEXHpOv+aFbuhNbHnyyTF*EORsqVY``8~h)ulLXQ`h3a^GuL%q=XspR z^8FsiaW&6&&P-nl&F;@RAk^cfk0sp~9V{BPKILh;@O)cC7>gf6{_M|y_=CGsxT1SF zb-GM)r_z{z#d2hNc32p1`rlVv583^XcNXRzcAl4TZl97?|qsEXd zQ}{eILX{U6^6hc!_uh1DJsRJ?jo)XvS*17Av;VY1MNpR`g8yT;R`a>Jx{CD4NMLPd zN$|_=_QXo3>eU^*yDcM3eax|>U)y6)&b=;#H=E&QF&U1Ij-JA*QProxrd92mesHA+5BnksZhcmw2b4jwSF15$~?7{YAFS0YW#Z6YvzIgf4 zxT*Oj9~}GB%RG8=gI4QhIB)*AivQ2?iDs0RDrmFv&Bgs3-`kB; zUwM+~nGV;<$cU+RjR7Zkf2)5mtIKhHmnGEy4d3MGcBi~kPc&T~N&NA3UamGbM3y)9Z74zIk+qa(8uPT+0))s`HW3Wj#IS z@4kiVgYKHEm@xKTsq9j6TOZVq|2l8|kQ;}Zj3TFs!Yu^Hc8P~8E6v&Z*+;r98;^*N z^C=d)h4jS6X2e%lrKM$*DYbc2zv-9RJk+!A9P`dSBgjVQkg!{d%PCVK6v}lFUO8uIq>MYkTM$^TasG*z9iMy z$bzQEnk1cV>JqJ@Vyu z>r6OJ&N&rKd>Mm1%DsE&M$4hQ1g7K^Wb9}R}##M0rt*K z6%`Hi_}fUrSZ2l2``}=%#a8vfr~v+;dzx|YeUII!nOvEF-}}+se62t^+;^!vy}sIG zIee#s5+d@~?IfM(A60sjs4{u~m+Lh)#3jv(l}H!i3*V`$sO*Cw8{i!sdgjN_)uGq& zc2j#K#wqJ!i(KqpWk?=!F${j`))F339N%?MdN=g6HNK^-syJ$D{%yXITl3HCCcf`Q zUgg|$?Yl|j^61(tvD8w#dkEw*xP79-szmW*0>~gN7Fbl|gQTa?-QAHA^X*p^Z-0P^ z`?>LvKQUoZpem3To$nvZ*l+?OxQp6zMKtQ)M13v1_B$LCs)He^O8(S7FoP?A=Sm;cUVc zlk(qqOyU}qY)$5ii3x@8js%AkzahiiR84x#Th1y+yfHn94$#a+;~bj_9+)FzZK>^c zc+s@9tjD@`#ESPXMTX+uLRn+AfSQVmUkV?ziyb=~phoef*xga=p7YB`4pY@^2}-$U&O{-qpLvx?kX#(KFe@z$ zR@hrJt`&!cHt_kBTmi$SYLx3qZBGPave)5p8UB`iBK^aMC(&l_CtCHCBa zT_JIERg#TO+ptC6kx|)MPv1l?&X&9}@PXsg_Vl#$p^koyN&u$aW7qhrL& zF5t7@jgM!-qt%XMWVD)(7^eQfBiUs~m5KZLwziI{VPd#lj#!MT<(GxPk85j%<6qKA zd$Ni}y`}`_hhFt-hJr>DqI9{28>2d8B`n68m`M|a@M=@kwU7g`&r>d`8;`=*mTT{` z%6Myd-JDa}C$VsP-8Cmg5l8AmjH(zyU%u4F^jLpsvxNu4P~wJKck*Yzil(Fe&efTH zx3ci5s?)dpSmB+RldzvzKa&+~T zt~cw>*6KlhYHZ!=<}>iLntimZv-?e}lhAP=JLg1A;r)41rjfH3HA<;;W;>ai@+G?Kw}KGN|2(*deE~EM4=QR*biVkVdJnt{JM?%V8@u!&@Au>tc?PDkzYc^`N9E z@siWLeUaPU!?CA=Z^$Wh5V}H6d({k_m6+}pNrLUzpg_o-n^w^e)eq7sp;r^ z;&_4VyhWW9QL3czMQW>|qWBNVME?KYK3iZfgS@i$@^+t&ea8vf6>AiHUj5Z6Zo_ND zwewDuu>2vQ&qX^(c8o}YbDB!2T?vhW;(r`%B#e+f2R?rmjh}djF>SGFvUZs-`;OCZ z?N)Bp%QoV*HubYjI*SQAn*J_OTN!^SSp(CTMzlRPmY2JNeAgiV5S8NK{Vp{#ESFFw zTGiuuF7GVH?NTA5tnkL9cw$@Y}Yf*ZL7z_3BT4ohx!RU$`2%Z`mQL7Og6p`$Ap#t&ldiUYCAY_&xV$LhNgh03> zgOz6Cps7tOanTA(FhzrfGnf{eF-uEihK=u3HX71;Y@ra%nOMdW z1|QX7>^`U|D>-GSWI2*9u@;b04b7NidQxm=6nR7O($=QkM4MK;5fSkrwfkKd=Q0gT ziiw8=Eyg7JZQEy!KElT>msz!(PL1VB+#DX3AZKhd^3_5HG^|fcth6)LGvyyOQxHPj z(U5tE)jDv|bj;M0ugbau;_hLsOxzLwx|KovA%=@FeHF#|5;V>+IyayCUg zTOWpF0ov(&1?SnvZbVz~q1ek|LoH_NVbHp6ZlR|su~yVx1Kqvb`u&8h>U|HgwaI>^ zQwQf6EBZ+jO_4#YN>AkcT6s7@9d<8HShhzONWkBBQ;+yJLqkD-rVA=(ulBpo^v>4s zc{p|~_VV<8kJ+;0-pz?Dwd&GR#In!AR0Lgb0PPvdA zOicR(4=dJSuAJ&zgHw-Dm=1fmjoq!E=56K84KYQ{6JEz6MXNduzYQ0hiDJ&(HS9Wz zql{I3g|)HgK_9BNu+IdAD_`3iVe0CbG_MEY!MkQop~pv$9|)rF3k)B$J2@F#2{pOZdqQtT^6Yr#vPbD}|K>Bb81-8|fzRR(03th~ zYku@FnO5RfdrLXyk8SGesOl*K7jKcQY^npq8x|E!uO$-l(Z_b38VqUrhJ%j z=zsc+nz)CRZ`>JHn=KHrxvA8$Wrp26%!^O1LOWo`#f(~^YJO)G#C5MXye~{&b4?Vh zYYbPLF3c-0b}H6nRgjh4>t1Zjry6!z!ub+!Uh2$GBhNmeXWVueeB%Z2OaBQMGb%a> zd%xv+817~|G1Fn!!Zk*G$nqc^367k<`JZE_m{{7KT#K}1C%yqC{h3lfNa<6j85>{S z19ui0DE?6)W*aqDmRd_K27mcJl^Y4CT0y1p?uaj?MG$3{OX!H~;rjfnEf`02QYyKX zdAe-;zh-o~>^|$o*1W$Wciyq2>Hy1PIO5_WiVpV3E4}`83pst`gbi$sqytay!yhHh z;$LoN78MZ>hvhxRU2~22CV=1Wj!AHddr;D&ef+EnbDp;dfq_9D{rzJ-aDvoeo+?q&LU-4u)&J-Fu?Q7SCHUa#+#3A}UEUC|8Dl zY}OPBkqTGmLfBu&ehOze4wgmAIiL))8RUL5yW8;&FB=rgJ}2V(VX9YZ-c+5F#4Hc# z1Qn-Jz^4_Huvn(>l&U4uSF#aT!q6{yc}7YMG>LpwKUbwm$s4R&-&DwGqOUh4Slikp zlE(~o#c>q1_sDoxQn}C6%D)BoCsxN~q~~?>=KT}p*^E)NRY@jTf?KfAL=dE4wsYa? zp;8>NQ5W1koj=IKYSQxtB$Vsq?d$IeT|1~TOM~-4DA3w5+g0npHA@#SJt2X;K0D~H zI9i;Rmd<|QDL@xZ)+W~Rr{0&lMa#<5CS3c)(e2-OJdm(baqe41XUlMHYm|G8^MuJ~ z{|B=~dTG~*)@RebL(?%%mqL3%cLyP&lV}PW_)>$u%;x1%fv}6To*t>=*%^o%dR6he z|Hiz+dtmHtv?x-PW8cMW#fJpe)K2Ug%EF(VE_3m*@FC}qhc(c~z9|clXJyQYH_n(| zXylimk4&^odymxS?LP4M(xx4%KC^4X5EfKRYANyqsC8+H7z^%Jv~fgW;B&1tzeF)u zbWk|V)cQEw{$|RP^83*{&v_)D z_OUWAFTTNGj1VTg9^6|1iiE4kX7g;H0;EdZ!66xZdcosx)(RrAC!BPmy~R~j*47AO z;h@uYTW;&Ws6^ydvbRah5?JVdQB2*-u}39*x?cl5OL|wDc+yXLX{4g_(omM|g5SHe zG|m5bNK@#fH*~x)I0vvT_Kp|NWRcjx)AD6I677QlG$^atmoyip&&W$nZB+Q>Dj*P< z+By*9^Dx5{C!IxIoC?YX-EcZVt3MR43Q(C-tMe$$=IV38KEmf3^Usi&@A*Zb+xBgl zK{{YG$>hPrM5}Y02L&7Cr__Q_Uly!+CSAHI=v|iL_~WYK3bK#3@8c2{K>wNN;tnDJ z!gjI!=H9*9Iz}YNQ4=CN+D}e9oC{?PUSUkE^B^nxVg-x9U+JRnse%q!?l$BjG36kxLl?H|~< zx^Su28@Zh6wXQMe#ZTv07R2k9t^hRqWCP;Ob4!|CY8TyP=y;_iPi;F`0$nZCST1wp zCM0eKbMsd_hy@bl#>7#c++P*G#~mX%J56O@^!4cNv-s=gS-+IRu{H6H!V#ioY)Gt{ zNh}#m;gl{BgT>!%x0_Adr1!+txp=C;#Wv;BUkhzM=e$-+X}2`0_ZhCW-h>$KgDxFA z?u7;-@k@P{8A$1ld;2DstvlK}qzgyJGt}V{Tz+dF?3P+tsaLgv@-d*9D(B!r1xJpqn^;tQ1iT&9sT^ zuiu9VF?xU5)9myZ(b6{|wrKXkjbM&;V7!-Ni*_v8PFVMW0(P2}{DEbzOUv@D-9Lg{ zn7AL#ylSksrl7_N!NEiUOxKwBaWp-fVcTjq^*t;bH(ppqrPw=TR?j{`roa8K6MqUR zQ0HocR=?k8y^u~gO6lX&nZ}D-Da_tLU^!+g0%)z9nUM)a4fV}srYXB_>z2v`LaZ{5czFwM2cSnD6p1ieW|FCNrH@A^%pHA zO`0);;ee>@rufn>H(nw?3V>i4$JGh+%?GUcEa|g&=NPHv&=eu%vE1R6^ZQ*!MQz-t z!LQS>>uc4LPp$L?xry0)I5d#sK9i2`zNeEzZzHS=JT<^(J%1hGwimF5JiA?NJMQsQ zM5en3OA#f`cdy%Uc<;-gf(%}bB3YLPBo@w4qc~wbd+_W|;YY)jcQq<3de+6uDzfdQ zc3Drb`*LvXIklO^;Udii1)z{EpZ?W+{SfNHRIiA(497Rw)lePw#r*L`8Trw4w3bpf zuTsYFsBO#yV%^4kp)dh4EYUrTmxOxFw);CBR*G(^eDW>=34-C12Avm`op_)z0p2=T zL{GXih3-A~{S)(z$$v}37hn`gA@ASMc}^|Ltb>B~SSO#MiCiaV;*ldqY`)iqf9M!( zPNDC~8*L0({Ed2ALzeDB9>zHPkLcsTd%9?rdQ_oBF= z-^;*fL`}^6zIh= z(5|JId9kgZj5X`_B{J%$7{*F}e)!)H00Mm1?4y#*xog+F))tF62Z7Y$S!a>m5fT|8 z?Up}p8dX^*U$-_rK#au4ohcely}1$bZkmGIPsA*BdCve2{{c4Zi+A!P1zJpZz$1B~ z`=a_MN}4n{4@UR6L8)g5hGEUy;nXop^_X~7C=4Ep5q%S7_3QI6+wkkeOkKO_qE|~( z{Uut7bbx`GClu|bBg}5J;zsFE-_@@~eN$!CK=%gFsw>D6@QRqt>`GSW>W*69wp1 zv&StwE18>}k9T`Fx{rTd;TBKv6N9D9^sjwoAGE82uWU9fpW-kXd+U~6j3Tkacz*Tk z%e^mO2IVw@+D4SQWy`tnw^e;6bB*hsnlWNI*!H3Nu36_#$ccdir(%ZKo!$a{=@My- zFV?GZkM~r+Vy>xOM~0oNDEXrB_fv7cdVVj~=29rH$}mebLx!Xi=x!v=ld)ST1b<=y z4w#%qX@o5GbM6h3pmmMr^@29bPY3om4Eg8SyA(Ey>*4Axo{E`z9-xs@sSfl+T;dyQMEM@Xfk!NNu}vdha4_U{J9rKVZ}LL-stPQ>;33yp@0r7< zILFgRVC*C;tNtNOSlbR-ZT@4vR_ezP*FUIA3?gCoVrW#h(8rCEnNsD?>bnYa)9S0I zA)=hR;<1EbP zn9Tg^Ut^rf@A7`AKyBCwyQ&%2ZzwF)(3-QIF!s2OZHPLv!m%Ei5@2I1nJQIBXapmN z57y26-pTlxo&C!wJyefZb==95S^$Fdtd{g!o$7Q}@3fx5%Ss}CJAz5Du_6CMr1Ug`GVJmS;e}4R8K&`rM-xO< zdJf_EQ0z8wt&mF=we#oi6Xj;qlCFGxwrFMztZyyqw9Wb9(HOTiFiJNI z7#-cfFG*8TI6azo)+$3+O$Bv#li4Jx6mbx85HeB~R_0i&3q4Vs(Iepb;qE1V3YY}g z(#4dUG^EtHvkc~=b>uo8hz}=RnG)6BH}Aj$Fz^*hCc)jlzS8;-CUiL-FbAH7-OcD8 zKkxivHp4oRuCTFjQ^=AgBjf>3p45Lb9|PWewbNRq0wQ9Bvq&L(M65~XG7`?Q+D0gW zN*N+!`k!nr+lK(&;F;#(Rg8K%R(}um?ByT$CZ_`4hfL&rgPbgCx?l8;i_807Apn7~ zGyJskHUnw8c9w&^48BHuwxAmtFH+&NQoj8!HB!O%@gSY6<$ zz^CaZcq5iEjbC~pl1pOE``VN|A0OC^l}=FIUp$rbtQs%S6d)W6tuN7vz_Tw7dc3k< zTI|Ye$P01Mg^(!}fsSfMlAbdnH&cDahs7CYmtZBPICjbd;4b{9nmS-xl4^0!mS*K= zN_7(BAz()c>?wl()OEY%y1xsH!R|Sj)q>L#tS*L6D~ISyEW}&LRKOSI!3o(eAKa}e zy&hpkj)j8!rIZe45h3x#2x)L=0dWFV1j*5a!o4j|S2#KF% ztC!lkn~1?*8uNPZbnKnVwXPZ=6VvtY=fdU%5rM+{^QZ}6o2kBA=jXwHNwF6{03ykK708Z6I$&|!G4h1I{sYQJgaesFh-ijT zC%=a|9RjU>+%{!zH!1_r%g_dWUER3NJ_p3`6Sz0sG&ZASx zTA_GIIXey+FIc0E-Ph!@b~Aw7?_My3`OdW;4(6bw0|gEm zRH#aB=Dl4=bwTBH$@HstDfnfdkZOc`1Y&95!3IU)$MTj}w{!XV9NV9lR$N?82NcyX zc0Nd1xmRp7;_>R;Y;hWG27(=3livNMd32Q37bWkRxVTDN7G!HBYL3Q}zJ0cB-Mxrg z>cutFRw>iZSG^GT1Ln>-lbQea^`vUPPMi~YRE9$=G%yfxYpO1)qpZCakwwdQawCar zYsJuwt44r2Mxj|WU8cRK(eEmK-DhHbm?hBVF_~L=>ylT`J_qc;VKx%hM)jJPmnt=q zFi_vt2)bScqIIlRN8sw~U_cLSW+S7-8-v+WV!07Yh}qZLKqJ!A=@zL$ex@yf0X#Lw zXlS74_%Pux8G=ar*+|#XY&4C`(%hl{r2i<}3)#g+G-N7VU{!p@!waBf-rd&@+NRpv z%Dmn#7L&=&#Qc?FkE0irveOAwmmpn0v&Wl_4Y;nE1J5b+yEN)yvWDNtZ+p$`5?I*p z<0Ciy$oRqCi{FG(Zj>BOhy$|uFd6)V0?#_|De_kqP2(g} zDt-GH-#@7+T8ueQxzho8J~Uzm-dlz#(dWZQM4uzjoJk}`U=xG>aG`YP4)hknTmM$D z!K%cCyd5`f<30gaOv;erlC?KT)RghL^HttSnB06&%W+}t!ZC>DX3Nw^<3nt?vJBxY zEDPhLv9h1sOq%J+=W2s9dfE5RHaW%VIi*E@2<)HDErUxIQ|Ta#1T zDRwR+mHy@DEGNSV!N6Us%bz%@^GJZvtD;Ct0K2`JK{t^}QymcVDF}imyuS(oho_oxIfJ;Vr<`8NdLHg4p{?`!ygD4z!D$ZGbsVD0*(CSB_!PE zW`*;M0&s@=-5DcC%}|qtA<2U>!YQ9E$?%-7HQS4V30Q~|codTz z|FQx9O(R}fv>+33;k;}hV?JPy$_NWP@_1Z=^iLr6C0@9EX(LoXV1VpLjVL)$1Rzx% z<(E)c=@|wR%XUsy6*50!UA_+6QvBwE`nB4Cw#}$)Dns;Zq+J&DE(a_yOoC|4x%Xu~ zW@2QlJ5|-iJo}yl^J$^O|r))0ox; z&K>nWUV4csfsgLyE@rPF{IF-efakYa75_;9j}w8}NOs$a3uRo&Fl}9k$j>m%!l^E@hBG19ecga)|O%fGG~w!7wr2`E;_zGIu9X6;51hj0puIJtH$rN z6nWl>5QAbkTwz3Gao{Nc7@4=*GHUgAuca*dfiL!Ydzw@bs?%8w>kK|%=MtaPwW`l1m2_XuPXfRL@HpH<`MR_$_8CZA5+gfLRAUR53lRXRC}Acz%CrFirF@mAwTnC zyzz}6KkCisP0%n#&sp^|LIEvkNe12h%T`r0YCZ(@F8H$x(4t_&MBZ4xN;ie zV;Y)?#r|>(vv}{IR&9lEF7P198Hi=9@Y=zPaz$iamm$UR=o(s77*2~=!uyo|3cqqTex4KGenZQx*SWV4viOclNZ+amS~ouCk>JxH`34|s6e&~| zNR2a9ys8QiD&eY#Ok-}uiz(Qt#wUG5P%wC@C;hBI%t&}3LZ5|}Oa|-4p%E50H^P)- zNN7z?N*Xf1Qgfed zb{2iKWnru%vw8L|%f8M@XO^X@n9VCaZ_qPJU8W~NNXZ6Z|L)(eaJ}TUc9(gBH=$^1 zaNB%5V6_wz9~`aK!;ulLWs2y^D@@@Hk$J&ZZC55JL6A|{)gFa0`95P!x=hkgo2aD~pHM|`j zE0SxNDPAf_c0*p*geehX!9(Qtb92(YYJjB%J1;h5gASGUL~DV%zBJO@x;7g z8C-AdR<6E$v8aDVd}XMzS5wSbcLBdbL2b(TV&;0*`L(>-hKE7Lr5Oq&(G61>kd<5Y z);jKaM;QWezHpuC8zJ&Y4OSpw8<*X2mG>uIf)Gf2yV{OnL9g}`tIN#Ir;ohG?3timVC zX8d+kA(wj4jiA`RM{8#Gu=aRchjJ-plw)DOgwD+%bKS7d?P8vbfYi6|h17;u_b(uk zG70VF8)9}_TOY;~TC((o%ey{O6(_TIK^D^C^ay%J>r~_Hvv@^>;5-i~wc*5_=1tb; zs~>xDPoz&5Z-+FIlB{xI$S+K+I-P{beLacp+Y*d$x!ysha!gXwAtw{c1e$Mu1`ONc z%MfdY_PL!}CE<_x<_DL|Oa;KvJ~ZMqiN6f02T6uW+OJF-6^`>DNl8$Wy}Cs_3-J43 z(4ODuRWrjQf9j)MNSCJQbURJ(&LdX?0*q);GC&0zuKi2J7=ns{4aD@M3`s8r+A{JT zi4^$yT#{2G&9di#nGE;I(`1%|R}S1?^q%}K0RAnWt^`)^DXbliVPUU8>v55osvPAl z)pOR?-skuIfCMB;#O&aN5A*GIp{!aBZFn<2nRfwZZq2G^aotSuZsV-_h~t#$?$4Ro z$0xkcMI(<8Y6>YC$$gA73MYrd5yg`PF@7!{q)2RaHVUy*L0%0Cwl{q2`hAC1^v#06 zV4YV{0VTpr`G5@Ah?wO}!g~m#j@LD6#ai3Vylp6%c(X`1jCUlnZl|gZ0bQ5H=99qgi#x8326v3KNtSk zwiOe;S48>N$Plc;KY+!lIr(pnwRAIhtXm(Z@vz`KE6w((oERYonbcrBACxuIT zVT`#vAm$Bwpcf$ANod=CBx?*|gl z412yUawE={!t4W%?0O+<;QW~RP=A%KQi##U1Q?pdhqEjd0|%M0M-5KG$)@8xy*1F& zm7dGZ!buqlE5kORbhE}IIpwpz6Vq4+4tpfa4s886#Yy+|lpvuBO>`ntm9LhHP$NblesDtuiHg+(@; zz21ijkVi1M0toYY;61fXO!)ZZTm$c|EqQBRB|_k(qa4nqbX*KS)}z;`4SohtCv0rC zEPq*%$Wmqs>@_oOsl9)0n#fX@CR(-n7R)z9c_$ahx|D(IqVj#@MYVuUj`aH2ttTY0 zjuFPQc~=W>fFjbMefnh7JKv1$cEuEo;(JWr&;CSB?PD0?dy#qFFDNLpUeLxwdL@2O zl*HhSa&ftXC7=1Eu}JJ2?TU99bsO9N&1D-Yx}P0B@3ZK`=8hZw0|{ib2eiqq*D4MonDy6Z?@s7Vq89Le{T4hkX><&@Bvq{T#0W4DDm`u8}U zB)y(Gy}r`++K4lTWY+_`7zB7-kao)22@oFJrO2Xrjzi?>e)6lS)&UV4$odoQOOF{s zGV9ua>Zo5j`WtxJ#QaZhna)HZ#_aeyN^hETTU+7ec-9e^)-gJgVTYt@>v96aTqJIv zjyQ}+c6J&NT-Cn6Rodi)tr}wgZdNytK8!etZAiM|t%ZkxQABZN19}D7kYSLS#09@k zosFq7c|N7%ABxm9Ajwfk9<+oXkX%;3it*NOxckV@B|zSNZ+{j+JOHv#s#0hUKJq>a zO-E2L5lV`XdWGW#g-6)tfG-9s0!-RLQVLVdO(YTo#QjeF`40w9ZrE#b{%iAtYAsM4ZGF2 z$M~FIgzQDcWNTZ0)I#5JjP>Xkhkd&|b-0<3?*uhrUzWyyyZzrAWYiR|`x&Gx$N@1y zTtQ&0*w|9xX<=oXmnL7Jkn+x72UZ*bz<<5_{bvq-*Z<>LRlP7}{! zrq*lDx^<8`bQC4d(|tr8bWHpdq#W2(d=1#g!^Y<8y%=ob(RUas*2RziE}DQ;56c_| zzr=q^6TV*;Pf=A^`}WP#mkuSh8=1~*OK$X(Se(~<_|UmsODxEeS>WgbqTByDI)!|H zlsQAfVarW{3buX$j?u6p<0fII57IWi*RF*Y`PuKArJ3hmZ#L?3{4HA;&XMq=B9Vy{ zR>B{b(=+%fI!GLw8F*pNqS>xY6^AW zPoWOnk6nH*$04lA?M0{&{Vl%?Mz~*eT`T)LbnCQYF0{etDlh_N%1dLVMnL;8C4{ZM zTnN3u?)Q?*sCe=|-1YsRmD)e=4qQ*cb;J2VIvonXr+vSbS9Izs%PMc(Hka4S!^9Mm zYUE6WM%IB9ogrv%DFple{mB1idGg!mUW!=2!lnWHyZTYF z)Lff#0K7lJBmHCRk}mIr5+Ec& z!8+detPl0WUFW`s%6k{KoH2xU*1K9UzDiZbz79w)0EN}%=egjt{9VNk)8y~|^w&Nl z*W)-3lN~FmKo>AEoy_nEpPq##5$z(%v_lXQf6&*1jkY$2l^<-b7P1E0hIO8hu@8)K zNbq<+>mFo3=j(ngR)bX9kY_@@0w#l@XUmR)u-b)HnBx$0t#U^1^>zPt;_ z-XC2Un3%fHtf#8x^e#6(=F95^gT~!~BhrIbKfPQp*7iBv4t)+}N6q$SH%5`ZDIt9r zXiKvJv zEg7Il*ec|e7;V2JAH!|~hAICG^c1jw;`VB&vJZzQ2D?V%U95Z<56npw7+)znlf zoP=pgQgmqWFcc~fx7&L%*BfhOp8$IvIkD9_r4rDSw56k8+|!xtZ}u5{x5{u5Y;vJJ z+TpxLQLDmWqW3*~NSfh1f-uUZ%PqCDA44M^L~m$%gY*OeKq*lR^X#-m(}&EtnMX0{ zzcv1NPTki@@YwTh=E;=ahR!cgikBgTL%3;34XJeQ@AX(kd;FN%16n`27r;B4w)?Eg z(eFHt9p3|+O*NL9)H?J~-}Ld@+Oq956smin@cUX`Ua#SwE@v?;5lbPhT0Cq}6jguZ zcuga62bEvGhzw?nZ9V6njutv3kAdXt}ug=bGPXnrCN~YK$Px(Fc7%_T{ zBrXUF_A#!Xuk&%!?3t$u?BXw?Y~)~yM9mPmETl3Hof>V>HHPmg5QNoU={VxWO!mq44m^jUJ)lXITV}z$ZEOa(t_{g)a29T8iAt z+`cFa5+dOr<2-JbREBz4t-45DZ}LDSKp(`=VQCf8+91Dbur5zLFgiNg9+%hU=d-l3 zgKn%aBENB{AOaIAA}%ivytEi{_9>iG#A%O(8qcX=;h_(7JUhu0BPnA_BlYL?3a;PA zj|!gkmkD>ZlucePZ(0D0H56H{W(BP;um=&n0l&~$M8>tg+x5O6J!DAzy)j(!@JN?6 z+ZM%Kwb5RrDG4c|rzf`Xdvy}efBcl7%LdYjrpwVO4{5m)Wo@?0GOJ7)n$_g+c|l{2 z)1zx4xbY8-o3{1afA{D`_;kmjuUnCAegRm><*5g`JN{6H2a)X{EnuS7he%5LS9MEK6~y1c7>rzx4{g}_TW-=7OGNLU59xort|%4PZ;d`>_oGh$^3FfAv4x`i zSDP>D6=GGe=c;!qozzAze+{F=a-=`{+z)IOa`t&m7h9V<;NQ$F6iiue2Uh}rx) z=^mo#V3P`0&B2_Pq$slWD=iGz<5<9}jOdlG7Wb1cJ3BD?cLyBP^wOnO-k{aUneCyy z9{q^_lEWu^tLAmkD%e>yu(?0&^+5kWV`O5wmOIf=*-so zHZp?rdcjIA-hX2TsZsX0wp)^Bep1?_GX`89YhPcFJ&A~z>R{@8H2iI(3Vo?i7RLFj zMNY>tB&8KK@S28ntRO{wyr-FKY3^rX-#m!>+4+fyA@p}t3<(y&Jv}{;_WSPb4Eo7C z>Ozk)xC%LZa#+*O(Ex$08h7=~D>^#tgnh~fB$kG1#X>bVt$;V=KXoRcv~+p)CZ|vp zeK{qy%Q0QI&&NeOu`j#u$d5D3vH&ryo8GbRjesc8fAkhvmocU@l5-xN2cUsV4qvz1 zXxwf$KoPJ;w_X<55!aW<4TMHwKwfQqCJwUf^Rr1bEnc{fq|A~W3cApUl0_DHsXNal zG3#AMoxA$_iPvI~Af1k0npQj>{C-w|L0!$jWUFr}N$~vZx~|4~*$0DcS~+&=uCz z8x?25!7)iI&dQ2xRiyN|G~UZ|Ap}6x?SCEQYa&ig zH$my?018vs-p26W^IqIlCm**Gvi_^K1ck|LHVwwA9X35(<)4F5nWF_3;C~#6sE#*5 zAKP``y^N7@+LHW9kXlrHQANd7fJbrmLis{Go(k6nREBxDoTbOuWYa&V*j@gAJE!1# z_436(_4V%t2GTe9$mhGb?fv(i?fdxu@>M^&RJm)!lQa7I@cfK_AMeU-2FY7@K{uIy zA5VR#;e2~FSi|-WkdL*eAH(yHr}#-q7>>|2S$Mm}CB#qF=AV{%T31 zY-(lX=V&=R)NEqbW@FneOZf9j^psJ^VR(WaSpv9pSpK)^9|1shpC$jlp(-sg*@abyTO_1BOZ zOUBG_Z$G_g$Klvl92Toy#N?mQu;Ke{r$_Mz!awM|eQ#(5i`d%ewP{l_CL}myQ|*gO zFB5OJ9Ty?wT)xaw5U{mZ&h5|35n}gNyN^|O`ynTLG0WayRu<|QC-=I*L;XK2|A~C{ zlI~ZXAJ=_ulg!&IP3|d;^x*jj@pDG6H!?j4SX9Yl5m3oi`?Fc)bou{05)pOrA6%H~g& zo^E@+(MJGPV}<#jaw>R-`H}DL|6Zlu6g^?5k9QYbTV1OfOql%J@=s-SujN**N)az< zL; zwOXCAte4}a#Kgsf==YmEI+@vV)#|FsgQJsq3l#Z#Q=!$~jn_YfI2}45NzJgKdZxB_ zVplg_D}OpH$EQRav3rv|RZv`N6+^WHbjWgnuf^EuZF#Jl!T0gr zxtz(D=kE>Z7QJ_Ik+5t2v-y44aJF%*jfa8Af8XKy`=^TZ>h`^W{r~$&T+L&2YNwyY zj6BoBhr_)CHy>Z_uNxzCf!Vj?h5lv#3oeYX$sXg9C7rC+E%%g$j~o+wwS6@sRU^H! z##-og;WunON+X);)SRGYSIrCC>M`Tl_?d}fRamC5Vd4o#@{P|4jnOCb8V5gLlYF9o zyFLr`V!~DLzEb8(#vt#7LhBQS8TG%U^It#5=){VDjXAMc`z`SxtxY_!HbC`JNpbMH z$X{)QIe3~+_GR)XsNQe-u-|o}ZOa+A&fx}ur7eG))OVnBHJ^L*;Zo|UdUcIRYeQw( zVfEHESVWM%Q;2f3&1HNR7czKoZq@FT&CF=P&piM_y6vMxZN6chVg4==Q*S;R>cMBZ z1}(x$cNTbDM+&tc$PhD5Zt{J!DPVu1zLRI6%todgH{!9IGkS7pl8A`h^3p{q3f0qa z=M@tL1$VW{GQGar`+d|*w@}ycUIE6QEklNpDU(ZSb@+ZYb)Rhg4prE1Fq`+T$43~3 zO`koxnROy4LCoZd%d8->t1?_ce`$rEZ(T7KzpL10--|}tL2q;gg2k+V{JvSaG2{HU4NN+N zLf7Y(PlpT-Jqpj+{BG?I(c_+5fa*@B%Y5GqBqt>mGE!`#q`HcVIL>0s%MTb7oo2hV zXY1J&=Oi`D@)q2t2|}KusJmmIpaC&jQdTy3RqedW^{*~@o6O%@gr1Ezy(Oh5e&PM* zsC_)LOrMyyuWu)NNslk5(qG438*RMyAKzrhH7*B3?C!d4J2#(E$?h{)Zn&b;bmTIh zLG%f?=u&pl8HY}K)b(!l%qncVz~=6B&}&?9Hp+J9(q+|W2|JWJ-=EjppOT9)5D{P! z3J?E8w^X}y7I(&Nz`972IJQyu|FEx!6uvY6XNI0`dpU6|N6ss~mt@E>Jre;#v+)On!3J23e6B^ZrYMLx(Y3d`gI2fXcws=sF zg|}C>?zD=1fKB|v`i@wFMml>2``%{GTdb@gv}|-9vt_$Aa-Qvv5xsk=9x2C{+uW=d z;gxpP^XAgNvk5oL!vs?%DfqC7O|2dGl)}36?H`JY=$jjnWs=V{pUmXo6f;x3ea8Qr zU+9Jyi@V8KgKZ(F$i#!&VR7!A`O<2QF5tPd+CJTTZ~Uk;`*Ao!&x?!40ZLmp)e(?l zrybzS5z!;P>D?d5c7k&UJt~maK9Y0CSx|@&d420lU7b=0SgXV-z2^o^Da}na8$c&; z{HJYt%^Y*i_rbrm@UJ&D9RH}^q9^oelYM_Ff6bp;lPcrt#Bnk&jDP;Fyjty2wRjlo zdsW5kL4cc@N;bddwyydH{Db@4!cU`@k|t5y#6{w=#o$nq#Ocy-o9oO!G~H9}g(nA| z)p{%LJ{GC;B>wRmFY)$U7!b*1Cba~)W$oi*t?z>Vyn7*6>-s1wRbxZ^jDvBwL%MQK zk+OIJ&vOe*PD#8r8(X6JQz=rmRj0MKbM6@GfJsU8$1B32+#=T;V_szrWH)c_PQ94z zlrO&A1)IBkp`zN-z7J=Uc=`JO!`55JMZE|6!yqZG(h@45bV!4gq>@TE3P^`^rw9rN zs0fmR0^-u$h?KNQH;ClY-8{2;&U4TGz3%=Y2CN&OnR(ZIcfRip=A+`PW&Sr2=gqC+ zR@>u|SH*+=4@iWnJq<3(J-F`MpPLJP8ymBXk<|6q>TX{ZbHA|_XMHq|`Y}qOw4YlO zgd8q2jWd}XXW87)@T{n^CL9gixsv&{>sj_DrRbgZndmOScOBxR8VHl0H9>b7bB{>w zA;g>BR!}Ovq}At2OI?xp``<+Nysfw>(8!Q`npjkp_J+Ta2i>zFs->!_Rs4l!q69YD zZ!*#k23c>&mKVR*N8V~~8a24}S&^clgr~D_9Nt~G_4?Y>)N3p1`e`a=oL;3Ii>3V;Bx%GgEnQCS>9TVR`ERxJ_96cOXG$ z&m4Gkrq33AaHaO|FlKM#Qz_?$CZLpGb)Eoao=T3s%7rbwQW z6sv~#pz7SlBD(SGYR7{-#OG~Oi}UM$W}LN;&uB?ZX-d5uL;RrDqIIe9wxrVbRtB`` zD@qGShh6gmh9Yh*pHrymr}SuVX|kQ3zA_PE7m{(a%+9)R-YfA^(O%JhHbZNSAb@|F zj`a8P{NFqN-xKskx_`jUygamjpW;9+#?mdvn+q4*?2STi;)Ng((uCi|g3+_BB&w>Y z_d{;y5;@1!1m_Rszos=PoI}tmsEj*hgLrOErZYOKU@R>K6+s7?YoGsrp48tzrXQ0u zivQak&(PaTL~IpLiQcrfR#dU}2N+cNG!JN2>O67UdV2-$XlLjk zy*8D&r0TxelhikcsqoB($t10b_vruLagqFcEw*^0(!*JiiPTYP^L>FmOu0dqYhtwb zwv8?urf547I>#~Ul{B2LXj#2b@8Bi=ouhn6x7NkJOie?gR!{g(1^qu?_Qn3E%kZ0I zUmem(KG)VH+rq;`N27Xg;dt%zSyvGhn@Q6Lg%Yc!)#b)?>|Z7=_q4SeGPG(0=vRGx z8Had`Y^-pi+VMv2MEhOfVR(K+=tJ?Mh$q$%Kn&?bbLAJ|4FTd}Gx?jvI&J{TvW<%i zXNpPtECTt>^=RwHW3_Hc)t~R2N&k6l{W`*zgLL)hM3U7z5T8#P=S5sMf8l#i@KP4+ zz2Q($9z!mFsW4Dm7!9$y9sTCtZu!5yQeXY=lX}>Y)PnXkEoMgNV4TmVl8q3~K4XXC z$f?-n%$Dk-0?e0+_piF5Ey0OH`yrF|dD|C8*L3!lqh76gpI^?lh-md>cmL#&yw=qA z(7_{1Godd#US`QX&7w$gx%XOkKK<5OqU|l{f!+eglmXQu@epw}kNJQ9X`+5&cYE{Y znE~U496IEZ7zzrTT1}8H$KSSugtr!)1QP%S~u0!AGp&a+gvU)H0FLl32lwM^NsbJ1Uu*DA=@_RmHf@- zMR^=kh4Z>d&YE83Y#F$p1CPV2uZMKy=BWrG3CgQhIt;B|OQk((2-LalP-bY?2XEpa zmvU2%?7eg(O}h|_-mS{zs&h2gp?w=f#V1l>3NC#^t?gkr9G__Kyv)s>N>f4v$LWQv zNg2lQIlNMTCB~Z0^_&nssJPWm85%N;m?yQoO43~;skCM%I5V@)CZM;yA5`i~6_x9)rcH45NB!V2FP*We z5PaA@+`83xJAYaNWL!_IcOE;|Jv&}dYf%Nb`QqX>w@{6LZ4GNRUd!brlgsa*Fp;_ zOMmD>U)b9i&CO3I>1wisSkEiNb?W#t=F?alIUZTxdHgZ-;vmbA=5$C?c`o)iRk^D2 z@g>w7Q|S>n(Le~})n~sE9cJCYe3iLga&|5pf|}djRb9~!Di^K^n}^PMj#S9q4q(Wi z7MbVtl)wG`=dcPqy5=uk`aDdW*wCJb2jmE)65maR!BeO1j`OemzI@`1jp26YC|R^( zp5N)ryvQi@fj6_MSx3wrE!u6vGPhA2S#;j-I+dTB6dorVHRWLor!d0jNv5n}v$#90 z@E%MxrBnOUpE9~q9o-KzhD=A~6J?HlyG&J7Mgo7orNzZ3a`~)E5%>0S;~r_^aXZYj z&K2j(I_`y$7e)RSlvHC=ls~MI!`3Om!;6nR=3hK@!=VuTrvvQgv)0y0J+rlDb0JKp zb2UFM|K?BMDi_rR@I#=yZ<*CLukVM9tc*^lx`189!rsYTM=95{&w<|o9qMPeEA`ED z)!q1l4-fgbQClOu|5->xCvWEAKd?IBghFz_?CdAgHpS0@=BK2S>~c@`gp+C=FTNjH zQ0j=_I=e`Wd^I@9@uO6|)n7(O@IPndpWg~o{Ldg<;|0}&l<*kyi*(BhZFB{k1he;# z+g-_Yw$4Nv(C&Gopx9XbMC4q!(U$=Aj^nj=or!kt8wPgce@4bC(QXvH*grJS z0)=~2Y+ZuQ8XK#^u`j($Zp4qvS9pF=*rD(AgY45g``A3Ye`on@M)1A0uewi+k@& zq}qM|YuP}DdFoYf@hltGzU63`ggK{Xq5yGjA$jLOFrb*yS=(6u++RZflzEnM($6z* zJzKSk`~@62abTs9L4EaZEqxJ|+t;!FF0@`EfG&6s`C-;T$|=TjQ4S$;fH zgOk;`*YhFw$Bd+x$O%8St^R4Cp}FPgkxu8fMmA=jg_F|?>$ap{Y{2F}Qo1QID?cH_E7~b!W=bgqU7#dTJ^nxe7*uM42|7#fk`2+QT$jpRAP?gJT1EVAF z_bUx`heP3Nc?c4Mytl{$F5UR^!IygUssQjGThs)=!{On(T|+)+P%bLT33F49;H3-{ zsm^hmTlKdQ5J;|rvQf00l(5I3WjCpZEJM{$OiyRiQN|0`+h~IZu|`q1ESQFA(JcK| zhVUFr=6+#Ca&Gh)wOmu-VZfFh{Nf{+ad}LO3;OhgCH7T^L+j8kx0ptu#j3g^Kjn=2 zxa;Qo`)6^9%s$ z7%GkIg2ipt90t=Fb?ZM;cL?7?fw;@2bd#3((RqegQtztIx9J&`lINEjzX%$Nx)xP- zr$wY1JS^`osm)B?#=}!HaJq`3@P-l#zhoOqR1*cXhPEFf0P9`&RmaL>rT<8e^P8mZ<9fPISOhfswMdDao$C?Z=ajp zwa?SWGO%BKwY^XXBVYf91+k zTPu0eu|`~?i+zgoU*kYINw4_SUno#%^vq7wO}J#WuCX;b*}E;Rpc-VJR_o?TK};*3P(NPmS}DFA%X-{0HFW9u5b=!y7=h_JbjawPUhIsJH7>B zaLj8i?Y>S6f=Of^x^aJV=zu?=p#mVOIgamW-@cupbGO^!GJmxMJ%E3Y-H8E#Apmzb zSkv==Z{EsU`FdtbC)%z^jIAuy+E$-xmGVQ_BaF{Sg z3I-7rmnj(dIS7INC2JXHODbBhF~9bE(Fu?aJ{9{hk3_xNS%(G9Os2*S;9A^V=KJ<- zc4hu-C;L7iBx?;NZm*a%?DW35Wj@g6Bew1I+ukDb&=aOh3>o8nnX=X4@~NGNbf**5 zKoN!SS&v4AC2s~Tx}(DcMV;p=hg*AwU!j=|Z)vU?pJ-e?jxyId6~75M?#0C()3If9 zg^3Dmpav$*Hrp}W{FBSRRnHySyPdda8R78YP!V8{S86`qOl`P0r9Lr$>$Y3`tC ztIRW~c&09YTj_jaVCZbSZ&bK^HzA)j53Y*l<>CEjc*Lml|Hp;+ zUTHtc8e=dCLN}rY5Y>obs5!meYXiEgVgw)QX)z;J_4yt!I|7IXvZgR{T5DD?{(O~1 zAwjTp4N8ObM*j4A;0c_8OT3SnIgM+dPn-)g>O8SuAjScFg^_9X7j|7|7y7~Hp{oQV zM7{nA2nuSwuH@{#Jx}zTBmD>fLo3;7s<~hQxz;Jkg)&Ex2*3=u^I7%=+xx;tj&Ce~ zOu$m`$J+cTlN3({+o$)ke^=t+6E!7#k`3~N*YldyK&aM`#Q#r^c&V!Z)`#A1YGfEa zEGvx%ME#Z>l)-8QamP)Iy;N`PU`u1}%gAGT+Y!#Hbzs0_59^B;k6qhhdcFe?W(sz? zawVnp$9Km*7wB>DqFg?$(S^2%!oK?p)f^8OzJojhBM57zi1${@Bj^C>tz|18mD5va zz~PF?Q3M1Qwm8s8`t>!R0}-4;)=`T2`qB$jqV*&k>)&6fzdrsla)yRNWz9?(!v$ip zJ9k!2W+c*m^9z%lQSB^!R0O1>_g7NOsMD8#A>1Fw1Xsr@p?o7b14PXSTMlD(*;Che zkJ-^szqp+*$irlkN*<7~np9L_VPh9+(s1}7686)$uY0))wETB9^_R`{a+`}%i;I5b z%j+c5vubG12o<7{y>`_1$?NFBo}=vor*S8_Tx|B+luLv!YjxT89Js43iX{iI)Y_RL0&drv6sy#sZ@;$#!5qmwXDt zBU=;XLYI>L-AlfKnosYs?9qxp*(}Ne=&)fqNB{tJ`BY~MFHe7b9jN-_!}Pcuy3nUa z0Dqwcw~jNd`fO({kUUqv>PQWmD@~cx zssN3+*6Ho+U~h3ZSJ(fJ*rezGqnz!>z`26yL}o}w);qLz+lI&};^ks^NRKjaNG`NQ za^wP;K^(!?Bq;n#&j)4FBw?N^#8VT;dh>>(e~|0OeKua^MRmgOhXy8FH^^pqTy=#8 zx9|L5hJL_C!NnaygNfq;KAwAal&zLo;Q#pfuRHL`ytvqhiaKw$3_?!!cX+1+jTv`V9xixQ`5sAfJ~@jez( zdubPi33|6s_#9KYiX|3W08S(JPU}gzmQHSCTErXju>m_Ae1rV)*27K2}X3&7qm*TR(Y+&+N<$L>=r2ocUpXXqo2d$mziHgrCpCK=< z)3^A-plhw2E5-3Ak1t%lGJi3M8)1B)7WVO$D&8PY0N~$P&2p4feEse2^-nh1QUOok zP;XfEVFM3NT+59`x_qzQx;m*_PHwKpQPXeT>FHKbIv*6If1;uu=Q1u31LE?ZjIvNi{x@?DCx|$9R%B?8;RXTKOl;cM~k>swh1h|^> z^}CrLA^<4A+fB!SpM-kMJ~4`VhF8Xm{bi7V73MBeu>1>!I@VgFtrQl1aWM3XDq1PX zX~LMcpXk>vXsN?%6Msh>d-l1+OAxVek6M6bT@o`3n0A-+{Vo#67QYAe)6t$$^K|G6ny5zM&vV&fV+BZc@ zX#YEB48S^0ms=kXVXH9Qsva(X%PNA+@q!^73l z{2m5JbpcmbvTST^LmU*(?k_L{;lulCxkBER{>nFvb3+7ee7ET%JbO%{3Hb2wiAsbw zo`P^fI$ps-O?fog!Z8dVf+F1-9luf2%mb6FVB^pVHH?3k8Cl0(TkNw0ap@-!9T_2h zt(HgQvX6eCKQtb*(t(&OSu!}L_QNSbpFEvutl3x|2Id)AlMOP~vWiHSDJwd_F-RZq zdkxe>)62J(c&ijeG5*2ui=KI&C$)5i1ji?{J@t&`-e0@A210cmem^rz!0{dImq~Nf zor=2*a4NrkBT05jkeK$=Njt%ELkh?g9Bi@)d5}k-v6hSym-|XdXGanW$w5mqpdM$l6RgA|Z` zpN4P2y!Q=4Q#M^R*O0;N{{0*Go#$U4Iot{GG|=|TXeOsW&0R06jL4*YZS9znU5QvC zKL)|ow09NEW7|OsM-PASD#>>)ntsRvpJ&g{kKFDNDcPmKhM7}y+QuY{TOYwoY-2HQ z=R9AHqiAPMgee1W!_TRX4||T~Nb%hMQ*wFa2iwv9C!fuH5kHY>xu0E@!=D2Y{xMZ+ z(K{{AIJ+tf1=m=VZS9Zje-Tpq)nV9Dd zuu6~J<(=Azlqj!Em~EKhj^I?w{*FGNS{2*X?nD#NI#{) zr;&ax+`tB{T}Pm+qcqb2UmrKc0Wn(k%dF8lA9lIPH!KiH^*6J zkVAYwX8G2`rn?u`<;#xFDrTbESzF0`joUk4Y+;Ae3qQT@GIG$%VEgV;FCuiHCl)NE ztdKeKxVQIHJ~qtwGM2lmv&u@S=VEDieYq1|xQ&nIk8`u1(bX4>$_IQa+aMDo(mwg0 zX)@^<{rp1Q%gd;+<(va(vsyfBc{xe2KfnxytBw}lUlx?&OX^F2l*`Jg%T>TdKOc)` zJK16x%E`*$USKk$U529ly3BMhzM8T3hGo$cvB3IZaWPF+^~9S1!5+-#{(LP0pWqcB zz@h&LPBjo5bFE;WF69@A0d2Fhws}J^_(28$eYKC&%^DS2s0q4&7-PEG+4w`@c+Fhp zrSAm`E(7jehU_gQt5EFcG}a+a6W6V&{hyEJ2>2%iK`3lS5w@hF#>dqK(~_tY>lKyp zWs&2LGR!B7>t?+UB}nvTdl?&0Kam% zXIy;6<2p;-w@)xq8CL*m1Y#F_OtW2CHE46#q5BQIsu9I(1B25(@Ikd}9_c;p>TI&v;Fb#IU9lm@l@dPLXyK90*yc{PE5B>m+pDL87UV)9)9 zTHAG5dmED9`aP=KK}mc-a$khyNyNbVJQx8nMaSNjY^>~g@8L~f zJX#}e>nm=aBp$q@ULT*=0H+qB21Ndm`Lf59`Mb#yG$j3Py~x<_4|H2qKlV)HhL%m7 zuNrAqZ?>kwB)e5rqaPq~CqWDs&2pefSdUM;u{z`yDJLD89IBSzf`1t9 zsb`pB8Sv>Xlphf>fgx0lE^^MGhP|@J=9VXyP(i6tcXNkZsPC(*bWD{#Ch*x&k5bbn znsCm`_jhC1$Mh)2dUQw9k;gVZ=JZq1XTGu+#OTxICl?SHg*JcVC*K-vi78);lmWO$ za;)^lAQkQdcn(FbEu+0YGDEC)&|2Cc)~9L=Z?JtDrUD@Mz3K^K$E&S}NB*Au#3tJ+hxB@On%+p&HFOB{>XA(%RQf z;2*tI<7xvsy28W(PbQbg$PLDy#vco-BD3y?#D-r1{q5GD5K2AY)rwn%dXC7yZTc_`>&3U%Z2 z(3Be7T8?hk)oblgc9e*Km>&iNa}!Xq-Q>rAM0dNo3Xc2_B%W9y!-tLND6Fxl*trZg z)wf?KU7l`_bJyMw60HYs$K{(Sc+k77_V+UFH+C7TxE$HOlD1N?X zvCe{9XMA1Y+&&LnldEM}3U!}5J-RLopvQp_M=|c~-r|~=?0pdhF}4!EcuwRQZM`%T z4f7R$Vd#(ZQ#F3xLBF zO>R}NB#w^fH-HL}?9^9se~xo_ZMpOY>RYd5o%`MV?9gMvXS0X)?QXT6SW7~y(|pcG zb6=@lR@J8g&Aa*aY@_5(@$|+>Ee_kVVz`&J44)+7L>p?s#f3iV3F!;dQw>9FBK_mW zQjd5zPu-Nc=?bBq?_Lm1lNKH8Dg95Ih~fLwCF&GKBN|&(#!9~l2U#==0eBHiVHuX$T{|B84?_pHCBh=Fv*!{U}BJtil+6HA*3ja#mW zrV>Lak(m@_jd-0rMz?;{Ze_NJ zEW!h^^Qtd+`qQ@~2x8qp5NrpQ%7F+gRGPQT^w-t7N5`cp`u6EVml1H|IyhUkrh3?z zbIJ>PE(PB(&qawM1}V;)MO{CBzigL~i4#V8#o?J@1%P92`|g3`Z02GS;H7FDb>M#j zW2&60u{hxT^foMlt^8WcHw>jCPH%`W(o#nRygDliwzRQkUr0hxXX;d#jMpJhtBbtZ z<$Vz>^dfKiJ=-BjU#uLOvHcAMys}Dp3qbX`9MM-WRTL8TfBK)Vogc5e0IgfW!*Jda znk)NwxY<6@4F=-3ek+A|5kinLcHDP^oSK^CGL42!H0KA^a_M44UT&tE&-5S0+^9B?|(o%m_O2eXxsQ8f8s{ zLh{dK^%i`{B1C)aK!|GU_CRMdO5{4uP%OW{j?(MRnA74pj8~&~X_Z0Ef;Qne-NBJQ z#Q(z;LqxEx|lY1VWf2V7<=6gW#j7o5ce$#@MlR23dZ~txry2vP6 zHr$mvURxzA-u)Dvb^0$N>25BnK5VEkT3()Vm_M!gtAE=*w`+F=W529_9RPv!p1^>s zP8f*yt4Br-;1cT&O@TgQvM}#@P=9_b0S?agk&tMB?2v$<&5Ri-Tg6&aJBaUHQyI zH;@Nbp+OelUar|rmZ@{QnOps$*&)JzU3)4hx^If4@K);ERE_-h>EcKBL0(D%kA+i? z*QT~^Yr>dll!DgWPn13H-c~B!Q4v1lca5>~>J3-tHZyUsbd8NaA3d({>e}+;nysr2 z#he~EROh&RFDmtf4Cti(gmITI`qzcTxP%KlhO2Qo_(5Lo#23H5EWd2#;E&v+ZP;|R z93oE~ym`ND8yX7tafUhbke1F{ixpxh$pZGJsUqnI{@tAo>7Po|J0cf))K>$$ke8XH z&Hp2+D^P=p_Wj%Y;>yhO8!vsc@9(HSs-Y|ME5w2O*!Y!{b+uFa0&0lMIDs(n80-xq zVm)>5VS4*{y!OnWTsa=V^=68YzowcW2YaC}O+T<{3gQi1mKxi`{x~QHTCf}84}n`` zjH4v*^S`Fv+#01}wn3bOwL^n<**4-6_Lx_NU@?L{`)f=3Fa2FNr805m(nL$d) za--frW6U)(n-Vgz^bYW)Yz%vV>o7Koces*$uYDa#A9ePOSD1nZ&$Ph2yZDM79b{N6 zHs6zjSP*6;MnX-A1_Ewky}qJ9*BgULwQgY)L60ehDY!yp$LFNZYF_V{S@?x0v4m@* zHC9TaR5;bC-+Uy%bX9prHO6%5k+-}n0h_6?wb zD)#}ejIMf0h-NYpnuxjYru{d!f{7e-VV?ce>}v*463`)hs-s=s&#NG4z?D$N$@@(q zTUFp^m7asv34<+uc`ifiy0}};65Xd!>65y1(3vwoGyqgi z?p}fmT(aP1v-d4S*kv_d_-RewsDRgZG^u*XJO$pnJ+}T~M6pOtLgf37GvH?TL0@(q zd%uYb{!`kg*8Ae>^@7h^sPiXAjmqbxQlYlBxICNZU2Rky?I70Bw;Km)?u=M#qpesA%9pdU4GC?!>M*;2YNc@{`A;ZjgXXT2DUmB;&sO(3l68nrEr!k+41T0T? z_g+aYKtv=wN{wMicbxmiRE>PZP3d?RS1j~&gOU~A{;qZ&4kUlU(Z0si^_ay$hNdSZ z{f4!4x~=C#xX9sBUcC|s5vL|%g=EgI@pIFxu@V#b`h2o76gC%hn2QXpb}jVu0Sfun zIKcFnDE*;!V7Nuk$8-by>2LS@FI4_QpeO|+NI!@qs-}N$RuM2Q~ zU;pdjp68vNsW_YVCF94twO*ToCw|qW7XD|lYEuC}uDRe5NJ?6gh`4wTuN|Gao!tN- zO3ilfvATOK&V@!gsfFo0xBc9yZsVqy$dSFR%^f4ausU-mfwHpG15wQ#<~yD8f@ zhW6sD4fTK4XL5e4co59v?UjU*2H?i12&^fBmQE897(XG>=C9BvbkLmu+1-`Ve2wp) zwY05xdbHHi4?oCw59`Bc=T7#>^R1|}mh@Z#@f1N^F17Ix0V!;!5%lDT7_I1%sF)&g z_Jd`*e$5@{tm&36FGT3L-!-zWV^t?W4~vq^;E!_zU0(hMb6i2xqkk=61~L*scEArz zzQtSiN(MlfVfjSoUp;)qyhfD_+#w(U%m+kiFYclP^+V+f(F;}a%i7o~?7cQ0xL$K~ zMN9PzlkqmD?~SYdCnpf(`Iw6XImAF}oNpY0a#Y4~l%VMN?}gjaW$~!-fvabj_Nv8K z7oTib!}FSe7bTtb#6yy66;P>Kjp7hhCU3}uHB zXc%@l#0v8A&eqh~YrhkIwvxRii4x;h$pZku$CqV$u8M3|ksYTrc}&<3pT$mBbYcWI zD;WBpQA>byM8lB!-m0*!^r(U*7AOcQG$MSX8kF1r-1lK}*!vT%9M4l)w7pPYIMC_b z%ugjHsvE_L*G5CR#7A#bo%F{L^qu>T9zXh(1YxN6v7>mldS`)FN7)NEO~1!q`rN-< z;4!Zlbd<0@JDuknUp75OF7-xE`E1`G%Y2`jl9R2ZuSTV;?0EG=*4DVXja1G$kNLdH zY6K^y6w#mlnA9d+HJD2waf(na4ENP>IW*s*`5nP10&8;TVWZ;={Q>QEWJ zg=TzAKc8ik?UU#(0y^v2nwGzM4iUBAYM`;_?HnSebTiPS0Yn>SS zU5(NY@NRBbKvnP$ac6vZ1RZSJFS%2hGxwu~R-g)Y>Er!N^ zM)Zen+(nHE2?x?R6Tm~5$eMy_!_JgTp-&lbst*YMLm>FYQ5#jd@A8!7*5-V$X+#=8 z38YFI{k}TxjEkN;x_t+=4fG%wX5pjp4d%8D9`BMp`Ag4#6sYC0HM_*0%!X+7@pC|-Qi>e3#4$JbHVSARzNJvONbL)2d=pxij6qhaQs zUe^B^W*EfW`1v0&b3?8Rf}Rjx#>2z9{_K?y&}a#%VVb|*_x2LI%#IrcGZ>OJ`s4U@j)yOs%dWJ@gufdD)vYsY7E>x6qJ`E2 z^G=XTWuG6Im$fFm2NE^#*H(ywV4mhHA$Ut&|4HSm#i?xAug(3~zYoz^yl>)j&+aGe z#~ch723c)W;MGD-?YBD-sJSq!u^^5v<1ImNIV%1PXMEgxz;AviH$vx;_g9 z$P2C@wU&>zp$UO;{mW=;nJGVP@K(Aual;Nso?v}tY1!R1Q-*<*nab8)%G}NsO(uhj z7V36x@&{0l53pA6)p{QbW}-<71LQUxsTnHVI2AMxvAys$7C-KXjU)=CJ14&L517xb z$$e)`|0u!TKGp=jZ*}y^kcIK=R;58riS5(mp?uCezi$!vh`KUj@EX0hajq`msxQSo zDrqs{zP8%ou)1tRHmAoP=5KWSWI%ycVbB07@v=poSuyb#UPGzYn&NrF3qgUJTb{J` zJ{F}c<61>)U@EgGAAd3Xo>f}K8RkEkhl~$1P)$?2c9#7h=vwoNb6l?|`}g@W ziRAB>l_Uf@x3y!82$EFcK?-D| zq1_pv)55qLO0W`XkA9Z9i-xL-ji!pUOq0IBT)hvjh-Hi4DoNesm}&w$g(6Xet6aFp zN+X-WIU)c)(p#07{*!|{a#JBF+P|Q_>yj&D-4~Hw2EN~hW;vTI4xi-Id zUUqP7yEt9vahgkvh0Gs|59(#W<{LF}K?%M8+_X`vn{8#zJ=dFFy}0fB9x#5Hi}~#; ztb=VYlxCieeI0To{d9DTgV_uaopSOr`GCN7x=(0iw=&~Z7DVq5J84yr_Ov}y*~U1G z@Tl4yIj8ebFfXRY2#f{F!%n5%IaGc=IdQ0sZsRIhuQvDh@=VwPOoz=0FJXIaCx^Wc zzqiKR{qkLPBzLJ32bW*)hF)i5iAmO^?|-?&D( z&NpoaQ9N4}^&vn7Fvr(m5Ypn=y~X>D`K6Je8poAvT>Yn(hTix&G0|cKhOs?QtpYOt z3v1%kr5!Iu9d6$}6zp6N1IocZpDi{{%bfjaOT$x#@Ptn!+V8HTcW z({CrH7#>;Ibge=J?X3YF7y#E+`V@0tpV&7jXFmfy(vj;BG-m__)@9v3jqIjq9tKCp zA#t>?g}ykO{myUf?A^8VDV|}=y$UO`sUOMe)_KOYM@{)EdQ$91p^Zg}<;3M!JzDz- z<`#K}JW@;+zn%H}en}vh?>laL zUE1@n)+V{eN_K{)VT=Pb17_!tO!*F3e}xX|7`e9$#zUXLkZMz#TLcSag2I;HXE-;PNT`Fjx7xcQ&Zh6JSkL!q-oj6s46-ImGq2%C;s{WiEH`*& zJS}}j^>)%gs)Oijqu%P&->Hl8<4)cg2LOmiNy#HKALT&B{hFdWh(8Zo3FvKh)Qx}B zlBj2=-a|4bbz{{xWdBzw8JR$NVe?5!`1o;~{r!yTkoz#+%|d*!Q3lT|$av9KX(xhM zI57KEopqJ9=Y&qDORe!pYs`C)5cFpZ6sj)?mM`Sr6>akl008o^&yO#p6T&8&KO@=% zj0XkmJ>FiMz71aDVn-cN;D&wd!6PG>-a$+7XgxL~=OAw`RqKmmGN}RxzNn11-X0c2 zN~rdhGPFic3}6ZTX8fxtTAN}{SY5!Jdi3th$Zg6eRhCe(2mMqjEfB|jwef&zQy{gn zTZEHfZkhi>s^RfCcti~@z$>5B+}7sqG0Y9B)(hDep9&C6@gE^2)p?NDtSUj*0{z5x zt(8)Tl7#vh6~1vmwH}ksBvg+Jk}d*@DPf*$&RG6|AG*(|&2rTeR-Vy`G70WQH0O8h z1-Lr4qk1p0QW7JG4C%-*W9{|Mq&^ZHvvbsO#Jk09+P2Nq4x~X7_=A8EGw+~3ENjy{ zT@JT5eNZgi&<<`0C{>&!OX9ULKQ0+E*=f3k%4A^9 zB+=JaUPe3=7|F>G&7yX&_2@}o7vCK>lX*oM6q{ee^to{pL!M(9?Q8lt$nZM32vu*w z=Mhwa`{Z@kEpcq-2Jg_J{h<-QKXm55#zG%D$Z)0wQ?k;8dWIq}$d{@w6j3wSEK}enth51HExjK(1 zjC3~gtR^?#)VJd~)0_^YB{n3VQ=i9QW8U(yr_(F9euP+*yUzJU2mr8v^7`2sE?{(~ zO69JvTB@a?UE3HM<;A%qy}eO{q|bcU01~U;jJ%SR>P(S`+d}AOIi$t6Se%J!A;csh-peX90p(XrKz*bcXl)?B?=8Ld9*zzvl?R%cFDI+;20osJfh~KrkI#0oU z2ykZ%!9F?KvLJKm{d_^fcjY&M5|`x~fhB63Ol$yV_$8VjR5032dTt?MVVe-b@;*&+ z#KCbRA0%u*a0*9u-63fl7n^KF^)l3r;%W0f;rzY|xDYm)DY@Qk5Q~v+&9*v@&e=`! zJg=?zq_IIiKns~|h{Y5vyxE034Hay0F+gs&YZZ__kIjS2H@gf2oRSaRjUa)<8b7(A z`*N{LcftAetUMU5CxmuSMoI3_bSjwwg0p*LqM!rYIU^SJY{QNypCK~lE8U}t&4f-ekEcE|1Ku_^!rL)*&%Ia(s&_@ z1d~Ys&nqG=!iq-ZxgV?b%rDXC@kSTaC5d|(a8KBnR9J>Fq3ux7To-acQ?pu7oH1C< zGm5G;_G62u_*os=wHh(=&iP#p`-ogt6deWTBR%|zoO^(IXI8am`X+n$@S613wf@4h!ghxu)UN_$Y;eR_pC2G zLY)%Iq`_rRcxfTmPb7vZ{*{0GZH=uBo<{*5jvscM@5==misH4r&R98+8gx^QxN7Z8 z8})=3FB8b-@OWn7;?TmS2A8gDsx~k+&-u@IbdQ&&1pDVjAR^f8A(@=_ko$B@b7soy zDNmyvTtLBKe|@$h5VFBS*<}9_!q2D@DylBX^gCj@t%fPvTq#_eC0!YJR&c33u@U@{ zmh2dPREe+cSpC}_fpo{;h})fSM&*lr>{ABJbM*w`c3bWWA1h3_qFFd|UJRlZ8=#}4 zy!1Tqv@xkD^+==}G4yid^ek1)v6d2gi5~mr=w;b%m_^JvF#sF zAN`r1PB*jB3m}ZK!it2TYp=Of@|;-2c(g18*$J1FL_J8Q85+%FO>NZ7;A$d{UK4Qs zUtbU3e*=aWXyGUtD?6lHup?%xTiHK9Jt)9YJCmW6retO^JlMa^?dI{eG1zPyFAx7n z1AHb3$$~x&QXAJhF0Z!b57S2|jUT3_R$OwuPlnX~>F7g&*faOcvh>R^yAg;NpcK5s zcC_xo#^`@ncDJggH;M>4iiQuOd$d7)F7ThFNhs})kRVa5A7P;WWTx{_mCNxxvH0XhUH~z!J=EYS+nEHE)58V+zpgs2z?C2it(8UKZmN_B9FOV zlH|GY&Ss3;YJS`K8P_|9Qxd^^ikW;+mN2^L-V)eCKfh;F<*`?@avBNfnXh!yQ0sTt zb)gF6dQUlm!v75cXNkP}`yLk}UmwOB@&r7kcI9gg0_9Y~Gf#asd&&Ds@(gUWH@>h< zRB8sPYCDiTQ@bE(S)Wo=7UyhK55BmGh*J|u;v6d}UWA7ey8y@eS3ZaHa&Huy!9r5y zQ+#ar7^?Vg49E^Bd8D@5yKx)qO4~M6TR^@p*>obXd~%^l_Ax$Zj>pA}u*% zRh?(FndHJ=&=PC0+5;HGhNjn{IN!I3=v>`4iA>tkFSX^2uQSqc@A2pt6f{?lxXM4w zca~?M9|(2L-ZDnzTYrA5VP}Kcvq%0#tOaMfv3o-$8#6n3L_H6E>)<3ijmFDEPl&H+ zBMZg;(UMSK&KXf^7=Cf)o$FM4R;$V{p*P^wS>{>oMF>S)6&{}V>sgq81RFAl4s6-I z98H~1FoessZOf>`E^7E^R{^Dm9)#2`qD$DBbfA#z$?^0 z&%h+VBzm*m%J>U4tTUeN?x!zIR!rp@uQZ?(c$@Ysu7i25=F6X+Lh#1jZ>XK9*cUVf zyAV54AGxi*F(;af_QBdt_@3&=1c%R90igGn9UsAeD)c(##ctb_&a)ST zW;IVDk6-#*lcrD#_OFrrX@1Ru$BaE0hlRNyQLJ4@5A(NvE1fD4IhS+=^gdyI2!g^A zxQ^jX{`g*}3FK=G18EEl1K-bX54?4-fvT@DQt!S06F#lY`B9&MxMW!L2wbvxe#|J{?7k=g0_wb-GMVf2$J^w4+icCN}! z9rE<%-!;)JsSxjkVmIT#JByos3)x!%KqBDoVL7-J$#o|b!nM=2k@}dY3{w3iv6-ico{TBj@i{6TKT0rGlvtHu$350yCf%@Fx4?Uy%RKaWvomg^<@r7h zQ7HeE1xOa_96lJ~5r(=gvkS*2NSuSw-P$)mWnDIvH`vax4*15Yn& z=aXG*F=&$DZ~wI|WwG}QfVU{&8zC5T#F_ z{ioC!=}|^|(A~7Jf6}WSgOi^u&oWGu7o4|&V``mjS|rm%L3L*rzhT6{vJr6>puoQ( z@wxA7&nU`aGowV1ntsS;n2DGtz{ViI>35C#gSFihfQ?6u%R|&=)h`RletQBsX)A+T z9Ijnw9<);`^q$hM9}Z61)?Uuw(pp-7JU415(8!(Z`KX~eI}dpd8MXLvp7pBfz?n+hFO;G8;a^yl{UiRf~AmGbs%Okibioj)-3d_dE6S8?47MOgA31r`Lvb*axT-bxsSnLRH40K~)oO}xw4lHdN|<7?N8AV-R& z=E%c32~DRE%V-vA5f#?lA$t@7@kpe%@C~Otpv-vR(28HBl0U=spP)+sJV-8M(gw-c}vw z^$D{$I}WtsCunPiwcI*)0Jsls7e7R5Bt%l-;u5|#G@db>=+MHu;B1(Gp7EkO~sW)X$y-eR+zHK2P4_`A@>yvR!A zc-(x4sQ9GwXi4UsNtM5GhXny1UN}~tcK!s(TV>aez*(g%k=$3ph;%BzX z5x>^-vuk`}Q)NxQQn?#0Mil4z%cX2(^Y2wX{+3!3T)gpenpR~vC$E2EYH?Cem$TP< zGXIkohxxUr#n8`b*nalMPEn2E;nadhRs9V=*AwzOZlU}}wQd;lKY7o2jF(|Rb|xLa zYmWQ1$Gz7qXymOjm?%p7bXV_^k)Lnr#Lzuu5dCN6p1$Z3*Gz;9slBTyMo%1Ky6>lj zUW$K|bI%?=MN#T#hb1!sFm$Wv$aZ zzS{}ADzI^?%_q49d)}}WIVUj2%6pu&xCP5HE}!iM9by$Y*M}JTAAfaFHeS2`hUQKc z>VgN7b8t?$%_nv$DY#0+$0TPzZC3t4G!` z;6Qo|b=9DDAuHz}ty6kKbzXZ&!JO?$jZ9i z_>Ib3ly{#Gx*z_#x?4wwObgpt2Dki(Gu$l;nrlW=wnNryxJIXPlbvCY^z}@Wa~5As zkIrA#e#4rkQ<}vF7Wzo=Q{C?fH4iG<@Y&ISiF1o7;@A}p_m+kttKl;(S`p=G*S*Zp8+7gF29 z_X@%gVStM;JW-)RwT}Z~wy|Z<7jwHy{4KEk26sUXtEb%jtpgS{I-1GU(3CFWaSfz6 z8e9Mumrkz3h+{slC8(d`Y#vGVk0B#qi3`oDk`hzr7=mO9c-sb*ju)8(9nvZ`o?p`_ zy#9rGNJK-PSe0K+#e+t=B_g^uGU;09ltUa|=ESN`wn5j1K8U=sJL|!yz{%p4(y|L0~!fgV{5#>JG z(NbC9oaE-Go2|yXdhQabgS%Dr&pYpw@be;rnmzA;9KB?rT8aa9cK){4oKd%6ucSsPw-G#nx0ApR5bSIsaBXZk)goftV-C1}z z-$b-JdUtl8q@UHD=3dX+J9_|6rOxZpwCWej=bDymX~f+|(`I2F5@O3eH4g&=NDd`je6_FC8LqJ-(TR}iUKv4wgkQ}-@1q7s|8>C^7uA%P!)c^eF z{`b!F9OVq-#5eD|_kP!2YwZZ!z%Ny3K>H+hz zw|Tpg@#vF%v>kM^J!%QM?2CchWt__A6PRwEj z)Syc->`VdF9^#j;`bwNWdcXDuuGzpPm@B zqse)N&WV_((}h@WPn_C?{Dx>D;9M@T^jHz}Q?izLX=3YuWS=LcW~9;HJ>UXPr`YCz z+3*go-;OngLr@y@bVw~pdBkU?5MpdR2h#L}SPIxuDW}~j5!yA-T#GSkF)H$}!BZj# z5n2))XhMML8b1NiCelmf#GOGUj3gPn3Rqw|djVPVQaG36%XyKvkh$UkRAV0)D=!T* zRy6|H4`Iu~@msPrN<`<3HC|T^q;PN$i3@w~`OiC^L=ZaT zR?t31)*+lhuYHfH`>jL#b67PofodT~;GRb5J!@xJdtIy$>pNX6$kdG2JngWclVVbteq6tWcX<{m|`KeWDhGKPM~)@iZh=W4G_5Z$ zjbKOC`T(ix7wW_-P~#yoCvvpem@?2#krp`kt63ebUqbTq?df(Q`3}uuSsF?2xfhG8 zrGErOFThskH3f$f`x#zHUEtoBFvkK;CSWXtg&56fqy#ZCRry?Fkd=y=f=Yxy;8hkl zFCxw-Cco}>*&{S|h)VRsRd6pTY+KJU12IKn-;a7J?1q3jmZePNsscyj(Zwzlt_`bJ z(Yz>PoT)ah16Vui#7BbW?{B8ewLNEJ;dJN@qV*{C`x>z`fKIVr-ZlYO3rP@iD=8hl zjFJwMtbcF<*-7F-je{9y@YI>5z`HXR^65_*2>-#GweiO_~;>Gah&aY<1w`%Vqpxj zM!fzx9nd&?;LGchF5C&yYgB91V5yExxyqbep;M0Ox0^7HNAI zOaX&mHP;X|zb^J(7<-N^waP&|Qj{Q=aM4>*$}H1jr~V~A>PqS09IJGn5ujzH%M=5J z*M7!Q`arFcTh?Z&k{#fyJFDI97tX8WI=1L-ZW=1+3q&9g^(cEwwH5+#{iCIQY~bW? z#&9D%`^RWi6}pIS8CY$g`+J9P2U08}N_?jK&`!q>{OD0K z>}HjmdFbnBUrg?1#=M*N2qg}wHROcjy}+3R%%774XQvgYdY=>80h<&tvCKOuN>k`z zq8CjN?r`ro4OKOsecw2mI{)?`_7@u^-`Rm33N;n6ZxA>h_-dT&Jyo8RVOH1u=q@;BT+*6NRvHp& z_qA&OPtI8UA)fiFBN_~Q5*}uUxj8@9|1qSoJ0k57&;GqfXNq~rM`fRo&UO$Hue1Ep zM+`53A43?Mf*z4hUOD@>1Gnu_7O$;+^hEQqKf$o#=^e7*WRKZZyhLzdxe@*gxV|=r z5bC^K1HpM0%sPzFM;%vc1=`(S_I?@2<#U}3d0Gt~C1Qv?uw?_}%N*IP*~1_|uz#FY zx<6y29AtK#D=x6jVUy57`7O_po7`yF(^qA$H(t1{c?EOwaH@+tT*oQpKE7bdw zh|m+jv#>roEO1}!G9W%IldcL>fu>+6SxOwxU&(>*iWC9~WGPQU9<#bPo3!hByrlYT z1F-2^UY7%r4P>2&6JWtisyTFC381XGG4XI7hO=xePAWE9x$mB2WfD2Xa{^8|@Zhgd1tpsqb`LeIu1uKadYE_Y_3Sof zvf6I{?ueyZQiWaT7qaRc0njx=v!W)=n}*6CvIX?4gFwwIdy@@lD~PzUZ3~d^rly<6xxnL1-57s|DYu$Qhv2AO;rTjbP_Ov&1=n(MtB1W82+` zrh24mTpsBqL5*L?YV=SOAJn309|RCvXxj~e6gQDforp8wn9^(TL9a<*lV4BBK`zt} zuvyTvR&P&h+Ct-;Yxf^6ol+Q>PVbt%>d*qYspR%H9QlErT8?3t`0RJ}Xl+i%Yhl}? zNQX}=sejmwLspB5_dp@Q8Thkw03(#QLX&Hv*HHy;7Q9%~VE9cOCMv3#7-x~Dg>S*i zdG?SyHGeZ87lFG<96tl92eyE<(w@L|D;V?z6Rcc7GBS?Fp;!W_udT)EAW0tph5*;E zi0$jx*w|9hgzsUz1o)-#ySYi&r-_I8$%v|aG+P_q3rJpoEXI8}0u^wC0QsGA_vIVE z`7RIsHPE z`Pj-tko%d)vd{Qh=tu!u_+IkY)?XvN_Q$+0%Ue#lcpxPka0Zq||Fa@)~w` z_eto@orsXqT(1&x?HBjG%(nIP;d?&7K0$ikYh76#IY52Q{N#)zS=xtTPi_!%3!U@1 z%f4+m7bDJoursb*8n%wT)F{gRuz)(YDDyOegr#=7=&;$ad>N04Yi*>@F}q(|!Zj5h zfs~iKX#=?m4q@i+_AJ*%TI-JVM7BHJ4mzU83t7!?|2a4ze{zf9U_)h44b*ep%hGi5r?IqSy&3O5zoBF3D!us7VCZp!J93{id5A}#FMysxs zx&r_(P{2c+-4^RR*gb{l)&Ml70&a=bu2|KZsl@AHO+oLEE|iI!!uM>edR46aJmkRg z$PC5)tjF%Hn;s#8Lho65!$JDgwYm(}{p(9!^A6Cjf{yy8_x>!AdXF97sIyp_C^a8z z84IgnNn2GHoFRJm^iG6=v`z2WacK>r0qL=Oj$Zt{2UnH3MA%`3yOiDM0i!`^Zspi! zZ`xRQ`wKYG zoe3LN1|JjZ{;L#fgY)7pM2U64+vzMznJ=KxajWnakZV?j9L`1uKVb3VJy!B zePk;MPU`^F1q0JGts97a{W?yB0L_^_@r4&M3nA@w9k*okz2~dHl0ro-vID*-itn(tu8BRA-{F)m_>S?HVPm(|4MhL5p%h1QkGfs4Cdqc%u zQ@1l-pz7hY;!)oEJYDJMq=go{%g>8Wi>1w^t!?|l#N4b1n>(Dx*Nt}XrXss&oULmY zC*WCP1N2x?q&pP3`N%MQU^v8cq*cV-J6DQVXr;qj`Dyl;JUvmOal$(|sE1PjW9T z<;p;H3cQ?xX-_!rwy+|CQOEHpW=da^qkG>Yyvpng)T#<`Z@q1wo8`~c`Vv&?US7mg z%zK{iA8m98Zw~WYp(OldZ=(z6_H<_UD$H@@@phe z?Kmd*L~xMdDuV(ISa#@wb{A|0t_gl?11m2$oaO>Wx>+n}D@lW4*t4Hs93)A*E{9Bz z@IH(pXJdrw1*fZ1a)P)I5ithRk-DP{yF}f!_#I0VIhqpSF0*ksSR=yBkl8j9yAy%n z{u`^Ls`FWyL4a`>%um{Z-Ng@#ih`wsQ@!4aV0g#tlT|iIUNgr6*#}+=Wq*8Pv3iA< z9`VUhD@D{zt><2F#jL;FhP1ytLIL@3q^jHDMY#KMVKrM5|vcz8l<+$*G1230MS15i1_6>lhiWs4S2o8L3m|TJY zBzxfIc?<5ld!y;RkXHZRWo^tQLEzfi7tRX0xdpNt7e|B3{B}4G-GJv#j*p|+agePh z$b7Z*=bnwv z?$fMPrVszsnU&cP&b9%T8V>4xr99;Lpc7|$6#3y*v@4Ms`rqg2=Vr*vF@;f1Uupqy zAZVN)DD;C=aT|WX+P}bL;N#nO2^c_tt^4CdegdTmh(G!mJiT8M>&_2W zQVEETv>RSdNt`ZPx>g%4xj(M+WNY8h$_P5?PdCaWnu~-(SdR=aU}| z?iFKs7FhJ&HC(v+`vCmjJWbH?n*E#w)*aQI`|KhPEO*7dCT5jIUz%)o^d=|qayocS zE}Hf@)4Y9Y3I%H%w3$Yjoe`LurOyl;y@-nY`KXmuBd5s0`45iLgqABAJa&pqU?WNk z3H17Ki_R1th`>|u{b5dtE)Ye*8D1O)>j1~g3bjcHau+YeLUtc^t**?Et02{s$G9TR z<)YM$Qx+ETt|k4_mkqRZbU=IM=<*p@7C<^31enQjMZF*q+4e+*rfQaxF?wGy!Ika> zhU09kHzhrE6{2_COHKFpo=RWzsVfME-Zigf*yZdv`5amKu*I6wu31twrTs5cTEwk_ z-p&oB4tf*CRb~@?l!+7|sWCm!0+S4uhn{Q~H(!DY_QdFjG$+;&08t=Fyo@o6&gQu?LlL{;V zUHpXC0-#~w080qoj9?-620XE_*&co7=a4|p?BYVDUsNANtKJb9yC4D?{i2Jqi}4%8Q*ruadJ0PDp@ zbyUL}@1{qsxPjD6lo+L(#HnU7cZ}a&wcOKA}ELHAvXxgU9DQyol zt6S0d?oF4;@4i`r7C5+g6DBo@W-<0BFjiy1bMq9pRkOHF9zwsKrSI&*sQ7thXArJ< zx!Q-Z>}+;T4HO)l5-XJx=B0qCp!U)L^~uZ?fc^zhiPw*oeG?danAx5X-1iK2r8Y(| zf+_N0mJt_JEW0x)ACk9ErPsTRW`11A+P(R3AH;J=eIC|Vt8LMe3k%--Cmt6`-cDVm%#gB(eKrfl-zk`X7V|N+e1K zu+v_R7oXxHWUd>?Wp5 zuRnV(;|LVwA7W{VUx0UTu+-f2yePh~=H9ibQ|E)2`ju;*viIycf7~Zk;Q$WLa}koq z@~Ut1=UAuQEVMI+rlFOgEG`gB&a2dC9Ul7uC~s#L`-tJ>17$H)#To@CkYxJ zy+L`XJ<`uzEj96aArU7q(xbk54GZW1uYtSUCCH`P2`70UFrU!1AGXj`WZRKL&MKu& zC&n~4-ZyacWQ6TWKmQKKr@1#4o`gYb(K%fx1IL3hSWI^Ow*Maw+5Dxl2+`0t0D%fhO^BR9j*v z@*A#l5;Bw_%vY3D1{`~A2 zf(su_n1$AhH^?&kMm59US$Iu+%NIS?<9_<%+f()lRB*A{ceYaJ*0iSKfj}pljLVuE z16AJ1Rx&q1^Tr%bi`jh;V1IXii{#PDz&d}~))``Vi@qWLCH}bBr!EJZQeK!}fZ&0m zHMb1|sq3#;wxXSN7T{9j4SY$1|Gb+@CeyN^X8-Bhxh>WpT? zy`&JK{F>j&kelHy5M5caaBN3|&dP4nlQm%FXQi%~s$WK(I|cS?&+0s==SeR1wGU(Sa(8Ypox z_W2q}8cj&32Vm6Rtp0Pv=V8}~bYhfu?p$LeMbrL)-PKmTN9~c>!Yhl0|KbXm+9AEf z5RlG>AKw*n@xv}HV@iIVvJkgrZ7ZIs^uEUJGDFLCx|z z)_k)E4HtlAuJZ0?E1%8~p0TnogxJGvRDJ%-o4dNiV_y!N65BMP&&0v-deZaof zR<$TdW)h2y+^Rr+Zp7V;P765ezPko{dZ5&q2CDl&g4m=1{F~f}JMLQIeZI|*qO53vI_b;{nsi`aPVc(k zJf#nf|JUGD^ImHOvUE5kcdloJf8HO``A41!@hCRzWbAYTw{@c>(dX}AFs5M%(us<^ z^73&8_Pzm)XrPGqwTvuJD8&_p!51OfN(_t7^M;09dD)Mf>~}h8A5XpqiOBAJ8;Ba0 z2MY5qqit8h#{GTW$AJ9}dQpFWn@q{bgY|C22vnl7NJ_lV_Kv9A)*6o@6*3&4O=#r$ zu<_2e<>@&@c7i%A0+fwIPV|jdtUCb~V@q`q=)N!A+J~IyWs8jk!WkB|aRA)VV@CuD z3_`3JAtwcomp3mX$q)WXpohY#(@UH4<4q)aY9MiIG%!Pa8A3Y3JqxW-lJ0^19?*<4 zAKr`jL^h+-H0#)}*T`&py`u>pc)Gh1XWj9Zf>w#kvhQWSd>+K&*GA|C&_G~+L^P`Q z;QPBh4?h_10!-_jD?c5iF3z@LWCsnr*2!l(+rqmJwR3~U#$u6z=S!%pj-tvG=-Py5(?d{juNhY@;n)ctq=OEI0sLbBcygtnvk-3tHh88h}1enCz( zGHUyGDc?c6u0GKl^Ryu%Q@Q$MX#K9_;g7HUfjnOet7f9Af23xgGFxBbGu~Lu?MCu+ zz#eY!0lENJdm0GxuR3_y8kQ{G=dDUoIoKUXQfonok~OvnlcESdaO~Rl4iT)bbP^^e z`rY3BJ=BRZ104RayMfj#2Uk=l^~tZ!0{1g;M$J2UZ8^ zr!B*nNacPrFl0hp71ng}tkznd{AU17JPN$M@53aduCM65FR+Q&_2lW{P;d`k?zH`{ z)8HozmJ3f%Q})QX5X1`{FE-K_2ya#}v!LKoZ$~!G15re-GPJ)JzBmsj|LVpL@GmJg z(xR3FMjqe>AdZ(uh$m1clVxj4)&@p)E;7A`s{4-p0|StVfqeurB?iV>)yORD=D&st z7McN%;v<3zowFGE`Z)?OnKRJ74?Zi59Z3LGGSYx@KY7ods;Bhg z-BX#zmA%kW2d)cP*y00>KDpj&w&gJzu@ef-|K}F^N6u18u@Ow6!1TQEiuU9kFih($ zLxZg-cI_t}_X+}pwWia9g~-Ukr(t4t>ni%#A=~5_fbR=LWxThsJx?OGh~5x0QiXa8 z&s{DAjQK>e6-;{+T8Bpq-FCb&(%UQ6kkn9YeRJO{EemQ^HF--jP?}$Qgdq82b-e7e z3k+THt@R~4p2vPm&i(HqL-mKMAtz>Mc~If~JU)7U6XF1z)mK3syg?$|**EDi8A zB-mv3M_33|W;Ze6f-vr;cdVx%*Apg%wBE_SKt1ZAL+?0C_*D6FZ$61aFI>zMc`PGq z#Y>Q%cU655#hU2pn&K;3d2#c3xABmdr$!WAE=C&~Zb%5~Gas(e%H-_jv)hDkP`9F4_srv8UUce4FU8w#AY} z@LJRu&4Nam>RO)=ENL5uDLcw>9`IMZOvyPIegHtCYB7f-dp2-bsF_eO%ZLM{8k%fn z1_F*y<#uh9K%e~#dPpJfdrcNiuCvzFG40&Jgvz8L9dM)m@9h~4^ykmq*&3$kWpX@i zXBtc!ucft(eQU4}y7Qo4r$X2=hhEA-&=$~;NYDQL#K_?#QGn5W1f~ISouY|!T?^O{|;a505 zETz-}W?kD_?m~d-vp9d+)7STOZR`n{`3_d-FcMRfmtd23-90t%_dt0G30Y)_4MocB z_a2KzFw`(qv;WUD24I-^gQdLstJh#9Get_>siquGiv@jiIzqi^=tvhv1KwrN^SnOB1RakUfHquVcx$sRSgC;DoNrMW+WRlg=Fz(}~fuIYmX}PFt$D3=2 zjjxKlaRlvQ>Qo0^p~*rZM2INmVsJ$Rd|>IY%>GOEzy|CnvhVc=SgrC8?+blkS3IAJ zXsyARi1Zb4M87_wXol{U*Pi5FKjM*%RkNN+r+%8c&j*;@>aRj!X`MOAYevawcpCJ} zy0_8YuwZ~Bdj2$vI^KP*^;M;*K!O!D0)DLxj38FJ_!FZF&f@mrFY!ya`A&K&A6?PuYpq0VDU~4-i-nIoE5s=y?mhkSJ-@VyOW_s=n zK^=(56Sg$L_Mp204z@h|25mNpkYmGltcqDl#fof}&MN8j>zUt?2n;B^^9{~6;9Fl< zo-w8?Hc-JzTwX1D^^89kC?ptQij9r)>qMwN!L~5RcI8R{2OXrju(zJpJs?r^(g%aY z9PzpHSO6*Oj(n?iG8lV?*cpLO!35B@yzG&F?Q4(Gu4>Cv^qMZTb(JeZUjVJE@lE&O zF3>h5*)wv{CE=5Z3e+$7<-b631R@_tQP`*g_boc>0#@AP(*9vGb2zJYmW3`CXO1(n zT4SnK4zfZEEj5OU%`_Oyrx+pSLI;UbJL>jXAiY>a=Nx#b1(=!ik6}7GfgL=RZERX6 z9)0GL%_C6B7+m~Vgs7b0{HlTC1bn>@$`!znCh{s51*A>SPu($?`z#}SP<>zTAcSS@Q%4mpYuWMN+|&%i~?W}$MfWi zuoQ;IKxh0@m>vC9hTo;=R}dQ_qaee{z*8TvX%&?~jJ6?rkUt8=?5r2Y-2A4;~vBi2*5U@@TA0gR5mj4fYUL+_rqJ%tFOVGE}yd2?WDF zy4+pjI3;BdC8nQY(wUlPka{be7S>{fz~stA^26&PJQaSRw>LQVfJ7f31BNQtO+trX zj2VsHKVxj{pB7}^xt}{B$uVX0HSw{EZocM5T*r)t=+|Leox;IS$IoCPztCyg_L7&q z3G~SD+*NFpZu%@jd-zL!q(`9?Hds^$3-JS*?)k#>D z-y_yfgJ(cdSp|p=L(Pgm*V)qa7Rj{UTT9wGW$`ej@3wvaiRKyCYBgrep3>GhE^(9? zbYnnD0WsPeK6Ly;PcmdlgQf?)+5xET;HB-s-z^z_IW=eg5Ya>G=2(yR2|1*7Z1V=% zh2W*h5YB6>2iAO&NCYl?7abX(7C_pU{uV|FL%DW;sMZ_i@Z>7sNIWhypv%1NQ}uT+ zxQ%@%`Z27l+}p{guwV#@t}*?oj@Pf}kL-ce0Ar?gBVYn#Yn)la{vAaF#Z5|bi-;_Y zv;Z2>iG%aoREv)^){Ae%;?I928t&1QJCJI1K=4~}uR39(;}P()j>_wWG?3g~?rk>! z%%32+^vP=(Y4tA)>`iITW|w?)rzJ9Ro6K5&Q9e{}N<(c2&abSlvxD<5i5wXhrE`;vq%VB`h5?V%sxUGUX7T^j%lg0ECmV}FP? z37G3!;P<#mbKz!LN3yIFPqj0%<)FE;#P2MplxGKkHRgjqX`rOtR^9Wz&ZSyuZ>xw_0E|koTDFOadme{5C`SpMzrA;J@MS6TlAyA$ z+0$wNIW!{_u=G{c;=u~-ndN>(W#pF?^CHW&}ljxEcQOcM(U5fRH|aix0L({tXq4fSoP2 zy!oy}EW$^r;LU&L17&PFlY8_^BJ>lnz6MbRUh%#fw{1*K^ys&n+?3*P%>N~vf0LcX zq6OojJpLOpsXs=1@`Z`pMgJLU`wGXQUOZDh@@^u`nnom_npG~BfGh^%Db;EnJ`Wmk zSS&tstMs_$olKc_o1eds?W+&Wb=K0thK=t6O*DVUzO%P-4>Kl9Ij=Xr_#ZEOsur1ZejoEsLAfPM%kt4?#E+PAMPaBq3(UJO0G&F3wmgcx&(j}BkhI4UWe@E3LL7M}P;z0?A|rrP>2U3G!0IDN zsD#V$n|AfwMTEB_&2mP_mPf@-t>f1yeQ{<-rz*rg@LH||8)Ut0>DCMnRy9G=G~1*4 zms12kqb~Lqp7IMSD?29^Iq>a&9T;@a!GaaUvlL((*ew8~k4UTXNz6jVthTO%L?dQB zrO)4bw71yUUxGxtD6tjbl&0H8reL$H_zT=BTt-{XVqn*}j(D9Hp%g{&;n|6FSLn;Q$!gQxsawEW?~-US9u z!jfB+Lm8A z<*M(#&TD@m)uq~GG#wkZ)mxyT@nqc_44pl_SR={{`G(uzebwYBDkLT(CoBJy+WR4Fw5LcJFqQV~YYn-f-R~zAzxQ zp8qWsdjE)-2coSvF69p10=|BNK{S8l@Y8mF1+9mNdI2cZ9YappsJiq*D)yPrp0#uF zev8lOGBt`{IR?yTi75GnejKZ{WvkF++xMs8r4+}Jj5E7q1YA(7sYV$UI+1^_6}J^s zF|({v1>w~&VG6Vy;bZyXyAAlM0X+p=V3b~tk#8ofo|!8zJ6HJ0 zp0%m>39O0qi@<^jV1+44Ux5{xm38&clk*QEQz_xXc{zh^W*x^BB=`NjoknPo)AZ)* z($(fiGyN)4`7)LohY2>rVXju-XC{Ce-p>9(?@a;M^D3p&87$~Pr0174$3tdcz7&b% z-iAC-77x_aP`tF6G~FaMb%EsoloD+4NAb6)7TKffTIS8vVlk!z2Lof*WDiFy z2+6>?f_@3jxiqaEYyV!``Lylxhb)|qmGhr^_n;fWa#!Y5xx;gkxu{guO#;``E0z02 zD3q8o1W_MZqW${*lv_f*Oe!jKbj?B&{~6f*`t>(jBqSoUz>7T=I8Wma9M?Sr?IG-! znlGit_q4@cnjH1@WOF#|TA_B>1Eeh?J(PK#)G0o3&juH}7vdD6SFUYvc?sagDf$ue zEGZ@RPZb}yA))q?r&3*t3>DTL7|aI^5VKxcpa;H#`7v+5FU}G6&@#cSo3mnCWpOgC z$_`K#$5@(Z z2RGA$mO?b(xz98$q##h_S}{Mx~Fge&SS4V_%Xh6rv1Lm#<+{36Piktr_cbC z7K`2Vp=?dxiP!Qa+4zIJGX#ZwxPXQZe8LcfC{~eV=I3U@G!|scp0RPN>fr~t@1rH! zEGUXj-&xy~TbWJD3nuYSm%-`sPm21C76Igy{yL8fx$mnTgx{{B;`8U^8#V3Fypk{_ z1GYg=p_Q)abC(IWHU$_c1)QG@by9fT*Av2YwhM~_G zO=NzLlG_<~+Mn?Wv1es14aAlnFH)1+CN+41&OT6dKC9Pmk6MbsJg58yZ9Tszn>>N#j&-gt2TX8*!zu~^7g+bsioK9Fz! zq%PKpwC$StCY%W!Ig!50jq(2^Qb5S+Ymq9SQX~H5y}pzC2;wrpp#Z@jVKk-50vGum zJ|ov0{QHYjzJ1oH=8Irh#fk1D|6-aA?k`#0Z10nX_Oqe>UhY=|yGbr2P|UAGm$4I= zyvF2HJyKH=-gGY)&BWTd!7aLpZ0?Z_hNBF0Ltt|Q1i`CiLBfa3(9QsavPu6@lBWsU8#*unyeSm_e3i>G4iJpr3t1rOq6AVYwVFrU zx{CP0v@?KDmYtnl^hUy)*yB7!cO8xA78n04FQ4=l zmkh|#8sM8-H@F8w%~xl#R-XA?w%B9BhM{}v)Z&nnz=YAaz#I40U3wUO$;iB_h(gr_ zSkhm~j>Ce@T84{#MJR(ILU{M`(e8#^RJ9vLnu)8-QW$)HqFedT;ufaKF0!2DKjfw? z*x3+=RE3qJo~0BGv|D9wk;JQ4VPM+DBT52`@ltP?V`T+{qr`?`;@a6*9`i@`%C4#J z!r6(iat?=Syf{<7%Dj+^tkx+?+m?SrYt*>#Xz=CdJ};5Lxg1EUPOB$)K`5sh)d~~8 z5kS;KxbE~Y<&q9$t$C`(ZFU&a@*{!9h{bKtR1n9W@Q(3Q-h(}rOCX`~-&7RETg)@v zXkuikx+&6*!dfI`n_zV76EK~u3k>QcT7X|Bi+CjHD!!WdR(dx(v3&mV20+n z6_1>|{@&hSL|Qsgjo^?71(hCmh8*IC$~+c8hS$ zej^MJLzP(8CYUlneQvA35GJmqxP4}kWQyZD9wKP&{OXGkO}%PV&Hnn8GL{{}5#i@n zu}A^4ai6BvP7}B!)+Y{nJLoXXBI+!FK;9JjWEOoCaDcoBYSMw;V@nJvpA)5J{ z$%s|zXI?d%=tdhLW_Mqh zF}dQiecVR)wgt*rx#+R@coFTpTGVal2n~JrVOLjk+j4DIt#!AmnC!WwHJ7TGl^s{xk`?B9Teq!I$Ia2GQil%7 z5ZNy&6papVRwg)K+iziX(0yGBx2!ohD0_;twK~ytN#kbdoXPc#E&Ea)Y;k#4hvjsJ zU!FC~{8)c(SwbFFPkQMq;`@(fKE_%;*(zyw^(6iKODwDwKPE#mtT>y}EVGJV6^Gfh z!EEHyM%JsrrVp7Ib9`Mi9qn_?nB-VU1Cpz_w24LpzesGH#>Jk0>FPJXxT~C`Z5vr) zXizAWG#gYs`eh;xE2YU()b5dJ)?RVJiAkM9c|#x*+SWJ?&CI;xM&9%$>)UUjZx z0`9Yx@);X+C62GV<2u|N50-Bfn$b*6zTM7lSSoj5v~fR$vG3PVJo>dx;qnt>L9;nD z?I+|Co6guljWwbDz$)k7sJ-88n~U7|t%&UM`y;ZQ41Ye<3r+<-V~#lA>i$^=5+^mj zM+zZBg`7v9ZSE<5a)GPU94zvSOLIM>H$CNeEh2uU{^qe2`PjnT5{lT{Lh*t8yLivm zb{dP7EA)&H0*u_zTqFgy`%Vo_g*-hmE}O6GbF)3!>pLAtHa}-;8`HQSG;#$}a`v$A zjD4Xm&n=y6rx=>LKR(6z(N=!E%7I%jqL@BJQSs6Rk}tvI;j>)xiXnX#L(h!w`}*+H ztmx+s2sgdWU@3M(0tRz{6|-pGybQ0O!gb0nhGa}o*JY$zUga%)o%_cdKl48FK| z_MdO=&ku25-M+J8-I>wLRpru~*Xhl&e;eazpUjzW6BSgsv0c4bbfR*4PNT4AiIQV8 z+f<%S>9(QgOkVN{&*dy#@BL3p?{s z4kvuAO}D>m7Ey_B(5RPdTvzY7nAW_)N3+pFmET)sj+e3zcf)${oln;8H*L(w7IHD{ zcU+1Czx7m$XmtaJY_XTzT(+3o>g)CDhh}>MhaU6c!?ynG0_{W@hdX_v5tzP_tJR_; z87reiq2-NuTj`rNu^hsDI@(Qiv6r@Tjy~aTVMpGmCO=LqjK32e6_Nfe$ErSDYk!-< zOZHlQT@;QxX@>LHxZwh)L82m8_d{=r2;M{U`jRb;V)CJ_B>^2jH%lR2sg(ApPnJ(~ zQpA$4Z%v;fD$^?8E$K5=&5l{7y@yXdZBr54kN_{$gT~ewzywi?YX5W{sL(Ih6)inS zbIW+Xt(e^ObL9jcLFFE4@rjr#Vq%mo8@H~fQ~b}jkF`#`eOBpY$0PDYqZ@BI%JIEF z-AX7}?e4edTrQA}ZGNndp`a0bpM|q^_Y(2govp<_2lAMmDMvceJ@uHF?x`}Ay;i?d z5-orH$74K>@7Zax#xjyTCQ4sm)JeZqCj6&0?GoUsEe0{f-rkK_i)Q?x+MtpBrfS5G z9~}=ucq|65FWAiM9#!_HQz%ViiCLmic_WcY&Y7qwjZ0>hqdhhA0VAm! z0(Xy&dP(|HW8B<{DJ%v_M1=GQ+4-@aaGmy48J4ZA#>sa?$oE* z+qhUxL?>eZUa^*PQx6aAXf{UnxNmct=%HS7RlV95l2T=5<{HHt4l6d9SL^GZ@60n% zbHy^ee481;%Uf{qLkRk=gs0|lK)f)z@1=z|ds6G!S%mxe@kn2HWox0u4&VQN)lU97a#%^}u%*&9Qa-5{ zz2+<@kqSF?9Px@*>BFk8Vc1V@xBy^WU0>9Ao#2Ax0&@jX>(XuJisFjQFDpc(0ZHGk z6#baV8OhT=_Fw;gV&dUN67p2%{IP2j7yrE0ol-KA^Q!Ao=gyePZ;IYtxiDV49PTyg zAJW59lr(JX#zHFT>qD0I>E~%q!U=t|(VR zaKIw7W}PWhYGG|YTshpy+On}Qoam=3>T*G3QHt-@|M#=Ooi&%5J32rm`S2KCvxIncfPPe_k70B_PuLpFy~a+D6MSvIYUzZ z>5KkZCh8mf=6!zJnOlK#H%w~g%a^nA?fYfTM+>_hD#hIQ)9#0v5$MJz)plR*cKbXy z=yKG?i#{pMeDZj)H^42(^2v%{_sM^~67pk-%@Yo2%>DazeDMOwGvipH&|kfKcz(R3 z4CJ{vibk?wm@=pCH=6~It9>ERG~|o)^%%uGanm$QnIdI9)J87Qk-)m_EzL3|vi0@v z+O(Ow&Y74|?z)1wj_+vng=o62RE75)X9@ASX0-qPfsxl%k`;lH;or9j&oakSWxOyC zOEWrQ>pC-Aojmnrj9x-SnJ|(g!^+3H0?@nX29U}-xs|l4bF_WG%?lNKMh2UA>2>v=YDkh)h1X(H3Y#} z`)ufDt||!y=`8!#JRutrSq!>BY-bywOiYh_3`;YH>tF+i`T8H_&PNior93!61f| zWTvfoV%YOx3uOnZ-EzJTP#rY}(@t$yBr?c4Rep-mDyQ_i&o(`CX@3Xhb@CytG2VaQ z{-2*D#b7%9pIh8r^%A>MhvQF5(;H=9i}Mf5o@^LPJ+AB=l8n$F$|!AVX|{G^VZ3|? z>L;}XSLT+3&x8CCBCZzivx!O0#|_LgN`-iCnfe>|$X7iGu!~#G3fhV&@38iNVidCNuX6Y0zi#s2_=q-oNvcyXY?sj2su2nT z|A05vf{um`MQ$kNTN`vj7wG!cH5``<3OjU4F>j7wbmZoB2_~pr~a#s%_^hY;uRcD0SyG%2$^z_(Zs*+Xs+B zM8D3gVtO#op#IawkMVO!I5|?jZ?2YUVY$YvG0vj#WO}JxV>gg9 zFE7Bp?FXlMpXTV{fTF;AkM!@^1n|U~3o)LX96qP9DKN^J8Yys1itxhi3>!MkxaY@Q z@hAO4xJ(WblP`y;4Y*ZzFy@WUuiljusl@Uv5l6d zRHYdG@+LkS4m82lnom>r?`p@u5=|+mqq4-qM6I=$Im}(s7MDub%KxM_4$G5rJG*>t zV*u|14bH0@DYQRc->2ow(1!@}Q%`UD0hBoL(VUTNAZR8YUzX%WVFU_WK+iyOJ`i;e~_SL1{E9A9SpiD&@%HdZkmU=9+e!%3Cs9J}LyaR7~Qj6Yp|1Rh>)F?!Eg zP57AP_~p2r0JMYQav@HT-IC77#~Hy=N%{g@l=N+}VnnSnLQ>Shp}^?vhE~LLc)oqB zZvHz^)8CT6HA^E|vAP=K(adWl%?D2&Mrt^h9cpsejgH_ZH67laU-9)zf<7)S=a7te zFpxJS+`w_!xCCpaSL$qrdA5`yNzQ7w<-7n8 zcu#m#%6Zc(Oy@87#;j;@+{V^>!5&b#>+5 zSgE2ezg#pt_VNfjeuLy^yNQEup4D(m9Pv;WIDGP0nr%xz=9)ghn$ta>K0^927cf&U zzk|^Z>RbAnV>TBF#xdk(Uj5j?z-u5onF?!9~7 ze{zRho$c5loZ7vf+|pPM&N9)JQPh?5`u=0(!NXcyT!LbrfvNHn8G~Ip_UXL^V8P-y zcI*JDAXFgdCEI<2a>UhSdg6!C&Q0JaRF(5M=dn7+8t(PXSyB+0a5!FvVM<1ZhC+ay>_A+?_p0~P>^PVE)u2l$9ty|LWqJ)`=TwGaMP1<5nplhZ9 zlN#Q5tIwB)sZ(WM6}AD%5DotV^V)%OBFqGnZAT}p%8}*W&OyrghGVa{4J_f1kWDw; z@5@fs6j7`v7w^4TAH1r336Sx9b*b0DPfSD1h@=hQ7{$7fn%tyx*PS~(xo3KxdyIh@ z8uRSgy{l^@{8pJgYmzJaz72!A5Ee3&&yqC6$Knt!C11-kLUmJxD?I$&GGA?Z`#J`H zE$P;a?+5=`Suh<7yIK`l{re-z%VlcRqaCoS@&ITo0KMpjy6U0DhNaK4*TJ@Fnl6bomhtacQ9xThC&vleH{W z7zeDZd3=4*QS3a0tycTt=!479S zGh9*}b`rH(H_tW#*ER4911Sw3C|?DS$g8X?mzL=EbfmsDlXW!qIW;Bd*P*68mJ z;qGApR$hAQ(Rt|mO-tM0Z+re+*d%E)w+UA@HYGGOPMz=ScqF?@s@CcFXMCYDY_+FJ&d;=hw_l+I{fh#fS2A|9IC6dWYX5cz?FJ^-=#|3;F$YQ~q!rz@;L(q4%&W%&g9miR~q}wUibe2vnvG zKAwgnPsv3~B(WSVnTV-%LN8>PuV1;jOc*#bHTH(wap&Po<=r-4|=yaCi4!O=_`PnEkhd7Zw*$ae}GVh&eTr&jaL|nqBrO zF}75}#mo)17S-_N=+8}-6Y&|x;}Aclt}I!d_E%3f^{ ziy7{Uw*er@0=o`u_^JltPP!XYA+xvy|I=38Ylft2N`TAJZ33fs`ol*NKL3KQEnO)_ zU(mkn4Aq#$^QqAJ3E#E}%aXft#+LqK3)eX(q~q~PxI-u~m}OwRfuwY8s{zcpwU!FC(pKHjk#%zydH z(@^igr-Rf7Wk;7x2=o0+=C;$dGRJL-!e6ErPA@9om5m&^7xAHfs+-Q0;rY&;M~}Bi z0bsZfe{X>YbIzQNuN+}tuNcojk{9S!r!CvSWKT8Rv$T`GqnHD*gcfgi`o=RmQqx{! zUqHRS=fFe323p`#@&H3}@iu2R{JXDC1EN5f4lF_{;*j2TOiHsQt`J3nuTKdD65$TR zqp%xWduAADF)wFwvr;uns)o%$sJtG0;o23DYXoWVp@7!eadzutg7;|dSSduDi;pV& zdFAgECNgyN0x2jG{;m&0nHhNO+8d&AcxzgI{^0n zX>_q?sP!!90CTFB5wHorNc%Wn8{Qp=r=>p;NfTfAeQ~*6!#)e*H}&-3M+i7!;e1Bl zL_83L3yR~=V&S0SIQ5$K)LPvzAr3{3X<4&8&Z(-oc^#Qyz%jb@1vgeMgfVJmn?{4F zggAy!YkX2g6d2Q;?A;X=HrGclJK70Qq3%fw{jf|Gb^XIiG9o-K^(){GB$cxhU6XM? zMGg_&`z1~qqV{D!nDqj8!A1M}J~%kVYx1L(eOmnteOgI6dV<+415^kcgVbQQwNT8)VpEu?~3x=d?;Zz4rJMvsD`3Dl_% z6ein@_RfO~B^XA5vQ+*LJ)+4pnC#NrJ zI9*;{S+%>K`;&LQn6@{1`qs#9Wya?k74hd67_OXIJlbRX-np1byKvKd-!zXelO^as*V9=tRnHC!IF*g~1A z)iBAqHCI0;@98rrhHP$;^pbgT&%3XJg@Z(t2=dXj>*4P$8`-oYU=;DVH!<3|OGKMA z5}ze5OT~hLQ4QAkJZd$vX6>@X*ax=y_L0W7I?@f}eZzDfSIq4kL zv~zT&1=U!~gUsWXp4u3%g%{~!WB@O>#qjl;cg(PNw7fl)KBG8f0k|$vA1BSY+dx7N zLiePrY0*G7QCJQ*0oF+df02bV<$hx8K1lShPP1L8oCF7W*O`}=9q}dV_jiA+! zm`7b%oY+egBFq-&H?pD_=norX#Q&MA?|*JrEL@&mJjWvYR@+O6t~*=DKK@2%AEgez z)%o(_ImNnYq3~o$4~%rVFf%K_-?C>rZYOcT3tQ)PZJhJs7xDPfYig&IqT(`|?lZPM z!0O$7JHglT%wImR@+=+CqD`7RVM0JC0Fn_*`ikKa2EyC}04aCI6SYuq>0%=LbbCqV~ zSW~{fhZq$K7l{%1s1jmRfr4bFVC0Z!jKvaQCZjvbg7i*+ zngqr`ySDzll){V2F}~)Bnm9O7{h7Tz0kPnFbs@sB>Nv0jiW&T&{=PTZDX2_v29X`zd5?Lrb3A z@IjG%`@&4tPir?d-8aJJcy$ib21U=DI^x``y}8)u1i2*f0c3+Q&sWBeFe+a$)7Sf9 z*$X)?T);I_qkFzh`ffW0+2dFI2%R|GUq5+R!VCMb(f_R2)GERuIu1PG2Cej6Tp@~d z#meMi8sypF)!JRKc(=iYia6t;zRP(nPb2x*Feg0*38X+*g4f5M10>ZqZQSf+kgpa!Ravhu+8)LaG78vnSMqFYHldAEevm&*-mE`VZXI6imnVD>g zc}d~&bQ;Q8T{&nU&2Um$SAV07q|(OYYnsV=_v%&*?p)R9(<0)J9GtSgAoKR!=0eZ6 zi7oBuq@ru(!(Q^)?Xl}rS2E5NVk`uPEDM~lDE9h>2=mtE=QXeV%005-y1EgU+g9-H zX~T=kkD6~llfc(WW)@mlW_BN3qC!if@LMMdqLhO)cr;@j6BAAj>dfcFSv`C%f;3v0 zy`q?Ev@{?28mUXbd>FM|QYGNslT3_jX-^pik%Pw%_e;Z^9c}-3BmdxEfR{x&Hw|An z>}?e5OmS*FlWi*!=xBR#Kq+<}Kv6mBOc&QoexhG}Qqf&7TI`W!-l8W;#%7LA^mdXD zgpYSG*kwxYrts&25q0HRg|rMJK9Z}&A$vFKuaB>W;)IJ9=aV0pOr>z0R9&8Ppei0^ z-;5SdB#&QeZr<&A(9tj4_VeVYS3?zjLOkTm-8oA>D3ivbX!rLT$0#Cw#ijX;4V2O~ z>;5l`DRqM+W5YdtQ&TfzWHZ%fq=bf^sU1>(tb8`Tc5r5d!$lLw$~Bg&Gd?JF#`KX< zeS?P3l2YA9eZXBpmc%QNzjFk00A3(&@xhBrVlBUe9G2SL_ zyc!Ok{d`VTD>ay>E#6utQ_|vFWo4z|cl{fPC1DaNTqt8&co+z){()?kjq&H){@#Uu zwl>p-opFAMSG%&HBYtM~2!-sWcFDjjDB(T5R!N1FpA@ zG+XqP;Y=9QGXFx+I3~iAg5v0HOH2Q_;w4{U(%1$i(_+v}-_^4U;Bun`{m5S&qRJ=HLiH&ZKp0o^p}!> zDR#O^a|?4Y!IyC*4TyGwY)W)am0PWNO4qG-xPZ%1QXhQ6AIJU2Uw#jK#z3TpDsQiM zZ$6PPzS8xNKYGGv_V*_#0BMH2J$UGbHD2M03lCyt-bHh(rtCiH{ROGth;)Dzs}jZN z8Eb^B``C!e;aNtEQ}fGnS(ZNK@;(t#*0w!<8wX(kkq&b(+E5?8u@Ndr>u zSdK}5R71(K#9m*&cm9?ihx~R@T+D`0*US6F#dM-2xp|kvv7y1iXtNT{heSu6iD&(U zX3-Y)^{*z)BrRLlh*{hbzhL>`0xtfV`NdN5Q-q(SRA92np zZP(0s8fH3S5mtbC$%o~p5ZTRVKdmJ2AYiED^a1jwX1}v?|vee@wi&S%IyNot8rOS!XeW}~|+ zWQi*a$NhtX43X6FG~vecRlPvj^F5CJ@9n-@$1*N!Ai3_RBP-Sg2X746)Bf0d>sm{i z$tfhu@|t!TXjuVznH5^{vPr(v^*x|yI58*4Sgl~~pltH9j5Mn_L5jOdM6pI4- zn~3{)RQ9+DhTlaEbD#JWbA5ln^*6o~^EOC~Yl9Xg9;7y1t^;*e9B61KT?H7R9A|z~ zx(l$S55DSKbvtu%+c*#3Ovf;HXzM%0x({K~9zTiP{U;AyS{67|Ql%R>()SRirWQNl zv62$d@i>O@9%?OT*E#@oMV>WYgiZU4+`aoTfQ&y?7bK`PQ%ru=ISj}rub6MC{XhdM z+Nf>4mi;SRlEOkVH}MY&5%W36vlOKvjCo|uUO~MW2ZxI!l(W5Kbr)fhviNuYc+flhWa@Lmq zHv5P5fH#wK{x|IoZIElGgO=43+lD8)nulep;>lNg$5)VKce1?$P3W=erig#`#+Tb{ zfRA1DG+Z&1H>@=@qs@GFiDFHsRM{(M?g=F1z1PERfbZK8dTBJmL*6;I8&b9C)mho~ zUC_GRzT9M8pzbkWp&>H~bVX<7NA$(^LjhBNG}w(^Y;=Cgcv&Z;QqdFSvS6xBgOFU3_eCXQoW8XUF~ z?{}j5mh=%xpLE@MVvFkfi&*+S39QuTP(CNVCBGvl9jYR8`BTa+~!eP& zJ;{j#9{$#xd!eq+pRZJoKpp)Cy<@-4bdN1F!!-kpAvkQD(?&tk5xOJ+tD9^JrTW@a zT{-3_mV-D@i(F(f1%=HS0RagM;ozt@Zysh!ZcQG=HCbEouB~nf7FW>N5F&n(njeHJ zq%1v<2l^4D2+NZg<-=eiZ9Q)N22k37&eF7Yv)(vpEP<9Qbmfv}`Y?cT=%C`0#$9v{ zLvT3z+>}4uLX35me5z&!L2SCf#6`n-;II16NbL{yOalH&Nx83aNyr`wdtXdV^$EpG zePU2Z9Cm)JjMaN_`JlxU@+l4BhEB+P@QvODX%EDBQk0Zlito*~+lZ5+aki>Zd7)p7 zv~ZM;?L}lf%9MfhR-H-$kUkJ=9-bm;p}zqKrVzh%L!N7Kq*%4k_+5=Y7OKHkGf%x9 zMMon;#PgCTF#SLyjG$61rLrlZ=|)5Rb0=>~l(MoUq9-JJy}&;xTnxLiO`;WQ0t~v; zW(>z^YUtTORm!y?9GHAv{?9_#a9VC)A|o$3JO_Q}l?V|-+kmx34Yvc^Y)`$kJS zh|<$@7F^jq7)mUUu1$~P@k>^3P-6F|r^;D5@~F;kqbFqS36pf(3F$cU+@}iEb?FF) z)qErsBH)zv8;AyDfMRKJNzA9WyIo~>%nIfsI|A(@BgzJ%g4a9|6GN0+kavLcq}?3| zn7+krnvi>^HxZoXlHxs(o!PUqK~aiJY>;XM2UQ;GqV$9r5g)8%NG zqeU}JGhY9wYz0goO7s#L?Wl7sA3Jygzchxd9b-Mz{uV02-%9Lw(uKNNWrR!*$o_sY zrNfO4@ewxaGN`#cU`8=nS(agf)|%pbse?7>C?Z1>@0V-ce|t|pUYcdcuL~-h#TIN| zP!)}}gH3ad1ZZ(h{LK6P7H>UyDDrma1xak21JqdCeyGeKy+KeWntxPx70NLMu`)6+ zRQ5{v57DT`#+{`&mIS06^<2*va%bNq(CTjt4ysP9h%vnZK?{?rYQsa(&HxbxIu2^j zOlgaOs+|l(HV)TEPUs+jC?*XBIo3@3-+R3hUN|^7v_{@=6qf@WiKMCR?f)eNUjT7s zb=gY^F`U3bgg}ErQhiPyNPq}A^y_v~^oTcV;l>El36_RQwD6V{kj5=YD7s|o(pR&gjoC}RHDk5$c zlO5)S$(H{1>F$L+MWl64y+DQ2LZtOf)dHi~)S`*xhSo?;jm(;x%uhjgNaQ*R^1VFV zP$B6%(Vi9y9iK=vS4uMnngsngAIo(h&vJt}E!vNi0DvGvwMR{>w=26f!Aa=B&hrP- zR3@5q+*c`$+OBu8p+~*+Euez0!mrPqcz2W@N(P|bm^hN!l37387fqkqZzc#l*Eh@F z*>^oz>4B8|Y8 zr2LnxSA1E-sXn9M#kYq7q)S^JG9tH!bue2|Uc`17nqNW2ZTaIMgWE+AQZIscfhB|# zU8e>U)s>6j3-^Mi16USM2Wh~ei2mNTwh*_@b8%dlJ$W-N{cB;6UuiY|DG2@_Z*TtT z;(={J@V`Dm@sAw;Umux0{Qswe%>C+`{w{IssH`Ydlew@6{rSJP@65j{cjrONQ2+_P zmi_-#8T)N78}a{dCj6ff{&w)qmkB9Ih}6Im4Q7h_vvexiGi&CQZ{5mMt*A&4b+A`R z9?{2a-|8hB^-I_M_hDkccVr9}kfX$anl(hH;vF*zNAlge&{jwh{AhyCN?+s=s|}g&p2& zxGU9>ji0rH`~Ph0f;UCnM1HmC{cAs)rf*k~j=jVag42J7p`6Qp)d3=V&i-nL{MSx} z=gyveATGYQjy<~7xN|kRlX>d{BKczfd7`SQR{2ax76tpmaK3G=*x&x2BIv(PLh%lw z+=ZGUfg5e)vn9zS?!~ZgDXJ5=4Xf@Y&V1C?rxU*Y*Oa(Ct(6-td1c}tE!6$k=x(lM z?jYZVw#HF3-azeB$vnGip=wRVy^uQ1?M@gL;(s41L35U=pk(GM!8A}MWY*FdfW#a27>B;m}9u3d_iPv?qJE9fg2b%D_ZJJxKSP1f+AuIhgc;QS}~ zm2_l>=_!j3QfauWt0~1!Mn0}d+CMHF`=2iCA0?P3 z$%WnJah_Vh7_(gim1X_!i%PIp59nmR)(ZBwqN&UCD^w@U{`c4p%Wt5JRLQQLy4Tv+ zIdMpRZF7tKyzndf|27G-{LIV@^3BhlcjepraB^{OvWEIy9=npQk^_={0#F3XA$5{S zsh-9h2WMwX` z5&Y5nWv)l#?U&kerUYw&{**6$ zRcU&h(<^%UD&PYOckaaT8n5RfHc@EY?HvX;GQ1W>V_urN?a^#Sf=A#s%cVWG)z%Z) z{MSAZLnG(KC8V!rb!WZif`SEXrL^p*zlZT=0S1F{b$Wq2DH23JL`QRv2VH0dbt_YH zuE3<*`Wn^h!Gx2M_&pc}sY*S{B!*2@`{~H7>iuCT`x_-Q{4s4>X&Hi~={2Uia2 zr6dFr>eQUulsu2Iv5rUnj?h_lZ@G#KCEJ-xJ&l?x3mmV!kt!05X7&S6*n?A%E)-j3 zGxaDWDYK@_;MN?l3aPh9=*VQ z6fBiSZRxCV$Dz#LUV*d{OiFL8N`up*6ALiDF?U8Bn}9&|p#~$C|>q~9aDl}-hS__)j4k)UwL`87* zhot>&iVN0|ekpX_>4QeUXv@rlWcJH->{P(7yW@ee=c|{{4(jyQNokD|pxhLD+$4R+ZVgaT05V;b{m>I7k9@2ohQDbCr zXROU6f^lrWABwN+#=L$I5n}&d<2<-tSFWIVVhK~zQXO{E+lzttBHZbzt9!3-$0Vnx zN%s;1M!VyPzEDy^`&+%hJroJ0ba!|f{g_Z`7tEaM+XKOw@qETfA8&L{y9#4zily(V zx`)7VOFgNl*EP_`NXD37cjX@WwLF)Yn3$+eE%McAEz79&zs_GEES{x#5A+s% zz=^J=x1jA%^FrmGBF!L>U8C96nZ3CtU782rl*n#`RjIC)KShMx@b|L`VEwi1Rk5#M zOHt+P#ArC*YPsTN_ayH}SFS(IP^T`aav9Tn=kNM{;>;xUD%EhCjU0C$qehymhNeh?>>gU3OORsPQXMEg? zjwdP(m-GTbHr}1~Svb(_fZ_}$NH1`#dDD{-Cq*_|)n5RZ<>e_X@k$nSOQpmN_1A4O zLFOzI9AlB7eOi=Rv(6EO9lEfoLY&w`&4|>5OW8oAPC>%iX5{)B`&MVV1r*DJj%O>V z+g>;HwYK3ZkbB-|m1Unqy7zM_w+K76O?Pb22qj4dvIl8Addnb)Q3t3;Px{ zpyUChd00 zhqKZC{_7>qB9;qbI*XlaW{Dc=l9@GN!Xj5wdY8?-!e;%WI0QH)rhce~Aj_$g0&CLi zfIo!&nCZYHdXsy$U9I(__RgR6$(Jwf-vzyk)pG7l<2rc~Q#k`6Gomgec2ng~Vi4P* zhg|5a{T}34k$<2&&)SA%(68<@~2UZoXsAkNI zmA{7&*`A~;xtUsPc-}9y+q*9jim-=~HGUOf^2oC0P4TK;nVedGj|?jPiFSne-q%qb zCWkk>r)2kAooc+TCACovk3TY$blbmH2;nVGDok(e)W#2uUC?oPPQ2T9qmI2SRzVOV z+jhRWczuhmin!q-H64XEzUkQ6uAXuyS?f2nL0LJ;44Px(bW@m&Yl>H~cI8_~jBuz=TK4GAI7g0sW67-9P4WE|p&Fguz1xwx*d&tC$sp9hVL2K4 z92#!lzC96p{QP*%vs|Ov4s&;%24ON%QUoxO#qUgV4Iq8&X|yIjFW(gAU-n!bF@wey zuMl2*bF3uo#9B(qq_HU0PQ(hyHu5){rh@6ZSXn*5|8=mKnXBu4Mj~|-Ux~*0n{}Zx zGaqj6^;6#yqtSVzmofJE+`KuJzL=9dqya6x3{%=9Tg7soYautaRj#TWbBlOR`PYsF5Po z2h?ltNDjy{%;w5&!4*XR~o! za;~sFd1o8t{;Qh=)L*tgoX=px8g&S4$)!qoWzj-Z@ZzWaH$Xy6A8XF{v+HdvZzPt_ zE9eF4`UQlW8hG#-BCS#`5@T^YiVjZAFRCch)AGNjqHt+`W4HVwilHkZ2kGdn4gzDc z5~zebYbFazvY!2esD6Ro1o^J7+3hk~98JFr1O_eg=Pj{2T*lnyFr#(!+ z*_@@`gy4sjh|5zuqPRYodcP07U9eE88Yj3Vo!Wsd0 z^0TnBaxwBKgi_kGEY}*LeM-Op-bJ<}*M$Y;`h~HuP>Yl6`-U~Le&<5}=BvXEz~8B+ z&{#`kxFXD!8aI_h1-GVn^LUX`Kv3%BBk~vjhg?07Pu8q! z6n+!>PY0gt2nN2!b8pQ=(9@|pah}Dhixz3y1&Mbjw@~or$&mj_j}q`F$vbT2ns5Bm zp!J|u^L|O~P86aJlr!Y|_<>9VJB;id*^oXviYvX7w1jMCR*_7`)zsE>LfI#~;Gz%3ow?g0 z*}HAP{`)y7bSG~~1j9EsB!cN(Pg!o`BP2s8n^5~hF4wj*Qb}A8;;8R>cbga<_&SDoqN|G zNfrrTQ~|0vySg54;jq{oIMDTxqj6GgYYqNy1;K^$$<>(;IVl}cbu~GiZoDIcd1t?- zaU|XQyMQ{cV)1(eRliw=o38|X%L@%_%zHx2w9tp$4~}-oonG;h;4l% z&e@pBYP`P*iGLK|buVxx8+{46I?;Ll0SyuL0fd1<&KX$P@k|~of9s}4OP9AqBaztu z6`3yjmle(MWR<9sAbPwL9KRrgFB6Q&$rjYY-Jn77cTIju+m|9lp-;`s*dCwmpnJT1 z?B#PoIm;(vuhyR@qkBT4kIG&6HpLG^G%2C>G#vVr23PUclK)yW>SIBtbidOnB>a76CDvYvhgjP##3X* z?j2RlRm_-i#=D2*V35x+p3+^q9Ds&7Vl$kXaW7u3PaVdqwZSq(eo1LZWiLk4s+K&? zEFXl)pzTTsFls;W_8h~?)cdEI4Xn1f)7U4KJ+&5ysrT{`*xf35F}4~q&+DnDrJCxq zUqUV0n(*eE`lZ16ohk=v>aNF&zVzsw_V_fvHZAY7hm(ikuCca)iRFPE2t)3 zM7qbWy)B83KYps;_&TY!H7PD^ zhyU&66f9Sh+md#!J;-vXqBBv>{5qcZxpK9qMzCuI;7~o$GB-hxp{F4{{r2Jbpso%w=t*OnG;5*_9OJ@$yvY~rr`(AYY0#rF(Dh~YBcK0hwSTJ>3OIzZ2zvt;&Q-6N*|LHtX82Y z6KX9;kCOeQnwC~Tz!6-|e&6M=;5RUkC zw|Q8x|eLTB4BQ93dMHFn?eLoy<(LSHRG8AOmfQ!^S zQ8Ez6mm)C9Q-o9r`LpWS_ty7eI+(?uW&a_yzf5goSfD$7<9Twj<}C)RX^_!tNv_8dw?K5^nAhxf5KfrN5I%YfrlE#now>n7nWUO;Gcfz{I(iN=o)G*#CL$1`&Uiu*Rp$(*cLS!}mX zVpOhcp9TXNk1~VYhb||oy;6TLu(0Ygb(UN@!Kl$RTAYTb~>+K^DJ0F z*V3SLYG_IOG(AL*3E@2S+jaqG4uK?o$dsz;EmBvN#EzVQ#i^Z5GTe8_+vte)K3`3U zC0@3=d7ZivzTS-2OfmaeZpZtCU&ZcI(_zat?Y{}DlXJS~zOQg7DKXsE#Nj1c zs%Isnvt}0`hE#xSWaM5?Px@0?^i;b#wSM=Y1%h!rF#zE|1VYQ3KqDA5D z^2>!tj`%|eI}(`Afho9ifhHC|6lIyJ^+|mg0iC=qV+!`#1>M}7gh5$3d+LiO# z9Yg%Mm*;k{Ygu?}p1mN~%}?z`Dh(VC}1*dVR#(fY^M$;~l#z1d^<) zYIWyn($eN;dM9 zsiJvM`cC@`S1y{IE{@1`7*BV)B$NlkY8t|^tj;i}O10A}#eLrK`JSplt%=;iS?v$&ehWSZfEuuB+Q|SFz>;Y!Kb7?qUIt4I8+`C z1ye@w%@>zZ?n}Bqgf~^zIk*e&ps_c*L_oUfa_~{stn9pInecY-${b8Y2)5=?MuM%e z{~ynH8rx6eIY_a>d|Fx`&{Hzk6NI{nXma?@t|G^hDW{!(pc7_Ng+lwx%$FVjXsHXD zi=9%=JoeTUpR1bMwSw2&2=`CY5tHr`3Mu_gVKs3qosmy78PZ!%pII61ZW&%JRnbI! zExA2K*-XvDVWolGHO(NTi32pbxbgX&0A?uW+_=Gnc{kUTg@tvtw^`IN^Crmqsvk!8 z%HIkJG~vox^O|P%jD?4Vi2FkNF{~~L_p{`G^Q5emyqiNNbm7fusdf_&{S)ze8D@RrlGAwkrk2>pPR#02pA>?zKWlH^zOViJ>Qx zlvnmDau7EB!GYf`cn$mw_Q*XuP;&93qKirILPP_C_7Ci?102c`=a70e{h+mdSXS$U zz~@q24XhA}SG~q#g9)xC=%5fs+ceh(r{VSN2@LuImQeiOy6GdeZ>zv*MaSt0?F4=} zx%YVnT*2268hWq#;9}7LB)D_pqn{d1t>PTAHP?7AbasWUS;^WM>vTI_4$7~^ zT)GtZwqVhk%riy0yw`HE)I64n-{koGB&5w$Y{kVi^nq?G^{RC2EFrZWW%W?)(|}cY zI#7?w>=6WDpqz{DkBdonx3fC;v2O@oEO3ko&Z1mcwurp8WX2z9l1wHR$&$KmXTOy7 zL8q1Fp z{xD&?JA19;42T?SNRIu`b&G}M&P$1I`gWxsz|N_qp{XN#3&K!mFH0uc`xx|(oQsVO z|8yX$&e{E5_ieikJf#dSH)hlp?Q^U!Et7CLF9KFMncN`^A@(iA-7d$v^BZOnnk{^N zgFN+Qx(}_&oVx(tnv~2=%>6e@|7c^JZLzLb#C8(=I z3!@!*imz?6&bXzDxKRFYdETbhD0?DVri5#;mv=p#nR20qW_`WEncbFKX5gI4+X+V0Ps#UB=+oq%gW{M~ZsGUXe6?T7=T&Su1Q$}|m9pKmjQMRfNBAokYClBWn{PA{zaW1704MYq%< zK&C#T)0FN8t9zAlciVKW{*j!FIDT_UbX+g0<a1+etl!JYK=orR(gcwVQOC9D8q@G;u+DhWU^ZU?TzsW(GLQ0<|f=?FhVr*muc zo0`vXimb8JnaMwvV#sbc(*Rx|(ZTXZsyi{%Rfkg0mZf*L%X9oh>MVI}b?s4~^JKjR zL3oU%5N9ei!jHfah(Vau;=!b8#CyR1dq-~%SU?1lpC~>{-G7#{nnv~#y_34EQ<1Mc z-*tH%O$O84VrFxexsS(6h#c5-^Ld0CDz}hy);fedS{=pcg$Dt!GYXXhE9QX&%#N>3 z*8xaj6m|%2zAq;;|3IHoSwl5|mV-Rl18bM)UrgO4alJM=5+xXMYO<}d{X^lMR~F&9 zM1ujmfNKBewM1~1PpagswG?(1vh6hZMmaI^b7+zGff?eu%K`$q?m_rEy<3xHrBISY zw=jQ~%GTjcGj%+E)=bR5kn48eY^572%c9p*4eL}?VALI#tu*5OVxLk?#|cm&Pyf;* zUj4cy_*W*inJjf$`^(koxhczB;W5vzx_N&|m^%Ns5jgOA_dPpHqr^Q<$Edxi4pPgu zjQp8{fye!es}?`L_ZDqVYpmMg&ZFj^x4>5z z*6S@s*#!qzFBKl6qHAl5)ZUMe^!$#V5Qnp7vO{Z8*Pfdw)xpOKgVIqvk8WPmW;F83jG(UDqB z+*JdC$aEjdaj+KV>lhbOxujr;WHxORF(*O2<4Yp{59&1eYx;**7SJ5d)QAhKnW(<0=|7qzV;mPVU2G(E_Cp*+>DfF&5m&1t74{(orSm(Gq zCEQYkGZX#>bwFwebD0tO2gR%&1H>(OV_!C`Z*<@0@S;iw2sZ+Oj(t;_OP14MpE-_d z{)Dx8!dL0%#ti&Y(zbN1|)6_`^*%=xI%BDK#%6R0@=DB}R6-!#{Pg|%iiV$|HIJDH6a}Qla z@(C>B}6v6z}2{zPx4={Zd{ zdSkj1^pR>FeKCNPj2iROycvWO>+rc~0F@>?-UBVhJoAsukC>k)eLS#g3}-ei$d3Q; zq1?n`1QlT&SDWJn8AZyhgjcemkhgEjiZ`Y6LuRM`yRe{WfN)-UqrBaj4K-_`;}@jy zVXT={4J1fV4@z$Eb)@FD8MR5uNLwpg06sNdYitUyuCHzKl#yy_4QimZO?7!yVYJbb zHgmDR{E7^&%FG9`;r@yhhun`W1jefyC~1j?PQt zPY-oQeJKY$XVP-|b=D=(LjlfDta7!VfQqaDAp%iJWENjUb4*IG8t;^5*hiufwg@zt>^86j%hTAIp5;CfEc_N*zl&+cxgd_zD5T8`X`D<&JboxX0d zjRuK-B@<-o?DUGGwuSDrk=7!ytdV3w4v`2h1DcQdTUSOxTdIe6&_!x7jW4SOX@PpQ81#rXpA>Xn$-2>b^YUuD$7pH z!8 zZvOzcb&(oHlMsAWAX9U8uV9z!Nl|<|@@k?JSHX7WTIrL%dM#prgit@pA}FdiS}L?_05C8gy< z16daxIO4SGG$0`=NbbzScT^u@3?@6AFX1n{;shK}0%hA3jh^Z&*77<_H%16*I2m39 zfd_$7xIX(XUuOrXY2#4w5Mh$0u z+E(KMjJ#Tmh<+Rhvuy!GEPK!Vt`Pqws#oB2d-MXmP(HTxw7`wFogJc60FoxIRh(;=LrA3~-_ z$bqd#T#n9e+I1-;r^DDVQJrl&lTU`54Wx%qdgw|thc!?sE>Pd2(%FDdgQ`x)Z=0Y! zNz7EdR#fsi;-yu&s1?!Eng!xd@5b;9QU;+XAERYyS^f?*Z22vV9L* zjvd4bNVTGXAkw8PDjh_67wH|P6GE|o#{xgWcV_lpd#yEl=;6u%wb}ms5Hcepe|QdG^h-$}skt`!UM~bs?mWqD zXPWH=xPZnK{q4pfSVc7kE}T{x|KP17*F>*dEpUanU2@)P?O27UTt^)ho2mGJU^dW= zyud%6QyuCS`PZohTST3oAuEq#``1s3QPGZX{_S>@(bRN($EpwZpZ}j^3HdOx5a?<; zvZ2SVmp8OFJQuwyfmz_k1}HEj%}=rr-L6GWNz>-)Wix+8KxU&0fypes6TJ#_(EG$U zs*fZN^~|R*L$s6g!LlEHmB=IVo4`Pfdpau_QdYu(SHCfF-dpOwk~yXz?TtyPp27DP z8A3mwxhIg(I`oqKZ6@7T94}@AeZf_;9+)!s>IUr?jYI8(IoE_sO>mCI6Ls>9RT`aJFtNfHOkPlz90SrGmO9E>9M#$l#A3+o5; z*2Lus>P1@V$g*wJQMAP217SnrMIkK+}L{>mvt8@5l5a^mZK~Fz- z38aS=+ge8db{0YNz^@jplI*e%2GxRBA?ETEimSe8PX3gi8ys5{p;79oW-Qi|MnKg1 z*n@2Nq0OaGyAzsk28`TD?H$>$a0n(I0;NndAtKwZp<$*#YpB9_QDYXynN>FOza z$%)?4B7HV$u@Mbn^6Z1SRjIh0+@aHYWQe7<T6JBSuz$#oz|LtPrsMDhBfNuj@m%)|HxT^Uxit;`L$;%uJn~V!gTG6 z2^tWu5u)WKI}JD~^er!~@jV(Ew!Zr!Kl3IkvD{_aB<}j?NqGVd$IHwv+B^iO0F4Rn z_EbB#M?QN=Sgs#Jp*nKRY3?SBv%4i#Gv-S|=y35vNs$QN-(U-faay%8{j68ldDZ8+ z_kxIifAncdI$A|7L2rCDYkrdgZvdEqxaz?Q*@R5A*#2G_Tw2`m<;kiuMXhi7jh6ue zf^FO)il}UUS{i>6^vldIA@w3Kohsk2aVo{sT^g&buVjOHtZ8Xn4lAo<=fFsL+NZ%i$N{W0ytni^Y%aZ;jGdu_4f2(rV}UfCT+*e4h}}V zr*(1xydNv+FQsD#nkO~ndhAML96iXSU#e>642oP+s3qBui z=Oqd1cb7(GcYjTg&asex!vj8zb(Dj<1N~t7VYT=&J*{j5dJYh0yI}rtz2J}g50;TS zVHrZusrWf5-4Zb}NW%h=0n(YjF+sHNSx*Q*!Jfu_j2uZu;^G^hO9?}XCj;bJ;Tlj{ zUpkun=u=6oxC35klVdiryxnWBvP2u^`m9ygZv)nj8CXegx%m>3JNO_Jolq|0n)e)3 z8*Sw!x97Pr*JGy!{)PBC#7cZ+IeSDra=4=a23T__d}a@CS^n_&bDgW~YWOp?3c?m) z6HCh*r=#PY{K%*M-ym2LhbznuU}``sYw=_MGm7J{!Upzf|5ZlMmK8x-isIeX}-7rJgg?K1#l0TNwW`b4xfJ$U^Cm|x%=4_ z8I8SH@n-Cq^bV!{I9eXWxm{sv7yyjeBG5+DdO-%~^L`9*qS9;2b%Ub2vToX(I)qLJ^5{yN|4n=@ma@2R-Ug|Yw&VpD$>-Z60D=Eg$OP*^-K5N@XejXKc>Q+&)%WTkRJAOU|M6Qqg4OQOmUiu;k{F%Jy>|S`D$^riHmMVLVU8XVbQMTb ze+vm?8tQtb7{XI9V>HQjEI9nS<(!y8R{EZP=CD7P%6+-~U19A`J$XSX$cfv*c(<>H z!i&m7K#z`d8V?=&D~>Zm9QZHwq*r3~gbtu?)0io=DRAL3gHE^SJB|bT(dZ?cYmhzH zax6Lj8UlZjf*WGd&rO}kjC=QJw7w<=$n$y`z5ChwU7|yz<4~^i6+%$I&><+mpQXtW z$Bxt<5NGeJN-;6}CuAK9{Nr)Tsyev2{8X1I*EJFes6C%_ zl6l7Gy_z|IpHrKIUu(9Bf>6WDu9@pfJJ00Mlx%e0RYEd|EOV|iAmEIn4O_xm`{b|3s(_AR_m;BKUBhU;2koq2^{6-Rx3Rr2M!eNF#a6aikOxYIP>B~H z7Xa~$h?-B#eP0|e9Wr;z1j&B|DqUpba@V_U1rJx7m!kYb_6P5_#Blyn8M$0V)A=zR z`9q0rW8XPScNr?-O57dN@e!AN+Erv9EkF8bsKgK>wRYnCS)s$@)oL%}Z8@S3y zs4nW(v^0o3S@t6BH31dPYuB6`Ei-IbaX>*$fq3y!$404E^gI74sc9$cxlZV(yRPfC zM7hd#ij=*`;z!%7OnV5xK`R^w4;{k(+kcO`sq#cfwsI>(Gj+B~O-81k+TsDvlJoH# zy_IO{J#K!8E~zh`UYiyN>eQEHKv^BT!pZ+jYJ32E1H;k^7X< zECsb(iYn4N(`%NHZ-H`$B%ljxJu)7to?EM*3Cas=1GolsA#=wBJXpDpIuvCzy2~6a z1f~9Zen<9kI2%{$fMsG@;QK7(5$yqBSG0e*Rc&xELQDbqTqUVyG0G&%3V3i}um@IU zx zv-2XF&nO$VXZ2I8{gK<) zbdCRM^Q8LEi~aLgpb|@Puh?IriS5-R3l$9NY~zdTuzJ%W%)dvz8aLI%4zl{=<7Xa~ zXYjVf{o;!C;ppS{JBpagvN+{5=!JpF=!?pQ!ai~dM9Ii zhfF}JV{PfYHds>R1>%4msQYwN;U&we{oP&h;E=;3ZjF9!-;4Lx!;_4neAKdx{SB&` zI@&(+3d*y;zJcb#47htCE0?EA41? zpfn)1YNlSoYrhEy$fF4(Fj;Zsz*Ws>x%;55##X4RSo3Q^j)O?l$or+9=9D1CZX88% zn;bp!=JgS=Wl-)OZw{Nr-E5CCC!D?mux|y#+*3Ybj5{pp7nOOnWs`LS8!b|Cx5cZy zeh=U0DA)zcK->L*M={ZQU@`BJZMU?Xno@F^hCeyhif<)KT+&BIiqDrw*0Lx)5$#(yToz5UUC)TgVn)K)A9G$DGjM}^UJkENB3 z!-RsHBbl)^K0Bb66*h6m9WSYsXgl`Z2L`%vw}0a63p2%AcG&Yp-`swUZ;$rcV1QE& zF;4<>8%u$k__%(>!N#k=4dQ)PU0hZY@vIE+OB;w*LU&lcguO49B8fI3A| z%bN6S!JVE3+dk73B-riZevR*aWk1#wypsj&u5W+}h&3@xw*h+(QwdERoxaoa2r&y) zE64LYWbloom_(kvqHscN$J$l_{r44JIYYM$NxJ01vT zpn|v+{t&wmSnWBR73%s<^7o81zwSRGoAinCjdFfl_%sxhT}P7yX9hq#I=FfJ?|)d( z^F}9lWsPz{_bwy3^yK*=^T~=zbzIViUagwu_kQ3g*3Y3aAq`c6#=!CMfYgR7yP*UQ zJ=~jChM}srh?~F%%^Mslui5Y=;Awulyt1({`VYub)FtW3!L+7`%|9N8IeOJNOZP_bF57WIxpw{!_!^(vsE?t#_VJeM`$` z-6zyVFi~m?fF!_-qsgl&UWD^z1lUTspD8qh5{a6iUS<-{#qxSb+ji8t3A?em6 zp}!srM1vwp(;(}Qaqms*BRwA}#qkmdp` z`w|0#!}9eD&7-!@ma#Vu^_p;b$W>wC2?}PFCPWbH#iUi!1U0Xh~29bYxLeeT>B!yS~1u+6Mr_2$ljOG(FIUR@#DLr1~CPd4P5 z5q?rH@&1V{BRCU_h6;9}g~AYbNS1nP*VqwBr5l%XD#1k_a4 zi+O=!l>>o?X*4!ujVqE%0i*haYsT3(>nO1%xYJ{O0dqWyH-NZ^HC*a5+iK}ZT>rV( zzQ5e`-Gn}kWSjpk(jrFl8M8J#>!A}Yachq@{j8$o4uvYnK@T0kwFAzihpoNfIw*E& z{0w%XN4NLBsS5Z6$c1HAE4*|YQ8whZmxmXeInIc*TAEc0>jmPvDH|pDt~y7~1bdvX zR1D-7f6C(o-8M6$91-eryP5mux1fyt*b5@Jw&=-kRLXLQ&hQi;GnFQ3eDRBCtD-Xchpk?GJ3#00 z)t+_O26Ua5>{XMKD+b{Lo#93LIa$Babnl+6nHpX3a-n(>7&riX>p{j6nQhgiIoIn7 z0cgcsnGxFlZYPh?BCJ0Q$&k#!B!@rKCR!)yDF%INzuU^f9GW>R`BPN2D(~W)#IMeB zAlOYSOutC;H1W?btqoYm^WTFyTB4U=5?2EHca*fpt`#? z)bG(+7QtqxUAs(a9WSL_snrzE&OP}?#UK8!70^BYbT;MQj^h@~{U_UfPk|jf*1&+W zyc9596T;C4-;XT!svRhUo_bv`5Y9p+wGd)J#x(gca8LVh$pBug4-jn9EM?ZuuQSu8 zN?4nlHOTSu;*Mi4_S^RrM%^M|dAb6SE@bofWv+o)?l)eSCUXYpU)D$iG1N@g1$!-5 z%mIF5eEhWRuXJ%Jj265dM~Paw>WZ;g!&3}f$TS*dJh|!QQUrUIta5;jKl(HQ=H0C% z4TCgb#rfOYL9?PUX}lLQZqQMx%=K>-)-_*jYA5t{3|EL24(>Y*gAo%jPyJf{7<_=e z#87#SYvA((p+fiLXEX#YnoQjU66WXvxFqOr#)Z495!O!(jYW^+>B-^2Aq}4P*0z$q z1^=>SJy=svIJ-Th--Tz)+2nu9Efi4LQDUxSu8q1a-k++StmFD8Zb9P0YsBu}QOEE?PscyAaE+Bn30*e!rrV&i7}9$Vg7 zOyC^`3o-CF{te3a4ReTyHvHj$e0_l`d=$Ye)3-Xh#=>`i+*uv}at*=u933=p|l-)YnF6aIihKl$k;r88$@D_uvQ9Oa@G6x&oUY0*r?fkpRI zgKIyYk5>$55gxhr_a7oSaIwW6Z9m1pySx*^iDBz2_26|PRN$u;e&k)1Uj0_g;=f67 zZ+we?R-{>`o)yiv(F`qQBJTm6Mysgg2_fCGy<_c!YmimzJ9c;V$lHQ>#2O=Vd%>r$RMqsXwTr4r=i~6>B$$^d>Pt~pBlwdI| ziv|&|V>GhGToURv9PD)gltwp3E7w@Dj)nPpKo;!#tJn3$S`U#>*2=2{eyj&#C@~N6 zBXs^BSI}9|M5Z_|Tf>hymb-YWN#PpDc0Z#b zBV+e>bA3All2~QTxd${cSy}e6;dS;@nvy<{=@z-{Pz0wPYD2BoOV6MKB03l;iq4x{ z24kXnMg|F#`s(eGkm>%+1eV)a9UymT**VQDf_&?Y(dBQr=R>FKX(nv9Qc`E-ZITDr z4{3C?`7l_1`=wiUk2!a}K9t0a^#Fff&oqwt4mz}>@K*$`$}!%=+a6#8ZB`C>u#+>v z{2Xa3Tkr9z7lg8Y(<6W;gvmRx_Cq4No#YFrapk2u;?xk@UNitK?ye!%Gz+-f*;@W@E9nodaC%JDfPO$OfuQ*exdq>=By% z0Z;4$%YekY@QF?q!+dH|wPaXsQ5%&?-ST0(b<;pwPVKKdx0%012{ZNdxC^uJx+OJb z)Y}(1E;?*l$D0{hI|||(xYYPEDVDmSx-YGxC!Wx~Jo4oIq4Li+6@PX=@()-a|L(E6 z_W;AsF@H=exnUDoKtfgz&ykP?ZY2v%2B(?W^lrPeCLx1+^U^6Svkk+Q)*%XLUyQ=# z)(UmE%hk4IUMb_vx2u$n^q7S?dqJTh5udgd=l0Eq3wk^Je@NGkUB_3aI5nVb)r$P1^G7^KCo*h!)YNvbzN@OkaqpEf$E}+I<=gvXcE99At$)P!ZBD;ZW4Z46sbxW^Z0>pq zimj^*ySpGA_9n_k;@@2LiFT??t<*@?S;GQt2d?n z34RWUI$SH9Mrh#Nj;|%YO-YFCT6xlQMED|?)=B!L4;i{nb@V||BE59q9$&u1rzAG$ z_69F-Ndvms3_O=TJ;B7fr=R6kf2Pxj#Y)@_!?KZ2XNM?*!a6!Ts6O)~TT%e zs{huS@57^;dzCxO#Sb6zJavVfeRGU*=Ii;s>X;Oxqy%em>ydK*24d@?nsU;p)Cg5u zGuz~aYv!-e$5wOL?dcoB71*U0N-ssFl8(+cFI|aT{v2wIK`&n%smLkHX-n$f7}y66 zPUA5BAt|qyc;ZtsQQ2gcI>>#EC2`7)vE6<1jE;^$`ryz;u#DI4gbK`3La+vw?n6`? zatA}X$}OfRz6G_&@)=Cc(L*r^b|}sj3R(nK?RN+5_B0rvC|-N4)=&R!3mQwx29aA!^v>6p;8vwUnF1}d+eNR7}-vwWxYRG zx8+vK?z)liHDcw@EH{o@zgA0|>VF=ANq6qt?cYY3;I9P4^q$Sv-*^YMo%Tq}eEr4f zuEtPLW_ZEIaZ~4c^$cg z_8#5j;gu#Y`hnl%NmZ&ZB`8K@uCo?wOk8WK_i{ze(FG_US0VXLJUw39VB=M7OzT;J zdBNw;(876y=+|WRK8J;`YvdA*$tpb)J(=!Bj~wF)E$C+JD-aFN z>k*Oa5xvsWiBwa3spam7a6pWQYoI40R*p92=55_^sV(}^AvTRO_YBI)zfeamFf#W3 zh+PthRi)H9H6$-=+DMHrA|oP@9mm{xlq_7BHAyUk6^|9o|!z=xtWk z!1oSrzMNgr73%0%VT;SF;FL~2uGL*7BPM=T@j^z1UE9ytI4PoU(qBm3c(iNq$A{;) zs-^gK4#+gDxUx5}u`#9@B_>X59*Y+meV0;cv7_c3rlO?M)#{envtgR&?~tD&?Mu;- zF({sXfUzIHbQ;Fi)qC(my%2q)gwR=O9XDCu&*PLfL-9H?a#2_AL7;P@Wc}EhI?O)l z%B^cm7blCqZWVWVVOMY)};2KJl7O+UV4vp;r8x>O#|0&vS@oQgj*f1H zlufIo8UDFBZuZMtEFV32G5j!jcv5^_O0e=xO0`Gp(agTK0Z%yy-B+6Y8OR&XNka<| zZB{tD8#~sp7gQ2vk0>C1Zn4GVS3EdBSeA#xo*AG#HVD2NU93+0HVWBjIjJ>YFxjD4Mlk`UfO!LQkp*T1RTpTXU#Z&bH!pWUX| zZV-bzdMEo=gF94R46ZU_KDFm|hxFK0@7a5;JIf)&uJy8+(WdCurqP>9&#cjM*&%6z zFoGBXXIka@e3jg!Gsj!WD^nQZe^orAU6#`2%o+R)-2*Aa#%t((+v@HcE?XXd!~34i z_T0=4Fm$QjY)~e}#j7;CQ(tGg52X>f3lUW~1%;_!3nvmJpe^51g%a=4!*v^DnI>b8 zTpFGB#JQ~q0?DsfM{dCmyeBuQh#-ny+3W%l`_O71 zq}wQN`nO+nk3ErpBD@rZFF{z8s&h!D|GaSd6hz9SCiq;7j|&@K!_yI`oU-yK)=S+UY0Ix%zpm67x7oQ@ zBGN;Dvy7NCXq*C{rOfkzjy&B!&X`q^^|^JsSe&XjK~I`E}p1<9nR}Eeyi7t7-`zhZnCuh zNW1v;${h)xyxY$=>S-76!QeiFy_t^-=00Efrg1G}Vx@X?%E!s#f{4Q?T~s{p@{o(t2papI3+t8jcwbtXo9;elW?WfQdgv+J_4>`gsW{aPE|%*eYNY5yswowI z%XZyzto%Yo+SEvalmlX@i*9^7Q|@;%=aJOR7>ZwIzD1X-7Uz~=V)jh!bLvIpdYMMA z&20t;46(x!HPMB14cDzjV_9w+nBGr`nKMo^Ke<0!>#+jbo;4%KY)9|1zf^W=#dAaS z_jMD`Z#^YPc4XKQEsF92KP%JEG*hfm;Rigm?o$C;mc1Jf2tdm)iRBGCGL;OFU*ZmM zud)x+eL4y;(|5U&C<3u^p<0r(xbA7sIH)>t(n_L=_yWe-$}6nf=5-NkAKblij>2GR z3)cIF+r~2p4~@F_a1VVq+U~Fj`Ui$nymTryY}kk0Zp-H3t=O*>DC8B?5?NZ~Yd-r9 ze=+gFRmiAlUO#SsoWpV3xtEl4s&1!Xm>`N|1N9x>5 z#Ndt|oilG;yQ@fwhCz1rdJ~8E%(U6IYRWMfmATKWHl?#0Ra8TGc_h?BRbc(moVy4JRI z4aH3w*>Ht@81W7FXYaS%`Ty=;YOBq|%NvEbVpiW6M%r}3dz6p7SX?~PBXr5&+&R9e z_Z>P-wQAaqs*MsRmk}DFqW(+yKdh{p2NUu;I-;j{@E2ERd*wUX`Sr9dpJ@=GMqFNA zYB<|=?&$&jr;^oHtm|hOS#pB(R<*3G()aA$OAftul?OZg0g3gb@=|yN1nz;FfcLVV zCegFzqgu;7DsRcNh4g6a>Q3vMzguKrx?)4veVT1UQ&Ond-XoxB?N`?N`Xfx7-=oCh z4dwgU19|N=`neZmIeJH*T;Sp3>mCV$>t2hBioz{!Z`}L*LewcT^8Lu-d*?$Y^kdk; z09H}w(gMP1BqG-)Y*xLG5c~DxL)z;@d;Rc=Jk`GsnMC2r%7%;``4mx(O4WXD_!4 zpdg~l7%b)9Wf0D5S7?zfJF=;&__F&$MF6%+f$B`o^`RzI%yr^>1xucr1UwkAdKv13 zg3pAli6iO9Z(9n+E`R;lD7+RFlRp-x(#%#9``O4Z-Qspd8aGc4R|9%fu*DN*OS`N84EHq5%rLJnY3vn+0S&hj> zrM+ERCLB!1m`6O-3+xXbq`P;Ez>g*M9?-DQdsSNa&YqQZYGI+5q}g`^QIupaxfb|& z7pZEx(i1IBz%N0_rTg+Cf@o}(S-IP0vG@qx7xypbH>cma`Xn@>>E#C(2N+l+%-_Ny z!PdM_v&HQC_7t-?^W=!r&~++TNDwyk>shD9rPXR%L6J_tG$nq+3fM=@Dm*i{q}k8B zP0UF-$-hO|OG9N5jp}ExwF0&Ju%3fu^hicAuG8>YS`DPMP{V7b<}OaA=w5Tz+*A*@ zRiA`@RPLyMbc{ChlS{;_iP`hiwI3QvDuRDqVVq!+q)&NGI-U!m_gz!ZZTPssmAUOo zN-9M%uj($yd(7pQk@6LyS)M!)5-MU37R_j9wmQfJR8$f+uPcIYOO0lhGujf)F{8!@54XdBt9)FYa z*7*~Lg5_sh2Qq!G@qa2MKrxKoXmt_xpwKmX++gOvM*YO$J7X6%+y3^;USR}wh<^?` z?4t@f%uKW*511z$l0qU@vuC*{ zG(sZMa7wT-GF_MIkkEH7GCXon=%$z|rhzybf9RP*Z^D*&(4($Un{N6i!WxGFU z)Ce1O7Ls+0y}y!3?n*CoUJM|`DVCFfAYrE~T10y?r z_}DLder_$Oj4|8Yydnm)m$}x893pePfYmpkF_`@#3#%UHv%Fy-L(t%pH=5kgR=&Hf@ko?bx7*JLH(7Q8=*!XaQ{eNRG70ND58t0nT)JG_3!WcW?sZS} z_5K>_dzoOlHr(MbIvs7*XHX+wN;q;$Dtsx`WW-$ceu;_gMb%s-=4*(O8XyTOhM2-h)f^B==Is8IpIfj0yF2p8d9WTXH{huy!F;|yBYMqj*8+-G z*6Z9(ojxy!y4{&h=2h&ryvw%-ev z)4DzI=~1zumknurqkY{n(E}}Jlcb&~1&4ff!;xoKr5C>m@$Z$b_W0G9S_J1ytbHZG z7CCmGnQY?n&m}kVt=kUG!x5>qQuA3YkCc^FR+^@-^%PkhQHvEfjm1ba2=C&*6BHCP z6`XA$vKoB6xn24rb>>HB=Y{8#g?SBH!jPVYq;N@zPu?&Ku(5~IixrB`qqPsRKJLaX z9gK$aZQZ!}ITy5xMnRF#8n;A70#2b!i&`=qY=Khe{5Y)VURFAGQnJ^V&oH=Dpr>C# zaFD#fAS0v|i(j+t)+3NgHcE;-C$3z~OAo(and9)#hQ5NLyl=nhgsYzdT0)ZR9Jjhl z4fRP)>-v{Uu8DBLOHhA*b-`r8JWO9!&s!221b_PX@4Hz99|xjQ?NQtwEX5g+82Ub; z;>27g=F6X@iSv0A?Z$~5nmb=E;L0xR5uC-Dp16;{t+6MZQ?4b)j%ao`!nOF4tBGjm zmAf&UP54b0!C2w#wnHp#LpnI1kq!s&=(YtMz2aaK`nxFVfm+(0{Y=68Huk*D&!;sA z7tqouJ^P|z{p)7?d=yrCnFy1%UoB5O<5^s zIcGb`&XI=hQ{R4tJQTOIIz-_vyLYMHnmkl#284?a;LXUAi+9{@^II?_T9CG!8xJ2n z)b&Hpx!WsgJGl-KmNhy5XkW;FZYW13+?~wE$|q3c7``S^AAdd> zqbz*Z;E{QJ3o6i;FaN%0*(f?4&4T*$;uxC5Hm<2ph_cvZ!OEn0sD$88p;u%<}G#|CQ%NA2#c}B%UDG24<%)?n^_so z#kljo7K9hw+qk0TZJA{<slqcmVnyQ1PNSpqwVFBmpDgv98tP13%p`DH40jt8M|I_V`DX2TM1rEG_IIiG%GhBW@LkA7X8Qs!T z>S~I}EOGFP)(xn^vfPV}F-wkju2!KgeL1mj|NhGSQI`NSOM1#oYSbSI0Rr-lh2dSSfw^Y=2xsprDPn$4)oB!n3x{f z%Qlph$lb(p!!K)e=NCtF_acK#(Qp!U;fYzhd_5(%j?^21HY+F_e%k-umcFVe$=6-O z(|0a+uWVBM=6j6-7txrhcb5}8A?!SV9&A7Vd1rbkH{_$Fv9k&0xyx3aYa4;7I=hdt zYSOi;YOBukW>VqxP8R9M<>3MGS=Oy_*(2Y>D86-fo%+iMSRNY+4Wez77xx(;MjoDG z_KpSeK9<;p3G%kfdOhxf!M7Uqs{^6c02}k^9fTuIfo=v)z_ z`zy$OQ(o%xO0BRKv*HuwNpo{NxQ4sFY-i6vQ4aX6{$g)67tpi&(=zSlDwpQ6fzCu% zs>5uA+>`$$m4*aK61hYYp9^VqLZ&AK5L($<9c{{ZpW5L`z}VDX&Qx05^56H@es7OA z@KaRZjn%Tm=WcPTYI(4WY_m537Q1rNb>;9)k7ZAnJ~j-$ltB`%X3pNt)M?g4IWBA` z{7?6HSIz%9yrKNyOw#~SK)Co*RzJ&aeJ21zA^?8m&7hbGKXFpndRcMk(b{%!$5%dr zOLs@fO=pL!qi*iP{*I~&>0#{oeZ%V^=lUOS9N_G|F;m=QY~9Di+Fh8aixX;*U03X? z_}qxtk&@SZcRW1pz-(=wt1(P0y)~QpWM1-YdkA$YduSoN=i{CEeeo`>Z~NeMB7Lri zCW_%obpSx7I%6+u>+Bv!U0cN$?FF2Gefms&z6e;a;L=yg`7pa%a7i%DXQXc^E31BD zb!o++=*TePwztQr4)1C8*!|7r`ynlO2M{l2W-SyITaJ|O{7$tj!8y%5@{-vBxA+up zSv)NKpffw`=YWB#wHDgH{`(N#rM6}A`|E%I3B)06ngr$Xnr|)DfCcs?M={YMU8=SV zV7rP!AI5XWHfD^lnkma8%7w%6(_rS#s6=ZaOxKlCf3yzBwVyifNpF9v!zODe$ZREP z&^vjNh?6^44OWAo3?+xQ4tR|`X&cb&uu;?E3he`MNQYz%tVpLo! zEf}QCQ3z&Zt5qFukFBUGm6oa~tDuoK847(Lghzd2Q1PlO+%O()OB*p_DgNxw|>%Fc1WOLkxB z$LPr8oPOH54mS?pV;A%e>`rdCPrYkuWgAfIgc-Vyp%yoi9!dL{{Vom4%eVX}uR-SW z04nyPUHw*ARFcWay#A~4yc*-^6c3g51F{)umzpepCQXToM#)k?f>vU#F-BGm!fA6-^baY2=2fz4YYHdu8>GEvs&&8EN`HkSxOwwhgrw2^!Y$XrpS9P4I z^n8R+bq{F)6GPX19bH|@FG}gnkDd*bJt7X8HUqw>-+ zeWH;Iu3wIq5nprM7)&ipJ^zM!&=blo4I+D2QF>#e^;}aY6haH@>))OgD052U$d4d^ z3&yJQhjGjJ4|dcnPzVTP*mrzAcs2M<{L1I;HLQQ>+ezQ5dWWEiN7g2&vpG72>OHnTlku@-knrN1}iDw*SF0AP@Ye|Uu*z@ z3C+R_?fy~{DkL$c4r|{4BOs9K_wK!|Z7D$CzP1{jA8}cOn@8}YdLJjOs`iP8td2%_ zoTs(*MjqGid$V&g;#9uD5t}Qr7&$`0aK;b1&Dy7-FaPd%yRf$hI80zDcKFXQ$-&2V zMta7zAdiErL!&b1zH=NtNH_1*^!tSZ_Xq!XQL~F3z@?>Ym8+`KjI1g;*J9Hto9$5n=dNu4yx63{a(M>^5LG?%%4pMa!&s<+pWbaA zR)Qk5p{s3?+gn%j`;vEgr~Y00=N4eG2}|-cqmXe|peE5J&hq0uP}< zl(n4SCBK=waz=)|R>ruqtSp60dsk-W$+C)HGm`g6Zbs!l1*J{V2RI-=UI$3*oSNv|}CUc_peWe*fr_?*o6@9O|7f7jjrK2GOeJJ!ar zi(QENQFCfc@JQXEA7k^U+Vu6-&6o7p?a*e!03b*I_) zaRonr&Vr+nGvxAK8o!&8FnLzBCpvR%Zl|3?UNRkVKcTM!oD6`D=#h}Nl4&#YjRULd zv`XpPkNC|cvX>D(nZV^yEZMM{NNybIZtVYet1GFZogFk}0johh47g;` zXHOd%ZcEas($OF4{0A=-G~`cQDR=!fpTf~vAF)t7{He2}X2L7guPC&Sy2Ap9n}y(s@I2Y9E*-?RSyHp6^*z9yB=U10o9)4USl?p zU?6eilL1a$awq>R_V8QHR+ynM)hsYbuTN_6lXAF$otu5X!PJU)6lC(k5>{VQk728* z{$gQfqSk66TV|skAt}z|21b^QFB2(|1u_s2SVHLL@XM&+V+JmKeq)^c4%bBDMds@F z@e>Kho&`U<@*CFmUeJ&l6P_0WJz%Knzh@nE6OYQ+`E%_0zdtKK^`nByq4vS^71{(K zAdN!#dP`np)$2lUfm7tL_ofMEzKM1vm$<~)#drj7L0$K0zhiV{^mHg@t}idv>q?D^ z&GX=AisA4PJ@P&&0_4;n|M(&J`qB0rYQgp>%q;8zr_abWdqPdV#}*_3NiTBgkvkvU z$2=hI${02FzYq1DEjs|iI(DxCwte14agNCiz~}Z*KmWm zwxk_fTj#2swczSXcwU&)!t1+7Ided^AD)MXx;Faird&9UiCT#{z_H>rMOz+E zALV3eFJ&ZAO%uYS(|{CMn-Y0k%HGjjTX5e;={e|6rv?h&NO+QxA(w`ySuM$ETi-2R zS@%hjotU$#u=T6BDi3f-G5ks0ljTW7z=D_3|1+Ax|J!PwNe(KJi7l%P6J~2AyKE8F zj&}+iuAHTh=8|l_JMtF?--aPGtP`laC7+Khg&|(~zBfSg!GmN4H*c7pdFCl1@(j68 z&pH8O#sZK+)aYsqZ(el`5QaN&|Vmr-4j5PnJt1gWv9i(N89FuA zw4~6}EW-#8k!GIg=BmZ}rJ796LbTYDrfsua?`T_Tn0unMjo~>f@HzK7;Kyu%+J4;Z z`~I8Y&@g?Y3D%t%zD+vh(1FXKxc;2GRGFN^le;{PqzdHO0F*(Su&@4&YXNm~g}7ru zS$Cg-@@hNMC<~NK+|rN?*bi*HrNJKZf6`!&sAw0V<^KtJf$hIjb6`@!<9UYie|{7g zS#cTP;oNz?pKxinTmn5@N^Z%cukkL>Ph-K4gJ=1$IF8D`#_wHR;KFQ!v-yAmc zL_WbJ6QDUqz&h2N4LmylFS^GVsGMrGWaql_wkI-4{y)uQRWp1ITGGD##pgeE8o4f; zR7t1eaY>ZRFd*?!uE~p36i_GA6%}E2z0=p&_P>)bs=cvB7+c*@xN~>14ez{qWn>u! zj|WyP!fmxe8vAabc((KipU0`?GcW?Q!mM+6C6oJPqe?h_+l}ZX#*3V>4Tyn^8t7YU zK5i}l>DTh*%iJk0OsEF^?FmEqtkJ+_?euo$@m^*cPXz@`J{g=#Am_}?jaQ6_`j^zZ z%8nR~PaLkJG^Br{SQ-P^@pSotoi4-XRHc30SW)4LjD45?iwDjQ`X3?+4`(Ee!bOrC1^R;k?*d=(Y}z8eMsG0nf;6-a(E3@F@au+E{-60TM_~DM{Z$GY%hE#kmtEh6vTz)@7*xv;1xP9}t zh02|^KH$1nZ8t+raO*Tv@t{E+^UYSu=cptL%=JsOobA!uAxzpXPEY}@ki(?WQg9077wl?6tl`!yA~?AnA}>8h zzgab6<=$K(H^02;-o4j^j0v7eI=MctGo{!uhATTw(vhe^c^{nCg@b6(IeEBC##Nxh z3q9kUUH2{1Gfx7`PjIq^60XWV7uVO&}H1jp9$* zgkdh$m9f5p&{wiwRiz6&RHr%4Owavi+x+KCf&ZBg8EWx40kzsr8_v0C7kmM`GHGsG z*6<%LAi0pT(r30Z5JRFvnwy*Nz)47c84~Y{4MBkZ+`_-4l=T28A@G7~1A*u}vV+hu zFvxB;^IzcUFEWI0(TczVSpk8Lc|&u}!dvo&6uwi}0SCa`;)a^!kT|li<7|6WkQRg+!izU+{Or z>63jouO~=)vwf!9HD8^!b0|2l>BnnF)UCe)n)%%O19Gx@as{%1CSbuSZIgD0Z(pgc zMk9kgwN?KL3YF0n(iSaHa7@^c7ide#GTez?QT*wvUe5y)$c;ch`G zvwLJ|Mh$wRmCLS}f~}c7d_JD(b#F(5#wgrXq3xpoS)<7XzuVJ)X4nyTDmpi_lr)F> zjb#H_uzuT7DMOVyydo~cKL}R4<*KtCI#PLqf`V2!zMeUKMiVvG2MU}0TaSU9uuZpf zMAIZhSqe9>5Gg}YkpjEb5eU{Nr>6b_FM>X|n^J6xMeWe1R?1S8JjK3inZSEK2VI-F zYgfz%AQnB5jf*eYRwW+`2(_D2N2L{{j&4fnXisaJ#KeJ3W6|r-8}-(=t>64TsKu?a zNxlSrM;A&&L}W+VtGY1|hoj(&hCzw@1mx;dT|Pq#WPVlJtz5?ZM=U?>3O}ItBlncM z8(g(XfG_@Y{q*~Om$~aV2{P66WIkEP?oJEYY$V=u=bx?syLQ#Wku+u^b8`Vg%dN4} zEpQul1>U^~i0`*y0(d0hG3=D{Q0;*B2U&8L{x=4D<#YL!3) zac&ohjkA^W&B1wvb~zLi3Ri8cEqTGB0(8Ydf`yFcgs1VYH-yes<*NWMqT&NeOP|We z@d9SKOaclwYU6e+(CSD*Q+W$k{{elXK`RVIfq9O|T7o%n$0)#xu*4k* zG|_7`f=o4C2hW_0K}g(OY~W{tI2~htOMeDIcx4hv=&(P)&zMUvbgXhR9{7${^ zVysW_gAPbmYH@POyLUA*N>QUtI*8zl(N&(ijp=9u1wK3tyvI6^Mh{j0f9!pSS5w&+ zE@M}kj)e|7_690Vx{8XZfC@;FfQWSIoe+WztYiiY(m@22DkZcCs3?dC0f7WUml7a^ zfDi(Kgtsrk%s9XI{(!fZS!))V#N6E6d(J+4fBXBsGe+JH6iX|N8uS@S867X-btWrs z=!3$i4LqVb-Z6=MPI@<$1&0q`%Hx%&23gptNq`O|HI8kw(-AA>UpbvRqWsE(i3cHyI(skolWNVYrB?vNk+?y^2@H(u4baO#~ zUq8oh7+N~PZEfayk>t)hS&K!+cYw3e+!e>wATL` zHl!NsG-dNSp0XW{gB2Hgqeeiphu$gh({))|H6?qsVO}Y{ae=OzMDt&1GMlh{qorr4 zBOGYtr_Nkvkdjh!PA@bGt(`wBys;=wsz?J=oD_MjbE#1%ou(j;2Tmi5AcR*oID>)d z%rZN>r4~9K$qon7lz4SccBje$$Oq< zkBws|G|?W4DArR(ptTnhVd71#jr0pSrg;odUmpH*Fcvh8Q??-#kL}w&7O5#hbeHrC ze9Fap;x9^1H&W*6zZ2!J3|}V1*WIrI9mWjoNjz!{oa%4>A4HES;(BKX6z(t}(4}jn z@F@%Z{tU=<{C(9OVLb{#j8kB2;|)z*&1Op-xQ7!!QHCw-u0gh+0;6Z5ycosqv05J; zy0c39h$~#D4s5qo%DWr*^Uj%d)jraP8r7~iM8be`fc``ZIfwsfMketR+!@SdzbQgb z@`~BM5lN{IDOx*HG<)bH*tC4hVAH){Gzk1NZDWSX?XxG_J)gj&N0A6HC-D;Gt@*X@ zqVp4S+As9b5KT4w*+5AbGIO<<)H zKX4|Ih?4>Ia{!Y^CQDC;vb<_2sWdU*Awd`JRqA&vzq}6&s0a508W+~@R@O|v;U&CZ z)5ZDDdyqwZ_AA=>x=fJfWX){;`}T)#UWgT=iBLhLwY6zdzPC?fkGpY z0^tSZFze7V+~GCsXocyMNt1(p02#}j)7T^~oA(I4!*|1kNkgG_BHr68_{(a(E%omB&m8ee0KFm4@6EgUV-DOI z)tu4+8o;(9z6j^DpKCU*5m^0w`dQ75DEVisF)qDGU91|yqHLe-kabOfnvIoV+)Nq- z3uLJnY-ZwZ7jJjr+9{#)^Zzq-z9CtMS>eb#@m3Q@%a!3mE9a?HzIK@UjC;%iWZ88? zRF>_aHu09T43eNAe=FdNKo9nzh}wnw9cJm!>QGUYhKkDyt9ZfcB+dYG-LmQxbwL}Y zR4naSgPrBHgTQaIX=APAm<67!xfwP1eIVbVsXt(IO03O+kKykAd(jtZ{vXX}9oC*i9Fd%w{`l?)9WgeF8lFGcS3czDdTATM|iXe$PL zvwtsPe!AdTcNsZJBUT*&|1Uyq@C4#i;%E^MJ9i%G&yw0 zx=;w{AWZ;dbmP>~U%$tU0MaOBAs|-$ojKu7A|H%Y9BS&U}Y&7ANLmt1ereD1|28IhF{&w%7~5aSPJjM>{zSw#R`D zP`;^g8WT^8sn&HS0q0Sh_O#B>gL|l<&N&^HZioe}q3bgYL#UN&qm>e8%a`AS{Rj}| ziw^Cg_ZAnPuDe;K&Ao>{eT^`GRjYZxG>`8qs2TI~TYc^z+)97d$qJb=*;_UO#Y1a= zjk6_=QJVdA?N<|hQmr_m!~Ax8?}pz=7i33fZlY@P-$ggA`5v0Onej4#STBl{Kti#` zcB>lltfhB1eBK`C06H>xPVHT1G9^D5$uN(0oUe%;d)%NDtK$4Dxl`#Ko6;coX85%< ztSzPOuXTU}o=l+5YA+2*@(-}!k=qB@#2NwKoE1|WU_Q5*b;}*thpwKsJqFZHW+Cvc z%MU}!rNV!HkD`(Ajc|Zt#e@Aa{HS|1RX}losJo=&Ac$5XA8{Arw5AE5U!Z?T_Phd39|TmC;J@mVXj%U^kwc zL89uzP}Z`HfNXky9P6+4(tHU-dXjM|GU0A+rJn@b=us|PqC{aRIqil2RLp)@+g|g2 z!)I)vMU_Zm10}2X9=heBasg;-?`N-Mkq$u4BUQ#ODt~R?dN9djOItqiLD-et2KfTI z&xD^1%po>QQ3;jh(1ezOc>pw6YQ0A_U=BEYG3k*rp-WC`pgl5}nBGs;H8Qf%$8oVn{Bv< z!;SkW&Cfrx0~wd_Uy*_;E}-yOl)`KjVxb`#1Z@g{jjwVLQd8NP;>+Wq4+mc+*x%eB z)^HxENwDgU8fWc&aD6Y3EuVn;F@-Fw`X(-+1o9;#YO7>ivswo`_v4j%34O`Bz2hC; zq7(MjW*3tEa9q@!X)(ptV;{)SA{5oW9RemTOe~Oe%>aW|QX(#%uRK~JPGp-v}hF1-2wQqWxF^mR4pqkyrmu+ zQxB$XpP)Zqr%x|$1R@QSOfV;dh}QwALI^nq)e$XNhd+Ul2S zns*XgLh}kg8kc|Z%bam_KW4yv7z)xxWE_c341R+3kX@O`m>%wU42mzDR_C}~Nh*Ai z-7@Gi7C|Q_ZW@ zI9(R}G}hQX8P(L~N6I&Z|L(X-lv1W<1caGp30tvPVR*$?4c1!60sa2uYe8hbD!yP?9M_u2_x^6($wK0!UD_X;srm1rA}en%1>SJOFe7&(EUEI$fqz zKY#gj@puS{M0M6tY87b4j68f6nGC=>$qiEE34P}T0h@l2SRV*YA`ZGkolV7=BcxVR z>Iui;8<(}t>gYpu$%k?5nErgu#mg8R)@$9TFz>D?ISMrjL(UJ<%U3AsuJU3G+%g0j zKx)PtHgp z3IDpG%+4T8L?Q##0Jne;4*UClYRp?y_KC7%ij|@w}bchu1$V?-EM#>M; z9Er_ra;G}J7D3NIn3%)sUdmSAKke7{cd&(mNsmBpX1HrYL z^2~Gmux3y{_yb7Z#yD3sJKI@+rY-@M>pW;yM}2GgD!3Fl&5fYcDY-zrUDn7qYerDk zR6HxOS6jDSCy{cZkWU@O$@22XfRd-Rdo~u!v_x70M*SsD^cjG~+tnQ-wP?|JSuGG9 zOmxZ0S|5r?=d3%Q!G!0`w3m_`lQpaY}@dEJ z!?Mq{Vo0ONl;t^yU4}XR^{$kjCAhj10#(wWnZhQUXKJbmD10i!fX?L2oMNyzjDDy zLMZ^(Xo#xVr)bp7jL;Ka`Ze*&AF6A5Cgki$-dRlnPYc%ri@=NcqZCawACZ?9Sm`Z( zC=Y4&baUN;ZL^`agVYoX!3*t~cNA|a+t^&C1*0>wTQ$WrC|7(kRQfw#AZk8tL{?ie z0~!dUs_ig#Pqxc_p_L2{m6%BpLnBi)S4q^!g3JIws-b7i?_LK&rAU(-9tXN0`0=0B z$>n9pMMUm`e}PQRQjf=5E_*>52++l+5f86QGVpRBDf+-@sde%A>TE-o2vlcO$>dzSb+) z)|FWd!8WOXf0R9ZQ@Ijz3*0n$gK@}2NJG%!MsXpuhi!zCzpv!qAkEto`PN$FOg_lVbzB ztTU#smfYjj`*bC{aY%G`$$Eaed|W`r=g$^qvEX6DXBmt1Dqt<8RWP2T+DIAOo7`VZuP&Im2Ht{vOZ-})M8}{`jt|J@-t-k2{ zTTrbl^Lvp5@)5rcL@xkfM0;#FWHf#aLIW_}B+Wi5ey*V&^L6BAK^uVFbKLt{PrV!5+hE%p z+d{~>{7K{6PUS_?C1n)n%v@+8!3=7HEg#)p@S)GND)I#%TBxNfQ^1@BkluF5)YSAG z-f^!@6T4#sV@lymgSp_oF^%?nqX3|9gWtt0iyD?Mq=V&eyC4HAw>=F&a!5k3JWx3u zKN$2JBq&m_!3NE3h#;tKSaUp7=%)%ALu zSUkT>8n<#0i;?dlN}yQMiaY9M$Z3n6teaY~ZUR6^#%099C+fhbI-l>xcPp(rnP20g zGt1y^C8k(d52(Ch-6^-Cc@PhWL8~ZM^HEIZeY#KI7%Q2;w^INVQ$X49nl77Az7DLH|@ z$2d0`T0h8;#Sko`w%OB%_#Q^dQ7rG+>L9)n-=9?{wo;gg=DPY|$OG}qg^U@N`1gU> zC3eq{w#uq2Z8MxeLRV~LmI|-lS;^)a(E+>mSsm2=mjajGZK-jT z_YU;F4r^}hwA=d|av+`z|4o}Y_^C8eU4jUUG=FfafU~C;mgPdkbHzIexi37rjbNmC zxsr>BRG0>!=fq1wThtn(6UKfpaGqN1RPNw(^E6Ep)>UZEAzl&~1|;mB-zpy|M_zh~ z^!&c$ntI)dCv>T}fNab2a^)6@79WEX%PS?!TlKAXes;h(gZwP`PmR_i2esb4_XuPe z07cl$we5A8tG$xGGvACE%RCrC5v;pxp%9VQ*m_Mi=N zC_|wBY_$r0AQiuHTq8eUFK$l(ujcGj6OrFZPh;j0d#CV?i$-AIFk=2BT(tazA~163 za9<1U_6W6$aSD66+UUbXaJ-_GNuq8*Sp8c|cVIBCF~elDJ%^MNI5; zEC3a>dUS%qwVTh?Ff4nd zmyt=yxH5gPKR~HxL0hu^4N+>`VbKZsZBqBQ=`72?W8{w~N0u5%OesMkkHo#cY*0WV z=|#yQhEqGIGQ_L(Lkqi~mI0NkG*L{oO?bulw3Y?*{?1f*Yob3j8HrLr*4x~~{!}s! z-dpS&8BtD~?N7%Y#}Vn66FxK|98Y5ZE>J%V1;v_^2Sx&y*qL1`zl-1~b?Q%zuG+!q zdT{10o*NzU*M2BFyO~lXy1O5$%55r65u|Va7+t6LN*+E`u^SNTAjwS8dqRH2E#9SQ z)QI_0^ozn4oV=#F;|fEQ8HE@|nxD)1881Bdn%ERuKXc3@9WrRI%+^WKb@fgG#nQ2W z-RA^jI{@dY;SBtSG-pN%z0Sx!20chFZNw}N_VqSzyf-c72>6PZ0A}xQ1u+T$ellN{ zkZw3V+a*P<=z5QZ0IOMg>Qd0ACCy6%uTMt%K$!@>Y^`Uc-9% z^4*|msm}*LpaTUD9~!tePm3a|e03_q63Qd*uQ>-$s?fC<-C5tSbt7$R6+Ip*JG>9% z%>%_q{httXXjfIDn*^pHRWSCJ2U%esJEuRqY3L1<3%;1p#3Q;PR)VyVN#R!3S--m+ zRt~#OHuU(%ta@%8V0q6s4g9I4IIS_fLRGv7FS6HTpSU!nwlD$LEjrPVR$(nbMDi| zrDzcf-0*gdp4}UJB}z6sUbImlVf{QW!8Q{TBF=W`3m$gHXTq}zU_G?k$vRuOnZ^^} zy-Hy9p=BI#-~z6i%I071zA3gj)8lah3Z^O`EbgSW%EBAhpHJI-&<9}s^qI|)t{b|x zX}hT9G*T0y@YZ4-21fb6ImuIDD!tCs1`NErgQV&Lh)2kF#)hnE9Rp>Gd+oVxy21StBpYiXDZ*2+D3Kpg0zst$*{wPuGKUpn2 z)I)f(HUHhC*mOM%@8so9ie%65O!R(>FfpFDm13JM2--d>R#htn3^$e6una-aOUf$G`@IZFd# zr{RQ+8oPNCrP4T3xKzn6c(Q3wQ#@4B{aPd`v-Z!T-hinGMK8|g+Ny3JG&^OtV1M+- zbF&bGriucw2iMK3h30UT)L4bM$*E?Ce9PGUYfH~xx>cwyd8Fs@Vs$M9K82u}*ZvyI zm^7JzqFk9@j}MG9V7+&INrj@7bnu#HZS-xymK7p?X$Q)tZe(79oBLe(q%qYDM9Y2S z>WDX(+325Dq3z1h0%(1`sfotm&9~wd&g}%a>JM$@*FAsYHLTBtbr(i%xd57E#DC0D z1@KQ&+zf6|pDOC4X=O(T#0V=AWrr{;p1$u#Tw&Z91NZ@|mqF({3@bI&RZO%*}o87lP0s;{HK z450%B!5TSD9G%wD>Bx%hsM zo!oIPU%j7Jt@enZ(n&uEqGf_w1<0cZ^>lMeT~d)YJcKy|xRz|(U{6!-qN%#;;38+} zpOI8}1Z4H+K)g=ksRK>xIe_>v%#VuR@8ZDc0R?6fUn~z;1JML@Dl!U6d^5A;Fue?o zkXv|9QAr9|v6vfAJnL5<%~|t(SknN(m_$E`NVJ-hI}V@5WS-6D{W!6)byUw~1GjhP z90A1qS>b+5?3$Ig`^e?x7W5h~#XAJsZJ(qKxktkl$f$HdM%dio5;+9=wsZ)>G=B)tkAn4d37b9ANpDWb#)VRM`FcSH3J~5NZ2*zT zqLU`FFZUpg8YNhjHHBM2_69~%snAwjt1A-z`z2~53=YD?P;Z0$*b|_}46{>VWI%fX z3=kO4u#X@uLB$$*KR-JH@m82JvP6~k{6JY@?e((4JVR{%;Wd5xzQ|mJqEf!gdw~pM zzbgmgS!R|^yt}5pGuH(0n&J9v@If?~TE1#{yjP^>^zoIfGT-k0BK-+2uawFgMF2!h zG|;I_d1oLM$z`m7Y(kUZk8291`{i4q%g^(qW}ya-c=%iZl8gE&wJ!y2W|qw^rERiJ zohq0uz#$H?tZE{^63}^Co52KFsF)MFxD*u1$*0fV_~%~#cxHA5gwBranDyX+9>>pt zaE^gkQ0@VEA(Dtvq!E*!uO`sR&GC*E3=>l#dA&7-0TNj@aInT=;+f^OONc$e=Qjn+ zN=I7nC?zS@RQZ>H0Ellgkg?2y z%L#CmFJ8t8R)yveJIb;r{y+q|X16kAx!;EO*67%+;qqCS<|4x74?wrWONzPLz|D;1 zOsvxZ>en~_!-DHNpl0X#lWEBM0cd#jh?j@N%{|>MvK3Fj476R+3QTzbJ9He;65Iu( zK7y?ye(3FcB-4lz++Po^Agm0|2-@1CQ_Y)WGx$#$_RX2s(OLo-V>X=z8Y;+GtG2ZT zi2?#@_3Qf8cmFeG?f(Xo|8*Zvm0Wm)H0DgfpAOFm*-wxIsk;D93i=oG>(y4=$01 zT?wS7(B4@)yVez%7t=)gO<@fH)OO3K&n9neYQ~{4oCQUh7w}SCOCo6|mk{RHitSC7 z+tASXLfYUYZ}x-pl37{vX{OqimNK6=mSV#c>{%&XZ7T#*G8sTh1mz{D)jrx>0?5en zkmZnazITl!Q2Dw-p%4V*c^Ay7g?fYY+BqCVnYk+Y;HLz0IUV}<&Ki!O*EFF;N`G{{ z9E%V0L{cW&<`D&^F^$%2T<(q5ZKK?Mt^0+SJXk=^G_M)C-r7<^fhL(+3ia_)> z!aqB^0|;Wf!6c=D5gniNwSiJ%mDPc4&3wcBbhGPK$VRvhw&a;m4 zm_05QJe?Pe0Shk$mb02v^6);M9K`Pcf)wZw_G?p0>~ZY}!GjKMSP2Eo_^#nx2wOr! zO{q@J%P*AyGLEQ295*9!ou`13DiUM9J=%Q62UcnRcXuIBKttGf!&Mt}9I)_!2LnA| zaQz{9sgLz&OVsSoD;?Q)IoIY-YcJ!gHJ5A6%Ztf=T!e~0(}_9vT#sA;zJaiVDNFHH z=QDMDaY24}>O98XIu4GR0$D!EZNif+`wAmkk!G@Za92xCyHS|!_2^b^NkbxjKqDEq z<+xXCna>`W&A8ec&{#uxp(^mUJZU%hiU1;!2I(d6m%X-8wdfsrm7_{qwP7K5UWN6{ z%ewJ{n%b%FjzD(f>R6y_3ol#7P1dD`M>(>nVZg6OU^BP~5DvdW9z>ZK^YOlZ8$3o& z;YwLTakVW*#KhuT3ZwvYiwD@?S5&l)xZ!UtSyY2!lR^v*xx zEz{=aY!~;AU3ihlbv4*u0eo;i<=}HG`qHzyk}_`ovoE45O|S~#j|Y#`7QKE7;1~i^ zuqpiDgFHS9=|CNTUBzIvMmKNGMO(yF2(0#y>hIH=ne1f)OEB;&lDwB-(V6mj`tFiE z_G$fY5m7nx8{1U2`*`{Vf=d(9xS#_QP<72L+kbk8CK|G9A=Srp98TvPlt`}=;k`UA zN5@OJUIpOxL3bT-AbAMih=RPl5%YTdGP&`5Ni9)ERu5gOGFLRP?^DxNhbx6C`9~i? z$WI$@n=hZ7SRSZ(JJbS%!{^EpGY_eUiHkJnf=fX@gFB|_T(hWxy-syQfc!h5LD%eZ}ANfwMGgC~L0r)Js!4Z%EwNCqCwNHFf9b$_)1wJ-cG zgwhVE2HkwlZB9>`2!PlyW4D>XwemT4A1c}YO=R5GSP(%Y)^L$>ROND6Gi8=|mQ(!8 z#A*3u(5%6RcQq1Ih`yd}RaPm|#ya`v=cVAwbGWRxTnus)JXDYn!34_ZfS}vU_7sEqHRiOZ$ll_7Pc2_R>7~U4W33;DT6UUQ=;Bn&~2Tst9@jn4*2< zgJZDzmIP4$(Ch@3Z82YLx(F$fS!0f;apw0-$3y4`JC^3-4t#2BxBK=1hZ{T;=O=fOv?5+SY8x;;67cb#no0*P%GOy}j z$qlgr`oq^Z8YRzx`w(Io0v&8Ro858?18$R4PYcKi2JkJ6=(Nt8Yw@rUK(oqb01fNA zE-fq54YBwgcVhIo8p%1snGjrUs;0#Mx+>%V+2EJXBn!`Lnqlo~d($8U(3jk~gy&Dg zNNhRk9Y+kC2{JNL_LS#-zgKdJMRwl?Y~bYBP^8kem2Wfz`_%*tG~{u$)iZk6$z=CU z7wNj-Wi8~XbFI7xQ)Dascr^=`#^f%}W4XZpjq+XwfBH1C$IFbIUNR4_WKmn0h}qSI zFW`gVtDBvV5bu!&?ssg!HM*DvT@;G{aPod>)r9Ebo56mChagutJrV&X92*Qw^{4Hr z*FA4$7Qe5{G$~X0czb7?xfsvTAK9JLN4`Z~zXh->KzB^L2hcntxXNht5Uz-snwYyw zKX(6z8UR4oDQM~s)`UE@YrkMM8c%rpnGHI>3qZLRVM8(8W^Mc`3-OTxdVg>!;4u?m zP?86uL;~&+q~*vRg$g^G_dLFcICAVlYsE=Ztn?Qdg=pHIJXz`Oq;5ZX%_FquHZRk& zD)JCnhV-+XABvLtM6#L-iV8FVzi*AFYBT}XB#QrE&f~AJxN=zsWqXMBAkpL8yz?!jO__vP+62YO| z+zJQ}!Q)7i^3evA3`<-H2LXvTUUPm|W(Ewuc0i)}F!{tm=T~FF-Zk-8e4tHW{wFn1 zGyKq}sfaM7hNMbVP1o0XZ}?{BtS)qI*goDHifaVns`E*HlSA^LeN6*2DdI1>+cSHh zvsRz$+0^~c+X|hN1%+0lzj`St$LmZM5#dp;TIX^tj4f8Sa>ku__R1gm|wfT#-# zIO-H+pTX-7Evo-Vv;M1UhNY$O*~6ldJ@ZXU#J;@o(2LuMl*U<8!4W*qGfv2qI1Ug1 zmoA_t2@Je?az-7WZv;k#TkwBk+_>@ zF&|NM9%)JFz95`P%v}67Muo^@Bz#bqKA_G@90%TtKd`kBiinA397H$qPppRV!*xS^-#^qaJ&A!1x@Hx_JAHcY!ZJ7P9p`A5WT1hs(XQ;#+_}AVYc%9F#-LXwb7m z=#T(*1djoz;xN@K2lK~zfC0V3ZdGUk%2DYACNZT$vGH5gjM@T|Yx1z1Gt5zujz?OP z5_n#ecW)G~Og2x&TXv*=W8c7wlxK51tOH55q3;M3BqZ$B`&C2!>vsB+2RqwYn;=$oUOGZ ze&DTh@~n=ms?a=uUXpKID;nzq^}hU2vSsjrhM_ykX~MAo#U1S!4H=rRN6YB_!?Ldk=DIcIR4=G<0FzD}-7Xo=sLQI?NJLzUQ@blM1 zy7&!Da2P`pDt-(=YJKh;NY4=&k5>&4^OaE-yGNU=Rq`8dZ-RY|iv^d&9|gel0-!f2 zEKu*wANCAPL35$tH+%VrvpSwo6sWGb_NaEbH{y98D9BFF#C(mW@j05nQVAO+ppRz8 z0e3_a#TCLG`vUsJS3p6iak4Xz*J6>|C0DQFb2ZjG?ir;_WFhh=sfH&|RLQdvSIR*n zqMUVq?>kHC`))~ppQqJ4fK1Jdbr`v@21F6Zez4VV?-+;08alk7Ql(aigY)b<{q~Tx z;7W#g1VVA=?*DM82-ypQ>?jIeXt2dZ^4(lD-Vr|V%>2;OBL}2Ereys6V3}7~BTY1* z@ZezdwoeIJo|jm+w%+G-dk00Ni^d zqF>F9_EN45uRg?gG#1K)d(hnx_vkb3iPIJWAQp0vq8#u#Z|B%Gk+|W9o7Fv>G%X!P zAx3K=bV|rnG&9oyP1^E77P{PL%j_L|+Tz_REUsol76|;8A=gE*?v|cTQ<3mbhgj*B~9ACOdpgef!{u67~XsF_Ne` zikQ~m6F-+$re1Li=5S+IH;3=H1VCoIDj%J6$T27^*vLX2MmXDvxD4-z zm7b+_aRaQmueH$L3|Vn8AQC@v=9UOD_V7e20=Bvbnkj5R_9h%2=vGP4kIlUQM2lu$f1DRIz@dfLC_2#e!-FyDX~rdg@ohu)&o1jgmyw_9cPJd3R}oVx!sTpO zAst`ZfQ&2H4di|*5{%E;n5O;m9rq(|A{|HwvW7s*ic!7zyc|uuedQz=Tml@^bTIoS zlwfG_?AMvR{-FY9b5%DH|FSjDI`|HsD;SZ07(d;#Ke+o z(wDUH-f(k8aRjs)c*?ct-*?H7f+q0q$*z=a+p$bt+#`+0IlU9JGQSt@x&^#xw+323j(r zjs&b`QB4W7X^X>uHI}6@-z+5A!pM}{ zD5YeXvGwqaqQKAl8|&MS%`dol+5Ul{u`&5aAABu(-2!5SdyIAysJqiWTEQI~Sa#=r zGzXed@ut^;>2541NSeNzD9UW|8e1 zaBRv>vqrL3*e>aU7SnxdxDB3JK7inwH5a`&HSJ(y^2PLf@wK5(;2!YK)@G`EUHu#UA zhMmoGoJZ>{Xc~n$%Lh=11)uvGX&DzL$I9}UU{{}+0EIHC6tX7yTVOqrE3`=sA$Y2q z=d!(ruD<%?x#Q*V`RGIDzL$dlK%uERJJn{<13+Jg4BiGZ z=&{uVc4yErulyduXzUQ|Drb&<>~Mdf9Q4&9mE0VTE=`76hqz7DPeHi@os}ynAs1@b zE4XI`Gpz?gptXHd7c(7LnPl}PW2+&w;dq}+Yawdboi`^JMOzQtI*M#*08;X?VgaLZ*@#HQS#FO-u#XwS;os7AL-jLYv!to~sH z3Gh=v*Q@@ta@6T`u-5yv+nxB21(H@ z7&=}9l#Vp!OBbIK&s;4Q2C`{&{_KgQwNfkpBvE5u>Ffv5AktwqS5twEbqH8w<`;;E z7p`p(T8Hzq-6p(I_iWoDD)}t_=FGU4-JE6LEHwyLN#Rh$e-#oY{TYGtrzPt2>4}Np z{_lWk&!=zYqr(OQ6GO9bH77oe=OQa#r_F3U2JQX?SozTFAVear($cH5ooSFx`Hx0K zd~Pe%iYjHBr_`iu*QO`7zB+Sb-1QN-$Tm?wZ&)yW4mJ%#?|6CQ6t9=gaQYS_i^22p zgU~WXES$|Vv})eglm6;<#^LmauO8kB8^surL3uBPgK82(@Q(sMJ(!q_M@@fr(!Ke8 z+aTb>FzhnuAb;nm@tgfIZxZXD3Q%Mox==yO#N`W%H4#D>vxF2_zfCPIuZmxtcN%Wz~9 z0;G&R6t@fEHG2XRnZ~&bE6b+=g(0vZhCF`*n1uGpCPxAg4yC(2%ylo=j>|issR?}lEtAz0iGx94jwHJ)oa z@|R|o;knSLk))?lNCx9?zmH^Xs1O!^C9K2$#Or$bF+P#tUpyWpLUD?o>m4`pHc z(YP=JNYAmV?AUo#6Hg5&N4=yiGjH^%q)Zyccj>- zGr@@|KGYH?qN8_xW1CUbIMP!Y4^WbU@C#g^^ME?4i{&byKS7B2fiVv@Ov7j} zrnwa`9MBr5v5#h|B^nSVBvD!U@_z0eMflPKwBBm9p8M#|A8<{R{bJ~Y7GADJ+<$G| z;@<(z_!1T^(xhSH%4xzjV3(XkBwAoOBRhd~o{`@OLm5cF29VSv&=4dQq5T2lCGernk zh$isoh??WK|LkG~uQrB)mZx)$aRoRL>6iuC#4Jh_{(U0AyY;|0M+y{Gbz2C8h;XUw z%%a&NeAy^E744`G+zIH+c5@Td5Lu@!4z*YM|L?psS zfnA1DA4)l$2X?{d_LwZBhkX!=6=b`eVnN-os>TnU+voi{e*6&Wwf`NOK^Bi6KNY1| zowT4c!h%*XT!e*C4zS&+@TH zyiszov#mdVYu5!I5RGd_S-f0K5&5;&|4hsY3by}tON@`McQFIL3_iXMyzN$c2#T=T zlY?|dj<~Y_i#ixIJfaEfKc)<+-;t7cq8R9>oRymZ7-xR4Km7FWi`kQyeX8iVv223MH*6g{Jg8gj5KAt06M`>AQ4jr>X{1_NF z8T~a;(*!x~3jyJWHD?;0SWLnEJ+jpxAicRxGECtUTmh$vrgDC|p-VvWdX1O~9}_tL zF3&3a5(Vmh%}&tRmAsoRs$Z|>@2QwqkkQ)x;}%&4{*}<)&_BL7U%E7YMy=l_OidmB zSSeOsJ5oxO*|U2$6S=PUo3qNespyfYjGILve4N~oqUz2hlaG~^ue{POlDw;`F7!!! zV$Gj|%|{ifX~HdKTI}qla@*u>YChw<$|!dV;$0(sr7U{-go}EPvGvYsK2F2j;ik?c zzdoPDlcE&ZLG$QSi~WdgRo$Ig_kJNK)XdGY>cwer_~_4$rqQN_UE4KhdSB^Bsxo(2 zgDoyWj_2)(U2p@*X6T#wPOZ)ZQXtyLyL8re_K7?Ym49Eywc)VlhxeTgT!R;HmWJ?S zj3&$8cUt>p*f%8m&pb5Z$H0kAe0X6CoojMuADr{%|6>d{fJzT8eUzd*ND}d} z|K1oi@$R!Avmx7xJ=lB@8@`6eE+~-0G=iTsm{HJKKX9;&oFO~%wMsvQuX2JBo=T0w zeCS6THE9=RoOy8eT^t&a3Rt2f%KstO&Drqcld(aE|Ib_07= zOp^Ty-n`hkLoG2&qd8tmP_pT6obKoEX&*cLmhgx^v#6bYelw!Tl>7d4&47kJ?+8W* zxl!|9Nf;@Fn??HkVYQ+#C9B5oK2eh=NH1DiDMOIU{wu+ycGv{DcRwEI%I8>}b6i## zS9);4q;mrn+3#@biv=3s2Df$9+ck`_1nsYm)GS9(pEUQji&uv@9 z^HPw1n)opGu>ZtPeyJzU1GJ|UL)I)^IwCR>aro)_gIOJ6F;Q2zb9hNku6LkMeV|RT z`@z`)gN2iKeeoY>;3MTeWQl?yqEE+G~jTY=F zJ55l9Vfb=1<&XA*ed?h6>^u;caaYA7z@32@<#>(QiKq_4fkI!_!2o0*Bxo4NcL zFpq2K9Pu-R*#UO~M%^)!sJWg41LNH*Z7mw(A{I`pPMCw8<1nyuG{zqbFU|jvwr&rb z5vHJX14n(dbz_tbpZ785&R{L3ip$RY#$$HNw7?xwps>#lYG-VH-?Za7{;ZA;S9&8)9nX`e&+y%yode@I@*c)pO9EnT;#Aj7#=nab z`;EuzJcnALz25npukW+6>an{PSdO7)2(BI^^mKDE1=OGz0Cs`2ht@h*g*98_9UbLLHCVeKCyetH~LRqpIOab#?$qC zL!P~U$rxox32Wrif+iaRJDdPs$*O((x{1ZIS}1ppQHG2&4r+1bM1IoSr1VSfQkrNN zk{}X9c?w-82J}qhvb71`bP^b5A)bWD#td?oVS#KMNDJAU;Ak}OfY&##2(PR~fuoK> zUtb^Xh;!y7Lt_A0k@yWD36OJp&oAoMbk1vNgGUwT`zr3@MGU&X@n?o~V1>rH`OJ%@ z8?&cALe_Q`63i1O>-^L;@?Zi@wvY*~NK*UW*LN zh1&#C>A9zL*R4P2Uq3=0yDzf%j?n4<${OhmWrk`)~zusDWs($R}rEBr^fBtrD z{dJIkuEvjNzplxDuLz9P|Mxf%eVnL;J(Sne(}U#vKOVk6Z&enwqMvMmW&YcozAoN# zu&o8jSLYu=l|SRDqq(FEQxOUwm)nEP)y2o#EbkGDlU-#+1(KBhc#I1FXQC>=oCZz{ zBY(c|T>R*0TKRG5p?^LJ zf!PB&8VA!&6R@}sC0TfF%tkS>#Q7VvxRSvU8W9)SCsKjwoEbp?M@xJdBxL6^30q7v6(Ea^A?#>@<(ns#|z9=>Dev{LH!`B59xo`z8^-f=$@)&xZ>y zS3D>+)l(b0gY5CfxUld}Jn84bYWEhoFFWG`W9z_EWbyIvBC-zlfQ#@e zC?NCUSEc|L;a54)JBW>TcSPtpWc1;mdD-pBcXH)S6#utD%wG#nw6*c(^vk{bGRYhr1>?XkM1& z^i4a}fBm|NS%#7B2uPD>DEYJ(9jb`&J4s4hgsT5O;@hBYlPx+fB~8D_mAKjr2aeb1 zR%#sjUhdXyt06*-W8L+&IAKDvTEcVW5K&I9!f7ufG-XIkSwadq%`9@K5Oz8|$prV) zcGvRo7I#D6f)g;RVRI@oMc6CKMm2AV0ORGslnH%@A{M$-iM@6P0p4^EhkPT&QAGt4%Z>2Jl>~(MklXB7 zQQ_46xIxe1yASoOk(2+=gKyWBIF#>TpHYq^T)q_#?yj#Lo+#=r+W&t&yL)hU<^kD7 z4OCh6$+(>P@9_9sB?TYp%Xpfu^*=)-=(iKUkjP&P0y60Se_0TILFB*ExcMOYD~QL)mS>o85@4;l)p75qFhCPaD_4YQ#SmQZyrKt#{?D|+lxO3n0e=+ zXDxid)Rgt1#8K)&ixAmHZOwAzbm2lZ3_cUKwiHzRZV>)g>y;seMy>W>_DeX4&S&WJ zXVuXGz3xTnE{nw7e-6sC>lhsL+{G*tA*x=CaCrzn6nEcWqKAoNRhl~KB6}yN1>jEh zgeYb{sNsL8J?UuQzP{SUbYmI2X1!Z`PW7^XF;o^0dzokb%(4~HPkb^TZa4bE7^m$k zx)NAsMm5gil5y5*?}G}gX_JMssbcJ_3^=G8SZ!3upD7hd+I};q)Nt`A|NPyI>&XE| zZw&mb%BLBvxso@1?3`M~bZS&`B(09L7CHLF-BuTDZmScvh@C5GhEYaqU8L*$a}=KM zhoO??`Cc?lB7Dn_+jrMYI-I#n>-JrFtC?9*FIb?OMGh92zRxv^KHwQ;7=IrwFMs^o!ouS1R`p|D4ocTsccM<#7w=F#N z4MEO*b6j%2#MmdRs2s_gn(1>)D-+y-%U+(I9|T>@OyufB)@65njc}{nf5z(l6wYf6 z=H&6mPxzwD9xZdPK08g}fAV`lvC@Ze#%29TEWY;dQgqI=SD zaqHk}vaim8JijtivEvKDoAOfs%GB$Gc1lWO;#LK<5_v9ur0UKc7@AV+*6HtPI)3~E zRr=ZQu(RFpB%TP!A*q$*JB5H$h$CK%wG%#bPZf7p8uwSUSZ2UT^>yF96@ z4N$Hf{C62RyykeHa*wyK28J2BM25Pan}OfoxE^b-qf;NPm0&`81!Fh&?w_*lcqbQ! z_Sw8eK<^2g3FG1|8#>Ij7?#{eXku)!M@==)hCW*@-kTV15d>4OXnp>Sk8S1H8_n^1 zsCkorWMLvA{@}YxM(0r5z`r(nlq5p-Ts&PKfpc*R$7*s$+1@@qeCanMYh*QBdylym zua)h7=Z-b?ZCl|*UO8%(ACF3TyPIC^Aly+~w;1Ea+!}&AphDcBU1ZmG&86{qyVkk< zc)=IR#zi}s#c|H(nMHN{~^X&LQ#rCGhx*a-nj z$~!fw{C{}gVA?nDd#rFKXgH;|#NjPDH3+5{Q>dFYc-5ki!*V}%^s&FaJyndeZf_9m z52xM&UTP)y;IMKTZXmlbUa?G5h{d4{4MMbC-I}MWX{8;^(g(37Fxw|*To~y+#|Rkr zi`E%O2un=&)8{Itqm|zta4r+n&eUl5dUniO3#FgR4J4l-tz90lE*p%*6ib(KnZ2kw zIfF@um8!S!Jd+ByUmk>QigWM+_;5u2RpauTER762e5^7S!<{SpCt!bzGxaOH*x23) zwCVYqxc*oNh)Me})Y9~ttV&q5C3guVO@48ovWg0^i#anGubc+QLE$cGJ*(@tY)tLC zGZLm*mQ2B})xhv_Tp!o-zf?#DrHfgwv2@XNAv1Hu8dPpMT@KUi;wk$qwoR`r9A z7SZFh#y{6(RLzgb42^TAbJ<~a8AY2rv6Rn6vSq>qO3{>SJ)_?UQ#~)7u+3md1pIdG zpF<6msW$HDl@9S#UsRq|V-LWtO` zFIV6%2OXB%#8WoQ;OhJJC`a4+$cvKTd>W2vfVP?w9j^~!ChncK=_Z#EZBJwZ zKBt;y1diE*@T87R&#Qc^FD4iU7N0Kn2F6A!@XSPg8`L473USkXVrBuwuYg17h#qs*2*o(cs?tY`7(>Y$ZlG@b*}Sq}VtmDgTCm`%lg9wpELZv8c$+H$gXIEN(K} zFLJAaLq#@n-_qs?FGp18zL+ek@B>0gzc1tSmUN>|oTvGs$X;i!MN#G;7f&osB;6PZ zfsx1;cnEpA#d891F>zB#je&&uTcR?8FM;#7VyTzQ+^%(dGV`rq{(#~QUVXu7;v%uq z*R3Vx4)!$gO;vL~Q$L<`wEtFNfpe$zIkSt#H^YE^x$1BqD)6G}`g^0c(}EvNY+gj9 z6fmn$xUzpbKnXSw_xplVcAs8cF0pl{<^syNu)CHh$Fjk6m!@Adx_I2ZE#M_BBYnfU z{f9iay*AKDx@xebqO7VUC9waP6HYWUFm(CXpsNH$Czj9Xj5neYpbez21HJQx806rBme?*}~R4 zv4(r|Q}}CluSVG&amF{(E?54uEFcnk55VpV@YxvwhG_l`w3Q>@<~pI9r!23PQ)9@6 zUdwIr`jdxpA=r@b(E-y4`@+87yY(< zLpmscA}4hYUyzN8vd|ixff-JK;6`PiCV1$7?mSf_(XjDE(dti{FPx9!%5&q&^s1=M zmgky}oY&+-JeC>%w`W%U9Nd_`CE@+;T-&eV%hG!Q5SyEuUB)<&(423lA2;s&QlD-X z_cfQ_HKVL6Zi@AxxA;p@E(;yYIY|c5F>ZNQvxOMzk67B1p{Zj2D1Yn1pO>Ql+&}z} z5U0hd2vr3T_97!m)7CJjTJK8zQ+kW7J(%ZYwA4l@Y~?~z`GNYBhGN76MKz`?pRDjA zrKP8gat|^Pn;w}eqUW4Tx~_GExZ#h@SO4t*KVDMJASwLKsrduguCtjt`Wp`pSOO&a zy4DB9k-6m69BDKpCubTmq~}V2Bse_dLens+inL zFv=m$8`{{lqe1D!$=VUKxYc6gGu26aw??ZrV;SwO>~Y}I$asO|)@%*y55wsTDrZp9 zZ%Z6+UU(`R2w3JtWC#LbIFAqa`;iONV@uBFkAYpOly@*F~vHh51_?HSzAqHDMJOCveNtSJIYCX9I(TC=3(9yy6M z%Jxas5C&xOE^#m8POBB~V3S@$3sgz__!qB0OsK8B(Q4<0j7m-eIKAX5HNRbO#IuK zYlfV+i>%^xdx2Lk5AQGCMlH$ci8)hER4! z98H=mSG*SjKt8wBd&nTDjN*Eutr}LFKU!Aw;3lfr+9dz?e1y*q^8B?uEeSmfn(&R$ zU0ggGR${P>ra31gOgb8UKzlXvlXStT3R5;ke);oqiyw}PJB^&f1*4DvfV!7pgg2C9 zwB0SsJ8X0+@x*ENvc*y+la_68|*#cmBFsE27}=a7vU?wIBmUkb8;YJrJBwM zmCzzlPPYgoB;1uJDqB@)cN>nyz+8_fW|*$sW3)G9;$8VME+PTUF>{6VOukPNkOW#1 z_oA7~mg-nk%GH~0kUdo*0JkC{FOS(EX{&V=h`aQnvxYn%m8a8})kFo-NdYpnv5M3h zIEOD%&*ci0wrsg_g<{E^7p+Vo^sH6buTrBbU%chHeA$*pZS<}7t^b)qZp;p?^A@=s zhp@OvbNqL-FusX{_-I->nS?|GM3&eh*r?%S-{Pu>E`Mih*Yjj(uGAf9>(cZzGww|U zGiWy=XJ-xiwc-?ljm<80T3=RvHbg+lA}2Da2BE%}gU{WWet80uH61Rh3#O3eN@jJy z*+Z_Ng@8c8bhSqBW3Qm%aU@W^r;?X8P|q0Qc=3PSg&8+_Gc@G#xAI9qVIU zZTxU0D!TMr_)OgbMRF05`Q(=+d&|_Zmo#}9+#2y4r^5mQtn1Lhc=KY`ADF*yFM)ja z#Y05xdm)Pe!ncg{jEC^(WF$d&VK33{V)+!}7CigIHEejB- zd!BZ; z7#Xw9MnYXxzPR}YYwNpKZ6wwMVsU@oOvY-VUd%zC%l};^-{|^%X^Z-TT3!M`+#nNx z1+B;M-k9U^zmx9$*`3%C2wrQJhNz-50G4LxeEV2b4|oN8cd{K>45zq<(wBb$q?pa*8)5KL!LKp4(Cw z*5?~;s;K>3!K|adP-cT-35gLo{V97{`o*WBYBMF-kAeL|M%#;sb?X$OC<*&Ue_OJJ z01~wfqN+~%uIz3JCY=r`z9i2cpGR^zZDHrsJkND#_a9wddHJ*(2whlji}IpA4G1U@ z{$*jUg`CLW+j~WZ0@at#%+Hp#+{ABuMhimvGL~p-j|EZaub$k8?uCudAO(%(dY=6B zLO#Jqt0!9nn-2o<^zUCAzRU1$YB)fgWSpF^LB4oUUR*wvRC$pBUv@i}OjC|=QIBmR zEW?g+y&!F%i{fXj5cvU7%?>)X?m}D)qB^gy$iE0=pKtP$0}?|(q{NokSd4&75pM3K zW$s876BCmhAuw|U@nRJTz{Gp|Qfd#n%q;PqT6uhpu=FFBGn9s7^67rBjV#HM=;QP} zag8cv10;_|@!FU|Mi(aw}1yUw)t}5!yPK7$}ZGTT4yt-{wJq}4^33hVF3^^S8JbXpyp!M~f zhhXn!i#L2@{#T4Uv0=9CHJqvI1QtqofvnXET8CW00mxckkxy+8uR8{Q@g{u(oE<|y zu1S_W6hzDQLEJl{c`Krkd}70gwikv8yuUS)=th73{CK@KE4U+PHqsW24*TQ3u133k z;TyOln1K?6!1#WUnbJ6t`K1w4A9YKU4rqx3=zptqa2OBwAm?-Tg4Zlz0fFRXSO2SL zY6iIk&nMiyR9nEvXLfI}R(fsO7naPqX<8ew@yds7i`Jq%Pd`~jQap92;q#iDEomZW zjv3r}Z7ukqL=tmK^y+@JMowq%+O3CTAKf#2a>%ctCD6UtDW5ZcHycw-Wpnem)mSY0 zCXwdcsF(*%X!wtIItiI22rrs`jx8Fn-oOEz-=)T1;c$qasq+=ZL1hIo& zQ{BtxgTrD{(n3Z9x@LiJR>z;HeN!m9j<*<8Lg%o z4*o8(`0>k^FE3}axZ`!Qgt|X?!2vV*lK3^0XyNmzj`sHUnSp#Fd1zJ}*A?hjR_XgT zH&;bk*rx1Kp14fCoAL1ObdV^k7!&{VE&<=i{rz7Lt13Tvk*hBvy?K=dJf23jz}bXe zJl>zNH$aF1;n>IRqR&lQDV-Ee3$@VBKOh(p-_skGAxN&^n0tH*UVmYNdc?w}uHr7krMuJnYhdpNhs!a`GnWoW(^*Q{N;miKva>bITbC<_altp9x)e!Pb3Z3hEC zm&*h!;t)61NNIQ{HF65^vy998gmsFrW52Rjzv=zF!1v$3-$Ym=UL}rZE#jlvv_j?X!b*>oZYY1!1b_|2P+4-45CJ z`n)e?*cuqBnqW#_1*hz#SU;PM$k= zZsYjZqkdikpI$uXzUH~k$%e|isr`@7t9cp3|Bs>?ppBW|e_Sg4>$@LU|K|nKKi_yL z`=8e;|NQ$Vga3Wgh(mGzd0q6+&)?Sk&+FuW{$0fDpHCY5_TGff&&%U~ee~OlzQF%G zfo~J>e=G3QF#gYke$2%GIq{E)_&*f-=^OtCqd#Wi{{q>MiTJ-r{L?r7zeFW$#P2^{ zaWY5M(WkU&`)?wL%IBPW>H??Bm&jZg{I{DU{W!H37Gl$#evh_>p2)$^*AE{Fhw8%o z@!7wYtSS0uong@E9?s)yM3_vhEe$KIfU8KZ488*iZIiV3%(xE!s0{CFJlt9BzqZ@H z$KaZR+S>b-V_*Guwos54PD3pX=%l%9})IzKq^>+3&Xyr3b;x%wa+#JkDa^1yI^V9sZb8|L&mJ_d@EHwieXfiTe;BG&6 z%Q*JMKksKVc*G@s+4=y)Mlk8cyKk6HCVGy=7ZuF%^Vbz^gIR)Gd%4>OYYkW_7V$p9 zwOWhU7Qq!*|Fy0E^B#_M4>_{W(41?p9)qyjB;PrRNy)F-BPm!foTuS;x!J*4Si!Qb z<+C*QKM2D8<2Lakw^@&=Ph_pFk8ZI%F`7^F|2F*pIu!i6Xd>>&KOUd`?PJCNb!V)f z-~In_V0p0Y=RA(4U%7Gxx^TpM$wV*tI3aU*W(plA=XUbvEQn-MDAmayzqQ{u0kh}a=c^#E9u`4ujR(!qM^-})2OELAgIa% zIa24W5h82DHfSBY2*)zJw9e2#Lredlhv~-iz7K_86~4a<|E}q1ccNHgIn$R*oz1!h zZ0*n~N%NC)f&n&dL&%vkF0(Ou%f8*TtjS;&zw|q#9#XKkvkTNCdUzdgVvIdFO6M<2 z51PM^d(*W2Z_H&rT_Leb?5HL2n#_YI2AaEki#^VmPKIREP)zb)PZu(1;gw0+?)J-N zd~M_FkkjPFbJrC87+1%$G#{So_y|oiy2GP9*tItaBRKnSLdkLwnY0ZO+QMdH4b|E@ zkg41_e?BKCXHAhN?BIpWs-L$UE4BR^qAJ@ZlOZi-p^iPiMqFIHX#bC6K!3Ha_0jv4 zk4@Mo9I^aZ3}P!wk$a|#G>L+s^zCD}%_k^r#Gt7Wi<6GUvr!{U&Hlcut!V*&1^=w+ zK6*@Lp000J>i4)Y|L|wBKe%0?O^9kPyIQR_UUvBIe|+_5l77rnTN*6-%f!UQAK(0S z_rbyKM1D_~4*#B2%~}L?er#7|M^#6UW?_ttEh#d>8T;>XgD3G*25Hh0WL`V^a$dkK z%|IgV&rQQ_*e*%7YeYD9>U&iD_ph{P7zn~JT*>M9Zz1=`cSc`rU8%9R+1Zmh_5IJ2 z%0FG<=U$Qd^OttJ`Lfi7{?-qQ+l%ZphgN!Te^6>nW3%j$VQAfOpW(<#sx-cgpXs?Lc=vSR3eJvE*nfABTH+Um<3y zGhpXw&&%=rEf#sIs_jXGu_OOHjTuajP?lsI<~U0S?+&z_>Hv!F1X8kdb6MTFVVR zJ&-vX{KH+r>jfazeV^h#R zvxYfPWcCKB&&|JEyT=ly@7y_a>7WdYHC4xs`N^xjU9nnl%OOY}Jt?oE6C{rmR|`X0<8 z)3(&iLmwH4X6EFh0^m}X9Zk4pW-Bw7b+(a-KGE9z&k^u*B)plSujyeL<2hOYuN{IH zbmT@>TqTn!^e^K{aSx7Ry4Q(`<>pg&{C%0}rX@cO?dt80u8#HE1&kU+hXFM|dFw+d z_yCVO!X^+W{%sa zEEadl51w%FW4rqPmvXt_gi}w2OYOg8b&ky*R+BL-VthFg-lI@5SGxxpTdtqy@zFoX zr$xv2%I|)Fn0=-1h`Dc}!3Pvq=D=PCDblyPH3EJ*n#SxkCp)YHy;b+`ue_@lUj6jb z6Kv@X53|I(+|eft=$N`{l<8RY;7It-QHbpOQj+vDf{DaNE6w%+9K7^{s%8{t++S42 zd{A#tyuD=de7xSEp^fiEYbs^Qv!|8!qM|PFtENVlnc4gAt`&V?tV}xD3O$WYR>YkM zYZA#c>1c1gFu3a1ffm2&{l$RXoQHQM>43jx!C;``*wpl=sJz(*z>LNbh9<9XLWf^z zRpaheKRqxoQ&Am~B?!X@yAw0Qz*DUzKZaF4EWf*&pO%&prcKOEG-2oqW7d*TtM?db zfwhL5VOJksvhhzCeB2c09gYPqcb=kJHhrbnTf2bN^d=>B^j+u~lLIwH;THP2duroa*U_2v<<}I1;ZiBwOpPmG`e_(ysJjAjBiga+>C#v^ zsS=x~EVYBxVG@!;&#`hX@5O*!HN)Z-0pS*2+cyX%kL!%CmxKiiMqm`{do@g2+hT~s z3mPFK2qzUf&YG;MCR-S^bzeg3RGy@x#1d9&Z&9FMdtpnHMGL%Kzrmxy@i{~5VD;Xy zm4xQJ!mKm$uEu%}4w_jL$(4EZn&6P$moN3hC}1#_qx}x7VNvfQ0dsUlPsRpYBAj=r zf7GVd7UQdZ%dBQ%Gwuwj-`Exq+6JGm-fSUhft6c{8hgK!h8A|E$SaHGY9(Wu=U-NUq{M|>>^MFIn#xrl?KO-Jcp7OWg>FZ4cYP*eWEhHx?sS*hbX zud2)aQ3dFYQ;*KYHYyB}3^CL+kAV(WSB|x-qtE=5=Z&dhTc7-0DvhpQE4_K6lYIqF zXU>fCqS4hnZGmDg>gh_VaxIsltsi|bS$71(nM03WjTq)#DM2^K~E ziapHBe(Z#5==T~P)^r#dv1;p5X&y^bRT>yGrEd2&M5p!_)?d@emQEsFD=mDEcCdfI zGZH*3YG>778=}jT$6KG!yGYlVw4gc|<_;-Lw|*GCu|4ewyv_A-!r|;!ItbCyh0|X0 z^cM;EXghvr(==LMOry*{zl~r0`b;24bE3|oT2OI6v8(tq=I=P<^NU0+cweR5c=Dc< zd=Wiwy|%}w{*vh&RebXYo#O1lTSO+G5p zneX`*KQ--b5a8!?c*tRw$~lvKIs?zIR?WWek%)JW20ITNi9Kuh#4xkgA|GwA(dg2Z z4MHz+*sZeX%!Y>XXK0Jag96yCE2>Nx388kx6)Q4!uqXi)E9>eQ+}Doro1^m_SxRHC z5^R!in@F*GtdMt~KNs-tvj^NYkB$CuzHqK+kpz94F}*@e%%!ev*{zjagcn4!uEeMx zh@K|Rr$#cWlUK_QththHS2r+IuBF|m?%YMlPtDm6i@%HUB@^-LLJPiq60X^URhWGP zqw%NbrY>L-oOR}eKen*m7^J*CUVSIS;<@|Y9Kw^`wRLosWEo|MQI_M2tM(!29in2y z74%z+N~Uf6SRYD;8hPZ{apY9RP&j+R9?&xxt1CSW?KHSvVL{{Sd`r@f2#wyZg`#MIUQL-y5{6y zjp=vo>5}Li$V{oNUCGFt77jeB+xWch;KymSvJ_4r-pF&TR-?QkQrxFezpCe%8rZ)J zlSK!vk9~>IaPY9Cl=zXwUp)`Kv#E;$aX%L>H0jCJ?o1|mPk!n*uE_d~>!%nUcQP}} zu)sp(1#%Qu&{3YORgV0m44XT?~<8{ z5(AeH=zYwU_^c&+bru~JH#cjSl~;yR$j!^itM27Lubd5U7;=lr!3eBfEyrtup`>%8J?U_+C^|K7?dR2&9tE{Y!9zE)ac4keO4qYOSStixw z6GbSldLUk$8yJW$wl+VVRIAV z$@I5xj~{P1r&}|2bY);b0Pvsf&EjGjzps!DlaaZ5N+9j+#e4Nh@wK;GNtT`A7ROZJ z+>N}}p=1xtkqG34B1DA?{P-H_AXwv1Lmhl+@14zYZrv_k#j23>g+7>=xIjqtalUvu zQd4b=TCZ*pN~ar(EPzh<49szqKAhUo+FBVg+??#eQx?nG zdXqyd8;9J>(1=>FbL;6H^O0CCQMgJwrBwSI1F>xth0-;ujl`D~t~Sf?_QiQ{L{Hj_ z!&&1~WO`)V+tr(ozim(N6r17TNjgPK!sQ#@sLhu!YdoU!nj)95F2|_Mxqu@n@d^R= zQL6pQS&SvqC!F-~2OLJELyau(r7dOa-?X=rWd|ODU3|-_?4@xdv1wU|m$YRN^h1WE zaJX}=2-RPnGVms0HK|hOfO9YcIod)aqOZCY1sKZESZ6>0!US}%6Eov}tm-=--RI^l zXERpOuS0}HEP@U+)jEa~vGuK0ZQm0^x;yxtW`UcXKbJnlFxuW|2Kg4K2$u)bKMmJP zid#<_QM^&Z0TmiC83%x(LnLV|%3CQ}gszYBa^9`!>R@HHcSg{lh@Vb0cIW^G)Y}b~ zG2Re#W6D^jq`((vUlRW8dG)web@kT@f8AqVUa7uct1ff-ykMJtIDosSmS6k1O~#ai z>s>moiHT)FP)T5o*R19e^H^UGdbhqwbr=n-3Dp~=bw&gRcv|)w24dd-OGQPBMKK&2 zF*3HZ^Kq*TSZ}DMjsmKT@FaZ4bcXvKD*cYcsBjEq_3aD^%U|~ zGciKGLNcG`L8A=Y3Qy-^*d&?YA)C|8XM#`urURKF8ANHW*5uh<6yiES02|fGNUm6^ zF#)+Ia=ucHdSR$S&8<46`4C>noy{8$fJe-HK^@h=g^jL|DwQPM`f$>iCVnZ+Y&BOf zS02+C>(pbfO~B+iBjIyBH!Ez){spg^Y>uEeA)?|Z(+0x$aH^OLfFJ0O_ry?tPj4Z& zPM_k=VI!P)8?~*5BeR5=*8*|BQoxVfekQ4l1AvbgOuwtMs`X;L&NFl6oY7c&sl^OA ze(;OnV)Tsj7F#4EMhvbB;rte})6XO}Z!NUfaAu`;N_>^Q`0SYrK#WBEgQ2fV@XNW2 z#u7ld^jnd%WcuLQM8wXSUC_u6yw7HZx+5Cv>bgbq72pH=t8P-v&GWlOn7X;#n!a&5 zol4M-n3BG}EMXYzv?R(FjLxHjhAhX{_cqn7%Gvi6BNTplZBZ)epZq2St0Yk5eM zW5h(D6R4rTeSBi-%0asgT#0mcOm5GcN4fQEmu0al`>AV8mM>d4Axg_lgN`SI_4`205BZf>bN%hiX(Re?fP-zrE}airOhFUtZ3*VN zDFZ@pdg1Kr$~>Sso%GohRVTD_W5lxThAO9d_sKWszR}<=bI>+?1TBP58?L(#R|W=( z5UuxIssy#nMw1InYCQLC#wo|*Yt%UuGb3+Si*hhyr}oa*r>SD6sHi?ikZD}TgIy=q z^*%FDNV6bhRo}o8ki_bsZ+%BW@w9B9`^H;F+|U>^Je5 zHzmQ{feX%Uh&D7(+x|KR~71Eqm(!%B8R{*jftE@BAN<5bdd8TX}Xi! z!a@{_M%NBS-AgtLZ#04oR>6LVjC;
Jt{0q;ItrRrJ`0T2QK1m%XsWAeypQ4yBo zN-P3hYPo;9b{NPc>*Z)fRfMX&Q&;it=b@K(wx?0jOXcZOZ=F`^37rxDM$(Mwu3WaK z0&UzmPO$gVQ<*ig(i-xwV^P}?{@z=#hb{@I>9hRcyNXMDb6fx{1w@YAN;L@Z^9tfY`B9EV~=)n43_EU&{6z3WOueS52m}+;=eY!UUvs| zZ9pGqp0_FY0Elnne%Yw#j_J1++C3w|Ou^&!&c`oefn(*nCp(31OpdjLihC=wMQRsK+l87&>x0vxR}bleS8v{fWKlbhBC ziD5ONcTihBG%K=i#I4RZE!~ub`*c*)T8m3J(0caWok1!A!3Bx;m3h^eXr*Zi#G98t zB38#&FGS{9!5NUvvtAZi#AS%yXpNTJl$xgS`gGjJ-25F!b+H@Pbz5l7&!y;`(6=G! znh;HeX@!CO0?>3SwIsn{ckYopliR|GnigcbSPe4}*#SaLvPXqv2FBBr@A~}N^Xa-j zKTw~e1dxdpYYC{{3}z~So%mYo;ADg%BCItJk>+%+wgFk>GfOXnfd+a&m_n4wm0azC zkHeLm^5T66JDUvhG*2Bt1Do)j}WW)lx2Y+ev z@jL60n1F3INr#-8hK`l_YPw|j|e)ji%{njU(m8E9q( z&08?e(r!fGPE` z#uPf5J|eXw1%V&%%xHVGt!)5MTqm-x3e2N^F0gx#_4j*ZpmUdt!Ir7xKKJSHNoc2` zo3zb$m&o4902|%$AueAu?5m`*u}(WPbAGF#{>dEt1uXusJ+l_*p=|YG z@10kQ5ZhBr$o}&^6bbYf66Apw#rgVPOqrkLyj?Xi`t*}5L#t%2la{8RlXMzY*pmCU z;LT2XB;=X&(8FZA3R6q9ramg*BW)34WQIcbcrA(5^!0)B(YSG1OBrGRpljr|%mZTR zxVL?*%yT+ZDmmBDruCp>+%Y#lK0dyNEnd3zYLvA1`sOK2JcQ_AL1@5EFSOwIn@DgC zC$%zOZ=*F}IgvDpwu@Vp+b7arSx`*aj?9}hFv(S}9U$UDb9H!97YJ~xxt=he)!BvF z=2wU)I{&N@o&K5AXpn>1u{^(2%wMX#A2o#E_$4RsNzw9~tWO)Cj}$#)ielZ?w8tsH z?u$+9kYKS@Q-dH}ydQ-6;Wmfwr9=Suj@Y%(*8P2ZfqfA40EB(h^$6WcMl$5!Dzd_` zQh>Ehs*o@WXg%@CtJDZMS}1~b?U^kA#Rvz~U_L>#98lT9;M~2_NYsv;9AR52FK;#37;3w&VXWHAk zzHw~6CUbIE;27asM%QO_(KvGHolh!#Xz-bBU5GrLMM7bjv9=iyzUAQ_M&XAOW2X$f zBTOLAXq@*NkTp27&<_PtG*igjKzdXX?r*h z&jx?XFOW(LJm51TE6^+NYX~_zl!1gsUg1O&K6rpQuXVu#VqPT_DwIR*Jhw`mjg-6K zh-vbjvmQxuoSu6x9XM?vVBpnzgsp+At>B}t28hTR(rZFPL`So9w-mesv>a%FdcWCY zSA4&^1vA^o_+=@??O16n=-3Dfs9@ULwKf9BJ%ac>q6KzvyXxfGZJ1Y$!Rm7X^4yaa z3Hb*Oc<{{6?PRib)9F`**D6*qA-m}w()i?OV-O~L3H2;WAp^)?f$+JBp>{q<4Gn*E z*89uM&9HPdw7zIXLwq~q*Tsa25000b<~5Y4gKZV_$9SCvk=1h4?X_(D2axqo=U`G* zdQ6myw}xZah)IyU+p;^zNPv842E;i@?Mtws{wGg#qYXiZa8&p@1Jg-W<=^#5?MG^r zN#(&HL~30kX39^M-T*;J5hzmc0_><568CZ~0N(3odlz{Tt0yz(P0K_tUc7XinWjNi z=6kPnm}!v=^pIyq_ROl@k4rUy`iOL{J06V&k=;CYPvePYbF7c!c5FTf6cZEkrA7AF ziBZdbB`Y2(zp7{=r$h_7zSnG`TyGyU_*QU;)r&Z*Q@KkvrYRwlRN=b3e9zPAuMfpK zDuKb++ziBU4&@5&G13WIag%ExTdWK>H`9Ni)^O0btLaWgr==NuZwRqXy&)r zlw=L{=DBbi3`7zn#qLobm{O(HAJk39MGN-6 zUh~(u9cat;A0_rz?0OrLr z>#(svScv36Lztbo#Tqe|OHqW+HP_SWYFZDoRo&ON zmu2r3ov*oL&iN>cFEGkL+Gf6OS241`lGr!UIPm0;n#*oAPKrQYQOiX1L}j}>=Rdov zZcEEJBdv`RoAv?;tn1o|fumvbTU^YBHv4E@G%*T)6VIHp!3|G-s~gX0ZF`$=^#hb) zF03B`7ERb?i$r%ZuF-xRxv@T?kO`EVtBgD^S~%TYI5c5PHv*m`k*PhRA#4{14b)YdSHmbrX&FsngXhfq2f~xSwuP-mXg(4?hHhu*fha477d2eZTH0-z3{<6%3I7J%p@uUqkD5Q52G9>H#u=$A;-hd98r5t)*ThQ7G}5v%`-Jr>QxD@UYs!i?{Ir z60)KziZ*QZTUyLYHm%_-;YxHRaJV&Ud&wB;Bwoh2aJaB!?osHx{|d^G>rW#QX;S>z z|FZ8;;*vYZbm~n!aoq=n56ovv;ugm0jQ&wD7#Mf|-t@D>+qVZkEP+MMf+^gZYGq<* zxTS-UY7iS2PNE?syNU6m40Ie|712mj@{sI0rzPq7if4A;)Vl5()TKeVVx{5_*9Eab?WH0(k9kqJB%sfeM!PDy-?(MU*nrzDoze#60vE;OqNJ=6iQdZzm+ zsSEi_O=DU_4yU>wh(wsAORBvg|A~iix5$X(YWd6V_ zq}c^g6^1X}yoQMd)g=|`DZyEnzN?QN{o{p=s;)K-`hJN{4)(*hgn~XOO0&kZdRDYq zKu(Xu4z$Cq&Fyz59U9jKo1K_fyom>lJrdruhiAD`1Jb5*+!?xr>t%2RQ_ty1`7qjA zYNU4y3j566K?vl(x{5Qf&S&Mw5HUn&l1dCiYQg4ee|NBH{GBz7B6333G@JXW%Oqwszl(l zoRLBR5XlSJMsbAzTDnR2(*1>=)Nao&p;Z#)k#%@ZbSnw8Q{x4PIq|rNOS+(AZD^4% z+u3Fg%W0}fkBNzkkNsCkZ8&v#c2I|yEw8dR$y=_{4KgkSygaWCZAA|M?*3Ot$m$Ww z5yUExL&z=&IkIuZf4zhNSwH(Z+}mfaapBbDb3|vlmWD`f3m2iXCfKg$`s0NW2sV(A za`!^Q{55)YrA>~uGo}UH9Z|8w0w^Yk0%4^;3l`V6#<$x7>3%>hGeZP<=otfuSSnbl z)JwQUkP($ZX^Q<_(ZRm#~3LWdO-Fbn5cK*|C|;*$s#YY`@|hG#(si4|DZw;w3MYnDwUBkHj6 zU2UWYRd*>hK*Eqz@5yR`yJ6OpFOe>2!9^(`^ljla*3$bcdQ=r<(y+7rtt!so#kr-x zLB;mBmdPNZG30;AHal@~ov{lZcq<9KH+XbHH_cQ5VnbHPjYI=%Fi8F4gX-2T-|{CZ zj41%%3N&#CvafMei`39>9l!zJ*d-YZ<^DNhTEY((HsWYPfnND=ecPno4eNVEqznMO z)L_vPs0Bi)FBB=0XU)GZUWbqnZY`DtRKAB|)**+q3boLZmKY?_q8sPj`P1^>%Ic&T zym__8A)xj?U$a{f!95`0W5kmIC`A$$8B8C$e6{SAKgy#YlFv_SXK{3gjJ3Rp#rF`& z+eOCPRgO3?5UrG_@V(!_F{8wA#j8VO(0c})JR^<>LMakqtyb>YBNvikar2tL>t(29 z5*B^<)qR>dvifHW=z{Ldcs?{M$?8HQ6qJx!$Cq&D4vGy&hz@NIE`%Gze7QF^S;=_q zvt#Yzcez=Lh$g*s9n!tB#k6noPqx6xRk$ww-RVcOXcI#PD;v4Xg32?(4uQFsxlZYUjjj5M+Z^E+zP zSF~!tm@rM6<3a{gAR!rs3m#uG^h7h zx|MU$jT`ot!u$@Hy1IuS{sK`W5i8KTnE`^pZV>3!Mj4;bSiLTyPUXfjejHTT&D%LX zX-hie9)-RlYKOD)h58n{sE}mo9{R)9zb5L}eTG)i;M^_YYN<)zMoY4h$Jf6|vO@ed z|E@Jwj~AhruRJek*DxvWPD)TqeT$8&^VU2|v-_2H3}V^4;8Oc<+xy1ivC$y}PwdH% z&q;1Dmvb{N_5WLkji>mGN-A-@h>DFmmkcXn@yAR(lFv@RznG~1sKO17W}Geioh*fz8D|G28-eApbspl1O3S37DilB|1@8UpY< zes>O}6k9EIuUhd*8=-#0bdy1PC;&$C>x84HZ=0b9ZRe(rAMAiGQNzE|#59f+15xZ@ zaaSNFyR>I}?3M9RWeG~rQl4XR-nZ5$=o#JDmcCQLV3G`=c^QiK%X+qJJ(m8{)K(k! zO1>H13IpYj+<$QJWWW&gg&NdfLp18k*seS8HJ*(4JQ0{m`+WH;&;tSq&BINjXN6y5 zeo>8#0ipt_q04*xE{lF%&;%WZsa%b5A<|8dO&S*EnLM=DxxVn~03;08H=;oOt-OMQ z6dN)&)pL5_;z^9bq4jEmw->Bp9`Tj&=8(+r9(;Ck&!akS>NJ}ura?`2Rh9sT`>^b*m5==bI;c6I^RQDbsJi85{~8*&12YzZn*iWwAJ zQfM-lc_-}o2tuFmO)2Cpc9g! zeJfxVm4L2}u%$(X(+J(+lnuEa@ji4D?+T5hQ-eiC?|?9ieYpY*5;5wn9jNe4V{u3& zx7in3-eN)#c^~OwNjh6&N)x^HFF$=e_fxKcdHvU-&lO^j$58qa>9WV0#o0`DkiYI} z&nw2t-zwA<6LT0TmU%;=@RUxd7!A>Pk<~__NM2+@$Ogbc>R}g(5e)_a^`LuiWu>PX zB!H!>DkB9ekFr-jgG9St(-leK0N>6lH-tyAGOYb;kf zcy7(d8yO30w0CqNjgP3bU%K<{xL@6}#Q}(rdC$b@yL8$>nQf^ELxeCjDKcv8mXSO3 zKLk!^6HTN5AVShYLXw^6H6BtApT3=#pEk%TUivZx6gi0Qr*lE_+3`V$XvB~WUkFGa zX)5Uq-_`}~2Z6c-Q|gdnLg%kHF8v`vt469p-!|kr zG=`Tw?cnx$xnfnCur{y!Z$pAzR%eBz4bWx5Y-b>i2%zns$h86ATzf}@5hT8?`di=5 zp7MugozU%UxNihh@5)@A>!Oo#uO=B@^@GFmVaCwN<6>rZQk08G4~Q^<6sO^3Fpi_O z66p&7y=zP<8cCv{ZwDz^SKm`61t=!rl*aiWzdnZq{B}OrMv_|Jmj$Ayb4Dy?kA@>X zNXH?*_pqXo_I7!$7gECWN>xC(y&5xIx+_BZY1G?El}+I{KbQ(;4tr}Eq}DTwq)9Hs zxQ=HPU6duUm{+o{9Wf5+`oVu8q+CLJ|neZT>-zPBgL;|ilV4Ya(AL3!} z*S#7!2?{(!nHGWa%uogk?fkaqNSh$ai}OHvlz`cnsssN-oQDI$eB4p?$pyxoT<5WX z;Nbfkr;Y&&vg>HGuF!1IQBSiY$>Ai^ckG-{Qp3?!6myW+oC=m7sj~|h#3{J;{RI4X z$D`9)D&CHY;YzfbZJoxY&%d-G8tJtM;L})eXy2$qkd%|A#yZ9Oa7?Q3A1~xiX<;B? zg!{aMP0P@9#&4a-V^GcSeU!}%JsYE<_L7((BUgrS8>Fb8f=qATQ`7aZyM72=7k^mz zR1+!94$A(%aEEYt5P~1n0(m1IP`lTbcrbb4OD>ky6+;xX^d1rKSx~+sfqIFy( zO3l*6ll1{`1Oso(`$-z2oN6XLaz84>b@P;6hy>p)g3JamPL{U@*dOn*>!d z&S~e=Ux~x!`$6`h;FS>?$p+bsK2aw~8tNtgK=GZyEJ)BOj;&}Jia|m1S6?TnuTlL= zvk?2ceZlX2(#*cDRAl)qIaHW?AYB{lt1;OX9QA@JPRh!9NLcO=QLc=bKB2@ zN{|+hbL-Op(MLf}&;fc3r1k{OC&0vtv7g@xOTvkdjEi`j=n}dBky*I_cYdJ~5Z0l; zTBkLmo7+4Q_7I%l5>=inUkNi){SDQqftVW&WtrYNjj@M3mPjh8g}t@HZqu%7C9~7S zX*su6v*6;z`1Q-ghyQY|V8?)nruFp`hb_yJ*5FWz3yP5;{STEcn*ljxgVK5hIrLMq zF`G8b^A?Q-U0;!az&P?B7aAL)A!z!dhkx^Va9}o6=H((50V$_!8U#fuQ4aBpF(;{4 z(bcDYKLL;nN+%M8wcUjrUb04~5ybcezYu_(7(UoNUZ)q&kt7b6jXv>*WM^;pBJS;@ ze4F3Dn}OdbMmB{Y-xeH7B0-5|)hA=|1?e^-WW&7J=KblRUBNPN&@pid)yZik4w zo>?oTu}8UdLvHG+--ijj3c!q=$%ie4RDh(dGvXN`?NNRKGJ&sh;^e>6j*ofiii5A& zUgm`q+bvHN*mhs!xqt?H#JqE1yQQ{;EEum7T>3u#`bEHHp4Ws-nKQrNy(G2}E-FUW z*q+BUMWty7(|bZT2pG4QgLR7PqBwAcf+zu}>mk-W>T2;bWU-WdeC5F)5b-s@zMf~# zPnno3B#~H?pYO;MrRP}nLFKhX<+(>vgDLToCK=`r{#d-{G*K0%nQROS-h0vzb3|}=55geC ziF|<9H{MTYKdpzPTMkk{mkJv4g4H=^VjzIRQ7EUwrx8(E;Lnaf`MZEPYK!sI5vQVX zGop$Nxg{m4!$o68wXYufCi&s>7%C-1GB_MbRCZd}k!uOc6s04lOxTRLI|2oYQqqxl ze(UP$t$Jy;IDl&DAex?gFJo34w(&jjn z_N)HQ<7!wL_W2_{Cv=GJ+5x~Jb}$Qgb%|yt?u}>h6S#Qpoogltoo=2BCZ% z;E_$;m3|F+b_u!V37yI(;Sq8d5-!d%DMz&5Rxa;xuHxeVRHuD4Fxk@ac&eeEQ5X(s z0*IsApn)>%L%s(oq8h-~#;E8Gm&h>yH*)VCJNxRj&hHejZYtDMV2l4RH3o zbcjA4?WkZO&F_jC0^*J)1dp@2v!%d#f@a(Ct}18M_}F#)*JFSc7tQ zSXrwoDvuJ|tS~0!4h%orVjIV8a3IT<_VYxQp$(L%Od=FJq!pbKeXuu9V!sz!C zVO00l$QDGe)V4$l;(x;fCU(F{ln zD=0nY?m4>$?R-et3Civ*Bi#c$Dz}CT8w*PHLc+tgJaoUVHD(wADcv^zGTu zFtFf*Le@`G4z5Mqjqzf@tZhcY<01F0>R*ebL`!GCqD-u6j4s9aA+6L>`(&?YZiA|$ z!%fN|P7bAwB_PL5Z>!AO5TS_Jsk-RSsjP8&*Sk3xkV^Nng}-lt_-{3MgjY+&JC9~l z;)AX)^R?{>bT{AyHuLVoZf|Czz<`2Y5|md6x|U5R%TNWwA+k6=->J80_9ZRe#)L!} z_x^yP#<*k<8>4|YOK~;z#DWie$gX==A3`IO<1L+xs)n*r!m6~FIB}Kvsn(ov3vdd# z%lwQ`M!0AhjwpD*>jqLc7d|^wh7GBzZz($qQPBdHiEBTkeh;MUd0vFN82~tmU>?)Z zZfW`g%Jv2L=99%lMyq0gI(USA-Z2mtgXG(uE7BuWXHd1(s0d=z*l%_{_3SFE6K<@y z(o$5$q5i{rw>W%2Ke<#6iNN?n;VYvgbsFiNSlR{k;|_z6&s}>j~OOI zemRdrJj^Ys+6p-0iu^j5#^x`Ds@#MU7Zx9OBh)Dat7Qo?pSya&ecCrq9UrGIP7_tP zzfg^;GHE#(*Bi-zI%VA=XSl)qMaFyj1q^iq^xE)Q5mIQ)+HazSVuBkupU4n%_ZXf3n;g2qf=IVUB{7%a14%}vem%DLGy5)ox9I_XR)2>0Z?WCF% zDy4$IOnId94Xml^QSwz9nr%nvOHjD2vhLw^Zpp?0GYF+cxW-RNdZZChbVzr?-SJiu zSY66M88VKLt(x23eXeN;g<=4A>QecBB-6q-xLvY{p-)XUpe$~qp@Qq2HCp*i_54Mn z&XFiPeMsfJQZNS5jnYNE!EpRbwj<$@nI-;RBeS453a4b)#4mM9-8Puk;A&LvlyU7j-?dVgR22 zpZlTJ!c@nP^F=o=0;rlSO~_SF%iY@9o7b;u#bPJ?YS>?%k@2f12$v4T|8@SUd0e72 zdn1aDgs*-7{ji2ehs$Rx!^#_5cj}^A(Z6I#KTf4V#?V~9!}`hQ4U#WV8-$t1`YP)- zr~q`$brPS{`|vr=G*gw_-}RfGQ^6G75W<)#(GH>h5u zAn#ztSMcTJbg!;8aLyE(680v#fz7! zK@e5_<{}vVK~MCmz_vq2TFbV z+awUef+(hGwT?Jb1u6jj3}CyH8dFEe@w`286)3ij0@6H;B#;XY6-%vYX1f)a2rVrGp)!AR^UfHAfW7qN`49T?XK~6kpGC5_yPvnact_}8#x(MTlBQ7 z5lU&%1amb0+vL-^8LePct-c0m#|e+kF_yf1M5P%1cutk^4{)V4suP_42g*|QW7 zD5M{tNU$KYrJYKvIqKP~jv{THrL$G2Qfme}dVPy0ws`{A8+Y%iweT^hz*fgve4HuO zq=>Xfu72|PE0zcIHIj73lmYSjq3MQ8;5;bR#fW>d^Arg*s59$FfSgjI0o9^tMuBw! z3B}9Z%l+WaoN67d)#zZ$q(k~z4VKRR8O?0jz&WF=a|HKXI!nrtIrTwI|Sv4 zujwRge~j{7%>fhey-Q{qDA|qXm|43&Iu2*SrVR|--Q#*A|J|mkYd}LmT7dGYUXw?W zHknG9DKzE#fgKQ2arqAARWq(#3lb7PHk#&2*^X|q%6BGMBuYqJ()d0V8zwdF`zRRP zn=MKM`WrSN(g=l9Rj~~m`sK@;QmE{$zCoGl!81~2^YA7Zm3HJ0q6n45O*lm%C%6By z=d*m^AyESZZ|^Wl17EZ1=w;`SLSHqKsuj2`h)33^{Y=Xype$vEg!N+QB1B(PBFF^bXG`ZAOi_%%0r0wRq42`Qk~N+@Y;&IF6jx4u-e zdY7X4#9lAn?Y!?$QR-56pQdosw|VJxnM?$lx>*d=6<<+OVBnSu=O4`;jzmh&RvbDK ztSTh-y^#14B)Z#;H1@}v4;JOuoz3+aL$8!Hc2Tj}Qcb&7o_;)T^O|IcM!)^` z{)uwhjk+s?Sv}dhKZnNUfAm*%|JUm%yo~B2>ouLjP$LY;w41_h6A?|tEKX__`?6QA z#R`^MLLvhyduu}?1HAoSHzO`L3H1j)2DuQLevL;=Vo7W+(^#8|0#Fb|f*KG>&BZSz zv954~yJh>fN+8N*)VWL$Wt*bxs#vJfMr4>z9oTy-07`Gvp?;AYD;Q_5889FfS3&lXOgg63*Fd07Gla9JsASZOg zJcev5zE(K9-+RSEFX~wX!QL|jqK6N&afZ*CB8t{-?&b|f?2qCb zX@CTdTI`%I)V$tp_^)Oq@B7rDMf%cf%OMo?-$VmG#-+L2`2_4UO9t>%h|}ZLd#!a1 z3XGacmgQU6_+uPM^D3JLO<9*XK&^Wg*CpUlBO{U_E47jcuFSdNoF9Xzy3RPJg**}8 z0i7~XPz+HwSvPNQS14~rhFe0y!=7pOvyyCd0&*~k zO_7Br^8vWI00bW;EQeMq3sCd7qa2EMD#FIkW%+WxH6D37UsQ&oKE6=3zmtfE()SG; z5NotX-0L!mylB0LDyWGo;QeMaMK--Ii>`p&SQ&U~-KQW4sIftRzbYn$ulpCZ*QB|e zm&R2z0P)?L&%C}{9STsQevX(vkRMR%WIEh3^ z)_~hV4E!>!G2fq7c}LtYtn-^YRcVmt12gzj9J#puStcYC zVvsk1vY1+2YN`+p33wZcs=DyYx$YN+b1)ZSK(UH0=b>WC4%*qR{2Ry+>WSY2I^fA5Wt)7cb6a{@$W`hgKP#BS|zfj z-8L+ZZ#un@S|;;UJzH(ioc`pH%e!4`%^GA%5eLOK+@vX$vgF5l!YL}1elF!vWztfw*b<3wOjL5 zQ}W^bu`~O754S0@g)(tA(l!YnheWLZ)fo-l$hyVZdD(9S> z!5RGF31p(V%|P`uX|H~=?tad;V{6S;W0R7TcPK(%Cv;rHVLZjT1Db07W;QoUr)8(R zm^`nAW?E)pY#G%39`t+$Gqa}Fq{ig7mmR(P)pot-;A80~24Vo)%3djG2 zXAS9FdR0BYD-)o$0N0!Bp70aaEOORa^l5XwD+4^lTKgaWZ0fVnP}reitdP@`{MQDP ze8*t6kFU|VG zGf`5?X$858PH#_l(W5cZj~skE3iH9>AL_arxUOky->EK#-+h1MI;`qmD;8XIExZb~ zY>dS*ugCiK>D-mE%g5wnyGVUZaU?6EZl!h zyR+AhvYS=i1#Z1ciOJCRa9~W4=tSXfl`Y#}_Wj9e&Zkx|VZJ6mG{*7Geq#M>wE5uW zf4pAxd^B2ai^8M|>cquM{Mus5Z`ivFI~?K?zk?)bFXqWE*|Inyzm z%AvoYIb>yFWW%?EU#FVxAM1N?j7P@pHOPhyYpy1A^5Z>CYo=LGPK_l5ZMx}k$IE7? zbJ8iqs$H71nQ`}g>8ZwPu0Kj zlQTKvgV40MZPhlHMHhUaUPx-tF<$6f`wV7K9OK&ha*idPUzHfH9^c^y3wExIy~3sa zv7SJY*3BEE;kd^N!>-fAlKmDh{dQ`Y8VMe|_n5!v3>Vk#K-55%#IC`*P(*__QE;W- z3Jd6i6V(F(jWTXQ#c3d6yK{=y^nm z>Wz6kk&jzidn8-_fnNV=aN|5mbeznzP?NP{K+|ABPxkk)DCPdC+ZrplE-TGEzf#dh zF{lf*SR^0IHZU^k^EBwWfn5ZBc30-1jN0O`Y~P@&v)*+|YT7MAqLROlp`M&UN&_&nQw6^bNb#21aKi4@%T z(kQQ%KUutwVb-;@)(M^$V!Jf*%Cq$1En=tIjz*pk$S9qMD1T;7&YTggl=TW?*-3&2 zk`9iuzK9Cm!pmD@3!9Q@k63vl^{-42jw+N-smafb627x~{a$12#0cEz#NdsEL`FNTe4YiGx4F$^-R{;NKsd@*hO3+IWqxYpHhqh6*>dT3@t#IRQ_ zDXt`jE0n9Laezs~LNGt=UFVq6##)ZOq{&fl6E*9#roeIi`rK(?bvVp9yu8smDvf9AcUHm9fS zmi_mGgXPO+lSo#ruAviUF}4$hL-CwGs~JXU}v{SI4ncUDs?3D6GaB@m*#c@!8`&o^Kp+^k{%JBQxrQQT7z6 zOBrX1lCagIscYXj!V$esQMcuCn{~{OBW*-xA94%)uP*o2MgL$P_QUtK5y|}gJK>6x zUjoE0QZyEZ6Ca^|vh=-is^U-G{$zx%)WltG*#Tcbud(J|9w1!YFTs&Q>Vl^u%bubmybd0nEpc)g;gt>A6sD>Bozu0Lp+FV=l@o@pwN-A%j%1gf zN}b=pjMG<)uBC?ksad4L)9@ikiUm<(s}2LsUI&dcS@>VWz097TUbM07FdF(};dic% zn(IV;?x+A>CaI7EiW9RCfo%K336L6$0ivtl7 z#g3FDII?c*-4^HxhHZ;EXKKuU`l)ngGmiRcB)czEsgF+84XZmd$vKvFOw(8CSJ(wz z)1QB)T3E_q{enmn$s$AON6Q%QNolx@nizW5JH)zg%nHG5X- z{x$t4SC)VZ+^1mrMe}3_P~InXi|w(bc^Yk!_42IF(m8+9CzcH?x&^$X2laRx-!ywL z)|1g?PNKHw(T8=5$D%Yx!xeth#JRU-MsTnSm+`N@<4a%VGOnc4DE(!xTT=-=^7GZh zkyO=@7A`K^qicvxqOC74@XneRO%q{~ zLO`QFLEAik_#rvnn9@ZUeZ+Bixr~4D4Zm&;{Sw|dS{zuQJbd=-A&`t*2P38$G*Y|h z?`-g=mztsfzua$vpAQU&<7fXQIQAn!FwT9!u8nA;eqh(gBXs42-GvL6 zcABl`DmPEo%Z#5~>@)CiA{Cdxfp}8U1BRkq`QdeE`|-Ci8zezxZui|Y|2jt}+$?$=ax=uCcWA_#*h+%z>gl0|> z`2=cKLFQV#*L4rjStjVJ3G==(BR`2o^lx}=-Qy76Hkk;c=hD$Sasg&M)z8H$l>R}J zq{y$Xp(;~q>o#3GAw9f_>#{6J+lB!5Z(AbG52~1B2A$tqTcZFEww^$PaXhj(~rm zAN0m$Ql3leprOf}2Z5-`WYOQ-MaH2f*9V`UW{pysB{NM3#9jDU{E}RbRs8n|EEX!# zI@^c=yvouG1GBp5c$sKJYwHtd7MuCnx&{|#nv*!E%Ul-!1?bRi(UP_t%N@EziNEoo zfU{|y!IZ-`gZ1~qwZ2^b))UaJjN1eUa+Q3ehB!CBhOh%U<1{++q(ge8uA0WpE<&)r zp5SII>9Kf#sE2LKzs?;v9DFEEza|OFJQJQ>wYILjE7_RNKgXISkwmC!&RQ+WNUXDX zH_|&#R#okAnZ>sDjilg%c!`dj8`vH=ZSJiAHFN3;ni3NC@XGsS=GwiRGbbV+Eh#R} zb8to1RiBMXb)0>_gf9b2ph<@&w zjzE@DXBigbJF$i>ipDx}9v-zyeLCgYIm0=RZqL~X3HerwW!wtaOh|SFQC{Vo9+nc1 zIkKOR>#*b^n?+)0cF|jsGppP+@t^k9_%LXmFHG_!B}jS`wiKcxGSNN7Lpw@;N2AL7 z>reId5B3gEeU2SF{@@;7Jwy@M=FuprmbFiLEojaz?MtuAq%17v;y&}Deb#l+qF@d6U}gQW4YzIL9f^fxzA`gM*0^_% z)E#WPc`EP8Z-oyR^hvJuUfRUozK1?SC9WibZDwEx)q#Okg8~#C?yLODNcm zX4CdJgcy*~FJO9lR*B|Qh0s)Rd@rv0udZkb=H~kE&$XS6NF~JG=$4zD?r&WMbW@g| zD!O}XxL)|CFx8jF40lcTHQz;-<1GXKdH!;ZQKZE(L+3L%+OG+uzu+@T=Dt#|;WM{A zXj}&)07Z2OtL^W9Zcjmv`q;uk_}Mc-9pvrS%gD=1IxSX(_f5-dPh4kQrVtH_#t4T7 zzRK=0gO>QquQAtgpU-a#T)RSy89ra#gTbOn=?Wj(I&xqI*Q*Mg2J28ef7zzv!b
R|WWR(-DzUI~5w>VpL zz)x{9UzK9LICSS5IJX!o@L%Yrk zF8U(_^U!e3jFU+DEWbKzkkU3__EpK3QXSidUvgxmxBc5FCG0_i&+qp3BoL9cxY>1W7djZ`>{#kCy>Kz_o;L((i%FlBYepwlm>LQQ#u%!m>@+i`w%P)UyOOaW)$ayTn0`NcDLp+Y@8Hsm+DiO; zBH)W#uQ-H6^NJ8v8XHm99hQiT8eHikNO0RF}UhQOzXJ+f~$tKt#o za0-9xU+0S~Kh4tr|JOBs{`r56=I8qfLH~v4e!lO&^Z#^%lh(h$4nNVP|8M zcA+D|W**?FsuB`Bv`>4G()snG3oTEEzxBm)=Qh+Dl!);*E}q_JwqV>c084+8D(vYJ z>)yWqEg-O8FZW;U-$BlqU*gtkn`h6)#Uq7s@D}5xZaDgRZlSF1cks)7Te6Jm5=#Af z0s2V*x<7xvOwDyNyAsPJJiLwkrz70)-ik5>-&M~csLIh#KR<{``j4|mmqtFllCn(rgGFCAhlom6@c9Gt6c2( zvTxCJ75Ra;xk>tx!&3m>+?!3U+h-Kp4ewZ^tGX&%P!s*ui6w| z4~~`i@qZ!6jxV6Qf8JtDH3K@@LHCtwRyCDVxHvuwQ6Cgk@GreUTI~G~ltOxfMkuPkwl1XX5WWN0u^lPmv!7dnPmupujGcKOyL-=`v;R5;ipyX9 z-ttCCfRnGr8E+hU#cvv$`ZNzC6qFET_QQf2#@Y3y1RwXMUB5hZxnF>+IQ_*734XPy z>jXH|Z*qq7RY3%Mo>0XN(rI6w;dK73ApdpO%lS_pKfOd5e@OlhhWp@R@iCKF3J-Jc z+&PXPw7ec%nBSN6d$|*(XyA@L#X?O&35k&H3FM)zT$oL5C;C>6>L7k}`IV*pODe~R zAovCpjhUnsI&jS5VR=Tikw87`PhWiLOP9m9F4O8@gL>$_7-58ke!c$Z-&c=P zhCt_mC-&1kFMY0DY@0j>m|xaAZ)&eq3{_1R*qh~k&a^J~$-3OenmVK0{yUD8N8plx4f!q~P_>GxBD*ldJl~T1YfzkD z9{k0bqOtt;3^M@(=5&!^v4TjkY&uqF``#31XstwvAVzivxMEB8l9oc`) z%;9v7SY%NVJna4vL+N|2S$_tdQb<1s_&GQ;bPFHO=Z5|CC?gqI43BI;VVg?`D)Bz! zcfD-s-H6%fNV(p-IMGouPp$X#*B-n;S?nzc?xzUE%1-|(T?Qo9&(VaJNv;!2-Xvfr zH0-sehg~nUKu9c0U@j|7t-U0B@5fW854GTVsr8d>Uu?FU?fq*5azwK? z`E3;69Jw8-!n2>j??=AB)v$qX_$ZE*3rpj2s6#m2#9#~3f2*VncAyUu4yh?Asd#Ga zQ*j^LqVD51%e}mv6XJb~1d)UB(ae)7?#W{H3ehUGLj!LeTcaJ`TPjh(yK*I<5A;@G zp!9_E{TlIw({(Y*DaP!JWK)4(o<#CHKL-Z=O~F7zasTs9R}Nogo%Q*dO(Iw;CpPWs zEqqp_VSdJbb1#>3b@QW%+35v`MsC73@Sg}RKGdM%U?|e{TK|`9l#A8xih)Jya85HT z&N$IzLw~F6QX5j7Sn*uLmxrr-eB_30%9hy(Da+4IZ(V3rk(WjpUvN+)Wv#y5ZS}^R z6XNX02RM39?pwF(i`j?jiHt9WI+1I+4zC^UmA=obxwyO6%fVfZ+0D(d@hs$#=aLi; zuv8e0%|(bRT6G?1lR%4}=~2V)b-6l`5Ehl}QF8a`h9AU|yD%GZgH+N_Y2Lbu{elN= z0ay^;3zPM9geS(&^tlkb0l&nXoW+pdywmdj?Jd7YFJ!cu&729$H)(V_Qmt(g{*3r2`3o-EM{MHE$%M>**j2-(r#>Szz7VR z&pug0z6ht%(k|-Tr!rV+mdb4Q+@QB4Li!!Yu4lfvpH8{&kmvcWsjSLj&mO7p^{_+g zO*CIMlZW1mHM2g)zOQ?K@@gDML^@{Y?GRw0b5E2^IT2{5{T31t0KF_8*(RfUlhh%7 zpEZNi95V#<#nCmV?=auh@bOJHjTN?aIg!5j90B0Mh^ng#nPb+ILbeTIF*Xz+x)Q;b z8Rt92xp^Pf>Q)(`!}6{dR+dI;FN4=W<=FM&&XR9g7H2n!;3>2$5yWoyv1Z32bg=%n`;<8<25sL zJBr;$9i2(j*7`c&5hPjU5w^jGCv?$VJ;pKJqa@O}YYgM8x7`+bP+VL~24NTLqM@N9 z<1t{VS&z3MS+sMf^##Q`5gn_y;z@pqxsSMTG~bdzy4TsDG&|Y!c`or ze;Gf$#Haf|&GMiX*;7P$XpfPDF}5a5bw0_SvZHvhT(YU;zM@kak=f%M&VlpM%~4a@ zmneQ-gBAb<;0zFTwbU+;aK!T;PmDnOq~|d(W%<+_*d@OTpftEfhl3I8*Euy zMrkLv?%Yw(OwF=Aw)nYAN6u%lE!(cP9XiXOza*AeW6^VU z27M-0bv$SH9w}#{Lej`tyuz;4z>^TzoE3bu%wQ$lT};i|KeF*MX;Qa&!JdL;h{T~M zJDAvC_SnJ^Gtxrl3?5K+CQ*G-+kmsyXOk@aCJt>+F(VK?snVy4_WAQGISz_E_a9Rt z8eFV;{93;6jquJ1k5_{>d@#EYBTYBJ;9tB|2o=M*!+hJe_}o8TYdJlas!qW?EYy=d zvve@EztA!SJQ%?f;PHl73e%l*`Zlv#^C`*PvDiKJ(h|x-B>oZs&WKUmVa7j3COT8i z6VqW>9I*IjsLE`tJw>*pU^jrdjVL}rX44xw%r8E1+tYL%4Km!%<#b)V63HmswQYUG zC^ke?L1DYs!Ck7Uy;h#T`(3(lNpVy;m$5duUCEnE8qL~a{r6{Ya8y>AMNviYY)nhCDcu&29-%Da0hdCK<{%fxF9$}VOeuHK_> zqZRHHpzW6#Z&y>g;BzESVZGqNttVoY4~BhfRzDzrz=LRT;ZpT|C4(AuT&&>TT4nS7G`Gt~RH6lX(o1Orz7bvS#<1)gjyjl*RI&45?yh%s?- z0Sxzv(pSivLZg~r8O$y1=#T(Va3j|Q3v+t z!7977WMli2tQghPMy@UK)MGeoiOxHX02q7k)Uqc0$SUf+!if2o^ceq}5KhoZm+~#{ zqDT6Ob0nJUZ|7%%g3YgwF@}K<_&WPdFw(;9%I4C1xf9`}oismq9_;|oE1 zYiW(O%bzpM(|`*{UvoV5@nhnT2tCVDVqZte%pIi5kPFl4NYMgXB6=VSEpU0+FAo&` zvwm+etxWJ9`U7j5yrd*3$Lg=68TV72qGE4$s|xp~Dr|Lz)y=+H-(@^2sNU!BFp@%O z5NfHH>#KWCSS|eY@sk?IR_jkidiZ#C%dg+BsOIOx=5h_T2vQaEde#29q@-KMrNLlz z4Y_ekOWKhqBMq)CPa?^$W}lNZ#Nb)&Klyx7Y6yC}QZJ3rF3Nbh)>Gtw?ct9K>ER$M zlb>U2TJ9MsR=L=#jt6J%3$>FU`ZQeoq2a!3BE>)@V z;sJs6Ywy+7y${qeFoz%qJrK4_k_f1$*tvG{J+jc7g3>o{s@XytkmmH$K=7AbM>XIt>&45Y)s8MP3Ggg zBhh$lW->Q6ZjPdApGR8|dtDD}W;NKaS3TxE)fz^bJQL^L^#OBy>j5;~57zzKX&C>| z7{!{Y;K<6+)L!Jy_?Ncaf6_(@E1JpA*B)(8kD+7$M3_U?G`JcJ+&u*0<7LaK66G` ztpKFgu)tR5OxgGwl}zuM{)}+==o|S@-*am|fRt&LeeFB3YOfREtu;29Lauo8DEb0l zC?rsU|#i2yu+i-BmR``qph_#pY823IWK9S=iGBC z)?>V{rDb(bV!8?8{s*!d-x^V`%XylIcI7}?itwSr{5P-CLSk?@T--LiD(G%a>&6&y zEA7x%Gb%~BpOdThRNQgM?H84Y$J_!!6e6zrvfaUujBzv|gr4T3cQy@ql5Olg(5RA(P|6Rx1*G2CpmoiO#p&>f?2GaxeX+6lVBIVRaTbk6SfF@YgNstZ#F*J^CB} zJa+wLX;tzyWU$+Q({yGrnKL|7#oz9IkaKyLEI0l2*4eVK;=TL2{?aV)IDe<&`00-y z`|k&l-ptOA*?;^bc#hSq4V}cvlCV<;7DomfJ~VXEZd5c&mOWK0a2apqc3A(c&WW;< zRub{Do3GWmRBoWK?d3KHl9JCnfzPEiZnjO;*VUlax*_NVBrdQ`Co6o?N2ZO~0{bXi zxCNwx2BfI>8BX)Z6DzJFC8SnufzwZH-`U5G|FTfV;S~ItHA9b4rZ%7J~L^b`xaAB6pv&wqeWLlzJ{=G zH4@J|L(+hGG&L17s~T-upQf+p`Lqwcp>{w?LV~{}?~r5GRmX9!&(Dzo_oy5H9XtLF z0EV`Xk#}h|Z2nCM^t?)F6n$AAyGe7QN_tPDa|I)&WgItuTwt88JOsKJ^w1tg@abB6 zvxuVG*4)90m0)jsXbW^cZgXQ1gP#3`%~zgoz^}gXve?37w_&Mu{%s9U1jP_~xqEyx z#vN%`E5IzC4vrpC2M02uG-^pp>GbSh2TdR5HiodvI#o6aPPH3a^0U7qEU%veUJXYH z*si9ezG=B9MRauiyJbf^1%PZ{V|cxTTX_#`m=`Z{a*jKc4LyGr+J!VP((gOZ7N)*s z@WYNu%@jxy)i=0ouj_H?%oe4c$lQd8XzaV0lY2g_l$Gz_=eX&Ot!%@n;erZ9$Pb&IKRUj)XW#4G*m`Jvu7FQDWa zX8SQyb^HlZfR%qnRzEnh@_d-trL4Vs+XfV$#DGT#2wt5$(bIkeo9R-lyp|e?%auDXx=}wrYA|iOt#3=7_%o;$LYL_S_ zzrE~sn?uVzC*9Z^b_4~Lk&ZH#h4M}Mt0iQl3=up+aKZs$r6-D}fGat+sf{kX&saQM zbIKR#uYsa4u}K10YUA`$+kVEL@K#Nd=T+;gTKcxQ8@gSrStaM!1nk(W4eqNSAq|a0 z5m~+Ot$INk+Vqo!M8=zQDpi`7YiN@a>X)U!t906Yp^|eRPgR%1=_$ zHFt`P@*53p8;=36&I&B`xpPK>k*4z5?fVD7Hh`S@4KX_b9vQ05eu-nQY%JY6Am=F> zd8!CAoZEQ5GNti5QW(gaIh(6P(RF_-aS$mj5NtSb;Y@;5l_ zRbTP!gGB>_C#sZ``JiO`>*ViMmq)AyG7YV3 z@oNynL3M^@*Y{cgqwhhnjhq+W2NQ`56u?Oj8xwOUGR&TkdXZ@SfyYqoG?_CXEkrM@ z01Wenzf2i4YI8Poo1FubObP1^rPfQwM32Kp$O)e93{L zU^RDb(Sbdm~F z){{Od>{C!NAYiry24nv{TZO2(z1g*twt#=RP9;t1@y+><`Nk!^*@-7Ea4&Qdpq792 zC+~D8{B*9+tvdj~CsXkudcyog;IA6rVV|n4wR)ehVzOaj*8X?9cue(!$6xNjyrr0u zmB<;X3R(n^0Z@NZw5}xX^u@^?0x4=W;3FX|**kj!OvyKe42|>Y=dkotsaFb5z{S5% zYyjG@uUr0}7JISRoRGyg+|!#B$&R@*BV^BD#>Eh?oqHP8x`N4e7o;I;xSN?%t^=GyLSKz7{uO(N!9p-mFYoRXzq z7OObs3OjjE#r|+`{Ae!uPEr{PV+dd?eN$^Oo8?Qj=RL=qvYKjX@ctN7lT2%zXp@J# zT;SeO55P^^jV7{TL-K%lJHr@Rqr?W)CqEI*>WRo_4>$1kdYldcm;KI4P0x$qKqBOT z^eb2&Q@`fBa$cWIqKo=2&RDzlo;o&D0U-~r_9S$RXhaG{!7Ic?&6rugN^ zJMEZ?BY4gI*Qh`byl|c{x1&A;1v}~+z+n^_`H)$HwmrPc%f@d#3%`J(IAq>B2OeAH; zHA5S%$_-{JDLw4zgSXO<69`)NBMXQ?+ZNXkpjNI?Gp@uI@9moN6CKs}9R7SKzh`6c zkVeTAX^nI7j-6)S-e-z_{8+t8Ns%y?rWo3rZ+3|ir8?6Pb?8gWdxXJ^k!#oTxVR4# z%W(iDmJT_^kh|8pI-z~0L`!S>TYO}WGx%-ow!$4Gi;t7w_n?nf-Hzn10sbv1GR53s zzt5tA++@C|78uv1YymFJ>;D8IJIM8r-$tNK_+c)fU`o@O)x0W@FaWtO$KeGj3@Rvv z0Vp<0WFhJ`OAi?5+duIzee=Eg-h1`2rk$B2ZRp1*5&`Q9^B(!Ab#xADOE?&%bH*ta7%jT;N`N(U<0$pgE(=~en<_#fKwC0oFMpsRS^R1ff?lWk) zXV#uvkhc$tT9 zQxTpL{z@^0y|xL2-EJZRHgF{XA_>@Gu_~YE4&`9-54|`&q*dX9LPGwf0sq=%90GTz@A0ENoa)*H-6#uoQ|T3$#Cliumyt z^Cp=>j`O8H2vL1gm<4>6op$*mfrx=@G%-`{`#KEFk9H3;(?*E2)9qQVNZh!>t7xr=c*6`J@ND(;-J-c63#U&RV`g+m%N?)+WMC0FFAqFKc2x@%P6_3zKZu z_|mG6xQw%33*bPgpq3ooR`b>i$a~8jst*s;qY|hKUEi&FnIb+v%Dr4^MR7Hbfo3aD z2Y}I}l5RY1f~q)3H;MkYibw+!&_(}~S;k>v59V@mBsi+X4mk3yG~CGTZ7KD!`u$Wc zPwhu0L2p%l+4C+_2>0rvn~z`Z zlqc)O&%Byq*{W0#1tvjIRjaD7BU-P$3- zi9_6XSMY>3#Hj#BG#?>mcgC#X8 z?zZAc*&|R&n=u5G4;1Nf*eAJdEuVN67H`=xU@e-`YQ$lI4vj}?OoBZ{6o;^aHV8Sg z5p{7sHo2Ktz`Rxq=}Og!da@xkc<+qnBh_fvgT{e9?U38fv6f?^%RDbU`r%EGHEg!b zFv1&6Tt9cm`PHH_rg_`%#r+sK1n7DuIuqwN0fV#8xA(DOVsvCrr}JC~y|-~=?rr@?wA zjY+Y;C+^fF9V0bub7X@X_l>zO={&FxU6^8+D-_;RMPBk|{`BxwveIENxD(_9KsKDs zvd4sjED?t%(b}VdQ+AmN%QmpA)VwJ5=&^?J)ZzxLL7Go$6&XzgS_1Sw@(-f06EqHI zaka@`Wfq)wN;-z)3U_#3{pH^wxrSGHeA4nXK2u3-XZ5w@6bwB#LYG?IVA^y4VyYxa zMwcA;*b(Mo>PN3M-%}y)eF91%rX(Id{1`igceU=q0=RMdK8a#n*QouCY639kRM#+I z3ESQ&@PdfCXhk+`*w73F%$vIrLR(b|Gi)m!bgH+7i^jo*HK_oPNYTCPyt{QvRd9mSzT~W^mMIT@Z(w zuxX63$lTg6Mtr@JDux>@RGGyAMPGl@54h)aa)D<%iM+4IS?Bf9--S~_!8)?Txau#o z-2MIK56G+4r$e0O6?rBWtc5+>%)R347Vi&yaz7qP^>vcS>t;lNxwRXLIpBVxWU=f1 zYJ4qdcNAI*&1x_kxN1`csI%RcJyeu$_2tD5!ni7?cTpwt>|W<|vq&OIKY$(@t+WTf z+WzJL!_}97L%qK5>r`4LZFVX}5>nZ>Q;8&0${IolA=!;Fl?vG^LbfSOl6_x?%2u|? z&e)0=`_5pD`9H7I`ToxJpX-|IoGW9@eBSqcp8L7)`*}WIIANn=xzT+oZ64#q3=K}M z?ubV{a1|P~=YiGsMFf5}ZnAq?qC4yV%h!Lt@gpMaNngMHBI>dhYm1U?>?d5Ef~N$j zV>_)*fmC(#^j)nhtlutA_I16KIqVF4MBq+d+ul#NP)IkY*Vo{$mf&DNli)U2;Wmvi zfR1hN+*PW)C3!VyG-nJKj95#g?8HzOQH=tiMVv&G9~82pZ~_a*xXlY#$}0w%aE_QF z?KK_79t~l0vZnFj{|l zY+M?w=KtB~{#=|1Kg7QSEHFi@u*3?457LAnUcGu4gz9}vz_cN4-fbq|AP$a{_2hB$ zA;cwuUnXyPB5ySnmwEy`WRnK*1o%F7W#f=BGJrB|Pt$L^M_+#yf1}SICER$7Q!pTz z)TRR1(dqQb7{TCh0Lnws>cC_`ELSIaK<3;JW<~z$APyG8J(D-bZv$yHkd`+-U|C)! z3~;x(5&a4yvz*uv32;fwnbUcHP#~&5*$}IqK4_Wo{osDw8;R|U;FGNGySMNyqXz2I z_ToN3z^1FI(t`blhzv`w>y^|9STmv*h-iIi^ab&=L&gOlM_ykzezEUELtZ}lu4UUw zeVYXmKQ?F;p+A{XpK1p*Dhi6@(dC7?=gf@_r?KdI-l6};akX)Q4?nZuj?D}B%1SnkY7*#Q{DJO3<@=aWq9RM z2Eh=)K_E|{j^*l&#~`g9U~iu^x-Qiy%IO#;MZ4W5SogP}KMzNs&{0^y24WEl%WPT( zQ^uhe`k)ab_GXAC55e2|Jf4LmzEN7pypx}mXmX`APHz$|Hk2W)w zYk-{0wXm|;h6x6IsZbD3Ty(z;ngEh0A@haJ6^?$rz;|@A=$~Sa8FY?>TemD z)jzw-=iI25`up68ijS#l3GHR6sS-`aZJQcg?*G~P8~<4nMg#{64(t(d7lp;Zaz9Ye zbZ!1N0VD$Y>{0WVhLuUbp28maD1HseJ&p2amq|1$qr9#8OOx1}wWlk6JKf6@cJgt- z6{>wuiHg;3?>#T_ZTEa)d990|>3ZHHg!ly}t)A=SaXQ8+u6Hbz64JzCkS>v1XC38o z6<8o=1+5KzVMwN<2FYQtV1a##R^HoU*<9sjuGPNm;-;uUMIjYuL)I;uOOx4bRTIpu z6`zK7!R7}lMh`EcrP12P#Tz5*+|#uVdwBNiYn>JaHGzc&&93m6Hh$YQ;GCslmE#OL zL{J5A6%OBSdG|YL;UZocYqFcl0``s&qc$Nsf)v0}Ypj2`uqFnVv}?zPFJVa(ShveI zx$Xj1Ht`~UECf3EhO{KXh! z+pZUk~UgvhyMdgyaBmqM6Zd>H)CKTF%t!KN#mDTBEo!T#wnvehOrGL=n) zxuZfU&{E+1cTUG2OV`F#&lITrC-VgYxvVN}0eX(uOrcla^Meu+?;L>RMb$O3aU_Hv zc7Qw)+L@)eN~{ql0Y2*=u_@;JIM=$VMpA0+^iia$(9ccE`SJ~fn?OiNlgcmDA;tZ5Ag1NsFY|I;OPJQ#UYiepyPK{nQ3ojH@T!Y_xZ`)4q@#%3J&*jJG^otP& z67~oCnR!rrdqiR)Ppvtsd?rt*uvX7WYA12&bQ5q>UEYJK?~Q}(d|d73<~QE_&mQh< zK?*y>O;ps}2r61hr=mK$og@ha?7i;ccPLbPaN8U2Pb?Qr&o>Y4-@z9o&e|1L;H=SZbM?{5osSg7vjA zWo}C!ptB3t;&vgtI4-^ppSl{*MF+LFI0d3TDmbElNa6~n@Z6L@@NUE7jt*Q<{U2~y zYS1&l2k+aWp&%k6HGTDqU#oQuVINDANQv!j=c)*)_c(lS04e8dbpBmO{(B+sodkXcMl~_r3oFW4`QXxlvtA1-l4oz}&caKGi00jIS3PGS&&m3+cK8 z^(`7fZ-W)t0e56L-~M#14!}rNf{#c{n%NC(B zu(i~CtR?5Z$0Nc>Ol9%@>%(<$bz+W00V-v!+o@f(OU)=v=h^_mOMn#+b&gO`?Jk395w zxZjo?4H|J-HW6ESF2pCBqVByUDoj1?CGI}fIP;@-ajArJu-YK0s%4jPP?P|z;Pb}L zK)d7J7gU@I(Tnoi%YqZ(ETenAaMV(t(cj(Vcg)&g1OA=xY zQ>;BWGQR9j`D@c!7HAmH?Lsaa1((f+L>a)TH5oZIl2RA+?n`y2 z_{c-}tM_}gZ>K2&_hhjv6TBLoMRu%sXiQX~mFU<;<^&l|ezb%bhfp0NlTIqG9BYc? z!0c&Agao^>tswKXdJsX|OH%zRbWQLF^Fl9meaB%WXtfGe5ai3*UP|t4TQr(pd*e=G-52PSxcM)`&67Yi$nGE6cY^B zti}RW!_RwiWV^RHmj`0(O49o_MWyuVv+CTYjawvJqVON-SH6FF2YMs)IM)e8S0FM- z^exWF3qIXdakKg>%MFhm5I(6}s0Fg4vgQGeX``7`Vx*fqeEC;%OXe^?LkLgM7# z2N>>X{wThcB7f%o%WAc`NDGU*$JQx?gkZX!m-OEr!tTTg4`B!|64Y52N~w!sl>B7b zKSJCd2oy^U=+iBC^SMG!*mlhD3KKh!-T=am88XtxCtW40ds}$?>dGWPHD*Q8Y3Fmc zVy6zz{LcwLmVv5=C<#7uIviB{g&tH5!FM=U*TbOt7J=56|J!<v|TV%t5XQ5t<^OB z;;3d?U7?u}V}@2>HpwW!v0SF`csO*&|MOUqd4u=dhsyu1a@ohxa1?&rciepJ*HDKx znt<@n<4y(~$agH+*?T@63W&H;Bg(Z;U}1Xy!8iEIG*>w`bWn#n6YKg9EOjY)EJ__6 z<~OXQbI7>;O_+KE6YlzDWc$1*gk|dMgU*b&_3_C?!3u?xPoD;KLXz6qxjed#$KPOd zvarna`)0jAzTDX*+I8xO+uf-r5Y=#zi+7_PhXr4~dezjdhWrs552Z>>l!SocUzx#? zP@W&K^~8Z3^^dcv!J{7Nn?j(8UMnohlb}2ujP-)SSc4X^0LrB1}_u2+wZxL(cpRx5L&*ky{193z`+L@t*c zad9JWJmF`B17pTXT$JIluc7B->aqJryrL*N6ZIvxNj zocC3l=@Rz47P0&QEO~PK5~S+mhT0(pv%_N?Wa&(t;{S|SpLt+_(%-vQ&Q;E4x>0)r z^OlmV`+q@TO`M>%v~rFgbm@wauB}zLT4LaenMUXY$$gNS7&DoTIwa+Jy?mR+&`;OO+KLx~g6a zoVouF$DzJFXrFGkPzFL%&j=Q8uyzF%h9Y9Lc24B)Y06V2L8;7dc)>qnaarq+RTD1H zj}xe+?Gf#}UIk;~giUy_#OU1838=Lk0QCbgnA*@I5bM8W4pHte^ziE3i#O<)@WxF}IiQg2E@NU5E1px`G2dpFrc_KaV?#Iybb-s3GU#f&KDp?YvQw z`mskAcYQh5tN+-0IEKKVA2(EHqI&!o0Ip5`VehyZ7__3C*Z;6Y1H|#v8oYIunJp>X znIiD-`*vPJtkA^x$&-qHPj1)>0mu^-uZb)g_sQsRV@<_%a6%|_$UY`WTtI8F#_)#YPFmb2-6 zm!P@9!=vwV4_0W&7j(CCZsmXt((`EVppgvxEv{r2V6IMiEath5WCW`KEz2 zHwtt{hC!q|qd(eGE_WC}L~q(4LMucOujA;opfKGUSM#7zo-TG7OcVxz`A2-J4IuMK z2l*pXG7&L-zQeKRtbP=bV!nuctxW0|LgbVrMWf$ceHB0%7!Z_z#We+DWU6Nx%%(=; zp3|wgjOvnjKnK>XQDd}y8zq-`yZL-O^O$pbUPH;abh`ymeq!^p0`a+}a$y;sDb;7s z0Zjtt6A6uxDBc>JpyC;um1h%y$ZAhq$EqS^scbxD;!`asz{3hW5vNn*v@Pj&{5m&s zaE$aMo8{#LLNp*#?P{<^rfmj!WrMCgQJblhln&3p#V+cd4LVuaD8E$Mz>xBJ4@fgv z0}X9RV9hJjbVW|a-4p$y(0EVun*X=`h9NR6Q}dDX34OYk9QCVWbC4i?L~(Xb$i7v` zZLXu7iFT4vj$avo=~m&l^Ohr36BWDXsOW(ISi(?xL{&Y9z|>1yJx!R zsB=M`9W;>*2|rA&j+7b95Tv5GFj0sM;>U6X-T(?ij_|-?&8j`b z*(k}eQ6^8hyXN2AW1?g9^rO<}JoYuK{Xhcs>ULGz zjKa|diZQ8iriFx-9w-q%II+kYe{W${CtZ&dDN^7j*0&rPD!c32PTSU@VSUqW3ZWj1 z4!!g|@o(UKkRcD+1l7I!R7%Ff3+QEio4~rocgJncSx5ZjxuQh0&l<|h%bUoEOLJR$ zOCIAc`#1FMg-Mr(Ss)pnYUGeO(a3h@w0?oiyQ&~+;#w@mX|DxD%ULM8w3dQ7=L5fe z9NZjY%jjbuXh$Yr8$rwEKaV@2b2-NS+VR4mMjM#~p=||b)(LGXlanC)jdKs^ z^n2i-A;fg5Rkr1Mh!NN(#-oM)*g%j60}P7q+raXpqi4kI%>Mr}Hp)cA z36E-7@8iWTIMkx+sGTui|Lxlb&x6>4=A@S2tK2O@cCicnf)ysgQ*Xju=*)_%Zulad ze=qW*;R2hHF2q8Z=W5q)6Ff*R$gR{I#10WDK!%)( z+@?S~x2E*L_jY@*6u)7~pA;1o(!`{re!*F@nu*=y65wlR=WK*fb61Cq7l>4d5=-3e zpEB;o!UA3gKgj5J$YPSfMx5ovoq>wBKf%U`SD2uH@yD!7oi#W96Q-??y#AwIc>_o@ z^itZgTm8U#62PFC$Cr&GQ_F%tUFum%J)`B+ zo5Vom?il_d>f{Sl*d--eU!qR`augEvT&SM^#(Zoyvf1CkA?V(rQOuWKY0ag44se$&ykE>u6itidOqgMN& zmMT|9abOZ4S9TOGl66qR7Nj6I2&{#=ejRGs0Vl8^6@6y3nt-gPch7a_Avb42k1M1^ z#82x}V`-Sx%t|Dt#f;}vKrc!lc$X4;ECFkMn24e-oV5y;Tq#aLG<91LvJ|7FMsmC_ z;@sBVI-Vm5O^KSA45Q^h$4^_a%1V$sI13EvEumC&U_)9Zz@wzQOdid0o2o$1*#gC$ z4G%UL+78CZewCtE3O?CGw8BL=^F3829^))DA5=qPYF;&vfdofg_Ou-g=8S$ zvy0|wduK;w2Udkd)rvF?zm%XW9*PAz5CN7TIQ6o@Z^0al_$ow!*U}D&=@ zDqp7gip1EYst;Q#U?NYhi9=9A?nl84B9fSqG}vwcMw}9x4(dC{q#(f>=c!h)A^FS- za1?Q*J-PWbowWh=g@={_?0pJ}FiD&KbqJIu)@g+8*JDX4(9sH^M}lwp!Hk}kwasPPL*)(>)h za4-g-d+)R0&DR_IJ$bv zP>yYQ4(QPt(b_f`=;Yn-YKr&4c}+W#07O&*lQS|Kw{P$6IRZ2-*fU-~qyy<1x#V>( zz-d!G7Q_1{yB#cL?~!M^5L5v~rT*q|x#t9{f~-u)Co)T4Q>1H#v=^5kl{^#+i)fAQ zm>xqdhmaKz`paADNB$nr$*em5rDrf3RJ*eTmKd&!g4P{+ZDVL z0bwDL=^F-ep8BDsA}9r*eKjnb4orG_haB0{{J~Y^qbWmhSzK|=F7QFE08gRpQ3a-f zp5nAZePfsI*7A3&{wlveDBy2_Kh0>NuR53f_GIb<9gxdNaVB`-B3K9HH-v>cI?(c{ zs#a5>tE&U}{2Pq#ls!3KMvko}?uqR04-v6aN3GK_N(qJ~8j|jLXBNd?(xxwuEh=mL$A6MwQARy^WRh-)S=6m)9F1A zbb1`Sofr44$xZ$q8~{i>H9P*l>p!uAwB)<(gE?x*Zd7|M4?>~3Hc?UgK*B^_-QKVWTd%yYm;_SN)5rVG_R}dDj>o1uQYSMA`j?F_; zp>VxOd)tDHrH&y+2yr9&)P65lGr9Fve}^fZZ5`uon}g~!-D51})|60v(IzW(?#+u@ z&4E`%y9T0jGy;L)70!JPZSkr*jSP9zVzJ4})Sr%xzM?P^y?H3hA7fuEY~H_BvG4@x zt|d%zcv+njEPrskPhZOh@4C@c_G6!h^lYI=pe*FtTpCJC`-~`{z};AinSZm2tv(qy z#qb&Bs0szGF2FX>Wi!1bt*lg91ly(z^%@905(W0`kM+yyb)<`qRywilE;91uCsFHv ztI$^puh)FCcbi^+_m8xYpiIR?aV)L0Fij`X>!+I*xA)rK``b!VYWQy% zjuL8>O^@N|*Y~J*iu2r>9&MHD&75H|w8$b~Msw1?^~J<{1fzduYH;6J62_T1dy|EP zshN#Gs(`~u&@9mafmx07a&gA$k|gm(4c=1>qJzz3`F!772hzL2v4g4ZP8p@SG-wz) zl!W*AYS%{n*0WU%a5iGTmCDULs=j3~`>kVxNt1i8GrU`ew~f0yp6~1XG4-P_rB9DH zCU}HHZ!&u;rRivx9Q8)&>em4@J|+gJVozn;4l5rPR>QpmR1;$URGPRTe$Y%qO8cKZ zbWPoHi{LN_Lf_@`o{fns!^x~ILkz&rBJ;%PjP%dtKw~X;$X-Imh0@I5ZMpB=1Hrrk zjM%YzHLD0j>1K_k@f+g_w}3R%&9Zg|qD11=-ga=N@)NfB0R#w6gn1}cbzuM%?E6;F zlh%>LX?4M7aZpVp2$+TBMhC>@=3D#`00R(OgY!Wdndx+zZhF2RM@442xPIW3nv-@? z1B{FnhleH+_jlk^!i)okfJTz&6><)Gjsf;C)Q4GjVVWE|to)AxkBg z`SeC_%^#vpzMvJWD1P`-|5BgpF3|jsl{h&{V0b zMH&|Pli8$lId&w{c-y3Y`t$~ZThF`Muu z`UF$+RHLjjvDk^^5l^KKsre}5;6V-P%>`7}~)v z3lrCxNx3G2Q<@iYRui(%D6G}@k8aV^)-mX`>vrjKSf?msIyQLq2{Lk$Wtmy#f2Jg< z)i&cqQj5HIBanr>)zHd#j#~7g1%?JF1>jZ4U4;L@Ws2$^j``kh;1s@d`%N$P4S`%Y~gX;BnQU zE{yx5u|n)H7|1MW>>$hr;r5w)Z_N7VJWYVX?%gv2Z3-hs@gv_@!U>jJrcKr?jJ`cH zG5>wVf1oF0kA}1*C=YV$zx`G59x*b`SLyGo$*jWam{=ATwXj3SfRu9Sq9`ZsrUv*H z&bC`QH2WEU65s8jW4W6>;>^U^qZ_df1Qb8s0QrpbM^1NKT{Aod!L)W=j5XTOl1+%= z8V;QU*V{O?Qs@KfjU*FX+iPeAYa(!iZdZpd%%VXpVEA9)I(_34Nmia20qD)W}KIKe|1qC{KE1-ik`6eaVMh zLb_i=2P)I=KNCI{>0Xa1;9iNuVqaTetv8yq(JT^qFA$Dp={RVmTk?JMgK!-}C^V5r zdIa}7_o8s;{nv#kYA>k=*+qMjU8}j70dQZ!Bj`hWnoW2p@yGKm z6VEPM&z`;yG8OjQbhvkWsm*ja-B-_mhN7B38^9EPjpufX! z*7>EB2R8ql$2&v;SnaK89yW%zaV_#zHM^xm%G;L!*>(=`3AVv5sCJtnYF3We#~9ks z#rNY_;Y_LGy<&h^LM{P=2@$xEn{Dwd6x}Nz~z=yfe?j2<+?`e{o#gZPTF20 zBheL=Ke_YBKNV2)a{}=sf;IIZ4!fBRMM4CAaEgTQWUpmP%b+5!*8%cvH^>o1L_Mt;ZP1G%Z0UXaHW*6$QZz5= z7o97z+dN|j6^{KUmIc*{(D2ae2Ay@y+w|O zj8OgrF2XQ_`^~58AOqI?2%)4JR~87Q5o6KRq8=wQ)~W`v*JX}Oh(}o z`CfX!IDZKlMW!z0;AQ=^U!M{C4#A+6s_VS9c2eV!quO`ynE`M|?oYP=cx7y_jaVsd zOoLb(ni5=GJsW>KfqMVGl{;W5no;C~A0V+LXdYbB_i-6JfS#7SsaMeOvbcnVbL|=Z zd)2}Thluj@3#l+y#b_SJQldCf=1BcBvb-u;fXTNE8x$m(YtIW)eHSDD0HJru8y*4V zmYXkE{MgYM#}sASP~I1KNmNvmGpw+J1;{QNDmdqwrzDZY3W#Z-^*qh>xk*?UmMvIu zSW=YNZa=7pI$9VPoUySK+8x`T3 z)Aaou%*ywpQOYm`>cx_R`7lA~Pz^<%zs&QiLlbSL zXJfFi=ARjdotiLp@!nG)&x7oPOY*;yxk0UJnw%kdbC7tOMM&la>YNttAbboU^AjRS zE{*qRuq+h4|SsXMj4>srKE4yGa%Uc;Hd+mq`tjKha)eJHCr>F--6o zx!YYQcq;@eI^W@P{F(*xZA#0fC2Tv{=$?ZOCqxC+OQy|mYidqXxn|K)Zr7TVN;Q`u z&uC}B)N6L(VZ42*9uOI*pCpjb)68nDQT$YCajaW_bT^UyrI!M8C@|Bh8II|5ubbVL zr?w5e#s_Jx=57St1627YRcg+}))~mNf zw0{61I`l}4eY#<|M5VnBh=^8H)D_c~-?G&^7FsEDI~JkKj}cFXZB_%FTQ-vzXnk3X zG=qSXv*mO9hxA3&g9}}qh(#}VSF0_9a3Hgpr6o-;Ic<$h%B8^+g#6gEZHU4R-4qX% zSI0Q50Qf-bZWu9ijHizT_!524fBS?G#4plR43X$@g^rLZxa-~@p}CdR@#FAsr|I4g zJ1rTlYhljY6tLAWm=MkG8+dMf6?!ObU%2+UOB%TkwYk}@ywKHIb|DO-$r>@rJm`#@E)+Csbm5^4Mu!@iPG!iT$tH@E zP=(9Yo_vw8q+n$8sEDg;+=r>u-~udSSeNE=XZ1;@jofife#?vKFes z&DycdADC#-&i&pFJ%R#>(LSpNT0b&*|LTgu-q3qHS0)3PJ)zD_R2zO_I&%=)=30#_ zHzGFqB%04QlDhl%5!3Y@i>DbT%|uv@}v!mA_&648}9#nilvwoTiV zN=}{M`r}43VO!+M1G^YHb5Oo8XXt`nQ^aw5e{BZ(MXRykVt*fjl*%z9dMcFZp1We5YFwLyY1Sk^xX<*e(VU4Ux{r4Xu?D z=#CknM#7*rkZC()$S~+plYN+j#Q_@9Q9cv;H=k|RG7j+)w8bTSWhaRZk_MpnyLdcLcEkF zQr%tg^kcbzb=C%{mxyce13t#ojL<=-dj~^mYx!l%nI-J~Uc*qcsh`nO@7gvwue;}_ z+??(YE4Fl%_o;pH!fiTcv5EFEw!w?$N}%Y^-F7YYZYSjN12z}WCYei)c!>XI6P_iv z1i3Z_ktI#1y`FFCEJ)mCj@l9hH&Z!E`ls|^&J~(8m$!sw<01Fz+4Mu1DNfUms~mYX zisj(WHpo?B`>qEKAVOC6#htFGasnwj-SIZ|uLf(dK_3BQZRFB4KWb8nB_?kUvd_*- z5z1X?>)3?|?sjQnYK<0sUpB-_Ya8PmVe5`Mz%Ly&HwNbn5~nBE(R79Fhu2^_pcTb^ zXZzYoCHukNb6vde(U7ttY3KVyAg%pT|Cw1PAu#E=-xsI zNA+6A+DwVnMP)nwL;ftVWIzyepqT;9J2<>SGDwHYrCIZIDbH0=Xjn+|CRkbRG z0o|!a1pZje`ZE-O(HP{72Nm5wb%zX)sHQz5j`4%6dy+erA;m5)iq_OC(I2- z`FUtBrp0!4Y?uGI+|s`_XB(*uvxw_p&nz~x#nZc=SQz7@aWeNIv={_r-r z?PbKi>8%+e!O|`C(1$TckCrVN_upR-r{0+l)1B!kugVE-Tdzoxo$@o`BihZ)l%NFT zHZ_|xUP_ULv(UIU+%X7&;R?0!Nw@N?XS7cxn`UtvyKTAF2a|A<^`VrDN|mc$T?hSDc=c?Ii)}K;OU4MUFHHODi^S7u!e^(h zP|H%MOHLgfSjdQBzgRxo#`sLBolN{>IKDiA*C|;g;jm|)gw<)N}vMfVZayHM!aP-_FpK%X#(Zh|?$>gtL}KEEP{?4zeq( zcA71Hn#R3C!zfr-UX8{+=0~!l?g}xlL|$`a?jdq^$>FK*kx>WCTO6gxOUj=KD*9CW;`uTe?;E?cGgOOkN~j#uwJk}XDS?qgs6KBeQOGsJ;$w0K zFsEZl4ryzs0WmZWnc+}1-gui zmAmv?i#Q7-89&_*+I>OV+WJf%=2>PigIUt z>L%ndcDqNYi*jV?td2gLjJzd%471cc5AQMl%AL=7L~my?ri7n5?M4WUas3_QebtT_ zIJfglwtgUE#;L1fY%m2VclWp!tD20@wI7zcl*2jIKe(zo?0v&zzuk>E>%6bsuFpEt zvt!+xjjLa0W_=4=xq{!0tTbz;z49jy?35lDG-css^-tC^x?pI?ZKtDsR6u~$-97xJ z%kqgS?ORBr`0-B1Wxw`KR>|!3SaK-HT^?UTx7msHV92#oy zshYW$m$o&{g}+XUI%sq}VTn`N{(FyNk5~pr z-{1F0FhM{0YhOlrRTF;CuQca6^_`Dftn)hQ8Pms~Jspjf4GVItX+G}bMY_;fbjq`6 zcz^sCx4;_9>hAe1n`x=ZBXiS?1+Gli+!#xtt>>*-qtq)Qba^Nq%9<^t~x}fh|$2xb~zN-J%5ZR_#`eCEF39Vb`*rd((`DA#R zB;M1id95kU#3ooszUQRS^iTVX{myUVeY6^WfIG1oQ6VHx_LKrv~Mw z2XXi^(%q%Pa~2KaJls4lSiGrx{6_;!*kGj>E;~MSstk27%QUaCYogug@OsK*Y^QQl zSM#Ye^AmqZ*_$0}kNtIa9xH$D5>hhIW#V@zuOKtwzM(L51MSd zl04`7w9YOr zRVqrBTV`tLJeDYSy|$f)uexV&Eady$W$K!h+ob*A0|n~7xm_>b-{@dqfatS*&z!e$ zhsezI*zuK~pup-%j>^^1zVjy6>Gatz7e3NkENy-~^V0DA#olT?6`%R>K!IKVrA1!c z82K7DOYGOoIB9Dv&y@NwNGvRRHGu>Bn)VXPBugmB3;S$FTA?>%SD=xTYx8amUm%Q%f_N?zm@LIeu%VPX8iod+9I`aOz71sa4wXZ#95nZc_E4#w9+M2BW>E#(3Qz~io zRYC`AyPJY4)+>s{`re^IA*AV3yI7azn;U*B{^3VImkVhVO18xE#7y)aJNhXJ4&k%o ziJf`$!#%&#BNpeLTI|Str=sa%RF*(nqCGs|_(9{$I9HBQ>?6KW%dOr=UWccDj&0yn zkL>r+NbqMIsDCSyT*YW*G5mY1o9E|v;HYTO#oHN|F_Lg2N#AcGMr`7LVZ1cXa){_8 z-r$@3&7Sku+@?D#QQx&!uqV$7u+l<|l$w3tZ{M*~$>__irjBBn7Go~U#jN>>5fP3s z{lpuCqv0R;WaJL)AYi^!6~+8s@|vNnFxJ#{Vi*MYMU@_9NGdBxZxauZcq1 z?D()o-I3p<27@f;4rVs_Q!K~LC7IG4-a91UNY1i_e^SWgWzdzsrc={N;DgT{pH%kz z{4Dm@SkE4K0?@ z7gB!`^julin^UF3Y-J_SvCu{#_?&^!60n$r6`& z?Pxhy9zI5j3Vc~oJC#!HAHFM-KI70-uDa~uEcVNq-SX4}m%5-{DVJd`#>%DV#ly^Q z7>{$?g`8i0&s=M`+qAcLl$w1#-CZEj8;Gihz{c%HjXxVKk7 z^c$aa49fZ~brD{}Cijk+k$I$Lf*kfWjQsH7eY2dG>H;+q+b7w5f5bj!NN6|$;V8Ro z|F%8;VZm&c?hG^`|JP~F?VME|4oL&~-%b4V)7l4SL>P3XDYgN4v95a$ey-!N_AWo7 zHS%se)6_|?;S>tjHJ=CwBz_II(34Xb?>BmUT~s0X)&AHVJ=teZtPW01+R9&bF*_A7 zKhw9_!lzT@?K-2jSKj$^x<6XqX~wKg`0G;JALdZ8@r9KW?})2qE6cl_n5TD$DjNHb zpBbKb6@qZUsZV^0)T4dN;|mr{)V#_iCcJ-G6z3r<)neg&Ywco_&V66NJz zCM-J|Uf<~HxHmH^o6;b&{xhBTc&|#DSpSk2O?c@JaBc-ghExJQ*b3AqK%>mBXCUwdDD?ySaYVmf!hd+63l zaj#|RwDWtPMF2MCEb9@`Md#Jh=eR`0h+aBM#{Dz~fSxF;rRUR%h^{9=PeQTp^N7Eg)A85tO+p5I_iiNe=VE`hS;2bef_#5xDzE96 z;$nr?a)FU2rzbD{8av8SoXqFMOZ`naz_7Q}wbj_ZM8UBC@IB{W0anUHA!m7IA0JcG z=A0)(Vk{=g;~#O%0l`7dip(C2TVFJ)YeBtQ(?~LIAA}$R->wORrD4=eFWe6)&w9Ge zRZ>RaDJa+g3&h5Uc4$uu{onD}_&@pR^_o7W!~f)2k5KdIz3ZnRhvPMKw)}oO;6!o= z3GM^rAtt9p*!Z-ra;XPoB}!dDZ8mjl2}t76<0z9MzKa zlP@fd?*8&HO zGoulu^Q-}llD#xn(vyx$9p8$~M!!jin`?L5A261b{+L@vFvj^)IRZx%Zah)&hVMRB zMFuF#S8NPO`o0&=ThF$aHY6G6_;)@0sxqFnVZ~EsspUx1N(<(1xevl_+`BqTy<63J zIk>waKD=^R%hfZx7@VHVSMy5Bbhpdo3X8LZmmSM1_cM+UXR8`}!*uRm9L1EetQpgO zZnI_&9kUL(z$J3e|+uuq6`e-wLwbwtC^GN;Iu46yxV*1uQHdB}tlC zOosDjJFb01My~bbVj{ebqsL;xnb0YfYLv~e=L|cc@ha>>70I*x zNxy!N_!{EV3zlQy6sH-pUT}K#hB(A$>Ql46u$zz_ktwx|E^6CACGYGKJ5GFdC=zK7}HUi6@6z zXFpRLch}ooWr=pmO@#FsmryzBkwj3?s_idA{Lj(eOPPU+v~eo``$rk4uE1=KQ!}=( zg_>NiJsEN9>$s=1c_-Q~mMaEKU(cDX)|il!47Pe-mON$ntW`b5#8yYtP^^$}yip6Y zJ!p4DYNEYFgfYBPb=hRUU;5+C{p*LNGy}G_i7SwQUay#iPx3pu?(1rSPDF1wXVvkz zO-};J`Qdm`4@z(asMBOB;3dz9?1@c{=^{@k*%$li06&MZk8TwV z-#+lLRE~0MdF$5CdQ&er-f&cHeYX6tK_G5^sP#+ylgF7*LpQx6R*ygY{gnJ5zHp1c z9&_gCW@nDBE?j^QRhZHA^Gf6Kc$(Y-O%omYF50=pm4&uqb@kX4nF~ZV zh6q|?9KYk2IwC8dfSk{t@|cN7zy)`FUye%!|NY*w{dHQbd~&%-gs6*3f%m(zI3cJ@ z?AY^8h4@R$@+t`j_(K<8+&MAd;NkvPPy+to?y=h|O=RV9(Njf1xglRCH8{j_+ZgM# zlUsEQ%!O`{8?#-t03-TF##pMjPIQ&BKu!vR!cPCTvFU|VhZ-Xoq2jVI@vAMl6(P5D zzVmk&cXQd(_g;_Ug#~XLW-2&~Y2b;~F9Pv?jD!AxI&-0_*UqiFYFvbJ%-<%_{9FC^ zH#ZBSdJYP)Hnas-jF@N+biQ@u$Idb^T;-8vWEC-dxIQXLnt?9 z#s+SFuer)D$#`*9ewlI_a%M<|0c|$>cdUYW4_oPJ>!%4+a(9m9;Z)2^y-BF8Pg{;1 zd%9)oQ7Fmzc{2w2UBN~jV!KKyK0+I4n##(z=qHa;TzkGH)yLLtD(kxp=@P;gKd+in z8ZqYN@vw~Uk96r1pLq#< z;tq={@?MP)Pdo1xdjwAEF@03BhH?hFcQNBZ*S{@!0jSR)ZuAa>(+ih87<2?x* z?aQNU3~b8_FB#7V$e>*Uk$>NAL8x8WYf4LXF|Z)|3Gd-eJ3S^kQ0 z6-n>&_o0v(gw&DZbD*zG&*T9{EjFp4>e3PWeI7|?AAa@TtM>l0X4bg#&u(5NVwcqI zn*yb-2O#~aoMs?Kq%Jtd4T;ePBkwzgm6%TS^Y47V*{EFq{(_f@h}rwEEr&|&W_o%r zc-$G%;hZs=GhnPxuX?iA*5pgVTkYf;_(Y`#agtKJ4NaK8OPdbYa=g($ci3flp2%^Z z*r~(6gYn{vXK$}?`5+@DE(5LR*&{4|xOaF>WHwO_l`+Y=#J zR1+ftgrwb{3?4gl*S7wv`1HAmL0c|vmOj~?brRSUY+rXV=scX;GJdXI^CCsgjF<#D z_v7{CD7o@Y+ZD<&%)YEl*vUi9oBtnOZvj+gyS)!@kdhDy0YN}eLIDXu8l*&}l|~vB zY3c4#Bm}{rq?9h{-lPZ;(hUmIUDE&c80VbdeDm!YamH8V;pM*9y4DqI10BRK5N@$8 z;y&28yBrF=?CFo63U2&qG8y--!mi;>654G6ikfvH5mBL`$~)3I!?{WWG-Dq(tj?YJ zm>5m&Z6k2R-XKOVX&gu_;a)T~Tf^i-W>m7WkDheohj&hx3u9AD)W6rso$t)Cq9|sU|R;z*wW`I=IUd6Tl-h! z*D~Ea#8_*r{hlhr-3sYjoeE0q!D`qO85O1)t#NI|v>4ewwR0H9JX{NHl-p|h{yrzB z>8_^Eh`u)k7u;Q#C4WqdXd8P`!E^@3*!s^bxQs6I_LSMobF1!U*)A`eSL$Omc$Uxr?>2#St&Z=`b|b11{7GS zg3aTlV5>6fj<@;$^Zkvqmi7aW)~2*w>+mU z3_#ad*9n@Le$q>}#U z_09-L;NLG5g1!>4)?4|$Z(59@A*#z($&hK_pU^Cq=6#XtbRysfnwIhd<-yB`S%rrb zuEeHP!73+*rxNKHopaxK7Wn$uMbDH@0_(WV%}b&DE9ND4*gQQ8Rkrigz^1Q+!_TXq zzw(!f=CepSj%$}bu)h}WJjIzjun$<*gEhSaA!^nd&s%qr#nZzLeZTAN+|9?DT2+*d z)~?XHmucOWPMJu#lr{L^EP+zdD@|6zRg86Lg%bHyp(&SuBp#U5qzC5ISpTg_aaa4# zEx8j4& z#jlZz{7Yl6D6U*z&vB!vO(S3Dd;2$rYeku)1#&qXZ&7 zDd1ohI+fDlLz#!4lAg1T6Z))%j=u**%mYfoI?dxpJmaoj@#u}ED)$gy)>G~t5xhp} z23%hPIZH1kIdS@->BOl{Z)qJrD%8bf2_#-(spxQ+N}DNVD|m8nrrY}*O0!`7`r)q` zmapT)&>oUw>RY=>VDx8v(x}|~q&U4wCjjUKtX<+o60F@FtFJrS%9Qv6WxC6B<^5Cg9(X7|1eZ-He225(<;sXa!VA_t{>lRLN2>{iq<2jNt zGqfIW3(S)VX62~4$?ND^~#dM|HJ zY=){em!*YEzu2cEDCqJcosMurU~I+7psge`cT!Cz~Mz&MS!+)ZmJ zTG61uQ0fG48S+1`bZkW3`FF90TH3}HxJ-LPI_2;q*{-pYv)ieBmO0h!jmz*5G?vWh zOy-A=&m@|X?#$23j$DpYxOG>$Y=pO5G-7a8S-`GU^YxQK0_k#o^H=L()oA^^FEjcy2vzN+v z4i8&Vw+sWn>%lBL^RyH%2mqxtKW&g5Tl_X zBk0F^PmCr{&aN6Ud67q{M$L6zu z{YZh@TrrgQcXOs4tGgj3=AB|o!=H$+gv=Uu)nQd)Wsrzog+U3pb`{>GPj}O9<=ndT zi{T*dkPSyXj#vcmQUs}pu@!BQB3|t=1-?zu2caB_*Z8?cyFOAV@1m&3lYeN7&$epr6e8kYF=Ec@bl^5E_U(t?LAE6fjLYsL*#AZP-u76Hh7RHIsH_l%*_ zTyj{+#u;g|)_9&9(&(>=&hL$>|CAyO5LpsXt@A$9)C6B6rUFe?kJ;0e(hCo2wEZIs zV^P!{ZWGtyr=VNAEXUf8%VBG5?>^gVXZINDI##;KbGdBD?dBklu z@_;{~zNY7Lb}rIkoitO`OgbniIm?L>z-WC#^(i1hv=3d=^QVR2>5=12u3zg@A9)Y{ zmgEY&&EZ@+GYprZI6$}ZYrf7H(4Z`g0yqH=qnu$L@ivNJ60MEtoxdVMbts+PG{sjT zxNq)wMzybRpBtMFO9)t3iIwlZ=yi`Ao?H-()}BVRP7ZaCW#&Dk`{~#>@L;(oz9n+t z$A^+031%sP;h63o7USY_L&hN60IC)c6w>bHbn|Dw@xk27yj-lG%1M{@Cmq|43`@z} z$MV_2X3e2@D}q>c-`$ex`>^pfUQ?VnontH%Wnp2FkKZ8upy+b?uQrmsG+}c%eZ|Bi z?zO5cv{ig+?0~S_7lJXK7U1Ep`Dr}vl-*z5_zR!^{1$#S(1Y~%-FvCas%G|ToK?!% zRH=q57T-(xq5zNlw=x=n>+V#*$#n104fgyBDs1$< zzdD|2an3LipfsXgV*Ha6IrL&d|A;+T^dvJZUKFdXo%ZA-lgP&x0Y2jcqd|5lmIxT# zeNAdagvy9j*AXx}VVcpqNI^)1;?J8frQUHV*31{n==`Xsw8d9;gUIsw4J$V|>9b1l zO>lQ7CsJrk2x|+jmM{9PbFxta^Aa3)eKzv|itHBYa5&DDw@E z)eH1{8Co@mP=m7il~X|xjvFt~g|7s!Ghv^5qa(X89FDa(x;^h28+bDFN7ak9$h()Z z?J)!j1n_F>^Iv)&5i?rcSt!y*@13Ki!-t-6bkIySiYOI#9g1bpSAj|C=g{?k)=EGW z^AY$OP{R1|vNU1X^xZCR)h)1T=5-_b*h-#z4F;$C_SLT@a>;eGpo<(AJJn7)WQ(sl zkdhezF)t8Oe-~da*LZ?cw!r*E0_|Zx8+wMgY4BwO$C;q$?;-J(*e&%6E6)L*ou<

oH4(PfMFmBFgjyJ?{E1*&#Y z$QXv9M9eKr#SN_|WW-iKMItn_VraE?CE=UiebDDQsp`#xHS@DYSAy=XVoRdCP=vWN zS+8^P`X6S^_5dM76zI0tOI$=0Mkr-sL0A#A>z~yap+@?1dOG+D=g~QxnthfMk4=ty zC&wSEV|%}!qtJDm_6lb(J0mF(A_Ni>nVH)!glz{hF}ldw@Xi&ATur||If}u`AxFE6 z$xMSGI7%2O7~O56DwFS~SrsB&$rgtHPyDrQ2^dv;lU}yUX&7|=3KsE+U+K2hzJC0e)7T@_QH~pg-xL!hfEj_dz%xGxRhU2{o22?o=sYG*N`q7Z z3`YWktVPFLdluaXGr%WcH#}J^(w0vs)#RGKe>?haC}=>8ApbAkpNPERnWd^>Y_x&t zxA-Ljh>qm_aB(h|B`Q_b@dDF0pi9|xKXs>?>FNJ+pN!zT7`1e^>=79*zNM|s)*;#fpm=aptrMPDc1vnvhQy6&~k;!17A?4F8ydc zAF@#Jy1(dHw)&TdkYmh_PbficLJOSlm4l7J(Y5w67Mfm-(0AQtf%f|k!qHwH1DH5^P=M|(;RPWCG@K!*7&4}3kN(tAhP1!uimC9);w(2?o5#NgSq_m^h- zMa%iofy+TMchxj)X+bc$nSTfj*jYM@Urv<(j1WJE;M}LFJ-Gs~jASe6vS;ML34bFJ zrsKGDEBvNcNONoCn!+U7HSSFnYBA;-j{qmlZqGMCd{lOrP3gzsp2cWRW~A!C+jjQ2)^@Y}oa)&6aozt*g0bz5!l0*TGOgpj@5~m}NcY%t|9$ObgI4*nIm7{%~QPyiu zO2ptJn#L?EhbebnDop>=U%#Fj9 z&p2)=9-oWy6LUaxlW@+GSO*Jp^{j5Zj=Z_GE*y!%-<>xPK(XPonLH<26P|3Hire|_ z%Rh6R9~wS475M(GJdbanOG^zvv-Zl=IC_#Be%w?!!cb?dZq>ZG1PrHgW>#oFVw|~? z*M38I?mi2gSwv4yIQ=N;Cw^A&Rd zV-m&=y_i*hfrB;&G7xLE^zU9=vt#&TO>k2E^%cC=jaCU*k1pI1H-9?v>~AUgymR5g zg|49rSoSNGo z?`B9YxX~LNgsnytI*VZ*Rt{I*xiK)*BF7c6QJ*)9h!xebgEtQpS0Z$LS3h3j<|XD!s*}>VE;+Q6QB~BQx&4;3XrRA5zPgfAEA(zCXxXKAHFI;usyrLU z74oyJ7)0|y{4KEf8S6XsE{n1{q*z#tzkKBk9j(If3G_dNpj{HZj6yxWgcc(vkY!>cO&x&(FyL4PxVP1(Gv^P&wQ z5On$6zE6}_BSvrcNTKD^r9|ITEy|Iy>riWS(qY!fn7rypFL1eHV(averLkU$2FO^Pe0jqFpK+H~Vpn zCp6(Uq(~&*Zw{4!zf-ZdCuw(b!f;~70J}}!TYoIy&9bcYG34x4BaNU)xm%@6$xfSf}SmNR0Z{xivy>U-ad`+{Scsq~xg0>wX`alA`6H zBV4Cw>PZ7rw|SqP9DJRGI|t^)+U#$S1H1Fb$4}8=DM)Y4_Zd_UQWy1mq}Oza=|v9j zae*D|P06G~s=QZ&flKkdhusw~jV5+h^LH1mWp+s_P7;Ldi`1~%#O?|mxbiPwg`Ih+ZdvLPs@b-O+nzZOJr#!a-a3ilGO1bxqHXUZ5MLokEY1 z(9&I{Hu}Xi0>%lXZ%#6g+#;6arOK>;T7q<8!gF1~LyQJG zXWb+>JfFe7|5J*w7MOP0bxQ8q)PT_Ryw$iRxW8a_R$sbJ zogbxwX(plN~%f|E=p8vsrqD5;pw~X%zd{iyL%g6Gb{LGbwb@6$C<9LxO9?lm@_!mmk9|wrS-Li_AjRIjvPwUkSumnr)ZPS!R+MQ-O zlls=4egd2tpqibH#kwz^4Lno!iH)FQ;@7(8#QiDxbJYD-gIFk=`>bcroe#^8Fenoc z=GTcO{eDp2+UjW7RPycDE>*IuE{x6D_d3CR6AnM>ofX9p28#Y!sDRLYPao4Vfm`6- z&-S9>;wK!?QPJ<~8~q1yrUHead*cOS>s>7~i$;mT91VD@zHvGrGW^cl>z~w)UO+PyP)+ocU-4c8RsCG1ri@B%5+v-{xs z#2O9gu*daEFgGkP5pcPT8==?9H}k3wu-lN&=APF=Z+_^AXBxChGJg~A);f6%Uk9h< zLruW-5A&U`->hAnW5{+p`yg)QJYg+U@=r2xSMtbPBBAJYrT34Om9yNBf0xeg29}3k zgOoF}xE-f&O9%Q>AG)0>2B++qeN>mBRAOUQ`6m?}oefl}Kkh4!Q> zormhq#^BlMwpCia_On>l$UE&riF(n4sqq?gr_p8#INKvv2()w5iw3J&nX*TxH zz|b;S*{a?XnV`h;mQc^L2JlJ~48({RcDvsEb7xJk@2Go9(q|VfL5kY2!suTTmzKt` zqExR$@1x`*t=cqB8Vkt69dC2JOapxn@k;_(ah>Y(V>TEWlBeuBe3Vg4D8lEfY*fa45FT6YWaYY0qiRU=-3y9FpE0CYo9pG*S zN6OTc1tstG^g)y*yoF06=%4(9>JMB`c^$eCqq?rMD(Bbuk5^+gaEHi&d85=~xJIfd zw3?onl7o_l`+a%|%+{c)iAbf2w&DJ3ySyWc=(X6!1AyBMfOY5&u7H2s%u z=d(l4klN=vGI|p=+j#Qx#zw18@(i#|h{F_Sd_I?|#+4|7$BNg5)WIi*lZcHtTb+pl zFz*n^-F*t$r zb_NAUMUlnDj|^3w{|66#|8BzcpHixxbT^~^ut|<&a^hUpxGv^zY9_Le(A^aeEQL8} z7;y1xXLTuvGRZgH#vjHB_wyQgfD^l|3+zENsP$le{4g5ynJcP5ky5ed;)z^zIF) zrqB5JXz8z*v$7q<8N3aV98VQDno4Q@S*cjwc(g!9&=K%iv@2SdXcw=gv(bW~z(GmfJ-tN|n){?Hd7VOm#D42IDDC0zJNo8Jp zS+6}X8(HDD-&yzlG}>fpV{UiWO>F3giHjVUi-pk#qi|NB(~J+i=0I$eWJIOVqSu*4I{$AQ`01Hh+dnQx$ByfP)im{RfX)(2MrITmSV@#wiqnM2qwq4yONMLh8!(P(cKT7|LbDg6 z_|6Lhq@chZc%)lpVsiika`%r^tbsG0jQ!L3ogl>}&j1mX;(&)Ou949TqgJf{)sV6K z4pVuMXW+RIksfWl#v+O{wKVpNq0K&0zEM6T2`~0UzEIG)Y;HRwq}Zz%4$hTo-~ur0@t|S$41Chz z`Iw~MP4(B05K%dhN&s1{{x&eXNRYmXdW1FN^*4i zE;|>N!=!~%`uf!vxd<~igZjO(pIbR)s)tS(Y7U56n|x5@mYrjuHUD*kdHMO=knUE( zJ8&xlOb`ZyE&XS@;<54jb3QGy_x9lo1_v(%p%aNe&Me+A+N%hioBVJyF};QR;x&qw zh75#PFEP3{_7SgEskz|tvsq!Fu08k?bT17Pb@**hiUw0o*q$7?1~ zWC-c_Mh(^|QL1cezb)vl>n<OLaPn~HwwoF!Ns1Wg}noN!$D49Y` zc1e#8O1y$LzVX0Fls55!fFX+121f+XfQ*1ZWWR?YcB>Yc^r9+1S=sJ&L8rr_?z62( zfMg+!4W~tRYFv-^qRMly${HM`hSdrh;qH|NQlIAauq^krvZ5ENA&H)k0zO};yf;w3pve4z5z`o#cniaQSF7Y zW`0->rcoM?Lr2L|dufC{ySWuUvDA#%E5d749x(1dxIvLVGB~Ca&5Y|KmtRqXeiMgfq+9WSu?Djc~#%kCsooH=asHvUxjTsIlCI?!h? z)zxNI$4YK>X;-?ifv+h^25V%fr>PA)dFHUo@qSC^BfJ=*uvIo!3z%d8hTt?1@_SS_ ze9!NM9#1fS(3NUTQMJne2linDiMeND)oWicUqtTY^q7JDfN-dMugp0Ob=4uX&(}OO zhcn!je%$<;mrlNP(P;UYF-4y~fLqI8AfzT^ znU`MGV^VX5eo%>vO7r90wB{Jv%IyP?p-9g4lK~jQTO{c-`_%v@;3fRe@N5OKt1`y z8aqISm6hEr|7)*u_uw3B+0ks9&k>L1QFZv4LsdG1wWZ8GoAyp9kVo$I=Unk6jh= zlx%DkJ&+GXd)I=!7Rg@}St(fNv@T)__d;1SF0gP7^r9*dOwcJ?a)_jD~)ek~Eh zf9-H$fX@}MXQT!5+KBU(1t;*i;7$kmNcP+D$Ii#b$CdP&Go$=p*yP2UtE+WBg%P8= zBibObIpTQa4FtxU2YG*DkO~Q2A6RQVrSjX`dZcVcCC$6^2*1(-rumzFzFM9Dp9M_F z+`&(nXy*){IUDxk1zMdgHu~QSlOw00uitu*(Ar-A_Qd@qF$xR&SBjGfR>B;aUtMm^ zoKgM66l?(9Q@x^=-e+pFt$o_e?6@vKZcjN=m2ksiWvZ-F9?ys$u%pV6A4F4}J!^@_I2Z5d3 zk2ds4zOiA#0OoYn>IXRHfO)LohAN85S^L2{`OFeD%|Rql$VTv(04LdjU>*agJUV5uh(zzB1h#0bT}Q zj2B3mwmxt|Xbgci-ECM^>ZC4*Pcc?&ObncdMZP@y;MS(v$;}bR@c992W1hG2TgMqX zvlWhua0x=e(0cO>6Il zR%q~9KY@bC&C&|!nq+uvw8I9Hqulw5PpIIEgHle(Z0?iOuz~5f+remY(Im6RkJJX7 z!}6YDx|~&koCEHRnycAq1k{J4U}^sCvZaUnXrZ|;_K3sdKG)jcWHnww=b+uY&(w%g zQnfa!KRMk>MMS|H69N&mL6v#RVM!q*N(0-(N4Z>4IFr*Gw4sj^jJEu>3G`M07 zRhEDKeh+es`HiU>k91T|K;CUn2m$wo%bPq^KU3blj>6?K7Z}zW!@#=!6 z_MgzN4&Pcq>e7eSme0Czv6e4VhYOAQ!NG?Mmm@M<5H24`Vudav7UwTtZH~LBaUbt% zJmM^!o;U{)WW}wTOR0a_4Y-S}02Ds!6ig2K&!ov@wD!1--{Fhg8=)B{i1f)SDqd4e zOKeWkKO-kp@;)CA8p&XddAR3DaH9_b>%mJt**8RH86USaCc~H+AbmZt9$voNV-<_y=G5s$w)LxZLgz^Fzaf;h%NdR(i{QP0UJ{HC=M$%dGG}UxErj z^S_(uYHIT6C@^(B>Lsu1W$-2RE?_!AG@}mA>}&Fyr|VK}I`=^fYca4%3)7rS$S~nK zua|I;X<+jhiuJ1#e|v1%KuGT2pf=)n?pTjTHUB)1?_!?tyvcZj4FmjzhX5eXK( zR^AsfE)yPn7GPTdSV=wRG-+CHB+q^YE@ZmkK8?kQ?OG*{xY8Rw1+LVtxb=C_i+Wr8 z?~?!S8^F)Mdt_v^-!j#kloC1n9V8%T?28avJNTucUvk$iEBfk@133Yv1^cyvS_7@! zwS-%p8L{+#g3DFRn>wZ)1gGJ`7cC>+8~CN3uUFvqCRYy-86;cG&!!H6-n6B?xWZ7e znLE6%;&Sq2d7X7juuTNGfhmc zJ+YU9lLZ%xJFNFHCh{RI|fvyZbO?HY#bdpTWyTdu=Zk%yZXoRi_gRgwOR5 zplrbAi@NAa`+$6d;vBggablyy5m;JN)%RphN$1Fbg@KhLuP*lpmndfSU?%vhWLKv7 zalX^;*AJ?4#IM<|RU5AU1cq$>v8DD&zpnTjQA!QZt#Sh`urVO1dFopg5RrW;O!!bG)H36WT%ytl6EdX6_uWs*-~~wV#+4_*AC>lQk{1ACkM@0EIg z5MfGOYLG;Ece?$4dgO!fe+=HD_KwR#LAq``3=h}c38~+7WsbXxfL`LnhTwyD2fl{J zkZ6U(dJ81-G>f+{6|Fn3P3b%?&$@#S3j_kjuy6hC{)cTzvE1qD^P}TnMm3u~%NPaO zmOULLRU7MnaUr+z(`a9^D_qf!cc>D2_Yjs}#&2#|KnP)`B3Ez&I|O0bMQ?hR!n;Hz zLfYWOtmygg#oCOT<1$J%=J=k+Ir_n&hrV>ld^qN`S61o-4+XPK*>Qow4V^c*lu0Mh`-*2dqSLNFd8;sXY6A$$Z3kspBZ zzj-3+QBJUay*Mw11>%$Cuq-$Ip@-OERnv(gVuMV!%N{9I&^C!A zd(On)Mx1k(6SH!nqD;#(E1Ip6ET9)5el$O071$R6fsRldDcF<2O}!-u;iASCX>@@? zEKX83(u*1%6FtLUd=hQYogMO8et4ouOJE6{?9V&ZIX6dRHv1s=(m%nGE{F^R#q61(#$Qc`m;~|A`ivXl(n5JLuo5o9v8{ zNs726c8y~FX#vW6cvzd9{6SN+FXjbt7yR?*XPNr$D{K1u)+ggg7#4J^%$*|%?4lbZ z=}!?JrK>j(U}UsY`;FWpJh+LO9zJBfRvxDR$+IE;YMMf)@N_F0bxiYpetrsLH;k5U z7^UQPo%|v-6;;>dERgvYYnzpA_qI-SH8g@oqNDc{LqhVBeu}Shrqh#ZtE#dK=8vd4 zWxm?a*GjtCMe^(v2I|!DrmL^v`~Og`CI83(o{_ih6KBOu{i|aoF0*b|6X$!hIKim^ z|CdoDoU8_xE@JxHPs65@395jv5P9=A)31air=%105Ge%m2ZkLRr5DMZPq~0#bpdwZ zASBVf>WVtq$Ja;rpjwx?Gjt}KZ-WV>@ytpPAQmG>(dOnBRUn*f*deu^V|&IJdKthK zGDv8dsJu5bMp{LhU{7~5K7HH@#kU57$&cI>2J*%@;nZQ`=ErA=(AOIwQN;slJp9c6h=TIswz8xv!;mOf>(yJXVs0e&R}jp{7dSQ1psMrG{!@283+ z-cocv1LcHyLXN0F{mu243_0slDUJW^0NXwL2_XX~hW!gS%Zp5Z3!I=Msd?!OyBLqje142q- zON+2RX0T!r7oFc$I<6mtWN(-Ere~KQ_FSyj4!y<8&I8#7B+q1UV*N_vjeSiFRc1QD zbt2-Gu!#3J(dAJuYDF|-R%r(KD=K=X;r>iw;Jv^G}h}WsHNTG zA8}AGtGJk}aT($bMj#}!f+GRq5)lb+nFO)6LHl|m;wb;A=rx@#o7f%|{2Cnp>TJ?S z+|!NibzFOLApU0$M(|wH)X(wxq=RY}ZuE}+E3Feb`~_UEdu`vRr~Sd&^v4Q7cStn< z$VJfB*7oVuBi;@OH;!dqz2gv%%slO25WjW#OyiFvhPm^-M7q`0RH(bs#MOlx_@kLs zezE6DZI-sqt7h~ntc3P|8~;jF?J6V5@cPo_R2GmI(4Bc~ej+;DxVS+$K1tgrl}~S{ zhuzDN%=ti3=hYyQm6i3tveKOf_B(U_iYv75BPywn zzDfV51Cb&iO*pg6K{mg=ukXBKm}R!(wde{Q>aQMt4_9c(IfPWbgDsS@`jR)o*7)zb z+82)D1~siH8SP3IxS7gw|(0TeGxRiDL^j?NrgZqe86(kN0{BssR=7J$6jUKwhqHu~hlZ;Ja}U=8x^`^id*ydFZ|XuAfCmtC+(J{3dW&ogW=i;dn&$F@@4So4?P3Zf)Io zZ^Kx4?S6xpwfg9d3MC)`z-i^?r0}ZXtZG_?_B~B~*uD~5s+N&^t2rR#Gg}$Hnu9}> zO(i7EDXRmcotnh`mu8olv}^1V_w-o}o^uF08p9HW$W#2UbB91>aR?TB~?R{1${{wor_CN7{ zM#gW#k)HUu9UZq$ZYXd_Jyue(-`VO7dVU3fV6EPBHf(GG&kb@Oo({QObFblUW(fmA z6jO44Zmwd8PIPpcYYtZ0L;iIgD>~1w`$W)XVl>8Y8+eN!9vz*fqrj)6r1TF8!ZfAA zyc2!{oQ&AhklG;74R-S2Z)?M#I)g2`)n=Ggd}nuQB1}}_3piuJHPz%@rU}?l%0NTk z-qtC|PpZv}{_?-+=%pxlfj%o^Fu`Ca=q_#TtW{$E;69?qjEPCwz$q|w(&xyq}jaRS5$}>kLt-i+s>#EdW zzXKR@-2hU)W+1`hx5+mfsk7 zKh_!}e3%201c6#WN5G;03odqfc9R@vR_uFo9@Fu)-q(ZRo!0@220@zwv(P1wGX>gW z1Ic@dsj9#k{-7pe;JrSyAUu4cTfjQxxoxWO2d zl_>Wd>P$q%6F$%&o59)n1IYJV4U_g5!6Y88441EtGEZ6mc9YmhLp1-F^Xh;=*`lmp zQd~t>Wm)6#qntJbox&HItiqv#qq`Mues*WDCQ8_qIGDf8>p-!7>N3!M7hx3wuG3NQ z=En|dLapV$G-awdFh0Oa$O2pth)vxSH9l8q|GJ*akV);X0DMmSbAj{;Ss0mb7%pGK3>UYfOylUP%#qQJuA#eup z3QRvb4s38E3cAymSo>kLv?-z0$b4|vfp=i@&h7=TCmr?=TFd=`(lOa`2)ELcX>NKS zdd~M%@IpwK>ji=dXWHM3rO1EA6?JvyB#^=1an$TQ2NMX#v7;Q`jA9PbEx~5)U}a*N%lBE z)mNz%wLbPSE+!YT5a*9A$SeKFI36I3GQPGeQHred~SZJNVOm8T3>O7o!-! z?T%g>fT$N2thK@Vwao#55b)LZ+%tRGYJ8rQoYgt+3XBEhitfL9&v)BOp%dbT zpPhtp@{teqBNYjt1FUzXgf(t_Wk{ep86#J^|5eT7yz1Dl?~UPf4>!T}IzuIDA&}f~ z35*SR;5zUWGC3G{@*AUuCkM^!CcNUKA3nTLF`>4$oKVR5GbuHju!(8Fl5{~!yz7(P>yEX}Mf`vVXCPVP zMvNh9VsH39P*P$-HwJT3;o|d6)LU$nf*t{i zSdaKOjYD`#*_)oue1t+*{&3g%A4W(HItQA=()6-ZfOzi>ubs;RT2$|)q@}|t!04A2 z^1*n=cb5ou?G>CBuGM}Rb&<=psOgq?81Y1@)osC-d|1Dnnjl(4LO0wqu26vgRvrh| zG6C8)f)2O844f(d&;Yfe16(C*8#kB}~evD62a4)D*1@a@0PI*FEPh1D1De~SAI z-dl)iH9^`Yc(_pB8Wc=&V%FfqMzY2yc3{CR=~$Nc#915`htXM~8FaWBt%P|{+#Ye^ zaenZM``XFz0qW7w1xt62Mf2RJnR(#CMsP1ILKt|j+ixa!F@Es&COB?YlFNM?{^gHU^wsDI-y0N34 zQP>h<&fG6%UNokCly@=;>pE=xJWHQi7^`&U^Ch(QyXS_8CT5?1Gt%5$fc~WUCwzl? zU!{XB4Avxw( zEAXy^(43nA(g^X2CI5(ve{F5!2Kf{iPT9}A<)t`&S7nb4RL)>(=(~emnMF&s zjies!(B|v|W6ki3-5U)bJAwnb2A&iT?zWLcjVw>k;laLx$!&f|y|Hq(RdsZAb?lwR zzjFm5;4w4Gf7?C-@E@GTr^oeSp3k)`%#g|JsOymm)Yc=p11njPL!12bzM2aGJ^P)R zRZ-G4r)XK(KX9NiKS>D_wXm>Yv#fnOmx5wSb^l46cw%U!np)%{89r)zb6dhRdNa51cw#{!d`s-3ovj5?Z(%K4IQ-v{Ho0$wFZifh zh{z9zS%j3{v*b~FO!ZJ9CoHn)Zf+1WG>E~f>U$h+t}AYm!}Mu(yZCV>$R9RVKgEU? z6D~!yW|PJY-<$o?u%tq}DVs6z)j>a^)nwmJR6fru{o%VmLVE9|^#d#wIdzax*~IrB z@TPKdGxSe7LaMyGg7QyNit53v{?cgZZ!oGK`BLc&DPcfcdn-s$98_EyK$Fw!2jgaF zhx}Qp@F%2)EbIJnORVqmG!Hmd4c;)IATS^!wbh7;DJv%UIr|vcLv@?c9>oN7{rMj0 z1``CL4kyB*GyE8McVfgxi`5{BIJPbyu?K=DHwB_u6mg4?ED$@~BKWjGw6;w`)Ei%F zajqXhSn419jCuxyu@>n#Y(o{0bRl|M1r`&+W{P+IT3>VReZCjB9TXxbu{FpaL;2yZ{knAzzI5N?uM*z`Zodr|W#Z zy@x*c=Rac9O%GsfpChM8dY_-(>bm^LvT+{s?muoN9P2pxZTj0jvy~rF922KwYWch} zt7C8XGSZ6euyvSB^A7Ip?)5B^>C`EGclh{w&zS2r+FD zbRJ-!}c5Uu3QwW_#K^nyXEZR6yXZhXBtF@#Qa9eiI5gm*I zfN)$E4~9s#{DtfIUgKFZj}&ve7m>2%`w-|pnNyg!;v@uiu#rI8j{xWjk^#1F217|? z^T|JItmYY0d_0lmA16~ke793`kMLmQTrxv?j>7LMCMtg73qmRH;0NHi9lzM7Fuc^F zFQd zRDY#A#aIZtAM|1A3+s2|>9cbBY@Sx_lw$NZv)SNh^Vlc3x+J9E$b2x4 zO($wZ$??mpD-}(g6crV1Y;C{c`YuS;0Y>qys1O>>&6b{77G}*$zo|FQM6O#I_~OOp zqxi+YViP{bjQOd5c*JZO@~*w@H6wO-ch!o~g)j}>Tv@PKWhO1YooKO_+d6DG$f9x| zsEmB)aq_BC*AVnZA5F>A<+I&VEJ6!U+1-CQIJ>S%MgIOiP2K@vthv2HiX1>G$TpR@ z#)He0>KmtCz6dy?crPkaUIWGyAEkgf)>wZ*2R|T-04&bfQuHtz%ggI|!(spkAUtmw z7Gexa>$sQ~;dSdQU@aG7#uhu|6TP0*Dq{O5h-#aCE%+(y zmwJbVjr(gqcwEkYHXfGr+-+mC$th|#4YooQ&PDbB%{NS_?8CWDW;=j~FP1p@9i$d?%?DhMJr*8W^EtvgrQ)lETADhNx6z-qz zL)aI4o@k!sDO851>$W~<9^&0IIJNJ+gXdJx-b#fyzhi|~T&N;IxBrtSeDD#DrLwm; zWPglqK42GevzKJqJwv$9==zzfH!9uXsg2FeKcD~(!5c9=a+I05d6H^|@&k9Ghp-41 zDzJ=96O!TCa7nGY0lNgGaQXSmo{Ii+o~f{b&mj4cYF2r8&cvZR^N7!%AF0?zz zVubCt2kcFfss{gQGTn+tDFz=BRF}kGAulyQ0 z1;}nxNU#7$Y@yx@!vJ=v(fCuIuY3*6WimtP_MLL4w+I~l68HAao2o7Y7XYiEepsIk z6M<#gDM_i4yDy+6x>xX=d&Vh$LsJjp^0iJM+TUh~=Yh-usjf?!dE8JwsR^G(Gm4Z| znhx54f&~?JQakcUkWQ8 zTl1CJ&Jt8iz^Z~k5AC_{Gni;@Mfu=|uzp3<_E#l7U(i0r`--kS5ISB(k2L~po1W3F z-^;d$YN&ynv$C>g<5}h8AZaZZY%8^K^w;M)0*?EHQlii3c>Yu|NjOEQI*) zzDfS~khKfY@L)K`TS1cVrs;FR=p&Am5Y1oH;E6zX79G!mNF~#7?s2sUdYbDs03oBr zTzag7a-$!Qf9*sS6D!<)QI1gs**i2WaaYdje&e7APKaNy$hB(wSSZBT3=9>pttVO` z7ZxFh&mCDoUwR4Ag5y!JIvXAA9$+?U&tzi|mwxX0wkIDPI>LF% zaM+FVsb#JwIs=IaQcSGcZ{GO~Yx!=jd`~70FIELOrp~EVH;^sOp=s}o)f{%=*<_s+ z;qOJojL8&!7>T{4|6P>~L6X1y2hh;;tUkIplJx(P^%Y=IZSU8E2q=xBARwS91}$MR zgNTBFbT>%1bhm(th>D`pC`h+-ixL9T-O@R94E4Q7z4!Y2fAc)|iaO4mv(Mgdu6Hd4 z+(cHrgXB>>2S=+w1eP?uNKQnw?-87Iou7Ptx2*6m0c*hCQ@x}Ma&zr4ml}me(;xQ3 z_m)roGFalYJ~}cP;4R=Og%WhuYWbJE;Y;`7e;&Jm6GZ!Ftm^3}H7d5{vfOXo{*1&@ z4;FSdZQQQtm}(IF5U199BaN!?n&bGxs-irg(#B=Uz{D{CSpaDSa2nQ_ zVPV$eWHRXhN`t{mK&{mOEYCk2eR>dyFn|+7X@v)zC~h6-O-+o2l zC_bmdg6cK9@KzV|C%}h#c&E#m^BOSQva$NN;AC9lP#9V1wkHlnS7RY7J- z0?Z{QAYAGw%B5sFeZH9S8!ZX$FypfYXm(N_qc+)av8P=V@35UYgkcd$TMhrsqgF>p zyCk0O3HHSaUsYOXW5?I@p ze;0{pD2@U$Kqc{8`gaM#6 z#A{Dn8SsBTprRfj%KQ;+GB<93)%!yueD;IbVJM@7iyLj6ugVc0i}88J)W&=zX}XW> zqEhFE6`@=zm9&Q=t8LSVVynldegiB{Pr=S>LMWASz;I-LkTZ+t#{~^C@CP+sR6

d^9k5RL`TF$zzdS*RLb62e6WvB$%a>7!L{b^vA(cwYhmJE&Ta z&gEOYVc0cGI*|VX<4ZKsP`rdGUfxRsJ<1*z&Ngt5giepz;@kzu8HGA^HGo=d!|Ei} zON6ziGr3t-yOwb126m>qn#DMgZipNPfepE zqSh-luG#nvAYBoPSl|T(@)ygcVsE^E8`~Y4NW^ z>yi2%Q4Z%g%jL7Q85s+d#~8fzo-98^c-r?tgMsBj0qT^R3(z{PS)#ms_8y>67;VPK zP+jn@DCq2yx0xdcSqK|e79!M=#*Y)Q6LIA`&MqTQdhTsGt>KzQWMt9v zWME_0Z!HKlniZQpQ3BM+^b=B)$wR68?}>tUl1sV*7%&^=p_YC`YuV0e=KPyBxOr>B z2IQXu!8}?1gm&>${lqIEK@mHSlcLsQYu_C1xc(qX=7#)Sw&>{^NEUp+U&2TGmp2#Z z9n3%WQGTz})Z08Dsj}?W>d?BrKVDgWgn}%m&sC)^!t+|$J~{(F>7o+9vta4X+rIcM zO~bY0Hm9##R0fvNn41Sp&=NH}0l;ha@R(-1a?$aLz$EGukq3-Pl#=IC%sLUcg)=3_ZP>CR-SF zYt@MW6$I3+lG;9iaP-8>rXN5lGVi-H-Pri_MOVh6AYRlmw~Tq?c7btoCCvOk%deJi z;`t%l^j0kBGVuNPK79y=mfZPu+*^Itc@J)>oPx{$c<`RYLXu$kMu~#)S*5 zOBy3p9>N7bB&=Q|x9LKl5|`q+l;B_N_Ot_c|CIhEn5KWc_3mQTN{w9=kP{+Qos!ML zRyl;##YBkaL)Fg5=irTigl;=Q(3f3NLrTxEeI z&3_H{D~F#k-cXV_v8bp7hMcKW48+@{d?J71QeP ztFFM|WQ@9TFGDi5^TD@pt~H;h>oZ?wo=4YJ*2p-G?U+)pTJSo1M&S&L zUles~p9*|gy%TYIKT*D0cbh0;-Gh3zWHFMsh$M_NeCXV3x$j$(I+tp2GW(aDn_j1w zc`(fpX-Lx@Jd?sFUtJ+t9-8GKY}hZ_Mth**B((g3V-~obgP&%+nJL(DTJWV9iXtgW zoncx*bf26?t_gS+i#&{}={}4S72-XTDeWR{|LiNS_g&mx8r4R!a`%yZJ*aZ-m3}X? zKOw1N;VZAi(c?AfEJJrdiYQd6d9)>aMOkVEzIdwPw6syebV`#cxceJU$vV80P)g5K z2=Mn;a=G+YyRcH~Sp1!DBfV7k9E%a0cc>|K3dCd#rVH4{IeSWZ?yeNCPj}$$Yn9N> z$H*%M!^hosC8VO_FSwrjX{`69%(BSqKam>8JDB$iI*`r&MMFn}xI1V{i`TxR>o8vA z31jGf2ge^~xlxRdyac9-ni5I`01+&E89rjt=Debxg zL^^PycxF5wNJKOmE^21klg%D7Z?dd-?FgmtB>7|xIwp}R0J&J@xrJo&v_(~SO`SbE zH&!7>h1PsMJIg8@(;*9vid5*l837Oo=Q$mJT=zpAH$48t;#*Z!oOsLIA5AN^ibLgo zD^H0bo4MVZXI@u#&3dEAepKi7pVlSdPZ(U32!_KVhI#sM}_K2z;L~-v0FIQ$CNQ zs~(e*2>&VwMg4=xS3d<~*!+HM6*|Y+>9-ZV>er!t!hU}{P2i9zjqo&c-ATzcu%j`i$T*IS1( zOyE_hPEb92lbqSlzNAoBrcUNfHVF8f5wHVb z*C0duIeq&U%h@k$dO8RnmuzbCT|j`RAV{CzB?>dW%29xtq9{oZ!!(50RUw_PAsXb(N0}_w6nf=oEgMq+qz|L8x>&&$C=z;r@LtcyEB~F}WZGAqGJVr*g%f zkxNl8UhMDbq%V&1t=2)o&u?cw=Ki|z^tOsMSK+1cjl~=PfEe<0kvxVF$WitQCoVi- zIo}Ch8ZjhNbJp9`pv)_T_3O!~## zi{9A&K)uO>-DOih_siDoW<}a=eyjIgN#BFnhT<0?*<2izlfjo@B3DE4?WL!gZ>pq5 zFJI@*Dom47Ue|Z-&?O$w(pD3yLTdF>O1XRTz3(;{o3KJ9TS1Ti&kRfa&ZRM8(3SkIgmg}xi zc@L*548!<~uSi*d+8Y2wd-ne6X5 z-v|QNLzOQs5Q+T8jX})WTJGU?k=8`)PWp>H{9DZiMx*zQP6}A{5~Kc_%@oE1Nv3tf z!mIy0-I-stiO2p1I9JJ?t{Vt{lM$CQ6LC?H?lzOxz;X0GbOO{E*3^chnL?OZFCX^! zP<>}}$Ex{EQ1zBK5nj-bA0bZ&`#f?cZi-JKLWVZt?@hQ7i@WWm86r>Ynaztug{hACGDna6`6Fr5w-b015)z_JXb-3S zKJGuf_AxF1Bt(t0a0RA2?&Ex7Vg7=6F_DAqBcCh8W?6=Lspc=< z>O4MO;w9RecLl}GcAp%DoI|>$tvl)ds zsYxDur?n~y)Bp4U6bZpxiy1|>zVvZx>o1E-mU80=843(RZZ`1!~1BV zzT%3;^sYpioDQuK@>?7%{@_-SVbVT%AfhLC$&;)2d2q}krJ&K=RiOuvaP2u9_I@#f zVvoLh_h8j6wcrS&j#tH`qL<)YX4%S93C%i=N+G9x4g0338=-@xec(OUdx3E0qsw(zG!Q#k$(dQ!$9AFtC z9-(X0zvS6a)HV`rI9oi3YYK6;TB8&{lOjh|<+~1WuH9o z=g?A7$%#PC{7G15&($LQoQ2ACfJ9EYyN!XO&4Ot1d5 z(*V>*;P(Z1D5+X=Z{jwCg`q^(ij$7s!4fUE^X4AjMBAS&tIvl}J2?M3)CwQ~@lPo( z7LS0U-ecmqqmcI?y%4BFoeY+N0s(_igb&Oq#sO*O6;6FIaxu_yqu(vrfr9eZm`t_K zVXlq}plD7?D%H2hNF(C^E}zWpexbk##g`}d4Wz8b#tn_pSOMODV*R7RiCa3+83h56~^+C4nCXA#Oz z0_2}Vv;bTb-!m`e^!JpxVjpu{qu>V8m4K>)W}*GvrG}ICo5QWu(FP;zR8;6Ip!rJu zVh&$85Mk1Ux)XIrXXmJ+;rYcq7Ydbw<6Wjf+_I;&oZl%Tk<;YVPArq{^|P*oa}B?H z6$hS<-e=;P64Oc#p&q8LAW7n(rD5jg;ULMis?2;oz|zasO&mEob5w?cqR?qRHf9W% zm_gdZ1j2)r70yw2<^26^>-&4U)*}^oz@+d1_+!6JasSyTtt^kgxd}K=9*GjFMV)~d1+v~H*_Ay>`5A?qfRG%rp4uo{? z3uM%9z*axp%%7$6tV-a!Ruqr(Ev7Ap1z*$Luv3B?fqP;R9PeEP z$yS5U(kQ~s<0LzRz)%KcQ%KG>ZHbO@(j%=K-RB@k^rymRQ*@Y_pcH3$$;bdI`26m;CUOw-ckh?a&Q zii?cqts`1pz+F4rUGkbzTaq?H1bPQNeF0xy&E3;3zWrH;VuS4)qUZ>dU?00>Ewls} zGi!F9r1yi}=m7ya-AZGRU~C1VH|cnhi;AkXy0WTDnCS&AEiEcmkM`mcGc_sdDU6ap z@H=w;H za;OK7%)u}X=Y&YIJRamIzwp6LeF7prQ12cb9jOeVI!^Uok_MIHHq_nK^&xGZ^$)nH zUT?S9aTwL1*1VxqC}X^lf#|I$>_kLt`kv6v4LT09y?K(u`dtqu7*I340J z690BQ@B6IP-#fZ_(A!!bwDFh7nRb-c3TuR8g8es0zy5SlI9yG4)w38M|M&yC!$gND zzFZB7uA@X-kLlZkWTW4#(Ee9Q{`oHTuZ~fb-k2z2#S4Nh=sOjD!{+le?Y7hUmr73_ zh{*9F;b)}4k{6}s#+#dy1pCoC0G_31YxMT6!|a19eg8uT2u>NIyd#X8+TYn*WfcxG zw4D#4460;10IIwMCAX%N&&)V7%dSXgHGzX>*W%OS+r4riVUs$yx`RDOl2Hb=JZCP$ zFkMq$V>+wgeJb4ql+j6E%gT)%{b4)XHK?!;WBu9CPc%@N_p%Mkq2 z^6Ev8WPj`ae$L;+ae4ltYM73!#pZ}6oQ!{8uFeTRq6okNA#=1rbI>7w08A*XM22AU zSbl*WysZ(AL(v_6 zrB)4nMx)bHzgrm=ZCD&$RPdg*W}CtN`AgiDO#AzuM<_`AgMvtqoB)P=N!QMuH4}vp z@CYgCEnQs{!upKml=OiH13(hPg~vG$x$?eB~s?02I2D-sl_sTeul+w=YzLsgFB zLLVI(!@Kj!WLrTbnvw&V%VNKLR#o0aM{~ILZbL|PW94)rkvL;6 z)S{n7Up>orFRQV``x`{~HAL7SRD)C4)x=TwPduYtnQ$YBiuA!@6cPjq`xDlRGA~(o z5AjvqIAKe3M{I+yPXM;z61MLy_b+dup|q56d$G-6NZ=%R0oc<}u}WR<1wG_J9K#7J z3qHlR9?ff^FZJyQEYjyEEvs^nf>m)?@F2zg35aE<1tG`OU*q*VjPC@g{XODY(xc>* z>4z=#$P;r^VXTS)%ErDLrs8H#v}G; zfHA?xAwl0dE^)HL5f3Pwum%Yk%E8MCU6#r?uoJ6G32lCV8(`!eyY zXw_*mNbO2@s-^%%@j3pXf}5-3qh;#?hA>d?=6H1I;CWOdb(;0K!rqzCEmm}POD|X` z^*g#LiYzg+TrL&A&>Hgvk!-9nD(sDUm~34oJ8`vD0P`ayL&!P(CIY(L>`2I0su_K* zXq8Ms12)l!So!o-qj#+ygo0Gim^J+&XqGmPYciVoALUSVAJopzg6z}?K-YTq_b^r5 zQ8kD}fzcM_8Y!{CV^Z)=BC){MIDC?wYOod}@T9EhBXf+8LZ^NN5#B{4^@@#H0Iis> z79j(QA9QOoFh&x)RbHs9&e}J0f?kN1PP!P_b?ZmpS=n|Ip7De25%v?P;^U(P^}DkF z5Rop;U)+su>BwVxe>&$KSS?UkSU}%#WfSl;+W7Sz5GIDMf)&<^PYytlLpe&KQ}$y; z7LPgOJHzMp5WR$cs9W&U86i+em}_#9VFT)qrA=mywR9ek)HG?ag}I+y_fpC(^br6& zoyECn0Kx-Wk9bhL&sK{7JnXcv^SeW~r~|Si1ayRcMHk4%xP#xD7?IcP?g7*ZBVfkDyuElRA*Hz9QAE$L=C%rWPz0ev+rvdDLp~n_iD5YZr~H!x zgU!>JGgWvW_AJ~v6ldxD#*$X=Sj3NSe5-jB6z~IzOB}YE95b88vtRn^#XJ}@yQCBI zLz!adb@8M=67ym-ZYRZqa{7mf;&lw}z3}5F#M#@fG#a5dAQ|>Gl%IMDC?VbVc~k#% zntsLUFDap7fFK?q-;x8Ob~xlo8Qb`vEj_*Y63VydE>eQNQxFiv15^SSEnpE0?!u(z zVZUVnd`S6nheGcaTo1H3mLV1+WsCrA7nl?CiiEE4JkEa4YtGp)U)OCGd%>W+`wP|$ zn#!Vp*?|IetF=N(D#E&4l@Hr;YqkUpL@%#&qYxeYzUV(t%5rV35$4di-&jQ4GSAiB zwv*tyeD*6KAOuVY%~=sYo4FrayWvrOpa}Vtu|j{if}l2K*D;|LwGAllGZs9$bp@ObM9+7v}#8OKD~j<2wMGS!m~uajvH(6DBkT);qc;aRlksp)swx5 z8rWHlu42wX;>)SuG9P7um6NkXPKRzet6tR9WZcEf&JU1#dTQZ$pNWRbulMaUV(yyk zKgS;6p%A&7i-muUvGtcc`Df1?us~Lk!;{S{phAK_?p`W*4oi=?yF0n%8U-M}ar1iq zOHGv$M;OJ07Ky>J;qi57+)R?biH=)RLUKIFzjSZKJ?A_uYYYvTM~8`+&+;89ABujp zqsx|mtwPjJ=W}U#=sHF*p>=>rUA!j|awqYfQ}i)VwGHihg;x6*#c@Wlts<3Xx5SaB zFAOcjhy;r$2vl~Rot+}moZwrYE$Rnn7i{kRCa?O_8c50pkbuWUW$4*&Sdqh?=MP0x z%o@Tf-VZQOE7Ssn43yN^T^Fp$zOc;oudZ^zpzbqIFRj9lJT)7zL+|3Bzp!g);1?@{ z`U7J^@Ma@C-4g!8C9i>pkuQxG;9*F3#jM&4T}f!qAk-68FXS{NK)W58;n@MnPf;%q zZ=y_sF&Ga9%tLAMHKnD(`2k5^L!#%=+^#=rGI>vAZ~`Wjk#X_#on{gg=Z}SO2L`-q zp8{am1xdv_P}kr=NN~3cA$7YS!<~QEf1=xr4|KCrbzR{0$i8Ssfw&~_?E-rrB5mwj z6$V17RyIAxKvv$`(p}SC^+KxO-_IM-p(!&E4qZSX48V(yXf{;k;YJg65O6atX0D*B zL1D{d3`F2XvBuDbC1~hMEH~uNQrPq-112X@?GRuwfC@=|U98%#fMhRLHmb7Ho3)=V zFhD}-cOTQ}v<5&O`GG(-e)v~Q(qt?5G$}$XsbOPld;88ETEP7-?b_+DHP=T*YF6o5 z{^QNfq5zCwASQokpu=p`>+$^0Ub&SbO|6InslXesX^>Iz*SN|ONbN29PR1>DCLb50 zC)=?|v)M9Myu+;>g^L_#zGn{g;IPjzuXU?%i=Ae78cs_cIm36-%55s+`}jt$gk7FU zdWI>#UW~lMY|jv=qnWjqUV0~B{(gM#P0B3moCic{FZeEJe``eU@PA$(x`60SGAk`2 zCUZI)*81f;%G87Fm>a|>$}W~8&n+!hF=^F_Uq#rDyKY`7EVDtGc)Y(}N zB|2dIyvSa}S^VQGNX$9U?K${P0Lk40zUkbo@nYS&bk5I8Cg~vUQ437@w2_;L7w9yx zIM(NSZZEBB0ZIT5tv26HXCDuZ{^l>K!tCS_Fz}7w49M~@4tPy){=`zS!V0wkH0J@m#m-o+sOfm#A5;+p+sSR^Zaj*elKT_Z=3#c{%sA{Fmuw~^e zS^s_Vy^eZl%0P(*Y1Wo*6xh5d5g3l>wm31V?w)@C^T2Bi9Z6oP!wx4hg&NP)b&g-P zlOdnH#FG4UUjuacKkk;e>TPaxs_s`ABqT?sS8;53E|N?P`d z^RLmm#KZ9Xm*;sayZc8X zlVnC`eW8<6RoMe*!_aDs<%y0Y;4XNM1!l7kt7;%nUmO79YTRt|x#0dFON7oXS(w*~ zHu0}f#tS!9w#owK^w7Zp{sHSpj%EC&?UY)K;JNL=$$se@xULzjLxA9s2rY|sfnVPK zum7Rve+$stK(l4y%{!!c-N0rV_vp}xD2DzKBtcSwE-x?Xbt}1B#^zJ&$5c&9DI>>TzVpmQrvUWNDj}!U# zSJ6ifG~|Uy=8&r}giy=5U2SIu$UC^#xdUVNMw$CIGr_F=uSJI<1b|wl)fKjRYi0Q) zD}UuLE!ADX*oybh29fgYCk*|TOzLxcyar>M_u4$6GJQ165GZh0T)rJLp944|fAY{X z5*m&S>(OVXJ%G~PJ9v)8(apG(3j_NyB4rf112kI?MxXNg#=Qlf{E~X@D2gWYNOKin zU_RCcr3XQg_Woi7H845*8X|0{-^0l`Hufd|y`kSfbf{nNDZ#sQZgF{O5otpHbiGe@ zU}t)PVMEP6z0qio|D@vj^$8|1=Lc$R#7-eW1lgl|i z7Bjq0e_Y48#z?=DVYk?kiX*eofe67~q(~-YWKU~y_;XBE40IOV&6i7l#rQD%t&T~j zU4P#Qo(D1vLw)ki*}2=8jlG>h@jU{@v`w_o>8k0(gtFe^^R_wiEl@lD-`C4`KNwHn zZGed8<=S@(z+9gNxzIZSgdgn^pTELCx*&X!L1rFp2UOab^#V;E%Utv8i|h+!p>Y0x zHIV}eK%J{kVZ{^7#;dvXjBSxDcF??CaR(}yCol{(P$c4vG=PhlRYSFxDtXb@7< z>9g-rj3Apza-L(3B_z{*nIsYmNB>0l?vl z=d=z(Is!)~z&4+~oOU5#!|M=;qd-v6lduOpgO^beAwep{X^h9zTnZ#U572o4j682< zi|-^%4Z8>_9FpMDzC_8sy1H!Eg;AFTz*K4nuwM0BdUPbFob(#jvLo7})5xspGoLMD zFA{op%-#Gcjc9Ns(2NEzc5m$?V#0MTK!bSdIp|lMJY9bOzvcS%Q%4&&5C%NIjJ+5f zu=m7YM}llUmH$|d$>U#H1>k+ht!Fa!$7Di(q`xRR5-@i{ht1Ecd(w;L6BOy8pxtue z25IJ=Wh91l30iG;{ZMRHgQj5*<&*D|`P%c_24e@5#21I7N-nqBiq#4@o~BoL{W?rC zJg~;sO6t5JM(I+gz>Hn`Bqs8lCQJ4@228Y6pOIxi*Uq2To9~=fa!Qz>=>NHF7!yWh z17E@@Xo_On@SGyBu^eeZ93T|uAPF5}x6FXiR-g|{XIxU#0lnGOj^kY{NoqBBo`Uu( zpj+;2JI=sx&ooSt{cvmM+oUfyE|gBqbLtL0^BxQL^|G`o!Y`u%)Uxt@F=M_JV7k}LV0^@UmwfTn+F!mB{oQJQ&~(;s-vABb+SSY-*`_U1)!&;-wkrk(U!!EW zWI~LrV<`PogW~t@lM$qYZ~-=he{`ySclGvDC^dvqkM}a@t9$l;-ei$D0YWQ69(hLr z)NRy_3=jcNB8y>DfHJ-NzaxTwp8PMcs)0U$Q4cpP4k(>Mmn+n6e3;l)c^cgH^5q+2 z+Y<0$j|2Li5a2oicZR^%bX}m#6K{EAXWv7!%>_0l58FTnpi0L9`l-XqCDEq?s--Tu z+SKd^V6qpG;V+Q_}b$d0*SYHe`A6Mw&-of?+QkkS+*_BoV&PF*EdPlXccv5Q%QLk&+E@%I=7bY;JIA%2PMe@9hjh5 z&~FoMbN;{E75$xIeh~E`Q0v}96oCA=RUPLxw7^7CqcNY2nW=4Q0Le?(;ML=$`1v_R z;*I175jSSIkbae}8r0umVK2K{0pslilf_B27?ftotd7s-Li-Mu`-6=AJKZQ)o(C}4 zPj9h`?IS41L{ZTne-52NKH8Y3{tZClHnE=ZhqjyP8jg1C!m)Q|DNqJhPv=xP7c4;? zC33L4G5l`p3-1Sq-PL}tpfDl+@uRtY$^bV;=AIe~Y3B;EF(K||-c<|PJfpsw&(cV? zdsf@2o_V^;hF*aSLxL1FugzWO!_Xh_9#GnnO6hl6q)BJFk(1cecj0bp9{tv22|i)> z&J=wZ5KO;qy&-}0=!%`@%t9T_8;@{W&1Gouf|;lL zTMhS+N{8+?n0_92T_Hf_(xg@QytC3M04)~G#$t*bGHnVtL+UV1F-I#50$PVr&d^!t zc3p*@yy=O(fBQ#&ZY?M9AlnK5F{2~r0^@OZHVEF6pTCq@r(wSd-JSmZ^^a9ofYS*` zV%lI@_6fO&;+-trZGg_OY`y}Dlx_kFBdt0bLC>}=6PxX#ahjn)NZs%LeAw9_K zJcOXLw6c@yH$IpfBK;yLgb~(e*(@=k;E$16lWAJ;t&%!+Po0xiZ-J?|d6T@CqDarpNaMjOc7p|-~OstkqlMv31M zR-_4m@sfw`-Q6&rbRMDlg9JdlHb&5CXx8$^rp9xV%Fzqh1YIBj@x=?S8S?6R2W7O@ z6?L+TC6x+YC7+FpSYUa_k4(i<3E8~=t*{{q*h<{ z;8%hWYU{ze8&Xq1Lf0J*XRJujkM_PpPQS(BFG3kuw6CJP68TCPdHc|vRX%3Sdu~DO z|9IH7i?&(6o_)ZV^A4PoT5l9<>kDm^^hdj$fnbS=i|= ziJkblc>aj~T1>U66+3_93Z*pF3AtB35{%q(L$}he?qMEo@9$6J<9k#L5jj@}RWG;} z&0o>Utex9-MMq~IkALV@JC~Q|B{1(*^=jCpX7hsSTHPoyJ>_g`q_Ml#E8N;*uYp#~ z5SsrA=XcY4AIfsnLpVLuhNVtrK5ES-CY|R@J>Unpq{8<3l5)q6X_^XkQ|Ho`Hgc9H zt)^!j%I%BVvL-M0tk*iZt7z=+KN+=|#a4aIws{(~zY%B>LQOr>?jbfnD7(M%(uH$Q zgm9g=^zF)X>e+0=gJjeCC&i1!R=C;J%s1CIL=(^Cq>p|##;WQ)u_YfjpW^HmD!k^k zxDb%nf;nS5)%&(<7XCB)+4NBfx4oVEBP7+X>)6J3vnH>LBJCA!L83KxlZ``|L#6~d z^&U0vEGw^N9jUg{&rIvuBZ+Cf$q}LNeqmwFZ^WB%eY18k*KLPqK!{w0zjM zqYq=+nUQ&2cC$9CVxeTiY_>@}U?(Uf*h*zHBvW;7?*u)+Vz!)!L-Rb}bu%;jY1KGJ zkIB2$JYvsdxJxYmx=C-xpe&i#H~wp%PjwCWGk1{?&o2s;a$)bqjqY0rH0`x+Y6S!@4mea=whX_C#!c*T;-A{v-Ce6Jh09OK3+T7N&58A`b%zMvh~~I> zYjv{mp?bv5M=X7GL}+5OX(v7X=hnO2B#>Fj66m9^{D;cWS_sKM(u zJLi*&)|MAFtJM!Y%bKhx{#p?fw<6R2_Hbr;i+*#N1#Om!WA?zutUEHN(Q*c=5wH2W zCVF1AHJ4$cIk+ngJ!*$Mtg12n8r~D#Bou2wukL}eq&(W^8@L;_@~|Tz5~~-zTP^kA2m~D?JzZ{KRFO z*K-t9!=?NhY+Y$~>Ll=2PAcUi4m!a*8u`=0hH0t`pQE!M1qeT#_cY5V4{?`}QEnt{pK9(V8YJ+$_nq(T08S;ju_>A9EU{5Mu+-0|@< z+uJ3*`sSmxb&5I?_db7>TIO+e?5)NOPAOEB>InM`HbsY@5o@}dqoRB>z1R0pp9A*q zTb*`i;ODO)uKjFbXFD^9o)C z1yb|(YW?F$$u`$9$Z19kaTrpsIVZH)O_lXo-wV+tAR?ga8uVA_>gp^aZQ`r`l-_Gq zU_&!W$J^m@WVViWjN;<`nfNY-=Dc+$%D1Iesvt+$t0s{7V1 zaU2v)ii_)O4T&`$EOJz1-l@ zYFH6jZV=NRUD|$_^Or1ct%~TJH?7%ZH66pf4L@P8W(_7e&t^|D9R6`gegRIW z{ij=VH|(!9_nPL?+4^?3Y|n#=?e*K48$wq&C7O+Q8!cZbEMav41#<(&O~MawYVt2Q)oag!D_a0$jWrdc+HSC7YMbIygJ zH**@sGE{Rj)M^(@*gZWt{fKg%rluD*1?akTJdDd4TJG~IDk>MvtG(Kv+hc$0@R55U z-gzuWpYsy>qsp)v_VZ*?a%-f&WH>sJX`i`s!bYpunkp0FsmPoh6-Xwy~KGv{f`;0)Me!_B~Sf6%J5{a zD5p89OEOwB?9wpNTdKNQbw&gTJMD?=6wVv3cFm{yUR@=Tnf9sW9f%G>9x&f$UIDWt zU%x|SwP4Aoqs17zH?-W89J4jySbMlGy--e5n`*VW8LDw%Xg;Uhv09YFZCeTXi+QC^ z#ONR%x+=#w;&e@4ALv_}&~k2XxY6f1`q^AFTG=|Rwzn76^#FW^%K&Ul^=zZ4@0?+h zIDv6qxL8KFMMEV;WqvmdSSt>Vwb+JjtKqzY`dnEJm8p&E;4_#tVy!W^FmQ{QX;c)IXOiEl6rp9nz#%b&spalF zIjdUyzt+uO>0~ZFs@NAcOv6cPX19q>Axc;9pi0UfGfK=88kwB!{#=|h5TOwVNhjyb zN{SB?uB?+RB#g?9&bF5rhuS>Rx9(u-7M!z>@zFj^331vlLxD>#NLsu9UOEcX9Utgg z{qaj@fz0CoZK1k3T``r~sO2*88!oss`0GSFLz?yNUJ|nYbI+bzUfX!Hx;c5RwfZbI zn*BCJ7P6kgFUxVWa{<@GJUz9Aeej4$OQqrCPCOjUqZ3Su0h2ZS^Lp{|wss?DGBjDC zeIA9i6gVqZnZo2z?a~$3d*x*mu=7o3J_@z&Ugx!+u}vs1bXJ@C&MENKkZ&{{6!vr~-pr+d?0^y98qY&D+M+WPrQqWSgAuD|VRHg?kk^(G^j-VY!Z{WvC`iaw*_|nA4P7iz-`s?AKP%@oFYBxwI(vs5J zPE)YAJu|H@KJ(x)A;s2a^Y|N`3|T8rpN81#_^p=#{3=2!%5rK6*g=PiSG5-|8kxjD zkG?41m#zRYVuxW?^tp7^TZPSn@uC{t%mpn?GNN z(S1qB>)jBZaQYn@)BpG;LZ_uPsKe*c1~HFd#d?m@6fGgO#KNj3ddWEeB3YB$Uuf4U z7C!3V{f0BaVhZ9SObev>)BM+t2)J*|8iMHtFPzv9KGdwP60hloO6aSa*-ob}dd8dj z-^|KzKU{;Ko%&Szu@+5C<8UIUmw{xeiN93l^5xXeF6Q}0pFU|SrFzNG4*DUc5j=1B zYd+;_78d*bd0)(rUhGPA)Z1tyT7$kOGl%OTdGrFaW08B|FStdObUYKj&zc#sd#b)# zt7eL)D_5~yCE+t3Jth@%Mf{}*;iY)7haw9d?{jgraqe&Nida6B?NmDawfNo{k04?W zz7M`yeKQdRhAg*tYQtBnv2J-4@XS+Wu+jgyzxhgG{Z=UJb?WGH%Z<-fs=-K7XJ$<^0R{xnHH^Ay$CB}f#!DKzF5I$$*% zIW^etR{gfwOICN2b*epL^(R~UdTGSofM+o_OAoP>_07%A#&$3Q{oiX~dqQM$w$)hM zUOhD_&GwY7={bymI;=NcSG!fz?3E=sGdkk;)I#!*Fr4VBi0S7RS3y^7%wIVXC76D~ zyz6>;=3rK~-u#f1mzk=N@ucMq+ub|4z9(Q8tmnEdc0~GzxS;b{Tt}B<=7!G8xKdNP zB-~w{x*WUXu)sQRr|qg!Zzl;fWH}0G) zjMVUWk#ci6yEUA5RiL`~0$zd2(9Lzd7-g>m*3EahHlnLehT5*-upeE~)RSunmY|1;02=sBGmLF0gC}M)t$u-8|1hDy}UVDZ90yW!IHp!#Tj_P7rM^W zpU)}WYU$T|N!4^C`dhP1!%1{cTDi8H{uZFG7w;lkUC7=dE7sSgpIeQD7* zKeT9{adny?(bgGunK2E2AO65`M{0ogrqpU%dnu!e*WloLSG!kmxz1$=@CFG6{@Zwi znY;s&=7H4<**cmylb~ifbM1ouPv=WYL!x!|!rIp%%6YwA!^Yv6=(55Y8ozVEQr25r zXzZw-J8b%+QBtNIom*C6jQ8K&po))u!T1OqqwdiHH*cq&tD>Ab;+);f0W(i8eaF6m z0gi0%12+#zfmg1^|9DdxYDg@x8SZgvRer(eLu8gNGBiV$eBf#655~pW;~C`3>U>@q zD*Cg2m;`G-RK#)VY0D8~=N=ULpy8WU^Ic;D64Is1OLj(P5+x)T(<-h`HI_D|zEF|D zKRZ|UYWY$3&gQu0o7d4cB~m|=sBryQ9qTSTCI;#x!yv|v$oB;<=q`nmLP8t|-)>jC zf|<+5_RW1h*~fM#qAShJkuo_Edu677b?axFSj1`08|#zj%1TzYPX>l^h6@!8qQ(AN z2GqjKZ}WuS83tAOLS4S+D{R!=Ch~+3oBe+9EGjnC8Gc9Cby9FiKezvH*LaIYnW;q; zyg$?YK!TJ;t`clmN6m+_v0x&*o|!5jH@M{t-WoA7;>#>=AWa)}ggrBRye+rB_6q)?<6DqIQd3jL&M2^67PN>B$&{^qKtkH`i9gl8G)Wb*$qa><7|SEEq~ydM zb?@|u=ekvdc7;ezOH9+2X$m}oLP#rN zW#z?$*vyh3s+WBz`=Vi~755M=PF71!JiREVtL@x#iz8g9d+fTvViWzTASUISpMm`r zk(rx%%c}~I3}ha<91?9ZRAlxXtV3VPj0|IwzLD8etBA0)w%jSY+mdrdB`WCNid1JS1q6~Q!q=>)5(p`6 zPdrVHu3gH)n>ET4F!v1ZJ%qm|Q8Sj34yC*aVvmJ~N%3pX;LbbN{i;h{3k2Oaqb^ry zwk~1PybfwegB78KMdxZ&@A0=qY*C`xIQltT2XjS6$YskkF)uONO$)oXPq3X=l_-Oc zx{(M}P-QE9Ok{I1dn=yhKmL9cv4ck!P9Ua`OfMMnZ1n}z^{U9BfNqF`4c9lV<)))j zWo31)DJxp2lFW4%D1{8Rj-N(bsX`G%*UTCYKPb_$-SFTbG0>A_elQw;+8k=wa9@ts*BwZcV0C~v23p6qu2Y2=DgecQmzVI)U!{6 zDy*qv1j?JdMs77Y+QT$mZ!gZR{olQ;Rj0FGuH><^raYf@QH|ZxsNlC{YFUI`y57-{ z`Z_{q*XSTN025iKL1w zvyk`i!|I(M8XA$vbYSoEfc1VSnuMuyI3FwPUk1&R-yhq@0sdUt)3YXP|EslfN`A%< z(#EG_<#;9c*f79ZP8cDWjFV_G>|*^ z_0m$6wu-j3T|Fm`_Oq0T;n7mCVoW@vry$Y+zu(2U4)1}>dUv^V05hIj(^r*glLr*IlKQpY4 z;KGoXEMoX~$iO~>;eC%r|Id;~q4Z+B#gUpxOWBN30bXS#WYP~DNxW4cU3oZs=+~Po z6|*7eL9NrgBlJMpPLzG6H>=k{19W>E=h@I#4_kPr|>p- zWy_uv!bB4OeWkY=Jm5I}_b2x59wH*^bw3Gk%;vcQ2NUQ$_vJ%(A5qxc9Rc)I75*2Dx&QhxrA`Afeelh4;-Bku3yop>4&yivaoAzv!^!fusTuYqK}!-n`FC zSKB!P4{Sw;eSU59&8Qy!`r>z*&$_*~5UZGWVSYj|-0?FId_T9_jsLC!e!4;R1@m9@ z^_ecUGbR4(;q&KNr;{(lw9K?8;I={saVAr>mpQM`*joIxAHxMqeQO5!)IVB3Q@WZ8 zCF?&+MF#~XWPx_Dk?CotV@$LrjgYf)`Yrfi$g(=K>W--G?lwk?Mr08XP7l7}N}7LA zgy|`ggiX9=GEln4H+W;&$R9E{H{$=j)HDBcsYippeR%x%;oUor%=b%63<^|co%P}{ zJ;hr6LIb0=^9jTyUB*^zxJId0@pVgQ!f)u69uo~#{AHN5SYX-~dD(M7Ak=hUsugw^ z7|(8~l0#~l1QlxgPxHP3VxrT<<|U^8kG(&Shq`V5$8oJnNgIWtMI<5F$(B-*E&HDA zWXl$V(MBaHg(61weaXHjA^X^~4592hGnlb_j~6Q4kL&&W_xI1oHA%Y8#it$+`98H)MaLOh^S6{fuvXF;&O3(&d{~bq}NQy$;kJ2tc64uIolAszQ=z2?RhsoFU_Ck36`wqBKdb;?cj~{~qE?$NC zcOkbtWjGENvm~^NZKd?T+3oMY3frQO`TZD-`0yWF7=5W>c2D5gxpU`QA9OuJ4EHg5 zeXsw%9C>u}>ZZ-`*jH5 zaR2ogVS51R{`AfBt-mE5Gm-2JJ3obcEv_4mq_#9uv5frXMAb4Ds1jlZe*BwAtPi?9v6`CNvr{KN)uN@t0SOgMtS8+gtyU*^&b(xZy6vvO`1{ksN@1rI zyFZq8EUoCNXy0obbATX{-=Oivv^V|PRuB5JvNJCCOHOl$?MJeS zsq$cp?pEag&8%(+TGL*SBh=q^N3Zzi3ZdU%Rm5n+`#kzxE})RzpGkv7_l)TWy}K`|pi2TP19- zP9luA!S^Z*IgWwj*2LsW`!#cakDjA#^1omnX_X_YjZ0KbkNFcwu||m+Bvgz9#;NYZ zM=cgvh2ME6xEGIoeJCXs{v|=qV)2A--Ixzu1Y}57UYQA7@p^hxR~>#Y&IGZ)XkfDIVkS^P6H}?!Y+n=&4e#A*cFP&E?Ca^}fDP z&TD{eP`|u2m+SH=#ps!&1KG}HEamnLy^mR|_}6y0gY}l|+5?@-hK7$pppHJkn}UMc z!Qe%^C4}FwH*D-ybo=i2`=l9XZ~wgKz=7W9Pj|H)PjCEbBcQCToH%xzkrh=T2U}$$ zm=*go687GRaCq$`(@z zD0vojA@IihzrA~gWDnE}Yj4x>B{73mHelXr_1jngJcHfwbAp&rc~9YXSUdX`Q)b97 za&nG}e7gTY^jk-wCAud(21E3ZmV1BK(PHflkY75gwzkin%LKXa>L-o1?uSWmG7nYT z!5Z2dPZ`?{aTS;`zWj~!Dnse9+uBCvJ)V>c?l2UzD3xkyX$`{B!21ZL(RoScXOEkH z;+P@sO7GnI<20YvtLUhk4czjaO)DRCh)s;$4fRJju|o|JoBUx|^?h`8PtbWVc%~u5 zgC~uAL=Et1tUx${g+(rr%)_zjj^p|%$$0YPcJQgwr(uP=QmrfoNH=8Vx$}A9MfSHw z_1w}c@4;rBI=!2;T`1wN#m6TlwWJw55qgDEU>sjkha(y z20Fn$DG21Z52De4*}+TJe}?Rj?el6!b~4PH&Z%jLxw_)zasK=>>;E-F5^3@vKVeDg z=V_AM>@wY71#}xO_s?I7{mvEf)_qmvc^+3e0?Zg9@hqAa-YX=Cg>Fx8E0<|3#qK?X4O)7Wfj6B!LQ9I zmqSbT*#KV>;`2y|Eg``n=ZAVJpGLcItyR6u+YEJjLD{DU6{SZ@6#3e`m~iYp_N^8y2{+hEwy$%=suepYmk&-(1|o-t^B0E z5Ob8QZgaN5Yo;B}%zP$R&y+D;Ya7|SNYQWU3-)nEOEYZH%ZX8n!X$)$x&g7*INRi$ z-i~)to>UD0^UwvbFfsGQQ?_Q}@}&J-ydHnMGtG_aSi7Q4zdZK6=((^~aocMw6Y(w9 z=yYvwnoF@gtG)6ad3Z^kJlFQM6b`%Ps+5$e`HZ3L`%8+jr;r?nJ=#cYzyn786ZfBo zaghyOF}E049(@mcUylwrzo}D9C3bEhzGCvp1Qdg3=)iw zLq$0@_NQ|GuDKVx7{gCW?i5b)k-0IqI87K^ZXV?#r((Fbh;O**vrVGS;!Uq3B!<+0 zVVCG{qi#*XL zFWBXp&3CspZuaLs$xJ07T0q(s{cn*px$GQKZHpipke`fl^Go|3?7L>(m~#b*`QO+s z+mn$MDe0#)M?I&2O1|-R|C0{%(9`EPV{a>b~hQ>Vv1e zqbye5`L>ul(CK#Y^UTFNOgxmT5A^3<%3bO^a8m!Vn4kH=hDiWAW!oXk~!SFG2j&WZmX53}q z)PX7Vd0u6`=$e{5-YSdaQEoT3e2#)-d*r~euBqe#WqmaOQA87AKf?|!hA;W81)ATx zi?!NZgOZKIux$>M#0Vq<)gO{4yO9W^g<{tB!hCEf6Upo9kQ7DB28BVEiKT4wN(gQ- zCp43V=`42izAF55!|NqO)9;<`dVC%gf6XRU%7;!O%!K83rl!$yPs&H1N0(Ff?o5KP zUp8}3`j-1jwnP}jyFl1dlLujVz5X@}r8js8z4psLF5r%eXm2CpRlbH@Sl><)!tgdS z>02G+*~7KG*8JzuITloFbwHq6xYfBP;B z&>bZEstfcts}GgOS5Als+58BrPWxle#gyB<_cDgZ9V*`__D+%Q6%U-rEd7{1Sgovl z?*=0}C<6WScxw~|%6q<3pgA1^5#6QxuUwvhqwcM+&7<=!Yuvw6b9J@pRy0q0`Pz2I z&3FeYJN7=T3l)iJii#3)o_YIRYd??HhA;ghKDuGc2e(puwvXt)#=D!;m1^caA%`g= zN5ZuIRv3QJ#`m&<=@@5k0$RoMORSIly>D@HEI*I3Q_YR0ziIEbT=L9m&R~b!@Grl< z6Hlb7ShW3fXaQ2)j@~lFqsjK&4o=2nm+-5JstNEvqNgVmn*;HAI9=9jK9Y-zJNzO= zP-6bbDg2!MT87#Y4kFLqEg`Soh5bH!NbNSf1^bMWk)g$K@zxU^wr=$V78+5*wFKAn z0l4SJNeiL~4hlzl(d~ZwM~Db0|nvV7!Y%sy`cH z(RGJPSq=Ec&{NIR-?>k@K9%x0_OZ9rixnP%Yi3IT8&z)D0Z7v)*4UGWjWny?dIpsu z_Ie9$Wy7b0Tynkv$!DSHOT}HLmzStF%6DAavK&1EJ| zk3UfhydF%nxJ*q5uJxfkjAK*PMh0oCQ;4W{7*U&)_M}Y#1DpKsO?{}PiQE7fp|BZt z)p$=BTU5r3bw_L?-eu2exzKwwulYkWK&u!lR1}fBLT;%t+FXbgxGFWdz6a6luJW zecnb*xUt+ezz(1UP9+mH*i+h*^w-?TY<6r$CBxnf1DFNzVi<5JB*O;;73$iY_FfDe zY#xU%_s_hbi-W+&GPFlLL0HXEV#7^ZuHMgnPWR*yZv~hIHT|KakFQ@-47a?Q+HFq< zP;m>`G!j_@E=K7W;&kJx1343Nej27ribE>k?_a0{86JJ>0Wyj(LD($S;=Eg8d9xf) z=b5VOt-gE)=b7tt-(mjqH<0Gcc2PV z;pnxVw5&fh^z{#J*u8(nc})Eyl9XH%7E?_l-{Jx((Qm*q((^*7Ex(WIZ-hUHOehu1 z9{R%x{a06;uc1UO5At;#^LZo!xxoqg$+Pwn5NYpbod7VGfA1W6+}eebmkrYv-ixS^jbi~4t_$T<(-a^^qe@`?qRpX(f37`c+lWiEU3G zv>*C*;JAPdw%<)?4_Ok`PG5SOr~_Q3Eu0y??Mmd>ea+YE%#aC-G9(qpFWpmyvZ1zH&G*OcGxDLah|}Za(=5DvdDknl z4)uOAEOG|H!&F+OG!J4h4bQ?)rEC+|akIW~$~Xtcwz*JMWE~1USOt;Uxr-vx)q0R} z#3{wdjaG9OH-;8sf^Z1PPM5Z0huOlEKE-&Gm&aOdwsB_#UyDC&^4xnh&eMSHvYk4m z4X)@76XJ^8-5(&?cK@rC^)R8VN~rU6K)itQl@E}*9GKZ*e?d?9m=LPyCW-2JL+?Wa z+0gNUcVNinXU?44HdSf@V>-(N)`U@9pW*-mNd%Lr2q0HCh8{F?~3}*wqQ|_J?58kbgY(bikMVKwNZ5fg}Uk`qkR*7kK{q z+ooxT4C~O!dXK=Z)O3$8W4+{qETt1%4uq4o<)QM3*yl;D=_|ZgY~C@TKlH|C4%+ zEnC8VpMT4+CZTl~ojJ%zp0SEOK<#|Aaq#Aqh5bHsf|pCZsFd#B&445!2)d+}VYU}? zB0VQ5XZ9UOa&_Ng{#dG=dI*$_fagujve=wpDYgO6j}*-3kFHWA=u6s*rfa9q0mqu)Fb8E}+eZ@4#Fqt#OS;zYobdkBp3RN745485>+9(4E2fs#){#5TxihnS?8m;_`brj!{BE|kMKg~g z)lB(NFC8`|<&K>Js_;Q+ zuY#gtn%rUP;Ws$cC5oLazwg`*#*X${z?{`?(iJboTg2Qf*PJHH&n5=YWcOQHK2`(G z$2rzR)C_#svElnbML!i-Z);a7YEiUU{>yzg=`KC6yATT*chp}dj)bN|rhy zxe%bq@gL+XUVvPR&-4ss&#qby3cVwnoG~QP?5b=z+tS%76gdova zT6$>gg+`7<^X~qQF2jeZ4PPJj=q&59HW2IN0Fe;ISzk5sFLG*Ib{TIsLWp$@?|x42 z_T*iKZx|O|dE-W|IdAw;vSXpd(KicbS#P=b?)IH*BNze;tCp^fB{HVh?g%$eyTg@+?sotE4*VqM_|mK{O3csq%T{@6H0 z!o)G(?4C3yJ+D2aM3~z$Zp||G@$sSS=pFcNs`!O*NH~B0hwBKu*A#Vb zVeGM-zq{3ad6)jG+x70tL?{e`M0ChZvjRNP13w2?S@=`*1$mvrnVFc+LId=2(fTHM zIGz9+gPH^eJZsQ-3~?zD1_18!@WGgOS=WAKS|*7P1B^NtqiiuJU8_LIifWxE#+FT` zPCcyNdKHUgkkghKXaByp=zl^hmK7n4eBIV=NlVN^PPa{}j$WBhe zI+RmIb2u_bUvZ>+S?KCsiZ?Ain9mmkS_GPBYjS6Wv~6y#eiKfVp(fBw2H%z^(|N79 zFsPTllDt+sVf+izdo5i52dAK)R9L^XFj-^>3+9^ga2zmUEtWqh_XIyV$g`B1l`Kc1 z8XU!o-L;`#S|k!FtJIfl5^9GLi(W4H*zdedMgJk$7rZn~v^I6DlrF;P@0YPwkt~Br z(Z6uM(m$2nQ1i&6VyIoUJv} z@Lh=qkDof+fl_OY5IV%Qfe(|;H!Jxe#@qbu0G+7oyE_IEEo@9zX!eL5+)CzCDh6Ms z?Lf4TgOJ+_S#ZvB5IXKlGBmq0L}xowh5^^y3_)-74Y3pO8Oj>VQ4hD~RgX+sI|_>$ zq{OS%^7rtZ_nO<~NOz(Ab&H-b_n^9#T8!YHz}WClRlgj)2kKj3N+hsTeQD{ffcfKt z*L<*@la|dLl=PI@wcQ~Rf@We0$k^ZeHW>I%LeJbr=31C4SA#5&v77wk(mh2=FD{#! z$QP?T+Bt_Dk_7Wf%f+zL$8NyU6AY0=+5SS|YErcD^QIPsPMuBK8y1@UuxISIb-_RF zrKL`WIU)^)eZK~e42`wh3F&CYKO(0ZzIXMi_scecy%{U3Da``==OC0(P%IDvYY=zA zbXlS^4Qc@pv zV`J6H?@^6JgfJ$2uO$hsgAkfU_xxRUcJf#@ZbzEPQawkj+l<#P85@U?jbcU(Jf(E0 zKfTkte;KBHl&#!LR61FF-Hj8~lYq5)Tb6TP*T20p6yL+h$h zwaUEt1`j1LrX{98kSZycN)2)}V#<8_-V`UGMm5#Xz7#31*H*rglRVm5Yo=ff2LxUr zithOdq3mTL=>r3dl0UbiJCJaoG@^@Pz6#azlx5x>B7^19?fj5;XzVp?@eYAA0E*--#Nf(CrUf5oB+a+)LVTCk&~o7rXmK!kmWE z*RWL2@wGmS8BQ-}Y;2?Z5ouj|>CEuA>C5eGYrXN56Bn`ZZaf&{dok$6${- z>;aHwGVat0Av2HB-5;7?q>mp5s@cGK?K=V$U>sK4`fl71Yx@!tYDc?OjS~_TBwqoi z$SIN=sj^KBO)Br?7=^LI)b?H{WF&hycUDcOlX#(th(>ewjfCc`*4g`JTP8-%)u`0V z-&*?Sc$_d?>ug%qGf3jvE<=DF67o-c2*8FiJ^mhKUk;fJg@m4japSLCl+Cop^LNn;y>5J zdD4E#ex5@S#`PKIf5pepHboz8MfesM&o+m^$ZqBer`8-$aiBPOWjtOQ0`t)#_?dI1 z3X#&18%i)KK^zfi({nQ-m8zyNO{)S;fTVkTs;#35U*IGUc zl*eH_iG#mb+>VCBEOXu?1)*MsYT=FZom2An5DOZg^N!ljU3q5_$R3#V}x4TP)R}Z#8AfD=whZ38v(gf<+lrq-XgpV;gBpqI*+1uZ(y|Y;_vayq;r!)w$hCcjj9Xx9frKO5nYIJP)QAi z1Qn7h#3K*q%xAg&u>C&xQbf=*6K4HAd+qlGh$M6^J*2@vcEug!+_SYXXHL-H0<7Coh{r>eO|T-q zkta!!6rjk0d=3wL6nnZte#rktMQ380JmL6-L|KOUN%Xal4?L}A-(lZqVdT@A!2J&9 zj6uf0!f_Rnd}6oGd^Snp8-bw9nXf*ff;B0GgPaHFJ+Li53civ7IejqR+Sw-#U;l~x zgeCk_qImEj{hmzC4aC+jQrOsyB$HTwsqkiF@ivJwtgC2r??%)Q%J|7dI~p zPqUgt#JB`N3zW1Mb53FrI3H$NlTI#ClVNQ&l2oA_L;9=8B-h65n%gIoFBJ8rzbEmc zV`1?T?(Z#^XpR0hxPM@Ls~~K%QJ~V9A0@m8Lx*Efy1%+S%XQ(k{`ku3b*DUcQgR`~ zjxPw=(Xi%1sVF@{Iqt|X_P-{A8bOHW48j6u73x>-rB(MA2IsaL?bikofQMW2`OniI zM;iF^9L|k{vrxlzrqVF*h3R-vd~t=EDY@OiZg!96;m~1(n^3CI$hPg7*Hx*b6N8UJ z@{=j#2w%bSKkR(UM|14(S=hkN`6Vd+&d&O%b^I)F9&LK$rTMFH1*9%=Mn(f!UQ$zA zEbC?QphQzrdct0mQPH=mCF=|lmPJj)#VCDWDBuJ2%&NT@a|?=KGnGbuY6bDCt?v+7 z&p$A*hhwE~+!OH`uq0`{Uem3@G4R?k-1X;YO>0f*`9U#`D1f%FpiFuCBxjq9&pu_rFc4!_+xKRpO9rhpdAmMF#do*X5q{ zo2fd;MzKhIqdTDn^x9gj^pw1p?7in=t>AkmFvnxq$8d8oo?@JkoRyv^LC9-BC#tc^ zR*Ee~!>~9nr-*DD)jlXc^dPrdo>h)RJl@5nwA|kJ-9oxReF+vdy7k5gM!n%~1Ski) zVDm@aWp0Q$(aybxL2uy0TJAcQsitC45qd1i?!e)KpY1&(Fzn59KHCdb3>8Oxl1tMs z5VA~JZ6j%T4lcu>H7!HRv@;cGi2!%Z?27$c7?b(KVEAE5nE1flR<42jmd}P!gG8Ho z6h8VJtWD3V^P66g-Z2Yg;tKhx{|6TnCvS$7cp8HMp~w1d1n3@$pBWxSyCwDcy^+ll zL=is71qSHy5`L6OhH!ko3;jPz^h=*-s^Xusiw{#D5WNBoH>p=FWS`mp!XnTSw=YV( zi6Xeh!#E{AJ|%7alS5?PhB__kIV=u-rUlX;LRn>`PaFH9c-F_Wcec2@{vPc`=U>N0f&WCewdcjdUc`R^Wj_&0PW=|nE%0mE5lm1;f~FOgt2T8evI;-~`2 zl(y%R18?1Bi=>qmz%SwhL$$*p@LhE6Mw|sqw9i&9k5!dZWzn}X`^x0F=rpRwOqM;b zI``{Hs0~T??s=$+0|zJ6_PBQ#t^nl@f}LXH&}+nd$Mhq0GwzxNs_m*kT(iq09z^8w z$7|a~=wKn9MM%X03qSBd25`clczN`~G*z z61L+H)QPNan0KmX(kIO%z_YGA&c0s^9fUk}-HDc?CK3@v`|0V&T_?6XZ6)gv2T*_0 z357vn;_Xq#wFb*ZV>0787(CBjCvR8+gwhDI9CD+uI&ae!dJ|gZF&7))`=KKWNf`*q zSHyBn2nEzQ)RNfp)vb;W4OqRzBMbl_S5iLurK?w5GOR6PloH>Hs2ixG0UNR^y)v%S z!m;gSkIv#l{m&nPe{0w1ep{euZcMQBgrFV}aIhm4;+?zqQP|U#f7f2~t*AwaYkYf? z-bMnRkVL0B92>{w8US^G0Ok1cfm1`RZ}$ImK}b>%FFo-XoLx+B?jD%-2Om><;)XTR z)UwDW1^pqtXUaA6XW_&MgTs;QB#mk57<~+a9beJ~rkpmx06*UbERkl|WHwe1^st^R zvrT1khiPMEU=WZ_UEkTO*t2A0uw`b}`E`FZ_n=(H-F>nJgL?lf=xG{P6{A2B%K3R^Yzdjl~ z@yrDNF>cu9x7)g0SJM&gHCfaYF<)($(9M$wF)+{VC2+amn`{}*33Qte)@V(Hy27IKT?%+~~ zS|O6Y&dlxc9JB_?IDijRrg|06pHo@@mS88`4df57gBS^P?QV>PCG7XD^5-qJbitQI zA7Jxob`2l6+mU8XcVM)^F}a(|zojLqqkk0r{kA$povKtmL9@iBRC(zRv_Q{z{%2hk z%LR@&F_ww1Qj{${?yGs7Hn-EqYkBwCz6ADDJE{<@IOWkx78W0e{91N{YxLb1Dg6X4 z{wsN1gYy%4dcYRv0UmxWMd(JIioNtto%(l0Qug}3m@>_O%~iUl^Dz#bhG8O-$L|b5 zP92g*s=7_t~T>7b$Qc*+%M;{GDO?yxi^PP7f z)(Dnr0=$7#0xH%%xk6suyMXEXX8E(W@y^{+S0AG?bXKvVo|%x=t>IjACIJ3GG8n__ zbnL52cI*Up0uX~FJuHS-`;!?M7er~|q&}pHuSpX%ivf%{J zxYUI{9!tihDk9JZL}1_-M;&NT9O=vttmuu|xZ8qjoA@SwSAc8^F;93>n`4?y>KHf>RVgC7zkSgb4(C9JWmjPY|zzo8Ty&zSKJw zbiZ2X#|K4$^9uD%DD08Rl!5PQ6$^sFBFn0Q2JG(cz~=9pD|{8KcLLtC+GObR;+a5_ zFOY@b+BLS^KbYA5!6|NCCLk7SC!AF?`HW8sc)vV%o1?uOyao@dePNmFLAxRG@kR0_5INsYzjZI0O+iPuzm2Jw_}h z9(omqm_VDXG$>(8#bhVdSWmoN_QE17^uuoC8hwC@GP;^n>rebq-xjJ94IHqf4(|0U z#kT(++too0oHj`JKrY2b1w_5Xq1O7euqf>Zs$d_iSfpY341x0eCREpMUY!+NZJLrBuPJa-1B$GfW3{df2urLkcTf}- zt%!gYiv_`*;r=B_G--MC(~h0*M56Zfu97gC$Lly>VJhnz0`*p8XsU}(Gf{YZJ5*Zn zljiI-e+6>#3-*ATO5EvdzVC4~9{X;8`yqiqFF;v^MG2$)2dnqyg?wDH>0;+ikZb{O&-S4c_{)zO}leYPzf|cQd`9`xwbqTX-V9to4rWx*wn(_Sgb)jKqa=S#)xQeW-k7%t3Hnl5X?Wrdfy|JSf z@YxClSFum_9C7IVa@DeaMZKZa3acc8{)iO?fksL&AP?UI08gUb#xLX}jp)xjN2=NW z87^KkGVvq&kbprS0qiM{(oD|R8|Y)8k$Sal_9xtUlsr`{(J z6$(Pp{n#SJRROB?fuIRB9m0wS%;;`Tio`Y|v*}Ul~LWq#+mPel=qd z@0K?nYOd__{B?$ENf`!ogK(-b4#{H$ekQo?)MNfUs>K>BkuZrpM8;=X*;rnyA>@#! z?5AwA#9h18Aksln3|OBeoZ|W-ip!ZUS2uE-nk){xbbshwvhPtt*cDiH);hhM?}pA_ z*!yeO2rFx=yzxt%Wi7MW{|PD>e`r17oS5+ASWV}i4c>02Leb)`2S0@X%m zF+eB7Kvr$e97CVAgtNgU*P$W9*qhP`@k0SXO%7iRNK7D&^TdNlCf@M?g&?@rQfHT0 z-JUd-3L{+SJTb1m-X$X>|COTNCnCuLv}5TQX!Y+~0yqBt9MV3juYDI9Purw~*HBFB zuI`E)Z)cS8BaDm1rWr`RU)PgmB!bpZhEr+s!k+Qsg5sq2D1=TWRP@l`s%${MQ*qrI zjcfC=lG)#Z&dOd3FA9oGXaL3_fF8q>IylDpiHw}*2(j{Ny$=AtW7tuhdk6|s00H-o zB0Y47E^-zT{de|Ow2;T-%?_7HcEAwF%pQ59QwH-p)rv|@ea!Z3;%iVzBHz%O2Z zs8xSf8mJ#oGM8%BP`tFI>%(NJ^f>~t5M4)*^9C{kAi zbR_L*8BCkec9#fEn4_(TB3b(h`(~P>KQ4;TiuLGDUOYSj$QptroP%l}Sx}I76z87Q z3j4{}9kHW26yLryfOUTA#_#CO|q=_S{lvWAd}^CWx0L}s=Afgal9uUx^RJ8N^7Ur#Admp z&+xPD5&)GTs1uL2`g}Q$DgV%jD`Zy{hw9ObVee5$H%cgZT$nI)@2XMhy?O?b=6okk zU=RM;n)L_RmjLQz+-J;waeUmqFyTZNv?MNduNqM0?ZbXZw|NE@t)~KUt7&~xXc_p# zVY{?k_l--#gi)tt)FeKF6}tQ!KJsDbGNM5Mh3MVZMT=y>eK-a+kmeLJauNh@ zh_ra1Im)+$JQ+~96kmj%-*#iAR|LtjuT{LH9U8w6X=0z@9Zaj0RVOmj-gso3S8!C`QjflA^3WC#>%y^l=<=;djoM6xiZqRoKH5P)1-)~ zNo1=rMQvK^u(cx^lr%M7Qi*F9b0K z-yFi49|q-_opZNIbaY0C+XW~i^5+@nhAsuZ4ON5^6G>(pxb#qt8VQdVbet^t1MwjK zXWrC$7vuz6Q#l}&7o99s13z^ZnuXP$aa~JWOD&E7HEX>@j!_|^AfrAb4KDRCOcM2z zKwp;%b({~ zO>Bi6Y<{uO8m6-{=ZRG@?(Kamu9W~K&^Gx|4iXU;0hbXi|31=VqHP!e+TjGW8@n{3 zN);&a4X&Xk6eP)Yq~=GHTRKyU!(W)_H50~s9R|h(fA1jhf*-s*!D^f6LKtc=#C=#S z2E82zF%e`K$dn3pFR#Bstm)VDD&|q^V#~aLr~_*oc%c$>uq<5);F3hT9q%=&r_3WF ze~3lp$^sq^ja##EPag(yw=AZb`lm<p@as(|8eq0M{H)36kE>w9addBSQ)Yn` zcJNTKt1>Q6!zS4fS5jNrGv!8UWdw?%~bt zt_*YE-?j7cnHqFdA1l_~@WI5nNg~tddL`&GAo6rmA3#sL2?k{;A6+{HlUK;<@L4WnBb@6% zLy`*p>RduGI=FVaYS9>E4dKYNJUD#~U>$jY^)E4C;N-DpzfgiX9kY{;(2PK2hlW68 zm=%P~Q|mLdVuBhm*&xpbfM6<|aYX7B8k<$ah%TbbSP6#5S=5z!`&eQO(lJ3?d}q^| z{n{XXuT(y_DkjXaQktV-2o{ergF*tg^L-s`;?tqCh%h2JsoX9;ILlN$ZZQyH@IoiD zSKJI*3mF;sQY8!8-1IjT^#J)_8z1d{14l0iguS? z&d3P;J+)wwb284_G&W8ISPo5~_+h0k2h^D!(ah%NP^|ODY+aF~0NvKU4aLGmJv-T$ z-BRglDhUI@))0^{AF1Nxt>J<>ce5>+(6jmK9MWiCitZM}FE{tlx9N1sLS8d@crY)sSxg?#Za*d%)@(mFSUrX+PAF)juU8?? zJNSQ{#a;%|;92KZU;2O;x%Y`Cbj!;{8Udp3KLUl$RK0FxRu2Y{QrayYGzY8d*qmF( zDXZhMqldDKHLQ@LAI5xK^-`85WUo@?OIRnm0A7Nacg8E^OoWp&EN7T&xp z2I(X=4&j-rcI?HuciKCQ8uR8KM_#`9$J7qm@t&z)0I4v}ZZaP!?M8wzdHeY;VSGcKyt0N$RH@gjiheemNss%wh*HoQ&|TTlc{j+W`u-W_u1HX= zSWkmDU*X@!*pKkOmmzN8L5%?B7|=YO>-YcnCI9>Ab~qx)_y79`tMPweZ-4)6Vb|ut z_P@{nedOW$|K8;9YOFu{e{1r;D}tbM|IbeTcUAuXye1y|k|4B>A^JSntFDRwLL(26 zf7^oSW&YQX^&G`+vrivjND#8LHC<8YEPBXIEIr{2I`1oUg9iw1;w%RaHcCqG=4XvG zNo`+a=uU0fOwIlK#>bzR-BE-=D_|VlRS!fLD4aHkSO;-Kw1;!;kQ{p`lDYe$+#zdP z6GsHAA#?pHY`i|+^Y?Am9(iP2kxJ%+&!eXr)}?_wCm{m$rAzCLJ+CeC)pCh-P6zb=KmH?3J5s1iM(=lIXpI3i8 z9we#wf@{YJSm~oU6TjwcYrQomFHSD?pN+<_QiPbyP*Reot=Nz{@kAwM_}LSx_iJUe z&_^KW28hnzdE;P~-NWqtX!PdJ0dWKXu!ZeAyNX2H6cKYBruPnmm=Ch zWGc+(E-tSnmCv$+C%sdWrEf%cHdD(g_o^xgOR^$>ObE??h6Im&!vtC}?7;kLJdE>D zjv{g{@tnoYxZ(6>pU;RAU52=&We9=cTTA1hZo#>s*+K6k`(N#}KPeN)!J7zuF}-^Z zDT^He>{<(&NV_`Y)c;{Keo$^%>$%Eh<3#f-iNKsO1Gcz}X}7j)uJZ>(f?|U-h`R{y z(vb_495+J!K{qf9kO`?A=1VtOu3JLt}6WrWQqg?4G^{r#=>9tl!3KQ}V(F~qg|MA#Y z4|Ht3o_b}N4R!e?rgmuU$}mH~JId&^RsyXGAgx}6EJ&Os(9f^0!7(>CH_vtPQ&Yz~ zE(62#_O6N2HI@zrL>Bli0h5@_gZ95Wt@ac^-U2xZQ(!h87rdv;^8EQ*bcQ&g>IOot zUaNYs5;sFGYNT;B1o>Th4`f~OV9g@e+-sd8(_asDWLW}a+88C%f}E3>#sSBPF0(b^ z8u3H_$3lg`LTxbMKxGNJ;h{U`z&4rgyBd?pCV3Q?Cn^dGVSA2n#D;~5`yD=hJn8Av zUHG8}NjR3H1|YpxYFqbx+!547Vls6LT`KR`9#>QXZg^9Sd_qo+x{8WQRegORQ~O=M ztJkiL=8YydP4$)*j>JXPVZ^M}XPI_2_1X-5eRhblt;{N0#;@F$kyAQvI6An)Uet1U zX>pda&1QZA=eOd%SDYK+r)+p`l$Jwv)&dBq<@IN>&71FUj<;WH))i;flW!BFfW-*-os+^{g+ z?|^e?Gi#1jlpPqXB_Rrg|1%PlvM|jsd4d!6QIvZ#qFAwG~26IfoKkl@zZ`N>E|(qjOq=J=)AdOiGIUeQ*!d0b5XWKn1T ztU6q;Q@3-)upm!IO3dVGSdWG7f0aW2F45YfJiqL27a8zF4)lpR_QSUjLj}nQ%TL`P zMk*qL8-52W52LWt40!$QPRd83llhaPrTf1CsVXWu+U3}sMvt0{OJ3;xzTJYyUaAVF zcu_I(?R|4oULJ-pa~(!iW#3;cyD-$$E_hvDz6~y0%ZS4=Dl1EM&9mF|UiL-ME?8e; zd%e43i<{X`FuRyWx(-R}L*K3Zuo3UFL;sXZiG1d!rTTuNz{Fg%Y zK$Oy*kdAS--`8`$pR#baPS`JUX~{KR9tOh11OFW&`}fu}w&dfI=etq$3D{DIFs#DD zqQq<7D@ELs7PMkCl3`fNh`3C>%Nb*F*AWVLke9UpJY}9nb1?W$@Dg2f8>r z2(nhR+){7`LH!!UHssMnC3c;>=1~LPB3F&s!TQP62N=hE+$gryg)~2@W^Nuo=9phS zzK>D#?tPt0B6XJ8xB-H|cv)+bR^dKgonrUGYD@TP2ocI~e?;g2%&bNfaez8>|LL9A z?+UwfW!e`#RCYvf&ZX-)HJh@QR;Qtky?sV6A(+o{#*a4`)jJ?Ut zh=POi+0t|reCUVnJ&oCIT(CGci8w1K`nmUhe&KU-_N4~RjDruz87`3Qb2Cx%R!Qn< zUq)@M3#L6ma8G|Z#Ho848N754ZbEpruBDALI{)s_=6~xlJ63?48(@jV%VR`gWIYVx zwBgaC&cNCP)nNJlavuwIARpWJxXlmy@OS5!lfXv7@xj!3sV=N|@3e8vJKGi^kUx!* z38b6g6cN!8BR2RpC#Z3Pass?EM8Nj4Ou%_ty0gfeBDRE_>c((8P}8tDs@WcKjf!2b zsO!Q|dRA7e(b2vjUJHA$Lca5iqORX>y z+01`c`FEu1m#TCkW&x+f>R#x`abd@rsRgxAt>I48-Rb~#xox)I-rlgpUHd6qs@bzm z;W@v^#KiM4^$#LFEl1k#W~mCt_>dU5jbB8|W0JprSN}CQXi>dd-yPy*qSC$CzuNpW zVfT#_k573XxO)25kt?s&kDoI?kb1S5=E=ngEzNAzJz@@5lBg>#s>@_%mSrj>mA`H+ zJeZx=vLo_SY!qEt<3ZLlkIU$m_-N7Sg)^m+p)5?j(sGw|Q$|NrCP;TDNgodJ?(eQY z?aQO_eev6A->;8%AEK*mYiTKiK+VmdrmTDzVvBF#e7Eq->}*7IG_x*hJkv9!h21}Q zD;d@?OGaq(wyr%ES4&O1KS}7c_N{!wQUT{bA~KSR`R2R*^IdlGZvz61FA;pvO9m0S zr+qJ7>|cQ+Wqq;QBBe0Ewll}JpRQImn2)`G#mLBr%X#J}&=aiAKCP**_dO};dBxe; z`7kwgl!#SpqK2Z9Qqu@YLrv|-5gmKQj>f9+@bI&ooYtSZrrU~F9Lg_#WmSsiITvy4 zqT%&Rb680;jMCMsSDRZ~vs~E6&99?#8#061TPM zE-M?GH{_US%3shPg8$3~Hk$J*KM+o?MZ;aANDdt7I0UrY{_oA4Tbd*$gsO+!-- zXRZ%c%X-!UR+PtRlg{tP(vCCfJTv@i^5nVLQZE%Z!#qSOo**_>)=KrrMet)h@B3e0 zV~`fs)70#U-Jcd84?@-F8YlE_-8wkxOU1*b#T&fSRteQ8`KV z)-AuRAyB%=HIvcOLoeeq23gOZHNDu=9H$fk)*USyOamFii)#$dOrb6#aqliU-)qda z@|N;t2$>%34YEkCT)+nnkFqJfHox)}xiZlNWOp?*G#cUC9l~UkHc5GJtN~UIl+XNa zwcxv>Ixdzy{ibd%=)Sa&F-uKNt*E9J=%{6BX<0rSf||(n8PIi|E_?BwGv~)c_GIU8 zuTIj+?5nP>*0Q5FsQVzuWzhu@uvdG|W)E>B;2$|JivK^A0_!87>w_4P4(YYF>t*9n~N z&w~8qg*pG#*p5N3EWQ91zEFmSh4ncT$n?_d(i$G_*mGy3InFG* z_Jbg|=U&=RWp+F&e7O}2{~vqr9o1CU#ew2Dj^EfAEFfZFY#`F4BPG#MVGse4CL*9B zO`1qA0Y*nfX-e-ws?vLhI08ZGMS6)6=>!O!5R$hqsH5LFZ>{&oTd%C|t~HWL?mhS1 zeRlc%_Bm%?S21Uc8#iwB%(jk?kLNYLh<9FF=md2R3OZ+HWffRs>@vXRHghJvcr&x? zNs=?SO4NSsYc}UKw6u)*6UjaoE?kJ{LtAuZ`Se1tjKoH_wzdkOQw5ARF~wRAS*`JY z)%lH&T>%2Yzsrm1H3P=#eZilXm30ytX${eCjFc3zqzr)viCmai&zWiCUSb+_6`JmoXP! z8I8!hcs@BeIM`<9udRjSQJ?}g045XbcXCMZ8#hF>e-wtVSSUKqLtZRLZcdjk`Sm;W zrdJXx?LY&|tFcs2xydF_{|l9z&w?w?`~rmoYlERLlY(pzO~$&; zk3O5f%*537!*?7)bQd_v)<$h@ZOIc#Q5X;R9;^}~Y{!oFcB8VX7}u%pNgt>+Xc_8i zL(cD%&sp40(BC4KsBLPGvl{U!P1U9VriPRkJOg(noX;u8??g=$U()ILz4w?*ki zMN-gqd`^xCKv=FO48yHwNjwnhXYZGFx>H%&rvsZNkD(47j%zo7bdMM==hMR9(T|8p z(VoiyGxhRXA{s=93}gp;@IEe3&sp^hqO5YcJ+*^sSECwqYUdCyDQ3$SMTH*Y9bJ}7 zL6o@j*76j=(xWI?1=Lkl<~*lPknyNclT7< zY^?2(4g>*67&wBQf`Wn&Y!FvaAVHr`##FxcUC0L?6Zs;aMuR?zMeT@VUrh?K;Dtk#D*ny}S_Zd<$jYHdkXx>dpuf9C#Wf;hd&{ES|S1BVw99HxnN}u!?(qqI+omKg@Hg8DeGuvW=61VShH=ocs`y*IrM;) z6__9bGTt;a_(IStS|0Rno6^?N3AZU{VOb>?8sugcr@x(-)!)w;y8T8e3w8O_Y&&)<}PJ@qM zq^5e$&d%ob7HYVi0%zybE5Z6$BJriU!nwU{h4jsQ3VsLKj~zQ^R;U~;Ee@v9YQ+KC zOR%i~QKUDvu^)7L{|9@aQHcTh5C(&R@;97d#cCQFK@K<|(gmYY(d7rl+ORLu()tE- zU1s0%1o|<572`ZrhWQLs)(+J>&kZ|LxqgGJTTh1orc~U>ZFV^@!0+5o2CKDQLD|es zAiLp}_VT4$lk! z1+IK**VZ(pfCu-+)b;ALRel>0T6LV@8SUjdi6r6E7Dxzla!PM zm=a>Mx>K_6sY3;D5O(jq5EB35&O{)kWZ#{$_2HsC!}Z~NwzKkkrKF_1NJ$yri|LeE zkBS8vzW`ccUm|*VR6;2dzIZr6jdDv*PftTrvl zvsTy==(+3mhu77VkLpbc_0i}X6ZE*fmLpj=5_w4=FGfvdOCDzHPqm7?b3j1xQyXo0(|##Z9_xDm0in$INDJ% z4M1ga^&Bf3TMdGHts((2_Br%gviqmMeEC>Os9cP+IM;;nNbS@EJSc2{NPqq^iCOS1 zSd+Ixd%h5#-(0k-U3SHXVgk!beqryvPN zf?S9{yQr~P8g{Viu`oFAFgaI;cvms_B(*x3YY3PJoHh%zPf54NntMNZ5uRZHT3K+> zF*(|y)+yuSG}fd=quT_Wde2fGgn^3H{%K?Wi(8mtvBy6FFi(}yC=Z!qMsAGFcr7}Z zpV(#fTyDYDB|i1xpQdznuQ~9lNFQKFhTtqSXNZVnH{O61m8OSC0dE+Cm#w@izQ&1?R(Gub} zSsur;8<4_lWk7j28UU{t@nie?06>dcMT(ni_1&r5sivvfeo|XqeUI&M{o}`bc#Qc` zQ4*F{b35|#S27?Em528izWedb@e?O*FGhg=4HooTty_D|HRqw=XOWSS!3%l$9M)yR zWH)azs#tTL}cS40y*dsQQJ~i_qlF z!K!(#n=bP$3dzx2(srf}4#%(6Veu6Q3)jY{oJPtSYtQR!E1cb< zTo=P_sAfEl5nQ_gB=L#EUoOoZ?dH;F zv)UBr6si~C2eAc3m^N@fmZT>sm!h4i5{{q0XT6#mAOY?aV|>kC`Qi$_eW&f*aD%_K zfygndFglkG-w<@BAPu2|bMGr{Dt$3ldE~(y>tiTpQ$x@m-9DJ9I0{i+gw?`o91@Kd zi}`bNbB47+yuyYApS)gR7NQ~7jcZTAP^9%oBHRfEsi(CGfOps_TjPr{vZsvDh`Uk6 zahl+;i&~j40WEM-$ zDRAP%ll#nQsl$f>AeZN@THDbmHObNN>f0u#bg!mu!X|*@SXf|}xgQ|v4Ait;ia(nA zqPphVoxiqFqNm+>NQ(#)n1&c=cJ9wXRVL3rEkdEH`1%n$1gMM@!bLNQUx8{7qJDB# zZuKvnPNbk|C&$7d49TRXxd{P%WV5N)_E*)NcdmA<;MlyxuPDu}+iV>=STra@R1Qyt zA@O7M{3*qV^EL!0r3U7NT>jKc89!aTTpB8If?D=I&GzQxix$P2K4V{$cuC7rFftx2 z3ch~5)5}0WD%e0*6>ENE^0d4R?v(#pk(9d=$)D6kA|wk?Y3%~JaYVstMTIRrSCo~N za~I969*?|Gi@$2Omg%uJ0s04(T$R8%^uZfTbQsMc&yFa_D=JbkY(fyTL8ayA&OY4n zI4373y~GM+(P!7DtfQyoig47khi;jJOQ{Az-f9IoO7{?ovxSD9UjHR}w?@lIo>^Oj zZ0FBBQO$mi8Q7g%XjzCL9dU-B9D|QAv>QOPxlvq9v30X7!5eaIR$S;!lybMPOIt3S zyX;Yv1&%>J5gqSGzuRiFiOEx9p1!ISKU=;O**9`8xAAdc4KB=1BVz0$g@@3yFfqvE z*cT##u^ym>Nm{#C7Ts!+yXi0Y(@_)~*v2*6is@6%G({UayU8(>eWYV&eWc;6xOO#P zW|LAleqm9D)X_U_v7ebkLTsDs;zemeCMJtL5p!}aKB*ax@BhU0)J`LIGrRb0jruU* z__ym=@!F4JkO%41bJRO%TQ7`I2YDC_?_5ME>?Tpq7UG(GR|HEZa}X!oW`+M4pcGeD zR@OoTNVEso{;eorsRH_Kgt(@HE2_GD{QN$VRwcadYZC_66s%%0mU$~xu6i{pT|j{poA z8~TCC&MbOb<`a>Q{p-wqX0Jz#i5RJgLGPQeJ&t|Jb@ru6fFI0@9Zhh|)E6et^4+X@ z&WNzvd)3rXv1RxHE;7Y!dQE^nqz_1Y8{!;%P~y^K`}KayeEqs;;UQObaiEZu#&YpX zWeigJ(NqR;^D^#s>O1;!kB%Kby;c3nmHiO6)DXCcQw-$Jr_zvQz#~qVV219k-di2{4}Ezs1`B1E0CZDmym0bot?#|%u~tOa{7n#`%Y2`LN9 zl9+-J7>hmJ7rlhAV@ft}kA;e@G+Kz+3|1kzncNs7;PHrW@dMJh+&}HkmTTN_TPp*S zVEZxdZc%QtZyQde^TP-r>8L3HhFbeV3qsLYY96%YzfdJF7otO`}x}R?Al+x=KlJYxL58@IX;v>@Ikt z!KFEUZw?Q7E`$I9w$pH@?)Fmc%rK|do$-@FZHK{Wj!Ee?LR&XKJ~lAl@39cpVpHj% zA8N?#5LYX9KuUl&)vyT{+zj^}#!PR^o{GS;C1v+=U(9jXo^)H#d(rR&LOWg2jWQJE-K#QFcG}T*2!-A<;w|FNEvT zZ)hTbHWuENoJ^87o!Mm5{C@KR0(cxsAwKh>G%Vq;5S^Wyjn}?$qg5CJmZm0lQS=8y zX}!G8?mBgJTja9!pFqBRdVgZ(jGDg-J#eo5^ybOy9GwvKONK^xP=Gi*C1NKbSuyGb zaSn@LpLx}}1jtH(3v1V73CTAb$j&G9Yaz&gz{!v>HwF1BMWLNc*CPw{TEDX#;djjq z>xax+w2UNTUO=^!!ZU$yl}?rL!H`8R7J)tQhvajs(q^Wv;cqe<8%72FBB)4x_!Hud ze2=B5alIE{UW70hLUz&{$(|a%9g%NKQdV+Ng!!U-pEC%buZF*#yGADr7%u!*^8f~( z5Sdcf*B?l;g-N<&kbMi#-wh}~>$?wN$`96+0eyuT4fBOlNDM=tFeQ6bbRT61@KN`gKgP!Nanv7i)_g)zl`-a@YIyYw+LQgH<_{ zc0>8(dc9L?)b#|;#$`JU2sQcn ziZy?GiHJ+!=>6q%SQsvo!oE+?nMH0FJ`N9aqgbG+QBRuRCy$?(J?&wAyG&wf;fkhS z|H-A=iMgAjEeZr}1x{38g#_AxP8*5OGPz+_*QAQz>raQB8CPcqA;6IcYe%r$Q{ihw zv!&wqWjRvLt$!xOBhd#H^hHvJ{Ut*+m3Ur%QiX(Y;*99wK}s_2Ne#~Z#F*FT2ANKa zqNvraQViXZqewVo8h!Qm#H-zwa+iJb)z(~3&f>2C!2NPj>^>K^ebr4K(xOvk2u8O3paTN9Ut_?4*Oiod0Nb(#sU>7~%vT)*} z^e$+k`Ry0A|8vK|E4w#u4XNtN^n2{(#X+AdSx@SIx)Jh6*X;j0PX3GSeAR#W_&=}v zNgn_AhuQy)z=o**p94Do8-f3gz$bIx`15}qEB{Z0{-;9!Q=z}1{a3R1AC&kcjQ{7L z#Ozn;(N_f)E3+xf{c}Qudt1B%{MOS;Uxv%>v?d>&(}O-S+htZ`mCa!Li5s;Q5-i5A zJo^m`>80>;TjKt_rGb`oUH8^!(sJuLh~}R{gfgEz;&hg+oyMV3Qy@|IH+w;vw!1F# zaj3?xfBPVgvBV%y^TicwU~QJS%nFg4FqVAriTm11&4b!lbMxidfIij& z+^-ssfVpN9NnV=8^Cm2;xWI9v5FtpHwx^m87G)~eUzvBNRb>*q)GFg0ar!WM^HmdY zUf4*C+cD0v>Bn8zCOy*NOJF&Oh=$1zlwc7LjSc@eahhJ-vN&-5Jf$?5OMAKdhbe^T zSR`INkK4T1BOr)#u$kw|GXMB-_$cbr*5*&Q7>(WLV(0JW#mHS_Jg%sg;n|Qr$>Tn! zeZ7~H?O$A^Pz+vZAguJojnx#GvvaSkFB;5wtEfO0`$V$r9G3bqrD`oBFNrkyJ{N|+ z;Vb=8KI*60hcAy3tKMb(hHa3FyN22tvI}M(N3_?ukz&n+*m|bumDVzh!}42!+{?o% zYr~R*LAE?!Q>9NWUFGMIRJLIz7J5*SBD>l;E>oWZ7*GlI{^+hey{h`F&6pS>o{AtWIdLwVxwQXprNrhcicFp{2{=acg&-9K$r;+_x zomRZI$eJs=G?g>!i35uVn(ghpBo+ncjcmo2hebZzs@7K7A1vY?H#nAX#C{uYb-qm% zqb!Zc#ZEp|yR!=KCslIe+Ld>r^wpOJ>uZ+2y8ic9=0a zYcN?h4%uD*95X_BwdA4UBy)aO$y`!h_=-BUiuOIHjRZR#9S?n)e`sxNur%i-!b-CNT znb}Or4;?TqGaU9EMt2&do%JWrr0J;>b)NNB^Me!B&^Ov!z$ZW{eoL_mp-*NCoLMT*pvDipp z+5Mt4km3WsONI#O4WzAgcBFZ4opKlBELv?D#l#vMbZiaN3N<&rY+%6E?vV*;gnKdl zh)5Ew_psB3D~J*irPGaDWi+~zm%fRqqbKBjRCT0|@HKFq_D{C3Tno;eX=KQL6MRTX6I zjyZf-uxvgpE%dR&Ek5nE3(~*k3OJ{Vbin%n#PAklvsXYy^JBd-V_qyRw56A$a}Ri@ zUtIq5N~MiKiWM#Sn37X8^Nz1)lIDuN;V(5m^ITB~T-)y(nfarpFr+qQBg_*crO@R1 zrQ`fwXPGflyHLVnF8daI%F0kTSIfwTa1sj(L-Nk+WvYj&%oDa6}R^xN4@tE;G}JmNk6jl(ugN6Rev zsnu|@sG_OYBBL1fbKKnJ=U`FCeR2Sj++otw#A301^Y?o>_We5L@jkX_jIC|ZXO6on z7ONDeN*qvY&UO7bME(U!UGzg^U2DQfem?9N{0Ua8asB8wR!l?JX5p6M!j%Rqf3olTo#UgNX zGr8G!t5eb&i_9J$MYS2`;x4d<(kK3)1ZVnQ4A+%nkP{3c-=({q5oGsGEbw#)4h-C5 zVj#tvI?|WzS}4U=Hv%W*@NIaB=6#4ZAetLgtdv>)>#h_1`09~y(Bf;hq@W^d06XuT zhm)m^ ?5c{LXLKlUVBPs{g$ohx>moNuoUB5iYVuy)`z(|wtl=|7I|%#%_#_wS`~ zR?ID~Qd+1URk#6np@*CzQdH(qwl(UKZ&Xy2gDwMXn0RtZ?Mn008|q8_E7wy8(i<$h zW)$6K1`KYt6SWl1otut!J#T}Lj*m52;+|GNSM{(*^`>P*$-|YTf#e39s{X((Wqm{S ziH8Mz%bWLJUi*uhIC-aC)Fm1URu(2ECP@9qlnkNeno^|@ z`K5cPcxxLRVy*MtiH_dH#GM@I$M)tO3g@b1tntK?#DKC5<=^=CbdC5|e(_E1ir)b_ zS(kbK-(BD(8K?KGN7U0loZxg8H0tv>jZZ&ecrede5fP!&obNF;YUWm}>^rls+cJJ< zCmIOzkqUg*MSib!yO12;LoUsTkCr|lyo8UIR!Q11I4((5e7vT)PN6aOnrpW|@^HL! zh`xR3T;-BFE{LCslMq^Sb9~zx^h1N1A#~aVH!n3@6n_&O1E>KgtRZ;IbV0vs1$NGT zvTJ_w<^wN4n`}BB1m0|}IObcBKaB066>~aNWi2lbnOF;x^6E^I`uu$*uo80Hw%_cj zS1NQ8647ITLnR1VV}B*vOT>7LFAyKoH&ZPs6;?1-vJ$H=<*V_L9JYOGHn09#y_J}? zQJUqGmR5hBQZ7vC+L2f3XDzd&iW7-H2asASFT`G zUi%0N1c#XXJXdpOgA#XX3YAJN8rxcE!e8EJO5so%r1DukfsraSi|LLGKF5ik_1=l6 z&k9$n7nIM;rwX7MCE^pLnoc68dd12bcV`(mIYf_z$vIr-!4pJgINO37wdoXCszLmt z;ii!!@P^*oRZ9=+yEs_oW@b|+7LdnoXqz^jm4--Djhsh6>1Y!eZbF>966Mf$UYDQh z;}r&0DZEyTtMbppm-kN#eXTK1Q8H6q0u?p{iY_*054ZcV0;2nBLMR|-i!)a1v$(OT z&TB&tcu648YH&pXe(mqGjyybnuF8OME%@?G7wFM)X#W!qLb>Bqu-!Wc$F+Tu)!-%0 zojd3M`d-(PXSw0%q5FthlY-gwReAroMk54-(Hf>}n%o!eOYj=#f%=>c@oI`u#Z{E6 z6V|rQn>#&Z?prgBn(;M^h&D3J9!r94q=@^nnR$M8uyTkG2`ot#y;ttKtk|=OUrSCW zCM)CKCXEc3%nW3A#%^L^F~c>1T)g}{7)J;P9dG4kY)TA}vd_i* zIv}g7TTB{uRVRe%yPk=UYf|=w=Z%5~;I;K{a8K#(RMMU=R%2{h8#E^_U#RW#ZYt|A z#fvjCci6R&B2LJGDZkY+LMGI{9MU15!gb3aX_5Mx(`NL;S!9_ulbHG!t1$n7yG#IQ zp(=ej(rFn#Qbp@zh5Zi;=ADIR1p_imx7%$HFKE~@B5&v?BeFhnv{~D^xXD5F!w@*r zh72GcUMW>H`}jEdAnf#Uc9yBSp|HanktYiaf|`Q)w9bQ3M9MOT6JwJ@qTjt3wN-!Q zxBmdXY~hcG7ZIBWABx0Z9o@O%kTOknxXOAzf^6`cBmL{T{PjKC7A6vQyL^ z4~JZkvAuiuCgRcowbi`BE_!*H&iz&^NxTvuLhq5RXiX!{&N_I33GDTjJXLY3X$0tx zWzLKpsaP_2YK=(KZ**SnmU{ml8^M2~;qBYUz%&Xz^nZgsdJxDKh5e9F&li9%tYw0zDq@(XQuUOalG-jiD&&nnU>j`u)VrOap?P5FcCHmxSzhEPvu+EU*gb^vNk^Or(h8!ZFc72pHAPXMd`Bf4}0F zEpnxWANs#Z%}C$NQBv`(-X2m_!_{uwH}ayTGG3wS-uk?+K*Bos?CJhw8J#EHN61Ti z_&!=F)kKSN11RZ9G1ShuvX3ps;e4x-X+Buq1$KPtEwj|=>LftA>K9~=9Nx9Rq>0sE z1{6)*4NQ3y~5MTeHx|9R^LMmu|We?34KaqqA zxB3|bhUpT~U)#g4n?8q6o0^8D#Qp)xWSZ1io?)OLFs8hXEP|);s9G0I6+Nl>PSmIj zfM43JD`5EY;qJ3I$NXj4Pllc-kcA+nv|Gn=0ed1m=_Vp@aC|w&)04mob`N#(o z#NuuG6b%R8SMKrx-V%VgroWF{EibYas&Z9n=nySI^E~pOj}0D^3yElCcKkHo22XX6 zn1A{g1>DgA!kmHb-b()ymL1u^KX2b%RcdPV1-|xdes`gwoC~bE_^r=z7 zKw^l~hcRtK+FL`8hUwN!)z-Sn-FNTa9gJ;V1Lasl*c0WWVN=Z~_AB1w2&izUL8qv@ zyCz^IAtVa?eVk_1yQ`^l&1oFS?|Gb%j}4FC1!QqT^My}a3riyYSe<+(;ttL>uA%}x zc15ptd|X@rZ49k83@l~Vq3d0S!A<5LuMCmBvlMSeEbUivXmD99XpYZ{ZLcX9UMbW9 z3=V730iYwEr6}J`zR!WN(+1*KH7J!*=`>(U0-&*80nFb$Lwp@|vQI`H<$A8Rk9Q;{ z3|Yu=o$uquZo85Q;9$u<2&YDXFo;iC`f6zQxOHmpJdMjkBT__gv;~{mw6bg+qsYDl z4^B5K9Au3qWqM>>;N77I3Ly2ujuC_f<_Rg07gxKJmQwN;<*rXVtrR=&B*7zzQ+sL? z(s7KWZ7v0q-E8xdoNyOk@Lq&13?e%v-J_Hrp~qbn1;%>()Jm?kxo=!TfpD_NM1qsy z)Zp?eF~u-bXp-d-Ct3_yCi6;bV>{K)Qj{P=5&{7BGiE!s{;f?NFU!GzaxeTcf-iWX zIYI>JYb+O!fGMLgfLFTt83+&Auy1m`k}j=+ALv;|#Vz^D^+>I;SnPabI6#uvpE7A_ zH1|XM&U~EMYF7iU3A_MLc~l;vF%NMD`$jH?MTuY-`R6|`WZ4*i>l}ff*kA_E#Wu=$ zpelg`zvwEiC)kc&pInYtuSFxmU!8J0-Y8}LTWgDS@Vchx7dou5`Xel_kd^*Yv(weJ zMaH*QxXpPjc*{D7wQQF*@Nm9S>%`c+H9uv)1JI0?_^e+MZ0+TRgqYfWsR#DCEy9pN z-h`8d1yEFLT0mR(l5EJVih&DxOGAlvdOnCjc;o{E)wH(8(NNLUBulv&zu=$&-mOMl zQbm#13=88(A=fp*ewfc%b-FY)Gu0ncv7(|O=!_ZAFFJ}W6=iDXcHqcT3L=WjjO_>9 zz1aaZBh+HuSK=|@* z)P_V(rYr_y72~cDX{4`~FA^Tv{wE}L)Cmo|6t|(3m7PYR0t>xF+4AZi>HuS1p%EB# zCL4h1qn-F^t@ODMjF>?4;ioHFeI* z2jlAS?5jIherr`L`SWw~r2VB4i6eNi0Yd$92VRdAa!9uqosj_yjrP#%iEBat*`l*V zCphRJ%4SE*0U%dD6Z16ybRj(DA*Cn6p7*?8q%QQ62F7p|TaF}tj#<2) zwNCG_{i7MUew+wGaL56$mto`h)hx4o-X%(6DpGhM&g{vjb-V?wt&WU?Y&d7x!SAg!ge89@{?uCLSYTpu%I8}bb=0dbao;(fUtSc zaAI{{o%>#dqsl+{1Du&ZfV_L2^Z@(SVEs)%SeJ`wBCgrQw9IV8au2t9V#9KA2 ziaF3)c-K>0#z263cDEpIY9h0UmldC#a3cfqum`1Ynx;-0DGfm;4N&Pr1#f1TInI|EChB3m$*Kl%X z;&Wne;32f%td+zQv!;x=G!n}=A_s`VHK=Y=bpNVdr@#E2jJ~GD8;BS{s6- z3Zv{lB)b}3V|GID36dNfze-Z&T7u@t=o5- zf@r!kM4IZMPQ4&Xi6sTS6O$r_BpKTTap9Loi{Z*Vqi_yeeY@^%o#IddP0;_od2chD z$|Wl^JFdvuGT|k-_TAK$#7{;KyAQ^ubf;lU*{PKP>zbA$_=67~$OZcSXZ`nP-o4W~ zO6Sh`uw(6@a0?-GsvKigzSbCo6-3X?I(=LkdWOS=+Vxrbt*>evWr9$A6VRL?#8gu8 zlTgxvHtQ{ZjXn)T9!SPvkzMvA^fz?j8p2{7gu2C%INw_MLNR$DTH-0s+EYZCdq$L*YtbMMbSs%kN0R?Ryp3M>^t z%7c(yiXVJ7av-?XY6>A6qc#e+^51;%ld(sxsx1^Q4>e6Ih4XdW`Y(?)hz7fDz$J?Z zu>VkMHLcCC>#wtyu|>p2WNM(bP%!>ea<~Rk!#EMlClOexIMvj+(zLowqT0 z$TSA?dk6Z zXQ@-oX0u;lYTkpMo}IiNXT%N<4s$Us!6qu?I33~(y8GG0ciJ{-mhG)krQL3ADd7D_ z!I9A(ednqCAe*-_uICir#wgZoG~UG>?b;E}OCq9|Z?!%vIQ1>OWb!<>WD$is)0PaE8hau-)5IeLbbI-lvU!9$(eX6*Q8){%Y#;q-mG&(C{!J z>&992VfxJYz3*Q%LmdpVao2~X3LDK7nANy1@F^;EC4F9i{rMidHPy?xjjmnk68*fI zy*F>F-|+%F1G)EyJg0+ctIhS)@7~1I^3|uE#gTSCS}$SSYK+@>xD#v{ zyxEoYT$mh@8lRPA(wZHs`S&vGZ+e+!kL*d6v1}O`9&U)_3RFSMtrx2JA^rP%E2Oh) zYnKM;Tiv7We9}_jei@~7OrYPUe?5IWeyu;@NZw-}oC`Or;I$~b9j0HZ0({exW<@9+ zAipr)u2^F#I;W4r+5PyhX9^hT2iyHAu|RkYJMcVN>%O-e1>%x)@N_vbX=J*XggN4| z%6Pm9D|v2-Mtx`Q5yr)3-6g1%(T^TyjM!SA+;MlI6i{59Nd6 zPRf&-*Lu?oYYj&CWYS;$N}m}}$^O6?K3e|c?_bu!fAe6!?pl4_6DIxP{;1xi&Ot*S z&I%E3`0wyxd8~I{QHZ7I%XWpoZ~Q9^BWZn++H`q!AuCkV@76Zgu0+=-{x1>^^Pm zG}vQaEfn?G!@`27EL?IUC}=W#z39u{y9KLa?ru}1YSMSGk^0GwXY~4HhR8*eY_5bj zW`_$d-OYmEQ_81Nb3!FWZLCV0Lb$gD*ms+nzduI1(~YS zKjl6C6j}am20ryZEF({lV&yPT6pzqZ6E}BS+c0L?s8wy-w~^9~WaQTvZi1;g=;b+F zU zX%26U$tSbhuv(~1(`9^+QLU+{@Wzx#XLz0Gnhdhr@h0UHhyCW+?F3KtVh@BKWP~~` z*GlGIo-clxBg{qEj=A#7eMiReF{2+KVm%-6W7JbU3QgLw{+NIgQ#uMUN1!&K|Urv&K3<{LYg)p>^0*ngC%)EgeE(m&ja{=C)Y?>1s}T&9+r zmK*SRb{K-fx`kST+kVW-u*|dE{$xcV5WIA{&{Vk~v@#BgA?HXlX%+r8^*e4xH>eO+ zIDlm$?kS^F09AypkTf%Et%HSYDyesp{R|Sq@oNmIOqFxu%buCPJ94@v1ny{489trw z3Uey)%qfqVwX~{Hs;oR%zJeDyv*k`xD{Q3u7{;lwbbO&&Uz$H*^FewEz! zqxzwl3r`{fjODcs5cw-ojrDUHF6h=y8;1^=_c|2XcjD5GR@30WHef3|Jj_6L6@oj$YOKw)^t=K4EDWikE={<0&sg44FcC!;nE!$P6`UQj|; z`&x_U4ZV`%+)0^bWmg>p44&dnEm7=SDo{o^Q=O}exkxSS?;f5%MMf{j54uE)idSat zzMDtN!Y>-FhGV5~cfsrb(5tlV>H4`HlgHuW7x8{jK-`hqU$R!9J47NGIvn&rxvD>~ zujNK!ZM|9e2v@o~NA4H)`j6Y0mb|{#Xg<5(Y+Up7=$eg762B$g(!u98Z0dCx4<8X1 zHDg_eg$)|+Xa3N?TJ$?bav+Xu-~snx68LeQVf~}c^)Ggm+j%5)v!vM1W;Qo$PDh$2 z_m>(Vmn%nNBxR6?@B7$NGh#Z{Ti4+C!9hWV8isnVb!y9iu>wu8T9@{cxPw!{ZDcwMHhm2hwykH8jFdEdm9o|A@)_ zmVP(QvcZq3Xy5Ra18QNEK@06K^Qn)oZA2S8h!u8$2|4_jtcBfo)8WTgO3nDTP zpIbApM4~p`k@=@HUvC`w=iM)z6H(nPQ3rle>lPQR#@X51-TA0Pn}&=PUT9v+wDm2)q~e7b!gVYV&U; z_63o2N)4{w@1sPf^=MB*Et=T3Kj|R6&C`&Hv*q73J6##nD#m5lx&N+t95@y)WIZm- zkEJJ0AEy&^=aV0JYpAK+aEQ{XYaJ#+&%u(m?-750p>%B0um-u=DJVs2YZ%rooKns( zdeP0oCmyICT9=MqQk z8?Ra*XL)irMVRMD3JM6czl}QRPZp9*v#FB|^e?)}w1Q#HudJ z6jP`9wjcTm?=z-um)6f6fx~ z@uU6~8Cjb8Ny!7=@3y|(`tB|MHK}Ik%{}}6JfYsEAnaz}2mg%ZNH0C*uWQVAgS*F5|F5GjEUA$z#O!YYP_DPMk_3+~bz z8Q>%D#@qDH7%P|t(2H#z(J52hr4z{dab;ps15PPm;40M*Q6B{*%yuQg&NXl9=(MT_ zakQp*JoT6KZr0Y()h!2f|1CV_{iG_pV(05DI;_%Zm?@Vw+I4(AtsU{Rcj);wg=>k) zMAzRyvc0~Rur6$Is^8k-T;Ao8w2{0!9jRBT^LIDSE-nnIh^IJSq>N4=eO?>sOSKzL z(~d`VMsif<%~V?X7GPdo^j!-6W@Qf|AmnJifJCj8En6JshWW7;ZQI0T^zUN}cwQ)mVy9{oj{kmFe_adYRm;`Ne_GPe#T6JMI}AMk0BgwSvB$x% zCLI{UL0}!Lv`^i~$vXTD%NBAiquYAn-F&3&eCnSE3#~EaJ6qS>?U{$J6my-QX#F>< z7w86KGP=c?A1P_vqil?%XgnkBwoSh%5MxwJLr#Lzz;{efmvKn`# zpCu=$7g~wR1n7315Nb;x+B9<D`HR>_FwB2_ zrS~uKekPX7gZ37wA;UD2Re-i^2d}d$36&CKBEWCKVuS-%PPNm%6nKK>cGEe?A~-Z0 zRcXiWD!QOfP%L#OnH<~buGQUeN^~SetQ; z>F!xu2=SnMYKU2EsOyz7{wExwRaOJTO=!m=3wYy4z2AhIHZ$ifhaY1b|3w0D7pChk zKn0Su&qbil28d1g#(?&S__^aOCFLeNK$X^QLC?iMw8TELa_HW@9~%6|yXQ{wH%|WN zbhPuoC=J!TF^L2Zr4*hoCi7QLwZ_w5(j@9%Aif0R)a zmpYh)W zRZP0q_)YH!#eOJGUZ{r` z$hOzCZMca%tQ;YAMm?$T<;=;ZeP}@j|HbCm;XH#^Jgx#q##pZq3ll&mIp+#vs{r!y7A4wU$-2h9zb6k(H z?XbNnHKEyn%YM3nH7ph)@~PF!idU(Ncm2tgfB}s0?doQaWW>g^1+H`|n-@F~8%Iap z&s{MgC&~OMK1D(vg3toW+iy28*7mTjLd+wE;8XE-@{LuCu+$YLR{U@avlx@?<0n_Y z(*-6^trLdUNyAN7OgC^yi>?y!*->S@Nkmy$pMk*5)Uu0={G!k)F5Aoa3vVc+*-_vg zctRA+3<_`xOZN%(!?!B#=G>j1E54i?4>u6oKB80yt_k0h6#MFQdQT}J=f`VnD3@dpfXae41OTk~H=rG+fX~p7AqQ zSVu9ad(tn?K7Geq7N4QGKgFj!+{Li*a_R1Bpr4&pKa=)x^b!qEf!B!bHP|0JCOI;y z{qre^65ugizzG52cBaX`o?EnL=+})@s6t64ShBiM9r}_wSf?R(l7#LTw-0c0a|0_8 z-?;IQLH8sq3=s=Dj`xNOJC-2vVFQyq2P`93IIFngIA0jas0vg~g^(4$TJU$s z8gKBuA$>Q;eFy{kIQr~ZOUc@UO?$BgOrP+L%fd6s>nj%z(NcmUan)NC4r!lR0e<&1a-BkIevzMG4@e@4PUyI9rA zRtBgt{OAYn&Ry4Y*8Hz<+<4U`d0k3Sz zCH(;{ zWy;Jn-M_ku@#s`#=Si(Wkl941|EN%1Xl>%I}`!rt(PpPAS{fW-Zd<6qSP@Yvlj!O^M9m}~#7 zA3dk_PmKGR*(aZ)?u~4U{RG3x;au$7_+4M7GuQj_UgAN=;o#s9h?>VXo!K-?F*8|r zFtO*Q>_s{!|06hYuvHPzAka^ylkJJsGHV%Q=17!aDpGqCyK5Y*9w(xr1ur7Uyy)tp zi-*7~WD2`&Oac31x}?D9 zD^NGrPB3-sLFdN5)4A2e86k@;v!eMsAs(-Jv@tX$Q`o;10cIg1RGTAtcmIkQA3OVq zS%!bLjN9sk^iNy)@UgGJAxj$1EAM3%&yjn%h~b-`9#uu=6i;ur0SJ*lm#Zfo=~3V^ zkkRidHrVq4FnvVYKSox}z<_o{bB}dNvJNZuLV!5_5?4k~tn2Q;`6f%5hV?DQV%rlF zjUtwMF=HnW-6Hb3HR`qB-oGhNqu zcKL)Kq%3u+b>L+NHhwse_k-pU=y?q0N~$dT3kKqY=_yTG|2LGk)zR!eY-e;`%f@r| zehbk{Oki;R@C}-3iYV&*W1&WJk+E!%mb}7VT145Ngu}C8Y+7#`$QU9##j-seBaFrR z`1-b5F@##6pA`_!Ura5)mjA)X$e0tG%L1=}^s1r!qvK!q>L5BQdQ4v5GPKevNT1Fc z^>gum+wZ6i=3MH$Pccp2+k}nXmJ()fKqxoq2H*#}cZ|kXSSWXVFfgjIzG?0nP0isV znX}^@k5_;%QX=U{7_#l{oF@%=-7?sqps8oFzC(*(H<{+N>eW#y&e{w&KWgyRvod%v z@OfZ+wSsw*=Ou4Ovt_ZBwiUURMUpig>c~`7ddmJ{ZS&Rt1@{tj z*05Xc#3Ky)4IFI){iKpv*C+nSdjTLIAvgDN#&!jhQO*3@lWG>tI^1*0PvqT^uYO@_ z)MzrT8J}Bj-<&~PU4qzfyYG45YF@r2sXLZ5zq+bwZT*)^4~gY+ObKgHHM7&j=3N#l z1|E5x5)Y+C)0l6#H)tIFe&J8#3WvdZyYQ4|(G!^!JLnrA=^glA?0kv{XQ$qGkdtH| z2LKqnKNEl1ZHHQT=XujTqQwj1!wrc4)=uZn%z*jK?pF9KGOqHuv{XAmc1+rz%a5V` zQwu*K`Y@xsgZd|YOY;$OjeGcgX)W;~yxu&8o*T4Z6b%plPxTyvy=(pfe@es*Ih|MF zxfC*rs@0r1R~*{DZhHc-2!62uT;RTGWK6Y7_D4+u77!KOle=NN{S++G`$`6W?Kd7F zj{n7^L$KMcTz(ck5VE}px-pr9FQi&g+s9_A3+RsoTs0hBzrkHoV<0FSUax|km-n+^ zgdb+&TfS|UP?O)+*c1O%p=xT-0Cz6D_Ux%ew+!LzrEWqE_Y7nwk+wUK6r8W#`>>K) zgQ%km$Dx{u2<4J9Uh=8nfQ+Knc@wZ1K&0Khx90ow6V|MeW|!Zup(Fgvb%0}T-&Ffc zE79PvtiZ+TmW_YO4pzI`CCtD3%SbwcnW<*}pk_aEEmT-guZ;{V!Ax+ei-m}c%2 z*`C>N%nPVMLR~*sml=Nlw;YI>zL)C zg7926(c91^88O0aiuD;|iAWm*4-;ri+s&=z@>YgOAXmPvglcq>Evi|{9rjCEMmnb~ zEH%4gXZjGW&5U)9!G4n}UwVOP70M3bi=km6zS&KncA7M!cy& z^7EECmFxYL5p(E%qU3-jt&mYvX^*2bYo#&cXO_9|7b+LNfBtRb;CeBpg0WuDTmU^| zagVw0&z`YoV`eo8B`Vs=F}(%}D3+cbd%czeSI%{p6U`fZu|KUR}a4V?MKDOimmQK}5ZbwhYioL;|!bWRaclo5$sEFohFoo!4hYsaLJl>5_77z`~NXHa>3%shl90IrrK4QdeU7$>v0P?|P zqvR&f$PU4%=79S(eP`(qeN)}6y!*~dxIBJd9cFKS7N;kl9}JwETistDma1)`JGAvW z^K|^cRmVYYTMl;qGUb6nU9Yz`yJ9VyXGf{cIwzc08a%>T@`D6)_zdKNxK^L3$l1R; z(n57c@`jOv74{K-UER*uayQ@8Twl0)9)9#OlU)x{)3$F|gZQYrBib@E2;II-N=-Br z&Dwv5cTHZcH!_T+OZ?l03)1AC30czHDTD7`M0&#kYp)$WfPl#T{bwy#hah86tm`31 z5P;NP)pfTdtwUNU)Use-MWP=iX_{;r`bS30Eni#<~_e zGl=n|>XmqM)I0dr*QYgL^c^{vm*!22!R*TI6gAj8flHnEz+)Zp$$wfe=Q{L>@bQ5zRQZt4pUo_!jU*}P-ovkAysB0p%QgsnoU zMXkRZNJP?*p!FY@6=r!9>K2OWg0J6B;lC8q6)O>`BDA4~K(!gL@q#fw81jTIB6oxd z7w|04jU@L`d|`vE$83Ra>eKm!i+uNEWQK~r9S3M>?T;TkP@;dTHcDx$JIR)= z)G|@Um+`P2j#x2Mod;8>@^dy$QUHY|_G4g*V}(o69f?S6cj;vn4)I*|Dv5E|hR>}1 ziE3`!LC0QD$~5ewu=c_J-PqU5o3OQpc}9HXt!N%{uxkt4ta9sTcXQhxa8nT_(PX`Hp*yF)=f}@ll=-Qrwx=)BCT>KE7@&_k{D1{CTROIO zc33yLOEU(`sF^CMC+Dx{o?%+!z#=TSG>1{K4H8WqY`%>HLuWI|U!o8b)Uh-Pr@=-% zua}tcp~L2@l{t1Yz4HBE{&t#_4Xgi|q=Zj^;IYz*3HH%k)?iBS6l{>)QlJzkswS)q zk7Ou0sH?|LATW7Iwna}xx)RBdK=L^F3&x_pa7LdWv(Us(5tVm? zq=OVZxBsd3c;c0=+af2vuM}CN3u|)!i60_^-Q^Mp zWQgYs5ZUemA$&>GwznfO>4x2aS z)YjFmEZ*}r*O~o#5Eu9lm(VFp8I(9R7Z><&%-9I}W^c5KGhyG9oO^v@) z#T4J{BHpk(1zU+G!dbIPgulS4U($Rhx`gr5Gi1S#M?$gg1#+V~>i}33b}Q$#pzRYAhu{2m<8r+7ik+-#dj1wl3-yQL@|TAz zbRAl-<+Qu~8dQ_ch;Jogv`!|%_N-nXq&qxWyCX{LlKwsb8#tYU8n^*fREc5(R774t zz9-nKL8T0fOz&G&T{AhJwBRuRBVG4$bTgvpW6*R`bl*zC3l8~AWqcYc zuWf4u#3=-z|CC15?xBe3#;O%YFbFVN>s2KMJ-EP{K@&>gj{P z2ILjImtxdO+1p5Y3v``{?W}+$`ZwX6GjfwTqIUqdGD4p%cnE=12cH1g13)|n4;}nO zhV0yrLRzG%*)O!wnN$cwPNAL(Azq=yAn?7YhQG4TF#$oxusR8Mnv&cnxQFfQs{uUl zyO|ANc{jzKoA8LFp7*4zasxI%`kO1Ua<<6}e7N0-UQ#`Z9i>i_kFR`Q7LoS-fbc4of=BwQTqAhH?O$Wj?&xzux zO?+rbx*yiEHTBKfO#z-}Q(a>cDyX2mXWnum?Tn`@*B0$E?1+zUj=3EZ3B=_Q%~hgw z<$`l$q@uZ{8m7po1h=3;IR^>BIPmSK*^n@i#`#`Mj($Q-tyqc!#EkEfr^6%aUiE&{ z%(!BIl+a(KuzG-k?cguWF{=GWHZ{;fk~&FU>T51jf1_+)B9NL|c;Qbxr`adl8k@cq zMR5=}z=N&B{q55|ZBDrC&bea8C>4-g@CkeMz|h4a12)_s!NXJm~6n=E%dev3TFE=YieXPUhjA4P77O)QrlKe*PTcF;T7@;FwHI zY{Qr9xtiFdz3~B%k@Q1=ZFfd}7sOuZFrupN&nyUcg~^M^P>A z(+0#^ErSsQp7rW|#W89QWykJM$L7CkSQ3x43V>T86vzIu&CjJZUIb1pqm-+aj5i?5 z*!)xnM@ogF^AmlnN5oK^*Rh)_M!tS~_zN}R<)0soB&|)*!|$k7Rvr7^nga)D?^+g0 z#uq=Om!@iO2&zKp!@avFo>Vw__muR*0@2jE=dN4GcoF}ra{wc|TXec+I`A-~Qui{gy0$c0;kAZ#|W*hEG7 zp1qyv@_qit7o+!e*0hz>$HRqK2voSBtU?XHAQhwV{k;Yzqo(mI^eVZqQl177gAmN2 zxfV9E(?)cahlITfL=~1?XXBZ^5X-{nP^KVd;Df8x7c3JB6TipBG{O3@u(Q}W4dgpY zBtP|7OR_iB-sJb3DwS2?G_#nfJEtPyy~4az-^84#rVwj4w8O2xDRY8y0|}X((;X$| zK8~=Jy&ARGgERDQJ3}d9Iq$du(F%4@3a1-HNlU!g=hwR;sT`_SX`RJY9j9kVEIN=Z z_(`{jzj+=qd$A}`*(vlWQgbNtGtc##YCTG=#|wnbhqf)*_)z&7koPQ^zw(qkTpB)E zVEtLL!ZLpR0LA&g4*;^_$YW47h;O%Jbb5)l{Ew5t)-$F+vh)&>7Us4$w?J|i@rbMs z!?Sjv;Su1G7+XUI*xXfrK)W!7?I_V+)EL|A4>ORE*(WEPYI|yYmW@PJXT+}ZoULX2 zToTe!LD|%qvJA?jsC3z_`B>a=E#Mne^xT=n&p0`m??3`DQ|AfJc8Kj;#3DgE!I44x zmQLTMwt`MNKY@Q_xe+% zoGG2XAZtbd`&ri$<;)%W3-R1Ek-M-QQ*JtX5;YH!i4#jgTRQPLMOCTp{2eW(YfNpQ z+=yR}R-)iEUX_9~)tzPPS2?msxu^oA!OCQh4D1jTdlttJHf1Q2dBzdjsvlLva&c+7ki(-0bwI1c_{{&T1zfp?arjx~s z#X=!B3n3+~qej}_2CKzQ8LZQzOm|P5pm_4X`gM0L&^wr}KM?t^13*sIH!7VXd4*po zff_TuDgfi*`35gzOfP6erc`pdV0YnO-B5#{z5HPDg=pEL4VKb7EbvKtyqV}RaqtIB zz-%v^sN_VB1)cEcVuzgC2Bncax1Kwd`|E~dBF{aSU)}KenBCO1{)u*Miuw^<16jw6%@S(=o>WN#IwZAqldK)j_wY$si8bLuX$MG=y+uOz%Gp z({(k;8rt~P3)zP+JyL{X4tc$jkvS*^(&_jgL(|x9T?WPTi8`_Hg}Kp;y%_kk^74SwJQwi(@G z^WVY`Cg&|snpiF3;f|(pv-WM|qw>#R0;j08Ezy5!qLwpmA^ZSE^RPBPt__aL%5kgi z4furn_&0jMHk=`tuyCYo+e7y%pg0%P!~%$k8I&46FCu8`!VFK+?a(y;93VPC;2dcBFY2TyW0TF0Lt6#=iZFR>DUG8Vl6il2Tau zTw?R`_vpyu*d}4uZqt3wd}d&W67hkXPj)o%&ja(=t&~fldeio)`IZ=UDt(D^k+;-! zYrhLGL>Jjk@g;-un(U4Byp*-#{+UwL+Q&0gHCVB{yRhXTyamFM94aa@Se;=I7Q6CC zpD_QXSYHxkIKVNdm#1-f+oJ0ZbZ6xbP+a`KDpVgfp;YByK{2BmWVJN|KzLS(h zr-&q$K2vg!ji7FyhA_&YpvzC+?I0{daKz-z(jhRGb?~d+yKK(rIITrDq!MdQ93LOv z&jz{|u{m^q$nH;WfA+KDLD=m&%U-zmYv8u?dEan><&TxlpJ3~>lWzWCOXZv;>PA~> zb5uTU1-QM9BzA>Q>;cAOH5F9^)t)RR@qsVhxZO59f9mw9r-eSTB)_77k*6a@eP-&_ zEK;I}B?WbM>-{!(rFrPjxvM74iO%-7*{%3Yk1dyXBsye0y8qC-n9))?EO0}brko*M zCr8{cLJ`1wzh~wq1Vi|b&%^gPz}{#L;5V~S&z1@%Wm9!VE)4TXU)F)Pgk^5s)M8_r zFJY`j<-K(HJ_?xw|01k!bK8Ohy(b@&c5-*;0~eQH2pPm>KhM2+4d{$3>Fw_O)R;wB z>xmB;P&yTNfm{^;Ig2##u1#XQcOx<75&)i&NkWi~A38iem2#N*#x40uk19e)Vf5nS zu*F3qN=l)OJ#{TyoQ5E1#~}Q3U5pG9kv)*^m&9Ip)cd$fWV9#(R5Mjznljt2O_N4s z*GRKVDaUJ&ntfSk;h5jlWwCf{Pdm*qJVSJrm_KWHbzVg+>;&99=Ppe7^q38dENY!o{`A44+ zZ});vxZnd=a{BUg?#+<|WDjrN(!IgIN?5Yl9g8 zvHW_t8%2Q#Zq6-J)4UV&FbSxh+uXUPOB3i_Mz{c?%Yz5l%9nGPy@6i>&H);8)e9Gb zH=S#1G|PgX#y-%yr@&xE$gLe3L}E);-uO|*Z@1$gDDy)0@Ofn332rT;0jO%B5iu*3et`9 z%lYt=?10-)O@&Njri$- z;(6$~>cx94R)+vm+;cYGTk+=a7Q35tW$UDk0Az$Z_Y$C?>4?7O0QW5JJt-lS{;-U^ zgv1-_vQ1j1NcbMMhoCM&1S`m&UwYTQIyCA+4xU8?;>u9rUQvcN)qhDM!gu|c#^T%C;4ST}dldmFjYW@6c7gQ7X$K#a2?CjACd$q%VUiB}u&jt3}i7tf21PVN{W^$gdVAzjdil1G5e`4WYaUz>Uc z7IE~LDwJ^7oQOzNyAx@3+Zl^Rq1oZGi8w`Ev{@p{-Iq|P*fxtBXl+Ncpt1!mVed%lTOQgje z;zct6xOk;?I8V*5KtXNF)`0Q3p63N30ps{U7!Rx{qb~i!jfzlr*{58W`CX$ zdYl-X!|O%s?Ja%W&t|>r581kNukRh;HO5!6(>&x;Q+X5bd!F8{B{;`pT1J*ZO-J4k z@SwV>wC>UvqoJuJqR5Pl4Aw5%#HiyN$y~IvarLIulDK31*BTcB--(^2Ohsxd9_<}M zpkJrvUj_h;=GF9bvR+NubUHvWs0j5yjA2vy^8gIwTXMZSuts~&ue|SZpW8jV>Av>3 zvEuuGp7%XR(_W^H`bO-bnhPxJaWsdZ>9O{`_04dZ_nfy{quzvd(FG1!+oUM^@t>mT zZcy1=;7P7PTlSAE7YBTX(>S@nRok8KwllHhayd=XF49c1@^&8(%INn-Nii{Zb6mIIR>m}eM z!n~I`8wq^xz1%-jGCeRb_O%g|Gqo{&?ORjv(apzE5>3ymEoS5psab$;^%fKTxyNhh znZNYAdQR&R&u+gOQUtXbn>0DiEw#q5c#@J0ULrUy{rc)(LpG^l6|&qDeJ@I{jzJwq zA$9+b@{AoFsazVAi2ush$WmTGAsR!f=;R-ml~+~G$J@YqyC1qBVC}yDzgc;x`PelPcjyi&TqV2EH!{!slCArIl zF?8LQ)9MJ!V7Qrcx+Mgn3tuB#hcM!H*=&ZOW6(mj0p3aPn*iindFL`0>;4V>no|j5 ztiH#NJ1@_~Lnm+oyK6*d7W%E7^8$Ge3EUuwd_s2(WxL)fXN%8_OWN4bHVyX@;{ct#Jqm|zpyQNLp*^{{xbyqnOMj&qUu5ueTIsnY_)brqNsFZV^>s+a5+XpTwA<@pN zumVu5#~eNTD0tYQH1~Kz^tF3;M8U+b(f_l5Oyu;|Ha|4oNc+9G1LaIDe^l93X2_}q zSjyGO&W16=*uxL_-n~+kgPhEJF~x~!ex^G%XUHJXek&IWaU>i%3~&K=!t#U7uGnuY zY(L=1b=wd*XO|0`m&0~9N{7^?TAEPZs(`Ko<`v&D){oIgrB!Xt{w~Z_ovH0OF)r${ zwkf00zUEPGQ+D?3=_pG~y7lzos77D(P66b1NsKe?3Mz9a+51}Rr%eqS#&(yGi7xyy%KdLIfy&E+X`^wO5_=5$e@ zT=L(Nhp3nA0j^?6Omb|vuetZym-M!_x2TmCamwC296@gRgz$+n=QVQ>Y!k;})Fg%! zvCH&k7cO!`*xrC(vVnRgh>+Dq#Vew2*A3`I1I5EH9HVtnbPV^oP7HTqi5m@1&lNev z*@v9?%&yVC2{poyxVU`m+0Rg_nCnV$_NFVm)E_W+Vm4fWp*`qzm9>?hs&Cpp>r}Wv zIDKk3v+#SOI*-}yKP8*Y&C$fcqj__KN^`?$K6MrH04L%-^=ml3sGVVyY6Eyv=%OA# z(tsnh4&{}+*b%CnsjjZ!sT;mT;ch(Bm0M%F5D)om*7XIbWQ*Hw{9ymV)&@9g8g2&& zics7(6iIDB6WsXhJVCprRZ>8-*EE|kRf&uB4HH;a^4<{ZNwZh08?7Ce)a>2HNWr!D)Wx+3XNQlzwJo| zKY_?tMk>ODe-38q5w5w&22S+&dDOcjHUem90tPI*Gz`jf zUQin!^nl{Ri{ZTKCnmVSkSTe`X@G7Si3>g63F4u~A)yBlV3T+3p+XTOT~dJ#GexkV zJPP@Pu+k^TGuolt3=DVe)xRbX1O=830SngswDDgb_Is}8f_~o z2!KW)FdsN^t@J`z)?fGl&%d^96bAf88%;_9Etx+0Ss{`P?cWLSiACSYbX+;qy2mLm zF);&@1-TgQs1@QCGCDB0LCkaQoj0rT@dys@w75#MfvEyJezQTj^Nc@y{fqkE#fmy{hduy4Dsha&?7YEe^%}v<64k(4PnUHI+|6~t z;E%CTR~Lk?)eZDX>?v8#io1S>(9qLh&Rn$%GBUdIQzv28$F&2b=zaC0No~# zPK_E%v}`*?hjMuy_`8a@v<_``ORH_~H*?duaFJGh&VWwX_PzcpVCHZEU|<&FtPX(? zhhW`T-iQE*lUADAJ0BtGe`rA93t`Q#bvRpSW={J~7H*;2>i|+q4zX@WM*T%mnbm8M zel9Wy8<=vbAak>#`o__`@pJ8_8@~G0i~(a)ZTlgJJ~sO;Y>*J{xv zC#XR;ia*)pk5yaR#ici`{$>gT`&0ve9VXU3cRDCeTE$rL)5H7&xd1?TX^}6(h9I0G z>QNHAHo~78kV5IxzQR})Gfe4Ktq2*T189B2mKF4Wj_sj6BG&|X`xanZ05K;ahqs7O-IyAz;Gdf#?cG;%xuZf=hQ zGU?GWL{{M2qx20t^%0wzO-qw(O$KqGlCX|7 z&+jyKqruwg)pGpkH2sj@plYg4F`!<&QuK|G1zlgqrl|eMB8OSPyvhZr-2hNf@>3|b z-aRXNOA$(sDO4B&*WQ&hX7V%`L0o!}FUd2wC&Ul@u|w z{+=d*0Bzq1JxbtKaG+!o*<4v?ThWQ%T5|c-j?b?YUg6U+1iN~V_Y5Yh`Uh@*#BJ?fF8x64E~|k7h1PoJ5P==>vlq|G z`=ue3f*GYpWk=3KW}&W<%I39Au^e!7wBAp4N41;d%*S-<1q5VdMtO zA|}m=>LuAtAYp4Gcl32^{oUu&3w1r~eHAM8$|qsYAD;GQ1mVTQL)Xp+9YF}2MLFv8 z)I6Y1yTQuU7KcHaTgb1)z!7Wic3+qEjyFrC8%-H9@HCm1$y#B5y!E1CRcIzh8bE%t zfv=LUAMoxDR5|!V2B^$#oLX?F^ov%&n{}+!G>H9nFYnw)I`j%xYN)tbs{pA;0!j!` z#zqCRi0^V#-Yo}$*QsydJZJGlO7`4MTW_&AN{QYFOrQL7rwpJQt?el_W zd+owG5RA0)&Auq?1!|n?Jick@@Ug;sOH7M4&Kzo)e#QcJlFk@43Mh!rl33WlnB>~jyWaK!+NIVJ#i18Knb^|=OMKS zqL%;>oZ@)28~LqLhxC#cx!Ao9jv98;Yzwqq6@IRQBNc&q^6H|i&6R$pGgB}nH_Wsr zJ(xucHQ7c^f>OXep5Srx%8az@3)yfpn7jlXT<(py&~e#1v-RrU$>6ms4Zef&Fx7>| zDI8d@1o5tW5YcNg9bL5x$e0kMq4~^PIk7Mj_*^`Wem07X5-ZNYM#AKK@(Y9TFbY*T z5iAwrwi`(w^g-a}MxxbAeZVWuf!d+5N?FWRY?0Dr&Muf-b$guEhs3K3DiwbG-Jtj< zfCc2SpW!a`)5j+$d{9vuNc&9+-Dzv`YKLA|}5(nwo$&@a_>(}v3e}hQc zWAvW0)FX~E;|lQeS=Yq;CzmU!K2z$QWaQ`@&7H(?{7|dZc9?k<1TsFPHrw3X009iL z?W!9uUI9+H(h$m?+!s5=ntBsZ`!4@5prmM^ZxCVcS|=C_eT=l?_uDHksfUK5okrRx zP!qZM_9duCn#;Eb%whZpWV^6`Ao2Vp&l7+7k+L`KdD45*c-M22KsX!#K06y1%;Go_Z64`| zog1MIM!@=~`zU1FRXslC+hC#OB-S<~Qm|Rd6djhjMp1N8^ zK!h;~IjM6~)Y{6dc=i*fq^4r+pyweWwv(7EjzvI-_RSVTeNT@MBrHc*!IjXr(Yd$n z(zrhNM9AZyaM)a5qZSaYbVjc5bcA~Q^22q?PC_yeyZz|!u49*}e#KRvP*jSfZx-|V ziv~0Zm54zWGU41`iIt|AiTh5;K0@m`6D@#bQvenNsor2O1J0GPLYs_;yn@8?zFE5f zIgM^Q10n!g$YqG!Er=~&_l8fH;GCsjSS+a4rk`}{Crl)PHWczv08(&ERwn{V++Ot) zBu`)^pDdoA%hmQ` zXBMY-j^NKG3i&I|~Hkzb6S` ze-z$PfA*;_H!ri9_|S5tp7cR8*yRiHy#MD;emh`Mqrp4|jSK%K96}y}Nh|53}ox==3GlNKGRSwCQ`Y}tP`+T#K=qj3WU)=Ux)gn(~og>nW1cI;f2a` zm-jkH%^umZ2xsUrvJXCdbc2M*uGj< z@IXv4Yh4EAHoEE2EV_o+qvI5DBR{B(Fg>FP+g4a)HU1iaW4MV8Z)DGP@+zXl%x;z* za_>ks?YzssPkh@>5jxeGai8KH+C3$ck7f7h{Sv~txeP@>_(8Dy_oMl{;rRL_%&$@n z$R7~RzN(4Lt?ZGCJ`-YZ+rWk)-blc<(m!NAQ}WpW@<_0O3&Mv!8c33< z%d8^?8%6yS1#tiYS98*{b{rzZN%l{(`7HH*(w?>2ey zV#bgiHmGg_@VSU8;bLE zr(r?A3k@_;(*&A8bKx-p;8TFTuvO&(v?j889HpyWJ=u5`(H9ZGx32(?G5_%E*lWso zLvGvMDhb83#ms8ZxC*F)7W82YVKxs5; zVl|-3X1IBIKSl0i?p(@I>-BctZTi={6gw`lW zE|@il4^cfcIL@f8uyD+rKGiySx*QS$uR{^=TIMx(;yD(r&A>MWw@1lZqypxju1;a3DIrZ+bHj?n4WL%ik~xRgA${e4c>UzeTE-O@9I z#ca@7yZ@2>47wT+K>%@dCZ=m1LR>@?qLvZ0hp$lFwr~hat?J(%fC2FN*cJriPV3~o z_jD${;&GoDS#|@Tf(OqTs5p2?X%U($DwVqq-6~d=$8m0K;|C%Ts)`!|lSw`90NMYC zI0u_rMx$8{k>HBi!I!@%0}`TE9g(Vu?TPS@XGCALF%5OBk)uO^rynKZ_GvwDFZ9<~ zvjQ30XD9#AeJC4A7%slD7dMl>_L$6s1y*#Z%x2-(*+B6;n~#BNlErYLS`#ZA7(UT* z`7*`rx)GYel4HLoKn>4CvZ zvq5?k+mdPcD1~i=TFVQn0#LQ>Mgk8E3APo`p{F2N4Hc#w{E8?^dfsz-jG!DvHT57m ztxQf!!S%5NEkugaPs{eszFw3ZBHU@HH38a+ow(Kw{3hi!n;?BtmIcG#gFt`;zk zQjltm7*BvVh0hZ9bi=2-FJ}Do=LnhyU_z0i`ScsR6MSTX2|j{H2&dMVq2{_if4Bfn z_2-D3gPooq>${H5!v{bVhdhBwRsBU#byJBnM;|r`@j}XvgjN-CPhnv-4dV|-!jy~( z$lHN9m|I)~WV!|Lg?3HafJg)0i?(^zi58@M4@ArX7-qYK+D`Zu6}Bw*eA`<;J>ik< zbAg2}zJrd9wSTg!YA@$#ts3vqg=ohOM6z@X1|%MRht1{*&;&AY@tpWO0`ZoPB&cX= zUn6`Im4+B}Y;x2}ZPAevsc?miU&S{DUjn#5!nje|e@F-x7Xd;*rj-oCmh{5&Zf@99 zXHQ?+mEDo`#seLJEVe&;fVmo`36(3W!Y-F26a+^RrVTp^~kXQJVu6Dr}M?Z1`nz}72iCIN`8cQ`eIRA%NDY&D%U5{oQ!{W9dR#9>(N67t}uNX zNcSLyB%3eF0MhhhTvft5nDc{7#=p^OU)R5+)lb8{+{ zrR~LMnCxT_XXsi&n$4}@Y_%(JkW0v;XWHey$)zBT7%$y6m$sy!Kt0XMRj@XR!a38a z6$~ML$}!7{zNl>#i&*9(Af-`%mWbyRX@uAWk=N@jr|{8`GxjcJ1rJp8tkQ;xs?Jeu zGaRI_PPc4dhPV2!N8ugtE1Y#oibQPUz}}L_!5RKGseR?@-gM;oMBbjeWQPw5yXY_Ta8@g`U|YbEv=s=-y8Z}qYM7<)XRYg% zS;H~Kf!An@a*uQI^%u_ExPG&sV+Q(*=K>w&_bT1QUlu2&U0s6lppIgnwCM^9YhQ?L zoTdFpLI4$;z{Q*SL6dzns{G%A$p{im>kb+s1km4)1}s<%g%Onis?8`TJuPsp75L>b zWCcug9@KGVa=nFGTsxtp_f!Vjbgm$>b zaVWpAQWLo&h{BV=qlckyr#V{f zCTO7!Z0f~*eSL^UNT@c~An%uB$~MP6UV&tF&}FuRpjKNKu_MF$MVkzkjx09|J|QLy zFdzL9*oe$8nJ`TTjlg9o(75YVb%czC!L(}f9Bj6v-0T0igEDfge@E35aoJ>NigqyG}1qE-#6<+?YAIzf4f zMP@E3Zqp-{pMmet$4_*W;pRGK3i?kD_a9auQxCu-% z#21|P1qhycsr_aGvxzQCAW%Av9RMiJ+=^~0^I|L17SP^|E{6#UGqi#7zn8EcSCys$ zv)t!{FqU#sIQI|R#~E>;+~*ZotKFEs7}O`-Iy!2zNLe^fN`wPcncW2w@F^h5_NXe{ z+u`o_{}s91g+Ot!4DK5Wiy)b9eg&3QlW}jnbDwWGZ;q4}v7g)i?cgCN9worek;*Mt zkfv4=!dma-1d|ABYReivAY3pwW7p9oi+oJZ!jCAF0%RX_GjY-PGCE@3P>znNkhC7B zgIQCb=i$PS5CF#1dsu3^!H)#Q-c(&r``{F{>#6xWK2|*F-T z-F}<=0R;v2U$%(k;p{I57Qt|5&k8ZELQq zXczr^R%g2c-~z@zoW@lM4pqiAEd$G-RM4-33^Po)iDg7`$uV1uu)}Nsx5y!A&pfibDjLGLKe%5LP!U@!EKmEV38hvX3T5AuJ=ZrYp@ z?AkD|OP>CLKzlVOT6QcB&oHRL%KwPe3mp9+qgU>BEjqzRr%)!S{@cvu*~Hp3x~{XzzhQH!SbwnTrC_XaI|XbpAR@I9(aJ?9F~g+?zipvODj_IRcY@M7A)6^iWIH27!Qw`=CQ z(Y3DZIPoA01nEPM2%}vRKom9|Td^ynt6OcwOX%k9CEi;|bj6TD$i5_rIncHyNch9H zyyhJz4$|~G(gmnH{YQ=xZduvB zSc71uy)^8ZcouqH{@KC3P+AH78eBHds_cS>^RdvMOe-V9!HD3C@1opeX?aNNm03W8 zrjj|O=b#xn^q!ai{T_p$+TJEK`L`#U$v(i^1MD|*@toSHTZ?jV{KsNVHF-$9NTB4; zFt;l-st=dllk4DMeT06KLYy9lYy6a_iql}$-=~kOB4XPijY<5Z=(@WY$;BHsbiLTs zf+?`mfpKGQ;RheB8=V^UTPDERxOrXUo!K<5gZLKSi}eRkXMm+LlN7g`TgJ8B@e|;9 z5!}~=0S~HN%k2Ra*i?J;EG#UvP&SFC8ApVgul`Zg2W~wu!G3#h8V&Unes_>+T=4aS zO$>F*&^lC(dA3Uoo>k$nPHB4rD+~SsGC52Z=WD>{3oc3mxt3=54yr(;SOqASE;j>e zdsVIoHUb9Y&KFMhd>33wOK^^^b-U&Ly7-0F%r_(|9T<&(Z>ZkSm zJFrsB#UR`e`C$?-{piukCY=?bM5iS1gJT;}0CgIUYC40+Lu9SUPHL|;;c%xako`!F zI6|1bu~mpdpTdL!nVpA~E?%28Qk(h&5{R+6^k(3r=+dW0Mq(3O`6d0t#A=9@)+u17 zbcRuWzCyKZxq$Y+O=kg1o0W@1Z z<6E`2PXx7|0f}6!8>s&kBwCWKOM;FMQE8jX>lVU9TNHp4(%}qsu++s?vpW8tEz5Lv zt%spLYXMFP1{)*7watqNnOXP1L^16IEqiN0aMk;Q4<&M4pm52TZo;m2#aL zx{cJ9nWez3H_xd-)2G6~E8lI~2X7cS9(ry^sW1$om#{D=g(tFrn^) zJDP6fcA#k3w4?s`Ti^KqNKg|? z2dulqPIi#C!hMxW69FZVU<*^O9mT{+(*K3-LCCUxGhw@E&tT zY<2hC_g=!;Z~Nsa-@cTc%3M(kM1I6qH@z3|*2t1Ezn%I%CPjM0myLFy8+ znY<3v%0qu}30JJid;;If28RSaEhs!d!?PqQ*y$BwHuqrA+8-Q zjW4TOO9k%0G~K9zi?rOrOnwMg&@!;k=J%ywge~=8q6~{{TB(p;q8*n`y{lJWQ{#DBhP>M3 zzf6X@?K5}F=KaKXo}sTs{WLT9R|8l$3n9C~_e)kDS3zkXWO(PQ4ZySE)&D~|trBr+;E zF)>jd3gsLFhsW>_`Mw1QucdJ=F}BanR3i!h0e*@Gc+`R~FXx^}fi>6p zZQpWBtOj{@=AYloq@(oT#+;p_;bJo_GoA;-)t+IwQ1e+oD==>k-PgWhkH^j=H zw38$>61tgdx;tWP&NEPV%u{`}aIMYcAQ2H87sMWo{NSJdz3r=IF_YW+`M*B}|K2u; z`ZePWA4COa=WY}{EykLdWaNavbSDTctP6X=j313NZgeu8=?u?<`+9&1qZTor_uMOx z;TxN$?n~;Js?}?nuqh%gbN?Y$;AH&mFaXap2}Eop{BYbK+xx$gcfG+5v5YJSo;uk9N@C{?!SB)Lw?&E`6WM3^Fv~$VwtsC5cpCQ8#;$? z-OL3W9Pp}^-1C}Z%sF}L@0a}dzh9L9?SSfD&ACu4lJ;0^Fds&z0s28~mO>nKBJiI7 zyrF-8_Y;6&i~H&Ho=)A^2!eut}yt5pf|dhe?o$<=!?{(XwIb z9Jya9_dnnH97nE<|| zIdTUzKOnam7@qg*&T_<_z}Tyc8;0>9+xT|Qn)kv-k>YJg6sq!1qQJCN8ZkWlR{5=7 z>>1BJ#QN1>$sRFKWh@s?1zTeB=u9)QimIYNwg;JUI(YC9%eDB(UoOu^j9kCo^Ot+& zI}V;eP5rWdBWq6|{)?o|y{(`hhpag9JhI{?T*5cFJ^3fUu-{9+UZ5BWri`_{&OrEQ-1X*+340i(1TdxVp^Z$&v z$s8SVi$?0z*k5aEQVls8u}h1cTkG!eU?KB2C*j9L9gY6iVLIqEl+K!SrOU8rIu3lf zZ&M;x$UBRqHRcVtuqJn4)Xoeh8K~%;@~GLgFTwmuJr(b^?Z}DkUxN;0<3a0>$h`oN zKZwc7(SrAU!oO-LccbJ|qF#4sNdF{h27FUaaT+ngeSAPJ^Jm4Y|D+;)IC3e75P~xq z+0}ax2)FPt%}-!293pNI{r|Z73a~1-Zfh`5LBJv;B~&`3q(MMRx?4cHlx`4|5CIF2 z?oe8qO^AS?bazX4H~jNOzwh4r@8>z;;cWJPSIjl$m}8Fhua^eNjs`;SE~f5po;fd* zuv1`}R7y<%2QGxQ0|wgLQ=x)pKaac4e6i2ws8sZ?-AeRG1+Wy57XCHoHMRZUTV5FP zlo6C|V+XabS4`X&a4x=sP%h%aSQA-u1(2WQkaImcpRlNxmt%}cWlYsB<1gKqeO z&rW1@wc3>N-38`qud$qFR;X_FW}ouj6>KN)I##5$vC3fe`Z+w0xD6xTpJxWpL4yZ{ zDq!0V4qOU-)K`1J5OWb+Df6^ufX}fXap!|@A_^|Y04JQpyo62UtCyJ7&pLMLB-c7j#d3W{>_Su$ZvUPR= zSvi#L73cofXeZ%zzWN{y_{XH4n$vGwh3_?O z1vU?k+=eb1F3B-Lpsay{>Tlyr%e7*Vk%3c(#i!ww$&e|>w0trI;IZrJ{x6$nZ#Pj# zS~gxE>DzY5^+_*{mr%@(Lo#x12H*ZVJyX1Az3%wW_D&hOe5a$mG0WxknZ+=2OOM|$ zU;mj0+}b+@C*9ma8Q#ov!J|X*a=Xb3e{1vbnL@E);GXoaqj1MsMlHW6<>v#Cm-N`E z+aLu<=IEz?r=y{1UowTmAd;SmPQhiMdZ94H3DrcbTznqQbZ~FfJBS-a2DX|*c)jS} z^oXw$a7!=L%U$w883F`+Dli#HDg#0?he0`&nrRRL7BRVjjt9}QgWw6_1O*%e9?KBX zBC@(^HK-ygVu&BcAyukuy^u)t%^#ha-p}&`SCNwzo>{v^{`wvu1c%Aq6 z+NC0!)7?R^%00Vw>E>Pu$Z{VAH3p97x1JhXi0=hZ?6m;17RjgW&aLB+WATc9o0;-LCG3|FZ#9I4Q`DER>dRmH?u^SzR! zwLZv>A@w<4(D@4GAno3YB%bW4sXke7OHh-$?eZdJp*)-#hzzDM!`ojNMMR|VEO;Aa z67%w_LWMQB(ewqn}?n#6x7&2VDLp@LjqecVjIykV4jRwtBAZ(ACNbG&zVh z!q%ugD2le3ln@&(@U=7D+61WVo%Z@)rsZv%hVev+F;N?c5e)Q;eSr{5 z%CwsG7D{vWnvYH!Evmo1@M41=h)>XNiKsH3k3g{saURn7assTB+Nm2@Iedb%X5s|V zXo0%7I2!R{VVvcFcEJgkY**mG|DAg}1J(Y>w|WYUnx2xS)@e4V%Sb|yz7EYTK@gi^ zjNQ^YCWJpPj0@tb38rl`V#3G^HT}n>K!~kiuBY#6kbq5$y)Q)1EbGb~`TtyTW$g%A zgy&&=8(~QNaoPOoLOS=e3ux>f(4y*R2ue-Ts)=^vOy#pibeE0b#~>1=V>%d^_zVsO zCDoC21N&y#HW^S1Lyjhr5N@}teED%59AGBftTyKMm~E9~(XG03xG;Q5wCjgQrA4zl z4w8lr(}K`v;s``lJ>@I7ydkD7^t9IWF9x9B2ux}gK;5`3%3umRX~E6j;c+VY zs@Mlv1>!dJ>-1{O4=6CUXbw~auD?RO~kMC!S-jUmV;pvHj6UXix3S_iJ`_^%*4%89Gmn?0&RU&h|O z4J{I1s&5kh)$8TP;eb0JR3sdrZE_o|k!rUQ=Y){u19zL)9Sw&nMAl-VOptyUrtMY< zOo8`IsFTuPVuYQ_OQ%3JvoEXY3_o|pS4qPq7W7DgAOLW==&V%qNu`io<7G6{gSRK` z>IY+7CJ*Gdv@bGn%!CsMswM@vZo(*6h1_)l;)la8Vg)PzSau6L+dM^T(B50 z-*AJYC-tnrPy%x$g*@2PZvrashRCtwF-i@ul;`&59~G z+^~cJe9RZyHkP|vd4f}+)^0GCpcePqG6kHMIgbT%1>@-*Tm}Ha__f)5^5E1Pv+UQ% z-G*)O2}H#HhMuyRD`-L=5Whe~BZ8PpAu7c+Hv^a(>lk)LC@c07B6>F9bGy}d0bu%R zw^D z%3V;>Nx0I;(*rkKF;Y^Yu_NuY$i|wUZj0ll2aHPn1!DbR@o9Bmh$s5;H2C{3iJa49 zNJfemQC|VRgRH)3dU`V?UKM(+slPxq-x1lcE+TyG%@6_Uc7X~5CqC!QWkdK9%p&)-Y7Jo4ZGDf&{&k5c809oJ@=}mq9 zAS*`P?ZHTCd#+FBOp!+k6Z7N)bnR&c(SnSSrcm4jY>`sI$&?oBLR1FQI0G$C+0%DH z$ZP>6gY2|+gzZ7o&Ie_3gp?K{LWG_MSRQ)E6tpVj6(1nHpX4uXF5v32Y`_^=gunf; z6TkpC08Tt#0*@wA(h++C~q_ z5hzA1x}Y75anh*>AnAa|WuqQ{r4#oYFqx%-QG zLxy7d>02AjwHa*fi13y++UExaAm^ge*^`%7LsmH?LL4)v@Vgpt;@(kDlkw^-axg_I z;SXm!^*tU^gATPqry+-@Th6zaOH!pzXD$KjY?^u&pUO5WX`#oYHves<2FHpL^n2W9 z(gT0J%xX_*aGn~RY-&b{zUrn~s6gbzAY*5(aDr1+&5#(e5i6H5na)xH*oVE!51_Q& z^*?N7`&DISWz*$yT(}O+>wpPrlTQ(+-}bt!9A`gUJJ(wP^9=@p0^eFqdCh#IwM&C|cy?aTlw1%d;j_JjG_hROB2$O3g@ zutBVB&E*$yv**;FokVmLI^Zny7$VYq;VVQ(50{f!Fm{@?0#HU^lWYW-@$`VJguQyy zA5fvxELUWb4}mNI8n@neB{k{9uOjIM@iHO*_Tjj6ZV`N_Se^FXE*v?M1XYfJCB?HZ z79b80M%%p@u>>YYUjIZyy`QTgeFMG4)}B8B5U}M(58)O89!8g~?zWVeuy$4cxX#xN z_F#zJeZZE+tZJf*bRsbt6xNM01bMl9f8P%P#gVb#YPZ7f&;_K+8{ ztD?DYZJ?LpX5>R^`H^0wL({lLh-4HJ*G}-Iw^83KgM2TCT(F`(jEVaO#C+sfSZG~D zy4zxzQ1HVfLn_knuf1UR&(3g2482KlzR6en_{&cT1#7c>zs}|&VrumS&AT9-9#Bv~ z7?@+{y%Zd;2#sT>tANDm3mabJ?yW$JEYcRqn3|+Qo@Ni-M+mr@sj>w77^GnrKm}fB z_Q#yL+d1@yyU*v#@{y+RF+`&+5IRdLsAH%L+!Sgx8*&sn=EI2E?c;NU#4}|NYg0*L z3&4--fOhj!2q|A8*oaIMx9jpBI4U*d-O=fwuMj4WQ-wSYG7U4hB!v2SY-$cPCI}HS zJ5GZ#EdFawRSP^4DhAi~ytXnZyzPfI%nIm{hL64Dx6nfloo42YAF>V-eg?6D%fP?* zn>|4A*a>`vvYS-6!71f=*1{n&zvGzyO@d!~99ES?SLy$d_|L^fGV*W->VIQ=zTHe# zL}%!xcsyD=3Hu)zrC=jO4-Wp%3QWsOM+w&dg9gx-5odj&f&y?~as5-I$1(6I01CCk z(i{ph^Ho_On9GYvE(6h)>LRgV@RE#-8aN-ZgUQEIt1LABY+CqZr!W)O-(ML0D4nkIvlJv3uvO81fc0-x0>RbueymxvY2N~$tBc%tP+q~eAI4J1a{jZ5KU$(XxtV!RpEDpObBr;fb+~0oIPESV-U+m zumm$NWftEQ|9fJ|#$zGyRQk|KS%MdEuH#)NL}{hlfG|3qa~}(?f_K!<6_PBk<(vL` zHwjj`SZNbm1%1I$99tTdxudpS;k3RJ)EdLlOioN*>o17TgkY6V$8h6{BXW-9;yP!# zRv{i~R%}0=mufHawKU|$!o}mz@U2zp#e4~XZ1~InW`Z;+#PMj5@sSjOb``nYalRH> zsSRcWrxngTzpuhAMBwnB10n~eYVS0SZufO!M){kw4IY42Qqc;O;E4Ff0Jnu8`ek4> zIeDe9L7Kgi8kLb|ujgog3imOwOMop0WV%_bV}CP5h2}K@;Ab@uI#185qIM!07uVe_ z;G3NCR2kw8mEF=uzWSU;0K6-=M~h)8KZeeJAj^6-!C>br^hcL}^TK^?zk4(m0?I5^ z@Mf)j5Cy=IFK)shk)MZ4Q*t~+#w{jfjMMN*Wm729L}K-m5PEy7Ty;TjzFnP@3n2mz z@l%2ZNF+XizGJENS1N2CdStw-OGXjc)DNzZ5+6XLHJi7@1_ufpBvDJNzoWm^oW@u3 zh6XM0Z8e;hg^0boPFSaZlMr8lL_?I~M$&~~zn_hjbwt4(b?i}5GSMm8L&Ma))P23K zLrcMCAB@=ob)(zKd%Mk(Zg6O_)6h)*co5J?IW&Gpg0#9?uvA2SCWgC!#l*YqaQip? zgEsCqm$-r!jT!gWrH37N=v&qL>9-Jvj?d*x8Lys7=$l&C7PoLe%$s6SJvQD?IIME4b5ZOmGp`XcDd@Ti z#lfXHRO$D43KZp@G1VQ=*h}CN!&Ue8k&-YOIf1uVrLh-`^s=) z@M2P}cC{mEy?E&N7=ZER)yt+Q+sa zp4rP$A=iJCV^IFp6Qa;>lseVRgW>)gfdR2$%a42;dcBpqz6L~t>_r1vc zF~`Hp)8glRPN$DYW|RL}r*jPt=H{FIq?^R*n1c&Tdl5|)?IxuK_uaexEHB}{FXO45 zES?N7L)RHHiasGCbJt%_81^D-ZHXY+GE;Q@c1~;zU+ZW&M0m`vjmyJqn~$^T{tMZP z;R62v>KmA?sx*)9aIoEGEfJkc#vZc7Y0B;CnF?zt9v=veC+eQ~U3@u`(6VQ(x8g$L zf_BmH2o*9Zk*rG}&mKG=!1UV2?O7_PoAAfR`pPx4hXQQ^R8DN=Lyts7&$soBHB7R+ zoH&tadEy6OV3-jlM*f)Lmtb3-9P)d7(sxaNHAFcv$oRrfoxHXK=R}@GFlih2r~9fO zj^7?gFG#j)zoZlXCw_G|Ql2pQ5$ieH&GcxG%avbp#5${9o_z^Z8d^aWT20TxBxO|k z#LSAZO5TJ}X)oQ8?SMB+QI-`k4F{e>5_5heilJrYIrnRW3lG#x8b?$jK2|xWnWBwe zr>_;LFYb!eU@Jeh!mNKla5btL<*gs@tL<%Ux{^s^bh_h}AKK)v6m)B1vFYV<&+{WaflC{jYvcoB)y(w1hYg@8Z?BcRA66LStFLj}* z*5FYLZ$SVCE7vZ0V0iA~(-!YBwe{`Y8hAIhZxwb#tKQH3)usy~?XL-=hVqWRZ1e0# ze!Pq`s)eyL_PTqvgfsqU=#hsvy-b(6{(Xj76sNPd=dnaVYQ~PtEgduFPOxmPWS= zw40o^P7;COglfi;se8orGUxjj_|#v!(>OCEcnm=S*o)gv!_B{p)3q`)#bv}sy!FHh z&FR_#PGwt$+M3L)ZgaSP8ZM|ca&WYNgR%7ZZo5ZiOSdY%^tYo%+JNxHUQXiPVr zHg>IY=eMEOLCe%HwtGq6Ql#upKD!@IU$1eaizSW4Uks4C+^|)4<#O~e-mr#C8fL?8 z+>&mZ!Pv{>hP9G-QALgN6EBF%>BUBhofja?dN=pn9@$-z+Gvv=Tvb098GnBLBMIZf z>OoP@lE=tO?o|0J_^SlO*kp$ZDPjEE`co`wknRJD9lzU#c9>?e~N32O|%V3mQTek0m%Mwka68E{%`JN)%$4iBbGXX8}d+b)b zB3=-=`ir^6pL0itbuPC2H&M?clu)mKtqFx$wL(B~{-_M=zRot8Yh7pYQd|-`UY#t$S#o_l$kC zZs+vZ^)P7X-TaJtggY#e8JV?O`L=7vRq#zU^e|>=WE>cicKr=LBu`sUdP5kv!RTZx zk;qo6X__fNmli1OgD*tBc) zbcblw;jwm5RS+`jz6LkFc*sRwrsWt-Do68HaVb@M-AuEfuk3v7@pZzL(EOKlHH*p@ z$8POJ>9$#@-rb+Qn0x(RYhc9mMVT~Fvx8Us@2a2 z`4~0q#N3K?95XwoQ%w*B>*htd=rjJ*-%4u&df7R;#fDgVUhykdUi0OzGrnZhj*X5` z9iVEeYJPWX&Huube7OblFNvOs3*)TLfK3O267l3iGMUMklDr|rIcl+?YgQxF*b?7vRo&NzKtE@?o#u%nHKK}HU zTgi;1W_+<#>+QnDo7b;n>*bf!Uk<-QihI=r(40S<%Zy~iaOk{48I?4V5zoD0p6qED z%O{*OUuc4ldZgyA@Tlc{ZGz25_qT#19>*GS1DE)xOw6aPiM+{u)Bzw)K@7i}?Nn6m zTCvfO-?54B_DTllIs-j`AEPS5Ei32aUxJc6>1?r0k;S!+BK}H)G+_|rqscr)6z7Cs z?2xx*?$9D%M+Uc&iHOFqbz<1P;$V+=&$5evg~;YNe>4&a8pT9^n*Fi%8HrlZTzmc6 zjf1YJL}}Q8r!|f_S)dpxU>V97q{981it+ly?0Q+ z*I~hCTTu~kP;hImVq@CKr~0O3%afnOMl*WOuFv{j6vV5=eYp{%`tuEUl2z;{c65a+ zf=4GTbQC18%9?e;{qkQ{eQOb+Kqp#{NQ%Z9VJ|Z@i<;x4#Q#}>U_8z%GB?pW3EGx8 zLIWE;4Q4N>%zmA=)-Cz|-0Ns7_3(Q9J@@MLlS#a?Xj%;mk<63{hKonLcOBaTy)Nc#coF3cf zkiS%|v_n!C@h3x4&ETY=?_;?2hn!Qk(0%3DEUBiW#Hi&OY4w+!`Xv*CzE*3!Pjoyo z06}L)<1r|0N$hr?K7rjMUGC_y(|?h@r^w!F>P8MfaWEGgL+5nX#bWxsq&94VZoRpV z?;DYQS4}9Ro8Y4s=H6=4P@DAD-SCTtZ$uloh-?g2(~q(VkX3_DHEoTC{L6n{*bGxG zUbUNARPrgA;ye7sf;3F^fEe?puGLNi1Yr24mE(8#Z!cYWN@Om(Bxq!Z);OB&;yZ)> z9gZf)OR;Fk?j&h#5jQvfHejpgMgA&jfiEJhzdAN`s9ToXj zQu5}r7~f8k-`wU7?diYZR|UMmbvp1Zgpk2|?QJn&*`Ds_gok&aOhkwDsNwNtN6Co#YIKi@C zKGk)x#ELuJx;-#+HeO{I4gvAQGIQxU8<_C?J*%5PqtaFqqQ6nb{@DybskPd^LzgI=hC3qppa=e;T+np-j8fuCE}3@Yi%|dKrjy00Ub_3tAg_Mim-76DCUb}FQu7^2wk*n4y9$BH}rETEeoTM3uZJbDkeQ|BNzjXn8`AApchA z_tCu1?oCK86LHpSNu^sZ-CrH^Y@ZnmKvqsi`;Wg#+P~nFq;Jqb!GZbv=Xm`gJqty~ zNTP!tq}~Jq*y0+&{=*}XuwOw!a=itwn17_l*pss8`3E>l4;eJc5)`Posjgqo*GKc- zPznRiKzldT;Zkj=WAM*tA3^-pfrb>9E6?qmM3i)bOfSm=R1J{pdNjXmR?Nr2R@TS2 zfJ~vz>Wc8p!du;V0uL>QJ?}Wh>A?NWm{adBjION(Xu(8qqTQhrAZgil{q!eZTwgbT zZg1>(xj)Y%;~gw=|EM=Se=4@ao|eJ;Zf~X!X`g0vB>!1SswE8OdI3SdLqF-p4_23- zb-E2v`zWmP&>-!3^@7=aP4_7Lz=Ej+98&*-AhqWwhYvc?B++ODY-Kxr{X@R(Dm4jv zNv(FMBcOQ~7!IG>9g@~C+Dd+ad%L{Fqo^L+wIx`sKDZ3O6IT!7VTY7Pew>(|ef^tF ze9;2df%Su1)VOEYYr z10dz%wGy!-wy!79vHlx2XUlsR`p5JiJ_2O3rQJ&ez$hR(H4G(v{_$KKJ~zbdbkj1a zWvl#)Sr<93#otQKffTyJ5aO z0=7Dxq9N0V(>ADD=?Po4PpR;605$V!+f#`@LY~1J$*DC zNhev2j*uEj6TbTK2yXk&t5!lI_|nR0{4Yp4)`swm);7!a)jY7Xsn(ltsT5L_euIT{ zlKyJli4wlQ8DQJ1X>+uYSdp6^3L;Sf{i9HyZ?K@p{E05E^F_ZyDqvakyo|XS8dS7t zXl^pLcLh3MOK@ zxa<1|hfg6>1CHt~#OPZjtKUOvWw8&JxK9Dx4_Dff!NMFrC|bT32`GlKATF0?%0nV7eINBDmR}=$QwkXD9xI4f@&C zWz6mx1Ub{f#Lui=Ji1;Ra!IpQt-S?|xK^dPedQ5lJm(>ga9*6>728b==4h$xz48%~ zG$w)UgAB7xlc8s zPojofh&Qg{&*VqMN7VDY(*2sO?4rr17ZlB{u{JpK+UG$;)xF7ua6wc+HUIPL;g1}a zvHf#0D>_%%;A<8zlPzS<1_c2Bv*#Zy^2D~VKB2tk(GqW3dfw-2nLZPA(qAt$QBJYj zIz^lILk1CamPsG(Shv)E5xE;s$%p||Ch&dFb}p*(GyZB_*yW`F=EPKeHO*fa&e#V@ z>0*yad2$Xn7yELOkCapAJSg+u8Au-%!aOW_uB*d|GVJQqnNa1(G>yt)$yK*MB@Mf_ z@@>}vNtYYB33nv_vxMQF>{s#TKdG-7kQ63bNxZSH|DlyDfpcPMN*c&qMeJ`lTPg3} zar0;N={+Z>fPl+8g3kKisCk=K-VucjYQY68EDa|w9_2mxvQg);xN_Y+Zyi8RY40Et z?Ut1FCA$#JaL6Z(&nv$GJZ}8bu7FC&98dUpZSm2}Gg~?B;lz1I4^J z$i{tF(BrEpc=;hlu6L8GO{E@sGC%Z$WYw4M$r}jOC7L>wwILjb7{%C5Zi09?cdoW) ztanAMaA~Ex;51)#ou;4XZbF`XJqOS<)Aj@OGnOqe%?%LjZ@3G_S#A* z44Lc9RLmKhuK@m;-V(~|Hv6yfS4lhWdnhmUtPWOC}Dx;$9W!> zpS72BJGic;v!cymVpvsj-cTFD%gAsLm)WE#iUu$Unuub_-Y#9mpbmi4Py>;H9qO%=m!dg6#(RFgnqnC0P+FrfaEs-#Mjl8 zCE}g60l*vS_wv*FP<)018;-N*>Jn+$$j6;GmN6dVQrfD>g56vDh&O6+`&+aBB-N~F z)2c3qT%q>-NKx)6iUT)?5*|sX*u$6WknWeD%}#}+e)UfaLY2 z@8rLDU=*H2%_vm*@Hvqv4k<)ibx*~d)~ZrcE*$9R%{Lq>l7GI9BiP@=Rrn^{Lg}7H z783HM!qP=N<@w*i=U8F7;vK#y>AvJJlN;p^0Dd5^10aX|*Je2p2w^ayzt|fr2@({f zT?yqGDx_&Xhe5=|6Z>h#!r4v$c;H;ca&BA8vAC_%FA5?ijyK_Q5^wX=#YHVY50A{< z<)-SGWU*ecPD-5FR>79F+T1W9z%IX3Yo-5}W>M~BEr2o}du=j$dMtiw;zhB6;b%RX)LTm7bvz&xygiK*AWVbX9# zH=c&RGxAt_=`}u-8GxDL0$_))9gp|s`qcxZJVx{r&ftJ>+!>+ofPde=AQF`evY+GD zg%eI^Drn?p{=LTL8`An8;LOBvF-uCmTsbcE29mw9=&`oQrfRQo(8rgJKIrslQ& z#ncM8cPs`j~6r`Cospg5|qdc^x#R%cCpWe4& zoNqyjgbIBPM?PPvXJ8Ai>3PBjV_uMHObC4CySv?`_&lxPZj1Cg_F-VaeXQ^?XNfPP z6K&@`TxFdbq84<%Y)|S%K+b)ih)~gCM{&2>El$ojr}eq(npk7X`UgMhNI{*oL?sjd zW}FN80(I$TcY7wJ5;?r5uiX_os*p5|C7;T?mA(155Nim5*i`ABFe#6ZRGF8Js{ys*iEQ&QXU>xvd#lkYa z+@5i;4ZEDZ7J(uOzKHt!7 z3{O^k^a*zM+x-s9ck$p9HO=Xx5=#D=(IHqNF$ndDob7(a#++0Km1gI=aRG0; zgh(g*Z1#0fgG^3c2$YgAcGvuXgr~H=B7&&=Co8#(14*9*3`P}V;iL&l?LTLnAQ1fC9 zan4E*3D`OKHC^nG;NXDWOjEAFx->wg@t@L@UO$Zlz8~#sp;uKzmu}vmxqz*ur7k^sB{f)mh(#_7@|^ z+uOS@f3lXmcFQm1?^#%|D!)GrhoS4}YKGgeZ*b~i8?p{|ka^z;4@gk9rJLS7b#g2B-g)=_R9zbKL^~`dNBh3kRs%JrOVL%X`rjA&Ohxix z&A}7@NT~LMw^#mFF}HEvcM+cFAIzDO0nh@PL(&Ov))k zq+x)+5~yIl(|I?op3Li(2OAuaa2+SB^E2d4UYc<%}lVa8ftgNzrfsBG0 zR~i{~k(>V)c454CBWtT$&&6sV?5A^5A~Da(WhQgmx?UD7EXtl?b+GVyg}fE>_*x0j zvuN^O>zGVBrNg9}`m%r`x+y$Mfw%6ldSQ3^ez?H!J9f_{Zy-ebw^zu3?TcWqpb~kK zp4KqAH~O*ihiFTn)MfH++?5QG)YnqjvkAZ;*fOR)mF>z;Ni}SNQOMv?V|3+H_^kH! z^B>bEGNEB8w?B6v4t?;)ZqJyuX7a8f_BIU$7Uw$lK@4yJaDp0#E((rq-39*C+OdKP z6%u{SXNvYip!Oc_k?jpk+ZTl$hI|-V*7&)O zt3`3J=mO%u)WeqdQ)$OSs4EA_7}ERBd}hF_bd=OP>TK@qG*OMr+WO6K{;?hihHA(4 zuxlYu@6W;vdx0-Kd!S;%ES!v8R6Yv1T2OdqcZdgl6sOULx#KRIYCeOmN~MoC59J)< zwV*;$uJ_(gV8dI@#)v`Rt9|VifgD^z9g1w zMYLsA>!=YGP&qDPPE*+7Y}Z~{(W{c-?w;q|@yPwS#gmj!OP;V4?1 z?P0HMU;l>h{7bT;?Ao6_AJW%t`D(&9MC!F5Y8g)^D8UjyD{`_RE2RB+jc~=rgV)d7 zLTzZ;+DD0_r)WGoeT=~4sd2(c$qHG8{n)CQzecI!OUq@uyM4Jj5vVH=%9_<(A7;+x z&vb98BSfcbRvE}80LQE)GPA*w{NNuf8R8gt=>>nKxHOq#_nd+)7s2=M7|i89z~I%i z>j0_H6t8n3nDZ8-Y{q?wkg))TuUzdf=ul7462}I|XA1^8F11B*SmG^Z?ss*?Mqikj zo?j5Je41R4+_mI7zqsllc^wJQKZ}} z>$KX-*Zf&b^+IKd&+N0Chg-|3hd`EEjHNxoj)A?{%h|T3LpnxE*gVE(Xn>Sdc2( z$h=3UA`AgwsugI1{ zj;CHd9RRh>$nl*Dgrs<8qCYef<5_Dp6w|dRECzgm?dZX8%UCEKQaeS66%q6Nu|D@p zXq{@J;qIr=IQ60z+(?DQ@i<}u7R&;ZZVk{8UJ}|^6v&ckLRb(KgAB{HcHpaJaI0Se zoGhSAXicywAaws+#v`(_5K|$2oQR8;hPy&4wP3cnf zf~LGnnpB>y*5CsY>ei5(mEsk!FA-fsik&r#i(@6KVX}#l%AoeoO0gw!K+o<%q`JGD zj(xxx{fLv7mq)){PXwuFdSrm=CiHYdA1PNCOan|d&2ra!tG$as*QP8PW$mHjtFo<$ zIm9Gj+dJv>%gKv(W^14I))!>~iAs)(UrN%(Bh&jPZDp|Eypf%1pIFT^S>?iSSS3EW zg7#}OjmmoFmgrBh@64uI!#dPN#Xx80?{@JO%<@Qu-S4eDO)-6VjK~Evt2^Hi(4FDq z@%sN;$PlLCbim}CIU>@l3JbCuix+p8H!6#M7M9U?`H6JQn+7tE3uI(uwp{)Jrs_a< zk3+2Y!Bq;wo6pKo5)}RM@h8T??~x1XfIadFw{C+91Ef$8!M5mcx|aUb1Q7n=oVc`{ z>NqgN?sq~Ei5ybfK#DCB1bFGpIQ@n7WY{2`&MzIfzVkp2$zb@iJ#P{nlr_QK!P0|@<^%>`KS zfG-G0c73m4`%i_p$V_N3|RkhGy+zJC$p#i8k3ksfone)6nSI&8_7?)__!_3{w zqaV4>YPYiX0=pv-2vm66g(g(JHUW{348)%uT>tT{^n3CWqEIW=FE`7UFao z%V)zwIaPZ-_96PqC^VYKbbnUvb>xjt&)z*y^sTNaR$3EV2;V6H<_LRn>ITd}(!?ty zV2bp|TnxP&?r|{s_`=@duMts71!fyhg2QT#Q{z07co#xD>uRx;d&p`=UcbygHP8(2 z28p4eO<_|-Eozad?&h8xzt;6S-ELaO$<*SvCV@oQijI;}?FJk};jqtv7#djhPX$1R zPyQ+Y8H*sh{AV?WJ&}Hsy?`o=~J_x8E_y0S*l3%gD-vGDm|BZprnhG_EPNL06}-5$+{_uf@sIgya@^b60H(-{X%%% ziBsb4QUEXLl4SK_9Q81(&;3T4$O4fVZf|u7aYE%47XWhko74 z3n+VV`cIiF-Miw49E%M{Dw|mXy9FH1_75S>)IzQC!?{l~%o#wNv;I^Igfud7;|U(9 zHj}d??wgTn0U{gL2@gx%AK9!I^neXF`7xN+)dMDFakTpLP{ftxYM68SxZ^z#hW~^2(~|WC7B>npbs~BXHV^RY8rIGliN&2Cg1*T*3jf#5-h@8yh zTo@lnT5-k(+s)T&-T-w`hO&C^-#J3D{gBSvI02?gJCDIL&G9=~_>$S->7M|BBePa< z;$lRsezp8g80g>uyfNbRL1NZn3rgKzBY5z!l7rZ!Nz)+i%B*T@A&uUd&Jfyo$Y@#c>!0Bu{R)adRzs`=x>`2^SB*+(>)Uu>c@sBQMl0m+&O@k;HbMZr}yAZxjZ zFJRg;7JQk{|F%`I_3ZT@kAQw@!n#(QY&;`(+9A3e)h)mnK^DK7-*4%Fq7;F zKP^RICFGS?VOK6N>k5Nj9Wqq;tzi}B3D?VK%bye~Y zsJ}Sn#+a}^YzCDW(Zc=u4@+6-%hK>GK?<#|ZJk!t-xUw|;u?SG+V*Zf;j7F2)o~CE z_5x!m0zxo@mA>-57%O*Uh+7RBnR?PHQRs32Vq4%3LH59O>ytiI3jMdj%8|{~>8Si! z&9V)sSUFGUGfG4Xl0x}+ZDaG1g$TVA7X9-D`M_~UBwe-{ z#qc}C@=JzjwRZ5)>qjVgl<#_naODp!nLY#1u3f*XdK0T1@&kNrxhKq=QmqoA&5Rux zd`cXqdJoK#I{e!&Lcpp|h#+xQnR_ zCBps(DkB7=ttaw%!{LVOVuEcX_SG(D;`Sj>SU1+l#?y)k>tDg~RP(mgDf+X1W6JRP z;#=ou!|yS1cNJ`%eEra!a!in=`=8gA7ZCkP)1#icsy%Aoi2^F!!)y z1A*F!!H~4j@M};wsfh;i7hdOQRJ0$4tuX1ZIONLz*|2rMkl+G&aiClmCBwbCCEO&S zIFNiC6gTfIdIwTKfuF*X0$CyfV+p$+a8%l2J>2VK9g%JUl1nw6@PiUUO*POj4UZH{ zr0(D|WcKt(W45x-4QVf+4;#A4N@3n=9@Dv_K10sgI=T03Afw|)HSSWy33H8u1a#tA}rH-t0B`0S|U|SQHDC9-}a0xR3;yZ-U7`{ zvQ|uUi>kap|GKP$D+!tsbxP}bPvr6b2c@%vz=skj^_?2%RWM^dT4UvDF z9-cXUPX7$+?%kFmhg+)gz2uR0fpL%bwzAPx&iCkK^+?*?LUL}<5GN=YLQzhu0yhod zs;drznXZq{GU#Q~o<8lWiAlbCXwSms#Dz4kmTIa<$X{}k>uY6LlGGi8RMh?9W!U5>hBuqdl`e0qbC3n_Sw@BbRSh}7GG zbt;&KT3{l574%yhmjEaG;hb6BqdXYI=@ks9pi@KU3Mn#Hcmt%?o@5}^AQ4~U?W_w= z;4|4I7E~>)*lXktawt_~A-`ia0J#aK-yd|Hm#=-pQVlIkZeSf27LKMc$ITCn3qtAF&UpmVfVhLH1z`fAaDboV z9OQ2NpleL`n;gz(w}R9AKF8e6J$-)SfI|_e9I05{mVi0Va+$g=$y`Ct}fsSH$ zWn3}ep?^ye+m->ChEQu=);kXH_-AG4LzMpnoppFG#UWRhQ&@k~Ym(z{ccc^Bh^#Wk zSF!xQ0=|tdCifdJXA1j$=|AbrD(d7-g~#c4l-B4EDtfOsDdWEje?$8mnxv?(a*CNP zSE#8v=GEN>NKUQ#oSs14BK@Rv@NP&YLPR%~Ijg*BlwBt6dw)Dkk1Qh9OYWHau#if< zrVI9Fz6(yTkiBOZ^4gUgSid~YDT zcquy>IixP2kg~q+8sJ1gWZcXYK~8`*%jS$SpBXM#$)f$G4XJZa3zNvx%sNglQ{lC` zlvly(vIdb&g2gT-7y8I8vtnKvXpW<8a+eNRS^Q|rDM8)|9Bp6#-w^BE(1RDy?jzhNV9s@O2Z++RePs7t zLzD$*J$gKHn|^!KfqIdtK^BfrwXMTf;D2pQGtC2|0#v$5AGl=@O8m4xwGCa{&!P%A z#J^q3@fU(++3s~1(%@t9=BlDFtnjb%J=O-h9AA(x-nREtxg+i4ayrx)YFl-b0q(4F zgseKC`)^pniRr2%PC@69XJylP z_bWqT>O&li)DSZ>V&~ma1|SVka?O`m;#B+a(1ifk;Jwg8pY_J;dmDlyS zDT1p|lN4FerJb1gCR$>C5 z!ftr&cTm?Ero(1H;1zRkUl7E5c^)c9pP4>)^G7FWWwk{S5uE;d_Cj#4WBGTU&k{Tm zD*`pp^|teQJVTm`@NU9gGkkpIltOj)o1P$j;&xJ}JUjE{bkEcoq&f$vvuOb<5>9w& zEOBURoVKwRfhw+Bv*_As_2pt;Dyrr`L2gUfG6jqd?dvR{Ms$eJ1JMgKKl4ik;LJ;@ z2qRASNUH)2fNyVymC8xZwWAnOM`dq8^?@n+cVP!c6>{g5QJE&j#Z$sqh7+HgRZ zCoL)a#Zw3y5zNiu5ZJr!tw5yd(h)uf>OI_zoxHC|M+d7OG>SxE_$Zzel7W2@xfJ&T zND}qEvlL4sRpvE&gFLsl)b?V4Hz-)TooCh`Q@;c4{U{*n2=u8i6OnDriD)Z<9-^S5 zdF(J6@iF>00P%t!Lcd-N@|k%FnocQQ%G<@tu5JeKF*;gU0j=e^-w~%%^N4;83hUK{ zW<_-e3v|@ZJRrot;RhjF)wg!P(wF7nFoE(J4+zY!drH9XAZVrUQv9w6moNBeCk3klb})5P#Ys9AJBc)%|ChJDm>*O$%Yj0 z6-ij&4d5AQ#R5s#Sdx8!kiFK{fjx3sQx1cS>v0PCK{p2BF7Ma=GGdP13};TaH$&Q& zY-4KnH_Xan#W@YWO25kN@^%!ZvmjD$x4<$+BHJ@Q@;x@);$Cu2HkMK7=4=stW!SaW#Yk`4*xL(g9C zbdjE?BdS?7v?=OAz#G2q5|g2jNLXP)?Nd^^>!tv}cp zjArR0t05l^AF1kHajv!BHzyFC1q*C1?eXr}MZa82xQ8Iku^#2+}l2Fx}278&9J z-505m1~JT~_P2riXq>8^E;#dJi&rl)>L-Y+j?QF+*BJKl;@$)*rEakgZe4KV29<54 zFl+{$u2`cX=4CBLkS^0Z!$<1fcjpL_|UL$#!U3AuUKHYH5ZV!9>SqGT!0L3oK)Wt70qA{XE~yKA*Fia zc*MS0mA+2PAs}*0!h(#cMeF@Duq#&S5G(*@Q)eM*J_sL^J`?~z-PzMkY`W*&KnnBw zir`z%f7WRtwQ1+|l5G8A(y1nGryo-LuT4TWt(Olg^V4yg30_1L*d;UyZF;$#8F(xA z7%if?lzon{bTe%J@@4jjij$)D@6FhN$lURC)!2o;quFMFtf21HKhtLMq(KHa_Q2a= z#RAnsz&gm~k-P<~w*)V9W4EM&VPuN5uL=Ay#A~zitPo+lxX@c}$JCQ>#i$WVyVV^G zY_NV1-htqr6=*2?1LWZgB|npfHi&6D$|Mb!pue6h-Ee^KOAm(~E+gD;fi3;eEqvF# zp_n_=h9-ngiJ8VM=B5l%()2-Ek%*F*5{S;F8koSps+gB6p(@w|Ae4Sbi347yhFcuS zYzOu`iYbnAZHV5Bff=~dRyi>BE7TUXs;Mhtp?tr_ca2(CKe5eS{B72{>fkLSLO(+c z@&y(gqSNp1a%xr7roJ=+t!$&wu-$-18#u9G1>c?^euD5dlmljF_)l$Rl83aRjZeNs z&*@08xp2@lZr%McdUa=-2JO|%tSBX+LERvRcM(c7grlQa5=%_tg07-gb7}aFS=yCJ zi^O*lPTtwM?eVZqh~CMxE71wQ8yE}O`8~aKY0x1oKCbA;mR9(OSwSz*mR5a(RTI}k zR5tIx3tKF0;jZ@k>rt>W$B!gb@I*7^=N9{|DDU=iQ0JPy+l+XyYcwy~M7noLc6sDO zkE+UtbL^)(KGF0Z>*3%{>whcm)gUFbOE_Fsz1;gKm?-7XYJG)OrMJ_FT6kuKF_Xk~ zT8u+a^bz`+2)UwgKDOK8E?+*m8quoz+Ix#nS!THc)SY^&_Ayr8$U*YYD^l~MpP;lP zr|m&Zja|fhP&ERoK*39l0~b;da$GfqDt3)G6SqZFoK7~^4nwWXPM=qDbpws1`X2$r z^v=D0=_-TPItG=pUc}3icuzWj!&ut{2N74Bg-!x3y3#;N*j^UYp_f1nls>{9Ai;2W zt_PI2fV-bHJ9q|bIpll+52#azwBI}ish7DZA7Te6;kxst;xr2NF&t^#p2kKa936+# z1?G1SiZ396;RREBg}Hh$7RR-Fz5o{FnhwGHg_bP>!xSF%F|yTF3ZQx&_<*OUinlMQ ztGrKm2gun3sFks^*CO_ZLv(KvYkz#MNbXL| zCDY2teYQDd%q4rLY6lRxjAcEj7%bdA7^V2T4T z#{BYV3#u#@h|EB!-$1H9yDRjJ*7F)xTUl^z94!K0`XolEHqWFu_5%z0&gi3+vt9DYmQzBjzIh$rHn`lTWVLii7o8A*hVEJpCHB}*?+ zRA05aJzv1F%NdRoCmauR=*9~cc`;Z$*9RCPC9%g%|u(8%~N)C z^!b$U%AmC8{a-u$FyT#xBl3Lt(}%skQZ}AaMV>9fKYhx+rMw$(;TuJ&TTWr(CW_QieI(UPJ9cR0WPocVRY?L~GE_S2C@2d{yUiA9#V+631`$}p)qvAJS zf9nj4J!{c@2!z}}FD>1F^?@#Nh!9Xyv4(4R;;5STsG$MYesaNtuC0^$SvW0_l5cx$ zEdgmVCs$7pCiM>f!f$AtU|Ovj>{wb?bh-`ffn+!Bpjg3+Q2cO&*@35zZ2J7^C@wg= z3~0pED^Nz#LIv}2No9l;!)RZ)RDQ(!MfR7AZp4kik?$K()EooX3fR0!5p41guY6KO zl*GI?0^n|CRp8|a7dNIG{su-L45m_@CH2iJP(e!j{-XK3{ul=+YUh1oVoe7vQv$C? zt`#~jKL{esIEda60`5`kT~0Xa5UgXt)hYbLgaZR?8V2AMz>EG@ht-NMvIIa2evbNd z52U0)er|HffJIXNyTM@PNuCiy4&EM`(m8+-Q+F}}F@RDdWb;ogh{YmjfQ=~jdqt%3 zD)N#Z)DoH1@%s1Lrw~IwGt50fN@E+?cBM8Gxc23D$l?ToDt!a27y7oyn{)ofXaCzs ztPnpk40);-kHJV80;Kiu{4z9tla>IK0hD_1WA2~{lt$`Oyi(l{ep^Y`XCx8e4r6AaP9oW~9uofzhq`W+Rv`i*Iu;s&kq)S8`S6$$(cL`&^4XiMAwV41*{%-fr9*tXwO z_3>!s>1tkocbkZKSlA<#ET`(=%C{q~Pn{oT-19Ge>7g`eO-b{qkMx{(9&ykN{OE8w zt$ieDmB)Gd>(sqdF0b>^l%X=z0?t~6#KfH1-qPx8-*LGP*S4>idoHGUPE0Qh^&Twu z@*S|I%7^8bA54|s-8OF?!HL!%aoPLQ9!0;_Y#O;0QEZvxYvE-Q)@kx3?P%{!qxW2V zKKb{i2I9xJiAF1bcCZUTG61ssR_8~|=SqlfqmBO|FQHpB0~IjiB|%E)z<>wKB)W-& zpp;fq6WS1;8+XlR4lqr6Q`!@?j=F||!AMIZR5tju|2y)3p zF0g~~$H^-dDq>h}Qodf6_QSP6v|PyK2Sw!vhXyEA5EGJGFPOW|P1|;rTd7Jp zNE(if2SCz!`~A|0qB{AOET|A4$Kn31qPut!;od;1r{%-n=Yo@*pcjWbbCxI+s+y^I zoS!pqoK)pLzxKq`ZdQQi1?3I;NS}s_>pzH7*5M}tB<3Q}77>Nt5pie-_4QcrI640g z&tRcxEyGmaHUxWv%>o@ObAad7NAU3e zIU2M}NZ|b;p%C_JCw=KmhxPEys8s#^=k%RZWbp8i&I0mIZk5mo`5=!Bm^|c*0Xp3D z+q(NiJ-dGEoQgbb%WSjBmIxx^qlJr;=O}3LYv}#Icdq<;*!y!?_&$x$y`?9eB&we2 z`6I@{PRz;AY(W)V7kvmVhmC#j>eAjR57>4LPOX-8<$zudc@Z+4IXX#hiX;ZHS*9K@ zC5M0Z!Q{wtuFz~3bL|N6TtTFiFY(czBacv4T6F_EI@`lyVyo(cwsGq=jw}Sv%*QRe z7%Hfd=IvKaNn-ec^b!$-T{{oM0ha z(ZO#fyyMaHx+Q%9;i7{y_P)2?_~$%uje^$ z$0+_S=;1fs6-Wm{;60xy=%?6wIrRMFq24sHvSw^yIQa|pN* zxR$7^mQ<1ut(Wk6I0Yw-m5;9Spb}f`0+k{GTu{hgg#L@nT{#e{;$M$+o&j*vB0z|u zw{fHf9&!}wSJMH@B{AEUYu1p}!-_zZ>uy%pE2yDc1CREFGw*qEV8V^YeGrNQ-HAdG zf4qtC)O=5wmpmu$)yF@Jc)L^$BfxYgUQrkHRjOPv-%;^?2c9Y(#D+9kNePJ zR`syLcL2lalNMG|QbE5yCtCXDyi^bd0c?N`>{3#hPB>gApE4>jM*|Go-fAZA`30p* zHmiGNn$Gj2FooSvc$NTC8iCF5bLeK=RAi6EKr7(p_or+;?+Fl>NTmqj0m-wY7ZLe4 z+fN31c|yhIPUwo7fTT_nW@az8)U@00kDgQeVNdi+t?M~F7u2li(Ww__W*v0skw6C6 z&b7dwmzSi)FXDPbSBuitFB1V<*8{MD0U-v$M-(J85k^y1%>Pkx1!#f*dfIy~Uw!A? zab^kjd2#sTkE*{fxTYz*GwHr&^JVus_%Y#qTF>npX_Z@F1eGf&)0aGzL?0V)k3TWp z^Rv(x;JT}?aSyI9?%A9TvzjvR_V~4jR|U7;Z}xOLc~mPYg~5|Sd4qWqjgg2!>9FOp z3QB~7%|>Jmg-r&ov{&B}MZE;A2|VCifVpAbz_4}dVgUu_+k>HqUcr{KYUix2>vjMs zB)VT=Af$!w)wlZ8GN&CVl2;Y#kCCMTx`40fGmOZYgz&h!KCi)H0xeBPs{C(H3~Lbh zKpcg^$fj?ygE7#L6>g9HyxiXeUuR(vaOI`b^(vqXRZvK7xW0c31qZgcAb^$Mw{*(R zKkMqZ@mIDA%qZFDTFw%d2wSg|cIzj&R$+*%=Dsh&c!dDO``w1`X_uU{s<5nTY9t?Y z7%QfB+{vqLe>-c9U&W+-M~M{$DuDi>5+K34eoyITC#1c&yE;f{wm^{lnsVLrCmamO z2CZ3i#-d$HS@_)`evLA;<-jdoa@uZ5)nyxfS4{T#lS=l3+)GLYWgpM8CG>eK4j))K zmxfDNX!|}|8-)$6FCq1*aS`2<*dD()w7s{L^f0XW=F6WNtieApUed1jzq-B|0GIFT z9=~X@>qxJ%-CeeL^a?JJ!7IA`Y|>a^AECQW@N%f#J=SNRZ1SFCNv_=GAU5H*r7rD+L2fsGPjf~EG{*f9JaMVYJ zO?)kaYzhKEzg#_{Hnz2G`GfsSLE}^~QnmrAvPnyFb;Pp-O8+TK+=#X(&}D+~8mDy& zjIdBDs&v0A6Y|xc@eR2%<@-xh`K7BFLA;AA2WRiE#80!l5c)^tw&`hTWuQaKP( zSKo_jA@ECU501)lr*B;{1*sH+hf1WL3w$aX>#AAl1&O)X@M&cc6<>E*TvS|KiY^-V zw8FD>Tq-afLXLN2EzMa(jjz9ZV7`B_NAx)Z{zit^VSd|0^gO5f`St#;us9K{YhKh- zT}KDsmJ$_TO-HF@y8e-dt3#lO_{{~biQ2je6}+c>C4~lCy#e~Xl&fF3Jk2)jKbkLZ zVtZ^B?i7V*4kz?ylzr^}xD{ZQ?`5V6Ya7RprNwiJn42<`@3@jPu{5RW5JevJ-&ceo z?A{hM_Hq3gVAh2YT~fX(%G#$vl20Dc#UFU2^8JyoA3_KEQz>uHR0WfdXZ7mVIpJ-M z6)1|<2G@Afu1)`l&{9;_FB{TkpyWn2eX(F#uQ8M@{knG~BsCSJwQFiypy{xFwSq~+ z$mj+w^38Re6mUYkjCsfIRMK+q*0X0VrFNV;o;D(X!rkGEmZ8(_aN+ySi+NEgw(qg1 zu!w-rZ60giZi|Ok%kp)F+sUH${U%pNCC>X6Htne>mZoO<+L3mKF7@)+=45>~OGRLtC8S}JcX{m-nmB?dc`&x3BReFTJE)%D~X3wIl(=m$%pfa1$ zQJYT3nu9|kxJjch%o4&&EUXepPZP+8gx4t^P7po!!#0RZAb5XN{w%p`#wIgSjeGl7 zywO^$Xi~~mIxglf3a`m{w>yn*=3O?U6&{l97vX~b6@An*Pw2xCnYwu@)a-#~uO81A zULH~0yn)NiP0jW>%6@4-lC)Gg1b+N&xgEDu(`~hH6nk{;DQJ+!=<$|l6_vJmZiM@O zs@58@D*due5zO#Ds^>HOEYkCV9oQ4uH;oxRo~W9vmYo)Y%9^Rb!qU>T(aXw{I9Av5 zo9WtlYxKbwKi=V$a-x`lcadZP)(YHL?YjkRGvx}=Q{_*gLLW*!82g%R_o>pkf9&9- zs5*fSo;g4j2Jk=f;3t&;gaNY~3Hz2(c`FD9qe0LMad{%3>cbMx@g}`OyuW<7LRD6A z{KxsY{k&wPr5yuyw6Z_ny5DI^b|C>!lLc8rsbYYN36a&4o9P76M-FHj~SO!v{|Cf25c{ zG%^U;U^e0G>Awy*D8o8gCX%E`Sx8UrGrO2&-)Ar57I?JsOE343a3i&twW)cuCSP@v zAj7;<7kO5SytT`p`{X6Rhl_qOSo}UOJYds zwtnog2YdBnox$~lgXvz5E<@xPXx+#;ZHaV${zq>UmVPZJ&Rlbao(uh~9(pLXjmhBT zi4CpZc#moQAyGRaG6FyJ;NCOIW}Ied0`*%n+BXJuytaO};|yvlsbQvc_L;>!>#_6d zv#ZAZxfxJwX=i|-M2lNa_)!_|mB&TuVg(cL1kcQ-W=C?()p9Z9{<4`5!Om!KiPh9_ z>;4SbVG1$V@574j4qUGdwSXUZC++R=ulXazZS`BNzzLPb0%Qj|ogyS8r?oq2V$zZ^ z;_hXROr)cK(hAZKrnw)V@y~4WYkF?>`RR-$dO&PJ=%Dav|YH0yseExK>2}B zWTahR$r0*K>$fROM&xW0d<#8PboD+<*rH$%a3kRLx=YCgxq7xTjKJX4wxJc4EMLA7 zVB9ZYKB9(I>u^Lveg3&=LE5;h8^p>SOR^YkuK==c`?6ydgD9i5LiTN}Aj=G!Ie|KI z3|R4J&Rs~KnGnI{65XPR%Gu?r%Pp-}p7|hOpCmu#I4yzlsB9js;_nT{9r7reA&GZ> z9Mr$V`sJ>mHF3!!m2Ui~gI(1OY2R&RbwKHIfN9)i%+lX|DanCR;`;FTevyL+wOm^J z>lgliVovY!-oH$?z36{qyN&LUU9CCCtKg&V_JVkD2lkex3~9`&vqhJe!D9J{Id{+$;o#F+{3m(Gp{Qy?2=1kGaVpc=rcY)Ro<9+Wq5>O>H;0#Z zi(tqeH}4?n!{xcIWf_G*jb!m8>co3Q?aq=)LKpG=J5}doL2Xu2Ddz>+pJ$prCb_wB(Gq3mTwuy*^e{<<^i2 zLd(Ksak;9(lWP#Z4(Q1c(}m9jC;^XyVUExG$9N~Y7W$1IvP^ep9Lu57XNJwy%>!u# zQnxl+p4^WjE2{wZw*j~BwIcu~Xee6rM@SSm5cbeEgZPDKxz{6=+^v?FK_+IgI1#W`2sF{VK09m@F5o67ngeo|7eT?Wl ziqFBpe8+y@INluvWvfrc-sR36gU|~JZt~XecAdlP4uY5KW0%dZYjVn>GFbq2`K+h@ z{E`~?Y?oA$lHOc0-D!Uhs0zlngRWQAU&kAmj!xjK2mutu=Uek#?jh1y0u$a|6+a1v za&tckL5xvAdzKEJPViH-m+`Pe#i(=1kH8C=C8Zy}&*>)sIzwNFK2533WW;UVnl z&S4w0eqA&$!O0BiTU^zx<9t@8Zf!sgx01J)n2rGTZK$er`Jb8sekaTXbs>Tk{LBT= zz?MH6fj;HVer_%W>;?E9_%|&qh&wSM_57KZ8@we*FOPWS;Nw^~A{9YH&`Zg6c=E{% zPXH&t?|S|XeMbFleGYsXdF1nbbfQY-xq)_Rm+I>1{Qos_qSagi&wl3Hmc<@BxTSba zxEEf@B?H-&JazaxkLbP5^Gg&{jAJi*9_%AeZt30ls1FC zgnoi;p?XN4&;w}b!^K8|_Li&e^1kf8KjtIeTUqc>w-Q>~$=x}8DE_1e3;)!)F?XL! zy2wSD?2?(-XR4n7_*e2+h=G0y%DT655WDq3 z+cVB?o`2|^)To;th-ZEz(K5cumUEUP`Yf#(V{sK(ywFUcC|Za~r%+fxi9>^y{Oo15 zCs(jb8tPhAK2vk3!5;kGVX7YKq8ihv&%Q{UQ{*I&DFVeC$e(%ZBC&+v^rM$J$F8QM zq-NkS2*;e(CGK~}Kf%c_<`Ano{#h*BA3B!9Z^!>m>kG)(A%!Cfo*$<)==&Z~MH?<% zYpr>84p>A4>Z4dkQf_!d;tmdl9jo`w-#OiCU4?dtKMoLQWbB<78+*3Btu|C)!Z7T5 zaNga?n(lCS%CT~Leung>G#f~pPQKAe*{Y^S0A^CFFR#DeCyE{Z%Giuc;CkhwNa56C z78!d+X1a~Wtfstu>4!16w@u25TBWJZ3F>+LS*7)b(yz0&3c7^KogaM2|30}_z9iu% zy>ifz@vKJ;#|z>}aWFn_1x)tzLKFRSg5Vp4Wvk!(=sn?v`#dF49~ z;@T!s+wbI;dI;QsistH}RW-Gs%40{0s^QVs#d7*rn7gISz0{I2iYffr-l+9`E~;zS zPrBtZKihb>NTq~hB21I!r|f^$QKTa98$fI5cK`?jhrvZ)zo8HNnyAP~3YpuI<8jUi zTvVmG#vBfz7>KDAx_=~pZwP%%NwNZ1F1?iC$grhYfkN?oJs0tkJYe?ADg0m9;ygbE z+dm-@OQfBn*KH@|uyJr1FXyt=jQ*N4K-e(kHF*s{fk}EGW{9#BFKdzJQVTx~`$cj$ zubj}|jA${10?LxF4H^U;_AKqfWz<_a?kYcX0S4?lTI5+c+EmM_9u{Jvz3B8|fF^&A z(`Ns9X;9U`Xu1ZWg)P5+M^Fd$Ahpf_P(#aJalUcy_a>M8yX$Y%Mf#tnp&Na@`B0~l z-Hqz`8S zEdql+N(qryc=TE|AQAz)yS$*}nT$U_DfL5Z*`E^7T0{mo7i>H(aqHq=Sp~pvw?H^^ zqBcMwz!IdPj%`XK4+RZHP7?Ex*4D^BSn0_}mQM?%G4}+oyugoF>e_vakjn@%}EqrnyFd zcAW{a)9nTvy!7`!Y4eQqET{^3B7DXw0TS8RGmlHSQ&^q*QatTD7V@r>>mH-Exd@^? za%EO9*LCUmXPg0*-JIwJ5DY|B2qMpd_c^3M zy?4}h0snK(QP;}U_YKQW|6*_SPRhrhhu^2>&<3t}o#21iq5$LV<0VbnpS)j4^vdep zbDs31Hx|(2{OGnW%i9~F=@nAab}dtN(KMs_M!?40?NBszwP{;m%k2sn2s{6Pvz&y& zQgoyNe1yt_;9U)C29qEgw10_-xnIoHry5&TdM?T5)Lhlbei#oI*h!71_gka zXhF9f+qQx1c6X;j5~Ofg;_O7NTGJV=m12=TG=+9x?#R;*Acre6<_x4dcNFWc*^Olu z`{N)jf9Fb9f9EOR!x+ime>T@38w<3F5obVyYEIcsIv?`Uw4&??AJ(Mln+*t-d0u^M z+mB-Nmxbw|vTHa5t;K4u`5&KM|M|r~A%aPEd|3MQ|7F)qRkvrJKfer494WbV=a4_S z2k$73KL|>pO$#w^k3r&p;x;~%1qmC&o~N*)=+qj)M1A`Fsjetv5#^oin}%VJW4l?o z$=f%zEV$IMYY@r0d76EZ>N4fU3!Mkc_h~K93}0f5JZ|g|i$q;+JM_vRXNPrWyk)<+ zaFbC()W%0Q3!4cOFoOQM1*n9ti}3+oF4e zj4Xhr3HaMU@8X}{ z8OtJd+oaK#`$U+D?uE0HSuIMG?e@|Va!hs7tGG{610Q3f<^jbIs5_9d^HdNiA0ETc z(R@hPu%NVpi)D0v*>KGsT&=geJm*?<2dLwAP9iE|Ilu*z2O3|C^u4UKBX%_|6e8>W zcT~mklSO{12%#1{d;fk!;ENXvL`UqLoOO8|Z}VPa(PHEuY&} zc%JoQ58JpgkiVL&pIww!baVUMtjCX9uoomJrH$D#p%ODc&6hveF{L@vhWjvDSH|J< zgsqQ1o87HXJ=iZcVGnm~OVV5;LLo6mq7yJV$?)MFPqA~Pg7`?gO(MpkQMg|fZnpwG zZbd4Tk=nc@ic=Ffbj!~w%G7PE4QhWE%lKteLi-ZwU&O-PH9SWi{aWy`D~L^6%H8<# zm*PVaeG*$vJ+!O{D8pM)cqe@V5L5;r1eGq65ycYw-nUNF$o)Rx@I|lml7^_(%*NpI zG6_30MiZ*-!tA97fgA~AoE|D?LHYFET;1mO#Rhi-&N7ip@9y3r;0^&&(6EuZ0+S}_ zIgsXeDYL-{3UotmxhK#Yz_CklEBylP9GY%P63B;)-y)DDWzsPN93bx0e`>Geb3URW z6h4F1)x5y)@WyWes)h#o`5`V?w#cV|pE9(J=V%w|5WM zBL{`54`U%RPzcnNAJys2ZlmhPGzl#tQ#D;WSFKL1Ky7;ERSLI}Eg5vJrvT)%S@I$J zD`Nx5$;m)7T+_-EDjD<5&;KsjzFYTym=Fq2@SQ!WB4EnQMm z$l=G2B7CgGi#(!;Du8B@xDLPrk1H)f-YTGm{z;>}nb1=siPy{LKIR%jnn`1I3|BaxO`8n1@wf*uQBivxn?1!}gFX z=`;F?e!u3(N%`+{kHGrU+&qpLC8npz_pcvO{0{>bMlj&A`08iz(s+s!Pg@^l_%hNk zT)l=#N;(FruEu-MqE(moqf1M#*tTCQ$S%?@*!~lDH#$8jBXeO3(Ba?@W>#2PGuJMh zw!h)$gKM&%M07FW_y!n$)7vPn#7OxAAun+JB%T5MDD8V;7x_(<%IFn zVOrOq;Kc8b4_zVo{BA3x__or)o_ipI+|&Q(71nco{~QY1evJ3oXK`5XF)oiXe8s>9 z2XwGM*58WouNt^zZpbWgo`tfMFq#Uo`fx&T zC|nlHl2nimRmK!~6`6W#Y69y`IIQKbFnB;j)-4Pdh13Bp5t2-_99OUYJiOyts?~MV zYVM8)*3+kE6-Sm97Sf%cyUkD?1a2`osUpW_AhQbypphGh508_G7s^n8gjGOB-;+Le zdhG=mD$TnJ;lo7oBTteTkl+!t!;rxNFJ9YRMJ!ejUvpSTw_yL*p*Z~c7$&<|*?X`V8t zPKmlzH-A5shqKlj=@=*_EqB~D4h|HzUhlg6vqT9ljB9?+n%d$pQGp@BwU>$caXrp; zCN?1@9L_N-gj_Q6)^DLHP;$vbfWfyWsQ9mQ*FPfeDe?eK;3e`mwb_~N=-t=*@9jG_ zP}X%k1ic#w0I-`u^fS0ns?RWvcYt&jh)$t66Fw6A7zqS~1!(xbwsLrI^i!MsSuLl) zX+)uk-{u_bFuS3O(`tTd6$~*H56AKlFnA_kHHc_!^d}DD`+L#~m|sJF;~Iy5h9$cC zZsC5z;1$seRKXx`8@%qa`cBnm!tPK3xtax(`Qht0f{DfV`xMXfam|ds6QHS8#leTTm2I} zZ&Q|vD=nRNA3l}8M6fOYtK&cZm~p)xWmYeyNu>!xWiS>74;zgCrgmV<;xH(kLp93z zHM#p}lcHA@4Z@E^MORiOL#1Uzba67yS=Lh=8bEk0wCb8D4>__;sYW0z-%c(xO>?#e+;b0M?8k zgnRRLPOgj{v`4w_-Ckpx@4BF)A_5^-)6VDeh5`6MQ!*pq@l#*rAv z_>;U+SN8Fxye!*g_ByM>72^VoaG4a+8m*)9*=wa{!xm4(O319=f#$=<_ZO5%jydg` z>#BKwf(Vnl$dY??8$q&ZRlETG+4(`SA!f=+Jd|wM`BPE-XEy~Y-is=~fNLGQb-8AX zbJFt-YiWpzXdcFLRhbU=`pQJR6Ql8rXx(lB`S&QY(VJwYTXSh6YIL1r;nZxuOWIIT zGgnYRpW&r+n%GcTvxJ=yQvpc2`#P~<_baX#4zAuuz-H<;0|0qV({g~0z|Rk}c0hJH z+Q}bTH=Qr}Z-}V;wEjDE3vLW30@&d?yM4YvG$8Tg{P!BdT+sR>g2u6~34R$J4Sw!lfIIUftlj6TDx8@7>wS&B=LjaKpuDXVQvcJ^v1%) z*^3k;-hhCeTo@r7RGczl)crL0l#t%By`61*+YQ$jvr6(7+g1Bg*KY5VV!>XG-o(MQ zSBR4wo08v-Kqh=uT+f;Q;Zr()uX7Q1M@ue65F%H4Z;eJMNDo+rw19Yq4J{e=#nlfw z$I(N#W!pO=U+Tvv^#~*m!0HzGjw0-b7kZZHHpgXlQfe+h!6?IGDCn4k4gpGTNE@uF zKM1av-gA;Y?M9gMXEp!o`t{loMe7*AK(?oLdsk<_Nk17h|F@kqEw+3U9)OlGxQN-%CL|fk#B! z49Y~CJ0r*~<$s8&&G12jn#Pkl>0&rskUnb&k8X8nN@ZDNA=HGo*U;>C1{y#m^+&#l z61el=-M8k`Jm(tlOioH{bwv;xfO}18>L5tt60^*;Gk=cAXoL+n?LZF)zH%GAX8nR^ zaWF+_;QZ!i2dO zwnJ@uRCpRSdZXv`t5oT~iACI~kU(6*le`usm4{)w>1?Ie`D8GhXd!3&=R{{#4U1aR z?c8k#Bebu`DEhi0{Vpa?{R70=zl9Z4(5gB$!E$AxC*CuMla$jCDQ_Yl21C#jy9LShb+@Y*C+@?d4({`oaf zNTKs3%~Tpo3_e1nS1)EqzevXCC`=$E<+-cL}F_U~fo zfp-{jC{{@ri4D-ZzrPPj!iqhdFO78m|E80au+VF-#0!ybAg{@KIMtB2pubMQ3COg% zde#0v@+tCp$V8I{kxi9EEx-T<5)x!5Xh0~>)tv+zR+zo>bnVeE?Mzuig>x^6;+y)H z+f#}Tyuz?Qrige%c5-tp;^9BP(m6W0?TtLnjscg%o>R{29h zZ6lWIqr>MPmR%X9F~u?bz0XFh0SnVY1jtL%^FXBxNjjF^8a+NZDmrt|nC@-ls33L( z0)N97F4l9Hh=#FKA4=JmqHJG(liWu+ua8Mn&@EVPI6uHL09uNrWFD?f1gpsy9RKzO zli}mQ{+}+joiDfnem6u{3s}f|KPROE;VC@imf&5vJ_P5FZ5uH9V?fFphS(O)=HJNL z1JVt1;mlqknf&h;Ak-KuQaGB+T_#iJ&SS$)JATsdUJ2a1$Ggx4gKw5Dm6JiNROS1F zq!0R_D8lmAr`cBbOn;EeZaMuJB@XeW<`ik5vj3axm_hjWvMLzMgcQhyLufbgZESB-`p1d@p-r&hh%M55NRxCx@uBV;8Scp~*5^4DZT8gq zm; z);acG_Wax09!#DU`h=ckhYlH-Vsu*NDB__K@t?Vf!w6ts_4D{wSmYjtT%q6Wt?4&# zBKdg3gGSymL68PYn*Jh@9biW{iTN3}mR-mj9yLmt_^kf43ul6hY^i_T-u0l{Q$GO# zl0Zr?cysG^b~D+!DM4J}jmNzjKwmfpc)Y;z2e^r_#Z`az-p+E{azB})d88=YsdnKV z@fxFJ;&Ac~!6uYXKtzK*yag>4Z3A;_0faE@7D%mahk;zS{JP}F0}1@(DTz_g-a@uI zA633O1hJs?6QC4<=;l+$^0hw@aUoZ3ceM#OT`JlNac#0Fyfoy4C3gGg2`<|`0n;y9 z(DsZO5Q(fN5(dLmBnL4-Eth1KdpP{5B;i6kp@p8Qr`^$ADdv#oVnCh2ULN1o(Cx0_M0CdU6 z@}2+p=8+UJmAjrb0^MJgXbl)Oe&=Y_os`|lkf-y_L5TY1W|czMoEtj|1rR2F82)cj z+7;UQ#x+*P$Pid@M5Ca^fSH5!-|+HRhMNWX@cN}qeDWUa<05N@z|T-|N!G3OF6BM; zz?}T_rq4+MfCAT`b8%~S8WLFAl+7Br5W&PNrz0w^%6-}3B-gMrdhcNsCsgi@^hFfl zmFZ5Y)9k*36>`KQMx%U5wjWKd`a$V30&jj-NO*>5dsIdK;>z(}I~*ClbO|853BX+< zI*%-Bp3X7ny#pHZUNpV}H($-Zy;BD(YXA3Kf&|q$Q&s$oDyWzu{q3amK5;<7r?=Ls zIF`_?sJ~`Lr(P7U1P7NB*8w$vrT47(k6ST_cIG?w`82ox`i3k{@cm@?iYma+56OoT zr9;}*LZBiB&i8v1gxqWqoW29uMWC!W@V^-s8xEehMNjW;2p9$--uHW0C@8g`f#|za z2pP6mU<@64#G@DCaRqbEMqR30#|snRuQXU4##(@4Qze}L7pS%^V!R-}zk!>MN-bt& z=;$ozsfj!e(g-_X^Eq4-s&RlYDaU#V zcML8nz>ks8AuBZNz{eEz_CP-5%J-%Z#%d-YG#-+55~;|3u(#s&E>P=7O9T_*bB9ak zZx|je97js`lNvfxcDAcWv7uOD>4p`!ITVg7i=#hYXd)%wNq>zK6Ts<^KNKuZB|$AX zuot9S-;iqwtNP{F6J&(fJ!fW7kbEwKmF0MCd^}YgFd$-L;d=2^P{-8N3> z8#hOih|`F<2MbVwG4RrDcXx?NL#a!km>>&)^t!8L4fj`A6ZI0l$RfTSr13=|$UmUM zgLN4S?OeTgB_)N-DOFv}ZRQpvJ0s@5ZxVZ#vR$fI@zMSyioky#)t$*60S*V3< zF~Qg{{341;2+md;dqyVq&KWD&R%3!3L_po1 zXiUE@uvUd=L@|jwSrU6L;*fR7e^w$`_EG^!@>NcIjcj;&L zhLXA4h1gA~?d!6xpIt$Bz5=KiA7rjh0S1?vnrjC&uTaRD5$G!9uZQsp_Q zu<(fwl|g}qzQZ;z)* zL7rR3?_86zVjH9X!nek^X|{C1Kkg56qmQ0eM7pqTA-J8Ev$a5ZePaLon-eC~ z)xkV+@X}*`+%k#Py~E^4t*=<+y>0Vc2@S&9QZ1F>Y1gU`rZ5W_qRjFuvNo?#T1kx1 zALgq-W?b2?B4OoNw)Tg>^9qC13GzcWD8W2SJ*F(d<0(M`t_tS`S!RhJYuDUT_D0(q{1h%e}6$iF3{(R?PcI$Txg}FfIgwbhbsXN*uahu*&cMk zdp?=+08ae=WzIu0iwXSFLhZ}=L6E2)c-Y}+Ve(xoRkE{UYLF1i! z)A}c=h5bV@i!s>E5Te<*#Me@!a7SNp1nOCYpM+>|^JT<8puv}FWM)d+Hy|Us(5j{H z1!}T%)teud5;7TP_+V5GJY{6^1v^m}7)fB9x0Ej-h91mf=%+KWI}fJctb))Cbv=OPno?UQW;S)LT9b_v#^fw~kM3KW2OdWogd-)1T_WbYjgL zI-u78yEJRt4P<}V8o`VLm?j>kw*wp(#5X`dIpmNB^~3^z2YxeKrv!vO@&i2;Dqn zDIY;5dO(cx zx-4LG)&-AvtS@hwBdh^vQpR2*0@0^sKYd^my-w(<>o9*wMlPr%pLAQqw7#C((>-^n zpp{*@FN8wV#{^fSefde=6e2Sn4WvhjJ@p($l{ zQ85IXR}Z1B(?N7u=lu&Tf8T%POom87F@RoPgUoaY#h^Y*0sS-LL4K0PGzlZ6;x{3WcvbHu0m7c2#KMP1WL+I3<0AshNG-PbYTPS z4>J(_Ge}4YAAleop&8QWC6O@<6%CibbKelM)AgR*e(sKZvt;VaQW(SmsXT#YEUx0= zcEr=vMFR2wEW$3zr?GN^FF(i?-nVh~G9+Xpsl!kxAt?c%u=zu2L_D1>HhAvsUEko3 z^ZSqsC0Db>2Nuw~*-Pj+-4+$=4w`<&0^ zDa95x#r&s>mzMdvs)b>U5(K%Z1bAqsK!m#iRD$kXC5FOuAgfKr=)L}U29Z@vG!c?r zh#u`*UCy&Q!njTR3-LX*>}2!l`c>zhhlGjyx(+x+^Ha5!TN9gN?MweXy(0(FjWo0F6clT~iJFu1WTOJx(?y0@jBVJyZV5+)j>gF&< zjG-=(l=%15eSMXMG1Y|IH3SciNS<#EgoY0d<9^;_r1U(&rx236D!5#y#Q}i%^XG+L zzm97Ns+Y)sZvQPID)e@IVC{p46m*zw$t z;nEYv`i2{lhiY$3a<9>xUO$6{=P5w@DZ4rZAg3(P9UnE6326Jb1JYgMC3&AiQs)0i zg2(5>3sL31-+?#a?T|oj2P&wsqE%X!S z(be7EI?%}#2K)zHu}ZccUM?1CCU2YJuuQ|wfP9)2Vk+a~030k)57j3fB^!hOU|n+b z%fQcS2EqF~k>=t}+{6td0x35*XZj^dASlu3ouzJ3`R%;=i&1@{~bH)gz&sgMrZIS4R41>?4Na?eEFFouK{e-9x!wUCvnIG8^r*=r=5 zZ~y)XVpdLu9dSN_K=xUu26&GF%U$2^rwfBZn@2PCfm?8xnblhd_};F7YB0bfK*9#; z9G@LeZB99ft|31VJ_pWv0e4D+t)AWDDOl+OM(+`C$hC_QJT*RHNO}_ICJs03j|8rw zoBLhsEAF8E!JQ)^Rj(s+Mr@EE>9HxAdRp`Htn4)$mnMWhkG_D3g~@$y5l?Rx&soSP zAOH`d;dVV}&!6v~UyiT4F%5Q+xx=l0*cS~=&HMJhc~OaF9~*yn-24nHzM`RM(_&ed zlMMG;tuUi);om}iqv!Qf`vM`@L~_>AH7q9;dC2Ka$rE6`y42gm`$Ti(1JQbLbSI@I zH!(Dq8hUu}SPVRb&BF%ojD(UB^UXKerKPhG@8z4ZP}Yzk;0*i71j>IyLndf6Y-U%S zRBFWQ?1f+MrD>xR-i6^LpJUb2%B5?!2=vE4PANlrA84h_;n^aSRb@h_6$0Ynl#pso z8E@d|0c`^@#HSn4y>Q??Ce5xL70bggQf_$_K&w%q0}EJ1z@br-BX_p=f(1{c1kDF_ zT|@+@H#UHCChB#>uUy2%1Hfkz7|%i+BpzP8&D3{vNG zV~5O!V@Nz6i!x@=8Y1s*Z=L_{x3>k>@ufn$k1YbBkGR5T`vaW3HokhBQz7|Q5T*Bq zV+9-nt98j@V3BU0p2 z1om0)bpTb(T2Kqa7fkVtjQ_{hdw^s8NAKg0UA7XkNs-7*M)po2%FN1oqU^o*$f}6S zCLvq)3?bRso0L7X_y4@>^ZovQzw7!x*VU&!<$Cw@c)iZK&wbzL9Io_*{-+2-Kd>_g z4NA4vd^~|PeAqGHoIIoKyS=>Wq1<&EqPcY0_i3nl>VUG`<8g9Ww$c%lf3dgc13=e| zKavk0n(VccgB|;Up66xYTG(!i`L7cwD^xi>{jsO@|D%0RTw|u$u=u-U`tQkW`v@~S z?y6{@P%xT;qpho$iXGX&v}Q(9F^IcK(*pA@*_S#1C zUFZzDiS&f4KbU`+k`8^|ZSBa|?hP1P*Y*c>1D;PrCndSS?{Xied~MdUR(Gc!j&54y zH_s}-S%!JuIe#}F!+bXI8b}Sb;{kHC5h*(B_rdXej|3bzS3GyBk(Ket68T`9ui9IQ z&(K!>FrHr>OUt>oT7F}{FcV^8z&*mXQ&=!^9#$FX&>~yekzr=w&lbs3$^Z7KiM8$E zpT;b0bCv`Rpiitq%cxq=i$_LqN$i{-#emifV@ymX3GxEg62tkUpkc@uy04?!3wzWKYhXp zN2fJS(*K!9{DC2Fe#>(=AlQ=1i~lPa!iq&0OUwncu5W+LY`=92<#i6{`9I4D%1~V} zNeE~WREhzbenXlRJ>GGhO0>nl!ByrL=|_}m$_?@Z80MD_!+$BcG*Um7Q3Mlv2d~$Q z`?h?%ma}v#kS2aC83}Kw2=RWVSt^Tr*wZjC=<#y_u4rKK_F?x^%kRc^*wbP9M@rS}Z*i5~%Jl~CD z#^GEM{QGPWBnY2SD$Ik_Wbnp={K#=C8$0bk3_+ke^HMbJ>JO56fuUvD_6)HPVlb_I zEno0HvO~cX@VIO%41B^D)6(_$l?xOp++mLrqxfGlFsNIwYFQ8p7(AOuI9R=^8tM;j zr})yaB*TzD|G2fI+7VuWW-zl7A*`@l#jBPp7OILC=HC$}I%D8^#!Gj3hK4AJLGFgu zlh&U$uyo+9*li2|1t^EWZ^)vc{|02jVhf7;iwNB%V9E@m!(g|_1RO*R)SZ|$8wSx& zN5W4*_)w-}e|3DK) zfUgmZv42>S1V|UAG0|RFew#=2RavwX#?CKW~M@uLwxj7yPkf~yj6z_?zSF6;(}wi;BG zzR$4z$#q!rG~n3|TYA54C5sz7pId^}CT6?#sp}Om;S;EipaS5SU$xuG?6dlKVF-hr zo`;DEAAXu)_+*AACL=aBw;XgT3oVd}NLh1MSU@8$rTfttA-7Epv<3 z?q}t*NTkPBpT8-;lDZ2i_@{IcjG6Ev(Rfd5`$hVLyfa<7a26{7z^~&MR{xM}BOkCO zyO^EE+O;0k0~qxQlkE&$h9RMj$KGmMu{D7u!G_;NdoJDKMVfB(SiWQ!C`-!7#9rQh z?y_E)M-l9fZt^s%cfUV13j^-o!q&q7*n z$}g(b-%Qk1FlEJVnEnQDNa7(F;HUpv@!J`TOt1vv;c&IjCqK@A)#a9E;TmLKz|C>h zVQF=z93+2xi9+w5YlU&H0yGy^me;_Q^LL-NqYA#0cZJG!-3+vHzJ+0G>R7>eKQnyc z(AExyCUL!whx#sKnW8ae9shPIxgS0XA~TN4iEU|TLp;!2ex3=)jP;b~urJO)17Xtq z<)${e!Hpx4nIqE>8vRYg_Z(5L07vb`A{z zeIvdgZS9FDqTf4g@OS-7y2RNc|9qSkoI4ufm(yijDsY!wD1Um?i`N<1h&>s-xfzgJ!Sh4#HS2SFMC~%MR3kal@-U|%24e5J%2=r;JUfBLQ?7IX{H zx+TQN@Dr)e&Q|*`yM`;tbp3>GBH+1YpmjQV`9Fh7f{6FzC(E~h9AE@33jFjvC)x-p z08d6_k$&tA>T9n$TpZwd{r7Tub>YLgb%vls_|$>)-|P(=!(qT3hfTn+w!D_!!TtkG zC=oJ$r+3`c(l&t|r;Z4kxe(kPoomtd6Chl|z!M!0HvNzciK{#i!Y;x-R_B!tgiDnH zAd6fhFi~a|SV-vozJ$_im4+dUu(es%M~-vxZ4uC%u>PdM&JwGvdC4BO{xzrI062jC z4XG17zR1jxs>TaH`>K9~L!S#$Y{HzHtAX0vgu{4MnC+dg<6cT#4@y>;K64e%tRLoW z!S?Bixuhqr)30)o6qW6$#oc7kVi5{U)&94~2l*tD zNq_Y`Zz{QW?`S*ef7?!}VW5jfaC+!!G*H8b1LpahMm$U~*_7Jh2#eY7K2T_3xb{GO zA(PdVCa3-z;j{^JVX>DTok-vFVGX&#g>~qTf&uAoiN;n$3|1<;I2<0+C4X#&bw~}H zfaHn~+|q7+#bm?cx5ybqrU1j>KxVCZn8@@|bJ%>!_;pgQP%G+d)DCKEYTol**CV$i zadyu;Pz8^!ZOIq)XxeW{E|op+DW8FHr4X_tr>WN-z$FoO6Y`c zSmjk}R7d_4pfi)h&c} zWb$R$2)6-{&9+3)WYyc@42|Ze!i|r+`AAw*`>7^r)^%);solwBL<9N{lHn@LI6cVr zNKrf3#y`QF?jco0O|%lt`AOX8669VTg4wa4vywoDr25TscTF3XG7Ze-R7_5ec!oo7 z;6=jeGTEuw@tz)1h)Q z$_^#in);d!{=;U5t}vW`0V@{*z75FVzzEazoB~3tya3l2PMK7fuEF`Bj_s4jU{P6V z=KI^#Tf?_dUo7R>&A`7^}cQdek8WTvWlS^Slt+d zR`>?bj$4C@@R1Eoq)MMtvIzCIo)N2*=W1orO>d`H)3&D>TXC{YO#*8kU7t`%P?&~J zS5>o!DDl^nW&*}LnUn55)ouE5`t9rF^I3>EFzqNGo$|l{@F}t>xNWNfc2qQYKI}fJ zhRp;O-HQj>Fx3Ke;CXNb-Yx#wGR303%wEozz}+~H9{nr@2GO|D!%vGG)?bxUlLS?h#Bdre&bp9 zWMZ(aBmE)PnSLwdw{wK-VIhR^FVYDwFfvO%PGfpd8TX;(s^UeO`}aTJE}e9X){5>Q zT)t29`TbY^_)xd|(qTNM)VCBK2?TxY`{gWa!gYbf^3wK7^-M@|>Sp9Sygt}Z z?d)>oCYk)HM zjiZJGCD}OJ?1kdyk}i_&qsXm>bDifhoZ~_cvuQYQ+x%9eyLm}ZPJiy{WA;U_?jocl zBNZ~`#l&o36x1R}?oH=$U`xkC8&U~7pykh@wT7d(HkN3y-7?CE9}Nr;j>KD+DAQPPI+OilGu-;%m!iGfOe z#k9I(GR&iGM4VT^p4YhZk>ZzWlAO9{bf4dHCv056dag>^5=kUs^2=n`J~YdK&pY@X zw3>Q3yZ`Ag&2U{RgWB(H z<%8Eu98x4kIk(o*NVI8a@|)a!#21FfZS)q|x39MZnYH23CC(_g#p+3NweBUnBuu~S z-x2p zAm{kju^|F49fscS@kZ~YI6CF0yc02m^F7Tf`cDWphXtLQkRi=r?lFymBqVB_ef)!OWwzpR0EJ|f6$WjH1%=g|;{Uts<~zY=ck z`%8)J%`NhiUp`wGEDZ>h#f9Y%P>4n|yD7t44cb^B%@UpH9UKiRC|kRzr>XVOrk?mM zSJw@!9{O)Pi$QwDYfN&|e3RNIuh6ImE)H40{=6~g`c8b4SAbs$S7`Aks{*d$4z})E z`JCnKrnuM>b)S1+@2VTF(bWifs=27U`*XVMUare;j7*$w!^Pe5KR&A_Xa>Eh_bxOT zCke^2Zku$zTkKJWb#mp6#W{-Hx8`I49|tR>le3!kb)|Wnj;S)7>xZ5?&T{Yhvb5Db zv6Ay$N%`9KhqpT2Q;T%t@2+1)UuR*us=Ql)o=d~- zRjq^7AH(Fk*9QjAU$^xe3a_1QT`C@p+%&che(?U@ZCj7-^9>t!u2){hVe7hKKdP0H zo6=~ZaI2*14+v^G*fgc0t4j&nL`n^@3$_jDM7zKlkne_5i2< zauFWYlfWQjgm2N~ISAeI`|*6vj$XcoR@75MvMkyNo!`4bkl#mc*@NANk$^(37{4d|$yS;g8EmkU&m61a-^t^Q( z?=`kG@tTg9JUBR@K#X9H5xd@HqE45y+sQvi$~A>D$=rv(4%`EC421+Vjk-Ea5t}sEPUoG`l9m2BE-Z-OMmY2NQ0`H zF@@*r{2}QtaMA?ujFHDK$1xbLvUY#3N>ZW!PDylDY#Lu6dV2GM+RqiY1Ha~~m2+4* zH*;1s_pq3kM2OcqVmP7_DS!7IUNK)#ed|cy`}qoT--0e+(Rq4jiy80MsTUs|`8a&> zhM>?o?>zmrC~cWd&?&!Fn%JKiQ}&1&0t2zPYsVZ)u<`@uNqCUi!#9K9@clmT@&*oz zWI7qtCrT#k)qXy_(z}`dd1Q+jC+s)T>d^|C%I(q9h}@L!{IOemSW)XjMtx#P)JjqG zt>q1R^rF0gQcdG(*r}M|^43`4ZnXI;=L(5GS5aX}indwZV>CB^kt@Y^dos4AZbsa_ zg$=*m&`c1Q8kT%6**5*iBYvBoQj_w|(Jp(N`y_lb-rB0v+<(q>OL_J9LVU$9bc4P7 zSs7ul!G;uo%6@}$qJq}k{j>LsnD7a@XQorFA3S`_?CvS$Yk+9T3HQX`8fN9m+v=aU zxKS|On4M;Y(XdOkxEi5ro%rJ0^=KZWW0&yV%LAOpVId#R(10j&~rBkx^gvBcs0M?-=rp=}Bq}9Q=G;(yW7|-IG05Qyl*`BVrP)4_Ne~ za&Cq{^JXznUOEQ**pawBUIq?yYxt-d)?LZ_LD$C7Po+=ZxINCR3L5=>BkN{PL&j@{ zsOry-L_4$~W${yUWyJf1WOZ|Uudz-EQFfpFc}hjNla-Yvvm3f@>K)t$X)o#jUhJ#Z z5wU!mr9FaZXy$oS)6)pY2n`16Tkoudk4bG*`u1IfO+{ORlBQIQrTO0HUw)vcapYZ} zRq)wG!KP_#-hJDsKDSm#v@Qzu@}1LTr{T;})oe&sVUDgxZi=|)f18mT9{OpGz`RHq znUmgRgZ+EF+l>@+-TnE#oHMbvK*K2_yXz_mR2 zxV{seNtq^CaL*<@I{pEwG_%?yME?7yUl?RHTboKEtMv28vCKW#$hKM~6_w|5JEc?X zeoZHF{g&CgI>xf0v8A4_d)1@g(P+8J8O)qiQ+O&do~M%&izq>p==cip%%{8xzXX3* zr&W9JY6j9*e?Hx{WLdu*8I*#tDv9>Ee$V+c1?$~in$8nxxEre%)H5BAbpc*L6lf^? zZtYCBxLMf!>Xf?8&YtwN33E@c>;4ax z-0wr9Taxzn6oW>Wq8$WkpWVAIgtx`gZa#i-eM10n$X;04!hH)B=MMb@HrJCkzv9OF zt)9&G=kMuxzE*aqhJ0_Ai_~BK|8EK!CyR-Opglz=WMMB6sft3&jAIAL2TF=Or9JjP zN!{G^v&g!;sK?4IGJ&r2y#brQvln>(@T!P#VrSb=KlpzV5|Ngbm6eUX{by=SH{RQn zBO_B|C06!XT@8ui%6xd!ve%U5;2!N;o(wZEGG{JOV~@URP>qTmr1K9$J*q5Z+S z_-vA)NZ0L!e(! z5{E%Ox!A0n^ESb#&Y0?JgUGEZ%CI0)+A(`fIGs8kPM;o1=i)xoB|!1Mv$~7x@U<&! zNr3Z+fmL>%Bn-dB@68r>qwWM5zgyl%`oOb0T-4;4sF5RO?S4nN(lTF5&&Vv{@b!vC zh$#;7By0D4_+quCO<=>ra51H}+iSnwIBiwO;nPD6iIkW5OxXZ*1LQ2OMq%{OfBXY* zbz-iMbx~E_t9eQA*MjEA$+g19<^ddt=43?XZ!B(z=aGEujC$ZkX;MU?FnU0Lf#Op5 z*Z!v)mv2*-UeaV)zhowaXICAl=yv4Rpt5|$t?t`-0UiPVu$u*g6<2J2YN7rV7{N2l zZr&|?Yu4V*Z1KIoLAM;?y96Nuft}_pY(#>Moh$f2L%0ir3J!mHo4&`|S9wJ-13p)Ie#} zMgHrgN*b9?#~hbTCSApNAT+(68d#2YQ1Egg_SxMW`2Ewuqsl05!E0?e#!T}}2F+p8 zPL(ZK(DCGb@*tiCmvau&M*zD>aq7MucM1!TPk|>(F zwQ1lD$;t)AMcSBXkNL-_1s9jPtctJssGi1Ld$takUCPP$j1W_*#R1#PHFEpQ-wHE1 zq$;c0*l!#gro9uB9(4d-c<& zoYf_gw^$p(Q}URNEryMGvt+#5VLjhSgNDf-MD?I+QhYUAJv;Aox&F>@M0pWnadL7> ztAm$}OJ7RcRfuqBMZ0YoiH&|2 z!%06DT<;4+YS;Sidz3mRqbqf0Z|)l!s>-A+Uh|_C{;E=7N^fNLZbT1};2# zhK}`^<~4m(JdJhAK0tb>!s$V|L{8)mB;5HE&y%@Sux2ppDg=6G=H6`qd>Z!jX8z5$ zR5X)3dTjhnXfy@k&#nO>}(YQ@Ksw_1BSqxxg*@9=O zvq0_We$U}QI4ux-!lgaoz55`)l7vM&;4WYhqbJV_XaHGacJt7w> zndeO}s{iWYjCz-yA-*s!5wx`(RL@mrx$n+5}1k>3(120Bu4pV;qS)Kc} z=zWVxqs5Jm{8l`)xSYEA{UWz;DD6uv!-rgx8OF8?n|YM9--H+Y;kT`7#0`A++QcI3 z9bjL?9KKB!0y1LPu%tZFnu-9{r~Y0$h>4-z!}UiX=1vF9yO{IHCAH91+aGnic=Y9!n! zCCv@9)zJL!H$#u8x+$6G+7kHuU zsynzNy^XPioG?7)BYfZy;Ht?mD2tcP2N;sxH?O^9vBa=KP}-umtwjO+&&UeFxhUdQ z4Fk(o={z#;F9T2;8@qo!>tW_`e6M#D{Z{yXiLxE$mon>PwUQ2%W&fvJ#xm`tRkp85 z<{1>b8&|XYOad=J-Mzyas68*+{y|IxgYjTy-5uR&JCEB zasxXQ^b9yUGR}B+?>(`sqEzjYWO8wFIXhhlCR%V-nYCJIoi37-4I9tcdymxUmfQov za+Tx&@8~D+NPZFR2J4f=%riN5qNlC$9ETDRsX7dVNSa+uHLyNzD^^*}`iBXLGr;by1W6wZ)tW# z)5e*uc9L;IAk)l+QcC|lQm@nnpwJ{}iYfuXODqM~vuCkN7}8q>L^-Nff#|tsKx(8&g`+16^%aZLuDyvLk8|*1&ADb&;z&*ZSw<4sY&hkpz2Ue)(yT zG1^bU0r(B5Ji3PG77-#+S4u-dXrh37iE*6)WafoRU|N#+$*AK@kpit-_nX z0~@)*0BYm~DW3jX45XNBc-Z?tH~iY$ z&nHpx?fHJNdY&;5@mW}occ^k%=>~9UuGHu5=z~29A|yxkR=fv@?6o&wfv=lmFAUUKOJ{~aTps*@wm$2vTXC!{k3lo;;|$!HdCmfZB|GnI`#am zOQjV$jzq5SGHrK{w$L@B3mb06#@4bS8U%?}7nrZGHkogu3F*GKo`&KUr%DGu0iqiC zYG-U8%+*xkj0{sYjyu88m6hMRPW5+%12LVLe_mr`zkXr%8aA`dvDPj|NY>Che~H%p zBk2O`Gd+d*qojBLiUgE|{^Vlh;+&x>s<^Qw%lRGFQm0W8sMtc6Q7qDip+G|}-_byT zfLR6~WG$hNM_!OmHex7%>v_E1WBVQh@4*3fU zWpN{a`e$JKaDAs6spoS_^wqF00zyt>CZ4;vO= z@j1M1AJllfsA-gpT=ngHFLHUJ{xPF%T9#+LmbPgo2au=RC@;~Tuvd=>eNrQ80IE%< zwNQ#Mr1y`J_w*obRoDHoYda>;S*TWSf# z<;7}jRfvSTy6tyvA`QeIWUgFrpe)V}y+Fx=2WD@pA8M7WCNq$p-5i2){?XI>QC^V< zfx9x6v+?JL=Bmyl{I->xWUKRQxSK!cTeawN-Xg?ySw??*G(G~OF1o*q9(K-zZ(Em{ z!ArtKdt1i_g0bY!PmdR}kj9c&PA#79B`xdkt7LwjSA-h2q<|t!J2A^Pepe3PBJzQWd z76ez<8R%gi3rykb7uvUfY=fE`e-pB($U}I_QXCy0Pt({aMMgm}O6Jk`;|Hd+r|6$Q z%IPW7gmv9Qp=m!a#1C<@V)Fywz9a8u7!WkXO@Nv`+YtXI|0h)jsieEI4i@&4>ZW(#DxL0Sk7jl!smO0bfsfarF0b;?R zn!vPFY4(Aa`UHpAqeszCK*Q7aZP+b8Bxo4}AFBuGwzDB@R-1wh@{u6d9h zq<~9ZMy0!IDT5P1CoVVWVWF~zK}pZ$m@5Jr!WHWA>lBd}fZnPKe;AV2zr%0;!0*X( z$Duz$8uNQ-Jeg+1C6VbD8Dog zget4cs%5dc2eB`5=+g9=zx;d&CA&@EP&z2zZC0QF@Q{SXvc)f{$-KL7{w5iZYT9 z_iZtJHp_!vcokS#i|32~q5W;NXi0^_Bcpv&&|eE2)GEX{D)%KU?Coma8d!7GJ!EMF{3-LgOW%osQq4& zw#U74B3{Erd5eo9{PZTD|3YZUxA%}(ooY%J>SXtR0uDCZhe0R*hr#mHpP%l&uFm>% z&3*05hi{uUj=vemNr_}2F2?3p%JXKK_VQidx0!Hi+j+}yckMd>#dl#BOoEl!71^I* zqrFwNmKw$;@oznHubM5oPcjR7nb7k@AS+(PmV(qR#!py+*C)u#Yx{cPFbaq(s@+P@ z*|?vp>9+EJnm(IadoSf)-**}!>oqPnR;O-Em)_rC0Tq=>%?L;0$b3puch9Kdh3mQZ z35jQWKHD}Mg;ywc8(Cr}Y7l%3P*Kc@0j&GSdsr*;cl zM29`6=6!{=gA84-(cGf-WQGmFxyCtH7PZr6^nl#Cy2K+{RQCEJ{xY9c3HqI?YBcaZ zb}N@fMbw;5n5}5I4tPUB#uxA={&v3MHU1A~=GS3U18g~eODC0OV@EomvH(;c#zJQ3 zhuiptWi6#NTrh@|3EEsxMSZM$2ax%8dPJ4hM=<2a-TK0R%23u5s(qD^$sjGeG@l_zI!z^?cDo z20QMNty3}WhA_CM#H$_T$DYy#EIYgxm*~`{)W02sM;}hSFI_kIpnu5fl*|e(8^M5r zM7wGCaZ2!pYyMa4hyn$_!e#S<^4#3$iJ>Q?3sj>~3L$33>0P$|t(Cp`UYCKQb<+-@ zPMROq(i^wpGVhJ#pz+mKm&z>;PK~_6xBrwvYn^LgD?N< zvO0Wo3^I~;v8Z7nynQSLj0#o(t` zx||jpp8HuO%8EU|2gh<(4@58}+#!tBhI4aYsDa9avf*3;siEUUmj)CaI%dSz90Q7G zglnD)YSc+a1~&htl^GV-pKI!F%MKoX&|7umoUlNvdo$nGLJ!+4u19_KV_63^b35?j zql{Y<`F+4B_s*uO{J1{4X6Luie}PFwiaHnjkr3bhV#^MVKqnXLHg2I;JtOn$D>hO- z2tWx(3WdBKEYzl$IO9no6XOYHDrlmdbt)^0IunA3>Q+)(He+Xd>xG@tt24dSi5Ud# zig`U{_1gV&k|r5-+EAA^HDnIO7ol(TVWcMew+S6?!5(_Q#g0v572ECVBj8=R9S6zd zV{J`5F7w^KTby9@a@80soAQlgskbg}+Wf*LqKr4~qk@ftUcw{FOYZnWQ^^Vw&a+^r zjDyeq&b7czMMbsoy zpz0?j<;6`sG^}v1xB`gjad5TN>@3cJ>YE?;W&FSXBOKg0`t3iI-?plCWVcdzON_); z8p?AKElP{fTd;I4k7jl~>X*p;&q@Qa2q=aU#0>)M=ew_Pot3PJeaT2`uXyNi8PC~^ z0JTB2=o?xdDR1)AH zjr*4kdhb|xu4~0&Lme;>oY%bn)U)u<9*K@8nxAjA0oH3VNzn{IN%xtq zNs{>zKkv4G8L+3l9F2d|7~L`%Ma&Ge<&ObpeEC+?phB%Gh-D*Isfyrfk)2Nl;m(;R z^I~a>8?>UzVio32T!8@o96#HRpPvc^{6E_uy$rt%R6B$fl;sX$@2frM<4bSL+lX33RR1!sEnUAKcrJng%KXB`J+fU87DkMh zaYtJ*y5=qleqS^F9q^sR1PI?}o?h{t?VdGWcuAh_Z+;BK$>!No)>;pbs;n~oIgyIzUxyu9UUFRtO9xc z$)}HNT+W;R{=I*C6kXxA`hH@J7tjAZjxVMJNOs{H)V*Z5!at^H^wVp(rP5otZ!dG1 zGcDz={drWPZ`};1777L2^aeJ!wGm6C-4z$dXk{2BBc~knkM@rGdGW#74BW9=FZ}Jp z>lN0$wJuOgS3G_c_~QpJt$0%4y1jB5hyY7&)CB^}mQ&}$j-HLzP=R&lKUI%v;3t^+x}5;4vC#kK4ovuPuo;rup$a9mU2DIAj80C!XQc4#{pa*qIu4BMfwi8n^on z4cnIby$({|<@qKbvd~eSv0s97|JX8bedqAnhp$Xd`s83#xVMWn1mGTcUt)k3c}aLk zIqMX1FX30r3H+CZde`+IIphz=e(bi_^~$6S%eA6vV#o?VA?K!%bg#K_(pieRsP!_^ z!>u$|R0wTS@FsGqjkQvLTdnP=%?UJcb3@`gt->etAG?ZA&9sWR-vjT1wtW71Kl}|o z?~6prtCe)X*A04Xuz*WyY)yj2UJwMdV1W210x}uR!xnwVP4X2&CACS!J}WnJcNCjx z6}zQY64j4>3ug0GV#K0Sc4@7K1Rj23N=sFPX2t~yk2i6yawYj{l7^z3{Ay!YmTPTK zf=nuFS0^>k;YeR3JAeK?3*AF2-U(X%lZ#i$Pq%7JDoeYs6m?sphU+<-KH9Y92eeSt z-p`=Mv^L-Z^B}^OL;7hN?g)Tfv}%) zfywx>4Wyzg)cjC8f}bgf=k%1K&wGCcuV25??;kfSC)seBYH1lES%b^9-#QYbWUwy( zkFO#{myr$im)12)*B=z>`CRq~AL4|~f236Z>erUZ7o>RPNU}m&y?}; z$*cdd7H-Xl*{mJAnwWS6udJ+_=4D#$o5Pioxhng_!$A<4E%Nup0z>2e{$j;x=t8zvX3_6UU??UY*XRCi1lmhS`t?opglUb=@3k%N6*oD@O#TaA8ufQrg zb0{G}I;+~{QrwH22l`qgWU8hVWVh*OdYJ?QAC;D#Diw<+UP?$Gh=?AKTr?2>vdzV4 zW<~yFPVRhH`sStsTZj@k9uS=MtJLgS1+g_kboZiocm&30RPUvyR$kd?)=>?5NOfP) zO9=5p9(&B~3pE_TbvMP?t*EF%iclUv@*vQ*=c7ndM2Axa3&MS;`e%NndbE9sq%0E; z7C(6swjfl#bMRd2sO_`=HJ9AgmiiTa0k=qf?qBVH37-hrLPe{O8=@RuUL8KqEG>~8Pmg+ zVw7{4@BxF@1X9%p>k&&w+Cj&J{&!%f1Rel;3^9nw4x1}_RW1k!dwX-3LM;PMPMh!h zz4uqys|-A~ZS~*X-x8rdo;+~AW0H-?`}BW0bj4dW*Vix7Rf8MTY-{~)4C6rgO;Np6 zhn=cQ4R!0i^(!~J_0*~kDG-B6duveCD~|F#vFk%fIbAAfD-r2@#CLjffo#Wdu6`X0 zr4oTR$EhIoxbY!yB-c#<9Mdrwf`5SptU#!Z?@G(yhy|L+GYT7`|Lj0ZWDj?dgevOX zIU0YTm%DWZLgeHqsVmaLb4J|gGgapUc2tUB0XFJooGZ1@_NFtov}kBZmA$>azrHvQ zJa)W%K_B(y_a}DnVYE2QM}9v|;OY<_B3mH>Z^XaDuRsW!6#o4ZEs?`k8^BfWO=PUj z_gK&RL6xw~$?NihxalUFF8mdT7>K(7BGHb6Scx6LTCcbjj!p&_CGL=xdlzy+JZevx znmlo!D9seCPH^QO>&Qbc$o7L?e>mV`$$(Ykml?${?eSC4 z`)1>uo`9@M_uMKTKoU1lDALQM_=D!8$tU4N;z48h+Qq_`fPUg|yN1J%KX$W*Seg&q z97bey{ZqqsM5U#p47-b~%QyhwA<9(0xkM{6MmDMONau=9=7Nx0pl*)&U*N9!=PJ&X zliEuoH_3>UfjCV|zU1?JKJW^tkLgAEMPme!&rF{ca{15~(r>S1?%2ABKW7>+$3uCh zIr#hvO!!G(Fk4<7xm{VP!dIH<-s)Vg=4%5g#p~tJTKiAa`-nBFC1@^^v1x0HpxhZC z+Yeo$LJPNB0lj=ub=<8u3ni9h@%@jb;7t zwLyHCa!;(5lqRnJrmz9{k(3YU?yBJT;StoKFi zT_*&T`mIni(5=Oj!V07TL-g0Wg3+VP`qU(NLvk`*bAQqL?FL5G z6j7=bw=GBVOK~YhMP#lAnMMF-oObt=^IHGI(mpk+-zr>r=+4$?iG@7^XuxdHP`bsbf7dMWQpfwLSWR3SkE#?tYZ)NQ;aW%LDy)% zceD(heUnifiLzc9ba7>5x5JF(tRi3Uyms8LKYm5X zAjfja=IljjT07nFNz#vMS|j5hd2dVgcFMwYrxc@gF?Ac&+wLp|yb-y5L+xGv}R zl_EV1?zhOtZ=)Te0xt4|bbBE%NoOu|XHii9qHzmz{{9aUT6SUMA=(A_W-oFBj_d_6|9pp z-MLXi>kKy3s_ERcn`Vg^(OoCt^pA+CY4yfk7iV_Jtqvn4%Gu@XKFM z%bKw-7-JfK<4O6T=ckoX=-SCqvNC)WNO2;k6MalciGhL8ZN)z8$EO1%9GGGveEqk7_RKu{+A8N106>6@=-MXFgZ<>Q)5ISG zjLb`4#8LltT~WJ0=ofC}f$FrMZ1OcATYeVYmz4S}m*mvJ!v>-++xJVvE`J#rKW(ri zOhdJs{rMw;7DO7w>n9v#c*7eed+@J|`HG&s7PX6v4*qksNuFM({OeD1%ZB@)Hq$_r zm$hv*vTE@L0|*I?`JB88W^F4@j~jP%1lcVYSlZp@KO5;6IWxs z^p&H5zn0nOL{T|+oZDtnp5TUGc*WVc-SFD?ad7>eIaN{}ZAn>qhF*9O#7KpyVoKoK z6z(NLzlCa;*wO~sPm)R+qE4qk6;R!-?vWMJr>x!sPUiJ2CJNdESai%d`#awykOd5l zfMK0~@K&Shm``KwR@U`;Oz`3A?|0P2UHl|E!2?|pIjx%Cq5hjO0A^igp zluv6%T5y0ER2I*4Qv#6$Au$FHy1WfIbgr@XJjkGCruKPcv%=xg5&<5q*W`n=$@+QZQW*7x;(Zk^pvr2-(Y6elo_( z{oE^DPtpiV0>lJps14sfs^$3Vz^gh(9!hW!0WVpqi~b>|REPy{pBv}7yiX;#i0ewI z8$?)^R&2FYE8zQ)Rvh)jd=0&etLhB)YQyH*l7S0?-^Op+`Sk=owthy&8QmvOvAp%K z^M(eh6nq39>?nLOn%|s--knXLp#8QKYtaq*gNK}D*lBw>AkI1^Yht}Z8fprrm|!F3 zrEs{;6fl*7No(^vmdmQ`3`#6L^|Jfq6szU4^cX;9inuK6CkrVPJ7;rSgbkK(qgnv^ zQqhQg9J_VplGywB2R^@o=Z8HKk;Wfwu~Us!|Ipvu@2q};fxC^;8MVq(thOvzCV%Gx z8-PXMi_ZwOquqoe4w{iQdU`z5(IFo|=V|E7I|;n?#5VrYb&( zr9M3mjOG6_)5L+eBeQPmVY4@i`ZmZs_auu~Y3j*2!QLV3FUD-Xz#=VFH(EskcufP# z?kEjJjkoo7NdrDvh`eD6#kuy_3vl}pZU}?72z{;)9JH*Xr5}%94kwDE71z(nE*B#2F{!#f(xoOD0LN5)${rfR+FlL=_d}Mj1}-WDb$r=W zaHDtvtE9~RALUJWM}9DG0B)qY{%+|ay#q)@XjR^xQ_*nM5R;B^wziGldgv-T@q;JXD-#Fl_K_NN|>dQZCjr?e62rj$*{yaMQ!NvxaZ*-;|?4&C%jy;KT zr?{cB_kS7atir+q+$g-+Q zh_>XkBEkQ%+Y>U(*Tw9Yfpjk3mE?$N0Ape2>H~Eqx@{=F;|S^6PFoc}Nj=31??IsI+=HGZ0{hRCkd7-{eHMIq^6q>HjP6Y0h>-;pmq3IKvx!F}1X z&D{hc&Cw}$Z^@JtSeoUBmDT}7z(M*d2M%=zh8zMQC_<+N^J-p!=|uHKwVy0aG!aCd zua}Eh(K>LyA@`;whYcR~cMZ!yNPo7!n2b1BOacut#c|seU%3fMn8Vq=bZ4HoFOp!_ z6nwba4MD6Vzh^B5O08z)q5Mp8n|+{d96^!3Pk{4`=n&f{s7kKwH?*v?jC#ji4Dn(S z&AF2BFz^4P>pQ@)eE+wfN|LRC%!aZf$=)Lpkz}uoD0}ZML{Vn43CZ4jrAV@}_sZVu zv3Wl?eZRl=|NejPeH_VA&%EyI`pok@KNibapOv72i$=oZND8~m6is$^VpMp>qp%7z z;?^%shrJ%|u``wmae1`G0i?7pxp}A(c7|%z{u~iB+@$IRI%%C2RS3mHpe-}d@bQRQ zmO>HK{-ALSp-l$&1i+Vda~->zKK2pfBL3a$=OCy-lIBAGUl}g3$GE}#t>jDc&r(mo zgiVJF*RCgeylD5rxw7&=%%7acnhwE+?GNsZbf3_z%2@DRF zhvboN2p9kEIlGVIk0Nmig3u$?dx!y--CIsShG@>0v&h4V8(xL8HQH2AJ?3^;qB8rX zcVVv3sJHkCg2zy=?je^U3WkIby?58`a|oU!@8*B~AuS|cJ}oqh z0S^@$dm{%Dz0fR7oA%$1NovPagf|5+ogwOl3(r`*G}-eMCTuU;L)njAZNl*F^UzaR z;ec9d!mlU9Ng~3jd=o>WZxi!H&Bxz|FWIaw0Bnfj8Iw<6&)03#dUgzTW+gTuz8BoG zHesmBTAh{H)*jd>AE%Pnm}tuW$QMN6o!m5{i?hS?!Z#tq<0t#tJ^&UdXTH$7-O?qv zWsZAtasK9&?J#hmvp708Ti;i`L;!0WTlJkPPs+lchINnD*$jwBY z1_y6%zIZ*o_5+d}8kymb2PY9(4A`R(Kt3ha$cjGdCCzqjDB(IqY&J^ z1@#C(7>uTF)?%66Q}-G2On(Yf9_MpMFdr z_tZvCOlSDy*g$GpyC1erGtCWo+@CeHPZ%J({9SMTJZ~SW`%$cywr$wV-eEQ!!3@k- z4T$gYAS%#}EDZofxp_wa)DKev09@CaTjx%20zNNeQ!}#+WDsCi_CgUmJ;4s*EQ*WV zQ)rCx(?TsC>WxBHfH%@s`*5)Hs&yqlVwr@Gx(ri_46m88C*dTC9C&JA`R1PxXeo75 z_z&*vdLsmSY^?cEqyc`9Bpx^62y@+PH5#9^=gp!i4MWMP2wIcU^D=tWp%z>4| znKdE=6)qsb6*w-xC!1GwdhxpSHWX(9p{MoCc$;+l4LM%xW2(R1EARp$jqZw<(l@#r4;Z9YwIJ!5%m9L1>4ilFKzT$ zMfgF#eyzIozVq))tT@tN2F8A$N1!M;PpnQ)Ozb#vy+JeHuJnu4>`-A@5QJHl{0B{4 z2yAoqCD-?WYL_d(8S@AIfmG|Ke-D%UalqsbH4^bAD%B#w;_uSA&rHOA532hCkQAAB z8U(qlWY-WN*J+xkU#|H@V2$=cs}!l$U3=n*(X$2M%iuA7U6BKf)V(+4%qrQC26#i{ z1!d7II=jY4}Q9yXUp%laf7^uieOHi3jOyu=dG6>%63F32F zSvZ!7`A5J4?~6U3`GIUdQ8B1tZwzSc?Ok8h1ha$0E3cVobzq^-bMG5`7)jj zTU04sZOWpzgzwK-PFL4AGA7){+pLY2Mm;&f&j&9C0gv4;R&6=s5DWYxshbg9eIWH2 zR0#kFy7tbYn3Tut>Bxv4oBf`vf8cX9>S;vRbs!i82nnh&YgH?-X?Lg|Wm#!_^PK*Z zM_gq1Hfewpg1}s)>wB>9Ztljnn-T++F0k7_gYrS#Yj8$96vNE4qxgo@N|Dp&1F(I{ zr&j{=fJl;eMgd_t$ZW#r6-KXyk{8flHh;PF(C;mEx|C54)Z!BxXs(-9%-AN^`%kU5 zZ=B6tA)WN-_h%fcw+A%PcwxiHW%a5V1$s@lHU3}fq3%Ham2<+6=!fclBfTh{YzX{n@ zhR{o+xW;ZSGQ*)A*M-jz0!l-;7oAE|t)4gdK7yVm<@ zlB*O_Xu#*OZ`TgXVF3c)R{KrlAnQlSWg0b86+Wavf(hjE;`E_od%2LfbZPiTiqQv! zp7@$LzDhYZ_pjaV{s1VvA~q3^z=j6~!*i$+aRfxP;SfR#u7bN4?U+K`vKTY1K7{jc z;b~Otn`_5c;!R?doD$lgsFayUqc?+b&z;H8s*|#S%AV)tb(-a5G)0_uXo)~R9LXZ+ z9S{|U**wBQUswacfc~rdw-29wk*N(tcPF!W|4zHNWRG5asyHtXb?O}lqj~#8#*Cotm3aDyyD)(1nol&8OflMSZ=x?BaOs7&PO&SI-h- znLTqJjy7$MO%hA+IXb-J|VJ#+${Q~*5ZWL8LZ@S=I?g1XKRsd{|?%g#($|5Qg zo4*Ps67}Jz;8#jt)028Vy03_$kJ*4*WQ@X{QS|XG5$%ROcNPfpVZJbjF#~tg`D|?%uKV z5@@z@nEC#FrDK;)Ga~kr>~Dlc`jZSt?`YzaVPsr)!XDI_Hdl+T2$#44IHPM7SO3RzBCzIf7wRUCLMcz*~5fRy?>OYAJxwEH$e| z3)C|$ic3f4pkfD`CUFu0Q9$A;7wP*L4=AQZi~0R46)b)K zY`59o2@I~Jo(y|R$6Stq;tG#BhiVh|{wlhNU+Bnn%iCyd@`Y1dpuz#jdOe)STR51) z;_F;=TOi566OxOse&V*HHPYtY-A$f2=~|VP(bA#fp#QY^1w@CVQ#`+3*ew0*Ye+9k zd1PJ&OdyqcLN!fOHAxK#{-JO*9ep{-+CMB&cpBgz(B&8O7Nc$W&@!9bY2eH|H;t$+ zPWG|ES8Swjdmh_YUD`ad zeI148CFm3b27tTZwv(5D2elVARRq_V1L)6SVH`_uup8I}kta6k zM>AcHlN$5^Gj>KFIOjSP56wpc`YwaeKQ<9fk=0uiGaJxc{*0YHAbezRL!`Vr3X1$| zy{-sy#&M!ufXGniZGZkj{}EIjB#Xknv*8c2xTCeHP8x9H(~;n0vuvQupRhJ~vvKdS zRyZV_5N*~q%-cy81zkd@HJT16C<{w1;}gBF1B4F-Ff0Mh{CvwL_pVDYej`C^F1K7i z^vDs01>`LrAl?y^X9my{Xj1Z0yy+V-(E%x+{pd=FnSJN=GPC8(7os~vR^j+~v%0Vb zx`r*bk;dPJk0eiBJ&-`yg8;1;Tfy`97d!%6uV`vd2vRu|0zs4LbCoWYIY35qE-@`1 z8US8k?N2;{ap-`TH&nJ@Q7G)f6A}8j1A!KBEi@XQ_wSPetC*fAuy0e92XH2zG+5|n z@W6^@9Q;wC=#P92RYD4_xOz&c&WQfOjcH!wh9aktfb-7;3rQvL&*;}AAAX%sEF)Z~ zBwdIhBH*a}abFtZ&vvJMfGDyWS?E-63p|S_ome(oU2n|b+eq4UQ-8biI`F7j z@Z^XYPq)M=m0|OL`CNkW>{-#lZLpij^_dZ{0DYeu?J2jil^Kpe1il> zm{kuC$yu(8F&;=LP&)_v6bzk%%xJ+?*0OKzBTvO{ed>=!=1WY1)?5SQ25%$IQ90?o z?>d-Um|1<&yE1pqQ$NmRsEx>fs}ViqY0~aMCGoVU^EDNW#W3s>*|(k*sn{6~G~K)X zd~uhvOg6+~VKMWlTlTlu{4MK>B`s%8`6a2~{Uq#MtE;Ni9x~Z{j4G+Yc>&5he;y`I zrr{V&yq1WN%u1@rld`6a;yfe^p?$?&$-MX;-6iZ?v~EN?N6wsLXS$_wS7-HG(nAVt z$qIRccYLy$8a>ZbTQ7~loZ601}ETf?NNZlJK* zF-MnpFWzCi(>+@o1#k3hoj1;gmm`MynqRlXX!5jM`?LrMiW+LKuIQ8;kFp2n#|cPF z=!%O$?t)H;u7J}O`*1d3|YSG^mWJm02TH=qf9vq>}sgF878kcx?MUL#99_5J{V6ZjqGQHj#yE*}%71;XKI^pOUfs;6}vTA3oIHz^hH2g~94xTJc zZB@hD`_(K+%jMGjM!8)dHskN3SP#yTp=TY3WOL-XV_gI}6D&@vYwJP0%-&C!W;!D0 z$Um-p%RaoSYKp(R5J17LX;xK!M(gOtlQnypxD*mx@?VY2(7Or7b^XQ9d=Q>&c2`yz zZ;vSq%L5nhytSiK|4%)(F~Oe7+gE-V+Rm7BJerKD>*nm?=Kt_G3;OUG_BxfLE{&=~ z1e7g4^-Xp3o1Br-+Y4^h;VvfCJ4^WN_wEf?Q>M7sv`x)_t%yTTbFKt2{hIAFwK)xW z+w;h$fj@OiLJHR`m&+m8vp1h~8>`lGt)H7d@ovb^_xZpV9fP_d0uP0l5>eM!`znRX z!CD4Y3lxgTVgeLU7Q$9gP$Ut3b{}p#!n?Mt^8zTKC<01x)53+a&qwksJB)g$=Z})J zlffy&c!Xt6r>+C=JI%v2!N>5b#JY?Otr~M!;<(X-NS>7UH8G{-qbD>xv8X}-q+a5L zF9A4mprYiLR>6r;ULSknsee1GSXo2USVf{Z6TpI7jxjG5w(kV*&2ay$OV|Y8cg-+( ztvb$1L2bSo)rv-SKs+xFHKcb|S4ki(K-!F`6==9YQ%h`+qe{us5H_%{*E2wA={BE} zmC~%!XgvjW1$+&X7S(ery7B4&AXC4cb^oy`{Y%Dq6bfj%DEBr=W>gJSaEF9B0*57{ z3-PaC=l*pd46z^(CY3ib-O-0(suRN&Zn*wfVnA)y*pi^1a^1@K{u-%Q>_kdkdVa9Y zDLhpLAtPSFR^>r~0O@RYZuxEbEwxg0bu$aD9H-9S_raUyifXb^)LuZ6=hfG+ggbK$ zYy-GrAnWu0fIW`h2jB3t%G0K(Y%d`Ugz2+?dv;LMd;kH6QS+HIxUUCnzXy za3>M;k(5I3rqnDjkQ#^^=fnme6|`0lnOYy%uRPc08mJV}W$PS-D$t6std4qzfa9S1 z1;PsWuo32NYZj6Qe3*ZBczD*3ItXZ^XF$!0@=L(oz@%w90^d3=YE{}p&We`ZDp7bM zV0C%~sFrINPZCqeO-3iOOFaf`@V0|q*kDn2=-3?ML=3I;OeEzTeZt&q-aimH*}a?m z&|t```}K|V)1pAHNm=2$Iyg3=lHL@jYuYDDgKJHhGdab2h^Wcwxhg*`Tz5U1JEvI8 zocw9w!je`l@v&oClFH*{T6TQMc<6Awx&h`j-n!^o{36|O`x($4TQOY}dpgHnDGgix zTy+G8JLontG!z)&nzv2;UMuacq)!he)m&%BXYn?cX+$X<qQ#-F( zI721v3K2+D(`v5=+u(hvT$DPvX0Tfl{xa&c@oV!c)3TbM!_dB2>irk zfpQ0E_g#n9ASwsc*Z8ZNL2*yK3Cfc(MU;?`1&+RseLhUwM%6#-7t3`Ea9Jz1SEU=@ z6XHu050L{3&E23iQ!AP5H)Cx?Wi!QoX9twu1Hxs1Dqi>TTc0By@5Be5y$XQWcmO@5 z!80<(f_ck#?-UUpAoJgux-iB=SxaLc0Q+YF-x<&v8T&;-t8@fyP2D6tTNS?3i>@yb zm3;aF53plro1n?-tC`XkR+`=3F(I-3kfMO6p(hz_ZMR#})^7EvaaBdDn|$kyN4G7% zh3kMC0RX|WD`+rhcpGwy=i1fcQjg1KNac)6I>@kQJ)nv8=ft_Fa2)Cm&#o;6olX9* z67xPh@@i7w3bw-E<)SB#@2t)3sRCJ&-1e_-X)ivTi9xp0KG~u{mtVFNJt*=)`d0he zmpR5n{W8Bc(++5M=BzbFP+$ixEHvBMjpiR6^e;_5Y`XnyzCnv8`L_J0ubYXRMB8<{ zRXfBRu=fV5buv2wdm7@FiTN~zJr1F;eTg_AJW~3WO2Vt|mQ4gXl`0#TV)xUP_v^6s zeb_ge&XXqIhG!fzvu=f)PmeoF5*3b%;jf&kek^a>o~6G<&0}gunA+>m-{)a=G#mPE z6kfw(bij>gSedPpll|FhozVigpFsh5Zpa@4uh0$y|Am*l@XEHyi}PVc$8D_RA9ILy zQjVLNkD&vYxa4$6F_#~ZfJ5p1G83@JHdH_0!nt~^6?PC5lBf?6^^-Eke^wd(?mP$N zx$OldgO&cD*Q2}y%CTXoGi77y461p&>Wl4=Hoo2lJh-#rlO!;OT?6JW_Us4k)T7xlm2$iu6s6$UROj5PDw^ny)|^fc=&M^I}RdR4o$3;P{(+7W6SGw(WydAlIGN!K#<%ay=`*@>lkXOy_SqEcl zzj{R0wD#QBKFZo3ecWfaIo7@P=w9shTq!xKDlgZF5_0-H%S*iwJa{PDL_Ch(PJ>KV z!a@J?a9`QO4SQqz%`?(*#Ai8so>6wFYxq7xf|P^VjCE%nh*T8$ay2dE;Hyhk>GC+c z+#iRy{9&vhczon-W(w)`A%0fb&sbd(Quu?>=)%6Yy*{KI%F~*d`L6--!@Q1EHe6q% z8twpkVsv~i(fZ!&&q3hi>(r|Eg$kTOffU&^Vq%{@Lov^shW}xXbf$#_q+M@~{xThq zeU|7UcJI&8*lcKkGFwvj>^weflg0Gb(V@!2?HO{n52lAvzsrw@BZCebDIjVWMkl#X zU+YKJ=2YO^ybPKGBhO=JIAjd@At(#*aUs4$jPocilC;C5Ur_eXL!Som`5R?twI0P@ z^65OAze>GsFgAPdh`DSbYI(6{Jzc4P&{m^vQl#LN@@gADD_Qm#)a|ETU%NcyO5I|A zG$2Ql3AT;)eAGNf9mW%qU-#!c`w}-77D{QUFX1ZMbk5!tR}{TpPL)l#sI1a3-^D4g zNIc5r@={^R#h{f2Q8Jxb!FQnk;oZKIP5JZ3gnBt_R@L#3Hmsz z9mzPnuCaXdxK_`Hec*vQ_j+?_B{fEj;+#IIvFiERGKm!ES{%99>2}@C#u~1J?(^f3 zHa(>eEGrW>-3~hgIw_+N&LQVpY~Y97h#zb;&#hKuyP1?bOk^{MCazHu5418Vyw5#YME=MG$0fw0qbMN=I1 zf>$K;e;j;qa&(`3oZ$nb0gMx|*)Jy%_wnn!7pWLLdZk3gBtE#ELN(OV8j{X>@0(xh z2#&$%N!gN!-sc$%DVK_-VkW(Aw_%?pZum*XMkYr4L%+Z8Qf>dHnX>fWE_I~mJnGw- z8J#3%YGSrX>E41J-S%Cavl6x(4)UZNX8VPP?tc>M`nH31^2NMlGTXs66j!V0Z>@`&PjE}{ad1aCSRCl}FMANm+J=)ev( zUp~|N`u0GG334k=t}|-T$Z2N-7c4%fhmXc^hrlnI+w~FKpNl9c(_=u5R($TEO9NQ` z+FAK&ZWX)2xsQeej)7}(d>CIc;!jIdzdWLJ0dl!>`8*c*>YEYI>@O3w*L0HX6^Xgj z0;3lck-pOA98C5BgI8JOK8c|4?8}g^Z}6H8jSc{&WmRXBoX$DJ&Z$m98Ii{CyUt7W zuM@IlVA*mTTkkReZ|$qur#ZNNYr*vnK@*3Jh|4;&)biiR4o?J$?4A?|h_I5f(iY2^ zA9{WrH|>B!?H9?DHn0(#_t5Y~P6IrvmS^s|$Q`5m#>H1vbKQk<-GBb!P1#ad z|7kh?=St%mViHtLGY_0Rd$~hG)cl~v1GU!+8XRK0#8AmWp-B0qCPL%8;A7A@#6T`8 z^mFJzV!{_+e1nNnTDShiVG?Tc*W0~60^=Z-wApuB%6O4DFz=|fz*HUoW9;DadM)jX z;K=~_>DS?AkzZ~x?XDAJE3;Agm;ydBNU4qTH_mr$a!$D84vzx z`s~B>?a=e4vJ;c-%-Uv02@{E%IPI7}@T)5JTEr?K2d>RKiQ1w>PlJR05J{@r!6uK? znZ93~7q63nn_KwAOWaAsHgKCd%g68x5_}l(3+@E*iq+jJ_xa%;<0Q0~8|HMlvlA%d zY^+zD)4CAtY5FYgR>%YPczF_{elKR@n9e(dW%b^Yg>JHdM)Hf@VyRdo)^t6zXsnS< zV{-z(tZx%LFp=u-BXwFxRP+hFNY+@mW!>%MKowc4RH*ogH$j*$R$Iu znr2X+-glw+A`20LVllmqy5oBE31}Euo4JP&AMwvmTXcjin<{4=FA{sl48hG&E zi30O*14PrwQvtm6pG)%8mup7^N9+0HngqM-NF>}Gq*-%Ot)6k{(ur|k|FM!rbDbA9 zj;J4pP)%nR$nai|rOsq^h@YJp0B^k!K|)4yaZ??*c~ zSl8~X;rN%yT*R)E*5D#{H+a^h+K?h>u3cch2mnj>^F7v8{ANb3#NGJSnoZmjzArv7Ii)Rd8h6v%g<`xVz2~ zaAbUM-FANLCwSza_sITrX^*Ktb(c&ai7((2pbPp~P#b?&kBb!^@IGg9K7wuLFUtp6 zPeVed?LdZaZ_O;7#qs^fN4&+0dl>$+h`3M7PKzC=z*Y?n)tKX`ghG<|)hOkfXNHnj zQ9>+)gdZH(&k1UjC!@f%JCdYDwLo9@?DKlSL zuc-&b3%R9YX850Nz&_uODXCd(2}#6%IxwAcF4@JBXZueV4Y|dLxA6CD zC8w-)hnqU|^Wo3i3CWxH8~BhzO3_L|_OiR*Q55>I-X^Lz?)kojsK;u`IJnYBPMA=5 zw%rxrdIHgeoAOWR}3kv1AqNuF!!1FOrOukzm)?Zu1lfD0~RelFF} z+K{)7QGCy?yMSt-!NYj%6Dy<(??Tk*j1Ixx1UTH&qyk&x6*)&Oi_tTR(^_M1k~fvRq< zo(%&altoSSHBx7n{3@*&psy5`b$#dv`vCYW*r+jvSB8XvG;20gTGuTXNy}mh`-nxhv|DBrZ1%Rh6Ns2hmhR zt7c(RH%app)#dw$YLqnT_lrKZBqSV&%VPNXJFE{;LZ2@bw{T(Jr@XYEy&B4`>{z+& zy@MwG8CP<-thwO!f22T$y$y)=gUHSr;>{cTGv*1zP1Gy))GFW6JHP$a(`%kmun$u?1NSG`9}`Wbk8 z3$iVrj=C~$-Tm6}WN6;z#)-=T+j%IPPJRxMO~1R|HACMENCpBN=+D#ZU-& z^g{4vpa<0W;s3^5yZ}`(V|MKP4$4mLJ=6<AyHCD?jwS!hwvqkjud|zr+x}^M{agu{UI=S)GD<#OKtV`}Q8#NE6!&sAJOMR)O^P+j}vhqI+J@ba8Tt0%U61@A9E^cFh}NO&>p&`MT6M+{bQFv#%RJ4(UAq-oVuMA_kdi6RxjH1h?g%CJ4dZ*?(f09G8i$Y{WU|GAp* z?hy*o^;0Oi?2HRc1lSwJ+bV(%aAw(prmXkX=!}VGsK|JfAH1TN`a?WkQ+%XQ=3kj2 zMng)B0v}i=(;@DS-Lhvg^*Yea4u6pywt@<@Xt+gIxIWa~VA1?SH2A0o*$xXlOAjGw z%3V)-+nclWz*od3Zo7*b7SULaFRtlO+ zzt5v2fZ~UTWZrUV3d9{{K5iiJ_ZF&-C@f!Qe`bT`&2UA1^5(|p&>?FH;?P;N z3i4}NZ#h@a)1Q~Uxp$VgWgrb*+>#^Gm zfL5gnx{*1Ph0}L{sE0U5Z^MSKt#+57=u+Ed^Wje(KT!p@W9a?7io&2_-+zt#fBPj` zoHI*hmSJOl_rx@`*C6Hn=^{4X*$m(2 z^Jm}rr)8yuFcEZPa`k)Tt3Ymr6gDBs`z8HG*iX_0ck`|N=3CiIndzjQj2_-a-gFwt)tWimbX4mA(2eo~+->2SvTLvrc7=?!@r*5cy63x4$V z{hPlBV0*kT)8^o)Tw_8XY{mqxv=TD*WVX39eSG+|FmIYUcZ@33fnB1&82sh0NQbCV z7s4E>jJ}JtQE*11D(&>+o1Xz*$)N3`%eB^E->P`4%Xdl2YzT$JDq}Cwwn~Nv;jLVbzKv`*5-A2APb1!krMwT!56_SF zn|8qa_%!NWP?z=NsCn7@!fbUzAt1bwAG14)T6JnB(r9`vnW-y0n?M&DbJY4aSm`5BB3$ex_{k z9zNWgyNbZ&3pRu#>D#FMrT=vt|K*oIpVjI-m5BF{;7VK|jyIy5{ zms9*YMpVdYG<|Sl^Sf3x5i?YAYMQLPwIk&uX;BpXu>t#@iKZRjzdoqslMsT8!DP2M z5)_DJqp237F|9qfdnMc1N*jc}95{)Bnao&MOO_vl0K-p{5$z*D9Lre9BfB)a9y}R) zefsme_9XV;$*7T?Us(q^_5G!+>$WiEU?cl(M%Kt+AQaXRCpJ+H?LiN`Y*$o-nh*Mi z>&eOZ|NP_@-TpwFoZgoGk0R%pGaZS3(LzktPv}NjRjE&bL*=TL!!mr(cawDP)d==Q zNL5r*9r7nu;m(OS1yKpU;=XGZGMEh(9Ky~E%+-(s(H14H^KJc+wU{w4(pG_|emdem zlZNx0Q9}kN|AOx7R5`C$lU?#+oovMIP<06x(Th-L5^s94^ymWfU0-WkJcmnaT85JG zGiA?j)n7H)5)ODlu6ZD2u9@2=cXN=(CRlF=a!bir&hGw^J}9Gz$SgjkXlfC4_&ynE zWCMo?M%g5b&J$xBhv`DP{H)2SHDGuGYAjWU5-Brl`Om-gU%I*vwa~;ZVWTY;33hh& z2iGSQZ}Jp2>qwmy?@5wR14}tWOiEV}3(ocC z{GVxFY$%1%k!_bkdv|pH?ghXC3+hsPbMc1CcTm47`0J;3!P5Gk5z#QgdiKx4{Xa%8 zF;+=I;f>i{w$%uhBY%XLqd|#0B(D}`rsd2GI{{zpaTx>uS1l{iS%x2324NnpZ71cw z$E{ZqtIT6r^KqX%21ED%|ZVP~^*(05vCy|^O@W>>RsLtxn1&R36ITpFur=Yx9MUFDcm zql5Cfhhsnfi}gkR$qNw+ZisUvZlY2T+$uVZ6h7wsAGB2+oXfH7novg9)UlGz{XIl* zjyB&VNGZ7+#ve!Rd3V>>DQ~J|#qnu&qpgV&6sIk$D1D$_C?PV@&zCjnSw zSS*e&Wvaj`Z;xk>T%|3}QJO2C z4b#$r@>eF+-O?UNz0N?**`Vj%Dx9RznN{_2v<8Ea{c{~$+kO4)u6Od+{=&XgDpSBU zjLoJMVeLh8@=`$(J}_lsd&^*;Xdz&5&s7$-etYUt7_tWc z+4?CWmhw4Zyy5~oC&z@seJbqy7L`zqN?Sa;;TX?q*V2o+`A(@NRj^2F`NdRQu+-@7FDvwxbZ7IHuC+%CuA zi$Nec!UL6&aWC25i0KQ>dS1*k!oO^rN|6n{MUJ!iDT6UISIkKh@!nnDCLzzvCo`v` z?NvMvhQIIf_72zsC;(uisd@JYuL#zAYbZ8~C1aX%>!yV#|W+u;^l-V1^s^ zQR$LR)M-I#AbCKoN3K8};t7c5Ji}zG^q#o<5{`4PSkGk?9C@t#x0`_Gpm#^lb91Qr z%0vkk?Kx_~zfoX+(b6tasQ zt`tt21C51F!v?X#a|WMwV`G!};^pxkm-&17KsyG{rs!yqBekLxPHQX5{MK{~|K* zE>TQgdmE>$pL}?Iv7U6pz&Q8shyi> zK5v{;kdCY-|Ifi#?a6`MRlBETNVLA-?I3HY$cjBg+^Kt znW_NNcS(8rbbi%n{^t+>CdKsb1Q(ul;D8So{(PrWp=I?U{h-34mDZO<%god8A%U zSWaky@~nR=>}*T}=AGl*GiYe!LZXP#1g+zKb+cTXQ)?m%2yoQ66=I1Ybo+Y~B-lut z_86>Dia)SF0NClLLH_&+`!FR18yp<#4SO46vI}&j%f*jkXpH#bl+g6(KIEm-2fOB6 zc!DPfh&J#IXguHfqnFSFP$PIwI(+0gSiJ-WzxeI)tb9lDy zCLl}Cwz+26;(xcswA!8l!00%Tu5>M0URLlCUgkapn=z_R{J>cP>C=qG(XLr6iCkg0 zt!PIP($Z4{z@XaEpN%13bD(0tWwD0)rhcA|Ps<;=I3V5e6_DihnA`rQb`Z2q4E(zI z7orYf7p|zwMZ_uVIJ3eELC0nn#|1hoy#TXB;*;gTn1avXmPk3f2~Hgfd=Oy%nSy@L zDbS5P%*2ojTy0(-is47 zt>;fW3QQ2)bGsy|8CfRl>oOIV7M6AT;AB2rakasSYKNnrSO)E7hCQN?WzQo7vY&}RH>MDfM?9)?EAg3s=h z+JCB@r(LY{{#R<-=?mOU&yFbIb_Ytd`Eb#O&ekR0TY2pc&LRU0lhY?JxpDK53Cd>4myr!QTeiu9K+f^K zZNFNJP_pLk9&0Qp52g=v0!5?MF7$W#Bjn5cb9KJG?RvV~c{@})5|r1al+4;xfc{-O z3yMaFX0(*jv2YD0VXg%VWe;uLGDk1}{tgXmCi68p0b66KCuRP~OR?94g@;3-jd~&5 zI!x~c8+$@ZBC%3B#P^qB$rD{`VbUiI9c^eGuv_x1Z%UUax~Jn!28#M6rDeiNvymnr zTsGbcF2%T>P2zK%0YrCVmIYlEck<(tI0@cB^0?l;`vwYGfM4;_9GU`Ojc|?suWR$c zF%Zq>pJ#P8fea4lMIAq6lL#R}$BEuD3d;)qT-K!YeoI3DKb5i;G_$0Xkb>1`t-TH- zvyJ{Dn4IjT#fJteHLsxZFmxGnxcLcL1m*gT_H>)4HiV90Cm_y)T#3S-U^&G)4Q5UZ zpQB6EKWqEH{?;ppw_TmGQafSAb;>RdxS@bgCj`<2_UJTPLgJ}GgY0VVSfUjjNmjRc$5@Bi#(%m{$Op z*{^zZQ*~)^UcX}(kYL)zB};9)pa=UI=067STOxCgLeuz%h`6qyps^t5R5qYc4AQ%v z9U9;hlp`Rp03zi$%(tT)fLWR%z;d!72T=glPL6n>p8)(F(QkyAYyji+1wtLLukVBA zeQJ?W&u)`G7`X525f-ao^*a^2Z5V_6$O(Wl0x~f^Qun^S*nA4l2}plfyaTho>NZot z19g`!a$)P%HPjdk8p0;dDaG&Z)`D@%ZIUJpRb^=!L{yY>ZMMkQ+r5~3c<1V%LH_-h z#@C8l8M-9!{IR})mMM&>>r>zaP!ri4nf{^yaL_OPR&l}nF;Oy+7@(RY~QFc4SaIJ)oU#|`%ZwsyPH-a0%m=+ZQJVrSl9Ng z4Np2lCdjZcpR{gbMoF7tfDiyNeT$CqhU~ou%SKu-HrhavZ8ib*VvT}*0POc&%CdxEg&=V%nuE+f0%G0GvJAL$AVcE|DkynHz)TRHo3Lxb z9>Z`41&LzSY#HCwOq;nJMvrO zR5q{>NcdBK?e#in0l|V0V9P!SVEx+Nw^ny46J5T-BQpk9;umZpVSq7=G!bv&-8%IO z!rq7s0&TaWf`qKc)=X`sn!qRm@PS#e1K7b!72-tzSE^s-rxn}F7deoD&(}fn?K%kT zstCvE*erjJ`X%&ey+k}C?sk2D?mPoRB|a$ZjmXwo*xmvh%)-Hp`)EZ>Do7pO2eoJ! z8;DhAVi17Hq_+)F> z_OGlYBFXnZF7_V?Xq%}TZZZJYn>0s&u_V?6Dqe)lpzs+~?7Aad{7!}*A=ZOR2ZwJf zR^~ZFf;l4W-B1{gS0Cq^`#nYf*T7&dJsxl^sMO0PZMD;G7~?%zWCEr+K--a_HZs#* z`TbeIVLsdlVCE?7;*2v;#v{6H2@=Av&hfycgpqXmXIT-OyeV;)Vs7`}gMuzVxaBzaHK0qFZsv)9>gIF7O1k-TvLD9t$=QCX!(fTtn!dH z)pXg*MHdyXbVg9Cm-vm_AN#F-G;)yO;+Y6J4~tLJ1li#r66>^-o-+VWQGvd5nx_FIOqw*+9*6HRK8_gHGY<`do(zV)kvPXJKZ3fe2t|g|t0wdcbq~r@D1^{)n2ClG)k(sCyX?mk| z$6N&(lf@<+bS(s{S0d~7q~LRODm^%Be!2!U4*TXAfNjqVG!2Y1%pkF1@3xrA0keIX z%nX-Bk*2n9iWkoBE~xiJL{!bS*?(1! z`|_Kt53t9@ZiWUaT|3ct?w@f7cDLP6Qjj1!f1{GQjsax2!?NHM;FflNS`Mj7OIwZx!U@oCYyEp@xvH)+&XM`*e4oEadi-E` zyuZ%>j0C_6TxH)nq-oEC&zbl2^6~iKcD7viIIvJ~`{!2#7DmFy+8jZV#~2%IwTYFK z-u)`jW=j=yII!sAMp?FCJc;hsXym3$DG$UCz5ouVIo z&*x6h(PVOVB8=UjpPy$~&e3(*$^x})8^b|)HH)(#K=!;R51?_9MIrnwmvV~uCF*2e z`-Pg%XFg$eHDzU|Lu%AyH87AXNbeV9LO?!5o|QM*t)xQ7yc8!`yLj7`s^ zIbs1Vh55?$V-5a0Ka-MfgFJ{jpW78=0_^wi zh0r|9GJuK)5@=mi$MlB@&IK? zY|k7N>PPL9K>w}u%J{VX1~l@i6xUi*cR&qe#4<5WsM(tajZp_-6I)Yw?o377e%B z4q1N43S+2c}R@4v@B8$wF;jA2aWT3C2Xec6P zMm&ZLSpbGmAP8Hnq-q%x@!yX#?Y&91)=^klwq!6}G)H0n5$Ma1#jJL-2oerufx1Nc zQygt8%5Eu$pYNMUIIViB^H2VKn@lMxXj1w!zQW&hA1S)7GuBqO2Stv9BdyJ zx}1lR%T}>~Nt1@|F@yJS7!Q~kgy<}MB#8Y+5@m;jHiOwsi`a(C{VXu0XIj_iy z+M7&}?QpJc(DW$5<8q^&MnRxrsP40;+8skKtm#$3cUj}M}sGWOdLzpoPeeE z3}ijHf@5#){zk!01>qYZ2qcpkv%p`G=!4Z9k7~V(a4@#lBF_TUebRtI zJOU=H4laJjhAmI2%buI_7$%~kn^=y066TGhTEo@#jJ@Z_R^$40%S57&M2N=qL-7dS zl*^YciDcU?ZS}=WhNZ0nzH+qa{M(ZG?(~wo{o3n1wiM`E&Qx~XofCL|qJZ!5BEdSr z_!hbOO=zYF!LMg}VI+u@k_TwvJV1ESFj+?^EYN}G^vzeNT_Hi(emO7Dt`DkNm+=(t z_>N+o#h844X_Qu$Wq~;7h7vGRgQ;Upax3IY9v|Ah@@(3B;ZpYMhqcXe2M)7{E5J+< zO)~+Hq=hf}qzJ_TW>gob$9Z7cou^UP$Lj|=(Z}EFRpICn z@SZY#T&REpT)^ZQST#@pbOoluc~P?9gH}ZiXy|I} z9fvLKAM_UUDF`qFsg0&TI%*9~rXSxE$Z`GrVzi0SJH;DAq>CK~%WB|(wLBQ)?z;`&TQ{Ex14>0ha-1fLSYxWcjXPSU~3ew#4OBXe@xImS1*hR zt8vbrXv@CU`ytysOwrD!vH#7qL)iKp?rRM&=qgovG|- z+ACNuZV^9HfcbKTnc!+xzM`i`Q*1*DnVlXKLQk@6**1LuB-U^S5+@4JYwukz{`LVctLs)KccRp>j!c^iTYY5L%u(*@H8$0|5DdZ zGKIOyQEUG@P?XCvT)Mu?zGG?yH3zap`fvSrVl8AV40~x&of;JkCC(s*Ke`Wh8Ny$v zWI|^MjeD6H?SX9Z+7)Kt->Elt@m)y zj(jTE?t#DW-HZv+*~`SkmNy|fY)f^h z{s!|}$RVhriGYRQviJM7vmEffINKrumu8U(2At`Yg1~dPBM&0!G{^TktwXHoo2N6o zKp_Bck(70I7#HIqkYu&3H-Ud5JGMFy++_WSLMYDe7FK z5==EejKM$T4M;Ia)`RLRi2Qto%kgO@@6gn_A|DzP2NVnSR*qgpUub7bvk+OR_-xwo z$&>ZeD7b0GhEww2G)1oJqsPTy{eFB@noB+-F0)6?><_ zbJ!nM3w)bztpJx)bVPMrpZ%Oas-5`{gevir_WvEA&imj~`a4)nYmXeBq3p=patrF) zYF0Q7e%Wi-uY7!0xa|GC;Kz@~<8^m@2lEnO1gYmK?i2k)(ZxP~){bs>>h;}kkPmo( z0?5KJL9UJO{g>ii4}jYpRb&CW4KQdXPc;bgm*U?0C$q^|Z*|Jr#=pxSIE4a+c}OKy znAGk;E_1zN)ckX1n)4_S^z5zBgn;5QS-TA}al1h3pj<;1B7Z3gsuD}VGS?2>O(4evmAiZ%V!6C(71tENJuGso zqsua0ZE$wp>n{*oVk>60k^UBFVQ=l%wF#5ebsAC%0^^*VPWG=@Wl|#mS#J-y*QbxO zh=RZGm_99q6bW|A^J)JNUtaOrfrV^8VC&s>wb%q(n|9t7$-p~8~=ku$dXK2ja%Q@G%uIt=)UEIvAHx12p zMaU=y;&6WxjSN2j%I{y{#IIvt6!UF6p0*X{--T-gkRzudSbUlH-AYngjjPJ}^lJyb zBK}LzgrXl@YO|%$Ul_PwF6g{ScnzH$x7%ScB3GQ;y#VcGD5|~lFmm3^ zXaO91cq_URGQ#|Gxvp5bmlbfXw%LT&s@2bBwSzdu!|{AXIGmzmSOl(`&TGlDL3^Yq zwt~xiV{$9p##_Gv7`wA1F!xMNCLoWOa6f82-Coy_2BP5Yk2hAX3Nju))L&f_sbN`i z@L^zc?9|1Zp~vK6*&Cb#kwX;WgmHzF7uI?Fo;!5?&2jGua8>2I;s>XyRsBugb&4gX zA$7LoZ&-~;%m7Uy+7j+cqmN7DbpB?y&rF~AyZxH_d?e*M$LOTB_++zjNzmcSa*yqh zEq5*`!X=u^ZrwsaFq(;@YsVn)cf8kyYkatMmv<5bSDAoJ;&Kkh3PVVNYf04P>OVue z>P`wB2oI<`v3I(0-zN1fpWGtgP=hKsfBT_rHqbP{;aB=$=vx6@t9N6NqfZ^6!QT3o zF`QDClV0+vA<0=S#-t88Odh#c1~|3iFgPgI9)bvRo+6y^lfje>_pJDV&iFfgj=mdb zNR=Ppmwm~7T>354_bg|!_X)iSGZf97I&~b5w;D7N-aT}UX|(7@Tjh3<^|lghTKxVEw?%^T=pHD+xIZA9-@YB5$9Tz{lY6VpZE_taKXk$!pwcZ2F08Bw{7x?e758K z#-g45R8uIND)D_5Kpe>JA*E&gB`51YL)WAKc&8?uFR}r+emFB54!Vdn?9)JwXvtNxnjP`H5|fBuV-(fTue_68H2!$P1GW#` z5iajl_Db_$5g*i^^(W!3>wCOBELJ4zq1)frOMB!_J?M|MzU2?~Ys~5!KiIn3?k-Hs zjvV9vy2VHCVd!VL7#3>k4#|h>w9Vke>hk+40$krXJwJpG&D;wCr2f6cPiqbxx{a(? z9z9vCsU<-M`?b$mCl>ovNm1dp_~cTq)$jiP{v7vUj<^eUpZ3C?ULK{j>)XZL4mpc- z{c&Asu5;h&cGR&AN{`6p$$)}XN5_y~fD%JCTImdzAL}4C*|wN<_@K>cn1OgFiMF3`+YX#VhMX}2+S3EpWpS>^*FbAR_2fT(_qcg3a$6kr z!L{n+ylvUbJ)B2rHIA8HI(m8IQ-+4#tN$S`4hdDZPdZKPz zeUsE1d3f?_RQCrpY3a_Xy!NE_t|)ANRB^j)q`{lR?T24xdCi%qUDfp7krMe~zwnV= zTa~oWT;I6!vAt9;iGS+KqPW?+M~@gh75IY{$D#?cW{SgN5SAp@YlwQ6S1j5?`UTWzELm&-r9nbFeb0|d z*k0s!Ar-pgdJQ~CTru+9an%VzTG}@z9UqjH6x^^P8SZ{Hy?Q&-V5w`^ak)Ji^K#Mi zw8d%ngZpXr=nFN5Q3;p#ok%GpzqR}^zEX6ivL|10YTf%2TWa8BmB^(oQK#@^0;Tdx{mbtnd?tme`}aDKlk~nD0sAs%amhudH!55fmY}+=!hi_ zO@BXR$E2zbvlA+UNn2RtZ05Dmst;2zN=OSiuT5Bu?kQ+SdkkkDyo*78%*Xz|ry&zJ@+AsoL+L6#Yl|yXNt?4}CTpQk z9B`uMj18dFm+arFpOOZ;Ab`xK-o>S!5`8bmb+q0eV zhc->=r3&?&drx-|)sW^9aU=N{9et*np++A2V%t%adqJ-v8U6kD%%`C+#<^ z6B22uv1x#QFj>Ww7fZdIOQ3b&h_CwI<(qJ_8O#uJDVg<#xPorGQcaeOgxU6uFp}ld zO*-?XGU_dT^QHTPmYIb@_!COfk<*1qtuKUQJdNDNj;?T=R)&GKH3 zu+i6j-@GL4zG7kvKcV)cuC9SXTFNx?7!(ZkTr^u*jN3e5_hF;Qt?5VY4|JwJC%O-f zCeJ6kY+uM{Q`oP1>2n{l2A?P~EIz328uc@A_UA5T+P``G^2LVv_UcF9&`$-4!buLJ z{ennb$eim|%)k_%QeAQx;tbk6)~1;qh{Iqk&9x!6V941-mnuJ#zLrmCeXp)8aY-uh z`}}GwY~)qsH*>xBZi+TcVx=iE+?LeK*xT($VOp{8>~xJEXKleVZ#4i%j=mGM?~^Od3vx%n<5FQt21xNXZa@*IaxXOgyg?}hMu zUiGm}fs}6h(#=TA7_JFr2C)lIo){!|<;r&G>YcouK2XYd>3fbx7kJ1B6Dmiy<@=?i zbvt?{_7g&H6`ZWJqrn7F2Pbi3j<`FoH3ifHbnaW)kSBJ^OK z$=qd1KUc%-f0gCmH?BR!Tuu@>R)4%#e{w-#rh>vYzVOmkt%wWbfGfRH!D5U(Ux&Tn`cx>2bAH~}I;y3d#w8rg zMx`$jY`sR%WyagOJXgd;tgC(QBLkHe=4h#4pGPne>EcSl(8(2dysNoV@Oxkueacgg zlW3eDy*F-y|Jtp5*Wq!bWu&hH%i03)6A<%0imaFQ$NrD$Isr)1p%+YAuLVB3La(-y z*4sYnRjJyJ68kj1TPrS9eCMi~3a&R}`{cgINCNVih}!3F;FTKJY_d^nR$+QKtkW`g zf@S4Qt8`hS-wz@m9^#)t6vi`~@HtX>LHIQ}G zP3+lxl3Z^_#|x%l6VBG;O%GyyV~DV8JC<%`wYn)TR1TN);8~1~xjrp1o$GTXrIZvn zocTQ;DwH?YP<7eY#{0HzVV~Z)OP-(D*ia>Kya|%1i2i)PHha97Z^FiEJ$x9}YMpv1 zXmpfcA?@FrlbSx9jm+wP$ET|b?K7z}G(4wcGA+m3gLx5Yr{va3m*PT-X%jq?!!wi&YzRM<2@L~>f1v4_t>z^S#cXPDuv|O-S=z#M#yHv zrN>~|)8b|eit1d(yGB`gJy$;&!RKdsa>}rOG7EOE$$3IeB#Sgx(4=dw9&OuYW@ROo zQ2ator?=Gp3#0vZT$|neV0)Gjo9`;DHtIS}uT;vp6e$}H7dKyXGmtkHY0*^jth0*^ z$szYhZHwXI`nANWsuf)&X5Q4?li}_0AOM!UN1&Hc;cQ{->Zpj{*#uXQiFa~3D2~AI z_ip?N{aOa|LI5(B!wNlPjmL+3Ba&b3==Ek_k`MN3v`ZhlRvJr>x;jPq7%pn$fKx>e z#r1O7P?|k-f=4{^05|=>H7@ou5l+u@qYbFf`)IXahI=$&RdL$zkGPsvFdEOv_0Z;8 z5(C1!F!@Fvvn<#+jqfzG_%kDNy_2of9?^5xj0{qce*_65n0S04xzvlSR-?vSZ}Rt? zAQNL}Ew;(ypNT67D>zb8{*aFssDp}cLLtp5|WsIj8AgDzu)$^ zYKToKNyMGp7tVrrDY-QwT4jc;7h`F@Afu@F6@Rfn>dMxmGYlA# z3*SmoY{|GqWYWKInqzCe-%<2vv=3q5 z`eI-IyF3?C&VSpGqjbQ-U=u{)f%Pflq@W~inzL(o8FiKV} zn{sJ$mTYS@8_P8A;`@ajf;9&pph8g+uFTb4oQp`t+_105%mW^IKQVC(+u0P>x@qm# zrUctlUv?KWL4Bf4x50`@(=SbPiuE>n$1dNE`gQ-C(2h zWU18!ffmj!`74td3<@o;mjjvp{;=@H?+ytveFNj)PnAF8y0o~vJXa6~7Jl3NaTclR zLu%V~O-EjbPfMRKbTh@ z%mw?iYeL2zg^!s@Qka?C88dDN!7D@6s7Pr)L(;PRoKquj>-7^XC=33cF~=1X(0^NLG@H3IEJ6eAI%xcaIhR^-KkEo3~sHCsI3U zK?xaC{37T6ICn)Xm(Tv1o7oqi5=x~<69-}q&y zDV&-JmoI*EE*cFw4XZir2b)q0)Tk>RiamU3`Xt&5)1NVl!uIjg-^9?Q=eFa&Y5oiT z`|{Et0}NX`who!s`e~n)B`71m;t+0$8(?)4zi~$*pkgm;jJH{ASl8!w1#CW6e(SLH zhhd(Uf`Wpa-SjIXj}?mW@#EXcO{k5pj~`W^A-A3Fy?h^QB$iIQS#VMrP9z>*!j$yS z(N48`jMg*LDSM%Gpsg;@r~_v*5Cmz-+DywuPj&DZdXMi?sk(XLei>42zwTXr+2c|8 z4NK7PYY{^)#5N;gTWe@EBiV03=7N4B4Ve@Ccu=B?|nZW5K2ze!j7L;58xoL zAYc%w5C95g?B>U4U3DwHr4g-7S<=vU{Ag5H;-scqc>?Or$kHQKSX51iM=OjYw?fcU zRdvqU5^FX4&w^DJgVr^)pZYg2XReZ#2$9m2*iRbN23cg4zz?d^MQ}sKkhS2=>W@&N#ZbFNGAFsHmpBu|}y;D@g?-lf3ky?11K>pN6O2CgAX8S1ZQ$f+2f zkDA(QxdNk3L{68B!Z70nH^pDXh+l0q6Ixs;*x>U#$$QZe%EpyhOe`}w2et7#g^tg6 z9_{ybH5A%$b`e&E!s56G@>)V^QQC$&X#i{yNEfn4TI53C>pPL+Lf2c`F5(Vi9=Y0% zslh^yid*g0rL+S-VZJD* z1ub+sm<7##-kA8wf$Qbo-0#?Gr?h&$>L%<&H)QGNHesgb3h^Tnw@MUr@;=Dmg4q9_ zEaVGM6*&IvuSYow&V6B3g#wuYIE-_db0SwbpV6|D3G*Vqi2?ALjygCEqoF=!Xm^1D zQ`xjXzR4pq^PhQ_z_Vo00m5Nv?kCCbKGy?z_YAWzr4h9bb4ppVID$6y!~D9MH}5 z%$zfmR|v|@jowb4|If<>-ubHd&s$%dK_6&^1^guh+ZzAC5wVRuqL$MG#Y$}C`3*Low##1mYQp66T4%A#1AzLx z47t|dw)sma=l}7md!d}@&cZ}D2jn`euw1&7FYAY+Ao{8m9j%4pEn+Zr%+sPY|E)n7-JhGXlKdP_bwG@J|`2Y8Wc z*LVtnZ~vn%9{rP@8TQTU9!6kMkzOp6iQD`v6hj+sB?@^?@A$oCjTkT0AN2^ zX0L5nMLw$zh#P229lBdJIP4KA<7v}cFQ&Y>KJz8vSDe|CEc8K**NQ~^p=(C zNzsO#4}}vTB?DkbJ$o{d(QFknGopTa9NiOSTR;o@kU58ON3ZZ64K8uQ#s5_?+e@q-~mqUh0R2NdW7Teq$`Y zJ@Q`c9^ej8xAo2mtv0_>-2OzJp0+;BXMCPMF%W`sshKC0BGaSi`~};Ve4-Br;dNUu z^PGkEf@%;n0vljX{V?0k5e3DH)!Z9fF2&%@-DRbnyfHt4%pJEiJUJ%b4w)E+Y%GX2 zlv%r$kY2qSzu?SQ+N>kFak5O@6Bgb`X*i88+6HN8ad~AkG72bE&(42E5wRn>6Zp}e zzj{{(aA}JKkA+TMFEb9>=!0=5f2G~LTM>@&{;=6j`bJcszRT$wJ0^~z&Rly}YnoxK z`+i}3o0g_#T6dMle07v<+GLk(@VIU$kEZr_akU$DV^y~b=}dx&o$PnFZ6(J6uFHoA ztcYx2lwDN*y`Vb+p8#186WAaopl^}xIcE=<637;YwlnAaB`n?&_K1;rph%N%;@kq& zSPu5nd8M|3v%r*|mCuCwSr}TvJos>$c^?Yi|BOrP&To|ryt?^eDBG56*mk>wug=HU zH|2)M|4#5YWl8e7@a>3Vqh_qfhkW3w=Sm_?XLZJ=0<>2;RJztuFYlOo9bRg(GN zsgx7#WyXS?l*=JH-FfHQmjbqUug9R*P?kG&VsGnDMr@#wVnszA9oAc`*0SkcX2i&p8s)SLe;UZG)R%I2Za6fgD>ybC~!*+vOU0YR~Roz8OdPNMNMb zDjJLK7NXgFLqZe_oq_T1ZV+)tW{RsAlpF2k!W(OXiEI;3s;=^&JyP8LaN}di#H#H7S{Jyrt?ck)w){yp zGC~N# zA2+eA0$G1E?HatYkGM5;zAS5bYJ0bG7wnViD#+{1cy*~9B+~=QV_BfCf4SJ0(&{?@ zq*bgNc$E^r)m=CGHy@K?WL%= zuWwyz-^m*uIMBM}Eb5Nz3P_JS z?I==yj9wo85hl4DAePgF(Sb^Vfsg?!g(h|4xR5q{Z*jhBY=zp};!ro8Fn7<=oGQU>GJ)RU61cKCt#uC{&1~c|lf2Vz<@rj48VW{2=JD z3BaIne=14+&|x)^1yxVqJ9v}$#m;f=-jiL`BV|Q(-j)U3Ky@{vjUdOTO&3xCGs)gA zMXI7%7!%N9iqYHdkYiNyV^_b6AmPM*HFK2>=!;Gz!+q1!9f0ep`NkzvGxWr?3f)#c zk_qnv(F{C+Bi7jXaU4N$`E%=xP|uS8S;n^)Xy0%W1Y#E!(hR$IX46y#%~Os)Q|3MZ zw)>(O>GZoy?+Wog#8?}F9C`RO3HLH<&4k$HPq&g)>!(@mgr|p>FM3K98vj(MDJQ5`I7is_>I~Wf7=^5ek>yt z8i4>lm{~hRvvFe%l3m()!Dke{&p4~jlVUg>1q8zBD%=xyb-Ou%a{3AZX zb1N9bIjhK1_u&!VsZ3vE*;oT1=aZZLw&+_HT2F%RcYEvSE{TQ7TEopgw^@bogOwX!+=XiG$6N^&@M{|3TzqD1kMynUrl|dG5W!gRfp=p?W6!;VbPTZ- zg4={oRwmh|^2`ST-j$JeeR>Yn>DT3jkP0^0YcX2mPF|0ZyHj>6(b3}L=_v}mxC2{# z)}LE(0tA=jRns@yFdnvgw4*07o8Gvrt|?9&+k=@k3#5J$AHv#o>n7b6zh*JLB_p51 z926Ezr}D}zNvMRQs9fB-WZvM{v(w^v)mY(7ZO726r(dRDov%Mbu zoxnjUU$+nQ;R}$?ZZPEtTH4k1)K=86J|>s?bA9Eb^ z@1$(aUza*2%byn2i1d~^{mvd8iJuwj(Db9t?>!a}!zm#=O#GU4xD6q0Ab|xhn3>T80Eoj`mvVgEZ+*pK%If7C&cZFKz!uOKiW%F=qm8W$DJkRgZ^oR zhvpas(cr&e9)epH(Z@l;@y*nP52sIW{dMR?td^JS&+)}@QjvWdWhbmNNc2ESYhJu) zN7uq-EJY&`@y4|w$(<`pwQT;py>BzC4^r5}Tn)`?R^4loXTj}1ks&PMJ&xVq0hrv)>75W;0FSwyR)P&NrO7i5s%oDi!X;D?XXy+b zP*WO_NC?|K5O{NnR!Drk0q8WyyB8{bj6ts~n<;N)9v;Gi+G>PpCAiLQOJ;Idx+L_d21Jzf-D)FLSEIh>2W<7b4#U;F6xmG4kOL;c8yhj2OrxK=c6p@5z< zS|2hH6tK{$=X?&9DwC*^`xiT}xv-y{xsLVnWn@rPs87^@7asxK=RP~-y})H*b!tl_ z+>-0m(?ub;P-J_7o&66pBU>-KvAH6xnjU9&3^ z%$vgY=~@m_$7$ym0Wz|X(lQcD2}PoQ4DFlL7$5sy7+!>5=EqPBBTMU2yxILg29J=> zT@&q$VZ~*9nn{rh)av}xHu^;?ukiTAYdNtpRx*2SShnkFNxlO~cbdyECiIkv*6?mV z&bsQXlV0v4IXZFCpL?W0R0#)%v*8^t()>qO%^Rnvu5y(&jD$x%Z zEWoA;wCNhY{NpQ)XERD&UeX?tRp*Bhb3uQ5Gql^cC0)=m7Lw*1eM&8~)%u(kKXe3{q?u3R9TE+ zvU>&G0!v%odITxw<>qkLi9a;Gg6R}e5)VR~VvT+;c$)USno=x-v;a{L86jIm&$L~o zPt+NY3N|VVQzSxQ@PFY}`kk@kfh|i;5#mDP?~?DGhb(?)u|vf0B{TO^6Bw*(P(bqG z10TQc>94gq)ivUr31MZbH4^Y2yf93|IP5FA!?;tB4}F8UOUrtkRAXDkKILz*0T-K} z3B-wRa(+Hn5F6Q@b#3f^zdYLZ4g7@4yr8dh^HZi$gdUu$$ zf7b{!!60BUnEA~K5PaT9oah#7(9^z=0Sb<%oGMVDM}+J8%M4u#>WS`|&XCBMqH!=| zV78pd{%LOWoim&1SU}#*Qa%!5;05G2*c&0_FQEor1ID`JRKP+ykSR3|5GrK$3LKqd z>OGe;UDs(9423naVVw9OGzDKDq7ytUrdzMInh3a92i?fNA*mo$G29wA^R=)mA8&gp zsT(`|GkKpoGajraiZL}6Muhda7x*9AP84mRlTk?vGVcj*coqxYNLB1p6Ecl)kf1|8 z6dht2qY|EJXc}gNLg8PsS|66}4#iVni=7}LpK^1rbwG$QDwf;&}LNm8F+#6pK+z6x>q&P4We6CayOWTFQ?T&EA*D!Nlz0r~-$s zS5&^eQVh5OUSOf1_e|t7hV#ZSMWe;I=bDzlGMKeEr z?n4d^e$+#Ul4*Xaj0nj(J%O;}}PL48|-LM}mpwck&8f_&)0yTcy&3xEZocRb2tq zU;TN|#;dU>Z|=IjA60REU7OK8@ASlGX{z=|9id@OcH52;t1h@3oB7du{5fzwcJ5tO z%jjk;;!;7NB0tCTXFOolci;=OBw`E$^0ZmuF!lo`{ z`5)U3jHb!-B!Dx-FV5>A@xuQSBV&BRfgO(ud+^W7+Fb-X5h^z5LUbloD#zNNEGxjoMwr>CV9+*%vGCev0OYa+rkrMPXFe^4Hs5%bqzCGRQOMy9MMJTb-p z2pd!E*6kVKFK7Sbwx16>VR7r669Vu#R4@ykN!uq4Q!_p~s(Svt^iSC=ac^>{I%|06 z*HO*d&w}j&&q+0?3+K)BlTX=prPgxi3)jDX{}@L3W3slJ3tut;?%Z9Jc-B7%_73+TGda^4B8FTx+}Y+#ZCn+-(GATkyqu9-U$rsI5Wt|}lFBU{ekpc>#fZ!AZ3!^qdKZ4k;rpp)~J0Ijn zHoN?y5rhZqTGdqlrD6h{o#0oz9<1g5A}z@55a-neJdlf(y<0)QtS6jW)k7$qX(=zU z%*AAh6NbI6zpbW*sjhva&Q@PZj@6Qw*3!KuUyvOFrrW<@Hu7H_`|g_DonLk&m^a!Q zX<1lA;GN|<<+QE7YX{c9hQV>PVjv)-DX`4CDSC=#JW`$BZRd#MCJ$Z8gX zM?(=Edi-de7@x3Puw>9%U_KMPFuNnJPh`4$oV%AhOsFhP2nKxfJ!`+k8OcvT4g$9z z$fmDzvMC>`-*!fzk@y~p`nx*8_QQZ=kKyW#^DkoJLH1L zrJE(K=BhIg8*PKb(2><3?Z5T~`UnceUwBs$9~^Vd+O71@4l5_&Ux%a*3<#(-EELuV zsA(qqi5&etWh7K#x}7}wN&smeF~WPeyEpaJi{ExbqkTmUM#i!qkjmJoU7{lL<{b6g z-y+WCz(Z#;In0MGO<+Dv=9_@T5^Sz(_C(=qHvnTFwqW@gF4tln7*KQ%jvY>f>&rmrBjN`KTPPD7 zZ;;UI#)BI^^V`~M9Rg|&f~7l;6_E8{$Ww&+uxb%k`b2;81)-M*V{|FaqWU+Y;ZT*$ zfo_qFC4!Dj6Uj*>Mo>c(In9LNFPaocrE5qwT?BIuQlv~YZ`X1uxDL(Ih;#~(7+1Fd zC@j%5(HyEWT%N?qwk{=DdD7jlr6*`h1@pThHCeRPUxkH+L}Bds6IJ!=V0Q6|=GQ2x zlH2Tvuw%sfypKDhTt0OnHiXw*cbq)h{I+!*6w4Ja8xXq@z=t!!{x)oN zTP^$J_uKV{szrhYiukJi0pAXq8?iQ@h4H?&AinGRvwut&ySIZD^TE5V@d%Ub9>}kz z)t+#(VF6JgOI-Y9#Nie!0fgOP@q-#!4Pf}^o*JGV1eVIN8eQzN{~P;O-seYN<#T*Y z-f?7x|91!RzfJ&ZZGe$+z#ozG1?1Bgj|agUp$4jA^$y*(?7Vn&XH_nz^z>U`8fWOl zi42XtpqiQ1g-brXOMi4P_tw48`{fQA_#%*%`n=p#l?geq+BGSfAs^ldMK%3uJEE4F zmL^?-B3zpfdbHP?QzkasJ*$7x{zc5r1OQtH6Ano>%Q-jI9fwi){qp)J%@CMVDQS;9 zaQ8C>KGm;eCZeH?1?wWqmf}LPW53>;;YrUx$sp~OVMV<&S#uAD3uJnb3JQPYY>)-r33 zE0eYaef~|y$Vb({4pB+#fl#E=ljo`Ts1A6IQ6PqkZ4o+0MOw~(-jz8Z`s86gJR}>g z8`pAXD_<9$xh%%Amn#UN9*wiX=>w4aEm$ROiVkTYCZNwn?buH?;7s00d%H*8ND^Bee9Jk%|h}BO>PR$36&aw%khY3%Y#xO$vb@KTpq3zG$wY5!@j! zCxG6`l#zZql^hi0^VBEs-21h_ADX(yEadL4Se>J`M?)_G0xSV&zl2ndc{fLh`8+8h??Q`FruoM-u zj+Pmy%Z~9Pb?S3Wfi+NxJg~C-MAgsDVCrbYQ8fU(IRn7>^4o&&z{H3D`HrCmehH$c zjR84Q!lqGw%^SRV2rdBsR1xrrNY7C!P7BemnDb}$oZ<~qNjq8UcR_@1kOC#xa+`@U zaLtv9`QYwO2R>8;tcdL64*t2M*C*Gy;6nNNxEo4bY`ce@ySTVHCzL6jw-2-d{&~F1 zoF$lfe0Pg;l~O*PbPn4`fZXSYm1=!~l&gA>$ZTkx`kPmZj3L(%7)&-AK|(-#U1)O{ zYkIGgzs3OHPX7;NKWA6v1s;c$>CL^Zk9_^-JH~Ok089E#{jqhaK|9lJff#EWO@QxX zL6YXHb&O*V=2nA+J!+tD>s)dp?}EUoG=VVml+i9t6<9O{D-+WpAcjinY}Rl=ZrPrzZxg_VQ;x3KVNiJZAvA zajZItem;x{RO3tH)JQi7Ho6H;BCW;{M==VqN@4>Ia{zq z#0-FVcO_lhRt;R;C&#U!6|O>Y{=f{m!*TYFDL7wwCyo{oF^IukVfBqJGj1=mT3#(^ zGt&?7-WUEnAMehq1PzE_Bw#ICXUfX2{ipQqgY+>ds-N0*|8jF4#0?VPq@gCBYNs&dp}Xiaftb5M%ZosgMQ!c zoh(#7!-{WKGBU(vmMcuXcvQ6EOlYWTb5DKTr2IpcQF`FP_obCeM{b4l&?VPJHY$Pv8pT)4f@L z7FzSz+>C?X-n)PN*2npIw~cfmR9f@Fv}5SU?@J5szF$~iDYA>Z>_I_`r3Y*%IQU*g zpDIeYS*izxX!~&Oz+EK7ZaHFm=bKX5z>--PboNXaX(Y^a>r*&wc8^>%d5=^U9#L(a zjm&sZ))!uGE=z)#ZbZq#e)>wHUA+kXpGMF)0x-pQD%pv4086}JERk?AS)g>4Rtk)2 zcrtenejTC|rMH-T`*nnC>^rQ0vD2L?YMO_?hwYIbp+b_oMM=nYAFZ$D!z^aegW*a;9}wK3nj$mUnM@v4az^p2Ef&cXZH`-!+kz&7}3tE`)M4Ifg6 zYS6|(2VCk%IIk%cujvh>RWcf4YZPp<^VJ5>>@Mtv;Ea# z5=@-T7{0H;?<9?Oik)QE^K9jRsj9K)gv$`BqvrOSM>UTxZuUOHJ8$4U?cZ=+X`Lsy zefPgXM`FF773oDJlUXO*3NPEP(&YX62>-OkfRN|ICYodXYm}G)Eh(qc-)oRQB_9&J z`TY6Tub!FWnH(QQP&r`WQEPzSQn~#b<)RmZlczO+Di|_N6n0$l6qMVw=Z(xu)Dba; zb+4Kep_^OF$V>0G$R z_aR-FO#&&&$*PmQyr~~$;_K?vwRsfdKY#Z1J78-|HQsUt+$KZjr-o=P^@*=Je7v9(QI52Nf)s**IP=rJ z*6oKN*0ehIfOkMepUe@2?=&g3!O3_(!iyy&qwXL&fLDlwDN7d~5?Uqn4xQ>E;1e+; z(EJVf2wJP8yQtZ^%Yy(FhXk^!P4QV-#sX@`s{)CI&P1M9}p)uwxxs9*L}86OYmxup88 zlIY5LkTJMzR_#ZH7{BBp++#p=d`@(eB32(fy$f^_Wz2~f!GN7XG~%T_jm}kP6*!NH zPJX|(A>1O+UEGxM6!Ng>7msp3VqwN+mB_p zPq049G}isiXJunC>Uh?tsK~u?`!Yw$dcP?;Y$poLiat2-FhfUz6-WOBAwt^q>~Vns zIYe+_nOkTdp53Z-e;gWp`WL#&V2;2Dt-#cQbc%AB&6|Sl#P3c4z#9c3qntaXYPEUeKAMQHFQ z1u*#Zd;AXb2kYt;lmJt!qC?_s{*0(Y)$EjX z#p&m)?7Ga)4p38F29%_l4`jFVOej}w>>quM3O`}2Gaez4inYC_shKUz%3n|BEne=x zUPn2dHA*tAi-;?T$l6#AfLH7c5RZP>3+SeaOMC63!8uIGA~dDHy^VRiQ=+X94H+O5 zI(|}ag2jw8;m^q!ny-FEL4u+VRpxUsDH~R3K)FLfLmNZ_m^?#CDf3)fPRA~CNj@2| zdViveYe?$bu!hX31A;wvNSi&BT5hK^!C>u9wLLXjR*x?t^d*gu4?8+(u4DCMn zJi~8T0K~0coS4@;I1U6+jae3;MaK#Z3LQ`^1$#^(?aQ0J(b(3uQN101nusw63o2i7 zd(HLf`=Sdf^8j!QcYMZA+jp+!;=_l%K@wz)?d*xK27S(bC~jXjAp>$0_+}8l*F+;> z9#)5?t9MU-gN1C0ANEaVYrb~=aowES)Y!~`NGaRuS6|{xonL*rJn_sT%d*}Uk$-hq@Ux-xfn%(|+j&D&Tp>i~B` zER_(zDGgwwxR>!FIzdU|M)I5yScj?pN~_0x(s1d=Ew%+-sC&&VLQ|eyv{y;P6`F{3b{-T1@6LzH) zr3KA*$PuoL6<^(H(!KMn2Ge#GVCX|EN}N~7-bI3Idq%oBE|as%tI?gPm7Z|L7SJ3tBBhI)h+*nF0l`1^r2-HVVjj%75t#@jKYlZuO$CH;*#Wb}*WPn^leL;qUqdJ#qJ`&h>^^#et8>LcS`x zlcN$KHAw|>YMcjxO_}P|54P4&*m+eL)x%utYVqMy^v`Q7>iYaKGYc-a%P)pPsAS0> z5}umlI>p|Xa{#{}Z9x?lzTC3yQc{c>F zYdc`k#J>G`uV*)k==1$v$UlwNgg!#gK8A0A<1d$mWHhN{Z$&e3NFZ|sloh*2(EQdt z)E_{(YK!O|@vsjwt&2X}wVP-}3Si*abjD@My?30j!DG(pfKR&icBFj;+0U`?-fuq! zkqGg``!XE>OOm2LpEBfZ*hP?gvExg-EOhaS1S;%(6UY#otXt@(1H-k&a0()HM^I#f9ko_ z>+7hd-Vu%Ul~3B;^y6mfyG0=}byy2WQwVd4$IXKF4%`=(SGug3q3Mgf%7H(%T#%>U zV9-ul!vACM&7-0G!}swim3C>PgvwS)QbJ@)l7wW7Y?16FTXwB>$}TaA?EAitvQ_qV zvV<(z$vXD=-OtRBdVjwEeb4!w?>V12scFnS^E|KRUatGPug9w6;+m`6d|d^P8uX^O z5rWoQeSQ5FKrQ4<=UuWUzV|m4sZeb`1>F6H=J^4lD`s9QP%2=D2@G-qA(e}|=h>JV z$FWdvZUmBP=?L#$ojWNM_HsF!jgqgVTGZ_egJc{6<6f^3AbcH}TVYoTe&6cuv$D4r zC?QCW#}#?%0Vo)Z)Avnew;TyBg;KPBj#bvA@kCATWf};shXxy~`rvA^XkA zg~=heBOp*BgSb0Y>@flBEj5OZkcM_>5XX|bfW5n9q5SEGPrQODav<){zuX~8km&^Ui2z@RV+R^A+W0Xo8G965Dv`>y^{rogwCPx)0<$AqWj;^Q3an^s4M=o}0e zqAPMh<_4&g0A4c;5^LsAX@lOP!Hc*~;LSD8wC;dN?_yuBt?lx{^rd4v03w)O;03DI z2tZpEW+znb!$jK?N_goR^ir@#9M##E(D7jJ8zrs!VnMqK!88;5-vJZBBH1^qvl8UCp-#HY#DHJ=3dCaqNB{L6MwBc%;OZlJ5 zfKtJapBQLrVH5|?kL2rxO9FKeOWU=1N6)j!Jkl_{7<_Ad5K9XB_a3`aiY&W0gb7lV zCJcoZIR*3EkkL69(bFFK_$(Q(*b`Q~FX!Br*ZGFAjC+QzqF_lvjM`3%M^I^l=Bpq3 zt4^!}1^K)bs7c5-oiFioEfr+73J`xf|UIMY3Ii@Z%07RSRVo;pZ@M~F?-CG?$bGbV(ID2=Os4elmQ zt`L8?76^Zu`zr$xNj{REf?Bo#5m)<>tq2`_5GtLvs5e}Vp?4D7Xq2DLGw|_AsOR^y z!iZCLyR9ZbACqn)Cd#F@WEvi3F9opx z=0UIFDul>ZS%o7wQB;}&ThnKMsn*fSog6Q>ETy14IQ}*%%X?SE#6*XpXcYCo`W&m4 zknlyy7bi9yxQ^eoNu|h62}BAIsRhWJCSihL_7HhKu}=^w>O%Si>*ZG=kzpj{1kfGm zz2Fh3YbPz^D?4Zq09q5*F*N5({3}8`Ch7{q#>L1`4B#cc7+(Q`A8+3(kDcp`_<*yD z2R*!zEI`Whd+oP4qHZ+k@S}&W&5HYv^MtSPQN*MtHYSgyr>Fa6FWW$Cu@Jbn*qyeJ3|arEIW_O=6Py4liJt_{q8R$V^|R{h@qBzf5g4B^ z7!D+@mA1}5(N{xbql@iktuK#y(JNgv`<6GAbEUUl>2pWT?1fUMnC_&cyI&~g2Et#c6GBSvGspuDcS$zlBy%Y*^A$HJH+;7MQEqxg2<>^qP ziM@*#uZo{(E$X6O@gx>8Z&n7E*b-Z{orLY?hIA0!8Pf#l=Ak!gPlh;*t^(+jNWS_C zZX^;C@%u{(VAq5TrjD4GAOhvY|7wFVLQ-~ZZU39Nz#-61qvQPDe|r8Rq>iK04=Lw2 zm*=Ke8^wLMy$c?(qt;Aa^w0W5mczsJ%YX|Mi(Q|rg*;KU-D>oxxot$Z^qoA_6H%`< z=;@psuBdJ)4jizB&}psVG!n|AfA-3iuBWD>q`(=WRb8IJtK%!N8M;3=w~nQ6Um@$g zF)=>?JG{BBE(RJ{W1rM^3ipy)920A|K0^X8h{Dh9Benljlnss8rr$kRm$uTbR8}D98SU_N#qcuCZ2zTr zXn>Pch=*=ovg1NF6@nWUb?wBSMRLcfiMyf)k5Nb12zKEo2ON`My}JI_t|ab%x_$FN zzzBh8vs^XTPLyp=v>X&ePrO{EJ^fWV^YP0xDxK4pnuQ1HgN`GK_s${*Rblk5Eelhf z#THM)f!;K?aO4CP(QKe{hU0MwEl9uGu`cLSfdt!O42c4OGtT+lnOMLtKfkf>^8721 z4J7qeV__CncCOBNBb6<`O#SzO81KVxDR%-2PSc|d88Lx`hSgB=rq-4r>>LTHOu{$> z6SAPaPBK0k=>S&IG}w|U{uW#8vHt|tQqqcX8<@(OM!w(udE1NYsIga$1g*q44vamWU|9 zNY+Xh6+&1^o2H`NLSB9f`c|Pbg*Y*X-=o zaN`$~114ctd-5N|D@ldI-1m$Mk$C0qA%{cvmq`vHDZw*>5FWH2-=(@}+7+E{g4|a&pJ&JMHOrAD#Cp zaQ`5)8#T}QSAXL4L4@R?t2?u>aL(Gxu|}=?wsOs}4;UeoiY_Mn zRfH$+)P4WCi7Vj%Z(y=&YijOvPK{)OwSeKU>wzT%=sMCUZeihRIew-gAkpN>k0fHw z=RZ?cyAAg4>C-=?b`x3QDhGh4HQU%F5F@>}Y1o^6tO1H1=qzY=A4UK|XtCJ|ns1{e z#L?{Cg4x0463J+bZg<}6nDAmt5`qjk88otFdfFSa-+g^#^fYm&sWgHWN3IFnAt}(; zo#>>HE6_>4DN(uYM^SoWy7Zug&tnC41q2K;XZ-Ted%$~n6MK?M(T9LnvT;P1b!La| z!Z66~DyK1uO*b?EEd|L@%X2$Fp`q+Vn;6$X-_FFTRI&6DXrlF znu2s>>K~UX7ouJkpEqHHi!2(xjS zPl0mv7|XHKU+s+@X9iZ3XSG+yp-uu1bt2c|4I^Y7JLb#@uN7|*88zG2>Yl#Xu_~dL zYh!cnn_<{7qC8E6F?NZ5Cb#gB@WZuXLXi>pE^`uh0Nglp#!DBtblX;X9Ri6sE>7u} z+iKOjiz22MwJtUDUi$95%fkio(MH!{UtWjN($~FVB;$OTP31VzRLPSS4?;#^f7J={ zjkbIoS>ejW#1l_cUMKlMP>LZUzf$uDsL3}|;{CR%Z9h(oh9mnHEI&SAu3YO6fNv5i z;}1OhwZ0&OlTbC^?_}Ty9Iysa6nOVs$Gtuz;EYeu9ZlvXBl|gO3n&l$c@&eaK6|iF zHYzC7T--}wIi|F5bVU#^s-7Dh{|+`=ciZEbf}7|j&N-!rnaa=k5qcFo=1+WjDEq8B z3TR>@06JS#2(q$z>klO13Rj;xnuL82G)un&p<~g0+8r>RQ3#b@K&9AZ9qKu@FK>3M zsWhrr7!E=b6R@$dG4Jj^Q8>3aA+N0bFh>$%7~r@vEbs6TS?NV1fs=?x1mZN@O2q7n z_6e=WIP3vXxdoWkA9bp{5DB0XV`P84fmOxNqzA!no`6;Gg^L5M2tyKZSkwyp<~1X3 zPU!?LTOSCR0I?Gyp3+TtJp?g5NvIm@Z5m;!mAy z6{uWYbjC}<@I40_GvTpY1mY(IZ32x5UWeruNTjD>or~>OG=UV5Y;(AeAD|$2-ig1x ztn;epVflVkaS-Bh*=E_m)HXxW5hw_W%DGorrveV(wt#TqU!;oR;iF?6@{w%!y3Th* z{awdonIiXqr3&I*ZlIkG`}%ysu1!QfOY3>mw=&fE!=#X=Nv;ubaog+yW)C5=pgMsg zBT8COt@RQUuEhz}!+y&yMPFMD-EK6%JJapxcoknO>;0{iJsLt{jc56G3vR26pVS9^ z-IKGRmMd<@r#91fP{=Q5Nr;nJYTdYCs9v6xMA&Q%Z3xoy$tIUiCD@*IDvss7+_0A2z$ z1)`q7MmZ*^zzEJ@_BpYx+9u>Fm)XuO+MPgsHA8X2K7zIRgNiJr^e!jc3EuC{o(82a z=XUYh7s&rf<3-viU|bFaH|y7fe?t(h4YGV{U1a+RIv>Of1wZcfJ(?ow7KcGE-cE%A z|3fe*aOGo$isMQ*`%DvlMK-t^)K~#8Kxtg#2xC)gR&?6tr)2h`91J zUIgEY5~!yaxBnn`$k|(W$xGE?EIqC-2sZMpQVZYm0bd@;EXI(8)Wa_yO^WU=Y5m2O z*_1s9)oNY^1^zz5piKp+CyVZ)`eKY`@^w|uL?lK5xcE(2kR%Zsqy~=ttpmbfh-e8> z(B<84Kv!qI@KAX^E)J2A4Dg2I;=GWBtG=lYT)#IJ*$u!g*4cETTMm{Sj`@PQ)&ZiD zm4;lbMDq_ClWT*5TrTzeu7mXUvjLZER?7nARaJY{<^mfYv)!`}8DHrfQ})lkCO&uM z!eqrmcr*1tO(q&U7iyciRO?u;ezTI?gI(1p0>4no{)}Iec+@T8_`zC47#wlNf)1Kr zHrngC#CjjCu{+U%YZBLFdXtuA6t0YXAdXEqEM%Hz!H$ubId zetg2}+#3wmncja!X|-cIWEOsYWMlls{F?WJ`YRG^JJ_03(18^EhltFf*64*^OX6{6 z?gNoFQ~Am^p?gxH4G>1gB#egN-(BUxx_MZ9V&0qaYB=^aSV;X?6Pc3wUTB1A`_0j_ zOR3M*iGXoO8b8_W)a0p%RFA+jVU#zwG<9Z?98Q?&Yq+g4~ayK!u8nh5g*i!xmX=^A-l!YX4lk38bJ4E}l#U6Fs>DB(#4q)da-76#ow_Npm-*CDbn@zV0 znFX@>GvdlnZEbCGviG6^HoUE>S_rI=R~e329j6a$WEBX{Cr(=*?Y{ZO$!)wf?|>Ou z5e&=M@@i)P>|#!WGPs43wgJduEr0=c`Hqivc5ROpJxT*^snF>3djO>524{8b5#uEP z_|}?QUu2{?U)M(Y)8ZFpCUFa2A6(FDOqVw3BBvih}5b5|3xBCi5`CY7@QWBVBw&bO;%HOP6exF8O;9^)ms~`OnLH9HAJO zUjj`|8|1QroK&dD9`W4ipdx!6`77XXGnP9hHdYA?-)`2;XTGioG&G8Cie%RePufLx zD_B2(ZD@Glz_aj(5d|;&kB}OG9rAX9DzFVmgcvu+|@&-7uO*Hv4{!ETBnJ?A(J@UlH}IzGsnC zs7mi3N0KnyWLg<$(t<5o)YfLNY4AJ!OGBE7YM>6hlbMeuvP$J-Os}y>bH%<-mC@EF z-}C%_5%jv5S2%i{6l$P4aeT=Ali z(ux2#aXdX3ZM0eFeC}s7O_5pj#>(>w6hg*tbP;*nZ+sfYUf`T|ocr|9OntWo$n@i; zV0c6I_~Qd)qc=YmMp3DMm?~u7+wJ!3<{gmx37~do)Qh|H?tCf*sn4`SxEAS?)SsAI zyj-0M;LbD1oUo=Cy*#nFCIzl+JbZ zuFW$uGckn`B7Yd4;Rvt&g^;R^`mtig%kfnCcv|B z+`I#OJ2O*d<(NY1fD7YLZ<@XZ2uv?K0H~W$HzcRrcIipgr3df?q`hcvmEb>aw zP^u8U=2?pl2Fb7n8%NALU;j8_$nQN;8@x7*r~`)%p;3@O+`M&*dev+FWfY1I2_YuY zt02t9D1PXWqR2&{s~#6`E7HQc)Bv{9N?YJ(x9f#XqaYn`5YYr~>AS1)sb=;1nKrEl z5G=#3XZpPsO3OVjg zw3(J{HWF`(|G-w#u0HW3hJ8_8SAz!*+j;#=pGl-uch_P1s3jg-$niB zx{F`*?8MJdppfP#pGY|3tjegVm=iA_pSZ};1cBCe=Ml7RB3eQRLZ3`I;~w1thclZwpkA$Kqx&CSgR7^mN`_Vn(w0lL{G zCc~X4=uu9Q@js;%N29?_PZTY&6i)tX*J3W-qoMhzOo++ap1uv3+0dmYu6|9JNe?pF z4m_O%)eh=PmHUqIWgozHzvaE8A}(&(vx-7-J2QwvHo*llnDPn=dLuL@`Vf_V3OfvU zYxyS0og@dtCogwQM!R3epO$A3>xc#$CkIRwvU3OD5VIUe|Cph>jL>-zvp!=7_NGd& z{|B1chv+vy78NekGuAW3FEx%!lP|y5hJ#vhe$g%vd!|AuEk6D_nzfrjoY1K>-@FCo zGy&KJa(@-%$3Kw1g&1Oa!Z;USDut@Ah+knuIJ64javUajVw!N8n|zvDTVo-2+-Arf zCh{=JC7RLdi>!~~j7m#vUuV(s!G{?U7oc%3P8Ouw1Z z+JTl-rAjp0Dl~HRsGPF$FRTOqRw0o3d`sM~2$A>*?Rx^xhzLU<7u$Q-QHEGZAY8f- z6fHK1`;@$C0HuRtUv5Bk%jVqNe1QY8XK&}1)S11Oa}P7`cKl=jS>U&i;LtZPUCiHQY4a_c8T;GHWIt7eI^mhu>tp`aO?t=5686LEDFA+o}%*MtJa|-Gs z*fWYKPRJi{X(UIb)!ZD%V|NY88&kFzkOn1Rs$2ica;A-NIlK(*mB61OVj9nip~Db) z>LHPYifIgf575HQyq0$y5V0jCtC8js;!=%y+m=TR9|;H6`1}_@z4m1uB09>o+u(Rk zrsWVO5wiw>Jn{p%i9`bAn8Xv^IVZP{ACzDE=M)LYE9QawImn>fr{9@9=Vw9{^Y|0s zd0$4v2v*SBBEs*aaMj<>Z< z?q`Pb4HOx9W}VW@;z#76RTzMchON+UejDN+2ZbJBX5GHk!%4_y5O9R}DBp4uj08No z-`-EXenxNZY_ns-pqhe`5>xa=hVQcLr-Qu54%9=m`Ew@~?F0Jz2eyG`ypzE#BKe0% zdxIo(i_rS`UL)yts%RBG9?|EYb`+c_DjecwABLMO>?B~F$U%t@bm6i|a$HnaQf=BZ zJ}h!*)RR36Pv5gE zcvs`0Ax*J<0`TNFg5EC%Mx*Nic%M-+_RJdens8wF-bjYJ^$4IK(p@_UI}0N0-Eb0m zLd&u~5Y6AbC3Tw!g`c{>dzHYWA_$8R4)3s5%TBsXvASF&kSZ;{*jb1hTEf{WcIEYQ z5gaC|=0QB4*NmX+Oz7%kmllkG+d*yf5L_5>ltZ}njayJE&xrs%9a;FFO*3VHUMsq7 zt|m6pl-0P%+5%2-Dj*mYE%1=U3G6mkldby~%3A*++?-pVP}>pKhK=sC>Q>aTxs%%2vA2`1dC;o$lbMc&|@ zP=DBHaRu{5)01%LIBLH&KnyWUjN(K#nt~+oLRh`yMvH|(4123>9O2l=UzvX{yK}9G_at8xeV%MaL`Ai_%9CGqvFn?j6~hK2yC%9ouk7-KX`+ z%$%cNxdozRUhKTi{429G0Xpvg6nQM4`^gUsoZPJP#E#XEK;D1x6;Rg4 zWqh{IeD|c^)yKYS^eNYgxwsE(Fn#mV2tCyp{r)j}3Z|n(Oh0Ms+Lit}HsRU%`-l+z zgQe9?@6!=lN6)Q>s>>H2xTfk7^gMQ=&Ud>#fjoY&aTova_^1L14y-3Hwr+#kIi~ec z?tQDDdlV0(4_51O?Ait z3ya^whygjB`Q$cKL^o9S?OV*g(oPvd@)&(oYUFCgRuHF#J1p$mvVs{Gy|=9^%IY|0 zS#V^EOG-Sx=2ifs%aj>!i=6g{hQhX|V-XbRE+ge&2&JLAw|hxT|sw0Hl&dwIIP`D#8Ec?e>M zs)y6zBU{gC44*#NCdk+Z>0(>`Z7k06>ryW#uy3Q3qw+9a(nMUw8gAdfuRl$6V+Tqy zrw^RoFz$bz-g`(Z6Jpdtzi+eTI$EC3ph3ufj-rak5aP}ILJ3+O-X^$_ z8wesm>s>#G;W6E?t6?w)OvL==%W)M0n+o9rhFqW|Ki&>x+8q1}gt2+}v7Ij*+E!u# zUu2y1e7=`ybjd<)j+P0Nmc`S#=A`-WKRt->?IYk(r4K@rx6+dA_L#sDm1%wnW-jQh7IeYw@WaaGns` zD7cG)I?o|W%J&f8T#)9fM0!TdQ4F6ST$puzcC%=!Fm8NF$R`67f9}_QYG2wUeWLEm z;P%hmgwy%^jx-z2(CI`9pbN@u5HT!n`2?tuIdR7vRoEiu!}n|8^r}CYl+TGjkv4b& zSh(xNZH;$3uh*FYR^$^2JOI7nrbkRntuTtfqJ3`9>%{T(-o6>IvHs?>{7whwsbjMD z=|Ts+>(?EK`Vh0}(fh=?EnkavH=GaQsSpa>I>Z!U;Ha+)L18&A>Af(-SRzIAiZgKM znBvsV6C`mdPLfWW0Dti%*1N*eL*($J4XK?U}?KFficO_mp%K1jPv9=Q_$w03)3Y;I7_(B`oV_eEq4+mW*}0 zx0H7E4AHv3vQvj!_X3*7#LRMvxU7*9we`1kqs+Of_o%Qm>N=r&YZyAns3%q&!qd}b z5d@rX%P^8I0AT(=BJH$wkt%`saks0HX~?q4{^0n_4BknRrrKw&6VlEo4*Y0RB=B}P zNg(<9PGM<{o)NqQ0aFf&-Z%K@b2qUlVJLL4z;O#FOl=q3`2`--CmkMdzy&s(riUf>*-BJG7X>IT5A#_j|_5r2HKg#p+M}zXpkU^M7n5R zjk_eYy2&(31nyWSXB8=|I;KkU5e$}?E&Q9PCl7Bn0!YdE6LAsU|1v=if&s)V42jf` zrnn`3!s!WgN84;T&p%I#d9=7fw1FUA&u(&}4onP;BSH<|{J0AWH{<(s@zg=?tlQ6s zL(nGS15IdYrFR*{iAp1;dq(J|Nuf-9Y!YH)8Em@X%(k>?-2d(7#z0z!A`a}`r%~gJ zYUuU2?DqeSMb+yi|x*z)Cn^Tl!HaM)i;QA#KadG$Z3Y{hZU(Mg2$0H zG~+h^r@2%{sZ|(31F);AT_a15Xz9ve27vayPLIJj2CO>@!s$hLp||~yo(P5Ar#MqN zi$|PVh03NB>3l(XNP!rUh=37^4R_9~*m*w7>q}4@UA2O+Xxx#7xO2oC zFP9C@Ck$Lv)t!JlRV%32+hv9EM9A)mwjfFcdVaeCOrV=mQa&Wd1mQx|3(27aO=M1~ z-w1w?=C?q&)s~F7;UI{~oM#F##4>xVZY5Qvis=Li)_7q4B1$=W+mamNlFqN+6Y1Zn zcs(PEy5$3$3-JcfC1Dar_;-x4M6*CLs109zCG~X;{WvcO*LB*E3){Oszj+6Uc9NNP zlm|y>clbv_9YG-5^`%!*=cZb(i?sR%ms!N)_!`2K@G7hLbgPL;)s|WOJUZ=A$hYrK z`?wh1xNhIP`I&2E-wsX+ZMw7!Z{}ES{a}Zts&t2DP6@u#?-%N)UIauklTSW>>38yY zxXZh3r^qR@IGQBer*O`Kb3ZkYZ7aw&pusJTwTxw9M|k;@M(o)tk_zc37K>Vf&1Yo9 zu3amAEwW84gezYRn`A3$EMHp&#WrT?b&XUOaX=6*>nm9kj*V^J*Os)XL#q% zol_CVry^3!1_s-+tu$N9Uxw2={JwwoZqBKbCz&yv{f*lF{p~9Ivlu$<`V&i^*t#)r zpWE%}^1V<4_uUp>$d&6LV|u@D^lkcr!FLY9w6yf2V@EkNnk%u4fkP4wXWN_wAb;12$oom<70_cF^B=-ER-Lkp4xo<*3XkLc$VGlcr*)4W0 zew66)I>f42-WM2EnxS`5Q&Mk;Ys}VcYk25f8^hdiE0aRuidtJV5<#|0wi5Ay^;Pm( zW@MKz!$~*ye9>!+J3qJFQ{Q=%lEa{7gpyK5Musx|q2uE02T`U>?e}_ZIp@s_spt&e z^$uj#y3>*6MceFodsSG+{*YBW-&|+D!NI-(v1B#=2+5Ei5eC@8lUIXTgrZ_tfHB0W)?UA6C z=I+@JO4Qa&ZBr;o-ZfkF~EcR>H|gNrq4w+>~e($n=Q2EBk(V6mLabAZqIK zS4VSoWIlXJ?;Q}JYHjVjO7S{ibaeF8Xh*J8>eSRB)}BK7_})b8Y0Z{Vtk&`iS3NOs zV@}%!xc%3Mi10+nHxNc02A^_Lk3+YssL&wLDDD{N$q#vXyb6UlwXr89K6NfGR-HN8L`RzSK%r?YaEmk!DA9?P{j_z^}gRFdO`HyM-j= zZ@;GBx^+QDrGxDi1>c#IC(p35YPF{I??l3NT}r_f#IiDFW=MD7$xzT+v~Od zGK{-zE&(5jNGQ0{Y;gNEnjS_c_8}@jpU8>BfUQH zZNp6w*Iz~*7x~Xw6MiZ5>ePl?to?bv({7o67clXcM?U=T1rSE!|JDqQUVB(=!_Tp5 zhvjedYWnc+blj1z6ACgtRy8ph_qWEdyf!@m|MZuPggmZK{ zYy^8~f3H_4UleVu=)Y4MIqdFC7ep<%JlLxzMhqi#>kqS?5%=ITHw?6lJS_UpqI?~< zs($4-M;Nb#vYet~G%c~UBwY0EWiv~}eQ;p2!j0Nel@I>2IE@)=e1D`DW&xMZ9`#OsWFD;jLTxKYu+I9HsEj}XnDU`0Ls!EF- z;JR0yAV-(UZSu30`3#1CmwWMENR3^|C#)Gctd>nHz)X_9g%1SWRtRU-^KwA#)_=cc^tJPZ<4KP3 zx+42y7v|mdD)bK%f{Eh2UutT+l8;?H%VAY1k8#1O+g=>g-Eew@rzxs*M;Gp1c~Z;Z znp?@fB=Z1tCutLvzW4h|3oI~>?SnVfy|bUcuf3aZE^t!?vTwb#q=UQ1${MIuXc9OEX^Wl8zJqG(e|7j@)!PZvwY0@ZR#jUAX_ z9=uO_giaUT8{d!v3tNpEy@-Zw!%;K8RkoBzn)>#dI3A&AnhtDI>U?j%L_Iwq7NMH| zEF*%8h7edC7vWNzCvoTge7d@MS_^+vRP?1@g|33~M*^gv>{haucE6*|Be^z?KkpoBhOzyD$g#W z6xrs;v5k)BxaE&a++#e|VgH9X^RTmWO&SQ)|NFkw zhZw$nSF!^`<>&AJk@oNO&MEYojLiKuplVp&B-n}Yrw|}lB(x@1nDvHv3$6~;XY{TO z2l3w&-^FhnB_9A?=o!WHG7z^R7=X@mXuGMeEPBrHHm88djEu7m+*1Nu|H<|^FAL7E zT2D3~DB6#h9kER)-Bf~67TU%?J^DRo)D0c=&pD#92Bj&y+*Vii-r(DY5^|3uBytz} zxG_ui%Zu4KV@qqqQg#=2>Q(&3PeN(1q&%r_I2DX};68F{c?(LHnrI8jxD;nvEzsZMNbBwBB57?7BN^`yV?}r*n|W(rA9UHC5PEM=1ReeD+{{j6Kd})Z=sw>8OCV% z&_y+O6MNfhnn@7cF3-2omkNvQy~PIv=Zz0Q$|`-;J<)3BvYJV@xSl6dse;!nEiKz7 zQ7B~(qNQ^BHNF^PEa|bUZJ^>y#l)SpZUxeDgnS>@y(a6W!gcusg>ah)d(h(S+Unus z#GECCEUNYi+pp?-uY>ZJ9F5(VpJnNZ{WI6RbAlR}`C;xNpP;3c6$&JS2`42heBHU8 z0v|^!ONtJL?!?{r*OwZ$ECkxIaD>74`UM1hL_SzXboG4Q-i|h$9IJ!`<{$r(1zOvz zMT?%vjK5a7EbWlkL3uYJJ4p!zwf2!Ji!&Vz|C;#}dGfvEfz;j(Q|w4?r^`Wf*_JlA zAM&oI7-9PcufHc5HQS(5yy`kxV7&|ZRh=dw6Ls4C70ELGd238yN5?x}Gl0O5(ENV@ zSYWV6Uo7q!_NB%wjmIIuNXGV7q9ZxlrB_lT;XzMhWzzQMZRDQh#_soIy*;dAk+gcz zDf<1xIlnPUakTk!8z&xBcOJ6FDB4Xh(nNnJ3968iRqwc11ne*u6waw1s_ax#$~Wu^ z`=N>B$`^I!Zz~hv&%3UmpwMQ@(9`&=+}iALvsXT@gOAs8I7M|SE^hIZTIhv~9Nc}s zh4s*yz6vr$=N?KrUgW?=;~m}vb}|1ql;Jx>81=drXW5hvA3%p3)2b*)X|;&D_4X|_ zh2Z3Gp81qbu~d#VA?c?aEZXfedv~!5>f~1nuztfb%^i9_YGx2-H*aDw9wJ+FML?2pq zH45ic;lx*e=U1FOsrh|}NQ78&PX@buL@Zq2*!Ffm3)!uy5%e~^MV5;ku>r-#F-168 zDXEcC{ydXfV?D8deR~Wy;Syh`<(iyz;WH`VZV8ZYQfkRM&TcdNgPQ|aUX{(2JKmP& zHb1TxSUzi|meqWx#y$C%OtRswUu<5deWR{lD|LFlcx~q9lK!PIw7w9rX-O|D z3oRurpT+k+OL=+vvUf_;J#Uuq`z0g-m;zL7zmAzM@nf$ACW-RBGpR3N(o;V<{Ik}3 zIdN{JbH%1~^oW^3g?9%&kXd$aVR2jT#_vO(E60(nN2t8$LLRRu+Gy8Q&mteYdpgFk z?V-B>!^y{wmBsMFg}p6ZDu)V}W@FNq6|;X(UAlCM@15!R$aG`TEyJ=7?!%=pHf@us zEVM_W-&xy@J<@oAz8V#b_KmFW!ug50xS=VuXoR?Q9 zFX#M%qP+bKQmzmOZtE`EgdFrtf+00w5i(dJZ8&6Xu7Bj zD$nQ-lzUm9Kvoi!D#wSKm27@oV|lW{HN(`=^tj=O92^ye(cImUX)oI8s~+&P8%)>y zoHdI{>asc9mN5N+HM~+OBP)BLYpLasg_{j`$HQ-tu^eOlN|Vp=0=Di;>grRM7vD#7 z=hrM2P@)g#s@actXT|T^&?-XMU%D!Z4Fnz~Ux=r~UWSuWf=lsy1ziaVXGNX%7>?8I z;<`7Jn`*$G%m1q=HT+XrQ&UqZnttJ)W|9tiJX4*XC&xsyl>vv>&ac7cN2}(Q?XRAF zSQ|tW80IJ<=a+1dV%Dir?q0K~S9l?b+p=nxY1vk8EZ<0jk2Z4h*(w$|*ty)CFo;in zcd~Hh+cke{o>XU1%P9L$E2*6`vrIF*?D($*tebfdO-SG@xxmD~qo$zaJr6G9wPk`b zW)U62aI$Bj7SAek{mHlBii-DQ4yl#}>4)5lVAMLSdkZNk>pvREyn8bId$7}NLKmO= z?C!LhQWbavW>3e{aPs%^t1|;T=d2`K0yvE!F77Af=3wJX$EefU@JgD(i?wOatNSA&b=-P)^ zjc*Em^lj>ze_j?{tauEF%Rr|JaOWDewke+}Z+k-5V2*nBJx`4UgX<&(ykfwwla zBUr&Dkud|?IUI%(S>qqhh4wtRpucXdBzon^M6WB>0k&`ttiQ|}19fD4)p6~m9j{<* z1y@`V$ja(AGg1=mFz;vA>Wj$!�Pb-sD3T#i-$#mLv%e+*m8y*o_@v&RxI`T#Nqf z<+>pClGY&5VzlYK@XoU#X^7o`1I~_s!>uvp1Z+kASl0K6^YwOsEs&L-E= z;#7Q}*dw+H{9_C@602UHX zP$@d2kf>6qinrAna?9v7xr4+5Q@`9#va%ZZR+>Flu8_nOT;brJ%yoQlMGQPy5Wkl3 zSPs#q#?;vG+Zk4B!M4W534EKk zHi-d)L0u^Ga|o{dfc1hASH7g0L&*FPo8YagAGmnu5&J51=Tu4cI4D!FN(oHQmv*FO zE4C;ln@d`2+QCwKb84aEE(9EHRp2yl9E|_h1xm`Ud873E;#$`X(;`!us8{M5Exu8Q zxPhDKOd9hN7*DA_Y@dXM>DRsYTJc^$ZpH(p1m#>C_m^Jn`pXyDnaf$icx2_{zgLL9@1ZrO2dx`P#b3Zv$~nsaEf}=Ns=%IyjdPpCk;U3$L(p z2+Tz?j1-QQC)XwnnHXDyupNp^%8vb91G?0uIXO#eu9FZd^ajL_C#ClLmS-M^wON`OJgQZ(+%;ciS3$pM*apW+K;-C{(89 zZeEmKdYOB5>uDW2dS|+Yvaz5Qp!*(yYqYCilP zsJO*T5p(ZoJ?ofUFg7+t%l#}zqBpE=DUYtV>E~ae=lf|pC&1ooL73NO3aj+XS0mbZ zg{00NOIiW!ISxA55tEsX-mMR;^_bnoX&STqG{;*%PAg#b8o89{^4PRL+-rxPqGn~`}m%=gn;(a~5K}41JpW!`TLd-ZAD%)eOJ?fv5pY5FcTv^-N zGo{B~d-I^H&t+ue@?Bkq(2}Vfv5=qfdHVTtmtc_&3dUH4E*z>Hb`(rov6(sJ<&2+e z4-(8~40|#-1g2-485?pzPybdh=OrZU5|5^*5u>#lWR_6=2c33`S{twD6g2xFt-8$@ zwzRlxW5Ne(@}8<#t4_&xl_OeOv9YnP5EF)`uJAv`t8*#MBrEZWrrav7j&gs&g#R()^!6(AY&9g!cIMfdjp* z7k`z*Y(!{9@6{7)*8}$*(>mt(E93H!A6*b#gcMAGb=R*m@I9@Xceo5*$?OFma}sg_NNJ@l)_43J&?~y}S#|0^FBnb|0feW=0F|!2bc#wd;KUdP|LV z(vlDGZOoIB`N`n;U$Xl-9hzUue6cfrwET(iLR>8~tM__OHwjyTTll5N+( zWMmCFs^fp9x~_OR{wf230`aQ;KE_=a@e4S`5&zT=cB;1=e)>u;9Lvr6!h117fl0zMCAt3_?ZmMr+t#0^dva^Ip5jvVqHNr?(P1CrK zoxbh;vr@2`yiBjyRP*N++6yF!j&t0ScEUgKVUt0j=dKIx`4gEU0}G>Q7Ups>daJvT zni)+yK+eSI#()iwH=GJP0yfdrVp#U^qH`DjVI(@AZ#Ke}y~JC{pAE`>;#^~GFn4Cv zPIEb1P*6R23?#BOk(|iPesLnZc?|Klcjv%@RzS*vH9$iYVLD{dC0*5=ZT@fF=Y}(3 zHf{(5x}oxT*4i}}S+Ga|*r*;{SiKaoRQ@cd=UK>zAN9kAj6j`BixzEfA!uSOhfM2- z{?nd$9kq_}UQgm!xAyRYrSZztK9t#ePSqWO7A)zJM1ePXq`rbS>&Ft`LRIg9UKTo= zyc%Bh%R?@uPavVowyA-mNmk8{Hrl(l-#UtKutbYoL&69o0v>w;wsQ7NW$0xiDa|WS z28(S91^Dii=650~G5|pCyYM4%Ib*}>6|~qD%XpmZwNgmv_xJAEb?-l>L5w(Ok+4;G zv?F@T;7k!zz)Ns>F4e|=*Ez@vt$tiMlT_F@cUL*>Y-)lPhOQn!oc2(Yl9OdF((as! zi7KIU+U@!^I$=Qo*r4ZoFLKx%=nt691j}HYJ3L-}m{S1~QNfj;My|PynCVPfUJTfe z1hXQ(J9K$qV_YYrZeHMnG~krqF}`0fpY(Xw8a{Re2R`|QJ}l0c1Gau>uBUf?1cbxb zxrS+z?>)F`6ko`BUz%z;D8e|-LpFN#q3^36q1b%)Q}%ABd4_Qz{7n7mHo8q4ysDeFNwNnx8e>Zgl4R2_ z<67VN`)B@WMf~_h?x{eEvbqS0SXxMIZ?^@WugU!#Vmii@jrBr;!G!qv<6b-uZ}jy! zDDzx4sre8zH1k;_eiyx^JIs2~)*J}-Rcgb6 zwzfmkcvIDZA_2@?fdtmCLQwtPVZ{6oPzS2wR(Q?!m2zJejQAT<)Dpmyb0R&W=|&#;yuA0n3i&=$5k&+I4UBT z0OD(TadOwL)e8KexpeWLu}?7~dIYyLICCUvfOtV6+U5L7O$5-2DiiWH~mSy})DZ@cw*K(QF2SqP67i z4QYT?aV~eUfcA`vDt^&K7eJGgWTuR}s z`*6TGQRpzOy=00@jL^rWXC-L*c1h6CfGL_Bo0jVQX^@-v>X^VOa2NhLs}R^VvznDg z<4gc1^9-Z;=aOSEq*~ueBp{hbb68bCjD-RHTZo>i)EuNb7DM7M+?mPk6jk(affz_n zJuESAv4viA!ht_4uM>H*2$ltcI3x_YrQoAQ znY8i3tAMlKkb3F|A{UKJ(^OlQYZ3XJN0@z{800ZVyybmVYT_Qw z#{_iUt(mV(s~z4@-8xGt3iN&`n8yyw&G@4z?9SN{h z(vkHF?zha8g>JNyT<~R4`@_S~7%l*V<4FYka=@XV1=2`2| z#hEU4a6hBJ9T0FxC8Kv;S6Z^EJOKK#TtbOlnIADdrDb@cr3;WI+j;JSu z_@1zMXYjKuD^z!6!T7!6%=L2smxN z+%PEoNf}G^QDsPy7-wr&WB4YmmiD%?^5JxKM~2mb_s+d&ybaSB50Ds4)h|E1yFA#l z>!9Ojsi?HHGlz@h_p^JUk4f_{8n&6j5r*gI7NRS@>V#3}=3Yihq#GViynevEpb$*( zQ14g?neRo@5x$x;Mh8tpRV%=KF3*Y`0*qCw`Xp81wl_}Oh`r`YI=7KtSs{ZlNcY5qgUoEbs5JRUz5O_g)^2 z2@>edghTC#1VWJA#J#{%i(sbS`r~PF-aIQoW(1xPjNpbxxB>(OZ+%I{($a^4^qeUs z?;dqfA{ASwF$I}PhpKn856_-eIyy~tHa30vUKbrDLP43Cuo7yMdAj=y`#qY0FdI|9 z#M51UQ$5DE6^hbV1g55@8W2XzXk@3gj?!7RG)58+{Xpfn7Qc$zXtiSu^cYq`1#VqH zhj3bc_APK%+E!+F0o>`;t19z_CKXRt{y+Ap|_Ca&K2GJY|nSjcmLdVbboN|e64;NWOw&TfK{_^iIIajms2O9gz~=9JgM=8 zp-BriL4k5%(yxl|UvqQ_nr+hzr#Y$G# zEI>wM#ry!sfckw$7<`|z1ft$S!e#JL?EpauBOl7=X#{O5$A962&F08IMFWY-YQ=t2 zUeYr*W`|_v&Rl_B6+id}(@~H^r48+KS6hQ*_=Yg?QrS@WbWrVS!C&a=nTwXM49dFfs@x(vyvn$>qSLE^cqYjNe+ zXs{S#FIu_M%wrs-PVpoiadWAHi~rlE@$3(v6smT1fUhn-E+I2ItA`5(H3T7Z4V)+; z3FLCF5ZD=>jZaZmiacg9PhC0*($g4CtVDo{0voi4^_{u8g*S|c{w1}b%a0ojAgPjo z+uk0sHEffH9%>_&&8=+Idar;&0mt4&hIQc1;R8u?9~}XAI2W(N%xEa5WpfR7Tmm`W zZZr0XPICs^-?opyU3Z*%`wH&<{Tf(8(M~5p85toG*o9CQys=u!-V~XA?WKIsTupe;$Ix;nx~M7d{fazTNo z^Z24yqTUEU$N)HQCDkRhFfp?<3I;ev)F7x`-MzW2a4ptar z?)d!zlmo-O`S{Mm` zESGyH4c@*L(SCW5NDw4}oWxgA&SpyJW0CJbw!z*W?R0~Ts{p!OkcB>D>d!RHj(fha zdj@_5B$uD{87bu)!)VJYL5(wP+|Su{)cwA7n$Z)?MlU%rx$FV*&J1+?e z1cU5CkRrm1=wukApHDM+{t42L2WIn9_z}_(BEB%^jg`XJ?_%G1f_kri=18?=D6SCm zqeV7F=8C*{Y2yhJ%Q@nzKz{#`Q36YY>Pt?ZGd5|5;XHTZCL&qOgsvlKSdOFjOh1;^ZH>0?{`?G0H*s`EwwV`_+5Q zNM8R(XvTn{Qr`8i`F(KS^2T^@fHU_ z5fHDW?QN(^4(d;lrW;awzD}I=#p7Kp1Id#+ER?3wB_4EA3&-RxPO7DRRqTMD21sQ0 zdeeztbHo%C`?>LRt#@?*lHsv0K-rU7)nJI;V zDj)`(R8%vTh<*v1s}Lp~>vz<$vjd&WT^+h9W9sE`K)~}Rl${el2#SfuCT=0SGtg|1 zh@JrwJ>W_a)$J*Dt1Ke8VLuj(0T;s6eU9{3hoFBCEllxvhbb-mu-@1;81h$@K{N1v zB#|K^ll=0l{M1a&I(8_dDMuvj6{KhiJLd@+iUsV^n-e@nyfaW~! z{7YQGe{GLASB={M!zM{O#^q6`lNAc;v=|n2MbdNN31dMdF4+G_rQH4U2xPK?<;&ch zD#ZI~gNF-#rhI|l7}pg-^M`y!ky*pH5~-vl*CkOUO^{W_bp7DFckf(VJY-sS z7s3+3#6l)H$5>oz!zNgGFl>EU?u!db<^sb|?_)&Se>MB6Y*MJ|mFJUPxFed^K(5c} zCq`2 z+MpozBy`AHH*x}(-DDy|T+qDA0hO<@_=4|Lyxjm?+$WVa^iprsnVwAflP5RbsmD*A z?AEtd{oWW86l7GmnyMy-TP)y=k{Zj-RSOV&1R)M7D2wf5nAtErqR7ehI|LobHONf!b2WtuKwpZgm{2wOT$+ z?#t`lZ|lHM_l(?^lBt~w0|l!ri!rzB8sD`SyO0aSfqu6I{kCFlcYgR6{ZS>gOptY1 zf8jLfd40`8B9&4Wv|g~jG|^CXNVw4?1#uM&=i-~L={K^A0eL0sPl{*s&lDse_6d*W zy%`K`()NWJ7~X4OKju1{JT`I zdVacQqqob$g7a<`Qn6l3JO~`fyQE9|V$%Uv1r2c#^YIkdjqjRw(O(<+1`VH8Ods=T zza@e4aW-8_$;!R2>fDzCFqW2qLAuW9r!V9>E1}BR*VIgL8m~;{acAV#erR6%Qh+2$ zUi|&t-HTonF9rYJn3r$;ogJysKB#rd==(+M{sQQ%Guw}OCWs?mu`ngjyT}CuSh?T$q#h z0!jg_$Bp*V;l!KpLynb8f!RTpR83qJm_lW?lz)~r23QLMWm}lHa=RhiX-tNM)Hzld zR|SrPP`;W)vDP>M2GMokaas0~3Q)9J_XqG>Y-?r;GNfa!V_c`>pwroJ{mifZW!eI) z$G$N)-Von3B?r2^1LsHRFl?D*Uh7T)NW|?Lf`W9Ef{vq#Q7Mp7=WZzutm^;7Xvs13 zx6PNRr(k7kLu@2tWm@NeNU8I;N$x{}LD#!A$nD34zM3D)u8f7uob6FLbmns~&MPpr zOrULvU|Jci+kX6AiADS!h~*$rZ)}upFf_ve&G0L>r+~;Qt)2z-ZE9aZ7omxR+=cSv z^6Adg(B0kn9KMiqa18&wgB75OS|Em>_!2`|eY#xRs@NfS!o>Z(v?elz&+?77lA}J9 zEFG^h4y;Guinu0BjgN2fbE0A7!JqS@^7}$BMlJGKVSIW(C!uy!T5TNe083-n(DTWK zx9l9yvv5zajJD^g<&|j9CoS@e0m+7cf{^x;4rBA+*D5e+nd>1FXXEE)-cAEK(~dtI zxnK42Y@zzk-hAP`cw4LQu=EViv#hHy2$}C-6z}}?bo_URX_vR8uJOkVufdYRc>Dl~ z;I4EEh?Xa`taps)$@Yq4r=Ld# z<9(7rp-#~>@^J7wDEj32O?{YPwpXw}-4(1f*<1|_uHOM4L;u8QnZD?msmA{Mf3*8ky5)P{~KOfm#9#jG(!ABi5HSxF-RKEvGRpTM$>pRSS6TGEjtFhC2 zGS&)`S4r1jHR7?;aw@KqjRk^^`sm8Ru=Y2g2o;F7?Vl#An9jduB%lSnF>EhZ=VeO0 z0^fcPv8ynUGS+_^6tEXs6oaz!R`^V_EUo13Eyq&r=z%EETG!w{jbcr390EkED1hQs z36JRlD7`+1MS9Ci<8p1OU=>>fKTKpF%u5?)^7tEd7S$1YT^v$|&tdt~g;|q8Ia+#% zLJk;+a4_|6-|70}u!ZwN5|m~Cd+~NiIsh>I0@Qnezj;zJ)>>ZikdrC?!!11%Zm$6! zLtfAI9B==}GHV?r3{Cknf2rdws$N3~&9Ah%F#ZA%J|+RL6X(LRyNKY-W$Mm+s5^15%C&OvW&nvhfR<;ZN-_$fZ#SApF9ykO5WD<67cR4DZ)8U%O= zo3pV{z*ZE$Rk>)C|51{ovP#w*f%LtZk#PQB&*PT>c*P~||! zOk#d&$$3_4---50I{bBhx{Mk#lO)LAWgHT5x2vCASHNjq3l5B6Yf1r?stsuG2MUT4 z&pGtgOYKQExX_#$IBk9CmcRCSB9VD4xPkR$s~HRm@d8ARbn{NaO%1t%%`JGv6{OQz z7olkmvekWxo;e4qLv!@4$f!RUI>v^(5cle@aWUPx2`YX$P1kGWU4~m6(40f%PM7_w zC~@ub#UM?08&m2t)~BSz5|?6G+x2SMm1jbA=~$hx2Gqvn<*CW_y=-IeN7NfmJU*xm zb5`>=>N-c-bF0w%-JQnGl9{<7$2_7Nu+6I2$G(@NA4%J(?_i3LTNM!lhhUF^%vh24 zrRQnhGl|bTun=m-6EM#OEFQS_>0bOep;Qj3@`U)%;_lPeHYPEnV_{0#&Ft))^<0Mj zf@Q0H!g+bo$cG>{L&JrFf<)m$$H257^BxX>j`8Soz4-B-6^_og#iQw^`ExZTsTl9S z*VD6ynZDFulyhJA&fGjMupS#%diDjyQxDWfv|Y2ZvZ+(4*Ow(WS5&cN4*A4jW{LsYYbOZJUl!q7V4v+05aG6dpM-<)vQxEKl-<(reHW? zH7pcOF}-KwcWn<7EE*_)_aav3^u#Vc#O)a5)LE#C&D+Obu9AF&o0D#<+k~0+%S+y1Mu`!d2QIuY>>f^d##(9Pv6g!-N%D_j< z#e#US6;ctRM&>Yja5;((Q%d2qYiJ7Mv>un4LPaXbWFji@iX#KrwF^Uv(xh)pO!uw; zfp8tCM%KR*GS`C@19dk?TG@k=%3BY7Fza@dyG>J^iba6fAWDMqUmZuicmf99bmI|vEZ&% z?Gvt@cIrJDLMi@Yq&-@L;K&@I`p2R0^YTZFbfB6XmiSASlpdEhK$Ze7T@z-YCSv-U zCtWxX7n-ZF6gk;VxwJIJ|FxXMkyUKK7< zk4lao6akK+11Cv`90pU{6*a<|@=IP1`X7;y@kp?Mo!UGw6${F(K(Jo!R%I-iv-ZUWrMoU##y&oGV(6t?5Ea=Lu!~>nIDaX^o#YK%xhswZpUfvL zNRzY+g962ZyQea;zJ*R&q+66;H>XrEDt*Bb5iG=jn$1&jYzi`-{tU$ci>VbxV#jRF zJp#`tY_5eeB0Go&_Y#9!rRd>C@d7uqy}f3P;AMyGYT{p&8splUd(~mK;sS; zVFiy>Z$Wvc;mZ_vx$D5hzz8$2uLhqlQX}p)Mmyd8X>BerS}*>}K$TiIe{Fw$N`N@2 z_3vHJOCkHy^qyyQoy?yOOlAPbSx^zi$_+p7%;k3JMB3QJu2VoLC7gZBA;91Eq@b1Rct^ph?`46YRi~4Vy9VsfOuq)t(9dDiD!_vGNHRcf zwf5A;b2<;Os%&nCgDAULnMxq1nw<$bh%F*s=}NOG(2zmzv6i1@Pt^vQXB_y}LP;^6 zFW&=e!0!#GLIKuT`UY0awT_w45qkLWVa_dQG&C=5wFPIdAP>#flN0RctVIVaH57!r zt;~5Cfu`usny}fob7d5n=L^6zLuG%}+E&=nkFr|@3C~LzypjlHPg5f)^_=y9-ALm@ zd8m4t&X*$?z87&87~QwcV9iab4OwZ{$6VK{7_wFcfL0q33fO@#w1kpvVVOX5(eTbL zH{e`v2mcJEV6nhpc5Ym5WM-Xt1Cg<|?&^r8`P%i|b{#d~x*!t_0v_`$zY+&~mmih$ zgLC^p>>O}RrqXNxR^)cBp^Az{hcWmM;Nq3f+0*hhO%)!Xjm%!hHZEXG)1N`8wb~$Q zNgT=t3po0Y!SgFQ*%q`qG3Z}gcfNi7>Oh-Pjr4S3#oi0yXT;_rJS9Q>=9=!fx`fN- zmNymcQctl0QQ&S+Nw+>Fg_j}H);7O#-oYvce{@UHZcoP}K|nB%`dmc|srbcR;D;N6 zgFrv{3>Df$IkLVvy&ZcU7d0Jq#^>D=BywX<;;v2IfpnTq8VU7i<9nu1XCLplu z2Ab_AB8HtaY(NMiL)PmP(z%z7vYQ>)ln|-u&`K*;{;HTzzUwy@BKLeD%ZYpPO+R2o z&-Ou{YvD&E_ecxiigitY@+cUrkH4a(j61z}S^Y6joJdXvk z0lut7jO`zrR`Lw#VrN-O9#N+cN>m{xk%QI!k1~aCaeA5G_ZpBM)rK}pMQgw5t~FRb zI~c6Hrr36-AsUccveeZ$K%3QUnN15X_0`I_kb6%C!#SHxiH!Iq5AcTzLU9AO*u8W9 z(@QLyF~ftov16mJ;fFwI-XqOIS=bHffYx=$Q4EwiO#WP~MWskIFhoE|m&tSv6yJs; zbYOz*@gYU}qjjnx=6MyuPi8&0Udx3Oi!nGP#DWwTgH4eRkn)d1COPAb17UbW1N-R1 z{;{INUQmb`u(ylPPTFGBMG3CB=!%0GC38@Wj`^BjMBR=af>KM;D3akQEo zO&>(MyipoLtZnTd4V=je9jOVmpXZgL_nubMnLDO~>41j{6#P-OS=W`TV>vEOx(f`D zg50>9gq^1fKp_3L_*d?ZX~c`!2=fEf&B%{KkrcJoERegf6U6NXZ!?0{A~Hj5$O6!T zdPka|(}Fm}jgQC>?mH#k`v9#FEf5Z|EekC*gCQyeyweh#pX%4Jj=$w2G2T)X5-!Jb z%U3&RC>}8e{48VC*BnW5DtFnD20$7isa4YIyZa2yw&i5+cN)k+HS0e}2+6XY7df35 zS@=^dJg**5uvacPYx~CYGKdOQfTsbTc9~1P)1*lhfybC&|E%8U% zrmQeD2|(Rn0{-eU-I?6Lxt1C9h(PDtS#F5uAT@&p2{u5QjLQ0GBSnl}rw_Xcc)+w&)m!+#UD0Bmkp@nEmipi$W^1}7_S z_t^bVLc}5J&c;C zcNCCZ1;?HC_0>XQr^K&-Xxh4%>%q6-g1h9S5WjFI1`ij{^7|nKQA-8K>k#F2z`b8> z4gq_DfH{&WZ>fIB+!0EMhp4Qd?BIugv61n*ifc+BcbTn~dUe*=^eyC=FJ!Eppfn&K zuWB`Gc_T-y@nl^2c!Wq1;F=i3)|XwhjoGH>uHcMIUgUHelwI-466)ntWNi3*D29_*j?%4!{6&aORvg@=*6OSGJQU6^WoE^Ob%% z3Qm^5 zl%2ZsA*IS`zyjvH#)SGyOTfooxImk@z~arPZhO7}4g~#}m-pb!3#T?u4=!FgHXrkv28q?3_B`;NTWap3TdBg&`tUO95*_>Y+J4eINB zJ26=Pn&WmderaA|7HLzp6iaS)uDS?Y`k*WWHl3J|9YRT7iTGz>yiy0crrHu}*tuTF z*^PRK@V@3JTE}ner}srwM!p&lZfgOUECB8PV)m2sUc8Ji8QdF)Q(Ef$sw}hd!t*on zW+HyfnNP4uD&nt}m%X;<&r`pB+e2ObF0`F+4?PAYWiSm-tNwn}cmRjn;=EUg`P$Xc zgYIvjJJde8dLuzEw#)IsUWkK~%7P>Nv$;q_qI<$$Q?Suo8y5N?od`E$@Oux{{lV6z$Pt?R>TzF?Y-2$ZAw4iYa6q4Jj zjMtGqs!vCjKPW||TF^#cgWGNX(Bj3Cf$}m{>%YwTJ2eFSALw)Prm06F{)FX;ht2>< zGHp5jc)P5YjYVf)h=nhyv#nSNRHQsHS7au*XK#TKFE66wa8h_vHez!8^Qiq&A?98$ zh)0QkfC#G^yNhr>r`;j{{t1n%SBOamX`5N#%>u+^zwz$0V%v?G>|JGCA-vSjCG9hy zzFhr_!T+HKXaV!+ou>D+Bq#E6pyPi!jhI(|z_afZzUcDF@4i zTXYgVil59d6Hb8-3jV}y*xnK!HQxL{bJt>{gRpZs$M^%r{%1>7z^Mu=o05|g7I0j!*rpeHZQ5F zT_y&;7VMPzL+$%_X6DMq0DP6A|BYv>{)vJcs(mZ%bILLRnZ>tIY;o3z0=xX7X&qvh zo}yp9LtFK>>28uQi>79m6#l-Z%dLC#*yAoZ{@8Q!bqx7~50i)X0+|rk_&IsUhu5us zA4k;s>b!;%JsVp%Ug`AE8EjiH4ZMw9JE-yb;G70{jn@D&`j3|ankB4#Zc%i(>xOa(EvfF&W%htq!{Uxz~-q4V>k@#1Q%bjxMn_z-j*D3g{m8-|!Lrh76Af zS{xwUxn_g4uWv5@DW&^ve_O>!SI&QH+5zncnSQax`Y59jRf26T0nDpId-p&Iqa_!5 zQ9@_Lv-h~`K@FRp&f5EL0(ZQvH4=?~gX5eUSeuM`_tcChS)%x5$E)1$)Ak{+$GJ7* zqahp1TtqKStH}71Jw^B?5UB^WiYna=(4T1wmnwbK(5u96--487IBjeBmyN;aHH)sr zp}NzxN9$InvR3)}c9%hiYMghSC@1T?C8$BtwSwC4%Xh=N3B+bk#+hZj7VKaq+}sk# zAKF`X+B6#&!!Dtu;xZxk{qE401l=C(sMD0UFi(1cV{hy)Na$AKbrL;a}nZuwbVtJ-_bt7l8^`W0XABby3$(ajO z#NcMF2OBuyNs;trQ!sAHJj=SF=$(i_Fe&A6C0fkttb5#-z~jTaC_dwJdkJq{m(GxA z-eD1n3u=jWaT|m?h;TOoEeTIK>8En*{c0e0)k53 zA*1xFYu4)#85`G?kFAWiL6)CyYk>sm{!<6ee^=<^eCr=`jN`St5pPy%`!N%{%z;Qq zXMxtDXiaI2WiAgA*OnG(2_dMVqr{1AD_~U(ThYMDiA!W=b7UfZ%8+#HgAmo43wlzh zUUBY9IwY?IGW{5FLZxFaY)x|iW^Hyh_ug6NnynOf0lMAQDCJM%xLJ;Li1LZ+m6VId!zDktCy*Na>ngG5YgPLA<{8Fed61b-!~OE; z_Z!A8a#zBE)o5%I-cArvjubP$!I&sAtuFtk(rjmGx~31D4Nqn9WK-os=WhMX)Oe%G zDkVIzGFsro9vyr%WXbznn> z_jOfW@=Ma7o1^dN5d%J{}cu7v@Yd%gELF2wxVlc&F2f1>v5*iO-rO*QhnSUAbjX}&uv z%u~jfaQ1@aub0W~8v^sh#DqHW&x8vw4E>oOBlVVOvn-7Czrc^VB|cEfmpob*41_D| z!RfjrfWNpqFaEp(y+XA_t7Yp%#_n~0Gmq-?vK3%>>FO|fRnL9yY|wj5RsU_t7=M)7 zl-k>MHam3>V@yn@R=%t`(!Pk#pH4AaX3F9w9LaK7-Yf^-<;;L1OonH^{$A$-we$C7 zSmCUgy;cR`EvEb1{_o!W(C-@dcRPRovcBj4H}dyn{QBkp&76Fbdgu%kbz1j^Z+*^e znb_+*BN5T=h!>4ZpVLY|-4)W-2=}e^bm|TamWsS->@zsBH8K+0d7=anT*xVkeM9ZC z&YxQT@StYn&d2f{zJ+a#g~?xO%{6i#Yt}= z%0oHO?y@b9)X81z5;dnZ(By%1RjinoJYUwexY_I!?Az$pVjlvUxzfaTym=KJMw{tU zmVr#sf4NEmabHiUaa;N1qqRK5>r(5W3ZsaMS zE`FvKr@_^ymiL923UmKdnDudMB!1{V8P7 zDZAo(_RXmYhN@>eQJF&~0s+(ed-dw=nprLp_GxEkD3Bu$eR=!iGRw8A*RMYKc!W}K z!QRg{|17-YA_Efrm-hn9clZgS=p>lvl@7F_AD^jD!TKi-#d6bZF^j~0=Sp2;rPkx8 zI#?X=PVIdjw@zu?RZp32e>QhtZW)JP{lRIkgG?ASE!0-bBrJcU7TYP)pD24BZcFZQ zlvZ4CQrc0*#k74km_RACn`u=eE5cnoe_Hnc(w{74(3>L`e6krrD9p`q_DO#O8=pG4 z%H=+LMAbzLy0lBg$*-`TzEb@$ndl$jJ6aw`t$mB-mUh3b-2VXdoc{y8Az=5b?M^K) zdg+%bFNf8)uG5}~w_uUa$WH^n!xbPP;+4|KqGsZ42i8JF$HkBk2K5Q_b;7s?Jm9#c}4{QaK9cw&IquWWSc*38Z{(3k>2|D_vd!;=$E7GS28qjQ8XZ8@_vzm)FXi0vg0oCtB83_Jf{ z-V~4y>3DTl#>d5qxGn096dcs;qFXMVJ-TTx_K-VM!8h+s^*q^J}N1!Jneyn;kt!pvH|d=G=Q@f=<@)BXt#$}( z)p}8c!bLpEQ^b4OJTe*P^ZA+_&jU02rjIs}4z=@NcCc?ZzjAe!9rV_A8xO5Yu}$m0 z^cniI;o+}vW}=s7gdB?pFHBxPj5^9+b6-2S?qgi{9MIq_o{nckP_N};G8wn^r&E8` zt+9p!`k3CF}Qd7_ZNsC?Dmf!7XL=U>)x6Vj;|R^pu@J&uJui4tO<0y~R_9AE7S>nw)k1$OUJ^br!AZ zqC&=ywmBr=Kq?cxs-b<;+dGx!vDjlxNfw5>G6`Fg7ty>r+jZt>_iJEQ&Ye-XV+6mm zemjMF+u&L6WWk91EM@xD;_A^MP4PHokF3JcRGvG>%-I%h!5eCGbhit4(MB)D^TJctgSGkTi4VpNZR1LTL5%xvr z1@9q$f|Sy^(xJ*y*%%m2c7Bly2OeJB7YS6P5Xm3ilI%rudkl#&^^w<)POq z$!SPvfbh($mptg}WtNC&<=+((kHO1mELdQvlAQVw7mxREHS zt=WV8o%nIY_pH3d3OjV!Vr9?y&YqW4SleLz?AbB*u}?gv`&OhVF+EEC($QKN^L3#b zCRYj#q=#|}xu38vq5vgkfWhbjYa$<04^OuJ)?4zmSK9VlyL)TV>lSOhkuN7x zhNQs!;~*2R;~+y=Pt&)n?9{itBD`T^Y;pBEn)h`aiWNy!O5 zXKLCN$|tL00R^D{^UB;|iQaZUcOo1&pL4iox1T8nSP1HxwiHVlWWQ6kIctHcISuf! zxO!8c#eE}`_wwt!{F}aI*-_7Tqs{H&nL==a0G+t<(q3llbgM`S4)8s;ub3ucWN_tznngH(LQxV z0n(=0i{d)P zn&)?)!L~xr0M7?7XYxoedTzbLeZy~YKhX7SYj0da`DH|A4)=G7Bc@qxRM1)wN-z^} zGpqi37$bk*D6E;DsJ@hCR;&tvm)MKR#Ly`{1?QTl!BWm^w}67Q_|Bh=WL6^ZX-Yzn ztFiZ0x)Pk~7WqgEkp@f{viRPITO>H>uy|#W2^byTMS%W6en;_I8Oxkw^iG{h;tCY$ z3?I~g*VSlT);{Il^^WuCk-^+vO2WvA7o>hIzCRzcvwOwMcQ!O z^qgJy%1jm=8A(7ITWd^lL8!}{Z@5GVlLI&EFutGs0h&COu)$*z3cVw=Sx15FR0VPY zs>o?5b2mBK+6W1x#gZ@WtOo0l1!aT-n#628B?)gZ6T2o|!rMG~1vsEs1^o$KBQy;7 zyqWJq;!+|O^Tf(bX(G4xGx=`HOy2}`F?gThxsA6*vh$%k_2Q{QI z)`2)u1<1o;;X)~T;?}@+boYnzvgV0-&(i(SGwj^_M~M0mk^LVo)4IN$W!nrBx3XD3 zMCi1{{aCpGq}#cCMGa>c8Y?^QSqZQ~SGUfhf;|oZml*)}v+m zSE;$6Q}eYM0ATbEg!FQUL%Kc)pb~wGk>M$s{fhmQrmq#^!z@1-N7EO^x21SQ8pd2C zM27MzS{hB9sa{BDeug;0+na;H+{$ZCmwzDrV>3fB)HI{?9=D^{}(wEn57mDFS z<~j;X5AxRpzrxO5IN3a+U~Tz=(?M*?c0nVh$zxTf$(HK_o7%b~L2+E;%twOY5$%~cJxgMmQ8PUv0?Fa3~j`bJfprAN1*LLJm%j(1X0#As2j?z?G z?|3rqm{T@}hHnT|2?8F+eyK=2YG+I+msID-qzmecWzBf>(-!Gvpno-Fr+noB0L0AK z5@n#k!kzx|X@t%%E-$qD%Vv|Fwp4L~uoOhd z@EXhtxuW<>BBqpuKKGI@_A@G0!2k^IYPj~hnC`$ZxD>StijqKx!W(J>aN31f84y9Q zVWUix;6DC?j+r2s8j*LVp3=FPOe<;ko&a&PuTAgAO}zEcHkOo&4}c}&P+0kB=>yzi zOWBh33du=7r7g6NI=&5ub!EhXlU=@`rQ+h2kGqD1&whW-Eumm+nQ%fZqy`|)6l1S@ zJTx-^t4ccjBl0yj`u9uDeE2ZUEx<>dU=lrf;hgDrE)CWu-&6JSv6#~gxh!VEk(w;|1owhMk2}h+aMnb*9b3uyvpWdJM%OOTd^HrNFBHR1?P zqE)hfApyAmXxFmNb)%F+H6#qM@KF4_jmkFQY*#RMb{oP78Qh8GSH~ zwFv~?8e(&7t8FG0{Y8l4skNX)@^#88xBMpzvT|N0aE3O$D8+rz-Zqi+E-L)Td1Y6? z4LC|i*Ep2}u+B;;oO;kQeSUv7t9$$~pfK!BXg}_{aU%p&pLlfuEs9;mV1AI!o)__( z8bT37hf!^U)EHvAw0gh ze*nP5FCWq1$d@!RGR_fPH%`UAn~Fm$h}Pr*OD5_w!5LD7{S_`jWDXz;)SoTM?K|ty zus;~pEP}WcWR9jvicJ%h;>)g5uO( zSz>-l;jR35Z1UWDxGJD!xFeJ{C`rJK@E4O;I8liWlPy^X&p~kkSorDTQ$5`un)eFw zR$%}_dqx(3xi*ea0J!!nWz|SliC}?51vvGCob|R#~O)<{dAl|r? zxVE`k#I5ae(A=r9y=vnPDmFfj_0K0Es~cZ0sj?2qQ21^4f=55lDvhDHZ+pK=3wyc8 zpVF;K7H||z%GGFhDmwWVpQ2^fPGl1d+Am%9q#&EjpVt3-IETt85ck`|FkUopk zI>qxDSsV~NI4kbVnUFWm=VM%e=?d#^X9pnk+FFuN*npkvAa|S6myyZub2vT=-_B&! z{Q?&TQ1toEn&hTs8;fo819R9Vsj+&rU7av_YUWXmc?c+o46USpl_v|URL8*FS}9kD z!_9csc9cPT>HqVpW*>`EmSGA3fgzwYWTlq6zbB8&XqTw~Bq|{S0Wis+f+R==&AnIn zaQC&+#O}2j6QcyNB0hwqsnt1H>f%29_5*Rqt!?O471ePcoN_)7O0qHOejS${%F##% z=mQ^E!NgY`z&;@Oo$YOo4x5;}YJ;hi+RxTPpT3@NN1~|RyKI!`WCS>v5x50J<0@*^ zY!o@m9Q@8F;h%%^C@>OfjJwb}pTE)S$MY6}lTBHXUF=C*^W`h)YXbrJOpAW#VY!u1QbLUg_D1?jt%{VVvCjbRQ zxLHt>4`6Y7AE`}}G7T0R0xbisch3wMt9?WwV1xaK0S zA(f4$eDNwL6S(^o6mC&{4x6a5hQu!c4(%rw{! zYNiSnFf+7O8(NcWlFin2YWR_g40fizta5ZzR%*#ir=0vpT4i69iN3!1t5arW_chel z1aJ{+Md_a_`Ue*eB#*0chji<-`x8#p3P*RgW1l zKNA@%Ui`36Lm9wvTJa{sh2pm5qygs7c1$f}D&IS63t#dbS|mRhD-dVkVA zlXkI+)lw1gD3<5V?_GW$QCX1F^>n99p`Z(zfzFypUNW-1?Nv4|Iyfp>z=Gda+6>$W z{KVPQa=FIe#wTG*(NDja1MIra*H>KJrBgW*1++fUDinSshKD zcMjJ%s{v=}mU`18GdLyrc-OLosL~QbHJ@ylNWNu*!ZklqH)Hwk8QEfe@xaA*(?uWh zl{*-Yzkit85@E?j%#+<3k3+ld#!WDcVB zHsxS@?&Y%@DqmEld3kTr%N+a*0(JO5yR7oI4;eU0vz!~Kw9D4{3de%4)0M19e8gQ) z!^gs`InzD0dmqce1xcELLKqcl%oTXtS{) zaPK4mJnEfDaz1+h{T3?03&eIU+d283igG_aJP6o^mcfht(jYh}sR#w|Eh+-#m)A)` z#!0L+IBU5~0J_&?@oRM8ILd!>G80bJyKr#L!^Tf;zpmoCR&F2*)#iGL?)#ZR#O~On#I?~c2}wGha&O)sg;UH?n+${&ZHGqlEL@- z^xInO4>_Q_n0kuvN^#(P1rEF+cE$@fI|N?*Lf}wJyG%-XfyGfyUi_qhJY2lJphYW1 zF)zDk#9w`tU1sK4vWHKe={u?}-!tl*0PpSOHsfWhhA5hOVJFS86yA}P;3E=rKY&3yb31UH5Mgu>$y`iqpU;!Z^fzQlb-Dbo& zZpV_x?4jB2Wov-$?*xm^Ty0w}?>y67STSmve_u(RviY88`YRl^qu5Z#q1Za1s6ji% zewq4)-gV+xQ~tBSUE#*c&PKLZ;-e?qfQ?`fS-4EI;w0VFnV$Bvt5;`}er!WPp`6Ym z2{uY?xk9FKNxtx}u!a37e|zrOklX=1*0+zo3MP-Y>*9;NqfMzGU6ZRQJwrjKN2pCr z>HzI!FgbZai9hKu*QLF*9RYNpa zByCHI0gPqv?ey0x2g#y0Z0`k4R9#R{7|9YS?{SPQT&X1W^92{;GpTfGx2?0sj) zll3@p>g6W+rf!;nxbaK}Qd|RsJ1bO~3&L6-yh4^_=b`GD|9RPJ_4!!r^tS`B3EgeO z5oN{k&Gr=7&eX&urLZnKHSD+oY&?g8uZ!|yxL!Eh3NXJWi_NB(orujS9izb3!3rG%7+$cU!7GM^#6{2TdjUZ!#l;T9 z)+UY++u*5;2c;Lq`GO1G&d6Iuil1<=15XsHIXNEVkhi~lHGeUABu-5w2v|1ctp;IZ zfERq{Q*Z-uE;B+t5W^frZF7;DvLNF%QRXDcQzc37L+n5ZXN4T5E6$7(8_Kh7*#4DlMwp$C zeiG$i{KkVBWbtFTQ&AJ7;~Ya&p-AR~GyO*9oHQfT`vc=v)bZzD1$I_hO)6*>fClVV z!V5RO^|Ks{nw+sNhPU))zDm-^T&Y>h4*`Y`d=TN>*^oYP@V?XZn0{#>U^B-jIf&vq z6?#Kh>3*w9T2H^2lMDsvT2$~=96=PW+u5+P#hhJQIn<8;k!~;C6Xf9?Dawd{tM+0C#~&U^x4OEe&MvQe?&F+Gk^5I7;O z01k>*%%P2G>3*!Tn^b|x1u_ctt;#OO9JL^09>Uh0)&KF9HT-j-fV16e_+l*}qGlpK zgvlS!S$(UtT{|%`OG@;h^ZW<9vospdws zb){T!`Z4%BVc#C7)a4$#*mLJRV0%{Q2TvmMQ|3zisWSql>@SS<%`-z>99J!|8wRdl zT%mTeXVT7PcjItq54r!QZkbRTVwT!>X38|cz|X9pt%19kk_I75?Ubw9(Gjryxn)rx z64C=k`Q6(d&8f7<*B%Qf=7MAgttL!QrN#a@-*!uA!$%b5kjW1yBMZv$lM2_u$Eb?~ z8}=Bnel9QOBZ?3k)t%+Q|H|+uhvg2)mKv z<|WV|x~r>0B`ACuit8mnvB>R4;Sc2}Du6;mz~7pT0$cLEG9hn%m|U^Vj>~OHhS+@& zfuwJi~{0}1qe+?6qTm*4$26k z6lu~$M4EKzC14FRN>LG!j?z0w3DPA3D$)!+KtL1-El4MU@a==XzyDi4o(IQgge3Re zbM{_)t+n@6hAOh>%Q&A=EBW>koIu6=CZz&wW#wpn1Ew~QY(~xZD39<*G{?8m>(f7v z<@P!=dhNO?`hNC&O?32m;t$oOS#}-cGuvMN4w=*N?dSq;QeYurhW3a#JZ`s6qjAc? z!O;p@50vCMVh!Ab7Rw)na_Z_bLYua2e?-5CdU9P1^G20>v}^S0IyPl>9P<4zDEsYp7?QUR4`R(T&3w3}y|7&nWK2w6qDg>+r^W zIgJm2)B7!y|3ZUzN2!JU6kXr0)Oy2(=v=b=oi&#JDPE=E|3AzQp9RjpPIv!x)rU-v z0jA&UX*MYxb4~vqUO$j8fm{|p{EAucQvlfZoxbOaW97*Nd&kU2QBUI<5S#{OD7#@f zp;4=zpvu~eI9)E=J+jR=XYR2oEZBMR0M<8{-Pc`AV@T80a5JKK<@26> zePe5bB4@RI*vP!@%<^FmzRuDA_D5Jt?Agu3AvPfl_xEll2kdO^ z>qA4;sF0nn;2C0Obfr`6lgAx-T37Ig;&^b_?(WWtLHz8SK9PWf6%HtmW+`J{LydT| zs6^jcD-TBCzs9rYon?)Z*em0eZe^T-?TZ!x<(zL!meo*LcDc>F>1Y05n&8W+7v1O3 z96-_{(LXZD$2yzou)y1@Htu`QpCFvq{Ib$ z*%G*if@oy#vPzD#K+x-`snHuCi}r1n;*xbL)%hw0OV+_hAsW{Ll~YA(a=VPvjp>O_ z73sDoDNnRD`hf=Ir^KZ+S(qmMGp7wCjpK5SqG$KX&Mw{Gzj;~wrA2b;?s3*afI&F$ zBHNI(mX0#x&WI<9aiHVu95fGsYBAQk3tvLdxPx5!_x@BO=Eu>MI zl2iwK`(zo~dg%q8dn8Ad&EJ3>f5f0I*G_FjwcYBd4|xu;t$Avd5-vJ;Enog_%g_9E z81rRf$#Ro!je1DxQJGj_O*`1|w~vFAn!B`TKl|3s67EA3z#RK!PjZPLZR!)7SJg!7 zO4JS&L48OU!c^xKvkll86BM?fJkABzjsc@F`yV*61hdM(ipE`si=*&AB$$g?SlK$J zFW?=&YBr68LFmcO_B6L+csDiK?zulIx!kS0uUY7YH=+rii9h3h5q&UPmc=DEyK2~r zPmcMwJcZGro%=ADn5gtzS;SXckei@DeRHD0VOG%NAcpgX3>AvgfAeart)e*5bLhE&JeVBPM2hzpr$ zeMn>(Z=qMxn@V6){2tF9RI8cf(wF9ONH|!V??3zGzy6XeS#hGBwHyP>rVF?6lOQ*_ zE(nZPAR!p)2&*KdhB91a&n%*#~8&CX^Qy8`-q-*xbVDUCMcV0z z39-S(o?%X(?DBJ!5fVKmqHkzZsy?OYe65VIpI8Lt(LmzT%p;Z#MICbY^hW-Vj#wFS zf@iQjMuN*;U>}B$3+Ht<*!aTu#P2pOA3Imb35HJO3J-lnT2nWgGM7BF8fccNM83L0 z*Yv&l^2FftSKC!-ANvKbzBB5$fARmygwwy%S)S0T3AZqcZ%I2?#GpRd64r(iEW0|B>!ySLJ-a&4ttw=f-r z&C?17sfF4t_b-RmC-Ka%{UkL*33r+x!G{2Ea{chN=}9Sj=2u-%G)V+~K663cA_x4C zog;P))AZgCpdBKTRc^N|GPn2(B0X9p`TtEZF&`#3EDLgn)$v+FXOTE!?oKW7QDdW=g-IvqXIqQ3^mmhd$7t}4yJo_0Nrg)1|YgS^?zy3H-3#k-HPhI$LP zzS=9?f_}w1`#7|r&ZYq9#qxqH37IUT85=z=yX0Mfk9MejhN`5z6Hi#fOoR?AuJv*& zifu;VhzdVGH=EMdN5K@WWKmX}+@#|ZNYU%fD85r)Q=O;ej4zPgOL?AZ=*Spz^G|8Z zNLd=7X(U=@lGs8PrJeNsCyx+@>%re^vG;pqs?QMVy77x=@e>Q!PiBzh;&vaa1s?{o zjYTKJ9-UEr7Vz#Z+Ixk$P(6*Tly}(qk(LJ^Q9JVb=+dcd?lRR4H^9($qd1vR%UEgJ zm%#^i54LJb%#dQ(eP8pz-L@N49gsjg$LM>bUSN?WPb6$RQXRX)IQOYuORK2U@Y$Rl zcj_IJIJgA5NbC#Cp5tExus49GuC1*?D4`^xH`9IF$`*Qy-xAFY~h1S>?6de1I$w_x}*5VyIs{oA$+@5xUI5P1WPs>3Y$VW0^9c zXg5{t|Izf&tX%5n%ElZI&j%S6Jtz3>WP%Gd0g?Z?yJ-q4tdQ^IB)KWZZS@#YX?PZ# zL$4k6c_Pk6*S#mE+rrTAT!NWyubmUw1!n{FhN)+hM}NSB^BE!~Sly7*bbjV}7CLvxAl zO@s0qbgz|(UToLb6~$YYH%teqMZa7<_wEHJx6kd5m=tqV?T*)sn9BP}0S0A}W{cx3 zi2F3rm+ZB?Q0=^EYwf%VjPf*M5Z8hL)e|N(;$HU2S={9B+AfZl0a#d!qwRqgaa-4h-k(Yxi@P zM*D|ODJsh8I7W(P>-(uCc2~p9#>WpvLyM*Se$Jp)-jO&v%i!Mze{*~W;>KyLroV20#cGl&yAvqR!)5u!dD1=r>7dvroX!h#x|{;eg3AE zrL^nfDQEQ5N!Gws6a0}>AoiCgFhQ7V>V){D(46y6=Toe$SKo%qN3o5OtPm}N9NFI| zVPsv;;&(Sk3ER#qd0?4S)zLC_%_keK4m0!%S1yp+B zXW$z^#Ux0@?Ai$DNF2!JO{o>Pef5v9a8ZHZlh$m>VB+Y4E!PYAsbI?h+mBP=8mY8x zM%IV`S0V3H|E#RykAq6N7pHY;+Gf>9?HQmN?a0P7uH9Bh%wG>U+J((1KX4=-pHb(Q z8@YftCsIXSBOv8&gsS(vba_I|geZGeO;K>cQR#-UlUef|a6ncdN5EAn|M)9*P6 zZ;I0nwe*WY6#|T*SiAtFcfg&npB#o+|GjiE+$egshqiuC&Zu5$<73y-zkK`cIF=o{ zSJZZgO;VQ?yw7%@BzogZRW#`j->5;SQz;z9ylZc+0t*!&6V-}-nZf&&Uva#UxtN6j9`UAWx zucWv7knTvE5@gD%a=&Ca|5K!PN?82DzXV!&VBVc#Dc=JhD7g{UtaMNBs!`E@9k4f9 zr?(ynmKP2p40&4i^^Btr=GwTuk>RyLC+_I@r{BC)GWvXpq_c5-GQ_Q;5uUEY!Cmm` z08;?>y~%t;xjU~s#Wo24xlEa|<~LgVir;Ecf=*K&a8x2`6#5TpO`e$Q1Kpzg>^Zon z?rQ9rB$2YAl=2$65%%bjPWIfDr^F{Zl>&Tmqh}VI4O`C_ng0Q<_=te{Pe_*2RPDV6 zy#!!q@~ak(t(e&y%#3FZSM35%y+ikX9b_Qo*&PO{y(gC`NsRSoz&8fTw2)oV`rsGJ z2ary+nt^wl{9@+;XCxgA4i^@9-xf0}le2(#qb=1dbM9`fUYeC7Ua5f-nlbXv-5W*5 zi=gJ-0?U@9`HHAF)d%y>z!&GB`65~Wn+!fyl2chV@&dVC%W?^jCE0O5^ptxBNEO5Q zu6-xN#NPgL@h^dR0lyTf0G{y<9%5X`0Ix8kyJU0YXnsH_fx15b$YaYYBJt87n`9Tt zunomL{W`Z%cx7DqsuH*nW!78qeNM4U!}{q9+L-8}X~p99&)~AMAIpo36K?KXi5^rp z^OhM%gieMmUEit5TFY#_vOeo+6{@0%T)#UNk{ADY(`cT;tFarS}wn zNkb%*K7=MP(+EiXl*TxL%*|w_y~|(A*|y~$gyP7njfnT-KGEy=q)TAtEiNZ2js}!NN+EoQip+udX@ItZ*)6jBtHxN(p8#L{`W-91mnm-u6%YIrTsm z!zgR+!RGFLJRC2xt^5-p^reE}Z(r8_g4TNE8`T9S^0e$R4HFMBvu|p>=}OIghPAu^ z)g}Cv`CB>btoaH)IWHIT-O~m%JbY2Z-DKmeE`$ZB^7Ji6Q+5stmFS`q2y{l+5Ob%y z6g2e=Y9ETa)i)}bpj6i!x=~Tp{$c;1grnZ$-!5Z7-(qs=(;K|sI(;{5S!Hu1GK4nc zgX4iCkdU5TWePcJaIf`+DSdM{^Q7S3RbJJ`t?`P`cnC?N7kOY^k~UmLb}`Qa8Fi|v z$HY9#6mukA%?FZRZDU{8rFpi1WW$n}Ep zK)nN7zhWtW4Qxl5diu`}ekr<-nv&?_0h)xd)QbJrQ|g-9+rw)nuSNh{Ce6CX9HdQc z0?oAl4r5QcvsC{z=+=?nNKAq*qV(+v&!&zh9K#93y{N&=@>$fB?h-vnFvCC)0HU%-DLA4gtW zb{j=r=-M7gezGhn;ckEG`89u5hk7s*%s2+X(}A7WMS}M%S5np6TL~@~`Cfd3b0fQiyyW!J< z-N2Xpws@kU?Y8%+;DGH%<>xLsYdNQc9O?VnOl=3L2q_!87FU{&(I4#`WSy%|Y*}mJ zeHr7|<9iC`s)4;P?;!+w`5r1#T$OT)Vb|tF=XR}RwB7;&?@q4Kch2#ZPxylY$(vx* zaGY!_?i4M>>r5#d8cfc+iqp&Zj!YMakruUGuupP%tOF4Q+BSR|2-#d!w*cV)hgY_W ztIsdrO-kJAkN~#nGI1aD1~|B-h%?Nib1rm8h%bU;hJ}6`EN)V+^z%TrE>LbKKnI`K zh{x>ov8R7n(Pl=K9b>{5?w*yFVWR`0@W7al3$b(#Qh2cw-UwS+A#wzC%!af(y7G2M z>+*4mG*SxNpjCNW6tE%l`0mMf^z>a7F<}b6WC>#1zXBvd@P@K-O|F0R&v%iwdzB0> zH_FOgej9C$k>>n`f5AJFKJw(5-#^~qZv{?RmbH25Auw~=1S`coN5OY!aD(g{G3HpV z;8VN6_jXe%geuQ`UDB)v`QzZS!TiRDBi)325yg+ZR>(pYtPJWqvF!$-u5(w!5NSG~ zV_N=$E{`p2XbPK-;`Hs{iX3pulP7n#`@jE*4Zi#n`fBgviEr44(7s=&d{^G<%k^Em znEWirK#?=~N5xM@_nYZd(KClE*3T6)e=tWB`#7Jk76d)eJ;mE|4(?A}wD65S6A=wh zsFiBZnf3rjJ<#dg*tcG_NX~RuTxNFM$-@u0_)hnTZ-)kts9ct$Yfr*`KAF^{MomoC zxl{kfv+umKIXY8R9)rAlAe3%r>n{nJUSI#7ROTxWZ4`eLF5ySoZsA{x@Q(S2DW`_r#wX)_^jZ z-@&o(pTU1?dV491@TLL3j&G_oNh$VeMyU;A7H)tv1IiVD!&!(5>G) zeeQs!H)vRWd&o_2LwrK~uc}CJZ@UM?g>p(7Ah&tfcp;pPwSo)S&n0K8j zNr8Q}-+1JyRjj8z4%ne{PqyF5IbT2Q8^tZ?;k$4N@}NKKn@i9BB+^cTy%%VN9~yOb zW(P66ulpXm^Z;CT1h)nd<72vtfRpv0P{fFJuc@BMm)LWEF!;(nX?An(>CU6D#iW6J zH-Op_rDWvi?6P9}(`2sOT*PVD4k2;JTgh?5Wvx9M;;ZAO<3<$t-i3z9@bzzg??K}Q*g=v&b$TDVSq16Y|a0ByF5|E4>t$&x>0XEO^AC6?WSluoR9o*K=lG6jFf z4Dq}7?=A%p5xN12mV$X$ajz6G0K_Ww+0MdO&Z0+N!~KIKE55;$tKpg@ZikJBNnh6g*A+1()8i9T!laMa?Fjoro4$jn}4p|^(;tyL2Kno}$bK#fN zkH)cXg|Hy^2&1Hj$S79-)$(oEVWav6YV|QQFeFwev^sEr70g6BT(=cxj%SDf!-K*g z>ijW!HHkj-@>K1D3`+p4j3UI~+TfM4bgIg>p<<@M^f2z%<*P1tUG{A&$u#4Q`Ezb= zo25qU#V=qiTE)u3X6k%WwA0YgB+r!%o1tj7{FbMz9BQm3Z3?CPLMYgGDkTQ{zU-}C zxtnpmevqxDATO)>)=t}mJ;1oaYfM1>movTUaA&Pv2|)mUG?B0=+S$kNngwUdT+FU zQ9GX)_ei>?s4vCC9CVG7Hi2&lZ{Cb%cok)v55QKy;hRB?f0I6|DaI+BR*93W8ifSBYAi@n+n83sNw#~ZRanl*hQ!HD`hgb74 z{q-+uDL;OsF-K1ZR|*YyeMlpB^(Z8HG$`htsk;2KxoN*mla6CUy$@?MHRx8w>#zA9 zS@~Dkf|6VvftKMxk@?Qm=~2VhwvQodID3@LtJ{|L6^w%`r%Qlux;=cUQmZYG54c6B zh>hMamFoo0O{w)bfn1lJJ%4^ep4mmmUOPR1@}1Bh zgvtkN&IzW>ydpmbM^#`zv%(P+JI?n<@|4BisfYYm}^2s^xj5$Ug&i zRXom@17WL0!oUf&)m?B|Jc7ZIRtwz?091KIR>3v_^5B;M(OPic{zq3&cvoCk;SMyE%O72CS! z@rkp5AX=)tZeuyDL5ztw&r=Og6}zWID=;dhH}9lP1!JeL%6tb%e-fx0CFaOnuE=$@ z+F3jA3#H9+)JKK}F;pmNZNB*XfzSNOIT;A<3~*Cx-}ixF zeh?~P2RdJpv>7-Tkd;6rtOHsNyK?2^2T$zZcnu6&2^xNSvjX@Z+F@qW$^}eb$+d4Z2WLJx0r7E)DxZP zEkf$9%zJ_Fl4%O*sj=OOKM%L`aGzk<>ODxlsu&k zF*qNNmVZD8h9I}#!TcO>D&F*@E(3R_a8gbrKYR^wR3(>p988d?ndaFe~?t2h@RD zd=va#yDn%a1s7AuYste$v|c`a1)9&7?v=8*ZVH7^rEuur`cCC8uFmPw4NNPjKS+fP zn*VU{(6r3NA|4003;CctQ3)RC;zN`Vrz=Q^{Vm;+m-ulV<2BmLpS^k|N-!8g4-`jKX`)aJ?NqvG(7+oWKRSQc!zT|P zc7tgP4>pvs*1(DPn_YkGr96o;2FCYx_7$ZKD~hZrj% zGFUj!b*B22Dvj(YsI~mk|exh9~C@O zkdHK{+&WLVzL13$dcxY^hR9=R(E&cwxuiOEj7$NUvI{KJQ+17i2bKC(z%B2fSi{g> z2?6Yg40nhpQm#?BvMCD>YDj^~@F1pN7hKa0%bXwM&*;Hn8<A%>il&xkvsPxo>FVcMWX) zp7RDVtqBEJZzTCvE&CLRD&g~HTu*Q7`N?&zH2cVHtE;_1_ZAn*fhx1T4Ek!WKskyG z7~*5(*98kTXR(--fA0JLn^&LZUxT-Prgb*#-z0zD^<`NyS*1Sz4ZryNt!Wt{Gv zO-A=&kD5e&1t~O#Z8LBIPvyM!8@RG&dYH`s1GRhK1e1T*<`34T?Md6*O6lMSB8J#M zP{U2!-7N3Pi#wC2;f{(7R4U-4>u3wO>!3p6LU}UJr9_rYgpX^w)_)JYyBRN>cV&b< z#Pj<2If922VruLSf*Y2nX1iG5{Nh3%)yTP&=l(MB0>vVyxM?3e@pQ0EGKo{04jlt6 zLmu8W`Oy3C$2Zdj#3%0VWng5NFgLrBbvJCZ9xar+!g+6hgCzaNb+hL$yVU_QO2*ID zufd3?$wSS+BO%ZFt6*+j#`2DkrU&rrX*TJIl*(~C0_6%sfn=~bfv8bLuz9G4JLzZe zthq^vHR5cIvRC)qOYgR^bz3&~ixX}65<0+FSdHgz3YpMR@V~u-sJEDwFa#ch!9%A)F+eoq;&!} z^3=y~x770y1=tKB5X1Y+mOXBWa0Jj)!_FpH zDd_pz0TWwxzXyNE!E2JBn%0VlLL&NR;~OVBX&=@s25hLpzPC<-Qzs>d*s;}owBk*L z4#;BaBQLeq0?382iL>ftH&V zEbmwN@6=)B{xZptsybf>LxJ2CXz>n{??j!&^$bmBu*ZepCs~z(Wu#)9LMj2@F^(M* z9Gq>w#SgO989Gc1Z}%U=Gw%eY0Vbh_jBdMARGZi9WU2)Hk-UQ?Tu?8O*fqOy7KYr6 zgJ8#y+dNE!W;97hGz_EIJKsjdJ67r;OEY6BV$q~kB!d_Pu2WP{AuCdiSknv_{xNg7 zBT@b>PWrsGr?2bNeNdu6p$viqrO78kT7k_EGC1*iI*wb*?%=Ws;9_9ixe{w{Wu}&E zZaI_1K`sE--AU;YsD(=YKnNNg7e-lN?%aFBKRyw^{9-k=?!Qc}_~8xY531Dw#de06 zzRiZSgp7b`MOZ^WuN3Q~#gD3zV16KsS{X%L;1&*GJPP6A9qm`uXj%SyrEdeqf1(w`C!o+;2rHnscs|bFs9pbwc%CHEg@K;uZ!F5^ek24K<0^1h-3V#V#AVMyk9y494EHR% z7s_lqat5FPr_w{S?8?G66AoJfvn&@b1hY#VJmgdBeW$@^OaTX?q4(F&W*ti^H_xk= zyt;DcL+dzN`|nNc_l~oKOIK~*bQP3}Q<9}kRk_ArdpVYcC9%n{(j_Vj^|AvVNpX7| z_k0Jz<&*fhf%kezjk00B^+SP!he!=+%n`ybyg^#>q1?C(2hMzL_dZHS(ty&#J327W z5}e7Dmb#$_e#VhsLHp0PR0CgtUL&jCic$~0m-WGx#4joz)Sq%jUuNu9p|(Eo-~Rrz z;)>vp7TRrXeTds>mi{1x*q}qq)XaaT=wVD2M_2S@QdFkP; zJyN~fgf-$Vq=9{&3;8E!6+Dx(z^3apGH^jFQSoD%^-kBOWn{evYdo*wT%CTf-&daK zpaP%wyn8*tA(L<|2VTZhdx%^sA>0-;tX41ye>T#X9FS3e4)^zQ{wBdP%axq@yMH4B zYU#5kY5~O&pTk0f6OU9Y%c1yO7!VD}(OVnA(;_va*@+qIp`{-7?k*;S3TtHiW^|fF z?{vkUZ6t3xSUpatg;*9K*@#yTtxF)^+=^Z*Q|JFPOxO1v^sdC&YKc|pVp|qu)1|SH zan4jFpMf`G9*Q*!F0z3rnOMHYwhuLfly{V<0c^yvhc5BW%B+e7sl>y~jRu>3XjKSd z?>humAt0z|70pw~Dt=#WNxI5654xXWpPr3ZwsNbw$E|RVYts`5i!=QEu*ihCg8Z5O zzBGZnB!wI6?R)4C3s8K$&C2~nM|w!%+y>Lq$UB*B!}k>97POkHt2vFI&6>+-J=@`( zl2z;>$)0m-$CrG;4%Ed^?3rTZRRu2G=p8q~pc}xw*w_Bw=(F0P<6l5=LFri85Ddf><$=Xq{rIWT|N&_3^bZymE72+LzIc1YWMr?oQfGC+f$Y zYzj+1vdtq@?_l02pLy3+#}8btGWVA-0Qd*)q`@xv8t>Jcw=Dy_hPY^J z&C1T+MG1oJi&|APpDM0@Ina83_Y1`uy@hDyZ*x0_x!#0uU3*aVlbvv&buJ*}#q#_j z*fo0UP*MDwHt_m@oKA+Vb^N|VX?vZPd^c4}XrZB{PH*_U(hO|B^zYz{Fv}HwHNvyA zpPd0_`zgJ4J7DS_KJfAMRr=xrCBD{9S13G|Caa`1qYqez>4BowdF<>}@Ccn{)sw5j zf0!Flt@%1#b>kQ`3jFZ56hTQ{4yNaI%$Mp3WXY>yLQatEQ3}r zm+?+v`K%A6h=^-qf?gQeYPI?(o8tguPM19viLbQB)F1D%NGh}+91Wri)3+j@-BTC~ z(&65dWo^tfV_Zw}+>pxq3IG0OaN9SiB zU}F=QwJ%X&pf>H>t9?!0uoyL8>y^kpRs+tPJatZ?Nl9kOnY|+E8>3STD}dXUrq<|r zRA(Q|V&B|L{@{SS6dFv&$OMB7(a~coN*vf+v}~h78?clwA?aHe98G zATNSrByj%bE%n6UVaZ=;P!8wS&d_1j+POWof|-VlF7|$?q;B>vQA1np>zHyUll*gV zQ_Hh!>jN-H_#EyinC%xA?V@1$R0a0AT=o7stU=^Tqp?xBph||?k*#m1lKM>+L*u_Q z&}8a{7;DXfWAcFKLN>D&1jc6WQlL9Ew6 zp-_c-Gtc9TdJ*r9iC1FFCLzF6{9zACGE>)q%pW1_MW})A(k%m*_Z_OtZ{YKScStF4 z;}~Hy39z_>r?5J~I&{U-*iMEF{D7;sCfCsM+>R+A|2WHhZeBGlK9aphT+9Hlk1*9m zdVOO9ue7&NCMQ(&RF4^cazTjlz`oH4h z@w4J;^d+rg0S6GTSw)i-ASTQ_=|NqXN`CgqmFUx4+Z(|`OpOzJE_M4Kc)sAS3fXP# zV}^;ccBZ_fl~1ergS4ci*cTx<9~RF8q$=P9)5kl~Tbf1Ex#V?F+vI6yc;2&kxKyDC9#Ya1daAJZBJPf^pnOw%X0DeS#lbK795>Cyd zx4$q8uXA+xz_uO9c}WGCKoT}em%SX?#%&$;SNwaX&-Gr&u=Jt#Zu)n#A<}Gj%f~fR z7~z)OgIFexuxhX8fCZi!uwOp_A4BfJZo$*Da<#i^d0VBGdV1y4GALE(M!MUoVdD-E zX`Ka*^l)={>g2X|Odk|d)W{v^f{2tqpN;Mu?35~i0(|6b(#G9Pwb}zd;o)ZjRVYA? z0dwi>UF8d<6QN9ol!!p6lxU^W%beql%Z5>nwLL4Cn_e_7S^{%M7a#Rk7m^?-VOAFP zH`lh~*>-GeH!WTApgZ3!XmkZoAO$lVspE4SP#uBU2vu>Bb^(k!TK89IdS)r$Vhv+C zTYb0$ctwXznZ0xg`CxKNaIS+HY!bx(hac+Z5h!`jQ47ih;SKnw;#wVHe;|iA#Inh5 z(X&^1f|CnC zCbhPvuvOv;;Ub)J-k={!dC*Ubl(126&G?vi#71<(Gd~A&W-NyCW&70nnh3#@O_;Xj zUvmBti2MXmcC8L=V}T*k9#T;obzyg1ZpQ)NE2X&Up}Opr;l&?u>9+5_4-;c%DGS~R zUB3;&!}j15X@+LiudVu^7bh>ovZ`%yy$h&k;eHrSzs)xt@U~se#x++_ z)6qrz7^s3hwPy(mFEe_M?w}neH>2bIHJ2MZhre%mn5?GFV1zp|i z5mtQ&Pt#J4ruTSzhKJk#7!?4ZTRp+sv#Vl{+g%{U=6piobw^}`!6-cvC{V*FYvHq- z?t9Z#=oHV~bMFLy6W4NE*6S~}8@Id}4Ai=;7_1Bm&!Y6dFG_x4KYwl@&m0%%Kc^LB z{y$8zY7Dqo*w-;14s6~|c{VxLCe7B1df?#eg2OoT{WBagK+5*8f(T5dwKqei(RTgK zonNAAR)h4b;N=oQP(<19<|6>L>!gnC%8_8ke8?iSKxy`%*y!?sch;NFQW*YO!kQtB zoQTlR@U4&!!jT{pHZV$0O0d|t|I|+{Dg&`dAfJjntcS+Gt=4Kx!rvcYWmLGlSK3@y z(QNa)cAycm+I6gLeA%O zS%^>L^2(iipgyj)(J=k$0^D*MQ8oDL@=od7_aVBf2;y)3ptlSOf+Cy3!RxW{HOp?62cp_fj_e{F|~bUW%8>Wsp9njawWEUVp+u}!UzH^5TD(Gjg{1kB0ofgf>O-fha?4>ziORoVOft;LS&7|E~pbt)t5cUK_>%xxqbTr9t zMa}g9g(RRe3dtP;8`CXHlOHRjzJ08r%!u)H=>taUkKKIAPnhX90D6&@wtTg=b7+ZP z&Cet6_Bse-N!{GKxMZ)>r$iP`Ud+e>k6R1YW}mbiO@-hg5k5_S@3(&nW%Cya@*eqp zYfT7{FF{-NNr24Y3+xK^wFJ&l6>{n>rvOr03I?g3z<^zW-WTww8$c9Vw_ zu*_+ndva&#X$P$j8gX)fpNNWon)<#(X|hAPtV=-#;t~h*x78i3eD~0>NTQt5<{20| z20X6k5Jc%x2OF2-6*OJ|d{S|Wcyr+U&b%)P+L(Wgu^>$61H~I_ae&kfP!AWvfvdjc z)x?j?1tHB!dN5aFY0XOXLW4yU$5y|?9l`xuMJ;R6@4mL`;-Qk#u;F# zpbp9yxw3qQgp+nSP!6fX%?4pb*dROJOf!#*?)G0Y>xA%E6-(BAVh{2CMC)faj@-^G zAa;RWvMZAnT+5SBIdUMkluVQ|^|1|P8jV+vc*@aL4@`sPi$Vt^+CyN`WP=|0BW zdOS{bjg|aJaTN^@*O<0mhacUJRYA|~fX6wp{_lC>_JmL2Nl+N>iiAcM@HJgOmjv>_ zj=^2dty`MLOKw;6hp9*6Ld;>mcYviPV4%#|QE8O>&GwXy%)EqJJ`nu?YTm|N={F#r zWdMrBU_Io1`Re-r&+{PczXR{e;<|f%IgYvm^Y`dwWcz}|hWb{$ETF%r^6U94)$VY< zV9xKsVQZ~N!~;~NZ0Ooeh7Ns&>`EYvA5$5T@G72DqQx7SHwS>e-sRfTMe9+{X5;dz z6#Z#ucQ#vyvE5`W-+9=LSOYHz#h{SbuYP-VEkZg+Q1mxXMt3@R&rPp$mbTn;g`9($ zgdN6dgJ_)VXoZoeNe*P6RFk&Piyw=!V_-i%e?Fws%xg*UzVkFe@xy=7J*7bR)V+nO zTeF!b%(_3jKNd#KzltHOjg6-!BvfvY9?ZOnV5mW?84>Q2v>90U7+6fP9Q&eIn;mzt^@cc4HmH72t>xzWi4XngUKmcI* zPf#2t)voIU^Qn*z^^sHxpwt&P~eI$Xi47uZB(}M^j$RP{;Vx zpBLdDAm4!?g)#8>i%C|pi|2D zx;#193`nbiL`tB}KxwOz(7+X>zZo3h_E_}1g4w;|Gan&M0B1sq=gDb>F$O(N#rY*x z=%s!NXq0y4@RPtMLE5p%PXVgPm9ebYa?P`nO#jS#9r?$jc$D+*@mM-S{qGW@g__}^ zar}d5(%85l~>Opd^?lg^Jb;R9M0#l5UJ^^T_&)#XI;J40s9mCxW{neYBcE zpvSgSXWw4enKrJaKT_}+_|tkFFozPrZ?ut<6vQr&c}{hVtp9}&V4n4#8vbkS&85S6 zZ=J_D*tc%O|9+f{O3wK7iC&;0x}oIrEe+LC>U{-R19_489RLpM1xvgaklsxfH0G~p z^1@jOfIAa)IFr2kp@y>Fn7xy$lU&ZlF`5Sn3l$dtI-x%w`KEXVgFk=6)Yy;d%*gn z{@tBge%Jx_q+9OhTn%nq5&Ppz0T?|{n~mNe;TceQmzp(oCsW_;YqMFo7%It4|2BTu zcioJT?sHo)wv7*tAL7D@4*vV5txRjaSI-2w+Ck0(9c;Hya+m(Ocgnwe9+>nCU<5($ z!ePZ=zJuKS5ae?;+`}4I^Oc??!=8Y;rnL9$&)&y5-iYb{c`%Zr5Sj+CQK$g02)3|7 zKAMagcU7w9j}>TR^Jm=cSzL5(1FFj*xw7sRbi5IjvhuRLDq9a&20QwT+k{GB>NCpvRjISMQ|F^-%UVrO=mNa2e-I{XNS=|&^wQp9v8gakUFLT z^ypq=(tu*HXxE0?;gic}gCNu@{%B3{|2n9ip!8uD#`bT061XZizAdcv-`2Zi`!SbG zlET5854~Dk`n^!qiK48PAKX(puYIhGlk%P8A=R4An?Pyfr_8sM_q)>kyZuWQZ+MTs zNxgbt0t9${K*H$#4q+&(I4e8pFnkFRQvRzAbWPQ}AlSi^K2g340q+1uxY2RwYca>~ zv<7Li(INSYQ|;{CYhDqHn46M9^6qmtljDpEAQU1qGjUr>_rcs>(A*PvmQa#euYW)D zwz{rt_z}um0`N=Ngq1LLFLn@>0FY^B6d7Zh_AE~DIm%TRdwdHItnkio@T(tQwVj!Z zjxhQz6;N)X54Fh6{zN#5arwZ{>@pyg$pP(Whppsy+St14o%p0 z|Dv>@UtA(SU6Mx?;0b(7Ll3N3s5DfK8M*bR&Q%vg37>?06Pe&zHmLa{ZpzE zApe2A2|Dm$NMLOV3VU}W6;6LFSqE;rT*CInUFn1eUb-E{SArUp2B9$d$?c*{1*y?ei@*2qKd<9}`%}x_=CE9H%`x0b& z9}ju7z_sR?uGhf$6)2!sG0qAvK%9*0K6^T#9QIH7QUU7cURVG+p_{rf>ZtfWhm{m-qelLWkO@4w&zrvMGoLh2mkVjB$1jT@3ST?h=%{uJh#%mNS?`ptJ%_v4Tv z2I>O{J)GLbP;UVW_yXg2UW5H7gxIU9(=z!q6{rvh%S>=`Y|AzwW0?^I*d$rJOW}?! zHH8wcB>Ap%s4e8fCj?fXTf`2ss`=h0!t(BoBlJWZH@g2&8i&|ex(pT|3IbaDoqH|< z;GQW7Do*a53TP-FMAS}+3sO{Bqh%o}pfOzf!HG6x@#FvFW-HvL1A!4f{r>l!LaTxwn-u^bcAbe_`RU4 zelnR$FAjOasnM)?mL-`Cu3^NRi(6mD>&iP|?=d`l zt6$fvO}5nvj8h|Km2Fo8&{3*h>%2ucOM?tAlSdX>@0_+m&o+fBZ0UbJ+72G#SRZ>Z zV~aRSEM;m4{d`Fu4ExVg_q}>oVQiQJrLKYhbcw^6SA%FGz!1xT=>T{Qg=^?K=|sq! zzY6<9%15-fnycm9=TZ?Boh6!DQTVy+bL>GojzWEy(gg%b=b3*zU~CuXvi@x~y<$J; zyOyhZTLI}3agGkxTOhcxxVB2|%{jYwz9awwHm-T5L%I9RsyFDEaSvmpyg+0{o=z*1 zP|57B;BdI0VI|1DP?cz^!U)*^5CKr`LMu5C#SCNQ%eebDa1UVcVqQEQ&@er+bNLD1 zPt6yGzC8o)Xt`QJ0rilzOj#RvTVU?w8SFo*;D7+N_DkWwMENlgG)EMmR5*RK?}wCx z{8bs)cfhVmv7SU*yE<5UmQRO}36up@3w)c2^gcZzlN*}6kddLB1Bc>^n0{xGeh5v> zr|}6SI!wUSNy0pgnT*3>rJOEy7e5*z|tet(R4QpX7V?zwq2gvo+Yy113W9}O?qnv z7dNts`!^k4QHor<<3HQJQvjn^I@?EIfM%0?p&co~P8;;)q768fT)g_=-Qz54U|iXA zte`a@g<(8pJhEr$VrU-Q4+-Y;EwHDUym>p@|Gj?*m>7Ea%y9&1^fcUKyj}Yp{o_e< zs}NCsy9yjjIzu6cVP#JZty~MXkY7F< zkgkihxfQa&nw@?$5AZ^#6PdMhEg&T%whtX|K(m0H@-}%s3jOW{;%`69=pEr8=jy;< zFU|W}TM6aAE9hJC1#sm*kd4{an-u-Q3+fkmKTjMnu7K243bLD#P?aqxx|*kK{&r7! zp4)Nmu7o%|)$+Dcu^Ug;tnCjN#e|dY@b@pK0iyxrgtKuXM0fVOe; zuu1}CLA-J~YZkxFi1(U6SJ;m*+L8U^)B_<)?lgHdLXkgk*e08zB&k!?1;mg@Ke5JI z>^lSe_|L?B_t0vw^{*BF^`AQ7U*5yX`{x3Io41YIOM5=)a*pda9E+>zrxF1bKdtB@dM z5+Rh#_sM(_0Us2M6wgSvT5%35XdHlmsue**PFYCeU_G1HsI=fDF+s{ za^DYN>3-*i0F8=|4Y454mB4SPj*ua))FQ(xnzXee=$azDHhdf?=UAfq=V9<0Ycm89 zziOqd$lY8o090myS7_;z>Aw)HXy1H4C^bs|`Om@%9DYC~Ln$;M4N9Qm{$B)G zy5RUH{r|aM`G3JJfLq;(u1Mp~N7!KgwELFy<0>U-G;36$f;$a+yVwq>Xmq7O&T!7b z>a~E@1nH5Ne1#Xo`E_0^4fyoMk6}~!xc1$N{k&xaJxOxuV&c z4kEMP`t@D(vo&}&O}BAhD3pKG{v`$J+>=6lCG2hguLm^U>`x9XT7`rTq%#OAEX_vTx6_U4BGS~lH(rpR=gNWeG%mgKM-#EJiP#)>h zsZuaK?e#A!DAe$~{0#$N1ENJr77|&csv~nCahWIf4DmrVkpJOILi)7yb60RuT&8Qm zHMqn|YJf-U?V&sUz+mZAEo4IReVR(OoP@J3(qdG4V*bDO-aQ`5{QV!Fj#dYDQ!9&XOVJ5Y&WB1V5klD_6p^zU=V3-|X*(b!WSAs` zkc6D6bqX_bKF*LB=d&5-`Cj+Th<)$-^Zo1h$M^UB?)}(@W@qlX53kqtx~}K-d|uZ* zHA9QwEs>NM3ZLMS;4k2ba)y9&H7ld!q}8BI15wW+5eMdT;27u}NVf}M!VL7H145$o z+XYe(->e`P=-!j?eEX z1X;x?7CJ_(VFe&MI{Y?WB1B5iO+F~$03MvnJM>9(jSr%2!&B!R2%j?efp<9`GshwF z7%5PMl1%uKrI#S6#g2HV(8X@lzZnHDy}d0pSEp4G?xbs{BZM|_<*CG@?}WHxAM?M0 z$93c=mAcXC3C~n;j;gMFM-mIkChtp~F2Ty)!M9`NZPk*zY7Qf;(p$J;(f2S^xt&0M zc?)U?y8kmP?A&B_?qM*r<$yHx8Za=FFC%5FP+xN-LI!kk&|?O_Y)y52B)hP;Cdsb+ znzbnS1VFb}f`AqJ_!gj|?o-=b^~~0yI&x%{xz zb=GnejK;6(Y#op$$TJ_27P@SD!L@qY}3eU1PEVBhSJ#9W8SEek= z74=l4gf@J!P&w~P)Z%Rpn)xK!iU_}JuLr~?C`Qz&57AwRlk@#S!aEZL1$P&(ixTVa z-=o2^E1^xu-Ng%*;@SgZa<8$E6+x=4CZ3XMfKmYp@l-3R%d)z{0G2sYGxND}zo7Qu zBMRTEM(KjA0N%x9Q=3trrGl@chp|9pFpJsb8i^{UbXOsV{z*F28r?VdGE@yfuGO0@ z1SX6|_?C%mW4qLpBmeOO=g?=)xrs}U3eObd-&t1+p`eCW{0AtF0UQHa;Jk}GNSc;y ze-UKQw}p0BtC$g*$g;^@y=btZ_gFX)qCL)`Nxo;{vDJ5hsLCnkxr+NU!c$ETdcker_`9bP5kVWHe+E0f=9Z`IZ>7nv!2u3I-i zx`0n~pNy2r+rb|sW!7x}6V-K6@7lpv5C3|YI}tL-ZTL>)uf^NzYA0uP7UngeeCj77 z1qC5*Q`P?%bDvE=&9fR7|Ng=M$?yAUd}~vrQ7a^{hlW!H%8mYTynvAh$S5*BHmVZj zs9dP%$dz|#I2PV|bxoUfU%cwkKnpK#;`pVgC8hBieUGov@j)W!P*TM}041u_L)Oy9 z+9o3uT%y*tL;N!F*FW`&XsBhYs7_8My&G%XkeySI*{Zn?<-=(|gg2Xw)SN5`I4BKx z6*iwuJl}`XN8?Shi)rg+`W_F(x9Q(#r!Gwu_tqaZiP(2!1Voq7meK z%llb-ffzF6RT2Esv68W^p-@X?eByIY+|$Z)+4QK~y0*4Z$4rOHU>)y6ZJEPz{oOE- zu`r&ymOcuHUsOD#Ir=+(_CT*C?Q6uXOXKfsOn-(0OqE9P;I;o$<-=*e?fauT!8WfE z&zf)T<4Bz_Yug(bT|}I7u|>|2Z*M<4Co5ywT^FG}_W_0snUV+hqs|6ZT~QHDQpx%4 z+-9qufw@-Z37k0hh1lF`dft$2TRpI4%N7gJ>6OwG_K@s6Wt2|`M2+2CR9>`hc`FZg zBkG|jGbOEBxIdkjA|eCpE8g1fPlg?JedVa}^Ltsc8&hsJ`H4#i$ zU?h55gZJ3;^7n&WWxl&UFbbSq+^m{XI;h+EpTi9-5egYN<~~ke4y>{YF0QzkZYTD_ zl3}}rx?fi=3vUpudnUX6>8a(HlwL=`{eoB#+=kv8X(I$tC`fc z@6HHCJd^f&72BeeoEYBy10%$gE4h>3^i3->?G+RAVIIu3iEZD+wi&A({2<2>O%e>y z?Rc}nM&60`XUV;+Ffqpc7b}y2dGSHRrG?JrH>a!i1_mwM+sYVen;anunRW69>72%i$_g0*W?qP5$J-771HM}v+ zl3BNI9Xzl|83y5P0O-0?E7A^dbfxlOt&z#y-%c2j0x0 zF=VCkRz2OPpO_CrVHSjG%3UyqPa4=UrpRJ9bE35r{A9^;10Uvf-7-R9&kJa?rsDHu z=B5%*C$?&{QfIoQFz1xql&kgQb*2spv>*@}Y_Ro=nE&1WjH0~VHUwz^VSbuQ3ekr zJ78>YZ=Vsm0i`9i%K2&ZiW7n_2we&E81>ePdlq*QVv6w(efnEVXk9`WSF zoD7B>${Knz5^D`?mYf# z_U-hsZ$~e|AUGh|NO2%^{eA3Dx??;jp9e!=C6ff9g^8Qdq+4}wAUq?hRFZT1rh7+Z zCC;EH?;zg9+EMjMsA~bfY!`0g>FvnX!Aa-tD=zsw zddBn|lGYt`BSIB>h(A#g2S0S9C7hd4I9dfDYrX_9@ZP&w8!4Zol?9PFS{>k^fzx5U z4rtOQM5)fXF|0B|IZ;}AYPTbixMXJbD>yZIiGpT@0d?i>{?UFC54ZABFCId6CH31g zxEAE zeTPgg7_zsmY${BH$19SZnZpsAo^M-^H)}X;_{zHLAe!Zh#5Y|dz!EKbbwO-CJ#Ehwks0_`1vLbRo~t(7nCz%Z(ImD$AW zkU|Tmcmdud=4+vRwm&L)b2Cx#-i6Z@&dvhFUTWs6o|bA>;IxM!L!8gd<^T~5 z2Szi86t1NG2r>6C6VSjSgt)0I_A1-s`11(fZwM|hbaTse@ZL?bZ+ULt=8B(6yM~|+ zhqJNIyX1?hDwF-`&v`^D-nUv`@+`Z?zQe?`Ebmwkd)@qlpyn&{6@Snl+-YbcnBu~h z!g3t8VbTZn_6EpuVEBTNqqqG!aPTf8s3LAb*(BjD4Z}JZzxN|x^weosN zWuiCYmiD|XND)V3F@Qp8Ny#UR^KFw=anF>}`&(|}NM^4B)1v=62WzEM3tx$UQ5cV_ z$G?a+ChWU{r2_L^c%FGGkg6$^)ZD%%`_Unl8PeW|{VJNTu{ssdK9~9)AOZ#05oEG@5@^y4d zD6BTGbQhyc6~+k}m3+7bh|ha(!-Qi-v*#b1gN+uo9QgUs_o*5wlz{WJWY}i06QB!w zm;v8u+q8>|Mi4Rzn#?p!TOXV$a)VcFFH(z1IGN#Lj)_`g=p?|XmQ4H8tORqdxw)Fd zS-2ryZ2(1p^N{3Xw%tl4(%=8na@ccX%?5JYMRzZX%f=czd`2u)PA&}gYFTK1_FgGh zewc%c9b_eGaqE$b%$eShwi|_#Be8~38PodCex8fHwunh?7ck5#Wez9r=bAV&+_R$T zQmv^0E0w;$0_D8{OJnr(IsY~mSU9rj-U(RLY-Hp_UWJU{NCp);i}gAeFs#Ueyb-p7 z<(Ow;|1mhTIvGd#RCNA6nwBR&M*mc%itR4zw;NSRE?audIGuf~HH6n9a%<0=F}9>f zbYbRt2}2y*j70fm2;QGP(Vf)-gNt3S(HyRVO37s==Za?WJ34aJ)lcz7}T&n1f}9)MDHl1bS5Y+HUx zq-_e}@Bp4^o58D4bdhS`XunRSZ#9yfYDl@w;@`~mJ*>ivV(cW>hTZkOBd^Bq%snzv z&W*u^IgU0H!28%pjJ~X+)Tuh0e0}{>ziL^zUy-m{w4v6aI0wrs+xT&}E23SxXAv}vMxJnJ_pEX9C;#GgAU}0FcJz#?XHLLF z50mw!4(nAeh&fvbIExkO8_6mkm9`1`a}>p+_B|uc!VRqpFTCbpDN&$Nk21UCo}VPj zw#sdjc|SbP4=heN5iQ*P;#{>fy663Y)A#;7+f?r^@ss#fShPG5^2H@qee)wv^s715 zOOKZkIop+hKJl?xA%{?VlX)}WTdV>GQv%AKA4$V3WDz2|%X5E4TZHdLZ|76J`E+HO zAJ!Ou^Pk>T8LZtCErcQ@*@?``h2YJG3ruyo*;>w}XL_^ED1 z*CC~TdevrY3xO$%HBr&k1x+#qicQ9k7!-hSB1;LelzPAbLWMF0jfdzh^RX& z%YP%krf9tjn9Nd2JkG$3VQ`*4c{NYII63L!#ragKTxc!c%tD|Z0eL=;{8k)_r1g&P zFti;*>Ohj(#$*ZnMYH&17s^;^G~oSrQ~o8FGCgmm5WKI5A27K%thp}isk%d>@`qgE zj{7;M?I)Cke1DS{?w0F`yOjO)^MN1{gi|}wRXZ2b0S;Uxscae&?9Q@vF+@Ry`attC zten?E{-{e__gGg~_vRga@(qO?bvE%m%9@r!ydB|}b3W5uI`4=vIg#tYX!H*cSCTU` zy=D*KBAM9dCM*L4I&Be>C`4P6>4Dd+@oq6^)*K7ja!UC$*{;*;^xbENjDqyeojZH< zX!dNchelA2ZEE#nkbi@o+SuqF{5fgS5Uoz6QG;&ZzTN#>^l1}}atDR6Gue|GM9)1bK}+!VHnIo$u8a%&F`D z3DwQ?q(omo-YhTrLAOhn895TWYCTWGJg~M-rNq>vN(e^j*p5~p#F%_v39nIlSKw)# zxex!1P;7bR_}Qh$7x=WIyV5dIlaZlqq)N*U0CZK)kD}fG7^zP%yxvZrhni?X%pn-t z$W3##-meu;E0~gQ8*)X4zP1{k>c8~P;>+tN)Q|TZ@Bd1gnx>VSPNbeeFyl~G0MIo4@qD$RAQoFR@BaWqwkD^cZBeNV+9JJx`E*PbUTs* zAjvYefOQ;oOBe@8>3l@R>TqKzkB(JVC&&RaYJ{8sLj&+7a@^JZJy+|ug`aB z_$=ah8f{mb92FRWJ-svuxFukIRt;WnbVuM((C1P&g|^(<7V=RCC#!E8DSTTE;~^5T zgK6So6Q>qxg#czF@rmk2dyk2P)ay*naj2$AJZ-lRn%-L!AzwS|3CDY9TRGyq%WK}& zE_TZXWnS?xf^}V7Z2OS|hQ=QE!mL7&5#`+^Xe-7wR~kUxxx{B7cCx;$$?fxKOT4Xx z*Nr!6yLG#EzkUpv69TaN`{eNKBpFcd*ck6FjiC0WG;WKfhuPqot-2uQP6_raI->7f zuYC5(KtI^XFxj~-^^p%VX7W5(aV|yHD@zz5B@uZ3NFvknlW;)PL>dHYZpXQ0_Vs!a zFLkb^sikwB*9{qs|M6(zC)Hi;6F6b*4=KnR5mrr7>Cg+JyiGv~i$>~Y7K zB0#t%e*p~${uzoDid~U{imq|gzAxy@pDJW##KAeRu|yr1s7{kM&lHNNv4ShDoSY;v zwD;HBXo4P7Ndb`)gDuU?Pk;$~>^QWQLwmcS5`X2hLltjIOWVCBm5!4*Ei44y8#|np z!v1+f?1Xcl?rDF)`Ktx=e4?{}ig6(7OT~#7NODyFJmHdi+d_TgIi2dqzjQguZN7Kn zs=K!2_OFz-#%HNGb92DT{Z9iNmzi!DX@pf7nUIGu7S205bJ5duP*DPXC%ULWzW6V6 zCcOH!?_LHXPK(bRZLtOTaS3a0wfB9uV8f|Hw5`I|t-W}8M$SQ!o zud53x#`(Q5n0cwO?%Q2`MBOb@&esG*rQfGzUH{ZK33sf~QKj&cZYO?))JT80)z3}; z4v-HIachG=m-BOZ6?Sv%*o^VdZ#k9_x3}b4T5G6gh~Hwi#T#(JQ?%)xa@*AS`FVOM9}IER_Py-Y*LZ<(Rdr>Gz}d`6sr#T4Fb z8M{=jlN&`i`Qc^)$P#F*33_cQ`fwFz=`V_+VPBjwC_7*P!+i3Dzk_%ue@GzSfg}Dq zQOa_5{Nswu#p}`@M?NKFr+uY6LuPtyXx?PTRkxORdoICq2A^Z`2UZ7p@9J|DCKLgLNSl!WoiC^>s!b2Y+z zRXk9@Gmo#S2|J#Z+|Dr0^h^VbVF! zv)hl@Xpgb8b?GP`M7GJw!(9dFR)I%_{U~0KKUN=%zL{3d`3tqYgAyZ=GKd@gIN-)uavU$7=2Qbk$w*}>0%OX5aaZPA~$Ha(&`6>qg4ITRY< z55ID^wE%+BIS9Sh-XtMXbGi0}EAY{pmR5k|*=x8Gk>D61DBkQ7(N5ZF(P`tP=!k7_ z*{2tHK8OemDnAp-o4ah%LGKW7@2UBanog`oh)e)@@zK#L(pEBDh3a=VU$^tR0np-; z5I^5q<)*QkikE5#7X5@D)R&IgYuMBTxM{ukWnt=iWJLibf;i%qxrNQ~(^17Wxch2V zoIaY>HFz^RCG&#FkhOMENv8q&j&N#z)W|4K74D+C*xw&nes7lW<5;Y{SX2Dfi;H9U zV^4%2XR2k<+{_Vl+Vvk?23sdV+HG34(2^Y0({nRZaN`g2gp zqg#{Jlkc3m7TBQ1JKRZ6SayaCQbN!-Bxwe^zkhhJa>WX~FQ!EzsKTKg+rE(!7#-fnIq7wm_#(Q~GWM9ANt4}!u(rgY76i$|-1pi3sCxmyu%jaklah~7!fGb~@g*cD z@7VVBrL4?)1q*PeseU({CVZkL#pKylCC;y-pInR{RvO0|}`~n!r4C$sZ3o z5B^Y$(b9g%R;n*hJN&tziln#eSM z*Ho69v%mfopuj#r18E3ZxtC9{&DyrjtuX?#;n;2T;fXlJ1}J0aI3bLGwDvyt%cd;| zZM<4K+~mE6+5jSYY;N~Fq~4!X4&_b`3(iCOhifFs$_%yC=t_wEbyrVpyDud`3)FK7zT}JL&C=?JaBecd>JDt|5ir1AY9r z=8pLXrbS*jU??JwMUO@)W6*D7y5FxA5{|SnPwnJz`VHLA7$iB00t_#vAvV2ap;09P zh8B5gE*+4t^<}vOv9>l+OaYqH_>RfwfCeNtLrZO$Lv3lQfYPlNjXy?7w${)Ypd0PT zmoH9O`vPS2?E-gx0^huG;Mf9^naJTBeMs z0xj0C1XIQ@2MTt|g>30-e!JPq>RT@vGL&5^l$F`AYRYmY+`M5mR7ys~j;KB9`7$VN z8;b?B*jxr%imO{vI6@mFv6;Vz-0MV_u^~c*kLI;wjSr4#{ebz{BLAhD6P!nFU0pD+ zlE{j9XI@4`Qq@+Hq6w|uM&~P2EmJR8F0^P|C_QorACfd^o8+BPv|;M+nDl(=Ibtc>@mdjwTC%%* z2N4o+e%sX(z>@3jZ+QHWYwMORud-4la$SdsV1en~v61mRb_+jO zs>ZW8*1AL2M>iCSxTvVj`}xUk!koCbHD5N=q8{JbU|6W=74sGy$>VMh!{CH?!F4`ARt6~Xr4@1v*nvpYtGvY7b7c*Yxoqs1{R z+R(eVq6QTYPAEo3Ua{JNQvPE52h<+uKsk$xUlU7@;s5X_UTi7bibO^T@4C4;SoH*^ zLeUk8Ry%(iQ8+qejfMB4c>Y3M1U;ibfQ7v*DE^CO`7}8m1ZBSL?X8Mb6rJoP?x5Y> z58eoINU5AbY_EFj;+h(se3=hDi<^-sgFcSi4{LTBe1QtCb+636+}%|ThRv=Iu0>Br8G3Ck~K6qln2 zMs*78Zbn;b}@R6{k#yU7s4clK@mb)HeZ+Zy+c$it7J2h`&T@AhW|D^yuo_~cUd;&!1G-PQ z3`9ZSRTU;=?@vdWSg(7`D2%CR)z95&hbm54>MT zC8d$gXN!|kFC2N14yy7&6X!=hTv}4*>%6SyTg~#yi6A9D!maJz50)O7ns)v+C>7QI zp{`Nqo%$@+)U@XGe_iVo1*Ep6zjHwpEL`At6)EI2Rp2YE7QNW3=AcoIh-|x)H;lK9 z>2FTAALk!F9+d8BFG(J}1EpI#7 zue-yg`IC?e2xKGbLWvDI4-FnuBM2~X-rRs&x+7yrDeJl|qRh6EWXYxE=o)RHSK1}1 zp3#DGYHJmW?m53~+K6dut!>PdfLbz@;c{(Y5x)|OrY7vpp->0(HBKHku;%C6aRwo) zz=k4k8vM$qIu(*9IG>X_-T|hsHelJvi7n=Hhhm1Qg2qrFID0!s{J#fbFZP-u(+he) z)Hw`js-fUy=Wo(v+dn|8Bh4slj_b%}oQReb=vJkm*B-~pA@zUDNUwK&QN?3#djt+J z>x6W7Un9iV( zf}L-`ZkP2WP+Tx_MRd@GbQEW<>PtAqL8Sa}{&sX{O^V~NN(|~o&`onO7OO@2*aj(p zF@ca0*=xHk*d+E-R7PFARYlOLY#?utOe$@>bo-{vv^7?7Oogy`4D$>)FOT8!K;ReV zAAjupN^tTNhqV=F^f|kCjt{U2H?awS?1qpWFLR}YTs`Z#>40ILTmASa+#=-`!*Nl4 z6H=`R)fvW45sXnse}yP-{p?mvmIU3Ygn(Me*FeTR>@{L6L))vtlIb)neZNpcAxZ)1 z+}3NdH}9-t^NyM|;!gwD6a*FrNIw{AuHTIJODm2(yZ#tEGp&u_`O`-#@PfinxQJWI zlJ3ty2Fa&eK@htI{wYI(C(H+2DBP6rr%tf9iLZ~6mn$Bhl7yEy;~k+M)56(}6qPUE z{N-jN`cB$%oQ&xj0MxG8|NOaDT4)4kAGqy>&W5JiI`8*ifpgnTr9S7AhLmXX+7Q0e z?%D##4a)-1D1uLTg4}Ghug*>BS(3!9py;m)ol74g9jZnr#g0A;(Q?mUB`kB8fA(+ zdLAL|HiEB;P?Tk6(mY{*t#aNF#oUPFI zb-_zp(s|umUtdr3+7~vZkH!seX-iJ_s&_de^)ElpD{UQYT=H6~9`(99GlOG%EF^_8 zzS8x_{@9H|1(z4j({b_xGghj3d+Xvu6Y~$UZa>k^4&`id1K~tSjZ{;)q#Es8Uw`Oh zN;;M{Ka*ut*;*1Xh}$;A>}Lx#Jw+*Cvbm=xfEcc8%o)WK+$@?Wefd*(!F)VLTkCTH zsUStD`YHubxPwsXNC9u&;^GoghXEo~ETH5F&)dPj@T1PLT}~{h8$&JWvWh0Hh`dgk z$=2DvhLzKLKYwo3``OxX6Zd(SRJUn_PXcdklJ=wq7>lZKKd^AnoY&&(0nWjiXBJ2H zV4i|%Yh{(kfA7L?Aa0k?O{Z=@gm$wAo17r`Z5Tf8YW)ec>9bw(2k3On_;HY!q2)puG%deFZ76LN-9i z7B`U0Fj!<|VJlMRO{^5SgjABZDtGg>{wh!2OBt7$*)*cN54j)Ez1X2yw;3W5&A0c$k2GgAEZJUNE7yoQh-<1txGE&dMRHJ<+{TSLF>86g0FnX# zp$v)Gcu6CIG4%^J{>XnP-R={dp(o8*DEWqxXxdfYusouRz9ZB$f1*(0p?;gKKYjx$ z_<;+`%HkVokJ^{$vt<0Q@*Y1pvjt+~3px7%)^U~H-*#A6ou6-7iLFF!5GUSyD=K-( z?=emVYO1UR7y|tKIg92(Pz<~wl<&V=R4v1%c zaa*eZ7URgZ;Ekk-%Yr&vdVZ*cV@WRO9j;_Zxip#BeiAR*DJ(mUW!FM8`O~cHgZpHs zSM#SLNM*5u2jk3Q02xpgeS(A~Q}Ki`{gX8c&m%LORjkGWDVe;fd@jjb(k;OaZ&MRy7a z2t@H9<;d|~3}8nfk{JBy9Tj&X`ibhyA}(ScJ5LDMgvG`$I$;-aGrSr>ba7kl1k20M zd5P8Utqm_=Dj;1f&4lVXd&L45<4*LevckC#ErLlRO^=0M!a5;z0;ah6Q&4m>F|f%* zhfXgfur*q|?9ASn7)3RPVpQ}PF~e_^c(jRhW`bx51?${bPMVB*&eIpw9qN;vQmee z?YF>VfRVa=W@zi&2Yq4^q9AsmGyu|T7czA{gj3M6w<34R5)z)SE9zHh!FPU1RdM}H zZ%3*?eIA9l{Yfl>EcM%XDt3=5JpgypU<_R@o0T6VGFw__l$a0v^6L*R>6NqHstd}@ zA3}o(PPDPGLwJgj%jpUgGyGonPrL}(G_pm$Wg{icS7u?a21~m9ED|HE^x#KQ%g#-` zk2_vP;9D4{t*%)p#eZj=QFxBO5iiSSZ&0H6oLTZ&1`7k}U13#-?j*CDIeu(K?Y9w= zGZ2>*q^FIac)`KFOW}InjMn2Vr!wv2W~p5Fjt74^{PIBD&-${HkxsSswHLl}SoJ{_ zDEbZkz6QCV;uun@qNp*!;s;H~Rnu8j2x6Ogt8~=Vye)}#N+argD10v3Z)*mWz%J~% z0DOa{gXvSY@IdCci*|kqJSZr}z@Gr&Bx&~w3Um(U*)NT$LJLQ3wES^aT}*x_dTpS$ z&m{FE`VI(B-TMkt!KJ;mo8Gn?3-r+)o_QrT038!EhJ`^mMh&hibZG#pZK2xP_(#H0 zy=tXIis_(#5R5!*#`W5n#j+lXAA&`4UGoJ^5I$T(Z+vIp8DZJrsEX3sz{63yg=N=p z^X2lnztXfeA7k?!UER?>Rzc?@&;iqe{5y^+n3k$5TN$G>+_{eJSWSMfl44n9US%3W zZH+k+pqxyA6p55K&K9+{K_yI5Yv_6{C(+HfX{_1;B%>(vUv!dFG3bZuv6}xS2cW+1 zrRf-=U*3b_Cuv`50>v|Y`hdrX@xQ+|szUK6(zk_sf2jW=G2l`Cm=7Du^WtpjqpjYo zVz!U_%+&xHAyh7+D~V(yQG}bZ`{%6&7nvf90NwX$#RKa45dYhNcBtzCbRy+d-700V zN}@YAx$lKI@)&#*5Zh3rzYDWapVX!Tlqb!c_8~>M$NL_7pBcoh^Y8FVVx8 z>t^U8Nz<{C#m^FqYbtrY{s)aEDH6WyNU?j1{YQa0O4{L?=)UisXlZF_Xs)r1A4{_O z(OIe6-ay&(a5$T|lxWWI#*{5h(Wgw`s*axY{V31YP|vWuD6^)|8ihacw`K}>Fg{5- z0)A{R&Pr!)F$bzO*d{$gMBCWs&k^bLr`!$9oUXXvoId#VwZ>alx+T<8qtx9+IIE>_6p&Z9{KwkeTEgHH< z>Q+@Dl=zyheGRqSVvPjjoZ=-LXcRP4f~pR7IXU91FAv8;unvB*q5y$&7lu}U4G=dM zNan&yh#DK?pPQ3=T2RPic-k4qm@ix86P8^U;F7T6WGJ&#sCsO?rnLQAWOcU2EpQ0? zwlB;Hqv9_f{J9R=@xWRpiWcdyY>f(rSmU;!S?x}~jpx+M{?v|z8vJ#tyLPJV;urlz zie|Y{_=?S%FKpYHaTn_hw~+f9ep2oSpSp`;+o9D9DM4i`rAYZ~l-GdmN30|#z5;Ld z26$}ahiksnMe^V)NJmjx4kg!a%KfE!MD)k6!1#-c6b>VH{Hn09uvHV8N?T-jGeY6z z4q4Y2n`QOaL$qy+lFYB3zrlhP-7<;iE~$J3bf0)Lfdi%HyH*h<@!U$Y+3!H#7{7a$fHB z<9WSw1Eg)pNduOTOv75q{uQR@JhY-dd8I4%rz5-#kQvydow8&Or`<@^BOkN4hxcP{ zVm^!ls`FzmEywn>pkrI6Fx))*9P;jCy3A2Rc42Ry*w?iGX9?2Fp3SEtG1k;d*8Y5h z10zUVXAJR*5o3gT3(MaE^rE4~tgPRgkc%Bri}5d92q{cK8jn!K=1<%KOP6~8uysgc zzOnM9uy3rh1;NwjjMSM|M=Bi4Vx7(!(-Y!&9ZF;T73dV)kTCPzmG7d0@l5O1(mNVU zAhIrXjIz*{H}N?u5us3g^lhMpgtG9L=;_D^k2=3_TxC16lV;@@J{sDri%!GMDsn$o z7?JlpQaJlSr+1cjNe0qkW_DL}6VkbfrU0eq0gQoEA$ppQ@=V+JypMgbSGvd(=%8y$ zUw^LOYk;!gilx>SstznTs*aKL6}>&NizBAh$yq5{LG?`y7ckl2u|*o#-J%nQOA64D zx|5FHweSfS5-xl6A|_$wD&ZBY*DttZSO|-oV~pINjbU^@iXqYkl%G=;^Hs}7m&if} zmLpJ5D#lU%N2X>m^bx?@mXIf6oac zhvIDqDUo{i)nhlIQ4d<&bPn}^f|dKQ zq%sL%S33n(nl79)o-5)&$cc(MJ=wsRb6=(k;lxgtk-DE6!cxm?bm4S@jsfNpJcVk* zpWjAiyrd{I5%Ey9@;g-CFCXJEqP00LRM%aa3}mQ4m%RU6v*-w8nn_7* zgI`$M^yBv~<<6j^v66cb-{#gnOY$1QRi4I;@J{#@dMH3c3)E33E%O>k46{M%b@C>X z7jU>y?spTJZIj`oMaMF4VygGo&|CBODW=1*FpfL)SZr;>GFnEN?L_RGj88T& z3g%#pnR3F|7Ogi}b;1|hdMQ$}I!cU?h0sDTB^{NC=_%UO1q3Isx3YZ=jo)GhRq&=W z4FpZ53xy^mJjY6@uh!0zuRXI3xQR5c1zGIwHq-X-VTqwC$~2AIGfdgW`0mCNEql+& z_6n=wtjKg}O|<=DF0*H+Q)#8~!0aP*Mg-BxIY&nGU92-gzZ)7{jh57CA~OOkU)uM7 z&0eld()Tyv#sA5V2>JIqSC0Ve^}qA=?>MZsihu3l{|0*qyQ{!ea2dfOT4z?G#>zYg z*F=vB>~d+Z(PdUOX#K$4Nq6^%@!ykc9|Ec4MDe>;gD(v`zl5akX~pnqXB-A6!>52qRvbrFj{Zh$69 zp*dpLd#}apf1awZ5=S#t7c{+M6mj7I#SnpTT05Bo;KT*3tXSG71iVD*FtDzL?H8F; z%&L==}za}M1On(zRE9B&=Bf`tQ*#+dnN+E_2?zt7ivhww$VN+wTV*%~=ln*a| zN0c@81?sLLEmDjoALuNR$|5BICGgN0p1ns4MMHTdbk6n1Cm|Zc>TT}>nR9gP9$svQbM=>@2|329VEiN$Y$fGyfdQtn zfsfYqQAW7Yq#{_D84gOp;mIeE-R?cYtmKo^fk(+qE^L;t9Na7sN5A0+ z9q!7h^eHBUUH&@quhwr&1=z^2K%(-Qc*PFG?Fi!W(OC7kr`5af|F#;;t^P!7`+r+G zC=^mFw0ghg*Z;oH8i{|$;oossUK7Z_|JuX+B^Nv-|^lz`a5C`8mF{RrkpUl^S=P11(2=) literal 0 HcmV?d00001 diff --git a/ci/utilities/README.md b/ci/utilities/README.md new file mode 100644 index 000000000000..35af5241767b --- /dev/null +++ b/ci/utilities/README.md @@ -0,0 +1,16 @@ +# JAX CI Utility Scripts + +This docpage gives a brief overview of the different utility scripts and what +they are used for. + +- **setup_build_environment.sh**: Sets up the build environment such as + cloning the latest XLA, adjusting file paths (for Windows), etc. +- **convert_msys_paths_to_win_paths.py**: Converts MSYS Linux-like paths + stored in env variables to Windows paths. +- **install_wheels_locally.sh**: Used by Pytest scripts to install JAX wheels + and any additional extras on the system. +- **run_auditwheel.sh**: Verifies that the Linux artifacts are "manylinux" + compliant. +- **run_docker_container.sh**: Runs a Docker container called "jax". Images + are read from the `JAXCI_DOCKER_IMAGE` environment variable in + [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env). From a940100a1edf4b7f45af49c49c8fe2000550995a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 10 Apr 2025 15:06:15 -0400 Subject: [PATCH 0537/1769] Enable execution of explicit-sharding notebook in docs. --- docs/conf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index bcc0b7a762a3..cddb63653a17 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -223,7 +223,6 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: 'notebooks/Distributed_arrays_and_automatic_parallelization.*', - 'notebooks/explicit-sharding.*', 'notebooks/autodiff_remat.*', # Fails on readthedocs with Kernel Died 'notebooks/convolutions.ipynb', From ae29f63e815427a12065e147f632e1c334babb77 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 10 Apr 2025 19:23:11 +0000 Subject: [PATCH 0538/1769] Don't use default quant config --- jax/_src/nn/functions.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 2c7c62d28222..eed3f7658c5d 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1381,7 +1381,7 @@ def scaled_dot_general( configs (list of BlockScaleConfig, optional): Scaling configurations for lhs, rhs, and gradients. Users can obtain valid configurations via `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` - are supported. If `None`, `mxfp8` is used. + are supported. If `None`, falls back to `lax.dot_general`. implementation: str (Deprecated) Backend selector, now ignored. The system chooses the backend automatically. Scheduled for removal in future releases. @@ -1422,19 +1422,9 @@ def scaled_dot_general( warnings.warn("Backend selector, now ignored. The system chooses the " "backend automatically.", DeprecationWarning) - # Create configs if not provided if configs is None: - if dtypes.float8_e8m0fnu is None: - raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - mxfp8_config = BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - configs = [mxfp8_config for _ in range(3)] + return lax.dot_general(lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type) out = cudnn_scaled_dot_general( lhs, rhs, dimension_numbers, From 7117aa03faa11ee41a80b392f1d1b6c35476ab2b Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 10 Apr 2025 12:56:38 -0700 Subject: [PATCH 0539/1769] [Mosaic GPU] Skip WGMMA with cluster example on non H100 GPUs. PiperOrigin-RevId: 746140286 --- tests/pallas/mosaic_gpu_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fca5c09f6b73..9ef1fcdea24e 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2413,7 +2413,7 @@ class WarpSpecializedPipelineWGTest( ... -class CoreMapTest(PallasTest): +class CoreMapTest(PallasTest, jtu.CudaArchSpecificTest): def test_multiple_wg(self): @@ -2543,6 +2543,7 @@ def kernel(ref): def test_realistic_matmul_with_cluster(self): self.skip_if_wg_semantics() # Needs WGMMA to support slices. + self.skip_unless_sm90a() # Requires WGMMA. dtype = jnp.float16 swizzle = 128 From 2807ae4e34ffc122bfcb5f1a685c2e09e96be468 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 10 Apr 2025 13:12:33 -0700 Subject: [PATCH 0540/1769] [Pallas] Fix ()-shaped vectors being materialized in Pallas lowering. This fixes some non-intuitive errors where scalar-shaped values in VREGs were being used in operations that expected SREGs. PiperOrigin-RevId: 746146037 --- jax/_src/pallas/mosaic/lowering.py | 15 ++++++++---- tests/pallas/indexing_test.py | 28 ++++++++++++++++++++++ tests/pallas/ops_test.py | 21 ++++++++++++++++ tests/pallas/pallas_error_handling_test.py | 11 +++++---- 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2ded01275a0f..343c7d79aab6 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1504,10 +1504,13 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): starts, ) if load_aval != aval_out: - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype, - is_kernel_boundary=True)) - load_val = vector.shape_cast(vec_type, load_val) + if aval_out.shape: + vec_type = ir.VectorType.get(aval_out.shape, + _dtype_to_ir_type(aval_out.dtype, + is_kernel_boundary=True)) + load_val = vector.shape_cast(vec_type, load_val) + else: + load_val = vector.extract(load_val, [], [0] * len(load_aval.shape)) return _maybe_cast_load_to_bool(ctx, aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: @@ -1692,6 +1695,8 @@ def _masked_swap_lowering_rule( result = vector.load(mem_aval_vec_type, ref, starts) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) if mem_aval != aval_out: + if not aval_out.shape: + raise ValueError("Cannot swap scalars to VMEM.") # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) @@ -2174,6 +2179,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ), x, ) + if not ctx.avals_out[0].shape: + return vector.extract(x, [], [0] * len(ctx.avals_in[0].shape)) return vector.shape_cast( aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index c3f3fa6e80a8..5430009c5d28 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -641,6 +641,34 @@ def kernel(x_ref, indices, y_ref): )(x, indices) self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + def test_scalar_load_from_vmem(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Requires TPU v4 or later") + def kernel(x_ref, o_ref, sem_ref): + o_ref[...] = jnp.zeros_like(o_ref) + scalar_val = x_ref[1, 2] + # Use scalar_val in both async_copy and store. + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + desc = pltpu.make_async_copy( + o_ref.at[scalar_val], + o_ref.at[scalar_val + 1], + sem_ref, + ) + desc.start() + desc.wait() + + x = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.int32) + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 8, 128), jnp.int32), + grid=(1,), + scratch_shapes=[pltpu.SemaphoreType.DMA] + )(x) + expected = jnp.zeros_like(res) + expected = expected.at[6].set(jnp.ones((8, 128), jnp.int32) * 6) + expected = expected.at[7].set(jnp.ones((8, 128), jnp.int32) * 6) + self.assertArraysEqual(res, expected) + class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ff02c334f45c..90d8bafb3c91 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1748,6 +1748,27 @@ def f(x_ref, o_ref): expected = x.reshape(out_shape) np.testing.assert_allclose(f(x), expected) + def test_reshape_to_scalar(self): + self.skip_if_mosaic_gpu() + # Test reshapes from (1, 1) to (). + # Because TPUs distinguish between VREGs/SREGs this tests an implicit + # copy from VREG -> SREG that must be inserted by Pallas. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + ) + def f(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + vector_val = x_ref[1:2, 0:1] + scalar_val = jnp.reshape(vector_val, ()) + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + + in_shape = (4, 4) + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.int32).reshape(in_shape) + expected = jnp.zeros((8, 128), jnp.int32) + expected = expected.at[x[1, 0]].set(x[1, 0]) + np.testing.assert_allclose(f(x), expected) + def test_num_programs(self): self.skip_if_mosaic_gpu() diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index cd5ceecfc9a8..84e38f3d09db 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -92,7 +92,7 @@ def kernel_in_jitted_fn(x): tb_string = "".join(tb_string) self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") - def test_invalid_smem_vmem_verification_error(self): + def test_index_with_f32_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( @@ -105,7 +105,8 @@ def test_invalid_smem_vmem_verification_error(self): @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - output_ref[0, 0] = input_ref[0, 0] + idx = input_ref[0, 0] + output_ref[idx, 0] = input_ref[0, 0] # Test that a verification error is raised. This assert is a guard against # underlying changes in Pallas lowering. @@ -113,8 +114,8 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.VerificationError, - "'memref.store' op failed to verify that type of 'value' matches " - "element type of 'memref'", + "must be signless-integer-like or memref of signless-integer, " + "but got 'f32'" ): test_kernel(input_arr) @@ -125,7 +126,7 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + self.assertEndsWith(tb_string, "output_ref[idx, 0] = input_ref[0, 0]\n") def test_parse_location_string(self): name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) From 654b91b0949df535bdcb3c8138c8ce79a762635c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 10 Apr 2025 15:24:39 -0400 Subject: [PATCH 0541/1769] Fix grep for label on Read the Docs. --- .readthedocs.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index f80953e7233a..3b7ba275a0d6 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -13,10 +13,8 @@ build: post_checkout: # Skip building PRs unless tagged with the "documentation" label. - | - if [ "$READTHEDOCS_VERSION_TYPE" = "external" ] && (curl -s "https://api.github.com/repos/jax-ml/jax/issues/$READTHEDOCS_VERSION/labels" | grep -vq "https://api.github.com/repos/jax-ml/jax/labels/documentation") - then - exit 183; - fi + [ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0 + (curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183 # Build documentation in the docs/ directory with Sphinx sphinx: From 7e5966b1f3d2b23268cc93ef769a0a855cb15e65 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 10 Apr 2025 13:14:44 -0700 Subject: [PATCH 0542/1769] Make sure direct-linearize handles res_names correctly post vma in types being enabled by default PiperOrigin-RevId: 746146834 --- jax/experimental/shard_map.py | 39 +++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 855ac291df67..2d114c6c3a2b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1733,38 +1733,50 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - res_names = _all_newly_manual_mesh_names(mesh, auto, trace) + all_names = _all_newly_manual_mesh_names(mesh, auto, trace) @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + res_avals, _, _, _, _, _ = linearize_outs_thunk() out_names = out_names_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). - return (*({0: res_names} for _ in range(num_res_out)), *out_names) + if check_rep and config.varying_axes_in_types.value: + res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} + for a in res_avals] + else: + res_names = [{0: all_names}] * len(res_avals) + return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) - residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) non_fwd_res = all_fwd_results[:num_res_out] primals_out = all_fwd_results[num_res_out:] residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None - for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] + for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] with (_extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)), config._check_rep(check_rep)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() - residual_names = [in_names[f1] if f1 is not None else - out_names[f2] if f2 is not None else - {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] - new_in_names = (*residual_names, *({} for _ in range(len(env))), + res_avals_iter = iter(res_avals) + res_names = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_names.append(in_names[f1]) + elif f2 is not None: + res_names.append(out_names[f2]) + else: + if check_rep and config.varying_axes_in_types.value: + res_vma = next(res_avals_iter).vma + res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + else: + res_names.append({0: all_names}) + new_in_names = (*res_names, *({} for _ in range(len(env))), *(ax for ax, nz in zip(in_names, nzs_in) if nz)) tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) @as_hashable_function(closure=tangent_out_names) @@ -1772,7 +1784,8 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=False, rewrite=rewrite, auto=auto) + check_rep=(check_rep if config.varying_axes_in_types.value else False), + rewrite=rewrite, auto=auto) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): From 48e14dcc0cb9c09d347f81b10ed3dc8e5cbc50e0 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 10 Apr 2025 13:16:34 -0700 Subject: [PATCH 0543/1769] Implement mutation by replacing the contents of a jax.Array with a result jax.Array. PiperOrigin-RevId: 746147571 --- jax/_src/interpreters/pxla.py | 6 +++++- jaxlib/xla/py_array.cc | 36 +++++++++++++++++++++++++++++++++++ jaxlib/xla/py_array.h | 4 ++++ jaxlib/xla/xla_client.py | 2 +- 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 45bdd4e17e8e..fb7e352bf600 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -57,6 +57,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla +from jax._src.lib import jaxlib_extension_version from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir @@ -1314,7 +1315,10 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._buf = o + if jaxlib_extension_version < 330: + args[i]._buf = o + else: + args[i]._buf._replace_with(o) else: out_.append(o) return out_ diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index ce5ceacbad99..84b7a726fbc4 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -1479,6 +1479,31 @@ absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); } +absl::Status PyArray::ReplaceWithAlias(PyArray o) { + auto& storage = GetStorage(); + auto& o_storage = o.GetStorage(); + if (storage.py_client.get() != o_storage.py_client.get()) { + return absl::InvalidArgumentError( + "Unable to replace a PyArray with a PyArray from a different client."); + } + storage.aval = o_storage.aval; + storage.weak_type = o_storage.weak_type; + storage.dtype = o_storage.dtype; + storage.shape = o_storage.shape; + storage.sharding = o_storage.sharding; + storage.npy_value = o_storage.npy_value; + storage.committed = o_storage.committed; + storage.traceback = o_storage.traceback; + storage.ifrt_array = o_storage.ifrt_array; + storage.fully_replicated_array = o_storage.fully_replicated_array; + storage.py_arrays = o_storage.py_arrays; + storage.host_value.Clear(); + storage.dynamic_shape = o_storage.dynamic_shape; + storage.result_status = o_storage.result_status; + + return absl::OkStatus(); +} + std::vector PyClient::LiveArrays() const { std::vector result; for (auto& shard : arrays_) { @@ -1899,6 +1924,12 @@ absl::Status PyHostValue::CopyToHostAsync( return absl::OkStatus(); } +void PyHostValue::Clear() { + ready_ = {}; + value_ = {}; + string_array_contents_ = {}; +} + namespace { PyMemberDef PyBaseArray_members[] = { #if PY_VERSION_HEX < 0x030C0000 @@ -2059,6 +2090,11 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); }, nb::is_method()); + type.attr("_replace_with") = nb::cpp_function( + [](PyArray& self, PyArray& o) { + xla::ThrowIfError(self.ReplaceWithAlias(o)); + }, + nb::is_method()); type.attr("block_until_ready") = nb::cpp_function( [](PyArray self) -> nb::object { xla::ThrowIfError(self.BlockUntilReady()); diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index 7fa2434c7c9f..7c7a6fefe3a2 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -70,6 +70,8 @@ class PyHostValue { absl::StatusOr> AsNumPyArray( std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + void Clear(); + private: absl::Status CopyStringArrayToHostAsync( std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); @@ -314,6 +316,8 @@ class PyArray : public nanobind::object { static absl::Status BatchedBlockUntilReady( std::vector objs); + absl::Status ReplaceWithAlias(PyArray o); + private: absl::StatusOr AssertUnsharded(absl::string_view api); diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index b6eb5ae7fb39..543664682c08 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 329 +_version = 330 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 92be510f0b504d8f87a181721801ed759886dedc Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 10 Apr 2025 13:58:29 -0700 Subject: [PATCH 0544/1769] [Mosaic GPU] Implement warp-level thread semantics. Adds a new WarpMesh object which when used in conjunction with core_map, allows the user to drop into warp-level code rather than programming at the warpgroup level. PiperOrigin-RevId: 746163942 --- jax/_src/pallas/mosaic_gpu/core.py | 23 +++++ jax/_src/pallas/mosaic_gpu/lowering.py | 95 +++++++++++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 14 ++- jax/experimental/mosaic/gpu/__init__.py | 1 + .../mosaic/gpu/dialect_lowering.py | 6 +- .../mosaic/gpu/examples/flash_attention.py | 4 +- .../mosaic/gpu/fragmented_array.py | 2 +- jax/experimental/mosaic/gpu/launch_context.py | 3 +- jax/experimental/mosaic/gpu/utils.py | 32 ++++--- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 60 ++++++++++++ 11 files changed, 213 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d964a8a90144..bb08d8f090a7 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -152,6 +152,8 @@ class PrimitiveSemantics(enum.Enum): # Convenience constants for (lowering, primitive) thread semantics pairs. LANExWG_SEMANTICS = ( mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +LANExWARP_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warp) WGxWG_SEMANTICS = ( mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) @@ -748,6 +750,27 @@ def shape(self) -> collections.OrderedDict[object, int]: def discharges_effect(self, effect: jax_core.Effect): return effect is _wgmma_pipeline_effect or effect is _memory_effect +@dataclasses.dataclass(frozen=True, kw_only=True) +class WarpMesh: + """Represents a mesh over individual warps within a warpgroup. + + When used in conjunction with `core_map`, the warp ID will be visible + within the body of the wrapped scope by querying `lax.axis_index` with + the specified axis name. + """ + + _NUM_WARPS_PER_WARPGROUP: ClassVar[int] = 4 + axis_name: str + + @property + def shape(self): + return collections.OrderedDict([ + (self.axis_name, self._NUM_WARPS_PER_WARPGROUP), + ]) + + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False def _gpu_mesh_discharge_rule( in_avals, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3c6c1ee98b03..49e337d2ba49 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -299,6 +299,7 @@ class ModuleContext: program_ids: Sequence[ir.Value] | None approx_math: bool single_wg_lane_predicate: ir.Value | None + single_warp_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int tmem_requested_cols: int @@ -310,6 +311,21 @@ class ModuleContext: squashed_dims: tuple[int, ...] lowering_semantics: mgpu.LoweringSemantics primitive_semantics: gpu_core.PrimitiveSemantics + warp_axis_name: str | None = None + + @property + def single_lane_predicate(self) -> ir.Value: + """Returns a predicate that is True for a single lane within the current + thread semantics. + """ + assert self.lowering_semantics == mgpu.LoweringSemantics.Lane + match self.primitive_semantics: + case gpu_core.PrimitiveSemantics.Warpgroup: + return self.single_wg_lane_predicate + case gpu_core.PrimitiveSemantics.Warp: + return self.single_warp_lane_predicate + case _: + raise ValueError(f"Unknown semantics: {self.primitive_semantics}") @contextlib.contextmanager def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: @@ -737,16 +753,21 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_cols = 0 if lowering_semantics == mgpu.LoweringSemantics.Lane: - single_lane_predicate = mgpu.single_thread_predicate(per_block=False) + single_wg_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARPGROUP) + single_warp_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARP) else: # Warpgroup semantics do not have a single lane predicate. - single_lane_predicate = None + single_wg_lane_predicate = None + single_warp_lane_predicate = None module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, - single_lane_predicate, + single_wg_lane_predicate, + single_warp_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, tmem_requested_cols=tmem_cols, @@ -826,6 +847,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , + gpu_core.LANExWARP_SEMANTICS: {} , # Lowering rules when using Mosaic GPU warpgroup semantics. (mgpu.LoweringSemantics.Warpgroup, gpu_core.PrimitiveSemantics.Warpgroup): {}, @@ -1372,11 +1394,17 @@ def _broadcast_in_dim_lowering_rule_wg( @register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding [x_aval] = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if x_aval.shape != (): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") return _ensure_fa(x, x_aval.dtype).astype( mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) ) @@ -1489,11 +1517,15 @@ def convert(ty, x): def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) - -mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ +for semantics in [gpu_core.LANExWG_SEMANTICS, gpu_core.LANExWARP_SEMANTICS]: + mosaic_lowering_rules[semantics].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1509,8 +1541,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), lax.max_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.max(y)), lax.min_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.min(y)), -}) - + }) def _binary_op_lowering_rule_wg( ctx: LoweringRuleContext, x, y, *, ui_impl, si_impl, f_impl=None @@ -1903,6 +1934,15 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): _block_id(ctx, gpu_dialect.Dimension(idx)), ) +@register_lowering_rule(lax.axis_index_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + if axis_name == ctx.module_ctx.warp_axis_name: + return mgpu.warp_idx(sync=True) + raise ValueError( + "Named axes can only refer to the warp axis name inside of core_map." + ) + @register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) def _debug_print_lowering_rule( @@ -2307,6 +2347,8 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -2416,6 +2458,45 @@ def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): return (result,) if len(ctx.avals_in) == 1 else result +@register_lowering_rule(pallas_core.core_map_p, mgpu.LoweringSemantics.Lane) +def _core_map_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr, + mesh, + **_, +): + if isinstance(mesh, gpu_core.WarpMesh): + # A core_map over a WarpMesh represents a fork/join over individual + # warps in a warpgroup. + if (ctx.module_ctx.warp_axis_name or + ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp): + raise LoweringError( + "Cannot nest core_maps. Already under core_map with warp_axis_name " + f"{ctx.module_ctx.warp_axis_name}.") + module_ctx = dataclasses.replace( + ctx.module_ctx, + warp_axis_name=mesh.axis_name, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, + ) + for aval_in in ctx.avals_in: + if isinstance(aval_in, jax_core.ShapedArray) and aval_in.shape: + raise LoweringError( + "Can only close over scalars and Refs when using core_map with " + f"WarpMesh. Found array of shape {aval_in}." + ) + _ = lower_jaxpr_to_mosaic_gpu( + module_ctx, + ctx.launch_ctx, + jaxpr, + args=(), + consts=args, + ) + mgpu.warpgroup_barrier() + return [] + raise ValueError(f"Unsupported mesh: {mesh}") + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index f4ccc2865cc0..99b42c3e1518 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -218,6 +218,9 @@ def _copy_smem_to_gmem_pp_eqn( @lowering.register_lowering_rule( copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) @lowering.register_lowering_rule( copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) @@ -240,12 +243,12 @@ def _copy_smem_to_gmem_lowering( if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if predicate is not None: - assert ctx.module_ctx.single_wg_lane_predicate is not None + assert ctx.module_ctx.single_lane_predicate is not None predicate = arith_dialect.andi( - predicate, ctx.module_ctx.single_wg_lane_predicate + predicate, ctx.module_ctx.single_lane_predicate ) else: - predicate = ctx.module_ctx.single_wg_lane_predicate + predicate = ctx.module_ctx.single_lane_predicate flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, @@ -443,6 +446,9 @@ def _copy_gmem_to_smem_pp_eqn( @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) @@ -506,7 +512,7 @@ def _copy_gmem_to_smem_lowering( dst_ref=dst, barrier=barrier, arrive=False, - predicate=ctx.module_ctx.single_wg_lane_predicate, + predicate=ctx.module_ctx.single_lane_predicate, collective=collective, **copy_params, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 3dc73606852c..c1275396036c 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -72,6 +72,7 @@ DynamicSlice as DynamicSlice, Partition as Partition, Partition1D as Partition1D, + ThreadSubset as ThreadSubset, bitwidth as bitwidth, bytewidth as bytewidth, c as c, diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 31f2fdb04bb1..1239a20ba865 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1171,9 +1171,11 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: sub_op.operation.regions[0].blocks[0] ): assert block_predicate is None - block_predicate = utils.single_thread_predicate(per_block=True) + block_predicate = utils.single_thread_predicate( + scope=utils.ThreadSubset.BLOCK + ) warpgroup_predicate = utils.single_thread_predicate( - per_block=False + scope=utils.ThreadSubset.WARPGROUP ) if block_predicate is None: diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index dc59dda3a6e5..57f30b8603c8 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -299,7 +299,7 @@ def kv_loop(kv_step, carry): scf.yield_([]) with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) @@ -391,7 +391,7 @@ def only_wg(idx): kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def kv_copy_init(slot, kv_seq_base): - with single_thread(per_block=False): + with single_thread(ThreadSubset.WARPGROUP): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index df1e03627f94..76f7d549cf55 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1588,7 +1588,7 @@ def reduce_sum(self, scratch: ir.Value | None = None): memref.store(warp_result, scratch, [warp_id]) utils.warpgroup_barrier() zero_index = c(0, index) - with mgpu.single_thread(per_block=False): + with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 9d3ab8c2e744..64cdedc779c8 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -657,7 +657,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ] uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) + functools.partial( + utils.single_thread, scope=utils.ThreadSubset.WARPGROUP) if uniform and predicate is None else contextlib.nullcontext ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 47401440fac2..3c7532dde99d 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -164,7 +164,7 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) ctx = ( - functools.partial(single_thread, per_block=False) + functools.partial(single_thread, scope=ThreadSubset.WARPGROUP) if uniform else contextlib.nullcontext ) @@ -258,6 +258,7 @@ def warpgroup_idx(sync=True): class ThreadSubset(enum.IntEnum): + WARP = enum.auto() WARPGROUP = enum.auto() BLOCK = enum.auto() @@ -266,25 +267,34 @@ class ThreadSubset(enum.IntEnum): _ONCE_PER: ThreadSubset | None = None -def single_thread_predicate(per_block=True): +def single_thread_predicate(scope: ThreadSubset = ThreadSubset.BLOCK): + """Returns a predicate that selects a single thread. + + Args: + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. + """ + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + if scope == ThreadSubset.WARP: + return elected warp = warp_idx() - if not per_block: + if scope is not ThreadSubset.BLOCK: warp = arith.remui(warp, c(4, warp.type)) first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) return arith.andi(first_warp, elected) @contextlib.contextmanager -def single_thread(per_block=True): +def single_thread(scope: ThreadSubset = ThreadSubset.BLOCK): """Runs the context only from a single thread. Args: - per_block: If True, only one thread per block will run the context. - Otherwise, only one thread per warp group will run the context. + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. """ global _ONCE_PER - scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP # If we're already in a single-thread context, we don't have to do anything. if _ONCE_PER is not None and _ONCE_PER >= scope: yield @@ -293,7 +303,7 @@ def single_thread(per_block=True): prev_scope = _ONCE_PER _ONCE_PER = scope try: - if_op = scf.IfOp(single_thread_predicate(per_block)) + if_op = scf.IfOp(single_thread_predicate(scope)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -708,7 +718,7 @@ def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), phases, []) - with single_thread(per_block=True): + with single_thread(scope=ThreadSubset.BLOCK): for i in range(num_barriers): nvvm.mbarrier_init_shared( llvm.getelementptr(ptr, address, [], [i], i64), @@ -870,7 +880,7 @@ def arrive(self): if self.barrier.num_barriers != 1: raise ValueError("Can only arrive on a single barrier") if self.cluster_mask is None: - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): self.barrier.arrive() return i32 = ir.IntegerType.get_signless(32) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 85e512d03290..63ace2baa64e 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -32,6 +32,7 @@ from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9ef1fcdea24e..fd4d5bb73f52 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -29,6 +29,7 @@ from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.pallas import pallas_call +from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline @@ -1524,6 +1525,64 @@ def kernel(x_ref, y_ref, o_ref): y = jax.lax.iota(jnp.float32, 128) * 3 np.testing.assert_array_equal(kernel(x, y), x + y) + def test_warp_specialization_axis_index(self): + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) + def kernel(y_ref): + def scope(ones_smem_ref, threes_smem_ref): + # Prepare data to copy. + ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) + threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + # We cannot load/store inside of core_map, so we issue async + # copies instead to produce a testable result. + @pl.when(warp_id == 1) + def _(): + plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) + @pl.when(warp_id == 3) + def _(): + plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.SMEM((1, 128), jnp.int32), + plgpu.SMEM((1, 128), jnp.int32) + ) + result = kernel() + expected = jnp.stack((jnp.ones((128,), jnp.int32), + jnp.ones((128,), jnp.int32) * 3), axis=0) + np.testing.assert_array_equal(result, expected) + + def test_warp_mesh_errors_when_closing_over_array(self): + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + # We currently do not allow closing over arrays when mapping over + # a mesh, since we would need to present a view of the array local + # to each warp. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), + scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) + def kernel(out_ref, smem_ref): + arr = jnp.ones((32, 32), dtype=jnp.float32) + @pl.core_map(warp_mesh) + def _(): + smem_ref[...] = arr + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, out_ref) + plgpu.wait_smem_to_gmem(0) + with self.assertRaisesRegex( + mgpu_lowering.LoweringError, + "Can only close over scalars and Refs when using core_map with " + "WarpMesh", + ): + kernel() + class PallasCallWGTest( PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup @@ -1549,6 +1608,7 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.broadcasted_iota_p, mgpu_primitives.load_p, lax.slice_p, + pallas_core.core_map_p, } # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. From cf8a52463c4d077e78ad8ecca67cf5b09c62e4f4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 10 Apr 2025 14:00:16 -0700 Subject: [PATCH 0545/1769] Update test shardings. This change primarily reduces sharding, although in a few cases it also increases shardings. It is harmful to performance to overshard tests since there's a startup and teardown cost to each test run. In a few cases, change tests to be non-accelerator tests. PiperOrigin-RevId: 746164539 --- tests/BUILD | 144 +++++++++++++++++----------------------------------- 1 file changed, 47 insertions(+), 97 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 5e2272b57f4f..876b760bc4d3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -60,10 +60,13 @@ jax_multiplatform_test( srcs = ["device_test.py"], ) -jax_multiplatform_test( +jax_py_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], - shard_count = 2, + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -206,7 +209,7 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan. }, shard_count = { - "tpu": 20, + "tpu": 10, "cpu": 20, "gpu": 10, }, @@ -244,9 +247,7 @@ jax_multiplatform_test( # using matplotlib plots # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, shard_count = { - "cpu": 48, - "gpu": 48, - "tpu": 48, + "cpu": 8, }, deps = [ "//jax:experimental_sparse", @@ -257,9 +258,9 @@ jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, - "tpu": 40, + "tpu": 15, }, ) @@ -285,9 +286,6 @@ jax_multiplatform_test( "gpu_p100x2_shardy", "tpu_v5e_4x2_shardy", ], - shard_count = { - "tpu": 5, - }, deps = [ "//jax:experimental", ], @@ -307,9 +305,8 @@ jax_multiplatform_test( "gpu_h100x2", ], shard_count = { - "cpu": 5, - "gpu": 5, - "tpu": 5, + "cpu": 3, + "tpu": 4, }, tags = ["multiaccelerator"], deps = [ @@ -423,8 +420,8 @@ jax_multiplatform_test( srcs = ["image_test.py"], shard_count = { "cpu": 10, - "gpu": 20, - "tpu": 10, + "gpu": 10, + "tpu": 8, }, tags = ["noasan"], # Linking TF causes a linker OOM. deps = py_deps("pil") + py_deps("tensorflow_core"), @@ -433,8 +430,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], - deps = [ - ], ) jax_multiplatform_test( @@ -468,7 +463,7 @@ jax_multiplatform_test( srcs = ["jet_test.py"], shard_count = { "cpu": 10, - "gpu": 10, + "gpu": 4, }, deps = [ "//jax:jet", @@ -481,8 +476,8 @@ jax_multiplatform_test( srcs = ["lax_control_flow_test.py"], shard_count = { "cpu": 30, - "gpu": 40, - "tpu": 30, + "gpu": 30, + "tpu": 20, }, ) @@ -547,11 +542,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( @@ -559,8 +549,8 @@ jax_multiplatform_test( srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, ) @@ -573,9 +563,9 @@ jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { - "cpu": 20, + "cpu": 30, "gpu": 20, - "tpu": 20, + "tpu": 8, }, deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) @@ -588,8 +578,8 @@ jax_multiplatform_test( }, shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, ) @@ -605,7 +595,7 @@ jax_multiplatform_test( }, shard_count = { "cpu": 20, - "gpu": 20, + "gpu": 30, "tpu": 20, }, deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), @@ -616,8 +606,8 @@ jax_multiplatform_test( srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { "cpu": 20, - "gpu": 10, - "tpu": 10, + "gpu": 8, + "tpu": 8, }, deps = [ "//jax:internal_test_util", @@ -658,7 +648,7 @@ jax_multiplatform_test( srcs = ["lax_autodiff_test.py"], shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 30, "tpu": 20, }, ) @@ -831,7 +821,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 30, - "gpu": 30, + "gpu": 10, "tpu": 30, }, tags = ["multiaccelerator"], @@ -846,7 +836,7 @@ jax_multiplatform_test( # No implementation of nonsymmetric Eigendecomposition. enable_backends = ["cpu"], shard_count = { - "cpu": 10, + "cpu": 5, }, # This test ends up calling Fortran code that initializes some memory and # passes it to C code. MSan is not able to detect that the memory was @@ -907,29 +897,12 @@ jax_multiplatform_test( "notsan", # Times out ], }, - shard_count = 10, + shard_count = 8, ) jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], - backend_tags = { - "cpu": [ - "notsan", # Times out - "nomsan", # Times out - ], - "tpu": [ - "optonly", - "nomsan", # Times out - "notsan", # Times out - ], - }, - shard_count = { - "cpu": 30, - "gpu": 30, - "tpu": 40, - }, - tags = ["noasan"], # Times out ) jax_multiplatform_test( @@ -962,25 +935,7 @@ jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], - backend_tags = { - "cpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - ], - "tpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - "optonly", - ], - }, main = "random_test.py", - shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, - }, ) jax_multiplatform_test( @@ -1049,9 +1004,9 @@ jax_multiplatform_test( "tpu": ["nomsan"], # Times out }, shard_count = { - "cpu": 40, - "gpu": 30, - "tpu": 40, + "cpu": 50, + "gpu": 50, + "tpu": 50, }, tags = [ "noasan", @@ -1078,8 +1033,8 @@ jax_multiplatform_test( }, shard_count = { "cpu": 50, - "gpu": 50, - "tpu": 50, + "gpu": 30, + "tpu": 20, }, tags = [ "noasan", @@ -1182,10 +1137,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - shard_count = { - "cpu": 5, - "gpu": 5, - }, deps = ["//jax:stax"], ) @@ -1314,7 +1265,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], - shard_count = 10, + shard_count = { + "cpu": 5, + "gpu": 5, + "tpu": 10, + }, ) jax_py_test( @@ -1337,9 +1292,13 @@ jax_multiplatform_test( srcs = ["garbage_collection_guard_test.py"], ) -jax_multiplatform_test( +jax_py_test( name = "name_stack_test", srcs = ["name_stack_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1437,8 +1396,6 @@ jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { - "cpu": 20, - "gpu": 10, "tpu": 20, }, ) @@ -1456,10 +1413,6 @@ jax_multiplatform_test( enable_configs = [ "gpu_p100x2_shardy", ], - shard_count = { - "gpu": 5, - "tpu": 5, - }, tags = [ "multiaccelerator", ], @@ -1583,9 +1536,9 @@ jax_multiplatform_test( "cpu_x32", ], shard_count = { - "cpu": 4, - "gpu": 6, - "tpu": 4, + "cpu": 30, + "gpu": 20, + "tpu": 25, }, tags = [ "noasan", # Times out @@ -1633,9 +1586,6 @@ jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], enable_backends = ["gpu"], - shard_count = { - "gpu": 4, - }, tags = ["multiaccelerator"], ) From 3864c4f335d1d236d5367264f3885dfce8721d9d Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 10 Apr 2025 15:08:11 -0700 Subject: [PATCH 0546/1769] Allow ctrl-c to cancel block_until_ready(). Partially addresses: https://github.com/jax-ml/jax/issues/18246. If compile can also be a future, this code can be used to safely block on that as well. PiperOrigin-RevId: 746189742 --- jaxlib/xla/BUILD | 6 ++++++ jaxlib/xla/py_array.cc | 6 +++++- jaxlib/xla/util.cc | 27 ++++++++++++++++++++++++++- jaxlib/xla/util.h | 3 +++ 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 56990712c441..a6a4cf660408 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -869,9 +869,15 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@nanobind", "@xla//xla:util", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/python:version", "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:async_value", "@xla//xla/tsl/concurrency:ref_count", ], ) diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index 84b7a726fbc4..e3321c4c88ce 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -814,6 +814,7 @@ absl::Status PyArray::BlockUntilResultStatusIsReady() { if (!result_status.IsReady()) { // Only release the gil if we need to Await(). nb::gil_scoped_release release_gil; + BlockUntilReadyWithCancel(result_status); return result_status.Await(); } return result_status.Await(); @@ -1761,7 +1762,9 @@ absl::StatusOr> PyHostValue::AsNumPyArray( nb::gil_scoped_release gil; TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, pjrt_buffer->AcquireExternalReference()); - TF_RETURN_IF_ERROR(ifrt_array->GetReadyFuture().Await()); + auto fut = ifrt_array->GetReadyFuture(); + BlockUntilReadyWithCancel(fut); + TF_RETURN_IF_ERROR(fut.Await()); } void* data = hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); @@ -1775,6 +1778,7 @@ absl::StatusOr> PyHostValue::AsNumPyArray( TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); if (!ready_.IsReady()) { nb::gil_scoped_release gil; + BlockUntilReadyWithCancel(ready_); TF_RETURN_IF_ERROR(ready_.Await()); } else { TF_RETURN_IF_ERROR(ready_.Await()); diff --git a/jaxlib/xla/util.cc b/jaxlib/xla/util.cc index ef0fb2ac3afd..5fb3f352ba2c 100644 --- a/jaxlib/xla/util.cc +++ b/jaxlib/xla/util.cc @@ -15,19 +15,44 @@ limitations under the License. #include "jaxlib/xla/util.h" +#include #include #include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_future.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/value.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" namespace xla { +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) { +#if JAX_IFRT_VERSION_NUMBER >= 5 + future.BlockUntilReady([](tsl::AsyncValue* value) { + auto state = std::make_shared(); + value->AndThen([state]() { state->Notify(); }); + while (true) { + if (state->WaitForNotificationWithTimeout(absl::Milliseconds(200))) { + break; + } + nanobind::gil_scoped_acquire gil_acquire; + if (PyErr_CheckSignals() != 0) { + throw nanobind::python_error(); + } + } + }); +#endif +} + absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { if (ifrt_arrays.empty()) { return absl::OkStatus(); @@ -45,7 +70,7 @@ absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { ifrt::Client* const client = ifrt_arrays.front()->client(); future = client->GetReadyFuture(values); } - + BlockUntilReadyWithCancel(future); absl::Status s = future.Await(); if (!s.ok()) { // Fix up error string because some clients rely on it. diff --git a/jaxlib/xla/util.h b/jaxlib/xla/util.h index ef5fc735fc33..ed3b03d733dd 100644 --- a/jaxlib/xla/util.h +++ b/jaxlib/xla/util.h @@ -22,6 +22,9 @@ limitations under the License. namespace xla { +// Waits until future is ready but will cancel if ctrl-c is pressed. +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future); + // Requests if given buffers are ready, awaits for results and returns OK if // all of the buffers are ready or the last non-ok status. absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); From 59068ae679ce0a97bc4d11992a8a6ba6539e9142 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 10 Apr 2025 16:25:52 -0700 Subject: [PATCH 0547/1769] Remove unused jaxlib_mlir_capi targets. Also remove some unnecessary LINKOPTS. These are no longer needed now we use the pywrap rules instead. PiperOrigin-RevId: 746216832 --- jaxlib/jax.bzl | 76 ----------------- jaxlib/mlir/_mlir_libs/BUILD.bazel | 132 +---------------------------- jaxlib/tools/BUILD.bazel | 5 +- 3 files changed, 2 insertions(+), 211 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index dadb62f7117e..3c234f5f8c37 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -160,82 +160,6 @@ def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pyt def py_extension(name, srcs, copts, deps, linkopts = []): nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) -def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []): - """Workaround DLL building issue. - - 1. cc_binary with linkshared enabled cannot produce DLL with symbol - correctly exported. - 2. Even if the DLL is correctly built, the resulting target cannot be - correctly consumed by other targets. - - Args: - name: the name of the output target - out: the name of the output DLL filename - deps: deps - srcs: srcs - """ - - # create a dummy library to get the *.def file - dummy_library_name = name + ".dummy.dll" - native.cc_binary( - name = dummy_library_name, - linkshared = 1, - linkstatic = 1, - deps = deps, - target_compatible_with = ["@platforms//os:windows"], - ) - - # .def file with all symbols, not usable - full_def_name = name + ".full.def" - native.filegroup( - name = full_def_name, - srcs = [dummy_library_name], - output_group = "def_file", - target_compatible_with = ["@platforms//os:windows"], - ) - - # say filtered_symbol_prefixes == ["mlir", "chlo"], then construct the regex - # pattern as "^\\s*(mlir|clho)" to use grep - pattern = "^\\s*(" + "|".join(exported_symbol_prefixes) + ")" - - # filtered def_file, only the needed symbols are included - filtered_def_name = name + ".filtered.def" - filtered_def_file = out + ".def" - native.genrule( - name = filtered_def_name, - srcs = [full_def_name], - outs = [filtered_def_file], - cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep -E '{}' $(location :{}) >> $@""".format(out, pattern, full_def_name), - target_compatible_with = ["@platforms//os:windows"], - ) - - # create the desired library - native.cc_binary( - name = out, # this name must be correct, it will be the filename - linkshared = 1, - deps = deps, - win_def_file = filtered_def_file, - target_compatible_with = ["@platforms//os:windows"], - ) - - # however, the created cc_library (a shared library) cannot be correctly - # consumed by other cc_*... - interface_library_file = out + ".if.lib" - native.filegroup( - name = interface_library_file, - srcs = [out], - output_group = "interface_library", - target_compatible_with = ["@platforms//os:windows"], - ) - - # but this one can be correctly consumed, this is our final product - native.cc_import( - name = name, - interface_library = interface_library_file, - shared_library = out, - target_compatible_with = ["@platforms//os:windows"], - ) - ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 6599e50695d4..25f2162685b9 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -15,7 +15,6 @@ load( "//jaxlib:jax.bzl", "if_windows", - "windows_cc_shared_mlir_library", ) load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") @@ -32,24 +31,12 @@ COPTS = [ "-frtti", ] -LINKOPTS = select({ - "@xla//xla/tsl:macos": [ - "-Wl,-rpath,@loader_path/", - "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", - ], - "@xla//xla/tsl:windows": [], - "//conditions:default": [ - "-Wl,-rpath,$$ORIGIN/", - ], -}) - nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -63,7 +50,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:CAPIIR", @@ -78,7 +64,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -92,7 +77,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:CAPINVGPU", @@ -107,7 +91,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:CAPILLVM", @@ -122,7 +105,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -136,7 +118,6 @@ nanobind_pywrap_extension( "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -148,7 +129,6 @@ nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", "@llvm-project//mlir:CAPIIR", @@ -166,7 +146,6 @@ nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", @@ -194,7 +173,6 @@ nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ "@nanobind", @@ -222,37 +200,11 @@ symlink_inputs( ], ) -cc_library( - name = "jaxlib_mlir_capi_shims", - srcs = ["jaxlib_mlir_capi_shims.cc"], - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:GPUPipelines", - "@llvm-project//mlir:GPUToLLVMIRTranslation", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:NVVMTarget", - "@llvm-project//mlir:NVVMToLLVMIRTranslation", - ], - alwayslink = 1, -) - -cc_library( - name = "jaxlib_mlir_capi_shims_hdrs", - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:CAPIIR", - ], -) - # JAX-specific registrations. nanobind_pywrap_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ "//jaxlib/mosaic/gpu:mlir_capi", "@llvm-project//mlir:CAPIArith", @@ -283,7 +235,6 @@ nanobind_pywrap_extension( "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -303,7 +254,6 @@ nanobind_pywrap_extension( "@shardy//shardy/integrations/python/ir:sdy_module.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", @@ -324,7 +274,6 @@ nanobind_pywrap_extension( "@stablehlo//:chlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", @@ -340,7 +289,6 @@ nanobind_pywrap_extension( "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", @@ -349,82 +297,4 @@ nanobind_pywrap_extension( "@nanobind", "@stablehlo//:stablehlo_capi", ], -) - -# Shared C++ extension library - -cc_library( - name = "jaxlib_mlir_capi_shared_library", - srcs = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"], - "@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"], - "//conditions:default": [":libjaxlib_mlir_capi.so"], - }), - deps = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"], - "//conditions:default": [], - }), -) - -cc_library( - name = "jaxlib_mlir_capi_objects", - deps = [ - "//jaxlib/mosaic:tpu_dialect_capi_objects", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", - "//jaxlib/mosaic/gpu:mlir_capi_objects", - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIGPUObjects", - "@llvm-project//mlir:CAPIIRObjects", - "@llvm-project//mlir:CAPILLVMObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", - "@llvm-project//mlir:CAPINVGPUObjects", - "@llvm-project//mlir:CAPINVVMObjects", - "@llvm-project//mlir:CAPISCFObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@shardy//shardy/integrations/c:sdy_capi_objects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", - "@xla//xla/mlir_hlo:CAPIObjects", - ] + if_windows( - [], - [ - "//jaxlib/triton:triton_dialect_capi_objects", - ], - ), -) - -cc_binary( - name = "libjaxlib_mlir_capi.so", - linkopts = [ - "-Wl,-soname=libjaxlib_mlir_capi.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -cc_binary( - name = "libjaxlib_mlir_capi.dylib", - linkopts = [ - "-Wl,-rpath,@loader_path/", - "-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -windows_cc_shared_mlir_library( - name = "jaxlib_mlir_capi_dll", - out = "jaxlib_mlir_capi.dll", - exported_symbol_prefixes = [ - "mlir", - "chlo", - "sdy", - "stablehlo", - ], - deps = [":jaxlib_mlir_capi_objects"], -) +) \ No newline at end of file diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 5b182817339b..3dffe556d821 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -29,7 +29,6 @@ load( load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", - "if_windows", "jax_py_test", "jax_wheel", "pytype_strict_library", @@ -71,9 +70,7 @@ py_binary( "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", - ] + if_windows([ - "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]), + ], deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", From 41a8805d9603ddd1aec105467aed0b0a0e80757e Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 10 Apr 2025 16:27:45 -0700 Subject: [PATCH 0548/1769] [pallas:mgpu] Return types allowed in mgpu.inline_mgpu. PiperOrigin-RevId: 746217405 --- jax/_src/pallas/mosaic_gpu/primitives.py | 190 +++++++++++++++++------ jax/experimental/pallas/mosaic_gpu.py | 3 +- tests/pallas/mosaic_gpu_test.py | 24 ++- 3 files changed, 161 insertions(+), 56 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 99b42c3e1518..070e37f64e2c 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1147,7 +1147,7 @@ def check_no_args(): case Layout.WG_SPLAT: return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: - return mgpu.WGStridedFragLayout(*args, **kwargs) + return mgpu.WGStridedFragLayout(*args, **kwargs) # pytype: disable=missing-parameter @dataclasses.dataclass(frozen=True) class ParameterizedLayout: @@ -1462,6 +1462,14 @@ def jaxpr_call( program_ids_treedef=program_ids_treedef, ) + +@dataclasses.dataclass(frozen=True) +class GPUShapeDtypeStruct: + shape: tuple[int, ...] + dtype: jnp.dtype + layout: ParameterizedLayout | Layout + + inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") inline_mgpu_p.multiple_results = True @@ -1471,39 +1479,96 @@ class RefType: ... -def inline_mgpu(*args, arg_types): - flat_args, treedef = jax.tree.flatten(tuple(args)) - flat_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) - if treedef != treedef_ty: - raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") - - # Strip the transforms from the refs since they will be recorded in - # the types. - raw_refs_flat_args = [] - for a, t in zip(flat_args, flat_types): - def traced_ty(ty): - return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) - - if isinstance(t, (ParameterizedLayout, Layout)) and traced_ty(jax_core.ShapedArray): - raw_refs_flat_args.append(a) - elif isinstance(t, RefType) and traced_ty(_Ref): - ref, transforms = a, () - if isinstance(a, state_types.TransformedRef): - ref, transforms = ref.ref, ref.transforms - - raw_refs_flat_args.append(ref) - if transforms: - raise NotImplementedError("Transformed refs (or types) are not supported.") - else: - raise ValueError(f"Mismatched type: {a, t}") +def inline_mgpu(arg_types=(), return_type=None): + """Decorate a function that inlines mgpu code. - def inner(f): - return inline_mgpu_p.bind( - *raw_refs_flat_args, - args_treedef=treedef, - flat_types=flat_types, - mgpu_fn=f, + Arguments provided to the decorated function may be Pallas + references or array values. The body will accept the corresponding + mgpu values. + + The decorated function may return a tree of `FragmentedArray`s. + + ``` + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.GPUShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def foo(ctx, smem_ref): + del ctx + x = mgpu.FragmentedArray.load_tiled(smem_ref, ) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return (x + y) + + arr = foo(smem_ref) + ``` + + Args: + + arg_types: a sequence of pytrees where the leaves are `RefType` or + `Layout` for references or arrays respectively as the return + type. + + return_type: A pytree where the leaves are `GPUShapeDtypeStruct` + represeinting the arrays returned by the decorated function. + + Returns: + A decorator that creates a function that inlines mgpu code. + + """ + flat_arg_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + flat_ret_ty, pytree_ret_ty = jax.tree.flatten(return_type) + if return_type and not all(isinstance(r, GPUShapeDtypeStruct) for r in flat_ret_ty): + raise ValueError( + "inline_mgpu_p only supports GPUShapeDtypeStructx return types." + ) + if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types): + raise ValueError( + "inline_mgpu_p only supports only Layout, ParameterizedLayout and" + " RefType arg types." ) + + def inner(f): + def wrapper(*args): + flat_args, treedef = jax.tree.flatten(tuple(args)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + raw_refs_flat_args = [] + for a, t in zip(flat_args, flat_arg_types): + def traced_ty(ty): + return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) + + if isinstance(t, ParameterizedLayout) and traced_ty(jax_core.ShapedArray): + raw_refs_flat_args.append(a) + elif isinstance(t, RefType) and traced_ty(_Ref): + ref, transforms = a, () + if isinstance(a, state_types.TransformedRef): + ref, transforms = ref.ref, ref.transforms + + raw_refs_flat_args.append(ref) + if transforms: + raise NotImplementedError("Transformed refs (or types) are not supported.") + else: + raise ValueError(f"Mismatched type: {a, t}") + + flat_ret = inline_mgpu_p.bind( + *flat_args, + args_treedef=treedef, + flat_ret_ty=flat_ret_ty, + pytree_ret_ty=pytree_ret_ty, + flat_arg_types=flat_arg_types, + mgpu_fn=f, + ) + return jax.tree.unflatten(pytree_ret_ty, flat_ret) + return wrapper + return inner @@ -1511,12 +1576,17 @@ def inner(f): def _inline_mgpu_abstract_eval( *flat_args, args_treedef, - flat_types, + flat_arg_types, + flat_ret_ty, + pytree_ret_ty, mgpu_fn, ): - del args_treedef, flat_types, mgpu_fn # Unused. + del args_treedef, flat_arg_types, pytree_ret_ty, mgpu_fn # Unused. + aval_return = tuple( + jax_core.ShapedArray(x.shape, x.dtype) for x in flat_ret_ty + ) # TODO(cperivol): Let the user set the effects. - return (), { + return aval_return, { gpu_core._wgmma_pipeline_effect, gpu_core._memory_effect, *itertools.chain.from_iterable( @@ -1529,27 +1599,51 @@ def _inline_mgpu_abstract_eval( @discharge.register_partial_discharge_rule(inline_mgpu_p) def _inline_mgpu_discharge(*args, **kwargs): + del args, kwargs raise NotImplementedError("inline_mgpu_p does not support discharge.") + +def _type_check_mgpu(v, ty): + match (ty, v): + case (RefType(), ir.Value()) if ir.MemRefType.isinstance(v.type): + pass + case (GPUShapeDtypeStruct(), mgpu.FragmentedArray()): + mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) + if v.mlir_dtype != mlir_dtype or ty.shape != v.shape or v.layout != ty.layout.to_mgpu(): + raise ValueError(f"Array type mismatch at {v} != {ty}.") + case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()): + if ty.to_mgpu() != v.layout: + raise ValueError(f"Unexpected layout for {v} (expected: {ty})") + case _: + raise ValueError(f"Unexpected type {ty} for value {v}") + + @lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) def _inline_mgpu_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, mgpu_fn: Callable[..., Any], - flat_types, + flat_arg_types, + flat_ret_ty, + pytree_ret_ty, args_treedef, ): - for a, t in zip(flat_args, flat_types, strict=True): - match a: - case ir.Value() if ir.MemRefType.isinstance(a.type): - # We checked the memory spaces at tracing time. - pass - case mgpu.FragmentedArray(): - if a.layout != t.to_mgpu(): - raise ValueError(f"Unexpected layout for {a} (expected: {t})") - case _: - raise ValueError(f"Unexpected argument {a}") + for a, t in zip(flat_args, flat_arg_types): + _type_check_mgpu(a, t) args = jax.tree.unflatten(args_treedef, flat_args) - mgpu_fn(ctx.launch_ctx, *args) - return () + ret = mgpu_fn(ctx.launch_ctx, *args) + ret_leaves, ret_tree = jax.tree.flatten( + ret, is_leaf=lambda x: isinstance(x, mgpu.FragmentedArray) + ) + + if ret_tree != pytree_ret_ty: + return_type = jax.tree.unflatten(pytree_ret_ty, flat_ret_ty) + raise ValueError( + f"inline_mgpu_p return type tree mismatch: {ret} != {return_type}" + ) + + for ty, r in zip(flat_ret_ty, ret_leaves): + _type_check_mgpu(r, ty) + + return ret_leaves diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 63ace2baa64e..d74ffe6eae1b 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -21,8 +21,8 @@ from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform @@ -42,6 +42,7 @@ from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group +from jax._src.pallas.mosaic_gpu.primitives import GPUShapeDtypeStruct as GPUShapeDtypeStruct from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fd4d5bb73f52..9ad9038dfc49 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -388,17 +388,27 @@ def test_inline_mgpu(self): def kernel(x_ref, o_ref, smem_ref, barrier): plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) plgpu.barrier_wait(barrier) - arr = jnp.ones_like(x_ref) + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) @plgpu.inline_mgpu( - smem_ref, - o_ref, - arr, - arg_types=[plgpu.RefType(), plgpu.RefType(), plgpu.Layout.WG_SPLAT(x_ref.shape)], + arg_types=(plgpu.RefType(),), + return_type=plgpu.GPUShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), ) - def _(ctx, smem_ref, o_ref, y): + def foo(ctx, smem_ref): del ctx x = mgpu.FragmentedArray.load_strided(smem_ref) - (x + y).store_untiled(o_ref) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return (x + y) + + arr = foo(smem_ref) + @plgpu.inline_mgpu(arg_types=(layout, plgpu.RefType())) + def store(ctx, arr, o_ref): + del ctx + arr.store_untiled(o_ref) + store(arr, o_ref) key = jax.random.key(0) x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype) From b73bf1a03a7084b59bccf77451e73ef5a3c7f025 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 10 Apr 2025 16:42:54 -0700 Subject: [PATCH 0549/1769] Update JAX continuous workflow to run once every 3 hours instead of 2. We are seeing a higher number of cancellations of the continuous job recently: ``` Canceling since a higher priority waiting request for CI - Wheel Tests (Continuous)-refs/heads/main exists ``` PiperOrigin-RevId: 746222323 --- .github/workflows/wheel_tests_continuous.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 3739c9267730..175fc2f22d4a 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -19,7 +19,7 @@ name: CI - Wheel Tests (Continuous) on: schedule: - - cron: "0 */2 * * *" # Run once every 2 hours + - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually concurrency: From b352763a177b0cb0d503f6359a8ed9bd60b72e59 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 10 Apr 2025 16:56:41 -0700 Subject: [PATCH 0550/1769] Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1. PiperOrigin-RevId: 746226562 --- jax/_src/pallas/fuser/block_spec.py | 7 +++++-- tests/mosaic/matmul_test.py | 1 + tests/pallas/indexing_test.py | 2 ++ tests/pallas/ops_test.py | 1 + tests/pallas/tpu_all_gather_test.py | 1 + tests/pallas/tpu_gmm_test.py | 1 + tests/pallas/tpu_ops_test.py | 2 ++ tests/pallas/tpu_pallas_pipeline_test.py | 1 + tests/pallas/tpu_pallas_test.py | 2 ++ tests/pallas/tpu_splash_attention_kernel_test.py | 1 + 10 files changed, 17 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index de0cdd204f3c..146191bab9b3 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -170,8 +170,11 @@ def get_out_block_indices(self): _illegal = object() -_sp_env = threading.local() -_sp_env.scalar_prefetch = None +class _SpEnv(threading.local): + def __init__(self): + self.scalar_prefetch = None + +_sp_env = _SpEnv() @contextlib.contextmanager diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index d598d7d0c0ec..41e60fbe4c29 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -48,6 +48,7 @@ def wrapper(self, seed): @jtu.with_config(jax_traceback_filtering="off") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class MatmulTestCase(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 5430009c5d28..6e9d552e379f 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -127,6 +127,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerTest(jtu.JaxTestCase): """These are unit tests for the indexer logic, not using pallas_call.""" @@ -246,6 +247,7 @@ def test_ndindexer(self, data): indexer.get_indexer_shape()) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerOpsTest(PallasBaseTest): def test_multi_indexing_interpreter_only(self): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 90d8bafb3c91..b3dd61757df8 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -307,6 +307,7 @@ def skip_if_mosaic_gpu(self): self.skipTest("TODO: Mosaic GPU does not support this yet") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.named_parameters( diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 98b3e5b40135..0c9a4b545591 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -81,6 +81,7 @@ def _array_dtypes(draw): ) + @jtu.thread_unsafe_test_class() # hypothesis is not thread safe class AllGatherTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index 9c416dabaeb1..cadba4c15fa0 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -172,6 +172,7 @@ def tolerances( # TODO(tgale): Fix errors with strict dtype promotion. @jtu.with_config(jax_numpy_dtype_promotion="standard") + @jtu.thread_unsafe_test_class() # hypothesis is not thread safe class GroupedMatmulTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index c8def2627462..53e5462e20c2 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -66,6 +66,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.product( @@ -491,6 +492,7 @@ def kernel(x, out): np.testing.assert_array_equal(output, expected) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 8e72c49e2598..95014d9e9683 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -1496,6 +1496,7 @@ def run(acc_scratch_ref): grid=(num_cores,), )(x, y) + @jtu.thread_unsafe_test_class() # hypothesis is not thread safe class PaddedPipelineEmitterTest(parameterized.TestCase): def setUp(self): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 2e773b88fbad..0bb7b45d7944 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1577,6 +1577,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): )(x) np.testing.assert_array_equal(y, x) + @jtu.thread_unsafe_test() # Uses a lot of TPU memory. def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 @@ -2331,6 +2332,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x[8:16, :128]) +@jtu.thread_unsafe_test_class() # debug print test is not thread safe class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index 240a9c91c02d..8a73f221bb6d 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -329,6 +329,7 @@ def _assert_allclose(self, x, y, **kwargs): np.testing.assert_allclose(x, y, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class SplashAttentionTest(PallasBaseTest): @parameterized.product( is_mqa=(False, True), From 6d57f00b584cc30833a41143c9430f5a80cf2365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 10 Apr 2025 17:03:27 -0700 Subject: [PATCH 0551/1769] [Mosaic:TPU][Relayout] Add implicit 2nd minor PiperOrigin-RevId: 746228503 --- .../tpu/transforms/apply_vector_layout.cc | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 25aebefa4506..9f12da5237bd 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5508,6 +5508,58 @@ void rotateLanes(OpBuilder &builder, xla::Array &vregs, rotateVregs(builder, vregs, amount, 1); } +// Rotate a vreg by a certain amount of rows, and get the low or high bits of +// each sublane after rotation. +// +// For these purposes, the vreg is considered to have shape (row_packing * +// target_shape[0], target_shape[1]) +// +// Args: +// vreg: The vreg to rotate +// rotate_amount: The amount to rotate the vreg by. +// rows_per_sublane: The number of rows in a sublane. +// is_high: If true, get the high bits of each sublane, otherwise get low bits. +// +// Returns: +// The rotated vreg. +Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, + const int64_t rotate_amount, + const int64_t rows_per_sublane, const bool is_high, + const std::array target_shape) { + CHECK_LE(0, rotate_amount); + CHECK_LT(0, rows_per_sublane); + const int64_t bits_per_row = 32 / rows_per_sublane; + const int64_t sublane_rotate_amount = + rotate_amount / rows_per_sublane + (is_high ? 0 : 1); + const int64_t within_sublane_rotate_amount = rotate_amount % rows_per_sublane; + vreg = builder.create(vreg.getLoc(), vreg, + /*amount=*/sublane_rotate_amount, + /*dimension=*/0, /*stride=*/nullptr, + /*stride_dimension=*/nullptr); + if (within_sublane_rotate_amount != 0) { + const VectorType vreg_ty = cast(vreg.getType()); + const VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), target_shape); + vreg = builder.create(loc, i32_vreg_ty, vreg); + if (is_high) { + auto shift_amt = builder.create( + loc, + builder.getIntegerAttr(builder.getI32Type(), + bits_per_row * within_sublane_rotate_amount)); + vreg = builder.create(loc, vreg, shift_amt); + } else { + auto shift_amt = builder.create( + loc, builder.getIntegerAttr( + builder.getI32Type(), + bits_per_row * + (rows_per_sublane - within_sublane_rotate_amount))); + vreg = builder.create(loc, vreg, shift_amt); + } + vreg = builder.create(loc, vreg_ty, vreg); + } + return vreg; +} + // Relayout src_vregs from layout src to layout dst, where dst is the same as // src except that the column offset is dst_col_offset. FailureOr> doColumnShiftRelayout( @@ -6649,6 +6701,59 @@ FailureOr>> changeImplicitDim( src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); return std::make_pair(src_candidate, vregs); } + const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); + CHECK_GT(sublanes_per_tile, 0); + if (src.tiling()[0] % sublanes_per_tile != 0) { + // Tilings such as 32-bit (4, 256) are not used and not supported. + return emitError( + loc, "Not implemented: Rows within tile span multiple sublanes"); + } + const int64_t rows_per_sublane = src.tiling()[0] / sublanes_per_tile; + // Add second minor implicit dim + if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst_implicit_dim == VectorLayout::ImplicitDim::kSecondMinor) { + // TODO(tlongeri): Detect replicated source 2nd minor as a no-op above + const int64_t src_offset = src.offsets()[0].value_or(0); + // TODO(tlongeri): Do broadcast (different path) for replicated output + const int64_t dst_offset = dst_offset_hints[0].value_or(0); + VectorLayout dst(src.bitwidth(), {dst_offset, src.offsets()[1]}, + src.tiling(), dst_implicit_dim); + xla::Array new_vregs( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + DCHECK_EQ(*(new_vregs.dimensions().end() - 2), 1); + // Define src_idx outside loop to avoid reallocation + SmallVector src_idx; + new_vregs.Each([&](const absl::Span idx, Value *new_vreg) { + // Shift the desired row from the source vreg to the desired offset for + // the destination vreg. This is done with rotates and, for packed types + // with multiple rows per sublane, bitshifts. + // Note that the offset of the source row varies but the destination + // offset is always the same. + const int64_t dst_offset_in_sublane = dst_offset % rows_per_sublane; + // src_row_with_offset is the row of the padded implicit shape that we + // will place in the destination vreg. The first dst vreg along the + // non-implicit 2nd minor has the source row at offset src_offset, the + // second has the source row at offset src_offset+1, etc. + const int64_t src_row_with_offset = *(idx.end() - 3) + src_offset; + src_idx.assign(idx.begin(), idx.end() - 3); + src_idx.push_back(src_row_with_offset / src.tiling()[0]); + src_idx.push_back(idx.back()); + Value vreg = vregs(src_idx); + const int64_t src_offset_in_vreg = src_row_with_offset % src.tiling()[0]; + const int64_t src_offset_in_sublane = + src_row_with_offset % rows_per_sublane; + int64_t row_rotate_amt = dst_offset - src_offset_in_vreg; + if (row_rotate_amt < 0) { + row_rotate_amt += rows_per_sublane * target_shape[0]; + } + *new_vreg = rotateVregRows( + builder, loc, vreg, row_rotate_amt, rows_per_sublane, + /*is_high=*/src_offset_in_sublane <= dst_offset_in_sublane, + ctx.target_shape); + }); + return std::make_pair(dst, new_vregs); + } + // Remove second minor implicit dim, for values that have (m, 128) tiling (for // m that is a power of 2). if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && @@ -6675,7 +6780,6 @@ FailureOr>> changeImplicitDim( // For example, extended offsets allow us to skip copies of low sublanes // in tiles with idx.back() == 0. const int tiles_per_vreg = src.tilesPerVreg(target_shape); - const int sublanes_per_tile = src.sublanesPerTile(target_shape); src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + dst_sl_start - dst_sublane_offset; for (int dst_sl_idx = dst_sl_start; From 6e52b1e95b62d3c0c215596955a127307fdbed81 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 11 Apr 2025 00:25:42 +0000 Subject: [PATCH 0552/1769] optimize while_loop by moving readonly carry components to be consts also fix a bug in ordered effects in cond_fun lowering fixes google/flax#4700 --- jax/_src/lax/control_flow/loops.py | 42 ++++++++++++++++++++++++------ tests/checkify_test.py | 4 +-- tests/lax_control_flow_test.py | 23 +++++++++++++++- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index babffa1d47d7..689e1f535259 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1461,9 +1461,34 @@ def _create_jaxpr(init_val): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') + + # If the body forwards an input carry to an output carry, *and* it's not used + # by the cond fun, it can be moved to be a body const. Doing so can lead to + # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch + # the carry too, but not the body consts. + body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) + _, carry_fwd = split_list(body_fwd, [len(body_consts)]) + cond_jaxpr_, keep_cond = pe.dce_jaxpr( + cond_jaxpr.jaxpr, [True], + [True] * len(cond_consts) + [i != f for i, f in enumerate(body_fwd)]) + _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) + move_to_const = [i == f and not k for i, (f, k) + in enumerate(zip(body_fwd, keep_cond_carry))] + if any(move_to_const): + cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) + body_jaxpr = pe.prune_closed_jaxpr_outputs( + body_jaxpr, [not m for m in move_to_const]) + body_jaxpr = pe.move_binders_to_front( + body_jaxpr, [False] * len(body_consts) + move_to_const) + init_vals, new_body_consts = partition_list(move_to_const, init_vals) + body_consts = [*new_body_consts, *body_consts] + outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) + + if any(move_to_const): + outs = pe.merge_lists(move_to_const, outs, new_body_consts) return tree_unflatten(body_tree, outs) @@ -1839,18 +1864,19 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) + return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) def new_cond(pred_args): - pred, _ = pred_args + pred, *_ = pred_args return pred def new_body(pred_args): - _, args = pred_args - args = body(args) - pred = cond(args) - return pred, args + _, cond_consts, body_consts, carry = pred_args + carry = body((*body_consts, *carry)) + pred = cond((*cond_consts, *carry)) + return pred, cond_consts, body_consts, carry def fun(*args): - pred = cond(args) - _, out = while_loop(new_cond, new_body, (pred, args)) + cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) + pred = cond((*cond_consts, *carry)) + *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) return out return mlir.lower_fun(fun)(ctx, *args) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 5ea99d20a2ab..2f4b7d511fbe 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -492,8 +492,8 @@ def f(x: jax.Array) -> jax.Array: def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - _ = jnp.sin(cond_val) - return i < 2 + j = jnp.sin(cond_val) + return i + (0. * j) < 2 # don't let the sin value be dead code def while_body(val): i, cond_val, body_val = val diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index a987d9e4c192..3034096cee57 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2362,7 +2362,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( @@ -3122,6 +3122,27 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash + def test_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4700 + def foo(w, x, c_max): + def while_cond(val): + c, x, w = val + return c < c_max + + def while_body(val): + c, x, w = val + return c + 1, x @ w, w + + _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) + return w, x + + w = jnp.ones((2, 2)) + xs = jnp.ones((4, 2)) + c_maxs = jnp.arange(4) + w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) + )(w, xs, c_maxs) # doesn't crash + self.assertAllClose(w, w_, check_dtypes=False) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From ffc33abb5dfcbbc80f68d2dc034d0009785ccae1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 01:41:31 +0000 Subject: [PATCH 0553/1769] Bump scipy build requirement on Python 3.13. We need v1.15.2 for Linux aarch64 3.13-t wheels to exist. --- build/requirements.in | 3 +- build/requirements_lock_3_13_ft.txt | 88 +++++++++++++++-------------- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index ec7fc71b07e1..b023cedfbd19 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -13,7 +13,8 @@ numpy~=2.1.0; python_version>="3.13" # # runtime deps # -scipy>=1.13.1 +scipy>=1.13.1; python_version<="3.12" +scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 opt_einsum diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index a96a3e6e489b..507e896ab8db 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -596,47 +596,53 @@ rich==13.9.4 \ --hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \ --hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90 # via -r build/test-requirements.txt -scipy==1.15.0 \ - --hash=sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731 \ - --hash=sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422 \ - --hash=sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2 \ - --hash=sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863 \ - --hash=sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37 \ - --hash=sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111 \ - --hash=sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8 \ - --hash=sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac \ - --hash=sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479 \ - --hash=sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6 \ - --hash=sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192 \ - --hash=sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4 \ - --hash=sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b \ - --hash=sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020 \ - --hash=sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34 \ - --hash=sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054 \ - --hash=sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e \ - --hash=sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d \ - --hash=sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6 \ - --hash=sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d \ - --hash=sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0 \ - --hash=sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b \ - --hash=sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8 \ - --hash=sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443 \ - --hash=sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2 \ - --hash=sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136 \ - --hash=sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca \ - --hash=sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff \ - --hash=sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0 \ - --hash=sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e \ - --hash=sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c \ - --hash=sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df \ - --hash=sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f \ - --hash=sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4 \ - --hash=sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d \ - --hash=sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5 \ - --hash=sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c \ - --hash=sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52 \ - --hash=sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913 \ - --hash=sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1 +scipy==1.15.2 ; python_version >= "3.13" \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db # via -r build/requirements.in six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ From 9f5f6edb85487569127429c7ac8be70b3d8cb2f9 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 10 Apr 2025 19:09:44 -0700 Subject: [PATCH 0554/1769] [Pallas] Fix integer array indexing Fixes https://github.com/google/jax/issues/22783 jax-fixit PiperOrigin-RevId: 746260869 --- jax/_src/state/discharge.py | 218 ++++++++++++++++++++++++---------- jax/_src/state/indexing.py | 20 +++- jax/_src/state/primitives.py | 197 ++++++++++++++++++++++++++---- tests/pallas/indexing_test.py | 19 ++- tests/state_test.py | 49 ++++---- 5 files changed, 377 insertions(+), 126 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7ab77d5b1c37..615fa862bf31 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -275,33 +275,97 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) +# In this code, indexing is handled in three ways: `slice`, `dynamic_slice`, and +# gather. For the gather case, the goal is to create a gather array, which means +# that we need to convert all other types of indexers into integer array +# indexers. This is done by looping over all indexers and checking if they are +# not integer array indexers, and if not, performing the conversion. However, +# during this process, the indexing semantics may change. Specifically, +# according to the indexing rules of NumPy, when there are integer array +# indexers separated by other indexers, the axes corresponding to the integer +# array indexers need to be moved to the front. After we convert all other +# indexers to integer array indexers, the distinction between integer array +# indexers and other types of indexers is lost. As a result, it becomes +# impossible to determine which axes should be moved to the front. In this case, +# we need to transpose the target array before the gather operation. We also +# need to transpose the target array back after the gather operation, if it is +# used in subsequent computations. +def _maybe_transpose_before_gather( + indexer: indexing.NDIndexer +) -> tuple[int, ...] | None: + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) + if int_indexers_contiguous: + return None # no transpose needed + + int_indexer_idxs: list[int] = [] + non_int_indexer_idxs: list[int] = [] + for i, is_int_index in enumerate(is_int_indexing): + (int_indexer_idxs if is_int_index else non_int_indexer_idxs).append(i) + transpose_order = (*int_indexer_idxs, *non_int_indexer_idxs) + return transpose_order + + +def _perform_transpose_before_gather( + target_arr: Array, + indexer: indexing.NDIndexer, + transpose_order: tuple[int, ...], +) -> tuple[Array, indexing.NDIndexer]: + new_target_arr = target_arr.transpose(transpose_order) + reordered_indices = tuple(indexer.indices[i] for i in transpose_order) + new_indexer = indexing.NDIndexer( + indices=reordered_indices, + shape=indexer.shape, + int_indexer_shape=indexer.int_indexer_shape, ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) * idx.stride + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer + return new_target_arr, new_indexer + + +def _convert_to_gather_arrays(indexer: indexing.NDIndexer) -> tuple[Array, ...]: + # This is the general gather case. We need to create the gather arrays. + total_shape = indexer.get_indexer_shape() + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + if any(is_int_indexing): + n_idxers = len(indexer.indices) + int_indexer_shape = indexer.int_indexer_shape + n_int_indexers = sum(1 for p in is_int_indexing if p) + last_int_index_idx = n_idxers - 1 - is_int_indexing[::-1].index(True) + n_slice_index_dims_after_int = n_idxers - last_int_index_idx - 1 + + def get_idx_in_shape_after_indexing(i): + if not any(is_int_indexing): + return i + + if i < n_idxers - n_slice_index_dims_after_int - n_int_indexers: + return i + if i < n_idxers - n_slice_index_dims_after_int: + raise ValueError + return i - n_int_indexers + len(int_indexer_shape) + + arrs = [] + for i, idxer in enumerate(indexer.indices): + if isinstance(idxer, indexing.Slice): + idx_in_shape_after_indexing = get_idx_in_shape_after_indexing(i) + arr = ( + lax.iota(np.int32, total_shape[idx_in_shape_after_indexing]) + * idxer.stride + + idxer.start ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + diff = len(total_shape) - idx_in_shape_after_indexing - 1 + arr = arr.reshape(arr.shape + (1,) * diff) + arrs.append(arr) + elif isinstance(idxer, (np.ndarray, Array)): + diff = n_idxers - 1 - last_int_index_idx + arr = idxer.reshape(idxer.shape + (1,) * diff) + arrs.append(arr) + else: + raise ValueError(f"Invalid type of idxer: {type(idxer).__name__}") + + return tuple(arrs) @register_discharge_rule(get_p) @@ -313,20 +377,8 @@ def _get_discharge_rule( y = _get_discharge(x, idx, tree) return (None,) * (len(idx) + 1), y -def _prepend_gather(x, indexer): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - return x[None][(np.array(0, 'int32'), *indexer)] -def _prepend_scatter(x, indexer, val, *, add=False): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - # However, since this is scatter, we need to remove the 1-sized dimension - # we added at the front. - if add: - return x[None].at[(0, *indexer)].add(val)[0] - return x[None].at[(0, *indexer)].set(val)[0] - - -def _index_array(x, indexer): +def _index_array(x, indexer: indexing.NDIndexer): if _is_trivial_indexer(indexer): return x # Try the three APIs in the following order: `lax.slice`, @@ -336,13 +388,16 @@ def _index_array(x, indexer): # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice + elif maybe_dynamic_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_dynamic_slice y = lax_slicing.dynamic_slice(x, starts, sizes) x = lax.squeeze(y, squeeze_dims) else: - indexer = _convert_to_array_indexer(indexer) - x = x[None][(np.array(0, "int32"), *indexer)] + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x[arrays] return x @@ -367,53 +422,79 @@ def transform_array(x, transforms): def transform_swap_array(x, transforms, val): if transforms is None: transforms = [] - result = x - result_val = val - # Compute updated "val" (result). - _results = [x] + + # Will hold the value read from `x` before the swap, and will have the same + # shape as `val`. + new_val = x + # List of intermediate results by transforming `x`. + intermediates = [x] + + # Read phase (forward loop) for transform in transforms: match transform: case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(_results[-1]) + intermediates.append(intermediates[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(result_old, squeeze_dims) + new_val = lax.squeeze( + lax_slicing.dynamic_slice(new_val, starts, sizes), squeeze_dims + ) else: - indexer = _convert_to_array_indexer(indexer) - result = _prepend_gather(result, indexer) - _results.append(result) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + new_val, indexer = _perform_transpose_before_gather( + new_val, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_val = new_val[arrays] + # Here, we don't need to transpose `new_val` back because it now holds + # the result of the indexing, and is no longer the original array that + # was indexed into. + intermediates.append(new_val) case RefBitcaster(): - _results.append(bitcast(result, transform.dtype)) + intermediates.append(bitcast(new_val, transform.dtype)) case RefReshaper(): - _results.append(result.reshape(transform.shape)) + intermediates.append(new_val.reshape(transform.shape)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") - # Compute updated "x" (result_val) - for i, transform in reversed(list(enumerate(transforms))): + # Will hold the final state of the `x` after `val` has been written to the + # transformed location, and will have the same shape as `x`. + new_x = val + + # Write phase (reversed loop) + for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): continue if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, _, squeeze_dims = maybe_slice - result_val = lax.expand_dims(result_val, squeeze_dims) - result_val = lax_slicing.dynamic_update_slice( - _results[i], result_val, starts + new_x = lax_slicing.dynamic_update_slice( + intermediate, lax.expand_dims(new_x, squeeze_dims), starts ) else: - indexer = _convert_to_array_indexer(indexer) - result_val = _prepend_scatter(_results[i], indexer, result_val) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + intermediate, indexer = _perform_transpose_before_gather( + intermediate, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_x = intermediate.at[arrays].set(new_x) # pytype: disable=attribute-error + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + new_x = new_x.transpose(transpose_order_inversed) else: raise NotImplementedError(f"Unsupported transform: {transform}") - return result, result_val + + return new_val, new_x + def _get_discharge(x, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) @@ -446,8 +527,10 @@ def _addupdate_discharge(x, val, idx, tree): if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") indexer = transforms[0] + if _is_trivial_indexer(indexer): return x + val + # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. @@ -457,8 +540,17 @@ def _addupdate_discharge(x, val, idx, tree): val = lax.expand_dims(val, squeeze_dims) y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) return y - indexer = _convert_to_array_indexer(indexer) - return _prepend_scatter(x, indexer, val, add=True) + + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x.at[arrays].add(val) + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + x = x.transpose(transpose_order_inversed) + return x + @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index e7b581680efe..e6e6b8a5ee25 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -272,11 +272,21 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: return cls(indices, shape, int_indexer_shape, validate=True) def get_indexer_shape(self) -> tuple[int | Array, ...]: - _, slice_indexers, _ = unpack_ndindexer(self) - slice_shape = [s.size for s in slice_indexers] - # In NDIndexers, the int_indexer_shape is *always* at the front of the - # result. - return (*self.int_indexer_shape, *slice_shape) + is_int_indexing, slice_indexers, _ = unpack_ndindexer(self) + + slice_shape = tuple(s.size for s in slice_indexers) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + if not int_indexers_contiguous: + return self.int_indexer_shape + slice_shape + + has_int_indexers = any(is_int_indexing) + if has_int_indexers: + pos = is_int_indexing.index(True) + return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:] + + return slice_shape def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: del shape # Unused diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 578ef0bbc328..1237da57f217 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -454,11 +454,52 @@ def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn): ## get/swap/addupdate batching rules -def _batch_indexer(indexer: indexing.NDIndexer, dims, - axis_size: int, - ref_shape: tuple[int, ...], - ref_dim: int | batching.NotMapped, - idx_is_batched: bool) -> indexing.NDIndexer: +def _batch_indexer( + indexer: indexing.NDIndexer, + dims, + axis_size: int, + ref_shape: tuple[int, ...], + ref_dim: int | batching.NotMapped, + idx_is_batched: bool, +) -> indexing.NDIndexer: + """Converts a batched indexer into an unbatched one. + + This function handles the complexity of `vmap`-style batching where either the + `ref` being indexed, the indexer, or both may have batched dimensions. The + goal is to produce a new indexer that acts as if applied in a batched context, + but without actual batching, enabling downstream code to process it as usual. + + If any index in `indexer` is batched, all array indexers are normalized. If + the array indexer contains a batched dimension, the dimension is moved to the + front (axis 0). If the array indexer not batched, it is broadcasted to include + a batch dimension at the front. This is to guarantee that all array indexers + are still of the same shape. + + Slices are passed through unchanged unless they contain dynamic elements and + are themselves batched, which is currently unsupported. + + If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example + indexing by inserting a new iota array at the position corresponding to + `ref_dim` in the indexer. + + It is worth noting that if the array indexers in the original indexer are + contiguous, but become non-contiguous in the new indexer due to the insertion + of the iota, the dimensions corresponding to the array indexers will be moved + to the front in the indexing result. The batched dimension will be at axis 0, + while the dimensions corresponding to the array indexers in the original + indexer will start from axis 1. This behavior would cause a mismatch between + the original indexer and the new indexer. Callers must take this behavior into + account and properly transpose the arrays involved to avoid this mismatch. + + Args: + indexer: An `NDIndexer` that indexes into `ref`. + dims: A pytree with the same structure as `indexer`, indicating which + dimension (if any) is batched for each array indexer. + axis_size: Size of the batch dimension. + ref_shape: Shape of `ref`. + ref_dim: The dimension of `ref` that is batched (if any). + idx_is_batched: Whether any index in the `indexer` is batched. + """ indices = indexer.indices indices_dims = dims.indices new_indices: list[Array | indexing.Slice | int] = [] @@ -510,9 +551,9 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, if ref_dim is not batching.not_mapped: iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) new_indices.insert(ref_dim, iota) - return indexing.NDIndexer(tuple(new_indices), ref_shape, - new_integer_indexer_shape, - validate=True) + return indexing.NDIndexer( + tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True + ) def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) @@ -527,11 +568,42 @@ def _get_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + out = get_p.bind(ref, *flat_indexers, tree=tree) + if not int_indexers_contiguous: # will always be moved to the front + out_bdim = 0 + else: # originally not going to be moved to the front + if new_int_indexers_contiguous: # now not going to be moved to the front + out_bdim = is_new_int_indexing.index(True) + else: # now going to be moved to the front + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(out.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[array_indexer_len:array_indexer_len+original_pos], + *transpose_order[1:array_indexer_len], + *transpose_order[array_indexer_len+original_pos:], + ) + + out = lax.transpose(out, transpose_order) + out_bdim = 0 + return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap def _swap_vmap(batched_args, batched_dims, *, tree): @@ -549,15 +621,59 @@ def _swap_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) - return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + 0 + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + transpose_order_inversed = None + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + transpose_order_inversed = np.argsort(transpose_order) + + out = swap_p.bind(ref, val, *flat_indexers, tree=tree) + + # `val` should not be transposed, but we needed to transpose it to match + # `swap_p`. As a result, the output of `swap_p` is also transposed. Now we + # need to transpose it back. + if transpose_order_inversed is not None: + out = out.transpose(transpose_order_inversed) + + return out, batched_dim_in_result batching.primitive_batchers[swap_p] = _swap_vmap def _addupdate_vmap(batched_args, batched_dims, *, tree): @@ -575,14 +691,47 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] batching.primitive_batchers[addupdate_p] = _addupdate_vmap diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 6e9d552e379f..3de0c1f305c6 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -218,12 +218,13 @@ def test_indexer_with_all_types(self): indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None]) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2)) + self.assertTupleEqual(indexer.get_indexer_shape(), (2, 5, 4)) @hp.given(hps.data()) def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) indexer = data.draw(nd_indexer_strategy(shape)) + is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices] rest_indexers, int_indexers = util.partition_list( is_int_indexer, indexer.indices @@ -235,16 +236,12 @@ def test_ndindexer(self, data): self.assertTupleEqual( indexer.int_indexer_shape, expected_int_indexer_shape ) + for idx in rest_indexers: self.assertIsInstance(idx, (np.ndarray, Slice)) if isinstance(idx, np.ndarray): self.assertTupleEqual(idx.shape, ()) self.assertEqual(idx.dtype, np.dtype("int32")) - rest_shape = tuple( - r.size for r in rest_indexers if not isinstance(r, np.ndarray) - ) - self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape), - indexer.get_indexer_shape()) @jtu.thread_unsafe_test_class() # hypothesis is not thread safe @@ -692,18 +689,18 @@ class IndexerOpsInterpretTest(IndexerOpsTest): ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), # slice + 1-D array ((4, 3), lambda arr, a, b, c, d: arr[a, :]), - # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((4, 3), lambda arr, a, b, c, d: arr[:, a]), ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), - # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), # slice + array w/ broadcasting - ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b[:, None], ::4, a[None], a[:, None]]), # integer + slice + 1-D array ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), diff --git a/tests/state_test.py b/tests/state_test.py index ab11ab829a66..65f6f0427a00 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import core from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu @@ -477,27 +478,17 @@ def g(r, rdot): op=[ lambda x_ref, indexer: [x_ref[indexer]], lambda x_ref, indexer: [ - ref_swap(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)])], + ref_swap(x_ref, indexer, jnp.ones_like(x_ref[indexer]))], lambda x_ref, indexer: ( - ref_addupdate(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)]) - or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]]) + ref_addupdate(x_ref, indexer, jnp.ones_like(x_ref[indexer])) + or [jnp.ones_like(x_ref[indexer])]), ], ) def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims, idx_bdims, out_bdim, op): - - float_ = (jnp.dtype('float64') if config.enable_x64.value else - jnp.dtype('float32')) - int_ = (jnp.dtype('int64') if config.enable_x64.value else - jnp.dtype('int32')) + intx = dtypes.canonicalize_dtype(jnp.int64) + floatx = dtypes.canonicalize_dtype(jnp.float64) axis_size = 7 - out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b) - if any(indexed_dims): - out_shape = (*idx_shape, *out_shape) def maybe_insert(shape, idx): if idx is None: @@ -505,13 +496,13 @@ def maybe_insert(shape, idx): return tuple_insert(shape, idx, axis_size) batched_ref_shape = maybe_insert(ref_shape, ref_bdim) - ref_aval = shaped_array_ref(ref_shape, float_) - bat_ref_aval = shaped_array_ref(batched_ref_shape, float_) + ref_aval = shaped_array_ref(ref_shape, floatx) + bat_ref_aval = shaped_array_ref(batched_ref_shape, floatx) - idx_avals = [core.ShapedArray(idx_shape, int_) + idx_avals = [core.ShapedArray(idx_shape, intx) for _ in idx_bdims] bat_idx_avals = [ - core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_) + core.ShapedArray(maybe_insert(idx_shape, idx_bdim), intx) for idx_bdim in idx_bdims] def f(x_ref, *idxs): @@ -531,6 +522,7 @@ def f(x_ref, *idxs): wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) + # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) @@ -830,11 +822,22 @@ def index_params(draw): min_size=len(ref_shape), max_size=len(ref_shape))) idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) - if any(indexed_dims): - sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b) - slice_shape = (*idx_shape, *sliced_shape) - else: + if not any(indexed_dims): slice_shape = ref_shape + else: + sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(indexed_dims)[0]) == 1) + ) + if not int_indexers_contiguous: + slice_shape = (*idx_shape, *sliced_shape) + else: + insert_pos = indexed_dims.index(True) + slice_shape = ( + *sliced_shape[:insert_pos], + *idx_shape, + *sliced_shape[insert_pos:], + ) ref_aval = shaped_array_ref(ref_shape, np.float32) idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in range(sum(indexed_dims))) From 7b7d36a8e6105d4bbc4e7cb0b86f171bbf2c884b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 10 Apr 2025 21:32:07 -0700 Subject: [PATCH 0555/1769] Add a 2D test in memories_test. PiperOrigin-RevId: 746295338 --- tests/memories_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/memories_test.py b/tests/memories_test.py index 570b0c375834..278044eabfd0 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -791,6 +791,36 @@ def f(x): lowered_text = f.lower(jnp.arange(8)).as_text() self.assertIn('_xla_compute_type', lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 + + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, "pinned_host") + + def test_compute_on_2d(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + + @compute_on("device_host") + @jax.jit + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return y * 3 + + inp = jnp.arange(9943.0) + inp = jnp.reshape(inp, (61, 163)) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) def h(x): y = g(x) From d42d2e88b47eb881928c423b042337bc2fbb1ac9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 11 Apr 2025 01:53:20 -0700 Subject: [PATCH 0556/1769] [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order. Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device. If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape). PiperOrigin-RevId: 746365669 --- jax/_src/pallas/mosaic/interpret.py | 135 ++++++++++++++++-- tests/pallas/tpu_pallas_interpret_test.py | 160 ++++++++++++++++++++++ 2 files changed, 285 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index a980590ba134..13b321424f81 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -20,7 +20,7 @@ import itertools import math import threading -from typing import Any, Literal +from typing import Any, Callable,Literal import jax from jax import lax @@ -38,6 +38,7 @@ from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives +from jax._src.typing import Array from jax._src.util import ( safe_map, safe_zip, @@ -74,10 +75,10 @@ class TPUInterpretParams: is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: "on_wait". - detect_races: If True, a dynamic, happens-before race detector will be - used to detect data races during kernel interpretation. If any races are - detected, a message will be printed and `races.races_found` will be set - to True. + detect_races: If True, a dynamic, happens-before race detector will be used + to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set to + True. Default: False. skip_floating_point_ops: If True, operations that produce only floating point values will not be interpreted; instead, their results will be @@ -85,14 +86,25 @@ class TPUInterpretParams: operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. uninitialized_memory: If "nan", allocated buffers are initialized to - to contain all NaNs (or to their maximum possible value for integers). - If "zero", allocated buffers are initialized to all zeros. + contain all NaNs (or to their maximum possible value for integers). If + "zero", allocated buffers are initialized to all zeros. Default: "nan". + random_seed: Seed for random number generator used during interpretation. + Currently random numbers are used to randomize the grid coordinates along + dimensions with 'parallel' semantics. + Default: None. + grid_point_recorder: Callback that is invoked by the interpreter for each + grid point in the order in which the grid points are traversed. This is + intended for inspecting the randomization of coordinates along grid + dimensions with 'parallel' semantics. + Default: None. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False skip_floating_point_ops: bool = False uninitialized_memory: Literal["nan", "zero"] = "nan" + random_seed: int | None = None + grid_point_recorder: Callable[[tuple[jnp.int32, ...]], None] | None = None VectorClock = np.ndarray @@ -1358,6 +1370,96 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) +def _get_parallel_dim_semantics( + compiler_params: dict[str, Any], grid: tuple[int, ...] +) -> tuple[bool, ...]: + """Returns a tuple of booleans indicating whether the corresponding dimension in `grid` is parallel.""" + dimension_semantics = compiler_params.get('mosaic', {}).get( + 'dimension_semantics', None + ) + if dimension_semantics is None: + return (False,) * len(grid) + return tuple(ds == 'parallel' for ds in dimension_semantics) + +_GridPointCoordinatesPerDim = tuple[Array, ...] + +def _get_randomized_grid_coordinates( + grid: tuple[int, ...], + compiler_params: dict[str, Any], + random_seed: int | None, +) -> _GridPointCoordinatesPerDim: + """Returns a tuple of randomized coordinates for each 'parallel' dimension in `grid`. + + For a dimension with 'parallel' semantics at position `d` in the grid, the + returned tuple contains a random permutation of the sequence `[0,..., + grid[d] - 1]` at index `d`. For each dimension with 'arbitrary' semantics, + the resulting tuple contains an empty array. (Inserting an empty arry for an + 'arbitrary' dimension at position `d` in the grid, instead of the sequence + `[0,..., grid[d] - 1]`, allows `grid[d]` to be a dynamic value, i.e. a value + not known at Jax trace time.) + + Args: + grid: Tuple of sizes of the dimensions in the grid. + compiler_params: Representation of a `mosaic_core.TPUCompilerParams` object + as a dictionary. + parallel_semantics_per_dim: A tuple of booleans indicating whether the + corresponding dimension in the grid has parallel semantics. + random_seed: The seed to use for randomizing coordinates in parallel + dimensions. + """ + parallel_semantics_per_dim = _get_parallel_dim_semantics( + compiler_params, grid + ) + + key = jax.random.key(random_seed or 0) + grid_point_coordinates = [] + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim): + if parallel_dim: + # The size of a dimension with `parallel` semantics must be known at Jax + # trace time. This ensures that the arguments to `jnp.arange` and + # `jax.random.permutation` below are valid. + dim_size = jax_core.concrete_or_error(None, dim_size) + + coordindates_along_dim = jnp.arange(dim_size, dtype=jnp.int32) + key, subkey = jax.random.split(key) + coordindates_along_dim = jax.random.permutation( + subkey, coordindates_along_dim + ) + grid_point_coordinates.append(coordindates_along_dim) + else: + grid_point_coordinates.append(jnp.array((), dtype=jnp.int32)) + + return tuple(grid_point_coordinates) + + +def _get_grid_point( + loop_indices: tuple[Array, ...], + grid_point_coordinates: _GridPointCoordinatesPerDim, +) -> Array: + """Indexes each entry in `grid_point_coordinates` with the corresponding entry in `loop_indices`. + + If an entry in `grid_point_coordinates` is an empty array, the corresponding + entry in the returned array is the corresponding entry in `loop_indices`. + Otherwise, the returned array contains the entry in `grid_point_coordinates` + indexed with the corresponding entry in `loop_indices`. + + Args: + loop_indices: A tuple of loop indices. + grid_point_coordinates: A tuple of coordinate arrays for each dimension in + the grid. Dimensions with 'arbitrary' semantics are represented by empty + arrays. Dimensions with 'parallel' semantics are represented by arrays of + randomized coordinates. + + Returns: + A 1-dimensional array containing the coordinates for the grid point + corresponding to the specified `loop_indices`. + """ + grid_point = [] + for li, coords in zip(loop_indices, grid_point_coordinates): + grid_point.append(li if jnp.size(coords) == 0 else coords[li]) + return jnp.array(grid_point, dtype=np.int32) + + def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) @@ -1411,7 +1513,7 @@ def interpret_pallas_call( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: mosaic_core.TPUCompilerParams, + compiler_params: dict[str, Any], cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, @@ -1568,6 +1670,10 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 + randomized_grid_coordinates = _get_randomized_grid_coordinates( + grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] + ) + def _get_local_grid_env(loop_idx): if grid_mapping.local_grid_env is not None: return grid_mapping.local_grid_env(loop_idx, grid) @@ -1607,13 +1713,19 @@ def body( The carry for the next iteration. """ iteration_idx, loop_idx, prev_start_indices, cur_start_indices = carry + if interpret_params.grid_point_recorder is not None: + grid_point = _get_grid_point(loop_idx, randomized_grid_coordinates) + callback.io_callback(interpret_params.grid_point_recorder, (), grid_point) with pallas_core.grid_env(_get_local_grid_env(loop_idx)): next_loop_idx = _get_next_indices(grid, loop_idx) + next_grid_point = _get_grid_point( + next_loop_idx, randomized_grid_coordinates + ) next_start_indices = [ _compute_start_indices( bm, - next_loop_idx, + next_grid_point, *scalar_buffer_ids, compiler_params=compiler_params, interpret_params=interpret_params, @@ -1739,11 +1851,14 @@ def _store_to_output_buffer(index, output_var): return iteration_idx + 1, next_loop_idx, cur_start_indices, next_start_indices initial_loop_idx = (jnp.int32(0),) * len(grid) + initial_grid_point = _get_grid_point( + initial_loop_idx, randomized_grid_coordinates + ) with pallas_core.grid_env(_get_local_grid_env(initial_loop_idx)): initial_start_indices = [ _compute_start_indices( bm, - initial_loop_idx, + initial_grid_point, *scalar_buffer_ids, compiler_params=compiler_params, interpret_params=interpret_params, diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index afb573f8cf44..c4bf07f39cef 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -32,6 +32,7 @@ jax.config.parse_flags_with_absl() +jax.config.update('jax_threefry_partitionable', True) class CountStoreCallbacksContext(object): @@ -58,6 +59,29 @@ def num_stores(self): return self._num_stores +class GridPointRecorderContext(object): + """Records grid points in the order in which they are traversed.""" + + def __init__(self): + self._grid_points = [] + + def __enter__(self): + return self + + def __exit__(self, ty, value, traceback): + ... + + def get_recorder(self): + def _recorder(grid_point): + self._grid_points.append(grid_point) + + return _recorder + + @property + def grid_points(self): + return self._grid_points + + class InterpretTest(jtu.JaxTestCase): def setUp(self): @@ -326,6 +350,142 @@ def kernel_call(x, s): np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) self.assertEqual(store_callbacks_counter.num_stores, 5) + def test_randomization_of_parallel_dimensions(self): + def kernel(s_ref, o_ref): + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + + def kernel_call_dimensions_arbitrary_parallel(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=mosaic_interpret.TPUInterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('arbitrary', 'parallel') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_arbitrary_parallel, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 2.0, 3.0, 0.0, 1.0], + [ 6.0, 7.0, 4.0, 5.0], + [10.0, 11.0, 8.0, 9.0], + [14.0, 15.0, 12.0, 13.0], + ], + ) + np.testing.assert_array_equal( + grid_point_recorder.grid_points, + [ + [0, 2], + [0, 3], + [0, 0], + [0, 1], + [1, 2], + [1, 3], + [1, 0], + [1, 1], + [2, 2], + [2, 3], + [2, 0], + [2, 1], + [3, 2], + [3, 3], + [3, 0], + [3, 1], + ], + ) + + def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=mosaic_interpret.TPUInterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_parallel_arbitrary, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0], + ], + ) + np.testing.assert_array_equal( + grid_point_recorder.grid_points, + [ + [2, 0], + [2, 1], + [2, 2], + [2, 3], + [3, 0], + [3, 1], + [3, 2], + [3, 3], + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [1, 0], + [1, 1], + [1, 2], + [1, 3], + ], + ) + + def test_dynamic_parallel_dimension_raises(self): + def kernel(o_ref): + o_ref[0] = 42.0 + + @jax.jit + def kernel_call_dynamic_parallel_dimension(): + dim_size = jax.random.randint( + jax.random.key(0), (), 10, 20, dtype=jnp.int32 + ) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.float32), + grid=(dim_size,), + in_specs=[], + out_specs=pl.BlockSpec((1,), lambda _: (0,)), + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), + )() + + with self.assertRaises(jax.errors.ConcretizationTypeError): + kernel_call_dynamic_parallel_dimension() + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 96d38a6b66afb63484c97cacad0b8afa483fb6a7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 10 Apr 2025 12:43:19 +0200 Subject: [PATCH 0557/1769] [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions About half of the tracing-cache-miss explanations in a large benchmark end up being from JAX-internal functions, such as `jax.numpy` functions. These cache misses are not what the JAX user wants to see, so we filter them out, using the same mechanism used for filtering tracebacks. --- jax/_src/linear_util.py | 13 +++++++++++++ jax/_src/pjit.py | 37 ++++++++++++++++++++----------------- jax/_src/traceback_util.py | 6 ++++-- tests/api_test.py | 13 ++++++++++--- tests/debug_info_test.py | 5 +++++ 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1497597ebd62..1231d3066062 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -326,6 +326,18 @@ def replace_func_name(self, name: str) -> DebugInfo: func_src_comps[0] = name return self._replace(func_src_info=" ".join(func_src_comps)) + @property + def func_filename(self) -> str | None: + m = _re_func_src_info.match(self.func_src_info) + if not m: return None + return m.group(3) + + @property + def func_lineno(self) -> int | None: + m = _re_func_src_info.match(self.func_src_info) + if not m or m.group(4) is None: return None + return int(m.group(4)) + def safe_arg_names(self, expected: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: @@ -352,6 +364,7 @@ def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: assert self.result_paths is not None and not callable(self.result_paths), self return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) +_re_func_src_info = re.compile(r"([^ ]+)( at (.+):(\d+))?$") def _missing_debug_info(for_what: str) -> DebugInfo: warnings.warn( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 99b0a6403937..83ddb8709d4c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1142,7 +1142,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] return in_shardings_flat, in_layouts_flat -callsites: set[str] = set() +callsites_with_tracing_cache_miss: set[str] = set() def explain_tracing_cache_miss( fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): @@ -1157,6 +1157,10 @@ def unpack(key): if inline: return debug_info = fun.debug_info + func_filename = debug_info.func_filename + if func_filename and not traceback_util.include_filename(func_filename): + return + msg: list[str] = [] p = msg.append done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) @@ -1165,19 +1169,18 @@ def unpack(key): p(f"TRACING CACHE MISS at {callsite} because:") # have we seen this function before at all? - fun_name = getattr(fun.f, '__qualname__', fun.f) - if debug_info.func_src_info: - # TODO(necula): clean up the extraction of the source info - _, *rest = debug_info.func_src_info.split(' at ') - src_info = " defined at " + ' '.join(rest) - else: - src_info = '' + src_info = debug_info.func_name + if func_filename: + src_info += f" defined at {func_filename}" + if func_lineno := debug_info.func_lineno: + src_info += f":{func_lineno}" if unseen_f: p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") - if callsite in callsites: + if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") - callsites.add(callsite) + else: + callsites_with_tracing_cache_miss.add(callsite) return done() else: p(f" for {fun_name}{src_info}") @@ -1239,8 +1242,8 @@ def unpack(key): types_match = [k for k in trees_match if k[1] == in_type] if not types_match: if len(in_type) < 5: - in_type_str = ':\n {}'.format(', '.join( - f'{n}: {ty.str_short(short_dtypes=True)}' + in_type_str = ":\n {}".format(", ".join( + f"{n}: {ty.str_short(short_dtypes=True)}" for n, ty in zip(debug_info.arg_names, in_type))) else: in_type_str = '' @@ -1257,8 +1260,8 @@ def unpack(key): if type(ty1) == type(ty2) == core.ShapedArray: s1, s2 = ty1.str_short(True), ty2.str_short(True) if ty1.weak_type != ty2.weak_type: - s1 += f'{{weak_type={ty1.weak_type}}}' - s2 += f'{{weak_type={ty2.weak_type}}}' + s1 += f"{{weak_type={ty1.weak_type}}}" + s2 += f"{{weak_type={ty2.weak_type}}}" add_weak_type_hint = True elif ty1.sharding != ty2.sharding: s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) @@ -1267,9 +1270,9 @@ def unpack(key): s1, s2 = str(ty1), str(ty2) p(f" * at {name}, seen {s1}, but now given {s2}") if add_weak_type_hint: - p('where weak_type=True often means a Python builtin numeric value, and ') - p('weak_type=False means a jax.Array.') - p('See https://docs.jax.dev/en/latest/type_promotion.html#weak-types') + p("where weak_type=True often means a Python builtin numeric value, and ") + p("weak_type=False means a jax.Array.") + p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types") return done() # we think this is unreachable... diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index d66cbb912a99..cde9e4a30f99 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -56,8 +56,10 @@ def _path_starts_with(path: str, path_prefix: str) -> bool: return False def include_frame(f: types.FrameType) -> bool: - return not any(_path_starts_with(f.f_code.co_filename, path) - for path in _exclude_paths) + return include_filename(f.f_code.co_filename) + +def include_filename(filename: str) -> bool: + return not any(_path_starts_with(filename, path) for path in _exclude_paths) # When scanning stack traces, we might encounter frames from cpython that are # removed from printed stack traces, such as frames from parts of importlib. We diff --git a/tests/api_test.py b/tests/api_test.py index ec946003d570..ac1623f3beee 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4525,6 +4525,13 @@ def f(x, y): msg = cm.output[0] self.assertIn("tracing context doesn't match", msg) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_skip_internals(self): + with config.explain_cache_misses(True): + with self.assertNoLogs(level='WARNING'): + for i in range(2): + jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @jax.jit @@ -4560,12 +4567,12 @@ def f(key): f(jax.random.key(seed=123)) if is_persistent_cache_enabled(): - # 5 warnings from tracing cache, 5-10 from persistent cache depending on + # 4 warnings from tracing cache, 5-10 from persistent cache depending on # the backend - self.assertTrue(10 <= len(cm.output) <= 15) + self.assertTrue(9 <= len(cm.output) <= 15) self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) else: - self.assertLen(cm.output, 5) + self.assertLen(cm.output, 4) for msg in cm.output: self.assertIn("TRACING CACHE MISS", msg) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index dad974c2cbaf..1f5ddba89e27 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -297,12 +297,17 @@ def test_debug_info_no_source_info_built_in(self): # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.func_name, "max") + self.assertEqual(dbg.func_filename, None) + self.assertEqual(dbg.func_lineno, None) self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEndsWith(dbg.func_filename, "debug_info_test.py") + self.assertIsNotNone(dbg.func_lineno) self.assertEqual(dbg.arg_names, ("my_arg",)) def test_debug_info_save_wrapped_fun_source_info(self): From 81722201fd75e35d33bc64a765916fa651027902 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 11 Apr 2025 03:16:30 -0700 Subject: [PATCH 0558/1769] Remove legacy CPU custom call kernels that have been unused since v0.4.34. As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering: * getrf * geqrf / orgqr * potrf * gesdd * syevd * geev * gehrd Following our compatibility policy, these are now safe to remove. PiperOrigin-RevId: 746388529 --- jax/_src/export/_export.py | 8 - .../cpu_cholesky_lapack_potrf.py | 337 ------- .../cpu_eig_lapack_geev.py | 271 +----- .../cpu_eigh_lapack_syev.py | 368 -------- .../cpu_hessenberg_lapack_gehrd.py | 267 ------ .../cpu_lu_lapack_getrf.py | 519 ----------- .../cpu_qr_lapack_geqrf.py | 854 +++--------------- .../cpu_svd_lapack_gesdd.py | 427 --------- jaxlib/cpu/_lapack/__init__.pyi | 36 - jaxlib/cpu/cpu_kernels.cc | 64 -- jaxlib/cpu/lapack.cc | 150 --- jaxlib/cpu/lapack_kernels.cc | 668 +------------- jaxlib/cpu/lapack_kernels.h | 188 +--- jaxlib/cpu/lapack_kernels_using_lapack.cc | 164 ---- tests/export_back_compat_test.py | 52 +- 15 files changed, 175 insertions(+), 4198 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index eacb04890f1f..31132dc77c82 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1086,14 +1086,8 @@ def _check_lowering(lowering) -> None: "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", # triangular_solve on CPU @@ -1102,8 +1096,6 @@ def _check_lowering(lowering) -> None: "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", # tridiagonal on CPU "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", - # hessenberg on CPU - "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on TPU "LuDecomposition", # ApproxTopK on TPU diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py index eb4143615da6..ee06d902d235 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py @@ -17,345 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_spotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 24.343887, 13.603932, 20.50489 , 12.063956], - [ 13.603932, 58.879757, -31.84056 , 16.328012], - [ 20.50489 , -31.84056 , 66.890755, -9.92216 ], - [ 12.063956, 16.328012, -9.92216 , 23.640734]], dtype=float32),), - expected_outputs=(array([[ 4.9339523, 0. , 0. , 0. ], - [ 2.7572079, 7.1608353, 0. , 0. ], - [ 4.155875 , -6.0466647, 3.6134892, 0. ], - [ 2.4450896, 1.3387254, -3.3177967, 2.2050648]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf32> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_spotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc6) - return %17 : tensor<4x4xf32> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf32> loc(unknown)) -> tensor<4x4xf32> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc12) - return %8 : tensor<4x4xf32> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b\x1fO\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02J\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_spotrf\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 23.022171138130666 , -16.79765603341739 , 0.9133449305189146, - -25.36636199966769 ], - [-16.79765603341739 , 31.655770252600092 , -1.5189878284433445, - 20.0344758332268 ], - [ 0.9133449305189146, -1.5189878284433445, 10.940134497877208 , - 8.169020034607513 ], - [-25.36636199966769 , 20.0344758332268 , 8.169020034607513 , - 37.054603917509596 ]]),), - expected_outputs=(array([[ 4.7981424674691215 , 0. , 0. , - 0. ], - [-3.500866459740129 , 4.404509539513645 , 0. , - 0. ], - [ 0.19035385812557523, -0.1935707899825621 , 3.2964268922333835 , - 0. ], - [-5.286704630312426 , 0.3465604732420997 , 2.8037778311164425 , - 1.060228174247855 ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf64> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_dpotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc6) - return %17 : tensor<4x4xf64> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf64> loc(unknown)) -> tensor<4x4xf64> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc12) - return %8 : tensor<4x4xf64> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02z\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_dpotrf\x00', - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 38.089394 +6.36582342e-09j, 3.3509154+3.13455486e+01j, - -0.5972489-3.80308151e+01j, -19.04205 +1.22770605e+01j], - [ 3.3509154-3.13455486e+01j, 73.875755 +4.06565448e-09j, - -12.427276 -1.23379612e+01j, 41.542507 -9.63993359e+00j], - [ -0.5972489+3.80308151e+01j, -12.427276 +1.23379612e+01j, - 73.04141 -4.18667753e-07j, 8.193126 -2.60565052e+01j], - [-19.04205 -1.22770605e+01j, 41.542507 +9.63993359e+00j, - 8.193126 +2.60565052e+01j, 52.977036 -1.09952367e-07j]], - dtype=complex64),), - expected_outputs=(array([[ 6.1716604 +0.j , 0. +0.j , - 0. +0.j , 0. +0.j ], - [ 0.542952 -5.078949j , 6.912687 +0.j , - 0. +0.j , 0. +0.j ], - [-0.09677281+6.162169j , 2.7373738 +1.3719271j, - 5.0679703 +0.j , 0. +0.j ], - [-3.0854013 -1.9892638j, 4.7903748 +3.8177056j, - 0.3555784 +0.5865844j, 1.2276335 +0.j ]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf32> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_cpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02\xe6\x07\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_cpotrf\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 77.35445791180521 -6.4555004827448569e-16j, - 16.89356598261691 -5.4959586590823566e+00j, - -21.124380423202325+6.4431220601700787e+01j, - 55.385054340628855+2.5198457006849742e+00j], - [ 16.89356598261691 +5.4959586590823566e+00j, - 67.125263428637 -3.2921739472953976e-16j, - 25.14078382035968 +1.2783276691803774e+01j, - 51.116221409460884-2.2635508887939348e+00j], - [-21.124380423202325-6.4431220601700787e+01j, - 25.14078382035968 -1.2783276691803774e+01j, - 107.43449297637208 -2.8959717546347756e-15j, - 12.493792156221616-5.7556567757218694e+01j], - [ 55.385054340628855-2.5198457006849715e+00j, - 51.116221409460884+2.2635508887939326e+00j, - 12.493792156221616+5.7556567757218708e+01j, - 78.9856503203742 +2.0971925518284437e-16j]]),), - expected_outputs=(array([[ 8.795138311124232 +0.j , - 0. +0.j , - 0. +0.j , - 0. +0.j ], - [ 1.9207845726825759+0.624885984127274j , - 7.940111306576433 +0.j , - 0. +0.j , - 0. +0.j ], - [-2.401824698593298 -7.325776846534311j , - 4.3238621722485755-0.026813746599595675j, - 5.413152651345813 +0.j , - 0. +0.j ], - [ 6.297235174866659 -0.28650438589440164j , - 4.936910868956218 +0.849977768846063j , - 0.7751580530200595+1.279980716041562j , - 3.451611642915363 +0.j ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf64> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_zpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0bOO\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02F\x08\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_zpotrf\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index bc28857fa325..e6792dc2d1b4 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -15,279 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 - -data_2023_06_19 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642489e+00+0.j, 1.4189274e-07+0.j, - -4.0686123e-07+0.j], dtype=complex64), array([[-0.40377745 +0.j, -0.82883257 +0.j, -0.06733338 +0.j, - -0.5208027 +0.j], - [-0.46480742 +0.j, -0.4371466 +0.j, 0.49492982 +0.j, - 0.82081676 +0.j], - [-0.52583724 +0.j, -0.045459956+0.j, -0.78785884 +0.j, - -0.07922471 +0.j], - [-0.5868671 +0.j, 0.3462263 +0.j, 0.36026272 +0.j, - -0.2207891 +0.j]], dtype=complex64), array([[-0.11417642+0.j, -0.73277813+0.j, 0.16960056+0.j, - -0.5435681 +0.j], - [-0.33000448+0.j, -0.28974825+0.j, 0.16204938+0.j, - 0.67456985+0.j], - [-0.54583275+0.j, 0.15328142+0.j, -0.8329006 +0.j, - 0.28156415+0.j], - [-0.761661 +0.j, 0.5963111 +0.j, 0.5012507 +0.j, - -0.41256607+0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_sgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02v\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\t\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729802e+00+0.j, - -1.5210037805054253e-15+0.j, 1.2568096307462507e-16+0.j]), array([[-0.4037774907686232 +0.j, 0.8288327563197505 +0.j, - 0.5454962288885842 +0.j, -0.2420483778598153 +0.j], - [-0.46480737115848986 +0.j, 0.43714638836388725 +0.j, - -0.7640998541831632 +0.j, -0.04349021275982002 +0.j], - [-0.5258372515483576 +0.j, 0.045460020408024715+0.j, - -0.10828897829942748 +0.j, 0.8131255590990858 +0.j], - [-0.5868671319382249 +0.j, -0.3462263475478384 +0.j, - 0.32689260359400607 +0.j, -0.5275869684794504 +0.j]]), array([[-0.11417645138733863+0.j, 0.7327780959803554 +0.j, - 0.49133754464261303+0.j, -0.04933420991901029+0.j], - [-0.33000459866554765+0.j, 0.28974835239692637+0.j, - -0.8355289351028521 +0.j, -0.3408099365295394 +0.j], - [-0.545832745943757 +0.j, -0.1532813911865017 +0.j, - 0.1970452362778633 +0.j, 0.8296225028161098 +0.j], - [-0.7616608932219663 +0.j, -0.5963111347699308 +0.j, - 0.14714615418237506+0.j, -0.43947835636755994+0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_dgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x96\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\x0b\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464237e+01+0.j, -2.4642489e+00+0.j, -5.7737714e-07+0.j, - 1.4719126e-07+0.j], dtype=complex64), array([[ 0.4037776 +0.j, 0.8288327 +0.j, -0.53126234 -0.j, - 0.052026853-0.j], - [ 0.46480742 +0.j, 0.43714646 -0.j, 0.80768156 +0.j, - -0.47577178 -0.j], - [ 0.52583724 +0.j, 0.045459922-0.j, -0.021575088-0.j, - 0.79546237 +0.j], - [ 0.5868671 +0.j, -0.3462263 -0.j, -0.25484383 -0.j, - -0.3717177 -0.j]], dtype=complex64), array([[ 0.114176475+0.j, 0.7327782 +0.j, -0.5452461 -0.j, - -0.13326685 -0.j], - [ 0.3300045 +0.j, 0.28974816 -0.j, 0.68821603 +0.j, - -0.2182906 -0.j], - [ 0.5458328 +0.j, -0.1532814 -0.j, 0.25930583 -0.j, - 0.8363818 +0.j], - [ 0.76166093 +0.j, -0.5963111 -0.j, -0.40227592 -0.j, - -0.4848244 -0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_cgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02Z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\t)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\xfe\x0c[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572965e+01+0.j, -2.4642491965729807e+00+0.j, - -1.6035677295293283e-15+0.j, 1.2218554396786611e-16+0.j]), array([[ 0.40377749076862335 +0.j, 0.8288327563197505 +0.j, - -0.5457111210844892 +0.j, -0.2322136424094458 -0.j], - [ 0.46480737115848997 +0.j, 0.4371463883638875 -0.j, - 0.7625701354883243 +0.j, -0.06012408092789514 -0.j], - [ 0.5258372515483578 +0.j, 0.045460020408024694-0.j, - 0.1119930922768192 +0.j, 0.8168890890841272 +0.j], - [ 0.5868671319382247 +0.j, -0.34622634754783854 -0.j, - -0.32885210668065423 +0.j, -0.5245513657467864 -0.j]]), array([[ 0.11417645138733871+0.j, 0.7327780959803554 +0.j, - -0.49606131100796214+0.j, -0.04689746607984153-0.j], - [ 0.3300045986655476 +0.j, 0.2897483523969264 -0.j, - 0.8344969112540657 +0.j, -0.34421909950105706-0.j], - [ 0.5458327459437571 +0.j, -0.15328139118650172-0.j, - -0.18080988948424467+0.j, 0.8291305972416383 +0.j], - [ 0.7616608932219663 +0.j, -0.5963111347699308 -0.j, - -0.1576257107618584 +0.j, -0.4380140316607401 -0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_zgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", - xla_call_module_version=6, -) # End paste - +from numpy import array, complex64 data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index f0696db1aeda..cd5f5c55caf9 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -17,376 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , - -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], - [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , - 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], - [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , - 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], - [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , - -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], - [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , - 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], - [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , - 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], - [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , - -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], - [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , - -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], - dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_ssyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8x8xf32> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8xf32> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<177xf32> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FC00000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf32> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf32> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FC00000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf32> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf32> - return %24, %29 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02\n\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, - 3.5662489253627483e-01, -6.3034019033669797e-01, - 1.0043483479985752e-16, -2.8842036081919542e-02, - 7.7164692943283169e-25, -1.8446994643771725e-01], - [-4.7070881487314614e-01, 4.7473787464450845e-01, - -4.8036836210243367e-01, 4.3802686872516400e-01, - 1.7961797619639258e-01, 8.3080980076741355e-03, - 2.1415294457221756e-01, -2.2856669794666584e-01], - [-3.2284062926217072e-01, -5.4336490915553370e-01, - 2.2181041859724990e-01, 2.9947877954402297e-01, - -3.6491813600134632e-01, 3.2867679819727436e-01, - 3.8223299448843473e-01, -2.7266344945561438e-01], - [-1.7497244365119530e-01, -8.9251550609769414e-02, - -6.3518515114898394e-02, 1.9162997359209971e-01, - -2.2087281326110139e-01, 5.9957027043505064e-02, - -8.7632498908241274e-01, -3.1676020096456303e-01], - [-2.7104258040220038e-02, -3.3772873786627672e-01, - 2.5901386593721748e-01, 1.7032650752287815e-01, - 6.7521217612940332e-01, -4.5036136532965476e-01, - -1.2279030059078447e-02, -3.6085695247351163e-01], - [ 1.2076392757075530e-01, -3.3834734096469254e-01, - -6.5506827461665540e-01, -5.0472498521116749e-01, - 6.9987430903492118e-02, 1.0595648906599275e-01, - 8.3443844143082022e-02, -4.0495370398246017e-01], - [ 2.6863211318173097e-01, 2.2958613191407318e-01, - 6.3952843755683941e-02, 1.8776775771084137e-02, - -5.3523731432241317e-01, -5.9199531677602002e-01, - 1.7916671834524248e-01, -4.4905045549140887e-01], - [ 4.1650029879270661e-01, 3.6355449432857079e-01, - 2.9755313100756142e-01, 1.6826270392615944e-02, - 1.9621068035557282e-01, 5.6830030587314817e-01, - 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_dsyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8x8xf64> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8xf64> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<177xf64> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf64> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf64> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf64> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf64> - return %24, %29 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b/O/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, - 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, - -0.019093221 +0.j, -0.18446997 +0.j], - [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, - -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, - 0.28112718 +0.j, -0.22856665 +0.j], - [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, - -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, - -0.6373532 +0.j, -0.27266347 +0.j], - [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, - -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, - -0.083263926 +0.j, -0.31676024 +0.j], - [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, - 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, - 0.6926741 +0.j, -0.360857 +0.j], - [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, - 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, - 0.00091664493+0.j, -0.40495378 +0.j], - [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, - 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, - -0.113144726 +0.j, -0.4490505 +0.j], - [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, - -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, - -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %5 = stablehlo.negate %4 : tensor<8x8xf32> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_cheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8xf32> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<169xf32> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FC00000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf32> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf32> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0b/O\x1f/\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02&\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\t)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00F\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c128=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, - 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, - 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, - 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], - [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, - -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, - 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, - 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], - [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, - 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, - -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, - 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], - [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, - -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, - -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, - -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], - [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, - 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, - 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, - -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], - [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, - -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, - 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, - 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], - [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, - 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, - -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, - 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], - [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, - 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, - 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, - 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %5 = stablehlo.negate %4 : tensor<8x8xf64> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_zheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8xf64> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<169xf64> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf64> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf64> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0bOO//\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x96\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\x0b)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00J\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd\x00", - xla_call_module_version=4, - ), # End paste -) - data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py index 204af8f55396..8d87c2524e64 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py @@ -17,275 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2024_08_30 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, - -0.3272236912989258 -3.2003874808591863e+00j, - -3.065817294924296 +1.6978219378771007e+00j, - -3.3971558164664 +2.6931967836060400e-01j], - [ 6.346214936866542 +0.0000000000000000e+00j, - 2.083218259144673 -1.2191838498692813e+00j, - 1.9552582313969427 -3.3216313521481879e+00j, - 2.7451664155727293 +2.5460553490974451e+00j], - [-0.16133388943502391 +3.6906265775683444e-01j, - -4.698636849217318 +0.0000000000000000e+00j, - 2.5396292124414077 -3.3038474840573420e+00j, - 2.5410992366186456 +4.1958389320867528e-01j], - [ 0.47396123039280513 +3.9524384493417053e-03j, - 0.058880409351504966-7.8934332132630333e-02j, - 0.9469634796174572 +0.0000000000000000e+00j, - -3.130422531669044 -8.8070401977461810e-01j]], - - [[-6.7065483048969465 -4.1981401054281309e-01j, - -0.21813268822330256 -3.8602920478381799e+00j, - -0.8248337528620167 -2.9073223456990824e+00j, - -3.597231249446879 +2.7626541679004930e+00j], - [-6.812126638479044 +0.0000000000000000e+00j, - -0.20651586628458585 -1.0948249928988512e+00j, - -1.6675586608354327 +4.2553627621795744e+00j, - -2.410110723267707 +3.6065122124698634e-01j], - [ 0.038235817369200516-3.7823713529009173e-01j, - -8.508141062606947 +0.0000000000000000e+00j, - 4.260708077719245 -6.8052584397204630e-02j, - 5.345997177836541 -1.1955161503390279e+00j], - [-0.18541509608158574 -1.2016051097247168e-01j, - -0.02698777746917469 -4.4847463691672246e-01j, - 6.149305574585603 +0.0000000000000000e+00j, - -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , - 1.9529937219183482-0.23299856112387676j, - 1.5940499664125072-0.8044281430962614j ], - [1.6682114302246909-0.11372755955977935j, - 1.4075913155446236-0.6008708461880701j , - 1.5086928152468893-0.8609480935086589j ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_zgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xce\x0f\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\x0b)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , - 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], - [ 6.3457894 +0.j , 1.6869383 -4.6557646j , - 0.88955224-1.7617276j , 2.9149916 +4.342665j ], - [-0.2465725 -0.5776757j , -5.3007755 +0.j , - -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], - [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , - 2.8020499 +0.j , 0.5636822 -6.218306j ]], - - [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , - 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], - [ 7.5675693 +0.j , 0.5971966 -3.6699948j , - 2.246994 -1.0858283j , -0.8870981 -0.022960603j], - [-0.2183232 +0.10552277j , 5.860886 +0.j , - -5.091036 +6.2841997j , 5.008773 +1.8765848j ], - [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , - 6.4528017 +0.j , -4.233642 -0.84165764j ]]], - dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , - 1.9999101-0.013409062j], - [1.4504763-0.44363326j , 1.3110259-0.07426627j , - 1.227255 +0.97383535j ]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_cgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xae\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\t)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], - [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], - [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], - [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], - - [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], - [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], - [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], - [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], - dtype=float32), array([[1.3584172, 1.9805213, 0. ], - [1.2920669, 1.7939165, 0. ]], dtype=float32)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_sgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<4288xf32>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) - return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x96\t\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\t)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , - -0.1271353200176119 , -0.43952156917870067 ], - [ 2.2633695323673964 , 0.9965090965971986 , - -1.3244131008423046 , 1.7324542351344163 ], - [ 0.24558316247256504 , 2.922776762811796 , - 3.630059093036474 , 1.4330664619737252 ], - [-0.2856727718012896 , -0.4601276537179077 , - -2.8602148466873802 , 1.9928744545245372 ]], - - [[-0.5351339571818844 , 5.753313169426148 , - 0.1385440281649789 , 2.8445493054193807 ], - [ 4.676815781213274 , 2.920688567170204 , - -2.610159425457712 , 4.0359806870679655 ], - [-0.16963242599901043 , -2.342935131066633 , - 4.179999589709703 , -0.6810604472011716 ], - [ 0.030645999613174775, -0.2271804227402005 , - -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], - [1.9422862513069978, 1.9018440331997255, 0. ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_dgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<4288xf64>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) - return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xa6\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x0b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - data_2024_08_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py index 72d97df53a4f..2290db62e436 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py @@ -17,527 +17,8 @@ import datetime from numpy import array, int32, float32, complex64 -data_2023_06_14 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f32'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_sgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":550:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":551:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b\x1fO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02J\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\x9e\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\x9a\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\t\x00\x00\xc0\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\t\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_dgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02Z\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x0b\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_cgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02b\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\t)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00~%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8b\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_zgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0bOO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02\x82\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\x0b)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00\x82%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 94314a7ae518..bf41f3c3445c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -17,259 +17,13 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = {} +data_2025_04_02 = {} -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f32"] = dict( +data_2025_04_02['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_sgeqrf', 'lapack_sorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], - [-0.44721356, 0.36514866, -0.8164965 ], - [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], - [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_sgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<96xf32>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3xf32> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<96xf32> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FC00000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf32> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf32> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FC00000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf32> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf32> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_sorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<96xf32>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf32>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf32>>) -> tensor<96xf32> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FC00000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf32> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf32> - %43 = call @triu(%18) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %42, %43 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xae\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\t\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf\x00lapack_sorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeqrf', 'lapack_dorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], - [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], - [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, - -9.3914855054991175e+00], - [ 0.0000000000000000e+00, 1.0954451150103341e+00, - 2.1908902300206665e+00], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -8.8817841970012523e-16]])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3x3xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_dgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf64>) -> tuple, tensor<3xf64>, tensor, tensor<96xf64>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3xf64> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<96xf64> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf64> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf64> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf64> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf64> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_dorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> tuple, tensor, tensor<96xf64>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf64>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf64>>) -> tensor<96xf64> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf64> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf64> - %43 = call @triu(%18) : (tensor<3x3xf64>) -> tensor<3x3xf64> - return %42, %43 : tensor<3x3xf64>, tensor<3x3xf64> - } - func.func private @triu(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf64> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> - return %8 : tensor<3x3xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xce\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x0b\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf\x00lapack_dorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeqrf', 'lapack_cungqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], - [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], - [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], - dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], - [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], - dtype=complex64)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_cgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_cungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xd6\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\t\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xce\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8bW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf\x00lapack_cungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf', 'lapack_zungqr'], - serialized_date=datetime.date(2023, 3, 17), + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], + serialized_date=datetime.date(2025, 4, 2), inputs=(), expected_outputs=(array([[ 0. +0.j, 0.9128709291752773 +0.j, 0.40824829046386235+0.j], @@ -283,531 +37,199 @@ [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, -8.8817841970012523e-16+0.j]])), mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_zgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_zungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -data_2024_08_22 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), - inputs=(), - expected_outputs=( - array([ - [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], - [ - -0.447213595499958 - 0.0j, - 0.3651483716701102 + 0.0j, - -0.8164965809277263 + 0.0j, - ], - [ - -0.894427190999916 - 0.0j, - -0.1825741858350548 + 0.0j, - 0.40824829046386324 + 0.0j, - ], - ]), - array([ - [ - -6.7082039324993694e00 + 0.0j, - -8.0498447189992444e00 + 0.0j, - -9.3914855054991175e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 1.0954451150103341e00 + 0.0j, - 2.1908902300206665e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 0.0000000000000000e00 + 0.0j, - -8.8817841970012523e-16 + 0.0j, - ], - ]), - ), - mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_zungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1fO\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xd2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\x0b)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_zgeqrf_ffi\x00lapack_zungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c64'] = dict( +data_2025_04_02['c64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], - [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], - [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], - ], - dtype=complex64, - ), - array( - [ - [ - -6.7082043e00 + 0.0j, - -8.0498438e00 + 0.0j, - -9.3914852e00 + 0.0j, - ], - [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], - [ - 0.0000000e00 + 0.0j, - 0.0000000e00 + 0.0j, - 7.1525574e-07 + 0.0j, - ], - ], - dtype=complex64, - ), - ), + expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], + [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], + [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], + dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], + dtype=complex64)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_cungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xb2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\t)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_cgeqrf_ffi\x00lapack_cungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f32'] = dict( +data_2025_04_02['f32'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0, 0.91287076, 0.4082487], - [-0.44721356, 0.36514866, -0.8164965], - [-0.8944271, -0.18257445, 0.40824816], - ], - dtype=float32, - ), - array( - [ - [-6.7082043e00, -8.0498438e00, -9.3914852e00], - [0.0000000e00, 1.0954441e00, 2.1908894e00], - [0.0000000e00, 0.0000000e00, 7.1525574e-07], - ], - dtype=float32, - ), - ), + expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], + [-0.44721356, 0.36514866, -0.8164965 ], + [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], + [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "result[0]"}, tensor<3x3xf32> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) - return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @lapack_sorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc13) + return %4, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) - return %6 : tensor<3x3xf32> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\x9a\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\t\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_sgeqrf_ffi\x00lapack_sorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f64'] = dict( +data_2025_04_02['f64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array([ - [0.0, 0.9128709291752773, 0.40824829046386235], - [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], - [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], - ]), - array([ - [ - -6.7082039324993694e00, - -8.0498447189992444e00, - -9.3914855054991175e00, - ], - [ - 0.0000000000000000e00, - 1.0954451150103341e00, - 2.1908902300206665e00, - ], - [ - 0.0000000000000000e00, - 0.0000000000000000e00, - -8.8817841970012523e-16, - ], - ]), - ), + expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], + [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], + [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103341e+00, + 2.1908902300206665e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -8.8817841970012523e-16]])), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "result[0]"}, tensor<3x3xf64> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) - return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @lapack_dorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>, tensor<3xf64>) -> tensor<3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc13) + return %4, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) - return %6 : tensor<3x3xf64> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\xaa\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x0b\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_dgeqrf_ffi\x00lapack_dorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 2d71308caeda..995847a03a60 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -17,435 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], - [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], - [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], - [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], - - [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], - [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], - [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], - [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], - dtype=float32),), - expected_outputs=(array([[[-0.48540133 , 0.6682397 , -0.48819906 , -0.28196266 ], - [ 0.2180054 , -0.13631375 , 0.14819765 , -0.95495003 ], - [ 0.8457052 , 0.44643915 , -0.27943406 , 0.08597418 ], - [ 0.040523227, -0.57928085 , -0.8133977 , -0.03429017 ]], - - [[-0.21146733 , 0.46376425 , 0.786309 , 0.34917438 ], - [ 0.3461469 , 0.21883713 , 0.3399653 , -0.84659094 ], - [ 0.6526192 , -0.5834038 , 0.3972404 , 0.2755518 ], - [ 0.6399631 , 0.6298203 , -0.32915345 , 0.2922879 ]]], - dtype=float32), array([[ 8.551608 , 5.3574076, 2.8073738, 0.5226082], - [11.457576 , 10.041606 , 5.6716514, 1.4754109]], dtype=float32), array([[[-0.6319046 , 0.6612254 , 0.39110154 , -0.102553196], - [-0.2971051 , 0.13673358 , -0.50112 , 0.80119365 ], - [ 0.08969147 , 0.4433047 , -0.73647296 , -0.5030348 ], - [-0.7101976 , -0.5895471 , -0.23135659 , -0.30745354 ]], - - [[-0.6964344 , -0.5023085 , -0.11150039 , 0.50023323 ], - [-0.32121164 , 0.7889568 , 0.3183193 , 0.41598475 ], - [ 0.5096958 , -0.31399378 , 0.60193455 , 0.5284816 ], - [-0.3898877 , -0.16322286 , 0.7238198 , -0.5453721 ]]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf32> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_sgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>, tensor<32xi32>, tensor<268xf32>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b/\x1fOo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\t)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , - 4.702602090972179 , -0.2702264758497052 ], - [ 2.209901632583705 , -2.6286702510632773 , - 4.591276599385847 , 3.4465035398844828 ], - [-1.5083742421154478 , 3.3225165204269635 , - 1.2596205557926703 , 3.524804355848018 ], - [ 1.5118969169108838 , 1.838885943509677 , - 2.818520751293422 , 3.06002540493494 ]], - - [[-2.4045510943950843 , -1.5657555633438576 , - -0.6061472334580296 , -0.23926156407779164], - [ 4.087879920053448 , -3.2507640936811715 , - -2.2556577657517476 , 6.090369998330348 ], - [ 1.1165401344486945 , 2.2134726894037247 , - 5.225178515435584 , 1.9794693474107725 ], - [-4.127878192684534 , -0.37313660200336163, - 0.7893465897510026 , -2.0315217791342848 ]]]),), - expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105785, - -0.731253241567692 , 0.1729779025790829 ], - [-0.5623501368035175 , 0.7608931604238581 , - 0.03470920608540986, 0.32186828528169453], - [-0.39585755254587435, -0.4954770291405409 , - 0.6561880513437818 , 0.4089212062978684 ], - [-0.5157288533916834 , -0.03577207859388855, - 0.18297871183094833, -0.8362194085221047 ]], - - [[-0.12124821978030875, -0.30260506534356213, - -0.5817463045715607 , -0.7451847292758064 ], - [ 0.8877417367326685 , -0.15794001239879188, - -0.3761180739267688 , 0.2133184375808915 ], - [ 0.03055221675864994, 0.9244545314395409 , - -0.3686107533067095 , -0.09260936183071355], - [-0.44303503260363514, -0.16990864078317836, - -0.619864940232637 , 0.624994775612963 ]]]), array([[8.951386926411189 , 5.762891699811626 , 3.839104008889441 , - 1.2696468971033248 ], - [9.21500688857692 , 6.477297670883227 , 3.24626945855818 , - 0.05112101994354587]]), array([[[-0.17890276924244797 , -0.2881812520705063 , - -0.7749616998111006 , -0.5332726590950898 ], - [ 0.38712159387038353 , -0.8985113987184378 , - 0.1397618670046424 , 0.15258033445914954 ], - [-0.23140697924040152 , -0.03708202130554661 , - -0.5045854966104308 , 0.8309447696839614 ], - [-0.8744034999217865 , -0.32901938548360005 , - 0.35396957633060866 , -0.043246992182741084]], - - [[ 0.6276106632546885 , -0.26728735347872895 , - -0.22995258718774078 , 0.6941067163520401 ], - [ 0.2802931697592562 , 0.4781137804659157 , - 0.808362569504731 , 0.19847646746808023 ], - [ 0.6187014005224262 , 0.47714095343944474 , - -0.3740686697560633 , -0.49961757159793246 ], - [-0.3804591585793503 , 0.6872417290515944 , - -0.3921025301835001 , 0.47875384105714014 ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf64> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xf64> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_dgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>, tensor<32xi32>, tensor<268xf64>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b//Oo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\x0b)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , - 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], - [-7.083374 -8.127356j , 2.7596245 -4.991001j , - -0.52622825+5.033981j , -0.35441273-1.8215327j ], - [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , - -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], - [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , - -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], - - [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , - -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], - [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , - 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], - [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , - -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], - [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , - -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], - dtype=complex64),), - expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.5452691 +0.5572638j , - 0.41363996 +0.18964858j , -0.26152334 -0.28195143j ], - [ 0.53678626 +0.64057267j , -0.21783225 -0.21288812j , - 0.28426644 +0.30535883j , 0.15201284 +0.10768581j ], - [ 0.21286921 +0.154735j , 0.066471666-0.25652882j , - -0.4074613 -0.10356682j , -0.11794163 -0.81844836j ], - [-0.39079374 -0.20583564j , -0.18335931 -0.4421772j , - 0.63489586 +0.19758748j , 0.038680226-0.36351213j ]], - - [[-0.3178596 +0.39032036j , -0.1273337 -0.30841744j , - 0.26394194 +0.26815224j , -0.21332254 -0.66947937j ], - [-0.39241245 -0.60790956j , -0.14006221 +0.41040683j , - -0.0830612 -0.10184447j , -0.45091942 -0.2603987j ], - [-0.36103728 +0.2876153j , -0.4965461 +0.10084368j , - -0.13752826 -0.6203828j , 0.35439825 -0.028546419j], - [ 0.062335093-0.078214265j, 0.35014474 -0.5668197j , - -0.42214075 -0.5090833j , -0.2889288 -0.15894148j ]]], - dtype=complex64), array([[15.135655 , 9.373035 , 7.444931 , 0.41523397], - [12.316969 , 8.661011 , 5.005059 , 2.115905 ]], - dtype=float32), array([[[-0.6537865 +0.j , -0.20306697 -0.6166746j , - 0.29948467 +0.24257992j , -0.007604365+0.04945353j ], - [ 0.52712685 +0.j , -0.11291563 -0.7116954j , - -0.089219 -0.36348897j , -0.23654723 -0.08269388j ], - [-0.31538543 +0.j , -0.014410622+0.15958191j , - -0.17958623 -0.13690898j , -0.6930434 -0.58613425j ], - [-0.44185135 +0.j , 0.17604677 -0.050492246j, - -0.4213856 -0.69485146j , 0.22373371 +0.2465445j ]], - - [[-0.64551586 +0.j , 0.32932255 -0.11672116j , - -0.093527466+0.6710145j , -0.038554154+0.02716677j ], - [ 0.4241116 +0.j , 0.031135002-0.539813j , - -0.26271763 +0.22760014j , -0.63609654 -0.04817467j ], - [-0.4577485 +0.j , -0.15202768 +0.2734652j , - 0.18931003 -0.3297506j , -0.7331101 -0.10269702j ], - [ 0.44034657 +0.j , 0.29474002 +0.63307834j , - 0.31271848 +0.4216674j , -0.20595454 -0.020532424j]]], - dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_cgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf32>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b/\x1fO/o\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1e\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\xc0\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\t)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[-0.9247611722912019-1.3615157109291343j , - -1.0663457975211892+4.73170030936092j , - -1.4918732811689488-2.880861991859318j , - -1.111356346434667 -2.869701609083459j ], - [-4.71291623424314 -1.5444012898828912j , - -5.232967549101415 -0.41287816948482003j, - 0.8905737109262459+9.50245186328329j , - 4.397722119094926 -6.842005210371916j ], - [ 1.9369405063276903+2.3496014107398917j , - -1.5609345742256133+4.2102103739897805j , - 0.6596030248996742+5.195353435247212j , - 0.6315014498240328-1.2778849649354402j ], - [ 5.115159214503849 -0.8856276268773485j , - 1.3719934567460779-2.236070491368575j , - 0.4974504006612811-3.0462081956756637j , - -0.2620346712025989+4.424682727912594j ]], - - [[-1.8242711798401063-0.8543252170262536j , - -2.724527211360488 +2.256038331706666j , - -1.2777487543905157+0.976556823566376j , - 3.7438974536713223-0.4994301527847589j ], - [-0.6359051102028691+2.730662301129662j , - -1.2877728943263032+3.9124921723649053j , - -3.4618573226579894+1.7835551986994034j , - -1.4710491660152465+2.144967500163963j ], - [-3.6013691182532828+2.8182351980619034j , - 2.0045935428878803+1.1146211993017152j , - -2.332213857689336 -0.874915651404938j , - -1.5393862406530452+0.6852883119580928j ], - [-2.674897392856801 +2.0724239502976984j , - -3.349108041292141 -1.0215359152295307j , - 0.2603515088197114-1.9093411474619364j , - 5.41252457188561 +8.634368042893094j ]]]),), - expected_outputs=(array([[[-0.04173678258633362+0.10796693731538423j , - 0.6813428383170976 +0.34327979589293334j , - -0.41770229002865755+0.20028957850808823j , - -0.43443513665085287+0.034743251442636465j], - [-0.8408468609573512 -0.1326064604464803j , - -0.21674151028481228+0.015170556885426551j, - 0.17147327711152338+0.1531041615298256j , - -0.3568765623609291 +0.21904384306708768j ], - [-0.2673618144044136 +0.1379833616281103j , - -0.17534278352558025-0.378992615769627j , - -0.8179957069096054 -0.037506032257391624j, - 0.25392637883428526-0.009771014463849802j], - [ 0.40569239968065934-0.08297706578106905j , - -0.4321527034953765 +0.09791545663574397j , - -0.23439193826962654-0.08427130532228161j , - -0.42348296145608866+0.6251448114949291j ]], - - [[ 0.0272684373986653 +0.36312055550335454j , - 0.270297713559288 +0.1304616587162563j , - 0.04286867013923673-0.4765859417602139j , - 0.7242702256119968 +0.15420620503522459j ], - [-0.08593436615104483+0.1189990183325552j , - 0.37050286109355285-0.6240865462984536j , - 0.46902056878806025-0.34747949920770266j , - -0.31667671459632074-0.10340064369932994j ], - [-0.07914843440873574-0.033487314943774035j, - 0.4110353453489128 -0.455090805566563j , - -0.431131803930273 +0.40910871949632j , - 0.13782730102420274+0.49428280062680086j ], - [-0.7478497242333215 +0.5283836938016964j , - -0.08345894989956631+0.011807690067190268j, - -0.27178304569905287+0.056526279406748176j, - -0.09911954913441999-0.2598859654000683j ]]]), array([[16.80132997488892 , 7.744755614558116 , 5.831221808032041 , - 1.1195288361137765], - [12.39537594694893 , 8.218551160453814 , 4.683634850274079 , - 1.8820915363839188]]), array([[[ 0.35796251040556704 +0.j , - 0.40179383774178046 -0.1269359716702074j , - -0.0751486661300563 -0.6109813931761136j , - -0.23049271148274278 +0.51209309438597j ], - [-0.4682861415308549 +0.j , - -0.013958972669495105+0.4210606476774211j , - -0.6006888466394119 -0.3766516564723718j , - -0.24264518623237025 -0.20408557153193485j ], - [-0.6392945524816095 +0.j , - 0.2432388607602898 -0.6679928485374246j , - 0.18168178910997038 -0.08126854868489754j , - -0.2030612067046724 -0.07124733621915219j ], - [-0.49383540371426055 +0.j , - -0.010402968929686592+0.3734624991410737j , - 0.27994282704104956 +0.01949406216762731j , - 0.32588905219319236 +0.6569569657140543j ]], - - [[ 0.2666920370516844 +0.j , - 0.24929033811571413 +0.27271089049933883j , - -0.012922512768026735+0.16383354123801513j , - 0.07388201893235022 -0.8717175469187741j ], - [-0.6156140469162428 +0.j , - -0.33787077397020143 +0.37797154650923376j , - -0.3916043058726119 -0.2839601305776179j , - -0.2714888604157674 -0.23729034093304682j ], - [ 0.5618758038857617 +0.j , - -0.5788776267734554 -0.13833058883452312j , - -0.48995086206819644 +0.19259594116096765j , - -0.22967101640965012 -0.012926826751577613j], - [-0.48393210641613593 +0.j , - -0.1049229605428438 -0.4911419972025977j , - -0.07782239226461217 +0.6751317817750165j , - 0.11941657609231515 -0.19354808489959852j ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_zgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf64>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', - xla_call_module_version=6, -) # End paste - data_2024_08_13 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_13["c128"] = dict( testdata_version=1, diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 4275d8e48813..f8b9a023b480 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -17,39 +17,3 @@ from . import eig as eig def initialize() -> None: ... def registrations() -> dict: ... - - -# Old-style LAPACK Workspace Size Queries -def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ... -def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def gesdd_iwork_size(m: int, n: int) -> int: ... -def heevd_rwork_size(n: int) -> int: ... -def heevd_work_size(n: int) -> int: ... -def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_cgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_chetrd_workspace(lda: int, n: int) -> int: ... -def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_dgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dsytrd_workspace(lda: int, n: int) -> int: ... -def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_sgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_ssytrd_workspace(lda: int, n: int) -> int: ... -def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_zgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_zhetrd_workspace(lda: int, n: int) -> int: ... -def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def syevd_iwork_size(n: int) -> int: ... -def syevd_work_size(n: int) -> int: ... -def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... - - -# FFI Kernel LAPACK Workspace Size Queries -def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 6ed42496f2f2..a118c20a4490 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -42,70 +42,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ctrsm", XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ztrsm", Trsm>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_spotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dpotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_ssyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dsyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgeev", ComplexGeev>::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgees", RealGees::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgees", diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 7cc4fa9e2dbd..1bb3f1f13405 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -58,19 +58,11 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - AssignKernelFn>(lapack_ptr("sgetrf")); - AssignKernelFn>(lapack_ptr("dgetrf")); - AssignKernelFn>>(lapack_ptr("cgetrf")); - AssignKernelFn>>(lapack_ptr("zgetrf")); AssignKernelFn>(lapack_ptr("sgetrf")); AssignKernelFn>(lapack_ptr("dgetrf")); AssignKernelFn>(lapack_ptr("cgetrf")); AssignKernelFn>(lapack_ptr("zgetrf")); - AssignKernelFn>(lapack_ptr("sgeqrf")); - AssignKernelFn>(lapack_ptr("dgeqrf")); - AssignKernelFn>>(lapack_ptr("cgeqrf")); - AssignKernelFn>>(lapack_ptr("zgeqrf")); AssignKernelFn>(lapack_ptr("sgeqrf")); AssignKernelFn>(lapack_ptr("dgeqrf")); AssignKernelFn>(lapack_ptr("cgeqrf")); @@ -85,28 +77,16 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeqp3")); - AssignKernelFn>(lapack_ptr("sorgqr")); - AssignKernelFn>(lapack_ptr("dorgqr")); - AssignKernelFn>>(lapack_ptr("cungqr")); - AssignKernelFn>>(lapack_ptr("zungqr")); AssignKernelFn>(lapack_ptr("sorgqr")); AssignKernelFn>(lapack_ptr("dorgqr")); AssignKernelFn>(lapack_ptr("cungqr")); AssignKernelFn>(lapack_ptr("zungqr")); - AssignKernelFn>(lapack_ptr("spotrf")); - AssignKernelFn>(lapack_ptr("dpotrf")); - AssignKernelFn>>(lapack_ptr("cpotrf")); - AssignKernelFn>>(lapack_ptr("zpotrf")); AssignKernelFn>(lapack_ptr("spotrf")); AssignKernelFn>(lapack_ptr("dpotrf")); AssignKernelFn>(lapack_ptr("cpotrf")); AssignKernelFn>(lapack_ptr("zpotrf")); - AssignKernelFn>(lapack_ptr("sgesdd")); - AssignKernelFn>(lapack_ptr("dgesdd")); - AssignKernelFn>>(lapack_ptr("cgesdd")); - AssignKernelFn>>(lapack_ptr("zgesdd")); AssignKernelFn>(lapack_ptr("sgesdd")); AssignKernelFn>(lapack_ptr("dgesdd")); AssignKernelFn>(lapack_ptr("cgesdd")); @@ -116,10 +96,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("cgesvd")); AssignKernelFn>(lapack_ptr("zgesvd")); - AssignKernelFn>(lapack_ptr("ssyevd")); - AssignKernelFn>(lapack_ptr("dsyevd")); - AssignKernelFn>>(lapack_ptr("cheevd")); - AssignKernelFn>>(lapack_ptr("zheevd")); AssignKernelFn>( lapack_ptr("ssyevd")); AssignKernelFn>( @@ -129,10 +105,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zheevd")); - AssignKernelFn>(lapack_ptr("sgeev")); - AssignKernelFn>(lapack_ptr("dgeev")); - AssignKernelFn>>(lapack_ptr("cgeev")); - AssignKernelFn>>(lapack_ptr("zgeev")); AssignKernelFn>(lapack_ptr("sgeev")); AssignKernelFn>(lapack_ptr("dgeev")); AssignKernelFn>( @@ -151,10 +123,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgees")); - AssignKernelFn>(lapack_ptr("sgehrd")); - AssignKernelFn>(lapack_ptr("dgehrd")); - AssignKernelFn>>(lapack_ptr("cgehrd")); - AssignKernelFn>>(lapack_ptr("zgehrd")); AssignKernelFn>( lapack_ptr("sgehrd")); AssignKernelFn>( @@ -186,63 +154,12 @@ nb::dict Registrations() { dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); dict["blas_ztrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["lapack_sgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_dgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_cgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_zgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_sgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_dgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_cgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_zgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_sorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_dorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_cungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_zungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_spotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_dpotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_cpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_zpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_sgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_dgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_cgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_zgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_ssyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_dsyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_cheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_zheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_sgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_dgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_cgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_zgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_sgees"] = EncapsulateFunction(RealGees::Kernel); dict["lapack_dgees"] = EncapsulateFunction(RealGees::Kernel); dict["lapack_cgees"] = EncapsulateFunction(ComplexGees>::Kernel); dict["lapack_zgees"] = EncapsulateFunction(ComplexGees>::Kernel); - - dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_cgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_zgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd::Kernel); dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd::Kernel); dict["lapack_chetrd"] = @@ -335,73 +252,6 @@ NB_MODULE(_lapack, m) { nb::enum_(schur, "Sort") .value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues) .value("kSortEigenvalues", schur::Sort::kSortEigenvalues); - - // Old-style LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_sorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n")); - m.def("sgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("dgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"), - nb::arg("compute_uv")); - m.def("cgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("zgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n")); - m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n")); - m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n")); - m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n")); - - m.def("lapack_sgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_dgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_ssytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_dsytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_chetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index ddc93261eeb5..3b510708a8bb 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -149,8 +149,7 @@ ffi::Error TriMatrixEquationSolver::Kernel( ffi::Buffer x, ffi::Buffer y, // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after // the release of JAX 0.5.2. - ffi::RemainingArgs, - ffi::ResultBuffer y_out, MatrixParams::Side side, + ffi::RemainingArgs, ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); @@ -189,42 +188,6 @@ template struct TriMatrixEquationSolver; //== LU Decomposition ==// -// lapack getrf - -template -typename Getrf::FnType* Getrf::fn = nullptr; - -template -void Getrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* ipiv = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} - -template struct Getrf; -template struct Getrf; -template struct Getrf>; -template struct Getrf>; - -// FFI Kernel - template ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, @@ -261,55 +224,6 @@ template struct LuDecomposition; //== QR Factorization ==// -// lapack geqrf - -template -typename Geqrf::FnType* Geqrf::fn = nullptr; - -template -void Geqrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int lwork = *(reinterpret_cast(data[3])); - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += std::min(m, n); - ++info; - } -} - -template -int64_t Geqrf::Workspace(lapack_int m, lapack_int n) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&m, &n, nullptr, &m, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Geqrf; -template struct Geqrf; -template struct Geqrf>; -template struct Geqrf>; - -// FFI Kernel - template ffi::Error QrFactorization::Kernel(ffi::Buffer x, ffi::ResultBuffer x_out, @@ -430,56 +344,6 @@ template struct PivotingQrFactorization; //== Orthogonal QR ==// //== Computes orthogonal matrix Q from QR Decomposition ==// -// lapack orgqr - -template -typename Orgqr::FnType* Orgqr::fn = nullptr; - -template -void Orgqr::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int k = *(reinterpret_cast(data[3])); - int lwork = *(reinterpret_cast(data[4])); - const T* a_in = reinterpret_cast(data[5]); - T* tau = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - T* work = reinterpret_cast(out[2]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, &k, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += k; - ++info; - } -} - -template -int64_t Orgqr::Workspace(int m, int n, int k) { - T work = 0; - int lwork = -1; - int info = 0; - fn(&m, &n, &k, nullptr, &m, nullptr, &work, &lwork, &info); - return info ? -1 : static_cast(std::real(work)); -} - -template struct Orgqr; -template struct Orgqr; -template struct Orgqr>; -template struct Orgqr>; - -// FFI Kernel - template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, @@ -535,42 +399,6 @@ template struct OrthogonalQr; //== Cholesky Factorization ==// -// lapack potrf - -template -typename Potrf::FnType* Potrf::fn = nullptr; - -template -void Potrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - char uplo = lower ? 'L' : 'U'; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&uplo, &n, a_out, &n, info); - a_out += static_cast(n) * static_cast(n); - ++info; - } -} - -template struct Potrf; -template struct Potrf; -template struct Potrf>; -template struct Potrf>; - -// FFI Kernel - template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, @@ -604,162 +432,6 @@ template struct CholeskyFactorization; //== Singular Value Decomposition (SVD) ==// //== using a divide and conquer method ==// -// lapack gesdd - -static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { - if (!job_opt_compute_uv) { - return 'N'; - } else if (!job_opt_full_matrices) { - return 'S'; - } - return 'A'; -} - -lapack_int GesddIworkSize(int64_t m, int64_t n) { - return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); -} - -template -typename RealGesdd::FnType* RealGesdd::fn = nullptr; - -template -void RealGesdd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - T* work = reinterpret_cast(out[6]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, - info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, &info); - return info ? -1 : static_cast(work); -} - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { - int64_t mn = std::min(m, n); - if (compute_uv == 0) { - return CastNoOverflow(7 * mn, "complex gesdd rwork"); - } - int64_t mx = std::max(m, n); - return CastNoOverflow( - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), - "complex gesdd rwork"); -} - -template -typename ComplexGesdd::FnType* ComplexGesdd::fn = nullptr; - -template -void ComplexGesdd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - typename T::value_type* rwork = - reinterpret_cast(out[6]); - T* work = reinterpret_cast(out[7]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, - iwork, info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t ComplexGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, nullptr, &info); - return info ? -1 : static_cast(work.real()); -} - -template struct RealGesdd; -template struct RealGesdd; -template struct ComplexGesdd>; -template struct ComplexGesdd>; - -// FFI Kernel - namespace internal { template @@ -949,16 +621,16 @@ static ffi::Error SvdQRKernel( for (int64_t i = 0; i < batch_count; ++i) { if constexpr (ffi::IsComplexType()) { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, rwork.get(), - info_data); + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, + x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, + vt_data, &vt_leading_dim_v, work_data.get(), + &workspace_dim_v, rwork.get(), info_data); } else { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, info_data); + svd::SVDQRType::fn( + &mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, vt_data, + &vt_leading_dim_v, work_data.get(), &workspace_dim_v, info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -970,9 +642,8 @@ static ffi::Error SvdQRKernel( } template -static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode) { +static absl::StatusOr SvdQRGetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { ffi::NativeType optimal_size = {}; lapack_int info = 0; lapack_int workspace_query = -1; @@ -994,7 +665,8 @@ static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, &u_leading_dim_v, nullptr, &vt_leading_dim_v, &optimal_size, &workspace_query, &info); } - return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) : -1; + return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) + : -1; } } // namespace internal @@ -1053,7 +725,8 @@ ffi::Error SingularValueDecompositionQRComplex::Kernel( } template -absl::StatusOr SingularValueDecompositionQR::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionQR::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); } @@ -1077,7 +750,8 @@ absl::StatusOr svd::GetRealWorkspaceSize( 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(5 * std::min(x_rows, x_cols)); } @@ -1098,109 +772,6 @@ template struct SingularValueDecompositionQRComplex; //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -// # Workspace sizes, taken from the LAPACK documentation. -lapack_int SyevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); -} - -lapack_int SyevdIworkSize(int64_t n) { - return CastNoOverflow(3 + 5 * n, "syevd iwork"); -} - -template -typename RealSyevd::FnType* RealSyevd::fn = nullptr; - -template -void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* w_out = reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - int* iwork = reinterpret_cast(out[4]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = SyevdWorkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork, - info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -// Workspace sizes, taken from the LAPACK documentation. -lapack_int HeevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); -} - -lapack_int HeevdRworkSize(int64_t n) { - return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); -} - -template -typename ComplexHeevd::FnType* ComplexHeevd::fn = nullptr; - -template -void ComplexHeevd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* w_out = - reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - typename T::value_type* rwork = - reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = HeevdWorkSize(n); - lapack_int lrwork = HeevdRworkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork, iwork, - &liwork, info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -template struct RealSyevd; -template struct RealSyevd; -template struct ComplexHeevd>; -template struct ComplexHeevd>; - -// FFI Kernel - absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { switch (mode) { @@ -1339,155 +910,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// lapack geev - -template -typename RealGeev::FnType* RealGeev::fn = nullptr; - -template -void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - T* vl_work = reinterpret_cast(out[1]); - T* vr_work = reinterpret_cast(out[2]); - - T* wr_out = reinterpret_cast(out[3]); - T* wi_out = reinterpret_cast(out[4]); - std::complex* vl_out = reinterpret_cast*>(out[5]); - std::complex* vr_out = reinterpret_cast*>(out[6]); - int* info_out = reinterpret_cast(out[7]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, - vr_work, &n_int, &work_query, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - if (!std::isfinite(a_work[j * n + k])) { - return false; - } - } - } - return true; - }; - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, - &n_int, vr_work, &n_int, work, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - if (info_out[0] == 0) { - UnpackEigenvectors(n, wi_out, vl_work, vl_out); - UnpackEigenvectors(n, wi_out, vr_work, vr_out); - } - } else { - *info_out = -4; - } - a_in += n * n; - wr_out += n; - wi_out += n; - vl_out += n * n; - vr_out += n * n; - ++info_out; - } - delete[] work; -} - -template -typename ComplexGeev::FnType* ComplexGeev::fn = nullptr; - -template -void ComplexGeev::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - - T* w_out = reinterpret_cast(out[2]); - T* vl_out = reinterpret_cast(out[3]); - T* vr_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, &work_query, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - T v = a_work[j * n + k]; - if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { - return false; - } - } - } - return true; - }; - - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, work, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - } else { - *info_out = -4; - } - a_in += n * n; - w_out += n; - vl_out += n * n; - vr_out += n * n; - info_out += 1; - } - delete[] work; -} - -template struct RealGeev; -template struct RealGeev; -template struct ComplexGeev>; -template struct ComplexGeev>; - -// FFI Kernel - template ffi::Error EigenvalueDecomposition::Kernel( ffi::Buffer x, eig::ComputationMode compute_left, @@ -1968,60 +1390,6 @@ template struct SchurDecompositionComplex; //== Hessenberg Decomposition ==// -// lapack gehrd - -template -typename Gehrd::FnType* Gehrd::fn = nullptr; - -template -void Gehrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t ilo = *reinterpret_cast(data[1]); - int32_t ihi = *reinterpret_cast(data[2]); - int32_t lda = *reinterpret_cast(data[3]); - int32_t batch = *reinterpret_cast(data[4]); - int32_t lwork = *reinterpret_cast(data[5]); - T* a = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info); - a_out += a_plus; - tau += n - 1; - ++info; - } -} - -template -int64_t Gehrd::Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Gehrd; -template struct Gehrd; -template struct Gehrd>; -template struct Gehrd>; - -// FFI Kernel - template ffi::Error HessenbergDecomposition::Kernel( ffi::Buffer x, lapack_int low, lapack_int high, diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index e075ff29387f..71ba8b8a5e0c 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -154,19 +154,6 @@ struct TriMatrixEquationSolver { //== LU Decomposition ==// -// lapack getrf - -template -struct Getrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - lapack_int* ipiv, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct LuDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -182,21 +169,6 @@ struct LuDecomposition { //== QR Factorization ==// -// lapack geqrf - -template -struct Geqrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - T* tau, T* work, lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct QrFactorization { using ValueType = ::xla::ffi::NativeType; @@ -240,23 +212,8 @@ struct PivotingQrFactorization { static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; - //== Orthogonal QR ==// -// lapack orgqr - -template -struct Orgqr { - using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct OrthogonalQr { using ValueType = ::xla::ffi::NativeType; @@ -276,16 +233,6 @@ struct OrthogonalQr { //== Cholesky Factorization ==// -// lapack potrf - -template -struct Potrf { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - template <::xla::ffi::DataType dtype> struct CholeskyFactorization { using ValueType = ::xla::ffi::NativeType; @@ -302,41 +249,6 @@ struct CholeskyFactorization { //== Singular Value Decomposition (SVD) ==// -// lapack gesdd - -lapack_int GesddIworkSize(int64_t m, int64_t n); - -template -struct RealGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, T* s, T* u, lapack_int* ldu, T* vt, - lapack_int* ldvt, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv); - -template -struct ComplexGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* s, T* u, - lapack_int* ldu, T* vt, lapack_int* ldvt, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -407,8 +319,8 @@ struct SingularValueDecompositionQR { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -432,8 +344,8 @@ struct SingularValueDecompositionQRComplex { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -451,42 +363,13 @@ using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType(), absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, ComputationMode mode); -absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols); } // namespace svd //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -lapack_int SyevdWorkSize(int64_t n); -lapack_int SyevdIworkSize(int64_t n); - -template -struct RealSyevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, T* w, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* liwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -lapack_int HeevdWorkSize(int64_t n); -lapack_int HeevdRworkSize(int64_t n); - -template -struct ComplexHeevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* w, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - namespace eig { // Eigenvalue Decomposition @@ -544,8 +427,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; -// lapack geev - // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. @@ -574,28 +455,6 @@ static void UnpackEigenvectors(Int n, const T* eigenvals_imag, const T* packed, } } -template -struct RealGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* wr, T* wi, T* vl, lapack_int* ldvl, - T* vr, lapack_int* ldvr, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* w, T* vl, lapack_int* ldvl, T* vr, - lapack_int* ldvr, T* work, lapack_int* lwork, - typename T::value_type* rwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct EigenvalueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -737,32 +596,6 @@ struct SchurDecompositionComplex { //== Hessenberg Decomposition ==// //== Reduces a non-symmetric square matrix to upper Hessenberg form ==// -// lapack gehrd - -template -struct Gehrd { - using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi); -}; - -template -struct real_type { - typedef T type; -}; -template -struct real_type> { - typedef T type; -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct HessenbergDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -785,6 +618,15 @@ struct HessenbergDecomposition { //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// +template +struct real_type { + typedef T type; +}; +template +struct real_type> { + typedef T type; +}; + // lapack sytrd/hetrd template diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 3c8ddf11cf29..e771aa0e37d1 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -118,114 +118,6 @@ static_assert( std::is_same_v::FnType, jax::Trsm>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert( std::is_same_v::FnType, jax::Sytrd::FnType>, @@ -258,22 +150,6 @@ static_assert( std::is_same_v::FnType, jax::ComplexGees>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); #undef JAX_KERNEL_FNTYPE_MISMATCH_MSG @@ -283,51 +159,11 @@ static auto init = []() -> int { AssignKernelFn>>(ctrsm_); AssignKernelFn>>(ztrsm_); - AssignKernelFn>(sgetrf_); - AssignKernelFn>(dgetrf_); - AssignKernelFn>>(cgetrf_); - AssignKernelFn>>(zgetrf_); - - AssignKernelFn>(sgeqrf_); - AssignKernelFn>(dgeqrf_); - AssignKernelFn>>(cgeqrf_); - AssignKernelFn>>(zgeqrf_); - - AssignKernelFn>(sorgqr_); - AssignKernelFn>(dorgqr_); - AssignKernelFn>>(cungqr_); - AssignKernelFn>>(zungqr_); - - AssignKernelFn>(spotrf_); - AssignKernelFn>(dpotrf_); - AssignKernelFn>>(cpotrf_); - AssignKernelFn>>(zpotrf_); - - AssignKernelFn>(sgesdd_); - AssignKernelFn>(dgesdd_); - AssignKernelFn>>(cgesdd_); - AssignKernelFn>>(zgesdd_); - - AssignKernelFn>(ssyevd_); - AssignKernelFn>(dsyevd_); - AssignKernelFn>>(cheevd_); - AssignKernelFn>>(zheevd_); - - AssignKernelFn>(sgeev_); - AssignKernelFn>(dgeev_); - AssignKernelFn>>(cgeev_); - AssignKernelFn>>(zgeev_); - AssignKernelFn>(sgees_); AssignKernelFn>(dgees_); AssignKernelFn>>(cgees_); AssignKernelFn>>(zgees_); - AssignKernelFn>(sgehrd_); - AssignKernelFn>(dgehrd_); - AssignKernelFn>>(cgehrd_); - AssignKernelFn>>(zgehrd_); - AssignKernelFn>(ssytrd_); AssignKernelFn>(dsytrd_); AssignKernelFn>>(chetrd_); diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index fd2b349f6c95..888b234e94c0 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -120,7 +120,7 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, - cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_qr_lapack_geqrf.data_2025_04_02, cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, @@ -134,12 +134,7 @@ def test_custom_call_coverage(self): # stable covering_testdatas = [ *cpu_ffi_testdatas, - cpu_cholesky_lapack_potrf.data_2023_06_19, - cpu_eig_lapack_geev.data_2023_06_19, - cpu_eigh_lapack_syev.data_2023_03_17, - cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2024_07_30, - cpu_lu_lapack_getrf.data_2023_06_14, cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, @@ -149,9 +144,7 @@ def test_custom_call_coverage(self): cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, - cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, - cpu_hessenberg_lapack_gehrd.data_2024_08_30, cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, @@ -213,10 +206,6 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -277,10 +266,6 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results, - expect_current_custom_calls=info["custom_call_targets"]) @staticmethod def eigh_input(shape, dtype): @@ -334,12 +319,6 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # Legacy custom call test - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", dtype_name=dtype_name, variant=variant) @@ -419,7 +398,7 @@ def test_cuda_lu_pivots_to_permutation(self): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) - def test_cuda_lu_lapack_getrf(self, dtype_name:str): + def test_cuda_lu_cusolver_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") dtype = dict(f32=np.float32, f64=np.float64, @@ -446,15 +425,10 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + info = cpu_qr_lapack_geqrf.data_2025_04_02[dtype_name] data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -529,14 +503,6 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/357034884): Remove legacy custom call test after mid March 2025. - legacy_data = self.load_testdata( - cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) - self.run_one_test(func, legacy_data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, - dtype=dtype), - expect_current_custom_calls=info["custom_call_targets"]) - def check_svd_results(self, input, res_run, res_exp, rtol=None, atol=None): # Following linalg_test.testSVD @@ -655,12 +621,6 @@ def func(operand): check_results=partial(self.check_svd_results, *data.inputs)) - data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, - *data.inputs), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}", dtype_name=dtype_name, algorithm_name=algorithm_name) @@ -751,12 +711,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) From 7eb397d1e56de021bb7c7ce8bcdbabd2e74f6a2b Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 9 Apr 2025 14:59:23 +0200 Subject: [PATCH 0559/1769] Make `trace` and `lower` class attributes for `jax.jit`. Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use: ``` jax.jit(f).trace(...) ``` The new attributes create problems when `jax.jit` is used along `functools.wraps`. Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`. This works as expected if you just call the result, but if you try to use it with `lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function, so when `trace` is invoked the `wrapper` is bypassed entirely. See #27829 and #27825. The solution proposed here is to make the `trace` and `lower` be class attributes, so that they are not copied by `functools.wraps`. Thus, if you try to use `lower` or `trace` on the result of `functools.wraps(jax.jit(f))()` you will get an error. That is better than silently ignoring the wrapper. The workaround is to apply `jax.jit` last among your wrappers. Fixes: #27829 --- CHANGELOG.md | 6 +++ jax/_src/api.py | 92 +++++++++++++++++++++----------------- jax/_src/pjit.py | 86 +++++++++++++++++------------------ tests/memories_test.py | 2 +- tests/pjit_test.py | 14 +++++- tests/xla_metadata_test.py | 4 +- 6 files changed, 116 insertions(+), 88 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00c96b228262..d118ef10f51f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery. * Removed the `config.jax_eager_pmap` config option. + * Disallow the calling of `lower` and `trace` AOT APIs on the result + of `jax.jit` if there have been subsequent wrappers applied. + Previously this worked, but silently ignored the wrappers. + The workaround is to apply `jax.jit` last among the wrappers, + and similarly for `jax.pmap`. + See {jax-issue}`#27873`. * Changes * The minimum CuDNN version is v9.8. diff --git a/jax/_src/api.py b/jax/_src/api.py index 2149510a2914..d170d2632d6f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1569,11 +1569,13 @@ def _cpp_pmap( out_axes) del static_broadcasted_argnums, donate_argnums + prepare_pmap_fn = partial(_prepare_pmap, + fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, + devices, backend, axis_size) + @api_boundary def cache_miss(*args, **kwargs): - p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, - axis_size, args, kwargs) + p = prepare_pmap_fn(args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) @@ -1650,48 +1652,56 @@ def cache_miss(*args, **kwargs): _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) + # Store some data for the `lower` and `trace` methods pmap_f._fun = fun + pmap_f._prepare_pmap = prepare_pmap_fn + pmap_f._backend = backend + pmap_f._axis_name = axis_name + pmap_f._donate_tuple = donate_tuple + + # TODO(necula): move these to top-level; we don't need to do this for + # every pmap + cpp_mapped_f_class = type(pmap_f) + cpp_mapped_f_class.lower = _cpp_mapped_lower + cpp_mapped_f_class.trace = _cpp_mapped_trace + # We return directly the function produced by pmap_lib.pmap, because we do not + # want to have Python in the dispatch path. + return pmap_f - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() +@api_boundary +def _cpp_mapped_trace(pmap_f, *args, **kwargs): + p = pmap_f._prepare_pmap(args, kwargs) + abstract_args = list(map(shaped_abstractify, p.flat_args)) + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( + p.flat_fun, pmap_f._backend, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) + args_info = stages.make_args_info(p.in_tree, abstract_args, pmap_f._donate_tuple) + return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) - @api_boundary - def trace(*args, **kwargs): - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - abstract_args = list(map(shaped_abstractify, p.flat_args)) - closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( - p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - avals=abstract_args) - lower_callable = partial( - pxla.lower_parallel_callable, p.flat_fun, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - closed_jaxpr=closed_jaxpr, - backend=xc_backend, - replicas=replicas, - shards=shards, - pci=pci) - args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) - return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, - p.out_tree(), lower_callable) - - pmap_f.lower = lower - pmap_f.trace = trace +@api_boundary +def _cpp_mapped_lower(pmap_f, *args, **kwargs): + return _cpp_mapped_trace(pmap_f, *args, **kwargs).lower() - return pmap_f _pmap_cache_clears = weakref.WeakSet() # type: ignore diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 83ddb8709d4c..871fb0a6c870 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -308,12 +308,6 @@ def _get_fastpath_data( return fastpath_data -def _cpp_pjit_evict_fn(self): - self._clear_cache() - _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error - _infer_params_cached.cache_clear() - - # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. @@ -374,9 +368,50 @@ def cache_miss(*args, **kwargs): cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun - type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn + cpp_pjitted_f._jit_info = jit_info + # TODO(necula): move these to top-level; we don't need to do this for + # every jit + cpp_jitted_f_class = type(cpp_pjitted_f) + # TODO(necula): make clear_cache private, no need to have it part of the API + cpp_jitted_f_class.clear_cache = jit_evict_fn + cpp_jitted_f_class.lower = jit_lower + cpp_jitted_f_class.trace = jit_trace + cpp_jitted_f_class.eval_shape = jit_eval_shape + # We return directly the function produced by _xla.pjit, because we do not + # want to have Python in the dispatch path. return cpp_pjitted_f +@api_boundary +def jit_trace(jit_func, *args, **kwargs) -> stages.Traced: + p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) + args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) + lower_callable = partial(_resolve_and_lower, args_flat, **p.params, + pgle_profiler=None) + return stages.Traced( + p.params['jaxpr'], args_info, p.params["name"], p.out_tree, + lower_callable, args_flat, p.arg_names, p.num_consts) + + +@api_boundary +def jit_lower(jit_func, *args, **kwargs): + return jit_trace(jit_func, *args, **kwargs).lower() + +@api_boundary +def jit_eval_shape(jit_func, *args, **kwargs): + p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] + # TODO(yashkatariya): Add `Layout` to SDS. + out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, + weak_type=x.weak_type) + for x, s in zip(p.params['jaxpr'].out_avals, out_s)] + return tree_unflatten(p.out_tree, out) + +def jit_evict_fn(self): + self._clear_cache() + _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error + _infer_params_cached.cache_clear() + def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) @@ -484,41 +519,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) - -def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): - - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() - - @api_boundary - def eval_shape(*args, **kwargs): - p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] - # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, - weak_type=x.weak_type) - for x, s in zip(p.params['jaxpr'].out_avals, out_s)] - return tree_unflatten(p.out_tree, out) - - @api_boundary - def trace(*args, **kwargs) -> stages.Traced: - p, args_flat = _infer_params(fun, jit_info, args, kwargs) - donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) - args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) - lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) - return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) - - wrapped = _cpp_pjit(fun, jit_info) - wrapped.lower = lower - wrapped.eval_shape = eval_shape - wrapped.trace = trace - return wrapped - - def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnums: int | Sequence[int] | None, donate_argnames: str | Iterable[str] | None, @@ -533,7 +533,7 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, keep_unused, inline, compiler_options, use_resource_env) - return _make_jit_wrapper(fun, jit_info) + return _cpp_pjit(fun, jit_info) class PjitParams(NamedTuple): diff --git a/tests/memories_test.py b/tests/memories_test.py index fca36570216f..203bc4abb613 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1501,8 +1501,8 @@ def test_mem_kind_donation_pinned_host(self): s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') - @compute_on('device_host') @functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1)) + @compute_on('device_host') def f(inp1, inp2): return inp1 * 2, inp2 * 2 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2db75be18475..f90b02cbdaf7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -14,7 +14,7 @@ from collections import OrderedDict, namedtuple import re -from functools import partial +from functools import partial, wraps import logging import json import math @@ -940,6 +940,18 @@ def testWithCustomPRNGKey(self): # Make sure this doesn't crash pjit(lambda x: x, in_shardings=None, out_shardings=None)(key) + def test_lower_with_wrapper_error(self): + @jax.jit + def f(x): + return x + + self.assertAllClose(1., f(1.)) + self.assertAllClose(1., f.lower(1.).compile()(1.)) + wrapped_f = wraps(f)(lambda x: f(x + 1)) + + with self.assertRaisesRegex(AttributeError, "has no attribute 'lower'"): + wrapped_f.lower(1.) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompile(self): @partial(pjit, diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index 33fd7a08b1de..ba2120fcb7b9 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -288,10 +288,10 @@ def f2(x, y): with set_xla_metadata(a="b"): return (x + y, y * 2.0) - f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + f2_vmap = jax.vmap(f2, in_axes=(0, None)) self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', - f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + jax.jit(f2_vmap).lower(jnp.arange(5.0), 1.0).as_text(), ) def test_multiple_instructions(self): From 896557f07b133ba635245ec4091de25b80e88edd Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 11 Apr 2025 06:14:36 -0700 Subject: [PATCH 0560/1769] Register NVPTX LLVM backend from Mosaic custom call So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target. PiperOrigin-RevId: 746433654 --- jaxlib/mosaic/gpu/BUILD | 3 ++- jaxlib/mosaic/gpu/custom_call.cc | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index bb27930a703c..0eb24781379e 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -133,11 +133,12 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], deps = [ + ":mosaic_gpu_comm", ":passes", ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", - "//jaxlib/mosaic/gpu:mosaic_gpu_comm", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 38a388224d65..a933b72ad55a 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" @@ -105,6 +106,16 @@ namespace ffi = xla::ffi; using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); +void EnsureLLVMNVPTXTargetIsRegistered() { + static absl::once_flag register_nvptx_target_flag; + absl::call_once(register_nvptx_target_flag, []() { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + }); +} + absl::StatusOr> GetSmAndPtxIsaVersion() { // Assumes driver has been initialized and a context exists. XLA already has // some utilities to query this, but we try to stay runtime-agnostic, so we @@ -123,13 +134,18 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { device) != CUDA_SUCCESS) { return absl::InternalError("Failed to get minor compute capability"); } + EnsureLLVMNVPTXTargetIsRegistered(); return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); } + mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { - static bool register_once = []() { + static absl::once_flag register_passes_flag; + absl::call_once(register_passes_flag, []() { + EnsureLLVMNVPTXTargetIsRegistered(); + llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -157,8 +173,7 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerByvalInsertionPass(); mlir::arith::registerArithExpandOpsPass(); return true; - }(); - (void)register_once; + }); return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( From b49972d1ce9ad46dc3348af697d9af962b80ba14 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 06:18:39 -0700 Subject: [PATCH 0561/1769] Move test skip for unary_ops_accuracy_test to a setUp method. The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests. But we don't really need a special decorator here anyway. PiperOrigin-RevId: 746434607 --- jax/_src/test_util.py | 14 -------------- tests/unary_ops_accuracy_test.py | 10 +++++++--- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c493d8297f21..3fff55d9ed1c 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -601,20 +601,6 @@ def skip(test_method): return skip -def skip_if_stablehlo_version_less_than(required_version): - def skip(test_method): - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - if not stablehlo_version_at_least(required_version): - plugin_version = xla_bridge.backend_stablehlo_version() - raise unittest.SkipTest( - f"Skipping since test requires StableHLO v{required_version}, and plugin" - f" version is v{plugin_version}.") - return test_method(self, *args, **kwargs) - return test_method_wrapper - return skip - - def format_test_name_suffix(opname, shapes, dtypes): arg_descriptions = (format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes)) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index e74b3a2d9669..6f51651af687 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -15,7 +15,6 @@ """Unit test for result accuracy for unary ops.""" from typing import Any, Callable, NamedTuple, Union -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -170,10 +169,15 @@ def generate_test_cases(op_names): return test_cases -@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.") -@jtu.skip_if_stablehlo_version_less_than("1.10.0") class UnaryOpsAccuracyTest(jtu.JaxTestCase): + def setUp(self): + if not jtu.stablehlo_version_at_least("1.10.0"): + self.skipTest("Test requires StableHLO v1.10.0 or higher.") + if not jtu.is_device_tpu(): + self.skipTest("Skipping test on non TPU devices.") + super().setUp() + def test_result_accuracy_mode_attr(self): with ir.Context() as context: hlo.register_dialect(context) From 8082186fa7cf909c630351783277d48a32de3672 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 11 Apr 2025 06:48:54 -0700 Subject: [PATCH 0562/1769] Fix api_test on persistent cache enabled platform Follow-up from https://github.com/jax-ml/jax/pull/27916. jax-fixit PiperOrigin-RevId: 746442635 --- tests/api_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/api_test.py b/tests/api_test.py index ac1623f3beee..72623279192f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4527,6 +4527,9 @@ def f(x, y): @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_skip_internals(self): + if is_persistent_cache_enabled(): + self.skipTest('With persistent cache, we see the cache misses') + with config.explain_cache_misses(True): with self.assertNoLogs(level='WARNING'): for i in range(2): From 614ef37ce7a07691ed36950089457382d160bc86 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 06:50:23 -0700 Subject: [PATCH 0563/1769] Fix test flakiness in tpu_pallas_test when JAX_TEST_NUM_THREADS > 1. stdout redirection is inherently racy; mark test cases doing it as thread unsafe. PiperOrigin-RevId: 746443039 --- tests/pallas/tpu_pallas_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 0bb7b45d7944..ce9348b594b0 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2416,6 +2416,7 @@ def kernel(x_ref, o_ref): class PallasCallTraceTest(PallasBaseTest): + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): @@ -2435,6 +2436,7 @@ def kernel(o_ref): self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_run_scoped(self): def kernel(o_ref): def scope1(): From d543df13248a0f88ceb3f4656cdb736fa6562540 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 11 Apr 2025 06:55:08 -0700 Subject: [PATCH 0564/1769] [pallas:mosaic_gpu] Added support for `unroll=True` to the `lax.fori_loop` lowering PiperOrigin-RevId: 746444372 --- jax/_src/pallas/mosaic_gpu/lowering.py | 35 ++++++++++++++++++-------- jax/_src/pallas/mosaic_gpu/pipeline.py | 11 +++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 49e337d2ba49..726d89bfffc5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2140,12 +2140,12 @@ def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: ir.Value, - length: ir.Value, + length: ir.Value | int, consts, *args, has_loop_index: bool, + unroll: bool = False, ): - _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) arg_avals = arg_avals[has_loop_index:] out_avals = [] @@ -2164,7 +2164,6 @@ def as_values(vals, avals): ) return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] - @mgpu.fori(length, as_values(args, arg_avals)) def loop(loop_index, body_args): if has_loop_index: loop_index = arith_dialect.addi(loop_index, start) @@ -2176,7 +2175,16 @@ def loop(loop_index, body_args): ) return as_values(outs, out_avals) - return loop.results + if unroll: + assert isinstance(length, int) + outs = as_values(args, arg_avals) + for i in range(length): + outs = loop(_ir_constant(i, start.type), outs) + return outs + else: + if not isinstance(length, ir.Value): + length = _ir_constant(length, start.type) + return mgpu.fori(length, as_values(args, arg_avals))(loop).results @register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) @@ -2197,10 +2205,10 @@ def _scan_lowering_rule( if ( (num_extensive := len(args) - num_consts - num_carry) or reverse - or unroll != 1 + or not (unroll == 1 or unroll == length) ): raise NotImplementedError - del linear, num_extensive, reverse, unroll + del linear, num_extensive, reverse jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts if jaxpr_consts: @@ -2216,17 +2224,24 @@ def _scan_lowering_rule( start, *args = args index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) - length = _ir_constant(length, start.type) else: start = _i32_constant(0) - length = _i32_constant(length) + for_out = _lower_jaxpr_to_for_loop( - ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ctx, + jaxpr, + start, + length, + consts, + *args, + has_loop_index=has_loop_index, + unroll=unroll == length, ) if has_loop_index: # Need to return the final loop index value if the outer scan expects # it as an output. - return [length, *for_out] + loop_index = arith_dialect.addi(start, _ir_constant(length, start.type)) + return [loop_index, *for_out] return for_out diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 0b72763ebc20..1e52f8701fcb 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -655,8 +655,10 @@ def _init_step(step, indices): buf_slot = _get_slot(step, not bref.is_index_invariant) bref.copy_in(buf_slot, indices, barrier) return _inc_grid_by_1(indices, grid) - # TODO(apaszke): Unroll when grid is static (need support in lowering). - indices = jax.lax.fori_loop(0, prologue_steps, _init_step, indices) + + indices = jax.lax.fori_loop( + 0, prologue_steps, _init_step, indices, unroll=not has_dynamic_grid + ) def memory_loop_body(step, carry): indices, = carry @@ -687,8 +689,9 @@ def memory_loop_body(step, carry): def _epi_step(step, _): for barrier in consumed_barrier_refs: gpu_primitives.barrier_wait(barrier.at[step]) - # TODO(apaszke): Unroll when grid is static (need support in lowering). - jax.lax.fori_loop(0, prologue_steps, _epi_step, None) + jax.lax.fori_loop( + 0, prologue_steps, _epi_step, None, unroll=not has_dynamic_grid + ) wg_idx = lax.axis_index(wg_axis) lax.cond( From b3c0ec04865c1ec58e3d3d47900eefaed962d665 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 11 Apr 2025 07:11:00 -0700 Subject: [PATCH 0565/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ca9011742bb84b3d2158feb262ddca221957ccc9. PiperOrigin-RevId: 746448816 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index bd126eaf032f..fb219ad99cb6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9f2aa85b909a615eef2212286f1e0c5684fc6b6c" -XLA_SHA256 = "e6c8ec9983d9c5ef0d12f8479079d17d382c59f2e243bd886d9c3df2e61206fd" +XLA_COMMIT = "ca9011742bb84b3d2158feb262ddca221957ccc9" +XLA_SHA256 = "7a0eb3d157236c0e9b4bdf2598d411a216e3fb7bbc0b47d20810746fb0ba772c" def repo(): tf_http_archive( From 8b7319afe9264938f9aa4767225ca14fb6d8ccfc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 08:08:53 -0700 Subject: [PATCH 0566/1769] [JAX] Remove calls to jax.dlpack.to_dlpack(), and avoid passing DLPack capsules to jax.dlpack.from_dlpack(). to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol. PiperOrigin-RevId: 746464890 --- jax/experimental/jax2tf/call_tf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 3b175cd64c4c..bb2af54025bc 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -348,8 +348,7 @@ def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax) - return tf.experimental.dlpack.from_dlpack(arg_dlpack) + return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__()) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! From 3736e5ba8538cfc503c6de2346c298930f09ddee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 09:33:52 -0700 Subject: [PATCH 0567/1769] Bump the JAX version to v0.6.0, which will be the next release version. PiperOrigin-RevId: 746490665 --- jax/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/version.py b/jax/version.py index 21662d078f7f..c0e59702692e 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.4" +_version = "0.6.0" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None From 5adac1cb8a5ce9f26ea62016582063bad57c9fa0 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 11 Apr 2025 09:53:08 -0700 Subject: [PATCH 0568/1769] Fix the printing of the function name in tracing-cache-miss explanations jax-fixit PiperOrigin-RevId: 746496570 --- jax/_src/pjit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 871fb0a6c870..459d2ebea80e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1169,13 +1169,13 @@ def unpack(key): p(f"TRACING CACHE MISS at {callsite} because:") # have we seen this function before at all? - src_info = debug_info.func_name + src_info = "" if func_filename: src_info += f" defined at {func_filename}" if func_lineno := debug_info.func_lineno: src_info += f":{func_lineno}" if unseen_f: - p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") + p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}") if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") @@ -1183,7 +1183,7 @@ def unpack(key): callsites_with_tracing_cache_miss.add(callsite) return done() else: - p(f" for {fun_name}{src_info}") + p(f" for {debug_info.func_name}{src_info}") seen_keys = map(unpack, cache.keys()) From 88dae18c684caeea17cfe7563b71f5bde08a7ab5 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 11 Apr 2025 09:57:39 -0700 Subject: [PATCH 0569/1769] [Pallas] Fix potential race condition in Pallas TPU docs --- docs/pallas/tpu/distributed.ipynb | 156 +++++++++++++++--------------- docs/pallas/tpu/distributed.md | 127 ++++++++++++------------ 2 files changed, 142 insertions(+), 141 deletions(-) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index b52ec579f508..ad047963fbce 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -17,12 +17,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": { "executionInfo": { - "elapsed": 1978, + "elapsed": 52, "status": "ok", - "timestamp": 1722904801801, + "timestamp": 1744390458993, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -30,18 +30,19 @@ "user_tz": 420 }, "id": "PyAGnWc9yI8T", - "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480" + "outputId": "c5912653-c34b-4810-c373-4a2787691317" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running with 4 TPU v5 lite devices.\n" + "Running with 4 TPU v4 devices.\n" ] } ], "source": [ + "import functools\n", "import jax\n", "from jax import lax\n", "from jax import numpy as jnp\n", @@ -215,12 +216,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": { "executionInfo": { - "elapsed": 1606, + "elapsed": 152, "status": "ok", - "timestamp": 1722904803566, + "timestamp": 1744390459367, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -228,7 +229,7 @@ "user_tz": 420 }, "id": "YkyIKN2thZ-V", - "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0" + "outputId": "26719bb9-87ff-46dd-af90-a114ce332417" }, "outputs": [ { @@ -338,12 +339,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": { "executionInfo": { - "elapsed": 812, + "elapsed": 209, "status": "ok", - "timestamp": 1722904804531, + "timestamp": 1744390459789, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -351,7 +352,7 @@ "user_tz": 420 }, "id": "ojQEZB5mBRqM", - "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b" + "outputId": "3a4373f8-1fb5-4a6b-b88e-3461c2609021" }, "outputs": [ { @@ -483,7 +484,7 @@ { "cell_type": "markdown", "metadata": { - "id": "KgU7HI2pS4om" + "id": "EDCmAaHVtY7x" }, "source": [ "## Advanced Techniques\n", @@ -651,12 +652,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": { "executionInfo": { - "elapsed": 254, + "elapsed": 248, "status": "ok", - "timestamp": 1722904804952, + "timestamp": 1744390460289, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -664,7 +665,7 @@ "user_tz": 420 }, "id": "XrY5bMlvBroQ", - "outputId": "77497000-4496-462e-cc3c-73fb640cc14c" + "outputId": "9216e749-48d2-43ff-d64b-bd419acf3e11" }, "outputs": [ { @@ -674,7 +675,7 @@ "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", - "Difference |Pallas - lax.psum| = 1.4959369e-08\n" + "Difference |Pallas - lax.psum| = 1.0535587e-08\n" ] } ], @@ -687,6 +688,41 @@ "input_arr = jax.device_put(input_arr, sharding)\n", "\n", "\n", + "def local_barrier(left_neighbor, right_neighbor, double_barrier=True):\n", + " \"\"\"Performs a barrier with neighbors on the global barrier semaphore.\n", + "\n", + " Optionally performs a second barrier, which prevents a potential race\n", + " when re-using the same collective_id across kernel invocations.\n", + " \"\"\"\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + " if double_barrier:\n", + " # The double-barrier prevents a race condition where one neighbor can\n", + " # re-enter the kernel again on a subsequent call and increment the\n", + " # barrier semaphore a second time. This would unblock the current device\n", + " # even if the other neighbor is not ready yet.\n", + " # To implement a double-barrier, we stack-allocate a second REGULAR\n", + " # semaphore using run_scoped.\n", + " @functools.partial(pl.run_scoped,\n", + " second_barrier=pltpu.SemaphoreType.REGULAR)\n", + " def _(second_barrier):\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " second_barrier,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(second_barrier, 2)\n", + "\n", + "\n", "def all_reduce_kernel(\n", " x_ref,\n", " o_ref,\n", @@ -709,20 +745,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch.\n", " o_ref[...] = jnp.zeros_like(o_ref)\n", @@ -892,12 +915,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": { "executionInfo": { - "elapsed": 544, + "elapsed": 362, "status": "ok", - "timestamp": 1722904805699, + "timestamp": 1744390460871, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1017,20 +1040,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n", " o_ref[...] = jnp.zeros_like(o_ref[...])\n", @@ -1179,12 +1189,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": { "executionInfo": { - "elapsed": 596, + "elapsed": 917, "status": "ok", - "timestamp": 1722904806442, + "timestamp": 1744390461967, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1192,7 +1202,7 @@ "user_tz": 420 }, "id": "E-NMh-_teoi4", - "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0" + "outputId": "6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c" }, "outputs": [ { @@ -1356,12 +1366,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": { "executionInfo": { - "elapsed": 1341, + "elapsed": 997, "status": "ok", - "timestamp": 1722904807930, + "timestamp": 1744390463178, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1474,20 +1484,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " initial_left_copy.start()\n", " initial_left_copy.wait()\n", @@ -1635,12 +1632,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": { "executionInfo": { - "elapsed": 768, + "elapsed": 1132, "status": "ok", - "timestamp": 1722904808851, + "timestamp": 1744390464532, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1648,7 +1645,7 @@ "user_tz": 420 }, "id": "cTEyiMDyx9Y0", - "outputId": "1de26695-3713-430e-9ab4-4ea646691680" + "outputId": "70ce154e-dab2-4ae0-e297-c4774d29da85" }, "outputs": [ { @@ -1710,6 +1707,13 @@ } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" @@ -1733,5 +1737,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index c1f216c6153e..deed916ceb62 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -26,16 +26,17 @@ Some recommended readings beforehand: ```{code-cell} ipython3 --- executionInfo: - elapsed: 1978 + elapsed: 52 status: ok - timestamp: 1722904801801 + timestamp: 1744390458993 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: PyAGnWc9yI8T -outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 +outputId: c5912653-c34b-4810-c373-4a2787691317 --- +import functools import jax from jax import lax from jax import numpy as jnp @@ -195,15 +196,15 @@ In order to call the kernel in distributed mode, we wrap the `pallas_call` in a ```{code-cell} ipython3 --- executionInfo: - elapsed: 1606 + elapsed: 152 status: ok - timestamp: 1722904803566 + timestamp: 1744390459367 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: YkyIKN2thZ-V -outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 +outputId: 26719bb9-87ff-46dd-af90-a114ce332417 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -296,15 +297,15 @@ We can re-purpose Pallas's `grid` argument to implement the loop. Rather than it ```{code-cell} ipython3 --- executionInfo: - elapsed: 812 + elapsed: 209 status: ok - timestamp: 1722904804531 + timestamp: 1744390459789 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: ojQEZB5mBRqM -outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b +outputId: 3a4373f8-1fb5-4a6b-b88e-3461c2609021 --- partition = P('x', None) mesh = jax.make_mesh((num_devices,), ('x',)) @@ -411,7 +412,7 @@ print('Difference |Pallas - lax.all_gather| = ', A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. -+++ {"id": "KgU7HI2pS4om"} ++++ {"id": "EDCmAaHVtY7x"} ## Advanced Techniques @@ -563,15 +564,15 @@ Note that this is not an optimal or fully general kernel, as the block sizes mus ```{code-cell} ipython3 --- executionInfo: - elapsed: 254 + elapsed: 248 status: ok - timestamp: 1722904804952 + timestamp: 1744390460289 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: XrY5bMlvBroQ -outputId: 77497000-4496-462e-cc3c-73fb640cc14c +outputId: 9216e749-48d2-43ff-d64b-bd419acf3e11 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -581,6 +582,41 @@ input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) input_arr = jax.device_put(input_arr, sharding) +def local_barrier(left_neighbor, right_neighbor, double_barrier=True): + """Performs a barrier with neighbors on the global barrier semaphore. + + Optionally performs a second barrier, which prevents a potential race + when re-using the same collective_id across kernel invocations. + """ + barrier_sem = pltpu.get_barrier_semaphore() + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + if double_barrier: + # The double-barrier prevents a race condition where one neighbor can + # re-enter the kernel again on a subsequent call and increment the + # barrier semaphore a second time. This would unblock the current device + # even if the other neighbor is not ready yet. + # To implement a double-barrier, we stack-allocate a second REGULAR + # semaphore using run_scoped. + @functools.partial(pl.run_scoped, + second_barrier=pltpu.SemaphoreType.REGULAR) + def _(second_barrier): + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + second_barrier, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(second_barrier, 2) + + def all_reduce_kernel( x_ref, o_ref, @@ -603,20 +639,7 @@ def all_reduce_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch. o_ref[...] = jnp.zeros_like(o_ref) @@ -772,9 +795,9 @@ In terms of construction of the kernel, we introduce an additional `phase` dimen ```{code-cell} ipython3 --- executionInfo: - elapsed: 544 + elapsed: 362 status: ok - timestamp: 1722904805699 + timestamp: 1744390460871 user: displayName: Justin Fu userId: '17543197034567316452' @@ -890,20 +913,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. o_ref[...] = jnp.zeros_like(o_ref[...]) @@ -1053,15 +1063,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 596 + elapsed: 917 status: ok - timestamp: 1722904806442 + timestamp: 1744390461967 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: E-NMh-_teoi4 -outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0 +outputId: 6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c --- # Compare our result to XLA. def lax_reduce_sum_scatter(x): @@ -1197,9 +1207,9 @@ The full kernel is as follows: ```{code-cell} ipython3 --- executionInfo: - elapsed: 1341 + elapsed: 997 status: ok - timestamp: 1722904807930 + timestamp: 1744390463178 user: displayName: Justin Fu userId: '17543197034567316452' @@ -1308,20 +1318,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) initial_left_copy.start() initial_left_copy.wait() @@ -1470,15 +1467,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 768 + elapsed: 1132 status: ok - timestamp: 1722904808851 + timestamp: 1744390464532 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: cTEyiMDyx9Y0 -outputId: 1de26695-3713-430e-9ab4-4ea646691680 +outputId: 70ce154e-dab2-4ae0-e297-c4774d29da85 --- # Now we compare our result to XLA. def lax_reduce_sum_scatter(x): From b1c96d47ed9876a74ee2686234201aacd7cd7791 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 11 Apr 2025 10:57:42 -0700 Subject: [PATCH 0570/1769] Remove unused execute_sharded_* functions. PiperOrigin-RevId: 746520758 --- jaxlib/xla/py_executable.cc | 142 ++++++++++---------------- jaxlib/xla/py_executable.h | 7 -- jaxlib/xla/xla.cc | 9 -- jaxlib/xla/xla_client_test.py | 19 ++-- jaxlib/xla/xla_extension/__init__.pyi | 10 -- 5 files changed, 64 insertions(+), 123 deletions(-) diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc index 71e6cfbdba7f..eaf5af34f883 100644 --- a/jaxlib/xla/py_executable.cc +++ b/jaxlib/xla/py_executable.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" @@ -131,56 +132,49 @@ std::vector> PyLoadedExecutable::AddressableDevices() namespace { -// Traits classes of common methods for std::vector. -template -struct ShardedBufferAdapter; - -template <> -struct ShardedBufferAdapter { - static int num_devices(const ExecuteShardedArg& arg) { - if (std::holds_alternative(arg)) { - return std::get(arg).num_addressable_shards(); - } else { - return std::get>(arg).size(); - } +static int GetNumDevices(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); } - static tsl::RCReference GetIfRtArray( - const ExecuteShardedArg& arg) { - if (std::holds_alternative(arg)) { - return tsl::FormRef(std::get(arg).ifrt_array()); - } - auto& arg_vector = std::get>(arg); - - // TODO(hyeontaek): This on-demand Array creation is not efficient and has - // insufficient information about the shape (a dummy shape is used). This - // should be removed if possible and only be used in the context where the - // shape information is unused. - std::vector> ifrt_arrays; - ifrt_arrays.reserve(arg_vector.size()); - absl::InlinedVector devices; - devices.reserve(arg_vector.size()); - for (auto& arr : arg_vector) { - CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) - << arr.ifrt_array()->sharding().DebugString(); - ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); - devices.push_back( - arr.ifrt_array()->sharding().devices()->devices().front()); - } - CHECK(!ifrt_arrays.empty()); - // Use a dummy shape. - // TODO(hyeontaek): Find a way to compute a correct shape. - // TODO(yashkatariya): Plumb sharding or memory_kind here. - ifrt::Client* client = ifrt_arrays.front()->client(); - auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( - ifrt_arrays.front()->shape(), - ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), - ifrt::MemoryKind()), - absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, - ifrt::SingleDeviceShardSemantics::kAddressableShards); - TF_CHECK_OK(ifrt_array.status()); - return *ifrt_array; +} +static tsl::RCReference GetIfRtArray( + const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); } -}; + auto& arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector> ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto& arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client* client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; +} void PopulateExecuteShardedResults( const nb_class_ptr& client, @@ -206,10 +200,10 @@ void PopulateExecuteShardedResults( } } -template > absl::StatusOr ExecuteShardedOnLocalDevicesInternal( const ifrt::ExecuteOptions& options, const nb_class_ptr& client, - ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, + ifrt::LoadedExecutable* ifrt_loaded_executable, + absl::Span args, std::optional>>& returned_futures) { std::vector> output_arrays; std::unique_ptr> returned_future; @@ -218,20 +212,22 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( { nb::gil_scoped_release gil_release; for (const auto& arg : args) { - if (ArgAdapter::num_devices(arg) != num_computations) { + if (GetNumDevices(arg) != num_computations) { return InvalidArgument( "Expected args to execute_sharded_on_local_devices to have %d " "shards, got: [%s]", num_computations, - absl::StrJoin(args, ", ", [](std::string* out, const ArgT& arg) { - out->append(std::to_string(ArgAdapter::num_devices(arg))); - })); + absl::StrJoin(args, ", ", + [](std::string* out, const ExecuteShardedArg& arg) { + out->append(std::to_string(GetNumDevices(arg))); + })); } } std::vector> arg_arrays(args.size()); - absl::c_transform(args, arg_arrays.begin(), [&](const ArgT& arg) mutable { - return ArgAdapter::GetIfRtArray(arg); - }); + absl::c_transform(args, arg_arrays.begin(), + [&](const ExecuteShardedArg& arg) mutable { + return GetIfRtArray(arg); + }); TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( absl::MakeSpan(arg_arrays), options, /*devices=*/std::nullopt)); @@ -368,38 +364,6 @@ std::vector PyExecuteResults::ConsumeWithHandlers( return outputs; } -absl::StatusOr>> -PyLoadedExecutable::ExecuteShardedOnLocalDevices( - absl::Span args) { - xla::ifrt::ExecuteOptions options = options_; - options.launch_id = GetNextLaunchId(); - options.fill_status = false; - options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); - std::optional>> returned_futures; - TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options, client_, ifrt_loaded_executable_.get(), args, - returned_futures)); - return outputs_and_tokens.DisassembleIntoSingleDeviceArrays(); -} - -absl::StatusOr>, PyShardedToken>> -PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens( - absl::Span args) { - xla::ifrt::ExecuteOptions options = options_; - options.launch_id = GetNextLaunchId(); - options.fill_status = true; - options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); - std::optional>> returned_futures; - returned_futures.emplace(); - TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options, client_, ifrt_loaded_executable_.get(), args, - returned_futures)); - return std::make_pair(outputs_and_tokens.DisassembleIntoSingleDeviceArrays(), - outputs_and_tokens.ConsumeToken()); -} - absl::StatusOr PyLoadedExecutable::ExecuteSharded( std::vector args, bool with_tokens) { xla::ifrt::ExecuteOptions options = options_; diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 804682db717e..9c8ce8010c90 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -173,13 +173,6 @@ class PyLoadedExecutable { // PjRtExecutable::Execute. The result is similarly transposed back into the // argid,deviceid format. // args is [num_args x num_devices]. - absl::StatusOr>> - ExecuteShardedOnLocalDevices(absl::Span args); - - absl::StatusOr>, PyShardedToken>> - ExecuteShardedOnLocalDevicesWithTokens( - absl::Span args); - absl::StatusOr ExecuteSharded( std::vector args, bool with_tokens); diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 0aad3163c203..225d45f53b4b 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -512,15 +512,6 @@ NB_MODULE(xla_extension, m) { "get_compiled_memory_stats", xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) .def("delete", &PyLoadedExecutable::Delete) - .def("execute_sharded_on_local_devices", - xla::ValueOrThrowWrapper( - &PyLoadedExecutable::ExecuteShardedOnLocalDevices), - nb::arg("arguments")) - .def("execute_sharded_on_local_devices_with_tokens", - xla::ValueOrThrowWrapper( - &PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens), - nb::arg("arguments")) - // TODO(parkers): Switch execute_sharded_on_local_devices* to this. .def("execute_sharded", xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), nb::arg("arguments"), nb::arg("with_tokens") = false) diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index d2d9c585745c..15c307145b29 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -3612,9 +3612,9 @@ def testExecuteShardedOnLocalDevicesWithTokens(self): options.num_replicas = num_replicas compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build()), compile_options=options) - results, sharded_token = ( - compiled_c.execute_sharded_on_local_devices_with_tokens([]) - ) + py_results = compiled_c.execute_sharded([], with_tokens=True) + results = py_results.disassemble_into_single_device_arrays() + sharded_token = py_results.consume_token() sharded_token.block_until_ready() self.assertLen(results, 1) self.assertLen(results[0], 1) @@ -3666,14 +3666,16 @@ def testExecuteShardedOverloadEmptyInput(self): compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build()), compile_options=options) - results = compiled_c.execute_sharded_on_local_devices([]) + results = compiled_c.execute_sharded( + []).disassemble_into_single_device_arrays() self.assertLen(results, 1) self.assertIsInstance(results[0], list) self.assertLen(results[0], 1) results[0][0].block_until_ready() self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens([]) + results = compiled_c.execute_sharded( + [], with_tokens=True).disassemble_into_single_device_arrays() self.assertLen(results, 1) self.assertIsInstance(results[0], list) self.assertLen(results[0], 1) @@ -3692,15 +3694,16 @@ def testExecuteShardedOverloadBufferInput(self): buffer = self.backend.buffer_from_pyval(arg) - results = compiled_c.execute_sharded_on_local_devices([[buffer]]) + results = compiled_c.execute_sharded( + [[buffer]]).disassemble_into_single_device_arrays() self.assertLen(results, 1) self.assertIsInstance(results[0], list) self.assertLen(results[0], 1) results[0][0].block_until_ready() self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens( - [[buffer]]) + results = compiled_c.execute_sharded( + [[buffer]], with_tokens=True).disassemble_into_single_device_arrays() self.assertLen(results, 1) self.assertIsInstance(results[0], list) self.assertLen(results[0], 1) diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index aa2060f59102..791bfe44e85f 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -736,16 +736,6 @@ class LoadedExecutable: def local_devices(self) -> List[Device]: ... def size_of_generated_code_in_bytes(self) -> int: ... def delete(self) -> None: ... - def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... - def execute_with_token( - self, arguments: Sequence[ArrayImpl] - ) -> Tuple[List[ArrayImpl], Token]: ... - def execute_sharded_on_local_devices( - self, arguments: Sequence[List[ArrayImpl]] - ) -> List[List[ArrayImpl]]: ... - def execute_sharded_on_local_devices_with_tokens( - self, arguments: Sequence[List[ArrayImpl]] - ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... def execute_sharded( self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... ) -> ExecuteResults: ... From a39b6232be9b435fb3c8a82445708b7f8f71081d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Apr 2025 11:17:54 -0700 Subject: [PATCH 0571/1769] Make sure the order passed to `make_jit` and `_parse_jit_arguments` is the same as the order of arguments received in `jit` API and make it keyword-only PiperOrigin-RevId: 746527807 --- jax/_src/api.py | 9 ++++++--- jax/_src/pjit.py | 50 +++++++++++++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index d170d2632d6f..23c0a610cee9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -288,9 +288,12 @@ def jit( Array([ 0, 1, 256, 6561], dtype=int32) """ return pjit.make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=False) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=False) @contextmanager diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 459d2ebea80e..1dd1c6609a62 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -434,14 +434,16 @@ def _split_layout_and_sharding(entries): return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) -def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, device: xc.Device | None, + backend: str | None, inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. @@ -519,20 +521,29 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) -def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def make_jit(fun: Callable, + *, + in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, + device: xc.Device | None, + backend: str | None, + inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=use_resource_env) return _cpp_pjit(fun, jit_info) @@ -995,9 +1006,12 @@ def pjit( [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ return make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=True) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=True) def hashable_pytree(pytree): From ab882735960011486772d171b2ddf40d400166da Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 11:25:27 -0700 Subject: [PATCH 0572/1769] Deprecate jax.dlpack.to_dlpack. This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`). PiperOrigin-RevId: 746530351 --- CHANGELOG.md | 4 ++++ jax/dlpack.py | 29 +++++++++++++++++++++++++- tests/pytorch_interoperability_test.py | 2 ++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d118ef10f51f..7072772a051d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a callable. + * `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX + `Array` directly to the `from_dlpack` function of another framework. If you + need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an + array. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, diff --git a/jax/dlpack.py b/jax/dlpack.py index a65496ec0cbf..d008608fc356 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,8 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import jax._src.dlpack +import jax._src.deprecations + from jax._src.dlpack import ( - to_dlpack as to_dlpack, from_dlpack as from_dlpack, SUPPORTED_DTYPES as SUPPORTED_DTYPES, ) + +_deprecations = { + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + jax._src.dlpack.to_dlpack, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + to_dlpack = jax._src.dlpack.to_dlpack +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index e41c4329b95b..3035e68d234c 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -67,6 +67,8 @@ def testTorchToJaxFailure(self): y, client, client) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + @jtu.ignore_warning(message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning) def testJaxToTorch(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, From 8e9fca1d08bae15e5128647a388aff0e15aedc3a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 11 Apr 2025 12:02:51 -0700 Subject: [PATCH 0573/1769] document SPMD pipeline parallelism PiperOrigin-RevId: 746543312 --- docs/gpu_performance_tips.md | 235 ++++++++++++++++++++++++++++++++--- 1 file changed, 221 insertions(+), 14 deletions(-) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index b3643cb8e292..bade464d22a1 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -243,20 +243,6 @@ Run the real workflow, if you found these loggings in the running log, it means By adjusting this factor, users can fine-tune the trade-off between memory efficiency and performance optimizations. -* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism, - this flag enables overlapping the (i+1)-th layer weight `AllGather` with the - i-th layer computation. It also enables overlapping (i+1)-th layer - weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default - value is False. **There are some bugs when this flag is turned on.** -* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when - performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a - nonzero threshold decomposes `CollectivePermute`s into - `CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that - computation can be performed between each corresponding - `ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the - threshold is 0 and there is no decomposition. Setting it to threshold > 0 such - as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this - feature. * **--xla_gpu_all_gather_combine_threshold_bytes** **--xla_gpu_reduce_scatter_combine_threshold_bytes** **--xla_gpu_all_reduce_combine_threshold_bytes** @@ -268,6 +254,227 @@ Run the real workflow, if you found these loggings in the running log, it means combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By default, the `combine_threshold_bytes` is set to 256. +### Pipeline Parallelism on GPU + +XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique +where the forward and backward pass are split into multiple pipeline stages. +Each device (or device group) processes the result of the previous +pipeline stage (or the pipeline input) and sends its partial result to the next +stage until the end of the pipeline is reached. This optimization works best +when the latency of the computation is larger than communication. At compile +time, the operations will be rearranged to overlap communication with +computation. + +For an optimized schedule, we recommend these XLA flags: +``` +--xla_gpu_enable_latency_hiding_scheduler=true +--xla_gpu_enable_command_buffer='' +--xla_disable_hlo_passes=collective-permute-motion +--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE +``` + +The following JAX example demonstrates a pattern where communication operations +are scheduled to overlap with computations. In this example we will illustrate +how to set up an optimized pipeline parallelism scheduling using 4 GPUs that +form a communication ring (device 0 -> device 1 -> device 2 -> device 3 -> +device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and +`3 -> 0` as the back edge. + +``` +# Imports and setup +import functools +import jax +from jax import sharding +from jax.experimental import mesh_utils +import jax.numpy as jnp +import jax.random + +NUM_DEVICES = 4 +NUM_MICROBATCHES = 5 +NUM_CIRC_REPEATS = 2 +CONTRACTING_DIM_SIZE = 4096 +NON_CONTRACTING_DIM_SIZE = 8192 +COMPUTE_INTENSITY = 32 + +# Creates a collective permute for the "forward edge". +# 0->1, 1->2, ... (N-2)->(N-1) +def shift_right(arr): + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + + +# Creates a collective permute for the "back edge". +# (N-1)->0 +def cycle_back(arr): + padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1) + return jax.lax.slice( + jnp.pad(arr, padding), + [NUM_DEVICES - 1] + [0] * (arr.ndim - 1), + (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:], + ) + + +def select_on_first_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0 + return jnp.where(is_first_device, then_value, else_value) + + +def select_on_last_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_last_device = ( + jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1 + ) + return jnp.where(is_last_device, then_value, else_value) + + +def select_on_first_cycle(i, then_value, else_value): + assert then_value.shape == else_value.shape + is_first_cycle = i < NUM_MICROBATCHES + return jnp.where(is_first_cycle, then_value, else_value) + + +def while_body(carry, i): + """Body of the pipeline while loop.""" + weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry + + # Read input data from input buffer. + input_data = jax.lax.dynamic_slice( + input_buffer, + (0, (i + 0) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "forward edge" shifts data to the next stage. + fwd_edge_data = shift_right(fwd_edge_data) + + # Select compute argument based on device and pipeline cycle. + compute_argument = select_on_first_device( + select_on_first_cycle(i, input_data, bwd_edge_data), + fwd_edge_data, + ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)) + + # A few matmuls to simulate compute. + tmp = compute_argument + for _ in range(COMPUTE_INTENSITY): + tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,)))) + compute_result = tmp.reshape( + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ) + + # Read data from buffer to pass it to the first device of the pipeline on the + # "back edge". + bwd_edge_data = jax.lax.dynamic_slice( + output_buffer, + (0, (1 + i) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Colelctive permute on the "back edge" passes data to the first device. + bwd_edge_data = cycle_back(bwd_edge_data) + + # Update output buffer. We do this after reading from it to avoid the data + # dependency. + output_buffer = jax.lax.dynamic_update_slice( + output_buffer, + compute_result, + (0, (2 + i) % NUM_MICROBATCHES, 0, 0), + ) + + fwd_edge_data = compute_result + carry = ( + weights, + input_buffer, + output_buffer, + fwd_edge_data, + bwd_edge_data, + ) + return carry, i + + +@functools.partial(jax.jit, static_argnames=["mesh"]) +def entry_computation(weights, input_buffer, mesh): + + # Init output buffer. + output_buffer = jnp.zeros_like(input_buffer) + + # Init dummy data for forward and backward edge passed through the while loop. + dummy_data = jnp.zeros( + shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ).astype(jnp.float32) + dummy_data = jax.device_put( + dummy_data, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Start pipeline. + carry = weights, input_buffer, output_buffer, dummy_data, dummy_data + num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1 + carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations)) + _, _, output_buffer, _, _ = carry + + return output_buffer + + +def main(_): + + # Expect constant number of devices. + assert NUM_DEVICES == jax.local_device_count() + + # Create mesh. + mesh = sharding.Mesh( + mesh_utils.create_device_mesh([NUM_DEVICES]), + axis_names=["the_one_and_only_axis"], + ) + + # Init weights. + weights = 1.0 / CONTRACTING_DIM_SIZE + weights = jax.lax.broadcast_in_dim( + weights, + shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE), + broadcast_dimensions=(), + ) + weights = jax.device_put( + weights, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Init random input and replicate it across all devices. + random_key = jax.random.key(0) + input_buffer = jax.random.uniform( + random_key, + shape=( + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + ) + input_buffer = jax.lax.broadcast_in_dim( + input_buffer, + shape=( + NUM_DEVICES, + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + broadcast_dimensions=[1, 2, 3], + ) + input_buffer = jax.device_put( + input_buffer, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Run computation. + output_buffer = entry_computation(weights, input_buffer, mesh) + print(f"output_buffer = \n{output_buffer}") +``` ## NCCL flags These Nvidia NCCL flag values may be useful for single-host multi-device From 5cf74cc72b266a044b096f881dc006beabab0e7f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 11 Apr 2025 12:10:49 -0700 Subject: [PATCH 0574/1769] Use dash instead of underscore for extras. The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: https://github.com/jax-ml/jax/issues/27874 from appearing in future JAX releases PiperOrigin-RevId: 746546162 --- CHANGELOG.md | 5 +++++ docs/installation.md | 4 ++-- jax_plugins/cuda/plugin_setup.py | 2 +- setup.py | 13 +++---------- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7072772a051d..4bac678ca14f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,11 +30,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. The workaround is to apply `jax.jit` last among the wrappers, and similarly for `jax.pmap`. See {jax-issue}`#27873`. + * The `cuda12_pip` extra for `jax` has been removed; use `pip install jax[cuda12]` + instead. * Changes * The minimum CuDNN version is v9.8. * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported. + * JAX package extras are now updated to use dash instead of underscore to + align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` + to install JAX, run `pip install jax[cuda12-local]` instead. * Deprecations diff --git a/docs/installation.md b/docs/installation.md index 34274d7596aa..500347e04ab1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -158,7 +158,7 @@ pip install --upgrade pip # Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" +pip install --upgrade "jax[cuda12-local]" ``` **These `pip` installations do not work with Windows, and may fail silently; refer to the table @@ -296,7 +296,7 @@ pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.co - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - NVIDIA GPU (CUDA 12) legacy: diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index db9928f6cf61..b9220cd29283 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -52,7 +52,7 @@ def has_ext_modules(self): python_requires=">=3.10", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ - 'with_cuda': [ + 'with-cuda': [ "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", diff --git a/setup.py b/setup.py index 629836b30862..6f9b4a9dd2cb 100644 --- a/setup.py +++ b/setup.py @@ -89,24 +89,17 @@ def load_version_module(pkg_path): 'cuda': [ f"jaxlib>={_current_jaxlib_version},<={_jax_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ f"jaxlib>={_current_jaxlib_version},<={_jax_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", - ], - - # Deprecated alias for cuda12, kept to avoid breaking users who wrote - # cuda12_pip in their CI. - 'cuda12_pip': [ - f"jaxlib>={_current_jaxlib_version},<={_jax_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. - 'cuda12_local': [ + 'cuda12-local': [ f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", ], From 27c07f7cd393870dffd617c9b4eae66fcd52ba44 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 11 Apr 2025 12:12:32 -0700 Subject: [PATCH 0575/1769] [Pallas] Allow 1D iota PiperOrigin-RevId: 746546870 --- jax/_src/pallas/mosaic/lowering.py | 8 ++++++++ tests/pallas/ops_test.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 343c7d79aab6..065d2b4c3b14 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2254,6 +2254,14 @@ def _split_lowering_rule( def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): + if len(shape) == 1: + if dimension != 0: + raise ValueError("Dimension must be 0 for 1D iota.") + def _1d_iota_helper(dtype, shape, dimension, sharding): + iota_2d = lax.iota_p.bind(dtype, (1,) + shape, dimension, sharding) + return iota_2d[0] + return lower_fun(_1d_iota_helper, multiple_results=False)( + ctx, dtype, shape, dimension, sharding) out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b3dd61757df8..0580d7ec5824 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1535,11 +1535,12 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_allclose(f(x, y), kernel(x, y)) @parameterized.parameters( + ((32,), jnp.int32, 0), ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), ((8, 16, 2), jnp.int8, 1), ) - def test_broadcasted_iota(self, shape, dtype, dimension): + def test_iota(self, shape, dtype, dimension): self.skip_if_mosaic_gpu() if jtu.test_device_matches(["tpu"]): From 904419cb0ecdbcf3ace16130301031b4420e4772 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 12:14:14 -0700 Subject: [PATCH 0576/1769] Rename TPU bazel test tags. Use a count of chips (or omit it if 1) rather than specifying an ICI topology. Examples: * tpu_v5e_1x1 -> tpu_v5e * tpu_v5e_4x2 -> tpu_v5e_x8 PiperOrigin-RevId: 746547477 --- jax/experimental/array_serialization/BUILD | 2 +- tests/BUILD | 72 +++++++++++----------- tests/pallas/BUILD | 40 ++++++------ 3 files changed, 57 insertions(+), 57 deletions(-) diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index ab1ee3fd393e..84e5b9300912 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -45,7 +45,7 @@ jax_multiplatform_test( name = "serialization_test", srcs = ["serialization_test.py"], enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], deps = [ "//jax:experimental", diff --git a/tests/BUILD b/tests/BUILD index 876b760bc4d3..234e577e015a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,7 +34,7 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = 10, deps = [ "//jax:experimental", @@ -44,7 +44,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debug_info_test", srcs = ["debug_info_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], deps = [ "//jax:experimental", "//jax:pallas", @@ -279,12 +279,12 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100x2", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v5p_2x2", - "tpu_v5e_4x2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v5p_x4", + "tpu_v5e_x8", "gpu_p100x2_shardy", - "tpu_v5e_4x2_shardy", + "tpu_v5e_x8_shardy", ], deps = [ "//jax:experimental", @@ -300,8 +300,8 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", "gpu_h100x2", ], shard_count = { @@ -321,8 +321,8 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ @@ -334,10 +334,10 @@ jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], enable_configs = [ - "tpu_v3_2x2", - "tpu_v5e_4x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v3_x4", + "tpu_v5e_x8", + "tpu_v4_x4", + "tpu_v3_x4_shardy", ], deps = [ "//jax:experimental", @@ -397,7 +397,7 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ @@ -763,7 +763,7 @@ jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", "gpu_h100x2", ], ) @@ -817,7 +817,7 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_v100", - "tpu_v3_2x2", + "tpu_v3_x4", ], shard_count = { "cpu": 30, @@ -1117,7 +1117,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = { "gpu": 2, "tpu": 4, @@ -1310,9 +1310,9 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], tags = ["multiaccelerator"], ) @@ -1323,11 +1323,11 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", "gpu_a100_shardy", - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", ], ) @@ -1338,10 +1338,10 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v3_x4_shardy", "gpu_p100x2_shardy", ], tags = ["multiaccelerator"], @@ -1359,9 +1359,9 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], ) @@ -1426,7 +1426,7 @@ jax_multiplatform_test( srcs = ["shard_map_test.py"], enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", ], shard_count = { "cpu": 50, @@ -1519,8 +1519,8 @@ jax_multiplatform_test( enable_configs = [ "cpu_shardy", "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = [], ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index ba5d9d5f4ae7..730c354d8fdb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -107,7 +107,7 @@ jax_multiplatform_test( "gpu_a100_x32", "gpu_h100", "gpu_h100_x32", - "tpu_v6e_1x1", + "tpu_v6e", ], shard_count = { "cpu": 16, @@ -315,7 +315,7 @@ jax_multiplatform_test( ], enable_backends = [], enable_configs = [ - "tpu_v5e_4x2", + "tpu_v5e_x8", ], deps = [ "//jax:pallas_tpu_ops", @@ -352,7 +352,7 @@ jax_multiplatform_test( enable_backends = ["tpu"], enable_configs = [ "tpu_v5e", - "tpu_v5p_1x1", + "tpu_v5p", ], deps = [ "//jax:extend", @@ -383,10 +383,10 @@ jax_multiplatform_test( srcs = ["tpu_pallas_distributed_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2", + "tpu_v5e_x8", + "tpu_v5p_x4", + "tpu_v4_x4", + "tpu_v3_x4", ], deps = [ "//jax:extend", @@ -400,8 +400,8 @@ jax_multiplatform_test( srcs = ["tpu_pallas_pipeline_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], shard_count = 5, tags = [ @@ -421,8 +421,8 @@ jax_multiplatform_test( srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], deps = [ "//jax:pallas_tpu", @@ -451,7 +451,7 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5p_2x2", + "tpu_v5p_x4", ], deps = [ "//jax:pallas", @@ -491,7 +491,7 @@ jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 5, @@ -509,7 +509,7 @@ jax_multiplatform_test( name = "tpu_ragged_paged_attention_test", srcs = ["tpu_ragged_paged_attention_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 24, @@ -544,8 +544,8 @@ jax_multiplatform_test( name = "tpu_splash_attention_kernel_sharded_test", srcs = ["tpu_splash_attention_kernel_sharded_test.py"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_2x2", + "tpu_v5e_x8", + "tpu_v5p_x4", ], shard_count = 5, deps = [ @@ -705,7 +705,7 @@ jax_multiplatform_test( name = "tpu_fusable_matmul_test", srcs = ["tpu_fusable_matmul_test.py"], disable_configs = [ - "tpu_v3_1x1", + "tpu_v3", "tpu_pjrt_c_api", "gpu_v100", "gpu_v100_x32", @@ -719,10 +719,10 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v4_1x1", + "tpu_v4", "tpu_v5e", - "tpu_v5p_1x1", - "tpu_v6e_1x1", + "tpu_v5p", + "tpu_v6e", ], shard_count = 4, tags = [ From e9364f4b0a2af8fceb7acb6fa9445385b2ee36c2 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 11 Apr 2025 12:36:36 -0700 Subject: [PATCH 0577/1769] Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7 PiperOrigin-RevId: 746554582 --- jax/_src/lax/control_flow/loops.py | 42 ++++++------------------------ tests/checkify_test.py | 4 +-- tests/lax_control_flow_test.py | 23 +--------------- 3 files changed, 11 insertions(+), 58 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 689e1f535259..babffa1d47d7 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1461,34 +1461,9 @@ def _create_jaxpr(init_val): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') - - # If the body forwards an input carry to an output carry, *and* it's not used - # by the cond fun, it can be moved to be a body const. Doing so can lead to - # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch - # the carry too, but not the body consts. - body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) - _, carry_fwd = split_list(body_fwd, [len(body_consts)]) - cond_jaxpr_, keep_cond = pe.dce_jaxpr( - cond_jaxpr.jaxpr, [True], - [True] * len(cond_consts) + [i != f for i, f in enumerate(body_fwd)]) - _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) - move_to_const = [i == f and not k for i, (f, k) - in enumerate(zip(body_fwd, keep_cond_carry))] - if any(move_to_const): - cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) - body_jaxpr = pe.prune_closed_jaxpr_outputs( - body_jaxpr, [not m for m in move_to_const]) - body_jaxpr = pe.move_binders_to_front( - body_jaxpr, [False] * len(body_consts) + move_to_const) - init_vals, new_body_consts = partition_list(move_to_const, init_vals) - body_consts = [*new_body_consts, *body_consts] - outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) - - if any(move_to_const): - outs = pe.merge_lists(move_to_const, outs, new_body_consts) return tree_unflatten(body_tree, outs) @@ -1864,19 +1839,18 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) + return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) def new_cond(pred_args): - pred, *_ = pred_args + pred, _ = pred_args return pred def new_body(pred_args): - _, cond_consts, body_consts, carry = pred_args - carry = body((*body_consts, *carry)) - pred = cond((*cond_consts, *carry)) - return pred, cond_consts, body_consts, carry + _, args = pred_args + args = body(args) + pred = cond(args) + return pred, args def fun(*args): - cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) - pred = cond((*cond_consts, *carry)) - *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) + pred = cond(args) + _, out = while_loop(new_cond, new_body, (pred, args)) return out return mlir.lower_fun(fun)(ctx, *args) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 2f4b7d511fbe..5ea99d20a2ab 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -492,8 +492,8 @@ def f(x: jax.Array) -> jax.Array: def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - j = jnp.sin(cond_val) - return i + (0. * j) < 2 # don't let the sin value be dead code + _ = jnp.sin(cond_val) + return i < 2 def while_body(val): i, cond_val, body_val = val diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3034096cee57..a987d9e4c192 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2362,7 +2362,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( @@ -3122,27 +3122,6 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash - def test_readonly_carry_optimization(self): - # https://github.com/google/flax/issues/4700 - def foo(w, x, c_max): - def while_cond(val): - c, x, w = val - return c < c_max - - def while_body(val): - c, x, w = val - return c + 1, x @ w, w - - _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) - return w, x - - w = jnp.ones((2, 2)) - xs = jnp.ones((4, 2)) - c_maxs = jnp.arange(4) - w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) - )(w, xs, c_maxs) # doesn't crash - self.assertAllClose(w, w_, check_dtypes=False) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 6efcf44b1ac32c5b5c3ef621f2f5e4ddff4c8b80 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Apr 2025 13:05:16 -0700 Subject: [PATCH 0578/1769] Deprecate `PositionalSharding` and `GSPMDSharding` PiperOrigin-RevId: 746564071 --- jax/_src/dispatch.py | 10 +++--- jax/_src/export/_export.py | 5 +-- .../array_serialization/serialization_test.py | 4 +-- jax/sharding.py | 32 +++++++++++++++++-- tests/array_test.py | 27 ++++++++++------ tests/pickle_test.py | 3 +- 6 files changed, 59 insertions(+), 22 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index baab6d519291..eea687145c0e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -50,9 +50,9 @@ from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( NamedSharding, - SingleDeviceSharding, TransferToMemoryKind, - is_single_device_sharding) +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, TransferToMemoryKind, GSPMDSharding, + PositionalSharding, is_single_device_sharding) import numpy as np @@ -133,11 +133,11 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + return jax.device_put(tok, PositionalSharding(devices)) # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. - s = jax.sharding.GSPMDSharding.get_replicated(devices) + s = GSPMDSharding.get_replicated(devices) sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 31132dc77c82..d02ef44f8318 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1212,10 +1212,11 @@ def _hlo_sharding_to_xla_compatible_sharding( def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, - device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: + device_assignment: Sequence[jax.Device] + ) -> sharding_impls.GSPMDSharding | None: if hlo_sharding is None: return None - return sharding.GSPMDSharding(device_assignment, hlo_sharding) + return sharding_impls.GSPMDSharding(device_assignment, hlo_sharding) def _hlo_sharding_to_named_sharding( diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 9f4539fc63c8..280c2f58b348 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -26,7 +26,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import array -from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding +from jax._src.sharding_impls import NamedSharding, GSPMDSharding, SingleDeviceSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization from jax.experimental.layout import Layout, DeviceLocalLayout as DLL @@ -620,7 +620,7 @@ def test_deserialization_with_int4(self): ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) # Run serialization. - sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + sharding = GSPMDSharding.get_replicated(jax.devices()) tspecs = jax.tree_util.tree_map( serialization.get_tensorstore_spec, [ckpt_dir] ) diff --git a/jax/sharding.py b/jax/sharding.py index bacf848f07ed..66692069d19b 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,8 +20,8 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, - GSPMDSharding as GSPMDSharding, - PositionalSharding as PositionalSharding, + GSPMDSharding as _deprecated_GSPMDSharding, + PositionalSharding as _deprecated_PositionalSharding, use_mesh as use_mesh, set_mesh as set_mesh, ) @@ -34,3 +34,31 @@ AxisType as AxisType, get_abstract_mesh as get_abstract_mesh, ) + +_deprecations = { + # Added April 11, 2025. + "PositionalSharding": ( + ( + "jax.sharding.PositionalSharding is deprecated. Use" + " jax.NamedSharding instead." + ), + _deprecated_PositionalSharding, + ), + "GSPMDSharding": ( + ( + "jax.sharding.GSPMDSharding is deprecated. Use" + " jax.NamedSharding instead." + ), + _deprecated_GSPMDSharding, + ), +} + +import typing +if typing.TYPE_CHECKING: + PositionalSharding = _deprecated_PositionalSharding + GSPMDSharding = _deprecated_GSPMDSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/tests/array_test.py b/tests/array_test.py index f097497cef51..87227d4d61e2 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -899,7 +899,7 @@ def test_op_sharding_indices(self, pspec): shape = (8, 4) mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) - ops = jax.sharding.GSPMDSharding( + ops = GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) self.assertDictEqual( ops.devices_indices_map(shape), mps.devices_indices_map(shape)) @@ -975,7 +975,7 @@ def test_gspmd_sharding_repr(self): op.tile_assignment_dimensions = [4, 1, 2] op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True - s = jax.sharding.GSPMDSharding(jax.devices(), op) + s = GSPMDSharding(jax.devices(), op) # memory kind also appears in the repr but only for TPU. self.assertIn( 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' @@ -983,7 +983,7 @@ def test_gspmd_sharding_repr(self): op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED - s2 = jax.sharding.GSPMDSharding(jax.devices(), op2) + s2 = GSPMDSharding(jax.devices(), op2) # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) @@ -1008,7 +1008,7 @@ def test_positional_sharding_op_sharding_lowering( mps = jax.sharding.NamedSharding(mesh, pspec) devices = jax.local_devices()[:8] # Taking up to 8 devices - devices_sharding = jax.sharding.PositionalSharding(devices) + devices_sharding = PositionalSharding(devices) devices_sharding = devices_sharding.reshape(shape).replicate(axes) if transpose: devices_sharding = devices_sharding.T @@ -1110,7 +1110,7 @@ def test_devices_sharding_respects_init_mesh_shape(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) - devices_sharding = jax.sharding.PositionalSharding(mesh.devices) + devices_sharding = PositionalSharding(mesh.devices) op1 = mps._to_xla_hlo_sharding(len(value_shape)) op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) @@ -1129,7 +1129,7 @@ def test_pmap_sharding_repr(self): def test_positional_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') - s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) + s = PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) repr(s) # doesn't crash str(s) # doesn't crash @@ -1200,9 +1200,9 @@ def test_are_shardings_equivalent(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.REPLICATED - s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1) + s6 = GSPMDSharding([jax.devices()[0]], op1) - s7 = jax.sharding.GSPMDSharding(jax.devices(), op1) + s7 = GSPMDSharding(jax.devices(), op1) # The OpSharding is replicated but the Sharding itself are on different # devices. @@ -1212,7 +1212,7 @@ def test_are_shardings_equivalent(self): op2.type = xc.OpSharding.Type.OTHER op2.tile_assignment_devices = [0, 1] op2.tile_assignment_dimensions = [2, 1] - s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2) + s8 = GSPMDSharding(list(mesh2.devices.flat), op2) self.assertTrue(s1.is_equivalent_to(s6, 2)) self.assertTrue(s5.is_equivalent_to(s8, 2)) @@ -1225,7 +1225,7 @@ def test_are_shardings_equivalent(self): op3.tile_assignment_devices = [0, 1] op3.tile_assignment_dimensions = [1, 1, 2] op3.replicate_on_last_tile_dim = True - s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3) + s10 = GSPMDSharding(list(mesh2.devices.flat), op3) self.assertTrue(s9.is_equivalent_to(s10, 2)) @@ -1444,6 +1444,13 @@ def test_memory_kind_with_abstract_mesh(self): ValueError, 'Got invalid memory kind'): NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + def test_pos_gspmd_sharding_warnings(self): + with self.assertWarns(DeprecationWarning): + jax.sharding.PositionalSharding(jax.devices()) + + with self.assertWarns(DeprecationWarning): + jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 185eebd90726..a3dc5be6e11c 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -28,6 +28,7 @@ from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc +from jax._src.sharding_impls import GSPMDSharding import numpy as np @@ -182,7 +183,7 @@ def test_pickle_pmap_sharding(self): self.assertEqual(s, pickle.loads(pickle.dumps(s))) def test_pickle_gspmd_sharding(self): - s = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + s = GSPMDSharding.get_replicated(jax.devices()) self.assertEqual(s, pickle.loads(pickle.dumps(s))) @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") From c0d97a6872cd965757af5e9fde98e6e96d1bdb33 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 11 Apr 2025 13:09:07 -0700 Subject: [PATCH 0579/1769] Removed type annotations appear to be used and actually defined in python as a patch, rolling back. Reverts b1c96d47ed9876a74ee2686234201aacd7cd7791 PiperOrigin-RevId: 746565341 --- jaxlib/xla/xla_extension/__init__.pyi | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 791bfe44e85f..de9bb02f6343 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -736,6 +736,10 @@ class LoadedExecutable: def local_devices(self) -> List[Device]: ... def size_of_generated_code_in_bytes(self) -> int: ... def delete(self) -> None: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... + def execute_with_token( + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... def execute_sharded( self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... ) -> ExecuteResults: ... From c90751bc54f9cdfc9cc582e2f41b44dffb101960 Mon Sep 17 00:00:00 2001 From: ywrt Date: Sat, 12 Apr 2025 07:20:39 +1000 Subject: [PATCH 0580/1769] Fix typo in jax.lax.linalg.symmetric_product description Missing space in '..math::' meant that the math wasn't rendering correctly. --- jax/_src/lax/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 5c3b962f6ba2..a3e0a71a671c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -528,7 +528,7 @@ def symmetric_product( Computes the symmetric product - ..math:: + .. math:: \alpha \, A \, A^T + \beta \, C where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix. From 6fc78a5a6db6492a2e01555a3c7fd7da0eb33391 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Apr 2025 14:41:12 -0700 Subject: [PATCH 0581/1769] Deprecate jax.lax.infeed and jax.lax.outfeed. These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users. PiperOrigin-RevId: 746595857 --- CHANGELOG.md | 2 ++ jax/lax/__init__.py | 47 ++++++++++++++++++++++++++++++++++++++++++++ tests/infeed_test.py | 4 ++++ 3 files changed, 53 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bac678ca14f..2cb99e58db45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `Array` directly to the `from_dlpack` function of another framework. If you need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an array. + * `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and + `jax.lax.outfeed_p` are deprecated and will be removed in JAX v0.7.0. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 53d8790874ca..e8ec74f59a7d 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -395,3 +395,50 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p + +import jax._src.lax.lax + +_deprecations = { + "infeed": ( + ( + "jax.lax.infeed was deprecated in JAX v0.6.0 and will be removed in" + " JAX v0.7.0." + ), + jax._src.lax.lax.infeed, + ), + "infeed_p": ( + ( + "jax.lax.infeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.infeed_p, + ), + "outfeed": ( + ( + "jax.lax.outfeed was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed, + ), + "outfeed_p": ( + ( + "jax.lax.outfeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed_p, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + infeed = jax._src.lax.lax.infeed + infeed_p = jax._src.lax.lax.infeed_p + outfeed = jax._src.lax.lax.outfeed + outfeed_p = jax._src.lax.lax.outfeed_p +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 060502ae68cd..79d4dc038fc2 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -78,6 +78,8 @@ def f(x): self.assertAllClose(f(x), to_infeed) @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. + @jtu.ignore_warning(category=DeprecationWarning, + message=".*(infeed|outfeed) was deprecated.*") def testInfeedThenOutfeed(self): @jax.jit @@ -99,6 +101,8 @@ def f(x): execution.join() self.assertAllClose(out, y + np.float32(1)) + @jtu.ignore_warning(category=DeprecationWarning, + message=".*(infeed|outfeed) was deprecated.*") def testInfeedThenOutfeedInALoop(self): def doubler(_, token): From b2a8df718377ff94bfe8630a6e005a0242f623d5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 11 Apr 2025 15:14:30 -0700 Subject: [PATCH 0582/1769] Add the `method` argument to `jax.numpy.isin` stub. This parameter is available from https://github.com/google/jax/pull/23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html. PiperOrigin-RevId: 746606206 --- jax/numpy/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 259f6e3ed2ee..df6454c9a1f1 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -583,7 +583,7 @@ def iscomplexobj(x: Any) -> builtins.bool: ... def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... + assume_unique: builtins.bool = ..., invert: builtins.bool = ..., method: str = ...) -> Array: ... def isinf(x: ArrayLike, /) -> Array: ... def isnan(x: ArrayLike, /) -> Array: ... def isneginf(x: ArrayLike, /) -> Array: ... From b3f49e42d9d6e0c7510f1fbd444b560b58427eca Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 11 Apr 2025 21:14:03 +0000 Subject: [PATCH 0583/1769] Re-landing #27937 with fewer bugs and more tests. --- jax/_src/lax/control_flow/loops.py | 42 ++++++++++++--- tests/checkify_test.py | 4 +- tests/jaxpr_effects_test.py | 4 +- tests/lax_control_flow_test.py | 86 +++++++++++++++++++++++++++++- 4 files changed, 122 insertions(+), 14 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index babffa1d47d7..de53ee14ca0d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1461,9 +1461,34 @@ def _create_jaxpr(init_val): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') + + # If the body forwards an input carry to an output carry, *and* it's not used + # by the cond fun, it can be moved to be a body const. Doing so can lead to + # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch + # the carry too, but not the body consts. + body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) + carry_nofwd = [len(body_consts) + i != f for i, f in enumerate(body_fwd)] + cond_jaxpr_, keep_cond = pe.dce_jaxpr( + cond_jaxpr.jaxpr, [True], [True] * len(cond_consts) + carry_nofwd) + _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) + move_to_const = _map(operator.not_, keep_cond_carry) + + if any(move_to_const): + cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) + body_jaxpr = pe.prune_closed_jaxpr_outputs( + body_jaxpr, [not m for m in move_to_const]) + body_jaxpr = pe.move_binders_to_front( + body_jaxpr, [False] * len(body_consts) + move_to_const) + init_vals, new_body_consts = partition_list(move_to_const, init_vals) + body_consts = [*new_body_consts, *body_consts] + outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) + + if any(move_to_const): + outs = pe.merge_lists(move_to_const, outs, new_body_consts) + return tree_unflatten(body_tree, outs) @@ -1839,18 +1864,19 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) + return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) def new_cond(pred_args): - pred, _ = pred_args + pred, *_ = pred_args return pred def new_body(pred_args): - _, args = pred_args - args = body(args) - pred = cond(args) - return pred, args + _, cond_consts, body_consts, carry = pred_args + carry = body((*body_consts, *carry)) + pred = cond((*cond_consts, *carry)) + return pred, cond_consts, body_consts, carry def fun(*args): - pred = cond(args) - _, out = while_loop(new_cond, new_body, (pred, args)) + cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) + pred = cond((*cond_consts, *carry)) + *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) return out return mlir.lower_fun(fun)(ctx, *args) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 5ea99d20a2ab..2f4b7d511fbe 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -492,8 +492,8 @@ def f(x: jax.Array) -> jax.Array: def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - _ = jnp.sin(cond_val) - return i < 2 + j = jnp.sin(cond_val) + return i + (0. * j) < 2 # don't let the sin value be dead code def while_body(val): i, cond_val, body_val = val diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index c331bfaf438a..d5574f8a9a1d 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -947,7 +947,7 @@ def make_fun(index): def f(x): def body(y): input_effect(x, y, index=index) - return y + return 2 * y lax.while_loop(lambda _: True, body, y) return f jaxpr = jax.make_jaxpr(make_fun(0))(0) @@ -959,7 +959,7 @@ def body(y): def f(x): def body(y): input_effect(x, y, index=1) - return y + return 2 * y lax.while_loop(lambda _: (x > 0).all(), body, y) jaxpr = jax.make_jaxpr(f)(0) self.assertIn(InputEffect(0), jaxpr.effects) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index a987d9e4c192..0c36d3ff7d88 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -33,7 +33,7 @@ from jax import random from jax._src import test_util as jtu from jax import tree_util -from jax._src.util import unzip2 +from jax._src.util import unzip2, split_list from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp @@ -2362,7 +2362,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( @@ -3122,6 +3122,88 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash + def test_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4700 + def foo(w, x, c_max): + def while_cond(val): + c, x, w = val + return c < c_max + + def while_body(val): + c, x, w = val + return c + 1, x @ w, w + + _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) + return w, x + + w = jnp.ones((2, 2)) + xs = jnp.ones((4, 2)) + c_maxs = jnp.arange(4) + w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) + )(w, xs, c_maxs) # doesn't crash + self.assertAllClose(w, w_, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=5)) + @jtu.run_on_devices("cpu") + def test_while_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds_cond_uses, + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds): + + num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + + num_noninplace_fwds) + num_carry = num_fwds + 4 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def cond_fun(c): + i, c = c + c = [c[i] for i in iperm] + c, _ = split_list(c, [num_inplace_fwds_cond_uses]) + return (i < 2) + (0. * jnp.array(sum(c))).astype(bool) + + def body_fun(c): + i, c = c + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return (i + 1, new_c) + + i, outs = jax.lax.while_loop(cond_fun, body_fun, (0, init_vals)) + self.assertEqual(i, 2) + _, outs_ref = body_fun(body_fun((0, init_vals))) + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + def test_while_constification_correctness_manually(self): + # regression test for a particular index-offset logic bug + + def cond_fun(c): + # cond doesn't use first or third element of the carry + _, i, _ = c + return i == 0 + + def body_fun(c): + # two body consts + for _ in range(2): jnp.sin(np.zeros(3)) + # first element of the carry is forwarded to third element of the carry + return 0., 1., c[0] + + outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) + self.assertAllClose(outs, (0., 1., 5.)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 0fa732ea4558571e629a78b3578382740cd87259 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 11 Apr 2025 15:49:04 -0700 Subject: [PATCH 0584/1769] [ragged-paged-attn][NFC] Make validate_inputs functions take same inputs as attention call. PiperOrigin-RevId: 746616128 --- .../pallas/ops/tpu/ragged_paged_attention.py | 77 +++++++++++++++++-- .../pallas/tpu_ragged_paged_attention_test.py | 8 +- 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 5c3f17206eaa..760a5eef089b 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -83,8 +83,17 @@ def ref_ragged_paged_attention( soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, ): - validate_static_inputs( - queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE @@ -130,17 +139,39 @@ def ref_ragged_paged_attention( # Expect to run these checks during runtime. -def validate_dynamic_inputs( +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) max_num_batched_tokens = q.shape[0] page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape @@ -168,15 +199,23 @@ def validate_dynamic_inputs( # Expect to run these checks during compile time. -def validate_static_inputs( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape @@ -215,6 +254,14 @@ def validate_static_inputs( raise ValueError(f"{sliding_window=} must be positive.") if soft_cap is not None and soft_cap == 0.0: raise ValueError(f"{soft_cap=} must not be 0.0.") + if num_kv_pages_per_block is not None and num_kv_pages_per_block <= 0: + raise ValueError(f"{num_kv_pages_per_block=} must be positive.") + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -676,7 +723,21 @@ def ragged_paged_attention( Returns: The output of the attention. """ - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE num_q_tokens, num_q_heads, head_dim = q.shape diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 8d48bc281400..e617a1b3b06c 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -19,9 +19,9 @@ import jax from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + dynamic_validate_inputs, ragged_paged_attention, ref_ragged_paged_attention, - validate_dynamic_inputs, ) import jax.numpy as jnp @@ -91,15 +91,15 @@ def _test_ragged_paged_attention( num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_dynamic_inputs( + dynamic_validate_inputs( q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, - sliding_window, - soft_cap, + sliding_window=sliding_window, + soft_cap=soft_cap, ) actual_num_q_tokens = cu_q_lens[num_seqs[0]] From 29f65f04ed492afa97a09160f29a7e80ffbf1db3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 11 Apr 2025 23:22:26 +0000 Subject: [PATCH 0585/1769] re-index jaxpr input effects in move_binders_to_front --- jax/_src/interpreters/partial_eval.py | 8 ++++++-- tests/lax_control_flow_test.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 516579b7ac39..f8ce92e7f97f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1559,10 +1559,14 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) + id_map = {id(v): i for i, v in enumerate(new_invars)} + idx_map = {i: id_map[id(v)] for i, v in enumerate(closed_jaxpr.jaxpr.invars)} + new_effs = {e.replace(input_index=idx_map[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e + for e in closed_jaxpr.jaxpr.effects} new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects, - closed_jaxpr.jaxpr.debug_info) + new_effs, closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 0c36d3ff7d88..242b0548023e 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3153,7 +3153,7 @@ def test_while_constification_correctness( num_inplace_fwds_cond_doesnt_use, num_noninplace_fwds): - num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + + num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + num_noninplace_fwds) num_carry = num_fwds + 4 From 8afc833c24011398e5fa5bd2272e411dca804004 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Apr 2025 17:41:45 -0700 Subject: [PATCH 0586/1769] Rename is_closed to is_open in the shardy shardings PiperOrigin-RevId: 746645422 --- jax/_src/callback.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/named_sharding.py | 13 ++++++------- jax/_src/sharding_impls.py | 6 +++--- jax/experimental/shard_map.py | 2 +- tests/array_test.py | 12 ++++++------ tests/pjit_test.py | 6 +++--- 8 files changed, 22 insertions(+), 23 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 20cdec781265..630539ab0db8 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -167,7 +167,7 @@ def _callback_op_sharding( sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) + sharding_impls.SdyDimSharding(axes=[], is_open=False) ] * avals_out[0].ndim, logical_device_ids=())]) else: diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 79b67219b490..dc140c22650d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -170,7 +170,7 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) + sharding_impls.SdyDimSharding(axes=[], is_open=False) ] * ctx.avals_out[0].ndim, logical_device_ids=())]) else: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 84bbd685258e..becdfd46bd92 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1795,7 +1795,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: s = SdyArraySharding( mesh_shape=None, dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + sharding_impls.SdyDimSharding(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) return wrap_with_sharding_op(ctx, val, aval, s) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index b712e69dd5b9..9bdadd8a8570 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -42,7 +42,7 @@ def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=False) + dim_shardings = [SdyDimSharding(axes=[], is_open=True) for _ in range(ndim)] return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) @@ -242,11 +242,11 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=True) + dim_shardings = [SdyDimSharding(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: - dim_shardings[i].is_closed = False + dim_shardings[i].is_open = True elif dim_spec is None: # Already empty and closed sharding. pass @@ -274,14 +274,13 @@ def get_array_mapping( @dataclasses.dataclass class SdyDimSharding: axes: Sequence[str] - is_closed: bool + is_open: bool priority: int | None = None def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], - is_closed=self.is_closed, - priority=self.priority) + is_closed=not self.is_open, priority=self.priority) def __repr__(self): return f'SdyDimSharding({self._custom_repr()})' @@ -289,7 +288,7 @@ def __repr__(self): def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) open_repr = '' - if not self.is_closed: + if self.is_open: open_repr = ', ?' if self.axes else '?' priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d95b12f244ba..e462ee2f0ba0 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -101,8 +101,8 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): dim_shardings, used_axes = [], [] # type: ignore for d in sdy_sharding.dimension_shardings: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDimSharding(axes=[], is_closed=False) - if not d.axes and d.is_closed else d) + dim_shardings.append(SdyDimSharding(axes=[], is_open=True) + if not d.axes and not d.is_open else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) # Sort wrt mesh axis names so order is deterministic and doesn't hang in @@ -185,7 +185,7 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return replicated_hlo_sharding def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) + sdy_dim_sharding = [SdyDimSharding(axes=[], is_open=False) for _ in range(num_dimensions)] return SdyArraySharding(None, sdy_dim_sharding) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 2d114c6c3a2b..a6998d29e897 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -706,7 +706,7 @@ def _shardy_shard_map_sharding( for dim_sharding in sdy_sharding.dimension_shardings: # Only allow dimensions which have no sharding to be auto-sharded. if not dim_sharding.axes: - dim_sharding.is_closed = False + dim_sharding.is_open = True return sdy_sharding diff --git a/tests/array_test.py b/tests/array_test.py index 87227d4d61e2..1780213bcc61 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1464,9 +1464,9 @@ def test_long_axis_names(self): SdyArraySharding( mesh.shape_tuple, [SdyDimSharding( - ('sequence', 'data'), True), - SdyDimSharding(('model',), True), - SdyDimSharding([], True)])) + ('sequence', 'data'), False), + SdyDimSharding(('model',), False), + SdyDimSharding([], False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( @@ -1483,9 +1483,9 @@ def test_unconstrained(self): sdy_sharding, SdyArraySharding( mesh.shape_tuple, - [SdyDimSharding([], True), - SdyDimSharding([], False), - SdyDimSharding(('x',), True)])) + [SdyDimSharding([], False), + SdyDimSharding([], True), + SdyDimSharding(('x',), False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f90b02cbdaf7..0e7867cc2cdd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8188,8 +8188,8 @@ def test_array_sharding_repr_with_priority(self): sharding = sharding_impls.SdyArraySharding( mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True), - sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)]) + sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_open=False), + sharding_impls.SdyDimSharding(axes=['model'], is_open=True, priority=2)]) self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") def test_array_sharding_repr_with_logical_ids(self): @@ -8202,7 +8202,7 @@ def test_array_sharding_repr_with_logical_ids(self): def test_dimension_sharding_repr(self): dim_sharding = sharding_impls.SdyDimSharding( - axes=['data', 'model'], is_closed=False, priority=2) + axes=['data', 'model'], is_open=True, priority=2) self.assertEqual(repr(dim_sharding), "SdyDimSharding({'data', 'model', ?}p2)") From e1cad34522638db43223a0d8d3d7b96f2f9ea2f1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 11 Apr 2025 18:58:12 -0700 Subject: [PATCH 0587/1769] Add `ChunkedCausalMask` for Splash Attention to support attention masking similar to Llama4. Llama4 uses (interleaved) chunk attention to support long context. PiperOrigin-RevId: 746661156 --- .../splash_attention/splash_attention_mask.py | 100 +++++++++- .../pallas/tpu_splash_attention_mask_test.py | 184 ++++++++++++++++++ 2 files changed, 281 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index eab2a695dc02..e43f30e7791c 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -92,6 +92,35 @@ def make_local_attention_mask( return mask.astype(np.bool_) +def make_chunk_attention_mask( + shape: tuple[int, int], chunk_size: int +) -> np.ndarray: + """Makes a chunked causal attention mask. + + Args: + shape: The desired shape of the mask (q_seq_len, kv_seq_len). + chunk_size: The size of the attention chunks. + + Returns: + A boolean mask of shape `mask_shape` where True indicates attention is + allowed according to chunked causal rules, and False otherwise. + + Raises: + ValueError: If chunk_window_size is None or not positive. + """ + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + + # chunk mask calculation + same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size) + mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :]) + return mask + + def make_random_mask( shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: @@ -196,15 +225,20 @@ def __hash__(self): class _ComputableMask(Mask): """Superclass for all masks that can be computed inside the kernel using a callable object. + This subclass is designed to be used with Splash Attention. + It allows the mask logic to be computed on-the-fly or fused into the attention + kernel, avoiding the memory cost of materializing the full + (sequence_length, sequence_length) boolean mask array, which can be excessive + for long sequences. + Attributes: _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - q_sequence: Indices of Q sequence. - q_sequence is reused across __getitem__ calls which is important for - compile-time performance. + q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__ + calls which is important for compile-time performance. mask_function: Function used by the SplashAttention kernel to compute the mask rather than loading it. """ @@ -314,6 +348,66 @@ def __hash__(self): )) +class ChunkedCausalMask(_ComputableMask): + """Lazy chunked causal mask. + + Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens + attend to each other but not accross chunks. + Llama4 models use interleaved chunk attention along with global attention. + + + Attributes: + chunk_size: The size of each attention chunk. + """ + + chunk_size: int + + def __init__( + self, + shape: tuple[int, int], + chunk_size: int, + shard_count: int = 1, + ): + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + self.chunk_size = chunk_size + + # Define the mask function for chunk attention + def chunked_causal_mask_function(q_ids, kv_ids): + """Computes the mask logic for the given slice indices.""" + # Condition 1: Same chunk + same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size) + + # Condition 2: Causal + causal = q_ids >= kv_ids + + return same_chunk & causal + + super().__init__( + shape=shape, + mask_function=chunked_causal_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.chunk_size == other.chunk_size + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + class LocalMask(Mask): """Lazy local mask, prevents model from attending to tokens outside window. diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index 5379eb10990f..7c4b53529169 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -44,6 +44,15 @@ def _make_local_attention_mask(*args, **kwargs): return mask_lib.make_local_attention_mask(*args, **kwargs) +def _make_lazy_chunked_causal_mask(shape, chunk_size): + mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + return mask[:, :] + + +def _make_chunked_causal_mask(shape, chunk_size): + return mask_lib.make_chunk_attention_mask(shape=shape, chunk_size=chunk_size) + + class SplashAttentionMaskTest(jtu.JaxTestCase): @parameterized.parameters([_make_lazy_causal_mask, _make_causal_mask]) @@ -412,6 +421,181 @@ def test_lazy_local_mask_chunking( block_size, ) + @parameterized.parameters( + [_make_lazy_chunked_causal_mask, _make_chunked_causal_mask] + ) + def test_chunked_causal_mask(self, make_chunked_mask): + """Tests the chunked causal mask logic for various shapes and chunk sizes.""" + with self.subTest("unit"): + expected = np.array([[1]], dtype=np.bool_) + actual = make_chunked_mask(shape=(1, 1), chunk_size=1) + self.assertArraysEqual(actual, expected) + actual = make_chunked_mask(shape=(1, 1), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_exact_chunks"): + # Chunk 0: [0, 1], Chunk 1: [2, 3] + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_uneven_chunks"): + expected = np.array( + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(5, 5), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("wide_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 6), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("tall_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(6, 4), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_1"): + # Should only allow self-attention q==k and chunk_size == 1 + expected = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=1) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_greater_equal_seqlen"): + # Should behave like a normal causal mask + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + # Test chunk_size == seqlen + actual_eq = make_chunked_mask(shape=(4, 4), chunk_size=4) + self.assertArraysEqual(actual_eq, expected) + # Test chunk_size > seqlen + actual_gt = make_chunked_mask(shape=(4, 4), chunk_size=5) + self.assertArraysEqual(actual_gt, expected) + + @parameterized.product( + block_size=[(128, 128), (256, 128), (128, 256)], + shape=[(512, 512), (512, 1024), (1024, 512)], + chunk_size=[64, 128, 256, 512, 1024], + ) + def test_lazy_chunked_causal_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + chunk_size: int, + ): + """Compares lazy chunked mask evaluation against the dense version block-by-block.""" + q_len, kv_len = shape + # Adjust block size if it exceeds shape dimensions + adjusted_block_size = ( + min(block_size[0], q_len), + min(block_size[1], kv_len), + ) + + if ( + q_len % adjusted_block_size[0] != 0 + or kv_len % adjusted_block_size[1] != 0 + ): + self.skipTest( + f"Shape {shape} not divisible by block_size {adjusted_block_size}" + ) + + dense_mask = _make_chunked_causal_mask(shape=shape, chunk_size=chunk_size) + lazy_mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + self._compare_masks( + dense_mask, + lazy_mask, + adjusted_block_size, + ) + + def test_chunked_causal_mask_invalid_chunk_size(self): + """Tests that invalid chunk_size raises ValueError.""" + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=0) + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=-1) + with self.assertRaises(ValueError): + mask_lib.make_chunk_attention_mask(shape=(10, 10), chunk_size=0) + + def test_chunked_causal_mask_minimal_equality_hash(self): + """Tests for __eq__ and __hash__ of ChunkedCausalMask.""" + shape1, chunk_size1 = (128, 256), 16 + shape2, chunk_size2 = (128, 128), 32 # Different shape/chunk_size + + # Create three masks: two identical, one with different shape/chunk_size. + mask1 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask2 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask_diff_shape = mask_lib.ChunkedCausalMask( + shape=shape2, chunk_size=chunk_size1 + ) + mask_diff_chunk = mask_lib.ChunkedCausalMask( + shape=shape1, chunk_size=chunk_size2 + ) + other_obj = object() + + # Test __eq__ + self.assertEqual(mask1, mask2) + self.assertNotEqual(mask1, mask_diff_shape) + self.assertNotEqual(mask1, mask_diff_chunk) + self.assertNotEqual(mask1, other_obj) + + # Test __hash__ of identical masks + self.assertEqual(hash(mask1), hash(mask2)) + + mask_set = {mask1, mask2, mask_diff_chunk} + self.assertLen(mask_set, 2) # mask1 and mask2 are duplicates + self.assertIn(mask1, mask_set) + self.assertIn(mask_diff_chunk, mask_set) + self.assertNotIn(mask_diff_shape, mask_set) + def test_using_logical_operators_raises_exception(self): mask_1 = mask_lib.NumpyMask( mask_lib.make_random_mask((256, 256), 0.5, seed=1) From dc10200906d7b4115d5a9eed793eb3841af41c64 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 11 Apr 2025 21:53:19 -0700 Subject: [PATCH 0588/1769] [explain-cache-miss] Improve the detection of user file names When we print explanations for tracing cache misses, we use traceback_util to ignore JAX-internal functions. Here we change the detection mechanism to use source_info_util, which has a more exhaustive list of JAX internals. This removes a lot of uninteresting explanations from a large benchmark. jax-fixit PiperOrigin-RevId: 746703003 --- jax/_src/pjit.py | 2 +- tests/api_test.py | 22 ---------------------- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1dd1c6609a62..d1e5c3bfbc54 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1172,7 +1172,7 @@ def unpack(key): debug_info = fun.debug_info func_filename = debug_info.func_filename - if func_filename and not traceback_util.include_filename(func_filename): + if func_filename and not source_info_util.is_user_filename(func_filename): return msg: list[str] = [] diff --git a/tests/api_test.py b/tests/api_test.py index 72623279192f..a5e192a9f826 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4557,28 +4557,6 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations_unpacks_transforms(self): - # Tests that the explain_tracing_cache_miss() function does not throw an - # error when unpacking `transforms` with a length greater than 3. - @jax.jit - def f(key): - return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) - - with config.explain_cache_misses(True): - with self.assertLogs(level="WARNING") as cm: - f(jax.random.key(seed=123)) - - if is_persistent_cache_enabled(): - # 4 warnings from tracing cache, 5-10 from persistent cache depending on - # the backend - self.assertTrue(9 <= len(cm.output) <= 15) - self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) - else: - self.assertLen(cm.output, 4) - for msg in cm.output: - self.assertIn("TRACING CACHE MISS", msg) - def test_cache_miss_explanations_no_source_info(self): # ``operator.add`` is a built-in function and does not have source info. with config.explain_cache_misses(True): From 19d3d954bf417ba43690a6a3c6268ade32252985 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 11 Apr 2025 22:08:56 -0700 Subject: [PATCH 0589/1769] unify `stages.Executable` and `stages.XlaExecutable` We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes. PiperOrigin-RevId: 746706712 --- jax/_src/interpreters/pxla.py | 8 +- jax/_src/stages.py | 163 ++++++++++++++-------------------- 2 files changed, 70 insertions(+), 101 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index fb7e352bf600..2cd9f6b8d795 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1116,7 +1116,7 @@ def from_hlo(hlo: ir.Module, jaxpr_debug_info=jaxpr_debug_info).load() -class PmapExecutable(stages.XlaExecutable): +class PmapExecutable(stages.Executable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_unloaded_executable"] @@ -1136,7 +1136,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -3122,7 +3122,7 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat): return tree_util.dispatch_registry.flatten(out_unflat, None) -class MeshExecutable(stages.XlaExecutable): +class MeshExecutable(stages.Executable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", @@ -3158,7 +3158,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 19cd0822aa58..55ba5e319e9d 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,6 +30,7 @@ """ from __future__ import annotations +import abc import functools from collections.abc import Sequence from dataclasses import dataclass @@ -61,98 +62,9 @@ CompilerOptions = dict[str, Union[str, bool]] -# -- Internal protocols - -class Executable(Protocol): - """Protocol for executables, which a user-facing ``Compiled`` encapsulates.""" - - def call(self, *args_flat) -> Sequence[Any]: - """Execute on the flat list of arguments, returning flat outputs.""" - # TODO(frostig): improve annotation (sequences of arrays/buffers) - raise NotImplementedError - - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: - """Flat sequence of input shardings. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - raise NotImplementedError - - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: - """Flat sequence of output shardings. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - raise NotImplementedError - - def input_layouts(self): - raise NotImplementedError - - def output_layouts(self): - raise NotImplementedError - - def as_text(self) -> str: - """A human-readable text representation of this executable. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. It is relayed directly to external callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def memory_analysis(self) -> Any: - """A summary of estimated memory requirements. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def runtime_executable(self) -> Any: - """An arbitrary object representation of this executable. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - """ - raise NotImplementedError - - def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: - """Optionally constructs a fast c++ dispatcher.""" - return None - +# -- Internal types +# TODO(frostig): collapse with XlaLowering class Lowering(Protocol): """Protocol for lowerings, which a user-facing ``Lowered`` encapsulates.""" @@ -208,21 +120,37 @@ def cost_analysis(self) -> Any: raise NotImplementedError -# -- Internal adapters from XLA-related objects to the above protocols - -class XlaExecutable(Executable): +class Executable(metaclass=util.StrictABCMeta): def xla_extension_executable(self) -> xc.LoadedExecutable: - raise NotImplementedError("must override") + raise NotImplementedError( + "compiled executable carries no loaded XLA executable. It may be " + f"that {type(self)} defines an incomplete implementation.") + @abc.abstractmethod def call(self, *args_flat) -> Sequence[Any]: - raise NotImplementedError("must override") + """Execute on the flat list of arguments, returning flat outputs.""" + pass + + def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: + """Optionally constructs a fast c++ dispatcher.""" + return None def input_shardings(self) -> Sequence[jax.sharding.Sharding]: + """Flat sequence of input shardings. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ raise NotImplementedError( "compiled executable carries no input sharding information") def output_shardings(self) -> Sequence[jax.sharding.Sharding]: + """Flat sequence of output shardings. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ raise NotImplementedError( "compiled executable carries no output sharding information") @@ -235,6 +163,14 @@ def output_layouts(self): "compiled executable carries no input layout information") def as_text(self) -> str: + """A human-readable text representation of this executable. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. It is relayed directly to external callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() err_msg = ("text view unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -249,7 +185,19 @@ def as_text(self) -> str: else: raise - def cost_analysis(self) -> dict[str, float]: + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() if hasattr(xla_ext_exe, "cost_analysis"): @@ -273,6 +221,18 @@ def cost_analysis(self) -> dict[str, float]: ) def memory_analysis(self) -> Any: + """A summary of estimated memory requirements. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() err_msg = ("memory analysis unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -288,6 +248,15 @@ def memory_analysis(self) -> Any: raise def runtime_executable(self) -> Any: + """An arbitrary object representation of this executable. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + """ return self.xla_extension_executable() From 4ff78e6a0eb7b3a543b4bb72ecbb6e093399c4c5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Apr 2025 23:01:55 -0700 Subject: [PATCH 0590/1769] Remove various methods from `MeshExecutable` These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`. Co-authored-by: Roy Frostig PiperOrigin-RevId: 746716734 --- jax/_src/interpreters/pxla.py | 14 -------------- jax/_src/stages.py | 14 ++++++++------ 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2cd9f6b8d795..afba3cef2005 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3186,20 +3186,6 @@ def call(self, *args): self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[JSharding]: - return self._in_shardings - - def output_shardings(self) -> Sequence[JSharding]: - return self._out_shardings - - def input_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] - - def output_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] - def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and not self.unsafe_call.has_unordered_effects and diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 55ba5e319e9d..3f90ae9f78bb 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -457,7 +457,7 @@ def runtime_executable(self) -> Any | None: @property def input_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.input_shardings() + shardings_flat = self._executable._in_shardings # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): iter_shardings_flat = iter(shardings_flat) @@ -467,13 +467,14 @@ def input_shardings(self): # PyTree[sharding.Sharding] @property def output_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.output_shardings() + shardings_flat = self._executable._out_shardings return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error @property def input_layouts(self): - layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + dll_flat = self._executable._xla_in_layouts + layouts_flat = [Layout(l, s) + for l, s in zip(dll_flat, self._executable._in_shardings)] # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) @@ -483,8 +484,9 @@ def input_layouts(self): @property def output_layouts(self): - layouts_flat = self._executable.output_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + dll_flat = self._executable._xla_out_layouts + layouts_flat = [Layout(l, s) + for l, s in zip(dll_flat, self._executable._out_shardings)] return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error @staticmethod From 99ca14601d5794e5e01e9eac4574a5d84455fac0 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 11 Apr 2025 23:48:27 -0700 Subject: [PATCH 0591/1769] revert making `Executable` an ABC PiperOrigin-RevId: 746726071 --- jax/_src/stages.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3f90ae9f78bb..99e75768959f 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,7 +30,6 @@ """ from __future__ import annotations -import abc import functools from collections.abc import Sequence from dataclasses import dataclass @@ -120,17 +119,16 @@ def cost_analysis(self) -> Any: raise NotImplementedError -class Executable(metaclass=util.StrictABCMeta): +class Executable: def xla_extension_executable(self) -> xc.LoadedExecutable: raise NotImplementedError( "compiled executable carries no loaded XLA executable. It may be " f"that {type(self)} defines an incomplete implementation.") - @abc.abstractmethod def call(self, *args_flat) -> Sequence[Any]: """Execute on the flat list of arguments, returning flat outputs.""" - pass + raise NotImplementedError("compiled executable does not support invocation") def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: """Optionally constructs a fast c++ dispatcher.""" From 566d0775a808ff36efa041968903676ee04ed11e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sat, 12 Apr 2025 00:30:27 -0700 Subject: [PATCH 0592/1769] unify `stages.Lowering` and `stages.XlaLowering` We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes. PiperOrigin-RevId: 746735857 --- jax/_src/interpreters/pxla.py | 8 +-- jax/_src/stages.py | 111 ++++++++++++++-------------------- 2 files changed, 51 insertions(+), 68 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index afba3cef2005..d5a18ad2f439 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -922,7 +922,7 @@ def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") -class PmapComputation(stages.XlaLowering): +class PmapComputation(stages.Lowering): _hlo: ir.Module _executable: PmapExecutable | None @@ -931,7 +931,7 @@ def __init__(self, hlo: ir.Module, **compile_args): self._hlo = hlo self.compile_args = compile_args - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @@ -2433,7 +2433,7 @@ def _to_logical_sharding( raise TypeError(aval) -class MeshComputation(stages.XlaLowering): +class MeshComputation(stages.Lowering): _hlo: ir.Module _executable: MeshExecutable | None @@ -2449,7 +2449,7 @@ def __init__(self, name: str, hlo: ir.Module, self.compile_args = compile_args self._executable = None - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 99e75768959f..b813037a3204 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -63,61 +63,6 @@ # -- Internal types -# TODO(frostig): collapse with XlaLowering -class Lowering(Protocol): - """Protocol for lowerings, which a user-facing ``Lowered`` encapsulates.""" - - def compile( - self, compiler_options: CompilerOptions | None = None) -> Executable: - """Compile and return a corresponding ``Executable``.""" - raise NotImplementedError - - def as_text(self, dialect: str | None = None, *, - debug_info: bool = False) -> str: - """A human-readable text representation of this lowering. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. It is relayed directly to external callers. - """ - raise NotImplementedError - - def compiler_ir(self, dialect: str | None = None) -> Any: - """An arbitrary object representation of this lowering. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - - Args: - dialect: Optional string specifying a representation dialect - (e.g. "stablehlo") - """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - This function estimates execution cost in the absence of compiler - optimizations, which may drastically affect the cost. For execution cost - estimates after optimizations, compile this lowering and see - ``Compiled.cost_analysis``. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - class Executable: @@ -258,8 +203,7 @@ def runtime_executable(self) -> Any: return self.xla_extension_executable() -class XlaLowering(Lowering): - """Adapts our various internal XLA-backed computations into a ``Lowering``.""" +class Lowering: compile_args: dict[str, Any] @@ -273,15 +217,23 @@ def hlo(self) -> xc.XlaComputation: def stablehlo(self) -> ir.Module: """Return a StableHLO representation of this computation.""" - raise NotImplementedError("must override") + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def compile( self, compiler_options: CompilerOptions | None = None) -> Executable: - raise NotImplementedError("must override") + """Compile and return a corresponding ``Executable``.""" + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def as_text(self, dialect: str | None = None, *, debug_info: bool = False) -> str: + """A human-readable text representation of this lowering. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. It is relayed directly to external callers. + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -295,6 +247,19 @@ def as_text(self, dialect: str | None = None, raise ValueError(f"unknown dialect: {dialect}") def compiler_ir(self, dialect: str | None = None) -> Any: + """An arbitrary object representation of this lowering. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + + Args: + dialect: Optional string specifying a representation dialect + (e.g. "stablehlo") + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -304,8 +269,26 @@ def compiler_ir(self, dialect: str | None = None) -> Any: else: raise ValueError(f"unknown dialect: {dialect}") - def cost_analysis(self) -> dict[str, float]: - raise NotImplementedError("must override") + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + This function estimates execution cost in the absence of compiler + optimizations, which may drastically affect the cost. For execution cost + estimates after optimizations, compile this lowering and see + ``Compiled.cost_analysis``. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") # -- Public-facing API, plus helpers @@ -562,14 +545,14 @@ class Lowered(Stage): lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] - _lowering: XlaLowering + _lowering: Lowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef _no_kwargs: bool def __init__( self, - lowering: XlaLowering, + lowering: Lowering, args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): @@ -581,7 +564,7 @@ def __init__( @classmethod def from_flat_info(cls, - lowering: XlaLowering, + lowering: Lowering, in_tree: tree_util.PyTreeDef, in_avals, donate_argnums: tuple[int, ...], From c69e61e1a9528d6abec0703d52e61406192d69c3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 12 Apr 2025 06:17:13 -0700 Subject: [PATCH 0593/1769] Remove jax.lib.xla_client.{XlaComputation,Shape}. PiperOrigin-RevId: 746803082 --- CHANGELOG.md | 3 ++- jax/lib/xla_client.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cb99e58db45..db5e0f9763d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, - `ops`, `register_custom_call_target`, `shape_from_pyval`. + `ops`, `register_custom_call_target`, `shape_from_pyval`, `Shape`, + `XlaComputation`. * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index bf66bda3b149..bd4d98462f11 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -66,16 +66,16 @@ None, ), "Shape": ( - "Shape is deprecated; use StableHLO instead.", - _xc.Shape, + "Shape has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "XlaBuilder": ( "XlaBuilder has been removed in JAX v0.6.0; use StableHLO instead.", None, ), "XlaComputation": ( - "XlaComputation is deprecated; use StableHLO instead.", - _xc.XlaComputation, + "XlaComputation has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), # Added Nov 20 2024, finalized 2025-04-09 "ArrayImpl": ( From 69173a289cac2797799df708d92f7a5eec197f80 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 12 Apr 2025 07:35:31 -0700 Subject: [PATCH 0594/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/007ab7fd0d30d585b802efcad403863d94e8b1c9. PiperOrigin-RevId: 746816179 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fb219ad99cb6..3cd7b1d56459 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ca9011742bb84b3d2158feb262ddca221957ccc9" -XLA_SHA256 = "7a0eb3d157236c0e9b4bdf2598d411a216e3fb7bbc0b47d20810746fb0ba772c" +XLA_COMMIT = "007ab7fd0d30d585b802efcad403863d94e8b1c9" +XLA_SHA256 = "0208d8ec8c7013d115173f48c21ac035c86612ce74d53c08843db8553c3872c4" def repo(): tf_http_archive( From ca50cae5a490b1babaa129c2bf17611ec567c33b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 9 Apr 2025 09:58:23 +0000 Subject: [PATCH 0595/1769] Properly center and size the SM image in the GPU docs --- docs/pallas/gpu/reference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 416679d9654c..11a84d251d9d 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -27,7 +27,7 @@ but multiple warps can be assigned to the same SM subdivision. At each clock cyc warp scheduler from each subdivision tries to select one of its resident warps to execute the next instruction. -![A diagram of one SM](../../_static/pallas/gpu/nvidia_sm.svg) +

A diagram of one NVIDIA SM
Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are 4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming From 7edd5d50dde5b82ad8c18e7bd5b8ae639800a7fd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 13:08:32 +0000 Subject: [PATCH 0596/1769] Add reference docs for Pallas:MGPU synchronization primitives --- docs/pallas/gpu/reference.md | 100 +++++++++++++++++++++++++++++++---- 1 file changed, 91 insertions(+), 9 deletions(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 11a84d251d9d..8ed443220acb 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -78,6 +78,10 @@ or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong For more information on how warp scheduling and instruction issue works, we recommend reading [Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). +### Memory spaces + +TODO: GMEM, SMEM, RMEM, (maybe TMEM) + ## Array layouts and reference transforms TODO @@ -105,27 +109,105 @@ TODO ### `commit_smem` -TODO +Regular reads/writes to references are guaranteed to produce values consistent +with the sequential program order. For example, in the following program, it is +guaranteed that `value` is equal to `value2`. +```python +ref[...] = value +value2 = ref[...] +``` + +This guarantee, however, does not extend to asynchronous primitives such as async +copies or MMA operations. To make the SMEM writes visible to those primitives, you +are required to explicitly synchronize with them using the `plgpu.commit_smem()` function. + +For example: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.copy_smem_to_gmem(smem_ref, ...) +``` +or: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.wgmma(smem_ref, ...) +``` + +Failing to call this function is likely to cause subtle data races, due to those asynchronous +hardware units reading stale data from SMEM. Unfortunately, this function is relatively expensive, +which is why we rely on you, the user, to insert it in the minimal number of places where it's necessary. ### `Barrier` This is essentially a thin wrapper around an array of PTX `mbarrier` types and is passed in as a reference. All functions involving barriers expect to only get a single barrier argument, and so if the reference contains multiple, you have to extract one -of them explicitly using `barriers.at[index]`. +of them explicitly using `barriers.at[index]`. `Barrier`s are always allocated in SMEM +and as such have relatively low overheads. Each barrier can be configured to complete +after a fixed number of "arrivals" (by default 1). -`Barrier`s are always allocated in SMEM and as such have relatively low overheads. -There are three primary use cases that require the use of `Barrier`s: +To block a thread until a barrier completes, use the following function: +```python +plgpu.barrier_wait(barrier) +``` -1. Awaiting asynchronous GMEM-to-SMEM copies +There are three operations that can complete a barrier: -TODO +> It is critical to ensure that the synchronization scheme makes it impossible for two + barrier completions to happen without a call to `plgpu.barrier_wait` in between them. + For example, if you use `Barrier`s to synchronize two producer/consumer threads, you + need to perform barrier synchronization going both ways to introduce "backpressure" + that will stop one thread from arriving twice before the other one had a chance to await. + Failing to satisfy this will corrupt the data structure and can cause surprising failures + (including CUDA runtime errors). See below for an example of a valid program with two threads. -2. Cross-warpgroup synchronization +#### Asynchronous GMEM-to-SMEM copies -TODO +When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it will +post progress updates to the barrier given to `plgpu.copy_gmem_to_smem`. Once the copy +is complete, the barrier will complete one arrival as well. + +#### Explicit arrival (cross-thread synchronization) + +Any thread can explicitly arrival on a barrier using the following function: +```python +plgpu.barrier_arrive(barrier) +``` + +This is especially useful when synchronizing two threads that are in producer/consumer +roles. In this case, we recommend allocating two arrays of `Barrier`s, with size equal +to the size of the "queue" used to pass data between the two threads. For example, +assume one thread continues writing tiles of an array to SMEM while another thread +reads them. We triple-buffer the SMEM region to allow more asynchrony between the two +threads: + +```python +tid = jax.lax.axis_index("thread") +assert queue.shape == (buffering, *item_shape) +assert produced.shape == consumed.shape == (buffering,) + +def thread0_body(i, _): + slot = jax.lax.rem(i, buffering) + @pl.when(i >= buffering) + def _await_consumed(): + plgpu.barrier_wait(consumed.at[slot]) # Wait for consumption of the value before overwriting it + # Option 1: Compute the next value + queue[slot] = produce() + plgpu.barrier_arrive(produced.at[slot]) # Signal the value is ready + # Option 2: Produce the value through async_copy + # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot]) +pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None)) + +def thread1_body(i, _): + slot = jax.lax.rem(i, buffering) + plgpu.barrier_wait(produced.at[slot]) # Wait for the value to be ready + consume(queue[slot]) # Load and compute + plgpu.barrier_arrive(consumed.at[slot]) # Signal that the value is consumed +pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None)) +``` -3. Awaiting `tcgen05` TensorCore instructions +#### Awaiting `tcgen05` TensorCore instructions TODO From 4fd610fc2d6bf17442ee3d8d0fa4d09400b55912 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 13 Apr 2025 07:32:50 -0700 Subject: [PATCH 0597/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/01b33c6596f2afeefaf76233cbb43cf6de66c1c9. PiperOrigin-RevId: 747085967 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3cd7b1d56459..f02028de99e3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "007ab7fd0d30d585b802efcad403863d94e8b1c9" -XLA_SHA256 = "0208d8ec8c7013d115173f48c21ac035c86612ce74d53c08843db8553c3872c4" +XLA_COMMIT = "01b33c6596f2afeefaf76233cbb43cf6de66c1c9" +XLA_SHA256 = "d18a63f603b206e6befda14e29a041c95a4036093b535bdb09b826d5808a2b89" def repo(): tf_http_archive( From f070cdecb3504897551b5ce1b4f9f0cf029e0483 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 10 Apr 2025 08:46:11 +0200 Subject: [PATCH 0598/1769] [explain-cache-miss] Improve tracing-cache-miss explanations The previous approach was to report, for several elements of the cache key, the closest mismatch. Some parts of the cache key were ignored, which led to "explanation unavailable". The same happened when we had two keys close to the current one, each differring in a different part of the key. No explanation was produced because for each part of the key, there was a matching key already in the cache, even though the key taken as a whole did not match. Now, we scan *all* parts of they key and compute the differences. We keep track of the "size" of the differences, and we explain the differences to those keys that are closest (possibly more than one key if equidistant). For example, for shape differences we'll report the closest matching shape. If a type differs in both the dtype and some parts of the shape, or sharding, it is considered farther away. We add new tests and explanations for different static argnums and argnames. There are still cases when we do not produce an explanation, but now the "explanation unavailable" includes a description of which component of the key is different, and what the difference is. This may still be hard to understand by the user but at least they can file a clearer bug. Refactored the tests, and added a few new ones. --- jax/_src/pjit.py | 333 +++++++++++++++++++++++++++------------ tests/api_test.py | 223 +++++++++++++++++++++----- tests/debug_info_test.py | 60 ------- tests/pjit_test.py | 5 +- 4 files changed, 414 insertions(+), 207 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d1e5c3bfbc54..afc7a5bed52f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -21,7 +21,6 @@ from functools import partial import inspect import logging -import operator as op import weakref from typing import NamedTuple, Any, Union, cast import warnings @@ -1158,17 +1157,209 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, callsites_with_tracing_cache_miss: set[str] = set() +def diff_tracing_cache_keys( + k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]: + """Explanations of differences between the cache keys, along with diff sizes. + + Result: a pair of a list of explanations for differences, and the total size + of the differences. The sizes are used to pick the old key with the smallest + different size for the explanation that is shown to the user. + """ + (fun_transforms_k, fun_params_k, fun_in_type_k, + (arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k + (fun_transforms_ok, fun_params_ok, fun_in_type_ok, + (arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk + + diffs: list[tuple[str, int]] = [] # each difference with its size + def unavailable(key_field: str, what_k, what_ok): + diffs.append( + (f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n" + "explanation unavailable! " + "please open an issue at https://github.com/jax-ml/jax.", + 10)) + + def list_diff_size(s1: Sequence, s2: Sequence) -> int: + min_len = min(len(s1), len(s2)) + diff_size = max(len(s1), len(s2)) - min_len + diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len], + s2[:min_len])) + return diff_size + + different_leaf_count = False + + def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple): + dyn_argnums_k, static_args_k = param_k + dyn_argnums_ok, static_args_ok = param_ok + if dyn_argnums_k != dyn_argnums_ok: + diffs.append( + ("different static_argnums:\n" + f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}", + 1)) + if static_args_k != static_args_ok: + diffs.append( + ("different value of static args:\n" + f" now {', '.join(repr(a.val) for a in static_args_k)}" + f" and before {', '.join(repr(a.val) for a in static_args_ok)}", + list_diff_size(static_args_k, static_args_ok))) + + def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple): + static_kwargs_k, = param_k + static_kwargs_ok, = param_ok + static_kwargs_k = [(k, v.val) for k, v in + sorted(static_kwargs_k.val.items())] + static_kwargs_ok = [(k, v.val) for k, v in + sorted(static_kwargs_ok.val.items())] + if static_kwargs_k != static_kwargs_ok: + diffs.append( + ("different value of static kwargs:\n" + f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}" + f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}", + list_diff_size(static_kwargs_k, static_kwargs_ok))) + + def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef): + nonlocal different_leaf_count + different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves) + if not different_leaf_count: + # Look for the special case of passing positional args as kwargs or + # vice-versa; the common prefix of positional args match. + args_tree_k, kwargs_tree_k = treedef_children(in_tree_k) + nr_args_k = len(treedef_children(args_tree_k)) + args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok) + nr_args_ok = len(treedef_children(args_tree_k)) + if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] == + treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]): + keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index] + keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index] + diffs.append( + (("different number of args and kwargs, but same total number.\n" + f" now {nr_args_k} args and kwargs " + f"with keys {keys_k}\n" + f" before {nr_args_ok} args and kwargs " + f"with keys {keys_ok}"), + abs(nr_args_ok - nr_args_k))) + return + + in_tree_k_str = str(in_tree_k) + in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73 + else in_tree_k_str[:73] + "...") + in_tree_ok_str = str(in_tree_ok) + in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73 + else in_tree_ok_str[:73] + "...") + diff = [f"different input pytree:\n now: {in_tree_k_str}\n" + f" before: {in_tree_ok_str}"] + + errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok)) + for path, thing1, thing2, explanation in errs: + fst, *path = path # type: ignore + base = ["args", "kwargs"][fst.idx] + diff.append( + f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2}," + f" so {explanation}") + diffs.append(("\n".join(diff), len(errs))) + + def explain_args_type_diff(args_k: tuple[core.AbstractValue], + args_ok: tuple[core.AbstractValue]): + diff_size = 0 + arg_names = debug_info.safe_arg_names(len(args_k)) + def arg_type_to_str(at): + if hasattr(at, "str_short"): + return at.str_short(short_dtypes=True) + else: + return str(at) + args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}" + for an, at in zip(arg_names, args_k)) + args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..." + diff = [f"different input types:\n types now: {args_k_str}"] + add_weak_type_hint = False + + for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok): + if arg_t_k == arg_t_ok: continue + this_arg_diff_size = 0 + if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray: + s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok) + this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore + + if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore + s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore + s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore + add_weak_type_hint = True + this_arg_diff_size += 1 + elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore + s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + this_arg_diff_size += 1 + else: + s1, s2 = str(arg_t_k), str(arg_t_ok) + diff_size += max(1, this_arg_diff_size) + diff.append(f" * at {name}, now {s1} and before {s2}") + + if add_weak_type_hint: + diff.append( + "where weak_type=True often means a Python builtin numeric value, and \n" + "weak_type=False means a jax.Array.\n" + "See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.") + diffs.append(("\n".join(diff), diff_size)) + + if fun_transforms_k != fun_transforms_ok: + if len(fun_transforms_k) != len(fun_transforms_ok): + different_leaf_count = True # Skip other more precise checks + unavailable("fun_transforms length", + fun_transforms_k, fun_transforms_ok) + else: + for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)): + t_name = t[0].__name__ + if t == ot: continue + if t[0] != ot[0]: + unavailable(f"fun_transforms[{i}] transform", t, ot) + continue + + if t_name == "flatten_fun": + explain_in_tree_diff(t[1][0], ot[1][0]) + continue + if t_name == "_argnums_partial": + explain_transform_argnums_partial(t[1], ot[1]) + continue + if t_name == "_argnames_partial": + explain_transform_argnames_partial(t[1], ot[1]) + continue + unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:]) + continue + + # If we had different leaf counts, we can discard the _argnums_partial + # difference. That transform sometimes occurs before the flatten_fun + if different_leaf_count: + diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]] + if fun_params_k != fun_params_ok: + unavailable("fun_params", fun_params_k, fun_params_ok) + if fun_in_type_k != fun_in_type_ok: + unavailable("fun_in_type", fun_params_k, fun_params_ok) + if arg_in_type_k != arg_in_type_ok and not different_leaf_count: + explain_args_type_diff(arg_in_type_k, arg_in_type_ok) + if arg_attr_data_k != arg_attr_data_ok: + unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok) + if arg_inline_k != arg_inline_ok: + unavailable("arg_inline", arg_inline_k, arg_inline_ok) + if ctx_k != ctx_ok: + assert len(ctx_k) == len(ctx_ok) + idxs = [f" [{i}]: now {c_k} and before {c_ok}" + for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok] + diffs.append( + ("different tracing context, e.g. due to config or context manager.\n" + "found differences at positions\n" + + ", and\n".join(idxs) + + "\ncompare to tuple returned by " + "config.trace_context() in jax/_src/config.py.", + len(idxs))) + if not diffs: # Should never happen, but let's not crash + unavailable("something (unexpected empty diffs)", k, oldk) + diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1])) + return (diffs_and_sizes[0], sum(diffs_and_sizes[1])) + + def explain_tracing_cache_miss( fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): if config.check_tracer_leaks.value: return - - def unpack(key): - transforms, (), _, (in_type, _, inline), *_, ctx = key - # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, inline.val, ctx - in_tree, in_type, inline, ctx = unpack(key) - if inline: return + if key[3][2].val: return # No explanations for "inline" functions debug_info = fun.debug_info func_filename = debug_info.func_filename @@ -1177,7 +1368,7 @@ def unpack(key): msg: list[str] = [] p = msg.append - done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) + done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) p(f"TRACING CACHE MISS at {callsite} because:") @@ -1188,110 +1379,42 @@ def unpack(key): src_info += f" defined at {func_filename}" if func_lineno := debug_info.func_lineno: src_info += f":{func_lineno}" - if unseen_f: - p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}") + func_name = debug_info.func_name + if unseen_f or not cache: + p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}") if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") else: callsites_with_tracing_cache_miss.add(callsite) return done() - else: - p(f" for {debug_info.func_name}{src_info}") - - seen_keys = map(unpack, cache.keys()) - - # have we maybe switched some args to be kwargs or visa-versa? - args_tree, kwargs_tree = treedef_children(in_tree) - args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys] - args_kwargs_match = [t for t in args_kwargs_trees - if t == [args_tree, kwargs_tree]] - if not args_kwargs_match: - num_args = len(treedef_children(args_tree)) - _, kwarg_keys = kwargs_tree.node_data() # type: ignore - p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} " - "keyword args with keys:\n" - f" {', '.join(map(repr, kwarg_keys))}") - dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore - if t != [args_tree, kwargs_tree]] - close_kwargs = min( - dont_match, key=set(kwarg_keys).symmetric_difference, default=None - ) - if not close_kwargs: - p(" closest seen is passing no keyword args") - else: - p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n" - f" {', '.join(map(repr, close_kwargs))}") - return done() - # have we never seen this tracing context before? - ctxs_match = [c for *_, c in seen_keys if c == ctx] - if not ctxs_match: - p(" tracing context doesn't match, e.g. due to config or context manager") - dont_match = [c for *_, c in seen_keys if c != ctx] - closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx))) - idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2] - p(" closest seen context tuple differs at positions:\n" - f" {', '.join(map(str, idxs))}\n" - " compare to tuple returned by config._trace_context() in jax/_src/config.py.") - return done() - - # have we never seen this input pytree before? - trees_match = [k for k in seen_keys if k[0] == in_tree] - if not trees_match: - in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else '' - p(f" never seen input pytree{in_tree_str}") - dont_match = [t for t, *_ in seen_keys if t != in_tree] - closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] - p(f" closest seen input pytree has {len(errs)} mismatches, including:") - for path, thing1, thing2, explanation in errs: - fst, *path = path # type: ignore - base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," - f" so {explanation}") - return done() + p(f" for {func_name}{src_info}") + + diffs = [diff_tracing_cache_keys(key, ok, debug_info) + for ok in cache.keys() if key != ok] + assert diffs, "we must find some diffs if key differs from all cache keys" + min_diff = min(diffs, key=lambda v: v[1]) + smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys + smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]] + def indent_subsequent_lines(indent: int, msg: str) -> str: + return msg.replace("\n", "\n" + " " * indent) + def p_one_diff(diff: Sequence[str]): + for d in diff: + p(" * key with " + indent_subsequent_lines(4, d)) + + if len(smallest_diffs) == 1: + p(" all previously seen cache keys are different. Closest previous key:") + p_one_diff(smallest_diffs[0]) + else: + p(" all previously seen cache keys are different. " + "Several previous keys are closest:") + for d in smallest_diffs: + p_one_diff(d) - # have we never seen these input types (eg shapes, dtypes) before? - types_match = [k for k in trees_match if k[1] == in_type] - if not types_match: - if len(in_type) < 5: - in_type_str = ":\n {}".format(", ".join( - f"{n}: {ty.str_short(short_dtypes=True)}" - for n, ty in zip(debug_info.arg_names, in_type))) - else: - in_type_str = '' - p(f" never seen input type signature{in_type_str}") - dont_match = [t for _, t, *_ in trees_match if t != in_type] - closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type))) - num_mismatch = sum(map(op.ne, closest_ty, in_type)) - p(f" closest seen input type signature has {num_mismatch} mismatches, including:") - add_weak_type_hint = False - arg_names = debug_info.safe_arg_names(len(in_type)) - - for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): - if ty1 != ty2: - if type(ty1) == type(ty2) == core.ShapedArray: - s1, s2 = ty1.str_short(True), ty2.str_short(True) - if ty1.weak_type != ty2.weak_type: - s1 += f"{{weak_type={ty1.weak_type}}}" - s2 += f"{{weak_type={ty2.weak_type}}}" - add_weak_type_hint = True - elif ty1.sharding != ty2.sharding: - s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) - s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True) - else: - s1, s2 = str(ty1), str(ty2) - p(f" * at {name}, seen {s1}, but now given {s2}") - if add_weak_type_hint: - p("where weak_type=True often means a Python builtin numeric value, and ") - p("weak_type=False means a jax.Array.") - p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types") - return done() + done() + return - # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") - return done() @partial(lu.cache, explain=explain_tracing_cache_miss) def _create_pjit_jaxpr( diff --git a/tests/api_test.py b/tests/api_test.py index a5e192a9f826..8705d2021577 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4466,74 +4466,219 @@ def add(x): self.assertEqual(tracing_add_count, 2) @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] + def test_cache_miss_explanations_skip_internals(self): + if is_persistent_cache_enabled(): + self.skipTest('With persistent cache, we see the cache misses') + + with config.explain_cache_misses(True): + with self.assertNoLogs(level='WARNING'): + for i in range(2): + jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_first_miss(self): + @jax.jit + def f(x): return x x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - # print on first miss, not on hit with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) + with self.assertLogs(level="WARNING") as cm: + f(x) + f(x) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("TRACING CACHE MISS", msg) + self.assertIn("never seen function", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_in_tree(self): + @jax.jit + def f(*args, **kwargs): return args[0] + + f(0., 1., y=(2., 2.1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + # Same number of leaves but different trees + f(0., (1., 1.1), y=2.) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input pytree", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_arg_passed_as_kwarg(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y=1.) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different number of args and kwargs, but same total number", msg) + self.assertIn("now 1 args and kwargs with keys ['y']", msg) + self.assertIn("before 1 args and kwargs with keys []", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnums(self): + @partial(jax.jit, static_argnums=(0, 2)) + def f(x, y, z): + return y + + f(1., 2., "foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(1., 2., "bar") + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different value of static args", msg) + self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnames(self): + @partial(jax.jit, static_argnames='foo') + def f(*, foo): + return 1 + + f(foo="foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(foo="bar") + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) + self.assertIn("different value of static kwargs", msg) + self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_dtype(self): + @jax.jit + def f(x, y): return x + f(np.float32(0), np.float32(1)) - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y_) + f(np.float32(0), np.int32(1)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) + self.assertIn("different input types", msg) + self.assertIn("at y, now i32[] and before f32[]", msg) + self.assertNotIn("explanation unavailable!", msg) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_weak_type(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + y = jnp.arange(4, dtype="float32") + f(jnp.float32(0.), y) # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg) + if config.enable_x64.value: + self.skipTest("Work only for 32 bit mode") + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg) + self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(np.float32(0), np.arange(1, dtype=np.float32)) - # kwarg change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(1, y=y) + f(np.float32(0), np.arange(2, dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) + self.assertIn("different input types", msg) + self.assertIn("at y, now f32[2] and before f32[1]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape_explain_closest(self): + @jax.jit + def f(x): return x + f(np.ones((1, 2), dtype=np.float32)) + f(np.ones((10, 20, 30), dtype=np.float32)) + f(np.ones((1, 2, 3), dtype=np.float32)) - # tracing config change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings + f(np.ones((10, 2, 30), dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_tracing_config(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + # tracing config change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + with jax.numpy_rank_promotion("warn"): + with jax.default_matmul_precision("high"): + f(0., 1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertTrue(1 <= len(cm.output) <= expected_log_len) msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) + self.assertIn("key with different tracing context", msg) + self.assertIn("now warn and before", msg) + self.assertIn("now high and before", msg) + self.assertNotIn("explanation unavailable!", msg) @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations_skip_internals(self): - if is_persistent_cache_enabled(): - self.skipTest('With persistent cache, we see the cache misses') + def test_cache_miss_explanations_multiple_changes(self): + @jax.jit + def f(x): return jnp.sin(x) + + call_1 = f(np.arange(4, dtype=np.float32)) + with jax.numpy_rank_promotion("warn"): + call_2 = f(np.arange(8, dtype=np.float32)) with config.explain_cache_misses(True): - with self.assertNoLogs(level='WARNING'): - for i in range(2): - jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + with self.assertLogs(level='WARNING') as cm: + # Matches call_2 in shape but not context, and call_1 in context but + # not in shape. + f(np.arange(8, dtype=np.float32)) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[8] and before f32[4]", msg) + self.assertIn("key with different tracing context", msg) + self.assertNotIn("explanation unavailable!", msg) @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 1f5ddba89e27..0fc1aabbaf57 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -392,66 +392,6 @@ def f(x): with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_arg_names_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0e7867cc2cdd..025512121a6d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3467,9 +3467,8 @@ def f(x, y): f(x_, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg) def test_pjit_function_cache_cpp(self): def f(x): From 2336cd169554e04a91b142c1191ccefcca7b31a2 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Sun, 13 Apr 2025 15:17:11 -0400 Subject: [PATCH 0599/1769] Minor improvements to doc for jax.nn.logsumexp. --- jax/_src/ops/special.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index fe4c46394832..6ed8804bc8e2 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -47,16 +47,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, JAX implementation of :func:`scipy.special.logsumexp`. .. math:: - \mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij}) + \operatorname{logsumexp} a = \log \sum_i b_i \exp a_i - where the :math:`j` indices range over one or more dimensions to be reduced. + where the :math:`i` indices range over one or more dimensions to be reduced. Args: a: the input array axis: int or sequence of ints, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes. - b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the - shape of `a`. + b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`. keepdims: If ``True``, the axes that are reduced are left in the output as dimensions of size 1. return_sign: If ``True``, the output will be a ``(result, sign)`` pair, From 2e4c0ec7ae42ae6124277e406e90ac6a8d21a697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Sun, 13 Apr 2025 17:57:43 -0700 Subject: [PATCH 0600/1769] [Mosaic:TPU] Add some invariant checking in VectorLayout ctor PiperOrigin-RevId: 747194404 --- jaxlib/mosaic/dialect/tpu/layout.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 12bf66cfcec0..8261d09697e3 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -233,6 +233,11 @@ class VectorLayout { implicit_dim_(implicit_dim) { // TODO(b/275751535): Allow more bitwidths. CHECK(llvm::has_single_bit(bitwidth_) && bitwidth_ <= 32); + CHECK_GT(tiling_[0], 0); + CHECK_GT(tiling_[1], 0); + CHECK_GE(offsets_[0].value_or(0), 0); + CHECK_GE(offsets_[1].value_or(0), 0); + CHECK_LT(offsets_[0].value_or(0), tiling_[0]); } static int num_implicit_dims(const ImplicitDim implicit_dim) { From 13c7183cfc03753731ca8a451dc90f496d1646f7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 11 Apr 2025 22:51:47 +0000 Subject: [PATCH 0601/1769] add a brief description of the jax.Array-has-no-__iadd__ gotcha fixes #226 --- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 42 +++++++++++++++++++++- docs/notebooks/Common_Gotchas_in_JAX.md | 24 ++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index de6da98b7d62..a9d4a9424f9f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -307,7 +307,7 @@ "id": "go3L4x3w4-9p" }, "source": [ - "If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)" + "If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉)" ] }, { @@ -357,6 +357,45 @@ "jax_array[1, :] = 1.0" ] }, + { + "cell_type": "markdown", + "id": "8f520bec", + "metadata": {}, + "source": [ + "And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fbed45", + "metadata": {}, + "outputs": [], + "source": [ + "jax_array = jnp.array([10, 20])\n", + "jax_array_new = jax_array\n", + "jax_array_new += 10\n", + "print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but...\n", + "print(jax_array) # the original value is unodified as [10, 20] !\n", + "\n", + "numpy_array = np.array([10, 20])\n", + "numpy_array_new = numpy_array\n", + "numpy_array_new += 10\n", + "print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated\n", + "print(numpy_array) # in-place, so both are [20, 30] !" + ] + }, + { + "cell_type": "markdown", + "id": "2604e220", + "metadata": {}, + "source": [ + "That's because NumPy defines `__iadd__` to perform in-place mutation. In\n", + "contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats\n", + "`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new +\n", + "10`, rebinding the variable without mutating any arrays." + ] + }, { "cell_type": "markdown", "metadata": { @@ -415,6 +454,7 @@ } ], "source": [ + "jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n", "updated_array = jax_array.at[1, :].set(1.0)\n", "print(\"updated array:\\n\", updated_array)" ] diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 9fbc26a46c8f..0857edc132fa 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -177,7 +177,7 @@ print(numpy_array) +++ {"id": "go3L4x3w4-9p"} -If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉) +If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉) ```{code-cell} ipython3 :id: iOscaa_GecEK @@ -197,6 +197,27 @@ jax_array = jnp.zeros((3,3), dtype=jnp.float32) jax_array[1, :] = 1.0 ``` +And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉) + +```{code-cell} ipython3 +jax_array = jnp.array([10, 20]) +jax_array_new = jax_array +jax_array_new += 10 +print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but... +print(jax_array) # the original value is unodified as [10, 20] ! + +numpy_array = np.array([10, 20]) +numpy_array_new = numpy_array +numpy_array_new += 10 +print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated +print(numpy_array) # in-place, so both are [20, 30] ! +``` + +That's because NumPy defines `__iadd__` to perform in-place mutation. In +contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats +`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new + +10`, rebinding the variable without mutating any arrays. + +++ {"id": "7mo76sS25Wco"} Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. @@ -219,6 +240,7 @@ For example, the update above can be written as: :id: PBGI-HIeCP_s :outputId: de13f19a-2066-4df1-d503-764c34585529 +jax_array = jnp.zeros((3,3), dtype=jnp.float32) updated_array = jax_array.at[1, :].set(1.0) print("updated array:\n", updated_array) ``` From 214d32373bf3bc33c899f8be4e88eb0883a998ce Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 14 Apr 2025 13:46:43 +0200 Subject: [PATCH 0602/1769] Fix race suppression for 3.13 --- .github/workflows/tsan-suppressions_3.13.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index b600e38276cc..dac134bf5169 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -23,7 +23,8 @@ race_top:PyMember_GetOne # https://github.com/python/cpython/issues/131680 # Fixed in Python 3.14, but not backported to 3.13. -race_top: new_reference +race_top:new_reference +race:_Py_IsOwnedByCurrentThread # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx From 3a7cec8563df7ce7d3bcc3fddf486186c93611f7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sun, 13 Apr 2025 09:15:58 +0000 Subject: [PATCH 0603/1769] Add Pallas:MGPU documentation for WGMMA --- docs/pallas/gpu/reference.md | 176 +++++++++++++++++++++++++++++++++-- 1 file changed, 169 insertions(+), 7 deletions(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 8ed443220acb..1e23d2eb7813 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -49,6 +49,8 @@ discuss later). One notable addition here is that we still allow you to co-sched of those Pallas-level threads on the same SM so that they can cooperate and communicate through shared memory (we relize that by putting them in the same CUDA block). +> From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. + > This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), but as you will see there are a few differences. Mosaic GPU tends to be more low level, which usually means you will have to put in more work, but it also puts you more in control. @@ -69,7 +71,7 @@ TensorCore operations are so big and take so many cycles to complete, that it is try to use other units in the meantime. To extend this even further, we can take advantage of this hardware-unit-level parallelism by -allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily +allowing multiple Pallas threads to run concurrently. If one of the threads primarily occupies the ALU, while another one primarily issues TensorCore related instructions, we can take advantage of the efficient context switching built into the warp schedulers to keep both units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) @@ -89,17 +91,159 @@ TODO ## MMA (TensorCore) In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. -NVIDIA continues to change the programming interface of the TensorCore significantly -between different hardware generations, which is why the lowest-level interfaces -differ in Pallas:MGPU as well. +The programming interface of the TensorCore changes significantly between different +NVIDIA GPU generations, which is why the lowest-level interfaces differ in Pallas:MGPU as well. + +Each MMA operation is associated with three operands: +* the accumulator `D` of shape `(M, N)`, +* the left input `A` of shape `(M, K)`, +* the right input `B` of shape `(K, N)`. +All operands must have the same element type. + +Each use of MMA involves a few steps: +1. Allocating the space for the accumulator (MMA implicitly performs `D += A @ B`) +2. Preparing the `A` and `B` operands +3. Issuing the operation +4. Waiting for the operation to complete +5. Reading out the result + +Steps 2.-4. are usually performed in a loop over the contraction dimension (`K`). + +### Memory space of `A` and `B` operands + +The `A` and `B` operands are generally best passed in through SMEM, where they can +be conveniently loaded using `plgpu.copy_gmem_to_smem`. For those operands to be +compatible with MMA operations, they need to have the appropriate tiling and swizzling +transforms specified upon their allocation. For all currently supported generations, +the TensorCore requires the data to be laid out into row-major 2D tiles of shape +`(8, swizzle_elems)`, where `swizzle_elems` is derived by dividing the swizzle by the +element type bytewidth. The currently supported swizzles are: 128, 64, and 32. Larger +swizzles are preferrable as they improve the performance of GMEM-to-SMEM copies. + +```python +def mma_transforms(shape_dtype: jax.ShapeDtypeStruct): + assert len(shape_dtype.shape) == 2 + if shape_dtype.shape[0] % 8: + raise ValueError("Number of rows must be divisible by 8") + for swizzle_bytes in (128, 64, 32): + swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize + if shape_dtype.shape[-1] % swizzle_elems == 0: + return (plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle_bytes)) + raise ValueError("Failed to find transforms for the specified window type") +``` + +If the operands need to be transformed, the `A` operand can be passed in through a different +memory space (architecture dependent, see below). The `B` operand _must_ be located in SMEM. + +### Transposed operands + +When performing MMA on 16-bit operands, the TensorCore can automatically transpose the +input data. For example, the `A` reference is allowed to be of shape `(K, M)`, but it +has to be transposed before passing it into the mma function. For example: +```python +assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N) +a_ref_t = plgpu.transpose_ref(a_ref, (1, 0)) +assert a_ref_t.shape == (M, K) # The shape expected by plgpu.wgmma +plgpu.wgmma(acc, a_ref_t, b_ref) +``` +An analogous operation is allowed on the `B` reference in this case too. ### Hopper (`wgmma`) -TODO +In this section, we cover the basics of using the Hopper-generation TensorCores, exposed in +PTX as the [`wgmma.mma_async` instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma). + +#### Allocating the accumulator + +In the Hopper hardware architecture the accumulator is allocated in registers, but in Pallas +it is modeled as a mutable reference, as each MMA operation accumulates in-place. +There are two ways to allocate the accumulator. + +To create a zero-initialized accumulator you can use `pl.run_scoped` with a +`plgpu.ACC((m, n), dtype)` type. +```python +def compute(acc_ref): + ... + return acc_ref[...] +output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) +``` +Dereferencing the accumulator reference, as seen in the end of the `compute` function will +implicitly await all outstanding WGMMA operations. + +If you'd like to initialize it with an existing array, you can use `pl.run_state` with +`plgpu.ACC.init(init_array)`: +```python +def compute(acc_ref): + ... + return # pl.run_state only returns the final value of the accumulator +output = pl.run_state(compute)(plgpu.ACC.init(init_array)) +``` +If `pl.run_state` has accumulator operands, it implicitly awaits all outstanding WGMMA +operations before returning the final values. + +#### Preparing the `A` and `B` operands + +As discussed above, we recommend passing in `A` and `B` through shared memory. In this +case the correct tiling and swizzling transforms must be specified. + +`plgpu.wgmma` additionally allows passing in `A` through registers (i.e. not an SMEM +reference but as a regular JAX array). This mode, however, comes with a number of +significant drawbacks and it is very difficult to ensure sufficient synchronization to +make this safe. + +TODO: Explain the conditions under which it is acceptable to do this. + +#### Issuing the operation + +The supported MMA shapes are such that: +* `M` is divisible by 64 +* `N` is divisible by 8 and smaller than 256 +* `K` is a multiple of `swizzle` divided by the bytewidth of element type + +The currently supported data types are: `jnp.float32`, `jnp.bfloat16` and `jnp.float16`. +The accumulator `D` must be a `jnp.float32`, with the exception of `jnp.float16` inputs, +in which case it is allowed to be `jnp.float16` as well. + +#### Waiting for the operation to complete + +Each `plgpu.wgmma` call implicitly synchronizes with all previous `plgpu.wgmma` calls, such +that once control returns from it, we guarantee that no WGMMA other than the last issued +one is still running. As such, any SMEM regions that were read by previously issued WGMMA +instructions can be reused. This is especially relevant for pipelining WGMMA with async memory copies: +```python +buffers = 3 # In reality you might want even more +assert a_smem.shape == (buffers, m, k) +assert b_smem.shape == (buffers, k, n) +assert acc_ref.shape == (m, n) + +def fetch_a_b(ki, slot): + a_slice = ... # Replace with the right M/K slice + b_slice = ... # Replace with the right K/N slice + plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot]) + plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot]) + +def loop_body(i, _): + slot = jax.lax.rem(i, buffers) + plgpu.barrier_wait(a_loaded.at[slot]) + plgpu.barrier_wait(b_loaded.at[slot]) + plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot]) + # We know that only the last issued WGMMA is running, so we can issue a async load in + # into the other buffer + load_i = i + buffers - 1 + load_slot = jax.lax.rem(load_i, buffers) + @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps)) + def _do_fetch(): + fetch_a_b(load_i, slot) +for slot in range(buffers): + fetch_a_b(slot, slot) +jax.lax.fori_loop(0, num_steps, loop_body, None) +``` ### Blackwell (`tcgen05`) -TODO +While Mosaic GPU supports `tcgen05` MMA instructions, exposing this capability to Pallas +is still work in progress. Stay tuned! ## Using `core_map` @@ -107,6 +251,9 @@ TODO ## Synchronization structures and primitives +In this section, we go over the most important functions and data structures +used for synchronization between threads and also some asynchronous operations. + ### `commit_smem` Regular reads/writes to references are guaranteed to produce values consistent @@ -162,6 +309,20 @@ There are three operations that can complete a barrier: Failing to satisfy this will corrupt the data structure and can cause surprising failures (including CUDA runtime errors). See below for an example of a valid program with two threads. +> Another critical restriction is that the number of barrier completions must equal the + number of barrier waits throughout the barrier's lifetime. It is not allowed to end a scoped + allocation of a barrier when it has an unawaited completion. Otherwise, when it is + reused by the compiler, leaving it in this state can cause problems downstream. + +> Finally, it is crucial to ensure that each thread that ever waits on a `Barrier` + takes part in all `wait` operations on it. It is not allowed to e.g. await every + other completion of a barrier from one thread, and all other completions from another + one. Doing so will lead to deadlocks. To recap: when a `Barrier` is used to wait in + some thread, it must observe every single completion of that barrier (by waiting on it). + + + Note that the `Barrier` can receive arrivals from any source, without restrictions. + #### Asynchronous GMEM-to-SMEM copies When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it will @@ -209,7 +370,8 @@ pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None)) #### Awaiting `tcgen05` TensorCore instructions -TODO +While Mosaic GPU supports `tcgen05` MMA instructions, exposing this capability to Pallas +is still work in progress. Stay tuned! ### `ClusterBarrier` From b8df4749652c4f046b504a28996bdd3173fd5db8 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 14 Apr 2025 12:53:59 +0300 Subject: [PATCH 0604/1769] [explain_cache_miss] Add to explanations the duration of the missed function call This enables the user to focus on the most important call sites. jax-fixit --- jax/_src/linear_util.py | 12 ++++++++---- jax/_src/pjit.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1231d3066062..bfe87430554e 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,6 +67,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from collections.abc import Callable, Sequence from functools import partial import re +import time from typing import Any, Hashable, NamedTuple import warnings import weakref @@ -446,7 +447,7 @@ def valid_size(d) -> bool: def cache(call: Callable, *, - explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): + explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -455,7 +456,8 @@ def cache(call: Callable, *, memoization cache key. explain: a function that is invoked upon cache misses to log an explanation - of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + of the miss. + Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`. Returns: A memoized version of ``call``. @@ -470,9 +472,11 @@ def memoized_fun(fun: WrappedFun, *args): ans, stores = result fun.populate_stores(stores) else: + if do_explain := explain and config.explain_cache_misses.value: + start = time.time() ans = call(fun, *args) - if explain and config.explain_cache_misses.value: - explain(fun, cache is new_cache, cache, key) + if do_explain: + explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) return ans diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index afc7a5bed52f..a10856cbced3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1357,7 +1357,8 @@ def arg_type_to_str(at): def explain_tracing_cache_miss( - fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): + fun: lu.WrappedFun, unseen_f: bool, cache: dict, + key: tuple, elapsed_sec: float): if config.check_tracer_leaks.value: return if key[3][2].val: return # No explanations for "inline" functions @@ -1371,7 +1372,7 @@ def explain_tracing_cache_miss( done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) - p(f"TRACING CACHE MISS at {callsite} because:") + p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:") # have we seen this function before at all? src_info = "" From 11a6abc6f7a94ea8991dfe423264130e4afa75df Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 14 Apr 2025 13:26:09 +0000 Subject: [PATCH 0605/1769] Remove accidental tab characters from Pallas:MGPU docs --- docs/pallas/gpu/reference.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 1e23d2eb7813..5cebdf364e5b 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -220,21 +220,21 @@ assert acc_ref.shape == (m, n) def fetch_a_b(ki, slot): a_slice = ... # Replace with the right M/K slice b_slice = ... # Replace with the right K/N slice - plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot]) - plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot]) + plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot]) + plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot]) def loop_body(i, _): - slot = jax.lax.rem(i, buffers) - plgpu.barrier_wait(a_loaded.at[slot]) - plgpu.barrier_wait(b_loaded.at[slot]) - plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot]) - # We know that only the last issued WGMMA is running, so we can issue a async load in - # into the other buffer - load_i = i + buffers - 1 - load_slot = jax.lax.rem(load_i, buffers) - @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps)) - def _do_fetch(): - fetch_a_b(load_i, slot) + slot = jax.lax.rem(i, buffers) + plgpu.barrier_wait(a_loaded.at[slot]) + plgpu.barrier_wait(b_loaded.at[slot]) + plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot]) + # We know that only the last issued WGMMA is running, so we can issue a async load in + # into the other buffer + load_i = i + buffers - 1 + load_slot = jax.lax.rem(load_i, buffers) + @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps)) + def _do_fetch(): + fetch_a_b(load_i, slot) for slot in range(buffers): fetch_a_b(slot, slot) jax.lax.fori_loop(0, num_steps, loop_body, None) From 077d134d6fe4867b253e97e0a914ba5ea7594ef2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 14 Apr 2025 07:06:11 -0700 Subject: [PATCH 0606/1769] Adjust test expectations for the tracing-cache-miss-explanations This fixes an error introduced in https://github.com/jax-ml/jax/pull/27980 for the case when we use a compilation cache. PiperOrigin-RevId: 747401715 --- tests/api_test.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/api_test.py b/tests/api_test.py index 8705d2021577..59a6211e3cbc 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4504,8 +4504,7 @@ def f(*args, **kwargs): return args[0] with self.assertLogs(level="WARNING") as cm: # Same number of leaves but different trees f(0., (1., 1.1), y=2.) - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("different input pytree", msg) self.assertNotIn("explanation unavailable!", msg) @@ -4521,8 +4520,8 @@ def f(x, y): return jnp.sin(x) + y with config.explain_cache_misses(True): with self.assertLogs(level="WARNING") as cm: f(0., y=1.) - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("different number of args and kwargs, but same total number", msg) self.assertIn("now 1 args and kwargs with keys ['y']", msg) @@ -4540,16 +4539,15 @@ def f(x, y, z): with config.explain_cache_misses(True): with self.assertLogs(level="WARNING") as cm: f(1., 2., "bar") - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("different value of static args", msg) self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) - self.assertNotIn('explanation unavailable!', msg) + self.assertNotIn("explanation unavailable!", msg) @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_other_static_argnames(self): - @partial(jax.jit, static_argnames='foo') + @partial(jax.jit, static_argnames="foo") def f(*, foo): return 1 @@ -4558,8 +4556,7 @@ def f(*, foo): with config.explain_cache_misses(True): with self.assertLogs(level="WARNING") as cm: f(foo="bar") - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("different value of static kwargs", msg) self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) @@ -4574,8 +4571,7 @@ def f(x, y): return x with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: f(np.float32(0), np.int32(1)) - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("different input types", msg) self.assertIn("at y, now i32[] and before f32[]", msg) @@ -4672,8 +4668,7 @@ def f(x): return jnp.sin(x) # not in shape. f(np.arange(8, dtype=np.float32)) - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - self.assertLen(cm.output, expected_log_len) + self.assertLen(cm.output, 1) msg = cm.output[0] self.assertIn("key with different input types", msg) self.assertIn("at x, now f32[8] and before f32[4]", msg) From 1b1bd071bce99e8436129eebf4b1b08da607565c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 14 Apr 2025 07:43:02 -0700 Subject: [PATCH 0607/1769] Finalize deprecation of vectorized argument in callbacks. The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details. This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`. PiperOrigin-RevId: 747413383 --- CHANGELOG.md | 2 ++ jax/_src/callback.py | 36 +++++++++----------------------- jax/_src/ffi.py | 39 +++++++---------------------------- tests/ffi_test.py | 8 ------- tests/python_callback_test.py | 17 +-------------- 5 files changed, 21 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index db5e0f9763d7..810e23b10139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most have no public replacement, though a few are available at {mod}`jax.extend.core`. + * The `vectorized` argument to {func}`~jax.pure_callback` and + {func}`~jax.ffi.ffi_call`. Use the `vmap_method` parameter instead. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 630539ab0db8..25bdb801edce 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -23,7 +23,6 @@ import jax from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -48,10 +47,6 @@ logger = logging.getLogger(__name__) -# TODO(dfm): Remove after 6 months. -# Added Oct 1, 2024 -deprecations.register("jax-callback-vectorized") - # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") pure_callback_p.multiple_results = True @@ -83,10 +78,9 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del sharding, vectorized, vmap_method, result_avals + del sharding, vmap_method, result_avals try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -114,10 +108,9 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del avals, callback, sharding, vectorized, vmap_method + del avals, callback, sharding, vmap_method return result_avals @@ -292,7 +285,7 @@ def pure_callback( When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` - is deprecated and it will eventually raise ``NotImplementedError``. + raises a ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop @@ -302,9 +295,8 @@ def pure_callback( * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the inputs are tiled to the expected batched shape. - If necessary, the legacy behavior provided by the deprecated - ``vectorized=True`` argument can be recovered using - ``vmap_method="legacy_vectorized"``. + If necessary, the legacy behavior provided by the removed ``vectorized=True`` + argument can be recovered using ``vmap_method="legacy_vectorized"``. The current default behavior is to use ``vmap_method="sequential"`` when not specified, but this behavior is deprecated, and in the future, the @@ -373,18 +365,11 @@ def pure_callback( .. _External Callbacks: https://docs.jax.dev/en/latest/external-callbacks.html """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of jax.pure_callback is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of jax.pure_callback cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.pure_callback was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -402,7 +387,6 @@ def pure_callback( callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, - vectorized=vectorized, vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index d1ad267543b8..22d54d39913d 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,7 +24,6 @@ import jax from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util @@ -328,7 +327,7 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), ) -> Callable[..., Array]: ... @@ -345,7 +344,7 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), ) -> Callable[..., Sequence[Array]]: ... @@ -361,7 +360,7 @@ def ffi_call( input_output_aliases: dict[int, int] | None = None, custom_call_api_version: int = 4, legacy_backend_config: str | None = None, - vectorized: bool | DeprecatedArg = DeprecatedArg(), + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), ) -> Callable[..., Array | Sequence[Array]]: """Call a foreign function interface (FFI) target. @@ -422,18 +421,11 @@ def ffi_call( to execute the FFI handler. Any keyword arguments are passed as named attributes to the FFI handler using XLA's FFI interface. """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of ffi_call is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of ffi_call cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.ffi.ffi_call was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -511,7 +503,6 @@ def wrapped(*args: ArrayLike, **kwargs: Any): results = ffi_call_p.bind( *args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, @@ -665,21 +656,10 @@ def ffi_batching_rule( args, dims, *, - vectorized: bool | None | DeprecatedArg, vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - if isinstance(vectorized, DeprecatedArg) and vmap_method is None: - deprecations.warn( - "jax-callback-vectorized", - f"The default behavior of {prim.name} under vmap will soon " - "change. Currently, the default behavior is to generate a sequential " - "vmap (i.e. a loop), but in the future the default will be to raise " - "an error. To keep the current default, set vmap_method='sequential'.", - stacklevel=6) - vmap_method = "sequential" - axis_size, = {a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped} new_args = [arg if dim is batching.not_mapped else @@ -707,7 +687,6 @@ def ffi_batching_rule( for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -723,7 +702,6 @@ def ffi_batching_rule( for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -736,7 +714,6 @@ def _batch_fun(batched_args): return prim.bind( *merged_args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, **kwargs, ) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index abd80096643b..978415194e55 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -200,14 +200,6 @@ def test_ffi_call_batching(self, shape, vmap_method): else: self.assertArraysEqual(a, b) - @jtu.run_on_devices("gpu", "cpu") - def test_vectorized_deprecation(self): - x = self.rng().randn(3, 5, 4).astype(np.float32) - with self.assertWarns(DeprecationWarning): - ffi_call_geqrf(x, vectorized=True) - with self.assertWarns(DeprecationWarning): - jax.vmap(ffi_call_geqrf)(x) - def test_input_output_aliases(self): def fun(x): return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a8442b4a1356..e4aeb8d66f9e 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1041,26 +1041,11 @@ def f(x): def test_vmap_method_raise(self): @jax.vmap def f(x): - # Setting vectorized to None disables the current default behavior of - # falling back on sequential. - return jax.pure_callback(np.sin, x, x, vectorized=None) + return jax.pure_callback(np.sin, x, x) with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): f(jnp.arange(4.)) - def test_deprecated_vectorized(self): - def f(x, **kwargs): - return jax.pure_callback(np.sin, x, x, **kwargs) - - with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): - jax.vmap(f)(jnp.arange(4.0)) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=True) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=False) - def test_vmap_method_expand_dims(self): def callback(x, y): self.assertTupleEqual(x.shape, (4,)) From 5af5925749fd861ad645f13f6108774fddbe1314 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 14 Apr 2025 08:00:55 -0700 Subject: [PATCH 0608/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f0b96b5f16b9374121fb21e9b751d1f941352932. PiperOrigin-RevId: 747419407 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f02028de99e3..afb9b809d2f6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "01b33c6596f2afeefaf76233cbb43cf6de66c1c9" -XLA_SHA256 = "d18a63f603b206e6befda14e29a041c95a4036093b535bdb09b826d5808a2b89" +XLA_COMMIT = "f0b96b5f16b9374121fb21e9b751d1f941352932" +XLA_SHA256 = "c37eefef6204cd1215760fc90608e9d270126f959a502f2920c8e7332ab69353" def repo(): tf_http_archive( From ceca6ec1fc5b30853fa530f7c548667f920c8f9d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 11 Apr 2025 11:58:19 -0700 Subject: [PATCH 0609/1769] jax.jit: deprecate non-standard call signature. --- CHANGELOG.md | 3 +++ jax/_src/api.py | 34 +++++++++++++++++++++++++++++++++- jax/_src/deprecations.py | 1 + tests/api_test.py | 14 ++++++++++++++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 810e23b10139..3fb64ef09bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` to install JAX, run `pip install jax[cuda12-local]` instead. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in a + DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. * Deprecations diff --git a/jax/_src/api.py b/jax/_src/api.py index 23c0a610cee9..43ab7729a348 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -37,6 +37,7 @@ import numpy as np from contextlib import contextmanager +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( @@ -147,8 +148,39 @@ def _update_debug_special_thread_local(_): float0 = dtypes.float0 +# TODO(jakevdp): remove this for v0.7.0 (~July 2025) +def _allow_deprecated_jit_signature(f: F) -> F: + """Temporary decorator for the jit signature deprecation.""" + @wraps(f) + def wrapped(*args, **kwargs): + if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'): + # Fast path for typical usage. + return f(*args, **kwargs) + if 'fun' in kwargs: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing fun by keyword is deprecated.' + ' Pass it by position to silence this warning.'), + stacklevel=2 + ) + return f(kwargs.pop('fun'), **kwargs) + if len(args) > 1: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing optional arguments by position is deprecated. ' + ' Pass them by keyword to silence this warning.'), + stacklevel=2 + ) + sig = inspect.signature(f) + kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args)) + return f(kwds.pop('fun'), **kwds, **kwargs) + return f(*args, **kwargs) + return cast(F, wrapped) + + +@_allow_deprecated_jit_signature def jit( - fun: Callable, + fun: Callable, /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, out_shardings: Any = sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 6c39c893a111..329491b1e8a8 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -134,3 +134,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') register('jax-scipy-special-sph-harm') +register('jax-jit-positional-args') diff --git a/tests/api_test.py b/tests/api_test.py index 59a6211e3cbc..7590b6e6af71 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -52,6 +52,7 @@ from jax._src import config from jax._src import core from jax._src import custom_derivatives +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -10516,6 +10517,19 @@ def tp(r, t): return 2 * fn(r, t) self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_jit_signature_deprecation(self): + fun = lambda x: x + if deprecations.is_accelerated('jax-jit-positional-args'): + with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): + jax.jit(fun=fun) + with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): + jax.jit(fun, None) + else: + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): + jax.jit(fun=fun) + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): + jax.jit(fun, None) + def test_cond(self): def f(x, y): @custom_transpose(jnp.ones(2)) From 42542feac61ddab94ef866309eba31f1b7bcd0b8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Apr 2025 10:42:29 -0700 Subject: [PATCH 0610/1769] jnp.power: better docs for invalid input --- jax/_src/numpy/ufuncs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3fe63545e6df..509b046554d3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2615,8 +2615,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.lax.integer_pow`. - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to :func:`jax.lax.pow`. - - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative - integer power. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to a concrete + negative integer power. For a non-concrete power, the operation is invalid + and the returned value is implementation-defined. - ``jnp.power`` returns ``nan`` for negative value raised to the power of non-integer values. From 785d07759da9a4440bd7fbdb6e10a81141b1e607 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 14 Apr 2025 11:01:55 -0700 Subject: [PATCH 0611/1769] Disable unknown warning option error on Mac. Clang 19 requires `-Wno-error=c23-extensions` (https://github.com/jax-ml/jax/commit/3d006e9f1566afc6e450243cd4dfdf792433a73d) but this flag is not supported on Apple Clang in XCode 16.0 so we suppress unknown warning option errors on Mac CI builds. Fixes https://btx.cloud.google.com/invocations/15ded1a8-956e-462e-9da9-9748b4e4a03e/targets/ml_oss%2Fjax%2Fbuild_artifacts%2Fnightly%2Fjaxlib%2Fmacos_arm64%2Fpy310/log PiperOrigin-RevId: 747489006 --- .bazelrc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bazelrc b/.bazelrc index 22a7a072e467..755572f21355 100644 --- a/.bazelrc +++ b/.bazelrc @@ -257,6 +257,10 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 +# Clang 19 requires `-Wno-error=c23-extensions` but this flag is not supported +# on Apple Clang in XCode 16.0 so we suppress unknown warning option errors +# on Mac CI builds. +build:ci_darwin_arm64 --copt=-Wno-unknown-warning-option build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true build:ci_darwin_arm64 --color=yes From 1fcb2b47055cc18b144b617adb6f1fb29eb9a207 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 14 Apr 2025 18:45:23 +0000 Subject: [PATCH 0612/1769] Reinstate lifegiving chaos line to docs. --- docs/random-numbers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 00f77e3473bb..134b690839e0 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -152,7 +152,7 @@ print(random.normal(key)) Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. -**The rule of thumb is: never reuse keys (unless you want identical outputs).** +**The rule of thumb is: never reuse keys (unless you want identical outputs). Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__.** JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: From 760d0e0e97467dac41fd026eea74930a3e16823e Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Mon, 14 Apr 2025 20:33:30 +0000 Subject: [PATCH 0613/1769] disable packed layout test on old arch prior to Hopper --- jax/_src/cudnn/fused_attention_stablehlo.py | 5 +++-- tests/fused_attention_stablehlo_test.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 818bc018cdf5..64e47dcd6a4b 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -367,8 +367,9 @@ def check_is_flash_attention( f"Unsupported sequence length Q {T}, KV {S}." ) - if is_packed and cudnn_version < 90600: - raise NotImplementedError("Packed layout requires cudnn version >= 9.6.") + if is_packed and (cudnn_version < 90600 or not check_compute_capability("9.0")): + raise NotImplementedError( + "Packed layout requires cudnn version >= 9.6 and at least hopper arch.") def check_cudnn_version(): # check if cuDNN is installed diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index af0b18b02f37..b0d040f8e6ec 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -618,6 +618,8 @@ def test_sdpa_packed_layout(self): return if cudnn_version < 90600: self.skipTest("Requires >= cuDNN 9.6.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 512, 4, 64), dtype=jnp.bfloat16) From 8930a67e635906aaccd326cf40ade4bcd28c5350 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 14 Apr 2025 13:32:50 -0700 Subject: [PATCH 0614/1769] Fix stablehlo version comparison in test utilities. PiperOrigin-RevId: 747547427 --- jax/_src/test_util.py | 4 +++- jax/_src/xla_bridge.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 3fff55d9ed1c..caff1c73145b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -429,7 +429,9 @@ def stablehlo_version_at_least(required_version: str): plugin_version = xla_bridge.backend_stablehlo_version() if plugin_version is None: return True - return hlo.get_smaller_version(plugin_version, required_version) == plugin_version + return hlo.get_smaller_version( + ".".join(map(str, plugin_version)), required_version + ) == plugin_version def get_tpu_version() -> int: if device_under_test() != "tpu": diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index c235101839d8..5fb42c333605 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -32,7 +32,7 @@ import platform as py_platform import threading import traceback -from typing import Any, Union +from typing import Any, Sequence, Union import warnings from jax._src import config @@ -1086,7 +1086,7 @@ def backend_xla_version(platform=None) -> int | None: backend = get_backend(platform) return getattr(backend, "xla_version", None) -def backend_stablehlo_version(platform=None) -> int | None: +def backend_stablehlo_version(platform=None) -> Sequence[int] | None: """Returns the StableHLO version of the backend. Returns None if the backend does not use PJRT C API or does not have From 73305e03feb0f0426d3c75f7c72d5a334a6f89a3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Apr 2025 14:35:30 -0700 Subject: [PATCH 0615/1769] Update issue template with correct URL for untemplated issue --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 628310519b66..1f8c2b2ac254 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -24,7 +24,7 @@ body: [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/jax-ml/jax/issues/new + [Raw report]: https://github.com/jax-ml/jax/issues/new?template=none - type: textarea attributes: label: Description From a88486ca7037737cbd6f718bfcc856c21cf0b346 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 14 Apr 2025 15:15:32 -0700 Subject: [PATCH 0616/1769] Fix warnings in array_interoperability_test. PiperOrigin-RevId: 747586780 --- tests/BUILD | 3 --- tests/array_interoperability_test.py | 8 ++++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 234e577e015a..c46ea7556a44 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -102,9 +102,6 @@ jax_multiplatform_test( enable_configs = [ "gpu_h100x2", ], - env = { - "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. - }, tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 80a4d8ef5a25..15fd6306da6f 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -95,6 +95,10 @@ def setUp(self): message="Calling from_dlpack with a DLPack tensor", category=DeprecationWarning, ) + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -188,6 +192,10 @@ def testTensorFlowToJax(self, shape, dtype): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxToTensorFlow(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): From ab600c3e82c0eef512c56e4d3f1ac12af7b4fd77 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Mon, 14 Apr 2025 16:32:24 -0700 Subject: [PATCH 0617/1769] Remove obsolete python key path registry. PiperOrigin-RevId: 747613761 --- jax/_src/tree_util.py | 49 +------------------------------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 883937fcce6e..b73d84b330de 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -21,7 +21,7 @@ from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, overload +from typing import Any, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -762,42 +762,6 @@ def _simple_entrystr(key: KeyEntry) -> str: return str(key) -# TODO(ivyzheng): remove this after another jaxlib release. -class _RegistryWithKeypathsEntry(NamedTuple): - flatten_with_keys: Callable[..., Any] - unflatten_func: Callable[..., Any] - - -def _register_keypaths( - ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]] -) -> None: - def flatten_with_keys(xs): - children, treedef = _registry[ty].to_iter(xs) - return list(zip(handler(xs), children)), treedef - if ty in _registry: - _registry_with_keypaths[ty] = _RegistryWithKeypathsEntry( - flatten_with_keys, _registry[ty].from_iter - ) - -_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} - -_register_keypaths( - tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths( - list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs))) - -_register_keypaths( - collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - -_register_keypaths( - collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - - @export def register_pytree_with_keys( nodetype: type[T], @@ -867,9 +831,6 @@ def flatten_func_impl(tree): register_pytree_node( nodetype, flatten_func, unflatten_func, flatten_with_keys ) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) @export @@ -1062,11 +1023,6 @@ def register_dataclass( msg += f" Unexpected fields: {unexpected}." raise ValueError(msg) - def flatten_with_keys(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) - return data, meta - def unflatten_func(meta, data): meta_args = tuple(zip(meta_fields, meta)) data_args = tuple(zip(data_fields, data)) @@ -1082,9 +1038,6 @@ def flatten_func(x): none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) return nodetype From 7b4b2f47c96ec703a4f1d61f42006ffc544df3d7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Apr 2025 13:46:45 +0000 Subject: [PATCH 0618/1769] Fixed the way to skip tests using optional python dependencies Description: - Fixed the way to skip tests using optional python dependencies - for example, tests using cloudpickle - Remove if/else, CAN_USE_HYPOTHESIS for hypothesis --- tests/colocated_python_test.py | 12 +- tests/mosaic/matmul_test.py | 10 +- tests/pallas/indexing_test.py | 19 +- tests/pallas/ops_test.py | 35 +- tests/pallas/tpu_all_gather_test.py | 196 ++-- tests/pallas/tpu_gmm_test.py | 621 +++++----- tests/pallas/tpu_ops_test.py | 8 +- tests/pallas/tpu_pallas_pipeline_test.py | 218 ++-- .../tpu_splash_attention_kernel_test.py | 16 +- tests/state_test.py | 1018 ++++++++--------- 10 files changed, 1058 insertions(+), 1095 deletions(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 52d494904fe6..ada17bc61c82 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -18,7 +18,6 @@ import threading import time from typing import Sequence -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -36,8 +35,9 @@ try: import cloudpickle # noqa + HAS_CLOUDPICKLE = True except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on cloudpickle library") + HAS_CLOUDPICKLE = False def _colocated_cpu_devices( @@ -68,10 +68,14 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() + if not HAS_CLOUDPICKLE: + self.skipTest( + "ColocatedPythonTest depends on cloudpickle library" + ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest( - "Serialization in Colocated Python needs StringDType, and thus" - " requires NumPy 2.0.0 or later" + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" ) def testMakeColocatedPythonProgram(self): diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 41e60fbe4c29..9634718d2d44 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,12 +15,15 @@ """Test different parameterizations of a matmul.""" import os -import unittest from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp + +import hypothesis as hp +import hypothesis.strategies as hps + try: # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 @@ -28,11 +31,6 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") config.parse_flags_with_absl() diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 3de0c1f305c6..cb862b406603 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -14,7 +14,6 @@ from __future__ import annotations import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,11 +31,7 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps @@ -95,7 +90,7 @@ def array_indexer_strategy(draw, shape) -> jax.Array: @hps.composite def indexer_strategy(draw, dim, int_indexer_shape - ) -> int | Slice | jax.Array: + ) -> int | Slice | jax.Array: return draw(hps.one_of( int_indexer_strategy(dim), slice_indexer_strategy(dim), @@ -372,7 +367,7 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): def test_vmap_nd_indexing(self, data): self.skipTest("TODO(necula): enable this test; was in jax_triton.") vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), - label="vmap_shape") + label="vmap_shape") el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs # hp.assume(len(el_shape) >= 2) @@ -389,7 +384,7 @@ def kernel(x_ref, y_ref): shape = el_shape for vmap_dim in vmap_shape[::-1]: index = data.draw(hps.integers(min_value=0, - max_value=max(0, len(shape) - 2)), + max_value=max(0, len(shape) - 2)), label="index") # hp.assume(index <= max(0, len(shape) - 2)) # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost @@ -649,9 +644,9 @@ def kernel(x_ref, o_ref, sem_ref): # Use scalar_val in both async_copy and store. o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val desc = pltpu.make_async_copy( - o_ref.at[scalar_val], - o_ref.at[scalar_val + 1], - sem_ref, + o_ref.at[scalar_val], + o_ref.at[scalar_val + 1], + sem_ref, ) desc.start() desc.wait() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0580d7ec5824..709828186480 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -48,14 +48,11 @@ plgpu_triton = None pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 @@ -188,7 +185,7 @@ def select_n_strategy( else: pred_dtype = np.int32 pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, - elements=allowed_elements)) + elements=allowed_elements)) cases = ( draw( arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) @@ -332,7 +329,7 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.full((8, 128), 4, dtype=dtype) y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0, - dtype=dtype) + dtype=dtype) np.testing.assert_allclose(kernel(x, y), fn(x, y)) @parameterized.named_parameters( @@ -1071,8 +1068,8 @@ def kernel(x_ref, o_ref): ( # fmt: off [jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin, - jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, - jnp.acosh, jnp.atanh], + jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, + jnp.acosh, jnp.atanh], # fmt: on ["bfloat16", "float32", "float64"], ), @@ -1096,7 +1093,7 @@ def test_elementwise(self, fn, dtype): self.skipTest("int16 and float16 are not supported on TPU") if ( fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log, - jnp.sqrt, lax.rsqrt) + jnp.sqrt, lax.rsqrt) and dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6) ): @@ -1474,7 +1471,7 @@ def kernel(x_ref, y_ref, o_ref): ( # fmt: off [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, - jnp.bitwise_left_shift, jnp.bitwise_right_shift], + jnp.bitwise_left_shift, jnp.bitwise_right_shift], # fmt: on ["int32", "uint32"], ), @@ -1918,7 +1915,7 @@ def dot(x_ref, y_ref, o_ref): # Pallas always accumulates in FP32, so we are explicit about # preferred_element_type here. expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y, - preferred_element_type=jnp.float32).astype(dtype) + preferred_element_type=jnp.float32).astype(dtype) np.testing.assert_allclose( out.astype(jnp.float32), expected.astype(jnp.float32), @@ -2107,7 +2104,7 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((n,), floatx), - jax.ShapeDtypeStruct((m,), floatx)), + jax.ShapeDtypeStruct((m,), floatx)), input_output_aliases={0: 0, 1: 1}, ) def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): @@ -2237,7 +2234,7 @@ def swap(_, lock_ref, out_ref): lock, out = swap(init_value) np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) + init_value) np.testing.assert_allclose(out, init_value) @parameterized.parameters(1, 2, 3, 4, 8) @@ -2603,15 +2600,15 @@ def body(x_ref): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), - "c:i32[4,3,2], a[:,:,:] <-"), + "c:i32[4,3,2], a[:,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), - "c:i32[3,3,2], a[:3,:,:] <-"), + "c:i32[3,3,2], a[:3,:,:] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), - "c:i32[3,3,4], a[1:,:,:4] <-"), + "c:i32[3,3,4], a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), - "e:i32[5,3,4], a[b,:,:4] <-"), + "e:i32[5,3,4], a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), - "o:i32[5,3,4], a[m,n,:4] <-"), + "o:i32[5,3,4], a[m,n,:4] <-"), ]) def test_swap_pretty_print(self, expr, expected): def body(x_ref): diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 0c9a4b545591..47168e1c35b4 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -25,115 +25,109 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() P = jax.sharding.PartitionSpec -if CAN_USE_HYPOTHESIS: - - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=50, - print_blob=True, - verbosity=hp.Verbosity.verbose, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=50, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile("deterministic") + + +@hps.composite +def _array_shapes(draw): + # TODO(sharadmv, apaszke): enable this on a wider variety of shapes + valid_shapes = [ + (128, 128), + (256, 128), + (256, 512), + (256, 1024), + # TODO(sharadmv,apaszke): enable these shapes + # (256, 129), + # (129, 128), + # (64, 64), + # (1, 1), + ] + return draw(hps.sampled_from(valid_shapes)) + + +@hps.composite +def _array_dtypes(draw): + return draw( + hps.sampled_from([ + jnp.float32, + jnp.bfloat16, + jnp.int32, + # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather + # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather + # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather + ]) ) - hp.settings.load_profile("deterministic") - - - @hps.composite - def _array_shapes(draw): - # TODO(sharadmv, apaszke): enable this on a wider variety of shapes - valid_shapes = [ - (128, 128), - (256, 128), - (256, 512), - (256, 1024), - # TODO(sharadmv,apaszke): enable these shapes - # (256, 129), - # (129, 128), - # (64, 64), - # (1, 1), - ] - return draw(hps.sampled_from(valid_shapes)) - - - @hps.composite - def _array_dtypes(draw): - return draw( - hps.sampled_from([ - jnp.float32, - jnp.bfloat16, - jnp.int32, - # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather - # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather - # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather - ]) - ) - @jtu.thread_unsafe_test_class() # hypothesis is not thread safe - class AllGatherTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Need TPU devices") - if not jtu.is_device_tpu(version=5, variant="e"): - # TODO(sharadmv,apaszke): expand support to more versions - self.skipTest("Currently only supported on TPU v5e") - - super().setUp() - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) - def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): - if jax.device_count() < 2: - self.skipTest("Need more devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (jax.device_count(),) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] - ) - leading, *rest = shape - shape = (mesh.shape["x"] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", - memory_space=memory_space) - np.testing.assert_array_equal(y, x) - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), - hps.sampled_from(["x", "y"])) - def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, - axis_name): - if jax.device_count() < 2: - self.skipTest("Need more devices") - if jax.device_count() % 2: - self.skipTest("Need an even number of devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (2, jax.device_count() // 2) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] - ) - if axis_name == "x": - sharding = jax.sharding.NamedSharding(mesh, P("x", None)) - else: - sharding = jax.sharding.NamedSharding(mesh, P("y", None)) - leading, *rest = shape - shape = (mesh.shape[axis_name] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, sharding) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, - memory_space=memory_space) - np.testing.assert_array_equal(y, x) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class AllGatherTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + if not jtu.is_device_tpu(version=5, variant="e"): + # TODO(sharadmv,apaszke): expand support to more versions + self.skipTest("Currently only supported on TPU v5e") + + super().setUp() + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) + def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): + if jax.device_count() < 2: + self.skipTest("Need more devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] + ) + leading, *rest = shape + shape = (mesh.shape["x"] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), + hps.sampled_from(["x", "y"])) + def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, + axis_name): + if jax.device_count() < 2: + self.skipTest("Need more devices") + if jax.device_count() % 2: + self.skipTest("Need an even number of devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (2, jax.device_count() // 2) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] + ) + if axis_name == "x": + sharding = jax.sharding.NamedSharding(mesh, P("x", None)) + else: + sharding = jax.sharding.NamedSharding(mesh, P("y", None)) + leading, *rest = shape + shape = (mesh.shape[axis_name] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, sharding) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, + memory_space=memory_space) + np.testing.assert_array_equal(y, x) if __name__ == "__main__": diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index cadba4c15fa0..7bc698794f09 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -24,12 +24,8 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() @@ -37,327 +33,326 @@ partial = functools.partial -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=10, - print_blob=True, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=10, + print_blob=True, +) +hp.settings.load_profile("deterministic") + +def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + +@hps.composite +def group_strategy( + draw: hps.DrawFn, + max_groups: int = 32, + max_stride: int = 32, + min_groups: int = 1, +) -> tuple[int, int]: + assert max_stride <= max_groups + + # Sample the number of groups owned by each shard. + group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) + + # Sample the number of groups as a multiple of the stride to ensure that we + # have an equal number of groups per shard. Round down s.t. num_groups <= + # max_groups. + num_groups = group_stride * draw( + hps.integers(min_value=min_groups, max_value=max_groups // group_stride) ) - hp.settings.load_profile("deterministic") - - def seed_strategy() -> hps.SearchStrategy[int]: - return hps.integers(min_value=0, max_value=4) - - @hps.composite - def group_strategy( - draw: hps.DrawFn, - max_groups: int = 32, - max_stride: int = 32, - min_groups: int = 1, - ) -> tuple[int, int]: - assert max_stride <= max_groups - - # Sample the number of groups owned by each shard. - group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) - - # Sample the number of groups as a multiple of the stride to ensure that we - # have an equal number of groups per shard. Round down s.t. num_groups <= - # max_groups. - num_groups = group_stride * draw( - hps.integers(min_value=min_groups, max_value=max_groups // group_stride) - ) - return num_groups, group_stride - - @hps.composite - def group_sizes_strategy( - draw: hps.DrawFn, m: int, num_groups: int - ) -> jnp.ndarray: - # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer - # sample with replacement so that it's possible to get zero-sized groups. Get - # 'num_groups - 1' run ends. The final group will end at 'm'. - ends_no_final = np.sort( - np.array( - [ - draw(hps.integers(min_value=0, max_value=m)) - for _ in range(num_groups - 1) - ], - dtype=np.int32, - ), + return num_groups, group_stride + +@hps.composite +def group_sizes_strategy( + draw: hps.DrawFn, m: int, num_groups: int +) -> jnp.ndarray: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [ + draw(hps.integers(min_value=0, max_value=m)) + for _ in range(num_groups - 1) + ], + dtype=np.int32, + ), + ) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return jnp.array(ends - starts, dtype=jnp.int32) + +GROUPED_MATMUL_TESTS = ( + (128, 128, 128), # Small + (512, 2048, 256), # Big + (128, 8, 16), # Test partial tiles. +) + +def random_dense( + shape: tuple[int, ...], + key: jax.Array, + dtype: jnp.dtype, + limit: int | None = None, +) -> jnp.ndarray: + if limit is None: + limit = 1 / np.prod(shape) + x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type + return x.astype(jnp.bfloat16).astype(dtype) + +def dot( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + transpose_lhs: bool = False, + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + lhs = jnp.transpose(lhs) if transpose_lhs else lhs + rhs = jnp.transpose(rhs) if transpose_rhs else rhs + return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) + +def reference_gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = dot( + lhs[start : start + size, :], + rhs[i, :, :], + preferred_element_type=preferred_element_type, ) - ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) - # Calculate the run starts by shifting ends 1 to the right. The first run - # starts at zero. - starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) - return jnp.array(ends - starts, dtype=jnp.int32) + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out, axis=0) + +def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: + dtypes = [jnp.float32, jnp.bfloat16] + + result = [] + for x in xs: + for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): + result.append(x + dtypes_tuple) + return tuple(result) + +def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: + flags = [False, True] + result = [] + for x in xs: + for flag in flags: + result.append(x + (flag,)) + return tuple(result) + +def tolerances( + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype +) -> tuple[float, float]: + if ( + lhs_dtype == jnp.bfloat16 + or rhs_dtype == jnp.bfloat16 + or out_dtype == jnp.bfloat16 + ): + return 1e-3, 1e-2 # atol, rtol + return 1e-3, 1e-5 # atol, rtol + +# TODO(tgale): Fix errors with strict dtype promotion. +@jtu.with_config(jax_numpy_dtype_promotion="standard") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class GroupedMatmulTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU device.") + + super().setUp() + self.key = jax.random.PRNGKey(1234) + + def assert_allclose( + self, + out: jnp.ndarray, + expected_out: jnp.ndarray, + *, + atol: float = 1e-5, + rtol: float = 1e-5, + ): + self.assertEqual(out.dtype, expected_out.dtype) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=atol, + rtol=rtol, + ) - GROUPED_MATMUL_TESTS = ( - (128, 128, 128), # Small - (512, 2048, 256), # Big - (128, 8, 16), # Test partial tiles. - ) + def gmm_test( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + interpret: bool = False, + ): + seed = data.draw(seed_strategy()) + num_groups, _ = data.draw(group_strategy(max_stride=1)) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] + transpose_rhs = data.draw(hps.booleans()) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, vjpfun = jax.vjp( + partial( + mblx.gmm, + preferred_element_type=out_dtype, + transpose_rhs=transpose_rhs, + interpret=interpret, + ), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) - def random_dense( - shape: tuple[int, ...], - key: jax.Array, - dtype: jnp.dtype, - limit: int | None = None, - ) -> jnp.ndarray: - if limit is None: - limit = 1 / np.prod(shape) - x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type - return x.astype(jnp.bfloat16).astype(dtype) - - def dot( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - transpose_lhs: bool = False, - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - lhs = jnp.transpose(lhs) if transpose_lhs else lhs - rhs = jnp.transpose(rhs) if transpose_rhs else rhs - return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) - - def reference_gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - - start = 0 - out = [] - for i, size in enumerate(group_sizes): - result = dot( - lhs[start : start + size, :], - rhs[i, :, :], - preferred_element_type=preferred_element_type, + def reference_fn(lhs, rhs, group_sizes, preferred_element_type): + rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs + return reference_gmm( + lhs, rhs, group_sizes, preferred_element_type=preferred_element_type ) - out.append(result) - start += group_sizes[i] - return jnp.concatenate(out, axis=0) - - def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: - dtypes = [jnp.float32, jnp.bfloat16] - - result = [] - for x in xs: - for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): - result.append(x + dtypes_tuple) - return tuple(result) - - def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: - flags = [False, True] - result = [] - for x in xs: - for flag in flags: - result.append(x + (flag,)) - return tuple(result) - - def tolerances( - lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype - ) -> tuple[float, float]: - if ( - lhs_dtype == jnp.bfloat16 - or rhs_dtype == jnp.bfloat16 - or out_dtype == jnp.bfloat16 - ): - return 1e-3, 1e-2 # atol, rtol - return 1e-3, 1e-5 # atol, rtol - - # TODO(tgale): Fix errors with strict dtype promotion. - @jtu.with_config(jax_numpy_dtype_promotion="standard") - @jtu.thread_unsafe_test_class() # hypothesis is not thread safe - class GroupedMatmulTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - self.key = jax.random.PRNGKey(1234) - - def assert_allclose( - self, - out: jnp.ndarray, - expected_out: jnp.ndarray, - *, - atol: float = 1e-5, - rtol: float = 1e-5, - ): - self.assertEqual(out.dtype, expected_out.dtype) - np.testing.assert_allclose( - out.astype(jnp.float32), - expected_out.astype(jnp.float32), - atol=atol, - rtol=rtol, - ) + expected_out, reference_vjpfun = jax.vjp( + partial(reference_fn, preferred_element_type=out_dtype), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + grad_lhs, grad_rhs, *_ = vjpfun(cotangent) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.gmm_test(m, k, n, data) + + # NOTE: Run fewer tests with interpret mode. We just want to sanity check that + # changes do not break running these kernels with interpret=True. + @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) + @hp.given(hps.data()) + def test_gmm_interpret( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.skipTest("interpret mode with dynamic grids is unsupported") + self.gmm_test( + m, + k, + n, + data=data, + interpret=True, + ) - def gmm_test( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - interpret: bool = False, - ): - seed = data.draw(seed_strategy()) - num_groups, _ = data.draw(group_strategy(max_stride=1)) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - transpose_rhs = data.draw(hps.booleans()) - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, vjpfun = jax.vjp( - partial( - mblx.gmm, - preferred_element_type=out_dtype, - transpose_rhs=transpose_rhs, - interpret=interpret, + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm_sharded_groups( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + seed = data.draw(seed_strategy()) + num_groups, group_stride = data.draw(group_strategy()) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, shard_vjpfun = jax.vjp( + partial(mblx.gmm, preferred_element_type=out_dtype), + lhs, + rhs[0:group_stride], + group_sizes, + ) + vjpfuns = [shard_vjpfun] + for group_offset in range(group_stride, num_groups, group_stride): + out, shard_vjpfun = jax.vjp( + lambda lhs, rhs, group_sizes, out: mblx.gmm( + lhs, + rhs, + group_sizes, + out_dtype, + group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop + existing_out=out, ), lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, - group_sizes, - ) - - def reference_fn(lhs, rhs, group_sizes, preferred_element_type): - rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs - return reference_gmm( - lhs, rhs, group_sizes, preferred_element_type=preferred_element_type - ) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_fn, preferred_element_type=out_dtype), - lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, + rhs[group_offset : group_offset + group_stride], group_sizes, + out, ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - grad_lhs, grad_rhs, *_ = vjpfun(cotangent) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) - - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.gmm_test(m, k, n, data) - - # NOTE: Run fewer tests with interpret mode. We just want to sanity check that - # changes do not break running these kernels with interpret=True. - @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) - @hp.given(hps.data()) - def test_gmm_interpret( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.skipTest("interpret mode with dynamic grids is unsupported") - self.gmm_test( - m, - k, - n, - data=data, - interpret=True, - ) + vjpfuns.append(shard_vjpfun) - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm_sharded_groups( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], + expected_out, reference_vjpfun = jax.vjp( + partial(reference_gmm, preferred_element_type=out_dtype), + lhs, + rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) + grad_lhs = shard_grad_lhs + grad_rhs = [shard_grad_rhs] + for i, group_offset in enumerate( + range(group_stride, num_groups, group_stride) ): - seed = data.draw(seed_strategy()) - num_groups, group_stride = data.draw(group_strategy()) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, shard_vjpfun = jax.vjp( - partial(mblx.gmm, preferred_element_type=out_dtype), - lhs, - rhs[0:group_stride], - group_sizes, - ) - vjpfuns = [shard_vjpfun] - for group_offset in range(group_stride, num_groups, group_stride): - out, shard_vjpfun = jax.vjp( - lambda lhs, rhs, group_sizes, out: mblx.gmm( - lhs, - rhs, - group_sizes, - out_dtype, - group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop - existing_out=out, - ), - lhs, - rhs[group_offset : group_offset + group_stride], - group_sizes, - out, - ) - vjpfuns.append(shard_vjpfun) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_gmm, preferred_element_type=out_dtype), - lhs, - rhs, - group_sizes, - ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) - grad_lhs = shard_grad_lhs - grad_rhs = [shard_grad_rhs] - for i, group_offset in enumerate( - range(group_stride, num_groups, group_stride) - ): - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) - grad_lhs += shard_grad_lhs - grad_rhs.append(shard_grad_rhs) - grad_rhs = jnp.concatenate(grad_rhs, axis=0) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) + grad_lhs += shard_grad_lhs + grad_rhs.append(shard_grad_rhs) + grad_rhs = jnp.concatenate(grad_rhs, axis=0) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 53e5462e20c2..2cb0cfff09e8 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -15,7 +15,6 @@ import functools import math import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,13 +31,10 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 95014d9e9683..535eaaedf058 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -26,25 +26,21 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False - - -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - 'deterministic', - database=None, - derandomize=True, - deadline=None, - max_examples=200, - print_blob=True, - verbosity=hp.Verbosity.verbose, - ) - hp.settings.load_profile('deterministic') + +import hypothesis as hp +import hypothesis.strategies as hps + + +hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile('deterministic') jax.config.parse_flags_with_absl() @@ -1450,105 +1446,103 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) -if CAN_USE_HYPOTHESIS: - - @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) - def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): +@partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) +def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): - m, k = x.shape - _, n = y.shape + m, k = x.shape + _, n = y.shape - def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): - grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) - - def run(acc_scratch_ref): - pltpu.emit_pipeline( - partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), - in_specs=[ - pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), - pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), - ], - out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), - grid=grid, - core_axis=0, - dimension_semantics=( - pltpu.PARALLEL, - pltpu.PARALLEL, - pltpu.ARBITRARY, - ), - )(x_hbm_ref, y_hbm_ref, o_hbm_ref) - - accum_dtype = ( - jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 - ) - pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) - num_cores = jax.devices()[0].num_cores - return pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), - grid=(num_cores,), - )(x, y) - - @jtu.thread_unsafe_test_class() # hypothesis is not thread safe - class PaddedPipelineEmitterTest(parameterized.TestCase): + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + )(x, y) + +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class PaddedPipelineEmitterTest(parameterized.TestCase): - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Only TPU v4+ allowed.') + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') - @parameterized.named_parameters( - ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') - ) - @hp.given( - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.sampled_from([8, 16, 32, 128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.integers(0, 4), - ) - def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): - if dtype == 'int8' and jtu.is_device_tpu_at_least(6): - self.skipTest('Not implemented for TPU v6.') - - def align_up_to(x, y): - return (x + y - 1) // y * y - - hp.assume(bm <= m) - hp.assume(bn <= n) - hp.assume(bk <= k) - if dtype == 'bfloat16': - hp.assume(bm >= 16) - if dtype == 'int8': - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Only TPU v5+ allowed for int8.') - hp.assume(bm >= 32) - # TODO(apaszke): Relax DMA restrictions and remove this. - packing = 4 // jnp.dtype(dtype).itemsize - if packing != 1: - m = align_up_to(m, 8 * packing) - k = align_up_to(k, 8 * packing) - k1, k2 = jax.random.split(jax.random.key(seed)) - x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) - y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) - - out = matmul(x, y, bm=bm, bk=bk, bn=bn) - expected = x @ y - atol = rtol = 2.3e-5 - if dtype == 'bfloat16': - out = out.astype('float32') - expected = expected.astype('float32') - atol = rtol = 1e-2 - np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + @parameterized.named_parameters( + ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') + ) + @hp.given( + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + if dtype == 'int8' and jtu.is_device_tpu_at_least(6): + self.skipTest('Not implemented for TPU v6.') + + def align_up_to(x, y): + return (x + y - 1) // y * y + + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only TPU v5+ allowed for int8.') + hp.assume(bm >= 32) + # TODO(apaszke): Relax DMA restrictions and remove this. + packing = 4 // jnp.dtype(dtype).itemsize + if packing != 1: + m = align_up_to(m, 8 * packing) + k = align_up_to(k, 8 * packing) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 2.3e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) if __name__ == '__main__': diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index 8a73f221bb6d..a494a62745d1 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -32,11 +32,9 @@ import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") +import hypothesis as hp +import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=5) @@ -515,9 +513,9 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, 1)) mask = jnp.array(masks[0].get_mask()[:, :]) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy(), - label="logit_cap") + label="logit_cap") attn_ref = partial(splash.attention_reference, mask, - attn_logits_soft_cap=attn_logits_soft_cap) + attn_logits_soft_cap=attn_logits_soft_cap) attn_custom = partial(splash.attention_reference_custom, mask, attn_logits_soft_cap=attn_logits_soft_cap) attn_custom_vanilla = partial(splash.attention_reference_custom, mask, @@ -525,7 +523,7 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): attn_logits_soft_cap=attn_logits_soft_cap) o_ref, attn_vjp_ref = jax.vjp(attn_ref, q, k, v, segment_ids) q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), - (q, k, v)) + (q, k, v)) o_custom = attn_custom(q32, k32, v32, segment_ids) _, attn_vjp = jax.vjp(attn_custom, q32, k32, v32, segment_ids) _, attn_vanilla_vjp = jax.vjp(attn_custom_vanilla, q32, k32, v32, @@ -624,7 +622,7 @@ def test_splash_attention_bwd( mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, - use_fused_bwd_kernel=use_fused_bwd_kernel) + use_fused_bwd_kernel=use_fused_bwd_kernel) ) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask, backward_impl="custom") diff --git a/tests/state_test.py b/tests/state_test.py index 65f6f0427a00..d9bf66eb3f50 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -37,13 +37,9 @@ import jax.numpy as jnp from jax._src.lax.control_flow import for_loop -try: - import hypothesis as hp - import hypothesis.extra.numpy as hnp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps from jax._src.state.discharge import (run_state, run_state_reference, discharge_state) @@ -798,305 +794,303 @@ def body(i, st): self.assertLen(jaxpr.outvars, 3) -if CAN_USE_HYPOTHESIS: - - def index_arrays(size, idx_shape): - valid_idx = hps.integers(min_value=-size, max_value=size - 1) - return hnp.arrays(np.int32, idx_shape, elements=valid_idx) - - Shape = tuple[int, ...] - - class IndexParam(NamedTuple): - ref_aval: shaped_array_ref - ref_shape: Shape - indexed_dims: list[bool] - idx_avals: tuple[core.ShapedArray, ...] - idx_shape: Shape - slice_aval: core.ShapedArray - slice_shape: Shape - - @hps.composite - def index_params(draw): - ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') - indexed_dims = draw(hps.lists(hps.booleans(), - min_size=len(ref_shape), - max_size=len(ref_shape))) - idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) - if not any(indexed_dims): - slice_shape = ref_shape +def index_arrays(size, idx_shape): + valid_idx = hps.integers(min_value=-size, max_value=size - 1) + return hnp.arrays(np.int32, idx_shape, elements=valid_idx) + +Shape = tuple[int, ...] + +class IndexParam(NamedTuple): + ref_aval: shaped_array_ref + ref_shape: Shape + indexed_dims: list[bool] + idx_avals: tuple[core.ShapedArray, ...] + idx_shape: Shape + slice_aval: core.ShapedArray + slice_shape: Shape + +@hps.composite +def index_params(draw): + ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') + indexed_dims = draw(hps.lists(hps.booleans(), + min_size=len(ref_shape), + max_size=len(ref_shape))) + idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) + if not any(indexed_dims): + slice_shape = ref_shape + else: + sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(indexed_dims)[0]) == 1) + ) + if not int_indexers_contiguous: + slice_shape = (*idx_shape, *sliced_shape) else: - sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) - int_indexers_contiguous = bool( - np.all(np.diff(np.where(indexed_dims)[0]) == 1) + insert_pos = indexed_dims.index(True) + slice_shape = ( + *sliced_shape[:insert_pos], + *idx_shape, + *sliced_shape[insert_pos:], ) - if not int_indexers_contiguous: - slice_shape = (*idx_shape, *sliced_shape) - else: - insert_pos = indexed_dims.index(True) - slice_shape = ( - *sliced_shape[:insert_pos], - *idx_shape, - *sliced_shape[insert_pos:], - ) - ref_aval = shaped_array_ref(ref_shape, np.float32) - idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in - range(sum(indexed_dims))) - slice_aval = core.ShapedArray(slice_shape, np.float32) - return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, - slice_aval, slice_shape) - - class VmappableIndexParam(NamedTuple): - index_param: IndexParam - ref_bdim: int | None - non_slice_idx_bdims: tuple[int | None, ...] - slice_bdim: int - bat_ref_aval: shaped_array_ref - bat_ref_shape: Shape - bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] - bat_non_slice_idx_shapes: tuple[Shape, ...] - bat_slice_aval: core.ShapedArray - bat_slice_shape: Shape - - def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, - val: Any) -> tuple[Any, ...]: - if idx is None: - return t - return tuple_insert(t, idx, val) - - @hps.composite - def vmappable_index_params(draw, *, op_type: str): - axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') - index_param: IndexParam = draw(index_params()) - non_slice_idx_bdims = tuple( - draw(hps.one_of( - hps.none(), - hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) - for b in index_param.indexed_dims if b) - bat_non_slice_idx_shapes = tuple( - maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) - for idx_bdim in non_slice_idx_bdims) - if op_type == "swap": - # In a swap, the ref *must* be batched - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): - # If it's a swap, if indices are batched, val must be batched. - slice_bdim = draw(hps.integers( - min_value=0, max_value=len(index_param.slice_shape))) - else: - slice_bdim = draw(hps.one_of(hps.none(), hps.integers( - min_value=0, max_value=len(index_param.slice_shape)))) - elif op_type == "get": - # In a get, the indices must be batched or ref is batched - if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - else: - ref_bdim = draw(hps.one_of(hps.none(), - hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + ref_aval = shaped_array_ref(ref_shape, np.float32) + idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in + range(sum(indexed_dims))) + slice_aval = core.ShapedArray(slice_shape, np.float32) + return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, + slice_aval, slice_shape) + +class VmappableIndexParam(NamedTuple): + index_param: IndexParam + ref_bdim: int | None + non_slice_idx_bdims: tuple[int | None, ...] + slice_bdim: int + bat_ref_aval: shaped_array_ref + bat_ref_shape: Shape + bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] + bat_non_slice_idx_shapes: tuple[Shape, ...] + bat_slice_aval: core.ShapedArray + bat_slice_shape: Shape + +def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, + val: Any) -> tuple[Any, ...]: + if idx is None: + return t + return tuple_insert(t, idx, val) + +@hps.composite +def vmappable_index_params(draw, *, op_type: str): + axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') + index_param: IndexParam = draw(index_params()) + non_slice_idx_bdims = tuple( + draw(hps.one_of( + hps.none(), + hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) + for b in index_param.indexed_dims if b) + bat_non_slice_idx_shapes = tuple( + maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) + for idx_bdim in non_slice_idx_bdims) + if op_type == "swap": + # In a swap, the ref *must* be batched + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): + # If it's a swap, if indices are batched, val must be batched. slice_bdim = draw(hps.integers( min_value=0, max_value=len(index_param.slice_shape))) + else: + slice_bdim = draw(hps.one_of(hps.none(), hps.integers( + min_value=0, max_value=len(index_param.slice_shape)))) + elif op_type == "get": + # In a get, the indices must be batched or ref is batched + if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + else: + ref_bdim = draw(hps.one_of(hps.none(), + hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + slice_bdim = draw(hps.integers( + min_value=0, max_value=len(index_param.slice_shape))) + + bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) + bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) + bat_non_slice_idx_avals = tuple( + core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) + bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) + bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) + return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, + slice_bdim, bat_ref_aval, bat_ref_shape, + bat_non_slice_idx_avals, bat_non_slice_idx_shapes, + bat_slice_aval, bat_slice_shape) + +class GetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def get_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw( + vmappable_index_params(op_type="get")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) + +class SetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_val: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def set_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( + op_type="swap")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) + return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) + +Indexer = tuple[Union[int, slice, np.ndarray]] + +def _unpack_idx(idx: Indexer + ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: + indexed_dims = [type(i) != slice for i in idx] + non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] + return non_slice_idx, indexed_dims + +def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], + indexed_dims: Sequence[bool]) -> Indexer: + idx_ = iter(non_slice_idx) + idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) + assert next(idx_, None) is None + return idx + +@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe +class StateHypothesisTest(jtu.JaxTestCase): + + @hp.given(get_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_get_vmap(self, get_vmap_param: GetVmapParams): + + indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + return [ref_get(ref, idx)] + ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = get_vmap_param.vmap_index_param.ref_bdim + idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims + out_bdim = get_vmap_param.vmap_index_param.slice_bdim + non_slice_idx = get_vmap_param.bat_idxs + idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals + ref = get_vmap_param.bat_ref + + f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, *idx_bdims), + out_axes=[out_bdim, ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_set_vmap(self, set_vmap_param: SetVmapParams): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Scatter is nondeterministic on GPU") + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_set(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) - bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) - bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) - bat_non_slice_idx_avals = tuple( - core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) - bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) - bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) - return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, - slice_bdim, bat_ref_aval, bat_ref_shape, - bat_non_slice_idx_avals, bat_non_slice_idx_shapes, - bat_slice_aval, bat_slice_shape) - - class GetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def get_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw( - vmappable_index_params(op_type="get")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) - - class SetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_val: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def set_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( - op_type="swap")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) - return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) - - Indexer = tuple[Union[int, slice, np.ndarray]] - - def _unpack_idx(idx: Indexer - ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: - indexed_dims = [type(i) != slice for i in idx] - non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] - return non_slice_idx, indexed_dims - - def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], - indexed_dims: Sequence[bool]) -> Indexer: - idx_ = iter(non_slice_idx) - idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) - assert next(idx_, None) is None - return idx - - @jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe - class StateHypothesisTest(jtu.JaxTestCase): - - @hp.given(get_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_get_vmap(self, get_vmap_param: GetVmapParams): - - indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - return [ref_get(ref, idx)] - ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = get_vmap_param.vmap_index_param.ref_bdim - idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims - out_bdim = get_vmap_param.vmap_index_param.slice_bdim - non_slice_idx = get_vmap_param.bat_idxs - idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals - ref = get_vmap_param.bat_ref - - f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, *idx_bdims), - out_axes=[out_bdim, ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_set_vmap(self, set_vmap_param: SetVmapParams): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Scatter is nondeterministic on GPU") - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_set(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): - - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_addupdate(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): + + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_addupdate(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) class StateControlFlowTest(jtu.JaxTestCase): @@ -1634,220 +1628,218 @@ def _body(ref): jtu.check_grads(f, (0.5,), order=3) -if CAN_USE_HYPOTHESIS: - - class FuncSpec(NamedTuple): - fun: Callable[..., Any] - name: str - min_rank: int = 0 - max_rank: int = 4 - min_dim: int = 0 - max_dim: int = 4 - - def call(self, *args): - return run_state(self.fun)(*args) - - def ref(self, *args): - return run_state_reference(self.fun)(*args) - - def sin_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - sin_spec = FuncSpec(sin_stateful, "sin") - - def cos_stateful(refs): +class FuncSpec(NamedTuple): + fun: Callable[..., Any] + name: str + min_rank: int = 0 + max_rank: int = 4 + min_dim: int = 0 + max_dim: int = 4 + + def call(self, *args): + return run_state(self.fun)(*args) + + def ref(self, *args): + return run_state_reference(self.fun)(*args) + +def sin_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + +sin_spec = FuncSpec(sin_stateful, "sin") + +def cos_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.cos(x_ref[...]) + +cos_spec = FuncSpec(cos_stateful, "cos") + +def mul2_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = x_ref[...] + y_ref[...] = y_ref[...] + x_ref[...] + +mul2_spec = FuncSpec(mul2_stateful, "mul2") + +def mul2_stateful_with_constant(refs): + x_ref, y_ref = refs + y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + +mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + +def crazy_identity_stateful(refs): + x_ref, y_ref = refs + x = x_ref[...] + x_ref[...] = (x + x) / 2 + y_ref[...] = x_ref[...] + y = y_ref[...] + y_ref[...] = (y + y) / 2 + +crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") + +def func_spec(depth: int = 4): + raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, + mul2_constant_spec, crazy_identity_spec]) + if depth > 0: + return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), + compose_spec(depth - 1)]) + return raw_specs + +@hps.composite +def compose_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(*args): + f1.fun(*args) + f2.fun(*args) + return FuncSpec(wrapped_impl, + f"({f2.name} . {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@hps.composite +def nest_spec(draw, depth): + f = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = jnp.cos(x_ref[...]) - - cos_spec = FuncSpec(cos_stateful, "cos") - - def mul2_stateful(refs): + x, y = x_ref[...], y_ref[...] + x, y = run_state(f.fun)((x, y)) + x_ref[...], y_ref[...] = x, y + return FuncSpec(wrapped_impl, + f"nest({f.name})", + min_rank=f.min_rank, + max_rank=f.max_rank, + min_dim=f.min_dim, + max_dim=f.max_dim) + + +@hps.composite +def add_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = x_ref[...] - y_ref[...] = y_ref[...] + x_ref[...] - - mul2_spec = FuncSpec(mul2_stateful, "mul2") - - def mul2_stateful_with_constant(refs): - x_ref, y_ref = refs - y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] - - mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") - - def crazy_identity_stateful(refs): - x_ref, y_ref = refs - x = x_ref[...] - x_ref[...] = (x + x) / 2 - y_ref[...] = x_ref[...] - y = y_ref[...] - y_ref[...] = (y + y) / 2 - - crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") - - def func_spec(depth: int = 4): - raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, - mul2_constant_spec, crazy_identity_spec]) - if depth > 0: - return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), - compose_spec(depth - 1)]) - return raw_specs - - @hps.composite - def compose_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(*args): - f1.fun(*args) - f2.fun(*args) - return FuncSpec(wrapped_impl, - f"({f2.name} . {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @hps.composite - def nest_spec(draw, depth): - f = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x, y = run_state(f.fun)((x, y)) - x_ref[...], y_ref[...] = x, y - return FuncSpec(wrapped_impl, - f"nest({f.name})", - min_rank=f.min_rank, - max_rank=f.max_rank, - min_dim=f.min_dim, - max_dim=f.max_dim) - - - @hps.composite - def add_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x1, y1 = run_state(f1.fun)((x, y)) - x2, y2 = run_state(f2.fun)((x, y)) - x_ref[...], y_ref[...] = x1 + x2, y1 + y2 - return FuncSpec(wrapped_impl, - f"({f2.name} + {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @jtu.thread_unsafe_test_class() # because of hypothesis - class RunStateHypothesisTest(jtu.JaxTestCase): - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_jvp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - t = random.normal(k2, x.shape) - y, y_t = jax.jvp(impl, (x,), (t,)) - y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) - self.assertAllClose(y, y_ref) - self.assertAllClose(y_t, y_ref_t) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_linearize(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_vjp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - key, k1, k2 = random.split(random.PRNGKey(0), 3) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - - # First order - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t)) - - y, impl_vjp = jax.vjp(impl, x) - y_ref, ref_vjp = jax.vjp(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(jax.random.clone(k2), x.shape) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp(t), ref_vjp(t)) - - if jtu.SKIP_SLOW_TESTS.value: - # Skip second order tests if JAX_SKIP_SLOW_TESTS=true - return - - # Second order - key, k1, k2 = random.split(key, 3) - t2 = random.normal(k2, t.shape) - - (x,), impl_lin2 = jax.linearize(impl_vjp, t2) - (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(k1, y.shape) - self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) - - (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) - (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) + x, y = x_ref[...], y_ref[...] + x1, y1 = run_state(f1.fun)((x, y)) + x2, y2 = run_state(f2.fun)((x, y)) + x_ref[...], y_ref[...] = x1 + x2, y1 + y2 + return FuncSpec(wrapped_impl, + f"({f2.name} + {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@jtu.thread_unsafe_test_class() # because of hypothesis +class RunStateHypothesisTest(jtu.JaxTestCase): + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_jvp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + t = random.normal(k2, x.shape) + y, y_t = jax.jvp(impl, (x,), (t,)) + y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) + self.assertAllClose(y, y_ref) + self.assertAllClose(y_t, y_ref_t) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_linearize(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_vjp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + key, k1, k2 = random.split(random.PRNGKey(0), 3) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + + # First order + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t)) + + y, impl_vjp = jax.vjp(impl, x) + y_ref, ref_vjp = jax.vjp(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(jax.random.clone(k2), x.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp(t), ref_vjp(t)) + + if jtu.SKIP_SLOW_TESTS.value: + # Skip second order tests if JAX_SKIP_SLOW_TESTS=true + return + + # Second order + key, k1, k2 = random.split(key, 3) + t2 = random.normal(k2, t.shape) + + (x,), impl_lin2 = jax.linearize(impl_vjp, t2) + (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) + + (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) + (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 57e33bcbcdd20dcc3675edbd921ff8a104ff27ae Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 14 Apr 2025 17:19:50 -0700 Subject: [PATCH 0619/1769] Deprecate the contents of jax.util. PiperOrigin-RevId: 747629222 --- CHANGELOG.md | 4 + benchmarks/api_benchmark.py | 4 +- jax/_src/pallas/mosaic/interpret.py | 8 +- jax/experimental/colocated_python/api.py | 2 +- jax/experimental/colocated_python/func.py | 2 +- jax/experimental/colocated_python/obj.py | 2 +- .../colocated_python/serialization.py | 2 +- jax/experimental/sparse/ad.py | 2 +- jax/experimental/sparse/bcoo.py | 2 +- jax/experimental/sparse/bcsr.py | 2 +- jax/experimental/sparse/random.py | 2 +- jax/experimental/sparse/test_util.py | 2 +- jax/experimental/sparse/transform.py | 2 +- jax/experimental/sparse/util.py | 2 +- jax/util.py | 148 ++++++++++++++++-- tests/mosaic/gpu_test.py | 2 +- tests/sparse_bcoo_bcsr_test.py | 2 +- tests/sparse_test.py | 2 +- 18 files changed, 156 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fb64ef09bb1..3374071f190a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a callable. + * The following internal APIs in `jax.util` are deprecated: + `HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, + `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, + `toposort`, `unzip2`, `wrap_name`, and `wraps`. * `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX `Array` directly to the `from_dlpack` function of another framework. If you need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index cabebce2227c..a62b78d66ced 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -847,7 +847,7 @@ def safe_map(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) def f(*args): return tuple(args) while state: - jax.util.safe_map(f, *args) + jax._src.util.safe_map(f, *args) @google_benchmark.register @google_benchmark.option.arg_names(['arg_lengths', 'num_args']) @@ -855,7 +855,7 @@ def f(*args): return tuple(args) def safe_zip(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) while state: - jax.util.safe_zip(*args) + jax._src.util.safe_zip(*args) @google_benchmark.register diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 13b321424f81..5ac6bb6564ba 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1036,7 +1036,7 @@ def write(var, value): value = Placeholder(value.shape, value.dtype) env[var] = value - jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) + jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) # Get the device ID. axis_sizes = jax_core.get_axis_env().axis_sizes @@ -1062,7 +1062,7 @@ def write(var, value): # not need to do any reads if `interpret_params.skip_floating_point_ops` # is True. If this is the case, we want to avoid materializing the read # array into the jaxpr when this function is traced. - deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) + deferred_invals = functools.partial(jax._src.util.safe_map, read, eqn.invars) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( @@ -1337,9 +1337,9 @@ def f(*args, jaxpr): out = prim.bind(*subfuns, *deferred_invals(), **bind_params) out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) + jax._src.util.safe_map(write, eqn.outvars, out) - return jax.util.safe_map(read, jaxpr.outvars) + return jax._src.util.safe_map(read, jaxpr.outvars) def _compute_start_indices( block_mapping, loop_idx, *args, compiler_params, interpret_params): diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index e72e04c6ded9..81db9b965e7c 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -33,7 +33,7 @@ def colocated_cpu_devices( return _colocated_cpu_devices_cached(devices) -@jax.util.cache(max_size=1024, trace_context_in_key=False) +@jax._src.util.cache(max_size=1024, trace_context_in_key=False) def _colocated_cpu_devices_cached( devices: tuple[jax.Device, ...], ) -> Sequence[jax.Device]: diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index effca1fe77b7..b7188d9da7ad 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -279,7 +279,7 @@ def _make_async_execution_fun( ) -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index b1e7a0b1eade..d7d40e88f925 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -58,7 +58,7 @@ def pop_instance(self, uid: int) -> set[jax.Device]: SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry() -@jax.util.cache(max_size=4096) +@jax._src.util.cache(max_size=4096) def _update_instance_devices( uid: int, shardings: tuple[jax.sharding.Sharding, ...] ) -> None: diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 1ca29ab12660..a8a62d78359f 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -35,7 +35,7 @@ DeviceList = xc.DeviceList -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" cpu_device_map: dict[int, jax.Device] = {} diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 018047e3d5e1..861ef5289cdd 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -22,7 +22,7 @@ from jax._src import core from jax import tree_util from jax._src.api_util import _ensure_index, _ensure_index_tuple -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.util import split_list, wraps from jax._src.traceback_util import api_boundary from jax.experimental.sparse._base import JAXSparse diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 42820fe73651..0365f93d551a 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -38,7 +38,7 @@ from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p from jax._src.interpreters import mlir import jax.numpy as jnp -from jax.util import safe_zip, unzip2, split_list +from jax._src.util import safe_zip, unzip2, split_list from jax._src import api_util from jax._src import config from jax._src import core diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index dc8be2237544..4b01f362bb83 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -33,7 +33,7 @@ from jax.experimental.sparse.util import ( nfold_vmap, _count_stored_elements, _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) -from jax.util import split_list, safe_zip +from jax._src.util import split_list, safe_zip from jax._src import api_util from jax._src import config diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index f90c2572d282..a9146b7746e0 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -18,7 +18,7 @@ from jax import dtypes from jax import vmap from jax import random -from jax.util import split_list +from jax._src.util import split_list import jax.numpy as jnp from jax.experimental import sparse diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 77c97513041c..63e035d2d1ac 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -29,7 +29,7 @@ from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp -from jax.util import safe_zip, split_list +from jax._src.util import safe_zip, split_list import numpy as np MATMUL_TOL = { diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index ce1d3f4af9d0..a16756d42c45 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -68,7 +68,7 @@ from jax._src.lib import pytree from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map, tree_unflatten -from jax.util import safe_map, safe_zip, split_list +from jax._src.util import safe_map, safe_zip, split_list from jax._src.lax.control_flow import _check_tree_and_avals from jax._src.numpy import indexing as jnp_indexing from jax.experimental import sparse diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 36e9a9c51664..7c6bfb1ec345 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src.api_util import flatten_axes import jax.numpy as jnp -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.lax.lax import _dot_general_shape_rule, DotDimensionNumbers from jax._src.typing import Array diff --git a/jax/util.py b/jax/util.py index 8071f77dffe2..b2c9df205206 100644 --- a/jax/util.py +++ b/jax/util.py @@ -15,19 +15,135 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.util import ( - HashableFunction as HashableFunction, - as_hashable_function as as_hashable_function, - cache as cache, - safe_map as safe_map, - safe_zip as safe_zip, - split_dict as split_dict, - split_list as split_list, - split_list_checked as split_list_checked, - split_merge as split_merge, - subvals as subvals, - toposort as toposort, - unzip2 as unzip2, - wrap_name as wrap_name, - wraps as wraps, -) +import jax._src.deprecations +import jax._src.util + + +_deprecations = { + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + jax._src.dlpack.to_dlpack, + ), + "HashableFunction": ( + ( + "HashableFunction was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.util.HashableFunction, + ), + "as_hashable_function": ( + ( + "as_hashable_function was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0." + ), + jax._src.util.as_hashable_function, + ), + "cache": ( + "cache was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", + jax._src.util.cache, + ), + "safe_map": ( + ( + "safe_map was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.safe_map, + ), + "safe_zip": ( + ( + "safe_zip was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.safe_zip, + ), + "split_dict": ( + ( + "split_dict was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.split_dict, + ), + "split_list": ( + ( + "split_list was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.split_list, + ), + "split_list_checked": ( + ( + "split_list_checked was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0." + ), + jax._src.util.split_list_checked, + ), + "split_merge": ( + ( + "split_merge was deprecated in JAX v0.6.0 and will be removed in" + " JAX v0.7.0." + ), + jax._src.util.split_merge, + ), + "subvals": ( + ( + "subvals was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.subvals, + ), + "toposort": ( + ( + "toposort was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.toposort, + ), + "unzip2": ( + ( + "unzip2 was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.unzip2, + ), + "wrap_name": ( + ( + "wrap_name was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.wrap_name, + ), + "wraps": ( + "wraps was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", + jax._src.util.wraps, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + HashableFunction = jax._src.util.HashableFunction + as_hashable_function = jax._src.util.as_hashable_function + cache = jax._src.util.cache + safe_map = jax._src.util.safe_map + safe_zip = jax._src.util.safe_zip + split_dict = jax._src.util.split_dict + split_list = jax._src.util.split_list + split_list_checked = jax._src.util.split_list_checked + split_merge = jax._src.util.split_merge + subvals = jax._src.util.subvals + toposort = jax._src.util.toposort + unzip2 = jax._src.util.unzip2 + wrap_name = jax._src.util.wrap_name + wraps = jax._src.util.wraps +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d8d57a06b573..9ffaff121849 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2552,7 +2552,7 @@ def set_in_transforms( in_transforms = [] smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable - for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): + for _, result_transforms in jax._src.util.safe_zip(smem_refs, transforms): in_transforms.append( ir.ArrayAttr.get([t.attr() for t in result_transforms]) ) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index e839bacbe5fc..1224717570d1 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -36,7 +36,7 @@ from jax.experimental.sparse import util as sparse_util import jax.numpy as jnp import jax.random -from jax.util import split_list +from jax._src.util import split_list import numpy as np jax.config.parse_flags_with_absl() diff --git a/tests/sparse_test.py b/tests/sparse_test.py index eb8d70be1f05..219875d4b7d0 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -38,7 +38,7 @@ from jax._src import test_util as jtu from jax.interpreters import mlir import jax.numpy as jnp -from jax.util import split_list +from jax._src.util import split_list import numpy as np import scipy.sparse From 4fa3cd91d3df16fe969b0b530309e44357638636 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 14 Apr 2025 19:08:12 -0700 Subject: [PATCH 0620/1769] [Pallas/Fuser] Add basic closed over consts support to pull_block_spec PiperOrigin-RevId: 747657069 --- jax/_src/pallas/fuser/block_spec.py | 10 +++++--- tests/pallas/fuser_block_spec_test.py | 35 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 146191bab9b3..9524ce4ca4d2 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -239,9 +239,7 @@ def wrapped(*args, **kwargs): jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( f, *args, **kwargs ) - # TODO(sharadmv): handle these consts better, they should correspond to - # scalar prefetch. - del consts, out_tree_ + del out_tree_ jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars) block_specs_ = jax.tree.map( _unwrap_block_spec_scalar_prefetch, out_block_specs @@ -263,6 +261,7 @@ def wrapped(*args, **kwargs): ) kernel_fn = make_kernel_function( jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -408,6 +407,7 @@ def _get_in_block_spec(v, usage): def make_kernel_function( jaxpr: core.Jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -505,6 +505,8 @@ def read_env(atom): def write_env(var, val): env[var] = val + for const, constvar in zip(consts, jaxpr.constvars): + env[constvar] = const for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages): if Usage.REGULAR in usage: env[invar] = arg @@ -1232,6 +1234,7 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1289,6 +1292,7 @@ def _custom_jvp_call_eval_rule( ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 1b3a215876ec..377901933b4e 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -769,6 +769,41 @@ def f(x): kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x ) + def test_pull_block_spec_handles_closed_over_constants(self): + x = jnp.ones((2, 512, 512)) + i = jnp.array(1) + + def f(): + return x[i] + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertLen(scalar_prefetch_values, 1) + + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l) + ) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + scalar_prefetch_values = jax.tree.map( + lambda x: x[None], scalar_prefetch_values + ) + fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x) + new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),) + # Try pulling again + # This should not raise an error. + _ = block_spec_lib.pull_block_spec( + fn, + block_spec, + grid=(1,), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values_type) + class PushBlockSpecTest(parameterized.TestCase): From 0ed0fb7c5493e710380472b2c824b6c41f3603cf Mon Sep 17 00:00:00 2001 From: Mark Sandler Date: Mon, 14 Apr 2025 19:08:48 -0700 Subject: [PATCH 0621/1769] Adds a debugging message to assert, otherwise the error is pretty cryptic. PiperOrigin-RevId: 747657234 --- jax/_src/sharding_impls.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index e462ee2f0ba0..2b79e09c49e1 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -898,7 +898,11 @@ def parse_flatten_op_sharding( while dim_size > 1: axis = next(mesh_axis) axis_size = mesh_shape[axis] - assert dim_size % axis_size == 0 + if dim_size % axis_size != 0: + raise ValueError( + f'{shape=} is incompatible with {mesh_shape=}: ' + f'{dim_size=} is not divisible by {axis_size=}.' + ) dim_size //= axis_size dim_partitions.append(axis) partitions.append(tuple(dim_partitions)) From 1926b99bfd3524f8c62521db454f69dcb942a1e5 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 14 Apr 2025 19:35:09 -0700 Subject: [PATCH 0622/1769] [pallas] Fix spelling of 'fusible'. PiperOrigin-RevId: 747663692 --- jax/BUILD | 2 +- jax/_src/pallas/fuser/BUILD | 16 ++-- jax/_src/pallas/fuser/__init__.py | 2 +- .../pallas/fuser/{fusable.py => fusible.py} | 18 ++--- .../{fusable_dtype.py => fusible_dtype.py} | 30 +++---- jax/_src/pallas/fuser/jaxpr_fusion.py | 40 +++++----- jax/experimental/pallas/fuser.py | 2 +- tests/pallas/BUILD | 4 +- tests/pallas/fusion_test.py | 20 ++--- ...mul_test.py => tpu_fusible_matmul_test.py} | 80 +++++++++---------- 10 files changed, 108 insertions(+), 106 deletions(-) rename jax/_src/pallas/fuser/{fusable.py => fusible.py} (86%) rename jax/_src/pallas/fuser/{fusable_dtype.py => fusible_dtype.py} (95%) rename tests/pallas/{tpu_fusable_matmul_test.py => tpu_fusible_matmul_test.py} (94%) diff --git a/jax/BUILD b/jax/BUILD index cb9a39efb6f4..ca0fb0268c46 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -721,7 +721,7 @@ pytype_strict_library( ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", "//jax/_src/pallas/fuser:custom_evaluate", - "//jax/_src/pallas/fuser:fusable", + "//jax/_src/pallas/fuser:fusible", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", ], diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 8339ad6705ff..a62a9937d91d 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -33,7 +33,7 @@ pytype_strict_library( deps = [ ":block_spec", ":custom_evaluate", - ":fusable", + ":fusible", ":fusion", ":jaxpr_fusion", ], @@ -58,9 +58,9 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable", + name = "fusible", srcs = [ - "fusable.py", + "fusible.py", ], deps = [ ":fusion", @@ -91,8 +91,8 @@ pytype_strict_library( "jaxpr_fusion.py", ], deps = [ - ":fusable", - ":fusable_dtype", + ":fusible", + ":fusible_dtype", ":fusion", "//jax", "//jax:api_util", @@ -104,13 +104,13 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable_dtype", + name = "fusible_dtype", srcs = [ - "fusable_dtype.py", + "fusible_dtype.py", ], deps = [ ":block_spec", - ":fusable", + ":fusible", "//jax", "//jax:api_util", "//jax:core", diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index 3295c8f1061a..39720100eb1d 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusible.py similarity index 86% rename from jax/_src/pallas/fuser/fusable.py rename to jax/_src/pallas/fuser/fusible.py index d9d0ee0b4682..289a9dc268b4 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusible.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable primitive.""" +"""Fusible primitive.""" from typing import Any import jax @@ -25,8 +25,8 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas.fuser import fusion as fusion_lib -fusable_p = jax_core.Primitive('fusable') -fusable_p.multiple_results = True +fusible_p = jax_core.Primitive('fusible') +fusible_p.multiple_results = True def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: @@ -37,7 +37,7 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: ) -def fusable(f=None, *, output_fusion_prefix: Any = True): +def fusible(f=None, *, output_fusion_prefix: Any = True): def decorator(f): def wrapper(*args): def wrapped(*args): @@ -45,14 +45,14 @@ def wrapped(*args): return f(*in_fusions, None) flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) + debug_info = api_util.debug_info('fusible', wrapped, args, {}) flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(wrapped, debug_info=debug_info), in_tree ) flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) out_tree = out_tree_thunk() - out = fusable_p.bind( + out = fusible_p.bind( *consts, *flat_args, jaxpr=jaxpr, @@ -71,16 +71,16 @@ def wrapped(*args): return decorator -@fusable_p.def_impl +@fusible_p.def_impl def _(*consts_and_args, jaxpr, num_consts, **_): consts, args = util.split_list(consts_and_args, [num_consts]) return jax_core.eval_jaxpr(jaxpr, consts, *args) -mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl)) +mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) -@fusable_p.def_abstract_eval +@fusible_p.def_abstract_eval def _(*args, jaxpr, **kwargs): del args, kwargs return [v.aval for v in jaxpr.outvars] diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py similarity index 95% rename from jax/_src/pallas/fuser/fusable_dtype.py rename to jax/_src/pallas/fuser/fusible_dtype.py index 99c80e652791..8e6cfefcc9eb 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom fusable dtypes.""" +"""Custom fusible dtypes.""" import abc import dataclasses @@ -34,7 +34,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.fuser import block_spec -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.util import foreach @@ -54,7 +54,7 @@ @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, FusableElementDType): + if dtypes.issubdtype(dtype, fusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,7 +69,7 @@ def pack(*xs, dtype): @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, FusableElementDType): + if dtypes.issubdtype(x.dtype, fusibleElementDType): return x.dtype.abstract_unpack(x) elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): raise NotImplementedError() @@ -80,20 +80,20 @@ def unpack(x): return unpack_dtype_p.bind(x) -class FusableElementDType(dtypes.extended): - """Scalar dtype for fusable dtypes.""" +class fusibleElementDType(dtypes.extended): + """Scalar dtype for fusible dtypes.""" -class FusableTyRules: +class fusibleTyRules: allow_conversion: bool = False class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): - """Base class for fusable extended dtypes.""" + """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = FusableTyRules - type = FusableElementDType + _rules = fusibleTyRules + type = fusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: @@ -124,7 +124,7 @@ def pull_block_spec_one_step(self, *args, **kwargs): def physicalize(f): - """Runs a function that contains fusable extended dtypes.""" + """Runs a function that contains fusible extended dtypes.""" def wrapper(*args, **kwargs): if kwargs: @@ -203,7 +203,7 @@ class Context: def physicalize_interp( jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value ): - """Physicalizes a jaxpr by replacing fusable dtypes with physical types.""" + """Physicalizes a jaxpr by replacing fusible dtypes with physical types.""" # TODO: Merge into JAX core. env: dict[core.Var, Any] = {} @@ -446,12 +446,12 @@ def _pack_dtype_pull_rule( return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error -def _fusable_physicalize_rule( +def _fusible_physicalize_rule( _, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func ): consts, _ = util.split_list(consts_and_args, [num_consts]) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) - return fusable_p.bind( + return fusible_p.bind( *consts_and_args, jaxpr=new_jaxpr.jaxpr, num_consts=num_consts, @@ -461,4 +461,4 @@ def _fusable_physicalize_rule( ) -_physicalize_rules[fusable_p] = _fusable_physicalize_rule +_physicalize_rules[fusible_p] = _fusible_physicalize_rule diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3c3c2a3d7b66..95768d71f792 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -23,22 +23,22 @@ from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe -from jax._src.pallas.fuser import fusable_dtype +from jax._src.pallas.fuser import fusible_dtype from jax._src.pallas.fuser import fusion as fusion_lib -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p def fuse(f=None, *, physicalize: bool = False, debug: bool = False): - """Fuses a function into a single fusable. + """Fuses a function into a single fusible. Args: f: The function to fuse. physicalize: (experimental) whether to physicalize the function. debug: Whether to print debug information. - There should be a single call to a `fusable` inside the body of `f`. `fuse` + There should be a single call to a `fusible` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into - the fusable and invoke it. + the fusible and invoke it. """ def decorator(f): @@ -58,7 +58,7 @@ def wrapper(*args, **kwargs): return tree_util.tree_unflatten(out_tree, out_flat) if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) + wrapper = fusible_dtype.physicalize(wrapper) return wrapper if f is not None: @@ -66,7 +66,7 @@ def wrapper(*args, **kwargs): return decorator -_fusable: dict[jax_core.Primitive, Any] = {} +_fusible: dict[jax_core.Primitive, Any] = {} def _construct_fusion_jaxpr( @@ -148,11 +148,11 @@ def _construct_output_fusions( jaxpr, out_tree, fusion_eqn_index, - fusion_eqn_outvars, # Flat list of vars output by the fusable eqn - fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs + fusion_eqn_outvars, # Flat list of vars output by the fusible eqn + fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs output_fusion_prefix, # Pytree defining output groups ): - # 1. Create jaxpr_out: represents computation *after* the fusable + # 1. Create jaxpr_out: represents computation *after* the fusible # Inputs: fusion_eqn_outvars # Outputs: jaxpr.outvars jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( @@ -164,15 +164,15 @@ def _construct_output_fusions( tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs tree_util.tree_unflatten( fusion_eqn_out_tree, fusion_eqn_outvars - ), # Fusable outputs as inputs + ), # Fusible outputs as inputs ) - # 2. Group fusable outputs based on the mask - unflat_fusable_outvars = jax.tree.unflatten( + # 2. Group fusible outputs based on the mask + unflat_fusible_outvars = jax.tree.unflatten( fusion_eqn_out_tree, fusion_eqn_outvars ) partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( - unflat_fusable_outvars + unflat_fusible_outvars ) # 3. Calculate dependencies and check disjointness @@ -180,10 +180,10 @@ def _construct_output_fusions( already_used_final_outputs = set() # Indices of final outputs already claimed for outvars_group in partial_flat: # Identify vars in this group - used_fusable_outvars = set(jax.tree.leaves(outvars_group)) + used_fusible_outvars = set(jax.tree.leaves(outvars_group)) # Create mask for jaxpr_out inputs corresponding to this group in_used_mask = [ - True if v in used_fusable_outvars else False for v in jaxpr_out.invars + True if v in used_fusible_outvars else False for v in jaxpr_out.invars ] # Trace dependencies through jaxpr_out to find which final outputs are affected downstream_used_mask = _find_downstream( @@ -257,11 +257,11 @@ def fuse_jaxpr( # Collect input fusions for i, eqn in enumerate(jaxpr.eqns): - if eqn.primitive is fusable_p: + if eqn.primitive is fusible_p: fusion_eqn_index = i break if fusion_eqn_index is None: - raise ValueError("No fusable eqn found") + raise ValueError("No fusible eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] # Now let's check if we need to do any fusion at all, e.g. do the outputs of @@ -269,13 +269,13 @@ def fuse_jaxpr( # with all the inputs and outputs to check if there is a dependence. dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) - if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns): + if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns): # Short circuit if there is nothing to fuse. return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) candidate_values = [*consts, *args] - # Construct fusions for non-constant inputs to the fusable. + # Construct fusions for non-constant inputs to the fusible. in_fusions_flat = [ construct_fusion( candidate_values, diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 729a447b7408..d4ec7e89cc7d 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 730c354d8fdb..fa98c0af4be8 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -702,8 +702,8 @@ jax_multiplatform_test( ) jax_multiplatform_test( - name = "tpu_fusable_matmul_test", - srcs = ["tpu_fusable_matmul_test.py"], + name = "tpu_fusible_matmul_test", + srcs = ["tpu_fusible_matmul_test.py"], disable_configs = [ "tpu_v3", "tpu_pjrt_c_api", diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py index 2edcf78f1aba..4bd02345ca62 100644 --- a/tests/pallas/fusion_test.py +++ b/tests/pallas/fusion_test.py @@ -28,7 +28,7 @@ def test_basic_fusion(self): @jax.jit @fuser.fuse - @fuser.fusable + @fuser.fusible def f(x_fn, y_fn): x = x_fn() if y_fn is None: @@ -40,7 +40,7 @@ def f(x_fn, y_fn): def test_separate_output_fusions_trivial(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -63,7 +63,7 @@ def g(x, y): def test_separate_output_fusions_should_error_if_not_disjoint(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -89,7 +89,7 @@ def g(x, y): def test_separate_output_fusions_allows_permute(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -112,7 +112,7 @@ def g(x, y): def test_separate_output_fusions_with_nesting(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -136,7 +136,7 @@ def g(x, y): def test_separate_output_fusions_with_nesting_and_permutation(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -160,7 +160,7 @@ def g(x, y): def test_separate_output_fusions_with_deep_output_mask(self): - @fuser.fusable(output_fusion_prefix=(True, (True, True))) + @fuser.fusible(output_fusion_prefix=(True, (True, True))) def f(x_fn, y_fn, z_fn, o_fns): x = x_fn() y = y_fn() @@ -185,7 +185,8 @@ def g(x, y, z): np.testing.assert_array_equal(z_out, z + z) def test_separate_output_fusions_with_reused_value(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -209,7 +210,8 @@ def g(x, y, a): np.testing.assert_array_equal(y_out, y + a) def test_empty_fusion(self): - @fuser.fusable + + @fuser.fusible def f(x_fn, y_fn): x = x_fn() if y_fn is None: diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py similarity index 94% rename from tests/pallas/tpu_fusable_matmul_test.py rename to tests/pallas/tpu_fusible_matmul_test.py index 93523b174774..2382c09f26ac 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable matmul test.""" +"""Fusible matmul test.""" import functools from typing import Any @@ -75,7 +75,7 @@ def _(): jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) -def _fusable_matmul( +def _fusible_matmul( x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation @@ -191,7 +191,7 @@ def z_index_map(i, j, k, *_): )[0] -def fusable_matmul( +def fusible_matmul( x: jax.Array, y: jax.Array, *, @@ -201,9 +201,9 @@ def fusable_matmul( debug: bool = False, interpret: bool = False, ) -> jax.Array: - return fuser.fusable( + return fuser.fusible( functools.partial( - _fusable_matmul, + _fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -213,7 +213,7 @@ def fusable_matmul( )(x, y) -class FusableMatmulTest(jtu.JaxTestCase): +class FusibleMatmulTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_device_tpu_at_least(4): @@ -226,7 +226,7 @@ def test_matmul(self, dtype): x = jax.random.normal(k0, (512, 512), dtype) y = jax.random.normal(k1, (512, 512), dtype) np.testing.assert_allclose( - jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5 + jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -238,7 +238,7 @@ def test_matmul_with_activation(self, dtype): @jax.jit @fuser.fuse def matmul_relu(x, y): - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) x = jnp.maximum(x, 0.0) return x @@ -258,7 +258,7 @@ def test_matmul_with_bias(self, dtype): @jax.jit @fuser.fuse def matmul_bias(x, y, b): - x = fusable_matmul(x, y).astype(dtype) + b + x = fusible_matmul(x, y).astype(dtype) + b x = jnp.maximum(x, 0.0) return x @@ -277,7 +277,7 @@ def test_matmul_with_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1]) + x = fusible_matmul(x, y[1]) return x np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5) @@ -291,7 +291,7 @@ def test_matmul_with_dynamic_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i): - x = fusable_matmul(x, y[i]) + x = fusible_matmul(x, y[i]) return x np.testing.assert_allclose( @@ -308,7 +308,7 @@ def test_matmul_with_dynamic_slice_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j): - x = fusable_matmul(x, y[j]).astype(dtype) + b[i] + x = fusible_matmul(x, y[j]).astype(dtype) + b[i] return x np.testing.assert_allclose( @@ -326,7 +326,7 @@ def test_matmul_with_multi_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1, 1]) + x = fusible_matmul(x, y[1, 1]) return x np.testing.assert_allclose( @@ -342,7 +342,7 @@ def test_matmul_with_multiple_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1][1]) + x = fusible_matmul(x, y[1][1]) return x np.testing.assert_allclose( @@ -358,7 +358,7 @@ def test_matmul_with_multiple_dynamic_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[i][j]) + x = fusible_matmul(x, y[i][j]) return x for i in range(2): @@ -376,7 +376,7 @@ def test_matmul_with_mixed_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[2][i, j]) + x = fusible_matmul(x, y[2][i, j]) return x for i in range(2): @@ -397,7 +397,7 @@ def test_matmul_with_multiple_mixed_slices_and_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j, k): - x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype) + x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype) return x + b[i, j] @jit_no_excess_precision @@ -428,7 +428,7 @@ def test_matmul_input_concat_output(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jax.jit @@ -454,7 +454,7 @@ def test_matmul_input_concat_contract(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -482,7 +482,7 @@ def test_matmul_double_concat(self, dtype): def matmul_concat(x, ys, y3): y = jnp.concatenate(ys, axis=0) y = jnp.concatenate([y, y3], axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -509,7 +509,7 @@ def test_matmul_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -534,7 +534,7 @@ def test_matmul_slice_concat_slice(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=1)[1] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -559,7 +559,7 @@ def test_matmul_dynamic_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2, i, j): y = jnp.concatenate([y1, y2[i]], axis=1)[j] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -585,7 +585,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -607,7 +607,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -629,7 +629,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -651,7 +651,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -673,7 +673,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -695,7 +695,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -716,7 +716,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -738,7 +738,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -760,7 +760,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -782,7 +782,7 @@ def matmul(impl, x, y): return z.T impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -804,7 +804,7 @@ def matmul(impl, x, y): return z.T * 2 impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -867,7 +867,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - fusable_matmul, + fusible_matmul, ) ) ref = functools.partial(matmul, dot_ref) @@ -893,7 +893,7 @@ def matmul(impl, x, y): out_ref = jit_no_excess_precision(ref)(x, y) - impl = fuser.fuse(functools.partial(matmul, fusable_matmul)) + impl = fuser.fuse(functools.partial(matmul, fusible_matmul)) out = jax.jit(impl)(x, y) self.assertAllClose(out, out_ref, atol=0) @@ -917,7 +917,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - functools.partial(fusable_matmul, bk=256, bn=128), + functools.partial(fusible_matmul, bk=256, bn=128), ) ) out = jax.jit(impl)(x, y) @@ -953,7 +953,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -990,7 +990,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -1025,7 +1025,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, From 09edc494d2d2cc1aacd4507cc770f9b2f25c7b59 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 14 Apr 2025 20:48:42 -0700 Subject: [PATCH 0623/1769] Add explicit_axes section to the doc --- docs/notebooks/explicit-sharding.ipynb | 48 ++++++++++++++++++++++++-- docs/notebooks/explicit-sharding.md | 34 +++++++++++++++++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index d656e12d4068..d0df799b51f6 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -59,7 +59,7 @@ "import numpy as np\n", "import jax.numpy as jnp\n", "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n", - "from jax.experimental.shard import reshard, auto_axes\n", + "from jax.experimental.shard import reshard, auto_axes, explicit_axes\n", "\n", "jax.config.update('jax_num_cpu_devices', 8)" ] @@ -652,7 +652,51 @@ "id": "_3sfJjRq8w9f" }, "source": [ - "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n", + "\n", + "\n", + "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a102e9c7", + "metadata": {}, + "outputs": [], + "source": [ + "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Auto, AxisType.Auto))\n", + "\n", + "@functools.partial(explicit_axes, axes=('X', 'Y'))\n", + "def explicit_g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }')\n", + " z = y * 2\n", + " print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n", + " return z\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", + " x = jnp.sin(arr1)\n", + "\n", + " z = explicit_g(x, in_shardings=P(\"X\", \"Y\"))\n", + "\n", + " return z + 1\n", + "\n", + "with jax.sharding.use_mesh(auto_mesh):\n", + " some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + " f(some_x)" + ] + }, + { + "cell_type": "markdown", + "id": "e64d40de", + "metadata": {}, + "source": [ + "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n", + "Because of that, sharding is visible on the type of arrays inside `g`." ] }, { diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 7c59a675d8ec..46315cc536d6 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -56,7 +56,7 @@ import jax import numpy as np import jax.numpy as jnp from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh -from jax.experimental.shard import reshard, auto_axes +from jax.experimental.shard import reshard, auto_axes, explicit_axes jax.config.update('jax_num_cpu_devices', 8) ``` @@ -403,6 +403,38 @@ f(some_x) As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + +You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes. + +```{code-cell} ipython3 +auto_mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Auto, AxisType.Auto)) + +@functools.partial(explicit_axes, axes=('X', 'Y')) +def explicit_g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }') + z = y * 2 + print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n') + return z + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n') + x = jnp.sin(arr1) + + z = explicit_g(x, in_shardings=P("X", "Y")) + + return z + 1 + +with jax.sharding.use_mesh(auto_mesh): + some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y")) + f(some_x) +``` + +As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`. +Because of that, sharding is visible on the type of arrays inside `g`. + +++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis From ae6a18d70d9c01de68a4040e9debf05e9006749f Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 4 Apr 2025 10:59:50 +0000 Subject: [PATCH 0624/1769] Add jax.fwd_and_bwd --- jax/__init__.py | 1 + jax/_src/api.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++ tests/api_test.py | 21 ++++++++++++++++ 3 files changed, 83 insertions(+) diff --git a/jax/__init__.py b/jax/__init__.py index 32ae955ae5b8..b7190d37faa6 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -100,6 +100,7 @@ from jax._src.api import disable_jit as disable_jit from jax._src.api import eval_shape as eval_shape from jax._src.dtypes import float0 as float0 +from jax._src.api import fwd_and_bwd as fwd_and_bwd from jax._src.api import grad as grad from jax._src.api import hessian as hessian from jax._src.xla_bridge import host_count as host_count diff --git a/jax/_src/api.py b/jax/_src/api.py index 43ab7729a348..5a448503d8f8 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -575,6 +575,67 @@ def _check_output_dtype_revderiv(name, holomorphic, x): "jax.vjp directly.") _check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad") +def fwd_and_bwd( + fun: Callable, has_aux: bool = False, jitted: bool = True + ) -> tuple[Callable, Callable]: + """Creates functions ``fwd`` and ``bwd`` corresponding to the forward and + backward pass of a given function ``fun``. The forward function ``fwd(*args)`` + functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows + reuse of the backward function ``bwd`` across multiple iterations, which is + useful to avoid recompilation when the forward and backward do not end up in a + single jitted function: + + >>> import jax + >>> + >>> x = W = cot_out = jax.numpy.ones((4,4)) + >>> + >>> def f(x, W): + ... return x @ W + ... + >>> f_jitted = jax.jit(f) + >>> for i in range(3): + ... y, f_vjp = jax.vjp(f_jitted, x, W) + ... cot_x, cot_W = f_vjp(cot_out) # not jitted + ... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration + ... + >>> fwd, bwd = jax.fwd_and_bwd(f) + >>> for i in range(3): + ... y, residuals = fwd(x, W) + ... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once + ... + + Args: + fun: Function to produce a forward and backward of. + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + jitted: Optional, bool. Indicates whether to return the ``jax.jit`` of + forward and backward. Note that jit-ing only the backward but not the + forward will result in the backward recompiling on every invocation, so we + default to jit-ing both. + + Returns: + The two functions, ``fwd`` and ``bwd``. + + If ``has_aux`` is ``False``, ``fwd(*primals)`` returns a tuple + ``(primals_out, residuals)``, where ``primals_out`` is ``fun(*primals)``. + If ``has_aux`` is ``True``, returns a ``(primals_out, residuals, aux)`` tuple + where ``aux`` is the auxiliary data returned by ``fun``. + + ``bwd`` is a function from ``residuals`` and a cotangent vector with the same + shape as ``primals_out`` to a tuple of cotangent vectors with the same number + and shapes as ``primals``, representing the vector-Jacobian product of ``fun`` + evaluated at ``primals``. + """ + def fwd(*args): + return vjp(fun, *args, has_aux=has_aux) # type: ignore + def bwd(f_vjp, outgrad): + return f_vjp(outgrad) + if jitted: + fwd = jit(fwd) + bwd = jit(bwd) + return fwd, bwd + def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False) -> Callable: diff --git a/tests/api_test.py b/tests/api_test.py index 7590b6e6af71..b1f04a223ed6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1627,6 +1627,27 @@ def f(x): assert g(2.0) == 4.0 assert len(side) == 1 + @jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace. + def test_fwd_and_bwd(self): + def f(x, W): + return x @ W + + x = W = cot_out = jnp.ones((4,4)) + expected_y, f_vjp = api.vjp(f, x, W) + expected_cot_x, expected_cot_W = f_vjp(cot_out) + + fwd, bwd = api.fwd_and_bwd(f) + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) + + self.assertArraysAllClose(y, expected_y) + self.assertArraysAllClose(cot_x, expected_cot_x) + self.assertArraysAllClose(cot_W, expected_cot_W) + + with jax.no_tracing(): + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) # no recompilation + @parameterized.named_parameters( {"testcase_name": f"_{transform.__name__}", "transform": transform} for transform in [grad, jacfwd, jacrev]) From aed3297fe2c3cb87f2918e68f1cbe6dbc351be96 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Tue, 15 Apr 2025 02:25:38 -0700 Subject: [PATCH 0625/1769] [CI] Update GPU optional presubmit naming PiperOrigin-RevId: 747782436 --- ...azel_optional_cuda.yml => bazel_optional_b200.yml} | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) rename .github/workflows/{bazel_optional_cuda.yml => bazel_optional_b200.yml} (90%) diff --git a/.github/workflows/bazel_optional_cuda.yml b/.github/workflows/bazel_optional_b200.yml similarity index 90% rename from .github/workflows/bazel_optional_cuda.yml rename to .github/workflows/bazel_optional_b200.yml index 71936aeb9ae8..6335fbacaf2c 100644 --- a/.github/workflows/bazel_optional_cuda.yml +++ b/.github/workflows/bazel_optional_b200.yml @@ -1,5 +1,6 @@ -name: CI - Bazel Optional CUDA tests +name: CI - Bazel Optional B200 CUDA tests on: + # Runs on PR if label "CI Optional GPU Presubmit" is present. workflow_dispatch: inputs: halt-for-connection: @@ -25,13 +26,9 @@ concurrency: jobs: run_tests: if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} - runs-on: ${{ matrix.runner }} + runs-on: linux-x86-a4-224-b200-1gpu container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' - strategy: - matrix: - # Optional gpus to run against - runner: ["linux-x86-a4-224-b200-1gpu"] - name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})" + name: "Bazel single B200 CUDA tests" # End Presubmit Naming Check github-cuda-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 From 06a77b7ba517594a9a58561af2340f1b78bb7420 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Tue, 15 Apr 2025 06:25:03 -0700 Subject: [PATCH 0626/1769] [CI] Propagate halt connection to tpu tests PiperOrigin-RevId: 747848417 --- .github/workflows/cloud-tpu-ci-presubmit.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index a92e3cc19313..40c99735c2de 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -62,4 +62,5 @@ jobs: python: "3.10" libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} + halt-for-connection: ${{ inputs.halt-for-connection || false }} # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file From 0b047396a12d817f7bf81a4fe21eda71c60c2708 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 15 Apr 2025 07:34:23 -0700 Subject: [PATCH 0627/1769] Deprecate the remaining exports from jax.lib.xla_client. Not all of these have replacements, but I want to mark these as deprecated now so they can be removed in a future release. PiperOrigin-RevId: 747867544 --- CHANGELOG.md | 4 ++ jax/lib/xla_client.py | 92 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3374071f190a..d46d9d01a0ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a callable. + * The following exports in `jax.lib.xla_client` are deprecated: + `get_topology_for_devices`, `heap_profile`, `mlir_api_version`, `Client`, + `CompileOptions`, `DeviceAssignment`, `Frame`, `HloSharding`, `OpSharding`, + `Traceback`. * The following internal APIs in `jax.util` are deprecated: `HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index bd4d98462f11..cc4fa78eb576 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -14,17 +14,6 @@ from jax._src.lib import xla_client as _xc -get_topology_for_devices = _xc.get_topology_for_devices -heap_profile = _xc.heap_profile -mlir_api_version = _xc.mlir_api_version -Client = _xc.Client -CompileOptions = _xc.CompileOptions -DeviceAssignment = _xc.DeviceAssignment -Frame = _xc.Frame -HloSharding = _xc.HloSharding -OpSharding = _xc.OpSharding -Traceback = _xc.Traceback - _deprecations = { # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( @@ -85,6 +74,77 @@ ), None, ), + # Added April 4 2025. + "get_topology_for_devices": ( + ( + "jax.lib.xla_client.get_topology_for_devices was deprecated in JAX" + " v0.6.0 and will be removed in JAX v0.7.0" + ), + _xc.get_topology_for_devices, + ), + "heap_profile": ( + ( + "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.heap_profile, + ), + "mlir_api_version": ( + ( + "jax.lib.xla_client.mlir_api_version was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + _xc.mlir_api_version, + ), + "Client": ( + ( + "jax.lib.xla_client.Client was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Client, + ), + "CompileOptions": ( + ( + "jax.lib.xla_client.CompileOptions was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.CompileOptions, + ), + "DeviceAssignment": ( + ( + "jax.lib.xla_client.DeviceAssignment was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + _xc.DeviceAssignment, + ), + "Frame": ( + ( + "jax.lib.xla_client.Frame was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Frame, + ), + "HloSharding": ( + ( + "jax.lib.xla_client.HloSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.HloSharding, + ), + "OpSharding": ( + ( + "jax.lib.xla_client.OpSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.OpSharding, + ), + "Traceback": ( + ( + "jax.lib.xla_client.Traceback was deprecated in JAX v0.6.0 and will" + " be removed in JAX v0.7.0" + ), + _xc.Traceback, + ), } import typing as _typing @@ -92,6 +152,16 @@ if _typing.TYPE_CHECKING: Shape = _xc.Shape XlaComputation = _xc.XlaComputation + get_topology_for_devices = _xc.get_topology_for_devices + heap_profile = _xc.heap_profile + mlir_api_version = _xc.mlir_api_version + Client = _xc.Client + CompileOptions = _xc.CompileOptions + DeviceAssignment = _xc.DeviceAssignment + Frame = _xc.Frame + HloSharding = _xc.HloSharding + OpSharding = _xc.OpSharding + Traceback = _xc.Traceback else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr From 4d692d159a7331f4e263f07d5d4e8c3da1548048 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 15 Apr 2025 07:45:58 -0700 Subject: [PATCH 0628/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0bfaa69a78c4306d6421f7ba78638427d9ce53e8. PiperOrigin-RevId: 747871119 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index afb9b809d2f6..23cb610724f2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f0b96b5f16b9374121fb21e9b751d1f941352932" -XLA_SHA256 = "c37eefef6204cd1215760fc90608e9d270126f959a502f2920c8e7332ab69353" +XLA_COMMIT = "0bfaa69a78c4306d6421f7ba78638427d9ce53e8" +XLA_SHA256 = "0605b276868b46ec8007fc9ef597e3ca8b3185eaa38a5cab67d7995ed2f3a3e5" def repo(): tf_http_archive( From c56cf4f68d0623c2cd2ec30760ad3385e8dc3c54 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 15 Apr 2025 08:18:30 -0700 Subject: [PATCH 0629/1769] jax.random.bernoulli: use higher-resolution sampler --- CHANGELOG.md | 3 +++ jax/_src/random.py | 8 +++++++- tests/random_lax_test.py | 6 ++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d46d9d01a0ff..3bd701777c36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.jit` now requires `fun` to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. + * {func}`jax.random.beroulli` now has higher resolution, and can correctly handle + values of `p` down to about `1E-10`. Previously results were incorrect for `p` + smaller than about `1E-7`. ({jax-issue}`#28022`) * Deprecations diff --git a/jax/_src/random.py b/jax/_src/random.py index b29c1dca7b08..b50a9653a993 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -958,7 +958,13 @@ def _bernoulli(key, p, shape) -> Array: else: _check_shape("bernoulli", shape, np.shape(p)) - return uniform(key, shape, lax.dtype(p)) < p + # we could return uniform(key, shape) < p. But uniform sacrifices some + # resolution, so instead we sample in the space of the generated bits. + nbits = dtypes.finfo(p.dtype).bits + unsigned_dtype = UINT_DTYPES[nbits] + samples = bits(key, shape, unsigned_dtype) + cutoff = (p * _lax_const(p, 1 << nbits)).astype(unsigned_dtype) + return (p > 0) & (samples < cutoff) | (p >= 1) def beta(key: ArrayLike, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index b4d2853abd65..6230b647149b 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -461,6 +461,12 @@ def testBernoulliShape(self): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) + def testBernoulliSmallProbabilty(self): + # Regression test for https://github.com/jax-ml/jax/issues/28017 + key = jax.random.key(0) + samples = jax.random.bernoulli(key, p=1E-10, shape=int(1E8)) + self.assertEqual(samples.sum(), 0) + @jtu.sample_product( a=[0.2, 5.], b=[0.2, 5.], From 34c2dbfdc81fa1b6fbd8e4916ded6ea26dd6d40c Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Thu, 20 Mar 2025 15:13:19 -0500 Subject: [PATCH 0630/1769] Fix lowering code for ROCm RNN --- jax/experimental/rnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 55cf2b3bae70..be06aba2db13 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -463,7 +463,7 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') - if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"): + if hasattr(gpu_rnn, "miopen_rnn_lowering"): mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_lowering, platform='rocm') From c527ddb7bf4528115c39bc680110437bb849d884 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 15 Apr 2025 10:40:24 -0700 Subject: [PATCH 0631/1769] [Mosaic:TPU] Fix bug in `rotateVregRows` Shift ops expect RHS to have the same type as RHS, but a scalar was being used for RHS. PiperOrigin-RevId: 747934619 --- .../dialect/tpu/transforms/apply_vector_layout.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 9f12da5237bd..d6452cba8b8d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5544,15 +5544,17 @@ Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, if (is_high) { auto shift_amt = builder.create( loc, - builder.getIntegerAttr(builder.getI32Type(), - bits_per_row * within_sublane_rotate_amount)); + DenseElementsAttr::get( + i32_vreg_ty, static_cast(bits_per_row * + within_sublane_rotate_amount))); vreg = builder.create(loc, vreg, shift_amt); } else { auto shift_amt = builder.create( - loc, builder.getIntegerAttr( - builder.getI32Type(), - bits_per_row * - (rows_per_sublane - within_sublane_rotate_amount))); + loc, + DenseElementsAttr::get( + i32_vreg_ty, static_cast( + bits_per_row * (rows_per_sublane - + within_sublane_rotate_amount)))); vreg = builder.create(loc, vreg, shift_amt); } vreg = builder.create(loc, vreg_ty, vreg); From 6e00b5e02d5bc6dd7d78bc783397013baea418a5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 15 Apr 2025 11:01:49 -0700 Subject: [PATCH 0632/1769] [NFC] Rename `standard_insert_pbroadcast` to `standard_insert_pvary` PiperOrigin-RevId: 747943230 --- jax/_src/ad_util.py | 2 +- jax/_src/core.py | 2 +- jax/_src/ffi.py | 2 +- jax/_src/lax/control_flow/solves.py | 2 +- jax/_src/lax/convolution.py | 2 +- jax/_src/lax/lax.py | 70 ++++++++++++++--------------- jax/_src/lax/linalg.py | 14 +++--- jax/_src/lax/slicing.py | 20 ++++----- jax/_src/lax/special.py | 14 +++--- jax/_src/lax/windowed_reductions.py | 8 ++-- jax/_src/prng.py | 2 +- 11 files changed, 69 insertions(+), 69 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c8e64ce5c2ef..8cfd7b214338 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,7 +31,7 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') diff --git a/jax/_src/core.py b/jax/_src/core.py index 9455f849a1eb..f66fb0928ab2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2026,7 +2026,7 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups): pvary_p.def_abstract_eval(_pvary_abstract_eval) -def standard_insert_pbroadcast(*args): +def standard_insert_pvary(*args): if not config.varying_axes_in_types.value: return args if not config._check_rep.value: diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 22d54d39913d..f0c7d761ac2b 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -499,7 +499,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - args = core.standard_insert_pbroadcast(*args) + args = core.standard_insert_pvary(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 2c736f403044..f34c98c6aaae 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -310,7 +310,7 @@ def f_aux(x): matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) args = _flatten(all_consts) + b_flat - args = core.standard_insert_pbroadcast(*args) + args = core.standard_insert_pvary(*args) out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 32294bbd72cf..28d67adb6413 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -158,7 +158,7 @@ def conv_general_dilated( preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) - lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f1e41eed3021..d68e742a3fd7 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -369,7 +369,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ - x1, x2 = core.standard_insert_pbroadcast(x1, x2) + x1, x2 = core.standard_insert_pvary(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -775,7 +775,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return atan2_p.bind(x, y) @export @@ -845,7 +845,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return complex_p.bind(x, y) @export @@ -917,7 +917,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return pow_p.bind(x, y) @export @@ -1072,7 +1072,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return and_p.bind(x, y) @export @@ -1099,7 +1099,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return or_p.bind(x, y) @export @@ -1126,7 +1126,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return xor_p.bind(x, y) @export @@ -1191,7 +1191,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return add_p.bind(x, y) @export @@ -1215,7 +1215,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return sub_p.bind(x, y) @export @@ -1239,7 +1239,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return mul_p.bind(x, y) @export @@ -1269,7 +1269,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return div_p.bind(x, y) @export @@ -1297,7 +1297,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return rem_p.bind(x, y) @export @@ -1323,7 +1323,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return max_p.bind(x, y) @export @@ -1349,7 +1349,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return min_p.bind(x, y) @export @@ -1375,7 +1375,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return shift_left_p.bind(x, y) @export @@ -1402,7 +1402,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1429,7 +1429,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1460,7 +1460,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return eq_p.bind(x, y) @export @@ -1491,7 +1491,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return ne_p.bind(x, y) @export @@ -1522,7 +1522,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return ge_p.bind(x, y) @export @@ -1553,7 +1553,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return gt_p.bind(x, y) @export @@ -1584,7 +1584,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return le_p.bind(x, y) @export @@ -1615,7 +1615,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ - x, y = core.standard_insert_pbroadcast(x, y) + x, y = core.standard_insert_pvary(x, y) return lt_p.bind(x, y) @export @@ -1771,7 +1771,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ - min, x, max = core.standard_insert_pbroadcast(min, x, max) + min, x, max = core.standard_insert_pvary(min, x, max) return clamp_p.bind(min, x, max) @@ -1878,7 +1878,7 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) - flat_args = core.standard_insert_pbroadcast(*flat_args) + flat_args = core.standard_insert_pvary(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, @@ -1996,7 +1996,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op - operands = core.standard_insert_pbroadcast(*operands) + operands = core.standard_insert_pvary(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -2520,7 +2520,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) - lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2656,7 +2656,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ - lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes) + lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2840,7 +2840,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ - operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value) + operand, padding_value = core.standard_insert_pvary(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -2873,7 +2873,7 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. - pred, on_false, on_true = core.standard_insert_pbroadcast( + pred, on_false, on_true = core.standard_insert_pvary( pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) @@ -2900,7 +2900,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") - which, *cases = core.standard_insert_pbroadcast(which, *cases) + which, *cases = core.standard_insert_pvary(which, *cases) return select_n_p.bind(which, *cases) @@ -3262,7 +3262,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) - operand = core.standard_insert_pbroadcast(*operand) + operand = core.standard_insert_pvary(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -8111,7 +8111,7 @@ def after_all(*operands): """Merges one or more XLA token values. Experimental. Wraps the XLA AfterAll operator.""" - operands = core.standard_insert_pbroadcast(*operands) + operands = core.standard_insert_pvary(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -8246,7 +8246,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ - a, b = core.standard_insert_pbroadcast(a, b) + a, b = core.standard_insert_pvary(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -8930,7 +8930,7 @@ def optimization_barrier(operand, /): """ flat_args, treedef = tree_util.tree_flatten(operand) # TODO(yashkatariya): Enable this - # flat_args = core.standard_insert_pbroadcast(flat_args) + # flat_args = core.standard_insert_pvary(flat_args) out = optimization_barrier_p.bind(*flat_args) return tree_util.tree_unflatten(treedef, out) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a3e0a71a671c..a49936373835 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,7 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ - r_matrix, w_vector = core.standard_insert_pbroadcast(r_matrix, w_vector) + r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) @@ -269,7 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ - a, taus = core.standard_insert_pbroadcast(a, taus) + a, taus = core.standard_insert_pvary(a, taus) return householder_product_p.bind(a, taus) @@ -547,7 +547,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ - a_matrix, c_matrix = core.standard_insert_pbroadcast(a_matrix, c_matrix) + a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -605,7 +605,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) - a, b = core.standard_insert_pbroadcast(a, b) + a, b = core.standard_insert_pvary(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -665,7 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ - dl, d, du, b = core.standard_insert_pbroadcast(dl, d, du, b) + dl, d, du, b = core.standard_insert_pvary(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -1658,7 +1658,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): if m == 0 or k == 0: return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k - permutation, swaps = core.standard_insert_pbroadcast(permutation, swaps) + permutation, swaps = core.standard_insert_pvary(permutation, swaps) result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, (permutation, swaps)) return result @@ -1774,7 +1774,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ - a, jpvt = core.standard_insert_pbroadcast(a, jpvt) + a, jpvt = core.standard_insert_pvary(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6a9dff8b4823..ed7c2f9f2777 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -173,7 +173,7 @@ def dynamic_slice( else: dynamic_sizes = [] static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore - operand, *start_indices = core.standard_insert_pbroadcast( + operand, *start_indices = core.standard_insert_pvary( operand, *start_indices) return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes, slice_sizes=tuple(static_sizes)) @@ -236,7 +236,7 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) - operand, update, *start_indices = core.standard_insert_pbroadcast( + operand, update, *start_indices = core.standard_insert_pvary( operand, update, *start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -420,7 +420,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None - operand, start_indices = core.standard_insert_pbroadcast(operand, start_indices) + operand, start_indices = core.standard_insert_pvary(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -510,7 +510,7 @@ def scatter_add( """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, @@ -566,7 +566,7 @@ def scatter_sub( jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_sub_p.bind( operand, @@ -622,7 +622,7 @@ def scatter_mul( """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, @@ -671,7 +671,7 @@ def scatter_min( """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, @@ -720,7 +720,7 @@ def scatter_max( """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, @@ -786,7 +786,7 @@ def scatter_apply( pass jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand))) # TODO: implement this via its own primitive so we can define appropriate autodiff rules. - operand, scatter_indices, unused = core.standard_insert_pbroadcast( + operand, scatter_indices, unused = core.standard_insert_pvary( operand, scatter_indices, unused) return scatter_p.bind( operand, scatter_indices, unused, update_jaxpr=jaxpr, @@ -871,7 +871,7 @@ def scatter( ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([0., 2., 3., 0., 4.], dtype=float32) """ - operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates = core.standard_insert_pvary( operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index a59d62523c9f..a486bda28486 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -59,7 +59,7 @@ def up_and_broadcast(*args): def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" - a, b, x = core.standard_insert_pbroadcast(a, b, x) + a, b, x = core.standard_insert_pvary(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -72,33 +72,33 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" - m, x = core.standard_insert_pbroadcast(m, x) + m, x = core.standard_insert_pvary(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" - a, x = core.standard_insert_pbroadcast(a, x) + a, x = core.standard_insert_pvary(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" - a, x = core.standard_insert_pbroadcast(a, x) + a, x = core.standard_insert_pvary(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" - a, x = core.standard_insert_pbroadcast(a, x) + a, x = core.standard_insert_pvary(a, x) return igamma_grad_a_p.bind(a, x) @_up_and_broadcast def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" - a, x = core.standard_insert_pbroadcast(a, x) + a, x = core.standard_insert_pvary(a, x) return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" - x, q = core.standard_insert_pbroadcast(x, q) + x, q = core.standard_insert_pvary(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 472b92d858f9..8ebdcd6c3784 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -97,7 +97,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') - flat_operands = core.standard_insert_pbroadcast(*flat_operands) + flat_operands = core.standard_insert_pvary(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -251,7 +251,7 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) - operand, source, init_value = core.standard_insert_pbroadcast( + operand, source, init_value = core.standard_insert_pvary( operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, @@ -264,7 +264,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: - source, operand = core.standard_insert_pbroadcast(source, operand) + source, operand = core.standard_insert_pvary(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -300,7 +300,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ - tangents, operand = core.standard_insert_pbroadcast(tangents, operand) + tangents, operand = core.standard_insert_pvary(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 1dc7e9c0df0e..6945c2f51f7c 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -618,7 +618,7 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): msgs = jnp.asarray(msgs) - keys, msgs = core.standard_insert_pbroadcast(keys, msgs) + keys, msgs = core.standard_insert_pvary(keys, msgs) return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') From d093b3ee35900f3a0baddeb94533b7c24e9fc6bf Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sun, 13 Apr 2025 10:16:34 +0000 Subject: [PATCH 0633/1769] Add docs for GPU memory spaces in Pallas:MGPU --- docs/_static/pallas/gpu/grid_tiling_off.svg | 175 +++++++++++++++++++ docs/_static/pallas/gpu/grid_tiling_on.svg | 183 ++++++++++++++++++++ docs/pallas/gpu/reference.md | 93 +++++++++- 3 files changed, 449 insertions(+), 2 deletions(-) create mode 100644 docs/_static/pallas/gpu/grid_tiling_off.svg create mode 100644 docs/_static/pallas/gpu/grid_tiling_on.svg diff --git a/docs/_static/pallas/gpu/grid_tiling_off.svg b/docs/_static/pallas/gpu/grid_tiling_off.svg new file mode 100644 index 000000000000..b11d85759ce4 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_off.svg @@ -0,0 +1,175 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/grid_tiling_on.svg b/docs/_static/pallas/gpu/grid_tiling_on.svg new file mode 100644 index 000000000000..9d24a8187179 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_on.svg @@ -0,0 +1,183 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 5cebdf364e5b..b3ad1d85d132 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -27,7 +27,7 @@ but multiple warps can be assigned to the same SM subdivision. At each clock cyc warp scheduler from each subdivision tries to select one of its resident warps to execute the next instruction. -
A diagram of one NVIDIA SM
+
A diagram of one NVIDIA SM
Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are 4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming @@ -82,7 +82,96 @@ For more information on how warp scheduling and instruction issue works, we reco ### Memory spaces -TODO: GMEM, SMEM, RMEM, (maybe TMEM) +The GPU features a few different memory spaces that can be totally ordered from largest (in +terms of capacity) and slowest (in both total bandwidth and latency of a single access). + +The biggest memory space is `plgpu.GMEM`, for _global memory_. In recent data-center grade GPUs +this memory space is often measured in tens or even hudreds of gigabytes, but it is also the +slowest one. + +The next memory space, used for the L2 cache, is also more or less global in the +sense that it is shared by the whole GPU, but its use can only be influenced indirectly through +cache hints. As such, there's no way to manually place values in there and so this memory space +is not exposed in Pallas:MGPU. While only about a 100MB in size, this memory has considerably +higher bandwidth than GMEM, and so it is still often recommended to take advantage of it while +writing high-performance kernels. + +Next in line is _shared memory_, or `plgpu.SMEM`. This memory is located directly inside each SM +and so it is partitioned. Unless block clusters are used (see the section of clusters below), +each block is only allowed to access its own SMEM allocations. + +Finally, the lowest level memory space is the _register memory_. This is where every single value +(i.e. JAX array) in a Pallas kernel will be located. If the compiler runs out of registers to +store those arrays, it will insert _spills_, meaning that it will periodically store and reload +values to memory. Those spills often introduce other significant performance degradations and so +we recommend avoiding them. The warning messages about spills can be clearly seen in the `ptxas` +messages during kernel compilation. To make them visible, run with `MOSAIC_GPU_DUMP_PTXAS=1` +in your environment. + +The Blackwell GPU generation, has one additional memory space called _tensor memory_ or `plgpu.TMEM`. +TMEM is very similar to register memory, only it is explicitly allocated and managed by you. +It is used to store the MMA accumulator, operand metadata (for sparsity or scaling), +and optionally the left MMA operand. See the Blackwell MMA section for more information about TMEM. + +#### Requesting/allocating memory in specific memory spaces + +Kernel inputs or outputs are placed in SMEM by default. If you want to access them as GMEM references +add `memory_space=plgpu.GMEM` to their `BlockSpec`. If you want the kernel to be called with the whole +input or output array in GMEM, it is sufficient to specify `BlockSpec(memory_space=plgpu.GMEM)`. + +`SMEM` and `TMEM` can be allocated explicitly in the `scratch_shapes` argument of `pl.pallas_call`, +or using `pl.run_scoped`. To allocate a reference, simply call the memory space object with the +requested shape and dtype. For example: `plgpu.SMEM((128, 128), jnp.float16)` will allocate a 128x128 +array of float16 elements in shared memory. + +#### Taking advantage of the L2 cache + +While the L2 cache cannot be managed manually, its noticeably higher bandwidth compared to global +memory makes it worth thinking about. The simplest way to take advantage of it, is to reorder +the parallel grid dimensions so that invocations that are scheduled in similar time periods also +access the same input data. + +While the CUDA programming model does not guarantee anything about the order in which the blocks +are assigned to SMs, in recent generations the heuristic seems to simply iterate over the +`(x, y, z)` CUDA grids in column-major order (i.e. `x` is the fastest-changing dimension and +`z` is the slowest). Similarly, Pallas:MGPU does not guarantee how a user-specified grid is mapped to +the CUDA grid (Pallas supports grids of arbitrary rank, not just up to 3D). However, you can assume that +the iteration will happen in _row-major_ order. That is, if a grid has dimensions `(a, b)`, then +`b` will be the fastest-changing dimension and `a` will be the slower one. + +To give a practical example of this, consider a plain matrix multiplication kernel. There, one +usually uses two parallel grid dimensions `(m, n)`, corresponding to tiling the two non-contracting +dimensions. If we use this simple scheme, in Pallas:MGPU all programs with id `(0, ...)` will be +scheduled before any block with id `(1, ...)`. And, collectively, the programs with `m=0` have to +read all of the `B` operand! If the `n` or `k` dimensions are very large, there is no chance that +we'll be able to get cache hits from the `(1, ...)` programs from accesses made by the `(0, ...)` +programs. For simplicity, assuming we can only run 16 blocks at a time, we see this access pattern +from the first scheduled wave: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks without grid tiling. + +
+ +However, if we simply rearrange the grid to be `(m // mt, n, mt)` (and then replace `pl.program_id(0)` +with `pl.program_id(0) * mt + pl.program_id(2)` in the kernel), it is straightforward to see that a +band of programs along both dimensions will be scheduled concurrently (instead of scheduling a single +row). This greatly increases the number of concurrent programs that load similar slices of data, +usually significantly improves the L2 utilization and hence the overall performance of the kernel +(if it was memory bound). Continuing our example with 16 blocks and using `mt=4`, we get the following +access pattern: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks with grid tiling. + +
+ +Note that even though the number of active blocks hasn't changed, the total footprint of the data they +access has halved! We get a much higher chance of getting L2 hits now. ## Array layouts and reference transforms From 33513f6eba6457fde25180f612d03aea80fa2b97 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 11 Apr 2025 19:18:23 +0000 Subject: [PATCH 0634/1769] Make Clang use manylinux C++ standard library --- .github/workflows/rocm-ci.yml | 2 +- .../build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm | 8 +++++++- build/rocm/build_wheels/clang.cfg | 3 +++ build/rocm/ci_build | 5 ++++- build/rocm/tools/fixwheel.py | 2 +- 5 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 build/rocm/build_wheels/clang.cfg diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 713e9099e381..0ce20726ce63 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -19,7 +19,7 @@ jobs: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} PYTHON_VERSION: "3.10" - ROCM_VERSION: "6.2.4" + ROCM_VERSION: "6.3.3" WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} steps: - name: Clean up old runs diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 8afe8b17252c..3ca491568911 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel + dnf install -y numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ @@ -25,5 +25,11 @@ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/a mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project +# Set some clang config +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang.cfg + # Stop git from erroring out when we don't own the repo RUN git config --global --add safe.directory '*' diff --git a/build/rocm/build_wheels/clang.cfg b/build/rocm/build_wheels/clang.cfg new file mode 100644 index 000000000000..767c04c03ae7 --- /dev/null +++ b/build/rocm/build_wheels/clang.cfg @@ -0,0 +1,3 @@ +# Tell clang where it can find gcc so that it can use gcc's standard libraries +--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/ + diff --git a/build/rocm/ci_build b/build/rocm/ci_build index ef43a95044d8..71ce747d7e86 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -98,7 +98,10 @@ def dist_wheels( bw_cmd.append("/jax") - cmd = ["docker", "run"] + cmd = [ + "docker", + "run", + ] mounts = [ "-v", diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py index ea77162728d5..7d8c1fcce055 100644 --- a/build/rocm/tools/fixwheel.py +++ b/build/rocm/tools/fixwheel.py @@ -87,7 +87,7 @@ def fix_wheel(path): exclude = list(ext_libs.keys()) # call auditwheel repair with excludes - cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + cmd = ["auditwheel", "-v", "repair", "--plat", plat, "--only-plat"] for ex in exclude: cmd.append("--exclude") From ba8877789dc9ab4c1b15f3834ee26474002fefaf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 15 Apr 2025 12:59:08 -0700 Subject: [PATCH 0635/1769] Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages. Reverts b336daf747940301de5956dce4ebe790298e6b5b PiperOrigin-RevId: 747988862 --- CHANGELOG.md | 3 --- jax/_src/random.py | 8 +------- tests/random_lax_test.py | 6 ------ 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bd701777c36..d46d9d01a0ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,9 +43,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.jit` now requires `fun` to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. - * {func}`jax.random.beroulli` now has higher resolution, and can correctly handle - values of `p` down to about `1E-10`. Previously results were incorrect for `p` - smaller than about `1E-7`. ({jax-issue}`#28022`) * Deprecations diff --git a/jax/_src/random.py b/jax/_src/random.py index b50a9653a993..b29c1dca7b08 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -958,13 +958,7 @@ def _bernoulli(key, p, shape) -> Array: else: _check_shape("bernoulli", shape, np.shape(p)) - # we could return uniform(key, shape) < p. But uniform sacrifices some - # resolution, so instead we sample in the space of the generated bits. - nbits = dtypes.finfo(p.dtype).bits - unsigned_dtype = UINT_DTYPES[nbits] - samples = bits(key, shape, unsigned_dtype) - cutoff = (p * _lax_const(p, 1 << nbits)).astype(unsigned_dtype) - return (p > 0) & (samples < cutoff) | (p >= 1) + return uniform(key, shape, lax.dtype(p)) < p def beta(key: ArrayLike, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 6230b647149b..b4d2853abd65 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -461,12 +461,6 @@ def testBernoulliShape(self): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) - def testBernoulliSmallProbabilty(self): - # Regression test for https://github.com/jax-ml/jax/issues/28017 - key = jax.random.key(0) - samples = jax.random.bernoulli(key, p=1E-10, shape=int(1E8)) - self.assertEqual(samples.sum(), 0) - @jtu.sample_product( a=[0.2, 5.], b=[0.2, 5.], From 393c555a02ae7a947c9cb591c2477c8d282747ea Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 15 Apr 2025 13:24:08 -0700 Subject: [PATCH 0636/1769] Fix bugs in tp_traverse handlers. * A garbage collected type bound by nanobind is allocated with `PyType_GenericAlloc` (https://github.com/wjakob/nanobind/blob/8f245bd6e5544ff828beb46435af24b8769bfcdc/src/nb_type.cpp#L81). * `PyType_GenericAlloc` calls `_PyObject_GC_TRACK` (https://github.com/python/cpython/blob/884df116d79b05d9342e05e50484d61c684ecb8b/Objects/typeobject.c#L2357), and as soon as an object is tracked by the GC, it may be traversed (https://docs.python.org/3/c-api/gcsupport.html#c.PyObject_GC_Track). * However, at this point, the nanobind-bound object has been allocated, but its C++ constructor need not yet have run. The constructor runs when `__init__` is called. * In tp_traverse methods of nanobind-bound classes, we should test whether the instance is ready before visiting any C++ state. Instance readiness will be set at the end of the dispatch of `__init__` (https://github.com/wjakob/nanobind/blob/8f245bd6e5544ff828beb46435af24b8769bfcdc/src/nb_func.cpp#L857). In addition, all nanobind-bound classes have heap-allocated types. Make sure we visit `Py_TYPE(self)`. PiperOrigin-RevId: 747998575 --- jaxlib/weakref_lru_cache.cc | 5 ++++- jaxlib/xla/config.cc | 4 ++++ jaxlib/xla/py_client.cc | 4 ++++ jaxlib/xla/pytree.cc | 10 ++++++++-- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc index a6199a22e38c..0e3b9b831b82 100644 --- a/jaxlib/weakref_lru_cache.cc +++ b/jaxlib/weakref_lru_cache.cc @@ -349,8 +349,11 @@ void WeakrefLRUCache::Clear() { /*static*/ int WeakrefLRUCache::tp_traverse(PyObject* self, visitproc visit, void* arg) { - WeakrefLRUCache* cache = nb::inst_ptr(self); Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + WeakrefLRUCache* cache = nb::inst_ptr(self); Py_VISIT(cache->cache_context_fn_.ptr()); Py_VISIT(cache->fn_.ptr()); for (const auto& [wr_key, wr_value] : cache->entries_) { diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc index c68ff7f4ac54..8804b783eb72 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/xla/config.cc @@ -291,6 +291,10 @@ void Config::SetGlobal(nb::object value) { /* static */ int Config::tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } Config* c = nb::inst_ptr(self); // For the purposes of GC, we pretend that this object owns both the global // and any thread-local values corresponding to this key. diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 2ce11e7e76c7..c4e8449b85c3 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -662,6 +662,10 @@ absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( /* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } PyClient* c = nb::inst_ptr(self); for (const auto& [ifrt_device, py_device] : c->devices_) { Py_VISIT(py_device.ptr()); diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index 175e753515d0..9359165b19dd 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -420,8 +420,11 @@ nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, /* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, void* arg) { - PyTreeRegistry* registry = nb::inst_ptr(self); Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeRegistry* registry = nb::inst_ptr(self); nb::ft_lock_guard lock(registry->mu_); for (const auto& [key, value] : registry->registrations_) { Py_VISIT(key.ptr()); @@ -1596,8 +1599,11 @@ int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { /* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, void* arg) { - PyTreeDef* treedef = nb::inst_ptr(self); Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeDef* treedef = nb::inst_ptr(self); Py_VISIT(treedef->registry_ref_.ptr()); for (const auto& node : treedef->traversal_) { node.tp_traverse(visit, arg); From b271a67bbc04359a9087905fd6f2badfe157f880 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 15 Apr 2025 14:36:56 -0700 Subject: [PATCH 0637/1769] Clean up softmax initial deprecation --- jax/_src/nn/functions.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index eed3f7658c5d..640c8b89d001 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -48,12 +48,6 @@ from jax._src.ops.special import logsumexp as _logsumexp -class Unspecified: - def __repr__(self): - return "_UNSPECIFIED" -_UNSPECIFIED = Unspecified() - - # activations @jax.jit def identity(x: ArrayLike) -> Array: @@ -525,8 +519,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: @partial(jax.jit, static_argnames=("axis",)) def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -552,10 +545,6 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") - del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True) @@ -573,8 +562,7 @@ def log_softmax(x: ArrayLike, # @partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -600,10 +588,6 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") - del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns From 738891306bfa3981eec0e02789317eac5b8d5cac Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 15 Apr 2025 14:38:39 -0700 Subject: [PATCH 0638/1769] [ragged-paged-attn][NFC] Set kv_pages_per_blk uplimit. PiperOrigin-RevId: 748027325 --- .../pallas/ops/tpu/ragged_paged_attention.py | 11 ++++++++--- tests/pallas/tpu_ragged_paged_attention_test.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 760a5eef089b..203dc8a7602a 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -221,7 +221,7 @@ def static_validate_inputs( _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 - max_num_seqs, _ = page_indices.shape + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") if head_dim_k != head_dim: @@ -254,8 +254,13 @@ def static_validate_inputs( raise ValueError(f"{sliding_window=} must be positive.") if soft_cap is not None and soft_cap == 0.0: raise ValueError(f"{soft_cap=} must not be 0.0.") - if num_kv_pages_per_block is not None and num_kv_pages_per_block <= 0: - raise ValueError(f"{num_kv_pages_per_block=} must be positive.") + if ( + num_kv_pages_per_block is not None + and not 0 < num_kv_pages_per_block <= pages_per_seq + ): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." + ) if num_queries_per_block is not None and num_queries_per_block <= 0: raise ValueError(f"{num_queries_per_block=} must be positive.") if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index e617a1b3b06c..f86d54575519 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -110,7 +110,7 @@ def _test_ragged_paged_attention( page_indices, cu_q_lens, num_seqs=num_seqs, - num_kv_pages_per_block=num_kv_pages_per_block, + num_kv_pages_per_block=min(num_kv_pages_per_block, pages_per_seq), num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, sliding_window=sliding_window, From 655bfcac39285e54ed64220f734717f06f66d725 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 15 Apr 2025 14:38:45 -0700 Subject: [PATCH 0639/1769] Enable standard_insert_pvary for optimization_barrier which was disabled before. PiperOrigin-RevId: 748027360 --- jax/_src/lax/lax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d68e742a3fd7..79bd42607290 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8929,13 +8929,13 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - # TODO(yashkatariya): Enable this - # flat_args = core.standard_insert_pvary(flat_args) + flat_args = core.standard_insert_pvary(*flat_args) out = optimization_barrier_p.bind(*flat_args) return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): + core.standard_vma_rule('optimization_barrier', *args) return args def _optimization_barrier_lowering_rule(ctx, *args): From 47bc2f55dc0e583c7b59880af66401dcee27867e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 15 Apr 2025 17:10:37 -0700 Subject: [PATCH 0640/1769] convert NumPy RNG key data to uncommitted default-device-backed `jax.Array` data Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction. Co-authored-by: Yash Katariya PiperOrigin-RevId: 748077900 --- jax/_src/prng.py | 20 ++++++++++++-------- tests/random_test.py | 20 ++++++++++++++++---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 6945c2f51f7c..e588a3894be4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -51,7 +51,8 @@ from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, physical_sharding, logical_sharding) + NamedSharding, PmapSharding, SingleDeviceSharding, physical_sharding, + logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -156,7 +157,7 @@ class behave like an array whose base elements are keys, hiding the # device_buffer, device_buffers, __cuda_interface__() _impl: PRNGImpl - _base_array: typing.Array + _base_array: jax.Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. _source_info: None | source_info_util.SourceInfo = None @@ -164,8 +165,14 @@ def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) self._impl = impl - self._base_array = key_data self._consumed = False # TODO(jakevdp): default to True here? + # If key_data is a numpy array, convert it to an uncommitted CPU jax.Array + if isinstance(key_data, np.ndarray): + aval = core.get_aval(key_data) + device = pxla.get_default_device() + key_data = pxla.batched_device_put(aval, SingleDeviceSharding(device), + [key_data], [device], committed=False) + self._base_array = key_data def block_until_ready(self): _ = self._base_array.block_until_ready() @@ -176,11 +183,8 @@ def copy_to_host_async(self): @property def aval(self): - logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') - else None) - vma = (self._base_array.aval.vma if hasattr(self._base_array, 'aval') - else frozenset()) - return keys_shaped_array(self._impl, self.shape, logical_sharding, vma) + vma = self._base_array.aval.vma + return keys_shaped_array(self._impl, self.shape, self.sharding, vma) @property def shape(self): diff --git a/tests/random_test.py b/tests/random_test.py index e5db7c55dafb..d2b465456d38 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -602,10 +602,26 @@ def assertKeysEqual(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype) self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def make_keys(self, *shape, seed=28): + seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) + return jax.vmap(random.key)(seeds).reshape(shape) + def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_numpy_construction(self): + key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32), + impl='threefry2x32') + self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key._base_array, jax.Array) + self.assertEqual(key._base_array.device, jax.devices()[0]) + self.assertEqual(key.device, jax.devices()[0]) + + def test_device_property(self): + key = random.key(42) + self.assertEqual(key.device, key._base_array.device) + def test_random_clone(self): # Here we test value semantics and compatibility with jit/vmap # key reuse semantics are tested in key_reuse_test.py @@ -632,10 +648,6 @@ def test_construction_upgrade_flag(self): key = random.PRNGKey(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) - def make_keys(self, *shape, seed=28): - seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) - return jax.vmap(random.key)(seeds).reshape(shape) - def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): From 90af5977863fcc47994ad7fdb65758812ee4e145 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 15 Apr 2025 17:38:52 -0700 Subject: [PATCH 0641/1769] remove inaccurate inline comment in `PRNGKeyArray` constructor PiperOrigin-RevId: 748085747 --- jax/_src/prng.py | 1 - tests/random_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index e588a3894be4..dd91097fcf98 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -166,7 +166,6 @@ def __init__(self, impl, key_data: Any): _check_prng_key_data(impl, key_data) self._impl = impl self._consumed = False # TODO(jakevdp): default to True here? - # If key_data is a numpy array, convert it to an uncommitted CPU jax.Array if isinstance(key_data, np.ndarray): aval = core.get_aval(key_data) device = pxla.get_default_device() diff --git a/tests/random_test.py b/tests/random_test.py index d2b465456d38..d75f3a9c5e2e 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -612,7 +612,7 @@ def test_construction(self): def test_numpy_construction(self): key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32), - impl='threefry2x32') + impl='threefry2x32') self.assertIsInstance(key, prng_internal.PRNGKeyArray) self.assertIsInstance(key._base_array, jax.Array) self.assertEqual(key._base_array.device, jax.devices()[0]) From 2beff6a1df03cc5299ae97e9298e829550b9b0ad Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 15 Apr 2025 17:42:55 -0700 Subject: [PATCH 0642/1769] [pallas] Fix case of `Fusible{ElementDtype,TyRules}`. The first letter was inadvertently made lower-case in the previous re-naming CL. PiperOrigin-RevId: 748086763 --- jax/_src/pallas/fuser/fusible_dtype.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 8e6cfefcc9eb..628d253e090a 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -54,7 +54,7 @@ @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, fusibleElementDType): + if dtypes.issubdtype(dtype, FusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,7 +69,7 @@ def pack(*xs, dtype): @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, fusibleElementDType): + if dtypes.issubdtype(x.dtype, FusibleElementDType): return x.dtype.abstract_unpack(x) elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): raise NotImplementedError() @@ -80,11 +80,11 @@ def unpack(x): return unpack_dtype_p.bind(x) -class fusibleElementDType(dtypes.extended): +class FusibleElementDType(dtypes.extended): """Scalar dtype for fusible dtypes.""" -class fusibleTyRules: +class FusibleTyRules: allow_conversion: bool = False @@ -92,8 +92,8 @@ class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = fusibleTyRules - type = fusibleElementDType + _rules = FusibleTyRules + type = FusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: From 552005536899a0eb3bcf5b1a1e2f861cf124af9c Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 15 Apr 2025 17:47:28 -0700 Subject: [PATCH 0643/1769] Fix test flakyness by blocking until the data is ready. PiperOrigin-RevId: 748087650 --- tests/array_interoperability_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 15fd6306da6f..adfd34627e76 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -111,6 +111,8 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy): x = jax.device_put(np, jax.devices("cpu")[0]) device = jax.devices("gpu")[0] y = jax.device_put(x, device) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) dl_device = y.__dlpack_device__() if use_stream: stream = tuple(y.devices())[0].get_stream_for_external_ready_events() @@ -153,6 +155,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) y = jax.dlpack.from_dlpack(x) self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) @@ -206,6 +210,8 @@ def testJaxToTensorFlow(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) @@ -327,6 +333,8 @@ def testJaxToCuPy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(y) z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) @@ -362,6 +370,8 @@ def testCaiToJax(self, shape, dtype): device = jax.devices('cuda')[-1] with jax.default_device(device): y = jnp.array(x, dtype=dtype) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) self.assertEqual(y.dtype, dtype) # Using a jax array CAI provider support to construct an object From 770dae72cb1883979b7b1709c5d5bde56232566c Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 16 Apr 2025 01:00:25 -0700 Subject: [PATCH 0644/1769] [Pallas][Mosaic][TPU] Add `disable_bounds_checks` compiler params When we run the program with "--xla_jf_bounds_check=true", we can selectively disable bounds checks for pallas kernels now. PiperOrigin-RevId: 748193719 --- jax/_src/pallas/mosaic/core.py | 20 +++++++++------- .../pallas/mosaic/pallas_call_registration.py | 1 + jax/_src/tpu_custom_call.py | 14 +++++++++++ tests/pallas/tpu_ops_test.py | 24 +++++++++++++++++++ 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 37b6e51892c7..0fe825d44858 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -64,23 +64,23 @@ class TPUCompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. Attributes: - dimension_semantics: A list of dimension semantics for each grid - dimension of the kernel. Either "parallel" for dimensions that can - execute in any order, or "arbitrary" for dimensions that must be - executed sequentially. + dimension_semantics: A list of dimension semantics for each grid dimension + of the kernel. Either "parallel" for dimensions that can execute in any + order, or "arbitrary" for dimensions that must be executed sequentially. allow_input_fusion: A list of booleans indicating whether input fusion is allowed for each argument. - vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note - that this must be used in conjunction with the + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note that + this must be used in conjunction with the --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. - collective_id: Indicates which barrier semaphore to use for the kernel. - Note that using the same collective_id does not guarantee that - the same barrier semaphore will be allocated between kernels. + collective_id: Indicates which barrier semaphore to use for the kernel. Note + that using the same collective_id does not guarantee that the same barrier + semaphore will be allocated between kernels. internal_scratch_in_bytes: The size of the internal scratch space used by Mosaic. flags: A dictionary of command line flags for the kernel. serialization_format: The serialization format for the kernel body. device_type: The device type to compile for. + disable_bounds_checks: Disable bounds checks in the kernel. """ PLATFORM: ClassVar[str] = "mosaic" dimension_semantics: ( @@ -94,7 +94,9 @@ class TPUCompilerParams(pallas_core.CompilerParams): internal_scratch_in_bytes: int | None = None serialization_format: int = 1 device_type: str | None = None + disable_bounds_checks: bool = False + # Replace is a method, not a field. replace = dataclasses.replace class TPUMemorySpace(enum.Enum): diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 896af0c464c5..824eb7e89716 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -243,6 +243,7 @@ def _maybe_cast_inputs(*args): collective_id=mosaic_params.get("collective_id", None), has_side_effects=mosaic_params.get("has_side_effects", False), output_memory_spaces=output_memory_spaces, + disable_bounds_checks=mosaic_params.get("disable_bounds_checks"), ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index cbec7f873156..32236bb6ae90 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -143,6 +143,7 @@ class CustomCallBackendConfig: serialization_format: int | None internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None + disable_bounds_checks: bool # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -193,6 +194,9 @@ def to_json(self) -> bytes: color = memory_space.color if memory_space is not None else -1 config.write(str(color).encode("ascii")) config.write(b"]") + if self.disable_bounds_checks: + config.write(b', "disable_bounds_checks": ') + config.write(str(self.disable_bounds_checks).lower().encode("ascii")) config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -573,6 +577,7 @@ def _lower_to_custom_call_config( output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, kernel_name: str | None = None, ir_version: int | None = None, + disable_bounds_checks: bool = False, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -601,6 +606,7 @@ def _lower_to_custom_call_config( needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=disable_bounds_checks, ) @@ -620,6 +626,7 @@ def _lowered_to_custom_call_config( needs_layout_passes: bool, device_type: str | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, ): if has_custom_barrier: if collective_id is None: @@ -650,6 +657,7 @@ def _lowered_to_custom_call_config( serialization_format, internal_scratch_in_bytes, output_memory_spaces, + disable_bounds_checks, ) return config @@ -672,6 +680,7 @@ def lower_module_to_custom_call( serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, device_type: str | None, + disable_bounds_checks: bool = False, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( module, @@ -687,6 +696,7 @@ def lower_module_to_custom_call( output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, ir_version=get_ir_version(ctx), + disable_bounds_checks=disable_bounds_checks, ) return _tpu_custom_call_lowering( ctx, @@ -715,6 +725,7 @@ def as_tpu_kernel( has_side_effects: bool = False, serialization_format: int | None = 1, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" device_type = _get_device_type(module) @@ -731,6 +742,7 @@ def as_tpu_kernel( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, + disable_bounds_checks=disable_bounds_checks, ) return _as_jax_callable( config, @@ -760,6 +772,7 @@ def lowered_as_tpu_kernel( input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = None, internal_scratch_in_bytes: int | None = None, + disable_bounds_checks: bool = False, ) -> Callable[..., Any]: lowered_module_asm = lowered_module.operation.get_asm( binary=True, enable_debug_info=True @@ -778,6 +791,7 @@ def lowered_as_tpu_kernel( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + disable_bounds_checks=disable_bounds_checks, ) return _as_jax_callable( config, diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 2cb0cfff09e8..1fb0bc24701b 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -487,6 +487,30 @@ def kernel(x, out): expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) np.testing.assert_array_equal(output, expected) + # We need to manually run the test with the env variable + # `export LIBTPU_INIT_ARGS="--xla_jf_bounds_check=true"` + def test_disable_bounds_check(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 16): + self.skipTest("Requires libtpu built after 2025-04-16") + if jtu.get_tpu_version() < 4: + self.skipTest("Requires TPUv4+") + src_shape = (8, 128) + tgt_shape = (16, 256) + + def kernel(src, tgt): + tgt[:] = pl.load(src, tuple(pl.ds(0, d) for d in tgt.shape)) + + x = jnp.arange(np.prod(src_shape), dtype=jnp.float32).reshape(src_shape) + run = pl.pallas_call( + kernel, + jax.ShapeDtypeStruct(tgt_shape, jnp.float32), + compiler_params=pltpu.TPUCompilerParams(disable_bounds_checks=True), + ) + output = run(x) + np.testing.assert_array_equal( + output[tuple(slice(0, d) for d in src_shape)], x + ) + @jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsInterpretTest(OpsTest): From 98a9ef449f5afe02ffbe5ffb539d112fd990505c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 16 Apr 2025 10:31:38 +0000 Subject: [PATCH 0645/1769] [Pallas:MGPU] Add a diagram outlining GPU memory spaces --- docs/_static/pallas/gpu/memory_spaces.svg | 96 +++++++++++++++++++++++ docs/pallas/gpu/reference.md | 2 + 2 files changed, 98 insertions(+) create mode 100644 docs/_static/pallas/gpu/memory_spaces.svg diff --git a/docs/_static/pallas/gpu/memory_spaces.svg b/docs/_static/pallas/gpu/memory_spaces.svg new file mode 100644 index 000000000000..73dc31a12406 --- /dev/null +++ b/docs/_static/pallas/gpu/memory_spaces.svg @@ -0,0 +1,96 @@ + + + + + + Faster / Smaller Capacity + + + Slower / Larger Capacity + + + + + + Registers (RMEM) + Fastest Latency & BW + Smallest Capacity + + Holds arrays (in Pallas). + Spills if full! + + + + + Tensor Memory (TMEM) + Fastest Latency & BW + Smallest Capacity + + Explicitly managed. + Blackwell specific. + + + + + + Shared Memory (SMEM) + Fast (close to compute) + Small Capacity (per SM) + Partitioned into private slices for each CUDA block/cluster. + + + + L2 Cache + Moderate Speed + Moderate Capacity (~100MBs) + Shared betwen SMs, not directly programmable. + + + + Global Memory (GMEM) + Slowest Latency & Bandwidth + Largest Capacity (GBs) + Main GPU memory (HBM/GDDR technology). + + + + + diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index b3ad1d85d132..a958b3704327 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -85,6 +85,8 @@ For more information on how warp scheduling and instruction issue works, we reco The GPU features a few different memory spaces that can be totally ordered from largest (in terms of capacity) and slowest (in both total bandwidth and latency of a single access). +
A diagram of memory spaces of an NVIDIA GPU
+ The biggest memory space is `plgpu.GMEM`, for _global memory_. In recent data-center grade GPUs this memory space is often measured in tens or even hudreds of gigabytes, but it is also the slowest one. From 19e2315b1f329392f376286d3ab0e8f36531d610 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 16 Apr 2025 06:02:03 -0700 Subject: [PATCH 0646/1769] [pallas] Use `*Params` dataclasses to pass `compiler_params=` to `pl.pallas_call` Dataclasses provide better ergonomics and also improve type safety, making it easier for Pallas maintainers to change the structure of `compiler_params=`. PiperOrigin-RevId: 748266655 --- jax/experimental/pallas/ops/gpu/layer_norm.py | 2 +- jax/experimental/pallas/ops/gpu/rms_norm.py | 6 +- tests/pallas/tpu_pallas_distributed_test.py | 2 +- tests/pallas/tpu_pallas_pipeline_test.py | 62 ++++++++++--------- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index d37afaf4d9e0..187d74ee1fd9 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -247,7 +247,7 @@ def layer_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index ff224c6dfde7..baeaeb8a57b3 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -228,7 +228,7 @@ def rms_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -264,8 +264,8 @@ def rms_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), grid=(), out_shape=out_shape, diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 737ab5137e99..3d4d441d7cd0 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -235,7 +235,7 @@ def body(x): in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=x, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(x) device_mesh = mesh_utils.create_device_mesh( diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 535eaaedf058..f29182e56314 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -716,20 +716,20 @@ def _wait_on_prev_dma(): pl.BlockSpec(memory_space=memory_space), pl.BlockSpec(memory_space=memory_space), ], - out_specs=[pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space)], + out_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], grid=(outer_steps, 2), - scratch_shapes=[ - pltpu.VMEM((tm, tn), jnp.float32)] + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1006,15 +1006,13 @@ def _loop_epilogue(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1269,15 +1267,13 @@ def _prefetch_accumulator(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1358,7 +1354,9 @@ def mul_kernel(iters_ref, x_ref, y_ref): out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), ), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) x = jax.random.uniform(jax.random.key(0), (640, 640)) np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) @@ -1392,7 +1390,9 @@ def matmul_kernel(x_ref, y_ref): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x), x * 2) @@ -1441,7 +1441,9 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) From dabe7b7ee99aab553aad1f6f3e4d373472df04f7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 14 Apr 2025 15:22:19 +0000 Subject: [PATCH 0647/1769] Add basic documentation for Pallas:MGPU layouts and transforms --- docs/pallas/gpu/reference.md | 114 ++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 2 deletions(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index a958b3704327..d67f80fb0ebe 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -175,9 +175,119 @@ access pattern: Note that even though the number of active blocks hasn't changed, the total footprint of the data they access has halved! We get a much higher chance of getting L2 hits now. -## Array layouts and reference transforms +## Array layouts and memory reference transforms -TODO +In Pallas, the data structures you work with (arrays and references) have a +**logical shape** (e.g., a 128x128 matrix). This +logical shape must be mapped to a **physical representation** (how the data is +actually represented in the GPU's memory). The specific mapping depends on where the +data resides: + +1. **Array Layouts:** Arrays are stored in register memory and we call this mapping + a _layout_. Layouts define how the elements of an array are + distributed across the registers available to the CUDA lanes that form a Pallas thread. +2. **Memory Reference Transforms:** For mutable references pointing + to `SMEM`, this mapping is called a _transform_. + Transforms describe how the logical data structure is arranged within that + block of memory. + +These concepts are crucial for performance, especially when interacting with +specialized hardware units like TensorCores or optimizing memory access +patterns. + +> We are working on a mode that will deal with assigning layouts and transforms fully + automatically (although with way to provide hints and more control). The APIs listed + below will likely continue to function, but will become optional. + +### Memory reference transforms + +Transforms are applied when a memory reference is first allocated. Pallas +primitives that operate on these references will automatically account for their +associated transforms. + +``` +def body(..., scratch_ref): + # Asynchronous copy will reformat the GMEM data to match the SMEM transforms + plgpu.copy_gmem_to_smem(..., scratch_ref, barrier) + barrier.wait() + plgpu.wgmma(..., scratch_ref) # wgmma only accepts properly transformed refs + ... +``` + +There are two ways in which references are allocated and each has a way to select +the desired transforms: + +**1. Using `GPUBlockSpec`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + in_specs=plgpu.GPUBlockSpec(in_block_shape, in_index_map, transforms=transforms), + out_specs=plgpu.GPUBlockSpec(out_block_shape, out_index_map, transforms=transforms), + ... +) +``` + +**2. Specifying the `transforms` argument on the allocated `SMEM`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms), + ... +) +``` + +The available transforms are: +* `plgpu.TileTransform(tile_shape)`, which organizes the data into contiguous, + non-overlapping tiles of shape `tile_shape`. The data of one tile is always + fully linearized (row-major), before another tile begins (tiles are also + traversed in row-major order). As an example, applying `TileTransform((8, + 64))` to a `(128, 128)` reference means the data corresponding to the logical + slice `[0:8, 0:64]` will be stored first (row-major), followed by + `[0:8, 64:128], [8:16, 0:64], [8:16, 64:128]`, and so on. A different way to achieve + this would be to take the input array `x` and traverse + `x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)` in row-major order. +* `plgpu.SwizzleTransform(swizzle_in_bytes)`, which transforms the data as described in the + [PTX docs](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-swizzling-modes) and + [CUDA docs](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#the-swizzle-modes). + Swizzling is useful, because it allows transferring data in MMA-related layouts + between register and shared memory without bank conflicts. The exact details + of how the memory looks like after swizzling _are not that important_, since + all primitives will account for it automatically. Note that the swizzle amount + is specified in bytes (only 128, 64, 32 and 16 are supported), and is usually + accompanied by a `TileTransform` (which uses elements in its shape!). +* `plgpu.TransposeTransform(permutation)`, which permutes the dimensions of the array before it is linearized. + This is primarily useful in that it lets you change the layout during the GMEM-SMEM copies (only + do keep in mind that changing the minormost/last dimension is not supported by the hardware). + + +### Array layouts + +There are a few useful layouts we have defined for you so far: +* `plgpu.Layout.WGMMA`, which is the layout in which the Hopper-generation TensorCore + expects the MMA accumulator or 16-bit input operands to have in registers. +* `plgpu.Layout.WGMMA_ROW`, which is the layout obtained after the above after reducing + it along the rows. Re-broadcasting the rows is free and will produce a value with `WGMMA` + layout. +* `plgpu.Layout.WGMMA_COL`, which is an analogue of the one above, only reduced along + columns instead of rows. +* `plgpu.Layout.WG_STRIDED`, where the value is partitioned equally among the 128 + CUDA lanes making up a Pallas thread. The consecutive elements (after vectorization) + are assigned to the lanes in a round-robin fashion. Very simple and effective when + no interaction with TensorCores is needed. +* `plgpu.Layout.WG_SPLAT`, indicating that the value is constant. Each CUDA lane will + hold a single register that contains the value. You normally never have to interact + with this layout, as it is implicitly used when constant values are created and + is always implicitly convertible to other layouts. + +At the moment, in the default mode of operation, array layout propagation happens +only in a forward direction and there is little implicit support for reconciling +layout conflicts: only splat layouts can be implicitly converted into any other +layout. If you e.g. try to add two arrays that have a different layout, the lowering +will complain and fail. There are very limited facilities that let you convert between +layouts, and we usually recommend storing the value to SMEM and reading it back in +the target layout. ## MMA (TensorCore) From 7d66cdd308a9e11f665034a7072f5c192fb63b3f Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 16 Apr 2025 07:20:43 -0700 Subject: [PATCH 0648/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4bc073325136932a634b3f972c5493b68f95a0d4. PiperOrigin-RevId: 748284934 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 23cb610724f2..d0cca161bd2d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0bfaa69a78c4306d6421f7ba78638427d9ce53e8" -XLA_SHA256 = "0605b276868b46ec8007fc9ef597e3ca8b3185eaa38a5cab67d7995ed2f3a3e5" +XLA_COMMIT = "4bc073325136932a634b3f972c5493b68f95a0d4" +XLA_SHA256 = "61c1b1116c94a7155dcf5bfe9a407befb266538f304735413f7f21c1719972ab" def repo(): tf_http_archive( From cd5d48214bf735f746296f96581b5c1cf63d5186 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 15 Apr 2025 17:54:36 -0700 Subject: [PATCH 0649/1769] [Pallas] Refactor pipelining docs. --- docs/conf.py | 2 + docs/pallas/index.rst | 1 + docs/pallas/pipelining.ipynb | 861 +++++++++++++++++++++++++++++++++++ docs/pallas/pipelining.md | 589 ++++++++++++++++++++++++ 4 files changed, 1453 insertions(+) create mode 100644 docs/pallas/pipelining.ipynb create mode 100644 docs/pallas/pipelining.md diff --git a/docs/conf.py b/docs/conf.py index cddb63653a17..153d3b7ab867 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,6 +133,7 @@ def _do_not_evaluate_in_jax( # These are kept in sync using the jupytext pre-commit hook. 'notebooks/*.md', 'pallas/quickstart.md', + 'pallas/pipelining.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', @@ -228,6 +229,7 @@ def _do_not_evaluate_in_jax( 'notebooks/convolutions.ipynb', # Requires accelerators 'pallas/quickstart.*', + 'pallas/pipelining.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 6c1a048298c1..8e1a9816212c 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -22,6 +22,7 @@ See also the :class:`jax.experimental.pallas` module API documentation. :maxdepth: 2 quickstart + pipelining grid_blockspec diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb new file mode 100644 index 000000000000..69029698c0c5 --- /dev/null +++ b/docs/pallas/pipelining.ipynb @@ -0,0 +1,861 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + } + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Software Pipelining\n", + "\n", + "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", + "\n", + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references.\n" + ], + "metadata": { + "id": "C93Xlf0DRW9H" + } + }, + { + "cell_type": "code", + "source": [ + "import jax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ], + "metadata": { + "id": "YkOjspo5BKPD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Memory Hierarchies\n", + "\n", + "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", + "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", + "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", + "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", + "It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register.\n", + "- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM.\n", + "- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices.\n", + "\n", + "\n", + "\n", + "\n", + "![memory_hierarchy]()" + ], + "metadata": { + "id": "8vCtShhBjzTd" + } + }, + { + "cell_type": "markdown", + "source": [ + "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", + "\n", + "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." + ], + "metadata": { + "id": "Qs3F--kwiOJm" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Deriving a Double-Buffered Pipeline\n", + "\n", + "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", + "\n", + "
\n",
+        "for i in range(N):\n",
+        "  copy_in(A[i], X)\n",
+        "  Y = X + 1\n",
+        "  copy_out(Y, A[i])\n",
+        "
\n", + "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", + "\n", + "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", + "
\n",
+        "  # Itr 1\n",
+        "  copy_in_start(A[0], X)\n",
+        "  copy_in_wait(X)\n",
+        "  Y = X + 1\n",
+        "  copy_out_start(Y, A[0])\n",
+        "  copy_out_wait(Y)\n",
+        "\n",
+        "  # Itr 2\n",
+        "  copy_in_start(A[1], X)\n",
+        "  copy_in_wait(X)\n",
+        "  Y = X + 1\n",
+        "  copy_out_start(Y, A[1])\n",
+        "  copy_out_wait(Y)\n",
+        "\n",
+        "  # Itr 3\n",
+        "  copy_in_start(A[2], X)\n",
+        "  copy_in_wait(X)\n",
+        "  Y = X + 1\n",
+        "  copy_out_start(Y, A[2])\n",
+        "  copy_out_wait(Y)\n",
+        "\n",
+        "  # Itr 4\n",
+        "  copy_in_start(A[3], X)\n",
+        "  copy_in_wait(X)\n",
+        "  Y = X + 1\n",
+        "  copy_out_start(Y, A[3])\n",
+        "  copy_out_wait(Y)\n",
+        "
\n", + "\n", + "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", + "
\n",
+        "  # Prologue\n",
+        "  copy_in_start(A[0], X[0])\n",
+        "  \n",
+        "  # Itr 1\n",
+        "  copy_in_start(A[1], X[1])\n",
+        "  copy_in_wait(X[0])\n",
+        "  Y[0] = X[0] + 1\n",
+        "  copy_out_start(Y[0], A[0])\n",
+        "  copy_out_wait(Y[0])\n",
+        "\n",
+        "  # Itr 2 - Steady state\n",
+        "  copy_in_start(A[2], X[0])\n",
+        "  copy_in_wait(X[1])\n",
+        "  Y[1] = X[1] + 1\n",
+        "  copy_out_start(Y[1], A[1])\n",
+        "  copy_out_wait(Y[1])\n",
+        "\n",
+        "  # Itr 3 - Steady state\n",
+        "  copy_in_start(A[3], X[1])\n",
+        "  copy_in_wait(X[0])\n",
+        "  Y[0] = X[0] + 1\n",
+        "  copy_out_start(Y[0], A[2])\n",
+        "  copy_out_wait(Y[0])\n",
+        "\n",
+        "  # Itr 4 - No copy-in\n",
+        "  copy_in_wait(X[1])\n",
+        "  Y[1] = X[1] + 1\n",
+        "  copy_out_start(Y[1], A[2])\n",
+        "  copy_out_wait(Y[1])\n",
+        "
\n", + "\n", + "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", + "\n", + "
\n",
+        "  # Prologue\n",
+        "  copy_in_start(A[0], X[0])\n",
+        "  \n",
+        "  # Itr 1\n",
+        "  copy_in_start(A[1], X[1])\n",
+        "  copy_in_wait(X[0])\n",
+        "  Y[0] = X[0] + 1\n",
+        "  copy_out_start(Y[0], A[0])\n",
+        "\n",
+        "  # Itr 2 - Steady state\n",
+        "  copy_in_start(A[2], X[0])\n",
+        "  copy_in_wait(X[1])\n",
+        "  Y[1] = X[1] + 1\n",
+        "  copy_out_start(Y[1], A[1])\n",
+        "  copy_out_wait(Y[0])\n",
+        "\n",
+        "  # Itr 3 - Steady state\n",
+        "  copy_in_start(A[3], X[1])\n",
+        "  copy_in_wait(X[0])\n",
+        "  Y[0] = X[0] + 1\n",
+        "  copy_out_start(Y[0], A[2])\n",
+        "  copy_out_wait(Y[1])\n",
+        "\n",
+        "  # Itr 4 - No copy-in\n",
+        "  copy_in_wait(X[1])\n",
+        "  Y[1] = X[1] + 1\n",
+        "  copy_out_start(Y[1], A[2])\n",
+        "  copy_out_wait(Y[0])\n",
+        "\n",
+        "  # Epilogue\n",
+        "  copy_out_wait(Y[1])\n",
+        "
\n", + "\n", + "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", + "\n", + "```\n", + "# Prologue\n", + "copy_in_start(A[0], X[0])\n", + "\n", + "# Main loop\n", + "for i in range(N):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + "\n", + " if i < N:\n", + " copy_in_start(A[i+1], X[next_slot])\n", + " \n", + " copy_in_wait(X[cur_slot])\n", + " Y[cur_slot] = X[cur_slot] + 1\n", + " copy_out_start(Y[cur_slot], A[i])\n", + "\n", + " if i > 0:\n", + " copy_out_wait(Y[next_slot])\n", + "\n", + "# Epilogue\n", + "copy_out_wait(Y[1])\n", + "```\n", + "\n", + "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", + "\n", + "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", + "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", + "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", + "\n", + "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", + "```python\n", + "def double_buffered_pipeline(\n", + " grid: tuple[int, ...],\n", + " kernel: Callable,\n", + " in_slices: Callable,\n", + " out_slices: Callable):\n", + " # Prologue\n", + " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", + "\n", + " # Main loop\n", + " grid_size = prod(grid)\n", + " for i in range(grid_size):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + " if i < grid_size:\n", + " copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])\n", + " copy_in_wait(in_sram[cur_slot])\n", + "\n", + " kernel(inputs, outputs)\n", + "\n", + " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", + " if i > 0:\n", + " copy_out_wait(out_sram[next_slot])\n", + "\n", + " # Epilogue\n", + " copy_out_wait(out_sram[1])\n", + "```" + ], + "metadata": { + "id": "ZcSzl4N6pPbG" + } + }, + { + "cell_type": "markdown", + "source": [ + "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." + ], + "metadata": { + "id": "ziBuvv8jDgxo" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Pallas Pipelining API\n", + "\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "\n", + "\n", + "### Grid\n", + "\n", + "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", + "\n", + "```\n", + "# For grid (N, M, K)\n", + "for n in range (N):\n", + " for m in range(M):\n", + " for k in range(K):\n", + " kernel()\n", + "```\n", + "\n", + "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### BlockSpecs\n", + "\n", + "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", + "\n", + "```python\n", + "pl.BlockSpec(\n", + " block_shape: tuple[int, ...],\n", + " index_map: Callable,\n", + " memory_space: pl.MemorySpace\n", + ")\n", + "```\n", + "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### Kernel\n", + "\n", + "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", + "\n", + "```python\n", + "def kernel(*input_buffers, *output_buffers):\n", + " # ... perform compute\n", + " # ... store result into output buffers\n", + "```\n", + "\n", + "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", + "\n", + "\n", + "### Pallas Call\n", + "\n", + "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", + "```python\n", + "def pallas_call(\n", + " kernel,\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[PyTree[BlockSpec]],\n", + " out_specs: PyTree[BlockSpec],\n", + " out_shape: PyTree[jax.ShapeDtypeStruct],\n", + ") -> Callable:\n", + "```\n", + "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", + "\n", + "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" + ], + "metadata": { + "id": "niMr39cPkJ2m" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Example - Elementwise Kernel revisited\n", + "\n", + "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." + ], + "metadata": { + "id": "0mHZ63eAq_8j" + } + }, + { + "cell_type": "code", + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "total_shape = (4096, 4096)\n", + "block_shape = (512, 512)\n", + "\n", + "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", + " o_ref[...] = x_ref[...] + y_ref[...]\n", + "\n", + "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", + " return pl.pallas_call(\n", + " add_matrices_pipelined_kernel,\n", + " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", + " in_specs=[\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", + " ],\n", + " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", + " )(x, y)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", + "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", + "result = add_matrices_pipelined(x, y)\n", + "np.testing.assert_array_equal(\n", + " result, x + y\n", + ")" + ], + "metadata": { + "id": "iqr_qjONAHN9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" + ], + "metadata": { + "id": "UWHD0_qm6DL7" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Parameterizing a Kernel\n", + "\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" + ], + "metadata": { + "id": "BZ-4U6Cv6cvU" + } + }, + { + "cell_type": "code", + "source": [ + "def add_matrices_pipelined_param(\n", + " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", + ") -> jax.Array:\n", + " m, n = x.shape\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", + " return pl.pallas_call(\n", + " add_matrices_kernel,\n", + " out_shape=x,\n", + " in_specs=[block_spec, block_spec],\n", + " out_specs=block_spec,\n", + " grid=(m // bm, n // bn),\n", + " )(x, y)\n", + "\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", + ")" + ], + "metadata": { + "id": "RZTAiwrZ6srD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Sharp edges\n", + "\n", + "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", + "\n", + "### Buffer Revisiting\n", + "\n", + "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", + "\n", + "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", + "\n", + "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", + "\n", + "\n", + "### Reductions and accumulation\n", + "\n", + "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", + "\n", + "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "\n", + "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "vO8VkbYj_ral" + } + }, + { + "cell_type": "code", + "source": [ + "x = jnp.ones((8, 1024, 1024))\n", + "jnp.sum(x, axis=0)" + ], + "metadata": { + "id": "4qz1ET-_f9fJ", + "executionInfo": { + "status": "ok", + "timestamp": 1744763773938, + "user_tz": 420, + "elapsed": 244, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + } + }, + "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Array([[8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " ...,\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." + ], + "metadata": { + "id": "yX762DRrgCOG" + } + }, + { + "cell_type": "code", + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "# Warning: this implementation is incorrect!\n", + "def incorrect_sum_kernel(x_ref, o_ref):\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def incorrect_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", + " return pl.pallas_call(\n", + " incorrect_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = incorrect_sum(x)\n", + "print(result)" + ], + "metadata": { + "id": "ZEi1_vQVf-81", + "executionInfo": { + "status": "ok", + "timestamp": 1744763774254, + "user_tz": 420, + "elapsed": 79, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + } + }, + "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[[65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " ...\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This result is completely wrong!\n", + "\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "\n", + "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." + ], + "metadata": { + "id": "MglScPDD9618" + } + }, + { + "cell_type": "code", + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def correct_sum_kernel(x_ref, o_ref):\n", + " @pl.when(pl.program_id(2) == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def correct_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " # We moved the reduction to the last axis of the grid.\n", + " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", + " return pl.pallas_call(\n", + " correct_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = correct_sum(x)\n", + "print(result)" + ], + "metadata": { + "id": "XtgD4nMa9_Bd", + "executionInfo": { + "status": "ok", + "timestamp": 1744763774523, + "user_tz": 420, + "elapsed": 104, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + } + }, + "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[[8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " ...\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "## Analyzing the performance\n", + "\n", + "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", + "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", + "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", + "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", + "\n", + "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", + "\n", + "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", + "\n", + "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", + "\n" + ], + "metadata": { + "id": "BckuFg6qcnVw" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "![compute_bound]()" + ], + "metadata": { + "id": "NDY4mcae_nMO" + } + }, + { + "cell_type": "markdown", + "source": [ + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + X / \\beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + ], + "metadata": { + "id": "HFWcaAudW4z1" + } + }, + { + "cell_type": "markdown", + "source": [ + "![bandwidth_bound]()" + ], + "metadata": { + "id": "gqcCDsGg_sca" + } + }, + { + "cell_type": "markdown", + "source": [ + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + ], + "metadata": { + "id": "V4YQCZf1W7X5" + } + }, + { + "cell_type": "markdown", + "source": [ + "![latency_multi_stage]()" + ], + "metadata": { + "id": "Sj5PFl0s_yc6" + } + }, + { + "cell_type": "markdown", + "source": [ + "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." + ], + "metadata": { + "id": "ar4NVxxFfKEb" + } + } + ] +} \ No newline at end of file diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md new file mode 100644 index 000000000000..92b2f2b6bcca --- /dev/null +++ b/docs/pallas/pipelining.md @@ -0,0 +1,589 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + +# Software Pipelining + +Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. + +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references. + + + +```python id="YkOjspo5BKPD" +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +import numpy as np +``` + + +## Memory Hierarchies + +The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: +- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them. +- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers. +SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2). +It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register. +- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM. +- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices. + + + + +![memory_hierarchy]() + + + +The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally "hide" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible. + +The initial startup time and final teardown time known as "bubbles", where only a subset of the stages are being executed while the pipeline is being "filled" or "drained". The bulk of the time is spent in the "steady-state" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor. + + + +### Deriving a Double-Buffered Pipeline + +Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`: + +
+for i in range(N):
+  copy_in(A[i], X)
+  Y = X + 1
+  copy_out(Y, A[i])
+
+The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to "pre-fetch" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously. + +In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony: +
+  # Itr 1
+  copy_in_start(A[0], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[0])
+  copy_out_wait(Y)
+
+  # Itr 2
+  copy_in_start(A[1], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[1])
+  copy_out_wait(Y)
+
+  # Itr 3
+  copy_in_start(A[2], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[2])
+  copy_out_wait(Y)
+
+  # Itr 4
+  copy_in_start(A[3], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[3])
+  copy_out_wait(Y)
+
+ +Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows: +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+  copy_out_wait(Y[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[1])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[0])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[2])
+  copy_out_wait(Y[1])
+
+ +Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration. + +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[0])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[1])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[2])
+  copy_out_wait(Y[0])
+
+  # Epilogue
+  copy_out_wait(Y[1])
+
+ +Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop: + +``` +# Prologue +copy_in_start(A[0], X[0]) + +# Main loop +for i in range(N): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + + if i < N: + copy_in_start(A[i+1], X[next_slot]) + + copy_in_wait(X[cur_slot]) + Y[cur_slot] = X[cur_slot] + 1 + copy_out_start(Y[cur_slot], A[i]) + + if i > 0: + copy_out_wait(Y[next_slot]) + +# Epilogue +copy_out_wait(Y[1]) +``` + +If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline: + +- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`. +- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`. +- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`. + +By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern: +```python +def double_buffered_pipeline( + grid: tuple[int, ...], + kernel: Callable, + in_slices: Callable, + out_slices: Callable): + # Prologue + copy_in_start(in_hbm[in_slices(0)], in_sram[0]) + + # Main loop + grid_size = prod(grid) + for i in range(grid_size): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + if i < grid_size: + copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot]) + copy_in_wait(in_sram[cur_slot]) + + kernel(inputs, outputs) + + copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)]) + if i > 0: + copy_out_wait(out_sram[next_slot]) + + # Epilogue + copy_out_wait(out_sram[1]) +``` + + + +Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API. + + + +## Pallas Pipelining API + +Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. + + +### Grid + +The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop. + +``` +# For grid (N, M, K) +for n in range (N): + for m in range(M): + for k in range(K): + kernel() +``` + +The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### BlockSpecs + +A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM. + +```python +pl.BlockSpec( + block_shape: tuple[int, ...], + index_map: Callable, + memory_space: pl.MemorySpace +) +``` +There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### Kernel + +The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`). + +```python +def kernel(*input_buffers, *output_buffers): + # ... perform compute + # ... store result into output buffers +``` + +The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`. + + +### Pallas Call + +The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature: +```python +def pallas_call( + kernel, + grid: tuple[int, ...], + in_specs: Sequence[PyTree[BlockSpec]], + out_specs: PyTree[BlockSpec], + out_shape: PyTree[jax.ShapeDtypeStruct], +) -> Callable: +``` +`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`. + +`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match. + + + + +### Example - Elementwise Kernel revisited + +Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration. + + +```python id="iqr_qjONAHN9" +# Note: This is a TPU example. + +total_shape = (4096, 4096) +block_shape = (512, 512) + +def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +def add_matrices_pipelined(x: jax.Array, y: jax.Array): + return pl.pallas_call( + add_matrices_pipelined_kernel, + grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)), + in_specs=[ + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)) + ], + out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32), + )(x, y) + +x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32) +y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32) +result = add_matrices_pipelined(x, y) +np.testing.assert_array_equal( + result, x + y +) +``` + + +It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel! + + + +### Parameterizing a Kernel + +It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so: + + +```python id="RZTAiwrZ6srD" +def add_matrices_pipelined_param( + x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 +) -> jax.Array: + m, n = x.shape + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) + return pl.pallas_call( + add_matrices_kernel, + out_shape=x, + in_specs=[block_spec, block_spec], + out_specs=block_spec, + grid=(m // bm, n // bn), + )(x, y) + +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y +) +``` + + +## Sharp edges + +While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs. + +### Buffer Revisiting + +In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**. + +Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general. + +There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`. + + +### Reductions and accumulation + +**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.** + +Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle. +The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. + +As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array. + + + + + + + + +```python id="4qz1ET-_f9fJ" executionInfo={"status": "ok", "timestamp": 1744763773938, "user_tz": 420, "elapsed": 244, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="e43067ef-933a-45a5-912a-e224151cfa60" +x = jnp.ones((8, 1024, 1024)) +jnp.sum(x, axis=0) +``` + + +To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first. + + +```python id="ZEi1_vQVf-81" executionInfo={"status": "ok", "timestamp": 1744763774254, "user_tz": 420, "elapsed": 79, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" +# Note: This is a TPU example. + +# Warning: this implementation is incorrect! +def incorrect_sum_kernel(x_ref, o_ref): + o_ref[...] += x_ref[...] + +def incorrect_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size))) + return pl.pallas_call( + incorrect_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = incorrect_sum(x) +print(result) +``` + + +This result is completely wrong! + +There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. + +After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. + + +```python id="XtgD4nMa9_Bd" executionInfo={"status": "ok", "timestamp": 1744763774523, "user_tz": 420, "elapsed": 104, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" +# Note: This is a TPU example. + +def correct_sum_kernel(x_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] + +def correct_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + # We moved the reduction to the last axis of the grid. + grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size) + return pl.pallas_call( + correct_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = correct_sum(x) +print(result) +``` + + + +## Analyzing the performance + +What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities: +- **Memory latency** $α$, the minimum latency of a memory transfer. +- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM. +- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform. + +We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware. + +Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically. + +In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\alpha + X/\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator. + + + + + + +![compute_bound]() + + + +In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + X / \beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. + + + +![bandwidth_bound]() + + + +If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. + + + + +![latency_multi_stage]() + + + +Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details. + From 8073d78aa9016e3ab1051c8ad508826eee12741d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 16 Apr 2025 11:57:45 -0700 Subject: [PATCH 0650/1769] Update job name to show "build only" on Linux ARM64. Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we cross-compile the test targets on the Linux x86 RBE pool. The job name on Linux Arm64 runs will now show "build only" to avoid any confusion. Also, run only a single Python version for Linux Arm64 PiperOrigin-RevId: 748374021 --- .github/workflows/bazel_cpu_rbe.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index d6816d492d1d..ef5084960b30 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -46,7 +46,10 @@ jobs: enable-x_64: 1 - python: "3.13" enable-x_64: 0 - name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + # Only test a single Python version on Arm64 as we don't run the tests. + - python: "3.10" + runner: "linux-arm64-c4a-16" + name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" # End Presubmit Naming Check github-cpu-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -54,5 +57,7 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests with RBE + # Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we + # cross-compile the tests on the Linux x86 RBE pool. + - name: ${{ (contains(matrix.runner, 'linux-arm64') && 'Build' || 'Run') }} Bazel CPU Tests with RBE run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file From 7f7b5c6e9ca560895fb9f451145e79b7fc5dd8c8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 16 Apr 2025 12:38:30 -0700 Subject: [PATCH 0651/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0d1b60216ea13b0d261d59552a0f7ef20c4f76c5. PiperOrigin-RevId: 748387300 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d0cca161bd2d..0ac09a4a6594 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4bc073325136932a634b3f972c5493b68f95a0d4" -XLA_SHA256 = "61c1b1116c94a7155dcf5bfe9a407befb266538f304735413f7f21c1719972ab" +XLA_COMMIT = "0d1b60216ea13b0d261d59552a0f7ef20c4f76c5" +XLA_SHA256 = "357b37cc7c439580344ce0305bad88ef841f29743a99ea8e2253e64a32e139c6" def repo(): tf_http_archive( From 4ad5dd83f772d41ff31b1a30628cfada73bf600b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 10:14:46 -0700 Subject: [PATCH 0652/1769] Add explicit sharding to Introduction to parallel programming doc --- docs/conf.py | 1 - docs/notebooks/explicit-sharding.ipynb | 2 +- docs/notebooks/explicit-sharding.md | 2 +- docs/sharded-computation.ipynb | 320 ++++++++++++------------- docs/sharded-computation.md | 181 +++++++------- 5 files changed, 255 insertions(+), 251 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index cddb63653a17..52ff876b4958 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -232,7 +232,6 @@ def _do_not_evaluate_in_jax( 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', - 'sharded-computation.*', 'distributed_data_loading.*' ] diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index d0df799b51f6..dada1f0db507 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -28,7 +28,7 @@ "of work and it's also easy to make mistakes that way because there's no way to\n", "check that the shardings make sense together. More commonly, people add just\n", "enough sharding annotations to constrain the compiler. But this is a slow\n", - "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "iterative process. It's hard to know ahead of time what XLA's GSPMD pass will\n", "do (it's a whole-program optimization) so all you can do is add annotations,\n", "inspect XLA's sharding choices to see what happened, and repeat.\n", "\n", diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 46315cc536d6..a091060393b6 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -31,7 +31,7 @@ constraints? You could put them on every single intermediate but that's a lot of work and it's also easy to make mistakes that way because there's no way to check that the shardings make sense together. More commonly, people add just enough sharding annotations to constrain the compiler. But this is a slow -iterative process. It's hard to know ahead of time what XLA's gSPMD pass will +iterative process. It's hard to know ahead of time what XLA's GSPMD pass will do (it's a whole-program optimization) so all you can do is add annotations, inspect XLA's sharding choices to see what happened, and repeat. diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index d3ddac4edbdb..568a0d4c6e3d 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -13,21 +13,44 @@ "\n", "The tutorial covers three modes of parallel computation:\n", "\n", - "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", - "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", - "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", + "- *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + "- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", - "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", + "A summary table:\n", "\n", - "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." + "| Mode | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|\n", + "| Auto | No | No |\n", + "| Explicit (new) | Yes | No |\n", + "| Manual | Yes | Yes |\n", + "\n", + "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7efa1e66", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" ] }, { "cell_type": "code", "execution_count": 1, - "metadata": { - "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" - }, + "metadata": {}, "outputs": [ { "data": { @@ -48,7 +71,6 @@ } ], "source": [ - "import jax\n", "jax.devices()" ] }, @@ -84,7 +106,9 @@ } ], "source": [ + "import numpy as np\n", "import jax.numpy as jnp\n", + "\n", "arr = jnp.arange(32.0).reshape(4, 8)\n", "arr.devices()" ] @@ -264,51 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UEObolTqw4pp" - }, - "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", - "\n", - "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", - "\n", - "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aKNeOHTJnqmS", - "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pinned_host\n", - "device\n" - ] - } - ], - "source": [ - "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", - "s_dev = s_host.with_memory_kind('device')\n", - "arr_host = jax.device_put(arr, s_host)\n", - "arr_dev = jax.device_put(arr, s_dev)\n", - "print(arr_host.sharding.memory_kind)\n", - "print(arr_dev.sharding.memory_kind)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jDHYnVqHwaST" - }, + "metadata": {}, "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", @@ -402,146 +382,157 @@ "source": [ "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", "\n", - "### 1.1 Sharding transformation between memory types\n", - "\n", - "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", + "## 2. Explicit sharding\n", "\n", - "#### Example 1: Pinned host to device memory\n", - "\n", - "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PXu3MhafyRHo", - "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" - }, + "execution_count": 9, + "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
+       "                                                                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "device\n" + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ - "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(out_dev)\n", - "print(out_dev.sharding.memory_kind)" + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "LuYFqpcBySiX" - }, + "metadata": {}, + "source": [ + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe62839", + "metadata": {}, + "outputs": [], "source": [ - "#### Example 2: Device to pinned_host memory\n", + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", "\n", - "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "id": "74995421", + "metadata": {}, + "source": [ + "To start seeing shardings in the type we need to set up an explicit-sharding mesh." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qLsgNlKfybRw", - "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "pinned_host\n" - ] - } - ], + "id": "e785a694", + "metadata": {}, + "outputs": [], "source": [ - "g = jax.jit(lambda x: x, out_shardings=s_host)\n", - "out_host = g(arr_dev)\n", - "print(out_host)\n", - "print(out_host.sharding.memory_kind)" + "from jax.sharding import AxisType\n", + "\n", + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7BGD31-owaSU" - }, + "id": "8d81409c", + "metadata": {}, + "source": [ + "Now we can create some sharded arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4969cabd", + "metadata": {}, + "outputs": [], "source": [ - "## 2. Semi-automated sharding with constraints\n", + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P(\"X\", None)))\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c09acf7d", + "metadata": {}, + "source": [ + "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"\n", "\n", - "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" + "These shardings associated with JAX-level types propagate through operations. For example:" ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
-       "                                                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[48. 52. 56. 60. 64. 68. 72. 76.]\n" - ] - } - ], + "execution_count": null, + "id": "ab2f9500", + "metadata": {}, + "outputs": [], "source": [ + "arg0 = jax.device_put(np.arange(4).reshape(4, 1),\n", + " jax.NamedSharding(mesh, P(\"X\", None)))\n", + "arg1 = jax.device_put(np.arange(8).reshape(1, 8),\n", + " jax.NamedSharding(mesh, P(None, \"Y\")))\n", + "\n", "@jax.jit\n", - "def f_contract_2(x):\n", - " out = x.sum(axis=0)\n", - " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - " return jax.lax.with_sharding_constraint(out, sharding)\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", "\n", - "result = f_contract_2(arr_sharded)\n", - "jax.debug.visualize_array_sharding(result)\n", - "print(result)" + "with jax.sharding.use_mesh(mesh):\n", + " add_arrays(arg0, arg1)" ] }, { "cell_type": "markdown", + "id": "dda3d0c5", "metadata": {}, "source": [ - "This gives you a function with the particular output sharding you'd like.\n", + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time.\n", "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", @@ -757,7 +748,8 @@ "source": [ "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", "\n", - "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" + "If you shard the leading axis of both `x` and make `weights` fully replicated,\n", + "then the matrix multiplication will automatically happen in parallel:" ] }, { @@ -780,10 +772,8 @@ ], "source": [ "mesh = jax.make_mesh((8,), ('x',))\n", - "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - "\n", - "x_sharded = jax.device_put(x, sharding)\n", - "weights_sharded = jax.device_put(weights, sharding)\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))\n", "\n", "layer(x_sharded, weights_sharded, bias)" ] @@ -792,15 +782,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" + "Alternatively, you can use explicit sharding mode too:" ] }, { "cell_type": "code", "execution_count": 17, - "metadata": { - "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" - }, + "metadata": {}, "outputs": [ { "data": { @@ -814,13 +802,22 @@ } ], "source": [ + "explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))\n", + "\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))\n", + "\n", "@jax.jit\n", "def layer_auto(x, weights, bias):\n", - " x = jax.lax.with_sharding_constraint(x, sharding)\n", - " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", - " return layer(x, weights, bias)\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"weights sharding: {jax.typeof(weights)}\")\n", + " print(f\"bias sharding: {jax.typeof(bias)}\")\n", + " out = layer(x, weights, bias)\n", + " print(f\"out sharding: {jax.typeof(out)}\")\n", + " return out\n", "\n", - "layer_auto(x, weights, bias) # pass in unsharded inputs" + "with jax.sharding.use_mesh(explicit_mesh):\n", + " layer_auto(x_sharded, weights_sharded, bias)" ] }, { @@ -871,6 +868,7 @@ "\n", "To learn about each SPMD method in-depth, check out these docs:\n", "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", + "- {doc}`../notebooks/explicit-sharding`\n", "- {doc}`../notebooks/shard_map`" ] } diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index b05eb8d5f66e..ae9f44aba832 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -20,18 +20,34 @@ This tutorial serves as an introduction to device parallelism for Single-Program The tutorial covers three modes of parallel computation: -- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). -- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint` -- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives +- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). +- *Explicit Sharding* (\*new\*) is similar to automatic sharding in that + you're writing a global-view program. The difference is that the sharding + of each array is part of the array's JAX-level type making it an explicit + part of the programming model. These shardings are propagated at the JAX + level and queryable at trace time. It's still the compiler's responsibility + to turn the whole-array program into per-device programs (turning `jnp.sum` + into `psum` for example) but the compiler is heavily constrained by the + user-supplied shardings. +- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives + +A summary table: + +| Mode | Explicit sharding? | Explicit Collectives? | +|---|---|---| +| Auto | No | No | +| Explicit (new) | Yes | No | +| Manual | Yes | Yes | Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. -If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with). - ```{code-cell} -:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456 - import jax + +jax.config.update('jax_num_cpu_devices', 8) +``` + +```{code-cell} jax.devices() ``` @@ -46,7 +62,9 @@ In the simplest cases, arrays are sharded on a single device, as demonstrated be ```{code-cell} :outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a +import numpy as np import jax.numpy as jnp + arr = jnp.arange(32.0).reshape(4, 8) arr.devices() ``` @@ -90,31 +108,6 @@ print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded) ``` -+++ {"id": "UEObolTqw4pp"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host. - -To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: aKNeOHTJnqmS -outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2 ---- -s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') -s_dev = s_host.with_memory_kind('device') -arr_host = jax.device_put(arr, s_host) -arr_dev = jax.device_put(arr, s_dev) -print(arr_host.sharding.memory_kind) -print(arr_dev.sharding.memory_kind) -``` - -+++ {"id": "jDHYnVqHwaST"} - ## 1. Automatic parallelism via `jit` Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. @@ -156,69 +149,76 @@ print(result) The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. -### 1.1 Sharding transformation between memory types - -The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array. +## 2. Explicit sharding -#### Example 1: Pinned host to device memory - -In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory. +The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that +the JAX-level _type_ of a value includes a description of how the value is sharded. +We can query the JAX-level type of any JAX value (or Numpy array, or Python +scalar) using `jax.typeof`: ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: PXu3MhafyRHo -outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b ---- -f = jax.jit(lambda x: x, out_shardings=s_dev) -out_dev = f(arr_host) -print(out_dev) -print(out_dev.sharding.memory_kind) +some_array = np.arange(8) +print(f"JAX-level type of some_array: {jax.typeof(some_array)}") ``` -+++ {"id": "LuYFqpcBySiX"} +Importantly, we can query the type even while tracing under a `jit` (the JAX-level type +is almost _defined_ as "the information about a value we have access to while +under a jit). + +```{code-cell} +@jax.jit +def foo(x): + print(f"JAX-level type of x during tracing: {jax.typeof(x)}") + return x + x -#### Example 2: Device to pinned_host memory +foo(some_array) +``` -In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory. +To start seeing shardings in the type we need to set up an explicit-sharding mesh. ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: qLsgNlKfybRw -outputId: a16448b9-7e39-408f-b200-505f65ad4464 ---- -g = jax.jit(lambda x: x, out_shardings=s_host) -out_host = g(arr_dev) -print(out_host) -print(out_host.sharding.memory_kind) +from jax.sharding import AxisType + +mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) ``` -+++ {"id": "7BGD31-owaSU"} +Now we can create some sharded arrays: -## 2. Semi-automated sharding with constraints +```{code-cell} +replicated_array = np.arange(8).reshape(4, 2) +sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None))) + +print(f"replicated_array type: {jax.typeof(replicated_array)}") +print(f"sharded_array type: {jax.typeof(sharded_array)}") +``` -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension +is sharded along mesh axis 'X'. The array is replicated along all other mesh +axes" -For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: +These shardings associated with JAX-level types propagate through operations. For example: ```{code-cell} -:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd +arg0 = jax.device_put(np.arange(4).reshape(4, 1), + jax.NamedSharding(mesh, P("X", None))) +arg1 = jax.device_put(np.arange(8).reshape(1, 8), + jax.NamedSharding(mesh, P(None, "Y"))) @jax.jit -def f_contract_2(x): - out = x.sum(axis=0) - sharding = jax.sharding.NamedSharding(mesh, P('x')) - return jax.lax.with_sharding_constraint(out, sharding) - -result = f_contract_2(arr_sharded) -jax.debug.visualize_array_sharding(result) -print(result) +def add_arrays(x, y): + ans = x + y + print(f"x sharding: {jax.typeof(x)}") + print(f"y sharding: {jax.typeof(y)}") + print(f"ans sharding: {jax.typeof(ans)}") + return ans + +with jax.sharding.use_mesh(mesh): + add_arrays(arg0, arg1) ``` -This gives you a function with the particular output sharding you'd like. +That's the gist of it. Shardings propagate deterministically at trace time and +we can query them at trace time. ## 3. Manual parallelism with `shard_map` @@ -320,32 +320,38 @@ layer(x, weights, bias) You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data. -If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel: +If you shard the leading axis of both `x` and make `weights` fully replicated, +then the matrix multiplication will automatically happen in parallel: ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 mesh = jax.make_mesh((8,), ('x',)) -sharding = jax.sharding.NamedSharding(mesh, P('x')) - -x_sharded = jax.device_put(x, sharding) -weights_sharded = jax.device_put(weights, sharding) +x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P())) layer(x_sharded, weights_sharded, bias) ``` -Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs: +Alternatively, you can use explicit sharding mode too: ```{code-cell} -:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4 +explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,)) + +x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P())) @jax.jit def layer_auto(x, weights, bias): - x = jax.lax.with_sharding_constraint(x, sharding) - weights = jax.lax.with_sharding_constraint(weights, sharding) - return layer(x, weights, bias) - -layer_auto(x, weights, bias) # pass in unsharded inputs + print(f"x sharding: {jax.typeof(x)}") + print(f"weights sharding: {jax.typeof(weights)}") + print(f"bias sharding: {jax.typeof(bias)}") + out = layer(x, weights, bias) + print(f"out sharding: {jax.typeof(out)}") + return out + +with jax.sharding.use_mesh(explicit_mesh): + layer_auto(x_sharded, weights_sharded, bias) ``` Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product: @@ -371,4 +377,5 @@ This tutorial serves as a brief introduction of sharded and parallel computation To learn about each SPMD method in-depth, check out these docs: - {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization` +- {doc}`../notebooks/explicit-sharding` - {doc}`../notebooks/shard_map` From 9afc047bf02c52e17cdea2185300521987eb9da6 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 16 Apr 2025 16:18:10 -0400 Subject: [PATCH 0653/1769] Fix bug in argnums_partial_except when static_argnums is unsorted. --- jax/_src/api_util.py | 2 +- tests/api_test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 451d2e490a15..163bade2065c 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -259,7 +259,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], dyn_args = tuple(args[i] for i in dyn_argnums) fixed_args = [] - for i in static_argnums: + for i in sorted(static_argnums): # TODO(shoyer): set allow_invalid=True permanently after static_argnames. if allow_invalid and i >= len(args): continue diff --git a/tests/api_test.py b/tests/api_test.py index 7590b6e6af71..2d1055516074 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4313,6 +4313,21 @@ def g(x, y): for i in range(3): # Loop verifies we exercise both Python and C++ dispatch self.assertEqual(2 * i, g(2, i), msg=i) + def test_make_jaxpr_static_argnums_order(self): + # https://github.com/jax-ml/jax/issues/28065 + def f(a, b, c): + x = a + c + y = b * c + z = x - y + return z + + for static_argnums in [(1, 0), (0, 1)]: + val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(val, -2) + jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(jaxpr.eqns[0].invars[0].val, 1) + self.assertEqual(jaxpr.eqns[1].invars[0].val, 2) + def test_fastpath_cache_confusion(self): # https://github.com/jax-ml/jax/issues/12542 @jax.jit From b61f87f5bdfcadce57acfe489ea9ea8bdf5bf65d Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 16 Apr 2025 13:48:21 -0700 Subject: [PATCH 0654/1769] [Pallas] Move pipelining doc figures to static folder --- .../pallas/pipelining_bandwidth_bound.svg | 1 + .../pallas/pipelining_compute_bound.svg | 1 + docs/_static/pallas/pipelining_example.svg | 1 + .../pallas/pipelining_latency_multistage.svg | 1 + .../pallas/pipelining_mem_hierarchy.svg | 30 + docs/pallas/pipelining.ipynb | 1700 +++++++++-------- docs/pallas/pipelining.md | 20 +- 7 files changed, 899 insertions(+), 855 deletions(-) create mode 100644 docs/_static/pallas/pipelining_bandwidth_bound.svg create mode 100644 docs/_static/pallas/pipelining_compute_bound.svg create mode 100644 docs/_static/pallas/pipelining_example.svg create mode 100644 docs/_static/pallas/pipelining_latency_multistage.svg create mode 100644 docs/_static/pallas/pipelining_mem_hierarchy.svg diff --git a/docs/_static/pallas/pipelining_bandwidth_bound.svg b/docs/_static/pallas/pipelining_bandwidth_bound.svg new file mode 100644 index 000000000000..45b78a7ce35e --- /dev/null +++ b/docs/_static/pallas/pipelining_bandwidth_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_compute_bound.svg b/docs/_static/pallas/pipelining_compute_bound.svg new file mode 100644 index 000000000000..cb3b58eaef99 --- /dev/null +++ b/docs/_static/pallas/pipelining_compute_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_example.svg b/docs/_static/pallas/pipelining_example.svg new file mode 100644 index 000000000000..59ca5b433b11 --- /dev/null +++ b/docs/_static/pallas/pipelining_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_latency_multistage.svg b/docs/_static/pallas/pipelining_latency_multistage.svg new file mode 100644 index 000000000000..2c40f1692b9a --- /dev/null +++ b/docs/_static/pallas/pipelining_latency_multistage.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_mem_hierarchy.svg b/docs/_static/pallas/pipelining_mem_hierarchy.svg new file mode 100644 index 000000000000..d7a2e6cbabd8 --- /dev/null +++ b/docs/_static/pallas/pipelining_mem_hierarchy.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + Registers + SRAM/Caches + DRAM/HBM + Network + + Fastest + Fast + Slow + Slowest + + Lowest Capacity + Low Capacity + High Capacity + Highest Capacity + + diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb index 69029698c0c5..8342bee2002d 100644 --- a/docs/pallas/pipelining.ipynb +++ b/docs/pallas/pipelining.ipynb @@ -1,861 +1,865 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "last_runtime": { - "build_target": "//experimental/users/justinfu/pallas:colab", - "kind": "private" - } - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "C93Xlf0DRW9H" + }, + "source": [ + "# Software Pipelining\n", + "\n", + "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", + "\n", + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references.\n" + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Software Pipelining\n", - "\n", - "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", - "\n", - "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references.\n" - ], - "metadata": { - "id": "C93Xlf0DRW9H" - } - }, - { - "cell_type": "code", - "source": [ - "import jax\n", - "from jax import numpy as jnp\n", - "from jax.experimental import pallas as pl\n", - "import numpy as np" - ], - "metadata": { - "id": "YkOjspo5BKPD" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Memory Hierarchies\n", - "\n", - "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", - "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", - "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", - "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", - "It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register.\n", - "- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM.\n", - "- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices.\n", - "\n", - "\n", - "\n", - "\n", - "![memory_hierarchy]()" - ], - "metadata": { - "id": "8vCtShhBjzTd" - } - }, - { - "cell_type": "markdown", - "source": [ - "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", - "\n", - "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." - ], - "metadata": { - "id": "Qs3F--kwiOJm" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Deriving a Double-Buffered Pipeline\n", - "\n", - "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", - "\n", - "
\n",
-        "for i in range(N):\n",
-        "  copy_in(A[i], X)\n",
-        "  Y = X + 1\n",
-        "  copy_out(Y, A[i])\n",
-        "
\n", - "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", - "\n", - "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", - "
\n",
-        "  # Itr 1\n",
-        "  copy_in_start(A[0], X)\n",
-        "  copy_in_wait(X)\n",
-        "  Y = X + 1\n",
-        "  copy_out_start(Y, A[0])\n",
-        "  copy_out_wait(Y)\n",
-        "\n",
-        "  # Itr 2\n",
-        "  copy_in_start(A[1], X)\n",
-        "  copy_in_wait(X)\n",
-        "  Y = X + 1\n",
-        "  copy_out_start(Y, A[1])\n",
-        "  copy_out_wait(Y)\n",
-        "\n",
-        "  # Itr 3\n",
-        "  copy_in_start(A[2], X)\n",
-        "  copy_in_wait(X)\n",
-        "  Y = X + 1\n",
-        "  copy_out_start(Y, A[2])\n",
-        "  copy_out_wait(Y)\n",
-        "\n",
-        "  # Itr 4\n",
-        "  copy_in_start(A[3], X)\n",
-        "  copy_in_wait(X)\n",
-        "  Y = X + 1\n",
-        "  copy_out_start(Y, A[3])\n",
-        "  copy_out_wait(Y)\n",
-        "
\n", - "\n", - "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", - "
\n",
-        "  # Prologue\n",
-        "  copy_in_start(A[0], X[0])\n",
-        "  \n",
-        "  # Itr 1\n",
-        "  copy_in_start(A[1], X[1])\n",
-        "  copy_in_wait(X[0])\n",
-        "  Y[0] = X[0] + 1\n",
-        "  copy_out_start(Y[0], A[0])\n",
-        "  copy_out_wait(Y[0])\n",
-        "\n",
-        "  # Itr 2 - Steady state\n",
-        "  copy_in_start(A[2], X[0])\n",
-        "  copy_in_wait(X[1])\n",
-        "  Y[1] = X[1] + 1\n",
-        "  copy_out_start(Y[1], A[1])\n",
-        "  copy_out_wait(Y[1])\n",
-        "\n",
-        "  # Itr 3 - Steady state\n",
-        "  copy_in_start(A[3], X[1])\n",
-        "  copy_in_wait(X[0])\n",
-        "  Y[0] = X[0] + 1\n",
-        "  copy_out_start(Y[0], A[2])\n",
-        "  copy_out_wait(Y[0])\n",
-        "\n",
-        "  # Itr 4 - No copy-in\n",
-        "  copy_in_wait(X[1])\n",
-        "  Y[1] = X[1] + 1\n",
-        "  copy_out_start(Y[1], A[2])\n",
-        "  copy_out_wait(Y[1])\n",
-        "
\n", - "\n", - "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", - "\n", - "
\n",
-        "  # Prologue\n",
-        "  copy_in_start(A[0], X[0])\n",
-        "  \n",
-        "  # Itr 1\n",
-        "  copy_in_start(A[1], X[1])\n",
-        "  copy_in_wait(X[0])\n",
-        "  Y[0] = X[0] + 1\n",
-        "  copy_out_start(Y[0], A[0])\n",
-        "\n",
-        "  # Itr 2 - Steady state\n",
-        "  copy_in_start(A[2], X[0])\n",
-        "  copy_in_wait(X[1])\n",
-        "  Y[1] = X[1] + 1\n",
-        "  copy_out_start(Y[1], A[1])\n",
-        "  copy_out_wait(Y[0])\n",
-        "\n",
-        "  # Itr 3 - Steady state\n",
-        "  copy_in_start(A[3], X[1])\n",
-        "  copy_in_wait(X[0])\n",
-        "  Y[0] = X[0] + 1\n",
-        "  copy_out_start(Y[0], A[2])\n",
-        "  copy_out_wait(Y[1])\n",
-        "\n",
-        "  # Itr 4 - No copy-in\n",
-        "  copy_in_wait(X[1])\n",
-        "  Y[1] = X[1] + 1\n",
-        "  copy_out_start(Y[1], A[2])\n",
-        "  copy_out_wait(Y[0])\n",
-        "\n",
-        "  # Epilogue\n",
-        "  copy_out_wait(Y[1])\n",
-        "
\n", - "\n", - "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", - "\n", - "```\n", - "# Prologue\n", - "copy_in_start(A[0], X[0])\n", - "\n", - "# Main loop\n", - "for i in range(N):\n", - " cur_slot = i % 2\n", - " next_slot = (i + 1) % 2\n", - "\n", - " if i < N:\n", - " copy_in_start(A[i+1], X[next_slot])\n", - " \n", - " copy_in_wait(X[cur_slot])\n", - " Y[cur_slot] = X[cur_slot] + 1\n", - " copy_out_start(Y[cur_slot], A[i])\n", - "\n", - " if i > 0:\n", - " copy_out_wait(Y[next_slot])\n", - "\n", - "# Epilogue\n", - "copy_out_wait(Y[1])\n", - "```\n", - "\n", - "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", - "\n", - "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", - "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", - "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", - "\n", - "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", - "```python\n", - "def double_buffered_pipeline(\n", - " grid: tuple[int, ...],\n", - " kernel: Callable,\n", - " in_slices: Callable,\n", - " out_slices: Callable):\n", - " # Prologue\n", - " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", - "\n", - " # Main loop\n", - " grid_size = prod(grid)\n", - " for i in range(grid_size):\n", - " cur_slot = i % 2\n", - " next_slot = (i + 1) % 2\n", - " if i < grid_size:\n", - " copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])\n", - " copy_in_wait(in_sram[cur_slot])\n", - "\n", - " kernel(inputs, outputs)\n", - "\n", - " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", - " if i > 0:\n", - " copy_out_wait(out_sram[next_slot])\n", - "\n", - " # Epilogue\n", - " copy_out_wait(out_sram[1])\n", - "```" - ], - "metadata": { - "id": "ZcSzl4N6pPbG" - } - }, - { - "cell_type": "markdown", - "source": [ - "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." - ], - "metadata": { - "id": "ziBuvv8jDgxo" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Pallas Pipelining API\n", - "\n", - "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", - "\n", - "\n", - "### Grid\n", - "\n", - "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", - "\n", - "```\n", - "# For grid (N, M, K)\n", - "for n in range (N):\n", - " for m in range(M):\n", - " for k in range(K):\n", - " kernel()\n", - "```\n", - "\n", - "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", - "\n", - "### BlockSpecs\n", - "\n", - "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", - "\n", - "```python\n", - "pl.BlockSpec(\n", - " block_shape: tuple[int, ...],\n", - " index_map: Callable,\n", - " memory_space: pl.MemorySpace\n", - ")\n", - "```\n", - "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", - "\n", - "### Kernel\n", - "\n", - "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", - "\n", - "```python\n", - "def kernel(*input_buffers, *output_buffers):\n", - " # ... perform compute\n", - " # ... store result into output buffers\n", - "```\n", - "\n", - "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", - "\n", - "\n", - "### Pallas Call\n", - "\n", - "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", - "```python\n", - "def pallas_call(\n", - " kernel,\n", - " grid: tuple[int, ...],\n", - " in_specs: Sequence[PyTree[BlockSpec]],\n", - " out_specs: PyTree[BlockSpec],\n", - " out_shape: PyTree[jax.ShapeDtypeStruct],\n", - ") -> Callable:\n", - "```\n", - "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", - "\n", - "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" - ], - "metadata": { - "id": "niMr39cPkJ2m" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Example - Elementwise Kernel revisited\n", - "\n", - "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." - ], - "metadata": { - "id": "0mHZ63eAq_8j" - } - }, - { - "cell_type": "code", - "source": [ - "# Note: This is a TPU example.\n", - "\n", - "total_shape = (4096, 4096)\n", - "block_shape = (512, 512)\n", - "\n", - "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", - " o_ref[...] = x_ref[...] + y_ref[...]\n", - "\n", - "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", - " return pl.pallas_call(\n", - " add_matrices_pipelined_kernel,\n", - " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", - " in_specs=[\n", - " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", - " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", - " ],\n", - " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", - " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", - " )(x, y)\n", - "\n", - "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", - "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", - "result = add_matrices_pipelined(x, y)\n", - "np.testing.assert_array_equal(\n", - " result, x + y\n", - ")" - ], - "metadata": { - "id": "iqr_qjONAHN9" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" - ], - "metadata": { - "id": "UWHD0_qm6DL7" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Parameterizing a Kernel\n", - "\n", - "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" - ], - "metadata": { - "id": "BZ-4U6Cv6cvU" - } - }, - { - "cell_type": "code", - "source": [ - "def add_matrices_pipelined_param(\n", - " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", - ") -> jax.Array:\n", - " m, n = x.shape\n", - " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(m // bm, n // bn),\n", - " )(x, y)\n", - "\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", - ")" - ], - "metadata": { - "id": "RZTAiwrZ6srD" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Sharp edges\n", - "\n", - "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", - "\n", - "### Buffer Revisiting\n", - "\n", - "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", - "\n", - "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", - "\n", - "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", - "\n", - "\n", - "### Reductions and accumulation\n", - "\n", - "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", - "\n", - "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", - "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", - "\n", - "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], - "metadata": { - "id": "vO8VkbYj_ral" - } - }, - { - "cell_type": "code", - "source": [ - "x = jnp.ones((8, 1024, 1024))\n", - "jnp.sum(x, axis=0)" - ], - "metadata": { - "id": "4qz1ET-_f9fJ", - "executionInfo": { - "status": "ok", - "timestamp": 1744763773938, - "user_tz": 420, - "elapsed": 244, - "user": { - "displayName": "Justin Fu", - "userId": "17543197034567316452" - } - }, - "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "metadata": {}, - "execution_count": 5 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." - ], - "metadata": { - "id": "yX762DRrgCOG" - } + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YkOjspo5BKPD" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shnVghWUSvpx" + }, + "source": [ + "## Memory Hierarchies\n", + "\n", + "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", + "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", + "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", + "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", + "It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register.\n", + "- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM.\n", + "- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices.\n", + "\n", + "\n", + "\n", + "\n", + "![memory_hierarchy](../../_static/pallas/pipelining_mem_hierarchy.svg)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvW6Lo7d2jfb" + }, + "source": [ + "\n", + "In order to perform computation on values X and Y that live in HBM, we need to:\n", + "\n", + "1. Copy the values x and y into SRAM.\n", + "2. Load the values from SRAM into registers.\n", + "3. Execute the computation and store the result into registers.\n", + "4. Store the values in the output registers into SRAM.\n", + "5. Copy the output values in SRAM back to HBM.\n", + "\n", + "Let’s implement a Pallas function that does just that!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 108, + "status": "ok", + "timestamp": 1744764235906, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 }, + "id": "IrPhDFnT3Nvw", + "outputId": "8bc03872-fd9f-4610-9d53-d4b46be560f4" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "# Note: This is a TPU example.\n", - "\n", - "# Warning: this implementation is incorrect!\n", - "def incorrect_sum_kernel(x_ref, o_ref):\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def incorrect_sum(x: jax.Array,\n", - " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", - " reduction_size, *out_shape = x.shape\n", - " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", - " return pl.pallas_call(\n", - " incorrect_sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", - " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", - " )(x)\n", - "\n", - "result = incorrect_sum(x)\n", - "print(result)" - ], - "metadata": { - "id": "ZEi1_vQVf-81", - "executionInfo": { - "status": "ok", - "timestamp": 1744763774254, - "user_tz": 420, - "elapsed": 79, - "user": { - "displayName": "Justin Fu", - "userId": "17543197034567316452" - } - }, - "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[[65. 65. 65. ... 66. 66. 66.]\n", - " [65. 65. 65. ... 66. 66. 66.]\n", - " [65. 65. 65. ... 66. 66. 66.]\n", - " ...\n", - " [71. 71. 71. ... 72. 72. 72.]\n", - " [71. 71. 71. ... 72. 72. 72.]\n", - " [71. 71. 71. ... 72. 72. 72.]]\n" - ] - } + "data": { + "text/plain": [ + "Array([[2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " ...,\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):\n", + " # Load x and y from SRAM into registers\n", + " x_regs = x_sram_ref[:, :]\n", + " y_regs = y_sram_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_regs = x_regs + y_regs\n", + " # Store the output values in registers back into SRAM\n", + " z_sram_ref[:, :] = z_regs\n", + "\n", + "\n", + "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", + " # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.\n", + " # It will then copy `x` and `y` from HBM into SRAM.\n", + " z = pl.pallas_call(\n", + " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", + " # pallas_call will also copy the output from SRAM back into HBM.\n", + " return z\n", + "\n", + "\n", + "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", + "add_matrices(x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gGjtwv9u3UNK" + }, + "source": [ + "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "\n", + "`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`.\n", + "\n", + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.\n", + "\n", + "Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both:\n", + "\n", + "- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "\n", + "- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself.\n", + "\n", + "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0Ebs2pCDgsEW" + }, + "source": [ + "## Pipelining Basics\n", + "\n", + "\n", + "How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel.\n", + "\n", + "The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): \n", + "\n", + "1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`.\n", + "2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y`\n", + "3. **copy_out**: Copy result `Y` back into HBM `A[i]`.\n", + "\n", + "Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8vCtShhBjzTd" + }, + "source": [ + "\n", + "![pipelining_example](../../_static/pallas/pipelining_example.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qs3F--kwiOJm" + }, + "source": [ + "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", + "\n", + "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcSzl4N6pPbG" + }, + "source": [ + "### Deriving a Double-Buffered Pipeline\n", + "\n", + "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", + "\n", + "
\n",
+    "for i in range(N):\n",
+    "  copy_in(A[i], X)\n",
+    "  Y = X + 1\n",
+    "  copy_out(Y, A[i])\n",
+    "
\n", + "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", + "\n", + "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", + "
\n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[0], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[0])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 2\n",
+    "  copy_in_start(A[1], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[1])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 3\n",
+    "  copy_in_start(A[2], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[2])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 4\n",
+    "  copy_in_start(A[3], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[3])\n",
+    "  copy_out_wait(Y)\n",
+    "
\n", + "\n", + "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", + "\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Epilogue\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", + "\n", + "```\n", + "# Prologue\n", + "copy_in_start(A[0], X[0])\n", + "\n", + "# Main loop\n", + "for i in range(N):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + "\n", + " if i < N:\n", + " copy_in_start(A[i+1], X[next_slot])\n", + " \n", + " copy_in_wait(X[cur_slot])\n", + " Y[cur_slot] = X[cur_slot] + 1\n", + " copy_out_start(Y[cur_slot], A[i])\n", + "\n", + " if i > 0:\n", + " copy_out_wait(Y[next_slot])\n", + "\n", + "# Epilogue\n", + "copy_out_wait(Y[1])\n", + "```\n", + "\n", + "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", + "\n", + "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", + "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", + "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", + "\n", + "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", + "```python\n", + "def double_buffered_pipeline(\n", + " grid: tuple[int, ...],\n", + " kernel: Callable,\n", + " in_slices: Callable,\n", + " out_slices: Callable):\n", + " # Prologue\n", + " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", + "\n", + " # Main loop\n", + " grid_size = prod(grid)\n", + " for i in range(grid_size):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + " if i < grid_size:\n", + " copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])\n", + " copy_in_wait(in_sram[cur_slot])\n", + "\n", + " kernel(inputs, outputs)\n", + "\n", + " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", + " if i > 0:\n", + " copy_out_wait(out_sram[next_slot])\n", + "\n", + " # Epilogue\n", + " copy_out_wait(out_sram[1])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ziBuvv8jDgxo" + }, + "source": [ + "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "niMr39cPkJ2m" + }, + "source": [ + "## Pallas Pipelining API\n", + "\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "\n", + "\n", + "### Grid\n", + "\n", + "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", + "\n", + "```\n", + "# For grid (N, M, K)\n", + "for n in range (N):\n", + " for m in range(M):\n", + " for k in range(K):\n", + " kernel()\n", + "```\n", + "\n", + "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### BlockSpecs\n", + "\n", + "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", + "\n", + "```python\n", + "pl.BlockSpec(\n", + " block_shape: tuple[int, ...],\n", + " index_map: Callable,\n", + " memory_space: pl.MemorySpace\n", + ")\n", + "```\n", + "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### Kernel\n", + "\n", + "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", + "\n", + "```python\n", + "def kernel(*input_buffers, *output_buffers):\n", + " # ... perform compute\n", + " # ... store result into output buffers\n", + "```\n", + "\n", + "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", + "\n", + "\n", + "### Pallas Call\n", + "\n", + "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", + "```python\n", + "def pallas_call(\n", + " kernel,\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[PyTree[BlockSpec]],\n", + " out_specs: PyTree[BlockSpec],\n", + " out_shape: PyTree[jax.ShapeDtypeStruct],\n", + ") -> Callable:\n", + "```\n", + "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", + "\n", + "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0mHZ63eAq_8j" + }, + "source": [ + "### Example - Elementwise Kernel revisited\n", + "\n", + "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqr_qjONAHN9" + }, + "outputs": [], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "total_shape = (4096, 4096)\n", + "block_shape = (512, 512)\n", + "\n", + "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", + " o_ref[...] = x_ref[...] + y_ref[...]\n", + "\n", + "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", + " return pl.pallas_call(\n", + " add_matrices_pipelined_kernel,\n", + " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", + " in_specs=[\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", + " ],\n", + " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", + " )(x, y)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", + "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", + "result = add_matrices_pipelined(x, y)\n", + "np.testing.assert_array_equal(\n", + " result, x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UWHD0_qm6DL7" + }, + "source": [ + "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZ-4U6Cv6cvU" + }, + "source": [ + "### Parameterizing a Kernel\n", + "\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RZTAiwrZ6srD" + }, + "outputs": [], + "source": [ + "def add_matrices_pipelined_param(\n", + " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", + ") -> jax.Array:\n", + " m, n = x.shape\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", + " return pl.pallas_call(\n", + " add_matrices_kernel,\n", + " out_shape=x,\n", + " in_specs=[block_spec, block_spec],\n", + " out_specs=block_spec,\n", + " grid=(m // bm, n // bn),\n", + " )(x, y)\n", + "\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO8VkbYj_ral" + }, + "source": [ + "## Sharp edges\n", + "\n", + "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", + "\n", + "### Buffer Revisiting\n", + "\n", + "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", + "\n", + "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", + "\n", + "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", + "\n", + "\n", + "### Reductions and accumulation\n", + "\n", + "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", + "\n", + "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "\n", + "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1744763773938, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 }, + "id": "4qz1ET-_f9fJ", + "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "This result is completely wrong!\n", - "\n", - "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", - "\n", - "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." - ], - "metadata": { - "id": "MglScPDD9618" - } - }, - { - "cell_type": "code", - "source": [ - "# Note: This is a TPU example.\n", - "\n", - "def correct_sum_kernel(x_ref, o_ref):\n", - " @pl.when(pl.program_id(2) == 0)\n", - " def _():\n", - " o_ref[...] = jnp.zeros_like(o_ref)\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def correct_sum(x: jax.Array,\n", - " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", - " reduction_size, *out_shape = x.shape\n", - " # We moved the reduction to the last axis of the grid.\n", - " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", - " return pl.pallas_call(\n", - " correct_sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", - " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", - " )(x)\n", - "\n", - "result = correct_sum(x)\n", - "print(result)" - ], - "metadata": { - "id": "XtgD4nMa9_Bd", - "executionInfo": { - "status": "ok", - "timestamp": 1744763774523, - "user_tz": 420, - "elapsed": 104, - "user": { - "displayName": "Justin Fu", - "userId": "17543197034567316452" - } - }, - "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[[8. 8. 8. ... 8. 8. 8.]\n", - " [8. 8. 8. ... 8. 8. 8.]\n", - " [8. 8. 8. ... 8. 8. 8.]\n", - " ...\n", - " [8. 8. 8. ... 8. 8. 8.]\n", - " [8. 8. 8. ... 8. 8. 8.]\n", - " [8. 8. 8. ... 8. 8. 8.]]\n" - ] - } + "data": { + "text/plain": [ + "Array([[8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " ...,\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = jnp.ones((8, 1024, 1024))\n", + "jnp.sum(x, axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yX762DRrgCOG" + }, + "source": [ + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 79, + "status": "ok", + "timestamp": 1744763774254, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 }, + "id": "ZEi1_vQVf-81", + "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "\n", - "## Analyzing the performance\n", - "\n", - "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", - "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", - "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", - "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", - "\n", - "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", - "\n", - "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", - "\n", - "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", - "\n" - ], - "metadata": { - "id": "BckuFg6qcnVw" - } - }, - { - "cell_type": "markdown", - "source": [ - "\n", - "![compute_bound]()" - ], - "metadata": { - "id": "NDY4mcae_nMO" - } - }, - { - "cell_type": "markdown", - "source": [ - "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + X / \\beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." - ], - "metadata": { - "id": "HFWcaAudW4z1" - } - }, - { - "cell_type": "markdown", - "source": [ - "![bandwidth_bound]()" - ], - "metadata": { - "id": "gqcCDsGg_sca" - } - }, - { - "cell_type": "markdown", - "source": [ - "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" - ], - "metadata": { - "id": "V4YQCZf1W7X5" - } - }, - { - "cell_type": "markdown", - "source": [ - "![latency_multi_stage]()" - ], - "metadata": { - "id": "Sj5PFl0s_yc6" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "[[65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " ...\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "# Warning: this implementation is incorrect!\n", + "def incorrect_sum_kernel(x_ref, o_ref):\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def incorrect_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", + " return pl.pallas_call(\n", + " incorrect_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = incorrect_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MglScPDD9618" + }, + "source": [ + "This result is completely wrong!\n", + "\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "\n", + "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 104, + "status": "ok", + "timestamp": 1744763774523, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 }, + "id": "XtgD4nMa9_Bd", + "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." - ], - "metadata": { - "id": "ar4NVxxFfKEb" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "[[8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " ...\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]]\n" + ] } - ] -} \ No newline at end of file + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def correct_sum_kernel(x_ref, o_ref):\n", + " @pl.when(pl.program_id(2) == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def correct_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " # We moved the reduction to the last axis of the grid.\n", + " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", + " return pl.pallas_call(\n", + " correct_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = correct_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BckuFg6qcnVw" + }, + "source": [ + "\n", + "## Analyzing the performance\n", + "\n", + "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", + "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", + "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", + "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", + "\n", + "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", + "\n", + "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", + "\n", + "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDY4mcae_nMO" + }, + "source": [ + "\n", + "![pipelining_compute](../../_static/pallas/pipelining_compute_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HFWcaAudW4z1" + }, + "source": [ + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + X / \\beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqcCDsGg_sca" + }, + "source": [ + "\n", + "![pipelining_bandwidth](../../_static/pallas/pipelining_bandwidth_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V4YQCZf1W7X5" + }, + "source": [ + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sj5PFl0s_yc6" + }, + "source": [ + "\n", + "![pipelining_latency](../../_static/pallas/pipelining_latency_multistage.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ar4NVxxFfKEb" + }, + "source": [ + "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index 92b2f2b6bcca..f5570f645429 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -1,6 +1,7 @@ --- jupyter: jupytext: + main_language: python text_representation: extension: .md format_name: markdown @@ -41,9 +42,7 @@ It's reasonable to expect the latency to access SRAM to be on the order of 10x l -![memory_hierarchy]() + +![pipelining_example](../../_static/pallas/pipelining_example.svg) + @@ -564,7 +565,8 @@ In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\a -![compute_bound]() +![pipelining_compute](../../_static/pallas/pipelining_compute_bound.svg) + @@ -572,7 +574,9 @@ In a **memory-bound** regime it is useful to identify if the problem is the late -![bandwidth_bound]() + +![pipelining_bandwidth](../../_static/pallas/pipelining_bandwidth_bound.svg) + @@ -581,7 +585,9 @@ If the bottleneck is specifically the latency and not the bandwidth, it is possi -![latency_multi_stage]() + +![pipelining_latency](../../_static/pallas/pipelining_latency_multistage.svg) + From 5f6b99a143749e3f3c89b00fb6da8e0633f8b083 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 14:09:22 -0700 Subject: [PATCH 0655/1769] Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070 PiperOrigin-RevId: 748418451 --- jax/_src/lax/windowed_reductions.py | 5 +++-- tests/pjit_test.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 8ebdcd6c3784..acbfae37eaf5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -530,7 +530,7 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, base_dilation, window_dilation): if spec is None: continue - if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): + if not (wdim == 1 and ws == 1 and pd == (0, 0) and bd == 1 and wdil == 1): raise core.ShardingTypeError( "Only trivial windowing is supported along non-replicated" f" dimensions. Got {operand.sharding.spec=}") @@ -639,7 +639,8 @@ def _reduce_window_lower( ): operand_aval, = ctx.avals_in - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.with_spec(())) return mlir.reduce_window( ctx, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 025512121a6d..8f30475eee32 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7034,6 +7034,18 @@ def f(x): self.assertArraysEqual(out, np.cumsum(np_inp)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + @jax.jit + def f(x): + x = jnp.expand_dims(x, 1) + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = jnp.cumsum(x, axis=1) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + arr2 = jax.device_put(np.arange(8), P('x')) + out = f(arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_device_put_under_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.zeros((4, 4), dtype=jnp.int32) From b3d5085a1384721b1970f7fcbdd5cd4e8556ba86 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 16 Apr 2025 15:54:09 -0400 Subject: [PATCH 0656/1769] JAX release v0.6.0 --- CHANGELOG.md | 2 +- jax/version.py | 2 +- setup.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d46d9d01a0ff..a4164a51119a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## Unreleased +## JAX 0.6.0 * Breaking changes diff --git a/jax/version.py b/jax/version.py index c0e59702692e..5f419d4baea9 100644 --- a/jax/version.py +++ b/jax/version.py @@ -152,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.3" +_minimum_jaxlib_version = "0.6.0" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 6f9b4a9dd2cb..823354adb70d 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.3' +_current_jaxlib_version = '0.6.0' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.3' +_latest_jaxlib_version_on_pypi = '0.6.0' -_libtpu_version = '0.0.11.*' +_libtpu_version = '0.0.13.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From fb16d53335e27812933f7dd85d7e1a9c84d33506 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 14:09:22 -0700 Subject: [PATCH 0657/1769] Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070 PiperOrigin-RevId: 748418451 --- jax/_src/lax/windowed_reductions.py | 5 +++-- tests/pjit_test.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 8ebdcd6c3784..acbfae37eaf5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -530,7 +530,7 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, base_dilation, window_dilation): if spec is None: continue - if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): + if not (wdim == 1 and ws == 1 and pd == (0, 0) and bd == 1 and wdil == 1): raise core.ShardingTypeError( "Only trivial windowing is supported along non-replicated" f" dimensions. Got {operand.sharding.spec=}") @@ -639,7 +639,8 @@ def _reduce_window_lower( ): operand_aval, = ctx.avals_in - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.with_spec(())) return mlir.reduce_window( ctx, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 025512121a6d..8f30475eee32 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7034,6 +7034,18 @@ def f(x): self.assertArraysEqual(out, np.cumsum(np_inp)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + @jax.jit + def f(x): + x = jnp.expand_dims(x, 1) + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = jnp.cumsum(x, axis=1) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + arr2 = jax.device_put(np.arange(8), P('x')) + out = f(arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_device_put_under_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.zeros((4, 4), dtype=jnp.int32) From a31e53a6c88ee990236a1954841a61432b7c0bd4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 15:14:39 -0700 Subject: [PATCH 0658/1769] Return False in `is_env_present` if importing kubernetes leads to a ModuleNotFoundError PiperOrigin-RevId: 748440123 --- jax/_src/clusters/k8s_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index a3b415df580a..23b2d68e11a5 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -34,7 +34,7 @@ def is_env_present(cls) -> bool: if 'KUBERNETES_SERVICE_HOST' in os.environ: try: import kubernetes as k8s # pytype: disable=import-error - except ImportError as e: + except (ImportError, ModuleNotFoundError): warnings.warn( '\n'.join([ textwrap.fill( From 127aa7621868cb77e552b5d1f90e4a42b09c13fa Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 15:14:39 -0700 Subject: [PATCH 0659/1769] Return False in `is_env_present` if importing kubernetes leads to a ModuleNotFoundError PiperOrigin-RevId: 748440123 --- jax/_src/clusters/k8s_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index a3b415df580a..23b2d68e11a5 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -34,7 +34,7 @@ def is_env_present(cls) -> bool: if 'KUBERNETES_SERVICE_HOST' in os.environ: try: import kubernetes as k8s # pytype: disable=import-error - except ImportError as e: + except (ImportError, ModuleNotFoundError): warnings.warn( '\n'.join([ textwrap.fill( From fbd6db3f117a16f81c439849d652977b0b40c965 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 16 Apr 2025 20:07:20 -0400 Subject: [PATCH 0660/1769] Update version numbers after release. --- CHANGELOG.md | 4 +++- jax/version.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a4164a51119a..7e027db7e32b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## JAX 0.6.0 +## Unreleased + +## JAX 0.6.0 (April 16, 2025) * Breaking changes diff --git a/jax/version.py b/jax/version.py index 5f419d4baea9..9301848b0cfb 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.6.0" +_version = "0.6.1" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None From cbcf3ab65e3e24d841c7c36637105f3545dbf148 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 16 Apr 2025 17:22:06 -0700 Subject: [PATCH 0661/1769] [Pallas] Fix relative links in pipelining docs --- docs/pallas/pipelining.ipynb | 10 +++++----- docs/pallas/pipelining.md | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb index 8342bee2002d..d7871a3b67db 100644 --- a/docs/pallas/pipelining.ipynb +++ b/docs/pallas/pipelining.ipynb @@ -46,7 +46,7 @@ "\n", "\n", "\n", - "![memory_hierarchy](../../_static/pallas/pipelining_mem_hierarchy.svg)\n", + "![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg)\n", "\n" ] }, @@ -181,7 +181,7 @@ }, "source": [ "\n", - "![pipelining_example](../../_static/pallas/pipelining_example.svg)\n" + "![pipelining_example](../_static/pallas/pipelining_example.svg)\n" ] }, { @@ -790,7 +790,7 @@ }, "source": [ "\n", - "![pipelining_compute](../../_static/pallas/pipelining_compute_bound.svg)\n" + "![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg)\n" ] }, { @@ -809,7 +809,7 @@ }, "source": [ "\n", - "![pipelining_bandwidth](../../_static/pallas/pipelining_bandwidth_bound.svg)\n" + "![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg)\n" ] }, { @@ -828,7 +828,7 @@ }, "source": [ "\n", - "![pipelining_latency](../../_static/pallas/pipelining_latency_multistage.svg)\n" + "![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg)\n" ] }, { diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index f5570f645429..42b91e368238 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -42,7 +42,7 @@ It's reasonable to expect the latency to access SRAM to be on the order of 10x l -![memory_hierarchy](../../_static/pallas/pipelining_mem_hierarchy.svg) +![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg) @@ -125,7 +125,7 @@ Note that there is a data-dependence between steps 1-3, and we cannot trivially -![pipelining_example](../../_static/pallas/pipelining_example.svg) +![pipelining_example](../_static/pallas/pipelining_example.svg) @@ -565,7 +565,7 @@ In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\a -![pipelining_compute](../../_static/pallas/pipelining_compute_bound.svg) +![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg) @@ -575,7 +575,7 @@ In a **memory-bound** regime it is useful to identify if the problem is the late -![pipelining_bandwidth](../../_static/pallas/pipelining_bandwidth_bound.svg) +![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg) @@ -586,7 +586,7 @@ If the bottleneck is specifically the latency and not the bandwidth, it is possi -![pipelining_latency](../../_static/pallas/pipelining_latency_multistage.svg) +![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg) From 0a9d0bec5b5a60a38a038248762a146082d49858 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 21:49:02 -0700 Subject: [PATCH 0662/1769] Remove _manual_axes from NamedSharding since we can now track the manual axes on the mesh. PiperOrigin-RevId: 748534841 --- jax/BUILD | 1 + jax/_src/interpreters/mlir.py | 30 +++++++++++++++++++-------- jax/_src/named_sharding.py | 24 ++++++--------------- jax/experimental/shard_map.py | 9 ++++---- jaxlib/xla/sharding.cc | 12 +++-------- jaxlib/xla/sharding.h | 4 +--- jaxlib/xla/xla_extension/__init__.pyi | 2 -- tests/shard_map_test.py | 11 +++------- 8 files changed, 39 insertions(+), 54 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index ca0fb0268c46..862679681c39 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -609,6 +609,7 @@ pytype_strict_library( ":dtypes", ":effects", ":layout", + ":mesh", ":op_shardings", ":partial_eval", ":partition_spec", diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index becdfd46bd92..9979ea151b76 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -49,6 +49,7 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.partition_spec import PartitionSpec +from jax._src.mesh import AxisType from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import (AUTO, NamedSharding, modify_sdy_sharding_wrt_axis_types, @@ -1017,18 +1018,29 @@ class LoweringResult(NamedTuple): def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): - mesh = axis_ctx.mesh + mesh = axis_ctx.mesh.abstract_mesh + sharding_mesh = sharding.mesh.abstract_mesh if (isinstance(sharding, sharding_impls.NamedSharding) and - sharding.mesh.shape == mesh.shape): - return sharding_impls.NamedSharding( - sharding.mesh, sharding.spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + sharding_mesh.shape == mesh.shape): + out_mesh, spec = sharding_mesh, sharding.spec else: - spec = sharding_impls.parse_flatten_op_sharding( + out_mesh, spec = mesh, sharding_impls.parse_flatten_op_sharding( sharding._to_xla_hlo_sharding(ndim), mesh)[0] - return sharding_impls.NamedSharding( - mesh, spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + + out_mesh = out_mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + out = sharding_impls.NamedSharding(out_mesh, spec, + memory_kind=sharding.memory_kind) + manual_axes = out.mesh.manual_axes + if any(p in manual_axes for s in out.spec + if s is not None and s is not PartitionSpec.UNCONSTRAINED + for p in (s if isinstance(s, tuple) else (s,))): + raise ValueError( + f'pspec {out.spec} contains a manual axes {manual_axes} of mesh' + f' which is not allowed. If you are using a' + ' with_sharding_constraint under a shard_map, only use the' + ' mesh axis in PartitionSpec which are not manual.') + return out def _to_physical_op_sharding( diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 9bdadd8a8570..2c6741ab4c9a 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -112,20 +112,17 @@ class NamedSharding(JSharding.Sharding): mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None - _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None @use_cpp_method() def __init__( self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, - memory_kind: str | None = None, _manual_axes=frozenset(), - _logical_device_ids=None): + memory_kind: str | None = None, _logical_device_ids=None): self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._manual_axes = _manual_axes self._logical_device_ids = _logical_device_ids - check_pspec(self.mesh, self.spec, self._manual_axes) + check_pspec(self.mesh, self.spec) def __repr__(self): mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' @@ -137,7 +134,6 @@ def __repr__(self): def __reduce__(self): return (type(self), (self.mesh, self.spec), {'memory_kind': self.memory_kind, - '_manual_axes': self._manual_axes, '_logical_device_ids': self._logical_device_ids}) @property @@ -147,8 +143,7 @@ def memory_kind(self) -> str | None: def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.mesh, self.memory_kind, self.spec, self._manual_axes, - self._logical_device_ids)) + (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash def __eq__(self, other): @@ -158,7 +153,6 @@ def __eq__(self, other): return True if (self.spec != other.spec or self.memory_kind != other.memory_kind - or self._manual_axes != other._manual_axes or self._logical_device_ids != other._logical_device_ids): return False return self.mesh is other.mesh or self.mesh == other.mesh @@ -333,9 +327,7 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() - if t == mesh_lib.AxisType.Manual} - manual_axes = self._manual_axes.union(mesh_manual_axes) + manual_axes = frozenset(self.mesh.manual_axes) if manual_axes: axis_names = self.mesh.axis_names for manual_axis in manual_axes: @@ -420,7 +412,7 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): @cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, spec, _manual_axes) + _check_mesh_resource_axis(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): @@ -455,7 +447,7 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None mesh=mesh, pspec=pspec) -def _check_mesh_resource_axis(mesh, pspec, _manual_axes): +def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue @@ -465,10 +457,6 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes): raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {pspec} " - f"is also found in manual_axes: {_manual_axes}.") from None if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): raise ValueError( 'AxisTypes should be the same in a tuple subset of PartitionSpec:' diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index a6998d29e897..3ffaee614691 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -785,13 +785,12 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, def _make_scoped_manual_sharding(ctx, mesh, axes): axis_ctx = ctx.module_context.axis_context + mesh = mesh.abstract_mesh if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - manual_axes = axis_ctx.manual_axes - else: - manual_axes = frozenset({}) + mesh = mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types - _manual_axes=manual_axes) + mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in, aval_out, x): diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index ff1539764864..b7c7e0a7de72 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -167,8 +167,6 @@ bool ShardingEqual(nb::handle a, nb::handle b) { a_named_sharding->spec().equal(b_named_sharding->spec()) && a_named_sharding->memory_kind().equal( b_named_sharding->memory_kind()) && - a_named_sharding->manual_axes().equal( - b_named_sharding->manual_axes()) && a_named_sharding->logical_device_ids().equal( b_named_sharding->logical_device_ids()); } @@ -204,7 +202,7 @@ static const std::array valid_memory_kinds = { }; NamedSharding::NamedSharding(nb::object mesh, nb::object spec, - nb::object memory_kind, nb::object manual_axes, + nb::object memory_kind, nb::object logical_device_ids) : Sharding(/*num_devices=*/[&mesh]() { return nb::cast(mesh.attr("size")); @@ -212,7 +210,6 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, mesh_(std::move(mesh)), spec_(std::move(spec)), memory_kind_(std::move(memory_kind)), - manual_axes_(std::move(manual_axes)), logical_device_ids_(std::move(logical_device_ids)) { if (spec_.is_none()) { throw nb::type_error( @@ -261,7 +258,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, } return output; }(); - (*check_pspec)(mesh_, spec_, manual_axes_); + (*check_pspec)(mesh_, spec_); } /*static*/ PyObject* NamedSharding::type_ = nullptr; @@ -352,16 +349,13 @@ void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); nb::class_(m, "NamedSharding", nb::dynamic_attr()) - .def(nb::init(), + .def(nb::init(), nb::arg("mesh"), nb::arg("spec").none(), nb::arg("memory_kind").none() = nb::none(), - nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)), nb::arg("_logical_device_ids").none() = nb::none()) .def_prop_ro("mesh", &NamedSharding::mesh) .def_prop_ro("spec", &NamedSharding::spec) .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) - .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { return xla::ValueOrThrow(s.internal_device_list()); diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 4b602bd14324..e0c54592259b 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -75,13 +75,12 @@ bool ShardingEqual(nanobind::handle a, nanobind::handle b); class NamedSharding : public Sharding { public: NamedSharding(nanobind::object mesh, nanobind::object spec, - nanobind::object memory_kind, nanobind::object manual_axes, + nanobind::object memory_kind, nanobind::object logical_device_ids); const nanobind::object& mesh() const { return mesh_; } const nanobind::object& spec() const { return spec_; } const nanobind::object& memory_kind() const { return memory_kind_; } - const nanobind::object& manual_axes() const { return manual_axes_; } const nanobind::object& logical_device_ids() const { return logical_device_ids_; } @@ -102,7 +101,6 @@ class NamedSharding : public Sharding { nanobind::object mesh_; nanobind::object spec_; nanobind::object memory_kind_; - nanobind::object manual_axes_; nanobind::object logical_device_ids_; std::optional> internal_device_list_; static PyObject* type_; diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index de9bb02f6343..7bfb2b1f675b 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -926,14 +926,12 @@ class NamedSharding(Sharding): spec: Any, *, memory_kind: Optional[str] = None, - _manual_axes: frozenset[Any] = frozenset(), _logical_device_ids: tuple[int, ...] | None = None, ): ... mesh: Any spec: Any _memory_kind: Optional[str] _internal_device_list: DeviceList - _manual_axes: frozenset[Any] _logical_device_ids: tuple[int, ...] | None class SingleDeviceSharding(Sharding): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 5ae66417cf59..844e60696479 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2232,17 +2232,12 @@ def g(x): return x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), + check_rep=False, auto=frozenset({'j'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) From 82215f660e979f5802c0e7817882e025e004e0d3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 16 Apr 2025 22:27:02 -0700 Subject: [PATCH 0663/1769] Remove jax_varying_axes_in_types config and `rewrite` from `shard_map_p` PiperOrigin-RevId: 748545142 --- jax/_src/config.py | 9 --- jax/_src/core.py | 6 -- jax/_src/interpreters/batching.py | 6 +- jax/_src/lax/parallel.py | 15 +--- jax/experimental/shard_map.py | 116 ++++++++++++------------------ tests/debug_info_test.py | 11 +-- tests/shard_map_test.py | 4 -- 7 files changed, 52 insertions(+), 115 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index a4d9b5582566..aec8b1450fd0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,7 +235,6 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, - varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, @@ -1092,14 +1091,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) -varying_axes_in_types = bool_state( - name='jax_varying_axes_in_types', - default=True, - help=('Adds varying manual axes to ShapedArray to track which mesh axes the' - ' array is varying over. This will help to remove the efficient' - ' transpose rewrite machinery in shard_map'), - include_in_jit_key=True) - # TODO make it so people don't use this, this is internal... _check_rep = bool_state( name='check_rep', diff --git a/jax/_src/core.py b/jax/_src/core.py index f66fb0928ab2..013219021e05 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2006,8 +2006,6 @@ def pvary(x, axis_name): pvary_p.def_impl(lambda *args, axes, axis_index_groups: args) def _pvary_abstract_eval(*args, axes, axis_index_groups): - if not config.varying_axes_in_types.value: - return args if not config._check_rep.value: return args assert isinstance(axes, tuple) @@ -2027,8 +2025,6 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups): def standard_insert_pvary(*args): - if not config.varying_axes_in_types.value: - return args if not config._check_rep.value: return args if not args: @@ -2040,8 +2036,6 @@ def standard_insert_pvary(*args): if out_vma - src else arg for arg, src in zip(args, in_vma)] def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: - if not config.varying_axes_in_types.value: - return frozenset() if not config._check_rep.value: return frozenset() avals = tuple(a for a in avals if a is not abstract_token) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index e3bed5bda6c9..c97c8d558608 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -407,7 +407,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, def aval(self): aval = core.get_aval(self.val) if self._trace.axis_data.spmd_name is not None: - if config._check_rep.value and config.varying_axes_in_types.value: + if config._check_rep.value: aval = aval.update( vma=aval.vma - frozenset(self._trace.axis_data.spmd_name)) if self.batch_dim is not_mapped: @@ -776,7 +776,7 @@ def _batch_jaxpr2( aval = core.unmapped_aval( axis_data.size, b, aval, axis_data.explicit_mesh_axis) if axis_data.spmd_name is not None: - if config._check_rep.value and config.varying_axes_in_types.value: + if config._check_rep.value: aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore avals_in2.append(aval) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -1111,7 +1111,7 @@ def broadcast(x, sz, axis, mesh_axis=None): # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) - if config._check_rep.value and config.varying_axes_in_types.value: + if config._check_rep.value: # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 spmd_names = core.get_axis_env().spmd_axis_names if len(spmd_names) > 1: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 8ec43c77cdaf..1df3de6f1bee 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -144,7 +144,7 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - if config.varying_axes_in_types.value and config._check_rep.value: + if config._check_rep.value: out_flat = bind_psum_invariant( leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) else: @@ -827,9 +827,6 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): - if not config.varying_axes_in_types.value: - return psum_p.abstract_eval( - *args, axes=axes, axis_index_groups=axis_index_groups) if not config._check_rep.value: return psum_p.abstract_eval( *args, axes=axes, axis_index_groups=axis_index_groups) @@ -865,9 +862,6 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): # TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): - if not config.varying_axes_in_types.value: - return _allreduce_effectful_abstract_eval( - *args, axes=axes, axis_index_groups=axis_index_groups) if not config._check_rep.value: return _allreduce_effectful_abstract_eval( *args, axes=axes, axis_index_groups=axis_index_groups) @@ -1417,8 +1411,6 @@ def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') def insert_collective_pvary(axis_name, x): - if not config.varying_axes_in_types.value: - return x if not config._check_rep.value: return x @@ -1551,8 +1543,6 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def collective_vma_rule(prim_name, axis_name, x_aval): - if not config.varying_axes_in_types.value: - return frozenset() if not config._check_rep.value: return frozenset() axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name @@ -1921,8 +1911,7 @@ def _axis_index_effectful_abstract_eval(*, axis_name): mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) - if config.varying_axes_in_types.value and config._check_rep.value - else frozenset()) + if config._check_rep.value else frozenset()) return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3ffaee614691..51e6acae8aee 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -190,18 +190,13 @@ def out_names_thunk(): raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) - rewrite = check_rep - if rewrite: - if config.varying_axes_in_types.value: - fun = _implicit_pvary_on_output(fun, out_names_thunk) - else: - fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) + if check_rep: + fun = _implicit_pvary_on_output(fun, out_names_thunk) try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) + out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -521,7 +516,6 @@ def _shard_map_staging( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, - rewrite: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: in_tracers = map(trace.to_jaxpr_tracer, in_tracers) @@ -550,7 +544,7 @@ def _shard_map_staging( jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, rewrite=rewrite, auto=auto) + check_rep=check_rep, auto=auto) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, effs, source_info) @@ -586,7 +580,7 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) vma = (frozenset({n for ns in names.values() for n in ns}) - if config.varying_axes_in_types.value and check_rep else frozenset()) + if check_rep else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array @@ -618,7 +612,7 @@ def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, new_sharding = NamedSharding(new_mesh, out_spec) manual_axes = set(new_mesh.manual_axes) vma = (frozenset(v for v in aval.vma if v in manual_axes) - if config.varying_axes_in_types.value and check_rep else frozenset()) + if check_rep else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array @@ -627,7 +621,7 @@ def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, RepType = Any def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): + check_rep, auto): # TODO(mattjj,parkers): check auto for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval( @@ -637,11 +631,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, with _extend_axis_env(mesh, auto), config._check_rep(check_rep): core.check_jaxpr(jaxpr) if check_rep: - if config.varying_axes_in_types.value: - out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] - else: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) + out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): if not _valid_repeats(mesh, auto, rep, dst): raise core.JaxprTypeError("shard_map can't prove output is " @@ -760,7 +750,7 @@ def _shard_map_lowering_shardy( def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): + check_rep, auto): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep) @@ -866,7 +856,7 @@ def _vma_to_rep(mesh, auto, vma): return frozenset((set(mesh.axis_names) - auto) - vma) def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, rewrite, auto): + check_rep, auto): if auto: raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): @@ -878,7 +868,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep and config.varying_axes_in_types.value: + if check_rep: _check_reps(mesh, auto, out_names_thunk(), out_rep) src_pspecs = tuple(_rep_to_spec(mesh, auto, r) for r in out_rep) else: @@ -912,7 +902,7 @@ def _unmatch_spec(mesh: Mesh, check_rep, src: AxisNames, x: JaxType, def _unmatch(mesh, check_rep, src_tup, x): src = _names_to_pspec(dict(src_tup)) - if check_rep and config.varying_axes_in_types.value: + if check_rep: used_axes = {i for _, ns in src_tup for i in ns} dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) else: @@ -951,8 +941,6 @@ def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) def _match(mesh, check_rep, src_pspec, dst_pspec, x): - if not config.varying_axes_in_types.value: - check_rep = False return shard_map(_rem_singleton, mesh, src_pspec, dst_pspec, check_rep=check_rep)(x) @@ -990,29 +978,24 @@ def to_val_rep_pair(self, val): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) - if self.check and config.varying_axes_in_types.value: + if self.check: return val_, frozenset(self.mesh.axis_names) - self.auto else: return val_, None def process_primitive(self, prim, tracers, params): in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - if config.varying_axes_in_types.value: - if self.check: - in_vma = tuple(map(partial(_rep_to_vma, self.mesh, self.auto), in_rep)) - out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) - out_avals = tuple(out_avals) if type(out_avals) is list else out_avals - out_vma = tree_map(lambda a: a.vma, out_avals) - out_rep = tree_map(partial(_vma_to_rep, self.mesh, self.auto), out_vma) - in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) - out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) - else: - out_rep = frozenset() - in_specs = out_specs = P(self.mesh.axis_names) + if self.check: + in_vma = tuple(map(partial(_rep_to_vma, self.mesh, self.auto), in_rep)) + out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) + out_avals = tuple(out_avals) if type(out_avals) is list else out_avals + out_vma = tree_map(lambda a: a.vma, out_avals) + out_rep = tree_map(partial(_vma_to_rep, self.mesh, self.auto), out_vma) + in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) + out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) else: + out_rep = frozenset() in_specs = out_specs = P(self.mesh.axis_names) - rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) - out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() eager_rule = eager_rules.get(prim) if eager_rule: @@ -1083,9 +1066,8 @@ def aval(self): _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error manual_axes = set(self._trace.mesh.axis_names) - self._trace.auto - vma = (frozenset(manual_axes - self.rep) - if config.varying_axes_in_types.value and config._check_rep.value - else frozenset()) + vma = (frozenset(manual_axes - self.rep) if config._check_rep.value else + frozenset()) return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): @@ -1113,8 +1095,6 @@ def apply(*args): outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) return tree_map(_add_singleton, outs) out_specs = list(out_specs) if type(out_specs) is tuple else out_specs - if not config.varying_axes_in_types.value: - check_rep = False return shard_map(apply, mesh, in_specs, out_specs, check_rep=check_rep)(*args) eager_rules: dict[core.Primitive, Callable] = {} @@ -1578,7 +1558,6 @@ def _shard_map_batch( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, - rewrite: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): @@ -1604,7 +1583,7 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, @@ -1627,7 +1606,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -1642,7 +1621,7 @@ def new_out_names_thunk(): return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) @@ -1653,7 +1632,7 @@ def new_out_names_thunk(): def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) @@ -1670,7 +1649,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, def known_out_names(): _, _, out_knowns, res_avals, _, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - if check_rep and config.varying_axes_in_types.value: + if check_rep: res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} for a in res_avals] else: @@ -1679,7 +1658,7 @@ def known_out_names(): known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() @@ -1698,7 +1677,7 @@ def known_out_names(): elif f2 is not None: res_names.append(known_out_names_[f2]) else: - if check_rep and config.varying_axes_in_types.value: + if check_rep: res_vma = next(res_avals_iter).vma res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) else: @@ -1710,9 +1689,7 @@ def known_out_names(): out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, - check_rep=(check_rep if config.varying_axes_in_types.value - else False), - rewrite=rewrite, auto=auto) + check_rep=check_rep, auto=auto) out_avals = map(partial(_unshard_aval, mesh, check_rep), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) @@ -1727,7 +1704,7 @@ def known_out_names(): def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) @@ -1738,7 +1715,7 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, def fwd_out_names_thunk(): res_avals, _, _, _, _, _ = linearize_outs_thunk() out_names = out_names_thunk() - if check_rep and config.varying_axes_in_types.value: + if check_rep: res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} for a in res_avals] else: @@ -1746,8 +1723,7 @@ def fwd_out_names_thunk(): return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, auto=auto) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() @@ -1770,7 +1746,7 @@ def fwd_out_names_thunk(): elif f2 is not None: res_names.append(out_names[f2]) else: - if check_rep and config.varying_axes_in_types.value: + if check_rep: res_vma = next(res_avals_iter).vma res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) else: @@ -1783,8 +1759,7 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=(check_rep if config.varying_axes_in_types.value else False), - rewrite=rewrite, auto=auto) + check_rep=check_rep, auto=auto) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): @@ -1846,11 +1821,11 @@ def _unmentioned2(mesh: Mesh, names: AxisNames, def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): + check_rep, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ ad.Zero(_shard_aval(mesh, auto, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if rewrite or dtypes.dtype(x) == dtypes.float0 + if type(x) is ad.Zero else x if check_rep or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts) ] @@ -1871,7 +1846,7 @@ def fun_trans_callable(out_cts, args): )[len(res_reshaped):] _, in_ct_names = partition_list(in_undef, in_names) in_cts = [ad.Zero(_unshard_aval(mesh, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if rewrite + if type(x) is ad.Zero else x if check_rep else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_ct_names, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] @@ -1891,7 +1866,7 @@ def new_out_names_thunk(): try: out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, + out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " @@ -1903,7 +1878,7 @@ def new_out_names_thunk(): _ = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: @@ -1946,11 +1921,10 @@ def _partial_eval_jaxpr_custom_rule( for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore - if check_rep and config.varying_axes_in_types.value - else {0: _all_newly_manual_mesh_names(mesh, auto)}) + if check_rep else {0: _all_newly_manual_mesh_names(mesh, auto)}) residuals.append(newvar(_unshard_aval(mesh, check_rep, rn, var.aval))) staged_in_res_names.append(rn) - if check_rep and config.varying_axes_in_types.value: + if check_rep: out_res_names_known = [ {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} for var, o in zip(res_vars, out_fwd) if o is None @@ -2028,8 +2002,6 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), out_names=tuple(out_names_staged)) - if not config.varying_axes_in_types.value: - new_params_staged.update(check_rep=False) return new_params_known, new_params_staged # TODO(mattjj): remove this mechanism when we revise mesh scopes diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 0fc1aabbaf57..d014c86c2506 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -1763,13 +1763,6 @@ def my_f(x): tracer_spy.append(x) return jnp.sin(jnp.sin(x)) - if config.varying_axes_in_types.value: - expected_tracer_debug_infos = [ - "traced_for=shard_map, fun=my_f, arg_names=x, from x" - ] - else: - expected_tracer_debug_infos = ["None"] - self._check_tracers_and_jaxprs( jax.jit(jax.grad(lambda x: my_f(x).sum())), jnp.arange(2, dtype=np.float32), @@ -1781,7 +1774,9 @@ def my_f(x): "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result", "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], - expected_tracer_debug_infos=expected_tracer_debug_infos) + expected_tracer_debug_infos=[ + "traced_for=shard_map, fun=my_f, arg_names=x, from x" + ]) def test_remat_saved_residuals(self): @functools.partial(jax.remat, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 844e60696479..1dd59daee873 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1665,7 +1665,6 @@ def example(x, y): dx, dy = example(x, y) self.assertEqual(dy.dtype, jax.dtypes.float0) - @config.varying_axes_in_types(True) def test_pvary(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -2738,7 +2737,6 @@ def test_pmax(self): mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) - @config.varying_axes_in_types(True) def test_pmax_vma_in_types(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) @@ -2748,7 +2746,6 @@ def test_pmax_vma_in_types(self): self.assertIn("pvary[axes=('i',)", str(jaxpr)) f(x) # doesn't crash - @config.varying_axes_in_types(True) def test_mul_with_vma_in_types(self): mesh = jtu.create_mesh((2,), ('x',)) x = np.arange(8.) @@ -2770,7 +2767,6 @@ def f(x): # return jnp.sum(f(x, y)) # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) - @config.varying_axes_in_types(True) def test_all_gather_with_vma_in_types(self): mesh = jtu.create_mesh((2,), ('x',)) x = np.arange(8.) From 76bb4953f94fab2393dec2ab3d1dc8f139d1639b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 16 Apr 2025 22:53:40 -0700 Subject: [PATCH 0664/1769] Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748551976 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 2 +- jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 2 +- jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 5ed5e94b13c0..341ead8431b4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1096,7 +1096,7 @@ LogicalResult ConcatenateOp::verify() { if (getOperands().size() < 2) { return emitOpError("Expected at least 2 operands for concatenate op."); } - auto first_type = getOperand(0).getType().cast(); + auto first_type = cast(getOperand(0).getType()); auto first_shape = first_type.getShape(); auto first_dtype = first_type.getElementType(); for (auto operand : getOperands()) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 373a5db6b4f6..247f47431745 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -257,7 +257,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, auto matmul_res = dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), sliced_acc.getResult()); - auto res_ty = matmul_res.getType().cast(); + auto res_ty = cast(matmul_res.getType()); auto res_shape = res_ty.getShape(); // reshape to 1x[prior_shape] auto reshape_shape = llvm::to_vector(res_shape); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 54ac777fc205..c81701d9a398 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -136,10 +136,10 @@ class VectorLayoutInferer { bool has_vector_io = false; for (auto op : any_op.getOperands()) { - has_vector_io |= op.getType().isa(); + has_vector_io |= isa(op.getType()); } for (auto r : any_op.getResults()) { - has_vector_io |= r.getType().isa(); + has_vector_io |= isa(r.getType()); } if (!has_vector_io && any_op.getRegions().empty()) { SmallVector in_layout(any_op.getNumOperands(), kNoLayout); @@ -1293,7 +1293,7 @@ class VectorLayoutInferer { (*(offsets.end() - 1) + *input_layout->offsets()[1]) % vreg_slice[1]; } for (auto stride : strides_attr) { - TPU_CHECK_OP(stride.cast().getInt() == 1, + TPU_CHECK_OP(cast(stride).getInt() == 1, "Only trivial strides supported."); } From c576d328bd1e3078b0e49739685bd1c1fd755e17 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 17 Apr 2025 02:21:40 -0700 Subject: [PATCH 0665/1769] Added `lax.axis_size` and switched all existing usage of `psum(1, ...)` to it PiperOrigin-RevId: 748604842 --- CHANGELOG.md | 4 + docs/jax.lax.rst | 1 + docs/notebooks/shard_map.ipynb | 12 +-- docs/notebooks/shard_map.md | 12 +-- jax/_src/lax/parallel.py | 74 +++++++++++++++---- jax/_src/nn/functions.py | 2 +- jax/_src/numpy/reductions.py | 2 +- jax/_src/pallas/mosaic/helpers.py | 4 +- jax/_src/pallas/mosaic/pipeline.py | 2 +- .../jax2tf/tests/sharding_test.py | 4 +- jax/experimental/pallas/ops/tpu/all_gather.py | 6 +- jax/lax/__init__.py | 1 + tests/api_test.py | 2 +- tests/export_back_compat_test.py | 4 +- tests/export_test.py | 2 +- tests/lax_control_flow_test.py | 2 +- tests/pallas/mosaic_gpu_test.py | 2 +- tests/pallas/pallas_test.py | 8 +- tests/pallas/tpu_pallas_async_test.py | 4 +- tests/pallas/tpu_pallas_distributed_test.py | 6 +- tests/pallas/tpu_pallas_state_test.py | 2 +- tests/pmap_test.py | 10 +-- tests/shard_map_test.py | 8 +- 23 files changed, 113 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e027db7e32b..f5494cc9d7a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* New features: + * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis + given its name. + ## JAX 0.6.0 (April 16, 2025) * Breaking changes diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 9db79f591a4e..43937130e5f4 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -222,6 +222,7 @@ Parallel operators pshuffle pswapaxes axis_index + axis_size Sharding-related operators -------------------------- diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index eb8f54f70d7a..f916147ef589 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -1230,7 +1230,7 @@ "source": [ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f7(x_block):\n", - " sz = jax.lax.psum(1, 'i')\n", + " sz = jax.lax.axis_size('i')\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])\n", " print('AFTER:\\n', y_block)\n", @@ -1287,7 +1287,7 @@ "outputs": [], "source": [ "def psum_scatter(x, axis_name, *, tiled=False):\n", - " size = jax.lax.psum(1, axis_name)\n", + " size = jax.lax.axis_size(axis_name)\n", " idx = jax.lax.axis_index(axis_name) # function instance index along axis_name\n", " if tiled:\n", " x = x.reshape(size, -1, *x.shape[1:]) # split leading axis\n", @@ -1550,7 +1550,7 @@ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1599,7 +1599,7 @@ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1708,7 +1708,7 @@ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i - 1) % size) for i in range(size)])\n", @@ -1751,7 +1751,7 @@ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index ae9206059b1e..ba23d7d17f3e 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -866,7 +866,7 @@ instance to each destination: ```{code-cell} @partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f7(x_block): - sz = jax.lax.psum(1, 'i') + sz = jax.lax.axis_size('i') print('BEFORE:\n', x_block) y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)]) print('AFTER:\n', y_block) @@ -911,7 +911,7 @@ this iteration. In code, it might look like this: ```{code-cell} def psum_scatter(x, axis_name, *, tiled=False): - size = jax.lax.psum(1, axis_name) + size = jax.lax.axis_size(axis_name) idx = jax.lax.axis_index(axis_name) # function instance index along axis_name if tiled: x = x.reshape(size, -1, *x.shape[1:]) # split leading axis @@ -1084,7 +1084,7 @@ multiplies: @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -1115,7 +1115,7 @@ each half in each direction: @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -1182,7 +1182,7 @@ interleave the communication steps with local matrix multiplies: @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) @@ -1207,7 +1207,7 @@ bidirectional version: @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 1df3de6f1bee..89a5799d9541 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -43,6 +43,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) +import jax.numpy as jnp import numpy as np unsafe_map, map = map, safe_map # type: ignore @@ -195,7 +196,7 @@ def pmean(x, axis_name, *, axis_index_groups=None): [0. 0.6666667 1.3333334 2. ] """ x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) - n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) + n = _axis_size(axis_name, axis_index_groups) return tree_util.tree_map(lambda v: v / n, x) def pmax(x, axis_name, *, axis_index_groups=None): @@ -446,14 +447,14 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) where ``axis_size`` is the size of the mapped axis named ``axis_name`` in - the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. + the input ``x``. Otherwise array with shape similar to the input shape, except with split_axis divided by axis size and concat_axis multiplied by axis size. """ axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) def bind(x, split_axis=split_axis, concat_axis=concat_axis): - group_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + group_size = _axis_size(axis_name, axis_index_groups) if tiled: if x.shape[split_axis] % group_size != 0: raise ValueError(f"The size of all_to_all split_axis ({x.shape[split_axis]}) " @@ -638,7 +639,7 @@ def ragged_all_to_all( axis_index_groups=axis_index_groups) -def axis_index(axis_name): +def axis_index(axis_name: AxisName) -> jax.Array: """Return the index along the mapped axis ``axis_name``. Args: @@ -654,16 +655,16 @@ def axis_index(axis_name): ... def f(_): ... return lax.axis_index('i') ... - >>> f(np.zeros(4)) + >>> f(jnp.zeros(4)) Array([0, 1, 2, 3], dtype=int32) - >>> f(np.zeros(8)) + >>> f(jnp.zeros(8)) Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... - >>> x, y = f(np.zeros((4, 2))) + >>> x, y = f(jnp.zeros((4, 2))) >>> print(x) [[0 0] [1 1] @@ -679,12 +680,54 @@ def axis_index(axis_name): return axis_index_p.bind(axis_name=axis_name) else: inner_size = 1 - index = 0 + index = jnp.asarray(0) for name in reversed(axis_name): index += axis_index(name) * inner_size - inner_size *= psum(1, name) + inner_size *= axis_size(name) return index + +def axis_size(axis_name: AxisName) -> int: + """Return the size of the mapped axis ``axis_name``. + + Args: + axis_name: hashable Python object used to name the mapped axis. + + Returns: + An integer representing the size. + + For example, with 8 XLA devices available: + + >>> from functools import partial + >>> from jax.experimental.shard_map import shard_map + >>> from jax.sharding import PartitionSpec as P + >>> mesh = jax.make_mesh((8,), 'i') + >>> @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + ... def f(_): + ... return lax.axis_size('i') + ... + >>> f(jnp.zeros(16)) + Array(8, dtype=int32, weak_type=True) + >>> mesh = jax.make_mesh((4, 2), ('i', 'j')) + >>> @partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) + ... def f(_): + ... return lax.axis_size(('i', 'j')) + ... + >>> f(jnp.zeros((16, 8))) + Array(8, dtype=int32, weak_type=True) + """ + return _axis_size(axis_name) + + +def _axis_size( + axis_name: AxisName, + axis_index_groups: Sequence[Sequence[int]] | None = None, + /, +) -> int: + axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + return psum(1, axis_name, axis_index_groups=axis_index_groups) + + def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" if not isinstance(axes, (tuple, list)): @@ -692,7 +735,6 @@ def pgather(src, idx, axes: int | AxisName): # TODO: Canonicalize exes! return pgather_p.bind(src, idx, axes=tuple(axes)) - ### parallel primitives def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: @@ -1254,7 +1296,11 @@ def _all_to_all_effectful_abstract_eval( axis_name = (axis_name,) _check_axis_names(axis_name) shape = list(input_aval.shape) - axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) + axis_size = ( + _axis_size(axis_name) + if axis_index_groups is None + else len(axis_index_groups[0]) + ) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size @@ -1487,7 +1533,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): if not isinstance(axis_name, tuple): axis_name = axis_name, axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_size = _axis_size(axis_name, axis_index_groups) def bind(leaf): leaf = insert_collective_pvary(axis_name, leaf) return all_gather_p.bind( @@ -1495,7 +1541,7 @@ def bind(leaf): all_gather_dimension=canonicalize_axis( axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis_name=axis_name, axis_index_groups=axis_index_groups, - axis_size=int(axis_size), tiled=tiled) + axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): @@ -1849,7 +1895,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, """ if not isinstance(axis_name, tuple): axis_name = axis_name, - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_size = _axis_size(axis_name, axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) def bind(leaf): leaf = insert_collective_pvary(axis_name, leaf) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 640c8b89d001..f16deb41a69b 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -663,7 +663,7 @@ def _one_hot(x: Array, num_classes: int, *, try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: - axis_size = lax.psum(1, axis) + axis_size = lax.axis_size(axis) if num_classes != axis_size: raise ValueError(f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 96b2782edc13..d708f3dd4c87 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -796,7 +796,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size = 1 a_shape = np.shape(a) for a in axis_seq: - size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) + size *= maybe_named_axis(a, lambda i: a_shape[i], jax.lax.axis_size) return size diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 24cd7cad6086..80bb4ef4abed 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -60,7 +60,7 @@ def _copy_start_or_wait(action, src_ref, dst_ref): def run_on_first_core(core_axis_name: str): """Runs a function on the first core in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) if num_cores == 1: return lambda f: f() @@ -77,7 +77,7 @@ def _(): def core_barrier(sem, *, core_axis_name: str): """Synchronizes all cores in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) core_id = jax.lax.axis_index(core_axis_name) @pl_helpers.when(num_cores > 1) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 9b0a9322c94d..5153ee6dcce3 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -988,7 +988,7 @@ def _partition_grid( num_cores = pl.num_programs(core_axis) core_id = pl.program_id(core_axis) else: - num_cores = jax.lax.psum(1, core_axis) + num_cores = jax.lax.axis_size(core_axis) core_id = jax.lax.axis_index(core_axis) # Check that num_cores is statically known if not isinstance(num_cores, int): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 653ddce7dca4..8651fe4e62d4 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -576,7 +576,7 @@ def test_repro_xla_bug_shmap_collective_permute(self): @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, 'x', perm=perm) @@ -612,7 +612,7 @@ def test_shmap_collective_permute(self, poly=None): @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, 'x', perm=perm) diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index 8fb975504e26..dbfde9eb5177 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -48,7 +48,7 @@ def get_neighbor( idx if i == which_axis else lax.axis_index(a) for i, a in enumerate(axis_names) ] - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) if direction == "right": next_idx = lax.rem(idx + 1, axis_size) else: @@ -67,7 +67,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() with jax.named_scope("neighbour_lookup"): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") @@ -131,7 +131,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], # We can short-circuit here if our axis size is 1 return x def ag_local(x_shard): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype) out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index e8ec74f59a7d..c6df458ba91d 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -361,6 +361,7 @@ all_to_all_p as all_to_all_p, axis_index as axis_index, axis_index_p as axis_index_p, + axis_size as axis_size, pbroadcast as pbroadcast, pmax as pmax, pmax_p as pmax_p, diff --git a/tests/api_test.py b/tests/api_test.py index 2d1055516074..e8d0d6590f2b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1483,7 +1483,7 @@ def f(k): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - f = lambda: lax.psum(1, "i") + f = lambda: lax.axis_size("i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)() diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 888b234e94c0..a5a3c984a0c8 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -805,7 +805,7 @@ def test_tpu_sharding(self): @partial(shard_map, mesh=mesh, in_specs=(P('a', None),), out_specs=P('a', None)) def func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size('a') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) @@ -1001,7 +1001,7 @@ def func(x): # x: f32[4, 4] @partial(shard_map, mesh=old_mesh, in_specs=(P('a', None),), out_specs=P('a', None)) def shard_map_func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size('a') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None))) diff --git a/tests/export_test.py b/tests/export_test.py index e50738ba2480..fcd5572edf5a 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1343,7 +1343,7 @@ def test_shard_map_collective_permute(self, poly=None): shard_map, mesh=mesh, in_specs=(P("x", None),), out_specs=P("x", None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, "x") + axis_size = lax.axis_size("x") perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, "x", perm=perm) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 242b0548023e..8876fb7d06be 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2237,7 +2237,7 @@ def body(x): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + scanned_f = lambda _, __: (lax.axis_size('i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() self.assertEqual(ans, 2) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9ad9038dfc49..3d6a853e7461 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2515,7 +2515,7 @@ def kernel(o_ref): xy_idx = jax.lax.axis_index(("x", "y")) yx_idx = jax.lax.axis_index(("y", "x")) wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") + num_wgs = jax.lax.axis_size("wg") o_ref[xy_idx, wg_idx] = jnp.broadcast_to( yx_idx * num_wgs + wg_idx, (128,) ) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 781934ecd682..045c47a8dd71 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2411,8 +2411,8 @@ def kernel(x_ref, y_ref): def test_can_query_named_grid_size_in_kernel_via_psum(self): def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128)) @@ -2432,8 +2432,8 @@ def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self): self.skipTest("Not supported.") def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128)) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 3dfc9bf1637a..36aba917e4ed 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -398,7 +398,7 @@ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) @@ -492,7 +492,7 @@ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 3d4d441d7cd0..bb46d1d18772 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -98,7 +98,7 @@ def test_pallas_call_axis_index(self, direction): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, num_devices) else: @@ -152,7 +152,7 @@ def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') my_other_id = lax.axis_index('y') - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, axis_size) else: @@ -208,7 +208,7 @@ def test_barrier_semaphore(self): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') neighbor = lax.rem(my_id + 1, num_devices) barrier_sem = pltpu.get_barrier_semaphore() pltpu.semaphore_signal(barrier_sem, device_id=neighbor) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index 46f98c087110..9977c18d939d 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -228,7 +228,7 @@ def inner(refs): x_ref, y_ref = refs @pl.core_map(mesh) def _(): - num_cores = jax.lax.psum(1, "x") + num_cores = jax.lax.axis_size("x") slc_size = 16 // num_cores def alloc(x_vmem_ref, y_vmem_ref, sem): core_index = jax.lax.axis_index("x") diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a07a9e271907..e22333609ab7 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -499,7 +499,7 @@ def testReduceScatterReplicaGroupsTiled(self): def testTrees(self): ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0) def protate(x, axis_name): - n = lax.psum(1, axis_name) + n = lax.axis_size(axis_name) return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)]) tree_f = lambda f: partial(jax.tree.map, f) @@ -1395,7 +1395,7 @@ def testNestedPmapConstantError(self): def testCollectiveConstant(self): device_count = jax.device_count() - f = self.pmap(lambda x: lax.psum(1, 'i'), 'i') + f = self.pmap(lambda x: lax.axis_size('i'), 'i') x = jnp.arange(device_count) ans = f(x) expected = np.repeat(device_count, device_count) @@ -1408,9 +1408,9 @@ def testCollectiveConstantNested(self): def f(x): @partial(self.pmap, axis_name='j') def g(y): - a = lax.psum(1, 'i') - b = lax.psum(1, 'j') - c = lax.psum(1, ('i', 'j')) + a = lax.axis_size('i') + b = lax.axis_size('j') + c = lax.axis_size(('i', 'j')) return a, b, c return g(x) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1dd59daee873..ae7f532c1526 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -218,7 +218,7 @@ def test_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) @@ -240,8 +240,8 @@ def test_collective_permute_with_multiple_axis_names(self): out_specs=P('x', ('y', 'z')), ) def fwd(a): - xy_axis_size = lax.psum(1, ('x', 'y')) - yz_axis_size = lax.psum(1, ('y', 'z')) + xy_axis_size = lax.axis_size(('x', 'y')) + yz_axis_size = lax.axis_size(('y', 'z')) xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)] yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)] return ( @@ -3383,7 +3383,7 @@ def test_shardy_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) From 519f490916e255b50da38c65b9ff4d15b2a0f753 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 17 Apr 2025 11:34:21 +0200 Subject: [PATCH 0666/1769] Suppress race between split_keys_entry_added and dict_dict_merge in 3.13 TSAN CI --- .github/workflows/tsan-suppressions_3.13.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index dac134bf5169..f4fbf830ddc2 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -67,6 +67,10 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added +# https://github.com/python/cpython/issues/132245 +race:split_keys_entry_added +race_top:dict_dict_merge + # https://github.com/python/cpython/issues/129547 # Maybe fixed? # race:type_get_annotations From 4ceb4b0526cd68c0f78d6bab325008d437a0f0d6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 17 Apr 2025 04:36:29 -0700 Subject: [PATCH 0667/1769] Do not use `-> ...` It is a non-standard pytype feature which is not supported by any other type checker. PiperOrigin-RevId: 748636378 --- jax/_src/tpu_custom_call.py | 2 +- jaxlib/xla/xla_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 32236bb6ae90..d7a921f1d952 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -252,7 +252,7 @@ def _tpu_custom_call_lowering( kernel_name: str | None, out_avals: Any, input_output_aliases: tuple[tuple[int, int], ...], -) -> ...: +) -> ir.OpResultList: result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 543664682c08..69f24ee13f6f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -78,7 +78,7 @@ def make_cpu_client( num_nodes=1, collectives=None, num_devices=None, -) -> ...: +) -> Client: register_custom_call_handler('cpu', _xla.register_custom_call_target) register_custom_type_id_handler('cpu', _xla.register_custom_type_id) return _xla.get_tfrt_cpu_client( From f2946ddceb4fd1415b844741ff302c88373ebcbf Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 17 Apr 2025 04:52:03 -0700 Subject: [PATCH 0668/1769] Automated Code Change PiperOrigin-RevId: 748639826 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/mlir.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index a6a4cf660408..8a19adc854ef 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -397,6 +397,7 @@ cc_library( "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:status_casters", "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", "@xla//xla/tsl/platform:statusor", diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc index 29ef86d50df6..76663a79556a 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/xla/mlir.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" +#include "xla/service/hlo.pb.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" From 0149a32c67af207ea96f5dd7f9d0b49a2b23fa08 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 17 Apr 2025 06:52:22 -0700 Subject: [PATCH 0669/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a5c81f946fd2717c13075221b692db59b513a6eb. PiperOrigin-RevId: 748666210 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0ac09a4a6594..846ec8d56d97 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0d1b60216ea13b0d261d59552a0f7ef20c4f76c5" -XLA_SHA256 = "357b37cc7c439580344ce0305bad88ef841f29743a99ea8e2253e64a32e139c6" +XLA_COMMIT = "a5c81f946fd2717c13075221b692db59b513a6eb" +XLA_SHA256 = "ae1f5afc050d3d6add7c8b03e8f82227c0e9931382889347e4d22ca38af1026c" def repo(): tf_http_archive( From 21f9d236581277e0fc25ec96f1a08f390efa855d Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Wed, 2 Oct 2024 11:13:00 -0400 Subject: [PATCH 0670/1769] BUG: return one covariance matrix per `rhs` in `polyfit` Addresses #24073 Taking the first residual from `resids` means that only the first set of coefficients would get a covariance matrix. This moves the line to make a `(1,)` shape array into an `int` to the branch where there is only one `rhs`. Fixes typos in `polyfit` docstring fix int/ shape (1,) array test failures add failing test with multiple `rhs` Make `weights` of shape `(M,)` in `polyfit` test Co-authored-by: Dan Foreman-Mackey Fix rank promotion problems for `polyfit` Move `fac` expansion out to any size `y_arr` Fix last `fac` broadcasting test case --- jax/_src/numpy/polynomial.py | 37 ++++++++++++++++++++++-------------- tests/lax_numpy_test.py | 6 ++++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 81d320cb7403..2b2923ba93ce 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -146,7 +146,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, rcond: Relative condition number of the fit. Default value is ``len(x) * eps``. It must be specified statically. full: Switch that controls the return value. Default is ``False`` which - restricts the return value to the array of polynomail coefficients ``p``. + restricts the return value to the array of polynomial coefficients ``p``. If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``. It must be specified statically. w: Array of weights of shape ``(M,)``. If None, all data points are considered @@ -154,8 +154,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where :math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None. cov: Boolean or string. If ``True``, returns the covariance matrix scaled - by ``resids/(M-deg-1)`` along with ploynomial coefficients. If - ``cov='unscaled'``, returns the unscaaled version of covariance matrix. + by ``resids/(M-deg-1)`` along with polynomial coefficients. If + ``cov='unscaled'``, returns the unscaled version of covariance matrix. Default is ``False``. ``cov`` is ignored if ``full=True``. It must be specified statically. @@ -224,7 +224,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, >>> p, C = jnp.polyfit(x, y, 2, cov=True) >>> p.shape, C.shape - ((3, 3), (3, 3, 1)) + ((3, 3), (3, 3, 3)) """ if w is None: x_arr, y_arr = ensure_arraylike("polyfit", x, y) @@ -233,7 +233,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, del x, y deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 - # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x_arr.ndim != 1: @@ -245,7 +244,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") - # set rcond if rcond is None: rcond = len(x_arr) * float(finfo(x_arr.dtype).eps) rcond = core.concrete_or_error(float, rcond, "rcond must be float") @@ -268,9 +266,17 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) - lhs /= scale[np.newaxis,:] + lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) - c = (c.T/scale).T # broadcast scale coefficients + + # Broadcasting scale coefficients + if c.ndim > 1: + # For multi-dimensional output, make scale (1, order) to divide + # across the c.T of shape (num_rhs, order) + c = (c.T / scale[np.newaxis, :]).T + else: + # Simple case for 1D output + c = c / scale if full: assert rcond is not None @@ -278,22 +284,25 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) + if cov == "unscaled": - fac = 1 + fac = array(1.0) else: if len(x_arr) <= order: - raise ValueError("the number of data points must exceed order " - "to scale the covariance matrix") + raise ValueError("the number of data points must exceed order" + " to scale the covariance matrix") fac = resids / (len(x_arr) - order) - fac = fac[0] #making np.array() of shape (1,) to int + if y_arr.ndim == 1: + fac = atleast_1d(fac)[np.newaxis] + # For 1D output, simple scalar multiplication return c, Vbase * fac else: - return c, Vbase[:, :, np.newaxis] * fac + # For multiple rhs, broadcast fac to match shape + return c, Vbase[:, :, np.newaxis] * atleast_1d(fac)[np.newaxis, np.newaxis, :] else: return c - @export @jit def poly(seq_of_zeros: ArrayLike) -> Array: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index de943c3b613a..9910d1eedc68 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -968,6 +968,7 @@ def np_fun(lhs, rhs): @jtu.sample_product( dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], shape=[shape for shape in one_dim_array_shapes if shape != (1,)], + num_rhs=[1, 5], deg=[1, 2, 3], rcond=[None, -1, 10e-3, 10e-5, 10e-10], full=[False, True], @@ -975,12 +976,13 @@ def np_fun(lhs, rhs): cov=[False, True, "unscaled"], ) @jax.default_matmul_precision("float32") - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + def testPolyfit(self, shape, num_rhs, dtype, deg, rcond, full, w, cov): rng = jtu.rand_default(self.rng()) tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} tol = jtu.tolerance(dtype, tol_spec) _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + rhs_shape = shape + (num_rhs,) if num_rhs > 1 else shape + args_maker = lambda: [rng(shape, dtype), rng(rhs_shape, dtype), rng(shape, dtype)] jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) np_fun = jtu.ignore_warning( message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) From 7d5c26910809215587de5ac925f8cb5caa07e41f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Apr 2025 08:31:52 -0700 Subject: [PATCH 0671/1769] Remove shard_map rewrite machinery since we have `vma` in types by default PiperOrigin-RevId: 748690834 --- jax/experimental/roofline/rooflines.py | 2 + jax/experimental/shard_map.py | 697 +------------------------ 2 files changed, 15 insertions(+), 684 deletions(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1a84095a0e31..9db3bd4e289c 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -109,6 +109,8 @@ def _unary_p_roofline( roofline.register_roofline(special.erfc_p)(_unary_p_roofline) roofline.register_roofline(special.lgamma_p)(_unary_p_roofline) +roofline.register_standard_roofline(core.pvary_p) + def _binary_p_roofline( ctx: roofline.RooflineRuleContext, *args, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 51e6acae8aee..713655897789 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -17,7 +17,6 @@ import enum from functools import partial import inspect -import itertools as it from math import prod import operator as op from typing import Any, TypeVar, Union @@ -27,39 +26,28 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec -from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api_util -from jax._src import callback from jax._src import config from jax._src import core -from jax._src import custom_derivatives as cd from jax._src import debugging from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu -from jax._src import ops -from jax._src import pjit -from jax._src import prng -from jax._src import random from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import util -from jax._src.core import pvary, pvary_p +from jax._src.core import pvary from jax._src.core import Tracer, typeof from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, get_abstract_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap -from jax._src.lax import (lax, parallel as lax_parallel, slicing, - windowed_reductions, convolution, fft, linalg, - special, control_flow, ann) -from jax._src import ffi from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2, foreach) + merge_lists, split_list, subs_list2) from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -451,6 +439,12 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] # Primitive +@lu.transformation2 +def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) + for o, n in zip(out_flat, out_names_thunk())] + JaxType = Any MaybeTracer = Union[JaxType, Tracer] @@ -646,31 +640,6 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: return set(mesh.axis_names) - {n for ns in names.values() for n in ns} -def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] - ) -> Sequence[RepType]: - env: dict[core.Var, RepType] = {} - - def read(x: core.Atom) -> RepType: - return env[x] if type(x) is core.Var else None - - def write(v: core.Var, val: RepType) -> None: - env[v] = val - - foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) - foreach(write, jaxpr.invars, in_rep) - last_used = core.last_used(jaxpr) - for e in jaxpr.eqns: - rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) - out_rep = rule(mesh, *map(read, e.invars), **e.params) - if e.primitive.multiple_results: - out_rep = (out_rep if isinstance(out_rep, (list, tuple)) else - [out_rep] * len(e.outvars)) - foreach(write, e.outvars, out_rep) - else: - write(e.outvars[0], out_rep) - core.clean_up_dead_vars(e, env, last_used) - return map(read, jaxpr.outvars) - def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) @@ -682,7 +651,6 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering - def _shardy_shard_map_sharding( ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in ) -> sharding_impls.SdyArraySharding: @@ -929,11 +897,6 @@ def _check_reps(mesh, auto, names, reps): class _RepError(Exception): pass -def _check_reps2(mesh, reps_dest, reps): - fail = [src if not dst.issubset(src) else no_fail - for dst, src in zip(reps_dest, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) - def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, dst_pspec: PartitionSpec, x: JaxType) -> JaxType: fn = HashablePartial(_match, mesh, check_rep, src_pspec, dst_pspec) @@ -978,10 +941,7 @@ def to_val_rep_pair(self, val): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) - if self.check: - return val_, frozenset(self.mesh.axis_names) - self.auto - else: - return val_, None + return val_, frozenset(self.mesh.axis_names) - self.auto def process_primitive(self, prim, tracers, params): in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) @@ -1127,428 +1087,6 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule -# Rewrite rules and static replication checking for efficient transposition - -_rewrite_rules: dict[core.Primitive, Callable] = {} -register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r) -register_standard_rewrite = lambda prim: \ - _rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim)) -register_norewrite = lambda p: \ - _rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p])) - -_check_rules: dict[core.Primitive, Callable] = {} -register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule) -register_standard_check = \ - lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim)) - -def _eq_rep(mesh, r1, r2) -> bool: - if r1 != r2 and r1 is None or r2 is None: - r1, r2 = _remove_none_rep(mesh, r1), _remove_none_rep(mesh, r2) - return r1 == r2 - -def _remove_none_rep(mesh, r): - return set(mesh.axis_names) if r is None else r - -def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): - out_vals = prim.bind(*args,**params) - out_rep = rule(mesh, *in_rep, **params) - if prim.multiple_results: - out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals) - else: - out_vals, out_rep_ = [out_vals], [out_rep] - return out_vals, out_rep_ - -def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params): - # The standard rewrite inserts pbroadcasts but doesn't change the primitive. - out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names) - args_ = [pvary(x, tuple(n for n in src if n not in out_rep_)) - if src - out_rep_ else x for x, src in zip(args, in_rep)] - out_vals_ = prim.bind(*args_, **params) - out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_] - out_vals = [out_vals_] if not prim.multiple_results else out_vals_ - return out_vals, out_rep - -def _standard_check(prim, mesh, *in_rep, **__): - # The standard check require args' and outputs' replications to be the same, - # except for Nones which correspond to constants. - in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and in_rep_[:-1] != in_rep_[1:]: - raise Exception(f"Primitive {prim} requires argument replication types " - f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return in_rep_[0] if in_rep_ else None - -def register_standard_collective(prim): - register_check(prim)(partial(_standard_collective_check, prim)) - register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) - -def register_reduction_collective(prim): - register_check(prim)(partial(_reduction_collective_check, prim)) - register_rewrite(prim)(partial(_reduction_collective_rewrite, prim)) - -def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): - # The standard collective check is varying -> varying over axis_name. - del mesh, params - if x_rep is None or axis_name in x_rep: - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axis_name}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep - -def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - x_rep, = in_rep - axis_name_set = set(axis_name) - if pbroadcast_axis_name := axis_name_set & x_rep: - x = pvary(x, tuple(pbroadcast_axis_name)) - out_val = prim.bind(x, axis_name=axis_name, **params) - return [out_val], [x_rep - axis_name_set] - -def _reduction_collective_check(prim, mesh, x_rep, *, axes, **params): - # The reduction collective check is varying -> replicated over axes. - del mesh, params - axes = (axes,) if not isinstance(axes, tuple) else axes - if x_rep is None or any(a in x_rep for a in axes): - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep | set(axes) - -def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axes = (axes,) if not isinstance(axes, tuple) else axes - x_rep, = in_rep - axes_set = set(axes) - if pbroadcast_axes := axes_set & x_rep: - x = pvary(x, tuple(pbroadcast_axes)) - out_val, = prim.bind(x, axes=axes, **params) - return [out_val], [x_rep | axes_set] - - -for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), - windowed_reductions.__dict__.values(), - special.__dict__.values(), convolution.__dict__.values(), - fft.__dict__.values(), linalg.__dict__.values(), - ops.__dict__.values(), ad_util.__dict__.values(), - prng.__dict__.values(), ann.__dict__.values(), - random.__dict__.values()): - if isinstance(o, core.Primitive): - register_standard_check(o) - register_standard_rewrite(o) - -for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p, - control_flow.loops.cumprod_p, control_flow.loops.cummax_p, - control_flow.loops.cummin_p, pjit.sharding_constraint_p, - pjit.mesh_cast_p]: - register_standard_check(p) - register_standard_rewrite(p) - - -@register_check(lax_parallel.psum_p) -def _psum_check(_, *in_rep, axes, axis_index_groups): - assert False # should be rewritten away - -@register_rewrite(lax_parallel.psum_p) -def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups): - # Replace the psum with psum2, insert pbroadcasts on input, replicated output. - if axis_index_groups is not None: raise NotImplementedError - axes = (axes,) if not isinstance(axes, tuple) else axes - axes_ = set(axes) - out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere) - args_ = [pvary(x, tuple(n for n in mesh.axis_names if n in axes_ & src)) - for x, src in zip(args, in_rep)] - out_val = lax_parallel.psum_invariant_p.bind( - *args_, axes=axes, axis_index_groups=axis_index_groups) - return out_val, out_rep - - -@register_check(lax_parallel.psum_invariant_p) -def _psum2_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if any(set(axes) & r for r in in_rep if r is not None): - raise Exception("Collective psum must be applied to a device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r | set(axes) for r in in_rep] -register_norewrite(lax_parallel.psum_invariant_p) - - -@register_check(pvary_p) -def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if not all(r is None or set(axes) & r for r in in_rep): - raise Exception("Collective pbroadcast must be applied to a " - "non-device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r - set(axes) for r in in_rep] -register_norewrite(pvary_p) - - -register_standard_collective(lax_parallel.all_gather_p) -register_standard_collective(lax_parallel.all_to_all_p) -register_standard_collective(lax_parallel.ppermute_p) -register_standard_collective(lax_parallel.reduce_scatter_p) -register_reduction_collective(lax_parallel.pmin_p) -register_reduction_collective(lax_parallel.pmax_p) - - -@register_check(lax_parallel.axis_index_p) -def _axis_index_check(mesh, *, axis_name): - axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name - return set(mesh.shape) - set(axis_name) -register_norewrite(lax_parallel.axis_index_p) - - -@register_rewrite(pjit.pjit_p) -def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep) - out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs) - return out_vals, out_rep - -@register_check(pjit.pjit_p) -def _pjit_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr.jaxpr, in_rep) - - -@register_rewrite(ad_checkpoint.remat_p) -def _remat_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_ = pe.close_jaxpr(jaxpr) - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr_, in_rep) - jaxpr, () = jaxpr_.jaxpr, jaxpr_.consts - out_vals = ad_checkpoint.remat_p.bind(*args, jaxpr=jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(ad_checkpoint.remat_p) -def _remat_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr, in_rep) - - -@register_check(core.call_p) -def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr, in_rep) - - -@register_check(debugging.debug_callback_p) -def _debug_callback_rule(mesh, *in_rep, **_): - return [] -register_norewrite(debugging.debug_callback_p) - - -@register_check(callback.pure_callback_p) -def _pure_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.pure_callback_p) - - -@register_check(callback.io_callback_p) -def _io_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.io_callback_p) - - -@register_check(dispatch.device_put_p) -def _device_put_rule(mesh, *xs, **_): - return list(xs) -register_norewrite(dispatch.device_put_p) - - -@register_check(ad.custom_lin_p) -def _custom_lin_rule(mesh, *_, out_avals, **__): - return [set()] * len(out_avals) -register_norewrite(ad.custom_lin_p) - - -@register_check(control_flow.loops.scan_p) -def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): - _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) - out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) - carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not all(map(partial(_eq_rep, mesh), carry_rep_in, carry_rep_out)): - raise Exception("Scan carry input and output got mismatched replication " - f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument to " - "shard_map") - return out_rep - -@register_rewrite(control_flow.loops.scan_p) -def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): - const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry]) - for _ in range(1 + num_carry): - in_rep_ = [*const_rep, *carry_rep_in, *xs_rep] - _, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_) - carry_rep_out, ys_rep = split_list(out_rep, [num_carry]) - carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out) - if carry_rep_in == carry_rep_out: - break - else: - carry_rep_in = carry_rep_out - else: - assert False, 'Fixpoint not reached' - - args = [pvary(x, tuple(n for n in src if n not in dst)) - if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] - out_rep = [*carry_rep_out, *ys_rep] - jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep) - - out_vals = control_flow.loops.scan_p.bind( - *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) - return out_vals, out_rep - -@register_check(control_flow.loops.while_p) -def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_): - _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) - carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in]) - if tuple(carry_rep_in) != tuple(carry_rep_out): - raise Exception("while_loop carry input and output got mismatched " - f"replication types {carry_rep_in} and {carry_rep_out}. " - "Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return carry_rep_out - -@register_rewrite(control_flow.loops.while_p) -def _while_rewrite(mesh, in_rep, *args, cond_jaxpr, body_jaxpr, cond_nconsts, - body_nconsts): - # while while isn't transposable, we insert pbroadcasts for consistent carry - cconst_rep, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) - num_carry = len(args) - cond_nconsts - body_nconsts - for _ in range(1 + num_carry): - in_rep_ = [*bconst_rep, *carry_rep_in] - _, carry_rep_out = _replication_rewrite_nomatch(mesh, body_jaxpr, in_rep_) - if tuple(carry_rep_in) == tuple(carry_rep_out): - break - carry_rep_in = map(op.and_, carry_rep_in, carry_rep_out) - else: - assert False, "Fixpoint not reached" - - cond_jaxpr_, _ = _replication_rewrite_nomatch( - mesh, cond_jaxpr, (*cconst_rep, *carry_rep_in)) - body_jaxpr_ = _replication_rewrite_match( - mesh, body_jaxpr, (*bconst_rep, *carry_rep_in), carry_rep_out) - args_ = [pvary(x, tuple(n for n in src if n not in dst)) - if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] - out_vals = control_flow.loops.while_p.bind( - *args_, cond_jaxpr=cond_jaxpr_, body_jaxpr=body_jaxpr_, - cond_nconsts=cond_nconsts, body_nconsts=body_nconsts) - return out_vals, carry_rep_out - -@register_check(control_flow.conditionals.cond_p) -def _cond_rule(mesh, *in_rep, branches): - _, *args_rep = in_rep - out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) - for branch in branches[1:]: - out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not all(map(partial(_eq_rep, mesh), out_rep, out_rep_)): - raise Exception("The branches of cond produced mismatched replication " - "types. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument " - "to shard_map") - return out_rep - -@register_rewrite(control_flow.conditionals.cond_p) -def _cond_rewrite(mesh, in_rep, *args, branches): - pred_rep, *args_rep = in_rep - _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) - for branch in branches[1:]: - _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) - if out_rep: - out_rep = map(op.and_, out_rep, out_rep_) - else: - out_rep = out_rep_ - out_rep = map(partial(op.and_, pred_rep), out_rep) - branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) - for branch in branches) - out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) - return out_vals, out_rep - -@register_check(control_flow.conditionals.platform_index_p) -def _platform_index_rule(mesh, *_, **__): - return set(mesh.axis_names) -register_norewrite(control_flow.conditionals.platform_index_p) - -@register_rewrite(core.closed_call_p) -def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): - new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) - out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(core.closed_call_p) -def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - - -@register_check(cd.custom_jvp_call_p) -def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - -@register_rewrite(cd.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_rewrite( - mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" - " a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - - fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep) - _, in_rep_ = split_list(in_rep, [num_consts]) - out_rep2 = [] - - @pe._memoize - def fwd_jaxpr_thunk_(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) - fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_) - out_rep2.append(out_rep) - return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts - - bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_) - - outs = cd.custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_rep = out_rep2[0] if out_rep2 else out_rep - return outs, out_rep - -@register_check(cd.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): - return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -@register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, jaxprs, **_): - out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) - return [out_rep] * len(jaxprs.solve.out_avals) -register_standard_rewrite(control_flow.solves.linear_solve_p) - -@register_check(ffi.ffi_call_p) -def _ffi_call_check(mesh, *in_rep, result_avals, **_): - out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) - return [out_rep] * len(result_avals) -register_standard_rewrite(ffi.ffi_call_p) - -del _check_rules[lax.tie_p] - -@register_check(lax.tie_p) -def _tie_check(mesh, x_rep, y_rep): - return x_rep -register_norewrite(lax.tie_p) - # Batching @@ -1643,7 +1181,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits2( f, (*in_knowns,), (*in_avals_sharded,)) - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) + all_names = _all_newly_manual_mesh_names(mesh, auto) @as_hashable_function(closure=out_names_thunk) def known_out_names(): @@ -1709,7 +1247,7 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) + all_names = _all_newly_manual_mesh_names(mesh, auto) @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): @@ -2006,16 +1544,14 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: + mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() spmd_names = axis_env.spmd_axis_names return tuple(name for name in mesh.axis_names if name not in spmd_names and name not in auto) def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: + mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: @@ -2128,210 +1664,3 @@ def _get_devices(p, backend): if jax.process_count() > 1: return devs[:p.global_axis_size] return devs[:p.local_axis_size] - -@lu.transformation2 -def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): - out_flat = f(*args, **kwargs) - return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) - for o, n in zip(out_flat, out_names_thunk())] - -### Rewrite! - -Val = Any - -class RewriteTracer(core.Tracer): - rep: set[AxisName] - val: Val - - def __init__(self, trace, rep, val): - self._trace = trace - self.rep = rep - self.val = val - - @property - def aval(self) -> core.AbstractValue: - return core.get_aval(self.val) - - def to_concrete_value(self): - return core.to_concrete_value(self.val) - - def __str__(self) -> str: - return str(self.val) # TODO(mattjj): could show replication info here - __repr__ = __str__ # for debuggers, like `p x` - -class RewriteTrace(core.Trace): - __slots__ = ("parent_trace", "tag", "mesh") - - parent_trace : core.Trace - tag : core.TraceTag - mesh: Mesh - - def __init__(self, parent_trace, tag, mesh): - super().__init__() - self.parent_trace = parent_trace - self.tag = tag - self.mesh = mesh - - def to_val_rep_pair(self, val): - # TODO: add a tag to tell if self - if isinstance(val, RewriteTracer) and val._trace.tag is self.tag: - return val.val, val.rep - else: - return val, set(self.mesh.axis_names) - - def process_primitive(self, prim, in_tracers, params): - rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - with core.set_current_trace(self.parent_trace): - out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) - out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) - return out_tracers if prim.multiple_results else out_tracers[0] - - def process_call(self, call_primitive, f, in_tracers, params): - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps)) - with core.set_current_trace(self.parent_trace): - out_vals = call_primitive.bind(f, *in_vals, **params) - return map(partial(RewriteTracer, self), out_reps(), out_vals) - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - jvp, out_reps2 = _rewrite_jvp_subtrace(jvp, self.tag, self.mesh, in_reps * 2) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - return map(partial(RewriteTracer, self), out_reps, out_vals) - - def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] - fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps) - bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - _, res_tree = out_trees() - _, out_reps = split_list(out_reps, [res_tree.num_leaves]) - return map(partial(RewriteTracer, self), out_reps, out_vals) - -def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): - in_reps = map(partial(_in_names_to_rep, mesh), in_names) - out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] - fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) - return _match_rep(fun, mesh, out_reps_src, out_reps_dst) - -@lu.transformation_with_aux2 -def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): - with core.take_current_trace() as parent: - tag = core.TraceTag() - t = RewriteTrace(parent_trace=parent, tag=tag, mesh=mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - with core.set_current_trace(t): - ans = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) - del t, in_tracers, ans - store.store(out_reps) - return out_vals - -@lu.transformation2 -def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): - outs = f(*args) - out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ - out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ - _check_reps2(mesh, out_reps_dst, out_reps_src) - outs = [pvary(x, tuple(n for n in src if n not in dst)) if src - dst - else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - return outs - -# TODO(mattjj): caching -def _replication_rewrite_match( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], - out_rep_dst: Sequence[set[AxisName]], -) -> core.ClosedJaxpr: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - f = _match_rep(f, mesh, out_rep, out_rep_dst) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts) - -# TODO(mattjj): caching -def _replication_rewrite_nomatch( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], -) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts), out_rep() - -@lu.transformation_with_aux2 -def _rewrite_subtrace(f: Callable, store: lu.Store, tag: core.TraceTag, - mesh: Mesh, in_reps, *in_vals): - with core.take_current_trace() as parent_trace: - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = RewriteTrace(parent_trace, tag, mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) - with core.set_current_trace(t): - outs = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) - store.store(out_reps) - return out_vals - -@lu.transformation_with_aux2 -def _rewrite_jvp_subtrace(f: Callable, store: lu.Store, tag: core.TraceTag, - mesh: Mesh, in_reps, *in_vals): - with core.take_current_trace() as parent_trace: - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = RewriteTrace(parent_trace, tag, mesh) - in_tracers = [x if type(x) is cd.SymbolicZero else RewriteTracer(t, r, x) - for r, x in zip(in_reps, in_vals)] - with core.set_current_trace(t): - out_tracers: list[RewriteTracer | cd.SymbolicZero] = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, out_tracers)) - out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) - out_primal_reps, out_tangent_reps = split_list(out_reps, [len(out_vals) // 2]) - out_reps = map(_merge_reps, out_primal_reps, out_tangent_reps, out_tangents) - out_tangents = map(_match_replication, out_tangent_reps, out_reps, out_tangents) - store.store(out_reps) - return out_primals + out_tangents - -def _merge_reps(primal_rep, tangent_rep, error_message_val): - if primal_rep - tangent_rep: - raise ValueError("custom_jvp primal output is more replicated than its " - "corresponding tangent of type " - f"{core.typeof(error_message_val).str_short()}") - return primal_rep - -def _rewrite_bwd(bwd: lu.WrappedFun, - mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun: - def new_bwd(*args): - tag = core.TraceTag() - bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps()) - out = bwd_.call_wrapped(*args) - return map(_match_replication, reps_thunk(), reps_dst, out) - return lu.wrap_init(new_bwd, debug_info=bwd.debug_info) - -def _match_replication(src, dst, x): - if dst - src: - x, = lax_parallel.psum_invariant_p.bind( - x, axes=tuple(n for n in dst if n not in src), axis_index_groups=None) - if src - dst: - x = pvary(x, tuple(n for n in src if n not in dst)) - return x From 85fdd1aae67b4f62df8e03ee6e56c303de727476 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Apr 2025 08:53:53 -0700 Subject: [PATCH 0672/1769] Make `ShardMapTracer` track `vma` instead of `rep`. PiperOrigin-RevId: 748696644 --- jax/experimental/shard_map.py | 94 +++++++++++++++-------------------- 1 file changed, 41 insertions(+), 53 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 713655897789..775b68022853 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -522,8 +522,8 @@ def _shard_map_staging( jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: - out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] - _check_reps(mesh, auto, out_names_thunk(), out_rep) + out_vma = [v.aval.vma for v in jaxpr.outvars] + _check_reps(mesh, auto, out_names_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] @@ -628,8 +628,8 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): if not _valid_repeats(mesh, auto, rep, dst): - raise core.JaxprTypeError("shard_map can't prove output is " - "sufficiently replicated") + raise core.JaxprTypeError( + "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] out_avals = map(partial(_unshard_aval, mesh, check_rep), out_names, out_avals_sharded) @@ -643,11 +643,6 @@ def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) -def _rule_missing(prim: core.Primitive, *_, **__): - raise NotImplementedError( - f"No replication rule for {prim}. As a workaround, pass the " - "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/jax-ml/jax/issues") # Lowering @@ -808,12 +803,6 @@ def get_mesh_from_args(args_flat, mesh): assert isinstance(mesh, Mesh) return mesh -def _rep_to_vma(mesh, auto, rep: frozenset[AxisName]) -> frozenset[AxisName]: - return frozenset((set(mesh.axis_names) - auto) - rep) - -def _rep_to_spec(mesh, auto, rep): - return _vma_to_spec(mesh, _rep_to_vma(mesh, auto, rep)) - def _vma_to_spec(mesh, vma): return P(tuple(i for i in mesh.axis_names if i in vma)) @@ -832,29 +821,29 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, cur_mesh = get_abstract_mesh() args = map(partial(_unmatch_spec, mesh, check_rep, context_mesh=cur_mesh), in_names, args) - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, cur_mesh) + in_vma = map(_names_to_vma, in_names) + outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_rep, cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_rep: - _check_reps(mesh, auto, out_names_thunk(), out_rep) - src_pspecs = tuple(_rep_to_spec(mesh, auto, r) for r in out_rep) + _check_reps(mesh, auto, out_names_thunk(), out_vma) + src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: - src_pspecs = tuple(P(mesh.axis_names) for _ in out_rep) + src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) dst_pspecs = map(_names_to_pspec, out_names_thunk()) return map(partial(_match_spec, mesh, check_rep), src_pspecs, dst_pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl -def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): +def _run_shmap(f, mesh, auto, args, vmas, check_rep, context_mesh): trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) - in_tracers = map(partial(ShardMapTracer, trace), reps, args) + in_tracers = map(partial(ShardMapTracer, trace), vmas, args) manual_mesh = _as_manual_mesh(mesh, auto) with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), config._check_rep(check_rep)): ans = f.call_wrapped(*in_tracers) - outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) - return outs, out_rep + outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) + return outs, out_vma def _names_to_pspec(names: AxisNames) -> PartitionSpec: ndmin = max(names) + 1 if names else 0 @@ -888,7 +877,8 @@ def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] class _SpecError(Exception): pass -def _check_reps(mesh, auto, names, reps): +def _check_reps(mesh, auto, names, vmas): + reps = [_vma_to_rep(mesh, auto, v) for v in vmas] fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail for n, r in zip(names, reps)] if any(f is not no_fail for f in fail): @@ -934,27 +924,25 @@ def __init__(self, mesh, auto, check, context_mesh): self.check = check self.context_mesh = context_mesh - def to_val_rep_pair(self, val): + def to_val_vma_pair(self, val): if isinstance(val, ShardMapTracer): - return val.val, val.rep + return val.val, val.vma elif isinstance(val, Tracer): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) - return val_, frozenset(self.mesh.axis_names) - self.auto + return val_, frozenset() def process_primitive(self, prim, tracers, params): - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) if self.check: - in_vma = tuple(map(partial(_rep_to_vma, self.mesh, self.auto), in_rep)) out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) out_avals = tuple(out_avals) if type(out_avals) is list else out_avals out_vma = tree_map(lambda a: a.vma, out_avals) - out_rep = tree_map(partial(_vma_to_rep, self.mesh, self.auto), out_vma) in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) else: - out_rep = frozenset() + out_vma = frozenset() in_specs = out_specs = P(self.mesh.axis_names) eager_rule = eager_rules.get(prim) @@ -969,10 +957,10 @@ def process_primitive(self, prim, tracers, params): out_vals = jax.jit(f)(*in_vals) _maybe_check_special(out_vals) if prim.multiple_results: - out_rep = (out_rep if isinstance(out_rep, (list, tuple)) - else [out_rep] * len(out_vals)) - return map(partial(ShardMapTracer, self), out_rep, out_vals) - return ShardMapTracer(self, out_rep, out_vals) + out_vma = (out_vma if isinstance(out_vma, (list, tuple)) + else [out_vma] * len(out_vals)) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + return ShardMapTracer(self, out_vma, out_vals) def process_call(self, call_primitive, fun, tracers, params): raise NotImplementedError( @@ -990,10 +978,10 @@ def process_map(self, map_primitive, fun, tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. del prim, jvp, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, + self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): @@ -1003,19 +991,22 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, + self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) class ShardMapTracer(core.Tracer): - rep: RepType + vma: frozenset[AxisName] val: JaxType - def __init__(self, trace, rep, val): + def __init__(self, trace, vma, val): self._trace = trace - self.rep = rep + if isinstance(vma, set): + vma = frozenset(vma) + assert isinstance(vma, frozenset) + self.vma = vma self.val = val @property @@ -1025,21 +1016,18 @@ def aval(self): new_sharding = NamedSharding( _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error - manual_axes = set(self._trace.mesh.axis_names) - self._trace.auto - vma = (frozenset(manual_axes - self.rep) if config._check_rep.value else - frozenset()) + vma = self.vma if config._check_rep.value else frozenset() return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): - if self.rep == set(self._trace.mesh.axis_names): + if self.vma == frozenset(): with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): return core.to_concrete_value(self.val[0]) else: return None def __str__(self) -> str: - pb_names = set(self._trace.mesh.axis_names) - _rep_to_vma( - self._trace.mesh, self._trace.auto, self.rep) + pb_names = set(self._trace.mesh.axis_names) - self.vma self = pvary(self, tuple(pb_names)) with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): blocks = list(self.val) From 9bd728e2b53f2efa5b2ff014a7ff05f03b8243c3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 17 Apr 2025 08:57:06 -0700 Subject: [PATCH 0673/1769] Exclude running tests against the oldest supported libtpu for releases PiperOrigin-RevId: 748697528 --- .github/workflows/wheel_tests_nightly_release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index f6d2aa9b97c6..e2466ee43d96 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -87,6 +87,7 @@ jobs: libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'oldest_supported_libtpu' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} # Run a single Python version for v4-8 - tpu-specs: From 25095153246a976fd41e9c8e5b40993cd8b9c2c7 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 17 Apr 2025 16:46:58 +0000 Subject: [PATCH 0674/1769] [Mosaic GPU] Align on device profiler array in smem to 8 bytes --- jax/experimental/mosaic/gpu/core.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 860b41e7e8e3..a5ac6b9bf142 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -370,7 +370,13 @@ def _launch( smem_bytes = user_smem_bytes if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) + # Profiler array stores values in 64 bit chunks (vectors of size 2 + # of 32-bit elements), and so the starting address needs to be 64 + # bit = 8 byte aligned. + # https://docs.nvidia.com/cuda/parallel-thread-execution/#addresses-as-operands:~:text=The%20address%20must%20be%20naturally%20aligned%20to%20a%20multiple%20of%20the%20access%20size. + align = 8 + profiler_start = (smem_bytes + align - 1) & ~(align - 1) + smem_bytes = profiler_start + profiler_spec.smem_bytes(block=block) # TODO(cperivol): Query the shared memory size programmatically. if smem_bytes > 228 * 1024: @@ -407,7 +413,7 @@ def _launch( (profiler_spec.smem_i32_elements(block=block),), i32, memory_space=smem, ), - dynamic_smem, c(user_smem_bytes, index), [], + dynamic_smem, c(profiler_start, index), [], ) prof = profiler.OnDeviceProfiler( profiler_spec, prof_smem, maybe_prof_buffer From 06ad3528e96eb942ee8c01cabbaf20fd26a81493 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Apr 2025 10:47:53 -0700 Subject: [PATCH 0675/1769] Use _make_lengths_same for explicit mode too. We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs. PiperOrigin-RevId: 748735303 --- jax/_src/core.py | 22 ++++++++-------------- jax/_src/lax/control_flow/loops.py | 3 ++- jax/_src/lax/windowed_reductions.py | 3 ++- jax/experimental/shard_map.py | 2 +- tests/pjit_test.py | 20 ++++++++++++++++++++ 5 files changed, 33 insertions(+), 17 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 013219021e05..0183a9942524 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1808,7 +1808,7 @@ def _make_lengths_same(sharding, ndim): if ndim > len(sharding.spec): return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) if ndim < len(sharding.spec): - assert all(s is None for s in sharding.spec[ndim:]) + assert all(s is None for s in sharding.spec[ndim:]), (ndim, sharding.spec) return sharding.with_spec(sharding.spec[:ndim]) assert False, "unreachable" @@ -1829,19 +1829,13 @@ def modify_spec_for_auto_manual(spec, mesh) -> P: def _maybe_modify_sharding(sharding, ndim): if len(sharding.spec) == 0 or all(s is None for s in sharding.spec): - if len(sharding.spec) != ndim: - return _make_lengths_same(sharding, ndim) - return sharding - - if sharding.mesh._are_all_axes_explicit: - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - return sharding - - out = sharding.with_spec(modify_spec_for_auto_manual( - sharding.spec, sharding.mesh)) - if (len(out.spec) != ndim and - (out.mesh.empty or out.mesh._are_all_axes_auto_or_manual)): + out = sharding + elif sharding.mesh._are_all_axes_explicit: + out = sharding + else: + out = sharding.with_spec(modify_spec_for_auto_manual( + sharding.spec, sharding.mesh)) + if len(out.spec) != ndim: out = _make_lengths_same(out, ndim) return out diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index de53ee14ca0d..39498ad624bc 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1921,7 +1921,8 @@ def fun(*args): name_stack=cond_name_stack, primitive=None, avals_in=[pred_aval], - avals_out=[pred_aval.update(shape=())], + avals_out=[pred_aval.update( + shape=(), sharding=pred_aval.sharding.with_spec(()))], tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index acbfae37eaf5..c159dcab8bfa 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -690,7 +690,8 @@ def _select_and_scatter_lower( window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.with_spec(())) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = hlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 775b68022853..b81a4656546b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -585,7 +585,7 @@ def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, for i, sz in enumerate(aval.shape)) names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) if aval.ndim == 0: - out_spec = names_spec + out_spec = P() else: out_spec = [] # type: ignore for name_s, aval_s in zip(names_spec, aval.sharding.spec): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8f30475eee32..5b69a415de1b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7021,6 +7021,26 @@ def f(x, y): "PartitionSpec passed to einsum cannot contain axis names.*Auto.*Manual"): f(arr, arr2) + def test_broadcasted_iota_mix_axes(self): + mesh = jtu.create_mesh( + (2, 2, 2), ('x', 'y', 'z'), + axis_types=(AxisType.Auto, AxisType.Explicit, AxisType.Explicit)) + yz_sharding = NamedSharding(mesh, P(('y', 'z'))) + + @jax.jit + def iota(): + out = jax.lax.broadcasted_iota( + dtype=jnp.int32, + shape=(16, 24), + dimension=1, + out_sharding=yz_sharding) + self.assertEqual(out.aval.sharding.spec, P(('y', 'z'), None)) + return out + + with jax.sharding.use_mesh(mesh): + out = iota() + self.assertEqual(out.sharding, yz_sharding) + @jtu.with_user_mesh((2,), ('x',)) def test_cumsum(self, mesh): np_inp = np.arange(16).reshape(8, 2) From 1d652ab7f4346a974cf5888ab23713e1a9c2ddba Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 17 Apr 2025 15:31:38 -0400 Subject: [PATCH 0676/1769] Don't recompute source_info for each tracer during staging. --- jax/_src/interpreters/ad.py | 24 ++++--- jax/_src/interpreters/partial_eval.py | 99 ++++++++++++++++----------- jax/_src/pjit.py | 4 +- jax/experimental/attrs.py | 3 +- jax/experimental/shard_map.py | 7 +- 5 files changed, 80 insertions(+), 57 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 98cda2df4964..0f11e0d72f12 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -89,12 +89,14 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): + source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, - tangent_trace.new_arg(get_aval(p).to_tangent_aval())) + tangent_trace.new_arg(get_aval(p).to_tangent_aval(), + source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): @@ -103,7 +105,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -172,13 +174,14 @@ def _linearize_jaxpr( lin_trace = LinearizeTrace(primal_trace, tangent_trace) tangent_trace.tag = lin_trace.tag - def new_arg(trace, primal_aval, nz): - primal = primal_trace.new_arg(primal_aval) + def new_arg(trace, primal_aval, nz, source_info): + primal = primal_trace.new_arg(primal_aval, source_info) tangent_aval = primal_aval.to_tangent_aval() - tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) - tracers = [new_arg(lin_trace, v.aval, nz) + source_info = source_info_util.current() + tracers = [new_arg(lin_trace, v.aval, nz, source_info) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] with core.set_current_trace(lin_trace, check_leaks=True): @@ -188,7 +191,7 @@ def new_arg(trace, primal_aval, nz): debug_info = jaxpr.jaxpr.debug_info nzs_out = [type(t) is not Zero for t in out_tangents] - out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) + out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t, source_info) for (nz, t) in zip(nzs_out, out_tangents) if nz) tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) tangent_trace.invalidate() @@ -200,7 +203,7 @@ def new_arg(trace, primal_aval, nz): tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) + residuals_and_primals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) primal_trace.invalidate() num_residuals = len(tangent_consts) @@ -212,8 +215,9 @@ def new_arg(trace, primal_aval, nz): def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, has_aux=False, tag=None): with core.take_current_trace() as parent_trace: + source_info = source_info_util.current() tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info) - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag @@ -234,7 +238,7 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) + out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() jaxpr, used_consts, _ = pe.dce_jaxpr_consts( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f8ce92e7f97f..8be78f1f6eb9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,6 +41,7 @@ JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) +from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) @@ -1729,7 +1730,8 @@ def to_jaxpr( invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] + source_info = source_info_util.current() + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x, source_info))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars @@ -1899,51 +1901,51 @@ def invalidate(self): self.frame.constid_to_tracer = {} self.frame.constvar_to_val = {} - def to_jaxpr_tracer(self, x): + def to_jaxpr_tracer(self, x, source_info: SourceInfo): as_local_var = self.frame.tracer_to_var.get(id(x)) if as_local_var is None: if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr with core.set_current_trace(self): x = x.dimension_as_value() - return self.to_jaxpr_tracer(x) + return self.to_jaxpr_tracer(x, source_info) else: - return self.new_const(x) + return self.new_const(x, source_info) else: return x - def new_arg(self, aval): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def new_arg(self, aval, source_info: SourceInfo): + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) return tracer - def new_const(self, c): + def new_const(self, c, source_info: SourceInfo): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = get_aval(c) if hasattr(aval, "weak_type"): aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, c) + aval = self._lift_tracers_in_aval(aval, source_info) + tracer = self._new_const(aval, c, source_info) return tracer pure = lift = new_const - def _new_const(self, aval, c) -> DynamicJaxprTracer: - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.constid_to_tracer[id(c)] = tracer self.frame.constvar_to_val[var] = c return tracer - def _lift_tracers_in_aval(self, aval): + def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1966,7 +1968,9 @@ def is_const(self, tracer): def process_primitive(self, primitive, tracers, params): if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): return primitive.bind_with_trace(core.eval_trace, tracers, params) - jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) return self.default_process_primitive(primitive, jaxpr_tracers, params) @@ -1989,17 +1993,19 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers, params): + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) assert f.in_type is not None - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers, + source_info) + in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) - source_info = source_info_util.current() out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: @@ -2009,7 +2015,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -2022,7 +2028,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) @@ -2041,10 +2049,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] - source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2062,7 +2069,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) @@ -2079,7 +2088,7 @@ def jvp_jaxpr_thunk(*in_zeros): out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, @@ -2088,7 +2097,7 @@ def jvp_jaxpr_thunk(*in_zeros): num_consts=len(consts), symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2097,7 +2106,9 @@ def process_custom_vjp_call(self, prim: core.Primitive, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, out_trees: Callable[[], Sequence[PyTreeDef]], symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2110,9 +2121,9 @@ def fwd_jaxpr_from_zeros(*zeros): if attrs: raise NotImplementedError return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, # pytype: disable=attribute-error @@ -2122,7 +2133,7 @@ def fwd_jaxpr_from_zeros(*zeros): bwd=bwd, out_trees=out_trees, symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2132,7 +2143,9 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid out_types, lin_tree: PyTreeDef, res_tree: PyTreeDef, out_tree: PyTreeDef): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] @@ -2152,9 +2165,9 @@ def transpose_jaxpr_thunk(): jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2162,7 +2175,7 @@ def transpose_jaxpr_thunk(): out_types=out_types, res_tree=res_tree, lin_tree=lin_tree, out_tree=out_tree), closed_call_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2216,13 +2229,15 @@ def trace_to_jaxpr_dynamic( keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + source_info = source_info_util.current() + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] try: with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) del fun, in_tracers, out_tracers, ans @@ -2269,12 +2284,14 @@ def trace_to_jaxpr_dynamic2( trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + source_info = source_info_util.current() in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) del trace, in_tracers, out_tracers, ans @@ -2449,7 +2466,7 @@ def __hash__(self): def _extract_implicit_args( trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer] + explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo, ) -> Sequence[DynamicJaxprTracer]: # First, construct a list to represent the full argument list, leaving the # implicit arguments as Nones for now. @@ -2467,8 +2484,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) + tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2616,13 +2633,13 @@ def inline_jaxpr_into_trace( trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - const_tracers = map(trace.new_const, consts) + src = source_info_util.current() + const_tracers = map(partial(trace.new_const, source_info=src), consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*constvars, *argvars])) - src = source_info_util.current() for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] outvars = [Var('', v.aval) for v in eqn.outvars] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a10856cbced3..595300f2c1b2 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1928,12 +1928,12 @@ def pjit_staging_rule(trace, *args, **params): trace, jaxpr.jaxpr, jaxpr.consts, *args) jaxpr = params['jaxpr'] + source_info = source_info_util.current() if config.dynamic_shapes.value: jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( jaxpr, params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: @@ -1952,7 +1952,7 @@ def pjit_staging_rule(trace, *args, **params): assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.new_const, consts) + consts = map(partial(trace.new_const, source_info=source_info), consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 0d40938a85c4..54fd0fe0b02f 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -87,10 +87,11 @@ def _check_append_type_agreement(_, attr, curtype, valtype): def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame + source_info = source_info_util.current() def new_tracer(x): aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) + tracer = pe.DynamicJaxprTracer(trace, aval, source_info) var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) frame.attrs_vars.append(var) frame.tracers.append(tracer) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b81a4656546b..b66b22304e17 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -512,7 +512,9 @@ def _shard_map_staging( check_rep: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: - in_tracers = map(trace.to_jaxpr_tracer, in_tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) + in_tracers = map(to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, in_avals) @@ -527,10 +529,9 @@ def _shard_map_staging( out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] - source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) + constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), From 309d9d262505ec1875917d6caa9596cf308c01d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 17 Apr 2025 12:52:26 -0700 Subject: [PATCH 0677/1769] Enable cross-compilation builds of the wheels. The feature is achieved through using transitive dependencies in the wheel targets. PiperOrigin-RevId: 748778634 --- BUILD.bazel | 39 ++- ci/run_bazel_test_cpu_rbe.sh | 4 +- jaxlib/jax.bzl | 41 ++- jaxlib/tools/BUILD.bazel | 236 ++++++++++++++--- jaxlib/tools/build_gpu_kernels_wheel.py | 137 ++++++---- jaxlib/tools/build_gpu_plugin_wheel.py | 98 ++++--- jaxlib/tools/build_utils.py | 48 +++- jaxlib/tools/build_wheel.py | 333 ++++++++++++++---------- 8 files changed, 642 insertions(+), 294 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 8dbf2bed0902..906ae83796b8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -16,23 +16,19 @@ load( "@xla//third_party/py:py_import.bzl", "py_import", ) -load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", "jax_source_package", "jax_wheel", "py_deps", "pytype_test", + "wheel_sources", ) -collect_data_files( - name = "transitive_py_data", - deps = ["//jax"], -) - -transitive_py_deps( - name = "transitive_py_deps", - deps = [ +wheel_sources( + name = "jax_sources", + data_srcs = ["//jax"], + py_srcs = [ "//jax", "//jax:compilation_cache", "//jax:experimental", @@ -60,6 +56,14 @@ transitive_py_deps( "//jax/tools:jax_to_ir", "//jax/tools:pgo_nsys_converter", ], + static_srcs = [ + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], ) py_binary( @@ -73,21 +77,10 @@ py_binary( ], ) -WHEEL_SOURCE_FILES = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", -] - jax_wheel( name = "jax_wheel", platform_independent = True, - source_files = WHEEL_SOURCE_FILES, + source_files = [":jax_sources"], wheel_binary = ":build_wheel", wheel_name = "jax", ) @@ -96,14 +89,14 @@ jax_wheel( name = "jax_wheel_editable", editable = True, platform_independent = True, - source_files = WHEEL_SOURCE_FILES, + source_files = [":jax_sources"], wheel_binary = ":build_wheel", wheel_name = "jax", ) jax_source_package( name = "jax_source_package", - source_files = WHEEL_SOURCE_FILES, + source_files = [":jax_sources"], source_package_binary = ":build_wheel", source_package_name = "jax", ) diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index d8cb190079e0..7eeb2adef0b3 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -53,7 +53,9 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test else echo "Running RBE CPU tests..." bazel test --config=rbe_${os}_${arch} \ diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 3c234f5f8c37..d6030a7ce03f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -23,7 +23,8 @@ load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_roc load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_test") -load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") +load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@xla//xla/tsl:tsl.bzl", "transitive_hdrs", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl @@ -434,10 +435,9 @@ def _jax_wheel_impl(ctx): _jax_wheel = rule( attrs = { "wheel_binary": attr.label( - default = Label("//jaxlib/tools:build_wheel"), + default = Label("//jaxlib/tools:build_wheel_tool"), executable = True, - # b/365588895 Investigate cfg = "exec" for multi platform builds - cfg = "target", + cfg = "exec", ), "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), @@ -585,3 +585,36 @@ def if_oss(oss_value, google_value = []): """ _ = (google_value, oss_value) # buildifier: disable=unused-variable return oss_value + +def wheel_sources( + name, + py_srcs = [], + data_srcs = [], + symlink_data_srcs = [], + hdr_srcs = [], + static_srcs = []): + """Create a filegroup containing the list of source files for a wheel. + + The sources are collected from the static files and from the transitive dependencies of the + given srcs. + + Args: + name: the target name + py_srcs: targets which transitive python dependencies should be included in the wheel + data_srcs: targets which platform-dependent data dependencies should be included in the wheel + symlink_data_srcs: targets which symlinked data dependencies should be included in the wheel + hdr_srcs: targets which transitive header dependencies should be included in the wheel + static_srcs: the platform-independent file dependencies of the wheel + """ + transitive_py_deps(name = "{}_py".format(name), deps = py_srcs) + collect_data_files( + name = "{}_data".format(name), + deps = data_srcs, + symlink_deps = symlink_data_srcs, + ) + transitive_hdrs(name = "{}_hdrs".format(name), deps = hdr_srcs) + native.filegroup(name = name, srcs = [ + ":{}_py".format(name), + ":{}_data".format(name), + ":{}_hdrs".format(name), + ] + static_srcs) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 3dffe556d821..c7d606eba458 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -33,6 +33,7 @@ load( "jax_wheel", "pytype_strict_library", "pytype_test", + "wheel_sources", ) licenses(["notice"]) # Apache 2 @@ -167,6 +168,10 @@ py_binary( ], ) +# Targets and configurations for the new wheel build rules. + +# Platform configurations. + selects.config_setting_group( name = "macos", match_any = [ @@ -223,6 +228,8 @@ selects.config_setting_group( ], ) +# Flags for the new wheel build rules. + string_flag( name = "jaxlib_git_hash", build_setting_default = "", @@ -233,60 +240,129 @@ string_flag( build_setting_default = "dist", ) -NVIDIA_WHEELS_DEPS = [ - "@pypi_nvidia_cublas_cu12//:whl", - "@pypi_nvidia_cuda_cupti_cu12//:whl", - "@pypi_nvidia_cuda_runtime_cu12//:whl", - "@pypi_nvidia_cudnn_cu12//:whl", - "@pypi_nvidia_cufft_cu12//:whl", - "@pypi_nvidia_cusolver_cu12//:whl", - "@pypi_nvidia_cusparse_cu12//:whl", - "@pypi_nvidia_nccl_cu12//:whl", - "@pypi_nvidia_nvjitlink_cu12//:whl", -] +# Wheel targets. + +# Jaxlib wheel targets. +py_binary( + name = "build_wheel_tool", + srcs = ["build_wheel.py"], + main = "build_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +wheel_sources( + name = "jaxlib_sources", + data_srcs = [ + "//jaxlib", + "//jaxlib:jaxlib_binaries", + "//jaxlib/xla:xla_extension", + ], + hdr_srcs = [ + "@xla//xla/ffi/api:ffi", + ], + py_srcs = [ + "//jaxlib", + ], + static_srcs = [ + "//jaxlib:README.md", + "LICENSE.txt", + "//jaxlib:setup.py", + "//jaxlib/xla:xla_client.py", + ], + symlink_data_srcs = [ + "//jaxlib", + ], +) jax_wheel( name = "jaxlib_wheel", no_abi = False, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) -py_import( - name = "jaxlib_py_import", - wheel = ":jaxlib_wheel", -) - jax_wheel( name = "jaxlib_wheel_editable", editable = True, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) +# JAX plugin wheel targets. +pytype_strict_library( + name = "version", + srcs = ["//jaxlib:version"], +) + +py_binary( + name = "build_gpu_kernels_wheel_tool", + srcs = ["build_gpu_kernels_wheel.py"], + main = "build_gpu_kernels_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +wheel_sources( + name = "jax_plugin_sources", + data_srcs = [ + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + "//jaxlib/cuda:cuda_plugin_extension", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + "//jaxlib/rocm:rocm_plugin_extension", + ]), + py_srcs = [":version"] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:plugin_pyproject.toml", + "//jax_plugins/cuda:plugin_setup.py", + ]) + if_rocm([ + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", + ]), +) + jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) -py_import( - name = "jax_cuda_plugin_py_import", - wheel = ":jax_cuda_plugin_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), -) - jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) @@ -295,7 +371,8 @@ jax_wheel( enable_rocm = True, no_abi = False, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) @@ -304,33 +381,79 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) +# JAX PJRT wheel targets. +pytype_strict_library( + name = "pjrt_c_api_gpu_plugin_so", + data = [":pjrt_c_api_gpu_plugin.so"], +) + +py_binary( + name = "build_gpu_plugin_wheel_tool", + srcs = ["build_gpu_plugin_wheel.py"], + main = "build_gpu_plugin_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +wheel_sources( + name = "jax_pjrt_sources", + data_srcs = [ + ":pjrt_c_api_gpu_plugin_so", + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + py_srcs = [ + ":version", + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:pyproject.toml", + "//jax_plugins/cuda:setup.py", + "//jax_plugins/cuda:__init__.py", + ]) + if_rocm([ + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", + ]), +) + jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) -py_import( - name = "jax_cuda_pjrt_py_import", - wheel = ":jax_cuda_pjrt_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), -) - jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) @@ -339,7 +462,8 @@ jax_wheel( enable_rocm = True, no_abi = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) @@ -348,10 +472,46 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) +# Py_import targets. +filegroup( + name = "nvidia_wheel_deps", + srcs = [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", + ], +) + +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda_plugin_wheel", + wheel_deps = if_cuda([":nvidia_wheel_deps"]), +) + +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda_pjrt_wheel", + wheel_deps = if_cuda([":nvidia_wheel_deps"]), +) + +# Wheel tests. + AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index e9684108caf0..835a8b72de9f 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--output_path", default=None, @@ -61,6 +61,9 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() @@ -79,78 +82,106 @@ def write_setup_cfg(sources_path, cpu): def prepare_wheel_cuda( - sources_path: pathlib.Path, *, cpu, cuda_version + wheel_sources_path: pathlib.Path, *, cpu, cuda_version, wheel_sources ): - """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the cuda kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_cuda{cuda_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_cuda_version(wheel_sources_path, cuda_version) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_cuda{cuda_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_linalg.{pyext}", - f"__main__/jaxlib/cuda/_prng.{pyext}", - f"__main__/jaxlib/cuda/_rnn.{pyext}", - f"__main__/jaxlib/cuda/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_triton.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", - f"__main__/jaxlib/cuda/_versions.{pyext}", - f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}", - f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", - "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/cuda/_solver.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_prng.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_triton.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_versions.{pyext}", + f"{source_file_prefix}jaxlib/cuda/cuda_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", + f"{source_file_prefix}jaxlib/version.py", ], ) + def prepare_wheel_rocm( - sources_path: pathlib.Path, *, cpu, rocm_version + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources ): - """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the rocm kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_rocm{rocm_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_rocm{rocm_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_linalg.{pyext}", - f"__main__/jaxlib/rocm/_prng.{pyext}", - f"__main__/jaxlib/rocm/_solver.{pyext}", - f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/rocm/_hybrid.{pyext}", - f"__main__/jaxlib/rocm/_rnn.{pyext}", - f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/rocm/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_prng.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_solver.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_triton.{pyext}", + f"{source_file_prefix}jaxlib/rocm/rocm_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/version.py", ], ) + # Build wheel for cuda kernels if args.enable_rocm: tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") @@ -161,12 +192,18 @@ def prepare_wheel_rocm( os.makedirs(args.output_path, exist_ok=True) if args.enable_cuda: prepare_wheel_cuda( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = f"jax cuda{args.platform_version} plugin" elif args.enable_rocm: prepare_wheel_rocm( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = f"jax rocm{args.platform_version} plugin" if args.editable: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index d52cc7da36e8..337bedab4591 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -67,6 +67,9 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() @@ -85,58 +88,77 @@ def write_setup_cfg(sources_path, cpu): """ ) +def prepare_cuda_plugin_wheel( + wheel_sources_path: pathlib.Path, *, cpu, cuda_version, wheel_sources +): + """Assembles a source tree for the wheel in `wheel_sources_path`""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) -def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) - - plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" - copy_runfiles( - dst_dir=sources_path, + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jax_plugins/cuda/pyproject.toml", - "__main__/jax_plugins/cuda/setup.py", + f"{source_file_prefix}jax_plugins/cuda/pyproject.toml", + f"{source_file_prefix}jax_plugins/cuda/setup.py", ], ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_cuda_version(wheel_sources_path, cuda_version) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/cuda/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/cuda/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jaxlib/tools/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) -def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): - """Assembles a source tree for the ROCm wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_rocm_plugin_wheel( + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources +): + """Assembles a source tree for the ROCm wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" - copy_runfiles( - dst_dir=sources_path, - src_files=[ - "__main__/jax_plugins/rocm/pyproject.toml", - "__main__/jax_plugins/rocm/setup.py", + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_files( + dst_dir=wheel_sources_path, + src_files=[ + f"{source_file_prefix}jax_plugins/rocm/pyproject.toml", + f"{source_file_prefix}jax_plugins/rocm/setup.py", ], ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/rocm/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/rocm/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jaxlib/tools/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_rocm_plugin.so", ) @@ -153,12 +175,18 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.enable_cuda: prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = "jax cuda plugin" elif args.enable_rocm: prepare_rocm_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = "jax rocm plugin" else: diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 582a0c9f1d6f..1cba68e87fd4 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -27,29 +27,63 @@ from jaxlib.tools import platform_tags +MAIN_RUNFILES_DIR = "__main__/" + + def is_windows() -> bool: return sys.platform.startswith("win32") +def create_wheel_sources_map(wheel_sources, root_packages): + """Returns a map of paths relative to the root package to the full paths.""" + wheel_sources_map = {} + for source in wheel_sources: + for package in root_packages: + if source.startswith("{}/".format(package)): + wheel_sources_map[source] = source + continue + root_package_ind = source.find("/{}/".format(package)) + if root_package_ind >= 0: + wheel_sources_map[source[root_package_ind + 1:]] = source + return wheel_sources_map + + +# TODO(ybaturina): remove the method when we switch to the new wheel build rules +# and the runfiles are not needed. +def get_source_file_prefix(wheel_sources): + return "" if wheel_sources else MAIN_RUNFILES_DIR + + def copy_file( src_files: str | Sequence[str], dst_dir: pathlib.Path, - dst_filename = None, - runfiles = None, + dst_filename=None, + runfiles=None, + wheel_sources_map=None, ) -> None: dst_dir.mkdir(parents=True, exist_ok=True) if isinstance(src_files, str): src_files = [src_files] for src_file in src_files: - src_file_rloc = runfiles.Rlocation(src_file) - if src_file_rloc is None: + if wheel_sources_map: + src_file_loc = wheel_sources_map.get(src_file, None) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + src_file_loc = runfiles.Rlocation(src_file) + else: + raise RuntimeError( + "Either runfiles or wheel_sources_map should be provided!" + ) + if src_file_loc is None: raise ValueError(f"Unable to find wheel source file {src_file}") - src_filename = os.path.basename(src_file_rloc) + + src_filename = os.path.basename(src_file_loc) dst_file = os.path.join(dst_dir, dst_filename or src_filename) if is_windows(): - shutil.copyfile(src_file_rloc, dst_file) + shutil.copyfile(src_file_loc, dst_file) else: - shutil.copy(src_file_rloc, dst_file) + shutil.copy(src_file_loc, dst_file) def platform_tag(cpu: str) -> str: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index ba0eedfd393e..b4b1ec72e8c4 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -29,7 +29,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -56,6 +56,9 @@ action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.", ) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() @@ -68,15 +71,23 @@ def _is_mac(): pyext = "pyd" if build_utils.is_windows() else "so" -def exists(src_file): - path = r.Rlocation(src_file) - if path is None: - return False - return os.path.exists(path) +def _get_file_path(src_file, runfiles=None, wheel_sources_map=None): + if wheel_sources_map: + return wheel_sources_map.get( + src_file.replace(build_utils.MAIN_RUNFILES_DIR, ""), None + ) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + return runfiles.Rlocation(src_file) + else: + raise RuntimeError("Either runfiles or wheel_sources should be provided!") -def patch_copy_mlir_import(src_file, dst_dir): - src_file = r.Rlocation(src_file) +def patch_copy_mlir_import( + src_file, dst_dir, runfiles=None, wheel_sources_map=None +): + src_file = _get_file_path(src_file, runfiles, wheel_sources_map) src_filename = os.path.basename(src_file) with open(src_file) as f: src = f.read() @@ -105,11 +116,17 @@ def patch_copy_mlir_import(src_file, dst_dir): ] -def patch_copy_xla_extension_stubs(dst_dir): +def patch_copy_xla_extension_stubs( + dst_dir, runfiles=None, wheel_sources_map=None +): xla_extension_dir = os.path.join(dst_dir, "xla_extension") os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name) + stub_path = _get_file_path( + f"__main__/jaxlib/xla/xla_extension/{stub_name}", + runfiles, + wheel_sources_map, + ) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). with open(stub_path) as f: src = f.read() @@ -120,7 +137,9 @@ def patch_copy_xla_extension_stubs(dst_dir): f.write(src) -def verify_mac_libraries_dont_reference_chkstack(): +def verify_mac_libraries_dont_reference_chkstack( + runfiles=None, wheel_sources_map=None +): """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. We don't entirely know why this happens, but in some build environments @@ -131,8 +150,11 @@ def verify_mac_libraries_dont_reference_chkstack(): """ if not _is_mac(): return + file_path = _get_file_path( + f"__main__/jaxlib/xla_extension.{pyext}", runfiles, wheel_sources_map + ) nm = subprocess.run( - ["nm", "-g", r.Rlocation(f"__main__/jaxlib/xla_extension.{pyext}")], + ["nm", "-g", file_path], capture_output=True, text=True, check=False, @@ -159,211 +181,249 @@ def write_setup_cfg(sources_path, cpu): ) -def prepare_wheel(sources_path: pathlib.Path, *, cpu): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): + """Assembles a source tree for the wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + # The wheel sources provided by the transitive rules might have different path + # prefixes, so we need to create a map of paths relative to the root package + # to the full paths. + # E.g. if we have the wheel sources paths like + # bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/register_jax_dialects.py and + # external/xla/xla/ffi/api/c_api.h, the resulting map will be + # {'jaxlib/mlir/_mlir_libs/register_jax_dialects.py': + # 'bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/register_jax_dialects.py', + # 'xla/ffi/api/c_api.h': 'external/xla/xla/ffi/api/c_api.h'} + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jaxlib", "xla"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - verify_mac_libraries_dont_reference_chkstack() - copy_runfiles( - dst_dir=sources_path, + verify_mac_libraries_dont_reference_chkstack( + runfiles=r, wheel_sources_map=wheel_sources_map + ) + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jaxlib/tools/LICENSE.txt", - "__main__/jaxlib/README.md", - "__main__/jaxlib/setup.py", + f"{source_file_prefix}jaxlib/tools/LICENSE.txt", + f"{source_file_prefix}jaxlib/README.md", + f"{source_file_prefix}jaxlib/setup.py", ], ) - write_setup_cfg(sources_path, cpu) + write_setup_cfg(wheel_sources_path, cpu) - jaxlib_dir = sources_path / "jaxlib" - copy_runfiles( - "__main__/jaxlib/init.py", dst_dir=jaxlib_dir, dst_filename="__init__.py" + jaxlib_dir = wheel_sources_path / "jaxlib" + copy_files( + f"{source_file_prefix}jaxlib/init.py", + dst_dir=jaxlib_dir, + dst_filename="__init__.py", ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir, src_files=[ - f"__main__/jaxlib/cpu_feature_guard.{pyext}", - f"__main__/jaxlib/utils.{pyext}", - "__main__/jaxlib/jax_common.dll" if build_utils.is_windows() else f"__main__/jaxlib/libjax_common.{soext}", - "__main__/jaxlib/lapack.py", - "__main__/jaxlib/hlo_helpers.py", - "__main__/jaxlib/gpu_prng.py", - "__main__/jaxlib/gpu_linalg.py", - "__main__/jaxlib/gpu_rnn.py", - "__main__/jaxlib/gpu_triton.py", - "__main__/jaxlib/gpu_common_utils.py", - "__main__/jaxlib/gpu_solver.py", - "__main__/jaxlib/gpu_sparse.py", - "__main__/jaxlib/plugin_support.py", - "__main__/jaxlib/version.py", - "__main__/jaxlib/xla/xla_client.py", - f"__main__/jaxlib/weakref_lru_cache.{pyext}", - "__main__/jaxlib/weakref_lru_cache.pyi", - f"__main__/jaxlib/xla_extension.{pyext}", + f"{source_file_prefix}jaxlib/cpu_feature_guard.{pyext}", + f"{source_file_prefix}jaxlib/utils.{pyext}", + f"{source_file_prefix}jaxlib/jax_common.dll" + if build_utils.is_windows() + else f"{source_file_prefix}jaxlib/libjax_common.{soext}", + f"{source_file_prefix}jaxlib/lapack.py", + f"{source_file_prefix}jaxlib/hlo_helpers.py", + f"{source_file_prefix}jaxlib/gpu_prng.py", + f"{source_file_prefix}jaxlib/gpu_linalg.py", + f"{source_file_prefix}jaxlib/gpu_rnn.py", + f"{source_file_prefix}jaxlib/gpu_triton.py", + f"{source_file_prefix}jaxlib/gpu_common_utils.py", + f"{source_file_prefix}jaxlib/gpu_solver.py", + f"{source_file_prefix}jaxlib/gpu_sparse.py", + f"{source_file_prefix}jaxlib/plugin_support.py", + f"{source_file_prefix}jaxlib/version.py", + f"{source_file_prefix}jaxlib/xla/xla_client.py", + f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", + f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", + f"{source_file_prefix}jaxlib/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing # type stubs. with open(jaxlib_dir / "py.typed", "w"): pass - patch_copy_xla_extension_stubs(jaxlib_dir) + patch_copy_xla_extension_stubs( + jaxlib_dir, runfiles=r, wheel_sources_map=wheel_sources_map + ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "cpu", src_files=[ - f"__main__/jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_lapack.{pyext}", ], ) mosaic_python_dir = jaxlib_dir / "mosaic" / "python" - copy_runfiles( + copy_files( dst_dir=mosaic_python_dir, src_files=[ - "__main__/jaxlib/mosaic/python/layout_defs.py", - "__main__/jaxlib/mosaic/python/mosaic_gpu.py", - "__main__/jaxlib/mosaic/python/tpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/layout_defs.py", + f"{source_file_prefix}jaxlib/mosaic/python/mosaic_gpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/tpu.py", ], ) # TODO (sharadmv,skyewm): can we avoid patching this file? patch_copy_mlir_import( - "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir + f"{source_file_prefix}jaxlib/mosaic/python/_tpu_gen.py", + dst_dir=mosaic_python_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu" os.makedirs(mosaic_gpu_dir) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir", src_files=[ - "__main__/jaxlib/mlir/ir.py", - "__main__/jaxlib/mlir/ir.pyi", - "__main__/jaxlib/mlir/passmanager.py", - "__main__/jaxlib/mlir/passmanager.pyi", + f"{source_file_prefix}jaxlib/mlir/ir.py", + f"{source_file_prefix}jaxlib/mlir/ir.pyi", + f"{source_file_prefix}jaxlib/mlir/passmanager.py", + f"{source_file_prefix}jaxlib/mlir/passmanager.pyi", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects", src_files=[ - "__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_func_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_math_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_ods_common.py", - "__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sdy_enums_gen.py", - "__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/arith.py", - "__main__/jaxlib/mlir/dialects/builtin.py", - "__main__/jaxlib/mlir/dialects/chlo.py", - "__main__/jaxlib/mlir/dialects/func.py", - "__main__/jaxlib/mlir/dialects/math.py", - "__main__/jaxlib/mlir/dialects/memref.py", - "__main__/jaxlib/mlir/dialects/mhlo.py", - "__main__/jaxlib/mlir/dialects/scf.py", - "__main__/jaxlib/mlir/dialects/sdy.py", - "__main__/jaxlib/mlir/dialects/sparse_tensor.py", - "__main__/jaxlib/mlir/dialects/stablehlo.py", - "__main__/jaxlib/mlir/dialects/vector.py", - "__main__/jaxlib/mlir/dialects/nvgpu.py", - "__main__/jaxlib/mlir/dialects/nvvm.py", - "__main__/jaxlib/mlir/dialects/llvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_builtin_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_chlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_func_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_math_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_memref_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_mhlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_ods_common.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_scf_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_enums_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_stablehlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/arith.py", + f"{source_file_prefix}jaxlib/mlir/dialects/builtin.py", + f"{source_file_prefix}jaxlib/mlir/dialects/chlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/func.py", + f"{source_file_prefix}jaxlib/mlir/dialects/math.py", + f"{source_file_prefix}jaxlib/mlir/dialects/memref.py", + f"{source_file_prefix}jaxlib/mlir/dialects/mhlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/scf.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sdy.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sparse_tensor.py", + f"{source_file_prefix}jaxlib/mlir/dialects/stablehlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/vector.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvgpu.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/llvm.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "extras", src_files=[ - "__main__/jaxlib/mlir/extras/meta.py", + f"{source_file_prefix}jaxlib/mlir/extras/meta.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/__init__.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/passes/__init__.py", ], ) mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" - copy_runfiles( + copy_files( dst_dir=mlir_libs_dir, src_files=[ - "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/_mlir.{pyext}", - f"__main__/jaxlib/_chlo.{pyext}", - f"__main__/jaxlib/_mlirHlo.{pyext}", - f"__main__/jaxlib/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/_tpu_ext.{pyext}", - f"__main__/jaxlib/_sdy.{pyext}", - f"__main__/jaxlib/_stablehlo.{pyext}", - f"__main__/jaxlib/register_jax_dialects.{pyext}", - f"__main__/jaxlib/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/_mlirGPUPasses.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/__init__.py", + f"{source_file_prefix}jaxlib/_mlir.{pyext}", + f"{source_file_prefix}jaxlib/_chlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirHlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"{source_file_prefix}jaxlib/_mlirSparseTensorPasses.{pyext}", + f"{source_file_prefix}jaxlib/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_tpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_sdy.{pyext}", + f"{source_file_prefix}jaxlib/_stablehlo.{pyext}", + f"{source_file_prefix}jaxlib/register_jax_dialects.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsLLVM.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsNVGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/_triton_ext.{pyext}", - "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", + f"{source_file_prefix}jaxlib/_triton_ext.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), ) triton_dir = jaxlib_dir / "triton" - copy_runfiles( + copy_files( dst_dir=triton_dir, src_files=[ - "__main__/jaxlib/triton/__init__.py", - "__main__/jaxlib/triton/dialect.py", + f"{source_file_prefix}jaxlib/triton/__init__.py", + f"{source_file_prefix}jaxlib/triton/dialect.py", ], ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_enum_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_enum_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_ops_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( - dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", - src_files=[ - "xla/xla/ffi/api/c_api.h", - "xla/xla/ffi/api/api.h", - "xla/xla/ffi/api/ffi.h", - ], + copy_files( + dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", + src_files=[ + "xla/xla/ffi/api/c_api.h", + "xla/xla/ffi/api/api.h", + "xla/xla/ffi/api/ffi.h", + ], ) tmpdir = None @@ -377,6 +437,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): prepare_wheel( pathlib.Path(sources_path), cpu=args.cpu, + wheel_sources=args.srcs, ) package_name = "jaxlib" if args.editable: From 7634230cdcd2d3cb42d1093f6ab255f47f9869d5 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 17 Apr 2025 13:32:12 -0700 Subject: [PATCH 0678/1769] Remove unused jax_spmd_mode flag. PiperOrigin-RevId: 748792684 --- jax/__init__.py | 1 - jax/_src/config.py | 13 ------------- 2 files changed, 14 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 32ae955ae5b8..b0f20b0cb087 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -70,7 +70,6 @@ transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, - spmd_mode as spmd_mode, ) from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info diff --git a/jax/_src/config.py b/jax/_src/config.py index aec8b1450fd0..288cfc64c027 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1016,19 +1016,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help='If True, pmap and shard_map API will be merged.') -spmd_mode = enum_state( - name='jax_spmd_mode', - enum_values=['allow_all', 'allow_jit'], - default='allow_jit', - help=("Decides whether Math on ``jax.Array`` objects that are not fully addressable " - "(i.e. spans across multiple processes) is allowed. The options are:\n\n" - "* ``allow_jit``: Default, ``pjit`` and ``jax.jit`` computations are allowed " - " to execute on non-fully addressable ``jax.Array`` objects\n" - "* ``allow_all``: ``jnp``, normal math (like ``a + b``, etc), ``pjit``, " - " ``jax.jit`` and all other operations are allowed to " - " execute on non-fully addressable ``jax.Array`` objects.")) - - distributed_debug = bool_state( name='jax_distributed_debug', default=False, From 23c973e4fafe03450e05c101803c371e743a897a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 17 Apr 2025 13:39:08 -0700 Subject: [PATCH 0679/1769] [pallas:mosaic] Replaced `device_type=` with `kernel_type` in `TPUCompilerParams` The `device_type` can be inferred from the `tpu.core_type` on the kernel. `kernel_type`, on the other hand, can also be used to define specialized lowering rules for scalar/vector subcores. PiperOrigin-RevId: 748794989 --- jax/_src/pallas/mosaic/core.py | 9 +++++++-- jax/_src/pallas/mosaic/pallas_call_registration.py | 1 - jax/_src/tpu_custom_call.py | 8 ++------ jax/experimental/pallas/tpu.py | 1 + 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 0fe825d44858..130e0eabb413 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -59,6 +59,12 @@ ) +class KernelType(enum.Enum): + TC = 0 + SC_SCALAR_SUBCORE = 1 + SC_VECTOR_SUBCORE = 2 + + @dataclasses.dataclass(frozen=True) class TPUCompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. @@ -79,7 +85,6 @@ class TPUCompilerParams(pallas_core.CompilerParams): Mosaic. flags: A dictionary of command line flags for the kernel. serialization_format: The serialization format for the kernel body. - device_type: The device type to compile for. disable_bounds_checks: Disable bounds checks in the kernel. """ PLATFORM: ClassVar[str] = "mosaic" @@ -93,7 +98,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): flags: dict[str, Any] | None = None internal_scratch_in_bytes: int | None = None serialization_format: int = 1 - device_type: str | None = None + kernel_type: KernelType = KernelType.TC disable_bounds_checks: bool = False # Replace is a method, not a field. diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 824eb7e89716..66be2be76113 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -238,7 +238,6 @@ def _maybe_cast_inputs(*args): allow_input_fusion=mosaic_params.get("allow_input_fusion"), input_output_aliases=input_output_aliases, serialization_format=mosaic_params.get("serialization_format", 1), - device_type=mosaic_params.get("device_type"), internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), collective_id=mosaic_params.get("collective_id", None), has_side_effects=mosaic_params.get("has_side_effects", False), diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index d7a921f1d952..d9f746506d57 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -566,7 +566,6 @@ def _lower_to_custom_call_config( module: ir.Module, *, backend: str, - device_type: str | None, vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, @@ -579,6 +578,7 @@ def _lower_to_custom_call_config( ir_version: int | None = None, disable_bounds_checks: bool = False, ) -> CustomCallBackendConfig: + device_type = _get_device_type(module) lowered_module_asm, ( has_communication, has_custom_barrier, @@ -679,7 +679,6 @@ def lower_module_to_custom_call( has_side_effects: bool, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, - device_type: str | None, disable_bounds_checks: bool = False, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( @@ -691,7 +690,6 @@ def lower_module_to_custom_call( allow_input_fusion=allow_input_fusion, internal_scratch_in_bytes=internal_scratch_in_bytes, collective_id=collective_id, - device_type=device_type, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, @@ -728,11 +726,9 @@ def as_tpu_kernel( disable_bounds_checks: bool = False, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" - device_type = _get_device_type(module) config = _lower_to_custom_call_config( module, backend=backend, - device_type=device_type, vmem_limit_bytes=vmem_limit_bytes, cost_estimate=cost_estimate, flags=flags, @@ -761,7 +757,6 @@ def lowered_as_tpu_kernel( cost_estimate: CostEstimate | None = None, needs_hlo_passes: bool = False, needs_layout_passes: bool = False, - device_type: str | None = None, has_communication: bool = False, has_side_effects: bool = False, has_custom_barrier: bool = False, @@ -774,6 +769,7 @@ def lowered_as_tpu_kernel( internal_scratch_in_bytes: int | None = None, disable_bounds_checks: bool = False, ) -> Callable[..., Any]: + device_type = _get_device_type(lowered_module) lowered_module_asm = lowered_module.operation.get_asm( binary=True, enable_debug_info=True ) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 21976c47166b..5ed6968c673e 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -19,6 +19,7 @@ from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics +from jax._src.pallas.mosaic.core import KernelType as KernelType from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType From 6c4998e27650c3da10c5ac454b0233a31642fca9 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 17 Apr 2025 13:51:47 -0700 Subject: [PATCH 0680/1769] Remove roofline subtests so they work with pytest. PiperOrigin-RevId: 748799657 --- tests/roofline_test.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 140beb3c6e71..9ed85c506814 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -467,13 +467,12 @@ def collective_matmul(a, b): def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) out, result = roofline.roofline(f)(data) - with self.subTest("flops"): - self.assertEqual(result.unfused_flops, 3 * 8) - with self.subTest("hbm_bytes"): - self.assertEqual( - result.unfused_hbm_bytes, - data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, - ) + + self.assertEqual(result.unfused_flops, 3 * 8) + self.assertEqual( + result.unfused_hbm_bytes, + data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, + ) def test_binary_ops(self): for f in [ From 3e585c68328a31a019458cce22a2b828ed6f4ffa Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 17 Apr 2025 15:54:54 -0700 Subject: [PATCH 0681/1769] jnp.frexp: add custom JVP rule for proper derivatives. --- jax/_src/numpy/ufuncs.py | 13 +++++++++++++ tests/lax_numpy_test.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 509b046554d3..c5a5d23764d1 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -3044,7 +3044,10 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") + return _frexp(x) +@custom_jvp +def _frexp(x): dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 @@ -3061,6 +3064,16 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@_frexp.defjvp +def _frexp_jvp(primals, tangents): + x, = primals + t, = tangents + m, e = frexp(x) + mdot = t * exp2(-e.astype(t.dtype)) + edot = np.empty(e.shape, dtypes.float0) + return (m, e), (mdot, edot) + + @export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index de943c3b613a..6a88363c2780 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6326,9 +6326,18 @@ def testGradLogaddexp2Complex(self, shapes, dtype): ) def testGradLdexp(self, n, dtype): rng = jtu.rand_default(self.rng()) - x = rng((), dtype) + x = rng((10,), dtype) check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradFrexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((10,), dtype) * 2 ** n + check_grads(lambda x: jnp.frexp(x)[0], (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase): From 62a46ada2b679064976d5ab8672a530b7ccb94a9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 17 Apr 2025 15:55:11 -0700 Subject: [PATCH 0682/1769] Return empty `wheel sources map` in case if `wheel_sources` is `None`. PiperOrigin-RevId: 748839843 --- jaxlib/tools/build_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 1cba68e87fd4..bf64a36ef0b7 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -37,6 +37,8 @@ def is_windows() -> bool: def create_wheel_sources_map(wheel_sources, root_packages): """Returns a map of paths relative to the root package to the full paths.""" wheel_sources_map = {} + if not wheel_sources: + return wheel_sources_map for source in wheel_sources: for package in root_packages: if source.startswith("{}/".format(package)): From 474dcd409d6fa4c048014851922460f9d4fc199e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 17 Apr 2025 16:43:46 -0700 Subject: [PATCH 0683/1769] Remove code to support jaxlib < v0.6. New minimum jaxlib_extension_version is 330. PiperOrigin-RevId: 748853497 --- jax/_src/array.py | 3 - jax/_src/basearray.py | 10 +- jax/_src/callback.py | 134 ++++++------------ jax/_src/dispatch.py | 3 +- jax/_src/interpreters/mlir.py | 17 +-- jax/_src/interpreters/pxla.py | 6 +- jax/_src/lib/__init__.py | 6 +- jax/_src/profiler.py | 6 +- jax/_src/test_util.py | 5 - jax/experimental/jax2tf/call_tf.py | 8 +- .../mosaic/gpu/dialect_lowering.py | 13 +- .../mosaic/gpu/layout_inference.py | 13 +- jaxlib/weakref_lru_cache.pyi | 5 +- tests/array_test.py | 4 - tests/debugging_primitives_test.py | 6 - tests/experimental_rnn_test.py | 9 +- tests/mosaic/gpu_layout_inference_test.py | 4 - tests/pallas/mosaic_gpu_test.py | 9 -- tests/pjit_test.py | 3 - tests/python_callback_test.py | 5 - 20 files changed, 62 insertions(+), 207 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 760593da9fa9..a802d122a257 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -39,7 +39,6 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe -from jax._src.lib import jaxlib_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -1094,8 +1093,6 @@ def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -if jaxlib_extension_version < 325: - basearray.Array.register(ArrayImpl) def _array_mlir_constant_handler(val): try: diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index 6cd60deda3b0..01a988782671 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -16,12 +16,10 @@ from __future__ import annotations -import abc from collections.abc import Sequence import sys from typing import Any, Union -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.util import use_cpp_class import numpy as np @@ -173,13 +171,7 @@ def copy_to_host_async(self): raise NotImplementedError -if jaxlib_extension_version >= 325: - Array = use_cpp_class(xc.Array)(Array) -else: - class Array(Array, metaclass=abc.ABCMeta): - ... - - +Array = use_cpp_class(xc.Array)(Array) Array.__module__ = "jax" diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 25bdb801edce..a44fb6fb2783 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -32,7 +32,6 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb -from jax._src.lib import jaxlib_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -788,9 +787,6 @@ def emit_python_callback( if platform not in {"cpu", "cuda", "rocm"}: raise ValueError( f"Partitioned callback not supported on {platform} backend.") - if jaxlib_extension_version < 329: - raise ValueError( - "Partitioned callback not supported on jaxlib version < 329.") if result_avals: raise ValueError("Partitioned callback not supported with return values.") backend = ctx.module_context.get_backend() @@ -855,98 +851,48 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". - if jaxlib_extension_version <= 320: - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) - if token: - - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) - - operand_shapes = [ - _aval_to_xla_shape(core.abstract_token), *operand_shapes - ] - result_shapes = [ - _aval_to_xla_shape(core.abstract_token), *result_shapes - ] - operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) - ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) - if sharding is not None: - mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] - else: - device = "gpu" if platform in {"cuda", "rocm"} else "cpu" - partition = "_partitioned" if partitioned else "" - call_target_name = f"xla_ffi{partition}_python_{device}_callback" - if token: - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) - operands = [token, *operands] - if ( - config.use_shardy_partitioner.value - and sharding is not None - and len(ctx.avals_out) > 0 - and isinstance(sharding, sharding_impls.SdyArrayShardingList) - ): - # Add a sharding annotation for the token if we have at least one - # output. Otherwise, the single shardy annotation required of all ops - # (even those without any results) can annotate the token. - sharding = sharding_impls.SdyArrayShardingList( - [*sharding.shardings, sharding.shardings[-1]] - ) - ctx = dataclasses.replace( - ctx, - avals_in=[core.abstract_token, *ctx.avals_in], - avals_out=[core.abstract_token, *ctx.avals_out], + device = "gpu" if platform in {"cuda", "rocm"} else "cpu" + partition = "_partitioned" if partitioned else "" + call_target_name = f"xla_ffi{partition}_python_{device}_callback" + if token: + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + operands = [token, *operands] + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, sharding_impls.SdyArrayShardingList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = sharding_impls.SdyArrayShardingList( + [*sharding.shardings, sharding.shardings[-1]] ) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) - # TODO(dsuo): Remove this line once we deprecate the XLA custom call - # handler. - ifrt_callback = _wrapped_callback - ctx.module_context.add_host_callback(ifrt_callback) - index = np.uint64(len(ctx.module_context.host_callbacks) - 1) - result = ffi.build_ffi_lowering_function( # type: ignore - call_target_name, - has_side_effect=has_side_effect, - )(ctx, *operands, index=np.uint64(index)) + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback + ctx.module_context.add_host_callback(ifrt_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) - if sharding is not None: - mlir.set_sharding(result, sharding) + if sharding is not None: + mlir.set_sharding(result, sharding) + + results = result.results # type: ignore - results = result.results # type: ignore if token: - token, *results = results - return results, token, ifrt_callback + token, *results = results # type: ignore + + return results, token, ifrt_callback # type: ignore diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index eea687145c0e..6d9b79b9b58b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,7 +44,6 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span @@ -496,7 +495,7 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device - if copy == CopySemantics.COPY and jaxlib_extension_version >= 327: + if copy == CopySemantics.COPY: return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 9979ea151b76..80cbf76242aa 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -585,11 +585,7 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Create one global thread pool that can be shared between multiple ir.Contexts # and enabling multi-threading -# TODO: remove this check after jaxlib 0.5.4 -if hasattr(ir, "ThreadPool"): - global_thread_pool = ir.ThreadPool() -else: - global_thread_pool = None +global_thread_pool = ir.ThreadPool() class JaxIrContext(ir.Context): @@ -606,16 +602,7 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # TODO: remove this check after v0.5.4 jaxlib - if global_thread_pool is not None: - context.set_thread_pool(global_thread_pool) - else: - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) + context.set_thread_pool(global_thread_pool) dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d5a18ad2f439..815183b6ebd4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -57,7 +57,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.lib import jaxlib_extension_version from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir @@ -1315,10 +1314,7 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - if jaxlib_extension_version < 330: - args[i]._buf = o - else: - args[i]._buf._replace_with(o) + args[i]._buf._replace_with(o) else: out_.append(o) return out_ diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b011bf0084d4..5d65e005d897 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -102,11 +102,7 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_extension_version: int = getattr(xla_client, '_version', 0) ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) -# TODO(phawkins): remove type: ignore once the minimum jaxlib is bumped. -if jaxlib_extension_version >= 328: - import jaxlib.weakref_lru_cache as weakref_lru_cache # type: ignore # noqa: F401 -else: - weakref_lru_cache = xla_extension # type: ignore # noqa: F401 +import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 912c90182977..c787ea4c0223 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -33,7 +33,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client -from jax._src.lib import version as jaxlib_version _profiler_server: xla_client.profiler.ProfilerServer | None = None @@ -425,10 +424,7 @@ def trace(cls, runner: PGLEProfiler | None): else: options = xla_client.profiler.ProfileOptions() options.enable_hlo_proto = True - - # ToDo(patrios): Remove when jaxlib version is updated to 0.5.4. - if jaxlib_version > (0, 5, 3): - options.raise_error_on_start_failure = True + options.raise_error_on_start_failure = True runner.current_session = xla_client.profiler.ProfilerSession(options) try: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index caff1c73145b..37a73eb148b3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -54,7 +54,6 @@ from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 @@ -370,8 +369,6 @@ def supported_dtypes(): _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} - if jaxlib_extension_version < 327: - types -= {_dtypes.int4, _dtypes.uint4} elif device_under_test() == "gpu": types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -385,8 +382,6 @@ def supported_dtypes(): _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} - if jaxlib_extension_version < 327: - types -= {_dtypes.int4, _dtypes.uint4} if not config.enable_x64.value: types -= {np.uint64, np.int64, np.float64, np.complex128} return types diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index bb2af54025bc..73b7544f991a 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -40,9 +40,7 @@ from jax._src import core from jax._src import effects from jax._src import util -from jax._src.lib import xla_client from jax._src.lib import xla_extension as _xla -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo @@ -598,11 +596,7 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - if jaxlib_extension_version >= 324: - stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) - else: - xla_comp = xla_client.XlaComputation(func_tf_hlo) - stablehlo = _xla.mlir.xla_computation_to_mlir_module(xla_comp) + stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 1239a20ba865..0ee33b4bfa92 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -22,7 +22,6 @@ import operator from typing import Any, Sequence, Type, cast -from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -470,13 +469,11 @@ def _vector_reduction_op_lowering_rule( return [_fragmented_array_to_ir(result, op.result.type)] -# TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. -if jaxlib.version >= (0, 5, 4): - @_register_lowering(mgpu.LayoutCastOp) - def _mgpu_layout_cast_op_lowering_rule( - _: LoweringContext, layout_cast_op: mgpu.LayoutCastOp - ) -> Sequence[ir.Value]: - return [layout_cast_op.x] +@_register_lowering(mgpu.LayoutCastOp) +def _mgpu_layout_cast_op_lowering_rule( + _: LoweringContext, layout_cast_op: mgpu.LayoutCastOp +) -> Sequence[ir.Value]: + return [layout_cast_op.x] def swizzle_and_transforms_from_transforms_attr( diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index c9c565f331c9..ee7571b1db72 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -21,7 +21,6 @@ import math from typing import cast -from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -444,13 +443,11 @@ def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: return None -# TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. -if jaxlib.version >= (0, 5, 4): - @partial(_add_layout_inference_rule, mgpu.LayoutCastOp) - def _infer_layout_cast_op_layout( - layout_cast_op: mgpu.LayoutCastOp, - ) -> OptionalLayouts: - return [layout_cast_op.new_layout], [layout_cast_op.new_layout] +@partial(_add_layout_inference_rule, mgpu.LayoutCastOp) +def _infer_layout_cast_op_layout( + layout_cast_op: mgpu.LayoutCastOp, +) -> OptionalLayouts: + return [layout_cast_op.new_layout], [layout_cast_op.new_layout] @partial(_add_layout_inference_rule, mgpu.WGMMAOp) diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi index 5b91ba1f5fc5..ed965d7be811 100644 --- a/jaxlib/weakref_lru_cache.pyi +++ b/jaxlib/weakref_lru_cache.pyi @@ -14,10 +14,11 @@ # ============================================================================== from collections.abc import Callable +from typing import Any class WeakrefLRUCache: - def __call__(self, arg0: object, /, *args, **kwargs) -> object: ... - def cache_keys(self) -> list[object]: ... + def __call__(self, arg0: Any, /, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ... def cache_clear(self) -> None: ... diff --git a/tests/array_test.py b/tests/array_test.py index 1780213bcc61..4e01eb33841e 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,7 +29,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip from jax._src.mesh import AxisType, AbstractMesh @@ -1430,9 +1429,6 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(hash(mesh1), hash(mesh2)) def test_memory_kind_with_abstract_mesh(self): - if jaxlib_extension_version < 326: - self.skipTest('Requires jaxlib_extension_version >= 326') - abstract_mesh = AbstractMesh((2,), ('x',)) ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') self.assertEqual(ns.memory_kind, 'pinned_host') diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index d9d50a546e57..becd18033d6d 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,7 +25,6 @@ from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu -from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp import numpy as np @@ -1222,11 +1221,6 @@ def setUp(self): raise unittest.SkipTest( f"Test requires CPU or GPU devices. Got {jtu.device_under_test()}" ) - if jaxlib_extension_version < 329: - self.skipTest( - "Requires jaxlib_extension_version >= 329. Got" - f" {jaxlib_extension_version}." - ) if len(jax.devices()) < 2: raise unittest.SkipTest("Test requires >= 2 devices.") diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 7fa3b93f3c42..58f5291e9375 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,18 +213,11 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - if jtu.jaxlib_version() <= (0, 5, 2): - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) - else: - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', stablehlo) @jtu.run_on_devices("cuda") def test_no_workspace_overflow(self): - if jtu.jaxlib_version() <= (0, 5, 2): - self.skipTest("Older versions fail because of integer overflow.") - # Problem sizes known to cause overflows on older versions. batch_size, max_seq_length, input_size = 256, 500, 512 num_layers, hidden_size = 1, 256 diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 104b088bbdd2..5355adfb2c7b 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -19,7 +19,6 @@ from absl.testing import parameterized import jax from jax._src import config -from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir @@ -214,9 +213,6 @@ def body(lhs, rhs): self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr]) def test_infer_layout_cast_layout(self): - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. - if jaxlib.version < (0, 5, 4): - self.skipTest("Test requires jaxlib version >= 0.5.4") add = cast = None shape = (128, 64) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 3d6a853e7461..6e30b146a606 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized import jax from jax import lax -from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.pallas import pallas_call from jax._src.pallas import core as pallas_core @@ -1399,10 +1398,6 @@ def rotate(src, dst): np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. - if jaxlib.version < (0, 5, 4): - self.skip_if_wg_semantics() - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), @@ -1621,10 +1616,6 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_core.core_map_p, } - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. - if jaxlib.version < (0, 5, 4): - expected_missing_primitives.add(mgpu_primitives.layout_cast_p) - self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5b69a415de1b..391dd39a9dd2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -62,7 +62,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -1413,8 +1412,6 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) def test_device_put_copy_donate(self): - if jaxlib_extension_version < 327: - raise unittest.SkipTest("Copy not supported in device put.") x = np.arange(1000) y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index e4aeb8d66f9e..0d78730b8ca8 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,7 +28,6 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src.lib import jaxlib_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -588,8 +587,6 @@ def fun(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -613,8 +610,6 @@ def f(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) From 7de522c5a3442a202eb304cb62023807f98c5e09 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Apr 2025 17:41:30 -0700 Subject: [PATCH 0684/1769] Enter into auto mode for `.at[...].get(...)` a bit earlier so that all ops inside `_gather` are in auto mode. Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`. PiperOrigin-RevId: 748867490 --- jax/_src/interpreters/batching.py | 4 +-- jax/_src/lax/lax.py | 7 +++--- jax/_src/numpy/indexing.py | 29 +++++++++++----------- tests/BUILD | 4 +-- tests/pjit_test.py | 41 +++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 22 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index c97c8d558608..ee0b46feddb7 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1153,9 +1153,9 @@ def __init__(self, leaf_idx, src, dst): self.src = src self.dst = dst -def bdim_at_front(x, bdim, size): +def bdim_at_front(x, bdim, size, mesh_axis=None): if bdim is not_mapped: - return broadcast(x, size, 0) + return broadcast(x, size, 0, mesh_axis=mesh_axis) else: return moveaxis(x, bdim, 0) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 79bd42607290..36b19b47a324 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7323,7 +7323,7 @@ def _select_transpose_rule(t, which, *cases): if ad.is_undefined_primal(case) else None for i, case in enumerate(cases) ] -def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): +def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): which, *cases = batched_args which_bdim, *case_bdims = batch_dims size = next(x.shape[i] for x, i in zip(batched_args, batch_dims) @@ -7350,7 +7350,7 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which) else which) if not all(() == np.shape(c) for c in cases): - cases = [batching.bdim_at_front(c, bdim, size) + cases = [batching.bdim_at_front(c, bdim, size, axis_data.explicit_mesh_axis) for c, bdim in zip(cases, case_bdims)] assert all(np.shape(cases[0]) == np.shape(c) for c in cases[1:]) if 0 < np.ndim(which) < np.ndim(cases[0]): @@ -7440,7 +7440,8 @@ def _select(offset, cases): vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule -batching.primitive_batchers[select_n_p] = _select_batch_rule +batching.fancy_primitive_batchers[select_n_p] = _select_batch_rule +batching.skippable_batchers[select_n_p] = lambda _: () mlir.register_lowering(select_n_p, _select_hlo_lowering) pe.def_trivial_padding(select_n_p) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 6982cc4080e6..21a97277e433 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -28,7 +28,6 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import errors -from jax._src import mesh as mesh_lib from jax._src.api import jit from jax._src.lax import lax as lax_internal from jax._src.numpy import einsum @@ -637,14 +636,21 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, return lax.dynamic_index_in_dim(arr, idx, keepdims=False) treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) - return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding) + internal_gather = partial( + _gather, treedef=treedef, static_idx=static_idx, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, fill_value=fill_value) + if out_sharding is not None: + return auto_axes(internal_gather, out_shardings=out_sharding + )(arr, dynamic_idx) + return internal_gather(arr, dynamic_idx) + # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(1, 2)) -def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding): +def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, + unique_indices, mode, fill_value): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) @@ -668,26 +674,19 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, # We avoid generating a gather when indexer.gather_indices.size is empty. if not core.is_empty_shape(indexer.gather_indices.shape): - internal_gather = partial( - lax.gather, - dimension_numbers=indexer.dnums, - slice_sizes=indexer.gather_slice_shape, + y = lax.gather( + y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape, unique_indices=unique_indices or indexer.unique_indices, indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted, mode=mode, fill_value=fill_value) - if out_sharding is not None: - internal_gather = auto_axes( - internal_gather, axes=mesh_lib.get_abstract_mesh().axis_names, - out_shardings=out_sharding) - y = internal_gather(y, indexer.gather_indices) # Reverses axes with negative strides. if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) - # This adds np.newaxis/None dimensions. return lax.expand_dims(y, indexer.newaxis_dims) + class _Indexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] diff --git a/tests/BUILD b/tests/BUILD index c46ea7556a44..5f142c097889 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -302,8 +302,8 @@ jax_multiplatform_test( "gpu_h100x2", ], shard_count = { - "cpu": 3, - "tpu": 4, + "cpu": 5, + "tpu": 5, }, tags = ["multiaccelerator"], deps = [ diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 391dd39a9dd2..27c5a053a638 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7447,6 +7447,47 @@ def step(carry, x): xs = jax.device_put(xs, sharding) scan(xs) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_select_batch(self, mesh): + y_sharding = NamedSharding(mesh, P('y', None)) + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 2), dtype=jnp.int32), xy_sharding) + + out_s = NamedSharding(mesh, P('x', 'y', None, None)) + + def select(a, b): + c = a.at[b].get(out_sharding=y_sharding) + return c + + @jax.jit + def vmap_select(batch_a, batch_b): + out = jax.vmap(select)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, out_s.spec) + return out + + out = vmap_select(batch_a, batch_b) + self.assertEqual(out.sharding, out_s) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_where_vmap(self, mesh): + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.bool), xy_sharding) + + def where(a, b): + out = jnp.where(b, a, 0) + return out + + @jax.jit + def vmap_where(batch_a, batch_b): + out = jax.vmap(where)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, xy_sharding.spec) + return out + + out = vmap_where(batch_a, batch_b) + self.assertEqual(out.sharding, xy_sharding) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From a2ebdf6d71004a7d108f0fa9005e3b0f7adeb85f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Apr 2025 18:40:39 -0700 Subject: [PATCH 0685/1769] Rename `with_user_mesh` to `with_explicit_mesh` PiperOrigin-RevId: 748880870 --- jax/_src/test_util.py | 2 +- tests/error_check_test.py | 4 +- tests/pjit_test.py | 162 +++++++++++++++++++------------------- tests/shard_map_test.py | 4 +- 4 files changed, 86 insertions(+), 86 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 37a73eb148b3..8cd0b0d7f6f4 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1415,7 +1415,7 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def with_user_mesh(sizes, names, axis_types=None): +def with_explicit_mesh(sizes, names, axis_types=None): axis_types = ((mesh_lib.AxisType.Explicit,) * len(names) if axis_types is None else axis_types) def decorator(fn): diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 0c77989b8a43..a7eeb4dbf86b 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -225,7 +225,7 @@ def f(x): jax.jit(error_check.raise_if_error)() @parameterized.product(jit=[True, False]) - @jtu.with_user_mesh((2, 2), ("x", "y")) + @jtu.with_explicit_mesh((2, 2), ("x", "y")) def test_error_check_explicit_mode(self, mesh, jit): def f(x): error_check.set_error_if(x <= 0, "x must be greater than 0") @@ -254,7 +254,7 @@ def f(x): error_check.raise_if_error() @parameterized.product(jit=[True, False]) - @jtu.with_user_mesh( + @jtu.with_explicit_mesh( (2, 2), ("x", "y"), axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 27c5a053a638..53f0d042e801 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4958,7 +4958,7 @@ def check_wsc_in_lowered(self, text): else: self.assertIn('@Sharding', text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_basic_mul(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5003,7 +5003,7 @@ def g(x): jax.jit(jax.grad(g)).lower(sds) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5052,7 +5052,7 @@ def g(x, y): ('fsdp', P('x', None), P('x', None), P('x', None), 'all-gather'), ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5094,7 +5094,7 @@ def g(x, y): self.assertEqual(out[1].sharding, arr2.sharding) @parameterized.parameters([True, False]) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((4,), ('x',)) def test_dot_general_out_sharding(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -5150,7 +5150,7 @@ def f(x, y): ('other_half_tp', P(None, 'y'), P('y', None), 'Contracting dimensions are sharded', core.ShardingTypeError), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5163,7 +5163,7 @@ def f(x, y): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) @@ -5182,7 +5182,7 @@ def test_dot_general_batch_error(self, mesh): ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - @jtu.with_user_mesh((2, 2), ('model', 'data')) + @jtu.with_explicit_mesh((2, 2), ('model', 'data')) def test_aval_repr(self, mesh): mesh = mesh.abstract_mesh aval = core.ShapedArray((128, 64), np.float32, @@ -5201,7 +5201,7 @@ def test_aval_repr(self, mesh): aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None))) self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_jnp_ones_mesh_context_eager(self, mesh): s = NamedSharding(mesh, P('x', None)) out = jnp.ones((8, 2), dtype=jnp.int32, device=s) @@ -5218,7 +5218,7 @@ def test_jnp_ones_mesh_context_eager(self, mesh): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5249,7 +5249,7 @@ def f(x): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5290,7 +5290,7 @@ def g(x): ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5316,7 +5316,7 @@ def f(x): ('3', 3), ('4', 4), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5335,7 +5335,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((1,), 'x') + @jtu.with_explicit_mesh((1,), 'x') def test_broadcasting_nary_error(self, mesh): mesh2 = Mesh([jax.devices()[0]], 'y', axis_types=(mesh_lib.AxisType.Explicit,)) @@ -5351,7 +5351,7 @@ def f(x, y): ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5369,7 +5369,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5385,7 +5385,7 @@ def f(x): f(arr) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) @@ -5404,7 +5404,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) @@ -5429,7 +5429,7 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_with_out_sharding(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5474,7 +5474,7 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr3.sharding) self.assertEqual(out[1].sharding, arr4.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_inverse(self, mesh): np_inp = np.arange(64.) @@ -5528,7 +5528,7 @@ def h2(x, y): ('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096), P(None, 'x', None, None, None, None), P(None, 'x'), False), ) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg, mesh): np_inp = np.arange(math.prod(src_shape), @@ -5611,7 +5611,7 @@ def g(x): P(None, 'y', None, 'x'), None, 'This reshape is not supported' ), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec, dst_spec, error_msg, mesh): np_inp = np.arange(math.prod(src_shape), @@ -5643,7 +5643,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_select(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5712,7 +5712,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_mesh_cast_reshard_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5739,7 +5739,7 @@ def g(x): ' mesh and the target mesh'): g(arr) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit, AxisType.Auto)) def test_mesh_cast_explicit_data_movement_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -5756,7 +5756,7 @@ def f(x): ValueError, 'Explicit data movement in mesh_cast is not allowed'): f(arr) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shard_map_full_manual(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5782,7 +5782,7 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shard_map_dot(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5810,7 +5810,7 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5841,7 +5841,7 @@ def g(x): with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_squeeze(self, mesh): np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @@ -5867,7 +5867,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_pad(self, mesh): np_inp = np.arange(8.) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -5909,7 +5909,7 @@ def g(x): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_concatenate(self, mesh): np_inp = np.arange(16.).reshape(4, 4) s = NamedSharding(mesh, P('x', 'y')) @@ -5951,7 +5951,7 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_scan(self, mesh): carry = jax.device_put(np.arange(16.).reshape(2, 8), NamedSharding(mesh, P(None, 'x'))) @@ -5989,7 +5989,7 @@ def g(carry, arr): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_argminmax(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6010,7 +6010,7 @@ def f(x): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) self.check_wsc_in_lowered(f.lower(arr).as_text()) - @jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) def test_only_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -6082,7 +6082,7 @@ def f(x, x2): "AxisTypes should be the same in a tuple subset of PartitionSpec"): NamedSharding(mesh2, P(('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_where_with_scalar(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6092,7 +6092,7 @@ def test_where_with_scalar(self, mesh): self.assertArraysEqual(out, x) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_full_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6119,7 +6119,7 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_full_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6144,7 +6144,7 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_auto_user_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6171,7 +6171,7 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_user_auto_mix_error(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6188,7 +6188,7 @@ def f(x, y): ValueError, "For primitive dot_general, context mesh.*aval mesh"): f(arr, arr.T) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_split(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6217,7 +6217,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_return_output_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) @@ -6236,7 +6236,7 @@ def f(x): self.assertDictEqual(out.sharding.mesh._axis_types_dict, {AxisType.Auto: ('x',)}) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_device_put_use_mesh(self, mesh): out = jax.device_put(np.arange(8), P('x')) self.assertArraysEqual(out, np.arange(8)) @@ -6249,7 +6249,7 @@ def test_device_put_no_use_mesh_error(self): ' passed to device_put'): jax.device_put(np.arange(8), P('x')) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_inputs_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) @@ -6271,7 +6271,7 @@ def f(x, y): self.assertDictEqual(out2.sharding.mesh._axis_types_dict, {AxisType.Auto: ('x',)}) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_output_different_context_error(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -6298,7 +6298,7 @@ def g(x, y): ValueError, "PartitionSpec.*cannot contain axis names.*Auto"): g(arr1, arr2) - @jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'), + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z'), axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto)) def test_out_sharding_mix_axis_types(self, mesh): @@ -6323,7 +6323,7 @@ def f(x): else: self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_mode_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6352,7 +6352,7 @@ def g(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((4,), ('x',)) def test_concat_vmap(self, mesh): @jax.jit def _f(sharded_array, replicated_array): @@ -6383,7 +6383,7 @@ def test_aval_spec_explicit_auto_complete(self): out = core.ShapedArray((8, 2), jnp.int32, sharding=s) self.assertEqual(out.sharding.spec, P('x', None)) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6413,7 +6413,7 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) @@ -6437,7 +6437,7 @@ def f(x, y, z): self.assertEqual(out.shape, (16, 8, 16)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_auto_complete_spec(self, mesh): s = NamedSharding(mesh, P('data')) @@ -6489,7 +6489,7 @@ def f(condition, x, y): f = jax.jit(f, in_shardings=(sharding, sharding, sharding)) f(condition, x, x).block_until_ready() - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) @@ -6513,7 +6513,7 @@ def f(x, y, z): 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Explicit, mesh_lib.AxisType.Auto)) def test_mix_to_full_user_mode(self, mesh): @@ -6540,7 +6540,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_partial_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6566,7 +6566,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_gather_out_sharding(self, mesh): embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16), jax.NamedSharding(mesh, P(None, 'x'))) @@ -6601,7 +6601,7 @@ def g(x, y): out = jax.jit(jax.grad(g))(embed, tok) self.assertEqual(out.sharding, embed.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshard_error(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6716,7 +6716,7 @@ def matmul_reshard(arr1, arr2): with jax.sharding.use_mesh(mesh): matmul_reshard(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_auto_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6736,7 +6736,7 @@ def f(x): out = hf(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_full_visible_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6789,7 +6789,7 @@ def f(x): self.assertTupleEqual(out2.sharding._device_assignment, tuple(mesh2.devices.flat)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_svd(self, mesh): np_inp = np.zeros([128, 128]) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None))) @@ -6814,7 +6814,7 @@ def f(x, y): self.assertNotIn("mhlo.sharding", lowered_text) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_mul_vmap(self, use_jit, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6851,7 +6851,7 @@ def g(x): self.assertEqual(out.sharding, arr.sharding) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_vmap(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(4, 2, 2) np_inp2 = np.arange(16.).reshape(2, 4, 2) @@ -6870,7 +6870,7 @@ def f(x, y): self.assertEqual(out.shape, (2, 2, 4)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_vmap(self, mesh): np_inp = np.arange(16).reshape(2, 8) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x'))) @@ -6886,7 +6886,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y'))) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shit_vmap_error_check(self, use_jit, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -6915,7 +6915,7 @@ def f(x, y): "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): jax.vmap(f, spmd_axis_name='y')(arr, arr) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_unmapped_last_vmap(self, mesh): np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x',))) @@ -6928,7 +6928,7 @@ def f(x): self.assertEqual(out.shape, (4, 8)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_shmap_close_over(self, mesh): const = jnp.arange(8) def f(): @@ -6938,7 +6938,7 @@ def f(): shmap_f() # doesn't crash jax.jit(shmap_f)() # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_shmap_close_over_partial_auto(self, mesh): const = jnp.arange(8) @@ -6954,7 +6954,7 @@ def f(): jaxpr = f.trace().jaxpr self.assertIn('mesh_cast', str(jaxpr)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_wsc_error(self, mesh): s = NamedSharding(mesh, P('x')) with self.assertRaisesRegex( @@ -7001,7 +7001,7 @@ def f(x, y): "Using PartitionSpec when.*not under a mesh context.*is not allowed"): f(arr, arr2) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_error_on_canonicalize_under_auto_mode(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -7038,7 +7038,7 @@ def iota(): out = iota() self.assertEqual(out.sharding, yz_sharding) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_cumsum(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -7110,7 +7110,7 @@ def test_wsc_pspec_use_mesh(self, sharded_inp): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_axes_api_error_manual_to_auto_explicit(self, mesh): def g(x): @@ -7122,7 +7122,7 @@ def g(x): jax.jit(shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) )(np.arange(16).reshape(8, 2)) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_auto_axes_numpy_array(self, mesh): @jax.jit def f(x): @@ -7140,7 +7140,7 @@ def f(x): jtu.dtypes.all_integer + jtu.dtypes.all_unsigned), shape_and_spec=[((), P()), ((2,), P('x')), ((2, 4), P('x', 'y'))], ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_bitcast_convert_type(self, from_dtype, to_dtype, shape_and_spec, mesh): shape, spec = shape_and_spec @@ -7179,7 +7179,7 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_dynamic_slice(self, mesh): np_inp = np.arange(16., dtype=np.float32) s = NamedSharding(mesh, P('x')) @@ -7224,7 +7224,7 @@ def test_divisbility_aval_error(self): ValueError, 'does not evenly divide the dimension size'): core.ShapedArray((5, 2), jnp.int32, sharding=s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_scan_unroll(self, mesh): np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) @@ -7238,7 +7238,7 @@ def body(carry, x): f(carry, arr) # doesn't crash - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_reshard_with_np_array(self, mesh): out = reshard(np.arange(8), P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) @@ -7259,7 +7259,7 @@ def test_set_mesh(self): finally: jax.sharding.set_mesh(prev_mesh) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_auto_axes_late_bind(self, mesh): @auto_axes def f(x): @@ -7269,7 +7269,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_explicit_axes_late_bind(self, mesh): @explicit_axes def f(x): @@ -7279,7 +7279,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_rng_bit_generator(self, mesh): def f(key): out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) @@ -7300,7 +7300,7 @@ def f(key): self.assertArraysEqual(out1[0], out2[0]) self.assertArraysEqual(out1[1], out2[1]) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_fold_in(self, mesh): key = jax.random.key(72) key = jax.device_put(key, NamedSharding(mesh, P())) @@ -7325,7 +7325,7 @@ def f(key): x=np.arange(8 * 12).reshape(8, 12)), P('x', 'y')), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_random_functions(self, fun, out_spec, mesh): @jax.jit def f(key): @@ -7354,7 +7354,7 @@ def f(key): 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_random_truncated_normal(self, mesh): @jax.jit def f(key, lower): @@ -7447,7 +7447,7 @@ def step(carry, x): xs = jax.device_put(xs, sharding) scan(xs) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_select_batch(self, mesh): y_sharding = NamedSharding(mesh, P('y', None)) xy_sharding = NamedSharding(mesh, P('x', 'y', None)) @@ -7469,7 +7469,7 @@ def vmap_select(batch_a, batch_b): out = vmap_select(batch_a, batch_b) self.assertEqual(out.sharding, out_s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_where_vmap(self, mesh): xy_sharding = NamedSharding(mesh, P('x', 'y', None)) batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ae7f532c1526..5449c68577e0 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1995,7 +1995,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v * v, out, check_dtypes=False) - @jtu.with_user_mesh((2, 2), ('i', 'j')) + @jtu.with_explicit_mesh((2, 2), ('i', 'j')) def test_partial_auto_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, @@ -2039,7 +2039,7 @@ def h(x): jax.grad(h)(v) # doesn't crash jax.jit(jax.grad(h))(v) # doesn't crash - @jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) + @jtu.with_explicit_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) def test_partial_auto_explicit_multi_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, From eab1dfccbca4647c025e694ef9fed150dca11dad Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 17 Apr 2025 18:45:46 -0700 Subject: [PATCH 0686/1769] [Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support: * Squeeze dimension (currently None, could be pl.Squeezed()) * Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …) * Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j). This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter. PiperOrigin-RevId: 748881960 --- docs/pallas/CHANGELOG.md | 9 ++ docs/pallas/grid_blockspec.md | 46 ++++---- jax/_src/pallas/core.py | 142 ++++++++++++++++++------- jax/_src/pallas/hlo_interpreter.py | 65 ++++++----- jax/_src/pallas/mosaic/interpret.py | 41 +++---- jax/_src/pallas/mosaic/lowering.py | 80 +++++++------- jax/_src/pallas/mosaic_gpu/lowering.py | 15 +-- jax/_src/pallas/mosaic_gpu/pipeline.py | 18 +++- jax/_src/pallas/pallas_call.py | 18 ++-- jax/_src/pallas/triton/lowering.py | 70 +++++++----- jax/experimental/pallas/__init__.py | 10 +- tests/pallas/pallas_test.py | 33 +++--- tests/pallas/tpu_pallas_state_test.py | 2 + 13 files changed, 334 insertions(+), 215 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 7533e6eda053..c960280a8891 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,6 +11,15 @@ For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changel Remember to align the itemized text with the first line of an item within a list. --> +## Released with jax 0.6.1 + +* Changes + + * {func}`jax.experimental.pallas.BlockSpec` now takes in special types in + addition to ints/None in the `block_shape`. `indexing_mode` has been + removed. To achieve "Unblocked", pass a `pl.Element(size)` into + `block_shape` for each entry that needs unblocked indexing. + ## Released with jax 0.5.0 * New functionality diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index ea1df15f2fd4..d74a91c96f54 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -151,8 +151,7 @@ over the second axis: ```python >>> def show_program_ids(x_shape, block_shape, grid, -... index_map=lambda i, j: (i, j), -... indexing_mode=pl.Blocked()): +... index_map=lambda i, j: (i, j)): ... def program_ids_kernel(o_ref): # Fill the output block with 10*program_id(1) + program_id(0) ... axes = 0 ... for axis in range(len(grid)): @@ -162,7 +161,7 @@ over the second axis: ... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), ... grid=grid, ... in_specs=[], -... out_specs=pl.BlockSpec(block_shape, index_map, indexing_mode=indexing_mode), +... out_specs=pl.BlockSpec(block_shape, index_map), ... interpret=True)() ... print(res) @@ -227,7 +226,8 @@ See {ref}`pallas_tpu_noteworthy_properties`. A `None` value appearing as a dimension value in the `block_shape` behaves as the value `1`, except that the corresponding -block axis is squeezed. In the example below, observe that the +block axis is squeezed (you could also pass in `pl.Squeezed()` instead of +`None`). In the example below, observe that the shape of the `o_ref` is (2,) when the block shape was specified as `(None, 2)` (the leading dimension was squeezed). @@ -269,27 +269,33 @@ used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`. ``` -### The "unblocked" indexing mode +### The "element" indexing mode -The behavior documented above applies to the `indexing_mode=pl.Blocked()`. -When using the `pl.Unblocked` indexing mode the values returned by the +The behavior documented above applies to the default "blocked" indexing mode. +When integers are used in the `block_shape` tuple e.g. `(4, 8)`, it is +equivalent to passing in a `pl.Blocked(block_size)` object instead, e.g. +`(pl.Blocked(4), pl.Blocked(8))`. Blocked indexing mode means the indices +returned by `index_map` are *block indices*. We can pass in objects other than +`pl.Blocked` to change the semantics of `index_map`, most notably, +`pl.Element(block_size)`.. +When using the `pl.Element` indexing mode the values returned by the index map function are used directly as the array indices, without first scaling them by the block size. -When using the unblocked mode you can specify virtual padding -of the array as a tuple of low-high paddings for each dimension: the +When using the `pl.Element` mode you can specify virtual padding +of the array as a tuple of low-high paddings for the dimension: the behavior is as if the overall array is padded on input. No guarantees -are made for the padding values in the unblocked mode, similarly to the padding +are made for the padding values in element mode, similarly to the padding values for the blocked indexing mode when the block shape does not divide the overall array shape. -The unblocked mode is currently supported only on TPUs. +The `Element` mode is currently supported only on TPUs. ```python ->>> # unblocked without padding ->>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked()) +>>> # element without padding +>>> show_program_ids(x_shape=(8, 6), block_shape=(pl.Element(2), pl.Element(3)), +... grid=(4, 2), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 0 0 1 1 1] [ 0 0 0 1 1 1] [10 10 10 11 11 11] @@ -299,10 +305,12 @@ The unblocked mode is currently supported only on TPUs. [30 30 30 31 31 31] [30 30 30 31 31 31]] ->>> # unblocked, first pad the array with 1 row and 2 columns. ->>> show_program_ids(x_shape=(7, 7), block_shape=(2, 3), grid=(4, 3), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked(((1, 0), (2, 0)))) +>>> # element, first pad the array with 1 row and 2 columns. +>>> show_program_ids(x_shape=(7, 7), +... block_shape=(pl.Element(2, (1, 0)), +... pl.Element(3, (2, 0))), +... grid=(4, 3), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 1 1 1 2 2 2] [10 11 11 11 12 12 12] [10 11 11 11 12 12 12] diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index a74206c46ce7..6af78d3e6ec8 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -24,7 +24,7 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable +from typing import Any, ClassVar, Hashable, Protocol, TypeAlias, Union, runtime_checkable import jax from jax._src import api_util @@ -332,36 +332,88 @@ def current_grid_env() -> GridEnv | None: return _pallas_tracing_env.grid_env_stack[-1] -class Mapped: - """Used as a block shape dimension to denote a mapped dimension. - A mapped dimension behaves like `1` except it is squeezed from the block. - See :ref:`pallas_blockspec` for more details. - """ - def __repr__(self): - return "Mapped" -mapped = Mapped() +@dataclasses.dataclass(frozen=True) +class Element: + """Use to index an array using an elementwise start index.""" + block_size: int + padding: tuple[int, int] = (0, 0) + def __str__(self): + if self.padding == (0, 0): + return f"Element({self.block_size})" + return f"Element({self.block_size}, padding={self.padding})" @dataclasses.dataclass(frozen=True) -class Unblocked: - padding: tuple[tuple[int, int], ...] | None = None - - def __repr__(self): - return f"Unblocked(padding={self.padding})" -unblocked = Unblocked() +class Squeezed: + """Represents a one-sized block dimension that is squeezed out in the kernel.""" +squeezed = Squeezed() +@dataclasses.dataclass(frozen=True) class Blocked: - def __repr__(self): - return "Blocked" -blocked = Blocked() + """The default BlockShape type.""" + block_size: int + def __str__(self): + return f"Blocked({self.block_size})" + +BlockDim: TypeAlias = Element | Squeezed | Blocked -IndexingMode = Union[Blocked, Unblocked] def default_index_map(ndim: int) -> Callable: return lambda *args: (0,) * ndim + +def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim: + match dim: + case None: + return squeezed + case int(): + return Blocked(int(dim)) + case Squeezed() | Blocked() | Element(): + return dim + case _: + # Handle case where the dim is a symbolic dimension so we assume it is + # Blocked. + if jax_core.is_symbolic_dim(dim): + return Blocked(dim) + try: + return Blocked(int(dim)) + except Exception as e: + raise ValueError( + f"Unsupported block dimension type: {type(dim)}. Allowed types:" + " `pl.Squeezed`, `pl.Blocked`, `pl.Element`, `int`, `None`." + ) from e + +def _canonicalize_block_shape(block_shape: Sequence[BlockDim | int | None] + ) -> tuple[BlockDim, ...]: + return tuple(_canonicalize_block_dim(dim) for dim in block_shape) + + +def _get_block_dim_size(dim: BlockDim) -> int: + match dim: + case Squeezed(): + return 1 + case Blocked(block_size): + return block_size + case Element(): + return dim.block_size + case _: + raise ValueError(f"Unsupported block shape type: {type(dim)}") + + +def _get_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + return tuple(_get_block_dim_size(dim) for dim in block_shape) + +def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + # Special handling for squeezed here (don't include Squeezed dims in the Ref + # shape). + return tuple( + _get_block_dim_size(dim) + for dim in block_shape + if not isinstance(dim, Squeezed) + ) + @dataclasses.dataclass class BlockSpec: """Specifies how an array should be sliced for each invocation of a kernel. @@ -369,12 +421,21 @@ class BlockSpec: See :ref:`pallas_blockspec` for more details. """ # An internal canonicalized version is in BlockMapping. - block_shape: Sequence[int | None] | None = None + block_shape: Sequence[BlockDim | int | None] | None = None index_map: Callable[..., Any] | None = None memory_space: Any | None = dataclasses.field(kw_only=True, default=None) - indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) + indexing_mode: Any | None = None pipeline_mode: Buffered | None = None + def __post_init__(self): + # TODO(sharadmv): Remove this check. + if self.indexing_mode is not None: + raise ValueError( + "indexing_mode has been removed. Please pass in `pl.Element` for each" + " block dimension in `block_shape` instead to enable 'Unblocked'" + " indexing." + ) + def to_block_mapping( self, origin: OriginStr, @@ -392,9 +453,9 @@ def to_block_mapping( else: index_map_func = self.index_map if self.block_shape is None: - block_shape = array_aval.shape + block_shape = _canonicalize_block_shape(array_aval.shape) else: - block_shape = self.block_shape # type: ignore + block_shape = _canonicalize_block_shape(self.block_shape) if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) " @@ -402,8 +463,8 @@ def to_block_mapping( f"array shape {array_aval.shape}." ) - unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) + ref_block_shape = _get_ref_block_shape(block_shape) + block_array_aval = array_aval.update(shape=ref_block_shape) if isinstance(array_aval, jax_core.DShapedArray): # Get the "max" shape for the ragged array. block_array_aval = jax_core.ShapedArray( @@ -435,7 +496,6 @@ def to_block_mapping( flat_index_map_fun, index_map_avals ) - mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) if len(out_avals) != len(block_shape): raise ValueError( f"Index map function {debug.func_src_info} for " @@ -460,10 +520,9 @@ def to_block_mapping( array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( - block_shape=mapped_block_shape, + block_shape=block_shape, transformed_block_aval=block_aval, # There are no transforms by default index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), - indexing_mode=self.indexing_mode, array_shape_dtype=jax.ShapeDtypeStruct( array_aval_shape, array_aval.dtype ), @@ -502,10 +561,9 @@ class BlockMapping: """ # TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform. # After all, it's just indexing out singleton dimensions. - block_shape: tuple[Mapped | int, ...] + block_shape: tuple[BlockDim, ...] transformed_block_aval: AbstractMemoryRef index_map_jaxpr: jax_core.ClosedJaxpr - indexing_mode: IndexingMode array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr transforms: Sequence[MemoryRefTransform] = () @@ -514,8 +572,8 @@ class BlockMapping: def check_invariants(self) -> None: if not config.enable_checks.value: return - unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.ref_aval.shape, ( + ref_block_shape = _get_ref_block_shape(self.block_shape) + assert ref_block_shape == self.ref_aval.shape, ( self.block_shape, self.ref_aval.shape) assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( self.block_shape, self.array_shape_dtype @@ -563,18 +621,22 @@ def compute_start_indices_interpret(self, loop_idx, *args): # updated values since we only care about the return values. block_indices, _ = split_list(block_indices_and_rest, [len(self.block_shape)]) - if isinstance(self.indexing_mode, Blocked): - return tuple(i if b is mapped else b * i - for b, i in zip(self.block_shape, block_indices)) - elif isinstance(self.indexing_mode, Unblocked): - return block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}") + def _get_start_index(i, b): + match b: + case Squeezed() | Element(): + return i + case Blocked(block_size): + return block_size * i + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + return tuple( + _get_start_index(i, b) for i, b in zip(block_indices, self.block_shape) + ) def has_trivial_window(self): """If block shape is same as the array shape and index_map returns 0s.""" for b, s in zip(self.block_shape, self.array_shape_dtype.shape): - if b != s and not (b is mapped and s == 1): + if _get_block_dim_size(b) != s: return False for atom in self.index_map_jaxpr.jaxpr.outvars: if not (isinstance(atom, jax_core.Literal) and atom.val == 0): diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 6fbe5e914bfe..f3d2c46ad9a9 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -83,18 +83,19 @@ def _logical_aval_to_interpret_mode_aval(aval): return aval -def _dynamic_slice(start_idx, block_shape, value, is_indexing): +def _dynamic_slice( + start_idx, block_shape: tuple[int, ...], value, is_squeeze, +): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) + squeeze_dims = tuple(np.arange(len(is_squeeze))[np.array(is_squeeze, + dtype=np.bool_)]) + return lax.squeeze(output, squeeze_dims) # type: ignore[arg-type] -def _dynamic_update_slice(start_idx, block_shape, value, update, - is_indexing): +def _dynamic_update_slice(start_idx, block_shape, value, update, is_squeeze): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) - broadcast_dims = tuple(i for i, b in enumerate(is_indexing) + broadcast_dims = tuple(i for i, b in enumerate(is_squeeze) if not b) update = lax.broadcast_in_dim(update, block_shape, broadcast_dims) assert update.shape == block_shape @@ -112,8 +113,7 @@ def _get_next_indices(grid, indices): return tuple(reversed(next_indices)) -def _pad_to_block_dimension(value, - block_shape): +def _pad_to_block_dimension(value, block_shape: tuple[int, ...]): """Pads values so the shape evenly divides into block dimensions. For example, if values has a shape of (33, 2, 5) with a block_shape of @@ -121,8 +121,7 @@ def _pad_to_block_dimension(value, Args: value: Array to be padded. - block_shape: Block shapes to use for padding. If None, no padding will - be performed. + block_shape: Block shapes to use for padding. Returns: A padded array. @@ -377,23 +376,21 @@ def pallas_call_hlo_interpret( carry = [] for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) - x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) + x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + block_shapes = [pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings] + is_squeeze_dim = [ + tuple(isinstance(bd, pallas_core.Squeezed) for bd in bm.block_shape) for bm in grid_mapping.block_mappings ] - block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) - ] # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier @@ -444,7 +441,7 @@ def body(carry): for bm in grid_mapping.block_mappings ] blocks = map(_dynamic_slice, start_indices, block_shapes, - carry_consts_ins, is_indexing_dim) + carry_consts_ins, is_squeeze_dim) with pallas_core.grid_env(local_grid_env): assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len( scratch_values @@ -462,7 +459,7 @@ def body(carry): _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) out_carry = map(_dynamic_update_slice, start_indices, block_shapes, - carry_consts_ins, out_inout, is_indexing_dim) + carry_consts_ins, out_inout, is_squeeze_dim) return (i + 1, _get_next_indices(grid, loop_idx), *out_carry, *out_scratch) @@ -473,14 +470,14 @@ def body(carry): out_out = carry[len(block_args):len(block_args) + len(out)] out_nopad = [] for o, bm in zip(out_out, grid_mapping.block_mappings_output): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_low, pad_high = zip(*padding) - limit_indices = [s - p for s, p in zip(o.shape, pad_high)] - o = lax.slice(o, pad_low, limit_indices) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_low, pad_high = zip(*padding) + limit_indices = [s - p for s, p in zip(o.shape, pad_high)] + o = lax.slice(o, pad_low, limit_indices) if o.shape != bm.array_shape_dtype.shape: o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape) out_nopad.append(o) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5ac6bb6564ba..5921b161d598 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1347,18 +1347,23 @@ def _compute_start_indices( block_indices = _interpret_jaxpr( jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, compiler_params=compiler_params, interpret_params=interpret_params) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = jnp.array( - tuple( - i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices) - ), - dtype=jnp.int32, - ) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed(): + return i + case pallas_core.Element(): + return i + case pallas_core.Blocked(): + return i * b.block_size + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + ret = jnp.array( + tuple( + _get_start_index(i, b) + for i, b in zip(block_indices, block_mapping.block_shape) + ), + dtype=jnp.int32, + ) return ret def _get_next_indices(grid, indices): @@ -1548,13 +1553,13 @@ def interpret_pallas_call( ordered=True) # Pad input arguments. - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + is_squeeze_dim = [ + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) for bm in grid_mapping.block_mappings ] block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings ] num_inputs = grid_mapping.num_inputs input_args = [ @@ -1745,7 +1750,7 @@ def _store_slice_to_kernel_input(index, input_var): for st, sz, iid in zip( cur_start_indices[index], block_shapes[index], - is_indexing_dim[index], + is_squeeze_dim[index], ) ), shape=input_args[index].shape, @@ -1813,7 +1818,7 @@ def _store_to_output_buffer(index, output_var): for st, sz, iid in zip( cur_start_indices[num_inputs + index], block_shapes[num_inputs + index], - is_indexing_dim[num_inputs + index], + is_squeeze_dim[num_inputs + index], ) ), shape=output_vals[index].shape, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 065d2b4c3b14..3fba962ba0a9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -162,7 +162,7 @@ class LoweringContext: grid_names: tuple[Hashable, ...] | None mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. user_grid_indices: Sequence[ir.Value] | None - block_shapes: list[tuple[int | pallas_core.Mapped, ...]] + block_shapes: list[tuple[int | pallas_core.Squeezed, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None replace = dataclasses.replace @@ -197,7 +197,7 @@ class LoweringRuleContext: lowering_context: LoweringContext avals_in: Sequence[jax_core.AbstractValue] avals_out: Sequence[jax_core.AbstractValue] - block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None] + block_shapes: Sequence[tuple[int | pallas_core.Squeezed, ...] | None] replace = dataclasses.replace @property @@ -354,7 +354,13 @@ def _get_arg_type( ), aval.shape, ) - shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape) + shape = pallas_core._get_block_shape(block_mapping.block_shape) + # Keep around squeezed as a sentinel for the lowering rules + block_shape = tuple( + pallas_core.squeezed if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in block_mapping.block_shape + ) return ( aval_to_ir_type( dynamic_shape_replacement_fn, @@ -362,7 +368,7 @@ def _get_arg_type( shape=shape, memory_space=memory_space, ), - block_mapping.block_shape, + block_shape, ) @@ -589,8 +595,7 @@ def err_details(): "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [ - 1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] + unmapped_bs = pallas_core._get_block_shape(bm.block_shape) bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] if rank >= 2: bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] @@ -735,9 +740,7 @@ def dynamic_shape_replacement_fn( dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, ) assert mlir_func.verify(), mlir_func - block_shape = [ - 1 if b is pallas_core.mapped else b for b in bm.block_shape - ] + block_shape = list(pallas_core._get_block_shape(bm.block_shape)) # Force single-buffering pipelining for trivial windowing in VMEM. pipeline_mode = bm.pipeline_mode @@ -756,11 +759,16 @@ def dynamic_shape_replacement_fn( window_bounds=window_shape, transform_indices=ir.FlatSymbolRefAttr.get(func_name), ) - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - if bm.indexing_mode.padding is None: - pad_low = pad_high = [0] * len(bm.block_shape) - else: - pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding)) + is_element_block = [isinstance(bd, pallas_core.Element) + for bd in bm.block_shape] + if any(is_element_block): + if not all(is_element_block): + raise NotImplementedError( + "All block dimensions must be Elements or none of them can be" + " Elements." + ) + padding = [bd.padding for bd in bm.block_shape] # pytype: disable=attribute-error + pad_low, pad_high = map(list, zip(*padding)) block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) @@ -1240,7 +1248,7 @@ def _index_to_start_size_stride( def _indexer_to_start_size_stride( indexer: NDIndexer, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], *, cast_to_index: bool, ) -> tuple[ @@ -1248,21 +1256,21 @@ def _indexer_to_start_size_stride( tuple[int | ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pallas_core.Mapped, ...], + tuple[int | pallas_core.Squeezed, ...], ]: indices_iter = iter(indexer.indices) starts, sizes, strides, squeeze_dims = [], [], [], [] for s in ref_block_shape: - start, size, stride, squeeze_dim = ( - ( - _maybe_cast_to_index(cast_to_index, 0), - 1, - 1, - True, + match s: + case pallas_core.Squeezed(): + start = _maybe_cast_to_index(cast_to_index, 0) + size = 1 + stride = 1 + squeeze_dim = True + case _: + start, size, stride, squeeze_dim = _index_to_start_size_stride( + next(indices_iter), cast_to_index ) - if s is pallas_core.mapped - else _index_to_start_size_stride(next(indices_iter), cast_to_index) - ) starts.append(start) sizes.append(size) strides.append(stride) @@ -1284,8 +1292,8 @@ def _slice_memref( ref: ir.Value, indexer: NDIndexer, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( @@ -1324,8 +1332,8 @@ def _bitcast_memref( ref: ir.Value, bitcaster: RefBitcaster, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Squeezed, ...]]: src_bitwidth = dtype_bitwidth(ref_dtype) dst_bitwidth = dtype_bitwidth(bitcaster.dtype) if src_bitwidth != dst_bitwidth: @@ -1333,7 +1341,7 @@ def _bitcast_memref( raise NotImplementedError( "Bitcast 1D ref with bitwidth change is not supported." ) - if ref_block_shape[-2] is pallas_core.mapped: + if ref_block_shape[-2] is pallas_core.squeezed: raise NotImplementedError( "Bitcast a ref whose 2nd minormost dimension is squeezed when" " bitwidth changes." @@ -1347,7 +1355,7 @@ def _bitcast_memref( new_ref_block_shape = list(ref_block_shape) if ( len(new_ref_block_shape) >= 2 - and new_ref_block_shape[-2] is not pallas_core.mapped + and new_ref_block_shape[-2] is not pallas_core.squeezed ): new_ref_block_shape[-2] = ( new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth @@ -1363,8 +1371,8 @@ def _reshape_memref( ref: ir.Value, reshaper: RefReshaper, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Squeezed, ...]]: if ref_dtype != reshaper.dtype: raise ValueError( f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}" @@ -1372,8 +1380,8 @@ def _reshape_memref( if len(ref_block_shape) < 2: raise NotImplementedError("Reshape 1D ref is not supported.") if ( - ref_block_shape[-2] is pallas_core.mapped - or ref_block_shape[-1] is pallas_core.mapped + ref_block_shape[-2] is pallas_core.squeezed + or ref_block_shape[-1] is pallas_core.squeezed ): raise NotImplementedError( "Reshape a ref with squeezed dimension on last two dimensions." @@ -1677,7 +1685,7 @@ def _masked_swap_lowering_rule( mem_slice_shape.insert(i, 1) mem_slice_shape_iter = iter(mem_slice_shape) mem_slice_shape = [ - 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) + 1 if b is pallas_core.squeezed else next(mem_slice_shape_iter) for b in ref_block_shape ] mem_aval = aval_out.update( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 726d89bfffc5..b465d05ba7d4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -471,11 +471,13 @@ def _eval_index_map( ) result = [] for i, b in zip(block_indices, block_mapping.block_shape): - if b is pallas_core.mapped: - result.append(i) - else: - # TODO(slebedev): Use a type-agnostic multiplication wrapper. - result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + result.append(i) + case pallas_core.Blocked(): + result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + case _: + raise ValueError(f"Unsupported block dim type: {b}") return tuple(result) @@ -507,7 +509,7 @@ def err_details(bm: pallas_core.BlockMapping) -> str: + err_details(bm) ) - if not isinstance(bm.indexing_mode, pallas_core.Blocked): + if any(isinstance(b, pallas_core.Element) for b in bm.block_shape): raise NotImplementedError( "Only Blocked indexing mode is supported in Mosaic GPU lowering.\n\n" + err_details(bm) @@ -548,7 +550,6 @@ def index_map(*indices): bm.block_shape, index_map, memory_space=bm.transformed_block_aval.memory_space, - indexing_mode=bm.indexing_mode, transforms=bm.transforms, ) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 1e52f8701fcb..1aa75eb0bc0e 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -40,6 +40,19 @@ map = util.safe_map zip = util.safe_zip +def _get_block_size(bd: pl.Blocked | pl.Element | pl.Squeezed | int | None + ) -> int: + match bd: + case int(): + return bd + case pl.Blocked(block_size): + return block_size + case _: + raise NotImplementedError(f"Unsupported block size type: {type(bd)}") + +def _get_block_shape(spec: pallas_core.BlockSpec): + assert spec.block_shape is not None + return tuple(_get_block_size(bd) for bd in spec.block_shape) @jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) @@ -64,10 +77,11 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: # We don't allow Python scalars here, because they are interpreted # differently depending on the x32/x64 mode. assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices) + sizes = _get_block_shape(self.spec) return tuple( pl.Slice(idx * size, size) # type: ignore[arg-type] for idx, size in zip( - index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] + index_map(*grid_indices), sizes # type: ignore[arg-type] ) ) @@ -201,7 +215,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_smem_refs, out_smem_refs = util.split_list( [ gpu_core.SMEM( - (max_concurrent_steps, *spec.block_shape), # type: ignore + (max_concurrent_steps, *_get_block_shape(spec)), # type: ignore ref.dtype, transforms=tuple( t.batch(1) for t in getattr(spec, "transforms", ()) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f6a3757ca8d6..500aaf125ca0 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -257,10 +257,10 @@ def _block_map_function(new_idx, *args): new_block_shape = shape stacked_axis = dim.stacked_axis new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.mapped + new_block_shape, stacked_axis, pallas_core.squeezed ) else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) array_shape = block_mapping.array_shape_dtype.shape if isinstance(dim, batching.RaggedAxis): @@ -663,7 +663,7 @@ def get_size(i, x, d): for block_mapping in batched_grid_mapping.block_mappings: mapped_dim_idxs = [] for i, d in enumerate(block_mapping.block_shape): - if d is pallas_core.mapped: + if isinstance(d, pallas_core.Squeezed): mapped_dim_idxs.append(i) else: mapped_dim_idxs.append(None) # type: ignore[arg-type] @@ -754,7 +754,7 @@ def when_wrapped_kernel(lengths_ref, *args, **kwargs): continue arg_i_idx = ( primitives.program_id(ragged_axis_dim) - * block_shapes[i][ragged_axis_dim] + * pallas_core._get_block_dim_size(block_shapes[i][ragged_axis_dim]) ) run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len) @@ -800,7 +800,8 @@ def index_rewrite_kernel(*indexer_args): nargs = list(rest_indexer_args) if ragged_axis_dim is not None: - val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim] + val_at_ragged_dim = pallas_core._get_block_dim_size( + batched_block_mapping.block_shape[ragged_axis_dim]) # The current index into the ragged dimension. # Invariant: There is only one ragged dimension, enforced above. @@ -965,13 +966,12 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, num_iterations = 1 is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings ] # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) # i:int32 is the interation index diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 150ae9b8b2d7..592abd1915d0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -90,7 +90,7 @@ class BlockInfo: full_shape_dtype: jax.ShapeDtypeStruct start_indices: Sequence[Any] start_indices_alignment: Sequence[int] - block_shape: tuple[int | pallas_core.Mapped, ...] + block_shape: tuple[int | pallas_core.Squeezed, ...] @dataclasses.dataclass @@ -125,28 +125,39 @@ def _eval_index_map( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - if block_mapping.indexing_mode.padding is not None: - raise NotImplementedError( - "Unblocked indexing with padding is not supported in Triton lowering." - ) - if block_mapping.pipeline_mode is not None: - raise NotImplementedError( - "Pipeline mode is not supported in Triton lowering." - ) - return tuple(block_indices) + if block_mapping.pipeline_mode is not None: + raise NotImplementedError( + "Pipeline mode is not supported in Triton lowering." + ) + if any( + isinstance(b, pallas_core.Element) and b.padding != (0, 0) + for b in block_mapping.block_shape + ): + raise NotImplementedError( + "Unblocked indexing with padding is not supported in Triton lowering." + ) + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return i + case pallas_core.Blocked(): + return _mul(i, _ir_constant(b.block_size, i.type)) + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") return tuple( - i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type)) - for i, b in zip(block_indices, block_mapping.block_shape) + _get_start_index(i, b) for i, b in + zip(block_indices, block_mapping.block_shape) ) def _get_index_alignment(block_mapping: BlockMapping) -> tuple[int, ...]: - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - return (1,) * len(block_mapping.block_shape) - return tuple( - 1 if b is pallas_core.mapped else b for b in block_mapping.block_shape - ) + def _get_bdim_alignment(b: pallas_core.BlockDim): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return 1 + case pallas_core.Blocked(): + return b.block_size + return tuple(_get_bdim_alignment(b) for b in block_mapping.block_shape) def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value: @@ -274,8 +285,9 @@ def _new_ir_context() -> ir.Context: # this). This check is only needed to obtain a nicer error message; the # Triton lowering will fail anyway but it will crash with a C++ exception. # We currently apply this check only to load/store operations. -def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]): - size = math.prod(1 if d is pallas_core.mapped else d for d in shape) +def _check_tensor_size(shape: tuple[int | pallas_core.Squeezed, ...]): + size = math.prod(1 if isinstance(d, pallas_core.Squeezed) else d + for d in shape) power_of_2 = (size & (size - 1)) == 0 if not power_of_2: raise ValueError( @@ -347,7 +359,9 @@ def lower_jaxpr_to_triton_module( block_mapping.array_shape_dtype, _eval_index_map(ctx, program_ids, block_mapping), _get_index_alignment(block_mapping), - block_mapping.block_shape, + tuple(pallas_core.squeezed if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in block_mapping.block_shape), ) for block_mapping in grid_mapping.block_mappings ] @@ -1833,7 +1847,8 @@ def _compute_offsets_from_indices( block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: full_shape = block_info.full_shape_dtype.shape - num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) + num_squeezed_dims = sum(isinstance(b, pallas_core.Squeezed) + for b in block_info.block_shape) strides = pallas_utils.strides_from_shape(full_shape) indexer_shape = nd_indexer.get_indexer_shape() int_indexer_shape = nd_indexer.int_indexer_shape @@ -1841,7 +1856,7 @@ def _compute_offsets_from_indices( indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] other_shape_idx = 0 - assert len(indices) + num_mapped_dims == len(full_shape) + assert len(indices) + num_squeezed_dims == len(full_shape) assert len(block_info.start_indices) == len(full_shape) array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype) @@ -1857,10 +1872,11 @@ def _compute_offsets_from_indices( for dim_stride, dim_block_size, start_offset in zip( strides, block_info.block_shape, block_info.start_indices ): - if dim_block_size is pallas_core.mapped: - index = _ir_constant(0, offset_eltype) - else: - index = next(indexer_iter) + match dim_block_size: + case pallas_core.Squeezed(): + index = _ir_constant(0, offset_eltype) + case int(): + index = next(indexer_iter) if isinstance(index, slice): index = primitives.Slice.from_slice(index, dim_block_size) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 2144be0fb18b..c05e1645ddbe 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -20,19 +20,19 @@ from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec +from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import CompilerParams as CompilerParams from jax._src.pallas.core import core_map as core_map from jax._src.pallas.core import CostEstimate as CostEstimate +from jax._src.pallas.core import Element as Element from jax._src.pallas.core import GridSpec as GridSpec -from jax._src.pallas.core import IndexingMode as IndexingMode from jax._src.pallas.core import lower_as_mlir as lower_as_mlir from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace -from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec from jax._src.pallas.core import semaphore as semaphore -from jax._src.pallas.core import Unblocked as Unblocked -from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.core import Squeezed as Squeezed +from jax._src.pallas.core import squeezed as squeezed from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like @@ -48,6 +48,7 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous @@ -61,7 +62,6 @@ from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap -from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.utils import cdiv as cdiv from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 from jax._src.pallas.utils import strides_from_shape as strides_from_shape diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 045c47a8dd71..54e4e4c47784 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -866,14 +866,14 @@ class PallasCallInterpretTest(PallasCallTest): INTERPRET = True -class PallasCallUnblockedIndexingTest(PallasBaseTest): +class PallasCallElementIndexingTest(PallasBaseTest): - def test_block_spec_unblocked(self): + def test_block_spec_element(self): def show_program_ids( - *, shape, block_shape, grid, indexing_mode: pl.IndexingMode + *, shape, block_shape, grid, ): def kernel(o1_ref): - assert o1_ref.shape == block_shape + assert o1_ref.shape == (8, 128) o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) return self.pallas_call( @@ -881,16 +881,15 @@ def kernel(o1_ref): jax.ShapeDtypeStruct(shape, dtype=np.int32), grid=grid, out_specs=pl.BlockSpec( - block_shape, lambda i: (8 * i, 0), indexing_mode=indexing_mode + block_shape, lambda i: (8 * i, 0), ), )() # No padding pids = show_program_ids( shape=(16, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8), pl.Element(128)), grid=(2,), - indexing_mode=pl.Unblocked(), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 8, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -901,9 +900,8 @@ def kernel(o1_ref): # Only high padding pids = show_program_ids( shape=(14, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (0, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((0, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -912,15 +910,14 @@ def kernel(o1_ref): self.skipTest("TODO: low padding not supported yet") pids = show_program_ids( shape=(11, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (3, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((3, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 5 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @parameterized.parameters("int32", "float32") - def test_block_spec_unblocked_padding_is_nan(self, dtype_name): + def test_block_spec_element_padding_is_nan(self, dtype_name): if not self.INTERPRET: self.skipTest("Only applicable for the interpret mode") @@ -935,7 +932,7 @@ def copy_kernel(x_ref, o_ref): grid=(1,), in_specs=[ pl.BlockSpec( - (6,), lambda i: 0, indexing_mode=pl.Unblocked(((1, 2),)) + (pl.Element(6, (1, 2)),), lambda i: 0, ) ], )(np.full((3,), 42, dtype=dtype)) @@ -949,7 +946,7 @@ def copy_kernel(x_ref, o_ref): ), ) - def test_unblocked_indexing(self): + def test_element_indexing(self): shape = (16 * 8, 128) result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) @@ -962,11 +959,12 @@ def kernel(x_ref, o_ref): grid=(15,), in_specs=( pl.BlockSpec( - (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked + (pl.Element(2 * 8), pl.Element(128)), lambda i: (i * 8, 0), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), out_shape=result_ty, + debug=True, )(x) ref = [] for i in range(15): @@ -991,9 +989,8 @@ def kernel(x_ref, y_ref): grid=(1,), in_specs=( pl.BlockSpec( - (2 * 8, 128), + (pl.Element(2 * 8, (0, 8)), pl.Element(128)), lambda i: (0, 0), - indexing_mode=pl.Unblocked(((0, 8), (0, 0))), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), @@ -1002,7 +999,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x) -class PallasCallUnblockedIndexingInterpretTest(PallasCallUnblockedIndexingTest): +class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest): INTERPRET = True diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index 9977c18d939d..0e0f80a0c2b5 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -117,6 +117,8 @@ def f_stateful(refs): x = pl.pallas_call( functools.partial(copy_kernel, x_ref, y_ref), + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), input_output_aliases={0: 0}, From 1d25c82a7d18e11881c0deebabf2923bf977b5ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 17 Apr 2025 21:14:59 -0700 Subject: [PATCH 0687/1769] [Mosaic TPU] Enable non-sublane-aligned 2D int8 load/stores PiperOrigin-RevId: 748914721 --- jaxlib/mosaic/dialect/tpu/layout.cc | 66 +++++++++++++++-------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 7ae8681e6980..f041570b8371 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -249,25 +249,22 @@ class TiledRectangularVregBounds : public VRegDataBounds { FailureOr> getVectorMask( OpBuilder& builder, const Location loc, const int generation, const std::array target_shape) const override { + const int8_t bitwidth = layout_.bitwidth(); + const int packing = layout_.packing(); + const int max_subelems = generation < 4 ? 1 : generation < 5 ? 2 : 4; const IntegerType i1 = builder.getI1Type(); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType mask_vreg_ty, [&]() -> FailureOr { - // I'm pretty sure this works for all bitwidths, but it's untested. - if (maskVariesAlong(Direction::kSubelements, target_shape)) { - if (layout_.packing() != 2) { - // TODO(b/300082350): Generalize this - return emitError(loc, "Not implemented: packing != 2"); - } - // For older TPUs, we virtualize masking - if (generation < 4) { - return VectorType::get(target_shape, i1); - } else { - return VectorType::get( - {target_shape[0], target_shape[1], layout_.packing()}, i1); - } - } + const VectorType mask_vreg_ty = [&]() { + if (maskVariesAlong(Direction::kSubelements, target_shape)) { + // When CreateSubelementMask isn't supported, we virtualize masking. + if (packing > max_subelems) { return VectorType::get(target_shape, i1); - }()); + } else { + return VectorType::get( + {target_shape[0], target_shape[1], packing}, i1); + } + } + return VectorType::get(target_shape, i1); + }(); if (isComplete(target_shape)) { return cast>( builder @@ -279,7 +276,6 @@ class TiledRectangularVregBounds : public VRegDataBounds { } Value mask = nullptr; CHECK_GE(num_tiles_, 0); - const int packing = layout_.packing(); const int64_t start_sub = start_offsets_[0] / packing; const int64_t end_sub = llvm::divideCeil(end_offsets_[0], packing); CHECK_LE(0, start_sub); @@ -308,20 +304,20 @@ class TiledRectangularVregBounds : public VRegDataBounds { if (maskVariesAlong(Direction::kSubelements, target_shape)) { int64_t start_row = start_offsets_[0] + row_offset; int64_t end_row = end_offsets_[0] + row_offset; - if (generation >= 4) { + if (packing <= max_subelems) { // Only use non-trivial start/end if they don't fall on sublane // boundary. Otherwise CreateMaskOp already does the right thing. This // lets us use cheaper instruction sequences on TPUv4. - if (start_offsets_[0] % layout_.packing() == 0) { + if (start_offsets_[0] % packing == 0) { start_row = 0; } - if (end_offsets_[0] % layout_.packing() == 0) { - end_row = target_shape[0] * layout_.packing(); + if (end_offsets_[0] % packing == 0) { + end_row = target_shape[0] * packing; } auto submask = builder.create( loc, mask_vreg_ty, start_row, end_row); tile_mask = builder.create(loc, tile_mask, submask); - } else { // generation < 4 + } else { // packing > max_subelems const auto getMaskCst = [&](const uint64_t v) { const auto int_mask_ty = VectorType::get(target_shape, builder.getI32Type()); @@ -333,25 +329,33 @@ class TiledRectangularVregBounds : public VRegDataBounds { }; tile_mask = builder.create( loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0)); - if (start_row % 2 != 0) { + if (const int64_t row_in_sublane = start_row % packing; + row_in_sublane != 0) { auto row_mask = builder.create( loc, mask_vreg_ty, - ValueRange{boundIdxConst(start_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(start_row / 2 + 1), + ValueRange{boundIdxConst(start_row / packing), + boundIdxConst(0)}, + ValueRange{boundIdxConst(start_row / packing + 1), boundIdxConst(target_shape[1])}); auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF0000), getMaskCst(0xFFFFFFFF)); + loc, row_mask, + getMaskCst(0xFFFFFFFF << row_in_sublane * bitwidth), + getMaskCst(0xFFFFFFFF)); tile_mask = builder.create(loc, tile_mask, row_bitmask); } - if (end_row % 2 != 0) { + if (const int64_t row_in_sublane = end_row % packing; + row_in_sublane != 0) { auto row_mask = builder.create( loc, mask_vreg_ty, - ValueRange{boundIdxConst(end_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(end_row / 2 + 1), + ValueRange{boundIdxConst(end_row / packing), boundIdxConst(0)}, + ValueRange{boundIdxConst(end_row / packing + 1), boundIdxConst(target_shape[1])}); auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF), getMaskCst(0xFFFFFFFF)); + loc, row_mask, + getMaskCst(0xFFFFFFFFu >> + (packing - row_in_sublane) * bitwidth), + getMaskCst(0xFFFFFFFF)); tile_mask = builder.create(loc, tile_mask, row_bitmask); } From 767c741da62399590da2554081b9547e5411823b Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 18 Apr 2025 06:13:20 -0700 Subject: [PATCH 0688/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9477be9b2b3afdc2c93b7c272d00f8034750f41c. PiperOrigin-RevId: 749017950 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 846ec8d56d97..7a393c317533 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a5c81f946fd2717c13075221b692db59b513a6eb" -XLA_SHA256 = "ae1f5afc050d3d6add7c8b03e8f82227c0e9931382889347e4d22ca38af1026c" +XLA_COMMIT = "9477be9b2b3afdc2c93b7c272d00f8034750f41c" +XLA_SHA256 = "07f295daea8cee7b72c1be64c0546b26f121b31b7897061c3fe86acfb52e238d" def repo(): tf_http_archive( From 13686bd8867f53aa031c825c31fcd14675fca9a9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Apr 2025 08:05:00 -0700 Subject: [PATCH 0689/1769] Remove unused type forwarding declarations for jax.lib.xla_client.{Shape,XlaComputation}. This should not be a breaking change; the underlying definitions were already removed. PiperOrigin-RevId: 749037564 --- jax/lib/xla_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index cc4fa78eb576..d4571ae297b9 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -150,8 +150,6 @@ import typing as _typing if _typing.TYPE_CHECKING: - Shape = _xc.Shape - XlaComputation = _xc.XlaComputation get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version From 492cd3d9313cfd45e8bd63a8f51aa63d92924cd5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 18 Apr 2025 09:02:08 -0700 Subject: [PATCH 0690/1769] Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717 PiperOrigin-RevId: 749049676 --- jax/_src/interpreters/ad.py | 24 +++---- jax/_src/interpreters/partial_eval.py | 99 +++++++++++---------------- jax/_src/pjit.py | 4 +- jax/experimental/attrs.py | 3 +- jax/experimental/shard_map.py | 7 +- 5 files changed, 57 insertions(+), 80 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 0f11e0d72f12..98cda2df4964 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -89,14 +89,12 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): - source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, - tangent_trace.new_arg(get_aval(p).to_tangent_aval(), - source_info)) + tangent_trace.new_arg(get_aval(p).to_tangent_aval())) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): @@ -105,7 +103,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] + out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -174,14 +172,13 @@ def _linearize_jaxpr( lin_trace = LinearizeTrace(primal_trace, tangent_trace) tangent_trace.tag = lin_trace.tag - def new_arg(trace, primal_aval, nz, source_info): - primal = primal_trace.new_arg(primal_aval, source_info) + def new_arg(trace, primal_aval, nz): + primal = primal_trace.new_arg(primal_aval) tangent_aval = primal_aval.to_tangent_aval() - tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) + tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) - source_info = source_info_util.current() - tracers = [new_arg(lin_trace, v.aval, nz, source_info) + tracers = [new_arg(lin_trace, v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] with core.set_current_trace(lin_trace, check_leaks=True): @@ -191,7 +188,7 @@ def new_arg(trace, primal_aval, nz, source_info): debug_info = jaxpr.jaxpr.debug_info nzs_out = [type(t) is not Zero for t in out_tangents] - out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t, source_info) + out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) for (nz, t) in zip(nzs_out, out_tangents) if nz) tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) tangent_trace.invalidate() @@ -203,7 +200,7 @@ def new_arg(trace, primal_aval, nz, source_info): tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), residuals_and_primals) + residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) primal_trace.invalidate() num_residuals = len(tangent_consts) @@ -215,9 +212,8 @@ def new_arg(trace, primal_aval, nz, source_info): def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, has_aux=False, tag=None): with core.take_current_trace() as parent_trace: - source_info = source_info_util.current() tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info) - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag @@ -238,7 +234,7 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) + out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() jaxpr, used_consts, _ = pe.dce_jaxpr_consts( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 8be78f1f6eb9..f8ce92e7f97f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,7 +41,6 @@ JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) -from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) @@ -1730,8 +1729,7 @@ def to_jaxpr( invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - source_info = source_info_util.current() - state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x, source_info))] + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars @@ -1901,51 +1899,51 @@ def invalidate(self): self.frame.constid_to_tracer = {} self.frame.constvar_to_val = {} - def to_jaxpr_tracer(self, x, source_info: SourceInfo): + def to_jaxpr_tracer(self, x): as_local_var = self.frame.tracer_to_var.get(id(x)) if as_local_var is None: if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr with core.set_current_trace(self): x = x.dimension_as_value() - return self.to_jaxpr_tracer(x, source_info) + return self.to_jaxpr_tracer(x) else: - return self.new_const(x, source_info) + return self.new_const(x) else: return x - def new_arg(self, aval, source_info: SourceInfo): - tracer = DynamicJaxprTracer(self, aval, source_info) + def new_arg(self, aval): + tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) return tracer - def new_const(self, c, source_info: SourceInfo): + def new_const(self, c): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = get_aval(c) if hasattr(aval, "weak_type"): aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval, source_info) - tracer = self._new_const(aval, c, source_info) + aval = self._lift_tracers_in_aval(aval) + tracer = self._new_const(aval, c) return tracer pure = lift = new_const - def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: - tracer = DynamicJaxprTracer(self, aval, source_info) + def _new_const(self, aval, c) -> DynamicJaxprTracer: + tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.constid_to_tracer[id(c)] = tracer self.frame.constvar_to_val[var] = c return tracer - def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): + def _lift_tracers_in_aval(self, aval): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1968,9 +1966,7 @@ def is_const(self, tracer): def process_primitive(self, primitive, tracers, params): if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): return primitive.bind_with_trace(core.eval_trace, tracers, params) - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - jaxpr_tracers = map(to_jaxpr_tracer, tracers) + jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) return self.default_process_primitive(primitive, jaxpr_tracers, params) @@ -1993,19 +1989,17 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers, params): - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) assert f.in_type is not None - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers, - source_info) - in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) + in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) + source_info = source_info_util.current() out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: @@ -2015,7 +2009,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -2028,9 +2022,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - tracers = map(to_jaxpr_tracer, tracers) + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) @@ -2049,9 +2041,10 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] + source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2069,9 +2062,7 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - tracers = map(to_jaxpr_tracer, tracers) + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) @@ -2088,7 +2079,7 @@ def jvp_jaxpr_thunk(*in_zeros): out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, @@ -2097,7 +2088,7 @@ def jvp_jaxpr_thunk(*in_zeros): num_consts=len(consts), symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info) + source_info_util.current()) self.frame.add_eqn(eqn) return out_tracers @@ -2106,9 +2097,7 @@ def process_custom_vjp_call(self, prim: core.Primitive, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, out_trees: Callable[[], Sequence[PyTreeDef]], symbolic_zeros: bool): - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - tracers = map(to_jaxpr_tracer, tracers) + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2121,9 +2110,9 @@ def fwd_jaxpr_from_zeros(*zeros): if attrs: raise NotImplementedError return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, # pytype: disable=attribute-error @@ -2133,7 +2122,7 @@ def fwd_jaxpr_from_zeros(*zeros): bwd=bwd, out_trees=out_trees, symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info) + source_info_util.current()) self.frame.add_eqn(eqn) return out_tracers @@ -2143,9 +2132,7 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid out_types, lin_tree: PyTreeDef, res_tree: PyTreeDef, out_tree: PyTreeDef): - source_info = source_info_util.current() - to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - tracers = map(to_jaxpr_tracer, tracers) + tracers = map(self.to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] @@ -2165,9 +2152,9 @@ def transpose_jaxpr_thunk(): jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(to_jaxpr_tracer, call_consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2175,7 +2162,7 @@ def transpose_jaxpr_thunk(): out_types=out_types, res_tree=res_tree, lin_tree=lin_tree, out_tree=out_tree), closed_call_jaxpr.effects, - source_info) + source_info_util.current()) self.frame.add_eqn(eqn) return out_tracers @@ -2229,15 +2216,13 @@ def trace_to_jaxpr_dynamic( keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - source_info = source_info_util.current() - in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] try: with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) + out_tracers = map(trace.to_jaxpr_tracer, ans) _check_no_returned_refs(fun.debug_info, out_tracers) jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) del fun, in_tracers, out_tracers, ans @@ -2284,14 +2269,12 @@ def trace_to_jaxpr_dynamic2( trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - source_info = source_info_util.current() in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) + out_tracers = map(trace.to_jaxpr_tracer, ans) jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) del trace, in_tracers, out_tracers, ans @@ -2466,7 +2449,7 @@ def __hash__(self): def _extract_implicit_args( trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo, + explicit_tracers: Sequence[DynamicJaxprTracer] ) -> Sequence[DynamicJaxprTracer]: # First, construct a list to represent the full argument list, leaving the # implicit arguments as Nones for now. @@ -2484,8 +2467,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info) + tracers[d1.val] = trace.to_jaxpr_tracer(d2) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2633,13 +2616,13 @@ def inline_jaxpr_into_trace( trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - src = source_info_util.current() - const_tracers = map(partial(trace.new_const, source_info=src), consts) + const_tracers = map(trace.new_const, consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*constvars, *argvars])) + src = source_info_util.current() for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] outvars = [Var('', v.aval) for v in eqn.outvars] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 595300f2c1b2..a10856cbced3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1928,12 +1928,12 @@ def pjit_staging_rule(trace, *args, **params): trace, jaxpr.jaxpr, jaxpr.consts, *args) jaxpr = params['jaxpr'] - source_info = source_info_util.current() if config.dynamic_shapes.value: jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( jaxpr, params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) + source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: @@ -1952,7 +1952,7 @@ def pjit_staging_rule(trace, *args, **params): assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(partial(trace.new_const, source_info=source_info), consts) + consts = map(trace.new_const, consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 54fd0fe0b02f..0d40938a85c4 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -87,11 +87,10 @@ def _check_append_type_agreement(_, attr, curtype, valtype): def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame - source_info = source_info_util.current() def new_tracer(x): aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, source_info) + tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) frame.attrs_vars.append(var) frame.tracers.append(tracer) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b66b22304e17..b81a4656546b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -512,9 +512,7 @@ def _shard_map_staging( check_rep: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: - source_info = source_info_util.current() - to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) - in_tracers = map(to_jaxpr_tracer, in_tracers) + in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, in_avals) @@ -529,9 +527,10 @@ def _shard_map_staging( out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] + source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) + constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), From 96865709b1e1594beda4ec7d399ab21472204d18 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Apr 2025 09:28:13 -0700 Subject: [PATCH 0691/1769] Allow the CPU collective implementation to be overridden to None. PiperOrigin-RevId: 749055960 --- jax/_src/config.py | 4 +++- jax/_src/xla_bridge.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 288cfc64c027..d427b9c84450 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1802,10 +1802,12 @@ def _update_garbage_collection_guard(state, key, val): include_in_jit_key=True ) +DEFAULT_CPU_COLLECTIVES_IMPL = "gloo" + cpu_collectives_implementation = optional_enum_state( name='jax_cpu_collectives_implementation', enum_values=["gloo", "mpi", "megascale"], - default=None, + default=DEFAULT_CPU_COLLECTIVES_IMPL, help=( "Cross-process collective implementation used on CPU. Must be one of " '("gloo", "mpi")'), diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 5fb42c333605..25df9501d49a 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -65,8 +65,6 @@ MIN_COMPUTE_CAPABILITY = 52 -_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' - # TODO(phawkins): Remove jax_xla_backend. _XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', @@ -251,8 +249,6 @@ def make_cpu_client( '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, ) - if collectives_impl is None: - collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL if collectives_impl == 'gloo': collectives = xla_client._xla.make_gloo_tcp_collectives( From 95156068927834ccccfa8e4133c213cde2a003ce Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Apr 2025 10:16:39 -0700 Subject: [PATCH 0692/1769] [JAX] Remove jax.lib.xla_client.mlir_api_version and its uses. (We leave the name exported by JAX to avoid breaking users, but fixed to its last known value.) PiperOrigin-RevId: 749070199 --- jax/_src/lib/__init__.py | 3 --- jax/lib/xla_client.py | 4 ++-- jaxlib/xla/xla_client.py | 3 --- jaxlib/xla/xla_client.pyi | 2 -- 4 files changed, 2 insertions(+), 10 deletions(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 5d65e005d897..e9f9d95f608f 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -128,9 +128,6 @@ def _xla_gc_callback(*args): import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error # noqa: F401 import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 -# Version number for MLIR:Python APIs, provided by jaxlib. -mlir_api_version = xla_client.mlir_api_version - # TODO(rocm): check if we need the same for rocm. def _cuda_path() -> str | None: diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index d4571ae297b9..12f48b21f1c3 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -94,7 +94,7 @@ "jax.lib.xla_client.mlir_api_version was deprecated in JAX v0.6.0" " and will be removed in JAX v0.7.0" ), - _xc.mlir_api_version, + 58, ), "Client": ( ( @@ -152,7 +152,7 @@ if _typing.TYPE_CHECKING: get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile - mlir_api_version = _xc.mlir_api_version + mlir_api_version = 58 Client = _xc.Client CompileOptions = _xc.CompileOptions DeviceAssignment = _xc.DeviceAssignment diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 69f24ee13f6f..0cbd2b3f3b4d 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -58,9 +58,6 @@ # In JAX, reference this via jax._src.lib.ifrt_version. _ifrt_version = _xla.ifrt_version_number -# Version number for MLIR:Python components. -mlir_api_version = 58 - xla_platform_names = { 'cpu': 'Host', 'gpu': 'CUDA', diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index f45b27c461e8..c10556e83920 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -60,8 +60,6 @@ _version: int _ifrt_version: int -mlir_api_version: int - XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] From 934a0db43e2dd276d7382c0bc63c9a8039a95dc7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Apr 2025 11:06:52 -0700 Subject: [PATCH 0693/1769] Don't use the deprecated jax.dlpack.to_dlpack in tests. PiperOrigin-RevId: 749086392 --- tests/array_interoperability_test.py | 29 ++------------------------ tests/pytorch_interoperability_test.py | 5 +---- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index adfd34627e76..a61c19ab4e0b 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -113,35 +113,11 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy): y = jax.device_put(x, device) # TODO(parkers): Remove after setting 'stream' properly below. jax.block_until_ready(y) - dl_device = y.__dlpack_device__() - if use_stream: - stream = tuple(y.devices())[0].get_stream_for_external_ready_events() - dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream) - else: - dlpack = jax.dlpack.to_dlpack(y, copy=copy) - z = jax.dlpack.from_dlpack(dlpack) + z = jax.dlpack.from_dlpack(y) self.assertEqual(z.devices(), {device}) self.assertAllClose(np.astype(x.dtype), z) - self.assertRaisesRegex(RuntimeError, - "DLPack tensor may be consumed at most once", - lambda: jax.dlpack.from_dlpack(dlpack)) - - if shape in nonempty_array_shapes: - _check_copy(y, z, bool(copy)) - - # Check if the destination device can be specified - make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy) - if copy == False: - self.assertRaisesRegex(ValueError, "copy=False", make_dlpack) - return - - z = jax.dlpack.from_dlpack(make_dlpack()) - self.assertEqual(z.devices(), {device}) - self.assertAllClose(x, z) - if shape in nonempty_array_shapes: - _check_copy(x, z, True) @jtu.sample_product( shape=all_shapes, @@ -215,8 +191,7 @@ def testJaxToTensorFlow(self, shape, dtype): # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) - dlpack = jax.dlpack.to_dlpack(x) - y = tf.experimental.dlpack.from_dlpack(dlpack) + y = tf.experimental.dlpack.from_dlpack(x.__dlpack__()) self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 3035e68d234c..4b8f58cd6982 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -67,8 +67,6 @@ def testTorchToJaxFailure(self): y, client, client) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) - @jtu.ignore_warning(message="jax.dlpack.to_dlpack was deprecated.*", - category=DeprecationWarning) def testJaxToTorch(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, @@ -79,8 +77,7 @@ def testJaxToTorch(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) - dlpack = jax.dlpack.to_dlpack(x) - y = torch.utils.dlpack.from_dlpack(dlpack) + y = torch.utils.dlpack.from_dlpack(x) if dtype == jnp.bfloat16: # .numpy() doesn't work on Torch bfloat16 tensors. self.assertAllClose(np, From 854b2c85db03c9f0803c8423e3d5bfe7ac3901ff Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 18 Apr 2025 11:16:36 -0700 Subject: [PATCH 0694/1769] Drop into `Auto` mode for `.at[...].set(...)` but instead of taking an `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes https://github.com/jax-ml/jax/issues/28111 PiperOrigin-RevId: 749089846 --- jax/_src/numpy/array_methods.py | 6 +++++- jax/_src/numpy/indexing.py | 15 ++++++++++----- jax/_src/ops/scatter.py | 20 +++++++++++++++----- tests/pjit_test.py | 22 ++++++++++++++++++++++ 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 0d7c50ee3358..a64a6662e2a7 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -795,9 +795,13 @@ def set(self, values, *, indices_are_sorted=False, unique_indices=False, See :mod:`jax.ops` for details. """ + out_s = core.typeof(self.array).sharding + if out_s.mesh.empty or out_s.mesh._are_all_axes_auto_or_manual: + out_s = None return scatter._scatter_update(self.array, self.index, values, lax.scatter, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + out_sharding=out_s) def apply(self, func, *, indices_are_sorted=False, unique_indices=False, mode=None): diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 21a97277e433..044b5175a46a 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -521,7 +521,8 @@ def _is_contiguous_slice(idx): (idx.stop is None or _is_integer_index(idx.stop)) and (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) -def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None: +def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None, + out_sharding=None) -> Array | None: # attempt to compute _rewriting_take via lax.slice(); return None if not possible. idx = idx if isinstance(idx, tuple) else (idx,) @@ -604,9 +605,12 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> jnp_error._check_precondition_oob_dynamic_slice( arr.shape, start_indices, slice_sizes, allow_negative_indices ) - arr = lax.dynamic_slice( - arr, start_indices=start_indices, slice_sizes=slice_sizes, - allow_negative_indices=allow_negative_indices) + internal_ds = partial(lax.dynamic_slice, slice_sizes=slice_sizes, + allow_negative_indices=allow_negative_indices) + if out_sharding is not None: + arr = auto_axes(internal_ds, out_shardings=out_sharding)(arr, start_indices) + else: + arr = internal_ds(arr, start_indices) if int_indices: arr = lax.squeeze(arr, tuple(int_indices)) return arr @@ -621,7 +625,8 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, # For simplicity of generated primitives, we call lax.dynamic_slice in the # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. - if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: + result = _attempt_rewriting_take_via_slice(arr, idx, mode, out_sharding) + if result is not None: return result # TODO(mattjj,dougalm): expand dynamic shape indexing support diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index eccbbdde006e..baf6a79328b4 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -19,6 +19,7 @@ from collections.abc import Callable, Sequence from typing import Union import warnings +from functools import partial import numpy as np @@ -30,6 +31,7 @@ from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.numpy import indexing +from jax._src.pjit import auto_axes from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions from jax._src.numpy.util import check_arraylike, promote_dtypes @@ -43,7 +45,8 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, - unique_indices, mode=None, normalize_indices=True): + unique_indices, mode=None, normalize_indices=True, + out_sharding=None): """Helper for indexed updates. Computes the value of x that would result from computing:: @@ -74,15 +77,22 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, # XLA gathers and scatters are very similar in structure; the scatter logic # is more or less a transpose of the gather equivalent. treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape) - return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices) + + internal_scatter = partial( + _scatter_impl, scatter_op=scatter_op, treedef=treedef, + static_idx=static_idx, indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode, + normalize_indices=normalize_indices) + if out_sharding is not None: + return auto_axes(internal_scatter, out_shardings=out_sharding + )(x, y, dynamic_idx) + return internal_scatter(x, y, dynamic_idx) # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(2, 3, 4)) -def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, +def _scatter_impl(x, y, dynamic_idx, *, scatter_op, treedef, static_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 53f0d042e801..c8a4ed906ff8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7488,6 +7488,28 @@ def vmap_where(batch_a, batch_b): out = vmap_where(batch_a, batch_b) self.assertEqual(out.sharding, xy_sharding) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_scatter_gather(self, mesh): + x = np.random.uniform(size=(mesh.size * 2, 3)) + i = np.random.randint(0, x.shape[1], len(x)) + j = np.random.randint(0, x.shape[1], len(x)) + x = jax.device_put(x, P("x")) + i = jax.device_put(i, P("x")) + j = jax.device_put(j, P("x")) + + @jax.jit + def f1(x, i, j): + x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding) + return x.at[:, i].set(x_a_j) + f1(x,i,j) # doesn't crash + + @jax.jit + @jax.vmap + def f2(x, i, j): + x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding) + return x.at[i].set(x_j) + f2(x,i,j) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 7f7d7805e4594e6c9702706674fd90722f46e848 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Apr 2025 11:17:01 -0700 Subject: [PATCH 0695/1769] Mark CliDebuggerTest as a thread-unsafe test. Fixes some test flakiness with JAX_TEST_NUM_THREADS > 1. PiperOrigin-RevId: 749089988 --- tests/debugger_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 419e7b18dfed..0d66cd47d8cc 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -43,6 +43,10 @@ def _format_multiline(text): foo = 2 +# This test is thread-unsafe because jax.effects_barrier() is global. This means +# that we can create a deadlock if running tests in multiple threads because we +# can introduce false dependencies via the effects barrier. +@jtu.thread_unsafe_test_class() class CliDebuggerTest(jtu.JaxTestCase): def setUp(self): From c6482ed636332336acb48d2348ccbd916aa11bcb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 16 Apr 2025 14:21:39 -0400 Subject: [PATCH 0696/1769] Ensure outputs are tracers when inlining jit. --- jax/_src/pjit.py | 7 ++++--- tests/pjit_test.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a10856cbced3..34f2ef0487a7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1921,11 +1921,12 @@ def pjit_staging_rule(trace, *args, **params): # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. with core.set_current_trace(trace): - return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: - return pe.inline_jaxpr_into_trace( + out = pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) + return [trace.to_jaxpr_tracer(x) for x in out] jaxpr = params['jaxpr'] if config.dynamic_shapes.value: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c8a4ed906ff8..bd2ffb8e1f58 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3264,6 +3264,17 @@ def g(x): jaxpr = jax.make_jaxpr(g)(3) self.assertNotIn('pjit', str(jaxpr)) + def test_pjit_inline_literal(self): + # https://github.com/jax-ml/jax/issues/27545 + def bar(x): + return jnp.array(1) + + def foo(x): + x = pjit(bar, inline=True)(x) + self.assertEqual(x.shape, ()) + + pjit(foo)(0) # doesn't crash + def test_pmap_in_axis_resources_error(self): pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) From d7c22eb1f1aaf4cf41e2e0aa103d8c67ab206daa Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 18 Apr 2025 13:05:58 -0700 Subject: [PATCH 0697/1769] #sdy Do not close any partially sharded dimensions if using auto axes in a shard_map. This change reverts cl/731724837 ([github link](https://github.com/jax-ml/jax/commit/4997e45743e3b243ef153674a11a826843ab37b0)), which is a temporary solution to solve the inconsistent padding on the boundary of manual computation. Now that we have a better solution cl/746600070 ([github link](https://github.com/openxla/xla/pull/25080)), we revert this temporary solution. We still keep the added `test_partially_sharded_dim_with_auto` to verify the correctness. Reverts 4997e45743e3b243ef153674a11a826843ab37b0 PiperOrigin-RevId: 749123936 --- jax/experimental/shard_map.py | 4 +--- tests/shard_map_test.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b81a4656546b..878c42a8d591 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -657,9 +657,7 @@ def _shardy_shard_map_sharding( sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) if auto: for dim_sharding in sdy_sharding.dimension_shardings: - # Only allow dimensions which have no sharding to be auto-sharded. - if not dim_sharding.axes: - dim_sharding.is_open = True + dim_sharding.is_open = True return sdy_sharding diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 5449c68577e0..6235ae3b60ec 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1955,8 +1955,8 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( - 'in_shardings=[<@mesh, [{"i"}, {?}]>]' - ' out_shardings=[<@mesh, [{"i"}, {?}]>] manual_axes={"i"}', + 'in_shardings=[<@mesh, [{"i", ?}, {?}]>]' + ' out_shardings=[<@mesh, [{"i", ?}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: From 80d1fbac42923613130b7f022a324cec6663a4d8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 18 Apr 2025 13:10:40 -0700 Subject: [PATCH 0698/1769] Handle `sharding` param in convert_element_type's batching rule properly by adding the explicit mesh axis on dim 0 PiperOrigin-RevId: 749125322 --- jax/_src/interpreters/batching.py | 2 ++ jax/_src/lax/lax.py | 22 +++++++++++++++++----- tests/pjit_test.py | 17 +++++++++++++++++ 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index ee0b46feddb7..c17fe892325e 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1105,6 +1105,8 @@ def broadcast(x, sz, axis, mesh_axis=None): shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) x_aval = core.get_aval(x) + if x_aval.sharding.mesh.empty: + mesh_axis = None new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) sharding = x_aval.sharding.with_spec(new_spec) # TODO(dougalm, yashkatariya): Delete this context manager once we figure diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 36b19b47a324..da9d64535cd1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4888,7 +4888,16 @@ def _convert_element_type_bind_with_trace(trace, args, params): partial(core.standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule -batching.defvectorized(convert_element_type_p) + +def _convert_element_type_batching_rule( + axis_data, batched_args, batch_dims, *, new_dtype, weak_type, sharding): + if sharding is not None: + sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) + new_params = dict(new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) + return convert_element_type_p.bind(*batched_args, **new_params), batch_dims[0] +batching.fancy_primitive_batchers[convert_element_type_p] = _convert_element_type_batching_rule +batching.skippable_batchers[convert_element_type_p] = lambda _: () + pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule pe.def_trivial_padding(convert_element_type_p) @@ -7336,7 +7345,8 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): else: # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [which_bdim]) + which = broadcast_in_dim(which, cases[0].shape, [which_bdim], + out_sharding=core.typeof(cases[0]).sharding) return select_n(which, *cases), which_bdim elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims): if all(case_bdims[0] == bdim for bdim in case_bdims[1:]): @@ -7347,8 +7357,9 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): for c, c_bdim in zip(cases[1:], case_bdims[1:])] return select_n(which, cases[0], *other_cases), bdim - which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which) - else which) + which = (batching.bdim_at_front(which, which_bdim, size, + axis_data.explicit_mesh_axis) + if np.shape(which) else which) if not all(() == np.shape(c) for c in cases): cases = [batching.bdim_at_front(c, bdim, size, axis_data.explicit_mesh_axis) for c, bdim in zip(cases, case_bdims)] @@ -7356,7 +7367,8 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): if 0 < np.ndim(which) < np.ndim(cases[0]): # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [0]) + which = broadcast_in_dim(which, cases[0].shape, [0], + out_sharding=core.typeof(cases[0]).sharding) if np.ndim(which) > np.ndim(cases[0]): assert np.ndim(cases[0]) == 0 cases = [broadcast(c, which.shape) for c in cases] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bd2ffb8e1f58..d8b2cf9483be 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7499,6 +7499,23 @@ def vmap_where(batch_a, batch_b): out = vmap_where(batch_a, batch_b) self.assertEqual(out.sharding, xy_sharding) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_convert_element_type_vmap(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + am = mesh.abstract_mesh + + @jax.jit + @jax.vmap + def f(x): + y = lax_internal._convert_element_type( + x, jnp.bfloat16, sharding=NamedSharding(am, P('y'))) + self.assertEqual(y.aval.sharding.spec, P('y')) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.with_explicit_mesh((2,), ('x',)) def test_scatter_gather(self, mesh): x = np.random.uniform(size=(mesh.size * 2, 3)) From 0ae613ee480c037061f4936e4105c8218676c2a8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 18 Apr 2025 15:08:24 -0700 Subject: [PATCH 0699/1769] Makes Effort_02 the default value for memory_fitting_level. PiperOrigin-RevId: 749159983 --- jax/_src/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index d427b9c84450..83c5654b87d0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1794,12 +1794,12 @@ def _update_garbage_collection_guard(state, key, val): 'O2', 'O3', ], - default='UNKNOWN', + default='O2', help=( 'The degree to which the compiler should attempt to make the program' ' fit in memory' ), - include_in_jit_key=True + include_in_jit_key=True, ) DEFAULT_CPU_COLLECTIVES_IMPL = "gloo" From 0e1b34196baf32bab1f2d913e957a7b9a5b556e5 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Fri, 18 Apr 2025 16:06:41 -0700 Subject: [PATCH 0700/1769] Refactor array serialization into separate JAX and tensorstore logic Array serialization in array_serialization.py contains a mixture of JAX specific serialization logic and tensorstore driver. This change separates JAX and tensorstore methods (a) making serialization more modular and (b) potentially allowing for alternative array serialization backends in the future. Additional clean-up changes include: - making ocdbt kvstore driver default in tensorstore - robustified array serialization tests especially on multi-host - explicit tensorstore array chunking to ensure chunk file size does not blow up PiperOrigin-RevId: 749175295 --- build/BUILD.bazel | 4 +- build/requirements.in | 1 + jax/_src/distributed.py | 4 +- jax/experimental/array_serialization/BUILD | 18 +- .../array_serialization/serialization.py | 402 ++------------ .../array_serialization/serialization_test.py | 179 +++--- .../array_serialization/tensorstore_impl.py | 518 ++++++++++++++++++ jaxlib/jax.bzl | 1 + jaxlib/xla/xla_extension/__init__.pyi | 5 +- 9 files changed, 673 insertions(+), 459 deletions(-) create mode 100644 jax/experimental/array_serialization/tensorstore_impl.py diff --git a/build/BUILD.bazel b/build/BUILD.bazel index f088cd58aa74..d5c05504e70d 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -63,5 +63,5 @@ compile_pip_requirements( py_library( name = "all_py_deps", - deps = all_py_deps(["zstandard"]), -) \ No newline at end of file + deps = all_py_deps(["zstandard", "tensorstore"]]), +) diff --git a/build/requirements.in b/build/requirements.in index b023cedfbd19..edd7982e74ca 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -19,5 +19,6 @@ scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 opt_einsum zstandard +tensorstore etils[epath] setuptools diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index fb0aebb0e642..1dc75b6d8dfc 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -37,8 +37,8 @@ class State: process_id: int = 0 num_processes: int = 1 - service: Any | None = None - client: Any | None = None + service: xla_extension.DistributedRuntimeService | Any | None = None + client: xla_extension.DistributedRuntimeClient | Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None slice_index: int | None = None diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index 84e5b9300912..d9f7e21e73f6 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -35,9 +35,12 @@ pytype_library( "serialization.py", ], visibility = ["//visibility:public"], - deps = ["//jax"] + py_deps([ - "numpy", + deps = [ + "//jax", + "//jax/experimental/array_serialization:tensorstore_impl", + ] + py_deps([ "absl/logging", + "numpy", ]), ) @@ -48,7 +51,16 @@ jax_multiplatform_test( "tpu_v3_x4", ], deps = [ + ":serialization", "//jax:experimental", - "//jax/experimental/array_serialization:serialization", ], ) + +pytype_library( + name = "tensorstore_impl", + srcs = ["tensorstore_impl.py"], + visibility = ["//visibility:public"], + deps = ["//jax"] + py_deps([ + "numpy", + ]), +) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 8a082b6e912d..fd37694e2ce7 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,34 +17,43 @@ import abc import asyncio -from collections.abc import Awaitable, Callable, Sequence -from functools import partial +from collections.abc import Callable, Sequence +import functools import itertools import logging -import os import re import threading import time -from typing import Any, Optional +from typing import Any import jax from jax._src import array from jax._src import distributed from jax._src import sharding -from jax._src.layout import Layout from jax._src import typing from jax._src import util +from jax._src.layout import Layout from jax._src.lib import xla_extension as xe -import jax.numpy as jnp -import numpy as np -import tensorstore as ts +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +# ruff: noqa: F401 +# pylint: disable=unused-import +# import tensorstore-backed methods for backward compatibility. +from jax.experimental.array_serialization.tensorstore_impl import ( + _run_deserialization as run_deserialization, + _run_serialization as run_serialization, + async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT, + _DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes) + +# for compatibility with older zarr format +_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata, + driver='zarr') +get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec, + driver='zarr', ocdbt=False) +# pylint: enable=unused-import -TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}}) -_REMOVED_VALUE = 'Value removed' _CHECKPOINT_SUCCESS = 'checkpoint_write_success' _module_unique_count = itertools.count() -_DEFAULT_DRIVER = 'file' _DISTRIBUTED_SYSTEM_MSG = ( 'Please initialize the distributed system via ' '`jax.distributed.initialize()` at the start of your program.') @@ -54,7 +63,7 @@ {'driver': 's3', 'path_regex': None}, ] -class BarrierTimeoutException(Exception): +class BarrierTimeoutError(Exception): pass _BARRIER_TIMED_OUT_MSG = ( @@ -66,68 +75,6 @@ class BarrierTimeoutException(Exception): logger = logging.getLogger(__name__) -async def create_async_array_from_callback( - global_shape: array.Shape, - inp_sharding: jax.sharding.Sharding, - data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], -): - device_to_index_map = inp_sharding.devices_indices_map(global_shape) - addressable_da = inp_sharding._addressable_device_assignment - future_arrays = [data_callback(device_to_index_map[d], d) - for d in addressable_da] - dbs = await asyncio.gather(*future_arrays) - return array.make_array_from_single_device_arrays( - global_shape, inp_sharding, dbs) - - -def _get_metadata(arr): - local_shape = arr.addressable_data(0).shape - return { - 'compressor': {'id': 'zstd'}, - 'shape': arr.shape, - 'chunks': np.array(np.maximum(1, local_shape)), - } - - -def _spec_has_metadata(tree): - if not isinstance(tree, dict): - return False - return 'metadata' in tree or any( - _spec_has_metadata(subtree) for _, subtree in tree.items()) - -def _get_kvstore_for_gcs(ckpt_path: str): - m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) - if m is None: - raise ValueError('The ckpt_path should contain the bucket name and the ' - f'file path inside the bucket. Got: {ckpt_path}') - gcs_bucket = m.group(1) - path_without_bucket = m.group(2) - return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} - -def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): - # Normalize path to exclude trailing '/'. In GCS path case, we will need to - # fix the path prefix to add back the stripped '/'. - ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') - is_gcs_path = ckpt_path.startswith('gs://') - spec = {'driver': 'zarr', 'kvstore': {}} - if ocdbt: - if not is_gcs_path and not os.path.isabs(ckpt_path): - raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') - base_path = os.path.dirname(ckpt_path) - spec['kvstore'] = { - 'driver': 'ocdbt', - 'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}', - 'path': os.path.basename(ckpt_path), - } - else: - if is_gcs_path: - spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path) - else: - spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path} - - return spec - - def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. @@ -157,278 +104,6 @@ def is_remote_storage(tspec: dict[str, Any] | str) -> bool: return False - -# Lifted from T5X. -class _LimitInFlightBytes: - """Limits in-flight bytes when reading/writing checkpoints per process.""" - - def __init__(self, num_bytes): - self._max_bytes = num_bytes - self._available_bytes = num_bytes - self._cv = asyncio.Condition(lock=asyncio.Lock()) - - async def wait_for_bytes(self, requested_bytes): - if requested_bytes > self._max_bytes: - raise ValueError('Requested more bytes than we reserved space for: ' - f'{requested_bytes} > {self._max_bytes}') - async with self._cv: - await self._cv.wait_for(lambda: self._available_bytes > requested_bytes) - self._available_bytes -= requested_bytes - assert self._available_bytes >= 0 - - async def release_bytes(self, requested_bytes): - async with self._cv: - self._available_bytes += requested_bytes - assert self._available_bytes <= self._max_bytes - self._cv.notify_all() - - -async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: - data = shard.data - has_pinned_host = any( - m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if has_pinned_host: - # If available, transfer to pinned host memory - sharding = jax.sharding.SingleDeviceSharding(shard.device, - memory_kind="pinned_host") - data = jax.device_put(data, sharding) - else: - data.copy_to_host_async() - # Allow other transfers to be scheduled simultaneously - await asyncio.sleep(0) - # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore - # implicitly converts the written data to a numpy array, and would otherwise - # silently copy host-to-host. - return np.array(data, copy=False) - - -async def async_serialize( - arr_inp, - tensorstore_spec, - commit_future=None, - context=TS_CONTEXT, - primary_host: int | None = 0, - replica_id: int = 0, - transaction: Optional[ts.Transaction] = None, -): - """Serialize an array using TensorStore. - - Args: - arr_inp: The array to serialize. - tensorstore_spec: The tensorstore spec to use. - commit_future: A list of futures that will be appended to. The futures can - be awaited asynchronously. If None, the futures will be awaited - synchronously by this method. - context: ts.Context instance. - primary_host: Primary host, which indicates the host that will be treated as - the "leader". If None, all hosts are treated as the primary. DO NOT USE - unless you are sure you know what you are doing. - replica_id: Allows overriding the shard replica id that will be saved. DO - NOT USE unless you are sure you know what you are doing. - transaction: TensorStore transaction to use for opening and writing the - array. If not specified, a non-transactional write will be used. - """ - if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and - arr_inp.is_fully_addressable): - raise ValueError( - f'Passing fully addressable arrays to a multiprocess ' - f'serialization is not allowed, as this may lead to a race condition ' - f'between processes. Serialization have failed for the array with ' - f'the path "{tensorstore_spec["kvstore"]["path"]}".') - - # 'metadata' may not be present at the top level (for example, if we are using - # a 'cast' driver). - if not _spec_has_metadata(tensorstore_spec): - tensorstore_spec['metadata'] = _get_metadata(arr_inp) - - # Set dtype if it's not in spec - if 'dtype' not in tensorstore_spec: - tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name - - # If primary_host is None, all hosts will checkpoint. This is used - # for checkpointing to local filesystem. - if primary_host is None or jax.process_index() == primary_host: - open_future = ts.open( - ts.Spec(tensorstore_spec), - create=True, - open=True, - context=context, - transaction=transaction, - ) - # Asynchronous case. - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(open_future) - else: - await open_future - - # `ts.open` runs twice for process `primary_host` because for the first time, - # we just get the future to be awaited upon in the background thread. The - # second one runs with `assume_metadata=True` which does no I/O operation and - # returns the tensorstore object. - # For every process other than `primary_host`, we open with - # `assume_metadata=True`. - t = await ts.open( - ts.Spec(tensorstore_spec), - open=True, - assume_metadata=True, - context=context, - transaction=transaction, - ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - data = await transfer_shard_to_host(shard) - write_future = t[shard.index].write( - data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance( - arr_inp, array.ArrayImpl - ), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) - return await asyncio.gather(*future_write_state) - - -def run_serialization(arrays, tensorstore_specs): - async def _run_serializer(): - future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) - return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) - - -def estimate_read_memory_footprint(t: ts.TensorStore, - domain: ts.IndexDomain) -> int: - rank = t.rank - num_bytes = t.dtype.numpy_dtype.itemsize - chunk_template = t.chunk_layout.read_chunk_template - if domain is None: - domain = t.domain - origin = domain.origin - shape = domain.shape - chunk_origin = chunk_template.origin - chunk_shape = chunk_template.shape - - # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. - # For those, instead of returning a near-infinite memory footprint, estimate - # the footprint as the entire shape. - for i in range(rank): - if not chunk_template[i].finite: - return domain.size * num_bytes - - # Otherwise, if we have a chunked driver, estimate based on chunk size. - for i in range(rank): - origin_value = origin[i] - chunk_origin_value = chunk_origin[i] - chunk_size = chunk_shape[i] - lower = origin_value - chunk_origin_value - upper = origin_value + shape[i] - chunk_origin_value - lower_aligned = lower // chunk_size * chunk_size - upper_aligned = -(-upper // chunk_size) * chunk_size - num_bytes *= (upper_aligned - lower_aligned) - - return num_bytes - - -async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, - tensorstore_spec: ts.Spec | dict[str, Any], - global_shape: Sequence[int] | None = None, - dtype=None, - byte_limiter: _LimitInFlightBytes | None = None, - context=TS_CONTEXT, - assume_metadata: bool = False, -): - in_sharding = (user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, jax.sharding.Sharding): - raise ValueError( - 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') - dll = (user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) else None) - t = await ts.open( - tensorstore_spec, - open=True, - assume_metadata=assume_metadata, - context=context, - ) - shape = t.shape if global_shape is None else global_shape - new_shard_shape = in_sharding.shard_shape(tuple(shape)) - - async def cb(index: array.Index, device: jax.Device): - requested_domain = ts.IndexTransform(input_shape=shape)[index].domain - restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = estimate_read_memory_footprint(t, restricted_domain) - # Limit the bytes read for every shard. - if byte_limiter is not None: - await byte_limiter.wait_for_bytes(requested_bytes) - # This maybe needed because the shape the array was saved with is smaller - # than the requested shape of the array in which it will be reloaded. So - # the extra values will be filled with 0s. - out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) - await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ - restricted_domain].write(t[restricted_domain]) - if dtype is not None: - # Cast while reloading on process to avoid 2 copies on device if the - # casting is done on device. - out = out.astype(dtype) - # Convert to jnp array so that layouts are initialized properly for - # sub-byte dtypes. - # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to - # make this work. - if out.dtype == jnp.int4: - out = jnp.asarray(out) # type: ignore - result = jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) - if byte_limiter is not None: - # NB: `out` actually might not be ready for garbage collection by the - # time we call release_bytes . Thus peak memory usage still might grow - # beyond what byte_limiter limit suggests it should. The simplest option - # would be to call `result.block_until_ready()`` here. However it - # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU - # transfer instead of loading data. In the future, if memory pressure - # becomes a problem, we can instead instrument bytelimiter to - # keep track of all in-flight tensors and only block_until_ready, if byte - # limiter hits the limit to get reduced memory usage, without losing - # performance in common use cases. - await byte_limiter.release_bytes(requested_bytes) - return result - - return await create_async_array_from_callback(tuple(shape), in_sharding, cb) - - -def run_deserialization(shardings: Sequence[sharding.Sharding | Layout], - tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Sequence[array.Shape] | None = None, - dtypes: Sequence[typing.DTypeLike] | None = None, - concurrent_gb: int = 32): - concurrent_bytes = concurrent_gb * 10**9 - - async def _run_deserializer(): - # Object should be created once per process. - byte_limiter = _LimitInFlightBytes(concurrent_bytes) - future_arrays = jax.tree_util.tree_map( - partial(async_deserialize, byte_limiter=byte_limiter), - list(shardings), list(tensorstore_specs), - [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, - [None] * len(tensorstore_specs) if dtypes is None else dtypes) - return await asyncio.gather(*future_arrays) - return asyncio.run(_run_deserializer()) - - def _get_key(key: int): return f'tensorstore_checkpoint_{key}' @@ -510,8 +185,7 @@ def __init__(self, timeout_secs=300): if jax.process_count() > 1 and distributed.global_state.client is None: raise ValueError(_DISTRIBUTED_SYSTEM_MSG) - if jax.process_count() > 1: - self._client = distributed.global_state.client + self._client = distributed.global_state.client self._count = None def __del__(self): @@ -533,7 +207,9 @@ def _thread_func(self): logger.info('Finished committing to storage layer by process: %s', current_process) + key_for_barrier = None if process_count > 1: + assert self._client is not None # All processes will wait at the barrier. When all processes are at the # barrier, the barrier will be satisfied. If not, then it will timeout. key_for_barrier = _get_key(self._count) @@ -544,9 +220,11 @@ def _thread_func(self): current_process) if current_process == 0: - self._on_commit_callback() - logger.info('on_commit_callback successfully ran!') + if self._on_commit_callback is not None: + self._on_commit_callback() + logger.info('on_commit_callback successfully ran!') if process_count > 1: + assert self._client is not None self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS) logger.info('Process 0 successfully set key %s in the kv store', key_for_barrier) @@ -555,7 +233,7 @@ def _thread_func(self): '/jax/checkpoint/write/async/thread_duration_sec', time.time() - thread_start_time) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self._exception = e def _start_async_commit(self, on_commit_callback): @@ -572,7 +250,7 @@ def check_for_errors(self): self._exception = None if (isinstance(exception, xe.XlaRuntimeError) and 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)): - raise BarrierTimeoutException( + raise BarrierTimeoutError( '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG])) raise exception # pylint: disable=raising-bad-type @@ -586,6 +264,7 @@ def wait_until_finished(self): logger.info('Error check finished successfully') if jax.process_count() > 1 and self._count is not None: + assert self._client is not None # Block until process 0 writes success value to the key value store. # If it fails to write it, then `blocking_key_value_get` will time out. get_key = _get_key(self._count) @@ -605,8 +284,8 @@ def serialize( arrays, tensorstore_specs, *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): """Serializes Arrays or Arrays via TensorStore asynchronously. @@ -635,11 +314,11 @@ def serialize( logger.info('Waiting for previous serialization to finish.') self.wait_until_finished() - commit_futures: list[ts.Future] = [] + commit_futures: list[ts_impl.Future] = [] async def _run_serializer(): future_writer = jax.tree_util.tree_map( - lambda arr_inp, tensorstore_spec: async_serialize( + lambda arr_inp, tensorstore_spec: ts_impl.async_serialize( arr_inp, tensorstore_spec, commit_future=commit_futures, @@ -649,7 +328,6 @@ async def _run_serializer(): tensorstore_specs, ) return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) self._add_futures(commit_futures) @@ -663,11 +341,11 @@ def serialize_with_paths( arrays: Sequence[jax.Array], paths: Sequence[str], *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): tspecs = jax.tree.map(get_tensorstore_spec, paths) - self.serialize( + return self.serialize( arrays, tspecs, on_commit_callback=on_commit_callback, @@ -680,8 +358,8 @@ def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): self.wait_until_finished() - return run_deserialization(shardings, tensorstore_specs, - global_shapes, dtypes, concurrent_gb) + return ts_impl._run_deserialization( + shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb) def deserialize_with_paths( self, shardings: Sequence[sharding.Sharding], diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 280c2f58b348..9a6b91d04c9a 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-importing-member import asyncio -import math from functools import partial +import math import os import pathlib import tracemalloc as tm @@ -22,21 +23,31 @@ from absl.testing import absltest from absl.testing import parameterized import jax -import jax.numpy as jnp +from jax._src import array from jax._src import config from jax._src import test_util as jtu -from jax._src import array -from jax._src.sharding_impls import NamedSharding, GSPMDSharding, SingleDeviceSharding -from jax.sharding import PartitionSpec as P +from jax._src.layout import DeviceLocalLayout as DLL +from jax._src.layout import Layout from jax.experimental.array_serialization import serialization -from jax.experimental.layout import Layout, DeviceLocalLayout as DLL +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +import jax.numpy as jnp + +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jax.sharding import SingleDeviceSharding import numpy as np import tensorstore as ts +# pylint: enable=g-importing-member jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +def _get_replicated_sharding(devices): + return NamedSharding( + jax.make_mesh(np.shape(devices), P('x'), devices=devices), P()) + + class CheckpointTest(jtu.JaxTestCase): def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @@ -93,13 +104,14 @@ def test_memory_consumption(self): manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [inp], [tspec], - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() async def deserialize_with_byte_limit(): r = await serialization.async_deserialize( - sharding, tspec, inp_shape, - byte_limiter=serialization._LimitInFlightBytes(4_200_000)) + sharding, tspec, inp_shape, + byte_limiter=serialization._LimitInFlightBytes(4_200_000)) r.block_until_ready() tm.start() @@ -133,24 +145,22 @@ def test_memory_consumption_for_save(self): inp_shape, sharding, lambda idx: src[idx] ) ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path) - tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) + tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir), ocdbt=False, + driver='zarr3') tspec['metadata'] = { 'shape': inp.shape, - 'compressor': None, - 'chunks': inp.shape, + 'data_type': jnp.dtype(inp.dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': np.array(np.maximum(1, inp.shape))} + } } - is_cpu = jtu.test_device_matches(['cpu']) tm.start() try: manager = serialization.GlobalAsyncCheckpointManager() - manager.serialize( - [inp], - [tspec], - on_commit_callback=partial( - self._on_commit_callback, ckpt_dir, ckpt_dir - ), - ) + manager.serialize([inp], [tspec], on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() unused_current, peak = tm.get_traced_memory() self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5)) @@ -176,7 +186,8 @@ def test_checkpointing_with_path_variant(self): manager = serialization.GlobalAsyncCheckpointManager() manager.serialize_with_paths( [a1], ckpt_paths, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, = manager.deserialize_with_paths( @@ -201,7 +212,8 @@ def test_checkpointing_jax_array(self): inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data1[idx]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) - ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + ckpt_path1 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/first').full_path) # Second Array global_input_data2 = np.arange( @@ -209,7 +221,8 @@ def test_checkpointing_jax_array(self): a2 = array.make_array_from_callback( inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data2[idx]) - ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path) + ckpt_path2 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/second').full_path) # Third Array def cb3(_): @@ -217,15 +230,17 @@ def cb3(_): global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) - ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) + ckpt_path3 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/third').full_path) ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [a1, a2, a3], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( @@ -295,9 +310,8 @@ def cb3(_): ckpt_path3 = ckpt_dir / 'third' ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map( - lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths - ) + tspecs = jax.tree.map(partial(ts_impl.get_tensorstore_spec, ocdbt=True), + ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() with ts.Transaction(atomic=True) as transaction: @@ -312,13 +326,8 @@ def cb3(_): manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( - [ - NamedSharding(global_mesh, pspec), - NamedSharding(global_mesh, P('x')), - NamedSharding(global_mesh1d, P(None)), - ], - tspecs, - ) + [NamedSharding(global_mesh, pspec), NamedSharding(global_mesh, P('x')), + NamedSharding(global_mesh1d, P(None))], tspecs) self.assertIsInstance(m1, array.ArrayImpl) self.assertArraysEqual( @@ -367,12 +376,13 @@ def cb1(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -395,15 +405,16 @@ def cb1(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [np.float32]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT') global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -418,12 +429,13 @@ def cb(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -448,8 +460,9 @@ def cb(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [target_dtype]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) @@ -463,22 +476,17 @@ def test_checkpointing_scalar_jax_array(self): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) - + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [array1], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) - m1, = serialization.run_deserialization( - [ds], - tspecs, - [()], - [np.float32] - ) + m1, = serialization.run_deserialization([ds], tspecs, [()], [np.float32]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) @@ -488,9 +496,7 @@ def test_deserialize_tensorstore_array_jax_array(self): data = np.arange(1024) tspec = ts.array(data).spec() m1, = serialization.run_deserialization( - [NamedSharding(global_mesh, P(None))], - [tspec] - ) + [NamedSharding(global_mesh, P(None))], [tspec]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data) @@ -507,9 +513,9 @@ def test_spec_has_metadata(self): }, 'f': 4 } - self.assertTrue(serialization._spec_has_metadata(spec)) + self.assertTrue(ts_impl._spec_has_metadata(spec)) self.assertTrue( - serialization._spec_has_metadata({ + ts_impl._spec_has_metadata({ 'driver': 'zarr', 'kvstore': 'gfile', 'metadata': { @@ -531,39 +537,40 @@ def test_spec_has_no_metadata(self): }, 'f': 4 } - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) def test_empty_spec_has_no_metadata(self): spec = {} - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) @parameterized.named_parameters( ('gcs', 'gs://my/ckpt/dir/path'), ('file', '/my/ckpt/dir/path') ) def test_get_tensorstore_spec_ocdbt(self, path): - spec = serialization.get_tensorstore_spec(path, ocdbt=True) + spec = ts_impl.get_tensorstore_spec(path, ocdbt=True) is_gcs_path = path.startswith('gs://') + # for OCDBT the last part of the path is the key in the kvstore + expected_path = os.path.split(path)[0] if is_gcs_path: - self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + self.assertEqual(spec['kvstore']['base']['driver'], 'gcs') + self.assertTrue(expected_path.endswith(spec['kvstore']['base']['path'])) else: - self.assertEqual(spec['kvstore']['base'], - f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}') - self.assertEqual(spec['kvstore']['path'], 'path') + self.assertEqual(spec['kvstore']['base']['path'], expected_path) def test_get_tensorstore_spec_not_absolute_path(self): path = 'my/ckpt/path' with self.assertRaisesRegex(ValueError, - "Checkpoint path should be absolute"): - serialization.get_tensorstore_spec(path, ocdbt=True) + 'Checkpoint path should be absolute'): + ts_impl.get_tensorstore_spec(path, ocdbt=True) def test_maybe_cloud_storage(self): - gs_path = 'gs://some-buck/path' - gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) + gs_path = 'gs://some-buck/path/array_name' + gs_spec = ts_impl.get_tensorstore_spec(gs_path, ocdbt=True) self.assertTrue(serialization.is_remote_storage(gs_spec)) - local_path = '/tmp/checkpoint' - local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) + local_path = '/tmp/checkpoint/array_name' + local_spec = ts_impl.get_tensorstore_spec(local_path, ocdbt=True) self.assertFalse(serialization.is_remote_storage(local_spec)) nested_tspec = { @@ -571,7 +578,8 @@ def test_maybe_cloud_storage(self): 'dtype': 'int32', 'base': { 'driver': 'zarr', - 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'}, + 'kvstore': {'driver': 'ocdbt', + 'base': 's3://some-bucket/path/array_name'}, }, } self.assertTrue(serialization.is_remote_storage(nested_tspec)) @@ -592,12 +600,13 @@ def test_load_with_layout(self): ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path]) + tspecs = jax.tree.map(ts_impl.get_tensorstore_spec, [ckpt_path]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() out, = serialization.run_deserialization([out_layout], tspecs) @@ -610,7 +619,7 @@ def test_load_with_layout(self): def test_deserialization_with_int4(self): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT') if jtu.test_device_matches(['gpu']): self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 @@ -620,10 +629,8 @@ def test_deserialization_with_int4(self): ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) # Run serialization. - sharding = GSPMDSharding.get_replicated(jax.devices()) - tspecs = jax.tree_util.tree_map( - serialization.get_tensorstore_spec, [ckpt_dir] - ) + sharding = _get_replicated_sharding(list(jax.devices())) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, [ckpt_dir]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], @@ -634,11 +641,8 @@ def test_deserialization_with_int4(self): # Run deserialization. deserialized_arr, = serialization.run_deserialization( - shardings=[sharding], - tensorstore_specs=tspecs, - global_shapes=[shape], - dtypes=[dtype], - ) + shardings=[sharding], tensorstore_specs=tspecs, global_shapes=[shape], + dtypes=[dtype]) out = deserialized_arr.astype(jnp.int8) # doesn't crash self.assertEqual(out.dtype, jnp.int8) @@ -650,13 +654,14 @@ class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') def test_transfer_shard_to_host(self): np_inp = np.arange(16).reshape((4, 4)) - sharding = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + sharding = SingleDeviceSharding(jax.devices()[0], memory_kind='device') arr = jax.device_put(np_inp, sharding) shard = arr.addressable_shards[0] - np_out = asyncio.run(serialization.transfer_shard_to_host(shard)) + np_out = asyncio.run(ts_impl._transfer_shard_to_host(shard)) self.assertArraysEqual(np_out, np_inp) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py new file mode 100644 index 000000000000..873cc82da95e --- /dev/null +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -0,0 +1,518 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import partial +import functools +import os +from os import PathLike +import re +from typing import Any, Awaitable, Callable, Sequence +import math +import logging + +import jax +from jax import numpy as jnp +from jax._src import array +from jax._src.layout import Layout +from jax._src import typing +import numpy as np +import tensorstore as ts + +_TS_ARRAY_DRIVER = "zarr3" + +_TS_CONTEXT = ts.Context({ + 'file_io_concurrency': {'limit': 128}, + 'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit + 'cache_pool#remote': {'total_bytes_limit': 10_000_000_000}, + 'data_copy_concurrency': {'limit': 128} +}) +_TS_CHUNK_LAYOUT = ts.ChunkLayout({ + "chunk": {"elements": 100_000_000}, # 100M (800MB for float64) file size +}) + +_DEFAULT_BASE_DRIVER = 'file' +_PROCESS_DIR_FORMAT = "process_{}" +_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB + +Future, Transaction = ts.Future, ts.Transaction + +logger = logging.getLogger(__name__) + +# Lifted from T5X. +class _LimitInFlightBytes: + """Limits host scratch memory usage when reading/writing checkpoints per process.""" + + def __init__(self, host_memory_bytes_limit: int): + self._max_bytes = host_memory_bytes_limit + self._available_bytes = host_memory_bytes_limit + self._cv = asyncio.Condition(lock=asyncio.Lock()) + + async def wait_for_bytes(self, requested_bytes): + if requested_bytes > self._max_bytes: + logger.debug("A single array item requests more bytes than we reserved" + " space for in the parallel pool: %d > %d. Increasing the" + " limit to %d.", requested_bytes, self._max_bytes, + requested_bytes) + self._max_bytes = requested_bytes + async with self._cv: + await self._cv.wait_for(lambda: self._available_bytes >= requested_bytes) + self._available_bytes -= requested_bytes + assert self._available_bytes >= 0 + + async def release_bytes(self, requested_bytes): + async with self._cv: + self._available_bytes += requested_bytes + assert self._available_bytes <= self._max_bytes + self._cv.notify_all() + +def _prime_factors(x: int) -> list[int]: + # find prime factors of axis sizes to help efficiently find divisor chunks + factors = [] + while x % 2 == 0: + factors.append(2) + x //= 2 + for i in range(3, int(math.sqrt(x)) + 1, 2): + while x % i == 0: + factors.append(i) + x //= i + if x > 1: + factors.append(x) + return sorted(factors) + +@functools.lru_cache(maxsize=1024) +def _compute_chunk_shape( + local_shape: Sequence[int], dtype: str | jnp.dtype, + file_size_target: int = _FILE_SIZE_TARGET) -> list[int]: + """Compute a chunk such that it divides the local shape and is less than + target file size. This helps the tensorstore kvstore driver limit the largest + file size on disk to below the ``file_size_target``. We compute a chunk with a + byte size at most 110% of the ``file_size_target``. + """ + local_shape = list(local_shape) + if len(local_shape) == 0 or math.prod(local_shape) == 0: + # a zero size array needs a non-zero chunk passed to tensorstore for compat. + return [max(z, 1) for z in local_shape] + total_size = math.prod(local_shape) * jnp.dtype(dtype).itemsize + axis_prime_factors = [_prime_factors(z) for z in local_shape] + chunk_shape, chunk_size = list(local_shape), total_size + # while chunk_size exceeds target size, reduce chunk_shape + while chunk_size > 1.1 * file_size_target: # 10% buffer + # 1. find the smallest axis divisor across all axes + chosen_axis_idx, chosen_divisor = None, 1 + for axis_idx in range(len(chunk_shape)): + if len(axis_prime_factors[axis_idx]) == 1: # ignore axes sizes == 1 + continue + if (chosen_axis_idx is None + or chosen_divisor > axis_prime_factors[axis_idx][0]): + chosen_axis_idx = axis_idx + chosen_divisor = axis_prime_factors[axis_idx][0] + # 2. if no divisor found, give up, return current chunk shape + if chosen_axis_idx is None: + return chunk_shape + # 3. remove the applied divisor from prime factors + prime_factors = axis_prime_factors[chosen_axis_idx] + prime_factors.pop(0) + # 4. apply the found divisor to reduce the chunk size + chunk_shape[chosen_axis_idx] //= chosen_divisor + chunk_size //= chosen_divisor + return chunk_shape + +def _get_tensorstore_metadata(arr, is_remote: bool = False, + file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + global_shape, dtype = arr.shape, arr.dtype + if hasattr(arr, 'addressable_data'): # jax.Array + local_shape = arr.addressable_data(0).shape + else: # np.ndarray + local_shape = global_shape + return _get_tensorstore_metadata_cached(global_shape, dtype, local_shape, + is_remote, file_size_target, driver) + +@functools.lru_cache(maxsize=1024) +def _get_tensorstore_metadata_cached( + global_shape: Sequence[int], dtype: jnp.dtype, local_shape: Sequence[int], + is_remote: bool = False, file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + if driver == "zarr3": + codecs = ([{"name": "zstd"}] if is_remote else []) + return { + 'codecs': codecs, + 'shape': global_shape, + 'data_type': jnp.dtype(dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': _compute_chunk_shape( + local_shape, dtype, file_size_target=file_size_target)} + } + } + elif driver == "zarr": # in zarr dtype goes in the base spec + return {'compressor': {'id': 'zstd'}, 'shape': global_shape, + 'chunks': np.array(np.maximum(1, local_shape)).tolist()} + else: + raise ValueError(f"Unsupported driver: {driver}") + +def _spec_has_metadata(tree): + if not isinstance(tree, dict): + return False + return 'metadata' in tree or any( + _spec_has_metadata(subtree) for _, subtree in tree.items()) + +def _get_kvstore_for_gcs(ckpt_path: str): + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket} + +def _get_kvstore_for_s3(ckpt_path: str): + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket} + +def get_tensorstore_spec( + ckpt_path: str | PathLike[str], ocdbt: bool = True, + process_num: int | None = None, arr: jax.Array | None = None, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + + # Normalize path to exclude trailing '/'. In GCS path case, normpath will + # replace a the double '//' with a single '/' and we need to restore the + # filesystem type:// prefix for GCS (gs://) and S3 paths (s3://) + ckpt_path = os.path.normpath(str(ckpt_path)) + ckpt_path = re.sub(r"^([a-z]+):/", r"\1://", ckpt_path) + + # in cases of multi-process writes, we need to write to a different location + # for each process and finally created a combined symlink to the final + # location, tensorstore can do this via ts.KvStore.experimental_copy_range_to + if process_num is not None: + _parent, _name = os.path.split(ckpt_path) + ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_num), + _name) + + is_gcs_path = ckpt_path.startswith('gs://') + is_s3_path = ckpt_path.startswith('s3://') + spec = {'driver': driver, 'kvstore': {}} + + # use a combined OCDBT store, the actual path is the parent path + # the name (filename/last part of the path) is the key in the ocdbt kvstore + entry_key = None + if ocdbt: + (ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path + if is_gcs_path: + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + elif is_s3_path: + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path) + else: + m = re.match("a", "a") # make it True + if m is None: + raise ValueError('Using OCDBT requires the bucket name, the directory' + ' name and the array name, your path is: ' + f'{org_ckpt_path}') + + if is_gcs_path: + base_kvstore = _get_kvstore_for_gcs(ckpt_path) + elif is_s3_path: + base_kvstore = _get_kvstore_for_s3(ckpt_path) + else: + base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path} + + if ocdbt: + if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path): + raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') + spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore, + 'path': entry_key} + else: + spec['kvstore'] = base_kvstore + # done writing tensorstore spec based on destination path + # optionally, if array is provided, we can add metadata to the spec + if arr is not None: + spec["metadata"] = _get_tensorstore_metadata( + arr, driver=str(spec["driver"])) + return spec + +async def _create_async_array_from_callback( + global_shape: array.Shape, + inp_sharding: jax.sharding.Sharding, + data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], +): + device_to_index_map = inp_sharding.devices_indices_map(global_shape) + addressable_da = inp_sharding._addressable_device_assignment + future_arrays = [data_callback(device_to_index_map[d], d) + for d in addressable_da] + dbs = await asyncio.gather(*future_arrays) + return array.make_array_from_single_device_arrays( + global_shape, inp_sharding, dbs) + +async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray: + data = shard.data + has_pinned_host = any( + m.kind == "pinned_host" for m in shard.device.addressable_memories()) + if has_pinned_host: + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding(shard.device, + memory_kind="pinned_host") + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore + # implicitly converts the written data to a numpy array, and would otherwise + # silently copy host-to-host. + return np.array(data, copy=False) + +async def async_serialize( + arr_inp, + tensorstore_spec, + commit_future=None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + primary_host: int | None = None, + replica_id: int = 0, + transaction: ts.Transaction | None = None, +): + """Serialize an array using TensorStore. + + Args: + arr_inp: The array to serialize. + tensorstore_spec: The tensorstore spec to use. + commit_future: A list of futures that will be appended to. The futures can + be awaited asynchronously. If None, the futures will be awaited + synchronously by this method. + context: ts.Context instance. + primary_host: Primary host, which indicates the host that will be treated as + the "leader". If None, all hosts are treated as the primary. DO NOT USE + unless you are sure you know what you are doing. + replica_id: Allows overriding the shard replica id that will be saved. DO + NOT USE unless you are sure you know what you are doing. + transaction: TensorStore transaction to use for opening and writing the + array. If not specified, a non-transactional write will be used. + """ + if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and + arr_inp.is_fully_addressable): + raise ValueError( + f'Passing fully addressable arrays to a multiprocess ' + f'serialization is not allowed, as this may lead to a race condition ' + f'between processes. Serialization have failed for the array with ' + f'the path from kvstore: "{tensorstore_spec["kvstore"]}".') + + # 'metadata' may not be present at the top level (for example, if we are using + # a 'cast' driver). + if not _spec_has_metadata(tensorstore_spec): + tensorstore_spec['metadata'] = _get_tensorstore_metadata( + arr_inp, driver=tensorstore_spec['driver']) + ## zarr driver requires specifying the dtype in the spec base + if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec: + tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name + + # If primary_host is None, all hosts will checkpoint. This is used + # for checkpointing to local filesystem. + if primary_host is None or jax.process_index() == primary_host: + open_future = ts.open( + ts.Spec(tensorstore_spec), + create=True, + open=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + # Asynchronous case. + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(open_future) + else: + await open_future + + # `ts.open` runs twice for process `primary_host` because for the first time, + # we just get the future to be awaited upon in the background thread. The + # second one runs with `assume_metadata=True` which does no I/O operation and + # returns the tensorstore object. + # For every process other than `primary_host`, we open with + # `assume_metadata=True`. + t = await ts.open( + ts.Spec(tensorstore_spec), + open=True, + assume_metadata=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + + async def _write_array(shard): + if shard.replica_id == replica_id: + data = await _transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=isinstance( + arr_inp, array.ArrayImpl + ), + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(write_future.commit) + await write_future.copy + else: + await write_future.commit + + local_shards = arr_inp.addressable_shards + future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + return await asyncio.gather(*future_write_state) + + +# TODO(rdyro): Remove this function. +def _run_serialization(arrays, tensorstore_specs): + """Legacy serialization of a list of arrays.""" + async def _run_serializer(): + future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) + return await asyncio.gather(*future_writer) + asyncio.run(_run_serializer()) + + +def estimate_read_memory_footprint(t: ts.TensorStore, + domain: ts.IndexDomain) -> int: + rank = t.rank + num_bytes = t.dtype.numpy_dtype.itemsize + chunk_template = t.chunk_layout.read_chunk_template + if domain is None: + domain = t.domain + origin = domain.origin + shape = domain.shape + chunk_origin = chunk_template.origin + chunk_shape = chunk_template.shape + + # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. + # For those, instead of returning a near-infinite memory footprint, estimate + # the footprint as the entire shape. + for i in range(rank): + if not chunk_template[i].finite: + return domain.size * num_bytes + + # Otherwise, if we have a chunked driver, estimate based on chunk size. + for i in range(rank): + origin_value = origin[i] + chunk_origin_value = chunk_origin[i] + chunk_size = chunk_shape[i] + lower = origin_value - chunk_origin_value + upper = origin_value + shape[i] - chunk_origin_value + lower_aligned = lower // chunk_size * chunk_size + upper_aligned = -(-upper // chunk_size) * chunk_size + num_bytes *= (upper_aligned - lower_aligned) + + return num_bytes + + +async def async_deserialize( + user_in_sharding: jax.sharding.Sharding | Layout, + tensorstore_spec: ts.Spec | dict[str, Any], + global_shape: Sequence[int] | None = None, + dtype=None, + byte_limiter: _LimitInFlightBytes | None = None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + assume_metadata: bool = False, +): + """Main performant deserialization routine for arrays using tensorstore.""" + in_sharding = (user_in_sharding.sharding + if isinstance(user_in_sharding, Layout) else user_in_sharding) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + 'sharding passed to deserialization should be specified, concrete and' + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') + dll = (user_in_sharding.device_local_layout + if isinstance(user_in_sharding, Layout) else None) + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=assume_metadata, + context=context, + chunk_layout=chunk_layout, + ) + shape = t.shape if global_shape is None else global_shape + new_shard_shape = in_sharding.shard_shape(tuple(shape)) + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = estimate_read_memory_footprint(t, restricted_domain) + # Limit the bytes read for every shard. + if byte_limiter is not None: + await byte_limiter.wait_for_bytes(requested_bytes) + # This maybe needed because the shape the array was saved with is smaller + # than the requested shape of the array in which it will be reloaded. So + # the extra values will be filled with 0s. + out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ + restricted_domain].write(t[restricted_domain]) + if dtype is not None: + # Cast while reloading on process to avoid 2 copies on device if the + # casting is done on device. + out = out.astype(dtype) + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + result = jax.device_put( + out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) + if byte_limiter is not None: + # NB: `out` actually might not be ready for garbage collection by the + # time we call release_bytes . Thus peak memory usage still might grow + # beyond what byte_limiter limit suggests it should. The simplest option + # would be to call `result.block_until_ready()`` here. However it + # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU + # transfer instead of loading data. In the future, if memory pressure + # becomes a problem, we can instead instrument bytelimiter to + # keep track of all in-flight tensors and only block_until_ready, if byte + # limiter hits the limit to get reduced memory usage, without losing + # performance in common use cases. + await byte_limiter.release_bytes(requested_bytes) + return result + + return await _create_async_array_from_callback(tuple(shape), in_sharding, cb) + + +# TODO(rdyro): Remove this function. +def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Layout], + tensorstore_specs: Sequence[dict[str, Any]], + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, + concurrent_gb: int = 32): + """Legacy deserialization of a list of arrays. Optionally pass global_shapes + and dtypes for type-checking. + """ + concurrent_bytes = concurrent_gb * 10**9 + + async def _run_deserializer(): + # Object should be created once per process. + byte_limiter = _LimitInFlightBytes(concurrent_bytes) + + future_arrays = jax.tree_util.tree_map( + partial(async_deserialize, byte_limiter=byte_limiter), + list(shardings), list(tensorstore_specs), + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes) + return await asyncio.gather(*future_arrays) + return asyncio.run(_run_deserializer()) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index d6030a7ce03f..718401c1477f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -108,6 +108,7 @@ _py_deps = { "numpy": ["@pypi_numpy//:pkg"], "scipy": ["@pypi_scipy//:pkg"], "tensorflow_core": [], + "tensorstore": get_optional_dep("@pypi_tensorstore//:pkg"), "torch": [], "zstandard": get_zstandard(), } diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 7bfb2b1f675b..491ba4a2e50a 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -852,9 +852,8 @@ class DistributedRuntimeClient: def key_value_set_bytes(self, key: str, value: bytes, allow_overwrite: bool = False) -> _Status: ... def key_value_delete(self, key: str) -> _Status: ... - def wait_at_barrier( - self, barrier_id: str, timeout_in_ms: int, process_ids: Optional[List[int]] - ) -> _Status: ... + def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int, + process_ids: Optional[List[int]] = None) -> _Status: ... def get_live_nodes(self, process_ids: List[int]) -> _Status: ... def get_distributed_runtime_service( From 8318157f8d5cc523edb0966e0c85adc3ecdbc571 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 18 Apr 2025 18:23:10 -0700 Subject: [PATCH 0701/1769] Update changelog to add information about breaking change with respect to tracing cache after sharding_in_types config was turned on which lead to `sharding` always being available on `ShapedArray` PiperOrigin-RevId: 749206500 --- CHANGELOG.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5494cc9d7a7..efb515b73283 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,30 @@ Patch release of 0.5.1 ## jax 0.5.1 (Feb 24, 2025) +* Breaking changes + * The jit tracing cache now keys on input NamedShardings. Previously, the + tracing cache did not include sharding information at all + (although subsequent jit caches did like lowering and compilation caches), + so two equivalent shardings of different types would not retrace, + but now they do. For example: + ```python + @jax.jit + def f(x): + return x + + # inp1.sharding is of type SingleDeviceSharding + inp1 = jnp.arange(8) + f(inp1) + + mesh = jax.make_mesh((1,), ('x',)) + # inp2.sharding is of type NamedSharding + inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) + f(inp2) # tracing cache miss + ``` + In the above example, calling `f(inp1)` and then `f(inp2)` will lead to a + tracing cache miss because the shardings have changed on the abstract values + while tracing. + * New Features * Added an experimental {func}`jax.experimental.custom_dce.custom_dce` decorator to support customizing the behavior of opaque functions under From 580fe9efe93968ebd2c5e2ca4f56d39a4587e9b9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 18 Apr 2025 21:09:14 -0700 Subject: [PATCH 0702/1769] Make `Shape::add_dimensions()` validate arguments by default. PiperOrigin-RevId: 749243140 --- .../xla_client_backend_independent_test.py | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/jaxlib/xla/xla_client_backend_independent_test.py b/jaxlib/xla/xla_client_backend_independent_test.py index ee1c33feb40c..611c602a73b5 100644 --- a/jaxlib/xla/xla_client_backend_independent_test.py +++ b/jaxlib/xla/xla_client_backend_independent_test.py @@ -34,17 +34,22 @@ class ShapeTest(absltest.TestCase): def testInvalidShapes(self): - with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"): + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, "Invalid dimension size" + ): xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) with self.assertRaisesRegex( - RuntimeError, "layout minor_to_major field contains 1 element.*"): + RuntimeError, "layout minor_to_major field contains 1 element.*" + ): xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) with self.assertRaisesRegex( - RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): - xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], - [1, -1]) + RuntimeError, "layout minor_to_major field has out-of-bounds value.*" + ): + xla_client.Shape.array_shape( + xla_client.PrimitiveType.F32, [2, 4], [1, -1] + ) class ComputationPrinting(absltest.TestCase): @@ -52,8 +57,9 @@ class ComputationPrinting(absltest.TestCase): def ExampleComputation(self): builder = xla_client.XlaBuilder("acomputation") p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter(builder, 1, - xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + p1 = ops.Parameter( + builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) + ) x = ops.Mul(p0, p1) ops.Add(x, x) return builder.build() @@ -92,7 +98,8 @@ def testHloModuleFromText(self): def testHloModuleToHloGraph(self): computation = self.ExampleComputation() hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.as_hlo_module()) + computation.as_hlo_module() + ) self.assertTrue(hlo_dot_graph.startswith("digraph ")) @@ -101,15 +108,17 @@ class ComputationHashTest(absltest.TestCase): def testHash(self): builder0 = xla_client.XlaBuilder("computation0") p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter(builder0, 1, - xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + p1 = ops.Parameter( + builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) + ) ops.Mul(p0, p1) computation0 = builder0.build() builder1 = xla_client.XlaBuilder("computation1") p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter(builder1, 1, - xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + p1 = ops.Parameter( + builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) + ) ops.Mul(p0, p1) computation1 = builder1.build() @@ -121,13 +130,19 @@ class AliasTest(absltest.TestCase): def testSetUpAlias(self): c = xla_client.XlaBuilder(self.id()) p1 = ops.Parameter( - c, 0, - xla_client.shape_from_pyval(np.array( - 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + c, + 0, + xla_client.shape_from_pyval( + np.array(1.0, np.float32) + ).with_major_to_minor_layout_if_absent(), + ) p2 = ops.Parameter( - c, 1, - xla_client.shape_from_pyval(np.array( - 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + c, + 1, + xla_client.shape_from_pyval( + np.array(1.0, np.float32) + ).with_major_to_minor_layout_if_absent(), + ) out = ops.Add(p1, p2) c.setup_alias([], 0, []) c.build(out) @@ -159,8 +174,9 @@ class HloModuleGroupTest(absltest.TestCase): def testHloModuleGroup(self): builder0 = xla_client.XlaBuilder("computation0") p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter(builder0, 1, - xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + p1 = ops.Parameter( + builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) + ) root = ops.Mul(p0, p1) computation0 = builder0.build(root) @@ -179,8 +195,9 @@ class RunHloPassTest(absltest.TestCase): def testHloDCE(self): b = xla_client.XlaBuilder("acomputation") p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter(b, 1, - xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + p1 = ops.Parameter( + b, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) + ) root = ops.Mul(p0, p1) # Dead instructions From 0573706b25a0ba7436924875477fed970002da0d Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 19 Apr 2025 07:07:07 -0700 Subject: [PATCH 0703/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1dd84dc2e7f87d79ba9f77b9874ff4a50227ad5e. PiperOrigin-RevId: 749332712 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 7a393c317533..180cf700ddcb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9477be9b2b3afdc2c93b7c272d00f8034750f41c" -XLA_SHA256 = "07f295daea8cee7b72c1be64c0546b26f121b31b7897061c3fe86acfb52e238d" +XLA_COMMIT = "1dd84dc2e7f87d79ba9f77b9874ff4a50227ad5e" +XLA_SHA256 = "2b16a6c708710443e0571b07ef107b23aa278eb2f32ad013eba94a0d033e5011" def repo(): tf_http_archive( From 6c56d651e5e47cfc3713e1eb0ffc61462bbe4092 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 19 Apr 2025 11:03:46 -0700 Subject: [PATCH 0704/1769] jax.random.bernoulli: add mode='high' for improved sampling for small p --- jax/_src/random.py | 25 ++++++++++++++++++++----- tests/random_lax_test.py | 23 +++++++++++++++++++++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index b29c1dca7b08..d167b9471f3a 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -918,7 +918,8 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: ArrayLike, p: RealArray = np.float32(0.5), - shape: Shape | None = None) -> Array: + shape: Shape | None = None, + mode: str = 'low') -> Array: r"""Sample Bernoulli random values with given shape and mean. The values are distributed according to the probability mass function: @@ -935,6 +936,10 @@ def bernoulli(key: ArrayLike, shape: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. + mode: optional, "high" or "low" for how many bits to use when sampling. + default='low'. Set to "high" for correct sampling at small values of + `p`. When sampling in float32, bernoulli samples with mode='low' produce + incorrect results for p < ~1E-7. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` @@ -942,23 +947,33 @@ def bernoulli(key: ArrayLike, """ if shape is not None: shape = core.canonicalize_shape(shape) + if mode not in ['high', 'low']: + raise ValueError(f"got {mode=}, expected 'high' or 'low'") key, _ = _check_prng_key("bernoulli", key) dtype = dtypes.canonicalize_dtype(lax.dtype(p)) if not jnp.issubdtype(dtype, np.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) - return _bernoulli(key, p, shape) + return _bernoulli(key, p, shape, mode=mode) -@partial(jit, static_argnums=(2,)) -def _bernoulli(key, p, shape) -> Array: + +@partial(jit, static_argnames=['shape', 'mode']) +def _bernoulli(key: Array, p: Array, shape: Shape | None, mode: str) -> Array: if shape is None: # TODO: Use the named part of `p` as well shape = np.shape(p) else: _check_shape("bernoulli", shape, np.shape(p)) + dtype = lax.dtype(p) - return uniform(key, shape, lax.dtype(p)) < p + if mode == 'high': + u1, u2 = uniform(key, (2, *shape), dtype) + # resolution of uniform samples is 2 ** -n_mantissa + u2 *= 2 ** -dtypes.finfo(dtype).nmant + return u2 < p - u1 + else: + return uniform(key, shape, lax.dtype(p)) < p def beta(key: ArrayLike, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index b4d2853abd65..f87b079b759c 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -372,11 +372,13 @@ def testPermutationErrors(self): @jtu.sample_product( p=[0.1, 0.5, 0.9], dtype=jtu.dtypes.floating, + mode=[None, 'low', 'high'], ) - def testBernoulli(self, p, dtype): + def testBernoulli(self, p, dtype, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) - rand = lambda key, p: random.bernoulli(key, p, (10000,)) + kwds = {} if mode is None else {'mode': mode} + rand = lambda key, p: random.bernoulli(key, p, (10000,), **kwds) crand = jax.jit(rand) uncompiled_samples = rand(key(), p) @@ -461,6 +463,23 @@ def testBernoulliShape(self): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) + def testBernoulliSmallProbabilty(self): + # Regression test for https://github.com/jax-ml/jax/issues/28017 + key = jax.random.key(0) + + # Choose such that N * p is much less than 1. + p = jnp.float32(1E-10) + N = int(1E8) + + # mode='low' fails for p<~1E-7 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='low') + self.assertNotEqual(samples.sum(), 0) + + # mode='high' is good up to p<~1E-14 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='high') + self.assertEqual(samples.sum(), 0) + + @jtu.sample_product( a=[0.2, 5.], b=[0.2, 5.], From 8117043dd0f70993d0e49e1828160ff6e3052787 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Sat, 19 Apr 2025 22:08:48 -0700 Subject: [PATCH 0705/1769] fix BUILD file syntax error PiperOrigin-RevId: 749464614 --- build/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index d5c05504e70d..4d8dd9c1b7d8 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -63,5 +63,5 @@ compile_pip_requirements( py_library( name = "all_py_deps", - deps = all_py_deps(["zstandard", "tensorstore"]]), + deps = all_py_deps(["zstandard", "tensorstore"]), ) From f3224caf462eb4f5618d16d37f122f15c919b4ae Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 20 Apr 2025 01:12:35 -0700 Subject: [PATCH 0706/1769] [Pallas/Fuser] Add support for pl.Element in fuser BlockSpec PiperOrigin-RevId: 749492881 --- jax/_src/pallas/fuser/block_spec.py | 12 +++++++-- tests/pallas/fuser_block_spec_test.py | 37 +++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 9524ce4ca4d2..89b5121f5691 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -420,9 +420,17 @@ def make_kernel_function( invar_usages = util.safe_map(read_usage_env, jaxpr.invars) bs_env, scalar_prefetch_fn_env = block_spec_env - def _remove_nones(shape: tuple[int | None, ...] | None) -> tuple[int, ...]: + def _block_size(dim: pallas_core.Element | int | None) -> int | None: + if isinstance(dim, pallas_core.Element): + return dim.block_size + return dim + + def _remove_nones( + shape: tuple[pallas_core.Element | int | None, ...] | None + ) -> tuple[int, ...]: assert shape is not None - return tuple(s for s in shape if s is not None) + new_shape = tuple(_block_size(s) for s in shape) + return tuple(s for s in new_shape if s is not None) _no_aval = object() diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 377901933b4e..ac82cd5f1b35 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -695,6 +695,43 @@ def f(): kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x ) + def test_element_indexing(self): + + x = np.zeros((512, 512), dtype=np.float32) + + def f(): + return x + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + # Block spec with an offset on the first dimension + block_spec = pl.BlockSpec( + (pl.Element(128, (0, 16)), 128), lambda i, j, k: (128 * i + 16, j) + ) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(1, 1, 1), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + self.assertEmpty(scalar_prefetch_values) + self.assertEqual(value_block_specs[0].block_shape, (pl.Element(128, (0, 16)), 128)) + self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (16, 1)) + self.assertEqual(value_block_specs[0].index_map(1, 1, 2), (128 + 16, 1)) + + x_block = np.ones((128, 128), dtype=np.float32) + np.testing.assert_array_equal( + kernel_fn( + (0, 0, 0), + scalar_prefetch_values, + (np.ones((128, 128), dtype=np.float32),), + ), + x_block, + ) + class PullBlockSpecHOPTest(jtu.JaxTestCase): From 88617818da799dc1c7bed2f8054bd7cec614ef74 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 20 Apr 2025 07:37:20 -0700 Subject: [PATCH 0707/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f94b36b783b9d955ec8cc966fc0b76cf9e265382. PiperOrigin-RevId: 749548364 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 180cf700ddcb..63bb9a1db73c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1dd84dc2e7f87d79ba9f77b9874ff4a50227ad5e" -XLA_SHA256 = "2b16a6c708710443e0571b07ef107b23aa278eb2f32ad013eba94a0d033e5011" +XLA_COMMIT = "f94b36b783b9d955ec8cc966fc0b76cf9e265382" +XLA_SHA256 = "93f9961a3e0920c883c5807fb6637f4b61cf3edde536ffd8ae52f4240884609a" def repo(): tf_http_archive( From ed719ecb74d334b2887f58e51a04a211945b70a8 Mon Sep 17 00:00:00 2001 From: Franz Srambical <79149449+emergenz@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:08:16 +0200 Subject: [PATCH 0708/1769] fix: typo --- tests/multi_device_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 38a37844ebf8..c4bcd98472d5 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -102,7 +102,7 @@ def test_computation_follows_data(self): self.assert_uncommitted_to_device(z3, devices[0]) - # A jitted computation with an device specification behaves as if the + # A jitted computation with a device specification behaves as if the # arguments are first device_put to the specified device. The result # will be committed on the specified. # The `device` parameter is experimental, and subject to change. From 3dad9de765b8b40edf22bc983ea7815c55dfbca1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 21 Apr 2025 06:23:49 -0700 Subject: [PATCH 0709/1769] Add support for TPU7x in Mosaic. PiperOrigin-RevId: 749779206 --- jax/_src/pallas/mosaic/pipeline.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 5153ee6dcce3..0f0e4a342fc7 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -82,8 +82,11 @@ def _get_tpu_generation() -> int: kind = get_default_device().device_kind if kind.endswith(' lite'): kind = kind[:-len(' lite')] - assert kind[:5] == "TPU v", kind - return int(kind[5]) + if kind.startswith("TPU v"): + return int(kind[5]) + else: + assert "TPU7x" in kind + return 7 def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions From b4c413552aaa37fca4e95b79856e4542256267a4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 21 Apr 2025 06:51:31 -0700 Subject: [PATCH 0710/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9219fd7ef180a01f814d3fce9f8aecfd80b9fd6c. PiperOrigin-RevId: 749784566 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 63bb9a1db73c..a25dee5e3f83 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f94b36b783b9d955ec8cc966fc0b76cf9e265382" -XLA_SHA256 = "93f9961a3e0920c883c5807fb6637f4b61cf3edde536ffd8ae52f4240884609a" +XLA_COMMIT = "9219fd7ef180a01f814d3fce9f8aecfd80b9fd6c" +XLA_SHA256 = "9c8cfa363951ee90e36d86d861c9d07f703562cc4b044de77877edde113261bc" def repo(): tf_http_archive( From c4bb9b671d3ca624ca14ae94177afdd52a04f6dd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 21 Apr 2025 15:33:05 +0000 Subject: [PATCH 0711/1769] Fixed cached mlir module mutation issue in export Description: - Copy mlir module before adding new attributes Fixes #27991 --- jax/_src/export/_export.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d02ef44f8318..91b093dd05bd 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -669,6 +669,12 @@ def _export_lowered( # For pmap module_kept_var_idx = tuple(range(len(args_avals_flat))) shape_poly_state = lowering.compile_args["shape_poly_state"] + + # Make a copy of mlir module as we should not mutate it + # because it may be cached + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + mlir_module = ir.Module.parse(mlir.module_to_bytecode(mlir_module)) if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) or lowering.compile_args.get("ordered_effects", [])): mlir_module = _wrap_main_func( @@ -840,7 +846,7 @@ def _wrap_main_func( See calling convention documentation https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. + module: a copy of HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, which correspond to the array arguments of the `module`. args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error @@ -854,10 +860,9 @@ def _wrap_main_func( Returns the wrapped module, without dimension and token arguments. """ dim_vars = shape_poly.all_dim_vars(args_avals_flat) - context = mlir.make_ir_context() + context = module.context + wrapped_module = module with context, ir.Location.unknown(context): - # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) symbol_table = ir.SymbolTable(wrapped_module.operation) orig_main = symbol_table["main"] orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") From 8d2e9a853a2b7c86012479cc133ed864bd1d3ada Mon Sep 17 00:00:00 2001 From: Jon Barron Date: Mon, 21 Apr 2025 09:23:36 -0700 Subject: [PATCH 0712/1769] Simplify (and potentially accelerate) fftfreq(), as a single vectorized op instead of multiple .at[] calls. PiperOrigin-RevId: 749818535 --- jax/_src/numpy/fft.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index f962438f23bb..2316ad73ffeb 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -1186,21 +1186,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n, dtype=dtype, device=device) - if n % 2 == 0: - # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) - - # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) - - else: - # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)) - - # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) - + i = jnp.arange(n, dtype=dtype, device=device) + k = ((i + n//2) % n - n//2) return k / jnp.array(d * n, dtype=dtype, device=device) From 85ad1fda5f760eeeb192b4b5bca2b021ab10c10f Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 21 Apr 2025 11:34:05 -0700 Subject: [PATCH 0713/1769] implement _replace_with for PRNGKeyArray --- jax/_src/prng.py | 3 +++ tests/mutable_array_test.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index dd91097fcf98..51211e62afc2 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -173,6 +173,9 @@ def __init__(self, impl, key_data: Any): [key_data], [device], committed=False) self._base_array = key_data + def _replace_with(self, value: PRNGKeyArray): + self._base_array._replace_with(value._base_array) + def block_until_ready(self): _ = self._base_array.block_until_ready() return self diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 5b7669f3db2d..865d4f8520f1 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -259,6 +259,11 @@ def test_implicit_cast_in_swap(self): v = core.mutable_array(jnp.array(0, dtype='bfloat16')) v[...] += 1.0 # don't crash + def test_rng_key(self): + key = core.mutable_array(jax.random.key(0)) + # test read/write + key[...] = jax.random.fold_in(key[...], 1) # don't crash + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From 504af332277c10a32106e8b373925d6c5ea4014e Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Tue, 15 Apr 2025 16:30:18 -0700 Subject: [PATCH 0714/1769] Add a document for activation and parameter offloading --- docs/notebooks/host-offloading.ipynb | 508 +++++++++++++++++++++++++++ docs/notebooks/host-offloading.md | 351 ++++++++++++++++++ 2 files changed, 859 insertions(+) create mode 100644 docs/notebooks/host-offloading.ipynb create mode 100644 docs/notebooks/host-offloading.md diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb new file mode 100644 index 000000000000..269371fe11af --- /dev/null +++ b/docs/notebooks/host-offloading.ipynb @@ -0,0 +1,508 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "bQbS50fIdHw1" + }, + "source": [ + "(host-offloading)=\n", + "# Host Offloading\n", + "\n", + "\n", + "\n", + "This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:\n", + "\n", + "- Activation offloading\n", + "- Parameter offloading\n", + "\n", + "By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement.\n", + "\n", + "## Building Blocks for Offloading\n", + "\n", + "JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. In the following sections, you'll explore:\n", + "\n", + "- How to specify data distribution with sharding\n", + "- How to control memory placement between host and device\n", + "- How to manage data movement in jitted functions\n", + "- How to control internal sharding within computations\n", + "\n", + "### NamedSharding and Memory Kinds\n", + "\n", + "{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes:\n", + "\n", + "- Basic data distribution configuration\n", + "- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`)\n", + "- By default, `memory_kind` is set to `device` memory\n", + "- `with_memory_kind` method for creating new sharding with modified memory type" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f-6sxUlqrlBn", + "outputId": "79e7fbda-de0e-4951-9949-77039b2fae81" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)\n", + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "import numpy as np\n", + "\n", + "# Create mesh\n", + "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", + "mesh = Mesh(np.array(jax.devices()[0]).reshape(1,1), ('x','y'))\n", + "\n", + "# Device sharding - partitions data along x and y dimensions\n", + "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", + "\n", + "# Host sharding - same partitioning but in pinned host memory\n", + "s_host = s_dev.with_memory_kind('pinned_host')\n", + "\n", + "print(s_dev) # Shows device memory sharding\n", + "print(s_host) # Shows pinned host memory sharding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R_pB9465VoMP" + }, + "source": [ + "### Data Placement with device_put\n", + "\n", + "{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OJFnf7FGp6Lj", + "outputId": "a6c1fcdd-e49e-4017-c7aa-8be2e394c3a4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pinned_host\n", + "device\n" + ] + } + ], + "source": [ + "# Create a 4x8 array\n", + "arr = jnp.arange(32.0).reshape(4, 8)\n", + "\n", + "# Move arrays to different memory locations based on sharding objects\n", + "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", + "arr_dev = jax.device_put(arr, s_dev) # Places in device memory\n", + "\n", + "# Verify memory locations\n", + "print(arr_host.sharding.memory_kind) # Output: pinned_host\n", + "print(arr_dev.sharding.memory_kind) # Output: device" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HHXvBpQKTMCR" + }, + "source": [ + "### Input/Output Sharding Controls\n", + "\n", + "Shardings determine how data are split across devices. JAX provides two key parameters for controlling data placement in jitted functions:\n", + "1. `in_shardings`: controls how input arrays are partitioned when entering a jitted function\n", + "2. `out_shardings`: controls how output arrays are partitioned when leaving a jitted function\n", + " - Can differ from input sharding\n", + " - Allows different memory kinds for outputs\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "ZXNj9NUeaIdX" + }, + "outputs": [], + "source": [ + "# Function with different input and output shardings\n", + "def compute_function(x):\n", + " return x * 2\n", + "\n", + "compute_function = jax.jit(\n", + " compute_function,\n", + " in_shardings=s_host, # Input arrays will be in host memory\n", + " out_shardings=s_dev # Output arrays will be in device memory\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EbE-eBrJTBuS" + }, + "source": [ + "### Internal Sharding Control\n", + "\n", + "{func}`jax.lax.with_sharding_constraint` is a function that allows you to specify how an array should be sharded at a particular point within a JAX computation. It allows you:\n", + "- Controls sharding within computations for intermediate values and outputs\n", + "- Alternative to {func}`jax.device_put`\n", + "- Works with {func}`jax.jit`" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "LIP5A01fVcrY" + }, + "outputs": [], + "source": [ + "from jax.lax import with_sharding_constraint\n", + "\n", + "@jax.jit\n", + "def func(x):\n", + " # Force x to be sharded across devices in a specific way\n", + " x = with_sharding_constraint(x, P('x'))\n", + " return x + 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6nn8I7weaz8r" + }, + "source": [ + "With these sharding and memory placement techniques, you can apply them flexibly according to your needs in the offloading strategies. The combination of:\n", + "- {class}`~jax.sharding.NamedSharding` for data distribution\n", + "- `memory_kind` and `with_memory_kind` for memory type control\n", + "- {func}`jax.device_put` for explicit data placement\n", + "- `in_shardings` and `out_shardings` for input/output data placement in jitted functions\n", + "- {func}`jax.lax.with_sharding_constraint` for internal sharding control\n", + "\n", + "provides a comprehensive toolkit for managing data placement and movement between host and device memory, enabling efficient implementation of various offloading patterns." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UhLVvRO2p6Lj" + }, + "source": [ + "## Activation Offloading\n", + "\n", + "The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.\n", + "\n", + "To implement activation offloading effectively, you need to understand checkpoint names and policies. Here's how they work in a simple example:\n", + "\n", + "### Checkpoint Names\n", + "\n", + "The {func}`checkpoint_name` function allows you to label activations for memory management during computation. Here's a simple example:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "sLO9ceS6p6Lj" + }, + "outputs": [], + "source": [ + "from jax.ad_checkpoint import checkpoint_name\n", + "\n", + "def layer(x, w):\n", + " w1, w2 = w\n", + " x = checkpoint_name(x, \"x\")\n", + " y = x @ w1\n", + " return y @ w2, None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-_T92oCOp6Lk" + }, + "source": [ + "This example shows:\n", + "\n", + "* A simple neural network layer with two matrix multiplications\n", + "* Labeling of input activation x with identifier `\"x\"`\n", + "* Sequential operations:\n", + " 1. First multiplication: `x @ w1`\n", + " 2. Second multiplication: `y @ w2`\n", + "\n", + "The checkpoint name helps the system decide whether to:\n", + "* Keep the activation in device memory or\n", + "* Offload it to host memory during computation\n", + "\n", + "This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.\n", + "\n", + "\n", + "### Checkpoint Policies\n", + "\n", + "The {func}`jax.remat` transformation manages memory by handling intermediate values through three strategies:\n", + "\n", + "1. Recomputing during backward pass (default behavior)\n", + "2. Storing on device\n", + "3. Offloading to host memory after forward pass and loading back during backward pass\n", + "\n", + "Example of setting an offloading checkpoint policy:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "W8Usw_wOp6Lk" + }, + "outputs": [], + "source": [ + "from jax import checkpoint_policies as cp\n", + "\n", + "policy = cp.save_and_offload_only_these_names(\n", + " names_which_can_be_saved=[], # No values stored on device\n", + " names_which_can_be_offloaded=[\"x\"], # Offload activations labeled \"x\"\n", + " offload_src=\"device\", # Move from device memory\n", + " offload_dst=\"pinned_host\" # To pinned host memory\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J0XslpYzp6Lk" + }, + "source": [ + "Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "xCrxjTx_p6Lk" + }, + "outputs": [], + "source": [ + "def scanned(w, x):\n", + " remat_layer = jax.remat(layer,\n", + " policy=policy, # Use our offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UasMfG8Sp6Lk" + }, + "source": [ + "Key components:\n", + "\n", + "* {func}`jax.remat` applies our checkpoint policy to the layer function\n", + "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", + "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis\n", + "\n", + "### Example Execution\n", + "\n", + "Here's how the code initializes and executes the computation:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "y_xX3eb7p6Lk" + }, + "outputs": [], + "source": [ + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.0001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.0001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.0001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "result_activation = f((w1, w2), input) # Execute the function with weights and input" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0tx7aara42pY" + }, + "source": [ + "### Summary of Activation Offloading\n", + "\n", + "Activation offloading provides a powerful way to manage memory in large computations by:\n", + "\n", + "* Using checkpoint names to mark specific activations\n", + "* Applying policies to control where and how activations are stored\n", + "* Supporting common JAX patterns like scan operations\n", + "* Moving selected activations to host memory when device memory is under budget\n", + "\n", + "This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.\n", + "\n", + "## Parameter Offloading\n", + "\n", + "Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind.\n", + "\n", + "While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.\n", + "\n", + "### Parameter Placement for Computation\n", + "\n", + "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "1qGN2hBQdheo" + }, + "outputs": [], + "source": [ + "# Hybrid version: Both activation and parameter offloading\n", + "def hybrid_layer(x, w):\n", + " # Move model parameters w1 and w2 to host memory via device_put\n", + " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", + " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def hybrid_scanned(w, x):\n", + " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", + " policy=policy, # Use offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zcgpNztNp6Lk" + }, + "source": [ + "Note that the activation offloading implementation remains unchanged, using the same:\n", + "* Checkpoint name `\"x\"`\n", + "* Checkpoint policy\n", + "* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan`\n", + "\n", + "### Parameter Initialization with Host Offloading\n", + "\n", + "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`. Note that {func}`jax.device_put` is used here instead of `in_shardings` because:\n", + "- `in_shardings` would need to be specified in the {func}`jax.jit` decoration, affecting all inputs (both `(w1, w2)` and `input`).\n", + "- Using {func}`jax.device_put` outside the jitted function allows us to selectively place only the parameters on host memory while keeping the `input` variable on device." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lHEoG9qGp6Lk", + "outputId": "7290e342-f0f1-4c85-8155-8fc374f88f47" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results match within tolerance: True\n" + ] + } + ], + "source": [ + "# Move model parameters w1 and w2 to the host via device_put\n", + "# Initialize input and weights with small values (0.0001)\n", + "wh1 = jax.device_put(w1, s_host)\n", + "wh2 = jax.device_put(w2, s_host)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation\n", + "result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading\n", + "\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result_both[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SVpozzwHflQk" + }, + "source": [ + "The matching results verify that initializing parameters on host memory maintains computational correctness.\n", + "\n", + "### Limitation of Parameter Offloading\n", + "\n", + "{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.\n", + "\n", + "## Tools for Host Offloading\n", + "\n", + "For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading." + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V28", + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md new file mode 100644 index 000000000000..6f3975f9c1f8 --- /dev/null +++ b/docs/notebooks/host-offloading.md @@ -0,0 +1,351 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + ++++ {"id": "bQbS50fIdHw1"} + +(host-offloading)= +# Host Offloading + + + +This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on: + +- Activation offloading +- Parameter offloading + +By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement. + +## Building Blocks for Offloading + +JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. In the following sections, you'll explore: + +- How to specify data distribution with sharding +- How to control memory placement between host and device +- How to manage data movement in jitted functions +- How to control internal sharding within computations + +### NamedSharding and Memory Kinds + +{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes: + +- Basic data distribution configuration +- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`) +- By default, `memory_kind` is set to `device` memory +- `with_memory_kind` method for creating new sharding with modified memory type + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: f-6sxUlqrlBn +outputId: 79e7fbda-de0e-4951-9949-77039b2fae81 +--- +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + +# Create mesh +# 1x1 mesh represents a single device with two named dimensions (x and y) +mesh = Mesh(np.array(jax.devices()[0]).reshape(1,1), ('x','y')) + +# Device sharding - partitions data along x and y dimensions +s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device") + +# Host sharding - same partitioning but in pinned host memory +s_host = s_dev.with_memory_kind('pinned_host') + +print(s_dev) # Shows device memory sharding +print(s_host) # Shows pinned host memory sharding +``` + ++++ {"id": "R_pB9465VoMP"} + +### Data Placement with device_put + +{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: OJFnf7FGp6Lj +outputId: a6c1fcdd-e49e-4017-c7aa-8be2e394c3a4 +--- +# Create a 4x8 array +arr = jnp.arange(32.0).reshape(4, 8) + +# Move arrays to different memory locations based on sharding objects +arr_host = jax.device_put(arr, s_host) # Places in pinned host memory +arr_dev = jax.device_put(arr, s_dev) # Places in device memory + +# Verify memory locations +print(arr_host.sharding.memory_kind) # Output: pinned_host +print(arr_dev.sharding.memory_kind) # Output: device +``` + ++++ {"id": "HHXvBpQKTMCR"} + +### Input/Output Sharding Controls + +Shardings determine how data are split across devices. JAX provides two key parameters for controlling data placement in jitted functions: +1. `in_shardings`: controls how input arrays are partitioned when entering a jitted function +2. `out_shardings`: controls how output arrays are partitioned when leaving a jitted function + - Can differ from input sharding + - Allows different memory kinds for outputs + +Example: + +```{code-cell} ipython3 +:id: ZXNj9NUeaIdX + +# Function with different input and output shardings +def compute_function(x): + return x * 2 + +compute_function = jax.jit( + compute_function, + in_shardings=s_host, # Input arrays will be in host memory + out_shardings=s_dev # Output arrays will be in device memory +) +``` + ++++ {"id": "EbE-eBrJTBuS"} + +### Internal Sharding Control + +{func}`jax.lax.with_sharding_constraint` is a function that allows you to specify how an array should be sharded at a particular point within a JAX computation. It allows you: +- Controls sharding within computations for intermediate values and outputs +- Alternative to {func}`jax.device_put` +- Works with {func}`jax.jit` + +```{code-cell} ipython3 +:id: LIP5A01fVcrY + +from jax.lax import with_sharding_constraint + +@jax.jit +def func(x): + # Force x to be sharded across devices in a specific way + x = with_sharding_constraint(x, P('x')) + return x + 1 +``` + ++++ {"id": "6nn8I7weaz8r"} + +With these sharding and memory placement techniques, you can apply them flexibly according to your needs in the offloading strategies. The combination of: +- {class}`~jax.sharding.NamedSharding` for data distribution +- `memory_kind` and `with_memory_kind` for memory type control +- {func}`jax.device_put` for explicit data placement +- `in_shardings` and `out_shardings` for input/output data placement in jitted functions +- {func}`jax.lax.with_sharding_constraint` for internal sharding control + +provides a comprehensive toolkit for managing data placement and movement between host and device memory, enabling efficient implementation of various offloading patterns. + ++++ {"id": "UhLVvRO2p6Lj"} + +## Activation Offloading + +The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation. + +To implement activation offloading effectively, you need to understand checkpoint names and policies. Here's how they work in a simple example: + +### Checkpoint Names + +The {func}`checkpoint_name` function allows you to label activations for memory management during computation. Here's a simple example: + +```{code-cell} ipython3 +:id: sLO9ceS6p6Lj + +from jax.ad_checkpoint import checkpoint_name + +def layer(x, w): + w1, w2 = w + x = checkpoint_name(x, "x") + y = x @ w1 + return y @ w2, None +``` + ++++ {"id": "-_T92oCOp6Lk"} + +This example shows: + +* A simple neural network layer with two matrix multiplications +* Labeling of input activation x with identifier `"x"` +* Sequential operations: + 1. First multiplication: `x @ w1` + 2. Second multiplication: `y @ w2` + +The checkpoint name helps the system decide whether to: +* Keep the activation in device memory or +* Offload it to host memory during computation + +This pattern is common in neural networks, where multiple transformations are applied sequentially to input data. + + +### Checkpoint Policies + +The {func}`jax.remat` transformation manages memory by handling intermediate values through three strategies: + +1. Recomputing during backward pass (default behavior) +2. Storing on device +3. Offloading to host memory after forward pass and loading back during backward pass + +Example of setting an offloading checkpoint policy: + +```{code-cell} ipython3 +:id: W8Usw_wOp6Lk + +from jax import checkpoint_policies as cp + +policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], # No values stored on device + names_which_can_be_offloaded=["x"], # Offload activations labeled "x" + offload_src="device", # Move from device memory + offload_dst="pinned_host" # To pinned host memory +) +``` + ++++ {"id": "J0XslpYzp6Lk"} + +Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context. + +```{code-cell} ipython3 +:id: xCrxjTx_p6Lk + +def scanned(w, x): + remat_layer = jax.remat(layer, + policy=policy, # Use our offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) +``` + ++++ {"id": "UasMfG8Sp6Lk"} + +Key components: + +* {func}`jax.remat` applies our checkpoint policy to the layer function +* `prevent_cse=False` enables XLA's common subexpression elimination for better performance +* {func}`jax.lax.scan` iterates the rematerialized layer along an axis + +### Example Execution + +Here's how the code initializes and executes the computation: + +```{code-cell} ipython3 +:id: y_xX3eb7p6Lk + +# Initialize input and weights with small values (0.0001) +input = jnp.ones((256, 256), dtype=jnp.float32) * 0.0001 # Input matrix: 256 x 256 +w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.0001 # 10 layers of 256 x 1024 matrices +w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.0001 # 10 layers of 1024 x 256 matrices + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation +result_activation = f((w1, w2), input) # Execute the function with weights and input +``` + ++++ {"id": "0tx7aara42pY"} + +### Summary of Activation Offloading + +Activation offloading provides a powerful way to manage memory in large computations by: + +* Using checkpoint names to mark specific activations +* Applying policies to control where and how activations are stored +* Supporting common JAX patterns like scan operations +* Moving selected activations to host memory when device memory is under budget + +This approach is particularly useful when working with large models that would otherwise exceed device memory capacity. + +## Parameter Offloading + +Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind. + +While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier. + +### Parameter Placement for Computation + +Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes. + +```{code-cell} ipython3 +:id: 1qGN2hBQdheo + +# Hybrid version: Both activation and parameter offloading +def hybrid_layer(x, w): + # Move model parameters w1 and w2 to host memory via device_put + w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w) + x = checkpoint_name(x, "x") # Offload activation x to host memory + y = x @ w1 + return y @ w2, None + +def hybrid_scanned(w, x): + remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer + policy=policy, # Use offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) +``` + ++++ {"id": "zcgpNztNp6Lk"} + +Note that the activation offloading implementation remains unchanged, using the same: +* Checkpoint name `"x"` +* Checkpoint policy +* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan` + +### Parameter Initialization with Host Offloading + +During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`. Note that {func}`jax.device_put` is used here instead of `in_shardings` because: +- `in_shardings` would need to be specified in the {func}`jax.jit` decoration, affecting all inputs (both `(w1, w2)` and `input`). +- Using {func}`jax.device_put` outside the jitted function allows us to selectively place only the parameters on host memory while keeping the `input` variable on device. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: lHEoG9qGp6Lk +outputId: 7290e342-f0f1-4c85-8155-8fc374f88f47 +--- +# Move model parameters w1 and w2 to the host via device_put +# Initialize input and weights with small values (0.0001) +wh1 = jax.device_put(w1, s_host) +wh2 = jax.device_put(w2, s_host) + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation +result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading + +# Verify numerical correctness +are_close = jnp.allclose( + result_activation[0], # Result from activation offloading only + result_both[0], # Result from both activation and parameter offloading + rtol=1e-5, + atol=1e-5 +) +print(f"Results match within tolerance: {are_close}") +``` + ++++ {"id": "SVpozzwHflQk"} + +The matching results verify that initializing parameters on host memory maintains computational correctness. + +### Limitation of Parameter Offloading + +{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms. + +## Tools for Host Offloading + +For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading. From 6a680941ba3f3ab6a8db3c2fec79f51d9380c9a9 Mon Sep 17 00:00:00 2001 From: Caslyn Tonelli Date: Mon, 21 Apr 2025 11:40:49 -0700 Subject: [PATCH 0715/1769] [docs] fix typo in doc link Amend the scheme format and top-level domain. --- docs/jit-compilation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 093f5ec4ab72..a4e2f8b41f0d 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. From 333cc59c3dd05a0f4a1ab95e69852b47324fcf29 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Mon, 21 Apr 2025 11:34:57 -0700 Subject: [PATCH 0716/1769] Modify according to the code review --- docs/notebooks/host-offloading.ipynb | 238 ++++++++++++--------------- docs/notebooks/host-offloading.md | 157 +++++++----------- 2 files changed, 168 insertions(+), 227 deletions(-) diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb index 269371fe11af..a14953b12850 100644 --- a/docs/notebooks/host-offloading.ipynb +++ b/docs/notebooks/host-offloading.ipynb @@ -7,7 +7,7 @@ }, "source": [ "(host-offloading)=\n", - "# Host Offloading\n", + "# JAX Memories and Host Offloading\n", "\n", "\n", "\n", @@ -25,7 +25,6 @@ "- How to specify data distribution with sharding\n", "- How to control memory placement between host and device\n", "- How to manage data movement in jitted functions\n", - "- How to control internal sharding within computations\n", "\n", "### NamedSharding and Memory Kinds\n", "\n", @@ -45,7 +44,7 @@ "base_uri": "https://localhost:8080/" }, "id": "f-6sxUlqrlBn", - "outputId": "79e7fbda-de0e-4951-9949-77039b2fae81" + "outputId": "691a3df2-8341-44a9-a4a0-5521c2d891e3" }, "outputs": [ { @@ -65,7 +64,7 @@ "\n", "# Create mesh\n", "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", - "mesh = Mesh(np.array(jax.devices()[0]).reshape(1,1), ('x','y'))\n", + "mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))\n", "\n", "# Device sharding - partitions data along x and y dimensions\n", "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", @@ -90,13 +89,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OJFnf7FGp6Lj", - "outputId": "a6c1fcdd-e49e-4017-c7aa-8be2e394c3a4" + "outputId": "c762e1df-2453-4ed9-9d53-0defb6a05ce2" }, "outputs": [ { @@ -109,8 +108,8 @@ } ], "source": [ - "# Create a 4x8 array\n", - "arr = jnp.arange(32.0).reshape(4, 8)\n", + "# Create a 2x4 array\n", + "arr = jnp.arange(8.0).reshape(2, 4)\n", "\n", "# Move arrays to different memory locations based on sharding objects\n", "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", @@ -127,34 +126,44 @@ "id": "HHXvBpQKTMCR" }, "source": [ - "### Input/Output Sharding Controls\n", + "### Output Sharding Controls\n", "\n", - "Shardings determine how data are split across devices. JAX provides two key parameters for controlling data placement in jitted functions:\n", - "1. `in_shardings`: controls how input arrays are partitioned when entering a jitted function\n", - "2. `out_shardings`: controls how output arrays are partitioned when leaving a jitted function\n", + "Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function.\n", + "\n", + "Key Features:\n", " - Can differ from input sharding\n", " - Allows different memory kinds for outputs\n", "\n", - "Example:" + "Examples:\n", + "\n", + "#### Device Output Sharding" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": { - "id": "ZXNj9NUeaIdX" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXNj9NUeaIdX", + "outputId": "399321ef-082a-4a77-c33a-9de3421f429b" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], "source": [ - "# Function with different input and output shardings\n", - "def compute_function(x):\n", - " return x * 2\n", - "\n", - "compute_function = jax.jit(\n", - " compute_function,\n", - " in_shardings=s_host, # Input arrays will be in host memory\n", - " out_shardings=s_dev # Output arrays will be in device memory\n", - ")" + "f = jax.jit(lambda x:x, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D: \\n\", out_dev)" ] }, { @@ -163,45 +172,34 @@ "id": "EbE-eBrJTBuS" }, "source": [ - "### Internal Sharding Control\n", - "\n", - "{func}`jax.lax.with_sharding_constraint` is a function that allows you to specify how an array should be sharded at a particular point within a JAX computation. It allows you:\n", - "- Controls sharding within computations for intermediate values and outputs\n", - "- Alternative to {func}`jax.device_put`\n", - "- Works with {func}`jax.jit`" + "#### Host Output Sharding" ] }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "LIP5A01fVcrY" - }, - "outputs": [], - "source": [ - "from jax.lax import with_sharding_constraint\n", - "\n", - "@jax.jit\n", - "def func(x):\n", - " # Force x to be sharded across devices in a specific way\n", - " x = with_sharding_constraint(x, P('x'))\n", - " return x + 1" - ] - }, - { - "cell_type": "markdown", + "execution_count": 4, "metadata": { - "id": "6nn8I7weaz8r" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FjZzkxI8ky4r", + "outputId": "2a1b6e7a-1c29-4347-c020-7b47c27a5cc3" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of D2H: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], "source": [ - "With these sharding and memory placement techniques, you can apply them flexibly according to your needs in the offloading strategies. The combination of:\n", - "- {class}`~jax.sharding.NamedSharding` for data distribution\n", - "- `memory_kind` and `with_memory_kind` for memory type control\n", - "- {func}`jax.device_put` for explicit data placement\n", - "- `in_shardings` and `out_shardings` for input/output data placement in jitted functions\n", - "- {func}`jax.lax.with_sharding_constraint` for internal sharding control\n", - "\n", - "provides a comprehensive toolkit for managing data placement and movement between host and device memory, enabling efficient implementation of various offloading patterns." + "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", + "out_host = f(arr_host) # Input arrays in hte device memory while output arrays in the host memory\n", + "print(\"Result value of D2H: \\n\", out_host)" ] }, { @@ -223,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": { "id": "sLO9ceS6p6Lj" }, @@ -272,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": { "id": "W8Usw_wOp6Lk" }, @@ -291,61 +289,53 @@ { "cell_type": "markdown", "metadata": { - "id": "J0XslpYzp6Lk" + "id": "iuDRCXu7ky4r" }, "source": [ - "Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context." + "Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context.\n", + "\n", + "Key components:\n", + "* {func}`jax.remat` applies our checkpoint policy to the layer function\n", + "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", + "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": { - "id": "xCrxjTx_p6Lk" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xCrxjTx_p6Lk", + "outputId": "13d46584-9b25-4622-b3c3-f50c1dac02c2" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sample of results: [3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07]\n" + ] + } + ], "source": [ "def scanned(w, x):\n", " remat_layer = jax.remat(layer,\n", " policy=policy, # Use our offloading policy\n", " prevent_cse=False) # Allow CSE optimizations\n", " result = jax.lax.scan(remat_layer, x, w)[0]\n", - " return jnp.sum(result)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UasMfG8Sp6Lk" - }, - "source": [ - "Key components:\n", - "\n", - "* {func}`jax.remat` applies our checkpoint policy to the layer function\n", - "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", - "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis\n", - "\n", - "### Example Execution\n", + " return jnp.sum(result)\n", "\n", - "Here's how the code initializes and executes the computation:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "y_xX3eb7p6Lk" - }, - "outputs": [], - "source": [ "# Initialize input and weights with small values (0.0001)\n", - "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.0001 # Input matrix: 256 x 256\n", - "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.0001 # 10 layers of 256 x 1024 matrices\n", - "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.0001 # 10 layers of 1024 x 256 matrices\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", "\n", "# Compile and compute gradients of the scanned function\n", "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", - "result_activation = f((w1, w2), input) # Execute the function with weights and input" + "result_activation = f((w1, w2), input) # Execute the function with weights and input\n", + "print(\"Sample of results: \", result_activation[0][0, 0, :5])" ] }, { @@ -373,39 +363,8 @@ "\n", "### Parameter Placement for Computation\n", "\n", - "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "1qGN2hBQdheo" - }, - "outputs": [], - "source": [ - "# Hybrid version: Both activation and parameter offloading\n", - "def hybrid_layer(x, w):\n", - " # Move model parameters w1 and w2 to host memory via device_put\n", - " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", - " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", - " y = x @ w1\n", - " return y @ w2, None\n", + "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.\n", "\n", - "def hybrid_scanned(w, x):\n", - " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", - " policy=policy, # Use offloading policy\n", - " prevent_cse=False) # Allow CSE optimizations\n", - " result = jax.lax.scan(remat_layer, x, w)[0]\n", - " return jnp.sum(result)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zcgpNztNp6Lk" - }, - "source": [ "Note that the activation offloading implementation remains unchanged, using the same:\n", "* Checkpoint name `\"x\"`\n", "* Checkpoint policy\n", @@ -413,20 +372,18 @@ "\n", "### Parameter Initialization with Host Offloading\n", "\n", - "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`. Note that {func}`jax.device_put` is used here instead of `in_shardings` because:\n", - "- `in_shardings` would need to be specified in the {func}`jax.jit` decoration, affecting all inputs (both `(w1, w2)` and `input`).\n", - "- Using {func}`jax.device_put` outside the jitted function allows us to selectively place only the parameters on host memory while keeping the `input` variable on device." + "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device." ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "lHEoG9qGp6Lk", - "outputId": "7290e342-f0f1-4c85-8155-8fc374f88f47" + "id": "1qGN2hBQdheo", + "outputId": "48c09658-f8b6-4be3-ef0e-02e0e2566e10" }, "outputs": [ { @@ -438,6 +395,21 @@ } ], "source": [ + "# Hybrid version: Both activation and parameter offloading\n", + "def hybrid_layer(x, w):\n", + " # Move model parameters w1 and w2 to host memory via device_put\n", + " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", + " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def hybrid_scanned(w, x):\n", + " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", + " policy=policy, # Use offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", "# Move model parameters w1 and w2 to the host via device_put\n", "# Initialize input and weights with small values (0.0001)\n", "wh1 = jax.device_put(w1, s_host)\n", diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md index 6f3975f9c1f8..96f59ee7f46e 100644 --- a/docs/notebooks/host-offloading.md +++ b/docs/notebooks/host-offloading.md @@ -15,7 +15,7 @@ kernelspec: +++ {"id": "bQbS50fIdHw1"} (host-offloading)= -# Host Offloading +# JAX Memories and Host Offloading @@ -33,7 +33,6 @@ JAX provides several key components for controlling where and how data are store - How to specify data distribution with sharding - How to control memory placement between host and device - How to manage data movement in jitted functions -- How to control internal sharding within computations ### NamedSharding and Memory Kinds @@ -49,7 +48,7 @@ JAX provides several key components for controlling where and how data are store colab: base_uri: https://localhost:8080/ id: f-6sxUlqrlBn -outputId: 79e7fbda-de0e-4951-9949-77039b2fae81 +outputId: 691a3df2-8341-44a9-a4a0-5521c2d891e3 --- import jax import jax.numpy as jnp @@ -58,7 +57,7 @@ import numpy as np # Create mesh # 1x1 mesh represents a single device with two named dimensions (x and y) -mesh = Mesh(np.array(jax.devices()[0]).reshape(1,1), ('x','y')) +mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y')) # Device sharding - partitions data along x and y dimensions s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device") @@ -81,10 +80,10 @@ print(s_host) # Shows pinned host memory sharding colab: base_uri: https://localhost:8080/ id: OJFnf7FGp6Lj -outputId: a6c1fcdd-e49e-4017-c7aa-8be2e394c3a4 +outputId: c762e1df-2453-4ed9-9d53-0defb6a05ce2 --- -# Create a 4x8 array -arr = jnp.arange(32.0).reshape(4, 8) +# Create a 2x4 array +arr = jnp.arange(8.0).reshape(2, 4) # Move arrays to different memory locations based on sharding objects arr_host = jax.device_put(arr, s_host) # Places in pinned host memory @@ -97,62 +96,46 @@ print(arr_dev.sharding.memory_kind) # Output: device +++ {"id": "HHXvBpQKTMCR"} -### Input/Output Sharding Controls +### Output Sharding Controls -Shardings determine how data are split across devices. JAX provides two key parameters for controlling data placement in jitted functions: -1. `in_shardings`: controls how input arrays are partitioned when entering a jitted function -2. `out_shardings`: controls how output arrays are partitioned when leaving a jitted function +Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function. + +Key Features: - Can differ from input sharding - Allows different memory kinds for outputs -Example: - -```{code-cell} ipython3 -:id: ZXNj9NUeaIdX +Examples: -# Function with different input and output shardings -def compute_function(x): - return x * 2 +#### Device Output Sharding -compute_function = jax.jit( - compute_function, - in_shardings=s_host, # Input arrays will be in host memory - out_shardings=s_dev # Output arrays will be in device memory -) +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: ZXNj9NUeaIdX +outputId: 399321ef-082a-4a77-c33a-9de3421f429b +--- +f = jax.jit(lambda x:x, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D: \n", out_dev) ``` +++ {"id": "EbE-eBrJTBuS"} -### Internal Sharding Control - -{func}`jax.lax.with_sharding_constraint` is a function that allows you to specify how an array should be sharded at a particular point within a JAX computation. It allows you: -- Controls sharding within computations for intermediate values and outputs -- Alternative to {func}`jax.device_put` -- Works with {func}`jax.jit` +#### Host Output Sharding ```{code-cell} ipython3 -:id: LIP5A01fVcrY - -from jax.lax import with_sharding_constraint - -@jax.jit -def func(x): - # Force x to be sharded across devices in a specific way - x = with_sharding_constraint(x, P('x')) - return x + 1 +--- +colab: + base_uri: https://localhost:8080/ +id: FjZzkxI8ky4r +outputId: 2a1b6e7a-1c29-4347-c020-7b47c27a5cc3 +--- +f = jax.jit(lambda x: x, out_shardings=s_dev) +out_host = f(arr_host) # Input arrays in hte device memory while output arrays in the host memory +print("Result value of D2H: \n", out_host) ``` -+++ {"id": "6nn8I7weaz8r"} - -With these sharding and memory placement techniques, you can apply them flexibly according to your needs in the offloading strategies. The combination of: -- {class}`~jax.sharding.NamedSharding` for data distribution -- `memory_kind` and `with_memory_kind` for memory type control -- {func}`jax.device_put` for explicit data placement -- `in_shardings` and `out_shardings` for input/output data placement in jitted functions -- {func}`jax.lax.with_sharding_constraint` for internal sharding control - -provides a comprehensive toolkit for managing data placement and movement between host and device memory, enabling efficient implementation of various offloading patterns. - +++ {"id": "UhLVvRO2p6Lj"} ## Activation Offloading @@ -217,44 +200,38 @@ policy = cp.save_and_offload_only_these_names( ) ``` -+++ {"id": "J0XslpYzp6Lk"} ++++ {"id": "iuDRCXu7ky4r"} Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context. -```{code-cell} ipython3 -:id: xCrxjTx_p6Lk +Key components: +* {func}`jax.remat` applies our checkpoint policy to the layer function +* `prevent_cse=False` enables XLA's common subexpression elimination for better performance +* {func}`jax.lax.scan` iterates the rematerialized layer along an axis +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: xCrxjTx_p6Lk +outputId: 13d46584-9b25-4622-b3c3-f50c1dac02c2 +--- def scanned(w, x): remat_layer = jax.remat(layer, policy=policy, # Use our offloading policy prevent_cse=False) # Allow CSE optimizations result = jax.lax.scan(remat_layer, x, w)[0] return jnp.sum(result) -``` - -+++ {"id": "UasMfG8Sp6Lk"} - -Key components: - -* {func}`jax.remat` applies our checkpoint policy to the layer function -* `prevent_cse=False` enables XLA's common subexpression elimination for better performance -* {func}`jax.lax.scan` iterates the rematerialized layer along an axis - -### Example Execution - -Here's how the code initializes and executes the computation: - -```{code-cell} ipython3 -:id: y_xX3eb7p6Lk # Initialize input and weights with small values (0.0001) -input = jnp.ones((256, 256), dtype=jnp.float32) * 0.0001 # Input matrix: 256 x 256 -w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.0001 # 10 layers of 256 x 1024 matrices -w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.0001 # 10 layers of 1024 x 256 matrices +input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256 +w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices +w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices # Compile and compute gradients of the scanned function f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation result_activation = f((w1, w2), input) # Execute the function with weights and input +print("Sample of results: ", result_activation[0][0, 0, :5]) ``` +++ {"id": "0tx7aara42pY"} @@ -280,9 +257,22 @@ While parameter offloading and activation offloading are distinct memory optimiz Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes. -```{code-cell} ipython3 -:id: 1qGN2hBQdheo +Note that the activation offloading implementation remains unchanged, using the same: +* Checkpoint name `"x"` +* Checkpoint policy +* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan` + +### Parameter Initialization with Host Offloading +During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: 1qGN2hBQdheo +outputId: 48c09658-f8b6-4be3-ef0e-02e0e2566e10 +--- # Hybrid version: Both activation and parameter offloading def hybrid_layer(x, w): # Move model parameters w1 and w2 to host memory via device_put @@ -297,28 +287,7 @@ def hybrid_scanned(w, x): prevent_cse=False) # Allow CSE optimizations result = jax.lax.scan(remat_layer, x, w)[0] return jnp.sum(result) -``` - -+++ {"id": "zcgpNztNp6Lk"} - -Note that the activation offloading implementation remains unchanged, using the same: -* Checkpoint name `"x"` -* Checkpoint policy -* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan` -### Parameter Initialization with Host Offloading - -During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`. Note that {func}`jax.device_put` is used here instead of `in_shardings` because: -- `in_shardings` would need to be specified in the {func}`jax.jit` decoration, affecting all inputs (both `(w1, w2)` and `input`). -- Using {func}`jax.device_put` outside the jitted function allows us to selectively place only the parameters on host memory while keeping the `input` variable on device. - -```{code-cell} ipython3 ---- -colab: - base_uri: https://localhost:8080/ -id: lHEoG9qGp6Lk -outputId: 7290e342-f0f1-4c85-8155-8fc374f88f47 ---- # Move model parameters w1 and w2 to the host via device_put # Initialize input and weights with small values (0.0001) wh1 = jax.device_put(w1, s_host) From 36cd3137923a76c794890a222828c3a83e987cba Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 21 Apr 2025 12:42:59 -0700 Subject: [PATCH 0717/1769] Allow all reshapes if the operand is fully replicated PiperOrigin-RevId: 749885945 --- jax/_src/lax/lax.py | 2 ++ tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index da9d64535cd1..8db5b09d913c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7032,6 +7032,8 @@ def _merge_on_one_axis(operand, new_sizes): def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): if sharding is not None: return sharding + if operand.sharding.is_fully_replicated: + return operand.sharding non_1s_op_shape = [s for s in operand.shape if s != 1] non_1s_new_shape = [s for s in new_sizes if s != 1] if non_1s_op_shape == non_1s_new_shape: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d8b2cf9483be..4d38d612e470 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5519,6 +5519,20 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_reshape(self, mesh): + np_inp = np.arange(64).reshape(64, 1) + arr = jax.device_put(np_inp, P(('x', 'y'))) + + @jax.jit + def f(x): + x = reshard(x, P(None, None)) + return jax.lax.reshape(x, (2, 32, 1)) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, None))) + self.assertArraysEqual(out, np_inp.reshape(2, 32, 1)) + @parameterized.named_parameters( ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), From d3a84289e3d50921a891912e83767f1a34407bcf Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 21 Apr 2025 13:54:19 -0700 Subject: [PATCH 0718/1769] Switch JAX to the new `ProgramShape::AddParameter()` API. PiperOrigin-RevId: 749908416 --- jaxlib/xla/xla_compiler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index bea3062c64e4..11ec4cdc3846 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -691,7 +691,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { [](ProgramShape* self, absl::Span params, Shape result) { new (self) ProgramShape(); for (const Shape& param : params) { - *self->add_parameters() = param; + self->AddParameter(param, ""); } *self->mutable_result() = result; }) From 33311b50ecf6268ca81f8a8eb4a058026e7a5d2b Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Thu, 20 Mar 2025 15:55:45 -0700 Subject: [PATCH 0719/1769] [Pallas] Support the new TPU interpret mode with pl.core_map. NOTE: The new TPU interpret mode does not yet support Megacore, so this only enables pl.core_map over a TensorCoreMesh with shape (axis_name, 1). Also adds a num_cores argument to pltpu.create_tensorcore_mesh. --- jax/_src/pallas/core.py | 20 +++++++++++-- jax/_src/pallas/mosaic/core.py | 14 +++++---- jax/_src/pallas/mosaic/interpret.py | 25 ++++++++++------ tests/pallas/tpu_pallas_interpret_test.py | 35 +++++++++++++++++++++++ 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 6af78d3e6ec8..9dda95d0c7f0 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1168,11 +1168,19 @@ def wrapped(f): @core_map_p.def_effectful_abstract_eval -def _core_map_abstract_eval(*args, jaxpr, mesh, **_): +def _core_map_abstract_eval(*args, jaxpr, mesh, **kwargs): del args if jaxpr.outvars: raise ValueError("core_map must not return any outputs.") + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: if mesh.discharges_effect(eff): continue @@ -1264,10 +1272,18 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwa def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): - del in_atoms, kwargs + del in_atoms with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())): jax_core.check_jaxpr(jaxpr) + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: if mesh.discharges_effect(eff): continue diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 130e0eabb413..3e5403d889c8 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -196,12 +196,16 @@ def discharges_effect(self, effect: jax_core.Effect): def create_tensorcore_mesh( - axis_name: str, devices: Sequence[jax.Device] | None = None + axis_name: str, + devices: Sequence[jax.Device] | None = None, + num_cores: int | None = None, ) -> TensorCoreMesh: - # TODO(b/355036384): emit a better error if we don't have tensorcores. - if devices is None: - devices = jax.devices() - num_cores = devices[0].num_cores + if devices is not None and num_cores is not None: + raise ValueError('cannot specify both devices and num_cores') + if num_cores is None: + if devices is None: + devices = jax.devices() + num_cores = devices[0].num_cores return TensorCoreMesh( np.array([TensorCore(i) for i in range(num_cores)]), [axis_name], diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5921b161d598..95dec0cd937e 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -622,8 +622,10 @@ def _allocate_semaphores(device_id, shape): ).reshape(shape) -TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = { +TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | pallas_core.MemorySpace | None, int] = { v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)} +TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = ( + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY]) TPU_MEMORY_SPACE_NAMES = { i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)} @@ -1019,7 +1021,7 @@ class Placeholder: shape: tuple[int, ...] dtype: jnp.dtype -def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): +def _interpret_jaxpr(jaxpr, *args, mesh, compiler_params, interpret_params): env = {} def read(var): @@ -1051,7 +1053,7 @@ def write(var, value): # - Handle other higher-order primitives? # - Megacore. _interpret = functools.partial( - _interpret_jaxpr, compiler_params=compiler_params, + _interpret_jaxpr, mesh=mesh, compiler_params=compiler_params, interpret_params=interpret_params) for eqn in jaxpr.eqns: with source_info_util.user_context( @@ -1110,6 +1112,12 @@ def write(var, value): elif prim is verification.pretend_p: out = [] + elif ((prim is lax.axis_index_p) + and (mesh is not None) and (eqn.params['axis_name'] in mesh.shape)): + # For now, there can only be one core. + # TODO(jburnim): Support two Megacore cores. + out = jnp.int32(0) + elif prim is lax.cond_p: def _make_branch(jaxpr): return lambda *args: _interpret(jaxpr, *args) @@ -1342,10 +1350,10 @@ def f(*args, jaxpr): return jax._src.util.safe_map(read, jaxpr.outvars) def _compute_start_indices( - block_mapping, loop_idx, *args, compiler_params, interpret_params): + block_mapping, loop_idx, *args, mesh, compiler_params, interpret_params): jaxpr = block_mapping.index_map_jaxpr block_indices = _interpret_jaxpr( - jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, + jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, mesh=mesh, compiler_params=compiler_params, interpret_params=interpret_params) def _get_start_index(i, b): match b: @@ -1523,7 +1531,7 @@ def interpret_pallas_call( out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, ): - del debug, mesh, cost_estimate, out_avals + del debug, cost_estimate, out_avals # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) dynamic_grid_args, scalars, input_args = split_list( @@ -1732,12 +1740,12 @@ def body( bm, next_grid_point, *scalar_buffer_ids, + mesh=mesh, compiler_params=compiler_params, interpret_params=interpret_params, ) for bm in grid_mapping.block_mappings ] - # Copy slices of the input to the kernel buffers. def _store_slice_to_kernel_input(index, input_var): @@ -1795,7 +1803,7 @@ def _store_slice_to_kernel_input(index, input_var): ) # Invoke the kernel. - _interpret_jaxpr(jaxpr, *kernel_buffer_ids, + _interpret_jaxpr(jaxpr, *kernel_buffer_ids, mesh=mesh, compiler_params=compiler_params, interpret_params=interpret_params) @@ -1865,6 +1873,7 @@ def _store_to_output_buffer(index, output_var): bm, initial_grid_point, *scalar_buffer_ids, + mesh=mesh, compiler_params=compiler_params, interpret_params=interpret_params, ) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index c4bf07f39cef..3f40f3cce0a2 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -486,6 +486,41 @@ def kernel_call_dynamic_parallel_dimension(): with self.assertRaises(jax.errors.ConcretizationTypeError): kernel_call_dynamic_parallel_dimension() + def test_core_map_over_one_core(self): + mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(mesh, interpret=mosaic_interpret.TPUInterpretParams()) + def _(): + num_cores = jax.lax.psum(1, "x") + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, sem): + core_index = jax.lax.axis_index("x") + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + sem, + ).wait() + y = x_vmem_ref[...] + 1 + jax.lax.axis_index("x") + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + ) + _, y = pl.run_state(inner)((x, y)) + return y + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + y = f(x) + np.testing.assert_array_equal(y, x + 1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 640de9e1534692b43dab5f23ee8fa82a395c0414 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 21 Apr 2025 17:01:01 -0700 Subject: [PATCH 0720/1769] [ragged-paged-attn] Add auto-tuned table that is being used for vLLM Migrate auto-tuned table from https://github.com/pytorch/xla/blob/master/torch_xla/experimental/tuned_block_sizes.py PiperOrigin-RevId: 749965181 --- .../tpu/ragged_paged_attention/__init__.py | 23 ++++ .../kernel.py} | 10 +- .../tuned_block_sizes.py | 106 ++++++++++++++++++ 3 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py rename jax/experimental/pallas/ops/tpu/{ragged_paged_attention.py => ragged_paged_attention/kernel.py} (98%) create mode 100644 jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py new file mode 100644 index 000000000000..3830adfa7fd6 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import kernel +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import tuned_block_sizes + +cdiv = kernel.cdiv +dynamic_validate_inputs = kernel.dynamic_validate_inputs +ragged_paged_attention = kernel.ragged_paged_attention +ref_ragged_paged_attention = kernel.ref_ragged_paged_attention +static_validate_inputs = kernel.static_validate_inputs +get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py similarity index 98% rename from jax/experimental/pallas/ops/tpu/ragged_paged_attention.py rename to jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index 203dc8a7602a..a74727adff64 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -24,6 +24,7 @@ from jax import lax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -700,8 +701,8 @@ def ragged_paged_attention( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int = 16, - num_queries_per_block: int = 128, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ): """Ragged paged attention that supports mixed prefill and decode. @@ -749,8 +750,13 @@ def ragged_paged_attention( _, page_size, num_combined_kv_heads, _ = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + num_q_heads, num_kv_heads, num_q_tokens, page_size, pages_per_seq + ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads num_q_blks = cdiv(num_q_tokens, num_q_per_blk) num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py new file mode 100644 index 000000000000..85f22f58ae3f --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py @@ -0,0 +1,106 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auto-tuned block sizes for ragged paged attention.""" + +import jax + + +# TODO: add more tuned block sizes in the table +# ragged_paged_attention +# key: (num_q_head, num_kv_head, num_q_tokens, max_model_len) +# value: (num_kv_pages_per_block, num_queries_per_block) +TUNED_BLOCK_SIZES = { + # go/keep-sorted start + (1, 1, 1024, 128): (32, 32), + (1, 1, 1024, 2048): (64, 32), + (1, 1, 1024, 4096): (64, 32), + (1, 1, 1024, 64): (32, 32), + (32, 8, 1024, 128): (32, 32), + (32, 8, 1024, 2048): (64, 32), + (32, 8, 1024, 4096): (64, 32), + (32, 8, 1024, 64): (32, 32), + (32, 8, 2048, 128): (32, 32), + (32, 8, 2048, 2048): (128, 32), + (32, 8, 2048, 4096): (128, 32), + (32, 8, 2048, 64): (32, 32), + (32, 8, 4096, 128): (32, 32), + (32, 8, 4096, 2048): (128, 64), + (32, 8, 4096, 4096): (128, 64), + (32, 8, 4096, 64): (32, 32), + (4, 1, 2048, 128): (32, 32), + (4, 1, 2048, 2048): (128, 64), + (4, 1, 2048, 4096): (128, 64), + (4, 1, 2048, 64): (32, 32), + (4, 1, 4096, 128): (32, 32), + (4, 1, 4096, 2048): (128, 128), + (4, 1, 4096, 4096): (128, 128), + (4, 1, 4096, 64): (32, 32), + # go/keep-sorted end +} + + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(num_q_head, num_kv_head, num_q_tokens, max_model_len): + num_q_tokens = next_power_of_2(num_q_tokens) + max_model_len = next_power_of_2(max_model_len) + return num_q_head, num_kv_head, num_q_tokens, max_model_len + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[: -len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_tuned_block_sizes( + num_q_head, num_kv_head, num_q_tokens, page_size, pages_per_seq +) -> tuple[int, int]: + """Searchs for best (num_kv_pages_per_blk, num_queries_per_blk).""" + if get_tpu_version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + if get_tpu_version() == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + num_kv_pages_per_blk = 16 + num_queries_per_blk = 128 + return num_kv_pages_per_blk, num_queries_per_blk + + max_model_len = pages_per_seq * page_size + key = simplify_key(num_q_head, num_kv_head, num_q_tokens, max_model_len) + num_kv_pages_per_blk, num_queries_per_blk = TUNED_BLOCK_SIZES.get( + key, (128, 32) + ) + num_kv_pages_per_blk = min(num_kv_pages_per_blk, pages_per_seq) + num_queries_per_blk = min(num_queries_per_blk, num_q_tokens) + return num_kv_pages_per_blk, num_queries_per_blk From d0b6eb2e05cdbcc9c83e770b6fcbb0378248e511 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 21 Apr 2025 17:51:38 -0700 Subject: [PATCH 0721/1769] [Pallas][jax] Better error message for unexpected types in standard abstract eval This can happen if a user forgets to unwrap a ref! @asabne had this happen to him today, and he was confused as to what was going on. The prior error is unclear: AssertionError: (MemRef{float32[2,1024,1024]}, MemRef{float32[1,1024,1024]}) PiperOrigin-RevId: 749979253 --- jax/_src/lax/utils.py | 16 ++++++++++++++-- tests/pallas/pallas_test.py | 23 ++++++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 8e97621912f1..9e033cadd933 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -22,9 +22,10 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib -from jax._src.util import safe_zip +from jax._src import state +from jax._src.named_sharding import DuplicateSpecError, NamedSharding from jax._src.partition_spec import PartitionSpec as P -from jax._src.named_sharding import NamedSharding, DuplicateSpecError +from jax._src.util import safe_zip zip, unsafe_zip = safe_zip, zip @@ -103,6 +104,17 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): + for a in avals: + if isinstance(a, state.AbstractRef): + raise ValueError( + f' Attempting to pass a Ref {a} to a primitive:' + f' {prim} - did you forget to unpack ([...]) the ref?' + ) + if not isinstance(a, core.UnshapedArray): + raise ValueError( + f'Attempting to pass an unexpected type {a} to a' + f' primitive: {prim}' + ) assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 54e4e4c47784..fd54c1f6065a 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1100,7 +1100,6 @@ def my_index_map(): "Currently returning 2 values."): f(dict(one=a, two=a)) - def test_pallas_call_index_map_wrong_return_type(self): a = np.arange(256, dtype=np.int32) def my_index_map(i): @@ -1214,6 +1213,28 @@ def test_pallas_call_input_output_aliases_errors(self): out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], input_output_aliases={1: 0})(x, x) + def test_pallas_error_for_ref_to_jax(self): + m, n, k = 8, 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + ) + def dot_general_kernel(x_ref, y_ref, o_ref): + o_ref[...] = jax.lax.dot_general(x_ref, y_ref, (((2), (1)), ((1,), (2,)))) + + key1, key2 = random.split(random.key(0)) + x = random.normal(key1, (m, k), dtype=jnp.float32) + y = random.normal(key2, (k, n), dtype=jnp.float32) + with self.assertRaisesRegex( + ValueError, + r" Attempting to pass a Ref" + r" MemRef{float32\[8,32\]}" + r" to a primitive: dot_general - did you forget to unpack \(\[...\]\)" + r" the ref?", + ): + dot_general_kernel(x, y) + class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True From f346fd0977012478ef4d078416495fffe4e1c906 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 21 Apr 2025 18:32:28 -0700 Subject: [PATCH 0722/1769] [JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Array creation when possible This changes makes use of the new `xla::ifrt::Client::MakeArraysFromHostBufferShards()` API when possible. This API needs a single call to create a multi-shard IFRT Array (to be wrapped as a JAX `PyArray`), which provides more optimization opportunities for the runtime than creating single-device IFRT Arrays and then assembling them. Please note that `xla::ifrt::Client::MakeArraysFromHostBufferShards()` implementation in PjRt-IFRT is not yet optimized, so there is no immediate performance benefits for McJAX. As an exception, it takes the previous path of array assembly if any shard for `BatchedDevicePut` is not a host buffer, but already a single-device array, because `xla::ifrt::Client::MakeArraysFromHostBufferShards()` works only if all the sharded input to be host buffers. With batching possible at IFRT level, we now skip `DevicePutResultFn` step; `DevicePut` (now `DevicePutWithDevice` and `DevicePutWithSharding`) internally calls per-shard functions (with GIL released) and returns a final IFRT Array. This change includes a code cleanup for `xla::DevicePutResult::owning_pybuffer`, which was originally intended to hold a Python object to keep an IFRT Array valid when it is created from `DevicePut()` implementations, but this role has been entirely covered by `on_done_with_host_buffer` function supplied at IFRT Array creation time. PiperOrigin-RevId: 749989229 --- jaxlib/xla/pjit.cc | 42 ++- jaxlib/xla/pmap_lib.cc | 72 +---- jaxlib/xla/py_array.cc | 74 +---- jaxlib/xla/py_client.cc | 34 +- jaxlib/xla/py_values.cc | 665 ++++++++++++++++++++++++++++------------ jaxlib/xla/py_values.h | 82 +++-- jaxlib/xla/sharding.cc | 3 +- jaxlib/xla/sharding.h | 2 +- 8 files changed, 587 insertions(+), 387 deletions(-) diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 50bdc750d3a4..02d8fc6efd01 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -434,13 +434,15 @@ void CallShardArgFallback( // Prepares the input PjRtBuffers from the python arguments. This is equivalent // to shard_args() in pxla.py but for only a few supported cases. absl::StatusOr>> -PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, - absl::Span flat_dynamic_args, - bool enable_x64, const std::vector& kept_args, - const std::vector& in_shardings, - const std::vector& in_device_local_layouts, - const nb::callable& shard_arg_fallback, - std::vector& keep_alive_objects) { +PrepareIfrtInputs( + const xla::PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + absl::Span flat_dynamic_arg_signatures, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { const auto& addressable_devices = executable.ifrt_loaded_executable()->addressable_devices(); const auto& num_global_devices = @@ -484,20 +486,11 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); TF_ASSIGN_OR_RETURN( - auto on_device_fn, - DevicePut(arg, executable.ifrt_loaded_executable()->client(), - data_device, options, xla::ifrt::MemoryKind())); - TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device, [&]() { - // Must release the GIL before calling IFRT because backends may - // decide to block/sleep for device buffer allocation. - nb::gil_scoped_release gil_release; - return std::move(on_device_fn)(); - }()); - - num_args_arrays.push_back(std::move(on_device.ifrt_array)); - if (on_device.owning_pybuffer) { - keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); - } + auto device_put_result, + DevicePutWithDevice(arg, + executable.ifrt_loaded_executable()->client(), + data_device, xla::ifrt::MemoryKind(), options)); + num_args_arrays.push_back(std::move(device_put_result.ifrt_array)); continue; } else { CallShardArgFallback(arg, in_shardings[dce_index], @@ -750,9 +743,10 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, // A vector of [num_inputs]. auto num_args_arrays = PrepareIfrtInputs( *cache_entry->executable, flat_dynamic_args, - call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, - cache_entry->in_shardings, cache_entry->in_device_local_layouts, - shard_arg_fallback_, keep_alive_objects); + call_signature.dynamic_arg_signatures, call_signature.jax_enable_x64, + cache_entry->kept_var_bitvec, cache_entry->in_shardings, + cache_entry->in_device_local_layouts, shard_arg_fallback_, + keep_alive_objects); if (!num_args_arrays.ok()) { VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 295ac8bfccfb..94a79e8ba0b9 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -56,7 +56,6 @@ limitations under the License. #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharded_device_array.h" #include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/to_ifrt_sharding.h" #include "jaxlib/xla/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" @@ -65,7 +64,6 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" @@ -186,74 +184,36 @@ absl::StatusOr ShardArg( indices.size(), n_devices); } - std::vector> per_device_arrays; - per_device_arrays.reserve(n_devices); - absl::InlinedVector devices; - devices.reserve(n_devices); - // TODO(hyeontaek): The created array will never be disassembled. We should - // omit collecting shapes and make the OpaqueSharding non-disassemblable? - std::vector shapes; - shapes.reserve(n_devices); - - nb::list owning_pylist; ShardArgResult result; - result.owning_sda = owning_pylist; const bool jax_enable_x64 = GetEnableX64(); - std::vector device_put_fns; - device_put_fns.reserve(n_devices); + std::vector owning_args; + std::vector args; + owning_args.reserve(n_devices); + args.reserve(n_devices); xla::DevicePutOptions options; options.squash_64bit_types = !jax_enable_x64; options.allow_zero_copy = true; + xla::ifrt::Client* ifrt_client = nullptr; for (size_t i = 0; i < n_devices; ++i) { auto to_device = nb::cast(py_devices_list[i]); if (to_device->client().get() == nullptr) { return xla::InvalidArgument("Cannot copy to unattached devices."); } - - TF_ASSIGN_OR_RETURN( - device_put_fns.emplace_back(), - DevicePut(arg[indices[i]], to_device->client()->ifrt_client(), - to_device->device(), options, xla::ifrt::MemoryKind())); - } - std::vector device_puts; - device_puts.reserve(n_devices); - { - nb::gil_scoped_release gil_release; - for (auto& device_put_fn : device_put_fns) { - TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); - device_puts.push_back(std::move(device_put)); - } - } - for (auto& device_put : device_puts) { - per_device_arrays.push_back(std::move(device_put.ifrt_array)); - devices.push_back( - per_device_arrays.back()->sharding().devices()->devices().front()); - shapes.push_back(per_device_arrays.back()->shape()); - if (device_put.owning_pybuffer) { - owning_pylist.append(device_put.owning_pybuffer); + if (i == 0) { + ifrt_client = to_device->client()->ifrt_client(); } + owning_args.push_back(arg[indices[i]]); + args.push_back(owning_args.back()); } - - if (per_device_arrays.empty()) { - return xla::InvalidArgument("Per-device arrays must not be empty."); - } - // TODO(hyeontaek): The logical shape here is inaccurate. We - // may want to avoid creating a new Array or specialize Array - // to disallow access to the logical shape. - xla::ifrt::Shape shape = per_device_arrays.front()->shape(); - TF_ASSIGN_OR_RETURN( - auto ifrt_sharding, - xla::GetIfrtConcreteSharding(input_spec.array_sharding, shape, shapes)); + CHECK(ifrt_client != nullptr); TF_ASSIGN_OR_RETURN( - result.ifrt_array, - per_device_arrays.front() - ->client() - ->AssembleArrayFromSingleDeviceArrays( - std::move(shape), std::move(ifrt_sharding), - absl::MakeSpan(per_device_arrays), - xla::ifrt::ArrayCopySemantics::kReuseInput, - xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + xla::DevicePutResult device_put_result, + xla::DevicePutWithSharding( + args, ifrt_client, ndarray.dtype(), + nb::cast>(ndarray.attr("shape")), + input_spec.array_sharding, options)); + result.ifrt_array = std::move(device_put_result.ifrt_array); return result; } tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index e3321c4c88ce..584f895e32d2 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -1257,12 +1257,7 @@ absl::StatusOr PyArray::BatchedDevicePut( options.allow_zero_copy = (!force_copy && (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); - if (!dst_devices.empty()) { - options.ifrt_user_context = - dst_devices.front()->client()->ifrt_client()->CreateUserContext(); - } - nb::list owning_pylist; std::vector> ifrt_arrays; absl::InlinedVector devices; @@ -1270,12 +1265,9 @@ absl::StatusOr PyArray::BatchedDevicePut( std::vector shapes; shapes.reserve(n_devices); - ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); - - std::vector device_put_fns; - device_put_fns.reserve(xs.size()); - size_t i = 0; - for (auto& x : xs) { + std::vector args; + args.reserve(xs.size()); + for (const nb::object& x : xs) { if (PyArray::IsPyArray(x)) { TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); @@ -1283,63 +1275,23 @@ absl::StatusOr PyArray::BatchedDevicePut( TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); } - TF_ASSIGN_OR_RETURN( - device_put_fns.emplace_back(), - DevicePut(x, dst_devices[i]->client()->ifrt_client(), - dst_devices[i]->device(), options, dst_memory_kind)); - ++i; - } - std::vector device_puts; - device_puts.reserve(device_put_fns.size()); - { - nb::gil_scoped_release gil_release; - for (auto& device_put_fn : device_put_fns) { - TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); - device_puts.push_back(std::move(device_put)); - } - } - for (auto& device_put : device_puts) { - ifrt_arrays.push_back(std::move(device_put.ifrt_array)); - devices.push_back( - ifrt_arrays.back()->sharding().devices()->devices().front()); - shapes.push_back(ifrt_arrays.back()->shape()); - if (device_put.owning_pybuffer) { - owning_pylist.append(device_put.owning_pybuffer); - } + args.push_back(x); } - - // TODO(phawkins): it's highly suspicious to me that owning_pylist isn't - // consumed here. Look into this. - auto weak_type = nb::cast(aval.attr("weak_type")); auto dtype = aval.attr("dtype"); auto shape = nb::cast>(aval.attr("shape")); + TF_ASSIGN_OR_RETURN(nb_class_ptr py_device_list, + jax::GetPyDeviceList(sharding)); TF_ASSIGN_OR_RETURN( - auto ifrt_sharding, - sharding.type().is(jax::PmapSharding::type()) - ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), - std::move(shapes)) - : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape))); - TF_ASSIGN_OR_RETURN(auto ifrt_dtype, DtypeToIfRtDType(dtype)); - // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are - // supported. - ifrt::DType array_dtype = - ifrt_arrays.empty() ? ifrt_dtype : ifrt_arrays.front()->dtype(); - TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding)); - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - py_device_list->py_client() - ->ifrt_client() - ->AssembleArrayFromSingleDeviceArrays( - array_dtype, ifrt::Shape(shape), std::move(ifrt_sharding), - absl::MakeSpan(ifrt_arrays), - xla::ifrt::ArrayCopySemantics::kReuseInput, - xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); - - return PyArray(aval, weak_type, dtype, std::move(shape), sharding, + DevicePutResult device_put_result, + DevicePutWithSharding(args, py_device_list->py_client()->ifrt_client(), + dtype, shape, sharding, options)); + + return PyArray(aval, weak_type, dtype, std::move(shape), std::move(sharding), py_device_list->py_client(), Traceback::Get(), - std::move(ifrt_array), committed, /*skip_checks=*/true); + std::move(device_put_result.ifrt_array), committed, + /*skip_checks=*/true); } absl::StatusOr PyArray::ReorderShards( diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index c4e8449b85c3..8b42da2fc9bd 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -57,6 +57,7 @@ limitations under the License. #include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/py_values.h" #include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/sharding.h" #include "jaxlib/xla/traceback.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" @@ -66,6 +67,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -339,25 +341,19 @@ absl::Status PyClient::Defragment() { options.allow_zero_copy = (!force_copy && (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); - TF_ASSIGN_OR_RETURN(auto put_fn, - DevicePut(argument, client->ifrt_client_.get(), device, - options, ifrt::MemoryKind())); - TF_ASSIGN_OR_RETURN(auto put, [&]() { - // Must release the GIL before calling IFRT because backends may - // decide to block/sleep for device buffer allocation. - nb::gil_scoped_release gil_release; - return std::move(put_fn)(); - }()); - - if (put.ifrt_array) { - auto traceback = Traceback::Get(); - return PyArray::MakeFromSingleDeviceArray( - std::move(client), std::move(traceback), std::move(put.ifrt_array), - /*weak_type=*/false, - /*committed=*/false); - } else { - return put.owning_pybuffer; - } + TF_ASSIGN_OR_RETURN(DevicePutResult device_put_result, + DevicePutWithDevice(argument, client->ifrt_client_.get(), + device, ifrt::MemoryKind(), options)); + auto sharding = make_nb_class( + client, client->ifrt_client()->MakeDeviceList({device}), + /*memory_kind=*/nb::none()); + + auto traceback = Traceback::Get(); + return PyArray::MakeFromIfrtArrayAndSharding( + std::move(client), std::move(traceback), + std::move(device_put_result.ifrt_array), std::move(sharding), + /*weak_type=*/false, /*committed=*/false, + /*skip_checks=*/true); } namespace { diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 90dd77209694..2ad0d5849ba8 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -30,6 +31,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -44,14 +47,17 @@ limitations under the License. #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/types.h" @@ -71,6 +77,45 @@ namespace xla { namespace { +// Prepared data for creating a single shard of an array. Holds a single-device +// IFRT array or a host buffer. +struct Shard { + explicit Shard(tsl::RCReference ifrt_array, bool weak_type) + : ifrt_array_or_host_buffer(std::move(ifrt_array)), + weak_type(weak_type), + // host_buffer_semantics is not meaningful when + // `ifrt_array_or_host_buffer` is an IFRT Array. + host_buffer_semantics( + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall) {} + + Shard(ifrt::Client::HostBuffer ifrt_host_buffer, bool weak_type, + ifrt::Client::HostBufferSemantics host_buffer_semantics) + : ifrt_array_or_host_buffer(std::move(ifrt_host_buffer)), + weak_type(weak_type), + host_buffer_semantics(host_buffer_semantics) {} + + Shard(const Shard&) = delete; + Shard& operator=(const Shard&) = delete; + Shard(Shard&&) noexcept = default; + Shard& operator=(Shard&&) noexcept = default; + + bool is_ifrt_array() const { + return std::holds_alternative>( + ifrt_array_or_host_buffer); + } + ifrt::DType ifrt_dtype() const; + const ifrt::Shape& ifrt_shape() const; + + // Points to the on-device array or on-host buffer. + std::variant, ifrt::Client::HostBuffer> + ifrt_array_or_host_buffer; + bool weak_type; + ifrt::Client::HostBufferSemantics host_buffer_semantics; +}; + +// A function that creates a `Shard` from a Python object when called. +using ShardFn = absl::AnyInvocable() &&>; + absl::StatusOr> StringDTypeArrayToCords( PyArrayObject* py_array_obj) { if (PyArray_SIZE(py_array_obj) == 0) { @@ -97,14 +142,114 @@ absl::StatusOr> StringDTypeArrayToCords( return cords; } -using DevicePutFunc = std::function( - nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind)>; +// Handler that creates a `Shard` from a Python object. +using DevicePutHandler = std::function( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options)>; + +// Shared logic that makes a single-device IFRT array from a `shard`. `shard` +// will be consumed. +// +// `user_context` will be used for a new IFRT array created from the host +// buffer, and be not applied when reusing an existing IFRT array. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeSingleDeviceIfrtArrayFromShard( + xla::ifrt::Client* ifrt_client, xla::ifrt::Device* ifrt_device, + xla::ifrt::MemoryKind ifrt_memory_kind, Shard& shard, + tsl::RCReference user_context) { + if (auto* ifrt_array = std::get_if>( + &shard.ifrt_array_or_host_buffer)) { + return std::move(*ifrt_array); + } else { + auto host_buffer_shard = std::get( + std::move(shard.ifrt_array_or_host_buffer)); + std::shared_ptr ifrt_sharding = + ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); + return ifrt_client->MakeArrayFromHostBuffer( + host_buffer_shard.data, host_buffer_shard.dtype, + std::move(host_buffer_shard.shape), + std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), + shard.host_buffer_semantics, std::move(host_buffer_shard.on_done), + std::move(user_context)); + } +} + +// Makes an IFRT Array from `shards` using a batched array creation API (fast +// path). `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr> MakeIfrtArrayFromShardsInBatch( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + std::shared_ptr ifrt_sharding, + absl::Span shards, + tsl::RCReference user_context) { + absl::InlinedVector< + std::pair, ifrt::Client::HostBuffer>, 1> + host_buffers; + host_buffers.reserve(shards.size()); + ifrt::Client::HostBufferSemantics safe_host_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + // TODO(hyeontaek): Deduplicate shards here or early on to create a unique + // HostBuffer for each set of replicated shards. + for (int64_t i = 0; i < shards.size(); ++i) { + host_buffers.push_back({{i}, + std::get(std::move( + shards[i].ifrt_array_or_host_buffer))}); + // The minimum host buffer semantics is a safe semantics that can be used + // for all shards when they are created in a single batch. + safe_host_semantics = + std::min(safe_host_semantics, shards[i].host_buffer_semantics); + } + + std::vector specs; + specs.push_back(ifrt::Client::MakeArraysFromHostBufferShardsSpec{ + std::move(host_buffers), + ifrt::ArraySpec{/*dtype=*/ifrt_dtype, + /*shape=*/std::move(ifrt_shape), + /*sharding=*/std::move(ifrt_sharding), + /*layout=*/nullptr}}); + TF_ASSIGN_OR_RETURN( + auto arrays, + ifrt_client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), safe_host_semantics, std::move(user_context))); + return std::move(arrays.front()); +} + +// Makes an IFRT Array from `shards` using an array assembly API (slow path). +// `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeIfrtArrayFromShardsWithAssembly( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + std::shared_ptr ifrt_sharding, + ifrt::DeviceList* ifrt_addressable_device_list, + ifrt::MemoryKind ifrt_memory_kind, absl::Span shards, + tsl::RCReference user_context) { + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + std::vector> ifrt_array_shards; + ifrt_array_shards.reserve(shards.size()); + for (int64_t i = 0; i < shards.size(); ++i) { + TF_ASSIGN_OR_RETURN(tsl::RCReference ifrt_array_shard, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_addressable_devices[i], + ifrt_memory_kind, shards[i], user_context)); + ifrt_array_shards.push_back(std::move(ifrt_array_shard)); + } + return ifrt_client->AssembleArrayFromSingleDeviceArrays( + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_array_shards), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); +} template -absl::StatusOr HandlePythonScalar( - nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandlePythonScalar(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { T value; try { value = nb::cast(obj); @@ -128,27 +273,24 @@ absl::StatusOr HandlePythonScalar( data.template emplace<1>(static_cast(value)); type = primitive_util::NativeToPrimitiveType(); } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); - return [client, data, type, to_device, to_memory_kind, - options]() -> absl::StatusOr { + return [data, ifrt_dtype]() -> absl::StatusOr { const void* ptr = std::visit( [](const auto& v) { return static_cast(&v); }, data); - TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); - // TODO(yashkatariya): Plumb sharding or memory_kind here. - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - client->MakeArrayFromHostBuffer( - ptr, ifrt_dtype, /*shape=*/ifrt::Shape({}), /*byte_strides=*/{}, - ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), - ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, - /*on_done_with_host_buffer=*/{}, options.ifrt_user_context)); - return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); }; } -absl::StatusOr HandlePythonInt( - nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandlePythonInt(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { PrimitiveType type; std::variant data; @@ -175,28 +317,24 @@ absl::StatusOr HandlePythonInt( } type = S64; } - return [client, data, type, to_device, to_memory_kind, - options]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, ifrt_dtype]() -> absl::StatusOr { const void* ptr = std::visit( [](const auto& v) { return static_cast(&v); }, data); - TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); - // TODO(yashkatariya): Plumb sharding or memory_kind here. - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - client->MakeArrayFromHostBuffer( - ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), - /*byte_strides=*/{}, - ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), - ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, - /*on_done_with_host_buffer=*/nullptr, options.ifrt_user_context)); - return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); }; } template -absl::StatusOr HandleNumpyScalar( - nb::handle h, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandleNumpyScalar(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { std::variant data; PrimitiveType type; // For extension types, ScalarAsCtype returns a pointer to the data. @@ -256,8 +394,9 @@ absl::StatusOr HandleNumpyScalar( py_buffer_ref = GlobalPyRefManager()->ManageReference(nb::cast(h)); } - return [client, data, py_buffer_ref, type, to_device, options, - to_memory_kind]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, py_buffer_ref = std::move(py_buffer_ref), + ifrt_dtype]() mutable -> absl::StatusOr { const void* ptr = std::visit( [](const auto& v) -> const void* { if constexpr (std::is_same_v, void*>) { @@ -267,32 +406,26 @@ absl::StatusOr HandleNumpyScalar( } }, data); - TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); - // TODO(yashkatariya): Plumb sharding or memory_kind here. - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - client->MakeArrayFromHostBuffer( - ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), - /*byte_strides=*/{}, - ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), - ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, - /*on_done_with_host_buffer=*/ - [py_buffer_ref = std::move( - py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }, - options.ifrt_user_context)); - return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = + std::move(py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); }; } -absl::StatusOr HandleStringNumpyArray( +absl::StatusOr HandleStringNumpyArray( nb::handle h, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options) { xla::nb_numpy_ndarray array = nb::cast(h); auto py_array_obj = reinterpret_cast(array.ptr()); TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); // Assemble all the parameters of MakeArrayFromHostBuffer - void* data = cords.data(); + const void* data = cords.data(); // Make an explicit copy of the shape elements so we won't run into complex // endianness and precision issues that might arise if we reinterpret-casted @@ -305,36 +438,30 @@ absl::StatusOr HandleStringNumpyArray( } ifrt::Shape shape(std::move(dims)); - std::shared_ptr sharding = - xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); - auto on_done_with_host_buffer = [cords = std::move(cords)] {}; - return [client, data = data, shape = std::move(shape), - sharding = std::move(sharding), - on_done_with_host_buffer = std::move(on_done_with_host_buffer), - options]() mutable -> absl::StatusOr { - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - client->MakeArrayFromHostBuffer( - data, ifrt::DType(ifrt::DType::kString), std::move(shape), - /*byte_strides=*/std::nullopt, std::move(sharding), - ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, - std::move(on_done_with_host_buffer), options.ifrt_user_context)); - - return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + return [data, shape = std::move(shape), + on_done_with_host_buffer = std::move( + on_done_with_host_buffer)]() mutable -> absl::StatusOr { + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(on_done_with_host_buffer)}; + return Shard( + std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes); }; } -absl::StatusOr HandleNumpyArray( - nb::handle h, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandleNumpyArray(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { xla::nb_numpy_ndarray array = nb::cast(h); // String numpy arrays require substantially different processing. if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { - return HandleStringNumpyArray(h, client, to_device, options, - to_memory_kind); + return HandleStringNumpyArray(h, client, to_device, to_memory_kind, + options); } TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); @@ -355,7 +482,7 @@ absl::StatusOr HandleNumpyArray( } absl::InlinedVector dims(array.ndim()); - absl::InlinedVector byte_strides(array.ndim()); + ifrt::Client::HostBuffer::ByteStrides byte_strides(array.ndim()); for (int i = 0; i < array.ndim(); ++i) { dims[i] = array.shape(i); byte_strides[i] = array.strides(i); @@ -363,16 +490,16 @@ absl::StatusOr HandleNumpyArray( const void* data = array.data(); std::shared_ptr py_buffer_ref = GlobalPyRefManager()->ManageReference(std::move(array)); - return [client, data, squashed_type, dims = std::move(dims), + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(squashed_type)); + return [data, ifrt_dtype, dims = std::move(dims), byte_strides = std::move(byte_strides), - py_buffer_ref = std::move(py_buffer_ref), options, to_device, - to_memory_kind]() mutable -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type)); - + py_buffer_ref = std::move(py_buffer_ref), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { ifrt::Client::HostBufferSemantics host_buffer_semantics = ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; std::function on_done_with_host_buffer; - if (options.allow_zero_copy) { + if (allow_zero_copy) { on_done_with_host_buffer = [py_buffer_ref{ std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; @@ -380,20 +507,18 @@ absl::StatusOr HandleNumpyArray( ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; } - TF_ASSIGN_OR_RETURN( - auto ifrt_array, - client->MakeArrayFromHostBuffer( - data, ifrt_dtype, ifrt::Shape(dims), byte_strides, - xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), - host_buffer_semantics, std::move(on_done_with_host_buffer), - options.ifrt_user_context)); - return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt_dtype, ifrt::Shape(dims), std::move(byte_strides), + std::move(on_done_with_host_buffer)}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + host_buffer_semantics); }; } -absl::StatusOr HandlePyArray( - nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, - const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandlePyArray(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { auto py_array = nb::borrow(obj); // We only allow single device case for PyArray in device put. @@ -413,8 +538,8 @@ absl::StatusOr HandlePyArray( if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || ifrt_array->sharding().devices()->devices().front()->client() != to_device->client()) { - return HandleNumpyArray(obj.attr("_value"), client, to_device, options, - to_memory_kind); + return HandleNumpyArray(obj.attr("_value"), client, to_device, + to_memory_kind, options); } if (ifrt_array->sharding().devices()->devices().front() == to_device && @@ -422,14 +547,13 @@ absl::StatusOr HandlePyArray( (!to_memory_kind.memory_kind().has_value() || !ifrt_array->sharding().memory_kind().memory_kind().has_value() || ifrt_array->sharding().memory_kind() == to_memory_kind)) { - DevicePutResult result(tsl::FormRef(ifrt_array), py_array.weak_type(), - /*owning_pybuffer=*/nb::borrow(obj)); + Shard result(tsl::FormRef(ifrt_array), py_array.weak_type()); return [result = std::move(result)]() mutable { return std::move(result); }; } else { return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, - owning_pybuffer = py_array.weak_type(), - allow_zero_copy = options.allow_zero_copy]() mutable - -> absl::StatusOr { + weak_type = py_array.weak_type(), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { auto* ifrt_client = ifrt_array->client(); TF_ASSIGN_OR_RETURN( auto copied_ifrt_arrays, @@ -438,101 +562,112 @@ absl::StatusOr HandlePyArray( ifrt_client->MakeDeviceList({to_device}), to_memory_kind, allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput : ifrt::ArrayCopySemantics::kAlwaysCopy)); - return DevicePutResult(std::move(copied_ifrt_arrays[0]), - std::move(owning_pybuffer)); + return Shard(std::move(copied_ifrt_arrays.front()), weak_type); }; } } -} // namespace +ifrt::DType Shard::ifrt_dtype() const { + if (is_ifrt_array()) { + return std::get>(ifrt_array_or_host_buffer) + ->dtype(); + } else { + return std::get(ifrt_array_or_host_buffer).dtype; + } +} -absl::StatusOr DevicePut(nb::handle arg, - ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { - tsl::profiler::TraceMe traceme("DevicePut"); - static const absl::flat_hash_map* const handlers = - [] { - auto p = new absl::flat_hash_map(); - const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); - // Python scalar types. - static_assert(sizeof(bool) == 1, - "Conversion code assumes bool is 1 byte"); - (*p)[reinterpret_cast(&PyBool_Type)] = - HandlePythonScalar; - (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; - (*p)[reinterpret_cast(&PyFloat_Type)] = - HandlePythonScalar; - (*p)[reinterpret_cast(&PyComplex_Type)] = - HandlePythonScalar; - - (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; - - // Numpy scalar types. For some of them, we share the handler with - // Python types (np_int64, np_float64, np_complex128). - (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; - if (dtypes.np_int2.has_value()) { - (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; - } - (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; - if (dtypes.np_uint2.has_value()) { - (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; - } - (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; - if (dtypes.np_float4_e2m1fn.has_value()) { - (*p)[dtypes.np_float4_e2m1fn->ptr()] = - HandleNumpyScalar; - } - if (dtypes.np_float8_e3m4.has_value()) { - (*p)[dtypes.np_float8_e3m4->ptr()] = - HandleNumpyScalar; - } - if (dtypes.np_float8_e4m3.has_value()) { - (*p)[dtypes.np_float8_e4m3->ptr()] = - HandleNumpyScalar; - } - (*p)[dtypes.np_float8_e4m3fn.ptr()] = - HandleNumpyScalar; - (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = - HandleNumpyScalar; - (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = - HandleNumpyScalar; - (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = - HandleNumpyScalar; - if (dtypes.np_float8_e8m0fnu.has_value()) { - (*p)[dtypes.np_float8_e8m0fnu->ptr()] = - HandleNumpyScalar; - } - (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; - (*p)[dtypes.np_complex128.ptr()] = - HandleNumpyScalar; - static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT - "long long must be the same size as int64_t"); - (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; - static_assert(sizeof(int) == sizeof(int32_t), - "int must be the same size as int32_t"); - (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; +const ifrt::Shape& Shard::ifrt_shape() const { + if (is_ifrt_array()) { + return std::get>(ifrt_array_or_host_buffer) + ->shape(); + } else { + return std::get(ifrt_array_or_host_buffer).shape; + } +} - return p; - }(); +// Creates a `ShardFn` that copies `arg` to `to_device` and `to_memory_kind`. +// +// Requires GIL. The returned `ShardFn` should be called without GIL held. +absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + static const absl::flat_hash_map* const handlers = [] { + auto p = new absl::flat_hash_map(); + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + + return p; + }(); if (arg.type().ptr() == PyArray::type().ptr()) { auto array = nb::borrow(arg); - return HandlePyArray(arg, client, to_device, options, to_memory_kind); + return HandlePyArray(arg, client, to_device, to_memory_kind, options); } auto res = handlers->find(arg.type().ptr()); @@ -540,7 +675,7 @@ absl::StatusOr DevicePut(nb::handle arg, for (auto base_class : arg.type().attr("__mro__")) { res = handlers->find(base_class.ptr()); if (res != handlers->end()) { - return res->second(arg, client, to_device, options, to_memory_kind); + return res->second(arg, client, to_device, to_memory_kind, options); } } return InvalidArgument( @@ -550,9 +685,11 @@ absl::StatusOr DevicePut(nb::handle arg, "(see implementation), or Python scalars. Got type ", nb::cast(nb::str(arg.type())))); } - return res->second(arg, client, to_device, options, to_memory_kind); + return res->second(arg, client, to_device, to_memory_kind, options); } +} // namespace + bool IsFloat0(xla::nb_numpy_ndarray arg) { static const auto* dtypes_module = new nb::module_(nb::module_::import_("jax.dtypes")); @@ -756,4 +893,152 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, return res->second(arg, jax_enable_x64); } +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client* ifrt_client, + ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePut"); + + if (!ifrt_device->IsAddressable()) { + return InvalidArgument("Cannot copy array to non-addressable device: %s", + ifrt_device->DebugString()); + } + + TF_ASSIGN_OR_RETURN(ShardFn shard_fn, + MakeShardFn(addressable_shard, ifrt_client, ifrt_device, + ifrt_memory_kind, options)); + + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fn)()); + TF_ASSIGN_OR_RETURN(tsl::RCReference ifrt_array, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_device, ifrt_memory_kind, shard, + std::move(ifrt_user_context))); + return DevicePutResult(std::move(ifrt_array), shard.weak_type); +} + +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client* ifrt_client, const nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePutWithSharding"); + + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef ifrt_device_list, + GetIfrtDeviceList(sharding)); + ifrt::DeviceList* ifrt_addressable_device_list = + ifrt_device_list->AddressableDeviceList(); + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + // Pmap sharding requires special handling because it needs a shard shape + // upfront. + const bool is_pmap_sharding = sharding.type().is(jax::PmapSharding::type()); + + if (addressable_shards.size() != ifrt_addressable_devices.size()) { + // Try to generate a friendly error message if the user attempted to copy to + // a non-addressable device. + if (addressable_shards.size() > ifrt_addressable_devices.size()) { + for (ifrt::Device* device : ifrt_device_list->devices()) { + if (!device->IsAddressable()) { + return InvalidArgument( + "Cannot copy array to non-addressable device: %s", + device->DebugString()); + } + } + } + // Otherwise, generate a generic error message. + return InvalidArgument( + "Number of addressable shard data does not match the number " + "of addressable devices in the sharding: %d vs. %d", + addressable_shards.size(), ifrt_addressable_devices.size()); + } + if (is_pmap_sharding && addressable_shards.empty()) { + return InvalidArgument( + "Pmap sharding requires at least one addressable shard."); + } + + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, DtypeToIfRtDType(dtype)); + ifrt::Shape ifrt_shape(shape); + ifrt::MemoryKind ifrt_memory_kind = GetMemoryKind(sharding); + + std::vector shard_fns; + shard_fns.reserve(addressable_shards.size()); + for (int i = 0; i < addressable_shards.size(); ++i) { + TF_ASSIGN_OR_RETURN( + ShardFn shard, + MakeShardFn(addressable_shards[i], ifrt_client, + ifrt_addressable_devices[i], ifrt_memory_kind, options)); + shard_fns.push_back(std::move(shard)); + } + + std::shared_ptr ifrt_sharding; + if (is_pmap_sharding) { + CHECK(!shard_fns.empty()); + // IFRT Sharding will be determined once we discover the shard shape. + } else { + TF_ASSIGN_OR_RETURN(ifrt_sharding, + GetIfrtHloSharding(sharding, ifrt_shape)); + } + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + // Whether to build an IFRT array from host buffers as a single batch. We do + // not batch any shard is already an IFRT array. + bool should_batch = true; +#if JAX_IFRT_VERSION_NUMBER < 2 + // PjRt-IFRT would fail `xla::ifrt::Client::MakeArrayFromHostBuffer()` invoked + // by `xla::ifrt::ClientMakeArraysFromHostBufferShards()` for a fully + // replicated sharding if the sharding has any non-addressable device. + should_batch = false; +#endif + + std::vector shards; + shards.reserve(shard_fns.size()); + for (int64_t i = 0; i < shard_fns.size(); ++i) { + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fns[i])()); + if (shard.is_ifrt_array()) { + // If any shard is an IFRT array, we should assemble shards. + should_batch = false; + } + shards.push_back(std::move(shard)); + } + + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + if (!shards.empty()) { + ifrt_dtype = shards.front().ifrt_dtype(); + } + if (is_pmap_sharding) { + ifrt_sharding = ifrt::ConcreteEvenSharding::Create( + ifrt::DeviceListRef(tsl::FormRef(ifrt_addressable_device_list)), + ifrt_memory_kind, ifrt_shape, + /*shard_shape=*/shards.front().ifrt_shape(), + /*is_fully_replicated=*/false); + } + + tsl::RCReference ifrt_array; + if (should_batch) { + TF_ASSIGN_OR_RETURN(ifrt_array, + MakeIfrtArrayFromShardsInBatch( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } else { + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromShardsWithAssembly( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), ifrt_addressable_device_list, + ifrt_memory_kind, absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } + const bool weak_type = shards.empty() ? false : shards.front().weak_type; + return DevicePutResult(std::move(ifrt_array), weak_type); +} + } // namespace xla diff --git a/jaxlib/xla/py_values.h b/jaxlib/xla/py_values.h index b64895100d8c..afa59a839d5d 100644 --- a/jaxlib/xla/py_values.h +++ b/jaxlib/xla/py_values.h @@ -24,14 +24,13 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" -#include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/ifrt/user_context.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" @@ -39,53 +38,67 @@ limitations under the License. namespace xla { struct DevicePutResult { - explicit DevicePutResult( - tsl::RCReference ifrt_array, bool weak_type, - nanobind::object owning_pybuffer = nanobind::object()) - : ifrt_array(std::move(ifrt_array)), - weak_type(weak_type), - owning_pybuffer(owning_pybuffer) {} - - // Disallow copy since copying `DevicePutResult` without holding GIL may be - // dangerous due to `owning_pybuffer`. + DevicePutResult(tsl::RCReference ifrt_array, bool weak_type) + : ifrt_array(std::move(ifrt_array)), weak_type(weak_type) {} + + // Disallow copy. `DevicePutResult` is expected to be consumed by one user. DevicePutResult(const DevicePutResult&) = delete; DevicePutResult& operator=(const DevicePutResult&) = delete; DevicePutResult(DevicePutResult&&) noexcept = default; DevicePutResult& operator=(DevicePutResult&&) noexcept = default; - // Points to the on-device array. Not owned. + // Points to the on-device array. tsl::RCReference ifrt_array; bool weak_type; +}; - nanobind::object owning_pybuffer; +// Options for `DevicePut`. +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; }; -// Copies a buffer-like object to be on device. +// Copies a buffer-like object to be on device. This version is designed for +// creating a single-device array. // -// If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be -// returned; float0s are not supported yet. -// If the value is known to be a PyBuffer object, py_buffer can be passed as -// an optimization to avoid a Python->C++ cast. +// If `addressable_shard` is not convertible to a `PjRtBuffer` from C++, an +// error will be returned; float0s are not supported yet. // -// This function performs Python work inline but postpones C++ work until the -// returned function is called. The returned function must be called after -// releasing GIL. Useful for batching GIL release when there are many device_put -// to execute. +// If the value is known to be a PyBuffer object, py_buffer can be passed as an +// optimization to avoid a Python->C++ cast. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. // // May throw exceptions from nanobind in addition to failing via an error // absl::Status. (We could catch these if needed, but there seems little point.) -struct DevicePutOptions { - bool squash_64bit_types = false; - bool allow_zero_copy = true; - tsl::RCReference ifrt_user_context; -}; -using DevicePutResultFn = - absl::AnyInvocable() &&>; -absl::StatusOr DevicePut(nanobind::handle arg, - ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind); +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client* ifrt_client, + ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options); + +// Copies a buffer-like object to be on device. This version is optimized for +// creating a multi-device array. +// +// `addressable_shards` is a list of buffer-like objects to be copied to +// addressable devices specified in `sharding`. +// +// `shape` and `sharding` determine the shape and sharding of the returned IFRT +// Array. +// +// The size of `addressable_shards` must match the number of addressable devices +// in `sharding`. For a Pmap sharding, there must be at least one addressable +// device. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// See the above `DevicePutWithDevice` for other details. +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client* ifrt_client, const nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options); // Returns `true` if `arg` is a JAX float0 array. bool IsFloat0(xla::nb_numpy_ndarray arg); @@ -122,6 +135,7 @@ H AbslHashValue(H h, const xla::PyArgSignature& s) { h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); return h; } + } // namespace xla #endif // JAXLIB_XLA_PY_VALUES_H_ diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index b7c7e0a7de72..858cb677e10a 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -51,8 +51,7 @@ namespace nb = nanobind; // Gets `jax::PyDeviceList` from a JAX Sharding. absl::StatusOr> GetPyDeviceList( - nb::handle sharding_py) { - nb::handle sharding(sharding_py.ptr()); + nb::handle sharding) { if (sharding.type().is(jax::NamedSharding::type())) { TF_ASSIGN_OR_RETURN( auto ns_device_list, diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index e0c54592259b..7ce10e7ed763 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -57,7 +57,7 @@ class Sharding { // Gets `jax::PyDeviceList` from a JAX Sharding. absl::StatusOr> GetPyDeviceList( - nanobind::handle sharding_py); + nanobind::handle sharding); // Checks if the memory kind is valid, and canonicalizes the // memory kind to default memory on backends that support memories. From b12e05c514a033a455876ae9ff629481460bbfc3 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 22 Apr 2025 02:47:36 -0700 Subject: [PATCH 0723/1769] #sdy Add more debug info when there is a mesh mismatch in JAX export PiperOrigin-RevId: 750113790 --- jax/_src/export/_export.py | 31 +++++++++++++++++++++---------- tests/export_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 91b093dd05bd..dc298f935d9e 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -745,18 +745,29 @@ def export_sharding(s: LoweringSharding, if _device_assignment_for_internal_jax2tf_use_only is not None: _device_assignment_for_internal_jax2tf_use_only[0] = device_assignment - mesh = None + cur_mesh = cur_arg = cur_k_path = None + # lowered.args_info is a tree of the args, but we need the out avals too to + # get the key paths for. + out_avals_tree = jax.tree_util.tree_unflatten(lowered.out_tree, out_avals_flat) if config.use_shardy_partitioner.value: - for sharding in itertools.chain.from_iterable( - [all_in_shardings, lowering.compile_args["out_shardings"]]): + for sharding, (k_path, arg) in zip( + itertools.chain.from_iterable([ + all_in_shardings, lowering.compile_args["out_shardings"]]), + itertools.chain.from_iterable([ + jax.tree.flatten_with_path(lowered.args_info)[0], + jax.tree.flatten_with_path(out_avals_tree)[0]])): if isinstance(sharding, sharding_impls.NamedSharding): - if mesh is not None and mesh.shape_tuple != sharding.mesh.shape_tuple: + if cur_mesh is None: + cur_mesh, cur_arg, cur_k_path = sharding.mesh, arg, k_path + elif cur_mesh.shape_tuple != sharding.mesh.shape_tuple: raise ValueError( - f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' - f' another mesh: {sharding.mesh}') - mesh = sharding.mesh - if mesh and isinstance(mesh, mesh_lib.Mesh): - mesh = mesh.abstract_mesh + "Mesh for all inputs/outputs should be equal. Got one mesh " + f"{cur_mesh} on an array {cur_arg._aval} at " + f"{shape_poly.args_kwargs_path_to_str(cur_k_path)} and another mesh: " + f"{sharding.mesh}' on a tensor {arg._aval} at " + f"{shape_poly.args_kwargs_path_to_str(k_path)}") + if cur_mesh and isinstance(cur_mesh, mesh_lib.Mesh): + cur_mesh = cur_mesh.abstract_mesh def _get_exported_vjp(exp_primal: Exported) -> Exported: # Turn the primal jaxpr into a function, in preparation for exporting @@ -774,7 +785,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: device_assignment=device_assignment, apply_jit=True, flat_primal_fun=True, - mesh=mesh) # type: ignore[arg-type] + mesh=cur_mesh) # type: ignore[arg-type] return export(fun_vjp_jax, # type: ignore[arg-type] platforms=exp_primal.platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) diff --git a/tests/export_test.py b/tests/export_test.py index fcd5572edf5a..41f74fe11a1d 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -2008,6 +2008,32 @@ def f(x, y): r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b) self.assertAllClose(a + b, r) + def test_lower_wth_different_meshes_axis_names(self): + mesh1 = jtu.create_mesh((4, 2), ("a", "b")) + mesh2 = jtu.create_mesh((4, 2), ("x", "y")) + @jax.jit + def f(tree): + return tree['foo'] + tree['bar'] + + args = { + 'foo': jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh1, P(None, "a"))), + 'bar': jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh2, P("y"))), + } + + if config.use_shardy_partitioner.value: + with self.assertRaisesRegex( + ValueError, + r'Mesh for all inputs/outputs should be equal.*' + r"args\[0\]\['bar'\].*"): + get_exported(f)(args) + else: + get_exported(f)(args) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From dd472e47b24b1f5d3d0718e9d43dbcacb827c7bd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 22 Apr 2025 03:36:55 -0700 Subject: [PATCH 0724/1769] [mosaic_gpu] Fixed the export code path in `_mosaic_gpu_lowering_rule` PiperOrigin-RevId: 750125100 --- jax/experimental/mosaic/gpu/core.py | 2 +- tests/pallas/mosaic_gpu_test.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index a5ac6b9bf142..2161d0d6d725 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -142,7 +142,7 @@ def _mosaic_gpu_lowering_rule( operands=args, operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module, + backend_config=kernel_id + module_asm, operand_output_aliases=dict(input_output_aliases), ) else: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6e30b146a606..e13165fd076b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -25,10 +25,11 @@ from absl.testing import absltest from absl.testing import parameterized import jax +from jax import export from jax import lax from jax._src import test_util as jtu -from jax._src.pallas import pallas_call from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline @@ -2807,6 +2808,23 @@ def scope(acc_ref): ) +class ExportTest(PallasTest): + + def test_export_succeeds(self): + out_shape = jax.ShapeDtypeStruct([128], jnp.float32) + + @functools.partial(self.pallas_call, out_shape=out_shape) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + _ = export.export( + kernel, + disabled_checks=[ + export.DisabledSafetyCheck.custom_call("mosaic_gpu"), + ], + )(out_shape) + + class ExamplesTest(PallasTest): # Basic From 972cec160f56f95a139509683d05acd9c6296450 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Apr 2025 05:24:01 -0700 Subject: [PATCH 0725/1769] typing: improve annotations for scatter implementations --- jax/_src/lax/slicing.py | 2 +- jax/_src/numpy/array_methods.py | 65 ++++++++++++++++++++------------- jax/_src/ops/scatter.py | 29 +++++++++------ 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index ed7c2f9f2777..9f4645dca975 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -307,7 +307,7 @@ class GatherScatterMode(enum.Enum): ONE_HOT = enum.auto() @staticmethod - def from_any(s: str | GatherScatterMode | None): + def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode: if isinstance(s, GatherScatterMode): return s if s == "clip": diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index a64a6662e2a7..3783e09bf694 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -26,7 +26,7 @@ import abc from functools import partial, wraps import math -from typing import Any, Sequence +from typing import Any, Callable, Sequence import numpy as np import jax @@ -740,13 +740,15 @@ class _IndexUpdateHelper: """ __slots__ = ("array",) - def __init__(self, array): + array: Array + + def __init__(self, array: Array): self.array = array - def __getitem__(self, index): + def __getitem__(self, index: scatter.Index) -> _IndexUpdateRef: return _IndexUpdateRef(self.array, index) - def __repr__(self): + def __repr__(self) -> str: return f"_IndexUpdateHelper({self.array!r})" @@ -759,15 +761,19 @@ class _IndexUpdateRef: """ __slots__ = ("array", "index") - def __init__(self, array, index): + array: Array + index: scatter.Index + + def __init__(self, array: Array, index: scatter.Index): self.array = array self.index = index def __repr__(self) -> str: return f"_IndexUpdateRef({self.array!r}, {self.index!r})" - def get(self, *, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): + def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None, + fill_value: ArrayLike | None = None, out_sharding: Sharding | None = None): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style @@ -786,8 +792,9 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, fill_value=fill_value, out_sharding=out_sharding) - def set(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> None: """Pure equivalent of ``x[idx] = y``. Returns the value of ``x`` that would result from the NumPy-style @@ -803,8 +810,9 @@ def set(self, values, *, indices_are_sorted=False, unique_indices=False, unique_indices=unique_indices, mode=mode, out_sharding=out_s) - def apply(self, func, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def apply(self, func: Callable[[ArrayLike], Array], *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. Returns the value of ``x`` that would result from applying the unary @@ -826,8 +834,9 @@ def _scatter_apply(x, indices, y, dims, **kwargs): indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - def add(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def add(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] += y``. Returns the value of ``x`` that would result from the NumPy-style @@ -840,8 +849,9 @@ def add(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def subtract(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] -= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -854,8 +864,9 @@ def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def multiply(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] *= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -870,8 +881,9 @@ def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, mode=mode) mul = multiply - def divide(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def divide(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] /= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -886,8 +898,9 @@ def divide(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) - def power(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def power(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] **= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -902,8 +915,9 @@ def power(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) - def min(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def min(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style @@ -917,8 +931,9 @@ def min(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - def max(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def max(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | jax.lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index baf6a79328b4..fcb3759c5cae 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from typing import Union +from typing import Any, Union import warnings from functools import partial @@ -28,6 +28,8 @@ from jax._src import config from jax._src import core from jax._src import dtypes +from jax._src import sharding +from jax._src import tree_util from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.numpy import indexing @@ -44,9 +46,10 @@ Scalar = Union[complex, float, int, np.number] -def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, - unique_indices, mode=None, normalize_indices=True, - out_sharding=None): +def _scatter_update(x: ArrayLike, idx: Index, y: ArrayLike, scatter_op: Callable[..., Array], + indices_are_sorted: bool, unique_indices: bool, + mode: lax.GatherScatterMode | str | None = None, normalize_indices: bool = True, + out_sharding: sharding.Sharding | None = None): """Helper for indexed updates. Computes the value of x that would result from computing:: @@ -92,9 +95,11 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(2, 3, 4)) -def _scatter_impl(x, y, dynamic_idx, *, scatter_op, treedef, static_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices): +def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, + scatter_op: Callable[..., Array], + treedef: tree_util.PyTreeDef, static_idx: tuple[Any, ...], + indices_are_sorted: bool, unique_indices: bool, + mode: lax.GatherScatterMode | str | None, normalize_indices: bool): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) @@ -178,7 +183,7 @@ def _segment_update(name: str, unique_indices: bool = False, bucket_size: int | None = None, reducer: Callable | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) @@ -217,7 +222,7 @@ def segment_sum(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the sum within segments of an array. Similar to TensorFlow's `segment_sum @@ -272,7 +277,7 @@ def segment_prod(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the product within segments of an array. Similar to TensorFlow's `segment_prod @@ -327,7 +332,7 @@ def segment_max(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the maximum within segments of an array. Similar to TensorFlow's `segment_max @@ -381,7 +386,7 @@ def segment_min(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the minimum within segments of an array. Similar to TensorFlow's `segment_min From 45726ec43c921aba10b5b5d0a75943bac5828fe7 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 16 Apr 2025 10:45:21 -0400 Subject: [PATCH 0726/1769] Inline literals while tracing instead of in a separate pass. --- jax/_src/core.py | 24 +++--- jax/_src/interpreters/partial_eval.py | 107 +++++++++++--------------- 2 files changed, 59 insertions(+), 72 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 0183a9942524..39ed1dd48447 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -442,34 +442,36 @@ def __init__(self, aval: AbstractValue): def __repr__(self): return '_' class Literal: - __slots__ = ["val", "aval", "hash"] + __slots__ = ["val", "aval"] val: Any aval: AbstractValue - hash: int | None def __init__(self, val, aval): self.val = val self.aval = aval + + @property + def hash(self): try: - self.hash = hash(val) + return hash(self.val) except TypeError: - if type(val) in literalable_types: + if type(self.val) in literalable_types: try: - self.hash = hash((val.item(), val.dtype)) + return hash((self.val.item(), self.val.dtype)) except (TypeError, AttributeError, ValueError): - self.hash = None + return None __hash__ = None # type: ignore def __repr__(self): - if hasattr(self, 'hash'): - return f'{self.val}' - else: - return f'Literal(val={self.val})' + return f'{self.val}' literalable_types: set[type] = set() +def is_literalable(x: Any) -> bool: + return type(x) in dtypes.python_scalar_dtypes or (type(x) in literalable_types and not np.shape(x)) + Atom = Union[Var, Literal] class Primitive: @@ -2061,7 +2063,7 @@ class DShapedArray(UnshapedArray): array_abstraction_level: int = 3 def __init__(self, shape, dtype, weak_type=False): - self.shape = shape + self.shape = tuple(d.val if isinstance(d, Literal) else d for d in shape) self.dtype = dtype self.weak_type = weak_type diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f8ce92e7f97f..a37c70f1ff37 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -198,7 +198,7 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: if const is None: return tracer else: - if type(const) in core.literalable_types and np.shape(const) == (): + if core.is_literalable(const): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) @@ -1647,7 +1647,8 @@ def _origin_msg(self): def get_referent(self): frame = self._trace.frame - val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) + var = frame.tracer_to_var.get(id(self)) + val = frame.constvar_to_val.get(var) if isinstance(var, Var) else None return self if val is None else get_referent(val) core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval @@ -1687,7 +1688,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: class JaxprStackFrame: gensym: Callable[[AbstractValue], Var] - tracer_to_var: dict[TracerId, Var] + tracer_to_var: dict[TracerId, Atom] constid_to_tracer: dict[ConstId, Tracer] constvar_to_val: dict[Var, Any] tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers @@ -1725,7 +1726,8 @@ def to_jaxpr( debug_info: core.DebugInfo, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) @@ -1738,14 +1740,15 @@ def to_jaxpr( jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, debug_info) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], debug_info: core.DebugInfo): # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) constvars, constvals = unzip2(self.constvar_to_val.items()) expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, @@ -1754,7 +1757,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -1779,14 +1782,15 @@ def find_progenitors(self, tracer): active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns - if {v for v in eqn.invars if type(v) is Var} & constvars] + const_eqns = [eqn for eqn in self.eqns if any( + v in constvars if type(v) is Var else type(v) is Literal + for v in eqn.invars)] return invar_positions, const_eqns def _const_folding_and_forwarding( jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined + var_subs: dict[Var, Atom] = {} new_eqns = [] def apply_var_sub(a: Atom) -> Atom: return var_subs.get(a, a) if isinstance(a, Var) else a @@ -1797,14 +1801,20 @@ def apply_var_sub(a: Atom) -> Atom: has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) for eff in eqn.effects) if (eqn.primitive in const_fold_rules and - any(v in consts for v in eqn.invars if isinstance(v, Var)) and + any(v in consts if isinstance(v, Var) + else isinstance(v, Literal) for v in eqn.invars) and not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else None + consts_in = [consts.get(v) if isinstance(v, Var) else + v.val if isinstance(v, Literal) else None for v in eqn.invars] consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) assert (new_eqn is None) == all(c is not None for c in consts_out) for v, c in zip(eqn.outvars, consts_out): - if c is not None: consts[v] = c + if c is not None: + if core.is_literalable(c): + var_subs[v] = Literal(c, v.aval) + else: + consts[v] = c if new_eqn is None: continue else: eqn = new_eqn # if the application trivially maps some inputs to outputs, simplify @@ -1836,54 +1846,26 @@ def apply_var_sub(a: Atom) -> Atom: forwarding_rules: dict[Primitive, ForwardingRule] = {} -def _inline_literals( +def _drop_unused_vars( jaxpr: Jaxpr, constvals: Sequence[Any] ) -> tuple[Jaxpr, list[Any]]: - # This function also prunes unused constants and inserts `dropvar` symbols. - input_effects = {eff for eff in jaxpr.effects - if isinstance(eff, effects.JaxprInputEffect)} - # Don't inline any literal with an input effect - has_input_effect = [any(eff.input_index == i for eff in input_effects) - for i in range(len(constvals))] - lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, - has_input_effect) - if type(c) in core.literalable_types and not np.shape(c) and not e} - def lit(a: Atom) -> Literal | None: - return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) - else None) - newname: Callable[[AbstractValue], Var] = core.gensym() - newvars: dict[Var, Var] = {} - newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) - var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) - lit_or_var = ( - lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) - ) - dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) - - def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: + def vars(atom: Atom) -> list[Var]: + if isinstance(atom, Literal): + return [] + aval = atom.aval if isinstance(aval, DShapedArray): - return [d for d in aval.shape if isinstance(d, Var)] - return [] - - used = {v for eqn in jaxpr.eqns for atom in eqn.invars - for v in it.chain([atom], vars_in_shape(atom.aval)) - if isinstance(atom, Var)} - used |= {v for outvar in jaxpr.outvars - for v in it.chain([outvar], vars_in_shape(outvar.aval))} - new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] - new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) - if v in used and not lit(v)] - new_invars = [var(v) for v in jaxpr.invars] - new_eqns = [] - for eqn in jaxpr.eqns: - invars = [lit_or_var(x) for x in eqn.invars] - outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] - new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, - jaxpr.debug_info) - return new_jaxpr, new_constvals + return [atom] + [d for d in aval.shape if isinstance(d, Var)] + return [atom] + used: set[Var] = {v for atom in jaxpr.outvars for v in vars(atom)} + for eqn in jaxpr.eqns[::-1]: + eqn.outvars = [v if v in used else DropVar(v.aval) for v in eqn.outvars] + used.update(v for atom in eqn.invars for v in vars(atom)) + cvars, constvals = unzip2( + (v, val) for v, val in zip(jaxpr.constvars, constvals) if v in used) + jaxpr._constvars = list(cvars) + jaxpr._effects = make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, + jaxpr.outvars, jaxpr.eqns) + return jaxpr, list(constvals) class DynamicJaxprTrace(core.Trace): @@ -1934,9 +1916,12 @@ def new_const(self, c): def _new_const(self, aval, c) -> DynamicJaxprTracer: tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c + if core.is_literalable(c): + self.frame.tracer_to_var[id(tracer)] = Literal(c, aval) + else: + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.constid_to_tracer[id(c)] = tracer + self.frame.constvar_to_val[var] = c return tracer def _lift_tracers_in_aval(self, aval): From 59a546500ed1103c94c2bd3ccbe4df5a30dcd045 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 22 Apr 2025 06:30:04 -0700 Subject: [PATCH 0727/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3e18a59cd822a8426db29c8c36e912e7b2dbaae4. PiperOrigin-RevId: 750168674 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a25dee5e3f83..6135c8c1da65 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9219fd7ef180a01f814d3fce9f8aecfd80b9fd6c" -XLA_SHA256 = "9c8cfa363951ee90e36d86d861c9d07f703562cc4b044de77877edde113261bc" +XLA_COMMIT = "3e18a59cd822a8426db29c8c36e912e7b2dbaae4" +XLA_SHA256 = "b00d2e514d5a7bb7276ab7a82d2b1c380c27709ac1b91f92c8b3ccfb87c285c0" def repo(): tf_http_archive( From c477cb27e112a1ed5823743d16b898f9140680d0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 21 Apr 2025 20:53:05 -0400 Subject: [PATCH 0728/1769] Add support for generating freethreaded pip lockfiles. Regenerate lockfiles for all Python versions. --- build/BUILD.bazel | 109 ++++++---- build/build.py | 7 +- build/freethreading-requirements.txt | 2 + build/nonfreethreading-requirements.txt | 10 + build/requirements.in | 15 -- build/requirements_lock_3_10.txt | 66 ++++-- build/requirements_lock_3_11.txt | 64 ++++-- build/requirements_lock_3_12.txt | 64 ++++-- build/requirements_lock_3_13.txt | 168 +++++++++------ build/requirements_lock_3_13_ft.txt | 266 +++++++----------------- tests/lax_scipy_test.py | 2 +- 11 files changed, 391 insertions(+), 382 deletions(-) create mode 100644 build/freethreading-requirements.txt create mode 100644 build/nonfreethreading-requirements.txt diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 4d8dd9c1b7d8..539c156d3ac4 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -13,53 +13,80 @@ # limitations under the License. # ============================================================================== -licenses(["notice"]) - load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") load("//jaxlib:jax.bzl", "all_py_deps") -compile_pip_requirements( - name = "requirements", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = True, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +licenses(["notice"]) -compile_pip_requirements( - name = "requirements_nightly", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", - "--pre", - "--upgrade" - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +COMMON_REQUIREMENTS = [ + "requirements.in", + "test-requirements.txt", + "gpu-test-requirements.txt", +] -compile_pip_requirements( - name = "requirements_dev", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--upgrade", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +# It isn't possible to constraint based on free-threaded vs non-free threaded +# in a requirements file. So we do it by having two separate sets of requirement +# files and two sets of build rules. +FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "freethreading-requirements.txt", +] +NON_FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "nonfreethreading-requirements.txt", +] + +COMBOS = [ + ("", NON_FREETHREADING_REQUIREMENTS), + ("_ft", FREETHREADING_REQUIREMENTS), +] + +[ + compile_pip_requirements( + name = "requirements" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--rebuild", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = True, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_nightly" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", + "--pre", + "--upgrade", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = False, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_dev" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--upgrade", + "--rebuild", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = False, + ) + for suffix, requirements in COMBOS +] py_library( name = "all_py_deps", diff --git a/build/build.py b/build/build.py index 87aa36aeba8b..a5d39e559a9d 100755 --- a/build/build.py +++ b/build/build.py @@ -424,6 +424,7 @@ async def main(): else: bazel_command_base.append("build") + freethreaded = False if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version # if bazel_options override it @@ -439,6 +440,7 @@ async def main(): ) # Let's interpret X.YY-ft version as free-threading python and set rules_python config flag: if args.python_version.endswith("-ft"): + freethreaded = True bazel_command_base.append( "--@rules_python//python/config_settings:py_freethreaded='yes'" ) @@ -456,14 +458,15 @@ async def main(): for option in args.bazel_options: requirements_command.append(option) + ft_suffix = "_ft" if freethreaded else "" if args.nightly_update: logging.info( "--nightly_update is set. Bazel will run" " //build:requirements_nightly.update" ) - requirements_command.append("//build:requirements_nightly.update") + requirements_command.append(f"//build:requirements{ft_suffix}_nightly.update") else: - requirements_command.append("//build:requirements.update") + requirements_command.append(f"//build:requirements{ft_suffix}.update") result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) if result.return_code != 0: diff --git a/build/freethreading-requirements.txt b/build/freethreading-requirements.txt new file mode 100644 index 000000000000..2bbaf1fe8443 --- /dev/null +++ b/build/freethreading-requirements.txt @@ -0,0 +1,2 @@ +# Under free-threading, we need an up-to-date numpy at least for the moment. +numpy~=2.2.5 diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt new file mode 100644 index 000000000000..19b9cb51686d --- /dev/null +++ b/build/nonfreethreading-requirements.txt @@ -0,0 +1,10 @@ +numpy~=2.0.0; python_version<="3.12" +numpy~=2.1.0; python_version>="3.13" + +# These packages have not released free-threaded wheels. +zstandard +tensorstore + +# portpicker is architecture independent, but it depends on psutil which has not +# released a 3.13-ft wheel. +portpicker diff --git a/build/requirements.in b/build/requirements.in index edd7982e74ca..108c5f7492b0 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -1,15 +1,3 @@ -# -# test deps -# --r test-requirements.txt --r gpu-test-requirements.txt - -# -# build deps -# -numpy~=2.0.0; python_version<="3.12" -numpy~=2.1.0; python_version>="3.13" - # # runtime deps # @@ -17,8 +5,5 @@ scipy>=1.13.1; python_version<="3.12" scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 -opt_einsum -zstandard -tensorstore etils[epath] setuptools diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 8bf5293bd948..51d09c6638bb 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -270,7 +270,7 @@ markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.8.4 ; python_version <= "3.10" \ +matplotlib==3.8.4 ; python_version == "3.10" \ --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ @@ -329,7 +329,9 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 @@ -381,76 +383,75 @@ numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy # matplotlib # ml-dtypes # opt-einsum # scipy + # tensorstore nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # via -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # via -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cusolver-cu12 nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # via -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -542,10 +543,12 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ +portpicker==1.6.0 ; python_version < "3.13" \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt + # via + # -r build/nonfreethreading-requirements.txt + # -r build/test-requirements.txt psutil==5.9.8 \ --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ @@ -596,7 +599,7 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.1 \ +scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ @@ -631,6 +634,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt tomli==2.0.1 \ --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f @@ -696,7 +722,7 @@ zstandard==0.22.0 \ --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 487346ab6d12..00e9af7ea2dc 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -324,7 +324,9 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 @@ -376,76 +378,75 @@ numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy # matplotlib # ml-dtypes # opt-einsum # scipy + # tensorstore nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cusolver-cu12 nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -537,10 +538,12 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ +portpicker==1.6.0 ; python_version < "3.13" \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt + # via + # -r build/nonfreethreading-requirements.txt + # -r build/test-requirements.txt psutil==5.9.8 \ --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ @@ -591,7 +594,7 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.1 \ +scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ @@ -626,6 +629,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe @@ -685,7 +711,7 @@ zstandard==0.22.0 \ --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index e2f76cab8abc..3bf4f29bfac8 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -324,7 +324,9 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 @@ -376,76 +378,75 @@ numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy # matplotlib # ml-dtypes # opt-einsum # scipy + # tensorstore nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cusolver-cu12 nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -537,10 +538,12 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ +portpicker==1.6.0 ; python_version < "3.13" \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt + # via + # -r build/nonfreethreading-requirements.txt + # -r build/test-requirements.txt psutil==5.9.8 \ --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ @@ -591,7 +594,7 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.1 \ +scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ @@ -626,6 +629,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe @@ -685,7 +711,7 @@ zstandard==0.22.0 \ --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 403d0ad8a061..d0508fc3e8bc 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -372,7 +372,9 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # tensorstore mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c @@ -432,75 +434,74 @@ numpy==2.1.2 ; python_version >= "3.13" \ --hash=sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1 \ --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy # matplotlib # ml-dtypes # scipy + # tensorstore nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cusolver-cu12 nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/test-requirements.txt packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 @@ -600,25 +601,18 @@ pluggy==1.5.0 \ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==6.0.0 \ - --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \ - --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \ - --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \ - --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \ - --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \ - --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \ - --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \ - --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \ - --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \ - --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \ - --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \ - --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \ - --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \ - --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \ - --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \ - --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ - --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 + # via -r build/nonfreethreading-requirements.txt +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 # via portpicker pyelftools==0.31 \ --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ @@ -652,40 +646,53 @@ rich==13.9.2 \ --hash=sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c \ --hash=sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1 # via -r build/test-requirements.txt -scipy==1.14.1 \ - --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ - --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ - --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ - --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ - --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ - --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ - --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ - --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ - --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ - --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ - --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ - --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ - --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ - --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ - --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ - --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ - --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ - --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ - --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ - --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ - --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ - --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ - --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ - --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ - --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ - --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ - --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ - --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ - --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ - --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ - --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ - --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ - --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 +scipy==1.15.2 ; python_version >= "3.13" \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db # via -r build/requirements.in six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ @@ -695,6 +702,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 @@ -805,7 +835,7 @@ zstandard==0.23.0 \ --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 507e896ab8db..7fce0eef6a8a 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.13 # by the following command: # -# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in +# bazel run //build:requirements_ft.update # absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ @@ -328,64 +328,64 @@ mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.1 ; python_version >= "3.13" \ - --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ - --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ - --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ - --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ - --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ - --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ - --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ - --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ - --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ - --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ - --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ - --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ - --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ - --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ - --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ - --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ - --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ - --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ - --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ - --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ - --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ - --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ - --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ - --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ - --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ - --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ - --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ - --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ - --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ - --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ - --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ - --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ - --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ - --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ - --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ - --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ - --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ - --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ - --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ - --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ - --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ - --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ - --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ - --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ - --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ - --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ - --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ - --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ - --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ - --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ - --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ - --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ - --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ - --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ - --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 +numpy==2.2.5 \ + --hash=sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70 \ + --hash=sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a \ + --hash=sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4 \ + --hash=sha256:0bcb1d057b7571334139129b7f941588f69ce7c4ed15a9d6162b2ea54ded700c \ + --hash=sha256:0cd48122a6b7eab8f06404805b1bd5856200e3ed6f8a1b9a194f9d9054631beb \ + --hash=sha256:19f4718c9012e3baea91a7dba661dcab2451cda2550678dc30d53acb91a7290f \ + --hash=sha256:1a161c2c79ab30fe4501d5a2bbfe8b162490757cf90b7f05be8b80bc02f7bb8e \ + --hash=sha256:1f4a922da1729f4c40932b2af4fe84909c7a6e167e6e99f71838ce3a29f3fe26 \ + --hash=sha256:261a1ef047751bb02f29dfe337230b5882b54521ca121fc7f62668133cb119c9 \ + --hash=sha256:262d23f383170f99cd9191a7c85b9a50970fe9069b2f8ab5d786eca8a675d60b \ + --hash=sha256:2ba321813a00e508d5421104464510cc962a6f791aa2fca1c97b1e65027da80d \ + --hash=sha256:2c1a1c6ccce4022383583a6ded7bbcda22fc635eb4eb1e0a053336425ed36dfa \ + --hash=sha256:352d330048c055ea6db701130abc48a21bec690a8d38f8284e00fab256dc1376 \ + --hash=sha256:369e0d4647c17c9363244f3468f2227d557a74b6781cb62ce57cf3ef5cc7c610 \ + --hash=sha256:36ab5b23915887543441efd0417e6a3baa08634308894316f446027611b53bf1 \ + --hash=sha256:37e32e985f03c06206582a7323ef926b4e78bdaa6915095ef08070471865b906 \ + --hash=sha256:3a801fef99668f309b88640e28d261991bfad9617c27beda4a3aec4f217ea073 \ + --hash=sha256:3d14b17b9be5f9c9301f43d2e2a4886a33b53f4e6fdf9ca2f4cc60aeeee76372 \ + --hash=sha256:422cc684f17bc963da5f59a31530b3936f57c95a29743056ef7a7903a5dbdf88 \ + --hash=sha256:4520caa3807c1ceb005d125a75e715567806fed67e315cea619d5ec6e75a4191 \ + --hash=sha256:47834cde750d3c9f4e52c6ca28a7361859fcaf52695c7dc3cc1a720b8922683e \ + --hash=sha256:47f9ed103af0bc63182609044b0490747e03bd20a67e391192dde119bf43d52f \ + --hash=sha256:498815b96f67dc347e03b719ef49c772589fb74b8ee9ea2c37feae915ad6ebda \ + --hash=sha256:54088a5a147ab71a8e7fdfd8c3601972751ded0739c6b696ad9cb0343e21ab73 \ + --hash=sha256:55f09e00d4dccd76b179c0f18a44f041e5332fd0e022886ba1c0bbf3ea4a18d0 \ + --hash=sha256:5a0ac90e46fdb5649ab6369d1ab6104bfe5854ab19b645bf5cda0127a13034ae \ + --hash=sha256:6411f744f7f20081b1b4e7112e0f4c9c5b08f94b9f086e6f0adf3645f85d3a4d \ + --hash=sha256:6413d48a9be53e183eb06495d8e3b006ef8f87c324af68241bbe7a39e8ff54c3 \ + --hash=sha256:7451f92eddf8503c9b8aa4fe6aa7e87fd51a29c2cfc5f7dbd72efde6c65acf57 \ + --hash=sha256:8b4c0773b6ada798f51f0f8e30c054d32304ccc6e9c5d93d46cb26f3d385ab19 \ + --hash=sha256:8dfa94b6a4374e7851bbb6f35e6ded2120b752b063e6acdd3157e4d2bb922eba \ + --hash=sha256:97c8425d4e26437e65e1d189d22dff4a079b747ff9c2788057bfb8114ce1e133 \ + --hash=sha256:9d75f338f5f79ee23548b03d801d28a505198297534f62416391857ea0479571 \ + --hash=sha256:9de6832228f617c9ef45d948ec1cd8949c482238d68b2477e6f642c33a7b0a54 \ + --hash=sha256:a4cbdef3ddf777423060c6f81b5694bad2dc9675f110c4b2a60dc0181543fac7 \ + --hash=sha256:a9c0d994680cd991b1cb772e8b297340085466a6fe964bc9d4e80f5e2f43c291 \ + --hash=sha256:aa70fdbdc3b169d69e8c59e65c07a1c9351ceb438e627f0fdcd471015cd956be \ + --hash=sha256:abe38cd8381245a7f49967a6010e77dbf3680bd3627c0fe4362dd693b404c7f8 \ + --hash=sha256:b13f04968b46ad705f7c8a80122a42ae8f620536ea38cf4bdd374302926424dd \ + --hash=sha256:b4ea7e1cff6784e58fe281ce7e7f05036b3e1c89c6f922a6bfbc0a7e8768adbe \ + --hash=sha256:b6f91524d31b34f4a5fee24f5bc16dcd1491b668798b6d85585d836c1e633a6a \ + --hash=sha256:c26843fd58f65da9491165072da2cccc372530681de481ef670dcc8e27cfb066 \ + --hash=sha256:c42365005c7a6c42436a54d28c43fe0e01ca11eb2ac3cefe796c25a5f98e5e9b \ + --hash=sha256:c8b82a55ef86a2d8e81b63da85e55f5537d2157165be1cb2ce7cfa57b6aef38b \ + --hash=sha256:ced69262a8278547e63409b2653b372bf4baff0870c57efa76c5703fd6543282 \ + --hash=sha256:d2e3bdadaba0e040d1e7ab39db73e0afe2c74ae277f5614dad53eadbecbbb169 \ + --hash=sha256:d403c84991b5ad291d3809bace5e85f4bbf44a04bdc9a88ed2bb1807b3360bb8 \ + --hash=sha256:d7543263084a85fbc09c704b515395398d31d6395518446237eac219eab9e55e \ + --hash=sha256:d8882a829fd779f0f43998e931c466802a77ca1ee0fe25a3abe50278616b1471 \ + --hash=sha256:e4f0b035d9d0ed519c813ee23e0a733db81ec37d2e9503afbb6e54ccfdee0fa7 \ + --hash=sha256:e8b025c351b9f0e8b5436cf28a07fa4ac0204d67b38f01433ac7f9b870fa38c6 \ + --hash=sha256:eb7fd5b184e5d277afa9ec0ad5e4eb562ecff541e7f60e69ee69c8d59e9aeaba \ + --hash=sha256:ec31367fd6a255dc8de4772bd1658c3e926d8e860a0b6e922b615e532d320ddc \ + --hash=sha256:ee461a4eaab4f165b68780a6a1af95fb23a29932be7569b9fab666c407969051 \ + --hash=sha256:f5045039100ed58fa817a6227a356240ea1b9a1bc141018864c306c1a16d4175 # via - # -r build/requirements.in + # -r build/freethreading-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -395,65 +395,63 @@ nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cusolver-cu12 nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt + # via -r build/gpu-test-requirements.txt nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # -r build/gpu-test-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via - # -r build/test-requirements.txt - # -r build/requirements.in + # via -r build/test-requirements.txt packaging==24.2 \ --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f @@ -541,29 +539,6 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==6.1.1 \ - --hash=sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca \ - --hash=sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377 \ - --hash=sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468 \ - --hash=sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3 \ - --hash=sha256:384636b1a64b47814437d1173be1427a7c83681b17a450bfc309a1953e329603 \ - --hash=sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac \ - --hash=sha256:8be07491f6ebe1a693f17d4f11e69d0dc1811fa082736500f649f79df7735303 \ - --hash=sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4 \ - --hash=sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160 \ - --hash=sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8 \ - --hash=sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003 \ - --hash=sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030 \ - --hash=sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777 \ - --hash=sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5 \ - --hash=sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53 \ - --hash=sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649 \ - --hash=sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8 - # via portpicker pyelftools==0.31 \ --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 @@ -664,112 +639,11 @@ zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can't compile 0.23.0 -# due to https://github.com/indygreg/python-zstandard/issues/231 -# zstandard==0.23.0 \ -# --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ -# --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ -# --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ -# --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ -# --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ -# --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ -# --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ -# --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ -# --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ -# --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ -# --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ -# --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ -# --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ -# --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ -# --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ -# --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ -# --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ -# --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ -# --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ -# --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ -# --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ -# --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ -# --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ -# --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ -# --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ -# --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ -# --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ -# --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ -# --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ -# --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ -# --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ -# --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ -# --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ -# --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ -# --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ -# --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ -# --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ -# --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ -# --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ -# --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ -# --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ -# --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ -# --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ -# --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ -# --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ -# --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ -# --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ -# --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ -# --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ -# --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ -# --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ -# --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ -# --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ -# --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ -# --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ -# --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ -# --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ -# --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ -# --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ -# --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ -# --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ -# --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ -# --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ -# --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ -# --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ -# --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ -# --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ -# --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ -# --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ -# --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ -# --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ -# --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ -# --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ -# --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ -# --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ -# --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ -# --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ -# --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ -# --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ -# --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ -# --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ -# --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ -# --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ -# --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ -# --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ -# --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ -# --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ -# --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ -# --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ -# --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ -# --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ -# --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ -# --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ -# --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ -# --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ -# --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ -# --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 -# # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: setuptools==70.3.0 \ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc # via - # -r build/test-requirements.txt # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index bc80ed4e1cc2..20ed169d9405 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -113,7 +113,7 @@ def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0): self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.") - if not jtu.test_device_matches(["cpu"]): + if not jtu.test_device_matches(["cpu", "gpu"]): rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) From 6f3a8b4a755f68bebc9aca2c6288ae38ef9c301b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 17 Apr 2025 10:55:05 -0400 Subject: [PATCH 0729/1769] Fix ensure_compile_time_eval context after stackless. --- jax/_src/interpreters/partial_eval.py | 2 +- tests/api_test.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f8ce92e7f97f..0e888c1591aa 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1964,7 +1964,7 @@ def is_const(self, tracer): return self.frame.tracer_to_var.get(id(tracer)) is None def process_primitive(self, primitive, tracers, params): - if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: diff --git a/tests/api_test.py b/tests/api_test.py index 0b77580f2839..e351ebef37d4 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4955,6 +4955,11 @@ def my_sin_lin(nzs, x): with config.use_direct_linearize(True): jax.grad(my_sin_p.bind)(1.0) # doesn't crash + def test_ensure_compile_time_eval_no_leaks(self): + # https://github.com/jax-ml/jax/issues/25847 + with jax.ensure_compile_time_eval(): + jnp.linalg.solve(jnp.eye(3), jnp.ones(3)) # doesn't crash + class RematTest(jtu.JaxTestCase): From 99252277f2f806eb342e73a4562b35f2e979480a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 22 Apr 2025 11:46:37 -0400 Subject: [PATCH 0730/1769] Handle DShapedArray in caller. --- jax/_src/core.py | 3 ++- jax/_src/interpreters/partial_eval.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 39ed1dd48447..472745871f6a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2063,7 +2063,8 @@ class DShapedArray(UnshapedArray): array_abstraction_level: int = 3 def __init__(self, shape, dtype, weak_type=False): - self.shape = tuple(d.val if isinstance(d, Literal) else d for d in shape) + assert not any(isinstance(d, Literal) for d in shape) + self.shape = shape self.dtype = dtype self.weak_type = weak_type diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a37c70f1ff37..07979b7f3885 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1767,6 +1767,7 @@ def newvar(self, aval): # this aval may have tracers in it, so we replace those with variables new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d for d in aval.shape] + new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] aval = aval.update(shape=tuple(new_shape)) return self.gensym(aval) From aae929313f674e7d0b86ba14b99322ba86dd502c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 21 Apr 2025 15:50:14 -0700 Subject: [PATCH 0731/1769] jnp.ldexp: avoid overflow for large exponent --- jax/_src/numpy/ufuncs.py | 11 ++++++++++- tests/lax_numpy_test.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index c5a5d23764d1..07758a87750c 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -3003,7 +3003,16 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, = promote_args_inexact("ldexp", x1) x2 = lax.convert_element_type(x2, dtypes.dtype(x1)) - x = x1 * (2 ** x2) + + # Split off the exponent to avoid overflow for small x1 and large x2. + m, e = frexp(x1) + e = (e.astype(x2.dtype) + x2).astype(x1.dtype) + + # exponent may overflow by 1 and still have a finite result. + m = _where(e > 0, m * 2, m) + e = _where(e > 0, e - 1, e) + + x = m * (2 ** e.astype(m.dtype)) return _where(isinf(x1) | (x1 == 0), x1, x) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6a88363c2780..887ef7a804de 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2754,6 +2754,28 @@ def np_fun(x1, x2): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.parameters(*float_dtypes) + def testLdexpOverflow(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + args_maker = lambda: [np.array(0.5, dtype), 1 << (dtypes.finfo(dtype).nexp - 1)] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + + @parameterized.parameters(*float_dtypes) + def testLdexpExtremeValues(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + def args_maker(): + info = dtypes.finfo(dtype) + span = int(np.log2(float(info.max)) - np.log2(float(info.tiny))) - 1 + return [np.array([info.tiny, info.max], dtype=dtype), + np.array([span, -span])] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + @jtu.sample_product( rng_factory=[ jtu.rand_some_inf_and_nan, From 720488fde4b644278944ccd99385c92c1d37a580 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 22 Apr 2025 10:29:13 -0700 Subject: [PATCH 0732/1769] [Pallas][Mosaic TPU] Improve error message for 1D block specs when the block size is too small. PiperOrigin-RevId: 750244014 --- tests/pallas/pallas_error_handling_test.py | 33 ++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index 84e38f3d09db..f7ea17852d56 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -16,13 +16,16 @@ import traceback from absl.testing import absltest +from absl.testing import parameterized import jax from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import xla_client from jax._src.pallas.mosaic import error_handling from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +import numpy as np config.parse_flags_with_absl() @@ -128,6 +131,36 @@ def test_kernel(input_ref, output_ref): tb_string = "".join(tb_string) self.assertEndsWith(tb_string, "output_ref[idx, 0] = input_ref[0, 0]\n") + @parameterized.parameters( + ((2048,), (256,)), + ((2048,), (512,)), + ) + def test_small_1d_block_spec_raises(self, total_shape, block_shape): + # https://github.com/jax-ml/jax/issues/25379 + dtype = jnp.float32 + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + x = jnp.arange(np.prod(total_shape), dtype=dtype).reshape(total_shape) + x_spec = pl.BlockSpec(block_shape, lambda *args: args) + fn = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(total_shape, dtype), + in_specs=[x_spec], + out_specs=x_spec, + grid=tuple(tot // blk for tot, blk in zip(total_shape, block_shape, + strict=True)), + ) + # Having a block size that is too small should raise a suggestion + # to increase the block size. + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, + r"Try changing your kernel block shape to \([0-9,\s]+\) to align with" + " the XLA layout", + ): + fn(x) + def test_parse_location_string(self): name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) self.assertEqual(name, "/squeeze") From 60ebfd70c517866cd251d96a43fa1b18cddb443e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Apr 2025 10:40:02 -0700 Subject: [PATCH 0733/1769] jax.numpy: make standard input utilities respect __jax_array__ --- jax/_src/numpy/util.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index bcfb12673806..825c79f507da 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -141,6 +141,10 @@ def _arraylike_asarray(x: Any) -> Array: return lax.asarray(x) +def _check_jax_array_protocol(x: Any) -> Any: + return x.__jax_array__() if hasattr(x, '__jax_array__') else x + + @overload def ensure_arraylike(fun_name: str, /) -> tuple[()]: ... @overload @@ -223,6 +227,7 @@ def check_for_prngkeys(fun_name: str, *args: Any): def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument shape and dtype promotion.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes(*args)) @@ -230,6 +235,7 @@ def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: def promote_args_numeric(fun_name: str, *args: ArrayLike) -> list[Array]: check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_numeric(*args)) @@ -240,6 +246,7 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: Promotes non-inexact types to an inexact type.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_inexact(*args)) From 91deb25e192340eabcbe44f3bf8d2710db7e6072 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Tue, 22 Apr 2025 11:19:08 -0700 Subject: [PATCH 0734/1769] [CI] Add tpu v6e-8 to nightly test and release test. PiperOrigin-RevId: 750262738 --- .github/workflows/cloud-tpu-ci-nightly.yml | 11 ++++++++++- .github/workflows/wheel_tests_nightly_release.yml | 11 +++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index b50b07d5cc4a..fd799a3f70b5 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -27,9 +27,18 @@ jobs: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] python-version: ["3.10"] + # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. + exclude: + - tpu: + type: "v6e-8" + jaxlib-version: "nightly+oldest_supported_libtpu" + - tpu: + type: "v6e-8" + jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20241205 diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index e2466ee43d96..fe8b191c9530 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -82,20 +82,27 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'oldest_supported_libtpu' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8 + # Run a single Python version for v4-8 and v6e-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" + - tpu-specs: + type: "v6e-8" + python: "3.10" + - tpu-specs: + type: "v6e-8" + python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From 5dfaee016e7802c87fcf3f20ae629b327fb23e9e Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Tue, 22 Apr 2025 18:27:34 +0000 Subject: [PATCH 0735/1769] Automatically initialize distributed runs in K8s indexed jobs --- .github/workflows/k8s.yaml | 58 +++++++------ .github/workflows/k8s/indexed-job.yaml | 42 ++++++++++ .github/workflows/k8s/jobset.yaml | 34 ++++++++ .pre-commit-config.yaml | 6 +- examples/k8s/example.yaml | 17 ++-- examples/k8s/svc-acct.yaml | 6 +- jax/_src/clusters/k8s_cluster.py | 111 +++++++++++++++++++++++-- 7 files changed, 230 insertions(+), 44 deletions(-) create mode 100644 .github/workflows/k8s/indexed-job.yaml create mode 100644 .github/workflows/k8s/jobset.yaml diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 1042388fe9c6..470a899a187e 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -31,18 +31,25 @@ jobs: distributed-initialize: runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + controller: [jobset, indexed-job] steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 with: path: jax - + - name: Start Minikube cluster uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 - name: Install K8s Jobset + if: matrix.controller == 'jobset' run: | - kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.8.0/manifests.yaml + kubectl wait --for=condition=established crd/jobsets.jobset.x-k8s.io --timeout=60s + kubectl rollout status -n jobset-system deploy/jobset-controller-manager --timeout=120s - name: Build image run: | @@ -56,43 +63,44 @@ jobs: minikube image build -t local/jax:latest . - name: Create service account for K8s job introspection - run: | - kubectl apply -f jax/examples/k8s/svc-acct.yaml - - - name: Prepare test job - run: | - export VERSION=v4.44.3 - export BINARY=yq_linux_amd64 - wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq - - cat jax/examples/k8s/example.yaml |\ - yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\ - yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\ - tee example.yaml + run: kubectl apply -f jax/examples/k8s/svc-acct.yaml - name: Submit test job - run: | - kubectl apply -f example.yaml + run: kubectl apply -f jax/.github/workflows/k8s/${{ matrix.controller }}.yaml - name: Check job status shell: bash -e -o pipefail {0} run: | while true; do - status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type) + + completion_time=$( + kubectl get jobs -o yaml | \ + yq '.items[] | select(.metadata.name | test("^jaxjob")) | .status.completionTime' + ) timestamp=$(date +"%Y-%m-%d %H:%M:%S") echo "[$timestamp] Checking job status..." - if [ "$status" == "Completed" ]; then - echo "[$timestamp] Job has completed successfully!" - exit 0 - elif [ "$status" == "Failed" ]; then - echo "[$timestamp] Job has failed!" - exit 1 - else + if [ "$completion_time" == "null" ]; then echo "[$timestamp] Job is still running. Current pod status:" kubectl get pods --no-headers echo "[$timestamp] Waiting for 3 seconds before checking again..." sleep 3 + else + succeeded=$( + kubectl get jobs -o yaml | \ + yq '.items[] | select(.metadata.name | test("^jaxjob")) | .status.succeeded' + ) + parallelism=$( + kubectl get jobs -o yaml | \ + yq '.items[] | select(.metadata.name | test("^jaxjob")) | .spec.parallelism' + ) + if [ "$succeeded" == "$parallelism" ]; then + echo "[$timestamp] Job has completed successfully!" + exit 0 + else + echo "[$timestamp] Job has failed!" + exit 1 + fi fi done diff --git a/.github/workflows/k8s/indexed-job.yaml b/.github/workflows/k8s/indexed-job.yaml new file mode 100644 index 000000000000..c38a8c9991a2 --- /dev/null +++ b/.github/workflows/k8s/indexed-job.yaml @@ -0,0 +1,42 @@ +apiVersion: v1 +kind: Service +metadata: + name: jaxpods +spec: + publishNotReadyAddresses: true + clusterIP: None + selector: + job-name: jaxjob +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: jaxjob +spec: + parallelism: 8 + completions: 8 + completionMode: Indexed + backoffLimit: 0 + template: + spec: + subdomain: jaxpods # must match headless service name + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: IfNotPresent + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/.github/workflows/k8s/jobset.yaml b/.github/workflows/k8s/jobset.yaml new file mode 100644 index 000000000000..00150d0a9095 --- /dev/null +++ b/.github/workflows/k8s/jobset.yaml @@ -0,0 +1,34 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 8 + completions: 8 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: Never + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89ce80d9a815..36c4981bd413 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,11 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml - exclude: examples/k8s/svc-acct.yaml + exclude: | + (?x)^( + examples/k8s/svc-acct\.yaml | + \.github/workflows/k8s/indexed-job\.yaml + )$ - id: end-of-file-fixer # only include python files files: \.py$ diff --git a/examples/k8s/example.yaml b/examples/k8s/example.yaml index deee1950683a..9039626e9c82 100644 --- a/examples/k8s/example.yaml +++ b/examples/k8s/example.yaml @@ -1,7 +1,7 @@ apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: - name: example + name: jaxjob spec: replicatedJobs: - name: workers @@ -12,20 +12,19 @@ spec: backoffLimit: 0 template: spec: - serviceAccountName: training-job-sa + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml restartPolicy: Never - imagePullSecrets: + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ - name: null containers: - name: main - image: PLACEHOLDER - imagePullPolicy: IfNotPresent + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always resources: - requests: - cpu: 900m - nvidia.com/gpu: null limits: - cpu: 1 + cpu: 900m + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ nvidia.com/gpu: null command: - python diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml index d05fb9b0cd2a..c1523964c515 100644 --- a/examples/k8s/svc-acct.yaml +++ b/examples/k8s/svc-acct.yaml @@ -1,7 +1,7 @@ apiVersion: v1 kind: ServiceAccount metadata: - name: training-job-sa + name: jax-job-sa namespace: default --- apiVersion: rbac.authorization.k8s.io/v1 @@ -10,7 +10,7 @@ metadata: name: pod-reader rules: - apiGroups: [""] - resources: ["pods"] + resources: ["pods", "services"] verbs: ["get", "list", "watch"] - apiGroups: ["batch"] resources: ["jobs"] @@ -23,7 +23,7 @@ metadata: namespace: default subjects: - kind: ServiceAccount - name: training-job-sa + name: jax-job-sa namespace: default roleRef: kind: Role diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 23b2d68e11a5..11f93e36f647 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -103,13 +103,111 @@ def _job(cls): ) @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: - return '{job_name}-0.{jobset_name}:{port}'.format( - job_name=cls._pod().metadata.labels['job-name'], - jobset_name=cls._job().metadata.labels['jobset.sigs.k8s.io/jobset-name'], - port=cls._coordinator_port + @cache + def _headless_svc(cls): + with cls._handle_api_exception(): + services = cls._core_api.list_namespaced_service(cls._namespace()).items + + pod_labels = cls._pod().metadata.labels or {} + for svc in services: + if svc.spec.cluster_ip == "None": # if headless service + svc_selector = svc.spec.selector or {} + if all(pod_labels.get(k) == v for k, v in svc_selector.items()): + return svc + + # returns None if no headless service targets the current pod + return None + + @classmethod + @cache + def _controller(cls): + # https://github.com/kubernetes/apimachinery/blob/7b4292b/pkg/apis/meta/v1/types.go#L235 + # states that there cannot be more than one managing controller. + for owner in cls._pod().metadata.owner_references: + if owner.controller is True: + return owner + + raise RuntimeError( + 'Cannot automatically initialize distributed workload: ' + f'pod {cls._pod().metadata.name} does not have a controller.' ) + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + controller = cls._controller() + job = cls._job() + pod = cls._pod() + if controller.kind == 'Job': + # if job belongs to a jobset + if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels: + return '{job_name}-0.{subdomain}:{port}'.format( + job_name=job.metadata.name, + subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name'], + port=cls._coordinator_port + ) + # if job is standalone + else: + # check if the job is associated with a headless service, which is + # necessary for pods to communicate with each other + if pod.spec.subdomain is None: + # check if a headless service exists but not specified as subdomain + svc = cls._headless_svc() + err_msg = ( + "Pods within a job need a headless service in order to " + "communicate with each other. " + ) + if svc: + err_msg += ( + f"A headless service '{svc.metadata.name}' is found that " + "targets this job, but it is not specified as the job subdomain. " + "Please add the following to the job specification: " + ) + fix_msg = [ + "```", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + f" subdomain: {svc.metadata.name}", + "```", + ] + else: + err_msg += "To fix, add the following to the job specification:" + fix_msg = [ + "```", + "apiVersion: v1", + "kind: Service", + "metadata:", + " name: jaxpods", + "spec:", + " publishNotReadyAddresses: true", + " clusterIP: None", + " selector:", + f" job-name: {job.metadata.name}", + "---", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + " subdomain: jaxpods", + "```", + ] + + raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg)) + + return '{job_name}-0.{subdomain}:{port}'.format( + job_name=job.metadata.name, + subdomain=pod.spec.subdomain, + port=cls._coordinator_port + ) + + else: + raise RuntimeError( + 'In K8s, cluster automatic bootstrap only supports Job/JobSet.' + ) + @classmethod def get_process_count(cls) -> int: # https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism @@ -122,5 +220,6 @@ def get_process_id(cls) -> int: return int(os.environ['JOB_COMPLETION_INDEX']) except KeyError: raise RuntimeError( - 'K8s job must be run with `completionMode: "Indexed"`.' + 'To enable automatic bootstrap in a K8s cluster, ' + 'jobs must be indexed by setting `completionMode: "Indexed"`.' ) From 64c645dc96ec732fbe6a5ca543094b17a6fefbb8 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 17 Apr 2025 10:56:22 -0700 Subject: [PATCH 0736/1769] [Pallas] Update TPU pipelining docs --- docs/pallas/design/async_note.md | 1 + docs/pallas/pipelining.ipynb | 13 +- docs/pallas/pipelining.md | 13 +- docs/pallas/quickstart.ipynb | 1 + docs/pallas/quickstart.md | 1 + docs/pallas/tpu/pipelining.ipynb | 657 +++++-------------------------- docs/pallas/tpu/pipelining.md | 465 ++++------------------ 7 files changed, 198 insertions(+), 953 deletions(-) diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md index 42e32a074fd7..0fda9fe0a4e2 100644 --- a/docs/pallas/design/async_note.md +++ b/docs/pallas/design/async_note.md @@ -1,3 +1,4 @@ +(pallas_async)= # Pallas Async Operations ## Background \+ Motivation diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb index d7871a3b67db..6770351d7760 100644 --- a/docs/pallas/pipelining.ipynb +++ b/docs/pallas/pipelining.ipynb @@ -6,11 +6,14 @@ "id": "C93Xlf0DRW9H" }, "source": [ + "\n", + "(pallas_software_pipelining)=\n", + "\n", "# Software Pipelining\n", "\n", "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", "\n", - "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references.\n" + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references.\n" ] }, { @@ -391,7 +394,7 @@ "source": [ "## Pallas Pipelining API\n", "\n", - "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", "\n", "\n", "### Grid\n", @@ -574,7 +577,7 @@ "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", "\n", "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", - "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", "\n", "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", "\n", @@ -799,7 +802,7 @@ "id": "HFWcaAudW4z1" }, "source": [ - "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + X / \\beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + N(X / \\beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." ] }, { @@ -818,7 +821,7 @@ "id": "V4YQCZf1W7X5" }, "source": [ - "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" ] }, { diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index 42b91e368238..a79876a0ca97 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -13,11 +13,14 @@ jupyter: --- + +(pallas_software_pipelining)= + # Software Pipelining Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. -This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references. +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references. @@ -316,7 +319,7 @@ Now that we've seen how to manually implement a pipelined loop, let's look into ## Pallas Pipelining API -Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. +Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. ### Grid @@ -466,7 +469,7 @@ There are two cases where a buffer supports both reads and writes - accumulation **Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.** Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle. -The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. +The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array. @@ -570,7 +573,7 @@ In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\a -In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + X / \beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. +In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + N(X / \beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. @@ -580,7 +583,7 @@ In a **memory-bound** regime it is useful to identify if the problem is the late -If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. +If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 11dd2108e405..6460c1d5e739 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -5,6 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(pallas_quickstart)=\n", "# Pallas Quickstart\n", "\n", "\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index fff1dcb730f3..d4865488a15b 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(pallas_quickstart)= # Pallas Quickstart diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 10de587105f2..68932f4d1e40 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -2,8 +2,9 @@ "cells": [ { "cell_type": "markdown", - "id": "7704d3bb", - "metadata": {}, + "metadata": { + "id": "7704d3bb" + }, "source": [ "(pallas_tpu_pipelining)=" ] @@ -14,7 +15,7 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining\n", + "# TPU Pipelining\n", "\n", "" ] @@ -25,14 +26,24 @@ "id": "gAJDZh1gBh-h" }, "source": [ - "In this guide we'll cover how memory spaces in TPU work and how to write\n", - "pipelines in Pallas that overlap memory I/O with compute." + "This guide serves as a reference for TPU-specific pipelining concerns.\n", + "We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { + "executionInfo": { + "elapsed": 54, + "status": "ok", + "timestamp": 1744908474512, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "ejAVO6ikUUuF" }, "outputs": [], @@ -48,9 +59,8 @@ }, { "cell_type": "markdown", - "id": "0e212a5e", "metadata": { - "id": "TWKESTKAlyjT" + "id": "0e212a5e" }, "source": [ "(tpu_and_its_memory_spaces)=\n", @@ -60,7 +70,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "NnWW9GV4kW6P" + }, "source": [ "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", "registers (which temporarily store scalar and array values) and compute units\n", @@ -83,568 +95,81 @@ " Values can be loaded into memory from their respective caches (VMEM for\n", " VREGs and SMEM for SREGs).\n", "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n", - " matrix unit (MXU) that can do numerical computation.\n", + " matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded.\n", " Compute units operate on values that live in SREGs and VREGs and output\n", - " values into those registers as well.\n", - "\n", - "In order to do a vectorized computation on our values `x` and `y` that live\n", - "in HBM, we need to:\n", - "\n", - "1. Copy the values `x` and `y` into VMEM.\n", - "2. Load the values from VMEM into VREGs.\n", - "3. Execute the computation using the VPU or MXU, storing the output in VREGs.\n", - "4. Store the values in the output VREGs into VMEM.\n", - "5. Copy the output values in VMEM back to HBM." + " values into those registers as well." ] }, { "cell_type": "markdown", "metadata": { - "id": "TzctMbNsn3vc" + "id": "8Tl3wt5Wk3Ek" }, "source": [ - "Let's implement a Pallas function that does just that!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2IXQxNWrKJyb", - "outputId": "d62eb493-5f92-4496-f113-d3cd24cb0b9f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", - " # Load x and y from VMEM into VREGs\n", - " x_vregs = x_vmem_ref[:, :]\n", - " y_vregs = y_vmem_ref[:, :]\n", - " # Execute a vectorized add\n", - " z_vregs = x_vregs + y_vregs\n", - " # Store the output values in VREGs back into VMEM\n", - " z_vmem_ref[:, :] = z_vregs\n", - "\n", - "\n", - "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.\n", - " # It will then copy `x` and `y` from HBM into VMEM.\n", - " z = pl.pallas_call(\n", - " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", - " # pallas_call will also copy the output from VMEM back into HBM.\n", - " return z\n", - "\n", - "\n", - "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", - "add_matrices(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HMENNLy8okCL" - }, - "source": [ - "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "## TPU-specific Pipelining Features\n", "\n", - "`add_matrices_kernel` operates using `Ref`s that live in VMEM.\n", - "Loading from a VMEM `Ref` produces a value that lives in VREGs.\n", - "Values in VREGs behave like `jax.Array`s in that we can use `jnp` and\n", - "`jax.lax` operations on them to produce new values that live in VREGs.\n", - "When we produce the values we'd like to return, we store them in the output\n", - "VMEM `Ref`.\n", - "\n", - "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`.\n", - "Inside it, we pass `x` and `y` into `pallas_call`.\n", - "`pallas_call` is responsible for copying `x` and `y` into VMEM and for\n", - "allocating the VMEM buffers that the kernel operates on (including allocating\n", - "`z_vmem_ref`, the output VMEM buffer).\n", - "After the kernel function is finished running, `pallas_call` will also copy\n", - "the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." + "Pallas TPU supports the following platform-specific features." ] }, { "cell_type": "markdown", "metadata": { - "id": "5kWr-1tKpYro" + "id": "1jg5WmExk47l" }, "source": [ - "## Constraints of using VMEM/SMEM\n", - "\n", - "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", - "writing kernels utilizing them adds some considerations.\n", + "### TPU Memory Spaces\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n", - " and SMEM ranges in the tens to hundreds of KiB.\n", - " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", - " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", - " scale beyond moderately sized arrays.\n", + "Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM):\n", "\n", - "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least\n", - " compared to most compute instructions.\n", - " The `add_matrices` function above will likely spend more time copying\n", - " between HBM and VMEM than actually performing the addition itself.\n", + "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", + "| --- | --- | --- |\n", + "| `pltpu.TPUMemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n", + "| `pltpu.TPUMemorySpace.VMEM` | VMEM | SRAM |\n", + "| `pltpu.TPUMemorySpace.SMEM` | SMEM | SRAM |\n", + "| `pltpu.TPUMemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", "\n", - "With these two constraints in mind, we'll have to rethink our strategy for\n", - "getting performance out of our TPUs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_NTqvlbetB3P" - }, - "source": [ - "## Primer: Pipelining\n", + "- `TPUMemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", + "- `TPUMemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", + "- `TPUMemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n", + "- `TPUMemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n", "\n", - "Pipelining our computation offers a way of dealing with both the memory\n", - "capacity and bandwidth constraints in one fell swoop.\n", - "What do we mean by pipelining?\n", + "Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM.\n", "\n", - "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our\n", - "compute units.\n", - "Naively this is difficult because in our program above we copy *all* of `x`\n", - "and `y` before we start doing any compute with them, creating a dependence\n", - "between the copy and the compute.\n", + "While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`.\n", "\n", - "However, if we can chunk up our computation into several subcomputations\n", - "(e.g. when we add two matrices, we can express that as addition of \"blocks\"\n", - "of the original matrices together), we can now overlap the copies of one of\n", - "those subcomputations with the compute of the other. Let's walk through a\n", - "simple example:\n", - "\n", - "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for\n", - "example, split along the leading axis, resulting in two `(256, 512)` arrays\n", - "for each input.\n", - "We can now execute the following pipelined computation.\n", - "\n", - "1. Copy `x1` and `y1` into VMEM.\n", - "1. Start copying `x2` and `y2` into VMEM\n", - "2. Load `x1, y1` from VMEM into VREGs.\n", - "3. Execute the `z1 = x1 + y1` using the compute units.\n", - "4. Store `z1` into VMEM.\n", - "5. Start copying `z1` from VMEM back into HBM.\n", - "6. Wait until `x2, y2` have been copied into VMEM.\n", - "7. Load `x2, y2` from VMEM into VREGs.\n", - "8. Execute the `z2 = x2 + y2` using the compute units.\n", - "9. Store `z2` into VMEM.\n", - "10. Wait until `z1` is copied into HBM.\n", - "10. Start copying `z2` from VMEM back into HBM.\n", - "10. Wait until `z2` is copied into HBM.\n", - "\n", - "Any time we are doing compute here, we are asynchronously copying something.\n", - "This means that some of the time spent copying is not wasted.\n", - "\n", - "The two most important numbers for determining how efficient a pipelined\n", - "computation are a) how many floating point operations (FLOPs) we need to\n", - "execute and b) how many bytes we need to copy to execute that computation.\n", - "The ratio of these two (FLOPs/memory usage) is called the\n", - "*arithmetic intensity* of an operation and determines if our pipeline will\n", - "be compute bound or memory bound." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gutx7y8uvZKH" - }, - "source": [ - "## Pipelining in Pallas" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "U-dPTjlBverB" - }, - "source": [ - "How do we implement a pipeline like the one above in Pallas?\n", - "It seems like a complex sequence of asynchronous data operations and\n", - "executing kernels that would be a pain to implement manually.\n", - "Fear not! Pallas offers an API for expressing pipelines without too much\n", - "boilerplate, namely through `grid`s and `BlockSpec`s.\n", - "\n", - "See how in the above pipelined example, we are executing the same logic\n", - "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", - "only on different inputs.\n", - "The {func}`jax.experimental.pallas.pallas_call` provides a way to\n", - "execute a kernel multiple times, by using the `grid` argument.\n", - "See {ref}`pallas_grid`.\n", - "\n", - "We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n", - "how to construct the input of each kernel invocation.\n", - "See {ref}`pallas_blockspec`.\n", - "\n", - "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", - "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", - "In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n", - "On the 1st iteration we'd\n", - "like to select `x1` and on the second iteration we'd like to use `x2`.\n", - "This can be expressed with the following `index_map`:\n", - "\n", - "```python\n", - "def x_index_map(i):\n", - " return (i, 0)\n", - "```\n", - "\n", - "We'd then construct the `BlockSpec`:\n", - "```python\n", - "block_spec = pl.BlockSpec((256, 512), x_index_map)\n", - "```\n", - "\n", - "The `BlockSpec`s for `y` and `z` will be the same as the one for `x`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "noybOKghzjwG" - }, - "source": [ - "### Putting it together\n", - "\n", - "We provide these arguments to `pallas_call` via `grid`, `in_specs` and\n", - "`out_specs` (`in_specs` corresponds to the tuple of positional arguments,\n", - "and `out_specs` corresponds to the output)." + "As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { - "id": "ehKAYAwIojfv", - "outputId": "504bab29-83f3-4e1f-8664-1860ad15b6de" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] + "executionInfo": { + "elapsed": 65, + "status": "ok", + "timestamp": 1744908591430, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(2,)\n", - " )(x, y)\n", - "\n", - "add_matrices_pipelined(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rkytgIZYzz4t" - }, - "source": [ - "We've only added a little bit of code to our original function to add\n", - "automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy\n", - "lifting!\n", - "\n", - "How does it work? Well, the `BlockSpec`s provide enough information to start\n", - "*prefetching* blocks of our input from HBM into VMEM.\n", - "For example, if we are starting iteration `i` of our `grid`, we can pass\n", - "`i + 1` into the `index_map` functions to obtain the blocks needed for the\n", - "next iteration. We can then start an asynchronous copy for those blocks.\n", - "Similarly for outputs, we can wait for the outputs of the previous iteration\n", - "to be copied before starting the copy for the current iteration's outputs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7Xtz9oMs0ZRL" - }, - "source": [ - "### Parameterizing a pipeline" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "esY4GcIB0bqQ" - }, - "source": [ - "It's common to parameterize the block shapes in our kernel. Block sizes are\n", - "perhaps the most important parameter to tune when optimizing the performance\n", - "of Pallas kernels! They give us control over the pipeline (for example,\n", - "picking smaller blocks adds more iterations to our pipelined loop where each\n", - "iteration has less work to do).\n", - "\n", - "Furthermore, we could also carve up the inputs and outputs along the 2nd\n", - "dimension (we are only splitting along the first right now). Let's write a\n", - "more general kernel that handles both of these features." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VartelFd0YfY" + "user_tz": 420 + }, + "id": "zcqz1CA_o50a" }, "outputs": [], "source": [ - "def add_matrices_pipelined_2d(\n", - " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", - ") -> jax.Array:\n", - " m, n = x.shape\n", - " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(m // bm, n // bn),\n", - " )(x, y)\n", - "\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KrfeYwaW1QA-" - }, - "source": [ - "## Handling reductions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P3SqEKDe3Mar" - }, - "source": [ - "How would you implement something like `jnp.sum` using `pallas_call`?\n", - "Specifically, we'd like to pipeline across the reduction dimension.\n", - "\n", - "Take the example of reducing a `(8, 512, 512)`-shaped array to a\n", - "`(512, 512)`-shaped one." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JoT-ZKEk1R7l", - "outputId": "fd842223-98a5-4e5c-87fc-5dadc94da4fa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = jnp.ones((8, 512, 512))\n", - "jnp.sum(x, axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5O3ByvuT3iyC" - }, - "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in\n", - "each iteration `i` load `x[i]` into VMEM.\n", - "Then we could add `x[i]` to an output VMEM buffer. Let's implement this\n", - "naively first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hqvv_WRQ3bvP", - "outputId": "200648d2-3f4d-4d1a-b95a-d2c1352cd7b8" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " ...,\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Warning: this implementation is incorrect!\n", - "\n", - "def naive_sum_kernel(x_ref, o_ref):\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def naive_sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " naive_sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", - " )(x)\n", - "naive_sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kv9qJYJY4jbK" - }, - "source": [ - "Notice how we've set up the `BlockSpec`s: we're loading the entirety of\n", - "the `(512, 512)` dimension into VMEM (no pipelining there) but selecting\n", - "the `i`-th dimension of `x` each iteration in the `index_map`.\n", - "We are using a `None` for that dimension in the block shape, which indicates\n", - "that we are selecting a singleton dimension from `x` that we would like\n", - "to squeeze away in the kernel.\n", - "Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", - "\n", - "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that\n", - "`o_ref` is unchanged over the course of the pipeline.\n", - "This means that we can update its value each iteration by reading from and\n", - "writing to it. Or can it?\n", - "Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll\n", - "be accumulating into garbage.\n", - "This will result in the overall function outputting the incorrect value!\n", - "\n", - "Therefore, **whenever we do a reduction in a kernel, we need to make sure\n", - "to initialize the `Ref` that is storing the reduced value**.\n", - "We can accomplish this by conditionally writing a value to `out_ref`\n", - "when we're on iteration 0.\n", - "We can do this with the helper function `pl.when`, a convenience wrapper\n", - "around `jax.lax.cond`, and `pl.program_id`,\n", - "which queries which iteration in a grid axis we are in." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JXN2RthX5cSw", - "outputId": "195df19b-a889-479b-95b6-1fb7281f1518" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def sum_kernel(x_ref, o_ref):\n", - " @pl.when(pl.program_id(axis=0) == 0)\n", - " def _():\n", - " o_ref[...] = jnp.zeros_like(o_ref)\n", + "def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):\n", + " pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)\n", + " out_vmem_ref[...] = scratch_vmem_ref[...] + 1\n", "\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", + "out = pl.pallas_call(hbm_vmem_kernel,\n", + " in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)],\n", + " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", + " scratch_shapes=(pltpu.TPUMemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", + ")(x)\n", "\n", - "sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2828qXBI5ksZ" - }, - "source": [ - "This `sum` function now outputs the correct values!\n", - "\n", - "One last thing to note about reductions in Pallas are that **they must be\n", - "done in the minormost (rightmost) dimensions of our grid** (our grid is\n", - "1-dimensional in the above example so we are reducing over its minormost\n", - "dimension). This is because the pipeline that Pallas generates using\n", - "the `BlockSpec`s, `grid` and kernel function *does not read outputs back\n", - "from HBM*.\n", - "Once you've written an output value back to HBM you cannot revisit it.\n", - "Therefore, you cannot do a reduction across a grid dimension that has any\n", - "revisiting and therefore all reductions need to happen in the rightmost\n", - "dimensions." + "np.testing.assert_allclose(out, x[0:1] + 1)" ] }, { @@ -655,7 +180,7 @@ "source": [ "(pallas_tpu_megacore)=\n", "\n", - "## TPUs in Megacore configuration" + "### TPUs in Megacore configuration" ] }, { @@ -683,10 +208,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { + "executionInfo": { + "elapsed": 106, + "status": "ok", + "timestamp": 1744910274556, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "nQNa8RaQ-TR1", - "outputId": "385ed87c-d95c-466c-af77-df3845c979f2" + "outputId": "29c0b574-3528-49a5-8a88-b6987efc69ce" }, "outputs": [ { @@ -701,12 +236,21 @@ " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", + " # Load x and y from VMEM into VREGs\n", + " x_vregs = x_vmem_ref[:, :]\n", + " y_vregs = y_vmem_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_vregs = x_vregs + y_vregs\n", + " # Store the output values in VREGs back into VMEM\n", + " z_vmem_ref[:, :] = z_vregs\n", + "\n", "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n", " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", @@ -715,7 +259,8 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", @@ -737,28 +282,16 @@ "\n", "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)." ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1ZJ2rV5W8FAe" - }, - "source": [ - "## Conclusion\n", - "\n", - "In this guide we covered how to express TPU pipelines using `pallas_call`,\n", - "`grid` and `BlockSpec`s. We covered how to express nested loops via a\n", - "multi-dimensional grid and how to handle reductions by initialize our\n", - "accumulators at the beginning of the reduction.\n", - "We also learned how to handle Megacore by adding annotations to the kernel.\n", - "\n", - "Exercises left to the reader:\n", - "* Try implementing a `sum` kernel that pipelines the other dimensions as well\n", - "* Add megacore support to the `add` kernel and the `sum` kernel as well." - ] } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst" }, diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index df570cf0806c..b9ed41f937c8 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -11,22 +11,33 @@ kernelspec: name: python3 --- ++++ {"id": "7704d3bb"} + (pallas_tpu_pipelining)= +++ {"id": "teoJ_fUwlu0l"} -# Pipelining +# TPU Pipelining +++ {"id": "gAJDZh1gBh-h"} -In this guide we'll cover how memory spaces in TPU work and how to write -pipelines in Pallas that overlap memory I/O with compute. +This guide serves as a reference for TPU-specific pipelining concerns. +We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`. ```{code-cell} -:id: ejAVO6ikUUuF - +--- +executionInfo: + elapsed: 54 + status: ok + timestamp: 1744908474512 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ejAVO6ikUUuF +--- #@title Imports import jax @@ -36,13 +47,13 @@ import jax.numpy as jnp import numpy as np ``` -+++ {"id": "TWKESTKAlyjT"} ++++ {"id": "0e212a5e"} (tpu_and_its_memory_spaces)= ## TPU and its memory spaces -+++ ++++ {"id": "NnWW9GV4kW6P"} A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units @@ -65,384 +76,71 @@ Let's talk about the components of this diagram in more detail: Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs). * **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and - matrix unit (MXU) that can do numerical computation. + matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. -In order to do a vectorized computation on our values `x` and `y` that live -in HBM, we need to: - -1. Copy the values `x` and `y` into VMEM. -2. Load the values from VMEM into VREGs. -3. Execute the computation using the VPU or MXU, storing the output in VREGs. -4. Store the values in the output VREGs into VMEM. -5. Copy the output values in VMEM back to HBM. - -+++ {"id": "TzctMbNsn3vc"} - -Let's implement a Pallas function that does just that! - -```{code-cell} -:id: 2IXQxNWrKJyb -:outputId: d62eb493-5f92-4496-f113-d3cd24cb0b9f - -def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): - # Load x and y from VMEM into VREGs - x_vregs = x_vmem_ref[:, :] - y_vregs = y_vmem_ref[:, :] - # Execute a vectorized add - z_vregs = x_vregs + y_vregs - # Store the output values in VREGs back into VMEM - z_vmem_ref[:, :] = z_vregs - - -def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: - # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM. - # It will then copy `x` and `y` from HBM into VMEM. - z = pl.pallas_call( - add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - # pallas_call will also copy the output from VMEM back into HBM. - return z - - -x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) -add_matrices(x, y) -``` - -+++ {"id": "HMENNLy8okCL"} - -We've written two functions: `add_matrices_kernel` and `add_matrices`. - -`add_matrices_kernel` operates using `Ref`s that live in VMEM. -Loading from a VMEM `Ref` produces a value that lives in VREGs. -Values in VREGs behave like `jax.Array`s in that we can use `jnp` and -`jax.lax` operations on them to produce new values that live in VREGs. -When we produce the values we'd like to return, we store them in the output -VMEM `Ref`. - -The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. -Inside it, we pass `x` and `y` into `pallas_call`. -`pallas_call` is responsible for copying `x` and `y` into VMEM and for -allocating the VMEM buffers that the kernel operates on (including allocating -`z_vmem_ref`, the output VMEM buffer). -After the kernel function is finished running, `pallas_call` will also copy -the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. - -+++ {"id": "5kWr-1tKpYro"} - -## Constraints of using VMEM/SMEM - -Pallas exposes access to lower level memory spaces like VMEM and SMEM but -writing kernels utilizing them adds some considerations. - -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB - and SMEM ranges in the tens to hundreds of KiB. - If our arrays are too big, we won't even be able to fit them into VMEM at all. - For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't - scale beyond moderately sized arrays. - -2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least - compared to most compute instructions. - The `add_matrices` function above will likely spend more time copying - between HBM and VMEM than actually performing the addition itself. - -With these two constraints in mind, we'll have to rethink our strategy for -getting performance out of our TPUs. - -+++ {"id": "_NTqvlbetB3P"} - -## Primer: Pipelining - -Pipelining our computation offers a way of dealing with both the memory -capacity and bandwidth constraints in one fell swoop. -What do we mean by pipelining? - -The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our -compute units. -Naively this is difficult because in our program above we copy *all* of `x` -and `y` before we start doing any compute with them, creating a dependence -between the copy and the compute. - -However, if we can chunk up our computation into several subcomputations -(e.g. when we add two matrices, we can express that as addition of "blocks" -of the original matrices together), we can now overlap the copies of one of -those subcomputations with the compute of the other. Let's walk through a -simple example: - -Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for -example, split along the leading axis, resulting in two `(256, 512)` arrays -for each input. -We can now execute the following pipelined computation. - -1. Copy `x1` and `y1` into VMEM. -1. Start copying `x2` and `y2` into VMEM -2. Load `x1, y1` from VMEM into VREGs. -3. Execute the `z1 = x1 + y1` using the compute units. -4. Store `z1` into VMEM. -5. Start copying `z1` from VMEM back into HBM. -6. Wait until `x2, y2` have been copied into VMEM. -7. Load `x2, y2` from VMEM into VREGs. -8. Execute the `z2 = x2 + y2` using the compute units. -9. Store `z2` into VMEM. -10. Wait until `z1` is copied into HBM. -10. Start copying `z2` from VMEM back into HBM. -10. Wait until `z2` is copied into HBM. - -Any time we are doing compute here, we are asynchronously copying something. -This means that some of the time spent copying is not wasted. - -The two most important numbers for determining how efficient a pipelined -computation are a) how many floating point operations (FLOPs) we need to -execute and b) how many bytes we need to copy to execute that computation. -The ratio of these two (FLOPs/memory usage) is called the -*arithmetic intensity* of an operation and determines if our pipeline will -be compute bound or memory bound. - -+++ {"id": "gutx7y8uvZKH"} - -## Pipelining in Pallas - -+++ {"id": "U-dPTjlBverB"} - -How do we implement a pipeline like the one above in Pallas? -It seems like a complex sequence of asynchronous data operations and -executing kernels that would be a pain to implement manually. -Fear not! Pallas offers an API for expressing pipelines without too much -boilerplate, namely through `grid`s and `BlockSpec`s. - -See how in the above pipelined example, we are executing the same logic -multiple times: steps 3-5 and 8-10 both execute the same operations, -only on different inputs. -The {func}`jax.experimental.pallas.pallas_call` provides a way to -execute a kernel multiple times, by using the `grid` argument. -See {ref}`pallas_grid`. - -We also use {class}`jax.experimental.pallas.BlockSpec` to specify -how to construct the input of each kernel invocation. -See {ref}`pallas_blockspec`. - -In the pipelining example above, we had `(512, 512)`-shaped arrays and -split them along the leading dimension into two `(256, 512)`-shaped arrays. -In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`. -On the 1st iteration we'd -like to select `x1` and on the second iteration we'd like to use `x2`. -This can be expressed with the following `index_map`: - -```python -def x_index_map(i): - return (i, 0) -``` - -We'd then construct the `BlockSpec`: -```python -block_spec = pl.BlockSpec((256, 512), x_index_map) -``` - -The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. - -+++ {"id": "noybOKghzjwG"} - -### Putting it together - -We provide these arguments to `pallas_call` via `grid`, `in_specs` and -`out_specs` (`in_specs` corresponds to the tuple of positional arguments, -and `out_specs` corresponds to the output). - -```{code-cell} -:id: ehKAYAwIojfv -:outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de - -def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(2,) - )(x, y) - -add_matrices_pipelined(x, y) -``` - -+++ {"id": "rkytgIZYzz4t"} - -We've only added a little bit of code to our original function to add -automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy -lifting! - -How does it work? Well, the `BlockSpec`s provide enough information to start -*prefetching* blocks of our input from HBM into VMEM. -For example, if we are starting iteration `i` of our `grid`, we can pass -`i + 1` into the `index_map` functions to obtain the blocks needed for the -next iteration. We can then start an asynchronous copy for those blocks. -Similarly for outputs, we can wait for the outputs of the previous iteration -to be copied before starting the copy for the current iteration's outputs. ++++ {"id": "8Tl3wt5Wk3Ek"} -+++ {"id": "7Xtz9oMs0ZRL"} +## TPU-specific Pipelining Features -### Parameterizing a pipeline +Pallas TPU supports the following platform-specific features. -+++ {"id": "esY4GcIB0bqQ"} ++++ {"id": "1jg5WmExk47l"} -It's common to parameterize the block shapes in our kernel. Block sizes are -perhaps the most important parameter to tune when optimizing the performance -of Pallas kernels! They give us control over the pipeline (for example, -picking smaller blocks adds more iterations to our pipelined loop where each -iteration has less work to do). +### TPU Memory Spaces -Furthermore, we could also carve up the inputs and outputs along the 2nd -dimension (we are only splitting along the first right now). Let's write a -more general kernel that handles both of these features. +Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM): -```{code-cell} -:id: VartelFd0YfY - -def add_matrices_pipelined_2d( - x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 -) -> jax.Array: - m, n = x.shape - block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(m // bm, n // bn), - )(x, y) - -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y -) -``` - -+++ {"id": "KrfeYwaW1QA-"} - -## Handling reductions +| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | +| --- | --- | --- | +| `pltpu.TPUMemorySpace.ANY` | HBM (usually) or VMEM | DRAM | +| `pltpu.TPUMemorySpace.VMEM` | VMEM | SRAM | +| `pltpu.TPUMemorySpace.SMEM` | SMEM | SRAM | +| `pltpu.TPUMemorySpace.SEMAPHORE` | Semaphore | SRAM | -+++ {"id": "P3SqEKDe3Mar"} +- `TPUMemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. +- `TPUMemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. +- `TPUMemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`. +- `TPUMemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details. -How would you implement something like `jnp.sum` using `pallas_call`? -Specifically, we'd like to pipeline across the reduction dimension. +Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM. -Take the example of reducing a `(8, 512, 512)`-shaped array to a -`(512, 512)`-shaped one. - -```{code-cell} -:id: JoT-ZKEk1R7l -:outputId: fd842223-98a5-4e5c-87fc-5dadc94da4fa - -x = jnp.ones((8, 512, 512)) -jnp.sum(x, axis=0) -``` - -+++ {"id": "5O3ByvuT3iyC"} - -To do this using `pallas_call`, we could use a grid of size `(8,)` and in -each iteration `i` load `x[i]` into VMEM. -Then we could add `x[i]` to an output VMEM buffer. Let's implement this -naively first. - -```{code-cell} -:id: hqvv_WRQ3bvP -:outputId: 200648d2-3f4d-4d1a-b95a-d2c1352cd7b8 - -# Warning: this implementation is incorrect! - -def naive_sum_kernel(x_ref, o_ref): - o_ref[...] += x_ref[...] - -def naive_sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - naive_sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), - )(x) -naive_sum(x) -``` +While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`. -+++ {"id": "Kv9qJYJY4jbK"} - -Notice how we've set up the `BlockSpec`s: we're loading the entirety of -the `(512, 512)` dimension into VMEM (no pipelining there) but selecting -the `i`-th dimension of `x` each iteration in the `index_map`. -We are using a `None` for that dimension in the block shape, which indicates -that we are selecting a singleton dimension from `x` that we would like -to squeeze away in the kernel. -Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. - -`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that -`o_ref` is unchanged over the course of the pipeline. -This means that we can update its value each iteration by reading from and -writing to it. Or can it? -Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll -be accumulating into garbage. -This will result in the overall function outputting the incorrect value! - -Therefore, **whenever we do a reduction in a kernel, we need to make sure -to initialize the `Ref` that is storing the reduced value**. -We can accomplish this by conditionally writing a value to `out_ref` -when we're on iteration 0. -We can do this with the helper function `pl.when`, a convenience wrapper -around `jax.lax.cond`, and `pl.program_id`, -which queries which iteration in a grid axis we are in. +As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer: ```{code-cell} -:id: JXN2RthX5cSw -:outputId: 195df19b-a889-479b-95b6-1fb7281f1518 - -def sum_kernel(x_ref, o_ref): - @pl.when(pl.program_id(axis=0) == 0) - def _(): - o_ref[...] = jnp.zeros_like(o_ref) - - o_ref[...] += x_ref[...] - -def sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) - -sum(x) +--- +executionInfo: + elapsed: 65 + status: ok + timestamp: 1744908591430 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: zcqz1CA_o50a +--- +def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref): + pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref) + out_vmem_ref[...] = scratch_vmem_ref[...] + 1 + +x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) +out = pl.pallas_call(hbm_vmem_kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), + scratch_shapes=(pltpu.TPUMemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) +)(x) + +np.testing.assert_allclose(out, x[0:1] + 1) ``` -+++ {"id": "2828qXBI5ksZ"} - -This `sum` function now outputs the correct values! - -One last thing to note about reductions in Pallas are that **they must be -done in the minormost (rightmost) dimensions of our grid** (our grid is -1-dimensional in the above example so we are reducing over its minormost -dimension). This is because the pipeline that Pallas generates using -the `BlockSpec`s, `grid` and kernel function *does not read outputs back -from HBM*. -Once you've written an output value back to HBM you cannot revisit it. -Therefore, you cannot do a reduction across a grid dimension that has any -revisiting and therefore all reductions need to happen in the rightmost -dimensions. - +++ {"id": "KvPFez9N8cKJ"} (pallas_tpu_megacore)= -## TPUs in Megacore configuration +### TPUs in Megacore configuration +++ {"id": "0f4HAVzQ8n71"} @@ -463,8 +161,26 @@ We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`. ```{code-cell} -:id: nQNa8RaQ-TR1 -:outputId: 385ed87c-d95c-466c-af77-df3845c979f2 +--- +executionInfo: + elapsed: 106 + status: ok + timestamp: 1744910274556 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: nQNa8RaQ-TR1 +outputId: 29c0b574-3528-49a5-8a88-b6987efc69ce +--- +def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): + # Load x and y from VMEM into VREGs + x_vregs = x_vmem_ref[:, :] + y_vregs = y_vmem_ref[:, :] + # Execute a vectorized add + z_vregs = x_vregs + y_vregs + # Store the output values in VREGs back into VMEM + z_vmem_ref[:, :] = z_vregs def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) @@ -474,7 +190,8 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) @@ -491,17 +208,3 @@ simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically. > Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available). - -+++ {"id": "1ZJ2rV5W8FAe"} - -## Conclusion - -In this guide we covered how to express TPU pipelines using `pallas_call`, -`grid` and `BlockSpec`s. We covered how to express nested loops via a -multi-dimensional grid and how to handle reductions by initialize our -accumulators at the beginning of the reduction. -We also learned how to handle Megacore by adding annotations to the kernel. - -Exercises left to the reader: -* Try implementing a `sum` kernel that pipelines the other dimensions as well -* Add megacore support to the `add` kernel and the `sum` kernel as well. From 82a79c598593c287e2fcf06707c656e45a6bf307 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 21 Apr 2025 12:26:52 -0400 Subject: [PATCH 0737/1769] Fix handling of SymbolicZero output when batching custom_jvp. Co-authored-by: Matthew Johnson --- jax/_src/interpreters/batching.py | 6 +++--- tests/api_test.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index c17fe892325e..200189502db6 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -941,11 +941,11 @@ def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) - return Zero(core.unmapped_aval(sz, dst, aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, aval, mesh_axis)) elif src is not_mapped and dst is not not_mapped: - return Zero(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) elif dst is not_mapped and sum_match: - return Zero(core.mapped_aval(sz, src, x.aval)) + return type(x)(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((axis_name, x, src, dst)) else: diff --git a/tests/api_test.py b/tests/api_test.py index 0b77580f2839..5122b11c6406 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8375,6 +8375,22 @@ def g(x): jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash + def test_symbolic_zero_under_vmap_of_jit(self): + # https://github.com/jax-ml/jax/issues/28144 + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(x, t): + (x,) = x + (t,) = t + z = custom_derivatives_public.zero_from_primal(x, symbolic_zeros=True) + return f(x), z + + x = jnp.arange(3.0) + jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash + class CustomVJPTest(jtu.JaxTestCase): From e81dae611437beb3a4c281d42bca774ac66f7b30 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 22 Apr 2025 12:19:30 -0700 Subject: [PATCH 0738/1769] [Pallas Fuser] Allow multiple BlockSpec inputs to select_n push rule if they are identical PiperOrigin-RevId: 750284947 --- jax/_src/pallas/fuser/block_spec.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 89b5121f5691..19b813a03f10 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -1525,9 +1525,14 @@ def _select_n_push_rule( ): del ctx block_specs = [b for b in args if b is not pallas_core.no_block_spec] + assert len(block_specs) > 0 + block_spec = block_specs[0] if len(block_specs) > 1: - raise NotImplementedError('select_n with multiple inputs not supported yet') - return block_specs[0] + if any(b is not block_spec for b in block_specs): + raise NotImplementedError( + 'select_n with multiple differing inputs not supported yet' + ) + return block_spec @register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p) From 5c22694c1935b660f0712ecd0671c7e03a89cfbb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Apr 2025 12:25:09 -0700 Subject: [PATCH 0739/1769] DOC: link to ai-stack tutorials from JAX's front page --- docs/index.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index aafe38da9f4c..35906f1a5534 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,8 +63,9 @@ JAX: High performance array computing :link-type: ref :class-card: user-guides -If you're looking to train neural networks, use Flax_ and start with its tutorials. -For an end-to-end transformer library built on JAX, see MaxText_. +If you're looking to use JAX to train neural networks, start with the +`JAX AI Stack Tutorials`_, and then check out the `JAX AI Stack Examples`_ +to see how JAX models can be implemented using the Flax_ framework. Ecosystem --------- @@ -183,6 +184,8 @@ maintains an up-to-date list. .. _Grain: https://github.com/google/grain .. _Hugging Face Datasets: https://huggingface.co/docs/datasets/ .. _JAX MD: https://jax-md.readthedocs.io/ +.. _JAX AI Stack Tutorials: https://docs.jaxstack.ai/en/latest/tutorials.html +.. _JAX AI Stack Examples: https://docs.jaxstack.ai/en/latest/examples.html .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter .. _Lineax: https://github.com/patrick-kidger/lineax From d52c87235e59c4312815be6ea5ebcbc7d6240aa7 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Thu, 6 Mar 2025 15:33:24 -0600 Subject: [PATCH 0740/1769] Update install docs for AMDGPU Remove experimental tag for linux-x86_64 and point folks to the newly refreshed `build/rocm/README.md` --- docs/installation.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 500347e04ab1..c9bf3a62942b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -33,7 +33,7 @@ The table below shows all supported platforms and installation options. Check if | CPU | {ref}`yes ` | {ref}`yes ` | {ref}`jax≤0.4.38 only ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | | NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | | Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | +| AMD GPU | {ref}`yes ` | no | {ref}`experimental ` | n/a | no | no | | Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | | Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | @@ -226,10 +226,10 @@ refer to (install-amd-gpu)= ## AMD GPU (Linux) -JAX has experimental ROCm support. There are two ways to install JAX: +AMD GPU support is provided by a ROCm JAX plugin supported by AMD. -* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://docs.jax.dev/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +There are several ways to use JAX on AMDGPU devices. +Please see [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md) for details. (install-intel-gpu)= ## Intel GPU From 3d470f431200538828a9901bd8ecae32ba4d947f Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 22 Apr 2025 12:50:02 -0700 Subject: [PATCH 0741/1769] [Mosaic GPU] Do not create a barrier with `arrival_count == 0` in the pipeline emitter if there's nothing to wait for. Also enforce that `arrival_count` is always > 0. PiperOrigin-RevId: 750294068 --- jax/_src/pallas/mosaic_gpu/core.py | 5 +++++ jax/_src/pallas/mosaic_gpu/pipeline.py | 11 +++++++---- jax/experimental/mosaic/gpu/core.py | 6 ++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index bb08d8f090a7..5015ebdb7e1f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -624,6 +624,11 @@ def get_ref_aval(self) -> AbstractMemoryRef: ) return AbstractMemoryRef(aval, SMEM) + def __post_init__(self): + if self.num_arrivals < 1: + raise ValueError( + f"Num arrivals must be at least 1, but got {self.num_arrivals}" + ) @dataclasses.dataclass(frozen=True) class ClusterBarrier: diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 1aa75eb0bc0e..f7becf1ec6da 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -227,6 +227,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ], [len(in_specs)], ) + arrival_count = sum(map(_in_smem, in_specs)) return pl.run_scoped( functools.partial( scoped_pipeline, @@ -235,9 +236,11 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ), in_smem_refs=in_smem_refs, out_smem_refs=out_smem_refs, - barrier_ref=gpu_core.Barrier( + barrier_ref=None + if arrival_count == 0 + else gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - sum(map(_in_smem, in_specs)), + arrival_count, num_barriers=max_concurrent_steps, ), ) @@ -274,8 +277,8 @@ def loop_body(step, carry): slot = lax.rem(step, max_concurrent_steps) indices, fetch_indices, last_store_slices = carry - if in_specs: - # Wait for the current GMEM->SMEM copy to complete. + if barrier_ref is not None: + # Wait for the current GMEM->SMEM copy to complete, if any. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. if copies_out_in_loop: diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 2161d0d6d725..8e240d55d4cc 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -188,6 +188,12 @@ class Barrier: arrival_count: int num_barriers: int = 1 + def __post_init__(self): + if self.arrival_count < 1: + raise ValueError( + f"Arrival count must be at least 1, but got {self.arrival_count}" + ) + @dataclasses.dataclass(frozen=True) class ClusterBarrier: collective_dims: Sequence[gpu.Dimension] From 54862961e9c835521e1161cf578d013fefbe9fca Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 22 Apr 2025 12:57:44 -0700 Subject: [PATCH 0742/1769] [Pallas Fuser] Change physicalize to resolve_fusion_dtypes PiperOrigin-RevId: 750296702 --- jax/_src/pallas/fuser/fusible_dtype.py | 4 ++-- jax/_src/pallas/fuser/jaxpr_fusion.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 628d253e090a..7d9c2ca67855 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -300,9 +300,9 @@ def _pallas_call_physicalize_rule( def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _assert_no_fusion_types(ctx.avals_out) - physicalized_branches = [ + physicalized_branches = tuple( physicalize_closed_jaxpr(branch) for branch in branches - ] + ) flat_args = jax.tree.leaves(args) return conditionals.cond_p.bind( *flat_args, branches=physicalized_branches, **kwargs diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 95768d71f792..d1e375e33ef1 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -28,12 +28,13 @@ from jax._src.pallas.fuser.fusible import fusible_p -def fuse(f=None, *, physicalize: bool = False, debug: bool = False): +def fuse(f=None, *, resolve_fusion_dtypes: bool = True, debug: bool = False): """Fuses a function into a single fusible. Args: f: The function to fuse. - physicalize: (experimental) whether to physicalize the function. + resolve_fusion_dtypes: (experimental) whether or not to resolve fusion + dtypes (which don't correspond to physical dtypes) debug: Whether to print debug information. There should be a single call to a `fusible` inside the body of `f`. `fuse` @@ -57,7 +58,7 @@ def wrapper(*args, **kwargs): out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) return tree_util.tree_unflatten(out_tree, out_flat) - if physicalize: + if resolve_fusion_dtypes: wrapper = fusible_dtype.physicalize(wrapper) return wrapper From 1727657d523a624da8c46900839fe10a7f8adedc Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 22 Apr 2025 13:05:48 -0700 Subject: [PATCH 0743/1769] [Pallas] Propagate Jaxpr effects through pl.fusible PiperOrigin-RevId: 750299496 --- jax/_src/pallas/fuser/fusible.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/fuser/fusible.py b/jax/_src/pallas/fuser/fusible.py index 289a9dc268b4..f0d03cb18d94 100644 --- a/jax/_src/pallas/fuser/fusible.py +++ b/jax/_src/pallas/fuser/fusible.py @@ -80,7 +80,7 @@ def _(*consts_and_args, jaxpr, num_consts, **_): mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) -@fusible_p.def_abstract_eval +@fusible_p.def_effectful_abstract_eval def _(*args, jaxpr, **kwargs): del args, kwargs - return [v.aval for v in jaxpr.outvars] + return [v.aval for v in jaxpr.outvars], jaxpr.effects From 1c9ef1dbc629a5832780c0f071528be89b47cd33 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 22 Apr 2025 13:15:04 -0700 Subject: [PATCH 0744/1769] Set core_index to default for tpu_pallas_async_test PiperOrigin-RevId: 750302979 --- tests/pallas/tpu_pallas_async_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 36aba917e4ed..f7588bc0776c 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -412,7 +412,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): src_neighbor = right_neighbor dst_neighbor = left_neighbor barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0) + pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor) pltpu.semaphore_wait(barrier_sem, 1) pltpu.make_async_remote_copy( x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor, @@ -500,10 +500,8 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): jax.lax.axis_index(axis_name) + 1, axis_size ) barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0) - pltpu.semaphore_signal( - barrier_sem, device_id=right_neighbor, core_index=0 - ) + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) pltpu.semaphore_wait(barrier_sem, 2) assert x.shape[0] % 2 == 0, x.shape pltpu.make_async_remote_copy( From 9425e1c20788a57fb8fdd497d253c120b3f134f7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 22 Apr 2025 17:51:46 -0400 Subject: [PATCH 0745/1769] Update tsan requirements patch after lockfile update. --- .../workflows/requirements_lock_3_13_ft.patch | 124 +++++++++--------- 1 file changed, 63 insertions(+), 61 deletions(-) diff --git a/.github/workflows/requirements_lock_3_13_ft.patch b/.github/workflows/requirements_lock_3_13_ft.patch index 0b63cb5b8711..7e45fe2b3e26 100644 --- a/.github/workflows/requirements_lock_3_13_ft.patch +++ b/.github/workflows/requirements_lock_3_13_ft.patch @@ -1,85 +1,87 @@ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt -index e7a2968e9..d37e11ee3 100644 +index 7fce0eef6..06e2cc5d4 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt -@@ -4,6 +4,11 @@ +@@ -4,6 +4,12 @@ # - # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in + # bazel run //build:requirements_ft.update # + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy ++ + absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff -@@ -328,68 +333,6 @@ mpmath==1.3.0 \ +@@ -328,68 +334,7 @@ mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt --numpy==2.2.1 ; python_version >= "3.13" \ -- --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ -- --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ -- --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ -- --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ -- --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ -- --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ -- --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ -- --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ -- --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ -- --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ -- --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ -- --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ -- --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ -- --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ -- --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ -- --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ -- --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ -- --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ -- --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ -- --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ -- --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ -- --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ -- --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ -- --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ -- --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ -- --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ -- --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ -- --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ -- --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ -- --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ -- --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ -- --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ -- --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ -- --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ -- --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ -- --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ -- --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ -- --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ -- --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ -- --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ -- --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ -- --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ -- --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ -- --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ -- --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ -- --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ -- --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ -- --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ -- --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ -- --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ -- --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ -- --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ -- --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ -- --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ -- --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 +-numpy==2.2.5 \ +- --hash=sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70 \ +- --hash=sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a \ +- --hash=sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4 \ +- --hash=sha256:0bcb1d057b7571334139129b7f941588f69ce7c4ed15a9d6162b2ea54ded700c \ +- --hash=sha256:0cd48122a6b7eab8f06404805b1bd5856200e3ed6f8a1b9a194f9d9054631beb \ +- --hash=sha256:19f4718c9012e3baea91a7dba661dcab2451cda2550678dc30d53acb91a7290f \ +- --hash=sha256:1a161c2c79ab30fe4501d5a2bbfe8b162490757cf90b7f05be8b80bc02f7bb8e \ +- --hash=sha256:1f4a922da1729f4c40932b2af4fe84909c7a6e167e6e99f71838ce3a29f3fe26 \ +- --hash=sha256:261a1ef047751bb02f29dfe337230b5882b54521ca121fc7f62668133cb119c9 \ +- --hash=sha256:262d23f383170f99cd9191a7c85b9a50970fe9069b2f8ab5d786eca8a675d60b \ +- --hash=sha256:2ba321813a00e508d5421104464510cc962a6f791aa2fca1c97b1e65027da80d \ +- --hash=sha256:2c1a1c6ccce4022383583a6ded7bbcda22fc635eb4eb1e0a053336425ed36dfa \ +- --hash=sha256:352d330048c055ea6db701130abc48a21bec690a8d38f8284e00fab256dc1376 \ +- --hash=sha256:369e0d4647c17c9363244f3468f2227d557a74b6781cb62ce57cf3ef5cc7c610 \ +- --hash=sha256:36ab5b23915887543441efd0417e6a3baa08634308894316f446027611b53bf1 \ +- --hash=sha256:37e32e985f03c06206582a7323ef926b4e78bdaa6915095ef08070471865b906 \ +- --hash=sha256:3a801fef99668f309b88640e28d261991bfad9617c27beda4a3aec4f217ea073 \ +- --hash=sha256:3d14b17b9be5f9c9301f43d2e2a4886a33b53f4e6fdf9ca2f4cc60aeeee76372 \ +- --hash=sha256:422cc684f17bc963da5f59a31530b3936f57c95a29743056ef7a7903a5dbdf88 \ +- --hash=sha256:4520caa3807c1ceb005d125a75e715567806fed67e315cea619d5ec6e75a4191 \ +- --hash=sha256:47834cde750d3c9f4e52c6ca28a7361859fcaf52695c7dc3cc1a720b8922683e \ +- --hash=sha256:47f9ed103af0bc63182609044b0490747e03bd20a67e391192dde119bf43d52f \ +- --hash=sha256:498815b96f67dc347e03b719ef49c772589fb74b8ee9ea2c37feae915ad6ebda \ +- --hash=sha256:54088a5a147ab71a8e7fdfd8c3601972751ded0739c6b696ad9cb0343e21ab73 \ +- --hash=sha256:55f09e00d4dccd76b179c0f18a44f041e5332fd0e022886ba1c0bbf3ea4a18d0 \ +- --hash=sha256:5a0ac90e46fdb5649ab6369d1ab6104bfe5854ab19b645bf5cda0127a13034ae \ +- --hash=sha256:6411f744f7f20081b1b4e7112e0f4c9c5b08f94b9f086e6f0adf3645f85d3a4d \ +- --hash=sha256:6413d48a9be53e183eb06495d8e3b006ef8f87c324af68241bbe7a39e8ff54c3 \ +- --hash=sha256:7451f92eddf8503c9b8aa4fe6aa7e87fd51a29c2cfc5f7dbd72efde6c65acf57 \ +- --hash=sha256:8b4c0773b6ada798f51f0f8e30c054d32304ccc6e9c5d93d46cb26f3d385ab19 \ +- --hash=sha256:8dfa94b6a4374e7851bbb6f35e6ded2120b752b063e6acdd3157e4d2bb922eba \ +- --hash=sha256:97c8425d4e26437e65e1d189d22dff4a079b747ff9c2788057bfb8114ce1e133 \ +- --hash=sha256:9d75f338f5f79ee23548b03d801d28a505198297534f62416391857ea0479571 \ +- --hash=sha256:9de6832228f617c9ef45d948ec1cd8949c482238d68b2477e6f642c33a7b0a54 \ +- --hash=sha256:a4cbdef3ddf777423060c6f81b5694bad2dc9675f110c4b2a60dc0181543fac7 \ +- --hash=sha256:a9c0d994680cd991b1cb772e8b297340085466a6fe964bc9d4e80f5e2f43c291 \ +- --hash=sha256:aa70fdbdc3b169d69e8c59e65c07a1c9351ceb438e627f0fdcd471015cd956be \ +- --hash=sha256:abe38cd8381245a7f49967a6010e77dbf3680bd3627c0fe4362dd693b404c7f8 \ +- --hash=sha256:b13f04968b46ad705f7c8a80122a42ae8f620536ea38cf4bdd374302926424dd \ +- --hash=sha256:b4ea7e1cff6784e58fe281ce7e7f05036b3e1c89c6f922a6bfbc0a7e8768adbe \ +- --hash=sha256:b6f91524d31b34f4a5fee24f5bc16dcd1491b668798b6d85585d836c1e633a6a \ +- --hash=sha256:c26843fd58f65da9491165072da2cccc372530681de481ef670dcc8e27cfb066 \ +- --hash=sha256:c42365005c7a6c42436a54d28c43fe0e01ca11eb2ac3cefe796c25a5f98e5e9b \ +- --hash=sha256:c8b82a55ef86a2d8e81b63da85e55f5537d2157165be1cb2ce7cfa57b6aef38b \ +- --hash=sha256:ced69262a8278547e63409b2653b372bf4baff0870c57efa76c5703fd6543282 \ +- --hash=sha256:d2e3bdadaba0e040d1e7ab39db73e0afe2c74ae277f5614dad53eadbecbbb169 \ +- --hash=sha256:d403c84991b5ad291d3809bace5e85f4bbf44a04bdc9a88ed2bb1807b3360bb8 \ +- --hash=sha256:d7543263084a85fbc09c704b515395398d31d6395518446237eac219eab9e55e \ +- --hash=sha256:d8882a829fd779f0f43998e931c466802a77ca1ee0fe25a3abe50278616b1471 \ +- --hash=sha256:e4f0b035d9d0ed519c813ee23e0a733db81ec37d2e9503afbb6e54ccfdee0fa7 \ +- --hash=sha256:e8b025c351b9f0e8b5436cf28a07fa4ac0204d67b38f01433ac7f9b870fa38c6 \ +- --hash=sha256:eb7fd5b184e5d277afa9ec0ad5e4eb562ecff541e7f60e69ee69c8d59e9aeaba \ +- --hash=sha256:ec31367fd6a255dc8de4772bd1658c3e926d8e860a0b6e922b615e532d320ddc \ +- --hash=sha256:ee461a4eaab4f165b68780a6a1af95fb23a29932be7569b9fab666c407969051 \ +- --hash=sha256:f5045039100ed58fa817a6227a356240ea1b9a1bc141018864c306c1a16d4175 - # via -- # -r build/requirements.in +- # -r build/freethreading-requirements.txt - # contourpy - # matplotlib - # ml-dtypes - # scipy ++ nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ From 7b1f6d1681249e50af133fa8dcc65032a07937d7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Apr 2025 15:12:42 -0700 Subject: [PATCH 0746/1769] Initial commit of non-experimental `shard_map`. The signature is: ``` jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True) ``` This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: `check_rep` is renamed to `check_vma` and all arguments (except `f`) to `shard_map` are keyword only and `f` is positional only. **But why are mesh and in_specs optional? And what is the new `axis_names` argument?** * `mesh` is optional because it can be inferred from the context if user sets the mesh via `jax.sharding.use_mesh(mesh)`. * `in_specs` is optional because it can be inferred from the arguments passed to `shard_map` if all mesh axes are `Explicit`. * `axis_names`: axis_names tells `shard_map` which axes are `Manual`. If empty, it implies the `shard_map` is `Manual` over all mesh axes. Before in the experimental endpoint of `shard_map`, this argument was called `auto`. But after the advent of `sharding_in_types`, mesh axes can be `Auto`, `Explicit` or `Manual`. So `auto` was not enough since axes can be `Explicit` too. That's why `jax.shard_map` flips the argument to `axis_names`. **If `in_specs` is optional, why is `out_specs` compulsory?** This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything. PiperOrigin-RevId: 750343135 --- jax/experimental/shard_map.py | 108 +++++++++++++++++++++++++++------- tests/shard_map_test.py | 95 ++++++++++++++++++++++++++---- 2 files changed, 172 insertions(+), 31 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 878c42a8d591..52dc45b68ef0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -127,37 +127,85 @@ def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, .. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html """ - return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) - -def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs | Callable[[], Specs], - check_rep: bool, auto: frozenset[AxisName]): + axis_names = frozenset(mesh.axis_names) - auto + return shard_map2(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + check_vma=check_rep, axis_names=axis_names) + + +def shard_map2(f, /, *, out_specs: Specs, + axis_names: set[AxisName] | frozenset[AxisName] = set(), + in_specs: Specs | None = None, + mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): + return _shard_map(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=axis_names, check_vma=check_vma) + +def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, + in_specs: Specs, out_specs: Specs | Callable[[], Specs], + axis_names: set[AxisName] | frozenset[AxisName], + check_vma: bool): if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") + + if mesh is None: + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Either use" + " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass" + " a mesh to `shard_map` via the `mesh` keyword argument.") if not isinstance(mesh, (Mesh, AbstractMesh)): raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " "`jax.sharding.AbstractMesh` instance for its " f"second argument, but got {mesh} of type {type(mesh)}.") + + if not isinstance(axis_names, (frozenset, set)): + raise TypeError( + "`axis_names` argument of shard_map should be of type `frozenset` or" + f" `set`. Got type: {type(axis_names)}") + if isinstance(axis_names, set): + axis_names = frozenset(axis_names) + if not axis_names: + axis_names = frozenset(mesh.axis_names) + auto = frozenset(mesh.axis_names) - frozenset(axis_names) if not auto.issubset(mesh.axis_names): raise ValueError(f"shard_map requires auto={auto} to be a subset of " f"mesh.axis_names={mesh.axis_names}") - _check_specs(SpecErrorType.input, in_specs, auto) + + if in_specs is not None: + _check_specs(SpecErrorType.input, in_specs, auto) if not callable(out_specs): _check_specs(SpecErrorType.out, out_specs, auto) @util.wraps(f) @traceback_util.api_boundary def wrapped(*args): - fun = lu.wrap_init(f, - debug_info=api_util.debug_info("shard_map", f, args, {})) + fun = lu.wrap_init( + f, debug_info=api_util.debug_info("shard_map", f, args, {})) args_flat, in_tree = tree_flatten(args) fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - try: in_specs_flat = broadcast_prefix(in_specs, args, - is_leaf=lambda x: x is None) + + # TODO(yashkatariya): Maybe we don't have to be this strict? + if mesh._any_axis_auto and in_specs is None: + raise TypeError( + "shard_map in_specs argument must be a pytree of" + " `jax.sharding.PartitionSpec` instances, but it was None when mesh" + f" {mesh} has `Auto` axes.\n") + + try: + in_specs_flat = broadcast_prefix( + in_specs, args, is_leaf=lambda x: x is None) except ValueError: e, *_ = prefix_errors(in_specs, args) raise e('shard_map in_specs') from None + + # TODO(yashkatariya): Relax this and convert only `None`s in `in_specs_flat` + # and accept the other specs as is. + if mesh._are_all_axes_explicit and in_specs is None: + arg_s = [typeof(a).sharding for a in args_flat] + assert all(i is None for i in in_specs_flat), in_specs_flat + in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s] + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) if s is not None) fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) @@ -172,19 +220,20 @@ def out_names_thunk(): else: out_specs_ = out_specs dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) - try: out_specs_flat = broadcast_prefix(out_specs_, dummy) + try: + out_specs_flat = broadcast_prefix(out_specs_, dummy) except ValueError: e, *_ = prefix_errors(out_specs_, dummy) raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) - if check_rep: + if check_vma: fun = _implicit_pvary_on_output(fun, out_names_thunk) try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto) + out_names_thunk=out_names_thunk, check_rep=check_vma, auto=auto) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -202,6 +251,7 @@ def out_names_thunk(): return tree_unflatten(out_tree(), out_flat) return wrapped + # Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: @@ -211,6 +261,23 @@ def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: else: return spec +def _manual_spec(manual_axes, spec: P) -> P: + out = [] # type: ignore + for s in spec: + if s is None: + out.append(s) + elif isinstance(s, tuple): + temp = [p if p in manual_axes else None for p in s] + while temp and temp[-1] is None: + temp.pop() + if None in temp: + raise ValueError(f"Invalid spec: {spec}") + out.append(None if len(temp) == 0 else tuple(temp)) + else: + out.append(s if s in manual_axes else None) + return P(*out) + + # Error checking and messages SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) @@ -259,7 +326,7 @@ class NoFail: pass no_fail = NoFail() def _check_specs_vs_args( - f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, + f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], xs: Sequence) -> None: in_avals = map(core.shaped_abstractify, xs) @@ -331,7 +398,7 @@ def _spec_rank_error( return msg def _spec_divisibility_error( - f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, + f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: ba = _try_infer_args(f, tree) fun_name = getattr(f, '__name__', str(f)) @@ -377,8 +444,8 @@ def _spec_divisibility_error( f"padding the input and adapting '{fun_name}' appropriately.") return msg -def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[set | NoFail]) -> str: +def _inout_rep_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, + specs: Specs, fails: list[set | NoFail]) -> str: fun_name = getattr(f, '__name__', str(f)) msgs = [] for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): @@ -413,7 +480,7 @@ def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, "check_rep=False argument to shard_map.") return msg -def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]: +def _unmentioned(mesh: Mesh | AbstractMesh, names: AxisNames) -> list[AxisName]: name_set = {n for ns in names.values() for n in ns} return [n for n in mesh.axis_names if n not in name_set] @@ -1621,8 +1688,9 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) - return (_shard_map(fun.call_wrapped, mesh, in_specs, out_specs, - check_rep=False, auto=frozenset()), + return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=False, + axis_names=set(mesh.axis_names)), in_specs, out_specs) @lu.transformation2 diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 6235ae3b60ec..2c2d84ca03ef 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -40,7 +40,7 @@ from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, get_abstract_mesh from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util @@ -48,7 +48,7 @@ import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.shard_map import shard_map +from jax.experimental.shard_map import shard_map, shard_map2 config.parse_flags_with_absl() @@ -2007,10 +2007,7 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - auto=frozenset({'j'}))(x) + x = shard_map2(g, out_specs=P('i', None), axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -2052,10 +2049,8 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', 'j', None, None), - out_specs=P('i', 'j', None, None), - auto=frozenset({'k', 'l'}))(x) + x = shard_map2(g, out_specs=P('i', 'j', None, None), + axis_names=frozenset({'i', 'j'}))(x) self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None)) return x @@ -2198,7 +2193,7 @@ def f(x): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"): + with self.assertRaisesRegex(ValueError, "contains a manual axes.*of mesh"): f(v) def test_partial_auto_error_wrong_in_specs(self): @@ -2934,6 +2929,84 @@ def body(carry, _): g(x, y, z) # doesn't crash + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap2_full_manual_context_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(shard_map2, out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap2_partial_manual_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(shard_map2, axis_names=frozenset('x'), out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y',)) + self.assertEqual(x.aval.sharding.spec, P(None, 'y')) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(None, 'y')) + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap2_full_manual_context_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(shard_map2, in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap2_partial_manual_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(shard_map2, axis_names=frozenset('x'), in_specs=P('x'), + out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y',)) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + + def test_no_mesh_context_error(self): + with self.assertRaisesRegex(ValueError, "The context mesh cannot be empty"): + shard_map2(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) + class FunSpec(NamedTuple): name: str From 14399f381488e7b83cb0caa43a0843d065d3cb4d Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 10 Apr 2025 15:46:54 +0000 Subject: [PATCH 0747/1769] Account for versioned clang binaries --- build/tools/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/build/tools/utils.py b/build/tools/utils.py index c52b89a1e6d2..4bf871067501 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -206,8 +206,13 @@ def get_clangpp_path(clang_path): clang_path = pathlib.Path(clang_path) clang_exec_name = clang_path.name clangpp_exec_name = clang_exec_name - if "clang++" not in clang_exec_name: - clangpp_exec_name = clang_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clang_exec_name + # Try and match what the user passed in (either clang-18 or clang) + if "clang++" not in clangpp_exec_name: + clangpp_exec_name = clangpp_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + clangpp_exec_name = "clang++" clangpp_path = clang_path.parent / clangpp_exec_name if not clangpp_path.exists(): raise FileNotFoundError( From 4adeec0feeb714e663d554a93432198bff9ea6a0 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 22 Apr 2025 16:41:41 -0700 Subject: [PATCH 0748/1769] [ragged-paged-attn] Autotune set up and increase page size to avoid SREG spill. Taking new factors into account for auto tunning: - q_dtype_name - kv_dtype_name - num_q_heads_per_blk - num_kv_heads_per_blk - head_dim - page_size - max_num_batched_tokens - max_model_len = page_size * pages_per_seq We only has 32 SREGs in TensorCore. If the page size is small, we can easily spill SREGs. This cl suggests using `page_size = max_model_len // 16` which will make sure at most 16 SREGs will be used for KV page indices per sequence. PiperOrigin-RevId: 750370022 --- .../ops/tpu/ragged_paged_attention/kernel.py | 15 +- .../tuned_block_sizes.py | 455 ++++++++++++++++-- 2 files changed, 414 insertions(+), 56 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index a74727adff64..3500ba3ee9fd 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -751,17 +751,24 @@ def ragged_paged_attention( assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype + ) num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block if num_q_per_blk is None or num_kv_pages_per_blk is None: num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( - num_q_heads, num_kv_heads, num_q_tokens, page_size, pages_per_seq + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads num_q_blks = cdiv(num_q_tokens, num_q_per_blk) - num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype - ) assert num_combined_kv_heads_per_blk % 2 == 0 num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py index 85f22f58ae3f..df2f1c4ea83f 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py @@ -15,42 +15,347 @@ """Auto-tuned block sizes for ragged paged attention.""" import jax +import jax.numpy as jnp +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 -# TODO: add more tuned block sizes in the table -# ragged_paged_attention -# key: (num_q_head, num_kv_head, num_q_tokens, max_model_len) -# value: (num_kv_pages_per_block, num_queries_per_block) +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block TUNED_BLOCK_SIZES = { - # go/keep-sorted start - (1, 1, 1024, 128): (32, 32), - (1, 1, 1024, 2048): (64, 32), - (1, 1, 1024, 4096): (64, 32), - (1, 1, 1024, 64): (32, 32), - (32, 8, 1024, 128): (32, 32), - (32, 8, 1024, 2048): (64, 32), - (32, 8, 1024, 4096): (64, 32), - (32, 8, 1024, 64): (32, 32), - (32, 8, 2048, 128): (32, 32), - (32, 8, 2048, 2048): (128, 32), - (32, 8, 2048, 4096): (128, 32), - (32, 8, 2048, 64): (32, 32), - (32, 8, 4096, 128): (32, 32), - (32, 8, 4096, 2048): (128, 64), - (32, 8, 4096, 4096): (128, 64), - (32, 8, 4096, 64): (32, 32), - (4, 1, 2048, 128): (32, 32), - (4, 1, 2048, 2048): (128, 64), - (4, 1, 2048, 4096): (128, 64), - (4, 1, 2048, 64): (32, 32), - (4, 1, 4096, 128): (32, 32), - (4, 1, 4096, 2048): (128, 128), - (4, 1, 4096, 4096): (128, 128), - (4, 1, 4096, 64): (32, 32), - # go/keep-sorted end + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, } - def next_power_of_2(x: int): """Finds the smallest power of 2 >= x using bit manipulation. @@ -66,10 +371,28 @@ def next_power_of_2(x: int): return 1 << (x - 1).bit_length() -def simplify_key(num_q_head, num_kv_head, num_q_tokens, max_model_len): - num_q_tokens = next_power_of_2(num_q_tokens) - max_model_len = next_power_of_2(max_model_len) - return num_q_head, num_kv_head, num_q_tokens, max_model_len +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) def get_tpu_version() -> int: @@ -83,24 +406,52 @@ def get_tpu_version() -> int: return int(kind[-1]) +def get_device_name(num_devices:int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + def get_tuned_block_sizes( - num_q_head, num_kv_head, num_q_tokens, page_size, pages_per_seq + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, ) -> tuple[int, int]: - """Searchs for best (num_kv_pages_per_blk, num_queries_per_blk).""" - if get_tpu_version() < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - if get_tpu_version() == 4: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: # This default block size is not tuned, only make sure there's no # OOM in vmem - num_kv_pages_per_blk = 16 - num_queries_per_blk = 128 - return num_kv_pages_per_blk, num_queries_per_blk - - max_model_len = pages_per_seq * page_size - key = simplify_key(num_q_head, num_kv_head, num_q_tokens, max_model_len) - num_kv_pages_per_blk, num_queries_per_blk = TUNED_BLOCK_SIZES.get( - key, (128, 32) - ) - num_kv_pages_per_blk = min(num_kv_pages_per_blk, pages_per_seq) - num_queries_per_blk = min(num_queries_per_blk, num_q_tokens) - return num_kv_pages_per_blk, num_queries_per_blk + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) From 57d1df4d69ae65e784f239f7408ab3abdfdad553 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Tue, 22 Apr 2025 17:35:20 -0700 Subject: [PATCH 0749/1769] Skip unary_ops_accuracy test for TPU version 7 and above. PiperOrigin-RevId: 750385718 --- tests/unary_ops_accuracy_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index 6f51651af687..02cc81695d57 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -176,6 +176,10 @@ def setUp(self): self.skipTest("Test requires StableHLO v1.10.0 or higher.") if not jtu.is_device_tpu(): self.skipTest("Skipping test on non TPU devices.") + # TODO(b/412112097): Enable this test on TPU version 7 and above once + # accuracy analysis is done. + if jtu.get_tpu_version() >= 7: + self.skipTest("Accuracy analysis is not yet done on TPU version 7 and above.") super().setUp() def test_result_accuracy_mode_attr(self): From 3c8698202b6dc28ed0abac8cba4ea16c44d4afaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 22 Apr 2025 17:53:47 -0700 Subject: [PATCH 0750/1769] [Mosaic:TPU][Relayout] Row shifts for packed types and non-native tilings PiperOrigin-RevId: 750390355 --- .../tpu/transforms/apply_vector_layout.cc | 445 ++++++++---------- 1 file changed, 194 insertions(+), 251 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d6452cba8b8d..dadf0498db3c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2169,6 +2169,34 @@ LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: unsupported layout change"); } +Value createSubelementMask(OpBuilder &builder, const Location loc, + const int bitwidth, const int64_t from, + const int64_t to, + const std::array target_shape) { + auto create_index_const = [&](const int64_t idx) { + return builder.create( + loc, builder.getIntegerAttr(builder.getIndexType(), idx)); + }; + const int packing = 32 / bitwidth; + const VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + // Prefer CreateMaskOp if possible - more efficient and supports unpacked + // TODO: b/412754162 - We can probably always use the CreateSubelementMaskOp + // if (1) optimize it on TPUv4 and (2) Add support for unpacked types in some + // of the invariants in lower_to_llo. + if (from % packing == 0 && to % packing == 0) { + const int64_t from_sublane = from / packing; + const int64_t to_sublane = to / packing; + return builder.create( + loc, vmask_ty, + ArrayRef{create_index_const(from_sublane), + create_index_const(0)}, + ArrayRef{create_index_const(to_sublane), + create_index_const(target_shape[1])}); + } + return builder.create(loc, vmask_ty, from, to); +} + // TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So // we do not need template for the op type and to explicitly force amount // argument to dynamic. @@ -2698,23 +2726,9 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), bitwidth, ctx.target_shape); if (tiling_dim.value() == 0) { // sublane - if (operand_offset % packing != 0) { - // Packed case, degenerate where we have a half or quarter - // sublane. - // TODO(mvoz): We can probably always use the - // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add - // support for unpacked types in some of the invariants in - // lower_to_llo. - mask = builder.create( - op.getLoc(), vmask_ty, 0, operand_offset); - } else { - auto sublane_offset = operand_offset / packing; - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(sublane_offset), - boundIdxConst(layout->tiling()[1])}); - } + mask = createSubelementMask(builder, op.getLoc(), bitwidth, + /*from=*/0, /*to=*/operand_offset, + ctx.target_shape); } else { // lane mask = builder.create( op.getLoc(), vmask_ty, @@ -5305,182 +5319,6 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, return src_vreg; } -// This function is based on tpu_rotate_rule. It applies a shift of amount to -// a given dim. A major difference is that it "overflows", i.e. if the shift -// amount is such that it pushes us into a new vreg, we create a new vreg and -// fill it in with the remaining rows. -// -// The shift is the difference between layout_in and layout_out, on the -// given dim. -FailureOr> tpu_rotate_with_overflow( - OpBuilder &builder, const std::array target_shape, - const Location loc, const VectorType vty, xla::Array in_tiles, - int64_t dim, const VectorLayout &layout_in, - const LayoutOffsets offsets_out) { - if (!layout_in.hasNativeTiling(target_shape)) { - return emitError(loc, "Not implemented: non-native tiling for layout"); - } - if (layout_in.bitwidth() != 32) { - return emitError(loc, - "Not implemented: multi-row shift with " - "bitwidth != 32"); - } - // TODO(apaszke,mvoz): Just use offsets_out instead of this. - VectorLayout layout_out(layout_in.bitwidth(), offsets_out, layout_in.tiling(), - layout_in.implicit_dim()); - - int64_t tiling_dim = dim - (in_tiles.num_dimensions() - 2); - if (tiling_dim != 0) { - return emitError(loc, - "Rotate with overflow untested for " - "dim != 0"); - } - auto amount = - *layout_out.offsets()[tiling_dim] - *layout_in.offsets()[tiling_dim]; - - SmallVector dst_tiles_shape = - layout_out.tileArrayImplicitShape(vty.getShape(), target_shape); - - const VectorType res_vreg_ty = - getNativeVregType(vty.getElementType(), target_shape); - - xla::Array out_tiles(dst_tiles_shape); - - // We update the result vregs in the following way: - // - If the offset is positive, write the first tile as is, if the offset - // is negative, blend it with the next tile. - // - Blend the rest of the tiles with the prior (positive offset) or next - // (negative offset) tile. - // - (In positive cases, we can get an extra vreg (overflow)) we write the - // remaining tiles. - // This only happens if the original input vreg size is smaller than the - // result vreg size (an offset) can "push" us into a new vreg. - // - // Ex: (30, 128), starting offset 0, shift by 6, native tiling (8, 128) - // The input is (4, 1), where the first 3 vregs are full (0-24) - // and the last vreg is filled in rows 0-6. When we offset it by 6, we - // need a 4th vreg, as now vreg 0 is filled in 6-8 (2 total), vreg 1, 2, 3 - // are filled in fully (8-16, 16-24, 24-32) (2 + 24 total), and vreg 4 is - // filled in 0-4. (2 + 24 + 4 = 30). - - // Negative offset amount means we: - // - // Ex 1: (30, 128), input offset 6, shift by -2, native tiling (8, 128) - // (The result of the last example, for simplicity). In this case, we have - // (5, 1) vregs as decribed above. Because the shift does not cause us to - // shift back from the 5th vreg, we still need it. In such a case, the result - // vreg is still (5, 1). - // - // - Write the first vreg as is. - // - The next vregs are blended with the prior one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 4-8, - // pulling 2 rows from the next vreg. - // - The last tile is masked to only write the remaining rows. - // Ex: Vreg 4 goes from 0-4 to 0-2. - // - // Ex 2: (30, 128), starting offset 6, shift by -6, native tiling (8, 128) - // In this case, we have (5, 1) vregs as described above. Because the shift - // causes us to shift back from the 5th vreg, we don't need it anymore. - // In such a case, the result vreg is (4, 1). - // - // - All vregs are blended with the next one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 0-8, - // pulling 6 rows from the next vreg. - // - The last tile is discarded - it was fully subsumed by the prior blends. - // - // Ex 3: (30, 128), starting offset 0, shift by -6, native tiling (8, 128) - // In this case, we have (4, 1) vregs as described above. - // In such a case, the result vreg is (4, 1), where the first vreg is filled - // in rows 2-8 (6), and vregs 1 and 2 are filled in fully (8-16, 16-24), and - // vreg 3 is filled in rows 0-6. - // - // NOTE - in such cases, where the abs(shift) in a negative shift > starting - // offset, we can actually implement this as a positive shift of the delta - // from the native tile size. - // in the example above, the delta is 8 - 6 + 0 = 2. The resulting vregs are - // the same as if we had shifted by 2, starting at offset 0. - // - // Another example to demonstrate the point: - // Ex 4: (30, 128), starting offset 2, shift by -4, native tiling (8, 128) - // In this case, we start with (4, 1) vregs as described above. - // (2-8)(8-16)(16-24)(0-4). Shifting by -4 is the same as 8 - 4 + 2 = 6. - // So we can just shift by 6, starting at offset 0. - // Vreg 0 is filled in 6-8 (2 total), vreg 1, 2 and 3 are filled in fully - // (8-16, 16-24, 24-32) (2 + 24 total = 26) vreg 4 is filled with the - // remainder, 0-4 (30 total). - // - // This means that no matter what the shift is, we should always - // rotate and compute the shift amount in such a way that the first input - // vreg is the first output vreg. - - // Compute the mask for the blend. - // Positive blends blend "forward" and negative blends blend "backward". - auto mask_val = amount; - auto vreg_rot_amount = amount; - if (amount < 0) { - mask_val = layout_in.tiling()[tiling_dim] - std::abs(amount); - vreg_rot_amount += target_shape[tiling_dim]; - } - auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); - auto mask = builder.create( - loc, VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(0)}, - ValueRange{boundIdxConst(mask_val), boundIdxConst(target_shape[1])}); - - // Actually do the rotation. - in_tiles.Each([&](absl::Span idxs, Value *v) { - if (dim >= in_tiles.num_dimensions() - 2) { - *v = builder.create(loc, res_vreg_ty, in_tiles(idxs), - vreg_rot_amount, tiling_dim, nullptr, - nullptr); - } - }); - - // Walk the result tiles. - // TODO(mvoz): There is a micro-optimization here where we can avoid - // allocating blend indices per vreg. - out_tiles.Each([&](absl::Span idxs, Value *v) { - if (idxs[dim] == 0) { - // A negative shift amount means we need to blend the first tile with the - // next one, but only if we're not at the end of the input. - if (amount < 0 && (idxs[dim] + 1 < in_tiles.dim(dim))) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Positive shift, or negative shift at the end of the input. - *v = in_tiles(idxs); - } - } else if (idxs[dim] < in_tiles.dim(dim)) { - // write the rest as blended up to the end of the input - if (amount < 0) { - if (idxs[dim] + 1 < in_tiles.dim(dim)) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Nothing to blend with, just write the last tile. - *v = in_tiles(idxs); - } - } else { - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = builder.create(loc, mask, in_tiles(prior_idx), - in_tiles(idxs)); - } - } else { - // write trailing if it's there (positive shift, increasing vreg count) - // Use the last prior - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = in_tiles(prior_idx); - } - }); - - return out_tiles; -} void rotateVregs(OpBuilder &builder, xla::Array &vregs, const int64_t amount, const int dimension) { @@ -5514,6 +5352,9 @@ void rotateLanes(OpBuilder &builder, xla::Array &vregs, // For these purposes, the vreg is considered to have shape (row_packing * // target_shape[0], target_shape[1]) // +// Note: When rotating by a whole number of sublanes, there are no low bits, so +// null is returned when is_high is false. +// // Args: // vreg: The vreg to rotate // rotate_amount: The amount to rotate the vreg by. @@ -5530,12 +5371,11 @@ Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, CHECK_LT(0, rows_per_sublane); const int64_t bits_per_row = 32 / rows_per_sublane; const int64_t sublane_rotate_amount = - rotate_amount / rows_per_sublane + (is_high ? 0 : 1); + (rotate_amount / rows_per_sublane + (is_high ? 0 : 1)) % target_shape[0]; const int64_t within_sublane_rotate_amount = rotate_amount % rows_per_sublane; - vreg = builder.create(vreg.getLoc(), vreg, - /*amount=*/sublane_rotate_amount, - /*dimension=*/0, /*stride=*/nullptr, - /*stride_dimension=*/nullptr); + if (within_sublane_rotate_amount == 0 && !is_high) { + return nullptr; + } if (within_sublane_rotate_amount != 0) { const VectorType vreg_ty = cast(vreg.getType()); const VectorType i32_vreg_ty = @@ -5559,7 +5399,159 @@ Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, } vreg = builder.create(loc, vreg_ty, vreg); } - return vreg; + return builder.create(vreg.getLoc(), vreg, + /*amount=*/sublane_rotate_amount, + /*dimension=*/0, /*stride=*/nullptr, + /*stride_dimension=*/nullptr); +} + +FailureOr> doRowShiftRelayout( + OpBuilder &builder, const Location loc, const ArrayRef shape, + xla::Array src_vregs, const VectorLayout &src_layout, + const int64_t dst_row_offset, const std::array target_shape) { + constexpr int32_t kNativeBitwidth = 32; + const std::array tiling = src_layout.tiling(); + const std::array tiled_ishape = + src_layout.getImplicitTiledDims(shape, 1); + const int64_t sublanes_per_tile = src_layout.sublanesPerTile(target_shape); + const int64_t tiles_per_vreg = src_layout.tilesPerVreg(target_shape); + const LayoutOffsets &src_offsets = src_layout.offsets(); + CHECK(src_offsets[0].has_value()); + CHECK_GE(*src_offsets[0], 0); + CHECK_LT(*src_offsets[0], tiling[0]); + CHECK_GE(dst_row_offset, 0); + CHECK_LT(dst_row_offset, tiling[0]); + CHECK_EQ(tiling[0] % sublanes_per_tile, 0); + const int64_t rows_per_sublane = tiling[0] / sublanes_per_tile; + const int64_t bits_per_row = kNativeBitwidth / rows_per_sublane; + const int64_t row_shift_amount = dst_row_offset - *src_offsets[0]; + // How many rows to shift (positive): + const int64_t shift_in_tile = (row_shift_amount + tiling[0]) % tiling[0]; + // How many rows to shift within a single sublane: + const int64_t shift_in_sublane = shift_in_tile % rows_per_sublane; + CHECK(src_vregs.begin() != src_vregs.end()); + const VectorType vreg_ty = cast(src_vregs.begin()->getType()); + const VectorType int_vreg_ty = + getNativeVregType(builder.getIntegerType(bits_per_row), target_shape); + + // The mask selects the first row_shift_amount full/half/quarter/etc-sublanes + // of each tile that contains data. + Value mask = nullptr; + for (int64_t i = 0; i < tiles_per_vreg; ++i) { + const int64_t start = i * sublanes_per_tile * rows_per_sublane; + // TODO: b/412753800 - Skip tiles that never contain data + Value tile_mask = + createSubelementMask(builder, loc, bits_per_row, /*from=*/start, + /*to=*/start + shift_in_tile, target_shape); + mask = mask == nullptr ? tile_mask + : builder.create(loc, mask, tile_mask); + } + + xla::Array res_vregs( + VectorLayout(src_layout.bitwidth(), {dst_row_offset, src_offsets[1]}, + src_layout.tiling(), src_layout.implicit_dim()) + .tileArrayImplicitShape(shape, target_shape)); + // rotate_rows_and_blend returns the combined high and low bits of a vreg + // after rotation by shift_in_tile. data_start and data_end (exclusive) are + // the rows of interest in the resulting vreg. + auto rotate_rows_and_blend = [&](Value vreg, const int64_t data_start, + const int64_t data_end) -> Value { + CHECK(vreg != nullptr); + // The split between low and high bits is at shift_in_sublane rows. + Value low_bits, high_bits; + // start_sublane is the first sublane in a tile that contains data + const int64_t start_sublane = data_start / rows_per_sublane; + // end_sublane the last sublane in a tile that contains data, inclusive + const int64_t end_sublane = (data_end - 1) / rows_per_sublane; + + // If data is in the high bits only, skip low bits + // This happens iff data is in a single sublane and begins after the split + if (start_sublane != end_sublane || + data_start % rows_per_sublane < shift_in_sublane) { + // Note that if shift_in_sublane is 0, rotateVregRows will return null + // since there are no low bits. + low_bits = + rotateVregRows(builder, loc, vreg, shift_in_tile, rows_per_sublane, + /*is_high=*/false, target_shape); + } + // If data is in the low bits only, skip high bits + // This happens iff data is in a single sublane and ends before the split + if (start_sublane != end_sublane || + (data_end - 1) % rows_per_sublane >= shift_in_sublane) { + high_bits = + rotateVregRows(builder, loc, vreg, shift_in_tile, rows_per_sublane, + /*is_high=*/true, target_shape); + } + if (low_bits != nullptr && high_bits != nullptr) { + return builder.create(loc, low_bits, high_bits); + } else if (low_bits != nullptr) { + return low_bits; + } else { + CHECK(high_bits != nullptr); + return high_bits; + } + }; + const int64_t res_low_idx_delta = *src_offsets[0] < dst_row_offset ? -1 : 0; + const int64_t res_high_idx_delta = *src_offsets[0] < dst_row_offset ? 0 : 1; + res_vregs.Each([&](absl::Span idxs, Value *v) { + // Each vreg of the result is (usually) a combination of two vregs from the + // source. If we are shifting *down* by 5 rows, the first 5 rows of result + // vreg i (along 2nd minor) will come from source vreg i-1, while the + // following rows will come from source vreg i. + + // The split of data between low and high is at shift_in_tile rows. + Value low, high; + // The start row of data in the vreg + const int64_t res_data_start = *(idxs.end() - 2) == 0 ? dst_row_offset : 0; + // The end row of data in the vreg, exclusive + const int64_t res_data_end = + *(idxs.end() - 2) == *(res_vregs.dimensions().end() - 2) - 1 + // -+ 1 before/after modulo so result is (1, tiling[0]) inclusive + ? (dst_row_offset + tiled_ishape[0] - 1) % tiling[0] + 1 + : tiling[0]; + // If data begins after the split, skip the low rows + if (res_data_start < shift_in_tile) { + SmallVector low_idxs(toArrayRef(idxs)); + *(low_idxs.end() - 2) += res_low_idx_delta; + low = builder.create(loc, int_vreg_ty, + src_vregs(low_idxs)); + low = rotate_rows_and_blend( + low, res_data_start, + /*data_end=*/std::min(res_data_end, shift_in_tile)); + // By doing the tile rotate after, rotate_rows_and_blend can be CSE'd + // since the low part of this vreg is the high part of the previous vreg. + // If there is no next previous or there is no benefit in CSE (e.g. we + // only use high bits and next vreg only uses low bits), the rotates + // should get merged anyway. + // TODO(tlongeri): Think more about the order in which rotates happen. + // Doing OR before rotate may be better. + low = builder.create( + loc, low, (tiles_per_vreg - 1) * sublanes_per_tile, 0, nullptr, + nullptr); + } + // If data ends before the split, skip high rows. + if (res_data_end > shift_in_tile) { + SmallVector high_idxs(toArrayRef(idxs)); + *(high_idxs.end() - 2) += res_high_idx_delta; + high = builder.create(loc, int_vreg_ty, + src_vregs(high_idxs)); + high = rotate_rows_and_blend( + high, + /*data_start=*/std::max(res_data_start, shift_in_tile), res_data_end); + } + + if (low != nullptr && high != nullptr) { + *v = builder.create(loc, mask, low, high); + } else if (low != nullptr) { + *v = low; + } else { + CHECK(high != nullptr); + *v = high; + } + *v = builder.create(loc, vreg_ty, *v); + }); + + return res_vregs; } // Relayout src_vregs from layout src to layout dst, where dst is the same as @@ -5819,8 +5811,6 @@ FailureOr>> changeOffsets( const auto &target_shape = ctx.target_shape; const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), src.implicit_dim()); - const int packing = src.packing(); - const int8_t bitwidth = src.bitwidth(); int row_diff; if (!src.offsets()[0].has_value()) { @@ -5846,56 +5836,9 @@ FailureOr>> changeOffsets( } const SmallVector implicit_shape = src.implicitShape(vty.getShape()); - if (implicit_shape[implicit_shape.size() - 2] != 1) { - // Multi row shift - // TODO(mvoz): This should take the vregs array, not the value. - FAILUREOR_ASSIGN_OR_RETURN( - vregs, tpu_rotate_with_overflow( - builder, target_shape, loc, vty, std::move(vregs), - /*dim*/ implicit_shape.size() - 2, src, dst_offsets)); - } else { - // Single row case - // TODO(mvoz): The single row case has a broader set of supported - // operations: non-native tiling, packed types, implicit dim. We should - // support these cases in tpu_rotate_with_overflow and remove this - // branch. - const int64_t src_sublane = *src.offsets()[0] / packing; - const int64_t dst_sublane = *dst_offsets[0] / packing; - if (int64_t sublane_diff = dst_sublane - src_sublane) { - if (sublane_diff < 0) { - sublane_diff += target_shape[0]; - } - rotateSublanes(builder, vregs, sublane_diff); - } - const int src_subelem = *src.offsets()[0] % packing; - const int dst_subelem = *dst.offsets()[0] % packing; - if (src_subelem != dst_subelem) { - const int subelem_diff = dst_subelem - src_subelem; - const int shift_bits = bitwidth * std::abs(subelem_diff); - VectorType bits_vreg_ty = - VectorType::get(target_shape, builder.getI32Type()); - auto shift_vreg = builder.create( - loc, bits_vreg_ty, - DenseElementsAttr::get(bits_vreg_ty, shift_bits)); - vregs.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(loc, bits_vreg_ty, *tile); - Operation *shift_tile; - if (subelem_diff > 0) { - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } else { // subelem_diff < 0 - CHECK_LT(subelem_diff, 0); - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } - *tile = builder - .create(loc, tile->getType(), - shift_tile->getResult(0)) - .getResult(); - }); - } - } + FAILUREOR_ASSIGN_OR_RETURN( + vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src, + *dst_offsets[0], ctx.target_shape)); } // Rows are now correctly aligned. Time to offset columns. From 377393c61c354f20129cb740a0f8a22b1a8932a4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Apr 2025 18:12:08 -0700 Subject: [PATCH 0751/1769] Cleanup: don't redefine softmax in jax.random --- jax/_src/random.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index d167b9471f3a..ef02719de146 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -39,6 +39,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import lax as lax_internal +from jax._src.nn.functions import softmax from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact from jax._src.pjit import auto_axes @@ -1136,16 +1137,7 @@ def _dirichlet(key, alpha, shape, dtype) -> Array: # Compute gamma in log space, otherwise small alpha can lead to poor behavior. log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype) - return _softmax(log_gamma_samples, -1) - - -def _softmax(x, axis) -> Array: - """Utility to compute the softmax of x along a given axis.""" - if not dtypes.issubdtype(x.dtype, np.floating): - raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}") - x_max = jnp.max(x, axis, keepdims=True) - unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) - return unnormalized / unnormalized.sum(axis, keepdims=True) + return softmax(log_gamma_samples, -1) def exponential(key: ArrayLike, From 910a88f4cc89742761cfd8de0cddc8ee0f1c2675 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Apr 2025 18:33:01 -0700 Subject: [PATCH 0752/1769] Expose `jax.shard_map` as the public non-experimental API and move shard_map.py to `jax/_src` The signature is: `jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True)` This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: check_rep is renamed to check_vma and all arguments (except f) to shard_map are keyword only and f is positional only. **But why are mesh and in_specs optional? And what is the new axis_names argument?** mesh is optional because it can be inferred from the context if user sets the mesh via jax.sharding.use_mesh(mesh). in_specs is optional because it can be inferred from the arguments passed to shard_map if all mesh axes are Explicit. axis_names: axis_names tells shard_map which axes are Manual. If empty, it implies the shard_map is Manual over all mesh axes. Before in the experimental endpoint of shard_map, this argument was called auto. But after the advent of sharding_in_types, mesh axes can be Auto, Explicit or Manual. So auto was not enough since axes can be Explicit too. That's why jax.shard_map flips the argument to axis_names. **If in_specs is optional, why is out_specs compulsory?** This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything. END_PUBLIC PiperOrigin-RevId: 750401402 --- jax/BUILD | 1 + jax/__init__.py | 2 + jax/_src/api.py | 4 +- jax/_src/checkify.py | 15 +- jax/_src/dispatch.py | 2 +- jax/_src/mesh.py | 7 + jax/_src/shard_map.py | 1705 +++++++++++++++++++++++++ jax/experimental/key_reuse/_core.py | 2 +- jax/experimental/roofline/roofline.py | 19 +- jax/experimental/shard_map.py | 1660 +----------------------- tests/shard_map_test.py | 24 +- 11 files changed, 1760 insertions(+), 1681 deletions(-) create mode 100644 jax/_src/shard_map.py diff --git a/jax/BUILD b/jax/BUILD index 862679681c39..d7c48f019096 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -254,6 +254,7 @@ py_library_providing_imports_info( "_src/public_test_util.py", "_src/random.py", "_src/shard_alike.py", + "_src/shard_map.py", "_src/sourcemap.py", "_src/stages.py", "_src/tree.py", diff --git a/jax/__init__.py b/jax/__init__.py index f8ceb41dd3be..18465c28bc84 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -131,6 +131,8 @@ from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh +from jax._src.shard_map import shard_map as shard_map + # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla del ad, batching, mlir, partial_eval, pxla, xla diff --git a/jax/_src/api.py b/jax/_src/api.py index 5a448503d8f8..9e13c4438d50 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1468,10 +1468,8 @@ def pmap( " removed from JAX. Please migrate to pjit and remove global_arg_shapes" " from pmap.") - # TODO(yashkatariya): Move this out after shard_map is out of experimental and - # in _src if config.pmap_shmap_merge.value: - from jax.experimental.shard_map import pmap + from jax._src.shard_map import pmap return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 959ff36881e1..298d44f023de 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -25,7 +25,10 @@ from jax import dtypes from jax import lax -from jax.experimental import shard_map +# TODO(yashkatariya): Remove the experimental import after users are migrated +# to `jax.shard_map`. +from jax.experimental import shard_map # noqa: F401 +from jax._src import shard_map as jshmap from jax._src import api from jax._src import api_util from jax._src import ad_checkpoint @@ -972,8 +975,8 @@ def shard_map_error_check( raise ValueError(f'Unsupported aval type: {type(v)}') in_avals[i] = sharder(mesh, auto, check_rep, new_in_names[i], v) - with (shard_map._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto)), + with (jshmap._extend_axis_env(mesh, auto), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, auto)), config._check_rep(check_rep)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( @@ -1008,11 +1011,11 @@ def expand_errors_leading_dim(*xs): out_names=new_out_names, **kwargs, ) - _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + _, new_params = jshmap.shard_map_p.get_bind_params(new_params) - err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + err_and_out = jshmap.shard_map_p.bind(subfun, *new_vals_in, **new_params) return tree_unflatten(out_tree, err_and_out) -error_checks[shard_map.shard_map_p] = shard_map_error_check +error_checks[jshmap.shard_map_p] = shard_map_error_check def custom_jvp_call_rule(in_err: Error, enabled_errors: set, *in_vals, num_consts, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6d9b79b9b58b..28ed39a1fe2a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -240,7 +240,7 @@ class SourceInfo(NamedTuple): def get_intermediate_shardings( jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: from jax._src import pjit - from jax.experimental import shard_map + from jax._src import shard_map out = [] for eqn in jaxpr.eqns: diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index aa6c49f0ccdd..b96bd2f832dc 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -178,6 +178,13 @@ def _any_axis_auto(self) -> bool: def _any_axis_explicit(self) -> bool: return any_axis_types_match(self._axis_types, AxisType.Explicit) + @functools.cached_property + def _any_axis_auto_or_manual(self) -> bool: + if not self._axis_types: + return False + return any(t == AxisType.Auto or t == AxisType.Manual + for t in self._axis_types) + @functools.cached_property def auto_axes(self): return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py new file mode 100644 index 000000000000..df79af0ba8f8 --- /dev/null +++ b/jax/_src/shard_map.py @@ -0,0 +1,1705 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Callable, Hashable, Sequence, Set +import enum +from functools import partial +import inspect +from math import prod +import operator as op +from typing import Any, TypeVar, Union + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec +from jax._src import ad_util +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import debugging +from jax._src import dispatch +from jax._src import dtypes +from jax._src import linear_util as lu +from jax._src import sharding_impls +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.core import pvary +from jax._src.core import Tracer, typeof +from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, + get_abstract_mesh) +from jax._src.api import _shared_code_pmap, _prepare_pmap +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy +from jax._src.util import (HashableFunction, HashablePartial, unzip2, + as_hashable_function, memoize, partition_list, + merge_lists, split_list, subs_list2) +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import pxla +from jax._src.interpreters import ad +from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, + tree_structure, tree_leaves, keystr) +from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, + generate_key_paths, KeyPath) +from jax.experimental.multihost_utils import (host_local_array_to_global_array, + global_array_to_host_local_array) + +P = PartitionSpec + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip +traceback_util.register_exclusion(__file__) + +# API + +Specs = Any # PyTree[PartitionSpec] +AxisName = Hashable + + +def shard_map(f, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), + in_specs: Specs | None = None, + mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): + """Map a function over shards of data using a mesh of devices. + + See the docs at https://jax.readthedocs.io/en/latest/notebooks/shard_map.html. + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the + array of devices over which to shard the data and on which to execute + instances of ``f``. The names of the ``Mesh`` can be used in collective + communication operations in ``f``. If mesh is None, it will be inferred + from the context which can be set via `jax.sharding.use_mesh` context + manager. + in_specs: (optional, default None) a pytree with + ``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure + that is a tree prefix of the args tuple to be mapped over. Similar to + ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the + corresponding argument (or subtree of arguments) should be sharded along + the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a + ``mesh`` axis name at a position expresses sharding the corresponding + argument array axis along that positional axis; not mentioning an axis + name expresses replication. If ``None``, all mesh axes must be in explicit + mode, in which case the in_specs are inferred from the argument types. + out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree + structure that is a tree prefix of the output of ``f``. Each + ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name + at a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis; not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + axis_names: (optional, default None) set of axis names from ``mesh`` over + which the function ``f`` is manual. If ``None``, ``f``, is manual + over all mesh axes. + check_vma: (optional) boolean (default True) representing whether to enable + additional validity checks and automatic differentiation optimizations. + The validity checks concern whether any mesh axis names not mentioned in + ``out_specs`` are consistent with how the outputs of ``f`` are replicated. + + Returns: + A callable representing a mapped version of ``f``, which accepts positional + arguments corresponding to those of ``f`` and produces output corresponding + to that of ``f``. + """ + return _shard_map(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=axis_names, check_vma=check_vma) + +def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, + in_specs: Specs, out_specs: Specs | Callable[[], Specs], + axis_names: Set[AxisName], check_vma: bool): + if not callable(f): + raise TypeError("shard_map requires a callable for its first argument, " + f"but got {f} of type {type(f)}.") + + if mesh is None: + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Either use" + " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass" + " a mesh to `shard_map` via the `mesh` keyword argument.") + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " + f"second argument, but got {mesh} of type {type(mesh)}.") + + if not isinstance(axis_names, (frozenset, set)): + raise TypeError( + "`axis_names` argument of shard_map should be of type `frozenset` or" + f" `set`. Got type: {type(axis_names)}") + if isinstance(axis_names, set): + axis_names = frozenset(axis_names) + if not axis_names: + axis_names = frozenset(mesh.axis_names) + auto = frozenset(mesh.axis_names) - frozenset(axis_names) + if not auto.issubset(mesh.axis_names): + raise ValueError(f"shard_map requires auto={auto} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") + + # TODO(yashkatariya): Maybe we don't have to be this strict? + if mesh._any_axis_auto_or_manual and in_specs is None: + raise TypeError( + "shard_map in_specs argument must be a pytree of" + " `jax.sharding.PartitionSpec` instances, but it was None when mesh" + f" {mesh} has `Auto` axes.\n") + + if in_specs is not None: + _check_specs(SpecErrorType.input, in_specs, auto) + if not callable(out_specs): + _check_specs(SpecErrorType.out, out_specs, auto) + + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + fun = lu.wrap_init( + f, debug_info=api_util.debug_info("shard_map", f, args, {})) + args_flat, in_tree = tree_flatten(args) + fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) + + try: + in_specs_flat = broadcast_prefix( + in_specs, args, is_leaf=lambda x: x is None) + except ValueError: + e, *_ = prefix_errors(in_specs, args) + raise e('shard_map in_specs') from None + + # TODO(yashkatariya): Relax this and convert only `None`s in `in_specs_flat` + # and accept the other specs as is. + if mesh._are_all_axes_explicit and in_specs is None: + arg_s = [typeof(a).sharding for a in args_flat] + assert all(i is None for i in in_specs_flat), in_specs_flat + in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s] + + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) + if s is not None) + fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) + in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + + @memoize + def out_names_thunk(): + if callable(out_specs): + out_specs_ = out_specs() + _check_specs(SpecErrorType.out, out_specs_, auto) + else: + out_specs_ = out_specs + dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) + try: + out_specs_flat = broadcast_prefix(out_specs_, dummy) + except ValueError: + e, *_ = prefix_errors(out_specs_, dummy) + raise e('shard_map out_specs') from None + return tuple(map(_canonicalize_spec, out_specs_flat)) + + if check_vma: + fun = _implicit_pvary_on_output(fun, out_names_thunk) + + try: + out_flat = shard_map_p.bind( + fun, *args_flat, mesh=mesh, in_names=in_names_flat, + out_names_thunk=out_names_thunk, check_rep=check_vma, auto=auto) + except _SpecError as e: + fails, = e.args + if not callable(out_specs): + msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) + if any(fail is not no_fail and not fail.shape for fail in fails): + msg += (" In particular, for rank 0 outputs which are not constant " + "over the mesh, add at least one (singleton) axis to them so " + "that they can be concatenated using out_specs.") + raise ValueError(msg) from None + except _RepError as e: + fails, = e.args + if not callable(out_specs): + msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) + raise ValueError(msg) from None + return tree_unflatten(out_tree(), out_flat) + return wrapped + + +# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs +AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable +def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: + if isinstance(spec, PartitionSpec): + return {i: names if isinstance(names, tuple) else (names,) + for i, names in enumerate(spec) if names is not None} + else: + return spec + +def _manual_spec(manual_axes, spec: P) -> P: + out = [] # type: ignore + for s in spec: + if s is None: + out.append(s) + elif isinstance(s, tuple): + temp = [p if p in manual_axes else None for p in s] + while temp and temp[-1] is None: + temp.pop() + if None in temp: + raise ValueError(f"Invalid spec: {spec}") + out.append(None if len(temp) == 0 else tuple(temp)) + else: + out.append(s if s in manual_axes else None) + return P(*out) + + +# Error checking and messages + +SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) + +def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: + if error_type == SpecErrorType.input and specs is None: + raise TypeError( + "shard_map in_specs argument must be a pytree of " + "`jax.sharding.PartitionSpec` instances, but it was None.\n" + "Instead of `in_specs=None`, did you mean `in_specs=P()`, " + "where `P = jax.sharding.PartitionSpec`?") + def check_spec(p): + if not isinstance(p, PartitionSpec): + return False + for names in p: + if not isinstance(names, tuple): + names = (names,) + for name in names: + if name in auto: + return False + return True + if all(check_spec(p) for p in tree_leaves(specs)): return + prefix = 'in' if error_type == SpecErrorType.input else 'out' + msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " + for key, x in generate_key_paths(specs) if not isinstance(x, P)] + if not msgs: + for key, p in generate_key_paths(specs): + for names in p: + if not isinstance(names, tuple): + names = (names,) + for name in names: + if name in auto: + msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") + raise ValueError( + f"shard_map {prefix}_specs argument cannot refer to an axis " + f"marked auto ({auto}), but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + raise TypeError( + f"shard_map {prefix}_specs argument must be a pytree of " + f"`jax.sharding.PartitionSpec` instances, but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + +class NoFail: pass +no_fail = NoFail() + +def _check_specs_vs_args( + f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, + dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], + xs: Sequence) -> None: + in_avals = map(core.shaped_abstractify, xs) + fail = [a if not len(p) <= a.ndim else no_fail + for p, a in zip(in_specs_flat, in_avals)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) + raise ValueError(msg) + in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) + for d, ns in names.items()) else no_fail + for a, names in zip(in_avals, in_names_flat)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) + raise ValueError(msg) + +def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], + fail: Sequence[core.ShapedArray | NoFail] + ) -> list[core.ShapedArray | NoFail]: + fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves + for i, f in zip(dyn_argnums, fail): + fail_[i] = f + return fail_ + +def _spec_rank_error( + error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + fun_name = getattr(f, '__name__', str(f)) + if error_type == SpecErrorType.input: + prefix, base = 'in', 'args' + ba = _try_infer_args(f, tree) + else: + prefix, base = 'out', f'{fun_name}(*args)' + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if error_type == SpecErrorType.input and ba is not None: + arg_key, *_ = fail_key + if arg_key.idx < len(ba.arguments): + param_name = list(ba.arguments.keys())[arg_key.idx] + extra = (f", where {base}{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + else: + param = list(ba.signature.parameters.values())[-1] + assert param.kind == inspect.Parameter.VAR_POSITIONAL + extra = (f", where {base}{arg_key} is the index " + f"{arg_key.idx - len(ba.signature.parameters) + 1} component " + f"of {fun_name}'s varargs parameter '{param.name}',") + msgs.append( + f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " + f"{len(spec)}, but " + f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " + f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given an " + f"{prefix}_specs entry which is too long to be compatible with the " + f"corresponding {prefix}put value from the function:\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Entries in {prefix}_specs must be of length no greater than the " + f"number of axes in the corresponding {prefix}put value.\n\n" + f"Either revise the spec to be shorter, or modify '{fun_name}' so " + f"that its {prefix}puts have sufficient rank.") + if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): + msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " + "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") + return msg + +def _spec_divisibility_error( + f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + ba = _try_infer_args(f, tree) + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if ba is not None: + arg_key, *_ = fail_key + if arg_key.idx < len(ba.arguments): + param_name = list(ba.arguments.keys())[arg_key.idx] + extra = (f", where args{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + else: + param = list(ba.signature.parameters.values())[-1] + assert param.kind == inspect.Parameter.VAR_POSITIONAL + extra = (f", where args{arg_key} is the index " + f"{arg_key.idx - len(ba.signature.parameters) + 1} component " + f"of {fun_name}'s varargs parameter '{param.name}',") + names = _canonicalize_spec(spec) + for d, ns in names.items(): + if aval.shape[d] % prod(mesh.shape[n] for n in ns): + axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" + total = 'total ' if len(ns) > 1 else '' + sz = prod(mesh.shape[n] for n in ns) + msgs.append( + f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " + f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " + f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " + f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " + f"{aval.shape[d]}") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given argument " + f"arrays with axis sizes that are not evenly divisible by the " + f"corresponding mesh axis sizes:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Array arguments' axis sizes must be evenly divisible by the mesh " + f"axis or axes indicated by the corresponding elements of the " + f"argument's in_specs entry. Consider checking that in_specs are " + f"correct, and if so consider changing the mesh axis sizes or else " + f"padding the input and adapting '{fun_name}' appropriately.") + return msg + +def _inout_rep_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, + specs: Specs, fails: list[set | NoFail]) -> str: + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): + dst = _canonicalize_spec(spec) + unmentioned = _unmentioned(mesh, dst) + if len(unmentioned) > 1: + need_rep = ','.join(map(str, unmentioned)) + got_rep = ','.join(map(str, rep)) + diff = ','.join(map(str, [n for n in unmentioned if n not in rep])) + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axes " + f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " + f"which is missing the required axes {diff}") + else: + need_rep_, = unmentioned + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axis " + f"'{need_rep_}', but could not infer replication over any axes") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given " + f"out_specs which require replication which can't be statically " + f"inferred given the mesh:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + "Check if these output values are meant to be replicated over those " + "mesh axes. If not, consider revising the corresponding out_specs " + "entries. If so, consider disabling the check by passing the " + "check_rep=False argument to shard_map.") + return msg + +def _unmentioned(mesh: Mesh | AbstractMesh, names: AxisNames) -> list[AxisName]: + name_set = {n for ns in names.values() for n in ns} + return [n for n in mesh.axis_names if n not in name_set] + + +def _try_infer_args(f, tree): + dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) + try: + return inspect.signature(f).bind(*dummy_args) + except (TypeError, ValueError): + return None + +T = TypeVar('T') +def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] + ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: + failures = tree_unflatten(tree, fails) + failures_aug = generate_key_paths(failures) + specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) + leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P + specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) + return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) + in zip(specs_aug, failures_aug) + if s is not None and fail_data is not no_fail] + +# Primitive + +@lu.transformation2 +def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) + for o, n in zip(out_flat, out_names_thunk())] + +JaxType = Any +MaybeTracer = Union[JaxType, Tracer] + +class ShardMapPrimitive(core.Primitive): + multiple_results = True + + def bind(self, *args, **params): + return self._true_bind(*args, **params) + + def bind_with_trace(self, trace, fun_and_args, params): + fun: lu.WrappedFun + fun, *args = fun_and_args + return trace.process_shard_map(shard_map_p, fun, args, **params) + + def get_bind_params(self, params): + new_params = dict(params) + jaxpr: core.Jaxpr = new_params.pop('jaxpr') + subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, + debug_info=jaxpr.debug_info), + jaxpr, ()) + axes = new_params.pop('out_names') + new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) + return [subfun], new_params + +shard_map_p = ShardMapPrimitive('shard_map') + +# Staging + +@util.cache(max_size=256, trace_context_in_key=True) +def _as_manual_mesh(mesh, auto: frozenset): + manual_axes = tuple(set(mesh.axis_names) - auto) + cur_mesh = get_abstract_mesh() + if cur_mesh.empty: + cur_mesh = mesh + explicit_axes, auto_axes = set(), set() # type: ignore + for a in auto: + if cur_mesh._name_to_type[a] == AxisType.Auto: + auto_axes.add(a) + else: + assert cur_mesh._name_to_type[a] == AxisType.Explicit + explicit_axes.add(a) + + new_axis_types = [] + for n in mesh.axis_names: + if n in manual_axes: + new_axis_types.append(AxisType.Manual) + elif n in auto_axes: + new_axis_types.append(AxisType.Auto) + else: + assert n in explicit_axes + new_axis_types.append(AxisType.Explicit) + return AbstractMesh(mesh.axis_sizes, mesh.axis_names, + axis_types=tuple(new_axis_types)) + + +def _extend_axis_env(mesh, auto): + return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() + if k not in auto]) + +def _shard_map_staging( + trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, + in_tracers: Sequence[Any], *, mesh: Mesh, + in_names: tuple[AxisNames, ...], + out_names_thunk: Callable[[], tuple[AxisNames, ...]], + check_rep: bool, + auto: frozenset, + ) -> Sequence[pe.DynamicJaxprTracer]: + in_tracers = map(trace.to_jaxpr_tracer, in_tracers) + in_avals = [t.aval for t in in_tracers] + in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, + in_avals) + manual_mesh = _as_manual_mesh(mesh, auto) + with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + config._check_rep(check_rep)): + jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) + _check_names(out_names_thunk(), out_avals_) + if check_rep: + out_vma = [v.aval.vma for v in jaxpr.outvars] + _check_reps(mesh, auto, out_names_thunk(), out_vma) + out_avals = map(_check_shapedarray, out_avals_) + out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) + for names, aval in zip(out_names_thunk(), out_avals)] + source_info = source_info_util.current() + out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] + invars = map(trace.getvar, in_tracers) + constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) + outvars = map(trace.makevar, out_tracers) + in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore + with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + config._check_rep(check_rep)): + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + params = dict(mesh=mesh, in_names=in_names_staged, + out_names=tuple(out_names_thunk()), jaxpr=jaxpr, + check_rep=check_rep, auto=auto) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, + effs, source_info) + trace.frame.add_eqn(eqn) + return out_tracers +pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging + +# TODO add underscore version, for direct-linearize to consume + +def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: + assert isinstance(aval, core.ShapedArray) + return aval + +def _shard_aval(mesh: Mesh, auto, check_rep, names: AxisNames, + aval: core.AbstractValue) -> core.AbstractValue: + if type(aval) in core.shard_aval_handlers: + return core.shard_aval_handlers[type(aval)](mesh, auto, check_rep, names, + aval) + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + +def _unshard_aval(mesh: Mesh, check_rep, names: AxisNames, + aval: core.AbstractValue) -> core.AbstractValue: + if type(aval) in core.unshard_aval_handlers: + return core.unshard_aval_handlers[type(aval)](mesh, check_rep, names, aval) + else: + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + +def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames, + aval: core.AbstractValue) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + manual_mesh = _as_manual_mesh(mesh, auto) + new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) + vma = (frozenset({n for ns in names.values() for n in ns}) + if check_rep else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array + +def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, + aval: core.AbstractValue,) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) + if aval.ndim == 0: + out_spec = P() + else: + out_spec = [] # type: ignore + for name_s, aval_s in zip(names_spec, aval.sharding.spec): + if name_s and not aval_s: + out_spec.append(name_s) + elif aval_s and not name_s: + out_spec.append(aval_s) + elif not name_s and not aval_s: + out_spec.append(None) + else: + assert name_s and aval_s + name_s = name_s if isinstance(name_s, tuple) else (name_s,) + aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) + out_spec.append(name_s + aval_s) + out_spec = PartitionSpec(*out_spec) + new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else + get_abstract_mesh()) + new_sharding = NamedSharding(new_mesh, out_spec) + manual_axes = set(new_mesh.manual_axes) + vma = (frozenset(v for v in aval.vma if v in manual_axes) + if check_rep else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array + +# Type-checking + +RepType = Any + +def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, + check_rep, auto): + # TODO(mattjj,parkers): check auto + for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): + if not core.typecompat(v.aval, _shard_aval( + mesh, auto, check_rep, in_name, x.aval)): + raise core.JaxprTypeError("shard_map argument avals not compatible with " + "jaxpr binder avals and in_names") + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + core.check_jaxpr(jaxpr) + if check_rep: + out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] + for rep, dst in zip(out_rep, out_names): + if not _valid_repeats(mesh, auto, rep, dst): + raise core.JaxprTypeError( + "shard_map can't prove output is sufficiently replicated") + out_avals_sharded = [x.aval for x in jaxpr.outvars] + out_avals = map(partial(_unshard_aval, mesh, check_rep), out_names, + out_avals_sharded) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + return out_avals, effs +core.custom_typechecks[shard_map_p] = _shard_map_typecheck + +def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: + return set(mesh.axis_names) - {n for ns in names.values() for n in ns} + +def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: + return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) + + +# Lowering + +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in +) -> sharding_impls.SdyArraySharding: + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) + if auto: + for dim_sharding in sdy_sharding.dimension_shardings: + dim_sharding.is_open = True + return sdy_sharding + + +def _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep): + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s cannot refer to axes that are already + # manual. So figure out what axes are free thus far. + free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + shardy_manual_axes = free_axes - auto + else: + shardy_manual_axes = frozenset(mesh.axis_names) - auto + new_axis_context = sharding_impls.SPMDAxisContext( + mesh, frozenset(mesh.axis_names) - auto) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + + # The order of manual axes should match the order of mesh.axis_names to avoid + # non-determinism issues. + manual_axes = [a for a in mesh.axis_names + if a in shardy_manual_axes] + if np.prod([mesh.shape[a] for a in manual_axes]) == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, + dim_var_values=ctx.dim_var_values) + return out_nodes + + in_shardings = sharding_impls.SdyArrayShardingList(map( + partial(_shardy_shard_map_sharding, ctx, mesh, auto), + in_names, ctx.avals_in)).build() + out_shardings = sharding_impls.SdyArrayShardingList(map( + partial(_shardy_shard_map_sharding, ctx, mesh, auto), + out_names, ctx.avals_out)).build() + output_types = map(mlir.aval_to_ir_type, ctx.avals_out) + manual_computation_op = sdy.ManualComputationOp( + output_types, in_nodes, in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + block = ir.Block.create_at_start( + manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) + with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), + config._check_rep(check_rep)): + out_nodes_, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, + dim_var_values=ctx.dim_var_values) + sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) + + return manual_computation_op.results + + +def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, + check_rep, auto): + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep) + + in_avals_ = [v.aval for v in jaxpr.invars] + out_avals_ = [x.aval for x in jaxpr.outvars] + in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, + in_avals_, in_nodes) + manual_axes = frozenset(mesh.axis_names) - auto + new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + out_nodes_, tokens_out = mlir.call_lowering( + "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, + out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, + arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), + result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) + ctx.set_tokens_out(tokens_out) + return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, + ctx.avals_out, out_nodes_) +mlir.register_lowering(shard_map_p, _shard_map_lowering) + +def _make_scoped_manual_sharding(ctx, mesh, axes): + axis_ctx = ctx.module_context.axis_context + mesh = mesh.abstract_mesh + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + mesh = mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + return NamedSharding( + mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore + +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + return x + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() + unspecified = set(range(aval_in.ndim)) if auto else set() + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, + unspecified_dims=unspecified) + manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) + +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + return x + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_out, ns) + aval_out = core.physical_aval(aval_out) + unspecified = set(range(aval_out.ndim)) if auto else set() + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + aval_in = core.physical_aval(aval_in) + manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) + shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, + unspecified) + +def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: + if isinstance(aval, core.ShapedArray): + return str(map(names.get, range(aval.ndim))) + return '' + +# Eager evaluation + +def get_mesh_from_args(args_flat, mesh): + for a in args_flat: + if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): + if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + aval = core.shaped_abstractify(a) + raise ValueError( + f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" + " match the mesh shape passed to shard_map " + f" {mesh.shape_tuple} for shape {aval.str_short()}") + mesh = a.sharding.mesh + if isinstance(mesh, AbstractMesh): + raise ValueError( + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument.") + assert isinstance(mesh, Mesh) + return mesh + +def _vma_to_spec(mesh, vma): + return P(tuple(i for i in mesh.axis_names if i in vma)) + +def _names_to_vma(names): + return {n for ns in names.values() for n in ns} + +def _vma_to_rep(mesh, auto, vma): + return frozenset((set(mesh.axis_names) - auto) - vma) + +def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, + check_rep, auto): + if auto: raise NotImplementedError + del prim + if isinstance(mesh, AbstractMesh): + mesh = get_mesh_from_args(args, mesh) + cur_mesh = get_abstract_mesh() + args = map(partial(_unmatch_spec, mesh, check_rep, context_mesh=cur_mesh), + in_names, args) + in_vma = map(_names_to_vma, in_names) + outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_rep, cur_mesh) + out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] + _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types + if check_rep: + _check_reps(mesh, auto, out_names_thunk(), out_vma) + src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) + else: + src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) + dst_pspecs = map(_names_to_pspec, out_names_thunk()) + return map(partial(_match_spec, mesh, check_rep), src_pspecs, dst_pspecs, + outs) +core.EvalTrace.process_shard_map = _shard_map_impl + +def _run_shmap(f, mesh, auto, args, vmas, check_rep, context_mesh): + trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) + in_tracers = map(partial(ShardMapTracer, trace), vmas, args) + manual_mesh = _as_manual_mesh(mesh, auto) + with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), + use_abstract_mesh(manual_mesh), config._check_rep(check_rep)): + ans = f.call_wrapped(*in_tracers) + outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) + return outs, out_vma + +def _names_to_pspec(names: AxisNames) -> PartitionSpec: + ndmin = max(names) + 1 if names else 0 + unpack = lambda t: t[0] if t is not None and len(t) == 1 else t + return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) + +def _unmatch_spec(mesh: Mesh, check_rep, src: AxisNames, x: JaxType, + context_mesh) -> JaxType: + with (core.eval_context(), jax.disable_jit(False), + use_abstract_mesh(context_mesh)): + return jax.jit(HashablePartial(_unmatch, mesh, check_rep, + tuple(src.items())))(x) + +def _unmatch(mesh, check_rep, src_tup, x): + src = _names_to_pspec(dict(src_tup)) + if check_rep: + used_axes = {i for _, ns in src_tup for i in ns} + dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) + else: + dst = P(mesh.axis_names) + check_rep = False + return shard_map(_add_singleton, mesh=mesh, in_specs=(src,), out_specs=dst, + check_vma=check_rep)(x) + +def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] + ) -> None: + fail = [a if n and not max(n) < a.ndim else no_fail + for n, a in zip(names, avals)] + if any(f is not no_fail for f in fail): + raise _SpecError(fail) + +class _SpecError(Exception): + pass + +def _check_reps(mesh, auto, names, vmas): + reps = [_vma_to_rep(mesh, auto, v) for v in vmas] + fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail + for n, r in zip(names, reps)] + if any(f is not no_fail for f in fail): + raise _RepError(fail) + +class _RepError(Exception): + pass + +def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, + dst_pspec: PartitionSpec, x: JaxType) -> JaxType: + fn = HashablePartial(_match, mesh, check_rep, src_pspec, dst_pspec) + with core.eval_context(), jax.disable_jit(False): + return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) + +def _match(mesh, check_rep, src_pspec, dst_pspec, x): + return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec, + out_specs=dst_pspec, check_vma=check_rep)(x) + +def _rem_singleton(x): return jnp.squeeze(x, axis=0) +def _add_singleton(x): return jnp.expand_dims(x, axis=0) + +def _maybe_check_special(outs): + if not config.debug_nans.value and not config.debug_infs.value: return + bufs = [s.data for leaf in tree_leaves(outs) + for s in getattr(leaf, 'addressable_shards', [])] + try: + dispatch.check_special('shard_map', bufs) + except dispatch.InternalFloatingPointError as e: + raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None + +class ShardMapTrace(core.Trace): + __slots__ = ("mesh", "auto", "check", "context_mesh") + + mesh: Mesh + auto: frozenset[AxisName] + check: bool + context_mesh: AbstractMesh + + def __init__(self, mesh, auto, check, context_mesh): + super().__init__() + self.mesh = mesh + self.auto = auto + self.check = check + self.context_mesh = context_mesh + + def to_val_vma_pair(self, val): + if isinstance(val, ShardMapTracer): + return val.val, val.vma + elif isinstance(val, Tracer): + raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") + else: + val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) + return val_, frozenset() + + def process_primitive(self, prim, tracers, params): + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + if self.check: + out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) + out_avals = tuple(out_avals) if type(out_avals) is list else out_avals + out_vma = tree_map(lambda a: a.vma, out_avals) + in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) + out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) + else: + out_vma = frozenset() + in_specs = out_specs = P(self.mesh.axis_names) + + eager_rule = eager_rules.get(prim) + if eager_rule: + out_vals = eager_rule(self.mesh, *in_vals, **params) + else: + f = HashablePartial( + _prim_applier, prim, self.check, tuple(params.items()), self.mesh, + in_specs, out_specs) + with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), + jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): + out_vals = jax.jit(f)(*in_vals) + _maybe_check_special(out_vals) + if prim.multiple_results: + out_vma = (out_vma if isinstance(out_vma, (list, tuple)) + else [out_vma] * len(out_vals)) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + return ShardMapTracer(self, out_vma, out_vals) + + def process_call(self, call_primitive, fun, tracers, params): + raise NotImplementedError( + f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " + "yet supported. Put a `jax.jit` around the `shard_map`-decorated " + "function, and open a feature request at " + "https://github.com/jax-ml/jax/issues !") + + def process_map(self, map_primitive, fun, tracers, params): + raise NotImplementedError( + "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/jax-ml/jax/issues !") + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + # Since ShardMapTrace is only used as a base main, we can drop the jvp. + del prim, jvp, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, + self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): + if symbolic_zeros: + msg = ("custom_vjp symbolic_zeros support with shard_map is not " + "implemented; please open an issue at " + "https://github.com/jax-ml/jax/issues") + raise NotImplementedError(msg) + del prim, fwd, bwd, out_trees, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, + self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + +class ShardMapTracer(core.Tracer): + vma: frozenset[AxisName] + val: JaxType + + def __init__(self, trace, vma, val): + self._trace = trace + if isinstance(vma, set): + vma = frozenset(vma) + assert isinstance(vma, frozenset) + self.vma = vma + self.val = val + + @property + def aval(self): + aval = core.get_aval(self.val) + out = core.mapped_aval(self._trace.mesh.size, 0, aval) + new_sharding = NamedSharding( + _as_manual_mesh(self._trace.mesh, self._trace.auto), + out.sharding.spec) # pytype: disable=attribute-error + vma = self.vma if config._check_rep.value else frozenset() + return out.update(sharding=new_sharding, vma=vma) + + def to_concrete_value(self): + if self.vma == frozenset(): + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): + return core.to_concrete_value(self.val[0]) + else: + return None + + def __str__(self) -> str: + pb_names = set(self._trace.mesh.axis_names) - self.vma + self = pvary(self, tuple(pb_names)) + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): + blocks = list(self.val) + mesh = self._trace.mesh + axis_names = f"({', '.join(map(str, mesh.axis_names))},)" + return '\n'.join( + f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" + for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + __repr__ = __str__ # for debuggers, like `p x` + +def _prim_applier(prim, check_rep, params_tup, mesh, in_specs, out_specs, *args): + def apply(*args): + outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) + return tree_map(_add_singleton, outs) + out_specs = list(out_specs) if type(out_specs) is tuple else out_specs + return shard_map(apply, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + check_vma=check_rep)(*args) + +eager_rules: dict[core.Primitive, Callable] = {} + + +# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually +def _debug_callback_eager_rule( + mesh, + *args, + callback: Callable[..., Any], + effect: debugging.DebugEffect, + partitioned: bool, +): + del effect + with core.eval_context(): + all_blocks = zip(*map(list, args)) + for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): + callback(*blocks) + return [] + + +eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule + +def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): + del mesh, srcs, copy_semantics + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs +eager_rules[dispatch.device_put_p] = _device_put_eager_rule + + +# Batching + +def _shard_map_batch( + trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, + in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, + in_names: tuple[AxisNames, ...], + out_names_thunk: Callable[[], tuple[AxisNames, ...]], + check_rep: bool, + auto: frozenset) -> Sequence[batching.BatchTracer]: + in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) + if any(isinstance(d, batching.RaggedAxis) for d in in_dims): + raise NotImplementedError + new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] + for ax in names} for names, d in zip(in_names, in_dims)] + spmd_axis_name = trace.axis_data.spmd_name + if spmd_axis_name is not None: + used = {n for names in in_names for ns in names.values() for n in ns} + if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") + new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped + else ns for ns, d in zip(new_in_names, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) + new_axis_data = batching.AxisData(trace.axis_data.name, new_size, + trace.axis_data.spmd_name, None) + else: + new_axis_data = trace.axis_data + fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) + @as_hashable_function(closure=out_names_thunk) + def new_out_names_thunk(): + return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) + + new_params = dict(mesh=mesh, in_names=new_in_names, + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + auto=auto) + with core.set_current_trace(trace.parent_trace): + out_vals = prim.bind(fun, *in_vals, **new_params) + make_tracer = partial(batching.BatchTracer, trace, + source_info=source_info_util.current()) + return map(make_tracer, out_vals, out_dims()) +batching.BatchTrace.process_shard_map = _shard_map_batch + +def _batch_out_names(spmd_axis_name, dims, out_names): + out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] + for ax in names} for names, d in zip(out_names, dims)] + if spmd_axis_name is not None: + used = {n for names in out_names for ns in names.values() for n in ns} + if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") + out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped + else ns for ns, d in zip(out_names_, dims)] + return out_names_ + + +# Autodiff + +def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, + out_names_thunk, check_rep, auto): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + which_nz = [ type(t) is not ad.Zero for t in tangents] + tangents = [t if type(t) is not ad.Zero else None for t in tangents] + args, in_tree = tree_flatten((primals, tangents)) + f_jvp = ad.jvp_subtrace(f, trace.tag) + f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) + tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] + + @as_hashable_function(closure=out_names_thunk) + def new_out_names_thunk(): + out_ax = out_names_thunk() + return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) + params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + auto=auto) + f_jvp, out_tree = ad.traceable(f_jvp, in_tree) + result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) + primal_out, tangent_out = tree_unflatten(out_tree(), result) + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t + for p, t in zip(primal_out, tangent_out)] + return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] +ad.JVPTrace.process_shard_map = _shard_map_jvp + +def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, + f: lu.WrappedFun, tracers, mesh, in_names, + out_names_thunk, check_rep, auto): + tracers = map(trace.to_jaxpr_tracer, tracers) + in_pvals = [t.pval for t in tracers] + in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) + unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) + in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_rep), + unk_in_names, in_avals) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) + f = _promote_scalar_residuals(f) + f_known, aux = pe.partial_eval_wrapper_nounits2( + f, (*in_knowns,), (*in_avals_sharded,)) + all_names = _all_newly_manual_mesh_names(mesh, auto) + + @as_hashable_function(closure=out_names_thunk) + def known_out_names(): + _, _, out_knowns, res_avals, _, _ = aux() + _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) + if check_rep: + res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} + for a in res_avals] + else: + res_names = [{0: all_names}] * len(res_avals) + return (*out_known_names, *res_names) + + known_params = dict(mesh=mesh, in_names=(*known_in_names,), + out_names_thunk=known_out_names, check_rep=check_rep, + auto=auto) + out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), + known_params) + in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() + num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) + assert not jaxpr.constvars + unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) + known_out_names_ = known_out_names() + res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) + # TODO make res_avals be the full set, not just the non-fwd ones + res_avals_iter = iter(res_avals) + res_names = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_names.append(known_in_names[f1]) + elif f2 is not None: + res_names.append(known_out_names_[f2]) + else: + if check_rep: + res_vma = next(res_avals_iter).vma + res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + else: + res_names.append({0: all_names}) + unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.to_jaxpr_tracer, env) + unk_arg_tracers = [t for t in tracers if not t.is_known()] + out_avals_sharded = [v.aval for v in jaxpr.outvars] + unk_params = dict(mesh=mesh, in_names=unk_in_names, + out_names=unk_out_names, jaxpr=jaxpr, + check_rep=check_rep, auto=auto) + out_avals = map(partial(_unshard_aval, mesh, check_rep), unk_out_names, + out_avals_sharded) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in out_avals] + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), + out_tracers, shard_map_p, unk_params, + effs, source_info_util.current()) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) +pe.JaxprTrace.process_shard_map = _shard_map_partial_eval + +def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, + tracers, mesh, in_names, + out_names_thunk, check_rep, auto): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + nzs_in = tuple(type(t) is not ad.Zero for t in tangents) + f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) + f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) + all_names = _all_newly_manual_mesh_names(mesh, auto) + + @as_hashable_function(closure=linearize_outs_thunk) + def fwd_out_names_thunk(): + res_avals, _, _, _, _, _ = linearize_outs_thunk() + out_names = out_names_thunk() + if check_rep: + res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} + for a in res_avals] + else: + res_names = [{0: all_names}] * len(res_avals) + return (*res_names, *out_names) + fwd_params = dict( + mesh=mesh, in_names=in_names, + out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, auto=auto) + all_fwd_results = shard_map_p.bind_with_trace( + trace.parent_trace, (f_primal, *primals), fwd_params) + res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_fwd_results[:num_res_out] + primals_out = all_fwd_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) + args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None + for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] + with (_extend_axis_env(mesh, auto), + use_abstract_mesh(_as_manual_mesh(mesh, auto)), + config._check_rep(check_rep)): + lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) + out_names = out_names_thunk() + res_avals_iter = iter(res_avals) + res_names = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_names.append(in_names[f1]) + elif f2 is not None: + res_names.append(out_names[f2]) + else: + if check_rep: + res_vma = next(res_avals_iter).vma + res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + else: + res_names.append({0: all_names}) + new_in_names = (*res_names, *({} for _ in range(len(env))), + *(ax for ax, nz in zip(in_names, nzs_in) if nz)) + tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) + @as_hashable_function(closure=tangent_out_names) + def tangent_out_names_thunk(): + return tangent_out_names + tangent_params = dict( + mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, + check_rep=check_rep, auto=auto) + + # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here + def f_tangent(*args): + return core.eval_jaxpr(lin_jaxpr, (), *args) + + nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] + nz_tangents_out = shard_map_p.bind_with_trace( + trace.tangent_trace, + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), + *residuals, *env, *nz_tangents_in), tangent_params) + nz_tangents_out_iter = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) + for nz, primal in zip(nzs_out, primals_out)] + return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) +ad.LinearizeTrace.process_shard_map = _shard_map_linearize + +@lu.transformation2 +def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): + ans = f(*args, **kwargs) + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + residuals = ans[:num_res_out] + primals = ans[num_res_out:] + residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in residuals] + return *residuals, *primals + +@lu.transformation2 +def _promote_scalar_residuals(f: Callable, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) + which = [f1 is None and f2 is None and not v.aval.shape + for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] + jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) + out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in out_consts] + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + +def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): + def fun(*res_and_args): + res, args = split_list(res_and_args, [len(jaxpr.constvars)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] + return core.eval_jaxpr(jaxpr, res, *args) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for v, w in zip(jaxpr.constvars, which)] + in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) + return jaxpr + + +def _unmentioned2(mesh: Mesh, names: AxisNames, + auto: frozenset[AxisName]) -> list[AxisName]: + # We use a filtered-down version of unmentioned to avoid defensive-psum over + # more chips than required in the transpose-no-check-rep case. + name_set = {n for ns in names.values() for n in ns} | auto + return [n for n in _all_mesh_names_except_spmd(mesh, auto) + if n not in name_set] + + +def _shard_map_transpose(out_cts, *args, + jaxpr: core.Jaxpr, mesh, in_names, out_names, + check_rep, auto): + mb_div = lambda x, y: x / y if y != 1 else x + out_cts = [ + ad.Zero(_shard_aval(mesh, auto, check_rep, ns, x.aval)) + if type(x) is ad.Zero else x if check_rep or dtypes.dtype(x) == dtypes.float0 + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) + for ns, x in zip(out_names, out_cts) + ] + args = tuple(x if type(x) is not ad.UndefinedPrimal else + ad.UndefinedPrimal(_shard_aval(mesh, auto, check_rep, ns, x.aval)) + for ns, x in zip(in_names, args)) + all_args, in_tree = tree_flatten((out_cts, args)) + + def fun_trans_callable(out_cts, args): + # TODO(mattjj): when #26811 lands, delete this and just run backward_pass + in_undef = map(ad.is_undefined_primal, args) + res, undefs = partition_list(in_undef, args) + jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( + pe.close_jaxpr(jaxpr), in_undef, False) + res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) + in_cts = ad.backward_pass( + jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts + )[len(res_reshaped):] + _, in_ct_names = partition_list(in_undef, in_names) + in_cts = [ad.Zero(_unshard_aval(mesh, check_rep, ns, x.aval)) + if type(x) is ad.Zero else x if check_rep + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) + for ns, x in zip(in_ct_names, in_cts)] + res_zeros = [ad_util.zero_from_primal(r) for r in res] + return merge_lists(in_undef, res_zeros, in_cts) + + fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) + fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) + fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) + + new_in_names = \ + [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ + [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] + + def new_out_names_thunk(): + return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) + + try: + out_flat = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + auto=auto) + except (FloatingPointError, ZeroDivisionError) as e: + print("Invalid nan value encountered in the backward pass of a shard_map " + "function. Calling the de-optimized backward pass.") + try: + # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` + # in eager mode so that output of shmap are not manual. + with jax.disable_jit(True): + _ = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + auto=auto) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + dispatch._raise_no_nan_in_deoptimized(e) + return tree_unflatten(out_tree(), out_flat) +ad.primitive_transposes[shard_map_p] = _shard_map_transpose + +# Remat + +def _partial_eval_jaxpr_custom_rule( + saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], + inst_in: Sequence[bool], eqn: core.JaxprEqn +) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] + check_rep, auto = eqn.params['check_rep'], eqn.params['auto'] + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ + pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] + with (_extend_axis_env(mesh, auto), + use_abstract_mesh(_as_manual_mesh(mesh, auto)), + config._check_rep(check_rep)): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) + jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) + ins_known, _ = partition_list(unks_in, eqn.invars) + out_binders_known, _ = partition_list(unks_out, eqn.outvars) + _, ins_staged = partition_list(inst_in, eqn.invars) + _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym() + residuals, staged_in_res_names = [], [] + for var, w in zip(jaxpr_staged.invars[:num_res], which): + if w: + rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore + if check_rep else {0: _all_newly_manual_mesh_names(mesh, auto)}) + residuals.append(newvar(_unshard_aval(mesh, check_rep, rn, var.aval))) + staged_in_res_names.append(rn) + if check_rep: + out_res_names_known = [ + {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} + for var, o in zip(res_vars, out_fwd) if o is None + ] + else: + out_res_names_known = [{0: _all_newly_manual_mesh_names(mesh, auto)}] * sum(which) + params_known, params_staged = _pe_custom_params( + unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, + out_res_names_known, staged_in_res_names, + dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) + eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + eqn.primitive, params_known, jaxpr_known.effects, + eqn.source_info, eqn.ctx) + full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) + eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, + jaxpr_staged.effects, eqn.source_info, eqn.ctx) + assert len(eqn_staged.invars) == len(jaxpr_staged.invars) + new_inst = [x for x, inst in zip(eqn.invars, inst_in) + if type(x) is core.Var and not inst] + new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals +pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ + _partial_eval_jaxpr_custom_rule + +def _add_reshapes(which: Sequence[bool], + jaxpr_known: core.Jaxpr, + jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: + # add singleton axes to residuals which are from jaxpr_known and are scalars + which_ = [w and not v.aval.shape # pytype: disable=attribute-error + for w, v in zip(which, jaxpr_staged.invars[:len(which)])] + if not any(which_): return jaxpr_known, jaxpr_staged + assert not jaxpr_known.constvars and not jaxpr_staged.constvars + + def known(*args): + out = core.eval_jaxpr(jaxpr_known, (), *args) + out_known, res = split_list(out, [len(out) - sum(which)]) + res = [_add_singleton(x) if not x.shape else x for x in res] + return [*out_known, *res] + avals_in = [v.aval for v in jaxpr_known.invars] + jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) + + def staged(*args): + res_, ins = split_list(args, [len(which)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] + return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] + avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] + jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) + + return jaxpr_known, jaxpr_staged + +def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, + in_fwd, out_fwd, out_res_names_known, staged_in_res_names, + params_known, params_staged): + # prune inputs to jaxpr_known according to unks_in + in_names_known, _ = partition_list(unks_in, params_known['in_names']) + _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) + out_names_known = out_names_known + out_res_names_known + assert len(out_names_known) == len(params_known['jaxpr'].outvars) + new_params_known = dict(params_known, in_names=tuple(in_names_known), + out_names=tuple(out_names_known)) + + # added num_res new inputs to jaxpr_staged, pruning according to inst_in + _, in_names_staged = partition_list(inst_in, params_staged['in_names']) + iter_staged = iter(staged_in_res_names) + res_names = [in_names_known[f1] if f1 is not None else + out_names_known[f2] if f2 is not None else + next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] + + in_names_staged = res_names + in_names_staged + _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) + new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), + out_names=tuple(out_names_staged)) + return new_params_known, new_params_staged + +# TODO(mattjj): remove this mechanism when we revise mesh scopes +def _all_mesh_names_except_spmd( + mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + spmd_names = axis_env.spmd_axis_names + return tuple(name for name in mesh.axis_names if name not in spmd_names and + name not in auto) + +def _all_newly_manual_mesh_names( + mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + vmap_spmd_names = set(axis_env.spmd_axis_names) + if not (ctx_mesh := get_abstract_mesh()).empty: + mesh = ctx_mesh + already_manual_names = set(ctx_mesh.manual_axes) + else: + # TODO(mattjj): remove this mechanism when we revise mesh scopes + already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names + return tuple(name for name in mesh.axis_names + if name not in auto | vmap_spmd_names | already_manual_names) + + +# DCE + +# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? +def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + mesh = eqn.params["mesh"] + auto = eqn.params["auto"] + check_rep = eqn.params["check_rep"] + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) + if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: + return used_inputs, None + else: + _, in_names = partition_list(used_inputs, eqn.params['in_names']) + _, out_names = partition_list(used_outputs, eqn.params['out_names']) + new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), + out_names=tuple(out_names)) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + new_eqn = pe.new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [x for x, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) + return used_inputs, new_eqn +pe.dce_rules[shard_map_p] = _shard_map_dce + +# Implementing pmap in terms of shard_map + +def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, + static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None): + devices = tuple(devices) if devices is not None else devices + axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( + f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) + + def infer_params(*args, **kwargs): + p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, + donate_tuple, devices, backend, axis_size, args, kwargs) + for arg in p.flat_args: + dispatch.check_arg(arg) + mesh = Mesh(_get_devices(p, backend), (axis_name,)) + _pmapped, in_specs, out_specs = _cached_shard_map( + p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) + flat_global_args = host_local_array_to_global_array( + p.flat_args, mesh, list(in_specs)) + jitted_f = jax.jit( + _pmapped, + donate_argnums=(i for i, val in enumerate(p.donated_invars) if val)) + return jitted_f, flat_global_args, p.out_tree, mesh, out_specs + + def wrapped(*args, **kwargs): + (jitted_f, flat_global_args, out_tree, mesh, + out_specs) = infer_params(*args, **kwargs) + outs = jitted_f(*flat_global_args) + outs = global_array_to_host_local_array(outs, mesh, out_specs()) + return tree_unflatten(out_tree(), outs) + + def lower(*args, **kwargs): + jitted_f, _, _, _, _ = infer_params(*args, **kwargs) + return jitted_f.lower(*args, **kwargs) + wrapped.lower = lower + + return wrapped + + +@lu.cache +def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): + in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) + out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) + fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) + return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=False, + axis_names=set(mesh.axis_names)), + in_specs, out_specs) + +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): + args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), + list(args), list(in_axes)) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) + +def _axis_to_spec(axis_name, ax): + if isinstance(ax, int): + specs = [None] * ax + [axis_name] + return P(*specs) + elif ax is None: + return P() + else: + raise TypeError(ax) + +def _get_devices(p, backend): + if backend is not None and p.devices is None: + devs = jax.devices(backend=backend) + else: + devs = jax.devices() if p.devices is None else p.devices + if jax.process_count() > 1: + return devs[:p.global_axis_size] + return devs[:p.local_axis_size] diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 7275046f556d..6f604f1195a0 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -38,7 +38,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.util import weakref_lru_cache -from jax.experimental.shard_map import shard_map_p +from jax._src.shard_map import shard_map_p import numpy as np diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index 6a7f2916b503..fcfe3ff4b9ff 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -29,11 +29,11 @@ from jax._src.mesh import AbstractMesh, Mesh from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map from jax._src.util import foreach -from jax.experimental import shard_map +from jax._src.shard_map import shard_map, shard_map_p ShapeDtypeStructTree = Any - +Specs = Any map = util.safe_map @@ -230,8 +230,8 @@ def wrapped(*args): def roofline( f: Callable, mesh: Mesh | AbstractMesh | None = None, - in_specs: shard_map.Specs | None = None, - out_specs: shard_map.Specs | None = None, + in_specs: Specs | None = None, + out_specs: Specs | None = None, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, @@ -243,14 +243,15 @@ def roofline( def wrapped(*args): wrapped_f = f if in_specs is not None and out_specs is not None and mesh is not None: - wrapped_f = shard_map.shard_map(wrapped_f, mesh, in_specs, out_specs) + wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs, + out_specs=out_specs) if vjp: wrapped_f = _f_with_vjp(wrapped_f) jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) def make_sharded_shape_dtype_struct( - shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + shape: api.ShapeDtypeStruct, out_spec: Specs ) -> api.ShapeDtypeStruct: return api.ShapeDtypeStruct( shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) # type: ignore @@ -267,7 +268,7 @@ def make_sharded_shape_dtype_struct( used_outputs = (True,) * len(jaxpr.jaxpr.outvars) jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) shard_map_eqns = [ - e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p + e for e in jaxpr.eqns if e.primitive == shard_map_p ] if shard_map_eqns: try: @@ -307,8 +308,8 @@ def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): def roofline_and_grad( f: Callable, mesh: Mesh | AbstractMesh, - in_specs: shard_map.Specs, - out_specs: shard_map.Specs, + in_specs: Specs, + out_specs: Specs, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 52dc45b68ef0..afb61159f55b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -11,71 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from collections.abc import Callable, Hashable, Sequence -import enum -from functools import partial -import inspect -from math import prod -import operator as op -from typing import Any, TypeVar, Union - -import numpy as np - -import jax -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import ad_util -from jax._src import api_util -from jax._src import config -from jax._src import core -from jax._src import debugging -from jax._src import dispatch -from jax._src import dtypes -from jax._src import linear_util as lu -from jax._src import sharding_impls -from jax._src import source_info_util +from collections.abc import Callable, Hashable +from typing import Any from jax._src import traceback_util -from jax._src import util -from jax._src.core import pvary -from jax._src.core import Tracer, typeof -from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, - get_abstract_mesh) -from jax._src.api import _shared_code_pmap, _prepare_pmap -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import sdy -from jax._src.util import (HashableFunction, HashablePartial, unzip2, - as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2) -from jax._src.interpreters import batching -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import pxla -from jax._src.interpreters import ad -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - tree_structure, tree_leaves, keystr) -from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, - generate_key_paths, KeyPath) -from jax.experimental.multihost_utils import (host_local_array_to_global_array, - global_array_to_host_local_array) - -P = PartitionSpec +from jax.sharding import Mesh, AbstractMesh +from jax._src import shard_map as jshmap -map, unsafe_map = util.safe_map, map -zip, unsafe_zip = util.safe_zip, zip -traceback_util.register_exclusion(__file__) - -# API - -Specs = Any # PyTree[PartitionSpec] +Specs = Any AxisName = Hashable - @traceback_util.api_boundary -def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs, check_rep: bool = True, - auto: frozenset[AxisName] = frozenset()): +def shard_map( + f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, + check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): """Map a function over shards of data. Note: @@ -128,1593 +77,6 @@ def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, .. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html """ axis_names = frozenset(mesh.axis_names) - auto - return shard_map2(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, - check_vma=check_rep, axis_names=axis_names) - - -def shard_map2(f, /, *, out_specs: Specs, - axis_names: set[AxisName] | frozenset[AxisName] = set(), - in_specs: Specs | None = None, - mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): - return _shard_map(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, - axis_names=axis_names, check_vma=check_vma) - -def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, - in_specs: Specs, out_specs: Specs | Callable[[], Specs], - axis_names: set[AxisName] | frozenset[AxisName], - check_vma: bool): - if not callable(f): - raise TypeError("shard_map requires a callable for its first argument, " - f"but got {f} of type {type(f)}.") - - if mesh is None: - mesh = get_abstract_mesh() - if mesh.empty: - raise ValueError( - "The context mesh cannot be empty. Either use" - " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass" - " a mesh to `shard_map` via the `mesh` keyword argument.") - if not isinstance(mesh, (Mesh, AbstractMesh)): - raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " - "`jax.sharding.AbstractMesh` instance for its " - f"second argument, but got {mesh} of type {type(mesh)}.") - - if not isinstance(axis_names, (frozenset, set)): - raise TypeError( - "`axis_names` argument of shard_map should be of type `frozenset` or" - f" `set`. Got type: {type(axis_names)}") - if isinstance(axis_names, set): - axis_names = frozenset(axis_names) - if not axis_names: - axis_names = frozenset(mesh.axis_names) - auto = frozenset(mesh.axis_names) - frozenset(axis_names) - if not auto.issubset(mesh.axis_names): - raise ValueError(f"shard_map requires auto={auto} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") - - if in_specs is not None: - _check_specs(SpecErrorType.input, in_specs, auto) - if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, auto) - - @util.wraps(f) - @traceback_util.api_boundary - def wrapped(*args): - fun = lu.wrap_init( - f, debug_info=api_util.debug_info("shard_map", f, args, {})) - args_flat, in_tree = tree_flatten(args) - fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - - # TODO(yashkatariya): Maybe we don't have to be this strict? - if mesh._any_axis_auto and in_specs is None: - raise TypeError( - "shard_map in_specs argument must be a pytree of" - " `jax.sharding.PartitionSpec` instances, but it was None when mesh" - f" {mesh} has `Auto` axes.\n") - - try: - in_specs_flat = broadcast_prefix( - in_specs, args, is_leaf=lambda x: x is None) - except ValueError: - e, *_ = prefix_errors(in_specs, args) - raise e('shard_map in_specs') from None - - # TODO(yashkatariya): Relax this and convert only `None`s in `in_specs_flat` - # and accept the other specs as is. - if mesh._are_all_axes_explicit and in_specs is None: - arg_s = [typeof(a).sharding for a in args_flat] - assert all(i is None for i in in_specs_flat), in_specs_flat - in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s] - - dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) - if s is not None) - fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) - _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - - @memoize - def out_names_thunk(): - if callable(out_specs): - out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_, auto) - else: - out_specs_ = out_specs - dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) - try: - out_specs_flat = broadcast_prefix(out_specs_, dummy) - except ValueError: - e, *_ = prefix_errors(out_specs_, dummy) - raise e('shard_map out_specs') from None - return tuple(map(_canonicalize_spec, out_specs_flat)) - - if check_vma: - fun = _implicit_pvary_on_output(fun, out_names_thunk) - - try: - out_flat = shard_map_p.bind( - fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_vma, auto=auto) - except _SpecError as e: - fails, = e.args - if not callable(out_specs): - msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) - if any(fail is not no_fail and not fail.shape for fail in fails): - msg += (" In particular, for rank 0 outputs which are not constant " - "over the mesh, add at least one (singleton) axis to them so " - "that they can be concatenated using out_specs.") - raise ValueError(msg) from None - except _RepError as e: - fails, = e.args - if not callable(out_specs): - msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) - raise ValueError(msg) from None - return tree_unflatten(out_tree(), out_flat) - return wrapped - - -# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs -AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable -def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: - if isinstance(spec, PartitionSpec): - return {i: names if isinstance(names, tuple) else (names,) - for i, names in enumerate(spec) if names is not None} - else: - return spec - -def _manual_spec(manual_axes, spec: P) -> P: - out = [] # type: ignore - for s in spec: - if s is None: - out.append(s) - elif isinstance(s, tuple): - temp = [p if p in manual_axes else None for p in s] - while temp and temp[-1] is None: - temp.pop() - if None in temp: - raise ValueError(f"Invalid spec: {spec}") - out.append(None if len(temp) == 0 else tuple(temp)) - else: - out.append(s if s in manual_axes else None) - return P(*out) - - -# Error checking and messages - -SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) - -def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: - if error_type == SpecErrorType.input and specs is None: - raise TypeError( - "shard_map in_specs argument must be a pytree of " - "`jax.sharding.PartitionSpec` instances, but it was None.\n" - "Instead of `in_specs=None`, did you mean `in_specs=P()`, " - "where `P = jax.sharding.PartitionSpec`?") - def check_spec(p): - if not isinstance(p, PartitionSpec): - return False - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - return False - return True - if all(check_spec(p) for p in tree_leaves(specs)): return - prefix = 'in' if error_type == SpecErrorType.input else 'out' - msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " - for key, x in generate_key_paths(specs) if not isinstance(x, P)] - if not msgs: - for key, p in generate_key_paths(specs): - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") - raise ValueError( - f"shard_map {prefix}_specs argument cannot refer to an axis " - f"marked auto ({auto}), but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - raise TypeError( - f"shard_map {prefix}_specs argument must be a pytree of " - f"`jax.sharding.PartitionSpec` instances, but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - -class NoFail: pass -no_fail = NoFail() - -def _check_specs_vs_args( - f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, - dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], - xs: Sequence) -> None: - in_avals = map(core.shaped_abstractify, xs) - fail = [a if not len(p) <= a.ndim else no_fail - for p, a in zip(in_specs_flat, in_avals)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) - raise ValueError(msg) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) - for d, ns in names.items()) else no_fail - for a, names in zip(in_avals, in_names_flat)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) - raise ValueError(msg) - -def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], - fail: Sequence[core.ShapedArray | NoFail] - ) -> list[core.ShapedArray | NoFail]: - fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves - for i, f in zip(dyn_argnums, fail): - fail_[i] = f - return fail_ - -def _spec_rank_error( - error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - if error_type == SpecErrorType.input: - prefix, base = 'in', 'args' - ba = _try_infer_args(f, tree) - else: - prefix, base = 'out', f'{fun_name}(*args)' - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if error_type == SpecErrorType.input and ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where {base}{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where {base}{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - msgs.append( - f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " - f"{len(spec)}, but " - f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " - f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given an " - f"{prefix}_specs entry which is too long to be compatible with the " - f"corresponding {prefix}put value from the function:\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Entries in {prefix}_specs must be of length no greater than the " - f"number of axes in the corresponding {prefix}put value.\n\n" - f"Either revise the spec to be shorter, or modify '{fun_name}' so " - f"that its {prefix}puts have sufficient rank.") - if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): - msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " - "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") - return msg - -def _spec_divisibility_error( - f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - ba = _try_infer_args(f, tree) - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where args{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where args{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - names = _canonicalize_spec(spec) - for d, ns in names.items(): - if aval.shape[d] % prod(mesh.shape[n] for n in ns): - axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" - total = 'total ' if len(ns) > 1 else '' - sz = prod(mesh.shape[n] for n in ns) - msgs.append( - f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " - f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " - f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " - f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " - f"{aval.shape[d]}") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given argument " - f"arrays with axis sizes that are not evenly divisible by the " - f"corresponding mesh axis sizes:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Array arguments' axis sizes must be evenly divisible by the mesh " - f"axis or axes indicated by the corresponding elements of the " - f"argument's in_specs entry. Consider checking that in_specs are " - f"correct, and if so consider changing the mesh axis sizes or else " - f"padding the input and adapting '{fun_name}' appropriately.") - return msg - -def _inout_rep_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, - specs: Specs, fails: list[set | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): - dst = _canonicalize_spec(spec) - unmentioned = _unmentioned(mesh, dst) - if len(unmentioned) > 1: - need_rep = ','.join(map(str, unmentioned)) - got_rep = ','.join(map(str, rep)) - diff = ','.join(map(str, [n for n in unmentioned if n not in rep])) - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axes " - f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " - f"which is missing the required axes {diff}") - else: - need_rep_, = unmentioned - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axis " - f"'{need_rep_}', but could not infer replication over any axes") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given " - f"out_specs which require replication which can't be statically " - f"inferred given the mesh:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - "Check if these output values are meant to be replicated over those " - "mesh axes. If not, consider revising the corresponding out_specs " - "entries. If so, consider disabling the check by passing the " - "check_rep=False argument to shard_map.") - return msg - -def _unmentioned(mesh: Mesh | AbstractMesh, names: AxisNames) -> list[AxisName]: - name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] - - -def _try_infer_args(f, tree): - dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) - try: - return inspect.signature(f).bind(*dummy_args) - except (TypeError, ValueError): - return None - -T = TypeVar('T') -def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] - ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: - failures = tree_unflatten(tree, fails) - failures_aug = generate_key_paths(failures) - specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) - leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P - specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) - return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) - in zip(specs_aug, failures_aug) - if s is not None and fail_data is not no_fail] - -# Primitive - -@lu.transformation2 -def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): - out_flat = f(*args, **kwargs) - return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) - for o, n in zip(out_flat, out_names_thunk())] - -JaxType = Any -MaybeTracer = Union[JaxType, Tracer] - -class ShardMapPrimitive(core.Primitive): - multiple_results = True - - def bind(self, *args, **params): - return self._true_bind(*args, **params) - - def bind_with_trace(self, trace, fun_and_args, params): - fun: lu.WrappedFun - fun, *args = fun_and_args - return trace.process_shard_map(shard_map_p, fun, args, **params) - - def get_bind_params(self, params): - new_params = dict(params) - jaxpr: core.Jaxpr = new_params.pop('jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, - debug_info=jaxpr.debug_info), - jaxpr, ()) - axes = new_params.pop('out_names') - new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) - return [subfun], new_params - -shard_map_p = ShardMapPrimitive('shard_map') - -# Staging - -@util.cache(max_size=256, trace_context_in_key=True) -def _as_manual_mesh(mesh, auto: frozenset): - manual_axes = tuple(set(mesh.axis_names) - auto) - cur_mesh = get_abstract_mesh() - if cur_mesh.empty: - cur_mesh = mesh - explicit_axes, auto_axes = set(), set() # type: ignore - for a in auto: - if cur_mesh._name_to_type[a] == AxisType.Auto: - auto_axes.add(a) - else: - assert cur_mesh._name_to_type[a] == AxisType.Explicit - explicit_axes.add(a) - - new_axis_types = [] - for n in mesh.axis_names: - if n in manual_axes: - new_axis_types.append(AxisType.Manual) - elif n in auto_axes: - new_axis_types.append(AxisType.Auto) - else: - assert n in explicit_axes - new_axis_types.append(AxisType.Explicit) - return AbstractMesh(mesh.axis_sizes, mesh.axis_names, - axis_types=tuple(new_axis_types)) - - -def _extend_axis_env(mesh, auto): - return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k not in auto]) - -def _shard_map_staging( - trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, - in_tracers: Sequence[Any], *, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - auto: frozenset, - ) -> Sequence[pe.DynamicJaxprTracer]: - in_tracers = map(trace.to_jaxpr_tracer, in_tracers) - in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, - in_avals) - manual_mesh = _as_manual_mesh(mesh, auto) - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), - config._check_rep(check_rep)): - jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) - _check_names(out_names_thunk(), out_avals_) - if check_rep: - out_vma = [v.aval.vma for v in jaxpr.outvars] - _check_reps(mesh, auto, out_names_thunk(), out_vma) - out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) - for names, aval in zip(out_names_thunk(), out_avals)] - source_info = source_info_util.current() - out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] - invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) - outvars = map(trace.makevar, out_tracers) - in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), - config._check_rep(check_rep)): - jaxpr = pe.convert_constvars_jaxpr(jaxpr) - params = dict(mesh=mesh, in_names=in_names_staged, - out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, auto=auto) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, - effs, source_info) - trace.frame.add_eqn(eqn) - return out_tracers -pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging - -# TODO add underscore version, for direct-linearize to consume - -def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: - assert isinstance(aval, core.ShapedArray) - return aval - -def _shard_aval(mesh: Mesh, auto, check_rep, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: - if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, check_rep, names, - aval) - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _unshard_aval(mesh: Mesh, check_rep, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: - if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, check_rep, names, aval) - else: - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, auto) - new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - vma = (frozenset({n for ns in names.values() for n in ns}) - if check_rep else frozenset()) - return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) -core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array - -def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, - aval: core.AbstractValue,) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) - if aval.ndim == 0: - out_spec = P() - else: - out_spec = [] # type: ignore - for name_s, aval_s in zip(names_spec, aval.sharding.spec): - if name_s and not aval_s: - out_spec.append(name_s) - elif aval_s and not name_s: - out_spec.append(aval_s) - elif not name_s and not aval_s: - out_spec.append(None) - else: - assert name_s and aval_s - name_s = name_s if isinstance(name_s, tuple) else (name_s,) - aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) - out_spec.append(name_s + aval_s) - out_spec = PartitionSpec(*out_spec) - new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else - get_abstract_mesh()) - new_sharding = NamedSharding(new_mesh, out_spec) - manual_axes = set(new_mesh.manual_axes) - vma = (frozenset(v for v in aval.vma if v in manual_axes) - if check_rep else frozenset()) - return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) -core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array - -# Type-checking - -RepType = Any - -def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, auto): - # TODO(mattjj,parkers): check auto - for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): - if not core.typecompat(v.aval, _shard_aval( - mesh, auto, check_rep, in_name, x.aval)): - raise core.JaxprTypeError("shard_map argument avals not compatible with " - "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): - core.check_jaxpr(jaxpr) - if check_rep: - out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] - for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, auto, rep, dst): - raise core.JaxprTypeError( - "shard_map can't prove output is sufficiently replicated") - out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh, check_rep), out_names, - out_avals_sharded) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - return out_avals, effs -core.custom_typechecks[shard_map_p] = _shard_map_typecheck - -def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: - return set(mesh.axis_names) - {n for ns in names.values() for n in ns} - -def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: - return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) - - -# Lowering - -def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in -) -> sharding_impls.SdyArraySharding: - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) - if auto: - for dim_sharding in sdy_sharding.dimension_shardings: - dim_sharding.is_open = True - return sdy_sharding - - -def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep): - in_avals_ = [v.aval for v in jaxpr.invars] - if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): - # Nested `ManualComputationOp`s cannot refer to axes that are already - # manual. So figure out what axes are free thus far. - free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - shardy_manual_axes = free_axes - auto - else: - shardy_manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - - # The order of manual axes should match the order of mesh.axis_names to avoid - # non-determinism issues. - manual_axes = [a for a in mesh.axis_names - if a in shardy_manual_axes] - if np.prod([mesh.shape[a] for a in manual_axes]) == 1: - # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, - dim_var_values=ctx.dim_var_values) - return out_nodes - - in_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - in_names, ctx.avals_in)).build() - out_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - out_names, ctx.avals_out)).build() - output_types = map(mlir.aval_to_ir_type, ctx.avals_out) - manual_computation_op = sdy.ManualComputationOp( - output_types, in_nodes, in_shardings, out_shardings, - sdy.ManualAxesAttr.get( - ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) - block = ir.Block.create_at_start( - manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), - config._check_rep(check_rep)): - out_nodes_, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, - dim_var_values=ctx.dim_var_values) - sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) - - return manual_computation_op.results - - -def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, auto): - if config.use_shardy_partitioner.value: - return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep) - - in_avals_ = [v.aval for v in jaxpr.invars] - out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, - in_avals_, in_nodes) - manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): - out_nodes_, tokens_out = mlir.call_lowering( - "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, - out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, - arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), - result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) - ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, - ctx.avals_out, out_nodes_) -mlir.register_lowering(shard_map_p, _shard_map_lowering) - -def _make_scoped_manual_sharding(ctx, mesh, axes): - axis_ctx = ctx.module_context.axis_context - mesh = mesh.abstract_mesh - if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - mesh = mesh.update_axis_types( - {a: AxisType.Manual for a in axis_ctx.manual_axes}) - return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore - -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() - unspecified = set(range(aval_in.ndim)) if auto else set() - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, - unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) - -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_out, ns) - aval_out = core.physical_aval(aval_out) - unspecified = set(range(aval_out.ndim)) if auto else set() - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) - shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, - unspecified) - -def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: - if isinstance(aval, core.ShapedArray): - return str(map(names.get, range(aval.ndim))) - return '' - -# Eager evaluation - -def get_mesh_from_args(args_flat, mesh): - for a in args_flat: - if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): - if a.sharding.mesh.shape_tuple != mesh.shape_tuple: - aval = core.shaped_abstractify(a) - raise ValueError( - f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" - " match the mesh shape passed to shard_map " - f" {mesh.shape_tuple} for shape {aval.str_short()}") - mesh = a.sharding.mesh - if isinstance(mesh, AbstractMesh): - raise ValueError( - "Please pass `jax.Array`s with a `NamedSharding` as input to" - " `shard_map` when passing `AbstractMesh` to the mesh argument.") - assert isinstance(mesh, Mesh) - return mesh - -def _vma_to_spec(mesh, vma): - return P(tuple(i for i in mesh.axis_names if i in vma)) - -def _names_to_vma(names): - return {n for ns in names.values() for n in ns} - -def _vma_to_rep(mesh, auto, vma): - return frozenset((set(mesh.axis_names) - auto) - vma) - -def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, auto): - if auto: raise NotImplementedError - del prim - if isinstance(mesh, AbstractMesh): - mesh = get_mesh_from_args(args, mesh) - cur_mesh = get_abstract_mesh() - args = map(partial(_unmatch_spec, mesh, check_rep, context_mesh=cur_mesh), - in_names, args) - in_vma = map(_names_to_vma, in_names) - outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_rep, cur_mesh) - out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] - _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep: - _check_reps(mesh, auto, out_names_thunk(), out_vma) - src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) - else: - src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) - dst_pspecs = map(_names_to_pspec, out_names_thunk()) - return map(partial(_match_spec, mesh, check_rep), src_pspecs, dst_pspecs, - outs) -core.EvalTrace.process_shard_map = _shard_map_impl - -def _run_shmap(f, mesh, auto, args, vmas, check_rep, context_mesh): - trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) - in_tracers = map(partial(ShardMapTracer, trace), vmas, args) - manual_mesh = _as_manual_mesh(mesh, auto) - with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - use_abstract_mesh(manual_mesh), config._check_rep(check_rep)): - ans = f.call_wrapped(*in_tracers) - outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) - return outs, out_vma - -def _names_to_pspec(names: AxisNames) -> PartitionSpec: - ndmin = max(names) + 1 if names else 0 - unpack = lambda t: t[0] if t is not None and len(t) == 1 else t - return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) - -def _unmatch_spec(mesh: Mesh, check_rep, src: AxisNames, x: JaxType, - context_mesh) -> JaxType: - with (core.eval_context(), jax.disable_jit(False), - use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, check_rep, - tuple(src.items())))(x) - -def _unmatch(mesh, check_rep, src_tup, x): - src = _names_to_pspec(dict(src_tup)) - if check_rep: - used_axes = {i for _, ns in src_tup for i in ns} - dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) - else: - dst = P(mesh.axis_names) - check_rep = False - return shard_map(_add_singleton, mesh, (src,), dst, check_rep=check_rep)(x) - -def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] - ) -> None: - fail = [a if n and not max(n) < a.ndim else no_fail - for n, a in zip(names, avals)] - if any(f is not no_fail for f in fail): - raise _SpecError(fail) - -class _SpecError(Exception): - pass - -def _check_reps(mesh, auto, names, vmas): - reps = [_vma_to_rep(mesh, auto, v) for v in vmas] - fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail - for n, r in zip(names, reps)] - if any(f is not no_fail for f in fail): - raise _RepError(fail) - -class _RepError(Exception): - pass - -def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, - dst_pspec: PartitionSpec, x: JaxType) -> JaxType: - fn = HashablePartial(_match, mesh, check_rep, src_pspec, dst_pspec) - with core.eval_context(), jax.disable_jit(False): - return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) - -def _match(mesh, check_rep, src_pspec, dst_pspec, x): - return shard_map(_rem_singleton, mesh, src_pspec, dst_pspec, - check_rep=check_rep)(x) - -def _rem_singleton(x): return jnp.squeeze(x, axis=0) -def _add_singleton(x): return jnp.expand_dims(x, axis=0) - -def _maybe_check_special(outs): - if not config.debug_nans.value and not config.debug_infs.value: return - bufs = [s.data for leaf in tree_leaves(outs) - for s in getattr(leaf, 'addressable_shards', [])] - try: - dispatch.check_special('shard_map', bufs) - except dispatch.InternalFloatingPointError as e: - raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None - -class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "auto", "check", "context_mesh") - - mesh: Mesh - auto: frozenset[AxisName] - check: bool - context_mesh: AbstractMesh - - def __init__(self, mesh, auto, check, context_mesh): - super().__init__() - self.mesh = mesh - self.auto = auto - self.check = check - self.context_mesh = context_mesh - - def to_val_vma_pair(self, val): - if isinstance(val, ShardMapTracer): - return val.val, val.vma - elif isinstance(val, Tracer): - raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") - else: - val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) - return val_, frozenset() - - def process_primitive(self, prim, tracers, params): - in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - if self.check: - out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) - out_avals = tuple(out_avals) if type(out_avals) is list else out_avals - out_vma = tree_map(lambda a: a.vma, out_avals) - in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) - out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) - else: - out_vma = frozenset() - in_specs = out_specs = P(self.mesh.axis_names) - - eager_rule = eager_rules.get(prim) - if eager_rule: - out_vals = eager_rule(self.mesh, *in_vals, **params) - else: - f = HashablePartial( - _prim_applier, prim, self.check, tuple(params.items()), self.mesh, - in_specs, out_specs) - with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), - jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): - out_vals = jax.jit(f)(*in_vals) - _maybe_check_special(out_vals) - if prim.multiple_results: - out_vma = (out_vma if isinstance(out_vma, (list, tuple)) - else [out_vma] * len(out_vals)) - return map(partial(ShardMapTracer, self), out_vma, out_vals) - return ShardMapTracer(self, out_vma, out_vals) - - def process_call(self, call_primitive, fun, tracers, params): - raise NotImplementedError( - f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " - "yet supported. Put a `jax.jit` around the `shard_map`-decorated " - "function, and open a feature request at " - "https://github.com/jax-ml/jax/issues !") - - def process_map(self, map_primitive, fun, tracers, params): - raise NotImplementedError( - "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/jax-ml/jax/issues !") - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # Since ShardMapTrace is only used as a base main, we can drop the jvp. - del prim, jvp, symbolic_zeros - in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) - return map(partial(ShardMapTracer, self), out_vma, out_vals) - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("custom_vjp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) - del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) - return map(partial(ShardMapTracer, self), out_vma, out_vals) - - -class ShardMapTracer(core.Tracer): - vma: frozenset[AxisName] - val: JaxType - - def __init__(self, trace, vma, val): - self._trace = trace - if isinstance(vma, set): - vma = frozenset(vma) - assert isinstance(vma, frozenset) - self.vma = vma - self.val = val - - @property - def aval(self): - aval = core.get_aval(self.val) - out = core.mapped_aval(self._trace.mesh.size, 0, aval) - new_sharding = NamedSharding( - _as_manual_mesh(self._trace.mesh, self._trace.auto), - out.sharding.spec) # pytype: disable=attribute-error - vma = self.vma if config._check_rep.value else frozenset() - return out.update(sharding=new_sharding, vma=vma) - - def to_concrete_value(self): - if self.vma == frozenset(): - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - return core.to_concrete_value(self.val[0]) - else: - return None - - def __str__(self) -> str: - pb_names = set(self._trace.mesh.axis_names) - self.vma - self = pvary(self, tuple(pb_names)) - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - blocks = list(self.val) - mesh = self._trace.mesh - axis_names = f"({', '.join(map(str, mesh.axis_names))},)" - return '\n'.join( - f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" - for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) - __repr__ = __str__ # for debuggers, like `p x` - -def _prim_applier(prim, check_rep, params_tup, mesh, in_specs, out_specs, *args): - def apply(*args): - outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) - return tree_map(_add_singleton, outs) - out_specs = list(out_specs) if type(out_specs) is tuple else out_specs - return shard_map(apply, mesh, in_specs, out_specs, check_rep=check_rep)(*args) - -eager_rules: dict[core.Primitive, Callable] = {} - - -# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually -def _debug_callback_eager_rule( - mesh, - *args, - callback: Callable[..., Any], - effect: debugging.DebugEffect, - partitioned: bool, -): - del effect - with core.eval_context(): - all_blocks = zip(*map(list, args)) - for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): - callback(*blocks) - return [] - - -eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule - -def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): - del mesh, srcs, copy_semantics - for device in devices: - if device is not None: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") - return xs -eager_rules[dispatch.device_put_p] = _device_put_eager_rule - - -# Batching - -def _shard_map_batch( - trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, - in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - auto: frozenset) -> Sequence[batching.BatchTracer]: - in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) - if any(isinstance(d, batching.RaggedAxis) for d in in_dims): - raise NotImplementedError - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.axis_data.spmd_name - if spmd_axis_name is not None: - used = {n for names in in_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(new_in_names, in_dims)] - new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) - new_axis_data = batching.AxisData(trace.axis_data.name, new_size, - trace.axis_data.spmd_name, None) - else: - new_axis_data = trace.axis_data - fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) - - new_params = dict(mesh=mesh, in_names=new_in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) - with core.set_current_trace(trace.parent_trace): - out_vals = prim.bind(fun, *in_vals, **new_params) - make_tracer = partial(batching.BatchTracer, trace, - source_info=source_info_util.current()) - return map(make_tracer, out_vals, out_dims()) -batching.BatchTrace.process_shard_map = _shard_map_batch - -def _batch_out_names(spmd_axis_name, dims, out_names): - out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(out_names, dims)] - if spmd_axis_name is not None: - used = {n for names in out_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") - out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(out_names_, dims)] - return out_names_ - - -# Autodiff - -def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - which_nz = [ type(t) is not ad.Zero for t in tangents] - tangents = [t if type(t) is not ad.Zero else None for t in tangents] - args, in_tree = tree_flatten((primals, tangents)) - f_jvp = ad.jvp_subtrace(f, trace.tag) - f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) - tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] - - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_ax = out_names_thunk() - return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) - params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) - f_jvp, out_tree = ad.traceable(f_jvp, in_tree) - result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) - primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t - for p, t in zip(primal_out, tangent_out)] - return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] -ad.JVPTrace.process_shard_map = _shard_map_jvp - -def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, - f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): - tracers = map(trace.to_jaxpr_tracer, tracers) - in_pvals = [t.pval for t in tracers] - in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) - unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_rep), - unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) - f = _promote_scalar_residuals(f) - f_known, aux = pe.partial_eval_wrapper_nounits2( - f, (*in_knowns,), (*in_avals_sharded,)) - all_names = _all_newly_manual_mesh_names(mesh, auto) - - @as_hashable_function(closure=out_names_thunk) - def known_out_names(): - _, _, out_knowns, res_avals, _, _ = aux() - _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - if check_rep: - res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} - for a in res_avals] - else: - res_names = [{0: all_names}] * len(res_avals) - return (*out_known_names, *res_names) - - known_params = dict(mesh=mesh, in_names=(*known_in_names,), - out_names_thunk=known_out_names, check_rep=check_rep, - auto=auto) - out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), - known_params) - in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) - assert not jaxpr.constvars - unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) - known_out_names_ = known_out_names() - res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) - # TODO make res_avals be the full set, not just the non-fwd ones - res_avals_iter = iter(res_avals) - res_names = [] - for f1, f2 in zip(in_fwd, out_fwd): - if f1 is not None: - res_names.append(known_in_names[f1]) - elif f2 is not None: - res_names.append(known_out_names_[f2]) - else: - if check_rep: - res_vma = next(res_avals_iter).vma - res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) - else: - res_names.append({0: all_names}) - unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.to_jaxpr_tracer, env) - unk_arg_tracers = [t for t in tracers if not t.is_known()] - out_avals_sharded = [v.aval for v in jaxpr.outvars] - unk_params = dict(mesh=mesh, in_names=unk_in_names, - out_names=unk_out_names, jaxpr=jaxpr, - check_rep=check_rep, auto=auto) - out_avals = map(partial(_unshard_aval, mesh, check_rep), unk_out_names, - out_avals_sharded) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in out_avals] - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), - out_tracers, shard_map_p, unk_params, - effs, source_info_util.current()) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) -pe.JaxprTrace.process_shard_map = _shard_map_partial_eval - -def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, - tracers, mesh, in_names, - out_names_thunk, check_rep, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - nzs_in = tuple(type(t) is not ad.Zero for t in tangents) - f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) - f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - all_names = _all_newly_manual_mesh_names(mesh, auto) - - @as_hashable_function(closure=linearize_outs_thunk) - def fwd_out_names_thunk(): - res_avals, _, _, _, _, _ = linearize_outs_thunk() - out_names = out_names_thunk() - if check_rep: - res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} - for a in res_avals] - else: - res_names = [{0: all_names}] * len(res_avals) - return (*res_names, *out_names) - fwd_params = dict( - mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, auto=auto) - all_fwd_results = shard_map_p.bind_with_trace( - trace.parent_trace, (f_primal, *primals), fwd_params) - res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - non_fwd_res = all_fwd_results[:num_res_out] - primals_out = all_fwd_results[num_res_out:] - residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) - args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None - for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), - config._check_rep(check_rep)): - lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) - out_names = out_names_thunk() - res_avals_iter = iter(res_avals) - res_names = [] - for f1, f2 in zip(in_fwd, out_fwd): - if f1 is not None: - res_names.append(in_names[f1]) - elif f2 is not None: - res_names.append(out_names[f2]) - else: - if check_rep: - res_vma = next(res_avals_iter).vma - res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) - else: - res_names.append({0: all_names}) - new_in_names = (*res_names, *({} for _ in range(len(env))), - *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) - @as_hashable_function(closure=tangent_out_names) - def tangent_out_names_thunk(): - return tangent_out_names - tangent_params = dict( - mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=check_rep, auto=auto) - - # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here - def f_tangent(*args): - return core.eval_jaxpr(lin_jaxpr, (), *args) - - nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] - nz_tangents_out = shard_map_p.bind_with_trace( - trace.tangent_trace, - (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), - *residuals, *env, *nz_tangents_in), tangent_params) - nz_tangents_out_iter = iter(nz_tangents_out) - tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) - for nz, primal in zip(nzs_out, primals_out)] - return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) -ad.LinearizeTrace.process_shard_map = _shard_map_linearize - -@lu.transformation2 -def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): - ans = f(*args, **kwargs) - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - residuals = ans[:num_res_out] - primals = ans[num_res_out:] - residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in residuals] - return *residuals, *primals - -@lu.transformation2 -def _promote_scalar_residuals(f: Callable, *args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) - which = [f1 is None and f2 is None and not v.aval.shape - for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] - jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) - out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in out_consts] - return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) - -def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): - def fun(*res_and_args): - res, args = split_list(res_and_args, [len(jaxpr.constvars)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] - return core.eval_jaxpr(jaxpr, res, *args) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for v, w in zip(jaxpr.constvars, which)] - in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) - return jaxpr - - -def _unmentioned2(mesh: Mesh, names: AxisNames, - auto: frozenset[AxisName]) -> list[AxisName]: - # We use a filtered-down version of unmentioned to avoid defensive-psum over - # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} | auto - return [n for n in _all_mesh_names_except_spmd(mesh, auto) - if n not in name_set] - - -def _shard_map_transpose(out_cts, *args, - jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, auto): - mb_div = lambda x, y: x / y if y != 1 else x - out_cts = [ - ad.Zero(_shard_aval(mesh, auto, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if check_rep or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) - for ns, x in zip(out_names, out_cts) - ] - args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, check_rep, ns, x.aval)) - for ns, x in zip(in_names, args)) - all_args, in_tree = tree_flatten((out_cts, args)) - - def fun_trans_callable(out_cts, args): - # TODO(mattjj): when #26811 lands, delete this and just run backward_pass - in_undef = map(ad.is_undefined_primal, args) - res, undefs = partition_list(in_undef, args) - jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( - pe.close_jaxpr(jaxpr), in_undef, False) - res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) - in_cts = ad.backward_pass( - jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts - )[len(res_reshaped):] - _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if check_rep - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) - for ns, x in zip(in_ct_names, in_cts)] - res_zeros = [ad_util.zero_from_primal(r) for r in res] - return merge_lists(in_undef, res_zeros, in_cts) - - fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) - fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) - fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) - - new_in_names = \ - [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ - [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] - - def new_out_names_thunk(): - return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) - - try: - out_flat = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) - except (FloatingPointError, ZeroDivisionError) as e: - print("Invalid nan value encountered in the backward pass of a shard_map " - "function. Calling the de-optimized backward pass.") - try: - # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` - # in eager mode so that output of shmap are not manual. - with jax.disable_jit(True): - _ = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - dispatch._raise_no_nan_in_deoptimized(e) - return tree_unflatten(out_tree(), out_flat) -ad.primitive_transposes[shard_map_p] = _shard_map_transpose - -# Remat - -def _partial_eval_jaxpr_custom_rule( - saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], - inst_in: Sequence[bool], eqn: core.JaxprEqn -) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], - list[core.Var]]: - jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - check_rep, auto = eqn.params['check_rep'], eqn.params['auto'] - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): - jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ - pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) - num_out_primals = len(jaxpr_known.outvars) - num_res - in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] - out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) - idx_map = {id(v): i for i, v in enumerate(out_vars)} - out_fwd = [idx_map.get(id(v)) for v in res_vars] - which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] - mesh = eqn.params['mesh'] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), - config._check_rep(check_rep)): - jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) - jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) - jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) - jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) - ins_known, _ = partition_list(unks_in, eqn.invars) - out_binders_known, _ = partition_list(unks_out, eqn.outvars) - _, ins_staged = partition_list(inst_in, eqn.invars) - _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() - residuals, staged_in_res_names = [], [] - for var, w in zip(jaxpr_staged.invars[:num_res], which): - if w: - rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore - if check_rep else {0: _all_newly_manual_mesh_names(mesh, auto)}) - residuals.append(newvar(_unshard_aval(mesh, check_rep, rn, var.aval))) - staged_in_res_names.append(rn) - if check_rep: - out_res_names_known = [ - {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} - for var, o in zip(res_vars, out_fwd) if o is None - ] - else: - out_res_names_known = [{0: _all_newly_manual_mesh_names(mesh, auto)}] * sum(which) - params_known, params_staged = _pe_custom_params( - unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, - out_res_names_known, staged_in_res_names, - dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info, eqn.ctx) - full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) - eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info, eqn.ctx) - assert len(eqn_staged.invars) == len(jaxpr_staged.invars) - new_inst = [x for x, inst in zip(eqn.invars, inst_in) - if type(x) is core.Var and not inst] - new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals -pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ - _partial_eval_jaxpr_custom_rule - -def _add_reshapes(which: Sequence[bool], - jaxpr_known: core.Jaxpr, - jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: - # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape # pytype: disable=attribute-error - for w, v in zip(which, jaxpr_staged.invars[:len(which)])] - if not any(which_): return jaxpr_known, jaxpr_staged - assert not jaxpr_known.constvars and not jaxpr_staged.constvars - - def known(*args): - out = core.eval_jaxpr(jaxpr_known, (), *args) - out_known, res = split_list(out, [len(out) - sum(which)]) - res = [_add_singleton(x) if not x.shape else x for x in res] - return [*out_known, *res] - avals_in = [v.aval for v in jaxpr_known.invars] - jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) - - def staged(*args): - res_, ins = split_list(args, [len(which)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] - return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] - avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] - jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) - - return jaxpr_known, jaxpr_staged - -def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, - in_fwd, out_fwd, out_res_names_known, staged_in_res_names, - params_known, params_staged): - # prune inputs to jaxpr_known according to unks_in - in_names_known, _ = partition_list(unks_in, params_known['in_names']) - _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + out_res_names_known - assert len(out_names_known) == len(params_known['jaxpr'].outvars) - new_params_known = dict(params_known, in_names=tuple(in_names_known), - out_names=tuple(out_names_known)) - - # added num_res new inputs to jaxpr_staged, pruning according to inst_in - _, in_names_staged = partition_list(inst_in, params_staged['in_names']) - iter_staged = iter(staged_in_res_names) - res_names = [in_names_known[f1] if f1 is not None else - out_names_known[f2] if f2 is not None else - next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] - - in_names_staged = res_names + in_names_staged - _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) - new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), - out_names=tuple(out_names_staged)) - return new_params_known, new_params_staged - -# TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - spmd_names = axis_env.spmd_axis_names - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto) - -def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - vmap_spmd_names = set(axis_env.spmd_axis_names) - if not (ctx_mesh := get_abstract_mesh()).empty: - mesh = ctx_mesh - already_manual_names = set(ctx_mesh.manual_axes) - else: - # TODO(mattjj): remove this mechanism when we revise mesh scopes - already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names - return tuple(name for name in mesh.axis_names - if name not in auto | vmap_spmd_names | already_manual_names) - - -# DCE - -# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? -def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn | None]: - if not any(used_outputs) and not pe.has_effects(eqn): - return [False] * len(eqn.invars), None - mesh = eqn.params["mesh"] - auto = eqn.params["auto"] - check_rep = eqn.params["check_rep"] - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): - jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) - if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: - return used_inputs, None - else: - _, in_names = partition_list(used_inputs, eqn.params['in_names']) - _, out_names = partition_list(used_outputs, eqn.params['out_names']) - new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), - out_names=tuple(out_names)) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - new_eqn = pe.new_jaxpr_eqn( - [v for v, used in zip(eqn.invars, used_inputs) if used], - [x for x, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) - return used_inputs, new_eqn -pe.dce_rules[shard_map_p] = _shard_map_dce - -# Implementing pmap in terms of shard_map - -def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, - static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None): - devices = tuple(devices) if devices is not None else devices - axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( - f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) - - def infer_params(*args, **kwargs): - p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, axis_size, args, kwargs) - for arg in p.flat_args: - dispatch.check_arg(arg) - mesh = Mesh(_get_devices(p, backend), (axis_name,)) - _pmapped, in_specs, out_specs = _cached_shard_map( - p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) - flat_global_args = host_local_array_to_global_array( - p.flat_args, mesh, list(in_specs)) - jitted_f = jax.jit( - _pmapped, - donate_argnums=(i for i, val in enumerate(p.donated_invars) if val)) - return jitted_f, flat_global_args, p.out_tree, mesh, out_specs - - def wrapped(*args, **kwargs): - (jitted_f, flat_global_args, out_tree, mesh, - out_specs) = infer_params(*args, **kwargs) - outs = jitted_f(*flat_global_args) - outs = global_array_to_host_local_array(outs, mesh, out_specs()) - return tree_unflatten(out_tree(), outs) - - def lower(*args, **kwargs): - jitted_f, _, _, _, _ = infer_params(*args, **kwargs) - return jitted_f.lower(*args, **kwargs) - wrapped.lower = lower - - return wrapped - - -@lu.cache -def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): - in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) - out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) - fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) - return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs, - out_specs=out_specs, check_vma=False, - axis_names=set(mesh.axis_names)), - in_specs, out_specs) - -@lu.transformation2 -def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): - args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), - list(args), list(in_axes)) - out = f(*args) - return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) - -def _axis_to_spec(axis_name, ax): - if isinstance(ax, int): - specs = [None] * ax + [axis_name] - return P(*specs) - elif ax is None: - return P() - else: - raise TypeError(ax) - -def _get_devices(p, backend): - if backend is not None and p.devices is None: - devs = jax.devices(backend=backend) - else: - devs = jax.devices() if p.devices is None else p.devices - if jax.process_count() > 1: - return devs[:p.global_axis_size] - return devs[:p.local_axis_size] + return jshmap.shard_map( + f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + check_vma=check_rep, axis_names=axis_names) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2c2d84ca03ef..3b63be231d90 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -48,7 +48,7 @@ import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.shard_map import shard_map, shard_map2 +from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() @@ -2007,7 +2007,7 @@ def g(x): @jax.jit def f(x): - x = shard_map2(g, out_specs=P('i', None), axis_names=frozenset({'i'}))(x) + x = jax.shard_map(g, out_specs=P('i', None), axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -2049,7 +2049,7 @@ def g(x): @jax.jit def f(x): - x = shard_map2(g, out_specs=P('i', 'j', None, None), + x = jax.shard_map(g, out_specs=P('i', 'j', None, None), axis_names=frozenset({'i', 'j'}))(x) self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None)) return x @@ -2930,11 +2930,11 @@ def body(carry, _): g(x, y, z) # doesn't crash @jtu.with_explicit_mesh((2, 2), ('x', 'y')) - def test_shmap2_full_manual_context_explicit(self, mesh): + def test_shmap_full_manual_context_explicit(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, P('x', 'y')) - @partial(shard_map2, out_specs=P('x', 'y')) + @partial(jax.shard_map, out_specs=P('x', 'y')) def f(x): self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) self.assertEqual(x.aval.vma, {'x', 'y'}) @@ -2948,11 +2948,11 @@ def f(x): jax.jit(f)(arr) # doesn't crash @jtu.with_explicit_mesh((2, 2), ('x', 'y')) - def test_shmap2_partial_manual_explicit(self, mesh): + def test_shmap_partial_manual_explicit(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, P('x', 'y')) - @partial(shard_map2, axis_names=frozenset('x'), out_specs=P('x')) + @partial(jax.shard_map, axis_names=frozenset('x'), out_specs=P('x')) def f(x): self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) self.assertEqual(get_abstract_mesh().explicit_axes, ('y',)) @@ -2968,11 +2968,11 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) - def test_shmap2_full_manual_context_auto(self, mesh): + def test_shmap_full_manual_context_auto(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, P('x', 'y')) - @partial(shard_map2, in_specs=P('x', 'y'), out_specs=P('x', 'y')) + @partial(jax.shard_map, in_specs=P('x', 'y'), out_specs=P('x', 'y')) def f(x): self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) self.assertEqual(x.aval.vma, {'x', 'y'}) @@ -2986,11 +2986,11 @@ def f(x): jax.jit(f)(arr) # doesn't crash @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) - def test_shmap2_partial_manual_auto(self, mesh): + def test_shmap_partial_manual_auto(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, P('x', 'y')) - @partial(shard_map2, axis_names=frozenset('x'), in_specs=P('x'), + @partial(jax.shard_map, axis_names=frozenset('x'), in_specs=P('x'), out_specs=P('x')) def f(x): self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) @@ -3005,7 +3005,7 @@ def f(x): def test_no_mesh_context_error(self): with self.assertRaisesRegex(ValueError, "The context mesh cannot be empty"): - shard_map2(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) + jax.shard_map(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) class FunSpec(NamedTuple): From f1803bef692b6d91815b02ec643980d9ab9b5132 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 22 Apr 2025 20:03:32 -0700 Subject: [PATCH 0753/1769] [Pallas] Make Pallas PRNG more robust by improving the Mosaic seed broadcast in lower_to_llo. - To make fold_in non-trivial, in Pallas the key is now represented as a (1, 2)-shaped key. - 2 new primitives were added for wrapping/unwrapping the key from scalars. This is needed because JAX's wrap/unwrap return to and from vectors, whereas in Pallas we need to return a list of scalars. PiperOrigin-RevId: 750422791 --- jax/_src/pallas/mosaic/lowering.py | 64 ++++++++++++-------------- jax/_src/pallas/mosaic/primitives.py | 50 +++++++++++++++++++- jax/_src/pallas/mosaic/random.py | 60 +++++++++--------------- tests/pallas/tpu_pallas_random_test.py | 7 ++- 4 files changed, 105 insertions(+), 76 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3fba962ba0a9..096aca115072 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1535,13 +1535,12 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree ref_block_shape = aval_out.dtype._impl.key_shape if len(ref_block_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(ref_block_shape) != (1, 1): - raise NotImplementedError( - f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}") + raise NotImplementedError("Seed key_data must be 1D.") + if ref_block_shape[0] != 1: + raise NotImplementedError("Leading dimension of seed key_data must be 1.") load_ops = [] - for i in range(ref_block_shape[0]): + for i in range(ref_block_shape[1]): idx = NDIndexer(indices=(0, i), shape=ref_block_shape, int_indexer_shape=tuple()) starts, _, _, _, _ = _indexer_to_start_size_stride( @@ -3819,18 +3818,10 @@ def random_unwrap_lowering(ctx, key): impl = keys_aval.dtype._impl if not pl_random.is_pallas_impl(impl): return key - assert isinstance(key, KeyScalarBundle) - # Convert to a vector. - if tuple(key.key_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key.key_shape}") - scalar = key.scalars[0] - out_type = ir.VectorType.get( - key.key_shape, _dtype_to_ir_type(jnp.dtype('int32')) + raise ValueError( + "key_data not support for Pallas PRNG keys. Use" + " split_pallas_seed instead." ) - val = vector.broadcast(out_type, scalar) - return val lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering @@ -3838,27 +3829,32 @@ def random_wrap_lowering(ctx, key_data, *, impl): del ctx if not pl_random.is_pallas_impl(impl): return key_data - if isinstance(key_data.type, ir.VectorType): - # If the key data lives in vregs, need to unpack it to sregs. - key_data_list = [] - key_data_shape = key_data.type.shape - if len(key_data_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(key_data_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key_data_shape}") - for i in range(key_data_shape[1]): - key_data_list.append(vector.ExtractOp(key_data, [], [0, i])) - return KeyScalarBundle( - scalars=key_data_list, key_shape=tuple(key_data_shape)) - if isinstance(key_data, KeyScalarBundle): - return key_data - else: - raise NotImplementedError(f"key_data wrap {type(key_data)}") + raise ValueError( + "wrap_key_data not support for Pallas PRNG keys. Use" + " wrap_pallas_seed instead." + ) lowering_rules[prng.random_wrap_p] = random_wrap_lowering + +def _split_key_lowering_rule( + ctx: LoweringRuleContext, key_data: KeyScalarBundle +): + return key_data.scalars + + +lowering_rules[tpu_primitives.split_key_p] = _split_key_lowering_rule + + +def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): + if not pl_random.is_pallas_impl(impl): + return ValueError(f"Can only join Pallas keys. Got impl={impl}") + return KeyScalarBundle(scalars=scalars, key_shape=impl.key_shape) + + +lowering_rules[tpu_primitives.join_key_p] = _join_key_lowering_rule + + def _checkify_lowering_rule( ctx: LoweringRuleContext, *err_args, err_tree, debug): if not tpu_core.runtime_assert_enabled(): diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 59856c0ca7b2..33a1de12ebde 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -22,6 +22,8 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import pretty_printer as pp +from jax._src import prng as jax_prng +from jax._src import random as jax_random from jax._src import state from jax._src import tree_util from jax._src import util @@ -685,8 +687,9 @@ def delay(nanos): prng_seed_p = jax_core.Primitive("prng_seed") prng_seed_p.multiple_results = True + @prng_seed_p.def_abstract_eval -def _(*_): +def _prng_seed_abstract_eval(*_): return [] @@ -703,9 +706,52 @@ def prng_seed(*seeds: int | jax.Array) -> None: prng_random_bits_p = jax_core.Primitive( 'prng_random_bits') + @prng_random_bits_p.def_abstract_eval -def _(*, shape): +def _prng_random_bits_abstract_eval(*, shape): return jax_core.ShapedArray(shape, jnp.dtype("int32")) + def prng_random_bits(shape): return prng_random_bits_p.bind(shape=shape) + +# PRNG wrap/unwrap ops. +# We cannot use JAX's key_data and wrap_key_data because they return +# vectors, and Pallas keys are represented as lists of scalars. + +split_key_p = jax_core.Primitive("prng_split") +split_key_p.multiple_results = True + + +@split_key_p.def_abstract_eval +def _split_key_scalar_abstract_eval(seed): + key_shape = seed.dtype._impl.key_shape + if len(key_shape) != 2 or key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {key_shape}") + return [jax_core.ShapedArray((), jnp.dtype("uint32"))] * key_shape[1] + + +def unwrap_pallas_seed(seed): + """Splits a PRNG key into it's scalar components.""" + return split_key_p.bind(seed) + + +join_key_p = jax_core.Primitive("prng_join") + + +@join_key_p.def_abstract_eval +def _join_key_scalar_abstract_eval(*seeds, impl): + if len(impl.key_shape) != 2 or impl.key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {impl.key_shape}") + if len(seeds) != impl.key_shape[1]: + raise ValueError( + f"Number of seeds must match key shape, got {len(seeds)}" + f" != {impl.key_shape[1]}." + ) + return jax_core.ShapedArray((), dtype=jax_prng.KeyTy(impl)) + + +def wrap_pallas_seed(*seeds, impl): + """Joins scalar into a single PRNG key.""" + impl = jax_random.resolve_prng_impl(impl) + return join_key_p.bind(*seeds, impl=impl) diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index fd8dcc720f07..6a2c557fd55d 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -13,18 +13,18 @@ # limitations under the License. from collections.abc import Callable - import functools import jax from jax import numpy as jnp from jax import random as jax_api_random from jax._src import blocked_sampler from jax._src import dtypes +from jax._src import prng as jax_prng from jax._src import typing -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.pallas import primitives -from jax._src import prng as jax_prng +from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas.mosaic.primitives import prng_seed Shape = jax_prng.Shape @@ -32,8 +32,8 @@ KeylessSampleFnType = Callable[..., jax.Array] set_seed = prng_seed - -FOLD_IN_ROUNDS = 128 +unwrap_pallas_seed = tpu_primitives.unwrap_pallas_seed +wrap_pallas_seed = tpu_primitives.wrap_pallas_seed def to_pallas_key(key: jax.Array) -> jax.Array: @@ -63,7 +63,7 @@ def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool: def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) - return (seed_data + seed).astype(jnp.uint32) + return (seed_data + seed).astype(jnp.uint32) # Broadcast the seed. def _random_bits(key: typing.Array, bit_width: int, shape: Shape): if bit_width != 32: @@ -72,42 +72,26 @@ def _random_bits(key: typing.Array, bit_width: int, shape: Shape): return prng_random_bits(shape) def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): - # Roughly, we compute the new key as follows: - # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] - # Because the TPU generates random numbers in (8, 128) blocks at once, we - # can generate that many values without additional cost which will reduce - # correlation between the old and new keys. - - # TODO(justinfu): The underlying TPU hardware PRNG doesn't produce robust - # random bits when applied in rounds such as below (measured via crush). - # We should consider a different strategy for generating keys. - key_shape = tpu_key_impl.key_shape - - prng_seed(data) - data_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - prng_seed(key) - key_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - - mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] - assert mixed.shape == key_shape - return jax.random.wrap_key_data(mixed, impl="pallas_tpu") + key0, key1 = unwrap_pallas_seed(key) + # Perform a cheap mixing of data into the key. + key1 = key1 + data + [key0, key1] = jax_prng.apply_round([key0, key1], 13) + return wrap_pallas_seed(key0, key1, impl="pallas_tpu") def _split(key: typing.Array, shape: Shape): del key, shape - raise NotImplementedError() + raise NotImplementedError( + "Cannot split a Pallas key. Use fold_in instead to generate new keys." + ) tpu_key_impl = jax_prng.PRNGImpl( - # Pallas currently only supports 2D+ windows, so set the key_shape - # to be 2D to have better compatibility with setting BlockSpecs. - key_shape=(1, 1), - seed=_seed_func, - split=_split, - random_bits=_random_bits, - fold_in=_fold_in, - name="pallas_tpu", - tag="pl" + key_shape=(1, 2), + seed=_seed_func, + split=_split, + random_bits=_random_bits, + fold_in=_fold_in, + name="pallas_tpu", + tag="pl", ) jax_prng.register_prng(tpu_key_impl) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index feeaa3cfceb9..74697e6c0b7f 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -143,7 +143,9 @@ def body(key_ref, o_ref): def test_key_data(self): def body(key_ref, o_ref): - o_ref[...] = jax.random.key_data(key_ref[...]) + x0, x1 = plrandom.unwrap_pallas_seed(key_ref[...]) + o_ref[0, 0] = x0 + o_ref[0, 1] = x1 rbg_key = jax_random.key(0, impl="rbg") key = plrandom.to_pallas_key(rbg_key) expected_key_data = jax.random.key_data(key) @@ -152,9 +154,10 @@ def body(key_ref, o_ref): result = pl.pallas_call( body, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), out_shape=o_shape, )(key) - self.assertEqual(result, expected_key_data) + self.assertArraysEqual(result, expected_key_data) def test_fold_in(self): # Test that folding in a value results in different random numbers. From 922d9e9d42b52f443a0999908aa37ea54498b05a Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 23 Apr 2025 11:03:37 +0530 Subject: [PATCH 0754/1769] Replace reference to jax.readthedocs.io with docs.jax.dev in jax._src.shard_map.py --- jax/_src/shard_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index df79af0ba8f8..469087dc19f3 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -77,7 +77,7 @@ def shard_map(f, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): """Map a function over shards of data using a mesh of devices. - See the docs at https://jax.readthedocs.io/en/latest/notebooks/shard_map.html. + See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html. Args: f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, From 30be6f3b555488ce5053f8795809ad16898c91fc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Apr 2025 23:06:17 -0700 Subject: [PATCH 0755/1769] Make `set_mesh` thread local so that it behaves exactly like `use_mesh` Co-authored-by: Matthew Johnson PiperOrigin-RevId: 750462521 --- jax/_src/sharding_impls.py | 11 ++++++----- tests/pjit_test.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2b79e09c49e1..faa0d31ee8ca 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -45,6 +45,7 @@ from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method import numpy as np +config_ext = xc._xla.config Shape = tuple[int, ...] Device = xc.Device @@ -1398,17 +1399,17 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + assert mesh is None or isinstance(mesh, mesh_lib.Mesh) if not core.trace_state_clean(): raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') if mesh is None: - config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore + config.abstract_mesh_context_manager.set_local(mesh_lib.empty_abstract_mesh) # type: ignore else: - config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore + config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh) # type: ignore - prev_mesh = config.device_context.get_global() - config.device_context.set_global(mesh) - return prev_mesh + prev_mesh = config.device_context.swap_local(mesh) + return None if prev_mesh is config_ext.unset else prev_mesh @contextlib.contextmanager def use_concrete_mesh(mesh: mesh_lib.Mesh | None): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4d38d612e470..3f10f7839a24 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7282,6 +7282,7 @@ def test_set_mesh(self): out = reshard(np.arange(8), P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) finally: + self.assertIsNone(prev_mesh) jax.sharding.set_mesh(prev_mesh) @jtu.with_explicit_mesh((2,), ('x',)) From ff20a62857246d61081fe04a95414a870b5debaa Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 23 Apr 2025 01:16:08 -0700 Subject: [PATCH 0756/1769] [pallas:mosaic_gpu] Added an export backward compatibility test for a trivial kernel PiperOrigin-RevId: 750494555 --- jax/_src/export/_export.py | 1 + .../pallas/mosaic_gpu_add_one.py | 88 +++++++++++++++++++ tests/export_back_compat_test.py | 1 + tests/pallas/BUILD | 1 + .../pallas/export_back_compat_pallas_test.py | 24 ++++- tests/pallas/mosaic_gpu_test.py | 7 +- 6 files changed, 113 insertions(+), 9 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index dc298f935d9e..cbd92f86b835 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1118,6 +1118,7 @@ def _check_lowering(lowering) -> None: "ApproxTopK", "stablehlo.dynamic_approx_top_k", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) "tpu_custom_call", # Pallas/TPU kernels + "mosaic_gpu", # Pallas Mosaic GPU kernels # TODO(burmako): maintain backwards compatibility for these, until they # are upstreamed to StableHLO. # See https://github.com/openxla/stablehlo/issues/8. diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py new file mode 100644 index 000000000000..fd66c35de7c9 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py @@ -0,0 +1,88 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from numpy import array, float32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_22 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['mosaic_gpu'], + serialized_date=datetime.date(2025, 4, 22), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., + 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., + 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., + 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., + 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., + 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., + 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., + 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., + 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., + 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., + 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., + 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., + 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., + 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., + 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., + 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., + 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., + 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., + 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., + 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., + 253., 254., 255.], dtype=float32),), + expected_outputs=(array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., + 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., + 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., + 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., + 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., + 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., + 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132., + 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143., + 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., + 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., + 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176., + 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187., + 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198., + 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., 209., + 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., + 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., 231., + 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., 242., + 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., 253., + 254., 255., 256.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("args[0]") +module @jit_wrapped attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<256xf32> loc("args[0]")) -> (tensor<256xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @mosaic_gpu(%arg0) {api_version = 2 : i32, backend_config = "\A9C\FB\81\9A1\C2?\0E\F4\E1\E4\E77\03\B6\97\E5G(]WR\98\EB{\BA\8A\84\01\12'#loc = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:83:4)\0A#loc1 = loc(\22-\22:94:40)\0A#loc2 = loc(\22-\22:94:47)\0A#loc3 = loc(\22-\22:94:54)\0A#loc4 = loc(\22-\22:94:116)\0A#loc5 = loc(\22-\22:94:123)\0A#loc6 = loc(\22-\22:94:130)\0A#loc7 = loc(\22-\22:94:65)\0A#loc8 = loc(\22-\22:94:78)\0A#loc9 = loc(\22-\22:94:91)\0A#loc10 = loc(\22-\22:94:141)\0A#loc11 = loc(\22-\22:94:157)\0A#loc12 = loc(\22-\22:94:174)\0A#loc17 = loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))\0A\22builtin.module\22() <{sym_name = \22add_one\22}> ({\0A \22stable_mosaic_gpu.func.func\22() ({\0A }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \22mosaic_gpu_init_tma_desc\22, sym_visibility = \22private\22} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.llvm.mlir.global\22() ({\0A }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \22global_scratch\22, unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.func\22() ({\0A ^bb0(%arg0: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc)), %arg1: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))):\0A %0 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\0A %1 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %2 = \22stable_mosaic_gpu.llvm.load\22(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %3 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %4 = \22stable_mosaic_gpu.llvm.insertvalue\22(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %5 = \22stable_mosaic_gpu.llvm.insertvalue\22(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %6 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %7 = \22stable_mosaic_gpu.llvm.insertvalue\22(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %8 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %9 = \22stable_mosaic_gpu.llvm.insertvalue\22(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %10 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %11 = \22stable_mosaic_gpu.llvm.insertvalue\22(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %12 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %13 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %14 = \22stable_mosaic_gpu.llvm.load\22(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %15 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %16 = \22stable_mosaic_gpu.llvm.insertvalue\22(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %17 = \22stable_mosaic_gpu.llvm.insertvalue\22(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %18 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %19 = \22stable_mosaic_gpu.llvm.insertvalue\22(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %20 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %21 = \22stable_mosaic_gpu.llvm.insertvalue\22(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %22 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %23 = \22stable_mosaic_gpu.llvm.insertvalue\22(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %24 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %25 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %26 = \22stable_mosaic_gpu.llvm.alloca\22(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\0A %27 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %28:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %29 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%12) : (memref<256xf32>) -> index loc(#loc17)\0A %30 = \22stable_mosaic_gpu.arith.index_cast\22(%29) : (index) -> i64 loc(#loc17)\0A %31 = \22stable_mosaic_gpu.llvm.inttoptr\22(%30) : (i64) -> !llvm.ptr loc(#loc17)\0A %32 = \22stable_mosaic_gpu.arith.index_cast\22(%28#1) : (index) -> i64 loc(#loc17)\0A %33 = \22stable_mosaic_gpu.llvm.getelementptr\22(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %34 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %35 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %36 = \22stable_mosaic_gpu.arith.index_cast\22(%28#2) : (index) -> i64 loc(#loc17)\0A %37 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %38 = \22stable_mosaic_gpu.llvm.alloca\22(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %39 = \22stable_mosaic_gpu.llvm.getelementptr\22(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %40 = \22stable_mosaic_gpu.arith.index_cast\22(%28#3) : (index) -> i64 loc(#loc17)\0A %41 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %42 = \22stable_mosaic_gpu.llvm.alloca\22(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %43 = \22stable_mosaic_gpu.llvm.getelementptr\22(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %44 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %45 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %46 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %47 = \22stable_mosaic_gpu.llvm.alloca\22(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %48 = \22stable_mosaic_gpu.llvm.getelementptr\22(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %49 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %50:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %51 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%24) : (memref<256xf32>) -> index loc(#loc17)\0A %52 = \22stable_mosaic_gpu.arith.index_cast\22(%51) : (index) -> i64 loc(#loc17)\0A %53 = \22stable_mosaic_gpu.llvm.inttoptr\22(%52) : (i64) -> !llvm.ptr loc(#loc17)\0A %54 = \22stable_mosaic_gpu.arith.index_cast\22(%50#1) : (index) -> i64 loc(#loc17)\0A %55 = \22stable_mosaic_gpu.llvm.getelementptr\22(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %56 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %57 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %58 = \22stable_mosaic_gpu.arith.index_cast\22(%50#2) : (index) -> i64 loc(#loc17)\0A %59 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %60 = \22stable_mosaic_gpu.llvm.alloca\22(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %61 = \22stable_mosaic_gpu.llvm.getelementptr\22(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %62 = \22stable_mosaic_gpu.arith.index_cast\22(%50#3) : (index) -> i64 loc(#loc17)\0A %63 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %64 = \22stable_mosaic_gpu.llvm.alloca\22(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %65 = \22stable_mosaic_gpu.llvm.getelementptr\22(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %66 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %67 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %68 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %69 = \22stable_mosaic_gpu.llvm.alloca\22(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %70 = \22stable_mosaic_gpu.llvm.getelementptr\22(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %71 = \22stable_mosaic_gpu.llvm.load\22(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\0A %72 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc17)\0A %73 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %74 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %75 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc17)\0A %76 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %77 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %78 = \22stable_mosaic_gpu.arith.constant\22() {value = 2056 : i32} : () -> i32 loc(#loc17)\0A %79 = \22stable_mosaic_gpu.gpu.launch\22(%0, %72, %73, %74, %75, %76, %77, %78) ({\0A ^bb0(%arg2: index loc(\22-\22:94:40), %arg3: index loc(\22-\22:94:47), %arg4: index loc(\22-\22:94:54), %arg5: index loc(\22-\22:94:116), %arg6: index loc(\22-\22:94:123), %arg7: index loc(\22-\22:94:130), %arg8: index loc(\22-\22:94:65), %arg9: index loc(\22-\22:94:78), %arg10: index loc(\22-\22:94:91), %arg11: index loc(\22-\22:94:141), %arg12: index loc(\22-\22:94:157), %arg13: index loc(\22-\22:94:174)):\0A %80 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc17)\0A %81 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\0A %82 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\0A %83 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\0A %84 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc17)\0A %85 = \22stable_mosaic_gpu.memref.view\22(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\0A %86 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %87 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\0A %88 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\0A %89 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %90 = \22stable_mosaic_gpu.llvm.mul\22(%88, %89) : (i64, i64) -> i64 loc(#loc17)\0A %91 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\0A %92 = \22stable_mosaic_gpu.llvm.add\22(%91, %90) : (i64, i64) -> i64 loc(#loc17)\0A %93 = \22stable_mosaic_gpu.llvm.inttoptr\22(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\0A %94 = \22stable_mosaic_gpu.llvm.getelementptr\22(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %95 = \22stable_mosaic_gpu.memref.alloca\22() {operandSegmentSizes = array} : () -> memref loc(#loc17)\0A %96 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.memref.store\22(%96, %95) : (i32, memref) -> () loc(#loc17)\0A %97 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %98 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %99 = \22stable_mosaic_gpu.arith.index_cast\22(%98) : (index) -> i32 loc(#loc17)\0A %100 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %101 = \22stable_mosaic_gpu.arith.index_cast\22(%100) : (index) -> i32 loc(#loc17)\0A %102 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %103 = \22stable_mosaic_gpu.arith.index_cast\22(%102) : (index) -> i32 loc(#loc17)\0A %104 = \22stable_mosaic_gpu.arith.muli\22(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %105 = \22stable_mosaic_gpu.arith.addi\22(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %106 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %107 = \22stable_mosaic_gpu.arith.index_cast\22(%106) : (index) -> i32 loc(#loc17)\0A %108 = \22stable_mosaic_gpu.arith.muli\22(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %109 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %110 = \22stable_mosaic_gpu.arith.index_cast\22(%109) : (index) -> i32 loc(#loc17)\0A %111 = \22stable_mosaic_gpu.arith.muli\22(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %112 = \22stable_mosaic_gpu.arith.addi\22(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %113 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %114 = \22stable_mosaic_gpu.arith.index_cast\22(%113) : (index) -> i32 loc(#loc17)\0A %115 = \22stable_mosaic_gpu.arith.muli\22(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %116 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %117 = \22stable_mosaic_gpu.arith.shrui\22(%112, %116) : (i32, i32) -> i32 loc(#loc17)\0A %118 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %119 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %120 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %121 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %122 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %123 = \22stable_mosaic_gpu.arith.cmpi\22(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %124 = \22stable_mosaic_gpu.arith.andi\22(%123, %97) : (i1, i1) -> i1 loc(#loc17)\0A \22stable_mosaic_gpu.scf.if\22(%124) ({\0A %332 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %333 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.mbarrier.init.shared\22(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc13)\0A }, {\0A }) : (i1) -> () loc(#loc17)\0A %125 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.fence.mbarrier.init\22() : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.gpu.barrier\22() : () -> () loc(#loc17)\0A %126 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %127 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %128 = \22stable_mosaic_gpu.arith.index_cast\22(%127) : (index) -> i32 loc(#loc17)\0A %129 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %130 = \22stable_mosaic_gpu.arith.index_cast\22(%129) : (index) -> i32 loc(#loc17)\0A %131 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %132 = \22stable_mosaic_gpu.arith.index_cast\22(%131) : (index) -> i32 loc(#loc17)\0A %133 = \22stable_mosaic_gpu.arith.muli\22(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %134 = \22stable_mosaic_gpu.arith.addi\22(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %135 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %136 = \22stable_mosaic_gpu.arith.index_cast\22(%135) : (index) -> i32 loc(#loc17)\0A %137 = \22stable_mosaic_gpu.arith.muli\22(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %138 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %139 = \22stable_mosaic_gpu.arith.index_cast\22(%138) : (index) -> i32 loc(#loc17)\0A %140 = \22stable_mosaic_gpu.arith.muli\22(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %141 = \22stable_mosaic_gpu.arith.addi\22(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %142 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %143 = \22stable_mosaic_gpu.arith.index_cast\22(%142) : (index) -> i32 loc(#loc17)\0A %144 = \22stable_mosaic_gpu.arith.muli\22(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %145 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %146 = \22stable_mosaic_gpu.arith.shrui\22(%141, %145) : (i32, i32) -> i32 loc(#loc17)\0A %147 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %148 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %149 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %150 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %151 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i32} : () -> i32 loc(#loc17)\0A %152 = \22stable_mosaic_gpu.arith.remui\22(%150, %151) : (i32, i32) -> i32 loc(#loc17)\0A %153 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %154 = \22stable_mosaic_gpu.arith.cmpi\22(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %155 = \22stable_mosaic_gpu.arith.andi\22(%154, %126) : (i1, i1) -> i1 loc(#loc17)\0A %156 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %157 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %158 = \22stable_mosaic_gpu.arith.index_cast\22(%157) : (index) -> i32 loc(#loc17)\0A %159 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %160 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc20)\0A %161 = \22stable_mosaic_gpu.memref.view\22(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %162 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %163 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : index} : () -> index loc(#loc20)\0A %164 = \22stable_mosaic_gpu.memref.view\22(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %165 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %166 = \22stable_mosaic_gpu.memref.subview\22(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %167 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %168 = \22stable_mosaic_gpu.arith.index_castui\22(%167) : (index) -> i32 loc(#loc19)\0A %169 = \22stable_mosaic_gpu.arith.addi\22(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %170 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %171 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %172 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %173 = \22stable_mosaic_gpu.arith.index_cast\22(%172) : (index) -> i32 loc(#loc19)\0A %174 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %175 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %176 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %177 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %178 = \22stable_mosaic_gpu.llvm.mul\22(%176, %177) : (i64, i64) -> i64 loc(#loc19)\0A %179 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %180 = \22stable_mosaic_gpu.llvm.add\22(%179, %178) : (i64, i64) -> i64 loc(#loc19)\0A %181 = \22stable_mosaic_gpu.llvm.inttoptr\22(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %182 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %183 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A %184 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %185 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %186 = \22stable_mosaic_gpu.arith.addi\22(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A %187 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %188 = \22stable_mosaic_gpu.arith.remsi\22(%186, %187) : (i32, i32) -> i32 loc(#loc22)\0A %189 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc23)\0A %190 = \22stable_mosaic_gpu.arith.index_castui\22(%189) : (index) -> i32 loc(#loc23)\0A %191 = \22stable_mosaic_gpu.arith.addi\22(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %192 = \22stable_mosaic_gpu.memref.load\22(%95) : (memref) -> i32 loc(#loc23)\0A %193 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc23)\0A %194 = \22stable_mosaic_gpu.arith.shli\22(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %195 = \22stable_mosaic_gpu.arith.andi\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A %196 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc23)\0A %197 = \22stable_mosaic_gpu.arith.cmpi\22(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\0A %198 = \22stable_mosaic_gpu.arith.xori\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A \22stable_mosaic_gpu.memref.store\22(%198, %95) : (i32, memref) -> () loc(#loc23)\0A %199 = \22stable_mosaic_gpu.arith.constant\22() {value = 10000000 : i32} : () -> i32 loc(#loc23)\0A %200 = \22stable_mosaic_gpu.arith.extui\22(%197) : (i1) -> i32 loc(#loc23)\0A %201 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\0A \22stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared\22(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\0A %202 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %203 = \22stable_mosaic_gpu.memref.subview\22(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %204 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %205 = \22stable_mosaic_gpu.memref.subview\22(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %206 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc24)\0A %207 = \22stable_mosaic_gpu.arith.index_cast\22(%206) : (index) -> i32 loc(#loc24)\0A %208 = \22stable_mosaic_gpu.memref.subview\22(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %209 = \22stable_mosaic_gpu.memref.collapse_shape\22(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %210 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc25)\0A %211 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc25)\0A %212 = \22stable_mosaic_gpu.arith.remui\22(%210, %211) : (index, index) -> index loc(#loc25)\0A %213 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc25)\0A %214 = \22stable_mosaic_gpu.arith.muli\22(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %215 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc25)\0A %216 = \22stable_mosaic_gpu.arith.addi\22(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %217 = \22stable_mosaic_gpu.vector.load\22(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\0A %218 = \22stable_mosaic_gpu.arith.constant\22() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\0A %219 = \22stable_mosaic_gpu.vector.splat\22(%218) : (f32) -> vector<2xf32> loc(#loc26)\0A %220 = \22stable_mosaic_gpu.arith.addf\22(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\0A %221 = \22stable_mosaic_gpu.memref.subview\22(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %222 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %223 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %224 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %225 = \22stable_mosaic_gpu.arith.remui\22(%223, %224) : (index, index) -> index loc(#loc27)\0A %226 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %227 = \22stable_mosaic_gpu.arith.muli\22(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %228 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %229 = \22stable_mosaic_gpu.arith.addi\22(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %230 = \22stable_mosaic_gpu.vector.load\22(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\0A %231 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %232 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %233 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %234 = \22stable_mosaic_gpu.arith.remui\22(%232, %233) : (index, index) -> index loc(#loc27)\0A %235 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %236 = \22stable_mosaic_gpu.arith.muli\22(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %237 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %238 = \22stable_mosaic_gpu.arith.addi\22(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A \22stable_mosaic_gpu.vector.store\22(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A %239 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc29)\0A %240 = \22stable_mosaic_gpu.arith.addi\22(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %241 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %242 = \22stable_mosaic_gpu.arith.remsi\22(%240, %241) : (i32, i32) -> i32 loc(#loc22)\0A %243 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc30)\0A %244 = \22stable_mosaic_gpu.arith.cmpi\22(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\0A %245 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc31)\0A %246 = \22stable_mosaic_gpu.arith.cmpi\22(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\0A %247 = \22stable_mosaic_gpu.arith.andi\22(%244, %246) : (i1, i1) -> i1 loc(#loc32)\0A %248 = \22stable_mosaic_gpu.arith.extui\22(%247) : (i1) -> i32 loc(#loc33)\0A %249 = \22stable_mosaic_gpu.arith.index_cast\22(%248) : (i32) -> index loc(#loc34)\0A \22stable_mosaic_gpu.scf.index_switch\22(%249) ({\0A %313 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %314 = \22stable_mosaic_gpu.memref.subview\22(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %315 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %316 = \22stable_mosaic_gpu.arith.index_castui\22(%315) : (index) -> i32 loc(#loc19)\0A %317 = \22stable_mosaic_gpu.arith.addi\22(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %318 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %319 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %320 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %321 = \22stable_mosaic_gpu.arith.index_cast\22(%320) : (index) -> i32 loc(#loc19)\0A %322 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %323 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %324 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %325 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %326 = \22stable_mosaic_gpu.llvm.mul\22(%324, %325) : (i64, i64) -> i64 loc(#loc19)\0A %327 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %328 = \22stable_mosaic_gpu.llvm.add\22(%327, %326) : (i64, i64) -> i64 loc(#loc19)\0A %329 = \22stable_mosaic_gpu.llvm.inttoptr\22(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %330 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %331 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc16)\0A }, {\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc34)\0A }) {cases = array} : (index) -> () loc(#loc34)\0A %250 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc21)\0A %251 = \22stable_mosaic_gpu.arith.addi\22(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A \22stable_mosaic_gpu.nvvm.fence.proxy\22() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\0A %252 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %253 = \22stable_mosaic_gpu.arith.index_cast\22(%252) : (index) -> i32 loc(#loc35)\0A %254 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %255 = \22stable_mosaic_gpu.arith.index_cast\22(%254) : (index) -> i32 loc(#loc35)\0A %256 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %257 = \22stable_mosaic_gpu.arith.index_cast\22(%256) : (index) -> i32 loc(#loc35)\0A %258 = \22stable_mosaic_gpu.arith.muli\22(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %259 = \22stable_mosaic_gpu.arith.addi\22(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %260 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %261 = \22stable_mosaic_gpu.arith.index_cast\22(%260) : (index) -> i32 loc(#loc35)\0A %262 = \22stable_mosaic_gpu.arith.muli\22(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %263 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %264 = \22stable_mosaic_gpu.arith.index_cast\22(%263) : (index) -> i32 loc(#loc35)\0A %265 = \22stable_mosaic_gpu.arith.muli\22(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %266 = \22stable_mosaic_gpu.arith.addi\22(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %267 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %268 = \22stable_mosaic_gpu.arith.index_cast\22(%267) : (index) -> i32 loc(#loc35)\0A %269 = \22stable_mosaic_gpu.arith.muli\22(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %270 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc35)\0A %271 = \22stable_mosaic_gpu.arith.shrui\22(%266, %270) : (i32, i32) -> i32 loc(#loc35)\0A %272 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc35)\0A %273 = \22stable_mosaic_gpu.arith.addi\22(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %274 = \22stable_mosaic_gpu.llvm.inline_asm\22(%273) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc35)\0A %275 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc22)\0A %276 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %277 = \22stable_mosaic_gpu.arith.remsi\22(%275, %276) : (i32, i32) -> i32 loc(#loc22)\0A %278 = \22stable_mosaic_gpu.arith.index_cast\22(%277) : (i32) -> index loc(#loc18)\0A %279 = \22stable_mosaic_gpu.memref.subview\22(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\0A %280 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc18)\0A %281 = \22stable_mosaic_gpu.arith.index_cast\22(%280) : (index) -> i32 loc(#loc18)\0A %282 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\0A %283 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\0A %284 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\0A %285 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc18)\0A %286 = \22stable_mosaic_gpu.llvm.mul\22(%284, %285) : (i64, i64) -> i64 loc(#loc18)\0A %287 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\0A %288 = \22stable_mosaic_gpu.llvm.add\22(%287, %286) : (i64, i64) -> i64 loc(#loc18)\0A %289 = \22stable_mosaic_gpu.llvm.inttoptr\22(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta\22(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group\22() {group = 0 : i32} : () -> () loc(#loc36)\0A %290 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %291 = \22stable_mosaic_gpu.arith.index_cast\22(%290) : (index) -> i32 loc(#loc36)\0A %292 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %293 = \22stable_mosaic_gpu.arith.index_cast\22(%292) : (index) -> i32 loc(#loc36)\0A %294 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %295 = \22stable_mosaic_gpu.arith.index_cast\22(%294) : (index) -> i32 loc(#loc36)\0A %296 = \22stable_mosaic_gpu.arith.muli\22(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %297 = \22stable_mosaic_gpu.arith.addi\22(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %298 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %299 = \22stable_mosaic_gpu.arith.index_cast\22(%298) : (index) -> i32 loc(#loc36)\0A %300 = \22stable_mosaic_gpu.arith.muli\22(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %301 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %302 = \22stable_mosaic_gpu.arith.index_cast\22(%301) : (index) -> i32 loc(#loc36)\0A %303 = \22stable_mosaic_gpu.arith.muli\22(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %304 = \22stable_mosaic_gpu.arith.addi\22(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %305 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %306 = \22stable_mosaic_gpu.arith.index_cast\22(%305) : (index) -> i32 loc(#loc36)\0A %307 = \22stable_mosaic_gpu.arith.muli\22(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %308 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc36)\0A %309 = \22stable_mosaic_gpu.arith.shrui\22(%304, %308) : (i32, i32) -> i32 loc(#loc36)\0A %310 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc36)\0A %311 = \22stable_mosaic_gpu.arith.addi\22(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %312 = \22stable_mosaic_gpu.llvm.inline_asm\22(%311) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc36)\0A \22stable_mosaic_gpu.gpu.terminator\22() : () -> () loc(#loc17)\0A }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\0A \22stable_mosaic_gpu.func.return\22() : () -> () loc(#loc17)\0A }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \22mosaic_gpu_body\22} : () -> () loc(#loc17)\0A}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\0A#loc13 = loc(\22-\22:141:7)\0A#loc14 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:19)\0A#loc15 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:6)\0A#loc16 = loc(\22-\22:279:7)\0A#loc18 = loc(\22/copy_smem_to_gmem\22(#loc))\0A#loc19 = loc(\22/copy_gmem_to_smem\22(#loc))\0A#loc20 = loc(\22/run_scoped\22(#loc))\0A#loc21 = loc(\22/scan\22(#loc))\0A#loc22 = loc(\22/rem\22(#loc))\0A#loc23 = loc(\22/barrier_wait\22(#loc))\0A#loc24 = loc(\22/jaxpr_call\22(#loc))\0A#loc25 = loc(\22/get\22(#loc14))\0A#loc26 = loc(\22/add\22(#loc14))\0A#loc27 = loc(\22/swap\22(#loc15))\0A#loc28 = loc(\22/commit_group\22(#loc))\0A#loc29 = loc(\22/add\22(#loc))\0A#loc30 = loc(\22/ge\22(#loc))\0A#loc31 = loc(\22/lt\22(#loc))\0A#loc32 = loc(\22/and\22(#loc))\0A#loc33 = loc(\22/convert_element_type\22(#loc))\0A#loc34 = loc(\22/cond\22(#loc))\0A#loc35 = loc(\22/commit_smem\22(#loc))\0A#loc36 = loc(\22/wait_smem_to_gmem\22(#loc))\0A", operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<256xf32>) -> tensor<256xf32> loc(#loc3) + return %0 : tensor<256xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4) +#loc3 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.7\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03_=\x0f\x01\x1d\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03!\x0b\x0f\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b/\x01\x05\x0b\x0f\x03\x0b\x17\x17\x07\x13\x07\x02\xd5\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x19\x05\x17\x17\x1b\xa7\t\x05\x19\x03\x01\x03\x03;\x03\x03#\r\x01#\x07\x03\x03)\r\x03+-\x1d\x1b\x1d\x1d\x1d\x1f\x1d!\x0b\x05\x1d#\x1d%\x05\x01\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03\x02\x08\t\x11\x03\x05\x03\x05\t)\x03\x05\r\x13\x04O\x05\x01Q\x01\x05\x01\x07\x04=\x03\x01\x05\x03P\x01\x03\x07\x04)\x03\x05\x0b\x03\x0b\x11\x00\x05F\x15\x05\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00D\xae\x05\'\x17\xa4\xa4\x05\x0f\x0b\x0f!\x85G\x11\x19%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_wrapped\x00args[0]\x00jit(wrapped)/jit(main)/pallas_call\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00jax.result_info\x00result\x00main\x00public\x00\xa9C\xfb\x81\x9a1\xc2?\x0e\xf4\xe1\xe4\xe77\x03\xb6\x97\xe5G(]WR\x98\xeb{\xba\x8a\x84\x01\x12\'#loc = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4)\n#loc1 = loc("-":94:40)\n#loc2 = loc("-":94:47)\n#loc3 = loc("-":94:54)\n#loc4 = loc("-":94:116)\n#loc5 = loc("-":94:123)\n#loc6 = loc("-":94:130)\n#loc7 = loc("-":94:65)\n#loc8 = loc("-":94:78)\n#loc9 = loc("-":94:91)\n#loc10 = loc("-":94:141)\n#loc11 = loc("-":94:157)\n#loc12 = loc("-":94:174)\n#loc17 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc))\n"builtin.module"() <{sym_name = "add_one"}> ({\n "stable_mosaic_gpu.func.func"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = "mosaic_gpu_init_tma_desc", sym_visibility = "private"} : () -> () loc(#loc17)\n "stable_mosaic_gpu.llvm.mlir.global"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = "global_scratch", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\n "stable_mosaic_gpu.func.func"() ({\n ^bb0(%arg0: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc)), %arg1: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc))):\n %0 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\n %1 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %2 = "stable_mosaic_gpu.llvm.load"(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %3 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %4 = "stable_mosaic_gpu.llvm.insertvalue"(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %5 = "stable_mosaic_gpu.llvm.insertvalue"(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %6 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %7 = "stable_mosaic_gpu.llvm.insertvalue"(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %8 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %9 = "stable_mosaic_gpu.llvm.insertvalue"(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %10 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %11 = "stable_mosaic_gpu.llvm.insertvalue"(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %12 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %13 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %14 = "stable_mosaic_gpu.llvm.load"(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %15 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %16 = "stable_mosaic_gpu.llvm.insertvalue"(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %17 = "stable_mosaic_gpu.llvm.insertvalue"(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %18 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %19 = "stable_mosaic_gpu.llvm.insertvalue"(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %20 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %21 = "stable_mosaic_gpu.llvm.insertvalue"(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %22 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %23 = "stable_mosaic_gpu.llvm.insertvalue"(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %24 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %25 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %26 = "stable_mosaic_gpu.llvm.alloca"(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\n %27 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %28:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %29 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%12) : (memref<256xf32>) -> index loc(#loc17)\n %30 = "stable_mosaic_gpu.arith.index_cast"(%29) : (index) -> i64 loc(#loc17)\n %31 = "stable_mosaic_gpu.llvm.inttoptr"(%30) : (i64) -> !llvm.ptr loc(#loc17)\n %32 = "stable_mosaic_gpu.arith.index_cast"(%28#1) : (index) -> i64 loc(#loc17)\n %33 = "stable_mosaic_gpu.llvm.getelementptr"(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %34 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %35 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %36 = "stable_mosaic_gpu.arith.index_cast"(%28#2) : (index) -> i64 loc(#loc17)\n %37 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %38 = "stable_mosaic_gpu.llvm.alloca"(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %39 = "stable_mosaic_gpu.llvm.getelementptr"(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %40 = "stable_mosaic_gpu.arith.index_cast"(%28#3) : (index) -> i64 loc(#loc17)\n %41 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %42 = "stable_mosaic_gpu.llvm.alloca"(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %43 = "stable_mosaic_gpu.llvm.getelementptr"(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %44 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %45 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %46 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %47 = "stable_mosaic_gpu.llvm.alloca"(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %48 = "stable_mosaic_gpu.llvm.getelementptr"(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %49 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %50:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %51 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%24) : (memref<256xf32>) -> index loc(#loc17)\n %52 = "stable_mosaic_gpu.arith.index_cast"(%51) : (index) -> i64 loc(#loc17)\n %53 = "stable_mosaic_gpu.llvm.inttoptr"(%52) : (i64) -> !llvm.ptr loc(#loc17)\n %54 = "stable_mosaic_gpu.arith.index_cast"(%50#1) : (index) -> i64 loc(#loc17)\n %55 = "stable_mosaic_gpu.llvm.getelementptr"(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %56 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %57 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %58 = "stable_mosaic_gpu.arith.index_cast"(%50#2) : (index) -> i64 loc(#loc17)\n %59 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %60 = "stable_mosaic_gpu.llvm.alloca"(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %61 = "stable_mosaic_gpu.llvm.getelementptr"(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %62 = "stable_mosaic_gpu.arith.index_cast"(%50#3) : (index) -> i64 loc(#loc17)\n %63 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %64 = "stable_mosaic_gpu.llvm.alloca"(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %65 = "stable_mosaic_gpu.llvm.getelementptr"(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %66 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %67 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %68 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %69 = "stable_mosaic_gpu.llvm.alloca"(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %70 = "stable_mosaic_gpu.llvm.getelementptr"(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %71 = "stable_mosaic_gpu.llvm.load"(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\n %72 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc17)\n %73 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %74 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %75 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc17)\n %76 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %77 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %78 = "stable_mosaic_gpu.arith.constant"() {value = 2056 : i32} : () -> i32 loc(#loc17)\n %79 = "stable_mosaic_gpu.gpu.launch"(%0, %72, %73, %74, %75, %76, %77, %78) ({\n ^bb0(%arg2: index loc("-":94:40), %arg3: index loc("-":94:47), %arg4: index loc("-":94:54), %arg5: index loc("-":94:116), %arg6: index loc("-":94:123), %arg7: index loc("-":94:130), %arg8: index loc("-":94:65), %arg9: index loc("-":94:78), %arg10: index loc("-":94:91), %arg11: index loc("-":94:141), %arg12: index loc("-":94:157), %arg13: index loc("-":94:174)):\n %80 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc17)\n %81 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\n %82 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\n %83 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\n %84 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc17)\n %85 = "stable_mosaic_gpu.memref.view"(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\n %86 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %87 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\n %88 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\n %89 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %90 = "stable_mosaic_gpu.llvm.mul"(%88, %89) : (i64, i64) -> i64 loc(#loc17)\n %91 = "stable_mosaic_gpu.llvm.ptrtoint"(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\n %92 = "stable_mosaic_gpu.llvm.add"(%91, %90) : (i64, i64) -> i64 loc(#loc17)\n %93 = "stable_mosaic_gpu.llvm.inttoptr"(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\n %94 = "stable_mosaic_gpu.llvm.getelementptr"(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %95 = "stable_mosaic_gpu.memref.alloca"() {operandSegmentSizes = array} : () -> memref loc(#loc17)\n %96 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.memref.store"(%96, %95) : (i32, memref) -> () loc(#loc17)\n %97 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %98 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %99 = "stable_mosaic_gpu.arith.index_cast"(%98) : (index) -> i32 loc(#loc17)\n %100 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %101 = "stable_mosaic_gpu.arith.index_cast"(%100) : (index) -> i32 loc(#loc17)\n %102 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %103 = "stable_mosaic_gpu.arith.index_cast"(%102) : (index) -> i32 loc(#loc17)\n %104 = "stable_mosaic_gpu.arith.muli"(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %105 = "stable_mosaic_gpu.arith.addi"(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %106 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %107 = "stable_mosaic_gpu.arith.index_cast"(%106) : (index) -> i32 loc(#loc17)\n %108 = "stable_mosaic_gpu.arith.muli"(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %109 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %110 = "stable_mosaic_gpu.arith.index_cast"(%109) : (index) -> i32 loc(#loc17)\n %111 = "stable_mosaic_gpu.arith.muli"(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %112 = "stable_mosaic_gpu.arith.addi"(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %113 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %114 = "stable_mosaic_gpu.arith.index_cast"(%113) : (index) -> i32 loc(#loc17)\n %115 = "stable_mosaic_gpu.arith.muli"(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %116 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %117 = "stable_mosaic_gpu.arith.shrui"(%112, %116) : (i32, i32) -> i32 loc(#loc17)\n %118 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %119 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %120 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %121 = "stable_mosaic_gpu.nvvm.shfl.sync"(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %122 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %123 = "stable_mosaic_gpu.arith.cmpi"(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %124 = "stable_mosaic_gpu.arith.andi"(%123, %97) : (i1, i1) -> i1 loc(#loc17)\n "stable_mosaic_gpu.scf.if"(%124) ({\n %332 = "stable_mosaic_gpu.llvm.getelementptr"(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %333 = "stable_mosaic_gpu.arith.constant"() {value = 128 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.mbarrier.init.shared"(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc13)\n }, {\n }) : (i1) -> () loc(#loc17)\n %125 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.fence.mbarrier.init"() : () -> () loc(#loc17)\n "stable_mosaic_gpu.gpu.barrier"() : () -> () loc(#loc17)\n %126 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %127 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %128 = "stable_mosaic_gpu.arith.index_cast"(%127) : (index) -> i32 loc(#loc17)\n %129 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %130 = "stable_mosaic_gpu.arith.index_cast"(%129) : (index) -> i32 loc(#loc17)\n %131 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %132 = "stable_mosaic_gpu.arith.index_cast"(%131) : (index) -> i32 loc(#loc17)\n %133 = "stable_mosaic_gpu.arith.muli"(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %134 = "stable_mosaic_gpu.arith.addi"(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %135 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %136 = "stable_mosaic_gpu.arith.index_cast"(%135) : (index) -> i32 loc(#loc17)\n %137 = "stable_mosaic_gpu.arith.muli"(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %138 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %139 = "stable_mosaic_gpu.arith.index_cast"(%138) : (index) -> i32 loc(#loc17)\n %140 = "stable_mosaic_gpu.arith.muli"(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %141 = "stable_mosaic_gpu.arith.addi"(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %142 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %143 = "stable_mosaic_gpu.arith.index_cast"(%142) : (index) -> i32 loc(#loc17)\n %144 = "stable_mosaic_gpu.arith.muli"(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %145 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %146 = "stable_mosaic_gpu.arith.shrui"(%141, %145) : (i32, i32) -> i32 loc(#loc17)\n %147 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %148 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %149 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %150 = "stable_mosaic_gpu.nvvm.shfl.sync"(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %151 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i32} : () -> i32 loc(#loc17)\n %152 = "stable_mosaic_gpu.arith.remui"(%150, %151) : (i32, i32) -> i32 loc(#loc17)\n %153 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %154 = "stable_mosaic_gpu.arith.cmpi"(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %155 = "stable_mosaic_gpu.arith.andi"(%154, %126) : (i1, i1) -> i1 loc(#loc17)\n %156 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %157 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %158 = "stable_mosaic_gpu.arith.index_cast"(%157) : (index) -> i32 loc(#loc17)\n %159 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %160 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc20)\n %161 = "stable_mosaic_gpu.memref.view"(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %162 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %163 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : index} : () -> index loc(#loc20)\n %164 = "stable_mosaic_gpu.memref.view"(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %165 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %166 = "stable_mosaic_gpu.memref.subview"(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %167 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %168 = "stable_mosaic_gpu.arith.index_castui"(%167) : (index) -> i32 loc(#loc19)\n %169 = "stable_mosaic_gpu.arith.addi"(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %170 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %171 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %172 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %173 = "stable_mosaic_gpu.arith.index_cast"(%172) : (index) -> i32 loc(#loc19)\n %174 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %175 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %176 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %177 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %178 = "stable_mosaic_gpu.llvm.mul"(%176, %177) : (i64, i64) -> i64 loc(#loc19)\n %179 = "stable_mosaic_gpu.llvm.ptrtoint"(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %180 = "stable_mosaic_gpu.llvm.add"(%179, %178) : (i64, i64) -> i64 loc(#loc19)\n %181 = "stable_mosaic_gpu.llvm.inttoptr"(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %182 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %183 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n %184 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %185 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %186 = "stable_mosaic_gpu.arith.addi"(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n %187 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %188 = "stable_mosaic_gpu.arith.remsi"(%186, %187) : (i32, i32) -> i32 loc(#loc22)\n %189 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc23)\n %190 = "stable_mosaic_gpu.arith.index_castui"(%189) : (index) -> i32 loc(#loc23)\n %191 = "stable_mosaic_gpu.arith.addi"(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %192 = "stable_mosaic_gpu.memref.load"(%95) : (memref) -> i32 loc(#loc23)\n %193 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc23)\n %194 = "stable_mosaic_gpu.arith.shli"(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %195 = "stable_mosaic_gpu.arith.andi"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n %196 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc23)\n %197 = "stable_mosaic_gpu.arith.cmpi"(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\n %198 = "stable_mosaic_gpu.arith.xori"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n "stable_mosaic_gpu.memref.store"(%198, %95) : (i32, memref) -> () loc(#loc23)\n %199 = "stable_mosaic_gpu.arith.constant"() {value = 10000000 : i32} : () -> i32 loc(#loc23)\n %200 = "stable_mosaic_gpu.arith.extui"(%197) : (i1) -> i32 loc(#loc23)\n %201 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\n "stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared"(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\n %202 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %203 = "stable_mosaic_gpu.memref.subview"(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %204 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %205 = "stable_mosaic_gpu.memref.subview"(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %206 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc24)\n %207 = "stable_mosaic_gpu.arith.index_cast"(%206) : (index) -> i32 loc(#loc24)\n %208 = "stable_mosaic_gpu.memref.subview"(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %209 = "stable_mosaic_gpu.memref.collapse_shape"(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %210 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc25)\n %211 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc25)\n %212 = "stable_mosaic_gpu.arith.remui"(%210, %211) : (index, index) -> index loc(#loc25)\n %213 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc25)\n %214 = "stable_mosaic_gpu.arith.muli"(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %215 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc25)\n %216 = "stable_mosaic_gpu.arith.addi"(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %217 = "stable_mosaic_gpu.vector.load"(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\n %218 = "stable_mosaic_gpu.arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\n %219 = "stable_mosaic_gpu.vector.splat"(%218) : (f32) -> vector<2xf32> loc(#loc26)\n %220 = "stable_mosaic_gpu.arith.addf"(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\n %221 = "stable_mosaic_gpu.memref.subview"(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %222 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %223 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %224 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %225 = "stable_mosaic_gpu.arith.remui"(%223, %224) : (index, index) -> index loc(#loc27)\n %226 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %227 = "stable_mosaic_gpu.arith.muli"(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %228 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %229 = "stable_mosaic_gpu.arith.addi"(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %230 = "stable_mosaic_gpu.vector.load"(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\n %231 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %232 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %233 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %234 = "stable_mosaic_gpu.arith.remui"(%232, %233) : (index, index) -> index loc(#loc27)\n %235 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %236 = "stable_mosaic_gpu.arith.muli"(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %237 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %238 = "stable_mosaic_gpu.arith.addi"(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n "stable_mosaic_gpu.vector.store"(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n %239 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc29)\n %240 = "stable_mosaic_gpu.arith.addi"(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %241 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %242 = "stable_mosaic_gpu.arith.remsi"(%240, %241) : (i32, i32) -> i32 loc(#loc22)\n %243 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc30)\n %244 = "stable_mosaic_gpu.arith.cmpi"(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\n %245 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc31)\n %246 = "stable_mosaic_gpu.arith.cmpi"(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\n %247 = "stable_mosaic_gpu.arith.andi"(%244, %246) : (i1, i1) -> i1 loc(#loc32)\n %248 = "stable_mosaic_gpu.arith.extui"(%247) : (i1) -> i32 loc(#loc33)\n %249 = "stable_mosaic_gpu.arith.index_cast"(%248) : (i32) -> index loc(#loc34)\n "stable_mosaic_gpu.scf.index_switch"(%249) ({\n %313 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %314 = "stable_mosaic_gpu.memref.subview"(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %315 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %316 = "stable_mosaic_gpu.arith.index_castui"(%315) : (index) -> i32 loc(#loc19)\n %317 = "stable_mosaic_gpu.arith.addi"(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %318 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %319 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %320 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %321 = "stable_mosaic_gpu.arith.index_cast"(%320) : (index) -> i32 loc(#loc19)\n %322 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %323 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %324 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %325 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %326 = "stable_mosaic_gpu.llvm.mul"(%324, %325) : (i64, i64) -> i64 loc(#loc19)\n %327 = "stable_mosaic_gpu.llvm.ptrtoint"(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %328 = "stable_mosaic_gpu.llvm.add"(%327, %326) : (i64, i64) -> i64 loc(#loc19)\n %329 = "stable_mosaic_gpu.llvm.inttoptr"(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %330 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %331 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc16)\n }, {\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc34)\n }) {cases = array} : (index) -> () loc(#loc34)\n %250 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc21)\n %251 = "stable_mosaic_gpu.arith.addi"(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n "stable_mosaic_gpu.nvvm.fence.proxy"() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\n %252 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %253 = "stable_mosaic_gpu.arith.index_cast"(%252) : (index) -> i32 loc(#loc35)\n %254 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %255 = "stable_mosaic_gpu.arith.index_cast"(%254) : (index) -> i32 loc(#loc35)\n %256 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %257 = "stable_mosaic_gpu.arith.index_cast"(%256) : (index) -> i32 loc(#loc35)\n %258 = "stable_mosaic_gpu.arith.muli"(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %259 = "stable_mosaic_gpu.arith.addi"(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %260 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %261 = "stable_mosaic_gpu.arith.index_cast"(%260) : (index) -> i32 loc(#loc35)\n %262 = "stable_mosaic_gpu.arith.muli"(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %263 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %264 = "stable_mosaic_gpu.arith.index_cast"(%263) : (index) -> i32 loc(#loc35)\n %265 = "stable_mosaic_gpu.arith.muli"(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %266 = "stable_mosaic_gpu.arith.addi"(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %267 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %268 = "stable_mosaic_gpu.arith.index_cast"(%267) : (index) -> i32 loc(#loc35)\n %269 = "stable_mosaic_gpu.arith.muli"(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %270 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc35)\n %271 = "stable_mosaic_gpu.arith.shrui"(%266, %270) : (i32, i32) -> i32 loc(#loc35)\n %272 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc35)\n %273 = "stable_mosaic_gpu.arith.addi"(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %274 = "stable_mosaic_gpu.llvm.inline_asm"(%273) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc35)\n %275 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc22)\n %276 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %277 = "stable_mosaic_gpu.arith.remsi"(%275, %276) : (i32, i32) -> i32 loc(#loc22)\n %278 = "stable_mosaic_gpu.arith.index_cast"(%277) : (i32) -> index loc(#loc18)\n %279 = "stable_mosaic_gpu.memref.subview"(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\n %280 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc18)\n %281 = "stable_mosaic_gpu.arith.index_cast"(%280) : (index) -> i32 loc(#loc18)\n %282 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\n %283 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\n %284 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\n %285 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc18)\n %286 = "stable_mosaic_gpu.llvm.mul"(%284, %285) : (i64, i64) -> i64 loc(#loc18)\n %287 = "stable_mosaic_gpu.llvm.ptrtoint"(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\n %288 = "stable_mosaic_gpu.llvm.add"(%287, %286) : (i64, i64) -> i64 loc(#loc18)\n %289 = "stable_mosaic_gpu.llvm.inttoptr"(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta"(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group"() {group = 0 : i32} : () -> () loc(#loc36)\n %290 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %291 = "stable_mosaic_gpu.arith.index_cast"(%290) : (index) -> i32 loc(#loc36)\n %292 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %293 = "stable_mosaic_gpu.arith.index_cast"(%292) : (index) -> i32 loc(#loc36)\n %294 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %295 = "stable_mosaic_gpu.arith.index_cast"(%294) : (index) -> i32 loc(#loc36)\n %296 = "stable_mosaic_gpu.arith.muli"(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %297 = "stable_mosaic_gpu.arith.addi"(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %298 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %299 = "stable_mosaic_gpu.arith.index_cast"(%298) : (index) -> i32 loc(#loc36)\n %300 = "stable_mosaic_gpu.arith.muli"(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %301 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %302 = "stable_mosaic_gpu.arith.index_cast"(%301) : (index) -> i32 loc(#loc36)\n %303 = "stable_mosaic_gpu.arith.muli"(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %304 = "stable_mosaic_gpu.arith.addi"(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %305 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %306 = "stable_mosaic_gpu.arith.index_cast"(%305) : (index) -> i32 loc(#loc36)\n %307 = "stable_mosaic_gpu.arith.muli"(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %308 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc36)\n %309 = "stable_mosaic_gpu.arith.shrui"(%304, %308) : (i32, i32) -> i32 loc(#loc36)\n %310 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc36)\n %311 = "stable_mosaic_gpu.arith.addi"(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %312 = "stable_mosaic_gpu.llvm.inline_asm"(%311) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc36)\n "stable_mosaic_gpu.gpu.terminator"() : () -> () loc(#loc17)\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\n "stable_mosaic_gpu.func.return"() : () -> () loc(#loc17)\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = "mosaic_gpu_body"} : () -> () loc(#loc17)\n}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\n#loc13 = loc("-":141:7)\n#loc14 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:19)\n#loc15 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:6)\n#loc16 = loc("-":279:7)\n#loc18 = loc("/copy_smem_to_gmem"(#loc))\n#loc19 = loc("/copy_gmem_to_smem"(#loc))\n#loc20 = loc("/run_scoped"(#loc))\n#loc21 = loc("/scan"(#loc))\n#loc22 = loc("/rem"(#loc))\n#loc23 = loc("/barrier_wait"(#loc))\n#loc24 = loc("/jaxpr_call"(#loc))\n#loc25 = loc("/get"(#loc14))\n#loc26 = loc("/add"(#loc14))\n#loc27 = loc("/swap"(#loc15))\n#loc28 = loc("/commit_group"(#loc))\n#loc29 = loc("/add"(#loc))\n#loc30 = loc("/ge"(#loc))\n#loc31 = loc("/lt"(#loc))\n#loc32 = loc("/and"(#loc))\n#loc33 = loc("/convert_element_type"(#loc))\n#loc34 = loc("/cond"(#loc))\n#loc35 = loc("/commit_smem"(#loc))\n#loc36 = loc("/wait_smem_to_gmem"(#loc))\n\x00mosaic_gpu\x00\x08\'\x07\x05\x1f\x01\x0b!%\'/1\x11357\x1d9\x1f\x1d\x1f', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index a5a3c984a0c8..1fa8cbaa765f 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -169,6 +169,7 @@ def test_custom_call_coverage(self): covered_targets = covered_targets.union({ "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately + "mosaic_gpu", # tested in pallas/export_back_compat_pallas_test.py "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py # The following require ROCm to test "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index fa98c0af4be8..92ff732e2200 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -237,6 +237,7 @@ jax_multiplatform_test( "//jax:internal_export_back_compat_test_util", "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu_ops", # build_cleaner: keep ], ) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index addf14d73792..c37bbbfec2a0 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -17,6 +17,7 @@ update these tests. """ +import functools import math import unittest @@ -25,6 +26,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_gpu_add_one from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one @@ -43,9 +45,6 @@ class CompatTest(bctu.CompatTestBase): def setUp(self): if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() @unittest.skip("This test is checking backwards compatibility " @@ -53,6 +52,9 @@ def setUp(self): "compatibility for its IR, and we have since removed " "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPUs with capability >= sm80") + def func(x): def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -65,6 +67,22 @@ def add_one(x_ref, o_ref): self.run_one_test(func, data) + def test_mosaic_gpu_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPUs with capability >= sm90") + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128 * 2,), jnp.float32), + grid=2, + backend="mosaic_gpu", + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + data = self.load_testdata(mosaic_gpu_add_one.data_2025_04_22) + self.run_one_test(add_one, data) + @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): # TODO(apaszke): Remove after 12 weeks have passed. diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e13165fd076b..ad80e3cf4c0a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2817,12 +2817,7 @@ def test_export_succeeds(self): def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] + 1.0 - _ = export.export( - kernel, - disabled_checks=[ - export.DisabledSafetyCheck.custom_call("mosaic_gpu"), - ], - )(out_shape) + _ = export.export(kernel)(out_shape) class ExamplesTest(PallasTest): From 8eb80b953469af4a61e6dc619288c3a70b478fd0 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 23 Apr 2025 05:31:56 -0700 Subject: [PATCH 0757/1769] Include MGPU in grid/blockspec tutorial + use proper note/warning formatting There are no restrictions on window sizes in that backend. Also, replace Markdown quotations with Note/Warning blocks in the GPU reference for added clarity. PiperOrigin-RevId: 750555285 --- docs/pallas/gpu/reference.md | 78 ++++++++++++++++++++--------------- docs/pallas/grid_blockspec.md | 10 ++++- 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index d67f80fb0ebe..0db31e11b459 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -34,10 +34,12 @@ Going further, recent CUDA versions also outline the concept of a _warpgroup_, w from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions that utilize the whole SM. -> A GPU can be viewed in many different ways and in here we want to focus on a slightly - simplified model that is very TensorCore-centric. This should help you navigate the - complexities of writing kernels involving the TensorCore, but keep in mind that the - real picture is more complicated. +```{note} +A GPU can be viewed in many different ways and in here we want to focus on a slightly +simplified model that is very TensorCore-centric. This should help you navigate the +complexities of writing kernels involving the TensorCore, but keep in mind that the +real picture is more complicated. +``` For our purposes, TensorCore operations have grown so big that it no longer makes much sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores @@ -49,14 +51,18 @@ discuss later). One notable addition here is that we still allow you to co-sched of those Pallas-level threads on the same SM so that they can cooperate and communicate through shared memory (we relize that by putting them in the same CUDA block). -> From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. +```{note} +From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. +``` -> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), - but as you will see there are a few differences. Mosaic GPU tends to be more low level, - which usually means you will have to put in more work, but it also puts you more in control. - In our view both approaches have their merits and we encourage you to pick the backend that - suits your needs the best! Pallas supports and will continue to support Triton as an alternative - GPU backend. +```{note} +This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), +but as you will see there are a few differences. Mosaic GPU tends to be more low level, +which usually means you will have to put in more work, but it also puts you more in control. +In our view both approaches have their merits and we encourage you to pick the backend that +suits your needs the best! Pallas supports and will continue to support Triton as an alternative +GPU backend. +``` ### In-order execution & using multiple hardware units @@ -195,9 +201,11 @@ These concepts are crucial for performance, especially when interacting with specialized hardware units like TensorCores or optimizing memory access patterns. -> We are working on a mode that will deal with assigning layouts and transforms fully - automatically (although with way to provide hints and more control). The APIs listed - below will likely continue to function, but will become optional. +```{note} +We are working on a mode that will deal with assigning layouts and transforms fully +automatically (although with way to provide hints and more control). The APIs listed +below will likely continue to function, but will become optional. +``` ### Memory reference transforms @@ -261,7 +269,6 @@ The available transforms are: This is primarily useful in that it lets you change the layout during the GMEM-SMEM copies (only do keep in mind that changing the minormost/last dimension is not supported by the hardware). - ### Array layouts There are a few useful layouts we have defined for you so far: @@ -502,27 +509,32 @@ plgpu.barrier_wait(barrier) There are three operations that can complete a barrier: -> It is critical to ensure that the synchronization scheme makes it impossible for two - barrier completions to happen without a call to `plgpu.barrier_wait` in between them. - For example, if you use `Barrier`s to synchronize two producer/consumer threads, you - need to perform barrier synchronization going both ways to introduce "backpressure" - that will stop one thread from arriving twice before the other one had a chance to await. - Failing to satisfy this will corrupt the data structure and can cause surprising failures - (including CUDA runtime errors). See below for an example of a valid program with two threads. - -> Another critical restriction is that the number of barrier completions must equal the - number of barrier waits throughout the barrier's lifetime. It is not allowed to end a scoped - allocation of a barrier when it has an unawaited completion. Otherwise, when it is - reused by the compiler, leaving it in this state can cause problems downstream. +```{warning} +It is critical to ensure that the synchronization scheme makes it impossible for two +barrier completions to happen without a call to `plgpu.barrier_wait` in between them. +For example, if you use `Barrier`s to synchronize two producer/consumer threads, you +need to perform barrier synchronization going both ways to introduce "backpressure" +that will stop one thread from arriving twice before the other one had a chance to await. +Failing to satisfy this will corrupt the data structure and can cause surprising failures +(including CUDA runtime errors). See below for an example of a valid program with two threads. +``` -> Finally, it is crucial to ensure that each thread that ever waits on a `Barrier` - takes part in all `wait` operations on it. It is not allowed to e.g. await every - other completion of a barrier from one thread, and all other completions from another - one. Doing so will lead to deadlocks. To recap: when a `Barrier` is used to wait in - some thread, it must observe every single completion of that barrier (by waiting on it). +```{warning} +Another critical restriction is that the number of barrier completions must equal the +number of barrier waits throughout the barrier's lifetime. It is not allowed to end a scoped +allocation of a barrier when it has an unawaited completion. Otherwise, when it is +reused by the compiler, leaving it in this state can cause problems downstream. +``` +```{warning} +Finally, it is crucial to ensure that each thread that ever waits on a `Barrier` +takes part in all `wait` operations on it. It is not allowed to e.g. await every +other completion of a barrier from one thread, and all other completions from another +one. Doing so will lead to deadlocks. To recap: when a `Barrier` is used to wait in +some thread, it must observe every single completion of that barrier (by waiting on it). - Note that the `Barrier` can receive arrivals from any source, without restrictions. +Note that the `Barrier` can receive arrivals from any source, without restrictions. +``` #### Asynchronous GMEM-to-SMEM copies diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index d74a91c96f54..d360e3e660b5 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -80,8 +80,14 @@ Not all block shapes are supported. must be equal to the array dimension, or be divisible by `128 * (32 / bitwidth(dtype))`. - * On GPU, the size of the blocks themselves is not restricted, but each - operation must operate on arrays whose size is a power of 2. + * On GPU, when using the Mosaic GPU backend, the size of the blocks is + unrestricted. However, due to hardware limitations, the size of the minormost + array dimension must by such that it is a multiple of 16 bytes. For example, + it must be a multiple of 8 if the input is `jnp.float16`. + + * On GPU, when using the Triton backend, the size of the blocks themselves is + unrestricted, but each operation (including a load or store) must operate + on arrays whose size is a power of 2. ``` If the block shape does not divide evenly the overall shape then the From d5de9e8cae0d1e20d94187478ec96dc2820f5c34 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 23 Apr 2025 06:32:09 -0700 Subject: [PATCH 0758/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/99b7c3bf05c3877c70ad587439b7481889810564. PiperOrigin-RevId: 750569770 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6135c8c1da65..912ea661a8b9 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3e18a59cd822a8426db29c8c36e912e7b2dbaae4" -XLA_SHA256 = "b00d2e514d5a7bb7276ab7a82d2b1c380c27709ac1b91f92c8b3ccfb87c285c0" +XLA_COMMIT = "99b7c3bf05c3877c70ad587439b7481889810564" +XLA_SHA256 = "148505b7fbbab60879608b43e7d038a7e8c97ddd6e2c6f45c11aca37e348b6a9" def repo(): tf_http_archive( From 92453fbc2a4758c9dd28ad0cfff7b9ba6f941b3b Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 23 Apr 2025 15:44:51 +0200 Subject: [PATCH 0759/1769] TSAN FT CI, install cython from nightly wheels --- .github/workflows/tsan.yaml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index ef1cf99d6d74..8336d86120d1 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -130,8 +130,10 @@ jobs: export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH python3 -m pip install uv~=0.5.30 - # Make sure to install a compatible Cython version (master branch is best for now) - NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + + # Install Cython same as in numpy CI: https://github.com/numpy/numpy/blob/9ead596ce4f8df0189f9ba3d54937e22e2785a5e/.github/workflows/linux.yml#L75C21-L75C96 + python3 -m uv pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -197,8 +199,10 @@ jobs: export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH python3 -m pip install uv~=0.5.30 - # Make sure to install a compatible Cython version (master branch is best for now) - NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + + # Install Cython same as in numpy CI: https://github.com/numpy/numpy/blob/9ead596ce4f8df0189f9ba3d54937e22e2785a5e/.github/workflows/linux.yml#L75C21-L75C96 + python3 -m uv pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple cython + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ python3 -m uv pip install pythran pybind11 meson-python ninja From 5227daa75ad066cd8105424288b35faf9e562578 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 23 Apr 2025 06:54:30 -0700 Subject: [PATCH 0760/1769] Fix typo "dataclasss". PiperOrigin-RevId: 750575644 --- jaxlib/xla/pytree.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index 9359165b19dd..08feeec94f1c 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -1053,7 +1053,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { auto* registration = registry_->Lookup(object.type()); if (registration != node.custom) { throw std::invalid_argument(absl::StrFormat( - "Custom dataclasss node type mismatch: expected type: %s, value: " + "Custom dataclass node type mismatch: expected type: %s, value: " "%s.", nb::cast(nb::repr(node.custom->type)), nb::cast(nb::repr(std::move(object))))); From c483a86c91a4fde7bd53b65c6b8e0772d314dfd4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Apr 2025 08:21:52 -0700 Subject: [PATCH 0761/1769] Fix some typos. PiperOrigin-RevId: 750600271 --- jaxlib/xla/pmap_lib.cc | 9 ++++----- jaxlib/xla/sharded_device_array.h | 6 +++--- jaxlib/xla/xla_client_test.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 94a79e8ba0b9..01b301d008af 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -893,8 +893,7 @@ void BuildPmapSubmodule(nb::module_& m) { [](const NoSharding& self) { return nb::make_tuple(); }) .def("__setstate__", [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) - .def("__repr__", - [](const NoSharding& chuncked) { return "NoSharding()"; }) + .def("__repr__", [](const NoSharding& self) { return "NoSharding()"; }) .def("__eq__", [](const NoSharding& self, nb::object obj) { return nb::isinstance(obj); @@ -914,9 +913,9 @@ void BuildPmapSubmodule(nb::module_& m) { }) .def_ro("chunks", &Chunked::chunks) .def("__repr__", - [](const Chunked& chuncked) { - return absl::StrCat("Chunked(", - absl::StrJoin(chuncked.chunks, ","), ")"); + [](const Chunked& self) { + return absl::StrCat("Chunked(", absl::StrJoin(self.chunks, ","), + ")"); }) .def("__eq__", [](const Chunked& self, nb::object other) { if (!nb::isinstance(other)) { diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h index 6e014789a289..b0b5597f9d41 100644 --- a/jaxlib/xla/sharded_device_array.h +++ b/jaxlib/xla/sharded_device_array.h @@ -34,7 +34,7 @@ namespace jax { // High level introduction. // // pmap and other parallel computation functions distribute some computation on -// several devices. On December 2020, the devices mesh (i.e. N-dimentional array +// several devices. On December 2020, the devices mesh (i.e. N-dimensional array // of devices on which we map the computation) is defined by the user. // // We describe how to shard the inputs, and how to map it to the mesh of devices @@ -157,7 +157,7 @@ using MeshDimAssignment = std::variant; // mesh_mapping = [ShardedAxis(0)] // // 2. With an input array of shape [6], that we want to chunk into [2, 3] -// Assuming an device mesh [3, 4, 2] of devices, we will have: +// Assuming a device mesh [3, 4, 2] of devices, we will have: // // sharding = [Chunked([2, 3])] // mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] @@ -193,7 +193,7 @@ class ShardingSpec { private: // `sharding` specifies how the array is supposed to get partitioned into - // chunks. Its length matchs the rank of the array. See the docstring + // chunks. Its length matches the rank of the array. See the docstring // of `AvalDimSharding` for the supported partitioning schemes. std::vector sharding_; // `mesh_mapping` describes an assignments of the array chunks created by diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 15c307145b29..7779b314ce0e 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -141,7 +141,7 @@ def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): # Return an unaligned copy of `x`. The result buffer's memory address is # guaranteed to not be aligned to `alignment`. This function is useful for -# testing failiures. +# testing failures. def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): if (x.ctypes.data % alignment) != 0: return x From 0948435faa4110381765a1393dc7634b4332602c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Apr 2025 08:33:09 -0700 Subject: [PATCH 0762/1769] Fix mypy errors related to jaxlib. At the moment mypy isn't correctly detecting errors related to jaxlib. In a future change this will be fixed, and this PR fixes errors that will be revealed by that change. PiperOrigin-RevId: 750603531 --- jax/_src/custom_partitioning_sharding_rule.py | 3 +- jax/_src/named_sharding.py | 2 +- jax/_src/pallas/mosaic_gpu/core.py | 4 +- .../mosaic/gpu/inference_utils.py | 4 +- jaxlib/gpu_linalg.py | 9 +- jaxlib/gpu_prng.py | 9 +- jaxlib/gpu_solver.py | 13 +- jaxlib/gpu_sparse.py | 9 +- jaxlib/mosaic/python/tpu.py | 5 +- jaxlib/xla/xla_extension/__init__.pyi | 155 +++++++++--------- jaxlib/xla/xla_extension/ops.pyi | 28 ++-- jaxlib/xla/xla_extension/pmap_lib.pyi | 6 +- jaxlib/xla/xla_extension/pytree.pyi | 14 +- jaxlib/xla/xla_extension/sdy.pyi | 2 +- 14 files changed, 141 insertions(+), 122 deletions(-) diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 5e2e5f4e0479..b5563634352c 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -15,6 +15,7 @@ """Implements SdyShardingRule.""" from collections import OrderedDict +from typing import Union from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy @@ -28,7 +29,7 @@ _BATCHING_DIM_FACTOR_PREFIX = "?" # A Jax value in general corresponds to an ir.Type or a tuple of ir.Types. -IrTypes = ir.Type | tuple[ir.Type, ...] +IrTypes = Union[ir.Type, tuple[ir.Type, ...]] def _check_factor(factor:str): """Validates a factor. diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 2c6741ab4c9a..fba54f438471 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -349,7 +349,7 @@ def named_sharding_to_xla_hlo_sharding( last_tile_dims = [] if replicated_mesh_axes: - axes_by_type = collections.defaultdict(list) + axes_by_type: dict[Any, list[int]] = collections.defaultdict(list) size_by_type = collections.defaultdict(lambda: 1) # type: ignore assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys())) for i, size in replicated_mesh_axes: diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 5015ebdb7e1f..e775d9f3cc64 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -22,7 +22,7 @@ import dataclasses import enum import itertools as it -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, Union import jax from jax._src import core as jax_core @@ -240,7 +240,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = mgpu.DynamicSlice | slice | int | ir.Value +Index = Union[mgpu.DynamicSlice, slice, int, ir.Value] @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 73ce23c427cd..ff3dbf665681 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -18,11 +18,11 @@ import enum from functools import partial import itertools -from typing import cast +from typing import cast, Union from jax._src.lib.mlir import ir -MlirOperation = ir.Operation | ir.OpView +MlirOperation = Union[ir.Operation, ir.OpView] def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: """Returns the in_layouts attribute of the given operation. diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index c747c0abbe8b..967dacdbacff 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -19,12 +19,17 @@ _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index b112534c0575..17da46de699f 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -19,11 +19,16 @@ _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: if module: registrations[platform].extend( (name, value, int(name.endswith("_ffi"))) - for name, value in module.registrations().items()) + for name, value in module.registrations().items() + ) return registrations diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index c846c63e2ff8..cdcd2b6199f9 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -24,7 +24,10 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( @@ -34,17 +37,17 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type def batch_partitionable_targets() -> list[str]: - targets = [] + targets: list[str] = [] for module in [_cusolver, _hipsolver]: if module: targets.extend( - name for name in module.registrations() - if name.endswith("_ffi") + name for name in module.registrations() if name.endswith("_ffi") ) for module in [_cuhybrid, _hiphybrid]: if module: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index cc2b2ad08e55..af03eb6e6a8a 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -22,11 +22,16 @@ cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: if module: registrations[platform].extend( (name, value, int(name.endswith("_ffi"))) - for name, value in module.registrations().items()) + for name, value in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type diff --git a/jaxlib/mosaic/python/tpu.py b/jaxlib/mosaic/python/tpu.py index a1c7f79ba769..8083b9759f1b 100644 --- a/jaxlib/mosaic/python/tpu.py +++ b/jaxlib/mosaic/python/tpu.py @@ -19,6 +19,7 @@ # pylint: disable=g-bad-import-order +from . import _tpu_gen from ._tpu_gen import * # pylint: disable=wildcard-import from ._tpu_gen import _Dialect from jaxlib.mlir._mlir_libs._tpu_ext import * # pylint: disable=wildcard-import @@ -32,7 +33,7 @@ @_cext.register_operation(_Dialect, replace=True) -class TraceOp(TraceOp): # noqa: F405 +class TraceOp(_tpu_gen.TraceOp): # noqa: F405 """An extension to the automatically generated TraceOp bindings.""" def __init__(self, results, message, level, *, loc=None, ip=None): @@ -45,7 +46,7 @@ def body(self): @_cext.register_operation(_Dialect, replace=True) -class RegionOp(RegionOp): # noqa: F405 +class RegionOp(_tpu_gen.RegionOp): # noqa: F405 """An extension to the automatically generated RegionOp bindings.""" def __init__(self, results, *, loc=None, ip=None): diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 491ba4a2e50a..537c18c7adf1 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -29,7 +29,6 @@ from typing import ( Optional, Sequence, Tuple, - Type, TypeVar, Union, overload, @@ -63,45 +62,45 @@ class XlaRuntimeError(RuntimeError): pass class PrimitiveType(enum.IntEnum): - PRIMITIVE_TYPE_INVALID: PrimitiveType - PRED: PrimitiveType - S2: PrimitiveType - S4: PrimitiveType - S8: PrimitiveType - S16: PrimitiveType - S32: PrimitiveType - S64: PrimitiveType - U2: PrimitiveType - U4: PrimitiveType - U8: PrimitiveType - U16: PrimitiveType - U32: PrimitiveType - U64: PrimitiveType - F4E2M1FN: PrimitiveType - F8E3M4: PrimitiveType - F8E4M3: PrimitiveType - F8E4M3FN: PrimitiveType - F8E4M3B11FNUZ: PrimitiveType - F8E4M3FNUZ: PrimitiveType - F8E5M2: PrimitiveType - F8E5M2FNUZ: PrimitiveType - F8E8M0FNU: PrimitiveType - BF16: PrimitiveType - F16: PrimitiveType - F32: PrimitiveType - F64: PrimitiveType - C64: PrimitiveType - C128: PrimitiveType - TUPLE: PrimitiveType - OPAQUE_TYPE: PrimitiveType - TOKEN: PrimitiveType + PRIMITIVE_TYPE_INVALID = ... + PRED = ... + S2 = ... + S4 = ... + S8 = ... + S16 = ... + S32 = ... + S64 = ... + U2 = ... + U4 = ... + U8 = ... + U16 = ... + U32 = ... + U64 = ... + F4E2M1FN = ... + F8E3M4 = ... + F8E4M3 = ... + F8E4M3FN = ... + F8E4M3B11FNUZ = ... + F8E4M3FNUZ = ... + F8E5M2 = ... + F8E5M2FNUZ = ... + F8E8M0FNU = ... + BF16 = ... + F16 = ... + F32 = ... + F64 = ... + C64 = ... + C128 = ... + TUPLE = ... + OPAQUE_TYPE = ... + TOKEN = ... # === BEGIN xla_compiler.cc class ArrayCopySemantics(enum.IntEnum): - ALWAYS_COPY: ArrayCopySemantics - REUSE_INPUT: ArrayCopySemantics - DONATE_INPUT: ArrayCopySemantics + ALWAYS_COPY = ... + REUSE_INPUT = ... + DONATE_INPUT = ... class Layout: @overload @@ -114,8 +113,8 @@ class Layout: def tiling(self) -> Sequence[Tuple[int, ...]]: ... def element_size_in_bits(self) -> int: ... def to_string(self) -> str: ... - def __eq__(self, other: Layout) -> bool: ... - def __ne__(self, other: Layout) -> bool: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Shape: @@ -150,8 +149,8 @@ class Shape: def tuple_shapes(self) -> List[Shape]: ... def leaf_count(self) -> int: ... def with_major_to_minor_layout_if_absent(self) -> Shape: ... - def __eq__(self, other: Shape) -> bool: ... - def __ne__(self, other: Shape) -> bool: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... @@ -162,14 +161,14 @@ class ProgramShape: def __repr__(self) -> str: ... class ShapeIndex: - def __init__(self, indices: List[int]) -> ShapeIndex: ... - def __eq__(self, other: Shape) -> bool: ... - def __ne__(self, other: Shape) -> bool: ... + def __init__(self, indices: List[int]) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... class Literal: - def __init__(self, shape: Shape) -> Literal: ... + def __init__(self, shape: Shape) -> None: ... def __repr__(self) -> str: ... def __array__( self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None @@ -306,9 +305,9 @@ def register_custom_call_as_batch_partitionable( def register_custom_type_id(type_name: str, type_id: Any) -> None: ... class AutotuneCacheMode(enum.IntEnum): - UNSPECIFIED: AutotuneCacheMode - UPDATE: AutotuneCacheMode - READ: AutotuneCacheMode + UNSPECIFIED = ... + UPDATE = ... + READ = ... class DebugOptions: def __repr__(self) -> str: ... @@ -383,15 +382,15 @@ class ExecutableBuildOptions: def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... class PrecisionConfig_Precision(enum.IntEnum): - DEFAULT: int - HIGH: int - HIGHEST: int + DEFAULT = ... + HIGH = ... + HIGHEST = ... class ResultAccuracy_Mode(enum.IntEnum): - DEFAULT: int - HIGHEST: int - TOLERANCE: int + DEFAULT = ... + HIGHEST = ... + TOLERANCE = ... class ResultAccuracy: mode: ResultAccuracy_Mode @@ -400,22 +399,22 @@ class ResultAccuracy: ulps: int class OpSharding_Type(enum.IntEnum): - REPLICATED: int - MAXIMAL: int - TUPLE: int - OTHER: int - MANUAL: int - UNKNOWN: int + REPLICATED = ... + MAXIMAL = ... + TUPLE = ... + OTHER = ... + MANUAL = ... + UNKNOWN = ... class OpSharding_ShardGroupType(enum.IntEnum): - AS: int - LIKE: int + AS = ... + LIKE = ... class OpSharding: Type: typing.Type[OpSharding_Type] type: OpSharding_Type replicate_on_last_tile_dim: bool - last_tile_dims: Sequence[Type] + last_tile_dims: Sequence[OpSharding_Type] tile_assignment_dimensions: Sequence[int] tile_assignment_devices: Sequence[int] iota_reshape_dims: Sequence[int] @@ -443,7 +442,7 @@ class HloSharding: dims: Sequence[int], reshape_dims: Sequence[int], transpose_perm: Sequence[int], - subgroup_types: Sequence[OpSharding.Type], + subgroup_types: Sequence[OpSharding_Type], ) -> HloSharding: ... @staticmethod def replicate() -> HloSharding: ... @@ -454,8 +453,8 @@ class HloSharding: @staticmethod def subgroup_with_device_ordering( tile_assignment: np.ndarray, - subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... - def __eq__(self, other: HloSharding) -> bool: ... + subgroup_types: Sequence[OpSharding_Type]) -> HloSharding: ... + def __eq__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... def tile(self, shape: Shape) -> Shape: ... @@ -469,15 +468,15 @@ class HloSharding: def num_dimensions(self) -> int: ... def tile_assignment_dimensions(self) -> Sequence[int]: ... def tile_assignment_devices(self) -> Sequence[int]: ... - def subgroup_types(self) -> Sequence[OpSharding.Type]: ... + def subgroup_types(self) -> Sequence[OpSharding_Type]: ... def replicate_on_last_tile_dim(self) -> bool: ... def to_proto(self) -> OpSharding: ... class FftType(enum.IntEnum): - FFT: FftType - IFFT: FftType - RFFT: FftType - IRFFT: FftType + FFT = ... + IFFT = ... + RFFT = ... + IRFFT = ... # === END xla_compiler.cc @@ -511,7 +510,7 @@ class Memory: class PjRtLayout: def __str__(self) -> str: ... - def __eq__(self, other: PjRtLayout) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __getstate__(self) -> Any: ... def __setstate__(self, _: Any): ... @@ -519,10 +518,10 @@ class PjRtLayout: class GpuAllocatorConfig: class Kind(enum.IntEnum): - DEFAULT: int - PLATFORM: int - BFC: int - CUDA_ASYNC: int + DEFAULT = ... + PLATFORM = ... + BFC = ... + CUDA_ASYNC = ... def __init__( self, @@ -533,9 +532,9 @@ class GpuAllocatorConfig: ) -> None: ... class HostBufferSemantics(enum.IntEnum): - IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics - IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics - ZERO_COPY: HostBufferSemantics + IMMUTABLE_ONLY_DURING_CALL = ... + IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ... + ZERO_COPY = ... class Client: platform: str diff --git a/jaxlib/xla/xla_extension/ops.pyi b/jaxlib/xla/xla_extension/ops.pyi index ff55de3a5cdc..f76ff1c2002c 100644 --- a/jaxlib/xla/xla_extension/ops.pyi +++ b/jaxlib/xla/xla_extension/ops.pyi @@ -39,28 +39,28 @@ _ReplicaGroup = Any _ScatterDimensionNumbers = Any class TriangularSolveOptions_Transpose(enum.IntEnum): - TRANSPOSE_INVALID: int - NO_TRANSPOSE: int - TRANSPOSE: int - ADJOINT: int + TRANSPOSE_INVALID = ... + NO_TRANSPOSE = ... + TRANSPOSE = ... + ADJOINT = ... class RandomAlgorithm(enum.IntEnum): - RNG_DEFAULT: int - RNG_THREE_FRY: int - RNG_PHILOX: int + RNG_DEFAULT = ... + RNG_THREE_FRY = ... + RNG_PHILOX = ... class CustomCallSchedule(enum.IntEnum): - SCHEDULE_NONE: int - SCHEDULE_LATEST: int - SCHEDULE_EARLIEST: int + SCHEDULE_NONE = ... + SCHEDULE_LATEST = ... + SCHEDULE_EARLIEST = ... # TODO(b/189822916): Remove this enum when all clients are migrated to the # status-returning API. class CustomCallApiVersion(enum.IntEnum): - API_VERSION_ORIGINAL: int - API_VERSION_STATUS_RETURNING: int - API_VERSION_STATUS_RETURNING_UNIFIED: int - API_VERSION_TYPED_FFI: int + API_VERSION_ORIGINAL = ... + API_VERSION_STATUS_RETURNING = ... + API_VERSION_STATUS_RETURNING_UNIFIED = ... + API_VERSION_TYPED_FFI = ... def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... def AllGather( diff --git a/jaxlib/xla/xla_extension/pmap_lib.pyi b/jaxlib/xla/xla_extension/pmap_lib.pyi index 8733d6c27b21..f862e87c0fcd 100644 --- a/jaxlib/xla/xla_extension/pmap_lib.pyi +++ b/jaxlib/xla/xla_extension/pmap_lib.pyi @@ -45,14 +45,14 @@ class ShardedAxis: def axis(self) -> int: ... def __init__(self, __axis: int) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: ShardedAxis) -> bool: ... + def __eq__(self, __other: Any) -> bool: ... class Replicated: @property def replicas(self) -> int: ... def __init__(self, __replicas: int) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Replicated) -> bool: ... + def __eq__(self, __other: Any) -> bool: ... class ShardingSpec: def __init__(self, @@ -62,7 +62,7 @@ class ShardingSpec: def sharding(self) -> Tuple[_AvalDimSharding, ...]: ... @property def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ... - def __eq__(self, __other: ShardingSpec) -> bool: ... + def __eq__(self, __other: Any) -> bool: ... def __hash__(self) -> int: ... _HAS_DYNAMIC_ATTRIBUTES = True diff --git a/jaxlib/xla/xla_extension/pytree.pyi b/jaxlib/xla/xla_extension/pytree.pyi index bfbad5de89d5..157d455e20ae 100644 --- a/jaxlib/xla/xla_extension/pytree.pyi +++ b/jaxlib/xla/xla_extension/pytree.pyi @@ -75,7 +75,7 @@ def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... class SequenceKey(Hashable): idx: int - __match_args__: tuple = ... + __match_args__: Tuple = ... def __init__(self, idx: int): ... def __str__(self) -> str: ... def __repr__(self) -> str: ... @@ -86,7 +86,7 @@ class SequenceKey(Hashable): class DictKey(Hashable): key: Hashable - __match_args__: tuple = ... + __match_args__: Tuple = ... def __init__(self, key: Hashable): ... def __str__(self) -> str: ... def __repr__(self) -> str: ... @@ -97,7 +97,7 @@ class DictKey(Hashable): class GetAttrKey(Hashable): name: str - __match_args__: tuple = ... + __match_args__: Tuple = ... def __init__(self, name: str): ... def __str__(self) -> str: ... def __repr__(self) -> str: ... @@ -108,7 +108,7 @@ class GetAttrKey(Hashable): class FlattenedIndexKey(Hashable): key: int - __match_args__: tuple = ... + __match_args__: Tuple = ... def __init__(self, key: int): ... def __str__(self) -> str: ... def __repr__(self) -> str: ... @@ -140,8 +140,8 @@ class PyTreeDef: num_leaves: int num_nodes: int def __repr__(self) -> str: ... - def __eq__(self, __other: PyTreeDef) -> bool: ... - def __ne__(self, __other: PyTreeDef) -> bool: ... + def __eq__(self, __other: Any) -> bool: ... + def __ne__(self, __other: Any) -> bool: ... def __hash__(self) -> int: ... def __getstate__(self) -> Any: ... def __setstate__(self, state: Any): ... @@ -153,6 +153,6 @@ class PyTreeDef: _Children = TypeVar("_Children", bound=Iterable[Any]) _KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) -_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair]) +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[Tuple[Any, Any]]) _KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) _AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/jaxlib/xla/xla_extension/sdy.pyi b/jaxlib/xla/xla_extension/sdy.pyi index 34714e5c0219..520f93f11bc6 100644 --- a/jaxlib/xla/xla_extension/sdy.pyi +++ b/jaxlib/xla/xla_extension/sdy.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from mlir import ir +from jaxlib.mlir import ir def sdy_round_trip_export_pipeline( module: ir.module From 5687a523fe467467e4580becb182013d31294bf2 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 23 Apr 2025 09:41:32 -0700 Subject: [PATCH 0763/1769] [Pallas/Mosaic GPU] Allow explicit `smem` aliasing. Users should now be able to instantiate aliases `smem` buffers by using an `RefUnion`, which takes a variadic number of trees of refs as an input. `RefUnion` represents a union/coproduct of all its operands, and its operands groups alias (overlap in memory), while the elements within the groups represent products, and their operands are consecutive in memory. The resulting aliased `smem` ref can then be unfolded into a flat structure using assignment inside the kernel. Here is an example: ``` @functools.partial( pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=[pl.BlockSpec((256,))], out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), scratch_shapes=[ plgpu.RefUnion( plgpu.SMEM((256,), jnp.float32), [ plgpu.SMEM((128,), jnp.float32), plgpu.SMEM((128,), jnp.float32), ], ) ], ) def kernel(x_ref, o_ref128, aliased_ref): smem_ref256, _, smem_ref128 = aliased_ref smem_ref256[...] = x_ref[...] + 1 plgpu.commit_smem() plgpu.copy_smem_to_gmem(smem_ref128, o_ref128) ``` PiperOrigin-RevId: 750624152 --- jax/_src/pallas/mosaic_gpu/core.py | 166 ++++++++++++++++++++++++- jax/_src/pallas/mosaic_gpu/lowering.py | 122 +++++++++++++++--- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 125 +++++++++++++++++++ 4 files changed, 393 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index e775d9f3cc64..6bbf53479d00 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -22,14 +22,15 @@ import dataclasses import enum import itertools as it +import math from typing import Any, ClassVar, Literal, Union import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import effects -from jax._src import tree_util from jax._src import pretty_printer as pp +from jax._src import tree_util from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives @@ -47,6 +48,11 @@ DimensionSemantics = Literal["parallel", "sequential"] +# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very +# sensitive to alignment and while this is quite conservative, it gets the job +# done. We should make this more refined in the future. +SMEM_ALIGNMENT = 1024 + def is_trivial_index(idx, shape) -> bool: """Checks if the index selects the entire shape.""" @@ -222,6 +228,130 @@ def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: return ref +def align_to(x: int, alignment: int): + if rem := x % alignment: + return x + alignment - rem + return x + + +# A tree of `GPUMemoryRef`s. +_GPUMemoryRefTree = Any + + +def _ref_group_size(refs: _GPUMemoryRefTree) -> int: + if isinstance(refs, GPUMemoryRef): + refs = (refs,) + size = 0 + for ref in jax.tree.leaves(refs): + # Make sure that the start of each ref is aligned with `SMEM_ALIGNMENT`. + size = align_to(size, SMEM_ALIGNMENT) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError("Only byte-aligned shapes are supported.") + size += ref_bits // 8 + return size + + +def flatten_ref_union( + ref_union: AbstractRefUnion, +) -> tuple[pallas_core.AbstractMemoryRef | state_types.TransformedRef, ...]: + """Flattens a union of trees of references into a tuple of references. + + This is the moral equivalent of `jax.tree.leaves` for aliased references. + """ + flat_refs = [] + union_bytes = 0 + for ref_group in ref_union.refs: + byte_offset = 0 + for ref in jax.tree.leaves(ref_group): + byte_offset = align_to(byte_offset, SMEM_ALIGNMENT) + assert isinstance(ref, pallas_core.AbstractMemoryRef) or isinstance( + ref, pallas_core.TransformedRef + ) + if not isinstance(ref, pallas_core.TransformedRef): + ref = pallas_core.TransformedRef(ref, transforms=()) + transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset) + flat_refs.append( + pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) + ) + ) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError("Only byte-aligned shapes are supported.") + byte_offset += ref_bits // 8 + union_bytes = max(union_bytes, byte_offset) + assert union_bytes == ref_union.shape[0] + return tuple(flat_refs) + + +class AbstractRefUnion(pallas_core.AbstractMemoryRef): + refs: Sequence[_GPUMemoryRefTree] + + def __init__( + self, + aval, + refs: Sequence[_GPUMemoryRefTree], + memory_space, + ): + self.refs = refs + super().__init__(aval, memory_space=memory_space) + + def _iter(self, tracer): + return iter(flatten_ref_union(tracer)) + + def _getitem(self, tracer, index): + return list(iter(tracer))[index] + + def _setitem(self, tracer, index, value): + del tracer, index, value # Unused. + raise ValueError("Ref unions can't be assigned to.") + + +@dataclasses.dataclass(init=False, frozen=True) +class RefUnion(GPUMemoryRef): + """A sequence of trees of refs that are allowed to reuse the same memory. + + One should not make assumptions as to how each ref will map to the underlying + memory region, since arbitrary padding may be applied inbetween different + refs. + + As such, ref unions are only safe to use when the groups of refs that we + intend to alias have disjoint lifetimes (i.e. one should never attempt to read + data using a different ref than the one that was used to write the data). + """ + refs: Sequence[_GPUMemoryRefTree] = () + + def __init__(self, *refs: _GPUMemoryRefTree): + if any(ref.memory_space != SMEM for ref in jax.tree.leaves(refs)): + raise NotImplementedError("Only SMEM refs can be aliased.") + object.__setattr__(self, "refs", refs) + num_bytes = max(map(_ref_group_size, self.refs)) + super().__init__( + shape=(num_bytes,), + dtype=jnp.int8, + memory_space=SMEM, + transforms=(), + ) + + def get_ref_aval(self) -> AbstractRefUnion: + inner_aval = jax.core.ShapedArray(self.shape, self.dtype) + refs_aval = jax.tree.map(lambda ref: ref.get_ref_aval(), self.refs) + return AbstractRefUnion(inner_aval, refs_aval, memory_space=SMEM) + + class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): @abc.abstractmethod def to_gpu_transform(self) -> mgpu.MemRefTransform: @@ -458,6 +588,40 @@ def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: return transform_ref(ref, UnswizzleRef(swizzle)) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class ExtractAliasedRef(state_types.Transform): + """Bitcasts the underlying ref at the given offset to the given shape and dtype.""" + dtype: dtypes.DType + shape: tuple[int, ...] + offset: int + + @classmethod + def from_transformed_ref( + cls, ref: pallas_core.TransformedRef, byte_offset: int + ): + return cls( + dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset + ) + + def transform_shape(self, shape): + if shape is None: + return None + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused. + return self.dtype + + def tree_flatten(self): + return (), (self.dtype, self.shape, self.offset) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): swizzle: int diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b465d05ba7d4..7c47e6e6309b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -73,17 +73,8 @@ partial = functools.partial SMEM = gpu_core.SMEM -# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very -# sensitive to alignment and while this is quite conservative, it gets the job -# done. We should make this more refined in the future. -_SMEM_ALIGNMENT = 1024 WARPGROUP_SIZE = 128 -def _align_to(x: int, alignment: int): - if (rem := x % alignment): - return x + alignment - rem - return x - @dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: @@ -114,13 +105,13 @@ def __post_init__(self): object.__setattr__( self, "smem_scratch_bytes", - _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), + gpu_core.align_to(self.smem_scratch_bytes, gpu_core.SMEM_ALIGNMENT), ) object.__setattr__( self, "tmem_scratch_cols", # TMEM must be allocated in 128x8 chunks. - _align_to(self.tmem_scratch_cols, 8), + gpu_core.align_to(self.tmem_scratch_cols, 8), ) @property @@ -393,7 +384,7 @@ def scratch_view( ) views = [] off = initial_used_bytes = self.smem_used_bytes - assert off % _SMEM_ALIGNMENT == 0 + assert off % gpu_core.SMEM_ALIGNMENT == 0 for s in structs: scratch_ty = ir.MemRefType.get( s.shape, @@ -411,11 +402,12 @@ def scratch_view( view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32)) views.append(view) - off += _align_to( - math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT + off += gpu_core.align_to( + math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, + gpu_core.SMEM_ALIGNMENT, ) assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM" - assert off % _SMEM_ALIGNMENT == 0 + assert off % gpu_core.SMEM_ALIGNMENT == 0 self.smem_used_bytes = off yield views @@ -1066,6 +1058,93 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), ) + +def _handle_dtype_bitcast( + ref: ir.Value, src_dtype: ir.Type, dst_dtype: ir.Type +) -> ir.Value: + """Allows bitcasting a SMEM ref from one element type to another. + + Args: + ref: the reference to bitcast. + src_dtype: the source element type. + dst_dtype: the destination element type. + + Returns: + A bitcasted version of `ref` with element type `dst_dtype`. + + Raises: + ValueError: if the source ref is not in SMEM. + """ + if src_dtype == dst_dtype: + return ref + if src_dtype != ir.IntegerType.get_signless(8): + raise NotImplementedError( + "Data type bitcast is only supported from i8 to other types." + ) + ref_ty = ir.MemRefType(ref.type) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + if len(ref_ty.shape) != 1: + raise NotImplementedError( + "Data type bitcast is only supported for 1D arrays." + ) + [stride], _ = ref_ty.get_strides_and_offset() + if stride != 1: + raise ValueError( + "Data type bitcast is only supported for contiguous 1D arrays, but got " + f"stride={stride}." + ) + [shape_bytes] = ref_ty.shape + shape_bitwidth = shape_bytes * 8 + target_bitwidth = mgpu_utils.bitwidth(dst_dtype) + + if shape_bitwidth % target_bitwidth: + raise ValueError( + f"Can not bitcast memory region of size {shape_bitwidth} bits to dtype " + f"with {target_bitwidth} bits." + ) + + result_type = ir.MemRefType.get( + shape=(shape_bitwidth // target_bitwidth,), + element_type=dst_dtype, + memory_space=ref_ty.memory_space, + ) + + # Do a memref_ptr/ptr_as_memref roundtrip instead of using `memref.view`, + # which refuses to take in our source ref. This is because `memref.view` only + # works on a super restricted set of `memref`s. E.g., it does not work if an + # offset is specified, which can be the case for our SMEM refs. + smem = mgpu_utils.WORKGROUP_NVPTX_ADDRESS_SPACE + ref = mgpu_utils.memref_ptr(ref, memory_space=smem) + return mgpu_utils.ptr_as_memref(ref, result_type, ptr_memory_space=smem) + + +def _extract_aliased_ref( + ref: ir.Value, transforms: Sequence[gpu_core.Transform] +) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: + match transforms: + case ( + gpu_core.ExtractAliasedRef(dtype, transformed_shape, offset), + *other_transforms, + ): + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) + ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(mlir_dtype) + if ref_bits % 8: + raise NotImplementedError("Only byte-aligned bitcasts are supported.") + assert offset % gpu_core.SMEM_ALIGNMENT == 0 + ref_bytes = ref_bits // 8 + ref = mgpu.memref_slice(ref, slice(offset, offset + ref_bytes)) + ref = _handle_dtype_bitcast( + ref, + ir.MemRefType(ref.type).element_type, + mgpu_utils.dtype_to_ir_type(dtype), + ) + ref = mgpu.memref_reshape(ref, transformed_shape) + return ref, tuple(other_transforms) + case _: + return ref, transforms + + def _handle_transforms( ref: ir.Value, transforms: Sequence[gpu_core.Transform], @@ -1073,6 +1152,9 @@ def _handle_transforms( handle_transposes=True, handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: + # Before we handle other transforms, we resolve any possible leading aliasing + # transform. + ref, transforms = _extract_aliased_ref(ref, transforms) transformed_ref = ref mlir_dtype = ir.MemRefType(ref.type).element_type new_transforms = [] @@ -1207,7 +1289,7 @@ def _swap_lowering_rule( raise TypeError(f"Can only store arrays (got {value}).") if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only store to references (got {x_smem}).") - x_aval = ctx.avals_in[0] + v_aval = ctx.avals_in[1] transforms = jax.tree.unflatten(tree, leaves) transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT x_smem, transforms = _handle_transforms( @@ -1219,7 +1301,7 @@ def _swap_lowering_rule( gpu_core.UntileRef(tiling), *maybe_transpose, ): - if tiling != (8, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // v_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") if transposed_value != bool(maybe_transpose): @@ -1237,7 +1319,7 @@ def _swap_lowering_rule( old_value = mgpu.FragmentedArray.load_tiled( x_smem, - is_signed=mgpu_utils.is_signed(x_aval.dtype), + is_signed=mgpu_utils.is_signed(v_aval.dtype), swizzle=swizzle, layout=value.layout, ) @@ -1249,14 +1331,14 @@ def _swap_lowering_rule( old_value = mgpu.FragmentedArray.load_untiled( x_smem, layout=value.layout, - is_signed=mgpu_utils.is_signed(x_aval.dtype), + is_signed=mgpu_utils.is_signed(v_aval.dtype), optimized=False, ) value.store_untiled(x_smem, optimized=False) return old_value case _: old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype) ) value.store_untiled(x_smem) return old_value diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index d74ffe6eae1b..63d0019fb99f 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -24,6 +24,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index ad80e3cf4c0a..bbd6f57fd653 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1589,6 +1589,131 @@ def _(): ): kernel() + def test_smem_aliasing_works(self): + self.skip_if_wg_semantics() + + in_shape = (2, 256) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec(in_shape)], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # Use a value larger than the number of bytes used for SMEM + # alignment (1024) in order to make sure that the second ref + # in the second group aliases the single ref in the first group. + plgpu.SMEM(in_shape, jnp.float32), + [ + plgpu.SMEM((256,), jnp.bfloat16), + # Add an arbitrary level of nesting to make sure that we + # support PyTrees. + [ + plgpu.SMEM( + (128,), + jnp.float32, + transforms=(plgpu.TilingTransform((64,)),), + ), + ] + ], + ) + ], + ) + def kernel(x_ref, o_ref128, aliased_ref): + smem_ref256, _, smem_ref128 = aliased_ref + # Ensure that extraction via index works the same as unfolding. + self.assertEqual(smem_ref128, aliased_ref[2]) + extract_alias_transform, tile_transform = smem_ref128.transforms + # Ensure that the transforms provided in the scratch shapes have been + # passed correctly. + self.assertIsInstance(extract_alias_transform, gpu_core.ExtractAliasedRef) + self.assertIsInstance(tile_transform, gpu_core.UntileRef) + smem_ref256[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref128, o_ref128) + + x = jnp.arange(512).astype(jnp.float32) + np.testing.assert_array_equal( + kernel(x.reshape(in_shape)).reshape((128,)), x[256 : 256 + 128] + 1 + ) + + def test_smem_aliasing_works_with_subbyte_dtypes(self): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.uint4), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((256,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # This allocation scheme is a bit complicated, but serves to + # test that + # 1. Refs are aligned correctly (currently to 1024 bytes); + # 2. (u)int4 references are not allocated more than 1 byte per + # 2 elements. + # The first group of refs serves to create two allocations, each + # aligned to 1024 bytes. The second group serves to create two + # allocations where the first one is exactly 1024 bytes, + # assuming 1 byte per 2 uint4 elements. As a result, if our + # implementation is correct, the second allocation of the second + # group should exactly alias the second allocation of the first + # group. + [ + plgpu.SMEM((128,), jnp.int8), + plgpu.SMEM((128,), jnp.int8), + ], + [plgpu.SMEM((2048,), jnp.uint4), plgpu.SMEM((256,), jnp.uint4)], + ) + ], + ) + def kernel(x_ref, o_refi4, aliased_ref): + _, smem_refi8, _, smem_refi4 = aliased_ref + smem_refi8[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_refi4, o_refi4) + + def unpack_i4_as_i8(x): + x = x.reshape((128, 1)) + x_high = x >> 4 + x_low = x & 0xF + return jnp.concatenate([x_low, x_high], axis=-1).reshape((256,)) + + x = jnp.arange(128).astype(jnp.int8) + test_as_i8 = jax.lax.convert_element_type(kernel(x), new_dtype=jnp.int8) + np.testing.assert_array_equal(test_as_i8[:256], unpack_i4_as_i8(x)) + + def test_assigning_to_ref_union_raises(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32))], + ) + def kernel(x_ref, o_ref128, aliased_ref): + aliased_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(aliased_ref, o_ref128) + + with self.assertRaisesRegex(ValueError, "can't be assigned to"): + kernel(jnp.arange(128).astype(jnp.float32)) + class PallasCallWGTest( PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From 0f067c54a54f7896c32acd705a83144a5f956b90 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 23 Apr 2025 10:09:30 -0700 Subject: [PATCH 0764/1769] Add a pvary on the `ones` we create in grad because the `ans` in the issue example is already varying on `x` and we were dropping that when we called into `lax_internal._one` Fixes https://github.com/jax-ml/jax/issues/28193 PiperOrigin-RevId: 750634402 --- jax/_src/lax/lax.py | 12 ++++++++---- tests/shard_map_test.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8db5b09d913c..6d9176a3e2d8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8755,15 +8755,19 @@ def _const(example, val): def _zero(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=0, + sharding=x_aval.sharding.with_spec(P())) + out = core.pvary(out, tuple(x_aval.vma)) + return out _ones: Callable = partial(full_like, fill_value=1) def _one(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=1, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=1, + sharding=x_aval.sharding.with_spec(P())) + out = core.pvary(out, tuple(x_aval.vma)) + return out _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3b63be231d90..09cc79d17b5f 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3007,6 +3007,19 @@ def test_no_mesh_context_error(self): with self.assertRaisesRegex(ValueError, "The context mesh cannot be empty"): jax.shard_map(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) + def test_pvary_in_shmap_of_grad(self): + mesh = jtu.create_mesh((2,), 'x') + + def g(x): + return jnp.mean(x ** 2) + + def f(x): + val, grad = jax.value_and_grad(g)(x) + return (jnp.atleast_1d(val), jnp.atleast_1d(grad)) + + jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x') + )(jnp.ones(2,)) # doesn't crash + class FunSpec(NamedTuple): name: str From 39d7c38229f199cb9d21fd52816f83784f1cfa54 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Apr 2025 10:52:29 -0700 Subject: [PATCH 0765/1769] Move contents of jaxlib/xla into jaxlib/ Having the directory structure of the jaxlib wheel be different to the source tree confuses type checkers such as mypy, since sometimes they find type stubs in the installed jaxlib wheel, and sometimes from the installed source tree. Instead: * don't include type stubs in the jaxlib wheel * don't install the jaxlib wheel as part of pre-commit * make sure that the location of type stubs (and the underlying libraries) is in the same position in the `jaxlib/` directory of the JAX source tree as it would be for the jaxlib wheel when installed. For now, we leave some stubs that forward from the old locations to the new locations for certain headers and modules. These will be removed after migrating some users. PiperOrigin-RevId: 750650528 --- .pre-commit-config.yaml | 2 +- jax/_src/lib/BUILD | 4 +- jaxlib/BUILD | 1042 ++++++++++++++++- jaxlib/{xla => }/callback.cc | 4 +- jaxlib/{xla => }/callback.h | 6 +- jaxlib/{xla => }/config.cc | 4 +- jaxlib/{xla => }/config.h | 6 +- jaxlib/{xla => }/config_test.py | 2 +- jaxlib/{xla => }/custom_call_sharding.cc | 2 +- jaxlib/{xla => }/custom_call_sharding.h | 6 +- jaxlib/{xla => }/custom_calls_testlib.cc | 0 jaxlib/{xla => }/dlpack.cc | 14 +- jaxlib/{xla => }/dlpack.h | 10 +- jaxlib/{xla => }/guard_lib.cc | 2 +- jaxlib/{xla => }/guard_lib.h | 6 +- jaxlib/{xla => }/ifrt_proxy.cc | 6 +- jaxlib/{xla => }/ifrt_proxy.h | 6 +- jaxlib/{xla => }/jax_jit.cc | 8 +- jaxlib/{xla => }/jax_jit.h | 14 +- jaxlib/{xla => }/jax_jit_test.py | 2 +- jaxlib/{xla => }/mlir.cc | 2 +- jaxlib/{xla => }/mlir.h | 6 +- jaxlib/mlir/_mlir_libs/_triton_ext.pyi | 2 +- jaxlib/nb_class_ptr.h | 59 + jaxlib/{xla => }/pjit.cc | 24 +- jaxlib/{xla => }/pjit.h | 6 +- jaxlib/{xla => }/pmap_lib.cc | 28 +- jaxlib/{xla => }/pmap_lib.h | 6 +- jaxlib/{xla => }/py_array.cc | 24 +- jaxlib/py_array.h | 365 ++++++ jaxlib/{xla => }/py_client.cc | 24 +- jaxlib/py_client.h | 252 ++++ jaxlib/{xla => }/py_client_cpu.cc | 2 +- jaxlib/{xla => }/py_client_cpu.h | 6 +- jaxlib/{xla => }/py_compile_only_client.cc | 6 +- jaxlib/{xla => }/py_compile_only_client.h | 10 +- jaxlib/{xla => }/py_device.cc | 10 +- jaxlib/py_device.h | 83 ++ jaxlib/{xla => }/py_device_list.cc | 10 +- jaxlib/py_device_list.h | 136 +++ jaxlib/{xla => }/py_executable.cc | 12 +- jaxlib/py_executable.h | 254 ++++ jaxlib/{xla => }/py_host_callback.cc | 8 +- jaxlib/{xla => }/py_host_callback.h | 6 +- jaxlib/{xla => }/py_host_callback.proto | 0 jaxlib/{xla => }/py_memory_space.cc | 6 +- jaxlib/{xla => }/py_memory_space.h | 10 +- jaxlib/{xla => }/py_program.cc | 12 +- jaxlib/{xla => }/py_program.h | 6 +- jaxlib/{xla => }/py_socket_transfer.cc | 12 +- jaxlib/{xla => }/py_socket_transfer.h | 6 +- jaxlib/{xla => }/py_values.cc | 10 +- jaxlib/{xla => }/py_values.h | 6 +- jaxlib/{xla => }/python_ref_manager.cc | 2 +- jaxlib/python_ref_manager.h | 108 ++ jaxlib/{xla => }/pytree.cc | 6 +- jaxlib/pytree.h | 408 +++++++ jaxlib/{xla => }/pytree.proto | 0 jaxlib/{xla => }/pytree_test.py | 2 +- jaxlib/{xla => }/sdy.cc | 2 +- jaxlib/{xla => }/sdy.h | 6 +- jaxlib/setup.py | 3 +- jaxlib/{xla => }/sharded_device_array.h | 6 +- jaxlib/{xla => }/sharding.cc | 10 +- jaxlib/sharding.h | 241 ++++ jaxlib/{xla => }/to_ifrt_sharding.cc | 8 +- jaxlib/{xla => }/to_ifrt_sharding.h | 6 +- jaxlib/tools/BUILD.bazel | 8 +- jaxlib/tools/build_wheel.py | 40 +- jaxlib/{xla => }/traceback.cc | 4 +- jaxlib/traceback.h | 109 ++ jaxlib/{xla => }/util.cc | 2 +- jaxlib/{xla => }/util.h | 6 +- jaxlib/{xla => }/xla.cc | 50 +- jaxlib/xla/BUILD | 1012 +--------------- jaxlib/xla/nb_class_ptr.h | 44 +- jaxlib/xla/py_array.h | 348 +----- jaxlib/xla/py_client.h | 235 +--- jaxlib/xla/py_device.h | 66 +- jaxlib/xla/py_device_list.h | 119 +- jaxlib/xla/py_executable.h | 237 +--- jaxlib/xla/python_ref_manager.h | 91 +- jaxlib/xla/pytree.h | 391 +------ jaxlib/xla/sharding.h | 224 +--- jaxlib/xla/traceback.h | 92 +- jaxlib/xla/xla_client.py | 972 +-------------- jaxlib/{ => xla}/xla_extension.py | 9 +- jaxlib/xla_client.py | 973 ++++++++++++++- jaxlib/{xla => }/xla_client.pyi | 0 .../xla_client_backend_independent_test.py | 2 +- jaxlib/{xla => }/xla_client_test.py | 4 +- jaxlib/{xla => }/xla_compiler.cc | 6 +- jaxlib/{xla => }/xla_compiler.h | 6 +- jaxlib/{xla => }/xla_extension/__init__.pyi | 2 + jaxlib/{xla => }/xla_extension/config.pyi | 0 jaxlib/{xla => }/xla_extension/guard_lib.pyi | 0 .../{xla => }/xla_extension/ifrt_programs.pyi | 2 +- jaxlib/{xla => }/xla_extension/ifrt_proxy.pyi | 2 +- jaxlib/{xla => }/xla_extension/jax_jit.pyi | 2 +- jaxlib/{xla => }/xla_extension/mlir.pyi | 0 jaxlib/{xla => }/xla_extension/ops.pyi | 2 +- jaxlib/{xla => }/xla_extension/pmap_lib.pyi | 0 jaxlib/{xla => }/xla_extension/profiler.pyi | 0 jaxlib/{xla => }/xla_extension/pytree.pyi | 0 jaxlib/{xla => }/xla_extension/sdy.pyi | 0 .../xla_extension/transfer_guard_lib.pyi | 0 pyproject.toml | 6 + 107 files changed, 4354 insertions(+), 4075 deletions(-) rename jaxlib/{xla => }/callback.cc (98%) rename jaxlib/{xla => }/callback.h (96%) rename jaxlib/{xla => }/config.cc (99%) rename jaxlib/{xla => }/config.h (91%) rename jaxlib/{xla => }/config_test.py (98%) rename jaxlib/{xla => }/custom_call_sharding.cc (99%) rename jaxlib/{xla => }/custom_call_sharding.h (86%) rename jaxlib/{xla => }/custom_calls_testlib.cc (100%) rename jaxlib/{xla => }/dlpack.cc (99%) rename jaxlib/{xla => }/dlpack.h (92%) rename jaxlib/{xla => }/guard_lib.cc (99%) rename jaxlib/{xla => }/guard_lib.h (97%) rename jaxlib/{xla => }/ifrt_proxy.cc (98%) rename jaxlib/{xla => }/ifrt_proxy.h (84%) rename jaxlib/{xla => }/jax_jit.cc (99%) rename jaxlib/{xla => }/jax_jit.h (97%) rename jaxlib/{xla => }/jax_jit_test.py (97%) rename jaxlib/{xla => }/mlir.cc (99%) rename jaxlib/{xla => }/mlir.h (90%) create mode 100644 jaxlib/nb_class_ptr.h rename jaxlib/{xla => }/pjit.cc (99%) rename jaxlib/{xla => }/pjit.h (90%) rename jaxlib/{xla => }/pmap_lib.cc (98%) rename jaxlib/{xla => }/pmap_lib.h (91%) rename jaxlib/{xla => }/py_array.cc (99%) create mode 100644 jaxlib/py_array.h rename jaxlib/{xla => }/py_client.cc (98%) create mode 100644 jaxlib/py_client.h rename jaxlib/{xla => }/py_client_cpu.cc (99%) rename jaxlib/{xla => }/py_client_cpu.h (88%) rename jaxlib/{xla => }/py_compile_only_client.cc (97%) rename jaxlib/{xla => }/py_compile_only_client.h (88%) rename jaxlib/{xla => }/py_device.cc (98%) create mode 100644 jaxlib/py_device.h rename jaxlib/{xla => }/py_device_list.cc (98%) create mode 100644 jaxlib/py_device_list.h rename jaxlib/{xla => }/py_executable.cc (98%) create mode 100644 jaxlib/py_executable.h rename jaxlib/{xla => }/py_host_callback.cc (98%) rename jaxlib/{xla => }/py_host_callback.h (97%) rename jaxlib/{xla => }/py_host_callback.proto (100%) rename jaxlib/{xla => }/py_memory_space.cc (96%) rename jaxlib/{xla => }/py_memory_space.h (90%) rename jaxlib/{xla => }/py_program.cc (98%) rename jaxlib/{xla => }/py_program.h (88%) rename jaxlib/{xla => }/py_socket_transfer.cc (98%) rename jaxlib/{xla => }/py_socket_transfer.h (83%) rename jaxlib/{xla => }/py_values.cc (99%) rename jaxlib/{xla => }/py_values.h (98%) rename jaxlib/{xla => }/python_ref_manager.cc (98%) create mode 100644 jaxlib/python_ref_manager.h rename jaxlib/{xla => }/pytree.cc (99%) create mode 100644 jaxlib/pytree.h rename jaxlib/{xla => }/pytree.proto (100%) rename jaxlib/{xla => }/pytree_test.py (99%) rename jaxlib/{xla => }/sdy.cc (99%) rename jaxlib/{xla => }/sdy.h (90%) rename jaxlib/{xla => }/sharded_device_array.h (98%) rename jaxlib/{xla => }/sharding.cc (98%) create mode 100644 jaxlib/sharding.h rename jaxlib/{xla => }/to_ifrt_sharding.cc (97%) rename jaxlib/{xla => }/to_ifrt_sharding.h (94%) rename jaxlib/{xla => }/traceback.cc (99%) create mode 100644 jaxlib/traceback.h rename jaxlib/{xla => }/util.cc (98%) rename jaxlib/{xla => }/util.h (93%) rename jaxlib/{xla => }/xla.cc (97%) rename jaxlib/{ => xla}/xla_extension.py (78%) rename jaxlib/{xla => }/xla_client.pyi (100%) rename jaxlib/{xla => }/xla_client_backend_independent_test.py (99%) rename jaxlib/{xla => }/xla_client_test.py (99%) rename jaxlib/{xla => }/xla_compiler.cc (99%) rename jaxlib/{xla => }/xla_compiler.h (88%) rename jaxlib/{xla => }/xla_extension/__init__.pyi (99%) rename jaxlib/{xla => }/xla_extension/config.pyi (100%) rename jaxlib/{xla => }/xla_extension/guard_lib.pyi (100%) rename jaxlib/{xla => }/xla_extension/ifrt_programs.pyi (97%) rename jaxlib/{xla => }/xla_extension/ifrt_proxy.pyi (96%) rename jaxlib/{xla => }/xla_extension/jax_jit.pyi (98%) rename jaxlib/{xla => }/xla_extension/mlir.pyi (100%) rename jaxlib/{xla => }/xla_extension/ops.pyi (99%) rename jaxlib/{xla => }/xla_extension/pmap_lib.pyi (100%) rename jaxlib/{xla => }/xla_extension/profiler.pyi (100%) rename jaxlib/{xla => }/xla_extension/pytree.pyi (100%) rename jaxlib/{xla => }/xla_extension/sdy.pyi (100%) rename jaxlib/{xla => }/xla_extension/transfer_guard_lib.pyi (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89ce80d9a815..be3fc8a86bcf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0] + additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 78ddad29306b..0c8acdd76630 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -45,8 +45,8 @@ py_library_providing_imports_info( "//jaxlib:cpu_feature_guard", "//jaxlib:utils", "//jaxlib:weakref_lru_cache", - "//jaxlib/xla:xla_client", - "//jaxlib/xla:xla_extension", + "//jaxlib:xla_client", + "//jaxlib:xla_extension", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 373f8cd17674..2e1e4072cb55 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -16,11 +16,17 @@ load( "//jaxlib:jax.bzl", + "cc_proto_library", + "if_oss", + "jax_visibility", "nanobind_extension", + "proto_library", "py_deps", "py_library_providing_imports_info", + "py_strict_library", "py_strict_test", "pytype_library", + "pytype_strict_library", ) load( "//jaxlib:pywrap.bzl", @@ -37,6 +43,13 @@ package( default_visibility = ["//jax:internal"], ) +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], +) + py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -52,7 +65,6 @@ py_library_providing_imports_info( "lapack.py", "plugin_support.py", "xla_client.py", - "xla_extension.py", ":version", ], data = [":ffi_headers"], @@ -62,6 +74,8 @@ py_library_providing_imports_info( ":jax", ":utils", ":weakref_lru_cache", + "//jaxlib:xla_client", + "//jaxlib:xla_extension", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", @@ -84,8 +98,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - "//jaxlib/xla:xla_client", - "//jaxlib/xla:xla_extension", ], ) @@ -123,6 +135,7 @@ pywrap_library( deps = [ ":utils", ":weakref_lru_cache", + "//jaxlib:xla_extension", "//jaxlib/mlir/_mlir_libs:_chlo", "//jaxlib/mlir/_mlir_libs:_mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", @@ -138,7 +151,6 @@ pywrap_library( "//jaxlib/mlir/_mlir_libs:_tpu_ext", "//jaxlib/mlir/_mlir_libs:_triton_ext", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", - "//jaxlib/xla:xla_extension", ], ) @@ -261,3 +273,1025 @@ nanobind_pywrap_extension( "@xla//third_party/python_runtime:headers", ], ) + +nanobind_pywrap_extension( + name = "xla_extension", + srcs = ["xla.cc"], + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["xla_extension/*.pyi"]), + visibility = ["//visibility:public"], + deps = [ + ":config", + ":custom_call_sharding", + ":dlpack", + ":guard_lib", + ":ifrt_proxy", + ":jax_jit", + ":mlir", + ":nb_class_ptr", + ":pjit", + ":pmap_lib", + ":py_client", + ":python_ref_manager", + ":pytree", + ":sdy", + ":traceback", + ":util", + ":xla_compiler", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:logging", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:ops", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:profiler", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + ":py_socket_transfer", + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + +cc_library( + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:debug_callback_partitioner", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla:util", + ], +) + +cc_library( + name = "ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + hdrs = ["ifrt_proxy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_client", + ":python_ref_manager", + ":pytree", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:Support", + "@nanobind", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/nb_class_ptr"), + deps = ["@nanobind"], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":guard_lib", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client", + srcs = [ + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/py_client"), + deps = [ + ":guard_lib", + ":nb_class_ptr", + ":py_client_cpu", + ":py_host_callback", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@nanobind", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:types", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + ":python_ref_manager", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@nanobind", + "@tsl//tsl/platform:casts", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + ], +) + +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/pytree"), + deps = [ + ":nb_class_ptr", + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "sdy", + srcs = ["sdy.cc"], + hdrs = ["sdy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + ], +) + +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:async_value", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + ":py_client", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/ir:hlo_module_group", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/hlo/pass:hlo_pass", + "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", + "@xla//xla/hlo/transforms/simplifiers:hlo_dce", + "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", + "@xla//xla/pjrt:compile_options_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:call_inliner", + "@xla//xla/service:computation_placer", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service:name_uniquer", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + [":xla_extension"], +) + +py_strict_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/testing", + "numpy", + "portpicker", + ]), +) + +py_strict_library( + name = "xla_client_test", + testonly = 1, + srcs = ["xla_client_test.py"], + visibility = [":xla_python"], + deps = [ + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +nanobind_extension( + name = "custom_calls_testlib", + testonly = 1, + srcs = ["custom_calls_testlib.cc"], + deps = [ + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + ], +) + +py_strict_test( + name = "xla_client_test_cpu", + srcs = ["xla_client_test.py"], + args = ["--backend=cpu"], + env = { + "XLA_FLAGS": "--xla_force_host_platform_device_count=4", + }, + main = "xla_client_test.py", + deps = [ + ":custom_calls_testlib", + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "numpy", + ]), +) diff --git a/jaxlib/xla/callback.cc b/jaxlib/callback.cc similarity index 98% rename from jaxlib/xla/callback.cc rename to jaxlib/callback.cc index b5519ed3bee3..1262a534961c 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/callback.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/callback.h" +#include "jaxlib/callback.h" #include #include @@ -34,7 +34,7 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/python_ref_manager.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" diff --git a/jaxlib/xla/callback.h b/jaxlib/callback.h similarity index 96% rename from jaxlib/xla/callback.h rename to jaxlib/callback.h index ee1f35ce34a3..59844ebf73b9 100644 --- a/jaxlib/xla/callback.h +++ b/jaxlib/callback.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_CALLBACK_H_ -#define JAXLIB_XLA_CALLBACK_H_ +#ifndef JAXLIB_CALLBACK_H_ +#define JAXLIB_CALLBACK_H_ #include #include @@ -84,4 +84,4 @@ class CpuCallback { } // namespace xla -#endif // JAXLIB_XLA_CALLBACK_H_ +#endif // JAXLIB_CALLBACK_H_ diff --git a/jaxlib/xla/config.cc b/jaxlib/config.cc similarity index 99% rename from jaxlib/xla/config.cc rename to jaxlib/config.cc index 8804b783eb72..3d701c516990 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/config.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/config.h" +#include "jaxlib/config.h" #include @@ -27,7 +27,7 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/python_ref_manager.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/config.h b/jaxlib/config.h similarity index 91% rename from jaxlib/xla/config.h rename to jaxlib/config.h index 2a9281f498b4..e42673cb66fa 100644 --- a/jaxlib/xla/config.h +++ b/jaxlib/config.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_CONFIG_H_ -#define JAXLIB_XLA_CONFIG_H_ +#ifndef JAXLIB_CONFIG_H_ +#define JAXLIB_CONFIG_H_ #include @@ -31,4 +31,4 @@ void BuildConfigSubmodule(nanobind::module_& m); } // namespace jax -#endif // JAXLIB_XLA_CONFIG_H_ +#endif // JAXLIB_CONFIG_H_ diff --git a/jaxlib/xla/config_test.py b/jaxlib/config_test.py similarity index 98% rename from jaxlib/xla/config_test.py rename to jaxlib/config_test.py index 8701a37acd1d..734e9ed78896 100644 --- a/jaxlib/xla/config_test.py +++ b/jaxlib/config_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest -from jax.jaxlib.xla import xla_client +from jax.jaxlib import xla_client config = xla_client._xla.config diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/custom_call_sharding.cc similarity index 99% rename from jaxlib/xla/custom_call_sharding.cc rename to jaxlib/custom_call_sharding.cc index 00accd85aefd..3e16768d0c29 100644 --- a/jaxlib/xla/custom_call_sharding.cc +++ b/jaxlib/custom_call_sharding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/custom_call_sharding.h" +#include "jaxlib/custom_call_sharding.h" #include diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/custom_call_sharding.h similarity index 86% rename from jaxlib/xla/custom_call_sharding.h rename to jaxlib/custom_call_sharding.h index 5a5f3776cc30..454f60c3a03d 100644 --- a/jaxlib/xla/custom_call_sharding.h +++ b/jaxlib/custom_call_sharding.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ -#define JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ +#ifndef JAXLIB_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_CUSTOM_CALL_SHARDING_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildCustomCallShardingPybindAPI(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ +#endif // JAXLIB_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/custom_calls_testlib.cc similarity index 100% rename from jaxlib/xla/custom_calls_testlib.cc rename to jaxlib/custom_calls_testlib.cc diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/dlpack.cc similarity index 99% rename from jaxlib/xla/dlpack.cc rename to jaxlib/dlpack.cc index c8d02e679036..ca11f665550f 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/dlpack.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/dlpack.h" +#include "jaxlib/dlpack.h" #include @@ -35,12 +35,12 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/traceback.h" -#include "jaxlib/xla/util.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" diff --git a/jaxlib/xla/dlpack.h b/jaxlib/dlpack.h similarity index 92% rename from jaxlib/xla/dlpack.h rename to jaxlib/dlpack.h index 7fffdc345d79..54feb2b45dba 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/dlpack.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_DLPACK_H_ -#define JAXLIB_XLA_DLPACK_H_ +#ifndef JAXLIB_DLPACK_H_ +#define JAXLIB_DLPACK_H_ #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/python/ifrt/device.h" #include "xla/xla_data.pb.h" @@ -55,4 +55,4 @@ absl::StatusOr PrimitiveTypeToNbDLDataType( } // namespace xla -#endif // JAXLIB_XLA_DLPACK_H_ +#endif // JAXLIB_DLPACK_H_ diff --git a/jaxlib/xla/guard_lib.cc b/jaxlib/guard_lib.cc similarity index 99% rename from jaxlib/xla/guard_lib.cc rename to jaxlib/guard_lib.cc index 77866741819c..6ad1f8e5366c 100644 --- a/jaxlib/xla/guard_lib.cc +++ b/jaxlib/guard_lib.cc @@ -17,7 +17,7 @@ limitations under the License. // guards. // C++ backends are responsible for enforcing transfer guard levels. -#include "jaxlib/xla/guard_lib.h" +#include "jaxlib/guard_lib.h" #include #include diff --git a/jaxlib/xla/guard_lib.h b/jaxlib/guard_lib.h similarity index 97% rename from jaxlib/xla/guard_lib.h rename to jaxlib/guard_lib.h index 8ddf6e8e892e..93e632fb7c9a 100644 --- a/jaxlib/xla/guard_lib.h +++ b/jaxlib/guard_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_GUARD_LIB_H_ -#define JAXLIB_XLA_GUARD_LIB_H_ +#ifndef JAXLIB_GUARD_LIB_H_ +#define JAXLIB_GUARD_LIB_H_ #include #include @@ -112,4 +112,4 @@ void BuildGuardSubmodule(nanobind::module_& m); } // namespace jax -#endif // JAXLIB_XLA_GUARD_LIB_H_ +#endif // JAXLIB_GUARD_LIB_H_ diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/ifrt_proxy.cc similarity index 98% rename from jaxlib/xla/ifrt_proxy.cc rename to jaxlib/ifrt_proxy.cc index a89941f8581c..e91c4d9a3859 100644 --- a/jaxlib/xla/ifrt_proxy.cc +++ b/jaxlib/ifrt_proxy.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "jaxlib/xla/ifrt_proxy.h" +#include "jaxlib/ifrt_proxy.h" #include #include @@ -36,8 +36,8 @@ #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unordered_map.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" diff --git a/jaxlib/xla/ifrt_proxy.h b/jaxlib/ifrt_proxy.h similarity index 84% rename from jaxlib/xla/ifrt_proxy.h rename to jaxlib/ifrt_proxy.h index a8fcb9e676ff..2bfc19062012 100644 --- a/jaxlib/xla/ifrt_proxy.h +++ b/jaxlib/ifrt_proxy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ -#define JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#ifndef JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ #include "nanobind/nanobind.h" @@ -28,4 +28,4 @@ void BuildIfrtProxySubmodule(nanobind::module_& m); } // namespace ifrt } // namespace xla -#endif // JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#endif // JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/jax_jit.cc similarity index 99% rename from jaxlib/xla/jax_jit.cc rename to jaxlib/jax_jit.cc index 4645c59c7147..c48aa6ab7d19 100644 --- a/jaxlib/xla/jax_jit.cc +++ b/jaxlib/jax_jit.cc @@ -24,7 +24,7 @@ limitations under the License. // (a) inspect arguments and describe their structure, dtype/shapes, etc. // (b) keep a mapping from function signatures to compiled XLA Executables. -#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/jax_jit.h" #include @@ -53,9 +53,9 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/pytree.h" -#include "jaxlib/xla/sharding.h" +#include "jaxlib/py_values.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/jax_jit.h similarity index 97% rename from jaxlib/xla/jax_jit.h rename to jaxlib/jax_jit.h index 9eba2e9d3228..dc025e63f1de 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/jax_jit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_JAX_JIT_H_ -#define JAXLIB_XLA_JAX_JIT_H_ +#ifndef JAXLIB_JAX_JIT_H_ +#define JAXLIB_JAX_JIT_H_ #include @@ -35,10 +35,10 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/pytree.h" -#include "jaxlib/xla/sharding.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/tsl/platform/logging.h" @@ -263,4 +263,4 @@ void BuildJaxjitSubmodule(nanobind::module_& m); } // namespace jax -#endif // JAXLIB_XLA_JAX_JIT_H_ +#endif // JAXLIB_JAX_JIT_H_ diff --git a/jaxlib/xla/jax_jit_test.py b/jaxlib/jax_jit_test.py similarity index 97% rename from jaxlib/xla/jax_jit_test.py rename to jaxlib/jax_jit_test.py index a090bc8dfadd..c242823566dc 100644 --- a/jaxlib/xla/jax_jit_test.py +++ b/jaxlib/jax_jit_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.jaxlib.xla import xla_client +from jax.jaxlib import xla_client jax_jit = xla_client._xla.jax_jit pytree = xla_client._xla.pytree diff --git a/jaxlib/xla/mlir.cc b/jaxlib/mlir.cc similarity index 99% rename from jaxlib/xla/mlir.cc rename to jaxlib/mlir.cc index 76663a79556a..a632cac71d10 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/mlir.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/mlir.h" +#include "jaxlib/mlir.h" #include diff --git a/jaxlib/xla/mlir.h b/jaxlib/mlir.h similarity index 90% rename from jaxlib/xla/mlir.h rename to jaxlib/mlir.h index ee95f5f95921..bcbacb57a485 100644 --- a/jaxlib/xla/mlir.h +++ b/jaxlib/mlir.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_MLIR_H_ -#define JAXLIB_XLA_MLIR_H_ +#ifndef JAXLIB_MLIR_H_ +#define JAXLIB_MLIR_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildMlirSubmodule(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_MLIR_H_ +#endif // JAXLIB_MLIR_H_ diff --git a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi index 1e1a67405113..93a82010043c 100644 --- a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi +++ b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mlir import ir +from jaxlib.mlir import ir def register_dialect(context: ir.Context, load: bool = ...) -> None: ... diff --git a/jaxlib/nb_class_ptr.h b/jaxlib/nb_class_ptr.h new file mode 100644 index 000000000000..381c77e812b9 --- /dev/null +++ b/jaxlib/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_NB_CLASS_PTR_H_ +#define JAXLIB_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/pjit.cc similarity index 99% rename from jaxlib/xla/pjit.cc rename to jaxlib/pjit.cc index 02d8fc6efd01..24b99770b767 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/pjit.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/pjit.h" +#include "jaxlib/pjit.h" #include @@ -50,17 +50,17 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/config.h" -#include "jaxlib/xla/guard_lib.h" -#include "jaxlib/xla/jax_jit.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_executable.h" -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/pytree.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/traceback.h" +#include "jaxlib/config.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" diff --git a/jaxlib/xla/pjit.h b/jaxlib/pjit.h similarity index 90% rename from jaxlib/xla/pjit.h rename to jaxlib/pjit.h index 8d47347ab9a2..d86fa6bddc3c 100644 --- a/jaxlib/xla/pjit.h +++ b/jaxlib/pjit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PJIT_H_ -#define JAXLIB_XLA_PJIT_H_ +#ifndef JAXLIB_PJIT_H_ +#define JAXLIB_PJIT_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ namespace jax { void BuildPjitSubmodule(nanobind::module_& m); } -#endif // JAXLIB_XLA_PJIT_H_ +#endif // JAXLIB_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/pmap_lib.cc similarity index 98% rename from jaxlib/xla/pmap_lib.cc rename to jaxlib/pmap_lib.cc index 01b301d008af..527bc022237f 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/pmap_lib.h" #include @@ -44,19 +44,19 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/config.h" -#include "jaxlib/xla/jax_jit.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/py_executable.h" -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/pytree.h" -#include "jaxlib/xla/sharded_device_array.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/traceback.h" +#include "jaxlib/config.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharded_device_array.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/pmap_lib.h similarity index 91% rename from jaxlib/xla/pmap_lib.h rename to jaxlib/pmap_lib.h index 2bad85e59671..b7cc2cc13f36 100644 --- a/jaxlib/xla/pmap_lib.h +++ b/jaxlib/pmap_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PMAP_LIB_H_ -#define JAXLIB_XLA_PMAP_LIB_H_ +#ifndef JAXLIB_PMAP_LIB_H_ +#define JAXLIB_PMAP_LIB_H_ // placeholder for index annotation headers @@ -31,4 +31,4 @@ void BuildPmapSubmodule(nanobind::module_& m); } // namespace jax -#endif // JAXLIB_XLA_PMAP_LIB_H_ +#endif // JAXLIB_PMAP_LIB_H_ diff --git a/jaxlib/xla/py_array.cc b/jaxlib/py_array.cc similarity index 99% rename from jaxlib/xla/py_array.cc rename to jaxlib/py_array.cc index 584f895e32d2..103c003fa89b 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/py_array.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_array.h" +#include "jaxlib/py_array.h" #include @@ -57,17 +57,17 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/guard_lib.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/to_ifrt_sharding.h" -#include "jaxlib/xla/traceback.h" -#include "jaxlib/xla/util.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/pjrt/exceptions.h" diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h new file mode 100644 index 000000000000..ddb09bc41771 --- /dev/null +++ b/jaxlib/py_array.h @@ -0,0 +1,365 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_ARRAY_H_ +#define JAXLIB_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + void Clear(); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + tsl::RCReference ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nanobind::object sharding, + bool weak_type, bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + const std::optional& traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<>& result_status() const { + return GetStorage().result_status; + } + + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + absl::Status ReplaceWithAlias(PyArray o); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(tsl::RCReference ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.cc b/jaxlib/py_client.cc similarity index 98% rename from jaxlib/xla/py_client.cc rename to jaxlib/py_client.cc index 8b42da2fc9bd..e8251939592f 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/py_client.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_client.h" +#include "jaxlib/py_client.h" #include @@ -48,17 +48,17 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/guard_lib.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/py_executable.h" -#include "jaxlib/xla/py_host_callback.h" -#include "jaxlib/xla/py_memory_space.h" -#include "jaxlib/xla/py_values.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/traceback.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_host_callback.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h new file mode 100644 index 000000000000..3bc7057bc4ab --- /dev/null +++ b/jaxlib/py_client.h @@ -0,0 +1,252 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_H_ +#define JAXLIB_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy alises. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device* device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device* device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> CompileIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::Compile` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/py_client_cpu.cc similarity index 99% rename from jaxlib/xla/py_client_cpu.cc rename to jaxlib/py_client_cpu.cc index 91f4e4ee42b9..647f33c59900 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/py_client_cpu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_client_cpu.h" +#include "jaxlib/py_client_cpu.h" #include diff --git a/jaxlib/xla/py_client_cpu.h b/jaxlib/py_client_cpu.h similarity index 88% rename from jaxlib/xla/py_client_cpu.h rename to jaxlib/py_client_cpu.h index 0035b0a361fa..275a57fa06b5 100644 --- a/jaxlib/xla/py_client_cpu.h +++ b/jaxlib/py_client_cpu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PY_CLIENT_CPU_H_ -#define JAXLIB_XLA_PY_CLIENT_CPU_H_ +#ifndef JAXLIB_PY_CLIENT_CPU_H_ +#define JAXLIB_PY_CLIENT_CPU_H_ #include "xla/ffi/api/ffi.h" @@ -25,4 +25,4 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); } // namespace xla -#endif // JAXLIB_XLA_PY_CLIENT_CPU_H_ +#endif // JAXLIB_PY_CLIENT_CPU_H_ diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc similarity index 97% rename from jaxlib/xla/py_compile_only_client.cc rename to jaxlib/py_compile_only_client.cc index 673dfc214346..f9914edac52a 100644 --- a/jaxlib/xla/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_compile_only_client.h" +#include "jaxlib/py_compile_only_client.h" #include #include @@ -30,8 +30,8 @@ limitations under the License. #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/py_compile_only_client.h similarity index 88% rename from jaxlib/xla/py_compile_only_client.h rename to jaxlib/py_compile_only_client.h index 6cc700e1d3a9..4b274871ee96 100644 --- a/jaxlib/xla/py_compile_only_client.h +++ b/jaxlib/py_compile_only_client.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ -#define JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ +#ifndef JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ #include // placeholder for index annotation headers #include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" namespace xla { @@ -42,4 +42,4 @@ void RegisterCompileOnlyClient(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ +#endif // JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.cc b/jaxlib/py_device.cc similarity index 98% rename from jaxlib/xla/py_device.cc rename to jaxlib/py_device.cc index 253bfd439a9b..f830b4f49448 100644 --- a/jaxlib/xla/py_device.cc +++ b/jaxlib/py_device.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_device.h" +#include "jaxlib/py_device.h" #include @@ -36,10 +36,10 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_memory_space.h" -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/status_casters.h" diff --git a/jaxlib/py_device.h b/jaxlib/py_device.h new file mode 100644 index 000000000000..8366f8deae3e --- /dev/null +++ b/jaxlib/py_device.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_H_ +#define JAXLIB_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/py_device_list.cc similarity index 98% rename from jaxlib/xla/py_device_list.cc rename to jaxlib/py_device_list.cc index 205c971b9317..3bf5480c5363 100644 --- a/jaxlib/xla/py_device_list.cc +++ b/jaxlib/py_device_list.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/py_device_list.h" #include @@ -32,10 +32,10 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/python_ref_manager.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/types.h" diff --git a/jaxlib/py_device_list.h b/jaxlib/py_device_list.h new file mode 100644 index 000000000000..5caba6f3dec7 --- /dev/null +++ b/jaxlib/py_device_list.h @@ -0,0 +1,136 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_LIST_H_ +#define JAXLIB_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/py_executable.cc similarity index 98% rename from jaxlib/xla/py_executable.cc rename to jaxlib/py_executable.cc index eaf5af34f883..16cd512a4007 100644 --- a/jaxlib/xla/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_executable.h" +#include "jaxlib/py_executable.h" #include @@ -36,11 +36,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/traceback.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h new file mode 100644 index 000000000000..b6a39c6968b8 --- /dev/null +++ b/jaxlib/py_executable.h @@ -0,0 +1,254 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_EXECUTABLE_H_ +#define JAXLIB_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector> Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector> ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Python wrapper around PjRtExecutable. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + std::shared_ptr shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + void Delete() { + // TODO(hyeontaek): Return absl::Status. + TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); + } + + bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // PjRtExecutable::Execute. The result is similarly transposed back into the + // argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional& traceback() { return traceback_; } + + ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int64_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + std::shared_ptr ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/py_host_callback.cc similarity index 98% rename from jaxlib/xla/py_host_callback.cc rename to jaxlib/py_host_callback.cc index fdb40c04b517..49525db53ca5 100644 --- a/jaxlib/xla/py_host_callback.cc +++ b/jaxlib/py_host_callback.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_host_callback.h" +#include "jaxlib/py_host_callback.h" #include #include @@ -30,9 +30,9 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "nanobind/nanobind.h" -#include "jaxlib/xla/callback.h" -#include "jaxlib/xla/py_host_callback.pb.h" -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/callback.h" +#include "jaxlib/py_host_callback.pb.h" +#include "jaxlib/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" #include "xla/python/ifrt/client.h" diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/py_host_callback.h similarity index 97% rename from jaxlib/xla/py_host_callback.h rename to jaxlib/py_host_callback.h index 1a1402a4eee2..b98338988bfd 100644 --- a/jaxlib/xla/py_host_callback.h +++ b/jaxlib/py_host_callback.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PY_HOST_CALLBACK_H_ -#define JAXLIB_XLA_PY_HOST_CALLBACK_H_ +#ifndef JAXLIB_PY_HOST_CALLBACK_H_ +#define JAXLIB_PY_HOST_CALLBACK_H_ #include #include @@ -116,4 +116,4 @@ class PyHostSendAndRecvLoadedHostCallback final } // namespace xla -#endif // JAXLIB_XLA_PY_HOST_CALLBACK_H_ +#endif // JAXLIB_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.proto b/jaxlib/py_host_callback.proto similarity index 100% rename from jaxlib/xla/py_host_callback.proto rename to jaxlib/py_host_callback.proto diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/py_memory_space.cc similarity index 96% rename from jaxlib/xla/py_memory_space.cc rename to jaxlib/py_memory_space.cc index 0409861dd3b9..2c123942a92d 100644 --- a/jaxlib/xla/py_memory_space.cc +++ b/jaxlib/py_memory_space.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/py_memory_space.h" #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/python/ifrt/device.h" namespace nb = ::nanobind; diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/py_memory_space.h similarity index 90% rename from jaxlib/xla/py_memory_space.h rename to jaxlib/py_memory_space.h index f38038af4870..2196a6cd9f30 100644 --- a/jaxlib/xla/py_memory_space.h +++ b/jaxlib/py_memory_space.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PY_MEMORY_SPACE_H_ -#define JAXLIB_XLA_PY_MEMORY_SPACE_H_ +#ifndef JAXLIB_PY_MEMORY_SPACE_H_ +#define JAXLIB_PY_MEMORY_SPACE_H_ #include #include "absl/strings/string_view.h" #include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" #include "xla/python/ifrt/memory.h" namespace xla { @@ -62,4 +62,4 @@ class PyMemorySpace { } // namespace xla -#endif // JAXLIB_XLA_PY_MEMORY_SPACE_H_ +#endif // JAXLIB_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/xla/py_program.cc b/jaxlib/py_program.cc similarity index 98% rename from jaxlib/xla/py_program.cc rename to jaxlib/py_program.cc index b3828f5372d9..d01df5e82b1b 100644 --- a/jaxlib/xla/py_program.cc +++ b/jaxlib/py_program.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_program.h" +#include "jaxlib/py_program.h" #include #include @@ -34,11 +34,11 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/sharding.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_executable.h" diff --git a/jaxlib/xla/py_program.h b/jaxlib/py_program.h similarity index 88% rename from jaxlib/xla/py_program.h rename to jaxlib/py_program.h index 9fd30eeeed2f..7772d740c41e 100644 --- a/jaxlib/xla/py_program.h +++ b/jaxlib/py_program.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_PY_PROGRAM_H_ -#define JAXLIB_XLA_PY_PROGRAM_H_ +#ifndef JAXLIB_PY_PROGRAM_H_ +#define JAXLIB_PY_PROGRAM_H_ #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ void BuildIfrtProgramsSubmodule(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_PY_PROGRAM_H_ +#endif // JAXLIB_PY_PROGRAM_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc similarity index 98% rename from jaxlib/xla/py_socket_transfer.cc rename to jaxlib/py_socket_transfer.cc index 4aa40cf66087..a0bd943333ee 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_socket_transfer.h" +#include "jaxlib/py_socket_transfer.h" #include #include @@ -35,11 +35,11 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/to_ifrt_sharding.h" -#include "jaxlib/xla/traceback.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" diff --git a/jaxlib/xla/py_socket_transfer.h b/jaxlib/py_socket_transfer.h similarity index 83% rename from jaxlib/xla/py_socket_transfer.h rename to jaxlib/py_socket_transfer.h index fa477f24e3e5..1b0236b56889 100644 --- a/jaxlib/xla/py_socket_transfer.h +++ b/jaxlib/py_socket_transfer.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ -#define JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ +#ifndef JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ #include "nanobind/nanobind.h" @@ -23,4 +23,4 @@ void RegisterTransferServerTypes(nanobind::module_& m); } // namespace aux -#endif // JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ +#endif // JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/xla/py_values.cc b/jaxlib/py_values.cc similarity index 99% rename from jaxlib/xla/py_values.cc rename to jaxlib/py_values.cc index 2ad0d5849ba8..b14c1f22708b 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/py_values.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/py_values.h" +#include "jaxlib/py_values.h" #include @@ -44,10 +44,10 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/complex.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/py_array.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" diff --git a/jaxlib/xla/py_values.h b/jaxlib/py_values.h similarity index 98% rename from jaxlib/xla/py_values.h rename to jaxlib/py_values.h index afa59a839d5d..40b186fb7fc0 100644 --- a/jaxlib/xla/py_values.h +++ b/jaxlib/py_values.h @@ -15,8 +15,8 @@ limitations under the License. // Helpers for converting Python values into buffers. -#ifndef JAXLIB_XLA_PY_VALUES_H_ -#define JAXLIB_XLA_PY_VALUES_H_ +#ifndef JAXLIB_PY_VALUES_H_ +#define JAXLIB_PY_VALUES_H_ #include #include @@ -138,4 +138,4 @@ H AbslHashValue(H h, const xla::PyArgSignature& s) { } // namespace xla -#endif // JAXLIB_XLA_PY_VALUES_H_ +#endif // JAXLIB_PY_VALUES_H_ diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/python_ref_manager.cc similarity index 98% rename from jaxlib/xla/python_ref_manager.cc rename to jaxlib/python_ref_manager.cc index 5b85d2ab84cb..64bd0041b625 100644 --- a/jaxlib/xla/python_ref_manager.cc +++ b/jaxlib/python_ref_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/python_ref_manager.h" #include diff --git a/jaxlib/python_ref_manager.h b/jaxlib/python_ref_manager.h new file mode 100644 index 000000000000..37eae1cae84d --- /dev/null +++ b/jaxlib/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTHON_REF_MANAGER_H_ +#define JAXLIB_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/pytree.cc similarity index 99% rename from jaxlib/xla/pytree.cc rename to jaxlib/pytree.cc index 08feeec94f1c..2700ac9e6c9a 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/pytree.cc @@ -16,7 +16,7 @@ limitations under the License. // Caution: this code uses exceptions. The exception use is local to the // binding code and the idiomatic way to emit Python exceptions. -#include "jaxlib/xla/pytree.h" +#include "jaxlib/pytree.h" #include @@ -49,8 +49,8 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/tuple.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/pytree.pb.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" #include "xla/pjrt/exceptions.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/pytree.h b/jaxlib/pytree.h new file mode 100644 index 000000000000..0a012d933c70 --- /dev/null +++ b/jaxlib/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTREE_H_ +#define JAXLIB_PYTREE_H_ + +// See https://docs.jax.dev/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl(nanobind::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_PYTREE_H_ diff --git a/jaxlib/xla/pytree.proto b/jaxlib/pytree.proto similarity index 100% rename from jaxlib/xla/pytree.proto rename to jaxlib/pytree.proto diff --git a/jaxlib/xla/pytree_test.py b/jaxlib/pytree_test.py similarity index 99% rename from jaxlib/xla/pytree_test.py rename to jaxlib/pytree_test.py index b5ac7dd5b4d2..a8846a91ea2b 100644 --- a/jaxlib/xla/pytree_test.py +++ b/jaxlib/pytree_test.py @@ -18,7 +18,7 @@ from absl.testing import absltest -from jax.jaxlib.xla import xla_client +from jax.jaxlib import xla_client pytree = xla_client._xla.pytree diff --git a/jaxlib/xla/sdy.cc b/jaxlib/sdy.cc similarity index 99% rename from jaxlib/xla/sdy.cc rename to jaxlib/sdy.cc index c6d1145517d8..ed908c28acd8 100644 --- a/jaxlib/xla/sdy.cc +++ b/jaxlib/sdy.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/sdy.h" +#include "jaxlib/sdy.h" #include #include diff --git a/jaxlib/xla/sdy.h b/jaxlib/sdy.h similarity index 90% rename from jaxlib/xla/sdy.h rename to jaxlib/sdy.h index ef075855decd..60ce012738fb 100644 --- a/jaxlib/xla/sdy.h +++ b/jaxlib/sdy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_SDY_H_ -#define JAXLIB_XLA_SDY_H_ +#ifndef JAXLIB_SDY_H_ +#define JAXLIB_SDY_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildSdySubmodule(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_SDY_H_ +#endif // JAXLIB_SDY_H_ diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 5bd010525c96..8d7933953851 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -58,7 +58,7 @@ def has_ext_modules(self): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=['jaxlib', 'jaxlib.xla_extension'], + packages=['jaxlib'], python_requires='>=3.10', install_requires=[ 'scipy>=1.11.1', @@ -107,7 +107,6 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/sharded_device_array.h similarity index 98% rename from jaxlib/xla/sharded_device_array.h rename to jaxlib/sharded_device_array.h index b0b5597f9d41..97fb8702cae5 100644 --- a/jaxlib/xla/sharded_device_array.h +++ b/jaxlib/sharded_device_array.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ -#define JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ +#ifndef JAXLIB_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_SHARDED_DEVICE_ARRAY_H_ #include #include @@ -213,4 +213,4 @@ H AbslHashValue(H h, const ShardingSpec& key) { } // namespace jax -#endif // JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ +#endif // JAXLIB_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/xla/sharding.cc b/jaxlib/sharding.cc similarity index 98% rename from jaxlib/xla/sharding.cc rename to jaxlib/sharding.cc index 858cb677e10a..2d8a88a6509d 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/sharding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/sharding.h" +#include "jaxlib/sharding.h" #include @@ -33,10 +33,10 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/sharded_device_array.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" diff --git a/jaxlib/sharding.h b/jaxlib/sharding.h new file mode 100644 index 000000000000..cb7c1b471a63 --- /dev/null +++ b/jaxlib/sharding.h @@ -0,0 +1,241 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDING_H_ +#define JAXLIB_SHARDING_H_ + +#include + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); + +// Returns a hash that may sometimes return different hashes for equal values. +// It is not a correct implementation of `__hash__` in python, but it's fine +// for jit/pjit dispatch since it only causes spurious cache misses. +size_t ShardingHash(nanobind::handle sharding); + +bool ShardingEqual(nanobind::handle a, nanobind::handle b); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const nanobind::object& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + private: + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; + static PyObject* type_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; + static PyObject* type_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_SHARDING_H_ diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/to_ifrt_sharding.cc similarity index 97% rename from jaxlib/xla/to_ifrt_sharding.cc rename to jaxlib/to_ifrt_sharding.cc index 52879cfa9fbe..f42b13ae4f1c 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/to_ifrt_sharding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/to_ifrt_sharding.h" #include #include @@ -24,9 +24,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/sharding.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/to_ifrt_sharding.h similarity index 94% rename from jaxlib/xla/to_ifrt_sharding.h rename to jaxlib/to_ifrt_sharding.h index ebc999888297..6d97f61330a0 100644 --- a/jaxlib/xla/to_ifrt_sharding.h +++ b/jaxlib/to_ifrt_sharding.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ -#define JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#ifndef JAXLIB_TO_IFRT_SHARDING_H_ +#define JAXLIB_TO_IFRT_SHARDING_H_ #include #include @@ -59,4 +59,4 @@ GetIfrtConcreteSharding(nanobind::handle sharding, } // namespace xla -#endif // JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#endif // JAXLIB_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index c7d606eba458..216313303f3e 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -66,8 +66,8 @@ py_binary( "//jaxlib:README.md", "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", - "//jaxlib/xla:xla_client.py", - "//jaxlib/xla:xla_extension", + "//jaxlib:xla_client.py", + "//jaxlib:xla_extension", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", @@ -261,7 +261,7 @@ wheel_sources( data_srcs = [ "//jaxlib", "//jaxlib:jaxlib_binaries", - "//jaxlib/xla:xla_extension", + "//jaxlib:xla_extension", ], hdr_srcs = [ "@xla//xla/ffi/api:ffi", @@ -273,7 +273,7 @@ wheel_sources( "//jaxlib:README.md", "LICENSE.txt", "//jaxlib:setup.py", - "//jaxlib/xla:xla_client.py", + "//jaxlib:xla_client.py", ], symlink_data_srcs = [ "//jaxlib", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index b4b1ec72e8c4..af17fed38804 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -102,41 +102,6 @@ def patch_copy_mlir_import( f.write(replaced) -_XLA_EXTENSION_STUBS = [ - "__init__.pyi", - "guard_lib.pyi", - "ifrt_programs.pyi", - "ifrt_proxy.pyi", - "jax_jit.pyi", - "ops.pyi", - "pmap_lib.pyi", - "profiler.pyi", - "pytree.pyi", - "transfer_guard_lib.pyi", -] - - -def patch_copy_xla_extension_stubs( - dst_dir, runfiles=None, wheel_sources_map=None -): - xla_extension_dir = os.path.join(dst_dir, "xla_extension") - os.makedirs(xla_extension_dir) - for stub_name in _XLA_EXTENSION_STUBS: - stub_path = _get_file_path( - f"__main__/jaxlib/xla/xla_extension/{stub_name}", - runfiles, - wheel_sources_map, - ) - stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). - with open(stub_path) as f: - src = f.read() - src = src.replace( - "from xla.python import xla_extension", "from .. import xla_extension" - ) - with open(os.path.join(xla_extension_dir, stub_name), "w") as f: - f.write(src) - - def verify_mac_libraries_dont_reference_chkstack( runfiles=None, wheel_sources_map=None ): @@ -240,7 +205,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/gpu_sparse.py", f"{source_file_prefix}jaxlib/plugin_support.py", f"{source_file_prefix}jaxlib/version.py", - f"{source_file_prefix}jaxlib/xla/xla_client.py", + f"{source_file_prefix}jaxlib/xla_client.py", f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", f"{source_file_prefix}jaxlib/xla_extension.{pyext}", @@ -250,9 +215,6 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): # type stubs. with open(jaxlib_dir / "py.typed", "w"): pass - patch_copy_xla_extension_stubs( - jaxlib_dir, runfiles=r, wheel_sources_map=wheel_sources_map - ) copy_files( dst_dir=jaxlib_dir / "cpu", diff --git a/jaxlib/xla/traceback.cc b/jaxlib/traceback.cc similarity index 99% rename from jaxlib/xla/traceback.cc rename to jaxlib/traceback.cc index 35085b3e32fa..3eba5288335c 100644 --- a/jaxlib/xla/traceback.cc +++ b/jaxlib/traceback.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/traceback.h" +#include "jaxlib/traceback.h" #include @@ -34,7 +34,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/nb_class_ptr.h" #include "xla/pjrt/exceptions.h" #include "tsl/platform/platform.h" diff --git a/jaxlib/traceback.h b/jaxlib/traceback.h new file mode 100644 index 000000000000..97699a7b3de9 --- /dev/null +++ b/jaxlib/traceback.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRACEBACK_H_ +#define JAXLIB_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" + +namespace xla { + +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. +class Traceback { + public: + // Requires GIL. Creates a Traceback object that requires destructor to be + // invoked with GIL held as well. + static std::optional> Get(); + + // Requires GIL. + static bool enabled() { return enabled_; } + // Requires GIL. + static void SetEnabled(bool enabled); + + // Requires GIL. Don't call this directly, you're looking for Get(). + Traceback(); + // Requires GIL. + ~Traceback(); + + Traceback(const Traceback&) = delete; + Traceback(Traceback&& other) noexcept; + Traceback& operator=(const Traceback&) = delete; + Traceback& operator=(Traceback&&) = delete; + + // Requires the GIL be held. + std::string ToString() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + std::vector Frames() const; + + const absl::InlinedVector, 32>& raw_frames() + const { + return frames_; + } + + // Returns the traceback as a fake Python Traceback object, suitable for + // using as an exception traceback. + nanobind::object AsPythonTraceback() const; + + bool operator==(const Traceback& other) const { + return frames_ == other.frames_; + } + bool operator!=(const Traceback& other) const { + return frames_ != other.frames_; + } + + private: + // Each frame is a pair of a code object and a "lasti" instruction location + // in bytes. The size of _Py_CODEUNIT has changed across different Python + // versions; the lasti value here has already been multiplied by + // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions + // like PyCode_Addr2Line(). + absl::InlinedVector, 32> frames_; + + // Protected by GIL. + static bool enabled_; +}; + +using nb_traceback = nb_class_ptr; + +template +H AbslHashValue(H h, const Traceback& traceback) { + h = H::combine(std::move(h), traceback.raw_frames()); + return h; +} + +void BuildTracebackSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_TRACEBACK_H_ diff --git a/jaxlib/xla/util.cc b/jaxlib/util.cc similarity index 98% rename from jaxlib/xla/util.cc rename to jaxlib/util.cc index 5fb3f352ba2c..814886b9a4d3 100644 --- a/jaxlib/xla/util.cc +++ b/jaxlib/util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/util.h" +#include "jaxlib/util.h" #include #include diff --git a/jaxlib/xla/util.h b/jaxlib/util.h similarity index 93% rename from jaxlib/xla/util.h rename to jaxlib/util.h index ed3b03d733dd..14848bb0ccf8 100644 --- a/jaxlib/xla/util.h +++ b/jaxlib/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_UTIL_H_ -#define JAXLIB_XLA_UTIL_H_ +#ifndef JAXLIB_UTIL_H_ +#define JAXLIB_UTIL_H_ #include "absl/status/status.h" #include "absl/types/span.h" @@ -31,4 +31,4 @@ absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); } // namespace xla -#endif // JAXLIB_XLA_UTIL_H_ +#endif // JAXLIB_UTIL_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla.cc similarity index 97% rename from jaxlib/xla/xla.cc rename to jaxlib/xla.cc index 225d45f53b4b..d05d24295482 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla.cc @@ -47,10 +47,10 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/ifrt_proxy.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_program.h" -#include "jaxlib/xla/sdy.h" +#include "jaxlib/ifrt_proxy.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_program.h" +#include "jaxlib/sdy.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" @@ -73,7 +73,7 @@ limitations under the License. #if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" -#include "jaxlib/xla/py_socket_transfer.h" +#include "jaxlib/py_socket_transfer.h" #include "xla/backends/cpu/collectives/gloo_collectives.h" #include "xla/backends/cpu/collectives/gloo_kv_store.h" #elif defined(__APPLE__) @@ -86,26 +86,26 @@ limitations under the License. #include "xla/backends/cpu/collectives/mpi_collectives.h" #endif // !_WIN32 && !PLATFORM_GOOGLE -#include "jaxlib/xla/config.h" -#include "jaxlib/xla/custom_call_sharding.h" -#include "jaxlib/xla/dlpack.h" -#include "jaxlib/xla/guard_lib.h" -#include "jaxlib/xla/jax_jit.h" -#include "jaxlib/xla/mlir.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/pjit.h" -#include "jaxlib/xla/pmap_lib.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_compile_only_client.h" -#include "jaxlib/xla/py_device.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/py_executable.h" -#include "jaxlib/xla/py_memory_space.h" -#include "jaxlib/xla/python_ref_manager.h" -#include "jaxlib/xla/pytree.h" -#include "jaxlib/xla/sharding.h" -#include "jaxlib/xla/traceback.h" -#include "jaxlib/xla/xla_compiler.h" +#include "jaxlib/config.h" +#include "jaxlib/custom_call_sharding.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/mlir.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pjit.h" +#include "jaxlib/pmap_lib.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_compile_only_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 8a19adc854ef..70460f1a4392 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -1,4 +1,4 @@ -# Copyright 2025 The JAX Authors. +# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +14,9 @@ load( "//jaxlib:jax.bzl", - "cc_proto_library", - "if_oss", "jax_visibility", - "nanobind_extension", - "proto_library", - "py_deps", - "py_strict_library", - "py_strict_test", - "pytype_strict_library", + "pytype_library", ) -load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") licenses(["notice"]) @@ -38,1026 +30,94 @@ package_group( includes = [ "//jax:internal", ], -) - -nanobind_pywrap_extension( - name = "xla_extension", - srcs = ["xla.cc"], - pytype_deps = py_deps(["numpy"]), - pytype_srcs = glob(["xla_extension/*.pyi"]), - visibility = ["//visibility:public"], - deps = [ - ":config", - ":custom_call_sharding", - ":dlpack", - ":guard_lib", - ":ifrt_proxy", - ":jax_jit", - ":mlir", - ":nb_class_ptr", - ":pjit", - ":pmap_lib", - ":py_client", - ":python_ref_manager", - ":pytree", - ":sdy", - ":traceback", - ":util", - ":xla_compiler", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:initialize", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@nanobind", - "@tsl//tsl/platform", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:literal", - "@xla//xla:shape_util", - "@xla//xla:types", - "@xla//xla:util", - "@xla//xla/backends/cpu/collectives:cpu_collectives", - "@xla//xla/ffi:ffi_api", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:mlir_to_hlo", - "@xla//xla/pjrt:pjrt_api", - "@xla//xla/pjrt:pjrt_c_api_client", - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_common", - "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_executable", - "@xla//xla/pjrt:pjrt_layout", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/distributed:client", - "@xla//xla/pjrt/distributed:key_value_store_interface", - "@xla//xla/pjrt/distributed:protocol_proto_cc", - "@xla//xla/pjrt/distributed:service", - "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", - "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/python:logging", - "@xla//xla/python:nb_absl_flat_hash_map", - "@xla//xla/python:nb_absl_span", - "@xla//xla/python:ops", - "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:profiler", - "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/python:types", - "@xla//xla/python:version", - "@xla//xla/python/ifrt", - "@xla//xla/python/ifrt:plugin_program", - "@xla//xla/python/ifrt:plugin_program_serdes", - "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", - "@xla//xla/python/pjrt_ifrt:xla_ifrt", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:status", - "@xla//xla/tsl/platform:statusor", - "@xla//xla/tsl/platform/cloud:gcs_file_system", - "@xla//xla/tsl/python/lib/core:numpy", - ] + select({ - # gloo tcp transport only builds on linux - "@xla//xla/tsl:macos": [ - "@gloo//:transport_uv", - "@xla//xla/backends/cpu/collectives:gloo_collectives", - "@xla//xla/backends/cpu/collectives:gloo_kv_store", - ], - "@xla//xla/tsl:windows": [], - "//conditions:default": [ - ":py_socket_transfer", - "@gloo//:transport_tcp", - "@xla//xla/backends/cpu/collectives:gloo_collectives", - "@xla//xla/backends/cpu/collectives:gloo_kv_store", - ], - }) + select({ - # mpitrampoline does not build on windows - "@xla//xla/tsl:windows": [], - # we support MPI collectives only in OSS builds - "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), - }), -) - -cc_library( - name = "callback", - srcs = [ - "callback.cc", - ], - hdrs = [ - "callback.h", - ], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":python_ref_manager", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:comparison_util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:transpose", - "@xla//xla/python:nb_numpy", - "@xla//xla/tsl/platform:statusor", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - -cc_library( - name = "config", - srcs = ["config.cc"], - hdrs = ["config.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":python_ref_manager", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla/tsl/platform:logging", - ], -) - -cc_library( - name = "custom_call_sharding", - srcs = ["custom_call_sharding.cc"], - hdrs = ["custom_call_sharding.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@nanobind", - "@xla//third_party/python_runtime:headers", - "@xla//xla:shape_util", - "@xla//xla:util", - "@xla//xla/hlo/ir:hlo", - "@xla//xla/hlo/utils:hlo_sharding_util", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_helpers", - "@xla//xla/python:custom_call_batch_partitioner", - "@xla//xla/python:custom_partition_callback", - "@xla//xla/python:debug_callback_partitioner", - "@xla//xla/python:inspect_sharding", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "dlpack", - srcs = ["dlpack.cc"], - hdrs = ["dlpack.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":nb_class_ptr", - ":py_client", - ":python_ref_manager", - ":traceback", - ":util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@dlpack", - "@llvm-project//llvm:Support", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:shape_util", - "@xla//xla:status_macros", - "@xla//xla:util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_common", - "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", - "@xla//xla/python/pjrt_ifrt", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "guard_lib", - srcs = ["guard_lib.cc"], - hdrs = ["guard_lib.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@nanobind", - "@xla//xla:util", - ], -) - -cc_library( - name = "ifrt_proxy", - srcs = ["ifrt_proxy.cc"], - hdrs = ["ifrt_proxy.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":nb_class_ptr", - ":py_client", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log:log_entry", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/time", - "@nanobind", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python/ifrt", - "@xla//xla/python/ifrt:attribute_map", - "@xla//xla/python/ifrt_proxy/client:grpc_client", - "@xla//xla/python/ifrt_proxy/client:registry", - "@xla//xla/tsl/platform:env", - "@xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "jax_jit", - srcs = ["jax_jit.cc"], - hdrs = ["jax_jit.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":py_client", - ":python_ref_manager", - ":pytree", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@nanobind", - "@tsl//tsl/profiler/lib:traceme", - "@xla//third_party/python_runtime:headers", # build_cleaner: keep - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_layout", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_absl_inlined_vector", - "@xla//xla/python:nb_absl_span", - "@xla//xla/python:types", - "@xla//xla/tsl/platform:logging", - ], -) - -cc_library( - name = "mlir", - srcs = ["mlir.cc"], - hdrs = ["mlir.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:BytecodeWriter", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:Support", - "@nanobind", - "@stablehlo//:stablehlo_serialization", - "@xla//xla/hlo/builder:xla_computation", - "@xla//xla/hlo/translate:stablehlo", - "@xla//xla/mlir_hlo:mhlo_passes", - "@xla//xla/pjrt:mlir_to_hlo", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/service:hlo_proto_cc", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "nb_class_ptr", - hdrs = ["nb_class_ptr.h"], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/xla/nb_class_ptr"), - deps = ["@nanobind"], -) - -cc_library( - name = "pjit", - srcs = ["pjit.cc"], - hdrs = ["pjit.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":config", - ":guard_lib", - ":jax_jit", - ":nb_class_ptr", - ":py_client", - ":python_ref_manager", - ":pytree", - ":traceback", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@nanobind", - "@tsl//tsl/profiler/lib:traceme", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:shape_util", - "@xla//xla:util", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:lru_cache", - "@xla//xla/python:nb_helpers", - "@xla//xla/python:nb_numpy", - "@xla//xla/python/ifrt", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:env", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "pmap_lib", - srcs = ["pmap_lib.cc"], - hdrs = ["pmap_lib.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":config", - ":jax_jit", - ":nb_class_ptr", - ":py_client", - ":python_ref_manager", - ":pytree", - ":traceback", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@nanobind", - "@tsl//tsl/profiler/lib:traceme", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:status_macros", - "@xla//xla:util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_helpers", - "@xla//xla/python:nb_numpy", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:env", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", - "@xla//xla/tsl/python/lib/core:numpy", - ], + packages = ["@xla//xla/python/..."], ) cc_library( name = "py_client", - srcs = [ - "py_array.cc", - "py_client.cc", - "py_compile_only_client.cc", - "py_device.cc", - "py_device_list.cc", - "py_executable.cc", - "py_memory_space.cc", - "py_program.cc", - "py_values.cc", - "sharding.cc", - "to_ifrt_sharding.cc", - ], hdrs = [ "py_array.h", "py_client.h", - "py_compile_only_client.h", "py_device.h", "py_device_list.h", "py_executable.h", - "py_memory_space.h", - "py_program.h", - "py_values.h", - "sharded_device_array.h", "sharding.h", - "to_ifrt_sharding.h", - ], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/xla/py_client"), - deps = [ - ":guard_lib", - ":nb_class_ptr", - ":py_client_cpu", - ":py_host_callback", - ":python_ref_manager", - ":traceback", - ":util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@nanobind", - "@tsl//tsl/platform:fingerprint", - "@tsl//tsl/platform:ml_dtypes", - "@tsl//tsl/profiler/lib:traceme", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:literal", - "@xla//xla:shape_util", - "@xla//xla:status_macros", - "@xla//xla:types", - "@xla//xla:util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/hlo/ir:hlo", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:lru_cache", - "@xla//xla/pjrt:mlir_to_hlo", - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_executable", - "@xla//xla/pjrt:pjrt_future", - "@xla//xla/pjrt:pjrt_layout", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_helpers", - "@xla//xla/python:nb_numpy", - "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:types", - "@xla//xla/python/compile_only_ifrt:client", - "@xla//xla/python/ifrt", - "@xla//xla/python/ifrt:attribute_map", - "@xla//xla/python/ifrt:custom_call_program", - "@xla//xla/python/ifrt:plugin_program", - "@xla//xla/python/ifrt:plugin_program_serdes", - "@xla//xla/python/ifrt:user_context", - "@xla//xla/python/ifrt/hlo:hlo_program", - "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:pjrt_dtype", - "@xla//xla/python/pjrt_ifrt:xla_ifrt", - "@xla//xla/service:platform_util", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/framework:allocator", - "@xla//xla/tsl/platform:env", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:status", - "@xla//xla/tsl/platform:statusor", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - -cc_library( - name = "py_client_cpu", - srcs = ["py_client_cpu.cc"], - hdrs = ["py_client_cpu.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:shape_util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/ffi:ffi_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:transpose", - "@xla//xla/python:nb_numpy", - "@xla//xla/python:types", - ], - alwayslink = 1, -) - -cc_library( - name = "py_host_callback", - srcs = ["py_host_callback.cc"], - hdrs = ["py_host_callback.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":callback", - ":py_host_callback_cc_proto", - ":python_ref_manager", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@nanobind", - "@xla//xla:shape_util", - "@xla//xla:status_macros", - "@xla//xla:util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla/pjrt:host_callback", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", - "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:statusor", - ], -) - -proto_library( - name = "py_host_callback_proto", - srcs = ["py_host_callback.proto"], -) - -cc_proto_library( - name = "py_host_callback_cc_proto", - visibility = jax_visibility("jaxlib/xla/py_host_callback_cc_proto"), - deps = [":py_host_callback_proto"], -) - -cc_library( - name = "py_socket_transfer", - srcs = ["py_socket_transfer.cc"], - hdrs = ["py_socket_transfer.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", ], + copts = ["-fexceptions"], features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/py_client"), deps = [ - ":nb_class_ptr", - ":py_client", - ":traceback", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", - "@nanobind", - "@tsl//tsl/platform:casts", - "@xla//xla:util", - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_numpy", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", - "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:pjrt_dtype", - "@xla//xla/python/transfer:event_loop", - "@xla//xla/python/transfer:socket-server", - "@xla//xla/python/transfer:socket_bulk_transport", - "@xla//xla/python/transfer:streaming", - "@xla//xla/python/transfer:streaming_ifrt", - "@xla//xla/python/transfer:transfer_socket_proto_cc", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:statusor", + "//jaxlib:py_client", ], ) -cc_library( - name = "python_ref_manager", - srcs = ["python_ref_manager.cc"], - hdrs = ["python_ref_manager.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/xla/python_ref_manager"), - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - ], -) - -proto_library( - name = "pytree_proto", - srcs = ["pytree.proto"], -) - -cc_proto_library( - name = "pytree_cc_proto", - deps = [":pytree_proto"], -) - cc_library( name = "pytree", - srcs = ["pytree.cc"], - hdrs = ["pytree.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", + hdrs = [ + "pytree.h", ], + copts = ["-fexceptions"], features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/xla/pytree"), + visibility = jax_visibility("jaxlib/pytree"), deps = [ - ":nb_class_ptr", - ":pytree_cc_proto", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla/pjrt:exceptions", - "@xla//xla/tsl/platform:logging", + "//jaxlib:pytree", ], ) cc_library( - name = "sdy", - srcs = ["sdy.cc"], - hdrs = ["sdy.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", + name = "nb_class_ptr", + hdrs = [ + "nb_class_ptr.h", ], + copts = ["-fexceptions"], features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/nb_class_ptr"), deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:BytecodeWriter", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@nanobind", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "@xla//xla/mlir_hlo:all_passes", - "@xla//xla/pjrt:mlir_to_hlo", - "@xla//xla/pjrt:status_casters", - "@xla//xla/service/spmd/shardy:constants", - "@xla//xla/service/spmd/shardy:utils", - "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", - "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", - "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + "//jaxlib:nb_class_ptr", ], ) cc_library( name = "traceback", - srcs = ["traceback.cc"], - hdrs = ["traceback.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/xla/traceback"), - deps = [ - ":nb_class_ptr", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@tsl//tsl/platform", - "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla/pjrt:exceptions", - "@xla//xla/tsl/platform:logging", - ], -) - -cc_library( - name = "util", - srcs = ["util.cc"], - hdrs = ["util.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", + hdrs = [ + "traceback.h", ], + copts = ["-fexceptions"], features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/traceback"), deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//xla:util", - "@xla//xla/pjrt:pjrt_future", - "@xla//xla/python:version", - "@xla//xla/python/ifrt", - "@xla//xla/tsl/concurrency:async_value", - "@xla//xla/tsl/concurrency:ref_count", + "//jaxlib:traceback", ], ) cc_library( - name = "xla_compiler", - srcs = ["xla_compiler.cc"], - hdrs = ["xla_compiler.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", + name = "python_ref_manager", + hdrs = [ + "python_ref_manager.h", ], + copts = ["-fexceptions"], features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/python_ref_manager"), deps = [ - ":dlpack", - ":py_client", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@nanobind", - "@xla//xla:array", - "@xla//xla:debug_options_flags", - "@xla//xla:literal", - "@xla//xla:shape_util", - "@xla//xla:util", - "@xla//xla:xla_data_proto_cc", - "@xla//xla:xla_proto_cc", - "@xla//xla/client:executable_build_options", - "@xla//xla/ffi", - "@xla//xla/ffi:ffi_api", - "@xla//xla/ffi/api:c_api", - "@xla//xla/hlo/builder:xla_builder", - "@xla//xla/hlo/builder:xla_computation", - "@xla//xla/hlo/ir:hlo", - "@xla//xla/hlo/ir:hlo_module_group", - "@xla//xla/hlo/parser:hlo_parser", - "@xla//xla/hlo/pass:hlo_pass", - "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", - "@xla//xla/hlo/transforms/simplifiers:hlo_dce", - "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", - "@xla//xla/pjrt:compile_options_proto_cc", - "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:pjrt_executable", - "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_helpers", - "@xla//xla/python:nb_numpy", - "@xla//xla/python:types", - "@xla//xla/service:call_inliner", - "@xla//xla/service:computation_placer", - "@xla//xla/service:custom_call_target_registry", - "@xla//xla/service:hlo_graph_dumper", - "@xla//xla/service:hlo_module_config", - "@xla//xla/service:hlo_proto_cc", - "@xla//xla/service:name_uniquer", - "@xla//xla/tsl/lib/strings:proto_serialization", - "@xla//xla/tsl/platform:env", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:logging", - "@xla//xla/tsl/platform:statusor", + "//jaxlib:python_ref_manager", ], ) -pytype_strict_library( +pytype_library( name = "xla_client", srcs = ["xla_client.py"], - pytype_srcs = ["xla_client.pyi"], visibility = [":xla_python"], - deps = py_deps([ - "numpy", - "ml_dtypes", - ]) + [":xla_extension"], -) - -py_strict_test( - name = "xla_client_backend_independent_test", - srcs = ["xla_client_backend_independent_test.py"], deps = [ - ":xla_client", - ] + py_deps([ - "absl/testing", - "numpy", - "portpicker", - ]), + ":xla_extension", + "//jaxlib:xla_client", + ], ) -py_strict_library( - name = "xla_client_test", - testonly = 1, - srcs = ["xla_client_test.py"], +pytype_library( + name = "xla_extension", + srcs = ["xla_extension.py"], visibility = [":xla_python"], deps = [ - ":xla_client", - "//jax", - "//jax:test_util", - "//jaxlib", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - "ml_dtypes", - "numpy", - ]), -) - -nanobind_extension( - name = "custom_calls_testlib", - testonly = 1, - srcs = ["custom_calls_testlib.cc"], - deps = [ - "@com_google_absl//absl/status", - "@nanobind", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", + "//jaxlib:xla_extension", ], ) - -py_strict_test( - name = "xla_client_test_cpu", - srcs = ["xla_client_test.py"], - args = ["--backend=cpu"], - env = { - "XLA_FLAGS": "--xla_force_host_platform_device_count=4", - }, - main = "xla_client_test.py", - deps = [ - ":custom_calls_testlib", - ":xla_client", - "//jax", - "//jax:test_util", - "//jaxlib", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - "ml_dtypes", - "numpy", - ]), -) - -py_strict_test( - name = "pytree_test", - srcs = ["pytree_test.py"], - deps = [ - ":xla_client", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - ]), -) - -py_strict_test( - name = "config_test", - srcs = ["config_test.py"], - deps = [ - ":xla_client", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - ]), -) - -py_strict_test( - name = "jax_jit_test", - srcs = ["jax_jit_test.py"], - deps = [ - ":xla_client", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - "numpy", - ]), -) diff --git a/jaxlib/xla/nb_class_ptr.h b/jaxlib/xla/nb_class_ptr.h index e468860dc661..0b539115a1cb 100644 --- a/jaxlib/xla/nb_class_ptr.h +++ b/jaxlib/xla/nb_class_ptr.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,44 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_NB_CLASS_PTR_H_ #define JAXLIB_XLA_NB_CLASS_PTR_H_ -#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" // IWYU pragma: keep -namespace xla { - -// A reference-counting smart pointer to a nanobind-wrapped class on the Python -// heap. Type T must be a class known to nanobind via a nanobind::class_ -// declaration. nb_class_ptr is useful for managing C++ classes that may be -// allocated inline in Python objects on the Python heap. -template -class nb_class_ptr : public nanobind::object { - public: - inline nb_class_ptr() : nanobind::object() {} - inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) - : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} - inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) - : nanobind::object(h, ::nanobind::detail::steal_t{}) {} - inline static bool check_(nanobind::handle h) { - nanobind::handle type = nanobind::type(); - return h.type().is(type); - }; - - T* operator->() const { return nanobind::inst_ptr(ptr()); } - T& operator*() const { return *nanobind::inst_ptr(ptr()); } - T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } -}; - -// This function is analogous to std::make_unique(...), but instead it -// allocates the object on the Python heap -template -nb_class_ptr make_nb_class(Args&&... args) { - nanobind::handle type = nanobind::type(); - nanobind::object instance = nanobind::inst_alloc(type); - T* ptr = nanobind::inst_ptr(instance); - new (ptr) T(std::forward(args)...); - nanobind::inst_mark_ready(instance); - return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); -} - -} // namespace xla - -#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ +#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index 7c7a6fefe3a2..fee6d2e24f16 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,350 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PY_ARRAY_H_ #define JAXLIB_XLA_PY_ARRAY_H_ -#include - -#include -#include -#include -#include -#include -#include - -// placeholder for index annotation headers -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "llvm/Support/Casting.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/traceback.h" -#include "xla/pjrt/exceptions.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/pjrt_layout.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/device_list.h" -#include "xla/python/ifrt/future.h" -#include "xla/python/nb_numpy.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/shape.h" -#include "xla/tsl/concurrency/ref_count.h" -#include "xla/util.h" - -namespace xla { - -// Private to PyArray, but you cannot forward declare member classes. -// Not thread safe; assumes the GIL is held. -class PyHostValue { - public: - PyHostValue(); - ~PyHostValue(); - - PyHostValue(const PyHostValue&) = delete; - PyHostValue(PyHostValue&&) = delete; - PyHostValue& operator=(const PyHostValue&) = delete; - PyHostValue& operator=(PyHostValue&&) = delete; - - absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, - ifrt::Array* ifrt_array); - - absl::StatusOr> AsNumPyArray( - std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); - - void Clear(); - - private: - absl::Status CopyStringArrayToHostAsync( - std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); - - absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); - - ifrt::Future<> ready_; - nb_numpy_ndarray value_; - - // Optional field, only used for arrays of type kString. This vector of cords - // serves as input buffer for the CopyToHostBuffer call. It holds these - // contents until it is lazily converted it to a numpy array when the user - // calls `AsNumPyArray`. - std::shared_ptr> string_array_contents_; -}; - -// Private to PyArray, but you cannot forward declare member classes. -struct PyArray_Storage { - PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, - std::vector shape, nanobind::object sharding, - bool committed, nb_class_ptr py_client, - std::optional traceback, - tsl::RCReference ifrt_array, - xla::PjRtFuture<> result_status); - - ~PyArray_Storage(); - nanobind::handle AsHandle(); - - nanobind::object aval; - bool weak_type = false; - nb_dtype dtype; - std::vector shape; - - nanobind::object sharding; - nanobind::object npy_value = nanobind::none(); - bool committed = false; - - nb_class_ptr py_client; - std::optional traceback; - tsl::RCReference ifrt_array; - nanobind::object fully_replicated_array = nanobind::none(); - - // optional field, used only in python - std::vector py_arrays; - PyHostValue host_value; // Protected by the GIL. - std::optional dynamic_shape = std::nullopt; - // Only set if this Array was generated by a computation that has effects. - // This is the result status of the XLA computation that generated this - // array. - xla::PjRtFuture<> result_status; - - // Doubly-linked list of all PyArrays known to the client. Protected by the - // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be - // duplicate PjRtBuffers in this list. - PyArray_Storage* next; - PyArray_Storage* prev; - - uint8_t thread_id_bucket; -}; - -// The C++ implementation of jax.Array. A few key methods and data members are -// implemented in C++ for performance, while most of the functionalities are -// still implemented in python. -class PyArray : public nanobind::object { - public: - NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); - PyArray() = default; - - // "__init__" methods. Only used in python - static void PyInit(PyArray self, nanobind::object aval, - nanobind::object sharding, - absl::Span py_arrays, bool committed, - bool skip_checks); - - // Only used in C++. `skip_checks` should only be set for Arrays created by - // jax that cannot possibly have consistency issues (e.g. `sharding` devices - // different than `ifrt_array` devices). Arrays created by users should be - // checked. - PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, - std::vector shape, nanobind::object sharding, - nb_class_ptr py_client, - std::optional traceback, - tsl::RCReference ifrt_array, bool committed, - bool skip_checks, - xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); - - static PyArray MakeFromSingleDeviceArray( - nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, bool weak_type, bool committed, - xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); - - static PyArray MakeFromIfrtArrayAndSharding( - nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, nanobind::object sharding, - bool weak_type, bool committed, bool skip_checks); - - static absl::Status RegisterTypes(nanobind::module_& m); - - static PyArray borrow(PyObject* ptr) { - return nanobind::borrow(ptr); - } - - using Storage = PyArray_Storage; - - const nanobind::object& aval() const { return GetStorage().aval; } - void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } - - bool weak_type() const { return GetStorage().weak_type; } - - const nb_dtype& dtype() const { return GetStorage().dtype; } - absl::Span shape() const { return GetStorage().shape; } - - const nanobind::object& sharding() const { return GetStorage().sharding; } - - absl::StatusOr> layout() { - return ifrt_array()->layout(); - } - - bool committed() const { return GetStorage().committed; } - - const nanobind::object& npy_value() const { return GetStorage().npy_value; } - void set_npy_value(nanobind::object v) { - GetStorage().npy_value = std::move(v); - } - - const nb_class_ptr& py_client() const { - return GetStorage().py_client; - } - - const std::optional& traceback() const { - return GetStorage().traceback; - } - - // Returns xla::InvalidArgument if the buffer has been deleted. - // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. - absl::StatusOr IsReady() { - ifrt::Array* ifrt_array_ptr = ifrt_array(); - if (ifrt_array_ptr->IsDeleted()) { - return InvalidArgument("Array has been deleted."); - } - return ifrt_array_ptr->GetReadyFuture().IsReady(); - } - - const xla::PjRtFuture<>& result_status() const { - return GetStorage().result_status; - } - - ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } - - // Short-term escape hatch to get PjRtBuffers from PyArray. - // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. - absl::Span> pjrt_buffers() const { - ifrt::Array* ifrt_array_ptr = ifrt_array(); - if (ifrt_array_ptr == nullptr) { - return {}; - } - auto* arr = - llvm::dyn_cast_or_null(ifrt_array_ptr); - if (arr == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return arr->pjrt_buffers(); - } - - int num_addressable_shards() const { - ifrt::Array* ifrt_array_ptr = ifrt_array(); - if (ifrt_array_ptr == nullptr) { - return 0; - } - auto* arr = - llvm::dyn_cast_or_null(ifrt_array_ptr); - if (arr == nullptr) { - // TODO(hyeontaek): Add num_addressable_shards to ifrt. - return num_shards(); - } - return arr->pjrt_buffers().size(); - } - - std::vector& py_arrays() { return GetStorage().py_arrays; } - const std::vector& py_arrays() const { - return GetStorage().py_arrays; - } - const std::vector& py_arrays_cached(); - - nanobind::object arrays(); - absl::Status set_arrays(nanobind::object obj); - absl::StatusOr FullyReplicatedShard(); - - int num_shards() const { - ifrt::Array* ifrt_array_ptr = ifrt_array(); - if (ifrt_array_ptr == nullptr) { - return 0; - } - return ifrt_array_ptr->sharding().devices()->size(); - } - - static nanobind::handle type() { - DCHECK(type_); - return nanobind::handle(type_); - } - - static bool IsPyArray(nanobind::handle arg) { - return arg.type().is(PyArray::type()); - } - - absl::Status BlockUntilReady() const; - - absl::Status BlockUntilResultStatusIsReady(); - - absl::StatusOr GetOnDeviceSizeInBytes(); - absl::StatusOr> - SingleDeviceArrayToNumpyArrayDidCopy(); - absl::StatusOr SingleDeviceArrayToNumpyArray(); - absl::Status CopySingleDeviceArrayToHostAsync(); - nanobind::dict CudaArrayInterface(); - absl::StatusOr UnsafeBufferPointer(); - - absl::Status Delete(); - - bool IsDeleted() const; - - PyArray Clone() const; - - static absl::StatusOr> BatchedCopyToDeviceWithSharding( - absl::Span py_arrays, - absl::Span dst_device_lists, - absl::Span dst_shardings, - absl::Span array_copy_semantics); - - static absl::StatusOr BatchedDevicePut( - nanobind::object aval, nanobind::object sharding, - std::vector xs, - absl::Span dst_devices, bool committed, - bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, - bool jax_enable_x64); - - static absl::StatusOr ReorderShards( - PyArray x, nanobind::object dst_sharding, - ifrt::ArrayCopySemantics array_copy_semantics); - - static absl::Status BatchedBlockUntilReady( - std::vector objs); - - absl::Status ReplaceWithAlias(PyArray o); - - private: - absl::StatusOr AssertUnsharded(absl::string_view api); - - nanobind::object CheckAndRearrange(absl::Span py_arrays, - nanobind::object sharding, - nanobind::object aval); - - void SetIfrtArray(tsl::RCReference ifrt_array); - - Storage& GetStorage(); - const Storage& GetStorage() const; - - inline static PyObject* type_ = nullptr; -}; - -class PyArrayResultHandler { - public: - PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, - bool committed, bool skip_checks); - - PyArray Call(absl::Span py_arrays) const; - PyArray Call(PyArray py_array) const; - - PyArray Call(nb_class_ptr py_client, - tsl::RCReference ifrt_array, - xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; - - private: - nanobind::object aval_; - nanobind::object sharding_; - bool weak_type_; - bool committed_; - bool skip_checks_; - - nb_dtype dtype_; - std::vector shape_; -}; - -absl::StatusOr CudaArrayInterfaceToBuffer( - const nanobind::dict& cai, nb_class_ptr cuda_client, - std::optional device_id); - -} // namespace xla +#include "jaxlib/py_array.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 29a506d48864..b7e90fe5e24c 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,237 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PY_CLIENT_H_ #define JAXLIB_XLA_PY_CLIENT_H_ -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "llvm/Support/Casting.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "xla/pjrt/exceptions.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/python/ifrt/attribute_map.h" -#include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/compiler.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/program.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/shape.h" - -namespace xla { - -class PyClient; -class PyLoadedExecutable; -class PyArray; -class PyDevice; -class PyMemorySpace; -struct PyArray_Storage; - -// Python wrapper around PjRtClient. -// We use a wrapper class to add Python-specific functionality. -class PyClient { - public: - static nb_class_ptr Make(std::shared_ptr ifrt_client); - - // Do not call the constructor directly. Use `PyClient::Make` instead. - explicit PyClient(std::shared_ptr ifrt_client); - virtual ~PyClient(); - - ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } - const std::shared_ptr& shared_ptr_ifrt_client() const { - return ifrt_client_; - } - - // Short-term escape hatch to get PjRtClient from PyClient. - // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. - xla::PjRtClient* pjrt_client() const { - auto* pjrt_client = - llvm::dyn_cast_or_null(ifrt_client_.get()); - if (pjrt_client == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return pjrt_client->pjrt_client(); - } - std::shared_ptr shared_ptr_pjrt_client() { - auto* pjrt_client = - llvm::dyn_cast_or_null(ifrt_client_.get()); - if (pjrt_client == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return pjrt_client->shared_ptr_pjrt_client(); - } - - // Legacy alises. - std::shared_ptr shared_pjrt_client() { - return shared_ptr_pjrt_client(); - } - - absl::string_view platform_name() const { - // TODO(phawkins): this is a temporary backwards compatibility shim. We - // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but - // we haven't yet updated JAX clients that expect "gpu". Migrate users and - // remove this code. - if (ifrt_client_->platform_name() == "cuda" || - ifrt_client_->platform_name() == "rocm") { - return "gpu"; - } else { - return ifrt_client_->platform_name(); - } - } - absl::string_view raw_platform_name() const { - // TODO(parkers): Once platform_name() is the same, remove this. - return ifrt_client_->platform_name(); - } - absl::string_view platform_version() const { - return ifrt_client_->platform_version(); - } - absl::string_view runtime_type() const { - return ifrt_client_->runtime_type(); - } - - // Returns implementation-specific attributes about this client, e.g. the PJRT - // C API version if applicable. - const xla::ifrt::AttributeMap& Attributes() const { - return client_attributes_; - } - - int addressable_device_count() const { - return ifrt_client_->addressable_device_count(); - } - int device_count() const { return ifrt_client_->device_count(); } - int process_index() const { return ifrt_client_->process_index(); } - - std::vector> Devices(); - std::vector> LocalDevices(); - // Returns all devices in the client. Private API; only use this method for - // implementing backend._get_all_devices(). - // TODO(hyeontaek): Remove this method once we have a unified API for - // enumerating devices with different criteria. - std::vector> GetAllDevices(); - absl::StatusOr> DeviceFromLocalHardwareId( - int local_hardware_id); - - // Returns the PyDevice associated with the given ifrt::Device. - nb_class_ptr GetPyDevice(ifrt::Device* device); - - // Returns the PyMemorySpace associated with the given ifrt::Memory. - nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); - - // Returns a vector of live PyArray objects. PyArray objects may share - // PjRtBuffers, so there may be duplicates of the same underlying device - // buffer. - std::vector LiveBuffersOnDevice(ifrt::Device* device); - - nanobind::list LiveExecutables(); - - // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. - absl::Status Defragment(); - - static absl::StatusOr BufferFromPyval( - nb_class_ptr client, nanobind::handle argument, - ifrt::Device* device, bool force_copy, - ifrt::Client::HostBufferSemantics host_buffer_semantics); - - static absl::StatusOr> CompileIfrtProgram( - nb_class_ptr client, - std::unique_ptr ifrt_program, - std::unique_ptr ifrt_options); - - static absl::StatusOr> Compile( - nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks); - - static absl::StatusOr> Compile( - nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks); - - absl::StatusOr SerializeExecutable( - const PyLoadedExecutable& executable) const; - static absl::StatusOr> DeserializeExecutable( - nb_class_ptr client, nanobind::bytes serialized, - std::optional options, - std::vector host_callbacks); - - absl::StatusOr HeapProfile(); - - // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable - // that takes in arguments of shapes `operand_shapes` and returns results of - // shapes `result_shapes`. The arguments correspond to Send ops in the HLO - // program through `send_channel_ids` and the results correspond to Recv ops - // through `recv_channel_ids`. It returns the host callback as an opaque - // object whose reference will keep the Python callback alive. The host - // callback can be passed to `PyClient::Compile` or - // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the - // XLA computation can trigger the execution of this host callback. - // `serializer` is a function that takes `callable` as an argument and returns - // a serialized callable as a string. - // - // The callable receives as arguments NumPy arrays for arguments with array - // types, and None for Token argument. The callable must return a tuple of - // either arrays or None values. - absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( - nanobind::callable callable, absl::Span operand_shapes, - absl::Span result_shapes, - absl::Span send_channel_ids, - absl::Span recv_channel_ids, - nanobind::callable serializer); - - std::vector LiveArrays() const; - - static void RegisterPythonTypes(nanobind::module_& m); - - protected: - static void Initialize(nb_class_ptr client); - - private: - friend class PyLoadedExecutable; - friend class PyArray; - friend struct PyArray_Storage; - - static int tp_traverse(PyObject* self, visitproc visit, void* arg); - static int tp_clear(PyObject* self); - static PyType_Slot slots_[]; - - std::shared_ptr ifrt_client_; - xla::ifrt::AttributeMap client_attributes_; - // Pointers to intrusive doubly-linked lists of arrays and executables, used - // to iterate over all known objects when heap profiling. The list structure - // is protected by the GIL. - - nanobind::ft_mutex executables_mutex_; - // List guarded by executables_mutex_. - PyLoadedExecutable* executables_ = nullptr; - -#ifdef NB_FREE_THREADING - static constexpr size_t kNumArraysShards = 16; -#else - static constexpr size_t kNumArraysShards = 1; -#endif - struct ArraysShard { - mutable nanobind::ft_mutex mutex; - PyArray_Storage* arrays; - }; - std::array arrays_; - - absl::flat_hash_map> devices_; - absl::flat_hash_map> - memory_spaces_; -}; - -} // namespace xla +#include "jaxlib/py_client.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h index 4e74992fb2ee..2b3beff864ae 100644 --- a/jaxlib/xla/py_device.h +++ b/jaxlib/xla/py_device.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,68 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PY_DEVICE_H_ #define JAXLIB_XLA_PY_DEVICE_H_ -#include - -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "xla/literal.h" -#include "xla/python/ifrt/device.h" -#include "xla/shape.h" - -namespace xla { - -class PyDevice { - public: - PyDevice(nb_class_ptr client, ifrt::Device* device); - - // Devices are compared using Python object identity, so we don't allow them - // to be copied or moved. - PyDevice(const PyDevice&) = delete; - PyDevice(PyDevice&&) = delete; - PyDevice& operator=(const PyDevice&) = delete; - PyDevice& operator=(PyDevice&&) = delete; - - const nb_class_ptr& client() const { return client_; } - ifrt::Device* device() const { return device_; } - - int id() const; - int process_index() const; - absl::string_view platform() const; - absl::string_view device_kind() const; - std::optional local_hardware_id() const; - - absl::string_view Str() const; - absl::string_view Repr() const; - - absl::Status TransferToInfeed(LiteralSlice literal); - absl::StatusOr TransferFromOutfeed(Shape shape); - - absl::StatusOr> Memory( - absl::string_view kind) const; - absl::StatusOr> DefaultMemory() const; - nanobind::list AddressableMemories() const; - absl::StatusOr> MemoryStats() const; - - absl::StatusOr GetStreamForExternalReadyEvents() const; - - static void RegisterPythonType(nanobind::module_& m); - - private: - static int tp_traverse(PyObject* self, visitproc visit, void* arg); - static int tp_clear(PyObject* self); - static PyType_Slot slots_[]; - - nb_class_ptr client_; - ifrt::Device* device_; -}; - -} // namespace xla +#include "jaxlib/py_device.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h index 0fa9b3965dfe..1b75286c3d3d 100644 --- a/jaxlib/xla/py_device_list.h +++ b/jaxlib/xla/py_device_list.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,121 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PY_DEVICE_LIST_H_ #define JAXLIB_XLA_PY_DEVICE_LIST_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "xla/python/ifrt/device_list.h" - -namespace jax { - -// Device list with various caching and direct access to IFRT DeviceList. -class PyDeviceList { - public: - PyDeviceList(xla::nb_class_ptr py_client, - xla::ifrt::DeviceListRef device_list); - explicit PyDeviceList(nanobind::tuple py_device_assignment); - ~PyDeviceList(); - - PyDeviceList(const PyDeviceList&) = delete; - PyDeviceList(PyDeviceList&&) = delete; - PyDeviceList& operator=(const PyDeviceList&) = delete; - PyDeviceList& operator=(PyDeviceList&&) = delete; - - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } - - // These two methods are safe to call from C++ without GIL. - xla::nb_class_ptr py_client() const { return py_client_; } - absl::StatusOr ifrt_device_list() const; - - int Len() const; // Requires the GIL in GIL mode. - nanobind::object GetItem(int index); // Requires the GIL in GIL mode. - - // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. - static xla::nb_class_ptr AddressableDeviceList( - xla::nb_class_ptr self); - - // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. - static absl::StatusOr DefaultMemoryKind( - xla::nb_class_ptr self); - - // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. - static absl::StatusOr MemoryKinds( - xla::nb_class_ptr self); - - // go/pywald-pybind-annotation BEGIN - // refs { - // module_path: "third_party/py/jax/jaxlib/xla/xla.cc" - // module_arg {} - // } - // go/pywald-pybind-annotation END - static void Register(nanobind::module_& m); - - private: - nanobind::tuple AsTuple() const; - - // Methods below require GIL. - nanobind::object GetSlice(nanobind::slice slice); - nanobind::iterator Iter(); - - std::string Str(); - - nanobind::tuple Dump() const; - - int64_t Hash(); // Mutates hash_, needs self lock. - - static bool Equal(xla::nb_class_ptr self, - nanobind::handle other); - static bool NotEqual(xla::nb_class_ptr self, - nanobind::handle other); - - // Finds the memory kind info from an addressable device. Requires the GIL - // or self lock. - void PopulateMemoryKindInfo(); - // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` - // instead of `ifrt_device_list_` to support duck-typed device objects. - // Requires the GIL or self lock. - void PopulateMemoryKindInfoForDuckTypedDevices(); - - // Requires the self lock or GIL is held. - bool IsFullyAddressable(); - - // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and - // non-empty. - xla::nb_class_ptr py_client_; - - // Either C++ `ifrt::DeviceList` or Python duck-type devices. - // TODO(hyeontaek): Remove support for Python duck-type devices once all - // JAX backends and tests are migrated to use an `xla::ifrt::Device` type - // for JAX devices. - // Immutable after constructor; no locking needed. - std::variant device_list_; - - // Populated on demand. Guarded by the object's self lock. - std::optional hash_; - // TODO(hyeontaek): Make the following property cached within - // `xla::ifrt::DeviceList`. - // Populated on demand. Guarded by the object's self lock. - std::optional is_fully_addressable_; - // Populated on demand. Guarded by the object's self lock. - std::optional> addressable_device_list_; - - struct MemoryKindInfo { - nanobind::object default_memory_kind; - nanobind::tuple memory_kinds; - }; - // Populated on demand. Guarded by the object's self lock. - std::optional> memory_kind_info_; -}; - -} // namespace jax +#include "jaxlib/py_device_list.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 9c8ce8010c90..5cc0f2d6ac6c 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,239 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PY_EXECUTABLE_H_ #define JAXLIB_XLA_PY_EXECUTABLE_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "llvm/Support/Casting.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_array.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/traceback.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/pjrt/exceptions.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/pjrt_layout.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/attribute_map.h" -#include "xla/python/ifrt/executable.h" -#include "xla/python/pjrt_ifrt/pjrt_executable.h" -#include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/platform/status.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class PyToken { - public: - PyToken() = default; - explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} - - static PyToken ReadyPyToken() { - return PyToken(PjRtFuture<>(absl::OkStatus())); - } - - absl::Status Await(); - - private: - PjRtFuture<> future_; -}; - -// PyShardedToken contains a PyToken for each device's execution. -class PyShardedToken { - public: - // Default construction creates a always-ready token. - PyShardedToken() = default; - explicit PyShardedToken(std::vector> futures) - : futures_(std::move(futures)) {} - - PyToken GetPyToken(int device_id) const { - if (futures_.empty()) return PyToken::ReadyPyToken(); - return PyToken(futures_.at(device_id)); - } - - absl::Status Await(); - - private: - std::vector> futures_; -}; - -class PyExecuteResults { - public: - PyExecuteResults(const nb_class_ptr& client, - std::vector> ifrt_arrays, - int num_computations, PyShardedToken token, - PjRtFuture<> result_status = PjRtFuture<>()); - - std::vector> DisassembleIntoSingleDeviceArrays(); - - std::vector> DisassemblePrefixIntoSingleDeviceArrays( - size_t n); - - std::vector ConsumeWithHandlers( - std::vector> - out_handlers); - - std::vector> Consume(); - - PyShardedToken ConsumeToken(); - - size_t Size() const { - CheckNotDisassembled(); - return ifrt_arrays_.size(); - } - - void CheckNotDisassembled() const; - - private: - bool is_exploded_ = false; - bool token_consumed_ = false; - nb_class_ptr client_; - std::vector> ifrt_arrays_; - int num_computations_; - PyShardedToken token_; - // Only set if the computation has tokens. - PjRtFuture<> result_status_; -}; - -using ExecuteShardedArg = std::variant>; - -// Python wrapper around PjRtExecutable. We use a wrapper class: -// a) to keep the PyClient alive via a std::shared_ptr<> -// b) to add Python-specific functionality. -class PyLoadedExecutable { - public: - PyLoadedExecutable( - nb_class_ptr client, - std::shared_ptr ifrt_loaded_executable, - std::optional traceback, - std::optional fingerprint); - ~PyLoadedExecutable(); - - nb_class_ptr client() const { return client_; } - ifrt::LoadedExecutable* ifrt_loaded_executable() const { - return ifrt_loaded_executable_.get(); - } - - std::shared_ptr shared_ifrt_loaded_executable() { - return ifrt_loaded_executable_; - } - - std::vector> AddressableDevices() const; - - int64_t SizeOfGeneratedCodeInBytes() const { - return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); - } - - absl::StatusOr GetCompiledMemoryStats() const { - nanobind::gil_scoped_release scope; - return ifrt_loaded_executable_->GetCompiledMemoryStats(); - } - - absl::StatusOr GetCostAnalysis() const { - return ifrt_loaded_executable_->GetCostAnalysis(); - } - - void Delete() { - // TODO(hyeontaek): Return absl::Status. - TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); - } - - bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } - - // Takes args indexed by argid then deviceid, transposes them, and passes to - // PjRtExecutable::Execute. The result is similarly transposed back into the - // argid,deviceid format. - // args is [num_args x num_devices]. - absl::StatusOr ExecuteSharded( - std::vector args, bool with_tokens); - - absl::StatusOr>> HloModules() const; - - absl::StatusOr>> - GetOutputMemoryKinds() const; - - absl::StatusOr>> - GetParameterLayouts() const; - - absl::StatusOr>> - GetOutputLayouts() const; - - std::optional> GetParameterShardings() const; - - std::optional> GetOutputShardings() const; - - const std::optional& traceback() { return traceback_; } - - ifrt::LoadedExecutable* ifrt_executable() const { - return ifrt_loaded_executable_.get(); - } - - // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. - // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. - std::shared_ptr shared_ptr_pjrt_executable() { - auto* exec = llvm::dyn_cast_or_null( - ifrt_loaded_executable_.get()); - if (exec == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return exec->shared_ptr_pjrt_loaded_executable(); - } - - // Returns a template of execute options to pass to - // `ifrt_executable()->Execute()`. Note that the caller may need to override - // some options such as `launch_id` that change at each execution. - const ifrt::ExecuteOptions& options() const { return options_; } - - // Returns a unique launch ID to use for the next execution. - int64_t GetNextLaunchId(); - - const std::optional& fingerprint() const { return fingerprint_; } - - // Keep `obj` alive as long as PyLoadedExecutable. - void KeepAlive(nanobind::object obj); - - private: - friend class PyClient; - - nb_class_ptr client_; - std::shared_ptr ifrt_loaded_executable_; - std::optional traceback_; - - // Identical executables (i.e. representing the same program) will have the - // same fingerprint. nullopt on platforms or executables where fingerprints - // aren't implemented. - std::optional fingerprint_; - - // Launch ID to use for the next execution. - std::atomic next_launch_id_; - - // The options to pass to `executable_.Execute`. - ifrt::ExecuteOptions options_; - - // Python objects to keep alive as requested by user. - std::vector keepalives_; - - // Doubly-linked list of all executables known to the client. Protected by the - // GIL. - PyLoadedExecutable* next_; - PyLoadedExecutable* prev_; -}; - -} // namespace xla +#include "jaxlib/py_executable.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/python_ref_manager.h b/jaxlib/xla/python_ref_manager.h index c0630da2ebd5..09f995c198e2 100644 --- a/jaxlib/xla/python_ref_manager.h +++ b/jaxlib/xla/python_ref_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,93 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PYTHON_REF_MANAGER_H_ #define JAXLIB_XLA_PYTHON_REF_MANAGER_H_ -#include - -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "nanobind/nanobind.h" - -namespace xla { - -// Class that manages destruction of Python objects. -// -// We must not destroy Python objects without holding the GIL. However, we -// frequently want to hold references to Python objects for the duration of -// an asynchronous transfer on a Stream, and release our reference when the -// transfer completes. -// -// This class holds references to Python objects outside a GIL scope, that can -// be collected later when the GIL is held by calling CollectGarbage(). -class PythonRefManager { - public: - PythonRefManager() = default; - - // Holds references to a set of nanobind::objects, adding the references to - // the PythonRefManager on destruction. - class ManagedPyObjects { - public: - ManagedPyObjects() = default; - ManagedPyObjects(PythonRefManager* manager, - absl::Span objects); - - ~ManagedPyObjects(); - - ManagedPyObjects(const ManagedPyObjects& other) = delete; - ManagedPyObjects(ManagedPyObjects&& other) = default; - ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; - ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; - - private: - PythonRefManager* manager_ = nullptr; - absl::InlinedVector objects_; - }; - - // Creates a managed std::shared_ptr to an object. When the shared_ptr is - // destroyed, the reference to 'object' will be added to python_garbage_, - // and collected next time CollectGarbage() is called. - std::shared_ptr ManageReference(nanobind::object object); - std::shared_ptr ManageReferences( - absl::Span objects); - - // Adds garbage objects to the manager. - void AddGarbage(nanobind::object garbage); - void AddGarbage(absl::Span garbage); - void AddGarbage(absl::Span const> garbage); - - // Releases the contents of python_garbage_. Requires that the GIL is held. - // The client calls this method during API entry points where the GIL is held - // to free any garbage that has accumulated. - void CollectGarbage(); - - // Cheaper version of CollectGarbage() with relaxed consistency and frequency. - // The purpose of this function is to amortize lock acquisition costs over - // a larger number of API calls. - void MaybeCollectGarbage() { - if (garbage_count_.load(std::memory_order_relaxed) >= 100) { - CollectGarbage(); - } - } - - private: - absl::Mutex mu_; - std::deque python_garbage_ ABSL_GUARDED_BY(mu_); - - // Writes to garbage_count_ are protected by mu_, reads are not protected. - std::atomic garbage_count_{0}; -}; - -// A global PythonRefManager. Unless `CollectGarbage()` is called before -// shutdown, this container will hold on to Python objects and thus cause a -// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. -PythonRefManager* GlobalPyRefManager(); - -} // namespace xla +#include "jaxlib/python_ref_manager.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h index c0cf284c6dbd..dcb7089674f5 100644 --- a/jaxlib/xla/pytree.h +++ b/jaxlib/xla/pytree.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,393 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PYTREE_H_ #define JAXLIB_XLA_PYTREE_H_ -// See https://docs.jax.dev/en/latest/pytrees.html for the documentation -// about pytree. - -#include - -#include -#include -#include -#include -#include -#include - -// placeholder for index annotation headers -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/hash/hash.h" -#include "absl/types/span.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/pytree.pb.h" - -namespace xla { - -enum class PyTreeKind { - kLeaf, // An opaque leaf node - kNone, // None. - kTuple, // A tuple - kNamedTuple, // A collections.namedtuple - kList, // A list - kDict, // A dict - kCustom, // A custom type. - kDataclass, // A dataclass. -}; - -// Registry of custom node types. -class PyTreeRegistry { - public: - PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, - bool enable_list, bool enable_dict); - - PyTreeRegistry(const PyTreeRegistry&) = delete; - PyTreeRegistry(PyTreeRegistry&&) = delete; - PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; - PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; - - struct Registration { - PyTreeKind kind; - - // The following values are populated for custom types. - // The Python type object, used to identify the type. - nanobind::object type; - // A function with signature: object -> (iterable, aux_data) - nanobind::callable to_iterable; - // A function with signature: (aux_data, iterable) -> object - nanobind::callable from_iterable; - // A function with signature: (aux_data, iterable(keypath, leaf)) -> object - std::optional to_iterable_with_keys; - - // Helper that calls to_iterable and validates that it returns a pair - // of an iterable and an aux_data object - std::pair ToIterable( - nanobind::handle o) const; - // Helper that calls to_iterable_with_keys and validates that it returns a - // pair of an iterable of key-leaf pairs and an aux_data object. If - // to_iterable_with_keys is not available, return a dummy key for each leaf, - // similar to the current jax.tree_util.FlattenedIndexKey. - std::pair>, - nanobind::object> - ToIterableWithKeys(nanobind::handle o) const; - - // For dataclasses. - std::vector data_fields; - std::vector meta_fields; - - int tp_traverse(visitproc visit, void* arg); - }; - - // Registers a new custom type. Objects of `type` will be treated as container - // node types in PyTrees. - void Register( - nanobind::object type, nanobind::callable to_iterable, - nanobind::callable from_iterable, - std::optional to_iterable_with_keys = std::nullopt); - // Same, but for dataclasses. - void RegisterDataclass(nanobind::object type, - std::vector data_fields, - std::vector meta_fields); - - // Finds the custom type registration for `type`. Returns nullptr if none - // exists. - const Registration* Lookup(nanobind::handle type) const; - - PyTreeKind KindOfObject(nanobind::handle obj, - PyTreeRegistry::Registration const** custom) const; - - // Flattens a pytree one level, returning either a tuple of the leaves and - // the node data, or None, if the entry is a leaf. - nanobind::object FlattenOneLevel(nanobind::handle x) const; - // Similar to above but returns a key-leaf pair for each leaf. - nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; - // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. - nanobind::object FlattenOneLevelImpl(nanobind::handle x, - bool with_keys) const; - - static PyType_Slot slots_[]; - - private: - struct TypeHash { - using is_transparent = void; - size_t operator()(const nanobind::object& t) const { - return absl::HashOf(t.ptr()); - } - size_t operator()(const nanobind::handle& t) const { - return absl::HashOf(t.ptr()); - } - }; - struct TypeEq { - using is_transparent = void; - bool operator()(const nanobind::object& a, - const nanobind::object& b) const { - return a.ptr() == b.ptr(); - } - bool operator()(const nanobind::object& a, - const nanobind::handle& b) const { - return a.ptr() == b.ptr(); - } - }; - mutable nanobind::ft_mutex mu_; - absl::flat_hash_map, TypeHash, - TypeEq> - registrations_; // Guarded by mu_ - bool enable_namedtuple_; - - static int tp_traverse(PyObject* self, visitproc visit, void* arg); - static int tp_clear(PyObject* self); -}; - -class SequenceKey { - public: - explicit SequenceKey(int idx) : idx_(idx) {}; - std::string ToReprString() const; - std::string ToString() const; - bool Equals(const nanobind::object& other); - int idx() const { return idx_; } - static nanobind::tuple MatchArgs(nanobind::handle unused); - - private: - int idx_; -}; - -class DictKey { - public: - explicit DictKey(nanobind::object key) : key_(key) {}; - std::string ToReprString() const; - std::string ToString() const; - bool Equals(const nanobind::object& other); - nanobind::object key() const { return key_; } - static nanobind::tuple MatchArgs(nanobind::handle unused); - static PyType_Slot slots_[]; - - private: - nanobind::object key_; - static int tp_traverse(PyObject* self, visitproc visit, void* arg); - static int tp_clear(PyObject* self); -}; - -class GetAttrKey { - public: - explicit GetAttrKey(nanobind::str name) : name_(name) {}; - std::string ToReprString() const; - std::string ToString() const; - bool Equals(const nanobind::object& other); - nanobind::str name() const { return name_; } - static nanobind::tuple MatchArgs(nanobind::handle unused); - - private: - nanobind::str name_; -}; - -class FlattenedIndexKey { - public: - explicit FlattenedIndexKey(int key) : key_(key) {}; - std::string ToReprString() const; - std::string ToString() const; - bool Equals(const nanobind::object& other); - int key() const { return key_; } - static nanobind::tuple MatchArgs(nanobind::handle unused); - - private: - int key_; -}; - -// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of -// Python values, where the interior nodes are tuples, lists, dictionaries, or -// user-defined containers, and the leaves are other objects. -class PyTreeDef { - public: - // Unowned registry: the registry must remain live at least as long as the - // PyTreeDef. It is the caller's responsibility to enforce this. - explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} - - explicit PyTreeDef(nb_class_ptr registry) - : registry_(registry.get()), registry_ref_(std::move(registry)) {} - - // Flattens a Pytree into a list of leaves and a PyTreeDef. - // Returns references to the flattened objects, which might be temporary - // objects in the case of custom pytype handlers. - static std::pair, nb_class_ptr> - Flatten(nanobind::handle x, nb_class_ptr registry, - std::optional leaf_predicate = std::nullopt); - - // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). - // `leaves` owns references to the flattened objects, which might be - // temporary objects in the case of custom pytype handlers. - void Flatten(nanobind::handle handle, std::vector& leaves, - std::optional leaf_predicate = std::nullopt); - void Flatten(nanobind::handle handle, - absl::InlinedVector& leaves, - std::optional leaf_predicate = std::nullopt); - void Flatten(nanobind::handle handle, nanobind::list& leaves, - std::optional leaf_predicate = std::nullopt); - - void FlattenWithPath( - nanobind::handle handle, nanobind::list& leaves, - std::optional leaf_predicate = std::nullopt); - - // Tests whether the given list is a flat list of leaves. - static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); - - // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of - // the tree-structure of 'x'. For example, if we flatten a value - // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the - // list of leaves [1, (2, 3), {"foo": 4}]. - nanobind::list FlattenUpTo(nanobind::handle x) const; - - // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. - nanobind::object Unflatten(nanobind::iterable leaves) const; - nanobind::object Unflatten(absl::Span leaves) const; - - // Composes two PyTreeDefs, replacing the leaves of this tree with copies of - // `inner`. The returned PyTreeDef holds a reference to its registry. - nb_class_ptr Compose(const PyTreeDef& inner) const; - - // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. - static nb_class_ptr Tuple(nb_class_ptr registry, - nanobind::list defs); - - // The returned PyTreeDefs hold a reference to the registry. - std::vector> Children() const; - - // Maps a function over a PyTree structure, applying f_leaf to each leaf, and - // f_node(node, node_data) to each container node. - nanobind::object Walk(const nanobind::callable& f_node, - nanobind::handle f_leaf, - nanobind::iterable leaves) const; - - // Given a tree of iterables with the same node/leaf structure as this PyTree, - // build the corresponding PyTree. - // TODO(phawkins): use flattening everywhere instead and delete this method. - nanobind::object FromIterableTree(nanobind::handle xs) const; - - int num_leaves() const { - if (traversal_.empty()) { - return 0; - } - return traversal_.back().num_leaves; - } - - int num_nodes() const { return traversal_.size(); } - - PyTreeRegistry* registry() const { return registry_; } - - size_t Hash() const; - - bool operator==(const PyTreeDef& other) const; - bool operator!=(const PyTreeDef& other) const { return !(*this == other); } - - std::string ToString() const; - - // Transforms the PyTreeDef into a pickleable object. Used to implement - // `PyTreeDef.__getstate__`. - nanobind::object ToPickle() const; - - // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used - // to implement `PyTreeDef.__setstate__`. - void FromPickle(nanobind::object pickleable); - - void SerializeTo(jax::PyTreeDefProto& result) const; - - static nb_class_ptr DeserializeFrom( - nb_class_ptr registry, const jax::PyTreeDefProto& input); - - std::optional> GetNodeData() - const; - - static nb_class_ptr MakeFromNodeDataAndChildren( - nb_class_ptr registry, - std::optional> node_data, - nanobind::iterable children); - - static PyType_Slot slots_[]; - - private: - void SetNumLeavesAndNumNodes(); - - struct Node { - PyTreeKind kind = PyTreeKind::kLeaf; - - // Arity for non-kLeaf types. - int arity = 0; - - // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type - // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom - // type, contains the auxiliary data returned by the `to_iterable` function. - nanobind::object node_data; - - // Kind-specific auxiliary data specialized for kDict. Use a c++ vector - // to hold the sorted dict keys instead of a py::list to avoid creating - // a new python list object when flattening kDict. For deeply nested dict, - // using c++ vector instead of py::list avoids creating too many python - // objects that make python gc sweep slow. - std::vector sorted_dict_keys; - - // Custom type registration. Must be null for non-custom types. - const PyTreeRegistry::Registration* custom = nullptr; - - // Number of leaf nodes in the subtree rooted at this node. - int num_leaves = 0; - - // Number of leaf and interior nodes in the subtree rooted at this node. - int num_nodes = 0; - - int tp_traverse(visitproc visit, void* arg) const; - }; - template - friend H AbslHashValue(H h, const Node& n); - - template - friend H AbslHashValue(H h, const PyTreeDef& t); - - // Helper that manufactures an instance of a node given its children. - static nanobind::object MakeNode(const Node& node, - absl::Span children); - - // Recursive helper used to implement FromIterableTree() - nanobind::object FromIterableTreeHelper( - nanobind::handle xs, - absl::InlinedVector::const_reverse_iterator* it) - const; - - template - void FlattenImpl(nanobind::handle handle, T& leaves, - const std::optional& leaf_predicate, - std::optional>& keypath); - - template - nanobind::object UnflattenImpl(T leaves) const; - - static int tp_traverse(PyObject* self, visitproc visit, void* arg); - static int tp_clear(PyObject* self); - - // Pytree registry. Not owned. - PyTreeRegistry* registry_; - // If this class holds a reference to `registry`, it is held by - // `registry_ref_`. - nb_class_ptr registry_ref_; - - // Nodes, in a post-order traversal. We use an ordered traversal to minimize - // allocations, and post-order corresponds to the order we need to rebuild the - // tree structure. - absl::InlinedVector traversal_; -}; - -template -H AbslHashValue(H h, const PyTreeDef::Node& n) { - h = H::combine(std::move(h), n.kind, n.arity, n.custom); - return h; -} - -template -H AbslHashValue(H h, const PyTreeDef& t) { - h = H::combine(std::move(h), t.traversal_); - return h; -} - -void BuildPytreeSubmodule(nanobind::module_& m); - -} // namespace xla +#include "jaxlib/pytree.h" // IWYU pragma: keep #endif // JAXLIB_XLA_PYTREE_H_ diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 7ce10e7ed763..f47dd265651a 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,226 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_SHARDING_H_ #define JAXLIB_XLA_SHARDING_H_ -#include - -#include -#include -#include - -// placeholder for index annotation headers -#include "absl/hash/hash.h" -#include "absl/status/statusor.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" -#include "jaxlib/xla/py_client.h" -#include "jaxlib/xla/py_device_list.h" -#include "jaxlib/xla/sharded_device_array.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_numpy.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace jax { - -class Sharding { - public: - Sharding() = default; - - // This constructor is used in the fast path to retrieve the number of devices - // without falling back to python. This is only used in the cpp path. - explicit Sharding(int num_devices) : num_devices_(num_devices) {} - - virtual ~Sharding() = default; - - static int SafeNumDevices(nanobind::handle sharding); - - private: - std::optional num_devices_; -}; - -// Gets `jax::PyDeviceList` from a JAX Sharding. -absl::StatusOr> GetPyDeviceList( - nanobind::handle sharding); - -// Checks if the memory kind is valid, and canonicalizes the -// memory kind to default memory on backends that support memories. -nanobind::object CheckAndCanonicalizeMemoryKind( - nanobind::object memory_kind, - const xla::nb_class_ptr& device_list); - -// Returns a hash that may sometimes return different hashes for equal values. -// It is not a correct implementation of `__hash__` in python, but it's fine -// for jit/pjit dispatch since it only causes spurious cache misses. -size_t ShardingHash(nanobind::handle sharding); - -bool ShardingEqual(nanobind::handle a, nanobind::handle b); - -class NamedSharding : public Sharding { - public: - NamedSharding(nanobind::object mesh, nanobind::object spec, - nanobind::object memory_kind, - nanobind::object logical_device_ids); - - const nanobind::object& mesh() const { return mesh_; } - const nanobind::object& spec() const { return spec_; } - const nanobind::object& memory_kind() const { return memory_kind_; } - const nanobind::object& logical_device_ids() const { - return logical_device_ids_; - } - - static nanobind::handle type() { return type_; } - static void InitializeType(); - - absl::StatusOr> internal_device_list() const { - if (internal_device_list_) { - return *internal_device_list_; - } - return xla::InvalidArgument( - "internal_device_list is not implemented for " - "`jax.sharding.AbstractMesh`"); - } - - private: - nanobind::object mesh_; - nanobind::object spec_; - nanobind::object memory_kind_; - nanobind::object logical_device_ids_; - std::optional> internal_device_list_; - static PyObject* type_; -}; - -class SingleDeviceSharding : public Sharding { - public: - explicit SingleDeviceSharding( - nanobind::object device, nanobind::object memory_kind = nanobind::none()); - - // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. - SingleDeviceSharding(xla::nb_class_ptr client, - xla::ifrt::DeviceListRef device_list, - nanobind::object memory_kind); - - const nanobind::object& device() const { return device_; } - const nanobind::object& memory_kind() const { return memory_kind_; } - - static nanobind::handle type() { return type_; } - static void InitializeType(); - - xla::nb_class_ptr internal_device_list() const { - return internal_device_list_; - } - - private: - nanobind::object device_; - nanobind::object memory_kind_; - xla::nb_class_ptr internal_device_list_; - - static PyObject* type_; -}; - -// The C++ implementation of jax.PmapSharding in python. It contains a few key -// data members and methods that are performance-critical. -class PmapSharding : public Sharding { - public: - PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); - - ~PmapSharding() override = default; - - xla::nb_numpy_ndarray devices() const { return devices_; } - - const ShardingSpec& sharding_spec() const { return sharding_spec_; } - - static nanobind::handle type() { return type_; } - static void InitializeType(); - - xla::nb_class_ptr internal_device_list() const { - return internal_device_list_; - } - - private: - xla::nb_numpy_ndarray devices_; - ShardingSpec sharding_spec_; - xla::nb_class_ptr internal_device_list_; - static PyObject* type_; -}; - -class GSPMDSharding : public Sharding { - public: - GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, - nanobind::object memory_kind, nanobind::object device_list) - : GSPMDSharding( - std::move(devices), - xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), - std::move(memory_kind), std::move(device_list)) {} - - GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, - nanobind::object memory_kind, nanobind::object device_list); - - const nanobind::tuple& devices() const { return devices_; } - const nanobind::object& memory_kind() const { return memory_kind_; } - - size_t Hash() { - if (!hash_.has_value()) { - hash_ = CalculateHash(); - } - return *hash_; - } - - static nanobind::handle type() { return type_; } - static void InitializeType(); - - const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } - - bool operator==(const GSPMDSharding& other) const { - return AreOpShardingsEqual(*this, other) && - this->devices().equal(other.devices()) && - this->memory_kind().equal(other.memory_kind()); - } - - xla::nb_class_ptr internal_device_list() const { - return internal_device_list_; - } - - private: - size_t CalculateHash() const { - // We only hash `hlo_sharding_` here for performance. - return absl::Hash()(hlo_sharding_); - } - - static bool AreOpShardingsEqual(const GSPMDSharding& a, - const GSPMDSharding& b) { - // If the OpSharding object is the same, return true - if (&a.hlo_sharding() == &b.hlo_sharding()) { - return true; - } - // If both OpShardings are replicated, return true - if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { - return true; - } - return a.hlo_sharding() == b.hlo_sharding(); - } - - bool IsOpShardingReplicated() const { - // For JAX, shardings with 1 device are considered as replicated in its - // semantics so that downstream things continue to work. - if (hlo_sharding_.tile_assignment().num_elements() == 1) { - return true; - } - return hlo_sharding().IsReplicated(); - } - - nanobind::tuple devices_; - xla::HloSharding hlo_sharding_; - nanobind::object memory_kind_; - std::optional hash_; - xla::nb_class_ptr internal_device_list_; - - static PyObject* type_; -}; - -void RegisterSharding(nanobind::module_& m); - -} // namespace jax +#include "jaxlib/sharding.h" // IWYU pragma: keep #endif // JAXLIB_XLA_SHARDING_H_ diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h index 685ecc5f8793..bb993233850a 100644 --- a/jaxlib/xla/traceback.h +++ b/jaxlib/xla/traceback.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors +/* Copyright 2025 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,94 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_TRACEBACK_H_ #define JAXLIB_XLA_TRACEBACK_H_ -#include - -#include -#include -#include -#include - -// placeholder for index annotation headers -#include "absl/container/inlined_vector.h" -#include "nanobind/nanobind.h" -#include "jaxlib/xla/nb_class_ptr.h" - -namespace xla { - -// Represents a Python traceback. This object is designed to be allocated on -// the Python heap; creating or destroying a traceback requires the GIL. -class Traceback { - public: - // Requires GIL. Creates a Traceback object that requires destructor to be - // invoked with GIL held as well. - static std::optional> Get(); - - // Requires GIL. - static bool enabled() { return enabled_; } - // Requires GIL. - static void SetEnabled(bool enabled); - - // Requires GIL. Don't call this directly, you're looking for Get(). - Traceback(); - // Requires GIL. - ~Traceback(); - - Traceback(const Traceback&) = delete; - Traceback(Traceback&& other) noexcept; - Traceback& operator=(const Traceback&) = delete; - Traceback& operator=(Traceback&&) = delete; - - // Requires the GIL be held. - std::string ToString() const; - - struct Frame { - nanobind::str file_name; - nanobind::str function_name; - int function_start_line; - int line_num; - - std::string ToString() const; - }; - std::vector Frames() const; - - const absl::InlinedVector, 32>& raw_frames() - const { - return frames_; - } - - // Returns the traceback as a fake Python Traceback object, suitable for - // using as an exception traceback. - nanobind::object AsPythonTraceback() const; - - bool operator==(const Traceback& other) const { - return frames_ == other.frames_; - } - bool operator!=(const Traceback& other) const { - return frames_ != other.frames_; - } - - private: - // Each frame is a pair of a code object and a "lasti" instruction location - // in bytes. The size of _Py_CODEUNIT has changed across different Python - // versions; the lasti value here has already been multiplied by - // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions - // like PyCode_Addr2Line(). - absl::InlinedVector, 32> frames_; - - // Protected by GIL. - static bool enabled_; -}; - -using nb_traceback = nb_class_ptr; - -template -H AbslHashValue(H h, const Traceback& traceback) { - h = H::combine(std::move(h), traceback.raw_frames()); - return h; -} - -void BuildTracebackSubmodule(nanobind::module_& m); - -} // namespace xla +#include "jaxlib/traceback.h" // IWYU pragma: keep #endif // JAXLIB_XLA_TRACEBACK_H_ diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 0cbd2b3f3b4d..4eb4a2d7939f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 The JAX Authors +# Copyright 2017 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,970 +14,8 @@ # ============================================================================== """An XLA client in Python.""" -from __future__ import annotations +# ruff: noqa: F401 +# ruff: noqa: F403 -import atexit -from collections.abc import Mapping, Sequence -import contextlib -import enum -import gzip -import inspect -import logging -import os -import threading -from typing import Any, Protocol, Union - -import ml_dtypes -import numpy as np - -from jaxlib import xla_extension as _xla - -# Note this module does *not* depend on any Python protocol buffers. The XLA -# Python bindings are currently packaged both as part of jaxlib and as part -# of TensorFlow. If we use protocol buffers here, then importing both jaxlib -# and TensorFlow may fail with duplicate protocol buffer message definitions. - -# Most functions are snake_case for consistency with other modules, some -# method names are CamelCase for consistency with XLA. -# pylint: disable=invalid-name - -# Pylint has false positives for type annotations. -# pylint: disable=invalid-sequence-index - -ifrt_programs = _xla.ifrt_programs -ops = _xla.ops -profiler = _xla.profiler - -# Just an internal arbitrary increasing number to help with backward-compatible -# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 330 - -# An internal increasing version number for protecting jaxlib code against -# ifrt changes. -# lives in xla/python/version.h. -# In JAX, reference this via jax._src.lib.ifrt_version. -_ifrt_version = _xla.ifrt_version_number - -xla_platform_names = { - 'cpu': 'Host', - 'gpu': 'CUDA', -} - -logger = logging.getLogger(__name__) - -_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] - - -def make_cpu_client( - asynchronous=True, - distributed_client=None, - node_id=0, - num_nodes=1, - collectives=None, - num_devices=None, -) -> Client: - register_custom_call_handler('cpu', _xla.register_custom_call_target) - register_custom_type_id_handler('cpu', _xla.register_custom_type_id) - return _xla.get_tfrt_cpu_client( - asynchronous=asynchronous, - distributed_client=distributed_client, - node_id=node_id, - num_nodes=num_nodes, - collectives=collectives, - num_devices=num_devices, - ) - - -DeviceTopology = _xla.DeviceTopology -get_topology_for_devices = _xla.get_topology_for_devices - - -def make_tfrt_tpu_c_api_device_topology( - topology_name: str = '', **kwargs -) -> DeviceTopology: - """Creates a PJRT C API TopologyDescription.""" - return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) - - -def make_c_api_device_topology( - c_api: Any, topology_name: str = '', **kwargs -) -> DeviceTopology: - """Creates a PJRT C API TopologyDescription.""" - return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) - - -def pjrt_plugin_loaded(plugin_name: str) -> bool: - return _xla.pjrt_plugin_loaded(plugin_name) - - -def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: - return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) - - -def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: - return _xla.load_pjrt_plugin(plugin_name, None, c_api) - - -def pjrt_plugin_initialized(plugin_name: str) -> bool: - return _xla.pjrt_plugin_initialized(plugin_name) - - -def initialize_pjrt_plugin(plugin_name: str) -> None: - """Initializes a PJRT plugin. - - The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or - static linking) before this method is called. - Args: - plugin_name: the name of the PJRT plugin. - """ - _xla.initialize_pjrt_plugin(plugin_name) - - -def make_c_api_client( - plugin_name: str, - options: _NameValueMapping | None = None, - distributed_client: _xla.DistributedRuntimeClient | None = None, -): - """Creates a PJRT C API client for a PJRT plugin. - - It is required that load_pjrt_plugin_dynamically is called once with the same - plugin_name before this method is called. - - Args: - plugin_name: the name of the PJRT plugin. - options: extra platform-specific options. - distributed_client: distributed client. - - Returns: - A PJRT C API client for plugin_name. - """ - if options is None: - options = {} - return _xla.get_c_api_client(plugin_name, options, distributed_client) - - -def make_tpu_client( - library_path: str | None = None, options: _NameValueMapping | None = None -): - """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" - if not pjrt_plugin_loaded('tpu'): - c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') - profiler.register_plugin_profiler(c_api) - assert pjrt_plugin_loaded('tpu') - if not pjrt_plugin_initialized('tpu'): - initialize_pjrt_plugin('tpu') - if options is None: - options = {} - return _xla.get_c_api_client('tpu', options) - - -def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: - """Generates the PjRt GPU plugin options. - - Returns: - A dictionary of plugin options. - """ - - options = {} - options['platform_name'] = 'cuda' - allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() - memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') - deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') - if deprecated_memory_fraction: - if memory_fraction: - raise ValueError( - 'XLA_CLIENT_MEM_FRACTION is specified together ' - 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' - 'Remove the latter one, it is deprecated.' - ) - else: - memory_fraction = deprecated_memory_fraction - preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') - collective_memory_size = os.getenv( - 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' - ) - if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): - raise ValueError( - 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' - '"bfc", or "cuda_async", got "%s"' % allocator - ) - options['allocator'] = allocator - if memory_fraction: - options['memory_fraction'] = float(memory_fraction) - if preallocate: - options['preallocate'] = preallocate not in ('false', 'False', '0') - if collective_memory_size: - options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) - return options - - -class OpMetadata: - """Python representation of a xla.OpMetadata protobuf.""" - - __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') - - def __init__(self, op_type='', op_name='', source_file='', source_line=0): - self.op_type = op_type - self.op_name = op_name - self.source_file = source_file - self.source_line = source_line - - -def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): - """Helper for use in source mapping that returns an OpMetadata object.""" - full_filename, lineno = inspect.stack()[skip_frames][1:3] - filename = os.path.basename(full_filename) - return OpMetadata( - op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno - ) - - -PrimitiveType = _xla.PrimitiveType - -XLA_ELEMENT_TYPE_TO_DTYPE = { - PrimitiveType.PRED: np.dtype('bool'), - PrimitiveType.S4: np.dtype(ml_dtypes.int4), - PrimitiveType.S8: np.dtype('int8'), - PrimitiveType.S16: np.dtype('int16'), - PrimitiveType.S32: np.dtype('int32'), - PrimitiveType.S64: np.dtype('int64'), - PrimitiveType.U4: np.dtype(ml_dtypes.uint4), - PrimitiveType.U8: np.dtype('uint8'), - PrimitiveType.U16: np.dtype('uint16'), - PrimitiveType.U32: np.dtype('uint32'), - PrimitiveType.U64: np.dtype('uint64'), - PrimitiveType.F4E2M1FN: np.dtype(ml_dtypes.float4_e2m1fn), - PrimitiveType.F8E3M4: np.dtype(ml_dtypes.float8_e3m4), - PrimitiveType.F8E4M3: np.dtype(ml_dtypes.float8_e4m3), - PrimitiveType.F8E4M3FN: np.dtype(ml_dtypes.float8_e4m3fn), - PrimitiveType.F8E4M3B11FNUZ: np.dtype(ml_dtypes.float8_e4m3b11fnuz), - PrimitiveType.F8E4M3FNUZ: np.dtype(ml_dtypes.float8_e4m3fnuz), - PrimitiveType.F8E5M2: np.dtype(ml_dtypes.float8_e5m2), - PrimitiveType.F8E5M2FNUZ: np.dtype(ml_dtypes.float8_e5m2fnuz), - PrimitiveType.F8E8M0FNU: np.dtype(ml_dtypes.float8_e8m0fnu), - PrimitiveType.BF16: np.dtype(ml_dtypes.bfloat16), - PrimitiveType.F16: np.dtype('float16'), - PrimitiveType.F32: np.dtype('float32'), - PrimitiveType.F64: np.dtype('float64'), - PrimitiveType.C64: np.dtype('complex64'), - PrimitiveType.C128: np.dtype('complex128'), - PrimitiveType.TUPLE: np.dtype(np.object_), - PrimitiveType.TOKEN: np.dtype(np.object_), -} - -# Note the conversion on the key. Numpy has a known issue wherein dtype hashing -# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, -# when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} - - -def dtype_to_etype(dtype): - """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" - return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - - -Shape = _xla.Shape -Shape.__doc__ = """ -A Shape is an object defined in C++ that duck types like the following class: - -class Shape: - '''Represents an XLA shape. - - A shape is either an array shape, having rank-many integer - dimensions and an element type (represented by a Numpy dtype), or it - is a tuple shape, having a shape for every tuple component: - - type shape = - TupleShape of shape list - | ArrayShape of { dimensions: int list; element_type: dtype } - ''' - - @staticmethod - def tuple_shape(tuple_shapes) -> Shape: - "Construct a tuple shape." - - @staticmethod - def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: - - @staticmethod - def from_pyval(pyval) -> Shape: - "Returns a Shape that describes a tuple-tree of Numpy arrays." - - def __init__(self, str) -> Shape: - "Parses a shape string." - def __eq__(self, other: Shape) -> bool: - def __ne__(self, other: Shape) -> bool: - def __hash__(self): - def __repr__(self): - def is_tuple(self) -> bool: - def is_array(self) -> bool: - def tuple_shapes(self) -> [Shape]: - def numpy_dtype(self) -> np.dtype: - "Like element_type(), but returns dtype('O') for a tuple shape." - def xla_element_type(self) -> PrimitiveType: - def element_type(self) -> np.dtype: - def dimensions(self) -> (int, int, ...): - def rank(self) -> int: - def with_major_to_minor_layout_if_absent(self) -> Shape: - "Returns a copy with missing layouts set to major-to-minor." - - def to_serialized_proto(self) -> bytes: - "Returns 'shape' as a serialized proto." -""" - -ProgramShape = _xla.ProgramShape -ProgramShape.__doc__ = """ -A ProgramShape is a C++ object that duck types like the following class. - -class ProgramShape: - def __init__(self, parameter_shapes, result_shape): - def parameter_shapes(self) -> [Shape]: - def result_shape(self) -> Shape: - def __repr__(self): -""" - -ShapeIndex = _xla.ShapeIndex -ShapeIndex.__doc__ = """ -A Shape is an object defined in C++ that duck types like the following class: - -class ShapeIndex: - '''Represents an XLA ShapeIndex. - - An index for specifying a particular nested subshape within a shape. Used in - ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through - the Shape tree where each element of ShapeIndex indexes into a tuple (or - nested tuple) within the shape. For a non-nested tuple, an index has a single - element. - ''' - - def __init__(self, List[int]) -> ShapeIndex: - def __eq__(self, other: Shape) -> bool: - def __ne__(self, other: Shape) -> bool: - def __hash__(self): - def __repr__(self): -""" - - -def shape_from_pyval(pyval, layout: Sequence[int] | None = None): - """Returns a Shape that describes a tuple-tree of Numpy arrays.""" - - def convert(pyval): - if isinstance(pyval, tuple): - if layout is not None: - raise NotImplementedError( - 'shape_from_pyval does not support layouts for tuple shapes' - ) - return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) - else: - return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) - - return convert(pyval) - - -DeviceAssignment = _xla.DeviceAssignment -DeviceAssignment.__doc__ = """ -A DeviceAssignment is a C++ object with the following signature. - -def create(assignment): - '''Builds a device assignment. - - Args: - assignment: a 2D numpy array of device ordinal integers, indexed by - [replica][computation_in_replica]. - Returns: - A device assignment. - ''' - -def replica_count(): - '''Returns the number of replicas.''' -def computation_count(): - '''Returns the number of computations per replica.''' -""" - -Device = _xla.Device -CompileOptions = _xla.CompileOptions - -HostBufferSemantics = _xla.HostBufferSemantics - -# An Executable is a C++ class that duck types with the following API: -# class Executable: -# def local_devices(self) -> [Device]: -# def execute(self, arguments : [Buffer]) -> Buffer: -# """Execute on one replica with Buffer arguments and return value.""" -# -# def size_of_generated_code_in_bytes(self) -> int: -# """Return generated binary size, or -1 if not known.""" -# -# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) -# -> [Buffer]: -# """Execute on many replicas with Buffer arguments and return value. -# -# Args: -# arguments: A sequence of sequences of Buffers. The i'th element of each -# sequence comprises the arguments for execution on the i'th local -# device. -# -# Returns: -# A list of the computation's outputs as a list of Buffers for each -# device. -# """ -# -# There are different implementations of Executable for different backends. - - -class PaddingType(enum.Enum): - VALID = 1 - SAME = 2 - - -def window_padding_type_to_pad_values( - padding_type, lhs_dims, rhs_dims, window_strides -): - """Maps PaddingType or string to pad values (list of pairs of ints).""" - if not isinstance(padding_type, (str, PaddingType)): - msg = 'padding_type must be str or PaddingType, got {}.' - raise TypeError(msg.format(type(padding_type))) - - if isinstance(padding_type, str): - if padding_type.upper() == 'VALID': - padding_type = PaddingType.VALID - elif padding_type.upper() == 'SAME': - padding_type = PaddingType.SAME - else: - msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' - raise ValueError(msg.format(padding_type)) - - if padding_type == PaddingType.VALID: - return [(0, 0)] * len(window_strides) - elif padding_type == PaddingType.SAME: - out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) - pad_sizes = [ - max((out_size - 1) * stride + filter_size - in_size, 0) - for out_size, stride, filter_size, in_size in zip( - out_shape, window_strides, rhs_dims, lhs_dims - ) - ] - return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] - else: - msg = 'Unexpected PaddingType value: {}' - raise ValueError(msg.format(padding_type)) - - -XlaBuilder = _xla.XlaBuilder -XlaComputation = _xla.XlaComputation -XlaOp = _xla.XlaOp -FftType = _xla.FftType -Client = _xla.Client -Memory = _xla.Memory -Array = _xla.Array -ArrayImpl = _xla.ArrayImpl -LoadedExecutable = _xla.LoadedExecutable -DeviceList = _xla.DeviceList -OpSharding = _xla.OpSharding -HloSharding = _xla.HloSharding -Sharding = _xla.Sharding -NamedSharding = _xla.NamedSharding -SingleDeviceSharding = _xla.SingleDeviceSharding -PmapSharding = _xla.PmapSharding -GSPMDSharding = _xla.GSPMDSharding -PjRtLayout = _xla.PjRtLayout -AutotuneCacheMode = _xla.AutotuneCacheMode -ResultAccuracyMode = _xla.ResultAccuracy_Mode - - -def LoadedExecutable_execute(self, arguments, device=None): - del device - results = self.execute_sharded(arguments) - return [x[0] for x in results.disassemble_into_single_device_arrays()] - - -def LoadedExecutable_execute_with_token(self, arguments, device=None): - del device - results = self.execute_sharded(arguments, with_tokens=True) - return ( - [x[0] for x in results.disassemble_into_single_device_arrays()], - results.consume_token().get_token(0), - ) - - -LoadedExecutable.execute = LoadedExecutable_execute -LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token - - -class CustomCallTargetTraits(enum.IntFlag): - DEFAULT = 0 - # Calls to custom call are safe to trace into the command buffer. It means - # that calls to custom call always launch exactly the same device operations - # (can depend on attribute values) that can be captured and then replayed. - # - # Supported only for custom calls implemented with XLA FFI. - COMMAND_BUFFER_COMPATIBLE = 1 - - -class CustomCallHandler(Protocol): - - def __call__( - self, - name: str, - fn: Any, - platform: str, - /, - api_version: int = ..., - traits: CustomCallTargetTraits = ..., - ) -> None: - ... - - -_custom_callback_handler: dict[str, CustomCallHandler] = {} -# Key is xla_platform_name, value is (function_name, function, api_version) -_custom_callback: dict[ - str, list[tuple[str, Any, int, CustomCallTargetTraits]] -] = {} -_custom_callback_lock = threading.Lock() - - -def register_custom_call_target( - name: str, - fn: Any, - platform: str = 'cpu', - api_version: int = 0, - traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, -) -> None: - """Registers a custom call target. - - Args: - name: bytes containing the name of the function. - fn: a PyCapsule object containing the function pointer. - platform: the target platform. - api_version: the XLA FFI version to use. Supported versions are: 0 for the - untyped FFI and 1 for the typed FFI. - traits: custom call traits corresponding to XLA FFI handler traits. - """ - # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" - # Since that is hardcoded to CUDA, we are using the following as workaround. - xla_platform_name = xla_platform_names.get(platform, platform) - with _custom_callback_lock: - if xla_platform_name in _custom_callback_handler: - _custom_callback_handler[xla_platform_name]( - name, fn, xla_platform_name, api_version, traits - ) - else: - _custom_callback.setdefault(xla_platform_name, []).append( - (name, fn, api_version, traits) - ) - - -def register_custom_call_handler( - platform: str, handler: CustomCallHandler -) -> None: - """Registers a custom handler and use it to register existing custom calls. - - If a custom call handler for the platform already exist, calling this method - is a no-op and it will not register a new handler. - - Args: - platform: the target platform. - handler: the function to register a custom call. - """ - xla_platform_name = xla_platform_names.get(platform, platform) - with _custom_callback_lock: - if xla_platform_name in _custom_callback_handler: - logger.debug( - 'Custom call handler for %s is already register. Will not register a' - ' new one', - xla_platform_name, - ) - return - _custom_callback_handler[xla_platform_name] = handler - if xla_platform_name in _custom_callback: - for name, fn, api_version, traits in _custom_callback[xla_platform_name]: - handler(name, fn, xla_platform_name, api_version, traits) - del _custom_callback[xla_platform_name] - - -class CustomTypeIdHandler(Protocol): - - def __call__(self, name: str, capsule: Any) -> None: - ... - - -_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} -_custom_type_id: dict[str, Any] = {} -_custom_type_id_lock = threading.Lock() - - -def register_custom_type_id( - type_name: str, - type_id: Any, - platform: str = 'cpu', -) -> None: - """Register a custom type id for use with the FFI. - - Args: - type_name: a unique name for the type. - type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. - platform: the target platform. - """ - xla_platform_name = xla_platform_names.get(platform, platform) - with _custom_type_id_lock: - if xla_platform_name in _custom_type_id_handler: - _custom_type_id_handler[xla_platform_name](type_name, type_id) - else: - _custom_type_id.setdefault(xla_platform_name, []).append( - (type_name, type_id) - ) - - -def register_custom_type_id_handler( - platform: str, handler: CustomTypeIdHandler -) -> None: - """Register a custom type id handler and use it to register existing type ids. - - If a custom type id handler for the platform already exist, calling this - method is a no-op and it will not register a new handler. - - Args: - platform: the target platform. - handler: the function to register a custom type id. - """ - xla_platform_name = xla_platform_names.get(platform, platform) - with _custom_callback_lock: - if xla_platform_name in _custom_type_id_handler: - logger.debug( - 'Custom type id handler for %s is already register. Will not ' - 'register a new one', - xla_platform_name, - ) - return - _custom_type_id_handler[xla_platform_name] = handler - if xla_platform_name in _custom_type_id: - for name, capsule in _custom_type_id[xla_platform_name]: - handler(name, capsule) - del _custom_type_id[xla_platform_name] - - -register_custom_call_partitioner = _xla.register_custom_call_partitioner -encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback -hlo_sharding_util = _xla.hlo_sharding_util -register_custom_call_as_batch_partitionable = ( - _xla.register_custom_call_as_batch_partitionable -) - - -class PaddingConfigDimension: - """Python representation of a xla.PaddingConfigDimension protobuf.""" - - __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') - - edge_padding_low: int - edge_padding_high: int - interior_padding: int - - def __init__(self): - self.edge_padding_low = 0 - self.edge_padding_high = 0 - self.interior_padding = 0 - - -class PaddingConfig: - """Python representation of a xla.PaddingConfig protobuf.""" - - __slots__ = ('dimensions',) - - def __init__(self): - self.dimensions = [] - - -def make_padding_config( - padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] -) -> PaddingConfig: - """Create PaddingConfig proto from list of triples of integers. - - Args: - padding_config: either a PaddingConfig or a list of integer triples - (edge_padding_low, edge_padding_high, interior_padding) representing the - configuration of the padding operation. - - Returns: - A `PaddingConfig` object. - """ - if not isinstance(padding_config, PaddingConfig): - triples = padding_config - padding_config = PaddingConfig() - for lo, hi, interior in triples: - dimension = PaddingConfigDimension() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - padding_config.dimensions.append(dimension) - return padding_config - - -class DotDimensionNumbers: - """Python representation of a xla.DotDimensionNumbers protobuf.""" - - __slots__ = ( - 'lhs_contracting_dimensions', - 'rhs_contracting_dimensions', - 'lhs_batch_dimensions', - 'rhs_batch_dimensions', - ) - - def __init__(self): - self.lhs_contracting_dimensions = [] - self.rhs_contracting_dimensions = [] - self.lhs_batch_dimensions = [] - self.rhs_batch_dimensions = [] - - -def make_dot_dimension_numbers( - dimension_numbers: Union[ - DotDimensionNumbers, - tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], - ] -) -> DotDimensionNumbers: - """Builds a DotDimensionNumbers object from a specification. - - Args: - dimension_numbers: either a `DotDimensionNumbers` or a nested tuple - `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of - integers representing the dimensions to treat as contracting dimensions - and batch dimensions on each input operand. - - Returns: - A `DotDimensionNumbers` object. - """ - if isinstance(dimension_numbers, (list, tuple)): - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers - dot_dims_proto = DotDimensionNumbers() - dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) - dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) - dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) - dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) - return dot_dims_proto - else: - return dimension_numbers - - -class ConvolutionDimensionNumbers: - """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" - - __slots__ = ( - 'input_batch_dimension', - 'input_feature_dimension', - 'input_spatial_dimensions', - 'kernel_input_feature_dimension', - 'kernel_output_feature_dimension', - 'kernel_spatial_dimensions', - 'output_batch_dimension', - 'output_feature_dimension', - 'output_spatial_dimensions', - ) - - def __init__(self): - self.input_batch_dimension = 0 - self.input_feature_dimension = 0 - self.input_spatial_dimensions = [] - self.kernel_input_feature_dimension = 0 - self.kernel_output_feature_dimension = 0 - self.kernel_spatial_dimensions = [] - self.output_batch_dimension = 0 - self.output_feature_dimension = 0 - self.output_spatial_dimensions = [] - - -def make_convolution_dimension_numbers( - dimension_numbers: Union[ - None, ConvolutionDimensionNumbers, tuple[str, str, str] - ], - num_spatial_dimensions: int, -) -> ConvolutionDimensionNumbers: - """Builds a ConvolutionDimensionNumbers object from a specification. - - Args: - dimension_numbers: optional, either a ConvolutionDimensionNumbers object or - a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length - N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the - output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions in - rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers consistent - with the Conv operation with two spatial dimensions, one could use - ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension - numbers consistent with the TensorFlow Conv2D operation, one could use - ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution - dimension specification, window strides are associated with spatial - dimension character labels according to the order in which the labels - appear in the rhs_spec string, so that window_strides[0] is matched with - the dimension corresponding to the first character appearing in rhs_spec - that is not 'I' or 'O'. By default, use the same dimension numbering as - Conv and ConvWithGeneralPadding. - num_spatial_dimensions: the number of spatial dimensions. - - Returns: - A `ConvolutionDimensionNumbers` object. - """ - if dimension_numbers is None: - nd = num_spatial_dimensions - dimension_numbers = ConvolutionDimensionNumbers() - dimension_numbers.input_batch_dimension = 0 - dimension_numbers.input_feature_dimension = 1 - dimension_numbers.output_batch_dimension = 0 - dimension_numbers.output_feature_dimension = 1 - dimension_numbers.kernel_output_feature_dimension = 0 - dimension_numbers.kernel_input_feature_dimension = 1 - dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) - elif isinstance(dimension_numbers, tuple): - lhs_spec, rhs_spec, out_spec = dimension_numbers - dimension_numbers = ConvolutionDimensionNumbers() - - dimension_numbers.input_batch_dimension = lhs_spec.index('N') - dimension_numbers.input_feature_dimension = lhs_spec.index('C') - dimension_numbers.output_batch_dimension = out_spec.index('N') - dimension_numbers.output_feature_dimension = out_spec.index('C') - dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') - dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') - - dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} - ) - dimension_numbers.input_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]), - ) - ) - dimension_numbers.output_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]), - ) - ) - return dimension_numbers - - -class PrecisionConfig: - """Python representation of a xla.PrecisionConfig protobuf.""" - - __slots__ = ('operand_precision',) - - Precision = _xla.PrecisionConfig_Precision - - def __init__(self): - self.operand_precision = [] - - -class ResultAccuracy: - """Python representation of a xla.ResultAccuracy protobuf.""" - - __slots__ = ('mode', 'atol', 'rtol', 'ulps') - - def __init__(self): - self.mode = _xla.ResultAccuracy_Mode.DEFAULT - self.atol = 0.0 - self.rtol = 0.0 - self.ulps = 0 - - -class GatherDimensionNumbers: - """Python representation of a xla.GatherDimensionNumbers protobuf.""" - - __slots__ = ( - 'offset_dims', - 'collapsed_slice_dims', - 'start_index_map', - 'index_vector_dim', - ) - - def __init__(self): - self.offset_dims = [] - self.collapsed_slice_dims = [] - self.start_index_map = [] - self.index_vector_dim = 0 - - -class ScatterDimensionNumbers: - """Python representation of a xla.ScatterDimensionNumbers protobuf.""" - - __slots__ = ( - 'update_window_dims', - 'inserted_window_dims', - 'scatter_dims_to_operand_dims', - 'index_vector_dim', - ) - - def __init__(self): - self.update_window_dims = [] - self.inserted_window_dims = [] - self.scatter_dims_to_operand_dims = [] - self.index_vector_dim = 0 - - -class ReplicaGroup: - """Python representation of a xla.ReplicaGroup protobuf.""" - - __slots__ = ('replica_ids',) - - def __init__(self): - self.replica_ids = [] - - -def _make_replica_group_proto(replica_group): - replica_group_proto = ReplicaGroup() - replica_group_proto.replica_ids.extend(replica_group) - return replica_group_proto - - -def make_replica_groups(replica_groups): - if replica_groups is None: - replica_groups_protos = [] # special value for XLA API - else: - replica_groups = list(replica_groups) - replica_groups_protos = [ - _make_replica_group_proto(group) for group in replica_groups - ] - return replica_groups_protos - - -Traceback = _xla.Traceback -Frame = _xla.Frame - - -@contextlib.contextmanager -def tracebacks(enabled=True): - """Context manager that enables or disables traceback collection.""" - saved = Traceback.enabled - Traceback.enabled = enabled - try: - yield - finally: - Traceback.enabled = saved - - -def heap_profile(client: Client) -> bytes: - """Returns a gzipped pprof protocol buffer containing a heap profile.""" - return gzip.compress(client.heap_profile()) - - -XlaRuntimeError = _xla.XlaRuntimeError - -# Perform one last garbage collection of deferred Python references. This is -# mostly to keep ASAN happy. -atexit.register(_xla.collect_garbage) - -array_result_handler = _xla.array_result_handler -batched_copy_array_to_devices_with_sharding = ( - _xla.batched_copy_array_to_devices_with_sharding -) -batched_device_put = _xla.batched_device_put -reorder_shards = _xla.reorder_shards -batched_block_until_ready = _xla.batched_block_until_ready -check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind -Layout = _xla.Layout -custom_call_targets = _xla.custom_call_targets -ArrayCopySemantics = _xla.ArrayCopySemantics +from jaxlib.xla_client import * # pylint: disable=wildcard-import +from jaxlib.xla_client import _xla # pylint: disable=unused-import diff --git a/jaxlib/xla_extension.py b/jaxlib/xla/xla_extension.py similarity index 78% rename from jaxlib/xla_extension.py rename to jaxlib/xla/xla_extension.py index e4fc7e96a1ab..5305b28f39dd 100644 --- a/jaxlib/xla_extension.py +++ b/jaxlib/xla/xla_extension.py @@ -1,4 +1,4 @@ -# Copyright 2025 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""An XLA client in Python.""" -from jaxlib.xla.xla_extension import * # noqa: F403 -from jaxlib.xla.xla_extension import sdy # noqa: F401 +# ruff: noqa: F401 +# ruff: noqa: F403 + +from jaxlib.xla_extension import * # pylint: disable=wildcard-import diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 01b01ecf704e..0cbd2b3f3b4d 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -1,4 +1,4 @@ -# Copyright 2025 The JAX Authors +# Copyright 2017 The JAX Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,972 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""An XLA client in Python.""" -from jaxlib.xla.xla_client import * # noqa: F403 -from jaxlib.xla.xla_client import _version # noqa: F401 -from jaxlib.xla.xla_client import _xla # noqa: F401 +from __future__ import annotations + +import atexit +from collections.abc import Mapping, Sequence +import contextlib +import enum +import gzip +import inspect +import logging +import os +import threading +from typing import Any, Protocol, Union + +import ml_dtypes +import numpy as np + +from jaxlib import xla_extension as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs +ops = _xla.ops +profiler = _xla.profiler + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. +_version = 330 + +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, +) -> Client: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + ) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client(plugin_name, options, distributed_client) + + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not pjrt_plugin_loaded('tpu'): + c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') + profiler.register_plugin_profiler(c_api) + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + return options + + +class OpMetadata: + """Python representation of a xla.OpMetadata protobuf.""" + + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') + + def __init__(self, op_type='', op_name='', source_file='', source_line=0): + self.op_type = op_type + self.op_name = op_name + self.source_file = source_file + self.source_line = source_line + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) + + +PrimitiveType = _xla.PrimitiveType + +XLA_ELEMENT_TYPE_TO_DTYPE = { + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S4: np.dtype(ml_dtypes.int4), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U4: np.dtype(ml_dtypes.uint4), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + PrimitiveType.F4E2M1FN: np.dtype(ml_dtypes.float4_e2m1fn), + PrimitiveType.F8E3M4: np.dtype(ml_dtypes.float8_e3m4), + PrimitiveType.F8E4M3: np.dtype(ml_dtypes.float8_e4m3), + PrimitiveType.F8E4M3FN: np.dtype(ml_dtypes.float8_e4m3fn), + PrimitiveType.F8E4M3B11FNUZ: np.dtype(ml_dtypes.float8_e4m3b11fnuz), + PrimitiveType.F8E4M3FNUZ: np.dtype(ml_dtypes.float8_e4m3fnuz), + PrimitiveType.F8E5M2: np.dtype(ml_dtypes.float8_e5m2), + PrimitiveType.F8E5M2FNUZ: np.dtype(ml_dtypes.float8_e5m2fnuz), + PrimitiveType.F8E8M0FNU: np.dtype(ml_dtypes.float8_e8m0fnu), + PrimitiveType.BF16: np.dtype(ml_dtypes.bfloat16), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +ShapeIndex = _xla.ShapeIndex +ShapeIndex.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class ShapeIndex: + '''Represents an XLA ShapeIndex. + + An index for specifying a particular nested subshape within a shape. Used in + ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through + the Shape tree where each element of ShapeIndex indexes into a tuple (or + nested tuple) within the shape. For a non-nested tuple, an index has a single + element. + ''' + + def __init__(self, List[int]) -> ShapeIndex: + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): +""" + + +def shape_from_pyval(pyval, layout: Sequence[int] | None = None): + """Returns a Shape that describes a tuple-tree of Numpy arrays.""" + + def convert(pyval): + if isinstance(pyval, tuple): + if layout is not None: + raise NotImplementedError( + 'shape_from_pyval does not support layouts for tuple shapes' + ) + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) + + return convert(pyval) + + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): + """Maps PaddingType or string to pad values (list of pairs of ints).""" + if not isinstance(padding_type, (str, PaddingType)): + msg = 'padding_type must be str or PaddingType, got {}.' + raise TypeError(msg.format(type(padding_type))) + + if isinstance(padding_type, str): + if padding_type.upper() == 'VALID': + padding_type = PaddingType.VALID + elif padding_type.upper() == 'SAME': + padding_type = PaddingType.SAME + else: + msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' + raise ValueError(msg.format(padding_type)) + + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + elif padding_type == PaddingType.SAME: + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [ + max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size in zip( + out_shape, window_strides, rhs_dims, lhs_dims + ) + ] + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + else: + msg = 'Unexpected PaddingType value: {}' + raise ValueError(msg.format(padding_type)) + + +XlaBuilder = _xla.XlaBuilder +XlaComputation = _xla.XlaComputation +XlaOp = _xla.XlaOp +FftType = _xla.FftType +Client = _xla.Client +Memory = _xla.Memory +Array = _xla.Array +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode +ResultAccuracyMode = _xla.ResultAccuracy_Mode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +class PaddingConfigDimension: + """Python representation of a xla.PaddingConfigDimension protobuf.""" + + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') + + edge_padding_low: int + edge_padding_high: int + interior_padding: int + + def __init__(self): + self.edge_padding_low = 0 + self.edge_padding_high = 0 + self.interior_padding = 0 + + +class PaddingConfig: + """Python representation of a xla.PaddingConfig protobuf.""" + + __slots__ = ('dimensions',) + + def __init__(self): + self.dimensions = [] + + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] +) -> PaddingConfig: + """Create PaddingConfig proto from list of triples of integers. + + Args: + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + A `PaddingConfig` object. + """ + if not isinstance(padding_config, PaddingConfig): + triples = padding_config + padding_config = PaddingConfig() + for lo, hi, interior in triples: + dimension = PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + return padding_config + + +class DotDimensionNumbers: + """Python representation of a xla.DotDimensionNumbers protobuf.""" + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) + + def __init__(self): + self.lhs_contracting_dimensions = [] + self.rhs_contracting_dimensions = [] + self.lhs_batch_dimensions = [] + self.rhs_batch_dimensions = [] + + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ] +) -> DotDimensionNumbers: + """Builds a DotDimensionNumbers object from a specification. + + Args: + dimension_numbers: either a `DotDimensionNumbers` or a nested tuple + `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: + A `DotDimensionNumbers` object. + """ + if isinstance(dimension_numbers, (list, tuple)): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto + else: + return dimension_numbers + + +class ConvolutionDimensionNumbers: + """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) + + def __init__(self): + self.input_batch_dimension = 0 + self.input_feature_dimension = 0 + self.input_spatial_dimensions = [] + self.kernel_input_feature_dimension = 0 + self.kernel_output_feature_dimension = 0 + self.kernel_spatial_dimensions = [] + self.output_batch_dimension = 0 + self.output_feature_dimension = 0 + self.output_spatial_dimensions = [] + + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + """Builds a ConvolutionDimensionNumbers object from a specification. + + Args: + dimension_numbers: optional, either a ConvolutionDimensionNumbers object or + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. + num_spatial_dimensions: the number of spatial dimensions. + + Returns: + A `ConvolutionDimensionNumbers` object. + """ + if dimension_numbers is None: + nd = num_spatial_dimensions + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) + dimension_numbers.input_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) + dimension_numbers.output_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) + return dimension_numbers + + +class PrecisionConfig: + """Python representation of a xla.PrecisionConfig protobuf.""" + + __slots__ = ('operand_precision',) + + Precision = _xla.PrecisionConfig_Precision + + def __init__(self): + self.operand_precision = [] + + +class ResultAccuracy: + """Python representation of a xla.ResultAccuracy protobuf.""" + + __slots__ = ('mode', 'atol', 'rtol', 'ulps') + + def __init__(self): + self.mode = _xla.ResultAccuracy_Mode.DEFAULT + self.atol = 0.0 + self.rtol = 0.0 + self.ulps = 0 + + +class GatherDimensionNumbers: + """Python representation of a xla.GatherDimensionNumbers protobuf.""" + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) + + def __init__(self): + self.offset_dims = [] + self.collapsed_slice_dims = [] + self.start_index_map = [] + self.index_vector_dim = 0 + + +class ScatterDimensionNumbers: + """Python representation of a xla.ScatterDimensionNumbers protobuf.""" + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) + + def __init__(self): + self.update_window_dims = [] + self.inserted_window_dims = [] + self.scatter_dims_to_operand_dims = [] + self.index_vector_dim = 0 + + +class ReplicaGroup: + """Python representation of a xla.ReplicaGroup protobuf.""" + + __slots__ = ('replica_ids',) + + def __init__(self): + self.replica_ids = [] + + +def _make_replica_group_proto(replica_group): + replica_group_proto = ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + +def make_replica_groups(replica_groups): + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return replica_groups_protos + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = Traceback.enabled + Traceback.enabled = enabled + try: + yield + finally: + Traceback.enabled = saved + + +def heap_profile(client: Client) -> bytes: + """Returns a gzipped pprof protocol buffer containing a heap profile.""" + return gzip.compress(client.heap_profile()) + + +XlaRuntimeError = _xla.XlaRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla_client.pyi similarity index 100% rename from jaxlib/xla/xla_client.pyi rename to jaxlib/xla_client.pyi diff --git a/jaxlib/xla/xla_client_backend_independent_test.py b/jaxlib/xla_client_backend_independent_test.py similarity index 99% rename from jaxlib/xla/xla_client_backend_independent_test.py rename to jaxlib/xla_client_backend_independent_test.py index 611c602a73b5..1cd2865bf9a9 100644 --- a/jaxlib/xla/xla_client_backend_independent_test.py +++ b/jaxlib/xla_client_backend_independent_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest import numpy as np -from jax.jaxlib.xla import xla_client +from jax.jaxlib import xla_client # pylint: disable=g-import-not-at-top try: diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla_client_test.py similarity index 99% rename from jaxlib/xla/xla_client_test.py rename to jaxlib/xla_client_test.py index 7779b314ce0e..4bb7f7992e16 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla_client_test.py @@ -30,13 +30,13 @@ import ml_dtypes import numpy as np -from jax.jaxlib.xla import xla_client +from jax.jaxlib import xla_client import jax import jax._src.test_util # pylint: disable=g-import-not-at-top try: - from jax.jaxlib.xla import custom_calls_testlib + from jax.jaxlib import custom_calls_testlib except ImportError: custom_calls_testlib = None diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla_compiler.cc similarity index 99% rename from jaxlib/xla/xla_compiler.cc rename to jaxlib/xla_compiler.cc index 11ec4cdc3846..add3ba9cfc15 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/xla/xla_compiler.h" +#include "jaxlib/xla_compiler.h" #include #include @@ -43,8 +43,8 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/dlpack.h" -#include "jaxlib/xla/py_client.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/py_client.h" #include "xla/array.h" #include "xla/client/executable_build_options.h" #include "xla/debug_options_flags.h" diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla_compiler.h similarity index 88% rename from jaxlib/xla/xla_compiler.h rename to jaxlib/xla_compiler.h index ca5bc762a7d8..261f630d1cd3 100644 --- a/jaxlib/xla/xla_compiler.h +++ b/jaxlib/xla_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_XLA_XLA_COMPILER_H_ -#define JAXLIB_XLA_XLA_COMPILER_H_ +#ifndef JAXLIB_XLA_COMPILER_H_ +#define JAXLIB_XLA_COMPILER_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildXlaCompilerSubmodule(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_XLA_XLA_COMPILER_H_ +#endif // JAXLIB_XLA_COMPILER_H_ diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla_extension/__init__.pyi similarity index 99% rename from jaxlib/xla/xla_extension/__init__.pyi rename to jaxlib/xla_extension/__init__.pyi index 537c18c7adf1..d80185c0fecd 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla_extension/__init__.pyi @@ -56,6 +56,8 @@ _Status = Any _Dtype = Any _XlaOpMetadata = Any +ifrt_version_number: int + _T = TypeVar("_T") class XlaRuntimeError(RuntimeError): diff --git a/jaxlib/xla/xla_extension/config.pyi b/jaxlib/xla_extension/config.pyi similarity index 100% rename from jaxlib/xla/xla_extension/config.pyi rename to jaxlib/xla_extension/config.pyi diff --git a/jaxlib/xla/xla_extension/guard_lib.pyi b/jaxlib/xla_extension/guard_lib.pyi similarity index 100% rename from jaxlib/xla/xla_extension/guard_lib.pyi rename to jaxlib/xla_extension/guard_lib.pyi diff --git a/jaxlib/xla/xla_extension/ifrt_programs.pyi b/jaxlib/xla_extension/ifrt_programs.pyi similarity index 97% rename from jaxlib/xla/xla_extension/ifrt_programs.pyi rename to jaxlib/xla_extension/ifrt_programs.pyi index bcee365e5732..ddbfdbd15728 100644 --- a/jaxlib/xla/xla_extension/ifrt_programs.pyi +++ b/jaxlib/xla_extension/ifrt_programs.pyi @@ -15,7 +15,7 @@ from typing import Any, Sequence, Union -from jax.jaxlib.xla import xla_extension +from jaxlib import xla_extension class Program: ... diff --git a/jaxlib/xla/xla_extension/ifrt_proxy.pyi b/jaxlib/xla_extension/ifrt_proxy.pyi similarity index 96% rename from jaxlib/xla/xla_extension/ifrt_proxy.pyi rename to jaxlib/xla_extension/ifrt_proxy.pyi index 3b5de7aa97c9..636af5931333 100644 --- a/jaxlib/xla/xla_extension/ifrt_proxy.pyi +++ b/jaxlib/xla_extension/ifrt_proxy.pyi @@ -15,7 +15,7 @@ from typing import Any, Optional, Callable -from jax.jaxlib.xla import xla_extension +from jaxlib import xla_extension _Status = Any Client = xla_extension.Client diff --git a/jaxlib/xla/xla_extension/jax_jit.pyi b/jaxlib/xla_extension/jax_jit.pyi similarity index 98% rename from jaxlib/xla/xla_extension/jax_jit.pyi rename to jaxlib/xla_extension/jax_jit.pyi index 1f78d283333c..70d832bc2ce8 100644 --- a/jaxlib/xla/xla_extension/jax_jit.pyi +++ b/jaxlib/xla_extension/jax_jit.pyi @@ -16,7 +16,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple import numpy as np -from jax.jaxlib.xla import xla_extension +from jaxlib import xla_extension from . import pytree diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla_extension/mlir.pyi similarity index 100% rename from jaxlib/xla/xla_extension/mlir.pyi rename to jaxlib/xla_extension/mlir.pyi diff --git a/jaxlib/xla/xla_extension/ops.pyi b/jaxlib/xla_extension/ops.pyi similarity index 99% rename from jaxlib/xla/xla_extension/ops.pyi rename to jaxlib/xla_extension/ops.pyi index f76ff1c2002c..2a6b0897a08f 100644 --- a/jaxlib/xla/xla_extension/ops.pyi +++ b/jaxlib/xla_extension/ops.pyi @@ -16,7 +16,7 @@ import enum from typing import Any, Optional, Sequence, overload -from jax.jaxlib.xla import xla_extension +from jaxlib import xla_extension FftType = xla_extension.FftType XlaBuilder = xla_extension.XlaBuilder diff --git a/jaxlib/xla/xla_extension/pmap_lib.pyi b/jaxlib/xla_extension/pmap_lib.pyi similarity index 100% rename from jaxlib/xla/xla_extension/pmap_lib.pyi rename to jaxlib/xla_extension/pmap_lib.pyi diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla_extension/profiler.pyi similarity index 100% rename from jaxlib/xla/xla_extension/profiler.pyi rename to jaxlib/xla_extension/profiler.pyi diff --git a/jaxlib/xla/xla_extension/pytree.pyi b/jaxlib/xla_extension/pytree.pyi similarity index 100% rename from jaxlib/xla/xla_extension/pytree.pyi rename to jaxlib/xla_extension/pytree.pyi diff --git a/jaxlib/xla/xla_extension/sdy.pyi b/jaxlib/xla_extension/sdy.pyi similarity index 100% rename from jaxlib/xla/xla_extension/sdy.pyi rename to jaxlib/xla_extension/sdy.pyi diff --git a/jaxlib/xla/xla_extension/transfer_guard_lib.pyi b/jaxlib/xla_extension/transfer_guard_lib.pyi similarity index 100% rename from jaxlib/xla/xla_extension/transfer_guard_lib.pyi rename to jaxlib/xla_extension/transfer_guard_lib.pyi diff --git a/pyproject.toml b/pyproject.toml index be29e16beb9c..8e2bf17edc69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,17 @@ module = [ "jaxlib.cpu_feature_guard", "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.mosaic.dialect.gpu.*", + "jaxlib.mosaic.python._tpu_gen", + "jaxlib.triton.*", "jaxlib.utils", + "jaxlib.version", "jaxlib.xla_extension.utils", "jraph.*", "libtpu.*", "matplotlib.*", + "mlir.*", + "ml_dtypes.*", "nvidia.*", "numpy.*", "opt_einsum.*", From 9ee7bade3513d8a4919a9194a023107492a5d2dd Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 23 Apr 2025 10:58:05 -0700 Subject: [PATCH 0766/1769] Reverts 49e25c6167806ba90efe8370fb04db3f4966437c PiperOrigin-RevId: 750652690 --- jax/_src/core.py | 23 +++--- jax/_src/interpreters/partial_eval.py | 108 +++++++++++++++----------- 2 files changed, 71 insertions(+), 60 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 472745871f6a..0183a9942524 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -442,36 +442,34 @@ def __init__(self, aval: AbstractValue): def __repr__(self): return '_' class Literal: - __slots__ = ["val", "aval"] + __slots__ = ["val", "aval", "hash"] val: Any aval: AbstractValue + hash: int | None def __init__(self, val, aval): self.val = val self.aval = aval - - @property - def hash(self): try: - return hash(self.val) + self.hash = hash(val) except TypeError: - if type(self.val) in literalable_types: + if type(val) in literalable_types: try: - return hash((self.val.item(), self.val.dtype)) + self.hash = hash((val.item(), val.dtype)) except (TypeError, AttributeError, ValueError): - return None + self.hash = None __hash__ = None # type: ignore def __repr__(self): - return f'{self.val}' + if hasattr(self, 'hash'): + return f'{self.val}' + else: + return f'Literal(val={self.val})' literalable_types: set[type] = set() -def is_literalable(x: Any) -> bool: - return type(x) in dtypes.python_scalar_dtypes or (type(x) in literalable_types and not np.shape(x)) - Atom = Union[Var, Literal] class Primitive: @@ -2063,7 +2061,6 @@ class DShapedArray(UnshapedArray): array_abstraction_level: int = 3 def __init__(self, shape, dtype, weak_type=False): - assert not any(isinstance(d, Literal) for d in shape) self.shape = shape self.dtype = dtype self.weak_type = weak_type diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index aa9b4a64d6f0..0e888c1591aa 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -198,7 +198,7 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: if const is None: return tracer else: - if core.is_literalable(const): + if type(const) in core.literalable_types and np.shape(const) == (): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) @@ -1647,8 +1647,7 @@ def _origin_msg(self): def get_referent(self): frame = self._trace.frame - var = frame.tracer_to_var.get(id(self)) - val = frame.constvar_to_val.get(var) if isinstance(var, Var) else None + val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval @@ -1688,7 +1687,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: class JaxprStackFrame: gensym: Callable[[AbstractValue], Var] - tracer_to_var: dict[TracerId, Atom] + tracer_to_var: dict[TracerId, Var] constid_to_tracer: dict[ConstId, Tracer] constvar_to_val: dict[Var, Any] tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers @@ -1726,8 +1725,7 @@ def to_jaxpr( debug_info: core.DebugInfo, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: - vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] - assert len(vars) == len(set(vars)) + assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) @@ -1740,15 +1738,14 @@ def to_jaxpr( jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, debug_info) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], debug_info: core.DebugInfo): # It's not necessary, but we keep the tracer-to-var mapping injective: - vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] - assert len(vars) == len(set(vars)) + assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) constvars, constvals = unzip2(self.constvar_to_val.items()) expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, @@ -1757,7 +1754,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -1767,7 +1764,6 @@ def newvar(self, aval): # this aval may have tracers in it, so we replace those with variables new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d for d in aval.shape] - new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] aval = aval.update(shape=tuple(new_shape)) return self.gensym(aval) @@ -1783,15 +1779,14 @@ def find_progenitors(self, tracer): active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns if any( - v in constvars if type(v) is Var else type(v) is Literal - for v in eqn.invars)] + const_eqns = [eqn for eqn in self.eqns + if {v for v in eqn.invars if type(v) is Var} & constvars] return invar_positions, const_eqns def _const_folding_and_forwarding( jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Atom] = {} + var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined new_eqns = [] def apply_var_sub(a: Atom) -> Atom: return var_subs.get(a, a) if isinstance(a, Var) else a @@ -1802,20 +1797,14 @@ def apply_var_sub(a: Atom) -> Atom: has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) for eff in eqn.effects) if (eqn.primitive in const_fold_rules and - any(v in consts if isinstance(v, Var) - else isinstance(v, Literal) for v in eqn.invars) and + any(v in consts for v in eqn.invars if isinstance(v, Var)) and not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else - v.val if isinstance(v, Literal) else None + consts_in = [consts.get(v) if isinstance(v, Var) else None for v in eqn.invars] consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) assert (new_eqn is None) == all(c is not None for c in consts_out) for v, c in zip(eqn.outvars, consts_out): - if c is not None: - if core.is_literalable(c): - var_subs[v] = Literal(c, v.aval) - else: - consts[v] = c + if c is not None: consts[v] = c if new_eqn is None: continue else: eqn = new_eqn # if the application trivially maps some inputs to outputs, simplify @@ -1847,26 +1836,54 @@ def apply_var_sub(a: Atom) -> Atom: forwarding_rules: dict[Primitive, ForwardingRule] = {} -def _drop_unused_vars( +def _inline_literals( jaxpr: Jaxpr, constvals: Sequence[Any] ) -> tuple[Jaxpr, list[Any]]: - def vars(atom: Atom) -> list[Var]: - if isinstance(atom, Literal): - return [] - aval = atom.aval + # This function also prunes unused constants and inserts `dropvar` symbols. + input_effects = {eff for eff in jaxpr.effects + if isinstance(eff, effects.JaxprInputEffect)} + # Don't inline any literal with an input effect + has_input_effect = [any(eff.input_index == i for eff in input_effects) + for i in range(len(constvals))] + lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, + has_input_effect) + if type(c) in core.literalable_types and not np.shape(c) and not e} + def lit(a: Atom) -> Literal | None: + return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) + else None) + newname: Callable[[AbstractValue], Var] = core.gensym() + newvars: dict[Var, Var] = {} + newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) + var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) + lit_or_var = ( + lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) + ) + dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) + + def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: if isinstance(aval, DShapedArray): - return [atom] + [d for d in aval.shape if isinstance(d, Var)] - return [atom] - used: set[Var] = {v for atom in jaxpr.outvars for v in vars(atom)} - for eqn in jaxpr.eqns[::-1]: - eqn.outvars = [v if v in used else DropVar(v.aval) for v in eqn.outvars] - used.update(v for atom in eqn.invars for v in vars(atom)) - cvars, constvals = unzip2( - (v, val) for v, val in zip(jaxpr.constvars, constvals) if v in used) - jaxpr._constvars = list(cvars) - jaxpr._effects = make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, - jaxpr.outvars, jaxpr.eqns) - return jaxpr, list(constvals) + return [d for d in aval.shape if isinstance(d, Var)] + return [] + + used = {v for eqn in jaxpr.eqns for atom in eqn.invars + for v in it.chain([atom], vars_in_shape(atom.aval)) + if isinstance(atom, Var)} + used |= {v for outvar in jaxpr.outvars + for v in it.chain([outvar], vars_in_shape(outvar.aval))} + new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] + new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) + if v in used and not lit(v)] + new_invars = [var(v) for v in jaxpr.invars] + new_eqns = [] + for eqn in jaxpr.eqns: + invars = [lit_or_var(x) for x in eqn.invars] + outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] + new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) + new_outvars = [lit_or_var(v) for v in jaxpr.outvars] + effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) + new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, + jaxpr.debug_info) + return new_jaxpr, new_constvals class DynamicJaxprTrace(core.Trace): @@ -1917,12 +1934,9 @@ def new_const(self, c): def _new_const(self, aval, c) -> DynamicJaxprTracer: tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) self.frame.tracers.append(tracer) - if core.is_literalable(c): - self.frame.tracer_to_var[id(tracer)] = Literal(c, aval) - else: - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.constid_to_tracer[id(c)] = tracer + self.frame.constvar_to_val[var] = c return tracer def _lift_tracers_in_aval(self, aval): From 7877bec7822ff5fd9261ef2ed6b1ecde630d815a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Apr 2025 13:43:46 -0700 Subject: [PATCH 0767/1769] [JAX] Rename xla_extension to _jax. Renaming only, no functional changes intended. There are two reasons to do this: * I want to split some XLA specific things out of the JAX wheel and move them back into the XLA repository. It would be nice if the name "xla" could be reserved for that extension instead. * There are lots of jax-specific things in this extension. PiperOrigin-RevId: 750709831 --- .pre-commit-config.yaml | 2 +- docs/jax-primitives.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 2 +- docs/notebooks/thinking_in_jax.ipynb | 2 +- jax/_src/array.py | 4 +- jax/_src/compiler.py | 4 +- jax/_src/distributed.py | 12 +-- jax/_src/export/_export.py | 12 +-- jax/_src/interpreters/mlir.py | 4 +- jax/_src/lib/BUILD | 2 +- jax/_src/lib/__init__.py | 21 +++-- jax/_src/stages.py | 10 +-- jax/_src/traceback_util.py | 6 +- jax/_src/xla_bridge.py | 6 +- jax/experimental/_private_mm/mini_dime.py | 6 +- .../array_serialization/serialization.py | 4 +- jax/experimental/jax2tf/call_tf.py | 4 +- .../jax2tf/tests/back_compat_tf_test.py | 4 +- jax/experimental/topologies.py | 4 +- jax/extend/ifrt_programs.py | 6 +- jax/lib/xla_extension.py | 76 +++++++++---------- jaxlib/BUILD | 10 +-- jaxlib/{xla_extension => _jax}/__init__.pyi | 0 jaxlib/{xla_extension => _jax}/config.pyi | 0 jaxlib/{xla_extension => _jax}/guard_lib.pyi | 0 .../{xla_extension => _jax}/ifrt_programs.pyi | 6 +- jaxlib/{xla_extension => _jax}/ifrt_proxy.pyi | 4 +- jaxlib/{xla_extension => _jax}/jax_jit.pyi | 6 +- jaxlib/{xla_extension => _jax}/mlir.pyi | 0 jaxlib/{xla_extension => _jax}/ops.pyi | 20 ++--- jaxlib/{xla_extension => _jax}/pmap_lib.pyi | 0 jaxlib/{xla_extension => _jax}/profiler.pyi | 0 jaxlib/{xla_extension => _jax}/pytree.pyi | 0 jaxlib/{xla_extension => _jax}/sdy.pyi | 0 .../transfer_guard_lib.pyi | 0 jaxlib/pjit.cc | 2 +- jaxlib/pmap_lib.cc | 2 +- jaxlib/tools/BUILD.bazel | 4 +- jaxlib/tools/build_wheel.py | 6 +- jaxlib/xla.cc | 2 +- jaxlib/xla/BUILD | 2 +- jaxlib/xla/xla_extension.py | 2 +- jaxlib/xla_client.py | 2 +- jaxlib/xla_client.pyi | 68 ++++++++--------- pyproject.toml | 2 +- tests/api_test.py | 12 +-- tests/checkify_test.py | 4 +- tests/notebooks/colab_cpu.ipynb | 1 - tests/pallas/tpu_pallas_test.py | 4 +- tests/pjit_test.py | 10 +-- tests/pmap_test.py | 8 +- tests/unary_ops_accuracy_test.py | 4 +- 52 files changed, 191 insertions(+), 183 deletions(-) rename jaxlib/{xla_extension => _jax}/__init__.pyi (100%) rename jaxlib/{xla_extension => _jax}/config.pyi (100%) rename jaxlib/{xla_extension => _jax}/guard_lib.pyi (100%) rename jaxlib/{xla_extension => _jax}/ifrt_programs.pyi (89%) rename jaxlib/{xla_extension => _jax}/ifrt_proxy.pyi (94%) rename jaxlib/{xla_extension => _jax}/jax_jit.pyi (95%) rename jaxlib/{xla_extension => _jax}/mlir.pyi (100%) rename jaxlib/{xla_extension => _jax}/ops.pyi (97%) rename jaxlib/{xla_extension => _jax}/pmap_lib.pyi (100%) rename jaxlib/{xla_extension => _jax}/profiler.pyi (100%) rename jaxlib/{xla_extension => _jax}/pytree.pyi (100%) rename jaxlib/{xla_extension => _jax}/sdy.pyi (100%) rename jaxlib/{xla_extension => _jax}/transfer_guard_lib.pyi (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f36862711f04..3fcfdb54bada 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: hooks: - id: mypy files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead + exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jaxlib/_jax/.* # Use pyi instead additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0] args: [--config=pyproject.toml] diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index 819d0418e894..fab5334b4010 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -100,7 +100,7 @@ def trace(name): vtype = str(type(v)) if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: + elif "jaxlib._jax_.XlaOp" in vtype: return "".format(id(v)) elif ("partial_eval.JaxprTracer" in vtype or "batching.BatchTracer" in vtype or diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a9d4a9424f9f..5879630ac818 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index d6cbf6e02198..28d5f20deab6 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -139,7 +139,7 @@ { "data": { "text/plain": [ - "jaxlib.xla_extension.ArrayImpl" + "jaxlib._jax.ArrayImpl" ] }, "execution_count": 4, diff --git a/jax/_src/array.py b/jax/_src/array.py index a802d122a257..f2b070c8221d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -38,7 +38,7 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension as xe +from jax._src.lib import _jax from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -554,7 +554,7 @@ def layout(self): try: return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), self.sharding) - except xe.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): return Layout(None, self.sharding) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 9ac47aa4f0ea..3d2ed0ccd050 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -575,7 +575,7 @@ def _share_fdo_profiles( devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, min_process_id ) -> bytes | None: sym_name = computation.operation.attributes['sym_name'] @@ -638,7 +638,7 @@ def _compile_and_share_module( computation: ir.Module, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, module_name: str, cache_key: str, first_process_id: int diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 1dc75b6d8dfc..ef8c48a61293 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -22,7 +22,7 @@ from jax._src import clusters from jax._src import config from jax._src import xla_bridge -from jax._src.lib import xla_extension +from jax._src.lib import _jax logger = logging.getLogger(__name__) @@ -37,8 +37,8 @@ class State: process_id: int = 0 num_processes: int = 1 - service: xla_extension.DistributedRuntimeService | Any | None = None - client: xla_extension.DistributedRuntimeClient | Any | None = None + service: _jax.DistributedRuntimeService | Any | None = None + client: _jax.DistributedRuntimeClient | Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None slice_index: int | None = None @@ -132,7 +132,7 @@ def initialize(self, logger.info( 'Starting JAX distributed service on %s', coordinator_bind_address ) - self.service = xla_extension.get_distributed_runtime_service( + self.service = _jax.get_distributed_runtime_service( coordinator_bind_address, num_processes, heartbeat_interval=service_heartbeat_interval_seconds, max_missing_heartbeats=service_max_missing_heartbeats) @@ -142,7 +142,7 @@ def initialize(self, if self.client is not None: raise RuntimeError('distributed.initialize should only be called once.') - self.client = xla_extension.get_distributed_runtime_client( + self.client = _jax.get_distributed_runtime_client( coordinator_address, process_id, init_timeout=initialization_timeout, heartbeat_interval=client_heartbeat_interval_seconds, max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True) @@ -170,7 +170,7 @@ def initialize_preemption_sync_manager(self): raise RuntimeError( 'Preemption sync manager should only be initialized once.') self.preemption_sync_manager = ( - xla_extension.create_preemption_sync_manager()) + _jax.create_preemption_sync_manager()) self.preemption_sync_manager.initialize(self.client) global_state = State() diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index cbd92f86b835..e01eca4a62f4 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -43,7 +43,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect @@ -691,7 +691,7 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = xla_extension.sdy.lowered_with_shardy( + shardy_enabled = _jax.sdy.lowered_with_shardy( mlir.module_to_bytecode(mlir_module)) mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) @@ -811,7 +811,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: if shardy_enabled: - mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( + mlir_str = _jax.sdy.sdy_round_trip_export_pipeline( mlir.module_to_bytecode(module)) else: mlir_str = mlir.module_to_bytecode(module) @@ -1443,10 +1443,10 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = xla_extension.sdy.lowered_with_shardy( + shardy_enabled = _jax.sdy.lowered_with_shardy( mlir.module_to_bytecode(submodule)) if shardy_enabled: - submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( + submodule = ir.Module.parse(_jax.sdy.sdy_round_trip_import_shardings( mlir.module_to_bytecode(submodule))) with submodule.context: @@ -1456,7 +1456,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, mesh = None if shardy_enabled: - sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule)) + sdy_mesh_axes = _jax.sdy.get_mesh(mlir.module_to_bytecode(submodule)) mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else mesh_lib.empty_abstract_mesh) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 80cbf76242aa..a6debad7cbdb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -56,7 +56,7 @@ SdyArraySharding, SdyArrayShardingList) from jax._src.util import foreach from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects @@ -3031,7 +3031,7 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: Then verifies that there are no more dynamic shapes in the module. """ try: - refine_polymorphic_shapes = partial(xla_extension.mlir.refine_polymorphic_shapes, + refine_polymorphic_shapes = partial(_jax.mlir.refine_polymorphic_shapes, mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 0c8acdd76630..dd8ab0557657 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -46,7 +46,7 @@ py_library_providing_imports_info( "//jaxlib:utils", "//jaxlib:weakref_lru_cache", "//jaxlib:xla_client", - "//jaxlib:xla_extension", + "//jaxlib:_jax", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index e9f9d95f608f..7c75ac22cbe1 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -85,13 +85,22 @@ def _parse_version(v: str) -> tuple[int, ...]: import jaxlib.lapack as lapack # noqa: F401 import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_extension as xla_extension # noqa: F401 -from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401 -from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401 -from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401 -from jaxlib.xla_extension import pytree as pytree # noqa: F401 -from jaxlib.xla_extension import Device as Device # noqa: F401 +if version >= (0, 6, 1): + import jaxlib._jax as _jax # noqa: F401 + from jaxlib._jax import guard_lib as guard_lib # noqa: F401 + from jaxlib._jax import jax_jit as jax_jit # noqa: F401 + from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 + from jaxlib._jax import pytree as pytree # noqa: F401 + from jaxlib._jax import Device as Device # noqa: F401 +else: + import jaxlib.xla_extension as _jax # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import guard_lib as guard_lib # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import jax_jit as jax_jit # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import pmap_lib as pmap_lib # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import pytree as pytree # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import Device as Device # type: ignore # pytype: disable=import-error # noqa: F401 + import jaxlib.xla_client as xla_client # noqa: F401 diff --git a/jax/_src/stages.py b/jax/_src/stages.py index b813037a3204..3c5d710f3bdc 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -47,6 +47,7 @@ from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir +from jax._src.lib import _jax from jax._src.lib import xla_client as xc @@ -54,7 +55,6 @@ traceback_util.register_exclusion(__file__) -xla_extension = xc._xla map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -121,7 +121,7 @@ def as_text(self) -> str: raise NotImplementedError(err_msg) try: return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(err_msg) from e @@ -146,7 +146,7 @@ def cost_analysis(self) -> Any: if hasattr(xla_ext_exe, "cost_analysis"): try: return xla_ext_exe.cost_analysis() - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): raise @@ -183,7 +183,7 @@ def memory_analysis(self) -> Any: raise NotImplementedError(err_msg) try: return xla_ext_exe.get_compiled_memory_stats() - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(err_msg) from e @@ -212,7 +212,7 @@ def hlo(self) -> xc.XlaComputation: hlo = self.stablehlo() m: str | bytes m = mlir.module_to_bytecode(hlo) - return xla_extension.mlir.mlir_module_to_xla_computation( + return _jax.mlir.mlir_module_to_xla_computation( m, use_tuple_args=self.compile_args["tuple_args"]) def stablehlo(self) -> ir.Module: diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index cde9e4a30f99..60276a22b4cf 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -24,7 +24,7 @@ from jax._src import config from jax._src import util -from jax._src.lib import xla_extension +from jax._src.lib import _jax C = TypeVar("C", bound=Callable[..., Any]) @@ -200,10 +200,10 @@ def reraise_with_filtered_traceback(*args, **kwargs): # just setting __traceback__ is enough. Since it is no longer needed, # the XLA extension no longer defines a traceback-replacing method at # Python 3.11 and onward. - if hasattr(xla_extension, "replace_thread_exc_traceback"): + if hasattr(_jax, "replace_thread_exc_traceback"): # TODO(kidger): remove this line once Python 3.11 is the minimum supported # version. - xla_extension.replace_thread_exc_traceback(filtered_tb) + _jax.replace_thread_exc_traceback(filtered_tb) if sys.version_info >= (3, 11) and mode == "quiet_remove_frames": e.add_note("--------------------\n" + _simplified_tb_msg) else: diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 25df9501d49a..644c395cb551 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -43,7 +43,7 @@ from jax._src.cloud_tpu_init import get_tpu_library_path from jax._src.lib import cuda_versions from jax._src.lib import xla_client -from jax._src.lib import xla_extension +from jax._src.lib import _jax logger = logging.getLogger(__name__) @@ -888,7 +888,7 @@ def _suggest_missing_backends(): assert _default_backend is not None default_platform = _default_backend.platform if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu(): - if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors: + if hasattr(_jax, "GpuAllocatorConfig") and "cuda" in _backend_errors: err = _backend_errors["cuda"] warning_msg = f"CUDA backend failed to initialize: {err}." if "no supported devices found for platform CUDA." in err: @@ -1030,7 +1030,7 @@ def devices( ) -> list[xla_client.Device]: """Returns a list of all devices for a given backend. - .. currentmodule:: jaxlib.xla_extension + .. currentmodule:: jaxlib._jax Each device is represented by a subclass of :class:`Device` (e.g. :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is diff --git a/jax/experimental/_private_mm/mini_dime.py b/jax/experimental/_private_mm/mini_dime.py index 971d5a016817..f12084b3a1ce 100644 --- a/jax/experimental/_private_mm/mini_dime.py +++ b/jax/experimental/_private_mm/mini_dime.py @@ -49,8 +49,8 @@ import jax import jax.numpy as jnp -import jaxlib.xla_extension as xe from jax._src import array +from jax._src.lib import _jax from jax._src.op_shardings import are_op_shardings_equal @@ -66,10 +66,10 @@ def _get_nccl_dtype_and_count(arr, count=None): return nccl_dtype, count -def get_distributed_client() -> xe.DistributedRuntimeClient: +def get_distributed_client() -> _jax.DistributedRuntimeClient: from jax._src.distributed import global_state - assert isinstance(global_state.client, xe.DistributedRuntimeClient) + assert isinstance(global_state.client, _jax.DistributedRuntimeClient) return global_state.client diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index fd37694e2ce7..82e9e3dc938b 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -33,7 +33,7 @@ from jax._src import typing from jax._src import util from jax._src.layout import Layout -from jax._src.lib import xla_extension as xe +from jax._src.lib import _jax from jax.experimental.array_serialization import tensorstore_impl as ts_impl # ruff: noqa: F401 # pylint: disable=unused-import @@ -248,7 +248,7 @@ def check_for_errors(self): # Clears self._exception so it is only raised once. exception = self._exception self._exception = None - if (isinstance(exception, xe.XlaRuntimeError) and + if (isinstance(exception, _jax.XlaRuntimeError) and 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)): raise BarrierTimeoutError( '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG])) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 73b7544f991a..2aadc3a9d512 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -40,7 +40,7 @@ from jax._src import core from jax._src import effects from jax._src import util -from jax._src.lib import xla_extension as _xla +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo @@ -596,7 +596,7 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) + stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo) submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 2cf363b0cfb2..0b75c679f5e6 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -30,7 +30,7 @@ import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax.experimental import jax2tf from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function import jax.numpy as jnp @@ -96,7 +96,7 @@ def serialize( for op in tf_graph.get_operations(): if op.type == "XlaCallModule": serialized_module = op.get_attr("module") - module_str = xla_extension.mlir.deserialize_portable_artifact( + module_str = _jax.mlir.deserialize_portable_artifact( serialized_module ) module_version = op.get_attr("version") diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 06be2b74853f..94b63769f101 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -19,7 +19,7 @@ import jax from jax.experimental import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src import xla_bridge as xb Device = xc.Device @@ -46,7 +46,7 @@ def get_topology_desc( try: topology = xb.make_pjrt_topology(platform, topology_name, **kwargs) return TopologyDescription(topology._make_compile_only_devices()) - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(msg) from e diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py index 715dfd43592c..13ba9088bc55 100644 --- a/jax/extend/ifrt_programs.py +++ b/jax/extend/ifrt_programs.py @@ -15,8 +15,8 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.lib import xla_extension as _xe +from jax._src.lib import _jax -ifrt_programs = _xe.ifrt_programs +ifrt_programs = _jax.ifrt_programs -del _xe +del _jax diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 8f1b27070e98..452d004b2f6d 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import xla_extension as _xe +from jax._src.lib import _jax _deprecations = { "ArrayImpl": ( @@ -35,100 +35,100 @@ "jax.lib.xla_extension.DistributedRuntimeClient is" " deprecated; use jax.distributed instead." ), - _xe.DistributedRuntimeClient, + _jax.DistributedRuntimeClient, ), "get_distributed_runtime_client": ( ( "jax.lib.xla_extension.get_distributed_runtime_client is" " deprecated; use jax.distributed instead." ), - _xe.get_distributed_runtime_client, + _jax.get_distributed_runtime_client, ), "get_distributed_runtime_service": ( ( "jax.lib.xla_extension.get_distributed_runtime_service is" " deprecated; use jax.distributed instead." ), - _xe.get_distributed_runtime_service, + _jax.get_distributed_runtime_service, ), "Device": ( "jax.lib.xla_extension.Device is deprecated; use jax.Device instead.", - _xe.Device, + _jax.Device, ), "PjitFunctionCache": ( "jax.lib.xla_extension.PjitFunctionCache is deprecated.", - _xe.PjitFunctionCache, + _jax.PjitFunctionCache, ), "ifrt_proxy": ( "jax.lib.xla_extension.ifrt_proxy is deprecated.", - _xe.ifrt_proxy, + _jax.ifrt_proxy, ), "jax_jit": ( "jax.lib.xla_extension.jax_jit is deprecated.", - _xe.jax_jit, + _jax.jax_jit, ), - "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _xe.mlir), - "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _xe.pmap_lib), + "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _jax.mlir), + "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _jax.pmap_lib), "profiler": ( "jax.lib.xla_extension.profiler is deprecated.", - _xe.profiler, + _jax.profiler, ), "pytree": ( "jax.lib.xla_extension.pytree is deprecated.", - _xe.pytree, + _jax.pytree, ), "hlo_module_cost_analysis": ( "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", - _xe.hlo_module_cost_analysis, + _jax.hlo_module_cost_analysis, ), "hlo_module_to_dot_graph": ( "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", - _xe.hlo_module_to_dot_graph, + _jax.hlo_module_to_dot_graph, ), "HloModule": ( "jax.lib.xla_extension.HloModule is deprecated.", - _xe.HloModule, + _jax.HloModule, ), "HloPrintOptions": ( "jax.lib.xla_extension.HloPrintOptions is deprecated.", - _xe.HloPrintOptions, + _jax.HloPrintOptions, ), "OpSharding": ( "jax.lib.xla_extension.OpSharding is deprecated.", - _xe.OpSharding, + _jax.OpSharding, ), "PjitFunction": ( "jax.lib.xla_extension.PjitFunction is deprecated.", - _xe.PjitFunction, + _jax.PjitFunction, ), "PmapFunction": ( "jax.lib.xla_extension.PmapFunction is deprecated.", - _xe.PmapFunction, + _jax.PmapFunction, ), } import typing as _typing if _typing.TYPE_CHECKING: - Device = _xe.Device - DistributedRuntimeClient = _xe.DistributedRuntimeClient - HloModule = _xe.HloModule - HloPrintOptions = _xe.HloPrintOptions - OpSharding = _xe.OpSharding - PjitFunction = _xe.PjitFunction - PjitFunctionCache = _xe.PjitFunctionCache - PmapFunction = _xe.PmapFunction + Device = _jax.Device + DistributedRuntimeClient = _jax.DistributedRuntimeClient + HloModule = _jax.HloModule + HloPrintOptions = _jax.HloPrintOptions + OpSharding = _jax.OpSharding + PjitFunction = _jax.PjitFunction + PjitFunctionCache = _jax.PjitFunctionCache + PmapFunction = _jax.PmapFunction - get_distributed_runtime_client = _xe.get_distributed_runtime_client - get_distributed_runtime_service = _xe.get_distributed_runtime_service - hlo_module_cost_analysis = _xe.hlo_module_cost_analysis - hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph - ifrt_proxy = _xe.ifrt_proxy - jax_jit = _xe.jax_jit - mlir = _xe.mlir - pmap_lib = _xe.pmap_lib - profiler = _xe.profiler - pytree = _xe.pytree + get_distributed_runtime_client = _jax.get_distributed_runtime_client + get_distributed_runtime_service = _jax.get_distributed_runtime_service + hlo_module_cost_analysis = _jax.hlo_module_cost_analysis + hlo_module_to_dot_graph = _jax.hlo_module_to_dot_graph + ifrt_proxy = _jax.ifrt_proxy + jax_jit = _jax.jax_jit + mlir = _jax.mlir + pmap_lib = _jax.pmap_lib + profiler = _jax.profiler + pytree = _jax.pytree else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr @@ -136,4 +136,4 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _xe +del _jax diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 2e1e4072cb55..7c00e3f5b99d 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -74,8 +74,8 @@ py_library_providing_imports_info( ":jax", ":utils", ":weakref_lru_cache", + "//jaxlib:_jax", "//jaxlib:xla_client", - "//jaxlib:xla_extension", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", @@ -133,9 +133,9 @@ pywrap_library( }), }, deps = [ + ":_jax", ":utils", ":weakref_lru_cache", - "//jaxlib:xla_extension", "//jaxlib/mlir/_mlir_libs:_chlo", "//jaxlib/mlir/_mlir_libs:_mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", @@ -275,10 +275,10 @@ nanobind_pywrap_extension( ) nanobind_pywrap_extension( - name = "xla_extension", + name = "_jax", srcs = ["xla.cc"], pytype_deps = py_deps(["numpy"]), - pytype_srcs = glob(["xla_extension/*.pyi"]), + pytype_srcs = glob(["_jax/*.pyi"]), visibility = ["//visibility:public"], deps = [ ":config", @@ -1190,7 +1190,7 @@ pytype_strict_library( deps = py_deps([ "numpy", "ml_dtypes", - ]) + [":xla_extension"], + ]) + [":_jax"], ) py_strict_test( diff --git a/jaxlib/xla_extension/__init__.pyi b/jaxlib/_jax/__init__.pyi similarity index 100% rename from jaxlib/xla_extension/__init__.pyi rename to jaxlib/_jax/__init__.pyi diff --git a/jaxlib/xla_extension/config.pyi b/jaxlib/_jax/config.pyi similarity index 100% rename from jaxlib/xla_extension/config.pyi rename to jaxlib/_jax/config.pyi diff --git a/jaxlib/xla_extension/guard_lib.pyi b/jaxlib/_jax/guard_lib.pyi similarity index 100% rename from jaxlib/xla_extension/guard_lib.pyi rename to jaxlib/_jax/guard_lib.pyi diff --git a/jaxlib/xla_extension/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi similarity index 89% rename from jaxlib/xla_extension/ifrt_programs.pyi rename to jaxlib/_jax/ifrt_programs.pyi index ddbfdbd15728..8c525de478be 100644 --- a/jaxlib/xla_extension/ifrt_programs.pyi +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -15,7 +15,7 @@ from typing import Any, Sequence, Union -from jaxlib import xla_extension +from jaxlib import _jax class Program: ... @@ -26,7 +26,7 @@ def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ... def make_colocated_python_program( name : str, picked_function: bytes, - devices: Sequence[xla_extension.Device] | xla_extension.DeviceList, + devices: Sequence[_jax.Device] | _jax.DeviceList, input_avals: Sequence[Any], output_avals: Sequence[Any], ) -> Program: ... @@ -36,7 +36,7 @@ def make_plugin_program(data: Union[str, bytes]) -> Program: ... def make_colocated_python_compile_options() -> CompileOptions: ... def make_xla_compile_options( - compile_options: xla_extension.CompileOptions, + compile_options: _jax.CompileOptions, host_callbacks: Sequence[Any] ) -> CompileOptions: ... diff --git a/jaxlib/xla_extension/ifrt_proxy.pyi b/jaxlib/_jax/ifrt_proxy.pyi similarity index 94% rename from jaxlib/xla_extension/ifrt_proxy.pyi rename to jaxlib/_jax/ifrt_proxy.pyi index 636af5931333..77963eae0f7e 100644 --- a/jaxlib/xla_extension/ifrt_proxy.pyi +++ b/jaxlib/_jax/ifrt_proxy.pyi @@ -15,10 +15,10 @@ from typing import Any, Optional, Callable -from jaxlib import xla_extension +from jaxlib import _jax _Status = Any -Client = xla_extension.Client +Client = _jax.Client class ClientConnectionOptions: diff --git a/jaxlib/xla_extension/jax_jit.pyi b/jaxlib/_jax/jax_jit.pyi similarity index 95% rename from jaxlib/xla_extension/jax_jit.pyi rename to jaxlib/_jax/jax_jit.pyi index 70d832bc2ce8..fd39ef01963e 100644 --- a/jaxlib/xla_extension/jax_jit.pyi +++ b/jaxlib/_jax/jax_jit.pyi @@ -16,12 +16,12 @@ from typing import Any, Callable, Optional, Sequence, Tuple import numpy as np -from jaxlib import xla_extension +from jaxlib import _jax from . import pytree -Client = xla_extension.Client -Device = xla_extension.Device +Client = _jax.Client +Device = _jax.Device class JitState: diff --git a/jaxlib/xla_extension/mlir.pyi b/jaxlib/_jax/mlir.pyi similarity index 100% rename from jaxlib/xla_extension/mlir.pyi rename to jaxlib/_jax/mlir.pyi diff --git a/jaxlib/xla_extension/ops.pyi b/jaxlib/_jax/ops.pyi similarity index 97% rename from jaxlib/xla_extension/ops.pyi rename to jaxlib/_jax/ops.pyi index 2a6b0897a08f..7f5e46cabbdf 100644 --- a/jaxlib/xla_extension/ops.pyi +++ b/jaxlib/_jax/ops.pyi @@ -16,17 +16,17 @@ import enum from typing import Any, Optional, Sequence, overload -from jaxlib import xla_extension +from jaxlib import _jax -FftType = xla_extension.FftType -XlaBuilder = xla_extension.XlaBuilder -XlaComputation = xla_extension.XlaComputation -XlaOp = xla_extension.XlaOp -PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision -PrimitiveType = xla_extension.PrimitiveType -Shape = xla_extension.Shape -ShapeIndex = xla_extension.ShapeIndex -ResultAccuracy = xla_extension.ResultAccuracy +FftType = _jax.FftType +XlaBuilder = _jax.XlaBuilder +XlaComputation = _jax.XlaComputation +XlaOp = _jax.XlaOp +PrecisionConfig_Precision = _jax.PrecisionConfig_Precision +PrimitiveType = _jax.PrimitiveType +Shape = _jax.Shape +ShapeIndex = _jax.ShapeIndex +ResultAccuracy = _jax.ResultAccuracy _ChannelHandle = Any _ConvDimensionNumbers = Any diff --git a/jaxlib/xla_extension/pmap_lib.pyi b/jaxlib/_jax/pmap_lib.pyi similarity index 100% rename from jaxlib/xla_extension/pmap_lib.pyi rename to jaxlib/_jax/pmap_lib.pyi diff --git a/jaxlib/xla_extension/profiler.pyi b/jaxlib/_jax/profiler.pyi similarity index 100% rename from jaxlib/xla_extension/profiler.pyi rename to jaxlib/_jax/profiler.pyi diff --git a/jaxlib/xla_extension/pytree.pyi b/jaxlib/_jax/pytree.pyi similarity index 100% rename from jaxlib/xla_extension/pytree.pyi rename to jaxlib/_jax/pytree.pyi diff --git a/jaxlib/xla_extension/sdy.pyi b/jaxlib/_jax/sdy.pyi similarity index 100% rename from jaxlib/xla_extension/sdy.pyi rename to jaxlib/_jax/sdy.pyi diff --git a/jaxlib/xla_extension/transfer_guard_lib.pyi b/jaxlib/_jax/transfer_guard_lib.pyi similarity index 100% rename from jaxlib/xla_extension/transfer_guard_lib.pyi rename to jaxlib/_jax/transfer_guard_lib.pyi diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc index 24b99770b767..8c8800e80706 100644 --- a/jaxlib/pjit.cc +++ b/jaxlib/pjit.cc @@ -1301,7 +1301,7 @@ void BuildPjitSubmodule(nb::module_& m) { } nb::object cfun = nb::borrow(PjitFunction_Type); - // Add PjitFunction to the xla_extension module so it can be pickled. + // Add PjitFunction to the _jax module so it can be pickled. m.attr("PjitFunction") = cfun; cfun.attr("__getstate__") = nb::cpp_function( [](const PjitFunction::object& self) { diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index 527bc022237f..c29c2d1eb2b5 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -1046,7 +1046,7 @@ void BuildPmapSubmodule(nb::module_& m) { } nb::object cfun = nb::borrow(JaxPmapFunction_Type); - // Add PmapFunction to the xla_extension module so it can be pickled. + // Add PmapFunction to the _jax module so it can be pickled. m.attr("PmapFunction") = cfun; cfun.attr("__signature__") = diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 216313303f3e..78db77b8521d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -67,7 +67,7 @@ py_binary( "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", "//jaxlib:xla_client.py", - "//jaxlib:xla_extension", + "//jaxlib:_jax", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", @@ -261,7 +261,7 @@ wheel_sources( data_srcs = [ "//jaxlib", "//jaxlib:jaxlib_binaries", - "//jaxlib:xla_extension", + "//jaxlib:_jax", ], hdr_srcs = [ "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index af17fed38804..b40ccd6f6870 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -105,7 +105,7 @@ def patch_copy_mlir_import( def verify_mac_libraries_dont_reference_chkstack( runfiles=None, wheel_sources_map=None ): - """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. + """Verifies that _jax.so doesn't depend on ____chkstk_darwin. We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. @@ -116,7 +116,7 @@ def verify_mac_libraries_dont_reference_chkstack( if not _is_mac(): return file_path = _get_file_path( - f"__main__/jaxlib/xla_extension.{pyext}", runfiles, wheel_sources_map + f"__main__/jaxlib/_jax.{pyext}", runfiles, wheel_sources_map ) nm = subprocess.run( ["nm", "-g", file_path], @@ -208,7 +208,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/xla_client.py", f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", - f"{source_file_prefix}jaxlib/xla_extension.{pyext}", + f"{source_file_prefix}jaxlib/_jax.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index d05d24295482..8c70f4bc7646 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -171,7 +171,7 @@ bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } } // namespace -NB_MODULE(xla_extension, m) { +NB_MODULE(_jax, m) { // Initialize ABSL logging because code within XLA uses it. #ifndef PLATFORM_GOOGLE InitializeAbslLogging(); diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 70460f1a4392..e9fb7e791574 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -118,6 +118,6 @@ pytype_library( srcs = ["xla_extension.py"], visibility = [":xla_python"], deps = [ - "//jaxlib:xla_extension", + "//jaxlib:_jax", ], ) diff --git a/jaxlib/xla/xla_extension.py b/jaxlib/xla/xla_extension.py index 5305b28f39dd..798919c01450 100644 --- a/jaxlib/xla/xla_extension.py +++ b/jaxlib/xla/xla_extension.py @@ -17,4 +17,4 @@ # ruff: noqa: F401 # ruff: noqa: F403 -from jaxlib.xla_extension import * # pylint: disable=wildcard-import +from jaxlib._jax import * # pylint: disable=wildcard-import diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 0cbd2b3f3b4d..9badaf355f0c 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -30,7 +30,7 @@ import ml_dtypes import numpy as np -from jaxlib import xla_extension as _xla +from jaxlib import _jax as _xla # Note this module does *not* depend on any Python protocol buffers. The XLA # Python bindings are currently packaged both as part of jaxlib and as part diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index c10556e83920..445bb2287f8a 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -21,40 +21,40 @@ from typing import Any, Union import numpy -from jaxlib import xla_extension as _xla -from jaxlib.xla_extension import ArrayImpl as ArrayImpl -from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode -from jaxlib.xla_extension import Client as Client -from jaxlib.xla_extension import CompileOptions as CompileOptions -from jaxlib.xla_extension import Device as Device -from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment -from jaxlib.xla_extension import DeviceList as DeviceList -from jaxlib.xla_extension import DeviceTopology as DeviceTopology -from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient -from jaxlib.xla_extension import FftType as FftType -from jaxlib.xla_extension import Frame as Frame -from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding -from jaxlib.xla_extension import HloSharding as HloSharding -from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics -from jaxlib.xla_extension import ifrt_programs as ifrt_programs -from jaxlib.xla_extension import Layout as Layout -from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable -from jaxlib.xla_extension import Memory as Memory -from jaxlib.xla_extension import NamedSharding as NamedSharding -from jaxlib.xla_extension import ops as ops -from jaxlib.xla_extension import OpSharding as OpSharding -from jaxlib.xla_extension import PjRtLayout as PjRtLayout -from jaxlib.xla_extension import PmapSharding as PmapSharding -from jaxlib.xla_extension import PrimitiveType as PrimitiveType -from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics -from jaxlib.xla_extension import profiler as profiler -from jaxlib.xla_extension import Shape as Shape -from jaxlib.xla_extension import Sharding as Sharding -from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding -from jaxlib.xla_extension import Traceback as Traceback -from jaxlib.xla_extension import XlaBuilder as XlaBuilder -from jaxlib.xla_extension import XlaComputation as XlaComputation -from jaxlib.xla_extension import XlaOp as XlaOp +from jaxlib import _jax as _xla +from jaxlib._jax import ArrayImpl as ArrayImpl +from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode +from jaxlib._jax import Client as Client +from jaxlib._jax import CompileOptions as CompileOptions +from jaxlib._jax import Device as Device +from jaxlib._jax import DeviceAssignment as DeviceAssignment +from jaxlib._jax import DeviceList as DeviceList +from jaxlib._jax import DeviceTopology as DeviceTopology +from jaxlib._jax import DistributedRuntimeClient as DistributedRuntimeClient +from jaxlib._jax import FftType as FftType +from jaxlib._jax import Frame as Frame +from jaxlib._jax import GSPMDSharding as GSPMDSharding +from jaxlib._jax import HloSharding as HloSharding +from jaxlib._jax import HostBufferSemantics as HostBufferSemantics +from jaxlib._jax import ifrt_programs as ifrt_programs +from jaxlib._jax import Layout as Layout +from jaxlib._jax import LoadedExecutable as LoadedExecutable +from jaxlib._jax import Memory as Memory +from jaxlib._jax import NamedSharding as NamedSharding +from jaxlib._jax import ops as ops +from jaxlib._jax import OpSharding as OpSharding +from jaxlib._jax import PjRtLayout as PjRtLayout +from jaxlib._jax import PmapSharding as PmapSharding +from jaxlib._jax import PrimitiveType as PrimitiveType +from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics +from jaxlib._jax import profiler as profiler +from jaxlib._jax import Shape as Shape +from jaxlib._jax import Sharding as Sharding +from jaxlib._jax import SingleDeviceSharding as SingleDeviceSharding +from jaxlib._jax import Traceback as Traceback +from jaxlib._jax import XlaBuilder as XlaBuilder +from jaxlib._jax import XlaComputation as XlaComputation +from jaxlib._jax import XlaOp as XlaOp _version: int diff --git a/pyproject.toml b/pyproject.toml index 8e2bf17edc69..03e4ec0c9ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ module = [ "jaxlib.triton.*", "jaxlib.utils", "jaxlib.version", - "jaxlib.xla_extension.utils", + "jaxlib._jax.utils", "jraph.*", "libtpu.*", "matplotlib.*", diff --git a/tests/api_test.py b/tests/api_test.py index d80404f25a4e..daea53d0fb38 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -63,7 +63,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension +from jax._src.lib import _jax import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1362,7 +1362,7 @@ def f(x): "exec_time_optimization_effort": 0.0, })(1.0) # doesn't crash. - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + with self.assertRaisesRegex(_jax.XlaRuntimeError, "No such"): f_jit = jit( f, compiler_options={ @@ -1403,12 +1403,12 @@ def f(x): lowered = f_jit.lower(1.) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + _jax.XlaRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -1423,7 +1423,7 @@ def f(x): # We should still error on invalid options after some valid compiles with self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'"): jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) def test_lower_compile_with_compiler_options_multiple(self): @@ -1448,7 +1448,7 @@ def f(x): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 2f4b7d511fbe..816e1cb81472 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -30,7 +30,7 @@ from jax._src import core from jax._src import test_util as jtu from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError -from jax._src.lib import xla_extension +from jax._src.lib import _jax import jax.numpy as jnp config.parse_flags_with_absl() @@ -1387,7 +1387,7 @@ def f(x): checkify.check(x > 0, "x needs to be positive") return x - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, + with self.assertRaisesRegex(_jax.XlaRuntimeError, "x needs to be positive"): f(-1.) diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index f5dcff837838..b49eed8a5e62 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -106,7 +106,6 @@ } ], "source": [ - "from jaxlib import xla_extension\n", "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index ce9348b594b0..53fdac98504c 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -32,7 +32,7 @@ from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax._src.state import utils as state_utils from jax._src.state import discharge as state_discharge @@ -1874,7 +1874,7 @@ def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with self.assertRaises(xla_extension.XlaRuntimeError): + with self.assertRaises(_jax.XlaRuntimeError): self.pallas_call( kernel, out_shape=x, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3f10f7839a24..d17851b3236d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -61,7 +61,7 @@ from jax._src.interpreters import pxla from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -7979,12 +7979,12 @@ def test_op_sharding_tuple_shardings(self): def test_hlo_sharding_iota_tile_error(self): self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `dims` should not be empty.', lambda: xc.HloSharding.iota_tile(()) ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: Cannot reshape from', lambda: xc.HloSharding.iota_tile( (2, 2), @@ -7993,7 +7993,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the' ' same size', lambda: xc.HloSharding.iota_tile( @@ -8002,7 +8002,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) self.assertRaisesWithLiteralMatch( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `subgroup_types`(3) should not have more dimensions ' 'than `dims`(2).', lambda: xc.HloSharding.iota_tile( diff --git a/tests/pmap_test.py b/tests/pmap_test.py index e22333609ab7..84136e48ecb7 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -49,7 +49,7 @@ from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import pxla from jax._src.lax import parallel -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() @@ -318,12 +318,12 @@ def test_jit_lower_compile_with_compiler_options_invalid(self): lowered = f.lower(x) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + _jax.XlaRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -356,7 +356,7 @@ def test_jit_lower_compile_with_compiler_options_multiple(self): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index 02cc81695d57..23a8fa7f42a1 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -22,7 +22,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lax import lax -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp @@ -373,7 +373,7 @@ def test_invalid_accuracy(self): ) def test_low_tol(self, op, x, **kwargs): with self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "impl_type.ok()" + _jax.XlaRuntimeError, "impl_type.ok()" ): op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) From 735718bfe75e08c172b94b58dccf3f1355b1c07f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Apr 2025 14:11:02 -0700 Subject: [PATCH 0768/1769] Cleanup: remove superfluous jax.numpy utility --- jax/_src/numpy/lax_numpy.py | 2 +- jax/_src/numpy/reductions.py | 4 ++-- jax/_src/numpy/util.py | 9 ++------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 641422fceef3..6e9cb3de3985 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2596,7 +2596,7 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike a, b = util.promote_args_inexact("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, np.complexfloating): - dtype = util._complex_elem_type(dtype) + dtype = np.array(0, dtype).real.dtype rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) out = lax.le( diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index d708f3dd4c87..77d8662fd9a9 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, ensure_arraylike, + _broadcast_to, check_arraylike, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -1160,7 +1160,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: - dtype = _complex_elem_type(a_dtype) + dtype = np.array(0, a_dtype).real.dtype computation_dtype = a_dtype return _upcast_f16(computation_dtype), np.dtype(dtype) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index bcfb12673806..edfe569349b6 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -27,8 +27,8 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import (Array, ArrayLike, DimSize, DType, DTypeLike, - Shape, SupportsNdim, SupportsShape, SupportsSize) +from jax._src.typing import ( + Array, ArrayLike, DimSize, Shape, SupportsNdim, SupportsShape, SupportsSize) from jax.sharding import Sharding import numpy as np @@ -124,11 +124,6 @@ def promote_dtypes_complex(*args: ArrayLike) -> list[Array]: for x in args] -def _complex_elem_type(dtype: DTypeLike) -> DType: - """Returns the float type of the real/imaginary parts of a complex dtype.""" - return np.abs(np.zeros((), dtype)).dtype - - def _arraylike(x: ArrayLike) -> bool: return (isinstance(x, np.ndarray) or isinstance(x, Array) or hasattr(x, '__jax_array__') or np.isscalar(x)) From 75e4aa9cba7ee588b4fadbae0429a8d24554fbd4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 23 Apr 2025 14:20:24 -0700 Subject: [PATCH 0769/1769] Skip host_offloading notebook from running PiperOrigin-RevId: 750722897 --- docs/conf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 87fef6337f29..addf0cf50676 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -143,6 +143,7 @@ def _do_not_evaluate_in_jax( 'autodidax2_part1.md', 'sharded-computation.md', 'ffi.ipynb', + 'notebooks/host-offloading.ipynb', ] # The name of the Pygments (syntax highlighting) style to use. @@ -234,7 +235,8 @@ def _do_not_evaluate_in_jax( 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', - 'distributed_data_loading.*' + 'distributed_data_loading.*', + 'notebooks/host-offloading.*', ] # -- Options for HTMLHelp output --------------------------------------------- From 9fef4c0077b0d5b7cbe3f281594f632ced8616cd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 23 Apr 2025 14:47:40 -0700 Subject: [PATCH 0770/1769] internal change PiperOrigin-RevId: 750732175 --- tests/array_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/array_test.py b/tests/array_test.py index 4e01eb33841e..230c4cda336a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1440,13 +1440,6 @@ def test_memory_kind_with_abstract_mesh(self): ValueError, 'Got invalid memory kind'): NamedSharding(abstract_mesh, P(), memory_kind='weird_device') - def test_pos_gspmd_sharding_warnings(self): - with self.assertWarns(DeprecationWarning): - jax.sharding.PositionalSharding(jax.devices()) - - with self.assertWarns(DeprecationWarning): - jax.sharding.GSPMDSharding.get_replicated(jax.devices()) - @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): From 7ec1a3dcc86b066b314c3198f94f03d31daae91b Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 23 Apr 2025 15:07:59 -0700 Subject: [PATCH 0771/1769] Set is_custom field for custom_partitioning_sharding_rule so that rules aren't removed in sharding propagation. PiperOrigin-RevId: 750739094 --- jax/_src/custom_partitioning_sharding_rule.py | 3 ++- tests/custom_partitioning_sharding_rule_test.py | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index b5563634352c..d17399beda5b 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -473,4 +473,5 @@ def build_dim_mapping_for_compound_factors(i, j, factors): return sdy.OpShardingRuleAttr.get( factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], operand_mappings=tensor_mappings[0:len(operand_types)], - result_mappings=tensor_mappings[len(operand_types):]) + result_mappings=tensor_mappings[len(operand_types):], + is_custom=True) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index f22721910408..d7e93ddec5b2 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -383,7 +383,7 @@ def test_conversion_compound_then_individual(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}, custom>") def test_conversion_elementwise_rule_scalar_instance(self): opnd0 = self.create_tensor_value(()) @@ -399,7 +399,7 @@ def test_conversion_elementwise_rule_scalar_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([], [])->([])>") + "#sdy.op_sharding_rule<([], [])->([]), custom>") def test_conversion_elementwise_rule_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -415,7 +415,7 @@ def test_conversion_elementwise_rule_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}, custom>") def test_conversion_vector_scalar_add_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -431,7 +431,7 @@ def test_conversion_vector_scalar_add_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}, custom>") def test_conversion_reshape_rule(self): opnd0 = self.create_tensor_value((2, 4)) @@ -446,7 +446,7 @@ def test_conversion_reshape_rule(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}, custom>") def test_conversion_contracting_dim_matmul(self): opnd0 = self.create_tensor_value((16, 32)) @@ -462,7 +462,7 @@ def test_conversion_contracting_dim_matmul(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}, custom>") def test_conversion_multiple_batching_groups(self): @@ -479,7 +479,7 @@ def test_conversion_multiple_batching_groups(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>") + "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}, custom>") if __name__ == "__main__": From 42937146dcc5ac9ed177b14fa3cf63f43b360f4f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 23 Apr 2025 15:16:48 -0700 Subject: [PATCH 0772/1769] Remove config `_check_rep` to `_check_vma` and the kwarg in shard_map_p from `check_rep` to `check_vma`. PiperOrigin-RevId: 750741689 --- jax/_src/checkify.py | 8 +- jax/_src/config.py | 6 +- jax/_src/core.py | 6 +- jax/_src/interpreters/batching.py | 6 +- jax/_src/lax/parallel.py | 12 +-- jax/_src/shard_map.py | 164 +++++++++++++++--------------- 6 files changed, 101 insertions(+), 101 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 298d44f023de..a645f6c71249 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -969,15 +969,15 @@ def shard_map_error_check( new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) auto = kwargs.get('auto') - check_rep = kwargs.get('check_rep') + check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, check_rep, new_in_names[i], v) + in_avals[i] = sharder(mesh, auto, check_vma, new_in_names[i], v) with (jshmap._extend_axis_env(mesh, auto), mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, auto)), - config._check_rep(check_rep)): + config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals @@ -990,7 +990,7 @@ def expand_errors_leading_dim(*xs): errs = [lax.expand_dims(e, [0]) for e in errs] return *errs, *outs - with core.extend_axis_env_nd(mesh.shape.items()), config._check_rep(check_rep): + with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(expand_errors_leading_dim, debug_info=checked_jaxpr.jaxpr.debug_info), diff --git a/jax/_src/config.py b/jax/_src/config.py index 83c5654b87d0..2abb457ef115 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -239,7 +239,7 @@ def trace_context(): disable_jit.value, debug_key_reuse.value, jax_xla_profile_version.value, - _check_rep.value, + _check_vma.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. hlo_source_file_canonicalization_regex.value, pgle_profiling_runs.value, @@ -1079,8 +1079,8 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: include_in_jit_key=True) # TODO make it so people don't use this, this is internal... -_check_rep = bool_state( - name='check_rep', +_check_vma = bool_state( + name='check_vma', default=False, help='internal implementation detail of shard_map, DO NOT USE', include_in_jit_key=True) diff --git a/jax/_src/core.py b/jax/_src/core.py index 0183a9942524..b6f12745ea92 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2000,7 +2000,7 @@ def pvary(x, axis_name): pvary_p.def_impl(lambda *args, axes, axis_index_groups: args) def _pvary_abstract_eval(*args, axes, axis_index_groups): - if not config._check_rep.value: + if not config._check_vma.value: return args assert isinstance(axes, tuple) arg_vma = [a.vma for a in args] @@ -2019,7 +2019,7 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups): def standard_insert_pvary(*args): - if not config._check_rep.value: + if not config._check_vma.value: return args if not args: return args @@ -2030,7 +2030,7 @@ def standard_insert_pvary(*args): if out_vma - src else arg for arg, src in zip(args, in_vma)] def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: - if not config._check_rep.value: + if not config._check_vma.value: return frozenset() avals = tuple(a for a in avals if a is not abstract_token) if not avals: diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 200189502db6..1c6e00861448 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -407,7 +407,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, def aval(self): aval = core.get_aval(self.val) if self._trace.axis_data.spmd_name is not None: - if config._check_rep.value: + if config._check_vma.value: aval = aval.update( vma=aval.vma - frozenset(self._trace.axis_data.spmd_name)) if self.batch_dim is not_mapped: @@ -776,7 +776,7 @@ def _batch_jaxpr2( aval = core.unmapped_aval( axis_data.size, b, aval, axis_data.explicit_mesh_axis) if axis_data.spmd_name is not None: - if config._check_rep.value: + if config._check_vma.value: aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore avals_in2.append(aval) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -1113,7 +1113,7 @@ def broadcast(x, sz, axis, mesh_axis=None): # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) - if config._check_rep.value: + if config._check_vma.value: # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 spmd_names = core.get_axis_env().spmd_axis_names if len(spmd_names) > 1: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 89a5799d9541..0b3e0c3acc59 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -145,7 +145,7 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - if config._check_rep.value: + if config._check_vma.value: out_flat = bind_psum_invariant( leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) else: @@ -869,7 +869,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): - if not config._check_rep.value: + if not config._check_vma.value: return psum_p.abstract_eval( *args, axes=axes, axis_index_groups=axis_index_groups) @@ -904,7 +904,7 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): # TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): - if not config._check_rep.value: + if not config._check_vma.value: return _allreduce_effectful_abstract_eval( *args, axes=axes, axis_index_groups=axis_index_groups) return _psum_invariant_abstract_eval( @@ -1457,7 +1457,7 @@ def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') def insert_collective_pvary(axis_name, x): - if not config._check_rep.value: + if not config._check_vma.value: return x axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name @@ -1589,7 +1589,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def collective_vma_rule(prim_name, axis_name, x_aval): - if not config._check_rep.value: + if not config._check_vma.value: return frozenset() axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name if any(a not in x_aval.vma for a in axis_name): @@ -1957,7 +1957,7 @@ def _axis_index_effectful_abstract_eval(*, axis_name): mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) - if config._check_rep.value else frozenset()) + if config._check_vma.value else frozenset()) return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 469087dc19f3..3496f77078b9 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -216,7 +216,7 @@ def out_names_thunk(): try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_vma, auto=auto) + out_names_thunk=out_names_thunk, check_vma=check_vma, auto=auto) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -460,7 +460,7 @@ def _inout_rep_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, "Check if these output values are meant to be replicated over those " "mesh axes. If not, consider revising the corresponding out_specs " "entries. If so, consider disabling the check by passing the " - "check_rep=False argument to shard_map.") + "check_vma=False argument to `jax.shard_map`.") return msg def _unmentioned(mesh: Mesh | AbstractMesh, names: AxisNames) -> list[AxisName]: @@ -559,23 +559,23 @@ def _shard_map_staging( in_tracers: Sequence[Any], *, mesh: Mesh, in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, + check_vma: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, + in_avals_ = map(partial(_shard_aval, mesh, auto, check_vma), in_names, in_avals) manual_mesh = _as_manual_mesh(mesh, auto) with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), - config._check_rep(check_rep)): + config._check_vma(check_vma)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) - if check_rep: + if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] _check_reps(mesh, auto, out_names_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) + out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] @@ -584,11 +584,11 @@ def _shard_map_staging( outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), - config._check_rep(check_rep)): + config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, auto=auto) + check_vma=check_vma, auto=auto) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, effs, source_info) @@ -602,21 +602,21 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval -def _shard_aval(mesh: Mesh, auto, check_rep, names: AxisNames, +def _shard_aval(mesh: Mesh, auto, check_vma, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, check_rep, names, + return core.shard_aval_handlers[type(aval)](mesh, auto, check_vma, names, aval) raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _unshard_aval(mesh: Mesh, check_rep, names: AxisNames, +def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, check_rep, names, aval) + return core.unshard_aval_handlers[type(aval)](mesh, check_vma, names, aval) else: raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames, +def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_vma, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) @@ -624,11 +624,11 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) vma = (frozenset({n for ns in names.values() for n in ns}) - if check_rep else frozenset()) + if check_vma else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array -def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, +def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) @@ -656,7 +656,7 @@ def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, new_sharding = NamedSharding(new_mesh, out_spec) manual_axes = set(new_mesh.manual_axes) vma = (frozenset(v for v in aval.vma if v in manual_axes) - if check_rep else frozenset()) + if check_vma else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array @@ -665,23 +665,23 @@ def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, RepType = Any def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, auto): + check_vma, auto): # TODO(mattjj,parkers): check auto for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval( - mesh, auto, check_rep, in_name, x.aval)): + mesh, auto, check_vma, in_name, x.aval)): raise core.JaxprTypeError("shard_map argument avals not compatible with " "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + with _extend_axis_env(mesh, auto), config._check_vma(check_vma): core.check_jaxpr(jaxpr) - if check_rep: + if check_vma: out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): if not _valid_repeats(mesh, auto, rep, dst): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh, check_rep), out_names, + out_avals = map(partial(_unshard_aval, mesh, check_vma), out_names, out_avals_sharded) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) return out_avals, effs @@ -712,7 +712,7 @@ def _shardy_shard_map_sharding( def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep): + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma): in_avals_ = [v.aval for v in jaxpr.invars] if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): # Nested `ManualComputationOp`s cannot refer to axes that are already @@ -731,7 +731,7 @@ def _shard_map_lowering_shardy( if a in shardy_manual_axes] if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + with _extend_axis_env(mesh, auto), config._check_vma(check_vma): out_nodes, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, dim_var_values=ctx.dim_var_values) @@ -751,7 +751,7 @@ def _shard_map_lowering_shardy( block = ir.Block.create_at_start( manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), - config._check_rep(check_rep)): + config._check_vma(check_vma)): out_nodes_, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, dim_var_values=ctx.dim_var_values) @@ -761,10 +761,10 @@ def _shard_map_lowering_shardy( def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, auto): + check_vma, auto): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep) + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] @@ -773,7 +773,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, manual_axes = frozenset(mesh.axis_names) - auto new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + with _extend_axis_env(mesh, auto), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, @@ -861,34 +861,34 @@ def _vma_to_rep(mesh, auto, vma): return frozenset((set(mesh.axis_names) - auto) - vma) def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, auto): + check_vma, auto): if auto: raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): mesh = get_mesh_from_args(args, mesh) cur_mesh = get_abstract_mesh() - args = map(partial(_unmatch_spec, mesh, check_rep, context_mesh=cur_mesh), + args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), in_names, args) in_vma = map(_names_to_vma, in_names) - outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_rep, cur_mesh) + outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_vma, cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep: + if check_vma: _check_reps(mesh, auto, out_names_thunk(), out_vma) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) dst_pspecs = map(_names_to_pspec, out_names_thunk()) - return map(partial(_match_spec, mesh, check_rep), src_pspecs, dst_pspecs, + return map(partial(_match_spec, mesh, check_vma), src_pspecs, dst_pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl -def _run_shmap(f, mesh, auto, args, vmas, check_rep, context_mesh): - trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) +def _run_shmap(f, mesh, auto, args, vmas, check_vma, context_mesh): + trace = ShardMapTrace(mesh, auto, check_vma, context_mesh) in_tracers = map(partial(ShardMapTracer, trace), vmas, args) manual_mesh = _as_manual_mesh(mesh, auto) with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - use_abstract_mesh(manual_mesh), config._check_rep(check_rep)): + use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): ans = f.call_wrapped(*in_tracers) outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) return outs, out_vma @@ -898,23 +898,23 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec: unpack = lambda t: t[0] if t is not None and len(t) == 1 else t return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) -def _unmatch_spec(mesh: Mesh, check_rep, src: AxisNames, x: JaxType, +def _unmatch_spec(mesh: Mesh, check_vma, src: AxisNames, x: JaxType, context_mesh) -> JaxType: with (core.eval_context(), jax.disable_jit(False), use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, check_rep, + return jax.jit(HashablePartial(_unmatch, mesh, check_vma, tuple(src.items())))(x) -def _unmatch(mesh, check_rep, src_tup, x): +def _unmatch(mesh, check_vma, src_tup, x): src = _names_to_pspec(dict(src_tup)) - if check_rep: + if check_vma: used_axes = {i for _, ns in src_tup for i in ns} dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) else: dst = P(mesh.axis_names) - check_rep = False + check_vma = False return shard_map(_add_singleton, mesh=mesh, in_specs=(src,), out_specs=dst, - check_vma=check_rep)(x) + check_vma=check_vma)(x) def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] ) -> None: @@ -936,15 +936,15 @@ def _check_reps(mesh, auto, names, vmas): class _RepError(Exception): pass -def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, +def _match_spec(mesh: Mesh, check_vma, src_pspec: PartitionSpec, dst_pspec: PartitionSpec, x: JaxType) -> JaxType: - fn = HashablePartial(_match, mesh, check_rep, src_pspec, dst_pspec) + fn = HashablePartial(_match, mesh, check_vma, src_pspec, dst_pspec) with core.eval_context(), jax.disable_jit(False): return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) -def _match(mesh, check_rep, src_pspec, dst_pspec, x): +def _match(mesh, check_vma, src_pspec, dst_pspec, x): return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec, - out_specs=dst_pspec, check_vma=check_rep)(x) + out_specs=dst_pspec, check_vma=check_vma)(x) def _rem_singleton(x): return jnp.squeeze(x, axis=0) def _add_singleton(x): return jnp.expand_dims(x, axis=0) @@ -1065,7 +1065,7 @@ def aval(self): new_sharding = NamedSharding( _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error - vma = self.vma if config._check_rep.value else frozenset() + vma = self.vma if config._check_vma.value else frozenset() return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): @@ -1087,13 +1087,13 @@ def __str__(self) -> str: for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) __repr__ = __str__ # for debuggers, like `p x` -def _prim_applier(prim, check_rep, params_tup, mesh, in_specs, out_specs, *args): +def _prim_applier(prim, check_vma, params_tup, mesh, in_specs, out_specs, *args): def apply(*args): outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) return tree_map(_add_singleton, outs) out_specs = list(out_specs) if type(out_specs) is tuple else out_specs return shard_map(apply, mesh=mesh, in_specs=in_specs, out_specs=out_specs, - check_vma=check_rep)(*args) + check_vma=check_vma)(*args) eager_rules: dict[core.Primitive, Callable] = {} @@ -1133,7 +1133,7 @@ def _shard_map_batch( in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, + check_vma: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): @@ -1158,7 +1158,7 @@ def new_out_names_thunk(): return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) new_params = dict(mesh=mesh, in_names=new_in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, + out_names_thunk=new_out_names_thunk, check_vma=check_vma, auto=auto) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) @@ -1182,7 +1182,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): + out_names_thunk, check_vma, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -1196,7 +1196,7 @@ def new_out_names_thunk(): out_ax = out_names_thunk() return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, + out_names_thunk=new_out_names_thunk, check_vma=check_vma, auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) @@ -1208,12 +1208,12 @@ def new_out_names_thunk(): def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): + out_names_thunk, check_vma, auto): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_rep), + in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_vma), unk_in_names, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) @@ -1225,7 +1225,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, def known_out_names(): _, _, out_knowns, res_avals, _, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - if check_rep: + if check_vma: res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} for a in res_avals] else: @@ -1233,7 +1233,7 @@ def known_out_names(): return (*out_known_names, *res_names) known_params = dict(mesh=mesh, in_names=(*known_in_names,), - out_names_thunk=known_out_names, check_rep=check_rep, + out_names_thunk=known_out_names, check_vma=check_vma, auto=auto) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) @@ -1253,7 +1253,7 @@ def known_out_names(): elif f2 is not None: res_names.append(known_out_names_[f2]) else: - if check_rep: + if check_vma: res_vma = next(res_avals_iter).vma res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) else: @@ -1265,8 +1265,8 @@ def known_out_names(): out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, - check_rep=check_rep, auto=auto) - out_avals = map(partial(_unshard_aval, mesh, check_rep), unk_out_names, + check_vma=check_vma, auto=auto) + out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] @@ -1280,7 +1280,7 @@ def known_out_names(): def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): + out_names_thunk, check_vma, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) @@ -1291,7 +1291,7 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, def fwd_out_names_thunk(): res_avals, _, _, _, _, _ = linearize_outs_thunk() out_names = out_names_thunk() - if check_rep: + if check_vma: res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} for a in res_avals] else: @@ -1299,7 +1299,7 @@ def fwd_out_names_thunk(): return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, auto=auto) + out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, auto=auto) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() @@ -1311,7 +1311,7 @@ def fwd_out_names_thunk(): for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] with (_extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)), - config._check_rep(check_rep)): + config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() res_avals_iter = iter(res_avals) @@ -1322,7 +1322,7 @@ def fwd_out_names_thunk(): elif f2 is not None: res_names.append(out_names[f2]) else: - if check_rep: + if check_vma: res_vma = next(res_avals_iter).vma res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) else: @@ -1335,7 +1335,7 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=check_rep, auto=auto) + check_vma=check_vma, auto=auto) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): @@ -1397,16 +1397,16 @@ def _unmentioned2(mesh: Mesh, names: AxisNames, def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, auto): + check_vma, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ - ad.Zero(_shard_aval(mesh, auto, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if check_rep or dtypes.dtype(x) == dtypes.float0 + ad.Zero(_shard_aval(mesh, auto, check_vma, ns, x.aval)) + if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts) ] args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, check_rep, ns, x.aval)) + ad.UndefinedPrimal(_shard_aval(mesh, auto, check_vma, ns, x.aval)) for ns, x in zip(in_names, args)) all_args, in_tree = tree_flatten((out_cts, args)) @@ -1421,8 +1421,8 @@ def fun_trans_callable(out_cts, args): jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts )[len(res_reshaped):] _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, check_rep, ns, x.aval)) - if type(x) is ad.Zero else x if check_rep + in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, ns, x.aval)) + if type(x) is ad.Zero else x if check_vma else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_ct_names, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] @@ -1442,7 +1442,7 @@ def new_out_names_thunk(): try: out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, + out_names_thunk=new_out_names_thunk, check_vma=check_vma, auto=auto) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " @@ -1453,7 +1453,7 @@ def new_out_names_thunk(): with jax.disable_jit(True): _ = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, + out_names_thunk=new_out_names_thunk, check_vma=check_vma, auto=auto) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None @@ -1470,8 +1470,8 @@ def _partial_eval_jaxpr_custom_rule( ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - check_rep, auto = eqn.params['check_rep'], eqn.params['auto'] - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + check_vma, auto = eqn.params['check_vma'], eqn.params['auto'] + with _extend_axis_env(mesh, auto), config._check_vma(check_vma): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res @@ -1483,7 +1483,7 @@ def _partial_eval_jaxpr_custom_rule( mesh = eqn.params['mesh'] with (_extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)), - config._check_rep(check_rep)): + config._check_vma(check_vma)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) @@ -1497,10 +1497,10 @@ def _partial_eval_jaxpr_custom_rule( for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore - if check_rep else {0: _all_newly_manual_mesh_names(mesh, auto)}) - residuals.append(newvar(_unshard_aval(mesh, check_rep, rn, var.aval))) + if check_vma else {0: _all_newly_manual_mesh_names(mesh, auto)}) + residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) staged_in_res_names.append(rn) - if check_rep: + if check_vma: out_res_names_known = [ {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} for var, o in zip(res_vars, out_fwd) if o is None @@ -1611,8 +1611,8 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] auto = eqn.params["auto"] - check_rep = eqn.params["check_rep"] - with _extend_axis_env(mesh, auto), config._check_rep(check_rep): + check_vma = eqn.params["check_vma"] + with _extend_axis_env(mesh, auto), config._check_vma(check_vma): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None From 987fc45ccfd5e657f1cd379e9fd2dd6586941cf0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 23 Apr 2025 15:46:56 -0700 Subject: [PATCH 0773/1769] [pallas] `pl.pallas_call` no longer allows `compiler_params=` to be a param->value dict The parameters must be specified via a dataclass or a mapping from a backend to the corresponding dataclass. PiperOrigin-RevId: 750750391 --- jax/_src/pallas/core.py | 6 +- jax/_src/pallas/mosaic/core.py | 2 +- jax/_src/pallas/mosaic/interpret.py | 23 +++-- jax/_src/pallas/mosaic/lowering.py | 6 +- .../pallas/mosaic/pallas_call_registration.py | 42 ++++---- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 46 ++++----- .../mosaic_gpu/pallas_call_registration.py | 19 ++-- jax/_src/pallas/pallas_call.py | 95 +++++++++++-------- jax/_src/pallas/triton/BUILD | 3 +- jax/_src/pallas/triton/core.py | 2 +- .../pallas/triton/pallas_call_registration.py | 37 ++++---- jax/_src/tpu_custom_call.py | 12 +-- tests/pallas/pallas_test.py | 44 --------- 15 files changed, 159 insertions(+), 181 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 9dda95d0c7f0..04390a03d4d6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -24,7 +24,7 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Protocol, TypeAlias, Union, runtime_checkable +from typing import Any, ClassVar, Hashable, Literal, Protocol, TypeAlias, Union, runtime_checkable import jax from jax._src import api_util @@ -117,10 +117,12 @@ class BarrierSemaphore(AbstractSemaphoreTy): name = "barrier_semaphore" type = barrier_semaphore +Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] + @runtime_checkable class CompilerParams(Protocol): """Base class for compiler parameters.""" - PLATFORM: ClassVar[str] + BACKEND: ClassVar[Backend] # Subclasses must be dataclasses. __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3e5403d889c8..18ad3029398e 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -87,7 +87,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): serialization_format: The serialization format for the kernel body. disable_bounds_checks: Disable bounds checks in the kernel. """ - PLATFORM: ClassVar[str] = "mosaic" + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" dimension_semantics: ( Sequence[Literal["parallel", "arbitrary"] | GridDimensionSemantics] | None ) = None diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 95dec0cd937e..448690785986 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -20,7 +20,7 @@ import itertools import math import threading -from typing import Any, Callable,Literal +from typing import Any, Callable, Literal, cast import jax from jax import lax @@ -35,6 +35,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit +from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives @@ -1292,7 +1293,7 @@ def f(*args, jaxpr): get_barrier_semaphore, jax.ShapeDtypeStruct((), jnp.int16), device_id, - compiler_params['mosaic']['collective_id'], + _get_mosaic_params(compiler_params).collective_id, ordered=True) elif prim is primitives.semaphore_signal_p: @@ -1383,16 +1384,22 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) + +def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> tpu_core.TPUCompilerParams: + try: + return cast(tpu_core.TPUCompilerParams, compiler_params['mosaic_tpu']) + except KeyError: + return tpu_core.TPUCompilerParams() + + def _get_parallel_dim_semantics( compiler_params: dict[str, Any], grid: tuple[int, ...] ) -> tuple[bool, ...]: """Returns a tuple of booleans indicating whether the corresponding dimension in `grid` is parallel.""" - dimension_semantics = compiler_params.get('mosaic', {}).get( - 'dimension_semantics', None - ) - if dimension_semantics is None: + mosaic_params = _get_mosaic_params(compiler_params) + if mosaic_params.dimension_semantics is None: return (False,) * len(grid) - return tuple(ds == 'parallel' for ds in dimension_semantics) + return tuple(ds == 'parallel' for ds in mosaic_params.dimension_semantics) _GridPointCoordinatesPerDim = tuple[Array, ...] @@ -1666,7 +1673,7 @@ def interpret_pallas_call( var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) - if compiler_params.get('mosaic', {}).get('collective_id', None) is None: + if _get_mosaic_params(compiler_params).collective_id is None: # The kernel doesn't specify its own barrier semaphore, so we do a global # barrier before running the first iteration of the kernel. callback.io_callback(_barrier, (), device_id, ordered=True) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 096aca115072..da230c89d892 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -401,7 +401,9 @@ def __init__( self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, - dimension_semantics: tuple[str | tpu_core.GridDimensionSemantics, ...] | None, + dimension_semantics: ( + Sequence[str | tpu_core.GridDimensionSemantics, ...] | None + ), mesh: mesh_lib.Mesh | None, dynamic_shape_replacement_fn: Callable[ [tuple[jax.DimSize, ...]], tuple[int, ...] @@ -654,7 +656,7 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, *, dimension_semantics: ( - tuple[str | tpu_core.GridDimensionSemantics, None, ...] | None + Sequence[str | tpu_core.GridDimensionSemantics, None, ...] | None ), mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 66be2be76113..c1d1a8029c5f 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -18,7 +18,7 @@ import os import tempfile -from typing import Any +from typing import cast import jax from jax import dtypes @@ -28,7 +28,6 @@ from jax._src import tpu_custom_call from jax._src.interpreters import mlir from jax._src.lib.mlir import ir -from jax._src.pallas import core from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering @@ -72,7 +71,7 @@ def _get_memory_space_from_aval( ) -> tpu_custom_call.MemorySpace | None: if not isinstance(out_aval, jax_core.ShapedArray): raise ValueError('Memory spaces not defined for non-ShapedArrays') - if not isinstance(out_aval, core.ShapedArrayWithMemorySpace): + if not isinstance(out_aval, pallas_core.ShapedArrayWithMemorySpace): # If we are passed a regular old ShapedArray, we don't constrain the # memory space return None @@ -97,23 +96,24 @@ def _get_memory_spaces_from_avals( ) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: output_memory_spaces = None if any( - isinstance(out_aval, core.ShapedArrayWithMemorySpace) + isinstance(out_aval, pallas_core.ShapedArrayWithMemorySpace) for out_aval in out_avals ): output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) return output_memory_spaces + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, - grid_mapping: core.GridMapping, + grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, - compiler_params: dict[str, Any], - cost_estimate: core.CostEstimate | None, + compiler_params: dict[str, pallas_core.CompilerParams], + cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" @@ -123,10 +123,13 @@ def pallas_call_tpu_lowering_rule( if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") print(jaxpr) - if "mosaic" in compiler_params: - mosaic_params = compiler_params["mosaic"] + + if "mosaic_tpu" in compiler_params: + mosaic_params = cast( + tpu_core.TPUCompilerParams, compiler_params["mosaic_tpu"] + ) else: - mosaic_params = {} + mosaic_params = tpu_core.TPUCompilerParams() jax_mesh = None axis_context = ctx.module_context.axis_context @@ -142,13 +145,12 @@ def lower_module(for_verification: bool): if for_verification or tpu_core.runtime_assert_enabled(): mlir_ctx.allow_unregistered_dialects = True with mlir_ctx, ir.Location.unknown(mlir_ctx): - dimension_semantics = mosaic_params.get("dimension_semantics", None) return lowering.lower_jaxpr_to_module( ctx, mlir_ctx, grid_mapping, jaxpr, - dimension_semantics=dimension_semantics, + dimension_semantics=mosaic_params.dimension_semantics, mesh=jax_mesh, for_verification=for_verification, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), @@ -233,16 +235,16 @@ def _maybe_cast_inputs(*args): backend="tpu", kernel_name=mlir.sanitize_name(debug_info.func_name), cost_estimate=mosaic_cost_estimate, - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), - flags=mosaic_params.get("flags"), - allow_input_fusion=mosaic_params.get("allow_input_fusion"), + vmem_limit_bytes=mosaic_params.vmem_limit_bytes, + flags=mosaic_params.flags, + allow_input_fusion=mosaic_params.allow_input_fusion, input_output_aliases=input_output_aliases, - serialization_format=mosaic_params.get("serialization_format", 1), - internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), - collective_id=mosaic_params.get("collective_id", None), - has_side_effects=mosaic_params.get("has_side_effects", False), + serialization_format=mosaic_params.serialization_format, + internal_scratch_in_bytes=mosaic_params.internal_scratch_in_bytes, + collective_id=mosaic_params.collective_id, + has_side_effects=mosaic_params.has_side_effects, output_memory_spaces=output_memory_spaces, - disable_bounds_checks=mosaic_params.get("disable_bounds_checks"), + disable_bounds_checks=mosaic_params.disable_bounds_checks, ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 554b9db878f6..35ce282234d2 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -42,6 +42,7 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", "//jax:core", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6bbf53479d00..5446cdb47add 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -90,7 +90,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): events than this. profile_dir: The directory to which profiling traces will be written to. """ - PLATFORM: ClassVar[str] = "mosaic_gpu" + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 7c47e6e6309b..a961d5bf56c4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -550,7 +550,7 @@ def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.GPUCompilerParams, cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. @@ -580,14 +580,11 @@ def lower_pipelined_jaxpr_to_module( block = (128, 1, 1) grid = grid_mapping.grid - params = compiler_params.get("mosaic_gpu", {}) - dimension_semantics = params.get("dimension_semantics", None) - if dimension_semantics is None: + if params.dimension_semantics is None: which_parallel = [True] * len(grid) else: - assert len(dimension_semantics) == len(grid) - which_parallel = [ds == "parallel" for ds in dimension_semantics] - del dimension_semantics + assert len(params.dimension_semantics) == len(grid) + which_parallel = [ds == "parallel" for ds in params.dimension_semantics] sequential_grid = tuple( d for axis, d in enumerate(grid) if not which_parallel[axis] @@ -662,8 +659,8 @@ def body_fn(indices, *refs): _block_spec_from_block_mapping(bm, which_parallel) for bm in out_block_mappings ], - max_concurrent_steps=params.pop("max_concurrent_steps", 1), - delay_release=params.pop("delay_release", 0), + max_concurrent_steps=params.max_concurrent_steps, + delay_release=params.delay_release, )(*refs) with grid_mapping.trace_env(): @@ -696,7 +693,7 @@ def body_fn(indices, *refs): for r in semaphore_ref_avals ], new_jaxpr, - compiler_params, + params, new_consts, ) @@ -710,15 +707,12 @@ def lower_jaxpr_to_module( out_shapes: Sequence[jax.ShapeDtypeStruct], gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.GPUCompilerParams, consts=(), ) -> LoweringResult: debug_info = jaxpr.debug_info - params = compiler_params.get("mosaic_gpu", {}) - approx_math = params.get("approx_math", False) - lowering_semantics = params.get( - "lowering_semantics", mgpu_core.LoweringSemantics.Lane - ) + approx_math = params.approx_math + lowering_semantics = params.lowering_semantics if len(cluster) < 3: cluster = cluster + (1,) * (3 - len(cluster)) @@ -785,27 +779,25 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ), jaxpr, ) - smem_scratch_bytes = params.get("smem_scratch_bytes") - if smem_scratch_bytes is None: - smem_scratch_bytes = rs.smem_scratch_bytes - tmem_scratch_cols = rs.tmem_scratch_cols scratch_buffers = [ - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, + jax.ShapeDtypeStruct(shape=[rs.smem_scratch_bytes], dtype=np.int8), + rs.barriers, ] - if tmem_scratch_cols > 0: + if rs.tmem_scratch_cols > 0: scratch_buffers.append( - mgpu.TMEM(shape=[tcgen05.TMEM_ROWS, tmem_scratch_cols], dtype=np.int32), + mgpu.TMEM( + shape=[tcgen05.TMEM_ROWS, rs.tmem_scratch_cols], dtype=np.int32 + ), ) else: scratch_buffers.append(None) prof_ctx = prof_spec = None - if prof_space := params.get("profile_space", 0): + if params.profile_space: # Each range is 2 events, each event is 4 bytes. - prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) - prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) + prof_spec = mgpu_profiler.ProfilerSpec(params.profile_space * 2 * 4) + prof_ctx = ProfilerContext(params.profile_dir, prof_spec) module, new_out_shapes, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index eb15aff21235..2e2cb976df8f 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -19,7 +19,7 @@ import os import time -from typing import Any +from typing import cast import warnings import jax @@ -27,6 +27,7 @@ from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering from jax.experimental.mosaic import gpu as mgpu import numpy as np @@ -41,7 +42,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): @@ -58,17 +59,15 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - lowering_semantics = compiler_params.get("mosaic_gpu", {}).get( - "lowering_semantics", mgpu.LoweringSemantics.Lane - ) mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + if "mosaic_gpu" in compiler_params: + params = cast(gpu_core.GPUCompilerParams, compiler_params["mosaic_gpu"]) + else: + params = gpu_core.GPUCompilerParams() + lowering_result = lowering.lower_pipelined_jaxpr_to_module( - grid_mapping, - mesh, - jaxpr, - compiler_params, - cost_estimate, + grid_mapping, mesh, jaxpr, params, cost_estimate ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {debug_info.func_src_info}:") diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 500aaf125ca0..1bcf47b9ddee 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,12 +15,11 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Callable, Sequence -import dataclasses +from collections.abc import Callable, Mapping, Sequence import enum from functools import partial, reduce import types -from typing import Any, Literal, cast +from typing import Any, cast import jax from jax import lax @@ -68,6 +67,8 @@ no_block_spec = pallas_core.no_block_spec ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate +Backend = pallas_core.Backend +CompilerParams = pallas_core.CompilerParams # See the docstring for GridMapping for the calling convention pallas_call_p = jax_core.Primitive('pallas_call') @@ -125,7 +126,7 @@ def _pallas_call_jvp_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): debug_info = jaxpr.debug_info if grid_mapping.num_dynamic_grid_bounds: @@ -328,7 +329,7 @@ def _batch_with_explicit_loop( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): """Batch the pallas_call by calling it in loop over the batch size. @@ -426,7 +427,7 @@ def _pallas_call_batching_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): if mesh is not None: raise NotImplementedError( @@ -1237,14 +1238,12 @@ def _unsupported_lowering_error(platform: str) -> Exception: " https://docs.jax.dev/en/latest/installation.html." ) -_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] - def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, - backend: _Backend | None, + backend: Backend | None, **params, ): if params['jaxpr'].constvars: @@ -1361,7 +1360,7 @@ def _pallas_call_state_discharge_rule( compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None = None, + backend: Backend | None = None, ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1479,13 +1478,15 @@ def pallas_call( in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, scratch_shapes: ScratchShapeTree = (), - input_output_aliases: dict[int, int] = {}, + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, interpret: bool = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, CompilerParams] | CompilerParams | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. @@ -1527,22 +1528,22 @@ def pallas_call( This is useful for debugging. name: if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line - where the kernel function is defined, .e.g: - `{name} for kernel function {kernel_name} at {file}:{line}`. - If missing, then we use `{kernel_name} at {file}:{line}`. - compiler_params: Optional compiler parameters. If a dict is provided, it - should be of the form {platform: {param_name: param_value}}, where - platform is either 'mosaic' or 'triton'. It is also possible - to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and - `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. - backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu" - determining the backend to be used. None means let pallas decide. - + where the kernel function is defined, .e.g: `{name} for kernel function + {kernel_name} at {file}:{line}`. If missing, then we use `{kernel_name} at + {file}:{line}`. + compiler_params: Optional compiler parameters. The value should either be a + backend-specific dataclass + (:class:`jax.experimental.pallas.tpu.TPUCompilerParams`, + :class:`jax.experimental.pallas.triton.TritonCompilerParams`, + :class:`jax.experimental.pallas.mosaic_gpu.GPUCompilerParams`) or a dict + mapping backend name to the corresponding platform-specific dataclass. + backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or + ``"mosaic_gpu"`` determining the backend to be used. None means let Pallas + decide. Returns: A function that can be called on a number of positional array arguments to invoke the Pallas kernel. - """ if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) @@ -1578,30 +1579,46 @@ def pallas_call( ) +def _normalize_compiler_params( + compiler_params: Mapping[Backend, CompilerParams] | CompilerParams | None, +) -> Mapping[Backend, CompilerParams]: + if compiler_params is None: + return {} + if isinstance(compiler_params, pallas_core.CompilerParams): + compiler_params = {compiler_params.BACKEND: compiler_params} + assert isinstance(compiler_params, Mapping) + for backend, params in compiler_params.items(): + if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]: + raise ValueError(f"Unknown backend in compiler_params: {backend}") + if not isinstance(params, pallas_core.CompilerParams): + raise ValueError( + f"Unexpected compiler_params for backend {backend}: {params}" + ) + if params.BACKEND != backend: + raise ValueError( + f"Inconsistent backend in compiler_params: {params.BACKEND} !=" + f" {backend}" + ) + return compiler_params + + def _pallas_call( kernel: Callable[..., None], out_shape: Any, *, grid_spec: GridSpec, mesh: pallas_core.Mesh | None = None, - input_output_aliases: dict[int, int] = {}, + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, interpret: bool = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, CompilerParams] | CompilerParams | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, ): - if compiler_params is None: - compiler_params = {} - if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: - raise ValueError( - f"Unknown platform in compiler params: {compiler_params.PLATFORM}" - ) - compiler_params = { - compiler_params.PLATFORM: dataclasses.asdict(compiler_params) - } + compiler_params = _normalize_compiler_params(compiler_params) if mesh is not None: if tuple(mesh.shape.values()) != grid_spec.grid: @@ -1611,7 +1628,7 @@ def _pallas_call( ) if backend is not None: raise ValueError("If `mesh` is specified, then `backend` must be `None`.") - backend = cast(_Backend, mesh.backend) + backend = cast(Backend, mesh.backend) grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index cde2aadd6013..2b8ee4eaa8f2 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -76,12 +76,11 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", - "//jax:config", "//jax:core", "//jax:mlir", - "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index 097f8497e8f7..6b3e10f2b018 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -32,7 +32,7 @@ class TritonCompilerParams(pallas_core.CompilerParams): serialized_metadata: Additional compiler metadata. This field is unstable and may be removed in the future. """ - PLATFORM: ClassVar[str] = "triton" + BACKEND: ClassVar[pallas_core.Backend] = "triton" num_warps: int | None = None num_stages: int | None = None serialized_metadata: bytes | None = None diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4e8775e514f0..b692bd43a0fa 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,16 +17,17 @@ from __future__ import annotations import io -from typing import Any +from typing import cast import zlib import jax import jax._src.core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import triton from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core +from jax._src.pallas.triton import core as triton_core from jax._src.pallas.triton import lowering @@ -51,7 +52,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): @@ -67,16 +68,17 @@ def pallas_call_lowering( ) if mesh is not None: raise NotImplementedError("mesh is not supported in the Triton backend") - triton_params = compiler_params.get("triton", compiler_params) - num_warps = triton_params.get("num_warps", 4) - num_warps = 4 if num_warps is None else num_warps + [lowering_platform] = ctx.platforms or ctx.module_context.platforms - if lowering_platform == "rocm": - num_stages = triton_params.get("num_stages", 1) - num_stages = 1 if num_stages is None else num_stages + + if "triton" in compiler_params: + params = cast(triton_core.TritonCompilerParams, compiler_params["triton"]) else: - num_stages = triton_params.get("num_stages", 3) - num_stages = 3 if num_stages is None else num_stages + params = triton_core.TritonCompilerParams() + num_warps = 4 if params.num_warps is None else params.num_warps + num_stages = params.num_stages + if num_stages is None: + num_stages = 1 if lowering_platform == "rocm" else 3 if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") @@ -117,12 +119,11 @@ def pallas_call_lowering( grid_z=mlir.i32_attr(grid_z), debug=ir.BoolAttr.get(debug), ) - if "serialized_metadata" in (triton_params or {}): + if params.serialized_metadata is not None: # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + backend_config["serialized_metadata"] = ir.StringAttr.get( + params.serialized_metadata + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, @@ -178,10 +179,10 @@ def pallas_call_lowering( call_target_name="triton_kernel_call", result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], operands=in_nodes, - backend_config=zlib.compress( + backend_config=zlib.compress( kernel_call.to_proto( debug_info.func_name, - triton_params.get("serialized_metadata") or b"", + params.serialized_metadata or b"", ) ), operand_layouts=avals_to_layouts(ctx.avals_in), diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index d9f746506d57..ff0c28ed13f2 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -139,7 +139,7 @@ class CustomCallBackendConfig: needs_layout_passes: bool vmem_limit_bytes: int | None flags: dict[str, bool | int | float] | None - allow_input_fusion: list[bool] | None + allow_input_fusion: Sequence[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None @@ -569,7 +569,7 @@ def _lower_to_custom_call_config( vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, @@ -616,7 +616,7 @@ def _lowered_to_custom_call_config( vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, @@ -672,7 +672,7 @@ def lower_module_to_custom_call( cost_estimate: CostEstimate | None, vmem_limit_bytes: int | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, input_output_aliases: tuple[tuple[int, int], ...], internal_scratch_in_bytes: int | None, collective_id: int | None, @@ -716,7 +716,7 @@ def as_tpu_kernel( kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, @@ -763,7 +763,7 @@ def lowered_as_tpu_kernel( kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = None, internal_scratch_in_bytes: int | None = None, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index fd54c1f6065a..128da748b233 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -30,11 +30,9 @@ from jax import random from jax._src import checkify from jax._src import config -from jax._src import core as jax_core from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.pallas import pallas_call from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl import jax.numpy as jnp @@ -2572,47 +2570,5 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest): INTERPRET = True -def _find_pallas_call_in_jaxpr( - jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None: - for eqn in jaxpr.eqns: - call_eqn = None - if eqn.primitive == pallas_call.pallas_call_p: - call_eqn = eqn - elif 'jaxpr' in eqn.params: - call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr']) - if call_eqn is not None: - return call_eqn - return None - - -class PallasCompilerParamsTest(PallasBaseTest): - def test_triton_params_consistent_across_double_jit(self): - # Test for https://github.com/jax-ml/jax/issues/25714 - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Triton backend only works on GPU.") - params = plgpu.TritonCompilerParams(num_warps=8) - - @jax.jit - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - compiler_params=params) - def copy_kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] - - @functools.partial(jax.jit, static_argnames=["z"]) - def plus_z(x, z): - return copy_kernel(x+z) - - x = 0. - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 1).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 1.), 1.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 2).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 2.), 2.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) - - if __name__ == "__main__": absltest.main() From d4dd0c4985cd8046a2b97efbd9798e0ebf397c8f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 23 Apr 2025 16:18:44 -0700 Subject: [PATCH 0774/1769] Migrate jax.experimental.shard_map to jax.shard_map in internal JAX and also change docs to point to `jax.shard_map` PiperOrigin-RevId: 750760353 --- docs/jax.rst | 49 +-- docs/multi_process.md | 4 +- docs/notebooks/shard_map.ipynb | 106 +++--- docs/notebooks/shard_map.md | 100 +++--- docs/persistent_compilation_cache.md | 3 +- docs/sharded-computation.ipynb | 17 +- docs/sharded-computation.md | 17 +- jax/_src/core.py | 4 +- jax/_src/debugging.py | 3 +- jax/_src/error_check.py | 2 +- jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/control_flow/conditionals.py | 4 +- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/parallel.py | 14 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 4 +- .../jax2tf/tests/sharding_test.py | 2 +- jax/experimental/pallas/ops/tpu/all_gather.py | 4 +- jax/experimental/roofline/rooflines.py | 2 +- tests/checkify_test.py | 8 +- tests/debug_info_test.py | 2 +- tests/debug_nans_test.py | 2 +- tests/export_back_compat_test.py | 2 +- tests/export_test.py | 2 +- tests/ffi_test.py | 4 +- tests/memories_test.py | 8 +- tests/pallas/tpu_pallas_async_test.py | 12 +- tests/pallas/tpu_pallas_distributed_test.py | 24 +- .../tpu_pallas_interpret_distributed_test.py | 14 +- tests/pallas/tpu_pallas_pipeline_test.py | 10 +- tests/pallas/tpu_pallas_random_test.py | 4 +- tests/pallas/tpu_pallas_test.py | 4 +- ...pu_splash_attention_kernel_sharded_test.py | 6 +- tests/pjit_test.py | 8 +- tests/python_callback_test.py | 2 +- tests/ragged_collective_test.py | 14 +- tests/shard_alike_test.py | 4 +- tests/shard_map_test.py | 313 +++++++++--------- 37 files changed, 391 insertions(+), 393 deletions(-) diff --git a/docs/jax.rst b/docs/jax.rst index 2f16df613e5a..de901caf9414 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -106,6 +106,31 @@ Automatic differentiation closure_convert checkpoint +Vectorization (:code:`vmap`) +---------------------------- + +.. autosummary:: + :toctree: _autosummary + + vmap + numpy.vectorize + +Parallelization (:code:`pmap`) +------------------------------ + +.. autosummary:: + :toctree: _autosummary + + shard_map + pmap + devices + local_devices + process_index + device_count + local_device_count + process_count + process_indices + Customization ------------- @@ -217,30 +242,6 @@ Array properties and methods Array.T Array.mT -Vectorization (:code:`vmap`) ----------------------------- - -.. autosummary:: - :toctree: _autosummary - - vmap - numpy.vectorize - -Parallelization (:code:`pmap`) ------------------------------- - -.. autosummary:: - :toctree: _autosummary - - pmap - devices - local_devices - process_index - device_count - local_device_count - process_count - process_indices - Callbacks --------- diff --git a/docs/multi_process.md b/docs/multi_process.md index 32cfae126784..cebc75fbedc5 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -32,7 +32,7 @@ Key concepts: of all devices across all processes. * Use standard JAX parallelism APIs like {func}`~jax.jit` (see {doc}`/sharded-computation` tutorial) and - {func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts + {func}`~jax.shard_map`. jax.jit only accepts globally shaped arrays. shard_map allows you to drop to per-device shape. * Make sure all processes run the same parallel computations in the same @@ -128,7 +128,7 @@ global devices. So how do you actually run a computation involving cross-process communication? **Use the same parallel evaluation APIs that you would in a single process!** -For example, {func}`~jax.experimental.shard_map.shard_map` can be used +For example, {func}`~jax.shard_map` can be used to run a parallel computation across multiple processes. (If you’re not already familiar with how to use `shard_map` to run across multiple devices within a single process, check out the diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index f916147ef589..e8128c4133f7 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -55,8 +55,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, PartitionSpec as P" ] }, { @@ -71,7 +70,7 @@ "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", "b = jnp.arange(16 * 4.).reshape(16, 4)\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", " out_specs=P('x', None))\n", "def matmul_basic(a_block, b_block):\n", " # a_block: f32[2, 8]\n", @@ -249,7 +248,7 @@ "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", "\n", "def check_shmap(f, y):\n", - " ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", + " ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", " expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])\n", " print(allclose(ans, expected))\n", "\n", @@ -296,7 +295,7 @@ "source": [ "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", "def f1(x_block):\n", " print(x_block.shape) # prints (3, 12)\n", " return x_block\n", @@ -327,7 +326,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", "def f2(x_block):\n", " print(x_block.shape)\n", " return x_block\n", @@ -383,13 +382,13 @@ "source": [ "x = jnp.array([[3.]])\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", "print(z) # prints the same as jnp.tile(x, (4, 2))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", "print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", "print(z) # prints the same as jnp.tile(x, (1, 1)), or just x" ] }, @@ -410,7 +409,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", "def f3(x_block):\n", " return jax.lax.psum(x_block, 'j')\n", "\n", @@ -439,7 +438,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f4(x_block):\n", " return jax.lax.psum(x_block, 'i')\n", "\n", @@ -448,7 +447,7 @@ "print(y4.shape) # (3,12)\n", "\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f5(x_block):\n", " return jax.lax.psum(x_block, ('i', 'j'))\n", "\n", @@ -481,7 +480,7 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", - "#### Tracking how values vary over manual mesh axes, and `check_rep=True`\n", + "#### Tracking how values vary over manual mesh axes, and `check_vma=True`\n", "\n", "Under a `shard_map`, values can vary across function instances, or they can be\n", "the same. For example, when we use `in_specs` to split an argument over a mesh\n", @@ -497,7 +496,7 @@ "source": [ "mesh = jax.make_mesh((2,), ('i',))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", "def f(x):\n", " print(x)\n", " return 2 * x\n", @@ -522,7 +521,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=P())\n", "def f(x):\n", " print(x)\n", " return 2 * x\n", @@ -548,7 +547,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", "def f(x):\n", " y = jax.lax.psum(x, 'i')\n", " print(y)\n", @@ -565,7 +564,7 @@ "source": [ "In general, each intermediate value in a `shard_map` can be either unvarying or\n", "possibly-varying over each manual mesh axis. That information can be tracked in\n", - "the JAX type system, enabled by the `check_rep=True` argument to `shard_map`:" + "the JAX type system, enabled by the `check_vma=True` argument to `shard_map`:" ] }, { @@ -575,7 +574,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", "def f(x):\n", " print(jax.typeof(x)) # f32[3]{i}\n", " y = jax.lax.psum(x, 'i')\n", @@ -610,7 +609,7 @@ "source": [ "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", "def f(x):\n", " print(jax.typeof(x)) # f32[2,2]{i,j}\n", " y = jax.lax.psum(x, 'j')\n", @@ -635,7 +634,7 @@ "3. The correctness of `out_specs` can be checked, ruling out the potential bug\n", " example below.\n", "\n", - "For example, this `out_specs` bug is caught with `check_rep=True`, but uncaught\n", + "For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught\n", "without it:" ] }, @@ -650,7 +649,7 @@ "\n", "x = jnp.arange(6.)\n", "try:\n", - " y = shard_map(lambda x: x, mesh, in_specs=P('i'), out_specs=P())(x)\n", + " y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x)\n", "except Exception as e:\n", " print(e)" ] @@ -662,8 +661,8 @@ "source": [ "Here the `out_specs` incorrectly promise that each function instance along mesh\n", "axis `'i'` produces the same value and thus we can choose just one of them.\n", - "With `check_rep=True` (the default) it raises an exception, while with\n", - "`check_rep=False` there is no exception and instead we get silent undefined\n", + "With `check_vma=True` (the default) it raises an exception, while with\n", + "`check_vma=False` there is no exception and instead we get silent undefined\n", "behavior.\n", "\n", "Sometimes we want to treat a value that is unvarying over a mesh axis as\n", @@ -677,7 +676,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n", "def f(x):\n", " print(jax.typeof(x)) # f32[6]\n", " y = jax.lax.pvary(x, 'i')\n", @@ -710,7 +709,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", "def f(x, y):\n", " return x * y\n", "\n", @@ -744,7 +743,7 @@ "source": [ "mesh = jax.make_mesh((2,), ('i',))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", "def f(x, y):\n", " def body(carry, _):\n", " c1, c2 = carry\n", @@ -779,7 +778,7 @@ "source": [ "mesh = jax.make_mesh((2,), ('i',))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", "def f(x, y):\n", " def body(carry, _):\n", " c1, c2 = carry\n", @@ -830,7 +829,7 @@ "def shard_map(\n", " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n", " auto: collections.abc.Set[AxisName] = frozenset([]),\n", - " check_rep: bool = True,\n", + " check_vma: bool = True,\n", ") -> Callable:\n", " ...\n", "```\n", @@ -839,7 +838,7 @@ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -861,7 +860,7 @@ "```python\n", "mesh = Mesh(jax.devices(), ('i',))\n", "x = jnp.arange(16.)\n", - "f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))\n", + "f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", "y = f_shmapped(x)\n", "```\n", "\n", @@ -933,8 +932,7 @@ "import jax.numpy as jnp\n", "from jax import lax\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P" ] }, { @@ -946,7 +944,7 @@ "source": [ "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", "\n", - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", "def f1(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -1002,7 +1000,7 @@ "source": [ "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", "\n", - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f2(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -1033,7 +1031,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f3(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, ('i', 'j'))\n", @@ -1070,7 +1068,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f4(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n", @@ -1109,7 +1107,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f5(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=False)\n", @@ -1152,7 +1150,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f6(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", @@ -1228,7 +1226,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f7(x_block):\n", " sz = jax.lax.axis_size('i')\n", " print('BEFORE:\\n', x_block)\n", @@ -1306,7 +1304,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f8(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = psum_scatter(x_block, 'i', tiled=True)\n", @@ -1354,7 +1352,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f9(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n", @@ -1426,8 +1424,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P" ] }, { @@ -1503,7 +1500,7 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather(lhs_block, rhs_block):\n", " rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)\n", @@ -1547,7 +1544,7 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped(lhs_block, rhs_block):\n", " size = jax.lax.axis_size('i')\n", @@ -1596,7 +1593,7 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n", " size = jax.lax.axis_size('i')\n", @@ -1677,7 +1674,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter(lhs_block, rhs_block):\n", " out_summand = lhs_block @ rhs_block\n", @@ -1705,7 +1702,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n", " size = jax.lax.axis_size('i')\n", @@ -1748,7 +1745,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n", " size = jax.lax.axis_size('i')\n", @@ -1885,7 +1882,6 @@ "from functools import partial\n", "\n", "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map\n", "\n", "mesh = jax.make_mesh((8,), ('batch',))\n", "\n", @@ -1895,7 +1891,7 @@ "\n", "# adapt the loss function to sum the losses across devices\n", "def loss_dp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", + " @partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", " def loss_spmd(local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict(params, inputs) # use reference 'predict`\n", @@ -2000,7 +1996,7 @@ " return outputs\n", "\n", "def loss_fsdp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", + " @partial(jax.shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", " def loss_spmd(local_params, local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict_fsdp(local_params, inputs)\n", @@ -2069,7 +2065,7 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n", " out_specs=P(None, 'feats'))\n", "def gemm_tp(inputs, W, b):\n", @@ -2117,7 +2113,7 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P(('feats', 'batch')), P('batch', 'feats')),\n", " out_specs=P())\n", "def loss_fsdp_tp(local_params, local_batch):\n", @@ -2227,7 +2223,7 @@ " outputs = jnp.dot(inputs, W_last) + b_last\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", " out_specs=P())\n", "def loss_pp(params, batch):\n", " inputs, targets = batch\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index ba23d7d17f3e..43069110301d 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -46,7 +46,6 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} @@ -55,7 +54,7 @@ mesh = jax.make_mesh((4, 2), ('x', 'y')) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) -@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), +@partial(jax.shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None)) def matmul_basic(a_block, b_block): # a_block: f32[2, 8] @@ -161,7 +160,7 @@ devices = np.array(jax.devices()[:4]) mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 def check_shmap(f, y): - ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y) + ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y) expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])]) print(allclose(ans, expected)) @@ -196,7 +195,7 @@ then there's no splitting over that mesh axis. For example: ```{code-cell} mesh = jax.make_mesh((4, 2), ('i', 'j')) -@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): print(x_block.shape) # prints (3, 12) return x_block @@ -215,7 +214,7 @@ less efficient program where all mesh axes are mentioned but the caller performs a `jnp.tile`, for example: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) def f2(x_block): print(x_block.shape) return x_block @@ -259,13 +258,13 @@ using the same mesh as above: ```{code-cell} x = jnp.array([[3.]]) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() print(z) # prints the same as jnp.tile(x, (4, 2)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() print(z) # prints the same as jnp.tile(x, (1, 1)), or just x ``` @@ -274,7 +273,7 @@ augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) def f3(x_block): return jax.lax.psum(x_block, 'j') @@ -291,7 +290,7 @@ two more examples where we vary which mesh axes are mentioned in the output pspec: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f4(x_block): return jax.lax.psum(x_block, 'i') @@ -300,7 +299,7 @@ y4 = f4(x) print(y4.shape) # (3,12) -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) def f5(x_block): return jax.lax.psum(x_block, ('i', 'j')) @@ -328,7 +327,7 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. -#### Tracking how values vary over manual mesh axes, and `check_rep=True` +#### Tracking how values vary over manual mesh axes, and `check_vma=True` Under a `shard_map`, values can vary across function instances, or they can be the same. For example, when we use `in_specs` to split an argument over a mesh @@ -337,7 +336,7 @@ axis, each function instance along that mesh axis gets a different value: ```{code-cell} mesh = jax.make_mesh((2,), ('i',)) -@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): print(x) return 2 * x @@ -350,7 +349,7 @@ If instead `in_specs` does not split the argument over a mesh axis, the value is the same for each function instance along that axis: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) +@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): print(x) return 2 * x @@ -364,7 +363,7 @@ example, applying a `psum` produces the same output on each function instance along an axis: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): y = jax.lax.psum(x, 'i') print(y) @@ -376,10 +375,10 @@ f(x) In general, each intermediate value in a `shard_map` can be either unvarying or possibly-varying over each manual mesh axis. That information can be tracked in -the JAX type system, enabled by the `check_rep=True` argument to `shard_map`: +the JAX type system, enabled by the `check_vma=True` argument to `shard_map`: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): print(jax.typeof(x)) # f32[3]{i} y = jax.lax.psum(x, 'i') @@ -402,7 +401,7 @@ axes over which the `shard_map` is acting: ```{code-cell} mesh = jax.make_mesh((4, 2), ('i', 'j')) -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) def f(x): print(jax.typeof(x)) # f32[2,2]{i,j} y = jax.lax.psum(x, 'j') @@ -422,7 +421,7 @@ Tracking varying manual axes can be useful: 3. The correctness of `out_specs` can be checked, ruling out the potential bug example below. -For example, this `out_specs` bug is caught with `check_rep=True`, but uncaught +For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught without it: ```{code-cell} @@ -430,22 +429,22 @@ mesh = jax.make_mesh((2,), ('i',)) x = jnp.arange(6.) try: - y = shard_map(lambda x: x, mesh, in_specs=P('i'), out_specs=P())(x) + y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x) except Exception as e: print(e) ``` Here the `out_specs` incorrectly promise that each function instance along mesh axis `'i'` produces the same value and thus we can choose just one of them. -With `check_rep=True` (the default) it raises an exception, while with -`check_rep=False` there is no exception and instead we get silent undefined +With `check_vma=True` (the default) it raises an exception, while with +`check_vma=False` there is no exception and instead we get silent undefined behavior. Sometimes we want to treat a value that is unvarying over a mesh axis as varying over that mesh axis. That's what `jax.lax.pvary` does: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None) +@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=None) def f(x): print(jax.typeof(x)) # f32[6] y = jax.lax.pvary(x, 'i') @@ -466,7 +465,7 @@ JAX implicitly inserts `jax.lax.pvary` calls in many cases, especially for binary operations: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) def f(x, y): return x * y @@ -488,7 +487,7 @@ this code raises an error: ```{code-cell} mesh = jax.make_mesh((2,), ('i',)) -@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) def f(x, y): def body(carry, _): c1, c2 = carry @@ -511,7 +510,7 @@ the `scan`: ```{code-cell} mesh = jax.make_mesh((2,), ('i',)) -@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) def f(x, y): def body(carry, _): c1, c2 = carry @@ -557,7 +556,7 @@ Specs = PyTree[PartitionSpec] def shard_map( f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, auto: collections.abc.Set[AxisName] = frozenset([]), - check_rep: bool = True, + check_vma: bool = True, ) -> Callable: ... ``` @@ -566,7 +565,7 @@ where: * `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; * `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; * `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). +* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -588,7 +587,7 @@ so that this: ```python mesh = Mesh(jax.devices(), ('i',)) x = jnp.arange(16.) -f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) +f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) y = f_shmapped(x) ``` @@ -654,13 +653,12 @@ import jax.numpy as jnp from jax import lax from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} mesh1d = Mesh(jax.devices()[:4], ('i',)) -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) def f1(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -698,7 +696,7 @@ each one separately, or over multiple axes at once: ```{code-cell} mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f2(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -717,7 +715,7 @@ If we apply the `psum` over both axes, the `y_block` value is equal along both axes: ```{code-cell} -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) +@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) def f3(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, ('i', 'j')) @@ -742,7 +740,7 @@ each function application has a full copy of the data along that axis: Illustration of an all_gather computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f4(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=True) @@ -769,7 +767,7 @@ When `tiled=False` (the default), results are stacked along a new axis instead of concatenated: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f5(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=False) @@ -800,7 +798,7 @@ The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like Illustration of a psum_scatter computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f6(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) @@ -864,7 +862,7 @@ that mesh axis, `ppermute` sends its argument value from each source function instance to each destination: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f7(x_block): sz = jax.lax.axis_size('i') print('BEFORE:\n', x_block) @@ -924,7 +922,7 @@ def psum_scatter(x, axis_name, *, tiled=False): ``` ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f8(x_block): print('BEFORE:\n', x_block) y_block = psum_scatter(x_block, 'i', tiled=True) @@ -960,7 +958,7 @@ transpose operating along one positional axis and one cross-device axis: Illustration of an all_to_all computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f9(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0, @@ -1021,7 +1019,6 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} @@ -1055,7 +1052,7 @@ side: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather(lhs_block, rhs_block): rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) @@ -1081,7 +1078,7 @@ multiplies: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): size = jax.lax.axis_size('i') @@ -1112,7 +1109,7 @@ each half in each direction: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.axis_size('i') @@ -1163,7 +1160,7 @@ rhs = device_put(rhs, rhs_spec) Here we can use a `reduce_scatter` to perform the contraction sum over shards: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter(lhs_block, rhs_block): out_summand = lhs_block @ rhs_block @@ -1179,7 +1176,7 @@ inline an implementation of `psum_scatter` in terms of `ppermute`, then interleave the communication steps with local matrix multiplies: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): size = jax.lax.axis_size('i') @@ -1204,7 +1201,7 @@ As in the previous example, to fully utilize interconnects on TPU, we'd run a bidirectional version: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.axis_size('i') @@ -1299,7 +1296,6 @@ all-reduce-sums of parameter gradients in the backward pass.) from functools import partial from jax.sharding import NamedSharding, Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map mesh = jax.make_mesh((8,), ('batch',)) @@ -1309,7 +1305,7 @@ params = jax.device_put(params, NamedSharding(mesh, P())) # adapt the loss function to sum the losses across devices def loss_dp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) + @partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) def loss_spmd(local_batch): inputs, targets = local_batch predictions = predict(params, inputs) # use reference 'predict` @@ -1384,7 +1380,7 @@ def predict_fsdp(params_frag, inputs): return outputs def loss_fsdp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) + @partial(jax.shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) def loss_spmd(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp(local_params, inputs) @@ -1429,7 +1425,7 @@ def predict_tp(params, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), out_specs=P(None, 'feats')) def gemm_tp(inputs, W, b): @@ -1465,7 +1461,7 @@ def predict_fsdp_tp(params_frag, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')), out_specs=P()) def loss_fsdp_tp(local_params, local_batch): @@ -1545,7 +1541,7 @@ def predict_pp(params, inputs): outputs = jnp.dot(inputs, W_last) + b_last return outputs -@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), +@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), out_specs=P()) def loss_pp(params, batch): inputs, targets = batch diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 0a5a89abe26d..e241e76e3c5f 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -264,10 +264,9 @@ However, if we were to wrap the layernorm primitive in shard_map and define a fu ```python import jax -from jax.experimental.shard_map import shard_map def G(x1, x2, gamma, beta, mesh, ispecs, ospecs): - ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta) + ln_out = jax.shard_map(LayerNorm, mesh=mesh, in_specs=ispecs, out_specs=ospecs, check_vma=False)(x1, x2, gamma, beta) return ln_out @ x2 ispecs = jax.sharding.PartitionSpec(...) diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 568a0d4c6e3d..f9f33febb094 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -22,7 +22,7 @@ " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", " into `psum` for example) but the compiler is heavily constrained by the\n", " user-supplied shardings.\n", - "- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", "A summary table:\n", "\n", @@ -536,14 +536,14 @@ "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", - "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", + "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", "\n", "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n", "\n", "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n", "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n", "\n", - "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it." + "**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it." ] }, { @@ -571,10 +571,9 @@ } ], "source": [ - "from jax.experimental.shard_map import shard_map\n", "mesh = jax.make_mesh((8,), ('x',))\n", "\n", - "f_elementwise_sharded = shard_map(\n", + "f_elementwise_sharded = jax.shard_map(\n", " f_elementwise,\n", " mesh=mesh,\n", " in_specs=P('x'),\n", @@ -615,7 +614,7 @@ " print(f\"device local shape: {x.shape=}\")\n", " return x * 2\n", "\n", - "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -649,7 +648,7 @@ "def f(x):\n", " return jnp.sum(x, keepdims=True)\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -684,7 +683,7 @@ " sum_in_shard = x.sum()\n", " return jax.lax.psum(sum_in_shard, 'x')\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" ] }, { @@ -849,7 +848,7 @@ "from functools import partial\n", "\n", "@jax.jit\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P('x'), P('x', None), P(None)),\n", " out_specs=P(None))\n", "def layer_sharded(x, weights, bias):\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index ae9f44aba832..60e789a109b2 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -29,7 +29,7 @@ The tutorial covers three modes of parallel computation: to turn the whole-array program into per-device programs (turning `jnp.sum` into `psum` for example) but the compiler is heavily constrained by the user-supplied shardings. -- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives +- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives A summary table: @@ -222,22 +222,21 @@ we can query them at trace time. ## 3. Manual parallelism with `shard_map` -In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. +In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. `shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below: - As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names. - The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together. -**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it. +**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it. ```{code-cell} :outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 -from jax.experimental.shard_map import shard_map mesh = jax.make_mesh((8,), ('x',)) -f_elementwise_sharded = shard_map( +f_elementwise_sharded = jax.shard_map( f_elementwise, mesh=mesh, in_specs=P('x'), @@ -259,7 +258,7 @@ def f(x): print(f"device local shape: {x.shape=}") return x * 2 -y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought. @@ -272,7 +271,7 @@ For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like: def f(x): return jnp.sum(x, keepdims=True) -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Your function `f` operates separately on each shard, and the resulting summation reflects this. @@ -286,7 +285,7 @@ def f(x): sum_in_shard = x.sum() return jax.lax.psum(sum_in_shard, 'x') -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) ``` Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`). @@ -362,7 +361,7 @@ Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` from functools import partial @jax.jit -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x', None), P(None)), out_specs=P(None)) def layer_sharded(x, weights, bias): diff --git a/jax/_src/core.py b/jax/_src/core.py index b6f12745ea92..12f56e7e527b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2011,7 +2011,7 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups): f"non-device-varying type, but got {arg_vma} for collective acting " f"over axis name {axes}. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") + "workaround pass the check_vma=False argument to `jax.shard_map`") sharding = NamedSharding(mesh_lib.get_abstract_mesh(), P()) return [a.update(sharding=sharding, vma=a.vma.union(frozenset(axes))) for a in args] @@ -2041,7 +2041,7 @@ def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: f'Primitive {prim_name} requires varying manual axes ' f'to match, but got {[vma, *vmas]}. Please open an issue at ' 'https://github.com/jax-ml/jax/issues and as a temporary ' - 'workaround pass the check_rep=False argument to shard_map') + 'workaround pass the check_vma=False argument to `jax.shard_map`') return vma # Dynamic shape stuff below here! We keep the abstract values distinct just so diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index dc140c22650d..b44a4e434027 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -133,7 +133,6 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule def _debug_callback_partial_auto(axis_context, *args, **params): - from jax.experimental.shard_map import shard_map partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes) def f(): idx = jax.lax.with_sharding_constraint( @@ -142,7 +141,7 @@ def f(): return jax.lax.cond(idx == 0, lambda: debug_callback_p.bind(*args, **params), lambda: []) - return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])() + return jax.shard_map(f, mesh=axis_context.mesh, in_specs=(), out_specs=[])() def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params): axis_context = ctx.module_context.axis_context diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index b80def4fd2db..339407fa295f 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -27,7 +27,7 @@ from jax._src import source_info_util from jax._src import traceback_util import jax._src.mesh as mesh_lib -from jax.experimental import shard_map +from jax._src import shard_map import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 815183b6ebd4..f3d1265511d5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1876,7 +1876,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/jax-ml/jax/issues/2926. Or " - "use jax.experimental.shard_map instead of pmap under jit compilation.") + "use jax.shard_map instead of pmap under jit compilation.") if nreps > xb.device_count(backend): raise ValueError( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index b6263c427a00..99fa72421ea1 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -448,8 +448,8 @@ def _cond_abstract_eval(*avals: core.AbstractValue, raise Exception("The branches of cond produced mismatched varying manual " f"axes. Got {b0_vma} and {b_vma}. Please open an issue " "at https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument " - "to shard_map") + "temporary workaround pass the check_vma=False argument " + "to `jax.shard_map`") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 39498ad624bc..9bd358c2ae9a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -574,8 +574,8 @@ def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, 'Scan carry input and output got mismatched varying manual axes ' f'{in_carry_avals} and {out_carry_avals}. Please open an ' 'issue at https://github.com/jax-ml/jax/issues, and as a ' - 'temporary workaround pass the check_rep=False argument to ' - 'shard_map') + 'temporary workaround pass the check_vma=False argument to ' + '`jax.shard_map`') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) return out_carry_avals + ys_avals, jaxpr.effects diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 0b3e0c3acc59..2c71565d92e4 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -699,17 +699,16 @@ def axis_size(axis_name: AxisName) -> int: For example, with 8 XLA devices available: >>> from functools import partial - >>> from jax.experimental.shard_map import shard_map >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((8,), 'i') - >>> @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) ... def f(_): ... return lax.axis_size('i') ... >>> f(jnp.zeros(16)) Array(8, dtype=int32, weak_type=True) >>> mesh = jax.make_mesh((4, 2), ('i', 'j')) - >>> @partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) + >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) ... def f(_): ... return lax.axis_size(('i', 'j')) ... @@ -883,7 +882,7 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): f"type, but got {arg_vma} for collective acting " f"over axis name {axes}. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") + "workaround pass the check_vma=False argument to `jax.shard_map`") named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) @@ -1598,7 +1597,7 @@ def collective_vma_rule(prim_name, axis_name, x_aval): f" type, but got {x_aval.vma} for collective acting " f"over axis name {axis_name}. Please open an issue at " "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") + "workaround pass the check_vma=False argument to `jax.shard_map`") return x_aval.vma def _all_gather_effectful_abstract_eval( @@ -1923,12 +1922,11 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): if axis_env.sizes[axis_pos] == 1: return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32))) - from jax.experimental.shard_map import shard_map def f(): return axis_index_p.bind(axis_name=axis_name) return mlir.lower_fun( - lambda: [shard_map(f, axis_context.mesh, check_rep=False, - in_specs=(), out_specs=P())()])(ctx)[0] + lambda: [jax.shard_map(f, mesh=axis_context.mesh, check_vma=False, + in_specs=(), out_specs=P())()])(ctx)[0] nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index b40b1a6d5571..3052b532cb97 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -38,7 +38,7 @@ from jax._src import xla_bridge as xb from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.experimental import pjit from jax.sharding import PartitionSpec as P @@ -1472,7 +1472,7 @@ def apply_transform(func, transform: str): in_shardings=(sharding.NamedSharding(mesh, P("a")),), out_shardings=sharding.NamedSharding(mesh, P("a"))), shard_map=( - shard_map(func, mesh, in_specs=(P("a", None),), + shard_map(func, mesh=mesh, in_specs=(P("a", None),), out_specs=P("a", None))), pmap=jax.pmap(func, in_axes=0, out_axes=0), )[transform] diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 8651fe4e62d4..cb28ab9f0dd1 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -36,7 +36,7 @@ from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index dbfde9eb5177..a0eb07f719ec 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -30,7 +30,7 @@ import jax from jax import lax from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -151,5 +151,5 @@ def ag_local(x_shard): return shard_map.shard_map( ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None), - check_rep=False + check_vma=False )(x) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 9db3bd4e289c..2f3ce62a5744 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -22,6 +22,7 @@ from jax._src import ops from jax._src import prng from jax._src import random +from jax._src import shard_map from jax._src.lax import ( ann, convolution, @@ -34,7 +35,6 @@ windowed_reductions, ) from jax.experimental import roofline -from jax.experimental import shard_map # One FMA (Fused Multiply Add) takes 2 flops to compute. _FMA_FLOPS_FACTOR = 2 diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 816e1cb81472..c619fc8e915c 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -23,7 +23,7 @@ from jax import lax from jax.experimental import checkify from jax.experimental import pjit -from jax.experimental import shard_map +from jax._src import shard_map from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config @@ -554,7 +554,7 @@ def g(x, y): self.assertStartsWith(b_err.get(), "division by zero") @parameterized.parameters(True, False) - def test_shard_map(self, check_rep): + def test_shard_map(self, check_vma): def f(x): # unary func return jax.lax.axis_index("dev") * x / x @@ -571,12 +571,12 @@ def g(x, y): x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) f = shard_map.shard_map( - f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + f, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=check_vma ) f = jax.jit(f, in_shardings=ps, out_shardings=ps) f = checkify.checkify(f, errors=checkify.float_checks) g = shard_map.shard_map( - g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + g, mesh=mesh, in_specs=(pspec, pspec), out_specs=pspec, check_vma=check_vma ) g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) g = checkify.checkify(g, errors=checkify.float_checks) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index d014c86c2506..611b2495949a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -30,7 +30,7 @@ from jax.experimental import checkify import jax.experimental.custom_dce from jax.experimental import pallas as pl -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp import jax.scipy as jsp diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index c80d23c416df..29fde318756c 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax import numpy as jnp from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import PartitionSpec as P jax.config.parse_flags_with_absl() diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 1fa8cbaa765f..937e4165a159 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -63,7 +63,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh diff --git a/tests/export_test.py b/tests/export_test.py index 41f74fe11a1d..598a6634e1e3 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -29,7 +29,7 @@ from jax import numpy as jnp from jax import export from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P diff --git a/tests/ffi_test.py b/tests/ffi_test.py index 978415194e55..fd41314350f3 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -35,7 +35,7 @@ from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -334,7 +334,7 @@ def test_shard_map(self): x = self.rng().randn(8, 4, 5).astype(np.float32) @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"), - check_rep=False) + check_vma=False) def f(x): return batch_partitionable_ffi_call(x) diff --git a/tests/memories_test.py b/tests/memories_test.py index 203bc4abb613..9700acc649e3 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -35,7 +35,7 @@ SingleDeviceSharding, GSPMDSharding, TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import numpy as np config.parse_flags_with_absl() @@ -665,7 +665,7 @@ def test_ragged_copy_on_host(self): def write(x): return x.at[16 * 1024:].set(0) - x = shard_map(write, mesh, P(('x'),), P(('x')))(x) + x = shard_map(write, mesh=mesh, in_specs=P(('x'),), out_specs=P(('x')))(x) chunk_size = 8 def inner(state): @@ -686,8 +686,8 @@ def foo(x): _, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output)) return cpu_x - fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')), - check_rep=False), + fn = jax.jit(shard_map(foo, mesh=mesh, in_specs=P(('x'),), + out_specs=P(('x')), check_vma=False), out_shardings=cpu_sharding) y = fn(x) jax.block_until_ready(y) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index f7588bc0776c..e464214928e4 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -623,7 +623,7 @@ def test_basic_remote_copy(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -646,7 +646,7 @@ def test_multi_remote_copy(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy( @@ -679,7 +679,7 @@ def test_basic_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -704,7 +704,7 @@ def test_staggered_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 @@ -737,7 +737,7 @@ def test_bidi_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index bb46d1d18772..966ed13fdad8 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -83,7 +83,7 @@ def body(x): mesh = jax.sharding.Mesh(devices, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) expected = jnp.concatenate([x[8:], x[:8]]) @@ -136,7 +136,7 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) if direction == 'right': @@ -192,10 +192,10 @@ def body(x): y = jax.jit( shard_map.shard_map( body, - mesh, + mesh=mesh, in_specs=P('x', None), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(x) if direction == 'right': @@ -243,7 +243,7 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) expected = jnp.concatenate([x[-8:], x[:-8]]) @@ -316,7 +316,7 @@ def test_kernel(x_ref, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result = compiled_func(sharded_arr) perm = tuple((src, permute_fn(src)) for src in range(num_devices)) @@ -402,7 +402,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) kernel = pl.pallas_call( @@ -415,7 +415,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_noninterpret = compiled_func(sharded_arr) np.testing.assert_allclose(result_interpret, result_noninterpret, @@ -497,7 +497,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) kernel = pl.pallas_call( @@ -510,7 +510,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_noninterpret = compiled_func(sharded_arr) np.testing.assert_allclose(result_interpret, result_noninterpret, @@ -568,7 +568,7 @@ def _(i, _): previous_config = jax.config.read('jax_pallas_dump_promela_to') jax.config.update('jax_pallas_dump_promela_to', tmpdir) shard_map.shard_map( - kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False + kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_vma=False )(jnp.ones((8, 128, 128), jnp.float32)) jax.config.update('jax_pallas_dump_promela_to', previous_config) self.assertNotEmpty(os.listdir(tmpdir)) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 1ed139e9e867..0fd94f1a8049 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -26,7 +26,7 @@ from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -114,7 +114,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -237,7 +237,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) @@ -396,7 +396,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -680,7 +680,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -984,7 +984,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -1069,7 +1069,7 @@ def run(src_dst_ids): mesh=mesh, in_specs=(P(None), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, )(src_dst_ids, input_arr) run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready() diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index f29182e56314..1f00b22b6708 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -498,7 +498,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -741,7 +741,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -1025,7 +1025,7 @@ def _loop_epilogue(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[0, 0])) @@ -1286,7 +1286,7 @@ def _prefetch_accumulator(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[1])) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 74697e6c0b7f..78a81d168136 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -19,7 +19,7 @@ from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu.random import philox # pylint: disable=unused-import # noqa: F401 from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 @@ -306,7 +306,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) jax_gen = generate(key_jax) pl_gen = generate(key_pallas) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 53fdac98504c..3746214eac18 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -39,7 +39,7 @@ from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import example_kernel from jax.extend import linear_util as lu @@ -2584,7 +2584,7 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False + check_vma=False ) )(input_arr) diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py index db14b44938e9..9edd425f24dd 100644 --- a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import PartitionSpec import numpy as np @@ -131,7 +131,7 @@ def test_dynamic_mask_manual_partitioning_mha( kv_spec, ), out_specs=q_spec, - check_rep=False, + check_vma=False, ) def f(kernel, q, k, v): return kernel(q, k, v) @@ -199,7 +199,7 @@ def test_dynamic_mask_manual_partitioning_mha_bwd( kv_spec, ), out_specs=q_spec, - check_rep=False, + check_vma=False, ) def f(kernel, q, k, v): return kernel(q, k, v) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d17851b3236d..468877d79be4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -42,7 +42,7 @@ from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax._src.compilation_cache import is_persistent_cache_enabled from jax.experimental.custom_partitioning import ( custom_partitioning, SdyShardingRule, BATCHING) @@ -4512,7 +4512,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -4530,7 +4530,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -6971,7 +6971,7 @@ def f(): return const * 2 shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'), - auto=frozenset({'y'})) + axis_names={'x'}) f = jax.jit(shmap_f) out = f() self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2])) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 0d78730b8ca8..13df0c1dd376 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,7 +30,7 @@ from jax._src import util from jax.experimental import io_callback from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh import numpy as np diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 1dd6ef657561..1734f67ff063 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -27,7 +27,7 @@ from jax._src import test_util as jtu import jax.numpy as jnp -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map config.parse_flags_with_absl() @@ -90,7 +90,7 @@ def test_ragged_all_to_all(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -176,7 +176,7 @@ def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -257,7 +257,7 @@ def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -346,7 +346,7 @@ def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -491,7 +491,7 @@ def get_data_sharding(axis): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -555,7 +555,7 @@ def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 2ad3e089e662..aeb218b478ad 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -19,7 +19,7 @@ from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -146,7 +146,7 @@ def g(x): @jax.jit def f(x): y = x @ x.T - s_out = shard_map(g, mesh, in_specs=P('x', 'y'), + s_out = shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P(None, 'y'))(y) z = s_out.T @ s_out return shard_alike(y, z) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 09cc79d17b5f..ecad33b99d8b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -36,6 +36,7 @@ from jax._src import config from jax._src import core from jax._src import prng +from jax._src.shard_map import shard_map from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists @@ -48,7 +49,6 @@ import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() @@ -83,7 +83,7 @@ def identity(x): def fwd(a): c = shard_map( identity, - mesh, + mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -368,7 +368,7 @@ def f(x): def test_jvp_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), jtu.check_grads(g, args, 2, ['fwd']) @@ -376,7 +376,7 @@ def test_jvp_basic(self): def test_linearize_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) @@ -390,7 +390,7 @@ def test_linearize_basic(self): def test_linearize_basic_repres(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, + g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -404,7 +404,7 @@ def test_linearize_basic_repres(self): def test_linearize_basic_repres_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -423,7 +423,7 @@ def test_replication_checker_eager(self): def f(x): return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): g(x) @@ -431,7 +431,7 @@ def g(x): def f2(x): return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = g2(x) # doesn't crash def test_replication_checker_jit(self): @@ -441,7 +441,7 @@ def test_replication_checker_jit(self): def f(x): return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): jax.jit(g)(x) @@ -449,7 +449,7 @@ def g(x): def f2(x): return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = jax.jit(g2)(x) # doesn't crash def test_process_env_traces(self): @@ -458,7 +458,7 @@ def test_process_env_traces(self): def g(x): y = (3. * x).sum() - z = shard_map(lambda x: 2 * x * y, mesh, + z = shard_map(lambda x: 2 * x * y, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.)) return z @@ -476,13 +476,14 @@ def f(x): return -x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) y = g(x) self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) + f = shard_map(lambda x: x.reshape(1, *x.shape), mesh=mesh, in_specs=P(), + out_specs=P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): @@ -490,7 +491,7 @@ def test_vmap_basic(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g)(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -500,7 +501,7 @@ def test_vmap_basic_axis_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='i')(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -510,7 +511,7 @@ def test_vmap_basic_axis_name_reuse_mesh_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -663,18 +664,18 @@ def test_check_rep_false_doesnt_hit_rep_rules(self): prim.def_impl(lambda: []) prim.def_abstract_eval(lambda: []) - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=True) def f(): prim.bind() - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f2(): prim.bind() f2() jax.jit(f2)() - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f3(): jax.jit(prim.bind)() @@ -713,7 +714,7 @@ def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial( - shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_vma=False ) def f(x): return jnp.sin(jnp.sum(x)) @@ -939,7 +940,7 @@ def f(key): dtype=jnp.int32) pspec = P('x') if config.enable_custom_prng.value else P('x', None) - g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec) + g = shard_map(f, mesh=mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! def test_functools_partial_rank_error(self): @@ -949,7 +950,7 @@ def test_functools_partial_rank_error(self): def f(x): return x - g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',)) + g = shard_map(f, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x',)) x = jnp.arange(4) with self.assertRaises(ValueError): g(x) @@ -960,13 +961,13 @@ def test_in_specs_none_error(self): def f(x): return x with self.assertRaisesRegex(TypeError, "but it was None"): - shard_map(f, mesh, in_specs=None, out_specs=P())(3.) + shard_map(f, mesh=mesh, in_specs=None, out_specs=P())(3.) # TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug # with self.assertRaises(TypeError): - # shard_map(f, mesh, in_specs=(None,), out_specs=P())(3.) + # shard_map(f, mesh=mesh, in_specs=(None,), out_specs=P())(3.) - shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash + shard_map(f, mesh=mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -983,18 +984,18 @@ def body(c, _): x = jnp.arange(4) # doesn't crash, because out_spec assumes no replication (and there is none) - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) # does crash, because output incorrectly promises replication with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('x'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('y'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(None))(x, x, x) def g(x, y, z): @@ -1005,12 +1006,12 @@ def body(c, _): return [jnp.expand_dims(a, 0) for a in out] # doesn't crash, because everything matches - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) # does crash, because the second guy is wrong with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) def test_while_rep_rule(self): @@ -1032,18 +1033,18 @@ def body(c): x = jnp.arange(4) # doesn't crash, because out_spec assumes no replication (and there is none) - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) # does crash, because output incorrectly promises replication with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('x'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('y'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(None))(x, x, x) def g(x, y, z): @@ -1058,12 +1059,12 @@ def body(c): return [jnp.expand_dims(a, 0) for a in out] # doesn't crash, because everything matches - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) # does crash, because the second guy is wrong with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) def test_cond_rep_rule(self): @@ -1077,10 +1078,10 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): @@ -1089,10 +1090,10 @@ def false_fun(x, y): return lax.pvary(y, 'x') return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) def f(x, y): def true_fn(x, y): @@ -1101,10 +1102,10 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): @@ -1113,8 +1114,8 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) # https://github.com/jax-ml/jax/issues/24418 def f(a): @@ -1123,7 +1124,7 @@ def f(a): mesh = jtu.create_mesh((2,), ('x',)) a = jnp.array([True, False]) - shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(a) def test_switch_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -1133,7 +1134,7 @@ def f(n, x, y): return jax.lax.switch( n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) - shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + shard_map(f, mesh=mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) def test_eager_custom_jvp_basic(self): @jax.custom_jvp @@ -1146,7 +1147,7 @@ def foo_jvp(primals, tangents): return foo(x), 3. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1165,7 +1166,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1179,7 +1180,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4,), ('x',)) - ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() + ans = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1195,7 +1196,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4, 2), ('i', 'j')) - ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), + ans1, ans2, ans3 = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2)) @@ -1262,7 +1263,7 @@ def test_key_array_with_replicated_last_tile_dim(self): def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False) + check_vma=False) def g(rng): return jnp.array([jax.random.normal(rng[0])]) return g(jax.random.split(rng, 4)) @@ -1301,7 +1302,8 @@ def test_returned_out_sharding(self): mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) - out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) + out = shard_map(lambda x: x, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P('x', 'y'))(inp) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, inp) @@ -1369,9 +1371,9 @@ def test_sharding_metadata_in_hlo_attrs(self): def foo(x): x = jnp.sin(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) return x @@ -1402,7 +1404,7 @@ def f(x): x)[0] * x mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call self.assertAllClose(y, 2 * x * x, check_dtypes=True) @@ -1436,7 +1438,7 @@ def foo_jvp(primals, tangents): return foo(x), 2. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1464,7 +1466,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1492,7 +1494,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1540,7 +1542,7 @@ def foo_scan(x): return y mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo_scan(x) * x, mesh, + g = shard_map(lambda x: foo_scan(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1645,7 +1647,7 @@ def g_bwd(vjp_fn, result): def f_shmapped(x, y): return jax.lax.psum(f(x, y).sum(), axis_name=('x')) - @partial(shard_map, mesh=mesh, check_rep=False, + @partial(shard_map, mesh=mesh, check_vma=False, in_specs=P('x'), out_specs=(P('x'), P())) def f_shmapped2(x, y): return g(x, y) @@ -1719,7 +1721,7 @@ def f(q, k, v): def body(q, k, v): return q * k[None, :] + v[None, :] - out = shard_map(body, mesh, check_rep=False, + out = shard_map(body, mesh=mesh, check_vma=False, in_specs=(q_spec, kv_spec, kv_spec,), out_specs=q_spec)(q, k, v) return out.sum() @@ -1744,7 +1746,7 @@ def foo(x): @partial(jax.remat, policy=lambda *args, **kwargs: True) def bar(x): return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),), - out_specs=P('x'), check_rep=False)(x) + out_specs=P('x'), check_vma=False)(x) jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash @@ -1855,7 +1857,7 @@ def f(*args): return args[0] @ args[1] shard_f = shard_map( - f, mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) + f, mesh=mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) @@ -1898,7 +1900,8 @@ def test_approx_top_k(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) x = jnp.array([3.0, 1.0, 4.0, 2.0]) - _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh, P('i'), P('i'))(x) + _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh=mesh, in_specs=P('i'), + out_specs=P('i'))(x) def test_disable_jit(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) @@ -1944,10 +1947,10 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -1981,10 +1984,10 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -2072,11 +2075,11 @@ def g(x): def f(x): return shard_map( g, - mesh, + mesh=mesh, in_specs=P(), out_specs=P(), - check_rep=False, - auto=frozenset({'i'}), + check_vma=False, + axis_names=frozenset({'j', 'k'}), )(x) v = jnp.arange(32.0).reshape(4, 8) @@ -2108,7 +2111,7 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(grad_fn, mesh=mesh, in_specs=P("data"), out_specs=P(), - check_rep=False)(batch) + check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) @@ -2126,7 +2129,7 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"), - out_specs=P("data"), check_rep=False)(batch) + out_specs=P("data"), check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) @@ -2160,11 +2163,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2183,11 +2186,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'k'}))(x) + check_vma=False, + axis_names=frozenset({'i', 'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2206,11 +2209,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2226,12 +2229,12 @@ def g(x): return x * x def h(x): - return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))(x) + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2246,18 +2249,18 @@ def g(x): def h(x): # auto: 'j', manual: 'i' - return shard_map(g, mesh, + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): # auto: 'i', 'j' - return shard_map(h, mesh, + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + check_vma=False, + axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2270,17 +2273,17 @@ def g(x): return x * x * x def h(x): - return shard_map(g, mesh, + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + check_vma=False, + axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2294,11 +2297,11 @@ def h(x): @jax.jit def f(x): - return shard_map(h, mesh, + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j', 'k'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2316,8 +2319,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2339,8 +2342,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2351,8 +2354,8 @@ def test_partial_auto_axis_index(self): @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) @@ -2363,8 +2366,8 @@ def test_partial_auto_axis_index_degenerated_axis(self): @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) @@ -2379,8 +2382,8 @@ def g(x): @jax.jit def f(x): return shard_map(g, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) y = f(x) # don't crash self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]), @@ -2399,8 +2402,8 @@ def f(x): # @jax.jit # def f(x): # return shard_map(g, - # mesh, in_specs=P('i', None), out_specs=P(None, 'i'), - # check_rep=False, auto=frozenset({'j'}))(x) + # mesh=mesh, in_specs=P('i', None), out_specs=P(None, 'i'), + # check_vma=False, axis_names=frozenset({'i'}))(x) # # f(x) # don't crash @@ -2417,8 +2420,8 @@ def g(x): @jax.jit def f(x): return shard_map(g, - mesh, in_specs=P('i'), out_specs=None, - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=None, + check_vma=False, axis_names=frozenset({'i'}))(x) y = f(x) # don't crash @@ -2429,8 +2432,8 @@ def test_partial_auto_of_random_keys(self): @jax.jit def f(x): return shard_map(lambda k: k, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(keys) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(keys) y = f(keys) # doesn't crash self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys), @@ -2443,8 +2446,8 @@ def test_partial_auto_of_random_keys_slice(self): @jax.jit def f(x): return shard_map(lambda k: k[0], - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) f(keys) # doesn't crash @@ -2549,7 +2552,7 @@ def f(x): jax.vmap(f, spmd_axis_name='i')(xs) @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P(('i', 'j')), - check_rep=False) + check_vma=False) def g(x): return jnp.sin(x) @@ -2567,11 +2570,11 @@ def f(o, x): return jnp.sin(x) obj = object() - y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(None, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f2(o, x): @@ -2580,7 +2583,7 @@ def f2(o, x): return jnp.sin(x) obj = {'a': object()} - y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) + y = shard_map(f2, mesh=mesh, in_specs=({'a': None}, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f3(x, o): @@ -2588,11 +2591,11 @@ def f3(x, o): return jnp.sin(x) obj = object() - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f4(o1, o2, x, o3): @@ -2605,7 +2608,8 @@ def f4(o1, o2, x, o3): obj1 = object() obj2 = (object(), object()) obj3 = object() - y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) + y = shard_map(f4, mesh=mesh, in_specs=(None, None, P('i'), None), + out_specs=P('i'))(obj1, obj2, x, obj3) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): @@ -2613,44 +2617,48 @@ def test_in_spec_none_divisibility_errors(self): x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i')), + out_specs=None)((object(), object()), x) def test_in_spec_none_rank_errors(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i', 'j')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i', 'j')), + out_specs=None)((object(), object()), x) def test_custom_linear_solve_rep_rules(self): # https://github.com/jax-ml/jax/issues/20162 @@ -2671,7 +2679,7 @@ def test_temporary_error_suppression_flag(self): def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), mesh=mesh, in_specs=(P(None), P('i')), out_specs=P(None), - check_rep=False, + check_vma=False, )(x, y) return z @@ -2807,7 +2815,7 @@ def test_rep_none_canonicalization_again(self): mesh = jtu.create_mesh((2,), ('i',)) def f(x): return jnp.insert(x, 0, 0)[None] - f = shard_map(f, mesh, P('i'), P('i')) + f = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) f(jnp.zeros(100)) # don't crash def test_custom_jvp_symbolic_zeros(self): @@ -2831,7 +2839,7 @@ def f_jvp(primals, tangents): x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20)) A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20)) - g = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) + g = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash def test_cond_pvary_errors(self): @@ -2846,7 +2854,7 @@ def false_fun(x, y): with self.assertRaisesRegex( TypeError, r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) def test_cond_pvary_errors_pytree(self): mesh = jtu.create_mesh((1, 1), ('x', 'y')) @@ -2861,7 +2869,7 @@ def false_fun(x, y): with self.assertRaisesRegex( TypeError, r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) def test_scan_pvary_errors(self): mesh = jtu.create_mesh((1, 1), ('i', 'j')) @@ -3272,7 +3280,8 @@ def make_mesh(mesh_shape): def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = shard_map(fun, mesh, in_specs, out_specs)(*args) + out = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs)(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -3281,7 +3290,8 @@ def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args) + out = jax.jit(shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs))(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -3294,7 +3304,8 @@ def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=check_rep) if jit: f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) @@ -3325,7 +3336,7 @@ def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) if jit: f = jax.jit(f) ans = jax.vmap(f, bdims)(*args) @@ -3425,8 +3436,8 @@ def f(x): def fwd(a): c = shard_map( f, - mesh, - check_rep=False, + mesh=mesh, + check_vma=False, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -3443,8 +3454,8 @@ def g(x): @jax.jit def f(x): x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(('i', 'j')))) - re = shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + re = shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names={'i'})(x) re = jax.lax.with_sharding_constraint(re, NamedSharding(mesh, P(('i', 'j')))) return re From 72b775baa136562f394f34abe41dd23662eca559 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 01:48:06 -0700 Subject: [PATCH 0775/1769] [pallas] Added a note on the recent `compiler_params=` change to the changelog PiperOrigin-RevId: 750900798 --- docs/pallas/CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index c960280a8891..db63657626b3 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -2,7 +2,7 @@ # Pallas Changelog - + This is the list of changes specific to {class}`jax.experimental.pallas`. For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html). @@ -11,7 +11,7 @@ For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changel Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.6.1 +## Unreleased * Changes @@ -19,6 +19,8 @@ Remember to align the itemized text with the first line of an item within a list addition to ints/None in the `block_shape`. `indexing_mode` has been removed. To achieve "Unblocked", pass a `pl.Element(size)` into `block_shape` for each entry that needs unblocked indexing. + * {func}`jax.experimental.pallas.pallas_call` now requires `compiler_params` + to be a backend-specific dataclass instead of a param to value mapping. ## Released with jax 0.5.0 From b2de662a235c60734c7b4e9f06725800ffa5a15d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 09:34:49 +0100 Subject: [PATCH 0776/1769] Ran pyupgrade --py310-plus on .pyi files in jaxlib --- jaxlib/_jax/__init__.pyi | 284 ++++++++++++++--------------- jaxlib/_jax/guard_lib.pyi | 12 +- jaxlib/_jax/ifrt_programs.pyi | 7 +- jaxlib/_jax/ifrt_proxy.pyi | 9 +- jaxlib/_jax/jax_jit.pyi | 19 +- jaxlib/_jax/mlir.pyi | 9 +- jaxlib/_jax/ops.pyi | 63 +++---- jaxlib/_jax/pmap_lib.pyi | 7 +- jaxlib/_jax/profiler.pyi | 16 +- jaxlib/_jax/pytree.pyi | 41 ++--- jaxlib/_jax/transfer_guard_lib.pyi | 10 +- jaxlib/xla_client.pyi | 16 +- 12 files changed, 238 insertions(+), 255 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index d80185c0fecd..19ffbb805404 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -15,24 +15,12 @@ from __future__ import annotations +import builtins import enum import inspect import types -import typing -from typing import ( - Any, - Callable, - ClassVar, - Dict, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, - overload, -) +from typing import Any, ClassVar, TypeVar, overload +from collections.abc import Callable, Iterator, Sequence import numpy as np @@ -106,13 +94,13 @@ class ArrayCopySemantics(enum.IntEnum): class Layout: @overload - def __init__(self, minor_to_major: Tuple[int, ...]): ... + def __init__(self, minor_to_major: tuple[int, ...]): ... @overload - def __init__(self, minor_to_major: Tuple[int, ...], - tiling: Tuple[Tuple[int, ...], ...], + def __init__(self, minor_to_major: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...], element_size_in_bits: int): ... - def minor_to_major(self) -> Tuple[int, ...]: ... - def tiling(self) -> Sequence[Tuple[int, ...]]: ... + def minor_to_major(self) -> tuple[int, ...]: ... + def tiling(self) -> Sequence[tuple[int, ...]]: ... def element_size_in_bits(self) -> int: ... def to_string(self) -> str: ... def __eq__(self, other: Any) -> bool: ... @@ -125,16 +113,16 @@ class Shape: def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... @staticmethod def array_shape( - type: Union[np.dtype, PrimitiveType], + type: np.dtype | PrimitiveType, dims_seq: Any = ..., layout_seq: Any = ..., - dynamic_dimensions: Optional[List[bool]] = ..., + dynamic_dimensions: list[bool] | None = ..., ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @staticmethod - def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... - def dimensions(self) -> Tuple[int, ...]: ... + def scalar_shape(type: np.dtype | PrimitiveType) -> Shape: ... + def dimensions(self) -> tuple[int, ...]: ... def layout(self) -> Layout: ... def xla_element_type(self) -> PrimitiveType: ... def element_type(self) -> np.dtype: ... @@ -148,7 +136,7 @@ class Shape: def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... def rank(self) -> int: ... def to_serialized_proto(self) -> bytes: ... - def tuple_shapes(self) -> List[Shape]: ... + def tuple_shapes(self) -> list[Shape]: ... def leaf_count(self) -> int: ... def with_major_to_minor_layout_if_absent(self) -> Shape: ... def __eq__(self, other: Any) -> bool: ... @@ -158,12 +146,12 @@ class Shape: class ProgramShape: def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... - def parameter_shapes(self) -> List[Shape]: ... + def parameter_shapes(self) -> list[Shape]: ... def result_shape(self) -> Shape: ... def __repr__(self) -> str: ... class ShapeIndex: - def __init__(self, indices: List[int]) -> None: ... + def __init__(self, indices: list[int]) -> None: ... def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... @@ -173,7 +161,7 @@ class Literal: def __init__(self, shape: Shape) -> None: ... def __repr__(self) -> str: ... def __array__( - self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None + self, dtype: np.dtype | None = None, copy: bool | None = None ) -> np.ndarray: ... def shape(self) -> Shape: ... @@ -217,8 +205,8 @@ class HloComputation: def render_html(self) -> None: ... class HloModule: - spmd_output_sharding: Optional[OpSharding] - spmd_parameters_shardings: Optional[List[OpSharding]] + spmd_output_sharding: OpSharding | None + spmd_parameters_shardings: list[OpSharding] | None @property def name(self) -> str: ... def to_string(self, options: HloPrintOptions = ...) -> str: ... @@ -227,31 +215,31 @@ class HloModule: def from_serialized_hlo_module_proto( serialized_hlo_module_proto: bytes, ) -> HloModule: ... - def computations(self) -> List[HloComputation]: ... + def computations(self) -> list[HloComputation]: ... class HloModuleGroup: - def __init__(self, name: str, modules: List[HloModule]) -> None: ... + def __init__(self, name: str, modules: list[HloModule]) -> None: ... @property def name(self) -> str: ... def to_string(self) -> str: ... - def to_modules(self) -> List[HloModule]: ... + def to_modules(self) -> list[HloModule]: ... def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... def hlo_module_cost_analysis( client: Client, module: HloModule -) -> Dict[str, float]: ... +) -> dict[str, float]: ... class XlaOp: ... class XlaBuilder: def __init__(self, name: str) -> None: ... - def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... + def Build(self, root: XlaOp | None = ...) -> XlaComputation: ... def GetShape(self, __op: XlaOp) -> Shape: ... build = Build def clear_op_metadata(self) -> None: ... get_shape = GetShape - def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... + def get_program_shape(self, root: XlaOp | None = ...) -> ProgramShape: ... def is_constant(self, __op: XlaOp) -> bool: ... def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... def set_sharding(self, sharding: OpSharding_Type) -> None: ... @@ -276,16 +264,16 @@ class CompileOptions: def ParseFromString(s: bytes) -> CompileOptions: ... def __init__(self) -> None: ... def SerializeAsString(self) -> bytes: ... - argument_layouts: Optional[List[Shape]] + argument_layouts: list[Shape] | None parameter_is_tupled_arguments: bool executable_build_options: ExecutableBuildOptions tuple_arguments: bool num_replicas: int num_partitions: int profile_version: int - device_assignment: Optional[DeviceAssignment] + device_assignment: DeviceAssignment | None compile_portable_executable: bool - env_option_overrides: List[Tuple[str, str]] + env_option_overrides: list[tuple[str, str]] def register_custom_call_target( fn_name: str, capsule: Any, platform: str, api_version: int = ..., @@ -296,12 +284,12 @@ def register_custom_call_partitioner( partition: Callable, infer_sharding_from_operands: Callable, can_side_effecting_have_replicated_sharding: bool = ..., - c_api: Optional[Any] = ..., + c_api: Any | None = ..., ) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... def register_custom_call_as_batch_partitionable( target_name: str, - c_api: Optional[Any] = ..., + c_api: Any | None = ..., ) -> None: ... def register_custom_type_id(type_name: str, type_id: Any) -> None: ... @@ -370,16 +358,16 @@ class CompiledMemoryStats: class ExecutableBuildOptions: def __init__(self) -> None: ... def __repr__(self) -> str: ... - result_layout: Optional[Shape] - fdo_profile: Optional[bytes] + result_layout: Shape | None + fdo_profile: bytes | None num_replicas: int num_partitions: int debug_options: DebugOptions - device_assignment: Optional[DeviceAssignment] + device_assignment: DeviceAssignment | None use_spmd_partitioning: bool use_auto_spmd_partitioning: bool - auto_spmd_partitioning_mesh_shape: List[int] - auto_spmd_partitioning_mesh_ids: List[int] + auto_spmd_partitioning_mesh_shape: list[int] + auto_spmd_partitioning_mesh_ids: list[int] use_shardy_partitioner: bool def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... @@ -413,7 +401,7 @@ class OpSharding_ShardGroupType(enum.IntEnum): LIKE = ... class OpSharding: - Type: typing.Type[OpSharding_Type] + Type: type[OpSharding_Type] type: OpSharding_Type replicate_on_last_tile_dim: bool last_tile_dims: Sequence[OpSharding_Type] @@ -424,7 +412,7 @@ class OpSharding: tuple_shardings: Sequence[OpSharding] is_shard_group: bool shard_group_id: int - ShardGroupType: typing.Type[OpSharding_ShardGroupType] + ShardGroupType: builtins.type[OpSharding_ShardGroupType] shard_group_type: OpSharding_ShardGroupType def ParseFromString(self, s: bytes) -> None: ... def SerializeToString(self) -> bytes: ... @@ -465,7 +453,7 @@ class HloSharding: def is_unknown(self) -> bool: ... def is_tiled(self) -> bool: ... def is_maximal(self) -> bool: ... - def tuple_elements(self) -> List[HloSharding]: ... + def tuple_elements(self) -> list[HloSharding]: ... def num_devices(self) -> int: ... def num_dimensions(self) -> int: ... def tile_assignment_dimensions(self) -> Sequence[int]: ... @@ -496,9 +484,9 @@ class Device: def transfer_from_outfeed(self, shape: Shape): ... def memory(self, kind: str) -> Memory: ... def default_memory(self) -> Memory: ... - def addressable_memories(self) -> List[Memory]: ... - def live_buffers(self) -> List[Any]: ... - def memory_stats(self) -> Optional[Dict[str, int]]: ... + def addressable_memories(self) -> list[Memory]: ... + def live_buffers(self) -> list[Any]: ... + def memory_stats(self) -> dict[str, int] | None: ... def get_stream_for_external_ready_events(self) -> int: ... def __getattr__(self, name: str) -> Any: ... @@ -508,7 +496,7 @@ class Memory: kind: str def __repr__(self) -> str: ... def __str__(self) -> str: ... - def addressable_by_devices(self) -> List[Device]: ... + def addressable_by_devices(self) -> list[Device]: ... class PjRtLayout: def __str__(self) -> str: ... @@ -545,25 +533,25 @@ class Client: runtime_type: str def device_count(self) -> int: ... def local_device_count(self) -> int: ... - def devices(self) -> List[Device]: ... - def local_devices(self) -> List[Device]: ... - def _get_all_devices(self) -> List[Device]: ... + def devices(self) -> list[Device]: ... + def local_devices(self) -> list[Device]: ... + def _get_all_devices(self) -> list[Device]: ... def device_from_local_hardware_id(self, int) -> Device: ... - def live_buffers(self) -> List[Any]: ... - def live_arrays(self) -> List[ArrayImpl]: ... - def live_executables(self) -> List[LoadedExecutable]: ... + def live_buffers(self) -> list[Any]: ... + def live_arrays(self) -> list[ArrayImpl]: ... + def live_executables(self) -> list[LoadedExecutable]: ... def host_id(self) -> int: ... def process_index(self) -> int: ... def buffer_from_pyval( self, argument: Any, - device: Optional[Device] = ..., + device: Device | None = ..., force_copy: bool = ..., host_buffer_semantics: HostBufferSemantics = ..., ) -> ArrayImpl: ... def compile( self, - computation: Union[str, bytes], + computation: str | bytes, compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... @@ -576,7 +564,7 @@ class Client: def deserialize_executable( self, serialized: bytes, - options: Optional[CompileOptions], + options: CompileOptions | None, host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... @@ -587,7 +575,7 @@ class Client: result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], recv_channel_ids: Sequence[int], - serializer: Optional[Callable] = ..., + serializer: Callable | None = ..., ) -> Any: ... def get_default_layout( self, dtype: np.dtype, shard_shape: Sequence[int], device: Device @@ -597,9 +585,9 @@ class Client: class CpuCollectives: ... def make_gloo_tcp_collectives( - distributed_client: Optional[DistributedRuntimeClient] = ..., - hostname: Optional[str] = ..., - interface: Optional[str] = ..., + distributed_client: DistributedRuntimeClient | None = ..., + hostname: str | None = ..., + interface: str | None = ..., ) -> CpuCollectives: ... class MpiCollectives(CpuCollectives): @@ -610,48 +598,48 @@ def make_mpi_collectives() -> MpiCollectives: ... def get_tfrt_cpu_client( asynchronous: bool = ..., - distributed_client: Optional[DistributedRuntimeClient] = ..., + distributed_client: DistributedRuntimeClient | None = ..., node_id: int = ..., num_nodes: int = ..., - collectives: Optional[CpuCollectives] = ..., + collectives: CpuCollectives | None = ..., num_devices: int | None = ..., ) -> Client: ... def get_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., - distributed_client: Optional[DistributedRuntimeClient] = ..., + distributed_client: DistributedRuntimeClient | None = ..., node_id: int = ..., num_nodes: int = ..., - allowed_devices: Optional[Any] = ..., - platform_name: Optional[str] = ..., - mock: Optional[bool] = ..., - mock_gpu_topology: Optional[str] = ..., + allowed_devices: Any | None = ..., + platform_name: str | None = ..., + mock: bool | None = ..., + mock_gpu_topology: str | None = ..., ) -> Client: ... def get_mock_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., - distributed_client: Optional[DistributedRuntimeClient] = ..., + distributed_client: DistributedRuntimeClient | None = ..., node_id: int = ..., - allowed_devices: Optional[Any] = ..., - platform_name: Optional[str] = ..., + allowed_devices: Any | None = ..., + platform_name: str | None = ..., ) -> Client: ... def get_c_api_client( platform_name: str, - options: Dict[str, Union[str, int, List[int], float, bool]], - distributed_client: Optional[DistributedRuntimeClient] = ..., + options: dict[str, str | int | list[int] | float | bool], + distributed_client: DistributedRuntimeClient | None = ..., ) -> Client: ... def get_default_c_api_topology( platform_name: str, topology_name: str, - options: Dict[str, Union[str, int, List[int], float]], + options: dict[str, str | int | list[int] | float], ) -> DeviceTopology: ... def get_c_api_topology( c_api: Any, topology_name: str, - options: Dict[str, Union[str, int, List[int], float]], + options: dict[str, str | int | list[int] | float], ) -> DeviceTopology: ... -def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... -def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... +def load_pjrt_plugin(platform_name: str, library_path: str | None, c_api: Any | None) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... @@ -688,7 +676,7 @@ ArrayImpl = Any def batched_copy_array_to_devices_with_sharding( arrays: Sequence[ArrayImpl], - devices: Sequence[List[Device]], + devices: Sequence[list[Device]], sharding: Sequence[Any], array_copy_semantics: Sequence[ArrayCopySemantics], ) -> Sequence[ArrayImpl]: ... @@ -699,7 +687,7 @@ def batched_device_put( aval: Any, sharding: Any, shards: Sequence[Any], - devices: List[Device], + devices: list[Device], committed: bool = True, ) -> ArrayImpl: ... @@ -710,8 +698,8 @@ def reorder_shards( ) -> ArrayImpl: ... def check_and_canonicalize_memory_kind( - memory_kind: Optional[str], device_list: DeviceList -) -> Optional[str]: ... + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... def array_result_handler( aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... ) -> Callable: ... @@ -725,52 +713,52 @@ class ShardedToken: class ExecuteResults: def __len__(self) -> int: ... - def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... + def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]: ... def disassemble_prefix_into_single_device_arrays( self, n: int - ) -> List[List[ArrayImpl]]: ... - def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... + ) -> list[list[ArrayImpl]]: ... + def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]: ... def consume_token(self) -> ShardedToken: ... class LoadedExecutable: client: Client - def local_devices(self) -> List[Device]: ... + def local_devices(self) -> list[Device]: ... def size_of_generated_code_in_bytes(self) -> int: ... def delete(self) -> None: ... - def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ... def execute_with_token( self, arguments: Sequence[ArrayImpl] - ) -> Tuple[List[ArrayImpl], Token]: ... + ) -> tuple[list[ArrayImpl], Token]: ... def execute_sharded( - self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ... ) -> ExecuteResults: ... - def hlo_modules(self) -> List[HloModule]: ... - def get_output_memory_kinds(self) -> List[List[str]]: ... + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... - def get_output_shardings(self) -> Optional[List[OpSharding]]: ... - def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... - def get_parameter_layouts(self) -> List[Layout]: ... - def get_output_layouts(self) -> List[Layout]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[Layout]: ... + def get_output_layouts(self) -> list[Layout]: ... def keep_alive(self) -> None: ... - def cost_analysis(self) -> Dict[str, Any]: ... + def cost_analysis(self) -> dict[str, Any]: ... traceback: Traceback - fingerprint: Optional[bytes] + fingerprint: bytes | None class Executable: - def hlo_modules(self) -> List[HloModule]: ... - def get_output_memory_kinds(self) -> List[List[str]]: ... - def get_output_shardings(self) -> Optional[List[OpSharding]]: ... - def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... - def get_parameter_layouts(self) -> List[Layout]: ... - def get_output_layouts(self) -> List[Layout]: ... + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[Layout]: ... + def get_output_layouts(self) -> list[Layout]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def serialize(self) -> str: ... - def cost_analysis(self) -> Dict[str, Any]: ... + def cost_analysis(self) -> dict[str, Any]: ... class DeviceTopology: platform: str platform_version: str - def _make_compile_only_devices(self) -> List[Device]: ... + def _make_compile_only_devices(self) -> list[Device]: ... def serialize(self) -> bytes: ... def __getattr__(self, name: str) -> Any: ... @@ -784,18 +772,18 @@ def dlpack_managed_tensor_to_buffer( @overload def dlpack_managed_tensor_to_buffer( # Legacy overload tensor: Any, - cpu_backend: Optional[Client] = ..., - gpu_backend: Optional[Client] = ..., + cpu_backend: Client | None = ..., + gpu_backend: Client | None = ..., ) -> ArrayImpl: ... def cuda_array_interface_to_buffer( - cai: Dict[str, Union[ - str, int, None, - Tuple[int, ...], Tuple[int, bool], - List[Tuple[str, str]], - List[Tuple[str, str, Tuple[int, ...]]]] + cai: dict[str, ( + str | int | None | + tuple[int, ...] | tuple[int, bool] | + list[tuple[str, str]] | + list[tuple[str, str, tuple[int, ...]]]) ], - gpu_backend: Optional[Client] = ..., + gpu_backend: Client | None = ..., device_id: int | None = None, ) -> ArrayImpl: ... @@ -822,13 +810,13 @@ class Traceback: frames: Sequence[Frame] def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... - def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... + def raw_frames(self) -> tuple[list[types.CodeType], list[int]]: ... @staticmethod def code_addr2line(code: types.CodeType, lasti: int) -> int: ... @staticmethod def code_addr2location( code: types.CodeType, lasti: int - ) -> Tuple[int, int, int, int]: ... + ) -> tuple[int, int, int, int]: ... def replace_thread_exc_traceback(traceback: Any): ... @@ -854,28 +842,28 @@ class DistributedRuntimeClient: allow_overwrite: bool = False) -> _Status: ... def key_value_delete(self, key: str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int, - process_ids: Optional[List[int]] = None) -> _Status: ... - def get_live_nodes(self, process_ids: List[int]) -> _Status: ... + process_ids: list[int] | None = None) -> _Status: ... + def get_live_nodes(self, process_ids: list[int]) -> _Status: ... def get_distributed_runtime_service( address: str, num_nodes: int, - heartbeat_interval: Optional[int] = ..., - max_missing_heartbeats: Optional[int] = ..., - cluster_register_timeout: Optional[int] = ..., - shutdown_timeout: Optional[int] = ..., + heartbeat_interval: int | None = ..., + max_missing_heartbeats: int | None = ..., + cluster_register_timeout: int | None = ..., + shutdown_timeout: int | None = ..., ) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, node_id: int, - rpc_timeout: Optional[int] = ..., - init_timeout: Optional[int] = ..., - shutdown_timeout: Optional[int] = ..., - heartbeat_interval: Optional[int] = ..., - max_missing_heartbeats: Optional[int] = ..., - missed_heartbeat_callback: Optional[Any] = ..., - shutdown_on_destruction: Optional[bool] = ..., - use_compression: Optional[bool] = ..., + rpc_timeout: int | None = ..., + init_timeout: int | None = ..., + shutdown_timeout: int | None = ..., + heartbeat_interval: int | None = ..., + max_missing_heartbeats: int | None = ..., + missed_heartbeat_callback: Any | None = ..., + shutdown_on_destruction: bool | None = ..., + use_compression: bool | None = ..., ) -> DistributedRuntimeClient: ... class PreemptionSyncManager: @@ -897,7 +885,7 @@ class PmapFunction: def _cache_clear(self) -> None: ... class DeviceList: - def __init__(self, device_assignment: Tuple[Device, ...]): ... + def __init__(self, device_assignment: tuple[Device, ...]): ... def __hash__(self) -> int: ... def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... @@ -913,9 +901,9 @@ class DeviceList: @property def addressable_device_list(self) -> DeviceList: ... @property - def default_memory_kind(self) -> Optional[str]: ... + def default_memory_kind(self) -> str | None: ... @property - def memory_kinds(self) -> Tuple[str, ...]: ... + def memory_kinds(self) -> tuple[str, ...]: ... class Sharding: ... @@ -925,26 +913,26 @@ class NamedSharding(Sharding): mesh: Any, spec: Any, *, - memory_kind: Optional[str] = None, + memory_kind: str | None = None, _logical_device_ids: tuple[int, ...] | None = None, ): ... mesh: Any spec: Any - _memory_kind: Optional[str] + _memory_kind: str | None _internal_device_list: DeviceList _logical_device_ids: tuple[int, ...] | None class SingleDeviceSharding(Sharding): - def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ... + def __init__(self, device: Device, *, memory_kind: str | None = None): ... _device: Device - _memory_kind: Optional[str] + _memory_kind: str | None _internal_device_list: DeviceList class PmapSharding(Sharding): def __init__( self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec ): ... - devices: List[Any] + devices: list[Any] sharding_spec: pmap_lib.ShardingSpec _internal_device_list: DeviceList @@ -952,14 +940,14 @@ class GSPMDSharding(Sharding): def __init__( self, devices: Sequence[Device], - op_sharding: Union[OpSharding, HloSharding], + op_sharding: OpSharding | HloSharding, *, - memory_kind: Optional[str] = None, - _device_list: Optional[DeviceList] = None, + memory_kind: str | None = None, + _device_list: DeviceList | None = None, ): ... - _devices: Tuple[Device, ...] + _devices: tuple[Device, ...] _hlo_sharding: HloSharding - _memory_kind: Optional[str] + _memory_kind: str | None _internal_device_list: DeviceList class PjitFunction: @@ -977,14 +965,14 @@ class PjitFunctionCache: def pjit( function_name: str, - fun: Optional[Callable], + fun: Callable | None, cache_miss: Callable, static_argnums: Sequence[int], static_argnames: Sequence[str], global_cache_key: Any, pytree_registry: pytree.PyTreeRegistry, shard_arg_fallback: Callable, - cache: Optional[PjitFunctionCache] = ..., + cache: PjitFunctionCache | None = ..., ) -> PjitFunction: ... class HloPassInterface: diff --git a/jaxlib/_jax/guard_lib.pyi b/jaxlib/_jax/guard_lib.pyi index cfa8b0c5fa5e..7f8896a4f75a 100644 --- a/jaxlib/_jax/guard_lib.pyi +++ b/jaxlib/_jax/guard_lib.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Any, List, Optional +from typing import Any class TransferGuardLevel: ALLOW: Any @@ -28,14 +28,14 @@ class GarbageCollectionGuardLevel: FATAL: Any class GuardState: - host_to_device: Optional[TransferGuardLevel] - device_to_device: Optional[TransferGuardLevel] - device_to_host: Optional[TransferGuardLevel] + host_to_device: TransferGuardLevel | None + device_to_device: TransferGuardLevel | None + device_to_host: TransferGuardLevel | None explicit_device_put: bool explicit_device_get: bool - garbage_collect_array: Optional[GarbageCollectionGuardLevel] + garbage_collect_array: GarbageCollectionGuardLevel | None def global_state() -> GuardState: ... def thread_local_state() -> GuardState: ... @@ -43,4 +43,4 @@ def thread_local_state() -> GuardState: ... class _TestingScopedLogSink: def __enter__(self) -> _TestingScopedLogSink: ... def __exit__(self, *args, **kwargs) -> None: ... - def logs(self) -> List[str]: ... + def logs(self) -> list[str]: ... diff --git a/jaxlib/_jax/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi index 8c525de478be..6fcd6525a95b 100644 --- a/jaxlib/_jax/ifrt_programs.pyi +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Any, Sequence, Union +from typing import Any +from collections.abc import Sequence from jaxlib import _jax @@ -21,7 +22,7 @@ class Program: ... class CompileOptions: ... -def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ... +def make_hlo_program(mlir_module: str | bytes) -> Program: ... def make_colocated_python_program( name : str, @@ -31,7 +32,7 @@ def make_colocated_python_program( output_avals: Sequence[Any], ) -> Program: ... -def make_plugin_program(data: Union[str, bytes]) -> Program: ... +def make_plugin_program(data: str | bytes) -> Program: ... def make_colocated_python_compile_options() -> CompileOptions: ... diff --git a/jaxlib/_jax/ifrt_proxy.pyi b/jaxlib/_jax/ifrt_proxy.pyi index 77963eae0f7e..73688b2d9696 100644 --- a/jaxlib/_jax/ifrt_proxy.pyi +++ b/jaxlib/_jax/ifrt_proxy.pyi @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Any, Optional, Callable +from typing import Any +from collections.abc import Callable from jaxlib import _jax @@ -22,9 +23,9 @@ Client = _jax.Client class ClientConnectionOptions: - on_disconnect: Optional[Callable[[_Status], None]] = None - on_connection_update: Optional[Callable[[str], None]] = None - connection_timeout_in_seconds: Optional[int] = None + on_disconnect: Callable[[_Status], None] | None = None + on_connection_update: Callable[[str], None] | None = None + connection_timeout_in_seconds: int | None = None def get_client( diff --git a/jaxlib/_jax/jax_jit.pyi b/jaxlib/_jax/jax_jit.pyi index fd39ef01963e..be7687f4eaa1 100644 --- a/jaxlib/_jax/jax_jit.pyi +++ b/jaxlib/_jax/jax_jit.pyi @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Any +from collections.abc import Callable, Sequence import numpy as np from jaxlib import _jax @@ -25,11 +26,11 @@ Device = _jax.Device class JitState: - disable_jit: Optional[bool] - enable_x64: Optional[bool] - default_device: Optional[Any] - extra_jit_context: Optional[Any] - post_hook: Optional[Callable[..., Any]] + disable_jit: bool | None + enable_x64: bool | None + default_device: Any | None + extra_jit_context: Any | None + post_hook: Callable[..., Any] | None def global_state() -> JitState: ... def thread_local_state() -> JitState: ... @@ -39,11 +40,11 @@ def set_thread_local_state_initialization_callback( function: Callable[[], None]): ... def swap_thread_local_state_disable_jit( - value: Optional[bool]) -> Optional[bool]: ... + value: bool | None) -> bool | None: ... class ArgSignature: dtype: np.dtype - shape: Tuple[int, ...] + shape: tuple[int, ...] weak_type: bool def _ArgSignatureOfValue( @@ -69,7 +70,7 @@ class ArgumentSignature: def parse_arguments( positional_args: Sequence[Any], keyword_args: Sequence[Any], - kwnames: Tuple[str, ...], + kwnames: tuple[str, ...], static_argnums: Sequence[int], static_argnames: Sequence[str], pytree_registry: pytree.PyTreeRegistry, diff --git a/jaxlib/_jax/mlir.pyi b/jaxlib/_jax/mlir.pyi index 961f01a0352c..9be8ef71b50d 100644 --- a/jaxlib/_jax/mlir.pyi +++ b/jaxlib/_jax/mlir.pyi @@ -13,22 +13,21 @@ # limitations under the License. # ============================================================================== -from typing import Union from . import XlaComputation def hlo_to_stablehlo(computation: bytes) -> bytes: ... def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... def mlir_module_to_xla_computation( - mlir_module: Union[bytes, str], + mlir_module: bytes | str, use_tuple_args: bool = ..., return_tuple: bool = ..., ) -> XlaComputation: ... -def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ... -def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ... +def mhlo_to_stablehlo(mlir_module: bytes | str) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: bytes | str) -> bytes: ... def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... def refine_polymorphic_shapes( - mlir_module: Union[bytes, str], + mlir_module: bytes | str, enable_shape_assertions: bool = ..., validate_static_shapes: bool = ..., enable_shardy: bool = ..., diff --git a/jaxlib/_jax/ops.pyi b/jaxlib/_jax/ops.pyi index 7f5e46cabbdf..06a38b9090f6 100644 --- a/jaxlib/_jax/ops.pyi +++ b/jaxlib/_jax/ops.pyi @@ -14,7 +14,8 @@ # ============================================================================== import enum -from typing import Any, Optional, Sequence, overload +from typing import Any, overload +from collections.abc import Sequence from jaxlib import _jax @@ -68,15 +69,15 @@ def AllGather( all_gather_dimension: int, shard_count: int, replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: Optional[_ChannelHandle] = ..., - shape_with_layout: Optional[_Layout] = ..., - use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... + channel_id: _ChannelHandle | None = ..., + shape_with_layout: _Layout | None = ..., + use_global_device_ids: bool | None = ...) -> XlaOp: ... def AllReduce( operand: XlaOp, computation: XlaComputation, replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: Optional[_ChannelHandle] = ..., - shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ... + channel_id: _ChannelHandle | None = ..., + shape_with_layout: _Layout | None = ...) -> XlaOp: ... def ApproxTopK( builder: XlaBuilder, operands: Sequence[XlaOp], @@ -84,9 +85,9 @@ def ApproxTopK( top_k: int, reduction_dim: int, comparator: XlaComputation, - recall_target: Optional[float], - aggregate_to_topk: Optional[bool], - reduction_input_size_override: Optional[int]) -> XlaOp: ... + recall_target: float | None, + aggregate_to_topk: bool | None, + reduction_input_size_override: int | None) -> XlaOp: ... def ApproxTopKFallback( builder: XlaBuilder, operands: Sequence[XlaOp], @@ -94,33 +95,33 @@ def ApproxTopKFallback( top_k: int, reduction_dim: int, comparator: XlaComputation, - recall_target: Optional[float], - aggregate_to_topk: Optional[bool], - reduction_input_size_override: Optional[int]) -> XlaOp: ... + recall_target: float | None, + aggregate_to_topk: bool | None, + reduction_input_size_override: int | None) -> XlaOp: ... def ApproxTopKReductionOutputSize( input_size: int, rank: int, top_k: int, recall_target: float, - aggregate_to_topk: Optional[bool] = ..., - input_size_override: Optional[int] = ...) -> tuple[int, int]: ... + aggregate_to_topk: bool | None = ..., + input_size_override: int | None = ...) -> tuple[int, int]: ... def ReduceScatter( operand: XlaOp, computation: XlaComputation, scatter_dimension: int, shard_count: int, replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: Optional[_ChannelHandle] = ..., - layout: Optional[_Layout] = ..., - use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... + channel_id: _ChannelHandle | None = ..., + layout: _Layout | None = ..., + use_global_device_ids: bool | None = ...) -> XlaOp: ... def AllToAll( operand: XlaOp, split_dimension: int, concat_dimension: int, split_count: int, replica_groups: Sequence[_ReplicaGroup] = ..., - layout: Optional[_Layout] = ..., - channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ... + layout: _Layout | None = ..., + channel_id: _ChannelHandle | None = ...) -> XlaOp: ... def BitcastConvertType(operand: XlaOp, new_element_type: PrimitiveType) -> XlaOp: ... def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ... @@ -136,7 +137,7 @@ def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... def CollectivePermute( operand: XlaOp, source_target_pairs: Sequence[tuple[int, int]], - channel_id: Optional[_ChannelHandle] = ..., + channel_id: _ChannelHandle | None = ..., inplace: bool = ...) -> XlaOp: ... def ConcatInDim(builder: XlaBuilder, operands: Sequence[XlaOp], @@ -165,9 +166,9 @@ def ConvGeneralDilated( dimension_numbers: _ConvDimensionNumbers, feature_group_count: int = ..., batch_group_count: int = ..., - precision_config: Optional[PrecisionConfig_Precision] = ..., - preferred_element_type: Optional[PrimitiveType] = ..., - window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ... + precision_config: PrecisionConfig_Precision | None = ..., + preferred_element_type: PrimitiveType | None = ..., + window_reversal: Sequence[bool] | None = ...) -> XlaOp: ... def ConvertElementType( operand: XlaOp, new_element_type: PrimitiveType) -> XlaOp: ... @@ -209,14 +210,14 @@ def CustomCallWithAliasing( def Dot( lhs: XlaOp, rhs: XlaOp, - precision_config: Optional[PrecisionConfig_Precision] = ..., - preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... + precision_config: PrecisionConfig_Precision | None = ..., + preferred_element_type: PrimitiveType | None = ...) -> XlaOp: ... def DotGeneral( lhs: XlaOp, rhs: XlaOp, dimensions_numbers: _DotDimensionNumbers, - precision_config: Optional[PrecisionConfig_Precision] = ..., - preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... + precision_config: PrecisionConfig_Precision | None = ..., + preferred_element_type: PrimitiveType | None = ...) -> XlaOp: ... def DynamicReshape( operand: XlaOp, dim_sizes: Sequence[XlaOp], @@ -251,7 +252,7 @@ def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ... def InfeedWithToken( token: XlaOp, shape: Shape, - config: Optional[str] = ...) -> XlaOp: ... + config: str | None = ...) -> XlaOp: ... @overload def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ... @overload @@ -266,14 +267,14 @@ def Map( def MultiCollectivePermute( operands: Sequence[XlaOp], source_target_pairs: Sequence[tuple[int, int]], - channel_id: Optional[_ChannelHandle] = ..., + channel_id: _ChannelHandle | None = ..., inplace: bool = ...) -> XlaOp: ... def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ... def OutfeedWithToken( operand: XlaOp, token: XlaOp, shape_with_layout: Shape, - outfeed_config: Optional[str] = ...) -> XlaOp: ... + outfeed_config: str | None = ...) -> XlaOp: ... def Pad( operand: XlaOp, padding_value: XlaOp, @@ -368,7 +369,7 @@ def SliceInDim( def Sort( builder: XlaBuilder, operands: Sequence[XlaOp], - comparator: Optional[XlaComputation] = ..., + comparator: XlaComputation | None = ..., dimension: int = ..., is_stable: bool = ...) -> XlaOp: ... def SVD( diff --git a/jaxlib/_jax/pmap_lib.pyi b/jaxlib/_jax/pmap_lib.pyi index f862e87c0fcd..3e26e7e1da84 100644 --- a/jaxlib/_jax/pmap_lib.pyi +++ b/jaxlib/_jax/pmap_lib.pyi @@ -14,7 +14,8 @@ # ============================================================================== import inspect -from typing import Any, Callable, Sequence, Iterable, Tuple +from typing import Any +from collections.abc import Callable, Sequence, Iterable from . import pytree @@ -59,9 +60,9 @@ class ShardingSpec: sharding: Iterable[_AvalDimSharding], mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... @property - def sharding(self) -> Tuple[_AvalDimSharding, ...]: ... + def sharding(self) -> tuple[_AvalDimSharding, ...]: ... @property - def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ... + def mesh_mapping(self) -> tuple[_MeshDimAssignment]: ... def __eq__(self, __other: Any) -> bool: ... def __hash__(self) -> int: ... diff --git a/jaxlib/_jax/profiler.pyi b/jaxlib/_jax/profiler.pyi index 95749f61978a..a2fcc67fbcb7 100644 --- a/jaxlib/_jax/profiler.pyi +++ b/jaxlib/_jax/profiler.pyi @@ -14,7 +14,7 @@ # ============================================================================== from types import TracebackType -from typing import Any, Optional, Type, Union, List, Tuple +from typing import Any _Status = Any @@ -24,13 +24,13 @@ def start_server(port: int) -> ProfilerServer: ... def register_plugin_profiler(c_api: Any) -> None: ... def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... -def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ... +def get_instructins_profile(tensorboard_dir: str) -> list[tuple[str, float]]: ... def get_fdo_profile( xspace: bytes, as_textproto: bool = ... -) -> Union[bytes, str]: ... +) -> bytes | str: ... class ProfilerSession: - def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ... + def __init__(self, options: ProfileOptions | None = ...) -> None: ... def stop(self) -> bytes: ... def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... @@ -44,16 +44,16 @@ class ProfileOptions: repository_path: str raise_error_on_start_failure: bool -def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... +def aggregate_profiled_instructions(profiles: list[bytes], percentile: int) -> str: ... class TraceMe: def __init__(self, name: str, **kwargs: Any) -> None: ... def __enter__(self) -> TraceMe: ... def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> Optional[bool]:... + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None) -> bool | None:... def set_metadata(self, **kwargs): ... @staticmethod def is_enabled() -> bool: ... diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi index 157d455e20ae..ac5298c77964 100644 --- a/jaxlib/_jax/pytree.pyi +++ b/jaxlib/_jax/pytree.pyi @@ -13,18 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import ( - Any, - Callable, - Hashable, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +from builtins import tuple as Tuple +from typing import Any, TypeVar +from collections.abc import Callable, Hashable, Iterable, Sequence _T = TypeVar("_T") @@ -43,22 +34,22 @@ class PyTreeRegistry: def flatten( self, tree: Any, - leaf_predicate: Optional[Callable[[Any], bool]] = ..., - ) -> Tuple[List[Any], PyTreeDef]: ... + leaf_predicate: Callable[[Any], bool] | None = ..., + ) -> Tuple[list[Any], PyTreeDef]: ... def flatten_one_level( self, tree: Any - ) -> Optional[Tuple[Iterable[Any], Any]]: ... + ) -> Tuple[Iterable[Any], Any] | None: ... def flatten_one_level_with_keys( self, tree: Any - ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... + ) -> Tuple[Iterable[_KeyLeafPair], Any] | None: ... def flatten_with_path( self, tree: Any, - leaf_predicate: Optional[Callable[[Any], bool]] = ..., - ) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ... + leaf_predicate: Callable[[Any], bool] | None = ..., + ) -> Tuple[list[Tuple[_KeyPath, Any]], PyTreeDef]: ... def register_node( self, - __type: Type[_T], + __type: type[_T], to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], from_iterable: Callable[[_AuxData, _Children], _T], to_iterable_with_keys: ( @@ -66,7 +57,7 @@ class PyTreeRegistry: ) = ..., ) -> Any: ... def register_dataclass_node( - self, __type: Type[_T], meta_fields: List[str], data_fields: List[str] + self, __type: type[_T], meta_fields: list[str], data_fields: list[str] ) -> Any: ... def default_registry() -> PyTreeRegistry: ... @@ -119,21 +110,21 @@ class FlattenedIndexKey(Hashable): class PyTreeDef: def unflatten(self, __leaves: Iterable[Any]) -> Any: ... - def flatten_up_to(self, __xs: Any) -> List[Any]: ... + def flatten_up_to(self, __xs: Any) -> list[Any]: ... def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... def walk( self, __f_node: Callable[[Any, Any], Any], - __f_leaf: Optional[Callable[[_T], Any]], + __f_leaf: Callable[[_T], Any] | None, leaves: Iterable[Any], ) -> Any: ... def from_iterable_tree(self, __xs: Any): ... - def node_data(self) -> Optional[Tuple[Type, Any]]: ... - def children(self) -> List[PyTreeDef]: ... + def node_data(self) -> Tuple[type, Any] | None: ... + def children(self) -> list[PyTreeDef]: ... @staticmethod def make_from_node_data_and_children( registry: PyTreeRegistry, - node_data: Optional[Tuple[Type, Any]], + node_data: Tuple[type, Any] | None, children: Iterable[PyTreeDef], ) -> PyTreeDef: ... diff --git a/jaxlib/_jax/transfer_guard_lib.pyi b/jaxlib/_jax/transfer_guard_lib.pyi index 091e1e10a742..d293f7c59798 100644 --- a/jaxlib/_jax/transfer_guard_lib.pyi +++ b/jaxlib/_jax/transfer_guard_lib.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Any, List, Optional +from typing import Any class TransferGuardLevel: ALLOW: Any @@ -23,9 +23,9 @@ class TransferGuardLevel: DISALLOW_EXPLICIT: Any class TransferGuardState: - host_to_device: Optional[TransferGuardLevel] - device_to_device: Optional[TransferGuardLevel] - device_to_host: Optional[TransferGuardLevel] + host_to_device: TransferGuardLevel | None + device_to_device: TransferGuardLevel | None + device_to_host: TransferGuardLevel | None explicit_device_put: bool explicit_device_get: bool @@ -36,4 +36,4 @@ def thread_local_state() -> TransferGuardState: ... class _TestingScopedLogSink: def __enter__(self) -> _TestingScopedLogSink: ... def __exit__(self, *args, **kwargs) -> None: ... - def logs(self) -> List[str]: ... + def logs(self) -> list[str]: ... diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 445bb2287f8a..1a6751066e7c 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -160,7 +160,7 @@ class PaddingConfig: dimensions: list[PaddingConfigDimension] def make_padding_config( - padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]], + padding_config: PaddingConfig | Sequence[tuple[int, int, int]], ) -> PaddingConfig: ... @@ -175,10 +175,10 @@ class DotDimensionNumbers: rhs_batch_dimensions: list[int] def make_dot_dimension_numbers( - dimension_numbers: Union[ - DotDimensionNumbers, - tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], - ], + dimension_numbers: ( + DotDimensionNumbers | + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] + ), ) -> DotDimensionNumbers: ... @@ -194,9 +194,9 @@ class ConvolutionDimensionNumbers: output_spatial_dimensions: list[int] def make_convolution_dimension_numbers( - dimension_numbers: Union[ - None, ConvolutionDimensionNumbers, tuple[str, str, str] - ], + dimension_numbers: ( + None | ConvolutionDimensionNumbers | tuple[str, str, str] + ), num_spatial_dimensions: int, ) -> ConvolutionDimensionNumbers: ... From aaa0279b7f1d14fb04bcf59e1d525d33dbc3f773 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 02:10:04 -0700 Subject: [PATCH 0777/1769] [pallas] Fixed the type of `MemoryRef.dtype` PiperOrigin-RevId: 750907599 --- jax/_src/pallas/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 04390a03d4d6..90d35c7949f7 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -199,7 +199,7 @@ def update( class MemoryRef: """Like jax.ShapeDtypeStruct but with memory spaces.""" shape: tuple[int, ...] - dtype: jnp.dtype + dtype: jnp.dtype | dtypes.ExtendedDType # TODO(b/368122763): Unify memory space types across backends memory_space: Any From 50ead60b8bb694d4f764e0e2055b9d80b18d57ad Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 24 Apr 2025 05:01:12 -0700 Subject: [PATCH 0778/1769] [Mosaic GPU] Introduce a dedicated `DialectBarrierRef` and handle warpgroup logic in the dialect lowering. The `DialectBarrierRef` class has the same interface as `BarrierRef`, but uses mgpu ops for initialization and `expect_arrive_tx`. This makes the IR cleaner and also allows us to take care of adjusting arrival counts and bytes in the dialect lowering. That makes the high-level code cleaner. The new lowering always has all threads in a warpgroup arrive when using WG semantics. The behavior so far was to have only a single thread arrive, but keeping this would have complicated things going forward. The existing tests (including the one that's no longer skipped) test the new behavior. PiperOrigin-RevId: 750948900 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/core.py | 27 ++++-- .../mosaic/gpu/dialect_lowering.py | 45 +++++++--- jax/experimental/mosaic/gpu/utils.py | 82 ++++++++++++++++--- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 4 - tests/mosaic/gpu_dialect_test.py | 9 +- tests/mosaic/gpu_test.py | 12 +-- tests/pallas/mosaic_gpu_test.py | 2 - 10 files changed, 144 insertions(+), 49 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a961d5bf56c4..5801441fdaf8 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -280,7 +280,9 @@ def __iter__(self) -> Iterable[Hashable]: ) -AnyBarrierRef = mgpu.BarrierRef | mgpu.CollectiveBarrierRef +AnyBarrierRef = ( + mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef +) @dataclasses.dataclass @@ -319,7 +321,9 @@ def single_lane_predicate(self) -> ir.Value: raise ValueError(f"Unknown semantics: {self.primitive_semantics}") @contextlib.contextmanager - def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: + def reserve_barrier( + self, barrier: mgpu.Barrier + ) -> mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef: """Reserves a barrier. Raises: @@ -807,6 +811,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): in_shapes=in_shapes, out_shape=(*out_shapes, *gmem_scratch_shapes), smem_scratch_shape=scratch_buffers, + lowering_semantics=lowering_semantics, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 070e37f64e2c..7e1c7254ddfb 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -526,7 +526,7 @@ def _copy_gmem_to_smem_lowering( indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") - barrier_ref = barrier.as_dialect_barrier_memref() + barrier_ref = barrier.as_barrier_memref() mgpu.dialect.arrive_expect_tx(barrier_ref, bytes) mgpu.dialect.async_load( src, diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index c1275396036c..16b667690de4 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -68,6 +68,7 @@ ) from .utils import ( BarrierRef as BarrierRef, + DialectBarrierRef as DialectBarrierRef, CollectiveBarrierRef as CollectiveBarrierRef, DynamicSlice as DynamicSlice, Partition as Partition, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 8e240d55d4cc..bb0ecff96350 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -227,6 +227,7 @@ def _construct_smem_reftree( dynamic_smem: ir.Value, smem_buffers: ShapeTree, delayed_warp_init: list[Callable[[], None]], # Mutated by this function! + lowering_semantics: LoweringSemantics, dynamic_smem_offset: int = 0, ) -> Callable[[], RefTree]: index = ir.IndexType.get() @@ -237,6 +238,7 @@ def _construct_smem_reftree( smem_buffers, is_leaf=lambda x: isinstance(x, Union) ) smem_refs = [] + for ref_ty in flat_ref_tys: def get_barrier_ptr(num_barriers: int) -> ir.Value: nonlocal dynamic_smem_offset @@ -260,6 +262,7 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: dynamic_smem, m, delayed_warp_init, + lowering_semantics, dynamic_smem_offset, ) for m in members @@ -271,11 +274,17 @@ def ref(member_thunks=member_thunks): return Union([t() for t in member_thunks]) case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( + init_fn = utils.DialectBarrierRef.initialize if ( + lowering_semantics == LoweringSemantics.Warpgroup + ) else utils.BarrierRef.initialize + ref = init_fn( get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 ) case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( + init_fn = utils.DialectBarrierRef.initialize if ( + lowering_semantics == LoweringSemantics.Warpgroup + ) else utils.BarrierRef.initialize + ref = init_fn( get_barrier_ptr(num_barriers), num_barriers, arrival_count=arrival_count, @@ -361,6 +370,7 @@ def _launch( block: tuple[int, int, int], scratch_arr, smem_buffers: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, ): @@ -433,7 +443,11 @@ def _launch( with ctx.named_region("Init"): delayed_warp_init = [] smem_ref_tree_thunk = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers, delayed_warp_init + cluster, + dynamic_smem, + smem_buffers, + delayed_warp_init, + lowering_semantics, ) # TODO(apaszke): Skip fences if no barriers or TMEM is initialized. # TODO(apaszke): Only initialize cluster barriers before the cluster wait. @@ -465,6 +479,7 @@ def _lower_as_gpu_kernel( in_shapes: tuple[Any, ...], out_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, module_name: str, kernel_name: str | None = None, prof_spec: profiler.ProfilerSpec | None = None, @@ -526,7 +541,7 @@ def main(token_ptr, buffers): scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) with _launch( token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer + lowering_semantics, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx launch_ctx = _launch_ctx @@ -618,7 +633,7 @@ def as_gpu_kernel( module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec + thread_semantics, module_name, kernel_name, prof_spec ) ) @@ -701,7 +716,7 @@ def as_torch_gpu_kernel( module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec + lowering_semantics, module_name, kernel_name, prof_spec ) ) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 0ee33b4bfa92..f2acd14cac72 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -188,10 +188,18 @@ def _initialize_barrier_op_lowering_rule( for i in range(num_barriers): nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i], - lowered_barrier_type), - utils.c(initialize_barrier_op.arrival_count.value, i32), - predicate=ctx.single_thread_per_block_predicate + llvm.getelementptr( + ptr_ty, + initialize_barrier_op.base_pointer, + [], + [i], + lowered_barrier_type, + ), + utils.c( + initialize_barrier_op.arrival_count.value * utils.WARPGROUP_SIZE, + i32, + ), + predicate=ctx.single_thread_per_block_predicate, ) gpu.barrier() @@ -596,7 +604,7 @@ def _mgpu_async_load_op_lowering_rule( ctx: LoweringContext, load_op: mgpu.AsyncLoadOp ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(load_op.barrier) if inference_utils.has_in_transforms_set(load_op): [transforms] = inference_utils.in_transforms(load_op) @@ -624,7 +632,7 @@ def _mgpu_async_load_op_lowering_rule( src_ref=load_op.source, dst_ref=reinterpret_smem_ref(load_op.destination, transforms), gmem_slice=tuple(gmem_slice), - barrier=barrier, + barrier=barrier.barrier_ref, arrive=False, uniform=True, swizzle=swizzle, @@ -924,14 +932,25 @@ def _mgpu_wgmma_op_lowering_rule( @_register_lowering(mgpu.ArriveExpectTxOp) def _mgpu_arrive_expect_tx_op_lowering_rule( - ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp + _: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp ) -> Sequence[ir.Value]: - - barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier) - barrier.arrive_expect_tx( - arrive_expect_tx_op.expect_tx.value, - ctx.single_thread_per_warpgroup_predicate, + bytes = arrive_expect_tx_op.expect_tx.value + if bytes % utils.WARPGROUP_SIZE: + raise NotImplementedError( + "Only copies of a multiple of 128 bytes are supported" + ) + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: dasenov - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= utils.WARPGROUP_SIZE + bytes = utils.c(bytes, ir.IntegerType.get_signless(32)) + + barrier = utils.DialectBarrierRef.from_barrier_memref( + arrive_expect_tx_op.barrier ) + nvvm.mbarrier_arrive_expect_tx_shared(barrier.get_ptr(), bytes) return [] @@ -941,7 +960,7 @@ def _mgpu_wait_op_lowering_rule( _: LoweringContext, wait_op: mgpu.WaitOp ) -> Sequence[ir.Value]: - barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(wait_op.barrier) barrier.wait_parity(wait_op.parity) return [] diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 3c7532dde99d..7957a92a3e0c 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -796,8 +796,67 @@ def get_ptr(self): ptr, self.base_address, [self.offset], [DYNAMIC32], i64 ) - def as_dialect_barrier_memref(self) -> ir.Value: - shape = () if self.num_barriers == 1 else (self.num_barriers,) + +@dataclasses.dataclass(frozen=True) +class DialectBarrierRef: + barrier_ref: BarrierRef + + @staticmethod + def initialize( + address: ir.Value, + num_barriers: int, + arrival_count: int = 1, + ) -> "DialectBarrierRef": + if num_barriers > 32: + raise NotImplementedError("Only up to 32 barriers per group supported") + + barrier_ty = ir.MemRefType.get( + (num_barriers,), ir.Type.parse("!mosaic_gpu.barrier") + ) + dialect.InitializeBarrierOp( + barrier_ty, base_pointer=address, arrival_count=arrival_count + ) + + i32 = ir.IntegerType.get_signless(32) + phases = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), phases, []) + return DialectBarrierRef( + barrier_ref=BarrierRef(address, c(0, i32), phases, num_barriers) + ) + + def __iter__(self) -> Iterator["DialectBarrierRef"]: + if self.barrier_ref.num_barriers == 1: + yield self + else: + for offset in range(self.barrier_ref.num_barriers): + yield self[offset] + + def __getitem__(self, offset: ir.Value | int) -> "DialectBarrierRef": + return DialectBarrierRef(self.barrier_ref[offset]) + + def wait_parity(self, parity, for_tensor_core=False): + self.barrier_ref.wait_parity(parity, for_tensor_core) + + def wait(self, for_tensor_core: bool = False): + assert self.barrier_ref.phases is not None + self.barrier_ref.wait(for_tensor_core) + + def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: + return self.barrier_ref.update_parities(parities) + + def arrive(self): + self.barrier_ref.arrive() + + def arrive_expect_tx(self, bytes: int | ir.Value): + dialect.ArriveExpectTxOp( + barrier=self.as_barrier_memref(), expect_tx=bytes) + + def get_ptr(self): + return self.barrier_ref.get_ptr() + + def as_barrier_memref(self) -> ir.Value: + num_barriers = self.barrier_ref.num_barriers + shape = () if num_barriers == 1 else (num_barriers,) return ptr_as_memref( self.get_ptr(), ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), @@ -805,8 +864,8 @@ def as_dialect_barrier_memref(self) -> ir.Value: ) @classmethod - def from_dialect_barrier_memref(cls, barrier: ir.Value): - """Creates a BarrierRef from a memref of a dialect barrier.""" + def from_barrier_memref(cls, barrier: ir.Value): + """Creates a DialectBarrierRef from a memref of a dialect barrier.""" memref_type = ir.MemRefType(barrier.type) if memref_type.rank > 1 or memref_type.element_type != ir.Type.parse( "!mosaic_gpu.barrier" @@ -817,15 +876,16 @@ def from_dialect_barrier_memref(cls, barrier: ir.Value): ) return cls( - base_address=memref_ptr( - barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE - ), - offset=c(0, ir.IntegerType.get_signless(64)), - phases=None, - num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + barrier_ref=BarrierRef( + base_address=memref_ptr( + barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ), + offset=c(0, ir.IntegerType.get_signless(64)), + phases=None, + num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + ) ) - @dataclasses.dataclass(frozen=True) class CollectiveBarrierRef: barrier: BarrierRef diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index cbc0ef9703aa..4ff17d5c99cf 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -75,10 +75,6 @@ def MosaicGPU_InitializeBarrierOp : Op { let summary = "Executes an arrive.expect_tx operation on the given barrier."; - let description = [{ - A single thread in the warpgroup will execute an `arrive.expect_tx` - operation on the provided barrier with the provided `expect_tx`. - }]; let arguments = (ins MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 2d75c42424ef..444830a3d75a 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -666,11 +666,14 @@ def test_initialize_barrier_op_lowering_rule(self): # One nvvm.mbarrier_init_shared is issued per barrier. self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) - # Each barrier has its count equal to the arrival count. + # Each barrier has its count equal to the arrival count times the + # warpgroup size. for op in all_mbarrier_init_shared_ops: count = op.count.owner.opview self.assertIsInstance(count, arith.ConstantOp) - self.assertEqual(count.literal_value, arrival_count) + self.assertEqual( + count.literal_value, arrival_count * mgpu_utils.WARPGROUP_SIZE + ) def test_lowering_vector_op_without_layout_fails(self): shape = (3, 4) @@ -939,7 +942,5 @@ def test_memref_transforms_with_transpose(self): self.assertEqual(strides, [512, 4096, 1, 16]) - - if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9ffaff121849..8d0c56877bd4 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2690,7 +2690,7 @@ def add( ): del ctx smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() elt_type = ir.MemRefType(in_gmem_ref.type).element_type memref_bytes = utils.bytewidth(elt_type) * math.prod( @@ -2714,7 +2714,7 @@ def add( ) set_in_transforms(load_op, [test_case.transforms]) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) @@ -2767,7 +2767,7 @@ def add( ): del ctx a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() memref_type = ir.MemRefType(a_gmem_ref.type) shape = memref_type.shape @@ -2799,7 +2799,7 @@ def add( collective=ir.ArrayAttr.get([]), ) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) @@ -2894,7 +2894,7 @@ def matmul( ): del ctx lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape) @@ -2924,7 +2924,7 @@ def matmul( collective=ir.ArrayAttr.get([]), ) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bbd6f57fd653..23e33e68e2d3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2531,8 +2531,6 @@ def pipeline(*gmem_refs): np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): - self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. - blk_m = blk_n = 64 @functools.partial( From c3598cb4394f03f5417ff7d83fed976ff144cda3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 24 Apr 2025 05:25:54 -0700 Subject: [PATCH 0779/1769] [Mosaic TPU] Fix a bug in signed scalar upcasts and add a test PiperOrigin-RevId: 750954967 --- tests/pallas/tpu_pallas_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 3746214eac18..c2a28b72dd24 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1982,6 +1982,23 @@ def kernel(x_ref, w_ref, o_ref): mosaic_nans = jnp.isnan(run(x, w)).sum() self.assertEqual(jax_nans, mosaic_nans) + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_load_upcast(self, in_dtype): + if not jtu.if_cloud_tpu_at_least(2025, 4, 25): + self.skipTest("Needs a newer libTPU") + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, 0].astype(o_ref.dtype) + x = jnp.asarray([[-1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x.astype(jnp.int32)) + def test_masked_store(self): shape = (16, 256) mask_shape = (10, 130) From 6a6a13dcfe5f6dbca4155410dd32dfc2d5bf1cec Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 06:24:42 -0700 Subject: [PATCH 0780/1769] Removed a few unused functions found by vulture For posterity, the command line is python -m vulture jax/_src \ --exclude jax/_src/export/serialization_generated.py \ --ignore-names "[A-Za-z]*" \ --ignore-decorators "*" | grep function PiperOrigin-RevId: 750969987 --- jax/_src/export/_export.py | 14 +------------- jax/_src/lax/control_flow/loops.py | 1 + jax/_src/pallas/mosaic/interpret.py | 7 ------- jax/_src/pallas/mosaic/pipeline.py | 15 --------------- jax/_src/shard_map.py | 2 -- 5 files changed, 2 insertions(+), 37 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index e01eca4a62f4..0d7920e9b206 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""JAX APIs for exporting JAX functions for interoperation. - -""" +"""JAX APIs for exporting JAX functions for interoperation.""" from __future__ import annotations @@ -1217,16 +1215,6 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], return tuple(all_in_shardings) -def _hlo_sharding_to_xla_compatible_sharding( - hlo_sharding: HloSharding | None, - mesh: sharding.Mesh) -> sharding.Sharding | None: - if hlo_sharding is None: - return None - return sharding_impls._gspmd_to_named_sharding_via_mesh( - _hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type] - mesh) - - def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, device_assignment: Sequence[jax.Device] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 9bd358c2ae9a..d1220ba3fdb3 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2078,6 +2078,7 @@ def new_cond(*consts_refs_carry): ad.primitive_transposes[while_p] = _while_transpose_error batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom +core.custom_typechecks[while_p] = _while_typecheck mlir.register_lowering(while_p, _while_lowering) state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 448690785986..f7160f5af386 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1480,13 +1480,6 @@ def _get_grid_point( return jnp.array(grid_point, dtype=np.int32) -def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) - output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) - def _uninitialized_value(shape, dtype, interpret_params): if interpret_params.uninitialized_memory == 'nan': if jnp.issubdtype(dtype, jnp.floating): diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0f0e4a342fc7..5ea992b5443a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -20,8 +20,6 @@ import dataclasses import enum import functools -import itertools -import operator from typing import Any, Union import jax @@ -159,19 +157,6 @@ def _grid_size(grid): return size -def _get_indices(step, grid, offsets): - """Get indices for a given step and grid.""" - # TODO(enriqueps): Implement using bitwise ops, avoid div/rem since they are - # expensive. - extended_grid = grid + (1,) - strides = tuple( - itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] - indices = tuple( - lax.div(lax.rem(step, a), b) - for a, b in zip(strides[:-1], strides[1:]) - ) - return tuple(a + b for a, b in zip(indices, offsets, strict=True)) - class BufferType(enum.Enum): """Buffer type for the arguments to an emitted pipeline.""" diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 3496f77078b9..dcbdb0896916 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -687,8 +687,6 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, return out_avals, effs core.custom_typechecks[shard_map_p] = _shard_map_typecheck -def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: - return set(mesh.axis_names) - {n for ns in names.values() for n in ns} def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) From 0e26cc49c49d1fbf02246dd777fbda31506600e4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Apr 2025 07:00:50 -0700 Subject: [PATCH 0781/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e7b3e6d645d1f3343115d1c5dd8d104976676152. PiperOrigin-RevId: 750979665 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 912ea661a8b9..c7d0d5cb7c3f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "99b7c3bf05c3877c70ad587439b7481889810564" -XLA_SHA256 = "148505b7fbbab60879608b43e7d038a7e8c97ddd6e2c6f45c11aca37e348b6a9" +XLA_COMMIT = "e7b3e6d645d1f3343115d1c5dd8d104976676152" +XLA_SHA256 = "4719cec5489231a3a90840d4ff5ec7bac4ab5dde44757ecf4d283f9dab485a0f" def repo(): tf_http_archive( From bcd59940a71b1e7ef1e5fca42801a1d4769e9b19 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 24 Apr 2025 10:14:01 -0400 Subject: [PATCH 0782/1769] Update FFI tutorial to new shard_map. --- docs/ffi.ipynb | 9 ++++----- docs/ffi.md | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index f74ae9d58a78..aafe9d56e82b 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -730,13 +730,13 @@ "source": [ "This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.\n", "\n", - "To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", + "To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", "That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.\n", "Specifically, let's add support for \"batch partitioning\", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.\n", "\n", "### Using `shard_map`\n", "\n", - "If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:" + "If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately:" ] }, { @@ -746,9 +746,8 @@ "outputs": [], "source": [ "from functools import partial\n", - "from jax.experimental.shard_map import shard_map\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", "def rms_norm_shmap(x):\n", " return rms_norm(x)\n", "\n", @@ -781,7 +780,7 @@ "source": [ "### Using `custom partitioning`\n", "\n", - "If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", + "If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index 97648c78e118..106b8118f1ab 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -556,19 +556,18 @@ print(hlo.split("\n\n")[-1]) This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given. -To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. +To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes. Specifically, let's add support for "batch partitioning", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding. ### Using `shard_map` -If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately: +If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately: ```{code-cell} ipython3 from functools import partial -from jax.experimental.shard_map import shard_map -@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) def rms_norm_shmap(x): return rms_norm(x) @@ -587,7 +586,7 @@ assert "all-to-all" in hlo_data_shmap ### Using `custom partitioning` -If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. +If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: From c11f2d1ca9ab388a0a19deb0ee4fa56209732852 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 08:29:18 -0700 Subject: [PATCH 0783/1769] Change pallas/distributed.{ipynb,md} to use jax.shard_map PiperOrigin-RevId: 751004262 --- docs/pallas/tpu/distributed.ipynb | 35 +++++++++++++++---------------- docs/pallas/tpu/distributed.md | 35 +++++++++++++++---------------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index ad047963fbce..f2b9562c0db2 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -8,11 +8,11 @@ "source": [ "# Distributed Computing in Pallas for TPUs\n", "\n", - "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", + "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", "\n", "Some recommended readings beforehand:\n", " - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n", - " - [Collectives with `shard_map`](shard_map_collectives_tutorial)" + " - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial)" ] }, { @@ -47,7 +47,6 @@ "from jax import lax\n", "from jax import numpy as jnp\n", "from jax.experimental import pallas as pl\n", - "from jax.experimental import shard_map\n", "from jax.experimental.pallas import tpu as pltpu\n", "\n", "P = jax.sharding.PartitionSpec\n", @@ -289,12 +288,12 @@ ")\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " right_permute,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -302,7 +301,7 @@ "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.ppermute(x, 'x', perm),\n", " mesh=mesh, in_specs=partition, out_specs=partition)\n", ")(input_arr)\n", @@ -448,18 +447,18 @@ "\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " all_gather,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False\n", + " check_vma=False\n", " )\n", ")(input_arr)\n", "\n", "# Compare Pallas result to XLA shard_map result.\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.all_gather(x, 'x'),\n", " mesh=mesh, in_specs=partition, out_specs=partition\n", " )\n", @@ -834,12 +833,12 @@ ")\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " kernel,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "pallas_result = jax.block_until_ready(pallas_result)[0]\n", @@ -850,7 +849,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n", " )\n", ")(input_arr)\n", @@ -1175,12 +1174,12 @@ "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1230,7 +1229,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", @@ -1618,12 +1617,12 @@ "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1667,7 +1666,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index deed916ceb62..36528bfbddec 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -17,11 +17,11 @@ kernelspec: # Distributed Computing in Pallas for TPUs -In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. +In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. Some recommended readings beforehand: - [Pallas Pipelining on TPU](pallas_tpu_pipelining) - - [Collectives with `shard_map`](shard_map_collectives_tutorial) + - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial) ```{code-cell} ipython3 --- @@ -41,7 +41,6 @@ import jax from jax import lax from jax import numpy as jnp from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu P = jax.sharding.PartitionSpec @@ -251,12 +250,12 @@ right_permute = pl.pallas_call( ) # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( right_permute, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -264,7 +263,7 @@ pallas_result = jax.jit( perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.ppermute(x, 'x', perm), mesh=mesh, in_specs=partition, out_specs=partition) )(input_arr) @@ -384,18 +383,18 @@ all_gather = pl.pallas_call( # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( all_gather, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) # Compare Pallas result to XLA shard_map result. xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.all_gather(x, 'x'), mesh=mesh, in_specs=partition, out_specs=partition ) @@ -728,12 +727,12 @@ kernel = pl.pallas_call( ) pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( kernel, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -744,7 +743,7 @@ def lax_sum(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') ) )(input_arr) @@ -1048,12 +1047,12 @@ def pallas_reduce_scatter(input_arr): pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1080,7 +1079,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), @@ -1452,12 +1451,12 @@ def pallas_reduce_scatter(input_arr): pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1484,7 +1483,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), From 5782425422cc1bf82b3cd4546854a12d20c0813b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 24 Apr 2025 08:52:01 -0700 Subject: [PATCH 0784/1769] Fix shardy sharding rule for SVD decomposition. For SVD on GPU, we pass empty arrays for U and V when `compute_uv=False`. In this case, the shardy annotation should be the empty string instead of `"..."`. PiperOrigin-RevId: 751011409 --- jax/_src/lax/linalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a49936373835..848107faf204 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2794,9 +2794,9 @@ def _column_major_matrix_layout(dim: int) -> tuple[int, ...]: return (dim - 2, dim - 1) + tuple(range(dim - 3, -1, -1)) def _sdy_rule_for_aval(letters, num_batch_dims, aval): - return " ".join( - ("...", *(next(letters) for _ in range(len(aval.shape) - num_batch_dims))) - ) + d = len(aval.shape) - num_batch_dims + preffix = "... " if num_batch_dims and d >= 0 else "" + return preffix + " ".join(next(letters) for _ in range(d)) def _build_sdy_sharding_rule(num_batch_dims, avals_in, avals_out): letters = iter(string.ascii_letters) From e945221fcbc128879b2791e65333310484487efd Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 14 Apr 2025 10:44:23 -0700 Subject: [PATCH 0785/1769] Added new multi-controller JAX guide. This guide was written by Matt and Skye. --- .../controller_and_local_devices.png | Bin 0 -> 132362 bytes docs/_static/multi_process/mcjax_overview.png | Bin 0 -> 145901 bytes docs/multi_process.md | 712 ++++++++++++++---- 3 files changed, 559 insertions(+), 153 deletions(-) create mode 100644 docs/_static/multi_process/controller_and_local_devices.png create mode 100644 docs/_static/multi_process/mcjax_overview.png diff --git a/docs/_static/multi_process/controller_and_local_devices.png b/docs/_static/multi_process/controller_and_local_devices.png new file mode 100644 index 0000000000000000000000000000000000000000..ad74cad6541796967c7d09630208f97fa8fe10e7 GIT binary patch literal 132362 zcmb??Wl&sQvu+XykN^P!!6iTf!Gi}6?(WV&aCc`yupooG!{EW)3GVJLlK}?z!MWs3 z-tV3|=TzN4w`z(XyEbceukL>O>0Z5q739RxUlP82^5hA+q=bmllPAwQo;-PKiGuX_ zWD#~~_~gmgCz2w9DsFmvY0uqM)UO}v8Yh+yY+M(Y_oWtBXVCL;TwN0o>-;~wdxyF4 zLh>z=$aiW?2|Lt-1HIl|RMdmV4{yP^x3;NRneKe8_Y3 zuA2!_5cB`{?*sBtsQzv688?h`llSlPZdT+7&8dh=qK@frQNI2D@y179dB>l{KP+LAbDw;$R?%_)7yBv+-} z`c3VEq=eVDjEBjTw)ci5_LoqRw8ucAt^bLIg>j1{gJc_90w6P4eucEGDxNFDms@Gk zpDm0nWXkW%)sZW3VwLhdz@rr@x+rBZ7IDhkn38uQmHMxgZ~MYpL>3AOSgW#!p&*6( z40K}z3H{`F{b!xS24_B?c>W-#4u#T~c>Oe|q3TyN{E1;)-uHCoRcV#^l4Jw}6m>7w<&o0@b znc!eZ16h#Mq@0z+_uL#et8u7HL-wIDud{l#XnUs1ZJQb7jQkGd2xS(`G%k!vh$CL_ z-r!5_2uo4i$xBC!&6;Br-1r2Xp~JZTwaU@gx2!Y1Z0YD}6riAo zqV$#mDyjW6I^?j%X&I%;!_1t?ybQbAi6hM!ebLNXgGx>sa`c>Ucv_*BQ-yUHDP9=w zzOW=>5X2s9X`!K)RP(fmW^43Rp;?|u2>o>J^q7Z%L5j}x%?#z}h{7n;(j%EK^8k;j zZQx&f+@eJ?-%u7Z(4*B^XF8p)kob z)O^oY0><;ruVB zu7Xg|J8%%dZORn!i!fhs*G5kcc}6U;yGrzkxcPeZq4=5b6%WC>yrPxcl^9$2C`kBN z=5?Kft(kF#`0=ncZ)5T^=h(!A$%I*3o5h^mV|sL6PCLAU$qIYIQI>XctmamxYBthw z<0@^k9j`WCy65@G7Q518{8g7!!i=*~P9oAbLvA`^nc?-01cM;YKZw*Z>fb*H(?Vxo ziCYy<{wG59UY=8Zf_Pjw5(r!@zCCud{+o6Kh^Cm~pn;i``f;b~efsWtUB&L37QaT8aO9L&nnr%4FW6fzx$| zxu>g`J<4I{!^~r@h4|D`d5xd z%fJaB>Up6}^P%QD3^UG}q=Dn_w(?56Q%A@G6K#935QXnP3R}vX4iyUU`OSWp*qKC} zk~ryrMl#^X1lDKgG)_8tNhWfC@kzwDA0_4V`5#PQuMi0<|AX5{A{1g8f8U|-<;FPP zP2Q6up!x%$|L?yYOjGkTL*{T)E%=uLjD~GX3Dl_xF`|%{e-GU6=??b#Ww5j9usa+T zM*x?5II>n&dhrWA-QTq_5LvQe$>c5OBTf3>v2g!Ht&~67dRf-Uo7-ZpQM9-#gadmy z0PO>lBcU+NJ(szonQx+O2!@v1_Law9?dMaU%f z)Le}*iPB?_yCI=TcIsES!0Gopi_Kd4ku#aR6Z2Yf)dFf8d5dO@mpzwrFJPUd!eASH z!Igp|bIk;!Rw83~(2T~I@*ab4QP4c2>J0Pz*m2gtW!?rz(~tH7c!%j`9W1_3Ql5}4 zu}Zhs{j0R_OM*HV1$3&EPd+VqEv6)tcVcvKq6r(}bWL0YiNgnD3DcN?%{WQ6Kh}tv-K!i9A3?o82#o$V_5ud?q{xe$^81qC;1g?}jeL z9f4Ejt2$H^Z^G75k$2)T|4cY}86TiQv<^%;!8M1?0#N+EHXde}yZ55Q9u2EnJxbqg zX9=^4Y#X0wsb&LXxr`6~b)|?WU)CD&hjb_%TVFx>klB3_25F9)@@_#|D_aYuV5|#{ zIH;SpTn0d*DFy3(B(W6gT?#F!jY-Ib9;1{xK**#_{XtuP!W42uUF#&N>BVvEgzJsh z+cBlh|b?v+m#|=*WY| z{rdWo58Rw#Uk`M|Qh>4e(ekU>jF~^(&MwX$7bbn0x#E)Kv_x;jSxQ)MKO?7nGk-Jf2eBI|kBflGDVeSaMjF}S`q6ujue>%$y z%{z3ck7UuSPAB_R+YdKc>5PB*&9l$!&Eo8=uBbaG$0<3V{Q8yA$Xl+7Ci0^AzY-c1 z2Z`F@3t5-!wGgheF@QQT_}{BePdHQF8XptV589zswSw3cAfE0SZVQ}z%lKQ`gp@;D ze)+->w(wvdZD|bl;;P)=b{iP}`q$ei5tHg0Kah^Ihsu_c8V!Q8f9%DMvk#Pyc(QkN z%I+>?<|df=6G)A+F_@~<7VhbS%HwCA?C=g;d-$+5>xXTmAa|pn{<9YEPGrG-FNQ7X|h$%^0_D(MhLQM0Tw6iG$pa8pZG@l&Y&G^D2oc+LasA{J1BEw~Ww= zx8BAti3 z7qwL;M^onS`3*$=GH_DVzOxzfefA`P8Wv2}(fvQv`EBP7T7Q7k@|Qs1ThRNvndY#6 z=`u&;y7!Bn+l~izf(<6gmLceWu=$&et>^?&41tNgmgxUos(4M&z?yB&EYQ6=Q5>Yt zl7Q#W^Yk1R=D`=j3~l*LXdpv}T3uR)wn%6=^3ag!-H^2(Gi@{`jmj`PW2bXIboGeQ zuBojge^@*(vjWqlc7c5dF_RTr7qR@1efMZ$0v4iYCgLOH@Cdc6ahuWWtfwfLMt@8O;EsM$?A6i398%zO51A5Iej|KMTf!KO)r#t7Cst2befXF2$A(+UV2&}R^OWUFgaln?PBS?9UE_&R2#92V~#n6+*aNgb+6H^cC*GP zmN>EEdmVvP=J*jtw#mXB3kSZ#C@H3fc{>sWPFTi>nL#ffUE1Kf7ZGmeU#NqigxMc- z#(iMm4?F!)&Y!1_?#@eCmT^-D8Z^IBt>%S33($KgN#J|~Z59w_?{6j%@HEDkdLZf7 zq}!79nAHi)Is--pT0k{)C76;2e-|$u!cU!*Muqpcugsl1@`s}HCttjO#*>A0=FzVb zq=8+-OFaR1IkDYf1ZRC{M-fIpPck-5Bfe(ssee_%8}vNMRPlSiZh8PXU2IF6h%R+9Tg(pW9(dWJ=PI& z`l~7hs})2cwqtxR{mg&odjFu;eUmVMr6)XM8q#&aSb1xT_I$?6y zyq>(JT1nx@jVZ>&4v;g=*|MHHKW9g?NIt2Eou4b33;;&OS243ulram;{pJo7ecNqO3?2} z=8gUQ+p(W7rC^yjS`+g+jPL?!z)sk5sA9_6sC=b3+SpB%KK zJSNk0xs;tKM$$B|1&tMyJt9D81LfU{N^2B(gi#C#Y5X-K0HeS;UECPRLf;AQCyxJl z2u3%4%srsR3pS|p$6wQv6LX(;%rDMj0PgANLh@TghuPKeX8ksn*eX;1bJUoI-?d&~ zY9hS7IHIOnFCea-X>kAy5)vvh58ntX=X}?L1lA2=ZHT5$ML_j_Bs~Ai4(UV!efZ=9h`qB2bq#2 ze%2E57>>Lz1duR#G!o2lcJQ_&rN@8yFv!YQ-oNQ&6n3P^u#W;u2(0wjJHlm97fuz#?eOqtVoKvY(Yuu!j=!@K z8IXczfe)zF4qg@DvIyHf2v_#i62r`dn^p~QcMqlgD03P&K(MR8STE)vvRGV$0KlOoA(kN+(=oX$7s)l zmt+NoBn}>ng6w-0rjxYTZamnruq&vuuD=ZaSpPyV!f@mG{SRMPda6TwHA0L5C5 z${BDOIGVBoXC6=nA=smI_kLJG_P(wmUD(pvbe==lnS{AbCD)D0*WaW8(&Fu&mC*O@ zoKzg~mjg;ohd7^|o7-G_SosLocT&b@1$rOBe{T70(c8S_;bauL^FG;TJCsK$$|A${z+w&7zdq`u!13534%?he|gEaAVG$lxbdscHU(IJ`Ks#L z+9WKKe#M-qbP=kl9uf0m(B zIGKbEMY{pkvygVCd~Y#&WD6n)Z44MG|KV%~2uXqAIaMxb<%+MF=y|I@-vxDQ=(`_S zrBGM@Bl8g=!g*jlp#s@%vh)ImGvct7*YH29A?`r(h9?m#b%u@qk;3g6xv~P;j`CJm zmC(|UM|s@}z7iht&&uf<+v`4b{bg$MAFRHt&wVdv493ItamnQW7-;&L3s987DU;zY z7W-~OV0iZr2ej)VmVVTw>a+9V=sWodMV|9$s8q5Ub((n4R^&odI%k>sKQFdFJFp6W z|G1mM=d{&ErYQosA?165Zg5+f*`QTbOj)QnC_maGQjg8{PiwIqNh8Hb|FNOB#8lv; zN&%v>_TPuJJL^&ob9(7nYur<6^mZnkVIY@zqSB;JeLb^;`1~ceLJiN@fB72uoJ!#B z5{ZHd(JF%whS>ata zbw(1Pji}}GceWF4XuEFDyTRS@{sfCU^=Y-V+a4%6xS0H*ON-T;tY1Rm1iJV?2Xd)w^ zrhS+tfm!GZ=+9jxLOc~^vBGrPn(Q1}{gv&T%9|4rp5K0Dm--X|shQ1eEK|`B8RHc$ zpZmsP`tH{TJzm<+=$>X`wC}tvdq3@LL`^#tP*%3Fmu<1=1P4Xei4kFEJ#x)5F`+x@vzt|C+MaV*(!Bb09Wr29*86tzAgABc=jvg_xiDc z8C6Chdt>B2vuzJ8n~Pf+Oyl$`Uh4sxgvQ!yS!dtETDw zS$?O-l_gWFo6GiEJNYI{uBud(2~G7~wsKjs!&=i;I>Zh&1;VzIvX#+he`X6@WCx3Z zyhi!GJBD|gGH9jq&m3n+#yhh-O9Up=gEhi0)W0cEOyNJbhZ2%-ygueq%%+yOO*?fR zb^fWtB-Lq<004XTj)<5;iFrCH9QDlK{37wwB~$BtNcR}%L985Uz9ux`JVere2<8~% zE)ur&G1COvytuS|n~$Torhb z-n+fT&T%C}nk}bKagI~<7FH%ch;?gt3eL8k(TqM2U&i9oA>4)n<S4Y_sbBo zM9z8w=oX?W1O5>tpU5|}kEh>}t#-G@ zYLhkWU_3n*dwvO&7bvj`(!Z)|2;&>JXmR8c-3Ey6y`Wm$kR{`df zc~q*-+)H6fj?GH#y_-y*<9#Jhi!Jlv2_piuYvh5TgY5$q{%GZpC~v zvQ@8ScHyOHE)%kT=SKCD%eHLT0WPh1_l#ljcJEH!>O}fAkv_-I>q@JzVdI0k)YA)c zwvQO73$MDa)K|;7OdDGtpk)(pOweEDp>Z_f_&+Meb%D44Y{>a2j zm;gF8JZdw|sDOPHg);FUdfgAN2J^+$0BAz74$^^JQG6Uu)DLB!-U-eEGqNB*mXTVS zPp#u=FexT2@Qi0U(n?$tZ?~WVvJ6HipvPmbgu1kh8yUfqTS4s>Cra3}~ud`?F z*74Y9O|F_?f&>6hXYrvUFcWh}wx^=z4V~bJs1q}t@bGQ($@LJ@)XUqHHf&E?%T6UN z%HoH#k`#Q=n0nzKA$b6Q2Sr%uK_-RL{w)ccsXTE0AQGa52_h zfsn(TJr9l_UiIKLxVf~m;VpX>7Bm?<0eUU)fpSP`CtMFGRlt)iE^X*93eT?m$K|2p z^~CUPi^=}~w*OFN*k+O#{LJf>ZQ^{lE&&%i=^)1F!^=?Z3t!67E?+cGIE7+PQoQ-FQwO8wQ1z~U6iOk9F~H%@-kc1FXni6)pd7OR#JoiXeF zJ1Ji<7VdKBz=Hk@q2O3?h#^PEEdSL{s&rMShpw$zw^lL^N-H)-6MG+tF&wWeIB*fa za(_Dcr+3H8lyklhaqsCn3Y^cmHD(+S$ZTmVnm1Vlg4}*_ntY+xAjqhRAhGj68NYOZ}ji=$t%Qrn0TNGa&dy~d99ArEu&YJ;(LSRAu>o-9n19!_T(1OST+ z?33)ZbqJEFMcuIV6@-!Bg#V)HT-EshdI>SRi5Rp*qWx?8U4_-MVJZ*smFLBtpqbeR z3JMClyKCpAdV5I4Z1*5uxz7#ZW%zNRe586XdKh@CsTf__98~V$h}}kr7G`_K>vg=) zq;@kr39KNX8`XR}G_Xe&HkB2{0fST0cYJ6ym(^dLZM)+m7zL?adD6{PXl`y{j#6N4 z@nHq6{;)V-E_cfdx6WV-IIMpzHGjg^VbZA`Jv10hO2Wq>mbhTtWHPPG(=~-rQM_{R?1()&;`nfk`C<~HL#$E*vV=%2_Iixpq^3eCL zHyVwiB;*|=8n2j+gTcMsDAd2i$ltNe^&IgB-@PGYO<9k=E~+djzKeSsE8)(@7*5Cv z4s1kghfU^N`L|oA!#PYd@+3T-rTgIqE43hM!we zV+nimBK`(hx+kgUiklHb%LAyvW(hQb4=|2L;n`FVJA!0AO6+;#L3qc(0YX^2gxr#93m}m;LNmjSxy`dZhKFh?HVLy zb!eBSf!9&g+4*CHp&+Dw)2 z0XRW2scxr6KMs{Al_qvKJq$6xkFvXlo%iz6TCI?cg9{%|e%D*|^O{*#cQ;_VNLB4_ zI_49_9t~wj2Uz5DV3WfueTfXfajK71-85?PIrAwXDlRv#M43>T! zq#XMt(OvJ0x2O8UE$1!V{G29?m~mXdu}StSNY#9#^LW{DbeUs@Z({Zx)5PTFwP6=8 z#SqI3Sre58YyajQCSS70QkY!^zjwLUmBWvFi=4JYYd(l=$U^a$)Jluz5pO_*g-g69_fPr5Cl0%yO;O%{&C<=-v;U|f%Gq(fI^8qQ ztM*yLmqqwntG?AcqagWe7$(pr%>&~7<5WF*ZN6Sp4Dhhi|A~NZvAGM|eP;5><9%Bn)A{@|A2@OtxslZY8X3}YL0Ny- zyhH?9t9pg|Ch8M!zS~$Xj@}+|CL-0j=}8k3j2~f>`QxqXkL_-fQFbM5#}+^_sC>We z1xIC-_c*wucNoPn%uc~BbHXNN@uXLmH{xTEaMvL3qrd!%ESJ5;tvg;z-Mv=)7`ah@ zFkCVVi?WBI>k*p$jDc=v${FPAfT^9pG;6d%MPm@^rrk@uAorGwE>3V#0#L&^T|DvP zK~w3im3WYmILJ1~Q<3Qr zPIq5Z-hPn?I|9>*-~4yz%(16w`IA_O#OiG?7>zd(yag^RIv)AMywg{VluMWc)uhRI zOV{%Ru}y^+o*r|g#fs5VEFEYhy@!%jvb`>eXSO3u{lAgJ;hos!cpTVURr5dG?jNud z5Qm&}UlDznv%@BB(AVy_9hgl_I!9;h21ZC%2xvGlU+@q|H2M2E%CB9k?h;Xyk~9Pl zD!0|p0%ec>*vmkniszLC`7xqmQJK%$CFT* zK-nm;eY@-Ave0|d8Dz^uXKKo!u?k+%msEG7c3xYOhjCy0H4XhU(1NjCbY>;1I9EFq zb`k;;W8?1F8N+X$CV+jaCxXX*XzWxz--^($*26pZRCj3WgqrQ1;DOsZlYKqy4Gx%f znjD5L2>$0;k~%>g_7|LsdF-%Plv@Q4ryew>f^ecv$o3*`NY#bzEpwH@HUQBrU5&Ns zSEd)<=9;tTsGRo3SF`o6Ep+b<`!MDpC(li4HG)aZk?ro1YusJ_qc7olm@m27O(g%v zsXU0MMy`PFNI_-K^Ho{bmvL@$WW`y#Hewn<;f1U)qE2Q#Ir9a!+;^1p9xU&~=IlxC zH;|fTyx$D6?rZN=O{(j@yu&bFAg-#aF3D++cz)I$trl3?Yn_LQ+!d85uo&z;P zJdQeb_;LsIMJ;Ij#p)9&tXpakcHNcPUvn)uFY~>KwEu?Dnp86+Wmq}VSwYZ@_$>g*pXJqR{ zZP^a%l|Z6cSwd&=4Q=<-wO&{3Kxov}h!+lRHQh=I~YIp~932 zYUPS{rkVDMk4}>guN}nfzoYGxpOsFjLT9BFoB?g5ytU&%8eQ7z1bE~=6v2T$XrE!J zG3O97QC@CRrypSSLw@}l@6n_OoE*?mWMGgrLO5k;X zPgPQ(S+?CDi4W+lhHRsw_216)pxopGL$R3!<}UcYSm4#nhZVa{w+7Wv*&&yvwuR7% zPN@fICzvWXFW#{^2&NKo2Hk&v#FJP zp6UFwO;ZY6Q{Ce~r|*7F$uyNc+89s6}_aqzflbcb~JI%eKo%%xsD z0|mR?dvMW&ofW(Vk&N@ADZH0=SfdL}E30p^N7Nt>40)wsW}WVjcG0&E-#{bFb%{ zX7twX$XcBosP=f%VsIJy6BN;b9^g~T57WED@$%Yp#oA|=zmFi;Hd4C*PNh@Tre+h( zS_Hhnatf9fr+hGDI|{8-dLZwYc<5^^8@P8Pk3U}xaWrb{fo_x0H{zV~)w&U@g|P%3 zX~t*4Erne*F1p8Ci|A=Bf$ve?E0I20re~-CdAi<;&jI27x&C4plB*4 z`KF5=CY@CJ>*4ZyvlY5$NpC_y44m}JeRGU9!XW9s+4%9b!7PFG%wBh#Deb^Q_Gx^E z7J%lM5yQwHQ#kg2fNTB=Xj?j4=MH0f5Xneg=15y^DX_e-agvx13D1!$Oo^KfFPzzl zZIs1oBO~NTUR^e(tweqwev1A*#qVd=bBtmBW1sp7t{q-?I_~_P1;bp%*qMNnxY=s! z?WnrZO-~xha6*8|l{ncj`+;SMbnNi;q3E)C8&Ce>*v=@f6Gfa3a0;%YR%VM!<7`R9 zZFbeeHpoqd2V~Tnt$p3S)!=%Kgf^fP3_#a&FjZX0z^N!{jc6J;SM*9I72W-$PB4F|d?tkMentLX$<0QU7Ta{jF?Kr^32 z3LLH8g2F}i1E?tbqX~zHimZBpOq>YscjZKgyD*Xg^w>-yJF}*05?OR|f>Us?INzCo zYb{wk5aj|bRVb3tM@6s?(b(YqB7AT*wzCl--3aPIWP z^9soat~aX%WILX}vRh3_3EiO8a33Zdu0-MUc2v(ShXzZ2X!TZRtln?dy^1h6o!r*yEfFO-;Im3_^8vXf&jU?SS&q)SU_=^3g zFB38%L{f(~gZ$42N zG#xRq6E$~fIg%tn|VdI4r#=VnCL{l=nbq>-*QLrLVkJUy$`S*hSH!A&?(iyUM+ zDXPTrY=iY+3sX;~mWGr$g;yNjOm5yk?a0D&vOKw(xJgye+II(WDNz63gvXBB?_NoqE2IN`aN(|QE#1;{Qn2Zlup1?< zn7M`zdk+GWBIPL^yj2Z-(6RG6Q|<6`&Qzkb`pk}5`Mvr?b0E?7z;*IO`5vjqfWzd| z1c)NQMk5|={yS=Z@z(3V2$KuQ+7FwlbKTUp1lswr(qq7i4sBJ zjuxh?Pe1Lf8^_!`($*SUsvP16inUNEwt3KbVk2EL)35-)qk8!^)3kd zA?6{iu*Z%{PYP;4jW-y+w$`XJ=VZxa5*8TMKn!|s;AP0j4aX1^zyo&m_yY_#+``a^ zgFKdn73yNhl3h;4wDYmd4tZbeUSG8jo3gz;75d0Qg^%9&Je%>(M(YeecxL3O*>oqK zY@-|4WKhs-${A5CpohXq=e5oq1T!r$WGzwUQ$s{hDDc?=Mz|={1)w6|QFt#hW{(V( zT705%fEiLmB0N?Ld@Iy}fJwL+YBC~d#{CXAEt+%D%l5rwm>tad@(jHmI(MNuVy=FC zxNR?`-|fNd_DPQrW@zLOgw5dj>{muzwR`=V)ir!`QEvVdp!4zyE69uKZJAQkdvA;V z?g4=`FUre=-Y?f|>&$t7YgXuTJ@M5B)oZ-Pa3mi2)-bJO$Kx|IPHD)AHcJOnghQ3N z&i&3ti0^WAqN(w2Z~n{O9y{Zuq1jDMS~iGOz|S_aRx_kmlg=?#+>NgtYyFp_7nrT-r7WIsp3Ua zh=sF=*FFcneiJ3tludxY$IjOH;f&ZG_3+1hU1}xqA@fB+_vMF*8m!fgDWV%BZGlqv zULG=oK?U_MY-T~_7&~u&=wbijunE}f*^gSt;bGA}6R~KCPt>X7Qrf8%^0-rzD48S< zZF257(bJ6_ucwSAa>~B54($hVlY;@;@21T`FW3pm17?!ub0{lQszrr{anBkkYn!$E zEp95_XlL2feXJ(J3=ZTIQf#TgE2_buV8!+=%CTVt937zw&1PYl&2zrq2ZxX*Vr{zn zBM9e1fbxZ`F1Lq;`IyA1wpD3*SIBE@&P?1?r9vcdVy_7v-04)qIFHWj`oT}tt{~y5 z9Af{l3`AFbGBi6fU%*Pb3J?hp*+gyOOel+|+*&Z>^|om0Ff5|rRkMBx*UFEX*JM)% zfiH)b6mt=h=3R`A6d~;e6Pc)luIFM2g5dFnGQ5S0W9lkc0LKXBSRPaE`IhzN+Yer7 zI$j9l*|sRj=P$%U`321+%dBrA1b6S-3l_(*r?t_jg!x}SPL<40RZ^h1TXWmz+<&_0t9On+fWW%5@fOLt#UK&^bcMW%s8^C9pjDqh!&HFWm>yf zU%quJHh3fMz)^aCu4maDXF+nWrw15bq{VQ~djbCmQ>UvP=P$+0!B1hUTIL2lb!KtR zFXio5bIw(Z)rGx}kPq80<2(iERN-hpe+8aKY1h=38A4|@LX5zSMuI;;YVeAqw!}#y z#_rY*%6b>*{h6w96!@_}ycG9u4l;!}K#M#keQob&?M7*ESKrMtcyr#95*SXG^Wv?A zw*_5q+!^RM?{E%f11}r03e^##7l*4FDl*FtUg9_}TCrClS`JW$R&ZERm+DpeC}9l# zQ^W0>bqXz-7257zgaSnlm-sl+QWaYyNF85MSfBID=#)QSmvsEd;m-Z-_2%11oY+su z?JolQ7>tF+Mr<~_r~2&cEn6emR}qes>>pOL|g=;Zyh+Sok&#g=wC}~)m1Fq7rjcW z$ok=2FE(+bib+Zc;${fU)bd-ggA;_3B0M{yA-FDop^O#7^87gFlo@csCL@y(E@`eD zl^7eGQ#>y*knZ@L!+{<*xb{1n@zsJx%6!t15G-cD8k zcQb3qSOg=IXE&3`=u-5DsA0`U@mTY!SV`;uoyK(cxy~SuqJ*&YQ$yrvv+x; zKXZl%B&0IFy)eI$x=V`WA_`Wq-GF&2?Z_`Ub>(eRu+U)WwFjH6&SH26hEMH;YXOBi zdno96rIBg8;c7$V$X9|#sfdSlX9je&Tx7qDoT28!}rp; zF#rtz-j7epg-%&`e<`;4ll3j>L3_-aXq#qTf9M*H^f|S8?y#%*c}oZuW1L$f+h}vy zdh#78F7Q`sw>E^L{K=YK9zH88g}}s^ow7KmSL(R>Hd&{70{(lBg5K{hiMx`1~5N^C8fqs7JE6+8h3D zanGw>ez1MAbVSXg-}+FSXj%>-0%`b~!6v>#D~~px@O@ z9T*+eno@gp=J7Vzn)?%-EZWjMJ5j9jPn_WekPfFE4~)7Z`FnFnb?+1Re?wHK74 zRJScGg|kc>)&sJvcY-||N#z2S_f^!}pT3N1`X<5>)^0mh%Y9Jq`Xf`!QeDLw8ZJ)= z|J-{|871IKvoYz&YGv8o@?|*HhEL> zdG8mo8?s2YUlnq8UvU^wMrPQWXpCC*p;^tHDhZE~Zi0=PIqNP5xlz>Ey&>ptL@-{UDDTWwAt0mHyD7)8Qs2- zv~PQ`DlB#yKeWD2Q)W}o;|RI4l)R}$+U~=rFDM1ff){k>QG&9?^53$w`xcB%p&L;s zm)lIf+FNkM`01669(Zh#^;ydV$AV9O@<9UD=G+=thUZG< zD6dqp4djW3%rZuuCUW_g!>xL^V|rMs60rW@!^X8^0Ka~V|1M|g!c>OkY^g#TPf+&z z&E$@m_9*aA!0Ha>{jlOZnT+_o7wbI4!#T^X;DCbY$PtcChe%8;4 zZ^K-rV}|PTeEu|n^j#L4Pn$Q~hhDGtHEiBpYMBU|@5N&h%{H`V=|J}c8>Yz>C^fNL z)DeD9E(u)goM|tZVvG&5#pmBA{z<7{_(KHYdIJWUJ6*Fwv1xfxtZ`%QGfKcRKVzR`BvTHxI#$d`3qz ziT1H#Vw-B`No*+V3dtv8-(g1v^zWposhym}c6;S3+)IttHHW=Nu%dqF=r>$C3jC-}awZ~A8F1O%_iF?-0)`Upst@Kc82WMYO@0}y zY`mFkfWLi3#uqHyD{*!`A2n)!bo63Rz)5G?^c@Q;j|8f_Mtc4zBGrc6hjl z{G~Gl;OfTCbJ))Gpp-66slbxEKKL+h1ajY!$O9zBWz@?dSl6{9P5(L|i`!jLu$Ovay?2T%VH3_ z9N9=*f;QPM_#aky_}Tv|=5F%aI@94b-Rfz0b!QlPQQ-k0 zUMuq9|9C&kX>(_F;^-(_;;7-kf2uqov<(y69et0jmIofra*8`m`y&LPIL4&Dol82fG z6{zJT0-y(r1EG<gL>wE2|95=at38_S<&~90d?beNH#?7W2XaBL8UR`S1 zc6t>F02RJQON7h9SZiqH)jxx>sscL78z!>9JQ8H-hG(ukBA2PJ;ocS=gLsPz<7KtG{+UCBn4SlI$N~`?9gjTP9p&`b;pVfQ8eKH||DK{u*@NoZ01eT60 zU-x+L`^^Bnah{YN*%i_v0iHFba%JX@i(QO--Aa4_$4^{J8E&7A4=#uE!Ca@C6L8I( zsj|nf5i-kGzo=Nj0v1ZHGda>a7f4zZGs5hwx#i#NI(6N!FJ#tPK1thc&=7V8Y#p== zh$`9NRm}MM1@%9fgx44o&6KQuh0>`RUMw|ymzC)#$7;t(4U)5)-G$b*%QT=X+^lBHXMv}%QR+& z$_)v1BgvVG;XU=qf<$p~5&L=fX6Z#ps8|qvduc7CP^OKU>~u_Vy8A6pmsc_CSiwU8 z?Rhd5zrz{g#om1ARn-Zzk;TiCKD6F&biuZ))@`C>V`bB@*OXqSY5Zk@he1g~ zkYgI$)0@Yc2Np|2%Z|o3kFA?O#Em+rlhogMU zn6%*`9c$MLLr+*Ax@WS#Teu^srU@m6Jx*k*JkGLh6fm z)Y%V0f-%s!B%3`a@#>Zd$Y*e{6t=@`%+zRn_!(}sjH3Uu8`ac) ze@9gGR?{1($;@Kqe;&mPSzWZUYrQ?8H#CIA!i#^UCVej&8`-I+T(w1c7-uwF8Ai9A zh6U9T)*G%@);xaMYNa$&yqtL(pnE%eVSZg4C~0?X8uZBO?)~OkT$x^)j`Sp>r2(hqZM&~K>85{Iq+4Ll=gTY_Dtz=SUN&)rfC`Z z7a#SXpEvfbJsaxmw3E!*A|RZoTOUU;N=TfhI&mgaeD8#oYD-OTKJ} zeDv-ih6WZ*@7878JdP6b+p!r^eF5Az3+?T3_F3H6ze&zkgfo9JSsE`Xqb{n!53IA` z{y%iRQ+T9d+oj#HZL>SJJGO17V%v5qR(I@lY}-jE>DWfawmJFUnK}M6+Os>)p)hQ0ExwO_prPjo>v0 zpO|V0q?t2Mpvj3{xy8u`GPRv&^91+gCxPhTeK;y>(zd`A#G0)}qA|-ZUob8Ue>8nHJ7JMkIp5X4 z5IG+#BjefO&$kw292VH?;3`TTE~Mqnc#-Y;fAgF?1lxhmqH8(Co;@F=^98pU(h`-9 zlj!Z-df{)u&%kKe-luC8plpoj)w1Os0dY|GjK@<8_2(#SKOd}#;(>Gf*FF;z9NY4% z`(Gua6>nM1%^(}tMN|&SZqQNzxDxMoY(s+sg#Mj-Dtg^e96p%sbU;^?+7<_7a>bOW zKhdcpEO8^qN}#2Lur`Bn;{P$}E*4vlVu5V{nzA5JxYsoL#S*@GghunDIw$6RUCd+H0uoM)Af5 zlQ}%nI1si5YHR~40t7LLJ)7-r14f}W+xBO{a01bCB*6WKwn<#ttZKHdYIX6UvT<(YS2UE6Oa3bw5`0=&FsB(WiBj5&Y^|FU z@CeogZE28VFQ+rO}1bstx+r%SE3nxFF1Oa*r@1BLyXqx25zjbLAU zoaV_?3Q-4$I4NAQBcjl|L}>ovCvp7@!I(^u)JOuQQlE7V~ z7W*x;GG4v-yY^e2NR#muM3)Wi9st_`4#Ri0c1^q{5+ZDbtzK$M#Yr&<4Mu&}KgRuo zl#~i-z`?=DG^xvERh7h)oXGJrJr&|$iXE2eA(Y7LWK=={u-G+Z2`(zGAxK^nTw!P* zG==PEblv9j{{8|p89nBB_)GPZINEB#w|19Hi{*9`5{;^rKcVL11<_?31T8dN^G!~~ zVUeSFX+3#T&k+Td$F!kC`vvwE0rU2}h#b5x$r6R?OJGN{T{xR%Y~LIf&q6Y)^k*6~ znKB-6BRRW#r}I;Jbsla+Qk5f*@X}*9om^&k?Ly?USx%%``qv+Q-A|yAvpTJcLZ}?J zm#x+aX0#-GTEQ-94C&sJl|@}@3ay=44dRGRir}SZfE%^&x(WetRfbuOmUD;0Ge6jg zn|ue9=knV9)s|qewKvuZY}^lj&uf;2uXY-?6P{Q6BbR*nBHFROh;nZ0@tAhiUFZ8% zq{@u|jQob4zeoMXv1H;C-Qh*L3ut&MH=~F>sK7BaPjI!S_ zrtrl(oXQm!Q^U01Fb{EoI94?$&$qUz*cI5@sg9|>v9!|oAT&Fv=z2w=J#gO&n!*$v zDD-(nNf@O(-|!HEliyb6i$mCKvZ!G_8D_~@^SJnH%_?|%z%=;ODu#G_C+sK%TVP_0 zN|(WvcqTBw*yX1QWTs^fMP*+SbQ8uqh`*j#C) z)ilc))on*E;`l`u%;?nx#XtuELCaZhd#rAq`v7(&8QT;9jqbR0xs+bN4w-JzzQ28f z7lgqI2U5P4DIHvYj<$_)G~PEaofW|St%EFC7$7E?vGK#dXrUll>yf z(n|-vq2E9K(>44i`@U0jB)a<5E>tx_Lg&H9;y1L6mVD~CT@>0u=6u8>}nvEGWDh=X! z5K<_ZX`yzLOQ@ap>A%!?in>Hp^l|nXbAq%))bw$-*`WxRyZth-gHTg6rb$M{L}C&= z@D)#zprN7Ra@iB^?Ckt{Nlr+}v+NIrGYwxsRa6rgR`LJ3Qp|u6NEV_$({xDMcr=aE zlT}kPA4>dH*Mfz!%j8m~;OhN5+CpL)j0620)(Jz9RnM z0<4}DUBTnVOOGC_k>D7k;aMYZ!A53^;K1~`vM|Fp{e&q{$^15Nc^=n(9==ocVsE`s zb{w@-`q8Pqgw?d>Gcd%3f6l)%!wG;|AjOon7?L}b@%PRU84S>9j?@Ahiw$-#+d zhl(y63vgD!ZSEVli-3l9@_3!A9(@dY+%)xI|y2Uq+7-xz`u~ zv*~cCSKD2Yn;Fx?CZa8j<_bh?wOKJTqs)y*&yLel1d)#M(t%{VV={7bYKUW=tbJ&W zB>o?TEL{o^73VnfcEGmSOx>b79%2%l_@`FtLCNPlD+?yqV}9^lq7OpJIKqK$B)VF^ zXTCG-F5rU$nmc{2;wwk$g9$b#eO6Gc>vfaW45F=;QsBL-a&u6l{iZACM0!5%{&7Pz zDhH!*y0Z;nMuqVHt4AxjoL#t!@uPH|5tXD7v7dKBV)wlG!?Pliq`dK7GP*m_a)Sz2 zn57IC>vs@Y~d|^$PC{ zj|p9P0-8blrD^2-VX;QQQ^BZ~hnB^Ze-MQv3ZSnbj9RB&tW1#5n$=V8rvO>kYCiB}8-wbZ!<) zD0)Qk}>6&}3aoatu}wR(3=m z&DeHE3tmX^=DX9FVldi4w*`5+j!Tw9?z{sC(}#FLAGy$IpQk-Rb?o+Zae!-@YvKWf z_wK@2_%DX9lb?S>h6KlaXKuxPw!{LUZ~~`?%-|ndPyW(rvJ7Y1(}Oy9Qo+*a9q`-G z2ku@Qq1N|FLr-A%b9RMm;6HH##s6F%iTc0s;C!pzpNjo9T?d)~3_OC!?&Kcv_+`4% zOFu+;-#Bu&P3?Uqnh&x|K6ebV2HTQ6r+@80b11s&i#x^~K8AlQj-9^o zW9-FU;{J?(bG;33fx#lx*tFf$UODFKTz9F`H9D3l9VZ1HIv+2+rl;EujR5CE_jNF` z5J<>1X&g6PqA;)wzj*sZ{CKA!p|~=QIcX3`GY5}F(b>q1aEM*9eb*_Lm4OkT?6TwN zY!9w59y63^mB^hy5bC!Qw493*zCRiy1@|Z_CRNL7J#22;9~y~E#0;Hi@&$o}N%k)Z zq9zNJ+c&*mUoH)X##bJ48Z1|SM@L6b{PpWk+VOPVH7tCrvinQqgAsXmT1m0KUPcrXj(mT5n-ez&9W<4>n2ELo;$pRad>cY4xv?!G zJiG->fCGusZSAT*`nky*=Va-Y8>Rpy+pYhYkeV$~U<>|h!S4pazhWf=g0|#;6-J6& zbOO^FMPlT0gRUx{j|7|cctW*pLq#*=TApp1*mw5FB|I(vd5`&TJ?)MiB{e~0vwjHr zIDlodE+{622(yVHzuHC^TGrmt zl}aEcPCpu(in1seCq1-2>Ze|j*QMzljb#c6hC`8*@67ie$E5XE zNLV}sWd6Y6Z4#+FwM2Edc^+70gMa%&4k>9tMpl9!>D2_J_=se*6*JQANi$qL?Lh8l zGUyzf-{11*pFi-0_Nu8^KJ>&J7Oe&ybp7rKL&@!gtTE`UsUHu3%5Dd+EEo0RJs-%_ zy{ow|@OxrDQz^s#Xdiw>2TvDR6r6#|Ioh*-$$HRlbcXclE=1CNN%^!@^DOjRShJ*f zUpQ!P^Q$=bV`n6QVgvZ;?~3HRU(klF)}%{kKI6sRUTg=>nf>SpCBPMYh4diyk1DU5 z0(5li>Ga`D*Kngq2OUXe^j~(rpkt$9Z}xe$P6njCta8X43zo~<+=D<`d8R6q;<~y-Zl?dXhhX$t13w}!9pVL zKX?gqdK9?p>ygRHA6FTXBS6@zXJ4%~p1jI>X!(hAdT)taRyxxKvAA67vZ$x3g8eaP z%A!6vSP@T8ZdNw7#MIP?w>NM^BqTsy2?7oj9Hc$Qbz0U^^2OQj+mnetwwpI)U3lV_ zk{xkF#vBZOTWy}DH8+;DN*Qqvi!VD7{5M)1^>-mv*=@T~ZP$8(mu7yYMdrYye8GpK zP+<{XL=|ty(BGY;F6-G43h@30ZegM2oMMTrjwWm{|J_`y%gXidTGEkP_mkxT)K9aWpv@0|*ZYUssThUv z-hW(BF!$h&QcD;=IrB$o%K@8s1i2C@U?3S`empJLZPx&t$tLU_aR3^KE~O=XdKur{d6vmtRVM2F2^ZFz*$eRnVU|p0DnANx5U{M0(vqS6joYEe-Zp8ALZfU85Dd*g|%i8 z`Ql)8!fb-I)+r#0sRqXsM|t=QH@*Dq=LHozSb+gIBPKsC5?zWmwb2JnfQS#eyxs|h zYT>O~CVd65L`(`Q7 zOR~HiuVh12rJ#Lhp{|(D0)384|BoCS)ee{L)l=GPjBx!@g_#$?4#%Xby9ThmPjsHD zmJiEpX)lantcDi{EXbV^CbZ1W3FR-YmnA-5&{wW`IZ!b>_>PB zxOlSNI@bQ;q{sH)w0~+F^r+e?oyrol6>Qw0XnWCTMZ7D^yc{27mN#VnC1>}rRNHbZMQ zAc%$`!t<26E9ED8%uF_*LqpL?hD9tvs~sH-Nx@tvJM``#15;fnv$QT1761LM&PkfO zrc{^+tZJ2W)0)t}gH8=UX2NKuEYXFwk*<4+5Y=8><0)+KK+sA##C+ktbB>3K0}J z;Bz`??(PinN68O7W|Y|2fiK{0mWrF(ah+hea$vdKKV)DR7avbja_QT(4pKQ3_#OMw zj+_f|&LirUYtIA9^(?{;(Mi7KUQ3%wpQz0rilmVB#Ppo-O2nNBSVnZOASqb|Aw<01 zA*zz9J<=~Y9WkD)p%(fEdULtpD5fi8GNAyWkI=0TOx;DQB4<#*jp|IjX@JJZP;6*0 zXqTiHTJ;WD+}@kj+Bum;FFQ;Y3)H&tD>Pz|WL^+uO(or-G=LbD|4#_b37vOKP-pQ+ zn_gYo9EASSMWWJ}ljFn~26#W|c&%Zp83ggFYQ5 z3C$}=)7f6+!)ag=!}q8dC2{K}Q`K7IULP9pGrq{?qs17DSec0-Dx(0CFLk5}qIE@RE)-Z&zVq>3oWi0f=Ssxf z*k9W%1LCb)8Bs%g$QdDW374fo6qo4>z|03C9p+PTpcWQ6rTv5J4Vz`AUm9l1 z)Bd64+s20|YX)dFAd)8iku7{Ru*eS1HQKt72JD!iNh_5Zn%+h|!PmuuP5rIeP`a5E zojGzn_dpa@>4Zxe7h!y=%qTFkjh+VD2^6N>o7l~_O#B9Q3ptW5E;`8$|91wdmPQ~} z70QIUu;GJlGV;B;!h4DIEN8RFeRC?{hJ=n(j-nqj)Ms+byo{Ya%zQzxU%*&M&afvA zIy(4D_9hc4-@z!*FC5uepyd2MR`4KRic5=DuIEX*v0y2DqwBi&RPaj3frmc5{H%8oU7ef^l zA0OHu7w5SyA9n<9z|#rx;1fzSy ziwX*4zq`;1e19L6MO!)S6w^8{G0rVXKFgv#Bum7NGn9mmItnQjB9x9Ef;oL4k(TLg z8sFNH-XRgYqrUbyYREMM!vlRrhpp+LS!sswjs6zrToik$`-=hL>3Glbs&H9Q<4s`q zr)PDYv&EsQEB>!pq0x`Z``EFCj$HDZPlRPGf2hniE@xW*A$-I&U)H3_XJ`55XT$!s zCj@S+F7)S4yFR+C5uONrxX35GC4zuMGpd5m-OFIZP|uyqum@?#@B&hE^!UQh(2~T@ zwqh8NAz<G~+)DBF$I598VGU+0%X5fix8vFk6n@ckB= z@ySK`zk9k}FjACPY)~0>z(6(KZ?@;n_a{|0Al+p(Or!3XlQE8&$7;Q=-hz%Q&xEU4XzGV3>|YV4le>b|*6 z+xjj1ea4egN8Tuj9t`SxP9eKWroVH3b!xS7sz4fVat4;%fj(z_ zLi@$PAQe0meo)F$WbI&6hNb&hX8*V!HktzERZe(a6lP3JOhJAT#0(1@HBFfjsqc_F zyCk=Q2ttS3^RmeHqX|4Ke*SV92nM0XbkKBxjHU;FdIaJ~?6mXm-$j*{h zPShFjfW&-TEv)6K99uzh=?~G*F8TCEw3ExVseM_f4r_3kD;ZC+b1*>wd9^yTxiTui z#wP%U@#uNRLktY@_UASqkdy#0Tu)Ly>z8ProTH->6XFfLJ)wHni!djiM0GHXKD}3d z3m|>*t%OV+2`u924SbS>;4oN4eL5**UB=Ir_SYIg=ohhrE|!c~A;)$lo#$fie-j0? zww6P!a)OqNcjAa-Aj20rYh@=eI(}u7Rsq5WRvEcH=dl>O6DvCnw8fF=x>AzTfUXAL z!r_glQ}==vl~iMX`~;EX$!$=E(7+$Cr3)PGrJImXKs@TGrUei;7u5l;^b6>atZ>Om z*Lm5)MRqbnLowCVm+~V$%c^Wp^Wn_6ZAXSTG0+BiL5RL(9s*>m#b}T`g!bE)F&%UlpgrQZHGmM7~%V9A&^I36g~O`%#C=eH#dwuAmhu1 z1p=Qsp}*>34y}3_@#oqEShs{Y_3~zYb1q|}us$@9VA8;YB_mw${{0PRzSzv?@rD66I##^sMFN_Zp)%`yQ(c2tq0BZAr-B z0K&7Np4&#Wuk^C;W@__6?(RN*-Oh=XnnaUa5X(-h{}VyeFQ-7->=#3k9^?CgyXslK z^~l=QmBZvDw^>?LL3Ry6%2Kq=^eCKYUxVj^Qe)Pt1VL)7!DTEv{odB4KEJ{Sh=Suk z^HVIe1R!J-zUGhIQz}T92 zuEAm%UVnc9Tw|+=vL6GkeNAbQ~^gW)b*7r;4y zzy-+y$xXxl$EdR+k)aWd5WBA&_U-{GKld9_ENlx27Z#8I9X8LPJx&&PZVLyTl z*tEIt?!4_UXc%~5E63WfvnGj_zTdnSC8u$$h6?&xPB9#dSd+5m{Qu-8Nic!k`>FKhNBUie^BD z9%@;K>Msy_Gp(cDPMzGUqnalgkzFcnXm!U8u_4&1^R$;jzjqvmHPk*#4R-H#b61{* z?{ezx{h9|DCmBRTX_QxYw;@4@n)c4E0Btpi_Og=t)xP?fGtjmbZ{0ju(&FtuXa-5P z;(Y0}*@99&925~=5y4kiBi!Zq?SQkKU?O}% zjamDWb}skwc2SifWS@N`N2LvH=VATr$w&zs9~51wqM|ASEUX1;M(RN=MVy;=y9DD- z9F82Fw27O(>~KhW1Y;IdlZYn_bZU<|N}-E*Dt3#sgRHsv$Ol)wGQ}YyeHw3=;RIc! zC~DS(ua|WO{>&zam{2Ra&*-+*u+v9|2Q_3sC2?>S?WQfdzPLKS`2jYwxGx_cUuf|v zI=h-_0l6TUGz*}%@v(1o+Lz_4e4+!kpWWFTv{*uK)X=S1@czrC&NCy2<9L{_8#}>J z_N-@HrCV(0l1tr0~@)P%ux}3kY^~a^iINNd{%e zF)Zb~wF2Jra3xl(whK>*Q5S_?bJ)3_fx%GHJkxz**O>TW?XtGsz3?FL)h-sCUK>h~ z>N1u~-hGdt;VjXbpBIZ{r?psvX(4OFj$I=b9QLSt;0G2vPV+{|hq;B|J|O7Kf1UhI zv+nQyCBN+z#@ARM>hKd$Ch?M}gq-L~(8T5SLM!53(=HUT*izZQun+ID$NpFuLf)b( zyw&F<6r)G@@``9htX!*!7)T_>3gYs`t)nix(i)-D-o24z4xmW6xwBwiArs zLSxQvsXgAM9fTbiPc?jVM>6t%?9jBCqUoIeeJJt_vMEFNGHU{E{7uJKKi zu4*YTlKJH;?QK!H2=)0PL*Ygc5o3*gow4msg`ai@_Oit0I4s3l^NR@dyF@GT!Kxs# zP*a7r&-i}o>L#4$1))o-%a!=lTV7VRb7L>iHL8lGJOi+=J7f&zv_DbcAm_5?efHMF zqcND}gsQsRAnLA1D}8lzeYmQexQa` zZEBg%FRDg$+t;QZlMPKxC86gd|IihFx=ZBRzoM1*%`;?|eG?V-B<9Dts1^@QXId0y zAtLPZZn`19kE;z3pvXDBI5m?spPQOw%n375`iXqxDJrvsvD)mr_luMHQf`vo*vJ$- zINaRT?J=)|7DXHRgt}R^S1lZ-JEOe+vZUh5nG{F$_=4I+O&Db>>*)tP*Qk)WHfrn| zQe^8=Miia?_cVTf5ojSIsfn-Cu)sL9GDRUqsr6ZSnItP^Q*f>91`O{4A1rK#@C`ML z4~L%9UMm6S{W|D?WH?2y_#?l(pkn5?x>5L>Ifs?I zmG!Ur4-R#b4>|TUXlp*YZcKz_C2g$DVZkfk#T1nfb<^G@iLlSu-53(D)$Itz6w5gN zsUL%1+ZbMnjl+)-mXULS2Q@{UCMbYMx9VN@AjN8AYi1mrv%(MSJ(DrjiO=thoDKKAxNS>Lx^#s#eYwpfbi#y!KWAPK$hauNj%=dDzSUE%f{Lcy+ON&1%?MP2c<7i*D{YQc|PB@Fe$5P z{M#%nLvwB06@xQNy!0De$_+w<0;R+mnC|Zt^wfdx6Sq@0YnrzuvvWOR1%VV_U9Ky6 zlLJL5?HACeZKch>xfK3L>VsrZmjHxIgdQH?ppES3H+UrU4#0weSfKk)Df*Rlf1`GZ zHpV%=Fyet%=p!1xKo2k5Jw7dSH_K=nDsmk}KWmMnnAY~o=p_!@>^&VuiB4l1M}gE7 z$M>Ha8sWTsUX^u!UU2>Z{9%vWhn<@TE-zDZD#uCOEv70)ujWTG=m6-f&@%F$VhgC6 z|LKs2=7tjA_$kEC$-t?WR+R~e+X5+?lwEsNMF&XqEQffVP#uM)E`)b2Nwwe&_crg`~U!n(Wt)l&BI3#f_2+eU%$ z(V@R(nd$pw2l*%}FUesQ^}ON62xlK*=^p7?e@#Pi7(||zydL?GzpnVyI5IbRA_bRq zE51}1I(1!(a>=sz5gs7u6XQ|eegdgF1tX6Po>EKsp3GUfu;k;Ze_sUt@%eUoX(^U| zV#s%%SfCYeUG&@X>g14yoypBG;Pn4zC<0H0Yo+|)UC36<1j1A&RkeyEFIpx@UmJ4! ztvCF~?j)6@iF*Y@pQG6s1rJ`gg~&-xtm?>TM`eO7%z3Q)&Fi*Z48R=~&U? zz^X+EI`yNkT;d3#sa4-AaxsuSVJO!BMD^7>noufT45ddEwb=iE*#}O=@blbb&nL{! zxd16N;(gnUdS<>CVpy}UXSzcRSDi($O+za59v){eBDYB+x3_i zt3WGb+HiV_mWF3lA*!f|!rhnuW%C%*#m}g>ABA@2U=)neP}!!6uOoj?^h-x)wRJ*) zC%U`*OLdr#lof0MjzR(9*E*9lnRg5ZrU$>tMS$*D$eE)4jazqo$Q6!+2##2b9&VT) zoXvIz<9S_58~vw_uCn?XP|1LuE>}a-FdiXJfqXcAvnaBQ>S?)1-liAYI64VNfX^06 zG?Nrka-Wp14A+;uCQqyNIh>r7k*Mz&sy2Eg!pjv)A!YnO9#1MMA?8M{`(*CtT>MIT z=&7)$n@E;pqvo<;MZv4)djovG$!-Y?JL@xL28FTD!`sBs?@M76#45#7Stid0xyI zrly0!o`x_X6F3#t`WZpr4EeM|uMxMZckcc4P`mnrWg0KkBo*uBl6@p;4$ZS9wvf;_ z8p2uUa7iIVUK04e#=lN%cu$zvn8u!-xGo!Rt?%Bw5_4u335T|`-@6)w*fnvZ64#9o z|Ga>rw(o_adaw3hFz?j%8?v@5JB`JLDD#`E!D8j*OYLl7sM;ZYT)h&eL2Ty>yMV(w zRPYmj#L&kw-u@LP?6*PkgWT;S!q(^O$Ov-OtR$GNJrVFeaqi!jrx@TWbndbFq;gGgMf@I2Q^y+|eK=Q(0OS&dpWH_h&VJW6Z{( zfb(BB2yvlh?-WxQ-pE&9m9JJwY_Fz&NCgeB{l^-BCl+u?biDlNbU%~p>I&o+ z>Nv0{kf}{6^t9-vPSNSR6632=P+nV#IWsduO~)t%po9^HtM(Kk4eY;|Gtf!(D5AuU zJ>j>M9-O<7(?E|-$!%Z5jn{RXIpe%>toUYkZU-6Kc7&%Dq}^^gJU>pJ(IiIEU#Dw| z+M+IDYIuW;H_zc#zuG9e=V0~WEaOc3^JfP$=kkG=!9!} z|4v$V5uPivwu=*ScI9HFYpYp5WWXVNvkL#&zJy-yzt2398!A0yXz=4vKiIM7B-2J( zl`>{*iWPNnA~#rmpU|6eQ}=q%=F&GiK!c*M|HJF$4DMpHOTO9zMw#k$e7v7W%VidB z`Pt-=1vyUpcJTyuF48u3!cIzYO3d3DUWOH2iDWmdL7dD}%m#Mu#%i8-P*l;foAY|A z@IeiFSaHroH z2{G$ND)g5y?jNM*1eP7)8Fob0aI%EO>wYTu$)_HaRZ!}nU{6Y2*q?F^9!=S|SARO+ zybGABK?B@WhTc0}`k8lqNKkpT?y^1m_X487s!EPoVWG#3%RVX^Y;vcMvS|vSfh+FB z?x^BhxPu={aZLuv@;IP}G4j^8N+GdzzXj^ewrcN0UljW-&%i1^!8+II%MKkgBL@A6ATEJ8)M!(%4H+(9Jcd(Eb6RZw?L?!hepOZo;n7Y31mC^Q#Xa4z0zz zQrU~Y4mmC=$nlyF0uQ>M!D2~B3)0*GVuE(O5(ck5z|ry>m%W~-o=E@dA#2nrYMs1m z{d#2+#f!J@%>4RzN+|<9y_laLfBjrCbJp>cMoPzAO59qNRjVW@pDnX$(>-2jFl3Ux zxqPkauPJt7mj*`Iez+K?_DO#c9c*;W(8v?JZg;wMO0QaOP|?UWMGJM5c<8Z=i@Y5_ zF0M|Q10PR=5Gfr3PHOee$q+*8W=7bWeA!;BbgPQEW*fqA8#*zg=5WD}^k`3T>N_9I zKYsz;q5Y2Z)k{M#6Se(VB`E?v-_@JKa&j$P|+)vBQPO z-F(>D;M#`&Q<#@;^%QN-d%pZa9_TD}%|vlp6jygd3H6K=m&Uxjda*ege+3 zp4Xjm(NG1HxkcaD;4-7M@AO(9ao%s;NPV-P%>4n#+4|Xg-bl4ma!gqk$@}h>FMyg9 z-ua2QzSu8M_p9BDS64|*FSxPM^;c?YVP_+5@Et6DN_fME0-hpwi^XJJiQ zYKm1358lZ1>hz(`4U3;QBm)uu2o3$v(_NH@<%@#qA@g{bYMuBbbP~*7-RT$lt5kw< zxL$V;{N6mOV^mX4!S_};+G`)Pz+LR|lNuHQ5~aFSOK)8*FlO)hCRR#oQo{JB2e$ z4_>R7nj(wP2Y45v#(NdLHT{MuCz5Z385pvahbICOYxQ$l7~xmQM62v##^A-G5fO*Z-_? z(f_r|P15s+Tg-|tTEcQ+`Y6ipCfp-ZF2y!}pC~4i7fS19>I^73hHP$Q-=QMY<7`hE zl@_@iaxq4Vs!frG(L3RXT`~l6(hKy^2=C;0bIukg!$ZGYT%`$}50oe>F|SYAv%qoD z2b>y5LTCs}C1BJ-mw}$XRX=QDS}k*~w%Xo#J2V2whL=0MD>VgQsI*cTquiY7Ncmv0 zJEhT3{l$e$p^OX;om4!Zt-rs6PKBNxdkllJQ<-ups3}9=`5E$4wAEsP=LHBld&Fvz zUjC@Y#-36#G9qkXJ0~aBYI|nJt>jn;3`5*=8Y0BewNg2`c63AO8aK-ZGC4RV%>o22 zNM30+<5d3!Y(*LWM!eLC#}o@n@7{nx<^Wcs(Os%wRNG0 zLYi87Rj0Q#Kq|009^bgF>lpkk`uLeeHJocRGW>@Wp@eR3#xM)`a3-}U;n<6SBq1ye z#>zH3t(p{ESB_Zxw}Ac*DK4+R8^UK^|610H;B4*zCsQbmz6_or8nqPvl_ep+BFHND zF03++;KAzW0B>MO?>8UIH16bpVd4`&6Ej)Q;b;#kQn9xtsh;V8F7MW>;3v>CB4v{o zp(_~gAXgpLK;BJDK3)5IDwFS~1_JPsQ7qy5=vvrSnEh;F4x+am3G_UjF*=)u@4wh> zq|{>e4Gtd7qk+7>rVELGd4X0#N3~~E7#6cb#@3YuMx_T#CLCHsrk3VTHe)Dq;Gll) zb$$i-$TMn^NCy^Q_cvFf@3P3eBzKlRaIp zh?6gulgYq83bf+0Y*ml!Nwd5Pgt0Gu$jS=O>axs>81+Qw9dO2j{A5_dP+h79-YcFq zvAZA;lSD~|@$mjrst@?iQ8VO;J4IZjMqhCMtgcd^jj7ez^ODoRvo3$5tqzv~?&){O zX&?~r`aiqe%zq5dkDdQAIH?p3^Qtd7=Qj?;`%>6BVR{%}WCOi=hTth{A`bOaD4s@I zzV$@ylUH_Nq2&0#qbF5(E5z|{w1+wl=_G~^-Q$G8u4qKAzcX0Q5*RCWyl@@bZwA^~ zl=Z^2Y z7Js_MVBz+{78UzY_5yCw{A6M9Po?$rbr1`UVE$HJQz0pC#_^-ryw>#@zgkcC@4_ip zI-_A%5S?bK)AK(hJ-@C`{F!@#mJy1Lfb34iO^Ohn$Y>4H`2XF6p^T_6MLB{iB*WDE*z$MGjS3`6V~OFDzM+8#5+2&owE6x*Ot&Y*9%74<2gHXe1#{ zb1@Ww#*cXCJBN6;M`R4Xo^Pvd0$8`_F?_Mb0i5gRaCQfZj6)&D8zh!2PDN%NKBl)Q+jXQ!0}p?HV=^!k6w z6+gT;k{>Neoi=#H{pdrIC(%a6Cxrdneq`~$apxTe0S8(6eBTIj13tc+o8z#Evv)WW ztKM2Td0WScqgbs}r7c6I90&u7OaH`xDcD$`yueaFqUW)=M1NnRA=Kp%y{<#YgM2ft z7`a|+jD_uO6&0VtgMVp+u7q(N*CTVl45SDbUUp{3VKU zpay0mA!H>t8p-FC7$rusGWokqig-}2b zF%FC(Bw))488Zd~)-^m3Cq1ydoc34B?*=%R0Sc3GL5sd#TVyf1*G!2S^u`mtMoK{Kh<9xf7{D^zdIHDW|*=i-Z zFAC=Nj!{0~?WLANAgOM^-HTx4_&E?GCA5B)@Q=kGktNy~VEn-(i?2kdORA?=@RON$ z4l?v4Ao$7B)fk%pePECY3m6pzI7WMhf+{V=&MT{mRu?=b!IM{6QZask8#?r-cajNNGLA<`d<*Gc_~taH+kmpFA|VTGKUf#3Z@~_-v=ztM z3Kdl3#;fB|E#DsILZ0yj*+J_4OLZ`F#dqI(7WNJ4&t2Mb6N6jsK5f)B#aG(7ZSlJI za?>7wHJv?QC)1+g8_!19H7c4+6KrBNwsLFmw*l!}&ucaIuAS#-oU7`J-UzbIC%4+e z3<22dU=JfSd!`K>T{yS@0mhH){&$S=_x}RpRStUH2<>j4aM0Ke1+LrVNxC>_%3Kik z3z)47s<1e6!V^63>)}8D+%FZ_=VNEf?kw$}HqM-Ot=6I@-+z zvroCeDfF)?MOP^{>ks9G=_`?S$=5INyiF1<3VM^JJ5o0HMbCQ85 zrbKzxJ7KjzBk=W2dO_`-QWE!aAK|UPI3Hw$rI|1$HSGwLrm=ORoA4y)HE%JV;xcnz zW>gMyo%8=<>nywCfR;6lG|;%aHSSJucXzko?hstNad!(ILU2if1#g@X+%32TClH+5 z=gj>uYt8(JUAx|@dh4masXJq71wlBw4=)%|T)3&TpVOV_3*`q6r<%oVO{~CJ6&V@2 zxWvS<&~j@6g3Jp8^Te1Ozt$vQd2>-y?Psby?D=$-R^Hiy4~`v=@(U`h8b7Z!Oh8uY% zqxWAbedxMxF>Xw@@SwrjOnYSTZ>02j7$_3zKjp{+oj4UJ7z~er+pHnuET(&7haO(sftv!O6+2QizIX%6+V(W|O>fBDX zXc#!IyPbxjR##2vKikTT%%f7Pke2s+j*S?TFYNS;j5x!C>Cl*;h7HYYOZ=ES!`JGq z7&q7TqV0iyIB^5J0zY=OO-O_voFY{m8t!iK2vW@ai)wvGtYF*cJ<5$&-6>yYSQ>)S z5xdEbCQW$LjFEL4+N`^Rxkk7_^JO=A15jg*Fzb2ujCU>>kVGzc(R4@vcEa4VLg3+l ziBveB{pvl|2l8dm%##T-WkOaMo>>r#o9w~8ztX`i(ML75k*Hg5XHF+FU*QZMad)Iz zh^aRe;^!DBbD&0|`v^xNN*K}oqwGGri>Z}*J}RjWfq)>(D@tsftx8JyMF`_ z(j*}%wB4K$U6c{vlVms)|Oq#bVS>EGun)g7vWa)bNL+ytofPlO7X-KOhpqUky zOkq-MX?_^hM%5@_SbBq&A_i-Fx@|{lj2^fK6btuUMhpI20uY%xwY{r#Hg z=4Vhh@^-Ju#n=Ca)2)7JMWgq)r1~Q>LT{OykQ%&D73pk6V+NXI#P@LV3gH%^Q%&Tp zNCf7X`Lxzj0Y1SVYmC6&@Z|Shh@FMnRq}oV5)J?+ydWChE?QLn*i_Z+7WRS$L zzOTtsPN=aA)0`5MnFh43Ayrw{fIw3QU`=K(^@m1{=D{b&S0jbIAkfrO$8E%nEqf)U zLPR7kQTOO*X%Qwl-7%uv&M35ejC`$EyFN1<-{Y4B&m!XSlBLahOIgW_uXYA|1wHd~ znsvpoEj4V5MCQS!6R6LdjHv`*kL`mqQ8r7j*h;nafwQ$XuOMr1gaZn6*T-^qj49PP zPHkKkHde*s{mBh`F8B9eex32SF_?_BX=q@&Q39Az;f_f=%F64=s4T36>_()$0^j{W ziG>Np9QB@0$GCsfH&-hYMW#b0%fGTlUEPoV{3&Y^BtllOre*p&3Ob9XC+OEV!GK4` zgh*3;EE|y;(UjFSdX*`;89(4X-sHJEAOIqdKbBXTVvSH$5+L~XsagLQx+qr7HzlfQ z%Y3@)v`oX83}pE2Bnh5?2S%J%7rnM8^&!~Rhv46`b>UPulQoXcgRH>J( zRVJ%t?6ncE+19j;79y>siJw0u(1fp3#>X3vPSCkUR#SI{(xv5V!zh#uBbaPQu4RnBUFlA<@oN8nik5N)wEDLKJL_P`S&=uz7#Qcb{ zw>{h|Ev=zC{)_;SvrD^hi7g`QS+3bjIC3k;-JUO+o#y!v)gzu7$<4)zG*&HqnK}{M zD$6dRbq$TpU|kNELY|%FO`|R6W^b&yHB^i;TnO~0So(PJmH849|0PeoJxXURrR86( zL``39@5@Mw4~2nHa}Hkf38fXz9p{;>zAhbLVW^6IO0l~Xqnah&)7Ljz0^Z78+~$Ci z(wxpHW%`Gc#KrPHeR2epJA`mkR;h(hQ9Szp}c z$TVnRUL2lk-F1+$f>uBn+Ydfp$aZjO>N9(j`Zf0rX7*ni<5AWssf}6H39dhcO2|2x znVD)3+otriw7ym1)k%H6hdur{zKdsg4*az0v<8?}0sC)! zJ0gU;H0I6|syjftZ=Pbay+2A`joT>!T|K%bLf(e$9Pr{){_x&Mv$nQ20{A&Op&lO} zGqbS`M57ZgJ4_XcqG)PrB0wiS$B%{&-|lx?Y}zQlMd{Bt84;#z+n&QE@daDvtVHLn z)9Pgr({j9{iU!1dFjp1w3HE%z=n0)Wfmh8szNL+&(Nu< zC1DLhT3%kf>gsBCeapu^<@Af9b~|eNig{2t6e9ISqOdDedviWDs*rc^^UROxU6PXS zh_PtZJQE}6;In+v6l`~&O&^)@u$t34K`cd+tF+7o_*~%{a34Fs2BHvtkY{G*bVy(K zMi}^(BS1Uyp7tVj}+Zi+!mhBl#|}}7tLqugbE5W(g|t5X1MGR zlsNqVaBj;ht5NhZiAy*k%}tq?_;AzJ7)yQYzW zMtsB%0Svy^Ae`d(?zk!U@c82@bat+QK09Q=)${F#q2J|t<&+F1{M$3XFU&-Hi6tNW zg&cfgCLw|xd7D8|y)Rb9k_dCB5(m_&bt1Negfx;|h{DEWqf8NGr2TY=bc#7xR*t?I zfr=#oBN@+(l>!!8{MWHynpmkO?7E_mvA!{CGtS)opR+Dw&U|w)^wb;<2_BX?pTlYT z+>o>#zT?N`zTp`%T;=@HdROTXd!7VVavsm1ltBB0``92ynNiX2(6NN@IQywGe_OP` zY)oLB68-Bogt{o|xm<@guLid108gs6fd8&E1FpHWhp?V9^vJGuL_J-4+L-0E80npSCSDQ zH_D6*bvn+6F(2t5WdVcsIen5MH1BM^E0poDvt#&^79h5{`Z_o$x2RnY=1!gkC=P#;|M$z; z{_XbVf^Ke(y5F~Y(+Gi-hXqKR5SJqF>q`JD4MuI~BVDPE;UL>1SFzzzpiVb+>D=>A zMYNJ)6zSUO7-mOaEJ=+D^*%l~Mt`eh&0&Vhoz@p2XUph%Ui<&7xoBo}m$~E0!|@>N ztoCS$6C&yVbevTl;jT6rlv)uZWBo!KPrHm}b3%kW{e5q{$_^*+7Yv6%a93qS$4nb7 zrHWYdUh(RLY^}-hiIGg>!(>0H!XQX!~# z;YmMoqivWy#qUzd<^5cd>W4F$J))e;eLA}hjm<^Y;Kt^Zj!&CL9|#`}+7b-b^O8zX zxJi*x39|cISUCQk8kcizBa(|cEP5QWVXeLMdHiGER`I}WYRlr}+1tbBKP}hY2poj? zex5O4yt=fK85!=ah6vbNH$S|cCq6DI#l9z=u{bUPL(zyNH!uda7E5LLj31Sjh7*N5 z6_Y7VVa~xJAwe}ZHpZqe=!oTMg?oc%9=5iWJK#OZaVAz)a{&x=^BHw4u&M3BY^@5| ziO}Hob^TQ4L={8#*st|26};;5R%JRY6&=Ww@Sr$AOQBc~HKwYsv}69xE^E2fH^SqC zmGbe4R=76q-hN0F{75v9_ckw3SJhBWnOtw~NbBtWN5TJ%+gdlV%w{JZyl(9|Fy|$; zk`n+$9UnO>t=NU*pmSIjUG3l4h}8Gi5|-}z03}2dnM{+^=s~|qO{!$z_tp6NCnGZx zQ3FJ*IG+Tnb%#ncdW^_hItrnE+nr$s9U2xxnvMXcVpPb=XO=sL!AfJyG7nYozLYi~ z7b1A7hoUZo0B+C35ib4bD_s`V(`BEi0EZz-c3SpEgm>tgJiT?qNLD{WgxyS;k1Tq_ ziZ&vDc;tVk-Us_2KgaPffxFEF!RpMs6tGY6nKcy9gY(QJDk1hHom{V-+~qI zCj~9S%jO2oRe1p-7e%QC7(#k$0S6QV-<2(CUih$skP7Sr|a+Q3vOxQ7{8+?c-RzwyW-QUsRo_ephmfiM{L!9 z1umCizMYM1_3o0$Ftp&%mqquPukKxu2(7hjb)GRlw>m*RYLSK3=0_bC#(r1fR`>ZQ z@^(q;fOOz;;c|V1D zO0J;4rVLy@B(tMsedn?CM!3;15ZXP75asn?EY)W4$OVsk`5Y~KHk*qQoUiM(`vf%I zBu0Ye>HR2oEx&vxO5B?fo00|C+uE@J-){Hk3ckmBeR)eK0qA}()!2&HEgbFmF#3}Z z*4KXvl9!j~d8>Q0?e?##(mVnwFH8RYU}h!;qrK*b)0Xez;u0nTuRUU7ytnuFf9bfg zv)8TI3TtXo3w1t7>s}E2O`7n84pMa5eE`vDBI~Nya(~7Hr|N^7)Ww0KEeOr(nPN-8 zVWJkPkSdCZ_iWymkzg(fy}!dASzVXSUx_0fu|q13V%}4E4NqYUOD(XIkUs1Kf=p>Z zH7?GwvZ_pE1)_OmrlK=)aCs8M{NI}nx32<)too&ujVar;SF~olqM}>)*fsN^lbg62 z=N3)h79sIv$QT+c(h36S{7X5syb{@wm#v?qmmBx7IzBO#ImsbA0WR$(ptvNk-Qf~$ zlFpHx2pVG1tg~iCmGl6GV)Ni^nE`BrS6xpk&z{$=ll2X3&n#C>VB;N zw%@2w6>>|-mwtJJ?Z9MS^hay|^sfu?5TJzCONZu8_)&sw|5{-s-;?3(CBvn&WBzw# z&sxU5`U)~85lr_XCyM=gy(FXGO-g||QV`R}-w05GuwxDtvQXmm91^qCYcE%afu$rwZa(kveH&mbGeK-JQg+lEiIWwFy_!})e| z_35nUOV7($V`)tdJd(T}GYhqi|MQUPt_Z`mxH>o12AWH}ACcm1F$>?3g9 zoOq0$r9H$Oyh-o>A!zm1R}Tnpf6dqqU%Bbluic}=7t5+WKhABjj?ED%^LCW9Jm z(X^3lD1`0%%6~U|U2(TCzNIEf0QS`SxIB_&6a2WpdxF2j;+NBn%DwBOX1tStb!wW- z7m;KXf6>rZ1DzPe4K8XTUrY3cHsqdVeRmY8=iZ?em3Jine@4 z!}pAaJW7!N{H@a;IXs*d0XDWBE-p&nAI{x=Y~%jwiF+&iMT^hVTKff^Sx5qsbfBC= zl)fpJ@|UrEn5SI3lCEbT5Os|4vD0S7EnPuMh;r;KpQH%8{}jw*e<;o^e37{MK*B%8 zHQ7^g7a+;`&RUuOL_k=0to#SRVd??;T%>cLT%JNouHM+2>=1$pDo9=f3l=iiEr6cI z#YMHh?;hI${Jztl5UH|ba^Oh^F=@iGBI4_?0yQeCQ3>_mQPAlKSH|=RZki#5^;U3R zU+qF)5&!dWVeBk6_6x&bQIO zb#v;U9!Tv&#~bp&^(1Xcw-NV+?>-&^&Lpza!exg zdz6CpQDaN@IwB^N_9w2!Pc~_FS@dB>`%(W(j04SL1`$X;{vo$|vF23C@cc82+Aj`g zZ~I97mr+KoKlnu0sZjbmO&T)6*doJ=-hG z!OL*x?qR@GwaBClVY$z@phkOaDlc@31+NHAC3+p7MC~PCijhu~<%B5Thl8A5@v0gZ zg&;rp+Hx3~ZCpo2XO`C}tn*(5zd>OtiL-OFY6ej6Hg>zaT~kM+7j5Nv5Jdti_AmKmqA z+==3E7$C@J{qJ9XRXf<*EHU)0D<`DYDX7FT66mU$rMV}oQ;EOoDZt9L8fR8McjS=ZZgzoUuLXfhvRBF6;fL2LACoy5Z3c1>~U##J4BtEgqff&{4W|M!!0gLMhmzm z96KksJgGsdX_C%S7NO7uTvML5Vb;4d-xv-BloPySx&(i%krae!N{tNL=)ep?S z^^HET=yva zmo`~nFSlV+iWh!sD4O>Kv}cScoyrRIW{qbB9a(7TAxpRE0Dt7*5Xw$}27)J}B}ELb z)++lX#0mM(i`yl+BADEd zf0|*IY_UngEicx$nbpM9pWjS2n@w1CPdmuHuZm=j-Am16N}`6Z+~K0@p3I*(b*}xwqb0utQ&CVC-JT> zq075#OKnMfeBHc{To$X9<=OO(R86p- zQ{H0#t~sF-x(ZRbwKrTu?f-hilw#X&)_Q3#iq!5zAN^MsIkCI&;W-XUxXZPU+GC`0 z73$*u#@C4LprJvkQPX< zh)+_Ht`Ic9aEsIZkOuzR*hro9tB!`gEUrjK(d(!R6>Gd;2Qf_6ac!iuqLs7-z#6Nl z;Pv>I)PO0EbE?~jE5NJT?(un@$(JX`$4Dp;kI%=4gp5*FR}p&_t)|S05=^P%w`J%H z&Wb@aUD43$$BNB51=a0E45+l?=r+ez{J@8J3cCrO=K4cTyBL=PM}Z-Dm2!;SydNUV zqM5K%NTHMBKqSIp)f+M zo7ARONq$#48m|co>a#67{3?dG)3tD%*!f}#T#H7!=cy%*^7Y+sT0l4CMwf-D|2$I0 z0SOMIZ#Ttk!>m>zk{mHV`C=T4!yi8$|DVtK4h9gS;PJPIo2XZZ^{R(Ovi40L7+zbP zA1CZakpy2(ak|#nQLQ-vQKqQCsXv|{*z+W!o^+u?qBh-kyWKKV7*K^_p^VuP`dqbz zDJse$-OG*Mw}b6hHP_YT`)6-{K?CU0&-?}OuPJ3na`HR+&wC&m78GqBCtzaSZc`P# z>5HnI3R2hXu-4onMP^ip9<0NHE99}Icw6#1+!Fe`t(A5Q@myY_#$pB7`7au$OBH|{9A))_ zjLrUgIh9`8ctWUSn|9^ZDTR14C@AwRaQk~F%1SGG&{%W5z80;HRA(il$^1NzrFBf0 zH%1Kbulnk$&E^oM-F3|oKes%4FFh-P-puu#LBm;W_cwU4;5+GsK}>3uPG|-`d@w@% zB0%W2(5B3BFWxN}%VuPi9fl21F3T& z%4d0)t50x@=3}BVB))%WDRU{srliBq%h*+5%esGjz~{DLFH>-_or@Mdfo* zRLblPcKD~wOcqq$P@T__FDAevPvtNQ(v@lQ$pW$OP^6^Af_nxs-M%#RiZTJj89MI@%Vx`U5t)5n;Mc zF2tG_#;qhdonvKFW~pXoXHg8ZVqvXuo>^awgnl(?RZ^4kCDThbw>ra(=a_X5P#?uL zt)TFd(n-{Z2+=d5lgdh->fc3mdXTslidmi$A$|Dh8OK^+IoBGx+~R`$ljm@I1*JF? z1yp7I<3FNFCX$=7c6Q9j&JOuq3vqDWSG7OLjcZfr{l0ZOZR%iI6W#}lfI70;s-8j# zaQe#D!iZn;|7HO$WN3v=vwuruQ8%S2QCoF2s#SW~DW~^ibZELr?!;n#30!}N%ej%? zM@E^1%&5~{<#ypO-Q_m82$cXY7n2WeBTu*{$j~!6DoNvEMoLxC0Or68aOuW65UTpb z&jtdI$A|<3_@w3GMAOrfo08{zQvTCBKFR{SiV44?qgd+Tm=GGf&?Gw##Inb0;>L*t zHwZ_E<`7FO$cG_9 zbPimZD8i<=JctE8ris#0wCN^sI8YElrJX5If*F#o?x-k_Odra^0}#G3X?$w`j&N4g zeMCtiPYh4`BgEU)8qM@BSD~M!!p?TT6t;0b=5)-=%&MSy0eVDMusR_@_!2=7!u}`0 zU;cFKpHx+3+JF}W$W~tt1 zHC=x$ZtyQbR1|Vk5ETAZf}1`QiD7>zd%wbt5Jba){n>^rO&qObl)(7Mr&aCnP+&%(f{Ga~ecW~8XRzS4d*sK`w zRv`Hg}RmVt_@$KBiB|Ry(!XVaC-cpl97js>jWe4|OJ)M@IXt z)^)V_7>54Di*LP3Hu}aF2m~@8i+>t1?4ji_}p3^?;EuL zz|hUbrca-xgv2-^KCrvFxjA}xIFwT+wYCa8L3@;n(2J}jQ&a5PoDh6zb+lH9lh?kT z)QS+Pp>2Icy?RG^p?GX8N+eE)e?cu(`4$-l_fbq(w&8_Lh#qOG0wLfLiAG2W*x*ie z37X#C8IG)nptJp*xF|B8rVO@Q;#@n$_CHz5J~``z+Ri-&LbGo0O&z-(a_RgPFk@tI$E1bZXy zQaIJ7CNh^M={l>Z#B>f$?JLTfl`&E@EG=y*grBX=*05hJMWJm}J3$<7soWW}H`>Pf z(LWwd2AK|J)MoG``cX;p^ZKFjd?chr^iA&~rLXggHbO_oxT11Xx!5=qrDHa)-s>1d zCe)B2^ZXV6>k*^CNaH|~8uCd&LW}oyfJ)djguNWnd#;Dvt4owMi&9p=*NkRn;LHHb?IKJy?G4U!c_M!6 zJsS;b^weJ+bjCTjoAj>!?AOow4e9(i`!Y!>>b)gSZPeD}m2D=HqJyANXCUrdRNsIk zf_=)X*0g_YM>jq~=r60?Js_#MW}Hg^_gkD(K>FJX61A!b*l-el_T-NHYiYW*Bs~+8 z1l*48AYQ}-7c(Qydi;5DSW>d)OB$AR$hUP#-)rCNd8byu4G{jYCz57@t5C-OO~_Sz z*BFrfXdhQ7SkRM}V~A{JNA9|(Kiq2myxh;0(7p4xl`Sx`IZ&|7GiQDR9D@%1Z&edywFhCd#)r}#J~GBLi`6a75$+d?ZTsDsuDTln#uv|%dDMxiDcnml z&}CCq<`<2fJ>~1pxi6kzsUhWSrFWu&iPiZ|LcswxAwgujy1KCRudro*PO&{m*#*S( zwoB(Oz%D2$naV0okG~#w_Vv9h?d>H01c8u{Q!$pOLh?vRp?gQ%mBWfW;u0~{&A1b4 zNj84Yc;CLE=h5|*xj3f|=j^h3ugxqd_5}e~P&(e)s?MNs8@B^sT)nF`Y}&?lG{S}3 zssbWeQiONl*`8hVC$w6weh8|(6hO9?TYN5E5v2(=8Xet zkYS42uL`O=C+r;;{rN&yvWgvu06ob^R1%heY(LE9nF)=Bc!PuNt-`L-j|Ru^U8Rvv z>g72e9kvHiCve_ie|z(O3{zbAxC`o_QLJ_6cfM9tVawPoalk%@rL|RbDIB&Rgi<}f zbaQNSOqwb`Zyvwz^!#KPB&=y8D?2Et2O(if#itv{qc6_4e>I+;*aO&qae%oN@E-P_ zp8yKH7t`6hj#wVven?RZV$I6)xT1sgeTs~XGOX@62Dc!53F1rGcq7b7)kp{l_ktgm z*Wqu~&qjZ;$*8Fz3^O`#3(ND%l(zO#Lk5Y@e*GF9o0PVk-uYE^<<+c|kSHzt;a!lZ z==9P?bmS)0PDvOBW^!r$H|j|m;ehvGjq<`PtP(7VWj{D_Kh-TdFY-%O=daE?-Tn3=yXr#tYqwLHu=KxU%1Oa4ygqWX)wUb<}ZxNVfL4YIB#uK_{ z1e?N*J-!h;)p<|-83|iTC<(30@YMwSdoliCr1Nx8*iW}19LC>f#bJO#d-%dVH4vj9yu^p<&%GxaC-Y>C+wT(0Rv47u7%5p#@cpfi3eXk)$ z^#b(0&z2f%10+a9Z}DNp5-=etcIlSl79>PmtMMlfZLijG_LK-e$E8G5lZZA|?!GXK zyNr}O*dBCydzx~h=_zIg&vfL9dOvo@^Sr(;5Z@nMyytm&`DS`A3P07GSGkIh6A%;A{4eC?HrE!eL2A^%9p_wITi)IPS!a%KSzYJ*tel^p0NX%TadAps zl@k3!XF2O4%2|Hd5lfu3p#?z@`lp%q1c}?%^ccAeR#o)HVvOZ&AhHd&Uw}#t$S!3> zYYLU`Kg>`5mQnJz&bBFDZj#ouaMpgF7rGM$61Mxf{ECr~h$-e{X6N;JHpZIx?LVr-ePwa9{fF9o)zqWICHL;r!Hd-fpO@tquFb zuZ=DJYiO`F!Aafr%cT4Hna3tPt}+tEjQ_w|xyC(TFXraz7gT+9L;^zl^181G$DiL- zRkK8PnBv&VI+qm+_neqXs1QU!kh(k{CBvCND&D5%hQ&7Y<~4$UXM^hYKYl&o%dN-| z?=Lly_)NXJB3o7RBf_KFkjimPbV3qaOgh}+xJ1-nav=HR(imV*$Yj(O+~XhOOk_}u zy3Aq+DZ;D7(Qbwy#=?b~h|Kcw83Bqp$8jksZ#;Rkv-0BZwZiveyr?yLEc328P{-rcjEi)aFA3?FxbCOA6-sg!mCZ+a-DU^U= z|Ac(&)F@N{0fnd1CIwhWWSPqaB5obS-Fzd9dZpnckLgTqs;y%ZSvTQ|k(%_WyAq?G z<<&9iM#{KJe8*gUd)sn|2!=L(RgMG7Xh)$GWWG%25XDU<;p`M~e|JzUukOMrvJFSr zNx0gyQuNDQUmO??aA_p7IOH2tP*$YK;$z$tM)N|V1Ud!mOaY?a#vKwiXbU|3{pH=< zNUwj^rm1J;7hy4sIF1heYY@pzol8W3)(*4Moj0qPn};ssUAknu1rgtgWWO94wfv@` zzV6v?it~sHRpFzk%J`Ufnn`YzYlKZB`~FDb^8hy|rG{{RoZGmd_SBP(?M88zS@P+ASc+<%MnltkPKu{s~(qZPOUozk6m?cb+jmzK7pWxq?V zs`rp)dRW6jI2IRwaja<^@U;L{&Y4dWTP5%Ze+Up-YfeCpM@3b_xRnn_L-7&r5z3~H z!0+f0GRKLSi?;2ZMV+4pdds5R6A0so9K;-Hnc7lsNlZpVzY10DfRn8O5!lAN9w;2k zU)!JYoTvGCARU>4-(5qjod3Yl5R%8C=w;rt0--0yl40?pDpQ2-pU>aC-7Y3Q-pj9i z)UNIG|GbC*O;80D%1YnU_B`Fe=2S+g&_TL+CX6&gwQ?e&gXyK;iPpTNH8u5i7omQcj6W$rMkvT07m83mgg`{N)a)x=+A9vw-?*xb!{5a5>w~u=K3YZIv2hYb__-J$JB`+OMY{d+;+&?t z7*m~$m_+$fX7ejDry3LM2uX^w*<+z2r0~B<>z%aKjLk6X^$*CWml~e!C023-XmFTq zUB)G|w3asZs-iuQ#9Zu3KH9_-fn}IT33IAbsVG*pa*X4j8-03Bd8XJQMbxXkpp)?PYD4{PT@=ug(75 z<*z#a^KWLgBmV3pE4}}l+rARPPB-}Qw_v9X(Y1e>z)blEK)O8JPsFh8A*BLX1fl-> z9fr3vRa``L74OL0zapKX&I^B}vF8St@x;A}FMV6@!Bi>FD$n zoY={vn(Ro&Ez9Mli_?&%v%eZPNF*T_I#EPNaf?gM>!Ye=W-z2;((>J}@3$n*J|STh zGsm#w4`SqU4(&e;Xp%07IcU)*8e}JU|I~%a|H*4w3~zJLc*t z$d7p9(+wFd3z9@>b@hg`HVW;i%Fo1DaV%Dy5qO()I5)hhRGZs^P)1yIR{{ZtR8xA_ ztL8406o6hIeGO5xB8E0iS-j0ktxt6QzzNUy#HY6kk*oJIu$s$~@UMi5Ci1p&4yUv( z-S=;>$qOBH9vz#uS2cZ`0T@OBa|l_Oke8eu*GI?Nc^u%0V~1pP7VyNticQh)kvvZ42HAo!o1{kpAF1yhzEu&nX8gNYN(DQIu<`45da_re zvpUJvPX6)kW(HB@9HW>r{ZUNOPnpExnm5|dTZs?VV1QLzj>%hnP^v;9@kp-5y)>S2 zq~9fiL;oBwX%z&r=A))7H6Ay5>A>6B{nuzY`+G++e*UlG$~L* z6Oy-|Vq%EE-{}LB_@50KBLwoLj(^)ZH#&bykT!~s;O<_D2RE6okaT{3Np8kn>e+l- z`+cc?=!MMo`{$T=$V)IGVTjLFW-RFD;n3gmVs9|)_4?=33SJBR#}4U$ekdNMdOggF zt|;wA`V5hUX_L6FC~n)Bw0;R8IGAgC^rxC^Ryu}?n=N;yb!u%kP5T0lx)ft^|HRp_ z(Rx)P^hHRl)As^n3~-Ng0z0m$B7~STdSYhnt!+r$|ezazbJS(KzACU@yc|^ zT*92hsGWw9rig`=lMo)hUx_wcw6%oz_sfA```@Rx$bxiHy^J8LV$GKdi}^&SE$W{t zyV{`-7vZ+(4s}r*)N`5A0X*-(kBXTQo*eL2L7UasJ+CK;v6r1@*xLcoJ+v1dj?VO@ zrbA(s!Usf!LYJefyd*h``yg5#;$QV6q^`3;o$)Nkahhs!R(`DBAj7n+H`1V(7Sa(<5R|(0EyX%E_sTt;7oO3s5V%1Pkh1 z*U)&7#UR$eZ|UA4rzYXw!HyTDHnI9$DOd-{I0OoJIN#l2MTE;J7=Yb8! z0Td-T^p*B>tP+IV=~tIGm{OiN-KYX^Sa+>ji!G7mPrFlG3oxzdTlaO>sDE*V4C!;= z9)&=86MqSktHt-`akgIK=Z4R)n46canGxeoXyk+~EC*8Z8jLl#)s^=bFeU4MvaO>{ zZ5v(^5D-_H60xXnyS-b!6BEgO2O+JcJH_@P&s1E*4bvX<$iEy-iQaStRbF6)V-tq( z&~^F&Eo?@>xE3LzL5^o5!O_JnHi39RE`=y!>E)aV?&xgMk`1W;rUm;W%$^edWEygK znCS8N)NaxH_VjzVGb#XCil4SH`{U2Bp5T9cd%1dj{+Ym~H+5Y))|wM7xjq^WIutAI z=Dc!2#n1BFA<8&4)?&Cje|d>feD7ZrI)dW@-)|FopHYZ1v2 zBrT51sA;GTMDIqcAS{2ttao&DR}40!RvFZ9Z>4UP2d1yJ(|t}i4Zx)=iOKp)VS7ka zR@oZAUvdYdxmPz}sYebgEsDn}?Qb|TN|{-Hi8#R_{|+FrF ziQY;Z4Qss0$owxNyCKl4B9CbSK*KkWxP305)S@pN(=7B= zZE%|09js~>IVh+-A?k{tPdNO{r(YaVd`Fap+l0krQRn8U5NbnSu=XR(L@V*r!OZe! zlB6Qu$t_$yYeMz_hU9)9CCuO<%qC4W?A1#53DINY2WDFC!c5aX9lD_4N7xA%;ZZsR zH?Dl8_|i>XeEEg~M_`MLLlZ7zM$t0hB=f814juLNI<9UtCyDYu=v4;0?>`J$zRn1^ za8@!?k!AY&hHAuyGT{KjbAciskOkQ9*_BJ1v zk%2d;h(t>RCD-mh_x#EiTmFM(kJ84lXlXylN-EDeMIq+J%HHK!CRN`WnQE(LEM(7! z!IlDO=x?qj_;dT!WTd7A;7XOG{~P;ecPHxn#HLe0{7*p3=Hi;o+Rc$!G0}vDp0TY2 z8Tgp8t-mEu=mQk@cZsl!)cxU^5qB4Mn_{Zi?N!)f%DytFivTw#91i-I_bwM8$6OU+ zY99X-evh~m_WS8lc3rg3JT^CSg^FtZ_(I-MZclc4tP*#dP#8L=P>B&$eLZ?Xx0_)! z6me2#WX9O!U6$4-)j(Z2XO`~40%1ky4o-@2x8ZR4y(1GDge^`6?p^uk>8qK35tUS<4o zpMU82=*<)y#ZL~K5D;b4(z7RogdsW~>qj35R2&${tcYQV?sZ+IN3pdgnAoJ&;a7YJ zDXlt8l0`FERy?F#U@OqCxr4Y@5LY@Zj~a?= zU*)GQmvyb7ml66=#b;pXD|aERZK&+iAJWP+{Ox7hxOjlI#CD&6mWBuOSAEgmy>EKx zpD4osEI;ZGBSfUwU7V+2Np!6Xp`BD)yjx-_aSB>cVO~M>W2OE4qF`Bj6^YgBIqC#5W1sM8^6I}LO$+qQRUS9|KWx2obYx-IJs78A+vwP~ZQHhOo89TSqmJ#Q z<5Y}}ZL4EDQ}4TG&G*f3=HFVYZq>c_dCs%XIs5D_Hu@~=SIwq4CORSm*27bBV=!FV zixO|*H}1G7RHOO74}~)5CLGP#zmYqh2Ad5tjqIimFYPWs_@}ASk8`4lZd0LL0&d^4 z{f}y*2|dwFsR&0oX`DOWjPQwA-;MQddn^kcWJ%cN} zMNSbTQCje1w_@u9GCneBcLW3)#l*o?2rUdZ@T*3>v4qWho)>U67wd?${F* zDq) zcNPwcOD#1uD?V7V0hN~XfxGU&E!WXluZM3h4YALI2QM1#X$3qAth^_u;m$F)^{V z>+f=Mv=fTc7&MlaewTH_8D@#{NO!q6jN)j=1T$g3nX-)`bJmU@5Xu$`_JQHZL0rqq zlE$xS^U#v2#-qH{>XnZxh^U#NcLltU31R+jFAkH;!)P2Qvqujlnqyiavr4Lgit?80 z3^VSBgS7<7a!8PL4j~pg?nQt?0A-m4DRVEB?qImBZPv-@rKnZ`Lf4>B2s8Muf;n0< zEZ*-~c>>$luo%ttaCCm&=GRlkEfZZ4=hk~uZi6|(-KjVO2;cx>#QvW8j1*~mO*g>< zRK%v3`%P#u{}%S|mNLT|_;>*S9?DzI&9NaLY6`ogh5`P#{h6v96|%?b(I~31M6@3T z+XCzqQ&P%&eOQ?NZRK5u!9UXXY3*4XvgUpG`stvDIM*Ynh)(H%n?>41Ym%yAtprW+ zH&=5Fv=Qqc+86Wxugu+7kHueAlUFmwJk@u_T%I56axLEU!MB~oIn=+*G8h!fKT(;5 z2vUcxJ!aHNuB)%-+*z%iVSh@uicOy##HvS|!hlRU*36mgM z(R+UVYP<1NDOl^eHLG15#fnxh+}Dl^|bpk1&S*){WJVZvk*Rcj z8UKuz<3ev@AKNkI35tOUis5^2klN|F74NIN@)v2sFY*358w62kbsAhA3LdIoFQwH* z*d?H0>q$m2PjL*Gjx`n))Y$kE(Lv!)3%0{>cHcY+qf9whph*HJFEw$yiX0W)A0u2@ z7ZC#ReKmbuuNJZqSp@J_Di?Qm>r=j=!HkE*n$gj^uu4mS*95P5RXb5@yl&$?fLIp6 z;qYg_^qWP)Q$3eQ0#TOt+}QS<9Stupw2Z8xwEnuQMYIp8=rL&#Co<`P6x!I4^e9vk zPHZ+5kdacw&A{t9mDcv?%qn5zrf zWcm*Z)u{e|!nfza#IL1a$rRv!P+i7h>Kv&aS1nLWD2m8XaA*PfAWMUh+s$gqo|5d; zzP9O#ZKb%q_1J-DFM|jf-0dVHvzNZ@B=`n-g*K^Vvw%4QMR#81z^?oj^A}oD z2XT?-c;?IciG}agj+o(D$^gXtvAf||LX=#3_`b5m)OJ;p=!u`ZJo+ce6<c8!l6> zAK|JHoU)~jGTcAJtJ^T%13q^431S8h5Ysf5DD+kG?B5#rB~crSzDrVq?>-kJYE30> ztw>R{iH4G$m-bBFdn6pfn2~bh8hMqn%q>dQxAMrHAE+KyOOOty_r+6p&Ib!kgp2b1 zy~+K1FXHdV(vd`BJ;4qeDN+Bfn$0ZWJd6Z9ymFA5rgz@mqA2QAp6kvq0*)es`~0!4 z;cL_QQb4W6=a1(1cG_L4Az+XWC|~pfQHw+t{KaVXvtDu;tGvDB&h;}Z44?->IX=ooasi@%!~~*PNkYk9DdUU zmL}zDVqOi)Ph?4nCtEP53eMQGmMQ1$iwgbN?-0d*H|3E{AjKXyqm?`5gj}3Lg!N;z zOI^H6S92M@Eydwn-6cAPEsJLK%)G4`+DPyq=If|sINt}cUU364EN%%^t&u!g3stKz zBBkgX@~PjRj03%f+i-olBqwOTBSZDYf()Yawm^wW0O+KP>M}Vf8ttKjG5wVAii)b` zd#>|vk&AgbMYn&MfPSe&zjv7iajB$6G_(olyFnx@Mj72{?$$0YF;aoB4}<&bQG7mN zsxt6}_J6nX1sF5CM~hf;=Z(SECkbsNEyXF=5YEuDWG+cbyG)Y&sY^&!8)jA?e(FNz z=C_ctjlAF$+0wC>x~m`Yuv^B?t84KH+QzD5Xk})*Ngq3q;~(9vw*F^~#D3&fT{IJ7 zGLc|eXN#))G&T<8wHEwIMT_=~a=ADbq+t9-;QbILD~^$dAyV?v*nFU&<3U(iIX5J=>v@`MN^k}}b}zn7Bu3iM z(J>iuhX6LvxWwvV0KQ`}H)5@TlKo=$*8{qc?*k(TiJC~YEttYyQXOB}!~>{PG(OwZ z9{6@7j%svC@+#w8Y_3gkGL0EWA`x~veg8B-L2>f$%6bdrYZP7T$hv=>mCVA_}6Sq^=Xn)}N9O2Z|)U}JNL`o!I z5-b5m{)#zO_|cCymcPZF<^p4>qu>L!M&$~L7O96b*SHZ8lI&M$+DnC99dJ(lSu6ev zNbhPLD+B(h{*r0eq3sl(E`7*#@H)FG9rO#6JS?u)QISm~42B=3k``xJNDmQKHWo-? zg2!X=jn(1T%w^*h5R~3W6yNzoRoM{$M0`?`5o|Os%*a%l zUNiGs+K2t?g3bB2|2Y9=?PyLr@W@T@;=kEzSx7x+A|Ut^FCK;&))aMosiJ+A1@)MD#9K1oNEIm}N~wlH2fR6|gRK+1f46!04dmd|1M%_Np%34=0vR9p@+Z!8jBbMdG)G&n4DQ>nB)5K7XcbW) zeCfK5(K-4hcN?s|$TdJOfboWD(4aj`*ad`#n(}OrokOc5X&j*sL-tW$0Vf!8EHpob z*8Z|-e~~03<%8@&qrC?~VxCX=y{* zz`+AHx5oQ;xxT@guD2Kd&`K-_=Z(yI-;*0GIoeq`Z91u>1^#^X2~d*^X=8)8o}h54 zB@Ywq)xd~GSe)h2#7y9YKtPBT8`I!HyGaP$BGi*1?S0|I>g=h$|s#i)COSxcXOv_m5;I z8^_UF(1H@{xzHYjkSaY7noM2rVPJq{i1(taQEB03fJ+<$DHvblSuWGEFh&Hul#=i@ za$9hEEzQP9u0m}OY0zs}sUiae>*B&arQW)JO6so{7S5tzAIwswUBobN`5bY%5h5}a zt@Nr&ClOmoqB-|C)%As_*xew7C96F(M+PCXEnaQ3v3K_#w(rkWFgW&Ru1NyIqSSOy zJQcxKckirO7Zd45PL9>q$1>9kXwqrTrlz(7XEO7P+k}P!py}#3aklK0Kp8xLNzMj& z9Y_MNn4L|=x0|I>$dzVsROz3bCeGF{u=;yMhT^w@oFU)9O;C_wg@lCao0=lKv%t{_ zn^pgn#9z|rCzSfKDIPwX_s~R42=7CqfoumXzJZyetL+|MOxJ2*VrTrWZUnC-(_LX0 z=-1-)29&oqEIx6`ubAzVnV+XLK<0m^=OW^Y?#iP_ym&9PKSV-!0_gZi3~+HkddyeTQ)HNS9oOOa=vDG`vWioM}H$@l^zb$A*{{ODye$1z%Z+R;X zGdV~-c`1}l_Kq8P86!`>Vc`Jr=%~ns4_Olr%_DER{Qsd)jI*J(GT`@bPCM;2tuMR2 zJtAROywQ;nKV_7Op`LWN`C@=v_JB?W`Vm`IP%f8jCvfoz7NqV^chf2IFpgJ9VFq7M zAlMr?fDY*<1@hyg*1W(&td{aB=l-xiSj&%)t#SK;vdBhgu0!ZBL(5x{fi8>xiO`r4M29>VwpZG|?;5kc`t$&M3*@4J zXUt2LZI-JIi%;E-j(Wi0Jsl*tX{9BpLbo`@@UKlnsF-z~1s?}D_);g<6;iYKzeQXr z3tr+j0aSKpIisu&W>pVEB{;vUTZuplE4U|opMCuCXJkF6PSq017!Mjq{#&}*p~yu_ zd->^SYDd3H^=?0Vz(K0LASOAq^W9m9wgSTVnO}XAx5X(sW=sW+!^L?)i8UgxB~i=- zXhu&0KOhW!2M0a+91R*&9_(9Kz(q!mM?ijgd1>$f=)k2Y>h zv>*d}g%lZ|K@dCOR)@1cnVHJM)ol&?PCcnzur}l>3K{^&sePTym;XQV%V15zY<3D# zhR?wy=zUlg_Wih{KMwbQC3Y^$K?3H=y_uxdx!s5n@zSnuN1O&9_vE$t!imlcnl~S% ziqBe-c*d|WJMrmO0YMLJkXh5hy)9>R+>ylU5Q+E4{7`<42opOW;M@Rf;xtK5S4x5R z0n&|r#nY``6PMN&u3{8m;ZOqeCj1Zwg%v2nq_!$&z^246LCE+kJmP^JK;Sg4 zM&iaU@J3qK7pW7?`S7os69MwRB^l0patKDywDMtyFk0rzx+!-6spRh<;uH*4>HBnX z+^>0pbC^FRATpx&7c5*<5U5i7odzIW<9)LIK?68ufZVMXO+)$xO#G+qa&h);qsteo zEu;T>z$2$Ko^*1KGTk>?Dxu1*QdL6_NTvFFAA`j;cGxONGhmWBF>sD5W$`J>+XYwc zJB5h_2^5HR2sM~twtIk=E8v;_d<1njin5huJ)QAp?Piz$Y>of>_YYk1V;c_#^p9}! z>-bS1Hd$a-w=sb;G0ULTZBu;$Uvsl39@yLKD;XUfJS!{fIjevGAsHDN02$dTPxlrj z@N#!omP%bqp`Xl@hK|X6!pr*!xhUDkYnI>qgb0@^=+ru|aI3KQ%Ipfy^r~>2rhJe! zP#PpX%6hfH-7M6u%LJQaA#jr&!=EMaXD@k>`O7`UZzNTK51j_o64>$ikpSCDbTc?| zYqa9KfB}n0?#j_hRGBrvlg@iPdiW>9{=e}$_h~2tGh&%V%kb9%@a7L@^K|LnWCU*y z0XeBa12XX61+)W*v;!*ry=gC=`oabXQ-juTWU1JsQfK{?r3)3Nw{-E%VGPf!wbA4P z;HM?2gQ@^fgP6E;Q?3yQjY@rj?Gy)g|NY`x+gfv_PEMd0u{>+u2ZD zC=pOsj*r$6iK5yMqr$qoLqx{k#ds3Z*uew?q<~V2o1k;u-8+naJph(OgA|j=r}8@c z$VWkQ$l_W~;4pBL3Q8cEg$C5d4q}E^e6$Rh=@6Rs=ZN2W@6q21E8u;6Re{sp%7`M8iWg|sCAG9w11W!pWn>FR^o80 z6GzInI##+?QDdnSS_0gP1IAp`&v2ZoQ2BYWBL&9xR%OxWpW^bQ=OMx~dO$HPp5i^I zZ|273S+Jt(d{EM!o1jbzLSNtbfRHyT6p3W_s`qXCoShGORE$2*v*#KFJmg`NSpYp= zi6PgF@IYao<40@mOypqL z0+LhVG!5Ok9R&7FF42<)n+>3(DEl_KA(*yyUG9)02}WSkwOS_$I1Go!$Fuuce=)4C zt^T+(g|)XQz{St%TbP81qDog;Vi6OO;kc1!)~`uQzt_K2(zo7*yt*%X@!`%0t~}FU z@h)g6Arkzc03=R;Pn$YMNPlCwn$bj|x69r|&kR+ZbFyeWNa}^}{2rUe9;6KR9312Q z={AF`|ILB)_XV*d!)!;GDN!F#s73@{tKNy?q=Ir z%Oc+}Bl20aHt~dk*rS=bkAn}YSnq)cs7KW2pge85#G0sXY2DA?!Jqio#C+IlZ)L|3 zI!M!q{PJ(iyUmE0J#!E{U4rrm=M8V@&J_&G-drX!pEZWo|Dz_ zLQP@%*_)_A^YSpB6DAONkPf~Z(#yAzROR;l=5dn^9Hp2#KV>`T(!d!?TJz#Z}Hwzo6VYU!zEDpoqE$PEPR1>+eX# zHrPcTouNkHo*l3kM$i=Xp`@z)I2j~y_F=%K7WB}|TzJv>>0{JVlhV3RIG zLD5Q7((>ClJ$*_pZe)Z^@L-fL#lkjH<=HEg)A_DBRw#_BXMz)s&QH3VE>Rs?^8-=4 zL9jSDfH5@&Ege3Z(J4tblWX*A)aj%6*G5ku0=JWqxh58&P%$qS^%?K^{#XXrscG!3@4}B;kkxSx%1=GsM?z0i?7KPq)>fC38&?7B?o+ul zj@~&6y@pyUx4M_=v}!8#(BPj=+z9!Cu(_QB88&DOu#F>wG0C@M)Si!-oHjHV{LdzR z4SL)lJ0FRp6^`_ZvAG)Ey&*RFpWsmTXZ00NID-`#9c^@a;T{qRRNHEyH97yS`~m5i zNua0abOJ{$upAt;&2JenGIcxp*0x~XGde-?eApwI`4#I^HyGBg_XoV`TDK^^rOU#w z$+TTfU)Qqz%nNb~S16^)wiZ7*_C9!%u2GaFhqCpm?QKkO{A&P;hVfZk70)fA5J9yh z@xi7zYF`p07KO?T(>EjJ?#5|pHKx}aWY1-m1Ab{Z)`dl+KMiB3$f?+%hL?z0keFwA z>eO6OiG$WWyTLQ+zQNZ?W1P`9)OM8XYJZ%b69)Ku@%9MW$6P5h1An_K@%WX``tf2uv@2 zg-zuJo|ut}lMt;ebxWW8VA%Y6SRdt+r5}kBRjrzMQ{D#V0`#Bm~PVADR zSnGUfi~YEYZf}-p=&jWrp$jvKJAJN0V+x=m!M(9I)ll+l{Cm~&t2%yX)Zm@)Zz7i;M ztzcXK*|61oJ+lqnPeW1~jO2^b8dOJxU>Z7Ez zzVQJlCayhLifQ8SjdC}|N+{$fc8wx?Mk48*y}#mpO>;MGN_9pZ2;ld+?0>8!9<)y1 zHy+8_|5$yjw>;(-+=8CCfM6X=$hM8~ZH8c^g_52z}+rw)eQ8$&fb9a~SK@0!8iPQ|jVz|lzNm_@v#o-Zr zZ^eW2=CBv#xJbL*83b<2i6_zRdM$i~9NCQ?T^y z`CzN8{cS8z)cZ1sJf=fc1>yFAfO~ePnYaZh@fuliP%c@*!RbnhJ{&6yJ1hg?^32($ zVEfi~a{_Uu+Z?A05(5<*J+ZZBJvCP|>Se+K6(1}7D6wW`sReSg^U;(#9JAx6;y|8Y z^i20#K|!q{Vu9DCG&Q)-H1R@>y_lC4_AN+Bz_Ew+)Bc?B{TP#eQpX~yYdb-6J$T8V z{bu5u>2q&CeU3&ieU{h3)=mif?*Z1@FDtb$Sjf)Z$27Z)?qCO(*@ zpV1ssRps1BK(l24#BSn9UDdnzmha=cYh7@;XXeW>0_PV^$9zD4lt-KC)o1rggMAsx z$K(sbbxcKHaY@e%cYPncJ3lTzO#cimdc1>JggKw>8leuz%hyPib#0An|9V%v-Q!TM z&igW$?gnMkROLB(OU>;{vWZkAwqgS)b21seYt8#z4HX^MsQ!j?v~+3kGno@@0UuU zGD%PzY%=lA?7sB#`Hl7S`IxoO;Qi|U2%--$>yB?sbiQDpowKn(&}+mgd$pwqDeUf~ z>jbO4CSPfFK`mgSer6Qh@D15y^9Y&4*$Lw{YU}-(>vF?<$L&f;h$_UC)IG<^ar;la zJ_$W>LXp<~59_JI5_fL8>^4%mCiJLBG#Q043p1=hJF4#>({^*4>ejDcQl%@7m7nzn zhtfY%3G>5f?;ILTA_>&y5b8&5IMa3gin?48E;OMK z3HoaeJYICFbhl!cASNv}J4EV>aBStbgMG1ju?7jsl4N~-BD8!^C+Q{xFDGN`MFh9d zpnmaeLVjIwNIyrG6?z#WK;HbdaoS8M0tOsdP?Avyg3Hpn>KyK&U?e}cApkw)P;2=^ zAEaydw#uGXc;9rcQ%*?w`B6w%f(0ysP9IQC$J9Gk{3yZhqZD)MlMGh`d|z>;m46MN zmAk$BiINWgZHvqc#jw(oFpINUG&{}v-5c%f{j^A7y*rRA3^wrQis0i)@sFllFP1y+OxI&6LGV{R^08GUK+Q9n3ODIdKUk#uyAI$al-OQ z<`@Et>nntUp-E%~0^^PR2jJbz7Bi@a5?|ODGwCb$I6gRcFYXjZXmq@;&#%NA@a*0+ zm+~twrn%WCLP`oaNCvP{#lk<{r5_Unqq0pzOy zCtWQZgu|O#_pgv(W$T3>?Uda|zOvAe-7Q0u*Y8j*u)@TK@n{@jP7s zOd>9NL~6J64Rg#WW9I3Vhn+TOZj%j)Lpnb!6u1Duh{I2tt+nv3IDCt=rQZc+jBev+ zsrsuS)wd6$w8qizl~xCJwxUX|Sn9mD^k=J&Rj3)7p&g2b4n`Qc;%@aoA#~Oo+n6MI zQ5_3`umQ7~R+HGZ&vhK;fz}cjyZJLJ1yNBg;-i}D!C(dNf$6f9{jVH{c-^R;a)a$q zf!2@f2!?0L?`*qxc}Uwa3eS?(1I%&lB>pKLO#Y6S}wjIn}gR*AubWe(5*OkNah!I$Z6xAo;34AqF*hnVgmKiIT-f=amw+62N3|e8D4asqL2|Gd$Pd4`X zc&T3Rq8lr)CH9N0Lrpd-4GKhaByh(L)nUJyGKAmcPv7@#C%OZ_h_51ov-ZgExk)-_ z(9Dn4OcPWo6K=hMeiyRUdfrLD2DE#n62n=RMT-BQ_>98~a7U_%LI9{s$^=HhSxqhz zhT9&ZDLmDU_^12&Y_n-s2qfr^1?eB`4Oo&r4*_Xb{36J|ZGn5`^~)saLJJGCGqfU& z6lu<8n(a_`#o6ryXIV0~tq;)fhapjec3h)vHR&1dkM-?}MJ8au&sTQ1@DF5tK)8r)2rKDs1V zN{rNsulVlZj?>vkE_FIVK49-UK~kJ^aH;2bjRN62pNJ(?7{ZyGM11BuBLX+jcCMnc zhC^q}Z8Hn65?td2Sj=N7N>A1KHqbDvlACe)PfqngCTRK!)^|r|Ip(hdkP5m7A-tuZ zEX{Y)GG=h$Z&WKU$E4luaQ==4i;vrlqfjtjh0hO-G76kpk6v_B2y79ns?8a{!X**!73Nsd9jz0BXoTjP`l=(T5HUey(m_1+)yPT0{l!evlPQ6d zNLnfIde<6`q>?NZI&bEF%JlSy(xz%3E zI!w)A^F&b*#U6v~->9rfzX?h|!3a=a1;zCA8W26Y~?piTgNr8lQKr=6pT#5vtA!FhXgqMQ&~%=xXo=n zNA`Y#*FPnh&h%h6obQ=<`k-~5*h~6YF)D1R?_&2oT9ZI5?;uWcAPqaGOPis0nmo7R z89nRYBG3EZ{X!l7%-O2?o8>Kl%P}$lvV7uqK5oucR_wMXHR$0*!TT92T=DwL5_qfl zR>g6YYQs;2vj8$-5ZsQLtke$v$=JyUtv_XR6;|)0%`3-sV55*x`A;vtK2eeSEMhw? zX`_Q9@vynV%D*Ay@j~14OXK#R-0~|3rwHvUVlr`yY1#JSEziUz^}grY{HTSVP!4m! zXgT{8CHYqthMn4k3hb|ICRiG7<%Hl%)^ecZ_wzzF<^&SZ;{h+U)bEn7f=HC~*gR)40|PBLr!2a8oI z$}lqux{)d80Hw%twc~oL<8>mlq@z4Rt&e}zMX$%1+GGp-h;!oJ^~B*U2*#7&H2Lm6 zuWqos&u_RqN4=7>k73E5hp`T`;K!@4nQD4#P5fC{$bmK;@-Jt!SBCS(g2GLIjee23 zM!7?_buxe*UbirhRL2aX_&_7vMIrFo(DtK-&sWK_&WVXcmOby1W;=NpT!Eefj;IRd zUL~+12;7iCcEg<$CU7SVW|r&%#bY>vzsn+8=?$AjNpfJXMzAMGm?vo#U-!e==i3|+ z89zrsoA)SU-cO3NE|C&mHSE6|$DdvhPWd>Oemkm75LNP|ek2x*ZoL3)9F5R}WF|{-~Gc|PK&1$C8 z&xEnz-ApU|B-&C{9S02Xz&9QpeL$K%F56PZqm2H&$Ebnrx zjX4V9#YBSoQd0`xd{uo(6N2SK4U5ZkNQE@roAMz1SA50Md9Mj-^p_5PC}ZG4_xbMYSsqDb-|wzyUuncRNU2$N zIqa}ZR40cCt|oB-Hm!%)YSRUr{@*k_OA~`P_dU5DD}Mlm!gu`woUJ(D*_?WsTBJL3 zE^|=@f+I=C!Q2?D$wUXyesb!TE*u6otA8Bg%n3`IA*;ns^juKJks z6t;Tw_xs#tT2Sju>r{?8t*rFbc;)&3F$j#4byErINqey3G4yn^m+3LPQxD7VBIOs7WoOyAC*!vhRvU5(M_PsB zFQf~#PFVlOms~gu8gMzA+EFts_q;ZZt99F=E>T5#`+bv$y>)DSX#_H45H)ne4xSA$1cvRZmTsDVA z*oi%zvK7uYdqRNq$Dy|Js51t^_uy>kItH^XU0ms}RI16?BHOf2E7MSRH@XBU&NJZjQ-_uWbkIQT?$`t;o)))k zeyjIx9x441x1?(*nt4EQ9=UhaGw~U*K_j~73}JdDf--0&X~RIjnC&PDbZc#`;sM;@ zX5Ey+XHZ;Zwc}aT%;rCodAPh75bE$rn|G8GHI?laAEJXU9iIfV!WKm68v!Fw^fNS+1+b`^C z8076_H`6xEhg+&SVghp6Zs{TQdjpew(QiR%MY1!*)W>=&qo};$)au87qaG`h)n~RX zPcLgp!%?NyWV7^5Y4w@oQB*%TOfVCC%+XU|T&9y4sb69ebNp)Y+SSqZ#!mONjwm!U zcnsWUbotwH^;&?vmngZgws6TTBCxprS>%AWlL`?AaJ*PXL9?a|Dzm)(v?|2% z+q`sJSVF@ZpVMs?La5SmMLl8_GjGRYs`%zA!;}De5p*=;C?y1M zmk5&DtAC?0BcV5=;-*i#A=0ZFyTN+MjPU16Dn_vkO!zXk`biRlHuP)%XMHggu?;h4I{8x?m3NAz>?IozPp~@dLpOV&1nBE z_(A8pmed+pIh?c7ntihL28}gYE-G&MwIfAvzf`RjeP@@GVl-pMB33|l5c}Fk*YbTU zg;aaM7u+D?9`G-(PY#A)No?)sNSNNTrmx09-4?5!pzk|Hu`5cPI1k#TVy9KW_cgB zU-by#ev)gLpM{>v2wye%x6>+i&EYO;q6sS7v(F|WEDt25jt*L_Z1R$xzoM_~T5O?! z8W~b$(PY++mBW0+iPFpPV4JRmvEGS0tZx5#$Q_{X1QxAx?+t*3g9HDeX!Q5TzG_zB zfEjA*?=-H-E|MfrhOp2x*bp0LJ^fVrav{-ZKK$Koc^_I&XtJ;yLHA-*P#Bj(ssqhr z2Cek3;YM$ShVz7TUtj@X03e~{6 z4*3|ijM;KMNwUHoVtn;;d^M=0BH=Fu+{DcDnT2F_nk|i|yxBbGGRC)~(`$x@Hmw9o z)TN_eMql@@QfDMTjhmlGDo=}MDrNg*8B3tD_F50$mM{e;y_MQi&_(nRQ0YmVgH;q5 zTd_fV!ULgT(*-S^L9yAe@GxJr?K(IGtey)OJKgl9;63K8H2v!5KPJG}w->)+BV1y= zL@(4*zM=xr)Qo;GQ}q5CuJS|~cx>XdAI$QGv?rpSi$=9##wIU*v_sWGOmSFeTF5{F zDF?(hpd=3JRGo5kS$^l)6X>0cQID$YXz?ZH6_vmOPGk*SvjIR2o0;37!Y%xD@_Mfa z7qq+{Mu$mUP10&4v+`T);fh`s&4uyt9&cT{-FYYy{BhA|$EtKumrWM=c^Pn1$1Suz zC8VIAy>F)^XbGBXwbYl&c~L!c6dPkY-bOYgI7qPttDLf1+8Xz zrv{CS85=NC*l{2aez7qe8n}=gUt3->Zxh@cc*=nmQ881XWy-CL!*Qt@OrFu9W>F&x zx@k_6=CFD{$?6m2!Vt5EbOA|jxS2G3a{8@u->)`P6ojNMnz60IQ(PN`g<-> zoimbuR29y}pJ<({Hc!b%Ud1j|8uL}5I*fse@L(lOj33aZoPhLmb(wfK8lvU&;6Dd5 z>lqxkwcm`L)qaCVoYtB;ShdS(B&*;+qqhvkO(0_f;sZlz9rv1x91 zc!kIe4_+()ea1wsEZQ8^FXc^!K@T;AM;mZcUXyaoBRLUOk4Tj;g>ZDeG{&L0hrvV4 z+v%(Rr{ITW{Ocz!iPLNPlhEM)YdLk1e&j@VRvmmg)SwpclePmVpz^bHLvDNCE0<;g zZCD|XqI`C9Z2Ip*?OJ|nfrQeU^#~29$%PE(A(qVka+!IrhwgjY(j!~j#EBRn_v^nBxMrL7RjU(i*5#yXLP`sy7$RPcYQ=R>I zhtHLxl~Q&}u1dazU#A(}7&MbfXg$q<;f<=4qYL+#7L@Xodnt$+bK8HWWZjXx5omyUll zoh!78|Aoy&Yv0W1GiS!XEx`lF1LM~-$}}I0)sPoOw7 z6Es2qG1gjYz1+}foO#rKvl=pW(93yK*y`VRC4~Rzp7EhDR)Pl>FMRje1$V(OTe9-GI;Nhcb0E#4j$=20}^M25%-9+`H(T6r1UGL>6BJCc2?g4OvN%T06$6 z6tus!S9JGSfi{w!ee%(acXK zYNo3A;NldTS8ICy-j7&FZzK5ixb>Tc0PLQR50aam36hy%fy1udxcc1o50bBYYpCXJ zl?(OzGG#gT90jz6mF|PLUHPVJwD=YNC>XUWa!w_d$dA zw@`_vYZ7fqgF>S~tk!f#J4n+r3I*kZIYz2{lZU0_S$JUOxc2krTX7E<(;lN^VG(xp zcX$IhmP#!_xaOE%%^J78deyn2RwHl>Gr!`~e;tnosEicdiq#>>6YZbp`e}qO8b5<* zSvYfF5TAOVji8Ik0r77d49f_>$Y?B524WQ6^732yX7b8oV*!}6nNr`b?dB~QK)kP3 z?(&wF$t=70F3Bu`+Sf|sDW;L*@?u7zXXc*QMn6vPTX?qf2?TxPTdNDW*Vo~biwO$@ zLSwr^!s-PsipnM1(?aV3iohNykv~u)>GaR@`Co`vjLVUh(5O8RViZJTLD{d%sdYRL(pY{4O)Q_v-!qw>sVgM11bVj3zyR-u`2BEt?h!mF z|8xF(fX;u*deO|!_Jylhcm}^<<{D)bXqEi!P)r8Ovn7#8a4hUw-{KyT%sKhC+(GBA zc>F}tYQydtAUWj|vcLYjMj$PLKYcKo6GvMGldoU%DYa;3zuamntlbM)3To)yQ}8=R z7{3#Ijh8-?pk^~*-c2!1OF&a7+fP5@@Y-zC#ppI?c&N>`oCI>0hvtc1`B>%DF{Hkf zGB8|LPWT?#|F~QB>mDcK&=>t<-sh1$x8v(4+h4?puW|=8txxPwO@S+NfP6Q_qp;R{ zX(3Sf3{RExt%n;VJQW%IH#!(ty)EN@ka{c%XVk#j{Vs{lQOb5R-gSzW> z!Z|ROIGf&Kav_r`u9ZHRkw%n#QFc$dLSmLkrIP=&2J`j;@ux>kozL|NJLnV@V|l3* z4qQKd73$>cl7kxb@KiJ;Flzb8{|+{-XC%OMlKuE{KV{;`IDOwj>^T2l^xhFtkBQhK z3qMF+9n;nkc}mU$Kkrwl(t@;B)?VxOHFt0GPj1E-6}w*|0$1%kDfy<8-?lRHsX^;er;>g-cBfxZJV zW*vRHY(cmihZppLkGxsyV8ME}D#%X8>i!407StLW!#PL|Ld`5&J-GleO6b1wZQ=|kbJBtp z64cMLXyKjJwqEViXyJD_l-_r6PHK|M5_u?7+a3*<#d0{H+XFNk`N>!Y$$v10Ge*Q= z&kWDicgwfiWI#sUt#YtvgD~m);#07Q$eZFb`wLu0KU0wqwq6)Xz>8om7u^=2D7tZ*FWsvc0eTcwEcr3LkRL~BmoWnybr@O`N`r80)5B71@Sd~QJO zT&&tgnV)|q!$lRAxzM%|As`He5KvL9FG3F)>&PzkS+DC5*u@C(hlr#8k#H|Cmrt>M zFBc{zB(3sNQNL{8)>!k6**rh5#4_sxVVYFRy+L<3-wk4zHpU1q<0Fhxz7(=W7SYOt zu;qZ?FUgY41=!{;XYKRJc~=c#S(Pq-C7I|W#YpMm7HH7a6_fbu%T40KkZx9>){stp z-m624I3OCeO=4*oP>%VldqVWPo%+B1$>iD9A%3=c_Q9xaok9KcU&AF3rP{iVfjE|=T4^Uclf>fxQ_2Ems!CT<_aCBc14nKw9tO?FUn8L6G@@9+j=8i zG?87wiXx{qM-+~fl|qp!hZHmF_s_Gh`H3x+qBrtQ*{$e~^Y&Nn&&lK^ngQOb24ZXQ zZ!Z7;4|yp>$&v;jWvWfqx^0(-qjtokJgZ>N>4ubx-5xC&gfYg$fvzRdIRAX`O5T-u zR>6?MdM16?M$sYy-@u1m#CMTS1ViNRmowXqk28N6MW=~BVe2`=hayL69G1a;4wdO@8`JG3D zP}GkArAwMfBD1oO#%{}J^%2-%3U*MHGcK2EzfU_{Y`P0LbH;?-)1uce>U7@Z-!MME zO)$?QW2_RaaVlfzA;c%eGGHT`RR_!e>-Hv;fUQs$g6Tj)8`}I!U6^11%;l%G=Aw=A zR=R3*lXBPD>ovr@?G&j5KXJ(0h{F^LJd0k5*eF>PuSZOutOZ`xzF_^2@gP>FT)R?3 z;}>&)J=;?=t#elzsFE+*lqQdDTP4YyRyPL0npLj{u;+hG4Q(*u6mMD=xrOejo<5hf)R3lwHGV(g4hdMjqj*q zd?RBarUMLN<-{NtS@Y?%C(uxM6cpz&6JU0=o%6i>f}wlx8SAFnTTbgaCQk0sjgNF^ zYcrKq?a3vT?cdLNp>o&-bleEjqq>P^lx7W`BY6aj*$7(qsW$c4y;yrMd7p3m{f0sT zyDw1SX>VPPCotdanbTfjYx65xwevHaFRWVOuIRw^pOHIr=bw!6H;9zzpuWHVT{NI{ z!m3v7ro;Wh$*`<5+Q8s+>;3Mh%B)5sV7G@d;ux;sv|b?PCgtegi|{&Ek0mCPxF2$U z3v$Q7`~JE57$G&!kBh{8EA z$iKHV{DLi#=KgWo*_GjuwEZ}7)a=`4YkCk~O)Tu&go2t(b2cXbUwe!n4-4K&*-%C+dD0F14&N$9X&1XC^XNtrhu00JGm zRa)RU-{^XA-KzxT@cL&b^6`0vg3d7^yTy?J5e+Yfe>Uln=(@`Dakt}H_}s8i`pfAr z*v}@{@78$FXb<+Dhr_^?{l)t0zotbEc{4j{m;z)5X&wYg5-wdYex;ZC3mmw2URC=< z*q>hub%T4i-u3nM(S#788W!qqL$|R%;+%HBa(eI%O_(gsk!mZ5B_6G&N*f>9`$o>W zDYDEiSeb0D5qW=3U}wpA+SKknVq%**N*DagZ#zb2&KKcUO4;VD5Y!+W>&81Bd^EY( zn6!=xMzHSnW6SqJ_x)f09)G?E}At)OraUz?ni@epRP2Nk3YhR1fA*fHI|nn@;lG zTP64(r!mCoPU+N63_b6t$%ip(vu+KPjrbql>j67p1RNeZI*f^E#%A=t43ZiIx&GLw z(2YLdQ3%N({aj;{>PJ@XDfeHdG!AUo3W2%%n18}TZafP=c-(rNqg}`v`wO+#*pgEi zK^(d%NOgsKe z`R9_9(-0-Yz2i4{_cTU25@YjXYe>NJlr$wJB`GsAW?cM@l-;D8^up)%m`F8GG=BKr zC2SB8-3PkRzLy%(;w{DGk*DXc=^G1v?h+<;z$7p8Yzmj+WG9|2iU~f3iGk! znbXzM)yTELF~`KG;MsS_Q}~yuv}na=b3M>lSZgvm8JC4e#$(Ikcb>baNwC!DrrDHQ zob|{Ybe)oc7*hzS^vg`vODPN$lAb{$8?nX%Kyi+qfa!rKWrQp;U~j&ZhKkV`)=SR9 zU&cE@DURs+zn4g$#;)df`Rg5VAW!>YTwL4?P;AP5dxq*h+{d2uoly~FHIN{R7f1_l zM4ZvX^(Uf7+Uvay$6tu<7R z4>q^9jwE0={%++MFB=G9Zjp-@UzeZn8z%?#{xaSYIG~1vTip_R$Xru*-IwW|$;G`e z5b)ePEpYo_!-t>ugT%txn(;od%iXC_|GR#ap5mLdhK2@NZ?O~^BWcncQl5m=>VG7h zusR$CkMg>--m`XSEJwJ`d0kABgwH`j;(O0vo^#RB!U0WRyNupfTs}P4_0^^HKBd9}cs0_`@uXHyyS!LK?wzZbzQ1vdmTQZrfqZnQoKK6#(>f5Fh?WgdqnNxGhb!Pf>jOqC%?$pi#ylyORiiopSg}I z0``kdP=?pXo35oxHLk#)Rbkcw>3&3>YJuf-Z~uCm03gu`V%er-NW)A_uAz+gez0$7 zR~YLiGadE}S__H~ALg!aIUOn;BMC;2pNjO3jMFXqJY%bo=)a^FcRz1cp1DpuQCHus z^u}7V&)ka7{73g8ekZbMz%90sb$6KBYKu$ReU4&vT90(V;+adRA1-RnMO6H~r~X7)+hi_-tc2RN8)2&NX)!zHgEVvjxg>y zGsQU*9$O8Mf(V{ug@6j}QK9EY2LpTT8a;1H0y^dM+SeWF!ROQ4;xx!9wV#A1;UiZBd0X2z>`9Fe_O10H@5?86 z*(db8Fyycy0qrTyajkeo$j1+>a?2AZ0dRRvz7f`o;+wX=DYSf29(@gwy@Gc2g|$ag zKW3Z^?j#tLH=??O3m`zp2Z?)aXU2X1*xsW6%p%V&-IF~dTnIz-oV@VfUqBA*e5y6O z#?%#_Va3c5Q_G>Vx}G;iSu^kmes)C+BuFL8my=# zy6J(C9LiV)cLex(9&#PZ_=*E}!Tz&{eq!Fd*VWmaxU{|5C z$!6=r+R;)ftFLec5~xP3VUlxsW#0HoUPFOazh*3)xF9rH%%T+TK8QlZQYJ0^U#f^e z682|4J4EsonsE0l(zVx^&vFa?$xO>dB8dtO1p@^9Rdya!TWRk71HFC+_b{haPYyU1 zz&xk>%*R&bcxF8^b%{KAIYsrW#0#U7xg973)~+MVby%sWCo$F*JvVN^5##m6R1bv8 z|1(rq|J5;XC-~9~*rO&`ba}ixrO}o ze*MFr?;RfFzRw>&&2!LfzW`LFFj7sq`aoyTIbu@V5xaZ;REd zbY6T-TX;jW3nifG33JAaFxOpf@iOKuNb4(w&=Bo~-}Hp?0htEKxj1k%26v5$r+`H{iB$=#S=TmsmwPLo1$ipKVj{p~JrKj(LeR!fxx1S*8UWebw)u za3k;Ic-8u8D=hu`BPgmy`@NQ>lz0Fv(XG$tZu|;PS9kt5+JBQu>wnvpaSm+sESR|Qa+C) z0S(5q@=4QazNw5enM7U=dN>@5g)NQYl${DdFSp{exGBh~+o4^HBAadN#(n<^(_int zxhOQ^uf#cD40W|27Vn2*x|x-j7W}|*%P!%gk99eR+xxQxlbiG^|138iZw;Q~`gbpK zjR=dI+9z`5a5rT}^Vqpwx^CipL3eo;1B|0{-rB_K;7?v?R&8JFpVLiCxJR;-IIU_g z4#phz4Z{X1>n}*X`qQYIzJJ)H2v50WO4jMbhv#K2TU`m=JuNI-^ng*|?J&hwN3!=Y zHnb>*ds~PkpyFyEym&}ES1j)jnH?Av9~~hde4OU;Yr3ZLy;Z zOfg9eR6y=qhim^at6da$dG&to8K1Y-F?%)q5hxu@U^xX&UC ze0D2he+4ZYSGhd!V>nmxu}a<`8q`4$mcw`M=BwK$HoyhCYj}Dn{B@Sey5lJIvdXL# zMqDv1N{zhD*ed>+DW1P=F$5u8C;?OB(|MSC+|#o}tpqKaRZUoAlC@)TKOYjYXmOQQ zN+}y+UVb>K`z?z48axI-HH!aMk`FW^4|jcB^Ky=W7t3zOwJ1HXfb58apWs?n6S74- zAz{`a7&g=^F0TVS6~AMn?-LeKU>wFxN7lSOUmQQv*CNFXkiHmh(d^9;bNQ_C0W(OS zxsY>?1K($t3OL%#*+&bPs$6hx;^TBbfa=R)gclsPjoSaNh!R2saS{DP-c(UhRqmrK zTZ|&p#|&d^kDFo>UC&fRQSgg$;yK&Y#a+{LB`+(kPuJ|b^<*{;}?dD&M|DQHhc)f=O}5rkS`91 zwvl+&?|UvS+?|M_v+%9J-~yo=*L`_TNTVsskLYea_r~SuMuBol>ee2C@&#DAzNMW z+;82;IlM7CNB<+EpnzsK*VH@4jf7qyos^L_`kWeCyTCHqqR{UbtIlF@@aK68u~56W z*7WrPCHiN?Bb=q>fJL1ODn8d4yTt)-W1r9^f@V0 zH(?>}v%o#1a!QC|7d7`?7QI<-IWA#q5ogVyd5oFBh<(a;JXIF89}o(egrA>>Xc7A0 zIoO2@T-P445y7vqhZR*rJriXW^2`-%;2Lw-i}eD(^RsBGdnF1)G{QMgC7FR%)^gdb z!EZ76>#Rz*Thj~9@PL4ndz#zDT)8T}Je5^87U`Si{A+Hni=sO`J82``d9=z?#XlrGa0WPSfCRF+TeyWTA_wzaGT~jCNMsZL27bMzhh(A z*Wd!%`b~EZ0l+FNbubAEvvnX`e+d->A0G3fpT!+2(3`fep9Z#h&9xTecZh)(8$HW~ zDJ51vlq_@)Tpu+Y&akEBE;Xd(HlDrRWM@9wa`Rb#bywsDaS}SSJwGn)SM%_23bD*B zfQ6DLq5hMZAgulTxp|oN%uiMhBsQ@+7xUW2r9?DQDXfglP}*jryA4!?=0%rlrBD>5 zL{E7x(#|dLFEHZOv5Jn`o9pyj1_bviWal?BA{BlA5Wwu9!&yCs%-o}f6 zkI-EZXd@y>2Wxw-vdG z=GHiB&3<+W5d|Qn0TdamB*+o1^91u1u%_8{Xt7vMD#GVQ?_EV7(dX*CxGYLUhth;(7j+ z?LMzNf$60pkKcg|SzWk6GssMvKe{ew!ORDl0Or(Uy0feNU?ve*-Pj^X_VZ6y*BKc; z@832Oa{amB^z!MwV@FIEfpfvrJ-oxJJL<9r87Mv1!p35$KJ*J5Vj`r^*Hy-(+d1kj zxnTT>uNrtWyimaNmSL$K9mLL+qavJ;+Z9O!g}W(U(qZ=KXEADM8O?|g3>0vtD`C0o zg$tEe<~|Q^v*wy>{6gK;irn<)WMQg_;6&M=1LJ!`B$}}A#F`txz!$yt;HkFwiE?>d ztkH2j@g}Repx}p*P5y1nXfLw838X_6SCj(2X@mqlr=UuZ8FbV%A*S_Cer+qN{Soox z+!_)|qE{uy7YS7lEc`44^H;8SxL;?eXqh5|E5t7=_^GQ}5pGCY%=*&F8=Y2vHNuiC zH(N|Miz@rp{jlKwD7;Xc+F#;$BB>KVgv@l54Ac!i|N3Fn&>kG@h{eld{zqFaZq5MIO@i@k3Z`<83_3+`|Rp5)11dkj9eWJ7h+ zd#?Na@W(=55QNcl2lFl7xqzQFj>W$aKBOIe%be&sV_=$dv`Qk0#kkfn8-(FT7#sr@ zokjeDl<3*YVA7L>ndYkqNPVbK*Xn={pN-dBpR2ka6Z>*~EGIY5=CInG<-JFl)tvWd zdHmTu$lE&41^tMxv0-aDE^5Ynp)QsQ@eZ?BcvI%LOEltNhV+gd#)w`L%`tC9duj8gS#}C(hC~~xTLO1KFI$5>(WW1!jMup*i zEXf?78YtwEUtNdy?O|t_u}0>L6(i~5!2G?wERbU!duwmWbcz&bzshAu8R5(V4QDZ2 z61(JEDS)+a3w3KTkeL|8F%|3EWKfJOr0X9!?tCMaU;`YI{(wA-v&4Hd(PIl{)2uH2 zEknft$<{!h*9QX7(NvLpnw|-fE-A?5{X_X7)6{q>=)48G+Rh!%HX@L8ZgjES?f~zo z_=ekaV{2i&ZuP!w3HI|$qSs3&dA8ACK+;XOP(06n;Q~DPY#4c+{g93!l%}gK=ReZQ zPEmk581*wjazLCVz=_PXv^o+A><(BLjyTT>yC1u z|NQ#t6v$)Gc+IL)SXYHo?0!UK=+HlV#P70~>I>=JQL23Bx_9b7+j{>rcOu;C@j|bf zBN(f)?D@UOLsTTB!0Z4w48oOwo1x9QP$eO$A$S@4JISO>UD8!B%Otu!h8O{)0%K{= zHb5yP1TQLl2V!PGv$9H9Segtk18T8w)ymlZ5&BwV-y3XE>I%L@Gwk*t(X9?-i5g@# z>IhfTbw=7My2_~D8jhaUn_mcfOjm*V-iAnL`Upxx-4=nvB$uNc{cc;GEF8LK8OH8$|kRAvLmv>k1j$!WFehU`_=8F1;AYz
%nQEffJ1}jWS|i$;y9$|2VG%f=f{01X=AL; z!fOogAGJj=bf4_i^^GELShJFKKVrNaxA~f(U04!;qa%nwBO(HQM9lYdk`+_cIKRwk zDQfqYEnCD>e+Q>6@co%#=Wh}JQ|xxnQ%OgEnsO3;YN`&@5|3^m$2c)ghL`CI>29Im z?gaPspt7`ypG{Jb0(~^h{aBM?P_7IMJqt$o>u*yY9%l{jy$P^InTbL4;iaF}*KJBs zUMwdujoQS(n};3B6>+1^D{M(LlL`oynPm84eXrut{t;s9S{mgE>#xphz-#r7zey%@ zkHSvQrW(}<2jFaF?d?L#A$p>}Lx~dk_hwsu} zqAJtbPBUt!CtQ*yH;cP2PPM|H4!9V9c@xm1>PW@%h|Klcv1M$C31aHVg#N zDA3i!fWQmGpsZD2Pw?eVto*NGOo@n+otnic6V{lY*>v=JwWFjctp0W8Y3wDL;TFGP znEeEtxCY0$-VZFfmpR)=3^lVs8g<17pBk-nx@SO-kePX1&5`lNLq$Ntg@!lmv2FDK znilwcO}MnwQczh=tV%BFavN>t(t$&^>{(#u&pbC*BZmfSj2%A@VY&0)Lg3+y%0I<2 z67@Abz(yHKtCm+#9Fhv11qzIi-riPl`sczcJ9MfW%?IU*BD$PT@funmXh4m!8g_ox zgxYdL^&25M%Wbrr^jk+OOcwKB)Az7{)1;M87>YIbhy}}|MelDev!EVVb z^Te5BlCn2$v(`m}q^_VKVNL9}!Co)tbXM%?7?;Z#eZwvf*m>s_*}b_xHs!1`D*Aki zz;%EBP%!DRgqO;?zt}ntzgzfX;uQMSpSBCn23px zA8)g`5Ca>TP>WJz!1R{nj4K^pZS)SRJ7(H?47_lWikf}Lk;g+LMu(F?o-aSfHsrAz ztzsagG4%7}b;3SHW^+O?tK_*=v$CW&O{|DK$E&Snu<4-z!FAk5$oby2X-| zy|T$JGIOFH9D=vhI=#Hj4Vg10O(Wn)e^7c|2sL)}rFH>CfphI&J3N!+?H9X2CPZJV zme(6mLOfH$DDg>ycnZ64W0MGdD2zDXjR_9|VdMK>d_6V%-SJ*PCm_~)!14Yrju5Pm zj}`J`?Q&rzDu-ZflOkl5w_W}@vPV3|s=1sOlEruOqvQP+%MHKG#1+fDXF%LIok>0p z?|+%9Mm#o8^3L~va%*_He_NP5K$snGH=MUZFYX}WVjy-%c*u_PW~k}9VXT_i6RIz} z(a+~jt!Gbf@aAcxL?6?@(af4AF6E3hu)5Zkp&6Nfd_8!Yx2r}e)&3vtLI9)GRtuvY zEv3q@d7!iBitFG=#+g4yBZncGzE9HV9lfc5P!Q0Ye< zpwe(GGT<_khH_}H!=I_~_X67zY}Oj_tff(OX0VUIe>UoNp)quLxbk;zIT!iM@!*<3|b1^NBBxq8=IZ2dpai8+F*@@@67g2 zVHX1T1x9l$z?;}nzhj#8Y>v_VrQdDuW`LGBF{=ZRP^%N_Ue+uz-7Z_m1mBj3D}Ch&y!P0iK8K0 zm6zi~4i&ihGS2NXT=)iZ=Fq~j2+jGZ*D&X~xnwC*vGABBbIu{{qQdR9O-<-k@o6@u z5?d6Q2pJnZ3te=WI)?9aLbXpAy?FCv=i3a3*RWo$vc2Wt%-RCWc<*(s@DyNvSP{``_ksmy~sL^^`WQhPpl|G?)JIHw4f`ffku z`sE?>w>a{u&H=V?p!HZT0p~nPO_z9^@<`pD?$%?kdErC1LH? z#s25%xyg9P)ehj^EeeKVx{h;Ln-QRZIu#2qBeo4y$B0OMh~BDqNf7Q^IB$qa&FXjj zjb)hY+%^{(n;X%9N0@|qk@9D^W1d?wkqENU41;fSzr>f@J*Z0?p>gore=dljD(K7D zRI^P)8gn4pyce0fDitlh0$6IRYC|e9EI%*0YCtmGN^n}-ZpXqqn z#!Whiu@=f02L!(irL_(U98}VSFfm)Gq281=t^K!S1261)Z z?AU@$7$n{%0MJvm@9B&B-m&04&}p-@wAPoShYx%j{=+V%<0Mvpg}l)8y%9#cDV~Kj zME0#wxhwhiNG}64mfMhew#cVYS(w_fYG}9wbJx>N?Ui;n3OtY<{t2Gk85`&KZ<$>` zsX^BlGb=K2*tJB|;&ShQ=E~*?`yfpEdXOSNzSer6NqYnc5|bPuSvFCz)?mM5DXcvU zBSAmB->BszC`VuuaI}pTl{-0K445MN_(&}xhKO_kLu-ZG$IgN5MBaUlh92&$M_cKS zG#a+k*LeQ2QyZX%nHyPWN*vKIjEUR6=rZDWG94l}3gs$75$0AR^GBtPLu^Y#LZ#GX1?%ji8aRG0k6}+2Y z!#7f2e|VXo!~&GQ17#x)+O8!?PDfX*k8E{%*vb=Ebv68L0-HQ4$<9PyKM)fLwRe9x z!t<&!sl~eeoHOPh0?XKRZ=ZmY_3;tjR_j%41}AUG4ZKDPn%-*+#e2*CWv^?Rouw$d zWUZ&b8~r*viKT;Zz_7%c0K;A^tb>MEwc6$25s5hBKjqjjZV%YP z&pIr@baMnF#9nG(me?Kb4=&8GJal^(S)RKNj^R+jVS4{KkOvo2^RklJLLl|=;rCWn z`1Nj|le4lSNee0~NzbRaOzn2y`3czJK2uYcI>oLc%T5UI#J$KVW94PXleC=HeT`as zHOr=%`232k5{+3ZJ&GBWUx2l#05ibzn_q;luV9SFXF-%BoQ)y0BdV(NG=6pQ2x%s^WNA+7ah)vu6*{Vixj#X|;uf0RE-2#zRH+=K9DtyyAEpBL?`Yn!6YP9~2fP!T zYTO3JX9a%~bfxujFwZpVxf(l>Ig5I6+ab=;ZXN98=Hm>{eQf{&#fGdfV5H2!``|=i zE8H%Hu^sW1n$U$B!iAcS4&O!mg7}Oaq~SaOH9Dd#9?`E8#?*EXLtloQR7V^ZO;ngL z=fcufsO3zo|N7WFn=JR|L+m_<8iJ179(%(#_(L+j@6@mPl*))7ARVe5b_;$~A^J<> z4CBv8k8@!j==;rJ0iB#qBv2kc$`=G_+a^zc9tg;R8n%0oR+&c@iYU3yM&$@pdTx;U zh2Ak^R4dm9ujq3+(Q@89$qhY$#60+8h8}2j^ExSsJ8Oas=)quEFC{4IDJSx0oz>OI z8O^euP@lOHon)#V@4;(9 zYm3jW%bH|R{eP4zASNE=MXt0>vxTA?RE5aw9=YP%Sh<$}zR~dG(Tsg7pgYX*RadE= z`5qg=yI~=TG%+yd(qyV;GbG<2Gq&B0yY=zBH-$~KH4^k1aE!if9-48M!sqiO3Bj}B zh5md%*|rJu;`n^nHLZ;8e>}J&O_F{FZ`R8KOVszw8OHkud{bc!r$}DM9c~g0p?e}m z3K*b-Fx;(0ncbf-BXD7mugJLQEnrjb;p>{^x#DW6_)!*uVfQG>gMU(8qb1aStlZI7 zak-7PGL?c3-+fgb&$eL1zEkY}(Y+&j)Lq4k9&q}5Db zSrI02NLo=5Ic3=0&rihM!s6RLZ&yZk_GoS&O)V)&9;%ZIAR7tQK)AKHcMay~me9PX zoX!VPR{#AOsJH^P+1hIF*Aeg&$^t^f;P8Cp8JEtS_$=!st z213GbQp)8}$#smF*qGJM4wC%P;13!oJdRuX0efa!n8%Z7bZ5!pxAs1!4dsXdfx~Ed z44BDRLUf)Ef}5xrK>|9tD$Ml}G8b=O$v^?>kQ#d@ErBLDkfR%aIuwM>6o~&K4Gny> zXe%qt74wfu*CDobG`kxt-P<6|L-}AtmfA-Ni&XIYGQ&YuS$hQ1}RcMh4*FG+#Wuz3u4{HW)G-dhQESQ3bDNh`eJI4EIw*_49`{J=;e&@I|ya$THu0)pWU1`Uwxyg#tHtVBYkD zsICR#Ko0+OhOiv)G)0zJW1eBUW4<+d-UjbhX}&ckACYhNWLm*=)qRs68U7SnwmYY~ zS!EpdE<>tU0QIt8{8Bc5;Rjw^4)Izg*N6}u#wMBWZk`%OhgoPn$yYT04YCwe8m_3R z%^S?KiQehEr#*J~FvgiR<5@L@FK5YL^3GHC=U#NrS&SBSH5LiFgB^T!W4dGP?Co&t zWf5E+dSCEfW*<5{ze^yTk3f{_kJ#&Qos(4WT$(?gc6@^>?l#|LF+t=@qYA;;4T9Rn zFQX+}r8iM=ovq7!w2`~u+GJCN$7fCiQ2rk`ucux3U_F?Qnt=kS&Mzu5-Kx;{Bkb#J z(NgU3yR0mRlFCxlPmd%p+#=loE+0;TqO$}QJ7brf66|}-P*fvaE&b`RP7pSJ+mW*~ zsQ6Sd+nKq~PYG$P3D+XQrad1@)7_h};&;EzCwhAO&%fN!4^gI)#*`%O?3R$dqNAfF z8G9BcamJFi z9}fIgg>r1+acgtusP#)Sl#Y_16*01FcN1=g(C=U0I=gqVM|tq;p4?PmM)^@Vl#Gv* zMJc%5{gmf+^V{Oof~*M(9(DRLH>#_DSYvgqbvIyY@qokJS+ta;^6mC}`BWDQal-x# zwdw`(U>?m1+#3$$Nti;Zc9d>3DzQl6JOEUfBAl&6{ZKkKjJSfhKLz?(pdo_-gjTvk z+~cXX#4!12B~)K5lCwr|ITy-lP<)0`NPp^*%39UbtXg4G-LM z(I1|*m6-45EDlO-(m{_U&kHFTGvAEyOnzs;sA|&Qr>)J_)j}}-h@WY&G`Th@EzycW z6+Ix}$AxQG3gg=U6Ajs?dm;bVfb5OwZjO67^up&+ilN~rEAT*Qjx%+EW-J!)19aXB z{yo44{6UEN=3Dig2@GKR+i-YC&4$kNPE2>?Z>l$xFx=JNRPETw{e)7%9Ooj4Mm+Lg zmFxfLR*aA@&W8~$k16Ud(btY&RxqjAuP7sz^N;1b;sjXg=)`=t#nN6{ZS^$!w7N0- zxZsaB8+og1KeTkQwpBXUIx%^@ z>>c5a^N&Aax6|!WY3hf$<9RDBQEpox+P?SDt?KOQ4#DV4kwrZxXLD3woFp3~*yqgr zB-mU0v|iQyDa+>AUb5KkT?u?fcZ0)H;%uv*P2Xh~r6CWRC=3_ELm@yP4$QvRZCNnC zIaC9*QRTpc2i9~y!kDmz61D4{-kul$kWoxxaNk*`naAqD;ofA|8T8?u_C`BgEN+|& zj1(86Bq{!8UTk!K_rQ|(WqU(XllYXsWYb9$<8Jt7^%2*wCayQFp2Ls{y)QfkFjE9a z#mNH|ec?JhI%*jAO#gFs6*l^p2%lq4@;8<0>3sj+&^_X&5#x76<_+~^x4A9@lIeJcTkxWi5 zhK5WbkT%2(Xrxk-3R(g1s6WvJI*5)9q;K^m%_8?(D;f(p9mZ}fkixw`$5h&2V(fsa z+xCwiW0QmH!B3}~{*ZU@SpySYQza>e;f1ka{ccEN0T@&(*-Lvd&fsAC9`d_i9Q&-( zU%Kl$CVSvDzuWXxpD(GPZD+oQy*B(a54G}(G(pn0C#$4hNhox?4v7&i%x~~gsOa%^ zTEy*T%6i&E1xK1Rp$<@&t-LrG!`5GFg1t(Eb2VK%=&?}nl3NXlBVy^m0$BtxAqx_; z?;}DoQbrLZc`irLkl^1j0E76)!Oa1l=Qoz13yt{9IxARBgG9n z|KskML+PiACW{lVvkOgo!&A|K@OnwxvZmn&l4EpoasOxnIFU%U z?g+JSJOX)4oruN1t$rYc$p+ue0g6t~Y8Yh9*LZRxY8>=(!(QHv39iZfI@M`DY}RWI zhvuvrA7+w2b~KxB{Ib-ggi-ujLQ(NI0`erWkQ7wQ22Jz=@h^FDEY79udjcK?j%1)p z)TPI31~{atz6BxA&f`WzGmNVL^u9f2`~X0WD(x_#781;IEcKI^$Ke_h0pouS%k6GZ z_J^fmAef`(@4;MMHQJ|y1@6Jo=QWjV*(wQi%LsrIH}0(0BzOdH7li+p1EvaRfe%_m)PI)gDB_2D`q40j@0+ zn+LBVCYooIv+|KP2X7*|!AMe_-(gTn;3Rf-ifnIhe=q+K)ja5fK*wLPiDdJMW1+a_ zNZ*$}H)Z~@mMZ#yDD<0y2!wpx_b4gZ>V^IYk25OHhZ^S0S|TH2aG9XDsg=3I$eYsp)sj-?WlCJ&CsXjm z-VH`PXoVZv_sa`a#Xfww7RDx@@BHSzA+jd7DeU4K0_w}iL%BUE*^eCelu}s z6O7@HC~DYt%iWw(>^ICF5vN@+mAB)yvdI(9W~pz}g(W3ks8ZU+KK`zxmS$phrxY#i zht}#(-XNYkV3s=Ej%L0<&3bGfr=kAbF5`43GH2H`H3c??q=w zO`4=OOG7_1buorO0SmZ2fi(Eh0jHGDAwN%!K}Tm{A1kt3)UR7(j65>(aXN#ZVr|`U zv-|AV6h+tLu07ieB!7t6icW$bxqdupm4>N!;XWWT0s;tM(U^8R`tcGxVm*b5=fBwR zaCy68MvmDTi~EwWcNczHN}T*J@FYMn9)*k*Blpdzj90}!+#;A?zfbBDX?IZbq?}&s z9kQU&6k4p{8(QOU`Z@(*Pr=24vg|w$wz8Sm!wVW*2@BkbK;FG#xmg?d3ArlMLMcZi z7Ma3^4zrVx%(&B}y$6jZ+U+bJIq+Yhoj6h0gT)5lA8r@#=lyO%9R^gTARvf+<_;WJ@``2Ml60GaY0*!t8hbz@@$ z_8usnmb>pAo|votq1$(HwT{4f<|PYmXfu5@IhZ~;%Y-HH&&g>ASK`(3yQpTo&D7~$ zG_=UXM0RRw>ZHt+kl&4f#&k>bb`IQmCJUFEhDT4J$Fag=e3;Qe)=$l z-XX$yozBs3sH(uoA_4gMQxcT)!uNr2fS3`Y(TRf-&?N$COU!`ri#h|*qlCsfN=K7G z(j+8479z;vY1OMCj~Y+7-{Ah_?AKj3{?-5c@2Ut$2gF1;rE3idfy5w}MO;|3k5-9J z?I00Q=auTrUx11Ij>O^3kuD(_N#DdU$CcXp2Ig3+oJCJH!L5CV)HD|%(Sy`3S!SM! z5)e|Y-PB%!FUooYtfC1x7zUxhd>s|m>KC4oPR!JF?WV+}9QSn3@ zo({HT@4e4>L_9H<8k+0<5(1ozc%a{zpD0;=O7?Lyvo_V|zuW%eUn|WGFi?$pY0#h{ zJeq+yXS|B*7hX%^j0;x%DE|8O*|a=Q7$6>?_`UgK`&qtrOq|6T>y38(ZAP+&FTuD% z&xGjRrEu;WTBQ=XxzeX#+~d=7rZ9)5FB4X=wNTdoA6H)))MnIfiv$R+#U((I;%>p+ zrNuosE$*(ty~U+C#flXuP&Bv}3KVyDcj?V{=G;5y-2BR~%p{ZN-D|IB$u!O8L1pKm z!G17g=yEfbnJW&%nc$4>XTXN>)&ZhuYnR)_N)nuD`F|`Ve2_*xQJc)xRR!`4{g_6V zIU%G^cu%=Cl7FSD11I#>GyE}|=s68IE@IDUi$4_?Q6yIpK}>v4F&vo^##ayNDxOAm zq=9Z6N>b;=lg``__ZO3w{+mrv);7r+;|j+Lz64oMzm67eH-$szCvNf#Tu1#^E4&Gk zaIT0o!CI1On)I^Tetc2axA*+NG;VI@%y^U}?+>P#)+3`hcd*@yKh(+f(zv0*i~3-g zHB#?XK&m8II)aTwB}JcSMNVk*`oCujiGJMNs464~5HpVDL__hJHVerC2t8RiY8^R4 z-_H>$-Zaa)l}ji~n@oP$PZWUQGk&wdBP9rJt`X-35*8Fr@28)p%vT?Vsm1=CuRM&U zI^=+7Dfgl)V6D&=?;^c1@js6n(T{1Gzn6tZDbVF^#XSmX#A6h>2#ga>x4yRM;`fK_ zMslHM0E~0D++7N}fJsJw^h?fY^FG(A^_?6F5hGHvs7aYQ(^Sb(kmR=mU{od@@Y3h- zKvaAoJEBD7NLqjeu#tvdggjntFs#6w0NYW4`a1Vd!B*yoTWpe%TrZCKN2Hr=8qZIUS_k~(4RB|CW2nY zRu$T8MkOfUvVE7|EoW_}@uYQ&%x&h~Zd$3!2&oUZWsV)P&B^}SPwW`!+ByWhkk{}- zkiU}Q`n(ZwC1<=`9@{&9=wr+p2QsIIJ_O*@5Nc?hiP|d-25ypJ^!$ve(Iz~@^Z9sm zst^Q0|NLcIA6)_imcYKE0%E(>AoV&TK==!UiQJeY!KCI3sgU?b0sF$=q~_keS-c${ z13x-d)cv*#GcjdSFXh{^vdN#>YcrOx#UqwE`Bc#vJ0fuR21lstsH&VutkyhegG^Le z_eMv%)l90R(x6hK=zTS7TARX--ruA_$iH{9W-B;>;hTA9mE>A)oE)ViKK^Q1oZ^9} zV(JCLHCIZU`%nM-Zuh5UGVN!_-fH*HXUrCiPGj2NpMSb=2B05}k+5!?o)aMM)Rufb zGrS}ncwnEWBni_f7E*a+%MN!$FHBZTG8ya2l4?kfvN`P$*6NZEVvjjjD@RvPs%z|&y_|2bwOKvG`(K zF8=7C!r&UZ{rzJ(62$19Lai3}>3RctJR_p-bq4*&Ml+&ASqSxE)HvlimQmUCsh-DL zhFWp4Bjj+gR)g_sGJ?dlHpXNwOe&@dHPl56ivj#2V8^IP2!4Nz82Ug#=S($504kea z^$>a8u;|O&t;8W_`auaHf55?xLy|BHxr}AGe0y+IxQgWEepF}i`XCoV{k=}kQ3Clue?Gc=!AD~n_~vVsk|wFZo~f) ztA(XQZGLdxaVeVm@-^|!yNKlvx>~4wi!ND$bF}j-mqE?hn5$o_tFA0!SXfM~xEl?Y2i^VM0 zs}0AV#JA|JpgtuE!-l~~I{XRMw=U=x>Nb9iJQnm*MxR3uH9w{P`c#k?a_{g}!bzu} z@IGO{r2%tfC9#dS&731jly5HPc ze<^b1-%8UqyLo$HbSM9(3#lM+T@-w@!w8|M%&pG)5yZW*-GO$PQ{euDD&wIN?p&&y1_}~lgbQXcv;N|Lb z&edC1ucR**7*J*l-H#rD`A=^m**F>~S<`Ks=6$w@XL_Z0F~N?{_};QD9{B>2jq;Xx z5Kaz0MfeJ6{{POJ-1Ouyg=GbPkJ{D!?h$8nr`2bKdRaD@f>dVEOHOU(Tdj_`436~) zlQw-0hj$|xTKX~A3dk;sztO=OEriT!jNPKcvY*QSP&ERq-?TKlPzi^6t&@#}IG2XZ z*`nO{HUfh)%e%)bNo8r)LI8Q+?RK&)L}!t&0L)IY7OvkIQ`oC1oT~aG;`H#ku@ye~ zlOS)Lenq_WXT&4TNiq?kov+-4It~WZ$qf@PCHY!BK?sk9KVqY>ljD0q%Oll=z3aw_ z3>Vw*NvOgrhV5sFtY{5`N&cwa6~#$nnhbfwJ)c%t<10=sC8mC9utVqW&uf!Smw7H{ zNKosS07*`3kc{5o;F*7t_OW6Mf7_;?C{G#3YEK;>)iJ2 z_UgjbFV;Ft1Ns#$u6%QGnr2y1BSqRoHvQeRvxMJFU@k1)iihrL`<5Yf>Z{Dv=s@q> zu3y{vv#SuUc2ey0dj%wzNBuJF=98U+lLbEHF#qDfz3x)Qk{6wg9~K)|EscLQF*MJI zxDL6@;;&2K_M>1)*>FK3WAx`r-j1#vWF3;MB;fqzK1uAC{qU_@9zBq|IdkH2U}qC= zxh8RNrOnFNiKDwuQ{s@^DEaeO_*UWtxtF=t-V%)9q;uD=t0TxGfFvQ*RFCtq?SI`@J9~JDZPBPnt|D(vAjcv0Z1dg_Io>!G~1rY4A!tvls;ud4mw7r!T=2xgP zXYwz~bJG4_hMI2XJGvuT7@qgGwqTcQz`8ThnUDg5!M#|nXxv_v^5>xcO^nzX+l{Yh z*Z(Xj)L&54-ws0l72rzNIEy{>gs{MMB+$%oelaRZ@+?!p`a03;Cr2+$k=Et@b3lTS zv@Gf0>MW|dRH+YT{l_V(`QVMBBN=VVef?#zQxguC@&I%%q5BQzyy^l7K|+}0CVtCX z?-p}+`o-hO(vgb)-oQy#X6V~ zj^aS8rI1M`BCY27T{KtDvNkQU;Tj5Vim4s}z>kJKPCAfKJiD6(Xqgoil0Rg8JcB-H z!_VoUb=8N;z~mrI3?$>%JJ$Lc)n0~)V=q$BR`cdNRR`N$>R*J!gqZ&8V^KqL3fgOl z2tx_vbPU&@0_Z>Yuf61mK0aTk*nW>L)8y5i+B6O1p%y6RNp4jEbtHwktRXxsut&$J z$?ABD8QrW+gMp!S5jqL;pYTEx08j~f&=nbC*3F^R{dOH#U?hkB?dPb@*B@9m<>W^! zT1F{og&-<9VRRSwtH>-PfNH00+zt)AVtub5L!t-u{a@=3c1$3)VKmd&du7SYho}&) z1X9kcr0wiolxdTlHovy4V8*h&_3jNwelz&?W9w#%4s!#5e+*nq$h%Klz*1d9Z$1yz zy$rHiC@)&xA>G9ky0>Y>T8mFz*YOsZ_%e9X2QC&RK`js*+H2ON4mHwoZBamF)icE4 z3TjT;*<}1?U2izq$jHZjrDr&cA+z_^p^z7&51ga3T&L8JyaXgYD|z#lk8)UY1YKR6 z8-59FeC147#)ysOR!8M(5{|+1Ju1oxIcvI9+|8#*<0jR(ljRQGBDl={UDom$QJ1m` zB=sV5HO7fS*auQYReOfLKY!1=kQl}OJmKn!R14_EZ6{~4^@O$No5 z9enTn7bN*gewcOQaM#j>ebFuPB%~P1wL-Zi2#91z61;y{cbb`dkN3=XDWO%oIEc_< z>}t(%?)uB_8PiK&leP*sgz;KdmHPw zPH{!&1+-+|M}0fP%y`qEFnDc^D!D}jaJmZc8sSiuf@ty(iSLC@v=s?myINm=w>9d}pF|^2-`m*6matb*bQxYim zA~9+>DUG}y3;LZ{D)SUWfkc_Z*LF(cto@2Z!skIU7AJRrX^@*o1Qy~O;76PoDoQ!& zN|K@RW0TGZw7Q#J`71cuF~cM|D=b@xGqv@rq-~4vLp;HG8z)U|Ow`w=(ybs9Q78O= z%im*+MtVBro_pzFL_#_Dq=yHRzSeKzfIqO``yEJ+YDFJ{A-h>&;|jTT_W6v>Lll-v z+!789n(Zga=w8g((x)wA#xF%#-us|fzo^|%5@KSAhX-XZwsng>f&4Q8gQ=a;v+@J_#wN5acL{1$fb+sF-4Y}(hK;rIUkMmG4W)uvN^ReaSEDMDZi-NZ@K5)Hnt z$z2>gttuC9>@o!ZO}wL)h!~KYypyjiOfC`EugO;D`Ar*=`;UM6Fj9``3PU_e&Ge)G z{%vY7$$o6`HNHY{#*y>N_d0ei1PJZ;drwmP)<%OUF!O4?1FN+m=e;B*(roNG*mkZo z|4LAOH(79t&dih&J6$xNqx=)!!QjdjZcyWYX-XvE95zH#s9!R27affN?={J}Z%P=w zgsoI5fIaaMesaZz)d3<8wlX-)Ks+MZ#&Pb@9~i847v-F{c%-YYFp|^dx9MxbE?)6z z8dY!2Nr|ZwVI?l9u56P84kMrlb$p4lF^F{;wP5MGlP4Boghh7TXh>$r_{pI@s$d|t zG#}qg?~+7vh|6y}Aa4Es4>KYf3hj=1-#J=`6_#+5k6!Z?r|~B&0`q?JSUfE1nnDl- za~&*LBOghE2qObyaeW247?q=b`4+cb(Ugc^CvB{CE9~0SlqIguoI0at9w>qJv(o+z zI`xMx5c*BiEHejT&nPFrK=^@~l~qc|6Z^dzi_5mhTd#viDS0L=e{2|V-eve{J3L{8 z<9=KmlNk><^AkI&-jXjqIx*)#=A{J8CjDb0S1W>d*h(_vaf>~j*ANw!UP=T#V-G$3 zc&2FXBMjiY!?{uMcE%ME0TX|q-!(rOv;-5J{=E(@(_!YpWHrG4*tY(;tk!C0n^u^S zvR`Yp97i;~UfY{EEm(QKjR5BA$9U$7N-}VOB^Hs^H zb;yv6Hl^FhXq!Z(HX6IMXyFlMNHeo3xcc&o_PWtlk-+F+;OHp3n-r?X!{|0N4*sV#XQkvgrU#+<3^WL=NMA}yh2}3ivo5Gjb#d@#yyZd*1 z96=9vt|Avb7z^jr-Qs~eVk}4fPc~Jj`AV(fGUy5SD1)9Akiy47E@x)zzFy_zU(z_G z57E1V1@G)|!Y43UXFhr3H<5?G$J}sBBVXVAJy`pkd*K|UuKG{CaZ~~!h&bYX9MK3) zGlcJEIqC0zZNCR4@82Rq0+C>gt;N*<;%^(2ohHzwe+2~%%;+&2 zM~sLYnWcNBJ9-v*l@eYbTP*inaa&w17A#s{d+>D^Xn9$q9Q}=T=_H*3m><8aTtxp3 zEp5`ZM{lYzZr>j>W#dePjZ{s!eN4*vKyvh5p4YA*ZpZKGaf60(gf7BkPQ>u5W56M0 z7-4LkdOwBHZKdTCyv?@yINq^xV*5B`JZfw>E9Bgj0tv>>R-_>1;}JkUJY2C)BNR8h zLA&gr%7m9IALJ|&$NXpL%oD=rn`c0Z*H_FeO!;sSQ?8oO-Ty~19qln%(mfx9KMp8U zVLft@ox^9;rDh)6rDI{lAU<`SFQy3#^pg@Qpn=jeA%mRzBEmvB5k2*uA490bb?SPV z+ys*68hTvmd65K(@8`q-$$`oR8seO7Kb8)Edo%^v;r)fLLtZQmlu1Tco=nf!qReBQ z);eRZNOl*9N2&7ij)nOY=3z|OIkO#E!F=pUDLtG_`6tX(`%Qyyh%Q3CLU7 z+3?IAJg}}hu-PzFJ7hYP-IpqY?Sh`91IZ766!&N(%DeiBwl(y+en~&5QW=;fr(PU! z(^3i2Ra9k6jYfMGejWOa5$>&t)mU9T-2E=oS7DV4Z-hxWnfI;fpNl?u2Bzo35Ozch zBZv15q1mAfL`~!8Re|aE$C;doOdj3(TyRGzjCe&t0WT3>`A(^k(`_{9-^(mq=4d^j z^ZZ%*tf4bjnem-!Wo^+VL7F6y+;bEB?ML1mV=az^m}?WVRqXq|-;Ncv1b2E9b?YVY z{+a844nJPUFzp0nFbi?a$5#at6nolo7I?p7G>#OC2c%(8t{2l>w6Kir=!4k_7?+j_zuo)ERbvlq7;;gc#cE#hv zYN*rARNa5ReJq_OJ8n2fONwvI5M-NWYIAM6`}^M`1{aucp?uH9=LXLG-oRZ83U{tZ zFBMY^vEc_!?nl$$DV+_e^CAwV*;{KHuu~35(aRb zg~%{DXs#^}gECJ81Ob`!mO-fl zk93SYlmX&2!tH?g-KZ^N*S8ZHa>cZ~1bHiB<<-tn@eNc0teLRYal*mzXHXSAhi8*Y z4nlmV1fS{)96ABG>z^_Snqu>=GPDh=pM9qG%tMUsDb1 zB#40dZ+S{UV>Be;F2ew1NQ`$n$mb1KFbQzF-))Y#rl_@)?7QRlGs$_uym96syfaFk z>+tYkg8PVwa^g31^t|DYv0A34I8s5>ET3CRx7=MzSzNVdc!y<`TI6eL=RWh26&P27 z;41`SGlImcQvA&7k-qr_M7~B35O1vuI`p8ZedN)JRG10U(JpM|H40gJ-AdNF>q#u; zm@hZ)xoojcjQ&8sj|IrvDeyhu3zX6~NQj7_Xc3(WRe=c^*F-aP@gnh}7^~u~#qYRG zrcFrMyp1|>$jHx07`bL1!+>#{kw;7w+0Bsb6*T$d4t1qnCt>x8Fp(c+16 z-ed)31SNuumpwkfXDIMD{LLChnMMZ(Lh*e4zPvG%^^QFoCH-WQeuDM!VJl&6F zvoEtUZq9XX)L-*4N*=4~#H5MXSkIJnSmq)l;=JoP<}rg$teT6u9@i$)qbHfsf(>Ld zKN`ARD8Tm!{}VlyR$d9DQhlo%cJWZkNsy=Yx1oGuAJkj^_L-kzF!Zd2(^5R-cgOub zioGUWko8aYk+!@E=eMb122Ziy#WTu44o4^}6LhOotOgd0$Y!-vP7#zsgyoGYU23}CTp@qoUJi2)WY1(u`hCOy!+&qMW8u_q?1qmA z4u1vBKL%iB)U1SPmOW zur9K4skYHUll^XaCNFG1Gk3WM`x&*SWo_=ow?>G(%aiJhfRTlS;v0c>xrN~0{CYTkUulh-xr&gvDniIWl+v_U~ z^dHMCzQ}``zpitgO`fEmh`n^HT+0Lm5XHvE4uL8$V3#}|GZ@^n2EPO7Zm}*BW%~s8 zD!%#UWSXV7Ps?;e>M-5+cD0sfB2(iRVoaenk{WuKadWuM*MXeusF-b#r&uOwLn3CE+>qm`yEitw&6t3OvUc?Q_00*tm2{GMIg33^J^s(VK%uKa%MDgxYlgvAU3>jJgeNkPcxzsLf|VkV`3ayuK`V|ZcXjA0b;*5aoL)tD}c~CRE8oO?V0An#y0G%ezhSMM=4tc{zaqVPvg!Z zS)bl+ULB@rPexCFeK%h@-0$VcaeL@mrzI95C(-fTj!#IHczF0xTU)!Jy@-JD18-RLkN@!M|s-Jrl?utFiu?Vk*gr{0cRIT)0JHfN(2VCF z0GQ>Ng#0ta<%m&ho;RT(JqUQT7#(b=DHZU+Q0{G|f~&RIg&HfXN;;>5l#LQEvWPjiVM!s{ce6-EHaR#h9a1UXEAP z^c{I`my_q8F~iqF)7Fsy`di$zKCs2VWt+RqU}7ip$bgy{gB=$fMcF`=yFX#0pDP7& z-Bfz^E_LS{G?15Y$skc}6wMVvn?J2IuNDy99X`M%(m9gnaV(m8W>MLTezJ_08#LLG z!&z}-U=8#!V}q86idM|b%&fu5d%i~?*PK8zNSvFju&{6p+}i3}PIso@?=3+liOv0w zNLOI1#abxHdkDMh!}UO0@9liM94Eg#0nPf&^_EXWjbUB>$v=$Yz%}S%lfm%!YCrLoj=73Q z_eaWK<0Z&^V6{Pztr)Fm@aQnaVZDqO_?iRK^DxW(H(|1x{;9?pv8L;{8C=oz;LI-H zqkdu0YS>{a+Hr?fYuX7b{-#lwj5k5oGx}6zCXyTa3z5b*=g{DM*KG;A-MD$>@_%zP zX#=H!`_$I~wx!*-fm$0YHkQY)AAMGBM4&gP4D-1}q1Qdss5Mm-Uv`R?X!DPLCHkDS zr6r8kDsDA&l9Kr!PMZHo0`uNwkXIbtoza%-GS9X=(Bbx#mEpEw&4c)rg~E$*j$O`1e4(NFL8vOUa;RQmh; z`oLY(pHWxEh>63rQ@P+lAoN<^LE8)38A#-IZ@^4g(jD^!3wckLSgsgfUjwkZ{fBjQoZ^xQ~Icpc+IVvPV0t4Xc-b}x@mN_;+ak1k_&tEh3 z_ItDe4Zn$Y0nqP(1TOk(UnJuxwr|?eC;hje6-+o^=1T0gqu2+L>wLQeyw1`F6Z0C9 zSN!qZE@PAR#8CrHe-X8_&EdkUGQp$tFx`hhAgwvYXOsB#d!pUEGf3qM{o<^9XG~yoe9QP0abTc7~18i$FPf)f-7ak-ShGzv6#uIanh$?=6XV z8KvXWH={GcS!GR@y38?K_Ut<4br|enCg=HY!(C`FKjZMrD#@F&z|^7_D3R!g;l-QH zGnoK$4mQ8Du^l%tCgkFPIdy?gtro;ra{+H@s-?t37g?zU^}84!z9zQXKY!0j)7HKx zv=y{Eu9Et71Z2Hw)fIOS z4MA-UhFR`h9R%`snL=l-Hy*>&@E7h5F2cX1w`9kowAjRl2%&SkT?DXH+T~vkFjKOr-#6W zeG2-V-lPXGk;bMOI(l~Mv^NpVoU90fZ=uiQaC<9XDBS3#o?`%8hM^)N0+==K&3{Vq zQLa!y&a7~ZaA^Q2;NF=5WIvnJNp8>Z<``B!4Plw15<=}82Rvfx$gnUZZCzbDK00cD zop^LrFUZED8Y*|-OL&X_nfU2SaP$){^=41+pem6AwS;dQvzr>|+pljl89+%hqm~GV zHs7evRN`V*n3#Gv@aV~Ut#OK+nf*;#e+Tek?2wBf>C%&oZcg|(24gjJ(a#t2Ij8zf zP{$HwWhB6>Iv2eldrNQoD=!n_9kpmrj-Xk?V%GB-EwQ=FVH_51-Ee@)cV($YJB&xe(u>N2v8H8zoam$nSHKS^4# zl5~UdAmG#p)OmVBf90bYi<&?j4iHzS`@n#iMpfhJ;M@X;6)vW~2E6Rxhm^K14vp<` z`}!tH4=S^Vr0nR*I3DiGuYcxK5=nilzKUqtqt7xmoThE@DO2dg{HIiuOyOyOXVcYe zELmg7*&75Z6`IRVf5?|uRskAUXiY0cfUqgn%t7cl8?ML4C@z_HnPQlSd??3mNluvdtN z6$J1gC}4@YL`z<9sF?-GzjBZn9OlUgF?$2iqLYY#z^d$4r3dRKHl)qnlEqXBsejW? zEQ_xiral0jy|h4rw8$kQJ~6#prD_`b-~8t}>20;ge2Cw-KaQ>W{!HI~^4Nj}La7OA z;HD+mP{-*Ix~Tts4Ho(A^xkiX!t2kVQCE{GvloGG>{U~g7Vbb>6rm6Q@kV^ahR!9@ z9N$-jod1^6I}?p`S?7_464s5MMrplg@~~Jcq?*R)y%(T#8-+U703n`f%v80qR}2St zJHIO=(10ps4Ozsr@)-1rJ#$4Zb;3>q@=l#i8c5hH}QqeF0lv+ z!T|wDE}DZh;P}slcI1hcWna(tQ0dvHNYp24>Y8cSE$oqbc`@}37GUG!$hqC-&E?gN zseF&i`jjG0>XI-b2AH2Nnva_w`5vcnU+Jh`bj{{SQ$&Bo2}O&%WR}_|!j*OLE3Obn#S0kX_2iFNyfk@d=@Nb}ZG;Jm$M! zfnXYH0cMe{Npz-p-+_&SGC{8@p^ZipAe^2z*9ekdTj%nnnnSlU>76`1$?%F(Ep#ks&cFAfKhRzLEqD9Y7;etqRgl$xV`c zC}+g_BT~9<-KrbzGS$q8^*G+HcO&xh9@Gw#7;XQ71Vza!FRztpDkr7miT+q!E9<2z zLP@n!Q0ok9Fa8z`S2VplKF)KE4LPTNaXAK)VZQl*Ga>Xu%M%+uY0fzh7U~&UOvh@e z&fK?j=BbN`(=hk#X>WAMt-tHXvF1%DhU-uhE@G6%CdTU(Id5?CAT4b!fm7g34|F9% zz|HT396>Uy>9eW|XG?qhk4zm9MR9-kJN;b7*U6tr^b!&TW$GI)*9p>!9!R-ie;}uM z7?{7`+8IJ2ISt4j>Q6oM0+bC3D?p$>Y%&*_{e#3?>-5iae7~RFPoJ$-Rr4*U)_YQv zPC=w@CX zL#^ z3?=mNUWNBhi*5I_J3~t%$y#=CH#|P#UhA99qKf@7d*`za6O{q1xzRTGMcB22UOJy! zn^0k&264TrL(g$*;Tyfex+gmI3%^3@g~(o7wPZ|(?x24n$5*vIrfV$6cU3TPXWDIK zs3t~b_r(wfAeIG`of72yYW(;k^qhw~*&rT6rBYrKZ=%9U_LDn5#^v;sK9g)(U1s+u zJS^;SAdo`FXfPSEKIOt5?#-0$0U_`_#>`jahphnUOhjKc_V&<>!2YOQ0s>wxUS5#o z%B7<4AwhCgsFvf4Tu7!+?w1T z*7)mN-vSFhZby24eAS1C`L#Qn&IillMrk*pH%4_YN?>xw7yxeF)(6HT8w~iwB{map zz&drd`II=6$%8?3+2n=gew$rX_#CSXVG(V=&UAiW^gKUh;`pk`vQfO&S!C|(|YK)4aO^SKxR;29W`R;(^%zWZRXOL=2=*(;c%5qyE zh4<+y{cDPkLcsWQuMJsw6!rR_G}irnDo~vP;@P^-eFP(&a4|x=1L7ZeCVrM#=rT-s zn>c33y;4(u3IH)7VTm4zX+iJT{Oc-bgPJd@0VtkcwOS|D2rxyZKr?H>Gt-$7C~B?B z@8SF|Fa07@l?DHf*-Xo_1%kX$N&EYxyZOta;-iEd%p53P+R*!If_(#2^R6;XFPgiz zXM`(mi-h@=Zu%o*tN)TZex4^Di9Egh3tKbk@;{4n%ob4lgL*>RKL&q^7N?;0@mbi| z*fe%Gu#;ICF&Z>mMheaFIer#6d1`)HwLPn+Txc-*mYB0^BJwG4PM9}4MSd^9RwBRS z_m6x>VgSFGO6L4SH@F#(``s)vONFeBDgUyCa^0M6Y<`vR#>&bVST#F9R$pZhVPKWp z==rjxS?WGtgH4w`w>7k7*IM67Lq@`rlKc;Td2AH3F1BX%VK`4eD?_Z#A^OTBv3HM9 z6Jh9VUO9NVX=gbgV-a=Z|KW8=>Ek4VkT`0!<=)zty5VdQVrrEM7aTRFmV+e9N6F%b|KORB}uF>cQ^tPl=cn-*2bOu+il zfn=v{?%nXMv{>pQ`1p}~EtDWB<@WAuwMuB5&kB5W~oca0p{ zl|Te~HkZAxS0@7R3d3+Eg(_V!*gM2T`HM;`2ux0HhJKHRb_H`eoEQ)BCT`P!PifCA z&c7fagdix%N@?Tb(#-q!TO_3coO5F9mqc;WowjES@BmTJ&(6iknauL)iMjb7yhsYG z+X=j{LK2fI&w?zv6~7VwQak$3@| znRO+mAY&$H%u1)GMl>Ev&j>c*LA?q}J93`Beg8p2wSv5VK-0g!O?ie#CF+J0r=U?- z*M>F{`3g)*rF)-@C>93-K)?ODsS2q`kO+`Qu{$nxU*L~MM#~TaxSySaF>QTqKb}q! zWG{hJm=R7OgeHtu**VVrzi|MiwXAms@w91N8pK9HuP>B9S*}r6%U}L>%4KNhc(D zP5@F3dM^P$-B`z{_o@^XDC(__c0$h<sNJv406d)a&wmcw{54NMngK!o`eNXH`-%*aGlU6!3IxcJP zN*$_7TvTl7WBTn|2IsO^d#ipMvWln6Wl~f}K7a!e9T-m}!=ff_W5XC_Cf=s=p_eI{ zy8t3_g+X9Z8%p8zN&kEVyG-4*?iQb}s)$0#mp__luebyPY%<{m;Gs>NGdU5fIE5R6DiPtPqVL|5%R4}aP|FNXloVl%>K!|E0k&T!&Qn=CZ>Ux@Sm%r zWS!;cu_8t*B7w`ncddYG9$^aZ;pbHJ;bkiA`}`9UgD}QfD$eKNOz+MM4aoz% zfVjt&V2k|#qRD?|?{kx2VU$E9BPqFWcZ%D_B&fH^cBydRo11ez%WSRhXCNs2ll6&s zOJ~R|r>x9kW*(8JpqrTZeL`1TTa~I3OSV4-B0p!Jl!$zL_)B_81cCUp2dF-z;HShY zdtjYBe5AP`G9fmtxN3jwJ0pR;omIAKe~DA^2}y??d5Gncw5$c=Gd;u55W5uek!+J! zQEBlW#Xmft_f3MiiwDLQu5PmB&g|vLQlZ1TN*xWFUae4j;US9kmgfF~{Fh)+$O(SJ z^)j8+4hiVpH}b3q?CpHUh+quzWcjZdR_SycDeJ<*5#cV)qBL4aWYR^1_&KW7pnuR< zHy=;nE=@Y2BMoSWA9}Z-Y`%?$fo#b;FXRhjkjAM33BPw&PR=w$NXj%mZdU3`0zwzU zKr9Zj`UwfZPnVP%wul?ck2QDS%F3wi!}Wi`M7Ew!amAR@lwH@m5m#TUYipgu&oOn^ z{#Xlb)2t(i2JfOynEZ;mKHbvvb=kW-xTt(1-$4BK&dWhq0{+a9^6~eFEC~R+d>9(3 zcRj7Wu_5Ve*#q(Oj$x z{YV68t7x?glc+#66ck#1a7LJ=`1>7{66q2x7aQfk0-{!nd^J3}j6J*-n`v?C|*C(Q|G@(!EjudT;dAKNn2p%p1ktCLJ>Srxd%TYkp`0im`$J&H+7RZ-Yho z1o}hehbM7&teZ%w$)gf|xyEwo%&pH*Skav4TNPk}-8532e_igdBRiWUJgSijAg5?= zENRMfe5A+wFrOM7){jW~5ED%Z12OfOj~1nSXOLOar_Pg<6}h#dlp9+J0#(Y`~2a4QX;40i}X|cGSYT3 zJYRNuJ3sfB?UYcLJRXt5NZDv3-@{Kl{2SP6ZyO}YnbH2?yzZIuN*w~hzsu1iO{%K{ za;z2DkuE*-%>Si&)TQ?M(3_?awrtG;ZRYjuSC`)7NfGUBFFsFV@m0!dFP7a)eo#2+ zVR1ObA-L2)GV9kE%-U%iJQr#2PfJ(^tHP7Xw)3&i;){|*w2R}OM&MY5002(sSDumIjes6(FNM}*-ApQ#Szm?>Id z-OK3cF&3-#@o^b0XXS@r=@n6gfthzIsij5rBDmpVYWpYU#Yb9_;!3ck-55;)%&>-Z z^lSDN^V58E_Dda9lcbhDA|Zv5i&%)L>H7;9_!s{Wm|T=U%(H~A@h0ECffBA_>X%zR ztvm9)+EoJr!lJr=8NktElGfg)h{HB-bhSkM*3nPbr|_hd6@M(2V9&|2G~7kII77U- zKb`;o$u`+Gc2&*RzzA|QlXlsmre8hYj}be0lI=MXN%uGi~)qd`uYky=L!E}3u+u|FL2ma%U{7iV>W_E`SHKt zs{82mQzA-1D6JYA-0I8*alm}X^+vxjInVAQ73xUk+{oIDMyKYq;u6c=@B5W(sktLT zKH=x{*;$>SQv{QVG0}z$hq@;ddG*+H;S5DqegXWNVSlg^*3sc-epexFVuG3N$5_Yt z;I%VNx*z1a$I91HRj$`&fEAgvXY2!j($mqkc+c;zZLLD6cy?*BRBtE28kfK22km<$ zj)e{0^AIo$;}O)!VSd45CAva))tcf4qGRtp>~wY_HsZ=VRO1z=OC&n;2UzQlb2LmDPHpSa!3ILO;ft{rrZQ8hAfj=#Q|f zw68A;Qm@5Db1cK^v{%oPq4>NbQJ9Y_!67Tp)S8wIbtNO-#NaR8;cJQ5!jEsz#V~O* z!BO4atZ)2}zpPW`pgbX#G5q}Y7S`+tUttn7u}R7Bj4=(^E1?yncQaC6ixz765{0D| zWt-<-Cn0W(SKw1v2>T}_Bs^zIar$FZ)cactnPO!Gt5j^^=o9Najw^Kp)2X zDHT?r3a(JgLr;vP)Cv!@wqcjj1bq$40(h$pYOTX=D56%3LrXDPz*O2m*mQparFzR` zP9s9RSY?ITOLT!hvNy=muxT)!pW=~|P5yK%Tnv1(nOZFlfus{pt;)H;6A*e8ZU-i3 z!SSIf=&Iu!7Ew-ISYt1EOk>b8Uw~KnuioxIG?g}x$wtDSSqaq3c8_Ruuac}C#i4Sj zP!j@ojBqrbHJxfEir_ze;1;3pY&4}WwY0^TjcwKtm1F=%A(;zKDvTaP-Z10Q`AAho zS-&KJ?wmDR;KT`SPGR@J8e1p-3>5@IzlZ%W-^u`1h4A?RN7qSd2Tj#f#Od0d^9eGt zvr%U}<2ya6Hl(V;TYT4FiGeoR=f<$K(50l1^GH2`>lW{?{SwESVl)Ha|nGA@MhYc*??%uAtTm!?h+(~b< zB4m~#CQ=G8;N?5f>QyrDwV15>0~Ohi6bub$2ri#!$4(f|JU)9bDvp8KqrU}aHikWa zL18@u2>s!&v_IYT#z0kIocjA+F=l;rqXPdJr31ecfcS1rx$yf?l;SXmXjI`cEL=vg_a1wnli1tiE+RFRWp%A<6S|9qb@tBp36YM; z2oUIlI+xcpy(kfWS;h{mV=32hOHXg>@FEg$S?g=gy>q?!M*$ArCS30cLf5Tb+hG}P zNZ*^Rt0{4DB;|7wLy1?0Hh*~`LZ#- z3NqoN6%?d=L7;pduE{qpeE`2l|`6d_EQ>Q1GuXVZl-K4%T}-9O<+D7*(ic^?roNP#@9X zBgrQRL$I-lWnVACGc&RLqJ(v8Pr>1Xy;%>#e;BcYq|{x}K1Q@SpoOK9sx9#SW?@wt zA9qtV>U1L$7?Ds~a9(FPTh;N6_IwBpUPOzE&QL8%2d9|mMGnN{nsiy(zjD?AAkU4^ z!USGaezaAA0p7{6f29PxntRL`OZsvZ)y|>FUP16-c-YKoSE8W3Uckthz1pAUgo|uM zSa2ppf|0YQEnfYRBB923WIW^QailuY)xv`%XGJ(gQ%z~zPR<@ArX)5>^rbjf&Q59T z{HHp6kUbNu7=K#7P8R`FDic^(Q6l(OVW37`K&T)|%BRwE{0zvr7@^2*g^b*8Xt35? zkvw(v8D#@Zl!#N#u(-mebn#iFpy&UT*!;oZT@&+rx&& z-YDctz#bt?N3f^Ic9vAK7|Vc(cy?qFt*t~Yu$JOeo+R;t!Bbva?0(5wNUST|@>w_m z+(%$}N5M#;S&=#Ck@joAs$dptcd-T=EYS!zG>bIH!MpX_3iZ?!*!QB?*A~c(TXnRQ z)j0>=1qYHjAwhc-Fd`5?k;LqMj*el${Q8Ti{D;i995`lXW-&awgV;Q?_uecbFuCw7 z+udF7v|gI^E%U3g(z;9p2u zM#|t-$=KM~%*=d3D3mKLBON=>3wQh%9d~K$LcmL@3iWL+0E$Tm7gU`Yq8rb!XhDD_ z<{+`yao^8RjoQPR*-Il0jFw!Ew(eEbLt2PYI$q%3Zto){HNg69hJ*4(QHi6pAP?JE z6esyQpAUH`Ha6WK_HSI}K_{rmRq`pIbnBlK;F-Dfp;Z4U^bOu?WChRS+hH9X*sY3? zy^j9=y}U({VOXH>$~-c)_t2;GU*Fa99&4Q*y+)cQIuI>1E?3}s;<(u9A3RIzNAcyK z<}A{l@_51uWvJ%S6&{TPM+i-bx>><(KH}EvpsMi(IsMb^Q{+3>r%8%KN8|qA*&h3a zc``v>EL*d`acEe=d0Ru@wc^ZHA4ac(K9APWz0IYLD6~6n8Hxp%Up7q9+GTyKOK%g*iglox&Ur%0daKmaxzW*jvbv3!OSDyFc^t2=JSM z%Ul@L({ovLKP}LnLWfl6WjlDYtTr0jZos!z6Smp#F{dq|s4n}@t}}L~yeq;d*d0v% zyoIXPsIHY9*Vd*7qp6=8nRfSvwsHgnzWC`k7fmw(Wa*`|vLe#vc*AH00m4YUto2QK z`tSKsKeJECsp(;`l8Y9pQXg+oB`u_pKYZ(_K(NTlGaYEyX3^{MB+;!~>wQ|NO|8AK zXu*cxFtv%ZH5!)tGV##NS6v_$*FUUc7IXmMjx%iNr;ox%c5s1PMX)Hgq#RTjpQ|=6 za;5}$HA&j)N6T!`0EgHwrhN#v2@;sVwK%aa@}Na0Y$T6p6knuI7X^t_c>&uqe4L*P zPdMv1-D1s0bh4hiF320l#THU7YF6H^drMhO)`Q-BQpN-d*0lZ*jVC*&AFpA z9fJOmB1(P7SZ2q4mX&~$5cnL?*!*t!{p7SjL!?YHMqTbFv*S!ubSkd#msuzcQ@Exr zsdrVwPaJzu25C9D1eu1M7HBIaj3ceQ-n8-&*Ee4l!VwfqOa)m> zR`~wKZ=<;4T%3{0-|46=^I85@ny3ED^2thC+!1P^M|BqO<7~M4$1-tXfV{2EN98+! zZ|=JFju6OUaWk%O)b8$;8jK%T?`o{;yU%S%zPlwnSs;;wSFU2=_5gdHt)2JGaB^Fa&?iEdL&&mM%sx)pd{=O`Utkkr()elucnziDYNe0OmLz(7r4 zAm5mJt>dYTmIg@}E*i|P61ll(3ARZ2NSznbI*wZeB7irxFw9zh>KzflbkY9=!b<~s zxS%@9ctCM<0h*bv$BQMmMF@kZOcfW}XoU=X=425)R1K&rHCGNOwc>!WJZ9o@mw}%% z$_=q;qHWANHy6I8sne*tjFRWCF<)#4=cN|)lA0AOWA(g7@q-Nn`WKG=wgqz$D`tLZ z$?J3}`M<;BR9#YqPrjHdn((NF+Gp}G4G49bEJARz%6W-UpP@n+3E(WLBC_EJFyX5C z{<@LMb(C070G^!`p6L2w9?>wbJ*#1NRji6X^cW5zZAl>ZVTsG~$NQq1QY4C+8}f*nYW){8vaBqFzTRmJB>6HFlppUL|G`_K-NU`I zqX0dl7{&Gus3Jm=Ib{b1XqXuY9sq+F_Ik{G8N$6qdKuqy@p?KVh9Ilc#K2D zaowibSU%RzB&l#WeGQ zrkU#>qW6s0DA7O6${T6cml}}IPJOYSZkc7in#t?`dFLIBj3QCB__Ufj)(#Wfr`|hV zsv513e;FR0nf*vnxJ7=cLJ3x7O!``lrjt9 zd4@j-KnC<@isIwqAcCZ6xf>f%V-+bgwk-VpOvDF1*#QPYzKZjNDxtVaQ#?QAAerStS$aZ{4@?L@Yv(FQK1do# zYjyEvN`=WC2}Zlv6%PR z=g~NWD!Q`egm-vaYJ$If$=@vB1-L~0r#Eid(>PZ24=OSz4gKUpc)ml_Ty4jWLdbl5 zCExm=!n^GXJO3gOOS%&i1B~tz8C&r1lT8Z#`lj!7yyZk(sv5Kh8$7{CLt6&ClQ<}* zmw`V(AnIqL+5G1En!Ov68!%>Eu&eDd(ugQpe{M(?tjID`{_-V|qJg(qb^d-pmIH7f z|K%kc_3*yg9Agn_|AIH*OUKEI^kK+k(SaO8J&o^sxH1SzfEgIgxxa+uk=-HNJ1Ih0 zK%*hcXi2ST(yQ>WxJ8N)EB6&wPxzsvvL<{FiKjO&PrTO!6nqX<`1OWzCw`xNblUX< zJ6b`2rPGnYV%%u!+b92E!}#Z&k+U5^;7Y2*focO!?`Du?8^raG6!5Cym-z#1vXq)ZVyP7!ZVul+8%0O9}RuV*Jfbg*XsuaZWmz~g9nlKvCg0~1<|pz21iaAyYXy<*n1$AJzmQW5^Y;tx!rt*>P{_;>!bqk)u z+*e_@OAHt~us-0Ke&HBG^UK?lj)gI^Jd{bkr|Mpswd*C^_Uw3QJk;-Sr{;UR1NqB>Y+OJFA~f3c?)z0TLOC&tpzJKb65PV z7!fCjNl>>XZYOwhe+{)OnzTepdAT#dVUEG{@-jq0h~$@!&O*^IFm7`K+YjIzpAW52 zE7GU-W?LtADH$jBZg*Ka26I_7>b8UdZ&sJVJ|}L{$pJVh4)ggWl1Hk$x$NzX-X>Z*ER|I% zr{v2jrk%8h(pj$Mcrls!yuDrzu$uyT4>j+ikBi zS3Amo10zN5RQ`*S#!t_r;ca;L>ejw0A%)5??V&uxj-DY-8`hnQJsITo5nEAB$Ka!+ z(L`{taUjC7NQbY&?sQ6tXa3~&;SJfU1ePXjTv-E|y|#Nk1m+2qw;}YMzOT;f?!5}L zNR4YyDp)#Kby|6My?@qSx*KPFr~`H^=+GAAbwX5YMTz{G9wKZ&I}qdX5YaJnfqICN zKttDx1`YF!SDAoObO&$l+KE(Z3h1T^zAq>jxBVg~$iBbqAa# zW9<2Z+yWaW`sBf8&m^gnQ0vssJtsY46cN6{p>20ddtt=rAv`Q8!Jm8hwd-9;V*+QJ zr(K7yS$`-GixEh+X zLgNMWzWVxkgEQG&5x}BM09}o-rFMx>W|hC1Tsk{%2bWpfYfKO!SD1#Kk2V5qSY3&(DQ? z4Ih`_3Fwg)7nUo4t_yJw-`Ql0NDODXK_PP8o3Y^fFtA5_c@ryzv6ROhG3yIYLq9K^ zzmi3fGX=Vs4Jw8aZsi{WbT2X-EmT1|StgHeB8(pfORfqT#<=ew<0nWIJ znskC6Ux`ob<+Mr0&@--7v=DQ(^bJBr*ZEmhFrb`W&R@YTqKaoio+#)$yFVa6u=9t$ zF(g!9xizY~6YhZ{$F!Vo&1FJyHSL9nhDeHWgxPbKW%y@39^B7E>)1o+Gb;;U6GLfu zJ@1Nsfbj`a-1Rc4*0)m252E2Ah|1aIVN^-lChI2Pmp7WhpWhm^VZ&U&e}=hgw7->n z75SAz2{((%bbT8GKP3w1`xW-UEqNO@^H)PvCr?vGOuus=a{rJR0aohKO zWFcC|9In@VnflMN1V)Vx>Ow0~aeH_dH#zeeQHVCp7nwZEhLa~vhkj&)Z=2`8hQn8^ zx#}CLV64kwYzbGSx0~p?i|D%G=sKzx3K!AWm$RO&Cy7_+ol&qa_3i5NvLA4`GFu-*;3rBw7{`o|#=voNYww9)UL`kLY(bB>TYfTRoj&(5wI@ z`e@emCl3`1#e?9AZv53tX&bM>gqK@Axx)`0F*!kAo=_HD6v&?@E}wIh=uJ8Rx1)EE z(HFMcrBaJRH|(@)<6;m262&mUmJAl0p%Wq&CB<2f&VWs+6>@&h3#1f|eFg^vIXrGc zq*z-CH$NJG0H(!E7vM+CwQ*XOvAiEiCnWe@Cy0vdfSPw;K;bYsG}A{?l znPmRli&G(7wn>4`RSo{$Ff90vqY4>!5v3t_GOk?{+mp4ccu%FsF)ie6I&n}{8*FxV zsJK@~oZx8^OS&Xf74G8rp?b%{)mgl0*FgLeOln#u@iGkBi!w5p(onFEd>StILQLH| zV+y#)^5K7qrjR%-e4+}Hq$sZfq-bW*`hx0=VKK8Qqw|Z()Zi6S(kSPz$<`Sf(Uqw_ zIrc}S6+GI#9@DuwIny&Up-6{X+JwJnzNN-gYM>_a+d%0b_B?%PCU4SRdpOUkv2@+v zI3@ClooV&NqAo@eDk|(Yu>&k002xVz(oek@pw4G5+9Y{7b50Fh!~njbpTE)O z2&^o=i;^gia=ZOvMjkfYSrt)qip-+e-`xG%_}l*?Dh1^OkT$C=0-Cw}>ida@ye@Y@ zrk@|z8W=R)f*IAAit$!VH;}{!sW+TLa0?TBRN*V?C`GWk_KO{h1yLBj%l)~_ws2|p zSRCRi>$sSqr>&Njls1RI{`!_|uB|c5W&1IzTN6bGtuCygjdV8gJ!>8_q>iIQj7ulf ze7#p zC;NjzGw?cmt$8c3;^b!yMvurh)B*L7M7-T2KVKx7N&r?TIGf-g$2Ljf=f6ks{u4s{ zZFZI_%++F;+I!_=k#ukS7%e3ngd9GR#H66QDtAr%OwvI9eb2;%GY(IYos*NCu^*Sg ztc5GJ2nx}WT-!``tg1Zv`Gu``kVg@DPkv2r`F7;JEj`v9vIm#MY=XxgJ2f*O9*;*H zpm9u3G$76Fo*rU}k3sJKpZphXOfhjO#eV*RULKSW|59Kdc}NH=YM!20ZpkG1S$%!5 zOjM8kJ9buccSroal}6VVl?a)3sS4~zTCq=IpTmfePHG}61@#N#VH?}F<@E475pK5{ z7P5s7kNWAX=*T5~4ApU&4LKx5DiZtO12GwvGK0({b-A$@iFOzR0N;$DPm4~Z@ur7} zP#XiOdn^X3iKAP6M2Vw4E+7!kS{D z+R>q$WF*%4%>2TuTBgxM&%D<*5!<$I^tHb}y0vn|J?-h@$*qbMs%CDbHoyl6=EXG^ ze_&-Bzf6m|{3*;zZUB+G1R`veoiZ;*;9F36>1cVe7L_>+ifH8Ub4~w3=?nUUQjrtw zj%~e8LTN?a6Vlx+b;S+1<$Z^Jig#(t$B2VjBe5Vp;=vc%`-rj-;mUj09MPqLaWi(e;bXVx%PQU!|(on z3q_@SoAja!PA%NUqKQ8=dd7=zN7rzGbPdO?2K&PXJ-)ISZA8p-X4o1mknVEWfHSBO zw{?N`qMR?aeCX%Y)S}D(Urn@>WH{^91)(b4bBEieTfsUF1lt9x6-?BN;tnWYspXul z?m!lsz4|ZK*fx{bpc_xD%3KFfEiQu<4Y5R8WcIF7o?vcNl>>Z$3ifk3-*!1at6&XG zBNVSHXazkOQXk7wk4p}9zcNLTNAwJxbw6%F_t%gx-uRN2LZBQy&uriR2uR+u{PMi9 z=y8I-K9(s~I>)9ThwSi%LiYh*o}zu{Bac@_F|O~v^GZaQ6W42@@^>s~>4xVJ+mQdh z7{0ls4-nQFf$G9R(gA#ffFVRc|3enV;Q;JM9Q}immbaUrg7secD2v&{qLTvUb6i9# z1gr{uyn)3bso`;i&qHvFHl(OBo!GxTuP;PZz5V^Y3pVKt6GQu?Fm7))@k38F)CHqt z6O$EuVaM2+;Yu=6NWhG*^?$#G-JzP*-i-(ifK}B3N?}tkraJ~!{f0yvF^4z zPi>|9RgFb?$u39~-`jj}opIDrOoR!E#etTbJBx*Wz4Ezlep@^1-U$K?8)1R3c;6iF zt(-vx&s>94g+jq_Gt9ect#PGC1aRn7sF%NUIprxTLyOg6iIK(MV|S?*ib@+2{mUDx z0TUAwwhVkbA7>c7hd;nEA%OlBPQ6zgkpMt=taeIi17u&@1&tyRw&}pL3NX{NH4*Tt z6AdFO*S~BNR~sHvftjPQ%zO(AGgq)aQioTmtA2_;LO0Irq1O>7y}M6cym9HoK69^Q8ApzJKDBDRw3Q!;ePz>v;l~i8)Q|ZUo<^mKfA}u35TGpA_ezTI z-y5>D>;AT1#%3HXNwgp<8}e6QnMFQpGe71fk{lK%4Ry1dZb|B8HzopCo+K3he1K{n zDjf(9^o=XZcjK^#%1%2H8Al|~1_21~AbjOSle^BoAU{uB7Y`u?+QWiF(Ic^Vt;&V|}jEA^u2C|Belcu-&Ohv@Yw ziMROvbL$Y0YZHm3Ad^H$7H0CpR@?@ym`H9yHr0Hgt4kA{v#?ef#&kednUq)$c5wn%twrx*(lojnmi$;ZkJgDv` z&h5UR)DkbsdthIMlcR!z0!PPF6jz{G=-xY3|J%}V`0$t1ik|u3p*ts!&{xXGd@F$d@!e(0;qLbSabZ}Q zp>R+&PGTkuW6HDk@j2}pZ;cq{v?Z;xyT>&v{odUkUMH@zF_#Dllg8iqSJbAK7(XX_ ze;?f^&P&1`T2Wh&htXoB{Bk>!Sni(?V0#OxR6$e$sGk1HaO_2v=0btqOj+4tC7c z)e&!}E8vF0q8^G*P$~0(|j3)#0vtPR5XK&j+A$Am81iI9D~Y;5;8|6B>sNwQ|0Y1 zvLlE+26G>7-ne6F&+}5N7KGCtx)SbKNz%thrmFIi)ARlBs%t9}!gE7yFti{LOH{>v zEi%^*xT2|uDZuC>y`cE$AQ6wAIkgps^>{wqNpgR-;)Y*EOVDswMys)I3J}*gMIOxzL7!n)LPsXKWhJ!)u)cZ^;lP7d~JU z@G09;o7plJWoPB>R^p5)=yLLiCrV`_$L8IAhZrMT;M!n(GIA^8Bn(GQ%o@8MG|=Qe z#UR8TA93qu$Jsj?Z2Kt5vq3#PYQZSjMu49_6J!9lADuDtR|M55iS>C_DK_&4@8Pq@ zTUlyI@H`IGbPgqA9!&j-t@oTbH-s;o3;ugfta;ilLkYcZl&2jp7?yucMjhpG^M_I5 zg-*)%_?Qg{8AryVxepEI|Mk#5LH#_hYiho`^NE#8uPK*WzpRH#==$&`K-RWNyo}hGy%PmW zmnx`tlfxEr&liK%Br11l(@$7M!L9(r!Ci1FzmV9lCN8P)@FuZ-d>~M0(TS$dWLpg$ z>bt)<{LKdD{{NJujIc>XBO{}TeCM&z%OouPWqNu#Vz;z}*o8^np#mp1@Af;x0QNSw zKz!K)=z|ELWT(_(P@Oy?JGk~&In<~oCj9)MjFf~Jh-BQAFRJ`K_m;L{ADO61LnO19!5Trp&?C&)J&Mcknzy@S zBSB&V2=BZY;L~vUOZZj#;;@Nl!!*6*e{_g0K6a&}>el?o^ET`@f5)1Mdzln(*(E%h zEM;EK8%0^aTD+iE=a)5UPCiz_(x^IB5@)fxsS233SAZBnGQx?S$bfBK(wkd+_ldfR zBt(GveycPNEG%h|VJ@3hS=mG-QuD3ogczLw9y&ch4nNNvmZ*+_kyu<)Ld1~HcfIZf z59MAxn1Gzr@>Vyk^w*4Lsa!y^#tBYvth0)?v7#;o2UaOY28E@R{3B4`YjknW=tP!` z4dWCt6>kc}6hqy_xyLUf=~AnTw$i>D|3h80qa9GOfb(3N1?{m!RXVVGlSK7S|D_^m zP5?ko0FW;7C}I>Uu^?O?x^Sl3xd^>ID=ohYGd0`j(lAce-v>{&J$O#g_gQ0Ty$J>6 zwQ>KVj#-v# zmSuc&gkL?K>VKj*Se3{$)O0K3LytyU@x6)7xOE0Xny#X`dRzgT$UvI*B8pB&OW!}D zs2pT57iX&?0LUTNebNgIHDE!#Dh(@zD`6Sx)#$BNe0(cLS^;H%lor$RD#WZ%ulc{F zfE1Da8a?nKwMeUCh~FjDjIi!n+C7>2Vz`~e@Of9?Vt&vjNjq4$GSa_F+;HlI=g@C^ z{5Ez+Zu@ZL|Z0b(&|7k3K6fgzE=!W3Tl~KbZUIzvEUD z4(RF8BId+~`gynCaEZ*^3|u@x7fALPYvOTyuWr=kjkvXg|60It&`f~oSwN9mY)6Es zEdm#|wF}^~AUA(}RG3~~kDERGV$|V|h=9q^_KDh~Q-p%*+6wCI^4=K-pNa*o-wcc( z&rsVolOO6I-rhEw5nK8lW$wwYgn9AZi#r^WAsG%1N#tXh4zRzU4i61m1w@?}DK^yS zyPySw`yDrp6gd$@hzZv&mQuZ=KsUut*DmQOgHs0Z_S2<~O#X((QN>?-7WXP6e2S|3 z4v`_S7+oh;y_8j~$DyKe267G5FOnPqdKNmGP_4Orf5KGofAH{70wCmal{P+6dAs+R z)|4wgw7$iLaGg3C#CgR-7e*2W4W5Y|e(Mn+me9VHb#+AgKqjod@i%%ogVMyEMNB&? zZCfj-6eUO++gDWGA0c17MvN^Y_!H{W6uWo~i+rN@N@eERaoN%JK^yMR1*3?r{z7_c}tTh!5 zdtF7b%3L@Ol!TvlHL;b=YE`0#Gkt3q>1%-O{ac0Hw(0HT<);9-rK2`VFzc&4O6|JxOSgFqv`P3*8Qdp-SyoWg!TxSdYy>wP5 z`LF>G(g`#xv6nU4_z-kA?{cxS5j3eu#{VvA?QnhiUrHCIlj!jJ@cXd_iwq7UJ#}O& ze^ic(A0VpmWPlJ-)X>8ZUZjD1v2Z5Fd&@X-CGWqR;m`mei=*ZJ7k{)`pJZWFG<#gv zc0NQO&6($HkFVUMU+(4I!=k*&;^@ zECxdot%(BkfRl`&zrYZxkw@dMo~S#1eoYNA^}SO< zi@en@o=dt_wbPO$dsWQqA%kLpNdxu#kj>pdB(9oDz~Cr(JQn{hX!qVt9Fj%H(fckJ zmLQtjr{pjt7$3l0e3%8`dy+kF!G$vP1D;J!e4m5;!_VogOo;`D`@}>sDN{(!&8s3G z2Ky&jNX6~aQCLsI9GtQ8E~BBdp7vv2YtwN^ARWH2hvCl1khgK=%Mq7z#D@^SUAHsG z-gE0WN2GsC4?^|oAaB_91RzZIVa^Umx}Tkt|5alrjkhQ%pdXYaI^bTez-UwNZVAWe zKXwl1_3iZ9S`+!hOK7l$Tcyy~r`<2Q?^V0}dc+9ItKU##XooD;CVLRye30JP4_`FzAsh=qeW` z{}SJl|3MnBSMQln1H*CmJWT8*mzG@=Hr!=uEo@_M1Hw>fZVd+M1Zgs-YHi+)n?)&e zeyFYJ`-(TrB{3bJG#rN-;o(}vNG2f6gJPiJ8lHHTgbz*q>JGmv9c1`qWK64rny62C z+AhqMGw{lU-lX7`1CWc98=;M{!h^LqBq}LpZ-erF8a4UPQMIWkKXtgCGg_kBscv&4 zBvU2ZzIUW-ET8Tl?lNNtEwvMqQ%cDgQ>b|FfZAvYW*1!tWTWW%r$}FD8G~}4Je#pI zmo@yBe%<`DET9^fCtt|16jkv3l2;m1zw3!UE%zqS*(?#bNlhxE0NTvQ6w8#g&25(% zTSa%tLOHJ!zPLAik~z;%r^hfxtcUqsN?Wo2ilga&K{ZQVL*FaI^L*)k`R%?Tzd;zG zpyP7+HKQ=_miY|d>2@+ISZ{gYAI=u$S$!czR_@^94j`Bf0aVx)aLk)D9OOVV9+;Y< zZ8XKm!|io37sKB%6fUZ8zv{|&s;O`Yvm8#GDj6Mp?B_d_T_YLkSFkAkfeV{1N)tbG z_lWQ93me;}nYU1r&bys55!{zBMYCONyz?`mu_rS~FSf(b6Qf7PjXU+HMGtY7KN{8J@K7`bj5p5>GUdX`gt0=$-ip$9Uy<$1 z^OxghRP=tdN(tT0=n<$d#4-ROwW>+lM;qcgFS|cTy9MnMul6jn49=D>vtg6T|M>>? z-oAX!p>8hRm)0{Dfu)lc6Kr6xfp~Z$6POC@hmDaSx-|vM_0P_0gr>DZ$;TJFh#~=g z1(bzNH2T|PO+R>a=oJIn?L#?q*eY#KLUL`S-HxSc$`(Xg0t)`S7=b-glHTV--9p#o zwRgm~t3ym&WEU2j5ZEp2_|VZIY3@B2!gt6e=`$pm-~}r@Rv80A??+5CLQa~!m0CcI zL8Pl?VQr0e6+UrqtHt&8;C1BD+?h?5iBni-|P zW`5)w3N=ogpYqb>rF&R*Kw6r1pBOE&2wSNuno@|N>-X>G5uXILln{z64$v9sT!6T+ zgX#e5AQq|dk$w6`Zjb5d#dvf0qp~o3dNd2acFAm|s6|QD_#a4{zWS_EaXW9#K9oV~ zmJiEs=<83Xm}lLcXKS54e~x#!Y!tBJ;pEgB;jr-s^El1l!FdoX z*myUl+T2b~PM%&BY*f~kb#xd?aVG3N1Qe)h>-y*hbWFCY6beT=+}L4Qa}!&ZLT85n zMnA|3?7z8I2ou$;v=lYkg&(9pN11);d0N%a6ZHBP{5o*Ju*oIh~1tNRbCqUQjIJ`25lW59~{$`J(JJ#w?+w7*HdQ zXhVA0(7;Q?+>io;B5ge|volLcsiSW3eCAqUz-&6UfbQcVIP85~_u(`aospZXj;zfR zUm$T7hqAKVfaQIUT28=rvT%LZQ6JKXEG?Af$-mP=aEPdqVCaDzq#DRfK(IFs!JvnG zz7%I4r}4*`ND8Z5Nnal)AyNSs5Vq0eN|vVgG6_cO*>cj4XIQoj#&}H^=ZTYtQj_NZ z%T;B*D}9g?X;kIdme~m*R$&Zp1miAkn2>N63RHjhoZTT>^@hfGu3$!_%pM&(NZ9#c zEiCQtd)BF{XukjomhvBsAF87y1t~T*j6lDtA1{73r!@Z7(bnO*V;#`#in{=v*9$(=>p=*lovoWMM{xJU6LLtZ zte#$hcb~CqO}|X!!GhjyS_U+A6vI%MJJGT4#SDDF)+8oL_;cGCEXvNLFv7uW z#}Ww9CDUlIF=T6Lc2v-_#CV36NzTn30||0)P@PJlW;sy-|CG}Bt_&fk{(CUHgs6dq z#^CBy8X;tV2+bka?J=+`zVoJ9aGH-7;ov`QHNCL3G%Z+KGaU^L+y#auEdn?N+6kXe zV$xR6juxmrCnJmIY)kn=N-`j!?;)WKgS?Wtw=$J+W?++vO^nP7W3`2^^~!|#u;tj2 zI`8XO#jW#}X5_LdpW!w+ocpU)&S!5GVg#t*^Iukc^t=+B0;8Ci0nr~jy0p!%2)Ij} z4fVKgyM`no1q&i&pCy>F%lIK8U)G^wiA+r;xrH2j55>nA|MtSM-ogx9T^b#y5%_K| z?^Fwf(C^NTl#KAhD&xYUji;eIgMT@7(fM2CGy~Fw@;s#%~QF}_6QMD zON%-6rLm^W9vUu(eHC56;uC~*$^}VWe~8Rd);@R_r$+OYmOj(4ED0g9a*U2lut<`3 z3dY68?mT==u;e^$B|YlTcLOuU2z(`8%o@3JcR&|Yql;_B`sWlu=I;8V8Hk)nuCVlI zpRsG6kr%YI5)eC zE4fL|=8FMYJVUVfI-o95pr6XL5Kc2r?)>sr+Ov%a9QeF>Mit|V)T=d?mzEZ>xTuFT zjH8!}8`%OzA!k*D94w3G)mj`)t3|B;reSd>RBij{T``jlsP+x zw4cBGq}fo|w+^RHkMRQ8A-*YLm<8S9!X}RTd6X8An#llV;LjV993B!wjnf&1^a+pv zgDTHQDnj3u`R`p8vG%?^#u>h~{SGpxHu5JtA#!GEw4j#Q2;j;pbV0uI$dx*eEV#0D zb^;}@^p)qI97F~OpuX3OHitm4^*R}f>%S+qHk6gIfaTj75IgTE`XtZ%u20uu4LhrR zy=za>ElaC z&KJOaRQe2k`1ENk%oJtij{P;u;RwD-u8OPwH;F~LK9nz`P$G$|&oFh|tkV-QEF6qO zc#J@08Vo>==6shZPm5p{wztBB1Znya%1D*NxHnR}xn zB=RH2VQ8(UUT1AiKi@n*9!jgvv4Xjj*wEs8=)~ZstOn7qrh;9L1Ya}95iXX)J}T)! zE-gs_R|M*MTiF1^RzbnXatJdkm3ZU|c6}|vI$7s+V&yV?M4=egsu^`aU^}b^a>vH- zt+ad1^algVC9uUZxo`dn8Bp5!G0>(z?2<! zk+|60EKss|nlex4?%DeO&&&V31^xaoZtDEuCF$Z}uW)OOTAHN0Bp`#=-vf z7w2=$9J?SNPioKWV=l)`Z_$}moGVyp7|-J0BWzwGJ34ByP_-c=ZFN>K(Wm>)T}6zt zA-6Foo+sgkv6n>Ew^S>$YLiIWn{B}x1XvpeHZCB$T~3SViGPW?d->uz#ZJC|+*_$> zLyU4~rfnu1#4}I8G2w0{d5q+5qY4Y8rcLVOgmW`6AT>AN-ozt@^ErDbDBKBJLL*`| zKE+_P(FGtO($b2i!71TKn{x9-e*E}R76jrtJR+^HD?y09JMEi!d2+QhWdwSiLi*M^ zT_aRi;!2)wzaRAmBrWZTIzTfnIM)x(SGIpP?Ol|pu5U60N`1vi?~c30s6 zmyDG$lio5y7oHs zgA%t^FZxAA)>s8>^sGBN_isQ2g?|))oA_fB%o!yX$hTY6W0UGgedSYmGTMP;O2?4% zmC7ovzpDgS$E(l-vVk`oCMG6HSRP&YHhqM=Z2un)e0FDXS6BBzjcDz9(+0FtN)y<9 zoxG)`V2|y$tJL1h)Qp1A*w`D}wZG#T?YX=1zek~BoqpcH2D|H=z3AOTLBsNGg`~}a z(Lei({wU1)3`|vvy-iI%8=&K54*;_qCUFlvNw7h3d}dARMQ%Il{cfJ#I^sspRAOHp z+t!5$`cE354UrBABUa}8ANZY;tJ3gR^Dt-H?e70-?I7Na>1I2?A2z2tB^Pcz2bx0QOW3o9CRgB13czq|?TS8PS>1 zJ}R8M5Fb@S2dVEx?YqL&Pr8EP(lMVIBs{lo6U|l@s@5I4%1-&rTLU9Qj_LU=x}ufU z+q@nis^*9730JEzfHqM^YG|Kz9WT~wmP*fPFIoA@Jvk+H?*<2R8F7>dU8$p^*U0t_pP{bzRO=` z4dH&r$35>3JBbb8Yhd`N_ac=-zC{8R=9J{QnMsYV^&5;}0;-tD9FN@wdomb! zd77s`yvOJqYeTYWB3lEGIH8li1^M;6fnIq`)dg#qj0!-~5g>z*>nr>nUb3OF9~6QX z)L1f(3IRg2NndhS`#Ca`i3=E{6ukX@zn4--7qj}ys-iYjROiAjb1YL+{z0Kl$8Vbh zN<2x%z^o9GsW56XF{$4qA7e5|$|4g^btpGF8GUQS>qk0}Q}_DtST;59?>+afY9s?j5)s7|7i}yH8 za{wfn!4kDS=+U|TvjJn%aT0rBmJWYep`&77r1v@`9&xP5>OmH@QEq8Qf`51L-z@yT zJe)XcSK*>RBEH^h&)_MkdF4T@qJ*(GYtgx4MYag1DbVK3oTmTT zjct+Z(+9U#p>Q?8T>(ES1gC*1|IX{Av)d`;dhIUY`uHw5V0ZNzF_9nCfs5R!lyP!i zjwR~6Cw8m)=F=z6YoD^(;*R%oe-UJ6mYBD7Im~Y!(*Y6wZvUH=Q#^cuoQs2ljbEN( zn%EdCf_-ymGaiA!my{0b8G+>|sspMLMS>%2rh(|~O1j`BN9JRj9`Da#KNw=@IvC|a z(PU&yGKtvgx=|Zu#95?b_>EbqHI*Yr?X3UGRgy-5rjG_b}sT?XT57*2x4B< z+8gV;ql!z0NFZ-S_!jV+^E$(I<&j}`54$V}b}?VTIhw)uVwkm}RXBjrvFu|JVx9JY zB!2S_{J%G10js|r**d}i9Upl($3OH5(~t#3N7>%iNp{%vCRtnr=Kbj!4!^z`5kawW zqYqRuEx#LaIGlO@;ozzr72(Pv(-d<5uI75Pe0nNs@KG^uU@|Dmh ze}6E%Zzv#{Y{;0JTSgr}Kny9>FYJq;wcmYwRgdyvmK-fq(VzjAOqeTzfY_Wh?voQ% zkvsGGD1Z>(Q4mFzS!8X~tD6%b4&)63<50aajefT{$5T1&2Z(|7%7pY~0M9bIB8*c35lv@3(hk7=t}mIJ)~>|CailtN?$KboIMlj{Y>O5$49jBj6_W1)$65 zwxXZe%p21$Et(p?rB;&p!i*MLNZyKRb61uQ5>9e!6*gj67Ie{y*i#fwG(>lY(yrw= z4E3;OM0xg=QP`x{|n4zkpOANN4&grl41aCu6qwQHa!BT*ZJ%hEXpq+?0aH}lyHnl3? zFoHZ_pbB6O%!F@B-wIf4Hg@{xKXcR=P0V8%oseP7xx=g42-al@6tH}``@r9C|BkJ) zlvJG-?02jH;di(y%D4uuhVzZfOf^)JB$_pS>nh8>Hqu9ny|*-A8S}j#LVbiE%agBW9j4{O;3_Rbr#A#7=V&A|8nbr+Gu$Vqw8L~ zoR;kIZ2oxywQiwB3%zOxtY9pZ3Trj#?v6iw5G?Giv6p@50u6&8+X;dIwdvh;C0FQd4MojIq8lAKl20fhD(_6#pXtCbduaf59S%*fyVpp z7~ARFaTB=tWChxR$9I&!SR16^8M^o937OgQlTV4-5T6g~#-{vg!7u5lVRr!7Px3c! zjBNk*Dx-539M4;?!{2nc{Ek3vmR_jPPqyNmPFpUi3ueicfz{d|fzS%~*?FT5=dGx} z%#M)+narOEeFqI3CA3T!!^`|m1Nx*%socIIMk4}HOS05|ClqgJ46zVH&O8+PDUmyw zM~g6}bazwgyb3OUXdGaIII7v#$`HQ0-HRF)tVm&t?^+H%ss9R~kF6y(Vijxpu$!tA?bPKviQHhE5+Mo-6;( z&&LD4vo`>7%LGlgeM=+rKbHEm6&Hw`2J*5vdgOm1s8#45lfffLPMd?_npwC}D(Xiu z4*@Y)NG|!l9wmhmO-{ZVO};q6>b>s$#uY$Dk)Gd)m-MAFge!8vp}+Xaw!(VdAawQ+ zaI_IVwehybcqjfj!_e)6=|5vsGUJNA2JKd14f1$zUSkJSaOu1=UkD_MoR| zxn%TQM@|o&eE9GMsfBA(7M*OZcI$d4f-4p&hok2%gsJOon)6F*3wl?&JRhwfNM2yy zNcmyMv~Nj@mO5z~e2^32OcWvnm?72@saFD+JaS=RsP$jB%sbiZ&2L9b2L)Y&*SEEw z9ym2P^p&7i;Kx-@|IoHB?9Dg-9XEF3%|z3 zhZu5!?w0pE)S_;%Bx!n&)g*Wo1wK(?ZzkWkmv?ehvn~$3gakXUINk?pY4+zLG`?o0 zwB$zv1(E8tCy4iljy>&#N<@h#Qoja@S=4vYYi4X0JO6Uyp-0kjA5&+2^|@DDGaS&g z)k3UOxe4|5*;hOAHC5J?UZs|{ zuNVjOc14m3k0$tuqvW;CyT*~b|?xB`~pDupG-hqH&O7H3JRj% zQHKgO$qV?iqAu)*Ox>imyx0I~Pr98Y={^^GN#6Wzp=CD8wDC-vnFJ}XQ&Kh6OSJ~K z^`xq83tp;sh+{czKT$L?+7WbJAg=JnR_iDRz2q!5r?_(1!jNFWdf&4d?D)MHz%xlS zNk;MnSrq!1zNw*(vGJ)$D3eiZ44XhR`J;h@%?9bmWo1BaGpnJx$KmjDY(0ZoOtWR_ zMb!nRLph&>-RLuA7%WI*A~~I@pfE&z@#7Nzn~L=?$c(VcGyC&m)$UMN(g)c&QbR2o zo4PW-;_>UA3e|#uabd6&d3Wunf$H7M0Ae!g37N?(O@16o!sPOVsS|0!;Z1;5V|9l|%dOdoQw2Fa1yKFD|axtG`l&HfTW@ zGBLq+$2e4XUZC>{tr9J*nxvk7n5GbP37aV<$!2@r=$GKIq)$csy7m0(Y>Wt6mE^ ztWW)%8i|%jM)Q#oTxqa&wcf!&itNLB!!aX3+($gwNCTlc2+59SaBY_}-srxmo=JN~ zr4@rq^RePlW`3z{AHDcWzxnK9oGGrFui4_=p~xJ`IHS#|YqPbSnZ)%M;|}p@#%=}~ z8Ufuph*{dWkY5s!!?;=a0!`HRr7A&yvl3QJcGm9A=DXQWMEMAR&4+k(pFq0L=ksix z)^bm`4akST$UT*#Np1b%MLfAF7D)P5w>%Mnfek2ATP!PPI3Okaa<{L|c+wj6%`xx? z3UR1`lmaH%?{;v?&|5Lm_EhH5yx)7xGC$`wm@_J!chVl!Bp0*!xQSZ8OwEQ@hI5uC z5HMJ0tYub^RHm6G2&ked9`}7|?9GBZIg&S{L8 zVlxD>1hgT07Zh-}wey75<)bXjRqK5}M7~p@{+a7vRN==ksdAL*=MEZ{RmfeYOQ=6| zTbn3u1Q^cx;a**G^OyIxMN`{lx`kz|x}YAk4N-uOCziRgG`0@Lwx8iHTJ_!oJVe`w zpT1Zi&nL<2IqDm_#cMb4$7CIVW9-6DZX7!m>Mjp~9FAG69KzJS`mvuzpQ|aFQXIZK zs}Z%0|5qJw&^e}FZM2QbkTq}33tP^DcxCw??e#&+WkN7pq{t@7~Z;dj_}{v)*v%0& z$njKGoM57$$oPnX;N4&27NlzLh#W2HanYyAGR2n@^a0VFwhzK>`cr7U?;iNh`hqo4 z)nmpGTG3et*h&dgm`LoEB=Ya*x89-IUPTDl8g0{TuL!O|9@W!S+O7K!mKPrjCVN6= zMh%Vimx+dM@%$q%C6yLa!)sPVAq5s%$U363PsM47v1QVj;TN5#3bf*@GQzY)C7n&Q z^^**{Kp~v#iOC}=F9LyRLk-N1ukmO!Cp_CPHNe$UaB$*$Xt8w-T*r|E=fbASDh7b z@RP>vq}gHiMvZ>A7@r58+izAr3sTWWhD&LYWscN+Zna%=PG9B6JdGg zN^9=$xbOX2asw)Zl+TtCd*|4QoLCvon!~ZY6Vj}(#~e@QSS8t5%I~QHKza9+oBh!R zy{|rf(u@i<<~6EAwcM)KqxyRWWEgo?A~xwhC)oL&QWfOIJSm_MZ1jUC`xKne5*o%x zju{tx`dIZ`&{wNJIGI<=m=SuH&+K{}{WtG8@$W$0dG}D;*)fhIzgP4KBm=eOdidFx z_k}b@jP|avAAg`Q+e%5FEt~~xYsZA%$n(R3Qjd_!d7cqNhTho1g)53rxdMJV(?h` zr71lH61T`2!4B|PAvY6|ouCbO*QVeD4343;r(7c+m6D0p3Qpq6KCB9k+dSf&XzEiT zlF(`TI%q3raXEI0xiEq+djN2UOl8SL$s zr{CVi+p2`|(~8VyI_{1;$0KfugE~AY5`@H^16=i=TIMr^ORO~wJ+A*DXthsfXvWgy ziDg~aPa|RiD&{5Z;m8%NK#;0=6ic8YI)!Xg$%^?edVeDez%Gq(;SDr{CV_y%!wj%J z&Bu;QcgaOx>PngGvryt06CHz@I@rfFbOHaQSQh>}>x<33n%&OZ9uZc8XU`Z?%uHic z`C|_uFKG3Px<17cZ_=n5x;+(4tX%jGTmuVcRW4%*jM*JD{et01jqoF`{d+kyM^hMz z>o(J+uE5qiVIxN9E#O2BQ+Z)4jg7;+m2jHEjLT)@Kw+XIt~na*v9{-6yq^^s7iSbO z8Q2;}0&eCKFi--c6F4j#4#%-AHE{ehvx)gW7Lk0z+R_Wb69x0X0sCfQH3+YNrAP%k zBlg!Be8%l!o*HXnQ~lkBok59fLBGd`$}e`-eJUt@@m^dMLp}DYl-{bBt@*$h{L-zL z_YwnV=gy6rJ3l2umuCZfoOC0~o%s5TJM9B!3)9sDZGhbPB4CpV>*(7^3$+l4hjh#j zhRNQloZ!r$S8pZUB=UN{#-)K77sFm9WgF9_CZ_azNk+pVqsq|$v6cq|nvULsttXqt*`A2JLrX3?vd0}PZ+x!7q+>AEYXZ`^vA?jUG9V~DCqU?#;QNx}-6KPMoWaL4y>=do8O{T#{OO@He)o}>+TMjc%bC_OU%Q%Gi*j0FZ# z*K8Mr*ER_qM#P@61IM-`6B;M``=7wT8!6$O0d|w0twi_g!{*v3I$JQ4QwbcOQTI-a+sDajwQe(WOzE?%dde3sz#I>ckI86EU_V z+RSL7Hhf1@cM08CRsj}Y*4qs&^K%o`^2$`{zXI>4;~MUAXqgYz*)k;8`VsKTJO^)q zBL>1-H7XY)(tXV~GX}kX)nLcMmjwmBJf0690B>g#i6k>>SqwA{vC#+!Kp`h?BBu=6 zAkHc+RVQ!Wj&HVtu{QE@i<;?14tsl@_qMB`+$Tc5fzO&&u=d-3coEROG zECZp6^woKMVQB!WSo5~8_7qT4Y^g@y%Cdn|jp5k&+$tZ_SF&66-)MU#Wy6}pZcn0L^u3}#YDv}}(n!=3|oU_ur4`B|GKX&;ehGKANcX6L(KkGMY` z!*PV_t;ZO2Kgxd4V+T!c^zD)L#T%PlynB8msL6Q}W-`h@R$O#-sq+wm#ZmX@?s!FO zT+lUHR4y^m(WmFd2QG*D$=n|^*GsVAB|wgfM`Y|^ecH11#9IA?i@Foi4=2k_NfE$~ zMMw`mpcrN}Ukt1H%-L_$z}o32VMJ5wtF=aI*d$*ZT(AtgZLYUJwq9-^a&XIaDA5#m z+@{VJUHOP?ZKz^$3-X7EW@PjE>$os9o4i==x*%;LX}8W8zj~C*BC$+*ev$7MdKC3+ zSsr>QS2MaA-wG3*Xv#S1lgj=)-U90aK>S|ugY@mjAe?F9m^d`9&N%lVc)mYb&#pMT)U2fGNK-tGtt1kBQJg@<;5jx9BOxm zR=P1=J%HTHI0lMVKOV~RB?^;PB0tznZ~#d3yvFVv3~fqm(ws=5bEmP7VKUT|PLjGX zOFgq5ZTjFBm&fFA)m~IF23;H272!&ROIzd&o-3bQXJ_ZuY zTZO(Yq5O`#;1sXU2!y^_0?xg5lPa>>R4zWV%ZL||7JZ@?%^Mp1CD&5DI0`a7J#GE$ z^OAhHWk<(np6>=cU-eIAh*$vBOFpu5vcHsZ|IVt#Me69c zC3v;H9%3(IE@ys|A=b%BUQETV1iUZu4J>>o&Bhg1Xl^M>ZUh>Euk^XAg zH~gdz?n{HTn!R|VK)+k{?Aj&z^U+@8)xG83$ss_I7`YdLK)=`?E%eTCB1s4I3pV{b zCMhP^%_-Ed;AZy?Qr|(npgK?PK0!8zraEgr+}l=Ap#k)UrwXM_J2FW*(01y+-D=xS zH0~LmECO9t`@1p>oEk|M7l}mHmK1Ti2GcW?*nW{WB0E3Fg}6{DR+)Cb_H1)Q&*-%% zcJBT5CH`q){;udK!Y2);p#;`~6U^QGEP?9qO;;|mkzbJ_ek)c#RLs1gI)&|Im0iv>$b~7Wfc|{!6Y%&bGub_xFDPwD`docx|3xNrgW< z^0fMewwb2ea0AGPa%>&Urf58LZ>%F0LJ~r`L0dzLp-*F0L7tiWZYqTkn3G&&?=n#3!#V)F^@!o{4uE{(^Fj0)6Pwd~2=`zGYK z>UfB9taiMwS+EAQ$hb zEZ~~&?kH2rwa*ivas5k`NC8Ita!iNTi324tUBgyAJ;rZV_DGC7;uu^~0?K|_VmsW0 zAB77K`g{yZKZ$=yhY73n6iw=DGg>)f5u2*xzV!|=GEHLpq~AyEP6jJb$}u;4)d+-^ z?j){>AXb%VL3N-R){2QF{*|8}zU8ur5o;ogh|1y(NCfBN7Bp$Bl}E@X+3%_@s{5rZ zu{9Otesz@cF>PCtixXw2=J-Gq@`c6zOGisjw;K{&XH-j{!?IHIxd_KZpx?c$ky>-*V5+vBf$NW z&S)~(CAnT^!z_$LfY}Q;M;>jT8Gu&(+!!QU9i|giO9eEt&BBfDfG_+`cT! z)6L=r<%A>_2a*)wD$do|j2UPnfSvRbq?$Oa^gW-yDAy^aOuLa4>IfPx_iY9~k7V_& zEMx-%X9Q6kN{f1azdMGfc$B~7Q9PHqMNWgP>bv2G!$Vq4flLXNWHE>~d6JgbjV*m6 zWYq-4#RPGe8^uQeSgJr3b2`hB>k|Jc{+u6FTXpf#R^QqA_j+Z}v=Z3GlgLFFdE*o& zC?(5AtLkgOS3mRS^m8Y_sv~_kB4@m{!?{p5C5w~;T71GmGl*^(;NTP8yIoe~n=?O& zDjULeNg^*R?PMY$=bRn+(W2!4cZ#V&N&i85){eY5?NK8|u>;k9yPL!z@AR71yBWk* zM2&|MFJ?T>fD4%IQ+l{s1icvQ1)CC$UTIjU;1t^FMT?R}P{OU&)QZq7%upFB98EEg zGvSP#gJ;VdR|lpfe1jFJqz2;qqrf~+4r@jIQQXQONR6bSe8 zHa#;q|I#~+hAu+vU@Uq2OC$OTJy+$G6R|=HN2rcrRxLQjYRkNLQ*<{o-#=VlpZZpB z`c+1CsKIrZ3aj}HxNoQ|G;mad7JGdKOJ%ZJuRf$^9XTCwl2@ST!L^7X`mE@_G4lMZe;lZXXt*lOn3~?pS~j zzg$Vp1WeVhG5dtoygFM=4jXYo5>fN!KunV5NEX@QQ{;eFeO0Dv{QFC^dCG6@+5M&V zkx!$tJpXD3r*aMP<9*=F8QgbcJLmFq)BUuP8i`0BY+@QAf(b!27MTK8<2j$bnK+E?@qXgT=o8HdQ$)G5#(&-qmI!b7px+ip-p&t4S@HwDxrMIH zW|mbQT|b6SHg_D?4=@shQc^c{Jll5KN8$3X6iCSDKU&iTWGJhOHx!oa<#6I&m|k`j z0S_W-9=cY0lTP$}4)Xe#2x!WRL;ZncZY!e zNeX^ zptr7fu~0qM7@PZ~(%q9a$y?{X=_2ddzDcC=g8`pCURaR{j$d zb6B<9y)wkmb2C#fmLDq&`*;`0(>SjP^bc)N?fqian)s=)!v#*l+>6WA2{;Q9c!3Z7 z!qz$e$jUtp-%<~%tjead{x&Ruw8hDihfU3r)Ji@(N1l5@y%uCj?lJR-nNG1P^pf`n)32rfaf@6p{s=2K)98mi#&j~Oz1 zKA1*ZODT++ZuS;YM~SuSN7U%}iJoWGvp0{t-HFMW7>4ONsE|OGE{jhDxsHQ|Sn;9I zcGRi=V(=+DUy3*R8e4b`araEig&bX!bN*a5od@VQu>2ai-Gs5JWSGc%DXFsY6O^na z5oQQ`>Q8dZI;-tj05Y(;@)0R0dd*NrAJtq?F^{r^_49zve$F(WpeZ@GCI==mtqGj? zhK4w|?XjSZ{p&x~&%4PzJ99S@vUT`jhBIv0wXo^Yp>oYib__3b2>l7-aQOt3Tg~%*r+a<|n&Sd4+n|}QQE zi5WbYi&m|c79x>CMj{na*6;^7jX3V0J^7U_YlUT~_h;Ow3>wE^lbi+LU4?$+b(c5> zq=Emon}sfFE*n^I#;$0*z&4F3txccK$<4Iion@Cp_O=Hn3P~Swfj*WI?`o?3V}(cP z2w3q?1XQDn8UfDjbzI}RkZCp6i6&O1mlcDxI<-r?ZO44PO&XU>N`R3w(TZTLnuCU! zS0_~~G!#)JX8PqN&%tFWnbDDGmp{fTJL*-TpleaSFqFbV2q^?5pQW7~xFwq_?=~PQ zK{g1=MG(`Q$OUmk>>GBiLf3ScaQNkK)`~I;qduZbX~CPEw$9I%JViP(RhV>_f%4D) zv7&2M`~!>{hfpZc;8+H2O!0j8$bHjJ9T&1EQOGN&rdJcBt+l(#KI?a{i3SxOI3yCM z+~U1u=t&|4@q*n!PTVHfU%kCpPS+O=$C5Ds_elL8uZ&cMs}+l!eG~{YWJE2FSwgvs z#fNa}FSs3*GsNwbS+zP%^h|z$SkwQBkd5dFNp#t7UCN>YbACopy5@dExB!;$N%#Kq zqdVkJGBF(K%Ydx2x+LQyiTh z?H2^3D;?D#R*o|CcuL^Gj7GmUYgzn(v#avmbI}}?VUPi_kX(7ucD$>tgTwrlLat1;`uS3b2<% zzc3`uyasCh4WxT-+T`7K6{4*q(9MiP>s)F{X3I>7zhHGku5`3Ou>y)DOS@WE)BHPh z^DtDH67*(YDx$;U5oFZ;$HHs!`Ea^{7VjJ(2CPOdcnJX{c-h(9w$oOiyEvFzzIr0* zBIK=%@3`!5;7_)@bZ);2)095Z{0`1JUfR~nk3WZY^YY&Mx_}Q1`L+BF7toj(UQ>i3 z#=B9U#fthEvz?lWUY)9j<{JuBq0wPF67ozUTev7}{{(tIg+~{1fDx4nER%aop1F*l zHW&JyM_Zfyw>V_R(P#H<;FjEq4QqwDBQQ7J-DK`OqY_sMw(kqd11+&ujE6gPKfjK> zy<0+Zca^XCFCkqkRwok&ce&6pjJKI0bKX`^W*;@>`V3$5;ERj`)}{?EMEaltMmDJl zLg(0IL;ohtSem1I|H6#qM$}FFFSEziT^R$FDmu~%BEx>ByqiCqcD=kN0dljNbFT9) zNE%0SDSxCKY`O|bNILHt<$d!VC91eMCI4CA z)30@4itpJsSk}^^(@vROQ_ek#740VAR|7tP?_TY>+3;zSUp2dMm*Ng3ScrA$m zl9wO($GDs0Z=V#GL{DQ&aDY>9zhvzC#wpbWnp{XB2j!nV9$8-i=`>pELUrk*=I)Gf zH#Qkojy9mDtgDsqXt7#KK3uT+%YgaqI9>7oXoUrG;$jrYdL_WOp?gSJ>JMw02n z@Bp4ZXY9%~9p;z2#4vJv*JJ>@oSrRW#M_&iIlXYK&mVEL4%%9ub;yBzL%xXl&EDKK zWrlYtTzL6^ZTaad8avY&!;5f)*#%ncX@zkx_nX>A_uA2`b-El{)WGwpieLpj<@$-& z0s0>Cs3pg{w-+G@l4?46(5|LW=lP@K;D>j?H?kyTUJdfa{x48lbe7nH< z|I3M@OF(4aidk!UV1+K<3*`Eq`9DEBZB`jzuQ@uU81^0I;{np0qB+A2NA<=m}Cd<*m_kUasb8=1|fEw<`kVTq=fo5p7YnVEw;L<|e5gLSGaZ#vN zIN61vt!U(|%Lwv)c=8<0-+}a<{(E?kW`SLggAPZpmU>`Oj3)ZxMnfCOo`k&zye!J)9M_8oFi-69BxcHr}X3lRDU~8q2aWz6@fzka?Kr`o!_3DLjpKimCv9%|Y7vj&-3 zue@%SZ%`~T8G}?pi;wKYM316m%a?Ds{j^_{>FsVQQQBt+BK0Sw&^{J^eEi^fNXV1< zh}Yz1K!ho1G)$UCIdym<-aVCRQ~mOrow}XcjlMTK$N&HRzn#Fv*ZsfPga+nMioDex QFwnnOsyZrFFRdf~7t~Vd%m4rY literal 0 HcmV?d00001 diff --git a/docs/_static/multi_process/mcjax_overview.png b/docs/_static/multi_process/mcjax_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..dae947ff9df7c62ec93a75d2a6066de0d7f6bea1 GIT binary patch literal 145901 zcmd42WmH_x);Ab}ySoJl!QE*jI01sYySp?FK^rGXfS@6`H|`MJA;I09#@(m$zt4T< z-gnk~oV8}H?mnyQoKsbM@2V}oy(>~#Q3f4_1m(?}H|U>aB~{=?qQohPKu?5vfII5t(-$4Nh z&m^~TfxS#IoOhu$aP+7YO-2PSMdu{L}#CC8ujc%4>A0@J_mzFWm0vB!)wUy_58>I9z`O#b{8 z`YaY42Se@+7PLwDN<#WC8+y$_ixSXIu76_^|HGFga$8Jnxbs;FQS*PeqyQ`ymc;7S z!~Ay#3;&;Yknom=Z;Be9kAu(-u^znvv>3jNe5oOnX|W~)C4ef|7<&0T^L*tQ7C$ih zRXwAsi&C;bCA2LxG?@HfWF-9jd~#_Kk{R>zwP^P`67q-?{O-ehJP7oH6r2{b-hR7k zb)1vogyas}?ooQ;QFBTKZ3ioZY5qQ8T64kE!|^&;1aC_rS2KbOC#IlHKmRWd>JU@J zmF(Vhs%|KbwaWl?g5O0Tqr(IK_WSMwlMx?T?Z>{JkN-zImunA)#`0;Zt^A8BP~g8_ z;i}OZU{?0H{b~>n9k^SfEADSk5ksT<8)?kb&mUxZo^vcx|IY60e`m+VO_t;+s|b|~ z>c2+@g!FJ6gl|J-*{%#G^;J+%c(IwQG&m_NEHs_m##?)B!!WlCfsR1U@%`Dhv83dE zKh~wUGZ>&wCtugHWvhU!V&EzZ@pk6Zc&G4xcZW&5vzvsuYKwwO>MeY0H7Lf_RUW%moqPVE!-RNpso)$1SW};O8Ezr# zluAFJ&=HMy&IlDI5+0+b2qq^d@>-JWin}t% zkI=$~;}H8{%y;uuMbVMLG-Zu;{?l!P>}|&&tGtgdH$|6Yv2?htmHYEQ|0VxlUgDb{ zk^|~YF6Esg0rhKmB0G?d2@bFjewXq7kpenGH5#zSg!lF!#Tm@kVgULL#db8n?^U_53}7{`Q3^K5rWlMSyQIF zejF2ydIf5!s}DXmbdHpkRfTAM$yokByB{8ee2jR%>I&n?nZEPeQ;{JeX-P3k4!hP* z4ZNkrNakNJ{tsnK>oo1c4^$!KGvnKn&{2RH85{kvq^_MSProuVGBTE*Jazm45Z*PG z##p>>)P_=6y>LOnEAw0`v6G8KZw*!qTnoE+h{qBGtwfmBfB*(YDBCiE$(_6|4O1$f z=e}8|WfF~kiVq#Z`!`PUZ?SP3j zi}QpJPuy=89P*3*?-iJhY?r)r4Ik1J@GL0QwhX^T=8zID9qt-_!5WPz#ly*)_%HI3 z(SZ`N7cjBh-4<*rqIeK{ zWahSwYzhN0Sb?rfpH!fW-iQ*CPr6$p%sRbvabnFJoV6?l107EF)D*nb;IIT%s&8oM z9vVVESSWGr+&iryAt6~OeV?)<4HeG+F7WUB(nG!_(v&bemIshS?t}4aTzSce;ror- z1ZqnLQb?0DNHEuhb%}nj>T;*p^=%nn{3Uaq@F*Y$X)#I=R+8>Vvh4VUbAGv{2u~=DjKrbMio~m-)p)YWqV@J zRT2%tuN?nj$Ir#@LO+dffp%@r|G`?Q?%*eJ%G>^{KUn8o(7A_YNDgYMl2*~5e|N=2 z+giCs1g|VroZtMHs#MjTI~5UAd_;{HG5o|(w@HgZ;1}6ZxAue0LK1T-AXa|C)1`GU zd)Z1@@S4nDs;VGmuOR@rjYS7UaqpX~*U_1K~WyYrJoPD!#AC>ga^Y68F(=*{rMo8b*i^^WHXh76}RnC@4 zzg@`=C8xo4V9d$SWg8Pq(!`^)J;dYP-xlAC!q{@wO)m72QEwDBbe1_FADViJAg@!)-COWe|~Bw+Y8Eor=iqPi!eVMRsH>8>cQS4^Os8fx$~;5_uW+6@E_QuZ zUtmf#DL;;v`;f&6^>rJso&9_M4D$(p?O7H5yuauE?Q1V=H>@zy5 z&7@(U#3Dx$(_pp=D6{!`|4cYi29pA{-6LMH<742}X@1IS(i25H#&famT+o$LnMy%V z_xgDP`ykJ1z<(`AjD2jbrk~@4>=S-07Rlkv+XZw_BwE1*u}=K?6^9=YxH`(&Z2zzd z#{0iU zwVOF1#<0hE4FTLRlY0 zD@{}9E%JjzV0?fGiG{L%M`w?~k}thsgQwShF8h4z-r`wHOrsIQ=KuOzNh|ctCnA1P z%g+-gIgcm=-;FRv1Rvq`)f9W^nverFeHi)#xoONaLXyud{0w7LR%BdwT8#M5pG*cL z91%nZvMcBk1nM;xnR#x|s|LJCX;Y~!KXt|Z^-27tx=vb;Zk~(ih54NwvovPuEkwgK zboms-YpTbXKFL|e9U!f$Q=^rT-B8>#xnGO8`x^CxfY;bNA8fVxLb7vL@)3#Eq=VmS zbzOWQlNaG?W8FIbk2OAF&Nob_l!HW<+e+Bp9hW?(x|ryPF+x3cp3uq^;Aj^cSX)gp zyEi-TtgI;Ykc_=GlCBcny)(vd89;pd91s%FxrXXPbOs>fY7L&`dHl-brVcjSZ5h}Y z{L7?&bfKs_c|g0VrfD$Fpk)g>(_*EgQy@%qo+|78olk4g`D+|R1@a<}0HXHGYJFeq{G zwc)*LL1F?;E$`=mA5<;x_XrTQZPr89`@(vtSgZ_%YxbsdeV$?T-SqkXIsr$JGm?8?;dnrd2ZJcVCLYp;tydy7(jWqeU?96wS$T0zD%-!^0T zO}j6Dn}glp$RNn$r(8MVHwn5XymXiX?Jh&Tjd8qlQ3k_Eh?~a_ir9;XB(1;P5<+wZxgOXl_61=)+r_;#fauA+Je& zR*kC&vY01OjjRy+HotB#u8x53LOkt1{9}WQedda7V0_<|C^Xfio$&Dm(pnZJAGR|c z%8*p7eP@6COKv}0U81KGXGy#)Y``{A?xqc;ROIP2g-ZiTcj9GLcf(uYahY?2>;7RM zJ{3G+Ht}4Ro3ujH=+vQz`(b`3D@blTpxk^-UP}LM)-OMMyl>!Bo#C5vgg$0&lka}jG)DhQ@t+HkeT@Ke7J@LcxMw2t`=aX=|Wo*bran;iA|S-8Sz|MPOY{l$yrUKWrD~8O@YFO;w2}KO z=SW}!dyis0f4)MM_zgsmYWkce zkc-eczh3lgToeC~!=(hC8WMo7NZ$IbAc#`2P1Bt$PqV>3dp~Ph9e~CCIdgr-0zM&? zK^;A_dT5Yhc!`|DDG7Jiu4N+{{pL=62rq}3qqj&N{ruk5%&%@5j@(<3Kr!jgb=L3NxS%0NHaJj$xJ4^mOQ9hWdg41)LA)1;|A>cwwd@owQ=I`gvLd=A z2;iSRQ-~dSVYNTgtpk;x&W)@v)*+weVNCLA$Y` zvBX|X9I?K65?dVcvA0k1;{+p;d|jV*cgGxpu+~irpaJ3xDy&8wo{ot&Sc}=zF5` z6nl0)H7HO;WNp6Q>e%MvZ@8>k|13eH>{UgM>FEo>+}g-WVB&X+{cl%ihQ3+R$Upo? zX)3#70*1%K&@BfGH)-I34PDw!j<4=kPMXe3>@zDv?NekQaCGGWCP%|t+;~fx-#n#Y zMAxPNfpvUwO=@TOxko~`>C**G?u^fI)4U&gJZ0=MsdZayA`}IBs_ar^d*!af5v-o) z#dM3=WS`(C;c6H#X$cT2JdifK9+!n1u%cmQfpxTBYHd0RFNr08+lK)5%UMJ0#%Hh289D8$(+sitNEs5 zB$cN;yLPiot%(KALgIuqi;2uC;&6cggkN1uaEq0L!%JS>TIZu^ox=cevz1i|w8$#Jw8C5P>U+HOYI4`dlOyhID7#PbW+O^D z2HwM8dPid}a^9^=Tqnku&3}3K^{^YJ_$s;Zdl&w$;i-hhYaJ4aQxE71|vO#Bz)e%<4)QEx{a%^8}4*(;(LeO0maM{CC9`U(?E z6T}2rG$)rYSPnLOVN@NjHUL*bzD_w3l zeGfxGVdxKkkXHK&6G;bLl>cGBjj2&(YF}V{8*&7>wnD1%Wzb)?t83C4i>L2@94%_n za(QOCEX`@w+nPudb%`gL?9|cyrOijMaCsO?(MnY{oNX9C%OcEw>P(4WwglR{giUz* zLVH_^8mqNEYxCcpW4ZJ__!W5=Qxv53!jxv^C%`;Xt~Hbel$rYY;mCW_Y`p==g!wkp&+C8nNn2F7^9)`A4xKwRoB;3 z`;*_&b-wq4>;J_hgsy16Fr!Ca?Ey-{+H88?0nRxcwHnJqWR3&Gau_0N$Sns=3kRnFHgWyqX<@DHncX<7(`w?n2 zPrFqKOq(o=KWW42ca!V|XxDrTYyM00|h`I`+lKQR?Wv6vHtEGkdNW;4}rC0 zSTc=E=PGn^@4K4$p#8%ZBlU+nRG*?;n+_Zf1o<1$8EFDwT%4~l%z2W~Qh&KO-4*>( zDb7;o{T!ANh$bYS_%3q1*HR4`HT=fe85?sBGquVH>j_pIiwH0L>u*^}={_n30dn}T zf?^Api7qA-JKn>d#bM;58BeoeYX9;Xa4!kE-KdTetXiXX==<8mI9ZU@$6#fNRha5C zJP_)BLSSe=h(7+@Cqsz;syx4$G^{W^q2bVq0` zvlu{B=qb*2{jR~nbLeqGw&Lo?DKd$YNQlHLleSg-;DgXwWs#{wnMgd z`?K6mQp>v*uFe`VTBh=vTd8uNz6eV6QfnPl+d7039X=nC(**$09~zH0uCo95wKI}C z&-hFtPokSE(J!>Z`uClA!kW)a8Mjt7sMx)&aL~+$6(b;KZII&Q7D>+PSm#&>l1+WTPoKifa-i%h{b ze115GEZ{N-`#+?yJid{MqYzUJpvV?E8dyLj)KQM^OD;grZWMzr<8$nKLT_dFEa2T` z&@*PW%dA{9#nhX{nqaID^RpCL6clthpQP1q@$=AbPpw2d<5aQWkw?kc78Fwdx#$tWjA-$HlUu=X_6me4($ihY3nvqi7*h_ z3M#VmxP4TUXA+Xk{mlzBKGKYh*k&f#GfrX<(0{}!D%vr8q?g`Ddf1ypFt1ArCBQ1JvoR7ASxkF1mqzbdS+Hom>au0 zx;-}ex0sciF6A}*@N$TFe*TWw+&p#=$3cYbkhJ$R1)1feb8$eZPeNJFC4ZOg{#msn z9yFoC(nuuca=`qh-y~~SQNieE-1m`2)Ba$dw4(ezTPxCC`?{X#SR740q3Lr9K@Q8( zNOrXa;$7B+VjFA1!@aX=iexQyDjo)kPvyYPP3n#955Yn8a6T&67msemKq%nJM1JIt zpi=U5@Kz6}%FM(aBVhN4dPjp2!MCfv5D#Ot%1dcwU*$ufl}+&H7BShJ562n;mq!Vm zKR@l~k{)aDEG@chrM62x1c}F{5sr?#JKbpOQ6$y9;;dGV&nVMuvs3c{cN=*>HFT<{tHd>^}o^k z&(K6x%kg ze$D>Iz&$!q(_b5+gwvsI5sqF;qTOgis%gDsMvLl(%DZ4Z?mDZSRg0zj(3)VN9y!`Z zt*$-`4u8g9Zt_Otb6UYKl!ymq8?^(~tbbe@d!C(P<7ukXaI?a<0LqPbAfWCfB&hD+1;WZr41<=tZTBJ0xZ%)j#g}3$2pqYAT5n&GxZRPW_2G%6 zr5Fbz>Fv*>D|r1TcXC3uwU2DLITw3-zE$W)kT53k9)Bl2L|50#1U2R0{b`}{DPt4! z9bV?yi!X_#*aYP>1Kpp;yTQ#Rq*CJXE7B4iv->{A6>&_ctcq#v!(YQD-cp4+X9@aa zgf$-wD#S_J`XN!jm15`e49=CWTNe3dn;Gu{o*Dy$Rp`*I>n_Ab^Esl!HSnaqh+w~&)tBy7X zmoGL!)l&F@!lI6w=N%sn#X;nTY{XSaMl_Rxeedsh$FEC8J@pb@C!`)nFKXL$qN%N&jJKs4 z92MO#vL_mWObwJ*jHK3kg`0;ft=9V@K?^Y1s)l{ne^>!t^j7G;IGw7s&a5jjK};4H zV-rS`5qalDJF_L` z8KvwWTg$}d(a0$PoQI%h+VbS2uQ%|W=XZnneO^vkXI{E2lcHt)hD9_@9*337f-e6y zbsy#rwr{}U_b!k@KpX^i?>yqR930-%hCs4A*%ik#`n}SiwVNctx7u@!B6mO$2}sD+ z(SY**IG>y<$oYiihNr4C0jq0`_=S-+{Ka910{hk$_p3X=n=!#}FVbB{E z-KKTIiF<0WF$I$);EPSysi!52thQPcNly8O(oWx6zOueS7@Ojc_`9iyyw74r zoJKg6=b`y->n*|V-P1*>kNONZbyacrv%jAGa&d?Oz~Ssb-(gg@JnWjZj@@p7V7>+Z zXLU0?U8_=U|E&grVCDLBv!Q$K*8yL>IFGEZKS^yeu?Key1l&r-PQG-dM;Qr5>ciaz zRFkf9><$N-%&NTo7z0@5bb@A}5Vl?0JLB{y?fnY97OCKR<}=2UKZDy%>ZUOIPdHp3)k|GO!n+*TZwTqCX)lQ~N0ffv z?jG|^kA1llgIVRG!-sf1gZCZ*i_RC@F1@rLHce0AO$-gCP*V=!&BjS!2`t^K#ohR3 z9qPMP9b`uJ?7pz=R`w>tGh)@Qh5Y;}4zY8P@SHeM;o(vRv}KQHb-oYX${SEz|F}n0 zcYONMz+afNwq|{uG5mWLQnS>|{FKj|Z{@sR6=*YxtYukGKV3I|i0N36GipC`b03Ek z$irzm&BtZ+?`n69W93XtbsyI2rq!C`PPDo4;Tk~>~HY+McqQO z@olY&1H6>TPro_X+ep|*#VP~ZG#igydzfj-?YYk+wTlVgzpTz1??->P#Zg$H@7MoK zN#M0ot=jYUsD2b2^X&9l$B`Qgq6-K$_Dpy_zL{;OXOWd~)Wu88d)|QT!pJBV7;RLakcejh5MgX$WhHTYx?zKs*!74|A%^mS?2#G$wy$*KCq^RLW9fEyS9yP_YURQadN z@dwq*-5*#nn)t`?5OPd+SQ28z0V9Cl6+5T4#CN_?wkyMMJ!=e0D7YAz0)<<>*Qe04 zwTllT912sk!GF*w*_L4N&5(5g7^o84K8jgg$OOwiZ09EK5-q$>X^5F>MH8F7Lyqng zf}!DsmUWc-lAk5Gg#5bSp(uW|6cZrPT~t~69ei=O*JU8AF6O`w&VaRpQ95uS!AATk?Lw6P#O~lLq z%=VFP3G?@oY$3)2b1w%!xr3GGfa^%8DdfyV+&@Yiupe(I#Z@d&3oCf7))r7rT!@Ry z#%1hnv6+Q%s_mIDJt>&yBday%@a?qER(fZx4jEBL^({j8@6gEj-N_Ia_AY;GC+9@B zmA~J%{6NJaf&yS(21_Jgg+B{IH)mnj37)$bc!UTU*alMU08-Qp~Xh)E%CuFtsa_Y>=Ta%@x!!F28K{#zsB zGGar~0_M^40S%AO*7H*z(CT-nMu&OGZWIIP*{VWf<%t^hb>7ic!NKIZh=Xk%g(K|v zPS&T18yw3m@4)h_dA?WM3}hm!u@WKmZeL2C*Ew`0www#%1Cgb4*jnErdB+@NkJT;rrna5#u{=o-@?{Zf%Lo$i%HH@grX)N zC?1&CQ3(Sta3KC{Uf9>OGJ*&ogEhgLV{SA8*@1#ra`0|TLoFA4XI-j%VrW(B?;=6< z>l&ix;O@0Z@-ri`Y<8vzLLu9IE{sf3)SWb}X2VuYpk9}%uuD^;V{s~LlF@QBeCuSa zuS-jl7E)@>5l{X<!2<^OUx>`o)NSN516p5LnnA zE|^6{=lf`IS27l7;%{sQk19M1 z)*1IblxEFO~5#&Nqg04v3I*oLXLPB%ETgcBD(erfq zSJD6`h;(0HiQ*O+{W6yVP`NjegTXiI4F>q0ppm3!gkGMF8$iOeym)sVk?|bO*tJGT zM;e>J2t}}*o8-pE2dw%ZE%K__P}%65A4a|FAP#1>eHBZ$X$avH=$IvD@%QYl@~Pf5 z&nR(ZfS?$N(Co((9G{{0UDeLlvC@`~98vO235n^E^G%XaY~h^{H41hDM?_?_jH|6+ zjT`D$Bkka(eCB-;02Rnz(Z(B=`x;WcG8n~@iXbduCrP)v4mIS=&e;-jx*u8=uCAKp z_bt&qMC{|DF{??_2YQSU?S5G8;Zrd+B2^&9)qI(Z3d=nLmqu;2>}WT5ZfBe^LRS5wx0xHQuQ6`ODD{H2EAjKJcH06fQ$6_)ConnFOw8FxSa@3PES#;6;(UwU zkAS*{+E*SHTe3`85y8ps#;t@ns$A^q(>)dlZ{tzC;7(P`G7_$CGznwo`d*&9_ z7Pt^bc_)f@YFGIP+1vHuLL-mawLx5{y0Zo?(ghCezDfecOrQSQ#TC*FL`Jt<`KP`? z)9q}aZw;rkQHu{QFQr%bbj?1MLK4uNRAj{J$Sy?7$?Qd+pN+Cqh2X)o^CbzgN{vGwn0 zlwtm$#&`fUc^=-Zf`R%TkrgS{h2kVz}hCFS&-t;u+57X>m*yWe~S*& z<>wJUN)+n(Q=i(EVaMCGV&-zrug2HBg!Bhqzqsh(J)k(d{VM1p-BMOu4LrpV%HuVG z-0spQ6NcF*6zF1ow?mtoJHB^5H=ZfXb~W#Dn~5F1LVkMK)kQP#=XVog2TU&eq74S0 zr{~=xwCal+@6Vq<5z)~S;^J=;6BEB?W-=?7nVMpe5PVpk{53j>?9+mvtTHulRnU^w zAdB0X)`2~uG5WiA_Hw)(#N2ef*QBhy!vn`{(hhs`}l$TIeCvWSQm>y__B+K>uY8|MkA z*Q1}Gqqsc{v}Gx@o^|=DJ9^_QD~8XX?;&GCI#B` zII+9O2FG!xD74wkx zk)xd-b}DKT#8btM%EWlV-YJB5IU-$GPEEPl7Scp=fy4J~6i>e{pLx?1jdBfaHXdNh zpX>@*fkgqLbV+2j-gobO<2dxqrf7HwvHzN+5%_o_6PY*@a0<<$iruY;U8H$Pqse9pMF2E5Y$4QW@U^R^FLp3q>vRe@=C>YXWD zsE!qU;vBjfzs@)#HlCbk=LGGXUY{rZHbM#)^zlxrW2wES1atd~gH2kU2hzKvAjxer6#4Q&PKrYJAK6ZIGmu+-w z2CYPdR@GMo8_X6%As^~&(XH6i03lyS8s2!mU)~X0u@|XqIhfJ6{mX40ScDmcd2hlP=Q{8sr>*Q!)w-3TD|9$B1IA{Dyl z27PO{pSu(RTWMGjG97TM>urbn@!3bdCogb-A^5jw3(%`#x!Dc1(P{1D`z1_5UpACW zIXhjjqt*Kq2N!(DkB-|!afjiWxdC2ZEA19Rz0>w&fIqQQ%#a)*A2_(7EgF}aMD zR+G0^ZTf|bu`|7%h;A`o<+~dp{xpASFl;hkhDm)m4~Z_u`fz;)zwZNH9Q3fj;WArs zZ9D$4@2#jbYb1awXxrm`m!tJ)-=y-g2jG{N^r7jCEFd)+nE0PZ(0AH3~n z7%g+1Fw^{mA;O?QK3gPZlctjx*=(_(dfCi+T!r0*w#mFL_C5 z`8vuQ_U5yw{8gJ>F6`{ly}wuhQ?+5YP;NYhd$+o%7!b!rLaeHv;ME$eR=O^_ef;`I z&&@1DCDOZMTyKnvNreLqRWMAGe4T~Y>yuS(?w=^=#wzo?{up*Khq*!S5`HU{g+&*g zc|Xk}8tCkG0wYHLRR((R5WJEyiYwl;%)VEQ1u;BZ=cT8($QC8_9-oNzdD)Eo{VwqO zun+9G+{UUewDZR9S4wg!YE@3BJ3a=CPFev&mpBh8^As?|NIl#+0OxeZcC|GJz7=J9 zmTyBr33E|i)g_M?BXxwvxr7@nb-4S_NZ}r8u(RL*BHfg*3+EWB!FQ4QjP)={BdZ7@ zVXS_V-8K#iLP|^{QBf~%-=4l1a^8lU`O-#Qu2h@NK+6)-!)1-M4t>JtS*HOKbXGJvYT%<(l|hkY?uSRbSsvEd7dz-imah<>ILR z*h6jovY2jK6o2Kn@J$61c3V!<(908)sF5)Foy%Nz^jH@0cmr#3 zsuag>AEnmF?f8gIV|N`1PI~X`Ln4Nv;f0iD**!zX9Xb4{-PR$v(d(gaLuqT-yTwh6 zeP|^YK0`q+O`-Ywa`Ys@%~uELelH3!;NvZ#4AWdRp_(^nzx`MUq=@x-(dMT=yV&6> zS=^|f92dVVQ8SPo8Sb93(+cx$-ubNEOoJ26O%BbkvY;JHC~MC29`5~c4lviz5Wa|# z`{Q=Vym&b5k}HU@OqRPdRfA4&I7MhoY3jD5PQc%Mhz+;4)Gp-v}5_w-gstTWQ}>|5UTMJ-neb}9J&@-Wi3gIrhu|i8_;pC zf78+et}+$0;cA1V%M}c6%#DMeZx?rR0xF6e&cC8i_UkDxzVUChE}z0R>9KSx*I1)k zF7^*6sUmmM1Q_0VRREG3T6z@8=->oDc_E17Mq{VzhOb|t+*!^3qP~rQV)><6pRgb& zx55%bVg!?u=Heumj9Fk2=KABXM1MBZVaKgMTQF6Et{cS5(cA5EY?rK>gghug4s~>& z*@k{Sbt&`tyy)+01lh>Atetwg8hZ2}Vxa8up&RWpL1XWusf6evbHx&|EyotGDC)Kx z+`<~?z_{_)S78SKgk-*hc!k@=QC6F(SIO8;t(%^R)42XprB{I=vx2)@VW4W)!>xVb zvICdzck-HCYfyjk0jbEteigLlGlu*xLe@A&-+b5UT4N0ei-tX6*86xX~<2%dkS z?v8fcxEPaZ3P{ZfeR=Bn$=|5ND% zX`iRD$XZ>1k;t?^xr-@wtQbGJN8E6CI_T>_6*utsjHUVX`hN_3Y!_=PV(V79NdO#z zz`$1zuxAt1OClnFA>l+$%|Y!(uAAR5>9!GuZE?^zXpUR}h&Oa#1KV|C$_AlDPAfPG z+R>fWuG)+FvCj=flF+h1@p)NBZTUYD@HOiBy|38gH1oF(vhjp$V`Ks)geYqI9Rt;< zGx|4x-bjwga6kCgFpr0DBp-cg*x9id7#PHhFU`>1wfH=>7w2i{s*o&pX1c!)#AH$z zA1rP4&|2Qjpr9(O?Z@07vZl4GQo5y-8utszD;TxowjdNbd){J&&leOFpyuc(D~D!h zPwiV5_ZDm#FI&c*b@WSKReKydDfZ^2m|yXISQm4^dm!07b2mx{cX%?2FSHas2_scmkwBfPgs5UgLQ3z-%>r_=4jg?NW<)sf_aKsV>Z!Yv0fZ0k3U1 zThiNijbJ!XJVnn-i|VqYT+db^9MtfST~+%}P29}2mm)1Qqj%T;e3LnRBW1{6Ms9>` zYG#^{k{$*At$cTPXFTVK8Yb}z@Z@h&LrTNe8hAA^l?v6Gw}BL*3NL!*AP+onmW>bZ@tTrAFj#`*US&4(Bp$kjH(tzo= zNwBF~3oTs>vx*6R!7knnF^ndecex1n*JW<^j?Z_eV2dkhG`zdb}w&@ zu-s>wVK~a$`c8K05yuvAns;&sJ;pR*#Ob3q&cn=FYWjhGMUStHJBdNQlF&)Ygvu3% zbv^D?OI7;*;{`qj=?!(D@6ZmdJdoy29&3?u~Ctb%O{%h&moDm z5A!y)#sj3g2n)>f9ycg^lINrfoj&*WV3qL+5DA*6IGB`J?)G6M^)4}zZ`?aCi*GnX zS2uK^Xgo74R*N56mTVH2j+^3m9_l11sC(0ZDy(@O>l^kMR$yiMv!xshQ9kiJg2(6a z&@3$cUV`An%gbl<>dRQGAz1{hiz1R|tt9uUiCOH$*;nH`pXMH=ZYSnrR7@vik_^2+ zZAIi`kkjqdS>xDBTvSIauJ6T_hW#59G1DIIl6ksMHDpl4P8>D(?S2fP9QhS-lJSj* zI~Q={RlvIpj<8AUO7n_~Q>*B!21O0)HpDJsWot*&``|DyKqB!=KPap64J?1oPAg;R zFCvvpE8UDWVqDk?34;dh!C7~wVSsZZ`v+87(&Dho>=ses7)R(f@9$D%dPd?_(*ZM) zBLSbtTK6WrevlUD1Ud6517O?RVG)r-VPQk+d+ZJlrs#@G%5Z%G@5O?rA>!qH$;KVw zT!jpPuHUS1efa3&&ya<=pr471?~5#cG&LbzOA$92roJb`A}Xb%sPR5VGRZj3K^34j z_dbeS|chc$5H^dp|nuU!}$ZyyuU%4*{$-PO8xHU36cPDTuyg9Unq3Q4mm zL(Msl`2)Pz9YNEL8w4i4W7CZ}>`QK`J8_K|21aVXnsNNs_FnN~azj4-C%;xH0_5L{ zJ3%K(IYxBT;{XasB}dh;wLb&eed94dHMg-+zoQWv>dOt=2wmN`$7nqh&?!_=`SuTd2&CIyBv`$r;NZtM(k9SH^xsYdTbjKuv*&+N$ z%yKlk_dC7`lV`}G7(;I5mN!_A;@yu;0cI{Ed6rHj{SsO^^`is9}k8cq9#Ucpr}!XzsFhMe0bh^UuB6Sj?!lb*?wo{FtSEs*h7K-HVd3H2 z$U|Ta?2<99(*`c3-up_51&k_}X5W=MD)=5T40JloyF(YT3^4CWbLd0Q71g=|fp#d| z_L4zL=&tu;e3DK?z_tY;+AKrfV9;`Fj$OwG^)d%Rx`8{u&2+ontxvwa z1_Z&8q=ngouWbH|n=#+I!Y*uZJDDQ>TP%+_Tdjj?Nf`~=p0w5brjoIvmud|55d*OO z;J3Y;)Y1=4hcj*H+N`w*+dH5b5hWWEUe2DAtBZJi6@IG$^;ibE?~sLQ?(FMOL#tAe zZ(TZYJd+IUoLz|y?|^h|u)&V)QcY|hu?lB!Z={5-fn%*O_##1~T_Y>ihgL(1m-R$LI9Yx?qAWeluXC{9L2&AxKIxub)WM+u0stqm04F`smAO zEWwV6St5vl7G)Dho%ctd2alYNU-B&F)e{68|mU#||#83Nw3-}Sr zj@jM4rvh*AyO>qRzv@l6sC}Iosw>jxGM*p@!$=oKbDn!m<3409JP2-C6X8tUJfL5f zw(8cgfck$p7x5{$+>g>Q)?(!;JF8((m=+ENAvdk%$$v*D3! zSL^ZD!^E_!i3Xo${L?FkAaX(aw%0~RaCZ-R8KwU~z``}!L;nr1a$v-PWY#;tgtv&3 z@*Ld#$F2Czd!wuF`2EmX*qtv?Zg6V-Z@5>h6+eUQ}KU z!*nx2yQT`=GGs5?Z@t9KDA^vrb425AI7uxnW86ZrG@rGtQD`_3lna+upMZg%9p}>; zzxdNBMB7GUD<|>5I$oHH_#CE9#JU2jQ)%0wgA!3aI@pmWl*Ey=Ydbi0R~0*0f=L#z405neLLT#48f%v?ZYC}oqrSvS;j2M^hG z$DT)27WlN%gC3>|KSBG!O3U1560bkRXemNYrTLVd}Y>oNdjmzFy?LHkzIm^DVJMYmr^Po(?d;&YWD<;q7Ft{H-?ZV? z>5;1WQVi+1Zf2w1NLdohwfkANT-L$h`*&^Fdr!=`)V-9&wzx7lzZcM}iW<}|9PX?J z;-OV%;k$Le-{pU4B;3WajZWd=7_@yJs6wSY5%@-oM#le9>+s5ghVzp%^(O{#ytqf( z&>4vXD8a-%K#PH28LJyX`Qe>J>pye<5vND5X0!ircej2oZ{z($fJ(YQ%kNkpWN*eg zD$4(*SBLnZp5xRNtCqPi!s_}y7i5|u0|)pO+;5K&q0kZam!{B#7Yr_q&MMrY)AYgY z@2*=xiC3b*%S`_zrR21^!DF-06Q^u;H5wK+QNzRw^D2}qpY>Dxn@nT*8#EH zj(lD9Dy+1L_;iT7*;HNKMgYT{ObOL25{Azp1~u%>^9 z#h42do{*hJSvlZA}f;t@& zTf{Mso*Cw?a8MQD2tK)GOna_o;>CyrJ{JPnyy%SBWsRSMM96@$xT zP;GHGKBmBxFeeuEzXhF)P#u+tj7?6#k@g!|bFQjRy|o7JIRV?W*-jrWJ-1wc zZ4(P#rC87*v(lif(`6sWhRFM3$WzyE1sk$J8DEgY?C8QBUAW?#Rt|NfIGmY@cENY% zRo&?1v3L@a@_@jn{sBW!w_*T(c|wXhq+_v%!$g-+TM)%f97a1DPd`OrB(Do~S2Eu3n~we|GcSXCDqu=3lMDR5Y5yFmU@-wCBGvEDz6cv8=}SOYjrnv@s%Jg7M3a{oqm^m64M!tZ1;+UReNL_6y@J$}d6xr| zD_fEGMM(mCM7vl#-A$b*tN%Oh#T1fEO>Kbx9R`l0TJ>4$pl?e^HHgHNqSSQNZ-a>{{~ zr}*P~70))=8PNp)ok`#~k~qVpc6+KPl;Ae}O-}XVy-`q|9`ohWIMmmrAv{G0rTLb% zZY(XUdb~VORA;%xbz}1uY|pKr)FV;;L46!LyI_ch#{Qm|O!>t=`U87g|-kzS`K>_808tsi00Av{Y8fLce3zBtXlTp!6Q zcX2bkRUl}@T~;)pk5S{Nd$;U_(4SsDNN!NVop}%OSGJ$$o9_2_%Dc}2u zbJ4k?8NELy6oZ#13?uE2wtgOE87s0GvAPv6$)VqM?ATOoEr)qbAH5NaxkBF3 zjWD5_$7&y%zXXu~<<)H8%#Ea(sl@y>Jy6&nsj|Y>0D-B?3wC_LJzn-^n%XR+!T|c_ zVq70fdZe8)h+c{q<4&eSRALc5m_+R)2Q@xDZRq3#jYqe9Y(p(3Y{~}Rv~PHH42RNp z&zgT$b?C!8Q0oM;LDn?Pzj#e(;C?WWeB>1##VYl3ARZm>>=FJNdg+*+k`@-7PTyFk z4K+WWB_^z~2e!f8Q|E+Ay>`k>_FkNoq($I-QrqltnyZ2t5Dgq(mlcsg+W%G%(%TRTb2BT zMczv>;5GcyhAvn5R~F4L3~)O*trIx^(|m<)nuB0lWHyJJSW{UqGqa;ie~7oPg>|7H z&1?kABb7in_^?uIOqoxNpzVe!0-{!5+aXj4ds1>oJ&R^vuN>yE1&tQ|H=rLrQow4E z&NTtx(hu*rlUj{VH{!h|9?b9D3dAs<@_r;we{t)~?3#ed$RTJ1uhS%85c?ub()b(7 zB>${yLi5`*3Dafci%_W0icx*@N?Q$Qr^X$g2!p{fctVI--B?^M`F9$dzbD;)$uPF{ z<>%S|uf2a|HIzfYn@P`S9LDBT)GFZDX0T&91IHY1`LW9S7_(iQ%VFE8%|=B|?b+Io zt_OlTp@MGbe0rfTKpd_*7gh5@fW?qvO>tnd>W$Qj9@a8Vd&tv%)Af)6yAAQcZTY>F zMF5Q7Z8KLJsa8#~!&6*5Wj24bt6(+_d!d{?};Rxq;VXuyB0nid&ra>QQ_O%k>SX3mjGH~B?EU$>{7)v)LMSiZM2IoKwc&KD%4L+j}JE-l+W)fMPq@wj6uHxCCnIg?IlM$PAOx`yNU+1QH-(hIkb@m z%H|iWrtofarevuwt7eV0x#uf0pa5K!;zP5wek(j=oPAKluM2C%c)=Hw(W#v^Qh(M8 zQDbJ3W)VpZ)D<)alAN+~*xiIRq?B&ekP`8{*d;z&wwf;UEhmM^qm%g_-nN=z=k5BcL87G+?6~BIoiklNa zK!hu9ZXux5*#J*R1GMRoH&cc%w-z2;wGX=+aszUk++o26EcfwF^Yd!mSX3~IYI$(} zhE@V$VZ%bTNoAT|SAA-yc3$#=bU&UcNFRU2akHo`z{KV>L=EEhaJ+O==gIAuo=Z%Q zSP#_Fsx4$ZL?gPt51Kb{wKZBj5#+V5VI9Wr)w^xQz20SMxc+`Yw*psZdGgYVg;A{- zb}Kh;I4z%DQTQ1tTDi7=AZ>j^?dD#*YY5sl$G<7_u^Fe%fLz@heKXXlf;M72W2V1> zbT>pI=YPWR`*Ndsfzh;Q)cS@GxyvisrR?+QS^bJkV&HN}v+c>2wIk;EjAPfrxVyUn z#O;hY>5uxiMHKA#3gfmD_-o$fx3%-O+Y8DM0oz=E(6NLX|UFR&y8=wfAuNwQOx_baPz zMx?AsEQxRxv-bZ8PW(suB4?tEO+Y|vunlJ4ws--yFVUDC;?3J!l{mA~xG;$YUOAQP ziFt8G%D5ufPY)G7O739cl4D!W$ng_u_ZZY+iZdWu9oY30VaV{}>5KM`?{6l)2YH1c(@5M9LK3FSqP)Z6%U0 z%!9g_3fiE$0q*0p0*sdvcNOM7&FWInTn#%sR3Tj)CCEI9kdq88wMU*Vn*5H*>ZChY z0Qbk=sBPFStTj5e9ZlBlTh5)rHfaNK(f9`0iA^oXJF2x7Hw1wJ9%Lh}OT>lykXU)# zPDnA$gYRgkchXw~`e2Mb7g7y_*7SCADB=qAbD_kCS+Gy*Y+V#-;|(+W&ZmC5WEqod zPlBumX~k(k5#E&n)AYXD7; zAb;6FcFKYur>Qkm5XX6%fAAj}DFy6%Gx>SI9(hSQI)}PfnAftA{p8Q3h8bG8b%-(b zGAtwwdC7FU?_SJ|%*l-I};{2@ieQr~@`l})OFeO0{VLmEc z$LO3YBkhc&H6PMi4=&@qrBMo%E`6+&Zv@w)HkyaLPT%f{c0lK^-hFvAe9!G~etrid zmLd7DkIqZ;sdr}Dy9tx{-?LmC2J4#2tgDvqEHs>ikLUJotefVwPzINv(MB}$u$fQh z8L7RE@A#Z@_wF3DoouN)#!BkXGk;t1qH|m5yJp?ZbWKWR{ z(GE56d0GinM;ViRf?1bFMMEWO_?aWcl%hDUe_N&%BRgp{D-o9d;z5)KQxf+2*v~Qg z;!N40i42523&o+71jE-Bj*^~E(8LDxOi1!6t~gu6{XSN3lL_l8VuPvT00j}UF9-~P zdXmEB-ijWny@RKP4swixX=dPVVuW5yDkv{4OM$o&c(bDh;lx7Kx0yTBiKAB)ZERh7 ziGaEkb@%YT^xOCcHwWSpveHAPQyA;c zSzM5WmYAX2lYwhB$N-uv(&i^zlWn;ZH^_(B8VMO0AwN0tNu!T>jHB_aWCf3|xyPvD zH`srTHIzH{vL@cc3m=dtEbH3etu6yPVt@8&g^0B>FL9Bv;^x}myD*5Yj}^XQutBFwuHS#n?!No*;Uqu%8WnD=le7U33rFm7oF|Nvv@6Ri|x|B3DQo0(zHXIOfEH0?P0D&yFkiZxbkkLx;oT{^qj@7qm9{WYU<0ua=q`2 z=s12G)S=5`gT@1jdbH`s(hNkB8qzqrt?t8<)M{uxLHBL$NHs>J>h8$1qit5!=}R61 zWLJ1N9)G2b^4xAS2=W=Hp7b^H?p9Xi%F}8ngTA%+z-2Mk6Sv<#@0t^9j$M=gFClwA z-uil!@z`C8tGgu?MXsubHcaK)bFV=~UJI)?CW61<>w&3-_M;Tb7uc6IT>JQ_d z4?J|um1X=|iiII4*t$dPZdL7nnb|k>>`{+H?KMj|_d|cyVhwjJdxE|VgQ%z5GGBJd zT+NvopTiEqtjsdSEb)erEyMkq1^vJEb{!Zu7$^i+sWUATdu4{``tV=V`r^o!Bn?y3 zN>JfsSNQtk8Rb{VSM==i=9Y#dolY;jUdPh4(!+3H>e)5>z0inZ9W^8@EIg9AGXX0r z>`ms{+FEo~&It}J^f+&qMf;++4;PYaA#yDcZ^nf1G2s!|=VKSvtenM<{o2OLj9|cF zSwc?UOxei-s|K_47!wyCx7pcC%3O=PQGc7C6`((t4+T$yzk@-O5b8uz(SYxj$cVhLocK& zhDyIO+6Y%!*IZg$_Y((FvN2y%C%*0Zp|H$6E{caCgX>vHh5>g|Rfwqe0Kt;ipAzQ{ zF##&P*^dyA!dtJI)gN5M{EV8r_vh~qdw1iNT*v4fniP~b4Wro?Pd@oGa5pg#FTnY+ zD9B6^RXZELHbfW0-RvDb^n4v;I7EZPJLV)cbw>KJGY%Jm&~Fa(_encM71yv1X-w=* zFQ%VvAwGO0<1ys*OLU$AR`c!% z>bkpE)c=HfUgcdYIO*r7dj5wkC#+zEHiXuEb-2X$kR#TZCccToj5TP+@8x-5u`;7C z?lIWj9ma0Aua8_#k=N@QQyg24NTtF~^_P=x7r?N~N}>5_eHTBpxrtYu(ffwyKDtqD zgt}@KFV~HS&0?8OO_D4Q zVMFafEc?#$lZ89f&{O!u*YFD#4;Z(+&(Dij!Hf&OH(8wgvCcTSdhM?pYQVSAXkLNo zF@#D_Tn&^T3*qnnaohS=4qh>^6T9N&v6tzwn_79IPr9}68eIjtxsg9G=Bt2-)T|Sp zaE=jlj=l$44W&?rWy18nTv^9T;y`=q2f(GSe}8SBmOARbd0usP<9m5DUQEkkqIerW zp@TTPe~-`es5?4kr5>qfqeP*!OA5<`IF0{L720s5R7yfMOwv>YJf}-h&UhU%F z(NTiF!u2mB4E>-Bsev_5Ld5Q>K++r7uxNWs>gd7549w_Xzi6A05PwzD3ae_NhQSOn zswc<1Y?w@|QvAW}(5;XkY2b3&Oo%FGh036g9d^-#Et@iqG-6cD4ZB$o#>~$XH?m?2 z8zaYfY9t9*S>T5tiI?f7IQfC^XLx;?OkA)B2L;ETR&te8-;WMuj=2VCJX&;BO+NHJ zW|`gJ|M|TZhJt}^=Ge>zE?$j1@f0`eVF(br6I#f>j8F4GI1KeBBqxV3H8rK2lo}DS zge{xgXVgLtehqSgjYF!ZvFhLSo~akzg4pjI*i(WxI|_AU+^WWcrTQ?a&<;a>eh8=K z6rU^;R3M({PG}?%Sw2t9{TGbLk8-~1Jj|ymUQz1*a?dGOL$j7jV2lB-F%lQJEgs$L z0d14BqY33k)6=BVB26CPbMm_>DmO(Ems4DDkN6X52|B1#mGZVxC@~X1G$e!r!ZkKl zmU2c;6oZpb6Dj_s6Lr6kbQ*vMhcJIm%yb5jw!eyb|&{IV*L#ZxfVG5_pgYM0b?8d zCWAKG)M_qPRNfp?3u>)0D|{c1$QfASAtaFd$qR0rZ3J|y4igaJ#*`{6lVKE@|--_<}Q5-QCbs)YSv)ywT_ zUj266y1f!L9&}??vl^F7|KZ5;?T18=@$#KNi`o~eq4?4tMYDfPS&4ar{uT0y7oTU& zeBoqB^<5Z$3C~AgHezp_#vs;ezC6&r^CKK`$RPAAvxYXf6}f!1GkJ9|F*q;MnI_Ny zAdN18_jc`Nz_TMs7`gQfiEeX+yemnLK@*SsUXQyvj~E3Q;H<(@h?S-9?t_!=)eD`YpEKDmm6ZK28v<0p^hCA{6I<>iOlFf?`;k0~zB zhjcv=?5#~SUB^sAwrNp5vqZ84H3(B7LXY-JMBI@^dZJocTQ6)a#f+Ar%gjt4Y`sXg zdz@g?>vTrJVALhmF)*CHkI)>hrH+iP9UQKOPDPK=>l>#w1YC>Y=4BY;v0FY37dB{e zgs1M)8FHWRqk&-LV|tj<64I7O*BGI%(f^9xZhV|?xfawGXYhK3VB%m2s?7_QkeM8O z$acRz9EzW~dIYT6ZH=Bdr%znD!Yki3V{-)0?g)yE43kh&Vupr>j!(@|TrJL&xxg>9 zXo;82OJ72cVC*9RKzEX^hs`wgugA3l<>QV65J*{Zi3%^kOpT#)W)FKYn(yfftxdqG zY~|G!l38E_vrK>NB)SZ_r3zS(J`J|JK#cx%w&qiptASioyuw`0Gt)Lzs3`t)tmTht zYP&jMyO{;MeVXhR?iT#GJ6<5+@wnbQxe*`ckKQTtx;-RX0NWy~qp7JNKYHf5llSz6 zERi+VG_{IX1G-Tw_rnC}Cr}g-`zoqr%yla(D~AUkGIKlLPE-5m^f~;^^rv@tJ%YZ2 zUl!fZZ942wjQnmdPRw=+vxoxH)R(~oJNpU(ezv>81W2o32;IH{PN(Z-H4wOm>4<)r zL?2EgySP0S+Z#lvUMMZzIarwBKH>8ieRg~d`0g*9aX9`1{f91R?DI;g4U?s>5mm5< zsCHB8psuSOpZ7_Gi^~@z98QtxA7%Y*0)`yaHllkI*UaEH$rY%+0?Gqzhp6Tn+Fmwh zG$Lny)09&$Bh=|_Qi$D6*q(1(^Oh2XwoXiHikLj^cU(HwQ<^Agsu;w%=LN1#c%QIzIFE8dOZ@qUggA$TeGfJYW~Fb`A?ws$CKxLvFX(MTP%*ML7wqhEYnXY z|2~|p7)obTe6bijk#q`>K|Qkm&mHhzIUATFjh)v6*L0ms%kf*Fozya%Q(}EiFhTD@ z$5Cg{F6}>Ad-8yoZE+1n2j(daomQu-^WPhfd^xOew+` z_3ZwiC`)SgE5Cq%G_G!}Tmmo|IxA4KC(!gIj!3LLy2sCXHYaS#0K-qh!h?heKs&!gN-8SBZ&Y=X zhI&Cd;B@OHn)arw?fdANEH~1u+U;#TXt(BMMCI}f z?(N!#(LeSiptWa!Xiz2+0ZrV~N2XKMSf)7ntdK`xv08g60DuIzQYGtgj#~ zF?kf$+V+tBIa`_jpG0mv4&v77Ny0dN8K|oUlOwG32OSf^{R^nk88)~4~1NI;h(#5K6T7+vBW9iTi$ zI80Gi)Y5{Ms;XLHZ(-#9p-Wot7WVw-cKNK`%XfmCJ=33~$2oc{ug@X!2R;b*?fXul zBxPpU00lLHDqIaXbG(1HoW|uH;1Q{wA^~h@aeF1Ch#Y?*`EL@w9STVQrWE3Tvc|gi zmK%t>{{075vZ@mlbHGC?VWXAwYY`W|eNK4i?>ZB;^x#<4oE0rmtK-d6`ocjzr#n6fc2D+9LbTicEMdfpjD?R?FPPd+*Pwtl)wH*3sb5m z#U_h-hGd=mj#8BSRAr`%$uWfmIpmF3r;_rK!fqA}&kigg5Ezh|i53$(V`yjy#_dhf z-s#mpz>Ddv@-88C9xGrk%fX2wxJ!gtMn9m?0yMGqGl4y384`%Fd=ZdXMtghMMnu9I zSV3)RX(Qp}g>kb3!NlhsB($~2uv2va#!qzk-To1me_+Z(tbaW&j)yYw9lW2QAV9h9 z3fl-XQ;W^ycf%HRe0dIT=M7ofPDsQLivYS`Q+gZ5n=!T{JPYou5hN~vC^SJ2=%LA9 zM_UJIPaSq(5HvoAIho{$Hs#=?CroIJ2@n<5ccAeMjMhgkDndp?MmI-JP6$lN z`ja)6rp$nHuI6QqbdFwbr1V9~c3;?3j+coAMM*tP`Z@8vz)vfAiyP@qYhYvqqV1F) zrN&X}>O)w71+C7G^4wWHFu#NH&yHxm0WN?TghX>6&yu(^OeQ6utslgk!}FDdi3v$o zR<@$LF1(`$5_+o2@%CPGdKzEU+8g8B$3pDAX@mQFrzD==?_aL`xUIK30(qI4W@vkF z3&!tPkYRp!2Lzy)Tqi<`>{f#~31U|&Sby~?!j%C3nioXVoRq*QVz<)h8kJp?yM-~% zJ^?c%LkU^CVAhY7Z0{Pw!$vO0}xw_C05uCP3&vz9<&=sQ6pt*!wG<@a+@tHullgJDm z*scdWj2SAh!58aB1qCPZ2U@>0Nge13N=8H|Ttx_qs37nEgjPVH(LP95-SLEW``Y`$ zf8A0i*6TU#^f}|>Dq|onKH~_oD*L151c-YnuAd_av|gYz(N+4RyB$%I2f>1GEpqh| z#(MPE^fVQW8|riSqm@OA>HW?ae=ZaX+J_x}LwCjZt<5pfzFV>EwkiMO&?o-8j2{Jw zr0bi3PKo?HuXuB&ar38mbr^gZnr3_nk^Rp6wcmux=gmR^wuyQjdWKdtI`Q<&+TCb1 zz^$(FlGi8ZW|2K%)g3v{6lGo>FG$FRfmDc_2zMSLv0n}Nwj#_|grp@I>p*EEU}Fwe zWsW`X?sxIxsi_6x_LY*>A6FY$w0M-3%{P;Buy3kW({Au?eWUx@k>>;GW#dTSiebb3&scf!%0jqffatttV2MJ=peBr zB&QYzgdM?RBg7!yi?RJMr@_1rg)Q0S6Hn0E2C2aet$PO9&O?3}z58uDF=utS%$C>d z9jR8&KVP<#TdF@!Zlr{tIrr!4#(H$fp~|2oQ`7=_Q!d}BoJy`9=YcOu{>taE)Jhpg7tC$$4j!y)GSsv_y7E-O z-5%148fIZZOWbEkB7h9U%&c{OfyroRVVOA$DUj!sO0-x2cuXO@82G$6LBKb#GcMB{ z8eh&F%zuh?ibwt0Qrf~xTwJu?*%>t{#ep)upUNg+W=IFp<;>xKgWSbYe~5Pqv+*9L zc3I(%PSdx<-!}e%bJ$=#&EG7gv$-YxV*;V;@SIRAKsho1sZENqb<(R*iXlG;jK^2U z@Jh0LBrHgK0p%Tfc8%&FVi#u9sHq-phb+%ieGBa>l5gC4~ z{qMSe>x>3qy(0v;F>&E3()LmtDb*i*$!7fpycMQv-kxDY;oL6SVKhgxtxpjYh-K)7R#czD$kFu3-1bgT&o66}+AW zjqTRLZsetJ-n9ab&i-4?KgGwR9mSk}t_tpaxA1-VJYF#k$*YFh{+LdRJ@c}hz?TYq zAb6`VWkG5MxX|Ye;;l8YRY-z8ScuOl5xN58SP|FK3~MdURu62ww$!MlQ?H6GQ-PLZ zLlXmmu&Or&*T$)W^gtZxKQvR(YjVltM)|#lUGNn0pfU9AG(&!m0B_`E z1zj$ePY%T;xkR}RN0>(^; zNcj*FsmLN^8HBCm(0<=qBPQ#`4d-x`V6+f+HN$nFF2lf_$WKiuiA)>>CCLNlzbGV+{=Mq(fc-WZf-nEQBU5DAMBU_n$PG zJH~}NUr&# zD%XoMvNZ4Ex5Wa5aBBW3)zeh}5>~l4A}P$-(%Jt@zF(=QqPAMt+kpcAX~|)L53hLf zgTq6#e`E+l2)lPpuwPk-{Ri?JD-k6rC>RaG8u%eGc@1s{_}}ZO@z5ewP!E;&e!4Q* zez_f}fS|-lZy_OZElHi6=bDl_IUo_|tX(EMK;3->i!aLdS%$D;&oE z@b5_ScGs7m%UGP|L~Pae6h43R@4XFGeD8vzx-s6qX04~=h&%C2A}7uJmmK?gCcVpV zxXj`_vnB(`*{6;+f1+RkuNQMY@4{=o72GACu=2nj0_WNE!JvNi?9Ci4$Jkko54 z<8M(9B&<9H#^#~~b7m!IYiSN`Z-ae%k=UG)$xcd3Ng<&kgNV{ek5^a>4J#l+MG_rz zu*B^wK;iPb#ei(~yfsTsk??rkW(YhDm3uX>#T+oz6wtBRcD#Jj#mLPSR`jCz8Qw%h z%n7LKl6-rZN(601b(XOexh=7~BF>)*s(%s1#LOWwW{a=x8M3&-35!gCfOm)84gA!IyKz##8ZVN8Jf_0U=>bbbJ6*cl8(8a1+*i{fG5cFs?EV7#Q zF3c|(r*dO}eo~hSt`#T$pzq}D(jtucwUwvlSpx;D)ju$R6c`vNAt@>7>dILX(z~iGxukL3gWfxYU}$2*k;EmV;)j`&0F}MIiI@-yCONC zhdn;I9NzHO;Nl3mnLvLH(Ay{)1=)9H<=vo-QPHm@lrNT6m%mGgb5obTsn85i>B78V>|{NcB&N*WrF1_lPsH6$c}Lc+rA{dT!%t4^p%hP)Cs zyix0S>`X4_)G5YZ8l{(>-zTdFvpS@^${Xeon%{khdpG^XXjg4Y-Iq0bvOCbVt(%>eJhzChxDEO zm;H_}+>)X%+u39K1lK|2_SL%)^AV0sOgg_FTD`Se%)4rjmw`gCmBFw(ol~t?KA3B@ zN;$rZH`Fa#zukxH=cq0RDt@4n41t#!L-Ozy)W5miAPt92b)nn@byk?P&Pdf`o~B3K z9#}f?srq`xEsB0-R@TtWOspIVNm|)Pi+%6QElwAA_k*QMWmx*-I|kj(d(?9yLmBqQ zTD}33dzJfM!@r;|QRRi2j38%4!Q+)xhBLxZ&_Ke}BN7riB=|?Ud2$7D6lX3|1yNO; z&dxB#MwcOcG<%2SVMT?J*uco9CYoUBtgUZThX%SI-$6NH%L&qsnz8;Zfs{ZzkEj-! zp8jB9*DWz4FFx2aJXBpOY$?XUV9+0he7Zci42EnMN4`fGzj`0&fovC!Td>PLXL35H z@96Ql;Oc}55hXV#q5KK*S;#cbKr3iyNhU>Ot*E&eF}Xnw>O}{8$<+zDem_oFi-QhXmPhfzfwL)oe`{d zNP408Rg#!m#hN_3k)bE>Y+6F!Ro5JPhh)GK(~e%ki7EPe+l-vr+Q2H zv~6H|1g);l{vUI_G=ZNzm;!;dot+S#o}Slmu;DRTSnc3{iE0cHw6$$P{p2hzFNH0I zxuvR3nV9+GwXb%~bXhTa?gA_)xbM@vj_`TKemKPaP;=9S8$4Kyd)hVm@gv}eq==x( z9vbGvqaZCMy|AMv>n3J*V;PjH>&&WrnC1-mHalXVM6-w zj*t2g_TskB$NofJp#2-q?vi0+*SpSkQ7nBr{b$L9r;N|R_e<29tvmT?+rh=DVb4aR zv8u;ucG2;^5VE_rIxeF_a6uAy20TwAt>*k_~d9 zUA_n*CDsP~#C*MUSe7laA~DOwe&H1=4BeZDMRJMO91$)$enb63Oi)fMiIMW%TC>bI zb)T9R_75JyMWy-l_ts8yYpoxfZ;~7SFT3-gj3hWE6z1_c{ebQLG(-i5kcuc~02&@) z**;L&+FT^eQ*gZ>CCaN-Vg4y3Cyv_$DrVf0TUcrvzma5jdVgfMI$deM+89wXaPyv()1^Eba{?Ey3PU|~I8VPLZfdLO$4em}mAUpQ6Zin{&| zJ~G7L<>!q!^$o&}oX%*YyU#)69$xmqI|?$M_#Ru7Cq)(|F>#7&TBW_(YxryAptgeq zgtAUoqQ8s z;DK1tn^0hc4oLE3xr-_rD)MK!NG`N%JH}((1RNt>S#S@dd<_eY@!>+qDLO#O1a-%u z;QAIRsOU9moY>kJ?`Yo2gMlz>MOt5+3*#ud?;?Ja4h8mjpysQHsDIOirXoBvN8tw{ za|lZgCEgDrz`Wvs(DwnMJs2tX+XWR@fghdt)BH{w^a%?tLC+w}n3)@mWsIojGBvuH z?k3CAg@#B1-FULIih!!A#(4$ZfXO7-+b3i&`WD8BQ-FYu9LgIYw9xzygya8V>Yc(P zYoo5+*ha^;ZQHipv2EMwbdrvpbgYhTJKeFJ3MzK~dcSY~d+*aaSO;~jXU+M{ImSIo zXtVrM=IA+Ll)>TA*y2Vn(u{qqllh;aBxAiRL82|$fT1`dXaRlCiZ6vzshno(xu&`D zk`jxTVD#G6g*gvMKrhjYma;fwDW_p7dhf;9KHh)mcejBZ>7)uA8gG6BMuaV*DMz~`8|F)23rjzg-i z=sTfamJ%^8**0UL@nSKdLz6+BhkCfx+GoBoUIv?;*7H8Gkve0h=a0m=t+on8o~yxL zQzUV4M>JI1XfJgV%5dAlcfU%^+sY6a$gYRHZSjGmL42j$?3|p;(4=iY)ZJgj`W0j{r-vI!^TYY`caZ-3Y z47q!8Sr_k+-@o~Qv2@_dI_>_ipoA^_^L2g@p%l6;1;R@K&pa~$vUL)APiPG+trFHW zQ=Qa{7TPpNoV0Zude;a~1hhcKxt900;6T4bYS-FE6u=g2Fq%c~%XAqqC{|vrJ|4nX zYH4|DNGvzt_nY6|niEZ;KEOJ*sTgRj6Nd231L&$rVPbh2&hLmwrach@blSp8ngtw` zJa^5VQJhJf+Xc==rOZei0FGpJHc05u7JwfskOl4Bk8<`F^(`n7lv=w|V5KRP-DZ1g zyO>&`ArUZ8hl-6#llh@WR%!^H8+P;m)z`qzlt1tHHTi8{f)N0x_fr!MKdM1}V~ZP)jX)0qz~%u^ z-7O&`9w^b%f0hUOQ%y3uC8UI!PYVDQ#*^Pai>`4nF%<$XQ4|#Haq;ng7!ZIEQA5+C zM7NXKE?%yrtR?cKhM1d+--jW4hcQE^MFg-R8-Ms%3v`qjuQOh#bRWoc9}IQhF>)qCUq2jq@;%3$DPA3(Eg8KldFX#6){cla<>*koij7D4dImsw2Kx zdi>J4Pnb^4%^jTm8P!YmJfaypcHgKqje-9?NLDFMLcx&?s4TReLl4(Y)yU-~@bjIt z&<0t~N`p2h?Tvq!W8f?%Rwcw)R}d?I z8LW1Dr}NHO-QxCoG_2KHY|P4`!v zhDF>%EH85vHMOE+OL8UFI^Yu-aY+4J=SnNh6(`EeUR(Yfxw$CqfcYn+G!6N1CAUFu z9-(+rnYZu)LJE>)96w5QXA`_gY&K=VEjk^)b{uWu+_a*QKNK-7OMyT_*SKJS5n6zU z1>&2y{%&exzY`y%fJTH<7${*w_7&)+`3|XNK@t)9T0+V54LkBT?Sw)^uREg;=O9_ROm8&v%f`~5kh(;X$I}Qi z9^f+Zj3}H9wAd0KHp~@Wigt5IWdvhHLO;dKSCB99Dw7r&I%D{gF=j+o2$zdD$o4hfHbHf`W!ia@u{!=GbQ?@KKJI)s<_|P?nIgQ{ju#CXkQ z03B}+mT*50K0oDuX0PTb-7i4cVKjBL7;BYlXmt&#?ct~)3?#EfLxHZ=w9IHL=`Kpt zT02h48;<3`alPHN$`$rb-RujCR6#uQJ;sjHx;% zCL}De|8=ZKu54D~|IJ1y-pG&f)0Wh%tl*2qvd!4h5jpLi$U2kh>2h(3uBrWYXql09hrL zV;k@~U|qT1>_>kruWN8SEpkqBPGn%PPPT&Z{rh)Ce!tPx`tJy!{sIooRAY&Hn~l|S zdx5GL2M#Gw(Nyu>!pv5w`=GzuLdhx5P_Xd3SfM-zL{KajZIz=86?nq}g+@)sOl))f zHlLhRGRK#p{!zif{by{mjor005iqjHns%QCp~@=I9rn%P0w1U1p8F4z z(g=VinZ$g<$VOjETyEyu*>66VyV!6raCUfK%$q5GQPmM5?K<#jNTtgjT>M}Cr_aUS z(#eX=4}P_WJ#Gl{=tP)vrgS2GU_ZGYK)hNR2R9hENjtOaOb9cs+qZ03@hx`T!S&Ep zKT=!+lnAr>PSl7ptzRr2Ge%_QxbPbLMIjBi?XV+5%(OV8#%G7uOK5`cEibdhl9D52 z)DLsvL?4D{_yV2I6t%kxUtBG&lQ642h6_$J(JCMc+$x`DaxNABkmhPIAs-Z=sVkYe zDn3Tkh&!|sFsoN7^{1)acrIQGR(wj0TRXNHr`YAAn5E9-W}xgYG^;NqN&=_qZ*E(# zWIfbsDzN1A>(e69e6h1Z)G?$^1_^PYWuX!qx}q`Rac;dK{V@7lB4 z>P0=`gY%kO(w(7*iiaJ5f;hhq$#IRn^~}Gs)fKek?f}-`yE_3ap0-ZeLOTga1Vwk= zu_d`@2f2fhv^yvoCh@5Gp{fC7x>jmLl=u-^)M}$(Y9!Md{-nnW^&xf0>0wS~s>$?2 z14~rT@5RKQeriO)9fe6Ma~B7be8GghstKWpU+>t9Wu|gnhRjq-3{seOo!eGkUf%)( z1E(%7*VorqnUuzDtpz%h=u+v~Qt9~U;OXclY9QJZ9lX;I zb#P7ugAcFc@B}xiLxQzsMo7(YCuw=mE$tq)QTyQshe}RG%3C;vjOxG5q|;R&FJ>E# z*zG1tIw_>a1gb-xN-s8W|CC0`P z@Vypx%QV>F31QMC`AE5QLX4R20TPWey7>G1`!(t^49qQ|aWzp3|IOQSK?3?C@9c<2EGY?1kbRP0mA5*V|5j93hSC4}Yyj zehj-1$!jSO_yzxHv#4$V>z=1=e;Eu?KXlTfZHKAWO$-6$rH?#!7)q#0saC4TxsEym z4kV7D>ZhJP-=MqPJHmZ};{xC72(;^Yd*W}zqXkKF6((t4U;Dp%5SKiD2Uj|Odza`l zzCjHn4nz)A0YUU1_ac8_M%Rv%Y!J?$W#OOeD}r{=s{hzRlJe>9XM9trOlhoe6YiU| zDG$?0D=G;CKOoz^F7bt}i8OzEJ2fU{6$DB5xMPw8%lF)|)$M>7#djwQN_7_E%4~u| zX=r)sAr5~(?s=;w{xDl@|9V0NB1Cqw7l$p{wHh~f)Udl1X)$6 zEsx3RnNVlNCRTrs50F450-I9!jM2S@r(k^n;h{6Pvo$4ExpCYm1TGDCta)*G&?dKC zr~8;l3due&e&lZJt4jY?E=59m{GXZ=ubxQKLyG?mfsD<;0^hthjL;(eUPCcHa$4rg zylWYblS>TRdWXv{c+d@aM)%X*Y!I=+@Qct^MK$N}%V?kfVXUBONem>9x5ZV44Hvrs z4h}7z;u`6Z7oE<=&*aBWtP^hLIGUcvY8XHHOF~#Gl1<%4aV|O5re?QF?EZx0^}j(t z^u_77&6SPxjjY7lXN#xG$5f6>xY$Z_FYh8iv2uB?R2-t?? z2uFtNr|xi2tIHGqUr$@tJJ4Ma(E0kQ-S%7^CMuRK&w)X8Kt}MT@KHOy*`K5Ms~^0b zw?ON0&{`xRn8K?|(;`!0MdImq*G_ilh$@Xu9s=~WM~iYfMxh4O_p%T?Ia7TIUZy04 zrPkA|8AGp6VlBjhTa%xuk;q8A!mNirLoo-?7Bnj4X#sqlL(>b$$I4C(q)mJ6gA>0H zt*ueymTlaznA;7lVbV1{SlE~|r}XESzQcP+IPpU>Bv6)d>Q%@QUxvwH3av{y?=1yn zw`lBgur-e=o9T6TXc2}jN~g@bm=Xw15kwc&6hTsPDe(r~QqG#RgE)-NyZlQ1I35QU zhFz}Vr~Gw>U2%w+%zJ=MLTMVB1xXBGLYlsA=ZI-025fP*6WYIhG5H+CH3}yvuuZax z3zSF;V%5J-gJu*gHd#KRQ_F~MW8F9Sav5zToKXo?*j{ltgX?keG2>8#DQqFBe+Zfo z6jF7J$Erd++}%1?2IfyGUtwiHu98H~v=INaFB0g?)pj4vm9)dYb|@gc>-aQ-q^2Lb zy#$LRZcJd)D{sIs7J((7&kGvrwOp6DY$T+f5@O}qn%H1J4of_+L2Lo}!m496XcciR z_n^pbcxcqoA8(#1ByWBj5-s(gVMYSWio9tXnEM$CV?j~}766HG^OrZ$wa-|%@0iE7 zz;RMt!(sctz*;2Nyo;NgCFrVA1TT^oH?haN1G1dfg+qYVLa6DMy2ykkzeZ{YRC-rN z)1IoHmVmsmhGf+>k=|`_`$K_tDStisx#2ZxaJ(forLhsJa^~k3XCFcF6%X`xQ%2VT z?ALh&?S3El%tjqy?oj97zgsT|u*Srdqd<+EG$~CyRd?*BWH8tdlb1*Yl{4=M+zt!K z`{k>0-A8pc(r_xpP6sr~Euvpvaz#a?L3_z%i^*+a?3uC6FM5YTfOqYlQd*)_+LH>O zk>DTrNBhJ${S>R7@es06bvYG;8zb*)$P&~JUvA+nwX!6qYQxR#=Ok=R3Tvn$!dgal z6t^NaIXPLv*cX~-qQe-$l_A#f zN1tZ*hZj}g$0AQac8Ti10|f-O$=uM0`4;WZlV0fp4blsGxsWSo20`mi&BQnRM`7>& z*`puW`D0yA6`%JdEEAH9wti8ya zyYZ={^kP5&V=-fX^)&Md|15E7=~Ov!^(5JEll15FOFW_N#mpD9m1WK_5*qHfzx4Y5 z&L(LY+o&u5czAygXobb=QNh&BuO>+Pr7NblfmH9PHTXFsn69;A?qE(>QDH6TCLA6+ zftQd<>EjR(w^fmyn?168#(`TvIsv?)RYpd(Fiq+IJ5H;NlrT8UfLDLTH3q|`kfpyS zy0eT?4^KQ2-GvPt2!aeCrA_wsp0xM~B67fP>sQ9Jci@#wJN$U`2^8trvIi&AGXSJ; zj>VJM?tXlfyWA#-@9e^re?K0d%Nrd(g?jl#ae6@_sldQzq6gZhAy|co4}hjB#8PG0 zv{WaJRprFUO0^)%Nkonb<)WKHp~aw>IXy5j8^>69d<2$VForg}iOap~k_%OB8sc z2FH*M8p%nrvs?XBf=p?#FjB3Oc zAo!`|Q}+I%{v*V-6`GT`AE)Ah6#W6(%=d&E~<-lpEtOsIn6`qgQn@xPprn*>ZJSv7pj2FVzU<2 zJ<@w7Weu!O)5_NXjrl%M4K8Dd25G;_)nuhm8~;%YnT_2L^SSxevGFDAOH6UEYP ziq+!;%44kWi7NHHK$ABe0wxDXY}N-OLy#b9u;?s`#MnTe9d%uUOw4B8l2{5JZ;`=Y z-Z#bpsbcrx{Hg2>8|}cHL(dBv>uGDdL?ZCGLEi>mDY9DGffp;cFO65-&~C&nMWYbW zEwGPH+3m7o&6zimTdd6;8j~;2PtO+kIid!N`VRHd-C|YRxhw;*1s~!K`8qA`m3A_INmtl0<<*>OXI5W0+A}Gf@KT45k12)CRBk+ z5*t|HUA8a&5LD$T;6GqpJty7PIs!GBAi9*v>&G-t%p1CR#QJhdm)^ef3DMd)NPAk5 z4N7#UX6A(rKg)9exogO#KE&1}lfdx!Pv$boASN*e)%j zUs&l1747MPck0@~?7)wsG!0VfN#-f*-QV9vWhKx@93>SOmHnh#tyoywQdkxexxD2l zPZsHz@HjY7alL#UYj=C#XEy1O$GB&+JUV8dCm_rZEAd9y#%_#wM~8*Lse?oD7!`)b z3?Q|O^L{)vUy?CsFk9kwz)7C@kxnE?Z7(+wj)|l!TLrA zv2@o&8g_9@MxZw+BuWJ-({J4FR@M=xr?0FMd!?!0^3qbvr8x-^rVE5!Voz32`ALft>!MQme(-1r{!6}?69+dE{Th2v0efzt4$$R)jv z@a`vbwWmb-9bJBxFu++Z!_ENc$Pud3Q^4r2=95&7&4|Egqeh8Q7-d0b42#V|N6Fnp z!W{wF>BgWpF^A1(bZ#e{(Qj%(04%l9D~W70HP1mYCKiBJOLX`TqWwWYY{3KYR@;b{ z#9hME0!O)kFM9Hve;IRJWLspIzR%I9XPnF!2rMI#p;-N*VG;@#VxbU##nW4}0mVz2 ze6flbbD-h=MDZ;F>!#TYX&cT1A^>MM)R|Xn_mR-?9A#5?6qqb%0Mw?|r-U#nBqlCB z5N3PdAc^dQ0sNJ5Jlk6l|8{&}NgTIf2DaWv+pp0MiHQKjC(fvWS92uwaLcZSL~8(u zPFwzX9ko4&IF^sMzU+N1@u5`<(XAo#6{G(?4-+)l&_Z-jfTS&gYU-E9^C`O8qmxJU zKUMw8AtH%xI4>s&{elf2=EjI7B9a`E>i3C%M=26bFyPURFtbFW`(BS&ir(JuvOaGl zT7M3IfL)K4D;4m*e?DbE%&ji#Ksa*UiE^j0T)xu>*61d6b;u^ss54%4rV7byTcv#w z8e9BsnYuo+kcwmOT6Is~+{q6{=`*j`#c)08!8Xi2DLN{6{`Iv}4V|Su;^I|Ma<4BQ{yvwCZ9F!HfKXn{j_y^Qh@-Se>v886`5;I#!q3UE=E;zajYHXN=AXU z>hk*R zR^p*{St*j1eTS;+;a61jly#dPVV1H6qad*EGd@ngrlQ0^U4@^ zXTLZo8w;ZBP4N%!d)PR+WA}h`ZUO_`*gE{jA*fF_W>XnNtKE(~P;%nTl+HtGMq+ph zvP8g;*dCIm+c-4?gYdX2`7|toa?k@|GkEr5o z^0z2mnJR(I6cp1@?THOaIum6H?b}P0TjqY)7h;p?rrwM4RL}Q0ND;Dk?XyZo+RpT# zJQ*W4r~jVvuW_KL#MtL0g@c>6-87mV)*rP?^lv{&akgMLV`Vf>Th(6Mh6gF8KJ4YD z+39@ZppNV~36xiINtx*rDPE~}zhK8VuLZt_W4|zv0Df>W%2D6h+z#AG$RbORZ*XuM zB4CJ5kp?&a2kjaO9W~GLv;88Ez`I)_^lETKAxprUxY>HP2^HPu$opzj3M;)$ z)NlhEZ{%TGVCdinFBdyXW{3@3n#E?z0Vy{o>3XjYLYEa&qk$oTz(}%K*0a${e`3EC z!sNJTy>q3MxI7A0l8c$cMf`EnK^E06FFSuMxIPsj{s^#Fqq4G2O4f(+9TMVsLkO*5 zYI1_=lTJp;ic24UIS-$3WcLQ@p77wNoENoT&et?Af?p7_R0IznEa@!aSbF zX%>dit%LVVUwt4wJqHvhkLoYL%C)-~+5dFsYXz%*I;|@RqY|@4F~V0I247j~;KH^R z+hWCprZgjOLuTd2a8(7@HmN-3E{~)hf)nbJ68S4hV0o47+)d%d>?Tx3?#0*h2us6d z-xtB}dXT8=;Q}}@n%FKyn(Reh210pl=IZqYUG3Hf{W;0^_J@H1M=&eyn;wE*;f!pC zHa!rDu8<=B8Y@0Yz?tlN&75>GHoOlKl1+bf4hBzbWQ0| z>B{jw+E@s~PQtq*U_e;^ImwSrAP7q^=h4CSA(c zJyZ`OX%*pw{dC?VW>-Z{WRoH_!E)tvGc7 z9fEBL4lYsBNCWXs#ZsMHvLyg+x6`b-1`*tp2JmwQHrX1{gv6VYa=3X?QB4h(1t!!2Kg*D1X&2juZIlXi|YbETUlnBOw~RtPQ_Wi|Q}}ZrRo5 z+4N#N!k9B3iT4;(NzU#DgHTA|kXMY{tc0f#=0*A1-|g_-LKiPV%tPGovEHY!kjo+- zer$PBzfyk<6*?8gDjElob&k)KEd{_vjjcvwmJAC}9JaQ1Voz!>_-Kq`I>)%)^cF&J zJA504ee)cjYa1a}({tGAg`AX`Zp?drK_w^bmJ+1?T_Yp9O7#11lhX2tf?gP3hKV=D zMR{beCDck-Pp%ps-~g!2BKcOBG+I5r8l$Dx_yf^K7w?>yL3<0)vrDc<{Xt2}!JV0! zX6q(%NuFV}YQv&*q5xquE$Hjz^M*aT^HcWSj6+B3|eoI}L&nJm- z7eqLLFKTy981(!p5b(8C0Q`6g_^>j*|McGZh~4}`5qXcS?_*r_KIGzP=Aflxv7YPk zd0~GqX(p>2A3&kQj3I1|Y|P`?$G`w(3CFcRgF;-Pw%2a5`p#V0F~}uS3HB04r)@m2 zsEeB^@h zc8r-trwO!9@QbVd{YHLnC~N2DBC8xtc`gV@VGWH!PrkI*k(UIouWN(`HUb@tG<85& zhD8ar0^E=|T2#x|^`Wr^jP331cjn{;HCe<5Nk?^3{E55uCi3}I%*4k{UC#8w{lk7% z_PFdDW=9{g*`dxrZ6D7Kv+mY6TH*Agy}~MNNfbFTu@}~5rVl;y&wX$M2Mj(0>m*D+ z|7;KnF=D@*yM>nReoNLci-On-XOsb)% zca#vBo<&2!L#3-r^r-#9iu=3Cdz_b4Oj#`HflY(HQ;M;J_LW|~OZkBEpkWVLUj7{| zCnat&8)iR`%x%WT(~Ge2j4I_XYQL8+v)A~ld3_TF2rN%pR|v}n;St8J&0A=iG@s?M z|7Mr2^RJ|%hhRdE&c^Z5)Shi4j$=LxBO#lY1S{GFeV{pvCV8klbMf3lJWE90JV31g zKOrN*rNvsNh=T_|(O2&D%0@=PnuOEQ{H7cVwLX^LJg<=X@8x{IL3B-{!3aG0B7RCr zT4Z^VP-|8D9_nC{PhbIh`k#cKFAuO%?DgP3#`STfRtBw^K92f&S0VhIaqgWD`1CaiE%4wDlPh;>GhsW}>A5(=t>S!?;&${hOq0T?+5#s)nwPagYj_63fdZ9>pMX5uQWL zRCov4D)9L9@8oUp_mg7C~YqZaK&vTk^1!iEk)-aKvq^>ho$LZ zaZhxL$!dc8M^ZJa%tDDHc&#%0jy7uIMoZ%!U~}hL{@d0Nr{HdnlWo9{Ce!cavsf+`HX^o zQZlAKaFQ4SG36cXz0kj_{;PN_ac(z@WY!L82cmP_A%rEE+ZkR14Cb>wQo}IL@U{{X zZ?t?OhCa|T1G5acPEM4mB7@=PgFvf)w?VIT1IirM38D~ve%8jE2@D0nO7G2^0-8^E zC|ZcXnVWG^2)LS=yaI~3H{ungJhXp^1KL!mZb_R(Rzwlq`*xb>KJ-l0;$miW5Mdxk z-pLq$+G3yFnw}n1Iz-B~Q^djEan_8-NF)TwW+l?dOg3D=8sXUcos(6+qPdqV?a|s3rO0HI7NL3S-_dj$ zMGf6hHkeC#f9G}gqS8&+4oh5~UI7MWfLFoMAkIHi*v46jn2}QKx4r&%*-3(WwNz=DT>8xRblqY@s!Vp>as%L@+?PlFNGG!epi^8}#C+D|KqwZgfVUUY-Y;Gp5yRv`)PLl_p%u7d7el;#(^5Sgow?Qh^BUy3 zS>1!<1oTfYKQp7P4@j!0C#L`BevJy-C0B{rV1eBQ0Eh4bF(j>HpU4X>+7Pv}UJe5q zPc2&Z#TmTHu#W)t>xF|N{FrYj0%E!&ARaMc`}0_nQ|d?=_V@&Yq_{f5J|>2$rtUbu zf_nYW^mI}l0mR>A#+A2h!^L)y&BnMs#NfX342<)n-S*d_PTE{8AIpai5Plh%U51yIkvXg! z9@QU}&a^s)e|xheeCLXDbgC;cCIPtbgxTFamOy!n;-kU^Cm%~w_(No((Ae%*j*XQ!=uzG7IdPHT2!y@Lu|Q+5szRJlCCpoguC&}%UG1dBgN$FkV9w0Q0EXX zN&SS-{O3gHMZEx3&UEZzke7x!sfgG7JZU*X)Y6K2D1jih2Y(wx6^-eSfog}A6q7%~ z^&-+7c>T!sam^I|Y>2;H337s8=7=Vl5vYa(kT*kQ);1c%=+s>? z*Hm?JQVUkEpU78xGDzgHwxQ|GlsrGb^7g**k_7jBi|I2t+Plu=lT?0wqmm(qH59Ci zx(Ls%k(U{pLAJ6sF8R++NKV?n!-$6AX_Xii2WQ2~7J~JrB0uQoC`O|n?s7lV* zZ}Ox4{#{5hm|PymN4o8&sR7=2*E=gJjt9}!<;$DN;#uZBU!)uJ>)nn;N~>(b++GNmt3#RwJ39gH4z)cy31(e-z4d0D4YoxAHMv2dLZb^>4N@0o za1H;Jvjy=np#U}*pEEorGb&x(0%cdT_nlM|FZ*X6Lli?XN-@??EJnRZeU*x>9Zff8 z2C{^nkt#k2>-_%n%Zo{N5Gr~d&~P+zzQ$hoP3B?fn1?gmuip(QDaXhc48ik)#$0~S z=Y($Q%m69bR47d|MbqwiA^g;MO|auftJDv%vB*eyk(6rlY-}&gYxk$@Hi{f^A|cno z(N^I5nWc!ms@o-G=(>#Y2!eoHk#7ZD-XXFzK{y}ZHwl0lBq*8PO$P5%d-tGYv zMMX-=i9(zY4w4FssY+5_C|6sZ!g(6MH(mItxPATz+gi=@^!hF(0R`wa-+loZA$=LI zRn_{SBO~_#YaD=Y z#9aBZV}tcKK=~s}u*nUCvgfc8_cpALXgdIHpyi4&nHir0j#_wO0-Yt-G&)hiXt##i zmS|EvL986z_|%k)uO7VL668GQ_at>#t6_pEQoea2Alt$?;!{6RmnL`iw%^bU><$IG zlf$<+YxdT%zM6TOY~Ccvaz8J7ojR9y;z`BVk)Xw45O@r((cO{bBQp$lO~O7R06PQh zjN#f|cZRSsf_;V=A*H0p3*Cdic|;7)R6&D~pxDLSh1Ia9vC(}I<><6s^Xxgu(}MD< zVJPns%FQLZ_84KLXNrg?NUn`EQUYa-)PEn*mHM~*Ym|Hc_2KS$>JJ=7HD}PfF4y~4 zUmTHAHWkz-*G(aIrE#$cLQUcu@;ZE)a#_Hp9Z}$?9=TemwGUe5bD`)qRgLOE*{@h` z>0bre9jvGK3au!~fdx7W0>;dWGr7U20{ROK83$@4O>8@SOYw2aZ-zeaC%FIbu+?p zPD{RWL+UrA=vqRarBqZRDkFt?H-~TTQbY}IH1lvu+JB0J757@s&B^VZ0$Rcxc@oXL zCPWv0$A!}BX;fH*drR*?OKTp^=$Txlvqd7am7pc^6t$CPUTnUVb2Ay(v0&tffrON> zp`Vhe?i1!iQKLM+k}xP_Ya-%OBN&hht|3ykZr?MvrDguH^W$zr3|Hj{R6ip_jSQS1 zUSXwm-H7pi^rrTgF@KUyj$FqY3z|k>?nlLt_0O*^ zO3&(H(pOidV&NpEXMs^lzYL7cfaA+3f_x|4fX?k7xnIMk-aftZeK)`V9$8p zy`b%)prTLR_C7Gr7IU2q)#U1bb-NqRebO2Ru*1pF=E+S$;Pfb*wPC3ShT-Sy(&_2i zWE=t{YOb-TS3*fq`;3ZMQN_#yIAlTFa9->b|6~zk{fa9OiANn;3928r?czXh>RhbT zU#XIMepCs_it&e78Dlee{Ot9G<;B>fs~zUp`^BpQ7?qcc_jin)+J^RMj>$nBD_Z z%L36c)+o0Z`YVGIHuYKDQ)J9~j{hw>) zis(z9EI*qR1B^UfJ)xzCvkZ3t-=MZCxBXgo#9lqS`3Twvb!H{7fXLZZ; zu}|T|6K+z4PhnRheywBP4Kmz!`)7Q(O|Cy)g*BacFu68x#E(PLd?XGTEh8v;u;X}8 zqD#Zrw=#rew6SgxG#wM9+p%BB3M}RhD9?EkFU*X;&W# zdoG~N7bnR|k3$_Nw|$H@v02nBG08Bz2Ct!&8YMQt8sVV;aD>ZCAJEkoXvC+nXe6*u ze*A+I>$rCM)T)0Fjl7Jmv53}gMH7YF)EAcc0&x`{v;kx`!}HAq^ku`561NlI!my4}W&cQ{e8t|a|UB7Pi* zz~YOVkq3X4UB}Czsv)a7g66UymmtoG@|B&y9Y_U^Gtr8H7)Jhq;^G+iIrtPQ0 z%mJm_Vqvq)Si15>D>^ukuC3#2m=aMyvHJ+UeT)y&LPo5mUsr(eur)*H+-a0iQEP(i zjn0&mtSIw+vU8HosFs9~EjF2RdzNmNi%=aD;0^3K5xVue@;Hn$+}XhLEJ#@kjh}6x zsuD!AsDXigJ|@-BhAEiGWLD`c**}YikN#(>?tCSnowc?i%S7+z~9RD?(|MlLPpkuo#C~OLBWv`rPoS|9ym-6CoZc)1sxNmk~3= zQ(tmV(b0@6(g?Jfg&QgpCo!bEwFvR9Wkl1kHC%45q8JlLSgF%*SI3Vb&@zKM@1EbU zj^@qVqTN3$q!xfb!d`K4eX>;Il9y+|%f(LT5gp=&G58tKps`L!GSO*~6GjuUvPJgW zHra^%c$L~VX{fVRR&hd9-j4bf&S7zsnQftwmy0uD@sQR>p>M9AS?V|xugdaCOwr6@?|d}&03{!2gh z=|~V5l}@mo;xZHu>-*e8i^@XiGK?sY0Ox|w-e7NQZ>`mp-RVau4NFF~jw6>UeDmZM z)zry`vJuVHRrf@BB7(X43k!w4p{v0An+`Ce>Q4_gs9+I{0&kothlWz zLxAeB*aklso!$u4C%vj<`cI>n=R=SCJbXzBId1AZNS9yFC;GNv)mO&!_fp<5=v~Sk z8BMqf$E??e^zef&#h9V=YL>rZOBgq?nI1fE5aGlF)aiY%-TrPyYX=!gG`sBD@On+I zn^z;j<{VXK3Xv5ZB|yq+uhQFwLf%r}+#pI|09B;FGfz(1$9V6<5w;ffG!nBNPu`#f z6OhFbsLm|)PtY20`kj$W+ta+fnA4tArXB2|eL;b2|JHXOG4BhRm{Ff~AkkF3imVhJ zP*vu{^>1!jO*wMa-$oyUbs}f^&OI$QocB<_=#Z_uE++U@qVV56s6ZpCZ2b+m%|GNc z49F7;pR!4}Qko3{ib_mHgZ*_*|I@H>HTcuZ0~kkxwr_Q%+BjpJnEU@UbGHZPYhMIy z>OLe7K=Sf9^#Oc95n>qC`H#;k#X}J9^b6|aK3SjgniBUrplSguEPNwVOYXmQe`F0* z!lGzs^#cA*SbNO7yeT+pum{??IE%!Lu_Enj$hN z_z^$|-oYD_LaD1bgB16we1jw>>gkAS5U(khZlW0~o)(aVk|!?XgjJUTdq5xBR!#NU zuDd(K`Eyi4CkAg+J$LPi1?$Cv@n>`rDP~5C^^e2u;UPHmy2l1@7t4!B`l@qyB+tU8 zI_&apTDadJ5oZJww;aD&^eP)19iPQTkd*^gdT3!gv8G22oaJ|l#F#KLkSrV!1sE@dY~9M%fUg)`Ly0A|kVRopKP^Baa7>b}OMjuq@A9%5~ozf;+ceGJD4 z3IGGYy6`(oY+%pE&dIuXNz!6*szGA~azis9rW=@%KoF#|d-d{>^xU_B1UxnlBy4^R zRS<>nk(?D?U~M}dn9e{{cHwk@C@wn2LgBgwSp9p*KYXWw?sEN${o{44#S{)lO6}G4B#IyxL zaGU0#Svh5bZ*&;ZM@Y!IVLpzsq;4_GB+Lo?sSzFGphFVPs|8@KW!&LMIPG2~#njX{ zP#h~;kQti_I}oGEKsyA#yzESwOp`7zxhSYCp^u)1ELI<5U+j(Bto}BT*iW&2OF72{{*)6AiFAxt76S@ zX|P4dU0;Mf7nJTyy_NT561Bc?*OTX@TjqoL?f&=u1@%&5WD#D*ggCk=c&M1Q%!wF| zn7Rb(qb@)zPLxuhzV{8Rq_ zBZ(eyF<%cPi{R8~DziOaJOmcLz z;m!JZc5saag(P)HfS>4x764*@&EwB|v0M2)00JwezG76yrHUPy#~yz3C-i}v9N6rU z|AApBDsv{xfwU)&5*H_h8hkUL`suhME5IsP;5I)>)By9MvxmW=hfike~BN(a5cFjrs}8;mOpE0 zE`ut zzWnkFeD&2WwoY!TcUaYK1@BPN@B0M>vCgXWeo=W94jm!TtLRGP0v($mBxe>9-;}X+ zkU0c&5>m4WIFH~&Qzb6d#^XjMg}WeJDDcG`xh5Bi0+?q$t&73OHBrnM;>Q0lfWRDWq%xJ^j%nhf6XbO#+ndqQ)QZ^nB-Vh3`iC}mDXM6Kwp}zt zG~FRi{2`)hgAZglY>nNo1ugOFUK8w;vxbP89fS`P4-x-}9wy#VwuCe%2^uNg#5+So zW6;jKe|O9^uzO0HDMof*5%gm7A>omO{hVy=VzUkvWj|%7&Cj7BoY$QFNSu-C3ji2FiaLslU*r0?K4ysN2_8CTGSM_swCp$@M^6$g9~bNryJlHZ{>2xcu?h-J&Z!nE zDt{vP=^833`uJj~qhFt$n?plW6Tf;lGXp2@SZtJX#C{cPNFL$BBRCP2HH~+wE_7W| zVoECNt4c6Xk%Et_!*Pi~-KCO1rt?L9xLg*3PimuavpyDA%R@;5?1k|(0(EIFto!6> zP5?3^eBo)OkI3+FWMyR|I3xrCL7_-UD}=gb6a-b=VQ$aG5i=L;mbJnzX%ie&HiwkD z1%zZwprmDi^vqmjX6K=hB4EVSE=H1^FqY|wy0mblr)MBKHWB`TVJI!HK^cL@y2f^dB^E(Y z)ssM-6?O|9z&KO@-9UB8LXKfieg>=t{6-e>FY=Bv94)!&7x-nZX=Lp)K-ruEX-)AeEdG(53e z$%a6$EsojwA-kXg6;%za>RVP>hqxrV_N)l4W!V&lm9^7Th@S{68dsSr~4hlVi^rgm-+ z*KmNKv>Eovn=$Xu&B^!8SM2CCOuB_fWmO$wqT`W}n2405B(&tjAlkE8$Iu&Y$&O=&c$iMTq_>l4uEaYW)_kWYJphI2Nk~jiC0-6kT1F1@58ZDuucQu| zw$V7C#>2it%Gf6?j_uN>SSRj`-|Vr)MmZ;_8oR^D-ka|q4)M;se{_TMpqbG}uHVG6 z>$y~GVE6RYKvXzb1h3Kp*kw^-PIW@oICA=bJI@z z+bJqXra7XSDITV$+%P@qhSQV8b5nunoKAbHvFceEbUhe_v&UTzMn$J^W)M9mur}!U zA0j=0$3K^UBkCr)r>o?Xjv_gp$U=w7H9cGel)6>}Fs*^znjWOCGwB5E4DTB#>4vVS zW+E^!6LnY9k#}_RbnFuc#rx?=g5~3b-5GpKl5UzTK=o*nW5!85u}SOoBbf9b#I} z*s09JA#E3kYdS$t&H>wH?6Fl;2U`yu#;!wZP}9<9iFtNT0Rua_#x5{01Ub2d@bvV8 zjioufNK)?W>kBJ#-_7g@I>W3mlFY?qhC2g0TCqDvG@b3mDlSub{%FmMXWKiZq-P-{ zG>Rns^{i_m?Q|HET8w?_?%1nh4&VMGbGa+Di?#5@#JG9!n%s0$C{H+jH@(%re zeo+xFUb?`(&Mhp(0Zn%ZD%fDRvOV;jf{|ZZ!@NT`k}s>QM>0ux+q=5a+f>0Sb`EF+;`$|U~PO1x$Z{jk8{8{0lDdHKG@Oy>(1o*GVja~;5?BPg3d;| z1xFRa$@^txi?79O8K% zY~4cP>K_H)@FWDqWWYNz4IV^p@kKDnIS;?p z(4FSYgB>60&P@`in{+|fOxm*d{4w+-vFceEv_d38^ohW7;CPtm9(X5+-V#_F^gPmO zl>GwB&(niME2XsZsVDFwS3c)qg#y5u9zzB4JnqtS`7DjXHLzRL1JpH3!0wFU-9R~e zZ=j^>6xj63iAB>d$L~(R{B~3kEwg#c-uK7QlLX7h1G@nvbWRhnBT)A>KRI8}Le=G) zZwc)EiB8cZ3HqP!PPL=SctazBmYOQWki?iic4qDzTmuuZhk(m&1q+DlxWGLmiA^V~ zrd2zloWgRFq}Jf{_!x$2vvIRNhJai+n?857B8=VA#GHa11-uJ|{y3ZGLy}+$Qht1} zD+onnbs1_KTVU(t4s~sPI5|2ZEG!JEscGQ3`#@gP9Kv$C*dn6ArqBsW8R7jM(s*mb zUc9+U6z^>oCU9o}9?uaz9=0$svx1?45v;9k;O_2DATJPHlK5Mh8X?qL2jx5s^n{v| z#M}*IM5lASF*n~l)PAdMvDJQq&_j)iQS2MV>}NhR2%uIG=Zt9IXF%P=*a(xzMr+V66MzM8_l`CzpU@RUPt6>Y#2Djr9s1 z*r)D_eTTW&rpU!6DQj$$u!OLtD>lnIV85z8Oq>E>#S4UkS0sPEU2rmZ;c4K;7Q#4l z4eVB<7}(9q+&5M&UQ>5_OcZ|`lfXYmMOd?x??z?Ne>!s6d;S=Dl34Yu47%+>tH5&T zcodx@QWyAn2L+(#-gracKZRZu_%C`Zz~3I=M+KwD81(1ck}f{~`?wgs7}v%>Mn&1;^KVBL zRwdZc^)(A%r=2-~_>O6O^VJvl`m4|J)mQwrQVSIrev(crF(f&sny=hZ(WL#0Tc2Zc zY67L@W$ZD!kDX6sCO(vM!bWLl9Mp0_Ky)fwBh=K=f%?Warvv?s zW$14xMc=$m)R*97Lm38|D%jV(wZ-VJD#Y>nGIWyLy21=JmKUP8_c$6`d!TFWhj;cF z;eeb0j_I4g#l?kfc%G1$jDWykSlZY@`bxv87U!U~G@mG!ud;k{Jj6%UCFmyVswqWJeHA)d zTXDR95J_n{2#<(CT|I3*-%8Fw_nqwy-k0;jb~PWoBjf3G_0aYbMEAh6R~ z40d^^aQ5oQ_w^3Hinrh$e&udKx4!&>t=UG*_zNH;vN4>FLiU6;SBMhB_ zP*PdPyhBw{(}aYK5)6$^VRUE^{jD{`KMUR|!%5ISf~vSC1KxmO?z6jh2!;;+ed8 z|Ky>)qL6rp?;on3x^i?h))Cn4M_E-pBBJ71^AzfzmewwE9-5)&7=ulUp4hD7gAGbv zc#ZhwRdFuZ=JgT%5GQ&GDo7+wehrHzKtLb`KaY6NL!;YwPU3B8nFHdG=lV zrN{<>|JnMX!0Q{{5RejF|8G_YHvbz}@IwKk_3H%oi|iD5QQ{*8NIz#+$zQ06Z1|tw z|MtJ>o_-W^m!6*oiB>!(iFOIBjkWQJV-4)q^dRxUPXFP7l2z$MFgP8C=Gg*_P5I&W z&^~-SBEKrZj#cck?}A+%0lQnDe~O!*UdQL3e}XR;CFpmi*nKmfoYU&vlBss5hE5?W zHkRE#ci{rOBQo%zs0p_2SHLb=Qyez)gL`m1N(z(U<{twM%P??*b1--Q63)(@$ISE@ zoIX8+iHS)}j89;Ed~#m&KL6{{(Q%B7jA3|q6hng}7^E5+M_x$-HjA6$58Jfx;x2u> zb-)n9vf9u+rVmFaC;0mZATBNr@$vB_QIAGgSU3WMg5e(!K%g%WAt52;u~3AChO&g5 z!{LzR-U^;x-pI_#Vcjgdx=)Z~`xFLF4Wpl^um2Q!dQUPPKRJLCM7<~a+4l#BMlmor zil+8n95Ho8Xm|{pMn~1u(uK&BGHfA9`8o*?yei?2-yd?t^CCPvFU-Sl>7Ha_ZtU@Q zWZbbq!2=r=y|7iq1A^LtP>k)yqpYuE$vFYLf5kg=(APZ(>=sw-jPeIDd*uc`yLkgQ ziEe$#pO(4UJM@DAzIT>S)ubN|_y-4Ldge5%4A(ceK~&odoA)bYv#2&Cw7JlCibP&s z5@Hh)2whF-OyN}|JJ9P1~kP!jAqfk;i0&@#X0{0$>h=^d`i6Q=uh>S#VNGJjb ztoaiK5g&yT-_R=Y!iNDiHa5g}95x-3ZpG2o-i_Yl{R|?Rf5`1|;+1aV9qO3{@AUN# zlHa2d43OjNC-MkNfTN2a@(N1XhXNGriprayWfzBcWxbexUKDr33kO}8e|{^>#qY@d z7scK227$c~_b`0!wx`i^x83!hyX~O$+z&^ke|b(?=-;BZye+VI@B0G368IS!k*KMu z{qlg&>s#M??KgCLgU115`rMC+I08?v(zrHO;aCH^H9bh3vl;kw#^jL#C7sf8d?tat zjt0no4ic~%UY%gaDt6f;e6Y(JKur5-mY9F~={0=*1qHj$m}rvz<(ITmCS7|+t1n-W z$3MR(DW{41nX|KOFFRVXIeYFb{2~e13F~94kQjDI8n7l9>c+N+2(^W>kq6!rcZRHC zI66-ZGq5{x@)TOyI&t;d4P3r*?T#*8zRGmr;uV~~a2e+a2o8@7(Xew3hAsKOJGlmc=jIJtI|mq>n8MW56t=duaB^}4kH=+IJt{iYl2+>I zIz4*_Cukir0oTnJjV-k0d?zxp^3d7aN7T>M-gARaRv!tpS8}{oku&bT#irLuG$(pBZRB*#ONe{e2lJGwgu=@=Gy59;rvyGyxKK%h7PlG;4T>v_P`F9qhl2{`H&3+<=}ue19p!H!?CJ9P)_PG7u=Pd~Ya zo1amzyLrbu3y{YuzTey;Ub=bD8qS3&o8?s%xOnM2dpXU^ct)$5DA!}rg{OLzSN@ecJ*Io%3` z9vgY?@xDRw`{IscW?a@4llrHbc&DJO0czF}c$+{!h5wht-0(Y5KGZ!={Il+m6LuRHQjXo@1B{RWp-wPg z{6JzZrl%ufb0C_DykIwnyY1?gnNj5v;S=Ez*Cf96)>}_~^8U;2@BHui-v#Y+$qj0P zxrp%sorfDYON|RVPHt>GmyY9ShvVVtmU#wB>ZkDI(T=!yA`Cwtasm$2!MDc(#2P&H zd2N9;!O{|8=fQ*>)w)oP-NenOaOciVvBrgj-MxFa83tkZ0~OBi;z#O(C!?sfqk{RV zGpCrSStKrB{Q8nqV@LDHHT6T~#@+B)BRAA(;fCg2b*RcjO{I2`a1I!SExQgfv36N3 zoKv4ExgahV#${wtM+Nhfr_PG=%$+*@74{uCj$z}|5E?%X+gO2l6%8=4WTiMyUTVf- z%$vUebLORE=A2YapEVa#r_aIgF%uB3Pk?K;zG&;>fj)iuqfX<_aOuP<4>A|6!pO;)Xy7#z<(%U2KUO$@y*bso&?M7Qc)ig;CWHepZ1f~#E?k0) z1&c91GZSNzMqb54{z!%03AiU4gk2gFcITmmA5?1G9n~6pp?ZBMwCK_g zwOnIR*;$8*Z6h!tc?C|MIgfK+T@-6kzdC=(=sP+252c}LD2^OE$%Nnqq%YclKDvoW zU%Z})_?=j~VjX+#%h@wuf|RsH?7d}TR`Pstq|6!2f5uLnhWH^P(X($PTDx>bFaN%1 zClpQE_`tJgI2J8mi~Y=>_b`FBXWwD$-gm^{Kc>F3!{9sHcI{`rvlly< zsN2i(w~Os>?;DFY9w8V#j;1`wMApJ(m_2_Px(}R;YOaIvUK^>#?zLumyxJrRuQFlx zexpG2=oiQGn<~zsn6Yp%rcIxU@neQDhG6vYc$Uv#wC@=S_pv#D7~sOLXtLw%Y}`HT zgMUu9VX>-tF0raU{yfqgcaC{7QP&oK9BPKIj|E{HA3P;jD=$bBj_0S=NrswZ+InK` z(v{v9-e_9AjK8l_nTs(@u&G5{ZbWr={L81E!`_M?2=xWq=1Xa%UVizB3U9vg%#OGJ z`@H&@f%1elP*R_|hllFnpMwqY z@IYPMJmQ0`XOhKpXwPds)&$G^3cCrLPcmUA70gN4-MxETociFdQr!K>3g|!GHHbT@ zrseC)7qMpDYAjv86cZ*-HVC`Rh#!}R@-0Hqpm}@LZQL4ln!BN)voD%^$DvNAaMWz= zkB+@!v0}|8965Rnr_Y?l*>mT`f;Sb^&wX`CEO?)01@|*&i3NSlMl z&P7nn1Z<*OvyI!Zdi_?cSi1?!R&5abV=ZBY>_yAgVBykLA~YyPO6D>+djz9?{TArt z+70hkXo|Y^TB4gzAZj-DKtx;;)@|H`z5Diyg>x!w@7Z@qtcj!Lz55S~>$~d+6WN!)QTWcSY~Q)3_8lU7V3kcU^;YwqW_1jgs#$p|@D{pVe5fcqNvuT8EX) z=R#wW(17KuV}~xNR@)h$R%rw`_wLN!J<-``0A@2mxP9ktvH9-aeFooQKBI*CVD4b~ z-@a=vwz9OM4{hItEn9bB?fNZf*DVU~)$&B64*f7RJ_u3zP_*q4hgxofQK@|bs&Ra^ogS|b=)va7?psaenIHuI{@xu zR%qU<(RaA8E3)i4Hx<{<1mfpYZup(WtiMjQ!oSCx;^EQ83>KS?G{irL8{w}*P4M7Q zOZ<4ay(sE-oSuoI7^{ksz0ODw3MTC67#bb(Ex|TcreSH#in#(0sZ^$7`IqKoe4aufw@m}l-AW2GVV-q+l5P)FlbyVDzzGb*3NEd^4K==Kh*llCV$|62 zSh8dh_Uze(V<%4F>nm4q`I~Pfu25~uHzJpqSR|-+kD6{CI(!6a85x*1D;fS_iRh;v zhc&D)Ok!@un#~3_VfkvBUV*_1>aRsH`=s!Mcq~&_`fLJUVspN2icc+V`p~oIhe= zCpX0;VR!QDD@Nau{D*vp1fJwOB<`|(=dS1{#CNxD;@0h(n4O%A3m49dezkhzRy6b) zfksZgXw}x43GiNU>DC|4J)_acH5i}QZH^lCS|NC#9&^}nn>TI1frAHd{^CXEHwOPP z@trG1-x2%deRV+;J>)<7#4+gFHxaX$z#|_apHcdb>_5v_ZxELkGoiO|>1y;HFc<>{ zMk75V3%a;b=+`eCv*x5=ozeVw~_hJhONk2xCE0YPr`&rQ_!YUPkd0N2|ldV84cQWL6f$f@kN_nXyO`*S}p@n zy>%eUw+h4mT1Yi^FV+si+h6#=FDMrC)3Ze1S+Qyr(&i^)?#u~Do}Pqh6NWKi7s7;H z1iJ7VyJE}wbD7w8b|TK4iN!akgK+btA8w!U!JXsXaO=1?ZXD~0%O`?x>SQc-pPGUV z=L$agxFuKf3KVk54{$Mg2T-f}ZaRL7G318@GHL7j~TDiZ#J9PhpqN zgk35Vc5`H5cLLwtyo&4Bzrl^0--^vLr9!$S@Ths+eX+k3jeIC`SFDL4!ME?gml!{B zqF6A$aOooA$EBcLvv5>x+?5qdJ<-TL9JM+ppk9|jsM*RFwHviX_rP#0S-T$TOIEN# z?FO7adj{WJ{TA15T*uYxjB66#UZp-$QeQhNuwS@%340G5M(WbFm@qdB!7*du**g}a z#!SG7(c_pP8;_Am6WH$raU1n<8$EV1k|s>S#3?f|gB8FhOr494T|>~_rw?LdhoW{< zZ*=t?z|v1ftBwKi@EeGknaeRNV;PpR0{!|eJH))UEh9w!*D@M!HrM05rXhIo^IoXb z)DKDHr(wm)RjgpX5=)mY!<31mF=gx!ChB4l9~X)C?*8ZyG~CR4RTR!2rLaq5!cMNS zn|}Z&E?&Wn8|>X(zhZ2XDSsfK8a$=%P}5AwcW8hfsYY+;h~YSO`lPsR^~TMju=}*J zKfY+{il(lC_`)>?^*o2*i}nNXc|&J3YVC$;^U{#9Y!wsgE3l6VJ8A-a_4+jv|Dis2 zrvCE{PMtZ66-NV?tN-g~HA@iMgTI%p-i)g&t zJRGmr4aA#Y^gxe(gT#3lNz@VZ=FY;D@xw7;bOJ^WjfQ_fKeX>Y5bk4go-%_AJ5GME zCRP^Q|0Yl5X^!tN*& zcKgwViMx;m2Qh8)F^pff2MLpxV$k@7j0F-C79xJaA|y;)EHZfVA`G6i5EEuEz{a)n zv3z+NVv^>gZ|oF|p0Nlcrlg`{-|_g=DG}8>#-Mt;{;203jmF&uqprsgRC9?&&33`) z+-E2jtXhvr$!SaoWny;P0?bNZg4vnN*za=8W(DJ!>8mg^Z3Skfq+@n!24>G+gbA}U zFmcWTq%B*Ik<;d*d)O#+_SGRKb}*v!@rcmHVPIq|282grAR{~?1`$zlh>A(Tpu~}) z$Qw0o8eH50(W9#`2E`6T)rLLrejPuQtJf1B*7iXM-&mxt*odT=^AR^}JoG~*K{s>~ zA_h-}elYz`K}6zYM8yq5?2s{t9XbyB#IYC{GaQkFhGS0FGK`!u2OYZh!G~3w;-gya zQL%1ElxNqe;uL|VU393~IUe=94n?&Nad@xUKvZeug?H-r#!H{|z`OMVP`+6ZK5G+< zy6yX+PTL-+%9cBiUXE@|aCuJO0Z*!J)x1|l;k=N9UDvt$khK0NrfoWkVRP0ne_15? z&p636lcSkoqRBe`C5DTV3e!d6Cc(0!+W*7QKMyVjF_4(o`2}r zY2qlbdgeQk%y)FmccR#D1iNnJpwWnC{un!StT+jM1X~Z&4aT_X^Ds4K0fM3uQL9M@ zd|0C`%GYj>PwRETr!9g|&ovTt++tD3b1-VTC!$;%Jw9#H3m-S^iPviS;Xk$f@L}UX zRBjoBT29P=nCSehg&&#)PlV^B_3)a$9iB6H!!zZm=DiwyhYLGSKCmWM6*VKVk|7kV zIVN$ILBswRq}Hy~XP*fFb*^7}M0i(+zt!T*8lt=%|Mqou*!JqH;lqP^n%=RBjf4s-4(5Iu1d_wmMX9(jE1hxxn3TAY8izpi$>g zwDk-^i!S}q)I9`^x@V#b%aC?NE|}7G7~U`70f41n~R>kb(lLV z8B3P0N4LPCXy_7%o;`cQ&(9w{{6o;(Q;$}j1JTIIpWRbmG_L@Q@^G;HOHMlQY4zK0Ipef6RMj2tort^q@kkTe(J!{(s5Pa^sZ z9D(L;{SgqAgodtR_@I>zwO#t-?fQXuwrWp2$9SVwPxSHcjKE$UQMX+$)aw+Ac1at= z`3I>ft@fR&3g?ed*m5HA zCWd;&AUd`$dWD1|agaY^^`3AK2uB;AI5cvLU?R3ZYPa%1wfY@I->F=e(ZB;$oMTa) ziM>kB@u=7`1fSJy11HaZ@bc}=e4;;Ec?63zXCjMghnssvAR;;pqlWi~yKg@`{oo(a7cSlrXw@|e%{ujE?=BDx znGZE|?u~}+`=Vb^FVTNGcJW4|c0SB^{LsRM`B9rrXwb%+-A5SQdPSmd)F?y`n+EUx zBM}@s34@a6!7XSEx&@CwhwhQ^3L1<~ezD>}MAe)IqC)c^yz;rf=sV9<^F#9vT@e(} zkqOVPEDr%_9yT2vr|pKv{A1$Zzvi7PeW#cSyESLBFnPgfi&%Y98tiJ$>Peb?bJtJ7 zvcsvG^JnasiPX*0G}~ns?VF2Pt8yN3a{AJ7VjA^!C5sQt!?Y!1#pTtfGQ|93?y_vK z6FPs&Mrp+6$ALAms;JpZjDdnQznWTUDoe1rAy4Ae37XzUq_X59v%j%y?; zGzmbJMqN?4p%*GKDm3hl3N1oW)g>06IYo*SPk-8|C#p8+gz9X)Vdn_cY#)hcUBb}B zEsPcL2jcS%I<)Pf$GDM!h>hzf7Swf z=l)~Zux$?}OrD7$!;{3qcpMeZ&0kJckvIqv7s75WPc&vU7yO zt4|z;#f4$auvjJnhoO#p0zPL2<~lA+NKh@CYbaWIMWJoCNK6&X?idHB30p-`rGA&@c%;IP zYJ6#edxP(EV&bm-)LkrpMdeogQMs`X^Bu{5DmGyL)1)^lw~Ik_=P358!zax`QH2Cwi+%>*i4uK=#9GrX z15mqT1U~N+i>T-jj33npo&5%gepSmk1XbFGp?s?W_^3&LRIKNQk7|3uBZw8cSMS0u z=07x&>Zqhi;zt1S!vOOc@||erLxU2AV+0d@GiJ_X!h96kdk#QgNFwu#2(<8wMtFEQ zMkYpzBbT}d4rT95&%``=VMKCT?AV5gtDtS&NqM~LT6@Q!xkn@uyE=T> zMu)eWMd0Q71DOAW;w9!o?tXz7Imibdetl4t<)Pugc}#@wx9p=pF%x#HPh}!@WTXj! zA-&NjI6$*~#W87mnowOZLZgCA_JzlV!l#E1;zsL?0+1LnJpu7add+oZNLU4j_tk8d zp>rpUi`HycGGtO5^h3kN<>g1`i}8nIq&l6Nj2*MYv=r;q`w9zdVpUNys!3|i_+HS8 z5=)`E*O1G-xSgTtbhZ6-IC?zikQ#;E^*LXpPSgI-JW}$ImXIm%`%3l@DLBej~UK%Mfd)v}OsTehDuoy5i zcg`|yw?^0z6nqoL<{T_no=AtckGJ@>mDWkvb@S<#YXFvsN*utB?Pk})j-ZLksac(o z`Nze{r#(`Sp$o%3+>I{i}_v~MnI(|T|+>JniYXfS{s7tW4n=V~0mh|EAew!Kk! zI@-sjz$tc~c-;mA=AvE*JD#1RRuH4#Y}5*zj@m(U(30iBZRis8PFjSalT#2odOGwY zCPP1b5~7DqM9lDsh#NTxL&i?WxEZM!pS%bjaqPS-uaxI6*f|;xoQrn)xp0n6LA$sN z)E$_C+F|oicffqq4^Km*i1}h3Tg7CsH1p6VdLG-Kg8HG!s2?($-QO&hx7qBT=CON7 zMRQ#W+Qg*6IX)B3_37+43oZ!@&{&s=rqP+=gwc(o($O$74POjo3`j#wcD>I-36>{z z&uwE_e%Q0FACZQp3CqzXm*#V79<8uDB?<%+-$^t0j>3Q3nE2~3^GmcEwUIr;O!kaZ z;W~)DugDbUFKKA0C%;J-eWx*d#`QvyneVV?7%cnG9Of%2qW_TZkS~${)TQkq?7cAG zX)sVaM`PwMO(JPMUG$$e(W&S-Xdd%Z@*VcBL&*>4;)`JNpJdePI|DWQ&P1aDsc?y3 zh^`}+AZAhqD}c{J^vJ2A@93G|kpIMvn27k%Q!sMUY)qTC0AXX&(Q+X39rmutmnq+^ zqUOSh`Gm`$G&G6IL>=}%>oDJ`$KGwj2=bpa<_8&Q7n>^jPowY@me08?Uvt>Kk?%0r zc^a_oP1*BlsZWK|Aj#j_#ATv=LKa%YWTA7 zMBiz~{IH>Z0bJPMFb#U6d5>l&ZofNv`)f~Vpk8p?#Rlc)^5-ko%nH;3p?9cCB{}kD@^jk z&y0D3g>$sM|Gw}>yZWVv>bmD|N`w3mFM|{(NTqbKIt6NaZZzo z9i~iDz}B4^9j5L;m+3p;K5ZM^rf!4#RKqV@@51nyz8zjOcfxDdE_9lg__W5-&amNa$#o%weOJNb5A*fPAvcQT8>Bl%g2z99OK>^qYGnEK96@%_+#DZM%W*{ATI-QqE3eMjj(hIN-IJK#pX!+b>P zJ4*lYoUsF4XYGXhjGbsdMd?414@mFa)PJ&(e20GNKAgn7k>48bOWbexo1*hfmwcxq zJu@bbC4VsW9m#*lS0w+*_8q1F?1x+OVT(l_QIdrnK{1SC*~;V6Xp+_nyR~PtMDdkV zbIs(nlSN@i0*}^ZVK;PgT&^{JE03p%Emua)Oyt7OhV}+yuu$;km@*d`ZbIHdrnFN7 z$6DRnIh<3AFvn|#L zyYy`{L?I@xTPy4+SO)a#g;^^nX!a2UV)`R#-Z103EbOTEkebG*)5{E>7LWL`(PE7t z7j_aJtWZcbF7=}g`4Nvj z0C%dTqP!dOPxqk|IP$X86i)k$bHkG6b83o)up^Ydb1X;y$@U#uKW-3&`^D!gy}JYA zfIrG-NoagWeE!G9cWo=*A+be5N%9@?p@YV}$&jzpM-aNc(SJ0)!(e_w=a7#hKk7X1 zu$X6~@6a_>nPL$U{WE{Ic-8Q@zrC zY&~F#IU>9r|LoezVKuw#6Y6~AVRR2~GDZsp6TE*jb~9cTxEYNLyMh@PcDBVDVMn!O zzP)TG>8|D3YN={Wbcz%b2?LOGD88Kf^A6? zcIpD5hjH1^)Q$@22B;7!@2B4`G$4z-PJOVbV9N^PR3N8kIUS3|}azsKH#;0kYsr+G+TCUK3 zF6?Y;A2XH;1rsqc|1h=)dGsl3SVefhbF(t(kE(B-5`7QHhe7TRN6HF9IdwkruqPSq zg@TJ+YPz{Vu-$YEkKN`qPIxlQDcUQ`Nf%F~JL{bmxw+B6GDgd2r+X>ON$;QO96gjo zNJe{IbLD=euv_31Rh6A@aHdoA@JuIt9c4OhrpASxZLvn!4Q4`(`mV|Awh(qS(ik;a zQGd})nmZH`@!8Ej2@-a4@FmeUCV8mld{Y-CiM4vPOkY<^lCV=3%0;mz7R04m5u=zR zv7#ujtB;YF#eyjvBiG0n_Q}^$U*B9_|?y6~M7CZf(*V3zKQ1R04> zv9PY#XSQH2-j87$)d*3o4NcNL?}$kapPW|7V=0^$qOi-`ccgXkSs2siQ`i}TvfC*Ld9aknksRl1M z-+3J3uA~dQ#PKoWetF#%!j9-4*-vv#iaF|I<-t>^PhRr6$>Ja|>T}WHW_?3@i_2qE zhI3(OW24E+Y4k=K2*nz4cbrjE$k|GDYT#JCdpn2IY7wS{`p0GSgn{gB>h!r0-p%pi z2mce=(Y4i?$-zjBRb*@zwD)Hze=t@G`L8J~lD5bx=6_6#T`-6XBN@*6gY!Bhh!eqQ zw2iLDw%t*uVXR%`mLNWdY+l40H&d#mD~D;_Q8y0<4m@ZunQUxC@xc;#L+y@nx3nFL_f?}K3-g= zCN3pW*cl6;#)2gktfj)MX-%C}i)T83Ms^`xdmd$>-avZpg>;E*At2WPT3R@lMVy&0Fy-@h7TA>7ck&OMDV6?WTHpFCG_@GctoRNhcjm$|kbA9>;hy$=F)6j=YIv*{Moi~gN zyGLNpF+3S31jBr71>g$9SIAjJIdo{^&@9l^;jUVMNj!mAzo!u4{WQY7oM%5hq zsxy*{QS-?-!8(CUwU_$@ssCcJw@erPEJ;}CE@U{x)S{6s+4@hkEyG1WgwimPw}sZT zob`Jaw2Q8k;T+YTt^X^gan}2Z=gAg!Ke2S->^yxko%A;i+fJEYm78o_*x43KgdGh| zLZV5AMi`?YyW~KMNm^S8yGfa&OooM~kvB<1(p*5wbR_KLFPsDgVrBolv~4rct9MUh z2X+ej>VR9ASQ9MsE9}&TW@W+JwE3K&aBgT)Dc8ox1$FT_1IB``3>8ofzslwO6j~|0 zEbi3J?9_-&Eay@)H%Zjdy{N?iU7Mg<5#zeay=x2S>Oy#KD38m}ecFP;5_Y-y4*81s zJT34Yxt%Myl@$~8LjiraASV&Dhsl^EiIb}Lo*vZW{DGo&q9%$T8qLX4A zHRGgpYJI|m-J>z8t$JUO`j!<$+)+QfNkYy|5X*LGP^Lq07l*%;;wy#;I$e)v5gYI! zvc|T?p+(-fvN-_X?dga6`}^akgCY3-%V1pH9*DE6e6cOXouzAqNzt_t*ZU(T{9Z<+ z&y$icR13Rq4tHG|ISdza7Mn_=nk8yJcuUaUTYjLa4z@F17VN;7pja;o3n%?%x%O>= zb95Uf__~wGlhWwEl7yXZUP3~G`2I6obZ%lAmzd8D=aGb+lWu~1?kp#LtUgC z*RX@QMBEcueEd7HNnE$|R zIWIH!UBj3i-(ci@8orWg92Bc-FSuQoN6t@|gSSo!-BkU-Diu0?VLtefki(Soj z#qs8@`=}`?!Y;S?qh#caL<|~}v#_I>Xo;{Rs82WvNLr>Ql*;Z2GHPK*F^+_s{Be}- zjk-3->AA3bv_=Amra;&xXa!4NP)*ei#=C+UA@#-^!ZXed%JdlE?#;_2o6C>NX#YoyR(DPH>R7bIUd^>}-oIg&j2|8J9X-vt1@`RFpVHf_l3-sLvcVN1>*bRO@DL zIs>(^qZ&OL%!MXSr$9;rs>q$?iIOPnjD={qFe?_orRH?9a8sr+iZw%FR9rV6n_bXV zo?BU9l`-bQ@Efw?Jf?qn&4l~E{e{(Mzn9kcUc)r@J9@h$zIrS;4YIihZ;0DvR-U+xqdKdF9YNC`P|YDt#!dr#DdQ{`c8}ivD@e6= z1ra!R8B}nmCXFV-=7X1n7hD>Y4vO;07Ix9T&th`)=lIi%aLX6RLUCiOKSmC$hDe`m zaTnq3_=ji9G9=dW5)Cv5Uq3+_(NcX+mO{-puQL(^so$dA7!8C%HFiu~{Am<$F8XT= z+UjZG9Vv}eL${0V`$r-2&dl!^{yOdF!meP(g`I7&rLY@5Ene)$rru6azk=RD{<&UD z3C#mU{pYBEAN5C2?P+hDSzr0bSlMbM41JFkDpqsM|XqOj|=_>8eYEaSWI zG`_k2JLVs`X%twpu#>MTImfQMgdv+R8(DhtHm?2fJBDn@-bVCVd{#VW&Drm9;Ob9E-2Al( zS;xM^FMs|cek72tdkRU+Z^(bl`3@}yulNcNAELm%Bj3N$cet?2KeiNh)V!aXOvG8Cq&D~UJLlXx z#{J74taq-pX3uA?iDfZd;_4JQeRqB4uO&A*6)NdkrwGoh9$69$CEa7VEmMvB#C%NF z5kGjqf+GjP{DnMS9LH(;ZEDq1);$fWZO}sB(#YvH((oA?=#!SMou~e)_|4gGd0zsO z>u`6wR;k3;Eqy^fl8-1YvGX$d{PSj+*!6ctbhA&FKXjM&xziV)eJHn)0Gw) z!{mz>8Rp_xdGd0+>rz$_9~okfpD^W-e5`&(@8wizb>~CJ2f=gvBepA>&nm*NvN914QHm>f7-YcX1+VOAS!D+|nTbozZ7s~q zHPw-cc3Lx;AH5eBl)f&cqEm>s$qj-PvfRKLc%aB}6O`9Iu3>aum77Ox=urzv!CAVR z8oRNV@?(Nmr8Nm}6xeWvXSR#VRJsM)lq=&r=XMp9Y&E_?Mf9HHa(>FpzN)jEEsm0D|g+8x4BsKzx@VKv8Y&v{JWQ_T)^ zgvyX+EpW-+>5_EH9x?bVY+^{xmz+x%Vws2jjkEakVUcCOE$0M@Pv!?hX)en{W2b;t z6r~Cs1--YUNc;i2W`%Uvw8?Z$&W3mUCzeufdF$YqX7K)dTB{SGhuHIU=fLT})OD<2 zy~GJ@0=#RX&b1Xq%4y_b{ZKIni+KJAF98~4em%DVKTkK8T;82t?R=6bD_R<4_%?mM zUHtwx{rFN}qSYmEZmqNXk!gQ@R2;1SKU%H^#Z?eChTv5-EH>c5wvPhljyb2)mZIn7cuw zqA=sSFJXYHtuMAZ*$$|>&|of?4OY2d(6mTG`RS269?R>54HA3{W;VFcN)Qv?Nw|z_ zKCoWC-%&IcUDPx-*Gz#Yc#q649-Y-H!mN;KtNYXNu+UWtmG&`gDA?KcD@r67540OE zEuWMf{N!!A7VopA{HX91Y{!HRBJOOpb`q{Vn`D{$6Z89?srWYPNvgai-{qjFQ+&iJ zI*WHLujRhr=6LE$l2HC_*1-EiL@x|b{vImk9pG4{II#K%`k^274-R@P`XnN!^&e0^%jSdk zt8@m%wVsPLUpq8}#htbIt338rjr82C^tjTq_bkN+RB!9q)vLFL2GcN(s3KYVc|CG* z`q#7+rp9RBM#oq}ZsQyy9OtT`9Qo}oWX#9NMrnYfi-RmNEc0t_(xik_qR!K+GBYVA zpAH2P*)yFJPeh-0uEFl`*;2>cqEYpVXlHVsom0j`fue>LD4l6D947{=ajE1j*S|_< zvZ}SdP7IKi%NL;0Qx%$b@JV<|9m+2R$BfhHD^@%vPyZ)~FW6p8<f%pmhScN(95MTP~gEEWq(-D*nhdkugo zz|V*IIA(1=7Wa}upa&OR{dldnA;0&4CKKo>Gxp)8bHc!Cbnnky!Q0@K8YS<#8iZ@; zSs0kPHFT@Tz-?g5bR;67x-9zgHt(JMc(y*KENovUenC+RNJI8#BPxy0Zy@$~L(N=i z!h6J`Dxc+PCRTMp^+7vS^WhPF299(C%(5(^(!m~`iEGC zzb<%rq1=EOy`5rD&x7SNVBADB7#W&&Bz#XN=5dmNRM(kar&~i6!CFUJ{brp?5eIbl zls_4nd?Yxlm4lbJ7Y?FEXD86x^+5HKu>XPS7q^_4iFAg)g1gXcve0GV{(ovx{- z8x`~!Mf8xn2#+(_NhR)~g0Nyb6#V;yI5~|TrOje(kCpN?G8A1Aq5yez{s3pL?oQFW z1>eI(vRJ!yfS-d|N`Ga>yRl%?v2Li@|@6cQc9Jr`olg*GNE=H}rUG-*e?X9^Cz6vwm?+cp3qZ>-t(` zXhTYdN*pQ~*UKi##;*WFi_Bk0p!yIL)jp=nm;W;9iPlLetl=u;1XeMW9vRkddtEjW z>FaWSsMvo*WPOZtA3viiGzV6Y_w|v!B)82brvkttmw`Yjn@7v1A3Z>QRx_gLe#<-u z`3z*GViv=AZ1kKh?qoy?QGq;H^JAcVPk+~Led~g3m>pxj<|%&sQp8zW0Suj?w7AN~ zcOzhf!FOFo?*4C^^T_-yaaN4FM9E3P4#0pr-b8q3Q_qAehViVKmRU`T}XDg5Hqy<&e`ItXIhU;u!6T)h`FJ6 zzlQ+A4jb=%OsbSn&c^cLRca5G;TGM%S+%b$D|6{iWt++!=x+Sk1%K9bRnj&O?=@YwCnNG( zMbPbu>M5xs&&#^`^)&0P5uM5EQP%Daqiec9CD(}9p0*S#-<~AI$Erc`@Pet1M?1`^ z?k`9$yv@cfH-#hI-V1NFS4%)jtd9ZD={X%O`^ODwdZ02MX!G95o!R3)#;JZx?2hy= zvv&c^o*DHZ#e{vib|<&Xli!+KJ%1CV*xCrvq(slgI(vbc4(0CZF>b^Jy|61(TubCA zC=%|~eZP+pP%?8p_&N#Lca1kBaB3(*mlVH=YabgbHp~syLZ-eX@<;hFF{owdlfnPwr_mGq6EUqMCHa1>hD- zLtQyO&FvCp@Cv4i1sM0<%m`O{tu|Jx-6fGx5&KLJvO@oW^PnntdS6v3WmJ zaYgD3R-pS~Vo3AQIjk|;iCuds0Daoizt(kM>UG<8p2`g8By8nF?;8d@LPK-wH3P8w51mnx*xG{L?h*{s$lx`=W3tIR z>iHI9+Djx#F*vsIbKxtk;t}A)JK=FZzWDnWz0}lbXVcfn@lC#~nOSjet_3;;2cWY7eZIGJ(~N};9P|`R zR1#!$sxEj}KUwCwGJWN;^7i=%6{xQNQLjz{!r`{9NWOfxpX~iaZ9-Wj=MS=lDUe+8chB!+u!Pq5gCO%jjxsl!j7$XZQoK7~Y!^#TCqq zmXbT{)wG{|uNA%PP#b6?yyC6369+A&&i=f4?7UaD@stU6rC&zZHH8xOMGw3`o(-rh z{xc58!3FB)JP09bb5>E6ll7)?7}C;hc@e>hbVgss$T(r@TzL+blCM=9waNUig(*I3DFa&h^+%mMT1E5Dr@3U`VUkBK+md`k#c^QTp%|lQoFa2(Q&fl0L?Kq2nlC zxcWq@GSL|Llkv^~Tv;Ao?R6AAoi2bI_dCknU(vZMIz#f`v)3{J^35Cb@8Tj<@ny;ql>JTqnQ~`KU|T_?F;+lvi?k_a`si9e5;wFao?L4G2XEzb8YGo<8;Q>T|4o z&!{EsLMsuua=86d1M`wtQ|qg2RicttELq|{ zS(~F2}tj#7`ZjAx~LX`E0b}Qs{k7I_2JF~a3;t)*0b3_*4uq5ZMrQPZ@ z{wx}#cedJ*zfYhr6znZ>f`4l@+e8y^QI(4;=y3@M#QIU<%9(t4<)bPAsrqZ-Gn8pW z0OY!mhe4vf9rNxEUv8Y`w#KXCi$7cmOggn8MUUuue{@D)KP*)!?WgMc&k)QIm=>f1 z$vP?7QvZ151*`W~^nVn(jG;&eDN}vlYFWN2cMJra^91SHb*kd@DR%k{tON0i2?{@5 z#luq{Zq97a=1ca|pV*w7wGqaSY*PFQI*-XV8giw_;0~f2I_tq^Dho-)?DiW^x*(5q ze2|29m+T*?k zvQXu3;yBy+uL$qR(Ug^5Hu@_*Kpvx7#Dp@8lj$gPpS%it%~GsHY4Xob!-9nDHwVAl zxm!pz@AOqvKvQw_V3mcIobSxQr%CwbM7RmNo3a>xzR@i=J8r34s4oj42YCcf$lGwe zk<21Uc|S1g!BUEIo&QS9@rs5vF+zIgrdOnq?j&SO`FBqv716^yK3Yn_()x zGJQXL?UJhmw^zQ^9}RG!`8-<>tj$ z!IAx;KjM=S@9NqgjObjf6P#c8xIEX)<8YKoh0i8@Sd(&op@O@qgdkZP;id`N6zyfV zx@?^>=VlAcrqud)=iJ5!TW|jg3UD?feE^kZf^Hl08*S5k45|8>zPQ}z-Yj^-u?o+~ zzkKgrjquyy-}yD{6CXzimGXx6IT&h~jA#CU-;@ z_PHvXvTqC}@9)Fa_;Mt(M5XH?Ykwb@p0?X(cN`_}%?X65!}>^^f%y zQ_^!O77#Ht5U}UTS-vL!mGIa_w>#GZ9sWwnSC{H_x2t;Yy=#LSp5nhQt>8P=b zA&y}=i}8yM-pqI_$o63$8}+bI5dCe@=5)23sMRsl$i}Rl-r=9?_eiY>S}B}c=v~5d zB_jNn9?$@3Pa)R==c6fL2>|slUHqmDZ!0PnSAtqVOI;*u*aJQDhs54V*g1&h9c8YN zuP{->T59Se$oUnh(95u=d>s9*z)Pt8m&ReBn@^D|fbE?@}&F0k^-V$&Qn24CXm zFa*a$Guem8jL~81=5&LR&o*@HCThGIGm7l9Mgw8HdANE=e~iuG_ka- zrgJZ48mPM;&jS>GYSvqHQLXa}Tub@E0tCl^^fU}%C$Hhug9isQs*=TW%XYvDdpB@W z^L~VPrC++|Cg_a=%}JkJ1>`6X!*q^34I8YPp42OdIv>gfC^CdXdFJlLT6uEKf~W6o zBV-y}p)geybuFfRN52fe!DV}r|K%QBbKop=&f4(ZCZ57P)YJ*@&WJs+7VcItQ(FEx zDA`9NQZnzCP#*fCc6*P|lRsMn!tpA=y9U62r!~0|as-o5GqT7(AFVPMV(PT8_1+Mc zGrqj`Ez`yoW^;6_k(584xSS4aUAslt>a8E=`PveYdh67HOrUy(E1xj&jvFD-k42vg z^uh{a`-<7Txcm=Obfg0{I!CKjs^wpQIAY|5S#Db%@0aZR>CpbY!2F$xywKjMFGDF7 z>sf1dFLpMYBf)gju0S!Fo!{{Z{HU6ujf#!3{IVHB$hxst#XrY1K z_b2D)VI|dV9joL=?T4HA;;js73`X8r3!Hn$zZYNI|B9~~8nRq%rS)H!7sLYE&w)F< zwF2NjvqaWWGlYY2arv0++;=On9%>rIO)PdD@Ps$D()^>`08&T`K(kdGH2yXr;|DCj zQ|QyD0}f96&F>#|IR8^%RX(Gd%ogw}oqJWp%lDRy73V^}Twx9RWoms*@{ncQ*oN?xFJZEJ|QbV|roU{GmTO-BBQb z|36lssgr`>(*H|7F5X`iVfO15HJ^CPP4xZu&q-M4P(NAGy?k7U-LT{SAiX+#hdL)y z<#J{DL1dAPxTWnKQ40s;{Y`r|I=tBSwd)*giN5Z$5l5&`lp4bYf6-I-D#7i zpHqs^KO@|f_kf?cm#8+=^$~JL*YmJO$gDiOpfcz%FG<+=+ym*ASb^_(r~;}X2e99M z^v96ymA^zkuc~IcWgn=uEC$}+-*+6JomHVDB8KfH%I!ZVkQe(RuwmPJHL+c~qL{=X)E#EC>(Q4-3!3L90_pSq_cBqZ@8{OFuu^oPU>sOf7rp>F|g7 zJSTdL1XmG~!r2iSgRLt*bY9C^W(k(7*!b{oA!Az@7jUY`0}}`<}KEkZVNprRf|!)fCCv)|j0=neSl! z{^NnF68_22>GF+f?~*vLehctsCW%2imBS3dToIwwydB%dZpxg{uB|Ik;3VaR8{1z1 z&3_RgXctU(-i5m(ef)8>410e|N>=q57|TcRRbLzM zW#)!Ky!`#+Ig?e0rPfO;t-Um=Efl6wqcnIlb0<^`H_+=iA8Xbp6(VtCl{#u^gjJiD8LTyw_8u#uAx-hzIuHQO51mH0i|Rc zAoU337wmmY=Op??f)WC z-=M5kC53bCwoOGtME*s^#Lg|lb1GMDg$w_GyHgj-l8nm@)^AcicQ2HTiZ^^4K@!9U zZ2@|)!-`9}c%@t&o`1z^3!3uG`9q<+LT3TxsBQM6LZm~(XUWmx8}eRytK(uny@7ty zOAZIXyKK=c%6lP^cZ=I7u+R7(U}g=gsbzm4A~2Wdh=mYtX-N)W&zkzFeen`@DR~EH zWx@d{D11)65Yeiw0h(&00jz@I&r%=dpIbQO$Uuy>>~@U)Z$8Xw-@?czK-MK&)b za)%4VMrO)Za(m>iB##C90AXGzi+LNkWMOq;Au&q1#lIl|U(y(p>kDGAZjYR-T=aQ( zaB~z%ie2qVWho83f^iLz_1bVCrq~y7m58?lwrj+<^ zbVGp;qI1nsgpqF)!wpdnu2DbwJjB;n7LK|B@*k-xpQ*n`6Vdvc7`D2<$)ENi@gNL^ zb4N5Ezt?)&BK2+QPFZ8jpU|4fiwGbFxq}XXABF-4PYFBs~vh~n;T>K&K*SfSU_uFt=*nh*7k98Epv}tij|1~nS{ z+1OZHY=+QJW6r>5156db3VJ-;8Z%&&y<2|(;gWYj+J~6R?(~-IE^9Md5UDNIP+7CX ztYP;BBonFDOLR+O6nelpF{lmXdvC2+*dI8creV&Db0vw7x**s#K8gY;2aif;yHxSe zvI*F3Vt{-LzwybnF9On+JZEV#uJM8+SSl2W&y=$%GA1b{n@5n5SEVg503CFhI>2xE z?5_1Z9@&aJCqhWFijRPPwgC4wpflvqg>MYx6>{n!Vv@Scs};%Bt-^F{qD*6`jiG_p zFZoBn1ms`wH@z_#Gj+e`_eT{tE9lp zUsz`$Nk;9DIVzj=u+&^GZ}OyJGq2s~9;;4Q7N8G;p-TXA?Znh5W?HvOmoP(U`nE5j zmHI6Yh{5Q|-&oL`;8|9HIzpj4RNpdNvys%qExZhuzqLK*^ z05jsLZSTGQ+)l#bc!i|jO$wIfBsBahhaLlfzm6k(3^W@Z&A)X! zTa|Qjsx4a6iY23BP*+|qapG2q(xy9z}K9Vsr} zH6Gs{%s}pjXo&oaw%=fb+J=*`gICn3kLTG8?;?S?DvW&b)|A{gthNwk8RrTo|GQEY zQ{)0UW1~KV08vJgwr9i=B6oGtGk{78l-ygAe z-xm3YfeN)Dm`2^Y%C8o}_$0?7TXP5Kd?BA9L)!=ZA)JTa4iV(H0EESTXKO1OPz33{ zrE!zc^q0L`)Hw+J?b}%aN^78)x>(jRmP$bzL@MvKK3^WC$&Y_{5H{~RoNL>`4#cZL z8i6#d=g|w!NT4^^5j0(+(>QKau8cujZ;IQm4!#SQ-Da|v0S@a3%nL(8!)Q@ zyb3~L`e7e9UM9em(55%d1qQS-fGNT8*-4d1w6_t<8phfW694W(1tJq{r@CQ=$5zo8 zH$ayS5rddUe&qP8%I#{T4fl3-Oo__CJ!~|l$-927f_0!1ZB#PCDvy6J1Tg9hBAGmnHg{l5>2zh+U{L$9n}W_%QbTD`on-j(Y%iK_ zJPo%qLs$q8gtDnXdz#m_h)e*rrOkKR)BPhF_~F!Mi?Lk#TSP9)cxjDD`6%Y;e52Bg zshyx!hXJ}(fXpe6D5x>TcCAmyW6ZN*EV%NgzM?W@u;w#9sTze5u4ul|=}H^^>ke2} zzbluqe{_vtI&41gR(39itSKqHH`LNH$b&Ce>VEpqxnJSr!rbd-}u^xYy zyriVy1c*IecpM`SWt+nxhhVaq{*8}z844%c)hz|?qP|TCF$XsSd@O{e`84eK=S1Cn?^iPknR-&OWB?C|Hz>V2ghjq&k=5dQ+wf~0^W?#0%p$2ZYAsu z8$JBCx9-(ch^ziR1mvnqYs$VHY~02Qc@`_=hwQe?dYlRJB-gqx1^N;26ZYBBde~9* z{Qb6o5+2$=&jon$4|KY_fEB!$Vtl%T^)8C4x01MOhRG@@jN5W?<){;x^pM9{+TS6w zG!Tp+iAR%uf=nT%$!%Lsrv~o!m)(zYqzi0!kw=+EvH#zl{j>VWauDfHlXF0fewuXs zzm->`-Vj@MH3MOB0#Z&QQ=}TXZ!xA}OCBDGQ>%q9ns>iB&%QH!+U(LWWQ~CE@0zoJ zDUM}O%5HWcmnriVn{3yY|0MSe?^Pt&by6Vg=q2}&Tk0jmxQBV8Og_c&F6KwV(+c_d zb9RoS^}`wUfjyf7dTWfTWs1Yqkq-3&^Rb8waNI4`TR`cEj#d|oXqy%{MWo-ob56b;ffYZFdfLv$4%wd)x`vy@l`H98I1)bOoKo zouB``pt_%gU!#7(RT!{edZ)1?y5Fch*+>QLgyhjtLvw{d_duUR*f4YQ?Fvoy$WGP()1E0H1w<6`~jBImxJt`=`gG=(FOkLhbQB05O|G;)Du=V zqG z4fN*==lTs-QP&Kouz4=QW*YvGBXkRsjlZj|eD*EBUh=QH^RX6KSSh=6G{31z*L3B8 zx{>U=Vxs8Sfjft2QVB+5)FccC82&m6i^NfeVrb7th(x^w4>O8ex$rZ@#ql(>G=^@G zOlM{z9&fWLoC}R4gF2zDRDs2lB%ln!31lEECAtM3zr_LNwtJnC-Dcmvu0io*LrUz; zP5Wi2Kb?~fI5W0&N0_4(?Z||07q&DXcrRr41*tL4Z2rKzNm&Q4#|vpxS3}K4>3+fJ z{J7m(O3UaM0G9jkkLLztm~e0a4AA(&$Q_&A*hIP}mqEY1reM4fEwlpH(icBo#g$9N zy0Uw`TiRi-|E_reZu1Vo4}5Y_Bo=2hrvZCv;4zYARIa^IL#lr^ffx9H_e&EsC&mmu zeJd@EdVl73!GLXUb&-V6jr)AqS*EL{+xP{dT(JU5hBmbS6v~|vKxbRUZ+2t|03c*1 zz;$@sv^}OaZqWQ0W3Q+{G4OY~C|GlJO|{HTlvpxlEFO?pqG6Djhd;#(Tw6&H-W)#` zmYc@bF!Jr?+k^cC{qkAS=={^6mN6MJzY-Q8!ho0j*G&aN8u+c<6EDr07iqZB`9{Vx zft%ONv9rEP0{9>uEXxqh28b!gczH*dm;TE)eiXl%-coDFUxrN+gda3C>56FBnkhFh zPM4onQ)B<-X7==}3g#_hQHEp*3zK293BgbXx!jWA6e>X3KraTG^$R0tkBfh00SKsO z8;^<=cqtVB-;@p4ALH#!T_-zIrWEeJdy+S00$If0s;T~$rz?*TmuqPWQ)mDgBhT=k z@u#=p9B0NO`)xcK4r@WGj!o4|2eSE}C{;?U`a8pE&Q~^b95d2w(V)KvE4uxR^e3)9 zRW+=pW^1mw#j=qHlhPsq4xIXI)N!!e zTIg@b0kzix&Om}7)u-npi`G{SN?0Ce;WVswD{fo?BPzYxt%^e3Z*RR4YHsplFR6}; z=j{-%L(aBW7EbXf#n!PCg;k8f+r%=t``>3Dn|J;;U1Woi6MKWp?$}Iw6cMxUPy!MQ zmPOaEPtFPX5Xfp~+8T{a`9N(0@D^tn7E~^LV7x3#37e#iV=eA7yqi@91tL z2jt@$!v_@eZaPAoWfqij#QROZC-!IDi{)r#OnU7~!Hq8WLZZHD=}davs$XZ6fsbY-azm)s>-+u(=qZ+;-p^yHy%H=t)w+e1pHpYd;;ni%{Evk` zss)^tM%tD+GDJNdroHv579O#N8Sg8?HatT(vwr#^vGjb{NLES=%6eR;+~hAbm!AmC zZrc7^DAixXY?^xHw82g};Oln=kX#ZeB`@Y@0$rP}|G_^F17=p~tV6a^=HsK!%bB1#fdRo{GL_CMPv+je z@ZES8e@_P(SX$uf{)wK4*m+*Fp*J6tu~&i_wba zNA@gWX1nG)*^>0}*eW(lq;S84_Drx7`Tue!P6Z zE~^X{?(x0C+@45|Rx3RE%^K%p^`yP(((ak1d80@Fi>8mxTY-Bt$NS}R()ZC;!0WX` zM>!eBC@FPy8=pB^Xj^3!B1q0A^Xv96=uyg8+&y7Qz%@GZ%nmWbjKQ6iDL{~KM~TS& z{*dYyGT(7?>;%%Pb2?DL^;Y1Z#S?f2Pq>;ZAYd%p@j0ObUJgY4-r9G?eI1+0Qf%td zjD(Y?Z(vQPtzj7#N{9f;5nF8i7>36kjR1LSdCP<2>;0F~TocbMgL~J26}teW;@;aHG(vA{MyYz2VQf!(w1c8P9HJT;_p_0^e znpW}txLz{JX_{u)5%B#7nnu~-4Ii~MEwcg(y_~C;s54e+U_ZyCL3T6InjjLIK_6zE z7hUY|1XE(vt{)oIEj2CEVuJf+G%d4HDc2a`sCFsdF=Im=wF+#)H=W>JSgnWU3LFd-2*(Ej-M+*#xjpYwoLK8FMF+;3XftbdG1LXShg- zJ}||)>?5;dd{0BF>XoCNKAP7m6kBA2A9m>|Rj7CU-1o7K$EH@McXiUcxj)57F&B}! z-c=q$yoZDQ0eameD5+_i6wIZ5omNp-(=Ao|tDg_$=zw3zQ6^CB6Q3>xcJ&w=>2E#u zyez?Y8p1h0UziM<+V(%pqga)kV&`q8z1llRd|<||N%O9oW5YeF_`ZKtG;OZfgM9UY zk#gcwM^9s96jn_yBk2g!^bo(+wUDKzg&34=sF%J>uDsE7b!eAN+cs;cqi&u$;gE~4 zayy4ffl$*VTYk-zk1heKz~?hSsA-fQ5G*AF*+@i-Ehd5M=U9`wRrm;lb%l|x9_6mg zba{A3*>KBh#DLZ#P)jz9X32Up(WlU4Q~D^EieT2o#fNa}2tRo_9Q?!f&GXmkJh4s# z=#&!;dc3YN2WDnGk|6nIj4R*A-)s{FqQTzzV!uhz!jAJ3H#`>|5((+%TD+pL4-Vc}B zKHQT`f74b)J+JL{e$COGB1ehMka@X1cjI|9CNvPK05lY zH7zs*AGikWlzq!mAAi>v)F@6@jhPXAql9X)zO=6s%0$-3dZQ5B#tP&A8iaUV_pP%K zg^`re(7-e_iHJoJKF*zQtj;nb-&pZ}eAhm%K?-cFAvU52`Xm!2<=~F`@e~cr$7F9a5Bk>XkNxFlCD(ZK zH6<}iVr0jr+|;b01a9zqo#jw8x>?oH0BF${7$& zi4SJt6U@jlLKlz|j9AG-QO$hpfD>v%%_`zBD*9~Aa!iRFZp4U(*XJmg(z=Pnjl?K5 zi#%}YLkK&ECs)@+-3q)Q;l!XZh}h~P4($cn4Br`jKsqB29Ci1|1TowDBGx)*|Fnv? zuq0#p+DhXjpPEbra z#aVxj2DzIj5q`92>F&|k#6Qi-bqMzpqrkw!3=Ar835)2pID*7JcMLT9dFMxLtmFB& zd$YlH6GAMv-=ed;1`46^i5MubZ4TuN7XwmE_?JPHU#|6PMg~dK`^my$ZUa-Fj1LZD z-?sWH{jgSB-Gar1WwBX1@qcTju378K;==TdL6U?jz2@!?M@2!+OWYVBq1ifMR3#1b zktoPG&5YGwlpqmRtopmq4Az4W+Zp2;E~4tsR}#`5-NSSEI;b$Yo2tlJ%PD0umer3| zt?pZN?EE`4a%`jCQ4L6Ae=tQ?CS-nH)Fwo#y%_GiLb#FJ9!Mob?m2UD=xQ|QR`?WFQ2tH}){g$e z#tHe;E3}tOmUf2iFOv`So^q@Mc*Mf|Z$P9~uK?2;^>vIK`c2>GLV})KW*2#8g-y6A zYz}^8WEQSpZf7`~5M++H%?%IL^5~#wLFU7aY2E`%q;Ao85ClB5G<%}caHU@Tlrm3t0ZgUm{_hEb5N-VVm;Gy zo-OekHC27arCNMGv+TsU(D(0H;h?X(y+_DGgQO_f>B(g12nSghC2f6BwCO}ys%%5w zcjqSL(_qmVt~QQfGJMPcPtLx6*d>FD5KHx-?iuGs3Yq`PT6@_~X04yPs#IE-fkypD)DDNM}*e2&u`k90%GQBqj*s1d|Hv}b&UlxSB75wTj_C6((-lmHUs znJ2Si*)XJ!*p=0}@QoRxa(~oLdTQMiDhiT&ip7I*Wd!r6OeYuTO7bxN1f81&nuyIi&z82d{h-}PE= ztA>df?l-$Zd|6@nVV)q(CMG63;=0O_{mBY0F604jpZMe$busZ%U>FA8mA}JKU8w0N z}(qo`-RyVMO z?zxSRy~OQ?MtO#Wh@7?0|6{_05S44j;CP8x>Q^9Q`8-m$)rmm>RLN!6}Yx1?ut@GPSyEQp(dNvQq^=e2qICRX3-vq#@f2uJUtg-pU|+Z@f~A@{f*YxV0Tpe|AX*9r zMs%vu+MGx?Hj^%zR(xVP715TVnU>{?)0m1W8V|HMIF)wvLH~Tzru=m+xKq5s!q#5Y zB%Q!cl#aI8M%<|~_4}Skd)qL9?8js9^cb~fdksniu)97z2(_=UWYUx>Q#|xf~{ne~(4^u?;mKjPr98 zY@YEZdvaVtfY^Ifv@Ir9MqYCVd@>m%Pj9Sg4re75_3)=3^vBCvQs?$)FI^VpA9$84 zEd#r%zwDpg41j!CKJ2Pd(CY;q{~!XU<-Pxv_(dhU9sIeK3?x5BOKbh(D39qRvR8hP zi7}Lw0&}XQ1u<;wTN3~4%k(U`+NIPgix5d`t{SO1GHW$kHhr})cq_5;S>*0ZJ<)pd zFnD^H@@3ptvS*AzMYHcII`B5=k}%+CdQ=I}&+47rP1Ji>3ozcWaXGotsaDw|9v_!Y zz%XZz zLgKv@pQ;}SE;VaLN?@~(6x^+v>sOJ~3()goG+JG>V|Xp#+3p(4Ev71qnf3Ks(6P#P zqDt+^llYmM?qU#ln!Z#2lds9|0JgL3NS!#%}Dp#S>83h!oiz-y&uddDBi;w$}FCv z%5=}`C2KaT!=ij7GVJCfWn8n0a4r3)S$f#_T=#{NfT2F1#+}APrVl^zi%<0oJKk5! zLPW)cKQyd-=wG) zFFIsu4$a{}zc6{$5kR#yhi;%o)|D*Osaoo^j5r&lAfKfko-I!VuFry716O7r&+}k! z-aI-so_t^C`p=_R!TTP9yS)*{VvjML%U4{mpi-UMs+ zZhFp(?rHi!@4F<|0q+!_t^;;u#Pv-c&0@sU+1y)4$$oR{Mh7pw#p>n?T(L5 zz8t^*9$V1Yom??nptsIkvH#}!z*xZd3B&Dl8N31Q8DzoP7@Cph@wp|@4Jt{kTgBo3 zGFvUFhF&&%T>W`5s%f8T#VTe_=f+?IOJaRqBu+hIw-F7P0*0wk`E{?=Tasc3%2g?w zQZ>b!;_xs^i$*3Cu$fLBL}O}Lg>C)zVl&*$+lcP0wy+hUULwjTdyn$TiSK> z$f)l4ioPNu#YkCtz=CF=Ur}ggx$D8*i}*xuKEm0|4bqoLysRTdW%7k4m@}`>?n^Q^ z`UUhhPCsu^npao-#OL*ZA7rSI`XojjWdjEXB7&pF)2#O8#;aV8nWO)xSyqeh`~}9t zXvp!%8;CHODO#7xrQc1rRxP zjdLJ3qIJ?z`?(`fi^>jQ@$yQ#*zs|Kbbt|-@RmQ=IbBKDJ9!)*o`U@PpV&`xrcjpk zp9eyeziP>2XvP#|8Z|#i8z-bA;n5R)kE$_9LQZBl>E^>lVN`Y;AlVTqpGNEF>v&f_ zh1SJBujdg7Ya{q$a>-F-Gbet=!oottPMyAXdMiF12VRjyIF$$4FmCcPfG^2{==iGw z;OI(Tt#dTvDTb^4K+j8Zh5LnalMt@ zSBA_Sot5xTFf^v6&iV6PN79F?_-~}$h9TFPDqNz2kRvFC;72`^5l)ED^8+>P4z|in z%^nWz&cOVM9C>#@GtQH0Zkb%QZaPVLbESH5LGO86PC{GDQ~xT_M|#V_EDlVJMS|eX zIVM`Ih2gf^D1?!;oBdnfh~IlOc+@F}@(7=c7LbxzU_ot>ieefrB)e3B7^HmsZyek% zEF!48VqE33K8-Izqlp(BTyn`Dy;_x(zVN&O;jd$5M)Q77EVo(cqE_>vJ>U_ONs0D% z>fLm;N^Nq-{o;B@(NvAvXKGa6=yq9<(M>$lC;PoZV!;s@+7~tb7?a5p?BBh$&+Cq8 znxqz1Q1IKKibL0ja%yt9dteF|qj_f=KbD~>7tRAN_SoIgl8V0cP+3zwH0zCBb(6EW zy95kQ5>sxo8u`bMm70n=`5l$l(7FBNgRP;v78DJ~-6zvbUN=jaV_Ge9B3{?*S2j34 z(y{psb;QAah3&u1t?ZEIZ8*AAs9UTbLdjrX<0qP zKonxv{Z7C-#Nl*ljOBbn&=*#upMQevY9MeRcH{a0`;&q!F+$ts;m2~bT-ozqmZ8)* zTopC)6R*n}mTO!h*>+j90)xvM-}BHT-AfjJ$>HJwd8c|w9A+WC=qF5nqCB{G>rK4Ags2q9774<3p`KRr6%;>pXnB@h4c!O92Y=&J1@{0FD$68Of z8EDF99QrcfnpHdhAEw?astvGd7sb7}OL6z&4#nNwf>WTlySvlkUMLO;P~6>JLUDI@ z?aBB3|2}))Wv$%gBJsdzAxF&BAZ@Dt_CZ*ZR~POSNIgC!q8GD^Fm-OP0cG zL#o=Gi-&;CNoF(>G}>7*EM_a-8HiDfKE|fRwwRNsTZdECg&L8N!=&Ra!8Z{^0n~gR zWPi$Rs!+RUdP&|mMCk?dgF9v7$Z+hR4WAndlp|V4%^NIqQoHInjE5AnDwFyN^?>+o zDPQm;aC=u|lwh{5bwoC(4g)e_b7BFdXRjlw`Cpzp{dkdr_@Cd7sTd)c_x))O$`2(1 zM`1~=F#kgC(ZJB>S!p;A$0w^~p^Y1_osb)P=qaRQ+rFCD$CR~|Sh1(J5je%)i`Ici zhPq!|%F~AObkB~gN=|}mUpiHQZ?3?!#-<~D(lGsNPMCIe;XrpM zqq!2J`XKV8&mvoZ$FnSM0h`s~jz1TB6VK$4sA+J6%G-H$L;TU7Hz$8lz78?f_eqf;GK3n;ybMPojw zVfn$UQMCcr4iZULOLLrOL`GvX*+=U%&`3vH&3FVdcU%y+BWUoFH zbv7XGoamHac!uqp5Fmw_gZ$HwwY>Cu6fxoOzLR*v4Q!9}P^cuCY{@^z%pY>NyW)mX zm%02%z51}A<`?Er;@}J4pqAvH@x#qqk5W?|^~ZZZCw}oOuJ@fVc7iK4_H;&;lkzd6 zS&N7PJ7NzGFM;se{|HKeD^C%mKare)&tu^Y&wzTPa6a%Fz+acm{8tY6Gx1pwdyp# zOJ;1y_78t7dH_1H7GvspmMYH&T|Ee@!N&-+FM&FYITlk)9uo$?;-bSH8ok^#M96k~ zrh80J3Pv}~R&nU30b5{~JF<)64_L&RrP+{^zk+b#QvET3^!jj~lVq>Ygl(07e{){S zy=lsSuBmmBfE8Ju+fbfaK=rHRKRQ~amLo(~zPbL(D&y?_(CcVHe}r@NQ=1p?H1O?a z$_m|;x@>CaDTa%m*FUyDe?dqo3LT^Pwa4sZ19F9PP3(*=B7`U8pAOfNJ|ZP6E3u2b z3wRed&cnATbGXgS4#6xfAA|SYwa3$#zFOUDsQR7`=!Q{hFFeqlbW6atI}`F=UWHp! z%s>{Q?}F)#3^*V}>Mzl0_5EoF;dqReFNGz{rrkEqU0ticOub;n%D;M7`x+ia79h68 zH!mII+U$*9wq2Gf*Sm0@9n1&zP2l^^f1(9os}m1Z?65`6 z5`C0;>EAypeGFr@;dO_%m5U{Kkx~5eC5S{G=+^W7(e2Nx7VtlE^SC~F*6IAu5RAuA zAnX%)vw>mCf%MMhpm;p!6c*+k{N@21_O4ogSi`-`73wz)dVXLfJSg3k)%^YAqoR{2 zo*O5*voH}+Zg!4t!*2M(SAkw^SGcR^sshaw98OexpNqw-wh?&IfXP z`w&XiiG<;zez5)Y!8}SH^MPFO>*%|{-DlCSJ)U~C%z0YXJdt@`Hw#2}_v9fh#j)5s z)V9Y*kwLaFFjQP>S|(Aym)P181##1(k-7Eb6tOn(aWmBGjloj7n;l-R5eLsuo zSfk?RJ`Y>^%m)_o)~MfvV-|7>AF|%Q`IAG6azbpt!S7GpPe1Jn6q0!zj#Rm3^K=*& zuWL3<&CSqWC~WLNSVw$SH(8E}SUua`!|NU1;Ne>=Y$=8gRDag1!wD{BV$=I29n+f6 z-gFB7Iz52jfLFs21(;c+$Ig2GIl@*gga1830_zF`NlMO=MuH_I&z@nTFV;k@4(G+X z_X6$?1)(TS7NAQD-QMoy7~h|Zd+}K;bqzN;bba9>bAlb>mkTLxK_XzwJxm%uT?GwN zVa4&hfsDp6ZWsMIoT8X>MI4Z_&p33yv2Uyc5`J{p4d1N`BP*AZL9Ap75vtJlSw zbgrmdI@BfiqKnAi_BQ0oL`s1MsEytjxhW=@@0y4q+NcTx+6)X#(vA{X6@JMKlWV2Sp%)CpuQGL@(@(0_oT!cm6g?l2J|AO9qVu~1q zmZnXEPB%oPnt8W|HR<1r!GC%=FtZd-pHhYVd2N=e!h=<~c|j=lgZ7}a!MyRuy;{Xe zKC(I!dY$*0j3Bz_)Q`GQVO*Dht0hC*k6rTp!V~R*o(q%^GQ*qT5aU;Si0RV+Yw)ro zj6C4d&vjW?+Djh&Kft!BUhLwVHRePsmZd*S$fttFio+k}ev_=)w^(zxLHmd0Pr$*= zvfc~4{~d&3Xy><{%P@6Do0o{JHhbeol>Aiz)3AGw0cQeh?>6D=bLbwsvli(b3LBFC zPnh_TFC**7dG_bfo67LJH6C+9#SR5EF3Xnwm>K9sY3+0s|A$R31q>S+MS7b%8t`+V zdwMyV4yO%2ww55jq|u|Dd>#v4zdcmFK5H%6-@o2P{K1h_SPKdf^Gfew<1h*5d7;!R zJ(s3T0s;dbIQK?8v)#B}?l= z3L$!eEPa9Kc;(QLEL52Hu;)}n`zAv`T0N&yO-Es|Rl-{_MDpa3(X-8B@VCv`)gd0& z-p4F>7~A7Go%U2VF=0#m5{jvk1PLm8HsFY%GP?~VKMd=F{H?soi7MDZT6P(l#ySS1 zvr29cS-lPVQOA)0Ke!+I0Mm~Cz(Dp7i!4rj$EH4FNE4iQ4@qreZd6_aYENTjiyqG; zF@~9$bD=f(?Voz2cBum`YC4~X3-kPNFW@{09M}5J4VU>RlpioQ4?j}w6Boo(3q*Hz zZ^oFJ%e&Zb4y%Y)LG!$y?eAegTJwm82{2~t(b6swFdAHQfMcSt(n_`wc6%Bm2Mo@}-lGf)305pGZQF9HMlMwM9f^dF$on z-$G*Sv6t*x%qLX{SeCE)?f%-Gl9+YIGH&R9i^W*VGHb1eU7EZ8v`79GYi97=l? z^C*0;pQ08a@|!W3T98e*PD43to#1rszD(|WhUrER!ztDh4fhoxZG!ec!lT%Kg?d)T z#-WHwbk+tJB7VeUqqC6y`nco?T$DSk@$m0#1L+OYK{3++tF$f#Vs6}Am$~(w1Wh_s zs$UG*(A@Ep!88P#X0}G%|7MetpPc^DA?nmoA$VpSPHe~h-~yg{Uv}!||A=7M`gq+8 zntc+i_>eNU7~RmwWIIEB{tP|nhJasc2Sak~;=4AQ7^59>Vu4d5W4w=P`-ZJkI*|WrBN31X6 zDPrck(LbVN)QmP)G}BI&+Yc}DwJ1Mn&cMTOai%S2WXp9%tY+qE@wv*M{8{q2^}hUM zY%*)PWBDwW2{lmhf4Xl*2pPGlu5&K^0UexY@kWC%a?sroTfPPr+pn_fKl_yA1cTT& z7-IgsYzUqhj3Huz&f!TQi{UBLj)7jsnl4Nd4-ufmG)5m-81izD&=Y(+K9VrE=7cGp z6IkHKSiKn zvS9>S=ld5{g&}>az(}tuAsLDlv7WZ6PlbnSVr|C_!0CDEQK46dYT78;!tN~Nz3Us3E%i0OOS+EWW@k?b4&K`MMuaSQT7Mj*#B5o-p)ZLwu=%GtH4J=xU; z0wT-GHqv9Tp6#lv1W!{!VJ_>J0EJqWvXdM*q0T<6KuvMxLnOp@#pVm=^0eEjmY~lm zUil+WhtE-GK?AUB?Lj%gtYQ9c`Q&xqPJ-K%Nq?Zh=~Z6zW(+XYIc_Hph?^MAqtR~? zHeRN=^I5L5M2)hq30a5=iu#-P2}$;Ac}cGE$s*ExHn*5zVyeO_NJewogo29-r;7Ba zx}UbZ1X^fpa)1sc`+jYopQDnH&=Y2Yzi!UJ#Vvia_-GRunP4uPc7?Z@Hl+(D*G*={4}wZ(e^#e1R}gHYX>dBLTezsrp_M4;PY0l)$)Y21$S(jFU)pjnh{l*!I7? znJs02geI}2OJ^YQ z_te5mZe=$;Fo*b87mJBV;2W-S|MZp?4M|nCoM09`2esf=2x(Nn;dF#+W%lprTh?Qs zTyjPbfc#^9RUPt8s`<}iNzD315}(txJw5HlVtB;nReU2l=enLW$~=087Xp zV&%ol;_^QevKn2$zaY1N&L29k`|nW>I^!5-8UHvPZp273jC;$j*t?QGG#~d(=G`>D zAO$#0m3^saXoW6~rJG=%?(dK{&@>L`;gQzUVC+k+>~r~E&L%DKK(rS9hSuEw!CiDE zww<|3`@`vEFRo`S)L1BFo#C^XzY7I~C;y+>c|4R%aY6umpZBXIC-u?>TjZBz?zcBi z9NGhj_kol==#7$K)yCp1G77oz1-+=e6jY~;8URj&HOYtw1{&PG+$Eu#|C08`=pIf$)lrN_@v!;f@)Qv%f zZBUES0wZH%UlkN8m3GL;$P#8|IG&zfdY3jL`m@_=;UTx@itH^>aTMvnf~Z?(5`{## z6Q((-siENmG&BPdl7w5x+5s5enA}F)E+l2Lk(>)UF}tkhxo6co#Kot?4*0va)a^~? zr1ng-vw;DWi3>4u5&T-zUo|}hATxgKTdIQ8AWdP42lV&zJ*doN?U4KSFrdB|rh=fO zUh_B(>LLt>Tf3OG!+r0oVFzbeqscOX?}Lj!S^dZ!jQ70PN%97o%`7yS;BkEMSc?7UzxeW(@1#6*gem7-u|jA=fhr5mCW9m9~($)Dl8mh#hOXESOE$3a|e z^^!Tt&h(P`o0XASH{RlfFd{i22HP0)qxPKPxvX0!{9)|t+35*naHro%`=X|&6*+Xk zd!vFeW{-5XmReM<`Z0SOTHGX*ZV#-Tf_?(zus_QLKl1unfU^0tCi{!Nlc7Hr;v!kx zZfEPkm#wk!<9!a*zfLyFQ(4+X1&jFK!PN{G^s3X4S9PF}|GJnVYhBK5xx?j5xB=NY zZ3moH^HUI1rR&B3yP60Ut1K-RPfpC?aPN0~@9676vF<~UY<6A`BOzotauFfx{eC#C zLF{w?i(5_hEsxvgh+#)qhTX!~5NCar4!596%T!5FXX{FJ;!M7l-@%WssS!QWJ$w z+A*b(@YCZwURp?=$_Is~y1Ao~H!eJ!C%I@)e3GbtU$aKQ-L~f6`t)@Aj)(}_y@Gh! z1Et4{&B2O!dgL=-bv2_2oZH#kj4k5SFlpisyzq z6{c^Mt8*FcC%zLLmrge5Wx3^3Y?o3d{6j@aNC*{XnbcR(ag49hx$HY(0a#>sqd4~I zKV+q3`V_aqGW8S|cCnF4cD6292L!Dw-j=_J_yaXiZ-20WGzPO1vjDs3A`GW163Yt_ z%}a97?2L@?MOesk<0HH4Vu)?4E|vUf8O#>W4V1{uIhxK@obX=R^R_whIrqpdio>59 zZFW>Nzc?-oY}r}S$PA{~GAr%xA}WqmJsbG){P|SwWGiFu$7en_pjX8_fF7XA0J$zd zX*Z--Ml#atSQ{&58s`w1N`hAXj+LE^aajmAP?ABxD;XahuMuwZlbWJiAJW)kLyTwA z39|;StOiVOO9Ax6kvbRZH>zSrMBw$-N0@)Y+aD>73(9M_nP&yFv-2%d;u&8HY2n+8 zE3F;4V|s~|by>7s!m||l?MQ^4^mrJ-Kcks_sxCqf?2PQj_U>K6gxzY=J0wNhba0a5 z_EY!`NyF!$XnEtc9uyq&V> ztZVrCQqPz4e$^A4HXWN){T5ibBLnR2Ma;}Ne~3}$%>Hg+h(nTdDYznK$5OIwC($w^?%B9qT$w8%3wO9p&5at6BdDg7hz@P z!f&M35F_^(_bHAoUS4ls?AI*sI|+SsvhhPTJ|3NOrbQNByqJ|B6Ca8Cdl5CCoJ^#59r{Ta-Uj4;z&0J@yBbNTaYv z5-l0n6kmnXqW_L%h3fh^6gy;f(A++$QuvG45cLN<`#=%0UeyR!uOkY1CY*wvBwJ74 zNkNz8ad(Ol$6t)Q_W%DPmy2`yh=$2?DCGXt)6DgRuk*K{pLhD!mfxKqX}GOq0Yf-S zvq%`}f~_t1=Z;-#@JQ%h-`uvDAW-*Z%Iv4Ikzt=b5JOw&lS3bg`W`8zyrHLKfPvge`=XUg)ndE+a#Q&nY)7oo;0CArV zm<+_-nrYwOU@34>QgUyn)F(eJ7R zkY4c6Zn^j&di&!l;+Zg~a>U$SWiBXFzC?wZ?D`JIq_TtI)5ZVXhv6 z+M6`sW5|&v=0~K*B*jHcw$1eT;oPfj0o@ygBsVZxCRAlKmXce@;erRLR(TpY9-Rr@a5ssyZ`pvV#Ttdq`@n zoKBAkLiON6*B8U2V3JUWc!z?3Gng0TBz}k`)|S7TYGaz?d73?YR_IY~oB!^-TRH!AQT z4{>6CLd?C!iv(Nl+W#NcbCgJz1i&ky#}UYlUc<9rUS27w1c6N&Zq!WF!>n&S8{xgrG>Yl?q#wNl3W!fJ&M>d4nf`nXQALP>=pL6M0Kaf=RNiKW0EEk>AAeN#{tm zel%lPOBa|nD+OTyyrL*-I+7?@{l-3xOsE!-+(OysO<9DE3(#n-_52!6)$`a}b3+=q zp|DL3B2n0{apjCz_ct`=GK!AQgVt7e8 zUzj$?6V{-{*X42y&ROMuxvLHLXJmP0i7)fJrg$bmsa_h+OtalXUgnWLvQ<;zhqZ;} z4mHKV7fqulV(X!|O!tW3+MM7E-tQPQ52~p6BYBTj<~U_q9k=LC%DKKZriIKXsh1bh zg+j)^$N%fjxIevxa83RmvZiR?B4d((qr2GnhIxY;rwvyVEwMNRpLZq}>7&bEgf1poRYtLBe+jr#xZRIdvogazA<>My*|kLY#HNxRBInL%@4F)Co+YVpb< zKp*fjD*F^Zka~MKU0`;1#y~qzLwcr?FA6KWNeBT^1{6NgGn{Y&$4dv^uK2;A`yRdJ z$tir8lQAU)$9agfqWQ;>=*7fNo9gJx;2fYaQlDyS5KjKuMNU~>MhPf}{brn!jv2Sm z$PEV)*6EN>*VxN|NYqkuG+p>i5YK_>I?aSH*Q z1wD>K4^R^t$UT0B9amE_Hg9HOToU3qjIS11h}0qKNIB|9}<_eLu3>1dZAQe});O_g>-mLskF2v;xI)DH&5eY2af6 zg+jAmbYR-4xOsSrIx8`kv$>G=2q#KiE`?7<2GVLBx`L|M(8G=$M?BXP!ZA%tN`F(Z z31Yhh*2^4A?02w~gNf9UY`4)9=AHAd*w z?L!Y5tLTB|QBfD)ISsQq`r9lW9|y_#T1UMPlSqt#L!1PES7fig>~j74+;eet)Ijhx z+w6v`{K<3>J64;;geb!F>*JP)QRn*)3{H+L>?$QU(Quh-%!6nvzvPMH%*K}b;ASwy z-;!!(h8nrE@0V_L?HkFP5OSSKvolU6=O|Zy?J>7&^&c-4XXje4E~^76*(sBk@pfE@ zWBo^zd%9&UMaWkKp6WR3C9s?3bWqX3`yRc2Us^@*z4l|#1E zt|aKSIfKZQPoF%+Rc?N_>-Tu`|7f=k;m786^Qk&6EPUj~Uyzufmq8OU*bG@q0tI2^ z_xLvtB4Hl+4$b%vzBM~(dIoIi$k9p#%LdI2)d;LH%gAiJjz>?fB!A{Aj2Rf~24+JfNF6E%p0lOU<)$AGx6?Mh`nv<6ru}Z5u8)gl?M!0mtY< z4r}{zz{B7Kq3@WVb7JGq{H3G@DbxtVlc7wbe+j!#&=Up_!H`blZ~5yJK2H6NmCI$u zcD;;GEdwiuBC=NBvweH}-Q)2WJv}AJV5KdtMovya62DWdopWZOx0Wc9O~x}g$PYmoc9G&pBb9#D_*Qeiw(w)tSN3bJ^d74d$WOj#@n}8*%78 zb0tUP^YpeA2MTzg`j1q1sCh5({0aL;rBD z!>Ru`9Ne@!KpOLSBwnvqAafQoQLx{Sy2l~>07r*!{i8(|wsH@$iiM&wjK}Y1<%B^3 zI*Cb)5@M#r@hG|;dxnffGoB<&>z+k@%W&s08-!C2gN!buFM@{83-@Sef^EJtv#sRD zujm>YYccB&DNlen<^u);zym(P#QI6RKsCIK3!hGvo-1&k>rqNunL1XHSxr1RBP%5W z$&~$!T|q)Io@>DI#1@rzNbNU4tqxXW4E0ms&tWz}$52H)Ebo!FIS{^(!h(~UdEg&j ztXGA3P!g1m6o3GiG;s%8ez~nF@iSyMwkS=TfcD@z8O=WMr`Ilt%53 zM{e)^y@~6%nLA_;PZLMqjJo`>Mvsp>nCWiyOyoZ;acnKR51aGoX0}LOk7|Mm+c!9Y z%r)ghBctEMnBonR6`-oGx=Q(pl)N zN>>Hm5rAVJ$16M4xnS!;1+rSFdO0rAZ|O0H7#nh(WUpf3WQ1!~-pl0IFjDDGzEhFY=%6=CRPs6MOKldY_uqRS?9fo&hnGda|6*gZ(=h0E$~K5uO#{RL=UdNxpuft1w{9tjM6ZF z0aYGB-`xl@AdC)`Jj&hfN_ux>`}gf+AEXIEu8LI zyRn%phW20gd%qp6UYgxNOqR;HyuTgszDaFYNv>O#o3XkF5v{5Oe4yTcBn8F4h`nw9 zQ-+7sI)W$PP|pQIul_J|=z6Js1K{D*w0q3!Iup0#YUd*JEc33#z2Nv2B~=ixIy&5a z-ABhn39oXxul4(YwS0}~ZE@l9I>_+9eoReRT207lW2Ewz@T3|GFNxCn@eC!{@2S^} z@${o#4srAj11AuRR7E$1Uva7x9s1hMLcCY|5yTq#Qf8e*&aV{tn6oU#x{$#HzC;*c za9iO&qUHPG!=o9&p?2Vv3yK9ar<)5I_g!(JJ!V`9p7!qW3e|pfj&<28mJvzEeGS9{ z@*mmfC(jjGz(||<=dS4KAp?z)8phq%UBIIe9cM-;S5ZhS{<2f8D&y79?Qb+b~2GDbZQzZYZlJUq>J~c8E_N1 zUuP<(O%#+H->TUKFCV&x!`W=#y`Wy#B+Ti;c~SQ=Nx~#qxxKBlu?5K`TNZ9XaPLPT zZn!U%h6P(lFWueX>aL)@2V>HJcqK`nmJQXoIV}3@wcn4-22L%siwj=5Z3Sf*!8i@b zd(c!Uhi0Zmc&RW<*K>3v2yss_vCJY32rEo%uJ)$7&hZ9yv@c>@es}F3BhwifW9{n! z?H-AJ36|^(D6hdRiUfdY=+Td_dpZ6%hlcp&{-@9Lp?DUXg|_|bVE*m4k6F7J^w#>3 zvO$>Xs~LQJm!-LzR%X=ZAJbi1+(mZ!=_`b*z*1Z!H9No$iA)bz6ML79dMdu7-Pa~0 z76^&`p+-Yo`!(@d(5NC|AMluAVZLS;Et`I9FrfDSQ=qa?$e$fS)7T{h>XdkKcTy0h z{NPJG8V0qf0GZYPqHWpZ`Pm6&&;yQthFO0;$He-M+=dJ7CD?3(7a7c|{-uuz&(L8G zD6Qd3c{0lPXOH~g*WQ4od!t8TYa8D23#k*R%HH7+r{2=?K*15E@iMICH`$BIs9rf%cvj_C{eKP* zL)8gOz(*umH+%X-G}8WUYT4N(jU)pXbS5czOe?9cTgS+fhGIwb{vU9{n^r)Q&IdH) zN%FBlrm8hnGsWvzb7!wvQ?NJb8j6Iy{;x2nA^O^`b{Qm8#?!U zHw!`1=Ko|L(2cd1m`CVB)z>LrcHiTl{CgT+MiBa)V=+$3(PL2)T!-R$(Eage&7&(E zOv5!SAEcJaP#wnvYs1@M(H3es*={mNH1}x^7Jjrk`%yjvmnIGP#3+}Z0RTQG;Gr|) zH0Y?D(rQ(nob%lKy>04@L4;`b%M0S82t_vT{!$@MW7X-nPhdg!1IiJSV^E=uHGYh)dz1Ps{hUXI~EM{V-My!#muMX)f@W674{NKW!GxL+`mvsZZ zEnz!P=+g1OL-cYf2+|p3E8fdvME|U9Frvt^M)Ft%riRu)L|<0Y6HL9mA=ZyfTUO>zb)|*N*NM$j>$f0bjbNQ$2Zo6x6>Vi(OKM6}NGn7WX@RaiNzYDpTuKGv!i>G9DdA+Les$WTo%zxLY%a3GeJUYd z@SoVXHB*^?Vp83#oLFdD3@aSs?uEz-?b)`twCmYYYsg;JwdaF>l;C#Z8z@WQqc}7d zO&Lacbz8DBRST_i^jck*Nbl(}6*G^;CQ4t2zoAQzkuZByq*64f!FWcL{)~=~m6W8W z{ZKylZw>Q?5#p7xi_t;aC>wY)Rv6Pa|M<`$?`T*;hb&?gT?;xERLRB z253u*e}s_vL?z4$gZbfsf*kvAl@ zLEF&Hfw0c1?Ej#;gr}*cHAD|@UK6tq=NZX13K4CR@=!(z9uz!6s>Z!P{?7qKA-=!P zTl(?M{+qb64;xe5oLVLsqG<>e-2*g-oNY>#ZberWU^sopEX{*6I)~J&1Yq+VnmMq8 z8>g$8bozUpJ1y_M^xvC)>{odX{)qKZ zOrU;`zGJ7+Id-T(TFM@!%Qb_Vgi8K{q%degziu<#P+LtP=MdubAi61srP>VcY(z@+ z{m=pYC&wjF8NNMn`|yh4Cx$<-54N`j1>ZA&n1~`uvF7kjmbF-aD%iDS|1$RS8N6En zYkcj-C$E;va{RScjADfAYP3&$ruQ8df@7m=2kfO|vKA4+vNEZr#Jb1X8+z)fjU^pnvv+4>r$ zFQd-`B?M_{w@F7&t+$6WYfEF(t9Gikixq1< zm&R=J=SV6OmGG7|$P2rVet$ zMTyo5ckvESe1WOM=%sb94_P@Gu@=OFaemWwAtRq%4uk-SDO(6yuIYpHkPNe)u zorL~#2p|vX3ygN&W-2qI>?VnfRYid`0pwJLEGcQ*0K0qfqwCj@Km|BQ z>__e&?j07<|Sz84_5V=A!$!T9oci>Pzm`m1`S!EgZe0pDz$nqiP zcL|Mu=y4E0y}F`YKYC%^kYg?cvbs-9hKt_tyX^FeMDL&Axcl!%%&&43G({*UF!$#2 z>6R<)@*7^8XYHGh$N%!&5HO_^aaU8TSq|Gimbq>xp!z4sFJ2#W-7ZLC5%NdaYY2Go z$|C%tgHTXoK~6_dLo%{1m5e#hhxpg00Wcr#|JFX;+}14&?x@uxqFcRyL~^huj_*H>Bm!YPS-aE3SvT{)a z<_u=m`qgdBPdMoM`8fC0s|)>XU+28FX=afT8%W>4!^oFDs}^bYNX6XtHdR&+4WUeT zmw%)!2yIBpbbBsH5Ln$cC?C!<0DQzBFDS*A`>@3hoa*Dz|K39nJ7O8$)@fY{{3I2*bcgtoypYp3N|)Q*J5WPz4&Lih zR=PVnp+806p&ZTKPHl3;$(9;AW0Mcjr;5wN&>*La#a{JyBfVYE6Br?v_G)~kkkY zayqTSiM8UcdPq`Q5yJ4SBGDY`?2HUBE)6Hr&A5wc*nMTOOKY*k1E)zXlX(fvuhch| z^$*f;P7n8Apsmr4J|E^)cylKYr&p=>SV+@F*Lo3bL50*s4$(mXuH*sV+!&orQbJLf z$19_1-Mbzgmv0Uk$%Bq$L1zm)m_WBaUTME3)iFE#Gh1^}-oRYe!ST_l!4Oz#e%^(I z)jTRBZ}H#+x&uB#O}%0j8~@)-zc_TLq?`;mrj}_zBA)Jy2j^1`XLX~JHeTck%%_-a z(b55uTm~br^Ggyvj8%!j#@qw)u3hU42&vYbU`L7pPcnnSc|`mDT(b+Na13@5(NH7r zNSofjl2C^1qA*-cSU)zAUQ1-aR*&qjMg%8+LXCOg+tVh*0^!asM4z*@yK)Os47Um* zHnQ3oI1b5QNDZ|A)VC2gJ8oK<*wF)I*}Kj-k$N4&)S7S`Y_2F^*$nU+s%$X79cMAf zCv3q=eLz74L&-}?Xvs+aKx~6=YSAJaweIXH`Oft~u&plk-6J{F?|ih{m2q`=iO5&KLvF8Qwc6N#6dHtDU(zW!ub`#W_7I4ew7Rnzt+Pa>P{|^us~%SHoiLt(dK`kkBT? zHnRo_($En$;pq1|Y0!zwSebuhOGOeY_xD|ON?GzC z6UwwEF+z8-(AW9toQiHpfRnNL;^}9h7?oJos93&A5lH5PV*tiEdiiH8hSyraW7+vC zDCpeaZT|I=-LIyO0maLW%UzND)n01*f@PKt^z&M0xE)UMa?ajZa3ex|#p{eSbnJ7l z=xv2Q&xx7^gqFkI3a%hP!($drtsm3VmNZbv=cSuwXZWh4LuaRes$PJd zs5!WewkFlNBSm}qEs1@JIgw_s`V)-&U-Ab_Pikg$TU4x5mpF70?g<#lo)?hcUlBC6rfb)e)r!NHUff@gmNz$=;N(Wmv9 zkX$U3E9wC>kULcZ>~w=Ws^y{E;g#WX@$HTLFssWP8)#fBEK`->Rq5DJ)Dq-mo%qd438*u%QJ3@{{kw#bzps6Ca?%UB?$0DD{3-BFD9oR?-~X2YGGBHo*kq!CRo`|zN!CgQ+u#P9 zLqDx2-7dC=4fXwe($?WTG|UUm!8j}>121kZR|`(c^6RANXN})f{}>FB_2JchZk^AT zKJ}ChP)?B3z$ue@iCR#D}?MQm~xRZ1s9y_X=GpWiDIy-O|pYSB!R@ z+l)Sie(|dLc*st|ahP^h;lphK7lk2;-lmKjmuKTTbWQAY!KQf<+2XTu4#i_bLFZ!D zg#J6CdL($Pdfkxh5}Tz{vCiYzn*e6NND~8q%#cGJ_U-_8C2-eQ^Mfa4g zV}}(=&t=aI$jVF5&&(C`xk0j-RyH%=Av`%bEeJE#fakS~kTjwXhEhZl?w?arUybj| zm(kMErsQHq_3-eprQp!Z%L|T4MOr(Wz^ADUf)a*5H$;`OuDS+dp;K1??DUuZDHW$2`< zCyPL{=Wcy~Q`Ar*s8btJHmR_>8XIy`XlNqKBJ-QFh2Md7cB)H}AigG^I=|-UnB9F& z21R@d9v-}(c>Cp}@(_-U-^fv}?>Q3J{ilsZCe3TB)#Q_2AIeRe$`*3J-u#iVObWP= zI=m5?TmfJB=xA>?z4>8UNXAw+>d zZokQtHsYWN`DKBgM8?L-w%_f)>Ui=hD2~spt;N^(8QmS;xKuD&EeT~+uMd#HEb|Ff zB-M|6-}w($`b(RUBrV!&hf^8KzPRkOP1zJxU3u;w{KAbQ!9Y8dRy?xmpFP$@IMJ7> zL-@41HwIrbZOXlP8CoPG{plk16ZpN4kaB4A`CSmSwLSPDl;L+@Vy;&^mpt!714Smb zW!fzxZEm_4uis1-Ge5~fCI$#`@$iaR#89IYc3o104^f^)WX95fb$ryEoNw@6dL1;g0Z!Dhpm( z#Ww^{eCo;KKCHiCEFEtjN^3!qO<&kRCm!eqnUWGpT0ItwN*$c{V!2_)*zqx45c?FfwB|BysZmjMp%zDrQ=Rp+awt(jv# zCvyNd<3WKVH^uGrX~ zannwJaPu~;x}Ss>Ri@3}1ciw%C)u3z^N3?&V<7Z|`}NS3fJopItza5GW>XDLjb9n(5qgJs;T-p|K7=U~z(1CEu=;L#8|rBA$Ou0X7^C_(3g1xB5CT2b0K)_D>#v8i-FbW(nl zGT>K#CdCYM1+$!R(g7ZSimKhv1Qa$fjeS@`>;}+z+x%Y$==sBnTb`YTW$?WH(dF?) zdxG?$e-L)>aOPcak0-HN+q|U_O}j}H(n zZf+My7RbIIzvyZLu9)uUodhiTnlSIZK_f@6Uzd`@ z0D=7eZz3SsX@VW!5RM5^q}9kcB}K92Qr{Yg85L6bT1v59g9nOo}lj8@1(yh4_QN zcBBq9eS2N_2L%lSe239N2OW}0hcRWpqfO}$qY-T0B^?Wdg+t>q%2Kj)Lz|`o56CJl z;AZAxS3$!x#3YYLfK_y42W=oF#bG2_Wi8y914w1*Q>wvdha@uA7FXLd23VoLZiyxo z{~uG|_!jy9z8xmpwq4t1H@VH5vDw;e+uCZgZF{pZwPD)Lwm0LR&-eG_K8|@X|G?{g zo!51$!}o#QF^6dE2XLKV(TF6MOn6y09;j4GSThTh*86N4?zeYvP*hO>MTYl8=)i9F ze(Y2;3<3pw_B^KnaSUIU8b(`%Gy)lL^?5?UUpn~p4@edI6Z&&C8SS*{Z17NF|3Xz{ zxlop#qL9pC0>YmpHkT20^fpwHS;M_oC$w*7A(MLvV=ba+tkQx@e&L$7BuAND@K<=w zCUExB-N-M`Hu9EH6on-#QS*a%4z9hP3?on22*11{FW`wMp~Xll%x$+wXaA+!Mar>G z_V{8X$-`m`UfD1=SkvaBN#~!PlN0~XgV`a<+5`c7e|Qyg;lxykkH$Yk=bs=#E_4b$ z4&Jd2GLF-ikDCDmM2I!!+W3?=u2%oS-tA40l!j8~VJ&x~VS5$H<%MxR*xhcv@18rY z{td8dCn`0VwMl6{^J~yzMr_D*X0G7Z1%s;YhiB5?J81LvK8$b&3UL!50ml!ho4vN3 zG71T&b(QVrgKw%W2WF%j4dm@@_9GOCXK&t-{C4Kbg_p-|=losYVF$+_HtvsoAx^vf z`X{!i3ka{F|FhNCZr3oR)bqHNo$OMwpd^d`0jh1#l@*g6HQ+O&LW*(cj1SlFd5Z|N z4L0kDLxy<%KL9WM@-79Yvx;GZR%(w8XPsteI)8lYhZaz5}VN|OR9WA8Pz5^_nuOcs_TWqW&RyRgq)=IeVM22UeE#vkwA<-Y#{dzV+ z_L?4%dG46ZwtAqY-?1BgqG?}Sse-cOaIDyOHa~U~Ys7Zr^&dFhw~>|ed&8nNtcb+# zFNK2QY7C^LqzZy)52U3>AnV8=X0X&UqrV@_`V0C17b98zr0bP$DQ1l((y$hhnSCnH zpfZ~nzbl~OCV?ipgD(wDq|lrgkp_hzQ>R;UgBEvzURVy3QVU``6(mx$r*tmdP6Yss zpyP^|GyK~W{|Q2p9wURG9e!{jcqlL7REW3`iQo`$^828^V5ohN()SS!ZC#n%bjCI8 zFoX_kyG7eWLLYW0U>0_1GWokG9Q+Zmh_fR@{)jHR$5XR5!hA~q4mKf_--l1swplse|raQKbx zP(|4&OvQ^CnOSe#jx3j>FV%1ZN%z*eCWCj(gNesp!()`(LB(G6X4q;09ZY!{vQB&v z6*Xrja`pV|gG7o>`I(xGC%(7C^^)p#?kzI$yUnv8J;b*sb}BrYm_ItuXjW0>ApX^j zbzxy4DL+5)!UE^+3@0$GthPqdg6bm3A+Y#IdO3Zgagf2 zt-4hMtBpO6|Il3B)U(EF3Y}Y;rOfPPw&V`8Bl~re;X^<@Pw=&?BDl9?4Ql?+G>o#9qhDB4xcrX z)^5nnJ$n$sb1x<2L#$WMlI5g#{4204lP)Wpo@7C5R}UcF7o4Now`1SPMWp;$q zC?5en#67LSD20{`#@sA7>B&jF@znC(wAqz^%?Oc+V3_OPu^=9L34ebP+f(ACtmM$b zZ7hSnPodnYhdP^WFBrf2-w>!pydI;DV`USs;h=A16Z5vNcNrilYv||{0;AY>6Rz5r z`G!O3XK-VC$(Q;GF_Y8)8G=DD3!47tP|C_vG(`gRJ=Wp|jl{^-CUzZ`Lm6xfhPJuc zWS|wMtEyg_B(??BUc985)=H$<>vdpq{u^M3rWkNHewy7Av|B#6Gj8vqjPqPlf}5U` zk)EDn-b>C#CUSH_G927kY<*4?(WuYrW}wla6`I(JoCjvC3sb^Y02l29H;mJEgj=&0Ib(|-p>;^q?#W)N)+3_zJkii~#=qI(G3pP{Mo>%VlO~MsdT6nI zVsWJ`Op^`GRRs3175>1F?8x3S0Q1tDvGpmOk`Q@GR5>4$uUUKz3GcXoS#b!-BHPf! zWp;2^73!nwE>$gwlsFf`A-mxiMRcHYZl0<1H)X7fVq?#>i(n_Q>1}UD%F(TZz6Jz0T}wC`r$kj^@C%-^uoWr-7>%ao*TYO z{Y5a8D2h3Ra2b2|?rUm~f1WKxk?+g?*7fTpnK0nhbMNI<2UU#jIWj*=1Bsb8V&_;=U%G@? zYFeO7Obm4X@((U7p!2pRIJ>!oX7C~81;Cuq%NG*|3;cp2@NzNOCc;Jj08zrI4sf<$ zMD#gex*-)D={!mNUPGY250Fz(mekb53$FJGyD~G`C7+I|{4!@S)VRwa?-D?Y(N{3> z$T~f5o-?qP;dQhBTh~82o6m_GT-3~t+x8eOt;2R1Vc6A&xF_3BLy1eW2yD<4BFtXm zS8YU<{DSJOGz)XKh&IaW&0?=f?aZ(nhVjs0O7o;=*A4~l@Y*dlR0da_B;FNyypB=4 z4{rsvF=YySB3VUQy9Ai;pk}ugbF=EMdPl{^0xc})6ed}8cJ1{(htGDo-6sa+Cq>?+ z;M@PEODejI6sh8cjQH{NMH_ttiJ3H5NE|m|*JzW7=SiESwN6BX#m`Ul1~nSjL(7IS zNvRsN@}1Q8ld26G;WvNbyH6^;8x(ENWSs&Du9fjmY!yIqucLjApw2JISDNF%_xz!VT}6H6fKezX{7H z`nf!-W^X03(6iq6F9Wd4M%#!u&Pb|}#<5zAyf_ubN)w?S63!@yk)**6YSNmr9%@^j z>g# z%XMA+45N}rPYN^2OIkn*TCAxip!1vgpTyrw8CQ4Lu5!f$yf5X-Kql@0$Z<7kD%jK- zwQAv>qsv;L9FhPOFmC*yB1KyCA!{%2rN;vc)At?&&=&BHL5*qV+@=8VbJct1*^IOf z6fC0?i+w-ymh)2*d(n>xdB7{0F=g4Uje{_l0g$>Sc@@m%5)Q^1K1@fQXK*a!km7OK zb@XFRv7Nlg!jHI;EETa@1U`R(%L_Mtu=)ry?d@hu_YaJt8A{UwU-79a3VeBXMlauH zIeyT3J2sq97ex$vZZ=N1U{)pvOm_^h!uC>9MH`3sbP}|(9+GwP2>=&4FujIOU#7*$ z(;U$m&SqfHyLB;T#qFZLd<=4l3v0DvG~Mx%i$)3KgBH`~Q%M=_#=>-?3f@3@hn9fo zg(SSa1twXl<*_LcbjcVhbOD7{TTruKl()~GZadtbkb`@wQSTs|3s-m81wPBD%k8q^ zT2YsO#bFZ!-N@-P&w} zsz{OZXP{&3tLyFon*r@yIrlK0I~k`-nK&+t#)CEZi2=CD2G^zho%4G%r6^X2Y!3*iKWJdsW~_>!lk7lz03WeV70qT5}D0DEUD78AQBbIWgaB0ykvkM zQ1>O$;aQSV|E8eM35^Gh4eTA(za~VaHT=N-{+-hkr-`}xbP?iFAOh;8`^duGAWy2Q zehrU_AdfSv%6Qs`GaeH^uU}H=@p3<7 z1Rrv$m-J^dE`0(YwWm<$9}v*}Fd_x_)e*u?M)MU~xhDM!8PjUfapNc{`QJ#61SoQl zh?T5#ZM6#9Egb&iAx2Eg5VhLYf$ruj?#B2fwR%O@ftxVE=q|Rn3s1>$HeeJ$P~uIN zgzgmi)f#0Fj!hMd7s$~^bl<5oPUJ}K-2VFZ)OVLVz%KlK{${FkHC>>A`zM@GVinh- z$G?K=X6S#drANp1lA6+7ii0yqV4`fDPxE^@vwu3D!hoZ!$bXg88H=V9?9MW$N&oAUYYjtdc9jpjp$_n4+Yu zGSR{fbF${w*qju2+;m7{bQ7Gng%V@sv6RxO=wkb&tosz@4YO}UEbT!w5|Q;WVNj(Z zVUHgTNOe;o!LQ@_2DCUI4-I3}@j|vF<#m(hl%fh@lP#RfvOpaT@mOX9Se~KWa*|qsU$bLDhNQa925a zzc9M-2@Yu=52PH)TPg?gc%GO%Ff$c~t}e6mI3^Qgr9VppNrOKl0Ch z>*yDjC=8tCzLOJid)PSakbS)sxSo4Mz=%q`*uLAt4X%Han_-LO;KQ1qcsld2l;CNK zJx%&1Bl>I7(#2lV#DavG_$AY`DhmO$;EL#zz~$_WBCDuKAPHD)Xh4%RF`?Vp`DAj4 z9oDjjee);6Iq5MB-%LQF;Nr2e`D_r%p_22H2ky-Ix}37S8Qj;S8bhE4OBT|a<|1!F zS0_rangFN~Ei}xL((SyH(*t!By_a)2&aa({z?Y|nvb+eE`w+9kv(hxio@GQM(LydAnEhwkw6C@E7qU0f4TKN_jpH9Pl#V!QXd}m@!cNS`D&S4H% zzlEfMc#)I+_@p4~s@Gb$*uBLXG^j_A;+N}d_N zEZhT`ds3sI$lB4=O=_js02f(QcneaCQN6UzCu1wcN0`CgEK{4smq6EnPr|~aD|Lou zT`v2`e~V@25C8yUH#gkW)Kva}SAJ*~cA>zlEeQ_;WDnP7HhU(^bJvDqZjb%ib8Dws zX-x}gz3eL0uSrAL4&2=Oghcv-{h+vs$v@Bu7XwM3%AB#|$!NX5HY3tf9u5`Ej*VqN znQ(^Yo~<_BS?uG<$tg$nR#V`%))rWp!g^Yvl~@cHS%u-7t|ku2ChRPq|Mj}rLxQYD zczFSDczv%Y$31zU{KM;y%=YdLQ*shXMG|si1h0xxv9K~CCbI;;F7aSD*h>4m3aKFw z>s#k{8}V{`C236ydl1t^sua{dV0u(_P;xS2GZBq|KHr2imqWo>%rvyzh8C8h*a|(X zE`SmWfwHUQ-P~3jT=V<+IDk9rkc~g9p&gF^Tmph-YGyWv+;N&9ZDs7+#T`Ab+l2=G zDOhZHzqW67v6hHh)v130XCdL%s~<~K+-B;g@@>aXHD=;5+`>d(~Bo#CYwwBni?n{i!bya&1PMP>%f;j*k>YTlBwjGz!Opi!s&lT$QpW9GAQrfEyA^> zJ=`rydjto;Z1g2_nIG;qqfodl*VPq9ztyHK^HTpRgBx5lh4G*QmlDLC{}}*R5kc-GW|C?dx-^I zaqxfAfZ{eIld{YJf6&F##dE)Ujos+*kkm8o*^Di6cbxWJ`B+{A92v`6o$rlEZ8hN1 zYuycDuA9k{Onk`LTlD<628NVo#8WKT`BQ;sVjoH-g2w834nABi{_oYEZ_og=Rm5o= zF)dy=cGmvA9Ptan>INL%hzp>&s}GBepVE5V~_5 zy$(Oq^fnOEjd+hwC`+3J(5IyCAeMYLwmTfYU(wb7%Pk<-)xR!;2#;=Xw0diNFdg1a zfdL}{_-mK`^JnznprqE9u=^tgMMVS->a2c6C=W=?aYhPLnv$@|#TKfQBTN(XAFNb0 z1?+9)&u$3?W}94Ym-&=}3A2SrF8L>ff)8ZUbMM7R#MFtGB(jQ^(0khL!6ezxkVXM` zkA2Nnz=gHTMu<;V4xf7iAea6yxA#RyEe19+VQc)F+xnOovcV!`>iZ1M3lNf&oKqF< zp!@5y3pgPtX>gPr{Gg7fU9a?B8fWO&R$YyGjnUfs(b_;Xa2H}LMQ16ANo8&<`%;Z$ zrh$n`96Y`#2^c#V&`v_q&u0u9H6NnmL@#rMNzYUy#8GWT4=JmoiaY*dt;*bDsDYKA zm6%yV+TqGwV2l&}?ed>7+VaC>REkTNih_h1n(e5^RLXU!4;EO{;dn^u3~r4gFplBy zgDgK$WYqfVCqamzobIOUC9cQ=`VZFus54ie%)qeDeMTl0dPop6OD>a^76`q-a30P!whY1C~M~oVWTZ~ou2uVpA z@Rv!}kKfIbG*m-kPJAGr%)+|HA716TPvwoJD7ApA;{`dr>$Hp75^g7a$9dKH@8Q|{ zoO>;6jq;;dS@={*%Xj^bctQZgl;rg6$g~`6bG()2s_bd8{Y5dQGXqmgtj9q@7=@qq zqYxV+4%N2W)>z5|2Ny-~e5dnWVkid(p6E7{^Huz?5jS+((fpg6pEdC+aHXHnBhQ5M zN-h{fjO3cB6ES)=BHELh*E-xdzJx2gvAEg2h)9SIg0p(C!MVcTHv#tcXRjX_m6c5J z>vuaT_4OP=VxK+@iYLN&xx4>E2KQRhwd4m&nv1CmF?_c?a3vM=B{@Z$Ss;V)Ama!Z ztn4QyWLPRX+&S$ar2YZM#+;!*jC%1S>>%~U=HT>79Nvjlm>BNvf6XU4O*@vjuSpgC>X&a+den!l^A!c;bL>& zU=7BGDeAzB{@&YdZ|3&EJ+t1%6+ERbFMz=VW?M`=s?^xj)rWM3&aZCY)E z{3;sB*EdQ#d^oZ(uQQ&aaHXOG{!=oGY%fOa%W;iMol(pWDV$$YlB-$$gaTE6uA&fi z1Z7PV6Zm>0an#g)f4}cy+kww?s2jjJR&k&ElWay{`|zSfOWFY z?%?`3(tYe7ok;Z>!%wKY0P#MUvmvFS3R9M!!h0JMxVS~MRKc-27=lNLo*&GPx59g? z%92*v!i%EC&U$E)(+FcnvtZmlYp0k|)$m8xNIW?XIm=b+O6nPrjiLvG6$4ZbL8wy0 zw1fFQ>~P)uel?_FOWkNxR8J6-0TTyX-i>w$S+2j>_Z*9h(cZ6jG_qV|{~E{+o`1Kt8Cq{Lg!PYU_9%)C|*{q+vv!#}a(QWBYs$^uncx$Z0^BeL_fH^(W3*0Me@V zk)D*mRGDCE1N$AHeDD0NslbptPC^_U9Da{8LPeO~-rnzT&;KltmL(gCwU=!2d6)}T z`8l57j!-8Iij*YSL`4IBT8Aj%IO`#rDc=Qmu7{XQO+U6jA+{Yql@=~(Qwt_OYGsnh_xHVZNb!uHEpn9{~9 z0(EM;yBa^7YqZcuY)g%6)X;|om&9wd|BdNci1|?;h9^FBNk}qa9((VzGCA3vOKSM? zv5;_tbQ`sJM)xO>AMbuvAT2FDWGZo3%SI%MB@<{cf^@pS zjx`=HqZ=SBm_0Q=|223pgcQbk9$dl_R3gEN>waF~aDptd!f(%uIjE8e?Mng8$}g|l zq7oElL^)jgwV55tii1y_la-TW?CJ5} zhKNlqBSb`cvB=qjLaW90c$AP>J0<1+hV)!og3I=T;%yBeQke5Dzh##1IM=hhHyd>; zTV}1yiCzqGUJQVItV@t6#AO7V8K}kbyyN-l;!*Ao8F^u4MU9;V_q7a-nj*uwJi`sT zh(?Yjv9nxGHp_b&?i1Ys`(I!d8m#1XtEpw_Ma81by!UK$9`xP7t@#~@?~76eO# z_HP)m$;wsgnnbc*_tjw0MtORS6LS8yD-r$t2|y)l%oQy&jpg}NWaK^(J#SqzOZQ$W zu4u?*|C&Msjy{b9YR!?^!ryZ`@0n)U(8l$qZsX&k=qt0jt))^ru!nV|z_H{>K3Fg8 z_qf&PhQu7OKnNRS*$n-K4-@ZU#YwP?Sdaz-(<@7HQ zU4o+3+(tx}k;!PCT7F=?e167!0R>g>Co{;nbB&w@$2W<%e|lA}koE+o%VxCp{O`>f zkN029%z)Vz_op)CKaoNRvNsR6=mW_Hns<{d-}g0#uAJYow0XGb{(gRhP%ivA-cL_Y zSd_Q<1R{s0Qzmnz-xN_Gj7s-Fe+X^53_FO^B?u3i+7*Sxr6n9ze$vL}Vq4=L4sM5(-RK^S4)%a0L7uY63W5{fHH{8vIo5_j7>`V?sSBD2mib$r~Z4+r?tV9ua z+^}pn8I+&d!7fu(6@JOW1oS6MMTOWgtBXI&U0NCQqp`s}K|daiX*GiNR} zsCdF#dlOT^POwpJ84143$oV>ESQ)P?Y&QCp9k7E6TVOd78moetqG39Kq!KB{fhKBD zw8hBpg9D~)`nNZDp?FqKZg|)k;$Q?YRWe~Fjr9V} z8g<2pAqp7qsn-1Gkn}#dY5$Z48GR_TdVuw6IB-q`D`5?WuuIx3vIO~la@_R6Q64gW zfXYT#F+TYJ!B9P9kJQ!Qgon9?@E%_?P93*{9wwxXWqWxJe})@9+r0vvVqbTs`XT~q zLsLm^gWD=e((WKK%ws4NMLT+AnY)mnY@s(IQAA|)@aAj|XhMV?KC8?>-oObE>`oS` zHGk<~Tq**~4#nm;_Tz}=0AGR^LuvYB1^UIHU4TIi0S(+v|IU2?ptyMHwlfoB&g zibW78pYbXRFRIFkW)`kYUHr5qQ_6Yyh>h_SF*+1alk+bT(=q}2>KaAD(UBawo3-QR z+Qggy3K*$+BqY;tlI|!<1qdl#wdd3(MBs@e+{u?a_nSCexqai~t@GLAPrWrHX{pg=p?T zb!DY4zT33*)0uZajkf&pL5<=j;41&kp;GD?e8 z6HlMW$-PI>_yI_aveE7o))_wc4Z3@!e(u*Jwu5Y0@}`*d=jd0_D(>;iQo;C|*;Lg% zo?n9Ev+#(0%R*ONCptF+o=;D1L(B^}2pKb7T4y( zljlEd2Aw{z49Y5n!L)0-BTIlDY_;`;*kD+*8lunP{Cb?Fowbo<$8R1menq-M^Xk)8 z5k^rsnJ9;&1!Zk;ZR{qTjaSeNw9v>A!Jt+oNmENv z^8(2f9GV`25hNKH3-;RboRB`0r+sa8f$jV#WfxKVzDWD?&1LX!=^vylhXQ}xTV z;r=FRLatI>T8zYFAks&J02uvMM-X z=NWzRb8G6{K7ZY`!I=9F*`J~i^7@S)$lRIF@vLAs$}`Eg;R$|+lE7X-uA2WNG+dQ? zk4|n(xw59*`Ep-y#zVssqJtkf)L|)2Q_#^z=rF(+=@*X+-J~z46et z5M3!PFab70l;ebvrrGQS)jW@nUlb9MLRH5reo8{}YAltZlbWI;Ty@~X8j%pO=!dwA zpxPvelIB}RXGK2UZWJo|{6Yb%xqEuhmLxo}7y+z8Sm8p7n z`Educ)8XWi-BSYaIHPI@`xJRP+<8zv>*2R&zi`$m-Ewi|CcbfRvPI3pO=U#ShOvss z?w=lF1bl%3G{;2YF{Lmvvt!L)%SS26$@r;%!p1rNB+%u@9?X`lXyRLzLex@HZq?AM z>Ly0ko1dN@f$EZ-S3II#aU}5k+R@?FmLHQ|Ro?90;?|qn%k36dQ&8rF1|<2GDR@h1 zHyI8=rGFmCB)x$`_lcx~nlL496HE@YZ(dr7dBOt(duv0oiC0^(k&mzM{%d`%2oRNu z%v|u$ICC|ESS8|G7&h(RGIe_mUi9h@xrth;Docf#Y6%|+wuKwJLIErn8mIA=hwxc1FV6j6@x|dK))AREaoPr+f?kyAnl>7RH!tm0!&YmiQmW zs2Pu}Rd2~THguAI;aAb)VYPe|=mW>jJNW(6A}~8C$Voqbg!c4MD}tE7OeF9B6uC$& zBChgP-lw7JePqRiWHU_E!7nG16snz-o?PHk%jp*e(`7MSQtDmoC{JvxRPoo;%s9Xf zMb~i3b0m9v&Mi7ygro@fA=wF5Rj2DBpRTU%Z67wIWnajO^nWA!iE1fH3heLm!B={!+K))L0ajwfO0barT}B8InLx57GnX3 zv+tCfU=4SYgMrj7seaoG!IU7rWAn$#*9>_*k@<{6SkilHjUjYEu~R%x!vIq1XQyEy zINtLwSo&&bdpbYA9-Gqi30HfG@>dnsT8TR2Un#TP6uve{nqILCWS3h#OjD7*nvqr8z&sX#a*--XlAKN3f0RVq-)}i}2oFMdTh8kYtU;u4h60OJYWg z&|}$$5DP?QcslXBp3CE%_e3^-tf~qYlGCMjJmw)T=q0c??^ViHBe|r5R%Z~8HBwbQ z(7}b1@SB^6uv4%5%>Tvit}d8t-AQ(PHs>e>ya~?hE>)V4*88V@^p98N8)X;3RouQ5 z`vAkV212%z%=q{{WDKkg{`Y(n`~05q=hFM|UN_|yoHQgT7DHe(w}&lVCXlDQ?{1W-+xx%>)IKirpaKQ=WDUw?6TJ^uO) zPpjrJ{PWliRfUh$_rpjN62!7eOYp%79?2NUtHeeB{&mDTiY9-a?d=4U92Jr;z5D{x z9tein4L{@pt^}!vk{TKR-cfbAIil+)gIf(ugYq&~RsxdsCB-PHaERaCGKF@HPY35X zy9G%gxuWpSb)MYww|4!`!u}LQRfBFerd;5^Mfpm7+o;5X=o>#|e1tQA(naDYxdU1i z3zL9xZb4Yx>x`{2f2-s^{Jsa4cY8oDDriV!Q!2(Iy}dz^Gu;?_PSVuWF0QP_`JuFa zg|#=Rk%Xb-%Hb?&Ejc?DStShP;Wti%AgQ^6EJ%;edG)2AK@Y~ddlWd6j^RMih09Gb z%DxS5N^KJ?cq!hqvNFUisIcl}CHhuXYAm}{$v4St+BSVrUd-v)c|cEmk3J)2g?{Gu zyE-%QEXwojr>DZf8?bj5SAY0*NM)8fDJ6lCiXu9?QBF~b+L`{uhVdqW;IF6rB?3ZZ zzZfSVU;t|i$qXw$Gpr=-D?Goe^w)|&#xu~1TVP(f6J<-BLqWwQBF+ub`B(R0L1}dL z*sS>>OAQ7c*O%;;JWyxf)-w; zL^y3ISzwBf>J3D{R>n+=d;1yqRSCAh1`pY=+u8cBv~jZ@=klaCFy_&+Jp!ZU3%~QB zltj2+E~N0|X6gi_Wb_x@s8OVaU?u%7@ZA=9h{mj<*2yxEuLarK$Vk-A9kjd}`lrih zk&m~yvKd3|7;j4oZ(#f0~y4z_%q38Mj`S~<(?EED$dEFPE z9j>^h{^VANr-5IK?v6D|I~Vn2Z;+Vs4W^>FQ9-B(yVVTO%V)`5RO^ioQg{1(6+feH z7=i+4njS7R|3Mx#8SceHBD>tnx07=C?V<_A1u ziR!;$E&JVbE!RDXEZ4pOPFDhB&`7k*b|dB*Y<_21W%Jd0w^>Z-*{^q@|81`wOE^EJ z>NA7&h;j03b3LS&J6nr0#LjN?UBN$FZN^hi`I|4VcB_1$#z`04~<-7LMcD^GZLHj(h!%=W&?)AZlp*s}>T;AUTz z(1Fs+F_x${zD(*rH1yA`5Vn;Za?G_HMF#teb@`oA6=kbP3UpDZzdFiYcUitH_7*(M zDc&!#g$c?lGoco-IB%H}z59-0X3oQu@dUFF_glF8xv!G%{w|%`_@%w%2`5zEpp%(K zVfBfKs2?ydDwy5Jg=6)j}2ekQkbfyjX2)ZzUUgo6plql96qv;o{86W)Oi)<5f(UlEKZhBa$r zwoMWOgK1~h-$MBVx2Yj+iS6kja|UB0irV0_>8o*xFS$aA zyXCVqjAgj_<1gY0n-nJ71Bi%7{q)id_Gzty@B$3>8>KO5*tTG5tKi&|yrf0vQDrpr zHk$~OSQ$i*ikHDj0n7+q(Z?BdJNvW;DKC3+h(#jkHp9}{>y=P>o`@g?QpFO03>-pG z(=(Fa;!6%@sg2TA*PF$5%_{Kp_-DEY;)jnM!z9lQex0O;K`Xq%u{6mJ#-ylnZj?{# zFKy(jR<9^28%c526Q<+kAs8@#4k>ckB1}gCj;Y$Ov4`dgc>?_o+2!OX#BJ+^_gZHD z6CM4dZt%r}+}8D%9#MLMFDqorv!H_64iHeSu)$&01{RTn*kuJA4lyM0b zc)%a-_E!_+D8m-GW;=WGIU<2r3@W))AP@Hqp}P){(B1n!3wMX}@h9|zgp|2!KxZJ{ zl2T|A7oKr#3dfmn2qqB`Djp7peG1J+vgZIj6RiDO<(C8>aNSvm2r)ajbUTT-J<++dOu}VPLLr}_f03}iL*S~9D7yr9+%7N%X5_;l*i|R zn0FQ)2v1{_R-iF|Teq}3ehf}b8MPdyXQ1MXoZX-s%-9`^(aqEc_YY4K(;u`ow77?~ z;si6YSy^ZD)5H}cs9+dnS9m9=eeEoMtAz^?M7$lxW-}++SlisOviwrt`amzPt)<2k)-xf{0bkmz zl+!Djs!Ay*>uAJ0EkFgXv<}U9IN7P4aoU}k0mKa_n2u4QoTYpnCHF=dKYA2j1remr z29L|dL**01^#ah_Cv<2yKcO`PXv@nE)inM_L1z^tDi>5cVOjvzAO{_LcUk0`Ygz#g zT6u_kimb`?R&@n5EZ)#qi*aHJ&FQAnAKT>kSR9A)PHPiPFfwzM6@qYI`3I97URt%gcAe>GZ1 z@j&Z>)v4-wY-LWi?dAa2{+qS^Ky9vn!fmB0uY;4A=ZcDnmNNB2tp*d3WH(8>FGhJn zeqWc|!K8{dcr^A5D+7}TWV2c>pwDeVIc+@VOEW(M)5cwu%m^~UKA}Ts5<0l~;zrC( zO-*2OPjUk%MsjXRH-k^l2$4w6iN&S0u#jVFQ`^+b)8P)?pFv+NtBkqAEU&u6$9S0h zf*K-hVuK?v1gyTEEu^5joSPLE9LB7a)LcD#F!ekg;PHHWU zrll(PXYSA+PPlMsMk-@jCtB3RJSaX|G#OIz`%?h+WycO3lm4sP8ZO5azezSS!cI&? zlEzX*SG2R{7wKDLLZQb&&;~_Uf|_bO>hRx%uE!xe)5rCE4@^#x7}zlfDu=xon&c(w zfKCjs#pdDVBwuv!mRYnm%HE}g7XKM3Qo-qbXYVPuQ9Qdrn1S=}_Tehc>l5h0gwg7u zH8&WKG z*y3ZeB0Pw-_VEqJ*z@PO?eEc^E(Tlsp!H&+M4Yo5y;KuZ+ql1y{)Qbx#k>#t2J%y|Q<+jPm6)2)IX44q|3bjj5~P4alBzpBWWhLTOQBVxqX zX%j56Qaw07y`Qko*msmoE%C3j>a?`wP3awk0!G5Yi1`?1I zqR8y+6Eq-8JAGh|jxsp4&m!p`&9bBDpB&LBjjM1VRdcnRd=mZ?A)UgXpCb?szEii> zeQr#j*#2HvMe-L3w{uq$NyC1XCi2`g23&3o-&pKu&(XN9pbVky!z%ru{=yX2^+EHv z{{!GR6uf?<_ zIHshj7`xcwv*iN$?QvrW_Ij>ghD5z2Ci9>|P9hB55KwWKkS>7gHM zapgBKy0;yqeMuKQLVnr-;tZ+3tB#SSR2_IpEZ`!~&w#MLYM*4quo1lG84*JRhO>_<#a%0!u)DvcMQupTKe;~s*V({*LV3X21`nI zIuu9f#)jLSPzsWWJ$^ASz!>)B<3QMu)w>!u#sIn|dmB@ynOLV0Gr)1i-Vlq3C}J_q zoN5c}vc$@~u&4<>cisPpgpfW!$J8NX5dX^w*$GKu`YZ@}gXrexZ&ZNF?DY2CPgm;; z!jQ6X3YQqWORq@w>0Nx>#B*eHFV}qHjSXMIZVX$K4=jQxCS8-n5Huln3lUGL5v#%! zab2Jp4M6$_K;~_WAS-G^VrTC66u*GTTo1P>ll^vSZI=mFa#U0tB>&~mGF|&C_s8e< zw&o9nT^xpq!}5Xn@u9!A?~g+!-(Dp~Ve2n`4Eq@!XwPCI?C}(q9EP-i z?bzdX^irDrNm?$_QsV~+kd3A}c;~M$N#rX*ZTQQKeqsYVEXEKmW-0aR!U6VQ`6ZOd zW`tGYy5H3>lB=Mg5Oy6yG4rS7c=O{QZap?hS)Yj3dT>_=LWCf3{a4rpHE?!>-R<*3 zyyl8ki`3jX?1_o#QtBthHIY_%M9&qz=)v)olfaKl3JLoU-Z)hJ=>EB1O<5%B-=#F4 zyJuBiqsFt3#$9FXrRM@EM~!-Y;tKs;aAFqu~|(Vg2#_mNg1W7Gh6d55zm+ zPx{#9VkFnusZ}JMzp0?rr&Ll{DH-gBKg6;lX>=Ze)804YC;dAS|w!oC_r|JwK? zC#O+kZI%dMK~C^Wvjha`#?&H{o(W9`cNOA^YdP46Th}VvVSU0ABc9jcgS=7 zM($!g(sjR;rDf$Ry0QE3L*vRwN!xDkVU1j&49!+EpMZ)HA&wIt{25(ne_%?9aL3}wKzn8=s%+1MYR2(u85|P`a zz?`h~F0<+l_4TIxVyW%7k7u^?h>rH26uONx8hDI?%GnMMN+J|2AAU8 zyze<@eD@b5V~=O<_3XK(NDi|io)A~ihjvGFC1?M(FD2LYGNtM_GR&m#VdLp~@DRrjR5JTsUw*Y&-E zi!%2(647&F{|mws?IW^*3s{hT+(>Njz-OSSI}i>+@zd8NEdBf9OeUQQI-4TM%c~{6 zs5%D`Rs57$LyH{6;o(X0QC)h2V$J$U-bRv3Cp6&(^x$yb!lg00iJmFQ9B0f(C->jq z52t^n0H$E2MLNr^hJQ}4OljXU5wHUg2dO-qnhy+_54@0>ZxRFzTt6jniUU(_c_E(P zymu$Ovpdv3sYz<+jETz_Gr&Fifxqm8IRlF6lDkJXp{%qY>>esWac)N5i!X()MAHeA zt>@VNqtarAQh5RSc~msO{Pr*sSQ$IXXGoAz8!tPvTxs=182)#blAw}Vie9B<*B~LB z3^0}onSbM?G{h|tMi&?$@o9Yr*~A|>Y3T_Ow;m%HP1W)VtL)<#OG%AEo8Bhq+l?AjvGP*+v4*t+n}b?df%PWXy}VpV5I%|=M^-u5bQGs1Jx|0YH^QA z!G*jWgt7*+o`*^#9FI&a6YSeOPO)>ym6xKyhojZ3AF4t{W7m@t0FptoTfw17lR3ziD(F6x1M+GrW$<_Gz-wD)YAU7u)G`>1_K*M|M}=AlC&`_A-xnR)~j7U zG=;E!k$z&5I)yOyVDG~%P6|5KAm*VQ_j||Tc*VC0^hQ&e`HqC~l6^6%;2HR}{h7;` zVBm|?cjSUB!BkD&n8bJv40500)9)rP^1t9iz;>{H{6N^jGyo(%)kq`y#R_IUg2tV$ z1!=O4LAGiCwWD}wc6K^j_e8HWd4km}aHgjL%pCnONVUO;ym^3h@vKvIxHT1R?AspD zHadaxQ0JVo{tHHI;Tq!7=vW=2vqsl5dm80}IdtO1#lG!IycKa6csK@D>Hva^b(TP& z(dr|fhF^n{W+UcqA75O%j$WA&oW9Hx>qq28Q!DoE@adhcCtQat!U$S73 z6E4H^EIT{7awJ3Zu1c!$Z!8(TC4g_&9bTEm)YO!K>T_x352NPdX30aYb<)&g>$5|> zzl<82Q9T=U8p?hIGK}R2w(jU$DOGhbLW3B#vR}CQQ(F=vo#5dlDHS<9eI3J&M&-C18 zwORes2+Yk<@B`D+({rjERZ{Q)8H`AX8AoTWN2|;dIyBJb5=BMcR0j1stAhr#W^eJ9 zEi!eluqjy4_@Dh92O}$P-oh{wYR3fCUxI4Sez#UWBFd%lEoC^&k`Bggd1D&!x|HE( z{xm^t{DjSAgxgzmDp6qq(|$Op!JXl5Xh9AB4$(<;BYS-|(%K$I={zcY9G-8UWhBkC$*xiw`lF?W4w=hvNv;TIK32vqT@=79XS1U&(8hU5D6Q zDJUa@;&_`@=9vV-t-6i8s`MHkh0wjHz4&Kc>%j}n<}_{8Pt#wLuKpi@>z43)m&$$Q z$&1KjVSi;bh*4yjqk3DpjL#=DDR^ukksvb8OKe6Dw~~}Qpe1YRmen^CQE!8oP7|g5 z?C4IOXt;TBqqT1+9<0?eB*~o}yF*IHDs%t$^YjJOQef%q-Lc~#Y6m!e9f{u{ED8B8 z=~8vw2J3B(2pWl>Im|68kPEgPK8xRg2Ta-mn@JGU9XeAN+#`JKX--IuH4IH-%5h*p z#D=kqN`9JUA6}-o=*F)MJvgT3WC%F%`WNSTxwBWMhr&t+(2cByySguS+ zGQ1R3M9=;&BCGrDXhY^g-il9aHCRU zLz=$Ptt0db!h*OTNCtw zo`B!GMxCsN?@#q8e|gKRJ66>!7Thq zmgwXDD#{K=UZiDIU45V>Ld>-7?o*lzf4n7TG$O z5naOell(Fwr>AQ?E4bhd<5pwbW|sJ^ZE6UgAtp|ouH%j_j7~)TP(4-1-zUZek>#kc z)*k-W__{KBTubhIrqftSXIH?JdL%G-w!nG#CdF{o@Y;I3j?r7AYwgow$FcE!_w!ZC zJ7y*xJ-3h%2dt)g{=_tQb|?CX=_NJVuO_LmmFU;a1@I{jpmDb~c3n3yTuhaS$81dQ zqJW4;Fe<$0CzG&F#Oz+Lg@BM=^wTOGWTW@4U>1GjmqnOj!jXi4b0VYv{}4y48d`n` zJxzJg=$k{X`EImv`Y0NhW2J1OZ`T7F-(s|p_;Hv{5x%(JyyBnbp(sugmL$qd-s3h%#9OS8who@iKQkhP7m{il~2%9nqbz`!gBzSO5 zwaZl%2=jcxq?yzuG)xPgN=B9znIryvOku!+LajkKJ>(tRQ>nK5nD8#qTHZVw*zz*0O@dsVK7L2c;$7LTQV!zfERGe7QtVlua%L z)aE@Zs{N4nH4rSP9UV@H3&qTVV}~=Y>AjR~I(2xRx0oc*rhaRsvhdliNXhdNf4$A% z-z+OOLJ$1+6g%(+X;+6v-^RE9u*(fO`koep9)VmaXoVTZECy}W(Er%-@A!_HqY=Q# z0YhI!&-6o0-cK6qmsNOLCRR$ro^PcG8dZz#m_-MD{ezivBK<(wrV>qPPHesE6#(Lk z23a*SLIPjKB7R?3UyaMYr}$|1n^BmuO+F24C25F4*9Rg#LQi#`ScLN(%upL;>Tw`g zX&a{)jB=x{`HcTZvdkKab_N9Fyi;p9uM&U@ePS#xnso8b`*q z_sCr;qEDby%0YzM9i02)Uf$dmDIa+(3)kY4>JMpEXWSdL_bnur)3vhCbEBT=AYV`l zc5zP22&xvySM58@oTBiK9Jl9Dt{)W#GcK`E)Qp5^uqe%Ftq%=02?UkZF@UW7GMu4V zYOklsUM!59Q$WW$c|*DH`uYHestH=zsa`kMr#D^~P^vZpRdY2mUT*GZJnZxbzT6dI z;nc6MjMWX6c#b>{qoS0EvWqCc()MwD8CxKlGXM|iaKjuM*J}p2eR!QD z)S%r+Sts|d@u?kFT-yktgzxP+0XVJQKdppq{EbD38xuhOhq_Lr{s8{atH4X4Ut0Oy zQ6;)H5BnO$!>|CgpZ2@^?$F#{LZ}TE7gunnkMp$C2 zki<ZdKdW(O5>;&9zGE!DT z$%6*Umo7F6a(~0?K*R3;%o&LLU<|kVfy?*QhGwyO{LL;=)a8PJq;(D?23EzzSurM3 z)DzLq3u%*+wL4sk3^=8z`!sW-SdX2`Q?0^@om4Vxgox?ZqRI4o`XsxTF5Y-*gjlgG zNmA;bV#2-t*mxvTrLE}DLM%2R&Su9LaYh&s{*kNBI*GLdi)FOSwF1N67*q3w@oh}Ng2I%5|RHgKfIVHYXk9{0Sy5hne zZE_X21$DfNMSPYZ#!C1nId7g>?(CI93>9((P@9QnXevR~dUA1jA>zI%`0^JrnAvZ4 zeALWWZ#BZ%;7tNi^fI%OG`N2bW6k;#(yc9y;+dq3Cu;o*{q<-n(OFCv66HhJnv~2W znZntq-2Z^MdH@A78{~#wf^{)}I=ISsdv`(#$7((ppdA`~pslmAwDU`ZkeKM>S8wSF zD!k`pVat*pR4?ch*Hv~zT7;&VL%Us>l)v35KywEN@j6 zBU{^pv}zBB`Tl|1iMLJF=p1r^1>q>$A|WLJ-=?7kdF_?N$VfUq{;%cr{p5EGWR(Sq zG-h^$btX$1XH^l=ZpM5FBo(}ik$(Hqa(uJ;vX;{~whLn=)8@`~Q+_(A~Zj#PT`|c3$6q&1NzIW}0j1^wJh~pXvB9wOZ#lIM@8f1wI zh``15Nz!4xMOjr~f~Bk#ZpRrkeJ!7k z)kDzu`TrQJ3xks|^r@OZlc{yuTkm(p8rfqTx* zvfC8blRybXuMjvNVhSzV)It*J$E9O0v;2<>}rQ4G)NkJ#eKmP}+#-)Aiohd~T zh63)g!)S;l^tvZ!fp8AZPx?qPzGL_+kvVzZZi)r((oCqBM*ALnp|;%#xWi?$lh z!mr_NxPhsVP9};|m%OY4YL9aTsTl{cy7zpQf7Bb++nbochP5H<*fCP?Oc`Hh`N4D_ zXw~dGv~`ED!}!OFKM{O4O|hceB)-s@*mBZDI%w-_$-PJslH%%#*WwZum&2K1xqOJT zzc`W$u%ir%h8eTbN!`63E?^1vRa-_KEK};P6e-3z+^Qu%0CPBg8shlR)0(Nm=l^=Q zr2q&e+)1vX^6~94kWu|X=8&BBpNny@eT#7|IFAr{*yXFw`Q2Rqrr{N4RQJEl-WenA zoH_RPe<#MmZc~C}R+EC@kQw-Bwc#n&a`tcNHFYi)Erpr#gw@1Sj!_?~m1*2+>5vlk zNZ9tC%h;!oyf;^wKw9$sB^2EsrO=Kx8}MNDThN9Zv`IZ1N~dg}(SuPAtd-kQY#?Rl>xc_HsbXy|H60Iyke`>$idH9yt(7!S zLG_=d)kH5l_J{8Sv4TIHncOvG^>dk@fDNgZbkU!F!`Y+rOfXAf&F08iaIX-e%2ory zMhIrlK_IfYX<{bkdU?_-b)Pw^?~ZHk>&zI&DGxYZH`Er;MBcoR^~-uOkZ*7bHIx$R zME2KKA+1lfv#SgxJc!E0!UXts#<*)#7?%_S|b z>uQM-N)`$!Rb;!gK7SpLKSzX2%MpDELZ}79h~kpT@1Y!)8$96jLu?Y>JNY7_mGRQQ z=~LS|k^`O{7WRyjS_i(OR%^7>v!smv2=PDP^z=w#0%f>)Yp%}ra4%2)5BwZUUkKAb z*D)NucZn9TK6@5 zOzfuy8th?3jWU{)nJvP)XjenJw)|Ei19)8b=e&3FcqeTz7z8G9o%U>14GRq?A;kH9 z;A?DD5>}L7ET)A-S6O8zBRMdbNoLF@K|x!w%cOQdE>LAbZQ;!p7uQsA#5W`-w@j z{lI}I!crzQ6smhVn;p6BC5837kP5Id{+XyqPzeX(AS*XnR?)7H_4T~0{ zb!nNEV~6SA7QfLvW^(z%`3soMIWhDS;1OJxrmrPO)djGm4vR&)lkxgQ6bSq3wm46G z;&Sz0sQI?w&bIP$lF?fD=#EL$_qUTg=p}IU@d^k;C{kdm*kywczdII`Sr8-#osdB^ zcO_TEiSqoaW<;v?`S*mgsN+nUw_Dx$vi2=Ix$p?%Bi!;H4xsxNxt*$3Uy(&$htgkX zM7F$Q$bkx2L#Pw>;k?(4m}!noNxv+v^IkqZJUmcwa7TRr!uMWD)`JGPyHkIqjDg?- zoU%Q~320W6nf(@FdNGKtPNlSCnR7>C&iA2dZCOPr@!yTc#u-Y}KIp0hk>fZmZmIW; zirjx>5$lL`Rn<7DjHm1GpXyANk18Nx&|c5=u~*@|n`JI9>tjOm@hSaUfp``7nTfg4 zdh_^%Skh2cB>B(ewyxX``A!KThTpcpuMvj}PSBsNKc)zRF2_lgXfS|9iSex11a(Q@ zSc!7l2y6mu0B?Xmjp3tA#$P%QQ~i$fp>uOtj__w0@+5ukq$mUNv-sxhlt9nECr02> zmXd1Pt{-&I-(ZBf`!}vKR?W2s6Z{IXdxsY@wSfgWZXdeW)(^sDodq$u5KG&rAo@So zaq*cXW%hgqTZbg>=3Q6@Xb}hhKpkniwRdx9$NsH>6ktZcC&cXb@uxz%HGkXR8o57d zMz_|@jmw&vE%i(Z; zj{(1_3)RUbUIKGUI!YRiH|I|{+U;R9lQ*UQxfYLS76PMs<2rLh6IX?rmA8d7$PbKN zem4k?F%{E-K_&~z9>Ih0Xs@d?5Hw@aC~eKTr+-_6_%j>Qhxfsl)_^ z>Oc>jl3sLZMU<~Ts@sMMvO0{}M1Q-l zv8Zm~PNw)K$Ds7R-4nppT1{$p^XPRNNsUK$qq)}*uahvU=uI~`sfMfcgrmrdCDWb2 zQcsf1=e=`*KuaX;saAf4uQog|stUfVqNWr-hol9>sSS2Ln{=51fcGXfiF@ zOXO?1T~6>z{dHuJRDSAB&532WNooc=q<1We8+Ef_?D~iCNyhS~GGbkPd?t(*`OdFo zzaYs}3z~>=Upy3lZN$JD3*iYdmD7#^JcS0Ff;T=(QPtmTz0b#!Z_G7yH{!Nld2wgs zEZeTHZmI@ugzwI_f`euV1D5)sUk;4LBCELC%OyYc2aS&&lP`jga?$?@_H+3ykGf0z6t#9X-roYtxwE z9Y(HL;fR@QQ0CYB;r1;tqRPLK(wl%z0>Yi2>7V*^^P>dfc`b8)N1h+~Tn+p+=dukX zpnqOXZSTg2;_qNwf8#F=yFMR~@DT5wx}diBW(|6$dNhUIM_!QhJC0-)CTx>UK9suS8om~ZF#X3I zcs=cH(hG_@Q~j2n`dc=L-tfJU;jo?sn%Qz8|LSNab@vsIm5s> z@*U;ni(}`~luZX2jennIm3EQhmPdT%GRktV{&^kvnQ1n?vaoZ;lr3cl@y`9Bg&O8U zM-Es^g~f;rR;>z1j|3yWwyG(5{*rI2;INW;G4!CIz)tE&G_l^(mtNu_SzCfFb2&$} z0h%jPc#fLcC_u7#N6q|zko0y73kBn2Cai3%`AXfKTziU>6H?9Gw8-gjuFUbr;iXxe z*=&Ybetl_j{e-X2s&;-md_oNQxxeMKZF#z#iWj|HKCE8j&i`5+HhQ5cf2e74w(@Mr zH|oJEGKx1n8Ax5)k<^kRe9oG=Dr^u5rk)+<1Q}S=hl^pHx|Nnrh7~rUX5N}kc@j<$ z`|sEXvUc1sUQrcfN+`F8+-xd9Z8-csr5uC*m$H*MogBFRq4|YX&-e?f70)6102{nthDZFAJ)5fEDD*a-1Kq6 zed>Xinn2ji`h8sUs8)vyoqR;w_Mj>6o$|3sKrLu8b!#^=H`XmOS}v3_Cvqg6MG(Hy zGR2l@5ACJY_5kzGu1PHUzQwn2{+n5~oub!0Zm!I^RbSNj?_jc_$Ok&Stoa~B_@n6* zDSbU8n*O3KdmwU%wl|(&ZP~Cf;^-9yU?)=JTP|hjAGYRRPe|M7{44={o#hR$=hQp? z31yF+&8cmkpmRns||pD70JT@uYzN?#Li zW#%;i&XI`8iee}Rv^XA}MCuRsa;ZI&3BNefO?M-$WD-B%7stpQ zRL?6KPb7EX^raF+=<`}#c_2{s^6;dRNkpkon_ErqvyfQFq_%b=9E&JQg0&J*`i;(@d(pvVb{yyY~m#nG>c>l4J_g%n$%hohrg8{UMgwQV8u&lgsD^J z81^r(_B2g$+hi{SQSu2B{tubG;iew{Xk1b&p-}X~PRyzdyt|px8{Smbf{GDte|_fT z6KI{A<76B9!%5&e)ot;vDMI$kJI2)-#o@$QF?@CAgMx-ma>oo=2cHfwGIS$=e}V21 z!xCArv1NYj#Y&EDvpK=WO`*&+jc$ZDuNbZLC_Y>}oC_)4@6yN9A_024tO}buq^02@ z!G6&p$~>$2ufSJ63;RSNJx?! z8l32^D$)pd5eXR;I<0$}2?@osdybI4jmh%Fz`z#K%H`CUPu||{ zV(;K+W8)C0rr!$dIXl>&NPbf02Gu`HrWY1KO9?FAhbw$d7s%Bq-A1*a$y*g7Q1NkH z4Y~kB(;|t}OA;irFrv8(iz96Hhzx$TG%L4eV#+PXJullNEyh>>*@L{^P87voI!xg{ z9pZM{z!%jkxw6a?m~yL{L-bnh*2sy3qo|s@^8Knii<|0JhYrBHdb;HmJs%DocH<;Q zz^Y-WHrq2IhjB5RjoaI!(Y^@29jhWaL8QeVN8YHCj%K1(FVo)p{#0S1^c%RGG(0Ho z$kQ!%d%!vFkQx#{tao2&<&C&AQ-%5gEkMovbb;h2HNv-H%?9)gJ%~LiG&%xy?FUEP zOT){91|Xn&WGQQORcyoQXWAS;bo)%f_*=}-JK6tGCz$G>jmfJA82!P_T#tzbU^q4P zUCKZf&2L7*PmbLBpt&81;Gf!|jcFtjX`C)epUs6%$c7d+3Ctf~_B7f^IIxSh&?N zc0P}vXOa&HSFn*ZGcBeW-0h3Tsop3IGU#n2>wg0-9pYYhao5KGU`*r|CXHH-g8e_l zs_zc4=c$X`p@P?r#qV{w5fL4+jK6-UpHa6ceiTa777ufIndDM>B(gm&8KF938GHEZ z%MhCIe%~?rdo%f=F@-k~7M=(rh%o1MU?_AG79f(3)mHsME#?>YKrhl`fa3z0M$*D& z@4L9LPqQM{-BNSNYdx_C28PqB0iDf8%KzO^L+T2uvU+JAGVw>46+SSgC%Pv9YKR#v zd{*}cG!)?WB@QrmA>pjq2d<9E_SrIfBVPHUB?VUu*H-yOcN1U}x>bVf=?iuOq2!D5 z8074~=n#1#@W#uYuu?;F>z+UVJhGfw$kSEP$&Yw@o{+Z@9{EWHd0EC)gYNP#aUYXqfnav zOdlOiZfX(kP#n4u7w)#0i2^S61trhAdH$0PG$B)&?|%s|wUz!C<{1W1`aZ7Hpk(Ee zaOMH>^?c{<8&H8UzGs!v(~O{-6TKFLCM~JRNn4Th^#x)RmH=EiDFv}>Ztm{r=eoBH z3aEamA#P}X6qLfbIsW9z1|8o`iAw+4Dp1YYa4d3Jk>moTQ2Y_1|D2QzQ>Qtdo5P3o zMt=c~eaW%-$L$JR)P#ls!d3KDnI`A_-nC`;8RM3XVf*P~o(5DXlU zXA8Fo#oP>GU8`QM>A2rt$HGpH6?75YH6cRtX1R$yvf=P<(&q(}!)Z>l!WGfjp!`b9 zR2zEhP6^Rs*s4DP$Q-&MDLZPe&42$0%UR!`V5A1~gb>3uZ3>(NS7X6;3GIpwey+=Y6rX z_+>4Dm8A4=W>SrxEF4ENfq+OpdXl{Qsdhn@j6k*p#@b-3`;6Dtu z9`n)4M>!qxBbkX#2co)(9dB(6w5;9K&aWg@3Xby`AwEqT4~oonth>#k?=S5z1(8F& z+{=&uWokac?qfEekAI>DFHre8urDr1J}woqWgiy0fo7wo3~LM-%WTQGC3fSo$EDLs zo`a4o!}on)dm-3#-amJ7vvW~1ZlGMvafj;7L3M{*2HjB!^4e+5QC)PRLJ16~)Nhxym2f2ka3`MJ8Qpj}zi&DB#N>sODB7fFSAdqBskset4 zk+|A1y+zimQyl-0Wa@dZ|B*-BHok&epCI8FhhhS-O~j!l8K5mjV%%$S28oIIj=;t9 zH7q5Ca%XoJ8;=Zy*I1zD)NnhroPrMTw$*%?$n{KfsNQ(&vifJbuT8rko(`X-t1G$> zp-OI^@$l*@!tCtq<^6s1;Go3%#>VzJBvi*7aWQJGLSzB8qT7;)dd04s3We9kJ(oSj zE}vO_X!@GYOOL?9!VC8J5A4GOLRm`q{%QUCnri?#poEit+a!h!BavWK zTapIoWYO*>N!hth;$WwknS^61ebb zM2sgjc%_3O9~O}FG=8;(=XVmrh+>pi0I_g!;p_y=j|zmO)qci)ol$1Wvv`I<@9RAK zVD^V5j?nyXiWtr{7Z;BuG`gZSJgSQSDjn`7aj=B)25WMhc$*@b38PJhH6>y=QSzXu zwGZhNKCXEVazCvO?32oGF7R5C!}SF^a*C9ehj4s6t3kb;r5{`enbj4^)vOD;IYjH< z2q_)Gp#Kzh^As4!7{oqDefSewK{cq%a(ncnuK*piJzfGI)Jor?UK` z*jtSZF4F7cTwJ*S;d9X7I$67R<)GVOROqhoPda*P?nV5TtJR?&eKq~@uk48|ob3C< zU2_P$x=;Hg1}Ls7XdcZVvhW#aPc)nAv=iWWt!VzEhhtkc1?7cs1$V8y-y#PW`2v>0 z>bFO#G_Hu3M098^CqjRb3+OW@c;j12@Kc1g?gTX3W!*Fj9Oyc9BVB<+=kKd(rqF7)fb< zjeB(?KCOuW4Hw<}Bd{i+?;M)pW`~L(lE1LB91C;ZSN#>6-CGxMLc%4;7oI<%#zgAB z{l`=u%23MUxC3*IR;jtTkXl|&c5PJo6b;=&ZXl$IUNFp;>W5KGp^VXH+(N2+4Pl%b;<>+bM+e1VlW-$BD z@n&e1u8NY+i~I!X+>fqpcxJm7%~fouO=*kWkTrd%SDLvqqtTwlL+pYU1$3?u*xP2- zbob0r{`baoXjX?ST03VGxY-RlqcwSdI3c-ccsOkGIUm{hx%M_aH9L~mt}-K$C_ ze4es`xw_LR(<$ge7(2$+!-U~%*C?l~CIoBc!OizEDess#@y-YLq652G#jfpI(qmQZoT$DgRa``On8o=r=+>3OXOs%aEJs)d;N00>- z(SH3`3QhZa;&Ash&D$ipoK%S)cp%^4iGW^9zKKzZ`|CyP1q%6Ut|D`idySp)eDpP4 zM%Yoy+w<2feZGIpgC1^cD4Wk5g_x`xBF3B$lB&>S@cl5ugQ$2n?q%je1{1e_dAHO} zL15#H=eh6T>rd_X%d56N{^?F4EchmcDi)kfFVWssT^}S7wCJnOCXKbbnAW=ou1zlI ziTcez-)y?lkebM#ISn6T9#fR9Yk2}u^J#^h0Te!XJzA~Q*p)!YrHZ2`nRbd~nm0Ux z!})AGDZNEb8im_wwZ!X{t_OxzT(RMQkx2UgL!!&V5e!@vKs zj1X*eLT+LH#7-?e8(T3$d5TJ|xuBd^{%4oqv_O}zmnPy{T>E@M?xeI<=JX?dt8eM# zAl5ajf^;|#(;peNQ8ammY&X#M5VYLDa(f%?vb>xI8yh?9@{%huI+})!Em0KHZcav8 zdP@bz!#F?zEF>64Nt)7j&7$`iLQXV1bSW4u&&aJPTk`J~v!d5EH_$mVs7Uag< zBc{Z;2z=P??E<~h9u$eZbxVoS_Hk3)&aMg->hyI9dYlouMFcjLRgr564}jbBVT|ec zsmspMbZ;bK6)lmHV3qD~z0I@P$>5zq^}MNGeo#0fEeknXu{|Wn4M9Lu z;1>QEhpXLge@YhFsmpbbb)oUMuki>ejy8LKo_w3V1Tu60n7z0859%O~}l8iSMh;lEpX%G;4<3}2iJq}Q8{Oc67| znQ>B<0dUhf=WTMd#J$6l)G#gLWX5V)&MG>2PptB0=9IVT0+e~o^LJe!F=ziC;@D7* zT4hqQJapXwCoWiwaCQZh(UkvZ*;^F}MkGuA-m1XOc}fg!vBpgG)m{?tVZTfOqB2HcRLd z_|32!|HB8NNF`Y*9dU<*gq6U#DGP2boFJFilJ|#y#(QLe(mJ1R>%Ugp0UN4YVL&Oi zPJr)}W*Z5HU2PWEXJOZs0GFcFq4_|z=Ygv~NP+wVlmd)+EA-&jc&KLACWWdAh(VcF zSs1pV3k#wXsxDU&uHV{Fn{{0;MSjlWj}wX#UzOMYYcOYYrB-4nO_dV50mJL`r#kYR z4O*+rCK<}At=qnEOVF8_@KOBeM4J5t`RH+#4H~w_$4+QJ^Kz+B(-cW6&E zNR|N~*ZyYMGe$q%69oh~p~4f3-ZDjv*JoxZeysGWHt<&k`0|h!?($ zlsO+_{@j9HIPFoEU<0e-wvhJ2fjswC125l19*&@MuH%ID7UO1aa)dzDUwi9ml5!(j zv@FCxF2%Q|yeHjcw;d8K4Xnq4Y_3)-52>N9VHQ|c_c0fmu)D8M;1=L-#D%O6{b;Zh zKwxU%|LVn++gsuTEowL|;qZY2o(Op#pe@VT8RU- z*ENiQ!N2?_C1yl|7DZ(2FUp|iFk(s6D@^D9N@IGBsZ&ytQEXZcz9b>VIh@kAy5bFs zDbCuMY%$7arpKLsF@cxHupp>Fu0{ExIX0v?!(&X!*(X;pYIK5^0Atgg%M+R>$KT(I zgP~@@_em&h{yH`Sk|r@RQec1L|1u$OmnSAO!%^@GolpDU z;FVW{g~U!{1|F?v)B|>w!Qv)GqMa_oY6+p_jXhi~jwSR;j{%?XAWa)C^|%ZN9k8BJ zbX$OP7ZluEbr*fpCbQRAy4gc69^ph2DRH=dUg(lMWqP8-rl{qvl)6*V-l8Cq1oXpR z+9V-?{X3nIzFT;_ACi(iqG48)p7xh7c81qMpvRXYx{%pA<`2cfLD_*IY~Zdo z)NHDNbITohd9!_#GgC+T(|M#@X@3cwsj)^*(~uuUDV?3EPSQ+@dCiuZn;r4ZGLXs9 zSVmbMKUzN9*Rz_&dXG;j{ai6wmvnR3y}_S{CwK=`wGViTVa$Dso}Ahh4Eye$9-2~N zXQ7kZl@ZyWL$|NfAx+EVJjA=!0ZoiezRZ=~v!xHq$%LyO`ULrnF;TA^?5jB&^Yz!l z-iv;^i?7iPwurSV_NY2ku*e=ZYy=y%ub7-7tL%idQvr*$a{;4^j#~o`9=&6K*dA7+ zRY3wVgli&z@J_^gt|_j@5PkYbq2ZQjCir@%xjyGmsz-Fc=m#NG-02_C875l+kRiM8}G39I;29iO55icN6&)Uffj8gylEkv=5WIKSY)c^6&d6DizJcRROJ z3w;X|+Aay7l)R+N8~Bfvx0L5OQ<(oL1+=iz1Bpw<%W(uObVcVKzJ_zy*t^sx?M@5!p zKK_W4Pzfq}L~(FgNR3Yq(gQQ8uVqO&eIascRV5u2zZhH?w~$8rwq=45)m!COf&^K) z0D`C?=|S_DkgwDX*%#2jo&l=sb2W+3pU=>0Dg_))q&_!u10%5H_z?1B`9Lt`81&Tg zeC;|)3Yq!wpCMyRQkVy#(;%F*TDHUWc(9fxj%atMk+^xDhq|2Z#Df{w`&Wc%W{%&X zoiQhzvR-tD%Y(0LgN+Rs%G>AUG&!mx$_En$cK>w}f{`*M`iuxrxZ8#%ll+_m`2&tf z;pmI1FZV0i#0AS@1`{dAa%Ax?mv?uVp)PNy=b56h!fEd9^nvj?i!UA(vCKUYi^CtdTXQ^Wslt`8(-9u2I*U;q`SjLz{jK`r7$TKl zOUu<+N%4`$vMjCjti#v^!2M{6QINPSD8q=f)d z4$7ZrVl}P|lfTyERwqC@0v+XbP4K#3f3g`i*L=3Fqn!>gb4SxhDsf`l%`#mJ=m@X+VuYc_zRTd+^ zN-lMM&8!vB$i=rhsK$gA{D*jW@fX0#VHWJZi0WGF3#^Z4+|=0FF*M|W>otdbo%HNQ zw{$p-2$bT!QByT#zxL!!%O;ZB;y`oP3~Om`9n}O9+RLuQo2=kFhi+ul7Du6(!b&z-y_QTTu(sB8xi!4F+hjynE{7gW|!N0PYINj5rm+Iu?kygQhL-UKM zk`Np8dweRk#k7(2lb@@{^84+wF?uDXNv+}&|GPn^Ok&DS;uza{71CZW)EqVcf`9ps z*f7vr3+0=dVjICsGsE~*iaYX)w=_vfhNNGR7J34s=|a!L+{NpS^MAA*R2JWQ@Oe9( z5hw$F^NvUHhkCDJlzv5pDy*G~Gr0LgFdTgvxffn7S7PO1{KAk{P%^h!;R$*1_48F- zhS$2z7p!c6krQJ~&kc~2U%~CQNdZ2)xHFOSevsfM9 zL?z^G{n&Fu$g#8ji2A96rrqm+O@5Lw)mS64tAj092!5mxcle`tb(T=iu$LusU04D~0)>pa( zR~b(QOe3WoRb!50peLfz$XY-i$(nyIiBS&{8Xqk?Ep8gOC_rHXctZi-_wa@8txE+i zpz_T;;1I@a-1|$&AGNi;LH6E1@3yZHm%h_6kjaM#XMb$za-=XoECaHVA`%Qss7A9s zITjUG?51q%@z4I8F?gp>lgR-erzhxJ+v9e>0yo9}&WQb+Eee3gA#9>Me-0@AmnY&` zU6}V+dDH)|PzX0F08xRMQyw#+IBbok&0YWD(~F}oOfZpNXb1@(WyoJv_zmfdd8&!? zOH1HTYX9>;%b#^s;Q2lS`a}MjW)3MYCl;%=lm3f(IsGN@*REbGwZBh1Y6Evotig$T ze^bV`lpui`T3ZEwDv4Pk^Lg=(R z9$ccW0Jw7A`FdBPb8r6W3IjlHhUy$*1a+f>hAXS!&WSmgub;9nWS-VhtZny!KeC;3%{3xAZsAV`T&4C$ zbcx7C;5-iqtu_t_AY!*v$)Znm6q?cIle4ty{*!ZQ3MC&u?RS!6S{Uwnu$e9-y1O5D z?VoGB9Gx?HQ4%F>M$o#e8;C|HD#?IFU=#S*>WiC(XNYvW!$PA@!M z%L3j!te7~+1WxNvwe-nJ6FGHF?;cyXA8x@0eRuu{#X{`_Zv)}g{EyoD-rfUwY1je| zuOyRppc_y2(FvA1Evg--hZdZc35kFsmKG?#ppt zUx))iuqO{P1PO??-lJJYK0mV50R)iv{DL@oh}Cz+%w~FEw^;8^PdTM=m~}RBb-Q$LM(U zdB2Yuv;{NT;&VtjDZ{X^oN5?p0&8Zr3t6)C7waxgG!5I4?h;5_1DA~03|ldP40aaZ zE3NkqvnrlLYhz>cDsVuEPmzq4S`OkL&;77ejX}WQ$^GEgLnII+NaL6qqVxOAp~p06 zFz6pOuw$+$Acd$z!1-^Ijp|SP%^uVr%NLS~%CoLLK_6UBxXHe4;D(*|kIwV5VZYKc z>@d=zx;a{7yO3x`l8#;~0#`APEd+?=$yjWh5ytj9)w@rLTW+R^Y=y=~#bk*ejNPU$TmhwEWN?Oqn{h6WyQMo(>G_YKU6pws5BziD zf*7;13@!^D*yRq0^p610A4|5x63cs126ef!u^L;>jpL_~zpn*sug^xk_B zBE3mT=%5cEEfDF=fYN&mEp${Ml+bJFN(&(&Nbm5*d!M`B_gm}!2VYiJ);ecp&z?PJ zX7+E->@(Z%Xncv+T`LR8DV!gFwjS-R;j3;Z$U3{HW)OgZC+8Lj9j_60I2DXttKvbNVZV%RfBUQ4tokoa0?Kx;uBBU` z^t()0?(e*&AHSEf+dXkIU)E$DFCTxuwZ{^&i~lx86q_vSieR=|M}}yp z-{vms>G879>Au3>rp7I7Fl&G|TAF+H=UDW^=b@z(FpEz12a4xohL&)T0^%IZMx zGtv1*4?^m$fMc0c{9RtWL|$n)D(LRvH{um3Y&hb2@}yKpHBg3L(CJ5RL7A|-+V$1` zj9A53qt6T7ru7Zhw@8ePFt#tzj5!wXB3wwyA)2%!tEn73)igb(E9t4a(PPA0eZcY; z_U@824g+1#a`~l*i4vx3Dk@npJdSybYH?x4$cRF^Lj$xRWgXs^U zTBR@Aun(ac740p;5%oM8mStZk4iEkk-I|=E*+Qwwt))P-e~k`#7(8#h8`o9QJ|=Fy zd+#ytD?6y8U-TrKMKI_f#)kOV&5znYF#ciaiML9{v`t5>GVzI~z9((YD0Fwx9z4jL zpgSEt?m`dhN_7Uu89xYdzfy>+cmbIb(fDK-$Cl8rCM?A;_t5?eu2^F?%OvJpzjR2Z zQ9k%7yRUNomp8-AV|?Z%wnY>Lyb!hZOdX51-Y9tP@p!R$D>^D?yK#j#gjpxomb=Z1 zVl_oTetet%EuSSUoM@(?E%@r?k?oj(MPZ)>Ke>iE@&TfQPVV-0OiGCWxweXUDIxW> z-OhC|LOtObKj02HmO95QeC9^cue!gYl};(4nbFOe?gpICH;9rVp3v0CbwSN;tg3Y*iBKNd=065MLuxJ&b^3v)0&kvDcWk7 z97zJUf9&aGOf`w%(8qWZx`Sq^CE6Jy_HP`JYl#^UPc31OPW@g#41lS8YoYe6chjLc z9<+*1PZ@l@T?k6A+5_i}&xib(?TYuia~ey_ykFTJeHsGr$A++Iph;6jbr0m8){29C z_3&N)XS#Ze$?2`CKx^GKm2=-16n4fZt%;(p%_U|H%lOj;6ql5YuStL@+*u6YS%J2N zM1izNM?Ad3{R&zGVs`F{acLj#W~alf5}4{m#VF=(D6j}#ZY$G%W~%pHrrc$hymmmTEkU`|@~byWu}^7h5#YC*!QgqAP=s{tWHuO=_-ujtnPLna zXXFp4y`5&;*$Xy7I&&5l(B*ngIlIo*-9?sAK|du&d@2{?OP9V;Es9&3FJ|2ZBt@S* z)Ekm{Zkf@BUt`IL|GkAT=2w03QtNU6Z@c*GaJA3vw8I|QtGmo1S}FW=VQ4!W3q2~< zJ1j#(&37nz>&s`GZ`{t>?E3xahW+Z8t%)mrB5tw_T)oBKetTHdS=OB>cOx|TfbukM zTOvc09SmGZ)$9P;Q6> zpr~W;rtiJ;}9c;c~Gd>*i>T?N7 zw0RJp_YJ}$BF0M?Y10Zx=DZ}Pl2&#mm@MaFij9G%#6u?3K1`WduIM;*@{BEXpUZVN zKyIBuFfuQIqnEXI09&GzArdsQIP2Jpk3wu|M zCP>S%a273phVi~yNN}o`4th*t$py@)3hC4=vxO zbh3k3+#x>WL$NBwlT4N{3eR=MJO|L)X>FA0+2311`IBcg(dc+**l9^wqtv8ex#^I{ z%*3(vbpK#&c~K5Fyup;B6*{M^mZG*$Lw@1;=nL{TCBRabF2q$5q3*smeg&A*f;*iGYagw%a^X_%Pc&b%{~7E4uj)^i@bFaqybP=6I}Tm(WWZAG}eKn0w1eV;Xfe ziR;y{i**N9Tuxcx478R~#T-^7Da47onh_zHo@^K-1- z7(IAc(;JS{1QP__W=KCp=xw!$t%*@81p5A`OEVBrkq#a>_3bmAm zOUU)-+ZeTRVZulAxv17r`nKQHEO*Vp^AQm(0u z44NjbORBRM-La1I&9}v4vviDV89UbXu4Is&8xFqtkC&`E2_sVCeWZEK%2Kyh8D0US8(}PPs)1sq^;3Q;a-DNs2V@| zTM5@1Ll(>()OU!dw)||{%ozNb{$Wcw1+b3veYP>~iO2N&y{5Y$rOp5aJHGF~nZV?s z=!OF4xD9`$dK%U{L9aZG=Y6_}Qz}R|z4nAJC)N{xZB4j@=2Z-v&Vd>UU$9hE#?-{= z00|gZuTWj5no?S4Kk;Q@vv}X4{4{VQ^hDI<>&lj)6{Z%_Qcvx2w^nvBe=MOYKDfb9 zPfmQ(Nl2Dm)^5T>Iy(io*gRRQUN>{PriNH~pC|`5G{zk>D?{4&vz%qWX63k;YRlG9 zI>V;nH0(RqjZZoKcCXa-%Qo)XsvOyb z^F~NgfzsP$0D6U~X#n|fklYWziBPdZJon7T_XoAG_lmi5jN}p*>AfM5DI4EA8fCg& z0xo%{xhhWN_N!%mY;ojV zm2W?|p2Dt|h<`~5F822B-hH4nRxDyf?9u@PuBWO^?HBLNRrD6^>ozLGM?(}0c&-?8 z;p)xdk;hgQ>}GWgug1m7)!+_$Hk1fS+mK4Ba1JW{FklT;@ljYQ>kXmi<4Yfm*wDpX z+RO*&lkB63adAlN8yK9n_Z09{!)kfkVNCp)YsdGs_28SpB!cfy6d*dwX1(d__7xJA zC*heV8b&RB@@b26$DVjccZ*|7w1!+PNu)w-Np;a7z41dkaD}bdjsJx8O~(Km18&!uXX1CU>aC%FkhjV#1<|Q`;Dhnsh^H*O@8#61?zOU0(f3GzT(|oWe?3xjrth)Z|7(LxHV<}P|~quK!>iqD#ylfU20MW>glo)vF{d}Q)y46 z%9;LUXk6e>zyuBtK6<~jN^n}GfAaw75#Qz8u(IL>@5^lsf3!__C|BX;@ z)*qP02#yPGFzWdcl>GF506$7?z3E5hH8k!F;U>P$&XX4`)70V9+ZZqt|M6eu(Rf_y z-}2P0IAqX0q|#32ygT3#cc;k>?S$0E>3gGNmN?G{m9%$g0U@6^ic^_YhKlvQM`gK! z(a4?gTPOmKB|$`#3FAA~lzQYsq_0GDoF?*-EkZB8{_t5|VzR$5MC6U1Ua!~^gYL<( zQ9)VwYcw(Wff!^Yv53}3Kk;){@6AjWm?U^CUF&WC4EX+%M3j|qVt~^Yts=1J)UVc} zZ!ugGaXfaYxf=2&aBguF((?0s_EZ#F8^oOxCsG`>lIe?p}!zqaSu4W2N8$S-;B zCEl^77gfmCc~r7~FG1=?ZTzRAnkZ-0+i8q&)!d4%{-gdvTe$w&R5~h1a0a_MT5#VF zgab%FlyqoVaEk?~pBB56n?vIc6I}Q>J=K6;VD($joV-fxPJF_I&lYyr?ju&XunsE> z`u1E7AZJhXLu*6%D=R!5lCA5~O!Nqiw#sK`J>an|P>?%_*zF)ZvmXc+28)Am>V`ly z!!zCGS_^fJB*ZapDB4YF&Ok2*XSRw!CQa0I;76gUiYLCgH^#S@#}xADbXN-oVzi^4 zu9pG^dtO-1&p1XqB{#;PD?#;@rM$18b2I;nJ|QC~IYQRgZgUitoG3UYj2x^4_*no( z>^_sB3cMoszGWqnwqmxHu^Knv(|_VZ+;484dN$t*GH-6B9j~YqGP}Fi%qT5s??Ts@ z>NBU2mE>$d6ob%*6(zn(codxk_4n|9v^y{%E6OC1$S=z=n{=j} zQ-uU3f>HUDiDF$+Ne-%?YCVV08ohitVcqymgK{bBrpLu^Ou2EO$?gV#;Y_wrubWEc z`KD~)Jy6i9m|(>n2Bk=oh!bD2mz_yzc+YSCJR3)tJz`L&Z=QzQekMRM^mOT)mow+|j74 zpg?K$%_TlE3PtpPbg|+_nZsCAzdq`)&D?F>=-Qi5<%~{Swnz+g8ve?Sm-M>F`woa#wfYhEgQr*yP^{NMy*HR^4K1aX7)99bmbFDSyQkkjfuO~=q)^|U450_y zSD*K1A`X$|1R!KtN?Oy8NI!a5!w>1Jiu#LI)E*5x@gIhC#m|Cj??&6Tzg1;Isko|P zDyA+2fc%cKq!t8L7&;eFgcM!WJC_LR%I=6i>3aW-2O9@}5@Oa|R@-L{;!b2|zFq3o zoEjVBEW_?}5A>9L?D(^{X z%nq_gt`-nv1aIS=|Bp}!R8N`h)d#I9K-?AYT6qophb$YBCQYqTNK7BRB6Z{2Zp|am zYq}H2$gL-IC8spJDQ-ec(t&#JjCAlQ^~Dk#yr)77TQPGe5+ERn8ttlLZ>z0hEPP!? zxw~Cj7HDstndF|z_jQ%w$fN*WvtRjTMc@X?OR1kV zmbKK=)Fsc8dGw)!BZN0S^qcI01#~*=qnfbFjw*WQCh`Gqjiu%1-z9ig!y=7r{5swt zDw{}+4&&m7@6U2SR^Ir7YFz0G*1Y2223EGE&PFK>sg`6Z-mO;${yY=(i?`yrTRMZH z0~G-OYX|C*Pa*m#C;-oR{fj!m)-pv9FUzwuu5Lz^`}_DTlj$gXw|eWkvEn83k2MtzM_6Bj$Q=6=m@A;R%4V- zc>Y{~4w0Q^B;~IYJTvO@e4^SYd|SP|rVsz#5W1nY`NGRNvFs!+`KbvMu53$61|oeP zFR}F;896*8xyB9mG&E%M_+H-B8ne`QcsM(sKdzFwG9|+lWqSrxVvu|_GCokq=>J&l zI$WrisFk>J;vj>joQ(kDMe8u^u1e-@h& zXUy>F$wv|x-HR&+SQ~NgxnA!5y)AQ>uX97L<>(6doMQjBeB)zsUgJx(VR9};7y^U{ z>`GQmriehcv!653X$)e#!PQlx%@edB2TK7<{tCqFkJefhZ6j(7>92($4+HfCdI~}+ zdo;#Rsn;8)!L+kDF z%DDi~=b{Sig^C!eNd-;UnqW>G=e?=oK1sS=QFS;LGpFVh>BlXMGLy426qu9tNE zQD2XVi2*>!;fqmj1Yj;q-36tYjj;+|S;u2B4?U6~k^!$B+9TM{)~vP4b~f$X9^|{o zgBTd4%oTZ@W63~_ZZz%9Pg>#N{S`C1QayTa2W$5yo~`DdXooaBln@Hpml=3BYPUR*`SJ|Bwf&aZAC-2VJnS~OIl+r*Vnew-P>cb zN0r8ys3%+Jqe>##RVUEEiy3X(`Cm_e|7?K$#^<#CAfDkqo69~Dw$6`(%lhIRruLa) zk=5LDZV6@JQLoDVordZ4n;@Jg%-p9u9o;Xp8Gk1f)D>X4cUGOIId(XE!*8iw&XiOu z9r7k%nW*e{c-mEY7F)O}6?R;-(CKmhyA2z9{E8FO428Rdt~yMmMIv{oS<(^_Sf`EV z%GAF*v6jS0#j-a)lSULPSfop6XKaZI5#zTl_Vn9Vo%34B@Zl*$B;|mBH&SpiEw>=C zxkjKXmkuZRQM32Qi+oXXlz{1zA78(=-n>hQ{B2YC@W&i8+px9I_o3hC3`&&j4}Nr0 zn_amRM-zAdj(omjvZIiaGL@3S@;!w(3Gv6g_d)1zTVnbbhTJY@c@3X*;o_e9bQ?x? z+W4)dAHp^kS06)RFNop4K0HRpD7_x(piY18`RK7Wb2LwGPo97Voo00Ny{B9>p8Cox zF|b&9`DRB-8d~)006pk*C^l{3WZsuuA9vyV|5zQ+iMY%p;Xl#eYN5{C92@MBa+=F) zDQh=G%hcaD3A{%{YduOq@};FU7v^&Xp6anc(F1mCk1||6Q8jCu7ISORWqoaa)e=|W z$q>OjQC|kJ5};I~D0iiRW@&@8Y#})5#Uzc9dZ=1YyD7J1M=X!t8tpBXDaem|M|K$L zs{!<&gE}W-IeI(|I!x(?Lm=LdM_rW?W&npHz2OL952upbFNo=u{BtYF)ImF`Q*WqbKICZe^nZ7@7N`(~B3 zCQP+|hSxm(dZ5{7X~$JO<+!UgehwSG967L_QlNC8;Z%`iBz*>mv+UY}v$)sW{CL{B z6(`uXSK+a!wl-t7aBkGJ2dLOI+1rr7BNRecCVUka#*gFG7fwjlJ-#_sp=W!V9ISVC zGT)O|Ir)&P7(E|y;&mH-J0Uy8{`rjkg%&h}QOwD(IuxhwrNMN~d#X-qIr+Hqbug23 z33i3)hQF4_FDY$WCJppJd$T01y{{KZ@foa144))495__wnd5gwU{$UM zXQnLDfm4b-TA;K>0u*zTu}g$zDXG|h5pcR*EeIW~YL$Ql%$v#b_?nv$%k!8r2S4Cl zmYgVwIEE+Jdav_fezEEpno5Xo^>DI7JL1(j9G3xvgFVHo)-1Dn!`G+v$6VGZq?p68Or8gVNYG z+9MH&NP4V6>UN@WAkbQfS=wXD%<@xoWp-vB+j{=2448>cJvT^r3bUl{CJYgUx3FI!SXP z-^?1lv^7$6m`TunwmuFX&lE_!+E}0!iykl^IrFqfQ@RG#_ta{vOO9~YiTN@#72 zckFraxXH@Lm%1T2UnCjFplAP>h2O_>OaG2Ox1dEWB9KRwi$5>YRC&IJyk|0LLh{55 zPvV(-wt*oZHn(qZYHW6Hkb2H2J$mZLJ=EHI;Odr8(qxFvGt}lT4UqHFz|~Dh`6qWN zGtD>?A00^~y5Y_hwNDpm+spC(s(lQlS1kt#+xRMH5O`Htncuy&%1vbC<{f?ZZA%LqQ|d^8xY{#5PgFZ~e8WYURv zS+7M{dX1rUjn5?$7SSlKjJmAn);pEqV^y=DV24DKS>YtO z*D8Kuj2>kC`==nZ@~!E~f>5FV-BzdP(tA@3)(te?!!r_~B1|Q&0zJU_5*1V@;G$V& z9cX%`;02R(!d^DPd-7oV94NbQ5Aae*$!@|(PI4=iaNNh#an>WVdBaDz zXPn^{qGZr|-6;7xe<-hdD0$t%fnt8V12wRS$avb5 zfQ>kK9dk|{W7A;U=`<%KWG7PaEio;}W@=}ZKvm-t3$JaS!1mNINefyk5o`9IBa;6& zqi7j5{she40@Y$fO#=@yWFLuPe;FTJTS0-=QD{87=*NW@+ zhmjeM)g9M1qhlRBZGm>mxia3nKn{3@7%^50^;FAJb^AelyzRAHFtV94uNNR+qJm$L z|M#PPt;jUHT~yUfgVy36Ij)wWMtL9(Y|1xh(b@IpuwQCs=2K~wWM?r#H~OSY!n@{` zC4FF(YcP6%FA{g2zh{3Cy%3OW%Dpgw`%XOBQ$YuRLbc2q>jTJh+2g+Yg*t-Gg4Jfb zU2>3xbOk0FUdJeX28FNLBFPY9Umz*t#=H<2`Deg!`QuZd$bWK1mDf6Ed)7WaJSxYv zQQOw71uNqN)h@L09x3Nm_!G5@&$}=+xwq>Vz%w1ZQbDNUv2Q_E@;PURSgf4Uar6kk z8jAPK0Kb24Q~_e#XkJpJhqzar^+vAXTrDi+NAib=bR|(ep59zxwP7_7KSl`{0gT|? zqYz#&=I%<_R*|$DVvLXSzw?spPE2&KqfSfsW?YCw)#^b7Z}XWXb?dxqW(T;0LzHLR zhB#z@+q}Qz(UGQL&`&ABz}Rjed(ejs#2Gq82QwC8*<3j+fnpNE&CQqKi~Xx*>hobdk5HOl&DTGChU zq>2NKvyt_jdciSlO?_Hn8XTr*%Uj#(l)$JtjGB31s$G{5Y@!Klxc(=ZR{kW+ob0~H z{5@F4MvmxCHW*d=rfOedsx( zz2qBuOtYu$&xX+$nu~7ik~Epz(QS?Yvy>tSndZuC4$kWy%XG)fA-#D(MGK(GRIKVk zAXVtm@nw6G41rquxYfh&$oQ~RIlF7!`HzX#Dw8*tgn!h!M0O_sXxwfWc8b=4AHF+{ zaw#k4Huz2^ks#fR$QI5&3@f}I!`Hq~{)7a?+>B>z7eZ#-chZuR z>w{L;;kzsqJh>97Eu8lqK}86JjJwn;;-LuSn61}=aJzQ{X`90!b3xkNnhrhtYm_9n zA&@pb_;ev|XG6KOOg9 z{QQZ(yY;&P22D%Px!7PA4&0PPfPBQSO=H1m>@X%x=HjmbAPDCI;7KD~TOq?VVIR5B zP5o2Sj^NfZd3BSW)YdtCSGN5kXp14Qb07p-`S`zbSRjWUP;qwCw~qGtoFDgctKEac zF71dC-gO`{7s8wZYm2LQ8%PCU8cfX1$}5DngJd=zpQ;hHSaGN zj#Hzxtc~ZqCQWU*-3}y|3>o(~rC`V4GP%+*K%(wGM2nl|!QYu&7x?`So_a5Np*&t5p$%?EYK|oeKFc=y(YT6kf^m zUkA3bLiX+nvr7ob5 z*T)x|ThYx2jR3AdRunyhTQp$wQo#ApTmGnjywgn%_$4wb`wZ-rSFYTGCEXI}32vM1 z9bqSrsjiuERg0S-JF(%YnqfgvUr=sEamdw5fZ|qRr{WNvEUVws4(4xQfVANqq`R_u zmIyk&h#cxHNBph*zZRmRu8v6xO*MF{zW9xeB~89ABb)GI_BaP+u7BsOuZT#1TyMa* zXcd`=O@XYKQovdzgGulaT;XK(W?LC=N&T-Xkn*Cym#vSJvVsM#V*y74jc{TYS&51d z$F=LEA%C~DW=eAQyO@40X%T*P6aVN>BnLlQ6e=e!Ui;##oiFcieXk2HDk9R8hf(@I zD6-MQQj2)id$#YBa9*?1>1o;qCQLPrR+Mbfm*W5&;rHZKRJsR*uN{;Gtr1d-7tM(a z)bc7tWPg57{w9ZlSMF(NBJBp0&&2M6d`C(~3_@;llLB&&#tYsXF`x9w@nr39Ige<8 z!y6}3;rB*Lz92dO5dFGD{J?Yck;9I3YlA?(F8c6qir?d#SnSWQIauwMSoA{a`EH# - -## Introduction - -This guide explains how to use JAX in environments such as -GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where -accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer -to these as “multi-process” environments. - -This guide specifically focuses on how to use collective communication -operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although -other communication methods may be useful too depending on your use case (e.g. -RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already -familiar with JAX’s collective operations, we recommend starting with the -{doc}`/sharded-computation` section. An important requirement of -multi-process environments in JAX is direct communication links between -accelerators, e.g. the high-speed interconnects for Cloud TPUs or -[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow -collective operations to run across multiple processes’ worth of accelerators -with high performance. - -## Multi-process programming model - -Key concepts: - - * You must run at least one JAX process per host. - * You should initialize the cluster with {func}`jax.distributed.initialize`. - * Each process has a - distinct set of *local* devices it can address. The *global* devices are the set - of all devices across all processes. - * Use standard JAX parallelism APIs like {func}`~jax.jit` (see - {doc}`/sharded-computation` tutorial) and - {func}`~jax.shard_map`. jax.jit only accepts - globally shaped arrays. shard_map allows you to drop to per-device - shape. - * Make sure all processes run the same parallel computations in the same - order. - * Make sure all processes has the same number of local devices. - * Make sure all devices are the same (e.g., all V100, or all H100). - -### Launching JAX processes - -Unlike other distributed systems where a single controller node manages many -worker nodes, JAX uses a “multi-controller” programming model where each JAX -Python process runs independently, sometimes referred to as a {term}`Single -Program, Multiple Data (SPMD)` model. Generally, the same JAX Python -program is run in each process, with only slight differences between each -process’s execution (e.g. different processes will load different input data). -Furthermore, **you must manually run your JAX program on each host!** JAX -doesn’t automatically start multiple processes from a single program invocation. - -(The requirement for multiple processes is why this guide isn’t offered as a -notebook -- we don’t currently have a good way to manage multiple Python -processes from a single notebook.) - -### Initializing the cluster - -To initialize the cluster, you should call {func}`jax.distributed.initialize` at -the start of each process. {func}`jax.distributed.initialize` must be called -early in the program, before any JAX computations are executed. - -The API {func}`jax.distributed.initialize` takes several arguments, namely: - - * `coordinator_address`: the IP address of process 0 in your cluster, together - with a port available on that process. Process 0 will start a JAX service - exposed via that IP address and port, to which the other processes in the - cluster will connect. - * `coordinator_bind_address`: the IP address and port to which the JAX service - on process 0 in your cluster will bind. By default, it will bind to all - available interfaces using the same port as `coordinator_address`. - * `num_processes`: the number of processes in the cluster - * `process_id`: the ID number of this process, in the range `[0 .. - num_processes)`. - * `local_device_ids`: Restricts the visible devices of the current process to - ``local_device_ids``. - -For example on GPU, a typical usage is: +# Introduction to multi-controller JAX (aka multi-process/multi-host JAX) + + + +By reading this tutorial, you'll learn how to scale JAX computations to more +devices than can fit in a single host machine, e.g. when running on a GPU +cluster, Cloud TPU pod, or multiple CPU-only machines. + +The main idea + +- **Run multiple Python processes**, which we sometimes call "controllers." We + can run one (or more) process per host machine. +- **Initialize the cluster with {func}`jax.distributed.initialize`**. +- **A {class}`jax.Array` can span all processes**, and if each process applies + the same JAX function to it, it's like programming against one big device. +- **Use the same [unified sharding mechanism][unified_sharding]** as in + single-controller JAX to control how data is distributed and computation is + parallelized. XLA automatically exploits high-speed networking links like TPU + ICI or NVLink between hosts when available, and otherwise uses available host + networking (e.g. Ethernet, InfiniBand). +- **All processes (usually) run the same Python script**. You write this Python + code almost exactly the same as you would for a single process — just run + multiple instances of it and JAX takes care of the rest. In other words, + except for array creation, you can write your JAX code as if there were one + giant machine with all devices attached to it. + +This tutorial assumes you've read [Distributed arrays and automatic +parallelization][distributed_arrays], which is about single-controller JAX. + +```{figure} _static/multi_process/mcjax_overview.png +:alt: Illustration of a multi-host TPU pod. Each host in the pod is attached via PCI to a board of four TPU chips. The TPUs chips themselves are connected via high-speed inter-chip interconnects. + +Illustration of a multi-host TPU pod. Each host in the pod (green) is attached +via PCI to a board of four TPU chips (blue). The TPUs chips themselves are +connected via high-speed inter-chip interconnects (ICI). JAX Python code runs on +each host, e.g. via ssh. The JAX processes on each host are aware of each other, +allowing you to orchestrate computation across the entire pods' worth of chips. +The principle is the same for GPU, CPU, and other platforms with JAX support! +``` + +## Toy example + +Before we define terms and walk through the details, here's a toy example: +making a process-spanning {class}`jax.Array` of values and applying +{mod}`jax.numpy` functions to it. ```python +# call this file toy.py, to be run in each process simultaneously + import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +import numpy as np + +# in this example, get multi-process parameters from sys.argv +import sys +proc_id = int(sys.argv[1]) +num_procs = int(sys.argv[2]) + +# initialize the distributed system +jax.distributed.initialize('localhost:10000', num_procs, proc_id) + +# this example assumes 8 devices total +assert jax.device_count() == 8 -jax.distributed.initialize(coordinator_address="192.168.0.1:1234", - num_processes=2, - process_id=0) +# make a 2D mesh that refers to devices from all processes +mesh = jax.make_mesh((4, 2), ('i', 'j')) + +# create some toy data +global_data = np.arange(32).reshape((4, 8)) + +# make a process- and device-spanning array from our toy data +sharding = NamedSharding(mesh, P('i', 'j')) +global_array = jax.device_put(global_data, sharding) +assert global_array.shape == global_data.shape + +# each process has different shards of the global array +for shard in global_array.addressable_shards: + print(f"device {shard.device} has local data {shard.data}") + +# apply a simple computation, automatically partitioned +global_result = jnp.sum(jnp.sin(global_array)) +print(f'process={proc_id} got result: {global_result}') ``` -On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no -arguments. Default values for the arguments will be chosen automatically. -When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will -be assigned only one visible local device. Otherwise it is assumed that one process is started per host, -i.e. each process will be assigned all local devices. -The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`. +Here, `mesh` contains devices from all processes. We use it to create +`global_array`, logically a single shared array, stored distributed across +devices from all processes. + +Every process must apply the same operations, in the same order, to +`global_array`. XLA automatically partitions those computations, for example +inserting communication collectives to compute the `jnp.sum` over the full +array. We can print the final result because its value is replicated across +processes. + +We can run this code locally on CPU, e.g. using 4 processes and 2 CPU devices +per process: + +```bash +export JAX_NUM_CPU_DEVICES=2 +num_processes=4 + +range=$(seq 0 $(($num_processes - 1))) + +for i in $range; do + python toy.py $i $num_processes > /tmp/toy_$i.out & +done + +wait + +for i in $range; do + echo "=================== process $i output ===================" + cat /tmp/toy_$i.out + echo +done +``` + +Outputs: + +```text +=================== process 0 output =================== +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +process=0 got result: -0.12398731708526611 + +=================== process 1 output =================== +device TFRT_CPU_131072 has local data [[ 8 9 10 11]] +device TFRT_CPU_131073 has local data [[12 13 14 15]] +process=1 got result: -0.12398731708526611 + +=================== process 2 output =================== +device TFRT_CPU_262144 has local data [[16 17 18 19]] +device TFRT_CPU_262145 has local data [[20 21 22 23]] +process=2 got result: -0.12398731708526611 + +=================== process 3 output =================== +device TFRT_CPU_393216 has local data [[24 25 26 27]] +device TFRT_CPU_393217 has local data [[28 29 30 31]] +process=3 got result: -0.12398731708526611 +``` + +This might not look so different from single-controller JAX code, and in fact, +this is exactly how you'd write the single-controller version of the same +program! (We don't technically need to call {func}`jax.distributed.initialize` +for single-controller, but it doesn't hurt.) Let's run the same code from a +single process: + +```text +JAX_NUM_CPU_DEVICES=8 python toy.py 0 1 +``` + +Outputs: + +```text +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +device TFRT_CPU_2 has local data [[ 8 9 10 11]] +device TFRT_CPU_3 has local data [[12 13 14 15]] +device TFRT_CPU_4 has local data [[16 17 18 19]] +device TFRT_CPU_5 has local data [[20 21 22 23]] +device TFRT_CPU_6 has local data [[24 25 26 27]] +device TFRT_CPU_7 has local data [[28 29 30 31]] +process=0 got result: -0.12398731708526611 +``` + +The data is sharded across eight devices on one process rather than eight +devices across four processes, but otherwise we're running the same operations +over the same data. + +## Terminology + +It's worth pinning down some terminology. + +We sometimes call each Python process running JAX computations a **controller**, +but the two terms are essentially synonymous. + +Each process has a set of **local devices**, meaning it can transfer data to and +from those devices' memories and run computation on those devices without +involving any other processes. The local devices are usually physically attached +to the process's corresponding host, e.g. via PCI. A device can only be local to +one process; that is, the local device sets are disjoint. A process's local +devices can be queried by evaluating {func}`jax.local_devices()`. We sometimes +use the term **addressable** to mean the same thing as local. + +```{figure} _static/multi_process/controller_and_local_devices.png +:alt: Illustration of how a process/controller and local devices fit into a larger multi-host cluster. The "global devices" are all devices in the cluster. + +Illustration of how a process/controller and local devices fit into a larger +multi-host cluster. The "global devices" are all devices in the cluster. +``` + +The devices across all processes are called the **global devices**. The list of +global devices is queried by {func}`jax.devices()`. That list of all devices is +populated by running {func}`jax.distributed.initialize` on all processes, which +sets up a simple distributed system connecting the processes. + +We often use the terms **global** and **local** to describe process-spanning and +process-local concepts in general. For example, a "local array" could be a numpy +array that's only visible to a single process, vs. a JAX "global array" is +conceptually visible to all processes. + +## Setting up multiple JAX processes + +In practice, setting up multiple JAX processes looks a bit different from the +toy example, which is run from a single host machine. We usually launch each +process on a separate host, or have multiple hosts with multiple processes each. +We can do that directly using `ssh`, or with a cluster manager like Slurm or +Kubernetes. In any case, **you must manually run your JAX program on each +host!** JAX doesn’t automatically start multiple processes from a single program +invocation. + +However they're launched, the Python processes need to run +{func}`jax.distributed.initialize`. When using Slurm, Kubernetes, or any Cloud +TPU deployment, we can run {func}`jax.distributed.initialize` with no arguments +as they're automatically populated. Initializing the system means we can run +{func}`jax.devices()` to report all devices across all processes. + +```{warning} +{func}`jax.distributed.initialize` must be called before running +{func}`jax.devices()`, {func}`jax.local_devices()`, or running any computations +on devices (e.g. with {mod}`jax.numpy`). Otherwise the JAX process won't be +aware of any non-local devices. (Using {func}`jax.config` or other +non-device-accessing functionality is ok.) {func}`jax.distributed.initialize` +will raise an error if you accidentally call it after accessing any devices. +``` + +### GPU Example + +We can run multi-controller JAX on a cluster of [GPU machines][gpu_machines]. +For example, after creating four VMs on Google Cloud with two GPUs per VM, we +can run the following JAX program on every VM. In this example, we provide +arguments to {func}`jax.distributed.initialize` explicitly. The coordinator +address, process id, and number of processes are read from the command line. ```python +# In file gpu_example.py... + import jax +import sys + +# Get the coordinator_address, process_id, and num_processes from the command line. +coord_addr = sys.argv[1] +proc_id = int(sys.argv[2]) +num_procs = int(sys.argv[3]) + +# Initialize the GPU machines. +jax.distributed.initialize(coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id) +print("process id =", jax.process_index()) +print("global devices =", jax.devices()) +print("local devices =", jax.local_devices()) +``` + +For example, if the first VM has address `192.168.0.1`, then you would run +`python3 gpu_example.py 192.168.0.1:8000 0 4` on the first VM, `python3 +gpu_example.py 192.168.0.1:8000 1 4` on the second VM, and so on. After running +the JAX program on all four VMs, the first process prints the following. + +```text +process id = 0 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=0), CudaDevice(id=1)] +``` + +The process successfully sees all eight GPUs as global devices, as well as its +two local devices. Similarly, the second process prints the following. + +```text +process id = 1 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=2), CudaDevice(id=3)] +``` + +This VM sees the same global devices, but has a different set of local devices. + +### TPU Example + +As another example, we can run on [Cloud TPU][cloud_tpu]. After creating a +`v5litepod-16` (which has 4 host machines), we might want to test that we can +connect the processes and list all devices: +```text +$ TPU_NAME=jax-demo +$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \ + | grep externalIp | cut -d: -f2) +$ cat << EOF > demo.py +import jax jax.distributed.initialize() +if jax.process_index() == 0: + print(jax.devices()) +EOF +$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c ' +scp demo.py $0: +ssh $0 "pip -q install -U jax[tpu]" +ssh $0 "python demo.py" ' +``` + +Here we're using `xargs` to run multiple `ssh` commands in parallel, each one +running the same Python program on one of the TPU host machines. In the Python +code, we use {func}`jax.process_index()` to print only on one process. Here's +what it prints: + +```text +[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)] +``` + +Woohoo, look at all those TPU cores! + +Once the processes are set up, we can start building global {class}`jax.Array`s +and running computations. The remaining Python code examples in this tutorial +are meant to be run on all processes simultaneously, after running +{func}`jax.distributed.initialize`. + +## Meshes, shardings, and computations can span processes and hosts + +Programming multiple processes from JAX usually looks just like programming a +single process, just with more devices! The main exceptions to this are around +data coming in or out of JAX, e.g. when loading from external data sources. +We'll first go over the basics of multi-process computations here, which largely +look the same as their single-process counterparts. The next section goes over +some data loading fundamentals, i.e. how to create JAX Arrays from non-JAX +sources. + +Recall a {class}`jax.sharding.Mesh` pairs an array of {class}`jax.Device`s with +a sequence of names, with one name per array axis. By creating a `Mesh` using +devices from multiple processes, then using that mesh in a +{class}`jax.sharding.Sharding`, we can construct {class}`jax.Array`s sharded +over devices from multiple processes. + +Here's an example that directly constructs a `Mesh` using {func}`jax.devices()` +to get devices from all processes: + +```python +from jax.sharding import Mesh +mesh = Mesh(jax.devices(), ('a',)) + +# in this case, the same as +mesh = jax.make_mesh((jax.device_count(),), ('a',)) # use this in practice +``` + +You should probably use the {func}`jax.make_mesh` helper in practice, not only +because it's simpler but also because it can choose more performant device +orderings automatically, but we're spelling it out here. By default it includes +all devices across processes, just like {func}`jax.devices()`. + +Once we have a mesh, we can shard arrays over it. There are a few ways to +efficiently build process-spanning arrays, detailed in the next section, but for +now we'll stick to `jax.device_put` for simplicity: + +```python +arr = jax.device_put(jnp.ones((32, 32)), NamedSharding(mesh, P('a'))) +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(arr) +``` + +On process 0, this is printed: + +``` +┌───────────────────────┐ +│ TPU 0 │ +├───────────────────────┤ +│ TPU 1 │ +├───────────────────────┤ +│ TPU 4 │ +├───────────────────────┤ +│ TPU 5 │ +├───────────────────────┤ +│ TPU 2 │ +├───────────────────────┤ +│ TPU 3 │ +├───────────────────────┤ +│ TPU 6 │ +├───────────────────────┤ +│ TPU 7 │ +├───────────────────────┤ +│ TPU 8 │ +├───────────────────────┤ +│ TPU 9 │ +├───────────────────────┤ +│ TPU 12 │ +├───────────────────────┤ +│ TPU 13 │ +├───────────────────────┤ +│ TPU 10 │ +├───────────────────────┤ +│ TPU 11 │ +├───────────────────────┤ +│ TPU 14 │ +├───────────────────────┤ +│ TPU 15 │ +└───────────────────────┘ +``` + +Let's try a slightly more interesting computation! + +```python +mesh = jax.make_mesh((jax.device_count() // 2, 2), ('a', 'b')) + +def device_put(x, spec): + return jax.device_put(x, NamedSharding(mesh, spec)) + +# construct global arrays by sharding over the global mesh +x = device_put(jnp.ones((4096, 2048)), P('a', 'b')) +y = device_put(jnp.ones((2048, 4096)), P('b', None)) + +# run a distributed matmul +z = jax.nn.relu(x @ y) + +# inspect the sharding of the result +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(z) + print() + print(z.sharding) +``` + +On process 0, this is printed: + +``` +┌───────────────────────┐ +│ TPU 0,1 │ +├───────────────────────┤ +│ TPU 4,5 │ +├───────────────────────┤ +│ TPU 8,9 │ +├───────────────────────┤ +│ TPU 12,13 │ +├───────────────────────┤ +│ TPU 2,3 │ +├───────────────────────┤ +│ TPU 6,7 │ +├───────────────────────┤ +│ TPU 10,11 │ +├───────────────────────┤ +│ TPU 14,15 │ +└───────────────────────┘ + +NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device) +``` + +Here, just from evaluating `x @ y` on all processes, XLA is automatically +generating and running a distributed matrix multiplication. The result is +sharded against the mesh like `P('a', None)`, since in this case the matmul +included a `psum` over the `'b'` axis. + +```{warning} +When applying JAX computations to process-spanning arrays, to avoid deadlocks +and hangs, **it's crucial that all processes with participating devices run the +same computation in the same order**. That's because the computation may +involve collective communication barriers. If a device over which an array is +sharded does not join in the collective because its controller didn't issue the +same computation, the other devices are left waiting. For example, if only the +first three processes evaluated `x @ y`, while the last process evaluated `y @ +x`, the computation would likely hang indefinitely. This assumption, +computations on process-spanning arrays are run on all participating processes +in the same order, is mostly unchecked. + +So the easiest way to avoid deadlocks in multi-process JAX is to run the same +Python code on every process, and beware of any control flow that depends on +{func}`jax.process_index()` and includes communication. +``` + +If a process-spanning array is sharded over devices on different processes, it +is an error to perform operations on the array that require the data to be +available locally to a process, like printing. For example, if we run `print(z)` +in the preceding example, we see + +``` +RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards. ``` -On TPU at present calling {func}`jax.distributed.initialize` is optional, but -recommended since it enables additional checkpointing and health checking features. +To print the full array value, we must first ensure it's replicated over +processes (but not necessarily over each process's local devices), e.g. using +`jax.device_put`. In the above example, we can write at the end: + +``` +w = device_put(z, P(None, None)) +if jax.process_index() == 0: + print(w) +``` + +Be careful not to write the {func}`jax.device_put` under the `if process_index() +== 0`, because that would lead to a deadlock as only process 0 initiates the +collective communication and waits indefinitely for the other processes. +The {mod}`jax.experimental.multihost_utils` module has some functions that +make it easier to process global {class}`jax.Array`s (e.g., +{func}`jax.experimental.multihost_utils.process_allgather`). + +Alternatively, to print or otherwise perform Python operations on only +process-local data, we can access `z.addressable_shards`. Accessing that +attribute does not require any communication, so any subset of processes can do +it without needing the others. That attribute is not available under a +{func}`jax.jit`. -### Local vs. global devices +## Making process-spanning arrays from external data -Before we get to running multi-process computations from your program, it’s -important to understand the distinction between *local* and *global* devices. +There are three main ways to create process-spanning {class}`jax.Array`s from +external data sources (e.g. numpy arrays from a data loader): -**A process’s *local* devices are those that it can directly address and launch -computations on.** For example, on a GPU cluster, each host can only launch -computations on the directly attached GPUs. On a Cloud TPU pod, each host can -only launch computations on the 8 TPU cores attached directly to that host (see -the -[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) -documentation for more details). You can see a process’s local devices via -{func}`jax.local_devices()`. +1. Create or load the full array on all processes, then shard onto devices using + {func}`jax.device_put`; -**The *global* devices are the devices across all processes.** A computation can -span devices across processes and perform collective operations via the direct -communication links between devices, as long as each process launches the -computation on its local devices. You can see all available global devices via -{func}`jax.devices()`. A process’s local devices are always a subset of the -global devices. +2. Create or load on each process an array representing just the data that will + be locally sharded and stored on that process's devices, then shard onto + devices using {func}`jax.make_array_from_process_local_data`; -### Running multi-process computations +3. Create or load on each process's devices separate arrays, each representing + the data to be stored on that device, then assemble them without any data + movement using {func}`jax.make_array_from_single_device_arrays`. -So how do you actually run a computation involving cross-process communication? -**Use the same parallel evaluation APIs that you would in a single process!** +The latter two are most often used in practice, since it's often too expensive +to materialize the full global data in every process. -For example, {func}`~jax.shard_map` can be used -to run a parallel computation across multiple processes. (If you’re -not already familiar with how to use `shard_map` to run across -multiple devices within a single process, check out the -{doc}`/sharded-computation` tutorial.) Conceptually, this can be -thought of as running a pmap over a single array sharded across hosts, -where each host “sees” only its local shard of the input and output. +The toy example above uses {func}`jax.device_put`. -Here’s an example of multi-process pmap in action: +{func}`jax.make_array_from_process_local_data` is often used for distributed data +loading. It's not as general as {func}`jax.make_array_from_single_device_arrays`, +because it doesn't directly specify which slice of the process-local data goes +on each local device. This is convenient when loading data-parallel batches, +because it doesn't matter exactly which microbatch goes on each device. For +example: ```python -# The following is run in parallel on each host on a GPU cluster or TPU pod slice. ->>> import jax ->>> jax.distributed.initialize() # On GPU, see above for the necessary arguments. ->>> jax.device_count() # total number of accelerator devices in the cluster -32 ->>> jax.local_device_count() # number of accelerator devices attached to this host -8 -# The psum is performed over all mapped devices across the pod slice ->>> xs = jax.numpy.ones(jax.local_device_count()) ->>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) -ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) -``` - -**It’s very important that all processes run the same cross-process computations -in the same order.** Running the same JAX Python program in each process is -usually sufficient. Some common pitfalls to look out for that may cause -differently-ordered computations despite running the same program: - -* Processes passing differently-shaped inputs to the same parallel function - can cause hangs or incorrect return values. Differently-shaped inputs are - safe so long as they result in identically-shaped per-device data shards - across processes; e.g. passing in different leading batch sizes in order to - run on different numbers of local devices per process is ok, but having each - process pad its batch to a different max example length is not. - -* “Last batch” issues where a parallel function is called in a (training) - loop, and one or more processes exit the loop earlier than the rest. This - will cause the rest to hang waiting for the already-finished processes to - start the computation. - -* Conditions based on non-deterministic ordering of collections can cause code - processes to hang. For example, iterating over - `set` on current Python versions or `dict` [before Python 3.7](https://mail.python.org/pipermail/python-dev/2017-December/151283.html) - may result in a different ordering on different processes, even with the - same insertion order. +# target (micro)batch size across the whole cluster +batch_size = 1024 +# how many examples each process should load per batch +per_process_batch_size = batch_size // jax.process_count() +# how many examples each device will process per batch +per_device_batch_size = batch_size // jax.device_count() + +# make a data-parallel mesh and sharding +mesh = jax.make_mesh((jax.device_count(),), ('batch')) +sharding = NamedSharding(mesh, P('batch')) + +# our "data loader". each process loads a different set of "examples". +process_batch = np.random.rand(per_process_batch_size, 2048, 42) + +# assemble a global array containing the per-process batches from all processes +global_batch = jax.make_array_from_process_local_data(sharding, process_batch) + +# sanity check that everything got sharded correctly +assert global_batch.shape[0] == batch_size +assert process_batch.shape[0] == per_process_batch_size +assert global_batch.addressable_shards[0].data.shape[0] == per_device_batch_size +``` + +{func}`jax.make_array_from_single_device_arrays` is the most general way to +build a process-spanning array. It's often used after performing +{func}`jax.device_put`s to send each device its required data. This is the +lowest-level option, since all data movement is performed manually (via e.g. +{func}`jax.device_put`). Here's an example: + +```python +shape = (jax.process_count(), jax.local_device_count()) +mesh = jax.make_mesh(shape, ('i', 'j')) +sharding = NamedSharding(mesh, P('i', 'j')) + +# manually create per-device data equivalent to np.arange(jax.device_count()) +# i.e. each device will get a single scalar value from 0..N +local_arrays = [ + jax.device_put( + jnp.array([[jax.process_index() * jax.local_device_count() + i]]), + device) + for i, device in enumerate(jax.local_devices()) +] + +# assemble a global array from the local_arrays across all processes +global_array = jax.make_array_from_single_device_arrays( + shape=shape, + sharding=sharding, + arrays=local_arrays) + +# sanity check +assert (np.all( + jax.experimental.multihost_utils.process_allgather(global_array) == + np.arange(jax.device_count()).reshape(global_array.shape))) +``` + +[cloud_tpu]: https://cloud.google.com/tpu?hl=en +[distributed_arrays]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[gpu_machines]: https://cloud.google.com/compute/docs/gpus +[unified_sharding]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html From 26daeaea68d8944be9af9c4cf0f7a4bc7f358d61 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 24 Apr 2025 09:12:23 -0700 Subject: [PATCH 0786/1769] #sdy avoid an extra call to `mlir.module_to_bytecode` and use `mlir::Sdy::getMeshAttr` PiperOrigin-RevId: 751017809 --- jax/_src/export/_export.py | 9 +++++---- jaxlib/sdy.cc | 9 +++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 0d7920e9b206..1f58c7e0def6 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1431,11 +1431,12 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = _jax.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + submodule_bc = mlir.module_to_bytecode(submodule) + shardy_enabled = _jax.sdy.lowered_with_shardy(submodule_bc) if shardy_enabled: - submodule = ir.Module.parse(_jax.sdy.sdy_round_trip_import_shardings( - mlir.module_to_bytecode(submodule))) + submodule = ir.Module.parse( + _jax.sdy.sdy_round_trip_import_shardings(submodule_bc) + ) with submodule.context: pipeline = passmanager.PassManager.parse( diff --git a/jaxlib/sdy.cc b/jaxlib/sdy.cc index ed908c28acd8..c31d11bac0d0 100644 --- a/jaxlib/sdy.cc +++ b/jaxlib/sdy.cc @@ -29,7 +29,6 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "nanobind/nanobind.h" @@ -125,14 +124,12 @@ void BuildSdySubmodule(nb::module_& m) { mlir::OwningOpRef module = xla::ValueOrThrow(ParseMlirModuleString( absl::string_view(bytecode.c_str(), bytecode.size()), context)); - auto mesh_op = - mlir::SymbolTable::lookupNearestSymbolFrom( - module.get(), mlir::StringAttr::get(&context, "mesh")); - if (!mesh_op) { + auto mesh_attr = mlir::sdy::getMeshAttr(module.get(), "mesh"); + if (!mesh_attr) { return {}; } nb::list mesh_shape; - for (auto axis : mesh_op.getMeshAttr().getAxes()) { + for (auto axis : mesh_attr.getAxes()) { mesh_shape.append( nb::make_tuple(axis.getName().str(), axis.getSize())); } From 1e370050257c304b48b4a6a09b1ad67de654b192 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Apr 2025 09:25:42 -0700 Subject: [PATCH 0787/1769] Refactor LocalMask to inherit from _ComputableMask PiperOrigin-RevId: 751021842 --- .../splash_attention/splash_attention_mask.py | 84 +++++++------------ .../pallas/tpu_splash_attention_mask_test.py | 50 +++++++---- 2 files changed, 63 insertions(+), 71 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index e43f30e7791c..3f7a0d863188 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -408,26 +408,20 @@ def __hash__(self): )) -class LocalMask(Mask): +class LocalMask(_ComputableMask): """Lazy local mask, prevents model from attending to tokens outside window. Attributes: - _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). window_size: Size of the two sides of the local window (None identifes no limit for the given side). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - _q_sequence: Important for performance. """ - # TODO(amagni): Transform LocalMask into a _ComputableMask. - - _shape: tuple[int, int] window_size: tuple[int | None, int | None] offset: int - _q_sequence: np.ndarray | None = None def __init__( self, @@ -436,68 +430,50 @@ def __init__( offset: int, shard_count: int = 1, ): - self._shape = shape self.window_size = window_size self.offset = offset - if self.shape[0] % (shard_count * shard_count) != 0: - raise ValueError( - f'Shard count squared ({shard_count * shard_count}) must' - f' divide Q seq_len ({self.shape[0]}) evenly.' - ) + def local_mask_function(q_ids, kv_ids): + """Computes the local attention mask for the given slice indices.""" + left_size, right_size = self.window_size - @property - def shape(self) -> tuple[int, int]: - return self._shape + assert q_ids.ndim == 2 + assert kv_ids.ndim == 2 - def __getitem__(self, idx) -> np.ndarray: - if len(idx) != 2: - raise NotImplementedError(f'Unsupported slice: {idx}') - q_slice, kv_slice = idx - if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): - raise NotImplementedError(f'Unsupported slice: {idx}') + if left_size is None and right_size is None: + return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_) - q_slice = _fill_slice(q_slice, self.shape[0]) - kv_slice = _fill_slice(kv_slice, self.shape[1]) - - if self._q_sequence is None: - rows = np.arange(q_slice.start, q_slice.stop) - else: - rows = self._q_sequence[q_slice] - - cols = np.arange(kv_slice.start, kv_slice.stop) - - left_size, right_size = self.window_size - - if left_size is None and right_size is None: - return np.ones((rows.shape[0], cols.shape[0]), dtype=np.bool_) - else: - expanded_cols = cols[None, :] - if self.offset != 0: - expanded_rows = rows[:, None] + self.offset + # Avoid the addition when possible to avoid instantiating an actual array. + if offset != 0: + shifted_q_ids = q_ids + self.offset else: - expanded_rows = rows[:, None] - if left_size is not None and right_size is not None: - return (expanded_rows <= expanded_cols + left_size) & ( - expanded_cols - right_size <= expanded_rows - ) + shifted_q_ids = q_ids + + mask = None + if left_size is not None: + mask = shifted_q_ids - left_size <= kv_ids + if right_size is not None: + if mask is None: + mask = shifted_q_ids + right_size >= kv_ids + else: + mask &= shifted_q_ids + right_size >= kv_ids + return mask - elif left_size is not None and right_size is None: - return expanded_rows <= expanded_cols + left_size - else: - assert left_size is None and right_size is not None - return expanded_cols - right_size <= expanded_rows + super().__init__( + shape=shape, + mask_function=local_mask_function, + shard_count=shard_count, + ) def __eq__(self, other: object): if not isinstance(other, type(self)): - return NotImplemented + return False return ( self.shape == other.shape and self.window_size == other.window_size and self.offset == other.offset - and (True if self._q_sequence is None else - np.array_equal(self._q_sequence, other._q_sequence)) + and np.array_equal(self.q_sequence, other.q_sequence) ) def __hash__(self): @@ -506,7 +482,7 @@ def __hash__(self): self.shape, self.window_size, self.offset, - self._q_sequence.tobytes() if self._q_sequence is not None else None, + self.q_sequence.tobytes() if self.q_sequence is not None else None, )) diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index 7c4b53529169..e2e420edee8c 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -1248,7 +1248,8 @@ def test_local_mask(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1292,10 +1293,12 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1327,10 +1330,14 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -1359,7 +1366,9 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1400,10 +1409,12 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1432,10 +1443,14 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -2250,11 +2265,12 @@ def test_huge_mask2(self): multi_head, block_shape ) - self.assertIsNone(mask_function) + self.assertIsNotNone(mask_function) self.assertIsNotNone(mask_info.block_mask) self.assertIsNotNone(mask_info.data_next) - self.assertIsNotNone(mask_info.mask_next) - self.assertIsNotNone(mask_info.partial_mask_blocks) + self.assertIsNone(mask_info.mask_next) + self.assertIsNone(mask_info.partial_mask_blocks) + self.assertIsNotNone(mask_info.q_sequence) def test_process_invalid_mask(self): """Masks with of an all-0 row causes undefined softmax, reject them.""" From d13ac0ac7f1d48a3a3ea9dd272009eaa056027ae Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 10:55:47 -0700 Subject: [PATCH 0788/1769] [pallas:mosaic] Fixed a bug in `pl.debug_print` lowering On TPU we allow the string argument to be a prefix, instead of a format string. PiperOrigin-RevId: 751054611 --- jax/_src/pallas/mosaic/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index da230c89d892..208ae389d2bf 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3696,8 +3696,8 @@ def _debug_print_rule( # Scalar case. if is_all_scalars: - primitives.check_debug_print_format(fmt, *args) if has_placeholders: + primitives.check_debug_print_format(fmt, *args) if not all( isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 for arg in args From 3bc8436a9b2e17ff0a1380916387443278136f3d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 11:23:08 -0700 Subject: [PATCH 0789/1769] Add `__str__` to `UnshapedArray` so that whenever we `print(aval)`, we don't see the class name by default. It only shows up when you do: `repr(aval)`. Weak_type in `__str__` is represented as `~int32[5, 4]` (note the tilde at the start) PiperOrigin-RevId: 751066142 --- jax/_src/api.py | 2 +- jax/_src/core.py | 3 +++ jax/_src/tree_util.py | 2 +- tests/core_test.py | 7 +++++++ tests/lax_control_flow_test.py | 2 +- 5 files changed, 13 insertions(+), 3 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 9e13c4438d50..4b608fd64a89 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -358,7 +358,7 @@ def disable_jit(disable: bool = True): ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS - Value of y is Tracedwith + Value of y is Tracedwith [5 7 9] Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`, diff --git a/jax/_src/core.py b/jax/_src/core.py index 12f56e7e527b..b073712e5a85 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1688,6 +1688,9 @@ def __repr__(self): return '{}({}{})'.format(self.__class__.__name__, self.str_short(), ", weak_type=True" if self.weak_type else "") + def __str__(self): + return '{}{}'.format("~" if self.weak_type else "", self.str_short()) + _bool = concretization_function_error(bool) _int = concretization_function_error(int, True) _float = concretization_function_error(float, True) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index b73d84b330de..7c7ca96b1e5c 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -529,7 +529,7 @@ class Partial(functools.partial): >>> print_zero() 0 >>> call_func(print_zero) # doctest:+ELLIPSIS - Tracedwith + Traced<~int32[]>with """ def __new__(klass, func, *args, **kw): diff --git a/tests/core_test.py b/tests/core_test.py index 00b3eb1d61d5..e39487035751 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -203,6 +203,13 @@ def test_is_valid_jaxtype(self, dtype): else: self.assertFalse(core.valid_jaxtype(arr)) + def test_str_aval(self): + aval = ShapedArray((8, 2), np.int32) + self.assertEqual(str(aval), "int32[8,2]") + + aval = ShapedArray((8, 2), np.int32, weak_type=True) + self.assertEqual(str(aval), "~int32[8,2]") + @parameterized.named_parameters( (str(i), *spec) for i, spec in enumerate(test_specs)) def test_jit(self, f, args): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 8876fb7d06be..78026968d2cd 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1941,7 +1941,7 @@ def plus_one(p, iter_idx): def testScanBodyOutputError(self): with self.assertRaisesRegex( TypeError, - re.escape("scan body output must be a pair, got ShapedArray(float32[]).")): + re.escape("scan body output must be a pair, got float32[].")): lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.)) def testScanMetadataError(self): From 18fcf917cbc57f71517bdf1f9004c83b859366c9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 11:25:48 -0700 Subject: [PATCH 0790/1769] Make inspect_array_sharding inside shard_map work with `check_vma=True` | `check_rep=True`. Fixes https://github.com/jax-ml/jax/issues/23936 PiperOrigin-RevId: 751067216 --- jax/_src/core.py | 3 +++ jax/_src/debugging.py | 4 +++- tests/debugging_primitives_test.py | 23 +++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index b073712e5a85..dcdfc91344ce 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1894,7 +1894,10 @@ def str_short_aval(shape, dtype, mesh, spec, vma, def get_vma(vma, mesh): if mesh.empty: return vma + axis_env_names = get_axis_env().axis_names() for i in vma: + if i in axis_env_names and i not in mesh._name_to_type: + continue if mesh._name_to_type[i] != AxisType.Manual: raise ValueError( "Axes mentioned in `vma` field of ShapedArray should" diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index b44a4e434027..18178b4efcb0 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -459,6 +459,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes), am.axis_names) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh devices = axis_context.mesh._flat_devices_tuple else: raise NotImplementedError(type(axis_context)) @@ -470,7 +471,8 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): if mesh.empty: return callback( sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices)) - pspec = parse_flatten_op_sharding(hlo_sharding, mesh)[0] + pspec = (P() if hlo_sharding.is_manual() else + parse_flatten_op_sharding(hlo_sharding, mesh)[0]) return callback(NamedSharding(mesh, pspec)) if len(devices) == 1: diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index becd18033d6d..bf86c82b9615 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,6 +25,7 @@ from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax.sharding import PartitionSpec as P import jax.numpy as jnp import numpy as np @@ -1120,6 +1121,28 @@ def test_visualize_pmap_sharding(self): """) self.assertEqual(output(), expected) + def test_visualize_sharding_shard_map(self): + mesh = jtu.create_mesh((2,), 'x') + + def f(): + a = jnp.zeros(1000) + debugging.visualize_array_sharding(a) + return a + + with jtu.capture_stdout() as output: + f() # doesn't crash + + with jtu.capture_stdout() as output: + jax.jit(f, out_shardings=jax.NamedSharding(mesh, P('x')))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"), + check_vma=False)() # doesn't crash + + class InspectShardingTest(jtu.JaxTestCase): def test_inspect_sharding_is_called_in_pjit(self): From 4e644eeec3cee6e65e5ca4c6e67c8f1f19603dc5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 12:06:44 -0700 Subject: [PATCH 0791/1769] Remove PositionalSharding from JAX from almost all places except for pjit_test.py and array_test.py so we keep testing it until we fully delete it. PiperOrigin-RevId: 751083672 --- docs/jax.sharding.rst | 6 ------ jax/_src/dispatch.py | 5 +++-- jax/_src/interpreters/pxla.py | 5 +++-- jax/_src/layout.py | 2 +- jax/_src/sharding.py | 4 ---- jaxlib/xla_client_test.py | 5 ----- tests/array_test.py | 37 ----------------------------------- tests/memories_test.py | 24 +++-------------------- tests/pjit_test.py | 36 ---------------------------------- 9 files changed, 10 insertions(+), 114 deletions(-) diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 954f62b8a52d..12760d62ddb3 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -16,15 +16,9 @@ Classes .. autoclass:: NamedSharding :members: :show-inheritance: -.. autoclass:: PositionalSharding - :members: - :show-inheritance: .. autoclass:: PmapSharding :members: :show-inheritance: -.. autoclass:: GSPMDSharding - :members: - :show-inheritance: .. autoclass:: PartitionSpec :members: .. autoclass:: Mesh diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 28ed39a1fe2a..fd96883a53ef 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -51,7 +51,7 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import ( NamedSharding, SingleDeviceSharding, TransferToMemoryKind, GSPMDSharding, - PositionalSharding, is_single_device_sharding) + is_single_device_sharding) import numpy as np @@ -132,7 +132,8 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put(tok, PositionalSharding(devices)) + return jax.device_put( + tok, NamedSharding(Mesh(devices, 'x'), PartitionSpec('x'))) # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f3d1265511d5..66465bd152a8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -68,7 +68,8 @@ from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) + SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding, + PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) @@ -1267,7 +1268,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token): for token in token_buf: assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) token_devices.append(token.sharding._device_assignment[0]) - s = PositionalSharding(token_devices) + s = NamedSharding(Mesh(token_devices, 'x'), P('x')) global_token_array = jax.make_array_from_single_device_arrays( (0,), s, token_buf ) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 8d4f8acd5327..3675433c43d8 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -105,7 +105,7 @@ def __init__(self, device_local_layout: LayoutOptions = None, raise ValueError( 'Sharding has to be concrete when layout is of type' f' {type(device_local_layout)}. Please pass a' - ' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or' + ' `jax.sharding.NamedSharding` or' ' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got' f' sharding {sharding}' ) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index a9bf62b46473..32373d6e6c39 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -192,10 +192,6 @@ def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: Two shardings are equivalent if they place the same logical array shards on the same devices. - - For example, a :class:`NamedSharding` may be equivalent - to a :class:`PositionalSharding` if both place the same shards of the array - on the same devices. """ try: return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), diff --git a/jaxlib/xla_client_test.py b/jaxlib/xla_client_test.py index 4bb7f7992e16..ae34751b6cab 100644 --- a/jaxlib/xla_client_test.py +++ b/jaxlib/xla_client_test.py @@ -951,11 +951,6 @@ def SetLayoutsSharded(self): if self.backend.platform != "tpu": raise self.skipTest("mhlo.layout_mode only implemented on TPU") - # Hand-edited version of: - # sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) - # x = jax.device_put(np.ones((1024, 128)), sharding.reshape(4, 2)) - # jax.jit(lambda x, y: x + y, out_shardings=sharding)(x, 1.) - # # This also lightly tests mixed default + user-specified input layouts. module_str = """ module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, diff --git a/tests/array_test.py b/tests/array_test.py index 230c4cda336a..97d66566cffa 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -986,10 +986,6 @@ def test_gspmd_sharding_repr(self): # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) - def test_positional_sharding_fully_replicated(self): - sharding = PositionalSharding(jax.devices()) - jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash - @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (4, 2), (), False), ("mesh_x", P("x"), (4, 2), (1,), False), @@ -1019,17 +1015,6 @@ def test_positional_sharding_op_sharding_lowering( devices_sharding.shard_shape(value_shape)) self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - def test_positional_sharding_aval_compatible(self): - if jax.device_count() < 2: - self.skipTest('Requires >=2 devices') - sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count()) - x = jax.random.uniform(jax.random.key(42), (256, 20, 1000)) - with self.assertRaisesRegex( - ValueError, - 'Sharding PositionalSharding.*is only valid for values of rank 2, but' - ' was applied to a value of rank 3'): - jax.lax.with_sharding_constraint(x, sharding) - @parameterized.named_parameters( ("2d_mesh_x_y", (4, 2), P("x", "y")), ("2d_mesh_x", (4, 2), P("x")), @@ -1103,21 +1088,6 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): op_shardings.is_op_sharding_replicated( ps._to_xla_hlo_sharding(len(shape)))) - def test_devices_sharding_respects_init_mesh_shape(self): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) - - devices_sharding = PositionalSharding(mesh.devices) - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - def test_pmap_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') @@ -1125,13 +1095,6 @@ def test_pmap_sharding_repr(self): str(out.sharding) # doesn't crash repr(out.sharding) # doesn't crash - def test_positional_sharding_repr(self): - if jax.device_count() < 2: - self.skipTest('Test needs >= 2 devices.') - s = PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) - repr(s) # doesn't crash - str(s) # doesn't crash - def test_pspec_tuple(self): pspec = P('x', 'y', 'z') self.assertEqual(pspec, ('x', 'y', 'z')) diff --git a/tests/memories_test.py b/tests/memories_test.py index 9700acc649e3..302895e29f63 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -31,9 +31,9 @@ import jax.numpy as jnp from jax.ad_checkpoint import Offloadable, remat, Recompute from jax._src.sharding import common_devices_indices_map -from jax._src.sharding_impls import (NamedSharding, PositionalSharding, - SingleDeviceSharding, GSPMDSharding, - TransferToMemoryKind, PartitionSpec as P) +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, GSPMDSharding, + TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on from jax._src.shard_map import shard_map import numpy as np @@ -66,7 +66,6 @@ def setUp(self): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -75,9 +74,6 @@ def test_canonicalize_memory_kind(self, name): mesh = jtu.create_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) self.assertEqual(ns.memory_kind, self._default_memory_kind) - elif name == "positional_sharding": - ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) self.assertEqual(ss.memory_kind, self._default_memory_kind) @@ -88,7 +84,6 @@ def test_canonicalize_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -99,11 +94,6 @@ def test_wrong_memory_kind(self, name): ): mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") - elif name == "positional_sharding": - with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device.*" - ): - PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, @@ -120,7 +110,6 @@ def test_wrong_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -131,8 +120,6 @@ def test_correct_tpu_memory_kind(self, name): if name == "named_sharding": mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) - elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -141,7 +128,6 @@ def test_correct_tpu_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -151,10 +137,6 @@ def test_sharding_eq(self, name): s1 = NamedSharding(mesh, P("x")) s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) - elif name == "positional_sharding": - s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) - self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 468877d79be4..0de1ee1b84a2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3738,25 +3738,6 @@ def test_list_in_pspec(self): out = with_sharding_constraint(jnp.arange(8), P(['x'])) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - def test_sharding_preserved_trivial(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - def identity(x): - return x - - out = pjit(identity)(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = pjit(identity)(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) - def test_wsc_error_on_none(self): with self.assertRaisesRegex( ValueError, @@ -3764,23 +3745,6 @@ def test_wsc_error_on_none(self): ' not allowed'): with_sharding_constraint(jnp.arange(8), None) - def test_sharding_preserved_aot(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - compiled = pjit(lambda x: x * 2).lower(arr).compile() - out = compiled(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = compiled(arr2) - # The sharding won't be PositionalSharding since the pjit was already - # Compiled which bakes in the output sharding. - self.assertIsInstance(out2.sharding, NamedSharding) - def test_sharding_on_output_with_vmap(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) From 53e8c0f8cf00d66c0831af96733ba9f77d6ac289 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 12:06:46 -0700 Subject: [PATCH 0792/1769] [pallas] Removed the deprecated `jax.experimental.pallas.gpu` module PiperOrigin-RevId: 751083687 --- docs/pallas/CHANGELOG.md | 5 +++++ jax/BUILD | 2 -- jax/experimental/pallas/gpu.py | 24 ------------------------ 3 files changed, 5 insertions(+), 26 deletions(-) delete mode 100644 jax/experimental/pallas/gpu.py diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index db63657626b3..476cc54673a1 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list ## Unreleased +* Removals + + * Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use + the Triton backend import {mod}`jax.experimental.pallas.triton`. + * Changes * {func}`jax.experimental.pallas.BlockSpec` now takes in special types in diff --git a/jax/BUILD b/jax/BUILD index d7c48f019096..d2320e1e4456 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -671,7 +671,6 @@ pytype_strict_library( "experimental/pallas/**/*.py", ], exclude = [ - "experimental/pallas/gpu.py", "experimental/pallas/mosaic_gpu.py", "experimental/pallas/ops/gpu/**/*.py", "experimental/pallas/ops/tpu/**/*.py", @@ -784,7 +783,6 @@ pytype_strict_library( pytype_strict_library( name = "pallas_triton", srcs = [ - "experimental/pallas/gpu.py", "experimental/pallas/triton.py", ], visibility = [ diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py deleted file mode 100644 index 0ee84c8453ec..000000000000 --- a/jax/experimental/pallas/gpu.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src import deprecations - -deprecations.warn( - "pallas-gpu-triton", - "The ``jax.experimental.pallas.gpu`` submodule is deprecated. " - " Use ``jax.experimental.pallas.triton`` instead.", - stacklevel=1, -) - -from jax.experimental.pallas.triton import * # noqa: F403 From d865c8cf266da710475b26adce70c63ce6bd9635 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 24 Apr 2025 12:07:53 -0700 Subject: [PATCH 0793/1769] [Mosaic GPU] Adding a deterministic backwards pass to the pallas MGPU kernel. It's implemented as split dq and dkv kernels to have enough SMEM for double-buffering. The compute throughput is better than in the forwards pass, which is expected because there's less vector ops since we're not computing the softmax. PiperOrigin-RevId: 751084141 --- jax/_src/pallas/mosaic_gpu/core.py | 2 +- .../pallas/ops/gpu/attention_mgpu.py | 332 +++++++++++++++++- tests/pallas/mgpu_attention_test.py | 65 ++++ 3 files changed, 394 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 5446cdb47add..6fae802b7ff4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -168,7 +168,7 @@ def kernel( body: Callable[..., None], out_shape: object, *, - scratch_shapes: Sequence[pallas_core.ScratchShape] = (), + scratch_shapes: pallas_core.ScratchShapeTree = (), compiler_params: object | None = None, **mesh_kwargs: object, ): diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 6a20b448ca54..a9403ad22786 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -25,7 +25,7 @@ import jax.experimental.pallas.mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np - +from functools import partial @dataclasses.dataclass(frozen=True) class TuningConfig: @@ -33,6 +33,12 @@ class TuningConfig: block_kv: int max_concurrent_steps: int use_schedule_barrier: bool = True + compute_wgs_bwd: int = 1 + + block_q_dkv: int | None = None + block_kv_dkv: int | None = None + block_q_dq: int | None = None + block_kv_dq: int | None = None def __post_init__(self): if self.block_q % 64: @@ -42,9 +48,19 @@ def __post_init__(self): if self.max_concurrent_steps < 2: raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") + backward_blocks = [self.block_q_dkv, self.block_kv_dkv, self.block_q_dq, self.block_kv_dq] + block_is_set = [blk is not None for blk in backward_blocks] + if any(block_is_set) and not all(block_is_set): + raise ValueError( + "Backward block sizes (block_q_dkv, block_kv_dkv, block_q_dq, " + "block_kv_dq) must either all be specified or all be None." + ) -@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) -def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): + @property + def has_backward_blocks(self) -> bool: + return self.block_q_dkv is not None + +def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -259,6 +275,314 @@ def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): return out +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +@partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): + return _attention_forward(q, k, v, config, save_residuals) + +def _attention_fwd(q, k, v, config: TuningConfig, save_residuals: bool): + del save_residuals + + out, (lse,) = _attention_forward(q, k, v, config, save_residuals=True) + return out, (q, k, v, out, lse) + +def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do): + del save_residuals + q, k, v, out, lse = res + + if not config.has_backward_blocks: + raise ValueError("Need to specify backward blocks.") + + assert config.block_q_dq is not None + assert config.block_kv_dq is not None + assert config.block_q_dkv is not None + assert config.block_kv_dkv is not None + + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + q_heads_per_kv_head = num_q_heads // num_kv_heads + dtype = q.dtype + compute_wgs = config.compute_wgs_bwd + + num_q_tiles, rem = divmod(q_seq_len, config.block_q_dq * compute_wgs) + if rem: + raise NotImplementedError( + f"{q_seq_len=} must be a multiple of {config.block_q_dq=} * {compute_wgs=}") + + num_kv_tiles, rem = divmod(kv_seq_len, config.block_kv_dkv * compute_wgs) + if rem: + raise NotImplementedError( + f"{kv_seq_len=} must be a multiple of {config.block_kv_dkv=} * {compute_wgs=}") + + num_q_tiles_in_dkv, rem = divmod(q_seq_len, config.block_q_dkv) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {config.block_q_dkv=}") + + num_kv_tiles_in_dq, rem = divmod(kv_seq_len, config.block_kv_dq) + if rem: + raise NotImplementedError(f"{kv_seq_len=} must be a multiple of {config.block_kv_dq=}") + + tiling = plgpu.TilingTransform((8, 64)) + swizzle = plgpu.SwizzleTransform(128) + + delta = jnp.einsum('bqhd,bqhd->bhq', out.astype(jnp.float32), do.astype(jnp.float32)) + del out # Not needed anymore. + + def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref, + smem_buffers, buffer_barriers, block_q, block_kv): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers + q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers + def _compute_thread(): + q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx] + q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q + q_slice = (batch, pl.ds(q_seq_base, block_q), q_head) + plgpu.copy_gmem_to_smem(q_ref.at[q_slice], q_smem, q_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem(do_ref.at[q_slice], do_smem, do_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + delta_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + delta_smem, + delta_barriers.at[wg_idx], + ) + plgpu.copy_gmem_to_smem( + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + lse_smem, + lse_barriers.at[wg_idx], + ) + for buffer in buffer_barriers: + plgpu.barrier_wait(buffer.at[wg_idx]) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_ROW) + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_ROW) + dq_acc = plgpu.layout_cast( + jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dq, _, _ = (yield (dq_acc, lse, delta)) + q_smem[...] = dq.astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice]) + plgpu.wait_smem_to_gmem(0) + + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): + q_smem, do_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx] + (dq_acc, lse, delta) = carry + + def compute_s(acc_ref): + plgpu.wgmma(acc_ref, q_smem, plgpu.transpose_ref(k_smem, (1, 0))) + return acc_ref[...] + + s = pl.run_scoped(compute_s, plgpu.ACC((block_q, block_kv), jnp.float32)) + s *= math.log2(math.e) + p = jnp.exp2(s - lax.broadcast_in_dim(lse, (block_q, block_kv), [0])) + + # dP + def compute_dp(acc_ref): + plgpu.wgmma(acc_ref, do_smem, plgpu.transpose_ref(v_smem, (1, 0))) + return acc_ref[...] + + dp = pl.run_scoped(compute_dp, plgpu.ACC((block_q, block_kv), jnp.float32)) + plgpu.barrier_arrive(v_consumed_barrier) + + # dS + ds = p * (dp - lax.broadcast_in_dim(delta, (block_q, block_kv), [0])) + + # dQ + def compute_dq(acc_ref): + plgpu.wgmma(acc_ref, ds.astype(k_ref.dtype), k_smem) + + dq_acc = pl.run_state(compute_dq)(plgpu.ACC.init(dq_acc)) + plgpu.barrier_arrive(k_consumed_barrier) + + return (dq_acc, lse, delta) + + pipeline = plgpu.emit_pipeline_warp_specialized( + kv_pipeline, + grid=(num_kv_tiles_in_dq,), + max_concurrent_steps=min([config.max_concurrent_steps, num_q_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + carry_coroutine=_compute_thread, + in_specs=[ + plgpu.GPUBlockSpec( # k + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.GPUBlockSpec( # v + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + ]) + k_ref = k_ref.at[batch, :, kv_head, :] + v_ref = v_ref.at[batch, :, kv_head, :] + pipeline(k_ref, v_ref) + + def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, + dk_ref, dv_ref, smem_buffers, buffer_barriers, block_q: int, block_kv: int): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + (k_smem2, v_smem2) = smem_buffers + (k_barriers, v_barriers) = buffer_barriers + + def _compute_thread(): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + plgpu.copy_gmem_to_smem( + k_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + k_smem, + k_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + v_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + v_smem, + v_barriers.at[wg_idx]) + plgpu.barrier_wait(k_barriers.at[wg_idx]) + plgpu.barrier_wait(v_barriers.at[wg_idx]) + dk_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + (dk, dv) = (yield (dv_acc, dk_acc)) + k_smem[...] = dk.astype(dtype) + v_smem[...] = dv.astype(dtype) + + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + k_smem, + dk_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.copy_smem_to_gmem( + v_smem, + dv_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.commit_smem_to_gmem_group() + plgpu.wait_smem_to_gmem(0) + + def q_pipeline(_, q_smem, do_smem, lse_smem, delta_smem, q_consumed_barrier, do_consumed_barrier, lse_consumed_barrier, delta_consumed_barrier, carry): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + dk_acc, dv_acc = carry + + def _compute_sT(acc_ref): + plgpu.wgmma(acc_ref, k_smem, plgpu.transpose_ref(q_smem, (1, 0))) + return acc_ref[...] + sT = pl.run_scoped(_compute_sT, plgpu.ACC((block_kv, block_q), jnp.float32)) + sT *= math.log2(math.e) + + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(lse_consumed_barrier) + pT = jnp.exp2(sT - lax.broadcast_in_dim(lse, (block_kv, block_q), [1])) + + def _compute(refs): + # Combining two WGMMA calls in one block to avoid the unnecessary + # sychronization from two `wgmma.wait_group` calls. + dv_acc_ref, dpT_acc_ref = refs + plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV + plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT + + zeros = plgpu.layout_cast( + jnp.full((block_kv, block_q), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc, dpT = pl.run_state(_compute)((plgpu.ACC.init(dv_acc), plgpu.ACC.init(zeros))) + plgpu.barrier_arrive(do_consumed_barrier) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(delta_consumed_barrier) + + dsT = pT * (dpT - lax.broadcast_in_dim(delta, (block_kv, block_q), [1])) + + def compute_dk(acc_ref): + plgpu.wgmma(acc_ref, dsT.astype(dtype), q_smem) + + dk_acc = pl.run_state(compute_dk)(plgpu.ACC.init(dk_acc)) + plgpu.barrier_arrive(q_consumed_barrier) + + return (dk_acc, dv_acc) + + pipeline = plgpu.emit_pipeline_warp_specialized( + q_pipeline, + grid=(num_q_tiles_in_dkv,), + max_concurrent_steps=min([config.max_concurrent_steps, num_kv_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + carry_coroutine=_compute_thread, + in_specs=[ + plgpu.GPUBlockSpec( # q + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.GPUBlockSpec( # do + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.GPUBlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)), + plgpu.GPUBlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)) + ]) + q_ref = q_ref.at[batch, :, q_head, :] + do_ref = do_ref.at[batch, :, q_head, :] + lse_ref = lse_ref.at[batch, q_head, :] + delta_ref = delta_ref.at[batch, q_head, :] + pipeline(q_ref, do_ref, lse_ref, delta_ref) + + q_scratch = plgpu.SMEM( + (compute_wgs, config.block_q_dq, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + do_scratch = q_scratch + lse_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + delta_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + dq = plgpu.kernel( + partial(kernel_dq, block_q=config.block_q_dq, block_kv=config.block_kv_dq), + out_shape=q, + scratch_shapes=[ + (q_scratch, do_scratch, lse_scratch, delta_scratch), # type: ignore + (plgpu.Barrier(1, num_barriers=compute_wgs),) * 4 # type: ignore + ], + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), + num_threads=compute_wgs + 1, + thread_name="wg", + )(q, k, v, do, lse, delta) + + k_scratch = plgpu.SMEM( + (compute_wgs, config.block_kv_dkv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + v_scratch = k_scratch + out_shape_kv = jax.ShapeDtypeStruct( + (batch_size, kv_seq_len, num_q_heads, head_dim), dtype=jnp.float16) + dk, dv = plgpu.kernel( + partial(kernel_dkv, block_q=config.block_q_dkv, block_kv=config.block_kv_dkv), + out_shape=[out_shape_kv, out_shape_kv], + scratch_shapes=[ + (k_scratch, v_scratch), # type: ignore + (plgpu.Barrier(1, num_barriers=compute_wgs),) * 2 # type: ignore + ], + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(batch_size, num_kv_tiles, num_q_heads), + grid_names=("batch", "kv_seq", "heads"), + num_threads=compute_wgs + 1, + thread_name="wg" + )(q, k, v, do, lse, delta) + + if q_heads_per_kv_head > 1: + sum_shape = (*k.shape[:-1], q_heads_per_kv_head, head_dim) + dk = dk.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dk.dtype) + dv = dv.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dv.dtype) + + return dq, dk, dv + +attention.defvjp(_attention_fwd, _attention_bwd) + @functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: @@ -427,7 +751,7 @@ def _kernel_entry(): ) @jax.jit def run_function(q, k, v, o, lse): - _, _, _, out, lse = pl.run_state(run)((q, k, v, o, lse)) + *_, out, lse = pl.run_state(run)((q, k, v, o, lse)) return out, lse lse = ( diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 27588683d0e9..3f0370153d81 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -16,10 +16,12 @@ import os +import contextlib import numpy as np from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.pallas import pallas_call import jax.numpy as jnp # pylint: disable=g-import-not-at-top @@ -47,6 +49,9 @@ def setUp(self): if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): self.skipTest("Only works on GPU with capability sm90a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) @parameterized.product( batch_size=(1, 4), @@ -95,6 +100,66 @@ def test_flash_attention( (lse_ref,) = res_ref[0] np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) + @parameterized.product( + batch_size=(3,), + seq_lens=((512, 512), (3584, 4096)), + num_q_and_kv_heads=( + (4, 4), # MHA + (4, 1), # MQA + (6, 3), # GQA + ), + bwd_blocks = ( + (64, 64, 64, 64), + (64, 128, 128, 64), + (128, 128, 128, 128), + ), + head_dim=(64, 128, 256), + ) + def test_bwd_flash_attention( + self, + batch_size, + seq_lens, + num_q_and_kv_heads, + bwd_blocks, + head_dim, + ): + num_q_heads, num_kv_heads = num_q_and_kv_heads + kv_seq_len, q_seq_len = seq_lens + block_q_dq, block_kv_dq, block_q_dkv, block_kv_dkv = bwd_blocks + compute_wgs = 2 if head_dim <= 128 else 1 + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + + def f(q, k, v): + return attention_mgpu.attention( + q, + k, + v, + attention_mgpu.TuningConfig( + block_q=block_q_dq, block_kv=block_kv_dq, + max_concurrent_steps=2, compute_wgs_bwd=compute_wgs, + block_q_dkv=block_q_dkv, block_kv_dkv=block_kv_dkv, + block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, + ) + ).sum() + + def f_ref(q, k, v): + return attention_mgpu.attention_reference(q, k, v).sum() + + try: + # TODO(pobudzey): Replace with `jtu.check_grads` when it's fixed. + dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=7e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + self.skipTest("Not enough SMEM for this configuration.") if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3c2b533e8660004ca5cd1b3515e1a349239699ff Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 24 Apr 2025 12:17:19 -0700 Subject: [PATCH 0794/1769] Re-land: "Don't recompute source_info for each tracer during staging". Reverts 492cd3d9313cfd45e8bd63a8f51aa63d92924cd5 PiperOrigin-RevId: 751087903 --- jax/_src/interpreters/ad.py | 24 ++++--- jax/_src/interpreters/partial_eval.py | 99 ++++++++++++++++----------- jax/_src/pjit.py | 7 +- jax/_src/shard_map.py | 7 +- jax/experimental/attrs.py | 3 +- 5 files changed, 82 insertions(+), 58 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 98cda2df4964..0f11e0d72f12 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -89,12 +89,14 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): + source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, - tangent_trace.new_arg(get_aval(p).to_tangent_aval())) + tangent_trace.new_arg(get_aval(p).to_tangent_aval(), + source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): @@ -103,7 +105,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -172,13 +174,14 @@ def _linearize_jaxpr( lin_trace = LinearizeTrace(primal_trace, tangent_trace) tangent_trace.tag = lin_trace.tag - def new_arg(trace, primal_aval, nz): - primal = primal_trace.new_arg(primal_aval) + def new_arg(trace, primal_aval, nz, source_info): + primal = primal_trace.new_arg(primal_aval, source_info) tangent_aval = primal_aval.to_tangent_aval() - tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) - tracers = [new_arg(lin_trace, v.aval, nz) + source_info = source_info_util.current() + tracers = [new_arg(lin_trace, v.aval, nz, source_info) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] with core.set_current_trace(lin_trace, check_leaks=True): @@ -188,7 +191,7 @@ def new_arg(trace, primal_aval, nz): debug_info = jaxpr.jaxpr.debug_info nzs_out = [type(t) is not Zero for t in out_tangents] - out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) + out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t, source_info) for (nz, t) in zip(nzs_out, out_tangents) if nz) tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) tangent_trace.invalidate() @@ -200,7 +203,7 @@ def new_arg(trace, primal_aval, nz): tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) + residuals_and_primals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) primal_trace.invalidate() num_residuals = len(tangent_consts) @@ -212,8 +215,9 @@ def new_arg(trace, primal_aval, nz): def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, has_aux=False, tag=None): with core.take_current_trace() as parent_trace: + source_info = source_info_util.current() tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info) - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag @@ -234,7 +238,7 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) + out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() jaxpr, used_consts, _ = pe.dce_jaxpr_consts( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0e888c1591aa..1317a584f6c3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,6 +41,7 @@ JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) +from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) @@ -1729,7 +1730,8 @@ def to_jaxpr( invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] + source_info = source_info_util.current() + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x, source_info))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars @@ -1899,51 +1901,51 @@ def invalidate(self): self.frame.constid_to_tracer = {} self.frame.constvar_to_val = {} - def to_jaxpr_tracer(self, x): + def to_jaxpr_tracer(self, x, source_info: SourceInfo): as_local_var = self.frame.tracer_to_var.get(id(x)) if as_local_var is None: if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr with core.set_current_trace(self): x = x.dimension_as_value() - return self.to_jaxpr_tracer(x) + return self.to_jaxpr_tracer(x, source_info) else: - return self.new_const(x) + return self.new_const(x, source_info) else: return x - def new_arg(self, aval): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def new_arg(self, aval, source_info: SourceInfo): + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) return tracer - def new_const(self, c): + def new_const(self, c, source_info: SourceInfo): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = get_aval(c) if hasattr(aval, "weak_type"): aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, c) + aval = self._lift_tracers_in_aval(aval, source_info) + tracer = self._new_const(aval, c, source_info) return tracer pure = lift = new_const - def _new_const(self, aval, c) -> DynamicJaxprTracer: - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.constid_to_tracer[id(c)] = tracer self.frame.constvar_to_val[var] = c return tracer - def _lift_tracers_in_aval(self, aval): + def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1966,7 +1968,9 @@ def is_const(self, tracer): def process_primitive(self, primitive, tracers, params): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) - jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) return self.default_process_primitive(primitive, jaxpr_tracers, params) @@ -1989,17 +1993,19 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers, params): + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) assert f.in_type is not None - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers, + source_info) + in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) - source_info = source_info_util.current() out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: @@ -2009,7 +2015,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -2022,7 +2028,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) @@ -2041,10 +2049,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] - source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2062,7 +2069,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) @@ -2079,7 +2088,7 @@ def jvp_jaxpr_thunk(*in_zeros): out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, @@ -2088,7 +2097,7 @@ def jvp_jaxpr_thunk(*in_zeros): num_consts=len(consts), symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2097,7 +2106,9 @@ def process_custom_vjp_call(self, prim: core.Primitive, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, out_trees: Callable[[], Sequence[PyTreeDef]], symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2110,9 +2121,9 @@ def fwd_jaxpr_from_zeros(*zeros): if attrs: raise NotImplementedError return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, # pytype: disable=attribute-error @@ -2122,7 +2133,7 @@ def fwd_jaxpr_from_zeros(*zeros): bwd=bwd, out_trees=out_trees, symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2132,7 +2143,9 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid out_types, lin_tree: PyTreeDef, res_tree: PyTreeDef, out_tree: PyTreeDef): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] @@ -2152,9 +2165,9 @@ def transpose_jaxpr_thunk(): jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2162,7 +2175,7 @@ def transpose_jaxpr_thunk(): out_types=out_types, res_tree=res_tree, lin_tree=lin_tree, out_tree=out_tree), closed_call_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2216,13 +2229,15 @@ def trace_to_jaxpr_dynamic( keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + source_info = source_info_util.current() + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] try: with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) del fun, in_tracers, out_tracers, ans @@ -2269,12 +2284,14 @@ def trace_to_jaxpr_dynamic2( trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + source_info = source_info_util.current() in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) del trace, in_tracers, out_tracers, ans @@ -2449,7 +2466,7 @@ def __hash__(self): def _extract_implicit_args( trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer] + explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo, ) -> Sequence[DynamicJaxprTracer]: # First, construct a list to represent the full argument list, leaving the # implicit arguments as Nones for now. @@ -2467,8 +2484,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) + tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2616,13 +2633,13 @@ def inline_jaxpr_into_trace( trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - const_tracers = map(trace.new_const, consts) + src = source_info_util.current() + const_tracers = map(partial(trace.new_const, source_info=src), consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*constvars, *argvars])) - src = source_info_util.current() for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] outvars = [Var('', v.aval) for v in eqn.outvars] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 34f2ef0487a7..0f9bb4c4a5f7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1926,15 +1926,16 @@ def pjit_staging_rule(trace, *args, **params): else: out = pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) - return [trace.to_jaxpr_tracer(x) for x in out] + source_info = source_info_util.current() + return [trace.to_jaxpr_tracer(x, source_info) for x in out] jaxpr = params['jaxpr'] + source_info = source_info_util.current() if config.dynamic_shapes.value: jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( jaxpr, params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: @@ -1953,7 +1954,7 @@ def pjit_staging_rule(trace, *args, **params): assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.new_const, consts) + consts = map(partial(trace.new_const, source_info=source_info), consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index dcbdb0896916..634e7f2b5701 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -562,7 +562,9 @@ def _shard_map_staging( check_vma: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: - in_tracers = map(trace.to_jaxpr_tracer, in_tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) + in_tracers = map(to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh, auto, check_vma), in_names, in_avals) @@ -577,10 +579,9 @@ def _shard_map_staging( out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] - source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) + constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 0d40938a85c4..54fd0fe0b02f 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -87,10 +87,11 @@ def _check_append_type_agreement(_, attr, curtype, valtype): def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame + source_info = source_info_util.current() def new_tracer(x): aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) + tracer = pe.DynamicJaxprTracer(trace, aval, source_info) var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) frame.attrs_vars.append(var) frame.tracers.append(tracer) From 89b2193a85a1b2d56dca255d16c14e52b851afac Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 24 Apr 2025 12:56:34 -0700 Subject: [PATCH 0795/1769] Update README.md to remove jax.pmap And do other things... Co-authored-by: Yash Katariya --- README.md | 338 ++++++++++++------------------------------------------ 1 file changed, 73 insertions(+), 265 deletions(-) diff --git a/README.md b/README.md index 00391f314044..cb217b0cf92e 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,9 @@ [![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml) [![PyPI version](https://img.shields.io/pypi/v/jax)](https://pypi.org/project/jax/) -[**Quickstart**](#quickstart-colab-in-the-cloud) -| [**Transformations**](#transformations) +[**Transformations**](#transformations) +| [**Scaling**](#scaling) | [**Install guide**](#installation) -| [**Neural net libraries**](#neural-network-libraries) | [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) | [**Reference docs**](https://docs.jax.dev/en/latest/) @@ -20,42 +19,29 @@ JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. -With its updated version of [Autograd](https://github.com/hips/autograd), JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) -via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +via [`jax.grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. -What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) -to compile and run your NumPy programs on GPUs and TPUs. Compilation happens -under the hood by default, with library calls getting just-in-time compiled and -executed. But JAX also lets you just-in-time compile your own Python functions -into XLA-optimized kernels using a one-function API, -[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be -composed arbitrarily, so you can express sophisticated algorithms and get -maximal performance without leaving Python. You can even program multiple GPUs -or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and -differentiate through the whole thing. +JAX uses [XLA](https://www.tensorflow.org/xla) +to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. +You can compile your own pure functions with [`jax.jit`](#compilation-with-jit). +Compilation and automatic differentiation can be composed arbitrarily. Dig a little deeper, and you'll see that JAX is really an extensible system for -[composable function transformations](#transformations). Both -[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) -are instances of such transformations. Others are -[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and -[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) -parallel programming of multiple accelerators, with more to come. +[composable function transformations](#transformations) at [scale](#scaling). This is a research project, not an official Google product. Expect [sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Please help by trying it out, [reporting -bugs](https://github.com/jax-ml/jax/issues), and letting us know what you -think! +Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), +and letting us know what you think! ```python +import jax import jax.numpy as jnp -from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: @@ -67,85 +53,50 @@ def loss(params, inputs, targets): preds = predict(params, inputs) return jnp.sum((preds - targets)**2) -grad_loss = jit(grad(loss)) # compiled gradient evaluation function -perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +grad_loss = jax.jit(jax.grad(loss)) # compiled gradient evaluation function +perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads ``` ### Contents -* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) * [Transformations](#transformations) +* [Scaling](#scaling) * [Current gotchas](#current-gotchas) * [Installation](#installation) * [Neural net libraries](#neural-network-libraries) * [Citing JAX](#citing-jax) * [Reference documentation](#reference-documentation) -## Quickstart: Colab in the Cloud -Jump right in using a notebook in your browser, connected to a Google Cloud GPU. -Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) - -**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). - -For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) -- See the [full list of -notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). - ## Transformations At its core, JAX is an extensible system for transforming numerical functions. -Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and -`pmap`. +Here are three: `jax.grad`, `jax.jit`, and `jax.vmap`. ### Automatic differentiation with `grad` -JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). -The most popular function is -[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) -for reverse-mode gradients: +Use [`jax.grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) +to efficiently compute reverse-mode gradients: ```python -from jax import grad +import jax import jax.numpy as jnp -def tanh(x): # Define a function +def tanh(x): y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) -grad_tanh = grad(tanh) # Obtain its gradient function -print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +grad_tanh = jax.grad(tanh) +print(grad_tanh(1.0)) # prints 0.4199743 ``` -You can differentiate to any order with `grad`. +You can differentiate to any order with `grad`: ```python -print(grad(grad(grad(tanh)))(1.0)) +print(jax.grad(jax.grad(jax.grad(tanh)))(1.0)) # prints 0.62162673 ``` -For more advanced autodiff, you can use -[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for -reverse-mode vector-Jacobian products and -[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for -forward-mode Jacobian-vector products. The two can be composed arbitrarily with -one another, and with other JAX transformations. Here's one way to compose those -to make a function that efficiently computes [full Hessian -matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian): - -```python -from jax import jit, jacfwd, jacrev - -def hessian(fun): - return jit(jacfwd(jacrev(fun))) -``` - -As with [Autograd](https://github.com/hips/autograd), you're free to use -differentiation with Python control structures: +You're free to use differentiation with Python control flow: ```python def abs_val(x): @@ -154,229 +105,103 @@ def abs_val(x): else: return -x -abs_val_grad = grad(abs_val) +abs_val_grad = jax.grad(abs_val) print(abs_val_grad(1.0)) # prints 1.0 print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` -See the [reference docs on automatic -differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) -and the [JAX Autodiff +See the [JAX Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +and the [reference docs on automatic +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) for more. ### Compilation with `jit` -You can use XLA to compile your functions end-to-end with +Use XLA to compile your functions end-to-end with [`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python +import jax import jax.numpy as jnp -from jax import jit def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) -fast_f = jit(slow_f) -%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X -%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +fast_f = jax.jit(slow_f) +%timeit -n10 -r3 fast_f(x) +%timeit -n10 -r3 slow_f(x) ``` -You can mix `jit` and `grad` and any other JAX transformation however you like. - -Using `jit` puts constraints on the kind of Python control flow +Using `jax.jit` constrains the kind of Python control flow the function can use; see the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is -the vectorizing map. -It has the familiar semantics of mapping a function along array axes, but -instead of keeping the loop on the outside, it pushes the loop down into a -function’s primitive operations for better performance. +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) maps +a function along array axes. +But instead of just looping over function applications, it pushes the loop down +onto the function’s primitive operations, e.g. turning matrix-vector multiplies into +matrix-matrix multiplies for better performance. Using `vmap` can save you from having to carry around batch dimensions in your -code. For example, consider this simple *unbatched* neural network prediction -function: - -```python -def predict(params, input_vec): - assert input_vec.ndim == 1 - activations = input_vec - for W, b in params: - outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! - activations = jnp.tanh(outputs) # inputs to the next layer - return outputs # no activation on last layer -``` - -We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the -left side of `activations`, but we’ve written this particular prediction function to -apply only to single input vectors. If we wanted to apply this function to a -batch of inputs at once, semantically we could just write - -```python -from functools import partial -predictions = jnp.stack(list(map(partial(predict, params), input_batch))) -``` - -But pushing one example through the network at a time would be slow! It’s better -to vectorize the computation, so that at every layer we’re doing matrix-matrix -multiplication rather than matrix-vector multiplication. - -The `vmap` function does that transformation for us. That is, if we write - -```python -from jax import vmap -predictions = vmap(partial(predict, params))(input_batch) -# or, alternatively -predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) -``` - -then the `vmap` function will push the outer loop inside the function, and our -machine will end up executing matrix-matrix multiplications exactly as if we’d -done the batching by hand. - -It’s easy enough to manually batch a simple neural network without `vmap`, but -in other cases manual vectorization can be impractical or impossible. Take the -problem of efficiently computing per-example gradients: that is, for a fixed set -of parameters, we want to compute the gradient of our loss function evaluated -separately at each example in a batch. With `vmap`, it’s easy: - -```python -per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) -``` - -Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other -JAX transformation! We use `vmap` with both forward- and reverse-mode automatic -differentiation for fast Jacobian and Hessian matrix calculations in -`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. - -### SPMD programming with `pmap` - -For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap). -With `pmap` you write single-program multiple-data (SPMD) programs, including -fast parallel collective communication operations. Applying `pmap` will mean -that the function you write is compiled by XLA (similarly to `jit`), then -replicated and executed in parallel across devices. - -Here's an example on an 8-GPU machine: +code: ```python -from jax import random, pmap +import jax import jax.numpy as jnp -# Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.key(0), 8) -mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) +def l1_distance(x, y): + assert x.ndim == y.ndim == 1 # only works on 1D inputs + return jnp.sum(jnp.abs(x - y)) -# Run a local matmul on each device in parallel (no data transfer) -result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) +def pairwise_distances(dist1D, xs): + return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs) -# Compute the mean on each device in parallel and print the result -print(pmap(jnp.mean)(result)) -# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +xs = jax.random.normal(jax.random.key(0), (100, 3)) +dists = pairwise_distances(l1_distance, xs) +dists.shape # (100, 100) ``` -In addition to expressing pure maps, you can use fast [collective communication -operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) -between devices: +By composing `jax.vmap` with `jax.grad` and `jax.jit`, we can get efficient +Jacobian matrices, or per-example gradients: ```python -from functools import partial -from jax import lax - -@partial(pmap, axis_name='i') -def normalize(x): - return x / lax.psum(x, 'i') - -print(normalize(jnp.arange(4.))) -# prints [0. 0.16666667 0.33333334 0.5 ] +per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0))) ``` -You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more -sophisticated communication patterns. - -It all composes, so you're free to differentiate through parallel computations: +## Scaling -```python -from jax import grad - -@pmap -def f(x): - y = jnp.sin(x) - @pmap - def g(z): - return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() - return grad(lambda w: jnp.sum(g(w)))(x) - -print(f(x)) -# [[ 0. , -0.7170853 ], -# [-3.1085174 , -0.4824318 ], -# [10.366636 , 13.135289 ], -# [ 0.22163185, -0.52112055]] - -print(grad(lambda x: jnp.sum(f(x)))(x)) -# [[ -3.2369726, -1.6356447], -# [ 4.7572474, 11.606951 ], -# [-98.524414 , 42.76499 ], -# [ -1.6007166, -1.2568436]] -``` +To scale your computations across thousands of devices, you can use any +composition of these: +* [**Compiler-based automatic parallelization**](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +where you program as if using a single global machine, and the compiler chooses +how to shard data and partition computation (with some user-provided constraints); +* [**Explicit sharding and automatic partitioning**](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) +where you still have a global view but data shardings are +explicit in JAX types, inspectable using `jax.typeof`; +* [**Manual per-device programming**](https://docs.jax.dev/en/latest/notebooks/shard_map.html) +where you have a per-device view of data +and computation, and can communicate with explicit collectives. -When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the -backward pass of the computation is parallelized just like the forward pass. +See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and +[advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more. -See the [SPMD -Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) -and the [SPMD MNIST classifier from scratch -example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) -for more. +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | No | No | +| Explicit | Global | Yes | No | +| Manual | Per-device | Yes | Yes | -## Current gotchas +## Gotchas and sharp bits -For a more thorough survey of current gotchas, with examples and explanations, -we highly recommend reading the [Gotchas +See the [Gotchas Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Some standouts: - -1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. -1. [In-place mutating updates of - arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. -1. [Random numbers are - different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). -1. If you're looking for [convolution - operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html), - they're in the `jax.lax` package. -1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and - [to enable - double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) - (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at - startup (or set the environment variable `JAX_ENABLE_X64=True`). - On TPU, JAX uses 32-bit values by default for everything _except_ internal - temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. - Those ops have a `precision` parameter which can be used to approximate 32-bit operations - via three bfloat16 passes, with a cost of possibly slower runtime. - Non-matmul operations on TPU lower to implementations that often emphasize speed over - accuracy, so in practice computations on TPU will be less precise than similar - computations on other backends. -1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars - and NumPy types aren't preserved, namely `np.add(1, np.array([2], - np.float32)).dtype` is `float64` rather than `float32`. -1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://docs.jax.dev/en/latest/control-flow.html). - You'll always get loud errors if something goes wrong. You might have to use - [`jit`'s `static_argnums` - parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), - [structured control flow - primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators) - like - [`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), - or just use `jit` on smaller subfunctions. ## Installation @@ -408,23 +233,6 @@ for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. - - -## Neural network libraries - -Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries -for training neural networks in JAX. If you want a fully featured library for neural network -training with examples and how-to guides, try -[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). - -Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem) -on the JAX documentation site for a list of JAX-based network libraries, which includes -[Optax](https://github.com/deepmind/optax) for gradient processing and -optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and -[Equinox](https://github.com/patrick-kidger/equinox) for neural networks. -(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk -[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.) - ## Citing JAX To cite this repository: From 6f2f0e181f3b8858553e2ddc3fc1259e44c6bcb7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 24 Apr 2025 13:05:37 -0700 Subject: [PATCH 0796/1769] use :check: and :x: emojis in readme table --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cb217b0cf92e..c8bf5ce52f68 100644 --- a/README.md +++ b/README.md @@ -194,9 +194,9 @@ See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and | Mode | View? | Explicit sharding? | Explicit Collectives? | |---|---|---|---| -| Auto | Global | No | No | -| Explicit | Global | Yes | No | -| Manual | Per-device | Yes | Yes | +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | ## Gotchas and sharp bits From bc39f86ac2944403a1d3a24a8ebec1d3b1d6bd6b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 13:17:21 -0700 Subject: [PATCH 0797/1769] Update the sharding table to be the same everywhere PiperOrigin-RevId: 751110300 --- docs/notebooks/explicit-sharding.ipynb | 10 +++++----- docs/notebooks/explicit-sharding.md | 10 +++++----- docs/sharded-computation.ipynb | 10 +++++----- docs/sharded-computation.md | 10 +++++----- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index dada1f0db507..f37d0dcdf887 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -527,11 +527,11 @@ "\n", "A summary table:\n", "\n", - "| Mode | Explicit sharding? | Explicit Collectives? |\n", - "|---|---|---|\n", - "| Auto | No | No |\n", - "| Explicit (new) | Yes | No |\n", - "| Manual | Yes | Yes |\n", + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", "\n", "The current mesh tells us which sharding mode we're in. We can query it with\n", "`get_abstract_mesh`:" diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index a091060393b6..b374b7d7a668 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -337,11 +337,11 @@ JAX now has three styles of parallelism: A summary table: -| Mode | Explicit sharding? | Explicit Collectives? | -|---|---|---| -| Auto | No | No | -| Explicit (new) | Yes | No | -| Manual | Yes | Yes | +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | The current mesh tells us which sharding mode we're in. We can query it with `get_abstract_mesh`: diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index f9f33febb094..c9e947d374f0 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -26,11 +26,11 @@ "\n", "A summary table:\n", "\n", - "| Mode | Explicit sharding? | Explicit Collectives? |\n", - "|---|---|---|\n", - "| Auto | No | No |\n", - "| Explicit (new) | Yes | No |\n", - "| Manual | Yes | Yes |\n", + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", "\n", "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices." ] diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 60e789a109b2..8af91ef5c306 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -33,11 +33,11 @@ The tutorial covers three modes of parallel computation: A summary table: -| Mode | Explicit sharding? | Explicit Collectives? | -|---|---|---| -| Auto | No | No | -| Explicit (new) | Yes | No | -| Manual | Yes | Yes | +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. From 079c7d6d5fb5b7718548cfbbc6856aed09f2cd70 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 24 Apr 2025 13:24:02 -0700 Subject: [PATCH 0798/1769] fix: in jax.asarray check default_device instead of default_backend PiperOrigin-RevId: 751112644 --- jax/_src/numpy/lax_numpy.py | 26 +++++++++++++++++--- tests/lax_numpy_test.py | 47 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6e9cb3de3985..a28a6d94e3eb 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5559,6 +5559,26 @@ def canonicalize_device_to_sharding(device: xc.Device | Sharding | None return device +def _get_platform( + device_or_sharding: xc.Device | Sharding | None | str) -> str: + """Get device_or_sharding platform or look up config.default_device.value.""" + if isinstance(device_or_sharding, xc.Device): + return device_or_sharding.platform + elif isinstance(device_or_sharding, Sharding): + return list(device_or_sharding.device_set)[0].platform + elif isinstance(device_or_sharding, str): + return device_or_sharding + elif device_or_sharding is None: + if config.default_device.value is None: + return jax.default_backend() + else: + return _get_platform(config.default_device.value) + else: + raise ValueError(f"`{device_or_sharding = }` was passed to" + "`canonicalize_or_get_default_platform`, only xc.Device," + " Sharding, None or str values are supported.") + + def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: try: dtypes.dtype(x) @@ -5703,11 +5723,11 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, # the buffer protocol but a copy is required. Since array() supports the buffer protocol # via numpy, this is only the case when the default device is not 'cpu' if (copy is False and not isinstance(a, Array) - and jax.default_backend() != 'cpu' + and _get_platform(device) != "cpu" and _supports_buffer_protocol(a)): raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " - f"on backend={jax.default_backend()!r} with copy=False. " - "Consider using copy=None or copy=True instead.") + f"on platform={_get_platform(device)} with " + "copy=False. Consider using copy=None or copy=True instead.") dtypes.check_user_dtype_supported(dtype, "asarray") if dtype is not None: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 4ef775d17e53..2e4a84c6c99e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3685,6 +3685,53 @@ def testAsarrayCopy(self, copy): self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) + @jtu.sample_product(numpy_array=[True, False]) + def testAsarrayWithCopyFalse(self, numpy_array): + x_jax = jnp.arange(4) + if numpy_array: + x = np.arange(4) + else: + x = make_python_array('l', [0, 1, 2, 3]) + device_error_msg = ('jnp.asarray: cannot convert object of type .* to JAX' + ' Array on platform={} with copy=False. Consider using' + ' copy=None or copy=True instead.') + + if jax.default_backend() != 'cpu': + # test accelerator devices - no support for copy=False + expected_platform = jax.local_devices()[0].platform + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=jax.local_devices()[0]) + sharding = SingleDeviceSharding(jax.local_devices()[0]) + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=sharding) + + # test None defaults to default backend - no support for copy=False + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=None) + else: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + + # test explicit CPU device or default CPU device context managers overwrite the default backend + x = make_python_array('l', [0, 1, 2, 3]) + for device in [jax.local_devices(backend='cpu')[0], + SingleDeviceSharding(jax.local_devices(backend='cpu')[0])]: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=device), + x_jax, check_dtypes=False) + with jax.default_device('cpu'): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + with jax.default_device(jax.local_devices(backend='cpu')[0]): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") def testArrayDtypeInference(self): def _check(obj, out_dtype, weak_type): From 24530dc9357541c43c4f002bfc130f6483217a6f Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 24 Apr 2025 16:24:16 -0400 Subject: [PATCH 0799/1769] Edit print_environment_info to print environment variables that start with JAX_. --- jax/_src/environment_info.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 4abfdeaa0f14..e11d1b4f4b31 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import os import platform import subprocess import sys @@ -48,8 +49,10 @@ def print_environment_info(return_string: bool = False) -> str | None: python: {python_version} device info: {xb.devices()[0].device_kind}-{xb.device_count()}, {xb.local_device_count()} local devices" process_count: {xb.process_count()} - platform: {platform.uname()} -""") + platform: {platform.uname()}""") + for key, value in os.environ.items(): + if key.startswith("JAX_"): + info += f"\n{key}={value}" nvidia_smi = try_nvidia_smi() if nvidia_smi: info += '\n\n$ nvidia-smi\n' + nvidia_smi From d585fafde31aecd28b73675d1c51f66e9939d75a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 13:27:05 -0700 Subject: [PATCH 0800/1769] Set __module__ on NamedSharding, SingleDeviceSharding and PmapSharding so we don't get the internal name for it in errors PiperOrigin-RevId: 751113771 --- jax/_src/named_sharding.py | 1 + jax/_src/sharding_impls.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index fba54f438471..c11e3687ba15 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -250,6 +250,7 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: return SdyArraySharding(self.mesh.shape_tuple, dim_shardings, self._logical_device_ids) +NamedSharding.__module__ = 'jax.sharding' def get_array_mapping( axis_resources: PartitionSpec | AUTO | UnspecifiedValue diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index faa0d31ee8ca..d5d7725f5381 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -200,6 +200,7 @@ def is_fully_addressable(self) -> bool: return xb.process_index(self._device.client) == self._device.process_index return True +SingleDeviceSharding.__module__ = 'jax.sharding' @util.cache(max_size=4096, trace_context_in_key=False) def pmap_sharding_devices_indices_map( @@ -368,6 +369,7 @@ def shard_shape(self, global_shape: Shape) -> Shape: f'the number of devices={len(self._device_assignment)}') return sharded_shape +PmapSharding.__module__ = 'jax.sharding' def _op_sharding_to_pos_sharding( op_sharding: xc.OpSharding | xc.HloSharding, From 277faad7c0d34897770fa8c07ae1511efe12258a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Apr 2025 13:38:03 -0700 Subject: [PATCH 0801/1769] Add mosaic tests to optional B200 GPU presubmit. PiperOrigin-RevId: 751118618 --- .github/workflows/bazel_optional_b200.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_optional_b200.yml b/.github/workflows/bazel_optional_b200.yml index 6335fbacaf2c..e72cdbda9205 100644 --- a/.github/workflows/bazel_optional_b200.yml +++ b/.github/workflows/bazel_optional_b200.yml @@ -59,4 +59,5 @@ jobs: --action_env=NCCL_DEBUG=WARN \ --color=yes \ //tests:gpu_tests //tests:backend_independent_tests \ - //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests \ No newline at end of file From cb704db594d406b0aa43416f48e3158f7a3bd41d Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 24 Apr 2025 22:39:00 +0200 Subject: [PATCH 0802/1769] Updated TSAN suppresssions files to get traceback of crashed process --- .github/workflows/tsan-suppressions_3.13.txt | 3 --- .github/workflows/tsan-suppressions_3.14.txt | 3 --- 2 files changed, 6 deletions(-) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index f4fbf830ddc2..fceb8f5f9a61 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -8,9 +8,6 @@ race:dnnl_sgemm # https://github.com/python/cpython/issues/128050 race:partial_vectorcall_fallback -# Likely only happens when the process is crashing. -race:dump_traceback - # https://github.com/python/cpython/issues/128137 # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index 6e1d34e6db65..3d0f5518862a 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -8,9 +8,6 @@ race:dnnl_sgemm # https://github.com/python/cpython/issues/128050 race:partial_vectorcall_fallback -# Likely only happens when the process is crashing. -race:dump_traceback - # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx From 10901d59deffa6db5f69f7d147ae6e1c6b2aa8ad Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 24 Apr 2025 21:09:35 +0000 Subject: [PATCH 0803/1769] [readme] include teaser example of shardings Co-authored-by: Yash Katariya --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index c8bf5ce52f68..9fe5bb7e9cb0 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,24 @@ explicit in JAX types, inspectable using `jax.typeof`; where you have a per-device view of data and computation, and can communicate with explicit collectives. +```python +from jax.sharding import set_mesh, AxisType, PartitionSpec as P +mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,)) +set_mesh(mesh) + +# parameters are sharded for FSDP: +for W, b in params: + print(f'{jax.typeof(W)}') # f32[512@data,512] + print(f'{jax.typeof(b)}') # f32[512] + +# shard data for batch parallelism: +inputs, targets = jax.device_put((inputs, targets), P('data')) + +# evaluate gradients, automatically parallelized! +gradfun = jax.jit(jax.grad(loss)) +param_grads = gradfun(params, (inputs, targets)) +``` + See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and [advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more. From 1bc9054b9eecbf576757899fba89d08c986bd3b3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 14:29:44 -0700 Subject: [PATCH 0804/1769] Move sharding table above the code example PiperOrigin-RevId: 751139503 --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 9fe5bb7e9cb0..a5deecef6c37 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,12 @@ explicit in JAX types, inspectable using `jax.typeof`; where you have a per-device view of data and computation, and can communicate with explicit collectives. +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | + ```python from jax.sharding import set_mesh, AxisType, PartitionSpec as P mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,)) @@ -210,12 +216,6 @@ param_grads = gradfun(params, (inputs, targets)) See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and [advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more. -| Mode | View? | Explicit sharding? | Explicit Collectives? | -|---|---|---|---| -| Auto | Global | ❌ | ❌ | -| Explicit | Global | ✅ | ❌ | -| Manual | Per-device | ✅ | ✅ | - ## Gotchas and sharp bits See the [Gotchas From ff3a520693bbf087eb0862048523e2eec811c080 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Thu, 24 Apr 2025 21:44:41 +0000 Subject: [PATCH 0805/1769] fix: make build.py use /usr/bin/env python3 as shebang --- build/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/build.py b/build/build.py index a5d39e559a9d..287400ee8f42 100755 --- a/build/build.py +++ b/build/build.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python3 # # Copyright 2018 The JAX Authors. # From 0f52d05387c548684b3c9c9a8a5037a2ce72c854 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Apr 2025 15:45:59 -0700 Subject: [PATCH 0806/1769] Fix typo in Bazel command for Mosaic tests. PiperOrigin-RevId: 751166905 --- .github/workflows/bazel_optional_b200.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_optional_b200.yml b/.github/workflows/bazel_optional_b200.yml index e72cdbda9205..1620022965f9 100644 --- a/.github/workflows/bazel_optional_b200.yml +++ b/.github/workflows/bazel_optional_b200.yml @@ -59,5 +59,5 @@ jobs: --action_env=NCCL_DEBUG=WARN \ --color=yes \ //tests:gpu_tests //tests:backend_independent_tests \ - //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests \ No newline at end of file From 3e65b335d85701d80ee37009a8a53a293fe1d681 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Apr 2025 16:03:34 -0700 Subject: [PATCH 0807/1769] [Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_slice_p with Element block shapes PiperOrigin-RevId: 751173626 --- jax/_src/pallas/fuser/block_spec.py | 43 +++++++++++++++++++++-------- jax/_src/pallas/mosaic/lowering.py | 11 ++++++-- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 19b813a03f10..d8d01005622a 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -94,6 +94,12 @@ def wrapped(*args): return wrapped +def _block_size(dim: pallas_core.Element | int | None) -> int | None: + if isinstance(dim, pallas_core.Element): + return dim.block_size + return dim + + @dataclasses.dataclass class UsageRuleContext: avals_in: tuple[core.AbstractValue, ...] @@ -420,11 +426,6 @@ def make_kernel_function( invar_usages = util.safe_map(read_usage_env, jaxpr.invars) bs_env, scalar_prefetch_fn_env = block_spec_env - def _block_size(dim: pallas_core.Element | int | None) -> int | None: - if isinstance(dim, pallas_core.Element): - return dim.block_size - return dim - def _remove_nones( shape: tuple[pallas_core.Element | int | None, ...] | None ) -> tuple[int, ...]: @@ -727,7 +728,14 @@ def new_index_map(i, *args): idx = util.tuple_update(idx, i, 0) return idx - new_block_shape = util.tuple_update(block_spec.block_shape, i, 1) + # TODO(wdvi): This is a hack needed since lowering rules require block shape + # to contain either all pl.Element or none + bcast_dim_block_shape = 1 + if isinstance(block_spec.block_shape[i], pallas_core.Element): + bcast_dim_block_shape = pallas_core.Element(1) + new_block_shape = util.tuple_update( + block_spec.block_shape, i, bcast_dim_block_shape + ) return pallas_core.BlockSpec( new_block_shape, functools.partial(new_index_map, i) ) @@ -876,10 +884,13 @@ def _slice_rule( ): if bs is None: continue - assert slice_start % bs == 0, (start_indices, block_spec.block_shape) - assert slice_size % bs == 0, (slice_sizes, block_spec.block_shape) + block_size = _block_size(bs) + assert ( + slice_start % block_size == 0 + ), (start_indices, block_spec.block_shape) + assert slice_size % block_size == 0, (slice_sizes, block_spec.block_shape) offsets = tuple( - slice_start // bs if bs is not None else slice_start + slice_start // _block_size(bs) if bs is not None else slice_start for slice_start, bs in zip(start_indices, block_spec.block_shape) ) @@ -957,7 +968,7 @@ def new_index_map(*args): # We then add these block indices to block indices produced by the index # map. block_indices = tuple( - _offset(i, o, s) + _offset(i, o, _block_size(s)) for i, o, s in zip( idx, slice_starts, block_spec.block_shape, strict=True ) @@ -976,6 +987,11 @@ def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): # divides the block size. block_spec = ctx.out_block_specs[0] block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + "Concatenation with Element indexing is not yet supported." + ) block_dim = block_shape[dimension] if block_dim is None: block_dim = 1 @@ -1019,6 +1035,11 @@ def _concatenate_rule( dimension: int, ): block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + "Concatenation with Element indexing is not yet supported." + ) num_blocks = [] block_dim = block_shape[dimension] if block_dim is None: @@ -1093,7 +1114,7 @@ def _broadcast_in_dim_eval_rule( if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error # Scalar -> Array broadcast block_spec = eval_ctx.out_block_specs[0] - shape = tuple(s for s in block_spec.block_shape if s is not None) + shape = tuple(_block_size(s) for s in block_spec.block_shape if s is not None) return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape) return x diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 208ae389d2bf..ea7681c31f11 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -764,12 +764,19 @@ def dynamic_shape_replacement_fn( is_element_block = [isinstance(bd, pallas_core.Element) for bd in bm.block_shape] if any(is_element_block): - if not all(is_element_block): + is_element_or_squeezed_block = [ + isinstance(bd, (pallas_core.Element, pallas_core.Squeezed)) + for bd in bm.block_shape + ] + if not all(is_element_or_squeezed_block): raise NotImplementedError( "All block dimensions must be Elements or none of them can be" " Elements." ) - padding = [bd.padding for bd in bm.block_shape] # pytype: disable=attribute-error + padding = [ + bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape + ] pad_low, pad_high = map(list, zip(*padding)) block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" From 672ee47b0f9135bc93ed13310a6453a3afb3da13 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 23 Apr 2025 22:14:55 +0000 Subject: [PATCH 0808/1769] [scan] when a carry is read-only, move it to be a const This is like the scan version of #27970, which applied to while_loop. Fixes google/flax#4709 --- jax/_src/interpreters/ad.py | 10 +++--- jax/_src/interpreters/partial_eval.py | 36 ++++++++++++-------- jax/_src/lax/control_flow/loops.py | 27 +++++++++++++-- tests/api_test.py | 4 +-- tests/core_test.py | 6 ++-- tests/lax_control_flow_test.py | 49 ++++++++++++++++++++++++++- 6 files changed, 105 insertions(+), 27 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 98cda2df4964..35fd4fbe3b6c 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1187,10 +1187,12 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ arg_names=new_arg_names, result_paths=new_result_paths, ) - new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, - new_invars, new_outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects, - new_debug_info) + constvars = jaxpr.jaxpr.constvars + new_effects = pe._renumber_effects( + (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, + new_effects, new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0e888c1591aa..fab4fc20de47 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1558,18 +1558,22 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) - new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) - id_map = {id(v): i for i, v in enumerate(new_invars)} - idx_map = {i: id_map[id(v)] for i, v in enumerate(closed_jaxpr.jaxpr.invars)} - new_effs = {e.replace(input_index=idx_map[e.input_index]) - if isinstance(e, effects.JaxprInputEffect) else e - for e in closed_jaxpr.jaxpr.effects} - new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, - closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - new_effs, closed_jaxpr.jaxpr.debug_info) + constvars, invars = closed_jaxpr.jaxpr.constvars, closed_jaxpr.jaxpr.invars + new_invars = _move_to_front(invars, to_move) + new_effs = _renumber_effects( + (*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects) + new_jaxpr = Jaxpr(constvars, new_invars, closed_jaxpr.jaxpr.outvars, + closed_jaxpr.jaxpr.eqns, new_effs, + closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr +def _renumber_effects(new_vars, old_vars, effs): + newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} + old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} + return {e.replace(input_index=old_to_new[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} + def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + [elt for elt, move in zip(lst, to_move) if not move]) @@ -1589,7 +1593,6 @@ def _move_outvars_to_back(jaxpr, to_move): return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) - class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1670,16 +1673,19 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - invar = eqn.invars[eff.input_index] - if invar in mut_arrays: + eqn_invar = eqn.invars[eff.input_index] + if eqn_invar in mut_arrays: continue - if (input_index := all_vars.get(invar, sentinel)) is sentinel: + if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: + # TODO(mattjj): ask for forgiveness + dbg = type('Fake', (), {'resolve_result_paths': lambda _: None})() raise ValueError( f"`JaxprInputEffect` {eff} does not have " - f"corresponding input: {invar}." + f"corresponding jaxpr input: {eqn_invar=}." f"\n Equation: {eqn}\n" + f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 39498ad624bc..22ba0c5aa197 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -312,6 +312,9 @@ def _create_jaxpr(init): init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest num_carry = len(init_flat) + num_xs = len(x_avals) + num_ys = len(jaxpr.out_avals) - num_carry + del init_flat _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) @@ -327,22 +330,42 @@ def _create_jaxpr(init): unroll = max(length, 1) if unroll else 1 if unroll < 1: raise ValueError("`unroll` must be a `bool` or a positive `int`.") + if attrs_tracked: in_state = _get_states(attrs_tracked) in_flat = [*in_state, *in_flat] num_carry += len(in_state) + + # If the body forwards an input carry to an output carry, that input is + # read-only and can be moved to be a const. Doing so can lead to efficiency + # wins, e.g. if the scan is inside a cond with a batched predicate. + carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) + move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] + if any(move_to_const): + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [not m for m in move_to_const] + [True] * num_ys) + jaxpr = pe.move_binders_to_front( + jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) + in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat) + consts = [*new_consts, *consts] + num_carry -= len(new_consts) + out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll, - _split_transpose=_split_transpose) + unroll=unroll, _split_transpose=_split_transpose) + + if any(move_to_const): + out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) + if attrs_tracked: num_ext = (len(out) - len(in_state) - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) out_state, out, out_append = split_list(out, [len(in_state), num_ext]) out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) _set_states(attrs_tracked, out_attrs) + return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): diff --git a/tests/api_test.py b/tests/api_test.py index d80404f25a4e..67d5d93552f6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6853,13 +6853,13 @@ def body(c, _): self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule + expected_num_eqns=0) used_outputs[7] = expected_used_inputs[7] = True used_outputs[6] = expected_used_inputs[6] = True self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) + expected_num_eqns=0) # If we use the value at index 3 only, some of the hidden sequence must be # kept but the rest pruned. diff --git a/tests/core_test.py b/tests/core_test.py index 00b3eb1d61d5..004ef4b57435 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -357,15 +357,15 @@ def g_vmap(x): def test_dropvar_avals(self): def f(x): def body(c, _): - return c, None + x1, x2 = c + return (2 * x1, 2 * x2), None (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1) return [x2] aval = core.ShapedArray((), jnp.dtype('int32')) pval = pe.PartialVal.unknown(aval) jaxpr, _, _ = pe.trace_to_jaxpr_nounits( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0,), {})), + lu.wrap_init(f, debug_info=debug_info("test", f, (0,), {})), [pval], False) dropvar, b = jaxpr.eqns[0].outvars self.assertEqual(dropvar.aval, aval) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 8876fb7d06be..31d47ef9dd4b 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3122,7 +3122,7 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash - def test_readonly_carry_optimization(self): + def test_while_readonly_carry_optimization(self): # https://github.com/google/flax/issues/4700 def foo(w, x, c_max): def while_cond(val): @@ -3204,6 +3204,53 @@ def body_fun(c): outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) self.assertAllClose(outs, (0., 1., 5.)) + def test_scan_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4709 + def f(x, y): + def g(_, y): + y, _ = jax.lax.scan(lambda y, _: (y, None), y, None, length=1) + return y + return jax.lax.cond(x < 0, g, g, x, y) + xs = jnp.arange(3.) + y = 3. + jax.vmap(f, (0, None), None)(xs, y) # don't crash + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds, + num_noninplace_fwds): + + num_fwds = num_inplace_fwds + num_noninplace_fwds + num_carry = num_fwds + 4 + num_xs = 2 + num_ys = 3 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, _): + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds, num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return new_c, [0 for _ in range(num_ys)] + + xs = [jnp.arange(2.) for _ in range(num_xs)] + outs = jax.lax.scan(body_fun, init_vals, xs)[0] + outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(outs, outs_ref, check_dtypes=False) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 4dc876bb3dc90db717a70ca63f330324cd20b0e2 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 24 Apr 2025 16:38:27 -0700 Subject: [PATCH 0809/1769] Update `debug_info` tests for `use_direct_linearize`. PiperOrigin-RevId: 751184927 --- tests/debug_info_test.py | 84 +++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 611b2495949a..b5e875c03676 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -954,20 +954,28 @@ def my_g(u, v): return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", + # TODO(necula): result_paths + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", + # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=result['c']", + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", + # TODO(necula): result_paths + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", + # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", + ] self._check_tracers_and_jaxprs( jax.jit(lambda x, y, res_ct: jax.vjp(my_f, x, y)[1](res_ct)), 2., 3., 0.3, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", - # TODO(necula): result_paths - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", - # TODO(necula): arg_names - "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", @@ -1379,22 +1387,37 @@ def the_grad(c, as_): _, pullback = jax.vjp(my_f, c, as_) return pullback((c, np.arange(3, dtype=c.dtype))) + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + ] self._check_tracers_and_jaxprs( jax.jit(the_grad), c, as_, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", "traced_for=scan, fun=f, arg_names=c,a, from c", @@ -1687,14 +1710,23 @@ def my_f(x): x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x,, result_paths=result" + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result" + ] + self._check_tracers_and_jaxprs( jax.jit(jax.hessian(jax.jit(my_f))), x, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", From c8ae45bf39c054f70163014d90a62f38fc9a8562 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Apr 2025 17:21:50 -0700 Subject: [PATCH 0810/1769] Fix shard_map's direct linearize after vma has been turned on Co-authored-by: Matthew Johnson PiperOrigin-RevId: 751197878 --- jax/_src/shard_map.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 634e7f2b5701..9870991c66f1 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1288,7 +1288,9 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): - res_avals, _, _, _, _, _ = linearize_outs_thunk() + res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] out_names = out_names_thunk() if check_vma: res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} @@ -1313,7 +1315,9 @@ def fwd_out_names_thunk(): config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() - res_avals_iter = iter(res_avals) + res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] + res_avals_iter = iter(res_avals2) res_names = [] for f1, f2 in zip(in_fwd, out_fwd): if f1 is not None: From d109cd8283c4d415c196df7a1b8d999e3d5bc4be Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Apr 2025 18:04:23 -0700 Subject: [PATCH 0811/1769] [pallas] Use `*MemorySpace` aliases PiperOrigin-RevId: 751212769 --- .../paged_attention/paged_attention_kernel.py | 12 ++--- .../ops/tpu/ragged_paged_attention/kernel.py | 2 +- .../pallas/ops/tpu/random/philox.py | 4 +- .../pallas/ops/tpu/random/threefry.py | 2 +- tests/pallas/mosaic_gpu_test.py | 10 ++-- tests/pallas/ops_test.py | 6 +-- tests/pallas/pallas_error_handling_test.py | 8 ++-- tests/pallas/tpu_pallas_distributed_test.py | 28 +++++------ .../tpu_pallas_interpret_distributed_test.py | 34 ++++++------- tests/pallas/tpu_pallas_interpret_test.py | 6 +-- tests/pallas/tpu_pallas_pipeline_test.py | 48 +++++++++---------- tests/pallas/tpu_pallas_random_test.py | 14 +++--- tests/pallas/tpu_pallas_test.py | 32 ++++++------- 13 files changed, 103 insertions(+), 103 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 4c03fb01be2b..6280064f29d3 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -542,10 +542,10 @@ def paged_attention( if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ] scratch_shapes = ( pltpu.VMEM( @@ -589,9 +589,9 @@ def paged_attention( else: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), None, # type: ignore[list-item] - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), None, # type: ignore[list-item] ] scratch_shapes = ( diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index 3500ba3ee9fd..cd5de96ccca7 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -784,7 +784,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ) in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py index 28e627cfb298..4c43f5c7c2ff 100644 --- a/jax/experimental/pallas/ops/tpu/random/philox.py +++ b/jax/experimental/pallas/ops/tpu/random/philox.py @@ -140,8 +140,8 @@ def kernel(offset_ref, key_ref, out_ref): return pl.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=out_spec, grid=grid_dims, diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py index 5c460d491f48..5fdac5782349 100644 --- a/jax/experimental/pallas/ops/tpu/random/threefry.py +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -79,7 +79,7 @@ def kernel(key_ref, out_ref): block_shape = (1,) * (len(shape)-2) + block_size result = pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), grid=grid_dims, out_shape=out, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 23e33e68e2d3..41d02ed781e2 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1032,7 +1032,7 @@ def test_load_scalar(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GMEM)], ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) @@ -1062,8 +1062,8 @@ def test_run_scoped_in_cond(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - in_specs=[pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.SMEM), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref_gmem, o_ref): def scoped_kernel(barrier_ref): @@ -1370,8 +1370,8 @@ def kernel(a_ref, b_ref): a = np.zeros((64, 64), dtype=jnp.float32) b = self.pallas_call( kernel, - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GMEM), input_output_aliases={0: 0}, out_shape=a, )(a) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 709828186480..8c8e59aedc11 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1517,10 +1517,10 @@ def test_binary_scalar(self, f, dtype): @functools.partial( self.pallas_call, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((1,), dtype), ) def kernel(x_ref, y_ref, o_ref): diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index f7ea17852d56..cc0f3f8ba7aa 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -53,9 +53,9 @@ def test_non_singular_stride(self): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) @@ -101,9 +101,9 @@ def test_index_with_f32_verification_error(self): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 966ed13fdad8..aa4488b778a8 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -44,8 +44,8 @@ def setUp(self): self.skipTest('Only works with TPU v5e.') @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pltpu.ANY), ) def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute @@ -126,8 +126,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -180,8 +180,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -232,8 +232,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(x) @@ -291,7 +291,7 @@ def test_kernel(x_ref, grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 @@ -375,9 +375,9 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 ) @@ -467,11 +467,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 0fd94f1a8049..bd85ded66a73 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -91,9 +91,9 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): num_scalar_prefetch=0, # TPUMemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -204,9 +204,9 @@ def _(): num_scalar_prefetch=0, in_specs=[ # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -365,13 +365,13 @@ def _(): num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -647,11 +647,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -740,7 +740,7 @@ def test_reduce_scatter_sum_with_emit_pipeline_example( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pltpu.ANY, ) LEFT = 0 @@ -952,11 +952,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1056,10 +1056,10 @@ def run(src_dst_ids): kernel, out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], interpret=mosaic_interpret.TPUInterpretParams( dma_execution_mode='eager', diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 3f40f3cce0a2..1af4b29d60ff 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -169,7 +169,7 @@ def f(s, x): out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(iters,), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(x.shape, lambda i: (0, 0)), ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), @@ -244,7 +244,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): y = pl.pallas_call( kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, @@ -259,7 +259,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pl.pallas_call( kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 1f00b22b6708..0f1cdb5b957f 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -131,8 +131,8 @@ def setUp(self): super().setUp() @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pltpu.ANY), ) def test_pipeline_matmul(self, memory_space): # TODO(b/358121809): Re-enable this test once the bug is fixed. @@ -178,8 +178,8 @@ def matmul_kernel(x_ref, y_ref, z_ref): np.testing.assert_allclose(out, expected_out) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pltpu.ANY), ) def test_double_pipeline_matmul(self, memory_space): # TODO(b/358121809): Re-enable this test once the bug is fixed. @@ -236,11 +236,11 @@ def setUp(self): super().setUp() @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -526,11 +526,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_122', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_121', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_122', pltpu.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_121', pltpu.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -769,11 +769,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1056,11 +1056,11 @@ def reference(x, y): np.mean(np.abs(out - expected_out)) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 78a81d168136..aea2d05d57b4 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -117,7 +117,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) self.assertGreaterEqual(jnp.min(result), 0) @@ -135,7 +135,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) self.assertGreaterEqual(jnp.min(result), 0) @@ -153,8 +153,8 @@ def body(key_ref, o_ref): expected_key_data.dtype) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=o_shape, )(key) self.assertArraysEqual(result, expected_key_data) @@ -177,7 +177,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) result_a = result[0] @@ -211,7 +211,7 @@ def body(key_ref, o_ref): global_key = jax_random.key(0, impl="pallas_tpu") o_shape = jnp.ones((64, 512), dtype=jnp.float32) - key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM) + key_spec = pl.BlockSpec(memory_space=pltpu.SMEM) out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) result_16x128 = pl.pallas_call( make_kernel_body(index_map=lambda i, j: (i, j)), @@ -257,7 +257,7 @@ def body(key_ref, o_ref): # TODO(justinfu): support passing keys into VMEM. result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], out_shape=o_shape, )(jax.random.key_data(threefry_key)) jax_result = jax_random.uniform( diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index c2a28b72dd24..5d3c97108a6c 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -841,7 +841,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), )() @@ -861,7 +861,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), )() @@ -880,7 +880,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int32), )() @@ -899,7 +899,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((17, 128), jnp.int32), )() @@ -1099,7 +1099,7 @@ def body(sems): y = jax.block_until_ready( self.pallas_call( kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), )() ) @@ -1122,7 +1122,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], ), @@ -1378,7 +1378,7 @@ def body(y_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), @@ -1395,9 +1395,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1420,7 +1420,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1443,7 +1443,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((16, 128))) @@ -1472,7 +1472,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((3, 16, 128))) @@ -1499,7 +1499,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) @@ -1571,7 +1571,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -2579,8 +2579,8 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), grid=(1,), scratch_shapes=[pltpu.SemaphoreType.DMA] * 2, ) From 966578b6119fcb0ae2dc3e88d816d1da42a3465d Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 24 Apr 2025 19:05:47 -0700 Subject: [PATCH 0812/1769] [Pallas TPU] Introduce a BoundedSlice block shape type * Also add Python pipeline emitter support PiperOrigin-RevId: 751228049 --- jax/_src/pallas/core.py | 74 ++++++++++-- jax/_src/pallas/mosaic/lowering.py | 14 +++ jax/_src/pallas/mosaic/pipeline.py | 84 +++++++++++-- jax/_src/pallas/mosaic_gpu/pipeline.py | 5 +- jax/_src/pallas/pallas_call.py | 22 +++- jax/_src/pallas/triton/lowering.py | 4 +- jax/experimental/pallas/__init__.py | 2 + tests/pallas/pallas_test.py | 29 +++++ tests/pallas/tpu_pallas_pipeline_test.py | 143 ++++++++++++++++++++++- 9 files changed, 343 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 90d35c7949f7..4bc015ab4dad 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -39,6 +39,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import indexing from jax._src.state import types as state_types from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -359,7 +360,20 @@ class Blocked: def __str__(self): return f"Blocked({self.block_size})" -BlockDim: TypeAlias = Element | Squeezed | Blocked +@dataclasses.dataclass(frozen=True) +class BoundedSlice: + """Allows to specify a bounded slice of a dimension. + + Specifically, the index_map need to return a `pl.Slice/pl.ds` for this + dimension. The start and size may be dynamic, as long as the size <= + block_size. + """ + block_size: int + + def __repr__(self): + return f"BoundedSlice({self.block_size})" + +BlockDim: TypeAlias = Element | Squeezed | Blocked | BoundedSlice def default_index_map(ndim: int) -> Callable: @@ -372,7 +386,7 @@ def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim: return squeezed case int(): return Blocked(int(dim)) - case Squeezed() | Blocked() | Element(): + case Squeezed() | Blocked() | Element() | BoundedSlice(): return dim case _: # Handle case where the dim is a symbolic dimension so we assume it is @@ -400,6 +414,8 @@ def _get_block_dim_size(dim: BlockDim) -> int: return block_size case Element(): return dim.block_size + case BoundedSlice(block_size): + return block_size case _: raise ValueError(f"Unsupported block shape type: {type(dim)}") @@ -420,7 +436,16 @@ def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: class BlockSpec: """Specifies how an array should be sliced for each invocation of a kernel. - See :ref:`pallas_blockspec` for more details. + The `block_shape` is a sequence of `int | None`s, or `BlockDim` types (e.g. + `pl.Element`, `pl.Squeezed`, `pl.Blocked`, `pl.BoundedSlice`). Each of these + types specify the size of the block dimension. `None` is used to specify a + dimension that is squeezed out of the kernel. The `BlockDim` types allow for + more fine-grained control over the indexing of the dimension. The `index_map` + needs to return a tuple of the same length as `block_shape`, which each entry + depending on the type of `BlockDim`. + + See :ref:`pallas_blockspec` and the individual `BlockDim` type docstrings for + more details. """ # An internal canonicalized version is in BlockMapping. block_shape: Sequence[BlockDim | int | None] | None = None @@ -437,6 +462,17 @@ def __post_init__(self): " block dimension in `block_shape` instead to enable 'Unblocked'" " indexing." ) + if self.index_map is not None: + old_index_map = self.index_map + @functools.wraps(old_index_map) + def _wrapper_index_map(*args, **kwargs): + indices = old_index_map(*args, **kwargs) + if isinstance(indices, list): + indices = tuple(indices) + if not isinstance(indices, tuple): + indices = (indices,) + return indices + self.index_map = _wrapper_index_map def to_block_mapping( self, @@ -497,14 +533,36 @@ def to_block_mapping( jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( flat_index_map_fun, index_map_avals ) + index_map_out_tree = index_map_out_tree_thunk() + unflat_avals = tree_util.tree_unflatten(index_map_out_tree, out_avals) - if len(out_avals) != len(block_shape): + if len(unflat_avals) != len(block_shape): raise ValueError( f"Index map function {debug.func_src_info} for " f"{origin} must return " f"{len(block_shape)} values to match {block_shape=}. " - f"Currently returning {len(out_avals)} values." + f"Currently returning {len(unflat_avals)} values:" ) + # Verify types match + for i, (idx_aval, bd) in enumerate(zip(unflat_avals, block_shape)): + match bd: + case BoundedSlice(): + if not isinstance(idx_aval, indexing.Slice): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be pl.Slice" + ) + case Blocked() | Element() | Squeezed() | int(): + if ( + not isinstance(idx_aval, jax_core.ShapedArray) + and not idx_aval.shape + ): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be a scalar" + ) for i, ov in enumerate(out_avals): if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: raise ValueError( @@ -525,6 +583,7 @@ def to_block_mapping( block_shape=block_shape, transformed_block_aval=block_aval, # There are no transforms by default index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), + index_map_out_tree=index_map_out_tree, array_shape_dtype=jax.ShapeDtypeStruct( array_aval_shape, array_aval.dtype ), @@ -566,6 +625,7 @@ class BlockMapping: block_shape: tuple[BlockDim, ...] transformed_block_aval: AbstractMemoryRef index_map_jaxpr: jax_core.ClosedJaxpr + index_map_out_tree: tree_util.PyTreeDef array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr transforms: Sequence[MemoryRefTransform] = () @@ -582,10 +642,6 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( - self.block_shape, - self.index_map_jaxpr.out_avals, - ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index ea7681c31f11..f9c85206276a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -717,6 +717,12 @@ def dynamic_shape_replacement_fn( window_params = [] static_grid = None grid = mosaic_grid_mapping.grid + if not grid and any( + not bm.has_trivial_window() for bm in grid_mapping.block_mappings + ): + raise NotImplementedError( + "Non-trivial windowing is not supported for grid-free pallas_call." + ) if grid: for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" @@ -761,6 +767,14 @@ def dynamic_shape_replacement_fn( window_bounds=window_shape, transform_indices=ir.FlatSymbolRefAttr.get(func_name), ) + for bd in bm.block_shape: + if not isinstance( + bd, (pallas_core.Element, pallas_core.Squeezed, pallas_core.Blocked) + ): + raise NotImplementedError( + "Unsupported block dimension type: " + f"{type(bd)} for block shape: {bm.block_shape}" + ) is_element_block = [isinstance(bd, pallas_core.Element) for bd in bm.block_shape] if any(is_element_block): diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 5ea992b5443a..6019840ab178 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -111,7 +111,7 @@ def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: return s - s % multiple + multiple -def _make_ds( +def _make_block_ds( idx: jax.Array | int, size: jax.Array | int ) -> pl.Slice: """Make a DMA slice with mosaic size hints.""" @@ -119,17 +119,41 @@ def _make_ds( assert isinstance(out, pl.Slice) return out - def _make_block_slice( - block_index: jax.Array, block_size: int, size: int, tiling: int + block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, + tiling: int ) -> pl.Slice | slice: # Computes a slice given a block index and block size. In the default case, # we return slice(block_index * block_size, (block_index + 1) * block_size). # However, if the total size of the ref does not divide block size and we are # selecting the last block, we need to pick the lowest tiling size multiple # that contains the block. + match block_size: + case pl.Blocked(): + block_start = block_size.block_size * block_index + block_size = block_size.block_size + case pl.Element(): + block_start = block_index + block_size = block_size.block_size + case pl.BoundedSlice(): + if not isinstance(block_index, pl.Slice): + raise ValueError( + "Must return a pl.ds from the index_map for a BoundedSlice" + " dimension." + ) + block_start = block_index.start + block_size = block_index.size + return pl.ds(block_start, block_size) + case int(): + # This is same as Blocked. + block_start = block_index * block_size + case None | pl.Squeezed(): + block_start = block_index + block_size = 1 + case _: + raise ValueError(f"Unsupported block dimension type: {block_size}") if size % block_size == 0: - return _make_ds(block_index, block_size) + return pl.ds(block_start, block_size) if block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") num_blocks = pl.cdiv(size, block_size) @@ -145,7 +169,7 @@ def _make_block_slice( def _tuples_differ(xs, ys): """Dynamic index-tuple comparison calculation.""" - differences = jax.tree.map(lambda x, y: x != y, xs, ys) + differences = jax.tree.leaves(jax.tree.map(lambda x, y: x != y, xs, ys)) return functools.reduce(lambda x, y: x | y, differences, False) @@ -167,6 +191,26 @@ class BufferType(enum.Enum): MANUAL = 5 +def _get_block_shape(spec: pl.BlockSpec) -> tuple[int, ...]: + """Get the block shape for a given block spec.""" + def _get_dim_size(bd): + match bd: + case pl.Blocked(block_size): + return block_size + case pl.Element(): + return bd.block_size + case pl.BoundedSlice(block_size): + return block_size + case int(): + return bd + case None: + return 1 + case _: + raise ValueError(f"Unsupported block dimension type: {bd}") + if spec.block_shape is None: + raise ValueError("Block shape must be specified.") + block_shape = tuple(_get_dim_size(x) for x in spec.block_shape) + return block_shape @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -236,7 +280,8 @@ def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: + def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True + ) -> BufferedRef: """Create a BufferedRef. Args: @@ -249,7 +294,7 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: Returns: Initialized BufferedRef """ - block_shape = tuple(1 if x is None else x for x in spec.block_shape) + block_shape = _get_block_shape(spec) if buffer_type is BufferType.ACCUMULATOR: accum_ref = VMEM(block_shape, dtype) else: @@ -375,9 +420,22 @@ def bind_existing_ref(self, window_ref, indices): def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple(1 if x is None else x for x in self.block_shape) + block_shape = [] + for bd in self.block_shape: + if isinstance(bd, (pl.Element, pl.BoundedSlice)): + raise ValueError( + "Element and BoundedSlice block dimensions are not supported." + ) + if bd is None: + block_shape.append(1) + elif isinstance(bd, pl.Blocked): + block_shape.append(bd.block_size) + elif isinstance(bd, int): + block_shape.append(bd) + else: + raise ValueError(f"Unsupported block dimension type: {type(bd)}") indices = self.compute_index(*grid_indices) - return jax.tree.map(_make_ds, indices, block_shape) + return jax.tree.map(_make_block_ds, indices, tuple(block_shape)) def init_slots(self): """Initialize slot indices.""" @@ -444,10 +502,12 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices): raise NotImplementedError("Must use >1D values.") tiling = _make_tiling(src_shape, src_dtype) - block_shape = tuple(1 if b is None else b for b in self.block_shape) block_indices = self.compute_index(*grid_indices) - return jax.tree.map( - _make_block_slice, block_indices, block_shape, src_shape, tiling + return tuple( + _make_block_slice(bi, bs, ss, t) + for bi, bs, ss, t in zip( + block_indices, self.block_shape, src_shape, tiling, strict=True + ) ) def copy_in(self, src_ref, grid_indices): diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index f7becf1ec6da..9b743bb18b37 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -40,8 +40,9 @@ map = util.safe_map zip = util.safe_zip -def _get_block_size(bd: pl.Blocked | pl.Element | pl.Squeezed | int | None - ) -> int: +def _get_block_size( + bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None, +) -> int: match bd: case int(): return bd diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 1bcf47b9ddee..19469824aa6a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -221,14 +221,19 @@ def _block_map_function(new_idx, *args): block_mapping.index_map_jaxpr.consts, *drop_last_args, ) + unflat_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, indices) + if not isinstance(unflat_indices, tuple): + unflat_indices = (unflat_indices,) + unflat_indices = list(unflat_indices) if dim is not batching.not_mapped: if isinstance(dim, batching.RaggedAxis): assert for_ragged, "Ragged axis not supported for non-ragged batching." stacked_axis = dim.stacked_axis - indices.insert(stacked_axis, new_idx) + unflat_indices.insert(stacked_axis, new_idx) else: - indices.insert(dim, new_idx) - return tuple(indices) + unflat_indices.insert(dim, new_idx) + return tuple(unflat_indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] if for_ragged: @@ -243,11 +248,15 @@ def _block_map_function(new_idx, *args): ) idx_avals = [*idx_avals, i32_aval_memref] + block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(_block_map_function, + debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info), + tree_util.tree_structure(idx_avals)) with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_block_map_function, - debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info), + block_mapping_flat_fn, idx_avals) + new_index_map_out_tree = out_tree_thunk() shape = block_mapping.block_shape if dim is batching.not_mapped: new_block_shape = shape @@ -278,7 +287,8 @@ def _block_map_function(new_idx, *args): jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, array_shape_dtype=new_array_shape_dtype, - index_map_jaxpr=jaxpr) + index_map_jaxpr=jaxpr, + index_map_out_tree=new_index_map_out_tree) def _broadcast_input_output_aliases( diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 592abd1915d0..2cddb623b33f 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -121,10 +121,12 @@ def _eval_index_map( block_indices = lower_jaxpr_to_triton_ir( ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx ) - block_indices = ( + block_indices = tuple( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) + block_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, block_indices) if block_mapping.pipeline_mode is not None: raise NotImplementedError( "Pipeline mode is not supported in Triton lowering." diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index c05e1645ddbe..1e631ad407fd 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -18,8 +18,10 @@ https://docs.jax.dev/en/latest/pallas.html. """ +from jax._src.pallas.core import BlockDim as BlockDim from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec +from jax._src.pallas.core import BoundedSlice as BoundedSlice from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import CompilerParams as CompilerParams from jax._src.pallas.core import core_map as core_map diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 128da748b233..9b40033d7307 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1001,6 +1001,35 @@ class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest): INTERPRET = True +class PallasCallBoundedSliceIndexingTest(PallasBaseTest): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Only applicable for TPU") + + def test_block_spec_bounded_slice_static(self): + shape = (16, 8, 128) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + with self.assertRaisesRegex(NotImplementedError, + "Unsupported block dimension type:"): + _ = self.pallas_call( + kernel, + jax.ShapeDtypeStruct((8, 8, 128), dtype=np.int32), + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), lambda i: (0, 0, 0), + ), + )(x) + class ApiErrorTest(PallasBaseTest): def test_pallas_call_kernel_args_mismatch(self): a = np.arange(256, dtype=np.int32) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 0f1cdb5b957f..d718d5cad7ab 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -17,19 +17,18 @@ import functools from absl.testing import absltest from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps import jax from jax import lax from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax._src import shard_map +from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -import hypothesis as hp -import hypothesis.strategies as hps - hp.settings.register_profile( 'deterministic', @@ -1547,5 +1546,141 @@ def align_up_to(x, y): np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) +class PallasCallBoundedSliceIndexingTest(parameterized.TestCase): + + def test_block_spec_bounded_slice_invalid_index(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + shape = (16, 8, 128) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + pltpu.emit_pipeline( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (0, 0, 0), # first index needs to be a pl.ds + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), + ), + )(x_ref, y_ref) + + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + with self.assertRaisesRegex( + ValueError, + 'Must return a pl.ds from the index_map for a BoundedSlice dimension.' + ): + f.trace(jax.ShapeDtypeStruct(shape, jnp.int32)) + + def test_block_spec_bounded_slice_static(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + pltpu.emit_pipeline( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), + ), + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + + out = f(x) + np.testing.assert_allclose(out, x[4:12]) + + def test_block_spec_bounded_slice_dynamic(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) + + slices = jnp.array([[0, 3], [3, 8], [8, 11], [11, 16]], dtype=jnp.int32)[ + ::-1 + ] + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref, slices_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + + @functools.partial( + pl.run_scoped, slices_smem=pltpu.SMEM(slices.shape, slices.dtype) + ) + def _(slices_smem): + pltpu.sync_copy(slices_ref, slices_smem) + def index_map(i): + return ( + pl.ds(slices_smem[i, 0], slices_smem[i, 1] - slices_smem[i, 0]), + 0, + 0, + ) + block_spec = pl.BlockSpec( + (pl.BoundedSlice(16), 8, 128), + index_map, + ) + pltpu.emit_pipeline( + kernel, + grid=(slices.shape[0],), + in_specs=(block_spec,), + out_specs=block_spec, + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x, slices): + y = pl.empty_like(x) + _, y, _ = pl.run_state(main)((x, y, slices)) + return y + + out = f(x, slices) + np.testing.assert_allclose(out, x) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 53d5af6bdb5b01ff4141058462145ec7ac420224 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Fri, 25 Apr 2025 00:12:02 -0700 Subject: [PATCH 0813/1769] [JAX:Sparse] Add BCSR benchmarks to sparse_benchmarks.py. PiperOrigin-RevId: 751297447 --- benchmarks/sparse_benchmark.py | 140 ++++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 37 deletions(-) diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index d6328881d5c6..0ffb2aed5125 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -21,7 +21,13 @@ import jax from jax.experimental import sparse -def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): + +def _sparse_fromdense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -32,7 +38,7 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): ) mat = jnp.zeros(shape).at[indices].set(data) - f = sparse.BCOO.fromdense + f = sparse.BCSR.fromdense if bcsr else sparse.BCOO.fromdense if compile or jit: # Note: nse must be specified for JIT. f = jax.jit(partial(f, nse=nse)) @@ -49,22 +55,12 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_fromdense(state): - return _sparse_bcoo_fromdense(state) - - -@google_benchmark.register -def sparse_bcoo_fromdense_jit(state): - return _sparse_bcoo_fromdense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_fromdense_compile(state): - return _sparse_bcoo_fromdense(state, compile=True) - - -def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): +def _sparse_todense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -74,6 +70,8 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): rng.choice(size, size=nse, replace=False), shape=shape ) mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) f = lambda mat: mat.todense() if jit or compile: @@ -91,22 +89,12 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_todense(state): - return _sparse_bcoo_todense(state) - - -@google_benchmark.register -def sparse_bcoo_todense_jit(state): - return _sparse_bcoo_todense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_todense_compile(state): - return _sparse_bcoo_todense(state, compile=True) - - -def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): +def _sparse_matvec( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 key = jax.random.key(1701) @@ -118,6 +106,9 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): indices_dtype=jnp.int32, sorted_indices=True, ) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) + vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) f = lambda mat, vec: mat @ vec @@ -136,19 +127,94 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): f(mat, vec).block_until_ready() +@google_benchmark.register +def sparse_bcoo_fromdense(state): + return _sparse_fromdense(state) + + +@google_benchmark.register +def sparse_bcoo_fromdense_jit(state): + return _sparse_fromdense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_fromdense_compile(state): + return _sparse_fromdense(state, compile=True) + + +@google_benchmark.register +def sparse_bcoo_todense(state): + return _sparse_todense(state) + + +@google_benchmark.register +def sparse_bcoo_todense_jit(state): + return _sparse_todense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_todense_compile(state): + return _sparse_todense(state, compile=True) + + @google_benchmark.register def sparse_bcoo_matvec(state): - return _sparse_bcoo_matvec(state) + return _sparse_matvec(state) @google_benchmark.register def sparse_bcoo_matvec_jit(state): - return _sparse_bcoo_matvec(state, jit=True) + return _sparse_matvec(state, jit=True) @google_benchmark.register def sparse_bcoo_matvec_compile(state): - return _sparse_bcoo_matvec(state, compile=True) + return _sparse_matvec(state, compile=True) + + +@google_benchmark.register +def sparse_bscr_fromdense(state): + return _sparse_fromdense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_jit(state): + return _sparse_fromdense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_compile(state): + return _sparse_fromdense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bscr_todense(state): + return _sparse_todense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_todense_jit(state): + return _sparse_todense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_todense_compile(state): + return _sparse_todense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bcsr_matvec(state): + return _sparse_matvec(state, bcsr=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_jit(state): + return _sparse_matvec(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_compile(state): + return _sparse_matvec(state, bcsr=True, compile=True) if __name__ == "__main__": From d8238ecb817c6fec4b1facec38740c4fbf32b155 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 02:49:06 -0700 Subject: [PATCH 0814/1769] [Mosaic GPU] Fix up the Blackwell code to match changes in the CUDA runtime The CUDA runtime now complains if the kernel does not explicitly deallocate the tensor memory. Also included a drive-by fix for the buggy lowering of vector.extract_strided_slice (the integers are signless and `getSInt` does not support those). PiperOrigin-RevId: 751338675 --- jax/experimental/mosaic/gpu/core.py | 54 ++++++++++++++++---------- jax/experimental/mosaic/gpu/tcgen05.py | 52 ++++++++++++++++++------- jaxlib/mosaic/gpu/passes.cc | 6 +-- tests/mosaic/gpu_test.py | 6 +-- 4 files changed, 77 insertions(+), 41 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index bb0ecff96350..546690c3e9fd 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -18,7 +18,6 @@ import ctypes import dataclasses import enum -import functools import hashlib import math import os @@ -222,11 +221,29 @@ class LoweringSemantics(enum.Enum): Warpgroup = enum.auto() +@dataclasses.dataclass(frozen=True) +class _TMEMAlloc: + addr_ref: ir.Value + num_cols: int + collective: bool + + def alloc(self): + tcgen05.tmem_alloc( + self.addr_ref, self.num_cols, collective=self.collective, exact=False + ) + + def dealloc(self): + addr = memref.load(self.addr_ref, []) + tcgen05.tmem_dealloc( + addr, self.num_cols, collective=self.collective, exact=False + ) + + def _construct_smem_reftree( cluster_shape: tuple[int, int, int], dynamic_smem: ir.Value, smem_buffers: ShapeTree, - delayed_warp_init: list[Callable[[], None]], # Mutated by this function! + tmem_allocs: list[_TMEMAlloc], # Mutated by this function! lowering_semantics: LoweringSemantics, dynamic_smem_offset: int = 0, ) -> Callable[[], RefTree]: @@ -261,7 +278,7 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: cluster_shape, dynamic_smem, m, - delayed_warp_init, + tmem_allocs, lowering_semantics, dynamic_smem_offset, ) @@ -304,12 +321,7 @@ def ref(member_thunks=member_thunks): if layout is None: layout = tcgen05._infer_tmem_layout(shape, collective) num_cols = layout.cols_in_shape(shape) - delayed_warp_init.append( - functools.partial( - tcgen05.tmem_alloc, - addr_ref, num_cols, collective=collective, exact=False, - ) - ) + tmem_allocs.append(_TMEMAlloc(addr_ref, num_cols, collective)) def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): addr = memref.load(addr_ref, []) return tcgen05.TMEMRef( @@ -441,13 +453,9 @@ def _launch( scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof) with ctx.named_region("Init"): - delayed_warp_init = [] + tmem_allocs: list[_TMEMAlloc] = [] smem_ref_tree_thunk = _construct_smem_reftree( - cluster, - dynamic_smem, - smem_buffers, - delayed_warp_init, - lowering_semantics, + cluster, dynamic_smem, smem_buffers, tmem_allocs, lowering_semantics ) # TODO(apaszke): Skip fences if no barriers or TMEM is initialized. # TODO(apaszke): Only initialize cluster barriers before the cluster wait. @@ -455,17 +463,23 @@ def _launch( if math.prod(cluster) != 1: nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - if delayed_warp_init: + if tmem_allocs: eq = arith.CmpIPredicate.eq is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32)) with utils.when(is_init_warp): - for init in delayed_warp_init: - init() - tcgen05.tmem_relinquish_alloc_permit() + for alloc in tmem_allocs: + alloc.alloc() + if any(alloc.collective for alloc in tmem_allocs): + tcgen05.tmem_relinquish_alloc_permit(collective=True) + if any(not alloc.collective for alloc in tmem_allocs): + tcgen05.tmem_relinquish_alloc_permit(collective=False) gpu.barrier() # Make sure the init is visible to all threads. smem_ref_tree = smem_ref_tree_thunk() yield ctx, smem_ref_tree + + for alloc in tmem_allocs: + alloc.dealloc() if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() @@ -666,7 +680,7 @@ def prof_kernel(*args): *results, prof_buffer = bind(*args) def dump_profile(prof_buffer): out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp"), f"{time.time_ns()}-trace.json", ) try: diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 53056ce594b2..4f8c5847f74f 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -288,6 +288,19 @@ def commit_arrive( ) +def _alloc_ncols(ncols: int, exact: bool): + if exact: + if ncols.bit_count() != 1 or not 32 <= ncols <= 512: + raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}") + else: + ncols = max(32, 1 << (ncols - 1).bit_length()) + if ncols > 512: + raise ValueError( + f"After rounding up, got {ncols} columns, exceeding the limit of 512" + ) + return ncols + + def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True): if ir.MemRefType.isinstance(tmem_addr.type): ref_ty = ir.MemRefType(tmem_addr.type) @@ -300,15 +313,7 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3) elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"): raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}") - if exact: - if ncols.bit_count() != 1 or not 32 <= ncols <= 512: - raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}") - else: - ncols = max(32, 1 << (ncols - 1).bit_length()) - if ncols > 512: - raise ValueError( - f"After rounding up, got {ncols} columns, exceeding the limit of 512" - ) + ncols = _alloc_ncols(ncols, exact) num_cta = 2 if collective else 1 return llvm.inline_asm( ir.Type.parse("!llvm.void"), @@ -318,11 +323,27 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: has_side_effects=True, ) -def tmem_relinquish_alloc_permit(): + +def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True): + if tmem_addr.type != ir.IntegerType.get_signless(32): + raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}") + ncols = _alloc_ncols(ncols, exact) + num_cta = 2 if collective else 1 + return llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [tmem_addr], + f"tcgen05.dealloc.cta_group::{num_cta}.sync.aligned.b32 $0, {ncols};", + "r", + has_side_effects=True, + ) + + +def tmem_relinquish_alloc_permit(collective: bool): + num_cta = 2 if collective else 1 return llvm.inline_asm( ir.Type.parse("!llvm.void"), [], - "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;", + f"tcgen05.relinquish_alloc_permit.cta_group::{num_cta}.sync.aligned;", "", has_side_effects=True, ) @@ -633,10 +654,13 @@ def _transfer_32xcols(base_addr, cols): cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. assert cols % cols_per_num == 0 total_num = cols // cols_per_num - if total_num <= 32: + assert total_num.bit_count() == 1 + # We artificially lower the instr_num compared to its limits, because higher + # values can lead to register spills.. + if total_num <= 16: instr_num = total_num - elif total_num == 64: - instr_num = 32 + elif 32 <= total_num <= 64: + instr_num = 16 else: raise NotImplementedError(total_num) # We transfer 16 lanes at a time, but have 32 to deal with. diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 9fa6f8df78a8..b5325e97d4ad 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -54,14 +54,14 @@ struct ConvertExtractStridedSlicePattern final return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported"); } int64_t size = - (*op.getSizes().getAsRange().begin()).getSInt(); + (*op.getSizes().getAsRange().begin()).getInt(); if (size < 0) { return rewriter.notifyMatchFailure(op, "size is negative"); } int64_t start = - (*op.getOffsets().getAsRange().begin()).getSInt(); + (*op.getOffsets().getAsRange().begin()).getInt(); int64_t stride = - (*op.getStrides().getAsRange().begin()).getSInt(); + (*op.getStrides().getAsRange().begin()).getInt(); if (stride != 1) { return rewriter.notifyMatchFailure(op, "only stride 1 is supported"); } diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 8d0c56877bd4..f7da896b6ee5 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -947,9 +947,8 @@ def kernel(ctx, input, output, scratch): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 swizzle=(32, 64, 128,), ) - def test_mma_basic(self, *args, **kwargs): + def test_mma_basic(self, **kwargs): self._basic_mma_test( - *args, **kwargs, k_steps=2, # Reducing to 1 can be helpful while debugging. lhs_transpose_tiles=False, @@ -967,11 +966,10 @@ def test_mma_basic(self, *args, **kwargs): lhs_transpose_tiles=(False, True), rhs_transpose_tiles=(False, True), ) - def test_mma_transposed_tiles(self, *args, **kwargs): + def test_mma_transposed_tiles(self, **kwargs): if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: self.skipTest("This is already tested in test_mma_basic") self._basic_mma_test( - *args, **kwargs, k_steps=2, # Reducing to 1 can be helpful while debugging. ) From 2c85a2587583e6dcd1e6875948d7f0112a42621e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 04:26:13 -0700 Subject: [PATCH 0815/1769] [Mosaic TPU] Allow indexing refs with narrow integers PiperOrigin-RevId: 751362838 --- tests/pallas/tpu_pallas_test.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 5d3c97108a6c..f7cb51e0c224 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1987,7 +1987,7 @@ def test_scalar_load_upcast(self, in_dtype): if not jtu.if_cloud_tpu_at_least(2025, 4, 25): self.skipTest("Needs a newer libTPU") if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): - self.skipTest("Triggers an XLA bug") + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) def kernel(x_ref, o_ref): o_ref[0, 0] = x_ref[0, 0].astype(o_ref.dtype) x = jnp.asarray([[-1]], dtype=in_dtype) @@ -1999,6 +1999,23 @@ def kernel(x_ref, o_ref): )(x) self.assertEqual(y, x.astype(jnp.int32)) + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_indirect_load(self, in_dtype): + if not jtu.if_cloud_tpu_at_least(2025, 4, 27): + self.skipTest("Needs a newer libTPU") + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, x_ref[0, 0].astype(jnp.int32)].astype(o_ref.dtype) + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) + x = jnp.asarray([[3, 0, 0, 1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x[0, x[0, 0]].astype(jnp.int32)[None, None]) + def test_masked_store(self): shape = (16, 256) mask_shape = (10, 130) From ddb1b26599d8a6e29b5c1c054167090aeb472416 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 06:15:22 -0700 Subject: [PATCH 0816/1769] [Mosaic TPU] Allow simultaneous column and row shifts It's unclear to me why we ever disallowed this. Proably some slight issue about how we handled vreg arrays. Either way, seems to be all fine now (as it should be). PiperOrigin-RevId: 751388746 --- .../tpu/transforms/apply_vector_layout.cc | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index dadf0498db3c..04902c2af6f3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5809,8 +5809,6 @@ FailureOr>> changeOffsets( const VectorType vty, const VectorLayout src, xla::Array vregs, const LayoutOffsets dst_offsets) { const auto &target_shape = ctx.target_shape; - const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), - src.implicit_dim()); int row_diff; if (!src.offsets()[0].has_value()) { @@ -5830,30 +5828,30 @@ FailureOr>> changeOffsets( col_diff = *dst_offsets[1] - *src.offsets()[1]; } + VectorLayout src_after_row_shift(src.bitwidth(), + {dst_offsets[0], src.offsets()[1]}, + src.tiling(), src.implicit_dim()); if (row_diff != 0) { - if (col_diff != 0) { - return emitError(loc, "Not implemented: Row and column offset changes"); - } const SmallVector implicit_shape = src.implicitShape(vty.getShape()); FAILUREOR_ASSIGN_OR_RETURN( vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src, *dst_offsets[0], ctx.target_shape)); + // Make sure the shape is as expected. + SmallVector current_tiles_shape = + src_after_row_shift.tileArrayImplicitShape(vty.getShape(), + target_shape); + CHECK_EQ(*(current_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); } - // Rows are now correctly aligned. Time to offset columns. - // TODO(apaszke, mvoz): Changing an offset might add or remove one vreg. - // Note - this is handled for row shifts via tpu_rotate_with_overflow - SmallVector dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); - CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); - - // TODO(tlongeri): Clean up col_diff and pass the dst offset directly. if (col_diff != 0) { FAILUREOR_ASSIGN_OR_RETURN( vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs), - src, *dst.offsets()[1], target_shape)); + src_after_row_shift, *dst_offsets[1], + target_shape)); } + VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), + src.implicit_dim()); return std::make_pair(dst, std::move(vregs)); } From 3508e0c6a3af48d2f143a2065838692e5ee32ad5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Apr 2025 06:18:48 -0700 Subject: [PATCH 0817/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ee9ee727b533dbd14698c9eda979a8c83ed86e11. PiperOrigin-RevId: 751389725 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c7d0d5cb7c3f..6240e79aeb22 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e7b3e6d645d1f3343115d1c5dd8d104976676152" -XLA_SHA256 = "4719cec5489231a3a90840d4ff5ec7bac4ab5dde44757ecf4d283f9dab485a0f" +XLA_COMMIT = "ee9ee727b533dbd14698c9eda979a8c83ed86e11" +XLA_SHA256 = "63ebc70aa209ada6b29faea67c196f9c1237e14bb381b2c014db0405b80881ec" def repo(): tf_http_archive( From 224ce511b5f31301408d25b554b526f49d3a12d5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 07:10:55 -0700 Subject: [PATCH 0818/1769] [Mosaic TPU] Relax test restrictions after improving DMA support for different tilings PiperOrigin-RevId: 751402693 --- tests/pallas/tpu_pallas_pipeline_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index d718d5cad7ab..627302904748 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -1515,9 +1515,6 @@ def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): if dtype == 'int8' and jtu.is_device_tpu_at_least(6): self.skipTest('Not implemented for TPU v6.') - def align_up_to(x, y): - return (x + y - 1) // y * y - hp.assume(bm <= m) hp.assume(bn <= n) hp.assume(bk <= k) @@ -1527,11 +1524,6 @@ def align_up_to(x, y): if not jtu.is_device_tpu_at_least(5): self.skipTest('Only TPU v5+ allowed for int8.') hp.assume(bm >= 32) - # TODO(apaszke): Relax DMA restrictions and remove this. - packing = 4 // jnp.dtype(dtype).itemsize - if packing != 1: - m = align_up_to(m, 8 * packing) - k = align_up_to(k, 8 * packing) k1, k2 = jax.random.split(jax.random.key(seed)) x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) From 0c9ccd50c4b720d14cab4c0fb01260f661967589 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 25 Apr 2025 16:14:58 +0200 Subject: [PATCH 0819/1769] Minor optim in profiler on session stop and export --- jax/_src/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index c787ea4c0223..510b09f16605 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -201,7 +201,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.export(sess.stop(), str(_profile_state.log_dir)) + sess.stop_and_export(str(_profile_state.log_dir)) if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: From a891cfabbcd3ab98c3c6bbfd97be19cb78275ade Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 25 Apr 2025 07:17:50 -0700 Subject: [PATCH 0820/1769] #jax #sdy add a Shardy config to mock_gpu_topology_test. PiperOrigin-RevId: 751404670 --- tests/BUILD | 1 + tests/mock_gpu_topology_test.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 5f142c097889..3745c7cf7882 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -378,6 +378,7 @@ jax_multiplatform_test( enable_backends = ["gpu"], enable_configs = [ "gpu_h100", + "gpu_a100_shardy", ], tags = [ "config-cuda-only", diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 59c511ae61cf..8e409d6ed331 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -14,6 +14,7 @@ from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding @@ -49,13 +50,18 @@ def testMockWithSharding(self): f_lowered = f.lower(jnp.arange(16)) hlo = f_lowered.compiler_ir() + hlo_str = str(hlo) mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE - self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) - self.assertIn( - f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', - str(hlo) - ) + self.assertIn(f'num_partitions = {mocked_count}', hlo_str) + + if config.use_shardy_partitioner.value: + expected_sharding = 'sharding = #sdy.sharding<@mesh, [{"x"}]>' + else: + expected_sharding = ( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"' + ) + self.assertIn(expected_sharding, hlo_str) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 23fdd66a2cbc20c7f2b431ae8a8f1585d97b30bf Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 07:24:25 -0700 Subject: [PATCH 0821/1769] [Pallas] Remove leftover debug=True in Pallas tests PiperOrigin-RevId: 751406280 --- tests/pallas/pallas_test.py | 1 - tests/pallas/tpu_pallas_test.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9b40033d7307..c16c66d4eb52 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -962,7 +962,6 @@ def kernel(x_ref, o_ref): ), out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), out_shape=result_ty, - debug=True, )(x) ref = [] for i in range(15): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index f7cb51e0c224..908bdab54027 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1548,7 +1548,6 @@ def kernel(y_ref, scratch_ref): out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), grid=(2,), ), - debug=True, out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.int32), )() expected = jnp.broadcast_to(jnp.arange(2, dtype=jnp.int32)[..., None, None], @@ -2097,7 +2096,6 @@ def _(): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), - debug=True, ) ) )(x, y) From 859fd679ddb6244a2ca249bcda181686bb72860f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 25 Apr 2025 07:24:32 -0700 Subject: [PATCH 0822/1769] [JAX] Remove xla_client_test. PiperOrigin-RevId: 751406310 --- jaxlib/BUILD | 67 - jaxlib/custom_calls_testlib.cc | 128 - jaxlib/xla_client_backend_independent_test.py | 212 - jaxlib/xla_client_test.py | 3734 ----------------- 4 files changed, 4141 deletions(-) delete mode 100644 jaxlib/custom_calls_testlib.cc delete mode 100644 jaxlib/xla_client_backend_independent_test.py delete mode 100644 jaxlib/xla_client_test.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 7c00e3f5b99d..e4efac1462f9 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -23,7 +23,6 @@ load( "proto_library", "py_deps", "py_library_providing_imports_info", - "py_strict_library", "py_strict_test", "pytype_library", "pytype_strict_library", @@ -1193,72 +1192,6 @@ pytype_strict_library( ]) + [":_jax"], ) -py_strict_test( - name = "xla_client_backend_independent_test", - srcs = ["xla_client_backend_independent_test.py"], - deps = [ - ":xla_client", - ] + py_deps([ - "absl/testing", - "numpy", - "portpicker", - ]), -) - -py_strict_library( - name = "xla_client_test", - testonly = 1, - srcs = ["xla_client_test.py"], - visibility = [":xla_python"], - deps = [ - ":xla_client", - "//jax", - "//jax:test_util", - "//jaxlib", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - "ml_dtypes", - "numpy", - ]), -) - -nanobind_extension( - name = "custom_calls_testlib", - testonly = 1, - srcs = ["custom_calls_testlib.cc"], - deps = [ - "@com_google_absl//absl/status", - "@nanobind", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - ], -) - -py_strict_test( - name = "xla_client_test_cpu", - srcs = ["xla_client_test.py"], - args = ["--backend=cpu"], - env = { - "XLA_FLAGS": "--xla_force_host_platform_device_count=4", - }, - main = "xla_client_test.py", - deps = [ - ":custom_calls_testlib", - ":xla_client", - "//jax", - "//jax:test_util", - "//jaxlib", - ] + py_deps([ - "absl/flags", - "absl/logging", - "absl/testing", - "ml_dtypes", - "numpy", - ]), -) - py_strict_test( name = "pytree_test", srcs = ["pytree_test.py"], diff --git a/jaxlib/custom_calls_testlib.cc b/jaxlib/custom_calls_testlib.cc deleted file mode 100644 index 58f4818a431e..000000000000 --- a/jaxlib/custom_calls_testlib.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2024 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "nanobind/nanobind.h" -#include "xla/ffi/api/c_api.h" -#include "xla/ffi/api/ffi.h" - -namespace xla::ffi { -namespace nb = ::nanobind; - -// Implement custom calls as static functions with XLA FFI types in the function -// signature that gives access to the arguments and results buffers together -// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI -// examples and features (e.g. binding attributes, custom user-defined structs -// and arbitrary execution context). - -static Error AlwaysFail(Result) { - return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally"); -} - -static Error AlwaysSucceed(Result) { return Error::Success(); } - -static Error Subtract(BufferR0 a, BufferR0 b, - Result> out) { - *out->typed_data() = *a.typed_data() - *b.typed_data(); - return Error::Success(); -} - -static Error SubtractCst(BufferR0 a, - Result> out, float cst) { - *out->typed_data() = *a.typed_data() - cst; - return Error::Success(); -} - -// Define XLA FFI handlers from the implementations defined above using explicit -// XLA FFI binding API to describe type signatures of custom calls. - -XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret()); - -XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, - Ffi::Bind().Ret()); - -XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract, - Ffi::Bind() - .Arg>() - .Arg>() - .Ret>()); - -XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst, - Ffi::Bind() - .Arg>() - .Ret>() - .Attr("cst")); - -// XLA FFI calls can also be stateful. -struct TestFfiState { - static TypeId id; - explicit TestFfiState(int32_t value) : value(value) {} - int32_t value; -}; -TypeId TestFfiState::id = {}; - -static ErrorOr> StateInstantiate() { - return std::make_unique(42); -} - -static Error StateExecute(TestFfiState* state, - Result> out) { - *out->typed_data() = state->value; - return Error::Success(); -} - -XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, - Ffi::BindInstantiate()); -XLA_FFI_DEFINE_HANDLER( - kStateExecute, StateExecute, - Ffi::Bind().Ctx>().Ret>()); - -template -static auto BindFunction(T* fn) { - return nb::capsule(reinterpret_cast(fn)); -} - -template -static auto BindTypeId(T* typeId) { - return nb::capsule(reinterpret_cast(typeId)); -} - -// Custom calls registration library that exports function pointers to XLA FFI -// handlers to the python users. -NB_MODULE(custom_calls_testlib, m) { - m.def("registrations", []() { - nb::dict dict; - dict["always_fail"] = BindFunction(kAlwaysFail); - dict["always_succeed"] = BindFunction(kAlwaysSucceed); - dict["subtract_f32"] = BindFunction(kSubtract); - dict["subtract_f32_cst"] = BindFunction(kSubtractCst); - - nb::dict bundle; - bundle["instantiate"] = BindFunction(kStateInstantiate); - bundle["execute"] = BindFunction(kStateExecute); - dict["stateful"] = bundle; - - return dict; - }); - m.def("type_ids", []() { - nb::dict type_ids; - type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id); - return type_ids; - }); -} - -} // namespace xla::ffi diff --git a/jaxlib/xla_client_backend_independent_test.py b/jaxlib/xla_client_backend_independent_test.py deleted file mode 100644 index 1cd2865bf9a9..000000000000 --- a/jaxlib/xla_client_backend_independent_test.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2017 The JAX Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Backend-independent tests for the Python XLA client.""" - -import unittest - -from absl.testing import absltest -import numpy as np - -from jax.jaxlib import xla_client - -# pylint: disable=g-import-not-at-top -try: - import portpicker -except ImportError: - portpicker = None -# pylint: enable=g-import-not-at-top - -ops = xla_client.ops - - -class ShapeTest(absltest.TestCase): - - def testInvalidShapes(self): - with self.assertRaisesRegex( - xla_client.XlaRuntimeError, "Invalid dimension size" - ): - xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) - - with self.assertRaisesRegex( - RuntimeError, "layout minor_to_major field contains 1 element.*" - ): - xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) - - with self.assertRaisesRegex( - RuntimeError, "layout minor_to_major field has out-of-bounds value.*" - ): - xla_client.Shape.array_shape( - xla_client.PrimitiveType.F32, [2, 4], [1, -1] - ) - - -class ComputationPrinting(absltest.TestCase): - - def ExampleComputation(self): - builder = xla_client.XlaBuilder("acomputation") - p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) - ) - x = ops.Mul(p0, p1) - ops.Add(x, x) - return builder.build() - - def testComputationToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_text() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testComputationToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = computation.as_hlo_dot_graph() - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - def testHloModuleToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_module().to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testHloModuleFromText(self): - hlo_module_text = """HloModule test - add { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x, y) - } - ENTRY entry { - p0 = f32[2,3] parameter(0) - start = f32[2,3] all-reduce-start(p0), to_apply=add - ROOT done = f32[2,3] all-reduce-done(start) - }""" - hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text) - hlo_text = hlo_module.to_string() - self.assertTrue(hlo_text.startswith("HloModule test")) - - def testHloModuleToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.as_hlo_module() - ) - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - -class ComputationHashTest(absltest.TestCase): - - def testHash(self): - builder0 = xla_client.XlaBuilder("computation0") - p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) - ) - ops.Mul(p0, p1) - computation0 = builder0.build() - - builder1 = xla_client.XlaBuilder("computation1") - p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) - ) - ops.Mul(p0, p1) - computation1 = builder1.build() - - self.assertEqual(computation0.hash(), computation1.hash()) - - -class AliasTest(absltest.TestCase): - - def testSetUpAlias(self): - c = xla_client.XlaBuilder(self.id()) - p1 = ops.Parameter( - c, - 0, - xla_client.shape_from_pyval( - np.array(1.0, np.float32) - ).with_major_to_minor_layout_if_absent(), - ) - p2 = ops.Parameter( - c, - 1, - xla_client.shape_from_pyval( - np.array(1.0, np.float32) - ).with_major_to_minor_layout_if_absent(), - ) - out = ops.Add(p1, p2) - c.setup_alias([], 0, []) - c.build(out) - - -class ProfilerTest(absltest.TestCase): - - def testTraceMe(self): - # TODO(phawkins): These tests just check that the TraceMe context manager - # acts like a context manager and doesn't explode. Ideally we'd check that - # the profiler saw the traceme too. - with xla_client.profiler.TraceMe("test1"): - pass - with xla_client.profiler.TraceMe("test2", foo=123): - pass - with self.assertRaises(ValueError): - with xla_client.profiler.TraceMe("test3"): - raise ValueError("test") - - @unittest.skipIf(portpicker is None, "Test requires portpicker") - def testStartServer(self): - port = portpicker.pick_unused_port() - server = xla_client.profiler.start_server(port) - del server - - -class HloModuleGroupTest(absltest.TestCase): - - def testHloModuleGroup(self): - builder0 = xla_client.XlaBuilder("computation0") - p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) - ) - root = ops.Mul(p0, p1) - computation0 = builder0.build(root) - - m = computation0.get_hlo_module() - mg_name = "test_module_group" - mg = xla_client._xla.HloModuleGroup(mg_name, [m]) - self.assertEqual(mg.name, mg_name) - - modules = mg.to_modules() - self.assertLen(modules, 1) - self.assertEqual(m.to_string(), modules[0].to_string()) - - -class RunHloPassTest(absltest.TestCase): - - def testHloDCE(self): - b = xla_client.XlaBuilder("acomputation") - p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - b, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)) - ) - root = ops.Mul(p0, p1) - - # Dead instructions - p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0))) - ops.Add(p2, p2) - - hlo_module = b.build(root).get_hlo_module() - self.assertTrue(xla_client._xla.HloDCE().run(hlo_module)) - - -if __name__ == "__main__": - absltest.main() diff --git a/jaxlib/xla_client_test.py b/jaxlib/xla_client_test.py deleted file mode 100644 index ae34751b6cab..000000000000 --- a/jaxlib/xla_client_test.py +++ /dev/null @@ -1,3734 +0,0 @@ -# Copyright 2017 The JAX Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Backend-dependent tests for the Python XLA client.""" - -import collections -import functools -import itertools -import re -import threading -import traceback -from typing import Sequence -import unittest - -from absl import flags -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized -import ml_dtypes -import numpy as np - -from jax.jaxlib import xla_client -import jax -import jax._src.test_util - -# pylint: disable=g-import-not-at-top -try: - from jax.jaxlib import custom_calls_testlib -except ImportError: - custom_calls_testlib = None - -xla_client._xla.jax_jit.set_thread_local_state_initialization_callback( - lambda: None -) - -bfloat16 = ml_dtypes.bfloat16 -float4_e2m1fn = ml_dtypes.float4_e2m1fn -float8_e3m4 = ml_dtypes.float8_e3m4 -float8_e4m3 = ml_dtypes.float8_e4m3 -float8_e4m3fn = ml_dtypes.float8_e4m3fn -float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz -float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz -float8_e5m2 = ml_dtypes.float8_e5m2 -float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz -float8_e8m0fnu = ml_dtypes.float8_e8m0fnu -ops = xla_client.ops - -def xla_computation_to_mlir_module(c: xla_client.XlaComputation) -> bytes: - return xla_client._xla.mlir.hlo_to_stablehlo( - c.as_serialized_hlo_module_proto()) - - -def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name - """Execute on one replica with Python values as arguments and output.""" - - def put(arg): # pylint: disable=invalid-name - return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) - - arguments = [put(arg) for arg in arguments] - outputs = executable.execute(arguments) - return [np.asarray(x) for x in outputs] - - -# pylint: disable=invalid-name -def jax_array_convert_to_array(self, dtype=None, copy=None): - del copy - out, _ = self._single_device_array_to_np_array_did_copy() - if dtype is not None: - out = out.astype(dtype) - return out - - -def jax_array_device(self): - return self._sharding._device - - -def jax_array_copy_to_host_async(self): - self._copy_single_device_array_to_host_async() - - -Array = xla_client.ArrayImpl -Array.__array__ = jax_array_convert_to_array -Array.copy_to_host_async = jax_array_copy_to_host_async -Array.device = jax_array_device -xla_client.SingleDeviceSharding.device_set = property( - lambda self: {self._device} -) -# pylint: enable=invalid-name - - -FLAGS = flags.FLAGS - -# We choose to ignore pylint's complaints about complex comprehensions, which we -# use widely for parameterizing tests. -# pylint: disable=g-complex-comprehension - -_CUSTOM_CALLS_REGISTERED = False - - -# XLA' alignment is 16 bytes at the moment, but it should match what Eigen -# supports, and that can go up to 128 bytes on hardware with HVX. -_XLA_CPU_MAX_ALIGNMENT = 128 - - -# Minimum possible alignment for XLA. -_XLA_CPU_MIN_ALIGNMENT = 16 - - -# Return a copy of `x` with the given alignment. Does nothing if `x` is already -# aligned. We do this manually, because numpy doesn't support custom alignment -# value. -def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): - if (x.ctypes.data % alignment) == 0: - return x - - # Create temporary buffer with extra space for alignment. - assert alignment % x.itemsize == 0 - extra = alignment // x.itemsize - buf = np.empty(x.size + extra, dtype=x.dtype) - - # Create a view of the temporary buffer with such an offset, that the result - # buffer is aligned. - offset = (-buf.ctypes.data % alignment) // x.itemsize - result = buf[offset : offset + x.size].reshape(x.shape) - - # Copy the data to the result buffer and return it. - np.copyto(result, x) - return result - - -# Return an unaligned copy of `x`. The result buffer's memory address is -# guaranteed to not be aligned to `alignment`. This function is useful for -# testing failures. -def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): - if (x.ctypes.data % alignment) != 0: - return x - - # Create temporary buffer with extra space. - assert (x.itemsize % alignment) != 0 - offset = 1 - buf = np.empty(x.size + offset, dtype=x.dtype) - - if (buf.ctypes.data % alignment) != 0: - # If the temporary buffer is already unaligned, return it. - result = buf - else: - # Otherwise, create a view of the temporary buffer with an offset. - result = buf[offset : offset + x.size].reshape(x.shape) - assert (result.ctypes.data % alignment) != 0 - - # Copy the data to the result buffer and return it. - np.copyto(result, x) - return result - - -def TestFactory(xla_backend, - cloud_tpu=False, - tfrt_tpu=False, - pjrt_c_api=False, - pathways=False, - pathways_ifrt=False): - tests = [] - - int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] - # TODO(phawkins): test np.float16, where supported. - float_dtypes = [bfloat16, np.float32, np.float64] - complex_dtypes = [np.complex64, np.complex128] - standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] - # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. - # standard_dtypes is only used for BufferProtocolTest so we only test fp8 - # round trip tests. - fp8_dtypes = [ - float8_e3m4, - float8_e4m3, - float8_e4m3fn, - float8_e4m3b11fnuz, - float8_e5m2, - float8_e8m0fnu, - ] - standard_dtypes += fp8_dtypes - # TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type - # standard_dtypes += [float4_e2m1fn] - dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes - - class ComputationTest(parameterized.TestCase): - """Base class for running an XLA Computation through the local client.""" - - def setUp(self): - super(ComputationTest, self).setUp() - self.backend = xla_backend() - - global _CUSTOM_CALLS_REGISTERED - if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: - for name, fn in custom_calls_testlib.registrations().items(): - xla_client.register_custom_call_target( - name, fn, platform="cpu", api_version=1 - ) - for name, val in custom_calls_testlib.type_ids().items(): - xla_client.register_custom_type_id(name, val, platform="cpu") - _CUSTOM_CALLS_REGISTERED = True - - def _NewComputation(self, name=None): - if name is None: - name = self.id() - return xla_client.XlaBuilder(name) - - def _Execute(self, c, arguments): - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - return execute_with_python_values( - compiled_c, arguments, backend=self.backend) - - def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): - assert expected is not None - results = self._Execute(c, arguments) - self.assertLen(results, len(expected)) - for result, e in zip(results, expected): - # Numpy's comparison methods are a bit too lenient by treating inputs as - # "array-like", meaning that scalar 4 will be happily compared equal to - # [[4]]. We'd like to be more strict so assert shapes as well. - self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) - assert_func(result, e) - - def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, - expected) - - def _ExecuteAndCompareClose(self, - c, - arguments=(), - expected=None, - rtol=1e-4, - atol=0): - self._ExecuteAndAssertWith( - functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), - c, arguments, expected) - - def NumpyArrayF32(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" - return np.array(*args, dtype=np.float32, **kwargs) - - def NumpyArrayF64(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" - return np.array(*args, dtype=np.float64, **kwargs) - - def NumpyArrayS32(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" - return np.array(*args, dtype=np.int32, **kwargs) - - def NumpyArrayBool(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" - return np.array(*args, dtype=np.bool_, **kwargs) - - class ComputationPrinting(absltest.TestCase): - - def setUp(self): - super(ComputationPrinting, self).setUp() - self.backend = xla_backend() - - def ExampleComputation(self): - builder = xla_client.XlaBuilder("acomputation") - p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) - x = ops.Mul(p0, p1) - ops.Add(x, x) - return builder.build() - - @unittest.skipIf(cloud_tpu or pathways, "not implemented") - def testCompiledHloModuleToHloText(self): - computation = self.ExampleComputation() - executable = self.backend.compile( - xla_computation_to_mlir_module(computation)) - hlo_modules = executable.hlo_modules() - self.assertLen(hlo_modules, 1) - hlo_text = hlo_modules[0].to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - self.assertIn("fusion", hlo_text) - - @unittest.skipIf(cloud_tpu or pathways, "not implemented") - def testCompiledHloModuleAsSerializedProto(self): - computation = self.ExampleComputation() - executable = self.backend.compile( - xla_computation_to_mlir_module(computation)) - hlo_modules = executable.hlo_modules() - self.assertLen(hlo_modules, 1) - hlo_text = hlo_modules[0].to_string() - proto = hlo_modules[0].as_serialized_hlo_module_proto() - hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() - hlo_text_roundtrip = hlo_module_roundtrip.to_string() - self.assertEqual(hlo_text, hlo_text_roundtrip) - - @unittest.skipIf(cloud_tpu or pathways, "not implemented") - def testStableComputationSerialization(self): - # Ideally we would test identical computations produced in different - # processes. For now we have this limited smoke test. - computation = self.ExampleComputation() - ref = computation.as_serialized_hlo_module_proto() - for _ in range(10): - self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) - - # TODO(b/261771737): some version of this should work with pjrt_c_api=True - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, - "not implemented") - def testFlopEstimate(self): - computation = self.ExampleComputation() - properties = xla_client._xla.hlo_module_cost_analysis( - self.backend, computation.as_hlo_module()) - self.assertEqual(properties["flops"], 8.0) - - def testFingerprint(self): - computation = self.ExampleComputation() - executable = self.backend.compile( - xla_computation_to_mlir_module(computation)) - fingerprint = executable.fingerprint - if ( - self.backend.platform == "tpu" - or self.backend.platform == "gpu" - or self.backend.platform == "cpu" - ) and not (cloud_tpu or pathways or pathways_ifrt): - logging.info("fingerprint: %s", fingerprint) - self.assertNotEmpty(fingerprint) - else: - self.assertIsNone(fingerprint) - - tests.append(ComputationPrinting) - - class ComputationsWithConstantsTest(ComputationTest): - """Tests focusing on Constant ops.""" - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in int_dtypes + float_dtypes) - def testConstantScalarSum(self, dtype): - c = self._NewComputation() - ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) - self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testConstantVectorMul(self, dtype): - c = self._NewComputation() - ops.Mul( - ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), - ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) - self._ExecuteAndCompareClose( - c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testConstantVectorScalarDiv(self, dtype): - c = self._NewComputation() - ops.Div( - ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), - ops.Constant(c, dtype(2.0))) - self._ExecuteAndCompareClose( - c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testConstantVectorScalarPow(self, dtype): - c = self._NewComputation() - ops.Pow( - ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), - ops.Constant(c, dtype(2.))) - self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) - - def testIota(self): - c = self._NewComputation() - ops.Iota(c, xla_client.PrimitiveType.F32, 10) - self._ExecuteAndCompareExact( - c, expected=[np.arange(10, dtype=np.float32)]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in int_dtypes) - def testBroadcastedIota(self, dtype): - c = self._NewComputation() - shape = xla_client.Shape.array_shape( - xla_client.dtype_to_etype(dtype), (2, 3)) - ops.Iota(c, shape, 1) - expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) - self._ExecuteAndCompareExact(c, expected=[expected]) - - def testBooleanAnd(self): - c = self._NewComputation() - ops.And( - ops.Constant(c, NumpyArrayBool([True, False, True, False])), - ops.Constant(c, NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) - - def testBooleanOr(self): - c = self._NewComputation() - ops.Or( - ops.Constant(c, NumpyArrayBool([True, False, True, False])), - ops.Constant(c, NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) - - def testBooleanXor(self): - c = self._NewComputation() - ops.Xor( - ops.Constant(c, NumpyArrayBool([True, False, True, False])), - ops.Constant(c, NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testSum2D(self, dtype): - c = self._NewComputation() - ops.Add( - ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), - ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) - self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) - - def testShiftLeft(self): - c = self._NewComputation() - ops.ShiftLeft( - ops.Constant(c, NumpyArrayS32([3])), - ops.Constant(c, NumpyArrayS32([2]))) - self._ExecuteAndCompareClose(c, expected=[[12]]) - - def testShiftRightArithmetic(self): - c = self._NewComputation() - ops.ShiftRightArithmetic( - ops.Constant(c, NumpyArrayS32([-2])), - ops.Constant(c, NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[[-1]]) - - def testShiftRightLogical(self): - c = self._NewComputation() - ops.ShiftRightLogical( - ops.Constant(c, NumpyArrayS32([-1])), - ops.Constant(c, NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testSum2DWith1DBroadcastDim0(self, dtype): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 0 to match the former's shape. - c = self._NewComputation() - ops.Add( - ops.Constant(c, - np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], - dtype=dtype)), - ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), - broadcast_dimensions=(0,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testSum2DWith1DBroadcastDim1(self, dtype): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 1 to match the former's shape. - c = self._NewComputation() - ops.Add( - ops.Constant(c, - np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], - dtype=dtype)), - ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), - broadcast_dimensions=(1,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testConstantAxpy(self, dtype): - c = self._NewComputation() - ops.Add( - ops.Mul( - ops.Constant(c, dtype(2)), - ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), - ops.Constant(c, np.array([100, -100, 200, -200], dtype))) - self._ExecuteAndCompareClose( - c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) - - def testCustomCall(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - ops.CustomCallWithLayout( - c, - b"subtract_f32", - operands=[ - ops.Constant(c, np.float32(1.25)), - ops.Constant(c, np.float32(0.5)) - ], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), ()), - operand_shapes_with_layout=[ - xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - ], - api_version=xla_client.ops.CustomCallApiVersion - .API_VERSION_TYPED_FFI) - self._ExecuteAndCompareClose(c, expected=[0.75]) - - def testCustomCallWithUnifiedApiUnknownTarget(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - - ops.CustomCallWithLayout( - c, - b"not_existing", - operands=[], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), () - ), - operand_shapes_with_layout=[], - api_version=xla_client.ops.CustomCallApiVersion - .API_VERSION_STATUS_RETURNING_UNIFIED, - ) - with self.assertRaisesRegex( - xla_client.XlaRuntimeError, expected_regex="NOT_FOUND" - ): - self._Execute(c, arguments=()) - - def testCustomCallTypedFfiUnknownTarget(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - - ops.CustomCallWithLayout( - c, - b"not_existing", - operands=[], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), () - ), - operand_shapes_with_layout=[], - api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, - ) - with self.assertRaises(xla_client.XlaRuntimeError): - self._Execute(c, arguments=()) - - def testCustomCallTypedFfiAlwaysFail(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - - ops.CustomCallWithLayout( - c, - b"always_fail", - operands=[], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), () - ), - operand_shapes_with_layout=[], - api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, - ) - - with self.assertRaisesRegex( - Exception, expected_regex="Failed intentionally" - ): - self._Execute(c, arguments=()) - - def testCustomCallTypedFfiAlwaysSucceed(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - - ops.CustomCallWithLayout( - c, - b"always_succeed", - operands=[], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), () - ), - operand_shapes_with_layout=[], - api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, - ) - - self._Execute(c, arguments=()) - - def testCustomCallTypedFfiSubtract(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - - ops.CustomCallWithLayout( - c, - b"subtract_f32_cst", - operands=[ops.Constant(c, np.float32(1.25))], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.float32), (), () - ), - operand_shapes_with_layout=[ - xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - ], - opaque=b"{cst = 3.0 : f32}", - api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, - ) - self._ExecuteAndCompareClose(c, expected=[-1.75]) - - def testStatefulCustomCall(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - c = self._NewComputation() - ops.CustomCallWithLayout( - c, - b"stateful", - operands=[], - shape_with_layout=xla_client.Shape.array_shape( - np.dtype(np.int32), (), ()), - operand_shapes_with_layout=[], - api_version=xla_client.ops.CustomCallApiVersion - .API_VERSION_TYPED_FFI) - self._ExecuteAndCompareClose(c, expected=[42]) - - def testCustomCallLookup(self): - if self.backend.platform != "cpu": - self.skipTest("Test requires cpu platform") - - self.assertTrue(_CUSTOM_CALLS_REGISTERED) - xla_client.make_cpu_client() - self.assertContainsSubset( - list(custom_calls_testlib.registrations().keys()), - xla_client.custom_call_targets("Host").keys(), - ) - - tests.append(ComputationsWithConstantsTest) - - class ComputationFromProtoTest(absltest.TestCase): - """Test computation execution from HLO proto.""" - - def setUp(self): - super(ComputationFromProtoTest, self).setUp() - self.backend = xla_backend() - - def testExecuteFromProto(self): - # Build the HLO proto - b = xla_client.XlaBuilder("computation") - ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) - serialized_proto = b.build().as_serialized_hlo_module_proto() - - # Load and execute the proto - c = xla_client.XlaComputation(serialized_proto) - m = xla_computation_to_mlir_module(c) - ans, = execute_with_python_values( - self.backend.compile(m), (), backend=self.backend) - np.testing.assert_equal(ans, np.int32(3)) - - tests.append(ComputationFromProtoTest) - - class ParametersTest(ComputationTest): - """Tests focusing on Parameter ops and argument-passing.""" - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in int_dtypes) - def testScalarTimesVector(self, dtype): - c = self._NewComputation() - arg0 = np.array(3, dtype=dtype) - if np.issubdtype(dtype, np.unsignedinteger): - arg1 = np.array([10, 15, 2, 7], dtype=dtype) - else: - arg1 = np.array([10, 15, -2, 7], dtype=dtype) - p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) - p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) - ops.Mul(p0, p1) - self._ExecuteAndCompareExact( - c, arguments=[arg0, arg1], expected=[arg0 * arg1]) - - # TODO(phawkins): test comparison harness doesn't support bfloat16 - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes if dtype != bfloat16) - def testScalarMinusVectorExplicitNumbering(self, dtype): - # Use explicit numbering and pass parameter_num first. Sub is used since - # it's not commutative and can help catch parameter reversal within the - # computation. - c = self._NewComputation() - arg0 = np.array(2.0, dtype=dtype) - arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) - p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) - p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) - ops.Sub(p1, p0) - self._ExecuteAndCompareClose( - c, arguments=[arg0, arg1], expected=[arg1 - arg0]) - - tests.append(ParametersTest) - - class LayoutsTest(ComputationTest): - """Tests related to getting and setting on-device memory layouts.""" - - def _minor_to_major(self, layout: xla_client.PjRtLayout): # pylint: disable=invalid-name - m2m_str = re.search("{([0-9,]*)", str(layout)).group(1) - if not m2m_str: - return () - return tuple(int(x) for x in m2m_str.split(",")) - - @unittest.skipIf(pathways, "not implemented") - def testGetArgumentLayouts(self): - # Create computation with a few parameters. - c = self._NewComputation() - param_count = 0 - - def MakeArg(shape, dtype): - nonlocal param_count - shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) - param = ops.Parameter(c, param_count, shape) - param_count += 1 - return param - - p0 = MakeArg((2, 3, 4), np.float32) - MakeArg((3, 2), np.int32) - MakeArg((), np.float64) - - ops.Add(p0, ops.Constant(c, np.ones((2, 3, 4), np.float32))) - executable = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - - # Test that compiled executable returns plausible layouts. - layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() - self.assertLen(layouts, 3) - self.assertLen(self._minor_to_major(layouts[0]), 3) - self.assertLen(self._minor_to_major(layouts[1]), 2) - self.assertEmpty(self._minor_to_major(layouts[2])) - - @unittest.skipIf(pathways, "not implemented") - def testGetArgumentLayoutsTupled(self): - # Generated with: - # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), - # np.int32(42), - # np.ones(10)) - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, - %arg1: tensor {mhlo.sharding = "{replicated}"}, - %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) - -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, - tensor {jax.result_info = "[1]"}, - tensor<10xf32> {jax.result_info = "[2]"}) { - return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> - } -} -""" - options = xla_client.CompileOptions() - # 'parameter_is_tupled_arguments' causes MLIR untupled arguments to get - # turned into HLO tupled arguments. - options.parameter_is_tupled_arguments = True - executable = self.backend.compile(module_str, compile_options=options) - - # Test that compiled executable returns plausible layouts. - layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() - self.assertLen(layouts, 3) - self.assertLen(self._minor_to_major(layouts[0]), 3) - self.assertEmpty(self._minor_to_major(layouts[1])) - self.assertLen(self._minor_to_major(layouts[2]), 1) - - @unittest.skipIf(pathways, "not implemented") - def testGetOutputLayouts(self): - # Generated with jax.jit(lambda: (np.ones((1024, 128)), np.int32(42), - # np.ones(10)))() - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<1024x128xf32> {jax.result_info = "[0]"}, - tensor {jax.result_info = "[1]"}, - tensor<10xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> - %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> - %2 = stablehlo.constant dense<42> : tensor - return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> - } -} -""" - executable = self.backend.compile(module_str) - - # Test that compiled executable returns plausible layouts. - layouts: Sequence[xla_client.Layout] = executable.get_output_layouts() - self.assertLen(layouts, 3) - self.assertLen(self._minor_to_major(layouts[0]), 2) - self.assertEmpty(self._minor_to_major(layouts[1])) - self.assertLen(self._minor_to_major(layouts[2]), 1) - - @unittest.skipIf(pathways, "not implemented") - def testSetArgumentLayouts(self): - # TODO(b/309682374): implement on CPU and GPU - if self.backend.platform != "tpu": - raise self.skipTest("mhlo.layout_mode only implemented on TPU") - - # Hand-edited version of: - # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), - # np.int32(42), - # np.ones(10)) - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", - mhlo.layout_mode = "{0,1,2}"}, - %arg1: tensor {mhlo.sharding = "{replicated}", - mhlo.layout_mode = "{}"}, - %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}", - mhlo.layout_mode = "{0}"}) - -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, - tensor {jax.result_info = "[1]"}, - tensor<10xf32> {jax.result_info = "[2]"}) { - return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> - } -} - """ - executable = self.backend.compile(module_str) - - # Check input layouts. - input_layouts = executable.get_parameter_layouts() - self.assertLen(input_layouts, 3) - self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1, 2)) - self.assertEqual(self._minor_to_major(input_layouts[1]), ()) - self.assertEqual(self._minor_to_major(input_layouts[2]), (0,)) - - # Compile a version with default arg0 layout so we can make sure we - # actually set it above. - default_executable = self.backend.compile( - module_str.replace('"{0,1,2}"', '"default"') - ) - self.assertNotEqual( - self._minor_to_major(input_layouts[0]), - self._minor_to_major(default_executable.get_parameter_layouts()[0]), - ) - - @unittest.skipIf(pathways or pathways_ifrt, "not implemented") - def testSetArgumentLayoutsLegacy(self): - """Tests setting the arg layouts with compile_options (deprecated). - - New code should use the mhlo.layout_mode string attr on parameters. - """ - # Create computation with custom input layouts. - c = self._NewComputation() - param_count = 0 - - def MakeArg(shape, dtype, layout): - nonlocal param_count - arr = np.arange(np.prod(shape), dtype=dtype).reshape(shape) - param = ops.Parameter(c, param_count, - xla_client.shape_from_pyval(arr, layout)) - param_count += 1 - shape = xla_client.Shape.array_shape(np.dtype(dtype), shape, layout) - return arr, param, shape - - arg0, p0, shape0 = MakeArg((2, 3, 4), np.float32, (1, 2, 0)) - arg1, p1, shape1 = MakeArg((3, 2), np.int32, (0, 1)) - arg2, p2, shape2 = MakeArg((), np.float64, ()) - - ops.Tuple(c, [ - ops.Add(p0, ops.Constant(c, np.ones(arg0.shape, arg0.dtype))), - ops.Add(p1, ops.Constant(c, np.ones(arg1.shape, arg1.dtype))), - ops.Add(p2, ops.Constant(c, np.ones(arg2.shape, arg2.dtype))), - ]) - - # We also need to set the input layouts in the compile options. - options = xla_client.CompileOptions() - options.argument_layouts = [shape0, shape1, shape2] - executable = self.backend.compile( - xla_computation_to_mlir_module(c.build()), compile_options=options) - - # Test that compiled executable has expected layouts. - expected_layouts: Sequence[xla_client.Shape] = [shape0, shape1, shape2] - actual_layouts: Sequence[xla_client.Layout] = ( - executable.get_parameter_layouts()) - self.assertEqual(len(actual_layouts), len(expected_layouts)) - for actual, expected in zip(actual_layouts, expected_layouts): - self.assertEqual( - self._minor_to_major(actual), - expected.layout().minor_to_major(), - ) - - @unittest.skipIf(pathways, "not implemented") - def testSetOutputLayouts(self): - # TODO(b/309682374): implement on CPU and GPU - if self.backend.platform != "tpu": - raise self.skipTest("mhlo.layout_mode only implemented on TPU") - - # Hand-edited version of: - # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), - # np.int32(42), - # np.ones(10)) - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, - %arg1: tensor {mhlo.sharding = "{replicated}"}, - %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) - -> (tensor<1024x8x128xf32> {jax.result_info = "[0]", - mhlo.layout_mode = "{0,1,2}"}, - tensor {jax.result_info = "[1]", - mhlo.layout_mode = "{}"}, - tensor<10xf32> {jax.result_info = "[2]", - mhlo.layout_mode = "{0}"}) { - return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> - } -} - """ - executable = self.backend.compile(module_str) - - # Check output layouts. - output_layouts = executable.get_output_layouts() - self.assertLen(output_layouts, 3) - self.assertEqual(self._minor_to_major(output_layouts[0]), (0, 1, 2)) - self.assertEqual(self._minor_to_major(output_layouts[1]), ()) - self.assertEqual(self._minor_to_major(output_layouts[2]), (0,)) - - # Compile a version with default first output layout so we can make sure - # we actually set it above. - default_executable = self.backend.compile( - module_str.replace('"{0,1,2}"', '"default"') - ) - self.assertNotEqual( - self._minor_to_major(output_layouts[0]), - self._minor_to_major(default_executable.get_output_layouts()[0]), - ) - - @unittest.skipIf(pathways, "not implemented") - def SetLayoutsSharded(self): - # TODO(b/309682374): implement on CPU and GPU - if self.backend.platform != "tpu": - raise self.skipTest("mhlo.layout_mode only implemented on TPU") - - # This also lightly tests mixed default + user-specified input layouts. - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x128xf32> {mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", - mhlo.layout_mode = "{0,1}"}, - %arg1: tensor {mhlo.sharding = "{replicated}"}) - -> (tensor<1024x128xf32> {jax.result_info = "", - mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", - mhlo.layout_mode = "{0,1}"}) { - %0 = stablehlo.convert %arg1 : tensor - %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1024x128xf32> - %2 = stablehlo.add %arg0, %1 : tensor<1024x128xf32> - return %2 : tensor<1024x128xf32> - } -} - """ - executable = self.backend.compile(module_str) - - # Check input layouts. - input_layouts = executable.get_parameter_layouts() - self.assertLen(input_layouts, 2) - self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) - self.assertEqual(self._minor_to_major(input_layouts[1]), ()) - - # Check output layout. - output_layouts = executable.get_output_layouts() - self.assertLen(output_layouts, 1) - self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) - - # Compile a version with default layouts so we can make sure we actually - # set it above. - default_executable = self.backend.compile( - module_str.replace('"{0,1}"', '"default"') - ) - self.assertNotEqual( - self._minor_to_major(input_layouts[0]), - self._minor_to_major(default_executable.get_parameter_layouts()[0]), - ) - self.assertNotEqual( - self._minor_to_major(output_layouts[0]), - self._minor_to_major(default_executable.get_output_layouts()[0]), - ) - - @unittest.skipIf(pathways, "not implemented") - def testAutoArgumentLayouts(self): - # TODO(b/309682374): implement on CPU and GPU - if self.backend.platform != "tpu": - raise self.skipTest("mhlo.layout_mode only implemented on TPU") - - # Hand-edited version of: - # jax.numpy.einsum("...a,ahd->...hd", ...) - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}", - mhlo.layout_mode = "auto"}, - %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", - mhlo.layout_mode = "auto"}) - -> (tensor<1024x8x128xf32> {jax.result_info = ""}) { - %0 = stablehlo.dot_general %arg0, %arg1, - contracting_dims = [1] x [0], - precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, - tensor<1024x8x128xf32>) - -> tensor<1024x8x128xf32> - return %0 : tensor<1024x8x128xf32> - } -} -""" - executable = self.backend.compile(module_str) - - # Check input layouts. - input_layouts = executable.get_parameter_layouts() - self.assertEqual(self._minor_to_major(input_layouts[0]), (1, 0)) - self.assertEqual(self._minor_to_major(input_layouts[1]), (2, 0, 1)) - - # Compile a version with default layouts so we can make sure the compiler - # is actually choosing above. - default_executable = self.backend.compile( - module_str.replace('"auto"', '"default"') - ) - # We expect the compiler to choose a non-default layout for the second - # (1024,8,128) argument. - self.assertNotEqual( - self._minor_to_major(input_layouts[1]), - self._minor_to_major(default_executable.get_parameter_layouts()[1]), - ) - - @unittest.skipIf(pathways, "not implemented") - def testAutoOutputLayouts(self): - # TODO(b/309682374): implement on CPU and GPU - if self.backend.platform != "tpu": - raise self.skipTest("mhlo.layout_mode only implemented on TPU") - - # Generated with jax.numpy.einsum("...a,ahd->...hd", ...) - module_str = """ -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, - mhlo.num_replicas = 1 : i32} { - func.func public @main( - %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}"}, - %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}) - -> (tensor<1024x8x128xf32> {jax.result_info = "", - mhlo.layout_mode = "auto"}) { - %0 = stablehlo.dot_general %arg0, %arg1, - contracting_dims = [1] x [0], - precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, - tensor<1024x8x128xf32>) - -> tensor<1024x8x128xf32> - return %0 : tensor<1024x8x128xf32> - } -} -""" - executable = self.backend.compile(module_str) - - # Check output layout - output_layout, = executable.get_output_layouts() - self.assertEqual(self._minor_to_major(output_layout), (2, 0, 1)) - - # Compile a version with default layouts so we can make sure the compiler - # is actually choosing above. - default_executable = self.backend.compile( - module_str.replace('"auto"', '"default"') - ) - # We expect the compiler to choose a non-default output layout. - self.assertNotEqual( - self._minor_to_major(output_layout), - self._minor_to_major(default_executable.get_output_layouts()[0]), - ) - - tests.append(LayoutsTest) - - class BufferTest(ComputationTest): - """Tests focusing on execution with Buffers.""" - - def testConstantSum(self): - c = self._NewComputation() - ops.Add( - ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) - self._ExecuteAndCompareClose(c, expected=[4.25]) - - def testOneParameterSum(self): - c = self._NewComputation() - ops.Add( - ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), - ops.Constant(c, np.float32(3.14))) - self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) - - def testTwoParameterSum(self): - c = self._NewComputation() - ops.Add( - ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), - ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) - self._ExecuteAndCompareClose( - c, - arguments=[NumpyArrayF32(1.11), - NumpyArrayF32(3.14)], - expected=[4.25]) - - @unittest.skipIf(cloud_tpu or pathways, "not implemented") - def testCannotCallWithDeletedBuffers(self): - c = self._NewComputation() - ops.Add( - ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), - ops.Constant(c, np.float32(3.14))) - arg = NumpyArrayF32(1.11) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - arg_buffer = self.backend.buffer_from_pyval(arg) - arg_buffer.delete() - with self.assertRaises(xla_client.XlaRuntimeError): - compiled_c.execute([arg_buffer]) - - def testXlaShapeIndex(self): - a = xla_client.ShapeIndex((1, 2)) - b = xla_client.ShapeIndex((1, 2)) - c = xla_client.ShapeIndex((2, 3)) - self.assertEqual(a, b) - self.assertNotEqual(b, c) - - def testLayout(self): - f32 = xla_client.PrimitiveType.F32 - a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() - b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() - c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout() - self.assertEqual(a.minor_to_major(), (0, 1)) - self.assertEqual(b.minor_to_major(), (0, 1)) - self.assertEqual(c.minor_to_major(), (1, 0)) - self.assertEqual(a, b) - self.assertNotEqual(a, c) - self.assertNotEqual(b, c) - self.assertEqual(hash(a), hash(b)) - self.assertNotEqual(hash(a), hash(c)) - self.assertNotEqual(hash(b), hash(c)) - - def testBlockUntilReadyWorks(self): - arg = np.array([[1., 2.]], np.float32) - arg_buffer = self.backend.buffer_from_pyval(arg) - arg_buffer.block_until_ready() - # This test merely checks that nothing goes awry when we call - # block_until_ready(); it's difficult to test anything else. - - def testBlockUntilReadyRaisesOnDeletedBuffer(self): - arg = np.array([[1., 2.]], np.float32) - buffer = self.backend.buffer_from_pyval(arg) - buffer.delete() - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "BlockHostUntilReady() called on deleted or donated buffer")): - buffer.block_until_ready() - - @unittest.skipIf(pathways_ifrt, "not implemented") - def testOnDeviceSizeInBytes(self): - if not isinstance(self.backend, xla_client.Client): - self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") - arg0 = np.array([]) - arg1 = np.array([[0., 1., 2.]], np.float32) - arg2 = np.array([[3., 4., 5.]], bfloat16) - arg0_buffer = self.backend.buffer_from_pyval(arg0) - arg1_buffer = self.backend.buffer_from_pyval(arg1) - arg2_buffer = self.backend.buffer_from_pyval(arg2) - self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) - # OnDeviceSizeInBytes varies depending on the platform. Confirm there's - # a reasonable value. - self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) - self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) - - def testLiveBuffers(self): - if not isinstance(self.backend, xla_client.Client): - self.skipTest("TPU Driver doesn't support LiveBuffers().") - self.assertEmpty(self.backend.live_buffers()) - arg0 = np.array([]) - arg1 = np.array([[0., 1., 2.]], np.float32) - arg2 = np.array([[3., 4., 5.]], bfloat16) - arg0_buffer = self.backend.buffer_from_pyval(arg0) - arg1_buffer = self.backend.buffer_from_pyval(arg1) - arg2_buffer = self.backend.buffer_from_pyval(arg2) - self.assertLen(self.backend.live_buffers(), 3) - self.assertIs(self.backend.live_buffers()[0], arg2_buffer) - self.assertIs(self.backend.live_buffers()[1], arg1_buffer) - self.assertIs(self.backend.live_buffers()[2], arg0_buffer) - - arg1_buffer.delete() - self.assertLen(self.backend.live_buffers(), 2) - self.assertIs(self.backend.live_buffers()[0], arg2_buffer) - self.assertIs(self.backend.live_buffers()[1], arg0_buffer) - - arg0_buffer.delete() - arg2_buffer.delete() - self.assertEmpty(self.backend.live_buffers()) - - def testCopyToHost(self): - arg0 = np.array([[1., 2.]], np.float32) - arg1 = np.array([[3., 4.]], np.float32) - arg0_buffer = self.backend.buffer_from_pyval(arg0) - arg1_buffer = self.backend.buffer_from_pyval(arg1) - # Prefetch two buffers using copy_to_host_async, and then retrieve their - # values using np.asarray(). - arg0_buffer.copy_to_host_async() - arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. - arg1_buffer.copy_to_host_async() - np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) - np.testing.assert_equal(arg1, np.asarray(arg1_buffer)) - # copy_to_host_async does nothing after np.asarray() is called. - arg0_buffer.copy_to_host_async() - np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) - - def testDevice(self): - x = np.arange(8, dtype=np.int32) - for device in self.backend.local_devices(): - buf = self.backend.buffer_from_pyval(x, device=device) - self.assertEqual(buf.device(), device) - np.testing.assert_equal(x, np.asarray(buf)) - - def testStandardTypes(self): - for dtype in standard_dtypes: - if dtype == np.complex128: - continue - # float8_e8m0fnu is not supported on TPU. - if dtype == float8_e8m0fnu and self.backend.platform == "tpu": - continue - # float8_e4m3b11fnuz not supported on some TPU backends. - if ( - dtype - in [ - float8_e3m4, - float8_e4m3, - float8_e4m3fnuz, - float8_e4m3b11fnuz, - float8_e5m2fnuz, - ] - and self.backend.platform == "tpu" - ): - if self.backend.platform_version.find("TPU") == -1: - continue - arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) - arr = np.asarray(arr) - self.assertEqual(dtype, type(arr[0])) - - @unittest.skipIf(pathways_ifrt, "not implemented") - def testUnsafeBufferPointer(self): - if not isinstance(self.backend, xla_client.Client): - self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") - arg0 = np.array([]) - arg1 = np.array([[0., 1., 2.]], np.float32) - arg2 = np.array([[3., 4., 5.]], bfloat16) - arg0_buffer = self.backend.buffer_from_pyval(arg0) - arg1_buffer = self.backend.buffer_from_pyval(arg1) - arg2_buffer = self.backend.buffer_from_pyval(arg2) - self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) - self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) - self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) - - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt, "not implemented") - def testClone(self): - x = np.array([[3., 4., 5.]], np.float32) - y = self.backend.buffer_from_pyval(x) - z = y.clone() - self.assertNotEqual(id(x), id(y)) - np.testing.assert_array_equal(np.asarray(y), np.asarray(z)) - self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) - - tests.append(BufferTest) - - class SingleOpTest(ComputationTest): - """Tests for single ops. - - The goal here is smoke testing - to exercise the most basic functionality of - single XLA ops. As minimal as possible number of additional ops are added - around the op being tested. - """ - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testConcatenate(self, dtype): - c = self._NewComputation() - args = ( - ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), - ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), - ) - ops.ConcatInDim(c, args, dimension=0) - self._ExecuteAndCompareExact( - c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) - - # pyformat: disable - @parameterized.named_parameters({ - "testcase_name": "_{}_{}".format(src_dtype.__name__, - dst_dtype.__name__), - "src_dtype": src_dtype, - "dst_dtype": dst_dtype, - } for src_dtype, dst_dtype in itertools.permutations( - [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) - # pyformat: enable - def testConvertElementType(self, src_dtype, dst_dtype): - if ((src_dtype in [np.int64, np.float64] or - dst_dtype in [np.int64, np.float64]) and - self.backend.platform == "tpu"): - self.skipTest("TPU doesn't support float64") - c = self._NewComputation() - x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) - ops.ConvertElementType( - ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - self.assertLen(result, 1) - expected = np.array(x, dtype=dst_dtype) - - self.assertEqual(result[0].shape, expected.shape) - self.assertEqual(result[0].dtype, expected.dtype) - np.testing.assert_equal(result[0], expected) - - # pyformat: disable - @parameterized.named_parameters( - { - "testcase_name": "_{}_{}".format(src_dtype.__name__, - dst_dtype.__name__), - "src_dtype": src_dtype, - "dst_dtype": dst_dtype, - } - for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] - for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) - # pyformat: enable - def testBitcastConvertType(self, src_dtype, dst_dtype): - if (np.float64 in (src_dtype, dst_dtype) and - self.backend.platform == "tpu"): - self.skipTest("TPU doesn't support float64") - c = self._NewComputation() - x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) - ops.BitcastConvertType( - ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - self.assertLen(result, 1) - expected = x.view(dst_dtype) - - self.assertEqual(result[0].shape, expected.shape) - self.assertEqual(result[0].dtype, expected.dtype) - np.testing.assert_equal(result[0], expected) - - # TODO(b/123523486) implement AllToAll on CPU - def DISABLED_testAllToAllOneReplica(self): - samples = [ - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples[:1]: - c = self._NewComputation() - ops.AllToAll(ops.Constant(c, lhs), 0, 0) - self._ExecuteAndCompareExact(c, expected=[lhs]) - - def testCrossReplicaSumOneReplica(self): - samples = [ - NumpyArrayF32(42.0), - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples: - c = self._NewComputation() - ops.CrossReplicaSum(ops.Constant(c, lhs)) - self._ExecuteAndCompareExact(c, expected=[lhs]) - - def testReplicaId(self): - c = self._NewComputation() - _ = ops.ReplicaId(c) - self._ExecuteAndCompareExact(c, expected=[0]) - - def testCrossReplicaSumOneReplicaWithSingletonGroup(self): - samples = [ - NumpyArrayF32(42.0), - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples: - c = self._NewComputation() - ops.CrossReplicaSum( - ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) - self._ExecuteAndCompareExact(c, expected=[lhs]) - - # TODO(phawkins): np.dot implementation doesn't support bfloat16 - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes if dtype != bfloat16) - def testDotMatrixVector(self, dtype): - c = self._NewComputation() - lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) - rhs = np.array([[10.0], [20.0]], dtype=dtype) - ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - # TODO(phawkins): np.dot implementation doesn't support bfloat16 - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes if dtype != bfloat16) - def testDotMatrixMatrix(self, dtype): - c = self._NewComputation() - lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) - rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) - ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - def testDotGeneral(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - dimension_numbers = xla_client.make_dot_dimension_numbers( - (([2], [1]), ([0], [0]))) - ops.DotGeneral( - ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) - - def testDotGeneralWithDotDimensionNumbersProto(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - - dimension_numbers = xla_client.DotDimensionNumbers() - dimension_numbers.lhs_contracting_dimensions.append(2) - dimension_numbers.rhs_contracting_dimensions.append(1) - dimension_numbers.lhs_batch_dimensions.append(0) - dimension_numbers.rhs_batch_dimensions.append(0) - - ops.DotGeneral( - ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) - - def testDotGeneralWithPrecisionConfig(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - dimension_numbers = xla_client.make_dot_dimension_numbers( - (([2], [1]), ([0], [0]))) - config = xla_client.PrecisionConfig() - config.operand_precision.append(config.Precision.HIGH) - config.operand_precision.append(config.Precision.HIGHEST) - ops.DotGeneral( - ops.Constant(c, lhs), - ops.Constant(c, rhs), - dimension_numbers, - precision_config=config) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) - - def testConvGeneralDilatedF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = xla_client.make_convolution_dimension_numbers( - ("NCHW", "OIHW", "NCHW"), 2) - ops.ConvGeneralDilated( - ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, - lhs_dilation, rhs_dilation, dimension_numbers) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedF32WithPrecisionConfig(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = xla_client.make_convolution_dimension_numbers( - ("NCHW", "OIHW", "NCHW"), 2) - config = xla_client.PrecisionConfig() - config.operand_precision.append(config.Precision.HIGHEST) - config.operand_precision.append(config.Precision.DEFAULT) - ops.ConvGeneralDilated( - ops.Constant(c, lhs), - ops.Constant(c, rhs), - strides, - pads, - lhs_dilation, - rhs_dilation, - dimension_numbers, - precision_config=config) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedPermutedF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - - dimension_numbers = xla_client.make_convolution_dimension_numbers( - ("NHWC", "OIHW", "CWNH"), 2) - ops.ConvGeneralDilated( - ops.Constant(c, np.transpose(lhs, - (0, 2, 3, 1))), ops.Constant(c, rhs), - strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) - result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], - [40., 50., 0.]]]]) - self._ExecuteAndCompareClose( - c, expected=[np.transpose(result, (1, 3, 0, 2))]) - - def testConvGeneralDilatedGroupedConvolutionF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 2, 2, 3) - rhs = a(2, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = xla_client.make_convolution_dimension_numbers( - ("NCHW", "OIHW", "NCHW"), 2) - feature_group_count = 2 - ops.ConvGeneralDilated( - ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, - lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ], [ - [0., 0., 0.], - [330., 380., 160.], - [0., 0., 0.], - [480., 530., 220.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedWindowReversalF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - window_reversal = [False, True] - dimension_numbers = xla_client.make_convolution_dimension_numbers( - ("NCHW", "OIHW", "NCHW"), 2) - ops.ConvGeneralDilated( - ops.Constant(c, lhs), - ops.Constant(c, rhs), - strides, - pads, - lhs_dilation, - rhs_dilation, - dimension_numbers, - window_reversal=window_reversal) - result = np.array([[[ - [0., 0., 0.], - [0., 10., 20.], - [0., 0., 0.], - [30., 40., 50.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testBooleanNot(self): - c = self._NewComputation() - arr = NumpyArrayBool([True, False, True]) - ops.Not(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[~arr]) - - def testPopulationCount(self): - c = self._NewComputation() - arr = NumpyArrayS32([3, 0, 1]) - ops.PopulationCount(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) - - def testCountLeadingZeros(self): - c = self._NewComputation() - arr = NumpyArrayS32([0x7FFF, 0x12345678]) - ops.Clz(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[[17, 3]]) - - def testExp(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Exp(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) - - def testExpWithResultAccuracy(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - accuracy = xla_client.ResultAccuracy() - accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT - ops.Exp(ops.Constant(c, arr), accuracy) - self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) - - def testExpm1(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Expm1(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) - - def testExpm1WithResultAccuracy(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - accuracy = xla_client.ResultAccuracy() - accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT - ops.Expm1(ops.Constant(c, arr), accuracy) - self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) - - def testRound(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Round(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) - - def testLog(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Log(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) - - def testLog1p(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Log1p(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) - - def testNeg(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Neg(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[-arr]) - - def testFloor(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Floor(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) - - def testCeil(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - ops.Ceil(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) - - def testAbs(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) - ops.Abs(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) - - def testTanF32(self): - c = self._NewComputation() - arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) - ops.Tan(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.tan(arr)]) - - def testTanhF32(self): - c = self._NewComputation() - arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) - ops.Tanh(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) - - def testTanhF64(self): - if self.backend.platform == "tpu": - self.skipTest("TPU doesn't support 64bit tanh") - c = self._NewComputation() - arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) - ops.Tanh(ops.Constant(c, arr)) - self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) - - def testTranspose(self): - - def _TransposeAndTest(array, permutation): - c = self._NewComputation() - ops.Transpose(ops.Constant(c, array), permutation) - expected = np.transpose(array, permutation) - self._ExecuteAndCompareClose(c, expected=[expected]) - - _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) - _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) - _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) - _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) - - arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) - for permutation in itertools.permutations(range(arr.ndim)): - _TransposeAndTest(arr, permutation) - _TransposeAndTest(np.asfortranarray(arr), permutation) - - def testEq(self): - c = self._NewComputation() - ops.Eq( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), - ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) - - def testNe(self): - c = self._NewComputation() - ops.Ne( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), - ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) - - ops.Ne( - ops.Constant(c, NumpyArrayF32([-2.0, 0.0, - float("nan"), - float("nan")])), - ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, - float("nan")]))) - self._ExecuteAndAssertWith( - np.testing.assert_allclose, - c, (), - expected=[[True, False, True, True]]) - - def testGt(self): - c = self._NewComputation() - ops.Gt( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), - ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[False, True, True, False, False]]) - - def testGe(self): - c = self._NewComputation() - ops.Ge( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), - ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[True, True, True, False, False]]) - - def testLt(self): - c = self._NewComputation() - ops.Lt( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), - ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[False, False, False, True, True]]) - - def testLe(self): - c = self._NewComputation() - ops.Le( - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), - ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[True, False, False, True, True]]) - - def testMax(self): - c = self._NewComputation() - ops.Max( - ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), - ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) - - def testMaxExplicitBroadcastDim0(self): - c = self._NewComputation() - ops.Max( - ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - ops.Constant(c, NumpyArrayF32([3, 4, 5])), - broadcast_dimensions=(0,)) - self._ExecuteAndCompareExact( - c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) - - def testMaxExplicitBroadcastDim1(self): - c = self._NewComputation() - ops.Max( - ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - ops.Constant(c, NumpyArrayF32([3, 4, 5])), - broadcast_dimensions=(1,)) - self._ExecuteAndCompareExact( - c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) - - def testMin(self): - c = self._NewComputation() - ops.Min( - ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), - ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) - - def testPad(self): - c = self._NewComputation() - ops.Pad( - ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - ops.Constant(c, NumpyArrayF32(0.0)), - xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) - self._ExecuteAndCompareClose( - c, - expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) - - def testPadWithPaddingConfig(self): - c = self._NewComputation() - padding_config = xla_client.PaddingConfig() - for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: - dimension = xla_client.PaddingConfigDimension() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - padding_config.dimensions.append(dimension) - ops.Pad( - ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - ops.Constant(c, NumpyArrayF32(0.0)), padding_config) - self._ExecuteAndCompareClose( - c, - expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) - - def testReshape(self): - c = self._NewComputation() - ops.Reshape( - ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), - new_sizes=[2, 3]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) - - def testCollapse(self): - c = self._NewComputation() - ops.Collapse( - ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), - dimensions=[1, 2]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) - - def testRev(self): - c = self._NewComputation() - ops.Rev( - ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), - dimensions=[0, 2]) - self._ExecuteAndCompareExact( - c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) - - def testReducePrecision(self): - c = self._NewComputation() - ops.ReducePrecision( - ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), - exponent_bits=8, - mantissa_bits=7) - self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) - - def testClampF32(self): - c = self._NewComputation() - ops.Clamp( - ops.Constant(c, NumpyArrayF32(-1)), - ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), - ops.Constant(c, NumpyArrayF32(2))) - self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) - - def testClampS32(self): - c = self._NewComputation() - ops.Clamp( - ops.Constant(c, NumpyArrayS32(-1)), - ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), - ops.Constant(c, NumpyArrayS32(2))) - self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) - - def testSelect(self): - c = self._NewComputation() - ops.Select( - ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), - ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), - ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) - self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) - - def testSlice(self): - c = self._NewComputation() - ops.Slice( - ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - [1, 0], [3, 2], [1, 1]) - self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) - - def testSliceInDim(self): - c = self._NewComputation() - ops.SliceInDim( - ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - start_index=1, - limit_index=2, - stride=1, - dimno=1) - self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) - ops.SliceInDim( - ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - start_index=0, - limit_index=3, - stride=2, - dimno=0) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) - - def testDynamicSlice(self): - c = self._NewComputation() - ops.DynamicSlice( - ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [ - ops.Constant(c, NumpyArrayS32(1)), - ops.Constant(c, NumpyArrayS32(0)) - ], [2, 2]) - self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) - - def testDynamicUpdateSlice(self): - c = self._NewComputation() - ops.DynamicUpdateSlice( - ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), [ - ops.Constant(c, NumpyArrayS32(1)), - ops.Constant(c, NumpyArrayS32(1)) - ]) - self._ExecuteAndCompareExact( - c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) - - def testTuple(self): - c = self._NewComputation() - ops.Tuple(c, [ - ops.Constant(c, np.int32(42)), - ops.Constant(c, NumpyArrayF32([1.0, 2.0])), - ops.Constant(c, NumpyArrayBool([True, False, False, True])) - ]) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - self.assertLen(result, 3) - np.testing.assert_equal(result[0], 42) - np.testing.assert_allclose(result[1], [1.0, 2.0]) - np.testing.assert_equal(result[2], [True, False, False, True]) - - def testGetTupleElement(self): - c = self._NewComputation() - ops.GetTupleElement( - ops.Tuple(c, [ - ops.Constant(c, np.int32(42)), - ops.Constant(c, NumpyArrayF32([1.0, 2.0])), - ops.Constant(c, NumpyArrayBool([True, False, False, True])) - ]), 1) - self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) - - def testBroadcast(self): - c = self._NewComputation() - ops.Broadcast( - ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) - self._ExecuteAndCompareExact( - c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) - - def testBroadcastInDim(self): - c = self._NewComputation() - ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) - self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) - ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) - - def testRngNormal(self): - shape = (2, 3) - c = self._NewComputation() - ops.RngNormal( - ops.Constant(c, NumpyArrayF32(0.)), - ops.Constant(c, NumpyArrayF32(1.)), - shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, - shape)) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - # since the result is random, we just check shape and uniqueness - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertLen(np.unique(result[0]), np.prod(shape)) - - def testRngUniformF32(self): - lo, hi = 2., 4. - shape = (2, 3) - c = self._NewComputation() - ops.RngUniform( - ops.Constant(c, NumpyArrayF32(lo)), - ops.Constant(c, NumpyArrayF32(hi)), - shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, - shape)) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - # since the result is random, we just check shape, uniqueness, and range - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertLen(np.unique(result[0]), np.prod(shape)) - self.assertTrue(np.all(lo <= result[0])) - self.assertTrue(np.all(result[0] < hi)) - - def testRngUniformS32(self): - lo, hi = 2, 4 - shape = (2, 3) - c = self._NewComputation() - ops.RngUniform( - ops.Constant(c, NumpyArrayS32(lo)), - ops.Constant(c, NumpyArrayS32(hi)), - shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, - shape)) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - # since the result is random, we just check shape, integrality, and range - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertEqual(result[0].dtype, np.int32) - self.assertTrue(np.all(lo <= result[0])) - self.assertTrue(np.all(result[0] < hi)) - - def testCholesky(self): - l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], - dtype=np.float32) - c = self._NewComputation() - ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) - self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) - - def testSort(self): - keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) - c = self._NewComputation() - ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) - self._ExecuteAndCompareClose( - c, - expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) - - def testSortKeyVal(self): - keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) - values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) - c = self._NewComputation() - ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - self.assertLen(result, 2) - np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) - np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) - - def testSortCustomComparator(self): - b = self._NewComputation("comparator") - p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) - q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) - p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) - q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) - ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) - comparator = b.build() - - keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) - values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) - c = self._NewComputation() - ops.Sort( - c, (ops.Constant(c, keys), ops.Constant(c, values)), - dimension=1, - comparator=comparator) - result = execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), (), - backend=self.backend) - self.assertLen(result, 2) - np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) - np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) - - def testQR(self): - a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], - [10, 63, 166, 310]], - dtype=np.float32) - c = self._NewComputation() - ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) - q, r = self._Execute(c, ()) - np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) - - def testEigh(self): - a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], - [10, 63, 166, 310]], - dtype=np.float32) - a = (a + a.T) / 2 - - c = self._NewComputation() - ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) - # TODO(b/129396575): Turn this test back on when it passes without - # fastmath. - # v, w = self._Execute(c, ()) - # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) - - def testSVD(self): - a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], - [10, 63, 166, 310]], - dtype=np.float32) - c = self._NewComputation() - ops.Tuple(c, ops.SVD(ops.Constant(c, a))) - u, d, v = self._Execute(c, ()) - self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) - - def testTriangularSolve(self): - a_vals = np.array( - [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], - dtype=np.float32) - b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - dtype=np.float32) - - c = self._NewComputation() - ops.TriangularSolve( - ops.Constant(c, a_vals), - ops.Constant(c, b_vals), - left_side=False, - lower=True, - transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, - unit_diagonal=False) - self._ExecuteAndCompareClose( - c, - expected=[ - np.array([ - [0.5, 0.08333334, 0.04629629, 0.03367003], - [2.5, -0.25, -0.1388889, -0.1010101], - [4.5, -0.58333331, -0.32407406, -0.23569024], - ], - dtype=np.float32) - ], - rtol=1e-4) - - def testApproxTopK(self): - if self.backend.platform != "tpu": - self.skipTest("ApproxTopK is only supported on TPU") - k = 10 - qy_size = 256 - db_size = 3000 - feature = 128 - recall_target = 0.95 - b = self._NewComputation() - p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) - q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) - ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) - ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) - ops.Gt(p0, q0) - comparator = b.build() - qy_shape = [qy_size, feature] - db_shape = [feature, db_size] - rng = np.random.RandomState(0) - qy_arg = rng.randn(*qy_shape).astype(np.float32) - db_arg = rng.randn(*db_shape).astype(np.float32) - b = self._NewComputation() - qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg)) - db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg)) - scores = ops.Dot(qy, db) - iota = ops.Iota( - b, - xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, - (qy_size, db_size)), 1) - init_val = ops.Constant(b, np.float32(-1)) - init_arg = ops.Constant(b, np.int32(-1)) - ground_truth = ops.TopK(scores, k=k) - approx_topk = ops.ApproxTopK( - b, [scores, iota], [init_val, init_arg], - top_k=k, - reduction_dim=1, - comparator=comparator, - recall_target=recall_target) - ops.Tuple(b, [ - ops.GetTupleElement(ground_truth, 1), - ops.GetTupleElement(approx_topk, 1) - ]) - results = self._Execute(b, [qy_arg, db_arg]) - ground_truth_docids = [set(x) for x in results[0]] - hits = sum( - len([x for x in approx_topk_per_q if x in ground_truth_docids[q]]) - for q, approx_topk_per_q in enumerate(results[1]) - ) - self.assertGreater(hits / (qy_size * k), recall_target) - - def testIsConstant(self): - c = self._NewComputation() - a = ops.Constant(c, np.int32(3)) - b = ops.Constant(c, np.int32(1)) - x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) - const_expr = ops.Sub(b, a) - non_const_expr = ops.Mul(const_expr, x) - self.assertTrue(c.is_constant(const_expr)) - self.assertFalse(c.is_constant(non_const_expr)) - - def testGather(self): - a = np.arange(9).astype(np.int32).reshape((3, 3)) - indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) - dnums = xla_client.GatherDimensionNumbers() - dnums.offset_dims.append(1) - dnums.offset_dims.append(2) - dnums.start_index_map.append(0) - dnums.start_index_map.append(1) - dnums.index_vector_dim = 2 - c = self._NewComputation() - ops.Gather( - ops.Constant(c, a), - ops.Constant(c, indices), - dnums, - slice_sizes=[1, 1]) - g, = self._Execute(c, ()) - expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) - np.testing.assert_allclose(g, expected, rtol=1e-4) - - def testAllGather(self): - a = np.arange(9).astype(np.int32).reshape((3, 3)) - c = self._NewComputation() - ops.AllGather( - operand=ops.Constant(c, a), - all_gather_dimension=0, - shard_count=1, - replica_groups=xla_client.make_replica_groups([[0]]), - use_global_device_ids=False) - [g] = self._Execute(c, ()) - np.testing.assert_equal(g, a) - - def testFft(self): - if self.backend.platform == "tpu": - self.skipTest("TPU only supports 1D FFT") - shape = [2, 3, 4, 5] - rng = np.random.RandomState(0) - a = rng.randn(*shape) + 1.0j * rng.randn(*shape) - a = a.astype(np.complex64) - # FFT - c = self._NewComputation() - ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) - # IFFT - c = self._NewComputation() - ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) - # RFFT - b = rng.randn(*shape).astype(np.float32) - c = self._NewComputation() - ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) - # IRFFT - c = self._NewComputation() - ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4 - ) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes + fp8_dtypes) - def testNextAfter(self, dtype): - if dtype == float8_e8m0fnu: - # TODO(b/409114865): Test fails with Mismatched elements error. - self.skipTest("b/409114865: Test fails with Mismatched elements error") - if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3") - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - if dtype == bfloat16 and self.backend.platform == "tpu": - self.skipTest("b/371119032: Test fails on TPUs with bfloat16") - finfo = ml_dtypes.finfo(dtype) - eps = finfo.eps - c = self._NewComputation() - # Each row is (value, direction, expected), where - # 'nextafter(value, direction)' should be 'expected'. - data = np.array( - [ - [1, 2, 1 + finfo.eps], - [2, 1, 2 - eps], - [-0., 1, finfo.smallest_subnormal], - [0., -1, -finfo.smallest_subnormal], - [-finfo.smallest_subnormal, 1, -0.], - [finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal], - [finfo.smallest_subnormal, -1, 0], - ], - dtype=dtype, - ) - - ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1])) - out, = self._Execute(c, ()) - np.testing.assert_equal(out, data[:, 2]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testRegularizedIncompleteBeta(self, dtype): - x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], - dtype=dtype) - a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], - dtype=dtype) - b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], - dtype=dtype) - c = self._NewComputation() - ops.RegularizedIncompleteBeta( - ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) - expected = np.array( - [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) - self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) - - tests.append(SingleOpTest) - - class EmbeddedComputationsTest(ComputationTest): - """Tests for XLA graphs with embedded computations (such as maps).""" - - def _CreateConstantComputation(self, in_dtype, out_dtype): - """Computation (A) -> B that returns a constant 1 for any input.""" - c = self._NewComputation("constant_{}_{}_one".format( - in_dtype.__name__, out_dtype.__name__)) - ops.Parameter( - c, 0, - xla_client.shape_from_pyval(np.array( - 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) - ops.Constant(c, out_dtype(1)) - return c.build() - - def _CreateMulBy2Computation(self, dtype): - """Computation (dtype) -> dtype that multiplies its parameter by 2.""" - c = self._NewComputation("mul_f32_by2") - ops.Mul( - ops.Parameter( - c, 0, - xla_client.shape_from_pyval(np.array( - 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), - ops.Constant(c, dtype(2.0))) - return c.build() - - def _CreateMulF32ByParamComputation(self): - """Computation (f32) -> f32 that multiplies one parameter by the other.""" - c = self._NewComputation("mul_f32_by_param") - ops.Mul( - ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), - ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) - return c.build() - - def _CreateBinaryAddComputation(self, dtype): - """Computation (dtype, dtype) -> dtype that adds its two parameters.""" - c = self._NewComputation("add_param0_by_param1") - shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) - shape = shape.with_major_to_minor_layout_if_absent() - ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.build() - - def _CreateBinaryGeComputation(self, dtype): - """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" - c = self._NewComputation("param0_lt_param1") - shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) - shape = shape.with_major_to_minor_layout_if_absent() - ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.build() - - def _MakeSample3DArray(self, dtype): - return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], - dtype=dtype) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testCall(self, dtype): - c = self._NewComputation() - ops.Call( - c, - self._CreateMulBy2Computation(dtype), - operands=(ops.Constant(c, dtype(5.0)),)) - self._ExecuteAndCompareClose(c, expected=[10.0]) - - @parameterized.named_parameters({ - "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), - "in_dtype": in_dtype, - "out_dtype": out_dtype, - } for in_dtype, out_dtype in [[np.float32, np.int32]]) - def testMapEachElementToConstant(self, in_dtype, out_dtype): - c = self._NewComputation() - ops.Map(c, - [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], - self._CreateConstantComputation(in_dtype, out_dtype), [0]) - self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testMapMulBy2(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - c = self._NewComputation() - ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], - self._CreateMulBy2Computation(dtype), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testSimpleMapChain(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - # Chains a map of constant-out with a map of mul-by-2 - c = self._NewComputation() - const = ops.Map( - c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], - self._CreateConstantComputation(dtype, dtype), [0]) - ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) - - # TODO(b/154752816): bfloat16 crashes in evaluator. - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes if dtype != bfloat16) - def testDivVectorsWithMap(self, dtype): - - def DivComputation(): - c = self._NewComputation("div_param0_by_param1") - shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) - ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.build() - - c = self._NewComputation() - ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), - ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), - DivComputation(), [0]) - self._ExecuteAndCompareClose( - c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testSelectAndScatter(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - c = self._NewComputation() - operand = ops.Constant( - c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) - window_dimensions = (2, 1) - window_strides = (1, 2) - padding = xla_client.window_padding_type_to_pad_values( - xla_client.PaddingType.VALID, - c.get_shape(operand).dimensions(), window_dimensions, window_strides) - ops.SelectAndScatterWithGeneralPadding( - operand, - select=self._CreateBinaryGeComputation(dtype), - window_dimensions=window_dimensions, - window_strides=window_strides, - padding=padding, - source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), - init_value=ops.Constant(c, np.array(1, dtype=dtype)), - scatter=self._CreateBinaryAddComputation(dtype)) - self._ExecuteAndCompareClose( - c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testReduce1DtoScalar(self, dtype): - c = self._NewComputation() - ops.Reduce( - c, - operands=[ - ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) - ], - init_values=[ops.Constant(c, dtype(0))], - computation=self._CreateBinaryAddComputation(dtype), - dimensions_to_reduce=[0]) - self._ExecuteAndCompareClose(c, expected=[10]) - - # TODO(phawkins): test comparison harness doesn't support bfloat16 - @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") - @parameterized.named_parameters({ - "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), - "dtype": dtype, - "dim": dim, - } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) - def testReduce2DTo1D(self, dtype, dim): - input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) - c = self._NewComputation() - ops.Reduce( - c, - operands=[ops.Constant(c, input_array)], - init_values=[ops.Constant(c, dtype(0))], - computation=self._CreateBinaryAddComputation(dtype), - dimensions_to_reduce=[dim]) - self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) - - @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") - @parameterized.named_parameters({ - "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), - "dtype": dtype, - "dims": tuple(dims) - } for dtype in float_dtypes for dims in itertools.permutations(range(3))) - def testReduce3DAllPossibleWaysF32(self, dtype, dims): - input_array = self._MakeSample3DArray(dtype) - c = self._NewComputation() - ops.Reduce( - c, - operands=[ops.Constant(c, input_array)], - init_values=[ops.Constant(c, dtype(0))], - computation=self._CreateBinaryAddComputation(dtype), - dimensions_to_reduce=dims) - self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testReduceWindowValidUnitStrides(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) - c = self._NewComputation() - window_dimensions = (2, 1) - window_strides = (1, 1) - padding = xla_client.window_padding_type_to_pad_values( - xla_client.PaddingType.VALID, input_array.shape, window_dimensions, - window_strides) - ops.ReduceWindowWithGeneralPadding( - operand=ops.Constant(c, input_array), - init_value=ops.Constant(c, dtype(0)), - computation=self._CreateBinaryAddComputation(dtype), - window_dimensions=window_dimensions, - window_strides=window_strides, - base_dilations=[], - window_dilations=[], - padding=padding) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testReduceWindowSameUnitStrides(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) - c = self._NewComputation() - window_dimensions = (2, 1) - window_strides = (1, 1) - padding = xla_client.window_padding_type_to_pad_values( - xla_client.PaddingType.SAME, input_array.shape, window_dimensions, - window_strides) - ops.ReduceWindowWithGeneralPadding( - operand=ops.Constant(c, input_array), - init_value=ops.Constant(c, dtype(0)), - computation=self._CreateBinaryAddComputation(dtype), - window_dimensions=window_dimensions, - window_strides=window_strides, - base_dilations=[], - window_dilations=[], - padding=padding) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testReduceWindowValidGeneralStrides(self, dtype): - if dtype == np.float64 and self.backend.platform == "tpu": - self.skipTest("TPU doesn't support float64") - input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) - c = self._NewComputation() - window_dimensions = (2, 1) - window_strides = (1, 2) - padding = xla_client.window_padding_type_to_pad_values( - xla_client.PaddingType.VALID, input_array.shape, window_dimensions, - window_strides) - ops.ReduceWindowWithGeneralPadding( - operand=ops.Constant(c, input_array), - init_value=ops.Constant(c, dtype(0)), - computation=self._CreateBinaryAddComputation(dtype), - window_dimensions=window_dimensions, - window_strides=window_strides, - base_dilations=[], - window_dilations=[], - padding=padding) - self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) - - @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") - def testReduceWindowVariadic(self): - c = self._NewComputation("reducer") - shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) - shape = shape.with_major_to_minor_layout_if_absent() - ps = [ops.Parameter(c, i, shape) for i in range(4)] - which = ops.Ge(ps[0], ps[2]) - ops.Tuple( - c, [ops.Select(which, ps[0], ps[2]), - ops.Select(which, ps[1], ps[3])]) - reducer = c.build() - - key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) - val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) - c = self._NewComputation() - window_dimensions = (2, 1) - window_strides = (1, 1) - padding = xla_client.window_padding_type_to_pad_values( - xla_client.PaddingType.VALID, key_array.shape, window_dimensions, - window_strides) - ops.ReduceWindowWithGeneralPadding( - operands=[ops.Constant(c, key_array), - ops.Constant(c, val_array)], - init_values=[ - ops.Constant(c, np.int32(0)), - ops.Constant(c, np.int32(0)) - ], - computation=reducer, - window_dimensions=window_dimensions, - window_strides=window_strides, - base_dilations=[], - window_dilations=[], - padding=padding) - self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) - - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in float_dtypes) - def testWhile(self, dtype): - - def LessThan10Cond(): - c = self._NewComputation("test_lt_10") - shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) - ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) - return c.build() - - cond = LessThan10Cond() - body = self._CreateMulBy2Computation(dtype) - c = self._NewComputation() - init = ops.Constant(c, dtype(1.)) - ops.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=[16.]) - - def testConditionalTrue(self): - c = self._NewComputation() - pred = ops.Constant(c, np.bool_(True)) - true_operand = ops.Constant(c, np.float32(3.)) - true_computation = self._CreateMulBy2Computation(np.float32) - false_operand = ops.Constant(c, np.float32(2.)) - false_computation = self._CreateConstantComputation( - np.float32, np.float32) - ops.Conditional(pred, true_operand, true_computation, false_operand, - false_computation) - self._ExecuteAndCompareClose(c, expected=[6.]) - - def testConditionalFalse(self): - c = self._NewComputation() - pred = ops.Constant(c, np.bool_(False)) - true_operand = ops.Constant(c, np.float32(3.)) - true_computation = self._CreateMulBy2Computation(np.float32) - false_operand = ops.Constant(c, np.float32(2.)) - false_computation = self._CreateConstantComputation( - np.float32, np.float32) - ops.Conditional(pred, true_operand, true_computation, false_operand, - false_computation) - self._ExecuteAndCompareClose(c, expected=[1.]) - - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, - "not implemented") - def testInfeedS32Values(self): - to_infeed = NumpyArrayS32([1, 2, 3, 4]) - c = self._NewComputation() - ops.GetTupleElement( - ops.InfeedWithToken( - ops.CreateToken(c), - xla_client.shape_from_pyval( - to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - device = self.backend.local_devices()[0] - for item in to_infeed: - device.transfer_to_infeed(item) - - for item in to_infeed: - result, = execute_with_python_values( - compiled_c, (), backend=self.backend) - self.assertEqual(result, item) - - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, - "not implemented") - def testInfeedTuple(self): - to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) - c = self._NewComputation() - ops.GetTupleElement( - ops.InfeedWithToken( - ops.CreateToken(c), - xla_client.shape_from_pyval( - to_infeed).with_major_to_minor_layout_if_absent()), 0) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - device = self.backend.local_devices()[0] - device.transfer_to_infeed(to_infeed) - - result = execute_with_python_values( - compiled_c, (), backend=self.backend) - self.assertLen(result, 2) - np.testing.assert_equal(result[0], to_infeed[0]) - np.testing.assert_equal(result[1], to_infeed[1]) - - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, - "not implemented") - def testInfeedThenOutfeedS32(self): - to_round_trip = NumpyArrayS32([1, 2, 3, 4]) - c = self._NewComputation() - x_and_token = ops.InfeedWithToken( - ops.CreateToken(c), - xla_client.shape_from_pyval( - to_round_trip[0]).with_major_to_minor_layout_if_absent()) - x = ops.GetTupleElement(x_and_token, 0) - token = ops.GetTupleElement(x_and_token, 1) - outfeed_shape = xla_client.shape_from_pyval( - to_round_trip[0]).with_major_to_minor_layout_if_absent() - ops.OutfeedWithToken(x, token, outfeed_shape) - ops.Tuple(c, ()) - - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - device = self.backend.local_devices()[0] - - for want in to_round_trip: - execution = threading.Thread(target=lambda: compiled_c.execute([])) - execution.start() - device.transfer_to_infeed(want) - got = device.transfer_from_outfeed(outfeed_shape) - execution.join() - self.assertEqual(want, got) - - def testScatter(self): - a = np.arange(9).astype(np.int32).reshape((3, 3)) - scatter_indices = np.array([0, 2], dtype=np.int32) - updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) - - dnums = xla_client.ScatterDimensionNumbers() - dnums.update_window_dims.append(1) - dnums.inserted_window_dims.append(0) - dnums.scatter_dims_to_operand_dims.append(0) - dnums.index_vector_dim = 1 - - c = self._NewComputation() - ops.Scatter( - ops.Constant(c, a), ops.Constant(c, scatter_indices), - ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), - dnums) - expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], - dtype=np.int32) - self._ExecuteAndCompareClose(c, expected=[expected]) - - class DeviceTest(ComputationTest): - - def testDevices(self): - self.assertNotEmpty(self.backend.devices()) - - def testLocalDevices(self): - self.assertNotEmpty(self.backend.local_devices()) - if self.backend.platform == "cpu": - self.assertLen(self.backend.local_devices(), 2) - - def testGetAllDevices(self): - # TODO(hyeontaek): Remove this method once we have a unified API for - # enumerating devices with different criteria. - self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access - - def testPlatform(self): - for device in self.backend.local_devices(): - self.assertEqual(device.platform, self.backend.platform) - - def testCoreCount(self): - if self.backend.platform != "gpu": - self.skipTest("core_count is only supported on GPU") - for device in self.backend.local_devices(): - self.assertGreater(device.core_count, 0) - - def testLocalHardwareId(self): - for device in self.backend.devices(): - local_hardware_id = device.local_hardware_id - if local_hardware_id is not None: - self.assertGreaterEqual(local_hardware_id, 0) - - @unittest.skipIf(pathways_ifrt, "not implemented") - def testLocalDeviceFromLocalHardwareId(self): - for device in self.backend.local_devices(): - if device.local_hardware_id is not None: - lookup_device = self.backend.device_from_local_hardware_id( - device.local_hardware_id) - self.assertEqual(lookup_device, device) - - @unittest.skipIf(pathways, "not implemented") - @unittest.skipIf(pathways_ifrt, "not implemented") - def testMemoryStats(self): - for device in self.backend.local_devices(): - stats = device.memory_stats() - if ( - self.backend.platform != "tpu" or not tfrt_tpu - ) and self.backend.platform not in ("gpu", "cuda", "rocm"): - self.assertIsNone(stats) - else: - self.assertIsNotNone(stats) - # Spot check a few fields - self.assertEqual(type(stats["num_allocs"]), int) - self.assertGreaterEqual(stats["num_allocs"], 0) - self.assertEqual(type(stats["bytes_in_use"]), int) - self.assertGreaterEqual(stats["bytes_in_use"], 0) - self.assertEqual(type(stats["peak_bytes_in_use"]), int) - self.assertGreaterEqual(stats["peak_bytes_in_use"], 0) - self.assertEqual(type(stats["largest_alloc_size"]), int) - self.assertGreaterEqual(stats["largest_alloc_size"], 0) - - @unittest.skipIf(pathways, "not implemented") - def testMemory(self): - for device in self.backend.local_devices(): - for memory in device.addressable_memories(): - self.assertEqual(memory.process_index, device.process_index) - self.assertEqual(memory.platform, device.platform) - self.assertIn(device, memory.addressable_by_devices()) - self.assertEqual(memory, device.memory(memory.kind)) - - tests.append(DeviceTest) - - class ErrorTest(ComputationTest): - - def setUp(self): - super(ErrorTest, self).setUp() - self.f32_scalar_2 = NumpyArrayF32(2.0) - self.s32_scalar_2 = NumpyArrayS32(2) - - def testCompileWithWrongElementTypeInLayout(self): - c = self._NewComputation() - c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) - ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) - c.clear_op_metadata() - - options = xla_client.CompileOptions() - options.argument_layouts = [ - xla_client.Shape.array_shape(np.dtype(np.float32), []) - ] - - def TestFun(): - return self.backend.compile(c.build(), compile_options=options) - - self.assertRaisesRegex( - RuntimeError, r".*Invalid argument shape.*" - r"expected s32\[\], got f32\[\].*", TestFun) - - def testInvokeWithWrongElementType(self): - c = self._NewComputation() - c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) - ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) - c.clear_op_metadata() - - def TestFun(): - return execute_with_python_values( - self.backend.compile(xla_computation_to_mlir_module(c.build())), - [self.f32_scalar_2], self.backend) - - self.assertRaisesRegex( - RuntimeError, r"Invalid argument: Argument does not match.*" - r"want s32\[\], got f32\[\].*", TestFun) - - tests.append(EmbeddedComputationsTest) - - class ComputationRootTest(ComputationTest): - """Tests related to setting the root of the computation.""" - - def testComputationRootDifferentFromLastOp(self): - c = self._NewComputation() - x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) - result = ops.Add(x, ops.Constant(c, np.float32(3.14))) - ops.Add(result, ops.Constant(c, np.float32(1.618))) - - arg = NumpyArrayF32(1.0) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build(result))) - ans, = execute_with_python_values( - compiled_c, [arg], backend=self.backend) - np.testing.assert_allclose(ans, 4.14) - - tests.append(ComputationRootTest) - - class SetShardingTest(ComputationTest): - """Tests related to set OpSharding.""" - - def testSetSharding(self): - c = self._NewComputation() - sharding = xla_client.OpSharding() - sharding.type = xla_client.OpSharding.Type.REPLICATED - sharding.tile_assignment_dimensions = [1] - sharding.tile_assignment_devices = [0] - c.set_sharding(sharding) - x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) - c.clear_sharding() - - result = ops.Add(x, ops.Constant(c, np.float32(3.14))) - ops.Add(result, ops.Constant(c, np.float32(1.618))) - arg = NumpyArrayF32(1.0) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build(result))) - ans, = execute_with_python_values( - compiled_c, [arg], backend=self.backend) - np.testing.assert_allclose(ans, 4.14) - - tests.append(SetShardingTest) - - testcase_shapes = [ - (), - (1,), - (2, 3), - (2, 0), - (0, 7), - (4, 1, 2), - (2, 1, 3), - (2, 4, 1), - (3, 1), - (1, 3), - ] - - def FormatShapeAndDtype(shape, dtype): - return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) - - class DLPackTest(parameterized.TestCase): - - def setUp(self): - super(DLPackTest, self).setUp() - self.backend = xla_backend() - if self.backend.platform not in ("cpu", "gpu", "cuda", "rocm"): - self.skipTest("DLPack requires CPU or GPU") - self.cpu_backend = ( - self.backend - if self.backend.platform == "cpu" else xla_client.make_cpu_client()) - self.gpu_backend = ( - self.backend - if self.backend.platform in ("gpu", "cuda", "rocm") - else None - ) - - def tearDown(self): - super().tearDown() - del self.backend - del self.cpu_backend - del self.gpu_backend - - @classmethod - def _GetStreamFromDevice(cls, device): - try: - return device.get_stream_for_external_ready_events() - except xla_client.XlaRuntimeError as err: # type: ignore - if "UNIMPLEMENTED" in str(err): - return None - else: - raise - - def _DLPackManagedTensorToBuffer( - self, tensor, use_legacy_api, backend=None - ): - if use_legacy_api: - return xla_client._xla.dlpack_managed_tensor_to_buffer( - tensor, self.cpu_backend, self.gpu_backend - ) - else: - if not backend: - backend = self.backend - device = backend.local_devices()[0] - stream = DLPackTest._GetStreamFromDevice(device) - return xla_client._xla.dlpack_managed_tensor_to_buffer( - tensor, device, stream - ) - - # pylint: disable=g-complex-comprehension - # pyformat: disable - @parameterized.named_parameters( - { - "testcase_name": "{}_gpu={}{}".format( - FormatShapeAndDtype(shape, dtype), - gpu, - "_legacy" if use_legacy_api else "", - ), - "dtype": dtype, - "shape": shape, - "gpu": gpu, - "use_legacy_api": use_legacy_api, - } - for dtype in dlpack_dtypes - for shape in testcase_shapes - for gpu in [False, True] - for use_legacy_api in [False, True] - ) - # pyformat: enable - def testRoundTrip(self, dtype, shape, gpu, use_legacy_api): - if gpu and self.gpu_backend is None: - raise unittest.SkipTest("Test not running with GPU support") - backend = self.gpu_backend if gpu else self.cpu_backend - if dtype == np.bool_: - x = np.random.randint(0, 2, size=shape).astype(np.bool_) - else: - x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - buffer = backend.buffer_from_pyval(x) - dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - del buffer # Free "buffer" to make sure dlt retains ownership. - self.assertEqual(type(dlt).__name__, "PyCapsule") - y = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api, backend) - np.testing.assert_array_equal( - x.astype(np.uint8) if dtype == np.bool_ else x, np.asarray(y)) - - @parameterized.named_parameters( - { - "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), - "use_legacy_api": use_legacy_api, - } - for use_legacy_api in [False, True] - ) - def testTensorsCanBeConsumedOnceOnly(self, use_legacy_api): - x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - buffer = self.backend.buffer_from_pyval(x) - dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - - def ConsumeDLPackTensor(): - _ = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api) - - ConsumeDLPackTensor() - self.assertRaisesRegex( - RuntimeError, ".*a DLPack tensor may be consumed at most once.*", - ConsumeDLPackTensor) - - @parameterized.named_parameters( - { - "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), - "use_legacy_api": use_legacy_api, - } - for use_legacy_api in [False, True] - ) - def testNonOwnedDlpackCanBeViewedTwice(self, use_legacy_api): - x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - buffer = self.backend.buffer_from_pyval(x) - d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - - y = self._DLPackManagedTensorToBuffer(d1, use_legacy_api) - z = self._DLPackManagedTensorToBuffer(d2, use_legacy_api) - del d1, d2 - np.testing.assert_array_equal(x, np.asarray(buffer)) - np.testing.assert_array_equal(x, np.asarray(y)) - np.testing.assert_array_equal(x, np.asarray(z)) - - @parameterized.parameters(False, True) - def testZeroCopyOnAlignedDlpackTensor(self, use_legacy_api): - # Using CPU only, since this test is about CPU memory alignment. - if self.backend.platform != "cpu": - self.skipTest("Test requires CPU") - - # Create a numpy array that is aligned to XLA requirements. - x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - x = _Aligned(x) - - # Convert it to a DLPack tensor, and then to an XLA buffer. - dlpack_tensor = x.__dlpack__() - buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) - y = np.array(buffer, copy=False) - - # The input was sufficiently aligned, so input and output should alias. - x_ptr = x.__array_interface__["data"][0] - y_ptr = y.__array_interface__["data"][0] - self.assertEqual( - x_ptr, - y_ptr, - msg=f"Buffers are not aliased ({hex(x_ptr)} != {hex(y_ptr)}).", - ) - - @parameterized.named_parameters( - { - "testcase_name": "{}{}".format( - "_legacy" if use_legacy_api else "", - "_transpose" if transpose else "", - ), - "use_legacy_api": use_legacy_api, - "transpose": transpose, - } - for use_legacy_api in [False, True] - for transpose in [False, True] - ) - def testReturnCopyOnUnalignedDlpackTensor(self, use_legacy_api, transpose): - # Using CPU only, since this test is about CPU memory alignment. - if self.backend.platform != "cpu": - self.skipTest("Test requires CPU") - - if transpose and use_legacy_api: - self.skipTest("Non-default layout is not supported in legacy API") - - # Create a numpy array that is not aligned to XLA requirements. XLA's - # alignment requirements differ for different hardware, so we use the - # smallest possible value. If we make sure the buffer is not aligned to - # this value (16 bytes), then it is also not aligned to its multiples (32, - # 64 etc.) - x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - x = _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT) - - # Transpose the array to test non-default layout with trivial striding. - if transpose: - x = x.transpose((0, 2, 1, 3)) - - # Convert it to a DLPack tensor, and then to an XLA buffer. - dlpack_tensor = x.__dlpack__() - buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) - y = np.array(buffer, copy=False) - - # The input was not sufficiently aligned, so input and output should not - # alias (output should be a copy of input, and it should be aligned). - x_ptr = x.__array_interface__["data"][0] - y_ptr = y.__array_interface__["data"][0] - self.assertNotEqual( - x_ptr, - y_ptr, - msg=( - f"Buffers aliased, but should not be ({hex(x_ptr)} ==" - f" {hex(y_ptr)})" - ), - ) - self.assertEqual( - y_ptr % _XLA_CPU_MIN_ALIGNMENT, - 0, - msg="Output buffer not aligned: {hex(y_ptr)}", - ) - np.testing.assert_array_equal(y, x) - - tests.append(DLPackTest) - - class BufferProtocolTest(parameterized.TestCase): - - def setUp(self): - super(BufferProtocolTest, self).setUp() - self.backend = xla_backend() - if self.backend.platform != "cpu": - self.skipTest("Test requires CPU") - - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), - "dtype": dtype, - "shape": shape - } for dtype in standard_dtypes if dtype != bfloat16 - for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): - x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - - x = _Aligned(x) - x_ptr = x.__array_interface__["data"][0] - buffer = self.backend.buffer_from_pyval( - x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) - y = np.array(buffer, copy=False) - y_ptr = y.__array_interface__["data"][0] - np.testing.assert_array_equal(x, y) - - # The input was sufficiently aligned, so input and output should alias. - self.assertEqual(x_ptr, y_ptr) - self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) - - during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL - buffer2 = self.backend.buffer_from_pyval( - x, host_buffer_semantics=during_call) - z = np.array(buffer2, copy=False) - self.assertNotEqual(x.__array_interface__["data"][0], - z.__array_interface__["data"][0]) - - def testDeleteWithActiveView(self): - x = np.random.randn(20, 10) - buffer = self.backend.buffer_from_pyval(x) - buffer_ptr = buffer.unsafe_buffer_pointer() - y = np.array(buffer, copy=False) - buffer.delete() - # It is still legal to access `y`; the array view must keep it alive. - np.testing.assert_array_equal(x, y) - self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) - - tests.append(BufferProtocolTest) - - class TracebackTest(absltest.TestCase): - - def setUp(self): - super(TracebackTest, self).setUp() - self.backend = xla_backend() - - def testNoTracebacksIfDisabled(self): - with xla_client.tracebacks(enabled=False): - self.assertEqual(None, xla_client.Traceback.get_traceback()) - buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) - self.assertEqual(None, buffer.traceback) - - b = xla_client.XlaBuilder("computation") - ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) - e = self.backend.compile(xla_computation_to_mlir_module(b.build())) - self.assertEqual(None, e.traceback) - - def assertIsTracebackContaining(self, tb, function): - self.assertIsInstance(tb, xla_client.Traceback) - self.assertIn(function, str(tb)) - self.assertTrue(any(f.function_name == function for f in tb.frames)) - - def testTracebacks(self): - with xla_client.tracebacks(enabled=True): - tb = xla_client.Traceback.get_traceback() - self.assertIsTracebackContaining(tb, "testTracebacks") - - # Tracebacks are not implemented on the TPU driver extension's variant - # of buffers and executables. - if not isinstance(self.backend, xla_client.Client): - return - - buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) - self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") - - b = xla_client.XlaBuilder("computation") - ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) - e = self.backend.compile(xla_computation_to_mlir_module(b.build())) - self.assertIsTracebackContaining(e.traceback, "testTracebacks") - - def testNestedFunction(self): - - def AFunction(): - - def AnotherFunction(): - return xla_client.Traceback.get_traceback() - - return AnotherFunction() - - with xla_client.tracebacks(enabled=True): - tb = AFunction() - self.assertIsInstance(tb, xla_client.Traceback) - frames = tb.frames - i = next( - i for (i, f) in enumerate(frames) if f.function_name == "AFunction") - self.assertEqual(frames[i - 1].function_name, "AnotherFunction") - self.assertEqual(frames[i + 1].function_name, "testNestedFunction") - - def testPythonTracebackHasCorrectLineNumbers(self): - def B(): - return xla_client.Traceback.get_traceback() - - def A(): - return B() - - tb = A().as_python_traceback() - for frame, lineno in traceback.walk_tb(tb): - if frame.f_code.co_name == "A": - line = A.__code__.co_firstlineno - self.assertBetween(lineno, line, line + 2) - elif frame.f_code.co_name == "B": - line = B.__code__.co_firstlineno - self.assertBetween(lineno, line, line + 2) - - def testAccessingLocalsDoesNotCrash(self): - # https://github.com/google/jax/issues/16027 - tb = xla_client.Traceback.get_traceback() - python_tb = tb.as_python_traceback() - for frame, _ in traceback.walk_tb(python_tb): - _ = frame.f_locals # should not crash - - def testTracebackFromFrames(self): - def FooFn(x): - return x + 1 - - def BarFn(y): - y = y + 1 - y = y + 2 - return y * 2 - - frame_foo = xla_client.Frame( - __file__, - FooFn.__code__.co_name, - FooFn.__code__.co_firstlineno, - FooFn.__code__.co_firstlineno + 1, - ) - frame_bar = xla_client.Frame( - __file__, - BarFn.__code__.co_name, - BarFn.__code__.co_firstlineno, - BarFn.__code__.co_firstlineno + 2, - ) - frames = [frame_foo, frame_bar] - tb = xla_client.Traceback.traceback_from_frames(frames) - - with self.subTest("WalkDoesNotError"): - for frame, _ in traceback.walk_tb(tb): - _ = frame.f_locals # should not crash - - with self.subTest("TracebackCorrectness"): - tb_string = traceback.format_tb(tb) - # The traceback should have the format: - # File , line N in BarFn - # y = y + 2 - # File , line N in FooFn - # return x + 1 - self.assertLen(tb_string, len(frames)) - bar_frame = tb_string[0].split("\n") - self.assertEndsWith(bar_frame[0], "BarFn") - self.assertEqual(bar_frame[1].strip(), "y = y + 2") - foo_frame = tb_string[1].split("\n") - self.assertEndsWith(foo_frame[0], "FooFn") - self.assertEqual(foo_frame[1].strip(), "return x + 1") - - tests.append(TracebackTest) - - class ClientTest(ComputationTest): - - def setUp(self): - super(ClientTest, self).setUp() - self.backend = xla_backend() - - def testPlatformVersion(self): - version = self.backend.platform_version - logging.info("platform_version:\n%s", version) - if self.backend.platform == "cpu": - self.assertEqual(version, "cpu") - elif self.backend.platform in ("gpu", "cuda", "rocm"): - # Following is false if not built with --config=cuda - if version != "": - self.assertTrue( - re.match(r"^cuda \d{4,}$", version), - msg=f"Expected CUDA version string; got {repr(version)}") - elif self.backend.platform == "tpu" and not (pathways or pathways_ifrt): - self.assertIn("tpu", version.lower()) - self.assertIn("cl/", version) - self.assertIn("Built on ", version) - - @unittest.skipIf( - not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins" - ) - def testPjRtCApiVersion(self): - self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0) - self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0) - - @unittest.skipUnless( - not pjrt_c_api and tfrt_tpu, - "Test that attributes are zero for non-plugin tfrt_tpu", - ) - def testStaticTfrtTpuAttributes(self): - self.assertEqual(self.backend.pjrt_c_api_major_version, 0) - self.assertEqual(self.backend.pjrt_c_api_minor_version, 0) - # CL number is defined as -1 when running as test. - self.assertEqual(self.backend.__getattr__("cl_number"), -1) - - @unittest.skipIf( - cloud_tpu or pjrt_c_api or (not pjrt_c_api and tfrt_tpu), - "PJRT version only exist for plugins", - ) - def testNotExistPjRtCApiVersion(self): - with self.assertRaises(AttributeError): - self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement - with self.assertRaises(AttributeError): - self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement - - @unittest.skipIf(pathways or pathways_ifrt, "has different behavior") - def testPluginProgramDoesNotCompile(self): - program = xla_client.ifrt_programs.make_plugin_program("foobar") - options = xla_client.ifrt_programs.make_plugin_compile_options() - with self.assertRaisesRegex( - xla_client.XlaRuntimeError, "PjRtCompiler requires an HloProgram" - ): - self.backend.compile_ifrt_program(program, options) - - @unittest.skipIf(pathways, "does not work with non-ifrt legacy pathways") - def testHloProgramViaIfrtProgram(self): - c = self._NewComputation() - ops.Iota(c, xla_client.PrimitiveType.F32, 10) - program = xla_client.ifrt_programs.make_hlo_program( - xla_computation_to_mlir_module(c.build()) - ) - options = xla_client.ifrt_programs.make_xla_compile_options( - xla_client.CompileOptions(), [] - ) - - compiled_c = self.backend.compile_ifrt_program(program, options) - results = execute_with_python_values( - compiled_c, arguments=(), backend=self.backend - ) - - self.assertLen(results, 1) - np.testing.assert_equal(results[0], np.arange(10, dtype=np.float32)) - - @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, - "not implemented") - def testExecutableSerialization(self): - if self.backend.platform != "tpu": - self.skipTest("Test requires tpu platform") - - c = self._NewComputation() - ops.Add( - ops.Constant(c, NumpyArrayS32([1, 2])), - ops.Constant(c, NumpyArrayS32([3, 4]))) - - options = xla_client.CompileOptions() - executable = self.backend.compile( - xla_computation_to_mlir_module(c.build()), options) - self.assertLen(executable.hlo_modules(), 1) - - serialized = self.backend.serialize_executable(executable) - deserialized = self.backend.deserialize_executable(serialized, options) - - expected, = execute_with_python_values(executable, (), self.backend) - actual, = execute_with_python_values(deserialized, (), self.backend) - self.assertTrue(np.all(actual == expected)) - - def testCompileOptionsSerialization(self): - options = xla_client.CompileOptions() - executable_build_options = options.executable_build_options - options.num_replicas = 3 - options.num_partitions = 2 - options.profile_version = 1337 - options.compile_portable_executable = True - executable_build_options.num_replicas = 3 - executable_build_options.num_partitions = 2 - deb_opt = executable_build_options.debug_options - deb_opt.xla_cpu_enable_fast_math = True - deb_opt.xla_test_all_input_layouts = True - deb_opt.xla_gpu_kernel_cache_file = "/foo/bar" - deb_opt.xla_gpu_enable_llvm_module_compilation_parallelism = True - deb_opt.xla_gpu_per_fusion_autotune_cache_dir = "/bar/foo/" - deb_opt.xla_gpu_experimental_autotune_cache_mode = ( - xla_client.AutotuneCacheMode.READ - ) - - b = options.SerializeAsString() - restored = xla_client.CompileOptions.ParseFromString(b) - - for name in ("num_replicas", "num_partitions", "profile_version", - "compile_portable_executable"): - self.assertEqual(getattr(options, name), getattr(restored, name), - msg=name) - - for name in ("num_replicas", "num_partitions"): - self.assertEqual(getattr(options.executable_build_options, name), - getattr(restored.executable_build_options, name), - msg=name) - - for name in ( - "xla_cpu_enable_fast_math", - "xla_test_all_input_layouts", - "xla_gpu_kernel_cache_file", - "xla_gpu_enable_llvm_module_compilation_parallelism", - "xla_gpu_per_fusion_autotune_cache_dir", - "xla_gpu_experimental_autotune_cache_mode", - ): - self.assertEqual( - getattr(options.executable_build_options.debug_options, name), - getattr(restored.executable_build_options.debug_options, name), - msg=name) - - tests.append(ClientTest) - - # TODO(b/182461453): Add TFRT and cloud TPU implementation of - # ReadDynamicShapes - @unittest.skip("Test fails HLO -> MHLO conversion") - class DynamicReshapeTest(ComputationTest): - """Tests related to DynamicReshape.""" - - def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, - test_fn): - compiled = self.backend.compile( - xla_computation_to_mlir_module(builder.build())) - output_buffers = compiled.execute([ - self.backend.buffer_from_pyval( - arg, device=compiled.local_devices()[0]) for arg in args - ]) - self.assertLen(output_buffers, len(expected_results)) - for buf, expected in zip(output_buffers, expected_results): - to_py_result = np.asarray(buf) - self.assertEqual(expected.shape, to_py_result.shape) - test_fn(expected, to_py_result) - if self.backend.platform == "cpu" and buf.dtype != bfloat16: - mview = memoryview(buf) - self.assertEqual(expected.shape, mview.shape) - test_fn(expected, np.asarray(mview)) - else: - # Buffer protocol expected to fail on non-cpu platforms and bfloat16 - # Note that np.asarray(buf) doesn't throw an exception. To test if the - # error was thrown properly we must use memoryview(buf). - with self.assertRaises(BufferError): - memoryview(buf) - - # 1D reshape of full size, half size, and size of 0. - @unittest.skip("not implemented") - @parameterized.parameters((5), (3), (0)) - def testReshape1D(self, reshape_size): - full_size = 5 - c = self._NewComputation() - arg = np.array(reshape_size, dtype=np.int32) - expected = np.array(range(reshape_size), dtype=np.int32) - p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) - ops.DynamicReshape( - ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], - [True]) - self._CompareToPyAndBufferProtocol(c, [arg], [expected], - np.testing.assert_equal) - - # 2D reshape with an slice on the minor dimension. We test different types - # where the strides may differ between the host and devices. The reshaped - # physical memory layout is not consecutive, and we test if the program can - # return the correct logical view of the data. - @unittest.skipIf( - cloud_tpu or pathways or tfrt_tpu or pjrt_c_api, - "not implemented") - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in int_dtypes + float_dtypes) - def testReshape2D(self, dtype): - arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) - arg1 = np.array(2, dtype=np.int32) - expected = np.array([[1, 2], [4, 5]], dtype=np.int32) - c = self._NewComputation() - p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) - p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) - ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) - self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], - np.testing.assert_equal) - - @unittest.skipIf(cloud_tpu or pathways or tfrt_tpu, "not implemented") - @parameterized.named_parameters({ - "testcase_name": "_{}".format(dtype.__name__), - "dtype": dtype, - } for dtype in int_dtypes + float_dtypes) - def testDynamicShapeArgs(self, dtype): - full_size = 10 - dynamic_shape_size = 4 - # subcomputation 1 - binary_add_builder = self._NewComputation() - scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) - ops.Add( - ops.Parameter(binary_add_builder, 0, scalar_shape), - ops.Parameter(binary_add_builder, 1, scalar_shape)) - # subcomputation 2 - reshape_reduce_builder = self._NewComputation() - dshape = xla_client.Shape.array_shape( - np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) - reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) - ops.Reduce( - reshape_reduce_builder, - operands=[reshape_reduce_p], - init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], - computation=binary_add_builder.build(), - dimensions_to_reduce=[0]) - # main computation: sum(range(full_size)[:dynamic_shape_size]) - c = self._NewComputation() - arg = np.array(dynamic_shape_size, dtype=np.int32) - p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) - reshaped = ops.DynamicReshape( - ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], - [full_size], [True]) - ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) - self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) - - tests.append(DynamicReshapeTest) - - class DeviceAssignmentTest(ComputationTest): - - def testSerialize(self): - shape = (3, 4) - device_assignment = xla_client.DeviceAssignment.create( - np.arange(np.prod(shape)).reshape(*shape)) - self.assertEqual(device_assignment.replica_count(), shape[0]) - self.assertEqual(device_assignment.computation_count(), shape[1]) - serialized = device_assignment.serialize() - self.assertIsInstance(serialized, bytes) - self.assertNotEmpty(serialized) - - tests.append(DeviceAssignmentTest) - - class TokenTest(ComputationTest): - """Tests related to PyToken.""" - - def testExecuteWithToken(self): - c = self._NewComputation() - ops.Mul( - ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), - ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build())) - results, token = compiled_c.execute_with_token([]) - token.block_until_ready() - self.assertLen(results, 1) - np.testing.assert_allclose( - np.asarray(results[0]), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) - - def testExecuteShardedOnLocalDevicesWithTokens(self): - c = self._NewComputation() - ops.Mul( - ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), - ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) - num_replicas = 1 - options = xla_client.CompileOptions() - options.num_replicas = num_replicas - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build()), compile_options=options) - py_results = compiled_c.execute_sharded([], with_tokens=True) - results = py_results.disassemble_into_single_device_arrays() - sharded_token = py_results.consume_token() - sharded_token.block_until_ready() - self.assertLen(results, 1) - self.assertLen(results[0], 1) - np.testing.assert_allclose( - np.asarray(results[0][0]), - np.float32([-3, 6.6, 2.4, -2.1]), - rtol=3e-3) - - tests.append(TokenTest) - - class ExecutePortableTest(ComputationTest): - - @unittest.skip("Test does not work under IFRT") - def testExecutePortable(self): - devices_by_kind = collections.defaultdict(list) - for device in self.backend.devices(): - devices_by_kind[device.device_kind].append(device) - multi_devices = [d for d in devices_by_kind.values() if len(d) > 1] - if not multi_devices: - raise unittest.SkipTest("Test needs multiple identical devices") - devices = multi_devices[0] - - c = self._NewComputation() - args = [ - np.array(3, dtype=np.int32), - np.array([10, 15, -2, 7], dtype=np.int32) - ] - p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0])) - p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1])) - ops.Mul(p0, p1) - options = xla_client.CompileOptions() - options.compile_portable_executable = True - compiled_c = self.backend.compile(c.build(), compile_options=options) - for device in devices: - out, = compiled_c.execute( - [self.backend.buffer_from_pyval(a, device=device) for a in args], - device=device) - np.testing.assert_array_equal(np.asarray(out), args[0] * args[1]) - - tests.append(ExecutePortableTest) - - class ExecuteShardedOverloadTest(ComputationTest): - - def testExecuteShardedOverloadEmptyInput(self): - c = self._NewComputation() - ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)) - options = xla_client.CompileOptions() - options.num_replicas = 1 - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build()), compile_options=options) - - results = compiled_c.execute_sharded( - []).disassemble_into_single_device_arrays() - self.assertLen(results, 1) - self.assertIsInstance(results[0], list) - self.assertLen(results[0], 1) - results[0][0].block_until_ready() - self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - - results = compiled_c.execute_sharded( - [], with_tokens=True).disassemble_into_single_device_arrays() - self.assertLen(results, 1) - self.assertIsInstance(results[0], list) - self.assertLen(results[0], 1) - results[0][0].block_until_ready() - self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - - def testExecuteShardedOverloadBufferInput(self): - arg = np.arange(12, dtype=np.int16).reshape(3, 4) - c = self._NewComputation() - ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) - - options = xla_client.CompileOptions() - options.num_replicas = 1 - compiled_c = self.backend.compile( - xla_computation_to_mlir_module(c.build()), compile_options=options) - - buffer = self.backend.buffer_from_pyval(arg) - - results = compiled_c.execute_sharded( - [[buffer]]).disassemble_into_single_device_arrays() - self.assertLen(results, 1) - self.assertIsInstance(results[0], list) - self.assertLen(results[0], 1) - results[0][0].block_until_ready() - self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - - results = compiled_c.execute_sharded( - [[buffer]], with_tokens=True).disassemble_into_single_device_arrays() - self.assertLen(results, 1) - self.assertIsInstance(results[0], list) - self.assertLen(results[0], 1) - results[0][0].block_until_ready() - self.assertIsInstance(results[0][0], xla_client.ArrayImpl) - - tests.append(ExecuteShardedOverloadTest) - - return tests - - -def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): - # Avoid creating a new backend per test (this causes GPU OOM, and is probably - # inefficient). - backend_fn = functools.lru_cache(maxsize=None)(backend_fn) - for klass in TestFactory(backend_fn, **kw): - test = type(test_prefix + klass.__name__, (klass,), {}) - # Clean up the qualified names of the tests to not include the test factory. - test.__qualname__ = test.__name__ - globals_dict[test.__name__] = test - - -backends = { - "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), -} - -if __name__ == "__main__": - flags.DEFINE_string("backend", "cpu", "Target platform.") - jax.config.parse_flags_with_absl() - # pylint: disable=unnecessary-lambda - InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) - # pylint: enable=unnecessary-lambda - absltest.main() From a86442a531f847ec1900fcaab5f03429e16f28bd Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 25 Apr 2025 15:06:25 +0000 Subject: [PATCH 0823/1769] Add argnums param to fwd_and_bwd --- jax/_src/api.py | 28 ++++++++++++++++++++-------- tests/api_test.py | 2 +- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b608fd64a89..d37e1c8cfd01 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -576,8 +576,9 @@ def _check_output_dtype_revderiv(name, holomorphic, x): _check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad") def fwd_and_bwd( - fun: Callable, has_aux: bool = False, jitted: bool = True - ) -> tuple[Callable, Callable]: + fun: Callable, argnums: int | Sequence[int], has_aux: bool = False, + jitted: bool = True, +) -> tuple[Callable, Callable]: """Creates functions ``fwd`` and ``bwd`` corresponding to the forward and backward pass of a given function ``fun``. The forward function ``fwd(*args)`` functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows @@ -598,7 +599,7 @@ def fwd_and_bwd( ... cot_x, cot_W = f_vjp(cot_out) # not jitted ... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration ... - >>> fwd, bwd = jax.fwd_and_bwd(f) + >>> fwd, bwd = jax.fwd_and_bwd(f, argnums=(0,1)) >>> for i in range(3): ... y, residuals = fwd(x, W) ... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once @@ -606,6 +607,8 @@ def fwd_and_bwd( Args: fun: Function to produce a forward and backward of. + argnums: Integer or sequence of integers. Specifies which positional argument(s) + to differentiate with respect to. has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. @@ -624,13 +627,22 @@ def fwd_and_bwd( ``bwd`` is a function from ``residuals`` and a cotangent vector with the same shape as ``primals_out`` to a tuple of cotangent vectors with the same number - and shapes as ``primals``, representing the vector-Jacobian product of ``fun`` - evaluated at ``primals``. + and shapes as the ``primals`` designated by ``argnums``, representing the + vector-Jacobian product of ``fun`` evaluated at ``primals``. """ - def fwd(*args): - return vjp(fun, *args, has_aux=has_aux) # type: ignore + check_callable(fun) + argnums = _ensure_index(argnums) + + def fwd(*args, **kwargs): + dbg = debug_info('fwd_and_bwd', fun, args, kwargs) + f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False) + return _vjp(f_partial, *dyn_args, has_aux=has_aux) # type: ignore def bwd(f_vjp, outgrad): - return f_vjp(outgrad) + g = f_vjp(outgrad) + g = g[0] if isinstance(argnums, int) else g + return g if jitted: fwd = jit(fwd) bwd = jit(bwd) diff --git a/tests/api_test.py b/tests/api_test.py index 7a1218d3790a..ee9826896082 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1636,7 +1636,7 @@ def f(x, W): expected_y, f_vjp = api.vjp(f, x, W) expected_cot_x, expected_cot_W = f_vjp(cot_out) - fwd, bwd = api.fwd_and_bwd(f) + fwd, bwd = api.fwd_and_bwd(f, argnums=(0,1)) y, residuals = fwd(x, W) cot_x, cot_W = bwd(residuals, cot_out) From 2ecfc24dc861781fdadcdb4af50d60dfcb97697c Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Fri, 25 Apr 2025 08:15:46 -0700 Subject: [PATCH 0824/1769] [CI] Correct the optional GPU presubmit to work properly with workflow dispatch PiperOrigin-RevId: 751421931 --- .github/workflows/bazel_optional_b200.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_optional_b200.yml b/.github/workflows/bazel_optional_b200.yml index 1620022965f9..55c56495d6d9 100644 --- a/.github/workflows/bazel_optional_b200.yml +++ b/.github/workflows/bazel_optional_b200.yml @@ -25,7 +25,7 @@ concurrency: cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: run_tests: - if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} runs-on: linux-x86-a4-224-b200-1gpu container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' name: "Bazel single B200 CUDA tests" From c0389ffa640ecd0ed71c07a60de99922958556cd Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 25 Apr 2025 08:50:35 -0700 Subject: [PATCH 0825/1769] #jax #sdy pass the value of `use_shardy_partitioner` to `get_compile_options` in `UnloadedPmapExecutable`. Otherwise a Shardy lowered JAX module won't be able to go to HLO since it thinks Shardy is disabled. PiperOrigin-RevId: 751435539 --- jax/_src/interpreters/pxla.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 66465bd152a8..f337c996d3e3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1064,6 +1064,7 @@ def from_hlo(hlo: ir.Module, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=False, + use_shardy_partitioner=config.use_shardy_partitioner.value, env_options_overrides=compiler_options, detailed_logging=compiler.use_detailed_logging(hlo), backend=pci.backend, From b359ea79ad6bcf284aa8c9d8c94595788a5f9dcd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 25 Apr 2025 09:15:17 -0700 Subject: [PATCH 0826/1769] [Mosaic GPU] Wire up the plugin that contains the MGPU custom call into Bazel tests PiperOrigin-RevId: 751445068 --- jax_plugins/cuda/BUILD.bazel | 14 +++++++------- jax_plugins/cuda/__init__.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 1f4e5a08dcb9..6566cfc62b0c 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load( - "//jaxlib:jax.bzl", - "if_windows", - "py_library_providing_imports_info", - "pytype_library", + "//jaxlib:jax.bzl", + "if_windows", + "py_library_providing_imports_info", + "pytype_library", ) +licenses(["notice"]) + package( default_applicable_licenses = [], default_visibility = ["//:__subpackages__"], @@ -41,7 +41,7 @@ py_library_providing_imports_info( ], data = if_windows( ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ["//jaxlib/tools:pjrt_c_api_gpu_plugin.so"], ), lib_rule = pytype_library, ) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 13293de7181d..4891fbeb3332 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -51,7 +51,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = os.path.join( - runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so' + runfiles_dir, '__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so' ) if os.path.exists(local_path): From 10f7344cdfa09bd43efce67d43f8cb044313d835 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Apr 2025 15:29:36 -0700 Subject: [PATCH 0827/1769] jax.numpy: move linspace & friends into array_creation --- jax/_src/numpy/array_creation.py | 321 +++++++++++++++++++++++++++++- jax/_src/numpy/lax_numpy.py | 322 +------------------------------ jax/_src/numpy/util.py | 7 + jax/numpy/__init__.py | 6 +- 4 files changed, 331 insertions(+), 325 deletions(-) diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 4f07f94fe8b4..86bcfb2c02f6 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -13,18 +13,23 @@ # limitations under the License. import types -from typing import Any +from functools import partial +import operator +from typing import Any, Literal, overload import numpy as np import jax from jax import lax +from jax._src.api import jit from jax._src import core from jax._src import dtypes +from jax._src.lax import lax as lax_internal from jax._src.lib import xla_client as xc +from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike -from jax._src.util import set_module +from jax._src.util import canonicalize_axis, set_module from jax.sharding import Sharding @@ -405,3 +410,315 @@ def full_like(a: ArrayLike | DuckTypedArray, dtype = dtypes.result_type(a) if dtype is None else dtype return jax.device_put( util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: Literal[False] = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int, + endpoint: bool, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, *, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Return evenly-spaced numbers within an interval. + + JAX implementation of :func:`numpy.linspace`. + + Args: + start: scalar or array of starting values. + stop: scalar or array of stop values. + num: number of values to generate. Default: 50. + endpoint: if True (default) then include the ``stop`` value in the result. + If False, then exclude the ``stop`` value. + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the + interval between adjacent values in ``result``. + axis: integer axis along which to generate the linspace. Defaults to zero. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: + + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` + - ``step`` is the interval between adjacent values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List of 5 values between 0 and 10: + + >>> jnp.linspace(0, 10, 5) + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) + + List of 8 values between 0 and 10, excluding the endpoint: + + >>> jnp.linspace(0, 10, 8, endpoint=False) + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) + + List of values and the step size between them + + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) + >>> vals + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) + >>> step + Array(1.25, dtype=float32) + + Multi-dimensional linspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 10]) + >>> jnp.linspace(start, stop, 5) + Array([[ 0. , 5. ], + [ 1.25, 6.25], + [ 2.5 , 7.5 ], + [ 3.75, 8.75], + [ 5. , 10. ]], dtype=float32) + """ + num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") + return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) + +@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) +def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Implementation of linspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "linspace") + if num < 0: + raise ValueError(f"Number of samples, {num}, must be non-negative.") + start, stop = util.ensure_arraylike("linspace", start, stop) + + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + broadcast_start = util._broadcast_to(start, bounds_shape) + broadcast_stop = util._broadcast_to(stop, bounds_shape) + axis = len(bounds_shape) + axis + 1 if axis < 0 else axis + bounds_shape.insert(axis, 1) + div = (num - 1) if endpoint else num + if num > 1: + delta: Array = lax.convert_element_type(stop - start, computation_dtype) / jax.numpy.array(div, dtype=computation_dtype) + iota_shape = [1,] * len(bounds_shape) + iota_shape[axis] = div + # This approach recovers the endpoints with float32 arithmetic, + # but can lead to rounding errors for integer outputs. + real_dtype = dtypes.finfo(computation_dtype).dtype + step = lax.iota(real_dtype, div).reshape(iota_shape) / jax.numpy.array(div, real_dtype) + step = step.astype(computation_dtype) + out = (broadcast_start.reshape(bounds_shape) * (1 - step) + + broadcast_stop.reshape(bounds_shape) * step) + + if endpoint: + out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], + canonicalize_axis(axis, out.ndim)) + + elif num == 1: + delta = jax.numpy.asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) + out = broadcast_start.reshape(bounds_shape) + else: # num == 0 degenerate case, match numpy behavior + empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + empty_shape.insert(axis, 0) + delta = full((), np.nan, computation_dtype) + out = empty(empty_shape, dtype) + + if dtypes.issubdtype(dtype, np.integer) and not dtypes.issubdtype(out.dtype, np.integer): + out = lax.floor(out) + + sharding = util.canonicalize_device_to_sharding(device) + result = lax_internal._convert_element_type(out, dtype, sharding=sharding) + return (result, delta) if retstep else result + + +@export +def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate logarithmically-spaced values. + + JAX implementation of :func:`numpy.logspace`. + + Args: + start: scalar or array. Used to specify the start value. The start value is + ``base ** start``. + stop: scalar or array. Used to specify the stop value. The end value is + ``base ** stop``. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + base: scalar or array, optional, default=10. Specifies the base of the logarithm. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the logspace. + + Returns: + An array of logarithm. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 + (``10 ** 2``): + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5) + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) + + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 + (``10 ** 2``), excluding endpoint: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5, endpoint=False) + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) + + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) + with base 2: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 7, base=2) + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) + + Multi-dimensional logspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 0]) + >>> base = jnp.array([2, 3]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(start, stop, 5, base=base) + Array([[ 1. , 243. ], + [ 2.378, 61.547], + [ 5.657, 15.588], + [ 13.454, 3.948], + [ 32. , 1. ]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") + return _logspace(start, stop, num, endpoint, base, dtype, axis) + +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of logspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "logspace") + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("logspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + lin = linspace(start, stop, num, + endpoint=endpoint, retstep=False, dtype=None, axis=axis) + return lax.convert_element_type(ufuncs.power(base, lin), dtype) + + +@export +def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate geometrically-spaced values. + + JAX implementation of :func:`numpy.geomspace`. + + Args: + start: scalar or array. Specifies the starting values. + stop: scalar or array. Specifies the stop values. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the geomspace. + + Returns: + An array containing the geometrically-spaced values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + + Examples: + List 5 geometrically-spaced values between 1 and 16: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 5) + Array([ 1., 2., 4., 8., 16.], dtype=float32) + + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 4, endpoint=False) + Array([1., 2., 4., 8.], dtype=float32) + + Multi-dimensional geomspace: + + >>> start = jnp.array([1, 1000]) + >>> stop = jnp.array([27, 1]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(start, stop, 4) + Array([[ 1., 1000.], + [ 3., 100.], + [ 9., 10.], + [ 27., 1.]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") + return _geomspace(start, stop, num, endpoint, dtype, axis) + +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of geomspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "geomspace") + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("geomspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + sign = ufuncs.sign(start) + res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), + num, endpoint=endpoint, base=10.0, + dtype=computation_dtype, axis=0) + axis = canonicalize_axis(axis, res.ndim) + if axis != 0: + # res = moveaxis(res, 0, axis) + res = lax.transpose(res, permutation=(*range(1, axis + 1), 0, *range(axis + 1, res.ndim))) + return lax.convert_element_type(res, dtype) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a28a6d94e3eb..7421111b23a4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -54,11 +54,10 @@ from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util -from jax._src.numpy.array_creation import (empty, empty_like, full, +from jax._src.numpy.array_creation import (empty, empty_like, full, linspace, ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize -from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) @@ -5452,7 +5451,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, sharding = object.aval.sharding sharding = None if sharding.mesh.empty else sharding else: - sharding = canonicalize_device_to_sharding(device) + sharding = util.canonicalize_device_to_sharding(device) # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and @@ -5552,13 +5551,6 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, return out_array -def canonicalize_device_to_sharding(device: xc.Device | Sharding | None - ) -> Sharding | None: - if isinstance(device, xc.Device): - return SingleDeviceSharding(device) - return device - - def _get_platform( device_or_sharding: xc.Device | Sharding | None | str) -> str: """Get device_or_sharding platform or look up config.default_device.value.""" @@ -6386,316 +6378,6 @@ def _arange_dynamic( return (array(start, dtype=dtype) + array(step, dtype=dtype) * lax.iota(dtype, size)) -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: Literal[False] = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int, - endpoint: bool, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, *, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... -@export -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Return evenly-spaced numbers within an interval. - - JAX implementation of :func:`numpy.linspace`. - - Args: - start: scalar or array of starting values. - stop: scalar or array of stop values. - num: number of values to generate. Default: 50. - endpoint: if True (default) then include the ``stop`` value in the result. - If False, then exclude the ``stop`` value. - retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the - interval between adjacent values in ``result``. - axis: integer axis along which to generate the linspace. Defaults to zero. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: - - - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` - - ``step`` is the interval between adjacent values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List of 5 values between 0 and 10: - - >>> jnp.linspace(0, 10, 5) - Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) - - List of 8 values between 0 and 10, excluding the endpoint: - - >>> jnp.linspace(0, 10, 8, endpoint=False) - Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) - - List of values and the step size between them - - >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) - >>> vals - Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) - >>> step - Array(1.25, dtype=float32) - - Multi-dimensional linspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 10]) - >>> jnp.linspace(start, stop, 5) - Array([[ 0. , 5. ], - [ 1.25, 6.25], - [ 2.5 , 7.5 ], - [ 3.75, 8.75], - [ 5. , 10. ]], dtype=float32) - """ - num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) - -@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) -def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Implementation of linspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "linspace") - if num < 0: - raise ValueError(f"Number of samples, {num}, must be non-negative.") - start, stop = util.ensure_arraylike("linspace", start, stop) - - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - broadcast_start = broadcast_to(start, bounds_shape) - broadcast_stop = broadcast_to(stop, bounds_shape) - axis = len(bounds_shape) + axis + 1 if axis < 0 else axis - bounds_shape.insert(axis, 1) - div = (num - 1) if endpoint else num - if num > 1: - delta: Array = lax.convert_element_type(stop - start, computation_dtype) / array(div, dtype=computation_dtype) - iota_shape = [1,] * len(bounds_shape) - iota_shape[axis] = div - # This approach recovers the endpoints with float32 arithmetic, - # but can lead to rounding errors for integer outputs. - real_dtype = finfo(computation_dtype).dtype - step = reshape(lax.iota(real_dtype, div), iota_shape) / array(div, real_dtype) - step = step.astype(computation_dtype) - out = (reshape(broadcast_start, bounds_shape) * (1 - step) + - reshape(broadcast_stop, bounds_shape) * step) - - if endpoint: - out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], - _canonicalize_axis(axis, out.ndim)) - - elif num == 1: - delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) - out = reshape(broadcast_start, bounds_shape) - else: # num == 0 degenerate case, match numpy behavior - empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - empty_shape.insert(axis, 0) - delta = asarray(np.nan, dtype=computation_dtype) - out = reshape(array([], dtype=dtype), empty_shape) - - if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer): - out = lax.floor(out) - - sharding = canonicalize_device_to_sharding(device) - result = lax_internal._convert_element_type(out, dtype, sharding=sharding) - return (result, delta) if retstep else result - - -@export -def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate logarithmically-spaced values. - - JAX implementation of :func:`numpy.logspace`. - - Args: - start: scalar or array. Used to specify the start value. The start value is - ``base ** start``. - stop: scalar or array. Used to specify the stop value. The end value is - ``base ** stop``. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - base: scalar or array, optional, default=10. Specifies the base of the logarithm. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the logspace. - - Returns: - An array of logarithm. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 - (``10 ** 2``): - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5) - Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) - - List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 - (``10 ** 2``), excluding endpoint: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5, endpoint=False) - Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) - - List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) - with base 2: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 7, base=2) - Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) - - Multi-dimensional logspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 0]) - >>> base = jnp.array([2, 3]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(start, stop, 5, base=base) - Array([[ 1. , 243. ], - [ 2.378, 61.547], - [ 5.657, 15.588], - [ 13.454, 3.948], - [ 32. , 1. ]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") - return _logspace(start, stop, num, endpoint, base, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of logspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "logspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("logspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - lin = linspace(start, stop, num, - endpoint=endpoint, retstep=False, dtype=None, axis=axis) - return lax.convert_element_type(ufuncs.power(base, lin), dtype) - - -@export -def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate geometrically-spaced values. - - JAX implementation of :func:`numpy.geomspace`. - - Args: - start: scalar or array. Specifies the starting values. - stop: scalar or array. Specifies the stop values. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the geomspace. - - Returns: - An array containing the geometrically-spaced values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - Examples: - List 5 geometrically-spaced values between 1 and 16: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 5) - Array([ 1., 2., 4., 8., 16.], dtype=float32) - - List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 4, endpoint=False) - Array([1., 2., 4., 8.], dtype=float32) - - Multi-dimensional geomspace: - - >>> start = jnp.array([1, 1000]) - >>> stop = jnp.array([27, 1]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(start, stop, 4) - Array([[ 1., 1000.], - [ 3., 100.], - [ 9., 10.], - [ 27., 1.]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") - return _geomspace(start, stop, num, endpoint, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of geomspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "geomspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("geomspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - sign = ufuncs.sign(start) - res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), - num, endpoint=endpoint, base=10.0, - dtype=computation_dtype, axis=0) - if axis != 0: - res = moveaxis(res, 0, axis) - return lax.convert_element_type(res, dtype) - @export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 8c8575f7c010..367d06065842 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -247,6 +247,13 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: return promote_shapes(fun_name, *promote_dtypes_inexact(*args)) +def canonicalize_device_to_sharding(device: xc.Device | Sharding | None + ) -> Sharding | None: + if isinstance(device, xc.Device): + return SingleDeviceSharding(device) + return device + + @partial(api.jit, inline=True) def _broadcast_arrays(*args: ArrayLike) -> list[Array]: """Like Numpy's broadcast_arrays but doesn't return views.""" diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index b6cfb1ff06ac..935fbcaa708c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -93,7 +93,6 @@ fromstring as fromstring, from_dlpack as from_dlpack, gcd as gcd, - geomspace as geomspace, get_printoptions as get_printoptions, gradient as gradient, histogram as histogram, @@ -118,9 +117,7 @@ ix_ as ix_, kron as kron, lcm as lcm, - linspace as linspace, load as load, - logspace as logspace, mask_indices as mask_indices, matrix_transpose as matrix_transpose, meshgrid as meshgrid, @@ -180,6 +177,9 @@ empty_like as empty_like, full as full, full_like as full_like, + geomspace as geomspace, + linspace as linspace, + logspace as logspace, ones as ones, ones_like as ones_like, zeros as zeros, From a48408ab7040fb496dee4d33865a952157f4ae28 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Fri, 25 Apr 2025 16:51:05 +0000 Subject: [PATCH 0828/1769] add retry to certain K8s-related calls to overcome transient failure --- jax/_src/clusters/k8s_cluster.py | 76 +++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 11 deletions(-) diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 11f93e36f647..bf3cc2bc5702 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -16,12 +16,50 @@ from contextlib import contextmanager from functools import cache +from itertools import chain +import numpy as np import os import socket +import time import textwrap import warnings from jax._src import clusters +import logging + +logger = logging.getLogger(__name__) + + +def retry( + func=None, + initial_delay=0, + wait=np.logspace(-1, 1, 5) * np.random.rand(5), + exceptions=Exception, +): + def retry_decorator(func): + def retry_driver(*args, **kwargs): + # Retry the function call with exponential backoff + for i, t in enumerate(chain([initial_delay], wait)): + logger.debug( + f"Trying {func.__name__} in {t:.2f} seconds, attempt {i}/{len(wait)}" + ) + time.sleep(t) + try: + return func(*args, **kwargs) + except exceptions as e: + if i == len(wait): + raise RuntimeError('Retry failed with all attempts exhausted') from e + finally: + logger.debug( + f"Finished {func.__name__} after {i+1} attempts" + ) + return retry_driver + + if func is None: + return retry_decorator + else: + return retry_decorator(func) + class K8sCluster(clusters.ClusterEnv): @@ -83,16 +121,16 @@ def _namespace(cls): @classmethod @cache + # in case of latency for core DNS to update pod IP to etcd/API server + @retry(exceptions=ValueError) def _pod(cls): + ip = socket.gethostbyname(os.getenv('HOSTNAME')) with cls._handle_api_exception(): - ip = socket.gethostbyname(os.getenv('HOSTNAME')) - pods = cls._core_api.list_namespaced_pod( + [pod] = cls._core_api.list_namespaced_pod( namespace=cls._namespace(), field_selector=f'status.podIP={ip}' ).items - assert len(pods) == 1, \ - f"Exactly 1 Kubernetes pod should have IP {ip}, got {len(pods)}." - return pods[0] + return pod @classmethod @cache @@ -140,10 +178,9 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str: if controller.kind == 'Job': # if job belongs to a jobset if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels: - return '{job_name}-0.{subdomain}:{port}'.format( + coordinator_hostname = '{job_name}-0.{subdomain}'.format( job_name=job.metadata.name, - subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name'], - port=cls._coordinator_port + subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name'] ) # if job is standalone else: @@ -197,12 +234,29 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str: raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg)) - return '{job_name}-0.{subdomain}:{port}'.format( + coordinator_hostname = '{job_name}-0.{subdomain}'.format( job_name=job.metadata.name, - subdomain=pod.spec.subdomain, - port=cls._coordinator_port + subdomain=pod.spec.subdomain ) + if timeout_secs: + # Ensure host pod is up before trying to communicate + # Retry in case of cached NXDOMAIN DNS failure (30 secs default) + @retry( + initial_delay=0.5, + wait=np.logspace(-1, 1.5, 8) * np.random.rand(8), + exceptions=socket.gaierror + ) + def wait_for_host(hostname): + socket.gethostbyname(hostname) + + wait_for_host(coordinator_hostname) + + return '{hostname}:{port}'.format( + hostname=coordinator_hostname, + port=cls._coordinator_port + ) + else: raise RuntimeError( 'In K8s, cluster automatic bootstrap only supports Job/JobSet.' From bbcb5e57e5cda88d10bed7b222459918553019b1 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 25 Apr 2025 09:56:59 -0700 Subject: [PATCH 0829/1769] Make callbacks work in unrolled loops on TPU. Before this change, we cached the lowering for the body function in scan. This causes problems on TPU when the body includes a callback and the loop is unrolled because we end up with multiple callbacks with the same channel id. Since the channel id must be globally unique, and the id is assigned when lowering, we can't cache in this case (on TPU when there are callbacks in the jaxpr). Since the previous lowering would unconditionally fail when compiling, I don't anticipate any performance regressions. PiperOrigin-RevId: 751459658 --- jax/_src/interpreters/mlir.py | 9 ++++++++- jax/_src/pjit.py | 9 ++++++++- tests/debugging_primitives_test.py | 17 ++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a6debad7cbdb..c7a53b1b4260 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2241,10 +2241,17 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, try: func_op = ctx.cached_primitive_lowerings[key] except KeyError: + num_callbacks = len(ctx.host_callbacks) func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, result_names=result_names) - ctx.cached_primitive_lowerings[key] = func_op + + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU every callback must have a globally unique channel, but the + # channel gets assigned during lowering. + has_callbacks = len(ctx.host_callbacks) > num_callbacks + if not has_callbacks or "tpu" not in ctx.platforms: + ctx.cached_primitive_lowerings[key] = func_op else: func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0f9bb4c4a5f7..3e22b06ba02d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2046,12 +2046,19 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. + num_callbacks = len(mod_ctx.host_callbacks) func = mlir.lower_jaxpr_to_fun( mod_ctx, name, jaxpr, effects, ctx.name_stack, arg_shardings=arg_shardings, result_shardings=result_shardings, use_sharding_annotations=False, api_name=api_name, arg_layouts=in_layouts, result_layouts=out_layouts) - mod_ctx.cached_primitive_lowerings[key] = func + + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU every callback must have a globally unique channel, but the + # channel gets assigned during lowering. + has_callbacks = len(mod_ctx.host_callbacks) > num_callbacks + if not has_callbacks or "tpu" not in mod_ctx.platforms: + mod_ctx.cached_primitive_lowerings[key] = func return func diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index bf86c82b9615..1e0408b8d2ba 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -16,7 +16,7 @@ import textwrap import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax import lax from jax.experimental import pjit @@ -275,6 +275,21 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "[1.23 2.35 0. ]\n") + @parameterized.parameters([False, True]) + def test_debug_print_in_unrolled_loop(self, use_jit): + def body(i, _): + jax.debug.print("{}", i) + if use_jit: + body = jax.jit(body) + @jax.jit + def f(): + return jax.lax.fori_loop(0, 4, body, None, unroll=2) + with jtu.capture_stdout() as output: + f() + jax.effects_barrier() + actual = tuple(sorted(map(int, output().splitlines()))) + self.assertEqual(actual, tuple(range(4))) + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTransformationTest(jtu.JaxTestCase): From 0cedff9f5d65c80f59e831e7a596493807a7ab27 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 25 Apr 2025 10:21:01 -0700 Subject: [PATCH 0830/1769] [XLA:Python] [JAX] Change JAX to use the _profiler module defined in xla/python. Remove uses of the ops module from JAX. JAX no longer uses the XlaBuilder APIs, instead using StableHLO. In passing, move make_tpu_client into xla_bridge so it can more easily import the _profiler module. Expose _jax.approx_top_k_reduction_output_size since we still need that function that we previously obtained via the _ops module. PiperOrigin-RevId: 751470399 --- jax/_src/lax/ann.py | 13 ++++++++++--- jax/_src/lib/BUILD | 1 + jax/_src/lib/__init__.py | 2 ++ jax/_src/profiler.py | 27 ++++++++++++++------------- jax/_src/xla_bridge.py | 36 +++++++++++++++++++++++++++--------- jax/experimental/profiler.py | 4 ++-- jax/lib/xla_client.py | 3 ++- jax/lib/xla_extension.py | 5 +++-- jaxlib/BUILD | 4 ++-- jaxlib/_jax/__init__.pyi | 12 ++++++++++-- jaxlib/tools/build_wheel.py | 1 + jaxlib/xla.cc | 12 ++++++++---- jaxlib/xla_client.py | 25 +------------------------ jaxlib/xla_client.pyi | 2 -- tests/xla_bridge_test.py | 5 +++-- 15 files changed, 86 insertions(+), 66 deletions(-) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index c9a68d84b024..bfcd45fba574 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -82,6 +82,8 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import lax +from jax._src.lib import _jax +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func @@ -231,9 +233,14 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if aggregate_to_topk: dims[reduction_dimension] = k elif core.is_constant_shape((reduction_input_size, k)): - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( - reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, - reduction_input_size_override)[0] + if jaxlib_extension_version >= 331: + dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size( + reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, + reduction_input_size_override)[0] + else: + dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( # type: ignore # pytype: disable=module-attr + reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, + reduction_input_size_override)[0] else: raise NotImplementedError( "approx_top_k with aggregate_to_topk=False not yet implemented when " diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index dd8ab0557657..20e6f66f6a5a 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -63,5 +63,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", + "@xla//xla/python:_profiler", ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7c75ac22cbe1..49aeff9c3763 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -93,6 +93,7 @@ def _parse_version(v: str) -> tuple[int, ...]: from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 from jaxlib._jax import pytree as pytree # noqa: F401 from jaxlib._jax import Device as Device # noqa: F401 + from jaxlib import _profiler as _profiler # noqa: F401 else: import jaxlib.xla_extension as _jax # type: ignore # pytype: disable=import-error # noqa: F401 from jaxlib.xla_extension import guard_lib as guard_lib # type: ignore # pytype: disable=import-error # noqa: F401 @@ -100,6 +101,7 @@ def _parse_version(v: str) -> tuple[int, ...]: from jaxlib.xla_extension import pmap_lib as pmap_lib # type: ignore # pytype: disable=import-error # noqa: F401 from jaxlib.xla_extension import pytree as pytree # type: ignore # pytype: disable=import-error # noqa: F401 from jaxlib.xla_extension import Device as Device # type: ignore # pytype: disable=import-error # noqa: F401 + from jaxlib.xla_extension import profiler as _profiler # type: ignore # pytype: disable=import-error # noqa: F401 import jaxlib.xla_client as xla_client # noqa: F401 diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index c787ea4c0223..71ca785b339f 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -32,14 +32,14 @@ traceback_util.register_exclusion(__file__) from jax._src import xla_bridge -from jax._src.lib import xla_client +from jax._src.lib import _profiler -_profiler_server: xla_client.profiler.ProfilerServer | None = None +_profiler_server: _profiler.ProfilerServer | None = None logger = logging.getLogger(__name__) -def start_server(port: int) -> xla_client.profiler.ProfilerServer: +def start_server(port: int) -> _profiler.ProfilerServer: """Starts the profiler server on port `port`. Using the "TensorFlow profiler" feature in `TensorBoard @@ -59,7 +59,7 @@ def start_server(port: int) -> xla_client.profiler.ProfilerServer: # is for start_trace), but I'm putting it here to be safe. xla_bridge.get_backend() - _profiler_server = xla_client.profiler.start_server(port) + _profiler_server = _profiler.start_server(port) return _profiler_server @@ -126,7 +126,7 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, # fail and no TPU operations will be included in the profile. xla_bridge.get_backend() - _profile_state.profile_session = xla_client.profiler.ProfilerSession() + _profile_state.profile_session = _profiler.ProfilerSession() _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) @@ -219,7 +219,7 @@ def stop_and_get_fdo_profile() -> bytes | str: if _profile_state.profile_session is None: raise RuntimeError("No profile started") xspace = _profile_state.profile_session.stop() - fdo_profile = xla_client.profiler.get_fdo_profile(xspace) + fdo_profile = _profiler.get_fdo_profile(xspace) _profile_state.reset() return fdo_profile @@ -257,7 +257,7 @@ def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfett stop_trace() -class TraceAnnotation(xla_client.profiler.TraceMe): +class TraceAnnotation(_profiler.TraceMe): """Context manager that generates a trace event in the profiler. The trace event spans the duration of the code enclosed by the context. @@ -359,7 +359,8 @@ def device_memory_profile(backend: str | None = None) -> bytes: Returns: A byte string containing a binary `pprof`-format protocol buffer. """ - return xla_client.heap_profile(xla_bridge.get_backend(backend)) + client = xla_bridge.get_backend(backend) + return gzip.compress(client.heap_profile()) def save_device_memory_profile(filename, backend: str | None = None) -> None: @@ -389,7 +390,7 @@ def __init__(self, retries: int, percentile: int): self.collected_fdo: str | None = None self.called_times: int = 0 self.fdo_profiles: list[Any] = [] - self.current_session: xla_client.profiler.ProfilerSession | None = None + self.current_session: _profiler.ProfilerSession | None = None def consume_fdo_profile(self) -> str | None: if self.collected_fdo is not None: @@ -398,7 +399,7 @@ def consume_fdo_profile(self) -> str | None: if not self.is_enabled() or self.called_times != self.retries: return None - self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.collected_fdo = _profiler.aggregate_profiled_instructions( self.fdo_profiles, self.percentile ) return self.collected_fdo @@ -422,17 +423,17 @@ def trace(cls, runner: PGLEProfiler | None): or not runner.is_enabled() or runner.is_fdo_consumed()): yield else: - options = xla_client.profiler.ProfileOptions() + options = _profiler.ProfileOptions() options.enable_hlo_proto = True options.raise_error_on_start_failure = True - runner.current_session = xla_client.profiler.ProfilerSession(options) + runner.current_session = _profiler.ProfilerSession(options) try: yield finally: xspace = runner.current_session.stop() runner.fdo_profiles.append( - xla_client.profiler.get_fdo_profile(xspace) + _profiler.get_fdo_profile(xspace) ) runner.current_session = None diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 644c395cb551..4c8373cc1105 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -44,6 +44,7 @@ from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import _jax +from jax._src.lib import _profiler logger = logging.getLogger(__name__) @@ -126,6 +127,24 @@ def _at_fork(): # Backends +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not _jax.pjrt_plugin_loaded('tpu'): + c_api = xla_client.load_pjrt_plugin_dynamically( + "tpu", library_path or "libtpu.so" + ) + _profiler.register_plugin_profiler(c_api) + assert _jax.pjrt_plugin_loaded('tpu') + if not _jax.pjrt_plugin_initialized('tpu'): + _jax.initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _jax.get_c_api_client('tpu', options) + def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): @@ -140,7 +159,7 @@ def _log_warning(): t.start() try: - client = xla_client.make_tpu_client( + client = make_tpu_client( get_tpu_library_path(), _options_from_jax_configs("tpu")) finally: @@ -437,12 +456,11 @@ def get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -if hasattr(xla_client, "make_tpu_client"): - # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, - # and then fail loudly on initialization failure. - register_backend_factory( - 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, - fail_quietly=True) +# TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, +# and then fail loudly on initialization failure. +register_backend_factory( + 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, + fail_quietly=True) def _get_pjrt_plugin_names_and_library_paths( @@ -660,7 +678,7 @@ def factory(): ) if library_path is not None: c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) else: assert c_api is not None xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) @@ -1214,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): "JAX TPU support not installed; cannot generate TPU topology. See" " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") if not xla_client.pjrt_plugin_initialized("tpu"): xla_client.initialize_pjrt_plugin("tpu") diff --git a/jax/experimental/profiler.py b/jax/experimental/profiler.py index 766d20472155..f22fba50092b 100644 --- a/jax/experimental/profiler.py +++ b/jax/experimental/profiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import xla_client +from jax._src.lib import _profiler def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: @@ -30,4 +30,4 @@ def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: Serialized [ProfiledInstructionsProto](https://github.com/openxla/xla/blob/main/third_party/tsl/tsl/profiler/protobuf/profiled_instructions.proto). """ - return xla_client.profiler.get_profiled_instructions_proto(tensorboard_dir) + return _profiler.get_profiled_instructions_proto(tensorboard_dir) diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 12f48b21f1c3..c81df076a6b2 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gzip as _gzip from jax._src.lib import xla_client as _xc _deprecations = { @@ -87,7 +88,7 @@ "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" " will be removed in JAX v0.7.0" ), - _xc.heap_profile, + lambda client: _gzip.compress(client.heap_profile()), ), "mlir_api_version": ( ( diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 452d004b2f6d..6b58a72783c9 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import jax._src.lib from jax._src.lib import _jax _deprecations = { @@ -71,7 +72,7 @@ "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _jax.pmap_lib), "profiler": ( "jax.lib.xla_extension.profiler is deprecated.", - _jax.profiler, + jax._src.lib._profiler, ), "pytree": ( "jax.lib.xla_extension.pytree is deprecated.", @@ -127,7 +128,7 @@ jax_jit = _jax.jax_jit mlir = _jax.mlir pmap_lib = _jax.pmap_lib - profiler = _jax.profiler + profiler = jax._src.lib._profiler pytree = _jax.pytree else: diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e4efac1462f9..6de18a5588e7 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -150,6 +150,7 @@ pywrap_library( "//jaxlib/mlir/_mlir_libs:_tpu_ext", "//jaxlib/mlir/_mlir_libs:_triton_ext", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", + "@xla//xla/python:_profiler", ], ) @@ -320,6 +321,7 @@ nanobind_pywrap_extension( "@xla//xla:util", "@xla//xla/backends/cpu/collectives:cpu_collectives", "@xla//xla/ffi:ffi_api", + "@xla//xla/hlo/builder/lib:approx_topk_shape", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:pjrt_api", @@ -341,9 +343,7 @@ nanobind_pywrap_extension( "@xla//xla/python:logging", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:ops", "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:profiler", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:types", "@xla//xla/python:version", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 19ffbb805404..9559d2862714 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -20,7 +20,7 @@ import enum import inspect import types from typing import Any, ClassVar, TypeVar, overload -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Mapping, Iterator, Sequence import numpy as np @@ -625,7 +625,7 @@ def get_mock_gpu_client( ) -> Client: ... def get_c_api_client( platform_name: str, - options: dict[str, str | int | list[int] | float | bool], + options: Mapping[str, str | int | list[int] | float | bool], distributed_client: DistributedRuntimeClient | None = ..., ) -> Client: ... def get_default_c_api_topology( @@ -1027,3 +1027,11 @@ class TransferServer: def connect(self, address: str) -> TransferConnection: ... def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ... + +def approx_top_k_reduction_output_size( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: bool | None = ..., + input_size_override: int | None = ...) -> tuple[int, int]: ... diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index b40ccd6f6870..ec306ff741f4 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -209,6 +209,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", f"{source_file_prefix}jaxlib/_jax.{pyext}", + f"{source_file_prefix}jaxlib/_profiler.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 8c70f4bc7646..219e220111af 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -106,6 +106,7 @@ limitations under the License. #include "jaxlib/sharding.h" #include "jaxlib/traceback.h" #include "jaxlib/xla_compiler.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" @@ -117,11 +118,9 @@ limitations under the License. #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/ops.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" -#include "xla/python/profiler.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" #include "tsl/platform/platform.h" @@ -579,8 +578,6 @@ NB_MODULE(_jax, m) { jax::BuildConfigSubmodule(m); BuildIfrtProgramsSubmodule(m); - BuildProfilerSubmodule(m); - BuildOpsSubmodule(m); BuildPytreeSubmodule(m); jax::BuildGuardSubmodule(m); jax::BuildJaxjitSubmodule(m); @@ -950,6 +947,13 @@ NB_MODULE(_jax, m) { nb::arg("device_list")); m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; + + m.def("approx_top_k_reduction_output_size", + xla::ValueOrThrowWrapper(ApproxTopKReductionOutputSize), + nb::arg("input_size"), nb::arg("rank"), nb::arg("top_k"), + nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, + nb::arg("input_size_override") = -1); + } // NOLINT(readability/fn_size) } // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 9badaf355f0c..34631f328c29 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -20,7 +20,6 @@ from collections.abc import Mapping, Sequence import contextlib import enum -import gzip import inspect import logging import os @@ -45,12 +44,10 @@ # pylint: disable=invalid-sequence-index ifrt_programs = _xla.ifrt_programs -ops = _xla.ops -profiler = _xla.profiler # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 330 +_version = 331 # An internal increasing version number for protecting jaxlib code against # ifrt changes. @@ -156,21 +153,6 @@ def make_c_api_client( return _xla.get_c_api_client(plugin_name, options, distributed_client) -def make_tpu_client( - library_path: str | None = None, options: _NameValueMapping | None = None -): - """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" - if not pjrt_plugin_loaded('tpu'): - c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') - profiler.register_plugin_profiler(c_api) - assert pjrt_plugin_loaded('tpu') - if not pjrt_plugin_initialized('tpu'): - initialize_pjrt_plugin('tpu') - if options is None: - options = {} - return _xla.get_c_api_client('tpu', options) - - def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: """Generates the PjRt GPU plugin options. @@ -959,11 +941,6 @@ def tracebacks(enabled=True): Traceback.enabled = saved -def heap_profile(client: Client) -> bytes: - """Returns a gzipped pprof protocol buffer containing a heap profile.""" - return gzip.compress(client.heap_profile()) - - XlaRuntimeError = _xla.XlaRuntimeError # Perform one last garbage collection of deferred Python references. This is diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 1a6751066e7c..2b78a31fea72 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -41,13 +41,11 @@ from jaxlib._jax import Layout as Layout from jaxlib._jax import LoadedExecutable as LoadedExecutable from jaxlib._jax import Memory as Memory from jaxlib._jax import NamedSharding as NamedSharding -from jaxlib._jax import ops as ops from jaxlib._jax import OpSharding as OpSharding from jaxlib._jax import PjRtLayout as PjRtLayout from jaxlib._jax import PmapSharding as PmapSharding from jaxlib._jax import PrimitiveType as PrimitiveType from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics -from jaxlib._jax import profiler as profiler from jaxlib._jax import Shape as Shape from jaxlib._jax import Sharding as Sharding from jaxlib._jax import SingleDeviceSharding as SingleDeviceSharding diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 97e8765cc096..3f1ac1787d2a 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -23,6 +23,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import _profiler from jax._src.lib import xla_client as xc config.parse_flags_with_absl() @@ -136,7 +137,7 @@ def test_register_plugin(self): "name1:path1,name2:path2,name3" ) with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] @@ -174,7 +175,7 @@ def test_register_plugin_with_config(self): ) with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] From 5da3fe60d1a930e3fb0638d7bb9d9d81a625f1ef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Apr 2025 10:31:28 -0700 Subject: [PATCH 0831/1769] Raise a better error in `jax.linearize` when vma's don't match between primals and tangents. Before: ``` ValueError: linearized function called on tangent values inconsistent with the original primal values: got tangent aval float64[1]{x} for primal aval float64[1] but expected float64[1] ``` After: ``` ValueError: linearized function called on tangent values inconsistent with the original primal values: Got tangent aval float64[1]{x} for primal aval float64[1] but expected float64[1]. This might be fixed by: * applying `jax.lax.pvary(..., ('x',))` to the primal value passed to `jax.linearize`; ``` Fixes https://github.com/jax-ml/jax/issues/28260 PiperOrigin-RevId: 751474552 --- jax/_src/api.py | 25 +++++++++++++++++++++---- tests/shard_map_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b608fd64a89..2812f8a808f8 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1992,10 +1992,27 @@ def fun(*tangents): for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): expected_tangent_aval = primal_aval.to_tangent_aval() if not core.typecompat(expected_tangent_aval, tangent_aval): - raise ValueError("linearized function called on tangent values inconsistent with " - "the original primal values: " - f"got tangent aval {tangent_aval} for primal aval {primal_aval} " - f"but expected {expected_tangent_aval}") + extra_msg = '' + if (isinstance(primal_aval, core.ShapedArray) and + isinstance(tangent_aval, core.ShapedArray) and + primal_aval.vma != tangent_aval.vma): + pvary_applications = [] + if left := tangent_aval.vma - primal_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pvary(..., {tuple(left)})` to the primal" + " value passed to `jax.linearize`") + if left := primal_aval.vma - tangent_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pvary(..., {tuple(left)})` to the tangent" + " value passed to the callable `f_jvp` returned by" + " `jax.linearize`") + extra_msg = " \nThis might be fixed by:\n" + "\n".join( + f" * {d};" for d in pvary_applications) + raise ValueError( + "linearized function called on tangent values inconsistent with " + "the original primal values:\n" + f"Got tangent aval {tangent_aval} for primal aval {primal_aval} " + f"but expected {expected_tangent_aval}.{extra_msg}") tangents_out = eval_jaxpr(jaxpr, consts, *tangents) tangents_out_ = iter(tangents_out) full_out = [pval.get_known() if pval.is_known() else next(tangents_out_) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ecad33b99d8b..380c9e774d05 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3028,6 +3028,34 @@ def f(x): jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x') )(jnp.ones(2,)) # doesn't crash + def test_shmap_linearize_and_linearize_transpose_error(self): + mesh = jtu.create_mesh((2,), ('x',)) + + def f(x): + return jnp.mean(x ** 2) + + def m(p, t): + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + with self.assertRaisesRegex( + ValueError, + r"applying `jax.lax.pvary\(..., \('x',\)\)` to the primal value passed"): + shard_map(partial(m, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + + def m2(p, t): + p = jax.lax.pvary(p, 'x') # fixes the issue + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + shard_map(partial(m2, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + class FunSpec(NamedTuple): name: str From beb0181ec4665a4bd1ad25cb27ca39812e081b6e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Apr 2025 10:33:39 -0700 Subject: [PATCH 0832/1769] Allow decorator factory pattern for `jax.shard_map` i.e. `@jax.shard_map(in_specs=... out_specs=..., ...)`. Note the experimental API of shard_map will not support this pattern. PiperOrigin-RevId: 751475329 --- jax/_src/shard_map.py | 9 ++++++--- tests/shard_map_test.py | 15 +++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 9870991c66f1..fce7ed824cb0 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -72,7 +72,7 @@ AxisName = Hashable -def shard_map(f, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), +def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), in_specs: Specs | None = None, mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): """Map a function over shards of data using a mesh of devices. @@ -120,8 +120,11 @@ def shard_map(f, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), arguments corresponding to those of ``f`` and produces output corresponding to that of ``f``. """ - return _shard_map(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, - axis_names=axis_names, check_vma=check_vma) + kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=axis_names, check_vma=check_vma) + if f is None: + return lambda g: _shard_map(g, **kwargs) + return _shard_map(f, **kwargs) def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs | Callable[[], Specs], diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 380c9e774d05..fbcca5b0e394 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1754,7 +1754,7 @@ def bar(x): def test_res_forwarding_optimization(self, jit, remat): mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x) if jit: @@ -1777,7 +1777,7 @@ def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x.sum()) + x, jax.lax.exp(x) if jit: @@ -1800,7 +1800,7 @@ def test_check_rep_failure_inside_rule(self, jit): mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): return jax.lax.psum(((w * x) ** 2).sum(), 'i') return f(x) @@ -1816,8 +1816,8 @@ def test_conv_general_dilated(self): dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) - @partial(shard_map, mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), - out_specs=P(None, None)) + @shard_map(mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), + out_specs=P(None, None)) def f(x, y): return lax.psum(dot(x, y), 'i') @@ -2456,8 +2456,7 @@ def test_grad_remat(self): args = [jnp.arange(6.).reshape(3, 2), jnp.arange(6.).reshape(3, 2, 1)] @partial(jax.remat, policy=lambda *_, **__: True) - @partial(shard_map, mesh=mesh, in_specs=(P('j'), P('i')), - out_specs=P('i', 'j')) + @shard_map(mesh=mesh, in_specs=(P('j'), P('i')), out_specs=P('i', 'j')) def f(x, y): return jnp.dot(x, y) jax.grad(lambda x, y: f(x, y).sum())(*args) @@ -2466,7 +2465,7 @@ def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j')) + @shard_map(mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) From 5d53967b3a4619b21bf8ed2ef4f5b5bbd790be59 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 25 Apr 2025 10:56:31 -0700 Subject: [PATCH 0833/1769] change build target PiperOrigin-RevId: 751484337 --- tests/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 3745c7cf7882..749819a5bc08 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -378,7 +378,7 @@ jax_multiplatform_test( enable_backends = ["gpu"], enable_configs = [ "gpu_h100", - "gpu_a100_shardy", + "gpu_h100_shardy", ], tags = [ "config-cuda-only", @@ -1324,7 +1324,7 @@ jax_multiplatform_test( "tpu_v2", "tpu_v3_x4", "tpu_v4_x4", - "gpu_a100_shardy", + "gpu_h100_shardy", "tpu_v3_x4_shardy", ], ) From fbb0cbbc6bfe30941076a365a427e12efbb5253b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 25 Apr 2025 11:11:51 -0700 Subject: [PATCH 0834/1769] roll back #28210 because it broke an internal test Reverts 8fab28883cf2070053a6ae0825ed2bbfe2654fb5 PiperOrigin-RevId: 751490028 --- jax/_src/interpreters/ad.py | 10 +++--- jax/_src/interpreters/partial_eval.py | 36 ++++++++------------ jax/_src/lax/control_flow/loops.py | 27 ++------------- tests/api_test.py | 4 +-- tests/core_test.py | 6 ++-- tests/lax_control_flow_test.py | 49 +-------------------------- 6 files changed, 27 insertions(+), 105 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 435e9027f5b3..0f11e0d72f12 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1191,12 +1191,10 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ arg_names=new_arg_names, result_paths=new_result_paths, ) - constvars = jaxpr.jaxpr.constvars - new_effects = pe._renumber_effects( - (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), - jaxpr.jaxpr.effects) - new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, - new_effects, new_debug_info) + new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, + new_invars, new_outvars, jaxpr.jaxpr.eqns, + jaxpr.jaxpr.effects, + new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 7203af95fb7b..1317a584f6c3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1559,22 +1559,18 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) - constvars, invars = closed_jaxpr.jaxpr.constvars, closed_jaxpr.jaxpr.invars - new_invars = _move_to_front(invars, to_move) - new_effs = _renumber_effects( - (*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects) - new_jaxpr = Jaxpr(constvars, new_invars, closed_jaxpr.jaxpr.outvars, - closed_jaxpr.jaxpr.eqns, new_effs, - closed_jaxpr.jaxpr.debug_info) + new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) + id_map = {id(v): i for i, v in enumerate(new_invars)} + idx_map = {i: id_map[id(v)] for i, v in enumerate(closed_jaxpr.jaxpr.invars)} + new_effs = {e.replace(input_index=idx_map[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e + for e in closed_jaxpr.jaxpr.effects} + new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, + closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, + new_effs, closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr -def _renumber_effects(new_vars, old_vars, effs): - newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} - old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} - return {e.replace(input_index=old_to_new[e.input_index]) - if isinstance(e, effects.JaxprInputEffect) else e for e in effs} - def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + [elt for elt, move in zip(lst, to_move) if not move]) @@ -1594,6 +1590,7 @@ def _move_outvars_to_back(jaxpr, to_move): return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1674,19 +1671,16 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - eqn_invar = eqn.invars[eff.input_index] - if eqn_invar in mut_arrays: + invar = eqn.invars[eff.input_index] + if invar in mut_arrays: continue - if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: - # TODO(mattjj): ask for forgiveness - dbg = type('Fake', (), {'resolve_result_paths': lambda _: None})() + if (input_index := all_vars.get(invar, sentinel)) is sentinel: raise ValueError( f"`JaxprInputEffect` {eff} does not have " - f"corresponding jaxpr input: {eqn_invar=}." + f"corresponding input: {invar}." f"\n Equation: {eqn}\n" - f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 58ef64add37a..d1220ba3fdb3 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -312,9 +312,6 @@ def _create_jaxpr(init): init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest num_carry = len(init_flat) - num_xs = len(x_avals) - num_ys = len(jaxpr.out_avals) - num_carry - del init_flat _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) @@ -330,42 +327,22 @@ def _create_jaxpr(init): unroll = max(length, 1) if unroll else 1 if unroll < 1: raise ValueError("`unroll` must be a `bool` or a positive `int`.") - if attrs_tracked: in_state = _get_states(attrs_tracked) in_flat = [*in_state, *in_flat] num_carry += len(in_state) - - # If the body forwards an input carry to an output carry, that input is - # read-only and can be moved to be a const. Doing so can lead to efficiency - # wins, e.g. if the scan is inside a cond with a batched predicate. - carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) - move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] - if any(move_to_const): - jaxpr = pe.prune_closed_jaxpr_outputs( - jaxpr, [not m for m in move_to_const] + [True] * num_ys) - jaxpr = pe.move_binders_to_front( - jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) - in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat) - consts = [*new_consts, *consts] - num_carry -= len(new_consts) - out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll, _split_transpose=_split_transpose) - - if any(move_to_const): - out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) - + unroll=unroll, + _split_transpose=_split_transpose) if attrs_tracked: num_ext = (len(out) - len(in_state) - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) out_state, out, out_append = split_list(out, [len(in_state), num_ext]) out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) _set_states(attrs_tracked, out_attrs) - return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): diff --git a/tests/api_test.py b/tests/api_test.py index 7a1218d3790a..daea53d0fb38 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6853,13 +6853,13 @@ def body(c, _): self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=0) + expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule used_outputs[7] = expected_used_inputs[7] = True used_outputs[6] = expected_used_inputs[6] = True self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=0) + expected_num_eqns=1) # If we use the value at index 3 only, some of the hidden sequence must be # kept but the rest pruned. diff --git a/tests/core_test.py b/tests/core_test.py index 646705ebf281..e39487035751 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -364,15 +364,15 @@ def g_vmap(x): def test_dropvar_avals(self): def f(x): def body(c, _): - x1, x2 = c - return (2 * x1, 2 * x2), None + return c, None (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1) return [x2] aval = core.ShapedArray((), jnp.dtype('int32')) pval = pe.PartialVal.unknown(aval) jaxpr, _, _ = pe.trace_to_jaxpr_nounits( - lu.wrap_init(f, debug_info=debug_info("test", f, (0,), {})), + lu.wrap_init(f, + debug_info=debug_info("test", f, (0,), {})), [pval], False) dropvar, b = jaxpr.eqns[0].outvars self.assertEqual(dropvar.aval, aval) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 422ef769e392..78026968d2cd 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3122,7 +3122,7 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash - def test_while_readonly_carry_optimization(self): + def test_readonly_carry_optimization(self): # https://github.com/google/flax/issues/4700 def foo(w, x, c_max): def while_cond(val): @@ -3204,53 +3204,6 @@ def body_fun(c): outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) self.assertAllClose(outs, (0., 1., 5.)) - def test_scan_readonly_carry_optimization(self): - # https://github.com/google/flax/issues/4709 - def f(x, y): - def g(_, y): - y, _ = jax.lax.scan(lambda y, _: (y, None), y, None, length=1) - return y - return jax.lax.cond(x < 0, g, g, x, y) - xs = jnp.arange(3.) - y = 3. - jax.vmap(f, (0, None), None)(xs, y) # don't crash - - @parameterized.parameters(itertools.product(range(3), repeat=4)) - @jtu.run_on_devices("cpu") - def test_scan_constification_correctness( - self, - seed, - num_body_consts, - num_inplace_fwds, - num_noninplace_fwds): - - num_fwds = num_inplace_fwds + num_noninplace_fwds - num_carry = num_fwds + 4 - num_xs = 2 - num_ys = 3 - - rng = np.random.RandomState(seed) - perm = rng.permutation(num_carry) - iperm = np.argsort(perm) - - body_consts = [rng.randn(3) for _ in range(num_body_consts)] - init_vals = list(rng.uniform(size=num_carry)) - - def body_fun(c, _): - c = [c[i] for i in iperm] - inplace_fwds, noninplace_fwds, dont_fwd = split_list( - c, [num_inplace_fwds, num_noninplace_fwds]) - dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) - for x in dont_fwd] - new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] - new_c = [new_c_perm[i] for i in perm] - return new_c, [0 for _ in range(num_ys)] - - xs = [jnp.arange(2.) for _ in range(num_xs)] - outs = jax.lax.scan(body_fun, init_vals, xs)[0] - outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] - self.assertAllClose(outs, outs_ref, check_dtypes=False) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 434efefe516494dfbde2c0a88d907e69075ae7f0 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 25 Apr 2025 11:19:18 -0700 Subject: [PATCH 0835/1769] [Mosaic] Support bf16 select and cmp <= TPUv4. PiperOrigin-RevId: 751493322 --- .../tpu/transforms/canonicalize_mosaic.cc | 77 +++++++++++++------ 1 file changed, 53 insertions(+), 24 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 247f47431745..66c227d62a74 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -15,7 +15,7 @@ limitations under the License. #include #include -#include +#include #include #include #include @@ -67,6 +67,9 @@ struct CanonicalizeContext { int hardware_generation; }; +bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, + Operation &op); + LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, tpu::MatmulOp op) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); @@ -351,21 +354,23 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, } } if (should_rewrite_op) { - auto result_ty = dyn_cast(op.getResult(0).getType()); - if (!result_ty) { + if (!res_ty) { op.emitOpError("Not implemented: Unexpected result type"); return failure(); } - auto result_element_type = result_ty.getElementType(); - if (!result_element_type.isF32() && !result_element_type.isBF16()) { - op.emitOpError("Not implemented: Unexpected result element type"); - return failure(); - } - // Do the new op in f32, then truncate to the original element type. + // Do the new op in f32, then truncate to the original element type if + // needed. For example, result of arith::CmpF is i1 and doesn't need to be + // truncated. + bool should_truncate = !isa(op); + auto new_res_ty = + VectorType::get(shape, should_truncate ? builder.getF32Type() + : res_ty.getElementType()); auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(), - new_operands, target_f32_ty); - new_op = builder.create(op.getLoc(), res_ty, - new_op->getResult(0)); + new_operands, new_res_ty, op.getAttrs()); + if (should_truncate) { + new_op = builder.create(op.getLoc(), res_ty, + new_op->getResult(0)); + } op.replaceAllUsesWith(new_op); op.erase(); } @@ -547,6 +552,9 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx, op.getLoc(), cond, op.getTrueValue(), op.getFalseValue()); op.replaceAllUsesWith(new_op.getResult()); op.erase(); + if (need_elementwise_canonicalization(ctx, *new_op.getOperation())) { + return canonicalize_elementwise(ctx, *new_op.getOperation()); + } return success(); } @@ -688,18 +696,39 @@ const llvm::StringMap &rules() { return *rules; } -bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) { - if (isa(op)) { - auto vec_ty = dyn_cast(op.getOperand(0).getType()); - if (vec_ty && vec_ty.getElementType().isBF16() && - ctx.hardware_generation >= 4) { - return false; - } - return true; - } - return isa(op); +const llvm::StringMap &bf16_upcast_min_supported_versions() { + constexpr int kAlwaysUpcast = std::numeric_limits::max(); + static const auto m = new llvm::StringMap{ + {arith::DivFOp::getOperationName(), 4}, + {arith::SelectOp::getOperationName(), 5}, + {arith::CmpFOp::getOperationName(), 5}, + {arith::MulFOp::getOperationName(), kAlwaysUpcast}, + {arith::AddFOp::getOperationName(), kAlwaysUpcast}, + {arith::SubFOp::getOperationName(), kAlwaysUpcast}, + {arith::MaximumFOp::getOperationName(), kAlwaysUpcast}, + {arith::MinimumFOp::getOperationName(), kAlwaysUpcast}, + {math::PowFOp::getOperationName(), kAlwaysUpcast}, + {math::TanhOp::getOperationName(), kAlwaysUpcast}, + {math::ExpOp::getOperationName(), kAlwaysUpcast}, + {math::LogOp::getOperationName(), kAlwaysUpcast}, + }; + return *m; +} + +bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, + Operation &op) { + // Only rewrite when the hardware generation is below the minimum supported + // version. + auto it = + bf16_upcast_min_supported_versions().find(op.getName().getStringRef()); + if (it == bf16_upcast_min_supported_versions().end() || + ctx.hardware_generation >= it->second) { + return false; + } + return llvm::any_of(op.getOperands(), [](Value operand) { + auto vty = dyn_cast(operand.getType()); + return vty && vty.getElementType().isBF16(); + }); } class MosaicCanonicalizer { From 58a650b4474da947f7df9a46fb84cc990fd48931 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Apr 2025 11:14:03 -0700 Subject: [PATCH 0836/1769] Ensure __jax_array__ works properly with JIT disabled. The roadmap for __jax_array__ involves no longer triggering it during abstractification, so we need to make sure that the extensibilty tests pass even when JIT is disabled. --- jax/_src/numpy/lax_numpy.py | 53 +++++++++++++------------- jax/_src/numpy/reductions.py | 55 +++++++++++++-------------- jax/_src/numpy/tensor_contractions.py | 8 ++-- jax/_src/numpy/ufuncs.py | 36 +++++++++--------- jax/_src/numpy/util.py | 1 + tests/array_extensibility_test.py | 5 +++ 6 files changed, 81 insertions(+), 77 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7421111b23a4..34417adeceb1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1352,7 +1352,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: [11, 8], [12, 9]]], dtype=int32) """ - util.check_arraylike("rot90", m) + m = util.ensure_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " f"two, but got first argument of shape {np.shape(m)}, " @@ -1588,6 +1588,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: [[ 71.57 -68.2 ] [-36.87 33.69]] """ + z = util.ensure_arraylike('angle', z) re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re) @@ -2073,7 +2074,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: >>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32) """ - util.check_arraylike("ravel", a) + a = util.ensure_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") return reshape(a, (np.size(a),), order) @@ -2138,8 +2139,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], """ assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) - util.check_arraylike("ravel_multi_index", *multi_index) - multi_index_arr = [asarray(i) for i in multi_index] + multi_index_arr = list(util.ensure_arraylike_tuple("ravel_multi_index", multi_index)) for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, @@ -2470,7 +2470,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: >>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3) """ - util.check_arraylike("swapaxes", a) + a = util.ensure_arraylike("swapaxes", a) perm = np.arange(np.ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) @@ -2629,7 +2629,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: - util.check_arraylike("interp", x, xp, fp) + x, xp, fp = util.ensure_arraylike("interp", x, xp, fp) if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") x_arr, xp_arr = util.promote_dtypes_inexact(x, xp) @@ -3091,6 +3091,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ + args = util.ensure_arraylike_tuple("broadcast_arrays", args) return util._broadcast_arrays(*args) @@ -3553,7 +3554,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: [-0., 0., -3.], [-1., 1., 2.]], dtype=float32) """ - util.check_arraylike("fix", x) + x = util.ensure_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") zero = _lax_const(x, 0) @@ -3771,7 +3772,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return tuple(zeros(calculated_size, int) for dim in arr.shape) flat_indices = reductions.cumsum( bincount(reductions.cumsum(mask), length=calculated_size)) - strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(dtypes.int_) + strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(flat_indices.dtype) out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if fill_value is not None: fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,) @@ -6816,11 +6817,11 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, # TODO(phawkins): remove this annotation after fixing jnp types. dx_array: Array if x is None: - util.check_arraylike('trapezoid', y) + y = util.ensure_arraylike('trapezoid', y) y_arr, = util.promote_dtypes_inexact(y) dx_array = asarray(dx) else: - util.check_arraylike('trapezoid', y, x) + y, x = util.ensure_arraylike('trapezoid', y, x) y_arr, x_arr = util.promote_dtypes_inexact(y, x) if x_arr.ndim == 1: dx_array = diff(x_arr) @@ -6941,7 +6942,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: [[5, 0], [7, 8]]], dtype=int32) """ - util.check_arraylike("tril", m) + m = util.ensure_arraylike("tril", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") @@ -7008,7 +7009,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: [[5, 6], [0, 8]]], dtype=int32) """ - util.check_arraylike("triu", m) + m = util.ensure_arraylike("triu", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") @@ -7065,7 +7066,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32) """ - util.check_arraylike("trace", a) + a = util.ensure_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") @@ -7582,7 +7583,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, >>> jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32) """ - util.check_arraylike("diagonal", a) + a = util.ensure_arraylike("diagonal", a) if np.ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") @@ -7668,11 +7669,11 @@ def diag(v: ArrayLike, k: int = 0) -> Array: >>> jnp.diag(x) Array([1, 5, 9], dtype=int32) """ + v = util.ensure_arraylike("diag", v) return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) -def _diag(v, k): - util.check_arraylike("diag", v) +def _diag(v: Array, k: int): v_shape = np.shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) @@ -8655,12 +8656,12 @@ def nanargmax( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") + a = util.ensure_arraylike("nanargmax", a) return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmax(a, axis: int | None = None, keepdims: bool = False): - util.check_arraylike("nanargmax", a) +def _nanargmax(a: Array, axis: int | None = None, keepdims: bool = False): if not issubdtype(_dtype(a), np.inexact): return argmax(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) @@ -8716,12 +8717,12 @@ def nanargmin( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") + a = util.ensure_arraylike("nanargmin", a) return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmin(a, axis: int | None = None, keepdims : bool = False): - util.check_arraylike("nanargmin", a) +def _nanargmin(a: Array, axis: int | None = None, keepdims : bool = False): if not issubdtype(_dtype(a), np.inexact): return argmin(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) @@ -8864,7 +8865,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3) """ - util.check_arraylike("rollaxis", a) + a = util.ensure_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = np.ndim(a) axis = _canonicalize_axis(axis, a_ndim) @@ -8941,7 +8942,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") - arr = lax.gt(arr, _lax_const(a, 0)).astype('uint8') + arr = lax.gt(arr, _lax_const(arr, 0)).astype('uint8') bits = arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] @@ -9101,7 +9102,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.gcd(x1, x2) Array([ 6, 3, 12], dtype=int32) """ - util.check_arraylike("gcd", x1, x2) + x1, x2 = util.ensure_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), np.integer): raise ValueError("Arguments to jax.numpy.gcd must be integers.") @@ -9148,7 +9149,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.lcm(x1, x2) Array([12, 36, 12], dtype=int32) """ - util.check_arraylike("lcm", x1, x2) + x1, x2 = util.ensure_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) if not issubdtype(_dtype(x1), np.integer): @@ -9667,9 +9668,9 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', Array([0, 2, 5, 1, 1], dtype=int32) """ if sorter is None: - util.check_arraylike("searchsorted", a, v) + a, v = util.ensure_arraylike("searchsorted", a, v) else: - util.check_arraylike("searchsorted", a, v, sorter) + a, v, sorter = util.ensure_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 77d8662fd9a9..d2ae80925597 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, ensure_arraylike, + _broadcast_to, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -54,8 +54,7 @@ def _isscalar(element: Any) -> bool: def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array: # simplified version of jnp.moveaxis() for local use. - check_arraylike("moveaxis", a) - a = lax_internal.asarray(a) + a = ensure_arraylike("moveaxis", a) source = _canonicalize_axis(source, np.ndim(a)) destination = _canonicalize_axis(destination, np.ndim(a)) perm = [i for i in range(np.ndim(a)) if i != source] @@ -83,8 +82,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: def check_where(name: str, where: ArrayLike | None) -> Array | None: if where is None: return where - check_arraylike(name, where) - where_arr = lax_internal.asarray(where) + where_arr = ensure_arraylike(name, where) if where_arr.dtype != bool: # Deprecation added 2024-12-05 deprecations.warn( @@ -113,7 +111,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, # exists, passing along all its arguments. if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") - check_arraylike(name, a) + a = ensure_arraylike(name, a) where_ = check_where(name, where_) dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") @@ -122,7 +120,6 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, raise ValueError(f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") - a = a if isinstance(a, Array) else lax_internal.asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) @@ -743,7 +740,7 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") - check_arraylike("logsumexp", a) + a = ensure_arraylike("logsumexp", a) where = check_where("logsumexp", where) a_arr, = promote_dtypes_inexact(a) pos_dims, dims = _reduction_dims(a_arr, axis) @@ -763,7 +760,7 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") - check_arraylike("logsumexp2", a) + a = ensure_arraylike("logsumexp2", a) where = check_where("logsumexp2", where) ln2 = float(np.log(2)) if initial is not None: @@ -873,7 +870,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: - check_arraylike("mean", a) + a = ensure_arraylike("mean", a) where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -972,7 +969,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: if weights is None: # Treat all weights as 1 - check_arraylike("average", a) + a = ensure_arraylike("average", a) a, = promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: @@ -982,7 +979,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: - check_arraylike("average", a, weights) + a, weights = ensure_arraylike("average", a, weights) a, weights = promote_dtypes_inexact(a, weights) a_shape = np.shape(a) @@ -1104,14 +1101,14 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") + a = ensure_arraylike("var", a) return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, +def _var(a: Array, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - check_arraylike("var", a) where = check_where("var", where) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: @@ -1242,14 +1239,14 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") + a = ensure_arraylike("std", a) return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, +def _std(a: Array, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - check_arraylike("std", a) where = check_where("std", where) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): @@ -1298,12 +1295,12 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None, [7], [6]], dtype=int32) """ + a = ensure_arraylike("ptp", a) return _ptp(a, _ensure_optional_axes(axis), out, keepdims) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, +def _ptp(a: Array, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: - check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") x = amax(a, axis=axis, keepdims=keepdims) @@ -1350,7 +1347,7 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, [1], [3]], dtype=int32) """ - check_arraylike("count_nonzero", a) + a = ensure_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims) @@ -1359,7 +1356,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, **kwargs) -> Array: - check_arraylike(name, a) + a = ensure_arraylike(name, a) where = check_where(name, where) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) @@ -1783,7 +1780,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out >>> jnp.nanmean(x, axis=0, keepdims=True, where=where) Array([[nan, nan, nan, nan]], dtype=float32) """ - check_arraylike("nanmean", a) + a = ensure_arraylike("nanmean", a) where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") @@ -1877,7 +1874,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [0. ], [4. ]], dtype=float32) """ - check_arraylike("nanvar", a) + a = ensure_arraylike("nanvar", a) where = check_where("nanvar", where) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: @@ -1973,7 +1970,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: >>> jnp.nanstd(x, axis=0, keepdims=True, where=where) Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ - check_arraylike("nanstd", a) + a = ensure_arraylike("nanstd", a) where = check_where("nanstd", where) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: @@ -2375,7 +2372,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - check_arraylike("quantile", a, q) + a, q = ensure_arraylike("quantile", a, q) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") @@ -2433,7 +2430,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanquantile", a, q) + a, q = ensure_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -2616,7 +2613,7 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - check_arraylike("percentile", a, q) + a, q = ensure_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): deprecations.warn( @@ -2676,7 +2673,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanpercentile", a, q) + a, q = ensure_arraylike("nanpercentile", a, q) q, = promote_dtypes_inexact(q) q = q / 100 if not isinstance(interpolation, DeprecatedArg): @@ -2736,7 +2733,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [4. ], [4.5]], dtype=float32) """ - check_arraylike("median", a) + a = ensure_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') @@ -2793,7 +2790,7 @@ def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [5. ], [3. ]], dtype=float32) """ - check_arraylike("nanmedian", a) + a = ensure_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') diff --git a/jax/_src/numpy/tensor_contractions.py b/jax/_src/numpy/tensor_contractions.py index 850eb90cf1d2..990f17c2b23e 100644 --- a/jax/_src/numpy/tensor_contractions.py +++ b/jax/_src/numpy/tensor_contractions.py @@ -284,7 +284,7 @@ def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 50, 122], [ 38, 92]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2) @@ -326,7 +326,7 @@ def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 40, 46], [ 94, 109]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2) @@ -372,7 +372,7 @@ def vdot( >>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64) """ - util.check_arraylike("vdot", a, b) + a, b = util.ensure_arraylike("vdot", a, b) if dtypes.issubdtype(dtypes.dtype(a, canonicalize=True), np.complexfloating): a = ufuncs.conj(a) return dot(jax.numpy.ravel(a), jax.numpy.ravel(b), precision=precision, @@ -638,6 +638,6 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") - util.check_arraylike("outer", a, b) + a, b = util.ensure_arraylike("outer", a, b) a, b = util.promote_dtypes(a, b) return jax.numpy.ravel(a)[:, None] * jax.numpy.ravel(b)[None, :] diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 07758a87750c..77b1220214ed 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -119,7 +119,7 @@ def fabs(x: ArrayLike, /) -> Array: >>> jnp.fabs(x2) Array([1., 0.], dtype=float32) """ - check_arraylike('fabs', x) + x = ensure_arraylike('fabs', x) if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @@ -365,9 +365,9 @@ def floor(x: ArrayLike, /) -> Array: [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32) """ - check_arraylike('floor', x) + x = ensure_arraylike('floor', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): - return lax.asarray(x) + return x return lax.floor(*promote_args_inexact('floor', x)) @@ -404,7 +404,7 @@ def ceil(x: ArrayLike, /) -> Array: [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32) """ - check_arraylike('ceil', x) + x = ensure_arraylike('ceil', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) @@ -2343,7 +2343,7 @@ def absolute(x: ArrayLike, /) -> Array: >>> jnp.absolute(x3) Array([17., 5., 5.], dtype=float32) """ - check_arraylike('absolute', x) + x = ensure_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) @@ -2386,7 +2386,7 @@ def rint(x: ArrayLike, /) -> Array: >>> jnp.rint(x3) Array([-2.+4.j, 4.-0.j], dtype=complex64) """ - check_arraylike('rint', x) + x = ensure_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) @@ -2995,7 +2995,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.ldexp(m, e) Array([ 2., 3., 5., 11.], dtype=float32) """ - check_arraylike("ldexp", x1, x2) + x1, x2 = ensure_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) @@ -3049,7 +3049,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: >>> m * 2 ** e Array([1., 2., 3., 4., 5.], dtype=float32) """ - check_arraylike("frexp", x) + x = ensure_arraylike("frexp", x) x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") @@ -3176,7 +3176,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 1., -1., 4.], [ 0., 2., -2.]], dtype=float32) """ - check_arraylike("fmod", x1, x2) + x1, x2 = ensure_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) out = lax.rem(*promote_args_numeric("fmod", x1, x2)) @@ -3229,7 +3229,7 @@ def square(x: ArrayLike, /) -> Array: >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64) """ - check_arraylike("square", x) + x = ensure_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.square(x) @@ -3343,7 +3343,7 @@ def conjugate(x: ArrayLike, /) -> Array: >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) """ - check_arraylike("conjugate", x) + x = ensure_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) @@ -3381,7 +3381,7 @@ def imag(val: ArrayLike, /) -> Array: >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32) """ - check_arraylike("imag", val) + val = ensure_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) @@ -3413,7 +3413,7 @@ def real(val: ArrayLike, /) -> Array: >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32) """ - check_arraylike("real", val) + val = ensure_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) @@ -3443,7 +3443,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: >>> jnp.modf(x) (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32)) """ - check_arraylike("modf", x) + x = ensure_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") @@ -3482,7 +3482,7 @@ def isfinite(x: ArrayLike, /) -> Array: >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True) """ - check_arraylike("isfinite", x) + x = ensure_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) @@ -3696,7 +3696,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32) """ - check_arraylike("heaviside", x1, x2) + x1, x2 = ensure_arraylike("heaviside", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, @@ -3781,7 +3781,7 @@ def reciprocal(x: ArrayLike, /) -> Array: >>> jnp.reciprocal(x) Array([1. , 0.2 , 0.25], dtype=float32) """ - check_arraylike("reciprocal", x) + x = ensure_arraylike("reciprocal", x) x, = promote_dtypes_inexact(x) return lax.integer_pow(x, -1) @@ -3834,7 +3834,7 @@ def sinc(x: ArrayLike, /) -> Array: (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00 """ - check_arraylike("sinc", x) + x = ensure_arraylike("sinc", x) x, = promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 367d06065842..6302e1a9b54c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -296,6 +296,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None # materialize the broadcast forms of scalar arguments. @api.jit def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: + condition, x, y = ensure_arraylike("where", condition, x, y) if x is None or y is None: raise ValueError("Either both or neither of the x and y arguments should " "be provided to jax.numpy.where, got {} and {}." diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 45f12ac06473..91e2a1d9cf6d 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -573,5 +573,10 @@ def wrap_array(arg): self.assertAllClose(actual, expected, atol=0, rtol=0) +@jtu.with_config(jax_disable_jit=True) +class JaxArrayTestsNoJit(JaxArrayTests): + pass + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 66b6eaa76696fc5188ecabfe6b365d3e182e58a3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Apr 2025 12:04:40 -0700 Subject: [PATCH 0837/1769] Replace auto tracking with manual axis names tracking internally in shard_map. This will make it easier to implement nested shard_map where you enter into manual mode one axis at a time per shard_map. PiperOrigin-RevId: 751509665 --- jax/_src/checkify.py | 8 +- jax/_src/interpreters/pxla.py | 2 - jax/_src/shard_map.py | 275 ++++++++++++++++++---------------- 3 files changed, 146 insertions(+), 139 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index a645f6c71249..c8caffeb7877 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -968,15 +968,15 @@ def shard_map_error_check( new_in_names = (*([{}] * num_error_vals), *in_names) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) - auto = kwargs.get('auto') + manual_axes = kwargs.get('manual_axes') check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, check_vma, new_in_names[i], v) + in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_names[i], v) - with (jshmap._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, auto)), + with (jshmap._extend_axis_env(mesh, manual_axes), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f337c996d3e3..6c9c6d2aad68 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3005,8 +3005,6 @@ def from_hlo(name: str, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler) - orig_out_shardings = out_shardings - if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index fce7ed824cb0..60e1991e2d1b 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -153,10 +153,10 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, axis_names = frozenset(axis_names) if not axis_names: axis_names = frozenset(mesh.axis_names) - auto = frozenset(mesh.axis_names) - frozenset(axis_names) - if not auto.issubset(mesh.axis_names): - raise ValueError(f"shard_map requires auto={auto} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") + if not axis_names.issubset(mesh.axis_names): + raise ValueError( + f"jax.shard_map requires axis_names={axis_names} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") # TODO(yashkatariya): Maybe we don't have to be this strict? if mesh._any_axis_auto_or_manual and in_specs is None: @@ -166,9 +166,9 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, f" {mesh} has `Auto` axes.\n") if in_specs is not None: - _check_specs(SpecErrorType.input, in_specs, auto) + _check_specs(SpecErrorType.input, in_specs, axis_names) if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, auto) + _check_specs(SpecErrorType.out, out_specs, axis_names) @util.wraps(f) @traceback_util.api_boundary @@ -202,7 +202,7 @@ def wrapped(*args): def out_names_thunk(): if callable(out_specs): out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_, auto) + _check_specs(SpecErrorType.out, out_specs_, axis_names) else: out_specs_ = out_specs dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) @@ -219,7 +219,8 @@ def out_names_thunk(): try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_vma=check_vma, auto=auto) + out_names_thunk=out_names_thunk, check_vma=check_vma, + manual_axes=axis_names) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -268,38 +269,39 @@ def _manual_spec(manual_axes, spec: P) -> P: SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) -def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: +def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: if error_type == SpecErrorType.input and specs is None: raise TypeError( "shard_map in_specs argument must be a pytree of " "`jax.sharding.PartitionSpec` instances, but it was None.\n" "Instead of `in_specs=None`, did you mean `in_specs=P()`, " "where `P = jax.sharding.PartitionSpec`?") + def check_spec(p): if not isinstance(p, PartitionSpec): return False for names in p: - if not isinstance(names, tuple): - names = (names,) + names = (names,) if not isinstance(names, tuple) else names for name in names: - if name in auto: + if name is not None and name not in manual_axes: return False return True - if all(check_spec(p) for p in tree_leaves(specs)): return + + if all(check_spec(p) for p in tree_leaves(specs)): + return prefix = 'in' if error_type == SpecErrorType.input else 'out' msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " for key, x in generate_key_paths(specs) if not isinstance(x, P)] if not msgs: for key, p in generate_key_paths(specs): for names in p: - if not isinstance(names, tuple): - names = (names,) + names = (names,) if not isinstance(names, tuple) else names for name in names: - if name in auto: + if name is not None and name not in manual_axes: msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") raise ValueError( - f"shard_map {prefix}_specs argument cannot refer to an axis " - f"marked auto ({auto}), but:\n\n" + f"shard_map {prefix}_specs argument must refer to an axis " + f"marked as manual ({manual_axes}), but:\n\n" + '\n\n'.join(msgs) + '\n\n' f"Check the {prefix}_specs values passed to shard_map.") raise TypeError( @@ -527,13 +529,13 @@ def get_bind_params(self, params): # Staging @util.cache(max_size=256, trace_context_in_key=True) -def _as_manual_mesh(mesh, auto: frozenset): - manual_axes = tuple(set(mesh.axis_names) - auto) +def _as_manual_mesh(mesh, manual_axes: frozenset): + not_manual = set(mesh.axis_names) - manual_axes cur_mesh = get_abstract_mesh() if cur_mesh.empty: cur_mesh = mesh explicit_axes, auto_axes = set(), set() # type: ignore - for a in auto: + for a in not_manual: if cur_mesh._name_to_type[a] == AxisType.Auto: auto_axes.add(a) else: @@ -553,9 +555,9 @@ def _as_manual_mesh(mesh, auto: frozenset): axis_types=tuple(new_axis_types)) -def _extend_axis_env(mesh, auto): +def _extend_axis_env(mesh, manual_axes): return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k not in auto]) + if k in manual_axes]) def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, @@ -563,22 +565,22 @@ def _shard_map_staging( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_vma: bool, - auto: frozenset, + manual_axes: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: source_info = source_info_util.current() to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) in_tracers = map(to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto, check_vma), in_names, + in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_names, in_avals) - manual_mesh = _as_manual_mesh(mesh, auto) - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + manual_mesh = _as_manual_mesh(mesh, manual_axes) + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] - _check_reps(mesh, auto, out_names_thunk(), out_vma) + _check_reps(mesh, out_names_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] @@ -587,12 +589,12 @@ def _shard_map_staging( constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, effs, source_info) @@ -606,11 +608,11 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval -def _shard_aval(mesh: Mesh, auto, check_vma, names: AxisNames, +def _shard_aval(mesh: Mesh, manual_axes, check_vma, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, check_vma, names, - aval) + return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma, + names, aval) raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, @@ -620,12 +622,13 @@ def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, else: raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_vma, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: +def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, + names: AxisNames, aval: core.AbstractValue + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, auto) + manual_mesh = _as_manual_mesh(mesh, manual_axes) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) vma = (frozenset({n for ns in names.values() for n in ns}) if check_vma else frozenset()) @@ -669,19 +672,19 @@ def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, RepType = Any def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): # TODO(mattjj,parkers): check auto for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval( - mesh, auto, check_vma, in_name, x.aval)): + mesh, manual_axes, check_vma, in_name, x.aval)): raise core.JaxprTypeError("shard_map argument avals not compatible with " "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): core.check_jaxpr(jaxpr) if check_vma: - out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] + out_rep = [_vma_to_rep(mesh, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, auto, rep, dst): + if not _valid_repeats(mesh, rep, dst): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] @@ -692,14 +695,14 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, core.custom_typechecks[shard_map_p] = _shard_map_typecheck -def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: - return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) +def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: + return rep is None or (set(_unmentioned(mesh, dst))).issubset(rep) # Lowering def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in + ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in ) -> sharding_impls.SdyArraySharding: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) @@ -707,24 +710,23 @@ def _shardy_shard_map_sharding( ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) - if auto: + if len(manual_axes) < len(mesh.axis_names): for dim_sharding in sdy_sharding.dimension_shardings: dim_sharding.is_open = True return sdy_sharding def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma): + ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma): + axis_ctx = ctx.module_context.axis_context in_avals_ = [v.aval for v in jaxpr.invars] - if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): # Nested `ManualComputationOp`s cannot refer to axes that are already # manual. So figure out what axes are free thus far. - free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - shardy_manual_axes = free_axes - auto + shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes else: - shardy_manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto) + shardy_manual_axes = manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) # The order of manual axes should match the order of mesh.axis_names to avoid @@ -733,17 +735,17 @@ def _shard_map_lowering_shardy( if a in shardy_manual_axes] if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, dim_var_values=ctx.dim_var_values) return out_nodes in_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), + partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), in_names, ctx.avals_in)).build() out_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), + partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), out_names, ctx.avals_out)).build() output_types = map(mlir.aval_to_ir_type, ctx.avals_out) manual_computation_op = sdy.ManualComputationOp( @@ -752,7 +754,7 @@ def _shard_map_lowering_shardy( ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) block = ir.Block.create_at_start( manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), + with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma)): out_nodes_, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, @@ -763,27 +765,26 @@ def _shard_map_lowering_shardy( def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma) + ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, - in_avals_, in_nodes) - manual_axes = frozenset(mesh.axis_names) - auto + in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_names, + ctx.avals_in, in_avals_, in_nodes) new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, - ctx.avals_out, out_nodes_) + return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_names, + out_avals_, ctx.avals_out, out_nodes_) mlir.register_lowering(shard_map_p, _shard_map_lowering) def _make_scoped_manual_sharding(ctx, mesh, axes): @@ -795,9 +796,9 @@ def _make_scoped_manual_sharding(ctx, mesh, axes): return NamedSharding( mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) @@ -805,25 +806,27 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() - unspecified = set(range(aval_in.ndim)) if auto else set() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) - unspecified = set(range(aval_out.ndim)) if auto else set() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, @@ -859,12 +862,13 @@ def _vma_to_spec(mesh, vma): def _names_to_vma(names): return {n for ns in names.values() for n in ns} -def _vma_to_rep(mesh, auto, vma): - return frozenset((set(mesh.axis_names) - auto) - vma) +def _vma_to_rep(mesh, vma): + return set(mesh.axis_names) - vma def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_vma, auto): - if auto: raise NotImplementedError + check_vma, manual_axes): + if len(manual_axes) < len(mesh.axis_names): + raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): mesh = get_mesh_from_args(args, mesh) @@ -872,11 +876,12 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), in_names, args) in_vma = map(_names_to_vma, in_names) - outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_vma, cur_mesh) + outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma, + cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_vma: - _check_reps(mesh, auto, out_names_thunk(), out_vma) + _check_reps(mesh, out_names_thunk(), out_vma) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) @@ -885,11 +890,11 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, outs) core.EvalTrace.process_shard_map = _shard_map_impl -def _run_shmap(f, mesh, auto, args, vmas, check_vma, context_mesh): - trace = ShardMapTrace(mesh, auto, check_vma, context_mesh) +def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh): + trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh) in_tracers = map(partial(ShardMapTracer, trace), vmas, args) - manual_mesh = _as_manual_mesh(mesh, auto) - with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), + manual_mesh = _as_manual_mesh(mesh, manual_axes) + with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): ans = f.call_wrapped(*in_tracers) outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) @@ -928,9 +933,9 @@ def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] class _SpecError(Exception): pass -def _check_reps(mesh, auto, names, vmas): - reps = [_vma_to_rep(mesh, auto, v) for v in vmas] - fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail +def _check_reps(mesh, names, vmas): + reps = [_vma_to_rep(mesh, v) for v in vmas] + fail = [r if not _valid_repeats(mesh, r, n) else no_fail for n, r in zip(names, reps)] if any(f is not no_fail for f in fail): raise _RepError(fail) @@ -961,17 +966,17 @@ def _maybe_check_special(outs): raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "auto", "check", "context_mesh") + __slots__ = ("mesh", "manual_axes", "check", "context_mesh") mesh: Mesh - auto: frozenset[AxisName] + manual_axes: frozenset[AxisName] check: bool context_mesh: AbstractMesh - def __init__(self, mesh, auto, check, context_mesh): + def __init__(self, mesh, manual_axes, check, context_mesh): super().__init__() self.mesh = mesh - self.auto = auto + self.manual_axes = manual_axes self.check = check self.context_mesh = context_mesh @@ -1030,8 +1035,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. del prim, jvp, symbolic_zeros in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) return map(partial(ShardMapTracer, self), out_vma, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, @@ -1043,8 +1048,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) return map(partial(ShardMapTracer, self), out_vma, out_vals) @@ -1065,7 +1070,7 @@ def aval(self): aval = core.get_aval(self.val) out = core.mapped_aval(self._trace.mesh.size, 0, aval) new_sharding = NamedSharding( - _as_manual_mesh(self._trace.mesh, self._trace.auto), + _as_manual_mesh(self._trace.mesh, self._trace.manual_axes), out.sharding.spec) # pytype: disable=attribute-error vma = self.vma if config._check_vma.value else frozenset() return out.update(sharding=new_sharding, vma=vma) @@ -1136,7 +1141,7 @@ def _shard_map_batch( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_vma: bool, - auto: frozenset) -> Sequence[batching.BatchTracer]: + manual_axes: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError @@ -1161,7 +1166,7 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, @@ -1184,7 +1189,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -1199,7 +1204,7 @@ def new_out_names_thunk(): return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) @@ -1210,18 +1215,18 @@ def new_out_names_thunk(): def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_vma), + in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma), unk_in_names, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits2( f, (*in_knowns,), (*in_avals_sharded,)) - all_names = _all_newly_manual_mesh_names(mesh, auto) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) @as_hashable_function(closure=out_names_thunk) def known_out_names(): @@ -1236,7 +1241,7 @@ def known_out_names(): known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() @@ -1267,7 +1272,7 @@ def known_out_names(): out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) @@ -1282,12 +1287,12 @@ def known_out_names(): def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - all_names = _all_newly_manual_mesh_names(mesh, auto) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): @@ -1303,7 +1308,8 @@ def fwd_out_names_thunk(): return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, auto=auto) + out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, + manual_axes=manual_axes) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() @@ -1313,8 +1319,8 @@ def fwd_out_names_thunk(): residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() @@ -1341,7 +1347,7 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): @@ -1393,26 +1399,27 @@ def fun(*res_and_args): def _unmentioned2(mesh: Mesh, names: AxisNames, - auto: frozenset[AxisName]) -> list[AxisName]: + manual_axes: frozenset[AxisName]) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over - # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} | auto - return [n for n in _all_mesh_names_except_spmd(mesh, auto) + # more chips than required in the transpose-no-check-vma case. + name_set = {n for ns in names.values() for n in ns} + return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ - ad.Zero(_shard_aval(mesh, auto, check_vma, ns, x.aval)) + ad.Zero(_shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, manual_axes)))) for ns, x in zip(out_names, out_cts) ] args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, check_vma, ns, x.aval)) + ad.UndefinedPrimal( + _shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) for ns, x in zip(in_names, args)) all_args, in_tree = tree_flatten((out_cts, args)) @@ -1429,7 +1436,7 @@ def fun_trans_callable(out_cts, args): _, in_ct_names = partition_list(in_undef, in_names) in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, ns, x.aval)) if type(x) is ad.Zero else x if check_vma - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, manual_axes))) for ns, x in zip(in_ct_names, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] return merge_lists(in_undef, res_zeros, in_cts) @@ -1449,7 +1456,7 @@ def new_out_names_thunk(): out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " "function. Calling the de-optimized backward pass.") @@ -1460,7 +1467,7 @@ def new_out_names_thunk(): _ = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: @@ -1476,8 +1483,8 @@ def _partial_eval_jaxpr_custom_rule( ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - check_vma, auto = eqn.params['check_vma'], eqn.params['auto'] - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res @@ -1487,8 +1494,8 @@ def _partial_eval_jaxpr_custom_rule( out_fwd = [idx_map.get(id(v)) for v in res_vars] which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) @@ -1503,7 +1510,7 @@ def _partial_eval_jaxpr_custom_rule( for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore - if check_vma else {0: _all_newly_manual_mesh_names(mesh, auto)}) + if check_vma else {0: _all_newly_manual_mesh_names(mesh, manual_axes)}) residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) staged_in_res_names.append(rn) if check_vma: @@ -1512,7 +1519,8 @@ def _partial_eval_jaxpr_custom_rule( for var, o in zip(res_vars, out_fwd) if o is None ] else: - out_res_names_known = [{0: _all_newly_manual_mesh_names(mesh, auto)}] * sum(which) + out_res_names_known = [ + {0: _all_newly_manual_mesh_names(mesh, manual_axes)}] * sum(which) params_known, params_staged = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, out_res_names_known, staged_in_res_names, @@ -1588,14 +1596,14 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() spmd_names = axis_env.spmd_axis_names - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto) + return tuple(name for name in mesh.axis_names + if name not in spmd_names and name in manual_axes) def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: @@ -1605,7 +1613,8 @@ def _all_newly_manual_mesh_names( # TODO(mattjj): remove this mechanism when we revise mesh scopes already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names return tuple(name for name in mesh.axis_names - if name not in auto | vmap_spmd_names | already_manual_names) + if (name not in vmap_spmd_names | already_manual_names and + name in manual_axes)) # DCE @@ -1616,9 +1625,9 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] - auto = eqn.params["auto"] + manual_axes = eqn.params["manual_axes"] check_vma = eqn.params["check_vma"] - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None From f7831007edbb5efc06263331814cfacbde8c7a9b Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Fri, 25 Apr 2025 13:01:22 -0700 Subject: [PATCH 0838/1769] Fix wrong shard_map import PiperOrigin-RevId: 751529035 --- tests/pallas/tpu_pallas_pipeline_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 627302904748..59ac680d3ac3 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -24,7 +24,7 @@ from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np From 1c3f4faedd1fd013a03aa90f47a12ceef9f9b02a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Apr 2025 13:14:46 -0700 Subject: [PATCH 0839/1769] Add conv_general_dilated sharding rule This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes https://github.com/jax-ml/jax/issues/28090 PiperOrigin-RevId: 751534065 --- jax/_src/lax/convolution.py | 50 ++++++++++++++++++--------- jax/_src/lax/windowed_reductions.py | 35 ++++++++++--------- tests/pjit_test.py | 52 +++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 31 deletions(-) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 28d67adb6413..53f9c88369a6 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -53,6 +53,8 @@ class ConvDimensionNumbers(NamedTuple): None, ] +# TODO(yashkatariya): conv_general_dilated should take `out_sharding` argument +# similar to `dot_general` def conv_general_dilated( lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], @@ -415,6 +417,26 @@ def _conv_general_dilated_shape_rule( return tuple(np.take(out_trans, np.argsort(out_perm))) +def _conv_general_dilated_sharding_rule( + lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, + batch_group_count, **unused_kwargs): + # Only allow if rhs is fully replicated and lhs's feature dim is not sharded + if ((rhs.sharding.mesh.empty or rhs.sharding.is_fully_replicated) and + lhs.sharding.spec[dimension_numbers.lhs_spec[1]] is None): + out_shape = _conv_general_dilated_shape_rule( + lhs, rhs, window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, lhs, "conv_general_dilated") + # TODO(yashkatariya): In this case, just let the user specify the out_sharding + # via `out_sharding` argument to `conv_general_dilated`. + raise core.ShardingTypeError( + "Please file an issue at https://github.com/jax-ml/jax/issues") + def _conv_general_dilated_dtype_rule( lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, preferred_element_type, **unused_kwargs): @@ -635,6 +657,7 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, 'conv_general_dilated', + sharding_rule=_conv_general_dilated_sharding_rule, vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, @@ -713,21 +736,18 @@ def _conv_general_dilated_lower( # TODO(https://github.com/openxla/stablehlo/issues/1268) raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count") if all(core.is_constant_shape(p) for p in padding): - return [ - hlo.convolution( - mlir.aval_to_ir_type(aval_out), - lhs, - rhs, - dimension_numbers=dnums, - feature_group_count=mlir.i64_attr(feature_group_count), - batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array(window_strides), - padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_array(lhs_dilation), - rhs_dilation=mlir.dense_int_array(rhs_dilation), - window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)) - ] + out = hlo.convolution( + mlir.aval_to_ir_type(aval_out), lhs, rhs, + dimension_numbers=dnums, + feature_group_count=mlir.i64_attr(feature_group_count), + batch_group_count=mlir.i64_attr(batch_group_count), + window_strides=mlir.dense_int_array(window_strides), + padding=mlir.dense_int_elements(padding), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), + window_reversal=window_reversal, + precision_config=lax.precision_attr(precision)) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] else: # d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each # spatial dimension. diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index c159dcab8bfa..41ea90804d7b 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -520,21 +520,11 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, def reduce_window_sharding_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation): - if base_dilation is None: - base_dilation = [1] * operand.ndim - if window_dilation is None: - window_dilation = [1] * operand.ndim - - for spec, wdim, ws, pd, bd, wdil in zip( - operand.sharding.spec, window_dimensions, window_strides, padding, - base_dilation, window_dilation): - if spec is None: - continue - if not (wdim == 1 and ws == 1 and pd == (0, 0) and bd == 1 and wdil == 1): - raise core.ShardingTypeError( - "Only trivial windowing is supported along non-replicated" - f" dimensions. Got {operand.sharding.spec=}") - return operand.sharding + out_shape = reduce_window_shape_tuple( + operand.shape, window_dimensions, window_strides, padding, base_dilation, + window_dilation) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'reduce_window') reduce_window_sum_p = lax.standard_primitive( _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', @@ -680,8 +670,14 @@ def _select_and_scatter_shape_rule( raise TypeError(msg.format(window_strides, window_dimensions)) return operand.shape +def _select_and_scatter_sharding_rule( + operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, + scatter_consts, window_dimensions, window_strides, padding): + return operand.sharding + select_and_scatter_p = lax.standard_primitive( _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter', + sharding_rule=_select_and_scatter_sharding_rule, vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( @@ -722,7 +718,8 @@ def _select_and_scatter_lower( *scatter.arguments, dim_var_values=ctx.dim_var_values) hlo.return_(mlir.flatten_ir_values(out_nodes)) - return op.results + return [mlir.lower_with_sharding_in_types(ctx, r, aval) + for r, aval in zip(op.results, ctx.avals_out)] mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower) @@ -731,6 +728,11 @@ def _select_and_scatter_add_shape_rule( padding): return operand.shape +def _select_and_scatter_add_sharding_rule( + source, operand, *, select_prim, window_dimensions, window_strides, + padding): + return operand.sharding + def _select_and_scatter_add_jvp( primals, tangents, *, select_prim, window_dimensions, window_strides, padding): @@ -779,6 +781,7 @@ def _select_and_scatter_add_batch_rule( select_and_scatter_add_p = lax.standard_primitive( _select_and_scatter_add_shape_rule, lax._input_dtype, 'select_and_scatter_add', + sharding_rule=_select_and_scatter_add_sharding_rule, vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0de1ee1b84a2..9174beeb2525 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7517,6 +7517,58 @@ def f2(x, i, j): return x.at[i].set(x_j) f2(x,i,j) # doesn't crash + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_conv_general_dilated(self, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), P('x', 'y')) + + @jax.jit + def f(x): + # Conv1D across sharded y-axis: + out = jax.lax.conv_general_dilated( + x, np.zeros((5, 8, 10)), + window_strides=(1,), padding='SAME', feature_group_count=1, + lhs_dilation=(1,), rhs_dilation=(1,), + dimension_numbers=('NWC', 'WIO', 'NWC')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y', None)) + # Max pooling along sharded y-axis. + out2 = jax.lax.reduce_window( + out, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out2.aval.sharding.spec, P('x', 'y', None)) + return out2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y', None))) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + + with self.assertRaises(core.ShardingTypeError): + arr2 = jax.device_put(np.zeros((16, 128, 8)), P('x', None, 'y')) + f(arr2) + + @parameterized.named_parameters( + ('spec1', P('x', 'y', None)), + ('spec2', P('x', None, 'y')), + ('spec3', P(None, 'x', 'y')), + ('spec4', P(('x', 'y'), None, None)) + ) + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_reduce_window(self, spec, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), spec) + + @jax.jit + def f(x): + out = jax.lax.reduce_window( + x, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out.aval.sharding.spec, spec) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, spec)) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 56a336b2ada8305a0ebf7a3938344cee8f296f8d Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 25 Apr 2025 13:34:35 -0700 Subject: [PATCH 0840/1769] [Pallas] Fix index_map equality checks PiperOrigin-RevId: 751540832 --- jax/_src/pallas/core.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 4bc015ab4dad..f68393a7de54 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -432,6 +432,28 @@ def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: if not isinstance(dim, Squeezed) ) + +class _IndexMapFunc: + """Helper class that checks for index_map equality.""" + + def __init__(self, index_map): + self.index_map = index_map + functools.update_wrapper(self, self.index_map) + + def __eq__(self, other: object): + if not isinstance(other, _IndexMapFunc): + return NotImplemented + return self.index_map == other.index_map + + def __call__(self, *args, **kwargs): + out_indices = self.index_map(*args, **kwargs) + if isinstance(out_indices, list): + out_indices = tuple(out_indices) + if not isinstance(out_indices, tuple): + out_indices = (out_indices,) + return out_indices + + @dataclasses.dataclass class BlockSpec: """Specifies how an array should be sliced for each invocation of a kernel. @@ -463,16 +485,7 @@ def __post_init__(self): " indexing." ) if self.index_map is not None: - old_index_map = self.index_map - @functools.wraps(old_index_map) - def _wrapper_index_map(*args, **kwargs): - indices = old_index_map(*args, **kwargs) - if isinstance(indices, list): - indices = tuple(indices) - if not isinstance(indices, tuple): - indices = (indices,) - return indices - self.index_map = _wrapper_index_map + self.index_map = _IndexMapFunc(self.index_map) def to_block_mapping( self, From 7c0e68a5e8fd3e0a91f36472ee97212042ff73bd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Apr 2025 13:41:32 -0700 Subject: [PATCH 0841/1769] Remove `rep` completely from shard_map (rip rep from shmap -- RIP). The error message logic was the last standing thing using `rep`s. Now, the error message talks about varyingness rather than replication Before: ``` ValueError: shard_map applied to the function 'f' was given out_specs which require replication which can't be statically inferred given the mesh: The mesh given has shape (2, 2, 2) with corresponding axis names ('x', 'y', 'z'). out_specs is PartitionSpec('x',) which implies that the corresponding output value is replicated across mesh axes {y,z}, but could only infer replication over {}, which is missing the required axes y,z Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_vma=False argument to `jax.shard_map`. ``` After: ``` ValueError: shard_map applied to the function 'f' was given out_specs which require replication which can't be statically inferred given the mesh: The mesh given has shape (2, 2, 2) with corresponding axis names ('x', 'y', 'z'). out_specs is PartitionSpec('x',) which implies that the corresponding output value is only varying across mesh axes {x} and not {y,z}, but it was inferred to be possibly varying over {x,y,z} Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_vma=False argument to `jax.shard_map`. ``` PiperOrigin-RevId: 751543307 --- jax/_src/shard_map.py | 80 ++++++++++++++++++++--------------------- tests/shard_map_test.py | 13 +++++++ 2 files changed, 51 insertions(+), 42 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 60e1991e2d1b..8313af3c2590 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -233,7 +233,7 @@ def out_names_thunk(): except _RepError as e: fails, = e.args if not callable(out_specs): - msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) + msg = _inout_vma_error(f, mesh, out_tree(), out_specs, fails) raise ValueError(msg) from None return tree_unflatten(out_tree(), out_flat) return wrapped @@ -432,22 +432,23 @@ def _spec_divisibility_error( f"padding the input and adapting '{fun_name}' appropriately.") return msg -def _inout_rep_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, +def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, fails: list[set | NoFail]) -> str: fun_name = getattr(f, '__name__', str(f)) msgs = [] - for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): + for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails): dst = _canonicalize_spec(spec) unmentioned = _unmentioned(mesh, dst) if len(unmentioned) > 1: - need_rep = ','.join(map(str, unmentioned)) - got_rep = ','.join(map(str, rep)) - diff = ','.join(map(str, [n for n in unmentioned if n not in rep])) + need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec)))) + got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma))) + diff = ','.join(map(str, order_wrt_mesh( + mesh, [n for n in unmentioned if n in vma]))) msgs.append( f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axes " - f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " - f"which is missing the required axes {diff}") + f"corresponding output value is only varying across mesh axes " + f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be " + f"possibly varying over {{{got_vma}}}") else: need_rep_, = unmentioned msgs.append( @@ -580,7 +581,7 @@ def _shard_map_staging( _check_names(out_names_thunk(), out_avals_) if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] - _check_reps(mesh, out_names_thunk(), out_vma) + _check_vmas(mesh, out_names_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] @@ -669,8 +670,6 @@ def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, # Type-checking -RepType = Any - def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, check_vma, manual_axes): # TODO(mattjj,parkers): check auto @@ -682,9 +681,9 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): core.check_jaxpr(jaxpr) if check_vma: - out_rep = [_vma_to_rep(mesh, v.aval.vma) for v in jaxpr.outvars] - for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, rep, dst): + out_vma = [v.aval.vma for v in jaxpr.outvars] + for vma, dst in zip(out_vma, out_names): + if not _valid_repeats(mesh, vma, dst): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] @@ -695,9 +694,11 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, core.custom_typechecks[shard_map_p] = _shard_map_typecheck -def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: - return rep is None or (set(_unmentioned(mesh, dst))).issubset(rep) - +def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool: + um = set(_unmentioned(mesh, names)) + if any(u in vma for u in um): + return False + return True # Lowering @@ -729,10 +730,7 @@ def _shard_map_lowering_shardy( new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - # The order of manual axes should match the order of mesh.axis_names to avoid - # non-determinism issues. - manual_axes = [a for a in mesh.axis_names - if a in shardy_manual_axes] + manual_axes = order_wrt_mesh(mesh, shardy_manual_axes) if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): @@ -857,13 +855,16 @@ def get_mesh_from_args(args_flat, mesh): return mesh def _vma_to_spec(mesh, vma): - return P(tuple(i for i in mesh.axis_names if i in vma)) + return P(order_wrt_mesh(mesh, vma)) + +def _spec_to_vma(spec): + return _names_to_vma(_canonicalize_spec(spec)) def _names_to_vma(names): return {n for ns in names.values() for n in ns} -def _vma_to_rep(mesh, vma): - return set(mesh.axis_names) - vma +def order_wrt_mesh(mesh, x): + return tuple(a for a in mesh.axis_names if a in x) def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, check_vma, manual_axes): @@ -881,7 +882,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_vma: - _check_reps(mesh, out_names_thunk(), out_vma) + _check_vmas(mesh, out_names_thunk(), out_vma) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) @@ -916,7 +917,7 @@ def _unmatch(mesh, check_vma, src_tup, x): src = _names_to_pspec(dict(src_tup)) if check_vma: used_axes = {i for _, ns in src_tup for i in ns} - dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) + dst = P(order_wrt_mesh(mesh, used_axes)) else: dst = P(mesh.axis_names) check_vma = False @@ -933,10 +934,9 @@ def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] class _SpecError(Exception): pass -def _check_reps(mesh, names, vmas): - reps = [_vma_to_rep(mesh, v) for v in vmas] - fail = [r if not _valid_repeats(mesh, r, n) else no_fail - for n, r in zip(names, reps)] +def _check_vmas(mesh, names, vmas): + fail = [vma if not _valid_repeats(mesh, vma, n) else no_fail + for n, vma in zip(names, vmas)] if any(f is not no_fail for f in fail): raise _RepError(fail) @@ -1233,8 +1233,7 @@ def known_out_names(): _, _, out_knowns, res_avals, _, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) if check_vma: - res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} - for a in res_avals] + res_names = [{0: order_wrt_mesh(mesh, a.vma)} for a in res_avals] else: res_names = [{0: all_names}] * len(res_avals) return (*out_known_names, *res_names) @@ -1262,7 +1261,7 @@ def known_out_names(): else: if check_vma: res_vma = next(res_avals_iter).vma - res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + res_names.append({0: order_wrt_mesh(mesh, res_vma)}) else: res_names.append({0: all_names}) unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] @@ -1301,8 +1300,7 @@ def fwd_out_names_thunk(): if f1 is None and f2 is None] out_names = out_names_thunk() if check_vma: - res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} - for a in res_avals] + res_names = [{0: order_wrt_mesh(mesh, a.vma)} for a in res_avals] else: res_names = [{0: all_names}] * len(res_avals) return (*res_names, *out_names) @@ -1336,7 +1334,7 @@ def fwd_out_names_thunk(): else: if check_vma: res_vma = next(res_avals_iter).vma - res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + res_names.append({0: order_wrt_mesh(mesh, res_vma)}) else: res_names.append({0: all_names}) new_in_names = (*res_names, *({} for _ in range(len(env))), @@ -1509,15 +1507,13 @@ def _partial_eval_jaxpr_custom_rule( residuals, staged_in_res_names = [], [] for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: - rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore + rn = ({0: order_wrt_mesh(mesh, var.aval.vma)} # type: ignore if check_vma else {0: _all_newly_manual_mesh_names(mesh, manual_axes)}) residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) staged_in_res_names.append(rn) if check_vma: - out_res_names_known = [ - {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} - for var, o in zip(res_vars, out_fwd) if o is None - ] + out_res_names_known = [{0: order_wrt_mesh(mesh, var.aval.vma)} # type: ignore + for var, o in zip(res_vars, out_fwd) if o is None] else: out_res_names_known = [ {0: _all_newly_manual_mesh_names(mesh, manual_axes)}] * sum(which) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index fbcca5b0e394..5320c12e75cd 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -943,6 +943,19 @@ def f(key): g = shard_map(f, mesh=mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! + def test_vma_out_specs_error_check(self): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + @shard_map(mesh=mesh, in_specs=P('x', 'y', 'z'), out_specs=P('x')) + def f(x): + return x * 2 + + with self.assertRaisesRegex( + ValueError, + r".*out_specs is PartitionSpec\('x',\) which implies that the.*" + r' output value is only varying across mesh axes \{x\} and not \{y,z\},' + r' but it was inferred to be possibly varying over \{x,y,z\}.*'): + f(np.arange(16).reshape(4, 2, 2)) + def test_functools_partial_rank_error(self): mesh = jtu.create_mesh((4,), ('x',)) From fdf9fd9cf804a02ac0f0a260984b91c128be1273 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Apr 2025 13:42:01 -0700 Subject: [PATCH 0842/1769] Add multiaccelerator H100 tests to optional GPU presubmit. PiperOrigin-RevId: 751543533 --- ..._b200.yml => bazel_optional_h100_b200.yml} | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) rename .github/workflows/{bazel_optional_b200.yml => bazel_optional_h100_b200.yml} (56%) diff --git a/.github/workflows/bazel_optional_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml similarity index 56% rename from .github/workflows/bazel_optional_b200.yml rename to .github/workflows/bazel_optional_h100_b200.yml index 55c56495d6d9..bde033361609 100644 --- a/.github/workflows/bazel_optional_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -1,4 +1,4 @@ -name: CI - Bazel Optional B200 CUDA tests +name: CI - Bazel Optional H100 and B200 CUDA tests on: # Runs on PR if label "CI Optional GPU Presubmit" is present. workflow_dispatch: @@ -36,10 +36,10 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CUDA Tests + - name: Run Bazel single B200 CUDA Tests run: | nvidia-smi - bazel test --config=ci_linux_x86_64_cuda \ + bazel test --config=rbe_linux_x86_64_cuda \ --config=resultstore \ --config=rbe_cache \ --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ @@ -50,6 +50,7 @@ jobs: --test_output=errors \ --test_env=JAX_ACCELERATOR_COUNT=1 \ --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --strategy=TestRunner=local \ --local_test_jobs=32 \ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ --test_tag_filters=-multiaccelerator \ @@ -60,4 +61,38 @@ jobs: --color=yes \ //tests:gpu_tests //tests:backend_independent_tests \ //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests + run_multiaccelerator_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a3-8g-h100-8gpu + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + name: "Bazel multiple H100 CUDA tests" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel multiple H100 CUDA Tests + run: | + nvidia-smi + bazel test --config=rbe_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --strategy=TestRunner=local \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests \ No newline at end of file From 69a02eec1392041ed998f522db5ac75a0ce89e3c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 25 Apr 2025 13:45:35 -0700 Subject: [PATCH 0843/1769] Reverts fbb0cbbc6bfe30941076a365a427e12efbb5253b PiperOrigin-RevId: 751544952 --- jax/_src/interpreters/ad.py | 10 +++--- jax/_src/interpreters/partial_eval.py | 36 ++++++++++++-------- jax/_src/lax/control_flow/loops.py | 27 +++++++++++++-- tests/api_test.py | 4 +-- tests/core_test.py | 6 ++-- tests/lax_control_flow_test.py | 49 ++++++++++++++++++++++++++- 6 files changed, 105 insertions(+), 27 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 0f11e0d72f12..435e9027f5b3 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1191,10 +1191,12 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ arg_names=new_arg_names, result_paths=new_result_paths, ) - new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, - new_invars, new_outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects, - new_debug_info) + constvars = jaxpr.jaxpr.constvars + new_effects = pe._renumber_effects( + (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, + new_effects, new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1317a584f6c3..7203af95fb7b 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1559,18 +1559,22 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) - new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) - id_map = {id(v): i for i, v in enumerate(new_invars)} - idx_map = {i: id_map[id(v)] for i, v in enumerate(closed_jaxpr.jaxpr.invars)} - new_effs = {e.replace(input_index=idx_map[e.input_index]) - if isinstance(e, effects.JaxprInputEffect) else e - for e in closed_jaxpr.jaxpr.effects} - new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, - closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - new_effs, closed_jaxpr.jaxpr.debug_info) + constvars, invars = closed_jaxpr.jaxpr.constvars, closed_jaxpr.jaxpr.invars + new_invars = _move_to_front(invars, to_move) + new_effs = _renumber_effects( + (*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects) + new_jaxpr = Jaxpr(constvars, new_invars, closed_jaxpr.jaxpr.outvars, + closed_jaxpr.jaxpr.eqns, new_effs, + closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr +def _renumber_effects(new_vars, old_vars, effs): + newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} + old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} + return {e.replace(input_index=old_to_new[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} + def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + [elt for elt, move in zip(lst, to_move) if not move]) @@ -1590,7 +1594,6 @@ def _move_outvars_to_back(jaxpr, to_move): return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) - class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1671,16 +1674,19 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - invar = eqn.invars[eff.input_index] - if invar in mut_arrays: + eqn_invar = eqn.invars[eff.input_index] + if eqn_invar in mut_arrays: continue - if (input_index := all_vars.get(invar, sentinel)) is sentinel: + if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: + # TODO(mattjj): ask for forgiveness + dbg = type('Fake', (), {'resolve_result_paths': lambda _: None})() raise ValueError( f"`JaxprInputEffect` {eff} does not have " - f"corresponding input: {invar}." + f"corresponding jaxpr input: {eqn_invar=}." f"\n Equation: {eqn}\n" + f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d1220ba3fdb3..58ef64add37a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -312,6 +312,9 @@ def _create_jaxpr(init): init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest num_carry = len(init_flat) + num_xs = len(x_avals) + num_ys = len(jaxpr.out_avals) - num_carry + del init_flat _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) @@ -327,22 +330,42 @@ def _create_jaxpr(init): unroll = max(length, 1) if unroll else 1 if unroll < 1: raise ValueError("`unroll` must be a `bool` or a positive `int`.") + if attrs_tracked: in_state = _get_states(attrs_tracked) in_flat = [*in_state, *in_flat] num_carry += len(in_state) + + # If the body forwards an input carry to an output carry, that input is + # read-only and can be moved to be a const. Doing so can lead to efficiency + # wins, e.g. if the scan is inside a cond with a batched predicate. + carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) + move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] + if any(move_to_const): + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [not m for m in move_to_const] + [True] * num_ys) + jaxpr = pe.move_binders_to_front( + jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) + in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat) + consts = [*new_consts, *consts] + num_carry -= len(new_consts) + out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll, - _split_transpose=_split_transpose) + unroll=unroll, _split_transpose=_split_transpose) + + if any(move_to_const): + out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) + if attrs_tracked: num_ext = (len(out) - len(in_state) - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) out_state, out, out_append = split_list(out, [len(in_state), num_ext]) out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) _set_states(attrs_tracked, out_attrs) + return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): diff --git a/tests/api_test.py b/tests/api_test.py index daea53d0fb38..7a1218d3790a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6853,13 +6853,13 @@ def body(c, _): self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule + expected_num_eqns=0) used_outputs[7] = expected_used_inputs[7] = True used_outputs[6] = expected_used_inputs[6] = True self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) + expected_num_eqns=0) # If we use the value at index 3 only, some of the hidden sequence must be # kept but the rest pruned. diff --git a/tests/core_test.py b/tests/core_test.py index e39487035751..646705ebf281 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -364,15 +364,15 @@ def g_vmap(x): def test_dropvar_avals(self): def f(x): def body(c, _): - return c, None + x1, x2 = c + return (2 * x1, 2 * x2), None (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1) return [x2] aval = core.ShapedArray((), jnp.dtype('int32')) pval = pe.PartialVal.unknown(aval) jaxpr, _, _ = pe.trace_to_jaxpr_nounits( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0,), {})), + lu.wrap_init(f, debug_info=debug_info("test", f, (0,), {})), [pval], False) dropvar, b = jaxpr.eqns[0].outvars self.assertEqual(dropvar.aval, aval) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 78026968d2cd..422ef769e392 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3122,7 +3122,7 @@ def body(c): return x + y jax.linearize(f, 1., 2.) # don't crash - def test_readonly_carry_optimization(self): + def test_while_readonly_carry_optimization(self): # https://github.com/google/flax/issues/4700 def foo(w, x, c_max): def while_cond(val): @@ -3204,6 +3204,53 @@ def body_fun(c): outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) self.assertAllClose(outs, (0., 1., 5.)) + def test_scan_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4709 + def f(x, y): + def g(_, y): + y, _ = jax.lax.scan(lambda y, _: (y, None), y, None, length=1) + return y + return jax.lax.cond(x < 0, g, g, x, y) + xs = jnp.arange(3.) + y = 3. + jax.vmap(f, (0, None), None)(xs, y) # don't crash + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds, + num_noninplace_fwds): + + num_fwds = num_inplace_fwds + num_noninplace_fwds + num_carry = num_fwds + 4 + num_xs = 2 + num_ys = 3 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, _): + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds, num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return new_c, [0 for _ in range(num_ys)] + + xs = [jnp.arange(2.) for _ in range(num_xs)] + outs = jax.lax.scan(body_fun, init_vals, xs)[0] + outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(outs, outs_ref, check_dtypes=False) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From a09698cc2a1a0685a13aafd7190dd7ed151e1765 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 25 Apr 2025 13:51:31 -0700 Subject: [PATCH 0844/1769] [jax] Switch lapack kernels to builtin FFI attrs decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` name old cpu/op new cpu/op delta BM_HloModule/jax.issue.26021/process_time 38.2µs ± 2% 36.3µs ± 3% -5.00% (p=0.000 n=73+72) name old time/op new time/op delta BM_HloModule/jax.issue.26021/process_time 38.2µs ± 3% 36.2µs ± 2% -5.24% (p=0.000 n=76+72) ``` Improves benchmark from https://github.com/jax-ml/jax/issues/26021 PiperOrigin-RevId: 751546903 --- jaxlib/cpu/lapack_kernels.cc | 40 +++++++++--------------------------- jaxlib/cpu/lapack_kernels.h | 22 -------------------- jaxlib/ffi_helpers.h | 2 +- 3 files changed, 11 insertions(+), 53 deletions(-) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 3b510708a8bb..2e91bcb34281 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -22,10 +22,9 @@ limitations under the License. #include #include #include -#include #include #include -#include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -41,38 +40,19 @@ static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), namespace ffi = xla::ffi; -#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ - std::optional xla::ffi::AttrDecoding::Decode( \ - XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ - if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ - return diagnostic.Emit("Wrong attribute type: expected ") \ - << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ - } \ - auto* scalar = reinterpret_cast(attr); \ - if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ - return diagnostic.Emit("Wrong scalar data type: expected ") \ - << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ - } \ - auto underlying = \ - *reinterpret_cast*>(scalar->value); \ - return static_cast(underlying); \ - } - -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef REGISTER_CHAR_ENUM_ATTR_DECODING +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::svd::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::eig::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::Sort); namespace jax { template -inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { +inline T CastNoOverflow(int64_t value, std::string_view source = __FILE__) { auto result = MaybeCastNoOverflow(value, source); if (!result.ok()) { throw std::overflow_error{std::string(result.status().message())}; diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 71ba8b8a5e0c..b3f1f1df758a 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -18,11 +18,9 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" -#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" @@ -93,26 +91,6 @@ void AssignKernelFn(typename KernelType::FnType* func) { } // namespace jax -#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ - template <> \ - struct xla::ffi::AttrDecoding { \ - using Type = ATTR; \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine& diagnostic); \ - } - -// XLA needs attributes to have deserialization method specified -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef DEFINE_CHAR_ENUM_ATTR_DECODING - namespace jax { using lapack_int = int; diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 634a48fcffc7..7c4dfce81311 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -89,7 +89,7 @@ namespace jax { template inline absl::StatusOr MaybeCastNoOverflow( - std::int64_t value, const std::string& source = __FILE__) { + std::int64_t value, std::string_view source = __FILE__) { if constexpr (sizeof(T) == sizeof(std::int64_t)) { return value; } else { From f84e5483ddcd973b813f8dba84c07af85377ad7b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 25 Apr 2025 14:14:54 -0700 Subject: [PATCH 0845/1769] [XLA:Python] [JAX] Move XlaBuilder bindings out of JAX and into XLA:Python. Stop exposing XlaBuilder, XlaOp and a number of related classes from JAX. PiperOrigin-RevId: 751555286 --- jax/lib/xla_client.py | 8 +- jaxlib/BUILD | 1 - jaxlib/_jax/__init__.pyi | 46 ---- jaxlib/_jax/ops.pyi | 466 --------------------------------------- jaxlib/xla_client.py | 407 +--------------------------------- jaxlib/xla_client.pyi | 194 ++-------------- jaxlib/xla_compiler.cc | 123 ----------- tests/infeed_test.py | 57 +++-- tests/pjit_test.py | 7 +- 9 files changed, 68 insertions(+), 1241 deletions(-) delete mode 100644 jaxlib/_jax/ops.pyi diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index c81df076a6b2..faaaf4a425f4 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -15,6 +15,9 @@ import gzip as _gzip from jax._src.lib import xla_client as _xc +def _heap_profile(client): + return _gzip.compress(client.heap_profile()) + _deprecations = { # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( @@ -88,7 +91,7 @@ "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" " will be removed in JAX v0.7.0" ), - lambda client: _gzip.compress(client.heap_profile()), + _heap_profile, ), "mlir_api_version": ( ( @@ -152,7 +155,7 @@ if _typing.TYPE_CHECKING: get_topology_for_devices = _xc.get_topology_for_devices - heap_profile = _xc.heap_profile + heap_profile = _heap_profile mlir_api_version = 58 Client = _xc.Client CompileOptions = _xc.CompileOptions @@ -167,4 +170,5 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing +del _heap_profile del _xc diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 6de18a5588e7..cd933868b6e3 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1149,7 +1149,6 @@ cc_library( "@xla//xla/ffi", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:c_api", - "@xla//xla/hlo/builder:xla_builder", "@xla//xla/hlo/builder:xla_computation", "@xla//xla/hlo/ir:hlo", "@xla//xla/hlo/ir:hlo_module_group", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 9559d2862714..fd9b1e068c74 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -30,7 +30,6 @@ from . import ifrt_programs as ifrt_programs from . import ifrt_proxy as ifrt_proxy from . import jax_jit as jax_jit from . import mlir as mlir -from . import ops as ops from . import pmap_lib as pmap_lib from . import profiler as profiler from . import pytree as pytree @@ -42,7 +41,6 @@ hlo_sharding_util = Any _LiteralSlice = Any _Status = Any _Dtype = Any -_XlaOpMetadata = Any ifrt_version_number: int @@ -230,27 +228,6 @@ def hlo_module_cost_analysis( client: Client, module: HloModule ) -> dict[str, float]: ... -class XlaOp: ... - -class XlaBuilder: - def __init__(self, name: str) -> None: ... - def Build(self, root: XlaOp | None = ...) -> XlaComputation: ... - def GetShape(self, __op: XlaOp) -> Shape: ... - build = Build - def clear_op_metadata(self) -> None: ... - get_shape = GetShape - def get_program_shape(self, root: XlaOp | None = ...) -> ProgramShape: ... - def is_constant(self, __op: XlaOp) -> bool: ... - def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... - def set_sharding(self, sharding: OpSharding_Type) -> None: ... - def clear_sharding(self) -> None: ... - def setup_alias( - self, - __output_index: Sequence[int], - __param_number: int, - __param_index: Sequence[int], - ) -> None: ... - class DeviceAssignment: @staticmethod def create(array: np.ndarray) -> DeviceAssignment: ... @@ -371,23 +348,6 @@ class ExecutableBuildOptions: use_shardy_partitioner: bool def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... -class PrecisionConfig_Precision(enum.IntEnum): - DEFAULT = ... - HIGH = ... - HIGHEST = ... - - -class ResultAccuracy_Mode(enum.IntEnum): - DEFAULT = ... - HIGHEST = ... - TOLERANCE = ... - -class ResultAccuracy: - mode: ResultAccuracy_Mode - atol: float - rtol: float - ulps: int - class OpSharding_Type(enum.IntEnum): REPLICATED = ... MAXIMAL = ... @@ -462,12 +422,6 @@ class HloSharding: def replicate_on_last_tile_dim(self) -> bool: ... def to_proto(self) -> OpSharding: ... -class FftType(enum.IntEnum): - FFT = ... - IFFT = ... - RFFT = ... - IRFFT = ... - # === END xla_compiler.cc class Device: diff --git a/jaxlib/_jax/ops.pyi b/jaxlib/_jax/ops.pyi deleted file mode 100644 index 06a38b9090f6..000000000000 --- a/jaxlib/_jax/ops.pyi +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright 2021 The JAX Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import enum -from typing import Any, overload -from collections.abc import Sequence - -from jaxlib import _jax - -FftType = _jax.FftType -XlaBuilder = _jax.XlaBuilder -XlaComputation = _jax.XlaComputation -XlaOp = _jax.XlaOp -PrecisionConfig_Precision = _jax.PrecisionConfig_Precision -PrimitiveType = _jax.PrimitiveType -Shape = _jax.Shape -ShapeIndex = _jax.ShapeIndex -ResultAccuracy = _jax.ResultAccuracy - -_ChannelHandle = Any -_ConvDimensionNumbers = Any -_DotDimensionNumbers = Any -_Layout = Any -_LiteralSlice = Any -_GatherDimensionNumbers = Any -_PaddingConfig = Any -_ReplicaGroup = Any -_ScatterDimensionNumbers = Any - -class TriangularSolveOptions_Transpose(enum.IntEnum): - TRANSPOSE_INVALID = ... - NO_TRANSPOSE = ... - TRANSPOSE = ... - ADJOINT = ... - -class RandomAlgorithm(enum.IntEnum): - RNG_DEFAULT = ... - RNG_THREE_FRY = ... - RNG_PHILOX = ... - -class CustomCallSchedule(enum.IntEnum): - SCHEDULE_NONE = ... - SCHEDULE_LATEST = ... - SCHEDULE_EARLIEST = ... - -# TODO(b/189822916): Remove this enum when all clients are migrated to the -# status-returning API. -class CustomCallApiVersion(enum.IntEnum): - API_VERSION_ORIGINAL = ... - API_VERSION_STATUS_RETURNING = ... - API_VERSION_STATUS_RETURNING_UNIFIED = ... - API_VERSION_TYPED_FFI = ... - -def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... -def AllGather( - operand: XlaOp, - all_gather_dimension: int, - shard_count: int, - replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: _ChannelHandle | None = ..., - shape_with_layout: _Layout | None = ..., - use_global_device_ids: bool | None = ...) -> XlaOp: ... -def AllReduce( - operand: XlaOp, - computation: XlaComputation, - replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: _ChannelHandle | None = ..., - shape_with_layout: _Layout | None = ...) -> XlaOp: ... -def ApproxTopK( - builder: XlaBuilder, - operands: Sequence[XlaOp], - init_values: Sequence[XlaOp], - top_k: int, - reduction_dim: int, - comparator: XlaComputation, - recall_target: float | None, - aggregate_to_topk: bool | None, - reduction_input_size_override: int | None) -> XlaOp: ... -def ApproxTopKFallback( - builder: XlaBuilder, - operands: Sequence[XlaOp], - init_values: Sequence[XlaOp], - top_k: int, - reduction_dim: int, - comparator: XlaComputation, - recall_target: float | None, - aggregate_to_topk: bool | None, - reduction_input_size_override: int | None) -> XlaOp: ... -def ApproxTopKReductionOutputSize( - input_size: int, - rank: int, - top_k: int, - recall_target: float, - aggregate_to_topk: bool | None = ..., - input_size_override: int | None = ...) -> tuple[int, int]: ... -def ReduceScatter( - operand: XlaOp, - computation: XlaComputation, - scatter_dimension: int, - shard_count: int, - replica_groups: Sequence[_ReplicaGroup] = ..., - channel_id: _ChannelHandle | None = ..., - layout: _Layout | None = ..., - use_global_device_ids: bool | None = ...) -> XlaOp: ... -def AllToAll( - operand: XlaOp, - split_dimension: int, - concat_dimension: int, - split_count: int, - replica_groups: Sequence[_ReplicaGroup] = ..., - layout: _Layout | None = ..., - channel_id: _ChannelHandle | None = ...) -> XlaOp: ... -def BitcastConvertType(operand: XlaOp, - new_element_type: PrimitiveType) -> XlaOp: ... -def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ... -def BroadcastInDim(operand: XlaOp, - shape: Sequence[int], - broadcast_dimensions: Sequence[int]) -> XlaOp: ... -def Call(builder: XlaBuilder, - computation: XlaComputation, - operands: Sequence[XlaOp]) -> XlaOp: ... -def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ... -def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ... -def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... -def CollectivePermute( - operand: XlaOp, - source_target_pairs: Sequence[tuple[int, int]], - channel_id: _ChannelHandle | None = ..., - inplace: bool = ...) -> XlaOp: ... -def ConcatInDim(builder: XlaBuilder, - operands: Sequence[XlaOp], - dimension: int) -> XlaOp: ... -@overload -def Conditional(branch_index: XlaOp, - branch_computations: Sequence[XlaComputation], - branch_operands: Sequence[XlaOp]) -> XlaOp: ... -@overload -def Conditional( - predicate: XlaOp, - true_operand: XlaOp, - true_computation: XlaComputation, - false_operand: XlaOp, - false_computation: XlaComputation) -> XlaOp: ... - -def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... -def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... -def ConvGeneralDilated( - lhs: XlaOp, - rhs: XlaOp, - window_strides: Sequence[int], - padding: Sequence[tuple[int, int]], - lhs_dilation: Sequence[int], - rhs_dilation: Sequence[int], - dimension_numbers: _ConvDimensionNumbers, - feature_group_count: int = ..., - batch_group_count: int = ..., - precision_config: PrecisionConfig_Precision | None = ..., - preferred_element_type: PrimitiveType | None = ..., - window_reversal: Sequence[bool] | None = ...) -> XlaOp: ... -def ConvertElementType( - operand: XlaOp, - new_element_type: PrimitiveType) -> XlaOp: ... -def CreateToken(builder: XlaBuilder) -> XlaOp: ... -def CrossReplicaSum( - operand: XlaOp, - replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ... -def CustomCall( - builder: XlaBuilder, - call_target_name: bytes, - operands: Sequence[XlaOp], - shape: Shape, - opaque: bytes = ..., - has_side_effect: bool = ..., - schedule: CustomCallSchedule = ..., - api_version: CustomCallApiVersion = ...) -> XlaOp: ... -def CustomCallWithLayout( - builder: XlaBuilder, - call_target_name: bytes, - operands: Sequence[XlaOp], - shape_with_layout: Shape, - operand_shapes_with_layout: Sequence[Shape], - opaque: bytes = ..., - has_side_effect: bool = ..., - schedule: CustomCallSchedule = ..., - api_version: CustomCallApiVersion = ...) -> XlaOp: ... -def CustomCallWithAliasing( - builder: XlaBuilder, - call_target_name: bytes, - operands: Sequence[XlaOp], - shape_with_layout: Shape, - operand_shapes_with_layout: Sequence[Shape], - opaque: bytes = ..., - has_side_effect: bool = ..., - output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ..., - literal: _LiteralSlice = ..., - schedule: CustomCallSchedule = ..., - api_version: CustomCallApiVersion = ...) -> XlaOp: ... -def Dot( - lhs: XlaOp, - rhs: XlaOp, - precision_config: PrecisionConfig_Precision | None = ..., - preferred_element_type: PrimitiveType | None = ...) -> XlaOp: ... -def DotGeneral( - lhs: XlaOp, - rhs: XlaOp, - dimensions_numbers: _DotDimensionNumbers, - precision_config: PrecisionConfig_Precision | None = ..., - preferred_element_type: PrimitiveType | None = ...) -> XlaOp: ... -def DynamicReshape( - operand: XlaOp, - dim_sizes: Sequence[XlaOp], - new_size_bounds: Sequence[int], - dims_are_dynamic: Sequence[bool]) -> XlaOp: ... -def DynamicSlice( - operand: XlaOp, - start_indices: Sequence[XlaOp], - slice_sizes: Sequence[int]) -> XlaOp: ... -def DynamicUpdateSlice( - operand: XlaOp, - update: XlaOp, - start_indices: Sequence[XlaOp]) -> XlaOp: ... -def Eigh( - a: XlaOp, - lower: bool = ..., - max_iter: int = ..., - epsilon: float = ..., - sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ... -def Fft( - operand: XlaOp, - fft_type: FftType, - fft_length: Sequence[int]) -> XlaOp: ... -def Gather( - a: XlaOp, - start_indices: XlaOp, - dimension_numbers: _GatherDimensionNumbers, - slice_sizes: Sequence[int], - indices_are_sorted: bool = ...) -> XlaOp: ... -def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ... -def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ... -def InfeedWithToken( - token: XlaOp, - shape: Shape, - config: str | None = ...) -> XlaOp: ... -@overload -def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ... -@overload -def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ... -def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ... -def Map( - builder: XlaBuilder, - operands: Sequence[XlaOp], - computation: XlaComputation, - dimensions: Sequence[int], - static_operands: Sequence[XlaOp] = ...) -> XlaOp: ... -def MultiCollectivePermute( - operands: Sequence[XlaOp], - source_target_pairs: Sequence[tuple[int, int]], - channel_id: _ChannelHandle | None = ..., - inplace: bool = ...) -> XlaOp: ... -def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ... -def OutfeedWithToken( - operand: XlaOp, - token: XlaOp, - shape_with_layout: Shape, - outfeed_config: str | None = ...) -> XlaOp: ... -def Pad( - operand: XlaOp, - padding_value: XlaOp, - padding_config: _PaddingConfig) -> XlaOp: ... -def Parameter( - builder: XlaBuilder, - parameter_number: int, - shape: Shape, - name: str = ..., - replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ... -def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ... -def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ... -def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ... -def Reduce( - builder: XlaBuilder, - operands: Sequence[XlaOp], - init_values: Sequence[XlaOp], - computation: XlaComputation, - dimensions_to_reduce: Sequence[int]) -> XlaOp: ... -def ReducePrecision( - operand: XlaOp, - exponent_bits: int, - mantissa_bits: int) -> XlaOp: ... -@overload -def ReduceWindowWithGeneralPadding( - operand: XlaOp, - init_value: XlaOp, - computation: XlaComputation, - window_dimensions: Sequence[int], - window_strides: Sequence[int], - base_dilations: Sequence[int], - window_dilations: Sequence[int], - padding: Sequence[tuple[int, int]]) -> XlaOp: ... -@overload -def ReduceWindowWithGeneralPadding( - operands: Sequence[XlaOp], - init_values: Sequence[XlaOp], - computation: XlaComputation, - window_dimensions: Sequence[int], - window_strides: Sequence[int], - base_dilations: Sequence[int], - window_dilations: Sequence[int], - padding: Sequence[tuple[int, int]]) -> XlaOp: ... -def ReplicaId(builder: XlaBuilder) -> XlaOp: ... -def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ... -def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... -def RngBitGenerator( - algorithm: RandomAlgorithm, - initial_state: XlaOp, - shape: Shape) -> XlaOp: ... -def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ... -def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ... -@overload -def Scatter( - input: XlaOp, - scatter_indices: XlaOp, - updates: XlaOp, - update_computation: XlaComputation, - dimension_numbers: _ScatterDimensionNumbers, - indices_are_sorted: bool = ..., - unique_indices: bool = ...) -> XlaOp: ... -@overload -def Scatter( - inputs: Sequence[XlaOp], - scatter_indices: XlaOp, - updates: Sequence[XlaOp], - update_computation: XlaComputation, - dimension_numbers: _ScatterDimensionNumbers, - indices_are_sorted: bool = ..., - unique_indices: bool = ...) -> XlaOp: ... -def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ... -def SelectAndScatterWithGeneralPadding( - operand: XlaOp, - select: XlaComputation, - window_dimensions: Sequence[int], - window_strides: Sequence[int], - padding: Sequence[tuple[int, int]], - source: XlaOp, - init_value: XlaOp, - scatter: XlaComputation) -> XlaOp: ... -def Slice( - operand: XlaOp, - start_indices: Sequence[int], - limit_indices: Sequence[int], - strides: Sequence[int]) -> XlaOp: ... -def SliceInDim( - operand: XlaOp, - start_index: int, - limit_index: int, - stride: int, - dimno: int) -> XlaOp: ... -def Sort( - builder: XlaBuilder, - operands: Sequence[XlaOp], - comparator: XlaComputation | None = ..., - dimension: int = ..., - is_stable: bool = ...) -> XlaOp: ... -def SVD( - a: XlaOp, - max_iter: int = ..., - epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ... -def TopK(input: XlaOp, k: int) -> XlaOp: ... -def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ... -def TriangularSolve( - a: XlaOp, - b: XlaOp, - left_side: bool, - lower: bool, - unit_diagonal: bool, - transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ... -def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ... -def While( - condition: XlaComputation, - body: XlaComputation, - init: XlaOp) -> XlaOp: ... - - -def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ... -def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ... -def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ... -def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ... -def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ... -def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ... - -def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... -def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... - -def Not(__arg: XlaOp) -> XlaOp: ... -def PopulationCount(__arg: XlaOp) -> XlaOp: ... -def Clz(__arg: XlaOp) -> XlaOp: ... -def Abs(__arg: XlaOp) -> XlaOp: ... -def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Floor(__arg: XlaOp) -> XlaOp: ... -def Ceil(__arg: XlaOp) -> XlaOp: ... -def Round(__arg: XlaOp) -> XlaOp: ... -def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Sign(__arg: XlaOp) -> XlaOp: ... -def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ... -def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def IsFinite(__arg: XlaOp) -> XlaOp: ... -def Neg(__arg: XlaOp) -> XlaOp: ... -def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def Square(__arg: XlaOp) -> XlaOp: ... -def Reciprocal(__arg: XlaOp) -> XlaOp: ... -def Erfc(__arg: XlaOp) -> XlaOp: ... -def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... -def ErfInv(__arg: XlaOp) -> XlaOp: ... -def Lgamma(__arg: XlaOp) -> XlaOp: ... -def Digamma(__arg: XlaOp) -> XlaOp: ... -def BesselI0e(__arg: XlaOp) -> XlaOp: ... -def BesselI1e(__arg: XlaOp) -> XlaOp: ... -def Acos(__arg: XlaOp) -> XlaOp: ... -def Asin(__arg: XlaOp) -> XlaOp: ... -def Atan(__arg: XlaOp) -> XlaOp: ... -def Acosh(__arg: XlaOp) -> XlaOp: ... -def Asinh(__arg: XlaOp) -> XlaOp: ... -def Atanh(__arg: XlaOp) -> XlaOp: ... -def Cosh(__arg: XlaOp) -> XlaOp: ... -def Sinh(__arg: XlaOp) -> XlaOp: ... -def Real(__arg: XlaOp) -> XlaOp: ... -def Imag(__arg: XlaOp) -> XlaOp: ... -def Conj(__arg: XlaOp) -> XlaOp: ... diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 34631f328c29..f63dfbe471dc 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -17,18 +17,14 @@ from __future__ import annotations import atexit -from collections.abc import Mapping, Sequence +from collections.abc import Mapping import contextlib import enum -import inspect import logging import os import threading from typing import Any, Protocol, Union -import ml_dtypes -import numpy as np - from jaxlib import _jax as _xla # Note this module does *not* depend on any Python protocol buffers. The XLA @@ -193,73 +189,8 @@ def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: return options -class OpMetadata: - """Python representation of a xla.OpMetadata protobuf.""" - - __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') - - def __init__(self, op_type='', op_name='', source_file='', source_line=0): - self.op_type = op_type - self.op_name = op_name - self.source_file = source_file - self.source_line = source_line - - -def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): - """Helper for use in source mapping that returns an OpMetadata object.""" - full_filename, lineno = inspect.stack()[skip_frames][1:3] - filename = os.path.basename(full_filename) - return OpMetadata( - op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno - ) - - PrimitiveType = _xla.PrimitiveType -XLA_ELEMENT_TYPE_TO_DTYPE = { - PrimitiveType.PRED: np.dtype('bool'), - PrimitiveType.S4: np.dtype(ml_dtypes.int4), - PrimitiveType.S8: np.dtype('int8'), - PrimitiveType.S16: np.dtype('int16'), - PrimitiveType.S32: np.dtype('int32'), - PrimitiveType.S64: np.dtype('int64'), - PrimitiveType.U4: np.dtype(ml_dtypes.uint4), - PrimitiveType.U8: np.dtype('uint8'), - PrimitiveType.U16: np.dtype('uint16'), - PrimitiveType.U32: np.dtype('uint32'), - PrimitiveType.U64: np.dtype('uint64'), - PrimitiveType.F4E2M1FN: np.dtype(ml_dtypes.float4_e2m1fn), - PrimitiveType.F8E3M4: np.dtype(ml_dtypes.float8_e3m4), - PrimitiveType.F8E4M3: np.dtype(ml_dtypes.float8_e4m3), - PrimitiveType.F8E4M3FN: np.dtype(ml_dtypes.float8_e4m3fn), - PrimitiveType.F8E4M3B11FNUZ: np.dtype(ml_dtypes.float8_e4m3b11fnuz), - PrimitiveType.F8E4M3FNUZ: np.dtype(ml_dtypes.float8_e4m3fnuz), - PrimitiveType.F8E5M2: np.dtype(ml_dtypes.float8_e5m2), - PrimitiveType.F8E5M2FNUZ: np.dtype(ml_dtypes.float8_e5m2fnuz), - PrimitiveType.F8E8M0FNU: np.dtype(ml_dtypes.float8_e8m0fnu), - PrimitiveType.BF16: np.dtype(ml_dtypes.bfloat16), - PrimitiveType.F16: np.dtype('float16'), - PrimitiveType.F32: np.dtype('float32'), - PrimitiveType.F64: np.dtype('float64'), - PrimitiveType.C64: np.dtype('complex64'), - PrimitiveType.C128: np.dtype('complex128'), - PrimitiveType.TUPLE: np.dtype(np.object_), - PrimitiveType.TOKEN: np.dtype(np.object_), -} - -# Note the conversion on the key. Numpy has a known issue wherein dtype hashing -# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, -# when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} - - -def dtype_to_etype(dtype): - """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" - return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - - Shape = _xla.Shape Shape.__doc__ = """ A Shape is an object defined in C++ that duck types like the following class: @@ -342,22 +273,6 @@ def __repr__(self): """ -def shape_from_pyval(pyval, layout: Sequence[int] | None = None): - """Returns a Shape that describes a tuple-tree of Numpy arrays.""" - - def convert(pyval): - if isinstance(pyval, tuple): - if layout is not None: - raise NotImplementedError( - 'shape_from_pyval does not support layouts for tuple shapes' - ) - return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) - else: - return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) - - return convert(pyval) - - DeviceAssignment = _xla.DeviceAssignment DeviceAssignment.__doc__ = """ A DeviceAssignment is a C++ object with the following signature. @@ -409,48 +324,7 @@ def computation_count(): # There are different implementations of Executable for different backends. -class PaddingType(enum.Enum): - VALID = 1 - SAME = 2 - - -def window_padding_type_to_pad_values( - padding_type, lhs_dims, rhs_dims, window_strides -): - """Maps PaddingType or string to pad values (list of pairs of ints).""" - if not isinstance(padding_type, (str, PaddingType)): - msg = 'padding_type must be str or PaddingType, got {}.' - raise TypeError(msg.format(type(padding_type))) - - if isinstance(padding_type, str): - if padding_type.upper() == 'VALID': - padding_type = PaddingType.VALID - elif padding_type.upper() == 'SAME': - padding_type = PaddingType.SAME - else: - msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' - raise ValueError(msg.format(padding_type)) - - if padding_type == PaddingType.VALID: - return [(0, 0)] * len(window_strides) - elif padding_type == PaddingType.SAME: - out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) - pad_sizes = [ - max((out_size - 1) * stride + filter_size - in_size, 0) - for out_size, stride, filter_size, in_size in zip( - out_shape, window_strides, rhs_dims, lhs_dims - ) - ] - return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] - else: - msg = 'Unexpected PaddingType value: {}' - raise ValueError(msg.format(padding_type)) - - -XlaBuilder = _xla.XlaBuilder XlaComputation = _xla.XlaComputation -XlaOp = _xla.XlaOp -FftType = _xla.FftType Client = _xla.Client Memory = _xla.Memory Array = _xla.Array @@ -466,7 +340,6 @@ def window_padding_type_to_pad_values( GSPMDSharding = _xla.GSPMDSharding PjRtLayout = _xla.PjRtLayout AutotuneCacheMode = _xla.AutotuneCacheMode -ResultAccuracyMode = _xla.ResultAccuracy_Mode def LoadedExecutable_execute(self, arguments, device=None): @@ -648,284 +521,6 @@ def register_custom_type_id_handler( ) -class PaddingConfigDimension: - """Python representation of a xla.PaddingConfigDimension protobuf.""" - - __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') - - edge_padding_low: int - edge_padding_high: int - interior_padding: int - - def __init__(self): - self.edge_padding_low = 0 - self.edge_padding_high = 0 - self.interior_padding = 0 - - -class PaddingConfig: - """Python representation of a xla.PaddingConfig protobuf.""" - - __slots__ = ('dimensions',) - - def __init__(self): - self.dimensions = [] - - -def make_padding_config( - padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] -) -> PaddingConfig: - """Create PaddingConfig proto from list of triples of integers. - - Args: - padding_config: either a PaddingConfig or a list of integer triples - (edge_padding_low, edge_padding_high, interior_padding) representing the - configuration of the padding operation. - - Returns: - A `PaddingConfig` object. - """ - if not isinstance(padding_config, PaddingConfig): - triples = padding_config - padding_config = PaddingConfig() - for lo, hi, interior in triples: - dimension = PaddingConfigDimension() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - padding_config.dimensions.append(dimension) - return padding_config - - -class DotDimensionNumbers: - """Python representation of a xla.DotDimensionNumbers protobuf.""" - - __slots__ = ( - 'lhs_contracting_dimensions', - 'rhs_contracting_dimensions', - 'lhs_batch_dimensions', - 'rhs_batch_dimensions', - ) - - def __init__(self): - self.lhs_contracting_dimensions = [] - self.rhs_contracting_dimensions = [] - self.lhs_batch_dimensions = [] - self.rhs_batch_dimensions = [] - - -def make_dot_dimension_numbers( - dimension_numbers: Union[ - DotDimensionNumbers, - tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], - ] -) -> DotDimensionNumbers: - """Builds a DotDimensionNumbers object from a specification. - - Args: - dimension_numbers: either a `DotDimensionNumbers` or a nested tuple - `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of - integers representing the dimensions to treat as contracting dimensions - and batch dimensions on each input operand. - - Returns: - A `DotDimensionNumbers` object. - """ - if isinstance(dimension_numbers, (list, tuple)): - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers - dot_dims_proto = DotDimensionNumbers() - dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) - dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) - dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) - dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) - return dot_dims_proto - else: - return dimension_numbers - - -class ConvolutionDimensionNumbers: - """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" - - __slots__ = ( - 'input_batch_dimension', - 'input_feature_dimension', - 'input_spatial_dimensions', - 'kernel_input_feature_dimension', - 'kernel_output_feature_dimension', - 'kernel_spatial_dimensions', - 'output_batch_dimension', - 'output_feature_dimension', - 'output_spatial_dimensions', - ) - - def __init__(self): - self.input_batch_dimension = 0 - self.input_feature_dimension = 0 - self.input_spatial_dimensions = [] - self.kernel_input_feature_dimension = 0 - self.kernel_output_feature_dimension = 0 - self.kernel_spatial_dimensions = [] - self.output_batch_dimension = 0 - self.output_feature_dimension = 0 - self.output_spatial_dimensions = [] - - -def make_convolution_dimension_numbers( - dimension_numbers: Union[ - None, ConvolutionDimensionNumbers, tuple[str, str, str] - ], - num_spatial_dimensions: int, -) -> ConvolutionDimensionNumbers: - """Builds a ConvolutionDimensionNumbers object from a specification. - - Args: - dimension_numbers: optional, either a ConvolutionDimensionNumbers object or - a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length - N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the - output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions in - rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers consistent - with the Conv operation with two spatial dimensions, one could use - ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension - numbers consistent with the TensorFlow Conv2D operation, one could use - ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution - dimension specification, window strides are associated with spatial - dimension character labels according to the order in which the labels - appear in the rhs_spec string, so that window_strides[0] is matched with - the dimension corresponding to the first character appearing in rhs_spec - that is not 'I' or 'O'. By default, use the same dimension numbering as - Conv and ConvWithGeneralPadding. - num_spatial_dimensions: the number of spatial dimensions. - - Returns: - A `ConvolutionDimensionNumbers` object. - """ - if dimension_numbers is None: - nd = num_spatial_dimensions - dimension_numbers = ConvolutionDimensionNumbers() - dimension_numbers.input_batch_dimension = 0 - dimension_numbers.input_feature_dimension = 1 - dimension_numbers.output_batch_dimension = 0 - dimension_numbers.output_feature_dimension = 1 - dimension_numbers.kernel_output_feature_dimension = 0 - dimension_numbers.kernel_input_feature_dimension = 1 - dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) - elif isinstance(dimension_numbers, tuple): - lhs_spec, rhs_spec, out_spec = dimension_numbers - dimension_numbers = ConvolutionDimensionNumbers() - - dimension_numbers.input_batch_dimension = lhs_spec.index('N') - dimension_numbers.input_feature_dimension = lhs_spec.index('C') - dimension_numbers.output_batch_dimension = out_spec.index('N') - dimension_numbers.output_feature_dimension = out_spec.index('C') - dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') - dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') - - dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} - ) - dimension_numbers.input_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]), - ) - ) - dimension_numbers.output_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]), - ) - ) - return dimension_numbers - - -class PrecisionConfig: - """Python representation of a xla.PrecisionConfig protobuf.""" - - __slots__ = ('operand_precision',) - - Precision = _xla.PrecisionConfig_Precision - - def __init__(self): - self.operand_precision = [] - - -class ResultAccuracy: - """Python representation of a xla.ResultAccuracy protobuf.""" - - __slots__ = ('mode', 'atol', 'rtol', 'ulps') - - def __init__(self): - self.mode = _xla.ResultAccuracy_Mode.DEFAULT - self.atol = 0.0 - self.rtol = 0.0 - self.ulps = 0 - - -class GatherDimensionNumbers: - """Python representation of a xla.GatherDimensionNumbers protobuf.""" - - __slots__ = ( - 'offset_dims', - 'collapsed_slice_dims', - 'start_index_map', - 'index_vector_dim', - ) - - def __init__(self): - self.offset_dims = [] - self.collapsed_slice_dims = [] - self.start_index_map = [] - self.index_vector_dim = 0 - - -class ScatterDimensionNumbers: - """Python representation of a xla.ScatterDimensionNumbers protobuf.""" - - __slots__ = ( - 'update_window_dims', - 'inserted_window_dims', - 'scatter_dims_to_operand_dims', - 'index_vector_dim', - ) - - def __init__(self): - self.update_window_dims = [] - self.inserted_window_dims = [] - self.scatter_dims_to_operand_dims = [] - self.index_vector_dim = 0 - - -class ReplicaGroup: - """Python representation of a xla.ReplicaGroup protobuf.""" - - __slots__ = ('replica_ids',) - - def __init__(self): - self.replica_ids = [] - - -def _make_replica_group_proto(replica_group): - replica_group_proto = ReplicaGroup() - replica_group_proto.replica_ids.extend(replica_group) - return replica_group_proto - - -def make_replica_groups(replica_groups): - if replica_groups is None: - replica_groups_protos = [] # special value for XLA API - else: - replica_groups = list(replica_groups) - replica_groups_protos = [ - _make_replica_group_proto(group) for group in replica_groups - ] - return replica_groups_protos - - Traceback = _xla.Traceback Frame = _xla.Frame diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 2b78a31fea72..80599e86676b 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -19,9 +19,8 @@ from collections.abc import Callable, Mapping, Sequence import enum from typing import Any, Union -import numpy - from jaxlib import _jax as _xla +from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics from jaxlib._jax import ArrayImpl as ArrayImpl from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode from jaxlib._jax import Client as Client @@ -31,7 +30,6 @@ from jaxlib._jax import DeviceAssignment as DeviceAssignment from jaxlib._jax import DeviceList as DeviceList from jaxlib._jax import DeviceTopology as DeviceTopology from jaxlib._jax import DistributedRuntimeClient as DistributedRuntimeClient -from jaxlib._jax import FftType as FftType from jaxlib._jax import Frame as Frame from jaxlib._jax import GSPMDSharding as GSPMDSharding from jaxlib._jax import HloSharding as HloSharding @@ -45,31 +43,17 @@ from jaxlib._jax import OpSharding as OpSharding from jaxlib._jax import PjRtLayout as PjRtLayout from jaxlib._jax import PmapSharding as PmapSharding from jaxlib._jax import PrimitiveType as PrimitiveType -from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics from jaxlib._jax import Shape as Shape from jaxlib._jax import Sharding as Sharding from jaxlib._jax import SingleDeviceSharding as SingleDeviceSharding from jaxlib._jax import Traceback as Traceback -from jaxlib._jax import XlaBuilder as XlaBuilder from jaxlib._jax import XlaComputation as XlaComputation -from jaxlib._jax import XlaOp as XlaOp _version: int - _ifrt_version: int -XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] - _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] -def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: - ... - -def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... - -def heap_profile(client: Client) -> bytes: - ... - XlaRuntimeError = _xla.XlaRuntimeError def make_cpu_client( @@ -79,9 +63,7 @@ def make_cpu_client( num_nodes: int = ..., collectives: _xla.CpuCollectives | None = ..., num_devices: int | None = ..., -) -> Client: - ... - +) -> Client: ... def make_gpu_client( distributed_client: DistributedRuntimeClient | None = ..., node_id: int = ..., @@ -90,159 +72,33 @@ def make_gpu_client( allowed_devices: set[int] | None = ..., mock: bool | None = ..., mock_gpu_topology: str | None = ..., -) -> Client: - ... - +) -> Client: ... def make_tfrt_tpu_c_api_device_topology( topology_name: str | None = None, **kwargs -) -> DeviceTopology: - ... - -def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology: - ... - -def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: - ... - -def make_tpu_client( - library_path: str | None, options: _NameValueMapping | None = None -) -> Client: - ... - +) -> DeviceTopology: ... +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: ... +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... def make_c_api_client( plugin_name: str, options: _NameValueMapping | None = None, distributed_client: DistributedRuntimeClient | None = None, -) -> Client: - ... - -def pjrt_plugin_loaded(plugin_name: str) -> bool: - ... - -def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: - ... - -def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: - ... - -def pjrt_plugin_initialized(plugin_name: str) -> bool: - ... - -def initialize_pjrt_plugin(plugin_name: str) -> None: - ... - -def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: - ... - -class OpMetadata: - - def __init__( - self, - op_type: str | None = ..., - op_name: str | None = ..., - source_file: str | None = ..., - source_line: int | None = ..., - ): - ... - op_type: str | None - op_name: str | None - source_file: str | None - source_line: int | None - -class PaddingConfigDimension: - edge_padding_low: int - edge_padding_high: int - interior_padding: int - -class PaddingConfig: - dimensions: list[PaddingConfigDimension] - -def make_padding_config( - padding_config: PaddingConfig | Sequence[tuple[int, int, int]], -) -> PaddingConfig: - ... - -class PaddingType(enum.Enum): - VALID = 1 - SAME = 2 - -class DotDimensionNumbers: - lhs_contracting_dimensions: list[int] - rhs_contracting_dimensions: list[int] - lhs_batch_dimensions: list[int] - rhs_batch_dimensions: list[int] - -def make_dot_dimension_numbers( - dimension_numbers: ( - DotDimensionNumbers | - tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] - ), -) -> DotDimensionNumbers: - ... - -class ConvolutionDimensionNumbers: - input_batch_dimension: int - input_feature_dimension: int - input_spatial_dimensions: list[int] - kernel_input_feature_dimension: int - kernel_output_feature_dimension: int - kernel_spatial_dimensions: list[int] - output_batch_dimension: int - output_feature_dimension: int - output_spatial_dimensions: list[int] - -def make_convolution_dimension_numbers( - dimension_numbers: ( - None | ConvolutionDimensionNumbers | tuple[str, str, str] - ), - num_spatial_dimensions: int, -) -> ConvolutionDimensionNumbers: - ... - -class PrecisionConfig: - Precision = _xla.PrecisionConfig_Precision - operand_precision: list[_xla.PrecisionConfig_Precision] - -class ResultAccuracy: - mode: _xla.ResultAccuracy_Mode - atol: float - rtol: float - ulps: int - -class GatherDimensionNumbers: - offset_dims: list[int] - collapsed_slice_dims: list[int] - start_index_map: list[int] - index_vector_dim: int - operand_batching_dims: list[int] - start_indices_batching_dims: list[int] - -class ScatterDimensionNumbers: - update_window_dims: list[int] - inserted_window_dims: list[int] - scatter_dims_to_operand_dims: list[int] - index_vector_dim: int - input_batching_dims: list[int] - scatter_indices_batching_dims: list[int] - -class ReplicaGroup: - replica_ids: list[int] - -def make_replica_groups( - replica_groups: Sequence[Sequence[int]] | None, -) -> list[ReplicaGroup]: - ... - -def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache: - ... - +) -> Client: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def load_pjrt_plugin_dynamically( + plugin_name: str, library_path: str +) -> Any: ... +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(plugin_name: str) -> None: ... +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: ... def batched_copy_array_to_devices_with_sharding( arrays: Sequence[ArrayImpl], devices: Sequence[list[Device]], sharding: Sequence[Any], array_copy_semantics: Sequence[ArrayCopySemantics], ) -> Sequence[ArrayImpl]: ... - def batched_device_put( aval: Any, sharding: Any, @@ -252,25 +108,18 @@ def batched_device_put( force_copy: bool = ..., host_buffer_semantics: Any = ..., ) -> ArrayImpl: ... - def reorder_shards( x: ArrayImpl, dst_sharding: Any, array_copy_semantics: ArrayCopySemantics, ) -> ArrayImpl: ... - def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... - def check_and_canonicalize_memory_kind( memory_kind: str | None, device_list: DeviceList ) -> str | None: ... - def array_result_handler( - aval: Any, - sharding: Any, - committed: bool, - _skip_checks: bool = ...) -> Callable: - ... + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... class CustomCallTargetTraits(enum.IntFlag): DEFAULT = 0 @@ -283,21 +132,16 @@ def register_custom_call_target( api_version: int = ..., traits: CustomCallTargetTraits = ..., ) -> None: ... - def register_custom_call_handler( xla_platform_name: str, handler: Any ) -> None: ... - def custom_call_targets(platform: str) -> dict[str, Any]: ... - def register_custom_type_id( type_name: str, type_id: Any, platform: str = ..., ) -> None: ... - def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... - def encode_inspect_sharding_callback(handler: Any) -> bytes: ... register_custom_call_partitioner = _xla.register_custom_call_partitioner diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index add3ba9cfc15..13b903de6c31 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include -#include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -32,7 +31,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" @@ -51,7 +49,6 @@ limitations under the License. #include "xla/ffi/api/c_api.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -71,7 +68,6 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/service/call_inliner.h" @@ -80,7 +76,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" @@ -92,60 +87,11 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -namespace nanobind { -namespace detail { - -template <> -struct type_caster { - public: - NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, - const_name("xla::OpMetadata")); - - bool from_python(handle h, uint8_t, cleanup_list*) noexcept { - handle op_type = getattr(h, "op_type"); - if (!op_type.is_none()) { - value.set_op_type(cast(op_type)); - } - handle op_name = getattr(h, "op_name"); - if (!op_name.is_none()) { - value.set_op_name(cast(op_name)); - } - handle source_file = getattr(h, "source_file"); - if (!source_file.is_none()) { - value.set_source_file(cast(source_file)); - } - handle source_line = getattr(h, "source_line"); - if (!source_line.is_none()) { - value.set_source_line(cast(source_line)); - } - return true; - } -}; - -} // namespace detail -} // namespace nanobind - namespace xla { namespace { namespace nb = nanobind; -struct Uniquer { - absl::Mutex mu; - NameUniquer name_uniquer ABSL_GUARDED_BY(mu); -}; - -Uniquer* GetUniquer() { - static Uniquer* uniquer = new Uniquer; - return uniquer; -} - -static std::string UniquifyName(const std::string& name) { - Uniquer* uniquer = GetUniquer(); - absl::MutexLock lock(&uniquer->mu); - return uniquer->name_uniquer.GetUniqueName(name); -} - // Converts a computation to a serialized HloModuleProto. absl::StatusOr GetComputationSerializedProto( const XlaComputation& computation) { @@ -945,54 +891,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { return result; })); - nb::class_ xla_op_class(m, "XlaOp"); - - nb::class_(m, "XlaBuilder") - .def("__init__", - [](XlaBuilder* self, const std::string& name) { - new (self) XlaBuilder(UniquifyName(name)); - }) - // TODO(phawkins): delete capitalized names after updating callers. - .def("Build", - xla::ValueOrThrowWrapper( - [](XlaBuilder& builder, std::optional root) { - return root ? builder.Build(*root) : builder.Build(); - }), - "Builds a computation from the contents of the builder.", - nb::arg("root") = std::nullopt) - .def("GetShape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) - .def("build", - xla::ValueOrThrowWrapper( - [](XlaBuilder& builder, std::optional root) { - return root ? builder.Build(*root) : builder.Build(); - }), - "Builds a computation from the contents of the builder.", - nb::arg("root") = std::nullopt) - .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) - .def("get_shape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) - .def( - "get_program_shape", - [](const XlaBuilder& builder, - std::optional root) -> absl::StatusOr { - return root ? builder.GetProgramShape(*root) - : builder.GetProgramShape(); - }, - nb::arg("root") = std::nullopt) - .def("is_constant", xla::ValueOrThrowWrapper(&XlaBuilder::IsConstant)) - .def("set_op_metadata", &XlaBuilder::SetOpMetadata) - .def("set_sharding", &XlaBuilder::SetSharding) - .def("clear_sharding", &XlaBuilder::ClearSharding) - .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes) - .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes) - .def("setup_alias", - [](XlaBuilder& builder, const std::vector& output_index, - int64_t param_number, const std::vector& param_index) { - builder.SetUpAlias( - ShapeIndex(output_index.begin(), output_index.end()), - param_number, - ShapeIndex(param_index.begin(), param_index.end())); - }); - // Device assignments nb::class_(m, "DeviceAssignment") .def_static( @@ -1595,27 +1493,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { [](const xla::HloSharding& self) { return self.ToString(); }) .def("to_proto", &xla::HloSharding::ToProto); - nb::class_ frontend_attributes(m, "FrontendAttributes"); - frontend_attributes.def(nb::init<>()) - .def("__setitem__", - [](FrontendAttributes* attr, std::string key, std::string value) { - (*attr->mutable_map())[key] = value; - }); - - nb::enum_(m, "PrecisionConfig_Precision") - .value("DEFAULT", PrecisionConfig::DEFAULT) - .value("HIGH", PrecisionConfig::HIGH) - .value("HIGHEST", PrecisionConfig::HIGHEST); - - nb::enum_(m, "ResultAccuracy_Mode") - .value("DEFAULT", ResultAccuracy::DEFAULT) - .value("HIGHEST", ResultAccuracy::HIGHEST); - - nb::enum_(m, "FftType") - .value("FFT", FftType::FFT) - .value("IFFT", FftType::IFFT) - .value("RFFT", FftType::RFFT) - .value("IRFFT", FftType::IRFFT); // Hlo Module Passes nb::class_ hlo_pass_interface(m, "HloPassInterface"); diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 79d4dc038fc2..08052b041bae 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -36,7 +36,9 @@ def setUp(self): raise SkipTest("infeed not implemented in PJRT C API") super().setUp() - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. + @jax.numpy_rank_promotion( + "allow" + ) # Test explicitly exercises implicit rank promotion. def testInfeed(self): raise SkipTest("skipping temporarily for stackless") @@ -44,13 +46,17 @@ def testInfeed(self): def f(x): token = lax.create_token(x) (y,), token = lax.infeed( - token, shape=(core.ShapedArray((3, 4), jnp.float32),)) + token, shape=(core.ShapedArray((3, 4), jnp.float32),) + ) (z,), _ = lax.infeed( - token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),)) + token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),) + ) return x + y + z x = np.float32(1.5) - y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # self.rng().randn(3, 4).astype(np.float32) + y = np.reshape( + np.arange(12, dtype=np.float32), (3, 4) + ) # self.rng().randn(3, 4).astype(np.float32) z = self.rng().randn(3, 1, 1).astype(np.float32) device = jax.local_devices()[0] device.transfer_to_infeed((y,)) @@ -63,8 +69,11 @@ def testInfeedPytree(self): x = np.float32(1.5) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) to_infeed = dict(a=x, b=y) - to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32), - b=core.ShapedArray((3, 4), dtype=np.int16)) + to_infeed_shape = dict( + a=core.ShapedArray((), dtype=np.float32), + b=core.ShapedArray((3, 4), dtype=np.int16), + ) + @jax.jit def f(x): token = lax.create_token(x) @@ -77,16 +86,18 @@ def f(x): device.transfer_to_infeed(tuple(flat_to_infeed)) self.assertAllClose(f(x), to_infeed) - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. - @jtu.ignore_warning(category=DeprecationWarning, - message=".*(infeed|outfeed) was deprecated.*") + @jax.numpy_rank_promotion( + "allow" + ) # Test explicitly exercises implicit rank promotion. + @jtu.ignore_warning( + category=DeprecationWarning, message=".*(infeed|outfeed) was deprecated.*" + ) def testInfeedThenOutfeed(self): @jax.jit def f(x): token = lax.create_token(x) - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) + y, token = lax.infeed(token, shape=core.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 @@ -96,18 +107,21 @@ def f(x): execution.start() device = jax.local_devices()[0] device.transfer_to_infeed((y,)) - out, = device.transfer_from_outfeed( - xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) + out = device.transfer_from_outfeed( + xla_client.Shape.array_shape( + xla_client.PrimitiveType.F32, (3, 4) + ).with_major_to_minor_layout_if_absent() + ) execution.join() self.assertAllClose(out, y + np.float32(1)) - @jtu.ignore_warning(category=DeprecationWarning, - message=".*(infeed|outfeed) was deprecated.*") + @jtu.ignore_warning( + category=DeprecationWarning, message=".*(infeed|outfeed) was deprecated.*" + ) def testInfeedThenOutfeedInALoop(self): def doubler(_, token): - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) + y, token = lax.infeed(token, shape=core.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2)) @jax.jit @@ -123,11 +137,14 @@ def f(n): for _ in range(n): x = self.rng().randn(3, 4).astype(np.float32) device.transfer_to_infeed((x,)) - y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) - .with_major_to_minor_layout_if_absent()) + y = device.transfer_from_outfeed( + xla_client.Shape.array_shape( + xla_client.PrimitiveType.F32, (3, 4) + ).with_major_to_minor_layout_if_absent() + ) self.assertAllClose(y, x * np.float32(2)) execution.join() -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9174beeb2525..03f5dd13c834 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -902,8 +902,11 @@ def _dispatch(): def check_outfeed(x_fn): for didx, d in enumerate(devices): x = x_fn(didx) - y, = d.transfer_from_outfeed( - xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent()) + y = d.transfer_from_outfeed( + xc.Shape.array_shape( + xc.PrimitiveType.F32, x.shape + ).with_major_to_minor_layout_if_absent() + ) self.assertAllClose(x, y, check_dtypes=True) logging.info('Transferring from outfeed for the pjit call') From 5263f0f869dcc0e0c2fc7b4e0ebe0d35b4c00613 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 25 Apr 2025 15:41:02 -0700 Subject: [PATCH 0846/1769] [JAX] [XLA:Python] Move ShapeIndex bindings out of JAX and into XLA. JAX does not use this class any more. PiperOrigin-RevId: 751584043 --- jaxlib/_jax/__init__.pyi | 105 ++++++++++++++++++++++----------------- jaxlib/xla_client.py | 22 -------- jaxlib/xla_compiler.cc | 13 ----- 3 files changed, 59 insertions(+), 81 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index fd9b1e068c74..9c5b493f3d60 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -16,11 +16,11 @@ from __future__ import annotations import builtins +from collections.abc import Callable, Iterator, Mapping, Sequence import enum import inspect import types from typing import Any, ClassVar, TypeVar, overload -from collections.abc import Callable, Mapping, Iterator, Sequence import numpy as np @@ -94,9 +94,12 @@ class Layout: @overload def __init__(self, minor_to_major: tuple[int, ...]): ... @overload - def __init__(self, minor_to_major: tuple[int, ...], - tiling: tuple[tuple[int, ...], ...], - element_size_in_bits: int): ... + def __init__( + self, + minor_to_major: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...], + element_size_in_bits: int, + ): ... def minor_to_major(self) -> tuple[int, ...]: ... def tiling(self) -> Sequence[tuple[int, ...]]: ... def element_size_in_bits(self) -> int: ... @@ -148,13 +151,6 @@ class ProgramShape: def result_shape(self) -> Shape: ... def __repr__(self) -> str: ... -class ShapeIndex: - def __init__(self, indices: list[int]) -> None: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - class Literal: def __init__(self, shape: Shape) -> None: ... def __repr__(self) -> str: ... @@ -253,7 +249,10 @@ class CompileOptions: env_option_overrides: list[tuple[str, str]] def register_custom_call_target( - fn_name: str, capsule: Any, platform: str, api_version: int = ..., + fn_name: str, + capsule: Any, + platform: str, + api_version: int = ..., ) -> _Status: ... def register_custom_call_partitioner( name: str, @@ -268,7 +267,6 @@ def register_custom_call_as_batch_partitionable( target_name: str, c_api: Any | None = ..., ) -> None: ... - def register_custom_type_id(type_name: str, type_id: Any) -> None: ... class AutotuneCacheMode(enum.IntEnum): @@ -346,7 +344,9 @@ class ExecutableBuildOptions: auto_spmd_partitioning_mesh_shape: list[int] auto_spmd_partitioning_mesh_ids: list[int] use_shardy_partitioner: bool - def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... + def compilation_environments_from_serialized_proto( + self, serialized_proto: bytes + ) -> None: ... class OpSharding_Type(enum.IntEnum): REPLICATED = ... @@ -402,8 +402,8 @@ class HloSharding: def unknown() -> HloSharding: ... @staticmethod def subgroup_with_device_ordering( - tile_assignment: np.ndarray, - subgroup_types: Sequence[OpSharding_Type]) -> HloSharding: ... + tile_assignment: np.ndarray, subgroup_types: Sequence[OpSharding_Type] + ) -> HloSharding: ... def __eq__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... @@ -549,7 +549,6 @@ class MpiCollectives(CpuCollectives): def Finalize(self): ... def make_mpi_collectives() -> MpiCollectives: ... - def get_tfrt_cpu_client( asynchronous: bool = ..., distributed_client: DistributedRuntimeClient | None = ..., @@ -593,7 +592,9 @@ def get_c_api_topology( options: dict[str, str | int | list[int] | float], ) -> DeviceTopology: ... def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... -def load_pjrt_plugin(platform_name: str, library_path: str | None, c_api: Any | None) -> _Status: ... +def load_pjrt_plugin( + platform_name: str, library_path: str | None, c_api: Any | None +) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... @@ -634,9 +635,7 @@ def batched_copy_array_to_devices_with_sharding( sharding: Sequence[Any], array_copy_semantics: Sequence[ArrayCopySemantics], ) -> Sequence[ArrayImpl]: ... - def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... - def batched_device_put( aval: Any, sharding: Any, @@ -644,13 +643,11 @@ def batched_device_put( devices: list[Device], committed: bool = True, ) -> ArrayImpl: ... - def reorder_shards( x: ArrayImpl, dst_sharding: Any, array_copy_semantics: ArrayCopySemantics, ) -> ArrayImpl: ... - def check_and_canonicalize_memory_kind( memory_kind: str | None, device_list: DeviceList ) -> str | None: ... @@ -724,18 +721,23 @@ def dlpack_managed_tensor_to_buffer( tensor: Any, device: Device, stream: int | None ) -> ArrayImpl: ... @overload -def dlpack_managed_tensor_to_buffer( # Legacy overload +def dlpack_managed_tensor_to_buffer( # Legacy overload tensor: Any, cpu_backend: Client | None = ..., gpu_backend: Client | None = ..., ) -> ArrayImpl: ... - def cuda_array_interface_to_buffer( - cai: dict[str, ( - str | int | None | - tuple[int, ...] | tuple[int, bool] | - list[tuple[str, str]] | - list[tuple[str, str, tuple[int, ...]]]) + cai: dict[ + str, + ( + str + | int + | None + | tuple[int, ...] + | tuple[int, bool] + | list[tuple[str, str]] + | list[tuple[str, str, tuple[int, ...]]] + ), ], gpu_backend: Client | None = ..., device_id: int | None = None, @@ -748,11 +750,13 @@ class Frame: function_name: str function_line_start: int line_num: int - def __init__(self, - file_name: str, - function_name: str, - function_line_start: int, - line_num: int): ... + def __init__( + self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int, + ): ... def __repr__(self) -> str: ... class Traceback: @@ -790,13 +794,19 @@ class DistributedRuntimeClient: def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... - def key_value_set(self, key: str, value: str, - allow_overwrite: bool = False) -> _Status: ... - def key_value_set_bytes(self, key: str, value: bytes, - allow_overwrite: bool = False) -> _Status: ... + def key_value_set( + self, key: str, value: str, allow_overwrite: bool = False + ) -> _Status: ... + def key_value_set_bytes( + self, key: str, value: bytes, allow_overwrite: bool = False + ) -> _Status: ... def key_value_delete(self, key: str) -> _Status: ... - def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int, - process_ids: list[int] | None = None) -> _Status: ... + def wait_at_barrier( + self, + barrier_id: str, + timeout_in_ms: int, + process_ids: list[int] | None = None, + ) -> _Status: ... def get_live_nodes(self, process_ids: list[int]) -> _Status: ... def get_distributed_runtime_service( @@ -970,22 +980,25 @@ def is_tsan() -> bool: ... def is_sanitized() -> bool: ... class TransferConnection: - def address(self) -> str: ... - def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... class TransferServer: def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... - def connect(self, address: str) -> TransferConnection: ... -def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ... - +def start_transfer_server( + client: Client, + address: str = "", + transport_addresses: list[str] = [], + max_num_parallel_copies: int = 0, + transfer_size: int = 0, +) -> TransferServer: ... def approx_top_k_reduction_output_size( input_size: int, rank: int, top_k: int, recall_target: float, aggregate_to_topk: bool | None = ..., - input_size_override: int | None = ...) -> tuple[int, int]: ... + input_size_override: int | None = ..., +) -> tuple[int, int]: ... diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index f63dfbe471dc..4521019fa77a 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -251,28 +251,6 @@ def result_shape(self) -> Shape: def __repr__(self): """ -ShapeIndex = _xla.ShapeIndex -ShapeIndex.__doc__ = """ -A Shape is an object defined in C++ that duck types like the following class: - -class ShapeIndex: - '''Represents an XLA ShapeIndex. - - An index for specifying a particular nested subshape within a shape. Used in - ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through - the Shape tree where each element of ShapeIndex indexes into a tuple (or - nested tuple) within the shape. For a non-nested tuple, an index has a single - element. - ''' - - def __init__(self, List[int]) -> ShapeIndex: - def __eq__(self, other: Shape) -> bool: - def __ne__(self, other: Shape) -> bool: - def __hash__(self): - def __repr__(self): -""" - - DeviceAssignment = _xla.DeviceAssignment DeviceAssignment.__doc__ = """ A DeviceAssignment is a C++ object with the following signature. diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 13b903de6c31..9b1377a1a9d1 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -647,19 +647,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def("result_shape", &ProgramShape::result) .def("__repr__", &ProgramShape::ToString); - nb::class_(m, "ShapeIndex") - .def("__init__", - [](ShapeIndex* self, const std::vector& v) { - new (self) ShapeIndex(v.begin(), v.end()); - }) - .def("__repr__", &ShapeIndex::ToString) - .def("__eq__", [](const ShapeIndex& shape_ind, - const ShapeIndex& other) { return shape_ind == other; }) - .def("__ne__", [](const ShapeIndex& shape_ind, - const ShapeIndex& other) { return shape_ind != other; }) - .def("__hash__", - [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); }); - // Literals nb::class_(m, "Literal") .def(nb::init()) From 351c13e5b6dfec90863ef64f137a28da7426f38a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 26 Apr 2025 05:25:18 -0700 Subject: [PATCH 0847/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4bf2349e4b17d82bfdd9f3ae61631813916d01a2. PiperOrigin-RevId: 751746177 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6240e79aeb22..c9cad29b0956 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ee9ee727b533dbd14698c9eda979a8c83ed86e11" -XLA_SHA256 = "63ebc70aa209ada6b29faea67c196f9c1237e14bb381b2c014db0405b80881ec" +XLA_COMMIT = "4bf2349e4b17d82bfdd9f3ae61631813916d01a2" +XLA_SHA256 = "d11cd6b6de56204b26b1ba705b2b35803297e26a6c81828d4f871ecaafd97ac9" def repo(): tf_http_archive( From 110c65849996c40e18f87f778cea9d1c93f35dee Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 27 Apr 2025 05:55:27 -0700 Subject: [PATCH 0848/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e039a21a6ffa80bab553479d4a62054c08316e6e. PiperOrigin-RevId: 751992751 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c9cad29b0956..420f93b36716 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4bf2349e4b17d82bfdd9f3ae61631813916d01a2" -XLA_SHA256 = "d11cd6b6de56204b26b1ba705b2b35803297e26a6c81828d4f871ecaafd97ac9" +XLA_COMMIT = "e039a21a6ffa80bab553479d4a62054c08316e6e" +XLA_SHA256 = "c1a353cee867f1f75079f9ca8bacfb5831ce6eb740382dd393b084c984184f34" def repo(): tf_http_archive( From d6efd0546819ed08483f2c2ac4566522daccd873 Mon Sep 17 00:00:00 2001 From: Michal Date: Sun, 27 Apr 2025 16:46:14 +0200 Subject: [PATCH 0849/1769] Improving PyTorch data loading notebook Simplifying DataLoader object creation --- .../Neural_Network_and_Data_Loading.ipynb | 272 +++++++++--------- .../Neural_Network_and_Data_Loading.md | 41 +-- 2 files changed, 145 insertions(+), 168 deletions(-) diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index a7ef2a017048..4c9b6c5e48a7 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "id": "OksHydJDtbbI" }, @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "id": "-fmWA06xYE7d" }, @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "id": "7APc6tD7TiuZ" }, @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "id": "4sW2A5mnXHc5", "outputId": "9d3b29e8-fab3-4ecb-9f63-bc8c092f9006" @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "id": "PpyQxuedXfhp", "outputId": "d5d20211-b6da-44e9-f71e-946f2a9d0fc4" @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "id": "oJOOncKMXbwK", "outputId": "31285fab-7667-4871-fcba-28e86adc3fc6" @@ -229,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "id": "6lTI6I4lWdh5" }, @@ -268,21 +268,37 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "id": "gEvWt8_u2pqG", "outputId": "2c83a679-9ce5-4c67-bccb-9ea835a8eaf6" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)\n", - "Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)\n", - "Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)\n", - "Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)\n", - "Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)\n" + "Requirement already satisfied: torch in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (2.4.1)\n", + "Requirement already satisfied: torchvision in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (0.19.1)\n", + "Requirement already satisfied: filelock in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.16.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (1.13.2)\n", + "Requirement already satisfied: networkx in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (2024.9.0)\n", + "Requirement already satisfied: setuptools in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (73.0.1)\n", + "Requirement already satisfied: numpy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (10.4.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from sympy->torch) (1.3.0)\n" ] } ], @@ -292,7 +308,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "cellView": "both", "id": "94PjXZ8y3dVF" @@ -301,38 +317,24 @@ "source": [ "import numpy as np\n", "from jax.tree_util import tree_map\n", - "from torch.utils import data\n", + "from torch.utils.data import DataLoader, default_collate\n", "from torchvision.datasets import MNIST\n", "\n", "def numpy_collate(batch):\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", - "\n", - "class NumpyLoader(data.DataLoader):\n", - " def __init__(self, dataset, batch_size=1,\n", - " shuffle=False, sampler=None,\n", - " batch_sampler=None, num_workers=0,\n", - " pin_memory=False, drop_last=False,\n", - " timeout=0, worker_init_fn=None):\n", - " super(self.__class__, self).__init__(dataset,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " sampler=sampler,\n", - " batch_sampler=batch_sampler,\n", - " num_workers=num_workers,\n", - " collate_fn=numpy_collate,\n", - " pin_memory=pin_memory,\n", - " drop_last=drop_last,\n", - " timeout=timeout,\n", - " worker_init_fn=worker_init_fn)\n", + " \"\"\"\n", + " Collate function specifies how to combine a list of data samples into a batch.\n", + " default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.\n", + " \"\"\"\n", + " return tree_map(np.asarray, default_collate(batch))\n", "\n", - "class FlattenAndCast(object):\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" + "def flatten_and_cast(pic):\n", + " \"\"\"Convert PIL image to flat (1-dimensional) numpy array.\"\"\"\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "id": "l314jsfP4TN4" }, @@ -341,108 +343,110 @@ "name": "stdout", "output_type": "stream", "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "75806ce83ace4f69b81bbc4251c5573f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "274ed4ab05f34f70b7a5bb6cf427ffd0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d38fa4eabf3c4d4494eb59e078ac94e8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "523ac9565c5f4509a1ee8fdbb1e6d66d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Processing...\n", - "Done!\n" + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] } ], "source": [ "# Define our dataset, using torch datasets\n", - "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())\n", - "training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast)\n", + "# Create pytorch data loader with custom collate function\n", + "training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": { "id": "FTNo4beUvb6t", "outputId": "65a9087c-c326-49e5-cbfc-e0839212fa31" @@ -452,27 +456,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:76: UserWarning: train_data has been renamed data\n", " warnings.warn(\"train_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets\n", - " warnings.warn(\"train_labels has been renamed targets\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:66: UserWarning: train_labels has been renamed targets\n", + " warnings.warn(\"train_labels has been renamed targets\")\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:81: UserWarning: test_data has been renamed data\n", " warnings.warn(\"test_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:71: UserWarning: test_labels has been renamed targets\n", " warnings.warn(\"test_labels has been renamed targets\")\n" ] } @@ -499,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": { "id": "X2DnZo3iYj18", "outputId": "0eba3ca2-24a1-4cba-aaf4-3ac61d0c650e" @@ -509,30 +499,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 in 55.15 sec\n", - "Training set accuracy 0.9157500267028809\n", - "Test set accuracy 0.9195000529289246\n", - "Epoch 1 in 42.26 sec\n", - "Training set accuracy 0.9372166991233826\n", - "Test set accuracy 0.9384000301361084\n", - "Epoch 2 in 44.37 sec\n", - "Training set accuracy 0.9491666555404663\n", - "Test set accuracy 0.9469000697135925\n", - "Epoch 3 in 41.75 sec\n", - "Training set accuracy 0.9568166732788086\n", - "Test set accuracy 0.9534000158309937\n", - "Epoch 4 in 41.16 sec\n", - "Training set accuracy 0.9631333351135254\n", - "Test set accuracy 0.9577000737190247\n", - "Epoch 5 in 38.89 sec\n", + "Epoch 0 in 5.53 sec\n", + "Training set accuracy 0.9156666994094849\n", + "Test set accuracy 0.9199000000953674\n", + "Epoch 1 in 1.13 sec\n", + "Training set accuracy 0.9370499849319458\n", + "Test set accuracy 0.9383999705314636\n", + "Epoch 2 in 1.12 sec\n", + "Training set accuracy 0.9490833282470703\n", + "Test set accuracy 0.9467999935150146\n", + "Epoch 3 in 1.21 sec\n", + "Training set accuracy 0.9568833708763123\n", + "Test set accuracy 0.9532999992370605\n", + "Epoch 4 in 1.17 sec\n", + "Training set accuracy 0.9631666541099548\n", + "Test set accuracy 0.9574999809265137\n", + "Epoch 5 in 1.17 sec\n", "Training set accuracy 0.9675000309944153\n", - "Test set accuracy 0.9616000652313232\n", - "Epoch 6 in 40.68 sec\n", - "Training set accuracy 0.9708333611488342\n", - "Test set accuracy 0.9650000333786011\n", - "Epoch 7 in 41.50 sec\n", - "Training set accuracy 0.973716676235199\n", - "Test set accuracy 0.9672000408172607\n" + "Test set accuracy 0.9615999460220337\n", + "Epoch 6 in 1.11 sec\n", + "Training set accuracy 0.9709500074386597\n", + "Test set accuracy 0.9652999639511108\n", + "Epoch 7 in 1.17 sec\n", + "Training set accuracy 0.9736999869346619\n", + "Test set accuracy 0.967199981212616\n" ] } ], @@ -576,7 +566,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -590,9 +580,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.12.3" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index cd98022e7421..bcc4019d6da0 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -192,41 +192,28 @@ JAX is laser-focused on program transformations and accelerator-backed NumPy, so import numpy as np from jax.tree_util import tree_map -from torch.utils import data +from torch.utils.data import DataLoader, default_collate from torchvision.datasets import MNIST def numpy_collate(batch): - return tree_map(np.asarray, data.default_collate(batch)) - -class NumpyLoader(data.DataLoader): - def __init__(self, dataset, batch_size=1, - shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, - pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - super(self.__class__, self).__init__(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=numpy_collate, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn) - -class FlattenAndCast(object): - def __call__(self, pic): - return np.ravel(np.array(pic, dtype=jnp.float32)) + """ + Collate function specifies how to combine a list of data samples into a batch. + default_collate creates pytorch tensors, then tree_map converts them into numpy arrays. + """ + return tree_map(np.asarray, default_collate(batch)) + +def flatten_and_cast(pic): + """Convert PIL image to flat (1-dimensional) numpy array.""" + return np.ravel(np.array(pic, dtype=jnp.float32)) ``` ```{code-cell} ipython3 :id: l314jsfP4TN4 # Define our dataset, using torch datasets -mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast()) -training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast) +# Create pytorch data loader with custom collate function +training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate) ``` ```{code-cell} ipython3 From 11a2a3b494ccf7bb7e40eab8760f7c43248e6dce Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 28 Apr 2025 02:38:16 -0700 Subject: [PATCH 0850/1769] [Mosaic GPU] Add a `BroadcastInDim` op in the mosaic mlir dialect. We replace the Pallas lowering for warp-group semantics with this new op. The new op also handles one extra-case: broadcasting to a single minor dimension. The concrete new functionality we need for now is to be able to expand a wgmma row (64) to a wgmma tile (64,64) by adding a minor dimension. I tried an alternative way of expressing this using the vector dialect with `vector.shape_cast` for `(64) -> (64,1)` + `vector.broadcast` for `(64,1) -> (64,64)`. That would work, but handling the intermediate state `(64,1)` requires non-trivial changes to `TileLayout`, `FragmentedArray` and layout inference and I wasn't convinced we want those. PiperOrigin-RevId: 752225526 --- jax/_src/pallas/mosaic_gpu/lowering.py | 30 +++-- .../mosaic/gpu/dialect_lowering.py | 31 +++++ .../mosaic/gpu/layout_inference.py | 70 +++++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 42 +++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 23 ++++ tests/mosaic/gpu_dialect_test.py | 110 ++++++++++++++++++ tests/mosaic/gpu_layout_inference_test.py | 44 +++++++ tests/mosaic/gpu_test.py | 61 ++++++++++ 8 files changed, 403 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5801441fdaf8..a955effdc31d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -30,6 +30,7 @@ from jax import api_util from jax import lax from jax._src import core as jax_core +from jax._src import lib as jaxlib from jax._src import linear_util as lu from jax._src import pjit from jax._src import source_info_util @@ -1456,21 +1457,34 @@ def _broadcast_in_dim_lowering_rule( lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, - x: ir.Value, + x, *, broadcast_dimensions, shape, sharding, ): del sharding - if broadcast_dimensions: - raise NotImplementedError + [x_aval] = ctx.avals_in - x = _ensure_ir_value(x, x_aval.dtype) - return vector_dialect.splat( - ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), - x, - ) + + if not broadcast_dimensions: + # Even though we could implement this case by passing a 0D vector as input + # to mgpu.dialect.BroadcastInDimOp we don't want that. 0D vectors are + # generally problematic and so we avoid them by specializing that case + # directly here. + x = _ensure_ir_value(x, x_aval.dtype) + return vector_dialect.splat( + ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), + x, + ) + + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. + if jaxlib.version < (0, 6, 1): + raise NotImplementedError() + + mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) + result_ty = ir.VectorType.get(shape, mlir_type) + return mgpu.dialect.broadcast_in_dim(result_ty, x, broadcast_dimensions) @register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index f2acd14cac72..89339bf4fb2f 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -22,6 +22,7 @@ import operator from typing import Any, Sequence, Type, cast +from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -484,6 +485,36 @@ def _mgpu_layout_cast_op_lowering_rule( return [layout_cast_op.x] +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. +if jaxlib.version >= (0, 6, 1): + @_register_lowering(mgpu.BroadcastInDimOp) + def _mgpu_broadcast_in_dim_op_lowering_rule( + _: LoweringContext, op: mgpu.BroadcastInDimOp + ) -> Sequence[ir.Value]: + in_ty = ir.VectorType(op.operand.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 1 or len(out_ty.shape) != 2: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + + broadcast_dims = list(op.broadcast_dimensions) + in_layout = inference_utils.in_layouts(op)[0] + operand_fa = _fragmented_array_from_ir(op.operand, in_layout) + + if (operand_fa.layout == fa.WGMMA_ROW_LAYOUT and broadcast_dims == [0]): + out = operand_fa.broadcast_minor(out_ty.shape[1]) + elif (operand_fa.layout == fa.WGMMA_COL_LAYOUT and broadcast_dims == [1]): + out = operand_fa.broadcast_major(out_ty.shape[0]) + else: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + return [_fragmented_array_to_ir(out, out_ty)] + + def swizzle_and_transforms_from_transforms_attr( transforms: ir.ArrayAttr, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index ee7571b1db72..ae082e4deb9b 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -21,6 +21,7 @@ import math from typing import cast +from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -450,6 +451,75 @@ def _infer_layout_cast_op_layout( return [layout_cast_op.new_layout], [layout_cast_op.new_layout] +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. +if jaxlib.version >= (0, 6, 1): + @partial(_add_layout_inference_rule, mgpu.BroadcastInDimOp) + def _infer_broadcast_in_dim_op_layout( + op: mgpu.BroadcastInDimOp, + ) -> OptionalLayouts: + if inference_utils.has_any_layout_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_in_layouts, op_out_layouts + + in_ty = ir.VectorType(op.operand.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 1 or len(out_ty.shape) != 2: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + + # Find out the layout of the output from the consumers. + user_layouts = set() + for use in cast(ir.OpResult, op.result).uses: + consumer = use.owner + operand = consumer.operands[use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, operand) + if layout is not None: + user_layouts.add(layout) + if user_layouts: + out_layout = _choose_representative_layout(user_layouts) + + if out_layout is None: + raise ValueError(f"Could not choose a best layout from {user_layouts}") + + if out_layout != layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT): + raise NotImplementedError(f"Unsupported layout: {out_layout}") + + broadcast_dims = list(op.broadcast_dimensions) + if broadcast_dims == [0]: + in_layout = layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + elif broadcast_dims == [1]: + in_layout = layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + else: + raise ValueError(f"Invalid broadcast dimensions: {broadcast_dims}") + + return [in_layout], [out_layout] + + # The consumers did not have any layouts set. Find out the layout of the + # input and infer the output layout from it. + in_layout = inference_utils.value_layout(op.operand) + if in_layout is None: + return None + + broadcast_dims = list(op.broadcast_dimensions) + if ( + broadcast_dims == [0] + and in_layout == layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + ) or ( + broadcast_dims == [1] + and in_layout == layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + ): + out_layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) + return [in_layout], [out_layout] + else: + raise NotImplementedError( + f"Unsupported layout: {in_layout} for broadcast dimensions" + f" {broadcast_dims}" + ) + + @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 073697df58ef..9d6085397493 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -407,6 +408,47 @@ llvm::LogicalResult CustomPrimitiveOp::verify() { return llvm::success(); } +llvm::LogicalResult BroadcastInDimOp::verify() { + auto error = [this](auto... params) { + return emitOpError(llvm::formatv(params...)); + }; + + auto operand_type = mlir::cast(getOperand().getType()); + auto result_type = mlir::cast(getResult().getType()); + + if (operand_type.getRank() == 0) { + return error("The input vector must have rank > 0."); + } + + if (operand_type.getRank() > result_type.getRank()) { + return error( + "The rank of the input vector must be smaller or equal to the rank " + "of the result vector."); + } + + if (operand_type.getRank() != getBroadcastDimensions().size()) { + return error( + "The size of the `broadcast_dimensions` attribute must be equal to " + "the rank of the input vector."); + } + auto dims = llvm::to_vector(getBroadcastDimensions()); + for (int i = 0; i < dims.size(); ++i) { + if (dims[i] < 0 || dims[i] >= result_type.getRank()) { + return error( + "The values in the `broadcast_dimensions` attribute must be in the " + "range [0, result.shape.rank={0}).", + result_type.getRank()); + } + if (i > 0 && dims[i] <= dims[i - 1]) { + return error( + "The values in the `broadcast_dimensions` attribute must be stricly " + "increasing."); + } + } + + return llvm::success(); +} + void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 4ff17d5c99cf..1465f76aa7bf 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -374,6 +374,29 @@ def MosaicGPU_LayoutCastOp : Op { + let summary = "Broadcasts a vector to a new shape."; + let description = [{ + `broadcast_dimensions` must have the same size as the rank of the input + vector and for each input dimension, specifies which output dimension it + corresponds to. + }]; + + let arguments = (ins + AnyVectorOfAnyRank:$operand, + + // Attributes + DenseI64ArrayAttr:$broadcast_dimensions + ); + + let results = (outs AnyVectorOfAnyRank); + let assemblyFormat = [{ + `(` $operand `:` type($operand) `)` attr-dict `->` type(results) + }]; + let hasVerifier = 1; +} + + def MosaicGPU_SliceSMEMOp : Op { let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 444830a3d75a..af8b296fe536 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -605,6 +605,116 @@ def test_tiled_layout_attr_parsing(self): parsed_layout = layouts.from_tiled_layout_attr(attr) self.assertEqual(layout, parsed_layout) + def test_broadcast_in_dim_ok(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0], + ) + ) + + self.assertTrue(self.module.operation.verify()) + + def test_broadcast_in_dim_no_0d(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The input vector must have rank > 0", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_no_input_larger_than_output(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64, 64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"rank of the input vector must be smaller", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_too_many_dims(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"size of the `broadcast_dimensions` attribute must be", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_oob(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[2], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"must be in the range \[0, result.shape.rank", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_transpose(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1, 3, 2], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"`broadcast_dimensions` attribute must be stricly increasing", + ): + self.module.operation.verify() + class DialectLoweringTest(MosaicGpuTest): diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 5355adfb2c7b..690e82c66cdc 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized import jax from jax._src import config +from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir @@ -235,6 +236,49 @@ def body(x): self.assertSequenceEqual(cast.attributes["in_layouts"], [wgmma_layout]) self.assertSequenceEqual(cast.attributes["out_layouts"], [wgmma_layout]) + @parameterized.parameters( + (0, mgpu.WGMMA_ROW_LAYOUT, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), + (1, mgpu.WGMMA_COL_LAYOUT, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT), + (0, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), + (1, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT), + ) + def test_infer_broadcast_in_dim_layout( + self, broadcast_dim, in_cast, out_cast, in_layout, out_layout + ): + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. + if jaxlib.version < (0, 6, 1): + self.skipTest("Test requires jaxlib version >= 0.6.1") + + bcast = None + + in_shape = (64,) + out_shape = (64, 64) + + def body(x): + nonlocal bcast + if in_cast is not None: + x = mgpu.dialect.LayoutCastOp(x, layouts.to_layout_attr(in_cast)) + + out_type = ir.VectorType.get(out_shape, ir.F32Type.get()) + bcast = mgpu.dialect.BroadcastInDimOp(out_type, x, [broadcast_dim]) + + if out_cast is not None: + mgpu.dialect.LayoutCastOp( + bcast.result, layouts.to_layout_attr(out_cast) + ) + + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get(in_shape, ir.F32Type.get()) + func.FuncOp.from_py_func(ty)(body) + + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + bcast.attributes["in_layouts"], [layouts.to_layout_attr(in_layout)] + ) + self.assertSequenceEqual( + bcast.attributes["out_layouts"], [layouts.to_layout_attr(out_layout)] + ) + def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) elt_type = ir.BF16Type.get() diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f7da896b6ee5..d166a885efeb 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -27,6 +27,7 @@ from absl.testing import absltest, parameterized import jax from jax._src import config +from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -53,6 +54,7 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core from jax.experimental.mosaic.gpu import launch_context + from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import inference_utils @@ -2848,6 +2850,65 @@ def add( self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) + @parameterized.parameters( + ((64,), (64, 128), [0]), + ((64,), (128, 64), [1]), + ) + def test_broadcast_in_dim(self, input_shape, output_shape, bcast_dims): + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. + if jaxlib.version < (0, 6, 1): + self.skipTest("Test requires jaxlib version >= 0.6.1") + + element_value = 42.0 + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + f32 = ir.F32Type.get() + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create input in registers + x_type = ir.VectorType.get(input_shape, f32) + c = arith.constant(f32, element_value) + x = vector.splat(x_type, c) + + # Computation + out_type = ir.VectorType.get(output_shape, f32) + expanded = mgpu_dialect.broadcast_in_dim(out_type, x, bcast_dims) + cast = mgpu_dialect.layout_cast( + expanded, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + + # Registers -> SMEM + vector.store(cast, result_smem_ref, [zero_index] * len(output_shape)) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32] * len(output_shape), + slice_lengths=output_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(output_shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(input_shape, element_value, dtype=dtype) + self.assertArraysEqual( + jax.jit(kernel)(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims) + ) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): From 19baf35f4cf91405cbc3560debb67ad1897b3371 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 28 Apr 2025 02:58:52 -0700 Subject: [PATCH 0851/1769] [Mosaic GPU] Handle WGMMA_ROW_LAYOUT in `vector_store` lowering and inference. These are essentially treated as having no transforms and stored untiled (only 2D-tilings are supported by `store_tiled`). PiperOrigin-RevId: 752230904 --- .../mosaic/gpu/dialect_lowering.py | 4 +- .../mosaic/gpu/transform_inference.py | 8 +++- tests/mosaic/gpu_test.py | 45 +++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 89339bf4fb2f..bef143465659 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -408,7 +408,9 @@ def _vector_store_op_lowering_rule( fragmented_array.store_tiled( reinterpret_smem_ref(vector_store_op.base, transforms), swizzle ) - elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or + elif (fragmented_array.layout == fa.WGMMA_ROW_LAYOUT or + fragmented_array.layout == fa.WGMMA_COL_LAYOUT or + isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): fragmented_array.store_untiled(vector_store_op.base) else: diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 6026cb216166..e146306a41db 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -179,8 +179,12 @@ def _infer_vector_load_store_transforms( layout_transforms = infer_transforms_for_wgmma_ref( ir.MemRefType(op.base.type) ) - elif (isinstance(layout, fa.WGStridedFragLayout) or - isinstance(layout, fa.WGSplatFragLayout)): + elif ( + layout == fa.WGMMA_ROW_LAYOUT + or layout == fa.WGMMA_COL_LAYOUT + or isinstance(layout, fa.WGStridedFragLayout) + or isinstance(layout, fa.WGSplatFragLayout) + ): layout_transforms = None else: raise NotImplementedError( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d166a885efeb..16a7ffef9d05 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2909,6 +2909,51 @@ def body(ctx, result_gmem_ref, smem): jax.jit(kernel)(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims) ) + @parameterized.parameters(fa.WGMMA_ROW_LAYOUT, fa.WGMMA_COL_LAYOUT) + def test_wgmma_row_col_store(self, in_layout): + element_value = 42.0 + shape = (64, ) + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + f32 = ir.F32Type.get() + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create input in registers + x_type = ir.VectorType.get(shape, f32) + c = arith.constant(f32, element_value) + x = vector.splat(x_type, c) + cast = mgpu_dialect.layout_cast(x, layouts.to_layout_attr(in_layout)) + + # Registers -> SMEM + vector.store(cast, result_smem_ref, [zero_index]) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32], + slice_lengths=shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(shape, element_value, dtype=dtype) + self.assertArraysEqual(jax.jit(kernel)(), x) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): From e8e358b7f72bd5d13fdb6de6ec26a7b6b57fa588 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 28 Apr 2025 03:19:26 -0700 Subject: [PATCH 0852/1769] [Mosaic GPU] Delete `gpu.binary` after lowering. Its lowering is becoming side-effecting in an upcoming MLIR change, preventing its deletion and causing issues with unregistered symbols. PiperOrigin-RevId: 752236293 --- jaxlib/mosaic/gpu/launch_lowering.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index f3f982f07481..53d4f47e58cc 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -299,6 +299,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { init_func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), mlir::UnitAttr::get(func->getContext())); bool had_launch = false; + mlir::Operation *gpu_binary = nullptr; auto result = getOperation()->walk([&](mlir::gpu::LaunchFuncOp launch) -> mlir::WalkResult { if (had_launch) { @@ -314,6 +315,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { << launch.getKernelModuleName(); return mlir::WalkResult::interrupt(); } + gpu_binary = binary.getOperation(); if (binary.getObjects().size() != 1) { binary.emitOpError("Expected exactly one object in the binary."); return mlir::WalkResult::interrupt(); @@ -352,6 +354,13 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { if (!had_launch) { init_func.erase(); } + if (gpu_binary) { + // This deletion is load-bearing: the conversion of `gpu.binary` to + // LLVM is side-effecting, as it creates module constructors and + // destructors which create an assumption that symbols from the MLIR + // runtime are available. + gpu_binary->erase(); + } if (result == mlir::WalkResult::interrupt()) { signalPassFailure(); } From b2f4d78a173fc027d43f27822ce3c0c3fee7c5dd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Apr 2025 04:06:49 -0700 Subject: [PATCH 0853/1769] [Mosaic GPU] Make sure the tests all pass on B200 (esp. in our CI) This is a bag of small fixes to our tests to make sure we can rely on our CI for B200 testing. The biggest change is the move of the single multi-device test from gpu_test, so that it's not skipped in CI when it deselects multiaccelerator targets. PiperOrigin-RevId: 752248888 --- tests/mosaic/BUILD | 18 +++- tests/mosaic/gpu_test.py | 13 --- tests/mosaic/gpu_test_multidevice.py | 74 +++++++++++++ tests/pallas/mosaic_gpu_test.py | 154 ++++++++++++++------------- 4 files changed, 167 insertions(+), 92 deletions(-) create mode 100644 tests/mosaic/gpu_test_multidevice.py diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 71b2b7d80570..735c9ffd5b42 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -33,14 +33,10 @@ jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], enable_backends = [], - enable_configs = [ - "gpu_h100", - "gpu_h100x2", - ], + enable_configs = ["gpu_h100"], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 16, tags = [ - "multiaccelerator", "noasan", # Times out. ], deps = [ @@ -48,6 +44,18 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "gpu_test_multidevice", + srcs = ["gpu_test_multidevice.py"], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = ["multiaccelerator"], + deps = [ + "//jax:mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_py_test( name = "gpu_dialect_test", srcs = ["gpu_dialect_test.py"], diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 16a7ffef9d05..c11141fee4d0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2241,19 +2241,6 @@ def kernel(ctx, src, dst, _): )) jax.block_until_ready(f(x)) - def test_multigpu(self): - if len(jax.devices()) < 2: - self.skipTest("Need at least 2 devices") - def kernel(ctx, src, dst, _): - mgpu.FragmentedArray.load_strided(src).store_untiled(dst) - x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - f = jax.jit(mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), x, x, () - )) - # Make sure we can invoke the same program on different devices. - for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): - jax.block_until_ready(f(xd)) - class TorchTest(TestCase): diff --git a/tests/mosaic/gpu_test_multidevice.py b/tests/mosaic/gpu_test_multidevice.py new file mode 100644 index 000000000000..114409a5efd8 --- /dev/null +++ b/tests/mosaic/gpu_test_multidevice.py @@ -0,0 +1,74 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest, parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + import jax.experimental.mosaic.gpu as mgpu + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +config.parse_flags_with_absl() + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_multigpu(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, () + )) + # Make sure we can invoke the same program on different devices. + for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): + jax.block_until_ready(f(xd)) + + +if __name__ == "__main__": + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 41d02ed781e2..63f69738b6f6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -34,6 +34,7 @@ from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.state import types as state_types from jax.experimental import pallas as pl import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu @@ -766,79 +767,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - src_memory_space=[plgpu.SMEM, plgpu.GMEM], - layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], - m=[64, 128, 192], - ) - def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): - self.skip_if_wg_semantics() - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), - in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), - ) - def kernel(x_ref, o_ref): - for i in range(2): - x = plgpu.load( - x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM - ) - o_ref[i, ...] = x - - x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) - np.testing.assert_array_equal(kernel(x), x) - - @parameterized.product( - src_memory_space=[plgpu.SMEM], - layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], - ) - def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): - self.skip_if_wg_semantics() - - m, k, n = 64, 128, 192 - key1, key2 = jax.random.split(jax.random.key(42), 2) - if layout == plgpu.Layout.WGMMA_ROW: - input_shape = (m,) - broadcast_dim = 0 - expand_dim = 1 - else: - input_shape = (k,) - broadcast_dim = 1 - expand_dim = 0 - a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) - b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) - def kernel(x_ref, y_ref, o_ref): - x = plgpu.load(x_ref, (), layout=layout) - x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) - - def compute(acc_ref): - plgpu.wgmma(acc_ref, x, y_ref) - return acc_ref[...] - - out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) - o_ref[...] = out - f = self.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), - in_specs=( - pl.BlockSpec(memory_space=src_memory_space), - plgpu.GPUBlockSpec( - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), - ), - ), - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), - ) - - out_ref = ( - jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b - ) - np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) - def test_indexing_before_transpose(self): self.skip_if_wg_semantics() @@ -1630,7 +1558,11 @@ def test_smem_aliasing_works(self): def kernel(x_ref, o_ref128, aliased_ref): smem_ref256, _, smem_ref128 = aliased_ref # Ensure that extraction via index works the same as unfolding. - self.assertEqual(smem_ref128, aliased_ref[2]) + smem_ref128_2 = aliased_ref[2] + self.assertIsInstance(smem_ref128, state_types.TransformedRef) + self.assertIsInstance(smem_ref128_2, state_types.TransformedRef) + self.assertIs(smem_ref128.ref, smem_ref128_2.ref) + self.assertEqual(smem_ref128.transforms, smem_ref128_2.transforms) extract_alias_transform, tile_transform = smem_ref128.transforms # Ensure that the transforms provided in the scratch shapes have been # passed correctly. @@ -2038,6 +1970,79 @@ def scope(acc_ref): )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM + ) + o_ref[i, ...] = x + + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) + + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] + + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.GPUBlockSpec( + transforms=( + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + class PallasCallSm90AWGTest( PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup @@ -2048,6 +2053,7 @@ class PallasCallSm90AWGTest( class PallasCallSm100ATest(PallasSm100ATest): def test_tmem_alloc(self): + self.skip_if_wg_semantics() # TMEM read not wired up in the WG get rule. @functools.partial( self.kernel, From b0251a36b5a19188250233917d4c53f29568c714 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 28 Apr 2025 05:33:15 -0700 Subject: [PATCH 0854/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b77c3dfa079d9d78e40cd7fcd18f821e7c90b8ed. PiperOrigin-RevId: 752269248 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 420f93b36716..0530665390bc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e039a21a6ffa80bab553479d4a62054c08316e6e" -XLA_SHA256 = "c1a353cee867f1f75079f9ca8bacfb5831ce6eb740382dd393b084c984184f34" +XLA_COMMIT = "b77c3dfa079d9d78e40cd7fcd18f821e7c90b8ed" +XLA_SHA256 = "c527159d433b3301acc6a3d01e504a1718d80c4dfa0a4d38e2cf7529d4fb8162" def repo(): tf_http_archive( From eba64f932118120d049de2b6eb82e5d708ca5f81 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Apr 2025 05:48:46 -0700 Subject: [PATCH 0855/1769] [Pallas:MGPU] Increase shard_coaunt for mgpu_attention_test to avoid timeouts PiperOrigin-RevId: 752272895 --- tests/pallas/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 92ff732e2200..897bc82d4365 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -653,6 +653,7 @@ jax_multiplatform_test( "gpu_h100", ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 4, deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", From 814c91704f8489755fc71bf8c4be14f648113b6a Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Mon, 28 Apr 2025 07:00:52 -0700 Subject: [PATCH 0856/1769] [Mosaic GPU] Relax tolerance for WGMMATest.test_narrow_n I'm assuming that the tolerance has been rule-of-thumbed depending on observed numbers. If that's the case, then we should relax it, as changing the tiling used by XLA's reference implementation (e.g. by setting `--xla_gpu_experimental_enable_dynamic_dot_search_space=1` that I'm trying to enable by default) causes the test to fail. This is only a workaround, and the test is still flaky due to this tolerance. One way to fix it would be to set up the inputs in a way that we can guarantee all intermediate results are representable in f16/f32, and then we can check for bitwise equality, rather than relying on tolerances. PiperOrigin-RevId: 752291561 --- tests/mosaic/gpu_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c11141fee4d0..177f622ff42a 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -892,7 +892,7 @@ def kernel(ctx, rhs, out, smem): ref = jax.lax.dot( x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 ) - np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0) + np.testing.assert_allclose(z, ref, rtol=1e-3, atol=0) class TCGen05Test(TestCase): From 02e8f5eae9436400787d8708513338d17e0a2918 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Apr 2025 07:30:06 -0700 Subject: [PATCH 0857/1769] [Mosaic GPU] Fix a skip condition to avoid problems with nightly jaxlib builds PiperOrigin-RevId: 752299310 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index bef143465659..da154bdd8fa3 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -22,7 +22,6 @@ import operator from typing import Any, Sequence, Type, cast -from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -488,7 +487,7 @@ def _mgpu_layout_cast_op_lowering_rule( # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. -if jaxlib.version >= (0, 6, 1): +if hasattr(mgpu, "BroadcastInDimOp"): @_register_lowering(mgpu.BroadcastInDimOp) def _mgpu_broadcast_in_dim_op_lowering_rule( _: LoweringContext, op: mgpu.BroadcastInDimOp From cd3fa6d6e7fc0ccbdfb55d10048342e4a7b2b2e6 Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Thu, 17 Apr 2025 06:23:09 +0000 Subject: [PATCH 0858/1769] implement hyp2f1 --- jax/_src/scipy/special.py | 254 ++++++++++++++++++++++ jax/scipy/special.py | 1 + tests/lax_scipy_special_functions_test.py | 17 ++ 3 files changed, 272 insertions(+) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index a24736ccfec0..d5c99b6e3b4f 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -2637,6 +2637,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: ) +def _hyp2f1_terminal(a, b, c, x): + """ + The Taylor series representation of the 2F1 hypergeometric function + terminates when either a or b is a non-positive integer. See Eq. 4.1 and + Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + # Ensure that between a and b, the negative integer parameter with the greater + # absolute value - that still has a magnitude less than the absolute value of + # c if c is non-positive - is used for the upper limit in the loop. + eps = jnp.finfo(x.dtype).eps * 50 + ib = jnp.round(b) + mask = jnp.logical_and( + b < a, + jnp.logical_and( + jnp.abs(b - ib) < eps, + jnp.logical_not( + jnp.logical_and( + c % 1 == 0, + jnp.logical_and( + c <= 0, + c > b + ) + ) + ) + ) + ) + orig_a = a + a = jnp.where(mask, b, a) + b = jnp.where(mask, orig_a, b) + + a = jnp.abs(a) + + def body(i, state): + serie, term = state + + term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x + serie += term + + return serie, term + + init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) + + return lax.fori_loop(jnp.array(1, dtype=a.dtype), + a + 1, + body, + init)[0] + + +def _hyp2f1_serie(a, b, c, x): + """ + Compute the 2F1 hypergeometric function using the Taylor expansion. + See Eq. 4.1 from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + rtol = jnp.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + + serie += term + term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie)) + + init = (jnp.array(0, dtype=x.dtype), + jnp.array(1, dtype=x.dtype), + jnp.array(1, dtype=x.dtype)) + + return lax.while_loop(cond, body, init)[0] + + +def _hyp2f1_terminal_or_serie(a, b, c, x): + """ + Check for recurrence relations along with whether or not the series + terminates. True recursion is not possible; however, the recurrence + relation may still be approximated. + See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + + ia = jnp.round(a) + ib = jnp.round(b) + id = jnp.round(d) + + neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps) + neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps) + neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b) + not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b) + + index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b), + jnp.where(jnp.abs(d - id) >= eps, 0, 1), + jnp.where(neg_int_a_or_b, 2, 0)) + + return lax.select_n(index, + _hyp2f1_serie(a, b, c, x), + _hyp2f1_digamma_transform(a, b, c, x), + _hyp2f1_terminal(a, b, c, x)) + + +def _hyp2f1_digamma_transform(a, b, c, x): + """ + Digamma transformation of the 2F1 hypergeometric function. + See AMS55 #15.3.10, #15.3.11, #15.3.12 + """ + rtol = jnp.finfo(x.dtype).eps + + d = c - a - b + s = 1 - x + rd = jnp.round(d) + + e = jnp.where(rd >= 0, d, -d) + d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype)) + d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d) + ard = jnp.where(rd >= 0, rd, -rd).astype('int32') + + ax = jnp.log(s) + + y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax + y /= gamma(e + 1.0) + + p = (a + d1) * (b + d1) * s / gamma(e + 2.0) + + def cond(state): + _, _, _, _, _, _, q, _, _, t, y = state + + return jnp.logical_and( + t < 250, + jnp.abs(q) >= rtol * jnp.abs(y) + ) + + def body(state): + a, ax, b, d1, e, p, q, r, s, t, y = state + + r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \ + - digamma(b + t + d1) - ax + q = p * r + y += q + p *= s * (a + t + d1) / (t + 1.0) + p *= (b + t + d1) / (t + 1.0 + e) + t += 1.0 + + return a, ax, b, d1, e, p, q, r, s, t, y + + init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s, + jnp.array(1, dtype=x.dtype), y) + _, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init) + + def compute_sum(y): + y1 = jnp.array(1, dtype=x.dtype) + t = jnp.array(0, dtype=x.dtype) + p = jnp.array(1, dtype=x.dtype) + + def for_body(i, state): + a, b, d2, e, p, s, t, y1 = state + + r = 1.0 - e + t + p *= s * (a + t + d2) * (b + t + d2) / r + t += 1.0 + p /= t + y1 += p + + return a, b, d2, e, p, s, t, y1 + + init_val = a, b, d2, e, p, s, t, y1 + y1 = lax.fori_loop(1, ard, for_body, init_val)[-1] + + p = gamma(c) + y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1)) + y *= p / (gamma(a + d2) * gamma(b + d2)) + + y = jnp.where((ard & 1) != 0, -y, y) + q = s ** rd + + return jnp.where(rd > 0, y * q + y1, y + y1 * q) + + return jnp.where( + rd == 0, + y * gamma(c) / (gamma(a) * gamma(b)), + compute_sum(y) + ) + + +@jit +@jnp.vectorize +def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array: + r"""The 2F1 hypergeometric function. + + JAX implementation of :obj:`scipy.special.hyp2f1`. + + .. math:: + + \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!} + + where :math:`(\cdot)_k` is the Pochammer symbol. + + The JAX version only accepts positive and real inputs. Values of + ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may + lead to erroneous results; consider enabling double precision in this case. + + Args: + a: arraylike, real-valued + b: arraylike, real-valued + c: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of 2F1 values. + """ + # This is backed by https://doi.org/10.48550/arXiv.1407.7786 + a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x) + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + s = 1 - x + ca = c - a + cb = c - b + + id = jnp.round(d) + ica = jnp.round(ca) + icb = jnp.round(cb) + + neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps) + neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps) + neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb) + + index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0, + jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1, + jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2, + jnp.where(jnp.logical_and(d <= 0, x == 1), 1, + jnp.where(jnp.logical_and(x < 1, b == c), 3, + jnp.where(jnp.logical_and(x < 1, a == c), 4, + jnp.where(x > 1, 1, + jnp.where(x == 1, 5, 6)))))))) + + return lax.select_n(index, + jnp.array(1, dtype=x.dtype), + jnp.array(jnp.inf, dtype=x.dtype), + s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x), + s ** (-a), + s ** (-b), + gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)), + _hyp2f1_terminal_or_serie(a, b, c, x)) + + def softmax(x: ArrayLike, /, *, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 2ffc65a1abe1..e1330d4b6cf3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -37,6 +37,7 @@ gammaln as gammaln, gammasgn as gammasgn, hyp1f1 as hyp1f1, + hyp2f1 as hyp2f1, i0 as i0, i0e as i0e, i1 as i1, diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 4b3945a84453..e02581626a53 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True ), + op_record( + "hyp2f1", 4, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.5, high=30), False + ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -354,5 +358,18 @@ def testBetaIncBoundaryValues(self): self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) + def testHyp2f1SpecialCases(self): + dtype = jax.dtypes.canonicalize_dtype(float) + + a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype) + b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype) + c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype) + x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype) + + args_maker = lambda: (a_samples, b_samples, c_samples, x_samples) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From cb4eeb53a8ad0e802796503e42a6a1eb4867f7ea Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 28 Apr 2025 08:10:45 -0700 Subject: [PATCH 0859/1769] Fix formatting error in matrix exclude strategy `python` needs to be on the first indent level PiperOrigin-RevId: 752312000 --- .github/workflows/wheel_tests_nightly_release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index fe8b191c9530..483705aeb0cb 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -102,7 +102,7 @@ jobs: python: "3.10" - tpu-specs: type: "v6e-8" - python: "3.11" + python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From 7ff68f9de233d02ee28219aa204c351a3333e1d0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 28 Apr 2025 08:40:03 -0700 Subject: [PATCH 0860/1769] Update `build.py` to avoid duplication of bazel options both in bazel command and in `.jax_configure.bazelrc`. This is done to prevent cache invalidation as stated in the bazel logs: ``` WARNING: The following configs were expanded more than once: [clang, mkl_open_source_only, avx_posix, cuda, cuda_libraries_from_stubs, build_cuda_with_nvcc]. For repeatable flags, repeats are counted twice and may lead to unexpected behavior. ``` PiperOrigin-RevId: 752320902 --- build/build.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/build/build.py b/build/build.py index 287400ee8f42..4a7c745ce9d9 100755 --- a/build/build.py +++ b/build/build.py @@ -573,7 +573,6 @@ async def main(): if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda") - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") if args.use_clang: wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" @@ -640,9 +639,6 @@ async def main(): if "ML_WHEEL_GIT_HASH" in option: wheel_git_hash = option.split("=")[-1][:9] - if "cuda" in args.wheels: - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") - with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) if not jax_configure_options: @@ -677,7 +673,9 @@ async def main(): ) sys.exit(1) - wheel_build_command = copy.deepcopy(wheel_build_command_base) + wheel_build_command = copy.deepcopy(bazel_command_base) + if "cuda" in args.wheels: + wheel_build_command.append("--config=cuda_libraries_from_stubs") print("\n") logger.info( "Building %s for %s %s...", From 0debc206d7031d2395dea36310f61f47277af711 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 28 Apr 2025 12:17:54 -0400 Subject: [PATCH 0861/1769] Relax constraints in jnp.vectorize for output shapes with default signature. --- jax/_src/numpy/vectorize.py | 31 ++++++++++++++++--------------- tests/lax_numpy_vectorize_test.py | 9 +++++++++ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index a60e681427f5..5ea9d697d27d 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -144,18 +144,15 @@ def wrapped(*args): out = func(*args) out_shapes = map(np.shape, out if isinstance(out, tuple) else [out]) - if expected_output_core_dims is None: - output_core_dims = [()] * len(out_shapes) - else: - output_core_dims = expected_output_core_dims - if len(output_core_dims) > 1 and not isinstance(out, tuple): - raise TypeError( - "output must be a tuple when multiple outputs are expected, " - "got: {!r}\n{}".format(out, error_context)) - if len(out_shapes) != len(output_core_dims): - raise TypeError( - 'wrong number of output arguments: expected %r, got %r %s' - % (len(output_core_dims), len(out_shapes), error_context)) + output_core_dims = expected_output_core_dims + if len(output_core_dims) > 1 and not isinstance(out, tuple): + raise TypeError( + "output must be a tuple when multiple outputs are expected, " + "got: {!r}\n{}".format(out, error_context)) + if len(out_shapes) != len(output_core_dims): + raise TypeError( + 'wrong number of output arguments: expected %r, got %r %s' + % (len(output_core_dims), len(out_shapes), error_context)) sizes = dict(dim_sizes) for shape, core_dims in zip(out_shapes, output_core_dims): @@ -215,7 +212,8 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): ``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If provided, ``pyfunc`` will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By - default, pyfunc is assumed to take scalars arrays as input and output. + default, pyfunc is assumed to take scalar arrays as input, and if + ``signature`` is ``None``, ``pyfunc`` can produce outputs of any shape. Returns: Vectorized version of the given function. @@ -294,8 +292,11 @@ def wrapped(*args, **kwargs): broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) - checked_func = _check_output_dims( - excluded_func, dim_sizes, output_core_dims, error_context) + if output_core_dims is None: + checked_func = excluded_func + else: + checked_func = _check_output_dims( + excluded_func, dim_sizes, output_core_dims, error_context) # Detect implicit rank promotion: if config.numpy_rank_promotion.value != "allow": diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 985dba484845..8fbd393dc3f8 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -287,6 +287,15 @@ def test_rank_promotion_error(self): with self.assertNoWarnings(): f2(rank2, rank1) + def test_non_scalar_outputs_and_default_signature(self): + def f(x): + self.assertEqual(np.shape(x), ()) + return x + jnp.linspace(-1, 1, out_dim) + + out_dim = 5 + self.assertEqual(jnp.vectorize(f)(0.5).shape, (out_dim,)) + self.assertEqual(jnp.vectorize(f)(jnp.ones(3)).shape, (3, out_dim)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 69540e400b47f0c14dc0fa01aea2ded666b6587f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 28 Apr 2025 09:51:48 -0700 Subject: [PATCH 0862/1769] [Mosaic] Remove Python pipeline. PiperOrigin-RevId: 752344826 --- jax/_src/tpu_custom_call.py | 219 +----------------------------------- 1 file changed, 1 insertion(+), 218 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ff0c28ed13f2..7aa0f87dd525 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -24,8 +24,6 @@ import enum import functools import io -import os -import time from typing import Any import jax @@ -38,7 +36,6 @@ from jax._src.lib import xla_client from jax.interpreters import xla from jaxlib.mlir import ir -from jaxlib.mlir.dialects import stablehlo from jaxlib.mlir.passmanager import PassManager try: @@ -47,16 +44,6 @@ except ImportError: FLAGS = {} -_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( - name="mosaic_use_python_pipeline", - default=False, - help=( - "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" - " is called (for Pallas, this happens at JAX lowering time), instead of" - " later within XLA." - ), -) - _MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, @@ -88,12 +75,6 @@ def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: tpu_custom_call_p.multiple_results = True -def get_target_shape(hardware_generation: int) -> tuple[int, int]: - """Returns the target shape for the given hardware generation.""" - del hardware_generation - return (8, 128) - - class MemorySpace(enum.Enum): HBM = enum.auto() VMEM = enum.auto() @@ -305,166 +286,6 @@ def _tpu_custom_call_lowering( platform="tpu") -def _lower_tpu_kernel( - module: ir.Module, - hardware_generation: int, - target_shape: tuple[int, int], - kernel_name: str | None = None, -) -> ir.Module: - """Runs MLIR passes lowering the given module to an MLIR module. - - Uses Python versions of canonicalize-mosaic,infer-memref-layout and - apply-vector-layout. - - Args: - module: The MLIR module to lower. - hardware_generation: The TPU hardware generation to target. - target_shape: The target shape of (sublane_count, lane_count). - - Returns: - An MLIR module implementing the kernel. - """ - try: - module.operation.verify() - except ir.MLIRError as e: - raise ValueError("The compiled module fails MLIR verification") from e - - timestamp = time.time_ns() - dump_cnt = [0] - - def get_dump_file_prefix() -> str: - s = f"{timestamp}-{dump_cnt[0]:04}" - dump_cnt[0] += 1 - return s - - with module.context as ctx, module.operation.location as _: - ctx.append_dialect_registry(mlir.upstream_dialects) - ctx.load_all_available_dialects() - tpu.register_dialect(ctx) - stablehlo.register_dialect(ctx) - dump_mlir(module, "original", get_dump_file_prefix(), kernel_name) - - if _MOSAIC_ALLOW_HLO.value: - # Run dialect conversion: StableHLO -> linalg -> vector. - pipeline = [ - "func.func(stablehlo-legalize-to-linalg)", - "func.func(linalg-vectorization)", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-hlo-conversion", get_dump_file_prefix(), kernel_name) - - sl_cnt, l_cnt = target_shape - # Note: we don't pass the TpuTilingFlags here, since we don't know the - # tiling decisions made by the compiler / what flags are enabled at this - # point, so we assume everything can be tiled up to default tiling. - pipeline = [ - "func.func(tpu-infer-memref-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt}" - f" lane-count={l_cnt}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-memref-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-infer-memref-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - try: - on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value - except KeyError: - on_device_checks = False - - if checks := on_device_checks: - checks = set(checks.split(",")) - if checks == {"bounds"}: # We only support one kind of checks now. - pipeline = PassManager.parse( - "builtin.module(func.func(debug-assert-insertion))" - ) - pipeline.run(module.operation) - dump_mlir(module, "post-assert-insertion", get_dump_file_prefix(), kernel_name) - elif checks: - checks.discard("bounds") - raise ValueError( - f"Unrecognized on-device check categories: {', '.join(checks)}" - ) - - # Legacy pipeline always runs in compatibility mode. - compatibility_mode = True - pipeline = [ - ( - f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation} compatibility-mode={compatibility_mode}}})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-canonicalize-mosaic", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-infer-vector-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-relayout-insertion{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-relayout-insertion", get_dump_file_prefix(), kernel_name) - - mxu_size = 128 if hardware_generation < 6 else 256 - pipeline = [ - "func.func(tpu-apply-vector-layout{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" - f" max-sublanes-in-scratch={sl_cnt * (sl_cnt + 1)}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-apply-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-apply-vector-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - return module - - def _lower_mosaic_module_to_asm( module: ir.Module, *, @@ -480,27 +301,7 @@ def _lower_mosaic_module_to_asm( needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) - module_op = module.operation - some_tpu = jax.devices(backend)[0] - device_kind = some_tpu.device_kind - if not device_kind.startswith("TPU v"): - raise ValueError( - f"Unrecognized TPU device kind: {device_kind}. " - "tpu_custom_call cannot be lowered on a machine without TPUs " - "when mosaic_use_python_pipeline=True.") - hardware_generation = int(device_kind[len("TPU v")]) - target_shape = get_target_shape(hardware_generation) - module = _lower_tpu_kernel( - module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name, - ) - needs_hlo_passes = False - needs_layout_passes = False - else: - module_op = module.operation.clone() + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True target_version = ( @@ -825,21 +626,3 @@ def apply_kernel(*args): return result[0] if unpack else result return jax.jit(apply_kernel) - - -def dump_mlir( - module: ir.Module, name: str, prefix: str, kernel_name: str | None = None -): - """A helper function to dump mosaic mlir module""" - try: - should_dump = FLAGS["xla_mosaic_dump_to"].value - except KeyError: - return - if should_dump == "sponge": - outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) - if outdir: - if kernel_name: - name = f"{kernel_name}-{name}" - path = os.path.join(outdir, f"{prefix}-mosaic-dump-{name}-py.txt") - with open(path, "w") as f: - f.write(str(module)) From cf36c8b9be3b27b24fe8d526be127326cdc1cdc0 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 28 Apr 2025 12:57:04 -0400 Subject: [PATCH 0863/1769] Remove unused parameters in emit_python_callback. --- jax/_src/callback.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index a44fb6fb2783..d23389af16eb 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -728,21 +728,6 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined return outputs, token -def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): - if minor_to_major is None: - # Needed for token layouts - layout: np.ndarray = np.zeros((0,), dtype="int64") - else: - layout = np.array(minor_to_major, dtype="int64") - return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) - - -def _aval_to_default_layouts(aval): - avals = [core.physical_aval(aval)] - # Row major order is default for `NumPy`. - return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] - - def emit_python_callback( ctx: mlir.LoweringRuleContext, callback, @@ -754,8 +739,6 @@ def emit_python_callback( has_side_effect: bool, partitioned: bool = False, sharding: SdyArrayShardingList | xc.OpSharding | None = None, - operand_layouts: Sequence[Sequence[int] | None] | None = None, - result_layouts: Sequence[Sequence[int] | None] | None = None, ) -> tuple[Sequence[mlir.IrValues], Any, Any]: """Emits MLIR that calls back to a provided Python function. @@ -770,8 +753,6 @@ def emit_python_callback( partitioned: If True, then `callback` is called on local shards only. If False, then `callback` is called on all shards. sharding: The sharding of the callback. - operand_layouts: The layouts of the operands. - result_layouts: The layouts of the results. Returns: A tuple of MLIR result values, a new token (if any), and the host callback @@ -792,14 +773,6 @@ def emit_python_callback( backend = ctx.module_context.get_backend() result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] - # Handling layouts - if operand_layouts is None: - operand_layouts = util.concatenate( - map(_aval_to_default_layouts, operand_avals)) - operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts) - if result_layouts is None: - result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals)) - result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts) # First we apply checks to ensure output shapes and dtypes match the expected # ones. From 5f30087e3692d93c687ed339c9d50d2922f0785b Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 28 Apr 2025 10:12:38 -0700 Subject: [PATCH 0864/1769] [JAX:sparse] Fix bcsr.from_bcoo to use the index_dtype of the input BCOO matrix. PiperOrigin-RevId: 752353455 --- jax/experimental/sparse/bcsr.py | 10 +++++++--- tests/sparse_test.py | 4 +++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 4b01f362bb83..a7f7deb2a93f 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -145,7 +145,7 @@ def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *, def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int], - index_dtype: DTypeLike = jnp.int32) -> tuple[Array, Array]: + index_dtype: DTypeLike) -> tuple[Array, Array]: """Given BCOO (indices), return BCSR (indices, indptr). Note: this assumes that ``indices`` are lexicographically sorted within each batch. @@ -238,7 +238,9 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype): raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.") bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch) - indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape) + indices, indptr = _bcoo_to_bcsr( + bcoo_mat.indices, shape=mat.shape, index_dtype=index_dtype + ) return bcoo_mat.data, indices, indptr @@ -867,7 +869,9 @@ def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR: raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}") if not arr.indices_sorted: arr = arr.sort_indices() - indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape) + indices, indptr = _bcoo_to_bcsr( + arr.indices, shape=arr.shape, index_dtype=arr.indices.dtype + ) return cls((arr.data, indices, indptr), shape=arr.shape) @classmethod diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 219875d4b7d0..71437fd0e028 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1102,7 +1102,9 @@ def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch): _, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) - bcoo_to_bcsr = partial(sparse_bcsr._bcoo_to_bcsr, shape=shape) + bcoo_to_bcsr = partial( + sparse_bcsr._bcoo_to_bcsr, shape=shape, index_dtype=bcoo_indices.dtype + ) args_maker_bcoo_to_bcsr = lambda: [bcoo_indices] self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr) From fffdbe693cfe4f577a35f3dc6444660b9cdb608c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 28 Apr 2025 10:14:12 -0700 Subject: [PATCH 0865/1769] Fix warning in mosaic/pipeline.py under Python 3.12. PiperOrigin-RevId: 752354049 --- jax/_src/pallas/mosaic/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 6019840ab178..6865713cd69a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -949,7 +949,7 @@ def skip_input_copies_when_init_accumulators(schedule) -> Any: def new_pred(original_pred_fn, *a): pred = original_pred_fn(*a) if a[1].is_accumulator or a[1].is_input_output: - pred &= ~a[0].init_accumulators + pred &= jnp.logical_not(a[0].init_accumulators) return pred new_schedule[k] = functools.partial( From ad44895bf9fa49166ee27d30b3272dc72cfe3977 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 28 Apr 2025 13:10:34 -0400 Subject: [PATCH 0866/1769] Handle extended dtypes in FFI lowering. --- jax/_src/ffi.py | 2 +- tests/debugging_primitives_test.py | 7 +++++++ tests/ffi_test.py | 5 +++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index f0c7d761ac2b..b25306c66b42 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -135,7 +135,7 @@ def include_dir() -> str: def _aval_shape(aval: core.AbstractValue) -> Shape: - return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error + return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error def _convert_layout_for_lowering( diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 1e0408b8d2ba..7985cf841248 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -290,6 +290,13 @@ def f(): actual = tuple(sorted(map(int, output().splitlines()))) self.assertEqual(actual, tuple(range(4))) + def test_debug_print_extended_dtype(self): + def f(k): + jax.debug.print("{}", k) + with jtu.capture_stdout(): + f(jax.random.key(0)) # doesn't crash + jax.effects_barrier() + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTransformationTest(jtu.JaxTestCase): diff --git a/tests/ffi_test.py b/tests/ffi_test.py index fd41314350f3..a66d17622ae5 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -286,6 +286,11 @@ def f(x): def test_extend_import_shim(self): ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True) + def test_extended_dtype_lowering(self): + def f(x): + return jax.ffi.ffi_call("edtype", (), has_side_effect=True)(x) + jax.jit(f).lower(jax.random.key(0)) # doesn't crash + def ffi_call_geqrf(x, _use_extend=False, **kwargs): if jtu.test_device_matches(["cpu"]): From 50b4654645d9681d173e93ab9031e2bcea5f306a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 28 Apr 2025 11:01:56 -0700 Subject: [PATCH 0867/1769] Allow 2x2x2 topologies with v6e PiperOrigin-RevId: 752372944 --- jax/_src/mesh_utils.py | 2 ++ jax/_src/test_util.py | 4 ++-- tests/pjit_test.py | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index c135919b14c5..fdb8e10d598e 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -34,6 +34,7 @@ _TPU_V5_LITE = "TPU v5 lite" _TPU_V5E = "TPU v5e" _TPU_V5P = "TPU v5p" +_TPU_V6_LITE = "TPU v6 lite" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -190,6 +191,7 @@ def _v5p_create_device_mesh( _TPU_V3: _tpu_v2_v3_create_device_mesh, _TPU_V5_LITE: _v5e_create_device_mesh, _TPU_V5P: _v5p_create_device_mesh, + _TPU_V6_LITE: _v5e_create_device_mesh, } diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 8cd0b0d7f6f4..c584ffefa4f2 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1415,12 +1415,12 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def with_explicit_mesh(sizes, names, axis_types=None): +def with_explicit_mesh(sizes, names, axis_types=None, iota_order=False): axis_types = ((mesh_lib.AxisType.Explicit,) * len(names) if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): - mesh = create_mesh(sizes, names, axis_types=axis_types) + mesh = create_mesh(sizes, names, iota_order, axis_types=axis_types) with jax.sharding.use_mesh(mesh): return fn(*args, **kwargs, mesh=mesh) return mesh_fn diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 03f5dd13c834..2c667ba36157 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3618,6 +3618,9 @@ def test_device_put_grad(self): if jtu.is_device_tpu(5, 'e'): self.skipTest('TPU v5e does not support computations that run on a ' 'non-singleton subset of cores.') + if jtu.is_device_tpu(6, 'e'): + self.skipTest('TPU v6e does not support computations that run on a ' + 'non-singleton subset of cores.') def _test(fun, inp, np_inp, in_s): out = fun(inp) From 15d294ca62df2c29ab4d042ab2909b8fe83d8918 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 28 Apr 2025 14:08:58 -0400 Subject: [PATCH 0868/1769] Fix deprecation warnings in cudnn scaled matmul. --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 6766e3992202..8598ca5f8920 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -28,7 +28,7 @@ from jax._src.interpreters import batching from jax._src.lax.lax import ranges_like, remaining from jax._src.typing import DTypeLike -from jax.interpreters import mlir, xla +from jax._src.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P @@ -112,7 +112,7 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): _scaled_matmul_p = core.Primitive("scaled_matmul") _scaled_matmul_p.multiple_results = True -_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p)) +dispatch.simple_impl(_scaled_matmul_p) _scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract) From 2805d9afbb40196b1cb7234dc4c84762ea271867 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Apr 2025 11:08:59 -0700 Subject: [PATCH 0869/1769] Fix scaled_matmul_stablehlo_test in x64 mode The test is written with the assumption that jax.random.uniform returns float32 when the dtype is unspecified, which is not always true. PiperOrigin-RevId: 752375641 --- tests/scaled_matmul_stablehlo_test.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index b53ffcd5b977..d2483966c984 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -736,11 +736,11 @@ def test_dot_general_sharded(self, in_shardings): k1, k2 = jax.random.split(jax.random.key(0), 2) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) @@ -771,10 +771,6 @@ def fwd(a, b, is_ref=False): j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]), in_shardings=input_shardings) - hlo_text = j_train.lower(a, b).compile().as_text() - hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) j_train_ref = jax.jit( jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]), @@ -808,11 +804,11 @@ def test_dot_general_vmap(self, configs): dimension_numbers = (([1], [1]), ([], [])) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) From dad87a314c861b33f9225c087a373facae525dc8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 18:21:03 +0000 Subject: [PATCH 0870/1769] Bump actions/setup-python from 5.5.0 to 5.6.0 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.5.0 to 5.6.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/8d9ed9ac5c53483de85588cdf95a591a75ab9f55...a26af69be951a213d495a4c3e4e4022e16d87065) --- updated-dependencies: - dependency-name: actions/setup-python dependency-version: 5.6.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/jax-array-api.yml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 7bb31f9d0327..09f169548796 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,7 +31,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.11 - run: python -m pip install pre-commit @@ -70,7 +70,7 @@ jobs: apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -108,7 +108,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -142,7 +142,7 @@ jobs: apt update apt install -y libssl-dev libsqlite3-dev build-essential - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -168,7 +168,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -201,7 +201,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 7df4228dd2a3..c062970e6b1e 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index ba2c750f8a8a..47f0ae0689b2 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index a2b3aeddc24a..d15a2305d3dc 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 5a435023ffda..b186d4315b02 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' From 934d13c09ba495869ca8afe87bc74d4c1a8cf3eb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Apr 2025 12:03:01 -0700 Subject: [PATCH 0871/1769] [Mosaic TPU] Add support for narrow integer `arith.constant`s in kernels PiperOrigin-RevId: 752396496 --- tests/pallas/tpu_pallas_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 908bdab54027..fccf1e42b932 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1776,6 +1776,23 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) + def test_sum_in_smem(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 30): + self.skipTest("Needs a newer libTPU") + def kernel(x, out): + a = jnp.array(0, dtype=jnp.int32) + for i in range(4): + for j in range(4): + out[i, j] = a.astype(out.dtype) + a += x[i, j].astype(jnp.int32) + + x = jnp.ones((4, 4), jnp.int16) + spec = pl.BlockSpec(memory_space=pltpu.SMEM) + y = pl.pallas_call(kernel, in_specs=[spec], out_specs=spec, out_shape=x)(x) + np.testing.assert_array_equal( + y, jnp.arange(16, dtype=jnp.int32).reshape(4, 4) + ) + @parameterized.parameters([ dict( m=m, From 1376ae3aceef05f06741d7ea603e076cf453130b Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 28 Apr 2025 12:13:36 -0700 Subject: [PATCH 0872/1769] Add sharding devices to XlaCompileOptions and plumb them through from JAX. This is necessary to support MPMD parallelism in McJAX, since the PjRt-IFRT executable's output shardings can no longer be built with the addressable devices from the PJRT executable, in the case where the executable has no addressable devices. PiperOrigin-RevId: 752400192 --- jax/_src/compilation_cache.py | 11 +- jax/_src/compiler.py | 48 ++++-- jax/_src/interpreters/pxla.py | 9 +- .../jax2tf/tests/sharding_test.py | 11 +- jax/experimental/jax2tf/tests/tf_test_util.py | 5 +- jax/experimental/serialize_executable.py | 8 +- jaxlib/BUILD | 1 + jaxlib/_jax/__init__.pyi | 2 + jaxlib/_jax/ifrt_programs.pyi | 1 + jaxlib/py_client.cc | 139 +++++++++++++++--- jaxlib/py_client.h | 8 +- jaxlib/py_compile_only_client.cc | 32 ++-- jaxlib/py_program.cc | 14 +- jaxlib/xla_client.py | 2 +- tests/compilation_cache_test.py | 37 +++-- 15 files changed, 266 insertions(+), 62 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index e8f7c9f7509c..aa1bd6ab65ba 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -31,6 +31,7 @@ from jax._src import config from jax._src import monitoring from jax._src.compilation_cache_interface import CacheInterface +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lru_cache import LRUCache @@ -207,7 +208,7 @@ def is_executable_in_cache(backend, cache_key: str) -> bool: def get_executable_and_time( - cache_key: str, compile_options, backend + cache_key: str, compile_options, backend, executable_devices ) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. @@ -223,8 +224,12 @@ def get_executable_and_time( executable_and_time = decompress_executable(executable_and_time) serialized_executable, compile_time = extract_executable_and_time( executable_and_time) - xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options) + if jaxlib_extension_version < 332: + xla_executable_deserialized = backend.deserialize_executable( + serialized_executable, compile_options) + else: + xla_executable_deserialized = backend.deserialize_executable( + serialized_executable, executable_devices, compile_options) return xla_executable_deserialized, compile_time diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 3d2ed0ccd050..04f993fed799 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -35,6 +35,7 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -288,6 +289,7 @@ def get_compile_options( def backend_compile( backend: xc.Client, module: ir.Module, + executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: @@ -312,16 +314,26 @@ def backend_compile( ) try: + if jaxlib_extension_version < 332: + if host_callbacks: + return backend.compile( + built_c, compile_options=options, host_callbacks=host_callbacks) # type: ignore + return backend.compile(built_c, compile_options=options) # type: ignore + # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results if host_callbacks: return backend.compile( - built_c, compile_options=options, host_callbacks=host_callbacks + built_c, + executable_devices=executable_devices, # type: ignore + compile_options=options, + host_callbacks=host_callbacks, ) # Some backends don't have `host_callbacks` option yet # TODO(sharadmv): remove this fallback when all backends allow `compile` # to take in `host_callbacks` - return backend.compile(built_c, compile_options=options) + return backend.compile( + built_c, executable_devices=executable_devices, compile_options=options) # type: ignore except xc.XlaRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) @@ -357,6 +369,7 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + executable_devices: xc.DeviceList, pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] @@ -385,14 +398,15 @@ def compile_or_get_cached( ) if cache_key is None: - return backend_compile(backend, computation, compile_options, - host_callbacks) + return backend_compile( + backend, computation, executable_devices, compile_options, + host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( - module_name, cache_key, compile_options, backend) + module_name, cache_key, compile_options, backend, executable_devices) cache_retrieval_time = time.monotonic() - cache_retrieval_start if retrieved_executable is not None: @@ -420,6 +434,7 @@ def compile_or_get_cached( return _compile_and_share_module( backend, computation, + executable_devices, compile_options, host_callbacks, distributed.global_state.client, @@ -432,6 +447,7 @@ def compile_or_get_cached( return _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -631,11 +647,13 @@ def _share_fdo_profiles( _share_fdo_profiles.modules_profiles = {} + # The process with the first_process_id should compile the module and write it # to the K-V storage. def _compile_and_share_module( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], global_client: lib._jax.DistributedRuntimeClient, @@ -654,6 +672,7 @@ def _compile_and_share_module( executable = _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -673,18 +692,24 @@ def _compile_and_share_module( serialized_executable = compilation_cache.decompress_executable( serialized_executable ) - executable = backend.deserialize_executable( - serialized_executable, compile_options - ) + if jaxlib_extension_version < 332: + executable = backend.deserialize_executable( + serialized_executable, compile_options) # type: ignore + else: + executable = backend.deserialize_executable( + serialized_executable, executable_devices, compile_options) # type: ignore _compile_and_share_module.modules_cache[cache_key] = executable return executable + _compile_and_share_module.modules_cache = {} + def _compile_and_write_cache( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], module_name: str, @@ -692,7 +717,7 @@ def _compile_and_write_cache( ) -> xc.LoadedExecutable: start_time = time.monotonic() executable = backend_compile( - backend, computation, compile_options, host_callbacks + backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time _cache_write( @@ -700,6 +725,7 @@ def _compile_and_write_cache( ) return executable + def _is_executable_in_cache(backend, cache_key) -> bool: """Checks if executable is presented in cache on a given key """ @@ -716,14 +742,14 @@ def _is_executable_in_cache(backend, cache_key) -> bool: def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, - backend: xc.Client + backend: xc.Client, executable_devices: xc.DeviceList, ) -> tuple[xc.LoadedExecutable | None, int | None]: """Looks up the `computation` and it's compilation time in the persistent compilation cache repository. """ try: return compilation_cache.get_executable_and_time( - cache_key, compile_options, backend) + cache_key, compile_options, backend, executable_devices) except Exception as ex: if config.raise_persistent_cache_errors.value: raise diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6c9c6d2aad68..b41bcd94165a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1099,9 +1099,14 @@ def from_hlo(hlo: ir.Module, with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): + # `executable_devices` contains devices for output shardings of a pmapped + # function. It contains only local devices for correspondence with + # `PmapSharding`s, which also contain only local devices. + executable_devices = _create_da_object( + tuple(local_device_assignment.flat)) compiled = compiler.compile_or_get_cached( pci.backend, hlo, device_assignment, compile_options, - host_callbacks) + host_callbacks, executable_devices) return UnloadedPmapExecutable( compiled=compiled, @@ -2792,7 +2797,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks, - pgle_profiler) + da, pgle_profiler) return xla_executable diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index cb28ab9f0dd1..55ccb1328c87 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -33,6 +33,8 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import jaxlib_extension_version +from jax._src.lib import xla_client as xc from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit @@ -109,8 +111,13 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) - jax_optimized_hlo = backend.compile( - jax_hlo, compile_options).hlo_modules()[0].to_string() + if jaxlib_extension_version < 332: + executable = backend.compile( + jax_hlo, compile_options=compile_options) # type: ignore + else: + executable = backend.compile( + jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore + jax_optimized_hlo = executable.hlo_modules()[0].to_string() logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 32f89e533daf..e87a8af5d15e 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -34,6 +34,7 @@ from jax import export from jax._src import config from jax._src import xla_bridge +from jax._src.lib import xla_client as xc import numpy as np import tensorflow as tf from tensorflow.compiler.xla import xla_data_pb2 @@ -344,7 +345,9 @@ def log_message(extra): tf_hlo) backend = xla_bridge.get_backend() - modules = backend.compile(str(jax_lowered.compiler_ir())).hlo_modules() + device_list = xc.DeviceList(tuple(backend.local_devices())) + modules = backend.compile( + str(jax_lowered.compiler_ir()), device_list).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info("[%s] JAX OPT HLO\n%s", self._testMethodName, jax_opt_hlo) diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 2d65141a22ea..0a1f32af322c 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -19,6 +19,7 @@ import io import jax +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc @@ -84,7 +85,12 @@ def __init__(self, file, backend): def persistent_load(self, pid): if pid[0] == 'exec': - return self.backend.deserialize_executable(pid[1]) + if jaxlib_extension_version < 332: + return self.backend.deserialize_executable(pid[1]) + return self.backend.deserialize_executable( + pid[1], + executable_devices=xc.DeviceList(tuple(self.backend.devices())) + ) if pid[0] == 'device': return self.devices_by_id[pid[1]] if pid[0] == 'client': diff --git a/jaxlib/BUILD b/jaxlib/BUILD index cd933868b6e3..50d3d28c05bf 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -831,6 +831,7 @@ cc_library( "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:types", + "@xla//xla/python:version", "@xla//xla/python/compile_only_ifrt:client", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 9c5b493f3d60..00a3a5d01fec 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -506,6 +506,7 @@ class Client: def compile( self, computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... @@ -518,6 +519,7 @@ class Client: def deserialize_executable( self, serialized: bytes, + executable_devices: DeviceList | Sequence[Device], options: CompileOptions | None, host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... diff --git a/jaxlib/_jax/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi index 6fcd6525a95b..5e426b070c21 100644 --- a/jaxlib/_jax/ifrt_programs.pyi +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -38,6 +38,7 @@ def make_colocated_python_compile_options() -> CompileOptions: ... def make_xla_compile_options( compile_options: _jax.CompileOptions, + executable_devices: _jax.DeviceList, host_callbacks: Sequence[Any] ) -> CompileOptions: ... diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index e8251939592f..6caf478a4324 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -52,6 +52,7 @@ limitations under the License. #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_array.h" #include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" #include "jaxlib/py_executable.h" #include "jaxlib/py_host_callback.h" #include "jaxlib/py_memory_space.h" @@ -71,6 +72,7 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" @@ -84,6 +86,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/types.h" +#include "xla/python/version.h" #include "xla/service/platform_util.h" // IWYU pragma: keep #include "xla/shape.h" #include "xla/status_macros.h" @@ -361,7 +364,8 @@ namespace { // Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host // callbacks. std::unique_ptr MakeIfrtCompileOptions( - CompileOptions options, std::vector host_callbacks) { + CompileOptions options, ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { std::vector> ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); @@ -371,14 +375,21 @@ std::unique_ptr MakeIfrtCompileOptions( ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } +#if JAX_IFRT_VERSION_NUMBER >= 6 + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else return std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif } // Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and // optional host callbacks. std::unique_ptr MakeIfrtDeserializeExecutableOptions(std::optional options, + ifrt::DeviceListRef executable_devices, std::vector host_callbacks) { std::vector> ifrt_loaded_host_callbacks; @@ -389,8 +400,14 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } +#if JAX_IFRT_VERSION_NUMBER >= 6 + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else return std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif } } // namespace @@ -447,7 +464,8 @@ PyClient::CompileIfrtProgram( /* static */ absl::StatusOr> PyClient::Compile( nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -458,12 +476,14 @@ PyClient::CompileIfrtProgram( } return CompileIfrtProgram( client, std::make_unique(module.get()), - MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); + MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), + std::move(host_callbacks))); } /* static */ absl::StatusOr> PyClient::Compile( nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -483,8 +503,14 @@ PyClient::CompileIfrtProgram( client->ifrt_client(), std::move(host_callback)); ifrt_loaded_host_callbacks.push_back(callback); } +#if JAX_IFRT_VERSION_NUMBER >= 6 + auto compile_options = std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else auto compile_options = std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif return CompileIfrtProgram( client, std::make_unique(module.get()), std::move(compile_options)); @@ -500,12 +526,14 @@ absl::StatusOr PyClient::SerializeExecutable( /* static */ absl::StatusOr> PyClient::DeserializeExecutable(nb_class_ptr client, nb::bytes serialized, + ifrt::DeviceListRef executable_devices, std::optional options, std::vector host_callbacks) { std::unique_ptr ifrt_loaded_executable; std::optional fingerprint; auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( - std::move(options), std::move(host_callbacks)); + std::move(options), std::move(executable_devices), + std::move(host_callbacks)); { nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN( @@ -733,45 +761,96 @@ PyType_Slot PyClient::slots_[] = { .def( "compile", [](nb_class_ptr client, nb::bytes mlir_module, - CompileOptions options, std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(PyClient::Compile( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), - std::move(options), std::move(host_callbacks))); + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); }, - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), nb::arg("host_callbacks") = std::vector()) .def( "compile", [](nb_class_ptr client, nb::bytes mlir_module, - CompileOptions options, std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(PyClient::Compile( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), - std::move(options), std::move(host_callbacks))); + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); }, - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), nb::arg("host_callbacks") = std::vector()) .def( "compile", [](nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(PyClient::Compile( - std::move(client), std::move(mlir_module), std::move(options), + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), std::move(host_callbacks))); }, - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), nb::arg("host_callbacks") = std::vector()) .def( "compile", [](nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(PyClient::Compile( - std::move(client), std::move(mlir_module), std::move(options), + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), std::move(host_callbacks))); }, - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), nb::arg("host_callbacks") = std::vector()) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) .def("compile_ifrt_program", xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) .def("serialize_executable", @@ -779,14 +858,36 @@ PyType_Slot PyClient::slots_[] = { .def( "deserialize_executable", [](nb_class_ptr client, nb::bytes serialized, + jax::PyDeviceList& py_executable_devices, std::optional options, std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(PyClient::DeserializeExecutable( - std::move(client), std::move(serialized), std::move(options), + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), std::move(host_callbacks))); }, - nb::arg("serialized"), nb::arg("compile_options").none() = nb::none(), + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none(), nb::arg("host_callbacks") = std::vector()) + // The following overload is for users of deprecated APIs who call + // `deserialize_executable` but do not have visibility to `DeviceList`. + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + nb::sequence& py_executable_devices, + std::optional options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none()) .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) // TODO(zhangqiaorjc): Experimental. .def("defragment", diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h index 3bc7057bc4ab..50529fac5c7e 100644 --- a/jaxlib/py_client.h +++ b/jaxlib/py_client.h @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/program.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/shape.h" @@ -168,16 +169,19 @@ class PyClient { static absl::StatusOr> Compile( nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks); + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); static absl::StatusOr> Compile( nb_class_ptr client, std::string mlir_module, - CompileOptions options, std::vector host_callbacks); + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); absl::StatusOr SerializeExecutable( const PyLoadedExecutable& executable) const; static absl::StatusOr> DeserializeExecutable( nb_class_ptr client, nanobind::bytes serialized, + ifrt::DeviceListRef executable_devices, std::optional options, std::vector host_callbacks); diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index f9914edac52a..4d53fc6ee832 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -32,15 +32,18 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/status_casters.h" #include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/version.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -68,8 +71,8 @@ class CompileOnlyPyClient : public PyClient { } absl::StatusOr> CompileUnloaded( - absl::string_view mlir_module, CompileOptions options, - std::vector host_callbacks) { + absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, + CompileOptions options, std::vector host_callbacks) { if (!host_callbacks.empty()) { return Unimplemented( "Compiling with host_callbacks not available with compile-only " @@ -88,7 +91,12 @@ class CompileOnlyPyClient : public PyClient { llvm::dyn_cast_or_null(this->ifrt_client()); CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " "CompileOnlyIfRtClient"; +#if JAX_IFRT_VERSION_NUMBER >= 6 + auto xla_options = std::make_unique( + options, std::move(executable_devices)); +#else auto xla_options = std::make_unique(options); +#endif TF_ASSIGN_OR_RETURN(auto executable, PjRtCompile(std::move(options), module.get(), *ifrt_client->topology().description())); @@ -115,17 +123,23 @@ void RegisterCompileOnlyClient(nb::module_& m) { .def( "compile", [](CompileOnlyPyClient& self, nb::bytes mlir_module, - CompileOptions options, std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(self.CompileUnloaded( absl::string_view(mlir_module.c_str(), mlir_module.size()), - std::move(options), std::move(host_callbacks))); + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); }, - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), nb::arg("host_callbacks") = std::vector()) - .def( - "compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), - nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()); + .def("compile", + ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); } } // namespace xla diff --git a/jaxlib/py_program.cc b/jaxlib/py_program.cc index d01df5e82b1b..8c57bd0515b8 100644 --- a/jaxlib/py_program.cc +++ b/jaxlib/py_program.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/python/types.h" +#include "xla/python/version.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" @@ -223,7 +224,8 @@ absl::StatusOr> MakeHloProgramFromBytes( } absl::StatusOr> MakeXlaCompileOptions( - CompileOptions options, std::vector host_callbacks) { + CompileOptions options, jax::PyDeviceList& py_executable_devices, + std::vector host_callbacks) { std::vector> ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); @@ -234,8 +236,16 @@ absl::StatusOr> MakeXlaCompileOptions( ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } +#if JAX_IFRT_VERSION_NUMBER >= 6 + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef executable_devices, + py_executable_devices.ifrt_device_list()); + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else return std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif } constexpr absl::string_view kColocatedPythonProgramType = @@ -281,7 +291,7 @@ void BuildIfrtProgramsSubmodule(nanobind::module_& m) { ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) .def("make_xla_compile_options", ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), - nb::arg("host_callbacks")) + nb::arg("executable_devices"), nb::arg("host_callbacks")) .def("make_colocated_python_compile_options", ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) .def("make_plugin_compile_options", diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 4521019fa77a..1d7f6cd8f584 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 331 +_version = 332 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 3fcc0ab476bf..d9f8cdddf0f1 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -134,7 +134,7 @@ def test_get_no_executable(self): backend = xla_bridge.get_backend() key = cc.get_cache_key(computation, devices, compile_options, backend) executable, compile_time = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, xc.DeviceList(tuple(devices.flat))) self.assertIsNone(executable) self.assertIsNone(compile_time) @@ -145,15 +145,24 @@ def test_diff_executables(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) + executable_devices = xc.DeviceList(tuple(backend.local_devices())) + if jax._src.lib.jaxlib_extension_version < 331: + executable1 = backend.compile(computation1, compile_options) + executable2 = backend.compile(computation2, compile_options) + else: + executable1 = backend.compile( + computation1, executable_devices, compile_options) + executable2 = backend.compile( + computation2, executable_devices, compile_options) cc.put_executable_and_time( "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) cc.put_executable_and_time( "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) self.assertNotEqual( - cc.get_executable_and_time("key1", compile_options, backend)[0], - cc.get_executable_and_time("key2", compile_options, backend)[0] + cc.get_executable_and_time( + "key1", compile_options, backend, executable_devices)[0], + cc.get_executable_and_time( + "key2", compile_options, backend, executable_devices)[0] ) def test_put_executable(self): @@ -167,12 +176,17 @@ def test_put_executable(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable = backend.compile(str(computation), compile_options) + executable_devices = xc.DeviceList(tuple(devices.flat)) + if jax._src.lib.jaxlib_extension_version < 331: + executable = backend.compile(str(computation), compile_options) + else: + executable = backend.compile( + str(computation), executable_devices, compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) cc.put_executable_and_time( key, "alambda", executable, backend, FAKE_COMPILE_TIME) executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, executable_devices) inputs_to_executable = ( jnp.array(1, dtype=np.int32), jnp.array(2, dtype=np.int32), @@ -562,8 +576,13 @@ def test_backend_serialization_deserialization(self): .runtime_executable() ) serialized_executable = backend.serialize_executable(executable) - deserialized_executable = backend.deserialize_executable( - serialized_executable, None) + if jax._src.lib.jaxlib_extension_version < 331: + deserialized_executable = backend.deserialize_executable( # type: ignore + serialized_executable, None) + else: + deserialized_executable = backend.deserialize_executable( # type: ignore + serialized_executable, + xc.DeviceList(tuple(jax.local_devices(backend=backend))), None) self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) From 4c2c6a2b3cbb4b2149b275527a66477784016788 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 28 Apr 2025 12:14:32 -0700 Subject: [PATCH 0873/1769] Add `standard_insert_pvary` support to `reduce`. Fixes: https://github.com/jax-ml/jax/issues/28334 PiperOrigin-RevId: 752400496 --- jax/_src/lax/lax.py | 2 ++ tests/shard_map_test.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6d9176a3e2d8..2034026e197a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2964,6 +2964,8 @@ def reduce(operands: Any, flat_init_avals = safe_map(core.get_aval, flat_init_values) closed_jaxpr, out_tree = _variadic_reduction_jaxpr( computation, comp_debug, tuple(flat_init_avals), init_value_tree) + flat_operands = core.standard_insert_pvary(*flat_operands) + flat_init_avals = core.standard_insert_pvary(*flat_init_values) out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation, jaxpr=closed_jaxpr, dimensions=tuple(dimensions)) return tree_util.tree_unflatten(out_tree, out) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 5320c12e75cd..ada327d936d6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3068,6 +3068,15 @@ def m2(p, t): shard_map(partial(m2, jnp.array([1.])), mesh=mesh, in_specs=P('x'), out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_argmax_pvary(self, mesh): + @jax.shard_map(in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def argmax_impl(x): + y = x.argmax(axis=-1, keepdims=1) + return y + + argmax_impl(jax.random.normal(jax.random.key(0), (1024, 1024))) # doesn't crash + class FunSpec(NamedTuple): name: str From 2792c6809cbee7abab505bd374577c23bec11a4c Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Mon, 28 Apr 2025 12:39:54 -0700 Subject: [PATCH 0874/1769] [JAX] Add a python 3.14 requirement lock file and update WORKSPACE PiperOrigin-RevId: 752409245 --- WORKSPACE | 36 ++++-- build/nonfreethreading-requirements.txt | 3 +- build/requirements_lock_3_14.txt | 141 ++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 11 deletions(-) create mode 100644 build/requirements_lock_3_14.txt diff --git a/WORKSPACE b/WORKSPACE index 5c093ec2228f..f9c0b3ccfea7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,70 +1,85 @@ # The XLA commit is determined by third_party/xla/workspace.bzl. load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") + jax_xla_workspace() # Initialize hermetic Python load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") + python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") + python_init_repositories( + default_python_version = "system", + local_wheel_dist_folder = "../dist", + local_wheel_inclusion_list = [ + "jax-*", + "jaxlib*", + "jax_cuda*", + "jax-cuda*", + ], + local_wheel_workspaces = ["//jaxlib:jax.bzl"], requirements = { "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", + "3.14": "//build:requirements_lock_3_14.txt", "3.13-ft": "//build:requirements_lock_3_13_ft.txt", "3.14-ft": "//build:requirements_lock_3_14_ft.txt", }, - local_wheel_inclusion_list = [ - "jax-*", - "jaxlib*", - "jax_cuda*", - "jax-cuda*", - ], - local_wheel_workspaces = ["//jaxlib:jax.bzl"], - local_wheel_dist_folder = "../dist", - default_python_version = "system", ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") + python_init_toolchains() load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") + python_init_pip() load("@pypi//:requirements.bzl", "install_deps") + install_deps() # Optional, to facilitate testing against newest versions of Python load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") + custom_python_interpreter( name = "python_dev", - urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], strip_prefix = "Python-{version_variant}", + urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], version = "3.13.0", version_variant = "3.13.0rc2", ) load("@xla//:workspace4.bzl", "xla_workspace4") + xla_workspace4() load("@xla//:workspace3.bzl", "xla_workspace3") + xla_workspace3() load("@xla//:workspace2.bzl", "xla_workspace2") + xla_workspace2() load("@xla//:workspace1.bzl", "xla_workspace1") + xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") + xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") + flatbuffers() load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") + jax_python_wheel_repository( name = "jax_wheel", version_key = "_version", @@ -75,6 +90,7 @@ load( "@xla//third_party/py:python_wheel.bzl", "python_wheel_version_suffix_repository", ) + python_wheel_version_suffix_repository( name = "jax_wheel_version_suffix", ) diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt index 19b9cb51686d..86f5e64d1973 100644 --- a/build/nonfreethreading-requirements.txt +++ b/build/nonfreethreading-requirements.txt @@ -1,5 +1,6 @@ numpy~=2.0.0; python_version<="3.12" -numpy~=2.1.0; python_version>="3.13" +numpy~=2.1.0; python_version=="3.13" +numpy~=2.2.5; python_version>="3.14" # These packages have not released free-threaded wheels. zstandard diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt new file mode 100644 index 000000000000..6edcd30ebe16 --- /dev/null +++ b/build/requirements_lock_3_14.txt @@ -0,0 +1,141 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile --output-file=build/requirements_lock_3_14.txt build/requirements.in build/nonfreethreading-requirements.txt build/test-requirements.txt build/gpu-test-requirements.txt +absl-py==2.2.2 + # via -r build/test-requirements.txt +attrs==25.3.0 + # via hypothesis +auditwheel==6.3.0 + # via -r build/test-requirements.txt +build==1.2.2.post1 + # via -r build/test-requirements.txt +cloudpickle==3.1.1 + # via -r build/test-requirements.txt +colorama==0.4.6 + # via -r build/test-requirements.txt +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +etils==1.12.2 + # via -r build/requirements.in +execnet==2.1.1 + # via pytest-xdist +filelock==3.18.0 + # via -r build/test-requirements.txt +flatbuffers==25.2.10 + # via -r build/test-requirements.txt +fonttools==4.57.0 + # via matplotlib +fsspec==2025.3.2 + # via etils +hypothesis==6.131.9 + # via -r build/test-requirements.txt +importlib-resources==6.5.2 + # via etils +iniconfig==2.1.0 + # via pytest +kiwisolver==1.4.8 + # via matplotlib +markdown-it-py==3.0.0 + # via rich +matplotlib==3.10.1 + # via -r build/test-requirements.txt +mdurl==0.1.2 + # via markdown-it-py +ml-dtypes==0.5.1 + # via + # -r build/requirements.in + # tensorstore +mpmath==1.4.0a4 + # via -r build/test-requirements.txt +numpy==2.2.5 + # via + # -r build/nonfreethreading-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # scipy + # tensorstore +nvidia-cublas-cu12==12.8.4.1 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.90 + # via -r build/gpu-test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.93 + # via -r build/gpu-test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.90 + # via -r build/gpu-test-requirements.txt +nvidia-cudnn-cu12==9.8.0.87 + # via -r build/gpu-test-requirements.txt +nvidia-cufft-cu12==11.3.3.83 + # via -r build/gpu-test-requirements.txt +nvidia-cusolver-cu12==11.7.3.90 + # via -r build/gpu-test-requirements.txt +nvidia-cusparse-cu12==12.5.8.93 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.26.2.post1 + # via -r build/gpu-test-requirements.txt +nvidia-nvjitlink-cu12==12.8.93 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.4.0 + # via -r build/test-requirements.txt +packaging==25.0 + # via + # auditwheel + # build + # matplotlib + # pytest +pillow==11.2.1 + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.5.0 + # via pytest +portpicker==1.6.0 + # via -r build/nonfreethreading-requirements.txt +psutil==7.0.0 + # via portpicker +pyelftools==0.32 + # via auditwheel +pygments==2.19.1 + # via rich +pyparsing==3.2.3 + # via matplotlib +pyproject-hooks==1.2.0 + # via build +pytest==8.3.5 + # via pytest-xdist +pytest-xdist==3.6.1 + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 + # via matplotlib +rich==14.0.0 + # via -r build/test-requirements.txt +scipy==1.15.2 + # via -r build/requirements.in +setuptools==80.0.0 + # via + # -r build/requirements.in + # -r build/test-requirements.txt +six==1.17.0 + # via python-dateutil +sortedcontainers==2.4.0 + # via hypothesis +tensorstore==0.1.74 + # via -r build/nonfreethreading-requirements.txt +typing-extensions==4.13.2 + # via etils +wheel==0.45.1 + # via -r build/test-requirements.txt +zipp==3.21.0 + # via etils +zstandard==0.23.0 + # via -r build/nonfreethreading-requirements.txt From 8c1d3c53f547a42877f99b2721545d333d96937a Mon Sep 17 00:00:00 2001 From: gentlelovebear Date: Mon, 28 Apr 2025 13:01:34 -0700 Subject: [PATCH 0875/1769] [jax/docs] - Fix link in quickstart.md --- docs/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index d2d9bf8cec41..ec9f3ccd3633 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -58,7 +58,7 @@ print(selu(x)) ``` You'll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; -these are explored in [🔪 JAX - The Sharp Bits 🔪](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html). +these are explored in [🔪 JAX - The Sharp Bits 🔪](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). ## Just-in-time compilation with {func}`jax.jit` JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the {func}`jax.jit` function to compile this sequence of operations together using XLA. From 936762f34ea55656454ac62c9408d235321386f1 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Mon, 28 Apr 2025 20:20:40 +0000 Subject: [PATCH 0876/1769] chore: load py_library from rules_python --- jaxlib/jax.bzl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 718401c1477f..632ff2047078 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -22,7 +22,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") -load("@rules_python//python:defs.bzl", "py_test") +load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load("@xla//xla/tsl:tsl.bzl", "transitive_hdrs", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") @@ -144,17 +144,17 @@ jax2tf_deps = [] def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused - native.py_library(name = name, **kwargs) + py_library(name = name, **kwargs) def pytype_strict_library(name, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} - native.py_library(name = name, data = data, **new_kwargs) + py_library(name = name, data = data, **new_kwargs) -py_strict_library = native.py_library -py_strict_test = native.py_test +py_strict_library = py_library +py_strict_test = py_test -def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): +def py_library_providing_imports_info(*, name, lib_rule = py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} lib_rule(name = name, data = data, **new_kwargs) From dbf298c1368135ed3db011b91afc53fb6e4a4bfe Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 28 Apr 2025 16:31:27 -0400 Subject: [PATCH 0877/1769] Disable logsumexp test under SciPy 1.15. I found https://github.com/scipy/scipy/issues/22903 to reproduce in JAX's CI in the right environment. --- tests/lax_scipy_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 20ed169d9405..e0d3528dfa41 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -113,6 +113,11 @@ def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0): self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.") + if use_b and scipy_version >= (1, 15) and scipy_version < (1, 15, 3): + self.skipTest( + "TODO(https://github.com/scipy/scipy/issues/22903): logsumexp with a" + " b scale array is buggy in scipy 1.15" + ) if not jtu.test_device_matches(["cpu", "gpu"]): rng = jtu.rand_some_inf_and_nan(self.rng()) else: From 8545f234308844ba6e9f51f8315f195e5f2f2c58 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 28 Apr 2025 14:40:01 -0700 Subject: [PATCH 0878/1769] Change nightly installation instructions and CI to use new package index. The old package index will be deprecated and will soon no longer get new nightly builds added to it. Issue https://github.com/jax-ml/jax/issues/5410 PiperOrigin-RevId: 752452919 --- .github/workflows/cloud-tpu-ci-nightly.yml | 4 ++-- docs/installation.md | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index fd799a3f70b5..2346de94377c 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -88,14 +88,14 @@ jobs: elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests diff --git a/docs/installation.md b/docs/installation.md index c9bf3a62942b..1314a2efa0a8 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -281,22 +281,34 @@ Unlike the instructions for installing a JAX release, here we name all of JAX's packages explicitly on the command line, so `pip` will upgrade them if a newer version is available. +JAX publishes nightlies, release candidates(RCs), and releases to several non-pypi [PEP 503](https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/) indexes. + +All JAX packages can be reached from the index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/` +as well as PyPI mirrored packages. This additional mirroring enables nightly +installation to use --index (-i) as the install method with pip. + +**Note:** The unified index could return an RC or release as the newest version +even with `--pre` immediately after a release before the newest nightly is +rebuilt. If automation or testing must be done against nightlies or you cannot +use our full index, use the extra index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/` +which only contains nightly artifacts. + - CPU only: ```bash -pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` - Google Cloud TPU: ```bash -pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` - NVIDIA GPU (CUDA 12) legacy: @@ -322,10 +334,10 @@ still be installed directly via the URLs here. For example: ```bash # Install jaxlib on CPU via the wheel archive -pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ # Install the jaxlib 0.3.25 CPU wheel directly -pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example ```bash From 836099f3380164b6c2dda3b8e54f5a12725bd81a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 28 Apr 2025 14:58:34 -0700 Subject: [PATCH 0879/1769] Disable old JAX nightly builds PiperOrigin-RevId: 752458934 --- .github/workflows/cloud-tpu-ci-nightly.yml | 11 +++++++---- .github/workflows/wheel_win_x64.yml | 5 +++-- .github/workflows/windows_ci.yml | 14 ++++++++------ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 2346de94377c..cb9e04d4488d 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -11,10 +11,13 @@ # Github Actions environment). name: CI - Cloud TPU (nightly) -on: - schedule: - - cron: "0 2,14 * * *" # Run at 7am and 7pm PST - workflow_dispatch: # allows triggering the workflow run manually +# Disable the schedule; Slated for removal, the new test workflow is in +# "wheel_tests_nightly_release.yml" +# on: +# schedule: +# - cron: "0 2,14 * * *" # Run at 7am and 7pm PST +# workflow_dispatch: # allows triggering the workflow run manually + # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. permissions: diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index d15a2305d3dc..6539a50ce790 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -1,6 +1,7 @@ name: Wheel build - Windows CPU x86_64 -on: - workflow_dispatch: # allows triggering the workflow run manually +# Slated for removal, Windows release/nightly wheels are now built in the internal CI system. +# on: +# workflow_dispatch: # allows triggering the workflow run manually concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index b186d4315b02..7d848391e64d 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -1,10 +1,12 @@ name: CI - Windows CPU -on: - schedule: - - cron: "0 12 * * *" # Daily at 12:00 UTC - workflow_dispatch: # allows triggering the workflow run manually - pull_request: - types: [ labeled ] # allow force-windows-run label +# Disable the schedule; Slated for removal, the new test workflows are in +# "wheel_tests_nightly_release.yml" and "wheel_tests_continuous.yml" +# on: +# schedule: +# - cron: "0 12 * * *" # Daily at 12:00 UTC +# workflow_dispatch: # allows triggering the workflow run manually +# pull_request: +# types: [ labeled ] # allow force-windows-run label concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} From cf072b8b1de736c05f09a8b04341442f2b136f7f Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 28 Apr 2025 15:44:43 -0700 Subject: [PATCH 0880/1769] Small changes for easier testing of kernels with TPU interpret mode. PiperOrigin-RevId: 752474493 --- jax/_src/pallas/helpers.py | 12 +++++++++--- jax/_src/pallas/mosaic/core.py | 2 +- jax/_src/pallas/pallas_call.py | 16 ++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 1b2649d4e987..684101e47e9e 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -23,7 +23,11 @@ @jax.named_call def empty( - shape: tuple[int, ...], dtype: jnp.dtype, *, memory_space: Any = None + shape: tuple[int, ...], + dtype: jnp.dtype, + *, + memory_space: Any = None, + interpret: Any = False, ): def _empty_kernel(_): # No-op to leave the out_ref uninitialized @@ -39,6 +43,7 @@ def _empty_kernel(_): in_specs=[], out_specs=pl_core.BlockSpec(memory_space=kernel_memory_space), out_shape=memory_space(shape, dtype), + interpret=interpret, )() @@ -47,8 +52,9 @@ class ArrayLike(Protocol): dtype: jnp.dtype -def empty_like(x: ArrayLike, *, memory_space: Any = None): - return empty(x.shape, x.dtype, memory_space=memory_space) +def empty_like( + x: ArrayLike, *, memory_space: Any = None, interpret: Any = False): + return empty(x.shape, x.dtype, memory_space=memory_space, interpret=interpret) def when(condition): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 18ad3029398e..d31de1c22089 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -224,7 +224,7 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, compiler_params: Any | None, - interpret: bool, + interpret: Any, debug: bool, cost_estimate: pallas_core.CostEstimate | None, name: str, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 19469824aa6a..286e49768cca 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -122,7 +122,7 @@ def _pallas_call_jvp_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -335,7 +335,7 @@ def _batch_with_explicit_loop( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -433,7 +433,7 @@ def _pallas_call_batching_rule( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1032,7 +1032,7 @@ def pallas_call_checkify_rule(error: checkify.Error, enabled_errors, *args: jax_core.Value, jaxpr: jax_core.Jaxpr, - interpret: bool, + interpret: Any, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1252,7 +1252,7 @@ def _unsupported_lowering_error(platform: str) -> Exception: def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, - interpret: bool, + interpret: Any, backend: Backend | None, **params, ): @@ -1366,7 +1366,7 @@ def _pallas_call_state_discharge_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1490,7 +1490,7 @@ def pallas_call( scratch_shapes: ScratchShapeTree = (), input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, compiler_params: ( Mapping[Backend, CompilerParams] | CompilerParams | None @@ -1620,7 +1620,7 @@ def _pallas_call( mesh: pallas_core.Mesh | None = None, input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, compiler_params: ( Mapping[Backend, CompilerParams] | CompilerParams | None From e572dbe54134ee281b523078a0a7006a941c850c Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 28 Apr 2025 15:53:26 -0700 Subject: [PATCH 0881/1769] Deactivate automatic cluster detection in tests. Previously, the tests were detecting a kubernetes environment but failing to import the kubernetes module, which was failing the tests. --- tests/distributed_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 5e47228c1719..ae72143fbe7d 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -41,7 +41,10 @@ def testInitializeAndShutdown(self): # concurrency to simulate multiple tasks. port = portpicker.pick_unused_port() jax.distributed.initialize( - coordinator_address=f"localhost:{port}", num_processes=1, process_id=0 + coordinator_address=f"localhost:{port}", + num_processes=1, + process_id=0, + cluster_detection_method="deactivate", ) jax.distributed.shutdown() @@ -55,7 +58,10 @@ def task(i): # We can't call the public APIs directly because they use global state. state = distributed.State() state.initialize( - coordinator_address=f"localhost:{port}", num_processes=n, process_id=i + coordinator_address=f"localhost:{port}", + num_processes=n, + process_id=i, + cluster_detection_method="deactivate", ) state.shutdown() From 01ee64d42ca1be67284731a7455e9c588b7ab5a5 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 28 Apr 2025 17:26:21 -0700 Subject: [PATCH 0882/1769] Update JAX nightly index usage Details in https://github.com/jax-ml/jax/discussions/28366 PiperOrigin-RevId: 752505159 --- .github/workflows/metal_plugin_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 7948337e9b6c..2135e473d6be 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -39,7 +39,7 @@ jobs: uv pip install -U pip numpy wheel absl-py pytest if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then uv pip install --pre jaxlib \ - -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ fi; cd jax uv pip install . jax-metal From 728048978ce1a3d982d0f34d4b95d0ebdf4b4b30 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 28 Apr 2025 18:02:51 -0700 Subject: [PATCH 0883/1769] Enter into correct mesh context in shard_map in `_partial_eval_jaxpr_custom_rule`. PiperOrigin-RevId: 752516299 --- jax/_src/shard_map.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 8313af3c2590..27f5956f0374 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1482,7 +1482,8 @@ def _partial_eval_jaxpr_custom_rule( list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] - with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res From 9d4f4aba28233e9ec31f145247d9ad9781f27e13 Mon Sep 17 00:00:00 2001 From: Yan Zhao Date: Tue, 29 Apr 2025 01:34:10 +0000 Subject: [PATCH 0884/1769] SPMD with jax.jit, namedsharding and partition spec Signed-off-by: Yan Zhao use jax.jit and replace the old example Signed-off-by: Yan Zhao --- examples/spmd_mnist_classifier_fromscratch.py | 112 ++++++++++-------- 1 file changed, 64 insertions(+), 48 deletions(-) diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 3698314708c7..234ffac7de4c 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -12,33 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An MNIST example with single-program multiple-data (SPMD) data parallelism. - -The aim here is to illustrate how to use JAX's `pmap` to express and execute -SPMD programs for data parallelism along a batch dimension, while also -minimizing dependencies by avoiding the use of higher-level layers and -optimizers libraries. -""" - - -from functools import partial import time import numpy as np import numpy.random as npr import jax -from jax import jit, grad, pmap +from jax import jit, grad +from jax.sharding import PartitionSpec as P, NamedSharding from jax.scipy.special import logsumexp from jax.tree_util import tree_map -from jax import lax import jax.numpy as jnp -from examples import datasets +import datasets def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): - return [(scale * rng.randn(m, n), scale * rng.randn(n)) - for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + return [ + (scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + def predict(params, inputs): activations = inputs @@ -50,11 +43,20 @@ def predict(params, inputs): logits = jnp.dot(activations, final_w) + final_b return logits - logsumexp(logits, axis=1, keepdims=True) + def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) + +def train_step(params, batch): + grads = grad(loss)(params, batch) + return [ + (w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads) + ] + + @jit def accuracy(params, batch): inputs, targets = batch @@ -72,57 +74,71 @@ def accuracy(params, batch): train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] + + num_devices = jax.device_count() + print(f"Using {num_devices} devices") + + if batch_size % num_devices != 0: + batch_size = (batch_size // num_devices) * num_devices + print(f"Adjusting batch size to {batch_size} for divisibility") + num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) - # For this manual SPMD example, we get the number of devices (e.g. GPUs or - # TPU cores) that we're using, and use it to reshape data minibatches. - num_devices = jax.device_count() + devices = np.array(jax.devices()) + mesh = jax.make_mesh((jax.device_count(),), ("batch",)) + + replicated_sharding = NamedSharding(mesh, P()) + data_sharding = NamedSharding(mesh, P("batch")) + def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): - batch_idx = perm[i * batch_size:(i + 1) * batch_size] - images, labels = train_images[batch_idx], train_labels[batch_idx] - # For this SPMD example, we reshape the data batch dimension into two - # batch dimensions, one of which is mapped over parallel devices. - batch_size_per_device, ragged = divmod(images.shape[0], num_devices) - if ragged: - msg = "batch size must be divisible by device count, got {} and {}." - raise ValueError(msg.format(batch_size, num_devices)) - shape_prefix = (num_devices, batch_size_per_device) - images = images.reshape(shape_prefix + images.shape[1:]) - labels = labels.reshape(shape_prefix + labels.shape[1:]) + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + images_np, labels_np = train_images[batch_idx], train_labels[batch_idx] + + current_batch_size = images_np.shape[0] + if current_batch_size < batch_size: + pad_len = batch_size - current_batch_size + images_np = np.concatenate([images_np, images_np[:pad_len]], axis=0) + labels_np = np.concatenate([labels_np, labels_np[:pad_len]], axis=0) + + images = jax.device_put(images_np, data_sharding) + labels = jax.device_put(labels_np, data_sharding) yield images, labels + batches = data_stream() - @partial(pmap, axis_name='batch') - def spmd_update(params, batch): - grads = grad(loss)(params, batch) - # We compute the total gradients, summing across the device-mapped axis, - # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. - grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] - - # We replicate the parameters so that the constituent arrays have a leading - # dimension of size equal to the number of devices we're pmapping over. - init_params = init_random_params(param_scale, layer_sizes) - replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape) - replicated_params = tree_map(replicate_array, init_params) + params = init_random_params(param_scale, layer_sizes) + + param_shardings = tree_map(lambda x: replicated_sharding, params) + params = jax.device_put(params, param_shardings) + batch_shardings = (data_sharding, data_sharding) + + jitted_train_step = jax.jit( + train_step, + out_shardings=param_shardings, + donate_argnums=(0,), + ) for epoch in range(num_epochs): start_time = time.time() - for _ in range(num_batches): - replicated_params = spmd_update(replicated_params, next(batches)) + for i in range(num_batches - 1): + print(f"Batch no {i+1} of {num_batches}") + batch = next(batches) + with jax.sharding.use_mesh(mesh): + params = jitted_train_step(params, batch) epoch_time = time.time() - start_time - # We evaluate using the jitted `accuracy` function (not using pmap) by - # grabbing just one of the replicated parameter values. - params = tree_map(lambda x: x[0], replicated_params) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print(f"Epoch {epoch} in {epoch_time:0.2f} sec") print(f"Training set accuracy {train_acc}") print(f"Test set accuracy {test_acc}") + + if epoch < num_epochs - 1: + batches = data_stream() + print(f"Batch no {0} of {num_batches}") + params = jitted_train_step(params, next(batches)) From ebdaa25ed436aecd9080f1171f15758ef9573062 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Tue, 29 Apr 2025 00:44:40 -0700 Subject: [PATCH 0885/1769] Add GetAllocatorStats() method for device. PiperOrigin-RevId: 752620345 --- tests/gpu_memory_flags_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index 308fff257348..87f60dd86f20 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -40,7 +40,7 @@ def test_gpu_memory_allocation(self): device = jax.devices()[0] mem_stats = device.memory_stats() self.assertEqual(mem_stats["pool_bytes"], 0) - x = jax.lax.add(1, 2) + x = jax.lax.add(1, 2).block_until_ready() mem_stats = device.memory_stats() if preallocate: From 66fca354d184639fe66da1c35dfdaa6f5cf542e5 Mon Sep 17 00:00:00 2001 From: Armand Picard Date: Mon, 28 Apr 2025 18:14:38 +0200 Subject: [PATCH 0886/1769] fix: use backend to call xb.process_count in _raise_warnings_or_errors_for_jit_of_pmap not giving the backend to process_count was causing exception when lowering a functions for GPU precompilation device without having a GPU device. --- jax/_src/interpreters/pxla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b41bcd94165a..75d26360c363 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1890,7 +1890,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") - if xb.process_count() > 1 and ( + if xb.process_count(backend) > 1 and ( nreps > 1 or dispatch.jaxpr_has_primitive(jaxpr, "xla_pmap") ): raise NotImplementedError( From faf682ebd1f2a6646d54161208eed30c1743cdfb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 28 Apr 2025 14:14:08 -0400 Subject: [PATCH 0887/1769] Show Literal's aval when pretty printing. Co-authored-by: Yash Katariya --- docs/aot.md | 6 +++++- jax/_src/api.py | 2 +- jax/_src/core.py | 21 ++++++++++++++++++--- jax/_src/state/indexing.py | 2 +- tests/api_test.py | 9 +++++---- tests/pjit_test.py | 4 ++-- tests/state_test.py | 4 ++-- 7 files changed, 34 insertions(+), 14 deletions(-) diff --git a/docs/aot.md b/docs/aot.md index 8f68c2758148..426e3b06ebf5 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -49,7 +49,11 @@ some other features along the way. An example: >>> # Print the specialized, staged-out representation (as Jaxpr IR) >>> print(traced.jaxpr) -{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) } +{ lambda ; a:i32[] b:i32[]. let + c:i32[] = mul 2:i32 a + d:i32[] = add c b + in (d,) } + >>> lowered = traced.lower() diff --git a/jax/_src/api.py b/jax/_src/api.py index 2812f8a808f8..4a729b223ff1 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2396,7 +2396,7 @@ def make_jaxpr( c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b - e:f32[] = mul 1.0 d + e:f32[] = mul 1.0:f32 d f:f32[] = neg e g:f32[] = mul f c in (g,) } diff --git a/jax/_src/core.py b/jax/_src/core.py index dcdfc91344ce..4dbf3b94bcf4 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -427,6 +427,10 @@ def __init__(self, suffix: str, aval: AbstractValue): def __repr__(self): return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del print_dtype # unused + return f"{context.var_names[self]}{self.suffix}" + def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: """Produce distinct variables, printed with the optional suffix.""" @@ -440,6 +444,9 @@ class DropVar(Var): def __init__(self, aval: AbstractValue): super().__init__('', aval) def __repr__(self): return '_' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context, print_dtype # unused + return '_' class Literal: __slots__ = ["val", "aval", "hash"] @@ -462,6 +469,14 @@ def __init__(self, val, aval): __hash__ = None # type: ignore + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context # unused + dtype = getattr(self.aval, 'dtype', None) + if print_dtype and dtype: + return f'{self.val}:{dtypes.short_dtype_name(dtype)}' + else: + return f'{self.val}' + def __repr__(self): if hasattr(self, 'hash'): return f'{self.val}' @@ -3152,9 +3167,9 @@ def suggest_same_var_names(self, self.var_names[for_v] = pp_var(like_v, self) -def pp_var(v: Var | Literal, context: JaxprPpContext) -> str: - if isinstance(v, (Literal, DropVar)): return str(v) - return f"{context.var_names[v]}{v.suffix}" +def pp_var(v: Var | Literal, context: JaxprPpContext, *, + print_literal_dtype: bool = True) -> str: + return v.pretty_print(context, print_dtype=print_literal_dtype) def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: if isinstance(a, DShapedArray): diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index e6e6b8a5ee25..adca41f82f7c 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -324,5 +324,5 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: if isinstance(idx, Slice): indices.append(_pp_slice(context, dim, idx)) else: - indices.append(core.pp_var(idx, context)) # type: ignore + indices.append(core.pp_var(idx, context, print_literal_dtype=False)) # type: ignore return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/tests/api_test.py b/tests/api_test.py index 7a1218d3790a..8a6a022a5677 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6657,7 +6657,8 @@ def test_const(self): def fun(x): return (x, 1., np.zeros(1, dtype=jnp.float32)) - expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }" + dtype = "f64" if config.enable_x64.value else "f32" + expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}, a) }}" jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) @@ -6669,9 +6670,9 @@ def f(x): x + 2., lambda xf: xf - x) expected = """{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:f32[] = add a 1.0 - d:f32[] = add a 2.0 + b:bool[] = ge a 0.0:f32 + c:f32[] = add a 1.0:f32 + d:f32[] = add a 2.0:f32 e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2c667ba36157..7f08c5af40b3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1274,7 +1274,7 @@ def test_pretty_print_with_constant_pjit_arg(self): b:f32[1] = pjit[ name= jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) } - ] a 1.0 + ] a 1.0:f32 in (b,) } """).strip(), ) @@ -1308,7 +1308,7 @@ def test_pretty_print_with_literal_outvar(self): { lambda ; a:f32[1]. let b:i32[] c:f32[1] = pjit[ name= - jaxpr={ lambda ; a:f32[1]. let in (2, a) } + jaxpr={ lambda ; a:f32[1]. let in (2:i32, a) } ] a in (b, c) } """).strip(), diff --git a/tests/state_test.py b/tests/state_test.py index d9bf66eb3f50..a8d6e88659a6 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -361,7 +361,7 @@ def body(x_ref): return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("a[] <- 2:i32", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val @@ -377,7 +377,7 @@ def body(x_ref): return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("b:i32[], a[] <- a[], 2:i32", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) From 6f09c833729d8433df4c4d40784c9134c70d6d9e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 29 Apr 2025 03:01:39 -0700 Subject: [PATCH 0888/1769] [Mosaic GPU] Add layout inference and lowering for `vector.MultiDimReduce` PiperOrigin-RevId: 752656409 --- .../mosaic/gpu/dialect_lowering.py | 43 +++++++++ .../mosaic/gpu/layout_inference.py | 82 +++++++++++++++++ tests/mosaic/gpu_layout_inference_test.py | 47 ++++++++++ tests/mosaic/gpu_test.py | 87 +++++++++++++++++++ 4 files changed, 259 insertions(+) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index da154bdd8fa3..9ca5349dd562 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -478,6 +478,49 @@ def _vector_reduction_op_lowering_rule( raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") return [_fragmented_array_to_ir(result, op.result.type)] +@_register_lowering(vector.MultiDimReductionOp) +def _vector_multi_dim_reduction_op_lowering_rule( + ctx: LoweringContext, op: vector.MultiDimReductionOp +) -> Sequence[ir.Value]: + del ctx + + [in_layout, acc_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + if layouts.from_layout_attr(in_layout) != fa.WGMMA_LAYOUT: + raise NotImplementedError(f"Unsupported input layout: {in_layout}") + if layouts.from_layout_attr(out_layout) not in { + fa.WGMMA_ROW_LAYOUT, + fa.WGMMA_COL_LAYOUT, + }: + raise NotImplementedError(f"Unsupported output layout: {out_layout}") + if out_layout != acc_layout: + raise ValueError( + f"Output layout {out_layout} must match the accumulator layout" + f" {acc_layout}" + ) + + element_type = ir.VectorType(op.source.type).element_type + + is_signed = False if ir.IntegerType.isinstance(element_type) else None + source_fa = _fragmented_array_from_ir(op.source, in_layout, is_signed) + acc_fa = _fragmented_array_from_ir(op.acc, acc_layout, is_signed) + match vector.CombiningKind[ + str(op.kind).removeprefix("#vector.kind<").removesuffix(">").upper() + ]: + case vector.CombiningKind.ADD: + result = source_fa.reduce("add", op.reduction_dims[0]) + result += acc_fa + case ( + vector.CombiningKind.MAXIMUMF + | vector.CombiningKind.MAXSI + | vector.CombiningKind.MAXUI + ): + result = source_fa.reduce("max", op.reduction_dims[0]) + result = result.max(acc_fa) + case _: + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + return [_fragmented_array_to_ir(result, op.result.type)] + @_register_lowering(mgpu.LayoutCastOp) def _mgpu_layout_cast_op_lowering_rule( diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index ae082e4deb9b..7be4abba3e65 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -444,6 +444,88 @@ def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: return None +@partial(_add_layout_inference_rule, vector.MultiDimReductionOp) +def _infer_multi_dim_reduction_op_layout( + op: vector.MultiDimReductionOp, +) -> OptionalLayouts: + if inference_utils.has_any_layout_set(op): + # At the moment we either have all layouts or none. So if we found some + # layouts, set just return the same ones. + op_in_layouts = list(inference_utils.in_layouts(op)) + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_in_layouts, op_out_layouts + + in_ty = ir.VectorType(op.source.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 2 or len(out_ty.shape) != 1: + raise NotImplementedError( + f"Only 2D -> 1D reductions are supported: {op}" + ) + + wgmma_layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) + wgmma_row_layout = layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + wgmma_col_layout = layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + reduction_dims = list(op.reduction_dims) + + # Find out the layout of the source. + in_layout = inference_utils.value_layout(op.source) + if in_layout is not None and in_layout == wgmma_layout: + if reduction_dims == [0]: + out_layout = wgmma_col_layout + elif reduction_dims == [1]: + out_layout = wgmma_row_layout + else: + raise NotImplementedError( + f"Invalid reduction dimensions: {reduction_dims}" + ) + return [in_layout, out_layout], [out_layout] + + # The source either has no layout or its layout is not WGMMA so we don't know + # yet how to handle it. Find out the layout of the result and see if that is + # WGMMA_ROW or WGMMA_COL which would imply the input is WGMMA. We can look at + # either the consumers or the acc input (they should have the same layout). + out_layouts = set() + + # Get acc layout. + acc_layout = inference_utils.value_layout(op.acc) + if acc_layout is not None: + out_layouts.add(acc_layout) + + # Get user layouts. + for use in cast(ir.OpResult, op.result).uses: + consumer = use.owner + operand = consumer.operands[use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, operand) + if layout: + out_layouts.add(layout) + + if not out_layouts: + # We couldn't find any definitive layouts, so we can't infer anything. + return None + + out_layout = _choose_representative_layout(out_layouts) + if out_layout is None: + raise NotImplementedError( + f"Could not choose a best layout from {out_layouts}" + ) + if out_layout != wgmma_row_layout and out_layout != wgmma_col_layout: + # We don't have a layout we can handle in the output, so we can't infer + # anything. + return None + + if (out_layout == wgmma_row_layout and reduction_dims == [1]) or ( + out_layout == wgmma_col_layout and reduction_dims == [0] + ): + in_layout = wgmma_layout + else: + raise NotImplementedError( + f"Unsupported output layout: {out_layout} for reduction dimensions" + f" {reduction_dims}" + ) + + return [in_layout, out_layout], [out_layout] + + @partial(_add_layout_inference_rule, mgpu.LayoutCastOp) def _infer_layout_cast_op_layout( layout_cast_op: mgpu.LayoutCastOp, diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 690e82c66cdc..315ae2659ab6 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -279,6 +279,53 @@ def body(x): bcast.attributes["out_layouts"], [layouts.to_layout_attr(out_layout)] ) + @parameterized.parameters( + (1, mgpu.WGMMA_LAYOUT, None, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, mgpu.WGMMA_LAYOUT, None, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, mgpu.WGMMA_ROW_LAYOUT, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, mgpu.WGMMA_COL_LAYOUT, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + ) + def test_infer_multi_reduce_layout( + self, reduce_dim, in_cast, acc_cast, out_cast, in_layout, out_layout + ): + red = None + + in_shape = (64, 64) + out_shape = (64,) + + def body(x, acc): + nonlocal red + if in_cast is not None: + x = mgpu.dialect.LayoutCastOp(x, layouts.to_layout_attr(in_cast)) + if acc_cast is not None: + acc = mgpu.dialect.LayoutCastOp(acc, layouts.to_layout_attr(acc_cast)) + + kind = vector.CombiningKind.MAXIMUMF + red = vector.MultiDimReductionOp(kind, x, acc, [reduce_dim]) + + if out_cast is not None: + mgpu.dialect.LayoutCastOp( + red.result, layouts.to_layout_attr(out_cast) + ) + + with ir.InsertionPoint(self.module.body): + in_ty = ir.VectorType.get(in_shape, ir.F32Type.get()) + acc_ty = ir.VectorType.get(out_shape, ir.F32Type.get()) + func.FuncOp.from_py_func(in_ty, acc_ty)(body) + + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + red.attributes["in_layouts"], + [layouts.to_layout_attr(in_layout), layouts.to_layout_attr(out_layout)], + ) + self.assertSequenceEqual( + red.attributes["out_layouts"], [layouts.to_layout_attr(out_layout)] + ) + def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) elt_type = ir.BF16Type.get() diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 177f622ff42a..fbe4b855174e 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2896,6 +2896,93 @@ def body(ctx, result_gmem_ref, smem): jax.jit(kernel)(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims) ) + @parameterized.parameters( + (jnp.float32, 5.0, 2.0, vector.CombiningKind.ADD), + (jnp.float32, 5.0, 2.0, vector.CombiningKind.MAXIMUMF), + (jnp.float32, 5.0, 7.0, vector.CombiningKind.MAXIMUMF), + (jnp.int32, 5, 2, vector.CombiningKind.MAXSI), + (jnp.int32, -5, -2, vector.CombiningKind.MAXSI), + (jnp.int32, -2, -5, vector.CombiningKind.MAXSI), + (jnp.uint32, 5, 2, vector.CombiningKind.MAXUI), + (jnp.uint32, 2, 5, vector.CombiningKind.MAXUI), + # + # TODO(dasenov): Add tests for wgmma_col_layout output once + # fragmented_array.reduce supports that. + ) + def test_vector_multi_dim_reduction( + self, + dtype, + input_value, + init_value, + kind, + ): + input_shape = (128, 64) + output_shape = (128,) + red_dims = [1] + + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + el_type = utils.dtype_to_ir_type(dtype) + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create source in registers + source_type = ir.VectorType.get(input_shape, el_type) + c = arith.constant(el_type, input_value) + source = vector.splat(source_type, c) + + # Create accumulator in registers + acc_type = ir.VectorType.get(output_shape, el_type) + c = arith.constant(el_type, init_value) + acc = vector.splat(acc_type, c) + + # Cast inputs + source = mgpu_dialect.layout_cast( + source, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + acc_layout = ( + fa.WGMMA_ROW_LAYOUT if red_dims[0] == 1 else fa.WGMMA_COL_LAYOUT + ) + acc = mgpu_dialect.layout_cast(acc, layouts.to_layout_attr(acc_layout)) + + # Computation + reduced = vector.multi_reduction(kind, source, acc, red_dims) + + # Registers -> SMEM + vector.store(reduced, result_smem_ref, [zero_index] * len(output_shape)) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32] * len(output_shape), + slice_lengths=output_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(output_shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + source = np.full(input_shape, input_value, dtype=dtype) + acc = np.full(output_shape, init_value, dtype=dtype) + if kind == vector.CombiningKind.ADD: + red = jax.lax.reduce_sum(source, red_dims) + red = red + acc + else: + red = jax.lax.reduce_max(source, red_dims) + red = jax.lax.max(red, acc) + self.assertArraysEqual(jax.jit(kernel)(), red) + @parameterized.parameters(fa.WGMMA_ROW_LAYOUT, fa.WGMMA_COL_LAYOUT) def test_wgmma_row_col_store(self, in_layout): element_value = 42.0 From 88f8691a923b5fb0e6ac95667111115b800e652d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 29 Apr 2025 04:33:17 -0700 Subject: [PATCH 0889/1769] [Mosaic GPU] Add suport for packed TMEM 16-bit MMA accumulators use a TMEM layout without packing (i.e. only one element is located in each 32-bit word), but 16-bit MMA operands need to be packed (i.e. two elements are packed into each 32-bit word). PiperOrigin-RevId: 752679306 --- jax/experimental/mosaic/gpu/core.py | 18 ++- jax/experimental/mosaic/gpu/tcgen05.py | 180 ++++++++++++++----------- tests/mosaic/gpu_test.py | 6 +- 3 files changed, 116 insertions(+), 88 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 546690c3e9fd..86be9825dcef 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -202,12 +202,16 @@ class ClusterBarrier: class TMEM: shape: tuple[int, int] dtype: Any + _: dataclasses.KW_ONLY layout: tcgen05.TMEMLayout | None = None collective: bool = False + packing: int | None = None def __post_init__(self): if self.layout is not None: - self.layout.check_shape(self.shape) + self.layout.check_type(self.shape, utils.dtype_to_ir_type(self.dtype)) + if self.packing is not None: + raise ValueError("Cannot specify both layout and packing") def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: @@ -313,13 +317,15 @@ def ref(member_thunks=member_thunks): collective_dims, cluster_shape, ) - case TMEM(shape, dtype, layout, collective): + case TMEM(shape, dtype, layout=layout, collective=collective, packing=packing): addr_ref = memref.view( ir.MemRefType.get([], i32, memory_space=smem), dynamic_smem, c(dynamic_smem_offset, index), [], ) if layout is None: - layout = tcgen05._infer_tmem_layout(shape, collective) + layout = tcgen05._infer_tmem_layout( + shape, collective, 1 if packing is None else packing + ) num_cols = layout.cols_in_shape(shape) tmem_allocs.append(_TMEMAlloc(addr_ref, num_cols, collective)) def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): @@ -478,8 +484,10 @@ def _launch( yield ctx, smem_ref_tree - for alloc in tmem_allocs: - alloc.dealloc() + if tmem_allocs: + gpu.barrier() # Make sure everyone is done before we release TMEM. + for alloc in tmem_allocs: + alloc.dealloc() if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 4f8c5847f74f..5afce275cca5 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -115,7 +115,7 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) - if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)): + if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective, packing=1)): raise ValueError( f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" ) @@ -348,7 +348,7 @@ def tmem_relinquish_alloc_permit(collective: bool): has_side_effects=True, ) -def _tmem_access_helper(shape, num, packing: int = 1): +def _tmem_access_helper(shape, num): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: @@ -369,15 +369,10 @@ def _tmem_access_helper(shape, num, packing: int = 1): return num_regs, regs_vector -def tmem_load(tmem_addr, shape, num, packing: int = 1): +def tmem_load(tmem_addr, shape, num, pack: bool): i32 = ir.IntegerType.get_signless(32) - num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) - if packing == 1: - pack_mod = "" - elif packing == 2: - pack_mod = ".pack::16b" - else: - raise ValueError(f"Unsupported packing: {packing}") + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".pack::16b" if pack else "" regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" @@ -390,14 +385,9 @@ def tmem_load(tmem_addr, shape, num, packing: int = 1): return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] -def tmem_store(tmem_addr, shape, num, regs, packing: int = 1): - num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) - if packing == 1: - pack_mod = "" - elif packing == 2: - pack_mod = ".unpack::16b" - else: - raise ValueError(f"Unsupported packing: {packing}") +def tmem_store(tmem_addr, shape, num, regs, unpack: bool): + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".unpack::16b" if unpack else "" llvm.inline_asm( ir.Type.parse("!llvm.void"), [*regs, tmem_addr], @@ -440,6 +430,7 @@ class TMEMLayout: """ elements_in_tile: tuple[int, int] column_tile_stride: int = 1 + packing: int = 1 def __post_init__(self): row_tiling = self.elements_in_tile[0] @@ -449,24 +440,36 @@ def __post_init__(self): ) if row_tiling.bit_count() != 1: raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}") + if self.elements_in_tile[1] % self.packing: + raise ValueError( + f"Column tiling must be a multiple of packing={self.packing}, got:" + f" {self.elements_in_tile[1]}" + ) - def check_shape(self, shape: tuple[int, ...]): + def check_type(self, shape: tuple[int, ...], dtype: ir.Type): if len(shape) != 2: raise ValueError(f"TMEM can only represent 2D shapes, got {shape}") if any(s % t for s, t in zip(shape, self.elements_in_tile)): raise ValueError( f"{shape} is divisible into tiles of shape {self.elements_in_tile}" ) + if self.packing not in {1, fully_packed := 32 // utils.bitwidth(dtype)}: + raise ValueError( + f"For {utils.bitwidth(dtype)}-bit types, only packing=1 and" + f" packing={fully_packed} are supported, but got: {self.packing}" + ) def cols_in_shape(self, shape: tuple[int, int]): - cols_in_tile = self.elements_in_tile[1] + cols_in_tile = self.elements_in_tile[1] // self.packing tiles_in_row = TMEM_ROWS // self.elements_in_tile[0] num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2]) assert num_tiles % tiles_in_row == 0 return num_tiles // tiles_in_row * cols_in_tile -def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: +def _infer_tmem_layout( + shape: tuple[int, int], collective: bool, packing: int = 1 +) -> TMEMLayout: if shape[0] > TMEM_ROWS: raise ValueError( "Can only infer TMEM layout for shapes with at most 128 rows, got:" @@ -488,9 +491,11 @@ def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: f" multiple of 8, got: {shape[1]}" ) if collective and shape[1] == 512: - return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2) + return TMEMLayout( + elements_in_tile=(shape[0], 128), column_tile_stride=2, packing=packing + ) else: - return TMEMLayout(elements_in_tile=(shape[0], 8)) + return TMEMLayout(elements_in_tile=(shape[0], 8), packing=packing) @dataclasses.dataclass(frozen=True) @@ -531,7 +536,7 @@ def from_alloc( ) layout = _infer_tmem_layout(shape, collective) else: - layout.check_shape(shape) + layout.check_type(shape, dtype) # TODO: Do we have to do this?? # warp_idx = utils.warp_idx(sync=False) # tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32))) @@ -575,35 +580,37 @@ def __getitem__(self, *idxs): raise NotImplementedError(f"Unsupported dtype: {self.dtype}") layout = _m128_layout(self.shape) regs_shape = layout.registers_shape(self.shape) - if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - # load_32xcols returns a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). - registers = _load_32xcols( - self.address, self.shape[1], self.dtype - ).T.reshape(regs_shape) - elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2): - if self.shape[1] % 128 != 0: - raise ValueError( - f"TMEM layout {self.layout} is not compatible with shape {self.shape}" - ) - num_column_tiles = self.shape[1] // 128 - column_tile_stride = self.layout.column_tile_stride - num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) - tiles = [] - for col_tile_base in range(num_strided_col_groups): - for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): - tiles.append( - _load_32xcols( - arith.addi(self.address, arith.constant(i32, col_tile * 128)), - cols=128, - dtype=self.dtype, - ) + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: + # load_32xcols returns a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + registers = _load_32xcols( + self.address, self.shape[1], self.dtype, packing + ).T.reshape(regs_shape) + case TMEMLayout(elements_in_tile=(r, 128), column_tile_stride=2) if r == TMEM_ROWS: + if self.shape[1] % 128 != 0: + raise ValueError( + f"TMEM layout {self.layout} is not compatible with shape {self.shape}" ) - registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) - else: - raise NotImplementedError( - f"Loads only implemented for refs with standard layout, got: {self.layout}" - ) + num_column_tiles = self.shape[1] // 128 + column_tile_stride = self.layout.column_tile_stride + num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) + tiles = [] + for col_tile_base in range(num_strided_col_groups): + for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): + tiles.append( + _load_32xcols( + arith.addi(self.address, arith.constant(i32, col_tile * 128)), + cols=128, + dtype=self.dtype, + tmem_packing=False, + ) + ) + registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + case _: + raise NotImplementedError( + f"Loads only implemented for refs with standard layout, got: {self.layout}" + ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) def __setitem__(self, idxs, value): @@ -637,19 +644,21 @@ def __setitem__(self, idxs, value): f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is" " supported" ) - if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - # store_32xcols needs a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). - _store_32xcols( - self.address, value.registers.T.reshape((4, -1)) - ) - else: # TODO(apaszke): Collective MMA layout - raise NotImplementedError( - f"Stores only implemented for refs with standard layout, got: {self.layout}" - ) + # TODO(apaszke): Collective MMA layout + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)), packing + ) + case _: + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) -def _transfer_32xcols(base_addr, cols): +def _transfer_32xcols(base_addr: ir.Value, cols: int, packing: int): i32 = ir.IntegerType.get_signless(32) cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. assert cols % cols_per_num == 0 @@ -669,17 +678,19 @@ def _transfer_32xcols(base_addr, cols): cols_per_instr = instr_num * cols_per_num for num_step in range(total_num // instr_num): num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) - addr_row_col = arith.addi(addr_row, utils.c(num_step * cols_per_instr, i32)) + addr_row_col = arith.addi( + addr_row, utils.c(num_step * cols_per_instr // packing, i32) + ) yield addr_row_col, instr_num, lane_step, num_slice -def _store_32xcols(base_addr, vector_regs): +def _store_32xcols(base_addr, vector_regs, tmem_packing): i32 = ir.IntegerType.get_signless(32) assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 cols = vector_regs.shape[1] * 8 - packing = 64 // utils.bitwidth(vector_regs.flat[0].type) - if packing == 1: + reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if reg_packing == 1: store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) c0 = arith.constant(i32, 0) @@ -692,43 +703,51 @@ def _store_32xcols(base_addr, vector_regs): # minor dim traversing columns and major being 8 rows apart. # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b assert regs.shape[-2:] == (2, 2) - elif packing == 2: + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits # From a single lane perspective a num tile has 2 registers, 8 rows apart. # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 else: - raise NotImplementedError(packing) + raise NotImplementedError(reg_packing) - it = _transfer_32xcols(base_addr, cols) + it = _transfer_32xcols(base_addr, cols, tmem_packing) for addr_row_col, instr_num, lane_step, num_slice in it: regs_slice = regs[lane_step, num_slice].flat - tmem_store(addr_row_col, store_shape, instr_num, regs_slice, packing) + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) -def _load_32xcols(base_addr, cols, dtype): +def _load_32xcols(base_addr, cols, dtype, tmem_packing): i32 = ir.IntegerType.get_signless(32) vec_ty = ir.VectorType.get((2,), dtype) - packing = 32 // utils.bitwidth(dtype) - if packing == 1: + reg_packing = 32 // utils.bitwidth(dtype) + if reg_packing == 1: load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits - elif packing == 2: + assert tmem_packing == 1 + pack = False + elif reg_packing == 2: load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 else: - raise NotImplementedError(packing) + raise NotImplementedError(reg_packing) vector_regs = np.ndarray((4, cols // 8), dtype=object) - it = _transfer_32xcols(base_addr, cols) + it = _transfer_32xcols(base_addr, cols, tmem_packing) c0 = arith.constant(i32, 0) c1 = arith.constant(i32, 1) for addr_row_col, instr_num, lane_step, num_slice in it: - regs = tmem_load(addr_row_col, load_shape, instr_num, packing) + regs = tmem_load(addr_row_col, load_shape, instr_num, pack) row_slice = slice(lane_step * 2, (lane_step + 1) * 2) # This aliases the original array, so updates will be reflected there. vector_regs_update = vector_regs[row_slice, num_slice] assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) - if packing == 1: + if reg_packing == 1: regs = [llvm.bitcast(dtype, r) for r in regs] # From a single lane perspective a num tile consists of a 2x2, with the # minor dim traversing columns and major being 8 rows apart. @@ -741,7 +760,7 @@ def _load_32xcols(base_addr, cols, dtype): vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) vector_regs_update[idx] = vreg else: - assert packing == 2 + assert reg_packing == 2 regs = [llvm.bitcast(vec_ty, r) for r in regs] # From a single lane perspective a num tile has 2 registers, 8 rows apart. # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b @@ -758,6 +777,7 @@ def _m128_layout(shape: tuple[int, ...]): raise ValueError(f"Shape {shape} is not a multiple of 64x8") return LAYOUT + # Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. # The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) LAYOUT = fa.TiledLayout( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index fbe4b855174e..9278d71cde40 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -903,8 +903,8 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") - @parameterized.parameters([jnp.float32, jnp.float16]) - def test_load_store_tmem(self, jax_dtype): + @parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)]) + def test_load_store_tmem(self, jax_dtype, packing): swizzle = 128 in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) swizzle_elems = swizzle // bytewidth(in_mlir_dtype) @@ -933,7 +933,7 @@ def kernel(ctx, input, output, scratch): scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), mgpu.TMABarrier(), - mgpu.TMEM(x.shape, jax_dtype), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), ] y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape From 74187e5aa00b9f960c1eb43d4e10c2aa890619cb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 29 Apr 2025 04:41:19 -0700 Subject: [PATCH 0890/1769] Add note to changelog about nightly package switchover PiperOrigin-RevId: 752681193 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index efb515b73283..e8007dd3eec2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis given its name. +* Changes + * JAX nightly packages are now published to artifact registry. To install + these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). + ## JAX 0.6.0 (April 16, 2025) * Breaking changes From 55046bbd813e804fa3f167f4ea1e6b66bd196506 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 29 Apr 2025 04:44:18 -0700 Subject: [PATCH 0891/1769] Allow nightly workflows to halt for connection PiperOrigin-RevId: 752681966 --- .github/workflows/bazel_cuda_non_rbe.yml | 4 ++-- .github/workflows/pytest_cpu.yml | 4 ++-- .github/workflows/pytest_cuda.yml | 4 ++-- .github/workflows/pytest_tpu.yml | 4 ++-- .github/workflows/wheel_tests_nightly_release.yml | 10 +++++++++- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index ff1cf9900ce3..8d230848d20a 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -36,9 +36,9 @@ on: type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean + type: string required: false - default: false + default: 'no' jobs: run-tests: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index c952ef9ee1a6..bdce2b684803 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -41,9 +41,9 @@ on: type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean + type: string required: false - default: false + default: 'no' jobs: run-tests: diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 671af873b48d..4df752310ace 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -46,9 +46,9 @@ on: type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean + type: string required: false - default: false + default: 'no' jobs: run-tests: diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 0b56635a8aac..55a0b4cc1a5f 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -66,9 +66,9 @@ on: type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean + type: string required: false - default: false + default: 'no' jobs: run-tests: diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 483705aeb0cb..de89c278c6c3 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -17,13 +17,18 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'gs://jax-nightly-release-transient/nightly/latest' + default: 'gs://jax-nightly-artifacts/latest' type: string download-jax-only-from-gcs: description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" required: true default: '0' type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -50,6 +55,7 @@ jobs: enable-x64: ${{ matrix.enable-x64 }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} run-pytest-cuda: uses: ./.github/workflows/pytest_cuda.yml @@ -70,6 +76,7 @@ jobs: enable-x64: ${{ matrix.enable-x64 }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} run-pytest-tpu: uses: ./.github/workflows/pytest_tpu.yml @@ -117,6 +124,7 @@ jobs: libtpu-version-type: ${{ matrix.libtpu-version-type }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} verify-release-wheels-install: if: ${{ startsWith(github.ref_name, 'release/')}} From a49adaca4c402eb77c6336d4f2fb118dd5ff4854 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Apr 2025 05:08:41 -0700 Subject: [PATCH 0892/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/38b38564b27b0146abf0f8131983874f1f5d8fe2. PiperOrigin-RevId: 752688631 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0530665390bc..e95466226c42 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b77c3dfa079d9d78e40cd7fcd18f821e7c90b8ed" -XLA_SHA256 = "c527159d433b3301acc6a3d01e504a1718d80c4dfa0a4d38e2cf7529d4fb8162" +XLA_COMMIT = "38b38564b27b0146abf0f8131983874f1f5d8fe2" +XLA_SHA256 = "e61f726a8ad1faf3d58c61c50d8f36b0bfc9e76ad81e8f7bc1562a3645f33eaa" def repo(): tf_http_archive( From ef015a0d3f8effa3d313c7d692d9241a54a99922 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 29 Apr 2025 06:14:09 -0700 Subject: [PATCH 0893/1769] Re-land "Inline literals while tracing instead of in a separate pass". Reverts 9ee7bade3513d8a4919a9194a023107492a5d2dd PiperOrigin-RevId: 752706961 --- jax/_src/core.py | 23 +++--- jax/_src/interpreters/partial_eval.py | 108 +++++++++++--------------- 2 files changed, 60 insertions(+), 71 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index dcdfc91344ce..11d48ca017ee 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -442,34 +442,36 @@ def __init__(self, aval: AbstractValue): def __repr__(self): return '_' class Literal: - __slots__ = ["val", "aval", "hash"] + __slots__ = ["val", "aval"] val: Any aval: AbstractValue - hash: int | None def __init__(self, val, aval): self.val = val self.aval = aval + + @property + def hash(self): try: - self.hash = hash(val) + return hash(self.val) except TypeError: - if type(val) in literalable_types: + if type(self.val) in literalable_types: try: - self.hash = hash((val.item(), val.dtype)) + return hash((self.val.item(), self.val.dtype)) except (TypeError, AttributeError, ValueError): - self.hash = None + return None __hash__ = None # type: ignore def __repr__(self): - if hasattr(self, 'hash'): - return f'{self.val}' - else: - return f'Literal(val={self.val})' + return f'{self.val}' literalable_types: set[type] = set() +def is_literalable(x: Any) -> bool: + return type(x) in dtypes.python_scalar_dtypes or (type(x) in literalable_types and not np.shape(x)) + Atom = Union[Var, Literal] class Primitive: @@ -2067,6 +2069,7 @@ class DShapedArray(UnshapedArray): array_abstraction_level: int = 3 def __init__(self, shape, dtype, weak_type=False): + assert not any(isinstance(d, Literal) for d in shape) self.shape = shape self.dtype = dtype self.weak_type = weak_type diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 7203af95fb7b..1c42f86c3d48 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -199,7 +199,7 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: if const is None: return tracer else: - if type(const) in core.literalable_types and np.shape(const) == (): + if core.is_literalable(const): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) @@ -1651,7 +1651,8 @@ def _origin_msg(self): def get_referent(self): frame = self._trace.frame - val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) + var = frame.tracer_to_var.get(id(self)) + val = frame.constvar_to_val.get(var) if isinstance(var, Var) else None return self if val is None else get_referent(val) core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval @@ -1694,7 +1695,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: class JaxprStackFrame: gensym: Callable[[AbstractValue], Var] - tracer_to_var: dict[TracerId, Var] + tracer_to_var: dict[TracerId, Atom] constid_to_tracer: dict[ConstId, Tracer] constvar_to_val: dict[Var, Any] tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers @@ -1732,7 +1733,8 @@ def to_jaxpr( debug_info: core.DebugInfo, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) @@ -1746,14 +1748,15 @@ def to_jaxpr( jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, debug_info) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], debug_info: core.DebugInfo): # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) constvars, constvals = unzip2(self.constvar_to_val.items()) expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, @@ -1762,7 +1765,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -1772,6 +1775,7 @@ def newvar(self, aval): # this aval may have tracers in it, so we replace those with variables new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d for d in aval.shape] + new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] aval = aval.update(shape=tuple(new_shape)) return self.gensym(aval) @@ -1787,14 +1791,15 @@ def find_progenitors(self, tracer): active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns - if {v for v in eqn.invars if type(v) is Var} & constvars] + const_eqns = [eqn for eqn in self.eqns if any( + v in constvars if type(v) is Var else type(v) is Literal + for v in eqn.invars)] return invar_positions, const_eqns def _const_folding_and_forwarding( jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined + var_subs: dict[Var, Atom] = {} new_eqns = [] def apply_var_sub(a: Atom) -> Atom: return var_subs.get(a, a) if isinstance(a, Var) else a @@ -1805,14 +1810,20 @@ def apply_var_sub(a: Atom) -> Atom: has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) for eff in eqn.effects) if (eqn.primitive in const_fold_rules and - any(v in consts for v in eqn.invars if isinstance(v, Var)) and + any(v in consts if isinstance(v, Var) + else isinstance(v, Literal) for v in eqn.invars) and not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else None + consts_in = [consts.get(v) if isinstance(v, Var) else + v.val if isinstance(v, Literal) else None for v in eqn.invars] consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) assert (new_eqn is None) == all(c is not None for c in consts_out) for v, c in zip(eqn.outvars, consts_out): - if c is not None: consts[v] = c + if c is not None: + if core.is_literalable(c): + var_subs[v] = Literal(c, v.aval) + else: + consts[v] = c if new_eqn is None: continue else: eqn = new_eqn # if the application trivially maps some inputs to outputs, simplify @@ -1844,54 +1855,26 @@ def apply_var_sub(a: Atom) -> Atom: forwarding_rules: dict[Primitive, ForwardingRule] = {} -def _inline_literals( +def _drop_unused_vars( jaxpr: Jaxpr, constvals: Sequence[Any] ) -> tuple[Jaxpr, list[Any]]: - # This function also prunes unused constants and inserts `dropvar` symbols. - input_effects = {eff for eff in jaxpr.effects - if isinstance(eff, effects.JaxprInputEffect)} - # Don't inline any literal with an input effect - has_input_effect = [any(eff.input_index == i for eff in input_effects) - for i in range(len(constvals))] - lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, - has_input_effect) - if type(c) in core.literalable_types and not np.shape(c) and not e} - def lit(a: Atom) -> Literal | None: - return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) - else None) - newname: Callable[[AbstractValue], Var] = core.gensym() - newvars: dict[Var, Var] = {} - newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) - var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) - lit_or_var = ( - lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) - ) - dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) - - def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: + def vars(atom: Atom) -> list[Var]: + if isinstance(atom, Literal): + return [] + aval = atom.aval if isinstance(aval, DShapedArray): - return [d for d in aval.shape if isinstance(d, Var)] - return [] - - used = {v for eqn in jaxpr.eqns for atom in eqn.invars - for v in it.chain([atom], vars_in_shape(atom.aval)) - if isinstance(atom, Var)} - used |= {v for outvar in jaxpr.outvars - for v in it.chain([outvar], vars_in_shape(outvar.aval))} - new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] - new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) - if v in used and not lit(v)] - new_invars = [var(v) for v in jaxpr.invars] - new_eqns = [] - for eqn in jaxpr.eqns: - invars = [lit_or_var(x) for x in eqn.invars] - outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] - new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, - jaxpr.debug_info) - return new_jaxpr, new_constvals + return [atom] + [d for d in aval.shape if isinstance(d, Var)] + return [atom] + used: set[Var] = {v for atom in jaxpr.outvars for v in vars(atom)} + for eqn in jaxpr.eqns[::-1]: + eqn.outvars = [v if v in used else DropVar(v.aval) for v in eqn.outvars] + used.update(v for atom in eqn.invars for v in vars(atom)) + cvars, constvals = unzip2( + (v, val) for v, val in zip(jaxpr.constvars, constvals) if v in used) + jaxpr._constvars = list(cvars) + jaxpr._effects = make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, + jaxpr.outvars, jaxpr.eqns) + return jaxpr, list(constvals) class DynamicJaxprTrace(core.Trace): @@ -1942,9 +1925,12 @@ def new_const(self, c, source_info: SourceInfo): def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c + if core.is_literalable(c): + self.frame.tracer_to_var[id(tracer)] = Literal(c, aval) + else: + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.constid_to_tracer[id(c)] = tracer + self.frame.constvar_to_val[var] = c return tracer def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): From 66bd3cef4224fe070b7d0a80715ba916aa4d2835 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 29 Apr 2025 06:41:38 -0700 Subject: [PATCH 0894/1769] Remove deprecated GPU linalg kernels after compatibility period. JAX 0.4.35 was released >180d ago, and these kernels were no longer targeted by that release. PiperOrigin-RevId: 752713750 --- .../cuda_eigh_cusolver_syev.py | 1393 ----------------- jaxlib/gpu/gpu_kernels.cc | 4 - jaxlib/gpu/solver.cc | 247 --- jaxlib/gpu/solver_kernels.cc | 468 ------ jaxlib/gpu/solver_kernels.h | 53 - tests/export_back_compat_test.py | 32 - 6 files changed, 2197 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py index 56479e82f9d9..4b7c37723cc1 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py @@ -17,1399 +17,6 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17=dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.18577063e-01, -8.00570633e-05, -1.96905047e-01, - -8.95753130e-02, 7.24549413e-01, -1.07546024e-01, - -4.77200520e-04, 1.84469908e-01], - [ 4.70708847e-01, 3.31519186e-05, 2.80930042e-01, - -5.84393919e-01, -4.93098050e-01, -2.50211239e-01, - -1.14346610e-03, 2.28566617e-01], - [ 3.22840720e-01, -5.11042356e-01, -3.03526163e-01, - 2.48800799e-01, -3.14544559e-01, 5.54342926e-01, - 1.10838346e-06, 2.72663534e-01], - [ 1.74972475e-01, 4.18093473e-01, -2.66933769e-01, - 5.78716159e-01, -2.97307134e-01, -4.46864694e-01, - 1.09066934e-06, 3.16760242e-01], - [ 2.71042082e-02, 4.29418474e-01, 4.71952170e-01, - 1.10573582e-01, 9.57800150e-02, 4.65731144e-01, - -4.72866714e-01, 3.60856950e-01], - [-1.20763958e-01, -3.84347916e-01, 5.79687178e-01, - 2.87678182e-01, 1.63329691e-01, -2.02215970e-01, - 4.32829827e-01, 4.04953718e-01], - [-2.68632114e-01, 3.63640338e-01, -2.97110289e-01, - -3.32554609e-01, 3.46945561e-02, 2.77071655e-01, - 5.63131213e-01, 4.49050426e-01], - [-4.16500419e-01, -3.15715015e-01, -2.68094122e-01, - -2.19244853e-01, 8.65960941e-02, -2.90307850e-01, - -5.21475971e-01, 4.93147314e-01]], dtype=float32), array([-2.4598812e+01, -2.4345848e-06, -1.2664314e-06, -8.6959182e-07, - -8.2917722e-07, 1.6633214e-06, 2.0499781e-06, 2.7659885e+02], - dtype=float32)), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<2125xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8x8xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<2125xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf32> - return %20, %25 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02r\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 3.14863890e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - -4.91220355e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 8.05416584e-01, - 0.00000000e+00, -1.77893345e-03, -2.64500137e-02, - 1.46598322e-04, -5.19353598e-02, -8.64148438e-02], - [ 2.99391806e-01, 2.77544819e-02, 6.73292065e-03, - -6.83086272e-03, -3.54272849e-03, -1.21014733e-02, - -1.32716037e-02, -1.15843862e-03, -8.83520208e-03, - -6.63395738e-03, 1.60171092e-03, -1.01765711e-03, - 1.19860061e-02, -1.33239310e-02, 1.76237477e-03, - 1.27085261e-02, 3.38556734e-03, -8.78101215e-03, - 1.58616400e-03, -7.37631368e-03, 3.81911686e-03, - -5.18379211e-02, -7.22059654e-03, 1.85085051e-02, - 2.94725411e-03, 4.74284729e-03, -1.33781182e-02, - -3.61499190e-03, -5.49228955e-03, -1.05845921e-01, - 1.01772454e-02, 4.47412670e-01, 1.95654288e-01, - 3.94686669e-01, 7.00925171e-01, -9.06614065e-02], - [ 2.83920437e-01, 1.69272088e-02, 6.64264262e-02, - -1.18565477e-01, 3.54601629e-02, -1.52457461e-01, - 6.84847543e-03, 1.90414500e-03, -2.76310533e-01, - 3.76881436e-02, 1.22269124e-01, -1.01556584e-01, - -1.90264836e-01, -1.16590485e-01, 6.09031200e-01, - -9.43092555e-02, -3.74726858e-03, -2.33182713e-01, - 1.95203945e-01, -1.20613754e-01, 3.94887812e-02, - -5.88066364e-03, 1.19152360e-01, -1.46030456e-01, - -4.74781469e-02, 2.67041594e-01, -1.22617789e-01, - 5.77996820e-02, 2.58437768e-02, -1.34434626e-01, - -3.28330845e-02, -9.32494774e-02, 1.14714004e-01, - 1.21207587e-01, -2.04871535e-01, -9.49072391e-02], - [ 2.68448830e-01, 2.17946004e-02, -1.94895901e-02, - 3.40374447e-02, 6.18659109e-02, 1.72068894e-01, - -8.02555401e-03, 9.68076065e-02, 4.98391055e-02, - 5.55528253e-02, -3.23998183e-02, -2.63249427e-01, - -4.35045222e-03, 5.20016700e-02, -5.92328422e-02, - 4.31317724e-02, -2.00986061e-02, -2.69871447e-02, - 1.54309347e-01, 1.74670279e-01, -4.97168908e-03, - -4.15510803e-01, -4.33471389e-02, -3.71299796e-02, - 5.26434295e-02, -1.18867345e-01, -2.42547281e-02, - -3.90263759e-02, -2.58720964e-01, -3.92957211e-01, - -1.28192365e-01, 2.77028710e-01, -4.02157485e-01, - -1.77024350e-01, -1.76668167e-01, -9.91534367e-02], - [ 2.52977222e-01, 3.48518007e-02, 7.02044442e-02, - 1.42712081e-02, 4.50692251e-02, 7.16193160e-03, - 1.19931757e-01, 2.32399218e-02, -6.05047755e-02, - 1.06077030e-01, 1.03731848e-01, -1.13200452e-02, - 5.94755262e-03, -2.32813850e-01, 8.72232541e-02, - 8.17264095e-02, 3.30835059e-02, 4.88227099e-01, - 6.14454560e-02, 1.43805355e-01, -7.40422234e-02, - 2.25823849e-01, -3.86487693e-01, 1.30468249e-01, - 3.16427708e-01, -1.19733319e-01, -4.18486483e-02, - -2.74667948e-01, -2.16731444e-01, 2.60375626e-02, - 5.77645637e-02, -7.56322592e-02, 2.28632554e-01, - 2.37157010e-02, -1.40153974e-01, -1.03399649e-01], - [ 2.37505659e-01, 7.01064467e-02, -3.83728333e-02, - 5.06979637e-02, 1.83892641e-02, 4.02548499e-02, - -3.88330072e-02, 3.13181393e-02, -5.75652197e-02, - 7.04995319e-02, -6.92743529e-03, -9.82947052e-02, - -4.91717793e-02, 4.06844541e-02, -1.53035461e-03, - 4.68783826e-02, 5.36918640e-03, -1.67432979e-01, - 1.03467651e-01, 3.48554403e-02, 3.20128165e-02, - 4.70223904e-01, 9.19904634e-02, 6.90946281e-02, - -6.94891065e-02, 3.92344594e-02, -6.30731881e-02, - 2.22810470e-02, -3.87494615e-03, 1.96694940e-01, - -1.92701817e-02, 2.01028123e-01, 1.89283062e-02, - -6.97807550e-01, 2.03354478e-01, -1.07645869e-01], - [ 2.22034067e-01, -1.60748392e-01, 2.42968962e-01, - -3.35482806e-01, -3.41870189e-02, 1.28819138e-01, - 1.24212839e-01, -3.87125909e-02, -5.60933471e-01, - 7.95257688e-02, -3.60307507e-02, 3.67332071e-01, - -5.87672107e-02, 7.33083040e-02, -3.94398779e-01, - -7.60597512e-02, 1.71925854e-02, 1.17799109e-02, - -2.65986789e-02, 1.98394638e-02, -1.35528380e-02, - -3.39059532e-02, 9.92002785e-02, -7.92167559e-02, - 9.19176906e-04, -4.89958897e-02, 5.72972372e-02, - 1.21006947e-02, 4.03640568e-02, -1.18844979e-01, - -2.80744191e-02, -1.74218431e-01, -4.31395955e-02, - -6.09265082e-02, 3.76862884e-02, -1.11892074e-01], - [ 2.06562474e-01, 1.73960440e-02, -2.63249487e-01, - 1.38902217e-01, -4.79032584e-02, -2.24852517e-01, - 4.69521992e-02, -3.35566737e-02, 1.37603536e-01, - -5.11448458e-02, 8.18398222e-02, 1.07205749e-01, - -1.46739393e-01, -1.30916521e-01, -2.28276670e-01, - -7.91462511e-02, 6.24803789e-02, 4.59876209e-02, - 8.15130547e-02, 1.46908918e-02, -2.61019613e-03, - 1.13239333e-01, 2.98404664e-01, -1.80148214e-01, - 1.44556239e-01, -3.98542970e-01, -4.15323582e-03, - 4.42554235e-01, 4.46505845e-02, -3.50878686e-02, - -1.36736231e-02, 1.28197059e-01, 1.92225441e-01, - 9.25138816e-02, -2.71676213e-01, -1.16138257e-01], - [ 1.91090912e-01, -3.68523598e-02, -6.60930753e-01, - 3.02158773e-01, 1.77503861e-02, 1.00428194e-01, - -1.10393446e-02, 9.11340117e-03, -7.01573640e-02, - -3.42316413e-03, -7.93189174e-05, 2.59178817e-01, - 1.22925844e-02, 6.14976510e-02, -1.56667307e-01, - -5.03374226e-02, -4.95696850e-02, -1.59401018e-02, - -4.26767953e-02, -5.12050986e-02, -6.04047906e-03, - 5.44762500e-02, -1.07276395e-01, -1.12534806e-01, - -1.20743208e-01, 3.80993217e-01, -2.20808387e-02, - -2.89817184e-01, 3.23761255e-02, -6.17432930e-02, - -3.90686616e-02, -5.96804358e-02, -4.96021062e-02, - 8.57739672e-02, -8.64073634e-02, -1.20384485e-01], - [ 1.75619304e-01, 1.71932317e-02, 4.29833472e-01, - 8.81271958e-02, -3.94745134e-02, -5.61874844e-02, - 7.05854744e-02, 7.86138419e-03, 4.67237175e-01, - -1.88353360e-02, 6.92435876e-02, -3.38627174e-02, - -8.19625556e-02, -4.84902970e-02, -2.62022078e-01, - -1.48765266e-01, 7.19114691e-02, -1.21600203e-01, - 1.18209779e-01, 2.58331411e-02, 4.69931588e-02, - 9.96347591e-02, 2.32059956e-01, -1.78489253e-01, - 1.77511200e-03, 1.59484446e-01, 3.28991674e-02, - -4.70239580e-01, 1.65105104e-01, -2.61324756e-02, - -1.49319443e-04, -8.15570727e-02, 7.44131976e-05, - 8.14437792e-02, -7.25714415e-02, -1.24630690e-01], - [ 1.60147712e-01, -1.10780589e-01, 2.73144871e-01, - 1.10703602e-01, 2.37337053e-02, 4.52041216e-02, - 1.52682560e-02, -3.83009948e-02, 2.30164632e-01, - 2.54375394e-02, -3.03758867e-02, 8.13979190e-03, - 2.33282149e-02, 3.12441736e-02, -1.84844747e-01, - 2.14728359e-02, -5.53616770e-02, -2.22909674e-02, - -9.31906551e-02, -1.01961263e-01, -3.32283713e-02, - 8.18983093e-02, -3.90430242e-01, 1.43959653e-02, - -1.31596243e-02, 4.55893874e-01, -4.22518775e-02, - 5.82709551e-01, -1.36653170e-01, -3.07889320e-02, - -4.67781313e-02, -6.33331314e-02, -5.06754033e-03, - 3.76623571e-02, -6.18892610e-02, -1.28876895e-01], - [ 1.44676119e-01, -2.91557442e-02, 2.55934417e-01, - 5.66692650e-01, 3.84408869e-02, 1.04354315e-01, - -1.37322113e-01, 7.15484237e-03, -1.95520781e-02, - -2.59401686e-02, -9.82144028e-02, 2.44248882e-01, - 1.52861271e-02, 1.99174404e-01, 2.76121795e-01, - 8.94557908e-02, -1.24152258e-01, 6.37411512e-03, - -1.13803938e-01, -3.23315486e-02, -3.17632034e-02, - 2.70075332e-02, 2.75091957e-02, -4.90174480e-02, - -2.08239228e-01, -3.95830333e-01, -5.95310889e-02, - -4.46558185e-03, -7.16161057e-02, -4.99811508e-02, - -1.02262713e-01, -2.79212356e-01, -5.11405505e-02, - 2.62467805e-02, 1.03744328e-01, -1.33123115e-01], - [ 1.29204527e-01, -1.63312718e-01, -1.99243486e-01, - -2.34051406e-01, 3.55675933e-03, 1.56449080e-02, - 9.30304453e-02, -7.26388171e-02, 1.25461653e-01, - 1.20737530e-01, 4.42517921e-02, -4.18601990e-01, - -1.94645032e-01, 1.02710314e-01, -7.12260604e-02, - -6.79927021e-02, -3.08946688e-02, -8.88019353e-02, - -4.35314551e-02, -2.15784147e-01, -1.86102502e-02, - 5.49090989e-02, -3.75167191e-01, -8.20007622e-02, - -2.06737250e-01, -3.52603942e-01, -3.86392660e-02, - -2.84039471e-02, 2.83454835e-01, -2.61564963e-02, - 1.20758023e-02, -2.92337686e-01, -5.17344326e-02, - 3.77417319e-02, 1.23368390e-01, -1.37369320e-01], - [ 1.13732927e-01, -6.13378249e-02, -1.77854180e-01, - -4.99198377e-01, 2.01901477e-02, 1.41450047e-01, - -3.23677920e-02, 9.39797983e-03, 5.04098058e-01, - 1.23931216e-02, -8.47154856e-02, 3.81212860e-01, - 1.21610202e-01, 4.87964153e-02, 2.52459884e-01, - 1.51112108e-02, -4.74717468e-02, -1.84605867e-02, - -7.36073852e-02, 3.58235948e-02, -7.69592915e-03, - -7.00120777e-02, 1.28127992e-01, 4.49521616e-02, - -7.93955289e-04, -3.76549661e-02, 1.04670962e-02, - 7.88062997e-03, -2.23614484e-01, -9.32817012e-02, - -4.67354655e-02, -1.74636483e-01, 1.47633761e-01, - -1.42957285e-01, 7.11189136e-02, -1.41615525e-01], - [ 9.82613564e-02, -1.55768439e-01, -1.11842593e-02, - 6.37831986e-02, 5.79317398e-02, 3.34746271e-01, - 3.84975046e-01, -2.11655404e-02, -4.85437140e-02, - -4.50517267e-01, -3.28294598e-02, -2.49714255e-01, - 3.28522325e-01, -1.25372112e-01, 2.82705110e-02, - 1.42169207e-01, -8.04641694e-02, 6.62415996e-02, - -9.59652960e-02, -5.61193414e-02, -4.80792150e-02, - -4.04721648e-02, 2.45707080e-01, 2.35501617e-01, - -4.14447524e-02, 4.34486791e-02, -4.62412462e-02, - 4.26126681e-02, 2.55748153e-01, -7.83308148e-02, - 2.59090564e-03, -3.38329338e-02, 1.78729519e-01, - -3.09782606e-02, -8.34960043e-02, -1.45861730e-01], - [ 8.27897936e-02, -5.39819747e-02, 5.41151650e-02, - -2.87518036e-02, 1.98750496e-02, -1.58728033e-01, - -4.75713938e-01, 1.16178179e-02, -2.98879808e-03, - 2.26475924e-01, 2.46154964e-02, 1.24507852e-01, - 4.07826692e-01, -2.43859500e-01, 1.46053182e-02, - 8.78053382e-02, -7.19747171e-02, -4.02797535e-02, - -8.92022029e-02, -4.73439731e-02, 2.02829354e-02, - -9.01956186e-02, -1.16379023e-01, 1.02566876e-01, - 1.27621949e-01, -3.85584086e-02, -1.85301397e-02, - -1.46384817e-02, 5.42852879e-01, -1.11336805e-01, - -4.69652563e-02, 1.10105053e-01, -3.25540863e-02, - -9.18325037e-02, -1.09285243e-01, -1.50107935e-01], - [ 6.73181787e-02, -1.69579491e-01, -5.90509735e-02, - -8.87142718e-02, -4.61161807e-02, -1.32888526e-01, - -4.28256035e-01, -4.96512838e-02, -1.00748278e-01, - -1.56540096e-01, -1.33985683e-01, -3.31550747e-01, - 3.25447232e-01, 2.73245610e-02, -6.19893037e-02, - -1.48184791e-01, 1.88705355e-01, 1.62340149e-01, - 1.02853999e-01, 3.19841057e-01, -6.06105961e-02, - 1.69779122e-01, 1.54020518e-01, -8.75391066e-02, - -2.06520095e-01, 6.03866279e-02, 1.08508043e-01, - 4.56446186e-02, -2.30992153e-01, 6.16142601e-02, - 5.93037927e-04, -2.22505212e-01, -4.13618460e-02, - 1.47342280e-01, 4.37493399e-02, -1.54354155e-01], - [ 5.18466011e-02, 1.40082181e-01, -2.43853368e-02, - -9.01594944e-03, -2.02037729e-02, -2.15594158e-01, - -1.49669036e-01, -2.02583615e-02, 4.76960652e-03, - -4.28980350e-01, -2.16286242e-01, 2.93388069e-02, - -2.61512101e-01, 4.32281435e-01, 5.15976362e-02, - -2.38068718e-02, 1.35174215e-01, 1.65118262e-01, - 1.18229888e-01, -4.75422740e-02, -1.69874616e-02, - -9.87956077e-02, -6.16191179e-02, 1.92472130e-01, - 4.03664082e-01, 9.86855701e-02, 2.18016505e-02, - 9.58452746e-03, 2.42479756e-01, -9.45590809e-02, - 6.06411323e-02, -1.15035795e-01, -5.60823381e-02, - -1.10115618e-01, 7.84227401e-02, -1.58600360e-01], - [ 3.63750085e-02, 2.90070504e-01, 2.58655623e-02, - -4.51171659e-02, -9.76288766e-02, -7.32196262e-03, - 2.62665208e-02, -1.30719528e-01, -3.34864855e-02, - 1.83281839e-01, -2.03847468e-01, -7.86208585e-02, - 2.39961028e-01, 9.32282284e-02, -1.40201841e-02, - -1.65743440e-01, 2.50046160e-02, 1.87149823e-01, - -1.68221984e-02, -6.99453712e-01, 2.46135090e-02, - 9.76792276e-02, 1.59403309e-01, 1.05807781e-01, - -1.64897703e-02, -3.37719321e-02, 9.97098759e-02, - -5.71760125e-02, -2.09543109e-01, 1.61970984e-02, - 4.49959114e-02, 1.13044158e-01, -1.33089647e-01, - 6.79383874e-02, -1.17107280e-01, -1.62846550e-01], - [ 2.09034402e-02, 9.87452939e-02, 3.10002435e-02, - -3.82550769e-02, 6.49476936e-03, -1.86508909e-01, - -1.58566430e-01, 1.52609888e-02, 2.44785240e-03, - -1.72963649e-01, 2.82357018e-02, 6.35804012e-02, - -4.01134878e-01, -3.48292142e-01, -9.30772051e-02, - 2.69406252e-02, -1.48355186e-01, 6.67649359e-02, - -1.52495161e-01, -4.16254858e-03, -7.79623985e-02, - -8.69922712e-02, 1.67651065e-02, 4.43452805e-01, - -4.69122916e-01, 1.32700158e-02, 1.84264123e-01, - -4.69396599e-02, -8.76988843e-02, -8.42647329e-02, - 1.80242240e-01, 4.39915545e-02, -3.01284958e-02, - -4.19178084e-02, -6.55100867e-02, -1.67092770e-01], - [ 5.43184578e-03, -8.44964292e-03, 5.85759105e-03, - -7.32589066e-02, -6.53161779e-02, 1.58945680e-01, - -1.98484868e-01, -2.29594544e-01, -3.62942442e-02, - -4.60159145e-02, 4.65791941e-01, -1.32931456e-01, - -1.30874768e-01, 1.82594404e-01, 4.72868867e-02, - 7.68151507e-02, -1.17584936e-01, -7.83182383e-02, - -5.70569098e-01, 5.07849343e-02, -6.92476258e-02, - 1.45652056e-01, 1.57256410e-01, -2.92076059e-02, - 2.85284370e-01, 2.52744146e-02, 2.82830708e-02, - -5.04164398e-02, -1.00659683e-01, 5.86346574e-02, - 1.91001222e-02, 8.99196714e-02, -1.54763028e-01, - 1.01448707e-01, -7.42661506e-02, -1.71338975e-01], - [-1.00397598e-02, 6.89980984e-02, 5.02617331e-03, - -5.32203764e-02, 1.92967560e-02, -5.64105034e-01, - 3.46719325e-01, -7.40835667e-02, -5.14018210e-03, - 9.32325572e-02, 1.93343818e-01, 3.23573984e-02, - 2.21131876e-01, 3.06417048e-01, -8.70961323e-03, - 4.47171003e-01, 8.35162401e-02, 8.83740187e-02, - -8.72178078e-02, 1.18704282e-01, 1.05058528e-01, - -4.56921048e-02, 1.59751941e-02, -3.00876088e-02, - -2.47394085e-01, 4.93424907e-02, -6.64604902e-02, - -3.64027135e-02, -1.82686392e-02, -4.59523462e-02, - -1.26862470e-02, 2.52796169e-02, -4.81151454e-02, - -2.86283679e-02, -2.56162435e-02, -1.75585181e-01], - [-2.55113579e-02, 1.63476765e-02, -6.48622513e-02, - 8.53358284e-02, -1.47179626e-02, -2.74279952e-01, - 3.23813617e-01, 1.18787922e-01, -3.12188938e-02, - 1.27388835e-01, -1.47029653e-01, -6.44396339e-03, - 1.59717619e-01, -8.00469816e-02, 4.15628105e-02, - -3.71895492e-01, -2.58336008e-01, -3.58502686e-01, - -9.30814072e-02, 2.37474293e-01, -1.02323368e-01, - 7.77886510e-02, -2.62345857e-04, 3.05618107e-01, - 2.69323707e-01, -4.94645983e-02, 7.17321262e-02, - 1.81141701e-02, -7.26979673e-02, 3.66130173e-02, - 3.41478437e-02, -1.42837018e-01, -2.29302347e-01, - 9.40499976e-02, 9.85415503e-02, -1.79831386e-01], - [-4.09829244e-02, 2.96095997e-01, 5.72670512e-02, - -1.39296770e-01, -1.60581374e-03, 2.67294142e-02, - 5.13432994e-02, 3.44210893e-01, -4.88008671e-02, - -1.20673403e-01, -4.54095185e-01, -3.60888802e-02, - -3.48375738e-02, -3.80728357e-02, 6.19033575e-02, - 2.85812598e-02, -5.49174994e-02, 8.16437509e-03, - -3.89526159e-01, 1.42197743e-01, -6.57034442e-02, - 9.32944417e-02, -1.29381031e-01, -4.54968363e-01, - -7.63084590e-02, -1.27602285e-02, -3.93663906e-02, - -2.22954508e-02, 9.34363678e-02, 4.61584628e-02, - 1.17300354e-01, 1.84356645e-01, 4.64061499e-02, - 2.61230320e-02, -1.38632745e-01, -1.84077591e-01], - [-5.64545169e-02, -3.65092814e-01, -4.26685773e-02, - 1.75265297e-02, -1.79290678e-03, 7.54252076e-02, - -2.16403184e-03, 1.22491851e-01, 4.61655157e-03, - 9.93698239e-02, -2.86250204e-01, 1.17600495e-02, - -1.76643163e-01, -1.61555171e-01, 4.21675071e-02, - 4.96386349e-01, 2.84064054e-01, -1.88499331e-01, - 5.03461063e-02, -9.29289460e-02, 2.72047639e-01, - 1.54824242e-01, 7.62812719e-02, 9.09931362e-02, - 1.82046860e-01, -1.51961623e-02, 1.57171339e-01, - -2.52939817e-02, -6.88583925e-02, 8.74516144e-02, - 1.06507227e-01, 3.63174151e-03, -2.16592148e-01, - 1.95526704e-01, -2.63463091e-02, -1.88323811e-01], - [-7.19260871e-02, 1.53307199e-01, 2.98810583e-02, - -1.76042188e-02, 4.68952209e-02, 2.30930567e-01, - -1.91631261e-02, -3.50371659e-01, -1.39247498e-03, - -3.16982158e-02, 3.19441818e-02, 1.38011038e-01, - 1.15297228e-01, 1.21593997e-01, 1.12343794e-02, - -6.25559241e-02, 2.27593221e-02, -1.95765942e-01, - 2.61839062e-01, 1.88924655e-01, 1.47905156e-01, - 3.61047573e-02, -1.53986499e-01, 4.26004231e-02, - -1.01659156e-01, -9.87078920e-02, -1.97795078e-01, - 2.87956242e-02, 2.66166143e-02, 2.03926936e-02, - 6.36121154e-01, 1.17329828e-01, -1.68884546e-02, - 1.05052806e-01, -1.36004210e-01, -1.92570001e-01], - [-8.73977244e-02, 2.91939259e-01, -6.38535023e-02, - 1.23778999e-01, 2.33115517e-02, 8.99281502e-02, - -2.38235518e-02, 2.54457176e-01, -2.92873345e-02, - 1.45903289e-01, 2.51857221e-01, -1.22888424e-01, - 4.71667722e-02, -1.51163086e-01, -6.75680041e-02, - 1.34960130e-01, -5.27166612e-02, 5.85827529e-02, - 6.49949759e-02, -6.27990216e-02, 7.91215152e-02, - -2.11644500e-01, 1.25666901e-01, -2.19153777e-01, - 1.45102561e-01, 9.46507752e-02, 2.63710856e-01, - 1.36273995e-01, -2.85680946e-02, -9.64817554e-02, - 3.51572961e-01, -3.73799771e-01, 7.54300505e-02, - -1.52278930e-01, 2.77134597e-01, -1.96816236e-01], - [-1.02869295e-01, 4.54483837e-01, -3.16920318e-02, - -9.15080402e-03, 4.94015254e-02, 2.09832817e-01, - 9.22076330e-02, -3.92193407e-01, -1.33265834e-03, - 1.03313603e-01, -7.82989189e-02, 8.86598602e-03, - -9.18587223e-02, -1.70766622e-01, 5.54255210e-02, - 2.28601284e-02, 1.81634039e-01, 4.14796174e-02, - 3.81892845e-02, 2.48120666e-01, 1.65915981e-01, - 2.87097245e-02, -2.50649545e-02, 4.36540544e-02, - -5.01171201e-02, 3.54694985e-02, 1.90053612e-01, - 9.52630565e-02, 1.70738876e-01, 3.70882489e-02, - -4.90600616e-01, -9.28841755e-02, -8.13470930e-02, - 8.31348598e-02, 5.93565181e-02, -2.01062426e-01], - [-1.18340865e-01, -6.85950592e-02, 4.95309308e-02, - -1.77844893e-02, -9.69045609e-02, 2.31995173e-02, - -1.06131600e-03, 2.21603140e-01, -6.05566725e-02, - -2.82245725e-01, 2.64784724e-01, 8.62200931e-02, - 1.37575060e-01, 1.50092602e-01, 4.38311473e-02, - -1.27834529e-01, -1.75913945e-02, -2.03415841e-01, - 1.48476526e-01, -7.80855790e-02, 2.29345813e-01, - 3.37421596e-02, -3.02611887e-01, -3.64654101e-02, - -4.98286486e-02, -1.24875009e-01, 5.32554924e-01, - -5.55246398e-02, -8.19649324e-02, 4.32646945e-02, - -1.92818239e-01, 1.91410363e-01, 1.91146538e-01, - -1.30635314e-02, -1.27977282e-01, -2.05308631e-01], - [-1.33812457e-01, 5.83807267e-02, 6.38746191e-03, - -6.32736981e-02, 2.60766506e-01, 1.92557305e-01, - -4.26477045e-02, 5.47973156e-01, 1.53431622e-02, - 2.03396276e-01, 2.18420655e-01, 1.71779748e-02, - -7.09848702e-02, 2.39939511e-01, -2.50959713e-02, - -1.48106590e-01, 1.51656091e-01, 1.71890616e-01, - 7.37760216e-02, 5.53064533e-02, 1.98505912e-02, - 9.67100039e-02, 1.37430176e-01, 2.82746285e-01, - -1.24559112e-01, 1.80215873e-02, -2.68079907e-01, - 9.55012143e-02, 1.30839288e-01, 8.27972442e-02, - -9.96278524e-02, 4.17835526e-02, -4.81917933e-02, - 1.98767141e-01, -6.95911944e-02, -2.09554836e-01], - [-1.49284035e-01, -7.56456144e-03, -8.76261014e-03, - 2.92932428e-02, -8.39372516e-01, 5.67366369e-02, - -2.41059046e-02, 8.43372419e-02, -2.29054149e-02, - 3.72556150e-02, 3.59098194e-03, -3.51436548e-02, - -4.86128107e-02, -4.90781479e-02, -2.96334457e-02, - 2.16081198e-02, -6.04292788e-02, 1.73466746e-02, - 5.54120354e-02, 4.32790630e-02, 1.27067477e-01, - -9.41377804e-02, -1.37587115e-02, 7.06801787e-02, - -1.22610051e-02, 2.18931045e-02, -3.70597780e-01, - -1.30672632e-02, -4.53533195e-02, -1.70034133e-02, - -1.13316208e-01, -3.45941707e-02, 1.05737671e-01, - -2.95185428e-02, 2.46357918e-02, -2.13801056e-01], - [-1.64755657e-01, -1.91551998e-01, 1.24477036e-02, - 1.76897332e-01, -1.70191415e-02, 2.34046783e-02, - 6.76611960e-02, -1.21719569e-01, -1.60261299e-02, - 2.84169883e-01, -7.72131458e-02, -4.39732298e-02, - -6.60723150e-02, 8.68341923e-02, 7.35200867e-02, - -1.56345084e-01, 4.99212921e-01, -9.53519195e-02, - -1.69593558e-01, 3.12364921e-02, -4.14223462e-01, - -2.19161183e-01, -7.49167113e-04, 4.25142385e-02, - -2.26298310e-02, 3.90600637e-02, 1.34113848e-01, - -4.32782359e-02, -2.25105719e-03, -8.36708769e-02, - 7.53742829e-02, 1.09890841e-01, 3.47145647e-01, - -1.67040601e-01, -4.17540558e-02, -2.18047246e-01], - [-1.80227250e-01, -3.65751952e-01, 1.95310116e-02, - 3.56181487e-02, -2.47674435e-02, -2.56252866e-02, - 1.70394495e-01, -1.01341322e-01, 6.43750429e-02, - -1.18520278e-02, 7.76712969e-02, 1.21111691e-01, - -7.56260678e-02, -1.32285401e-01, 2.50612080e-01, - -2.70852149e-01, -9.66061503e-02, 4.63890702e-01, - 5.18286489e-02, 1.14975851e-02, 7.05922395e-02, - 7.95801077e-03, 3.40116471e-02, -2.50298321e-01, - -4.72176410e-02, 7.11330771e-02, 7.71585703e-02, - 7.12307394e-02, 1.51480496e-01, 4.94032800e-02, - 9.26278085e-02, 1.93590626e-01, -3.63108933e-01, - -1.36400744e-01, 1.46016315e-01, -2.22293481e-01], - [-1.95698813e-01, 8.16941485e-02, 6.35532150e-03, - -5.50320372e-02, 1.45350844e-01, -7.66825154e-02, - -1.48402769e-02, 8.44644289e-03, -3.05129532e-02, - -3.45072865e-01, 1.88118920e-01, 1.39703169e-01, - 9.01852995e-02, -3.05740625e-01, -7.54492134e-02, - 6.51175901e-02, 2.45817453e-01, -1.89270392e-01, - 1.16880536e-01, -2.26171866e-01, -3.72853994e-01, - 5.43844700e-03, -1.24716990e-01, -1.48458153e-01, - 5.83554097e-02, -8.44632387e-02, -3.41172040e-01, - -5.05601391e-02, -1.60052970e-01, 5.74440435e-02, - -1.45993277e-01, -4.03214097e-02, -2.16732427e-01, - -2.84256153e-02, 1.41579702e-01, -2.26539686e-01], - [-2.11170420e-01, -6.31088763e-02, 8.17671046e-03, - -5.57366088e-02, 6.94130734e-02, 3.52174342e-02, - -6.57851174e-02, -9.82191563e-02, -1.27271414e-02, - 1.43996403e-01, -1.19659491e-01, -5.62400967e-02, - -1.02117673e-01, 1.46197915e-01, -6.46053180e-02, - 2.75428176e-01, -5.38663089e-01, 1.51460487e-02, - 3.81278455e-01, 1.08411210e-02, -4.44346756e-01, - 4.02242467e-02, 9.23668295e-02, -7.21167400e-02, - 3.91138941e-02, 4.99221608e-02, 9.94546860e-02, - -3.87978405e-02, 1.93843860e-02, 8.32882449e-02, - -1.15623131e-01, 8.08125958e-02, 1.40358344e-01, - 1.01261795e-01, -5.90205789e-02, -2.30785877e-01], - [-2.26641983e-01, -1.44536331e-01, 8.91233422e-03, - 5.05167954e-02, 3.87359351e-01, -1.25706807e-01, - -9.50697213e-02, -1.42298609e-01, -7.01352954e-02, - -3.15868692e-03, -1.33074358e-01, -1.18453935e-01, - -7.71054849e-02, -4.75535467e-02, -1.50268868e-01, - -1.44392461e-01, -1.82032049e-01, -1.19762598e-02, - -1.21959276e-01, -6.38470054e-02, 4.80738163e-01, - -1.59658909e-01, 2.71296166e-02, -4.31644246e-02, - 1.02411315e-01, 2.07743910e-03, -2.89108336e-01, - -1.03720047e-01, -2.01758668e-01, -2.16420572e-02, - -1.27163813e-01, -7.36601278e-03, 3.14732850e-01, - -1.12868495e-01, 3.11465543e-02, -2.35032097e-01]], dtype=float32), array([-1.89882166e+03, -1.79985218e-04, -1.70435800e-04, -1.27975552e-04, - -1.24901737e-04, -1.24676313e-04, -1.16428266e-04, -1.06598200e-04, - -1.00050034e-04, -9.61478145e-05, -8.36294785e-05, -6.41566730e-05, - -4.51904889e-05, -2.39018827e-05, -1.49146554e-05, -9.43070791e-06, - -8.04440424e-06, 1.51055592e-05, 2.01099483e-05, 2.64523860e-05, - 3.25085311e-05, 5.15936626e-05, 5.31896258e-05, 7.24942220e-05, - 9.04739063e-05, 1.04830775e-04, 1.08393360e-04, 1.37811687e-04, - 1.49946762e-04, 1.86386926e-04, 1.89535742e-04, 2.40968098e-03, - 2.56012683e-03, 2.69382820e-03, 3.27441283e-03, 2.52088105e+04], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf32> {jax.result_info = "[0]"}, tensor<36xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf32> - %1 = stablehlo.reshape %0 : (tensor<1296xf32>) -> tensor<36x36xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf32>) -> tensor<36x36xf32> - %3 = stablehlo.add %1, %2 : tensor<36x36xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf32> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf32> - %7 = call @tril(%6) : (tensor<36x36xf32>) -> tensor<36x36xf32> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf32>) -> tuple, tensor<36xf32>, tensor, tensor<39001xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36x36xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<39001xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf32> - return %20, %25 : tensor<36x36xf32>, tensor<36xf32> - } - func.func private @tril(%arg0: tensor<36x36xf32>) -> tensor<36x36xf32> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf32> - return %8 : tensor<36x36xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02v\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\t)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.1857700048412179e-01, -7.9870412160195655e-05, - -7.1795133407817180e-02, 7.2651725579187088e-01, - -5.8816812454044016e-04, -1.0752133550364418e-01, - -1.9695247974936425e-01, 1.8446994643771727e-01], - [ 4.7070881487314487e-01, 3.3071017759156432e-05, - -5.9630159401629157e-01, -4.7856902268752244e-01, - -1.4151478943184035e-03, -2.5017522435505674e-01, - 2.8106392345809550e-01, 2.2856669794666581e-01], - [ 3.2284062926217122e-01, -5.1104181032785456e-01, - 2.4098685972870454e-01, -3.2057977627137213e-01, - 6.0128498619340851e-04, 5.5435726441071020e-01, - -3.0349043125069775e-01, 2.7266344945561433e-01], - [ 1.7497244365119549e-01, 4.1809211960021736e-01, - 5.7112844532216078e-01, -3.1146378582869927e-01, - -4.8989605706119613e-04, -4.4689091764000977e-01, - -2.6709076241922963e-01, 3.1676020096456298e-01], - [ 2.7104258040218803e-02, 4.2941995817157164e-01, - 1.1304358388496584e-01, 9.3073375918824142e-02, - -4.7236149166811120e-01, 4.6617552271070906e-01, - 4.7197416944525139e-01, 3.6085695247351168e-01], - [-1.2076392757075657e-01, -3.8434927079561992e-01, - 2.9171425263113138e-01, 1.5624558970245273e-01, - 4.3260383504376299e-01, -2.0278835428567779e-01, - 5.7959048064074936e-01, 4.0495370398246017e-01], - [-2.6863211318173014e-01, 3.6363990709349564e-01, - -3.3163183889685732e-01, 4.2836063092320187e-02, - 5.6343802845177837e-01, 2.7652818360156795e-01, - -2.9700444618985122e-01, 4.4905045549140854e-01], - [-4.1650029879270561e-01, -3.1571410434740910e-01, - -2.1714457524599659e-01, 9.1940300282126255e-02, - -5.2178844473770358e-01, -2.8968513893859849e-01, - -2.6809045393495168e-01, 4.9314720700035708e-01]]), array([-2.4598804776133605e+01, -2.8026300235964570e-15, - -1.8958980326674837e-15, 1.5553235693581772e-15, - 1.6670762548207520e-15, 2.2405283578797194e-15, - 5.4086800892994285e-15, 2.7659880477613365e+02])), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<2125xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8x8xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<2125xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf64> - return %20, %25 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa2\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-3.1486359056225782e-01, 3.7431364158123925e-02, - 6.1831284766658730e-02, -1.2946991231313536e-02, - 1.9330566993707950e-02, 3.1760201896488226e-03, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 9.4213470166864710e-01, -8.6414847942068732e-02], - [-2.9939200325938797e-01, 8.3501568928299474e-01, - 4.0680107296867257e-01, -4.6573192775473518e-02, - 6.5422207600829785e-02, 2.2099527094683900e-02, - -1.0242349878775975e-02, 4.0829390183091318e-03, - -1.5827725558444371e-02, -8.6793932713605769e-03, - 1.3047005177451432e-03, -5.3573283556152184e-03, - -1.1723085990292578e-02, -3.4282481604778923e-03, - 1.5300655388654032e-03, 1.3010433879291027e-02, - -7.6245808434662662e-03, 5.9569775610370131e-04, - -5.9294293157650772e-03, -1.9734040942842074e-03, - -1.8628968192927392e-02, -1.3034235399858809e-02, - -5.0097004610369401e-03, 2.4749245795903537e-02, - -5.0644358547264675e-03, 3.0532167800601515e-03, - 2.0824661626164857e-02, -1.5147462161617094e-03, - 1.6322395782111299e-02, -1.1236053191734820e-02, - -1.1821960842042806e-02, 3.8822577430670320e-03, - 7.0724820528586508e-04, 1.9906723944256747e-02, - -1.7030338737863057e-01, -9.0661051391036640e-02], - [-2.8392041595652112e-01, -1.0171687781151459e-01, - -1.1816431661072314e-01, 2.9212172394267638e-01, - 3.3294458108354380e-01, 4.2087881292542445e-01, - -2.2194306321456944e-01, 1.2056157631930936e-01, - -1.0764065526585581e-01, 4.4945129933377570e-02, - -1.1518299700192679e-01, -3.1085391640205563e-02, - 3.1385765542768805e-02, -2.2533661915179113e-02, - 9.3053311217867085e-02, -1.6099650538834706e-01, - -3.8639305088265900e-02, 9.2990366329018387e-03, - 4.6666113341746911e-02, -2.1871647987757620e-01, - 1.7703518610745730e-01, 1.5467613762024190e-01, - -7.2294521250116733e-02, 2.3499877830015681e-01, - -5.6829378083033165e-03, -1.0178485446351725e-01, - 1.7877785721217213e-01, 2.1684187554288339e-01, - 7.7233872499541889e-02, 2.2835265304748494e-02, - 3.1080805156356406e-01, 3.1722234078538948e-02, - -7.8092425763001377e-02, 9.4554636051152510e-02, - -9.6031463624110386e-02, -9.4907254840003452e-02], - [-2.6844882865365438e-01, -2.0201860535424061e-02, - -2.0343029420688158e-01, 1.2815855886454322e-01, - 4.8774092445450092e-02, 1.3232562034943543e-01, - -1.8521836621459195e-01, 9.8747816539597660e-02, - 2.7324903486606195e-01, -7.8737437097193080e-02, - 4.9421661772677816e-02, 7.1493931251323112e-02, - 3.5542595611320515e-01, 1.3920746216059152e-01, - -2.8249741974519734e-02, 6.7932896387190703e-02, - -2.3008512044551552e-01, 5.5015746716542496e-02, - -6.0329018554125865e-03, 8.4249901371007491e-02, - -1.0850059549176212e-01, -2.7052679792044718e-02, - 1.7199248671821082e-01, -2.0779039909219962e-01, - 1.1023999772580403e-01, 4.0228126834019268e-01, - -7.1331569093078903e-02, -2.2546040356632324e-01, - -5.6848723613690040e-02, 2.0039103669806510e-01, - -2.2375524112669190e-01, -6.6955463229343037e-02, - -1.4356710092268696e-01, 2.2907198003730800e-01, - -8.4342913246148038e-02, -9.9153458288970819e-02], - [-2.5297724135078736e-01, -9.7633470097019753e-02, - -2.0613664461051402e-02, -4.6575018452204114e-01, - -4.5475545929408095e-01, -1.6835202228307944e-01, - -2.7411043542686481e-01, 1.4382896244553764e-01, - 1.5533482960243880e-01, -7.7897907011887785e-02, - -5.9104799414908579e-02, -5.1049057176047449e-02, - 5.0937034273965797e-03, -2.9920502980456239e-02, - 7.9164430071644656e-02, 6.5334090456028976e-02, - -2.4594170101813598e-01, 4.0287932953704184e-02, - 1.3071075582032446e-01, -5.6912271071735306e-02, - -1.2680756132856946e-02, 3.5044366466197449e-02, - -5.1780762628180410e-03, 1.2325979893038844e-01, - -1.3286387357961091e-01, -1.9718715617446650e-01, - -7.0204376770955132e-02, -9.3710658292701816e-03, - 7.6870928390159760e-03, 1.2623341382152653e-01, - 3.4895566103640097e-01, 7.7553659039143241e-02, - -3.4023999296528072e-02, 8.3074702907895745e-02, - -8.5300072672481381e-02, -1.0339966173793817e-01], - [-2.3750565404792034e-01, -8.2181485614283623e-02, - -2.4796576412755008e-02, 2.6469606244089910e-01, - 2.5136155191565374e-01, -8.5932117879471037e-01, - -6.7801327364868255e-02, 2.3630380146045637e-02, - -6.0339530364635997e-02, 2.4318784991642788e-02, - -2.0157980609574723e-02, 1.3969684905577337e-02, - 5.2064373452097072e-02, -1.3504287072787914e-03, - 1.1948855400414819e-02, -7.7684684576308824e-02, - -1.8126869586737940e-02, -3.2895203661275497e-02, - -4.7194795185232655e-03, -6.2526420481870917e-02, - 7.8353014950393762e-02, 4.3021669650274826e-02, - 4.1123834759705602e-02, 2.1527669096626890e-02, - 3.2298969317449348e-02, 2.3438124417394162e-02, - 3.1518151219115144e-02, 8.9704214482948422e-02, - 7.6821260017619769e-03, -8.5409778343425186e-03, - 1.5521001031338759e-02, -1.3290428648657086e-02, - 1.8906628930454021e-02, -1.2782589525387992e-02, - -8.2979044248598546e-02, -1.0764586518690553e-01], - [-2.2203406674505338e-01, -9.0264475102341105e-02, - 9.0740700176499111e-03, 6.9171384437416147e-02, - -1.3111811612891669e-01, -1.8966507957248607e-02, - 4.0414304307463594e-01, -3.2564666059313241e-02, - 5.6086124244845181e-01, -4.0083205571491060e-02, - -2.4505702319715772e-02, 2.8981348567837486e-02, - -1.8028953963325864e-01, 1.2810669493073431e-02, - -3.0205734928244080e-02, 1.3016546116209483e-03, - 4.1180187675978214e-01, 1.8487430939971340e-03, - 2.1878399115523185e-02, -1.2942737544986772e-02, - 3.1612876215063763e-02, 1.9040590265843902e-02, - -2.9853451951736565e-01, -2.1069261774264141e-02, - 1.2756924052704141e-02, 1.0396556130345047e-02, - 2.0982593071380967e-01, 7.2513245350085284e-02, - 2.6961322653924678e-02, 4.4259057451694346e-02, - 1.3245555422671054e-02, -1.1355432725780245e-02, - -1.6423769454471046e-01, 2.1283797622603673e-01, - -7.7771821344734746e-02, -1.1189206863587289e-01], - [-2.0656247944218639e-01, -7.5555152047925872e-02, - 2.1436004480934572e-02, 1.8519822533150174e-01, - -4.7687267679858099e-02, 1.0893715640778658e-01, - 5.4446388557811642e-01, 6.7864355635107079e-02, - 1.8925675037139755e-01, 3.6392773516755073e-02, - -2.4764455183159433e-02, -3.8468294614801751e-02, - -2.8696444635530814e-02, -1.8823021866307067e-02, - 4.8264052464878845e-02, -3.6882747079153497e-02, - -3.0155420938729255e-01, 1.0404831951207047e-02, - 4.4505477004053171e-03, -4.6873846610364103e-02, - 2.4798470273412251e-02, 2.5891733287640804e-02, - 3.5011544817152707e-01, 8.5903050378751358e-02, - -1.6860450574909990e-02, -3.9052038500091160e-02, - -2.9924661599529656e-01, -1.5823886416275893e-03, - 2.8254484941419005e-03, -4.8861168063938747e-03, - 9.7917302635802658e-02, 2.7710576047465570e-02, - 2.3536560145276611e-01, -3.9600571986552502e-01, - -7.4934893198527877e-02, -1.1613827208484023e-01], - [-1.9109089213931946e-01, -8.4666472598656825e-02, - 5.7740802097843921e-02, 1.9626130737187028e-01, - -2.4601756649487860e-01, 8.1511271167717628e-02, - -4.6530930078469529e-01, 6.8795587726048116e-02, - 5.2415554010200038e-02, -1.7332120317563506e-03, - 3.1251731285109323e-02, 1.5521676381926154e-02, - -1.2359815126908288e-01, 2.7460289856811461e-02, - 1.9114633014954776e-02, 2.8966001347205911e-03, - 4.3487864890462036e-01, -2.2957986155413699e-02, - -1.5357935266312277e-02, 1.0016152245695723e-02, - -4.5019081491420573e-02, -2.4405778384030734e-02, - -7.4832588748429490e-02, -4.4078616914614753e-02, - 3.0809052034342380e-02, 1.1926634983737788e-01, - -8.1517751909305367e-02, -7.7527914203627396e-02, - -3.7123430398910418e-02, 1.3750979135916276e-02, - -9.7457414231716055e-02, -1.7178991628521816e-02, - 2.1304973749867503e-01, -5.4941011823140218e-01, - -6.7860578570392335e-02, -1.2038447553380759e-01], - [-1.7561930483645249e-01, -8.8342789136092309e-02, - -1.1242590243640400e-02, -1.8652768797207359e-01, - -9.8464009205703876e-02, 1.7256713195193910e-02, - 2.9649268724224581e-01, 5.8780632678962143e-02, - -3.4585362321307522e-01, 7.6907763800451081e-03, - 2.5103268120083535e-02, 2.5393826053803564e-02, - 4.3240349879996420e-01, 3.3310696488693933e-02, - 2.1609140330890370e-02, 1.3951456173138647e-03, - -1.2840968480253712e-01, -3.3248191939129826e-02, - -8.9379099725266672e-04, -1.8994911138723630e-03, - -2.3834826680311980e-02, 4.7502947323282011e-03, - -4.4024121870114297e-01, -6.7327999197165686e-02, - 2.9359383382924452e-02, 9.1479482958182867e-02, - 3.8593300484440007e-01, -4.7958512765110956e-02, - -5.1251961259242168e-02, 1.8636628882937378e-02, - -6.5572564769060912e-02, -2.2887842635462220e-02, - -1.6042006104302377e-02, -3.3250776465128573e-01, - -6.6477273291217359e-02, -1.2463067898277495e-01], - [-1.6014771753358550e-01, -8.3434053708109190e-02, - 1.3638599925185501e-02, -2.4158649874087133e-02, - -1.1124755841847851e-01, 4.2695267715302458e-02, - 1.4866152720116035e-01, 4.9700778378845270e-04, - -3.5326388070491549e-01, -1.5745483283003094e-02, - -8.9738221678782072e-03, 1.0993364411347295e-02, - 1.9527915544397639e-01, 1.3259513825918660e-02, - -3.9339417079053149e-03, -3.7389315402467350e-02, - 3.0825337281314197e-01, 2.9465425388143118e-02, - -1.0086552608467406e-04, -2.1130010935818223e-02, - 2.4746795171351338e-02, 1.2876294127766924e-02, - -1.3542161100061775e-01, 2.3491306500478031e-02, - 2.8381089185132442e-02, 5.0060402655999779e-02, - -4.7990645387633185e-01, 1.7841388064942280e-02, - 3.6163722246352295e-02, 2.2692968040711251e-02, - -1.4881297657765719e-03, -1.1068249840362020e-02, - 4.3250260717661632e-01, 4.5393847466427317e-01, - -6.1116215809998306e-02, -1.2887688243174231e-01], - [-1.4467613023071851e-01, -8.5360329958689612e-02, - 3.6773895176301370e-02, 2.8417567832807769e-04, - -1.4251569175101705e-01, 1.8419541161364662e-02, - 1.4739729008583152e-01, -6.2901931512317516e-02, - -4.3820330673251112e-01, -1.1585923923104585e-01, - -4.6526417840431711e-02, 1.2161556905396271e-02, - -8.3388018002128958e-02, 2.3616237126461999e-02, - -9.1086898933490409e-02, 9.6073985629915787e-02, - 3.0200810799555788e-01, 9.9080289536070815e-02, - 4.9921034650103280e-02, 7.6871969202905246e-02, - -8.3377720121475072e-03, -1.7031625806123534e-02, - 4.5636496936456672e-01, -4.0005637071420394e-02, - -1.9891703100641429e-02, 1.2472945837760744e-02, - 5.9697784009368959e-03, -9.5789228620796370e-03, - 6.8806967828826657e-02, 1.5038487697273856e-01, - 6.8452882565985446e-02, 1.3123694381544091e-02, - -5.6226049096551989e-01, -4.1018946243773058e-02, - -5.6717572380307106e-02, -1.3312308588070965e-01], - [-1.2920454292785150e-01, -6.6253352907543861e-02, - -1.0164436321011842e-01, -1.4433060335444364e-01, - 1.6028176487458967e-01, 3.4584483531135940e-02, - 1.9900533500768001e-02, -5.2164178106233798e-02, - -1.2875710620386896e-01, -1.3038955529948765e-01, - -3.1311992664378889e-02, 2.5299917094429910e-02, - -4.1764341929454979e-01, 5.7547077142788963e-02, - -1.1598534347679475e-01, 1.8086109486937549e-01, - -6.3115663671148348e-02, 8.6408791666891471e-02, - 4.0289642159952954e-02, 1.2892059198986330e-01, - -7.5052803928986972e-02, -3.4807004039357006e-02, - 2.0072216849958635e-01, -1.1909118683716058e-01, - -2.6393566026650855e-02, 6.6849035713186178e-02, - 4.7200759534307635e-01, -7.6853961442131774e-02, - 2.6993333821331650e-02, 1.7484304402685918e-01, - 5.3240433359001025e-03, 2.9788042206222785e-03, - 5.1760936987899087e-01, 1.1384037033693235e-01, - -5.1865856323749862e-02, -1.3736928932967699e-01], - [-1.1373295562498452e-01, -5.7235135967154585e-02, - -4.7652965020097103e-02, -1.7627396739100985e-02, - 7.7938405922626644e-02, 2.2087656281477019e-02, - 6.1009605667557178e-03, -5.4981966965685393e-02, - -1.8486086378646865e-01, 3.8911039431433647e-02, - 3.5079519080830110e-02, 1.9272432328556483e-02, - -5.9096451891695889e-01, -7.7247905448605157e-03, - 3.7441325666613741e-02, -4.9165769090891341e-02, - -3.3776276260195798e-01, 1.6606308621317768e-02, - 3.8859102913090936e-02, -1.9047412918711374e-02, - -3.8482634352387676e-02, -4.8755071639337150e-02, - -4.3270527443011519e-01, -9.1999354995766322e-02, - 1.0430914529054176e-01, 1.4978760949122619e-01, - -3.4135100214765429e-01, -2.5289826614278744e-02, - 3.4608873349492607e-02, 8.8085003662463843e-02, - -1.5196825642675141e-01, -9.3051296574294673e-03, - -2.4468277187262805e-01, -2.4348157193486621e-02, - -4.7513567722300747e-02, -1.4161549277864433e-01], - [-9.8261368322117570e-02, -1.6390394385331745e-02, - -5.4742294041798749e-02, -5.8987021949670405e-02, - -1.6882319276059432e-01, 4.3601612172208745e-02, - -2.9911314975774938e-02, 2.3284677199386728e-03, - -3.1808540586289284e-02, 6.9627318822466044e-01, - 1.6271702602637766e-01, 1.5743246880124597e-02, - -4.3195703838658110e-02, -2.2494758789598773e-01, - 7.1399213422553218e-02, -1.3240943946997921e-01, - -8.4980139589052577e-03, -3.2038201094679952e-01, - 6.2407097431780204e-02, -7.6882180114861851e-02, - 2.9470860002467913e-02, -4.2571478756212582e-02, - 2.0163350380724604e-01, -3.2389702717405428e-01, - 6.9711204990479309e-02, -8.1573794801329258e-02, - 1.3304500243627673e-01, 4.0406118875997113e-02, - 8.2477981782237836e-02, -1.1543529624088469e-01, - -1.1014206710642817e-01, 4.2320022953069426e-06, - 3.8041226304310447e-03, 1.3395530894194055e-01, - -3.9467794046677329e-02, -1.4586169622761166e-01], - [-8.2789781019250525e-02, -1.9278711714630567e-01, - 2.2165755909431184e-01, -2.1201546316262262e-01, - 1.4307796989725635e-01, 6.0334342472999250e-02, - -5.5139304406736672e-02, -1.9408969113742302e-02, - 5.4970843704949646e-02, -4.5047658482968128e-01, - -3.3338315762977556e-02, -6.5308425743183532e-02, - 1.4218465309675436e-02, 4.9087218418760230e-02, - 1.8670840217742501e-01, -1.5287462038432642e-01, - -1.3217180940167689e-02, -6.6463048958420534e-02, - 3.8845065361654303e-04, -2.2429929685530131e-01, - -2.6776933696982124e-02, 8.5772405898653856e-02, - 1.1857225379472448e-01, -3.3789334871471582e-01, - 8.3834684881833613e-02, -1.7391265231974168e-01, - -5.9431721332300208e-03, 2.7485104738181495e-02, - 1.6105963634532708e-01, -4.7246605597344127e-01, - -2.3898285645951292e-01, -2.0628986543330220e-02, - -2.1798010578591574e-02, 1.6076906598537423e-02, - -5.4377032852269684e-02, -1.5010789967657906e-01], - [-6.7318193716383562e-02, -1.3247564302860890e-01, - 1.7006921492087917e-01, 1.2398760160260749e-01, - -1.4177630269484331e-01, 1.5422349385403381e-02, - -5.9592326716797428e-02, -3.5882053764316857e-02, - -1.7232432793461348e-02, 2.3701488719579314e-01, - -4.6593215018650616e-02, -6.3082282004145299e-02, - -2.0902723950643357e-02, 5.2050993065408405e-02, - -8.0468326155430828e-02, -5.0880717820819980e-02, - -1.1820152914284968e-01, 5.6506976812092713e-01, - -2.1968735055254530e-02, 1.6529598718631755e-01, - 1.0797738052990204e-01, -3.0113303079001008e-02, - 5.5521405735639642e-03, 2.7802427161516047e-01, - -1.3829193596041753e-01, -1.1466435184415830e-01, - 1.1740546330296046e-01, -1.7311150238082029e-01, - -1.6365530586101310e-01, -3.6819727396673907e-01, - -3.1239015782869367e-01, 6.3966770007709506e-02, - -2.6591619532336051e-02, 1.2885889151522636e-01, - -3.7992961598361283e-02, -1.5435410312554640e-01], - [-5.1846606413516585e-02, -6.0477319044140554e-02, - -7.5750638182608219e-03, -1.0624372654415394e-01, - 8.1266486795481985e-02, 4.0180836057036554e-02, - -3.7783670829837974e-02, 4.6289675320758547e-02, - 3.3808855820936547e-02, -1.9195948450068509e-01, - -5.8196442046703094e-02, 1.7282080569685822e-03, - 1.4755965059760449e-02, -6.0959969133142022e-01, - -2.8239274796445768e-01, 1.2486767782495350e-01, - -1.6812624118941352e-02, -3.1637047991210354e-01, - -3.4329518102613220e-02, 2.9658523886210797e-01, - 2.1095830387260842e-01, -7.1581690223787436e-02, - 1.4902746008909057e-02, 2.5118050689616306e-01, - 1.5960904763919231e-01, 1.6146826320314336e-01, - -3.0778528162015331e-02, -6.0781897242040703e-03, - -1.5766062756371724e-01, -2.2924930849571712e-01, - -2.3919944196342770e-02, 4.0432828090792343e-02, - -3.3603315710298294e-02, 6.6005717038430623e-03, - -3.2237412023528290e-02, -1.5860030657451374e-01], - [-3.6375019110649567e-02, -4.6095054123273631e-02, - 4.1487329226456366e-03, -4.9882330119267008e-02, - 2.6789583798631911e-01, 2.8310263556813459e-02, - -5.0744234427433435e-02, -2.1955670997388516e-01, - 8.8814242427478526e-02, 7.2616405945027329e-02, - 3.7105581486243189e-01, 1.3801726499993164e-01, - 1.2228306569610396e-01, -1.8641957679946289e-01, - -1.7746951776518829e-01, 1.1838468893129621e-01, - 4.1434840944853890e-02, 3.4352445701196649e-01, - -1.3539286248067484e-01, 1.2179016223131671e-01, - -1.4481862254120659e-01, -6.0813770391397334e-02, - -9.5024877677197070e-02, -2.6026144416788322e-01, - 6.7007386100264313e-02, -2.7403316717453452e-01, - -1.2940472617950355e-01, -7.0811325772559455e-02, - 1.0283464270665656e-02, -5.0042226650144100e-02, - 3.9567119578457077e-01, -2.3131183910318670e-01, - -2.4438157021422158e-02, -9.5495078814865603e-02, - -3.1811761848109070e-02, -1.6284651002348108e-01], - [-2.0903431807782615e-02, 7.2327502897265056e-02, - -2.1426834420397733e-01, -2.4971807305411563e-02, - -6.8251303361485452e-02, -3.5176957926268708e-03, - -1.7281098595222758e-02, -2.7919893499292525e-01, - -7.5490419998562163e-03, 8.8933532299955390e-02, - -8.3918077552881970e-02, 4.2946166228858822e-02, - -3.5084337029511685e-02, 5.2484778345047800e-01, - -1.3476341073870199e-01, 8.9651093734304757e-02, - -2.6221874920893444e-02, -3.2081171793188057e-01, - -7.0201683149374666e-02, 9.7920337768921742e-02, - -7.6208072805887969e-02, 2.9964575931518713e-02, - 2.1839138515231137e-03, 2.1907625163481245e-01, - 7.8802565386018458e-02, 1.0637722019900711e-01, - -1.5047419808766808e-02, -1.2522929609505140e-01, - 1.0489044814827699e-01, -4.4452472469644072e-01, - 2.5261973738582033e-01, -1.9360753077714768e-01, - -3.0637038971187570e-02, -3.9473390838082588e-04, - -1.0054456334322568e-02, -1.6709271347244839e-01], - [-5.4318445049156196e-03, -1.1991560506989501e-01, - 1.6016393502783463e-01, -9.0534713898102900e-02, - 1.7803986653673967e-01, 4.2517830558630100e-02, - -6.5595472901773699e-02, -6.9456352075150884e-02, - 7.9849581869208763e-02, 1.4596149872374808e-01, - -3.7448911148165226e-01, 3.0784697110174092e-02, - 1.0212691273921030e-01, 1.2477201433959939e-01, - -2.1170895978207616e-01, 1.9057503902571590e-01, - -1.9885301263116554e-02, -2.1847437899940467e-01, - -1.3659628076825936e-01, 6.2262165446311392e-02, - -1.9622860693073528e-02, 4.1620399347292121e-02, - -3.1648999142503326e-02, 8.2027519954154221e-02, - -7.9260224219164649e-02, -4.4257777757196498e-01, - -1.0450524222584731e-01, 7.1670676847096298e-02, - 4.6620848245388563e-02, 3.5490360494088574e-01, - -3.4694381436297000e-01, -2.2966638374036538e-01, - -2.1349097951285249e-02, -5.0149218417714851e-02, - -2.8318514185483656e-02, -1.7133891692141581e-01], - [ 1.0039742797951326e-02, 1.4486958501002600e-01, - -3.0487486722127227e-01, 1.2108072885929126e-01, - -1.1723298949673400e-01, 9.6017523703054095e-03, - 4.9883113678426960e-03, 3.2018649396693973e-02, - -4.0095882258820964e-02, -2.4528012104090294e-01, - 6.0349817604330003e-01, -6.0025406492642708e-02, - -1.6146280657180472e-02, 1.5798023347451132e-01, - -1.5035528625979958e-02, -2.2434556029665070e-02, - -2.4354754626807390e-02, -1.5308774844201870e-01, - -1.1065734099847921e-02, 5.1339996940509787e-02, - 1.6396255893983677e-01, 2.4722965810338692e-02, - 9.6017297101513074e-03, 1.6662850312888863e-01, - 9.1395453034799151e-02, -4.2004786665153609e-01, - 3.0226599593042958e-02, 3.3204444593892296e-02, - -9.0545811500522586e-02, 1.1327046229049616e-01, - -2.5108979165208944e-01, 1.2687846708619716e-01, - -2.1404901679780933e-03, 2.9977168343317158e-02, - 5.5400172108409033e-03, -1.7558512037038310e-01], - [ 2.5511330100818325e-02, -5.8698168696025753e-02, - 8.0629703301508024e-02, -7.0612253616157819e-02, - 3.2715731475630602e-02, 2.1732269341780134e-02, - -5.6700795470449199e-02, -6.8235752853351661e-01, - 6.4905178300795938e-02, -3.5862976828251472e-02, - 8.8618413873728166e-02, 3.1550620324006268e-01, - 9.2319437517647415e-02, -1.0599662867975553e-01, - 2.6587503059973538e-01, -1.0545080566473539e-01, - -2.2738440485640277e-02, -6.6368929276419075e-02, - -5.1003071286368440e-02, -1.1626185301232636e-01, - 5.4119363471023328e-02, -2.4882466696968256e-02, - 4.6420092314024886e-02, 1.7831888983094824e-01, - -2.7253935859206135e-01, 1.7198911112035339e-01, - 1.3432430343834192e-02, 7.1000954309573148e-03, - -3.8416339301886476e-03, 1.6384316059667964e-01, - -6.0953258543061287e-02, 2.6960776094017469e-01, - 2.0718992188831518e-02, -2.7614704623654989e-02, - -1.2643038301898243e-02, -1.7983132381935049e-01], - [ 4.0982917403685273e-02, -4.7160343894475723e-02, - 7.8787266856851345e-03, -1.6730572778497552e-01, - 2.7113248408711793e-01, 9.8438763801876154e-03, - 2.2608843153598773e-02, 4.0738411310515976e-01, - 3.2355058682223534e-02, 1.1698920368317291e-01, - 1.4072643414054364e-01, 6.7061453574130916e-02, - -1.8930127519950827e-02, 1.9146087806398635e-01, - -2.4250669817151019e-02, 1.1868698006794093e-01, - 1.0317141879348907e-01, -8.5252634874863287e-02, - -2.8010523433118828e-01, 1.3060583612270180e-01, - -9.9969111180962050e-02, -3.4760563118607063e-02, - -1.7994529116745678e-02, -6.0554676763009442e-02, - -4.6559703882739706e-01, 1.1940676107160293e-01, - -1.0161278374127546e-01, 1.3173327834920193e-01, - -2.2709272071986680e-02, -1.1755702148341549e-01, - 3.7441059930431703e-02, 4.4164660080364565e-01, - -6.6992110689447992e-02, -2.5301348191003502e-02, - -9.7262032302421250e-03, -1.8407752726831786e-01], - [ 5.6454504706552257e-02, 7.8158541336176779e-02, - -1.4338657014458589e-01, 1.0703741291078765e-01, - -1.3942580377761906e-03, 2.2695174951015635e-03, - -3.8562621975632518e-02, -3.0965063003047144e-01, - 3.7355997032764349e-02, 1.4990453152525209e-02, - -1.1227058245216649e-01, -7.0287795373175999e-01, - 1.1718292741895955e-01, -5.1035967037226390e-02, - -9.4000621055494157e-02, 1.7518267045374700e-01, - -1.4730348981690847e-02, 5.1783743616797537e-02, - 2.1169018058168132e-01, 5.8597372997689870e-02, - -1.6243455966644404e-01, 5.9497378897041750e-02, - -7.3121464646455983e-02, -1.8084067697810838e-01, - -6.6501694611624321e-02, 4.1097079298917809e-02, - -4.3356588698331838e-02, 2.4444891440205574e-01, - 6.5642952335239826e-03, -9.6906979426258765e-03, - 1.8913630981055121e-03, 2.7008769602574367e-01, - 8.8545125037905337e-03, -3.9988001886776758e-02, - 9.3906452280477001e-03, -1.8832373071728517e-01], - [ 7.1926092009419212e-02, 8.0994217906793439e-02, - -2.0767188447365928e-01, -1.5196436606475891e-01, - 1.3077919554196207e-01, -2.1254474743086713e-02, - 4.5019671597743463e-02, 9.6558458919928689e-02, - 1.2420216348711157e-02, -6.2064238471275191e-03, - 9.8956490118614168e-02, -3.2363738790615754e-01, - -3.2870638207842147e-02, -1.5482218310094722e-01, - 2.9647782998980127e-01, -6.1576762109174010e-02, - 1.2666434428081200e-01, 2.1955834692424955e-02, - -1.8997255642944891e-03, -1.0295835477975461e-01, - 1.8208445909004639e-02, -1.1030261882048981e-01, - 4.3794875217006007e-02, 1.8518198489376456e-01, - -4.0747443172392700e-01, 1.3827664021164707e-01, - -4.1431123873109715e-03, -1.4061023435938111e-01, - 1.3942741953117222e-02, 1.9365617058920072e-02, - -8.4489815015323350e-02, -5.7838799828344145e-01, - -2.8902818751484066e-02, -2.4186610549109096e-02, - 1.2086263962861131e-02, -1.9256993416625251e-01], - [ 8.7397679312286217e-02, 1.5064561887342939e-01, - -2.1080556782941462e-01, 1.5916760566958116e-01, - -1.9624826757584166e-01, 1.5198104896205650e-02, - -1.4330248064956560e-02, 3.3068118190946301e-02, - -3.5714352226646290e-02, -1.4260141979380403e-01, - -2.4115477092387741e-01, 3.4101861982281523e-01, - -1.9029646752241479e-03, -2.7699284020832545e-02, - 1.0920088465260440e-01, -1.5239532632222408e-01, - -8.5144012779746134e-02, 5.5970342531910411e-02, - 6.9106614215268647e-02, 2.4036876137100174e-01, - -1.2301443222654272e-01, -1.1953863304856910e-01, - 3.5171852820881193e-03, -2.1104179481631563e-01, - -1.6652675336533382e-01, -6.9825511877400867e-02, - 7.3611503187800218e-03, 5.1349708686040763e-01, - -3.0172148431909446e-01, -1.0589893886410634e-01, - 3.6783462028334960e-03, -2.0553003985674112e-01, - 1.8790472746182015e-02, 1.9823557204917654e-02, - 2.5168461511062466e-02, -1.9681613761521988e-01], - [ 1.0286926661515323e-01, -5.1095768728277327e-02, - 1.3471859461003702e-01, 3.0500821091821676e-02, - -1.6790235354550213e-02, -7.0308669455806175e-03, - -3.0939649438101019e-03, 2.5665199177927620e-02, - 2.1279168221811904e-03, -2.5037640808915945e-02, - -1.2405085129935786e-01, -2.6231150724568519e-01, - -8.5787446133464614e-03, 3.9627338244596369e-02, - 2.3267441336286346e-01, -4.0743293242468487e-01, - 2.4149661576382757e-04, -6.3680910375172040e-02, - -4.3805185403053759e-01, 2.0300111728111647e-01, - -2.1099142295899803e-01, -3.4325637130492054e-01, - 2.4798870388207689e-02, 5.8652422232119368e-02, - 3.1273508409742873e-01, -6.5663309732651248e-02, - 8.4976320234436575e-02, -1.2972698624062320e-01, - -1.0136590956706468e-01, 1.7606369902531008e-01, - 1.7776135567204221e-01, 1.0742707779456324e-01, - -7.9052346006256245e-03, 7.3493627583932908e-02, - 9.9131943085618447e-03, -2.0106234106418724e-01], - [ 1.1834085391802023e-01, 8.1061946874736585e-02, - -1.6265342280467382e-01, -2.5856375159094996e-01, - 1.4258244531423583e-01, -2.5799424990869069e-02, - 1.9638649342146815e-02, 7.3355921016709083e-02, - 5.9394009978013036e-02, 1.5655633426552953e-01, - -9.8792934500238835e-02, 9.9575902680803088e-02, - 1.8527367488061958e-02, 6.3288806058580380e-02, - 3.7739330071097632e-01, 3.9157302813010320e-01, - 1.3485974151563190e-01, 2.4396726581112591e-01, - 3.6171829433890815e-02, -1.5329124928290030e-01, - 1.0994295285071572e-01, -6.2470988682208468e-02, - 7.2649015124010521e-02, 1.4656583051512045e-01, - 5.0160574613932607e-01, 4.7267639935224738e-02, - -4.2965682291764895e-02, 1.8881658695211850e-01, - -1.0776584277343945e-01, -2.6754374009298049e-02, - -7.7009726198669998e-02, 8.6417047403639091e-02, - -5.3833621971674586e-03, -8.0918819205681225e-02, - 2.1780800232539175e-02, -2.0530854451315456e-01], - [ 1.3381244122088717e-01, 1.5997082437941978e-01, - -2.1906335649966574e-01, 2.3332171765159351e-01, - -6.4994730069703827e-02, -2.7137179321886296e-02, - 4.4299490835366419e-02, 5.4082161016101568e-02, - -6.1822856454263338e-02, -6.6517101749567792e-02, - -2.9376460130324589e-01, 1.1103413626514062e-01, - -3.3806550575053815e-02, -1.8397686746205080e-01, - 3.9400318507963744e-02, 1.8758272608343995e-01, - 2.1898570040268548e-02, -5.7258401311702969e-02, - -1.2054652895121606e-01, -3.3785342949153890e-01, - -3.9112933378476634e-02, 1.2987622324621689e-01, - -7.4850924489854642e-02, -1.8237325410753219e-01, - -1.1058781873480500e-01, -2.0595217802395629e-01, - 7.5757742040963461e-03, -5.3655875317100610e-01, - -2.0896258914648322e-01, -5.5945308120122161e-02, - 8.2455318541596961e-02, 1.7624602710846482e-01, - -2.2489297400574856e-02, 5.2915934277181324e-02, - 3.8152138968863464e-02, -2.0955474796212192e-01], - [ 1.4928402852375416e-01, -4.7103999084602964e-02, - 1.5843017378423407e-01, -1.0471529213101267e-01, - 4.1822224430947852e-02, 4.9674575956627585e-03, - -1.3311898606966285e-03, 4.8322275176183468e-02, - 2.6782623911085109e-02, 1.3647784270166637e-02, - 1.0980857986376788e-01, -5.0748588072257886e-04, - -1.0361251293227987e-02, 1.1049141088458188e-01, - -4.7174567274205670e-01, -2.0220954115377396e-01, - 1.3182956708179594e-02, -1.1843903142333311e-02, - 2.0088578029524848e-01, -5.3080319758187777e-01, - -1.6308626968204651e-01, -1.6901681485606096e-01, - 7.0269705034495436e-02, 9.8708103667137601e-02, - 5.8906260202682963e-02, 1.3406466835766842e-01, - 1.3927440769859889e-02, 9.2483635015958410e-02, - -4.1489874017913597e-01, 5.3520424215223954e-02, - 3.3087563030626183e-02, -4.3491644319790287e-02, - -1.4259433018598195e-05, 8.4993306168228474e-03, - 1.9440725644047020e-02, -2.1380095141108932e-01], - [ 1.6475561582662113e-01, 1.5611161975472390e-01, - -2.2981567498448408e-01, -2.5170242091030143e-01, - 1.2572164509633985e-01, -3.5101394036068920e-02, - 1.4388788465769620e-02, 6.4367254285863956e-02, - 7.9127393952476463e-02, 4.7770236979792664e-02, - -1.4967962998375717e-01, 9.6657597995555136e-02, - 2.8600846275685401e-02, 1.0247100377903102e-03, - -1.7416826445456965e-01, -5.2903452729155642e-01, - 1.2378709088794008e-01, 1.6002483124980629e-01, - 2.3117191956384286e-01, 2.0936710152049257e-01, - 8.2739337123958492e-02, 3.7851995698789648e-01, - 1.9060641335918893e-02, -6.4314540668445608e-03, - 8.4778867125413743e-02, 1.6232730574310308e-02, - -5.8776952506303194e-02, -1.6317833767006093e-01, - 2.0541131812472332e-01, 9.4709191370766388e-02, - -3.0776520624173034e-02, 1.1938827858311649e-01, - 1.2517716200189802e-02, -1.3352132837280375e-01, - 3.8021934168930759e-02, -2.1804715486005660e-01], - [ 1.8022720312948814e-01, 8.6827318631575279e-02, - -7.4501227114099414e-02, 1.2876209736226957e-01, - -2.2037890384696301e-01, -2.5814842105572621e-02, - -3.1406758090893994e-02, 9.6294241223690305e-02, - 1.8240072112824506e-02, -7.8775576899911090e-02, - -3.0389264268442007e-02, 8.6684499299869738e-02, - 5.4365532030452843e-02, -1.1850090448995039e-01, - -2.4574663651253167e-01, 5.0606647353540021e-02, - -1.1179254494673002e-01, 1.5746625930135386e-01, - -2.3653025671773734e-01, -2.4326576699636770e-01, - 5.1089622549619594e-02, -2.8901934374460203e-01, - -2.6451534372339578e-02, 5.4045829899974578e-02, - 7.5844174532653701e-03, 9.5261278786040723e-02, - 6.5117432591824925e-02, 1.5374072905554484e-01, - 6.6944827374030014e-01, 2.5045538719576737e-03, - -5.5672913354967879e-02, 1.2051210553600417e-02, - 3.3658431259863966e-02, -3.1395677687489406e-03, - 4.7661017511192831e-02, -2.2229335830902394e-01], - [ 1.9569879043235514e-01, -9.1769753653107577e-02, - 2.7141769527027171e-01, 2.2785564717029946e-01, - 6.4057719170856758e-02, -3.7788206214948872e-03, - 9.7259287514508460e-03, 1.6918261328737952e-01, - 6.8155784376799586e-04, -1.4846652373116371e-02, - 9.7665427524227605e-02, 1.4020779957899679e-01, - 5.4803013440760974e-02, -3.7770889485239614e-02, - 2.0161818269196646e-01, 1.3431772896192445e-01, - -2.2780324141178667e-02, -1.3299949529057514e-01, - 5.6952253822862586e-01, 1.7551693338628394e-01, - -3.8851158821630960e-01, -8.2597118671349307e-02, - -5.5521724833590726e-02, 1.8126259477529724e-01, - 1.7814975368438311e-02, -6.5528218308503153e-02, - 3.7971760553771383e-02, -1.5071623691597721e-01, - 2.1592446351812103e-01, -5.6402536331480002e-04, - 4.5088070248228272e-02, 2.6712876881033590e-02, - -5.4087768899409383e-03, 6.8686308808012492e-02, - 3.2287080492312645e-02, -2.2653956175799139e-01], - [ 2.1117037773522207e-01, -5.0164247531242157e-02, - 2.6588099000556803e-01, 9.2461134185888125e-02, - -1.8638912752062822e-01, -1.3326201088302150e-02, - -1.5139012219398481e-02, 5.6526342555140038e-02, - -2.1347405801495557e-02, 4.2134620640229903e-03, - 1.6189227618448768e-01, -4.0274584225345120e-02, - -5.6430110607539385e-02, -5.8413256975427548e-02, - 5.2327365554425583e-02, 1.0547316593589447e-01, - -1.0141590903757328e-01, 2.2750086641208328e-03, - -2.9965053997941909e-01, 1.5580924251156411e-03, - -9.8801397992561726e-02, 7.0133690173366392e-01, - 2.9288631311505543e-02, 3.2187639373342534e-02, - 8.5847997795661615e-02, 2.0571325754758280e-01, - 7.4079833507648560e-02, 1.5568547966076893e-01, - -4.9689302197244593e-02, 7.8435365554783448e-02, - 4.8351735020509205e-02, -1.7685071128733182e-01, - 6.5889048949493989e-03, 8.0297089881752479e-02, - 3.9088810533135447e-02, -2.3078576520695873e-01], - [ 2.2664196503808903e-01, -3.6435359235223168e-02, - 2.7461198824493543e-01, 8.5347376974543573e-02, - -2.1059797477235808e-02, 1.1448326379020789e-02, - -2.6592754399652377e-02, 2.5891172442431810e-02, - 2.8366243844641929e-02, -2.0536075588459556e-02, - 6.6444382000443650e-05, -6.6068428617317751e-02, - 2.3676624954254568e-02, 2.2112015932022797e-01, - 3.6011261258148117e-02, 6.3110902119789564e-02, - -6.5129709470743133e-02, -4.8955274099800709e-02, - 1.5625642089103450e-01, 1.1336968441478927e-01, - 7.1887047535547766e-01, -1.4060033754799098e-01, - -4.3732646616641863e-02, -2.9113406474813336e-01, - -5.4252028224128682e-02, 8.5563234976626823e-02, - -9.8842092892354998e-03, -8.6014269752744857e-02, - -5.3867992496449059e-02, 1.0226004671603665e-01, - 2.0616418999784455e-01, -6.6321426514466278e-02, - 1.7485733797709232e-02, 1.0373147806260606e-02, - 3.9178042791043720e-02, -2.3503196865592610e-01]]), array([-1.8988227080038084e+03, -8.1652460579197793e-12, - -6.8293671717855184e-12, -5.0961343548435651e-12, - -4.6422244875241180e-12, -4.0432649621797409e-12, - -4.6750947941168519e-13, -4.2866623066103143e-13, - -3.9638626555876315e-13, -3.4647469398250028e-13, - -3.2765729675497798e-13, -3.0727463002427591e-13, - -2.9879803908775378e-13, -2.4080245315867009e-13, - -2.1775959053373055e-13, -1.8534745675222213e-13, - -1.5959779217062472e-13, -1.0879546752449559e-13, - -9.0067575069985811e-14, -5.3973885458936187e-14, - -4.6064162488080463e-14, 6.1429074771130427e-15, - 1.3659631287864453e-14, 3.4753391317142145e-14, - 8.7547004653142170e-14, 1.2585089324337818e-13, - 1.5745245909745148e-13, 2.0606204849135956e-13, - 2.1792577470203850e-13, 2.6674476798831050e-13, - 3.0421425292401405e-13, 3.1193691330212636e-13, - 3.1270969371399125e-13, 4.3446674157388007e-13, - 1.6764394233642590e-12, 2.5208822708003838e+04])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf64> {jax.result_info = "[0]"}, tensor<36xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf64> - %1 = stablehlo.reshape %0 : (tensor<1296xf64>) -> tensor<36x36xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf64>) -> tensor<36x36xf64> - %3 = stablehlo.add %1, %2 : tensor<36x36xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf64> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf64> - %7 = call @tril(%6) : (tensor<36x36xf64>) -> tensor<36x36xf64> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf64>) -> tuple, tensor<36xf64>, tensor, tensor<39001xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36x36xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<39001xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf64> - return %20, %25 : tensor<36x36xf64>, tensor<36xf64> - } - func.func private @tril(%arg0: tensor<36x36xf64>) -> tensor<36x36xf64> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf64> - return %8 : tensor<36x36xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\x0b)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ) # End paste -) - data_2024_09_30 = {} data_2024_09_30["f32"] = dict( diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 620f9cf45199..c59cc7d8076b 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -45,17 +45,13 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", SytrdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", GesvdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA", GesvdjFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 3c76598e5285..e4d6b5d4dedf 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include @@ -67,243 +66,6 @@ nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, #endif // JAX_GPU_CUDA -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -// Returns the workspace size and a descriptor for a syevd operation. -std::pair BuildSyevdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - } - return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -// Returns the workspace size and a descriptor for a syevj_batched operation. -std::pair BuildSyevjDescriptor(const dtype& dtype, bool lower, - int batch, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpuSyevjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - if (batch == 1) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})}; -} - -// Singular value decomposition using QR algorithm: gesvd - -// Returns the workspace size and a descriptor for a gesvd operation. -std::pair BuildGesvdDescriptor(const dtype& dtype, int b, int m, - int n, bool compute_uv, - bool full_matrices) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - signed char jobu, jobvt; - if (compute_uv) { - if (full_matrices) { - jobu = jobvt = 'A'; - } else { - jobu = jobvt = 'S'; - } - } else { - jobu = jobvt = 'N'; - } - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - } - return {lwork, - PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -// Returns the workspace size and a descriptor for a gesvdj operation. -std::pair BuildGesvdjDescriptor(const dtype& dtype, int batch, - int m, int n, bool compute_uv, - int econ) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = - compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - gesvdjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (batch <= 1 || m > 32 || n > 32 || econ) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor( - GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})}; -} - -#endif // JAX_GPU_CUDA - // Returns the workspace size and a descriptor for a geqrf operation. std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, int b, int n) { @@ -341,15 +103,10 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); - dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); - dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd); #ifdef JAX_GPU_CUDA dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); - dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); - #endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); @@ -372,13 +129,9 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_syevd_descriptor", &BuildSyevdDescriptor); - m.def("build_syevj_descriptor", &BuildSyevjDescriptor); - m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); m.def("build_sytrd_descriptor", &BuildSytrdDescriptor); #ifdef JAX_GPU_CUDA m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); - m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); #endif // JAX_GPU_CUDA } diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 040b5a137bc6..d054e77d2102 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -15,10 +15,8 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels.h" -#include #include #include -#include #include #include "absl/status/status.h" @@ -151,472 +149,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -static absl::Status Syevd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - std::int64_t batch = d.batch; - int output_idx = 1; // with static shapes buffers[1] is the first output - if (d.batch == -1) { - // the batch is passed as a second operand - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - (void*)&batch, reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - output_idx = 2; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[output_idx], buffers[0], - SizeOfSolverType(d.type) * batch * static_cast(d.n) * - static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[output_idx + 2]); - void* work = buffers[output_idx + 3]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -absl::Status Syevj_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - gpuSyevjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[3]); - void* work = buffers[4]; - if (d.batch == 1) { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), - d.lwork, info, params, d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Singular value decomposition using QR algorithm: gesvd - -static absl::Status Gesvd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - int64_t k = d.jobu == 'A' ? d.m : d.n; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -static absl::Status Gesvdj_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - gesvdjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (d.batch <= 1 || d.m > 32 || d.n > 32 || d.econ) { - int k = std::min(d.m, d.n); - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, - d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvdj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#endif // JAX_GPU_CUDA - // sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction static absl::Status Sytrd_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index a68aaf1ca233..c325e746b709 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -48,59 +48,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -struct SyevdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; // batch may be -1 in which case it is passed as operand. - int lwork; -}; - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -struct SyevjDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; - int lwork; -}; - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Singular value decomposition using QR algorithm: gesvd - -struct GesvdDescriptor { - SolverType type; - int batch, m, n; - int lwork; - signed char jobu, jobvt; -}; - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -struct GesvdjDescriptor { - SolverType type; - int batch, m, n; - int lwork; - gpusolverEigMode_t jobz; - int econ; -}; - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_CUDA - // sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. struct SytrdDescriptor { SolverType type; diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 937e4165a159..7082b212a5e3 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -320,38 +320,6 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{variant}", - dtype_name=dtype_name, variant=variant) - for dtype_name in ("f32", "f64") - # We use different custom calls for sizes <= 32 - for variant in ["syevj", "syevd"]) - def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"): - if not config.enable_x64.value and dtype_name == "f64": - self.skipTest("Test disabled for x32 mode") - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - if _is_required_cusolver_version_satisfied(11600): - # The underlying problem is that this test assumes the workspace size can be - # queried from an older version of cuSOLVER and then be used in a newer one. - self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") - data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - # For lax.linalg.eigh - dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] - size = dict(syevj=8, syevd=36)[variant] - rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] - atol = dict(f32=1e-2, f64=1e-10)[dtype_name] - operand = CompatTest.eigh_input((size, size), dtype) - func = lambda: CompatTest.eigh_harness((size, size), dtype) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) From 9fa6294c8a3b51d8a416a3e227be602f8623a7fd Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Apr 2025 06:57:00 -0700 Subject: [PATCH 0895/1769] Support batch_axis being int rather than Sequence[int] in initializers. PiperOrigin-RevId: 752717970 --- jax/_src/nn/initializers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 287e8f039e1d..6f117eef749f 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -267,7 +267,7 @@ def variance_scaling( Literal["uniform"]), in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_ ) -> Initializer: r""" @@ -352,7 +352,7 @@ def init(key: Array, @export def glorot_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). @@ -390,7 +390,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, @export def glorot_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot normal initializer (aka Xavier normal initializer). @@ -428,7 +428,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, @export def lecun_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun uniform initializer. @@ -464,7 +464,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, @export def lecun_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun normal initializer. @@ -500,7 +500,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, @export def he_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He uniform initializer (aka Kaiming uniform initializer). @@ -538,7 +538,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2, @export def he_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He normal initializer (aka Kaiming normal initializer). From 91485b1be514de015c4404dc102ae1f61d1bd5aa Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 07:53:48 -0700 Subject: [PATCH 0896/1769] [JAX] Remove jaxlib/xla directory. These were temporary forwarding shims, now no longer needed. PiperOrigin-RevId: 752733428 --- jaxlib/xla/BUILD | 123 -------------------------------- jaxlib/xla/nb_class_ptr.h | 21 ------ jaxlib/xla/py_array.h | 21 ------ jaxlib/xla/py_client.h | 21 ------ jaxlib/xla/py_device.h | 21 ------ jaxlib/xla/py_device_list.h | 21 ------ jaxlib/xla/py_executable.h | 21 ------ jaxlib/xla/python_ref_manager.h | 21 ------ jaxlib/xla/pytree.h | 21 ------ jaxlib/xla/sharding.h | 21 ------ jaxlib/xla/traceback.h | 21 ------ jaxlib/xla/xla_client.py | 21 ------ jaxlib/xla/xla_extension.py | 20 ------ 13 files changed, 374 deletions(-) delete mode 100644 jaxlib/xla/BUILD delete mode 100644 jaxlib/xla/nb_class_ptr.h delete mode 100644 jaxlib/xla/py_array.h delete mode 100644 jaxlib/xla/py_client.h delete mode 100644 jaxlib/xla/py_device.h delete mode 100644 jaxlib/xla/py_device_list.h delete mode 100644 jaxlib/xla/py_executable.h delete mode 100644 jaxlib/xla/python_ref_manager.h delete mode 100644 jaxlib/xla/pytree.h delete mode 100644 jaxlib/xla/sharding.h delete mode 100644 jaxlib/xla/traceback.h delete mode 100644 jaxlib/xla/xla_client.py delete mode 100644 jaxlib/xla/xla_extension.py diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD deleted file mode 100644 index e9fb7e791574..000000000000 --- a/jaxlib/xla/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load( - "//jaxlib:jax.bzl", - "jax_visibility", - "pytype_library", -) - -licenses(["notice"]) - -package( - default_applicable_licenses = [], - default_visibility = ["//jax:internal"], -) - -package_group( - name = "xla_python", - includes = [ - "//jax:internal", - ], - packages = ["@xla//xla/python/..."], -) - -cc_library( - name = "py_client", - hdrs = [ - "py_array.h", - "py_client.h", - "py_device.h", - "py_device_list.h", - "py_executable.h", - "sharding.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/py_client"), - deps = [ - "//jaxlib:py_client", - ], -) - -cc_library( - name = "pytree", - hdrs = [ - "pytree.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/pytree"), - deps = [ - "//jaxlib:pytree", - ], -) - -cc_library( - name = "nb_class_ptr", - hdrs = [ - "nb_class_ptr.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/nb_class_ptr"), - deps = [ - "//jaxlib:nb_class_ptr", - ], -) - -cc_library( - name = "traceback", - hdrs = [ - "traceback.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/traceback"), - deps = [ - "//jaxlib:traceback", - ], -) - -cc_library( - name = "python_ref_manager", - hdrs = [ - "python_ref_manager.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - visibility = jax_visibility("jaxlib/python_ref_manager"), - deps = [ - "//jaxlib:python_ref_manager", - ], -) - -pytype_library( - name = "xla_client", - srcs = ["xla_client.py"], - visibility = [":xla_python"], - deps = [ - ":xla_extension", - "//jaxlib:xla_client", - ], -) - -pytype_library( - name = "xla_extension", - srcs = ["xla_extension.py"], - visibility = [":xla_python"], - deps = [ - "//jaxlib:_jax", - ], -) diff --git a/jaxlib/xla/nb_class_ptr.h b/jaxlib/xla/nb_class_ptr.h deleted file mode 100644 index 0b539115a1cb..000000000000 --- a/jaxlib/xla/nb_class_ptr.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_NB_CLASS_PTR_H_ -#define JAXLIB_XLA_NB_CLASS_PTR_H_ - -#include "jaxlib/nb_class_ptr.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h deleted file mode 100644 index fee6d2e24f16..000000000000 --- a/jaxlib/xla/py_array.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PY_ARRAY_H_ -#define JAXLIB_XLA_PY_ARRAY_H_ - -#include "jaxlib/py_array.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h deleted file mode 100644 index b7e90fe5e24c..000000000000 --- a/jaxlib/xla/py_client.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PY_CLIENT_H_ -#define JAXLIB_XLA_PY_CLIENT_H_ - -#include "jaxlib/py_client.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h deleted file mode 100644 index 2b3beff864ae..000000000000 --- a/jaxlib/xla/py_device.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PY_DEVICE_H_ -#define JAXLIB_XLA_PY_DEVICE_H_ - -#include "jaxlib/py_device.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h deleted file mode 100644 index 1b75286c3d3d..000000000000 --- a/jaxlib/xla/py_device_list.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PY_DEVICE_LIST_H_ -#define JAXLIB_XLA_PY_DEVICE_LIST_H_ - -#include "jaxlib/py_device_list.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h deleted file mode 100644 index 5cc0f2d6ac6c..000000000000 --- a/jaxlib/xla/py_executable.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PY_EXECUTABLE_H_ -#define JAXLIB_XLA_PY_EXECUTABLE_H_ - -#include "jaxlib/py_executable.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/python_ref_manager.h b/jaxlib/xla/python_ref_manager.h deleted file mode 100644 index 09f995c198e2..000000000000 --- a/jaxlib/xla/python_ref_manager.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PYTHON_REF_MANAGER_H_ -#define JAXLIB_XLA_PYTHON_REF_MANAGER_H_ - -#include "jaxlib/python_ref_manager.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h deleted file mode 100644 index dcb7089674f5..000000000000 --- a/jaxlib/xla/pytree.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_PYTREE_H_ -#define JAXLIB_XLA_PYTREE_H_ - -#include "jaxlib/pytree.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_PYTREE_H_ diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h deleted file mode 100644 index f47dd265651a..000000000000 --- a/jaxlib/xla/sharding.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_SHARDING_H_ -#define JAXLIB_XLA_SHARDING_H_ - -#include "jaxlib/sharding.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_SHARDING_H_ diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h deleted file mode 100644 index bb993233850a..000000000000 --- a/jaxlib/xla/traceback.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2025 The JAX Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_XLA_TRACEBACK_H_ -#define JAXLIB_XLA_TRACEBACK_H_ - -#include "jaxlib/traceback.h" // IWYU pragma: keep - -#endif // JAXLIB_XLA_TRACEBACK_H_ diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py deleted file mode 100644 index 4eb4a2d7939f..000000000000 --- a/jaxlib/xla/xla_client.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2017 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An XLA client in Python.""" - -# ruff: noqa: F401 -# ruff: noqa: F403 - -from jaxlib.xla_client import * # pylint: disable=wildcard-import -from jaxlib.xla_client import _xla # pylint: disable=unused-import diff --git a/jaxlib/xla/xla_extension.py b/jaxlib/xla/xla_extension.py deleted file mode 100644 index 798919c01450..000000000000 --- a/jaxlib/xla/xla_extension.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An XLA client in Python.""" - -# ruff: noqa: F401 -# ruff: noqa: F403 - -from jaxlib._jax import * # pylint: disable=wildcard-import From 97ff21e7fa884f31b051203e4eec9ca8dc81587f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 15:56:24 +0000 Subject: [PATCH 0897/1769] Fix test failure in autodidax due to compile() API change. --- docs/autodidax.ipynb | 4 +++- docs/autodidax.md | 4 +++- docs/autodidax.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index b6f12b624f8b..93710e964bea 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2019,7 +2019,9 @@ "\n", " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", - " compiled = xb.get_backend(None).compile(output.getvalue())\n", + " backend = xb.get_backend(None)\n", + " compiled = backend.compile(\n", + " output.getvalue(), backend.devices()[:1])\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 1c375e21227c..ab375aefca27 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1589,7 +1589,9 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile( + output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: diff --git a/docs/autodidax.py b/docs/autodidax.py index 6329234224cb..3390eb286073 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1581,7 +1581,9 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile( + output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: From c3cad1d9fc197e3b0bcae88f38a9eafe92c51521 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 09:20:38 -0700 Subject: [PATCH 0898/1769] PR #28391: Add version guard in autodidax test. Imported from GitHub PR https://github.com/jax-ml/jax/pull/28391 Copybara import of the project: -- 8f85229ba059f301bfb0640ba9a304dab7a345ca by Peter Hawkins : Add version guard in autodidax test. Merging this change closes #28391 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/28391 from hawkinsp:add 8f85229ba059f301bfb0640ba9a304dab7a345ca PiperOrigin-RevId: 752764010 --- docs/autodidax.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/autodidax.py b/docs/autodidax.py index 3390eb286073..9531ef7694c5 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1548,6 +1548,7 @@ def __eq__(self, other): from jax.extend.mlir import ir from jax.extend.mlir.dialects import func from jax.extend.mlir.dialects import stablehlo as hlo +import jax._src.lib from jax._src import xla_bridge as xb class MlirContext(NamedTuple): @@ -1582,8 +1583,11 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - compiled = backend.compile( - output.getvalue(), backend.devices()[:1]) + if jax._src.lib.version >= (0, 6, 1): + compiled = backend.compile( + output.getvalue(), backend.devices()[:1]) + else: + compiled = backend.compile(output.getvalue()) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: From 2da0c0aee3d5049e8970832c11d5e16e84d26a9d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 11:33:15 -0700 Subject: [PATCH 0899/1769] Update autodidax.ipynb and autodidax.md PiperOrigin-RevId: 752816808 --- docs/autodidax.ipynb | 8 ++++++-- docs/autodidax.md | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 93710e964bea..07c7d7e84ff0 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -1986,6 +1986,7 @@ "from jax.extend.mlir import ir\n", "from jax.extend.mlir.dialects import func\n", "from jax.extend.mlir.dialects import stablehlo as hlo\n", + "import jax._src.lib\n", "from jax._src import xla_bridge as xb\n", "\n", "class MlirContext(NamedTuple):\n", @@ -2020,8 +2021,11 @@ " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", " backend = xb.get_backend(None)\n", - " compiled = backend.compile(\n", - " output.getvalue(), backend.devices()[:1])\n", + " if jax._src.lib.version >= (0, 6, 1):\n", + " compiled = backend.compile(\n", + " output.getvalue(), backend.devices()[:1])\n", + " else:\n", + " compiled = backend.compile(output.getvalue())\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index ab375aefca27..e78aeded41c0 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1556,6 +1556,7 @@ import io from jax.extend.mlir import ir from jax.extend.mlir.dialects import func from jax.extend.mlir.dialects import stablehlo as hlo +import jax._src.lib from jax._src import xla_bridge as xb class MlirContext(NamedTuple): @@ -1590,8 +1591,11 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - compiled = backend.compile( - output.getvalue(), backend.devices()[:1]) + if jax._src.lib.version >= (0, 6, 1): + compiled = backend.compile( + output.getvalue(), backend.devices()[:1]) + else: + compiled = backend.compile(output.getvalue()) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: From 71bf484cf78894d29b233a4b6598483f54715c84 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 12:05:53 -0700 Subject: [PATCH 0900/1769] Relax test tolerance for PureCallbackTest.test_can_take_grad_of_pure_callback_with_custom_jvp. Fixes a CI failure on TPU v6e: https://github.com/jax-ml/jax/actions/runs/14705775316/job/41265793158 ``` _____ PureCallbackTest.test_can_take_grad_of_pure_callback_with_custom_jvp _____ [gw2] linux -- Python 3.11.12 /usr/bin/python3.11 tests/python_callback_test.py:836: in test_can_take_grad_of_pure_callback_with_custom_jvp np.testing.assert_allclose(out, jnp.cos(2.)) /usr/lib/python3.11/contextlib.py:81: in inner return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-07, atol=0 E E Mismatched elements: 1 / 1 (100%) E Max absolute difference among violations: 5.9604645e-08 E Max relative difference among violations: 1.4322983e-07 E ACTUAL: array(-0.4161468, dtype=float32) E DESIRED: array(-0.41614687, dtype=float32) ``` PiperOrigin-RevId: 752830528 --- tests/python_callback_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 13df0c1dd376..17087be35f8b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -833,7 +833,7 @@ def sin_jvp(xs, ts): def f(x): return sin(x) out = f(2.) - np.testing.assert_allclose(out, jnp.cos(2.)) + np.testing.assert_allclose(out, jnp.cos(2.), atol=1e-7) def test_callback_inside_of_cond(self): From 2061be4bafa38137916f3487bf26c3de9a40bbef Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 12:45:19 -0700 Subject: [PATCH 0901/1769] Set PYTHONWARNINGS=error for jax_multiplatform_test. This is already the default for most of our other test environments, but apparently missed here. PiperOrigin-RevId: 752844454 --- jaxlib/jax.bzl | 6 ++++-- tests/sparse_bcoo_bcsr_test.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 632ff2047078..075bf3ed2de0 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -252,6 +252,9 @@ def jax_multiplatform_test( else: fail("Must set a main file to test multiple source files.") + env = dict(env) + env.setdefault("PYTHONWARNINGS", "error") + for backend in ALL_BACKENDS: if shard_count == None or type(shard_count) == type(0): test_shards = shard_count @@ -565,8 +568,7 @@ def jax_py_test( env = {}, **kwargs): env = dict(env) - if "PYTHONWARNINGS" not in env: - env["PYTHONWARNINGS"] = "error" + env.setdefault("PYTHONWARNINGS", "error") deps = kwargs.get("deps", []) test_deps = _get_test_deps(deps, backend_independent = True) kwargs["deps"] = test_deps diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 1224717570d1..feac4882c9a3 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -974,6 +974,7 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertEqual(out.nse, expected_nse) @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") + @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) def test_bcoo_spdot_general_ad_bug(self): # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) From 475c4764a1267f50b408e41d7e09db5bdbe99c02 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Tue, 29 Apr 2025 20:53:16 +0000 Subject: [PATCH 0902/1769] chore: switch to pypi hub instead of underlying repos --- BUILD.bazel | 6 +++--- jaxlib/jax.bzl | 44 ++++++++++++++++++++-------------------- jaxlib/tools/BUILD.bazel | 36 ++++++++++++++++---------------- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 906ae83796b8..ddf59c0290d7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -71,9 +71,9 @@ py_binary( srcs = ["build_wheel.py"], deps = [ "//jaxlib/tools:build_utils", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 632ff2047078..60478ddeaeaf 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -67,21 +67,21 @@ PLATFORM_TAGS_DICT = { _GPU_PYPI_WHEEL_DEPS = [ "//:jax_wheel_with_internal_test_util", - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", + "@pypi//jaxlib", + "@pypi//jax_cuda12_plugin", + "@pypi//jax_cuda12_pjrt", ] _CPU_PYPI_WHEEL_DEPS = [ "//:jax_wheel_with_internal_test_util", - "@pypi_jaxlib//:pkg", + "@pypi//jaxlib", ] # TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): return [] - return ["@pypi_zstandard//:pkg"] + return ["@pypi//zstandard"] def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): if HERMETIC_PYTHON_VERSION in excluded_py_versions: @@ -89,26 +89,26 @@ def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): return [package] _py_deps = { - "absl/logging": ["@pypi_absl_py//:pkg"], - "absl/testing": ["@pypi_absl_py//:pkg"], - "absl/flags": ["@pypi_absl_py//:pkg"], - "cloudpickle": get_optional_dep("@pypi_cloudpickle//:pkg"), - "colorama": get_optional_dep("@pypi_colorama//:pkg"), - "epath": get_optional_dep("@pypi_etils//:pkg"), # etils.epath - "filelock": get_optional_dep("@pypi_filelock//:pkg"), - "flatbuffers": ["@pypi_flatbuffers//:pkg"], - "hypothesis": ["@pypi_hypothesis//:pkg"], + "absl/logging": ["@pypi//absl_py"], + "absl/testing": ["@pypi//absl_py"], + "absl/flags": ["@pypi//absl_py"], + "cloudpickle": get_optional_dep("@pypi//cloudpickle"), + "colorama": get_optional_dep("@pypi//colorama"), + "epath": get_optional_dep("@pypi//etils"), # etils.epath + "filelock": get_optional_dep("@pypi//filelock"), + "flatbuffers": ["@pypi//flatbuffers"], + "hypothesis": ["@pypi//hypothesis"], "magma": [], - "matplotlib": get_optional_dep("@pypi_matplotlib//:pkg"), + "matplotlib": get_optional_dep("@pypi//matplotlib"), "mpmath": [], - "opt_einsum": ["@pypi_opt_einsum//:pkg"], - "pil": get_optional_dep("@pypi_pillow//:pkg"), - "portpicker": get_optional_dep("@pypi_portpicker//:pkg"), - "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], - "numpy": ["@pypi_numpy//:pkg"], - "scipy": ["@pypi_scipy//:pkg"], + "opt_einsum": ["@pypi//opt_einsum"], + "pil": get_optional_dep("@pypi//pillow"), + "portpicker": get_optional_dep("@pypi//portpicker"), + "ml_dtypes": ["@pypi//ml_dtypes"], + "numpy": ["@pypi//numpy"], + "scipy": ["@pypi//scipy"], "tensorflow_core": [], - "tensorstore": get_optional_dep("@pypi_tensorstore//:pkg"), + "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], "zstandard": get_zstandard(), } diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 78db77b8521d..219096836ffc 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -75,9 +75,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -133,9 +133,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -162,9 +162,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -250,9 +250,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -309,9 +309,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -399,9 +399,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) From 58b4bca9d85d285cc5ff3cad5799f5b867956782 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Tue, 29 Apr 2025 21:32:24 +0000 Subject: [PATCH 0903/1769] init --- jax/_src/cudnn/fused_attention_stablehlo.py | 6 ++- tests/fused_attention_stablehlo_test.py | 45 +++++++++++++++++---- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 463665f6fdca..39d97255fa4a 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -407,7 +407,7 @@ def _dot_product_attention_fwd( variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=False or return_residual) if return_residual: - return outputs + return tuple(outputs) else: return outputs[0] @@ -427,7 +427,7 @@ def _dot_product_attention_fwd_rule( res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, outputs[1], outputs[0]) if return_residual: - return outputs, res + return tuple(outputs), res else: return outputs[0], res @@ -436,6 +436,8 @@ def _dot_product_attention_bwd_rule( sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, activation, fwd_output) = res + if return_residual: + grad_output = grad_output[0] grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, activation, fwd_output, grad_output, scale=scale, seed=seed, diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 084ea3b3b0ae..64e0f4377462 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -741,7 +741,7 @@ def generate_segment_mask(segment_ids, dtype): @jtu.run_on_devices("cuda") def test_sdpa_residual(self): - k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) query = jax.random.normal( k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) key = jax.random.normal( @@ -750,14 +750,43 @@ def test_sdpa_residual(self): k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) grad = jax.random.normal( k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad_stat = jax.random.normal( + k5, (4, 4, 1024), dtype=jnp.float32) - jitted_sdpa_inference = jax.jit( - partial( - dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, - dropout_rate=0, return_residual=True), - ) - outs = jitted_sdpa_inference(query, key, value) - assert len(outs) == 2 + devices = np.array(jax.local_devices()[:2]) + with Mesh(devices, ("dp")) as mesh: + qkv_spec = PartitionSpec("dp", None, None, None) + stat_spec = PartitionSpec("dp", None, None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + stat_sharding = NamedSharding(mesh, stat_spec) + + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + grad = jax.device_put(grad, qkv_sharding) + grad_stat = jax.device_put(grad_stat, stat_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding), + out_shardings=(qkv_sharding, stat_sharding) + ) + + outs = jitted_sdpa_inference(query, key, value) + assert len(outs) == 2 + + def train(query, key, value, grads): + outs, grad_fn = jax.vjp(partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), query, key, value) + return outs, grad_fn(grads) + jitted_sdpa_train = jax.jit(train, + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, (qkv_sharding, stat_sharding)), + out_shardings=((qkv_sharding, stat_sharding), (qkv_sharding, qkv_sharding, qkv_sharding))) + outs = jitted_sdpa_train(query, key, value, (grad, grad_stat)) + assert len(outs) == 2 @jtu.run_on_devices("cuda") def test_layouts(self): From 7772acf44d47723161c3c53eb0f552cfacb01d80 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Tue, 29 Apr 2025 15:55:28 -0700 Subject: [PATCH 0904/1769] [JAX] Modify `BatchedCopyToDeviceWithSharding` to use dst shardigns when devices and memory kinds are the same. PiperOrigin-RevId: 752914419 --- jaxlib/py_array.cc | 17 +++++++++++++++-- tests/api_test.py | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 103c003fa89b..6277e1dfc702 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1153,6 +1153,7 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( }; absl::flat_hash_map batches; + auto traceback = Traceback::Get(); for (int i = 0; i < py_arrays.size(); ++i) { const auto& py_array = py_arrays[i]; const auto& dst_sharding = dst_shardings[i]; @@ -1171,7 +1172,20 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && array_cs == ifrt::ArrayCopySemantics::kReuseInput) { - results[i] = py_arrays[i]; + if (jax::ShardingEqual(py_array.sharding(), dst_sharding)) { + results[i] = py_arrays[i]; + } else { + absl::Span shape_span = py_array.shape(); + // We can reuse the input array despite the sharding being different. + // This is because this code expects no resharding is necessary, which + // has been verified by the code invoking this method. + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_sharding, py_array.py_client(), traceback, + tsl::FormRef(ifrt_array_ptr), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } continue; } @@ -1212,7 +1226,6 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( } } - auto traceback = Traceback::Get(); for (auto& [i, ifrt_array] : ifrt_arrays) { const auto& py_array = py_arrays[i]; absl::Span shape_span = py_array.shape(); diff --git a/tests/api_test.py b/tests/api_test.py index 8a6a022a5677..1149cf78e626 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -511,6 +511,27 @@ def test_device_put_aliasing(self): may_alias=False, donate=False) self.assertNotEqual(id(arr), id(out)) + def test_device_put_aliasing_with_diff_compatible_sharding(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y") + ) + x = jax.device_put( + np.arange(16).reshape((4, 4)), + jax.NamedSharding(mesh, P("x", None)), + ) + expanded_mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((1, 2, 1)), ("replicas", "x", "y") + ) + dst_sharding = jax.NamedSharding(expanded_mesh, P("x", None)) + # No transfer should happen because the array is aliased to compatible + # sharding that only has a mesh with an additional dimension of size 1. + with jax.transfer_guard_device_to_device("disallow_explicit"): + res = jax.device_put(x, dst_sharding, may_alias=True) + self.assertEqual(dst_sharding, res.sharding) + @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), From e403f8fcb4881786e84d24e30123397878ee64a5 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Tue, 29 Apr 2025 16:01:52 -0700 Subject: [PATCH 0905/1769] Use 32-bit launch ids in `PyLoadedExecutable` `xla::ExecuteOptions::launch_id` is `int32_t`, so it makes sense to align `PyLoadedExecutable` with it. Since signed int overflow is UB in C++, the atomic that keeps track of the next launch id is defined as unsigned (so that overflow does not trigger UB), and then a cast is inserted on a newly minted launch id. The current implementation may trigger UB because `options.launch_id = GetNextLaunchId()` implicitly casts `int64_t` to `int32_t`, causing signed int overflow. PiperOrigin-RevId: 752916687 --- jaxlib/py_executable.cc | 6 ++++-- jaxlib/py_executable.h | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 16cd512a4007..9141111ae168 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -416,8 +417,9 @@ std::optional> PyLoadedExecutable::GetOutputShardings() return ifrt_loaded_executable_->GetOutputShardings(); } -int64_t PyLoadedExecutable::GetNextLaunchId() { - return next_launch_id_.fetch_add(1, std::memory_order_relaxed); +int32_t PyLoadedExecutable::GetNextLaunchId() { + return absl::bit_cast( + next_launch_id_.fetch_add(1, std::memory_order_relaxed)); } void PyLoadedExecutable::KeepAlive(nb::object obj) { diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index b6a39c6968b8..5c7f57301b82 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -215,7 +215,7 @@ class PyLoadedExecutable { const ifrt::ExecuteOptions& options() const { return options_; } // Returns a unique launch ID to use for the next execution. - int64_t GetNextLaunchId(); + int32_t GetNextLaunchId(); const std::optional& fingerprint() const { return fingerprint_; } @@ -235,7 +235,7 @@ class PyLoadedExecutable { std::optional fingerprint_; // Launch ID to use for the next execution. - std::atomic next_launch_id_; + std::atomic next_launch_id_; // The options to pass to `executable_.Execute`. ifrt::ExecuteOptions options_; From 7cde3774153df9ea159b3294e9c6d47c6189e65e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Apr 2025 17:10:18 -0700 Subject: [PATCH 0906/1769] Round-robin pytest tests across GPUs. PiperOrigin-RevId: 752938886 --- conftest.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/conftest.py b/conftest.py index fed4564bbc1c..fa0e6de94346 100644 --- a/conftest.py +++ b/conftest.py @@ -21,6 +21,7 @@ def add_imports(doctest_namespace): import jax import numpy + doctest_namespace["jax"] = jax doctest_namespace["lax"] = jax.lax doctest_namespace["jnp"] = jax.numpy @@ -29,8 +30,8 @@ def add_imports(doctest_namespace): # A pytest hook that runs immediately before test collection (i.e. when pytest # loads all the test cases to run). When running parallel tests via xdist on -# Cloud TPU, we use this hook to set the env vars needed to run multiple test -# processes across different TPU chips. +# GPU or Cloud TPU, we use this hook to set the env vars needed to run multiple +# test processes across different chips. # # It's important that the hook runs before test collection, since jax tests end # up initializing the TPU runtime on import (e.g. to query supported test @@ -43,17 +44,31 @@ def add_imports(doctest_namespace): # https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result # for details. # -# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an +# For TPU, the env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an # effect. We do this to minimize any effect on non-TPU tests, and as a pointer # in test code to this "magic" hook. TPU tests should not specify more xdist # workers than the number of TPU chips. +# +# For GPU, the env var JAX_ENABLE_CUDA_XDIST must be set equal to the number of +# CUDA devices. Test processes will be assigned in round robin fashion across +# the devices. def pytest_collection() -> None: - if not os.environ.get("JAX_ENABLE_TPU_XDIST", None): - return - # When running as an xdist worker, will be something like "gw0" - xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") - if not xdist_worker_name.startswith("gw"): - return - xdist_worker_number = int(xdist_worker_name[len("gw"):]) - os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) - os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + num_cuda_devices = int(num_cuda_devices) + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) From 806190d05ff50c292e342c83b79d6b77002f025e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 29 Apr 2025 19:51:05 -0700 Subject: [PATCH 0907/1769] [Pallas TPU] Add better support for BoundedSlice and other BlockDim types to emit_pipeline and fusions PiperOrigin-RevId: 752979105 --- jax/_src/pallas/fuser/block_spec.py | 179 +++++++++++++++++++++------- jax/_src/pallas/mosaic/pipeline.py | 166 +++++++++++++++++--------- 2 files changed, 245 insertions(+), 100 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index d8d01005622a..2795b1e52f6a 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -34,6 +34,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.fuser import fuser_utils +from jax._src.state import indexing import jax.numpy as jnp import numpy as np @@ -95,9 +96,17 @@ def wrapped(*args): def _block_size(dim: pallas_core.Element | int | None) -> int | None: - if isinstance(dim, pallas_core.Element): - return dim.block_size - return dim + match dim: + case ( + pallas_core.Element() + | pallas_core.BoundedSlice() + | pallas_core.Blocked() + ): + return dim.block_size + case pallas_core.Squeezed() | None: + return None + case _: + return dim # pytype: disable=bad-return-type @dataclasses.dataclass @@ -176,10 +185,13 @@ def get_out_block_indices(self): _illegal = object() + class _SpEnv(threading.local): + def __init__(self): self.scalar_prefetch = None + _sp_env = _SpEnv() @@ -427,7 +439,7 @@ def make_kernel_function( bs_env, scalar_prefetch_fn_env = block_spec_env def _remove_nones( - shape: tuple[pallas_core.Element | int | None, ...] | None + shape: tuple[pallas_core.BlockDim | int | None, ...] | None, ) -> tuple[int, ...]: assert shape is not None new_shape = tuple(_block_size(s) for s in shape) @@ -860,10 +872,75 @@ def new_index_map(*args): def _slice_eval_rule(ctx, x, **params): del params out_block_shape = ctx.out_block_specs[0].block_shape - assert len(x.shape) == sum(1 for bs in out_block_shape if bs is not None) + assert len(x.shape) == sum( + 1 + for bs in out_block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) return x +def _offset_indexer( + bs: pallas_core.BlockDim | int | None, + indexer, + slice_start, + slice_size, +): + # Short-circuit if the slice start is just at zero. + print('BS', bs, indexer, slice_start, slice_size) + if isinstance(slice_start, int) and slice_start == 0: + return indexer + match bs: + case None | pallas_core.Squeezed(): + return indexer + slice_start + case pallas_core.Element(block_size): + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexer + slice_start + case int() | pallas_core.Blocked(): + block_size = _block_size(bs) + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + # indexer is a block index so we need to offset it by the block offset. + return indexer + slice_start // block_size + case pallas_core.BoundedSlice(block_size): + assert isinstance(indexer, indexing.Slice) + _maybe_static_check( + indexer.start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + indexer.size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexing.ds(indexer.start + slice_start, indexer.size) + case _: + raise ValueError(f'Unsupported block size {bs}') + + +def _maybe_static_check(pred: bool, msg: str): + # Tries to emit a static error if possible, otherwise falls back to runtime. + from jax.experimental import checkify + + if isinstance(pred, jax.Array): + checkify.check(pred, msg, debug=True) + else: + if not pred: + raise ValueError(msg) + + @register_pull_block_spec_rule(lax.slice_p) def _slice_rule( ctx: PullRuleContext, @@ -879,28 +956,42 @@ def _slice_rule( slice_sizes = tuple( int(end - start) for start, end in zip(start_indices, limit_indices) ) + # Do some basic checks for bs, slice_start, slice_size in zip( block_spec.block_shape, start_indices, slice_sizes ): - if bs is None: - continue - block_size = _block_size(bs) - assert ( - slice_start % block_size == 0 - ), (start_indices, block_spec.block_shape) - assert slice_size % block_size == 0, (slice_sizes, block_spec.block_shape) - offsets = tuple( - slice_start // _block_size(bs) if bs is not None else slice_start - for slice_start, bs in zip(start_indices, block_spec.block_shape) - ) - - def _offset(x, i): - return x + i if i != 0 else x + match bs: + case None | pallas_core.Squeezed(): + continue + case pallas_core.BoundedSlice() | pallas_core.Element(): + block_size = _block_size(bs) + # Require that block_size no bigger than the slice. + if block_size > slice_size: + raise ValueError( + f'Block size {block_size} is larger than the slice size' + f' {slice_size}' + ) + case _: + block_size = _block_size(bs) + assert slice_start % block_size == 0, ( + start_indices, + block_spec.block_shape, + ) + assert slice_size % block_size == 0, ( + slice_sizes, + block_spec.block_shape, + ) def new_index_map(*args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) - return tuple(_offset(i, o) for i, o in zip(idx, offsets)) + idx = tuple( + _offset_indexer(bs, i, start, size) + for bs, i, start, size in zip( + block_spec.block_shape, idx, start_indices, slice_sizes, strict=True + ) + ) + return idx return [pallas_core.BlockSpec(block_spec.block_shape, new_index_map)] @@ -917,20 +1008,6 @@ def _dynamic_slice_usage_rule(ctx, used_out: set[Usage], **params): return [set()] * len(ctx.avals_in) -def _offset(x, i, s): - from jax.experimental import checkify - - if s is not None: - pred = i % s == 0 - if isinstance(pred, jax.Array): - checkify.check(i % s == 0, 'Invalid index', debug=True) - else: - if not pred: - raise ValueError('Invalid index') - offset = jax.lax.div(i, s) if s is not None else i - return x + offset - - @register_eval_rule(lax.dynamic_slice_p) def _dynamic_slice_eval_rule(ctx, x, *args, **params): del ctx, params @@ -944,7 +1021,6 @@ def _dynamic_slice_rule( *, slice_sizes: tuple[int, ...], ): - del slice_sizes def new_index_map(*args): slice_starts = ctx.scalar_prefetch_fn() @@ -966,11 +1042,15 @@ def new_index_map(*args): # multiples of the block sizes. The indices of the block that correspond to # the slice are then given by (i // b_l, j // b_m, k // b_n). # We then add these block indices to block indices produced by the index - # map. + # map + print('BLOCK SHAPE', block_spec.block_shape) + print('INDEXER', idx) + print('SLICE starts', slice_starts) + print('SLICE sizes', slice_sizes) block_indices = tuple( - _offset(i, o, _block_size(s)) - for i, o, s in zip( - idx, slice_starts, block_spec.block_shape, strict=True + _offset_indexer(s, i, start, size) + for i, s, start, size in zip( + idx, block_spec.block_shape, slice_starts, slice_sizes, strict=True ) ) return block_indices @@ -990,7 +1070,7 @@ def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] if any(is_element_block): raise NotImplementedError( - "Concatenation with Element indexing is not yet supported." + 'Concatenation with Element indexing is not yet supported.' ) block_dim = block_shape[dimension] if block_dim is None: @@ -1038,11 +1118,11 @@ def _concatenate_rule( is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] if any(is_element_block): raise NotImplementedError( - "Concatenation with Element indexing is not yet supported." + 'Concatenation with Element indexing is not yet supported.' ) num_blocks = [] block_dim = block_shape[dimension] - if block_dim is None: + if block_dim is None or isinstance(block_dim, pallas_core.Squeezed): block_dim = 1 if block_dim == sum(aval.shape[dimension] for aval in ctx.avals_in): # pytype: disable=attribute-error # Handle special case if the block contains all of the concatenated @@ -1114,7 +1194,9 @@ def _broadcast_in_dim_eval_rule( if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error # Scalar -> Array broadcast block_spec = eval_ctx.out_block_specs[0] - shape = tuple(_block_size(s) for s in block_spec.block_shape if s is not None) + shape = tuple( + _block_size(s) for s in block_spec.block_shape if s is not None + ) return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape) return x @@ -1149,10 +1231,17 @@ def _transpose_eval_rule( ): block_spec = eval_ctx.out_block_specs[0] block_shape = block_spec.block_shape - block_shape_no_nones = tuple(bs for bs in block_shape if bs is not None) + block_shape_no_nones = tuple( + bs + for bs in block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) block_dims_iter = iter(range(len(block_shape_no_nones))) expanded_block_dims = [ - None if bs is None else next(block_dims_iter) for bs in block_shape + None + if (bs is None or isinstance(bs, pallas_core.Squeezed)) + else next(block_dims_iter) + for bs in block_shape ] assert next(block_dims_iter, None) is None permuted_block_dims = [expanded_block_dims[p] for p in permutation] diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 6865713cd69a..dd83dab3c3c5 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -38,7 +38,6 @@ SMEM = tpu_core.TPUMemorySpace.SMEM VMEM = tpu_core.TPUMemorySpace.VMEM -DMA = tpu_core.SemaphoreType.DMA REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics PARALLEL = tpu_core.PARALLEL @@ -104,8 +103,10 @@ def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) -def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: - if s % multiple == 0: +def _round_up_to_nearest_multiple( + s: int | jax.Array, multiple: int +) -> int | jax.Array: + if isinstance(s, int) and s % multiple == 0: return s # Subtract off the remainder, then add multiple return s - s % multiple + multiple @@ -119,10 +120,51 @@ def _make_block_ds( assert isinstance(out, pl.Slice) return out +def _create_blocked_slice(block_index: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int): + block_start = block_size * block_index + if (dim_rem := dim_size % block_size) == 0: + return pl.ds(block_start, block_size) + if block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + num_blocks = pl.cdiv(dim_size, block_size) + is_last = block_index == num_blocks - 1 + rounded_size = jnp.where( + is_last, + _round_up_to_nearest_multiple(dim_rem % block_size, tiling), + block_size, + ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(block_index * block_size, rounded_size) + +def _create_bounded_slice(slice_start: jax.Array | int, + slice_size: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int): + if block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + # We assume by construction that slice_size <= block_size. We also assume + # that the slice_start is already aligned to the tiling. + + # If we are out of bound, we need to round the slice size down to the nearest + # multiple of the tiling. + is_oob = slice_start + slice_size > dim_size + remaining = dim_size - slice_start + rounded_size = jnp.where( + is_oob, + _round_up_to_nearest_multiple(remaining, tiling), + slice_size, + ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(slice_start, rounded_size) + def _make_block_slice( block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, tiling: int -) -> pl.Slice | slice: +) -> pl.Slice | slice | int | jax.Array: # Computes a slice given a block index and block size. In the default case, # we return slice(block_index * block_size, (block_index + 1) * block_size). # However, if the total size of the ref does not divide block size and we are @@ -130,41 +172,30 @@ def _make_block_slice( # that contains the block. match block_size: case pl.Blocked(): - block_start = block_size.block_size * block_index - block_size = block_size.block_size + return _create_blocked_slice(block_index, block_size.block_size, size, tiling) + case int(): + return _create_blocked_slice(block_index, block_size, size, tiling) case pl.Element(): block_start = block_index block_size = block_size.block_size - case pl.BoundedSlice(): + return _create_bounded_slice( + block_start, block_size, block_size, size, tiling + ) + case pl.BoundedSlice(block_size): if not isinstance(block_index, pl.Slice): raise ValueError( "Must return a pl.ds from the index_map for a BoundedSlice" " dimension." ) - block_start = block_index.start - block_size = block_index.size - return pl.ds(block_start, block_size) - case int(): - # This is same as Blocked. - block_start = block_index * block_size + slice_start = block_index.start + slice_size = block_index.size + return _create_bounded_slice( + slice_start, slice_size, block_size, size, tiling + ) case None | pl.Squeezed(): - block_start = block_index - block_size = 1 + return block_index case _: raise ValueError(f"Unsupported block dimension type: {block_size}") - if size % block_size == 0: - return pl.ds(block_start, block_size) - if block_size % tiling != 0: - raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") - num_blocks = pl.cdiv(size, block_size) - is_last = block_index == num_blocks - 1 - rounded_size = jnp.where( - is_last, - _round_up_to_nearest_multiple(size % block_size, tiling), - block_size, - ) - rounded_size = pl.multiple_of(rounded_size, tiling) - return pl.ds(block_index * block_size, rounded_size) def _tuples_differ(xs, ys): @@ -181,7 +212,6 @@ def _grid_size(grid): return size - class BufferType(enum.Enum): """Buffer type for the arguments to an emitted pipeline.""" INPUT = 1 @@ -197,20 +227,20 @@ def _get_dim_size(bd): match bd: case pl.Blocked(block_size): return block_size - case pl.Element(): - return bd.block_size + case pl.Element(block_size): + return block_size case pl.BoundedSlice(block_size): return block_size case int(): return bd - case None: - return 1 + case None | pl.Squeezed(): + return None case _: raise ValueError(f"Unsupported block dimension type: {bd}") if spec.block_shape is None: raise ValueError("Block shape must be specified.") - block_shape = tuple(_get_dim_size(x) for x in spec.block_shape) - return block_shape + block_shape_nones = tuple(_get_dim_size(x) for x in spec.block_shape) + return tuple(x for x in block_shape_nones if x is not None) @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -369,7 +399,10 @@ def memory_space(self): @property def current_ref(self): buffer_slice = tuple( - 0 if x is None else slice(None) for x in self.block_shape) + slice(None) + for x in self.block_shape + if not (x is None or isinstance(x, pl.Squeezed)) + ) assert not (self.window_ref is None or isinstance(self.window_ref, REF)) if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] @@ -420,22 +453,29 @@ def bind_existing_ref(self, window_ref, indices): def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = [] - for bd in self.block_shape: - if isinstance(bd, (pl.Element, pl.BoundedSlice)): - raise ValueError( - "Element and BoundedSlice block dimensions are not supported." - ) - if bd is None: - block_shape.append(1) - elif isinstance(bd, pl.Blocked): - block_shape.append(bd.block_size) - elif isinstance(bd, int): - block_shape.append(bd) - else: - raise ValueError(f"Unsupported block dimension type: {type(bd)}") indices = self.compute_index(*grid_indices) - return jax.tree.map(_make_block_ds, indices, tuple(block_shape)) + assert len(self.block_shape) == len(indices) + indexer = [] + for bd, idx in zip(self.block_shape, indices, strict=True): + match bd: + case None | pl.Squeezed(): + # Dimension is squeezed out so we don't do anything. + indexer.append(idx) + case pl.Element(): + raise ValueError( + "Element block dimensions are not supported." + ) + case pl.BoundedSlice(): + raise ValueError( + "BoundedSlice block dimensions are not supported." + ) + case pl.Blocked(block_size): + indexer.append(_make_block_ds(idx, block_size)) + case int(): + indexer.append(_make_block_ds(idx, bd)) + case _: + raise ValueError(f"Unsupported block dimension type: {type(bd)}") + return tuple(indexer) def init_slots(self): """Initialize slot indices.""" @@ -520,7 +560,11 @@ def copy_in(self, src_ref, grid_indices): self.swap[0] = True next_slot = self.next_slot_index src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( src_ref.at[src_slice], self.window_ref.at[(next_slot, *dst_slice)], @@ -537,7 +581,11 @@ def copy_out(self, dst_ref, grid_indices): self.swap[0] = True slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], @@ -551,7 +599,11 @@ def wait_in(self, src_ref, grid_indices): assert not (self.window_ref is None or isinstance(self.window_ref, REF)) assert self.sem_recvs is not None src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter @@ -570,7 +622,11 @@ def wait_out(self, dst_ref, grid_indices): # In a double buffer, previous slot is the same as next slot. prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important From a0d295cb983eb75313271532c2c4c17aa6e8da3b Mon Sep 17 00:00:00 2001 From: Yan Zhao Date: Wed, 30 Apr 2025 00:39:30 +0000 Subject: [PATCH 0908/1769] Use sharding in types and reshard Signed-off-by: Yan Zhao --- examples/spmd_mnist_classifier_fromscratch.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 234ffac7de4c..c5c85b2aff37 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import time +from jax import NamedSharding import numpy as np import numpy.random as npr - import jax from jax import jit, grad -from jax.sharding import PartitionSpec as P, NamedSharding +from jax.experimental.shard import reshard +from jax.sharding import ( + PartitionSpec as P, + AxisType, +) from jax.scipy.special import logsumexp -from jax.tree_util import tree_map import jax.numpy as jnp import datasets @@ -50,6 +54,7 @@ def loss(params, batch): return -jnp.mean(jnp.sum(preds * targets, axis=1)) +@partial(jax.jit, donate_argnums=0) def train_step(params, batch): grads = grad(loss)(params, batch) return [ @@ -86,7 +91,9 @@ def accuracy(params, batch): num_batches = num_complete_batches + bool(leftover) devices = np.array(jax.devices()) - mesh = jax.make_mesh((jax.device_count(),), ("batch",)) + mesh = jax.make_mesh( + (jax.device_count(),), ("batch",), axis_types=(AxisType.Explicit,) + ) replicated_sharding = NamedSharding(mesh, P()) data_sharding = NamedSharding(mesh, P("batch")) @@ -112,16 +119,7 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) - - param_shardings = tree_map(lambda x: replicated_sharding, params) - params = jax.device_put(params, param_shardings) - batch_shardings = (data_sharding, data_sharding) - - jitted_train_step = jax.jit( - train_step, - out_shardings=param_shardings, - donate_argnums=(0,), - ) + replicated_params = jax.device_put(params, replicated_sharding) for epoch in range(num_epochs): start_time = time.time() @@ -129,11 +127,19 @@ def data_stream(): print(f"Batch no {i+1} of {num_batches}") batch = next(batches) with jax.sharding.use_mesh(mesh): - params = jitted_train_step(params, batch) + replicated_params = train_step(replicated_params, batch) epoch_time = time.time() - start_time - train_acc = accuracy(params, (train_images, train_labels)) - test_acc = accuracy(params, (test_images, test_labels)) + # Reshard train_images, train_labels, test_images, test_labels + sharded_train_images = reshard(train_images, data_sharding) + sharded_train_labels = reshard(train_labels, data_sharding) + sharded_test_images = reshard(test_images, data_sharding) + sharded_test_labels = reshard(test_labels, data_sharding) + + train_acc = accuracy( + replicated_params, (sharded_train_images, sharded_train_labels) + ) + test_acc = accuracy(replicated_params, (sharded_test_images, sharded_test_labels)) print(f"Epoch {epoch} in {epoch_time:0.2f} sec") print(f"Training set accuracy {train_acc}") print(f"Test set accuracy {test_acc}") @@ -141,4 +147,4 @@ def data_stream(): if epoch < num_epochs - 1: batches = data_stream() print(f"Batch no {0} of {num_batches}") - params = jitted_train_step(params, next(batches)) + replicated_params = train_step(replicated_params, next(batches)) From 6cb8353a92536b9db395bed8f9d3d7ee1c892595 Mon Sep 17 00:00:00 2001 From: chaser Date: Wed, 9 Apr 2025 18:44:13 +0000 Subject: [PATCH 0909/1769] Added additional stream annotation tests --- tests/memories_test.py | 78 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 302895e29f63..e0a42d4a0146 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -34,6 +34,7 @@ from jax._src.sharding_impls import ( NamedSharding, SingleDeviceSharding, GSPMDSharding, TransferToMemoryKind, PartitionSpec as P) +from jax._src.xla_metadata import set_xla_metadata from jax.experimental.compute_on import compute_on from jax._src.shard_map import shard_map import numpy as np @@ -1684,6 +1685,74 @@ def peer_forward(x, experts, indices, scores): class StreamAnnotationTest(jtu.JaxTestCase): + def test_stream_annotation_single_instruction(self): + # E2E test for fix https://github.com/openxla/xla/pull/24269 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + np_inp = np.ones((8,)) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x + y + + @jax.jit + def f(x, y): + return g(x, y) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('wrapped_add', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 2) + + def test_streamed_gemm_overlap(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x @ y + + @compute_on('gpu_stream:2') + @jax.jit + def h(x, y): + return x @ y + + @jax.jit + @functools.partial( + jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x')) + def f(x, y): + with set_xla_metadata(_scheduling_group_id="1"): + a = g(x, y) + b = h(y, x) + return a + b + + np_input = np.ones((1024, 512)) + + arr1 = jax.device_put(np_input, s) + arr2 = jax.device_put(np_input, s) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_text) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertIn('_scheduling_group_id="1"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 1024) + def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") @@ -1694,10 +1763,6 @@ def test_stream_annotation_inside_shmap(self): arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) - # Makes sure the compute wrapped here is fusible. - # This is a workaround for limitations in XLA. - # 1) Compute-on boxes contain a single instruction cannot work. - # 2) Compute-on boxes contain tiny matmul cannot work. @compute_on('gpu_stream:1') @jax.jit def g(x, y): @@ -1715,9 +1780,7 @@ def f(x, y): compiled_f = jax.jit( shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x'))).lower(arr1, arr2).compile( - {"xla_gpu_experimental_stream_annotation": True} - ) + out_specs=P('x'))).lower(arr1, arr2).compile() compiled_text = compiled_f.as_text() self.assertIn('call-start', compiled_text) self.assertIn('_xla_stream_annotation="1"', compiled_text) @@ -1725,7 +1788,6 @@ def f(x, y): self.assertIn('_xla_stream_annotation="2"', compiled_text) self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) - class ActivationOffloadingTest(jtu.JaxTestCase): def setUp(self): From 603f73016cdd6c576b819524ea270972439ff5c5 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Tue, 29 Apr 2025 23:55:24 -0700 Subject: [PATCH 0910/1769] Automated Code Change PiperOrigin-RevId: 753039040 --- jax/_src/core.py | 2 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/lax/lax.py | 8 ++++---- jax/_src/pallas/triton/pallas_call_registration.py | 2 +- jax/_src/pjit.py | 10 +++++----- jax/_src/tpu_custom_call.py | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ab2232925596..4bc64b85b81d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2047,7 +2047,7 @@ def standard_insert_pvary(*args): if not args: return args in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token - else aval.vma for a in args] + else aval.vma for a in args] # pytype: disable=attribute-error out_vma = frozenset.union(*in_vma) return [pvary(arg, tuple(n for n in out_vma if n not in src)) if out_vma - src else arg for arg, src in zip(args, in_vma)] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c7a53b1b4260..ccbe84bcfa4e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1097,7 +1097,7 @@ class UnconstrainedVariants(NamedTuple): def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants: us = contains_unconstrained(s) - unconstrained_dims = ({i for i, p in enumerate(s.spec) + unconstrained_dims = ({i for i, p in enumerate(s.spec) # pytype: disable=attribute-error if p is PartitionSpec.UNCONSTRAINED} if us else None) return UnconstrainedVariants( contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval), diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2034026e197a..f94f28be32c5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5172,7 +5172,7 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) rhs_group = () if isinstance(dimension_numbers, RaggedDotDimensionNumbers): - rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) # pytype: disable=attribute-error rhs_contract_or_batch_or_group = tuple( sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) ) @@ -6017,7 +6017,7 @@ def grad_x_dims(): unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_x_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def grad_y_dims(): match mode: @@ -6036,7 +6036,7 @@ def grad_y_dims(): ) case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_y_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def _ragged_dot_grad(lhs, rhs, dims_fn, aval): dims, unsorted_axes = dims_fn() @@ -6238,7 +6238,7 @@ def expand(x, dim, gs, *axes): lhs, rhs, dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, - ) + ) # pytype: disable=bad-return-type def _ragged_dot_general_lower( diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index b692bd43a0fa..e111cef0f924 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -40,7 +40,7 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: def avals_to_layouts(avals): - return [list(reversed(range(aval.ndim))) for aval in avals] + return [list(reversed(range(aval.ndim))) for aval in avals] # pytype: disable=attribute-error def pallas_call_lowering( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 3e22b06ba02d..579d3dd0e10e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -750,16 +750,16 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = f"argument path is {dbg.arg_names[i]}" + arg_path = f"argument path is {dbg.arg_names[i]}" # pytype: disable=name-error raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = f"path {dbg.arg_names[i]}" + arg_description = f"path {dbg.arg_names[i]}" # pytype: disable=name-error raise TypeError( f"Error interpreting argument to {fun} as an abstract array." - f" The problematic value is of type {type(x)} and was passed to" + f" The problematic value is of type {type(x)} and was passed to" # pytype: disable=name-error f" the function at {arg_description}.\n" "This typically means that a jit-wrapped function was called with a non-array" " argument, and this argument was not marked as static using the" @@ -2035,8 +2035,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): num_devices = axis_ctx.mesh.size key = (pjit_p, name, jaxpr, effects, num_devices, - pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), - pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), + pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), # pytype: disable=wrong-arg-types + pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), # pytype: disable=wrong-arg-types in_layouts, out_layouts, api_name) func = mod_ctx.cached_primitive_lowerings.get(key, None) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 7aa0f87dd525..6039979df37b 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -222,7 +222,7 @@ def _tpu_custom_call_abstract_eval(*_, out_avals, **__): def _avals_to_layouts(avals) -> Sequence[Sequence[int]]: - return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] + return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] # pytype: disable=attribute-error def _tpu_custom_call_lowering( From 70567d466768b10954094aa1de7c26cf38069dc1 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Wed, 30 Apr 2025 01:56:10 -0700 Subject: [PATCH 0911/1769] [JAX:Sparse] Implement CSR sparse kernel PiperOrigin-RevId: 753071768 --- jaxlib/BUILD | 2 + jaxlib/cpu/BUILD | 33 +++++ jaxlib/cpu/_sparse/__init__.pyi | 15 +++ jaxlib/cpu/cpu_kernels.cc | 3 + jaxlib/cpu/sparse.cc | 37 ++++++ jaxlib/cpu/sparse_kernels.cc | 215 ++++++++++++++++++++++++++++++++ jaxlib/cpu/sparse_kernels.h | 27 ++++ jaxlib/cpu_sparse.py | 27 ++++ jaxlib/xla_client.py | 2 +- 9 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 jaxlib/cpu/_sparse/__init__.pyi create mode 100644 jaxlib/cpu/sparse.cc create mode 100644 jaxlib/cpu/sparse_kernels.cc create mode 100644 jaxlib/cpu/sparse_kernels.h create mode 100644 jaxlib/cpu_sparse.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 50d3d28c05bf..e84203322f42 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -52,6 +52,7 @@ package_group( py_library_providing_imports_info( name = "jaxlib", srcs = [ + "cpu_sparse.py", "gpu_common_utils.py", "gpu_linalg.py", "gpu_prng.py", @@ -76,6 +77,7 @@ py_library_providing_imports_info( "//jaxlib:_jax", "//jaxlib:xla_client", "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 76934df6c37b..cbcddd9713f0 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -85,9 +85,42 @@ cc_library( deps = [ ":lapack_kernels", ":lapack_kernels_using_lapack", + ":sparse_kernels", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) + +cc_library( + name = "sparse_kernels", + srcs = ["sparse_kernels.cc"], + hdrs = ["sparse_kernels.h"], + deps = [ + "@eigen_archive//:eigen3", + "@xla//xla/ffi/api:ffi", + ], +) + +nanobind_extension( + name = "_sparse", + srcs = ["sparse.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + enable_stub_generation = False, + features = ["-use_header_modules"], + module_name = "_sparse", + pytype_srcs = [ + "_sparse/__init__.pyi", + ], + deps = [ + ":sparse_kernels", + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/base", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) diff --git a/jaxlib/cpu/_sparse/__init__.pyi b/jaxlib/cpu/_sparse/__init__.pyi new file mode 100644 index 000000000000..a82f83b267b7 --- /dev/null +++ b/jaxlib/cpu/_sparse/__init__.pyi @@ -0,0 +1,15 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def registrations() -> dict: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index a118c20a4490..4361e42827ea 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/cpu/sparse_kernels.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_target_registry.h" @@ -110,6 +111,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi); +JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi); + #undef JAX_CPU_REGISTER_HANDLER } // namespace diff --git a/jaxlib/cpu/sparse.cc b/jaxlib/cpu/sparse.cc new file mode 100644 index 000000000000..15f5c0f1984f --- /dev/null +++ b/jaxlib/cpu/sparse.cc @@ -0,0 +1,37 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/sparse_kernels.h" +#include "jaxlib/kernel_nanobind_helpers.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +nb::dict Registrations() { + nb::dict dict; + + dict["cpu_csr_sparse_dense_ffi"] = + EncapsulateFunction(cpu_csr_sparse_dense_ffi); + + return dict; +} + +NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); } + +} // namespace +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.cc b/jaxlib/cpu/sparse_kernels.cc new file mode 100644 index 000000000000..8000abca65cc --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.cc @@ -0,0 +1,215 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/cpu/sparse_kernels.h" + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "Eigen/SparseCore" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +namespace jax { + +template +using SparseMatrixType = + Eigen::SparseMatrix; +template +using DenseMatrixType = + Eigen::Matrix; + +template +using InputMap = Eigen::Map; +template +using OutputMap = Eigen::Map; + +template +static ffi::Future CsrSparseDenseKernelImpl( + const InputMap>& lhs_matrix, + const InputMap>& rhs_matrix, + OutputMap>& out_matrix, + ffi::ThreadPool& thread_pool) { + // Rule of thumb to give each task at least 100k cycles to hide the cost of + // task scheduling. + // TODO(willfroom) Do we want to make this configurable? + constexpr int64_t kTargetCyclesPerTask = 100'000; + // Based on AVX (CPI 0.5 -> 2 IPC) + constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType); + constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle; + + if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize || + thread_pool.num_threads() == 0) { + out_matrix.noalias() = lhs_matrix * rhs_matrix; + + ffi::Promise promise; + promise.SetAvailable(); + return ffi::Future(promise); + } else { + std::vector batch_sizes; + { + int64_t running_batch_nnz = 0; + int64_t running_number_rows = 0; + for (int row = 0; row < lhs_matrix.rows(); ++row) { + int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] - + lhs_matrix.outerIndexPtr()[row]; + // If there is no non-zero elements in a row the task still needs to + // write out a zero row we give each row a non-zero contribution to + // avoid the pathological case of a task having to write many rows where + // there is a large block of zero inputs. + running_batch_nnz += std::max(row_nnz, static_cast(1)); + running_number_rows++; + if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) { + batch_sizes.push_back(running_number_rows); + running_batch_nnz = 0; + running_number_rows = 0; + } else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) { + batch_sizes.push_back(running_number_rows); + } + } + } + + ffi::CountDownPromise promise(batch_sizes.size()); + ffi::Future future(promise); + int64_t batch_start = 0; + for (int64_t size : batch_sizes) { + thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start, + size, promise]() mutable { + out_matrix.middleRows(batch_start, size).noalias() = + lhs_matrix.middleRows(batch_start, size) * rhs_matrix; + promise.CountDown(); + }); + batch_start += size; + } + return future; + } +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + ffi::Span rhs_shape = rhs.dimensions(); + ffi::Span out_shape = out->dimensions(); + + InputMap> lhs_matrix( + out_shape[0], rhs_shape[0], lhs_data.element_count(), + lhs_outer_indicies.reinterpret_data(), + lhs_inner_indicies.reinterpret_data(), + lhs_data.reinterpret_data()); + + InputMap> rhs_matrix( + rhs.reinterpret_data(), rhs_shape[0], + rhs_shape.size() > 1 ? rhs_shape[1] : 1); + OutputMap> out_matrix( + out->reinterpret_data(), lhs_matrix.rows(), + rhs_matrix.cols()); + + return CsrSparseDenseKernelImpl( + lhs_matrix, rhs_matrix, out_matrix, thread_pool); +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) { + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Sparse index type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_outer_indicies.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Invalid index data type")); + return ffi::Future(promise); + } +} + +static ffi::Future CsrSparseDenseKernelDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_data.element_type() != rhs.element_type() || + lhs_data.element_type() != out->element_type()) { + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Element type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_data.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C64: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C128: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Invalid data type")); + return ffi::Future(promise); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi, + CsrSparseDenseKernelDispatch, + (ffi::Ffi::Bind() + .Arg(/*lhs_data*/) + .Arg( + /*lhs_outer_indicies*/) + .Arg( + /*lhs_inner_indicies*/) + .Arg(/*rhs*/) + .Ret(/*out*/) + .Ctx(/*thread_pool*/))); + +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.h b/jaxlib/cpu/sparse_kernels.h new file mode 100644 index 000000000000..856b1da9d36c --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.h @@ -0,0 +1,27 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ + +#include "xla/ffi/api/ffi.h" + +namespace jax { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi); + +} // namespace jax + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ diff --git a/jaxlib/cpu_sparse.py b/jaxlib/cpu_sparse.py new file mode 100644 index 000000000000..ed43b3ee0f92 --- /dev/null +++ b/jaxlib/cpu_sparse.py @@ -0,0 +1,27 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from .cpu import _sparse + + +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + api_version = 1 + return { + "cpu": [ + (name, value, api_version) + for name, value in _sparse.registrations().items() + ] + } diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 1d7f6cd8f584..82ae7855a8ee 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 332 +_version = 333 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From cac35e0cb7d53a802bda9b1fb0f3f3223ce1e33c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Apr 2025 05:58:39 -0700 Subject: [PATCH 0912/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7bf1170008bdd80e25f38daecc9413961b36d9c3. PiperOrigin-RevId: 753134804 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e95466226c42..fbd81d22a568 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "38b38564b27b0146abf0f8131983874f1f5d8fe2" -XLA_SHA256 = "e61f726a8ad1faf3d58c61c50d8f36b0bfc9e76ad81e8f7bc1562a3645f33eaa" +XLA_COMMIT = "7bf1170008bdd80e25f38daecc9413961b36d9c3" +XLA_SHA256 = "dcb6d27aabd985090e9df8b776a76f3881362cd502e3c580ffc6d03ee4524fe0" def repo(): tf_http_archive( From ed96800a4bbd65531d8521deb9045dd454ec0db1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 30 Apr 2025 08:39:26 -0700 Subject: [PATCH 0913/1769] [Mosaic GPU] Check the `is_signed` on FragmentedArray, not on its mlir_dtype The mlir_dtype for integer types is always signless! This meant that we always kept going down the unsigned path. This wasn't caught by our tests, because the comparison tests were disabled for MGPU... PiperOrigin-RevId: 753182317 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- jax/experimental/mosaic/gpu/layout_inference.py | 2 +- tests/pallas/mosaic_gpu_test.py | 14 ++++++++++++++ tests/pallas/ops_test.py | 11 +++++------ 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a955effdc31d..b4f33ef3e7a4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1713,7 +1713,7 @@ def _comparison_lowering_rule_wg( x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out) if jnp.issubdtype(x_aval, jnp.signedinteger): return arith_dialect.cmpi(si_pred, x, y) - elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool): + elif jnp.issubdtype(x_aval, jnp.unsignedinteger) or jnp.issubdtype(x_aval, jnp.bool): return arith_dialect.cmpi(ui_pred, x, y) elif jnp.issubdtype(x_aval, jnp.floating): return arith_dialect.cmpf(f_pred, x, y) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 76f7d549cf55..fbcef6e6ecb8 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1167,7 +1167,7 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred): if ir.FloatType.isinstance(self.mlir_dtype): pred = functools.partial(arith.cmpf, f_pred) elif ir.IntegerType.isinstance(self.mlir_dtype): - if ir.IntegerType(self.mlir_dtype).is_signed: + if self.is_signed: pred = functools.partial(arith.cmpi, si_pred) else: pred = functools.partial(arith.cmpi, ui_pred) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 7be4abba3e65..feae31f0e9f6 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -712,7 +712,7 @@ def update_default_vector_size_from_vector(v: ir.Value): max_vec_size_for_v = ( np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + desired_vec_size = 64 // utils.bitwidth(v.type.element_type) default_vector_size = min( default_vector_size, max_vec_size_for_v, desired_vec_size ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 63f69738b6f6..349a1dffb274 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3117,6 +3117,20 @@ def do_wgmma(acc_ref): np.testing.assert_allclose(kernel(x, x), x @ x) + def test_debug_bug(self): + dtype = jnp.float16 + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], dtype), + ) + def kernel(o_ref): + kv_step = jnp.asarray(0) + @pl.when(kv_step < -2) + def dp(): + pl.debug_print("foo") + o_ref[...] = jnp.zeros_like(o_ref) + kernel() + # TODO(apaszke): Clusters and multicast diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8c8e59aedc11..bf60408e96b7 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1292,7 +1292,6 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison(self, fn, dtype): - self.skip_if_mosaic_gpu() if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") @@ -1302,16 +1301,16 @@ def test_comparison(self, fn, dtype): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - o_ref[:] = fn(x_ref[...], y_ref[...]) + o_ref[:] = fn(x_ref[...], y_ref[...]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) @parameterized.named_parameters( (f"{fn.__name__}_{dtype.__name__}", fn, dtype) From 74580607b5dcd470a8bf9b386a0ac8413dd1aae3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 30 Apr 2025 09:03:33 -0700 Subject: [PATCH 0914/1769] Don't allow UNCONSTRAINED in in_shardings of jit when the user passes those arguments. This was never supported in GSPMD and shardy errors out. If you are broken by this change, just replace `UNCONSTRAINED` with `None` since the behavior will be the exact same. PiperOrigin-RevId: 753190393 --- jax/_src/pjit.py | 3 ++- jax/_src/sharding_impls.py | 12 +++++++++--- tests/pjit_test.py | 8 ++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 579d3dd0e10e..03c755b451e7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -481,7 +481,8 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, out_layouts, out_shardings = _split_layout_and_sharding(out_shardings) in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') + out_shardings = prepare_axis_resources(out_shardings, 'out_shardings', + allow_unconstrained_dims=True) user_specified_in_shardings = (in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue)) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d5d7725f5381..ec5fd4c512bf 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -693,14 +693,20 @@ def prepare_axis_resources(axis_resources, arg_name, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') + if (not allow_unconstrained_dims and isinstance(entry, NamedSharding) and + PartitionSpec.UNCONSTRAINED in entry.spec): + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') new_entries.append(entry) else: if not isinstance(entry, PartitionSpec): raise TypeError(f"{what} are expected to be " f"PartitionSpec instances or None, but got {entry}") - for e in entry: - if e is PartitionSpec.UNCONSTRAINED and not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") + if not allow_unconstrained_dims and PartitionSpec.UNCONSTRAINED in entry: + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') _check_unique_resources(entry, arg_name) new_entries.append(entry) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7f08c5af40b3..8d56f96d684d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4468,6 +4468,14 @@ def f(x): self.assertLen(traced.in_avals[0], 1) self.assertLen(traced.in_avals[1], 0) # empty kwarg + def test_in_out_shardings_unconstrained_error(self): + mesh = jtu.create_mesh((1,), ('x',)) + + with self.assertRaisesRegex( + ValueError, "Unconstrained dims are not allowed"): + jax.jit(lambda x: x, + in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) + def test_empty_io_callback_under_shard_map(self): if config.use_shardy_partitioner.value: self.skipTest("TODO(b/384938613): Failing under shardy.") From eb312d3728e80247d997fbd84c609e700d15bbe3 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Wed, 30 Apr 2025 10:00:04 -0700 Subject: [PATCH 0915/1769] fix jax test to wait for async result to complete. PiperOrigin-RevId: 753209846 --- tests/checkify_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index c619fc8e915c..e7ae4d0468fd 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -1389,7 +1389,7 @@ def f(x): with self.assertRaisesRegex(_jax.XlaRuntimeError, "x needs to be positive"): - f(-1.) + f(-1.).block_until_ready() if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3a81bd427c89361e2fbb1199111f8aa4598825e6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 30 Apr 2025 10:01:22 -0700 Subject: [PATCH 0916/1769] [Mosaic TPU] Add support for passing in single-element inputs through ANY memory space PiperOrigin-RevId: 753210325 --- tests/pallas/tpu_pallas_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index fccf1e42b932..a70aa19bda4d 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1776,6 +1776,25 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) + def test_scalar_any_input(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Needs a newer TPU") + if not jtu.if_cloud_tpu_at_least(2025, 5, 1): + self.skipTest("Needs a newer libTPU") + def kernel(src, dst, sem): + pltpu.async_copy(src, dst, sem).wait() + + def run(src): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(src.shape, jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + )(src) + x = jnp.full((1,), 3.1415, dtype=jnp.float32) + np.testing.assert_array_equal(run(x), x) + def test_sum_in_smem(self): if not jtu.if_cloud_tpu_at_least(2025, 4, 30): self.skipTest("Needs a newer libTPU") From b094d40f20f3e045fb2d63006eda25c96acad302 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Apr 2025 10:08:21 -0700 Subject: [PATCH 0917/1769] Add a matrix strategy to test jax head with latest jaxlib/CUDA plugins release to the Bazel CUDA non-rbe job PiperOrigin-RevId: 753213675 --- .github/workflows/bazel_cuda_non_rbe.yml | 28 +++++++++++++++----- .github/workflows/wheel_tests_continuous.yml | 4 ++- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 8d230848d20a..72878ad7aacb 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -29,6 +29,11 @@ on: type: string required: true default: "0" + jaxlib-version: + description: "Which jaxlib version to test? (head/pypi_latest)" + type: string + required: true + default: "head" gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true @@ -55,7 +60,7 @@ jobs: # Enable writing to the Bazel remote cache bucket. JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1" - name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "Bazel single accelerator and multi-accelerator CUDA tests (jaxlib version=${{ inputs.jaxlib-version }}, ${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -77,12 +82,21 @@ jobs: # fails. Instead, we verify the outcome in the next step so that we can print a more # informative error message. continue-on-error: true - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + run: | + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + + if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then + PYTHON=python${{ inputs.python }} + $PYTHON -m pip download jaxlib jax-cuda12-pjrt jax-cuda12-plugin --dest $(pwd)/dist/ + else + echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}" + exit 1 + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 175fc2f22d4a..95d83a92c776 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -149,12 +149,14 @@ jobs: # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] python: ["3.10",] + jaxlib-version: ["head", "pypi_latest"] enable-x64: [1, 0] - name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Bazel CUDA Non-RBE (jax version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + jaxlib-version: ${{ matrix.jaxlib-version }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} From 978c1306b18272d815344fba0041f6fbb11ab370 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Apr 2025 10:08:27 -0700 Subject: [PATCH 0918/1769] Migrate JAX Array API workflow and NumPy/SciPy nightly wheels workflow to use JAX's new self-hosted runners These new runners are more powerful so we should see significant boosts in the run times. PiperOrigin-RevId: 753213725 --- .github/workflows/jax-array-api.yml | 15 +++++++-------- .github/workflows/upstream-nightly.yml | 5 +++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index c062970e6b1e..825d3ada9a0b 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -15,11 +15,15 @@ concurrency: jobs: build: - runs-on: ubuntu-latest + runs-on: linux-x86-n2-16 + container: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest strategy: matrix: python-version: [3.11] + env: + PYTHON: "python${{ matrix.python-version }}" + steps: - name: Checkout jax uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -31,18 +35,13 @@ jobs: ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 submodules: 'true' path: 'array-api-tests' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install uv~=0.5.30 - uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt + $PYTHON -m uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt - name: Run the test suite env: ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt + $PYTHON -m pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 47f0ae0689b2..349ddf0d96a3 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -22,7 +22,8 @@ on: jobs: upstream-dev: - runs-on: ubuntu-latest + runs-on: linux-x86-n2-64 + container: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 permissions: contents: read issues: write # for failed-build-issue @@ -66,7 +67,7 @@ jobs: echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n 2 --tb=short --maxfail=20 tests examples + pytest -n auto --tb=short --maxfail=20 tests examples - name: Notify failed build uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 if: failure() && github.event.pull_request == null From dc22eb40bf65cd92769f02a645e78606d31d7255 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 30 Apr 2025 10:47:07 -0700 Subject: [PATCH 0919/1769] Optionally plumb executable devices through deserialization (and use backend.devices()) by default). PiperOrigin-RevId: 753229392 --- jax/experimental/serialize_executable.py | 37 +++++++++++++++++++----- tests/aot_test.py | 14 +++++++++ tests/pgle_test.py | 3 +- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 0a1f32af322c..6f5062d4ce99 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -21,6 +21,7 @@ import jax from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc +from typing import Sequence def serialize(compiled: jax.stages.Compiled): @@ -44,14 +45,27 @@ def serialize(compiled: jax.stages.Compiled): def deserialize_and_load(serialized, in_tree, out_tree, - backend: str | xc.Client | None = None): + backend: str | xc.Client | None = None, + execution_devices: Sequence[xc.Device] | None = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): backend = jax.devices(backend)[0].client + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} and ' + 'execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + (unloaded_executable, args_info_flat, - no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load() + no_kwargs) = _JaxPjrtUnpickler( + io.BytesIO(serialized), backend, execution_devices).load() args_info = in_tree.unflatten(args_info_flat) @@ -78,19 +92,28 @@ def persistent_id(self, obj): class _JaxPjrtUnpickler(pickle.Unpickler): - def __init__(self, file, backend): + def __init__(self, file, backend, execution_devices=None): super().__init__(file) self.backend = backend - self.devices_by_id = {d.id: d for d in backend.devices()} + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} ' + 'and execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + self.devices_by_id = {d.id: d for d in execution_devices} + self.execution_devices = xc.DeviceList(tuple(execution_devices)) def persistent_load(self, pid): if pid[0] == 'exec': if jaxlib_extension_version < 332: return self.backend.deserialize_executable(pid[1]) return self.backend.deserialize_executable( - pid[1], - executable_devices=xc.DeviceList(tuple(self.backend.devices())) - ) + pid[1], executable_devices=self.execution_devices) if pid[0] == 'device': return self.devices_by_id[pid[1]] if pid[0] == 'client': diff --git a/tests/aot_test.py b/tests/aot_test.py index daaeb8417d33..623c6aaed0cc 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -126,6 +126,20 @@ def my_function(x): hlo = lowered.as_text("hlo") self.assertNotRegex(hlo, r"sine.*metadata=.*source_file=.*") + @jtu.run_on_devices('gpu', 'tpu') + def test_mismatched_backends_raises(self): + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + with self.assertRaisesRegex( + ValueError, + 'Execution devices belong to a client other than `backend`'): + deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', + execution_devices=jax.devices()[:1]) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 2787de4c6e17..8814250ea066 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -202,7 +202,8 @@ def f(x): f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) - compiled = deserialize_and_load(serialized, in_tree, out_tree) + compiled = deserialize_and_load( + serialized, in_tree, out_tree, execution_devices=jax.devices()[:1]) with config.pgle_profiling_runs(1), config.enable_pgle(True): # Run 1 From 9c19d99775a562c91f649843b812e00507021b46 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 30 Apr 2025 12:13:04 -0700 Subject: [PATCH 0920/1769] Add a bazel target for jax.experimental.shard. PiperOrigin-RevId: 753263386 --- jax/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/BUILD b/jax/BUILD index d2320e1e4456..05aa6134628d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1171,6 +1171,12 @@ pytype_library( deps = [":jax"], ) +pytype_library( + name = "experimental_shard", + srcs = ["experimental/shard.py"], + deps = [":jax"], +) + pytype_library( name = "experimental_sparse", srcs = glob( From 614e975a71f347bff354483e2ab46c6b80beb0e6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 30 Apr 2025 12:35:24 -0700 Subject: [PATCH 0921/1769] Fix some flaky tests in CI. * Increase sharding of scipy_spatial_test on CPU and GPU. * Tag lax_scipy_special_functions_test as noasan on all platforms and nomsan on CPU PiperOrigin-RevId: 753271096 --- tests/BUILD | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 749819a5bc08..9af33da83e5d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -585,17 +585,17 @@ jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { - "gpu": ["noasan"], # Times out. "cpu": [ - "noasan", - "notsan", - ], # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], }, shard_count = { "cpu": 20, "gpu": 30, "tpu": 20, }, + tags = ["noasan"], # Times out under asan. deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) @@ -992,6 +992,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], + shard_count = { + "cpu": 4, + "gpu": 4, + }, deps = py_deps("scipy"), ) From 88391ce07b32144e8231d007a9e866a06ac09eca Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 30 Apr 2025 13:00:24 -0700 Subject: [PATCH 0922/1769] A JAX callback that operates directly on device buffers. This experimental callback mechanism `buffer_callback` provides an interface for simple interoperability between JAX and Python libraries like PyTorch, Cupy, Warp, and even Numpy, that most effectively operate in-place on device. The existing JAX callback APIs (`pure_callback`, `io_callback`, etc.) require explicit copies on CPU, and expensive device-to-host (and back again) copies on GPU. The `buffer_callback` API instead transports buffers using a custom type that supports zero-copy access to the underlying device memory via the standard `__array__` (on CPU), `__cuda_array_interface__` (on GPU), or `__dlpack__` interfaces. The callback also accepts the outputs as a Pytree of mutable buffers so that no copies are required on return. An API like this is useful for prototyping kernels that are available (or easier to write!) in other frameworks. Related feature requests have come up in the past (e.g. https://github.com/jax-ml/jax/issues/20701), and some previous literature exists (e.g. https://github.com/jax-ml/jax/pull/21003, https://github.com/openxla/xla/pull/23243), but nothing has been implemented yet. We would like to eventually remove this API in favor of a more coherent mutable JAX Array interface, but it seems useful to land this as a temporary measure until something like that exists. **API** A callback wrapped using `buffer_callback` must have the following signature: ```python def callback( ctx: ExecutionContext, out: Pytree[Buffer], *args: Pytree[Buffer], **kwargs: Any, # static attributes ) -> None: ... ``` where the `ExecutionContext` provides access to metadata such as the XLA CUDA stream, and the Pytree of `Buffer`s in `out` are mutable. From JAX, this callback is executed using: ```python out = buffer_callback(callback, result_types)(*args, **kwargs) ``` where `result_types` is a Pytree of `ShapeDtypeStruct`s specifying the types of the outputs. PiperOrigin-RevId: 753279846 --- jax/BUILD | 22 ++ jax/_src/buffer_callback.py | 261 ++++++++++++++++++ jax/_src/lib/__init__.py | 4 +- jax/experimental/buffer_callback.py | 27 ++ jaxlib/BUILD | 54 ++++ jaxlib/_jax/__init__.pyi | 1 + jaxlib/_jax/ffi.pyi | 48 ++++ jaxlib/cuda/BUILD | 2 + jaxlib/cuda/cuda_plugin_extension.cc | 2 + jaxlib/dlpack.cc | 199 +------------- jaxlib/dlpack_support.cc | 223 ++++++++++++++++ jaxlib/dlpack_support.h | 30 +++ jaxlib/ffi.cc | 386 +++++++++++++++++++++++++++ jaxlib/ffi.h | 153 +++++++++++ jaxlib/gpu/py_client_gpu.cc | 24 ++ jaxlib/gpu/py_client_gpu.h | 2 +- jaxlib/jax.bzl | 1 + jaxlib/py_client_cpu.cc | 17 ++ jaxlib/rocm/BUILD | 2 + jaxlib/rocm/rocm_plugin_extension.cc | 2 + jaxlib/xla.cc | 2 + jaxlib/xla_client.py | 2 +- tests/BUILD | 12 + tests/buffer_callback_test.py | 180 +++++++++++++ 24 files changed, 1455 insertions(+), 201 deletions(-) create mode 100644 jax/_src/buffer_callback.py create mode 100644 jax/experimental/buffer_callback.py create mode 100644 jaxlib/_jax/ffi.pyi create mode 100644 jaxlib/dlpack_support.cc create mode 100644 jaxlib/dlpack_support.h create mode 100644 jaxlib/ffi.cc create mode 100644 jaxlib/ffi.h create mode 100644 tests/buffer_callback_test.py diff --git a/jax/BUILD b/jax/BUILD index 05aa6134628d..dae375938fdd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -18,6 +18,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", + "buffer_callback_internal_users", "if_building_jaxlib", "jax_export_file_visibility", "jax_extend_internal_users", @@ -134,6 +135,12 @@ package_group( packages = serialize_executable_internal_users, ) +package_group( + name = "buffer_callback_users", + includes = [":internal"], + packages = buffer_callback_internal_users, +) + # JAX-private test utilities. py_library( # This build target is required in order to use private test utilities in jax._src.test_util, @@ -230,6 +237,7 @@ py_library_providing_imports_info( "_src/api.py", "_src/array.py", "_src/blocked_sampler.py", + "_src/buffer_callback.py", "_src/callback.py", "_src/checkify.py", "_src/custom_batching.py", @@ -1155,6 +1163,9 @@ py_library_providing_imports_info( "experimental/*.py", "example_libraries/*.py", ], + [ + "experimental/buffer_callback.py", + ], ), visibility = ["//visibility:public"], deps = [ @@ -1325,3 +1336,14 @@ pytype_library( "//jax/extend:ifrt_programs", ] + py_deps("numpy") + py_deps("cloudpickle"), ) + +pytype_library( + name = "experimental_buffer_callback", + srcs = [ + "experimental/buffer_callback.py", + ], + visibility = [":buffer_callback_users"], + deps = [ + ":jax", + ], +) diff --git a/jax/_src/buffer_callback.py b/jax/_src/buffer_callback.py new file mode 100644 index 000000000000..916ac57da2ad --- /dev/null +++ b/jax/_src/buffer_callback.py @@ -0,0 +1,261 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +import functools +from typing import Any + +import numpy as np + +from jax._src import core +from jax._src import dispatch +from jax._src import effects +from jax._src import ffi +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version + +export = util.set_module("jax.experimental.buffer_callback") + +if jaxlib_extension_version >= 334: + from jax._src.lib import ffi as ffi_lib + + Buffer = export(ffi_lib.Buffer) + ExecutionStage = export(ffi_lib.ExecutionStage) + ExecutionContext = export(ffi_lib.ExecutionContext) + + +def buffer_callback( + callback: Callable[..., None], + result_shape_dtypes: object, + *, + has_side_effect: bool = False, + vmap_method: str | None = None, + input_output_aliases: dict[int, int] | None = None, +): + """An experimental callback that operates in place on device buffers. + + Only supported on CPU and GPU backends. + + Note that the plan is for this to eventually be replaced by a consolidated + callback API built using JAX mutable arrays, but for now this provides a + mechanism for prototyping computational kernels using other Python libraries + including Numpy, PyTorch, Cupy, and others. + + Let's start with a simple example: + + >>> def py_add_one_inplace(ctx, out, x): + ... np.asarray(out)[...] = np.asarray(x) + 1 + ... + >>> x = jnp.array(41, dtype=jnp.int32) + >>> out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + >>> add_one = buffer_callback(py_add_one_inplace, out_type) + >>> add_one(x) # doctest: +SKIP + Array(42, dtype=int32) + + In this example, we're executing a numpy computation via JAX, and this could + have been implemented using :func:`jax.pure_callback`, but in this case, the + output is being populated in-place. This means that JAX doesn't need to copy + the output arrays upon returning from the callback. Note that even though the + callback function operates on mutable buffers, JAX still sees this as an + operation that consumes and produces regular immutable JAX arrays. + + Unlike the other JAX callback APIs, ``buffer_callback`` requires that the + user-defined Python function have the following signature: + + .. code-block:: python + + def callback(ctx: ExecutionContext, out, *args) -> None: + ... + + where ``ctx`` is an instance of + :class:`~jax.experimental.buffer_callback.ExecutionContext`, which mainly + provides access to XLA's computation stream when running on GPU, ``out`` is a + pytree of mutable :class:`~jax.experimental.buffer_callback.Buffer` objects, + and the ``args`` arguments have the same pytree structure as the inputs, but + each leaf is :class:`~jax.experimental.buffer_callback.Buffer`. This callback + should not return any values, and it should overwrite the ``out`` buffers in + place to output values back to JAX. + + It's important to note that this Python function can't really be called + except via ```buffer_callback`` itself, because it's not (yet!) possible to + construct mutable JAX buffers directly in Python. + + The bespoke :class:`~jax.experimental.buffer_callback.Buffer` type is an + array-like object that supports the ``__array__`` protocol on CPU, the + ``__cuda_array_interface__`` protocol on GPU, and the ``__dlpack__`` protocol + on both CPU and GPU. + + Args: + callback: A Python function with the signature and behavior described above. + result_shape_dtypes: A pytree whose leaves have ``shape`` and ``dtype`` + attributes, with a structure that matches the expected output of the + callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used + to define leaf values. + has_side_effect: Whether the callback has side effects. + vmap_method: A string specifying how the callback transforms under + :func:`~jax.vmap` as described in the docs for :func:`~jax.pure_callback`. + input_output_aliases: a dictionary mapping the index of some inputs to + the index of the output that aliases them. These indices are in the + flattened inputs and outputs. + + Returns: + A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof), + and pytree of :class:`jax.Array` objects whose structure matches that + of ``result_shape_dtypes``. + + See Also: + - :func:`jax.pure_callback`: callback designed for pure host functions. + - :func:`jax.experimental.io_callback`: callback designed for impure host + functions. + - :func:`jax.debug.callback`: callback designed for general-purpose + debugging. + - :func:`jax.debug.print`: callback designed for printing. + """ + flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) + flat_result_avals = tuple( + core.ShapedArray(x.shape, x.dtype) for x in flat_shape_dtypes + ) + + def wrapped_callback(*args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + + in_avals = [core.get_aval(x) for x in flat_args] + static_input_output_aliases: tuple[tuple[int, int], ...] = () + if input_output_aliases is not None: + for i_idx, o_idx in sorted(input_output_aliases.items()): + i_idx, o_idx = int(i_idx), int(o_idx) + if i_idx >= len(args): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with input index {i_idx} outside the range [0, " + f"{len(args)}).") + if o_idx >= len(flat_result_avals): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with output index {o_idx} outside the range [0, " + f"{len(flat_result_avals)}).") + in_aval = in_avals[i_idx] + out_aval = flat_result_avals[o_idx] + if not ffi._check_compatible_avals(in_aval, out_aval): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with abstract value {in_aval} and an " + f"output with a different abstract value {out_aval}.") + static_input_output_aliases += ((i_idx, o_idx),) + + out_flat = buffer_callback_p.bind( + *flat_args, + callback=callback, + result_avals=flat_result_avals, + in_tree=in_tree, + out_tree=out_tree, + vmap_method=vmap_method, + has_side_effect=has_side_effect, + input_output_aliases=static_input_output_aliases, + ) + return tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped_callback + + +buffer_callback_p = core.Primitive("buffer_callback") +buffer_callback_p.multiple_results = True +dispatch.prim_requires_devices_during_lowering.add(buffer_callback_p) +dispatch.simple_impl(buffer_callback_p) + + +class BufferCallbackEffect(effects.Effect): + def __str__(self): + return "BufferCallback" + +_BufferCallbackEffect = BufferCallbackEffect() +effects.lowerable_effects.add_type(BufferCallbackEffect) +effects.control_flow_allowed_effects.add_type(BufferCallbackEffect) + + +@buffer_callback_p.def_effectful_abstract_eval +def _buffer_callback_abstract_eval( + *args, + result_avals: tuple[core.ShapedArray, ...], + has_side_effect: bool, + **_, +): + del args + effects = {_BufferCallbackEffect} if has_side_effect else core.no_effects + return result_avals, effects + + +def _buffer_callback_jvp_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support JVP. " + "Please use `jax.custom_jvp` to use callbacks while taking gradients.") +ad.primitive_jvps[buffer_callback_p] = _buffer_callback_jvp_rule + + +def _buffer_callback_transpose_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support transpose. " + "Please use `jax.custom_vjp` to use callbacks while taking gradients.") +ad.primitive_transposes[buffer_callback_p] = _buffer_callback_transpose_rule + +batching.primitive_batchers[buffer_callback_p] = functools.partial( + ffi.ffi_batching_rule, buffer_callback_p +) + + +def _buffer_callback_lowering( + ctx: mlir.LoweringRuleContext, + *args: Any, + callback, + in_tree: Any, + out_tree: Any, + has_side_effect: bool, + input_output_aliases: Sequence[tuple[int, int]], + **_, +): + + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for buffer_callback") + platform = ctx.module_context.platforms[0] + target_name = { + "cpu": "xla_buffer_python_cpu_callback", + "cuda": "xla_buffer_python_gpu_callback", + "rocm": "xla_buffer_python_gpu_callback", + }.get(platform) + if target_name is None: + raise ValueError(f"`buffer_callback` not supported on {platform} backend.") + + def wrapped_callback(exec_ctx, *args: Any): + args_in, args_out = util.split_list(args, [in_tree.num_leaves]) + py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in) + py_args_out = tree_util.tree_unflatten(out_tree, args_out) + if callback(exec_ctx, py_args_out, *py_args_in, **py_kwargs_in) is not None: + raise ValueError("buffer_callback callback must not return any values.") + return () + + ctx.module_context.add_host_callback(wrapped_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + rule = ffi.ffi_lowering( + target_name, + has_side_effect=has_side_effect, + operand_output_aliases=dict(input_output_aliases), + ) + return rule(ctx, *args, index=index) +mlir.register_lowering(buffer_callback_p, _buffer_callback_lowering) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 49aeff9c3763..e3718d57ccd6 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -103,7 +103,6 @@ def _parse_version(v: str) -> tuple[int, ...]: from jaxlib.xla_extension import Device as Device # type: ignore # pytype: disable=import-error # noqa: F401 from jaxlib.xla_extension import profiler as _profiler # type: ignore # pytype: disable=import-error # noqa: F401 - import jaxlib.xla_client as xla_client # noqa: F401 # Jaxlib code is split between the Jax and the XLA repositories. @@ -113,6 +112,9 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_extension_version: int = getattr(xla_client, '_version', 0) ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) +if jaxlib_extension_version >= 334: + from jaxlib._jax import ffi as ffi # noqa: F401 + import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 diff --git a/jax/experimental/buffer_callback.py b/jax/experimental/buffer_callback.py new file mode 100644 index 000000000000..6c8514340af0 --- /dev/null +++ b/jax/experimental/buffer_callback.py @@ -0,0 +1,27 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.lib import jaxlib_extension_version as _jaxlib_extension_version + +if _jaxlib_extension_version >= 334: + from jax._src.buffer_callback import ( + Buffer as Buffer, + ExecutionContext as ExecutionContext, + ExecutionStage as ExecutionStage, + buffer_callback as buffer_callback, + ) + +from jax._src.buffer_callback import buffer_callback as buffer_callback + +del _jaxlib_extension_version diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e84203322f42..29f274ad9a03 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -286,6 +286,7 @@ nanobind_pywrap_extension( ":config", ":custom_call_sharding", ":dlpack", + ":ffi", ":guard_lib", ":ifrt_proxy", ":jax_jit", @@ -473,6 +474,24 @@ cc_library( ], ) +cc_library( + name = "dlpack_support", + srcs = ["dlpack_support.cc"], + hdrs = ["dlpack_support.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@dlpack", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + ], +) + cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -484,6 +503,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":dlpack_support", ":nb_class_ptr", ":py_client", ":python_ref_manager", @@ -515,6 +535,38 @@ cc_library( ], ) +cc_library( + name = "ffi", + srcs = ["ffi.cc"], + hdrs = ["ffi.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack_support", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "guard_lib", srcs = ["guard_lib.cc"], @@ -868,6 +920,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", @@ -875,6 +928,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@dlpack", "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 00a3a5d01fec..49e430069744 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -25,6 +25,7 @@ from typing import Any, ClassVar, TypeVar, overload import numpy as np from . import config as config +from . import ffi as ffi from . import guard_lib as guard_lib from . import ifrt_programs as ifrt_programs from . import ifrt_proxy as ifrt_proxy diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi new file mode 100644 index 000000000000..efaad46329f9 --- /dev/null +++ b/jaxlib/_jax/ffi.pyi @@ -0,0 +1,48 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +from typing import Any + +class Buffer: + @property + def dtype(self) -> Any: ... + @property + def ndim(self) -> int: ... + @property + def shape(self) -> tuple[int, ...]: ... + @property + def writeable(self) -> bool: ... + def __array__(self, dtype: Any = None, copy: bool | None = None) -> Any: ... + def __cuda_array_interface__(self) -> Any: ... + def __dlpack__( + self, + stream: Any = None, + max_version: Any = None, + dl_device: Any = None, + copy: Any = None, + ) -> Any: ... + def __dlpack_device__(self) -> tuple[int, int]: ... + +class ExecutionStage(enum.IntEnum): + INSTANTIATE = ... + PREPARE = ... + INITIALIZE = ... + EXECUTE = ... + +class ExecutionContext: + def stage(self) -> ExecutionStage: ... + def device_ordinal(self) -> int: ... + def stream(self) -> int: ... diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index be7ac6116d2f..3895b067c87f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -623,6 +623,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", + "//jaxlib:ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", @@ -633,6 +634,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@dlpack", "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index e753025ac714..d7500b711e48 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -51,6 +51,8 @@ nb::dict FfiRegistrations() { jax::EncapsulateFfiHandler(jax::cuda::kXlaFfiPythonGpuCallback); dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaBufferPythonGpuCallback); return dict; } diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc index ca11f665550f..c58eac81b9a7 100644 --- a/jaxlib/dlpack.cc +++ b/jaxlib/dlpack.cc @@ -41,6 +41,7 @@ limitations under the License. #include "jaxlib/python_ref_manager.h" #include "jaxlib/traceback.h" #include "jaxlib/util.h" +#include "jaxlib/dlpack_support.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" @@ -94,204 +95,6 @@ void DLPackTensorDeleter(DLManagedTensor* t) { } } -absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { - switch (type) { - case S8: - return DLDataType{kDLInt, 8, 1}; - case S16: - return DLDataType{kDLInt, 16, 1}; - case S32: - return DLDataType{kDLInt, 32, 1}; - case S64: - return DLDataType{kDLInt, 64, 1}; - case U8: - return DLDataType{kDLUInt, 8, 1}; - case U16: - return DLDataType{kDLUInt, 16, 1}; - case U32: - return DLDataType{kDLUInt, 32, 1}; - case U64: - return DLDataType{kDLUInt, 64, 1}; - case F4E2M1FN: - return DLDataType{kDLFloat4_e2m1fn, 4, 1}; - case F8E3M4: - return DLDataType{kDLFloat8_e3m4, 8, 1}; - case F8E4M3: - return DLDataType{kDLFloat8_e4m3, 8, 1}; - case F8E4M3B11FNUZ: - return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; - case F8E4M3FN: - return DLDataType{kDLFloat8_e4m3fn, 8, 1}; - case F8E4M3FNUZ: - return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; - case F8E5M2: - return DLDataType{kDLFloat8_e5m2, 8, 1}; - case F8E5M2FNUZ: - return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; - case F8E8M0FNU: - return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; - case BF16: - return DLDataType{kDLBfloat, 16, 1}; - case F16: - return DLDataType{kDLFloat, 16, 1}; - case F32: - return DLDataType{kDLFloat, 32, 1}; - case F64: - return DLDataType{kDLFloat, 64, 1}; - case PRED: - return DLDataType{kDLBool, 8, 1}; - case C64: - return DLDataType{kDLComplex, 64, 1}; - case C128: - return DLDataType{kDLComplex, 128, 1}; - default: - return Unimplemented("XLA type %s has no DLPack equivalent", - PrimitiveType_Name(type)); - } -} - -absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { - if (type.lanes != 1) { - return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", - type.lanes); - } - switch (type.code) { - case kDLBool: - switch (type.bits) { - case 8: - return PRED; - default: - return Unimplemented( - "Only 8-bit DLPack booleans are supported, got %d bits", - type.bits); - } - case kDLInt: - switch (type.bits) { - case 8: - return S8; - case 16: - return S16; - case 32: - return S32; - case 64: - return S64; - default: - return Unimplemented( - "Invalid or unsupported DLPack integer width: %d bits", - type.bits); - } - case kDLUInt: - switch (type.bits) { - case 8: - return U8; - case 16: - return U16; - case 32: - return U32; - case 64: - return U64; - default: - return Unimplemented( - "Invalid or unsupported DLPack unsigned integer width: %d bits", - type.bits); - } - case kDLFloat4_e2m1fn: - if (type.bits == 4) { - return F4E2M1FN; - } - return Unimplemented( - "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", - type.bits); - case kDLFloat8_e3m4: - if (type.bits == 8) { - return F8E3M4; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e3m4 width: %d bits", - type.bits); - case kDLFloat8_e4m3: - if (type.bits == 8) { - return F8E4M3; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e4m3 width: %d bits", - type.bits); - case kDLFloat8_e4m3b11fnuz: - if (type.bits == 8) { - return F8E4M3B11FNUZ; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", - type.bits); - case kDLFloat8_e4m3fn: - if (type.bits == 8) { - return F8E4M3FN; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", - type.bits); - case kDLFloat8_e4m3fnuz: - if (type.bits == 8) { - return F8E4M3FNUZ; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", - type.bits); - case kDLFloat8_e5m2: - if (type.bits == 8) { - return F8E5M2; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e5m2 width: %d bits", - type.bits); - case kDLFloat8_e5m2fnuz: - if (type.bits == 8) { - return F8E5M2FNUZ; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", - type.bits); - case kDLFloat8_e8m0fnu: - if (type.bits == 8) { - return F8E8M0FNU; - } - return Unimplemented( - "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", - type.bits); - case kDLBfloat: - if (type.bits == 16) { - return BF16; - } - return Unimplemented( - "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); - case kDLFloat: - switch (type.bits) { - case 16: - return F16; - case 32: - return F32; - case 64: - return F64; - default: - return Unimplemented( - "Invalid or unsupported DLPack float width: %d bits", type.bits); - } - case kDLComplex: - switch (type.bits) { - case 64: - return C64; - case 128: - return C128; - default: - return Unimplemented( - "Invalid or unsupported DLPack complex width: %d bits", - type.bits); - } - default: - return Unimplemented("Unknown or invalid DLPack type code %d", type.code); - } -} - absl::StatusOr> StridesToLayout( absl::Span dims, absl::Span strides) { CHECK_EQ(dims.size(), strides.size()); diff --git a/jaxlib/dlpack_support.cc b/jaxlib/dlpack_support.cc new file mode 100644 index 000000000000..9e851842ed14 --- /dev/null +++ b/jaxlib/dlpack_support.cc @@ -0,0 +1,223 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack_support.h" + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +} // namespace xla diff --git a/jaxlib/dlpack_support.h b/jaxlib/dlpack_support.h new file mode 100644 index 000000000000..25e862353bab --- /dev/null +++ b/jaxlib/dlpack_support.h @@ -0,0 +1,30 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_DLPACK_SUPPORT_H_ +#define JAXLIB_XLA_DLPACK_SUPPORT_H_ + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type); +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type); + +} // namespace xla + +#endif // JAXLIB_XLA_DLPACK_SUPPORT_H_ diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc new file mode 100644 index 000000000000..1bf9a5a3150a --- /dev/null +++ b/jaxlib/ffi.cc @@ -0,0 +1,386 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/ffi.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/dlpack_support.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +namespace { +const char* const kDlTensorCapsuleName = "dltensor"; +const char* const kDlTensorVersionedCapsuleName = "dltensor_versioned"; + +template +struct DLPackTensor { + std::vector shape; + ManagedTensor tensor; +}; + +template +void DLPackTensorDeleter(ManagedTensor* t) { + if (t) { + delete static_cast*>(t->manager_ctx); + } +} + +xla::PrimitiveType PrimitiveTypeForFfiDataType(ffi::DataType dtype) { + switch (dtype) { + case ffi::DataType::INVALID: + return xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; + case ffi::PRED: + return xla::PrimitiveType::PRED; + case ffi::S1: + return xla::PrimitiveType::S1; + case ffi::S2: + return xla::PrimitiveType::S2; + case ffi::S4: + return xla::PrimitiveType::S4; + case ffi::S8: + return xla::PrimitiveType::S8; + case ffi::S16: + return xla::PrimitiveType::S16; + case ffi::S32: + return xla::PrimitiveType::S32; + case ffi::S64: + return xla::PrimitiveType::S64; + case ffi::U1: + return xla::PrimitiveType::U1; + case ffi::U2: + return xla::PrimitiveType::U2; + case ffi::U4: + return xla::PrimitiveType::U4; + case ffi::U8: + return xla::PrimitiveType::U8; + case ffi::U16: + return xla::PrimitiveType::U16; + case ffi::U32: + return xla::PrimitiveType::U32; + case ffi::U64: + return xla::PrimitiveType::U64; + case ffi::F16: + return xla::PrimitiveType::F16; + case ffi::F32: + return xla::PrimitiveType::F32; + case ffi::F64: + return xla::PrimitiveType::F64; + case ffi::BF16: + return xla::PrimitiveType::BF16; + case ffi::C64: + return xla::PrimitiveType::C64; + case ffi::C128: + return xla::PrimitiveType::C128; + case ffi::TOKEN: + return xla::PrimitiveType::TOKEN; + case ffi::F8E5M2: + return xla::PrimitiveType::F8E5M2; + case ffi::F8E4M3: + return xla::PrimitiveType::F8E4M3; + case ffi::F8E4M3FN: + return xla::PrimitiveType::F8E4M3FN; + case ffi::F8E4M3B11FNUZ: + return xla::PrimitiveType::F8E4M3B11FNUZ; + case ffi::F8E5M2FNUZ: + return xla::PrimitiveType::F8E5M2FNUZ; + case ffi::F8E4M3FNUZ: + return xla::PrimitiveType::F8E4M3FNUZ; + case ffi::F8E3M4: + return xla::PrimitiveType::F8E3M4; + case ffi::F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; + case ffi::F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; + } +} +} // namespace + +PyFfiContext::PyFfiContext(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage) + : api_(api), ctx_(ctx), stage_(stage) {} + +PyFfiContext::Stage PyFfiContext::stage() const { + return static_cast(stage_); +} + +absl::StatusOr PyFfiContext::stream() const { + XLA_FFI_Stream_Get_Args args; + args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx_; + args.stream = nullptr; + if (XLA_FFI_Error* error = api_->XLA_FFI_Stream_Get(&args)) { + return ffi::TakeStatus(error); + } + return absl::bit_cast(args.stream); +} + +absl::StatusOr PyFfiContext::device_ordinal() const { + XLA_FFI_DeviceOrdinal_Get_Args args; + args.struct_size = XLA_FFI_DeviceOrdinal_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.device_ordinal = 0; + if (XLA_FFI_Error* error = api_->XLA_FFI_DeviceOrdinal_Get(&args)) { + return ffi::TakeStatus(error); + } + return args.device_ordinal; +} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + void* data, ffi::Span dimensions, + ffi::DataType element_type, bool writeable) + : device_type_(device_type), + device_ordinal_(device_ordinal), + data_(data), + dimensions_(dimensions.begin(), dimensions.size()), + element_type_(PrimitiveTypeForFfiDataType(element_type)), + writeable_(writeable) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf.untyped_data(), + buf.dimensions(), buf.element_type(), + /*writeable=*/false) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf->untyped_data(), + buf->dimensions(), buf->element_type(), + /*writeable=*/true) {} + +absl::StatusOr PyFfiAnyBuffer::dtype() const { + return xla::PrimitiveTypeToNbDtype(element_type_); +} + +size_t PyFfiAnyBuffer::ndim() const { return dimensions_.size(); } + +nb::tuple PyFfiAnyBuffer::shape() const { + return xla::SpanToNbTuple(dimensions_); +} + +bool PyFfiAnyBuffer::writeable() const { return writeable_; } + +absl::StatusOr PyFfiAnyBuffer::NumpyArray() const { + if (device_type_ != kDLCPU) { + return absl::UnimplementedError( + "Buffer.__array__ is only supported on CPU."); + } + + TF_ASSIGN_OR_RETURN(auto dtype, this->dtype()); + xla::nb_numpy_ndarray array(dtype, dimensions_, /* strides= */ std::nullopt, + data_, nb::cast(this)); + + // TODO(danfm): We don't seem to be allowed to set this flag like this + // because the array doesn't own its data. + // array.attr("flags").attr("writeable") = nb::bool_(writeable_); + + return array; +} + +absl::StatusOr PyFfiAnyBuffer::CudaArrayInterface() const { + if (device_type_ != kDLCUDA) { + return absl::UnimplementedError( + "Buffer.__cuda_array_interface__ is only supported on CUDA."); + } + + nb::dict result; + result["shape"] = xla::SpanToNbTuple(dimensions_); + TF_ASSIGN_OR_RETURN(result["typestr"], + TypeDescriptorForPrimitiveType(element_type_)); + result["data"] = nb::make_tuple( + nb::int_(absl::bit_cast(data_)), !writeable_); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr PyFfiAnyBuffer::DLPack() const { + auto pack = std::make_unique>(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +absl::StatusOr PyFfiAnyBuffer::DLPackVersioned() const { + auto pack = std::make_unique>(); + pack->tensor.version = + DLPackVersion{DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION}; + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + pack->tensor.flags = writeable_ ? 0 : DLPACK_FLAG_BITMASK_READ_ONLY; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal(PyCapsule_New( + &pack.release()->tensor, kDlTensorVersionedCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensorVersioned* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorVersionedCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +nb::tuple PyFfiAnyBuffer::DLPackDevice() const { + return nb::make_tuple(static_cast(device_type_), device_ordinal_); +} + +void BuildFfiSubmodule(nb::module_& m) { + tsl::ImportNumpy(); + + nb::module_ ffi_module = + m.def_submodule("ffi", "Python bindings for the XLA FFI."); + + nb::class_ buffer(ffi_module, "Buffer"); + buffer.def_prop_ro("dtype", xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::dtype)); + buffer.def_prop_ro("ndim", &PyFfiAnyBuffer::ndim); + buffer.def_prop_ro("shape", &PyFfiAnyBuffer::shape); + buffer.def_prop_ro("writeable", &PyFfiAnyBuffer::writeable); + buffer.def( + "__array__", + [](PyFfiAnyBuffer self, nb::object dtype, nb::object copy) { + if (!dtype.is_none()) { + throw nb::value_error( + "dtype parameter is not supported by Buffer.__array__."); + } + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__array__ with copy=True is not supported."); + } + return xla::ValueOrThrow(self.NumpyArray()); + }, + nb::arg("dtype") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def_prop_ro( + "__cuda_array_interface__", + xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::CudaArrayInterface)); + buffer.def( + "__dlpack__", + [](PyFfiAnyBuffer self, nb::object stream, nb::object max_version, + nb::object dl_device, nb::object copy) { + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__dlpack__ with copy=True is not supported."); + } + + // Fall back on the non-versioned API if unsupported by the requested + // max_version. + nb::tuple max_version_tuple; + int64_t max_version_major; + if (!nb::try_cast(max_version, max_version_tuple) || + max_version_tuple.size() < 2 || + !nb::try_cast(max_version_tuple[0], max_version_major) || + max_version_major < 1) { + return xla::ValueOrThrow(self.DLPack()); + } + + // TODO(danfm): Handle other optional inputs. + return xla::ValueOrThrow(self.DLPackVersioned()); + }, + nb::arg("stream") = nb::none(), nb::arg("max_version") = nb::none(), + nb::arg("dl_device") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def("__dlpack_device__", &PyFfiAnyBuffer::DLPackDevice); + + nb::enum_(ffi_module, "ExecutionStage") + .value("INSTANTIATE", PyFfiContext::Stage::kInstantiate) + .value("PREPARE", PyFfiContext::Stage::kPrepare) + .value("INITIALIZE", PyFfiContext::Stage::kInitialize) + .value("EXECUTE", PyFfiContext::Stage::kExecute) + .export_values(); + + nb::class_ context(ffi_module, "ExecutionContext"); + context.def_prop_ro("stage", &PyFfiContext::stage); + context.def_prop_ro("device_ordinal", + xla::ValueOrThrowWrapper(&PyFfiContext::device_ordinal)); + context.def_prop_ro("stream", + xla::ValueOrThrowWrapper(&PyFfiContext::stream)); +} + +} // namespace jax diff --git a/jaxlib/ffi.h b/jaxlib/ffi.h new file mode 100644 index 000000000000..393ce79ddabc --- /dev/null +++ b/jaxlib/ffi.h @@ -0,0 +1,153 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_FFI_H_ +#define JAXLIB_XLA_FFI_H_ + +#include + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// Wrapper class for XLA FFI execution context. +// +// This class provides a Python interface to the XLA FFI execution context, +// exposing metadata such as the execution stage, device ordinal, and stream. +class PyFfiContext { + public: + enum class Stage { + kInstantiate, + kPrepare, + kInitialize, + kExecute, + }; + + PyFfiContext(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage); + Stage stage() const; + absl::StatusOr stream() const; + absl::StatusOr device_ordinal() const; + + private: + const XLA_FFI_Api* api_; + XLA_FFI_ExecutionContext* ctx_; + XLA_FFI_ExecutionStage stage_; +}; + +// Wrapper class for XLA FFI AnyBuffer. +// +// This class provides a Python interface to the XLA FFI `AnyBuffer` class. +// From Python, this object looks like an array (with `.dtype` and `.shape` +// attributes), but it also provides methods zero-copy conversions to standard +// transport formats: `__array__`, `__cuda_array_interface__`, and `__dlpack__`. +class PyFfiAnyBuffer { + public: + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, void* data, + ffi::Span dimensions, + ffi::DataType element_type, bool writeable); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf); + + absl::StatusOr dtype() const; + size_t ndim() const; + nb::tuple shape() const; + bool writeable() const; + + absl::StatusOr NumpyArray() const; + absl::StatusOr CudaArrayInterface() const; + absl::StatusOr DLPack() const; + absl::StatusOr DLPackVersioned() const; + nb::tuple DLPackDevice() const; + + private: + DLDeviceType device_type_; + int32_t device_ordinal_; + void* data_; + absl::Span dimensions_; + xla::PrimitiveType element_type_; + bool writeable_; +}; + +template +ffi::Error XlaBufferCallback(int32_t device_ordinal, const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + xla::FfiLoadedHostCallbacks* callbacks, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = + nb::steal(PyTuple_New(1 + args.size() + rets.size())); + + jax::PyFfiContext py_ctx(api, ctx, XLA_FFI_ExecutionStage_EXECUTE); + PyTuple_SET_ITEM(nb_args.ptr(), 0, nb::cast(py_ctx).release().ptr()); + + size_t offset = 1; + for (size_t i = 0; i < args.size(); ++i, ++offset) { + auto arg = args.get(i); + if (arg.has_error()) { + return arg.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, arg.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + for (size_t i = 0; i < rets.size(); ++i, ++offset) { + auto ret = rets.get(i); + if (ret.has_error()) { + return ret.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, ret.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + xla::EnterHostCallback(); + try { + callback(*nb::borrow(nb_args)); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("Error when calling buffer callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + return ffi::Error::Success(); +} + +void BuildFfiSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_FFI_H_ diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 580e0130c3a8..570e3135b1c9 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -33,8 +33,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "include/dlpack/dlpack.h" #include "nanobind/nanobind.h" #include "jaxlib/gpu/vendor.h" +#include "jaxlib/ffi.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/pjrt/host_callback.h" @@ -225,5 +227,27 @@ XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, kXlaFfiPythonGpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallback, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallback); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 4d48858ad278..b389dd393443 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ #define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ - #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -24,6 +23,7 @@ namespace jax { namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallback); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 17e69abd10f6..ff4720748b0f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -50,6 +50,7 @@ pallas_tpu_internal_users = [] pallas_fuser_users = [] mosaic_extension_deps = [] serialize_executable_internal_users = [] +buffer_callback_internal_users = [] jax_internal_export_back_compat_test_util_visibility = [] jax_internal_test_harnesses_visibility = [] diff --git a/jaxlib/py_client_cpu.cc b/jaxlib/py_client_cpu.cc index 647f33c59900..e6778a69a9c3 100644 --- a/jaxlib/py_client_cpu.cc +++ b/jaxlib/py_client_cpu.cc @@ -32,7 +32,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "include/dlpack/dlpack.h" #include "nanobind/nanobind.h" +#include "jaxlib/ffi.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/pjrt/host_callback.h" @@ -182,4 +184,19 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_partitioned_python_cpu_callback", "HOST", {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, kXlaFfiPythonCpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaBufferPythonCpuCallback, + (jax::XlaBufferCallback), + ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_buffer_python_cpu_callback", + "HOST", kXlaBufferPythonCpuCallback); + } // namespace xla diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 94d75d9c19ae..8fc988137ff7 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -525,6 +525,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":hip_vendor", + "//jaxlib:ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", @@ -535,6 +536,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@dlpack", "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index d893c7fb7fe2..74013ed0de68 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -75,6 +75,8 @@ nb::dict FfiRegistrations() { jax::EncapsulateFfiHandler(jax::hip::kXlaFfiPythonGpuCallback); dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaBufferPythonGpuCallback); return dict; } diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 219e220111af..8b47223591dd 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -47,6 +47,7 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/ffi.h" #include "jaxlib/ifrt_proxy.h" #include "jaxlib/py_client.h" #include "jaxlib/py_program.h" @@ -587,6 +588,7 @@ NB_MODULE(_jax, m) { BuildMlirSubmodule(m); BuildSdySubmodule(m); BuildCustomCallShardingPybindAPI(m); + jax::BuildFfiSubmodule(m); #if defined(__linux__) aux::RegisterTransferServerTypes(m); #endif // defined(__linux__) diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 82ae7855a8ee..39c9922adbff 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 333 +_version = 334 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/BUILD b/tests/BUILD index 9af33da83e5d..50fd9e8fe837 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -114,6 +114,18 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "buffer_callback_test", + srcs = ["buffer_callback_test.py"], + enable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:experimental_buffer_callback", + ], +) + jax_py_test( name = "config_test", srcs = ["config_test.py"], diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py new file mode 100644 index 000000000000..647b5407567b --- /dev/null +++ b/tests/buffer_callback_test.py @@ -0,0 +1,180 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu +from jax._src.lib import jaxlib_extension_version +from jax.experimental import buffer_callback + +jax.config.parse_flags_with_absl() + + +class BufferCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if jaxlib_extension_version < 334: + self.skipTest( + "Requires a version of jaxlib with buffer callback support." + ) + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_numpy(self, dtype): + def callback(ctx, out, arg): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "XLA FFI GPU context is not available" + ): + ctx.stream + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(ctx.device_ordinal, 0) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + self.assertFalse(arg.writeable) + self.assertTrue(out.writeable) + + x = np.asarray(arg) + self.assertArraysEqual(x, data) + + y = np.asarray(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + y[...] = x + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + self.assertArraysEqual(fun(data), data) + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_dlpack(self, dtype): + if dtype == jnp.bfloat16: + self.skipTest("Numpy's DLPack implementation does not support bfloat16") + + def callback(ctx, out, arg): + del ctx # unused + + x = np.from_dlpack(arg) + self.assertArraysEqual(x, data) + + y = np.from_dlpack(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + + # We can't actually test the output because numpy doesn't support writable + # DLPack tensors. + fun(data) + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cuda") + def test_cuda_array_interface(self, dtype): + def callback(ctx, out, arg): + ctx.stream # doesn't crash + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + obj = arg.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + obj = out.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + fun(data) + + @parameterized.parameters([ + "sequential", "sequential_unrolled", "expand_dims", "broadcast_all" + ]) + @jtu.run_on_devices("cpu") + def test_batching(self, vmap_method): + def callback(ctx, out, *args): + del ctx # unused + x = np.asarray(args[0]) + y = np.asarray(args[1]) + z = np.asarray(out) + z[...] = x + z[...] += y + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + x = rng(shape, jnp.float32) + y = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + vmap_method=vmap_method, + ) + self.assertArraysEqual(jax.vmap(fun)(x, y), x + y) + + @jtu.run_on_devices("cpu") + def test_input_output_aliases(self): + def callback(ctx, out, arg): + del ctx # unused + x = np.asarray(arg) + y = np.asarray(out) + self.assertEqual(x.ctypes.data, y.ctypes.data) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + input_output_aliases={0: 0}, + ) + fun(data) + + def test_side_effect(self): + def callback(*_): + nonlocal called + called = True + + called = False + fun = buffer_callback.buffer_callback(callback, (), has_side_effect=True) + fun() + self.assertTrue(called) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From ede6bfe912fe13d39096dee0832d81cf7eef27ad Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Wed, 30 Apr 2025 13:38:46 -0700 Subject: [PATCH 0923/1769] [CI] Reenable old nightly cloud TPU for workflow_dispatch The workflow file became invalid when it lost all of its "on:" conditions. Reenable workflow_dispatch to prevent notifications for invalid workflow PiperOrigin-RevId: 753294372 --- .github/workflows/cloud-tpu-ci-nightly.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index cb9e04d4488d..5cc2aebe3cd0 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,10 +13,10 @@ name: CI - Cloud TPU (nightly) # Disable the schedule; Slated for removal, the new test workflow is in # "wheel_tests_nightly_release.yml" -# on: +on: # schedule: # - cron: "0 2,14 * * *" # Run at 7am and 7pm PST -# workflow_dispatch: # allows triggering the workflow run manually + workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. From 22f8e684ebd3671b15982c30df21c623ec78a03a Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Sun, 27 Apr 2025 23:38:35 -0700 Subject: [PATCH 0924/1769] [JAX][Docs] Add an example of device compute on host memory space --- docs/notebooks/host-offloading.ipynb | 42 ++++++++++++++++++++++++++++ docs/notebooks/host-offloading.md | 22 +++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb index a14953b12850..9c806c2d56e7 100644 --- a/docs/notebooks/host-offloading.ipynb +++ b/docs/notebooks/host-offloading.ipynb @@ -166,6 +166,48 @@ "print(\"Result value of H2D: \\n\", out_dev)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "iYXC5ix384XP" + }, + "source": [ + "Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cmM6tJTS84XQ", + "outputId": "40c353a1-fb55-44bc-bac9-dffc09852f49" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D and add 1 in device memory: \n", + " [[1. 2. 3. 4.]\n", + " [5. 6. 7. 8.]]\n" + ] + } + ], + "source": [ + "# Instead of the lambda function, you can define add_func to explicitly\n", + "# move data to device before computation\n", + "def add_func(x): # Move data to device and add one\n", + " x = jax.device_put(x, s_dev)\n", + " return x + 1\n", + "\n", + "f = jax.jit(add_func, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D and add 1 in device memory: \\n\", out_dev)" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md index 96f59ee7f46e..7e113d40a4b3 100644 --- a/docs/notebooks/host-offloading.md +++ b/docs/notebooks/host-offloading.md @@ -120,6 +120,28 @@ out_dev = f(arr_host) print("Result value of H2D: \n", out_dev) ``` ++++ {"id": "iYXC5ix384XP"} + +Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: cmM6tJTS84XQ +outputId: 40c353a1-fb55-44bc-bac9-dffc09852f49 +--- +# Instead of the lambda function, you can define add_func to explicitly +# move data to device before computation +def add_func(x): # Move data to device and add one + x = jax.device_put(x, s_dev) + return x + 1 + +f = jax.jit(add_func, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D and add 1 in device memory: \n", out_dev) +``` + +++ {"id": "EbE-eBrJTBuS"} #### Host Output Sharding From 51e973fd7b2b62f7d66312a9c79b8bed08f06852 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Apr 2025 14:06:27 -0700 Subject: [PATCH 0925/1769] Remove old Windows nightly workflows These workflows have been replaced by "wheel_tests_nightly_release.yml" and "wheel_tests_continuous.yml" PiperOrigin-RevId: 753305525 --- .github/workflows/wheel_win_x64.yml | 65 ------------------------- .github/workflows/windows_ci.yml | 75 ----------------------------- 2 files changed, 140 deletions(-) delete mode 100644 .github/workflows/wheel_win_x64.yml delete mode 100644 .github/workflows/windows_ci.yml diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml deleted file mode 100644 index 6539a50ce790..000000000000 --- a/.github/workflows/wheel_win_x64.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: Wheel build - Windows CPU x86_64 -# Slated for removal, Windows release/nightly wheels are now built in the internal CI system. -# on: -# workflow_dispatch: # allows triggering the workflow run manually - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - strategy: - fail-fast: false # Don't stop all wheel builds if one has a test failure. - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10', '3.11', '3.12', '3.13'] - name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build - runs-on: ${{ matrix.os }} - - steps: - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_RELEASE: true - run: | - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt ` - --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels-${{ matrix.os }}-${{ matrix.pyver }} - path: ${{ github.workspace }}\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib ` - -e ${{ github.workspace }} - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml deleted file mode 100644 index 7d848391e64d..000000000000 --- a/.github/workflows/windows_ci.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: CI - Windows CPU -# Disable the schedule; Slated for removal, the new test workflows are in -# "wheel_tests_nightly_release.yml" and "wheel_tests_continuous.yml" -# on: -# schedule: -# - cron: "0 12 * * *" # Daily at 12:00 UTC -# workflow_dispatch: # allows triggering the workflow run manually -# pull_request: -# types: [ labeled ] # allow force-windows-run label - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} - strategy: - fail-fast: true - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10'] - name: Windows CI build - runs-on: ${{ matrix.os }} - - steps: - - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_NIGHTLY: true # Tag the wheels as dev versions - run: | - cd jax - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels - path: ${{ github.workspace }}\jax\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - cd jax - python -m uv pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib ` - -e ${{ github.workspace }}\jax - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples From fbd17dede47a9c683255d41e7852316defaa3e4d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Apr 2025 14:20:30 -0700 Subject: [PATCH 0926/1769] Add 3.13 and 3.13t to tpu matrix strategy and run only a single build for oldest supported libtpu The following builds are run for TPU in the nightly/release workflow: * v4-8: Python 3.10 * v5e-8: Python 3.13, Python 3.13t * v6e-8: Python 3.11, Python 3.12 The following builds are run for TPU in the continuous workflow: * v4-8: Python 3.10, libtpu=nightly * v5e-8: Python 3.10, libtpu=nightly, oldest_supported_libtpu * v6e-8: Python 3.10, libtpu=nightly PiperOrigin-RevId: 753310757 --- .github/workflows/wheel_tests_continuous.yml | 14 ++++++++-- .../workflows/wheel_tests_nightly_release.yml | 27 ++++++++++++++----- build/collect-profile-requirements.txt | 3 ++- ci/run_pytest_tpu.sh | 4 ++- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 95d83a92c776..a2ee224c38e1 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -174,8 +174,18 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] + libtpu-version-type: ["nightly", "oldest_supported_libtpu"] + exclude: + # Run a single config for oldest_supported_libtpu + - libtpu-version-type: "oldest_supported_libtpu" + tpu-specs: + type: "v4-8" + - libtpu-version-type: "oldest_supported_libtpu" + tpu-specs: + type: "v6e-8" name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.tpu-specs.runner }} @@ -183,5 +193,5 @@ jobs: tpu-type: ${{ matrix.tpu-specs.type }} python: ${{ matrix.python }} run-full-tpu-test-suite: "1" - libtpu-version-type: "nightly" + libtpu-version-type: ${{ matrix.libtpu-version-type }} gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index de89c278c6c3..8d597b84f735 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -83,37 +83,50 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for - # profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302) - python: ["3.10", "3.11", "3.12"] + python: ["3.10", "3.11", "3.12", "3.13", "3.13-nogil"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] - libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] + libtpu-version-type: ["pypi_latest", "nightly"] exclude: + # Exclude nightly for releases - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'oldest_supported_libtpu' }} + # Exclude pypi_latest for nightly releases - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8 and v6e-8 + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" + - tpu-specs: + type: "v4-8" + python: "3.12" + - tpu-specs: + type: "v4-8" + python: "3.13-nogil" + # Run Python versions in between min and max for v6e-8 - tpu-specs: type: "v6e-8" python: "3.10" - tpu-specs: type: "v6e-8" - python: "3.11" + python: "3.13" + - tpu-specs: + type: "v6e-8" + python: "3.13-nogil" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" python: "3.11" + - tpu-specs: + type: "v5e-8" + python: "3.12" + name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.tpu-specs.runner }} diff --git a/build/collect-profile-requirements.txt b/build/collect-profile-requirements.txt index a7d57dd2c4ef..e58558fd29a6 100644 --- a/build/collect-profile-requirements.txt +++ b/build/collect-profile-requirements.txt @@ -1,4 +1,5 @@ -tensorflow +# TF hasn't released 3.13 wheels yet (b/402590302) +tensorflow; python_version<"3.13" tensorboard-plugin-profile<=2.19.0 # Needed for the profile plugin to work without error protobuf diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 5d8aa9ed648f..ef5a8cbef943 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -41,7 +41,9 @@ echo "Installed packages:" "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' -strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' +# Free-threaded builds use "-nogil" as the suffix for the binary and "t" for its +# dist-packages path +strings /usr/local/lib/"${JAXCI_PYTHON//-nogil/t}"/dist-packages/libtpu/libtpu.so | grep 'Built on' "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' # Set up all common test environment variables From a68dbd74adc455f8eb24921a35903e6f736fc4d3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 3 Apr 2025 01:00:40 +0000 Subject: [PATCH 0927/1769] initial commit of boxes/lists Co-authored-by: Dougal Maclaurin --- jax/_src/api.py | 25 +- jax/_src/core.py | 2 +- jax/_src/interpreters/partial_eval.py | 14 +- jax/_src/lax/control_flow/loops.py | 40 ++- jax/_src/pjit.py | 227 +++++++++++---- jax/_src/util.py | 3 +- jax/experimental/attrs.py | 132 ++++++++- tests/attrs_test.py | 384 +++++++++++++++++++++++++- 8 files changed, 752 insertions(+), 75 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4a729b223ff1..167f9c67d9ec 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1851,8 +1851,11 @@ def jvp( def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): """Variant of jvp() that takes an lu.WrappedFun.""" - ps_flat, tree_def = tree_flatten(primals) - ts_flat, tree_def_2 = tree_flatten(tangents) + primals_, (), primal_box_data = pjit._flatten_boxes(fun.debug_info, primals, {}) + tangents_, (), tangent_box_data = pjit._flatten_boxes(fun.debug_info, tangents, {}) + fun = pjit._handle_boxes(fun, fun.debug_info) + ps_flat, tree_def = tree_flatten(primals_) + ts_flat, tree_def_2 = tree_flatten(tangents_) if tree_def != tree_def_2: raise TypeError("primal and tangent arguments to jax.jvp must have the same tree " f"structure; primals have tree structure {tree_def} whereas tangents have " @@ -1873,9 +1876,27 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def) out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat) out_tree = out_tree() + if primal_box_data or tangent_box_data: + assert primal_box_data and tangent_box_data + box_treedef, out_tree = out_tree.children() + box_out_flat, out_primals = split_list(out_primals, [box_treedef.num_leaves]) + box_dot_out_flat, out_tangents = split_list(out_tangents, [box_treedef.num_leaves]) + box_out = tree_unflatten(box_treedef, box_out_flat) + box_dot_out = tree_unflatten(box_treedef, box_dot_out_flat) + for (i, kind), b in zip(primal_box_data, box_out): + if kind is pe.BoxAttr: + primals[i].set(tree_unflatten(b.treedef, b.leaves)) + else: + assert False + for (i, kind), b in zip(tangent_box_data, box_dot_out): + if kind is pe.BoxAttr: + tangents[i].set(tree_unflatten(b.treedef, b.leaves)) + else: + assert False return (tree_unflatten(out_tree, out_primals), tree_unflatten(out_tree, out_tangents)) else: + if primal_box_data or tangent_box_data: raise NotImplementedError flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def) jvp_fun, aux = ad.jvp(flat_fun, has_aux=True) out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat) diff --git a/jax/_src/core.py b/jax/_src/core.py index 4bc64b85b81d..2160f0af824f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3403,7 +3403,7 @@ def __eq__(self, other): else: return False -def get_opaque_trace_state(convention): +def get_opaque_trace_state(convention=None): del convention return OpaqueTraceState(trace_ctx.trace._weakref) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1c42f86c3d48..990b0c51d175 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -65,6 +65,8 @@ def identity(x): return x # Attrs flavors, see jax/experimental/attrs.py ReadWrite = type('ReadWrite', (), {})() Append = type('Append', (), {})() +BoxAttr = type('BoxAttr', (), {})() +ListAttr = type('ListAttr', (), {})() def _update_annotation_known( f: lu.WrappedFun, @@ -1724,8 +1726,8 @@ def __init__(self, debug_info: core.DebugInfo): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def reset_states(self): - reset_states(self.attrs_tracked, self.attrs_inits) + def reset_states(self, trace): + reset_states(trace, self.attrs_tracked, self.attrs_inits) def to_jaxpr( self, trace: DynamicJaxprTrace, @@ -1889,6 +1891,8 @@ def invalidate(self): self.frame.tracers = [] self.frame.constid_to_tracer = {} self.frame.constvar_to_val = {} + self.frame.attrs_tracked = [] + self.frame.attrs_inits = [] def to_jaxpr_tracer(self, x, source_info: SourceInfo): as_local_var = self.frame.tracer_to_var.get(id(x)) @@ -2234,7 +2238,7 @@ def trace_to_jaxpr_dynamic( jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) del fun, in_tracers, out_tracers, ans finally: - trace.frame.reset_states() + trace.frame.reset_states(trace) del trace config.enable_checks.value and core.check_jaxpr(jaxpr) @@ -2297,8 +2301,8 @@ def trace_to_jaxpr_dynamic2( AttrsTracked = list[tuple[Any, str, AttrKind]] AttrStates = list -def reset_states(attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: - for ((obj, attr, _), val) in zip(attrs_tracked, init_vals): +def reset_states(trace, attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: + for ((obj, attr, kind), val) in zip(attrs_tracked, init_vals): setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 58ef64add37a..da3fa6fa1019 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -288,11 +288,12 @@ def _create_jaxpr(init): raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) if attrs_tracked: - appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + appends_out = [k for _, t, (_, _, k) in attrs_tracked + for k in [k in (pe.Append, pe.ListAttr)] * t.num_leaves] jaxpr = pe.move_outvars_to_back( jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out))) num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) - in attrs_tracked if kind is pe.ReadWrite) + in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr)) _, carry_avals_out, _ = split_list( jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves]) else: @@ -361,7 +362,9 @@ def _create_jaxpr(init): if attrs_tracked: num_ext = (len(out) - len(in_state) - - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) + - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked) + - sum(t.num_leaves for _, t, (_, _, k) in attrs_tracked + if k is pe.ListAttr)) out_state, out, out_append = split_list(out, [len(in_state), num_ext]) out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) _set_states(attrs_tracked, out_attrs) @@ -378,6 +381,13 @@ def _set_states(attrs_tracked, vals): elif kind is pe.Append: val, = leaves jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:])) + elif kind is pe.BoxAttr: + val = tree_unflatten(treedef, leaves) + obj.set(val) + elif kind is pe.ListAttr: + for leaves_ in zip(*leaves): + for item in tree_unflatten(treedef, leaves_): + obj.append(item) else: assert False @@ -392,15 +402,30 @@ def _get_states(attrs_tracked): vals.extend(leaves) elif kind is pe.Append: pass + elif kind is pe.BoxAttr: + tree = obj.get() + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.ListAttr: + pass else: assert False return vals def _merge_attrs_out(attrs_tracked, out_state, out_append): + # merge out_state & out_append back into attrs_tracked order out_state_, out_append_ = iter(out_state), iter(out_append) - out_attrs = [item for _, out_tree, (_, _, k) in attrs_tracked for item in - (itertools.islice(out_state_, out_tree.num_leaves) - if k is pe.ReadWrite else [next(out_append_)])] + out_attrs = [] + for _, out_tree, (_, _, k) in attrs_tracked: + if k in (pe.ReadWrite, pe.BoxAttr): + out_attrs.extend(itertools.islice(out_state_, out_tree.num_leaves)) + elif k is pe.Append: + out_attrs.append(next(out_append_)) + elif k is pe.ListAttr: + out_attrs.extend(itertools.islice(out_append_, out_tree.num_leaves)) + else: + assert False assert next(out_state_, None) is next(out_append_, None) is None return out_attrs @@ -847,7 +872,8 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e]) jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + appends_out = [k for _, t, (_, _, k) in attrs_tracked + for k in [k in (pe.Append, pe.ListAttr)] * t.num_leaves] jaxpr_trans = pe.move_outvars_to_back( jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out))) num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 03c755b451e7..bd4768c193fb 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -20,6 +20,7 @@ import dataclasses from functools import partial import inspect +import itertools import logging import weakref from typing import NamedTuple, Any, Union, cast @@ -74,10 +75,10 @@ from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef, none_leaf_registry as none_lr, tree_map) + PyTreeDef, none_leaf_registry as none_lr, tree_map, tree_flatten_with_path) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, tuple_insert, - distributed_debug_log, split_list, weakref_lru_cache, + distributed_debug_log, split_list, split_list_checked, weakref_lru_cache, merge_lists, subs_list, fun_name, fun_qual_name) map, unsafe_map = safe_map, map @@ -219,49 +220,38 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None dispatch.maybe_recursive_nan_check(e, fun, args, kwargs) + if p.box_data: + box_treedef, out_tree = p.out_tree.children() + box_flat, out_flat = split_list_checked(out_flat, [box_treedef.num_leaves, out_tree.num_leaves]) + box_out = tree_unflatten(box_treedef, box_flat) + leaves = tree_leaves((args, kwargs)) + for (i, kind), b in zip(p.box_data, box_out): + if kind is pe.BoxAttr: + leaves[i].set(tree_unflatten(b.treedef, b.leaves)) + elif kind is pe.ListAttr: + for item in tree_unflatten(b.treedef, b.leaves): + leaves[i].append(item) + else: + assert False + else: + out_tree = p.out_tree + if p.attrs_tracked: num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) final_states, out_flat = split_list(out_flat, [num_states_out]) _set_states(p.attrs_tracked, final_states) - outs = tree_unflatten(p.out_tree, out_flat) - return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], - p.attrs_tracked, compiled, profiler) + outs = tree_unflatten(out_tree, out_flat) + return (outs, out_flat, out_tree, args_flat, p.params['jaxpr'], + p.attrs_tracked, p.box_data, compiled, profiler) -def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr, jax_extendattr - valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): - if kind is pe.ReadWrite: - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) - elif kind is pe.Append: - del treedef - val, = leaves - jax_extendattr(obj, attr, val) - -def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr, dne_sentinel - vals = [] - for treedef, _, (obj, attr, kind) in attrs_tracked: - if kind is pe.ReadWrite: - tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) - elif kind is pe.Append: - pass - else: - assert False - return vals - def _need_to_rebuild_with_fdo(pgle_profiler): return (pgle_profiler is not None and pgle_profiler.is_enabled() and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, effects, + executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, effects, consts, abstracted_axes, pgle_profiler ) -> pxla.MeshExecutableFastpathData | None: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) @@ -278,6 +268,7 @@ def _get_fastpath_data( and abstracted_axes is None # no attr state effects and not attrs_tracked + and not box_data # no ref state effects and not any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking @@ -336,13 +327,12 @@ def cache_miss(*args, **kwargs): raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, - pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data, + executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) maybe_fastpath_data = _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jaxpr.consts, jit_info.abstracted_axes, - pgle_profiler) + executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, + jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler) return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) @@ -557,6 +547,7 @@ class PjitParams(NamedTuple): arg_names: tuple[str, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]] + box_data: list def _infer_params_impl( @@ -585,8 +576,10 @@ def _infer_params_impl( f = lu.wrap_init(fun, debug_info=dbg) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args - f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) + del kwargs + dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) + f = _handle_boxes(f, dbg) explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) @@ -623,6 +616,8 @@ def _infer_params_impl( assert in_avals is None in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) + elif box_data: + in_type = in_avals = tuple(core.shaped_abstractify(x) for x in explicit_args) # type: ignore else: in_type = in_avals # type: ignore assert in_avals is not None @@ -656,7 +651,7 @@ def _infer_params_impl( args_flat = [*implicit_args, *explicit_args] num_attrs_in = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) - in attrs_tracked if kind is pe.ReadWrite) + in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr)) num_extra_args = len(implicit_args) + num_attrs_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat @@ -679,7 +674,7 @@ def _infer_params_impl( ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names, len(consts), - attrs_tracked), args_flat + attrs_tracked, box_data), args_flat class InferParamsCacheEntry: @@ -725,7 +720,9 @@ def _infer_params_internal( static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) - if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache + from jax.experimental.attrs import Box, List + any_boxes = any(isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))) + if config.dynamic_shapes.value or any_boxes: # don't use the cache p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=None) return p, p.consts + args_flat @@ -739,7 +736,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't populate the cache + if p.attrs_tracked or p.box_data: # if attrs/boxes, don't populate cache return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs @@ -1310,10 +1307,13 @@ def arg_type_to_str(at): for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)): t_name = t[0].__name__ if t == ot: continue + + # TODO(mattjj): explain box cache misses + if t_name == '_handle_boxes': continue + if t[0] != ot[0]: unavailable(f"fun_transforms[{i}] transform", t, ot) continue - if t_name == "flatten_fun": explain_in_tree_diff(t[1][0], ot[1][0]) continue @@ -1414,8 +1414,7 @@ def p_one_diff(diff: Sequence[str]): for d in smallest_diffs: p_one_diff(d) - done() - return + return done() @partial(lu.cache, explain=explain_tracing_cache_miss) @@ -1509,7 +1508,7 @@ def _attr_cache_index( cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): for obj, attr, kind, treedef, avals in records: - if kind is pe.ReadWrite: + if kind in (pe.ReadWrite, pe.BoxAttr): val = getattr(obj, attr, dne_sentinel) vals, treedef_ = tree_flatten(val) avals_ = map(core.shaped_abstractify, vals) @@ -1690,7 +1689,6 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] - assert len(args) == len(pjit_in_shardings) for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. @@ -1857,8 +1855,8 @@ def call_impl_cache_miss(*args_, **kwargs_): ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - jaxpr.consts, None, pgle_profiler) + compiled, tree_structure(out_flat), args, out_flat, [], [], + jaxpr.effects, jaxpr.consts, None, pgle_profiler) return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) f = _get_jaxpr_as_fun( @@ -3170,3 +3168,132 @@ def get_unconstrained_dims(sharding: NamedSharding): assert sharding.spec is not None return {i for i, axes in enumerate(sharding.spec) if axes is PartitionSpec.UNCONSTRAINED} + +# -------------------- attrs etc -------------------- + +def _set_states(attrs_tracked, vals): + from jax.experimental.attrs import jax_setattr, jax_extendattr + valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + del treedef + val, = leaves + jax_extendattr(obj, attr, val) + elif kind is pe.BoxAttr: + val = tree_unflatten(treedef, leaves) + obj.set(val) + elif kind is pe.ListAttr: + for item in tree_unflatten(treedef, leaves): + obj.append(item) + else: + assert False + +def _get_states(attrs_tracked): + from jax.experimental.attrs import jax_getattr, dne_sentinel + vals = [] + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + elif kind is pe.BoxAttr: + tree = obj.get() # not getattr! + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.ListAttr: + pass + else: + assert False + return vals + +def static(): + return dataclasses.field(metadata=dict(static=True)) + +@tree_util.register_dataclass +@dataclasses.dataclass +class BoxTree: + leaves: list + treedef: PyTreeDef = static() + +@tree_util.register_dataclass +@dataclasses.dataclass +class ListTree: + leaves: list + treedef: PyTreeDef | None = static() + +def _flatten_boxes(dbg, args, kwargs): + from jax.experimental.attrs import Box, List + box_data = [] + id_first_occurrences = {} + idxs = itertools.count() + def visit(x): + i = next(idxs) + if (isinstance(x, (Box, List)) and + (dup_idx := id_first_occurrences.setdefault(id(x), i)) != i): + type_name = type(x).__name__ + raise ValueError( + f"a {type_name} instance can't be passed as an argument more than " + f"once, but when tracing {dbg.func_src_info} for {dbg.traced_for}, " + f"the object {x} appeared at both arguments " + f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}" + if dbg else + f"at both flat index {dup_idx} and flat index {i}") + if type(x) is Box: + leaves, treedef = tree_flatten(x._val) + ty = tuple(core.shaped_abstractify(l) for l in leaves) + box_data.append((i, pe.BoxAttr)) + return BoxTree(leaves, treedef) + elif type(x) is List: + box_data.append((i, pe.ListAttr)) + return ListTree([], None) + else: + return x + args, kwargs = tree_map(visit, (args, kwargs)) + return args, kwargs, box_data + +@lu.transformation2 +def _handle_boxes(f, dbg, *args, **kwargs): + from jax.experimental.attrs import Box, List + new_args = [] + arg_mutables = [] + def visit(x): + if type(x) is BoxTree: + box = Box(tree_unflatten(x.treedef, x.leaves)) + arg_mutables.append(box) + return box + elif type(x) is ListTree: + lst = List() + arg_mutables.append(lst) + return lst + else: + return x + args, kwargs = tree_map(visit, (args, kwargs), + is_leaf=lambda x: isinstance(x, (BoxTree, ListTree))) + out = f(*args, **kwargs) + for path, leaf in tree_flatten_with_path(out)[0]: + if isinstance(leaf, (Box, List)): + type_name = type(leaf).__name__ + raise ValueError( + f"a {type_name} instance can't be returned from a transformed " + f"function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " + f"the object {leaf} appeared at result{keystr(path)}") + if not arg_mutables: + return out + extra_outs = [] + for mutable in arg_mutables: + if type(mutable) is Box: + leaves, treedef = tree_flatten(mutable._val) + extra_outs.append(BoxTree(leaves, treedef)) + elif type(mutable) is List: + leaves, treedef = tree_flatten(mutable._val) + extra_outs.append(ListTree(leaves, treedef)) + else: + assert False + return extra_outs, out diff --git a/jax/_src/util.py b/jax/_src/util.py index e551c654b005..ece1808ab8d8 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -145,6 +145,7 @@ def subvals(lst, replace): def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: args = list(args) + assert all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) @@ -154,7 +155,7 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: args = list(args) - assert sum(ns) == len(args) + assert sum(ns) == len(args) and all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 54fd0fe0b02f..d483ac076a0f 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -35,7 +35,8 @@ Array = Any JaxVal = Any -Pytree = Any +PyTree = Any +PyTreeDef = Any ReadWrite = pe.ReadWrite Append = pe.Append @@ -43,11 +44,11 @@ register = api_util.register_class_with_attrs dne_sentinel = pe.dne_sentinel -def jax_getattr(obj: Any, attr: str) -> Pytree: +def jax_getattr(obj: Any, attr: str) -> PyTree: with core.take_current_trace() as t: return t.process_getattr(obj, attr) -def jax_setattr(obj: Any, attr: str, val: Pytree) -> None: +def jax_setattr(obj: Any, attr: str, val: PyTree) -> None: with core.take_current_trace() as t: return t.process_setattr(obj, attr, val) @@ -85,7 +86,8 @@ def _check_append_type_agreement(_, attr, curtype, valtype): f"{expected.str_short()}, but appendattr got value of type " f"{valtype.str_short()} which has trailing shape {got.str_short()}.") -def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): +def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str, + kind: pe.AttrKind): frame = trace.frame source_info = source_info_util.current() @@ -100,22 +102,22 @@ def new_tracer(x): if (obj, attr, Append) in frame.attrs_tracked: raise TypeError(f"can't read/write to append-only attr {attr}") - if (obj, attr, ReadWrite) not in frame.attrs_tracked: + if (obj, attr, kind) not in frame.attrs_tracked: init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr, ReadWrite)) + frame.attrs_tracked.append((obj, attr, kind)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked def _getattr_staging(trace, obj, attr): - trace._ensure_tracked(obj, attr) + trace._ensure_tracked(obj, attr, ReadWrite) return getattr(obj, attr) pe.DynamicJaxprTrace.process_getattr = _getattr_staging def _setattr_staging(trace, obj, attr, val): - trace._ensure_tracked(obj, attr) + trace._ensure_tracked(obj, attr, ReadWrite) setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging @@ -291,3 +293,117 @@ def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) return args_ct, dict(zip(in_attrs, in_attr_bars)) return f_vjp + + +class Box: + _val: PyTree + _tag: core.OpaqueTraceState + def __init__(self, val): + self._val = val + self._tag = core.get_opaque_trace_state() + def get(self): + with core.take_current_trace() as t: + return t.process_box_get(self) + def set(self, val): + with core.take_current_trace() as t: + return t.process_box_set(self, val) + +def _box_get_impl(trace, box): + return box._val +core.EvalTrace.process_box_get = _box_get_impl + +def _box_set_impl(trace, box, val): + box._val = val +core.EvalTrace.process_box_set = _box_set_impl + +def _is_local(trace, box): + is_arg = box._tag._trace_ref() is trace + if is_arg: assert box._tag._trace_ref() is trace + return is_arg + +def _box_get_staging(trace, box): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + return box._val +pe.DynamicJaxprTrace.process_box_get = _box_get_staging + +def _box_set_staging(trace, box, val): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + box._val = val +pe.DynamicJaxprTrace.process_box_set = _box_set_staging + +def _box_get_jvp(trace, box): + return box._val +ad.JVPTrace.process_box_get = _box_get_jvp + +def _box_set_jvp(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + box._val = ad.JVPTracer(trace, primal, tangent) +ad.JVPTrace.process_box_set = _box_set_jvp + +def _box_get_linearize(trace, box): + return box._val +ad.LinearizeTrace.process_box_get = _box_get_linearize + +def _box_set_linearize(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + raise NotImplementedError # TODO + box._val = ad.LinearizeTracer(trace, primal, tangent) +ad.LinearizeTrace.process_box_set = _box_set_linearize + + +class List: + _val: PyTree + _tag: core.OpaqueTraceState + frozen: bool + def __init__(self, val=None): + self._val = [] if val is None else val + self._tag = core.get_opaque_trace_state() + self.frozen = False + def append(self, val): + with core.take_current_trace() as t: + return t.process_list_append(self, val) + def freeze(self): + with core.take_current_trace() as t: + return t.process_list_freeze(self) + +def _list_append_impl(trace, lst, val): + if lst.frozen: + raise Exception("can't append to an already-frozen List") + lst._val.append(val) +core.EvalTrace.process_list_append = _list_append_impl + +def _list_freeze_impl(trace, lst): + lst.frozen = True + return lst._val +core.EvalTrace.process_list_freeze = _list_freeze_impl + +def _list_append_staging(trace, lst, val): + if not _is_local(trace, lst): + _ensure_list_tracked(trace, lst) + return _list_append_impl(trace, lst, val) +pe.DynamicJaxprTrace.process_list_append = _list_append_staging + +def _ensure_list_tracked(trace, lst): + frame = trace.frame + if (lst, '_val', pe.ListAttr) not in frame.attrs_tracked: + frame.attrs_inits.append(lst._val) + frame.attrs_tracked.append((lst, '_val', pe.ListAttr)) + lst._val = [] + +def _list_freeze_staging(trace, lst): + if not _is_local(trace, lst): + raise Exception("can only freeze a local List") + return _list_freeze_impl(trace, lst) +pe.DynamicJaxprTrace.process_list_freeze = _list_freeze_staging diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 8cf64790311b..6e1381ffbc88 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -29,7 +29,8 @@ from jax._src.util import safe_zip, safe_map from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr, jax_appendattr +from jax.experimental.attrs import ( + jax_setattr, jax_getattr, jax_appendattr, Box, List) config.parse_flags_with_absl() @@ -922,5 +923,386 @@ def f_ref(x, y, z, w): check_dtypes=False) +class BoxTest(jtu.JaxTestCase): + + def test_jit_arg(self): + @jax.jit + def f(box, x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f(box1, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f(box2, 2.) + self.assertAllClose(box2.get(), 4.0) + + def test_jit_arg_in_pytree(self): + @jax.jit + def f(dct, x): + assert tracing_ok + box = dct['box'] + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f({'box': box1, 'a': 1.0}, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f({'box': box2, 'a': 2.0}, 2.) + self.assertAllClose(box2.get(), 4.0) + + tracing_ok = True + box3 = Box(3) # int, dtype changed + f({'box': box3, 'a': 2.0}, 2.) + self.assertAllClose(box3.get(), 5.0) + + def test_jit_closure(self): + @jax.jit + def f(x): + box.set(box.get() + x) + + box = Box(1.0) + f(2.0) + self.assertAllClose(box.get(), 3.0) + + @jax.jit + def g(x): + f(x) + + g(3.0) + self.assertAllClose(box.get(), 6.0) + + def test_jit_closure_nested(self): + @jax.jit + def h(x): + box = Box(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + @parameterized.parameters([False, True]) + def test_jvp_closure_stop_gradient(self, jit): + box = Box(1.0) + + def f(x): + y = 2 * x + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + y, y_dot = jax.jvp(f, (1.0,), (1.0,)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 2.0) + self.assertAllClose(box.get(), 3.0) + + @parameterized.parameters([False, True]) + def test_jvp_arg(self, jit): + def f(box, x): + box.set(box.get() + x) + return x + + if jit: + f = jax.jit(f) + + box = Box(5.0) + box_dot = Box(1.0) + y, y_dot = jax.jvp(f, (box, 2.), (box_dot, 1.)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 1.0) + self.assertAllClose(box.get(), 7.0) + self.assertAllClose(box_dot.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + box.set(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_grad_closrue_stop_gradient(self, jit): + box = Box(0.0) + + def f(x): + y = x * 2 + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(1.0) + self.assertAllClose(g, 2.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + box = Box(1.0) + + def double_it_10(): + def body(_, __): + box.set(box.get() * 2) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(box.get(), 1024., check_dtypes=False) + + def test_error_passing_multiple_times_to_jit(self): + + @jax.jit + def f(box1, box2): + ... + + b = Box(1.0) + with self.assertRaisesRegex(ValueError, "a Box instance can't be passed"): + f(b, b) + + def test_error_returning_from_jit(self): + @jax.jit + def f(): + return {'a': Box(1.0)} + + with self.assertRaisesRegex(ValueError, "a Box instance can't be returned"): + f() + + +class ListTest(jtu.JaxTestCase): + + def test_eager(self): + lst = List() + lst.append(1.0) + lst.append(2.0) + lst.append(3.0) + self.assertAllClose(lst.freeze(), [1., 2., 3.]) + + def test_jit_arg(self): + @jax.jit + def f(lst, x): + assert tracing_ok + lst.append(1.0) + lst.append(2.0) + lst.append({'c': x + 3.0}) + + + tracing_ok = True + lst1 = List() + f(lst1, 0) + self.assertAllClose(lst1.freeze(), [1., 2., {'c': 3.}]) + + tracing_ok = False + lst2 = List() + lst2.append(0.) + f(lst2, 1) + self.assertAllClose(lst2.freeze(), [0., 1., 2., {'c': 4.}]) + + def test_jit_closure(self): + lst = List() + + @jax.jit + def f(x): + assert tracing_ok + lst.append(1.0) + lst.append({'a': 2.0}) + lst.append(x + 3.0) + + tracing_ok = True + f(1) + self.assertAllClose(lst._val, [1., {'a': 2.}, 4.]) + + tracing_ok = False + f(2) + self.assertAllClose(lst.freeze(), [1., {'a': 2.}, 4., 1., {'a': 2.0}, 5.0]) + + def test_jit_closure_nested(self): + lst = List() + + @jax.jit + def h(x): + lst.append(x) + + @jax.jit + def k(x): + lst.append(x) + + k(1.0) + k(2.0) + + h(0.0) + self.assertAllClose(lst.freeze(), [0., 1., 2.]) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + lst = List() + + def f(): + def body(_, x): + lst.append(2 * x) + lst.append(2 * x + 1) + return (), () + (), () = jax.lax.scan(body, (), jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(lst.freeze(), [0., 1., 2., 3., 4., 5.]) + + @parameterized.parameters([False, True]) + def test_scan_basic_hetero(self, jit): + lst = List() + + def f(): + def body(_, x): + lst.append(2 * x) + lst.append({'a': (2 * x + 1, 2 * x + 2)}) + return (), () + (), () = jax.lax.scan(body, (), jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + expected = [ + 0., + {'a': (1., 2.)}, + 2., + {'a': (3., 4.)}, + 4., + {'a': (5., 6.)}, + ] + self.assertAllClose(lst.freeze(), expected) + + @parameterized.parameters([False, True]) + def test_freeze_basic(self, jit): + + def f(): + lst = List() + lst.append(1.) + lst.append(2.) + return lst.freeze() + + if jit: + f = jax.jit(f) + + lst = f() + self.assertAllClose(lst, [1., 2.]) + + def test_freeze_nonlocal_list(self): + lst = List() + + @jax.jit + def f(): + lst.freeze() + + with self.assertRaisesRegex(Exception, "can only freeze a local List"): + f() + + def test_freeze_nonlocal_list_nested(self): + @jax.jit + def f(): + lst = List() + + @jax.jit + def g(): + lst.freeze() + + g() + + with self.assertRaisesRegex(Exception, "can only freeze a local List"): + f() + + @parameterized.parameters([False, True]) + def test_append_after_freeze(self, jit): + def f(): + lst = List() + lst.append(1.) + lst.append(2.) + val = lst.freeze() + with self.assertRaisesRegex(Exception, "can't append"): + lst.append(3.) + + if jit: + f = jax.jit(f) + + f() + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + lst = List() + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + lst.append(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + self.assertAllClose(lst.freeze(), [2.0]) + + def test_error_passing_multiple_times_to_jit(self): + @jax.jit + def f(lst1, lst2): + ... + + b = List([]) + with self.assertRaisesRegex(ValueError, "a List instance can't be passed"): + f(b, b) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 2d8b568893ccd0c64e9e829d7dc253d913a19ebc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 30 Apr 2025 15:30:27 -0700 Subject: [PATCH 0928/1769] Fix bugs wrt partial-auto when there are multiple levels of nesting. The changes are: * If the mesh passed to shard_map doesn't match the context mesh (if present), error out * Whenever we trace a jaxpr in shard_map: * the avals passed via `_shard_aval` should union the current manual axes on mesh with the newly manual axes specified on shard_map's `axis_name` argument * The mesh we enter into when in `use_abstract_mesh` should also union the current manual axes with the newly manual axes PiperOrigin-RevId: 753336465 --- jax/_src/debugging.py | 2 +- jax/_src/lax/parallel.py | 4 +- jax/_src/shard_map.py | 48 ++++++++----- jax/experimental/shard_map.py | 4 +- tests/shard_map_test.py | 123 ++++++++++++++++++++++++---------- 5 files changed, 124 insertions(+), 57 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 18178b4efcb0..29cbb01511e9 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -141,7 +141,7 @@ def f(): return jax.lax.cond(idx == 0, lambda: debug_callback_p.bind(*args, **params), lambda: []) - return jax.shard_map(f, mesh=axis_context.mesh, in_specs=(), out_specs=[])() + return jax.shard_map(f, in_specs=(), out_specs=[])() def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params): axis_context = ctx.module_context.axis_context diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 2c71565d92e4..6df8690f1123 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1925,8 +1925,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): def f(): return axis_index_p.bind(axis_name=axis_name) return mlir.lower_fun( - lambda: [jax.shard_map(f, mesh=axis_context.mesh, check_vma=False, - in_specs=(), out_specs=P())()])(ctx)[0] + lambda: [jax.shard_map(f, check_vma=False, in_specs=(), + out_specs=P())()])(ctx)[0] nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 27f5956f0374..af5b277d1a43 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -128,7 +128,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs | Callable[[], Specs], - axis_names: Set[AxisName], check_vma: bool): + axis_names: Set[AxisName], check_vma: bool, + _skip_mesh_check: bool = False): if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") @@ -140,6 +141,14 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, "The context mesh cannot be empty. Either use" " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass" " a mesh to `shard_map` via the `mesh` keyword argument.") + else: + ctx_mesh = get_abstract_mesh() + if (not _skip_mesh_check and not ctx_mesh.empty and + mesh.abstract_mesh != ctx_mesh): + raise ValueError( + f"The context mesh {ctx_mesh} should match the mesh passed to" + f" shard_map {mesh}") + if not isinstance(mesh, (Mesh, AbstractMesh)): raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " "`jax.sharding.AbstractMesh` instance for its " @@ -540,7 +549,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset): if cur_mesh._name_to_type[a] == AxisType.Auto: auto_axes.add(a) else: - assert cur_mesh._name_to_type[a] == AxisType.Explicit + assert cur_mesh._name_to_type[a] == AxisType.Explicit, cur_mesh._name_to_type[a] explicit_axes.add(a) new_axis_types = [] @@ -558,7 +567,7 @@ def _as_manual_mesh(mesh, manual_axes: frozenset): def _extend_axis_env(mesh, manual_axes): return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k in manual_axes]) + if k in manual_axes]) def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, @@ -571,11 +580,11 @@ def _shard_map_staging( source_info = source_info_util.current() to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) in_tracers = map(to_jaxpr_tracer, in_tracers) + inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_names, in_avals) - manual_mesh = _as_manual_mesh(mesh, manual_axes) - with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) @@ -590,7 +599,7 @@ def _shard_map_staging( constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, @@ -629,10 +638,11 @@ def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, manual_axes) + manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) vma = (frozenset({n for ns in names.values() for n in ns}) if check_vma else frozenset()) + vma = vma | aval.vma return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array @@ -695,7 +705,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool: - um = set(_unmentioned(mesh, names)) + um = set(_unmentioned(mesh, names)) - set(mesh.manual_axes) if any(u in vma for u in um): return False return True @@ -808,8 +818,10 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, if len(manual_axes) < len(mesh.axis_names) else set()) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) - return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, + unspecified) def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in, aval_out, x): @@ -824,8 +836,10 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, if len(manual_axes) < len(mesh.axis_names) else set()) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, + unspecified_dims=unspecified) shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, unspecified) @@ -894,9 +908,9 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh): trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh) in_tracers = map(partial(ShardMapTracer, trace), vmas, args) - manual_mesh = _as_manual_mesh(mesh, manual_axes) + inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), - use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): + use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): ans = f.call_wrapped(*in_tracers) outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) return outs, out_vma @@ -1318,7 +1332,7 @@ def fwd_out_names_thunk(): args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] with (_extend_axis_env(mesh, manual_axes), - use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))), config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() @@ -1483,7 +1497,7 @@ def _partial_eval_jaxpr_custom_rule( jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), - use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res @@ -1494,7 +1508,7 @@ def _partial_eval_jaxpr_custom_rule( which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] with (_extend_axis_env(mesh, manual_axes), - use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))), config._check_vma(check_vma)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index afb61159f55b..8f9548bdce1b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -77,6 +77,6 @@ def shard_map( .. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html """ axis_names = frozenset(mesh.axis_names) - auto - return jshmap.shard_map( + return jshmap._shard_map( f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, - check_vma=check_rep, axis_names=axis_names) + check_vma=check_rep, axis_names=axis_names, _skip_mesh_check=True) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ada327d936d6..f1496e7e1e18 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -438,19 +438,17 @@ def test_replication_checker_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) - def f(x): - return 2 * x def g(x): - return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(lambda x: x * 2, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): jax.jit(g)(x) - def f2(x): - return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) - _ = jax.jit(g2)(x) # doesn't crash + return shard_map(lambda x: jax.lax.psum(x, 'x'), mesh=mesh, + in_specs=P('x', 'y'), out_specs=P(None, 'y'))(x) + jax.jit(g2)(x) # doesn't crash def test_process_env_traces(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -2235,23 +2233,83 @@ def f(x): with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): f(v) - def test_nested_partial_auto(self): + def test_partial_auto_mismatch_mesh_error(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) def g(x): return x * x def h(x): - return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_vma=False, axis_names=frozenset({'i'}))(x) + return shard_map(h, mesh=mesh, in_specs=P('i', None), + out_specs=P('i', None), check_vma=False, + axis_names=frozenset({'i'}))(x) + with self.assertRaisesRegex( + ValueError, r"context mesh.*should match the mesh passed to shard_map"): + self.assertAllClose(v*v, f(v), check_dtypes=False) + + def test_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v, f(v), check_dtypes=False) + + def g(x): + return x * x + + def h(x): + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x) + + with jax.sharding.use_mesh(mesh): + self.assertAllClose(v*v, f(v), check_dtypes=False) + + @parameterized.named_parameters( + ('0', 'x', 'y', {'x'}, {'x', 'y'}), + ('1', None, 'y', frozenset(), {'y'}), + ('2', 'x', None, {'x'}, {'x'}), + ('3', None, None, frozenset(), frozenset()), + ) + def test_nested_partial_auto_1d(self, dim1, dim2, outer_vma, inner_vma): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(dim1, dim2))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, inner_vma) + out = x * x + self.assertEqual(out.aval.vma, inner_vma) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, outer_vma) + out = shard_map(g, in_specs=P(None, dim2), + out_specs=P(None, dim2), axis_names={'y'})(x) + self.assertEqual(out.aval.vma, outer_vma) + return out + + @jax.jit + def f(x): + return shard_map(h, in_specs=P(dim1, None), + out_specs=P(dim1, None), axis_names={'x'})(x) + + with jax.sharding.use_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) def test_grad_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2262,22 +2320,19 @@ def g(x): def h(x): # auto: 'j', manual: 'i' - return shard_map(g, mesh=mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): # auto: 'i', 'j' - return shard_map(h, mesh=mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_vma=False, - axis_names=frozenset({'i'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + with jax.sharding.use_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * 2, check_dtypes=False) def test_grad_nested_partial_auto_with_residuals(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2286,21 +2341,18 @@ def g(x): return x * x * x def h(x): - return shard_map(g, mesh=mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh=mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_vma=False, - axis_names=frozenset({'i'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False) + with jax.sharding.use_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * v * 3, check_dtypes=False) def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) @@ -2367,10 +2419,11 @@ def test_partial_auto_axis_index(self): @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), - mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + in_specs=P('i', None), out_specs=P('i', None), check_vma=False, axis_names=frozenset({'i'}))() - self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) + with jax.sharding.use_mesh(mesh): + self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_axis_index_degenerated_axis(self): mesh = jtu.create_mesh((1, 2), ('i', 'j')) @@ -2432,11 +2485,11 @@ def g(x): @jax.jit def f(x): - return shard_map(g, - mesh=mesh, in_specs=P('i'), out_specs=None, + return shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=None, check_vma=False, axis_names=frozenset({'i'}))(x) - y = f(x) # don't crash + with jax.sharding.use_mesh(mesh): + f(x) # don't crash def test_partial_auto_of_random_keys(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) From c9fab1e0a11788863b7a9bbdadcc550a938fbd2c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Apr 2025 15:36:28 -0700 Subject: [PATCH 0929/1769] Create the `build_jax` flag to control whether `//jax` should be built. 1) `:build_jax=true` (default value): `//jax` will be built. 2) `:build_jax=false`: `//jax` will not be built. It is assumed that the pre-built `jax` wheel is available in the `dist` folder. 3) `:build_jax=wheel`: `jax` wheel will be built as a `py_import` rule attribute. The `py_import` rule unpacks the wheel and provides its content as a `py_library`. PiperOrigin-RevId: 753338643 --- BUILD.bazel | 14 ---- ci/run_bazel_test_cuda_non_rbe.sh | 2 + docs/developer.md | 27 +++---- jax/BUILD | 32 ++++++++ jaxlib/jax.bzl | 121 ++++++++++++++++++++++-------- 5 files changed, 139 insertions(+), 57 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index ddf59c0290d7..4cc8d6d3f63c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -20,7 +20,6 @@ load( "//jaxlib:jax.bzl", "jax_source_package", "jax_wheel", - "py_deps", "pytype_test", "wheel_sources", ) @@ -119,22 +118,10 @@ genrule( tools = ["@bazel_tools//tools/zip:zipper"], ) -COMMON_DEPS = py_deps([ - "absl/testing", - "numpy", - "ml_dtypes", - "scipy", - "opt_einsum", - "hypothesis", - "cloudpickle", - "flatbuffers", -]) - py_import( name = "jax_py_import", wheel = ":jax_wheel", wheel_deps = [":wheel_additives"], - deps = COMMON_DEPS, ) # This target is used to add more sources to the jax wheel. @@ -144,7 +131,6 @@ py_import( name = "jax_wheel_with_internal_test_util", wheel = "@pypi_jax//:whl", wheel_deps = [":wheel_additives"], - deps = COMMON_DEPS, ) pytype_test( diff --git a/ci/run_bazel_test_cuda_non_rbe.sh b/ci/run_bazel_test_cuda_non_rbe.sh index 176efd3444c9..ce3a7562fea4 100755 --- a/ci/run_bazel_test_cuda_non_rbe.sh +++ b/ci/run_bazel_test_cuda_non_rbe.sh @@ -76,6 +76,7 @@ bazel test --config=ci_linux_x86_64_cuda \ --config=rbe_cache \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --//jax:build_jaxlib=false \ + --//jax:build_jax=false \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ @@ -102,6 +103,7 @@ bazel test --config=ci_linux_x86_64_cuda \ --config=rbe_cache \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --//jax:build_jaxlib=false \ + --//jax:build_jax=false \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --jobs=8 \ diff --git a/docs/developer.md b/docs/developer.md index 9edeaeac83f8..cfb3f16cf649 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -466,16 +466,16 @@ requirements = { Then you can build and test different combinations of stuff without changing anything in your environment: ``` -# To build with scenario1 dependendencies: +# To build with scenario1 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 -# To build with scenario2 dependendencies: +# To build with scenario2 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2 -# To build with default dependendencies: +# To build with default dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13 -# To build with scenario1 dependendencies and custom Python 3.13 interpreter: +# To build with scenario1 dependencies and custom Python 3.13 interpreter: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" @@ -526,27 +526,28 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. -To use the preinstalled `jax` and `jaxlib` instead of building them you first -need to make them available in the hermetic Python. To install the specific -versions of `jax` and `jaxlib` within hermetic Python run (using `jax >= 0.4.26` -and `jaxlib >= 0.4.26` as an example): +You need to configure `cuda` to run `gpu` tests: +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only +``` + +To use a preinstalled `jaxlib` instead of building it you first need to +make it available in the hermetic Python. To install a specific version of +`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): ``` -echo -e "\njax >= 0.4.26" >> build/requirements.in echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py requirements_update ``` -Alternatively, to install `jax` and `jaxlib` from the local wheels -(assuming Python 3.12): +Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` -echo -e "\n$(realpath jax-0.4.26-py3-none-any.whl)" >> build/requirements.in echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` -Once you have `jax` and `jaxlib` installed hermetically, run: +Once you have `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests diff --git a/jax/BUILD b/jax/BUILD index dae375938fdd..14d205c0e210 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -70,6 +70,38 @@ config_setting( }, ) +# The flag controls whether jax should be built by Bazel. +# If ":build_jax=true", then jax will be built. +# If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel +# is available in the "dist" folder. +# If ":build_jax=wheel", then jax wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( + name = "build_jax", + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], +) + +config_setting( + name = "disable_jaxlib_and_jax_build", + flag_values = { + ":build_jaxlib": "false", + ":build_jax": "false", + }, +) + +config_setting( + name = "enable_jaxlib_and_jax_py_import", + flag_values = { + ":build_jaxlib": "wheel", + ":build_jax": "wheel", + }, +) + exports_files([ "LICENSE", "version.py", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index ff4720748b0f..f8c89c0e0401 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -66,18 +66,6 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } -_GPU_PYPI_WHEEL_DEPS = [ - "//:jax_wheel_with_internal_test_util", - "@pypi//jaxlib", - "@pypi//jax_cuda12_plugin", - "@pypi//jax_cuda12_pjrt", -] - -_CPU_PYPI_WHEEL_DEPS = [ - "//:jax_wheel_with_internal_test_util", - "@pypi//jaxlib", -] - # TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): @@ -167,8 +155,14 @@ ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( if_building, - if_not_building = _GPU_PYPI_WHEEL_DEPS, - if_not_building_for_cpu = _CPU_PYPI_WHEEL_DEPS): + if_not_building = [ + "@pypi//jaxlib", + "@pypi//jax_cuda12_plugin", + "@pypi//jax_cuda12_pjrt", + ], + if_not_building_for_cpu = [ + "@pypi//jaxlib", + ]): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. @@ -188,40 +182,106 @@ def if_building_jaxlib( }) def _get_test_deps(deps, backend_independent): + """Returns the test deps for the given backend. + + Args: + deps: the full list of test dependencies + backend_independent: whether the test is backend independent + + Returns: + A list of test deps for the given backend. + For CPU builds: + If --//jax:enable_jaxlib_build=true, returns pypi test deps. + If --//jax:enable_jaxlib_build=false, returns jaxlib pypi wheel dep and pypi test deps. + If --//jax:enable_jaxlib_build=wheel, returns jaxlib py_import dep and pypi test deps. + For GPU builds: + If --//jax:enable_jaxlib_build=true, returns pypi test deps and gpu build deps. + If --//jax:enable_jaxlib_build=false, returns jaxlib, jax-cuda-plugin, + jax-cuda-pjrt pypi wheel deps and pypi test deps. + If --//jax:enable_jaxlib_build=wheel, returns jaxlib, + jax-cuda-plugin, jax-cuda-pjrt py_import deps and pypi test deps. + """ gpu_build_deps = [ "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ] + pypi_test_deps = [d for d in deps if d.startswith("@pypi//")] gpu_py_imports = [ - "//:jax_py_import", "//jaxlib/tools:jaxlib_py_import", "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", - ] + ] + pypi_test_deps cpu_py_imports = [ - "//:jax_py_import", "//jaxlib/tools:jaxlib_py_import", - ] + ] + pypi_test_deps + jaxlib_pypi_wheel_deps = [ + "@pypi//jaxlib", + ] + pypi_test_deps if backend_independent: - jaxlib_build_deps = deps - gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS + test_deps = pypi_test_deps + gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps gpu_py_import_deps = cpu_py_imports else: - jaxlib_build_deps = gpu_build_deps + deps - gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS + test_deps = gpu_build_deps + pypi_test_deps + gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps + [ + "@pypi//jax_cuda12_plugin", + "@pypi//jax_cuda12_pjrt", + ] gpu_py_import_deps = gpu_py_imports return select({ - "//jax:enable_jaxlib_build": jaxlib_build_deps, - "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, + "//jax:enable_jaxlib_build": test_deps, + "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": jaxlib_pypi_wheel_deps, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps, }) +def _get_jax_test_deps(deps): + """Returns the jax build deps, pypi jax wheel dep, or jax py_import dep for the given backend. + + Args: + deps: the full list of test dependencies + + Returns: + A list of jax test deps. + + If --//jax:enable_jax_build=true, returns jax build deps. + If --//jax:enable_jax_build=false, returns jax pypi wheel dep and transitive pypi test deps. + If --//jax:enable_jax_build=wheel, returns jax py_import dep and transitive pypi test deps. + """ + jax_build_deps = [d for d in deps if not d.startswith("@pypi//")] + + # A lot of tests don't have explicit dependencies on absl/testing, numpy, etc. But the tests + # transitively depends on them via //jax. So we need to make sure that these dependencies are + # included in the test when JAX is built from source. + # TODO(ybaturina): Add individual dependencies for each test and remove this block. + jax_transitive_pypi_test_deps = {k: "true" for k in py_deps([ + "absl/testing", + "numpy", + "ml_dtypes", + "scipy", + "opt_einsum", + "hypothesis", + "cloudpickle", + "flatbuffers", + ])} + + # Remove the pypi deps that are already provided by _get_test_deps(). + for d in deps: + if d.startswith("@pypi//") and jax_transitive_pypi_test_deps.get(d): + jax_transitive_pypi_test_deps.pop(d) + return select({ + "//jax:disable_jaxlib_and_jax_build": ["//:jax_wheel_with_internal_test_util"] + + jax_transitive_pypi_test_deps.keys(), + "//jax:enable_jaxlib_and_jax_py_import": ["//:jax_py_import"] + + jax_transitive_pypi_test_deps.keys(), + "//conditions:default": jax_build_deps + jax_transitive_pypi_test_deps.keys(), + }) + # buildifier: disable=function-docstring def jax_multiplatform_test( name, @@ -275,10 +335,11 @@ def jax_multiplatform_test( srcs = srcs, args = test_args, env = env, - deps = _get_test_deps([ - "//jax", - "//jax:test_util", - ] + deps, backend_independent = False), + deps = _get_test_deps(deps, backend_independent = False) + + _get_jax_test_deps([ + "//jax", + "//jax:test_util", + ] + deps), data = data, shard_count = test_shards, tags = test_tags, @@ -571,13 +632,13 @@ def jax_py_test( env = dict(env) env.setdefault("PYTHONWARNINGS", "error") deps = kwargs.get("deps", []) - test_deps = _get_test_deps(deps, backend_independent = True) + test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps) kwargs["deps"] = test_deps py_test(name = name, env = env, **kwargs) def pytype_test(name, **kwargs): deps = kwargs.get("deps", []) - test_deps = _get_test_deps(deps, backend_independent = True) + test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps) kwargs["deps"] = test_deps native.py_test(name = name, **kwargs) From 76f399caf8daf23c0ff69c895f71c4dcafde7a1c Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 30 Apr 2025 15:41:53 -0700 Subject: [PATCH 0930/1769] [JAX] Add `process_indices` to `DeviceList`. This attribute return the set of indices of processes that have local devices in the DeviceList. PiperOrigin-RevId: 753340580 --- jax/_src/mesh.py | 4 --- jax/_src/named_sharding.py | 6 +--- jax/_src/sharding.py | 12 ++----- jaxlib/_jax/__init__.pyi | 2 ++ jaxlib/py_device_list.cc | 70 ++++++++++++++++++++++++++------------ jaxlib/py_device_list.h | 6 ++++ 6 files changed, 60 insertions(+), 40 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b96bd2f832dc..d9183f8805d7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -366,10 +366,6 @@ def empty(self): def is_multi_process(self): return self.devices.size != len(self.local_devices) - @functools.cached_property - def _process_indices(self): - return {d.process_index for d in self._flat_devices_tuple} - @property def local_mesh(self): return self._local_mesh(xb.process_index()) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index c11e3687ba15..d0a49b6b081c 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -27,7 +27,6 @@ from jax._src import mesh as mesh_lib from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding -from jax._src import xla_bridge as xb import numpy as np Shape = tuple[int, ...] @@ -192,10 +191,7 @@ def is_fully_addressable(self) -> bool: # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client # type: ignore - return (len(self.mesh._process_indices) == 1 and - next(iter(self.mesh._process_indices)) == - xb.process_index(client)) + return self._internal_device_list.is_fully_addressable # type: ignore return not self.mesh.is_multi_process @property diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 32373d6e6c39..f4b342deafcc 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -36,11 +36,8 @@ def _addressable_devices_indices_map( global_map = sharding.devices_indices_map(global_shape) if sharding.is_fully_addressable: return global_map - if hasattr(sharding, '_internal_device_list'): - return {d: global_map[d] - for d in sharding._internal_device_list.addressable_device_list} - return {d: ind for d, ind in global_map.items() - if d.process_index == d.client.process_index()} + return {d: global_map[d] + for d in sharding._internal_device_list.addressable_device_list} # type: ignore @cache(max_size=4096, trace_context_in_key=False) def common_devices_indices_map( @@ -174,10 +171,7 @@ def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: def _addressable_device_assignment(self) -> XLADeviceAssignment: if self.is_fully_addressable: return self._device_assignment - if hasattr(self, '_internal_device_list'): - return tuple(self._internal_device_list.addressable_device_list) - return tuple(d for d in self._device_assignment - if d.process_index == d.client.process_index()) + return tuple(self._internal_device_list.addressable_device_list) # type: ignore def shard_shape(self, global_shape: Shape) -> Shape: """Returns the shape of the data on each device. diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 49e430069744..895d56e852e1 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -868,6 +868,8 @@ class DeviceList: @property def addressable_device_list(self) -> DeviceList: ... @property + def process_indices(self) -> set[int]: ... + @property def default_memory_kind(self) -> str | None: ... @property def memory_kinds(self) -> tuple[str, ...]: ... diff --git a/jaxlib/py_device_list.cc b/jaxlib/py_device_list.cc index 3bf5480c5363..c80602e14862 100644 --- a/jaxlib/py_device_list.cc +++ b/jaxlib/py_device_list.cc @@ -20,16 +20,19 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "nanobind/make_iterator.h" #include "nanobind/nanobind.h" +#include "nanobind/stl/set.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/nb_class_ptr.h" @@ -256,31 +259,28 @@ nb::tuple PyDeviceList::Dump() const { return AsTuple(); } bool PyDeviceList::IsFullyAddressable() { if (!is_fully_addressable_.has_value()) { - is_fully_addressable_ = true; - switch (device_list_.index()) { - case 0: { - const int process_index = py_client_ ? py_client_->process_index() : 0; - for (const xla::ifrt::Device* device : - std::get<0>(device_list_)->devices()) { - if (device->ProcessIndex() != process_index) { - is_fully_addressable_ = false; - break; - } + ProcessIndices(); + CHECK(process_indices_.has_value()); + if (process_indices_->size() > 1) { + is_fully_addressable_ = false; + } else { + CHECK_EQ(process_indices_->size(), 1); + int process_index; + switch (device_list_.index()) { + case 0: { + process_index = py_client_ ? py_client_->process_index() : 0; + break; } - break; - } - case 1: { - for (nb::handle device : std::get<1>(device_list_)) { - if (nb::cast(device.attr("process_index")) != - nb::cast(device.attr("client").attr("process_index")())) { - is_fully_addressable_ = false; - break; - } + case 1: { + process_index = + nb::cast(std::get<1>(device_list_)[0].attr("client").attr( + "process_index")()); + break; } - break; + default: + throw nb::value_error("Unrecognized DeviceList type"); } - default: - throw nb::value_error("Unrecognized DeviceList type"); + is_fully_addressable_ = *process_indices_->begin() == process_index; } } return *is_fully_addressable_; @@ -332,6 +332,30 @@ bool PyDeviceList::IsFullyAddressable() { return *self->addressable_device_list_; } +const std::set& PyDeviceList::ProcessIndices() { + if (!process_indices_.has_value()) { + process_indices_ = std::set{}; + switch (device_list_.index()) { + case 0: { + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + process_indices_->insert(device->ProcessIndex()); + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + process_indices_->insert(nb::cast(device.attr("process_index"))); + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *process_indices_; +} + void PyDeviceList::PopulateMemoryKindInfo() { if (device_list_.index() == 1) { // Handle Python duck-type devices in a separate function for readability. @@ -448,6 +472,8 @@ void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { nb::lock_self()) .def_prop_ro("addressable_device_list", &PyDeviceList::AddressableDeviceList) + .def_prop_ro("process_indices", &PyDeviceList::ProcessIndices, + nb::lock_self()) // `xla::ValueOrThrowWrapper` does not work with // `def_prop_ro()`. Manually convert an error into an exception. .def_prop_ro("default_memory_kind", diff --git a/jaxlib/py_device_list.h b/jaxlib/py_device_list.h index 5caba6f3dec7..19c646dfc99b 100644 --- a/jaxlib/py_device_list.h +++ b/jaxlib/py_device_list.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -103,6 +104,9 @@ class PyDeviceList { // Requires the self lock or GIL is held. bool IsFullyAddressable(); + // Requires the self lock or GIL. + const std::set& ProcessIndices(); + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and // non-empty. xla::nb_class_ptr py_client_; @@ -122,6 +126,8 @@ class PyDeviceList { std::optional is_fully_addressable_; // Populated on demand. Guarded by the object's self lock. std::optional> addressable_device_list_; + // Populated on demand. Guarded by the object's self lock. + std::optional> process_indices_; struct MemoryKindInfo { nanobind::object default_memory_kind; From 6bd9e3deb4310d0b6f0b75fb3a94c40dc35e7bb0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Apr 2025 17:29:56 -0700 Subject: [PATCH 0931/1769] Add `buffer_callback.py` to `jax` wheel content. PiperOrigin-RevId: 753374761 --- BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/BUILD.bazel b/BUILD.bazel index 4cc8d6d3f63c..82e3b4ab5c00 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -33,6 +33,7 @@ wheel_sources( "//jax:experimental", "//jax:experimental_colocated_python", "//jax:experimental_sparse", + "//jax:experimental_buffer_callback", "//jax:lax_reference", "//jax:pallas_experimental_gpu_ops", "//jax:pallas_gpu_ops", From 8bcd20d601350feb3ffc08c1bdf3e482e6af013c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 30 Apr 2025 18:32:07 -0700 Subject: [PATCH 0932/1769] Fix pytype failure in experimental.shard_map.shard_map PiperOrigin-RevId: 753388918 --- jax/_src/shard_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index af5b277d1a43..960f8167f417 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -129,7 +129,7 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs | Callable[[], Specs], axis_names: Set[AxisName], check_vma: bool, - _skip_mesh_check: bool = False): + _skip_mesh_check: bool = False) -> Callable: if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") From c3574fb5b0f3686a335d207024a90b56ca8211b7 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 30 Apr 2025 20:43:02 -0700 Subject: [PATCH 0933/1769] Make `is_cloud_tpu_older_than` error more verbose (if CPU platform is encountered). PiperOrigin-RevId: 753419532 --- jax/_src/pallas/mosaic/lowering.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f9c85206276a..6a4441d43f17 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -39,6 +39,7 @@ from jax._src import source_info_util from jax._src import state from jax._src import traceback_util +from jax._src import xla_bridge from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.export._export import export from jax._src.interpreters import mlir @@ -664,8 +665,10 @@ def lower_jaxpr_to_module( ) -> tuple[Module, tuple[Any, ...]]: # NOTE: We should bump this periodically if is_cloud_tpu_older_than(2025, 1, 10): + platform_version = xla_bridge.get_backend().platform_version raise RuntimeError( - "Pallas TPU requires a libTPU version that's at most a month old" + "Pallas TPU requires a libtpu version that's at most a month old. Found" + f" version string:\n{platform_version}" ) debug_info = jaxpr.debug_info _mosaic_lowering_dynamic_shape_env = None From b43ce494bb7e8df3dbd06ca1a4cb56b2a3a1cc20 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Apr 2025 23:31:36 -0700 Subject: [PATCH 0934/1769] Automated Code Change PiperOrigin-RevId: 753458817 --- jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 66c227d62a74..9a9e594f928e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include From 9e28d739c2546a038072e1bc6916fa2d830d2abd Mon Sep 17 00:00:00 2001 From: Will Froom Date: Thu, 1 May 2025 00:57:02 -0700 Subject: [PATCH 0935/1769] [JAX:Sparse] Use new sparse kernel in BSCR PiperOrigin-RevId: 753483539 --- jax/_src/lib/__init__.py | 7 ++ jax/experimental/sparse/_lowerings.py | 31 +++++++++ jax/experimental/sparse/bcsr.py | 98 ++++++++++++++++++++++++++- jaxlib/tools/build_wheel.py | 3 +- jaxlib/xla_client.py | 2 +- 5 files changed, 137 insertions(+), 4 deletions(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index e3718d57ccd6..5cdcaf400c8a 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -115,6 +115,13 @@ def _parse_version(v: str) -> tuple[int, ...]: if jaxlib_extension_version >= 334: from jaxlib._jax import ffi as ffi # noqa: F401 +if jaxlib_extension_version >= 335: + import jaxlib.cpu_sparse as cpu_sparse # noqa: F401 + + has_cpu_sparse = True +else: + has_cpu_sparse = False + import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 76e74d13ed69..c2c25db2c561 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -25,6 +25,7 @@ from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse +from jax._src.lib import has_cpu_sparse import numpy as np if hasattr(gpu_sparse, "registrations"): @@ -34,6 +35,16 @@ name, value, platform=platform, api_version=api_version ) +if has_cpu_sparse: + from jax._src.lib import cpu_sparse + + if hasattr(cpu_sparse, "registrations"): + for platform, targets in cpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + def _get_module(target_name_prefix: str) -> Any: if target_name_prefix == "cu": return gpu_sparse._cusparse @@ -272,6 +283,26 @@ def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') + +if has_cpu_sparse: + def _csr_spmm_cpu_lowering(ctx, data, outer_indices, inner_indices, rhs): + rule = ffi.ffi_lowering("cpu_csr_sparse_dense_ffi") + return rule(ctx, data, outer_indices, inner_indices, rhs) + + + # _csr_spmm_cpu_lowering can handle both matrix-matrix and matrix-vector + # multiplication. + mlir.register_lowering( + csr_spmv_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) + mlir.register_lowering( + csr_spmm_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) + def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in nnz, = data_aval.shape diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index a7f7deb2a93f..c7b056c5adfa 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -31,8 +31,8 @@ from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( - nfold_vmap, _count_stored_elements, - _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) + nfold_vmap, _count_stored_elements, _csr_to_coo, + SparseEfficiencyWarning, CuSparseEfficiencyWarning, SparseInfo, Shape) from jax._src.util import split_list, safe_zip from jax._src import api_util @@ -695,6 +695,95 @@ def _bcsr_dot_general_gpu_lowering( shape=lhs_spinfo.shape, transpose=False, target_name_prefix=target_name_prefix) + +def _bcsr_dot_general_cpu_lowering( + # csr_matvec_lowering, csr_matmat_lowering, + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + *, + dimension_numbers, + preferred_element_type, + lhs_spinfo: SparseInfo, +): + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in + props = _validate_bcsr( + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape + ) + + use_default_lowering = False + dtype = lhs_data_aval.dtype + if lhs_batch or rhs_batch: + # TODO(willfroom): Add support for batched matrices. + use_default_lowering = True + elif lhs_data_aval.dtype != rhs_aval.dtype: + use_default_lowering = True + elif ( + preferred_element_type is not None + and preferred_element_type != lhs_data_aval.dtype + ): + use_default_lowering = True + elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]: + # only matmat / matvec supported + use_default_lowering = True + elif props.n_batch or props.n_dense: + # batch and dense dimensions in BCSR not supported + use_default_lowering = True + elif list(lhs_contract) != [1] or list(rhs_contract) != [0]: + # TODO(willfroom): Add support for non-canonical dots. + use_default_lowering = True + elif lhs_indices_aval.dtype != lhs_indptr_aval.dtype: + warnings.warn( + "bcsr_dot_general cpu lowering not available, " + f" {lhs_indices_aval.dtype=} and {lhs_indptr_aval.dtype=} do not match." + " Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + elif lhs_indices_aval.dtype not in [np.int32, np.int64]: + use_default_lowering = True + warnings.warn( + "bcsr_dot_general cpu lowering not available for" + f" {lhs_indices_aval.dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + elif dtype not in [ + np.int32, + np.int64, + np.float32, + np.float64, + np.complex64, + np.complex128, + ]: + # This would be supported if not for the dtype. + warnings.warn( + "bcsr_dot_general cpu lowering not available " + f"for {dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + + if use_default_lowering: + return _bcsr_dot_general_default_lowering( + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type, + lhs_spinfo=lhs_spinfo, + ) + + return _lowerings._csr_spmm_cpu_lowering( + ctx, lhs_data, lhs_indptr, lhs_indices, rhs + ) + + _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) mlir.register_lowering( @@ -713,6 +802,11 @@ def _bcsr_dot_general_gpu_lowering( platform='rocm') +if _lowerings.has_cpu_sparse: + mlir.register_lowering( + bcsr_dot_general_p, _bcsr_dot_general_cpu_lowering, platform="cpu" + ) + #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index ec306ff741f4..56d3bc27c488 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -190,6 +190,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): dst_dir=jaxlib_dir, src_files=[ f"{source_file_prefix}jaxlib/cpu_feature_guard.{pyext}", + f"{source_file_prefix}jaxlib/cpu_sparse.py", f"{source_file_prefix}jaxlib/utils.{pyext}", f"{source_file_prefix}jaxlib/jax_common.dll" if build_utils.is_windows() @@ -221,6 +222,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): dst_dir=jaxlib_dir / "cpu", src_files=[ f"{source_file_prefix}jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_sparse.{pyext}", ], ) @@ -328,7 +330,6 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): ], ) - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" copy_files( dst_dir=mlir_libs_dir, diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 39c9922adbff..def2318ae75b 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 334 +_version = 335 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 68e5aa80f8006c739cd634ebc2e847b4517d8783 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 1 May 2025 03:13:36 -0700 Subject: [PATCH 0936/1769] [pallas:mpgu] inline_mgpu() handle ref transform types. Will replay the transforms passed via the RefType on the raw ref of the input and check that the same transforms were produced. PiperOrigin-RevId: 753519740 --- jax/_src/pallas/mosaic_gpu/primitives.py | 98 +++++++++++++++++------- tests/pallas/mosaic_gpu_test.py | 67 ++++++++++++---- 2 files changed, 123 insertions(+), 42 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7e1c7254ddfb..d0330e60e961 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -19,6 +19,7 @@ from collections.abc import Sequence, Callable import dataclasses import enum +import functools import itertools import math from typing import Any, Literal @@ -1476,7 +1477,17 @@ class GPUShapeDtypeStruct: @dataclasses.dataclass(frozen=True) class RefType: - ... + transforms: tuple[gpu_core.MemoryRefTransform, ...] = () + + +def _undo_transforms( + raw_ref: pallas_core.AbstractMemoryRef, + memory_transforms: Sequence[gpu_core.MemoryRefTransform], +): + """Extract the `Transform`s that reverse the `MemoryRefTransform`s""" + tmp_ref = state_types.TransformedRef(raw_ref, transforms=()) + tmp_ref = functools.reduce(lambda r, t: t.undo(r), reversed(memory_transforms), tmp_ref) + return tmp_ref.transforms def inline_mgpu(arg_types=(), return_type=None): @@ -1540,30 +1551,30 @@ def wrapper(*args): # Strip the transforms from the refs since they will be recorded in # the types. - raw_refs_flat_args = [] + ref_transforms = [] + raw_flat_args = [] for a, t in zip(flat_args, flat_arg_types): - def traced_ty(ty): - return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) - - if isinstance(t, ParameterizedLayout) and traced_ty(jax_core.ShapedArray): - raw_refs_flat_args.append(a) - elif isinstance(t, RefType) and traced_ty(_Ref): - ref, transforms = a, () - if isinstance(a, state_types.TransformedRef): - ref, transforms = ref.ref, ref.transforms - - raw_refs_flat_args.append(ref) - if transforms: - raise NotImplementedError("Transformed refs (or types) are not supported.") + if isinstance(a, state_types.TransformedRef) and isinstance(t, RefType): + raw_flat_args.append(a.ref) + ref_transforms.append(a.transforms) + elif isinstance(aval := jax_core.get_aval(a), jax_core.ShapedArray) and isinstance(t, (ParameterizedLayout, Layout)): + raw_flat_args.append(a) + ref_transforms.append(None) + elif isinstance(aval, state.AbstractRef) and isinstance(t, RefType): + raw_flat_args.append(a) + ref_transforms.append(()) else: raise ValueError(f"Mismatched type: {a, t}") + flat_ref_transforms, pytree_ref_transforms = jax.tree.flatten(ref_transforms) flat_ret = inline_mgpu_p.bind( - *flat_args, - args_treedef=treedef, + *raw_flat_args, + *flat_ref_transforms, + flat_arg_types=flat_arg_types, flat_ret_ty=flat_ret_ty, pytree_ret_ty=pytree_ret_ty, - flat_arg_types=flat_arg_types, + pytree_args=treedef, + pytree_ref_transforms=pytree_ref_transforms, mgpu_fn=f, ) return jax.tree.unflatten(pytree_ret_ty, flat_ret) @@ -1574,18 +1585,20 @@ def traced_ty(ty): @inline_mgpu_p.def_effectful_abstract_eval def _inline_mgpu_abstract_eval( - *flat_args, - args_treedef, + *flat_args_and_transforms, flat_arg_types, flat_ret_ty, + pytree_args, + pytree_ref_transforms, pytree_ret_ty, mgpu_fn, ): - del args_treedef, flat_arg_types, pytree_ret_ty, mgpu_fn # Unused. + del flat_arg_types, pytree_ret_ty, pytree_ref_transforms, mgpu_fn # Unused. aval_return = tuple( jax_core.ShapedArray(x.shape, x.dtype) for x in flat_ret_ty ) # TODO(cperivol): Let the user set the effects. + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] return aval_return, { gpu_core._wgmma_pipeline_effect, gpu_core._memory_effect, @@ -1609,8 +1622,18 @@ def _type_check_mgpu(v, ty): pass case (GPUShapeDtypeStruct(), mgpu.FragmentedArray()): mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) - if v.mlir_dtype != mlir_dtype or ty.shape != v.shape or v.layout != ty.layout.to_mgpu(): - raise ValueError(f"Array type mismatch at {v} != {ty}.") + if v.mlir_dtype != mlir_dtype: + raise ValueError( + f"Array dtype mismatch: expected {v.mlir_dtype} got {mlir_dtype}." + ) + if ty.shape != v.shape: + raise ValueError( + f"Array shape mismatch: expected {ty.shape} got {v.shape}." + ) + if v.layout != ty.layout.to_mgpu(): + raise ValueError( + f"Array layout mismatch: expected {v.layout} got {ty.layout.to_mgpu()}." + ) case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()): if ty.to_mgpu() != v.layout: raise ValueError(f"Unexpected layout for {v} (expected: {ty})") @@ -1621,17 +1644,40 @@ def _type_check_mgpu(v, ty): @lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) def _inline_mgpu_lowering_rule( ctx: lowering.LoweringRuleContext, - *flat_args, + *flat_args_and_transforms, mgpu_fn: Callable[..., Any], flat_arg_types, flat_ret_ty, + pytree_args, + pytree_ref_transforms, pytree_ret_ty, - args_treedef, ): + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] + flat_arg_avals = ctx.avals_in[:pytree_args.num_leaves] + ref_transforms = pytree_ref_transforms.unflatten(flat_args_and_transforms[pytree_args.num_leaves:]) for a, t in zip(flat_args, flat_arg_types): _type_check_mgpu(a, t) - args = jax.tree.unflatten(args_treedef, flat_args) + flat_transformed = [] + for a, aval, t, transforms in zip( + flat_args, flat_arg_avals, flat_arg_types, ref_transforms, strict=True + ): + if not isinstance(t, RefType): + flat_transformed.append(a) + assert transforms is None + continue + assert isinstance(aval, pallas_core.AbstractMemoryRef) + a, user_transforms = lowering._handle_transforms(a, transforms, handle_transposes=False) + # Transforms that do not originate from a MemoryRefTransform are + # applied implicitly (eg by emit-pipeline) and therefore we do not + # expect the user to pass them to the type. The transforms not + # passed by the user here will be discharged. + ty_transforms = _undo_transforms(aval, t.transforms) + if ty_transforms != tuple(user_transforms): + raise ValueError(f"Transform mismatch: got {user_transforms}, expected {ty_transforms}") + flat_transformed.append(a) + + args = jax.tree.unflatten(pytree_args, flat_transformed) ret = mgpu_fn(ctx.launch_ctx, *args) ret_leaves, ret_tree = jax.tree.flatten( ret, is_leaf=lambda x: isinstance(x, mgpu.FragmentedArray) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 349a1dffb274..eac5e735294a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -374,14 +374,32 @@ def kernel(o_ref): ) def test_inline_mgpu(self): - dtype = jnp.bfloat16 + dtype = jnp.dtype(jnp.bfloat16) self.skip_if_wg_semantics() + shape = (128, 128) + tile = (64, 128 // dtype.itemsize) + tiled_shape = mgpu.tile_shape(shape, tile) + tiled_shape_t = list(tiled_shape) + tiled_shape_t[0], tiled_shape_t[1] = tiled_shape_t[1], tiled_shape_t[0] + + key = jax.random.key(0) + x = (jax.random.uniform(key, (2, *shape)) * 42).astype(dtype) + + transforms = ( + plgpu.TilingTransform(tile), + plgpu.TransposeTransform((0, 2, 1, 3, 4)), + plgpu.SwizzleTransform(128), + ) @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + out_shape=jax.ShapeDtypeStruct(shape, dtype), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ - plgpu.SMEM((128, 128), dtype), + plgpu.SMEM( + x.shape, + dtype, + transforms=transforms, + ), plgpu.Barrier(num_arrivals=1), ], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -389,31 +407,48 @@ def test_inline_mgpu(self): def kernel(x_ref, o_ref, smem_ref, barrier): plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) plgpu.barrier_wait(barrier) - layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + # Add an indexer at the end. + sliced_smem_ref = smem_ref.at[0] @plgpu.inline_mgpu( - arg_types=(plgpu.RefType(),), + arg_types=(plgpu.RefType(( + plgpu.TilingTransform(tile), + plgpu.TransposeTransform((1, 0, 2, 3)), + plgpu.SwizzleTransform(128), + )),), return_type=plgpu.GPUShapeDtypeStruct( - (128, 128), dtype, layout=layout + shape, dtype, layout=plgpu.Layout.WGMMA ), ) def foo(ctx, smem_ref): del ctx - x = mgpu.FragmentedArray.load_strided(smem_ref) + assert smem_ref.type.shape == tiled_shape_t, (smem_ref.type, tiled_shape_t) + x = mgpu.FragmentedArray.load_tiled(smem_ref, swizzle=128) y = mgpu.FragmentedArray.splat( mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout ) return (x + y) - arr = foo(smem_ref) - @plgpu.inline_mgpu(arg_types=(layout, plgpu.RefType())) - def store(ctx, arr, o_ref): - del ctx - arr.store_untiled(o_ref) - store(arr, o_ref) + arr = foo(sliced_smem_ref) + @plgpu.inline_mgpu(arg_types=(plgpu.Layout.WGMMA, plgpu.RefType(transforms), plgpu.RefType())) + def store(ctx, arr, smem_ref, o_ref): + sliced_smem_ref = mgpu.memref_slice(smem_ref, (0,)) + arr.store_tiled(sliced_smem_ref, swizzle=128) + mgpu.commit_shared() + ctx.async_copy( + src_ref=sliced_smem_ref, + dst_ref=o_ref, + swizzle=128, + gmem_transform=( + mgpu.TileTransform(tile), + mgpu.TransposeTransform((1, 0, 2, 3)), + ), + ) + ctx.await_async_copy(0) - key = jax.random.key(0) - x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype) - np.testing.assert_array_equal(kernel(x), x + 1) + # This time we slice inside the inline_mgpu body. + store(arr, smem_ref, o_ref) + + np.testing.assert_array_equal(kernel(x), x[0] + 1) @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_smem_to_gmem(self, indexer): From 2ba2e61e84e864b16583dfc8762f986b35af5b26 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 1 May 2025 04:18:07 -0700 Subject: [PATCH 0937/1769] [pallas] Handle no-op broadcasts in broadcast_in_dim mosaic lowering. This is currently handled in the forwarding rule for broadcast_in_dim, so this case is never seen by Pallas, but I'm making some changes to tracing in https://github.com/jax-ml/jax/pull/28396 which mean that sometimes no-op broadcasts will hit this lowering rule. It seems reasonable to add support for this case here. PiperOrigin-RevId: 753536365 --- jax/_src/pallas/mosaic/lowering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6a4441d43f17..a920f0363470 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1902,6 +1902,8 @@ def _broadcast_in_dim_lowering_rule( del sharding (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out + if aval_in.shape == shape: + return val if jnp.issubdtype(aval_in.dtype, jnp.bool_): # Direct broadcasts for bools are not supported in Mosaic due to booleans From c47bc83e579b5e8c5f75123bfdaf7b0b32413983 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 1 May 2025 05:53:58 -0700 Subject: [PATCH 0938/1769] Skip buffer_callback tests on TPU. PiperOrigin-RevId: 753558159 --- tests/buffer_callback_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py index 647b5407567b..de5870cdf47b 100644 --- a/tests/buffer_callback_test.py +++ b/tests/buffer_callback_test.py @@ -33,6 +33,8 @@ def setUp(self): self.skipTest( "Requires a version of jaxlib with buffer callback support." ) + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU.") @parameterized.parameters(jtu.dtypes.all) @jtu.run_on_devices("cpu") From 0e1c6a4be0ad8654da7bccb2e7d3196e9b4454f7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 1 May 2025 06:03:05 -0700 Subject: [PATCH 0939/1769] [Mosaic GPU] Add support for passing in tcgen05.mma LHS in TMEM PiperOrigin-RevId: 753560393 --- jax/experimental/mosaic/gpu/tcgen05.py | 108 +++++++++++++++++-------- tests/mosaic/gpu_test.py | 68 ++++++++++++++++ 2 files changed, 142 insertions(+), 34 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 5afce275cca5..31d37cc44513 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -75,7 +75,7 @@ def create_instr_descriptor( def mma( d: TMEMRef, - a: ir.Value, + a: ir.Value | TMEMRef, b: ir.Value, *, a_swizzle: int = 128, @@ -95,12 +95,22 @@ def mma( num_cta = 2 if collective else 1 # Step 1. Establish the shape and element type of the operation. - if not ir.MemRefType.isinstance(a.type): - raise ValueError(f"A must be a memref, got {a.type}") if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") (k, n), element_type = mma_utils.tiled_memref_shape(b) - (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if isinstance(a, TMEMRef): + m, k2 = a.shape + element_type2 = a.dtype + if collective: + raise NotImplementedError("Collective not supported for TMEMRef") + if a.layout != (expected_layout := _infer_tmem_layout(a.shape, collective, packing=2)): + raise ValueError( + f"A layout mismatch: expected {expected_layout}, got {a.layout}" + ) + else: + if not ir.MemRefType.isinstance(a.type): + raise ValueError(f"A must be a memref, got {a.type}") + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) if k != k2: raise ValueError( "MMA requires A and B to have the same contraction dimension (K)," @@ -132,6 +142,8 @@ def mma( "MMA with element type f16 only supports accumulators of type f32" f" or f16, but got: {d.dtype}" ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. @@ -153,22 +165,27 @@ def mma( m_groups = m // m_group_elems k_groups = k // k_group_elems n_groups = n // n_group_elems - # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. - wgmma_element_type = ( + # TODO(apaszke): Require users to bitcast input refs to tf32 before MMA. + mma_element_type = ( ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type ) # Step 3. Compute the operand descriptors. - ( - (a_desc_base, a_k_instr_stride), - (a_m_group_stride, a_k_group_stride), - a_fastest, - ) = mma_utils.create_descriptor( - a, - swizzle=swizzle, - group_size=(m_group_elems, k_group_elems), - logical_k_major=False, - ) + if not isinstance(a, TMEMRef): + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + else: + a_fastest = mma_utils.Dim.K + a_k_instr_stride = None + a_m_group_stride = a_k_group_stride = a_desc_base = None ( (b_desc_base, b_k_instr_stride), (b_n_group_stride, b_k_group_stride), @@ -184,8 +201,11 @@ def mma( true = arith.constant(ir.IntegerType.get_signless(1), 1) n_collective_group_elems = n_group_elems * num_cta for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): - a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) + if isinstance(a, TMEMRef): + a_mk = a.slice(slice(None), utils.ds(ki * k_group_elems, k_group_elems)).address + else: + a_offset = mi * a_m_group_stride + ki * a_k_group_stride + a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) if m_groups != 1: @@ -207,17 +227,17 @@ def mma( b_k_stride=b_k_instr_stride, accumulate=acc, swizzle=swizzle, - element_type=wgmma_element_type, + element_type=mma_element_type, ) def _do_mma( d_addr: ir.Value, - a_desc: ir.Value, + a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None b_desc: ir.Value, a_transpose: bool, b_transpose: bool, - a_k_stride: int, + a_k_stride: int | None, b_k_stride: int, m: int, n: int, @@ -228,10 +248,13 @@ def _do_mma( collective: bool, ): i1 = ir.IntegerType.get_signless(1) + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - kn_tiling = swizzle // utils.bytewidth(element_type) - instr_k = 32 // utils.bytewidth(element_type) - if a_k_stride % 16 or b_k_stride % 16: + elem_bytewidth = utils.bytewidth(element_type) + kn_tiling = swizzle // elem_bytewidth + instr_k = 32 // elem_bytewidth + packing = 4 // elem_bytewidth + if (a_k_stride is not None and a_k_stride % 16) or b_k_stride % 16: raise ValueError if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type): @@ -243,16 +266,27 @@ def _do_mma( i_desc = create_instr_descriptor( m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose ) + a_in_tmem = a_k_stride is None + a_ptx = "[$1]" if a_in_tmem else "$1" + a_ptx_constraint = "r" if a_in_tmem else "l" + assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64) for _ in range(kn_tiling // instr_k): llvm.inline_asm( ir.Type.parse("!llvm.void"), - [d_addr, a_desc, b_desc, i_desc, accumulate], - f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], $1, $2, $3, $4;", - "r,l,l,r,b", + [d_addr, a_desc_or_addr, b_desc, i_desc, accumulate], + f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, $3, $4;", + f"r,{a_ptx_constraint},l,r,b", has_side_effects=True, ) accumulate = arith.constant(i1, 1) - a_desc = arith.addi(a_desc, arith.constant(i64, a_k_stride >> 4)) + if not a_in_tmem: + a_desc_or_addr = arith.addi( + a_desc_or_addr, arith.constant(i64, a_k_stride >> 4) + ) + else: + a_desc_or_addr = arith.addi( + a_desc_or_addr, arith.constant(i32, instr_k // packing) + ) b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4)) @@ -543,14 +577,18 @@ def from_alloc( return cls(tmem_addr, shape, dtype, layout) def slice(self, *idxs): + i32 = ir.IntegerType.get_signless(32) base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") - if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - raise NotImplementedError( - "Slicing only implemented for refs with standard layout, got:" - f" {self.layout}" - ) + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: + pass + case _: + raise NotImplementedError( + "Slicing only implemented for refs with standard layout, got:" + f" {self.layout}" + ) if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS: raise NotImplementedError("TMEM cannot be sliced along rows") if slice_shape[1] % 8: @@ -559,7 +597,9 @@ def slice(self, *idxs): ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): - col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx) + col_idx = arith.constant(i32, col_idx) + if packing != 1: + col_idx = arith.divui(col_idx, arith.constant(i32, packing)) return TMEMRef( address=arith.addi(self.address, col_idx), shape=tuple(slice_shape), diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9278d71cde40..8232d61df550 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1069,6 +1069,74 @@ def kernel(ctx, lhs, rhs, out, scratch): rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + @parameterized.product( + in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 + out_jax_dtype=(jnp.float16, jnp.float32,), + m=(128,), # TODO(apaszke): 64, 192, 256 + n=(64, 128, 256), # TODO(apaszke): 192, other non-power-of-2 + ) + def test_mma_lhs_tmem(self, m, n, in_jax_dtype, out_jax_dtype): + swizzle = 128 + k_steps = 2 # Reducing to 1 can be helpful while debugging. + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem, barriers, acc, lhs_tmem = scratch + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(lhs_tiling), + barrier=barriers[0], + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + barrier=barriers[1], + ) + barriers[0].wait() + barriers[1].wait() + lhs_tmem[:] = fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) + tcgen05.commit_tmem() + with mgpu.single_thread(): + tcgen05.mma( + acc, lhs_tmem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, + ) + tcgen05.commit_arrive(barriers[2]) + barriers[2].wait(for_tensor_core=True) + acc[:].store_untiled(out, optimized=False) + + x_shape = (m, k) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + mgpu.TMABarrier(3), + mgpu.TMEM((128, n), out_jax_dtype), + mgpu.TMEM((128, k), in_jax_dtype, packing=2), + ] + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = x32 @ y32 + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), From c0da649dedd5edb2a900050da13556023e9912fe Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 1 May 2025 06:09:34 -0700 Subject: [PATCH 0940/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/15565b8da6d85e9faec669cb22878a0e44cca4ee. PiperOrigin-RevId: 753562330 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fbd81d22a568..b409fd17957a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "7bf1170008bdd80e25f38daecc9413961b36d9c3" -XLA_SHA256 = "dcb6d27aabd985090e9df8b776a76f3881362cd502e3c580ffc6d03ee4524fe0" +XLA_COMMIT = "15565b8da6d85e9faec669cb22878a0e44cca4ee" +XLA_SHA256 = "360d260d4da982da900d783d3a2705a5fe9133f0e130c0436485bf4477d60ff0" def repo(): tf_http_archive( From 37b87b52ad1dd32106769eeebea6f9eaa00ed933 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 1 May 2025 06:55:41 -0700 Subject: [PATCH 0941/1769] Fix ASAN/MSAN/TSAN failures for buffer callback. I don't expect that we actually need the device ordinal to be defined on the execution context, but we can add it back in statically (it's already decoded in the handler) if necessary. PiperOrigin-RevId: 753573598 --- jaxlib/_jax/ffi.pyi | 1 - jaxlib/ffi.cc | 13 ------------- jaxlib/ffi.h | 1 - tests/buffer_callback_test.py | 1 - 4 files changed, 16 deletions(-) diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi index efaad46329f9..b92575e77c96 100644 --- a/jaxlib/_jax/ffi.pyi +++ b/jaxlib/_jax/ffi.pyi @@ -44,5 +44,4 @@ class ExecutionStage(enum.IntEnum): class ExecutionContext: def stage(self) -> ExecutionStage: ... - def device_ordinal(self) -> int: ... def stream(self) -> int: ... diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc index 1bf9a5a3150a..790a9876dd10 100644 --- a/jaxlib/ffi.cc +++ b/jaxlib/ffi.cc @@ -152,17 +152,6 @@ absl::StatusOr PyFfiContext::stream() const { return absl::bit_cast(args.stream); } -absl::StatusOr PyFfiContext::device_ordinal() const { - XLA_FFI_DeviceOrdinal_Get_Args args; - args.struct_size = XLA_FFI_DeviceOrdinal_Get_Args_STRUCT_SIZE; - args.extension_start = nullptr; - args.device_ordinal = 0; - if (XLA_FFI_Error* error = api_->XLA_FFI_DeviceOrdinal_Get(&args)) { - return ffi::TakeStatus(error); - } - return args.device_ordinal; -} - PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, void* data, ffi::Span dimensions, ffi::DataType element_type, bool writeable) @@ -377,8 +366,6 @@ void BuildFfiSubmodule(nb::module_& m) { nb::class_ context(ffi_module, "ExecutionContext"); context.def_prop_ro("stage", &PyFfiContext::stage); - context.def_prop_ro("device_ordinal", - xla::ValueOrThrowWrapper(&PyFfiContext::device_ordinal)); context.def_prop_ro("stream", xla::ValueOrThrowWrapper(&PyFfiContext::stream)); } diff --git a/jaxlib/ffi.h b/jaxlib/ffi.h index 393ce79ddabc..e2045a0f513c 100644 --- a/jaxlib/ffi.h +++ b/jaxlib/ffi.h @@ -54,7 +54,6 @@ class PyFfiContext { XLA_FFI_ExecutionStage stage); Stage stage() const; absl::StatusOr stream() const; - absl::StatusOr device_ordinal() const; private: const XLA_FFI_Api* api_; diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py index de5870cdf47b..06f0bcf7c3bc 100644 --- a/tests/buffer_callback_test.py +++ b/tests/buffer_callback_test.py @@ -46,7 +46,6 @@ def callback(ctx, out, arg): ctx.stream self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) - self.assertEqual(ctx.device_ordinal, 0) self.assertEqual(arg.shape, shape) self.assertEqual(arg.dtype, dtype) self.assertEqual(out.shape, shape) From e189cd4650a36796629aa5ee97ddfa95f27033e7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 1 May 2025 08:26:23 -0700 Subject: [PATCH 0942/1769] [Mosaic GPU] Add a debug_print for TMEM refs This is quite helpful while trying to debug the load/store routines. PiperOrigin-RevId: 753599963 --- jax/experimental/mosaic/gpu/tcgen05.py | 24 ++++++++++++ jax/experimental/mosaic/gpu/utils.py | 11 ++++++ tests/mosaic/gpu_test.py | 51 +++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 31d37cc44513..e1f12a0f95cd 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -697,6 +697,30 @@ def __setitem__(self, idxs, value): f"Stores only implemented for refs with standard layout, got: {self.layout}" ) + def _debug_print(self): + i32 = ir.IntegerType.get_signless(32) + num_cols = self.layout.cols_in_shape(self.shape) + lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE)) + for c in range(num_cols): + val = llvm.inline_asm( + i32, + [arith.addi(self.address, arith.constant(i32, c))], + "tcgen05.ld.sync.aligned.32x32b.x1.b32 {$0}, [$1];", + "=r,r", + ) + dtype_bitwidth = utils.bitwidth(self.dtype) + full_packing = 32 // dtype_bitwidth + if self.layout.packing == 1: + if dtype_bitwidth < 32: + val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val) + val = utils.bitcast(val, self.dtype) + elif self.layout.packing == full_packing: + val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype)) + else: + raise NotImplementedError(f"Unsupported packing: {self.layout.packing}") + # TODO(apaszke): Make this print logical, not physical location. + utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False) + def _transfer_32xcols(base_addr: ir.Value, cols: int, packing: int): i32 = ir.IntegerType.get_signless(32) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 7957a92a3e0c..844f5d34c32a 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1280,6 +1280,11 @@ def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value): def bitcast(x: ir.Value, new_type: ir.Type): if x.type == new_type: return x + if (x_bw := bitwidth(x.type)) != (new_bw := bitwidth(new_type)): + raise ValueError( + f"Can't bitcast {x.type} (of bitwidth {x_bw}) to {new_type} (of" + f" bitwidth {new_bw})" + ) if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): new_type = ir.IntegerType(new_type) x_ty = ir.VectorType(x.type) @@ -1299,6 +1304,12 @@ def bitcast(x: ir.Value, new_type: ir.Type): if bitwidth(x_ty) != bitwidth(new_ty): raise ValueError(f"Can't bitcast {x.type} to {new_type}") return vector.bitcast(new_type, x) + if ir.IntegerType.isinstance(x.type) and ir.FloatType.isinstance(new_type): + return arith.bitcast(new_type, x) + if ir.FloatType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): + return arith.bitcast(new_type, x) + if ir.FloatType.isinstance(x.type) and ir.FloatType.isinstance(new_type): + return arith.bitcast(new_type, x) raise ValueError(f"Can't bitcast {x.type} to {new_type}") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 8232d61df550..2431edb0cf28 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -41,9 +41,10 @@ import jax.numpy as jnp import numpy as np try: - import jax._src.lib.mosaic_gpu # noqa: F401 + import jax._src.lib.mosaic_gpu as mosaic_gpu_lib # noqa: F401 HAS_MOSAIC_GPU = True except ImportError: + mosaic_gpu_lib = None HAS_MOSAIC_GPU = False class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok @@ -241,6 +242,16 @@ def setUp(self): self.enter_context(self.context) self.enter_context(ir.Location.unknown()) + @contextlib.contextmanager + def capture_stdout(self): + if mosaic_gpu_lib is None: + raise ValueError("Running tests but missing Mosaic GPU extension") + with jtu.capture_stdout() as stdout: + yield stdout + # We need to cudaDeviceSynchronize to make sure printfs are flushed. + mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + + class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest): @@ -940,6 +951,44 @@ def kernel(ctx, input, output, scratch): )(x) np.testing.assert_array_equal(x, y) + @parameterized.parameters([ + (jnp.float32, 1, "130.0000"), + (jnp.float16, 1, "130.0000"), + (jnp.float16, 2, "[132.000000,133.000000]"), + ]) + @jtu.thread_unsafe_test() + def test_tmem_debug_print(self, jax_dtype, packing, expected): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tcgen05.commit_tmem() + tmem.slice(slice(None), slice(0, 8))._debug_print() + + x = jnp.arange(128 * 128, dtype=jax_dtype).reshape(128, 128) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), + ] + with self.capture_stdout() as stdout: + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x).block_until_ready() + self.assertIn("[1, 2]: " + expected, stdout()) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), From fc7d1b4ec1d5b680f906358f83ec8a13212fd4ff Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Thu, 1 May 2025 16:22:52 +0000 Subject: [PATCH 0943/1769] add numpy as test deps --- jax/BUILD | 2 +- jax/_src/clusters/k8s_cluster.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index d2320e1e4456..8e0af54c6885 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1143,7 +1143,7 @@ pytype_strict_library( ":traceback_util", ":util", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) # Public JAX libraries below this point. diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index bf3cc2bc5702..b40b39cade34 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -17,6 +17,7 @@ from contextlib import contextmanager from functools import cache from itertools import chain +import logging import numpy as np import os import socket @@ -25,7 +26,6 @@ import warnings from jax._src import clusters -import logging logger = logging.getLogger(__name__) From 9107c6344a6fb3fd16b1595d38d0a4bc0d166855 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 1 May 2025 10:16:31 -0700 Subject: [PATCH 0944/1769] Rename with_dll_constraint to with_layout_constraint for the upcoming Layout API rename! PiperOrigin-RevId: 753636449 --- jax/_src/pjit.py | 38 +++++++++---------- .../jax2tf/tests/primitives_test.py | 2 +- jax/experimental/layout.py | 2 +- tests/layout_test.py | 6 +-- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 03c755b451e7..cc2aa874bea8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -3120,49 +3120,49 @@ def use_explicit_axes(*axes): with mesh_lib.use_abstract_mesh(new_mesh): yield -# -------------------- with_dll_constraint -------------------- +# -------------------- with_layout_constraint -------------------- -def with_dll_constraint(x, layouts): +def with_layout_constraint(x, layouts): x_flat, tree = tree_flatten(x) - layouts_flat = tuple(flatten_axes("with_dll_constraint layouts", tree, + layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree, layouts)) if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): raise ValueError( - 'layouts passed to `with_dll_constraint` must be of type' + 'layouts passed to `with_layout_constraint` must be of type' f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') check_aval_layout_compatibility( layouts_flat, x_flat, ("",) * len(layouts_flat), - "with_dll_constraint arguments") - outs = [dll_constraint_p.bind(xf, layout=l) + "with_layout_constraint arguments") + outs = [layout_constraint_p.bind(xf, layout=l) for xf, l in zip(x_flat, layouts_flat)] return tree_unflatten(tree, outs) -dll_constraint_p = core.Primitive('dll_constraint') -dll_constraint_p.def_abstract_eval(lambda x, **_: x) -ad.deflinear2(dll_constraint_p, - lambda ct, _, **params: (dll_constraint_p.bind(ct, **params),)) +layout_constraint_p = core.Primitive('layout_constraint') +layout_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(layout_constraint_p, + lambda ct, _, **params: (layout_constraint_p.bind(ct, **params),)) -def _dll_constraint_impl(x, *, layout): +def _layout_constraint_impl(x, *, layout): if not isinstance(x, xc.ArrayImpl): raise ValueError( - 'with_dll_constraint in eager mode can only be applied to' + 'with_layout_constraint in eager mode can only be applied to' f' jax.Arrays. Got {type(x)}') if x.layout.device_local_layout == layout: # type: ignore return x return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x) -dll_constraint_p.def_impl(_dll_constraint_impl) +layout_constraint_p.def_impl(_layout_constraint_impl) -def _dll_constraint_hlo_lowering(ctx, x_node, *, layout): +def _layout_constraint_hlo_lowering(ctx, x_node, *, layout): aval, = ctx.avals_in out_aval, = ctx.avals_out return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] -mlir.register_lowering(dll_constraint_p, - _dll_constraint_hlo_lowering) +mlir.register_lowering(layout_constraint_p, + _layout_constraint_hlo_lowering) -def _dll_constraint_batcher(axis_data, vals_in, dims_in, layout): +def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout): raise NotImplementedError -batching.fancy_primitive_batchers[dll_constraint_p] = _dll_constraint_batcher -batching.skippable_batchers[dll_constraint_p] = lambda _: () +batching.fancy_primitive_batchers[layout_constraint_p] = _layout_constraint_batcher +batching.skippable_batchers[layout_constraint_p] = lambda _: () # -------------------- helpers -------------------- diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 74e4ddc8136d..f6ce4435e6a2 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -178,7 +178,7 @@ def test_primitive_coverage(self): continue if p.name == "sharding_constraint": continue - if p.name == "dll_constraint": + if p.name == "layout_constraint": continue if p.name == "mesh_cast": continue diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index aa114a2803e8..e98cfbc68104 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -17,5 +17,5 @@ Layout as Layout, ) from jax._src.pjit import ( - with_dll_constraint as with_dll_constraint, + with_layout_constraint as with_layout_constraint, ) diff --git a/tests/layout_test.py b/tests/layout_test.py index ae10013a5f60..c15816d7794a 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -23,7 +23,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax.experimental.layout import (with_dll_constraint, Layout, +from jax.experimental.layout import (with_layout_constraint, Layout, DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on @@ -745,7 +745,7 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) self.assertEqual(out.layout, out_layout) - def test_with_dll_constraint(self): + def test_with_layout_constraint(self): if not jtu.test_device_matches(['tpu']): self.skipTest('Only works for TPU') mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -761,7 +761,7 @@ def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - y = with_dll_constraint(y, custom_dll) + y = with_layout_constraint(y, custom_dll) return y * 2 f(arr) # doesn't crash From dcdc25ba8e5b23da59e5365b2ccf3c85d408828c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 1 May 2025 10:23:17 -0700 Subject: [PATCH 0945/1769] [mosaic_gpu] Use `jtu` helpers instead of `get_sass` PiperOrigin-RevId: 753639002 --- tests/mosaic/gpu_test.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 2431edb0cf28..19048b3a4307 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -20,7 +20,6 @@ import itertools import math import operator -import os import re import unittest @@ -88,20 +87,6 @@ def mlir_sum(elems): return total -@contextlib.contextmanager -def get_sass(): - prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) - os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" - try: - with jtu.capture_stdout() as output: - yield output - finally: - if prev_dump is not None: - os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump - else: - del os.environ["MOSAIC_GPU_DUMP_SASS"] - - def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) @@ -2430,6 +2415,7 @@ def kernel(ctx, dst, _): num_col_tiles=[1, 2, 3], row_tiling=[8, 64], ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. def test_copy_tiled(self, dtype, swizzle, num_col_tiles, row_tiling): mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) @@ -2455,7 +2441,7 @@ def kernel(ctx, in_, out, smems): .transpose(0, 2, 1, 3) ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, [expected, expected, mgpu.TMABarrier()], @@ -2554,6 +2540,7 @@ def kernel(ctx, in_, out, smems): (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5), (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2), ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. def test_upcast_to_wgmma( self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg ): @@ -2597,7 +2584,7 @@ def tile(x, tiling): f = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: yt_kernel = f(xt) np.testing.assert_array_equal(yt_kernel, yt) self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) From 48001a24cb74f311b51d8bcf0891437069db6b95 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 1 May 2025 10:27:53 -0700 Subject: [PATCH 0946/1769] Make literal's dtype print with an empty shape so that it's consistent. So `1.0:f32` -> `1.0:f32[]` PiperOrigin-RevId: 753640777 --- docs/aot.md | 2 +- jax/_src/api.py | 2 +- jax/_src/core.py | 2 +- tests/api_test.py | 8 ++++---- tests/pjit_test.py | 4 ++-- tests/state_test.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/aot.md b/docs/aot.md index a4422f88e6b0..1870f8c55093 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -50,7 +50,7 @@ some other features along the way. An example: >>> # Print the specialized, staged-out representation (as Jaxpr IR) >>> print(traced.jaxpr) { lambda ; a:i32[] b:i32[]. let - c:i32[] = mul 2:i32 a + c:i32[] = mul 2:i32[] a d:i32[] = add c b in (d,) } diff --git a/jax/_src/api.py b/jax/_src/api.py index 4a729b223ff1..61d685ef40f1 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2396,7 +2396,7 @@ def make_jaxpr( c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b - e:f32[] = mul 1.0:f32 d + e:f32[] = mul 1.0:f32[] d f:f32[] = neg e g:f32[] = mul f c in (g,) } diff --git a/jax/_src/core.py b/jax/_src/core.py index 4bc64b85b81d..63d7eb35d6f9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -475,7 +475,7 @@ def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del context # unused dtype = getattr(self.aval, 'dtype', None) if print_dtype and dtype: - return f'{self.val}:{dtypes.short_dtype_name(dtype)}' + return f'{self.val}:{self.aval.str_short(short_dtypes=True)}' else: return f'{self.val}' diff --git a/tests/api_test.py b/tests/api_test.py index 1149cf78e626..610719518a03 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6679,7 +6679,7 @@ def fun(x): return (x, 1., np.zeros(1, dtype=jnp.float32)) dtype = "f64" if config.enable_x64.value else "f32" - expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}, a) }}" + expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}[], a) }}" jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) @@ -6691,9 +6691,9 @@ def f(x): x + 2., lambda xf: xf - x) expected = """{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0:f32 - c:f32[] = add a 1.0:f32 - d:f32[] = add a 2.0:f32 + b:bool[] = ge a 0.0:f32[] + c:f32[] = add a 1.0:f32[] + d:f32[] = add a 2.0:f32[] e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8d56f96d684d..443880fc502b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1274,7 +1274,7 @@ def test_pretty_print_with_constant_pjit_arg(self): b:f32[1] = pjit[ name= jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) } - ] a 1.0:f32 + ] a 1.0:f32[] in (b,) } """).strip(), ) @@ -1308,7 +1308,7 @@ def test_pretty_print_with_literal_outvar(self): { lambda ; a:f32[1]. let b:i32[] c:f32[1] = pjit[ name= - jaxpr={ lambda ; a:f32[1]. let in (2:i32, a) } + jaxpr={ lambda ; a:f32[1]. let in (2:i32[], a) } ] a in (b, c) } """).strip(), diff --git a/tests/state_test.py b/tests/state_test.py index a8d6e88659a6..9bbc68101443 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -361,7 +361,7 @@ def body(x_ref): return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("a[] <- 2:i32", jaxpr.pretty_print(use_color=False)) + self.assertIn("a[] <- 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val @@ -377,7 +377,7 @@ def body(x_ref): return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("b:i32[], a[] <- a[], 2:i32", jaxpr.pretty_print(use_color=False)) + self.assertIn("b:i32[], a[] <- a[], 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) From b96817ab1bb26f7b5ddb97a540a01b668bf7ad82 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 1 May 2025 13:36:34 -0400 Subject: [PATCH 0947/1769] [Mosaic GPU] Deallocate TMEM using a single warp See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-ops-execution-granularity. --- jax/experimental/mosaic/gpu/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 86be9825dcef..7b0fdb688708 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -486,8 +486,9 @@ def _launch( if tmem_allocs: gpu.barrier() # Make sure everyone is done before we release TMEM. - for alloc in tmem_allocs: - alloc.dealloc() + with utils.when(is_init_warp): + for alloc in tmem_allocs: + alloc.dealloc() if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() From 42d76fb3145816030f7691dc40e57709698cf08f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 30 Apr 2025 10:08:43 -0700 Subject: [PATCH 0948/1769] jax._src.util: improve type annotations --- jax/_src/checkify.py | 2 +- jax/_src/export/_export.py | 2 +- jax/_src/export/shape_poly.py | 2 +- jax/_src/nn/functions.py | 2 +- jax/_src/numpy/reductions.py | 2 +- jax/_src/pallas/fuser/block_spec.py | 4 +-- jax/_src/pallas/mosaic/pipeline.py | 2 +- jax/_src/sharding_impls.py | 3 ++- jax/_src/state/primitives.py | 2 +- jax/_src/util.py | 41 ++++++++++++++++++----------- 10 files changed, 37 insertions(+), 25 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index c8caffeb7877..5a6456762db7 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -976,7 +976,7 @@ def shard_map_error_check( in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_names[i], v) with (jshmap._extend_axis_env(mesh, manual_axes), - mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 1f58c7e0def6..d5a328bb8e05 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1408,7 +1408,7 @@ def pp_arg_dim(dim_idx: int | None) -> str: # it would be ambiguous whether we should continue tracing with a result # of type `f32[c]` or `f32[d]`. shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) + exported_dim_values = [synthetic_eval.evaluate(solution[var]) # type: ignore[arg-type] for var in exported_dim_vars] out_avals = tuple( core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 405592cadd2b..31371cf345a1 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -2021,7 +2021,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) # type: ignore[arg-type] def _solve_dim_equations( eqns: list[_DimEquation], diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index f16deb41a69b..3f7647758003 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -661,7 +661,7 @@ def _one_hot(x: Array, num_classes: int, *, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) try: - output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) + output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) # type: ignore[arg-type] except TypeError: axis_size = lax.axis_size(axis) if num_classes != axis_size: diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index d2ae80925597..9cb543d5d869 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -988,7 +988,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, if axis is None: pass - elif isinstance(axis, tuple): + elif isinstance(axis, Sequence): axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis) else: axis = _canonicalize_axis(axis, a_ndim) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 2795b1e52f6a..9e7e18c590dd 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -745,7 +745,7 @@ def new_index_map(i, *args): bcast_dim_block_shape = 1 if isinstance(block_spec.block_shape[i], pallas_core.Element): bcast_dim_block_shape = pallas_core.Element(1) - new_block_shape = util.tuple_update( + new_block_shape = util.tuple_update( # pytype: disable=wrong-arg-types block_spec.block_shape, i, bcast_dim_block_shape ) return pallas_core.BlockSpec( @@ -1128,7 +1128,7 @@ def _concatenate_rule( # Handle special case if the block contains all of the concatenated # array. new_shapes = [ - util.tuple_update( + util.tuple_update( # pytype: disable=wrong-arg-types block_spec.block_shape, dimension, aval.shape[dimension] # pytype: disable=attribute-error ) for aval in ctx.avals_in diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index dd83dab3c3c5..df7be297c9e8 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1171,7 +1171,7 @@ def _partition_grid( offsets = jax_util.tuple_update( (0,) * len(grid), partition_dimension, grid_offset ) - return new_grid, offsets + return new_grid, offsets # type: ignore[return-value] def emit_pipeline( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index ec5fd4c512bf..6e86911e63b0 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1252,7 +1252,8 @@ def logical_sharding(logical_shape, dtype, phys_sharding) -> jsharding.Sharding: @util.cache() def create_mesh_pspec_sharding( - mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, + pspec: PartitionSpec | None, memory_kind: str | None = None) -> NamedSharding: if pspec is None: pspec = PartitionSpec() diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 1237da57f217..dbcc67df18cb 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -537,7 +537,7 @@ def _batch_indexer( idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) else: - idx = batching.moveaxis(idx, dim, 0) + idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type] new_indices.append(idx) else: if ref_dim is not batching.not_mapped: diff --git a/jax/_src/util.py b/jax/_src/util.py index e551c654b005..d009111f1f34 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -21,7 +21,7 @@ import itertools as it import logging import operator -from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, SupportsIndex, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np @@ -34,6 +34,9 @@ Seq = Sequence +# TODO(jakevdp): fix import cycles and import Array. +Array = Any + T = TypeVar("T") T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -137,13 +140,15 @@ def unzip3(xyzs: Iterable[tuple[T1, T2, T3]] zs.append(z) return tuple(xs), tuple(ys), tuple(zs) -def subvals(lst, replace): +def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...]: + """Substitute values within a list.""" lst = list(lst) for i, v in replace: lst[i] = v return tuple(lst) def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) lists = [] for n in ns: @@ -153,6 +158,7 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) assert sum(ns) == len(args) lists = [] @@ -162,8 +168,9 @@ def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]: + """Partition a list into two based on a mask.""" assert len(bs) == len(l) - lists = [], [] # type: ignore + lists: tuple[list[T], list[T]] = ([], []) for b, x in zip(bs, l): lists[b].append(x) return lists @@ -172,6 +179,7 @@ def merge_lists(bs: Sequence[bool], l0: Sequence[T1], l1: Sequence[T2] ) -> list[T1 | T2]: + """Merge the elements of two lists based on a mask.""" assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) i0, i1 = iter(l0), iter(l1) out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs] @@ -200,7 +208,7 @@ def subs_list2( assert next(base_, sentinel) is sentinel return out -def split_dict(dct, names): +def split_dict(dct: dict[T1, T2], names: Sequence[T1]) -> list[T2]: dct = dict(dct) lst = [dct.pop(name) for name in names] assert not dct @@ -244,7 +252,10 @@ def curry(f): toposort = partial(jaxlib_utils.topological_sort, "parents") -def split_merge(predicate, xs): +def split_merge( + predicate: Callable[[T], bool], + xs: Sequence[T] +) -> tuple[list[T], list[T], Callable[[Sequence[T], Sequence[T]], list[T]]]: sides = list(map(predicate, xs)) lhs = [x for x, s in zip(xs, sides) if s] rhs = [x for x, s in zip(xs, sides) if not s] @@ -349,10 +360,10 @@ def __hash__(self): def __eq__(self, other): return self.val == other.val -def wrap_name(name, transform_name): +def wrap_name(name: str, transform_name: str) -> str: return transform_name + '(' + name + ')' -def fun_name(fun: Callable): +def fun_name(fun: Callable) -> str: name = getattr(fun, "__name__", None) if name is not None: return name @@ -361,7 +372,7 @@ def fun_name(fun: Callable): else: return "" -def fun_qual_name(fun: Callable): +def fun_qual_name(fun: Callable) -> str: qual_name = getattr(fun, "__qualname__", None) if qual_name is not None: return qual_name @@ -369,7 +380,7 @@ def fun_qual_name(fun: Callable): return fun_qual_name(fun.func) return fun_name(fun) -def canonicalize_axis(axis, num_dims) -> int: +def canonicalize_axis(axis: SupportsIndex, num_dims: int) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) if not -num_dims <= axis < num_dims: @@ -378,7 +389,7 @@ def canonicalize_axis(axis, num_dims) -> int: axis = axis + num_dims return axis -def moveaxis(x, src, dst): +def moveaxis(x: Array, src: int | Sequence[int], dst: int | Sequence[int]) -> Array: if src == dst: return x if isinstance(src, int): @@ -392,7 +403,7 @@ def moveaxis(x, src, dst): perm.insert(d, s) return x.transpose(perm) -def ceil_of_ratio(x, y): +def ceil_of_ratio(x: int, y: int) -> int: return -(-x // y) @@ -429,15 +440,15 @@ def wrapper(fun: T) -> T: def assert_unreachable(x): raise AssertionError(f"Unhandled case: {type(x).__name__}") -def tuple_insert(t, idx, val): +def tuple_insert(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx <= len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx:] -def tuple_delete(t, idx): +def tuple_delete(t: tuple[T, ...], idx: int) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + t[idx + 1:] -def tuple_update(t, idx, val): +def tuple_update(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] @@ -578,7 +589,7 @@ def __eq__(self, other): return self.x == other.x if self.hash is not None else self.x is other.x -def _original_func(f): +def _original_func(f: Callable) -> Callable: if isinstance(f, property): return cast(property, f).fget elif isinstance(f, functools.cached_property): From 4768e0038f0fb2cef227c97c71c332372e3ebb06 Mon Sep 17 00:00:00 2001 From: Luke Tsekouras Date: Thu, 1 May 2025 12:25:04 -0700 Subject: [PATCH 0949/1769] Fix bug in all_leaves when is_leaf is specified This issue occurs when some of the leaves have custom `__eq__` methods defined on them, which either result in errors when compared to some other types (see http://cl/753579906), or result in return values that cannot have their truthiness evaluated, e.g.: ``` import jax.tree_util as jtu import numpy as np jtu.all_leaves( [[np.asarray([1, 2])]], is_leaf=lambda x: jtu.all_leaves([x]), ) ``` ``` ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() ``` This fix avoids equality issues by using the `is` operator instead of `==`, and introduces tests for the case where `is_leaf` is provided. PiperOrigin-RevId: 753684035 --- jax/_src/tree_util.py | 7 +++++-- tests/tree_util_test.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 7c7ca96b1e5c..e2e97c90f120 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -202,8 +202,11 @@ def all_leaves(iterable: Iterable[Any], if is_leaf is None: return pytree.all_leaves(default_registry, iterable) else: - lst = list(iterable) - return lst == tree_leaves(lst, is_leaf) + items = list(iterable) + leaves = tree_leaves(items, is_leaf) + return len(leaves) == len(items) and all( + item is leaf for item, leaf in zip(items, leaves, strict=True) + ) _Children = TypeVar("_Children", bound=Iterable[Any]) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index e5e649d43d8a..0df811d9da28 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -552,6 +552,18 @@ def testAllLeavesWithTrees(self, tree): def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) + @parameterized.parameters(*TREES) + def testAllLeavesWithTreesAndCustomIsLeaf(self, tree): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertFalse(tree_util.all_leaves([tree], is_leaf=is_leaf)) + + @parameterized.parameters(*LEAVES) + def testAllLeavesWithLeavesAndCustomIsLeaf(self, leaf): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertTrue(tree_util.all_leaves([leaf], is_leaf=is_leaf)) + @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) From 879eb41c3695a5232111c7150e226402c7a4aa0f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 1 May 2025 12:58:00 -0700 Subject: [PATCH 0950/1769] Propagate function name when recording elapsed time event. Useful for filtering events by function name or differentiating between events. PiperOrigin-RevId: 753695215 --- jax/_src/dispatch.py | 8 ++++++-- tests/compilation_cache_test.py | 3 ++- tests/monitoring_test.py | 11 ++++++----- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index fd96883a53ef..64991c2fd3e2 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -191,8 +191,12 @@ def __exit__(self, exc_type, exc_value, traceback): logger.log(log_priority, self.fmt.format( fun_name=self.fun_name, elapsed_time=elapsed_time)) if self.event is not None: - record_event_duration_secs(self.event, elapsed_time) - record_event_time_span(self.event, self.start_time, end_time) + record_event_duration_secs( + self.event, elapsed_time, fun_name=self.fun_name + ) + record_event_time_span( + self.event, self.start_time, end_time, fun_name=self.fun_name + ) log_elapsed_time = LogElapsedTimeContextManager diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index d9f8cdddf0f1..1ba6b1221a88 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -358,7 +358,8 @@ def test_cache_saving_metric(self): config.persistent_cache_min_entry_size_bytes(0), ): durations = Counter() # Map metric name to time duration. - def append_metric_duration(metric, duration): + def append_metric_duration(metric, duration, **kwargs): + del kwargs durations[metric] += duration with jtu.register_event_duration_listener(append_metric_duration): diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 89c7148a2a42..5ef5c5d928ba 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -49,7 +49,8 @@ def increment_event_counter(event): def test_record_event_durations(self): durations = {} # Map event names to frequency. - def increment_event_duration(event, duration): + def increment_event_duration(event, duration, **kwargs): + del kwargs if event not in durations: durations[event] = 0. durations[event] += duration @@ -88,7 +89,7 @@ def test_record_scalar(self): def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, original_duration_listeners) monitoring.register_event_duration_secs_listener(callback) self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -102,7 +103,7 @@ def test_unregister_exist_callback_success(self): jax_src_monitoring.get_event_duration_listeners()) def test_unregister_not_exist_callback_fail(self): - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -112,7 +113,7 @@ def test_unregister_not_exist_callback_fail(self): def test_unregister_callback_index_in_range_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, original_duration_listeners) monitoring.register_event_duration_secs_listener(callback) self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -138,7 +139,7 @@ def test_unregister_callback_index_out_of_range_fail(self): def test_get_event_duration_listeners_returns_a_copy(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None original_duration_listeners.append(callback) From b59c0e4818e2c71ba519cbe893e81a2d33cfbc3e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 1 May 2025 12:58:50 -0700 Subject: [PATCH 0951/1769] Initial commit for `smap` i.e. shard_map1D. The signature is `smap(f, in_axes, out_axes, axis_name)`. This change does NOT make the API public. The API semantics are as follows: * `smap` only allows going into `Manual` mode one mesh axes at a time via the `axis_name` argument. * mesh needs to be present in the context via `use_mesh` or `set_mesh`. * If in_axes or out_axes contains `None`, it means that the input(s) is **replicated**. This is similar to `vmap` where `None` means unmapped input. * If the context mesh is in full explicit mode, `in_axes` can be inferred from the arguments. But how do we tell `smap` to do that? We **can't** use `None` because `None` means replicated in `smap`. So we introduce a singleton called `Infer` which when passed to `smap`, will tell it to infer the in_axes (in_specs) from the arguments! For example: `smap(f, in_axes=Infer, out_axes=0, axis_name='x')`. You always have the option of specifying `in_axes` and not infer even in full explicit mode :) PiperOrigin-RevId: 753695446 --- jax/_src/interpreters/pxla.py | 3 +- jax/_src/shard_map.py | 54 +++++++++++++++++++++- tests/shard_map_test.py | 87 ++++++++++++++++++++++++++++++++++- 3 files changed, 140 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 75d26360c363..a7782063491c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -84,6 +84,7 @@ class WeakRefList(list): xe = xc._xla unsafe_map, map = map, safe_map # type: ignore +zip, unsafe_zip = safe_zip, zip # type: ignore logger = logging.getLogger(__name__) @@ -1321,7 +1322,7 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._buf._replace_with(o) + args[i]._buf._replace_with(o) # type: ignore else: out_.append(o) return out_ diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 960f8167f417..43d31f18bb1e 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -41,7 +41,7 @@ from jax._src.core import pvary from jax._src.core import Tracer, typeof from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, - get_abstract_mesh) + get_abstract_mesh, get_concrete_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy @@ -126,6 +126,54 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), return lambda g: _shard_map(g, **kwargs) return _shard_map(f, **kwargs) +def _axes_to_pspec(axis_name, axis): + if axis is None: + return P() + return P(*[None] * axis + [axis_name]) + +class InferFromArgs: + + def __repr__(self): + return "jax.sharding.Infer" + + def __reduce__(self): + return (_get_default_infer, ()) + +Infer = InferFromArgs() + +def _get_default_infer(): + return Infer + +# TODO(yashkatariya): We need a singleton which users can provide to `in_axes` +# to tell smap to infer in_specs from args when mesh is fully explicit. +def smap(f, in_axes, out_axes, axis_name: AxisName): + if isinstance(axis_name, (list, tuple)): + raise TypeError( + f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}") + if (in_axes is not None and in_axes is not Infer and + not isinstance(in_axes, (int, tuple))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of" + " entries corresponding to the positional arguments passed to the" + f" function, but got {in_axes}.") + if (in_axes is not Infer and + not all(isinstance(l, int) for l in tree_leaves(in_axes))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or (nested)" + f" container with those types as leaves, but got {in_axes}.") + if not all(isinstance(l, int) for l in tree_leaves(out_axes)): + raise TypeError("smap out_axes must be an int, None, or (nested) container " + f"with those types as leaves, but got {out_axes}.") + + in_specs = (None if in_axes is Infer else + tree_map(partial(_axes_to_pspec, axis_name), in_axes, + is_leaf=lambda x: x is None)) + out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes, + is_leaf=lambda x: x is None) + return shard_map(f, axis_names={axis_name}, in_specs=in_specs, + out_specs=out_specs) + + def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs | Callable[[], Specs], axis_names: Set[AxisName], check_vma: bool, @@ -172,7 +220,7 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, raise TypeError( "shard_map in_specs argument must be a pytree of" " `jax.sharding.PartitionSpec` instances, but it was None when mesh" - f" {mesh} has `Auto` axes.\n") + f" has `Auto` axes {mesh}") if in_specs is not None: _check_specs(SpecErrorType.input, in_specs, axis_names) @@ -886,6 +934,8 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + mesh = concrete_mesh if concrete_mesh is not None else mesh mesh = get_mesh_from_args(args, mesh) cur_mesh = get_abstract_mesh() args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f1496e7e1e18..a4c420959710 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -36,7 +36,7 @@ from jax._src import config from jax._src import core from jax._src import prng -from jax._src.shard_map import shard_map +from jax._src.shard_map import shard_map, smap, Infer from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists @@ -3130,6 +3130,91 @@ def argmax_impl(x): argmax_impl(jax.random.normal(jax.random.key(0), (1024, 1024))) # doesn't crash + def test_smap(self): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + @jax.jit + def f(x): + return smap(h, in_axes=0, out_axes=0, axis_name='x')(x) + + with jax.sharding.use_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_smap_explicit(self, mesh): + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, P('x', 'y')) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().explicit_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + @jax.jit + def f(x): + return smap(h, in_axes=Infer, out_axes=0, axis_name='x')(x) + + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_replicated(self, mesh): + @partial(smap, in_axes=None, out_axes=None, axis_name='x') + def f(x): + return x * 2 + out = f(np.arange(8)) + self.assertArraysEqual(out, np.arange(8) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @jtu.with_explicit_mesh((2,), ('data',), axis_types=(AxisType.Auto,)) + def test_smap_replicated_sharded(self, mesh): + @partial(smap, in_axes=(None, 0), out_axes=(None, 0), axis_name='data') + def f(x, y): + return x * 2, y * 2 + + out1, out2 = f(np.arange(8), np.arange(8)) + self.assertArraysEqual(out1, np.arange(8) * 2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P())) + self.assertArraysEqual(out2, np.arange(8) * 2) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('data'))) + + @partial(smap, in_axes=(None, 0), out_axes=0, axis_name='data') + def g(x, y): + return x + y + + out = g(np.arange(4), np.arange(8)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data'))) + class FunSpec(NamedTuple): name: str From 02c2768d30e36d555b67c16830045f6a0e14f6bc Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 1 May 2025 13:25:17 -0700 Subject: [PATCH 0952/1769] Wait for async result to populate the exception. PiperOrigin-RevId: 753705149 --- tests/python_callback_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 17087be35f8b..4aac07992ca8 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -606,7 +606,7 @@ def f(x): with self.assertRaisesRegex( Exception, "Unsupported primitive type" ): - _ = jax.jit(f)(x) + _ = jax.jit(f)(x).block_until_ready() @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): @@ -629,7 +629,7 @@ def f(): with self.assertRaisesRegex( Exception, "Unsupported primitive type" ): - _ = jax.jit(f)() + _ = jax.jit(f)().block_until_ready() class PureCallbackTest(jtu.JaxTestCase): From 5a81db007105f41cf2130498750a3f76382469d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 1 May 2025 13:39:27 -0700 Subject: [PATCH 0953/1769] Integrate LLVM at llvm/llvm-project@7752e0a10b25 Updates LLVM usage to match [7752e0a10b25](https://github.com/llvm/llvm-project/commit/7752e0a10b25) PiperOrigin-RevId: 753710403 --- jax/experimental/mosaic/gpu/core.py | 5 +++-- jax/experimental/mosaic/gpu/dialect_lowering.py | 3 ++- jax/experimental/mosaic/gpu/launch_context.py | 6 +++--- jax/experimental/mosaic/gpu/utils.py | 8 ++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 86be9825dcef..21f5278829d5 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -271,7 +271,8 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: ) smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") barrier_base_ptr = llvm.getelementptr( - smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8 + smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8, + llvm.GEPNoWrapFlags.none ) dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES return barrier_base_ptr @@ -550,7 +551,7 @@ def main(token_ptr, buffers): token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty, llvm.GEPNoWrapFlags.none)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) in_refs = arg_refs[:len(in_ref_tys)] out_refs = arg_refs[len(in_ref_tys):] diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 9ca5349dd562..5edccafec236 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -194,6 +194,7 @@ def _initialize_barrier_op_lowering_rule( [], [i], lowered_barrier_type, + llvm.GEPNoWrapFlags.none, ), utils.c( initialize_barrier_op.arrival_count.value * utils.WARPGROUP_SIZE, @@ -206,7 +207,7 @@ def _initialize_barrier_op_lowering_rule( barrier_base_ptr = llvm.getelementptr( ir.Type.parse("!llvm.ptr"), - initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type) + initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type, llvm.GEPNoWrapFlags.none) return utils.ptr_as_memref( barrier_base_ptr, initialize_barrier_op.barriers_ref.type), diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 64cdedc779c8..2ec9047402a4 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -294,13 +294,13 @@ def _alloc_scratch( self.next_scratch_offset += size def host_init_wrapped(host_ptr): host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none) ) self.host_scratch_init.append(host_init_wrapped) # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): # There is no way to create an insertion point after an operation... gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none ) gep.move_after(self.gmem_scratch_ptr.owner) return device_init(gep.result) @@ -339,7 +339,7 @@ def init_tma_desc(host_ptr): alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none, ) rank = ref_ty.rank assert rank * 2 == len(sizes_and_strides) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 844f5d34c32a..8bea56abf485 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -99,7 +99,7 @@ def pack_array(values): ptr_ty = ir.Type.parse("!llvm.ptr") arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty) for i, v in enumerate(values): - elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty) + elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty, llvm.GEPNoWrapFlags.none) llvm.store(v, elem_ptr) return arr_ptr @@ -721,7 +721,7 @@ def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> with single_thread(scope=ThreadSubset.BLOCK): for i in range(num_barriers): nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr, address, [], [i], i64), + llvm.getelementptr(ptr, address, [], [i], i64, llvm.GEPNoWrapFlags.none), c(arrival_count, i32), ) return BarrierRef(address, c(0, i32), phases, num_barriers) @@ -793,7 +793,7 @@ def get_ptr(self): i64 = ir.IntegerType.get_signless(64) DYNAMIC32 = -2147483648 return llvm.getelementptr( - ptr, self.base_address, [self.offset], [DYNAMIC32], i64 + ptr, self.base_address, [self.offset], [DYNAMIC32], i64, llvm.GEPNoWrapFlags.none ) @@ -1241,7 +1241,7 @@ def getelementptr( ) -> ir.Value: static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] dyn_indices = [i for i in indices if not isinstance(i, int)] - return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype, llvm.GEPNoWrapFlags.none) def dyn_dot(x, y): From 42572338ac26616ba368a3d4deae9608c620a11a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 1 May 2025 13:34:43 -0700 Subject: [PATCH 0954/1769] Update shard_map.md's API specification section --- docs/notebooks/shard_map.ipynb | 13 ++++++++----- docs/notebooks/shard_map.md | 13 ++++++++----- jax/_src/shard_map.py | 8 ++++---- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index e8128c4133f7..d04b7583a4a0 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -827,17 +827,20 @@ "Specs = PyTree[PartitionSpec]\n", "\n", "def shard_map(\n", - " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n", - " auto: collections.abc.Set[AxisName] = frozenset([]),\n", + " f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,\n", + " in_specs: Specs | None = None,\n", + " axis_names: collections.abc.Set[AxisName] = set(),\n", " check_vma: bool = True,\n", ") -> Callable:\n", " ...\n", "```\n", "where:\n", "* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n", - "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", - "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", - "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", + "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the\n", + "context which can be set via the `jax.sharding.use_mesh` context manager.\n", + "* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types;\n", + "* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", + "* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh.\n", "* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 43069110301d..bf139b48d6f3 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -554,17 +554,20 @@ from jax.sharding import Mesh Specs = PyTree[PartitionSpec] def shard_map( - f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, - auto: collections.abc.Set[AxisName] = frozenset([]), + f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None, + in_specs: Specs | None = None, + axis_names: collections.abc.Set[AxisName] = set(), check_vma: bool = True, ) -> Callable: ... ``` where: * communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`; -* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; -* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; -* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; +* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the +context which can be set via the `jax.sharding.use_mesh` context manager. +* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types; +* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; +* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh. * `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 43d31f18bb1e..939dbeddf3d7 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -97,8 +97,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis - name expresses replication. If ``None``, all mesh axes must be in explicit - mode, in which case the in_specs are inferred from the argument types. + name expresses replication. If ``None``, all mesh axes must be of type + `Explicit`, in which case the in_specs are inferred from the argument types. out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be @@ -107,8 +107,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), corresponding positional axis; not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced. - axis_names: (optional, default None) set of axis names from ``mesh`` over - which the function ``f`` is manual. If ``None``, ``f``, is manual + axis_names: (optional, default set()) set of axis names from ``mesh`` over + which the function ``f`` is manual. If empty, ``f``, is manual over all mesh axes. check_vma: (optional) boolean (default True) representing whether to enable additional validity checks and automatic differentiation optimizations. From cbf79e0a294102d3af305e7ce0a1e6aa35eb6239 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 1 May 2025 14:28:53 -0700 Subject: [PATCH 0955/1769] Enable usage of mirrored `.tar` redistributions in Bazel RBE GPU jobs. The extraction of `.tar` files is 10 times faster than the extraction of `.tar.xz` files. By enabling `.tar` files usage in RBE jobs we are going to save at least one min of execution time in all Bazel RBE GPU jobs. PiperOrigin-RevId: 753730448 --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 755572f21355..b6bb79cc8d0d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -329,6 +329,8 @@ build:rbe_linux_x86_64 --config=ci_linux_x86_64 build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 +# Speed up CUDA repos creation by downloading ".tar" dists from the mirror. +build:rbe_linux_x86_64_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1 # RBE configs for Windows # Set the remote worker pool From 9fad6055040ca769c76821d9c9b851a91143599a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 30 Apr 2025 15:00:56 +0000 Subject: [PATCH 0956/1769] Update requirements files: * add jaxlib, jax-cuda12-plugin, jax-cuda12-pjrt, and libtpu to the requirements files. This allows testing without building jaxlib against the latest jaxlib release. * remove gpu-test-requirements.txt, and instead just use the jax-cuda12-plugin[with-cuda] extra. This avoids the need for us to redundantly list the set of packages. * remove several packages used when building JAX wheels from test-requirements.txt, and instead move them to requirements.in. * relax the version constraint on NumPy for Python 3.14, where we want the newest version. * remove version constraints on matplotlib, which existed to work around NumPy 2 upgrade problems. * regenerate the lock files for Python 3.10-3.13. --- build/BUILD.bazel | 1 - build/freethreading-requirements.txt | 3 +- build/gpu-test-requirements.txt | 13 --- build/nonfreethreading-requirements.txt | 2 +- build/requirements.in | 17 ++++ build/requirements_lock_3_10.txt | 99 +++++++++++++++------- build/requirements_lock_3_11.txt | 99 +++++++++++++++------- build/requirements_lock_3_12.txt | 99 +++++++++++++++------- build/requirements_lock_3_13.txt | 101 ++++++++++++++++------- build/requirements_lock_3_13_ft.txt | 104 +++++++++++++++++------- build/test-requirements.txt | 7 +- 11 files changed, 380 insertions(+), 165 deletions(-) delete mode 100644 build/gpu-test-requirements.txt diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 539c156d3ac4..761cf02ad624 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -22,7 +22,6 @@ licenses(["notice"]) COMMON_REQUIREMENTS = [ "requirements.in", "test-requirements.txt", - "gpu-test-requirements.txt", ] # It isn't possible to constraint based on free-threaded vs non-free threaded diff --git a/build/freethreading-requirements.txt b/build/freethreading-requirements.txt index 2bbaf1fe8443..cc302cffdd0c 100644 --- a/build/freethreading-requirements.txt +++ b/build/freethreading-requirements.txt @@ -1,2 +1,3 @@ # Under free-threading, we need an up-to-date numpy at least for the moment. -numpy~=2.2.5 +numpy~=2.2.5; python_version=="3.13" +numpy>=2.2.5; python_version>="3.14" diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt deleted file mode 100644 index d0dda5cf526c..000000000000 --- a/build/gpu-test-requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -# NVIDIA CUDA dependencies -# Note that the wheels are downloaded only when the targets in bazel command -# contain dependencies on these wheels. -nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" -nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" -nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" -nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" -nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" -nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" -nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" -nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt index 86f5e64d1973..f8171559a142 100644 --- a/build/nonfreethreading-requirements.txt +++ b/build/nonfreethreading-requirements.txt @@ -1,6 +1,6 @@ numpy~=2.0.0; python_version<="3.12" numpy~=2.1.0; python_version=="3.13" -numpy~=2.2.5; python_version>="3.14" +numpy>=2.2.5; python_version>="3.14" # These packages have not released free-threaded wheels. zstandard diff --git a/build/requirements.in b/build/requirements.in index 108c5f7492b0..d2fc3a60a708 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -6,4 +6,21 @@ scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 etils[epath] +opt-einsum + +# Needed to build wheels +build setuptools +wheel + +# JAX's own libraries. We include these in the requirements so you can +# bazel test without building jaxlib and without manually updating the +# the requirements files. +jaxlib + +# The with-cuda extra also includes NVIDIA's pip packages. +jax-cuda12-plugin[with-cuda] +jax-cuda12-pjrt + +# TPU dependencies +libtpu ; sys_platform == "linux" and platform_machine == "x86_64" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 51d09c6638bb..45820e38f195 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -160,6 +160,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.0 \ + --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ + --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.0 \ + --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ + --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ + --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ + --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ + --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ + --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ + --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ + --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ + --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ + --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a + # via -r build/requirements.in +jaxlib==0.6.0 \ + --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ + --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ + --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ + --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ + --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ + --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ + --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ + --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ + --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ + --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ + --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ + --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ + --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ + --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ + --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ + --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ + --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ + --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 + # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ @@ -266,11 +304,14 @@ kiwisolver==1.4.5 \ --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.8.4 ; python_version == "3.10" \ +matplotlib==3.8.4 \ --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ @@ -331,6 +372,7 @@ ml-dtypes==0.5.1 \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via # -r build/requirements.in + # jaxlib # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -385,73 +427,74 @@ numpy==2.0.0 ; python_version <= "3.12" \ # via # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # opt-einsum # scipy # tensorstore -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/gpu-test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/gpu-test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/gpu-test-requirements.txt -nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/gpu-test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/gpu-test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/gpu-test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/gpu-test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -625,7 +668,9 @@ scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -670,7 +715,7 @@ typing-extensions==4.12.0rc1 \ wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.18.2 \ --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e @@ -728,6 +773,4 @@ zstandard==0.22.0 \ setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 00e9af7ea2dc..e2140583c7e0 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -154,6 +154,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.0 \ + --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ + --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.0 \ + --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ + --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ + --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ + --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ + --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ + --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ + --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ + --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ + --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ + --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a + # via -r build/requirements.in +jaxlib==0.6.0 \ + --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ + --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ + --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ + --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ + --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ + --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ + --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ + --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ + --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ + --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ + --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ + --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ + --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ + --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ + --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ + --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ + --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ + --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 + # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ @@ -260,11 +298,14 @@ kiwisolver==1.4.5 \ --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ +matplotlib==3.9.0 \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -326,6 +367,7 @@ ml-dtypes==0.5.1 \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via # -r build/requirements.in + # jaxlib # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -380,73 +422,74 @@ numpy==2.0.0 ; python_version <= "3.12" \ # via # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # opt-einsum # scipy # tensorstore -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/gpu-test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/gpu-test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/gpu-test-requirements.txt -nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/gpu-test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/gpu-test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/gpu-test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/gpu-test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -620,7 +663,9 @@ scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -659,7 +704,7 @@ typing-extensions==4.12.0rc1 \ wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.18.2 \ --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e @@ -717,6 +762,4 @@ zstandard==0.22.0 \ setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 3bf4f29bfac8..7482f6b2bad9 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -154,6 +154,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.0 \ + --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ + --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.0 \ + --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ + --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ + --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ + --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ + --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ + --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ + --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ + --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ + --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ + --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a + # via -r build/requirements.in +jaxlib==0.6.0 \ + --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ + --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ + --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ + --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ + --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ + --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ + --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ + --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ + --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ + --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ + --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ + --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ + --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ + --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ + --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ + --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ + --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ + --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 + # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ @@ -260,11 +298,14 @@ kiwisolver==1.4.5 \ --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ +matplotlib==3.9.0 \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -326,6 +367,7 @@ ml-dtypes==0.5.1 \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via # -r build/requirements.in + # jaxlib # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -380,73 +422,74 @@ numpy==2.0.0 ; python_version <= "3.12" \ # via # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # opt-einsum # scipy # tensorstore -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/gpu-test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/gpu-test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/gpu-test-requirements.txt -nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/gpu-test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/gpu-test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/gpu-test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/gpu-test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -620,7 +663,9 @@ scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -659,7 +704,7 @@ typing-extensions==4.12.0rc1 \ wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.18.2 \ --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e @@ -717,6 +762,4 @@ zstandard==0.22.0 \ setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index d0508fc3e8bc..83cccfa84e4b 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -181,6 +181,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.0 \ + --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ + --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.0 \ + --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ + --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ + --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ + --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ + --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ + --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ + --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ + --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ + --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ + --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a + # via -r build/requirements.in +jaxlib==0.6.0 \ + --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ + --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ + --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ + --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ + --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ + --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ + --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ + --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ + --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ + --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ + --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ + --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ + --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ + --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ + --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ + --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ + --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ + --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 + # via -r build/requirements.in kiwisolver==1.4.7 \ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ @@ -297,11 +335,14 @@ kiwisolver==1.4.7 \ --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \ --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052 # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.2 ; python_version >= "3.11" \ +matplotlib==3.9.2 \ --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \ --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \ --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \ @@ -374,12 +415,13 @@ ml-dtypes==0.5.1 \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via # -r build/requirements.in + # jaxlib # tensorstore mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.1.2 ; python_version >= "3.13" \ +numpy==2.1.2 ; python_version == "3.13" \ --hash=sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8 \ --hash=sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466 \ --hash=sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35 \ @@ -436,72 +478,73 @@ numpy==2.1.2 ; python_version >= "3.13" \ # via # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # scipy # tensorstore -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/gpu-test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/gpu-test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/gpu-test-requirements.txt -nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/gpu-test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/gpu-test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/gpu-test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/gpu-test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 @@ -693,7 +736,9 @@ scipy==1.15.2 ; python_version >= "3.13" \ --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -732,7 +777,7 @@ typing-extensions==4.12.2 \ wheel==0.44.0 \ --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \ --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.20.2 \ --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 @@ -841,6 +886,4 @@ zstandard==0.23.0 \ setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 7fce0eef6a8a..5dd300f224e5 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -19,7 +19,7 @@ auditwheel==6.2.0 \ build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.1.0 \ --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \ --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e @@ -172,6 +172,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.0 \ + --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ + --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.0 \ + --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ + --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ + --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ + --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ + --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ + --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ + --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ + --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ + --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ + --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a + # via -r build/requirements.in +jaxlib==0.6.0 \ + --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ + --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ + --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ + --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ + --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ + --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ + --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ + --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ + --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ + --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ + --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ + --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ + --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ + --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ + --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ + --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ + --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ + --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 + # via -r build/requirements.in kiwisolver==1.4.8 \ --hash=sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50 \ --hash=sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c \ @@ -254,11 +292,14 @@ kiwisolver==1.4.8 \ --hash=sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34 \ --hash=sha256:fc2ace710ba7c1dfd1a3b42530b62b9ceed115f19a1656adefce7b1782a37794 # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.10.0 ; python_version >= "3.11" \ +matplotlib==3.10.0 \ --hash=sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6 \ --hash=sha256:12eaf48463b472c3c0f8dbacdbf906e573013df81a0ab82f0616ea4b11281908 \ --hash=sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6 \ @@ -323,12 +364,14 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.5 \ +numpy==2.2.5 ; python_version == "3.13" \ --hash=sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70 \ --hash=sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a \ --hash=sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4 \ @@ -387,71 +430,72 @@ numpy==2.2.5 \ # via # -r build/freethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/gpu-test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/gpu-test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/gpu-test-requirements.txt -nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via -r build/gpu-test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/gpu-test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/gpu-test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/gpu-test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/gpu-test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.2 \ --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f @@ -618,7 +662,9 @@ scipy==1.15.2 ; python_version >= "3.13" \ --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 @@ -634,7 +680,7 @@ typing-extensions==4.12.2 \ wheel==0.45.1 \ --hash=sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729 \ --hash=sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 @@ -644,6 +690,4 @@ zipp==3.21.0 \ setuptools==70.3.0 \ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/test-requirements.txt b/build/test-requirements.txt index f0b315771cbb..ef23b10ddf88 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,5 +1,4 @@ absl-py -build cloudpickle colorama>=0.4.4 filelock @@ -10,12 +9,8 @@ pillow>=10.4.0 # TODO(kanglan): Remove once psutil from portpicker supports python 3.13t portpicker; python_version<"3.13" pytest-xdist -wheel rich -setuptools # matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement # below. -matplotlib~=3.8.4; python_version=="3.10" -matplotlib; python_version>="3.11" -opt-einsum +matplotlib auditwheel \ No newline at end of file From 23df262ec57c5644a1d09aef9f157690256cf53e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 1 May 2025 14:35:07 -0700 Subject: [PATCH 0957/1769] [state] Slightly restructured `_is_trivial_indexer` PiperOrigin-RevId: 753732387 --- jax/_src/state/discharge.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 615fa862bf31..bc6a20a0a76e 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -208,14 +208,13 @@ def _eval_jaxpr_discharge_state( return out_vals + ref_vals def _is_trivial_indexer(indexer: indexing.NDIndexer): + """Returns whether the indexer selects the entire shape.""" for s, idx in zip(indexer.shape, indexer.indices): if not isinstance(idx, indexing.Slice): return False - if not isinstance(idx.start, int): + if idx.is_dynamic_start or idx.is_dynamic_size: return False - if idx.start: - return False - if idx.size != s: + if idx.start != 0 or idx.size != s: return False return True From d985f3662d19c59d803be75c68fb73ec4b30932a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 1 May 2025 14:49:25 -0700 Subject: [PATCH 0958/1769] remove spurious assertion PiperOrigin-RevId: 753737826 --- jax/_src/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index e979745e73e5..1227b15c5ead 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -150,7 +150,6 @@ def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...] def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: """Split list into sublists of the specified sizes.""" args = list(args) - assert all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) From 11d33824763fcab66c8a9738c491c3355829337a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 1 May 2025 21:32:46 +0000 Subject: [PATCH 0959/1769] [attrs] replace list.freeze() with list.get(), only allow locals --- jax/_src/pjit.py | 1 + jax/_src/util.py | 1 - jax/experimental/attrs.py | 27 ++++++------------ tests/attrs_test.py | 59 ++++++++++++++++++++++++++------------- 4 files changed, 50 insertions(+), 38 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 28a7f1361c67..7c716910eaeb 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -3270,6 +3270,7 @@ def visit(x): return box elif type(x) is ListTree: lst = List() + lst._is_arg = True arg_mutables.append(lst) return lst else: diff --git a/jax/_src/util.py b/jax/_src/util.py index e979745e73e5..1227b15c5ead 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -150,7 +150,6 @@ def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...] def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: """Split list into sublists of the specified sizes.""" args = list(args) - assert all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index d483ac076a0f..db738ee6368d 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -366,29 +366,26 @@ def _box_set_linearize(trace, box, val): class List: _val: PyTree _tag: core.OpaqueTraceState - frozen: bool + _is_arg: bool def __init__(self, val=None): - self._val = [] if val is None else val + self._val = [] if val is None else val[:] self._tag = core.get_opaque_trace_state() - self.frozen = False + self._is_arg = False def append(self, val): with core.take_current_trace() as t: return t.process_list_append(self, val) - def freeze(self): + def get(self): with core.take_current_trace() as t: - return t.process_list_freeze(self) + if _is_local(t, self) and not self._is_arg: + return self._val[:] # defensive copy in case caller erroneously mutates + raise Exception("can't read the value of a List that was not created in " + "this scope") +AppendList = List def _list_append_impl(trace, lst, val): - if lst.frozen: - raise Exception("can't append to an already-frozen List") lst._val.append(val) core.EvalTrace.process_list_append = _list_append_impl -def _list_freeze_impl(trace, lst): - lst.frozen = True - return lst._val -core.EvalTrace.process_list_freeze = _list_freeze_impl - def _list_append_staging(trace, lst, val): if not _is_local(trace, lst): _ensure_list_tracked(trace, lst) @@ -401,9 +398,3 @@ def _ensure_list_tracked(trace, lst): frame.attrs_inits.append(lst._val) frame.attrs_tracked.append((lst, '_val', pe.ListAttr)) lst._val = [] - -def _list_freeze_staging(trace, lst): - if not _is_local(trace, lst): - raise Exception("can only freeze a local List") - return _list_freeze_impl(trace, lst) -pe.DynamicJaxprTrace.process_list_freeze = _list_freeze_staging diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 6e1381ffbc88..0d3a85d0e694 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -1113,7 +1113,7 @@ def test_eager(self): lst.append(1.0) lst.append(2.0) lst.append(3.0) - self.assertAllClose(lst.freeze(), [1., 2., 3.]) + self.assertAllClose(lst.get(), [1., 2., 3.]) def test_jit_arg(self): @jax.jit @@ -1127,13 +1127,13 @@ def f(lst, x): tracing_ok = True lst1 = List() f(lst1, 0) - self.assertAllClose(lst1.freeze(), [1., 2., {'c': 3.}]) + self.assertAllClose(lst1.get(), [1., 2., {'c': 3.}]) tracing_ok = False lst2 = List() lst2.append(0.) f(lst2, 1) - self.assertAllClose(lst2.freeze(), [0., 1., 2., {'c': 4.}]) + self.assertAllClose(lst2.get(), [0., 1., 2., {'c': 4.}]) def test_jit_closure(self): lst = List() @@ -1151,7 +1151,7 @@ def f(x): tracing_ok = False f(2) - self.assertAllClose(lst.freeze(), [1., {'a': 2.}, 4., 1., {'a': 2.0}, 5.0]) + self.assertAllClose(lst.get(), [1., {'a': 2.}, 4., 1., {'a': 2.0}, 5.0]) def test_jit_closure_nested(self): lst = List() @@ -1168,7 +1168,7 @@ def k(x): k(2.0) h(0.0) - self.assertAllClose(lst.freeze(), [0., 1., 2.]) + self.assertAllClose(lst.get(), [0., 1., 2.]) @parameterized.parameters([False, True]) def test_scan_basic(self, jit): @@ -1186,7 +1186,7 @@ def body(_, x): f() - self.assertAllClose(lst.freeze(), [0., 1., 2., 3., 4., 5.]) + self.assertAllClose(lst.get(), [0., 1., 2., 3., 4., 5.]) @parameterized.parameters([False, True]) def test_scan_basic_hetero(self, jit): @@ -1212,16 +1212,16 @@ def body(_, x): 4., {'a': (5., 6.)}, ] - self.assertAllClose(lst.freeze(), expected) + self.assertAllClose(lst.get(), expected) @parameterized.parameters([False, True]) - def test_freeze_basic(self, jit): + def test_get_basic(self, jit): def f(): lst = List() lst.append(1.) lst.append(2.) - return lst.freeze() + return lst.get() if jit: f = jax.jit(f) @@ -1234,9 +1234,9 @@ def test_freeze_nonlocal_list(self): @jax.jit def f(): - lst.freeze() + lst.get() - with self.assertRaisesRegex(Exception, "can only freeze a local List"): + with self.assertRaisesRegex(Exception, "can't read the value"): f() def test_freeze_nonlocal_list_nested(self): @@ -1246,27 +1246,48 @@ def f(): @jax.jit def g(): - lst.freeze() + lst.get() g() - with self.assertRaisesRegex(Exception, "can only freeze a local List"): + with self.assertRaisesRegex(Exception, "can't read the value"): f() @parameterized.parameters([False, True]) - def test_append_after_freeze(self, jit): + def test_append_after_get(self, jit): def f(): lst = List() lst.append(1.) lst.append(2.) - val = lst.freeze() - with self.assertRaisesRegex(Exception, "can't append"): - lst.append(3.) + val = lst.get() + lst.append(3.) + return lst.get() if jit: f = jax.jit(f) - f() + lst = f() + self.assertAllClose(lst, [1., 2., 3.]) + + def test_get_on_nonlocal_list_closure(self): + lst = List() + + @jax.jit + def f(): + lst.append(1.) + lst.append(2.) + with self.assertRaisesRegex(Exception, "can't read"): + val = lst.get() + + def test_get_on_nonlocal_list_arg(self): + lst = List() + + @jax.jit + def f(lst): + lst.append(1.) + lst.append(2.) + with self.assertRaisesRegex(Exception, "can't read"): + val = lst.get() @parameterized.parameters([False, True]) def test_custom_vjp_plumbing(self, jit): @@ -1292,7 +1313,7 @@ def f(x): f = jax.jit(f) jax.grad(f)(1.0) - self.assertAllClose(lst.freeze(), [2.0]) + self.assertAllClose(lst.get(), [2.0]) def test_error_passing_multiple_times_to_jit(self): @jax.jit From f72524979580b86d950d93c458d6a94cc18bb3e8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 1 May 2025 15:01:43 -0700 Subject: [PATCH 0960/1769] [mosaic_gpu] Added support for using `cf.assert` in Mosaic GPU kernels The XLA GPU runtime does not yet handle device assertions well and will hang if the assert is triggered. However, the assertion output still appears in stderr, so I think having `cf.assert` support is still useful. PiperOrigin-RevId: 753742121 --- jax/BUILD | 1 + jax/_src/lib/BUILD | 1 + jax/_src/lib/mlir/dialects/__init__.py | 6 ++++ jaxlib/BUILD | 1 + jaxlib/mlir/BUILD.bazel | 35 +++++++++++++------ jaxlib/mlir/_mlir_libs/BUILD.bazel | 3 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 2 ++ jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/custom_call.cc | 13 +++---- jaxlib/tools/build_wheel.py | 2 ++ tests/mosaic/gpu_test.py | 30 +++++++++++++++- 11 files changed, 76 insertions(+), 19 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 14d205c0e210..7643c96bd0e0 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -866,6 +866,7 @@ py_library_providing_imports_info( "//jax/_src/lib", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 20e6f66f6a5a..4bbc861432aa 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -52,6 +52,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:ir", "//jaxlib/mlir:math_dialect", diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index be5317824c36..5584afee2116 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -55,3 +55,9 @@ # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo + +from jax._src import lib +if lib.version >= (0, 6, 1): + from jaxlib.mlir.dialects import cf +else: + cf = None # type: ignore[no-redef] diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 29f274ad9a03..5165f5cf2520 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -82,6 +82,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index de7b017355fc..c7231c557e78 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -65,7 +65,7 @@ symlink_inputs( name = "func_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], + "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], }}, deps = [ ":core", @@ -78,7 +78,7 @@ symlink_inputs( name = "vector_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], }}, deps = [ ":core", @@ -91,7 +91,7 @@ symlink_inputs( name = "math_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], }}, deps = [ ":core", @@ -104,7 +104,7 @@ symlink_inputs( name = "arithmetic_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], }}, deps = [ ":core", @@ -117,7 +117,20 @@ symlink_inputs( name = "memref_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + }}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "control_flow_dialect", + rule = py_library, + symlinked_inputs = {"srcs": { + "dialects": ["@llvm-project//mlir/python:ControlFlowOpsPyFiles"], }}, deps = [ ":core", @@ -130,7 +143,7 @@ symlink_inputs( name = "scf_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], + "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], }}, deps = [ ":core", @@ -143,7 +156,7 @@ symlink_inputs( name = "builtin_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], }}, deps = [ ":core", @@ -157,7 +170,7 @@ symlink_inputs( name = "chlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:chlo_ops_py_files"], + "dialects": ["@stablehlo//:chlo_ops_py_files"], }}, deps = [ ":core", @@ -171,7 +184,7 @@ symlink_inputs( name = "sparse_tensor_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], }}, deps = [ ":core", @@ -186,7 +199,7 @@ symlink_inputs( name = "mhlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], + "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], }}, deps = [ ":core", @@ -228,7 +241,7 @@ symlink_inputs( name = "stablehlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:stablehlo_ops_py_files"], + "dialects": ["@stablehlo//:stablehlo_ops_py_files"], }}, deps = [ ":core", diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 25f2162685b9..6e54c9be83f5 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -208,6 +208,7 @@ nanobind_pywrap_extension( deps = [ "//jaxlib/mosaic/gpu:mlir_capi", "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPICF", "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:CAPILLVM", @@ -297,4 +298,4 @@ nanobind_pywrap_extension( "@nanobind", "@stablehlo//:stablehlo_capi", ], -) \ No newline at end of file +) diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index b8432bf615c9..3c2604640a19 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/ControlFlow.h" #include "mlir-c/Dialect/Func.h" // IWYU pragma: keep #include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep #include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep @@ -50,6 +51,7 @@ NB_MODULE(register_jax_dialects, m) { REGISTER_DIALECT(scf); REGISTER_DIALECT(vector); // For Mosaic GPU + REGISTER_DIALECT(cf); REGISTER_DIALECT(gpu); REGISTER_DIALECT(nvgpu); REGISTER_DIALECT(nvvm); diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 0eb24781379e..be83fd4f6b18 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -155,6 +155,7 @@ cc_library( "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ExecutionEngine", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index a933b72ad55a..5d812f483de4 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -60,6 +60,7 @@ limitations under the License. #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -228,12 +229,12 @@ mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, void InitContext(mlir::MLIRContext* context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); mlir::registerConvertNVVMToLLVMInterface(registry); mlir::registerConvertComplexToLLVMInterface(registry); mlir::registerConvertMemRefToLLVMInterface(registry); diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 56d3bc27c488..0c29a7ae6ea3 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -272,6 +272,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/mlir/dialects/_arith_enum_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/_arith_ops_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/_builtin_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_cf_ops_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/_chlo_ops_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/_func_ops_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/_math_ops_gen.py", @@ -296,6 +297,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_ops_gen.py", f"{source_file_prefix}jaxlib/mlir/dialects/arith.py", f"{source_file_prefix}jaxlib/mlir/dialects/builtin.py", + f"{source_file_prefix}jaxlib/mlir/dialects/cf.py", f"{source_file_prefix}jaxlib/mlir/dialects/chlo.py", f"{source_file_prefix}jaxlib/mlir/dialects/func.py", f"{source_file_prefix}jaxlib/mlir/dialects/math.py", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 19048b3a4307..c8e204433a37 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -32,6 +32,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member @@ -237,7 +238,6 @@ def capture_stdout(self): mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() - class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest): def setUp(self): @@ -3320,6 +3320,34 @@ def test_parse_indices_oob(self, indices): with self.assertRaisesRegex(IndexError, "out of bounds"): utils.parse_indices(indices, (2, 3, 4)) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + def test_assert(self): + if cf is None: + self.skipTest("``cf`` is not available") + + def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: + del ctx, out # Unused. + # TODO(b/408271232): Use a False condition once the bug is fixed. + x = mgpu.FragmentedArray.load_strided(x_ref) + cond = x.reduce_sum(*scratch) != 42.0 + cf.assert_(cond.registers.item(), "OOOPS") + + f = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),), + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),), + ) + + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: + f(jnp.ones((128,), jnp.float32)) + + # SASS doesn't seem to include the assertion message, so we are just + # checking that __assertfail appears in the symbol table for the kernel. + self.assertIn("__assertfail", sass()) + class SerializationTest(absltest.TestCase): From 4cd63a7c81114cc39122a48dd970753e38bb96b5 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 1 May 2025 17:18:45 -0700 Subject: [PATCH 0961/1769] Add out_sharding on broadcast_to(). PiperOrigin-RevId: 753786660 --- jax/_src/numpy/lax_numpy.py | 6 ++++-- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 34417adeceb1..1908529e078c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -65,6 +65,7 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding +from jax._src.sharding_impls import (NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_map import numpy as np @@ -3096,7 +3097,8 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: @export -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, + *, out_sharding: NamedSharding | P | None = None) -> Array: """Broadcast an array to a specified shape. JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style @@ -3130,7 +3132,7 @@ def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ - return util._broadcast_to(array, shape) + return util._broadcast_to(array, shape, sharding=out_sharding) def _split(op: str, ary: ArrayLike, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 443880fc502b..dcbfa99d8a27 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7041,6 +7041,20 @@ def iota(): out = iota() self.assertEqual(out.sharding, yz_sharding) + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_broadcast_to(self, mesh): + x = np.arange(24).reshape((1, 24)) + x = jax.device_put(x, P(None, ('y', 'z'))) + + @jax.jit + def f(x): + out = jnp.broadcast_to(x, (8, 24), out_sharding=P('x', ('y', 'z'))) + self.assertEqual(out.aval.sharding.spec, P('x', ('y', 'z'))) + return out + + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', ('y', 'z')))) + @jtu.with_explicit_mesh((2,), ('x',)) def test_cumsum(self, mesh): np_inp = np.arange(16).reshape(8, 2) From 3b3c1b3e8b0114a95709ab71025f9181c7d25c33 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 1 May 2025 17:58:18 -0700 Subject: [PATCH 0962/1769] Add out_sharding to `jnp.ravel` and `jnp.reshape` and `jnp.dot` which just forward out_sharding to their lax variants. PiperOrigin-RevId: 753797017 --- jax/BUILD | 1 + jax/_src/basearray.pyi | 11 ++-- jax/_src/lax/lax.py | 4 -- jax/_src/numpy/array_methods.py | 7 ++- jax/_src/numpy/lax_numpy.py | 15 ++--- jax/_src/numpy/tensor_contractions.py | 9 ++- jax/numpy/__init__.pyi | 7 ++- tests/pjit_test.py | 81 ++++++++++++++++++++------- 8 files changed, 91 insertions(+), 44 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 7643c96bd0e0..f1a45a48ec1b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -440,6 +440,7 @@ pytype_strict_library( srcs = ["_src/basearray.py"], pytype_srcs = ["_src/basearray.pyi"], deps = [ + ":named_sharding", ":partition_spec", ":sharding", ":util", diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 8bf68f622051..3026b3bd6ab9 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -17,7 +17,8 @@ from types import ModuleType from typing import Any, Protocol, runtime_checkable, Union import numpy as np -from jax._src.partition_spec import PartitionSpec +from jax._src.partition_spec import PartitionSpec as P +from jax._src.named_sharding import NamedSharding from jax._src.sharding import Sharding @@ -183,12 +184,14 @@ class Array(metaclass=abc.ABCMeta): promote_integers: bool = True) -> Array: ... def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... - def ravel(self, order: str = 'C') -> Array: ... + def ravel(self, order: str = 'C', + out_sharding: NamedSharding | P | None = ...) -> Array: ... @property def real(self) -> Array: ... def repeat(self, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: ... - def reshape(self, *args: Any, order: str = "C") -> Array: ... + def reshape(self, *args: Any, order: str = "C", + out_sharding: NamedSharding | P | None = ...) -> Array: ... def round(self, decimals: int = 0, out: None = None) -> Array: ... def searchsorted(self, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... @@ -282,7 +285,7 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_spec: Sharding | PartitionSpec | None = None) -> Array: ... + out_spec: Sharding | P | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f94f28be32c5..214fc9650505 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2507,10 +2507,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ - if out_sharding is not None and not isinstance(out_sharding, NamedSharding): - raise NotImplementedError( - '`out_sharding` argument of `dot_general` only supports NamedSharding ' - 'instances. Please file a bug if this is not enough for your use case.') out_sharding = canonicalize_sharding(out_sharding, 'dot_general') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 3783e09bf694..3e65e5d83100 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -300,7 +300,8 @@ def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, """ return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) -def _reshape(self: Array, *args: Any, order: str = "C") -> Array: +def _reshape(self: Array, *args: Any, order: str = "C", out_sharding=None + ) -> Array: """Returns an array containing the same data with a new shape. Refer to :func:`jax.numpy.reshape` for full documentation. @@ -308,10 +309,10 @@ def _reshape(self: Array, *args: Any, order: str = "C") -> Array: __tracebackhide__ = True newshape = _compute_newshape(self, args[0] if len(args) == 1 else args) if order == "C": - return lax.reshape(self, newshape, None) + return lax.reshape(self, newshape, None, out_sharding=out_sharding) elif order == "F": dims = list(range(self.ndim)[::-1]) - return lax.reshape(self, newshape[::-1], dims).T + return lax.reshape(self, newshape[::-1], dims, out_sharding=out_sharding).T elif order == "A": raise NotImplementedError("np.reshape order=A is not implemented.") else: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1908529e078c..4b355ce60436 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1946,7 +1946,7 @@ def isrealobj(x: Any) -> bool: @export def reshape( a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, - copy: bool | None = None) -> Array: + copy: bool | None = None, out_sharding=None) -> Array: """Return a reshaped copy of an array. JAX implementation of :func:`numpy.reshape`, implemented in terms of @@ -2020,16 +2020,17 @@ def reshape( util.check_arraylike("reshape", a) try: - # forward to method for ndarrays - return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] + if out_sharding is None: + # forward to method for ndarrays + return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: pass - return asarray(a).reshape(shape, order=order) + return asarray(a).reshape(shape, order=order, out_sharding=out_sharding) @export -@partial(jit, static_argnames=('order',), inline=True) -def ravel(a: ArrayLike, order: str = "C") -> Array: +@partial(jit, static_argnames=('order', 'out_sharding'), inline=True) +def ravel(a: ArrayLike, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. JAX implementation of :func:`numpy.ravel`, implemented in terms of @@ -2078,7 +2079,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: a = util.ensure_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") - return reshape(a, (np.size(a),), order) + return reshape(a, (np.size(a),), order, out_sharding=out_sharding) @export diff --git a/jax/_src/numpy/tensor_contractions.py b/jax/_src/numpy/tensor_contractions.py index 990f17c2b23e..979f68e28f6d 100644 --- a/jax/_src/numpy/tensor_contractions.py +++ b/jax/_src/numpy/tensor_contractions.py @@ -36,10 +36,12 @@ export = set_module('jax.numpy') @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@partial(jit, static_argnames=('precision', 'preferred_element_type', 'out_sharding'), + inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_sharding=None) -> Array: """Compute the dot product of two arrays. JAX implementation of :func:`numpy.dot`. @@ -119,7 +121,8 @@ def dot(a: ArrayLike, b: ArrayLike, *, contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=out_sharding) return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index df6454c9a1f1..1e8e900f1e04 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -350,7 +350,8 @@ def divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: DTypeLike | None = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... double: Any def dsplit( ary: ArrayLike, indices_or_sections: int | ArrayLike @@ -798,7 +799,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = .. r_: _RClass def rad2deg(x: ArrayLike, /) -> Array: ... def radians(x: ArrayLike, /) -> Array: ... -def ravel(a: ArrayLike, order: str = ...) -> Array: ... +def ravel(a: ArrayLike, order: str = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = ..., order: str = ...) -> Array: ... def real(x: ArrayLike, /) -> Array: ... @@ -809,6 +811,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, total_repeat_length: int | None = ...) -> Array: ... def reshape( a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dcbfa99d8a27..0bd570f37004 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5511,36 +5511,43 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, None))) self.assertArraysEqual(out, np_inp.reshape(2, 32, 1)) - @parameterized.named_parameters( - ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), - ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), - ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True), - ('4', (1, 4, 1, 6, 1), (1, 4, 6), - P(None, 'x', None, None, None), P(None, 'x', None), False), - ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), - ('6', (1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), - P('x', None), P('x', None, None, None, None, None, None), False), - ('7', (1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), - P('x', None, None), P('x', None, None, None, None, None), False), - ('8', (1024, 4096), (1024, 1, 1, 4096), - P('x', None), P('x', None, None, None), False), - ('9', (1024, 4096), (1024, 1, 1, 4096), - P(None, 'x'), P(None, None, None, 'x'), False), - ('10', (1024, 2048, 2, 1, 1, 1), (1024, 4096), - P('x', None, None, None, None, None), P('x', None), False), - ('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096), - P(None, 'x', None, None, None, None), P(None, 'x'), False), + @parameterized.parameters( + (src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg, fun) + for fun in [jnp.reshape, jax.lax.reshape] + for src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg in [ + ((16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), + False), + ((8, 2, 1), (1, 16, 1), P('x', None, None), + P(None, 'x', None), True), + ((8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), + True), + ((1, 4, 1, 6, 1), (1, 4, 6), + P(None, 'x', None, None, None), P(None, 'x', None), False), + ((4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ((1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ((1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), + ] ) @jtu.with_explicit_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, - use_sharding_arg, mesh): + use_sharding_arg, fun, mesh): np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) @partial(jax.jit, static_argnums=1) def f(x, new_sharding): - y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + y = fun(x, dst_shape, out_sharding=new_sharding) self.assertEqual(y.aval.sharding.spec, dst_spec) self.assertEqual(y.shape, dst_shape) y = y * 2 @@ -7597,6 +7604,38 @@ def f(x): jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_dot(self, mesh): + np_inp1 = np.arange(16).reshape(8, 2) + np_inp2 = np.arange(16).reshape(2, 8) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('x', 'y')) + + @jax.jit + def f(x, y): + out = jnp.dot(x, y, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.dot(np_inp1, np_inp2)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_ravel(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + out = jnp.ravel(x, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x')) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.ravel(np_inp)) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 6e5802315b241a2eef2c85e2d71b1d8c510cb9f5 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 1 May 2025 18:02:26 -0700 Subject: [PATCH 0963/1769] Remove prints PiperOrigin-RevId: 753797982 --- jax/_src/pallas/fuser/block_spec.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 9e7e18c590dd..96b3a54635f6 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -887,7 +887,6 @@ def _offset_indexer( slice_size, ): # Short-circuit if the slice start is just at zero. - print('BS', bs, indexer, slice_start, slice_size) if isinstance(slice_start, int) and slice_start == 0: return indexer match bs: @@ -1043,10 +1042,6 @@ def new_index_map(*args): # the slice are then given by (i // b_l, j // b_m, k // b_n). # We then add these block indices to block indices produced by the index # map - print('BLOCK SHAPE', block_spec.block_shape) - print('INDEXER', idx) - print('SLICE starts', slice_starts) - print('SLICE sizes', slice_sizes) block_indices = tuple( _offset_indexer(s, i, start, size) for i, s, start, size in zip( From 0f169a2e4cc5a27555a7e1a886745abe7cebc95b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 1 May 2025 21:06:35 -0700 Subject: [PATCH 0964/1769] [Pallas/Fuser] Ignore ops that have no_block_spec being pulled PiperOrigin-RevId: 753840368 --- jax/_src/pallas/fuser/block_spec.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 96b3a54635f6..38afb1d7d34a 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -35,6 +35,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.fuser import fuser_utils from jax._src.state import indexing +from jax._src.state import primitives as state_primitives import jax.numpy as jnp import numpy as np @@ -326,7 +327,7 @@ def _pull_block_spec( def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return env[atom] + return env.get(atom, pallas_core.no_block_spec) def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): if isinstance(atom, core.Literal): @@ -335,9 +336,11 @@ def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): for i, eqn in reversed(list(enumerate(jaxpr.eqns))): eqn_out_block_specs = tuple(util.safe_map(_read_block_spec, eqn.outvars)) + if all(bs is pallas_core.no_block_spec for bs in eqn_out_block_specs): + continue rule = pull_block_spec_rules.get(eqn.primitive, None) if not rule: - raise NotImplementedError(eqn.primitive) + raise NotImplementedError(eqn.primitive, eqn_out_block_specs) ctx = PullRuleContext( avals_in=tuple(v.aval for v in eqn.invars), avals_out=tuple(v.aval for v in eqn.outvars), @@ -475,7 +478,7 @@ def sds_like(x): def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return bs_env[atom] + return bs_env.get(atom, pallas_core.no_block_spec) def kernel_fn(program_ids, scalar_prefetch, *args, **kwargs): def _check_args(prefix, path, x, y, usage): @@ -801,11 +804,18 @@ def _eval_function(_, x, y): return [l_block_spec, r_block_spec] +def register_default_eval_rule(prim: core.Primitive): + def default_rule(ctx, *args, **params): + assert all(bs is pallas_core.no_block_spec for bs in ctx.out_block_specs) + return prim.bind(*args, **params) + register_eval_rule(prim)(default_rule) + def register_binop_rule(prim: core.Primitive): register_pull_block_spec_rule(prim)(functools.partial(_binop_pull_rule, prim)) register_usage_rule(prim)(functools.partial(_binop_usage_rule, prim)) register_eval_rule(prim)(functools.partial(_binop_eval_rule, prim)) +register_default_eval_rule(state_primitives.get_p) register_binop_rule(lax.mul_p) register_binop_rule(lax.add_p) From 6951cb92ce1aef51b80f0f94da9db17361dffdbf Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 1 May 2025 22:23:14 -0700 Subject: [PATCH 0965/1769] [Pallas Fuser] Add basic sublane/lane reshape fusion PiperOrigin-RevId: 753859510 --- jax/_src/pallas/fuser/block_spec.py | 78 +++++++++++++++++++++++++++ tests/pallas/fuser_block_spec_test.py | 29 ++++++++++ 2 files changed, 107 insertions(+) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 38afb1d7d34a..ba2c182014b5 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -1331,6 +1331,84 @@ def _iota_pull_rule( return [] +def _pattern_match_sublanes_to_lanes_reshape( + aval_in: core.ShapedArray, + aval_out: core.ShapedArray, +) -> bool: + # Pattern matches a reshape of the form (..., n/l, l) -> (..., n * l) + # where l is a multiple of 128 n/l is a multiple of packing. + + *leading_in, second_to_last_dim, last_dim = aval_in.shape + *leading_out, last_dim_out = aval_out.shape + if leading_in != leading_out: + return False + assert last_dim_out == second_to_last_dim * last_dim + if last_dim % 128 != 0: + return False + return True + + +@register_pull_block_spec_rule(lax.reshape_p) +def _reshape_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + if _pattern_match_sublanes_to_lanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + last_dim = _block_size(block_shape[-1]) + if last_dim % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + # We can now reshape last dim from d -> (d/128, 128) + new_block_shape = block_shape[:1] + (last_dim // 128, 128) + + def new_index_map(*args): + idx = block_spec.index_map(*args) + return *idx, 0 + + return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') + + +@register_eval_rule(lax.reshape_p) +def _reshape_eval_rule( + eval_ctx: KernelEvalContext, x, *, dimensions, new_sizes, sharding +): + del sharding, dimensions, new_sizes + out_shape_nones = tuple( + _block_size(s) for s in eval_ctx.out_block_specs[0].block_shape + ) + out_shape = tuple(s for s in out_shape_nones if s is not None) + # Because we have restricted the pull block spec rule, we can just apply a + # basic reshape here. + orig_dtype = x.dtype + if jnp.issubdtype(orig_dtype, jnp.integer): + x = x.astype(jnp.int32) + elif jnp.issubdtype(orig_dtype, jnp.floating): + x = x.astype(jnp.float32) + x = x.reshape(out_shape) + return x.astype(orig_dtype) + + +# Higher order primitives + + @register_usage_rule(pjit.pjit_p) def _jit_usage_rule( ctx, used_out: list[set[Usage]], *, jaxpr: core.ClosedJaxpr, **_ diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index ac82cd5f1b35..665cdfb1dd6b 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -732,6 +732,35 @@ def f(): x_block, ) + def test_basic_reshape(self): + + def f(x): + return x.reshape((512, 2048)) + + in_type = jax.ShapeDtypeStruct((512, 16, 128), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 1024), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 8, 128)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 1024))) + class PullBlockSpecHOPTest(jtu.JaxTestCase): From 8bf790ec5ec037fee7bc080e364110c2106bc173 Mon Sep 17 00:00:00 2001 From: Sannidhya Chauhan Date: Thu, 1 May 2025 23:45:14 -0700 Subject: [PATCH 0966/1769] Added options into the profiler. PiperOrigin-RevId: 753880517 --- jax/_src/profiler.py | 41 +++++++++++++++++++++++++++++++---------- jax/profiler.py | 21 +++++++++++---------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 8fdf9953653f..6b58b2ba6326 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -39,6 +39,10 @@ logger = logging.getLogger(__name__) +class ProfileOptions(_profiler.ProfileOptions): + """Profiler Options to configure the collectors for the profiler.""" + + def start_server(port: int) -> _profiler.ProfilerServer: """Starts the profiler server on port `port`. @@ -89,12 +93,17 @@ def reset(self): _profile_state = _ProfileState() -def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, - create_perfetto_trace: bool = False) -> None: +def start_trace( + log_dir: os.PathLike | str, + create_perfetto_link: bool = False, + create_perfetto_trace: bool = False, + profiler_options: ProfileOptions | None = None, +) -> None: """Starts a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python - functions and JAX on-device operations. Use :func:`stop_trace` to end the trace + functions and JAX on-device operations. Use :func:`stop_trace` to end the + trace and save the results to ``log_dir``. The resulting trace can be viewed with TensorBoard. Note that TensorBoard @@ -113,8 +122,8 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ with _profile_state.lock: if _profile_state.profile_session is not None: @@ -126,7 +135,12 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, # fail and no TPU operations will be included in the profile. xla_bridge.get_backend() - _profile_state.profile_session = _profiler.ProfilerSession() + if profiler_options is None: + _profile_state.profile_session = _profiler.ProfilerSession() + else: + _profile_state.profile_session = _profiler.ProfilerSession( + profiler_options + ) _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) @@ -225,7 +239,12 @@ def stop_and_get_fdo_profile() -> bytes | str: @contextmanager -def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False): +def trace( + log_dir: os.PathLike | str, + create_perfetto_link=False, + create_perfetto_trace=False, + profiler_options: ProfileOptions | None = None, +): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python @@ -247,10 +266,12 @@ def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfett ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ - start_trace(log_dir, create_perfetto_link, create_perfetto_trace) + start_trace( + log_dir, create_perfetto_link, create_perfetto_trace, profiler_options + ) try: yield finally: diff --git a/jax/profiler.py b/jax/profiler.py index 77157dc02a13..31f3ea186d79 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -16,14 +16,15 @@ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.profiler import ( - StepTraceAnnotation as StepTraceAnnotation, - TraceAnnotation as TraceAnnotation, - device_memory_profile as device_memory_profile, - save_device_memory_profile as save_device_memory_profile, - start_server as start_server, - stop_server as stop_server, - start_trace as start_trace, - stop_trace as stop_trace, - trace as trace, - annotate_function as annotate_function, + ProfileOptions as ProfileOptions, + StepTraceAnnotation as StepTraceAnnotation, + TraceAnnotation as TraceAnnotation, + annotate_function as annotate_function, + device_memory_profile as device_memory_profile, + save_device_memory_profile as save_device_memory_profile, + start_server as start_server, + start_trace as start_trace, + stop_server as stop_server, + stop_trace as stop_trace, + trace as trace, ) From 52c99642c6b60031b94d42c8622a6700672d7315 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 2 May 2025 00:06:56 -0700 Subject: [PATCH 0967/1769] [Mosaic TPU] Add simplify pass after canonicalize-mosaic PiperOrigin-RevId: 753886330 --- .../mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 9a9e594f928e..7f241deb550c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -590,15 +590,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero); op.replaceAllUsesWith(new_op.getResult()); op.erase(); - // We briefly trigger canonicalization here to potentially fuse the rounding - // ops into the newly created tpu.fptosi. - { - PatternRewriter rewriter(new_op.getContext()); - rewriter.setInsertionPoint(new_op); - // We don't care if the canonicalization pattern matched or not. - (void)tpu::FPToSIOp::canonicalize(new_op, rewriter); - new_op = nullptr; // Canonicalization may have erased the op! - } return success(); } Value x = op.getIn(); From ff1672b3278450ab853cd7c78b2721edb63e9c15 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 2 May 2025 01:35:30 -0700 Subject: [PATCH 0968/1769] [Mosaic GPU] Refactor how gmem scratch and TMA descriptors are initialized This puts all the relevant functionality in a single place - the new `Scratch` class. It also ensures that ops are created on-demand. This allows us to run passes such as `canonicalize` without worrying that they will remove the dead code that was previously created eagerly. PiperOrigin-RevId: 753908639 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +- jax/experimental/mosaic/gpu/core.py | 51 ++------ jax/experimental/mosaic/gpu/launch_context.py | 109 ++++++++++++++++-- 3 files changed, 110 insertions(+), 54 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b4f33ef3e7a4..1dd5a8f68d3a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -803,7 +803,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(params.profile_space * 2 * 4) prof_ctx = ProfilerContext(params.profile_dir, prof_spec) - module, new_out_shapes, _, launch_ctx, scratch_arr = ( + module, new_out_shapes, _, launch_ctx = ( mgpu_core._lower_as_gpu_kernel( body, grid=tuple(map(operator.mul, parallel_grid, cluster)), @@ -825,7 +825,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mgpu.infer_transforms(module) # pytype: disable=attribute-error mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() if gmem_scratch_shapes: new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)] diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 21f5278829d5..9d46c6803186 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -387,7 +387,6 @@ def _launch( grid: tuple[int, int, int], cluster: tuple[int, int, int], block: tuple[int, int, int], - scratch_arr, smem_buffers: ShapeTree | Union[ShapeTree], lowering_semantics: LoweringSemantics, profiler_spec: profiler.ProfilerSpec | None = None, @@ -456,9 +455,9 @@ def _launch( else: prof = None - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof) + ctx = launch_context.LaunchContext( + launch_context.Scratch(launch_op), cluster, prof + ) with ctx.named_region("Init"): tmem_allocs: list[_TMEMAlloc] = [] smem_ref_tree_thunk = _construct_smem_reftree( @@ -510,7 +509,6 @@ def _lower_as_gpu_kernel( ptr_ty = ir.Type.parse("!llvm.ptr") token_ty = ir.Type.parse("!gpu.async.token") i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) @@ -536,7 +534,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: kernel_name = getattr(body, "__name__", "anonymous") # These are needed as nonlocal below. - launch_ctx, scratch_arr = None, None + launch_ctx = None with ir.InsertionPoint(module.body): _declare_runtime_functions() global_scratch = llvm.GlobalOp( @@ -547,7 +545,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ) @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}") def main(token_ptr, buffers): - nonlocal launch_ctx, scratch_arr + nonlocal launch_ctx token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): @@ -556,14 +554,8 @@ def main(token_ptr, buffers): in_refs = arg_refs[:len(in_ref_tys)] out_refs = arg_refs[len(in_ref_tys):] prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, - alignment=launch_context.TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, + token, grid, cluster, block, smem_scratch_shape, lowering_semantics, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx @@ -575,7 +567,7 @@ def main(token_ptr, buffers): sym_tab.insert(global_scratch) module.operation.verify() - return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr + return module, out_shape, unwrap_output_tuple, launch_ctx def _run_serde_pass( @@ -602,27 +594,6 @@ def _run_serde_pass( return module -def _initialize_scratch( - launch_ctx : launch_context.LaunchContext, - scratch_arr: ir.Value, - ): - """ - Allocates and initializes the host buffer right before the launch. This needs - to be done after all TMA descriptors have been recorded by the launch context. - Only then we know what the scratch contains. - - When using the Mosaic GPU dialect, the necessary information is known only - after the lowering passes have run. - """ - with ir.InsertionPoint(scratch_arr.owner): - gmem_scratch_bytes = launch_ctx.next_scratch_offset - scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc_op.result) - def _declare_runtime_functions(): """Declares the runtime functions that can be used by the generated code.""" ptr_ty = ir.Type.parse("!llvm.ptr") @@ -653,7 +624,7 @@ def as_gpu_kernel( elif not isinstance(in_shape, tuple): in_shape = (in_shape,) - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( + module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, thread_semantics, module_name, kernel_name, prof_spec @@ -667,7 +638,7 @@ def as_gpu_kernel( transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() module.operation.verify() expected_arg_treedef = jax.tree.structure(in_shape) @@ -736,7 +707,7 @@ def as_torch_gpu_kernel( flat_out_types, out_treedef = jax.tree.flatten(out_shape) expected_arg_treedef = jax.tree.structure(in_shape) - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( + module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, lowering_semantics, module_name, kernel_name, prof_spec @@ -750,7 +721,7 @@ def as_torch_gpu_kernel( transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() module.operation.verify() # Get our hands on the compilation and unload functions diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 2ec9047402a4..fbefd027b53e 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -25,6 +25,7 @@ from jax._src import lib as jaxlib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm @@ -231,16 +232,100 @@ def batch(self, leading_rank: int) -> MemRefTransform: ReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] +class Scratch: + """Manages ops handling the GMEM scratch that contains the TMA descriptors. + + TMA descriptors are created on the host and then copied to GMEM. So there + needs to be some code on the host to allocate and initialize the TMA + descriptors. However, we only know what descriptors we need after we have + lowered the entire kernel. This class helps manage everything needed to + correctly allocate and initialize the scratch. + + To help reconcile the needs of kernels that use the dialect lowering with + those that use MGPU APIs directly, this class only creates the relevant ops + lazily. Eager creation would make them appear dead before dialect lowering + and MLIR's DCE would remove them. + + During the lowering, we collect information about how many bytes are needed + and also how each descriptor should be initialized on the host. At the end + of the lowering, the finalize_size() method should be called to add the + necessary code on the host to allocate and initialize all descriptors. + """ + def __init__(self, gpu_launch_op: gpu.LaunchOp): + self.next_offset: int = 0 + self.host_init: list[Callable[[ir.Value], None]] = [] + self._alloc_op = None + self._load_op = None + self._scratch_ptr = None + + # Ideally, we would store the gpu.launch op directly. However, it gets + # invalidated by passes like "canonicalize". Thus we store the module and + # find the gpu.launch op from there when needed. + op = gpu_launch_op + while op.name != "builtin.module": + op = op.parent.opview + assert op is not None + self._module_op = op + + def _find_gpu_launch_op(self, block: ir.Block) -> ir.OpView | None: + for op in block: + if op.name == "gpu.launch": + return op + for region in op.regions: + for block in region: + child_op = self._find_gpu_launch_op(block) + if child_op is not None: + return child_op + return None + + def _create_ops_if_none(self): + if self._alloc_op is not None: + return + + gpu_launch_op = self._find_gpu_launch_op(self._module_op.body) + assert gpu_launch_op is not None + ptr_ty = ir.Type.parse("!llvm.ptr") + with ir.InsertionPoint(gpu_launch_op): + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + i64 = ir.IntegerType.get_signless(64) + self._alloc_op = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, + alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + self._load_op = llvm.LoadOp(empty_arr_ty, self._alloc_op) + + with ir.InsertionPoint.at_block_begin(gpu_launch_op.body.blocks[0]): + self._scratch_ptr = builtin.unrealized_conversion_cast( + [ptr_ty], [self._load_op] + ) + + def device_ptr(self) -> ir.Value: + self._create_ops_if_none() + return self._scratch_ptr + + def finalize_size(self): + """ + Allocates and initializes the host buffer. This needs to be done after + lowering, i.e. after all TMA descriptors have been recorded. Only then we + know what the scratch contains. + """ + if self.next_offset == 0: + return + assert self._alloc_op is not None + with ir.InsertionPoint(self._load_op): + gmem_scratch_bytes = self.next_offset + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + self._alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) + self._load_op.result.set_type(scratch_arr_ty) + for init_callback in self.host_init: + init_callback(self._alloc_op.result) + + @dataclasses.dataclass() class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value + scratch: Scratch cluster_size: tuple[int, int, int] profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) tma_descriptors: dict[ tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], ir.Value, @@ -288,21 +373,21 @@ def _alloc_scratch( ptr_ty = ir.Type.parse("!llvm.ptr") if alignment is None: alignment = size - if self.next_scratch_offset % alignment: + if self.scratch.next_offset % alignment: raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size + alloc_base = self.scratch.next_offset + self.scratch.next_offset += size def host_init_wrapped(host_ptr): host_init( llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none) ) - self.host_scratch_init.append(host_init_wrapped) + self.scratch.host_init.append(host_init_wrapped) # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): # There is no way to create an insertion point after an operation... gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none + ptr_ty, self.scratch.device_ptr(), [], [alloc_base], i8, llvm.GEPNoWrapFlags.none ) - gep.move_after(self.gmem_scratch_ptr.owner) + gep.move_after(self.scratch.device_ptr().owner) return device_init(gep.result) def _get_tma_desc( From e64049f44fcc11fbae9a3b05c15bb1708a8d5afa Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 2 May 2025 06:06:42 -0700 Subject: [PATCH 0969/1769] [jaxlib] Pack/unpack subbyte types to/from numpy arrays to support int2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks. Reverts a099b285307508efad12a015d6f6d9d13ae49077 PiperOrigin-RevId: 753974925 --- jaxlib/BUILD | 1 + jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 105 ++++++++++++++++++++++------------ jaxlib/py_client_cpu.cc | 91 +++++++++++++++++++++-------- jaxlib/rocm/BUILD | 1 + jaxlib/xla_client.py | 2 +- tests/python_callback_test.py | 91 +++++++++++++++++------------ 7 files changed, 195 insertions(+), 97 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5165f5cf2520..8dabee20038c 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -933,6 +933,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 3895b067c87f..c872433ce04a 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -639,6 +639,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 570e3135b1c9..0afa3a9bf1d5 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -83,8 +84,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -93,12 +93,19 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, host_input_buffers[i] = nullptr; continue; } - void* buf = new char[arg->size_bytes()]; - host_input_buffers[i] = buf; + size_t size_bytes = arg->size_bytes(); + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + size_bytes = arg->element_count() * bits_per_element / 8; + } + host_input_buffers[i] = new char[size_bytes]; // TODO(b/238441608): Use pinned memory here to speed up the transfer. auto gpu_res = - gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), - gpuMemcpyDeviceToHost, stream); + gpuMemcpyAsync(host_input_buffers[i], arg.value().untyped_data(), + size_bytes, gpuMemcpyDeviceToHost, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) @@ -114,9 +121,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -124,6 +128,24 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + auto size_bytes = arg->element_count() * bits_per_element / 8; + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + size_bytes); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); @@ -148,8 +170,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -170,32 +191,46 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - continue; + + const void* data = array.data(); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[size_bytes]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; } - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), temp); - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/py_client_cpu.cc b/jaxlib/py_client_cpu.cc index e6778a69a9c3..1943244b51be 100644 --- a/jaxlib/py_client_cpu.cc +++ b/jaxlib/py_client_cpu.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -81,8 +82,7 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -98,9 +98,21 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t size_bytes = arg->element_count() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + } // We pass in data using default numpy layout i.e., std::nullopt. - auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -121,9 +133,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -143,26 +154,56 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); - continue; + + const void* data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } } - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions_size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 8fc988137ff7..75406174dd93 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -541,6 +541,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index def2318ae75b..6d03f4530631 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 335 +_version = 336 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 4aac07992ca8..9f7336548d12 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -31,6 +31,7 @@ from jax.experimental import io_callback from jax.experimental import pjit from jax._src.shard_map import shard_map +from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp from jax.sharding import Mesh import numpy as np @@ -585,8 +586,15 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): + if jaxlib_extension_version < 336: + self.skipTest("Requires jaxlib_extension_version >= 336.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(x): return x def f(x): @@ -597,19 +605,17 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)(x).block_until_ready() + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 336: + self.skipTest("Requires jaxlib_extension_version >= 336.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(): return np.arange(8, dtype=dtype) @@ -620,16 +626,43 @@ def f(): ) return y - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)().block_until_ready() + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 336: + self.skipTest("Requires jaxlib_extension_version >= 336.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) class PureCallbackTest(jtu.JaxTestCase): @@ -1088,20 +1121,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): From c8f7e39f4116686514ae83b340551ecdc145a341 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 2 May 2025 06:51:07 -0700 Subject: [PATCH 0970/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8f46fb3d56b8d9b049d3158cd12838c24bc95307. PiperOrigin-RevId: 753985467 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b409fd17957a..4b4a60c182b9 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "15565b8da6d85e9faec669cb22878a0e44cca4ee" -XLA_SHA256 = "360d260d4da982da900d783d3a2705a5fe9133f0e130c0436485bf4477d60ff0" +XLA_COMMIT = "8f46fb3d56b8d9b049d3158cd12838c24bc95307" +XLA_SHA256 = "786eefd5383f1264dcb1de71bdc1091b9fdb38a52ec859514e20015458206a78" def repo(): tf_http_archive( From b26d855324f8378bed329c30bc92966a8876c763 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 2 May 2025 10:17:26 -0400 Subject: [PATCH 0971/1769] Skip Pallas tests that are failing in CI on L4 GPUs. --- tests/pallas/ops_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index bf60408e96b7..bc29e741f119 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1893,6 +1893,15 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): > (256 * 256) * 2 ): self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and + dtype == jnp.float32 and + lhs_and_rhs_shape in [ + ((128, 16), (128, 256)), + ((16, 128), (128, 256)), + ((16, 256), (256, 128)), + ((256, 16), (256, 128)), + ]): + self.skipTest("Shared memory size limit exceeded") if min(*lhs_shape, *rhs_shape) < 16: self.skipTest("All dimensions of lhs and rhs must be >= 16") if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape): From f4626f42283ba76bf0e986d263ba60ab4d930fe5 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 2 May 2025 07:26:34 -0700 Subject: [PATCH 0972/1769] Update ManualComputation round-tripping to not have round-tripping attrs on the CallOps. This is needed as sometimes the op has `stablehlo.token` types, and during MLIR<->HLO round-tripping, MLIR type conversion converts this to `mhlo.token` and "accidentally" discards the unregistered attrs (the `frontend_attributes`) - note MLIR has no guarantees about preserving these, so we can't submit a fix to MLIR. This happens because StableHLO->HLO conversion still does StableHLO->MHLO conversion, but this should be removed soon. To unblock ourselves instead of waiting for StableHLO to remove the intermediate MHLO pass, we can just move the frontend attrs to the GlobalToLocal/LocalToGlobal custom calls. PiperOrigin-RevId: 753995222 --- ...ardy_sharding_ops_with_different_meshes.py | 79 +++++++++++++++++++ tests/BUILD | 4 + tests/export_back_compat_test.py | 11 ++- 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py index b54234d11cca..89ba7b0a8790 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py @@ -54,4 +54,83 @@ mlir_module_serialized=b'ML\xefR\rStableHLO_v1.8.8\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x97q\x13\x019\x0f\x07\x0b\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x1b\x0b\x0f\x0b\x17\x13\x039\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02v\x03\x1d\x1f!\x1f\x05\x11\x05\x13\x03\t\x0b\r\x05\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07r\x10\x1b\x1d%\'\x05#\x17\x07j\x10\x1f\x1d+\x03\x05%\x03\x05\x05[/_\x05\'\x1d35\x05)\x17\x07n\x10\x15\x03\x03\x05e\x03\x01\x1d+\x1d-\x0b\x03\x05\x01\x1d/\x03\x03G\r\x01#\r\x03\x03M\r\x03O;\x1d1\x1d3\x1d5#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03]=\x1d7\x1d9\x1d;\x1d=\r\x07g=ikm=\x1d?\x1dA\x1dC\x1dE\x1dG\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xb9\x05\x01Q\x03\t\x01\x07\x04\xa7\x03\x01\t\x05P\x03\x03\x07\x04]\x03\x0b\x17\x03\x0f)\x00\x03G1-\x05\x03\x07\x03\x01\x03F\x01\x07\x03\x05\x03\x03\x0bG\x017\t\x03\x05\x03\x05\x03F\x01\x0b\x03\x07\x03\x07\x07\x04\x03\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00r\x0bI7-3)+7\x13+#\x0f\x0b!Ae\x03Q\x1d\x05;=\x13%)=\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<[\\"a\\"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00\x00#sdy.sharding_per_value<[<@mesh, [{\\"a\\"}, {}]>]>\x00xla.sdy.manual_computation_body\x00jax.result_info\x00main\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.in_shardings\x00xla.sdy.manual_axes\x00#sdy\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05;\x01\x0bEIKQS\x11?;a9A999\x11?;c9A999\x03C\x11?;o9A999\x0b9U9C;\x05WY', xla_call_module_version=9, nr_devices=2, +) + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_14 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['Sharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'], + serialized_date=datetime.date(2025, 4, 14), + inputs=(array([[0., 1., 2., 3.], + [4., 5., 6., 7.]], dtype=float32),), + expected_outputs=(array([[4., 5., 6., 7.], + [0., 1., 2., 3.]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":1017:8 to :54) +#loc4 = loc("third_party/py/absl/testing/absltest.py":2872:19 to :56) +#loc5 = loc("third_party/py/absl/testing/absltest.py":2908:35 to 2910:3) +#loc6 = loc("third_party/py/absl/testing/absltest.py":2449:6 to :34) +#loc7 = loc("third_party/py/absl/app.py":404:13 to :23) +#loc8 = loc("third_party/py/absl/app.py":484:6 to :27) +#loc9 = loc("third_party/py/absl/testing/absltest.py":2451:4 to :31) +#loc10 = loc("third_party/py/absl/testing/absltest.py":2333:2 to :38) +#loc11 = loc("third_party/py/jax/tests/export_back_compat_test.py":1021:2 to :47) +#loc12 = loc("third_party/py/jax/tests/export_back_compat_test.py":1008:13 to :30) +#loc15 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes"(#loc3)) +#loc16 = loc("_run_and_get_tests_result"(#loc4)) +#loc17 = loc("run_tests"(#loc5)) +#loc18 = loc("_run_in_app..main_function"(#loc6)) +#loc19 = loc("_run_main"(#loc7)) +#loc20 = loc("run"(#loc8)) +#loc21 = loc("_run_in_app"(#loc9)) +#loc22 = loc("main"(#loc10)) +#loc23 = loc(""(#loc11)) +#loc24 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc12)) +#loc26 = loc(callsite(#loc22 at #loc23)) +#loc28 = loc(callsite(#loc21 at #loc26)) +#loc30 = loc(callsite(#loc20 at #loc28)) +#loc32 = loc(callsite(#loc19 at #loc30)) +#loc34 = loc(callsite(#loc18 at #loc32)) +#loc36 = loc(callsite(#loc17 at #loc34)) +#loc38 = loc(callsite(#loc16 at #loc36)) +#loc40 = loc(callsite(#loc15 at #loc38)) +#loc43 = loc(callsite(#loc24 at #loc40)) +#loc46 = loc("jit(func)/jit(main)/shard_map"(#loc43)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4xf32> loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc45) + %1 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy"}} : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc46) + %2 = call @xla.sdy.manual_computation_body(%1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc46) + %3 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%2) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc46) + return %3 : tensor<2x4xf32> loc(#loc) + } loc(#loc) + func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc43))) -> tensor<1x4xf32> { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc47) + return %0 : tensor<1x4xf32> loc(#loc46) + } loc(#loc46) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":1007:10 to :73) +#loc13 = loc("third_party/py/jax/tests/export_back_compat_test.py":1006:15 to :46) +#loc14 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc2)) +#loc25 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func"(#loc13)) +#loc27 = loc(callsite(#loc21 at #loc22)) +#loc29 = loc(callsite(#loc20 at #loc27)) +#loc31 = loc(callsite(#loc19 at #loc29)) +#loc33 = loc(callsite(#loc18 at #loc31)) +#loc35 = loc(callsite(#loc17 at #loc33)) +#loc37 = loc(callsite(#loc16 at #loc35)) +#loc39 = loc(callsite(#loc15 at #loc37)) +#loc41 = loc(callsite(#loc24 at #loc39)) +#loc42 = loc(callsite(#loc14 at #loc40)) +#loc44 = loc(callsite(#loc25 at #loc41)) +#loc45 = loc("jit(func)/jit(main)/sharding_constraint"(#loc42)) +#loc47 = loc("jit(func)/jit(main)/ppermute"(#loc44)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.5\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x1a\x02\xe7\x13\x01\xa7\x0f\x0b\x0b\x0b\x07\x0f\x0b\x0f\x0f\x0f\x0f\x0f\x0f\x0b\x0f\x0f\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x1f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\'\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x1f\x13\x13\x13\x03A\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b\x1b\x0b#\x1b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02\xce\x06\x1d9;\x05\x11\x05\x13\x05\x15\x1f\x1d\r=\x05\x17\x15\x11C\x1d?A\x1dEG\x1dKM\x1dQS\x1dWY\x05\x19\x1d]_\x1dce\x1dik\x03\t%\'\x03)/135\x05\x1b\x11\x03\x00\x03\x03+-\x05\x1d\x05\x1f\x05!\x11\x01\t\x05#\x11\x01\x05\x05%\x05\'\x15\x0b\x0f-\x05\x07\xc2\x0f\x1b=\x05)-\x05\x07\xe6\x0f\x11m\x15\x13I\x05+-\x07\x07\xe2,\'q\x15\x15O\x05--\x07\tr-Gz-\x07\x15\x17U\x05/-\x07\x07F&\rE\x15\x19[\x051-\x1b\x07R\x06\x1b/\x15\x1da\x053-\x1b\x07\x92\x07\r7\x15\x1fg\x055-\x07\x07N&\t?\x15!m\x057-\x07\x07v$\x05M\x1doq\x059-\x05\x07\xf6\x0f\x05_\x1duw\x05;\x15y\x7f\x1d{}\x05=-\x05\x07\xba\x0f\x1f]\x15\x0b\x81\x15\x11\x83\x15\x13\x85\x15\x15\x87\x15\x17\x89\x15\x19\x8b\x15\x1d\x8d\x15\x1f!\x1d\x91\t\x05?\x03\x05\x03\xd3\x95\xd7\x05A\x1d\x99\x9b\x05C\x15\x9d\x0f\x1d\r\x9f-\x05\x07\xbe\x0f\x15\x93\x03\x03\x03\xdd\x03\x03\x03\xe1\x03\x03\x03\xe3\x03\x01\x1dE\x1dG\x0b\x03\x1dI\x1dK\x1dM\x1dO\x05\x03\x1dQ\x03\x03\xbd\r\x01#\r\x03\x03\xc3\r\x03\xc5\xc7\x1dS\x1dU\x1d7\x1dW#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03\xd5\xa9\x1dY\x1d[\x1d]\x05\x01\r\x05\xb5\xa9\xaf\xb1\x1d_\r\x07\xb5\xa9\xaf\xb1\xb9\xa9\r\x05\xaf\xb1\xb9\xa9\x1da\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xbd\x05\x01Q\t#\x01\x07\x04\xab\x03\x01\t\x05P\t\x03\x07\x04a\x03\x0b\x17\x03\x0f\x8f\x00\x03G\x97\x93\x05\x03\x07\x03\x01\x03G\x01\xa1\x07\x03\x05\x03\x03\x0bG\x01\xa3\t\x03\x05\x03\x05\x03G\x01\xa5\x0b\x03\x07\x03\x07\x07\x04\t\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tFs\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00.\x12c77\x13+#\x0f\x0f!-+A/)\x03aQ\x1d\x05\xcd;\x13\x0b\x19\t\x15G\x155\x81=\x13%)9\x1f97\x9dQi3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00third_party/py/absl/testing/absltest.py\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func\x00third_party/py/absl/app.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes\x00_run_and_get_tests_result\x00run_tests\x00_run_in_app..main_function\x00_run_main\x00run\x00_run_in_app\x00main\x00\x00jit(func)/jit(main)/ppermute\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00#sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>\x00\x00xla.sdy.manual_axes\x00#sdy\x00xla.sdy.manual_computation_body\x00xla.sdy.in_shardings\x00xla.sdy.out_shardings\x00jax.result_info\x00result\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05o\x01\x0b\xbb\xbf\xc1\xc9\xcb\x11\xad\xab\xd9\xa7\xdb\xa7\xa7\xa7\x11\xad\xab\xdf\xa7\xb7\xa7\xa7\xa7\x03\xb3\x11\xad\xab\xe5\xa7\xb7\xa7\xa7\xa7\x0b\xa7\xcd\xa7\xb3\xab\x05\xcf\xd1', + xla_call_module_version=9, + nr_devices=2, ) # End paste diff --git a/tests/BUILD b/tests/BUILD index 50fd9e8fe837..3ac17767ad13 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1589,6 +1589,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], + # TODO(b/415285434): enable once we have backwards compatibility support with GSPMD checkpoints. + # enable_configs = [ + # "tpu_v3_x4_shardy", + # ], tags = [], deps = [ "//jax:internal_export_back_compat_test_data", diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 7082b212a5e3..be87b4e3e5b3 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -976,9 +976,14 @@ def shard_map_func(x): # b: f32[2, 4] x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None))) return shard_map_func(x) - data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12) - with Mesh(devices, axis_names=('x')): - self.run_one_test(func, data) + data = [ + shardy_sharding_ops_with_different_meshes.data_2025_02_12, + shardy_sharding_ops_with_different_meshes.data_2025_04_14, + ] + + for d in data: + with Mesh(devices, axis_names=('x')): + self.run_one_test(func, self.load_testdata(d)) if __name__ == "__main__": From 2a90dae6220e35774ea9f17f65f45d50f907aea9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 2 May 2025 07:30:40 -0700 Subject: [PATCH 0973/1769] DOC: hide config options TOC on front page --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index 35906f1a5534..5a43be427041 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -167,6 +167,7 @@ maintains an up-to-date list. glossary .. toctree:: + :hidden: :maxdepth: 2 config_options From bf5d1c001e243fe58aebf999e27d0d198bedc482 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 2 May 2025 07:36:03 -0700 Subject: [PATCH 0974/1769] Mark `out_sharding` as a keyword argument in the `pyi` files too for `jnp.ravel`, `jnp.reshape` and `jnp.dot` PiperOrigin-RevId: 753997769 --- jax/_src/basearray.pyi | 2 +- jax/_src/numpy/array_methods.py | 4 ++-- jax/numpy/__init__.pyi | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 3026b3bd6ab9..cf64afdacfe3 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -184,7 +184,7 @@ class Array(metaclass=abc.ABCMeta): promote_integers: bool = True) -> Array: ... def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... - def ravel(self, order: str = 'C', + def ravel(self, order: str = 'C', *, out_sharding: NamedSharding | P | None = ...) -> Array: ... @property def real(self) -> Array: ... diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 3e65e5d83100..a3dbc0f9f6c6 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -197,12 +197,12 @@ def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = N """ return tensor_contractions.dot(self, b, precision=precision, preferred_element_type=preferred_element_type) -def _flatten(self: Array, order: str = "C") -> Array: +def _flatten(self: Array, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. Refer to :func:`jax.numpy.ravel` for the full documentation. """ - return lax_numpy.ravel(self, order=order) + return lax_numpy.ravel(self, order=order, out_sharding=out_sharding) def _imag_property(self: Array) -> Array: """Return the imaginary part of the array.""" diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 1e8e900f1e04..a8de717a0d07 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -799,7 +799,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = .. r_: _RClass def rad2deg(x: ArrayLike, /) -> Array: ... def radians(x: ArrayLike, /) -> Array: ... -def ravel(a: ArrayLike, order: str = ..., +def ravel(a: ArrayLike, order: str = ..., *, out_sharding: NamedSharding | P | None = ...) -> Array: ... def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = ..., order: str = ...) -> Array: ... From 06161ef87e2826d422782e391cb6d0107c918626 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 2 May 2025 09:29:03 -0700 Subject: [PATCH 0975/1769] [pallas:mgpu] Add select op support to divisibility inferencece. PiperOrigin-RevId: 754030929 --- jax/_src/pallas/mosaic_gpu/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6fae802b7ff4..23660200fd48 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -207,6 +207,10 @@ def _is_known_divisible(value, divisor, fuel=10) -> bool: case arith_dialect.MulIOp(): return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) + case arith_dialect.SelectOp(): + return (_is_known_divisible(value.owner.operands[1], divisor, fuel // 2) and + _is_known_divisible(value.owner.operands[2], divisor, (fuel + 1)// 2)) + return False From 832f3bac556acd19a1ec6f94d1781c854da2fcfc Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 2 May 2025 09:44:33 -0700 Subject: [PATCH 0976/1769] fix breakage from #28318 due to colliding kwargs PiperOrigin-RevId: 754036048 --- jax/_src/pjit.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7c716910eaeb..40ea97344469 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -3258,9 +3258,13 @@ def visit(x): args, kwargs = tree_map(visit, (args, kwargs)) return args, kwargs, box_data +# TODO(mattjj): because _handle_boxes's caller passes arguments splatted, the +# names of its first two parameters must not collide with user-suppliedkwargs. +# Using obscure names is a temporary workaround; revise! @lu.transformation2 -def _handle_boxes(f, dbg, *args, **kwargs): +def _handle_boxes(__f, __dbg, *args, **kwargs): from jax.experimental.attrs import Box, List + f, dbg = __f, __dbg new_args = [] arg_mutables = [] def visit(x): From 5efbaa8d39ace35f0cbeb4661c1a77dbda813076 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 2 May 2025 10:56:12 -0700 Subject: [PATCH 0977/1769] Explicitly fail on pallas w/ disable_jit until we can decide if/how we want to support it. PiperOrigin-RevId: 754061572 --- jax/_src/pallas/pallas_call.py | 6 ++++++ tests/pallas/pallas_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 286e49768cca..def8efd472c6 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -77,6 +77,12 @@ def _pallas_call_impl(*args, **params): # Call the lowering path + if config.disable_jit.value: + raise NotImplementedError( + "pallas_call not supported with disable_jit. Consider invoking under a" + " local context of `jax.disable_jit(False)`." + ) + @partial(jax.jit, inline=True) def _jit_run(*args): return pallas_call_p.bind(*args, **params) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index c16c66d4eb52..725b3adb4388 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1261,6 +1261,31 @@ def dot_general_kernel(x_ref, y_ref, o_ref): ): dot_general_kernel(x, y) + def test_jax_disable_jit(self): + def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + + @jax.jit + def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return self.pallas_call( + add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + + # Prove kernel works fine without disable_jit. + add_vectors(jnp.arange(8), jnp.arange(8)) + + with self.assertRaisesRegex( + NotImplementedError, "pallas_call not supported with disable_jit." + ): + with jax.disable_jit(): + add_vectors(jnp.arange(8.0), jnp.arange(8.0)) + + with jax.disable_jit(): + # We instructed the user to do this, so this should not raise an error. + with jax.disable_jit(False): + add_vectors(jnp.arange(8.0), jnp.arange(8.0)) + class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True From c7f3d1c0edf20f9d29161da601aecf2fbbac0ee2 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 2 May 2025 11:58:54 -0700 Subject: [PATCH 0978/1769] fix breakage from #28318 when using a chex.dataclass(mapapble_dataclass=False), we can't tree_map over those instances, even though they might be passed as args to jitted functions. PiperOrigin-RevId: 754084506 --- jax/_src/pjit.py | 14 +++++++++++--- tests/attrs_test.py | 15 ++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 40ea97344469..53cebb951f6a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -578,8 +578,16 @@ def _infer_params_impl( del args f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) del kwargs - dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) - f = _handle_boxes(f, dbg) + + + # TODO(mattjj,dougalm): refine this implementation of box-handling... + from jax.experimental.attrs import Box, List + if any(isinstance(x, (Box, List)) for x in tree_leaves((dyn_args, dyn_kwargs))): + dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) + f = _handle_boxes(f, dbg) + else: + box_data = [] + explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) @@ -1414,7 +1422,7 @@ def p_one_diff(diff: Sequence[str]): for d in smallest_diffs: p_one_diff(d) - return done() + done() @partial(lu.cache, explain=explain_tracing_cache_miss) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 0d3a85d0e694..c48b377b076b 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -1097,13 +1097,14 @@ def f(box1, box2): with self.assertRaisesRegex(ValueError, "a Box instance can't be passed"): f(b, b) - def test_error_returning_from_jit(self): - @jax.jit - def f(): - return {'a': Box(1.0)} - - with self.assertRaisesRegex(ValueError, "a Box instance can't be returned"): - f() + # TODO(mattjj): re-enable this test + # def test_error_returning_from_jit(self): + # @jax.jit + # def f(): + # return {'a': Box(1.0)} + + # with self.assertRaisesRegex(ValueError, "a Box instance can't be returned"): + # f() class ListTest(jtu.JaxTestCase): From 653085d4a948d8e47fa8966da9355c447b7e4531 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Fri, 2 May 2025 19:19:39 +0000 Subject: [PATCH 0979/1769] chore: disable legacy external runfiles --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index b6bb79cc8d0d..9ec02b94f03b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -33,6 +33,8 @@ build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. build --copt=-DNB_DOMAIN=jax +build --legacy_external_runfiles=false + # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel # depending on the platform that is running the build. From 5a3605d3077dfe25c5382df99f13a7b10a6e796f Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 2 May 2025 13:21:17 -0700 Subject: [PATCH 0980/1769] [pallas:mgpu] Allow WGMMA accuculator to cond yielded values. PiperOrigin-RevId: 754112101 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1dd5a8f68d3a..6b0c0760c73c 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2465,7 +2465,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): def _yielded_values(outs, avals): ret = [] for out, aval in zip(outs, avals): - if isinstance(out, mgpu.FragmentedArray): + if isinstance(out, (mgpu.WGMMAAccumulator, mgpu.FragmentedArray)): ret.append(out) else: ret.append(_ensure_ir_value(out, aval.dtype)) From 22375ab5ea105db93879b95189ad089b490feb0a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 2 May 2025 13:24:01 -0700 Subject: [PATCH 0981/1769] Add Bazel CPU tests with `py_import` dependency to continuous tests. This is needed to collect the statistics and identify the failures which are detected by pytest but not by Bazel or vice versa. These tests don't require pre-built wheels downloaded from GCS. Instead they build the wheels as transitive dependencies of the test targets and unpack them using `py_import`. Execution example - https://github.com/jax-ml/jax/actions/runs/14800773538/job/41558800743 PiperOrigin-RevId: 754113096 --- .github/workflows/bazel_cpu_py_import_rbe.yml | 63 +++++++++++++++++ .github/workflows/wheel_tests_continuous.yml | 22 +++++- ci/run_bazel_test_cpu_py_import_rbe.sh | 69 +++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/bazel_cpu_py_import_rbe.yml create mode 100755 ci/run_bazel_test_cpu_py_import_rbe.sh diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml new file mode 100644 index 000000000000..14d6b95b4347 --- /dev/null +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -0,0 +1,63 @@ +# CI - Bazel CPU tests with py_import (RBE) +# +# This workflow runs the Bazel CPU tests with py_import dependency. It can only be triggered by +# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` workflows +# to run the Bazel CPU tests. +# +# It consists of the following job: +# run-tests: +# - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions: +# - Runs the Bazel CPU tests with py_import dependency. +name: CI - Bazel CPU tests with py_import (RBE) +permissions: + contents: read + +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version to test?" + type: string + required: true + default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + required: false + default: 'no' + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash + shell: bash + runs-on: ${{ inputs.runner }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + + name: "Bazel CPU tests with py_import (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU tests with py_import (RBE) + timeout-minutes: 60 + run: ./ci/run_bazel_test_cpu_py_import_rbe.sh diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index a2ee224c38e1..207075fd0340 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -9,13 +9,17 @@ # that was built in the previous step and runs CPU tests. # 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and # uploads them to a GCS bucket. -# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA +# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_py_import_rbe.yml` workflow which +# runs Bazel CPU tests with py_import on RBE. +# 5. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA # artifacts that were built in the previous steps and runs the CUDA tests. -# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib +# 6. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib # and CUDA artifacts that were built in the previous steps and runs the # CUDA tests using Bazel. name: CI - Wheel Tests (Continuous) +permissions: + contents: read on: schedule: @@ -136,6 +140,20 @@ jobs: # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + run-bazel-test-cpu-py-import: + uses: ./.github/workflows/bazel_cpu_py_import_rbe.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"] + python: ["3.10",] + enable-x64: [1, 0] + name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + run-bazel-test-cuda: # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we diff --git a/ci/run_bazel_test_cpu_py_import_rbe.sh b/ci/run_bazel_test_cpu_py_import_rbe.sh new file mode 100755 index 000000000000..9a17397c47ff --- /dev/null +++ b/ci/run_bazel_test_cpu_py_import_rbe.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel CPU tests with py_import on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +echo "Running CPU tests..." +# When running on Mac or Linux Aarch64, we build the test targets on RBE +# and run the tests locally. These platforms do not have native RBE support so +# we RBE cross-compile them on remote Linux x86 machines. +if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + bazel test --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + --strategy=TestRunner=local \ + --//jax:build_jaxlib=wheel \ + --//jax:build_jax=wheel \ + //tests:cpu_tests //tests:backend_independent_tests +else + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + --//jax:build_jaxlib=wheel \ + --//jax:build_jax=wheel \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file From b93ffa7041bcecc02fa15f06f16c40eb29574c4d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 2 May 2025 14:09:43 -0700 Subject: [PATCH 0982/1769] Put the Box, List handling inside _flatten_boxes PiperOrigin-RevId: 754128416 --- jax/_src/pjit.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 53cebb951f6a..e087eec3b3d8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -579,14 +579,9 @@ def _infer_params_impl( f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) del kwargs - - # TODO(mattjj,dougalm): refine this implementation of box-handling... - from jax.experimental.attrs import Box, List - if any(isinstance(x, (Box, List)) for x in tree_leaves((dyn_args, dyn_kwargs))): - dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) + dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) + if box_data: f = _handle_boxes(f, dbg) - else: - box_data = [] explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) @@ -3238,6 +3233,9 @@ class ListTree: def _flatten_boxes(dbg, args, kwargs): from jax.experimental.attrs import Box, List + # TODO(mattjj,dougalm): refine this implementation of box-handling... + if all(not isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))): + return args, kwargs, [] box_data = [] id_first_occurrences = {} idxs = itertools.count() From 9fef4001660a4ec8695479ab89caac1aed7a1baa Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 2 May 2025 14:31:30 -0700 Subject: [PATCH 0983/1769] [IFRT] Introduce `xla::ifrt::ShardingRef` type alias `xla::ifrt::ShardingRef` is a type alias for `absl_nonnull std::shared_ptr`, which is a prevailing type for storing a reference to IFRT `Sharding` object. Using this alias throughout the API and user code makes typing more concise and applies non-nullness uniformly. This change is *mostly* trivial refactoring. The places that used (`absl_nullable`) `std::shared_ptr` will need to spell out the full type or use `std::optional`. We do not expect any runtime performance or behavior change (other than any potential performance improvement enabled by extensive use of `absl_nonnull`). PiperOrigin-RevId: 754135729 --- jaxlib/py_array.cc | 4 ++-- jaxlib/py_program.cc | 6 +++--- jaxlib/py_values.cc | 9 ++++----- jaxlib/to_ifrt_sharding.cc | 14 +++++++------- jaxlib/to_ifrt_sharding.h | 15 +++++++-------- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 6277e1dfc702..4c5baab58683 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -213,7 +213,7 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( throw nb::value_error(ifrt_dtype.status().ToString().c_str()); } - absl::StatusOr> ifrt_sharding = + absl::StatusOr ifrt_sharding = sharding.type().is(jax::PmapSharding::type()) ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), std::move(shapes)) @@ -1330,7 +1330,7 @@ absl::StatusOr PyArray::ReorderShards( } TF_ASSIGN_OR_RETURN( - std::shared_ptr dst_ifrt_sharding, + xla::ifrt::ShardingRef dst_ifrt_sharding, GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), ifrt_array_ptr->shape())); diff --git a/jaxlib/py_program.cc b/jaxlib/py_program.cc index 8c57bd0515b8..40bfd3497ebd 100644 --- a/jaxlib/py_program.cc +++ b/jaxlib/py_program.cc @@ -139,10 +139,10 @@ ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { // Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape // dimensions, which may become necessary when building an HLO sharding. -absl::StatusOr> GetIfrtSharding( - nb::handle sharding, int64_t num_dimensions) { +absl::StatusOr GetIfrtSharding(nb::handle sharding, + int64_t num_dimensions) { auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); - std::shared_ptr ifrt_sharding; + ifrt::ShardingRef ifrt_sharding; if (sharding.type().is(jax::SingleDeviceSharding::type())) { TF_ASSIGN_OR_RETURN(auto ifrt_device_list, nb::cast(sharding) diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index b14c1f22708b..b0ac2171eb8d 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -165,7 +165,7 @@ MakeSingleDeviceIfrtArrayFromShard( } else { auto host_buffer_shard = std::get( std::move(shard.ifrt_array_or_host_buffer)); - std::shared_ptr ifrt_sharding = + ifrt::ShardingRef ifrt_sharding = ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); return ifrt_client->MakeArrayFromHostBuffer( host_buffer_shard.data, host_buffer_shard.dtype, @@ -182,8 +182,7 @@ MakeSingleDeviceIfrtArrayFromShard( // Expected to be called without holding GIL. absl::StatusOr> MakeIfrtArrayFromShardsInBatch( ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, - std::shared_ptr ifrt_sharding, - absl::Span shards, + ifrt::ShardingRef ifrt_sharding, absl::Span shards, tsl::RCReference user_context) { absl::InlinedVector< std::pair, ifrt::Client::HostBuffer>, 1> @@ -224,7 +223,7 @@ absl::StatusOr> MakeIfrtArrayFromShardsInBatch( absl::StatusOr> MakeIfrtArrayFromShardsWithAssembly( ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, - std::shared_ptr ifrt_sharding, + ifrt::ShardingRef ifrt_sharding, ifrt::DeviceList* ifrt_addressable_device_list, ifrt::MemoryKind ifrt_memory_kind, absl::Span shards, tsl::RCReference user_context) { @@ -975,7 +974,7 @@ absl::StatusOr DevicePutWithSharding( shard_fns.push_back(std::move(shard)); } - std::shared_ptr ifrt_sharding; + ifrt::ShardingRef ifrt_sharding; if (is_pmap_sharding) { CHECK(!shard_fns.empty()); // IFRT Sharding will be determined once we discover the shard shape. diff --git a/jaxlib/to_ifrt_sharding.cc b/jaxlib/to_ifrt_sharding.cc index f42b13ae4f1c..2bb6e121893f 100644 --- a/jaxlib/to_ifrt_sharding.cc +++ b/jaxlib/to_ifrt_sharding.cc @@ -92,7 +92,7 @@ xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { } // Converts a JAX Sharding into `xla::ifrt::HloSharding`. -absl::StatusOr> GetIfrtHloSharding( +absl::StatusOr GetIfrtHloSharding( nb::handle sharding, const xla::ifrt::Shape& shape) { TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, GetIfrtDeviceList(sharding)); @@ -104,9 +104,9 @@ absl::StatusOr> GetIfrtHloSharding( } // Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. -absl::StatusOr> -GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, - const xla::ifrt::Shape& shape) { +absl::StatusOr GetIfrtConcreteEvenSharding( + nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, GetIfrtDeviceList(sharding)); xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); @@ -127,9 +127,9 @@ GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, } // Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. -absl::StatusOr> -GetIfrtConcreteSharding(nb::handle sharding, const xla::ifrt::Shape& shape, - std::vector shard_shapes) { +absl::StatusOr GetIfrtConcreteSharding( + nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, GetIfrtDeviceList(sharding)); xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); diff --git a/jaxlib/to_ifrt_sharding.h b/jaxlib/to_ifrt_sharding.h index 6d97f61330a0..911a7caea368 100644 --- a/jaxlib/to_ifrt_sharding.h +++ b/jaxlib/to_ifrt_sharding.h @@ -43,19 +43,18 @@ absl::StatusOr GetIfrtDeviceList( xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); // Converts a JAX Sharding into `xla::ifrt::HloSharding`. -absl::StatusOr> GetIfrtHloSharding( +absl::StatusOr GetIfrtHloSharding( nanobind::handle sharding, const xla::ifrt::Shape& shape); // Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. -absl::StatusOr> -GetIfrtConcreteEvenSharding(nanobind::handle sharding, xla::ifrt::DType dtype, - const xla::ifrt::Shape& shape); +absl::StatusOr GetIfrtConcreteEvenSharding( + nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); // Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. -absl::StatusOr> -GetIfrtConcreteSharding(nanobind::handle sharding, - const xla::ifrt::Shape& shape, - std::vector shard_shapes); +absl::StatusOr GetIfrtConcreteSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes); } // namespace xla From 43b784c3c7b961134bbdb2fdaa56a168540926cc Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 2 May 2025 14:39:43 -0700 Subject: [PATCH 0984/1769] [Pallas][Mosaic GPU] Add support for blackwell mma instruction. PiperOrigin-RevId: 754138441 --- jax/_src/pallas/mosaic_gpu/core.py | 15 +- jax/_src/pallas/mosaic_gpu/lowering.py | 8 +- jax/_src/pallas/mosaic_gpu/primitives.py | 183 ++++++++++++++++++++++- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 52 +++++++ 5 files changed, 252 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 23660200fd48..9b293ebb51f1 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -765,6 +765,7 @@ class BarrierType(dtypes.ExtendedDType): name: ClassVar[str] = "barrier" num_arrivals: int + for_tensor_core: bool def __str__(self): return self.name @@ -783,12 +784,24 @@ def __str__(self): @dataclasses.dataclass(frozen=True) class Barrier: + """Describes a barrier Ref. + + Attributes: + num_arrivals: The number of arrivals that will be recorded by this barrier. + num_barriers: The number of barriers that will be created. Individual + barriers can be accessed by indexing into the barrier Ref. + for_tensor_core: Whether this barrier is used for synchronizing with + the tensor core. This should be set to True when waiting on Blackwell + (TC Gen 5) asynchoronous matmul instructions. + """ num_arrivals: int num_barriers: int = 1 + for_tensor_core: bool = dataclasses.field(default=False, kw_only=True) def get_ref_aval(self) -> AbstractMemoryRef: aval = jax_core.ShapedArray( - [self.num_barriers], BarrierType(self.num_arrivals) + [self.num_barriers], BarrierType(self.num_arrivals, + for_tensor_core=self.for_tensor_core) ) return AbstractMemoryRef(aval, SMEM) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b0c0760c73c..0c0444bc7026 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -219,10 +219,11 @@ def _run_scoped_resource_estimator( for v in jaxpr.invars: aval = v.aval if isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = 1 if aval.dtype.for_tensor_core else ctx.arrival_multiplier rs += Resources( barrier_counts=collections.Counter([ mgpu.Barrier( - aval.dtype.num_arrivals * ctx.arrival_multiplier, *aval.shape + aval.dtype.num_arrivals * multiplier, *aval.shape ) ]) ) @@ -2102,11 +2103,12 @@ def _run_scoped_lowering_rule( input_refs.append(acc) should_discharge.append(True) elif isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = (1 if aval.dtype.for_tensor_core else + ctx.estimator_ctx.arrival_multiplier) barrier_ref = alloc_stack.enter_context( ctx.module_ctx.reserve_barrier( mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, + aval.dtype.num_arrivals * multiplier, *aval.shape, ) ) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index d0330e60e961..0384e41144cd 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -34,6 +34,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -44,6 +45,7 @@ from jax._src.state import primitives as state_primitives from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp @@ -618,6 +620,8 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: def _barrier_arrive_abstract_eval(barrier, *args, **params): del args, params # Unused. _check_ref(barrier, "barrier", gpu_core.SMEM) + if getattr(barrier.inner_aval.dtype, "for_tensor_core", False): + raise ValueError("Cannot arrive on a tensor core barrier.") return (), {gpu_core._memory_effect} @@ -706,12 +710,14 @@ def _barrier_wait_lowering( *flat_transforms, transforms_treedef, ): - del ctx # Unused. + barrier_aval = ctx.avals_in[0] transforms = transforms_treedef.unflatten(flat_transforms) indexer = _extract_barrier_indexer(transforms) + for_tensor_core = getattr( + barrier_aval.inner_aval.dtype, "for_tensor_core", False) if indexer is not None: barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices)) - barrier.wait() + barrier.wait(for_tensor_core=for_tensor_core) return () @@ -722,7 +728,7 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: ) flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) barrier_wait_p.bind( - barrier, *flat_transforms, transforms_treedef=transforms_treedef + barrier, *flat_transforms, transforms_treedef=transforms_treedef, ) @@ -1112,6 +1118,177 @@ def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): ) +# MMA for TensorCore gen 5. +tcgen05_mma_p = jax_core.Primitive("tcgen05_mma") +tcgen05_mma_p.multiple_results = True + +def tcgen05_mma(acc: _Ref, + a: _Ref, + b: _Ref, + barrier: _Ref, + accumulate: bool | jax.Array = True): + """Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell). + + Args: + acc: The accumulator. Must be a TMEM Ref. + a: The left-hand side. Must be a TMEM/SMEM Ref. + b: The right-hand side. Must be an SMEM Ref. + barrier: Barrier Ref for synchronizing with the tensor core. Should have + for_tensor_core set to True. + accumulate: Whether to accumulate into acc or overwrite it. + """ + acc_m, acc_n = acc.shape + lhs_m, lhs_k = a.shape + rhs_k, rhs_n = b.shape + if acc_m != lhs_m: + raise ValueError( + f"Accumulator and LHS have incompatible shapes. Accumulator: {acc.shape}. LHS: {a.shape}.") + if acc_n != rhs_n: + raise ValueError( + f"Accumulator and RHS have incompatible shapes. Accumulator: {acc.shape}. RHS: {b.shape}.") + if lhs_k != rhs_k: + raise ValueError( + f"LHS and RHS have incompatible shapes. LHS: {a.shape}. RHS: {b.shape}.") + + if isinstance(a, pallas_core.TransformedRef): + a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) + a = a.ref + else: + a_transforms_leaves, a_transforms_tree = [], None + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None + + tcgen05_mma_p.bind(acc, a, b, barrier, accumulate, + *a_transforms_leaves, *b_transforms_leaves, + a_transforms_tree=a_transforms_tree, + b_transforms_tree=b_transforms_tree, + collective=False) + +@tcgen05_mma_p.def_abstract_eval +def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, + *transforms_leaves, + a_transforms_tree, b_transforms_tree, + collective): + del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) + if collective: + raise NotImplementedError("Collective MMA not yet implemented.") + + if acc.memory_space != gpu_core.GPUMemorySpace.TMEM: + raise ValueError("Accumulator must be a TMEM Ref.") + if a.memory_space != gpu_core.GPUMemorySpace.SMEM: + raise ValueError("LHS must be an SMEM Ref. TMEM not yet supported.") + if b.memory_space != gpu_core.GPUMemorySpace.SMEM: + raise ValueError("RHS must be an SMEM Ref.") + + for_tensor_core = getattr( + barrier.inner_aval.dtype, "for_tensor_core", False) + if not for_tensor_core: + raise ValueError("MMA barrier must have for_tensor_core set to True.") + + return [] + +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS) +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS) +def _tcgen05_mma_lowering( + ctx: lowering.LoweringRuleContext, + acc: tcgen05.TMEMRef, + a_ref, + b_ref, + barrier_ref: mgpu.BarrierRef, + accumulate: bool | ir.Value, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, + collective: bool, +): + _, a_aval, b_aval, *_ = ctx.avals_in + lhs_swizzle: int = 128 + lhs_transpose: bool = False + if a_transforms_tree is not None: + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a_ref, a_transforms = lowering._handle_transforms( + a_ref, a_transforms, handle_transposes=False, handle_reshapes=True + ) + match a_transforms: + case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(lhs_tiling)): + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(lhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True + case _: + raise NotImplementedError( + f"Unsupported transforms: {a_transforms}." + ) + swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + if lhs_tiling != (8, swizzle_elems): + raise ValueError("MMA lhs tiling does not fit swizzle. " + f"{lhs_tiling=} expected={(8, swizzle_elems)}") + else: + b_transforms_leaves = transforms_leaves # type: ignore + + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b_ref, b_transforms = lowering._handle_transforms( + b_ref, b_transforms, handle_transposes=False, handle_reshapes=True + ) + match b_transforms: + case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): + rhs_transpose = False + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + rhs_transpose = True + case _: + raise NotImplementedError( + f"Unsupported transforms: {b_transforms}." + ) + + swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize + if rhs_swizzle != lhs_swizzle: + raise ValueError("MMA rhs swizzle must match lhs swizzle." + f" {lhs_swizzle=} {rhs_swizzle=}") + if rhs_tiling != (8, swizzle_elems): + raise ValueError("MMA rhs tiling does not fit swizzle" + f" {rhs_tiling=} expected={(8, swizzle_elems)}") + if lhs_transpose or rhs_transpose: + raise NotImplementedError("Lowering does not yet support transpose") + if isinstance(accumulate, bool): + accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) + + predicate = ctx.module_ctx.single_lane_predicate + if collective: + index = ir.IndexType.get() + is_leader_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(gpu_dialect.Dimension.x), mgpu.c(0, index)) + predicate = arith_dialect.andi(predicate, is_leader_block) + with mgpu.when(predicate): + tcgen05.mma( + acc, + a_ref, + b_ref, + a_swizzle=rhs_swizzle, + b_swizzle=lhs_swizzle, + accumulate=accumulate, + collective=collective, + ) + tcgen05.commit_arrive(barrier_ref, + collective=collective, + ctx=ctx.launch_ctx) + return [] + class Layout(enum.Enum): #: [m, n] matrix, where m % 64 == 0 == n % 8. WGMMA = enum.auto() diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 63d0019fb99f..ecbe28bbcc20 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -55,6 +55,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait +from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index eac5e735294a..b2325fb62dee 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1705,6 +1705,7 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.inline_mgpu_p, mgpu_primitives.broadcasted_iota_p, mgpu_primitives.load_p, + mgpu_primitives.tcgen05_mma_p, lax.slice_p, pallas_core.core_map_p, } @@ -2110,6 +2111,57 @@ def kernel(y_ref, tmem_ref, smem_ref): # Test that this runs without errors. jax.block_until_ready(kernel()) + @parameterized.parameters( + ((128, 128), 128, jnp.float16), + # Test bfloat16 + ((128, 128), 128, jnp.bfloat16), + # Test additional swizzles. + ((128, 128), 64, jnp.float16), + ((128, 128), 32, jnp.float16), + ) + def test_simple_matmul(self, shape, swizzle, dtype): + # Test a matmul with a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + plgpu.tcgen05_mma(acc_tmem, + a_smem, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(num_arrivals=1, for_tensor_core=True), + ] + f = self.pallas_call( + kernel, + in_specs=( + plgpu.GPUBlockSpec(transforms=transforms, + memory_space=plgpu.SMEM), + plgpu.GPUBlockSpec(transforms=transforms, + memory_space=plgpu.SMEM), + ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + class PallasCallSm100AWGTest( PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From c15f26a1748bae2f5f94cc632416f097196aac4b Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 3 May 2025 05:56:45 -0700 Subject: [PATCH 0985/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3b7465acc3d89452742d3ac8dabe53e9caabe260. PiperOrigin-RevId: 754335061 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4b4a60c182b9..b2324714160d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8f46fb3d56b8d9b049d3158cd12838c24bc95307" -XLA_SHA256 = "786eefd5383f1264dcb1de71bdc1091b9fdb38a52ec859514e20015458206a78" +XLA_COMMIT = "3b7465acc3d89452742d3ac8dabe53e9caabe260" +XLA_SHA256 = "4026e5e64a0cae402d0d17aefa171d4e050d62703ff646252c9aba805c6c06b4" def repo(): tf_http_archive( From a96e175c76e148083830ad4f0cac41386cf25f3f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 3 May 2025 18:37:46 -0700 Subject: [PATCH 0986/1769] Fix with_sharding_constraint when a prng key is passed to it. The problem was that the input aval's rank and the mlir input node's rank did not match. Converting the input aval to physical aval fixes that. PiperOrigin-RevId: 754467370 --- jax/_src/pjit.py | 12 +++++++----- tests/pjit_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e087eec3b3d8..79948cc1da06 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2822,20 +2822,22 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, context_mesh, unconstrained_dims): - aval, = ctx.avals_in + in_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_ctx = ctx.module_context.axis_context + if dtypes.issubdtype(in_aval.dtype, dtypes.extended): + in_aval = core.physical_aval(in_aval) if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and axis_ctx.manual_axes): - sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim) + sharding = mlir.add_manual_axes(axis_ctx, sharding, in_aval.ndim) if config.use_shardy_partitioner.value: - sharding = sharding._to_sdy_sharding(aval.ndim) + sharding = sharding._to_sdy_sharding(in_aval.ndim) else: - sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + sharding = sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out = mlir.wrap_with_sharding_op( ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims) if layout is not None: - out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, in_aval) return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0bd570f37004..68bdeafd9c2b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4934,6 +4934,38 @@ def g(x): else: self.assertIn("unspecified_dims=[0]", lowered_text) + def test_prng_key_wsc(self): + mesh = jtu.create_mesh((2,), 'x') + + @jax.jit + def f(x): + y = lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + return y.T + f(jax.random.key(0)) # doesn't crash + + @jax.jit + def g(x): + return lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + g(jax.random.key(1)) # doesn't crash + + def test_prng_key_wsc_multi_axes_sharding(self): + input_shape = (8, 4) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + return lax.with_sharding_constraint( + make_key(seeds), NamedSharding(mesh, P('x', 'y'))) + + out = make_keys(seeds) + self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key)) + self.assertEqual(out.shape, input_shape) + jax.random.key_data(out) # doesn't crash + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From b5db71bd77ac5e8527a6e0b6c566cd00dabe0abd Mon Sep 17 00:00:00 2001 From: Kirill Bobyrev <3352968+kirillbobyrev@users.noreply.github.com> Date: Sat, 3 May 2025 23:16:16 -0700 Subject: [PATCH 0987/1769] Fix `dlpack` docs formatting --- jax/_src/dlpack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index a0b1db608ad0..40a69d1e0390 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -240,7 +240,7 @@ def from_dlpack(external_array, device transfer or copy was requested. Args: - external_array: An array object that has ``__dlpack__` and + external_array: An array object that has ``__dlpack__`` and ``__dlpack_device__`` methods. device: The (optional) :py:class:`Device`, representing the device on which the returned array should be placed. If given, then the result is From 55fb40f51e382553efd183e2f428bf320ad53540 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 4 May 2025 06:14:52 -0700 Subject: [PATCH 0988/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8a838a92071ebd506036576824be50a120a5905e. PiperOrigin-RevId: 754608791 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b2324714160d..405c89442f8a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3b7465acc3d89452742d3ac8dabe53e9caabe260" -XLA_SHA256 = "4026e5e64a0cae402d0d17aefa171d4e050d62703ff646252c9aba805c6c06b4" +XLA_COMMIT = "8a838a92071ebd506036576824be50a120a5905e" +XLA_SHA256 = "8d2793920dfa93d7c8b9c15f11eaeaef3029872ea8f31aac347fb1d09894d2a6" def repo(): tf_http_archive( From 6fa051431196dcdde2c5c59ea26e1acfc064df33 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 4 May 2025 10:08:42 -0400 Subject: [PATCH 0989/1769] Improve TSAN CI build. * don't use requirements lock patch which often breaks, instead copy wheels to the dist/ which bazel will pick up. * update tsan suppressions, remove some suppressions that haven't been needed in some time, and disable a suppression that should be fixed already. --- .../workflows/requirements_lock_3_13_ft.patch | 87 ------------------- .github/workflows/tsan-suppressions_3.13.txt | 31 ------- .github/workflows/tsan-suppressions_3.14.txt | 10 +-- .github/workflows/tsan.yaml | 44 ++++------ WORKSPACE | 2 + build/requirements_lock_3_13_ft.txt | 2 +- 6 files changed, 26 insertions(+), 150 deletions(-) delete mode 100644 .github/workflows/requirements_lock_3_13_ft.patch diff --git a/.github/workflows/requirements_lock_3_13_ft.patch b/.github/workflows/requirements_lock_3_13_ft.patch deleted file mode 100644 index 7e45fe2b3e26..000000000000 --- a/.github/workflows/requirements_lock_3_13_ft.patch +++ /dev/null @@ -1,87 +0,0 @@ -diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt -index 7fce0eef6..06e2cc5d4 100644 ---- a/build/requirements_lock_3_13_ft.txt -+++ b/build/requirements_lock_3_13_ft.txt -@@ -4,6 +4,12 @@ - # - # bazel run //build:requirements_ft.update - # -+ -+--pre -+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple -+numpy -+ -+ - absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff -@@ -328,68 +334,7 @@ mpmath==1.3.0 \ - --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ - --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c - # via -r build/test-requirements.txt --numpy==2.2.5 \ -- --hash=sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70 \ -- --hash=sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a \ -- --hash=sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4 \ -- --hash=sha256:0bcb1d057b7571334139129b7f941588f69ce7c4ed15a9d6162b2ea54ded700c \ -- --hash=sha256:0cd48122a6b7eab8f06404805b1bd5856200e3ed6f8a1b9a194f9d9054631beb \ -- --hash=sha256:19f4718c9012e3baea91a7dba661dcab2451cda2550678dc30d53acb91a7290f \ -- --hash=sha256:1a161c2c79ab30fe4501d5a2bbfe8b162490757cf90b7f05be8b80bc02f7bb8e \ -- --hash=sha256:1f4a922da1729f4c40932b2af4fe84909c7a6e167e6e99f71838ce3a29f3fe26 \ -- --hash=sha256:261a1ef047751bb02f29dfe337230b5882b54521ca121fc7f62668133cb119c9 \ -- --hash=sha256:262d23f383170f99cd9191a7c85b9a50970fe9069b2f8ab5d786eca8a675d60b \ -- --hash=sha256:2ba321813a00e508d5421104464510cc962a6f791aa2fca1c97b1e65027da80d \ -- --hash=sha256:2c1a1c6ccce4022383583a6ded7bbcda22fc635eb4eb1e0a053336425ed36dfa \ -- --hash=sha256:352d330048c055ea6db701130abc48a21bec690a8d38f8284e00fab256dc1376 \ -- --hash=sha256:369e0d4647c17c9363244f3468f2227d557a74b6781cb62ce57cf3ef5cc7c610 \ -- --hash=sha256:36ab5b23915887543441efd0417e6a3baa08634308894316f446027611b53bf1 \ -- --hash=sha256:37e32e985f03c06206582a7323ef926b4e78bdaa6915095ef08070471865b906 \ -- --hash=sha256:3a801fef99668f309b88640e28d261991bfad9617c27beda4a3aec4f217ea073 \ -- --hash=sha256:3d14b17b9be5f9c9301f43d2e2a4886a33b53f4e6fdf9ca2f4cc60aeeee76372 \ -- --hash=sha256:422cc684f17bc963da5f59a31530b3936f57c95a29743056ef7a7903a5dbdf88 \ -- --hash=sha256:4520caa3807c1ceb005d125a75e715567806fed67e315cea619d5ec6e75a4191 \ -- --hash=sha256:47834cde750d3c9f4e52c6ca28a7361859fcaf52695c7dc3cc1a720b8922683e \ -- --hash=sha256:47f9ed103af0bc63182609044b0490747e03bd20a67e391192dde119bf43d52f \ -- --hash=sha256:498815b96f67dc347e03b719ef49c772589fb74b8ee9ea2c37feae915ad6ebda \ -- --hash=sha256:54088a5a147ab71a8e7fdfd8c3601972751ded0739c6b696ad9cb0343e21ab73 \ -- --hash=sha256:55f09e00d4dccd76b179c0f18a44f041e5332fd0e022886ba1c0bbf3ea4a18d0 \ -- --hash=sha256:5a0ac90e46fdb5649ab6369d1ab6104bfe5854ab19b645bf5cda0127a13034ae \ -- --hash=sha256:6411f744f7f20081b1b4e7112e0f4c9c5b08f94b9f086e6f0adf3645f85d3a4d \ -- --hash=sha256:6413d48a9be53e183eb06495d8e3b006ef8f87c324af68241bbe7a39e8ff54c3 \ -- --hash=sha256:7451f92eddf8503c9b8aa4fe6aa7e87fd51a29c2cfc5f7dbd72efde6c65acf57 \ -- --hash=sha256:8b4c0773b6ada798f51f0f8e30c054d32304ccc6e9c5d93d46cb26f3d385ab19 \ -- --hash=sha256:8dfa94b6a4374e7851bbb6f35e6ded2120b752b063e6acdd3157e4d2bb922eba \ -- --hash=sha256:97c8425d4e26437e65e1d189d22dff4a079b747ff9c2788057bfb8114ce1e133 \ -- --hash=sha256:9d75f338f5f79ee23548b03d801d28a505198297534f62416391857ea0479571 \ -- --hash=sha256:9de6832228f617c9ef45d948ec1cd8949c482238d68b2477e6f642c33a7b0a54 \ -- --hash=sha256:a4cbdef3ddf777423060c6f81b5694bad2dc9675f110c4b2a60dc0181543fac7 \ -- --hash=sha256:a9c0d994680cd991b1cb772e8b297340085466a6fe964bc9d4e80f5e2f43c291 \ -- --hash=sha256:aa70fdbdc3b169d69e8c59e65c07a1c9351ceb438e627f0fdcd471015cd956be \ -- --hash=sha256:abe38cd8381245a7f49967a6010e77dbf3680bd3627c0fe4362dd693b404c7f8 \ -- --hash=sha256:b13f04968b46ad705f7c8a80122a42ae8f620536ea38cf4bdd374302926424dd \ -- --hash=sha256:b4ea7e1cff6784e58fe281ce7e7f05036b3e1c89c6f922a6bfbc0a7e8768adbe \ -- --hash=sha256:b6f91524d31b34f4a5fee24f5bc16dcd1491b668798b6d85585d836c1e633a6a \ -- --hash=sha256:c26843fd58f65da9491165072da2cccc372530681de481ef670dcc8e27cfb066 \ -- --hash=sha256:c42365005c7a6c42436a54d28c43fe0e01ca11eb2ac3cefe796c25a5f98e5e9b \ -- --hash=sha256:c8b82a55ef86a2d8e81b63da85e55f5537d2157165be1cb2ce7cfa57b6aef38b \ -- --hash=sha256:ced69262a8278547e63409b2653b372bf4baff0870c57efa76c5703fd6543282 \ -- --hash=sha256:d2e3bdadaba0e040d1e7ab39db73e0afe2c74ae277f5614dad53eadbecbbb169 \ -- --hash=sha256:d403c84991b5ad291d3809bace5e85f4bbf44a04bdc9a88ed2bb1807b3360bb8 \ -- --hash=sha256:d7543263084a85fbc09c704b515395398d31d6395518446237eac219eab9e55e \ -- --hash=sha256:d8882a829fd779f0f43998e931c466802a77ca1ee0fe25a3abe50278616b1471 \ -- --hash=sha256:e4f0b035d9d0ed519c813ee23e0a733db81ec37d2e9503afbb6e54ccfdee0fa7 \ -- --hash=sha256:e8b025c351b9f0e8b5436cf28a07fa4ac0204d67b38f01433ac7f9b870fa38c6 \ -- --hash=sha256:eb7fd5b184e5d277afa9ec0ad5e4eb562ecff541e7f60e69ee69c8d59e9aeaba \ -- --hash=sha256:ec31367fd6a255dc8de4772bd1658c3e926d8e860a0b6e922b615e532d320ddc \ -- --hash=sha256:ee461a4eaab4f165b68780a6a1af95fb23a29932be7569b9fab666c407969051 \ -- --hash=sha256:f5045039100ed58fa817a6227a356240ea1b9a1bc141018864c306c1a16d4175 -- # via -- # -r build/freethreading-requirements.txt -- # contourpy -- # matplotlib -- # ml-dtypes -- # scipy -+ - nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index fceb8f5f9a61..e82699036e92 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -37,41 +37,10 @@ race:scal_k_ race:gemm_beta race:gemm_oncopy - - -# Races below this point are likely fixed. -# TODO(phawkins): remove these if they don't show up in CI again. - -# https://github.com/python/cpython/issues/128100 -# race:ensure_nonmanaged_dict - -# https://github.com/python/cpython/issues/128657 -# race:py_digest_by_name - -# https://github.com/python/cpython/issues/128714 -# race:func_get_annotations - -# https://github.com/python/cpython/issues/129533 -# race:PyGC_Disable -# race:PyGC_Enable - -# https://github.com/python/cpython/issues/128133 -# race:bytes_hash - -# https://github.com/python/cpython/issues/130571 -# race:_PyObject_GetMethod - -# https://github.com/python/cpython/issues/130547 -# race:split_keys_entry_added - # https://github.com/python/cpython/issues/132245 race:split_keys_entry_added race_top:dict_dict_merge -# https://github.com/python/cpython/issues/129547 -# Maybe fixed? -# race:type_get_annotations - # https://github.com/python/cpython/issues/132013 # Fixed on 3.14 and not backported to 3.13 race_top:frozenset_hash \ No newline at end of file diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index 3d0f5518862a..ec4d81c987d0 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -11,12 +11,6 @@ race:partial_vectorcall_fallback # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj - -# https://github.com/python/cpython/issues/132214 -race_top:update_one_slot - # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi @@ -24,3 +18,7 @@ race:dscal_k_ race:scal_k_ race:gemm_beta race:gemm_oncopy + +# https://github.com/python/cpython/issues/132214 +# Should be fixed +# race_top:update_one_slot diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 8336d86120d1..596bc425bfeb 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -44,18 +44,13 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | apt update - apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + apt install -q -y clang-18 libstdc++-14-dev build-essential libssl-dev \ zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev file zip - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: ${{ matrix.github_branch }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy @@ -81,6 +76,13 @@ jobs: ./python-tsan.tgz key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' + with: + repository: python/cpython + path: cpython + ref: ${{ matrix.github_branch }} + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | @@ -180,9 +182,11 @@ jobs: - name: Build Scipy wheel if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + env: + DEBIAN_FRONTEND: noninteractive run: | # Install scipy dependencies: - apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + apt install -q -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends cd scipy @@ -254,7 +258,9 @@ jobs: JAX_ENABLE_X64: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 + DEBIAN_FRONTEND: noninteractive run: | + set -x cd jax export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz)) @@ -272,31 +278,19 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - if [ "${{ matrix.python-version }}" == "3.13" ]; then - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/${{ matrix.requirements_lock_name }}.patch - cat .github/workflows/${{ matrix.requirements_lock_name }}.patch - git apply .github/workflows/${{ matrix.requirements_lock_name }}.patch || exit 1 - - # Display the content for debugging in logs - cat build/${{ matrix.requirements_lock_name }}.txt | head -15 - # Check the patch - cat build/${{ matrix.requirements_lock_name }}.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi - cat build/${{ matrix.requirements_lock_name }}.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi - - else + mkdir -p dist + cp -v ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl dist/ + cp -v ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl dist/ + if [ "${{ matrix.python-version }}" == "3.14" ]; then # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy - sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt # We should install jpeg dev package to be able to build Pillow from source: - apt-get install -y libjpeg-dev --no-install-recommends + apt install -q -y libjpeg-dev --no-install-recommends # Install scipy runtime dependencies (in case we restore scipy wheel from cache): - apt-get install -y libopenblas-dev liblapack-dev --no-install-recommends + apt install -q -y libopenblas-dev liblapack-dev --no-install-recommends fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" diff --git a/WORKSPACE b/WORKSPACE index f9c0b3ccfea7..903085714e65 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,8 @@ python_init_repositories( default_python_version = "system", local_wheel_dist_folder = "../dist", local_wheel_inclusion_list = [ + "numpy*", + "scipy*", "jax-*", "jaxlib*", "jax_cuda*", diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 5dd300f224e5..13bb7126d62a 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -690,4 +690,4 @@ zipp==3.21.0 \ setuptools==70.3.0 \ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc - # via -r build/requirements.in + # via -r build/requirements.in \ No newline at end of file From f5d621f42424e2aed971e3d8bde6bcbf96682744 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sun, 4 May 2025 11:08:05 -0700 Subject: [PATCH 0990/1769] Add block_until_ready to buffer_callback tests where needed. PiperOrigin-RevId: 754658087 --- tests/buffer_callback_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py index 06f0bcf7c3bc..acc414bf7fd4 100644 --- a/tests/buffer_callback_test.py +++ b/tests/buffer_callback_test.py @@ -95,7 +95,7 @@ def callback(ctx, out, arg): # We can't actually test the output because numpy doesn't support writable # DLPack tensors. - fun(data) + jax.block_until_ready(fun(data)) @parameterized.parameters(jtu.dtypes.all) @jtu.run_on_devices("cuda") @@ -123,7 +123,7 @@ def callback(ctx, out, arg): fun = buffer_callback.buffer_callback( callback, jax.ShapeDtypeStruct(data.shape, data.dtype) ) - fun(data) + jax.block_until_ready(fun(data)) @parameterized.parameters([ "sequential", "sequential_unrolled", "expand_dims", "broadcast_all" @@ -164,7 +164,7 @@ def callback(ctx, out, arg): callback, jax.ShapeDtypeStruct(data.shape, data.dtype), input_output_aliases={0: 0}, ) - fun(data) + jax.block_until_ready(fun(data)) def test_side_effect(self): def callback(*_): @@ -172,8 +172,9 @@ def callback(*_): called = True called = False - fun = buffer_callback.buffer_callback(callback, (), has_side_effect=True) - fun() + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct((), jnp.float32), has_side_effect=True) + jax.block_until_ready(fun()) self.assertTrue(called) From d58a73463f286595957c5f84f4697d9479d5e95a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 4 May 2025 13:08:40 -0700 Subject: [PATCH 0991/1769] Block yt_kernel in test_upcast_to_wgmma before the assert PiperOrigin-RevId: 754679843 --- tests/mosaic/gpu_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c8e204433a37..990398708a94 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2586,6 +2586,7 @@ def tile(x, tiling): ) with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: yt_kernel = f(xt) + jax.block_until_ready(yt_kernel) np.testing.assert_array_equal(yt_kernel, yt) self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) From e1ca6ababaa4fc58226a655b6f46b3a06a96dc90 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 4 May 2025 18:16:32 -0700 Subject: [PATCH 0992/1769] [XLA:Python] [JAX] Move HloPass bindings to XLA/Python. JAX does not use these. PiperOrigin-RevId: 754733948 --- jaxlib/BUILD | 6 ----- jaxlib/_jax/__init__.pyi | 19 -------------- jaxlib/xla_compiler.cc | 53 ---------------------------------------- 3 files changed, 78 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 8dabee20038c..62621302b8f3 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1210,12 +1210,7 @@ cc_library( "@xla//xla/ffi/api:c_api", "@xla//xla/hlo/builder:xla_computation", "@xla//xla/hlo/ir:hlo", - "@xla//xla/hlo/ir:hlo_module_group", "@xla//xla/hlo/parser:hlo_parser", - "@xla//xla/hlo/pass:hlo_pass", - "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", - "@xla//xla/hlo/transforms/simplifiers:hlo_dce", - "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", "@xla//xla/pjrt:compile_options_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:pjrt_executable", @@ -1224,7 +1219,6 @@ cc_library( "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", - "@xla//xla/service:call_inliner", "@xla//xla/service:computation_placer", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:hlo_graph_dumper", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 895d56e852e1..c0c0dc77ccb7 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -944,25 +944,6 @@ def pjit( cache: PjitFunctionCache | None = ..., ) -> PjitFunction: ... -class HloPassInterface: - @property - def name(self) -> str: ... - def is_pass_pipeline(self) -> bool: ... - def run(self, module: HloModule) -> bool: ... - def run_on_module_group(self, module_group: HloModuleGroup) -> bool: ... - -class HloDCE(HloPassInterface): - def __init__(self) -> None: ... - -class CallInliner(HloPassInterface): - def __init__(self) -> None: ... - -class FlattenCallGraph(HloPassInterface): - def __init__(self) -> None: ... - -class TupleSimplifer(HloPassInterface): - def __init__(self) -> None: ... - class WeakrefLRUCacheInfo: @property def hits(self) -> int: ... diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 9b1377a1a9d1..8803da924381 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -52,14 +52,9 @@ limitations under the License. #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_parser.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" -#include "xla/hlo/transforms/simplifiers/hlo_dce.h" -#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -70,7 +65,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_numpy.h" #include "xla/python/types.h" -#include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/hlo.pb.h" @@ -820,32 +814,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { return param_shardings; }); - nb::class_ hlo_module_group_class(m, "HloModuleGroup"); - hlo_module_group_class - .def("__init__", - [](HloModuleGroup* self, const std::string& name, - const std::vector>& hlo_modules) { - std::vector> modules; - modules.reserve(hlo_modules.size()); - for (const auto& m : hlo_modules) { - modules.push_back(m->Clone(/*suffix=*/"")); - } - new (self) HloModuleGroup(name, std::move(modules)); - }) - .def_prop_ro("name", &HloModuleGroup::name) - .def("to_string", &HloModuleGroup::ToString) - .def("to_modules", - [](HloModuleGroup& m) -> std::vector> { - std::vector> modules = - m.ConsumeModules(); - std::vector> shared_modules; - shared_modules.reserve(modules.size()); - for (auto& module : modules) { - shared_modules.push_back(std::move(module)); - } - return shared_modules; - }); - m.def("hlo_module_to_dot_graph", [](const HloModule& hlo_module) -> std::string { return xla::ValueOrThrow(RenderGraph( @@ -1479,26 +1447,5 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def("__repr__", [](const xla::HloSharding& self) { return self.ToString(); }) .def("to_proto", &xla::HloSharding::ToProto); - - - // Hlo Module Passes - nb::class_ hlo_pass_interface(m, "HloPassInterface"); - hlo_pass_interface.def_prop_ro("name", &HloPassInterface::name) - .def("is_pass_pipeline", &HloPassInterface::IsPassPipeline) - .def("run", - [](HloPassInterface& pass, HloModule* module) -> bool { - return xla::ValueOrThrow(pass.Run(module)); - }) - .def("run_on_module_group", - [](HloPassInterface& pass, HloModuleGroup* module_group) -> bool { - return xla::ValueOrThrow(pass.RunOnModuleGroup(module_group)); - }); - - nb::class_(m, "HloDCE").def(nb::init<>()); - nb::class_(m, "CallInliner").def(nb::init<>()); - nb::class_(m, "FlattenCallGraph") - .def(nb::init<>()); - nb::class_(m, "TupleSimplifier") - .def(nb::init<>()); } // NOLINT(readability/fn_size) } // namespace xla From 820e3dbd2c17900c1bd455d0b0f166615207696e Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 5 May 2025 02:05:49 -0700 Subject: [PATCH 0993/1769] Increase tolerance in mgpu_attention_test . PiperOrigin-RevId: 754844152 --- tests/pallas/mgpu_attention_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 3f0370153d81..50f9a455c9a2 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -153,7 +153,7 @@ def f_ref(q, k, v): dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) - self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dq, dq_ref, atol=7e-2) self.assertAllClose(dk, dk_ref, atol=7e-2) self.assertAllClose(dv, dv_ref, atol=5e-2) From 30f38eaa9e7466fc4dd2dff4361ed8df6033a4a5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 5 May 2025 02:06:28 -0700 Subject: [PATCH 0994/1769] [Mosaic GPU] Improve checks in utils.ptr_to_memref We previously didn't check the strides and offset that appear in the desired memref_type, and simply assumed that the offset is 0 and the memref is intended to be contiguous. Now, we check that the offset is 0 (if you intend it to not be zero, then perhaps just shift the pointer?) and use the strides inferred from the type. I haven't found any bugs related to this, but it just feels like a footgun waiting to fire. PiperOrigin-RevId: 754844324 --- jax/experimental/mosaic/gpu/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 8bea56abf485..d89108cf451d 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -64,6 +64,9 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None): + strides, offset = memref_ty.get_strides_and_offset() + if offset != 0: + raise ValueError("Non-zero offset is not supported for ptr_as_memref") i64 = ir.IntegerType.get_signless(64) rank = len(memref_ty.shape) ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>" @@ -84,7 +87,7 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] ) - for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + for i, s in enumerate(strides): desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] ) From b4d8b78eb7bd7e00aaa9d3614e7b14014f330d91 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 5 May 2025 04:04:59 -0700 Subject: [PATCH 0995/1769] [Mosaic GPU] A bag of fixes This also includes a bag of fixes to make our B200 CI green: * Fixed a type error in tcgen05. * Fixed the Mosaic profiler to estimate the profiling overhead instead of assuming a reasonable value (it was ok on H100, but B200 has lower overheads). * Added some skips for cuDNN attention tests that are broken at the moment PiperOrigin-RevId: 754875555 --- jax/experimental/mosaic/gpu/layout_inference.py | 3 +-- jax/experimental/mosaic/gpu/profiler.py | 14 ++++++++++++-- jax/experimental/mosaic/gpu/tcgen05.py | 2 +- tests/mosaic/BUILD | 5 ++++- tests/pallas/BUILD | 1 + 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index feae31f0e9f6..b39dc933ce9d 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -21,7 +21,6 @@ import math from typing import cast -from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -534,7 +533,7 @@ def _infer_layout_cast_op_layout( # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. -if jaxlib.version >= (0, 6, 1): +if hasattr(mgpu, "BroadcastInDimOp"): @partial(_add_layout_inference_rule, mgpu.BroadcastInDimOp) def _infer_broadcast_in_dim_op_layout( op: mgpu.BroadcastInDimOp, diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 5b278468b98c..a048903428ae 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -263,10 +263,20 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): if np.any(entries_used > self.entries_per_warpgroup - 2): raise RuntimeError("Insufficient space to capture a full trace") traces = entries[..., 3:] + + # Estimate the overhead of profiling. + time_events = traces[:, :, 1::2] + valid_times_mask = np.arange(traces.shape[-1])[1::2] < (entries_used[..., None] - 3) + # 12 cycles is a ballpark estimate for H100 + profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min( + where=valid_times_mask[:, :, 1:], initial=12 + ) + profiling_overhead = max(0, profiling_overhead - 1) + unintern = {v: k for k, v in self.interned_names.items()} events = [] for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): - valid_entries = entries_used[block_idx, wg_idx] - 3 + valid_entries = (entries_used[block_idx, wg_idx] - 3) local_clock_offset = None assert valid_entries % 2 == 0, valid_entries start_time = start_times[block_idx, wg_idx] @@ -278,7 +288,7 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset - time -= i * 6 # Account for the overhead of profiling. + time -= (i // 2) * profiling_overhead # Account for the overhead of profiling. if time < 0: break # Detect a timer wraparound name_id = tag diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index e1f12a0f95cd..4726805f5b76 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -643,7 +643,7 @@ def __getitem__(self, *idxs): arith.addi(self.address, arith.constant(i32, col_tile * 128)), cols=128, dtype=self.dtype, - tmem_packing=False, + tmem_packing=1, ) ) registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 735c9ffd5b42..ea8be497faa4 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -104,7 +104,10 @@ jax_multiplatform_test( enable_backends = [], enable_configs = ["gpu_h100"], main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", - tags = ["notap"], + tags = [ + "manual", + "notap", + ], deps = [ "//jax:mosaic_gpu", ] + py_deps("numpy"), diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 897bc82d4365..ed7e96ad1d02 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -151,6 +151,7 @@ jax_multiplatform_test( "JAX_PALLAS_USE_MOSAIC_GPU": "1", "JAX_PALLAS_VERBOSE_ERRORS": "0", }, + shard_count = 16, tags = [ "noasan", # Times out. "nomsan", # Times out. From 073034cacfee094998e6f9cac5698f3b9496e93d Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 5 May 2025 19:36:46 +0800 Subject: [PATCH 0996/1769] Update sharded-computation.md --- docs/sharded-computation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 8af91ef5c306..16a5dc8cfa08 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -147,7 +147,7 @@ print(result) +++ {"id": "Q4N5mrr9i_ki"} -The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. +The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on. ## 2. Explicit sharding From 63003bb5789539d0c90af6447daef1c99b4733da Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 5 May 2025 05:11:46 -0700 Subject: [PATCH 0997/1769] [Mosaic GPU] Run a standard mlir cse pass before Mosaic layout inference. There are cases where layout inference fails with unused `vector.load` ops. This CL adds a pass to remove these. The unused ops are the result of lowering expressions like `o[...] = a[...] + b[...]` where the lowering goes through `swap` and always return the old value of `o` even if it's not used. PiperOrigin-RevId: 754892752 --- jax/_src/pallas/mosaic_gpu/lowering.py | 5 +++++ jax/experimental/mosaic/gpu/core.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c0444bc7026..84fec5156d4d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -820,6 +820,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + # We need to run CSE first in orderto remove dead-code for which layout + # inference does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 97f31cbdb32d..dd8996dc0f6f 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -633,6 +633,11 @@ def as_gpu_kernel( ) if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: + # We need to run CSE first in orderto remove dead-code for which layout + # inference does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error @@ -716,6 +721,11 @@ def as_torch_gpu_kernel( ) if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: + # We need to run CSE first in orderto remove dead-code for which layout + # inference does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error From 405c79b03a39347228d6dfd8ead7f71f0114ecbc Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 5 May 2025 05:32:44 -0700 Subject: [PATCH 0998/1769] [Mosaic GPU] Use warpgroup semantics for the FlashAttention 3 pipelined kernel. As a result we can remove all transforms and layout casts. PiperOrigin-RevId: 754897809 --- .../pallas/ops/gpu/attention_mgpu.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index a9403ad22786..4d43d6045fee 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -613,9 +613,6 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residual if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((8, 64)) - swizzle = plgpu.SwizzleTransform(128) - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") @@ -633,15 +630,9 @@ def perform_schedule_barrier(): def _compute_thread(): qo_smem = qo_smem2.at[wg_idx] lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None - m_i = plgpu.layout_cast( - jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - l_i = plgpu.layout_cast( - jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - acc = plgpu.layout_cast( - jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, - ) + m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32) + l_i = jnp.full((block_q,), 0, dtype=jnp.float32) + acc = jnp.full((block_q, head_dim), 0, dtype=jnp.float32) # Q is not pipelined, so we load in with a manual DMA. plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], @@ -712,12 +703,10 @@ def compute_pv(acc_ref): in_specs=[ plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, swizzle]), + index_map=lambda i: (i, 0)), plgpu.GPUBlockSpec( # v block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, swizzle]), + index_map=lambda i: (i, 0)), ], out_specs=[], ) @@ -732,13 +721,16 @@ def compute_pv(acc_ref): ) def run(refs): q_ref, k_ref, v_ref, out_ref, lse_ref = refs - @pl.core_map(mesh, - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) + + @pl.core_map( + mesh, + compiler_params=plgpu.GPUCompilerParams( + approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup + ), + ) def _kernel_entry(): qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, - transforms=(tiling, swizzle), ) scratch = [qo_scratch, None] if save_residuals: From c75c3fae9a1542650a2d9d908531211692ddac13 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 5 May 2025 05:54:50 -0700 Subject: [PATCH 0999/1769] [Mosaic GPU] Enable `pallas/ops_test:test_comparison_scalar` on Mosaic GPU (except for bool) PiperOrigin-RevId: 754903306 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +++--- .../mosaic/gpu/transform_inference.py | 11 +++++++++++ tests/pallas/ops_test.py | 17 +++++++++-------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 84fec5156d4d..cacca8d3d1d9 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1354,8 +1354,9 @@ def _swap_lowering_rule( def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): - if not ir.VectorType.isinstance(value.type): - raise TypeError(f"Can only store vectors (got {value}).") + shape = ctx.avals_out[0].shape + if shape and not ir.VectorType.isinstance(value.type): + raise TypeError(f"Can only store scalars or vectors (got {value}).") if not ir.MemRefType.isinstance(x_smem.type): raise TypeError(f"Can only store to references (got {x_smem}).") @@ -1368,7 +1369,6 @@ def _swap_lowering_rule_wg( "Transforms are not yet implemented for warpgroup semantics" ) - shape = ctx.avals_out[0].shape ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index e146306a41db..d4810845f2bc 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -195,6 +195,17 @@ def _infer_vector_load_store_transforms( return None if transforms is None else ([transforms], []) +@partial(_add_transform_inference_rule, memref.StoreOp) +def _infer_memref_store_transforms(op: memref.StoreOp) -> OptionalTransforms: + # memref.store is only used for scalar operations, so there are no transforms. + ref_shape = ir.MemRefType(op.memref.type).shape + if ref_shape != [] and ref_shape != [1]: + raise NotImplementedError( + f"Only scalar memrefs are supported, got {ref_shape}" + ) + + return None + @partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: transforms = None diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index bc29e741f119..cf6536df2344 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1320,8 +1320,6 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison_scalar(self, fn, dtype): - self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @@ -1331,6 +1329,9 @@ def test_comparison_scalar(self, fn, dtype): ): self.skipTest("Only works on GPUs with capability >= sm80") + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, in_specs=( @@ -1338,17 +1339,17 @@ def test_comparison_scalar(self, fn, dtype): pl.BlockSpec(memory_space=smem_on_tpu()), ), out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - for i in range(8): - o_ref[i] = fn(x_ref[i], y_ref[i]) + for i in range(128): + o_ref[i] = fn(x_ref[i], y_ref[i]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) def test_isnan(self): self.skip_if_mosaic_gpu() From 842170be05cf57f9088d569f5d9f65350468016e Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 May 2025 06:29:50 -0700 Subject: [PATCH 1000/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/03d995ffbc7653bcc4fe0d477330422f139c634b. PiperOrigin-RevId: 754912573 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 405c89442f8a..db143127ffe6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8a838a92071ebd506036576824be50a120a5905e" -XLA_SHA256 = "8d2793920dfa93d7c8b9c15f11eaeaef3029872ea8f31aac347fb1d09894d2a6" +XLA_COMMIT = "03d995ffbc7653bcc4fe0d477330422f139c634b" +XLA_SHA256 = "a9c9245e4e9971f57e252ab023ed26304c3cb0ffc1164bf30231930a3096bc3f" def repo(): tf_http_archive( From 0e105c8dfac5073fd81460011b4f3d3387559659 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Fri, 2 May 2025 12:13:14 -0400 Subject: [PATCH 1001/1769] [Mosaic GPU] Print instead of warning when skipping flash_attention test https://github.com/jax-ml/jax/pull/28403 made all warnings trigger test failures which is not appropriate in this case. --- jax/experimental/mosaic/gpu/examples/flash_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 57f30b8603c8..071a4dec81fd 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -17,7 +17,6 @@ import dataclasses import enum import itertools -import warnings import jax from jax import random @@ -601,7 +600,7 @@ def ref(q, k, v): if __name__ == "__main__": if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): - warnings.warn( + print( "Mosaic GPU Flash Attention requires compute capability 9.0a to run, " "skipping.") exit(0) From 108b288bcbc44e09866f99ddd811514301334b0a Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 5 May 2025 14:54:20 +0000 Subject: [PATCH 1002/1769] jax-cuda12-plugin: require nvidia-cublas-cu12<12.9 This avoids a bug in cuDNN when used with cuBLAS 12.9 --- jax_plugins/cuda/plugin_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index b9220cd29283..4b442f3aa7de 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -53,7 +53,8 @@ def has_ext_modules(self): install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ 'with-cuda': [ - "nvidia-cublas-cu12>=12.1.3.1", + # cudnn has a bug with mxfp8 with multiple GPUs per process and cublas 12.9 + "nvidia-cublas-cu12>=12.1.3.1,<12.9", "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", From 6c7e0220a77e133811e8658de9a61f85b7af256b Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 May 2025 09:28:00 -0700 Subject: [PATCH 1003/1769] Call block_until_ready for testQrInvalidDtypeCPU PiperOrigin-RevId: 754965091 --- tests/linalg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 1670f1ee4abd..74259c300cf7 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1061,7 +1061,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): else: err, msg = Exception, "Unsupported dtype" with self.assertRaisesRegex(err, msg): - jnp.linalg.qr(arr) + jax.block_until_ready(jnp.linalg.qr(arr)) @jtu.sample_product( shape=[(10, 4, 5), (5, 3, 3), (7, 6, 4)], From 5a5fd86958c10f8b928523af9585893fa1e5054a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 May 2025 09:46:16 -0700 Subject: [PATCH 1004/1769] Call block_until_ready for test_assert PiperOrigin-RevId: 754971091 --- tests/mosaic/gpu_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 990398708a94..eb5bb355ffe1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3343,7 +3343,7 @@ def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: ) with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: - f(jnp.ones((128,), jnp.float32)) + jax.block_until_ready(f(jnp.ones((128,), jnp.float32))) # SASS doesn't seem to include the assertion message, so we are just # checking that __assertfail appears in the symbol table for the kernel. From 06988e612dad47b6829d0fec426eedc2d8c8fa84 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 May 2025 09:51:33 -0700 Subject: [PATCH 1005/1769] Initial f32, bf16 TPU SublanesShuffleOp Add a TPU SublaneShuffleOp, which allows us to swap sublanes across vregs. The contract is sublane_shuffle(lhs, rhs, pattern)->out where pattern encompasses the entire set of sublanes desired position to be pulled from either lhs or rhs. PiperOrigin-RevId: 754972962 --- jaxlib/mosaic/dialect/tpu/tpu.td | 32 ++++++++++++++++++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 44 ++++++++++++++++++++++++++++ jaxlib/mosaic/dialect/tpu/util.cc | 5 +++- 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index e574889626aa..8b6e005e1719 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -850,6 +850,38 @@ def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { let results = (outs AnyVectorOfNonZeroRank:$output); } +def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> { + // This op takes 2 physical vregs and a pattern, applies the pattern, + // and returns the result as 1 vreg. + // + // The pattern is a list of integers, where the integer value is the + // index of the sublane in the *combined input* [lhs, rhs], and the + // position of the integer in the list is the index of the sublane + // in the *output* vreg. + // + // The pattern size must match the operand/result sublane count. + // + // Example: + // %0 = tpu.single_output_sublane_shuffle %a, %b, + // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a + // %1 = tpu.single_output_sublane_shuffle %a, %b, + // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b + // %2 = tpu.single_output_sublane_shuffle %a, %b, + // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a + // // and low half of b, reversed. + let arguments = (ins + TPU_Vreg:$lhs, + TPU_Vreg:$rhs, + DenseI32ArrayAttr:$pattern + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasVerifier = 1; +} + def TPU_LogOp : TPU_Op<"log"> { let arguments = (ins Variadic:$inputs, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 341ead8431b4..bbc5be3d125d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1270,6 +1270,50 @@ LogicalResult AssumeMultipleOp::verify() { return success(); } +LogicalResult SublaneShuffleOp::verify() { + auto lhs = getLhs(); + auto rhs = getRhs(); + auto result = getResult(); + auto lhs_ty = dyn_cast(lhs.getType()); + auto rhs_ty = dyn_cast(rhs.getType()); + auto result_ty = dyn_cast(result.getType()); + + if (!lhs_ty || !rhs_ty || !result_ty) { + return emitOpError("Expected operands and result to be vector types"); + } + + if (lhs_ty.getShape() != rhs_ty.getShape() || + lhs_ty.getShape() != result_ty.getShape()) { + return emitOpError("Expected lhs, rhs, and result shapes to match"); + } + if (lhs_ty.getElementType() != rhs_ty.getElementType() || + lhs_ty.getElementType() != result_ty.getElementType()) { + return emitOpError("Expected lhs, rhs, and result element types to match"); + } + + auto pattern = getPattern(); + auto shape = result_ty.getShape(); + if (shape.size() < 2 || shape.size() > 3) { + return emitOpError("Vreg rank should be 2 or 3"); + } + auto sublane_count = shape[0]; + + if (pattern.size() != sublane_count) { + return emitOpError("Expected pattern size (") + << pattern.size() << ") to match result/operand sublanes (" + << sublane_count << ")"; + } + + int64_t total_input_sublanes = sublane_count * 2; + for (int32_t idx : pattern) { + if (idx < 0 || idx >= total_input_sublanes) { + return emitOpError("Pattern index ") << idx << " out of bounds [0, " + << (total_input_sublanes - 1) << "]"; + } + } + return success(); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 141f52ec125b..bb42c678bbf6 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -210,7 +210,10 @@ FailureOr> getOutLayouts( FAILUREOR_ASSIGN_OR_RETURN(const SmallVector out_layouts, getLayoutArrayFromAttr(op.getAttr("out_layout"))); if (out_layouts.size() != op.getNumResults()) { - return op.emitOpError("out_layout size does not match number of results"); + return op.emitOpError("out_layout size does not match number of results") + << " results: " << op.getNumResults() + << " vs layout size: " << out_layouts.size() << " for " + << op.getName(); } for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) { if (!layoutIsValidForValue(l, res, target_shape)) { From a2d67a6a00e26d0a30460db8c73decb6c2519fb7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 May 2025 10:04:15 -0700 Subject: [PATCH 1006/1769] create a mosaic_gpu_support cc_library and include it as part of the //third_party/py/jax/jaxlib/cuda:cuda_gpu_kernels PiperOrigin-RevId: 754978294 --- jaxlib/cuda/BUILD | 1 + jaxlib/mosaic/gpu/BUILD | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index c872433ce04a..2cc1476b637e 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -474,6 +474,7 @@ cc_library( ":cusolver_kernels_ffi", ":cusparse_kernels", ":triton_kernels", + "//jaxlib/mosaic/gpu:mosaic_gpu_support", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index be83fd4f6b18..bd4c86e97ad7 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -23,7 +23,18 @@ package( py_library( name = "mosaic_gpu", data = [":libmosaic_gpu_runtime.so"], - deps = [":_mosaic_gpu_ext"], + deps = [ + ":_mosaic_gpu_ext", + ":mosaic_gpu_support", + ], +) + +cc_library( + name = "mosaic_gpu_support", + deps = [ + ":custom_call", + ":runtime", + ], ) cc_library( From 268c6842cc2c272028c816c1d1ce0346d40be724 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 5 May 2025 10:44:02 -0700 Subject: [PATCH 1007/1769] [JAX] In McJAX, support JIT compilation/execution on only the devices attached to a single host, with other hosts not participating. PiperOrigin-RevId: 754994333 --- jax/_src/dispatch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 64991c2fd3e2..ebab2120c4d0 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -471,9 +471,15 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types): - # TODO(emilyaf): Remove this condition when jit works when a sharding - # has no local devices. - if not config.enable_empty_arrays.value: + # If all hosts participate in the sharding, assert that the input is the + # same on all hosts. If some hosts have no addressable devices in the + # sharding, bypass the check, since we can't easily distinguish between + # these two cases: (1) the sharding contains the same subset of global + # devices on all hosts (and hosts with no addressable devices in the + # sharding do not transfer data) or (2) the sharding contains a + # different subset of devices on each host. For (1), the input should be + # the same on all hosts, but for (2) it need not be. + if jax.process_count() == len(s._internal_device_list.process_indices): # pytype: disable=attribute-error multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" From 13ca7002b0f5197b7c7e2a8d11a6a9490d4a9ef0 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 5 May 2025 11:12:18 -0700 Subject: [PATCH 1008/1769] Add workflow for testing JAX against NumPy nightly builds. This workflow builds `jax`, `jaxlib` at HEAD and `ml_dtypes` HEAD(compiled with NumPy nightly). It then runs the JAX tests against the NumPy nightly wheel. PiperOrigin-RevId: 755005016 --- .github/workflows/numpy_nightly.yml | 86 +++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 .github/workflows/numpy_nightly.yml diff --git a/.github/workflows/numpy_nightly.yml b/.github/workflows/numpy_nightly.yml new file mode 100644 index 000000000000..51876a7eb71d --- /dev/null +++ b/.github/workflows/numpy_nightly.yml @@ -0,0 +1,86 @@ + +name: CI - jaxlib head with NumPy nightly +# This workflow is used to build and test against NumPy nightly releases. We build ml_dtypes from +# HEAD using the NumPy nightly ABI, then build jaxlib at head, and then finally run tests against +# NumPy nightly. + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + schedule: + - cron: "0 */3 * * *" # Run once every 3 hours + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} + +jobs: + test-nightly-numpy: + defaults: + run: + shell: bash + runs-on: "linux-x86-n2-64" + strategy: + matrix: + python: ["3.13",] + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + name: "CI - jaxlib head with NumPy nightly" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_PYTHON: "python${{ matrix.python }}" + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + JAXCI_CLONE_MAIN_XLA: 1 + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout ml_dtypes + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + repository: jax-ml/ml_dtypes + ref: main + path: ml_dtypes + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Install numpy & scipy development versions + run: | + "$JAXCI_PYTHON" -m uv pip install \ + --system \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ + --no-deps \ + --pre \ + --upgrade \ + numpy \ + scipy + "$JAXCI_PYTHON" -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes with NumPy nightly + run: | + pushd ml_dtypes + git submodule init + git submodule update + "$JAXCI_PYTHON" -m uv pip install . --no-build-isolation + popd + - name: Build jax at HEAD + run: ./ci/build_artifacts.sh jax + - name: Build jaxlib at HEAD + run: ./ci/build_artifacts.sh jaxlib + - name: Install test dependencies + run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + - name: Run Pytest CPU tests + timeout-minutes: 30 + run: ./ci/run_pytest_cpu.sh \ No newline at end of file From 2c52ae3569f313fa263046b203a1af37fb07ec99 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 29 Apr 2025 13:23:51 +0000 Subject: [PATCH 1009/1769] Fixed deadlock in MakeShardFn on static var assignment under free-threading Description: - Fixed deadlock in MakeShardFn on static var assignment under free-threading using nb::ft_mutex Fixes #28385 --- jaxlib/BUILD | 1 + jaxlib/py_values.cc | 21 ++++++++++++--------- jaxlib/sharding.cc | 23 ++++++----------------- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 62621302b8f3..9278e996d1fb 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -898,6 +898,7 @@ cc_library( "@xla//xla/python/pjrt_ifrt", "@xla//xla/python/pjrt_ifrt:pjrt_dtype", "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/python:safe_static_init", "@xla//xla/service:platform_util", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/framework:allocator", diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index b0ac2171eb8d..5225b67ee93b 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -60,6 +60,7 @@ limitations under the License. #include "xla/python/ifrt/user_context.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" @@ -591,9 +592,11 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, ifrt::Device* to_device, ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options) { - static const absl::flat_hash_map* const handlers = [] { - auto p = new absl::flat_hash_map(); + using PyObjectDeviceHandlerMap = absl::flat_hash_map; + + auto init_fn = [](){ + std::unique_ptr p = std::make_unique(); + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); // Python scalar types. static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte"); @@ -660,20 +663,20 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, static_assert(sizeof(int) == sizeof(int32_t), "int must be the same size as int32_t"); (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; - return p; - }(); + }; + const PyObjectDeviceHandlerMap& handlers = xla::SafeStaticInit(init_fn); if (arg.type().ptr() == PyArray::type().ptr()) { auto array = nb::borrow(arg); return HandlePyArray(arg, client, to_device, to_memory_kind, options); } - auto res = handlers->find(arg.type().ptr()); - if (res == handlers->end()) { + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { for (auto base_class : arg.type().attr("__mro__")) { - res = handlers->find(base_class.ptr()); - if (res != handlers->end()) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { return res->second(arg, client, to_device, to_memory_kind, options); } } diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc index 2d8a88a6509d..fa19e1434a90 100644 --- a/jaxlib/sharding.cc +++ b/jaxlib/sharding.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -240,24 +241,12 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - nb::object* check_pspec = []() { - static absl::Mutex mu; - static nb::object* output = nullptr; - { - absl::MutexLock lock(&mu); - if (output) { - return output; - } - } + auto init_fn = [](){ nb::module_ si = nb::module_::import_("jax._src.named_sharding"); - nb::object attr = si.attr("check_pspec"); - absl::MutexLock lock(&mu); - if (!output) { - output = new nb::object(attr); - } - return output; - }(); - (*check_pspec)(mesh_, spec_); + return std::make_unique(si.attr("check_pspec")); + }; + nb::object& check_pspec = xla::SafeStaticInit(init_fn); + check_pspec(mesh_, spec_); } /*static*/ PyObject* NamedSharding::type_ = nullptr; From cf976283cb8d8aa44c69989b247431c5103c43fb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 5 May 2025 14:28:55 -0700 Subject: [PATCH 1010/1769] Replace C++ safe_zip with Python implementation using 3.10 strict keyword PiperOrigin-RevId: 755075208 --- jax/_src/util.py | 45 +++++++++++----------- jaxlib/utils.cc | 94 ---------------------------------------------- tests/util_test.py | 18 ++++----- 3 files changed, 33 insertions(+), 124 deletions(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index 1227b15c5ead..94b6380b13d4 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -42,29 +42,32 @@ T2 = TypeVar("T2") T3 = TypeVar("T3") +# safe_zip cannot yet be fully annotated, so we use a strategy similar +# to that used for builtins.zip in python/typeshed. This supports +# return types matching input types for up to three arguments. +@overload +def safe_zip(__arg1: Iterable[T1], /) -> list[tuple[T1]]: ... +@overload +def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[tuple[T1, T2]]: ... +@overload +def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[tuple[T1, T2, T3]]: ... +@overload +def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[tuple[Any, ...]]: ... + +def safe_zip(*args): + """ + Like builtin :func:`zip`, but with additional safety checks. -if TYPE_CHECKING: - # safe_zip cannot yet be fully annotated, so we use a strategy similar - # to that used for builtins.zip in python/typeshed. This supports - # return types matching input types for up to three arguments. - @overload - def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ... - @overload - def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ... - @overload - def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ... - @overload - def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ... + The differences from :func:`zip` are: - def safe_zip(*args): - args = list(map(list, args)) - n = len(args[0]) - for arg in args[1:]: - assert len(arg) == n, f'length mismatch: {list(map(len, args))}' - return list(zip(*args)) - -else: - safe_zip = jaxlib_utils.safe_zip + - :func:`safe_zip` checks that at least one argument is provided. + - :func:`safe_zip` checks that all arguments have the same length. + - :func:`safe_zip` returns an eagerly-evaluated list instead of a + lazily-evaluated iterator. + """ + if not args: + raise TypeError("safe_zip requires at least 1 argument.") + return list(zip(*args, strict=True)) if TYPE_CHECKING: diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index e5bb45e999da..1cf6798010ed 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -204,98 +204,6 @@ PyMethodDef foreach_def = { "ignoring the return values and returns None. The iterables must all have " "the same lengths."}; -// A variant of zip(...) that: -// a) returns a list instead of an iterator, and -// b) checks that the input iterables are of equal length. -// TODO(phawkins): consider replacing this function with -// list(zip(..., strict=True)) once TensorFlow 2.13 is released, which should -// resolve an incompatibility with strict=True and jax2tf. -PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - if (nargs < 1) { - PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument"); - return nullptr; - } - absl::InlinedVector iterators; - iterators.reserve(nargs); - for (Py_ssize_t i = 0; i < nargs; ++i) { - PyObject* it = PyObject_GetIter(args[i]); - if (!it) return nullptr; - iterators.push_back(nb::steal(it)); - } - - // Try to use a length hint to estimate how large a list to allocate. - Py_ssize_t length_hint = PyObject_LengthHint(args[0], 2); - if (PyErr_Occurred()) { - PyErr_Clear(); - } - if (length_hint < 0) { - length_hint = 2; - } - - nb::list list = nb::steal(PyList_New(length_hint)); - int n = 0; // Current true size of the list - - while (true) { - nb::object tuple; - nb::object v = nb::steal(PyIter_Next(iterators[0].ptr())); - if (PyErr_Occurred()) return nullptr; - - if (v.ptr()) { - tuple = nb::steal(PyTuple_New(nargs)); - if (!tuple.ptr()) return nullptr; - - PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr()); - for (size_t i = 1; i < iterators.size(); ++i) { - v = nb::steal(PyIter_Next(iterators[i].ptr())); - if (PyErr_Occurred()) return nullptr; - if (!v.ptr()) { - PyErr_Format(PyExc_ValueError, - "safe_zip() argument %u is shorter than argument 1", - i + 1); - return nullptr; - } - PyTuple_SET_ITEM(tuple.ptr(), i, v.release().ptr()); - } - } else { - // No more elements should be left. Checks the other iterators are - // exhausted. - for (size_t i = 1; i < iterators.size(); ++i) { - v = nb::steal(PyIter_Next(iterators[i].ptr())); - if (PyErr_Occurred()) return nullptr; - if (v.ptr()) { - PyErr_Format(PyExc_ValueError, - "safe_zip() argument %u is longer than argument 1", - i + 1); - return nullptr; - } - } - - // If the length hint was too large, truncate the list to the true size. - if (n < length_hint) { - if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) { - return nullptr; - } - } - return list.release().ptr(); - } - - if (n < length_hint) { - PyList_SET_ITEM(list.ptr(), n, tuple.release().ptr()); - } else { - if (PyList_Append(list.ptr(), tuple.ptr()) < 0) { - return nullptr; - } - tuple = nb::object(); - } - ++n; - } -} - -PyMethodDef safe_zip_def = { - "safe_zip", - reinterpret_cast(SafeZip), - METH_FASTCALL, -}; nb::list TopologicalSort(nb::str parents_attr, nb::iterable end_nodes_iterable) { @@ -368,8 +276,6 @@ NB_MODULE(utils, m) { PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); m.attr("foreach") = nb::steal( PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr())); - m.attr("safe_zip") = nb::steal( - PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"), nb::arg("end_nodes"), diff --git a/tests/util_test.py b/tests/util_test.py index 53414dae977f..90506117af8f 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -175,28 +175,28 @@ def test_safe_zip(self): ) def test_safe_zip_errors(self): - with self.assertRaisesRegex( - TypeError, "safe_zip requires at least 1 argument" + with self.assertRaisesWithLiteralMatch( + TypeError, "safe_zip requires at least 1 argument." ): util.safe_zip() - with self.assertRaisesRegex( + with self.assertRaisesWithLiteralMatch( TypeError, "'function' object is not iterable" ): util.safe_zip(lambda x: x) - with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + with self.assertRaisesWithLiteralMatch( + ValueError, "zip() argument 2 is longer than argument 1" ): util.safe_zip(range(3), range(4)) - with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1" + with self.assertRaisesWithLiteralMatch( + ValueError, "zip() argument 2 is shorter than argument 1" ): util.safe_zip(range(7), range(2)) - with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + with self.assertRaisesWithLiteralMatch( + ValueError, "zip() argument 2 is longer than argument 1" ): util.safe_zip((), range(3)) From 6d1b5271a115007162e9f98561d6b118aa66382c Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 5 May 2025 14:46:16 -0700 Subject: [PATCH 1011/1769] Support `np.asarray` for JAX arrays with fully replicated shardings when some hosts have no local shards. Only hosts with local shards can fetch np arrays. PiperOrigin-RevId: 755081626 --- jax/_src/array.py | 3 ++- jaxlib/py_array.cc | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f2b070c8221d..422fa5086e62 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -636,7 +636,8 @@ def _value(self) -> np.ndarray: self._check_if_deleted() if self._npy_value is None: - if self.is_fully_replicated: + if (self.is_fully_replicated and + self.sharding._internal_device_list.addressable_device_list): # type: ignore npy_value, did_copy = self._single_device_array_to_np_array_did_copy() npy_value.flags.writeable = False if did_copy: diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 4c5baab58683..a977d5f1a554 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1528,6 +1528,9 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { absl::Span> buffers = array->pjrt_buffers(); + if (buffers.empty()) { + return InvalidArgument("Array has no buffers."); + } PjRtBuffer& buffer = *buffers.front(); if (!buffer.IsOnCpu()) { return InvalidArgument( From 8f17a552528da9a2f66587933756c33a463d27ce Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Mon, 5 May 2025 14:50:46 -0700 Subject: [PATCH 1012/1769] [docs] Remove tensorflow from tensorboard installation instructions PiperOrigin-RevId: 755083199 --- docs/profiling.md | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/profiling.md b/docs/profiling.md index ac992b3a05da..c33e79c1dc0c 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -79,24 +79,22 @@ GPU and TPU. The end result looks something like this: ### Installation -The TensorBoard profiler is only available with the version of TensorBoard -bundled with TensorFlow. - +The TensorBoard profiler is available as a plugin to TensorBoard ```shell -pip install tensorflow tensorboard-plugin-profile +pip install tensorboard tensorboard-plugin-profile ``` -If you already have TensorFlow installed, you only need to install the +If you already have TensorBoard installed, you only need to install the `tensorboard-plugin-profile` pip package. Be careful to only install one version of TensorFlow or TensorBoard, otherwise you may encounter the "duplicate plugins" error described {ref}`below `. See for more information on installing TensorBoard. -Nightly version of TensorBoard profiler requires nightly tensorflow and -tensorboard +Profiling with the nightly version of TensorBoard requires the nightly +tensorboard profiler plugin ```shell -pip install tf-nightly tb-nightly tbp-nightly +pip install tb-nightly tbp-nightly ``` ### Programmatic capture @@ -156,7 +154,7 @@ example. You can specify a different port with the `--port` flag. See Then, either select "Profile" in the upper-right dropdown menu, or go directly to . Available traces appear in the "Runs" dropdown menu on the left. Select the run you're interested in, and then under -"Tools", select `trace_viewer`. You should now see a timeline of the +"Tools", select `trace_viewer`. You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See [these TensorFlow docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) @@ -308,8 +306,8 @@ replace, so it may be necessary to uninstall everything and reinstall a single version: ```shell -pip uninstall tensorflow tf-nightly tensorboard tb-nightly -pip install tensorflow +pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile tbp-nightly +pip install tensorboard tensorboard-plugin-profile ``` ## Nsight From 47b13c77b3a1e15ace8dacdba1dda5fc388d7835 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 6 May 2025 07:44:03 +0800 Subject: [PATCH 1013/1769] Update sharded-computation.ipynb --- docs/sharded-computation.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index c9e947d374f0..1bae4014b5a8 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -380,7 +380,7 @@ "id": "Q4N5mrr9i_ki" }, "source": [ - "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", + "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on.\n", "\n", "## 2. Explicit sharding\n", "\n", From d9aee1521761b066b6a0f1695d3dd604c06f8cbe Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Mon, 5 May 2025 18:05:17 -0700 Subject: [PATCH 1014/1769] PR #28157: Warn on excessive captured constants Imported from GitHub PR https://github.com/jax-ml/jax/pull/28157 One of the most common modes by which users may have extended lowering times is by unintentionally capturing, rather than tracing, a large number of constants, e.g. capturing the model weights as part of the computation. To address this we introduce two new flags - `jax_captured_constants_warn_bytes` defaults to `2 * 10 ** 9` (2GB). The number of total bytes of captured constants before warning is issued. (Note that the maximum size binary XLA can serialize for the compilation cache is 2GB.) - `jax_captured_constants_report_frames` defaults to `0`. If a warning is issued, how many stack frames to report for each constant. Defaults to 0, which means we don't generate the report by default. Reports all frames if set to `-1`. Both message are returned by using `warnings.warn`. The envisioned workflow for debugging captured constants is as follows: 1. The user is alerted to the problem by the initial (by default). This looks something like: ``` UserWarning: A large amount of constants were captured during lowering (125.00GB total). If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. To obtain a report of where these constants were encountered, set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1. ``` 2. The user sets the `JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1` to obtain debugging information to locate the source of the constants. Setting `JAX_CAPTURED_CONSTANTS_REPORT_FRAMES` to a small positive integer will return a suffix of the number of captured frames. Upon rerunning the code, the report generated will be something like this: ``` UserWarning: A large amount of constants were captured during lowering (125.00GB total). If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. The subsequent report may be disabled by setting JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=0. Largest 5 allocation(s): Constant , float32[1439,721,720], 5.00GB captured at: /home/user/project/main.py:193 () /home/user/project/main.py:156 (run_export) /home/user/project/main.py:147 (run_forward_exported) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:98 (tree_map_over_nonscalars) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:97 (g) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:661 (g) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:266 (inverse_transform) Constant , float32[1439,721,720], 5.00GB captured at: /home/user/project/main.py:193 () /home/user/project/main.py:156 (run_export) /home/user/project/main.py:147 (run_forward_exported) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:98 (tree_map_over_nonscalars) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:97 (g) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:661 (g) /home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:266 (inverse_transform) ``` and so fourth. The report is hard coded to only report the the top 5 (ordered by `nbytes`) largest constants. 3. The user can then set a breakpoint at last line listed, and inspect the operands to find the problematic constant by the shape and type we report. The reason for the two stage process is it is very easy and cheap to check the number of bytes each time we lower a function. However generating the report involves a traversal of the Jaxpr, and so is not. For Googlers [b/403532544](http://b/403532544#comment26) Merging this change closes #28157 PiperOrigin-RevId: 755142222 --- jax/BUILD | 1 + jax/_src/config.py | 22 +++++++++++ jax/_src/interpreters/mlir.py | 72 +++++++++++++++++++++++++++++------ jax/_src/jaxpr_util.py | 32 ++++++++++++++++ jax/_src/util.py | 10 +++++ tests/jax_jit_test.py | 19 +++++++++ 6 files changed, 144 insertions(+), 12 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 6358cce78281..99cf508faf11 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -650,6 +650,7 @@ pytype_strict_library( ":core", ":dtypes", ":effects", + ":jaxpr_util", ":layout", ":mesh", ":op_shardings", diff --git a/jax/_src/config.py b/jax/_src/config.py index 2abb457ef115..e79993958349 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -970,6 +970,28 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) + +captured_constants_warn_bytes = int_state( + name='jax_captured_constants_warn_bytes', + default=2 * 10 ** 9, + help=('The number of bytes of parameters that may be captured as constants ' + 'before a warning is issued. Defaults to approximately 2GB. ' + 'Set to -1 to disable issuing a warning.' + ) +) + +captured_constants_report_frames = int_state( + name='jax_captured_constants_report_frames', + default=0, + help=('The number of stack frames reported for each captured constant ' + 'indicating the file and operation where the constant was captured. ' + 'Set to -1 to print the complete set of frames, or 0 to disable. ' + 'N.b. the report is only generated if the total amount of captured ' + 'constants exceeds `jax_captured_constants_warn_bytes`, as it is expensive' + 'to generate the report.' + ) +) + debug_nans = bool_state( name='jax_debug_nans', default=False, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ccbe84bcfa4e..e9deb8d3fff9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,11 +16,12 @@ from __future__ import annotations import collections -import contextlib from collections.abc import Callable, Iterable, Iterator, Sequence +import contextlib import dataclasses import functools from functools import partial +import heapq import io import itertools import operator @@ -31,14 +32,13 @@ from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings -import numpy as np - from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects as effects_lib +from jax._src import jaxpr_util from jax._src import linear_util as lu from jax._src import path from jax._src import sharding_impls @@ -48,19 +48,20 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout -from jax._src.partition_spec import PartitionSpec -from jax._src.mesh import AxisType -from jax._src.sharding import Sharding as JSharding -from jax._src.sharding_impls import (AUTO, NamedSharding, - modify_sdy_sharding_wrt_axis_types, - SdyArraySharding, SdyArrayShardingList) -from jax._src.util import foreach -from jax._src.lib import xla_client as xc from jax._src.lib import _jax +from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir, passmanager -from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects +from jax._src.lib.mlir.dialects import func as func_dialect, hlo +from jax._src.mesh import AxisType +from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding +from jax._src.sharding_impls import ( AUTO, NamedSharding, + SdyArraySharding, SdyArrayShardingList, + modify_sdy_sharding_wrt_axis_types) from jax._src.state.types import AbstractRef +from jax._src.util import foreach +import numpy as np # mypy: ignore-errors @@ -1104,6 +1105,51 @@ def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants: unconstrained_dims=unconstrained_dims) +def check_jaxpr_constants(closed_jaxpr: core.ClosedJaxpr): + """Check if a JAXPR contains an excessive amount of constants, if so, report where they were captured""" + if (threshold := config.captured_constants_warn_bytes.value) == -1: + return + + # need the unaesthetic getter here as some of the consts in the test suite are arbitrary objects + total_iter, nbytes_iter = itertools.tee( + map(lambda c: getattr(c, "nbytes", 0), closed_jaxpr.consts) + ) + + if (total_bytes := sum(total_iter)) < threshold: + return + + message = ( + "A large amount of constants were captured during lowering" + f" ({util.pprint_bytes(total_bytes)} total). If this is intentional," + " disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. " + ) + + if not (num_frames := config.captured_constants_report_frames.value): + message += ( + "To obtain a report of where these constants were encountered, " + "set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1." + ) + warnings.warn(message) + return + + message += ( + "The subsequent report may be disabled by setting JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=0.\n\n" + f"Largest {min(num_frames, len(closed_jaxpr.consts))} allocation(s):\n" + ) + try: + nbytes_var_const = zip(nbytes_iter, closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts) + for nbytes, var, const in heapq.nlargest(5, nbytes_var_const, key=operator.itemgetter(0)): + message += f" Constant {type(const)}, {var.aval.str_short()}, {util.pprint_bytes(nbytes)} captured at:\n" + + for eqn in jaxpr_util.eqns_using_var(closed_jaxpr.jaxpr, var): + call_frame_source_info = source_info_util.summarize(eqn.source_info, num_frames) + message += " " * 2 + call_frame_source_info.replace("\n", "\n" + " " * 2) + "\n\n" + + warnings.warn(message) + except Exception as exc: + warnings.warn(message + f" Exception raised while generating report: {exc}") + + def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -1441,6 +1487,8 @@ def lower_jaxpr_to_fun( MLIR func op """ util.test_event("lower_jaxpr_to_fun", name) + check_jaxpr_constants(jaxpr) + # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index ab72634d3bdf..a6c93c8c120c 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -209,3 +209,35 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: for _, eqn in all_eqns(jaxpr) ) return _pprof_profile(d) + +def eqns_using_var_with_invar_index(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[tuple[core.JaxprEqn, int]]: + """Find all the equations which use invar and the positional index of its binder""" + for eqn in jaxpr.eqns: + for invar_index, eqn_var in enumerate(eqn.invars): + if eqn_var == invar: + yield eqn, invar_index + break # we found the var, no need to keep looking in this eqn + +def jaxpr_and_binder_in_params(params, index: int) -> Iterator[tuple[core.Jaxpr, core.Var]]: + for val in params.values(): + vals = val if isinstance(val, tuple) else (val,) + for v in vals: + if isinstance(v, core.Jaxpr): + if index >= len(v.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v, v.invars[index] + elif isinstance(v, core.ClosedJaxpr): + if index >= len(v.jaxpr.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v.jaxpr, v.jaxpr.invars[index] + +def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn]: + """Find the leaf equations using a variable""" + # The complexity of this call is becauase the invar might originate from a nested jaxpr + for eqn, invar_index in eqns_using_var_with_invar_index(jaxpr, invar): + if (child_jaxprs_and_vars := tuple(jaxpr_and_binder_in_params(eqn.params, invar_index))): + for (jaxpr, invar) in child_jaxprs_and_vars: + yield from eqns_using_var(jaxpr, invar) + else: + # if the previous condition fails, there is no deeper jaxpr to explore =( + yield eqn diff --git a/jax/_src/util.py b/jax/_src/util.py index 94b6380b13d4..34f748544d6d 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -20,6 +20,7 @@ from functools import partial import itertools as it import logging +import math import operator from typing import (Any, Generic, SupportsIndex, TypeVar, overload, TYPE_CHECKING, cast) import weakref @@ -675,3 +676,12 @@ def test_event(name: str, *args) -> None: if hasattr(jaxlib_utils, "Mutex"): Mutex = jaxlib_utils.Mutex + + +def pprint_bytes(num_bytes: int | float) -> str: + prefixes = ("", "K", "M", "G", "T") + if num_bytes <= 0: + return "0.00B" + exponent = min(math.floor(math.log(num_bytes, 1000)), len(prefixes) - 1) + scaled_value = num_bytes / (1000**exponent) + return f"{scaled_value:.2f}{prefixes[exponent]}B" diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 5946d557d4ba..cbf7c710f0e8 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -227,6 +227,25 @@ def fn(x): self.assertArraysEqual(v1, v1_expected) self.assertArraysEqual(v2, v2_expected) + def test_check_for_large_number_of_constants(self): + y = jnp.ones((128, 128)) + x = jnp.zeros((128,)) + + def jit_maker(): # need to ensure we lower at each test + def func(x): + return x @ y + return jax.jit(func) + + with self.assertWarnsRegex(UserWarning, "A large amount of constants were captured during lowering"): + with config.captured_constants_warn_bytes(y.nbytes): + jit_maker()(x) + + with self.assertNoWarnings(): + with config.captured_constants_warn_bytes(y.nbytes + 1): + jit_maker()(x) + + with config.captured_constants_warn_bytes(-1): + jit_maker()(x) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3193ed39646f6c7e2f898801648b9c034225b962 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 5 May 2025 18:26:10 -0700 Subject: [PATCH 1015/1769] Move attrs.py code to `_src/attrs.py` so that we can remove experimental imports in pjit.py which are in the middle of tracing code and can affect performance. PiperOrigin-RevId: 755147927 --- jax/BUILD | 1 + jax/_src/attrs.py | 400 +++++++++++++++++++++++++++++ jax/_src/lax/control_flow/loops.py | 3 +- jax/_src/pjit.py | 9 +- jax/experimental/attrs.py | 393 +--------------------------- tests/attrs_test.py | 2 +- 6 files changed, 412 insertions(+), 396 deletions(-) create mode 100644 jax/_src/attrs.py diff --git a/jax/BUILD b/jax/BUILD index 99cf508faf11..30e071b81a55 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -268,6 +268,7 @@ py_library_providing_imports_info( "_src/ad_checkpoint.py", "_src/api.py", "_src/array.py", + "_src/attrs.py", "_src/blocked_sampler.py", "_src/buffer_callback.py", "_src/callback.py", diff --git a/jax/_src/attrs.py b/jax/_src/attrs.py new file mode 100644 index 000000000000..db738ee6368d --- /dev/null +++ b/jax/_src/attrs.py @@ -0,0 +1,400 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Callable + +import jax +from jax._src import core +from jax._src import source_info_util +from jax._src import api_util +from jax._src import linear_util as lu +from jax._src.ad_util import (Zero) +from jax._src.api_util import flatten_fun_nokwargs +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe +from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, + treedef_tuple) +from jax._src.util import unzip2, safe_map, safe_zip, split_list +from jax._src.dtypes import dtype, float0 + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Array = Any +JaxVal = Any +PyTree = Any +PyTreeDef = Any + +ReadWrite = pe.ReadWrite +Append = pe.Append + +register = api_util.register_class_with_attrs +dne_sentinel = pe.dne_sentinel + +def jax_getattr(obj: Any, attr: str) -> PyTree: + with core.take_current_trace() as t: + return t.process_getattr(obj, attr) + +def jax_setattr(obj: Any, attr: str, val: PyTree) -> None: + with core.take_current_trace() as t: + return t.process_setattr(obj, attr, val) + +def jax_appendattr(obj: Any, attr: str, val: Array) -> None: + return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) + +def jax_extendattr(obj: Any, attr: str, val: Array) -> None: + with core.take_current_trace() as t: + return t.process_extendattr(obj, attr, val) + +def _getattr_impl(_, obj, attr): + return getattr(obj, attr) +core.EvalTrace.process_getattr = _getattr_impl + +def _setattr_impl(_, obj, attr, val): + setattr(obj, attr, val) +core.EvalTrace.process_setattr = _setattr_impl + +def _extendattr_impl(_, obj, attr, val): + cur = getattr(obj, attr, dne_sentinel) + if cur is dne_sentinel: + new = val + else: + _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) + new = jax.numpy.concatenate([cur, val]) + setattr(obj, attr, new) +core.EvalTrace.process_extendattr = _extendattr_impl + +def _check_append_type_agreement(_, attr, curtype, valtype): + expected = core.mapped_aval(curtype.shape[0], 0, curtype) + got = core.mapped_aval(valtype.shape[0], 0, valtype) + if not core.typematch(expected, got): + raise TypeError( + f"can only append to attr {attr} with values of trailing shape " + f"{expected.str_short()}, but appendattr got value of type " + f"{valtype.str_short()} which has trailing shape {got.str_short()}.") + +def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str, + kind: pe.AttrKind): + frame = trace.frame + source_info = source_info_util.current() + + def new_tracer(x): + aval = core.get_aval(x) + tracer = pe.DynamicJaxprTracer(trace, aval, source_info) + var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) + frame.attrs_vars.append(var) + frame.tracers.append(tracer) + return tracer + + if (obj, attr, Append) in frame.attrs_tracked: + raise TypeError(f"can't read/write to append-only attr {attr}") + + if (obj, attr, kind) not in frame.attrs_tracked: + init_val = getattr(obj, attr, dne_sentinel) + frame.attrs_inits.append(init_val) + init_vals, init_tree = tree_flatten(init_val) + tracers = map(new_tracer, init_vals) + setattr(obj, attr, tree_unflatten(init_tree, tracers)) + frame.attrs_tracked.append((obj, attr, kind)) +pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked + +def _getattr_staging(trace, obj, attr): + trace._ensure_tracked(obj, attr, ReadWrite) + return getattr(obj, attr) +pe.DynamicJaxprTrace.process_getattr = _getattr_staging + +def _setattr_staging(trace, obj, attr, val): + trace._ensure_tracked(obj, attr, ReadWrite) + setattr(obj, attr, val) +pe.DynamicJaxprTrace.process_setattr = _setattr_staging + +def _extendattr_staging(trace, obj, attr, val): + frame = trace.frame + + if (obj, attr, ReadWrite) in frame.attrs_tracked: + raise TypeError("can't append to read/write-only attr {attr}") + + first_write = (obj, attr, Append) not in frame.attrs_tracked + init_val = getattr(obj, attr, dne_sentinel) + if init_val is not dne_sentinel: + _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) + if first_write: + frame.attrs_inits.append(init_val) + frame.attrs_tracked.append((obj, attr, Append)) + tracer = val + else: + assert init_val is not dne_sentinel + with core.set_current_trace(trace): + tracer = jax.numpy.concatenate([init_val, val]) + setattr(obj, attr, tracer) +pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging + + +def jvp(f, primals, tangents, attr_tangents): + attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) + attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) + primals_flat, in_tree = tree_flatten((attr_primals, *primals)) + tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) + if in_tree != in_tree_: raise Exception + dbg = api_util.debug_info("attrs_jvp", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree) + out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( + primals_flat, tangents_flat) + out_primals = tree_unflatten(out_tree(), out_primals_flat) + out_tangents = tree_unflatten(out_tree(), out_tangents_flat) + return out_primals, out_tangents, tangent_attrs_out + +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): + for (o, a), x in zip(attrs, attr_vals): + jax_setattr(o, a, x) + return f(*args) + +def _jvp(fun: lu.WrappedFun): + return jvpfun2(jvp_subtrace2(fun)) + +@lu.transformation2 +def jvpfun2(f, primals, tangents): + tag = core.TraceTag() + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and dtype(t) == float0 else t for t in tangents] + ctx = source_info_util.transform_name_stack('jvp') + with ctx: + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out + +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = ad.JVPTrace(parent_trace, tag) + tag.attrs_tracked = [] # attrs written to + in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = f(*in_tracers) + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + tangent_attrs_out = [] + for (obj, name) in tag.attrs_tracked: + primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) + jax_setattr(obj, name, primal) + if type(tangent) is not ad.Zero: + tangent_attrs_out.append((obj, name, tangent)) + del tag.attrs_tracked + return out_primals, out_tangents, tangent_attrs_out + +def _setattr_jvp(trace, obj, attr, maybe_tracer): + primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) + if isinstance(tangent, ad.Zero): + return setattr(obj, attr, primal) + if (obj, attr) not in trace.tag.attrs_tracked: + trace.tag.attrs_tracked.append((obj, attr)) + return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) +ad.JVPTrace.process_setattr = _setattr_jvp + +def _getattr_jvp(trace, obj, attr): + return getattr(obj, attr) +ad.JVPTrace.process_getattr = _getattr_jvp + +ad.LinearizeTrace.process_setattr = _setattr_jvp +ad.LinearizeTrace.process_getattr = _getattr_jvp + +def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + attr_avals = [core.get_aval(p) for p in attr_primals] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + dbg = api_util.debug_info("attrs linearize", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_lin + +def _linearize(traceable: lu.WrappedFun, *primals): + jvpfun, attrs = _split_attrs(_jvp(traceable)) + in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) + for p in primals)) + _, in_tree = tree_flatten((primals, primals)) + jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) + jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) + out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ + tree_unflatten(out_tree(), out_pvals) + out_primals_consts = [pval.get_known() for pval in out_primals_pvals] + return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], + jaxpr, consts, attrs()) + +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) + attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) + store.store(attrs) + return primals, tangents, tangent_attr_vals + +def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + def f_lin(*tangents, attr_tangents): + if set(attr_tangents) - set(in_attrs): raise Exception + tangents_, in_tree_ = tree_flatten(tangents) + assert in_tree == in_tree_ + attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) + for a, aval in zip(in_attrs, attr_avals)] + out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) + out_ = iter(out) + out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] + assert next(out_, None) is None + tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) + out_ct = tree_unflatten(out_tree, tangents_out) + return out_ct, dict(zip(out_attrs, attr_tangents_out)) + return f_lin + + +def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + dbg = api_util.debug_info("attrs vjp", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval() + for o, a in attrs_out] + f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_vjp + +def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): + out_cts, out_tree_ = tree_flatten(out_ct) + assert out_tree == out_tree_ + attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) + for a, aval in zip(out_attrs, attr_avals)] + out = ad.backward_pass(jaxpr, (), consts, dummies, (*out_cts, *attr_cts)) + in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) + args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) + return args_ct, dict(zip(in_attrs, in_attr_bars)) + return f_vjp + + +class Box: + _val: PyTree + _tag: core.OpaqueTraceState + def __init__(self, val): + self._val = val + self._tag = core.get_opaque_trace_state() + def get(self): + with core.take_current_trace() as t: + return t.process_box_get(self) + def set(self, val): + with core.take_current_trace() as t: + return t.process_box_set(self, val) + +def _box_get_impl(trace, box): + return box._val +core.EvalTrace.process_box_get = _box_get_impl + +def _box_set_impl(trace, box, val): + box._val = val +core.EvalTrace.process_box_set = _box_set_impl + +def _is_local(trace, box): + is_arg = box._tag._trace_ref() is trace + if is_arg: assert box._tag._trace_ref() is trace + return is_arg + +def _box_get_staging(trace, box): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + return box._val +pe.DynamicJaxprTrace.process_box_get = _box_get_staging + +def _box_set_staging(trace, box, val): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + box._val = val +pe.DynamicJaxprTrace.process_box_set = _box_set_staging + +def _box_get_jvp(trace, box): + return box._val +ad.JVPTrace.process_box_get = _box_get_jvp + +def _box_set_jvp(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + box._val = ad.JVPTracer(trace, primal, tangent) +ad.JVPTrace.process_box_set = _box_set_jvp + +def _box_get_linearize(trace, box): + return box._val +ad.LinearizeTrace.process_box_get = _box_get_linearize + +def _box_set_linearize(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + raise NotImplementedError # TODO + box._val = ad.LinearizeTracer(trace, primal, tangent) +ad.LinearizeTrace.process_box_set = _box_set_linearize + + +class List: + _val: PyTree + _tag: core.OpaqueTraceState + _is_arg: bool + def __init__(self, val=None): + self._val = [] if val is None else val[:] + self._tag = core.get_opaque_trace_state() + self._is_arg = False + def append(self, val): + with core.take_current_trace() as t: + return t.process_list_append(self, val) + def get(self): + with core.take_current_trace() as t: + if _is_local(t, self) and not self._is_arg: + return self._val[:] # defensive copy in case caller erroneously mutates + raise Exception("can't read the value of a List that was not created in " + "this scope") +AppendList = List + +def _list_append_impl(trace, lst, val): + lst._val.append(val) +core.EvalTrace.process_list_append = _list_append_impl + +def _list_append_staging(trace, lst, val): + if not _is_local(trace, lst): + _ensure_list_tracked(trace, lst) + return _list_append_impl(trace, lst, val) +pe.DynamicJaxprTrace.process_list_append = _list_append_staging + +def _ensure_list_tracked(trace, lst): + frame = trace.frame + if (lst, '_val', pe.ListAttr) not in frame.attrs_tracked: + frame.attrs_inits.append(lst._val) + frame.attrs_tracked.append((lst, '_val', pe.ListAttr)) + lst._val = [] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index da3fa6fa1019..05e9c010dc51 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -59,6 +59,7 @@ from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors from jax._src.typing import Array +from jax._src.attrs import jax_setattr, jax_getattr, jax_extendattr from jax._src.util import ( merge_lists, partition_list, safe_map, safe_zip, split_list, split_list_checked, unzip2, weakref_lru_cache,) @@ -372,7 +373,6 @@ def _create_jaxpr(init): return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): if kind is pe.ReadWrite: @@ -392,7 +392,6 @@ def _set_states(attrs_tracked, vals): assert False def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr vals = [] for treedef, _, (obj, attr, kind) in attrs_tracked: if kind is pe.ReadWrite: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 79948cc1da06..81d8982184c1 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -80,6 +80,8 @@ HashableFunction, safe_map, safe_zip, wraps, tuple_insert, distributed_debug_log, split_list, split_list_checked, weakref_lru_cache, merge_lists, subs_list, fun_name, fun_qual_name) +from jax._src.attrs import (Box, List, dne_sentinel, jax_setattr, jax_getattr, + jax_extendattr) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -723,7 +725,6 @@ def _infer_params_internal( static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) - from jax.experimental.attrs import Box, List any_boxes = any(isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))) if config.dynamic_shapes.value or any_boxes: # don't use the cache p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, @@ -1507,7 +1508,6 @@ def _attr_cache_index( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): for obj, attr, kind, treedef, avals in records: @@ -1521,7 +1521,6 @@ def _attr_cache_index( return len(cases) def _attr_cachedata_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import dne_sentinel leaves = lambda obj, attr: tree_leaves(getattr(obj, attr, dne_sentinel)) records = [(obj, attr, kind, init_tree, map(core.typeof, leaves(obj, attr))) for init_tree, _, (obj, attr, kind) in attrs_tracked] @@ -3177,7 +3176,6 @@ def get_unconstrained_dims(sharding: NamedSharding): # -------------------- attrs etc -------------------- def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): if kind is pe.ReadWrite: @@ -3197,7 +3195,6 @@ def _set_states(attrs_tracked, vals): assert False def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] for treedef, _, (obj, attr, kind) in attrs_tracked: if kind is pe.ReadWrite: @@ -3234,7 +3231,6 @@ class ListTree: treedef: PyTreeDef | None = static() def _flatten_boxes(dbg, args, kwargs): - from jax.experimental.attrs import Box, List # TODO(mattjj,dougalm): refine this implementation of box-handling... if all(not isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))): return args, kwargs, [] @@ -3271,7 +3267,6 @@ def visit(x): # Using obscure names is a temporary workaround; revise! @lu.transformation2 def _handle_boxes(__f, __dbg, *args, **kwargs): - from jax.experimental.attrs import Box, List f, dbg = __f, __dbg new_args = [] arg_mutables = [] diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index db738ee6368d..8984b2159d82 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -12,389 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from typing import Any, Callable - -import jax -from jax._src import core -from jax._src import source_info_util -from jax._src import api_util -from jax._src import linear_util as lu -from jax._src.ad_util import (Zero) -from jax._src.api_util import flatten_fun_nokwargs -from jax._src.interpreters import ad -from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, - treedef_tuple) -from jax._src.util import unzip2, safe_map, safe_zip, split_list -from jax._src.dtypes import dtype, float0 - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -Array = Any -JaxVal = Any -PyTree = Any -PyTreeDef = Any - -ReadWrite = pe.ReadWrite -Append = pe.Append - -register = api_util.register_class_with_attrs -dne_sentinel = pe.dne_sentinel - -def jax_getattr(obj: Any, attr: str) -> PyTree: - with core.take_current_trace() as t: - return t.process_getattr(obj, attr) - -def jax_setattr(obj: Any, attr: str, val: PyTree) -> None: - with core.take_current_trace() as t: - return t.process_setattr(obj, attr, val) - -def jax_appendattr(obj: Any, attr: str, val: Array) -> None: - return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) - -def jax_extendattr(obj: Any, attr: str, val: Array) -> None: - with core.take_current_trace() as t: - return t.process_extendattr(obj, attr, val) - -def _getattr_impl(_, obj, attr): - return getattr(obj, attr) -core.EvalTrace.process_getattr = _getattr_impl - -def _setattr_impl(_, obj, attr, val): - setattr(obj, attr, val) -core.EvalTrace.process_setattr = _setattr_impl - -def _extendattr_impl(_, obj, attr, val): - cur = getattr(obj, attr, dne_sentinel) - if cur is dne_sentinel: - new = val - else: - _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) - new = jax.numpy.concatenate([cur, val]) - setattr(obj, attr, new) -core.EvalTrace.process_extendattr = _extendattr_impl - -def _check_append_type_agreement(_, attr, curtype, valtype): - expected = core.mapped_aval(curtype.shape[0], 0, curtype) - got = core.mapped_aval(valtype.shape[0], 0, valtype) - if not core.typematch(expected, got): - raise TypeError( - f"can only append to attr {attr} with values of trailing shape " - f"{expected.str_short()}, but appendattr got value of type " - f"{valtype.str_short()} which has trailing shape {got.str_short()}.") - -def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str, - kind: pe.AttrKind): - frame = trace.frame - source_info = source_info_util.current() - - def new_tracer(x): - aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, source_info) - var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) - frame.attrs_vars.append(var) - frame.tracers.append(tracer) - return tracer - - if (obj, attr, Append) in frame.attrs_tracked: - raise TypeError(f"can't read/write to append-only attr {attr}") - - if (obj, attr, kind) not in frame.attrs_tracked: - init_val = getattr(obj, attr, dne_sentinel) - frame.attrs_inits.append(init_val) - init_vals, init_tree = tree_flatten(init_val) - tracers = map(new_tracer, init_vals) - setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr, kind)) -pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked - -def _getattr_staging(trace, obj, attr): - trace._ensure_tracked(obj, attr, ReadWrite) - return getattr(obj, attr) -pe.DynamicJaxprTrace.process_getattr = _getattr_staging - -def _setattr_staging(trace, obj, attr, val): - trace._ensure_tracked(obj, attr, ReadWrite) - setattr(obj, attr, val) -pe.DynamicJaxprTrace.process_setattr = _setattr_staging - -def _extendattr_staging(trace, obj, attr, val): - frame = trace.frame - - if (obj, attr, ReadWrite) in frame.attrs_tracked: - raise TypeError("can't append to read/write-only attr {attr}") - - first_write = (obj, attr, Append) not in frame.attrs_tracked - init_val = getattr(obj, attr, dne_sentinel) - if init_val is not dne_sentinel: - _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) - if first_write: - frame.attrs_inits.append(init_val) - frame.attrs_tracked.append((obj, attr, Append)) - tracer = val - else: - assert init_val is not dne_sentinel - with core.set_current_trace(trace): - tracer = jax.numpy.concatenate([init_val, val]) - setattr(obj, attr, tracer) -pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging - - -def jvp(f, primals, tangents, attr_tangents): - attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) - attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) - primals_flat, in_tree = tree_flatten((attr_primals, *primals)) - tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) - if in_tree != in_tree_: raise Exception - dbg = api_util.debug_info("attrs_jvp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree) - out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( - primals_flat, tangents_flat) - out_primals = tree_unflatten(out_tree(), out_primals_flat) - out_tangents = tree_unflatten(out_tree(), out_tangents_flat) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def _set_attrs(f, attrs, attr_vals, *args): - for (o, a), x in zip(attrs, attr_vals): - jax_setattr(o, a, x) - return f(*args) - -def _jvp(fun: lu.WrappedFun): - return jvpfun2(jvp_subtrace2(fun)) - -@lu.transformation2 -def jvpfun2(f, primals, tangents): - tag = core.TraceTag() - tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) - and dtype(t) == float0 else t for t in tangents] - ctx = source_info_util.transform_name_stack('jvp') - with ctx: - out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def jvp_subtrace2(f, tag, primals, tangents): - with core.take_current_trace() as parent_trace: - trace = ad.JVPTrace(parent_trace, tag) - tag.attrs_tracked = [] # attrs written to - in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x - for x, t in zip(primals, tangents)] - with core.set_current_trace(trace): - ans = f(*in_tracers) - out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) - tangent_attrs_out = [] - for (obj, name) in tag.attrs_tracked: - primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) - jax_setattr(obj, name, primal) - if type(tangent) is not ad.Zero: - tangent_attrs_out.append((obj, name, tangent)) - del tag.attrs_tracked - return out_primals, out_tangents, tangent_attrs_out - -def _setattr_jvp(trace, obj, attr, maybe_tracer): - primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) - if isinstance(tangent, ad.Zero): - return setattr(obj, attr, primal) - if (obj, attr) not in trace.tag.attrs_tracked: - trace.tag.attrs_tracked.append((obj, attr)) - return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) -ad.JVPTrace.process_setattr = _setattr_jvp - -def _getattr_jvp(trace, obj, attr): - return getattr(obj, attr) -ad.JVPTrace.process_getattr = _getattr_jvp - -ad.LinearizeTrace.process_setattr = _setattr_jvp -ad.LinearizeTrace.process_getattr = _getattr_jvp - -def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - attr_avals = [core.get_aval(p) for p in attr_primals] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs linearize", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_lin - -def _linearize(traceable: lu.WrappedFun, *primals): - jvpfun, attrs = _split_attrs(_jvp(traceable)) - in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) - for p in primals)) - _, in_tree = tree_flatten((primals, primals)) - jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) - jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) - out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ - tree_unflatten(out_tree(), out_pvals) - out_primals_consts = [pval.get_known() for pval in out_primals_pvals] - return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], - jaxpr, consts, attrs()) - -@lu.transformation_with_aux2 -def _split_attrs(f, store, *args, **kwargs): - primals, tangents, tangent_attrs = f(*args, **kwargs) - attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - store.store(attrs) - return primals, tangents, tangent_attr_vals - -def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - def f_lin(*tangents, attr_tangents): - if set(attr_tangents) - set(in_attrs): raise Exception - tangents_, in_tree_ = tree_flatten(tangents) - assert in_tree == in_tree_ - attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) - for a, aval in zip(in_attrs, attr_avals)] - out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) - out_ = iter(out) - out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] - assert next(out_, None) is None - tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) - out_ct = tree_unflatten(out_tree, tangents_out) - return out_ct, dict(zip(out_attrs, attr_tangents_out)) - return f_lin - - -def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs vjp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval() - for o, a in attrs_out] - f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_vjp - -def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] - def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): - out_cts, out_tree_ = tree_flatten(out_ct) - assert out_tree == out_tree_ - attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) - for a, aval in zip(out_attrs, attr_avals)] - out = ad.backward_pass(jaxpr, (), consts, dummies, (*out_cts, *attr_cts)) - in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) - args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) - return args_ct, dict(zip(in_attrs, in_attr_bars)) - return f_vjp - - -class Box: - _val: PyTree - _tag: core.OpaqueTraceState - def __init__(self, val): - self._val = val - self._tag = core.get_opaque_trace_state() - def get(self): - with core.take_current_trace() as t: - return t.process_box_get(self) - def set(self, val): - with core.take_current_trace() as t: - return t.process_box_set(self, val) - -def _box_get_impl(trace, box): - return box._val -core.EvalTrace.process_box_get = _box_get_impl - -def _box_set_impl(trace, box, val): - box._val = val -core.EvalTrace.process_box_set = _box_set_impl - -def _is_local(trace, box): - is_arg = box._tag._trace_ref() is trace - if is_arg: assert box._tag._trace_ref() is trace - return is_arg - -def _box_get_staging(trace, box): - if not _is_local(trace, box): - trace._ensure_tracked(box, '_val', pe.BoxAttr) - return box._val -pe.DynamicJaxprTrace.process_box_get = _box_get_staging - -def _box_set_staging(trace, box, val): - if not _is_local(trace, box): - trace._ensure_tracked(box, '_val', pe.BoxAttr) - box._val = val -pe.DynamicJaxprTrace.process_box_set = _box_set_staging - -def _box_get_jvp(trace, box): - return box._val -ad.JVPTrace.process_box_get = _box_get_jvp - -def _box_set_jvp(trace, box, val): - primal, tangent = trace.to_primal_tangent_pair(val) - if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): - raise Exception - if isinstance(tangent, ad.Zero): - box._val = primal - else: - box._val = ad.JVPTracer(trace, primal, tangent) -ad.JVPTrace.process_box_set = _box_set_jvp - -def _box_get_linearize(trace, box): - return box._val -ad.LinearizeTrace.process_box_get = _box_get_linearize - -def _box_set_linearize(trace, box, val): - primal, tangent = trace.to_primal_tangent_pair(val) - if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): - raise Exception - if isinstance(tangent, ad.Zero): - box._val = primal - else: - raise NotImplementedError # TODO - box._val = ad.LinearizeTracer(trace, primal, tangent) -ad.LinearizeTrace.process_box_set = _box_set_linearize - - -class List: - _val: PyTree - _tag: core.OpaqueTraceState - _is_arg: bool - def __init__(self, val=None): - self._val = [] if val is None else val[:] - self._tag = core.get_opaque_trace_state() - self._is_arg = False - def append(self, val): - with core.take_current_trace() as t: - return t.process_list_append(self, val) - def get(self): - with core.take_current_trace() as t: - if _is_local(t, self) and not self._is_arg: - return self._val[:] # defensive copy in case caller erroneously mutates - raise Exception("can't read the value of a List that was not created in " - "this scope") -AppendList = List - -def _list_append_impl(trace, lst, val): - lst._val.append(val) -core.EvalTrace.process_list_append = _list_append_impl - -def _list_append_staging(trace, lst, val): - if not _is_local(trace, lst): - _ensure_list_tracked(trace, lst) - return _list_append_impl(trace, lst, val) -pe.DynamicJaxprTrace.process_list_append = _list_append_staging - -def _ensure_list_tracked(trace, lst): - frame = trace.frame - if (lst, '_val', pe.ListAttr) not in frame.attrs_tracked: - frame.attrs_inits.append(lst._val) - frame.attrs_tracked.append((lst, '_val', pe.ListAttr)) - lst._val = [] +from jax._src.attrs import ( + jax_setattr as jax_setattr, + jax_getattr as jax_getattr, + jax_appendattr as jax_appendattr, + Box as Box, + List as List, +) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index c48b377b076b..60a3753a7ba5 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -28,7 +28,7 @@ from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map -from jax.experimental import attrs +from jax._src import attrs from jax.experimental.attrs import ( jax_setattr, jax_getattr, jax_appendattr, Box, List) From 766e68c4813a30e29b4fcefaa3253a42d0e197be Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 5 May 2025 18:36:57 -0700 Subject: [PATCH 1016/1769] Remove obsolete DCHECK on num_computations_, now that executables may have no local devices. PiperOrigin-RevId: 755150351 --- jaxlib/py_executable.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 9141111ae168..2902a2bd7be0 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -327,7 +327,6 @@ std::vector PyExecuteResults::ConsumeWithHandlers( std::vector outputs; auto ifrt_arrays = Consume(); auto traceback = Traceback::Get(); - DCHECK_GT(num_computations_, 0); int num_output_buffers = ifrt_arrays.size(); outputs.reserve(num_output_buffers); if (out_handlers.size() != num_output_buffers) { From cd513f253dff58ee4263267a8be1b4b09172840a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 May 2025 02:49:22 -0700 Subject: [PATCH 1017/1769] [Pallas:MGPU] Perform a warpgroup barrier before and after every SMEM write This allows us to guarantee the single-thread sequential semantics that we want to see in Pallas, even if it sometimes goes a little overboard with the barriers. However, there are situations when both are necessary! We barrier before we overwrite memory to ensure that all the warps are done reading from it before we do so. Conversely, we barrier after the store to make sure its effects are visible by reads issued from all other warps in the same Pallas thread (i.e. the warpgroup). I hope this should not lead to significant performance problems, since we generally only write from registers to SMEM once in the whole kernel (in the epilogue), and we usually had to perform a warpgroup barrier there too (as well as the async proxy fence). PiperOrigin-RevId: 755285818 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +++--- jax/experimental/mosaic/gpu/dialect_lowering.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cacca8d3d1d9..7ad912508ec4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1299,6 +1299,7 @@ def _swap_lowering_rule( x_smem, transforms = _handle_transforms( x_smem, transforms, handle_transposes=not transposed_value ) + mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: case ( gpu_core.UnswizzleRef(swizzle), @@ -1328,7 +1329,6 @@ def _swap_lowering_rule( layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - return old_value case (): match value.layout: case mgpu.TiledLayout(): @@ -1339,15 +1339,15 @@ def _swap_lowering_rule( optimized=False, ) value.store_untiled(x_smem, optimized=False) - return old_value case _: old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype) ) value.store_untiled(x_smem) - return old_value case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") + mgpu.warpgroup_barrier() # Make sure the writes have completed. + return old_value @register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 5edccafec236..0cc94261e847 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -37,6 +37,7 @@ from jax._src.lib.mlir.dialects import vector from jax._src.util import safe_zip from jax.experimental.mosaic.gpu import layouts as layouts_lib +from jax.experimental.mosaic.gpu import utils as mgpu_utils import numpy as np from . import fragmented_array as fa @@ -399,6 +400,7 @@ def _vector_store_op_lowering_rule( vector_store_op.valueToStore, to_store_layout ) + mgpu_utils.warpgroup_barrier() # Make sure the reads have completed. if fragmented_array.layout == fa.WGMMA_LAYOUT: swizzle, transforms = swizzle_and_transforms_from_transforms_attr( inference_utils.in_transforms(vector_store_op)[0] @@ -417,6 +419,7 @@ def _vector_store_op_lowering_rule( raise ValueError( f"{vector_store_op} has an unsupported layout: {to_store_layout}" ) + mgpu_utils.warpgroup_barrier() # Make sure the writes have completed. return [] From 66fc5e88d9fbaffde90c659cbddcf85b60bc428d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 May 2025 03:53:25 -0700 Subject: [PATCH 1018/1769] [Mosaic GPU] Add support for arbitrary reductions of tiled layouts This significantly generalizes our ability to perform reductions, to the point where pretty much all tiled layouts can be handled out of the box. The code is slightly longer than the few special cases we've implemented in the past, but overall is much more general. This also includes a hypothesis test that verifies that we always return the right answers, even for randomly sampled layouts. The test could still be improved in that we'll skip all the cases where we fail to synthesize a load/store for the layout, but it's already caught a number of problems in the initial implementation. PiperOrigin-RevId: 755303138 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- .../mosaic/gpu/dialect_lowering.py | 2 +- .../mosaic/gpu/fragmented_array.py | 365 ++++++++++++------ jax/experimental/mosaic/gpu/utils.py | 11 + tests/mosaic/BUILD | 2 +- tests/mosaic/gpu_test.py | 125 +++++- 6 files changed, 387 insertions(+), 120 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 7ad912508ec4..1d42c86efdbc 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1870,7 +1870,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError("No support for axes yet") scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: - return x.reduce_sum(scratch) + return x.reduce("add", axes, scratch) case mgpu.WGMMA_LAYOUT: if axes != (x_aval.ndim - 1,): raise NotImplementedError diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 0cc94261e847..c1506bde32ea 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -472,7 +472,7 @@ def _vector_reduction_op_lowering_rule( ir.MemRefType.get([4], element_type, memory_space=smem), arith.constant(None, op.attributes["offset"]), ) - result = a.reduce_sum(scratch) + result = a.reduce("add", range(len(a.shape)), scratch) case ( "#vector.kind" | "#vector.kind" | "#vector.kind" ): diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index fbcef6e6ecb8..62c5903f475f 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -30,7 +30,6 @@ from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import math as mlir_math from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np @@ -68,15 +67,15 @@ class Tiling: def __post_init__(self): if not self.tiles: return - tiled_rank = len(self.tiles[0]) + last_tile_rank = len(self.tiles[0]) for tile in self.tiles: - if len(tile) > tiled_rank: - raise ValueError("Only the first tile can refer to value dimensions") + if len(tile) > last_tile_rank: + raise ValueError("Tiles must have a decreasing rank") if not tile: raise ValueError("Tiles must not be empty") if any(d <= 0 for d in tile): raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") - tiled_rank += len(tile) + last_tile_rank = len(tile) def __str__(self): return f"Tiling({''.join(map(str, self.tiles))})" @@ -118,6 +117,33 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_dimension(self, dim: int) -> tuple[bool, ...]: + """Result is True whenever the tiled dim originated from the given input dim.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + strides = [1] * tiling_rank + strides[dim] = 0 + return tuple(s == 0 for s in self.tile_strides(tuple(strides))) + + def remove_dimension(self, dim: int) -> "Tiling": + """Returns a tiling with the given dimension removed.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + dim_in_tile = dim + tiles = [] + last_tile_rank = len(self.tiles[0]) + for t in self.tiles: + assert last_tile_rank >= len(t) + dim_in_tile -= last_tile_rank - len(t) + if dim_in_tile >= 0: + t = t[:dim_in_tile] + t[dim_in_tile + 1:] + if not t: # If this tile is empty, all other tiles will be empty too. + break + tiles.append(t) + return Tiling(tuple(tiles)) + def tile_nested_shape_strides( self, shape: tuple[tuple[int, ...], ...], @@ -281,7 +307,7 @@ def __post_init__(self): for d in self.lane_dims ) if lane_dims_prod != WARP_SIZE: - raise ValueError + raise ValueError("The product of lane dims does not equal the warp size") @functools.cached_property def partitioned_lane_dims(self) -> tuple[int, ...]: @@ -405,6 +431,36 @@ def warp_indices(self) -> tuple[ir.Value, ...]: indices[self.warp_dim] = warp_idx return tuple(indices) + def remove_dimension(self, dim: int) -> TiledLayout: + if dim < 0 or dim >= len(self.tiling.tiles[0]): + raise ValueError(f"Dimension {dim} is out of range for {self.tiling}") + new_tiling = self.tiling.remove_dimension(dim) + tiled_shape = self.tiled_tiling_shape + removed_dim = self.tiling.tile_dimension(dim) + dim_offsets = np.cumsum(removed_dim[::-1])[::-1].tolist() + if removed_dim[self.vector_dim]: + new_tiling = Tiling((*new_tiling.tiles, (1,))) + new_vector_dim = -1 + dim_offsets = [o - 1 for o in dim_offsets] # We inserted an extra dim. + else: + new_vector_dim = self.vector_dim + dim_offsets[self.vector_dim] + def replace_tiled_dim(d: int | Replicated, size: int): + if isinstance(d, Replicated): + return d + elif removed_dim[d]: + return Replicated(size) + else: + return d + dim_offsets[d] + return TiledLayout( + new_tiling, + replace_tiled_dim(self.warp_dim, WARPS_IN_WARPGROUP), + tuple( + d if isinstance(d, Replicated) else replace_tiled_dim(d, tiled_shape[d]) + for d in self.lane_dims + ), + new_vector_dim, + ) + def _tiled_wgmma_layout(shape: tuple[int, ...]): """Returns the tiled layout relevant for WGMMA operations. @@ -539,9 +595,9 @@ def linear_thread_idxs(self): vector_dim=-1, ) WGMMA_ROW_LAYOUT = TiledLayout( - Tiling(((64,), (16,), (8,), (1,))), - warp_dim=-4, - lane_dims=(-2, Replicated(4)), + Tiling(((64,), (16,), (8,), (1,), (1,))), + warp_dim=-5, + lane_dims=(-3, Replicated(4)), vector_dim=-1, ) @@ -1543,74 +1599,26 @@ def upcast_to_bf16(reg, high): _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) - # NOTE: scratch can be reused immediately once this function returns. - def reduce_sum(self, scratch: ir.Value | None = None): - if isinstance(self.layout, WGSplatFragLayout): - [reg] = self.registers.flat - if ir.FloatType.isinstance(self.mlir_dtype): - op = mulf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.muli - else: - raise NotImplementedError(self.mlir_dtype) - return FragmentedArray.splat( - op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)), - (), - is_signed=self.is_signed, - ) - - if not isinstance(self.layout, WGStridedFragLayout): - raise NotImplementedError(f"Unsupported layout {self.layout}") - - if scratch is None: - raise ValueError("scratch must be provided") - - if ir.FloatType.isinstance(self.mlir_dtype): - op = addf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.addi - else: - raise NotImplementedError(self.mlir_dtype) - - result = c(0, self.mlir_dtype) - for reg in self.registers: - result = op( - result, - vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), - ) - scratch_ty = ir.MemRefType(scratch.type) - if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: - raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") - - index = ir.IndexType.get() - warp_result = utils.warp_tree_reduce(result, op, 32) - warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) - memref.store(warp_result, scratch, [warp_id]) - utils.warpgroup_barrier() - zero_index = c(0, index) - with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): - scratch_vec = vector.load( - ir.VectorType.get((4,), self.mlir_dtype), - scratch, - [zero_index], - ) - scratch_sum = vector.reduction( - self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec - ) - memref.store(scratch_sum, scratch, [zero_index]) - utils.warpgroup_barrier() - result = memref.load(scratch, [zero_index]) - utils.warpgroup_barrier() # Make sure everyone is done using scratch. - return FragmentedArray.splat(result, (), is_signed=self.is_signed) - - def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): + def reduce( + self, + op: str | Callable[[ir.Value, ir.Value], ir.Value], + axis: int | Sequence[int, ...], + scratch: ir.Value | None = None, + ): + i32 = ir.IntegerType.get_signless(32) + if isinstance(axis, int): + axis = (axis,) + splat_op = None if isinstance(op, str): match op: case "add": + reduced_elems = math.prod(self.shape[a] for a in axis) if ir.FloatType.isinstance(self.mlir_dtype): op = addf + splat_op = lambda x: arith.mulf(x, c(reduced_elems, x.type)) elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi + splat_op = lambda x: arith.muli(x, c(reduced_elems, x.type)) else: raise NotImplementedError(self.mlir_dtype) case "max": @@ -1622,54 +1630,176 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): op = arith.maxsi if self.is_signed else arith.maxui else: raise NotImplementedError(self.mlir_dtype) + splat_op = lambda x: x case _: raise ValueError(f"Unrecognized reduction operator: {op}") - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError(self.layout) - if axis != 1: + match self.layout: + case WGStridedFragLayout(shape=_, vec_size=vec_size): + if set(axis) != set(range(len(self.shape))): + raise NotImplementedError( + "Warpgroup strided layout only support reductions along all axes" + ) + # We reinterpret the data as a tiled layout. We're reducing it all anyway. + layout = TiledLayout( + tiling=Tiling(((128 * vec_size,), (32 * vec_size,), (vec_size,))), + warp_dim=-3, + lane_dims=(-2,), + vector_dim=-1, + ) + return FragmentedArray( + _registers=self.registers.reshape( + layout.registers_shape((math.prod(self.shape),)) + ), + _layout=layout, + _is_signed=self.is_signed, + ).reduce(op, 0, scratch) + case WGSplatFragLayout(): + if splat_op is None: + raise NotImplementedError( + "Splat reductions only supported when the operator is a string" + ) + assert not self.registers.shape + return FragmentedArray( + _registers=np.asarray( + splat_op(self.registers.item()), dtype=object + ), + _layout=WGSplatFragLayout( + tuple(d for a, d in enumerate(self.shape) if a not in axis) + ), + _is_signed=self.is_signed, + ) + case TiledLayout(): + pass + case _: + raise NotImplementedError(self.layout) + if len(self.layout.base_tile_shape) != len(self.shape): raise NotImplementedError + if isinstance(axis, int): + axis = (axis,) + layout = self.layout + tiled_tiling_shape = layout.tiled_tiling_shape + reduced_dims = layout.tiling.tile_dimension(axis[0]) + for a in axis[1:]: + reduced_dims = [ + r or d for r, d in zip(reduced_dims, layout.tiling.tile_dimension(a), strict=True) + ] + regs_shape = self.registers.shape + reduced_shape = tuple( + d if r else 1 for r, d in zip(reduced_dims, regs_shape, strict=True) + ) + remaining_shape = tuple( + 1 if r else d for r, d in zip(reduced_dims, regs_shape) + ) + out_regs = np.empty(remaining_shape, dtype=object) index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - row_tile_dim = self.registers.shape[0] - row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, 1, row_subtile_dim, 1, 1), dtype=object) - assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(row_tile_dim, row_subtile_dim): - # Reduce the registers owned by the current thread over n tiles - reg_index = [0] * self.registers.ndim - reg_index[0] = row_tile - reg_index[4] = row_subtile - thread_result_vec = self.registers[tuple(reg_index)] - for n_tile in range(1, self.registers.shape[1]): - reg_index[1] = n_tile - thread_result_vec = op( - thread_result_vec, self.registers[tuple(reg_index)] + for out_idx in np.ndindex(remaining_shape): + out_reg = None + for red_idx in np.ndindex(reduced_shape): + src_idx = tuple(o + r for o, r in zip(out_idx, red_idx)) + if out_reg is None: + out_reg = self.registers[src_idx] + else: + out_reg = op(out_reg, self.registers[src_idx]) + # Reduce within the vector dimension, if necessary. + if reduced_dims[layout.vector_dim]: + [vec_len] = ir.VectorType(out_reg.type).shape + scalar_out_reg = None + for i in range(vec_len): + scalar = vector.extractelement(out_reg, position=c(i, index)) + scalar_out_reg = ( + scalar if scalar_out_reg is None else op(scalar_out_reg, scalar) + ) + out_reg = vector.splat( + ir.VectorType.get((1,), out_reg.type.element_type), scalar_out_reg ) - - thread_result = vector.extractelement(thread_result_vec, position=c(0, index)) - for i in range(1, self.layout.vector_length): - thread_result = op( - thread_result, - vector.extractelement(thread_result_vec, position=c(i, index)), + # Reduce accross warp lanes, if necessary (using warp shuffles). + if any(reduced_dims[d] for d in layout.partitioned_lane_dims): + if utils.bitwidth(out_reg.type) > 32: + raise NotImplementedError # Need to implement wide shfl_bfly. + lane_stride = 1 + for d in layout.lane_dims[::-1]: # Iterate minor-to-major + if isinstance(d, Replicated): + lane_stride *= d.times + elif not reduced_dims[d]: + lane_stride *= tiled_tiling_shape[d] + else: + assert lane_stride.bit_count() == 1 + reduction_size = tiled_tiling_shape[d] + while reduction_size > 1: + other_out_reg = utils.shfl_bfly(out_reg, lane_stride) + out_reg = op(out_reg, other_out_reg) + lane_stride *= 2 + reduction_size //= 2 + assert lane_stride == WARP_SIZE, lane_stride + # Reduce accross warps in the warpgroup, if necessary. + if ( + not isinstance(layout.warp_dim, Replicated) + and reduced_dims[layout.warp_dim] + ): + if scratch is None: + raise ValueError( + "scratch must be provided when cross-warp reduction is required" + ) + [vec_len] = ir.VectorType(out_reg.type).shape + scratch_ty = ir.MemRefType(scratch.type) + if scratch_ty.rank != 1: + raise ValueError(f"Expected rank 1 for scratch, got {scratch_ty.rank}") + if scratch_ty.element_type != self.mlir_dtype: + raise ValueError( + f"Expected element type {self.mlir_dtype} for scratch, got" + f" {scratch_ty.element_type}" + ) + # TODO(apaszke): All lanes that replicate data can share the same scratch. + # For now we treat the complete reduction as a special case. + reduces_all_dims = set(axis) == set(range(len(self.shape))) + unique_lanes = 1 if reduces_all_dims else 32 + if scratch_ty.shape[0] < WARPS_IN_WARPGROUP * unique_lanes * vec_len: + raise ValueError("Insufficient scratch space for cross-warp reduction") + if scratch_ty.get_strides_and_offset()[0] != [1]: + raise ValueError("Expected scratch to be contiguous") + thread_idx = utils.thread_idx() + if reduces_all_dims: + lane_idx = c(0, i32) + else: + lane_idx = arith.remui(thread_idx, c(WARP_SIZE, i32)) + warp_idx = arith.divui( + arith.remui(thread_idx, c(WARPGROUP_SIZE, i32)), c(WARP_SIZE, i32) ) - - # Do a shuffle to reduce in groups of 4 consecutive threads. - result = thread_result - for i in (1, 2): - other_result = nvvm.shfl_sync( - result.type, - c(0xFFFFFFFF, i32), - result, - c(i, i32), - c(0x1F, i32), - nvvm.ShflKind.bfly, + spill_base = arith.muli(lane_idx, c(WARPS_IN_WARPGROUP, i32)) + store_idx = arith.index_cast(index, arith.addi(spill_base, warp_idx)) + vector.store( + out_reg, scratch, [arith.muli(store_idx, c(vec_len, index))] ) - result = op(result, other_result) - new_regs[row_tile, :, row_subtile] = vector.splat( - ir.VectorType.get((1,), self.mlir_dtype), result - ) + utils.warpgroup_barrier() + scratch_vec = vector.load( + ir.VectorType.get((WARPS_IN_WARPGROUP * vec_len,), self.mlir_dtype), + scratch, + [arith.muli(arith.index_cast(index, spill_base), c(vec_len, index))], + ) + out_reg = None + for w in range(WARPS_IN_WARPGROUP): + part = utils.vector_slice(scratch_vec, slice(w * vec_len, (w + 1) * vec_len)) + out_reg = part if out_reg is None else op(out_reg, part) + utils.warpgroup_barrier() # Make sure everyone is done using scratch. + out_regs[out_idx] = out_reg + # Infer the output layout and reshape the registers accordingly. + reduced_logical_shape = list(self.shape) + for a in sorted(axis, reverse=True): + del reduced_logical_shape[a] + if not reduced_logical_shape: # Complete reduction results in a splat. + reduced_layout = WGSplatFragLayout(()) + assert out_regs.size == 1 + out_reg = out_regs.flat[0] + assert ir.VectorType(out_reg.type).shape == [1] + out_reg = vector.extractelement(out_reg, position=c(0, index)) + out_regs = np.asarray(out_reg, dtype=object) + else: + reduced_layout = layout + for a in sorted(axis, reverse=True): + reduced_layout = reduced_layout.remove_dimension(a) + out_regs = out_regs.reshape(reduced_layout.registers_shape(reduced_logical_shape)) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed + _registers=out_regs, _layout=reduced_layout, _is_signed=self.is_signed ) def broadcast(self, shape): @@ -1726,6 +1856,8 @@ def broadcast_minor(self, n): ) def broadcast_major(self, m): + if self.layout != WGMMA_COL_LAYOUT: + raise NotImplementedError if m % 64: raise ValueError("Number of rows must be divisible by 64") reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) @@ -1777,7 +1909,6 @@ def foreach( if create_array: new_regs[reg_idx] = val - if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) @@ -1825,12 +1956,22 @@ def load_untiled( ) def _store_untiled_splat(self, ref: ir.Value): + if math.prod(self.shape) == 1: + c0 = c(0, ir.IndexType.get()) + memref.store( + self.registers.flat[0], ref, [c0] * len(ir.MemRefType(ref.type).shape) + ) + return + vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: vec_size = 1 if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: - raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) + raise NotImplementedError( + "Arrays with the splat layout can only be stored when they have a" + f" single element or a multiple of {WARPGROUP_SIZE} elements" + ) fa = FragmentedArray.splat( self.registers.flat[0], diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index d89108cf451d..bd11c3a07544 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1254,13 +1254,24 @@ def dyn_dot(x, y): def shfl_bfly(x: ir.Value, distance: int | ir.Value): i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() if isinstance(distance, int): distance = c(distance, i32) if (result_type := x.type) != i32: + if (x_bitwidth := bitwidth(x.type)) < 32: # Pad to 32-bits if necessary. + x = bitcast(x, ir.IntegerType.get_signless(x_bitwidth)) + empty32 = llvm.mlir_undef(ir.VectorType.get((32 // x_bitwidth,), x.type)) + x = vector.insertelement(x, empty32, position=c(0, index)) + elif x_bitwidth != 32: + raise ValueError(f"Unsupported bitwidth {x_bitwidth}") x = bitcast(x, i32) y = nvvm.shfl_sync( i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, ) + if (x_bitwidth := bitwidth(result_type)) < 32: + bits_ty = ir.IntegerType.get_signless(x_bitwidth) + y_vec = bitcast(y, ir.VectorType.get((32 // x_bitwidth,), x.type)) + y = vector.extractelement(y_vec, position=c(0, index)) return bitcast(y, result_type) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index ea8be497faa4..24acb1b9a3f2 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -41,7 +41,7 @@ jax_multiplatform_test( ], deps = [ "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) jax_multiplatform_test( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index eb5bb355ffe1..50d60cae0080 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -63,6 +63,12 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm Dimension = gpu.Dimension +try: + import hypothesis as hp + import hypothesis.strategies as hps + jtu.setup_hypothesis() +except ImportError: + hp = hps = None # ruff: noqa: F405 @@ -1999,7 +2005,7 @@ def kernel(ctx, src, dst, scratch): src = mgpu.FragmentedArray.load_strided( src, is_signed=utils.is_signed(dtype) ) - acc = src.reduce_sum(scratch).broadcast((m,)) + acc = src.reduce("add", (0, 1), scratch).broadcast((m,)) acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) @@ -2019,15 +2025,19 @@ def kernel(ctx, src, dst, scratch): dtype=[jnp.float32, jnp.int32], m=[128], n=[32, 64], + reduce_both=[False, True], ) - def test_splat_reduce_sum(self, dtype, m, n): + def test_splat_reduce_sum(self, dtype, m, n, reduce_both): def kernel(ctx, dst, _): src = mgpu.FragmentedArray.splat( utils.c(1, utils.dtype_to_ir_type(dtype)), (m, n), is_signed=utils.is_signed(dtype), ) - acc = src.reduce_sum().broadcast((m,)) + if reduce_both: + acc = src.reduce("add", (0, 1)).broadcast((m,)) + else: + acc = src.reduce("add", 1) acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( @@ -2038,7 +2048,8 @@ def kernel(ctx, dst, _): out_shape=jax.ShapeDtypeStruct((m,), dtype), smem_scratch_shape=(), ) - np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0)) + result = m * n if reduce_both else n + np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), result, dtype)) @parameterized.product( op=(arith.addf, arith.maximumf), @@ -3330,7 +3341,7 @@ def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: del ctx, out # Unused. # TODO(b/408271232): Use a False condition once the bug is fixed. x = mgpu.FragmentedArray.load_strided(x_ref) - cond = x.reduce_sum(*scratch) != 42.0 + cond = x.reduce("add", 0, *scratch) != 42.0 cf.assert_(cond.registers.item(), "OOOPS") f = mgpu.as_gpu_kernel( @@ -3364,5 +3375,109 @@ def test_pass_is_registered(self): pipeline.run(module.operation) +if hp is not None: + @hps.composite + def tiled_layouts(draw, initial_tile, vector_transfer: bool = False): + assert all(t.bit_count() == 1 for t in initial_tile) + assert math.prod(initial_tile) >= 128 + tiles = [initial_tile] + dim_offset = len(initial_tile) + warp_dim = fa.Replicated(4) + if draw(hps.booleans()): + warp_dim = draw( + hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % 4 == 0] + ) + ) + warp_tile = list(tiles[-1]) + warp_tile[warp_dim] //= 4 + warp_dim += dim_offset + tiles.append(warp_tile) + dim_offset += len(tiles[-1]) + lane_dims = [fa.Replicated(2) if draw(hps.booleans()) else None for _ in range(5)] + for i, dim in enumerate(lane_dims): + if isinstance(dim, fa.Replicated): + continue + lane_dim = draw(hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % 2 == 0] + )) + lane_tile = list(tiles[-1]) + lane_tile[lane_dim] //= 2 + lane_dims[i] = dim_offset + lane_dim + tiles.append(lane_tile) + dim_offset += len(lane_tile) + # Permute lane dims so that they don't always partition the data in order. + lane_dims = draw(hps.permutations(lane_dims)) + if vector_transfer: + min_vector_dim = len(tiles[-1]) - 1 + else: + min_vector_dim = 0 + vector_dim = draw(hps.integers(min_vector_dim, len(tiles[-1]) - 1)) + vector_size = 2 ** draw( + hps.integers(0, tiles[-1][vector_dim].bit_length() - 1) + ) + vector_tile = list(tiles[-1]) + assert vector_tile[vector_dim] % vector_size == 0 + vector_tile[vector_dim] //= vector_size + tiles.append(vector_tile) + dim_offset += len(vector_tile) + vector_dim += dim_offset + dim_offset += len(vector_tile) # This is the remainder after tiling! + + if not isinstance(warp_dim, fa.Replicated): + warp_dim = warp_dim - dim_offset + lane_dims = tuple( + d if isinstance(d, fa.Replicated) else d - dim_offset + for d in lane_dims + ) + vector_dim = vector_dim - dim_offset + return fa.TiledLayout( + tiling=fa.Tiling(tuple(map(tuple, tiles))), + warp_dim=warp_dim, + lane_dims=lane_dims, + vector_dim=vector_dim, + ) + + class HypothesisTest(TestCase): + + def test_reduce(self): + @hps.composite + def strategy(draw): + rank = draw(hps.integers(2, 3)) + initial_tile = tuple( + draw(hps.sampled_from([1, 2, 4, 8, 16, 32, 64, 128])) + for _ in range(rank) + ) + hp.assume(128 <= math.prod(initial_tile) < 128 * 32) + shape = tuple(t * draw(hps.integers(1, 5)) for t in initial_tile) + hp.assume(math.prod(shape) <= 128 * 128) + layout = draw(tiled_layouts(initial_tile, vector_transfer=True)) + reduced_dims = draw(hps.sets(hps.integers(0, rank - 1), min_size=1)) + return shape, layout, tuple(reduced_dims) + + @hp.given(strategy()) + def run(args): + shape, layout, reduced_dims = args + out_shape = list(shape) + for d in sorted(reduced_dims, reverse=True): + del out_shape[d] + def kernel(ctx, src, dst, scratch): + arr = fa.FragmentedArray.load_untiled(src, layout=layout, optimized=False) + arr.reduce("max", reduced_dims, scratch).store_untiled(dst, optimized=False) + x = jax.random.normal(jax.random.key(1234), shape, jnp.float32) + out_type = jax.ShapeDtypeStruct(out_shape, jnp.float32) + scratch_type = jax.ShapeDtypeStruct((2048,), jnp.float32) + hp.assume(layout.vector_length <= 16) # Otherwise we run out of scratch + try: + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_type, scratch_type + )(x) + except NotImplementedError: + hp.assume(False) + return + np.testing.assert_array_equal(result, x.max(reduced_dims)) + run() + + if __name__ == "__main__": absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) From 70977f08e2ab654e030ea89a6bfda604f286f489 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 May 2025 04:58:08 -0700 Subject: [PATCH 1019/1769] [Mosaic GPU] Add tsl::TraceMe to annotate compilation cache misses This should make it clear if a captured profile contains compilation overhead or not. PiperOrigin-RevId: 755321326 --- jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/custom_call.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index bd4c86e97ad7..5fab85d2b77c 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -198,6 +198,7 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", + "@tsl//tsl/profiler/lib:traceme", "@xla//xla/ffi", "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 5d812f483de4..ca95080d5669 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -99,6 +99,7 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "tsl/profiler/lib/traceme.h" namespace { @@ -406,6 +407,7 @@ absl::StatusOr get_nvshmem_llvm_lib_path() { absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { + tsl::profiler::TraceMe trace("Compile"); auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); if (!sm_and_ptx_isa.ok()) { return sm_and_ptx_isa.status(); @@ -577,6 +579,7 @@ absl::StatusOr CachedCompileAndInit( absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. if (cache->find(key) == cache->end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); auto compiled = CompileAndInit(module); if (!compiled.ok()) { return compiled.status(); From 72e1a7d20ec89a237a76bfb76c0135d70e7737f6 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 6 May 2025 06:08:13 -0700 Subject: [PATCH 1020/1769] Enable batch sharding tests for Cholesky and triangular solve on GPU. These are now supported on GPU when using shardy. Also fix some TODOs by re-enabling HLO checks. PiperOrigin-RevId: 755340978 --- tests/BUILD | 4 +++ tests/linalg_sharding_test.py | 68 ++++++++++++++++------------------- 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 3ac17767ad13..20a37ba746de 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -737,6 +737,10 @@ jax_multiplatform_test( "gpu_p100x2_shardy", "gpu_p100x2_pjrt_c_api", ], + shard_count = { + "cpu": 10, + "gpu": 10, + }, tags = [ "multiaccelerator", ], diff --git a/tests/linalg_sharding_test.py b/tests/linalg_sharding_test.py index d8e1e6a16871..2f190cdc5ad6 100644 --- a/tests/linalg_sharding_test.py +++ b/tests/linalg_sharding_test.py @@ -14,7 +14,7 @@ import functools -from absl.testing import absltest +from absl.testing import absltest, parameterized import numpy as np import jax @@ -32,12 +32,6 @@ CPU_ONLY_FUN_AND_SHAPES = [ - # These functions are supported on GPU, but partitioning support will - # require updates to GSPMD, since they are lowered directly to HLO ops - # instead of custom calls on GPU. - (lax.linalg.cholesky, ((6, 6),)), - (lax.linalg.triangular_solve, ((6, 6), (4, 6))), - # The GPU kernel for this function still uses an opaque descriptor to # encode the input shapes so it is not partitionable. # TODO(danfm): Update the kernel and enable this test on GPU. @@ -49,11 +43,13 @@ ] CPU_AND_GPU_FUN_AND_SHAPES = [ + (lax.linalg.cholesky, ((6, 6),)), (lax.linalg.eig, ((6, 6),)), (lax.linalg.eigh, ((6, 6),)), (lax.linalg.lu, ((10, 6),)), (lax.linalg.qr, ((6, 6),)), (lax.linalg.svd, ((10, 6),)), + (lax.linalg.triangular_solve, ((6, 6), (4, 6))), (lax.linalg.tridiagonal, ((6, 6),)), ] @@ -68,9 +64,15 @@ def setUp(self): self.skipTest("Requires multiple devices") def get_fun_and_shapes(self, fun_and_shapes, grad=False): - if (jtu.test_device_matches(["gpu"]) - and fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES): - self.skipTest(f"{fun_and_shapes[0].__name__} not supported on GPU") + if jtu.test_device_matches(["gpu"]): + if fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES: + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} not supported on GPU.") + if (fun_and_shapes[0] in (lax.linalg.cholesky, lax.linalg.triangular_solve) + and not config.use_shardy_partitioner.value): + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} only supported on GPU " + "when shardy is enabled.") if not grad: return fun_and_shapes @@ -79,10 +81,10 @@ def get_fun_and_shapes(self, fun_and_shapes, grad=False): self.skipTest(f"{fun.__name__} does not support differentation") if jtu.test_device_matches(["gpu"]) and fun in ( lax.linalg.eig, lax.linalg.lu, lax.linalg.qr - ): + ) and not config.use_shardy_partitioner.value: self.skipTest( f"JVP of {fun.__name__} uses triangular solve on GPU, which doesn't " - "support batch partitioning yet") + "support batch partitioning unless shardy is enabled.") if fun == lax.linalg.eig: fun = functools.partial( @@ -107,9 +109,8 @@ def arg_maker(shape): return x return tuple(arg_maker(shape) for shape in shapes) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -124,20 +125,17 @@ def test_batch_axis_sharding(self, fun_and_shapes, dtype): expected = fun(*args) actual = fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) vmap_fun = jax.vmap(fun) vmap_fun_jit = jax.jit(vmap_fun) actual = vmap_fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn( - # "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn( + "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -155,9 +153,8 @@ def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): self.assertIn( "all-gather", fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_jvp(self, fun_and_shapes, dtype): @@ -181,14 +178,12 @@ def jvp_fun(primals, tangents): (primals_sharded, tangents), ]: _, actual = jvp_fun_jit(*args) - self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) - - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + self.assertAllClose(actual, expected, atol={np.float64: 1e-12}) + hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) + + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): @@ -205,9 +200,8 @@ def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): expected = vjp_fun(tangents) actual = vjp_fun_jit(tangents_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = vjp_fun_jit.lower(tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) + hlo = vjp_fun_jit.lower(tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) if __name__ == "__main__": From db376303d0007d032abd184c0f127514d1b1cfdd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 May 2025 06:43:12 -0700 Subject: [PATCH 1021/1769] [Mosaic-GPU] [3/3] Add support for communication primitives in MGPU lowering This is a rebase and reupload of @nvcastet's PR #26675. Co-authored-by: Nicolas Castet PiperOrigin-RevId: 755350802 --- build/requirements.in | 3 + build/requirements_lock_3_10.txt | 4 + build/requirements_lock_3_11.txt | 4 + build/requirements_lock_3_12.txt | 4 + build/requirements_lock_3_13.txt | 4 + build/requirements_lock_3_13_ft.txt | 6 +- jax/BUILD | 15 ++ jax/_src/pallas/mosaic_gpu/BUILD | 3 + jax/_src/pallas/mosaic_gpu/core.py | 49 +++- jax/_src/pallas/mosaic_gpu/lowering.py | 193 +++++++++++-- .../mosaic_gpu/pallas_call_registration.py | 9 +- jax/_src/pallas/mosaic_gpu/primitives.py | 24 +- jax/_src/test_multiprocess.py | 254 ++++++++++++++++++ jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/core.py | 46 +++- jax/experimental/mosaic/gpu/launch_context.py | 33 +++ jax/experimental/pallas/mosaic_gpu.py | 2 + jax_plugins/cuda/plugin_setup.py | 3 + jaxlib/jax.bzl | 1 + jaxlib/mosaic/gpu/BUILD | 17 +- jaxlib/mosaic/gpu/custom_call.cc | 12 +- .../gpu/{mosaic_gpu_comm.h => nvshmem.h} | 34 ++- jaxlib/mosaic/gpu/runtime.cc | 20 +- pyproject.toml | 1 + tests/pallas/BUILD | 21 ++ tests/pallas/gpu_pallas_distributed_test.py | 91 +++++++ tests/pallas/mosaic_gpu_test.py | 4 + 27 files changed, 774 insertions(+), 84 deletions(-) create mode 100644 jax/_src/test_multiprocess.py rename jaxlib/mosaic/gpu/{mosaic_gpu_comm.h => nvshmem.h} (76%) create mode 100644 tests/pallas/gpu_pallas_distributed_test.py diff --git a/build/requirements.in b/build/requirements.in index d2fc3a60a708..8b8af9d6b591 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -24,3 +24,6 @@ jax-cuda12-pjrt # TPU dependencies libtpu ; sys_platform == "linux" and platform_machine == "x86_64" + +# For Mosaic GPU collectives +nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 45820e38f195..c4ca6088e4bf 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -491,6 +491,10 @@ nvidia-nvjitlink-cu12==12.8.61 \ # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via -r build/requirements.in opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index e2140583c7e0..1f667115af04 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -486,6 +486,10 @@ nvidia-nvjitlink-cu12==12.8.61 \ # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via -r build/requirements.in opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 7482f6b2bad9..20ca67a3e921 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -486,6 +486,10 @@ nvidia-nvjitlink-cu12==12.8.61 \ # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via -r build/requirements.in opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 83cccfa84e4b..804373b03899 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -541,6 +541,10 @@ nvidia-nvjitlink-cu12==12.8.61 \ # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via -r build/requirements.in opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 13bb7126d62a..c7a1c882fc73 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -492,6 +492,10 @@ nvidia-nvjitlink-cu12==12.8.61 \ # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via -r build/requirements.in opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac @@ -690,4 +694,4 @@ zipp==3.21.0 \ setuptools==70.3.0 \ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc - # via -r build/requirements.in \ No newline at end of file + # via -r build/requirements.in diff --git a/jax/BUILD b/jax/BUILD index 30e071b81a55..749c72b45aba 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -233,6 +233,20 @@ py_library( ) + py_deps("numpy"), ) +py_library( + name = "test_multiprocess", + srcs = ["_src/test_multiprocess.py"], + visibility = [":internal"], + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ), +) + py_library( name = "internal_export_back_compat_test_util", srcs = ["_src/internal_test_util/export_back_compat_test_util.py"], @@ -858,6 +872,7 @@ pytype_strict_library( py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), + data = py_deps("libnvshmem_device"), visibility = [ ":mosaic_gpu_users", ], diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 35ce282234d2..2652be7a7c9a 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -48,6 +48,7 @@ pytype_strict_library( "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", + "//jax:sharding_impls", "//jax/_src/pallas", ] + py_deps("numpy"), ) @@ -59,11 +60,13 @@ pytype_strict_library( ":core", "//jax", "//jax:core", + "//jax:mesh", "//jax:mlir", "//jax:mosaic_gpu", "//jax:pallas", "//jax:partial_eval", "//jax:source_info_util", + "//jax:tree_util", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 9b293ebb51f1..808759edf35c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -44,6 +44,7 @@ from jaxlib.mlir import ir +_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef AbstractMemoryRef = pallas_core.AbstractMemoryRef DimensionSemantics = Literal["parallel", "sequential"] @@ -144,7 +145,7 @@ def __call__(self, shape: tuple[int, ...]): def get_array_aval(self) -> jax_core.ShapedArray: return self(()).get_array_aval() - def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + def get_ref_aval(self) -> _Ref: return self(()).get_ref_aval() @@ -218,7 +219,7 @@ def _is_known_divisible(value, divisor, fuel=10) -> bool: class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () - def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + def get_ref_aval(self) -> _Ref: aval = jax_core.ShapedArray(self.shape, self.dtype) for t in self.transforms: aval = t(aval) @@ -262,9 +263,7 @@ def _ref_group_size(refs: _GPUMemoryRefTree) -> int: return size -def flatten_ref_union( - ref_union: AbstractRefUnion, -) -> tuple[pallas_core.AbstractMemoryRef | state_types.TransformedRef, ...]: +def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: """Flattens a union of trees of references into a tuple of references. This is the moral equivalent of `jax.tree.leaves` for aliased references. @@ -567,6 +566,46 @@ def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: return pp.text(f"{{transpose({list(self.permutation)})}}") +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class PeerMemRef(state_types.Transform): + device_id: Any + device_id_type: pallas_primitives.DeviceIdType + + def transform_shape(self, shape): + return shape + + def transform_dtype(self, dtype): + return dtype + + def untransform_index( + self, idxs: tuple[Index, ...] + ) -> tuple[tuple[Index, ...], state_types.Transform]: + return idxs, self + + def tree_flatten(self): + return (self.device_id,), (self.device_id_type,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + return cls(arrays[0], metadata[0]) + + +def remote_ref( + ref: _Ref, + device_id: jax.typing.ArrayLike, + device_id_type: pallas_primitives.DeviceIdType = pallas_primitives.DeviceIdType.MESH, +) -> pallas_core.TransformedRef: + """Translate memref to a symmetric memref on a peer device.""" + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, PeerMemRef(device_id, device_id_type)), + ) + + def transform_ref( ref: pallas_core.TransformedRef, transform: state_types.Transform diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1d42c86efdbc..77216ea99446 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -32,14 +32,17 @@ from jax._src import core as jax_core from jax._src import lib as jaxlib from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import source_info_util +from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect @@ -59,8 +62,8 @@ import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler -from jax.experimental.mosaic.gpu import utils as mgpu_utils from jax.experimental.mosaic.gpu import tcgen05 +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp import numpy as np @@ -306,6 +309,7 @@ class ModuleContext: squashed_dims: tuple[int, ...] lowering_semantics: mgpu.LoweringSemantics primitive_semantics: gpu_core.PrimitiveSemantics + mesh: mesh_lib.Mesh | None warp_axis_name: str | None = None @property @@ -554,7 +558,8 @@ def index_map(*indices): def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, - mesh: pallas_core.Mesh | None, + gpu_mesh: pallas_core.Mesh | None, + jax_mesh: mesh_lib.Mesh | None, jaxpr: jax_core.Jaxpr, params: gpu_core.GPUCompilerParams, cost_estimate: pallas_core.CostEstimate | None, @@ -578,10 +583,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh: - assert isinstance(mesh, gpu_core.GPUMesh) - block = (128 * (mesh.num_threads or 1), 1, 1) - grid = mesh.grid + if gpu_mesh: + assert isinstance(gpu_mesh, gpu_core.GPUMesh) + block = (128 * (gpu_mesh.num_threads or 1), 1, 1) + grid = gpu_mesh.grid else: block = (128, 1, 1) grid = grid_mapping.grid @@ -682,16 +687,17 @@ def body_fn(indices, *refs): assert not new_consts axis_names = ( - _AxisNames(mesh.grid_names, mesh.cluster_names, mesh.thread_name) - if mesh is not None + _AxisNames(gpu_mesh.grid_names, gpu_mesh.cluster_names, gpu_mesh.thread_name) + if gpu_mesh is not None else _AxisNames(grid_mapping.grid_names or ()) ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( - parallel_grid, + jax_mesh, axis_names, + parallel_grid, block, - mesh.cluster if mesh is not None else (), + gpu_mesh.cluster if gpu_mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], [ @@ -705,8 +711,9 @@ def body_fn(indices, *refs): def lower_jaxpr_to_module( - grid: Sequence[int], + jax_mesh: mesh_lib.Mesh | None, axis_names: _AxisNames, + grid: Sequence[int], block: Sequence[int], cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], @@ -736,6 +743,11 @@ def lower_jaxpr_to_module( def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers + if gmem_scratch_shapes: + in_buffers, _, out_scratch_buffers = util.split_list( + buffers_gmem, [len(in_shapes), len(gmem_scratch_shapes)] + ) + buffers_gmem = in_buffers + out_scratch_buffers grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): @@ -772,6 +784,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): squashed_dims=squashed_dims, lowering_semantics=lowering_semantics, primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, + mesh=jax_mesh, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -810,7 +823,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, - in_shapes=in_shapes, + in_shapes=(*in_shapes, *gmem_scratch_shapes), out_shape=(*out_shapes, *gmem_scratch_shapes), smem_scratch_shape=scratch_buffers, lowering_semantics=lowering_semantics, @@ -1150,11 +1163,13 @@ def _extract_aliased_ref( def _handle_transforms( + ctx: LoweringRuleContext, ref: ir.Value, transforms: Sequence[gpu_core.Transform], *, handle_transposes=True, handle_reshapes=True, + allow_peer_refs=False, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: # Before we handle other transforms, we resolve any possible leading aliasing # transform. @@ -1172,6 +1187,7 @@ def _bubble_up(untransform_fn, data): new_transforms = list(reversed(new_transforms_rev)) return data + peer_device_id = None for t in transforms: match t: case indexing.NDIndexer(): @@ -1192,9 +1208,23 @@ def _bubble_up(untransform_fn, data): lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop shape) transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case gpu_core.PeerMemRef(device_id, device_id_type): + if device_id_type != primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + "Only logical device IDs are supported for peer memrefs." + ) + peer_device_id = device_id case _: new_transforms.append(t) - + if peer_device_id is not None: + if not allow_peer_refs: + raise NotImplementedError( + "Peer device references are not allowed in the lowering of this" + " primitive." + ) + transformed_ref = ctx.launch_ctx.to_remote( + transformed_ref, _ensure_ir_value(peer_device_id, jnp.int32) + ) return transformed_ref, new_transforms @@ -1237,11 +1267,16 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(x_ref, transforms) + x_smem, transforms = _handle_transforms( + ctx, x_ref, transforms, allow_peer_refs=True + ) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (8, (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype)): + if tiling != ( + 8, + (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype), + ): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle @@ -1268,7 +1303,9 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(x_smem, transforms) + x_smem, transforms = _handle_transforms( + ctx, x_smem, transforms, allow_peer_refs=True + ) if transforms: raise NotImplementedError( @@ -1297,7 +1334,7 @@ def _swap_lowering_rule( transforms = jax.tree.unflatten(tree, leaves) transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT x_smem, transforms = _handle_transforms( - x_smem, transforms, handle_transposes=not transposed_value + ctx, x_smem, transforms, handle_transposes=not transposed_value, allow_peer_refs=True ) mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: @@ -1363,7 +1400,7 @@ def _swap_lowering_rule_wg( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(x_smem, transforms) + x_smem, transforms = _handle_transforms(ctx, x_smem, transforms, allow_peer_refs=True) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" @@ -1977,17 +2014,37 @@ def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): @register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - axis_names = ctx.module_ctx.axis_names - if not axis_names: + gpu_axis_names = ctx.module_ctx.axis_names + jax_axis_names = getattr(ctx.module_ctx.mesh, "axis_names", ()) + if gpu_axis_names is None and not jax_axis_names: raise LookupError( "No axis names are available. Make sure you are using `pl.core_map`" - " with a `plgpu.GPUMesh`." + " with a `plgpu.GPUMesh` or an appropriate JAX device mesh." ) - if axis_name not in axis_names: + if axis_name not in itertools.chain((gpu_axis_names or ()), jax_axis_names): raise LookupError( - f"Unknown axis {axis_name}, available axes: {[*axis_names]}" + f"Axis {axis_name} does not refer to a GPU mesh axis (available axes:" + f" {[*gpu_axis_names]}) or a JAX mesh axis (available axes:" + f" {[*jax_axis_names]})" + ) + if axis_name in jax_axis_names: + jax_mesh = ctx.module_ctx.mesh + assert jax_mesh is not None + device_id = ctx.launch_ctx.device_id() + jax_mesh_shape = jax_mesh.axis_sizes + axis_index = jax_axis_names.index(axis_name) + i32 = ir.IntegerType.get_signless(32) + axis_size = _ir_constant(jax_mesh_shape[axis_index], i32) + minor_divisor = _ir_constant( + np.prod(jax_mesh_shape[axis_index + 1 :], dtype=np.int32), i32 ) + return arith_dialect.remsi(arith_dialect.divsi(device_id, minor_divisor), axis_size) + # We already checked that the axis is in scope and it wasn't a JAX mesh axis. + assert gpu_axis_names is not None + + # We only deal with GPU axes from now on. + axis_names = gpu_axis_names if axis_names.wg is not None and axis_name == axis_names.wg: return mgpu.warpgroup_idx(sync=True) @@ -2806,3 +2863,93 @@ def _ensure_idx_fa(x): shape=root_shape, int_indexer_shape=(), ) + + +@register_lowering_rule(primitives.semaphore_read_p, mgpu.LoweringSemantics.Lane) +def _semaphore_read_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_read: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + i32_ty = ir.IntegerType.get_signless(32) + return llvm_dialect.inline_asm( + i32_ty, [sem_ptr], "ld.acquire.sys.u32 $0,[$1];", "=r,l", has_side_effects=True, + ) + + +@register_lowering_rule(primitives.semaphore_signal_p, mgpu.LoweringSemantics.Lane) +def _semaphore_signal_lowering_rule( + ctx: LoweringRuleContext, + *args, + args_tree, + device_id_type, +): + i32 = ir.IntegerType.get_signless(32) + sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( + args_tree, args + ) + if core_index is not None: + raise NotImplementedError( + "Mosaic GPU backend does not support the concept of cores, but" + " core_index is specified" + ) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_signal: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + if device_id is not None: + if device_id_type != primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + f"Unsupported device id type: {device_id_type}" + ) + sem_ptr = ctx.launch_ctx.to_remote( + sem_ptr, _ensure_ir_value(device_id, jnp.int32) + ) + # TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local. + val = _ir_constant(value, i32) + pred = ctx.module_ctx.single_wg_lane_predicate + llvm_dialect.inline_asm( + i32, + [sem_ptr, val, pred], + "@$3 atom.add.release.sys.global.u32 $0, [$1], $2;", + "=r,l,r,b", + has_side_effects=True, + ) + return () + + +@register_lowering_rule(primitives.semaphore_wait_p, mgpu.LoweringSemantics.Lane) +def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError( + f"Unhandled transforms for semaphore_wait: {transforms}" + ) + + sem_ptr = mgpu.utils.memref_ptr(sem) + i32_ty = ir.IntegerType.get_signless(32) + ne_pred = arith_dialect.CmpIPredicate.ne + zero_const = mgpu.utils.c(0, i32_ty) + val = _ir_constant(value, i32_ty) + + with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): + # Create the while loop for busy waiting + while_op = scf_dialect.WhileOp([i32_ty], [zero_const]) + before_block = while_op.before.blocks.append(i32_ty) + with ir.InsertionPoint.at_block_begin(before_block): + old_val = llvm_dialect.inline_asm( + i32_ty, + [sem_ptr, val, zero_const], + "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", + "=r,l,r,r", + has_side_effects=True, + ) + comparison = arith_dialect.cmpi(ne_pred, old_val, val) + scf_dialect.condition(comparison, before_block.arguments) + after_block = while_op.after.blocks.append(i32_ty) + with ir.InsertionPoint.at_block_begin(after_block): + scf_dialect.yield_(after_block.arguments) + mgpu_utils.warpgroup_barrier() + return () diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 2e2cb976df8f..72e6f96c125a 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -25,6 +25,7 @@ import jax from jax import lax from jax._src import core as jax_core +from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -66,8 +67,14 @@ def pallas_call_lowering( else: params = gpu_core.GPUCompilerParams() + jax_mesh = None + axis_context = ctx.module_context.axis_context + if axis_context is not None: + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + jax_mesh = axis_context.mesh + lowering_result = lowering.lower_pipelined_jaxpr_to_module( - grid_mapping, mesh, jaxpr, params, cost_estimate + grid_mapping, mesh, jax_mesh, jaxpr, params, cost_estimate ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {debug_info.func_src_info}:") diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 0384e41144cd..8d0e0c82671d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -88,7 +88,7 @@ def _load_p_lowering_rule( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(args_tree, leaves) - x_ref, transforms = lowering._handle_transforms(x_ref, transforms) + x_ref, transforms = lowering._handle_transforms(ctx, x_ref, transforms) if layout is not None: layout = layout.to_mgpu() @@ -259,7 +259,7 @@ def _copy_smem_to_gmem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) + src, src_transforms = lowering._handle_transforms(ctx, src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( @@ -477,7 +477,7 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_transforms(dst, dst_transforms, handle_transposes=False) + dst, dst_transforms = lowering._handle_transforms(ctx, dst, dst_transforms, handle_transposes=False) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -923,7 +923,7 @@ def _wgmma_lowering( ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) a, a_transforms = lowering._handle_transforms( - a, a_transforms, handle_transposes=False, handle_reshapes=False + ctx, a, a_transforms, handle_transposes=False, handle_reshapes=False ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): @@ -950,7 +950,7 @@ def _wgmma_lowering( b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) b, b_transforms = lowering._handle_transforms( - b, b_transforms, handle_transposes=False, handle_reshapes=False + ctx, b, b_transforms, handle_transposes=False, handle_reshapes=False ) match b_transforms: @@ -1010,14 +1010,12 @@ def _wgmma_warpgroup_lowering( a_transforms_tree, b_transforms_tree, ): - del ctx # Unused. - if a_transforms_tree is not None: a_transforms_leaves, b_transforms_leaves = util.split_list( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_transforms(a, a_transforms) + a, a_transforms = lowering._handle_transforms(ctx, a, a_transforms) match a_transforms: case (gpu_core.TransposeRef((1, 0)),): a = mgpu.memref_transpose(a, (1, 0)) @@ -1032,7 +1030,7 @@ def _wgmma_warpgroup_lowering( if b_transforms_tree is not None: b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_transforms(b, b_transforms) + b, b_transforms = lowering._handle_transforms(ctx, b, b_transforms) match b_transforms: case (gpu_core.TransposeRef((1, 0)),): b = mgpu.memref_transpose(b, (1, 0)) @@ -1215,7 +1213,7 @@ def _tcgen05_mma_lowering( a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) a_ref, a_transforms = lowering._handle_transforms( - a_ref, a_transforms, handle_transposes=False, handle_reshapes=True + ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(lhs_tiling)): @@ -1239,7 +1237,7 @@ def _tcgen05_mma_lowering( b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) b_ref, b_transforms = lowering._handle_transforms( - b_ref, b_transforms, handle_transposes=False, handle_reshapes=True + ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): @@ -1535,7 +1533,7 @@ def _jaxpr_call_lowering_rule( # We ignore other transforms here, because they are already embedded # in the jaxpr. ref, _ = lowering._handle_transforms( - ref, transforms, handle_reshapes=False, handle_transposes=False + ctx, ref, transforms, handle_reshapes=False, handle_transposes=False ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) @@ -1844,7 +1842,7 @@ def _inline_mgpu_lowering_rule( assert transforms is None continue assert isinstance(aval, pallas_core.AbstractMemoryRef) - a, user_transforms = lowering._handle_transforms(a, transforms, handle_transposes=False) + a, user_transforms = lowering._handle_transforms(ctx, a, transforms, handle_transposes=False) # Transforms that do not originate from a MemoryRefTransform are # applied implicitly (eg by emit-pipeline) and therefore we do not # expect the user to pass them to the type. The transforms not diff --git a/jax/_src/test_multiprocess.py b/jax/_src/test_multiprocess.py new file mode 100644 index 000000000000..8a5b6ee1df00 --- /dev/null +++ b/jax/_src/test_multiprocess.py @@ -0,0 +1,254 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper for running multi-process tests.""" + +import os +import pathlib +import re +import signal +import subprocess +import time + +from absl import app +from absl import flags +import jax +from jax import config +from jax._src import distributed +try: + import portpicker +except ImportError: + portpicker = None + +from absl.testing import absltest +from jax._src import test_util as jtu + + +_NUM_PROCESSES = flags.DEFINE_integer( + "num_processes", None, "Number of processes to use." +) + +_GPUS_PER_PROCESS = flags.DEFINE_integer( + "gpus_per_process", + 0, + "Number of GPUs per worker process.", +) + +_MULTIPROCESS_TEST_WORKER_ID = flags.DEFINE_integer( + "multiprocess_test_worker_id", + -1, + "Worker id. Set by main test process; should not be set by users.", +) + +_MULTIPROCESS_TEST_CONTROLLER_ADDRESS = flags.DEFINE_string( + "multiprocess_test_controller_address", + "", + "Address of the JAX controller. Set by the main test process; should not be" + " set by users.", +) + + +expect_failures_with_regex = None + + +def main(): + config.config_with_absl() + app.run(_main) + + +class GracefulKiller: + """Add a signal handler that sets a flag if SIGINT or SIGTERM are caught.""" + + # From https://stackoverflow.com/a/31464349 + kill_now = False + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, sig_num, unused_stack_frame): + print(f"Caught signal: {signal.Signals(sig_num).name} ({sig_num})") + self.kill_now = True + + +def _main(argv): + if _MULTIPROCESS_TEST_WORKER_ID.value >= 0: + jax.distributed.initialize( + _MULTIPROCESS_TEST_CONTROLLER_ADDRESS.value, + num_processes=_NUM_PROCESSES.value, + process_id=_MULTIPROCESS_TEST_WORKER_ID.value, + initialization_timeout=10, + ) + absltest.main(testLoader=jtu.JaxTestLoader()) + + if not argv[0].endswith(".py"): # Skip the interpreter path if present. + argv = argv[1:] + + num_processes = _NUM_PROCESSES.value + if num_processes is None: + raise ValueError("num_processes must be set") + gpus_per_process = _GPUS_PER_PROCESS.value + if portpicker is None: + jax_port = 9876 + else: + jax_port = portpicker.pick_unused_port() + subprocesses = [] + output_filenames = [] + output_files = [] + for i in range(num_processes): + env = os.environ.copy() + + args = [ + "/proc/self/exe", + *argv, + f"--num_processes={num_processes}", + f"--multiprocess_test_worker_id={i}", + f"--multiprocess_test_controller_address=localhost:{jax_port}", + "--logtostderr", + ] + + if gpus_per_process > 0: + gpus = range(i * gpus_per_process, (i + 1) * gpus_per_process) + env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpus)) + + undeclared_outputs = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp") + stdout_name = f"{undeclared_outputs}/jax_{i}_stdout.log" + stderr_name = f"{undeclared_outputs}/jax_{i}_stderr.log" + stdout = open(stdout_name, "wb") + stderr = open(stderr_name, "wb") + print(f"Launching process {i}:") + print(f" stdout: {stdout_name}") + print(f" stderr: {stderr_name}") + proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr) + subprocesses.append(proc) + output_filenames.append((stdout_name, stderr_name)) + output_files.append((stdout, stderr)) + + print(" All launched, running ".center(80, "="), flush=True) + + # Wait for all the children to finish or for a SIGTERM from bazel. If we get + # SIGTERM, we still want to collect their logs, so kill them and continue. + killer = GracefulKiller() + running_procs = dict(enumerate(subprocesses)) + while not killer.kill_now and running_procs: + time.sleep(0.1) + for i, proc in list(running_procs.items()): + if proc.poll() is not None: + print(f"Process {i} finished.", flush=True) + running_procs.pop(i) + if killer.kill_now and running_procs: + print("Caught termination, terminating remaining children.", flush=True) + + # Send a SIGTERM to each child process, to let it know it should terminate. + for i, proc in running_procs.items(): + proc.terminate() + print(f"Process {i} terminated.", flush=True) + + # We give the child process(es) a few seconds for their own cleanup, and + # keep the rest (up to 15s) for copying the children logs into our own. + time.sleep(5) + + # Send a SIGKILL (a "hard" kill) to each child process. This is CRITICAL: + # without it, this process may end up waiting a long time on the proc.wait() + # below, and never get to saving the children logs, making test timeouts + # very hard to debug. + for i, proc in running_procs.items(): + proc.kill() + print(f"Process {i} killed.") + print("Killed all child processes.", flush=True) + + retvals = [] + stdouts = [] + stderrs = [] + for proc, fds, (stdout, stderr) in zip( + subprocesses, output_files, output_filenames + ): + retvals.append(proc.wait()) + for fd in fds: + fd.close() + stdouts.append(pathlib.Path(stdout).read_text(errors="replace")) + stderrs.append(pathlib.Path(stderr).read_text(errors="replace")) + + print(" All finished ".center(80, "="), flush=True) + + print(" Summary ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + m = re.search(r"Ran \d+ tests? in [\d.]+s\n\n.*", stderr, re.MULTILINE) + result = m.group().replace("\n\n", "; ") if m else "Test crashed?" + print( + f"Process {i}, ret: {retval}, len(stdout): {len(stdout)}, " + f"len(stderr): {len(stderr)}; {result}" + ) + + print(" Detailed logs ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + print(f" Process {i}: return code: {retval} ".center(80, "=")) + if stdout: + print(f" Process {i} stdout ".center(80, "-")) + print(stdout) + if stderr: + print(f" Process {i} stderr ".center(80, "-")) + print(stderr) + + print(" Done detailed logs ".center(80, "="), flush=True) + for i, (retval, stderr) in enumerate(zip(retvals, stderrs)): + if retval != 0: + if expect_failures_with_regex is not None: + assert re.search( + expect_failures_with_regex, stderr + ), f"process {i} failed, expected regex: {expect_failures_with_regex}" + else: + assert retval == 0, f"process {i} failed, return value: {retval}" + + +class MultiProcessTest(absltest.TestCase): + + def setUp(self): + """Start tests together.""" + super().setUp() + assert jax.process_count() == _NUM_PROCESSES.value, ( + jax.process_count(), + _NUM_PROCESSES.value, + ) + # Make sure all processes are at the same test case. + client = distributed.global_state.client + try: + client.wait_at_barrier(self._testMethodName + "_start", 10000) + except jax.errors.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Init or some test executed earlier than {self._testMethodName} " + "failed. Check logs from earlier tests to debug further. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + + def tearDown(self): + """End tests together.""" + client = distributed.global_state.client + # Ensure a shared fate for tests where a subset of processes run different + # test assertions (i.e. some processes may pass and some processes fail - + # but the overall test should fail). + try: + client.wait_at_barrier(self._testMethodName + "_end", 10000) + except jax.errors.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Test {self._testMethodName} failed in another process. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + super().tearDown() diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 16b667690de4..074890d1816d 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -27,6 +27,7 @@ TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, + supports_cross_device_collectives as supports_cross_device_collectives, ) from .launch_context import ( diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index dd8996dc0f6f..28fef9d0e96d 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -40,6 +40,7 @@ from jaxlib.mlir.dialects import nvvm import numpy as np + # mypy: ignore-errors from . import dialect_lowering @@ -82,15 +83,33 @@ # Set this so that the custom call can find it os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) -if os.environ.get("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH") is None: - try: - from nvidia import nvshmem - except ImportError: - pass - else: - os.environ["MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"] = ( - os.path.join(nvshmem.__path__[0], 'lib/libnvshmem_device.bc') + +try: + from nvidia import nvshmem +except ImportError: + pass +else: + if os.environ.get("MOSAIC_GPU_NVSHMEM_BC_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_device.bc" ) + if os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_host.so.3" + ) + + +def supports_cross_device_collectives(): + try: + nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] + except KeyError: + return False + xla_flags = os.environ.get("XLA_FLAGS", "") + return ( + os.path.exists(nvshmem_bc_path) + and "--xla_gpu_experimental_enable_nvshmem" in xla_flags + ) + mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @@ -389,6 +408,7 @@ def _launch( block: tuple[int, int, int], smem_buffers: ShapeTree | Union[ShapeTree], lowering_semantics: LoweringSemantics, + module: ir.Module, profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, ): @@ -456,7 +476,7 @@ def _launch( prof = None ctx = launch_context.LaunchContext( - launch_context.Scratch(launch_op), cluster, prof + module, launch_context.Scratch(launch_op), cluster, prof ) with ctx.named_region("Init"): tmem_allocs: list[_TMEMAlloc] = [] @@ -557,7 +577,7 @@ def main(token_ptr, buffers): prof_buffer = out_refs.pop() if prof_spec is not None else None with _launch( token, grid, cluster, block, smem_scratch_shape, - lowering_semantics, prof_spec, prof_buffer + lowering_semantics, module, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx launch_ctx = _launch_ctx @@ -647,6 +667,9 @@ def as_gpu_kernel( launch_ctx.scratch.finalize_size() module.operation.verify() + if launch_ctx.is_device_collective and not supports_cross_device_collectives(): + raise RuntimeError("Kernel is a cross-device collective but no support is available.") + expected_arg_treedef = jax.tree.structure(in_shape) def _check_args(*args): arg_treedef = jax.tree.structure(args) @@ -735,6 +758,9 @@ def as_torch_gpu_kernel( launch_ctx.scratch.finalize_size() module.operation.verify() + if launch_ctx.is_device_collective: + raise RuntimeError("Kernel is a cross-device collective but no support is available.") + # Get our hands on the compilation and unload functions try: import jax_plugins.xla_cuda12 as cuda_plugin diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index fbefd027b53e..d169c448a80e 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -323,6 +323,7 @@ def finalize_size(self): @dataclasses.dataclass() class LaunchContext: + module: ir.Module scratch: Scratch cluster_size: tuple[int, int, int] profiler: OnDeviceProfiler | None = None @@ -330,6 +331,7 @@ class LaunchContext: tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], ir.Value, ] = dataclasses.field(default_factory=dict, init=False) + is_device_collective: bool = False @contextlib.contextmanager def named_region(self, *args, **kwargs): @@ -845,3 +847,34 @@ def await_async_copy( ): nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) utils.warpgroup_barrier() + + def _ensure_nvshmem_decls(self): + if self.is_device_collective: + return + self.is_device_collective = True + with ir.InsertionPoint(self.module.body): + nvshmem_my_pe_type = ir.TypeAttr.get(ir.Type.parse("!llvm.func")) + llvm.LLVMFuncOp( + "nvshmem_my_pe", nvshmem_my_pe_type, sym_visibility="private" + ) + nvshmem_ptr_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.func") + ) + llvm.LLVMFuncOp("nvshmem_ptr", nvshmem_ptr_type, sym_visibility="private") + + def to_remote(self, ref: ir.Value, peer: ir.Value): + self._ensure_nvshmem_decls() + if ir.MemRefType.isinstance(ref.type): + return utils.ptr_as_memref( + self.to_remote(utils.memref_ptr(ref), peer), ref.type + ) + if ref.type != ir.Type.parse("!llvm.ptr"): + raise ValueError(f"Unsupported type for to_remote: {ref.type}") + if peer.type != ir.IntegerType.get_signless(32): + raise ValueError(f"peer index must be an i32, got {peer.type}") + return llvm.call(ref.type, [ref, peer], [], [], callee="nvshmem_ptr") + + def device_id(self) -> ir.Value: + self._ensure_nvshmem_decls() + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index ecbe28bbcc20..dd1bd3aba4bd 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -24,7 +24,9 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion +from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 4b442f3aa7de..acd82702b357 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -71,6 +71,9 @@ def has_ext_modules(self): # Until NVIDIA add version constraints, add a version constraint # here. "nvidia-nvjitlink-cu12>=12.1.105", + # NVSHMEM is used by Mosaic GPU collectives and can be used by XLA to + # speed up collectives too. + "nvidia-nvshmem-cu12>=3.2.5", ], }, url="https://github.com/jax-ml/jax", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index f8c89c0e0401..a8fe2b50344b 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -100,6 +100,7 @@ _py_deps = { "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], "zstandard": get_zstandard(), + "libnvshmem_device": ["@pypi//nvidia_nvshmem_cu12"], } def all_py_deps(excluded = []): diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 5fab85d2b77c..b694258fed1e 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -123,17 +123,22 @@ cc_library( name = "runtime", srcs = ["runtime.cc"], # Linker may prune these symbols if they are not explicitly exported. - linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], + linkopts = [ + "-Wl,--export-dynamic-symbol='mosaic_gpu_*'", + "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", + "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", + ], deps = [ - ":mosaic_gpu_comm", + ":nvshmem", "@local_config_cuda//cuda:cuda_headers", ], alwayslink = True, ) cc_library( - name = "mosaic_gpu_comm", - hdrs = ["mosaic_gpu_comm.h"], + name = "nvshmem", + hdrs = ["nvshmem.h"], deps = [ "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", @@ -144,7 +149,7 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], deps = [ - ":mosaic_gpu_comm", + ":nvshmem", ":passes", ":target", "//jaxlib/cuda:cuda_vendor", @@ -236,7 +241,7 @@ cc_binary( "notap", ], deps = [ - "//jaxlib/mosaic/gpu:mosaic_gpu_comm", + ":nvshmem", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", ], diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index ca95080d5669..7c93d54aff9e 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -91,7 +91,7 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" -#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" +#include "jaxlib/mosaic/gpu/nvshmem.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" @@ -400,8 +400,9 @@ bool is_nvshmem_used(mlir::ModuleOp module) { } absl::StatusOr get_nvshmem_llvm_lib_path() { - const char * nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); - if (!nvshmem_path_ptr) return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + const char* nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_BC_PATH"); + if (!nvshmem_path_ptr) + return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_BC_PATH"); return nvshmem_path_ptr; } @@ -418,6 +419,11 @@ absl::StatusOr, bool>> Compile( std::string nvshmem_path = ""; if (is_comm_used) { TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + if (!mosaic::gpu::NvshmemApi::Default(/*assert_ok=*/false).is_loaded()) { + return absl::InternalError( + "Failed to load the NVSHMEM library. Make sure it is installed (e.g. " + "`pip install nvidia-nvshmem-cu12`)."); + } } DumpCompilationOutput(module, sm, ptx_isa, nvshmem_path); auto passes = GetPassPipeline( diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_comm.h b/jaxlib/mosaic/gpu/nvshmem.h similarity index 76% rename from jaxlib/mosaic/gpu/mosaic_gpu_comm.h rename to jaxlib/mosaic/gpu/nvshmem.h index 1aa15f9307c7..dbd11aa1d373 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_comm.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -17,14 +17,15 @@ limitations under the License. #define JAXLIB_MOSAIC_GPU_COMM_H_ #include -#include + #include +#include +#include #include "third_party/gpus/cuda/include/cuda.h" #include "cuda_runtime_api.h" #define NVSHMEM_SUCCESS 0 -#define NVSHMEM_LIB_SONAME "libnvshmem_host.so.3" namespace mosaic { namespace gpu { @@ -33,19 +34,22 @@ namespace gpu { FnName = reinterpret_cast(dlsym(library, #FnName)); \ if (!FnName) { \ fprintf(stderr, #FnName " not available in this library."); \ - abort(); \ } class NvshmemApi { public: // Returns a default NvshmemApi for a current process. // NvshmemApi follows the Singleton design pattern - static NvshmemApi& Default() { + static NvshmemApi& Default(bool assert_ok = true) { static NvshmemApi instance; + if (assert_ok && !instance.is_loaded()) { + fprintf(stderr, "Failed to load the NVSHMEM library.\n"); + abort(); + } return instance; } - int cumodule_int(CUmodule module) { + int cumodule_init(CUmodule module) { std::lock_guard lock(mutex_); return nvshmemx_cumodule_init(module); } @@ -54,28 +58,32 @@ class NvshmemApi { nvshmemx_barrier_all_on_stream(stream); } + bool is_loaded() { + return nvshmemx_init_status != nullptr && nvshmemx_init_status() == 2; + } + NvshmemApi(NvshmemApi const&) = delete; void operator=(NvshmemApi const&) = delete; private: NvshmemApi() { - const char* env_value = getenv("NVSHMEM_LIBRARY_PATH"); + const char* env_value = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH"); const char* libnvshmem_path = - env_value && *env_value != 0 ? env_value : NVSHMEM_LIB_SONAME; + env_value && *env_value != 0 ? env_value : nullptr; void* library = dlopen(libnvshmem_path, RTLD_LAZY); if (library == nullptr) { - fprintf(stderr, "Failed to open %s library: %s", libnvshmem_path, dlerror()); - abort(); + fprintf(stderr, "Failed to open library (from %s): %s", + libnvshmem_path ? libnvshmem_path : "", dlerror()); } - // Initialize supported NVSHMEM host API - NVSHMEM_SET_FN(nvshmemx_cumodule_init) NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_init_status) } - // Dlopened NVSHMEM API - int (*nvshmemx_cumodule_init)(CUmodule); int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_init_status)(); std::mutex mutex_; }; diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index d03aa51b3124..cb48a20dc3d5 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -17,9 +17,8 @@ limitations under the License. #include #include -#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "third_party/gpus/cuda/include/cuda.h" - +#include "jaxlib/mosaic/gpu/nvshmem.h" extern "C" { @@ -177,13 +176,16 @@ void* mosaic_gpu_module_load(void *data) { abort(); } - CUdeviceptr ptr = 0; - size_t size = 0; - // Check if module contains NVSHMEM globals implying NVSHMEM state needs to set - if (cuModuleGetGlobal(&ptr, &size, module, "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { - if (mosaic::gpu::NvshmemApi::Default().cumodule_int(module) != NVSHMEM_SUCCESS) { - fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); - abort(); + { // Set the NVSHMEM state if it's used by the module. + CUdeviceptr ptr = 0; + size_t size = 0; + if (cuModuleGetGlobal(&ptr, &size, module, + "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_init(module) != + NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } } } diff --git a/pyproject.toml b/pyproject.toml index 03e4ec0c9ffe..03cc78a6dcbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ module = [ "numpy.*", "opt_einsum.*", "optax.*", + "portpicker.*", "pygments.*", "pytest.*", "rich.*", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index ed7e96ad1d02..3769da27a1eb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -363,6 +363,27 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "gpu_pallas_distributed_test", + srcs = ["gpu_pallas_distributed_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + }, + tags = ["multiaccelerator"], + deps = [ + "//jax:extend", + "//jax:pallas_mosaic_gpu", + "//jax:test_multiprocess", + ] + py_deps("portpicker"), +) + jax_multiplatform_test( name = "tpu_ops_test", srcs = [ diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py new file mode 100644 index 000000000000..882475406b28 --- /dev/null +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -0,0 +1,91 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distributed pallas GPU operations.""" + +import functools +import jax +from jax import lax +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class PallasCallRemoteDMATest(jt_multiprocess.MultiProcessTest): + + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + super().setUp() + + def test_basic_remote_dma(self): + if jax.process_count() < 2: + self.skipTest("Test requires multiple processes.") + if jax.process_index() > 2: + return # Only 2 processes needed. + def kernel(x_ref, y_ref, ready_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + y_ref[...] = x_ref[...] + pl.semaphore_signal(ready_sem, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(ready_sem) + neighbor_ptr = plgpu.remote_ref( + y_ref, other_dev_id, device_id_type=pl.DeviceIdType.LOGICAL + ) + neighbor_ptr[...] = x_ref[...] + pl.semaphore_signal(recv_sem, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(recv_sem) + + x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[ + plgpu.SemaphoreType.REGULAR, + plgpu.SemaphoreType.REGULAR, + ], + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + shard_map.shard_map( + body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False, + ) + )(x) + + expected = x[8:] if jax.process_index() == 0 else x[:8] + np.testing.assert_allclose(y.addressable_shards[0].data, expected) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b2325fb62dee..dcafa9b7277e 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -34,6 +34,7 @@ from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.pallas import primitives as pallas_primitives from jax._src.state import types as state_types from jax.experimental import pallas as pl import jax.experimental.mosaic.gpu as mgpu @@ -1708,6 +1709,9 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.tcgen05_mma_p, lax.slice_p, pallas_core.core_map_p, + pallas_primitives.semaphore_signal_p, + pallas_primitives.semaphore_wait_p, + pallas_primitives.semaphore_read_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) From b3c49b0ecc3f24e13e54f582d8d845afa52d2a58 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 6 May 2025 07:02:33 -0700 Subject: [PATCH 1022/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a6ac2e4648e217653453aab56320af252fb992b7. PiperOrigin-RevId: 755356129 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index db143127ffe6..28071970c894 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "03d995ffbc7653bcc4fe0d477330422f139c634b" -XLA_SHA256 = "a9c9245e4e9971f57e252ab023ed26304c3cb0ffc1164bf30231930a3096bc3f" +XLA_COMMIT = "a6ac2e4648e217653453aab56320af252fb992b7" +XLA_SHA256 = "fc49f1ba52cccae1397d15401920f10e0cc6b1959abf9da0e2d827c489d9923d" def repo(): tf_http_archive( From 9c65f62d98701f1935b487f98cf414f399e439e7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 May 2025 08:08:52 -0700 Subject: [PATCH 1023/1769] Fix a x64 error in fused_attention_stablehlo The code assumed that `jnp.arange(x)` will return an `int32` array when `x` is an integer. But the return type depends on the value of the x64 mode. The fix is to specify the desired dtype explicitly. PiperOrigin-RevId: 755376676 --- jax/_src/cudnn/fused_attention_stablehlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 39d97255fa4a..46df84e08e0f 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -478,7 +478,7 @@ def _cu_offset(offsets, max_seq): batch = offsets.shape[0] offsets = jnp.where( offsets >= 0, - offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis], + offsets + (jnp.arange(batch, dtype=offsets.dtype) * max_seq)[..., jnp.newaxis], offsets, ) return offsets From 958ea15a86c44c226624253a8c2ea6deecefa5a4 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 6 May 2025 08:57:33 -0700 Subject: [PATCH 1024/1769] Enable command buffer support for buffer callbacks. This enables `command_buffer_compatible=True` for buffer callbacks on GPU. PiperOrigin-RevId: 755391950 --- jax/_src/buffer_callback.py | 10 ++++++++++ jaxlib/cuda/cuda_plugin_extension.cc | 3 +++ jaxlib/gpu/py_client_gpu.cc | 23 ++++++++++++++++++++++- jaxlib/gpu/py_client_gpu.h | 1 + jaxlib/rocm/rocm_plugin_extension.cc | 3 +++ jaxlib/xla_client.py | 2 +- tests/buffer_callback_test.py | 17 ++++++++++++++--- 7 files changed, 54 insertions(+), 5 deletions(-) diff --git a/jax/_src/buffer_callback.py b/jax/_src/buffer_callback.py index 916ac57da2ad..739fdb4c408d 100644 --- a/jax/_src/buffer_callback.py +++ b/jax/_src/buffer_callback.py @@ -46,6 +46,7 @@ def buffer_callback( has_side_effect: bool = False, vmap_method: str | None = None, input_output_aliases: dict[int, int] | None = None, + command_buffer_compatible: bool = False, ): """An experimental callback that operates in place on device buffers. @@ -112,6 +113,10 @@ def callback(ctx: ExecutionContext, out, *args) -> None: input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. + command_buffer_compatible: if ``True``, the callback will be traced into + the command buffer. This means that the Python code should only be + executed once, and then the operations will be replayed for every + subsequent call. Returns: A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof), @@ -167,6 +172,7 @@ def wrapped_callback(*args, **kwargs): vmap_method=vmap_method, has_side_effect=has_side_effect, input_output_aliases=static_input_output_aliases, + command_buffer_compatible=command_buffer_compatible, ) return tree_util.tree_unflatten(out_tree, out_flat) @@ -228,6 +234,7 @@ def _buffer_callback_lowering( out_tree: Any, has_side_effect: bool, input_output_aliases: Sequence[tuple[int, int]], + command_buffer_compatible: bool, **_, ): @@ -242,6 +249,9 @@ def _buffer_callback_lowering( if target_name is None: raise ValueError(f"`buffer_callback` not supported on {platform} backend.") + if command_buffer_compatible and platform in ("cuda", "rocm"): + target_name += "_cmd_buffer" + def wrapped_callback(exec_ctx, *args: Any): args_in, args_out = util.split_list(args, [in_tree.num_leaves]) py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in) diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index d7500b711e48..383bbf7731aa 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -53,6 +53,9 @@ nb::dict FfiRegistrations() { dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; dict["xla_buffer_python_gpu_callback"] = jax::EncapsulateFfiHandler(jax::cuda::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + jax::EncapsulateFfiHandler( + jax::cuda::kXlaBufferPythonGpuCallbackCmdBuffer); return dict; } diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 0afa3a9bf1d5..cb618890023b 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -35,8 +35,8 @@ limitations under the License. #include "absl/types/span.h" #include "include/dlpack/dlpack.h" #include "nanobind/nanobind.h" -#include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi.h" +#include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/pjrt/host_callback.h" @@ -279,10 +279,31 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .RemainingArgs() .RemainingRets()); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallbackCmdBuffer, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets(), + {ffi::Traits::kCmdBufferCompatible}); + XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "xla_buffer_python_gpu_callback", absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), kXlaBufferPythonGpuCallback); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback_cmd_buffer", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallbackCmdBuffer); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index b389dd393443..0df0891ceae5 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -24,6 +24,7 @@ namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallbackCmdBuffer); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 74013ed0de68..37ae638a47fc 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -77,6 +77,9 @@ nb::dict FfiRegistrations() { dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; dict["xla_buffer_python_gpu_callback"] = jax::EncapsulateFfiHandler(jax::hip::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + jax::EncapsulateFfiHandler( + jax::hip::kXlaBufferPythonGpuCallbackCmdBuffer); return dict; } diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 6d03f4530631..a77a8226d944 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 336 +_version = 337 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py index acc414bf7fd4..e77ee4af687f 100644 --- a/tests/buffer_callback_test.py +++ b/tests/buffer_callback_test.py @@ -97,9 +97,14 @@ def callback(ctx, out, arg): # DLPack tensors. jax.block_until_ready(fun(data)) - @parameterized.parameters(jtu.dtypes.all) + @parameterized.product( + dtype=jtu.dtypes.all, command_buffer_compatible=[True, False] + ) @jtu.run_on_devices("cuda") - def test_cuda_array_interface(self, dtype): + def test_cuda_array_interface(self, dtype, command_buffer_compatible): + if command_buffer_compatible and jaxlib_extension_version < 337: + self.skipTest("Requires jaxlib extension version of at least 337.") + def callback(ctx, out, arg): ctx.stream # doesn't crash @@ -121,8 +126,14 @@ def callback(ctx, out, arg): shape = (3, 4) data = rng(shape, dtype) fun = buffer_callback.buffer_callback( - callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + command_buffer_compatible=command_buffer_compatible, ) + + # TODO: There's an XLA:GPU/CUDA bug that causes a segfault when + # instantiating an empty CUDA graph. Once that bug is fixed or worked + # around, add a test that checks that the Python callback is only executed + # once. jax.block_until_ready(fun(data)) @parameterized.parameters([ From 579a1573cc82265c2b3245f2ee9ffd33f55fcd63 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 6 May 2025 09:03:30 -0700 Subject: [PATCH 1025/1769] [pallas:mosaic] Handle more types in `ir_constant` PiperOrigin-RevId: 755394119 --- jax/_src/pallas/mosaic/lowering.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index a920f0363470..e3e3fdeeb32b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -311,11 +311,9 @@ def ir_constant(x, mlir_type=None): x = np.array(x, np.float32) if not mlir_type: mlir_type = _dtype_to_ir_type(x.dtype) - if isinstance(x, int) or np.issubdtype(x.dtype, np.integer): + if isinstance(x, int) or jnp.issubdtype(x.dtype, np.integer): return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))) - elif isinstance(x, float) or x.dtype == np.float32: - return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) - elif x.dtype == jnp.bfloat16: + elif isinstance(x, float) or jnp.issubdtype(x.dtype, jnp.floating): return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) elif x.dtype == jnp.bool_: return arith.constant(mlir_type, ir.BoolAttr.get(bool(x))) From cf9b265c0b7512e43f26c69761a0cdcab5d12112 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 6 May 2025 11:27:50 -0700 Subject: [PATCH 1026/1769] Avoid an unlucky seed for for some random categorical tests. PiperOrigin-RevId: 755450699 --- jax/_src/internal_test_util/test_harnesses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index b557434ac7f3..7445b9cfcb6f 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -2744,7 +2744,8 @@ def wrap_and_split(): "random_categorical", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}", lambda x, axis: jax.random.categorical( - jax.random.key(42), x, axis), + # TODO(b/416027995): Change this key back to 42. + jax.random.key(1337), x, axis), [RandArg(shape, dtype), StaticArg(axis)], dtype=dtype, From 8c113bd1775c1cd8f62d05ebb0daa4df1d00c5a0 Mon Sep 17 00:00:00 2001 From: chaserileyroberts Date: Mon, 5 May 2025 17:39:34 -0700 Subject: [PATCH 1027/1769] Generalized test, removed specific call to config --- tests/multiprocess_gpu_test.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index fe9922148ab4..20a2b9ba972b 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -106,29 +106,17 @@ def test_distributed_jax_visible_devices(self): env["JAX_PORT"] = str(port) env["NUM_TASKS"] = str(num_tasks) env["TASK"] = str(task) - visible_devices = ",".join( - str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) - - if jtu.is_device_rocm(): - program = ( - 'import jax, os; ' - f'jax.config.update("jax_rocm_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) - else: - program = ( - 'import jax, os; ' - f'jax.config.update("jax_cuda_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) + visible_devices = [ + (task * num_gpus_per_task) + i for i in range(num_gpus_per_task) + ] + program = ( + 'import jax, os; ' + 'jax.distributed.initialize(' + 'f\'localhost:{os.environ["JAX_PORT"]}\', ' + f'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"]), {visible_devices}); ' + 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' + ) args = [sys.executable, "-c", program] proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) From 439f0f14318ff970204cb2f0f093641468a2e992 Mon Sep 17 00:00:00 2001 From: Alina Sbirlea Date: Tue, 6 May 2025 12:40:14 -0700 Subject: [PATCH 1028/1769] Integrate LLVM at llvm/llvm-project@2d287f51eff2 Updates LLVM usage to match [2d287f51eff2](https://github.com/llvm/llvm-project/commit/2d287f51eff2) PiperOrigin-RevId: 755479536 --- jaxlib/mosaic/dialect/tpu/transforms/communication.cc | 8 ++++++-- jaxlib/mosaic/gpu/launch_lowering.cc | 11 ++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 7e99dd15611b..dfe42111916c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -111,8 +111,12 @@ struct LogicalToPhysicalDeviceIdPass {total_devices}, IntegerType::get(func.getContext(), 32), TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); - func.insertArgument(func.getNumArguments(), device_assignment_type, - nullptr, UnknownLoc::get(func.getContext())); + + if (failed(func.insertArgument(func.getNumArguments(), + device_assignment_type, nullptr, + UnknownLoc::get(func.getContext())))) { + return signalPassFailure(); + } auto device_assignment_arg = func.getArgument(func.getNumArguments() - 1); func.walk([device_assignment_arg](Operation *some_op) { if (auto op = dyn_cast(some_op)) { diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 53d4f47e58cc..44362e825345 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -238,7 +238,7 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, cluster = as_32bit(launch.getClusterSizeOperandValues()); } else { cluster.x = cluster.y = cluster.z = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); } mlir::Value stream = launch.getAsyncObject(); builder.create( @@ -337,15 +337,16 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { launch.getDynamicSharedMemorySize(), cluster_shape); // Add a new function argument for the kernel handle. - func.insertArgument(0, ptr_ty, - mlir::DictionaryAttr::get(func.getContext()), - mlir::UnknownLoc::get(func.getContext())); + if (failed(func.insertArgument( + 0, ptr_ty, mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())))) { + return mlir::WalkResult::interrupt(); + } mlir::Value kernel_handle = func.getArgument(0); if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { return mlir::WalkResult::interrupt(); } launch.erase(); - // TODO(apaszke): Generate a destructor function. // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); From 35b25c914a62582968e44f5d02dcdccee667072c Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Tue, 6 May 2025 16:36:49 -0700 Subject: [PATCH 1029/1769] Use `LoadedExecutableRef` instead of `std::unique_ptr` PiperOrigin-RevId: 755573315 --- jaxlib/py_client.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 6caf478a4324..2cc84f0bf86c 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -446,7 +446,7 @@ PyClient::CompileIfrtProgram( } } - std::unique_ptr ifrt_loaded_executable; + ifrt::LoadedExecutableRef ifrt_loaded_executable; std::optional fingerprint; { nb::gil_scoped_release gil_release; @@ -529,7 +529,7 @@ PyClient::DeserializeExecutable(nb_class_ptr client, ifrt::DeviceListRef executable_devices, std::optional options, std::vector host_callbacks) { - std::unique_ptr ifrt_loaded_executable; + ifrt::LoadedExecutableRef ifrt_loaded_executable; std::optional fingerprint; auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( std::move(options), std::move(executable_devices), From 51cd3c3b33a608909233f31fef228c285743b33f Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Tue, 6 May 2025 22:21:12 -0700 Subject: [PATCH 1030/1769] Temporarily roll back changes for new LLVM version Reverts 439f0f14318ff970204cb2f0f093641468a2e992 PiperOrigin-RevId: 755680460 --- jaxlib/mosaic/dialect/tpu/transforms/communication.cc | 8 ++------ jaxlib/mosaic/gpu/launch_lowering.cc | 11 +++++------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index dfe42111916c..7e99dd15611b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -111,12 +111,8 @@ struct LogicalToPhysicalDeviceIdPass {total_devices}, IntegerType::get(func.getContext(), 32), TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); - - if (failed(func.insertArgument(func.getNumArguments(), - device_assignment_type, nullptr, - UnknownLoc::get(func.getContext())))) { - return signalPassFailure(); - } + func.insertArgument(func.getNumArguments(), device_assignment_type, + nullptr, UnknownLoc::get(func.getContext())); auto device_assignment_arg = func.getArgument(func.getNumArguments() - 1); func.walk([device_assignment_arg](Operation *some_op) { if (auto op = dyn_cast(some_op)) { diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 44362e825345..53d4f47e58cc 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -238,7 +238,7 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, cluster = as_32bit(launch.getClusterSizeOperandValues()); } else { cluster.x = cluster.y = cluster.z = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); } mlir::Value stream = launch.getAsyncObject(); builder.create( @@ -337,16 +337,15 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { launch.getDynamicSharedMemorySize(), cluster_shape); // Add a new function argument for the kernel handle. - if (failed(func.insertArgument( - 0, ptr_ty, mlir::DictionaryAttr::get(func.getContext()), - mlir::UnknownLoc::get(func.getContext())))) { - return mlir::WalkResult::interrupt(); - } + func.insertArgument(0, ptr_ty, + mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())); mlir::Value kernel_handle = func.getArgument(0); if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { return mlir::WalkResult::interrupt(); } launch.erase(); + // TODO(apaszke): Generate a destructor function. // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); From 2ebb756056396e825c32b493ea242cd6c9db0765 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 6 May 2025 23:46:45 -0700 Subject: [PATCH 1031/1769] [Mosaic GPU] Extract duplicated code into a `_transforms_from_uses` function. PiperOrigin-RevId: 755706701 --- .../mosaic/gpu/transform_inference.py | 57 ++++++------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index d4810845f2bc..b7d06822e0e9 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -92,6 +92,18 @@ def _resolve_transforms( return transforms +def _transforms_from_uses(op: ir.OpView) -> ir.Attribute | None: + transforms = None + + for result_use in cast(ir.OpResult, op.result).uses: + consumer = result_use.owner + op_user = consumer.operands[result_use.operand_number] + user_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + transforms = _resolve_transforms(transforms, user_transforms) + return transforms + def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") @@ -208,17 +220,7 @@ def _infer_memref_store_transforms(op: memref.StoreOp) -> OptionalTransforms: @partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: - transforms = None - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - transforms = _resolve_transforms(transforms, out_transforms) - + transforms = _transforms_from_uses(op) return None if transforms is None else ([], [transforms]) @@ -245,15 +247,7 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: raise NotImplementedError( "memref view with in_transforms aren't yet supported" ) - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - transforms = _resolve_transforms(transforms, out_transforms) + transforms = _transforms_from_uses(op) # TODO(bchetioui): do we actually need to assign a transform to the input of # the view op? Presumably, it'll only be used to access scratch memory. @@ -292,16 +286,7 @@ def _get_tile_and_swizzle_transforms( def _infer_memref_subview_transforms( op: memref.SubViewOp, ) -> OptionalTransforms: - transforms = None - - for result_use in cast(ir.OpResult, op.result).uses: - consumer = result_use.owner - op_user = consumer.operands[result_use.operand_number] - user_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - transforms = _resolve_transforms(transforms, user_transforms) - + transforms = _transforms_from_uses(op) in_transforms = inference_utils.value_transforms(op.source) transforms = _resolve_transforms(transforms, in_transforms) @@ -353,17 +338,7 @@ def _infer_memref_transpose_transforms( out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() transpose = in_strides != out_strides - users = list(op.result.uses) - if len(users) != 1: - raise NotImplementedError( - f"Only memref.transpose with a single use are supported, got {op}" - ) - - op_operand_use = users[0] - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand(consumer, op_user) - + out_transforms = _transforms_from_uses(op) in_transforms = [] if not transpose: in_transforms = out_transforms From 698404f4ad942efa7d36505904a462ae3887a129 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 7 May 2025 01:55:16 -0700 Subject: [PATCH 1032/1769] [Mosaic GPU] Implement a trivial pass-through transform inference for `memref.cast` PiperOrigin-RevId: 755747438 --- .../mosaic/gpu/transform_inference.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index b7d06822e0e9..b032992f88d4 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -363,6 +363,23 @@ def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: raise NotImplementedError("Non-scalar memref.load transforms") +@partial(_add_transform_inference_rule, memref.CastOp) +def _infer_memref_cast_transforms( + op: memref.CastOp, +) -> OptionalTransforms: + if inference_utils.has_in_transforms_set( + op + ) and inference_utils.has_out_transforms_set(op): + return inference_utils.in_transforms(op), inference_utils.out_transforms(op) + + transforms = _transforms_from_uses(op) + in_transforms = inference_utils.value_transforms(op.source) + transforms = _resolve_transforms(transforms, in_transforms) + if transforms is None: + return None + return [transforms], [transforms] + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( From 011ec37867498c41f1c15726e66adaf02a952306 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 May 2025 04:39:09 -0700 Subject: [PATCH 1033/1769] [pallas:mosaic] Do not use *Op classes for creating MLIR ops unless necessary In most cases we just want the result of an op, so we could use e.g. arith.mulf instead of arith.MulFOp. PiperOrigin-RevId: 755796318 --- jax/_src/pallas/mosaic/lowering.py | 42 ++++++++++++++++-------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e3e3fdeeb32b..8adb1b8cc1fc 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -166,7 +166,6 @@ class LoweringContext: block_shapes: list[tuple[int | pallas_core.Squeezed, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None - replace = dataclasses.replace traceback_caches: mlir.TracebackCaches for_verification: bool forward_compatible: bool @@ -174,6 +173,8 @@ class LoweringContext: [tuple[jax.DimSize, ...]], tuple[int, ...] ] + replace = dataclasses.replace + @property def grid_rank(self): return len(self.grid_sizes) @@ -199,6 +200,7 @@ class LoweringRuleContext: avals_in: Sequence[jax_core.AbstractValue] avals_out: Sequence[jax_core.AbstractValue] block_shapes: Sequence[tuple[int | pallas_core.Squeezed, ...] | None] + replace = dataclasses.replace @property @@ -1126,9 +1128,9 @@ def write_env(var: jax_core.Var, val): current_name_stack, name_stack) current_name_stack = name_stack for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() for name in pushed: - tpu.TraceStartOp(message=name, level=10) + tpu.trace_start(message=name, level=10) try: ans = lowering_rules[eqn.primitive]( @@ -1163,7 +1165,7 @@ def write_env(var: jax_core.Var, val): popped, pushed = _compute_name_stack_updates( current_name_stack, initial_name_stack) for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() assert len(pushed) == 0 outvals = map(read_env, jaxpr.outvars) @@ -1608,13 +1610,13 @@ def _maybe_cast_load_to_bool( out_aval, is_kernel_boundary=True, ) - vector_zeros = arith.ConstantOp( + vector_zeros = arith.constant( load_vector_type, ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) ) return arith.cmpi(predicate, val, vector_zeros) else: # Scalar case. - const_zero = arith.ConstantOp(load_scalar_type, const_zero) + const_zero = arith.constant(load_scalar_type, const_zero) return arith.cmpi(predicate, val, const_zero) @@ -1687,7 +1689,7 @@ def _masked_swap_lowering_rule( result = memref.load(ref, starts) result = _maybe_cast_load_to_bool(ctx, val_aval, result) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) - memref.StoreOp(val, ref, starts) + memref.store(val, ref, starts) return result if not is_vmem_store: @@ -1740,9 +1742,9 @@ def _masked_swap_lowering_rule( if need_stride: if mask is not None: raise NotImplementedError("masked swap with strided store") - tpu.StridedStoreOp(val, ref, starts, strides) + tpu.strided_store(val, ref, starts, strides) else: - tpu.VectorStoreOp(val, ref, starts, [], mask=mask) + tpu.vector_store(val, ref, starts, [], mask=mask) return result @@ -1805,7 +1807,7 @@ def _proxy_fun(val, *, axes): ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) identity = ir.DenseElementsAttr.get_splat(out_type, val) - acc = arith.ConstantOp(out_type, identity) + acc = arith.constant(out_type, identity) return vector.multi_reduction(kind, x, acc, axes) return _lowering_rule @@ -2074,12 +2076,12 @@ def _dot_general_lowering_rule( else: raise NotImplementedError(f"Unsupported {preferred_element_type=}") - acc = arith.ConstantOp( + acc = arith.constant( red_type, ir.DenseElementsAttr.get_splat(red_type, val) ) - red = vector.MultiDimReductionOp( + red = vector.multi_reduction( ir.Attribute.parse("#vector.kind"), - arith.MulFOp(x, y), + arith.mulf(x, y), acc, [1] ) @@ -2101,7 +2103,7 @@ def _dot_general_lowering_rule( ) else: raise NotImplementedError(f"Unsupported dot precision: {precision}") - out_tile = arith.ConstantOp( + out_tile = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) return tpu.matmul( @@ -2945,7 +2947,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1) - minus_one = arith.ConstantOp( + minus_one = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one) ) return arith.xori(x, minus_one) @@ -3064,7 +3066,7 @@ def _run_body(i, args): iv = for_op.induction_variable inner_args = for_op.inner_iter_args inner_out = _run_body(iv, inner_args) - scf.YieldOp(inner_out) + scf.yield_(inner_out) return for_op.results @@ -3236,10 +3238,10 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): ) else: out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) with ir.InsertionPoint(if_op.else_block): out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) return if_op.results @@ -3501,7 +3503,7 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): block_shapes=(*ctx.block_shapes, *block_shapes) ) out = jaxpr_subcomp(ctx, jaxpr, *consts, *args) - tpu.YieldOp(out) + tpu.yield_(out) return region.results @@ -3981,7 +3983,7 @@ def _pad(val): pad = vector.broadcast(pad_vec_type, padding_value) else: scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) - pad = arith.ConstantOp( + pad = arith.constant( pad_vec_type, ir.DenseElementsAttr.get_splat( pad_vec_type, From 0651fdc8e500a9987b5cf3c0273957a3361b701a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 7 May 2025 05:06:47 -0700 Subject: [PATCH 1034/1769] Fix an import path to properly detect the CUDA plugin in bazel tests PiperOrigin-RevId: 755803774 --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b355ce60436..c40716daee4b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -71,7 +71,7 @@ export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: +for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' From b60dc08f039acddc21942f0990cd08bba7ab69ce Mon Sep 17 00:00:00 2001 From: Nicolas Perez Date: Wed, 7 May 2025 05:07:11 -0700 Subject: [PATCH 1035/1769] Block on svd result to fix race condition in svd_test. PiperOrigin-RevId: 755803901 --- tests/svd_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/svd_test.py b/tests/svd_test.py index 97f8176f8f94..4225db038d72 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -189,7 +189,9 @@ def testSingularValues(self, m, n, log_cond, full_matrices): osp_linalg_fn = functools.partial( osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) - actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) + actual_s = svd.svd( + a, full_matrices=full_matrices, compute_uv=compute_uv + ).block_until_ready() expected_s = osp_linalg_fn(a) From 1fabde7f616a7882aae73cccd946433c75253393 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 7 May 2025 06:08:31 -0700 Subject: [PATCH 1036/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4b45c0f0dc2fb80be6036c36a3e40f3cd1b478c9. PiperOrigin-RevId: 755821476 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 28071970c894..ed77b40bb4a7 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a6ac2e4648e217653453aab56320af252fb992b7" -XLA_SHA256 = "fc49f1ba52cccae1397d15401920f10e0cc6b1959abf9da0e2d827c489d9923d" +XLA_COMMIT = "4b45c0f0dc2fb80be6036c36a3e40f3cd1b478c9" +XLA_SHA256 = "64292fbcceebe0ee03a97a0a0edf667067642de028a2e7d7d1175641e91f8925" def repo(): tf_http_archive( From 17b02ecb1b0ec5ca109cd997a82759490b22531e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 May 2025 06:49:08 -0700 Subject: [PATCH 1037/1769] Make PartitionSpec not inherit from a tuple at runtime. For type checkers, it's still a tuple. PiperOrigin-RevId: 755833267 --- CHANGELOG.md | 1 + jax/_src/partition_spec.py | 55 ++++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8007dd3eec2..a03eb80eb973 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes * JAX nightly packages are now published to artifact registry. To install these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). + * `jax.sharding.PartitionSpec` no longer inherits from a tuple. ## JAX 0.6.0 (April 16, 2025) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index bf6a90060bc8..158e361482b7 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from typing import TYPE_CHECKING class UnconstrainedSingleton: @@ -43,7 +44,7 @@ def _canonicalize_partition(partition): return partition -class PartitionSpec(tuple): +class PartitionSpecImpl: """Tuple describing how to partition an array across a mesh of devices. Each element is either ``None``, a string, or a tuple of strings. @@ -52,38 +53,64 @@ class PartitionSpec(tuple): This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ + __slots__ = ("_partitions",) + __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION def __init__(self, *partitions): - pass - - def __new__(cls, *partitions): - partitions = tuple(_canonicalize_partition(p) for p in partitions) - return tuple.__new__(PartitionSpec, partitions) + self._partitions = tuple(_canonicalize_partition(p) for p in partitions) def __repr__(self): - return f"PartitionSpec{tuple.__repr__(self)}" + return f"PartitionSpec{self._partitions!r}" def __reduce__(self): - return (PartitionSpec, tuple(self)) + return (PartitionSpec, self._partitions) + + def __getitem__(self, i): + return self._partitions[i] + + def __iter__(self): + return iter(self._partitions) + + def __len__(self): + return len(self._partitions) def __eq__(self, other): - if not isinstance(other, tuple): + if not isinstance(other, (PartitionSpec, tuple)): return False other = tuple(_canonicalize_partition(o) for o in other) - return super().__eq__(other) + return self._partitions == other def __hash__(self): - return super().__hash__() + return hash(self._partitions) + + def __add__(self, other): + if not isinstance(other, (tuple, PartitionSpec)): + return NotImplemented + return PartitionSpec(*self, *other) + + def __radd__(self, other): + if not isinstance(other, (tuple, PartitionSpec)): + return NotImplemented + return PartitionSpec(*other, *self) def index(self, value): - value = _canonicalize_partition(value) - return super().index(value) + return self._partitions.index(_canonicalize_partition(value)) + + def count(self, value): + return self._partitions.count(_canonicalize_partition(value)) def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: - out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self] + out = [None if p is _UNCONSTRAINED_PARTITION else p + for p in self._partitions] if len(out) < ndim: out.extend([None] * (ndim - len(out))) return PartitionSpec(*out) + +if TYPE_CHECKING: + class PartitionSpec(PartitionSpecImpl, tuple): # type: ignore + ... +else: + PartitionSpec = PartitionSpecImpl From 043b4a73d82ea1120e178342e30220ea7ecc5ceb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 7 May 2025 06:56:31 -0700 Subject: [PATCH 1038/1769] Exclude Mosaic GPU from the experimental target PiperOrigin-RevId: 755835530 --- jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/BUILD b/jax/BUILD index 749c72b45aba..16bc9de6935e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1216,6 +1216,7 @@ py_library_providing_imports_info( ], [ "experimental/buffer_callback.py", + "experimental/mosaic/gpu/*.py", ], ), visibility = ["//visibility:public"], From a8583db5f9240283cd2d22e75d8f98bdc3982322 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 May 2025 07:09:52 -0700 Subject: [PATCH 1039/1769] [pallas:mosaic] Fixed the type of `dimension_semantics` Previously the type annotation was malformed. PiperOrigin-RevId: 755839296 --- jax/_src/pallas/mosaic/core.py | 22 ++++++++++++---------- jax/_src/pallas/mosaic/lowering.py | 8 ++------ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index d31de1c22089..c04fc6f155b9 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -65,6 +65,17 @@ class KernelType(enum.Enum): SC_VECTOR_SUBCORE = 2 +class GridDimensionSemantics(enum.Enum): + PARALLEL = "parallel" + ARBITRARY = "arbitrary" + +PARALLEL = GridDimensionSemantics.PARALLEL +ARBITRARY = GridDimensionSemantics.ARBITRARY + + +DimensionSemantics = Literal["parallel", "arbitrary"] | GridDimensionSemantics + + @dataclasses.dataclass(frozen=True) class TPUCompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. @@ -88,9 +99,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): disable_bounds_checks: Disable bounds checks in the kernel. """ BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" - dimension_semantics: ( - Sequence[Literal["parallel", "arbitrary"] | GridDimensionSemantics] | None - ) = None + dimension_semantics: Sequence[DimensionSemantics] | None = None allow_input_fusion: Sequence[bool] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None @@ -271,10 +280,3 @@ def _convert_semaphore_type_to_aval( pallas_core._out_shape_to_aval_mapping[SemaphoreType] = ( _convert_semaphore_type_to_aval ) - - -class GridDimensionSemantics(enum.Enum): - PARALLEL = "parallel" - ARBITRARY = "arbitrary" -PARALLEL = GridDimensionSemantics.PARALLEL -ARBITRARY = GridDimensionSemantics.ARBITRARY diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8adb1b8cc1fc..9915228168e9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -402,9 +402,7 @@ def __init__( self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, - dimension_semantics: ( - Sequence[str | tpu_core.GridDimensionSemantics, ...] | None - ), + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, mesh: mesh_lib.Mesh | None, dynamic_shape_replacement_fn: Callable[ [tuple[jax.DimSize, ...]], tuple[int, ...] @@ -656,9 +654,7 @@ def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, *, - dimension_semantics: ( - Sequence[str | tpu_core.GridDimensionSemantics, None, ...] | None - ), + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, dynamic_shape_replacement_enabled: bool = False, From 63b51b5a242b02213117f3c65dfa495cd400886c Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 7 May 2025 07:27:57 -0700 Subject: [PATCH 1040/1769] [Mosaic GPU] Do not shortcut the transform computation for `memref.cast` As it turns out the result of `in_transforms`/`out_transforms` cannot be returned directly as it has the wrong type (one extra wrapping with `ir.ArrayAttr`. I will remove this for now and consider refactoring the whole file in the future to make the types clearer. PiperOrigin-RevId: 755844780 --- jax/experimental/mosaic/gpu/transform_inference.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index b032992f88d4..1d97b3f0fa63 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -367,11 +367,6 @@ def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: def _infer_memref_cast_transforms( op: memref.CastOp, ) -> OptionalTransforms: - if inference_utils.has_in_transforms_set( - op - ) and inference_utils.has_out_transforms_set(op): - return inference_utils.in_transforms(op), inference_utils.out_transforms(op) - transforms = _transforms_from_uses(op) in_transforms = inference_utils.value_transforms(op.source) transforms = _resolve_transforms(transforms, in_transforms) From 1aee6dc6c0919e0150a1e73e92ffa06ff489cf6d Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Wed, 7 May 2025 07:42:39 -0700 Subject: [PATCH 1041/1769] When registering plugins, allow to postpone the creation of the options dict. PiperOrigin-RevId: 755849074 --- jax/_src/xla_bridge.py | 10 +++++++--- tests/xla_bridge_test.py | 24 ++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 4c8373cc1105..72a16d5fbe5c 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -614,6 +614,8 @@ def _options_from_jax_configs(plugin_name): return options +OptionsDict = Mapping[str, str | int | list[int] | float | bool] + # TODO(b/261345120): decide on a public name and expose a public method which is # an alias of this method. @@ -622,7 +624,7 @@ def register_plugin( *, priority: int = 400, library_path: str | None = None, - options: Mapping[str, str | int | list[int] | float | bool] | None = None, + options: OptionsDict | Callable[[], OptionsDict] | None = None, c_api: Any | None = None, ) -> Any: """Registers a backend factory for the PJRT plugin. @@ -633,7 +635,9 @@ def register_plugin( Default to be 400. library_path: Optional. The full path to the .so file of the plugin. The plugin needs to provide either the library_path or the c_api. - options: Optional. It is used when creating a PJRT plugin client. + options: Optional. It is used when creating a PJRT plugin client. Can be a + callable, in which case it will be invoked upon plugin initialization + time, and will be expected to return an option dictionary. c_api: Optional. The plugin can provide a PJRT C API to be registered. """ def factory(): @@ -641,7 +645,7 @@ def factory(): xla_client.initialize_pjrt_plugin(plugin_name) updated_options = {} if options is not None: - updated_options.update(options) + updated_options.update(options() if callable(options) else options) updated_options.update(_options_from_jax_configs(plugin_name)) if distributed.global_state.client is None: return xla_client.make_c_api_client(plugin_name, updated_options, None) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 3f1ac1787d2a..5a6bf80a469d 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -143,7 +143,9 @@ def test_register_plugin(self): registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -181,7 +183,9 @@ def test_register_plugin_with_config(self): registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -203,6 +207,22 @@ def test_register_plugin_with_config(self): mock_make.assert_called_once_with("name1", options, None) + def test_register_plugin_with_lazy_config(self): + options = {"bar": "baz"} + + def f(): + return options + + with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): + with mock.patch.object( + _profiler, "register_plugin_profiler", autospec=True + ): + xb.register_plugin("foo", options=f, library_path="/dev/null") + with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: + with mock.patch.object(xc, "pjrt_plugin_initialized", autospec=True): + xb._backend_factories["foo"].factory() + mock_make.assert_called_once_with("foo", options, None) + class GetBackendTest(jtu.JaxTestCase): From ed62be8a5421056e55db1c1b217fbd4799441357 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 7 May 2025 07:49:13 -0700 Subject: [PATCH 1042/1769] [Mosaic GPU] Run `canonicalize` instead of `cse` before the lowering. PiperOrigin-RevId: 755851405 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +++--- jax/experimental/mosaic/gpu/core.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 77216ea99446..a6bdf76206d3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -833,9 +833,9 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: - # We need to run CSE first in orderto remove dead-code for which layout - # inference does not work. - pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) pm.run(module.operation) # Run Python lowering passes. The remaining passes will be run in C++ in diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 28fef9d0e96d..c20c5252a27f 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -653,9 +653,9 @@ def as_gpu_kernel( ) if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: - # We need to run CSE first in orderto remove dead-code for which layout - # inference does not work. - pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) pm.run(module.operation) # Run Python lowering passes. The remaining passes will be run in C++ in @@ -744,9 +744,9 @@ def as_torch_gpu_kernel( ) if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: - # We need to run CSE first in orderto remove dead-code for which layout - # inference does not work. - pm = mlir.passmanager.PassManager.parse("builtin.module(cse)", module.context) + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) pm.run(module.operation) # Run Python lowering passes. The remaining passes will be run in C++ in From 482ae677f3358c7935a4518828b796fba1c1e214 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 May 2025 08:26:22 -0700 Subject: [PATCH 1043/1769] [pallas:mosaic] Added a `register_lowering` decorator This change prepares the switch to `pltpu.KernelType`-specific lowering rules. PiperOrigin-RevId: 755864420 --- jax/_src/pallas/mosaic/lowering.py | 426 +++++++++-------------------- 1 file changed, 133 insertions(+), 293 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9915228168e9..f372f8f9b472 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -20,7 +20,7 @@ import dataclasses import functools import string -from typing import Any, Hashable +from typing import Any, Hashable, TypeVar import jax from jax import api_util @@ -41,9 +41,11 @@ from jax._src import traceback_util from jax._src import xla_bridge from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.export import shape_poly from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop from jax._src.lib import version as jaxlib_version @@ -325,6 +327,22 @@ def ir_constant(x, mlir_type=None): lowering_rules = {} skip_mlir_conversions = set() + +T = TypeVar("T") + + +def register_lowering_rule( + prim: jax_core.Primitive, *, ensure_mlir_values: bool = True +) -> Callable[[T], T]: + def decorator(rule: T) -> T: + lowering_rules[prim] = rule + if not ensure_mlir_values: + skip_mlir_conversions.add(prim) + return rule + + return decorator + + def _get_aval_physical_dtype_shape(aval): dtype_physical_shape = jax_core.physical_aval(aval).shape[ len(aval.shape) : @@ -1185,6 +1203,7 @@ def _ensure_mlir_value(val, aval): ) +@register_lowering_rule(state_primitives.get_p, ensure_mlir_values=False) def _get_lowering_rule( ctx: LoweringRuleContext, ref, *idx, tree, ): @@ -1201,10 +1220,7 @@ def _get_lowering_rule( return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.get_p] = _get_lowering_rule -skip_mlir_conversions.add(state_primitives.get_p) - - +@register_lowering_rule(state_primitives.swap_p, ensure_mlir_values=False) def _swap_lowering_rule( ctx: LoweringRuleContext, ref, @@ -1226,9 +1242,6 @@ def _swap_lowering_rule( ) return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.swap_p] = _swap_lowering_rule -skip_mlir_conversions.add(state_primitives.swap_p) - def _make_index(s): if isinstance(s, (int, np.ndarray)): @@ -1461,6 +1474,8 @@ class KeyScalarBundle: key_shape: tuple[int, ...] scalars: list[ir.OpResult] + +@register_lowering_rule(primitives.load_p, ensure_mlir_values=False) def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, transforms, mask, _ = args_tree.unflatten(args_flat) ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in) @@ -1574,10 +1589,6 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape)) -lowering_rules[primitives.load_p] = _load_lowering_rule -skip_mlir_conversions.add(primitives.load_p) - - def _maybe_cast_load_to_bool( ctx, out_aval, val: ir.Value ) -> tuple[ir.Value, jnp.dtype]: @@ -1630,6 +1641,7 @@ def _maybe_cast_store_to_memref_type( return arith.extui(int_out_type, val) +@register_lowering_rule(primitives.swap_p, ensure_mlir_values=False) def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): @@ -1744,10 +1756,7 @@ def _masked_swap_lowering_rule( return result -lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule -skip_mlir_conversions.add(primitives.swap_p) - - +@register_lowering_rule(primitives.multiple_of_p) def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): del ctx for multiple in values: @@ -1755,9 +1764,6 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): return val -lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule - - def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity): def _lowering_rule(ctx: LoweringRuleContext, x, *, axes): (x_aval,) = ctx.avals_in @@ -1819,7 +1825,7 @@ def _proxy_fun(val, *, axes): } _reduce_max_lowering_rule = reduce_lowering_rule( jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY) -lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule +register_lowering_rule(lax.reduce_max_p)(_reduce_max_lowering_rule) REDUCE_MIN_KINDS = { @@ -1833,7 +1839,7 @@ def _proxy_fun(val, *, axes): } _reduce_min_lowering_rule = reduce_lowering_rule( jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY) -lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule +register_lowering_rule(lax.reduce_min_p)(_reduce_min_lowering_rule) REDUCE_SUM_KINDS = { @@ -1847,9 +1853,10 @@ def _proxy_fun(val, *, axes): } _reduce_sum_lowering_rule = reduce_lowering_rule( jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY) -lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule +register_lowering_rule(lax.reduce_sum_p)(_reduce_sum_lowering_rule) +@register_lowering_rule(lax.reduce_and_p) def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1862,9 +1869,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule - +@register_lowering_rule(lax.reduce_or_p) def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1877,9 +1883,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule - +@register_lowering_rule(state_primitives.broadcast_to_p) def _broadcast_to_lowering_rule( ctx: LoweringRuleContext, x, shape: Sequence[int] ): @@ -1889,9 +1894,7 @@ def _broadcast_to_lowering_rule( ) -lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule - - +@register_lowering_rule(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): @@ -1932,9 +1935,6 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): return vector.broadcast(out_type, val) -lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule - - def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): """Converts a jax dot dimension numbers to a tpu dot dimension numbers. @@ -2000,6 +2000,7 @@ def format_dims(dims): return ir.Attribute.parse(tpu_dim_numbers_str) +@register_lowering_rule(lax.dot_general_p) def _dot_general_lowering_rule( ctx: LoweringRuleContext, x, @@ -2112,8 +2113,6 @@ def _dot_general_lowering_rule( ) -lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule - def _convert_helper(x, *, to_dtype): # Helper function for dtype conversion from_dtype = x.dtype @@ -2144,6 +2143,8 @@ def _convert_helper(x, *, to_dtype): return x.astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") + +@register_lowering_rule(lax.convert_element_type_p) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -2194,9 +2195,7 @@ def _convert_element_type_lowering_rule( multiple_results=False)(ctx, x) -lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule - - +@register_lowering_rule(lax.reshape_p) def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding): if dimensions is not None: @@ -2220,9 +2219,7 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ) -lowering_rules[lax.reshape_p] = _reshape_lowering_rule - - +@register_lowering_rule(lax.squeeze_p) def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): del dimensions # Unused. (aval_in,) = ctx.avals_in @@ -2243,9 +2240,7 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): ) -lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule - - +@register_lowering_rule(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] @@ -2253,9 +2248,7 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): return tpu.concatenate(out_type, xs, dimension=dimension) -lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule - - +@register_lowering_rule(lax.split_p) def _split_lowering_rule( ctx: LoweringRuleContext, x, *, sizes, axis ): @@ -2280,9 +2273,8 @@ def _split_lowering_rule( starts[axis] += size return outs -lowering_rules[lax.split_p] = _split_lowering_rule - +@register_lowering_rule(lax.iota_p) def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): if len(shape) == 1: @@ -2299,9 +2291,7 @@ def _1d_iota_helper(dtype, shape, dimension, sharding): return tpu.iota(out_type, dimension=dimension) -lowering_rules[lax.iota_p] = _iota_lowering_rule - - +@register_lowering_rule(lax.gather_p) def _gather_lowering_rule( ctx: LoweringRuleContext, x, @@ -2367,9 +2357,7 @@ def _gather_lowering_rule( raise NotImplementedError("Unsupported gather") -lowering_rules[lax.gather_p] = _gather_lowering_rule - - +@register_lowering_rule(lax.transpose_p) def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): if permutation != (1, 0): raise NotImplementedError @@ -2379,9 +2367,6 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): return vector.transpose(out_type, x, permutation) -lowering_rules[lax.transpose_p] = _transpose_lowering_rule - - def _bcast(x, y, x_aval, y_aval, out_aval): x_dtype = x_aval.dtype y_dtype = y_aval.dtype @@ -2411,6 +2396,8 @@ def _bcast(x, y, x_aval, y_aval, out_aval): return x, y +@register_lowering_rule(lax.add_p, ensure_mlir_values=False) +@register_lowering_rule(ad_util.add_any_p, ensure_mlir_values=False) def _add_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2421,12 +2408,6 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.add_p] = _add_lowering_rule -skip_mlir_conversions.add(lax.add_p) -lowering_rules[ad_util.add_any_p] = _add_lowering_rule -skip_mlir_conversions.add(ad_util.add_any_p) - - class FoldingError(Exception): pass @@ -2457,6 +2438,7 @@ def _fold(x, fuel): return None +@register_lowering_rule(lax.max_p, ensure_mlir_values=False) def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2469,10 +2451,7 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.max_p] = _max_lowering_rule -skip_mlir_conversions.add(lax.max_p) - - +@register_lowering_rule(lax.min_p, ensure_mlir_values=False) def _min_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2485,10 +2464,7 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.min_p] = _min_lowering_rule -skip_mlir_conversions.add(lax.min_p) - - +@register_lowering_rule(lax.sub_p, ensure_mlir_values=False) def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2499,10 +2475,7 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.sub_p] = _sub_lowering_rule -skip_mlir_conversions.add(lax.sub_p) - - +@register_lowering_rule(lax.mul_p, ensure_mlir_values=False) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2513,10 +2486,7 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.mul_p] = _mul_lowering_rule -skip_mlir_conversions.add(lax.mul_p) - - +@register_lowering_rule(lax.div_p, ensure_mlir_values=False) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2529,10 +2499,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.div_p] = _div_lowering_rule -skip_mlir_conversions.add(lax.div_p) - - +@register_lowering_rule(lax.rem_p, ensure_mlir_values=False) def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2545,10 +2512,7 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.rem_p] = _rem_lowering_rule -skip_mlir_conversions.add(lax.rem_p) - - +@register_lowering_rule(lax.abs_p) def _abs_lowering_rule(ctx: LoweringRuleContext, x): (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): @@ -2558,9 +2522,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.abs_p] = _abs_lowering_rule - - +@register_lowering_rule(lax.neg_p, ensure_mlir_values=False) def _neg_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in new_ctx = ctx.replace( @@ -2570,64 +2532,49 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x) -lowering_rules[lax.neg_p] = _neg_lowering_rule -skip_mlir_conversions.add(lax.neg_p) - - +@register_lowering_rule(lax.sign_p) def _sign_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.sign_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.sign_p] = _sign_lowering_rule - - +@register_lowering_rule(lax.nextafter_p) def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): return lower_fun( pallas_utils.nextafter_lowering_helper, multiple_results=False, )(ctx, x, y) -lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule - - +@register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) -lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule - - +@register_lowering_rule(lax.sqrt_p) def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) -lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule - - +@register_lowering_rule(lax.square_p) def _square_lowering_rule(ctx: LoweringRuleContext, x): if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): return arith.muli(x, x) return arith.mulf(x, x) -lowering_rules[lax.square_p] = _square_lowering_rule - - +@register_lowering_rule(lax.exp_p) def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.exp(x) -lowering_rules[lax.exp_p] = _exp_lowering_rule - - +@register_lowering_rule(lax.pow_p, ensure_mlir_values=False) def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): # jax accepts float base (x) and integer/float exponent (y), and integer # exponent is casted to float. @@ -2642,18 +2589,13 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): return math.powf(x, y) -lowering_rules[lax.pow_p] = _pow_lowering_rule -skip_mlir_conversions.add(lax.pow_p) - - +@register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): return lower_fun(lax_internal._integer_pow, multiple_results=False)( ctx, x, y=y) -lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule - - +@register_lowering_rule(lax.exp2_p, ensure_mlir_values=False) def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. @@ -2665,10 +2607,7 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): )(ctx, x) -lowering_rules[lax.exp2_p] = _exp2_lowering_rule -skip_mlir_conversions.add(lax.exp2_p) - - +@register_lowering_rule(lax.logistic_p) def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") @@ -2686,63 +2625,49 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return arith.divf(one, denom) -lowering_rules[lax.logistic_p] = _logistic_lowering_rule - - +@register_lowering_rule(lax.sin_p) def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.sin(x) -lowering_rules[lax.sin_p] = _sin_lowering_rule - - +@register_lowering_rule(lax.cos_p) def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.cos(x) -lowering_rules[lax.cos_p] = _cos_lowering_rule - - +@register_lowering_rule(lax.tan_p) def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.tan(x) -lowering_rules[lax.tan_p] = _tan_lowering_rule - - +@register_lowering_rule(lax.tanh_p) def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) -lowering_rules[lax.tanh_p] = _tanh_lowering_rule - - +@register_lowering_rule(lax.log_p) def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.log(x) -lowering_rules[lax.log_p] = _log_lowering_rule - - +@register_lowering_rule(lax.log1p_p) def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) -lowering_rules[lax.log1p_p] = _log1p_lowering_rule - - +@register_lowering_rule(lax.round_p) def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): if rounding_method == 0: return math.round(x) @@ -2752,37 +2677,28 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): raise NotImplementedError(f"Unsupported rounding method: {rounding_method}") -lowering_rules[lax.round_p] = _round_lowering_rule - - +@register_lowering_rule(lax.ceil_p) def _ceil_lowering_rule(ctx: LoweringRuleContext, x): return math.ceil(x) -lowering_rules[lax.ceil_p] = _ceil_lowering_rule - - +@register_lowering_rule(lax.floor_p) def _floor_lowering_rule(ctx: LoweringRuleContext, x): return math.floor(x) -lowering_rules[lax.floor_p] = _floor_lowering_rule - - +@register_lowering_rule(lax.clz_p) def _clz_lowering_rule(ctx: LoweringRuleContext, x): return math.ctlz(x) -lowering_rules[lax.clz_p] = _clz_lowering_rule - +@register_lowering_rule(lax.population_count_p) def _population_count_lowering_rule(ctx: LoweringRuleContext, x): aval_out = ctx.avals_out[0] if aval_out.shape == (): raise ValueError("Population count is not supported on scalars") return math.ctpop(x) -lowering_rules[lax.population_count_p] = _population_count_lowering_rule - # Mapping for signed integer comparisons. _cmpsi_lowering_types = { @@ -2888,23 +2804,17 @@ def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}") -lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p) -lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p) -lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p) -lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p) -lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p) -lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p) +for prim in [lax.eq_p, lax.ne_p, lax.lt_p, lax.le_p, lax.gt_p, lax.ge_p]: + register_lowering_rule(prim)(functools.partial(_cmp_lowering_rule, prim)) +@register_lowering_rule(lax.and_p, ensure_mlir_values=False) def _and_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.andi(x, y) -lowering_rules[lax.and_p] = _and_lowering_rule -skip_mlir_conversions.add(lax.and_p) - - +@register_lowering_rule(lax.is_finite_p) def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): out_aval, = ctx.avals_out out_type = aval_to_ir_type( @@ -2913,18 +2823,13 @@ def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): return _not_lowering_rule(ctx, tpu.weird(out_type, x)) -lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule - - +@register_lowering_rule(lax.or_p, ensure_mlir_values=False) def _or_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.ori(x, y) -lowering_rules[lax.or_p] = _or_lowering_rule -skip_mlir_conversions.add(lax.or_p) - - +@register_lowering_rule(lax.not_p) def _not_lowering_rule(ctx: LoweringRuleContext, x): # The primitive not_p is lowered to # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not @@ -2949,8 +2854,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): return arith.xori(x, minus_one) -lowering_rules[lax.not_p] = _not_lowering_rule - +@register_lowering_rule(lax.select_n_p) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): if len(args) > 1: raise NotImplementedError("select_n only supported with <= 2 arguments") @@ -2970,22 +2874,18 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): return arith.select(pred, y, x) -lowering_rules[lax.select_n_p] = _select_n_lowering_rule - - def _clamp(min, operand, max): res = jnp.maximum(operand, min) return jnp.minimum(res, max) +@register_lowering_rule(lax.clamp_p) def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max): """Compute minimum_p(maximum_p(min, operand), max).""" return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max) -lowering_rules[lax.clamp_p] = _clamp_lowering_rule - - +@register_lowering_rule(for_loop.for_p) def _for_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3017,9 +2917,6 @@ def _for_lowering_rule( return args -lowering_rules[for_loop.for_p] = _for_lowering_rule - - def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: int | ir.Value, num_steps: int | ir.Value, consts, *args, @@ -3066,6 +2963,7 @@ def _run_body(i, args): return for_op.results +@register_lowering_rule(lax.scan_p, ensure_mlir_values=False) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3110,8 +3008,6 @@ def _scan_lowering_rule( mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))), *out] return out -lowering_rules[lax.scan_p] = _scan_lowering_rule -skip_mlir_conversions.add(lax.scan_p) def _lower_while_via_fori( @@ -3141,6 +3037,7 @@ def _lower_while_via_fori( return [ub, ub, *for_out] +@register_lowering_rule(lax.while_p) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3201,8 +3098,7 @@ def _while_lowering_rule( return list(while_op.results) -lowering_rules[lax.while_p] = _while_lowering_rule - +@register_lowering_rule(lax.cond_p) def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): index, *args = args constant_index = _fold_and_get_constant_value(index) @@ -3241,22 +3137,18 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): return if_op.results -lowering_rules[lax.cond_p] = _cond_lowering_rule - - +@register_lowering_rule(pjit.pjit_p) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args) -lowering_rules[pjit.pjit_p] = _pjit_lowering_rule - - +@register_lowering_rule(pjit.mesh_cast_p) def _mesh_cast_lowering_rule(ctx, x, dst_sharding): return x -lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule +@register_lowering_rule(custom_derivatives.custom_jvp_call_p) def _custom_jvp_call_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3273,19 +3165,14 @@ def _custom_jvp_call_lowering_rule( return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) -lowering_rules[custom_derivatives.custom_jvp_call_p] = ( - _custom_jvp_call_lowering_rule) - - +@register_lowering_rule(debugging.debug_callback_p) def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): del ctx, args, kwargs # No-op debug callbacks in Mosaic for now return [] -lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule - - +@register_lowering_rule(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): if ctx.lowering_context.user_grid_indices is None: @@ -3299,8 +3186,9 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): f" length: {length}" ) return ctx.lowering_context.user_grid_indices[axis] -lowering_rules[primitives.program_id_p] = _program_id_lowering_rule + +@register_lowering_rule(primitives.num_programs_p) def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): mapped_axes = set(ctx.lowering_context.mapped_dims) seen_user_axes = 0 @@ -3314,9 +3202,9 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): f" length: {len(ctx.lowering_context.grid_rank)}" ) return tpu.iteration_bound(i) -lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule +@register_lowering_rule(tpu_primitives.repeat_p) def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): (out_aval,) = ctx.avals_out return tpu.repeat( @@ -3329,9 +3217,7 @@ def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): ) -lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule - - +@register_lowering_rule(tpu_primitives.roll_p) def _roll_lowering_rule( ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): @@ -3348,9 +3234,7 @@ def _roll_lowering_rule( ) -lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule - - +@register_lowering_rule(lax.slice_p) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -3367,62 +3251,45 @@ def _slice_lowering_rule( ) -lowering_rules[lax.slice_p] = _slice_lowering_rule - - +@register_lowering_rule(lax.xor_p, ensure_mlir_values=False) def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.xori(x, y) -lowering_rules[lax.xor_p] = _xor_lowering_rule -skip_mlir_conversions.add(lax.xor_p) - - +@register_lowering_rule(lax.shift_left_p, ensure_mlir_values=False) def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shli(x, d) -lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule -skip_mlir_conversions.add(lax.shift_left_p) - - +@register_lowering_rule(lax.shift_right_arithmetic_p, ensure_mlir_values=False) def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrsi(x, d) -lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule -skip_mlir_conversions.add(lax.shift_right_arithmetic_p) - - -def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): +@register_lowering_rule(lax.shift_right_logical_p, ensure_mlir_values=False) +def _shift_right_logical_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrui(x, d) -lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules -skip_mlir_conversions.add(lax.shift_right_logical_p) - - +@register_lowering_rule(lax.erf_inv_p) def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.erf_inv_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule - - +@register_lowering_rule(primitives.reciprocal_p) def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx): if not isinstance(x.type.element_type, ir.F32Type): raise ValueError("Only float32 is supported.") return tpu.reciprocal(x, approx=approx) -lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule - +@register_lowering_rule(tpu_primitives.bitcast_p) def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out @@ -3433,8 +3300,8 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): x, ) -lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule +@register_lowering_rule(lax.bitcast_convert_type_p) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype): (in_aval, ) = ctx.avals_in @@ -3449,7 +3316,6 @@ def _bitcast_convert_type_lowering_rule( ), x, ) -lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value( @@ -3481,6 +3347,7 @@ def _alloc_value( raise NotImplementedError(f"Cannot allocate {type(aval)}.") +@register_lowering_rule(primitives.run_scoped_p) def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): out_type = [ aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval) @@ -3503,8 +3370,6 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): return region.results -lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule - def _device_id_to_logical( ctx: LoweringRuleContext, device_id, device_id_type: primitives.DeviceIdType): @@ -3528,6 +3393,7 @@ def _device_id_to_logical( raise NotImplementedError(f"Unsupported device id type: {device_id_type}") +@register_lowering_rule(primitives.semaphore_read_p) def _semaphore_read_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3550,8 +3416,7 @@ def _semaphore_read_lowering_rule( return tpu.sem_read(sem) -lowering_rules[primitives.semaphore_read_p] = _semaphore_read_lowering_rule - +@register_lowering_rule(primitives.semaphore_signal_p) def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3569,19 +3434,16 @@ def _semaphore_signal_lowering_rule( return [] -lowering_rules[primitives.semaphore_signal_p] = ( - _semaphore_signal_lowering_rule) - - +@register_lowering_rule(primitives.semaphore_wait_p) def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule +@register_lowering_rule(tpu_primitives.dma_start_p) def _dma_start_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3634,9 +3496,7 @@ def _dma_start_lowering_rule( return [] -lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule - - +@register_lowering_rule(tpu_primitives.dma_wait_p) def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: primitives.DeviceIdType): del device_id_type @@ -3666,8 +3526,8 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, tpu.wait_dma2(sem, src, dst) return [] -lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule +@register_lowering_rule(lax.axis_index_p) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: @@ -3686,24 +3546,23 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32) ) return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size) -lowering_rules[lax.axis_index_p] = _axis_index_rule + +@register_lowering_rule(tpu_primitives.get_barrier_semaphore_p) def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) return tpu.sem_barrier(memref_type) -lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule +@register_lowering_rule(tpu_primitives.delay_p) def _delay_rule(ctx: LoweringRuleContext, nanos: int): tpu.delay(nanos) return [] -lowering_rules[tpu_primitives.delay_p] = _delay_rule - - +@register_lowering_rule(primitives.debug_print_p) def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): @@ -3730,7 +3589,7 @@ def _debug_print_rule( " remove placeholders from the format string." ) - # TPU expects $0, $1 etc as placeholders. + # TPU expects $0, $1 etc as placeholders. fmt = "".join( f"{text}${idx}" for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) @@ -3775,9 +3634,7 @@ def _debug_print_rule( return () -lowering_rules[primitives.debug_print_p] = _debug_print_rule - - +@register_lowering_rule(tpu_primitives.prng_seed_p) def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): del ctx # In the KeyScalarBundle case we unpack the bundle and set the seed with @@ -3793,9 +3650,9 @@ def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") tpu.prng_set_seed_32(seeds) return [] -lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule +@register_lowering_rule(tpu_primitives.prng_random_bits_p) def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): if len(shape) <= 1: # TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp. @@ -3805,15 +3662,15 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) return tpu.prng_random_bits(out_type) -lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule +@register_lowering_rule(prng.random_seed_p) def random_seed_lowering(ctx, seeds, *, impl): seed_lowering = lower_fun(impl.seed, multiple_results=False) return seed_lowering(ctx, seeds) -lowering_rules[prng.random_seed_p] = random_seed_lowering +@register_lowering_rule(prng.random_bits_p) def random_bits_lowering(ctx, keys, *, bit_width, shape): assert bit_width == 32, "Only 32-bit PRNG supported." aval, = ctx.avals_in @@ -3826,17 +3683,17 @@ def new_lowering(key, bit_width, shape): _proxy_fn = new_lowering bits_lowering = lower_fun(_proxy_fn, multiple_results=False) return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) -lowering_rules[prng.random_bits_p] = random_bits_lowering +@register_lowering_rule(prng.random_fold_in_p) def random_fold_in_lowering(ctx, keys, msgs): keys_aval, _ = ctx.avals_in impl = keys_aval.dtype._impl fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False) return fold_in_lowering(ctx, keys, msgs) -lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering +@register_lowering_rule(prng.random_unwrap_p) def random_unwrap_lowering(ctx, key): keys_aval = ctx.avals_in[0] impl = keys_aval.dtype._impl @@ -3846,9 +3703,9 @@ def random_unwrap_lowering(ctx, key): "key_data not support for Pallas PRNG keys. Use" " split_pallas_seed instead." ) -lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering +@register_lowering_rule(prng.random_wrap_p) def random_wrap_lowering(ctx, key_data, *, impl): del ctx if not pl_random.is_pallas_impl(impl): @@ -3858,27 +3715,22 @@ def random_wrap_lowering(ctx, key_data, *, impl): " wrap_pallas_seed instead." ) -lowering_rules[prng.random_wrap_p] = random_wrap_lowering - +@register_lowering_rule(tpu_primitives.split_key_p) def _split_key_lowering_rule( ctx: LoweringRuleContext, key_data: KeyScalarBundle ): return key_data.scalars -lowering_rules[tpu_primitives.split_key_p] = _split_key_lowering_rule - - +@register_lowering_rule(tpu_primitives.join_key_p) def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): if not pl_random.is_pallas_impl(impl): return ValueError(f"Can only join Pallas keys. Got impl={impl}") return KeyScalarBundle(scalars=scalars, key_shape=impl.key_shape) -lowering_rules[tpu_primitives.join_key_p] = _join_key_lowering_rule - - +@register_lowering_rule(checkify.check_p) def _checkify_lowering_rule( ctx: LoweringRuleContext, *err_args, err_tree, debug): if not tpu_core.runtime_assert_enabled(): @@ -3914,8 +3766,9 @@ def _checkify_lowering_rule( operands=(not_pred,), attributes=attrs) return [] -lowering_rules[checkify.check_p] = _checkify_lowering_rule + +@register_lowering_rule(prng.threefry2x32_p) def _threefry2x32_lowering(ctx, k1, k2, m1, m2): def _lower_fun(k1, k2, m1, m2): with jax.named_scope("threefry2x32"): @@ -3926,9 +3779,7 @@ def _lower_fun(k1, k2, m1, m2): return threefry_lowering(ctx, k1, k2, m1, m2) -lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering - - +@register_lowering_rule(prng.iota_2x32_shape_p) def _iota_2x32_shape_lowering(ctx, *, shape): total_elements = np.prod(shape) if total_elements > np.iinfo(jnp.int32).max: @@ -3950,9 +3801,7 @@ def _lower_fun(shape): return iota_lowering(ctx, shape=shape) -lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering - - +@register_lowering_rule(lax.pad_p) def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): operand, padding_value = args padding_config = kwargs["padding_config"] @@ -4014,9 +3863,7 @@ def _pad(val): return operand -lowering_rules[lax.pad_p] = _pad_lowering_rule - - +@register_lowering_rule(control_flow.platform_index_p) def _platform_index_lowering( ctx: mlir.LoweringRuleContext, *, @@ -4036,16 +3883,9 @@ def _platform_index_lowering( ) -lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering - - -def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim): +@register_lowering_rule(shape_poly.dim_as_value_p) +def _dim_as_value_lowering(ctx: LoweringRuleContext, *, dim): placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0] return ir_constant( placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) ) - - -import jax._src.export.shape_poly as shape_poly - -lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering From d7bc08472cf9b84ff3e62759aa7ba6065c7ebd8a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 24 Apr 2025 15:18:13 -0400 Subject: [PATCH 1044/1769] Move const folding and forwarding into tracing --- jax/_src/interpreters/partial_eval.py | 150 +++++++++++++------------- jax/_src/lax/lax.py | 15 +-- jax/_src/pjit.py | 13 ++- tests/pjit_test.py | 14 ++- tests/shard_map_test.py | 4 +- 5 files changed, 107 insertions(+), 89 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 990b0c51d175..64226a789cde 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1375,15 +1375,15 @@ def _closed_jaxpr_partial_eval_custom_cached( def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]: # Compute which inputs are just forwarded to outputs. - fwds: dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars)) + fwds: dict[Var, Atom] = dict(zip(jaxpr.invars, jaxpr.invars)) for eqn in jaxpr.eqns: if eqn.primitive in forwarding_rules: eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore for a in eqn.invars]) - fwd_vars, _ = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: - fwds[v_orig] = v_new + fwd_idx, _ = forwarding_rules[eqn.primitive](eqn) + for v_orig, idx in zip(eqn.outvars, fwd_idx): + if idx is not None: + fwds[v_orig] = eqn.invars[idx] idxs: dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)} return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore for v in jaxpr.outvars] @@ -1749,7 +1749,6 @@ def to_jaxpr( jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, debug_info) - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1766,7 +1765,6 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) @@ -1783,7 +1781,7 @@ def newvar(self, aval): def find_progenitors(self, tracer): var = self.tracer_to_var.get(id(tracer)) - if not var: + if not var or isinstance(var, Literal): return None, None active_vars = {var} for eqn in self.eqns[::-1]: @@ -1798,51 +1796,6 @@ def find_progenitors(self, tracer): for v in eqn.invars)] return invar_positions, const_eqns -def _const_folding_and_forwarding( - jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: - consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Atom] = {} - new_eqns = [] - def apply_var_sub(a: Atom) -> Atom: - return var_subs.get(a, a) if isinstance(a, Var) else a - for eqn in jaxpr.eqns: - # always apply invar substitutions - eqn = eqn.replace(invars=[apply_var_sub(v) for v in eqn.invars]) - # if any inputs are constants and we have a constant-folding rule, apply it - has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) - for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and - any(v in consts if isinstance(v, Var) - else isinstance(v, Literal) for v in eqn.invars) and - not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else - v.val if isinstance(v, Literal) else None - for v in eqn.invars] - consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) - assert (new_eqn is None) == all(c is not None for c in consts_out) - for v, c in zip(eqn.outvars, consts_out): - if c is not None: - if core.is_literalable(c): - var_subs[v] = Literal(c, v.aval) - else: - consts[v] = c - if new_eqn is None: continue - else: eqn = new_eqn - # if the application trivially maps some inputs to outputs, simplify - if eqn.primitive in forwarding_rules and not has_input_effect: - fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: var_subs[v_orig] = v_new - if new_eqn is None: continue - else: eqn = new_eqn - new_eqns.append(eqn) - new_constvars, new_constvals = unzip2(consts.items()) - new_outvars = [apply_var_sub(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals ConstFoldRule = Callable[ [list[Union[Any, None]], JaxprEqn], @@ -1852,7 +1805,7 @@ def apply_var_sub(a: Atom) -> Atom: ForwardingRule = Callable[ [JaxprEqn], - tuple[list[Union[Var, None]], Union[JaxprEqn, None]] + tuple[list[Union[int, None]], Union[JaxprEqn, None]] ] forwarding_rules: dict[Primitive, ForwardingRule] = {} @@ -1937,6 +1890,13 @@ def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: self.frame.constvar_to_val[var] = c return tracer + def get_const(self, tracer) -> Any: + var = self.frame.tracer_to_var.get(id(tracer)) + if isinstance(var, Literal): + return var.val + elif var is not None: + return self.frame.constvar_to_val.get(var) + def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): @@ -1958,9 +1918,6 @@ def makevar(self, tracer): var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) return var - def is_const(self, tracer): - return self.frame.tracer_to_var.get(id(tracer)) is None - def process_primitive(self, primitive, tracers, params): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) @@ -1973,7 +1930,7 @@ def process_primitive(self, primitive, tracers, params): def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] - out_avals, effects = primitive.abstract_eval(*avals, **params) + out_avals, effs = primitive.abstract_eval(*avals, **params) if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") @@ -1982,9 +1939,29 @@ def default_process_primitive(self, primitive, tracers, params): out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, - source_info) - self.frame.add_eqn(eqn) + eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effs, source_info) + no_input_effects = not any(isinstance(e, effects.JaxprInputEffect) + for e in eqn.effects) + + # Constant folding + if no_input_effects and primitive in const_fold_rules: + consts_in = map(self.get_const, tracers) + if any(c is not None for c in consts_in): + consts_out, eqn = const_fold_rules[primitive](consts_in, eqn) + assert (eqn is None) == all(c is not None for c in consts_out) + for i, c in enumerate(consts_out): + if c is not None: + out_tracers[i] = self.new_const(c, source_info) + + # Input-to-output tracer forwarding + if eqn is not None and no_input_effects and primitive in forwarding_rules: + in_fwd, eqn = forwarding_rules[primitive](eqn) + for out_idx, in_idx in enumerate(in_fwd): + if in_idx is not None: + out_tracers[out_idx] = tracers[in_idx] + + if eqn is not None: + self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() def process_call(self, call_primitive, f: lu.WrappedFun, @@ -2633,26 +2610,53 @@ def inline_jaxpr_into_trace( const_tracers = map(partial(trace.new_const, source_info=src), consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) - env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*constvars, *argvars])) + const_env: dict[Var, Any] = { + v: c for v, c in zip(constvars, consts) if not isinstance(v, Literal)} + env: dict[Var, Atom] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*constvars, *argvars])) for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] - outvars = [Var('', v.aval) for v in eqn.outvars] + orig_outvars = eqn.outvars + outvars = [Var('', v.aval) for v in orig_outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) - foreach(env.setdefault, eqn.outvars, outvars) - - tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*consts, *arg_tracers])) - def new_tracer(atom): + eqn = eqn.replace(invars, outvars, source_info=src_) + foreach(env.setdefault, orig_outvars, outvars) + + # We must re-run constant folding when inlining because some jaxpr inputs + # may be consts in the outer scope. + eqn_: JaxprEqn | None = eqn + inp_eff = any(isinstance(e, effects.JaxprInputEffect) for e in eqn.effects) + if eqn.primitive in const_fold_rules and not inp_eff: + consts_in = [v.val if isinstance(v, Literal) else const_env.get(v) + for v in invars] + if any(c is not None for c in consts_in): + consts_out, eqn_ = const_fold_rules[eqn.primitive](consts_in, eqn) + assert (eqn_ is None) == all(c is not None for c in consts_out) + for v, c in zip(orig_outvars, consts_out): + if c is not None: + if core.is_literalable(c): + env[v] = Literal(c, v.aval) + else: + const_env[v] = c + if eqn_ is not None: + trace.frame.add_eqn(eqn_) + + tracer_env: dict[Var, Any] = const_env + tracer_env.update( + {v: t for v, t in zip(argvars, arg_tracers) if not isinstance(v, Literal)} + ) + def maybe_new_tracer(atom): + if isinstance(atom, Literal): + return atom.val + if atom in tracer_env: + return tracer_env[atom] tracer = tracer_env[atom] = DynamicJaxprTracer(trace, atom.aval, src) trace.frame.tracers.append(tracer) - trace.frame.tracer_to_var[id(tracer)] = env[atom] + trace.frame.tracer_to_var[id(tracer)] = atom return tracer - return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env - else new_tracer(x) for x in jaxpr.outvars] + return [maybe_new_tracer(x if isinstance(x, Literal) else env[x]) for x in jaxpr.outvars] # TODO(mattjj,dougalm): this special handling is to avoid round-tripping the # jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 214fc9650505..6c248565174c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2705,6 +2705,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ + # TODO(dfm): Re-write this as a "reshard" when only the sharding changes. out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim') if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): @@ -4850,11 +4851,11 @@ def _convert_elt_type_folding_rule(consts, eqn): def _convert_elt_type_fwd_rule(eqn): v, = eqn.invars - if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and + if (v.aval.dtype == eqn.params['new_dtype'] and + v.aval.weak_type == eqn.params['weak_type'] and not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and - v.aval.dtype == eqn.params['new_dtype'] and - v.aval.weak_type == eqn.params['weak_type']): - return [v], None + (eqn.params['sharding'] is None or eqn.params['sharding'] == v.aval.sharding)): + return [0], None else: return [None], eqn @@ -6447,8 +6448,10 @@ def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, def _broadcast_in_dim_fwd_rule(eqn): v, *dyn = eqn.invars - if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape): - return [v], None + if (not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape) + and (eqn.params['sharding'] is None or + eqn.params['sharding'] == v.aval.sharding)): + return [0], None else: return [None], eqn diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 81d8982184c1..268f533b971b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1955,7 +1955,7 @@ def pjit_staging_rule(trace, *args, **params): assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(partial(trace.new_const, source_info=source_info), consts) + consts = [trace.new_const(c, source_info) for c in consts] in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) @@ -1975,8 +1975,8 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts): for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) - out_shardings = [o for o, k in zip(out_shardings, keep) if k] - out_layouts = [o for o, k in zip(out_layouts , keep) if k] + out_shardings = tuple(o for o, k in zip(out_shardings, keep) if k) + out_layouts = tuple(o for o, k in zip(out_layouts , keep) if k) return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): @@ -1985,11 +1985,10 @@ def pjit_forwarding_rule(eqn): jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] - new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,), - out_layouts=(*out_layouts,)) + new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) new_eqn = eqn.replace(params=new_params, outvars=new_outvars) - fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] - return fwd_vars, new_eqn + return in_fwd, new_eqn # TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[pjit_p] = pjit_forwarding_rule diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 68bdeafd9c2b..e8d8d46455ba 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5150,7 +5150,7 @@ def f(x, y): self.assertEqual(out[1].sharding, arr2.sharding) jaxpr = jitted_grad.trace(arr1, arr2).jaxpr - bwd_jaxpr = jaxpr.eqns[1] + bwd_jaxpr = jaxpr.eqns[-1] expected_spec = [('broadcast_in_dim', P('x', None)), ('dot_general', P('x', None)), ('transpose', P(None, 'x')), @@ -7668,6 +7668,18 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.ravel(np_inp)) + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_broadcast_forwarding(self, mesh): + arr = jax.device_put(np.zeros(()), P()) + + def f(x): + out = jax.lax.full_like(x, 1.0) + self.assertEqual(jax.typeof(out).sharding, jax.typeof(x).sharding) + return out + + f(arr) + jax.jit(f)(arr) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index a4c420959710..4d3b265bd869 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1778,7 +1778,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @@ -1801,7 +1801,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) From c93b051ff851b4bb7cc9a73102459d6b1ce3ffdb Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 7 May 2025 09:30:45 -0700 Subject: [PATCH 1045/1769] [JAX] Add a test for multiprocess shard_map in McJAX with non-participating hosts. Update handling of device memories in PyDeviceList to support this. PiperOrigin-RevId: 755887258 --- jaxlib/py_device_list.cc | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/jaxlib/py_device_list.cc b/jaxlib/py_device_list.cc index c80602e14862..c5004dc57330 100644 --- a/jaxlib/py_device_list.cc +++ b/jaxlib/py_device_list.cc @@ -366,21 +366,14 @@ void PyDeviceList::PopulateMemoryKindInfo() { throw nb::value_error("Unrecognized DeviceList type"); } MemoryKindInfo info; - xla::ifrt::Device* addressable_device = nullptr; - const int process_index = py_client_ ? py_client_->process_index() : 0; - for (xla::ifrt::Device* device : std::get<0>(device_list_)->devices()) { - if (device->ProcessIndex() == process_index) { - addressable_device = device; - break; - } - } - if (addressable_device == nullptr) { + if (std::get<0>(device_list_)->size() == 0) { info.default_memory_kind = nb::none(); memory_kind_info_ = std::move(info); return; } + xla::ifrt::Device* device = std::get<0>(device_list_)->devices()[0]; - auto default_memory = addressable_device->DefaultMemory(); + auto default_memory = device->DefaultMemory(); if (!default_memory.ok()) { // Cache the error. memory_kind_info_ = default_memory.status(); @@ -388,9 +381,9 @@ void PyDeviceList::PopulateMemoryKindInfo() { } info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); nb::tuple memory_kinds = - nb::steal(PyTuple_New(addressable_device->Memories().size())); - for (size_t i = 0; i < addressable_device->Memories().size(); ++i) { - auto* memory = addressable_device->Memories()[i]; + nb::steal(PyTuple_New(device->Memories().size())); + for (size_t i = 0; i < device->Memories().size(); ++i) { + auto* memory = device->Memories()[i]; nb::str s = nb::str(memory->Kind().memory_kind()->data(), memory->Kind().memory_kind()->size()); PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); @@ -402,24 +395,17 @@ void PyDeviceList::PopulateMemoryKindInfo() { void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { MemoryKindInfo info; try { - nb::handle addressable_device; - for (nb::handle device : std::get<1>(device_list_)) { - if (nb::cast(device.attr("process_index")) == - nb::cast(device.attr("client").attr("process_index")())) { - addressable_device = device; - break; - } - } - if (!addressable_device) { + if (std::get<1>(device_list_).size() == 0) { info.default_memory_kind = nb::none(); // info.memory_kinds is default-initialized to an empty tuple. memory_kind_info_ = std::move(info); return; } - auto default_memory = addressable_device.attr("default_memory")(); + nb::handle device = std::get<1>(device_list_)[0]; + auto default_memory = device.attr("default_memory")(); info.default_memory_kind = default_memory.attr("kind"); info.memory_kinds = nb::tuple( - nb::object(addressable_device.attr("addressable_memories")())); + nb::object(device.attr("addressable_memories")())); memory_kind_info_ = std::move(info); } catch (nb::python_error& e) { // Cache the error. From f7f57d1dbd5db9eb64ed698533996df8acf44b1d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 7 May 2025 09:50:55 -0700 Subject: [PATCH 1046/1769] jax.scipy.signal.istft: support array input for window --- jax/_src/dtypes.py | 6 ++++++ jax/_src/scipy/signal.py | 6 +++--- tests/scipy_signal_test.py | 7 ++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d1e5b7bf430b..ae3516ea671c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -279,6 +279,12 @@ def to_inexact_dtype(dtype: DTypeLike) -> DType: return _dtype_to_inexact.get(dtype_, dtype_) +def to_floating_dtype(dtype: DTypeLike) -> DType: + """Promotes a dtype to a non-complex floating dtype.""" + dtype_ = np.dtype(dtype) + return finfo(_dtype_to_inexact.get(dtype_, dtype_)).dtype + + def to_complex_dtype(dtype: DTypeLike) -> DType: ftype = to_inexact_dtype(dtype) if ftype in [np.dtype('float64'), np.dtype('complex128')]: diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index d950cd2ea395..565909e8a6d1 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -1071,7 +1071,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). nfft: Number of FFT points used in the STFT. If ``None`` (default), the value is determined from the size of ``Zxx``. - input_onesided: If Tru` (default), interpret the input as a one-sided STFT + input_onesided: If True (default), interpret the input as a one-sided STFT (positive frequencies only). If False, interpret the input as a two-sided STFT. boundary: If True (default), it is assumed that the input signal was extended at its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`. @@ -1108,7 +1108,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', raise ValueError('Must specify differing time and frequency axes!') Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( - np.result_type(Zxx, np.complex64))) + dtypes.to_complex_dtype(Zxx.dtype))) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) @@ -1147,7 +1147,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg_int, :] # Get window as array - if window == 'hann': + if isinstance(window, str) and window == 'hann': # Implement the default case without scipy win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 win = win.astype(xsubs.dtype) diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 11923257a9dd..7ff3c87435c7 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -388,7 +388,7 @@ def osp_fun(x): ], dtype=default_dtypes, fs=[1.0, 16000.0], - window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann', 'USE_ARRAY'], onesided=[False, True], boundary=[False, True], ) @@ -399,6 +399,11 @@ def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:] + if window == 'USE_ARRAY': + # ensure dtype matches the expected dtype of `xsubs` within the implementation. + window = np.ones(nperseg, dtype=( + dtypes.to_floating_dtype(dtype) if onesided else dtypes.to_complex_dtype(dtype))) + kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) From 4a407bbd6e4467ece13094d634f4202a90a5d107 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Wed, 7 May 2025 17:13:51 +0000 Subject: [PATCH 1047/1769] add docs for multi-process run in Kubernete Co-authored-by: Michael Whittaker --- .github/workflows/k8s.yaml | 2 +- docs/multi_process.md | 85 ++++++++++++++++++++++++++++++++ jax/_src/clusters/k8s_cluster.py | 4 +- 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 470a899a187e..5756b1afbbd2 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -1,4 +1,4 @@ -name: Distributed run using K8s Jobset +name: Multi-process run using K8s on: push: diff --git a/docs/multi_process.md b/docs/multi_process.md index 8ecc51cc2557..f8c2566ca872 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -307,6 +307,83 @@ what it prints: Woohoo, look at all those TPU cores! +### Kubernetes Example + +Running multi-controller JAX on a Kubernetes cluster is almost identical in spirit to the GPU and TPU examples above: every pod runs the same Python program, JAX discovers its peers, and the cluster behaves like one giant machine. + +1. **Container image** - start from a JAX-enabled image, e.g. one of the public JAX AI images on Google Artifact Registry ([TPU][google-artifact-tpu] / [GPU][google-artifact-gpu]) or NVIDIA ([NGC][nvidia-ngc] / [JAX-Toolbox][nvidia-jax-toolbox]). + +2. **Workload type** - use either a [JobSet][k8s-jobset] or an [indexed Job][k8s-indexed-job]. Each replica corresponds to one JAX process. + +3. **Service Account** - JAX needs permission to list the pods that belong to the job so that processes discover their peers. A minimal RBAC setup is provided in [examples/k8s/svc-acct.yaml][rbac-svc-acct]. + +Below is a [minimal JobSet][minimal-jobset] that launches two replicas. Replace the placeholders - +image, GPU count, and any private registry secrets - with values that match your environment. + +```yaml +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml + restartPolicy: Never + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ + - name: null + containers: + - name: main + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always + resources: + limits: + cpu: 1 + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) +``` + +Apply the manifest and watch the pods complete: + +```bash +$ kubectl apply -f example.yaml +$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob +NAME READY STATUS RESTARTS AGE +jaxjob-workers-0-0-xpx8l 0/1 Completed 0 8m32s +jaxjob-workers-0-1-ddkq8 0/1 Completed 0 8m32s +``` + +When the job finishes, inspect the logs to confirm that every process saw all accelerators: + +```bash +$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=0)] +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=1)] +``` + +Every pod should have the same set of global devices and a different set of local devices. At this point, you can replace the inline script with your real JAX program. + Once the processes are set up, we can start building global {class}`jax.Array`s and running computations. The remaining Python code examples in this tutorial are meant to be run on all processes simultaneously, after running @@ -580,3 +657,11 @@ assert (np.all( [distributed_arrays]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html [gpu_machines]: https://cloud.google.com/compute/docs/gpus [unified_sharding]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[google-artifact-tpu]: https://console.cloud.google.com/artifacts/docker/cloud-tpu-images/us/jax-ai-image/tpu +[google-artifact-gpu]: https://console.cloud.google.com/artifacts/docker/deeplearning-images/us-central1/jax-ai-image/gpu +[nvidia-ngc]: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax +[nvidia-jax-toolbox]: https://github.com/NVIDIA/JAX-Toolbox +[k8s-jobset]: https://github.com/kubernetes-sigs/jobset +[k8s-indexed-job]: https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs +[rbac-svc-acct]: https://github.com/jax-ml/jax/blob/main/examples/k8s/svc-acct.yaml +[minimal-jobset]: https://github.com/jax-ml/jax/blob/main/examples/k8s/example.yaml diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 11f93e36f647..9520b947d9c5 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -69,7 +69,9 @@ def _handle_api_exception(cls): "this job does not have the permission for pod introspection. Please " "either grant the default SA permission to read pod info, or create a " "dedicated service account with the permission and associated with " - "the job. For more details, see .", + "the job. For an example on setting up the service account, see the " + "example/k8s directory in the JAX repo. For more details, please refer to " + "https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example", width=80 )) raise RuntimeError('\n'.join(err_msg)) from e From 59f07c3bcb374b197f25137e09fa74c25c959efb Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 7 May 2025 10:33:59 -0700 Subject: [PATCH 1048/1769] Clean up `LoadedExecutable::Delete` and `LoadedExecutable::IsDeleted` This API was added for symmetry with IFRT arrays, but is not used by anyone. Not exposing the "deleted" state for executables simplifies IFRT implementations. Users who want to delete executables (such as `PyLoadedExecutable` in JAX) can always just drop the reference to the executable. With this, `LoadedExecutable` is completely immutable once constructed. This also helps with some cases where users want to share the same executable across many places and want to ensure that executables are never modified across these shared entities. IFRT Proxy's server-side logic now returns an `UNIMPLEMENTED` error for `Delete` and false for `IsDeleted`. This will not cause any disruption for existing workloads because no one is actually calling these methods outside tests. PiperOrigin-RevId: 755912978 --- jaxlib/py_client.cc | 14 +++++--------- jaxlib/py_executable.h | 7 +------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 2cc84f0bf86c..ecd412ddbb99 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -191,9 +191,7 @@ nb::list PyClient::LiveExecutables() { nb::ft_lock_guard lock(executables_mutex_); nb::list executables; for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { - if (!exec->is_deleted()) { - executables.append(nb::find(exec)); - } + executables.append(nb::find(exec)); } return executables; } @@ -621,12 +619,10 @@ absl::StatusOr PyClient::HeapProfile() { for (PyLoadedExecutable* executable = executables_; executable; executable = executable->next_) { - if (!executable->is_deleted()) { - HeapProfileKey key{ - executable->traceback() ? executable->traceback()->get() : nullptr, - executable->SizeOfGeneratedCodeInBytes(), nullptr}; - ++entries[key]; - } + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; } PprofProfileBuilder builder; diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index 5c7f57301b82..9f8034ed2675 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -162,12 +162,7 @@ class PyLoadedExecutable { return ifrt_loaded_executable_->GetCostAnalysis(); } - void Delete() { - // TODO(hyeontaek): Return absl::Status. - TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); - } - - bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } + void Delete() {} // Will be deleted. // Takes args indexed by argid then deviceid, transposes them, and passes to // PjRtExecutable::Execute. The result is similarly transposed back into the From 63d2f7d3956cec18aa214b0c2416681a0325f1f2 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 7 May 2025 10:37:55 -0700 Subject: [PATCH 1049/1769] Declare `tpu.vector_load`, mirroring `tpu.vector_store`. At the moment this op does not have any lowerings defined and should not be used. PiperOrigin-RevId: 755914724 --- jaxlib/mosaic/BUILD | 15 ++ jaxlib/mosaic/dialect/tpu/tpu.td | 27 ++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 31 +++ .../dialect/tpu/tpu_ops_verification_test.cc | 246 ++++++++++++++++++ 4 files changed, 319 insertions(+) create mode 100644 jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index a4123d0654bf..a212f7afb8bd 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -232,6 +232,21 @@ cc_test( ], ) +cc_test( + name = "tpu_ops_verification_test", + srcs = ["dialect/tpu/tpu_ops_verification_test.cc"], + deps = [ + ":tpu_dialect", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/status", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + "@xla//xla/mlir/utils:error_util", + ], +) + filegroup( name = "extension_srcs", srcs = [ diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 8b6e005e1719..c2d35f6f694a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -245,6 +245,33 @@ def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperand let hasVerifier = 1; } +// tpu.vector_load loads a vector from memory into a register. +// +// base : Memref to load from. +// indices: Scalar indices into base. indices must be of the same rank as the +// base memref shape. +// strides: The stride to use for calculating the address of subsequent +// elements. If left unspecified, the stride is implicitly 1 along +// each dimension. Otherwise the stride must match the rank of the +// memref shape. +// mask : Elementwise vector mask. Must be broadcastable to the shape of the +// result vector. Depending on the core type, this may be a dynamic +// (lane) mask consumed from a register or a static (sublane) mask +// that must be the result of arith.constant. +def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask) + }]; + let hasVerifier = 1; +} + def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { let arguments = (ins AnyMemRef:$base, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index bbc5be3d125d..6c2e6b700bba 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -508,6 +509,36 @@ LogicalResult VectorStoreOp::verify() { return success(); } +LogicalResult VectorLoadOp::verify() { + const MemRefType ref_ty = getBase().getType(); + if (!getStrides().empty()) { + if (llvm::size(getStrides()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " strides."; + } + return emitError("Not implemented: general vector load with strides."); + } + const VectorType value_ty = getResult().getType(); + + if (value_ty.getElementType() != ref_ty.getElementType()) { + return emitOpError("Expected base and result element type to match."); + } + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices."; + } + if (getMask()) { + if (value_ty.getElementTypeBitWidth() != 32) { + return emitError( + "Not implemented: masked load with non-32-bit element type"); + } + if (vector::isBroadcastableTo(getMask().getType(), value_ty) != + vector::BroadcastableToResult::Success) { + return emitOpError( + "Expected mask shape to be broadcastable to result shape."); + } + } + return success(); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc new file mode 100644 index 000000000000..e92403c21ad0 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc @@ -0,0 +1,246 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include "absl/status/status.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "xla/mlir/utils/error_util.h" + +namespace mlir::tpu { +namespace { + +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::status::StatusIs; + +class TpuOpsVerificationTest : public ::testing::Test { + protected: + TpuOpsVerificationTest() + : context_([]() { + DialectRegistry registry; + registry + .insert(); + return registry; + }()), + builder_(UnknownLoc::get(&context_), &context_) { + context_.loadAllAvailableDialects(); + context_.printOpOnDiagnostic(true); + } + ~TpuOpsVerificationTest() { + for (int i = ops_.size() - 1; i >= 0; --i) { + ops_[i]->erase(); + } + } + + template + OpTy Create(Args&&... args) { + OpTy op = builder_.create(std::forward(args)...); + ops_.push_back(op.getOperation()); + return op; + } + + template + absl::Status VerifyOp(OpTy op) { + BaseScopedDiagnosticHandler diag(&context_); + if (op.verify().succeeded()) { + return absl::OkStatus(); + } + return diag.ConsumeStatus(); + } + + ImplicitLocOpBuilder& builder() { return builder_; } + + private: + MLIRContext context_; + ImplicitLocOpBuilder builder_; + std::vector ops_; +}; + +TEST_F(TpuOpsVerificationTest, VectorLoadVerificationWorks) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfStridesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1, 1, 1, 1}), + /*mask=*/nullptr); + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadStridesFeatureNotImplemented) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1}), + /*mask=*/nullptr); + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr("Not implemented: general vector load with strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadBaseAndResultTypesDoNotMatch) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getF32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, + HasSubstr("Expected base and result element type to match."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfIndicesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 indices."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadValidMaskSucceeds) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 1}, builder().getI32Type()), + /*value=*/dyn_cast( + builder().getDenseI32ArrayAttr({1, 1, 1, 1, 1, 1, 1, 1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadMaskInvalidResultBitWidth) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI64Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 1}, builder().getI32Type()), + /*value=*/dyn_cast( + builder().getDenseI32ArrayAttr({1, 1, 1, 1, 1, 1, 1, 1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI64Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Not implemented: masked load with non-32-bit element type"))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMinor) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 2}, builder().getI32Type()), + /*value=*/dyn_cast(builder().getDenseI32ArrayAttr({1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMajor) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({5, 1}, builder().getI32Type()), + /*value=*/dyn_cast(builder().getDenseI32ArrayAttr({1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +} // namespace +} // namespace mlir::tpu From 4cc4bd3ca1efd787aaaadc23cd8f5ecdfe8a1c4a Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 7 May 2025 10:50:25 -0700 Subject: [PATCH 1050/1769] Use `ArrayRef` instead of `tsl::RCReference` PiperOrigin-RevId: 755919798 --- jaxlib/pjit.cc | 18 +++++++-------- jaxlib/pmap_lib.cc | 6 ++--- jaxlib/py_array.cc | 44 ++++++++++++++++++------------------ jaxlib/py_array.h | 21 ++++++++--------- jaxlib/py_executable.cc | 31 ++++++++++++------------- jaxlib/py_executable.h | 6 ++--- jaxlib/py_socket_transfer.cc | 11 ++++----- jaxlib/py_values.cc | 33 ++++++++++++--------------- jaxlib/py_values.h | 4 ++-- 9 files changed, 81 insertions(+), 93 deletions(-) diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc index 8c8800e80706..804352161597 100644 --- a/jaxlib/pjit.cc +++ b/jaxlib/pjit.cc @@ -419,11 +419,10 @@ PjitFunction::~PjitFunction() { executables_ = nullptr; } -void CallShardArgFallback( - nb::handle arg, nb::handle sharding, nb::handle layout, - const nb::callable& fallback, - std::vector>& num_args_arrays, - std::vector& keep_alive_objects) { +void CallShardArgFallback(nb::handle arg, nb::handle sharding, + nb::handle layout, const nb::callable& fallback, + std::vector& num_args_arrays, + std::vector& keep_alive_objects) { tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); auto py_array_or_bufs = fallback(arg, sharding, layout); auto py_array = nb::cast(py_array_or_bufs); @@ -433,8 +432,7 @@ void CallShardArgFallback( // Prepares the input PjRtBuffers from the python arguments. This is equivalent // to shard_args() in pxla.py but for only a few supported cases. -absl::StatusOr>> -PrepareIfrtInputs( +absl::StatusOr> PrepareIfrtInputs( const xla::PyLoadedExecutable& executable, absl::Span flat_dynamic_args, absl::Span flat_dynamic_arg_signatures, @@ -449,12 +447,12 @@ PrepareIfrtInputs( executable.ifrt_loaded_executable()->num_devices(); int num_args = flat_dynamic_args.size(); - std::vector> num_args_arrays; + std::vector num_args_arrays; num_args_arrays.reserve(num_args); struct CopyGroup { std::vector indices; - std::vector> arrays; + std::vector arrays; }; absl::flat_hash_map, CopyGroup> @@ -760,7 +758,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, tsl::Env::Default()->GetCurrentThreadId(); // A vector of [num_outputs]. - std::vector> output_arrays; + std::vector output_arrays; { nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto result, diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index c29c2d1eb2b5..e18dd8b4637a 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -110,7 +110,7 @@ struct ResultSpec { struct ShardArgResult { // Points to the on-device array. // ifrt_array->sharding().num_shards() == `num_devices`. - tsl::RCReference ifrt_array; + xla::ifrt::ArrayRef ifrt_array; // The Python argument will be always be copied to `owning_sda`. nb::object owning_sda; }; @@ -615,7 +615,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, const int num_args = flat_dynamic_args.size(); // We need [num_args] for the `Execute` call below. - std::vector> num_args_arrays(num_args); + std::vector num_args_arrays(num_args); for (int i = 0; i < num_args; ++i) { TF_ASSIGN_OR_RETURN( ShardArgResult sharded_arg, @@ -634,7 +634,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, tsl::Env::Default()->GetCurrentThreadId(); // A vector of [num_outputs]. - std::vector> output_arrays; + std::vector output_arrays; { nb::gil_scoped_release gil_release; auto ifrt_executable = cache_entry.executable->ifrt_executable(); diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index a977d5f1a554..1222d410bad8 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -146,12 +146,12 @@ absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, return &scratch.value(); } -tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( +ifrt::ArrayRef CreateIfRtArrayFromSingleDeviceShardedPyArrays( nb_dtype dtype, absl::Span shape, absl::Span py_arrays, const nb::object& sharding) { const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); - std::vector> ifrt_arrays; + std::vector ifrt_arrays; ifrt_arrays.reserve(py_arrays.size()); absl::InlinedVector devices; devices.reserve(py_arrays.size()); @@ -225,7 +225,7 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. ifrt::DType array_dtype = ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); - absl::StatusOr> ifrt_array = + absl::StatusOr ifrt_array = device->client()->AssembleArrayFromSingleDeviceArrays( array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, @@ -458,7 +458,7 @@ PyArray_Storage::PyArray_Storage( nb::object aval, bool weak_type, xla::nb_dtype dtype, std::vector shape, nb::object sharding, bool committed, nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, xla::PjRtFuture<> result_status) + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status) : aval(std::move(aval)), weak_type(weak_type), dtype(std::move(dtype)), @@ -515,7 +515,7 @@ void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, PyArray PyArray::MakeFromSingleDeviceArray( nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, bool weak_type, bool committed, + ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, xla::PjRtFuture<> result_status) { if (!llvm::isa(ifrt_array->sharding())) { throw XlaRuntimeError( @@ -547,8 +547,8 @@ PyArray PyArray::MakeFromSingleDeviceArray( PyArray PyArray::MakeFromIfrtArrayAndSharding( nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, nb::object sharding, - bool weak_type, bool committed, bool skip_checks) { + ifrt::ArrayRef ifrt_array, nb::object sharding, bool weak_type, + bool committed, bool skip_checks) { auto shape_span = ifrt_array->shape().dims(); ShapedArrayCacheKey key; key.dtype = ifrt_array->dtype(); @@ -590,7 +590,7 @@ PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { } PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, - tsl::RCReference ifrt_array, + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status) const { return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, std::move(py_client), Traceback::Get(), std::move(ifrt_array), @@ -606,8 +606,8 @@ PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, std::vector shape, nb::object sharding, nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, bool committed, - bool skip_checks, xla::PjRtFuture<> result_status) { + ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, + xla::PjRtFuture<> result_status) { auto* self = PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); m_ptr = self; @@ -636,7 +636,7 @@ nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); } -void PyArray::SetIfrtArray(tsl::RCReference ifrt_array) { +void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) { GetStorage().ifrt_array = std::move(ifrt_array); } @@ -683,7 +683,7 @@ nb::object PyArray::arrays() { absl::Status PyArray::set_arrays(nb::object obj) { if (obj.is_none()) { - SetIfrtArray(tsl::RCReference()); + SetIfrtArray(ifrt::ArrayRef()); py_arrays().clear(); return absl::OkStatus(); } @@ -697,9 +697,9 @@ absl::Status PyArray::set_arrays(nb::object obj) { if (list.size() == 0) return absl::OkStatus(); - SetIfrtArray(tsl::RCReference()); + SetIfrtArray(ifrt::ArrayRef()); py_arrays().clear(); - std::vector> ifrt_arrays; + std::vector ifrt_arrays; ifrt_arrays.reserve(list.size()); absl::InlinedVector devices; devices.reserve(list.size()); @@ -1074,7 +1074,7 @@ absl::Status PyArray::Delete() { // buffer has been deleted or a request must be processed via RPC, // especially as this deletion is done per array. ifrt_array()->Delete(); - SetIfrtArray(tsl::RCReference()); + SetIfrtArray(ifrt::ArrayRef()); } return absl::OkStatus(); } @@ -1090,7 +1090,7 @@ bool PyArray::IsDeleted() const { PyArray PyArray::Clone() const { auto array = tsl::FormRef(ifrt_array()); auto* ifrt_client = py_client()->ifrt_client(); - tsl::RCReference out = + ifrt::ArrayRef out = ifrt_client ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, /*memory_kind=*/std::nullopt, @@ -1149,7 +1149,7 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. struct Batch { std::vector indexes; - std::vector> ifrt_arrays; + std::vector ifrt_arrays; }; absl::flat_hash_map batches; @@ -1206,7 +1206,7 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); } - std::vector>> ifrt_arrays; + std::vector> ifrt_arrays; { GlobalPyRefManager()->CollectGarbage(); nb::gil_scoped_release gil_release; @@ -1271,7 +1271,7 @@ absl::StatusOr PyArray::BatchedDevicePut( (!force_copy && (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); - std::vector> ifrt_arrays; + std::vector ifrt_arrays; absl::InlinedVector devices; devices.reserve(n_devices); @@ -1334,7 +1334,7 @@ absl::StatusOr PyArray::ReorderShards( GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), ifrt_array_ptr->shape())); - tsl::RCReference new_ifrt_array; + xla::ifrt::ArrayRef new_ifrt_array; { nb::gil_scoped_release gil_release; @@ -1399,7 +1399,7 @@ absl::StatusOr PyArray::ReorderShards( /*mappings=*/std::move(mappings), }; DCHECK_OK(plan.Validate()); - std::vector> input; + std::vector input; input.push_back(tsl::FormRef(ifrt_array_ptr)); TF_ASSIGN_OR_RETURN( auto remapped, @@ -1717,7 +1717,7 @@ absl::StatusOr> PyHostValue::AsNumPyArray( PrimitiveTypeToNbDtype(shape->element_type())); // Objects that must be kept alive while the array is alive. struct Hold { - tsl::RCReference buffer; + ifrt::ArrayRef buffer; std::unique_ptr external_reference_hold; }; auto hold = std::make_unique(); diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h index ddb09bc41771..bf1208c11da5 100644 --- a/jaxlib/py_array.h +++ b/jaxlib/py_array.h @@ -94,8 +94,7 @@ struct PyArray_Storage { std::vector shape, nanobind::object sharding, bool committed, nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, - xla::PjRtFuture<> result_status); + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status); ~PyArray_Storage(); nanobind::handle AsHandle(); @@ -111,7 +110,7 @@ struct PyArray_Storage { nb_class_ptr py_client; std::optional traceback; - tsl::RCReference ifrt_array; + ifrt::ArrayRef ifrt_array; nanobind::object fully_replicated_array = nanobind::none(); // optional field, used only in python @@ -153,20 +152,19 @@ class PyArray : public nanobind::object { PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, std::vector shape, nanobind::object sharding, nb_class_ptr py_client, - std::optional traceback, - tsl::RCReference ifrt_array, bool committed, - bool skip_checks, + std::optional traceback, ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); static PyArray MakeFromSingleDeviceArray( nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, bool weak_type, bool committed, + ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); static PyArray MakeFromIfrtArrayAndSharding( nb_class_ptr py_client, std::optional traceback, - tsl::RCReference ifrt_array, nanobind::object sharding, - bool weak_type, bool committed, bool skip_checks); + ifrt::ArrayRef ifrt_array, nanobind::object sharding, bool weak_type, + bool committed, bool skip_checks); static absl::Status RegisterTypes(nanobind::module_& m); @@ -325,7 +323,7 @@ class PyArray : public nanobind::object { nanobind::object sharding, nanobind::object aval); - void SetIfrtArray(tsl::RCReference ifrt_array); + void SetIfrtArray(ifrt::ArrayRef ifrt_array); Storage& GetStorage(); const Storage& GetStorage() const; @@ -341,8 +339,7 @@ class PyArrayResultHandler { PyArray Call(absl::Span py_arrays) const; PyArray Call(PyArray py_array) const; - PyArray Call(nb_class_ptr py_client, - tsl::RCReference ifrt_array, + PyArray Call(nb_class_ptr py_client, ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; private: diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 2902a2bd7be0..d79b236b9241 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -140,8 +140,7 @@ static int GetNumDevices(const ExecuteShardedArg& arg) { return std::get>(arg).size(); } } -static tsl::RCReference GetIfRtArray( - const ExecuteShardedArg& arg) { +static ifrt::ArrayRef GetIfRtArray(const ExecuteShardedArg& arg) { if (std::holds_alternative(arg)) { return tsl::FormRef(std::get(arg).ifrt_array()); } @@ -151,7 +150,7 @@ static tsl::RCReference GetIfRtArray( // insufficient information about the shape (a dummy shape is used). This // should be removed if possible and only be used in the context where the // shape information is unused. - std::vector> ifrt_arrays; + std::vector ifrt_arrays; ifrt_arrays.reserve(arg_vector.size()); absl::InlinedVector devices; devices.reserve(arg_vector.size()); @@ -177,11 +176,11 @@ static tsl::RCReference GetIfRtArray( return *ifrt_array; } -void PopulateExecuteShardedResults( - const nb_class_ptr& client, - std::vector> ifrt_arrays, - const PjRtFuture<>& result_status, int num_computations, - std::vector>& outputs) { +void PopulateExecuteShardedResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + const PjRtFuture<>& result_status, + int num_computations, + std::vector>& outputs) { auto traceback = Traceback::Get(); DCHECK_GT(num_computations, 0); int num_output_buffers = ifrt_arrays.size(); @@ -206,7 +205,7 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, std::optional>>& returned_futures) { - std::vector> output_arrays; + std::vector output_arrays; std::unique_ptr> returned_future; int num_computations = ifrt_loaded_executable->addressable_devices().size(); PjRtFuture<> result_status; @@ -224,7 +223,7 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( })); } } - std::vector> arg_arrays(args.size()); + std::vector arg_arrays(args.size()); absl::c_transform(args, arg_arrays.begin(), [&](const ExecuteShardedArg& arg) mutable { return GetIfRtArray(arg); @@ -257,10 +256,10 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( } // namespace -PyExecuteResults::PyExecuteResults( - const nb_class_ptr& client, - std::vector> ifrt_arrays, - int num_computations, PyShardedToken token, PjRtFuture<> result_status) +PyExecuteResults::PyExecuteResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status) : client_(client), ifrt_arrays_(std::move(ifrt_arrays)), num_computations_(num_computations), @@ -273,7 +272,7 @@ void PyExecuteResults::CheckNotDisassembled() const { } } -std::vector> PyExecuteResults::Consume() { +std::vector PyExecuteResults::Consume() { CheckNotDisassembled(); is_exploded_ = true; return std::move(ifrt_arrays_); @@ -306,7 +305,7 @@ PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { ifrt_arrays_.size()) .c_str()); } - std::vector> ifrt_arrays; + std::vector ifrt_arrays; ifrt_arrays.reserve(ifrt_arrays_.size() - n); for (size_t i = n; i < ifrt_arrays_.size(); ++i) { ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index 9f8034ed2675..7e329410c763 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -89,7 +89,7 @@ class PyShardedToken { class PyExecuteResults { public: PyExecuteResults(const nb_class_ptr& client, - std::vector> ifrt_arrays, + std::vector ifrt_arrays, int num_computations, PyShardedToken token, PjRtFuture<> result_status = PjRtFuture<>()); @@ -102,7 +102,7 @@ class PyExecuteResults { std::vector> out_handlers); - std::vector> Consume(); + std::vector Consume(); PyShardedToken ConsumeToken(); @@ -117,7 +117,7 @@ class PyExecuteResults { bool is_exploded_ = false; bool token_consumed_ = false; nb_class_ptr client_; - std::vector> ifrt_arrays_; + std::vector ifrt_arrays_; int num_computations_; PyShardedToken token_; // Only set if the computation has tokens. diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index a0bd943333ee..89900b02bd93 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -112,7 +112,7 @@ absl::StatusOr MemorySpaceFromSharding( class IfrtArrayEntry : public PullTable::Entry { public: struct BufferRef { - tsl::RCReference arr; + xla::ifrt::ArrayRef arr; xla::PjRtBuffer* buffer; size_t buf_size; }; @@ -158,7 +158,7 @@ class IfrtArrayEntry : public PullTable::Entry { }; absl::StatusOr> CreatePullEntry( - const std::vector>& arrs, + const std::vector& arrs, std::shared_ptr state, size_t xfer_size) { std::vector refs; for (auto& arr : arrs) { @@ -228,8 +228,7 @@ class PyTransferServer { server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); } - void AwaitPull(uint64_t uuid, - const std::vector>& arrs) { + void AwaitPull(uint64_t uuid, const std::vector& arrs) { server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( arrs, premapped_copier_, xfer_size_))); } @@ -260,7 +259,7 @@ absl::StatusOr ArraySpecFromShapeDtypeStruct( } struct BufferSource { - tsl::RCReference arr; + xla::ifrt::ArrayRef arr; xla::PjRtBuffer* buffer; }; @@ -373,7 +372,7 @@ void RegisterTransferServerTypes(nanobind::module_& m) { .def("_await_pull_flat", [](PyTransferServer& self, uint64_t uuid, std::vector inputs) { - std::vector> arrs; + std::vector arrs; arrs.reserve(inputs.size()); for (const xla::PyArray& input : inputs) { arrs.push_back(tsl::FormRef(input.ifrt_array())); diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index 5225b67ee93b..81f6523d3e14 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -81,7 +81,7 @@ namespace { // Prepared data for creating a single shard of an array. Holds a single-device // IFRT array or a host buffer. struct Shard { - explicit Shard(tsl::RCReference ifrt_array, bool weak_type) + explicit Shard(ifrt::ArrayRef ifrt_array, bool weak_type) : ifrt_array_or_host_buffer(std::move(ifrt_array)), weak_type(weak_type), // host_buffer_semantics is not meaningful when @@ -101,14 +101,13 @@ struct Shard { Shard& operator=(Shard&&) noexcept = default; bool is_ifrt_array() const { - return std::holds_alternative>( - ifrt_array_or_host_buffer); + return std::holds_alternative(ifrt_array_or_host_buffer); } ifrt::DType ifrt_dtype() const; const ifrt::Shape& ifrt_shape() const; // Points to the on-device array or on-host buffer. - std::variant, ifrt::Client::HostBuffer> + std::variant ifrt_array_or_host_buffer; bool weak_type; ifrt::Client::HostBufferSemantics host_buffer_semantics; @@ -155,13 +154,12 @@ using DevicePutHandler = std::function( // buffer, and be not applied when reusing an existing IFRT array. // // Expected to be called without holding GIL. -absl::StatusOr> -MakeSingleDeviceIfrtArrayFromShard( +absl::StatusOr MakeSingleDeviceIfrtArrayFromShard( xla::ifrt::Client* ifrt_client, xla::ifrt::Device* ifrt_device, xla::ifrt::MemoryKind ifrt_memory_kind, Shard& shard, tsl::RCReference user_context) { - if (auto* ifrt_array = std::get_if>( - &shard.ifrt_array_or_host_buffer)) { + if (auto* ifrt_array = + std::get_if(&shard.ifrt_array_or_host_buffer)) { return std::move(*ifrt_array); } else { auto host_buffer_shard = std::get( @@ -181,7 +179,7 @@ MakeSingleDeviceIfrtArrayFromShard( // path). `shards` will be consumed. // // Expected to be called without holding GIL. -absl::StatusOr> MakeIfrtArrayFromShardsInBatch( +absl::StatusOr MakeIfrtArrayFromShardsInBatch( ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, ifrt::ShardingRef ifrt_sharding, absl::Span shards, tsl::RCReference user_context) { @@ -221,8 +219,7 @@ absl::StatusOr> MakeIfrtArrayFromShardsInBatch( // `shards` will be consumed. // // Expected to be called without holding GIL. -absl::StatusOr> -MakeIfrtArrayFromShardsWithAssembly( +absl::StatusOr MakeIfrtArrayFromShardsWithAssembly( ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, ifrt::ShardingRef ifrt_sharding, ifrt::DeviceList* ifrt_addressable_device_list, @@ -230,10 +227,10 @@ MakeIfrtArrayFromShardsWithAssembly( tsl::RCReference user_context) { absl::Span ifrt_addressable_devices = ifrt_addressable_device_list->devices(); - std::vector> ifrt_array_shards; + std::vector ifrt_array_shards; ifrt_array_shards.reserve(shards.size()); for (int64_t i = 0; i < shards.size(); ++i) { - TF_ASSIGN_OR_RETURN(tsl::RCReference ifrt_array_shard, + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array_shard, MakeSingleDeviceIfrtArrayFromShard( ifrt_client, ifrt_addressable_devices[i], ifrt_memory_kind, shards[i], user_context)); @@ -569,8 +566,7 @@ absl::StatusOr HandlePyArray(nb::handle obj, ifrt::Client* client, ifrt::DType Shard::ifrt_dtype() const { if (is_ifrt_array()) { - return std::get>(ifrt_array_or_host_buffer) - ->dtype(); + return std::get(ifrt_array_or_host_buffer)->dtype(); } else { return std::get(ifrt_array_or_host_buffer).dtype; } @@ -578,8 +574,7 @@ ifrt::DType Shard::ifrt_dtype() const { const ifrt::Shape& Shard::ifrt_shape() const { if (is_ifrt_array()) { - return std::get>(ifrt_array_or_host_buffer) - ->shape(); + return std::get(ifrt_array_or_host_buffer)->shape(); } else { return std::get(ifrt_array_or_host_buffer).shape; } @@ -916,7 +911,7 @@ absl::StatusOr DevicePutWithDevice( nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fn)()); - TF_ASSIGN_OR_RETURN(tsl::RCReference ifrt_array, + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array, MakeSingleDeviceIfrtArrayFromShard( ifrt_client, ifrt_device, ifrt_memory_kind, shard, std::move(ifrt_user_context))); @@ -1024,7 +1019,7 @@ absl::StatusOr DevicePutWithSharding( /*is_fully_replicated=*/false); } - tsl::RCReference ifrt_array; + ifrt::ArrayRef ifrt_array; if (should_batch) { TF_ASSIGN_OR_RETURN(ifrt_array, MakeIfrtArrayFromShardsInBatch( diff --git a/jaxlib/py_values.h b/jaxlib/py_values.h index 40b186fb7fc0..64a83aa66ab9 100644 --- a/jaxlib/py_values.h +++ b/jaxlib/py_values.h @@ -38,7 +38,7 @@ limitations under the License. namespace xla { struct DevicePutResult { - DevicePutResult(tsl::RCReference ifrt_array, bool weak_type) + DevicePutResult(ifrt::ArrayRef ifrt_array, bool weak_type) : ifrt_array(std::move(ifrt_array)), weak_type(weak_type) {} // Disallow copy. `DevicePutResult` is expected to be consumed by one user. @@ -48,7 +48,7 @@ struct DevicePutResult { DevicePutResult& operator=(DevicePutResult&&) noexcept = default; // Points to the on-device array. - tsl::RCReference ifrt_array; + ifrt::ArrayRef ifrt_array; bool weak_type; }; From e8e586e698d36fa5bab363c93b251f9c21c68e7f Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 7 May 2025 11:03:00 -0700 Subject: [PATCH 1051/1769] expose mutable_array in experimental PiperOrigin-RevId: 755925216 --- jax/experimental/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 6c37635df1b0..1b4f7efedbe7 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -32,3 +32,7 @@ from jax._src.earray import ( EArray as EArray ) +from jax._src.core import ( + mutable_array as mutable_array, + MutableArray as MutableArray, +) From a4adf32f9a9bf0ad80e4ae6ad6cbb6c00cc1dae2 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 7 May 2025 11:19:22 -0700 Subject: [PATCH 1052/1769] Replace `std::shared_ptr` with `xla::ifrt::LoadedExecutableRef` PiperOrigin-RevId: 755932233 --- jaxlib/py_compile_only_client.cc | 4 ++-- jaxlib/py_executable.cc | 2 +- jaxlib/py_executable.h | 13 ++++++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 4d53fc6ee832..0fa2f4b48fd7 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -70,7 +70,7 @@ class CompileOnlyPyClient : public PyClient { return client; } - absl::StatusOr> CompileUnloaded( + absl::StatusOr CompileUnloaded( absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, CompileOptions options, std::vector host_callbacks) { if (!host_callbacks.empty()) { @@ -102,7 +102,7 @@ class CompileOnlyPyClient : public PyClient { *ifrt_client->topology().description())); TF_ASSIGN_OR_RETURN(auto ifrt_executable, ifrt::PjRtExecutable::Create(std::move(executable))); - return std::shared_ptr(std::move(ifrt_executable)); + return ifrt::ExecutableRef(std::move(ifrt_executable)); } private: diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index d79b236b9241..a9a7ebea2d0b 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -84,7 +84,7 @@ absl::Status PyShardedToken::Await() { PyLoadedExecutable::PyLoadedExecutable( nb_class_ptr client, - std::shared_ptr ifrt_loaded_executable, + ifrt::LoadedExecutableRef ifrt_loaded_executable, std::optional traceback, std::optional fingerprint) : client_(std::move(client)), diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index 7e329410c763..0e7762730cfb 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -131,11 +131,10 @@ using ExecuteShardedArg = std::variant>; // b) to add Python-specific functionality. class PyLoadedExecutable { public: - PyLoadedExecutable( - nb_class_ptr client, - std::shared_ptr ifrt_loaded_executable, - std::optional traceback, - std::optional fingerprint); + PyLoadedExecutable(nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); ~PyLoadedExecutable(); nb_class_ptr client() const { return client_; } @@ -143,7 +142,7 @@ class PyLoadedExecutable { return ifrt_loaded_executable_.get(); } - std::shared_ptr shared_ifrt_loaded_executable() { + ifrt::LoadedExecutableRef shared_ifrt_loaded_executable() { return ifrt_loaded_executable_; } @@ -221,7 +220,7 @@ class PyLoadedExecutable { friend class PyClient; nb_class_ptr client_; - std::shared_ptr ifrt_loaded_executable_; + ifrt::LoadedExecutableRef ifrt_loaded_executable_; std::optional traceback_; // Identical executables (i.e. representing the same program) will have the From 6259cfe31742e8985f9522beb1d64d7903002950 Mon Sep 17 00:00:00 2001 From: Jaswanth Sreeram Date: Wed, 7 May 2025 11:26:27 -0700 Subject: [PATCH 1053/1769] Fold transposes feeding the RHS of a matmul into a change in the dot dimension numbers. PiperOrigin-RevId: 755935066 --- .../tpu/transforms/canonicalize_mosaic.cc | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 7f241deb550c..7a5e7ba0c10b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -205,6 +205,37 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, } } + // Attempt to canonicalize matmul(x, transpose(y)) to a matmul with the + // dimension numbers changed which will later be lowered into a more efficient + // operation that fuses the transpose into the matmul. + auto transpose_op = + dyn_cast_if_present(rhs.getDefiningOp()); + auto dimension_numbers = op.getDimensionNumbers(); + if (transpose_op && transpose_op->hasOneUse() && + dimension_numbers->getRhsContractingDims().size() == 1 && + dimension_numbers->getRhsNonContractingDims().size() == 1) { + auto rhs_non_contracting_dim = + dimension_numbers->getRhsNonContractingDims()[0]; + auto rhs_contracting_dim = dimension_numbers->getRhsContractingDims()[0]; + auto permutation = transpose_op.getPermutation(); + if (permutation[rhs_contracting_dim] == rhs_non_contracting_dim && + permutation[rhs_non_contracting_dim] == rhs_contracting_dim && + std::all_of(dimension_numbers->getRhsBatchDims().begin(), + dimension_numbers->getRhsBatchDims().end(), + [&](long batch_dim) { + return permutation[batch_dim] == batch_dim; + })) { + if (auto transpose_op_vector_operand = + dyn_cast>(transpose_op.getOperand())) { + // The transpose is DCE'ed away at a later point. + rhs = transpose_op_vector_operand; + transpose_rhs = !transpose_rhs; + } else { + return op->emitOpError("Unexpected operand type for transpose op."); + } + } + } + auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { auto precision_attr = op.getPrecisionAttr(); From 3ea7a5dac2f1188e79b6c3306fb07d74485d310f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 May 2025 13:00:51 -0700 Subject: [PATCH 1054/1769] [pallas:mosaic] Use `cf.assert` directly in the lowering rule for `checkify.check_p` We now bundle the `cf` dialect with jaxlib and register it in the `ir.Context`, so `cf.assert` can be used directly. PiperOrigin-RevId: 755972714 --- jax/_src/pallas/mosaic/lowering.py | 23 +++++++------------ .../pallas/mosaic/pallas_call_registration.py | 3 +-- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f372f8f9b472..6aae3a07c764 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -51,6 +51,7 @@ from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import math from jax._src.lib.mlir.dialects import memref @@ -160,7 +161,6 @@ def to_placeholder(self, dim_expr: Any) -> ir.Value: @dataclasses.dataclass class LoweringContext: - ir_context: ir.Context grid_sizes: tuple[int, ...] # Includes both user and vmap axes. grid_names: tuple[Hashable, ...] | None mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. @@ -668,7 +668,6 @@ def err_details(): def lower_jaxpr_to_module( lowering_context: mlir.LoweringRuleContext, - ctx: ir.Context, grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, *, @@ -720,7 +719,6 @@ def dynamic_shape_replacement_fn( sym_tab = ir.SymbolTable(m.operation) func_op = lower_jaxpr_to_func( - ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, name="main", @@ -755,7 +753,6 @@ def dynamic_shape_replacement_fn( continue mlir_func = lower_jaxpr_to_transform_func( - ctx, bm.index_map_jaxpr.jaxpr, bm.block_aval, name=func_name, @@ -902,7 +899,6 @@ def dynamic_shape_replacement_fn( def lower_jaxpr_to_transform_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, aval: jax_core.AbstractValue, *, @@ -937,7 +933,6 @@ def body_func(*args): else: mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, @@ -970,7 +965,6 @@ def body_func(*args): def lower_jaxpr_to_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, *, mosaic_grid_mapping: MosaicGridMapping, @@ -1009,7 +1003,6 @@ def body_func(*args): else: mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, @@ -3742,9 +3735,12 @@ def _checkify_lowering_rule( "--jax_pallas_enable_runtime_assert " "or functionalize with checkify.check.") - assert ctx.lowering_context.ir_context.allow_unregistered_dialects, ( - "allow_unregistered_dialects must be set to True for " - "runtime assert check.") + if cf is None: + # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. + raise ValueError( + "cf dialect is not available. Make sure you have jaxlib 0.6.1 or later." + ) + error = jax.tree.unflatten(err_tree, err_args) assert len(error._pred) == 1 assert len(error._metadata) == 1 @@ -3761,10 +3757,7 @@ def _checkify_lowering_rule( out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) minus_one = ir_constant(-1, out_scalar_type) not_pred = arith.xori(pred, minus_one) - attrs = {"msg": ir.StringAttr.get(exception.fmt_string)} - ir.Operation.create("cf.assert", - operands=(not_pred,), - attributes=attrs) + cf.assert_(not_pred, exception.fmt_string) return [] diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index c1d1a8029c5f..5de917d077ce 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -142,12 +142,11 @@ def pallas_call_tpu_lowering_rule( tpu.register_dialect(mlir_ctx) def lower_module(for_verification: bool): - if for_verification or tpu_core.runtime_assert_enabled(): + if for_verification: mlir_ctx.allow_unregistered_dialects = True with mlir_ctx, ir.Location.unknown(mlir_ctx): return lowering.lower_jaxpr_to_module( ctx, - mlir_ctx, grid_mapping, jaxpr, dimension_semantics=mosaic_params.dimension_semantics, From 261141615dd3fe2b15f7c886acb290b6e9de09ea Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 7 May 2025 14:16:30 -0700 Subject: [PATCH 1055/1769] jnp.packbits: fix handling of negative entries --- jax/_src/numpy/lax_numpy.py | 2 +- tests/lax_numpy_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c40716daee4b..b662f1d6f7ed 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8945,7 +8945,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") - arr = lax.gt(arr, _lax_const(arr, 0)).astype('uint8') + arr = lax.ne(arr, _lax_const(arr, 0)).astype('uint8') bits = arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 2e4a84c6c99e..875024617b5f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4661,7 +4661,7 @@ def testRollaxis(self, shape, dtype, start, axis): self._CompileAndCheck(jnp_op, args_maker) @jtu.sample_product( - dtype=[np.uint8, np.bool_], + dtype=int_dtypes + unsigned_dtypes + bool_dtypes, bitorder=['big', 'little'], shape=[(1, 2, 3, 4)], axis=[None, 0, 1, -2, -1], From 412de97e2b2962324ec30b0c801559d19d070e34 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 7 May 2025 14:24:46 -0700 Subject: [PATCH 1056/1769] Clean up `PyLoadedExecutable::Delete` PiperOrigin-RevId: 756006422 --- jaxlib/_jax/__init__.pyi | 1 - jaxlib/py_executable.h | 2 -- jaxlib/xla.cc | 1 - 3 files changed, 4 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index c0c0dc77ccb7..c6c8a67e4653 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -678,7 +678,6 @@ class LoadedExecutable: client: Client def local_devices(self) -> list[Device]: ... def size_of_generated_code_in_bytes(self) -> int: ... - def delete(self) -> None: ... def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ... def execute_with_token( self, arguments: Sequence[ArrayImpl] diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index 0e7762730cfb..6354edaf9a3e 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -161,8 +161,6 @@ class PyLoadedExecutable { return ifrt_loaded_executable_->GetCostAnalysis(); } - void Delete() {} // Will be deleted. - // Takes args indexed by argid then deviceid, transposes them, and passes to // PjRtExecutable::Execute. The result is similarly transposed back into the // argid,deviceid format. diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 8b47223591dd..fa83afc2dc1f 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -511,7 +511,6 @@ NB_MODULE(_jax, m) { .def( "get_compiled_memory_stats", xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) - .def("delete", &PyLoadedExecutable::Delete) .def("execute_sharded", xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), nb::arg("arguments"), nb::arg("with_tokens") = false) From dd375cbd336ea49339c349ad350adef66cbd6176 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 7 May 2025 15:50:40 -0700 Subject: [PATCH 1057/1769] Deprecate parsing of __jax_array__ during abstractification. Going forward, objects defining __jax_array__ should define pytree lowering if they want to be compatible with JAX transformations. --- jax/_src/core.py | 13 +++++++++++++ jax/_src/deprecations.py | 1 + jax/_src/numpy/lax_numpy.py | 2 +- tests/api_test.py | 24 +++++++++++++++++++++++- tests/array_extensibility_test.py | 3 +++ 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 780d2c487173..2ee3136e6c65 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -34,6 +34,7 @@ import numpy as np +from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects @@ -1554,6 +1555,12 @@ def shaped_abstractify(x): if isinstance(x, AbstractValue): return x if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return shaped_abstractify(x.__jax_array__()) if hasattr(x, 'dtype'): aval = ShapedArray(np.shape(x), x.dtype, @@ -1578,6 +1585,12 @@ def get_aval(x): if (aval_fn := pytype_aval_mappings.get(t)): return aval_fn(x) if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return get_aval(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 329491b1e8a8..4e5e22745658 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -135,3 +135,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-trimzeros-not-1d-array') register('jax-scipy-special-sph-harm') register('jax-jit-positional-args') +register('jax-abstract-dunder-array') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c40716daee4b..be3ac14647a8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1235,7 +1235,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: [2, 5], [3, 6]], dtype=int32) """ - util.check_arraylike("permute_dims", a) + a = util.ensure_arraylike("permute_dims", a) return lax.transpose(a, axes) diff --git a/tests/api_test.py b/tests/api_test.py index 610719518a03..53afed80cb70 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3959,6 +3959,9 @@ def test_default_device(self): def test_dunder_jax_array(self): # https://github.com/jax-ml/jax/pull/4725 + @partial(jax.tree_util.register_dataclass, + data_fields=['jax_val'], + meta_fields=[]) class AlexArray: def __init__(self, jax_val): self.jax_val = jax_val @@ -3968,10 +3971,16 @@ def __jax_array__(self): shape = property(lambda self: self.jax_val.shape) x = AlexArray(jnp.array([1., 2., 3.])) + + y = jax.jit(lambda x: x)(x) + self.assertIsInstance(x, AlexArray) + self.assertArraysEqual(jnp.asarray(x), jnp.asarray(y)) + y = jnp.sin(x) self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.]))) y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x) - self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.]))) + self.assertIsInstance(y, AlexArray) + self.assertAllClose(jnp.asarray(y), jnp.cos(jnp.array([1., 2., 3.]))) x = AlexArray(jnp.array([[1., 2., 3.]])) y = api.pmap(jnp.sin)(x) @@ -3989,6 +3998,19 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_dunder_jax_array_warnings(self): + class AlexArray: + def __init__(self, jax_val): + self.jax_val = jax_val + def __jax_array__(self): + return self.jax_val + + f = jax.jit(lambda x: x) + a = AlexArray(jnp.arange(4)) + msg = r"Triggering of __jax_array__\(\) during abstractification is deprecated." + with self.assertDeprecationWarnsOrRaises('jax-abstract-dunder-array', msg): + f(a) + @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe def test_eval_shape_weak_type(self): # https://github.com/jax-ml/jax/issues/23302 diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 91e2a1d9cf6d..36726659c2f9 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -29,6 +29,9 @@ config.parse_flags_with_absl() +@functools.partial(jax.tree_util.register_dataclass, + data_fields=['x'], + meta_fields=[]) class JaxArrayWrapper: """Class that provides a __jax_array__ method.""" x: ArrayLike From cfab66d540a7ebbeb37b815b5f250691b0ca98d7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 7 May 2025 16:30:03 -0700 Subject: [PATCH 1058/1769] [Mosaic] Move pallas + mosaic over to tpu transpose from vector. Keep vector around for compat reasons, as legacy. PiperOrigin-RevId: 756051773 --- jax/_src/pallas/mosaic/lowering.py | 5 ++- jaxlib/mosaic/dialect/tpu/tpu.td | 23 ++++++++++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 35 +++++++++++++++++++ .../tpu/transforms/apply_vector_layout.cc | 6 ++-- .../tpu/transforms/canonicalize_mosaic.cc | 16 +++++++-- .../tpu/transforms/infer_vector_layout.cc | 6 ++-- 6 files changed, 82 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6aae3a07c764..d7e26ec3b342 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2357,7 +2357,10 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) - return vector.transpose(out_type, x, permutation) + if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 5, 8): + return vector.transpose(out_type, x, permutation) + else: + return tpu.transpose(out_type, x, permutation) def _bcast(x, y, x_aval, y_aval, out_aval): diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index c2d35f6f694a..7f295b4ec09b 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -909,6 +909,29 @@ def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType] let hasVerifier = 1; } +def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> { + let summary = "tpu transpose operation"; + let arguments = (ins AnyVectorOfAnyRank:$vector, + DenseI64ArrayAttr:$permutation); + let results = (outs AnyVectorOfAnyRank:$result); + + let builders = [ + OpBuilder<(ins "Value":$vector, "ArrayRef":$permutation)> + ]; + let assemblyFormat = [{ + $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getVector().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + let hasVerifier = 1; +} + def TPU_LogOp : TPU_Op<"log"> { let arguments = (ins Variadic:$inputs, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 6c2e6b700bba..134db412042d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -324,6 +324,41 @@ LogicalResult MemRefReshapeOp::verify() { return success(); } +LogicalResult TransposeOp::verify() { + auto source_type = getSourceVectorType(); + auto permutation = getPermutation(); + auto output_type = getResultVectorType(); + auto input_shape = source_type.getShape(); + auto output_shape = output_type.getShape(); + if (source_type.getElementType() != output_type.getElementType()) { + return emitOpError("Expected input and output element types to match"); + } + if (permutation.size() != source_type.getRank()) { + return emitOpError("Expected permutation rank to match input rank"); + } + if (permutation.size() != output_type.getRank()) { + return emitOpError("Expected permutation rank to match output rank"); + } + std::vector seen_dims(source_type.getRank(), false); + for (int64_t dim : permutation) { + if (dim < 0 || dim >= source_type.getRank()) { + return emitOpError("Permutation element out of bounds: ") << dim; + } + if (seen_dims[dim]) { + return emitOpError("Permutation element repeated: ") << dim; + } + seen_dims[dim] = true; + } + for (int i = 0; i < source_type.getRank(); ++i) { + if (input_shape[permutation[i]] != output_shape[i]) { + return emitOpError( + "Expected input shape permuted by the given permutation to match the " + "output shape"); + } + } + return success(); +} + LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, PatternRewriter &rewriter) { auto src_ty = op.getInput().getType(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 04902c2af6f3..f8e18070e5e7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4652,7 +4652,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: Unsupported 2D layouts"); } ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto transpose_op = cast(op); + auto transpose_op = cast(op); VectorType src_ty = transpose_op.getSourceVectorType(); VectorType dst_ty = transpose_op.getResultVectorType(); const int64_t rank = src_ty.getRank(); @@ -4735,7 +4735,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const Value src_tile = assemble(builder, tile_ty_in, layout_in, src_tile_vregs, ctx.target_shape); auto new_transpose_op = - builder.create(tile_ty_out, src_tile, minor_perm); + builder.create(tile_ty_out, src_tile, minor_perm); new_transpose_op->setAttr("out_layout", builder.getAttr(layout_out)); auto unroll_vectors_op = builder.create( @@ -4871,7 +4871,7 @@ const llvm::StringMap &rules() { vector_extract_strided_slice_rule}, {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, {vector::StoreOp::getOperationName(), vector_store_rule}, - {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; + {tpu::TransposeOp::getOperationName(), vector_transpose_rule}}; for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { rules->insert({name, rule}); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 7a5e7ba0c10b..110550127ca5 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -209,7 +209,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, // dimension numbers changed which will later be lowered into a more efficient // operation that fuses the transpose into the matmul. auto transpose_op = - dyn_cast_if_present(rhs.getDefiningOp()); + dyn_cast_if_present(rhs.getDefiningOp()); auto dimension_numbers = op.getDimensionNumbers(); if (transpose_op && transpose_op->hasOneUse() && dimension_numbers->getRhsContractingDims().size() == 1 && @@ -259,7 +259,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, const SmallVector perm_vec = SmallVector(perm.begin(), perm.end()); - lhs = builder.create( + lhs = builder.create( lhs_ty_transposed, lhs, DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); } @@ -703,6 +703,17 @@ LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_vector_transpose(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto new_op = builder.create(op.getType(), op.getVector(), + op.getPermutation()); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + return success(); +} + using canonicalize_rule_type = std::function; @@ -713,6 +724,7 @@ const llvm::StringMap &rules() { {vector::ExtractOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, + {vector::TransposeOp::getOperationName(), canonicalize_vector_transpose}, {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c81701d9a398..2e4c1c9c48a9 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -320,7 +320,7 @@ class VectorLayoutInferer { if (inferStore(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } @@ -1622,7 +1622,7 @@ class VectorLayoutInferer { return success(); } - LogicalResult infer(vector::TransposeOp op) { + LogicalResult infer(tpu::TransposeOp op) { auto permutation = op.getPermutation(); TPU_CHECK_OP(permutation.size() > 1, "Vector and scalar transpose should be a no-op and removed"); @@ -1910,7 +1910,7 @@ class VectorLayoutInferer { continue; } } - if (auto transpose = dyn_cast(operand.getOwner())) { + if (auto transpose = dyn_cast(operand.getOwner())) { auto perm = transpose.getPermutation(); auto rank = perm.size(); // Only permutations that actually swap the last two dims need it. From 292b468af2bc454df2906663a1972ce285201579 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 7 May 2025 16:35:11 -0700 Subject: [PATCH 1059/1769] Get rid of `HloModuleProto` from `CompiledMemoryStats` `HloModuleProto` can be obtained from `LoadedExecutable::GetHloModuleProtos()`, so it's wasteful to duplicate this information. PiperOrigin-RevId: 756053432 --- jaxlib/_jax/__init__.pyi | 2 +- jaxlib/xla.cc | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index c6c8a67e4653..8c02bb4ba722 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -328,7 +328,7 @@ class CompiledMemoryStats: host_output_size_in_bytes: int host_alias_size_in_bytes: int host_temp_size_in_bytes: int - serialized_hlo_proto: bytes + serialized_buffer_assignment_proto: bytes def __str__(self) -> str: ... class ExecutableBuildOptions: diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index fa83afc2dc1f..0d3d8f6e1b29 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -487,10 +487,26 @@ NB_MODULE(_jax, m) { &CompiledMemoryStats::host_alias_size_in_bytes) .def_rw("host_temp_size_in_bytes", &CompiledMemoryStats::host_temp_size_in_bytes) - .def_prop_ro("serialized_hlo_proto", + .def_prop_ro("serialized_buffer_assignment_proto", [](const CompiledMemoryStats& cms) -> nb::bytes { - return nb::bytes(cms.serialized_hlo_proto.data(), - cms.serialized_hlo_proto.size()); +#if JAX_IFRT_VERSION_NUMBER >= 7 + if (cms.buffer_assignment.has_value()) { + std::string s = + cms.buffer_assignment->SerializeAsString(); + return nb::bytes(s.data(), s.size()); + } else { + return nb::bytes(); + } +#else + xla::HloProto hlo; + if (!cms.serialized_hlo_proto.empty() && + hlo.ParseFromString(cms.serialized_hlo_proto)) { + std::string s = + hlo.buffer_assignment().SerializeAsString(); + return nb::bytes(s.data(), s.size()); + } + return nb::bytes(); +#endif }) .def("__str__", &CompiledMemoryStats::DebugString); From 0a68794ebb566d898eb1de28959e91f12faa3167 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 7 May 2025 16:57:07 -0700 Subject: [PATCH 1060/1769] Disable specific Triton fused_attention backwards test on A100. PiperOrigin-RevId: 756060415 --- tests/pallas/gpu_ops_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index a33760cbfa86..3c352afe3382 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -267,6 +267,17 @@ def test_fused_attention_bwd( causal, use_segment_ids, ): + if jtu.is_cuda_compute_capability_equal("8.0") and all([ + block_sizes["block_q"] == 128, + batch_size == 2, + num_heads == 2, + head_dim == 128, + causal, + not use_segment_ids + ]): + # TODO(b/416306534) + self.skipTest("Precision issues after CUDA 12.8.1 upgrade") + k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 From 48140bcbe7358f745e53bc3d1f87ad0465fbb9dd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 May 2025 17:05:51 -0700 Subject: [PATCH 1061/1769] Add basic support of `unreduced` to sharding-in-types! We cannot lower it right now, but it atleast shows up in types. The API to specify unreduced is via `PartitionSpec`. For example: `PartitionSpec('x', 'y', None, unreduced='z')` or `PartitionSpec('x', unreduced=('y', 'z'))`. In types/jaxpr, unreduced will show up as: `f32[8@x,2]{U:y}` But we support unreduced only in dot_general and nary ops (add, mul, etc) as of this change: (the support will be expanded in following changes) * **dot general** only allows unreduced when contracting dims are sharded. And the unreduced axes specified by the user needs to match the sharding of the contracting dims. In all other cases, an error is raised. An example of how unreduced can be specified: `jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y'))` * **nary ops** can propagate unreduced (add, mul, etc). If all ops aren't unreduced across the same mesh axes, an error is raised. PiperOrigin-RevId: 756063074 --- jax/_src/core.py | 37 +++++++++----- jax/_src/interpreters/batching.py | 2 + jax/_src/lax/lax.py | 36 +++++++++++-- jax/_src/named_sharding.py | 32 ++++++++++-- jax/_src/numpy/einsum.py | 7 +-- jax/_src/partition_spec.py | 85 ++++++++++++++++++++++++++----- tests/array_test.py | 63 +++++++++++++++++++++++ tests/pjit_test.py | 84 +++++++++++++++++++++++++++++- 8 files changed, 308 insertions(+), 38 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 780d2c487173..da9efdb0eacc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1825,11 +1825,12 @@ def get_cur_mesh_sharding(spec=None): return NamedSharding(mesh_lib.get_abstract_mesh(), spec) def _make_lengths_same(sharding, ndim): - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - if ndim < len(sharding.spec): - assert all(s is None for s in sharding.spec[ndim:]), (ndim, sharding.spec) - return sharding.with_spec(sharding.spec[:ndim]) + pspec = sharding.spec + if ndim > len(pspec): + return sharding.with_spec(pspec._normalized_spec_for_aval(ndim)) + if ndim < len(pspec): + assert all(s is None for s in pspec[ndim:]), (ndim, pspec) + return sharding.with_spec(P(*pspec[:ndim], unreduced=pspec.unreduced)) assert False, "unreachable" # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with @@ -1841,11 +1842,11 @@ def modify_spec_for_auto_manual(spec, mesh) -> P: new_spec.append(s) else: temp_s = s[0] if isinstance(s, tuple) else s - new_spec.append( - None - if mesh._name_to_type[temp_s] in (AxisType.Auto, AxisType.Manual) - else s) - return P(*new_spec) + new_spec.append(s if mesh._name_to_type[temp_s] == AxisType.Explicit + else None) + new_unreduced = tuple(u for u in spec.unreduced + if mesh._name_to_type[u] == AxisType.Explicit) + return P(*new_spec, unreduced=new_unreduced) def _maybe_modify_sharding(sharding, ndim): if len(sharding.spec) == 0 or all(s is None for s in sharding.spec): @@ -1905,8 +1906,8 @@ def str_short_aval(shape, dtype, mesh, spec, vma, dt_str = dt_str.replace('void', 'float0') shapestr = _get_shape_sharding_str(shape, spec) mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - vma = f"{{{','.join(i for i in vma)}}}" if vma else '' - return f'{dt_str}[{shapestr}]{vma}{mesh_axes}' + vma_ur = _vma_ur_str(vma, spec.unreduced) + return f'{dt_str}[{shapestr}]{vma_ur}{mesh_axes}' def get_vma(vma, mesh): if mesh.empty: @@ -2000,6 +2001,18 @@ def _get_shape_sharding_str(shape, spec): out.append(f"{s1}@{s2}") return ','.join(out) +def _create_str(x, prefix): + x_str = f"{','.join(i for i in x)}" + x_str = x_str if len(x) == 1 else f"({x_str})" + return f"{prefix}:{x_str}" + +def _vma_ur_str(vma, unreduced): + if not vma and not unreduced: + return '' + vma_str = _create_str(vma, 'V') if vma else '' + ur_str = _create_str(unreduced, 'U') if unreduced else '' + sep = ', ' if vma and unreduced else '' + return f"{{{vma_str}{sep}{ur_str}}}" def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 1c6e00861448..0fbe54a30672 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -458,6 +458,8 @@ class AxisData: def get_sharding_for_vmap(axis_data, orig_sharding, axis): val = axis_data.explicit_mesh_axis + # TODO(yashkatariya): Preserve unreduced here using + # `orig_sharding.spec.with_partitions` new_spec = P(*tuple_insert(orig_sharding.spec, axis, val)) return NamedSharding(orig_sharding.mesh, new_spec) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6c248565174c..f6b4c1be102f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3954,7 +3954,7 @@ def broadcasting_sharding_rule(name, *avals): result_specs = [None] * len(shapes[0]) for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): - if all(s == ss[0] for s in ss[1:]): + if all(ss[0] == s for s in ss[1:]): # if all dimension shardings are same, the resulting dimension sharding is # the same. result_specs[i] = ss[0] @@ -3974,7 +3974,20 @@ def broadcasting_sharding_rule(name, *avals): raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') - return NamedSharding(mesh, P(*result_specs)) + + unreduced = [a.sharding.spec.unreduced for a in avals if a.shape] + # TODO(yashkatariya): Relax this restriction to allow + # `f32[8]{R:x} * f32[8]{U:x} -> f32[8]{U:x}` for example and maybe more cases. + if unreduced: + if not all(unreduced[0] == u for u in unreduced[1:]): + raise core.ShardingTypeError( + 'All arrays must be unreduced along the same mesh axes. Got' + f' {", ".join(map(str, map(tuple, unreduced)))}') + result_unreduced = unreduced[0] + else: + result_unreduced = None + + return NamedSharding(mesh, P(*result_specs, unreduced=result_unreduced)) def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): @@ -5190,11 +5203,26 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) + rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) + if out_sharding is not None: assert isinstance(out_sharding, NamedSharding) + if out_sharding.spec.unreduced: + if lhs_contracting_spec != rhs_contracting_spec: + raise core.ShardingTypeError( + 'lhs and rhs contracting dims should be sharded identically when' + ' out_sharding provided to dot_general mentions unreduced_axes.' + f' Got {out_sharding=}, {lhs_contracting_spec=},' + f' {rhs_contracting_spec=}') + if out_sharding.spec.unreduced != lhs_contracting_spec: + raise core.ShardingTypeError( + "out_sharding's unreduced axes should be equal to the contracting" + f' specs. Got unreduced axes={out_sharding.spec.unreduced} and' + f' contracting spec={lhs_contracting_spec}') return out_sharding - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " @@ -5202,8 +5230,6 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) - lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) - rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index d0a49b6b081c..3dfcbd29fc96 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -25,6 +25,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib +from jax._src.mesh import AxisType from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding import numpy as np @@ -410,6 +411,7 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, spec) + _check_mesh_unreduced(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): @@ -443,14 +445,13 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) - def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue p = p if isinstance(p, tuple) else (p,) for r in p: - if r not in mesh.shape: + if r not in mesh.axis_names: raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") @@ -459,9 +460,34 @@ def _check_mesh_resource_axis(mesh, pspec): 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {pspec}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') - if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and + if (AxisType.Auto not in mesh._axis_types_dict and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( f'{pspec} cannot contain' ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' f' axis_types: {mesh._axis_types_dict}') + +def _check_mesh_unreduced(mesh, pspec): + counts = {} + duplicate = False + for u in pspec.unreduced: + if u not in mesh.axis_names: + raise ValueError( + f'Unreduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + count = counts.get(u, 0) + if count > 0: + duplicate = True + counts[u] = count + 1 + if duplicate: + multiple_uses = [r for r, c in counts.items() if c > 1] + raise ValueError( + f'Unreduced axes in {pspec} has duplicate entries which is not allowed.' + f' Got {mesh_lib.show_axes(multiple_uses)}') + + for u in pspec.unreduced: + if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): + raise ValueError( + 'Unreduced axes can only refer to mesh axes that is of type' + f' `Explicit`. Got unreduced axes: {pspec.unreduced} and' + f' mesh: {mesh}') diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 21333a9e7a0d..3f657082e1d4 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -25,7 +25,7 @@ from jax._src.lax import lax from jax._src.lax.lax import PrecisionLike from jax._src.numpy import util -from jax._src.sharding_impls import canonicalize_sharding, NamedSharding, PartitionSpec as P +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import partition_list, set_module, unzip2 @@ -422,7 +422,8 @@ def _einsum( " instances. Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") if preferred_element_type is None: - preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True) + preferred_element_type, output_weak_type = dtypes.result_type( + *operands, return_weak_type_flag=True) else: output_weak_type = False @@ -557,7 +558,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding( - out_sharding.mesh, P(*inverse_spec)) + out_sharding.mesh, spec.with_partitions(inverse_spec)) else: dot_general_out_sharding = out_sharding # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 158e361482b7..fcea21934bfb 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any class UnconstrainedSingleton: @@ -43,6 +43,25 @@ def _canonicalize_partition(partition): return tuple(partition) return partition +def _check(partitions, unreduced): + us = set(unreduced) + for p in partitions: + p = p if isinstance(p, tuple) else (p,) + for r in p: + if r in us: + raise ValueError( + "partitions cannot overlap with unreduced axes passed to" + f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" + f" {unreduced}") + if None in unreduced: + raise ValueError( + "unreduced cannot contain None. All elements in unreduced should refer" + " to the mesh axes.") + +def unpicke_pspec(partitions, unreduced): + return PartitionSpec(*partitions, unreduced=unreduced) + +AxisName = Any class PartitionSpecImpl: """Tuple describing how to partition an array across a mesh of devices. @@ -53,20 +72,34 @@ class PartitionSpecImpl: This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ - __slots__ = ("_partitions",) + __slots__ = ("_partitions", "_unreduced") __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION - def __init__(self, *partitions): + def __init__(self, *partitions, + unreduced: tuple[AxisName, ...] | AxisName | None = None): self._partitions = tuple(_canonicalize_partition(p) for p in partitions) + self._unreduced = ( + () if unreduced is None else tuple(unreduced) + if isinstance(unreduced, (list, tuple)) else (unreduced,)) + _check(self._partitions, self._unreduced) + + @property + def unreduced(self): + return self._unreduced def __repr__(self): - return f"PartitionSpec{self._partitions!r}" + pr = repr(self._partitions)[1:-1] + if not self._unreduced: + return f"PartitionSpec({pr})" + ur_str = f"unreduced={self._unreduced!r}" + pr = '' if not pr else f"{pr} " if pr.endswith(',') else f"{pr}, " + return (f"PartitionSpec({pr}{ur_str})") def __reduce__(self): - return (PartitionSpec, self._partitions) + return (unpicke_pspec, (self._partitions, self._unreduced)) def __getitem__(self, i): return self._partitions[i] @@ -80,20 +113,42 @@ def __len__(self): def __eq__(self, other): if not isinstance(other, (PartitionSpec, tuple)): return False - other = tuple(_canonicalize_partition(o) for o in other) - return self._partitions == other + other_p = tuple(_canonicalize_partition(o) for o in other) + if isinstance(other, PartitionSpec): + return (self._partitions == other_p and + self._unreduced == other._unreduced) + else: + if self._unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__eq__` of PartitionSpec.") + return self._partitions == other_p def __hash__(self): - return hash(self._partitions) + return hash((self._partitions, self._unreduced)) def __add__(self, other): if not isinstance(other, (tuple, PartitionSpec)): - return NotImplemented - return PartitionSpec(*self, *other) + raise NotImplementedError + if isinstance(other, PartitionSpec): + return PartitionSpec( + *self, *other, + unreduced=(*self._unreduced, *other._unreduced)) + else: + if self._unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__add__` of PartitionSpec.") + return PartitionSpec(*self, *other) def __radd__(self, other): - if not isinstance(other, (tuple, PartitionSpec)): - return NotImplemented + if not isinstance(other, tuple): + raise NotImplementedError + # other will always be a tuple. + if self._unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__radd__` of PartitionSpec.") return PartitionSpec(*other, *self) def index(self, value): @@ -102,12 +157,16 @@ def index(self, value): def count(self, value): return self._partitions.count(_canonicalize_partition(value)) + def with_partitions(self, new_partitions): + return PartitionSpec(*new_partitions, unreduced=self._unreduced) + def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self._partitions] if len(out) < ndim: out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.with_partitions(out) + if TYPE_CHECKING: class PartitionSpec(PartitionSpecImpl, tuple): # type: ignore diff --git a/tests/array_test.py b/tests/array_test.py index 97d66566cffa..b951f7f6b4cd 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1403,6 +1403,69 @@ def test_memory_kind_with_abstract_mesh(self): ValueError, 'Got invalid memory kind'): NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + def test_pspec_unreduced(self): + pspec1 = P('a', 'b', None, unreduced=('c',)) + self.assertEqual(repr(pspec1), + "PartitionSpec('a', 'b', None, unreduced=('c',))") + + pspec2 = P('a', 'b', None, unreduced=('c',)) + self.assertEqual(pspec1, pspec2) + + pspec3 = P('a', 'b', None, unreduced=('d',)) + self.assertNotEqual(pspec1, pspec3) + + out = P('x', unreduced=('z',)) + P('a', unreduced='b') + self.assertEqual(out, P('x', 'a', unreduced=('z', 'b'))) + + pspec4 = P('x', unreduced='y') + self.assertEqual(repr(pspec4), + "PartitionSpec('x', unreduced=('y',))") + + pspec5 = P(None, None, unreduced='x') + self.assertEqual(repr(pspec5), + "PartitionSpec(None, None, unreduced=('x',))") + + pspec6 = P(None, unreduced='x') + self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced=('x',))") + + pspec7 = P(unreduced='x') + self.assertEqual(repr(pspec7), "PartitionSpec(unreduced=('x',))") + + with self.assertRaisesRegex( + TypeError, 'unreduced in `__add__` of PartitionSpec'): + P('x', unreduced=('z',)) + (None,) * 2 + + with self.assertRaisesRegex( + TypeError, "unreduced in `__radd__` of PartitionSpec"): + (None,) * 2 + P('x', unreduced='y') + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', 'y', unreduced='x') + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', None, 'y', unreduced=('z', 'y')) + + def test_named_sharding_unreduced_error(self): + mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes.*not found in mesh.*"): + NamedSharding(mesh, P('x', unreduced='a')) + + with self.assertRaisesRegex( + ValueError, "Unreduced.*has duplicate entries"): + NamedSharding(mesh, P('x', unreduced=('y', 'y'))) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes can only refer to mesh axes.*Explicit"): + NamedSharding(mesh, P('x', unreduced=('y', 'z'))) + + with self.assertRaisesRegex( + ValueError, "unreduced cannot contain None.*"): + NamedSharding(mesh, P('x', unreduced=('y', None))) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e8d8d46455ba..3215881a2c14 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7677,8 +7677,88 @@ def f(x): self.assertEqual(jax.typeof(out).sharding, jax.typeof(x).sharding) return out - f(arr) - jax.jit(f)(arr) + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_basic(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'y')) + b = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y, a, b): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + self.assertEqual(m1.aval.sharding.spec, P('x', None, unreduced='y')) + + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced='y')) + self.assertEqual(m2.aval.sharding.spec, P('x', None, unreduced='y')) + + s = m1 + m2 # unreduced + self.assertEqual(s.aval.sharding.spec, P('x', None, unreduced='y')) + + out = reshard(s, P('x')) # reduce + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + f.trace(x, y, a, b) # doesn't crash + + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + # Case 1 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='z')) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + f.trace(x, y) + + # Case 2 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P(None, None)) + @jax.jit + def g(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs and rhs contracting dims should be sharded identically"): + g.trace(x, y) + + # Case 3 + x = jax.device_put(np_inp, P('x', None)) + y = jax.device_put(np_inp.T, P(None, None)) + + @jax.jit + def h(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + h.trace(x, y) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_add_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + m2 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x')) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, + "arrays must be unreduced along the same mesh axes"): + f.trace(x, y) @jtu.pytest_mark_if_available('multiaccelerator') From 5634ca126e776edd1bac2880b8162bd9b2cf791d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 May 2025 17:25:21 -0700 Subject: [PATCH 1062/1769] Rename `out_shardings -> out_sharding` in the `auto_axes` API and `in_shardings -> in_sharding` in `explicit_axes` API. PiperOrigin-RevId: 756069295 --- docs/notebooks/explicit-sharding.ipynb | 10 ++++----- docs/notebooks/explicit-sharding.md | 10 ++++----- jax/_src/numpy/indexing.py | 4 ++-- jax/_src/ops/scatter.py | 2 +- jax/_src/pjit.py | 30 +++++++++++++------------- jax/_src/random.py | 6 +++--- tests/pjit_test.py | 30 +++++++++++++------------- 7 files changed, 46 insertions(+), 46 deletions(-) diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index f37d0dcdf887..e1bee4b99fb5 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -397,7 +397,7 @@ " which the split/merged axes are sharded as None then we shard the\n", " resulting split/merged axes as None and the other axes according to their\n", " corresponding input axis shardings. In all other cases we throw an error\n", - " and require the user to provide an `out_shardings` argument." + " and require the user to provide an `out_sharding` argument." ] }, { @@ -494,7 +494,7 @@ " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", " return x + y\n", "\n", - "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", + "result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P(\"X\", None))\n", "print(f\"Result type: {jax.typeof(result)}\")" ] }, @@ -637,7 +637,7 @@ " x = jnp.sin(arr1)\n", " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", "\n", - " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", + " z = g(x, out_sharding=P(\"X\", \"Y\"))\n", "\n", " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", " return z + 1\n", @@ -681,7 +681,7 @@ " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", " x = jnp.sin(arr1)\n", "\n", - " z = explicit_g(x, in_shardings=P(\"X\", \"Y\"))\n", + " z = explicit_g(x, in_sharding=P(\"X\", \"Y\"))\n", "\n", " return z + 1\n", "\n", @@ -778,7 +778,7 @@ " compare_shardings(x)\n", " return x\n", "\n", - "check_in_auto_context(my_array, out_shardings=P(\"X\"))" + "check_in_auto_context(my_array, out_sharding=P(\"X\"))" ] }, { diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index b374b7d7a668..1402bca2415f 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -239,7 +239,7 @@ Here are some example sharding rules: which the split/merged axes are sharded as None then we shard the resulting split/merged axes as None and the other axes according to their corresponding input axis shardings. In all other cases we throw an error - and require the user to provide an `out_shardings` argument. + and require the user to provide an `out_sharding` argument. +++ {"id": "jZMp6w48Xmd7"} @@ -308,7 +308,7 @@ def add_with_out_sharding_kwarg(x, y): print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}") return x + y -result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None)) +result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P("X", None)) print(f"Result type: {jax.typeof(result)}") ``` @@ -390,7 +390,7 @@ def f(arr1): x = jnp.sin(arr1) print(f'x.sharding: {jax.typeof(x)}', end='\n\n') - z = g(x, out_shardings=P("X", "Y")) + z = g(x, out_sharding=P("X", "Y")) print(f'z.sharding: {jax.typeof(z)}', end="\n\n") return z + 1 @@ -423,7 +423,7 @@ def f(arr1): print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n') x = jnp.sin(arr1) - z = explicit_g(x, in_shardings=P("X", "Y")) + z = explicit_g(x, in_sharding=P("X", "Y")) return z + 1 @@ -469,7 +469,7 @@ def check_in_auto_context(x): compare_shardings(x) return x -check_in_auto_context(my_array, out_shardings=P("X")) +check_in_auto_context(my_array, out_sharding=P("X")) ``` +++ {"id": "MRFccsi5X8so"} diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 044b5175a46a..17fbccd7ac9d 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -608,7 +608,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None, internal_ds = partial(lax.dynamic_slice, slice_sizes=slice_sizes, allow_negative_indices=allow_negative_indices) if out_sharding is not None: - arr = auto_axes(internal_ds, out_shardings=out_sharding)(arr, start_indices) + arr = auto_axes(internal_ds, out_sharding=out_sharding)(arr, start_indices) else: arr = internal_ds(arr, start_indices) if int_indices: @@ -646,7 +646,7 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value) if out_sharding is not None: - return auto_axes(internal_gather, out_shardings=out_sharding + return auto_axes(internal_gather, out_sharding=out_sharding )(arr, dynamic_idx) return internal_gather(arr, dynamic_idx) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index fcb3759c5cae..4db79557c3cc 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -87,7 +87,7 @@ def _scatter_update(x: ArrayLike, idx: Index, y: ArrayLike, scatter_op: Callable unique_indices=unique_indices, mode=mode, normalize_indices=normalize_indices) if out_sharding is not None: - return auto_axes(internal_scatter, out_shardings=out_sharding + return auto_axes(internal_scatter, out_sharding=out_sharding )(x, y, dynamic_idx) return internal_scatter(x, y, dynamic_idx) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 268f533b971b..f141d6c6237b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -3067,24 +3067,24 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None, return mesh_to_use.update_axis_types({a: axis_type for a in axes}) def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, - out_shardings=None): + out_sharding=None): def decorator(*args, **kwargs): - if out_shardings is None: - if "out_shardings" in kwargs: - _out_shardings = kwargs.pop("out_shardings") + if out_sharding is None: + if "out_sharding" in kwargs: + _out_sharding = kwargs.pop("out_sharding") else: - raise TypeError("Missing required keyword argument: 'out_shardings'") + raise TypeError("Missing required keyword argument: 'out_sharding'") else: - _out_shardings = out_shardings + _out_sharding = out_sharding new_mesh = _get_new_mesh( - axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_shardings, + axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_sharding, error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) args = mesh_cast(args, in_specs) out = fun(*args, **kwargs) - return mesh_cast(out, _out_shardings) + return mesh_cast(out, _out_sharding) return decorator @contextlib.contextmanager @@ -3095,19 +3095,19 @@ def use_auto_axes(*axes): def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, - in_shardings=None): + in_sharding=None): def decorator(*args, **kwargs): - if in_shardings is None: - if "in_shardings" in kwargs: - _in_shardings = kwargs.pop("in_shardings") + if in_sharding is None: + if "in_sharding" in kwargs: + _in_sharding = kwargs.pop("in_sharding") else: - raise TypeError("Missing required keyword argument: 'in_shardings'") + raise TypeError("Missing required keyword argument: 'in_sharding'") else: - _in_shardings = in_shardings + _in_sharding = in_sharding new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): - args = mesh_cast(args, _in_shardings) + args = mesh_cast(args, _in_sharding) out = fun(*args, **kwargs) out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual( core.get_aval(o).sharding.spec, mesh_lib.get_abstract_mesh()), out) diff --git a/jax/_src/random.py b/jax/_src/random.py index ef02719de146..6f139dd9665c 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -349,12 +349,12 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) -def maybe_auto_axes(f, out_shardings, **hoist_kwargs): +def maybe_auto_axes(f, out_sharding, **hoist_kwargs): f_ = partial(f, **hoist_kwargs) - if out_shardings is None: + if out_sharding is None: return f_ else: - return auto_axes(f_, out_shardings=out_shardings) + return auto_axes(f_, out_sharding=out_sharding) def bits(key: ArrayLike, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3215881a2c14..effdc4a4ddbc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5137,7 +5137,7 @@ def f(x, y): ValueError, 'PartitionSpec passed to einsum cannot contain axis names that are of' ' type Auto or Manual'): - auto_axes(f, out_shardings=P())(arr1, arr2) + auto_axes(f, out_sharding=P())(arr1, arr2) out = jax.grad(f, argnums=(0, 1))(arr1, arr2) self.assertEqual(out[0].sharding, arr1.sharding) @@ -6371,7 +6371,7 @@ def test_auto_mode_mix(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(auto_axes, axes='x', out_shardings=P('x', None)) + @partial(auto_axes, axes='x', out_sharding=P('x', None)) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6433,7 +6433,7 @@ def test_full_user_mode(self, mesh): arr = jax.device_put(np_inp, s) # No axes specified means full visible mode. - @partial(explicit_axes, in_shardings=P('x', 'y')) + @partial(explicit_axes, in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6563,7 +6563,7 @@ def test_mix_to_full_user_mode(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P('x', 'y')) + @partial(explicit_axes, axes='y', in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6589,7 +6589,7 @@ def test_full_auto_to_partial_user(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P(None, 'y')) + @partial(explicit_axes, axes='y', in_sharding=P(None, 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6703,7 +6703,7 @@ def test_auto_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(auto_axes, out_shardings=P('x', None)) + @partial(auto_axes, out_sharding=P('x', None)) def auto_matmul(arr1, arr2): return arr1 @ arr2 @@ -6725,7 +6725,7 @@ def test_explicit_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(explicit_axes, in_shardings=(P('x', None), P('x', None))) + @partial(explicit_axes, in_sharding=(P('x', None), P('x', None))) def jax_matmul(arr1, arr2): out = arr1 @ arr2 self.assertEqual(out.aval.sharding.spec, P('x', None)) @@ -6774,7 +6774,7 @@ def f(x): self.assertEqual(a.aval.sharding.spec, P(None, None)) return a - hf = auto_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y')) + hf = auto_axes(f, axes=('x', 'y'), out_sharding=P('x', 'y')) out = hf(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @@ -6793,7 +6793,7 @@ def f(x): self.assertEqual(z.aval.sharding.spec, P('x', 'y')) return z - hf = explicit_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y')) + hf = explicit_axes(f, axes=('x', 'y'), in_sharding=P('x', 'y')) out = hf(arr) # doesn't crash self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @@ -7171,7 +7171,7 @@ def test_wsc_pspec_use_mesh(self, sharded_inp): def test_axes_api_error_manual_to_auto_explicit(self, mesh): def g(x): return auto_axes(lambda a: a * 2, axes=('x', 'y'), - out_shardings=P('x', 'y'))(x) + out_sharding=P('x', 'y'))(x) with self.assertRaisesRegex( NotImplementedError, "Going from `Manual`.*to.*`Auto`.*`Explicit`"): @@ -7185,7 +7185,7 @@ def f(x): self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto) return x * 2 - out = auto_axes(f, out_shardings=P('x'))(np.arange(8)) + out = auto_axes(f, out_sharding=P('x'))(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) @@ -7268,7 +7268,7 @@ def test_auto_axes_computation_follows_data(self): def f(x): return x * 2 - out = auto_axes(f, out_shardings=s)(arr) + out = auto_axes(f, out_sharding=s)(arr) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, arr * 2) @@ -7322,7 +7322,7 @@ def test_auto_axes_late_bind(self, mesh): def f(x): return x * 2 - out = f(np.arange(8), out_shardings=P('x')) + out = f(np.arange(8), out_sharding=P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) @@ -7332,7 +7332,7 @@ def test_explicit_axes_late_bind(self, mesh): def f(x): return x * 2 - out = f(np.arange(8), in_shardings=P('x')) + out = f(np.arange(8), in_sharding=P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) @@ -7471,7 +7471,7 @@ def test_auto_axes_no_context_mesh(self): arr = jax.device_put(np_inp, s) @partial(auto_axes, axes='x', - out_shardings=NamedSharding(mesh, P('x', 'y'))) + out_sharding=NamedSharding(mesh, P('x', 'y'))) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) From eada2988c2f2a24d96de462f58205cb6815d13c4 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 7 May 2025 17:54:47 -0700 Subject: [PATCH 1063/1769] [Mosaic] Add explicit control over core parallelization strategy This CL introduces the dimension semantic `core_parallel`. It allows the user to control which dimension is parallelized across cores as opposed to leaving it to Mosaic to choose the best out of `parallel` dimensions. PiperOrigin-RevId: 756076809 --- jaxlib/mosaic/dialect/tpu/tpu.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 7f295b4ec09b..fb72b6948d9d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -82,7 +82,8 @@ def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]> def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"arbitrary", 1> + I32EnumAttrCase<"arbitrary", 1>, + I32EnumAttrCase<"core_parallel", 2> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::tpu"; From 4ab810b4835908d9e06a4fe55b6d37b986ed9cf1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 May 2025 18:54:07 -0700 Subject: [PATCH 1064/1769] Fix `with_sharding_constraint` with a scalar input PiperOrigin-RevId: 756092441 --- jax/_src/pjit.py | 14 ++++++-------- tests/pjit_test.py | 7 +++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f141d6c6237b..f87207e0a796 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1578,9 +1578,8 @@ def check_aval_layout_compatibility( if l is None or isinstance(l, AutoLayout): continue name_str = f' with pytree key path {name}' if name else '' - shape = aval.shape try: - l.check_compatible_aval(shape) + l.check_compatible_aval(aval.shape) except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its layout ' @@ -2717,7 +2716,7 @@ def with_sharding_constraint(x, shardings): .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) - + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] layouts, shardings = _split_layout_and_sharding(shardings) user_shardings = prepare_axis_resources( @@ -2753,13 +2752,11 @@ def with_sharding_constraint(x, shardings): for s in shardings_flat] pjit_check_aval_sharding( - shardings_flat, x_flat, ("",) * len(shardings_flat), + shardings_flat, x_avals_flat, ("",) * len(shardings_flat), "with_sharding_constraint arguments", allow_uneven_sharding=True) - check_shardings_are_auto(shardings_flat) - - check_aval_layout_compatibility(user_layouts_flat, x_flat, + check_aval_layout_compatibility(user_layouts_flat, x_avals_flat, ("",) * len(user_layouts_flat), "with_sharding_constraint arguments") @@ -3125,6 +3122,7 @@ def use_explicit_axes(*axes): def with_layout_constraint(x, layouts): x_flat, tree = tree_flatten(x) + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree, layouts)) if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): @@ -3132,7 +3130,7 @@ def with_layout_constraint(x, layouts): 'layouts passed to `with_layout_constraint` must be of type' f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') check_aval_layout_compatibility( - layouts_flat, x_flat, ("",) * len(layouts_flat), + layouts_flat, x_avals_flat, ("",) * len(layouts_flat), "with_layout_constraint arguments") outs = [layout_constraint_p.bind(xf, layout=l) for xf, l in zip(x_flat, layouts_flat)] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index effdc4a4ddbc..373e3ceff7a8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4272,6 +4272,13 @@ def make_keys(seeds): else: self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + def test_wsc_with_scalar(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P()) + out = jax.lax.with_sharding_constraint(1., s) + self.assertArraysEqual(out, 1.) + self.assertEqual(out.sharding, s) + def test_jit_partially_specified_shardings(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From 8137c37e324c9cb5c8f991a16d78310b6e37bd05 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 May 2025 21:41:54 -0700 Subject: [PATCH 1065/1769] Reverts 6d1b5271a115007162e9f98561d6b118aa66382c PiperOrigin-RevId: 756139245 --- jax/_src/array.py | 3 +-- jaxlib/py_array.cc | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 422fa5086e62..f2b070c8221d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -636,8 +636,7 @@ def _value(self) -> np.ndarray: self._check_if_deleted() if self._npy_value is None: - if (self.is_fully_replicated and - self.sharding._internal_device_list.addressable_device_list): # type: ignore + if self.is_fully_replicated: npy_value, did_copy = self._single_device_array_to_np_array_did_copy() npy_value.flags.writeable = False if did_copy: diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 1222d410bad8..022c7a831c92 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1528,9 +1528,6 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { absl::Span> buffers = array->pjrt_buffers(); - if (buffers.empty()) { - return InvalidArgument("Array has no buffers."); - } PjRtBuffer& buffer = *buffers.front(); if (!buffer.IsOnCpu()) { return InvalidArgument( From 93eb725e7a9c3ac1c990945967917a2e736efe6e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 May 2025 05:52:54 -0400 Subject: [PATCH 1066/1769] Fix docs build by constraining snowballstemmer version. --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5d49222bbb42..1fd706ab01a5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling +snowballstemmer<3.0.0 # v3.0.0 incompatible with older sphinx; missing stemmer sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 From dbfc93fd53f30f1bda1ca70dd7452212d15e8cc1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 8 May 2025 06:17:20 -0700 Subject: [PATCH 1067/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f08f9dc30cb4c0c638c244ba59a09722dcbedad5. PiperOrigin-RevId: 756282772 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ed77b40bb4a7..b98430d24198 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4b45c0f0dc2fb80be6036c36a3e40f3cd1b478c9" -XLA_SHA256 = "64292fbcceebe0ee03a97a0a0edf667067642de028a2e7d7d1175641e91f8925" +XLA_COMMIT = "f08f9dc30cb4c0c638c244ba59a09722dcbedad5" +XLA_SHA256 = "8259f57fd5a475557f97f5365612646167d8c7c2d849202923a6fda88a218de2" def repo(): tf_http_archive( From 7e8c74985a78ff1dec528aa0f1936129dd7a31cf Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 May 2025 09:32:54 -0400 Subject: [PATCH 1068/1769] Add a pretty printing rule for custom_jvp. --- jax/_src/custom_derivatives.py | 15 +++++++++++++++ tests/api_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index a8f136477bd9..dc8fc90e3d1f 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -487,6 +487,21 @@ def dce_jvp_jaxpr_thunk(*in_zeros): pe.dce_rules[custom_jvp_call_p] = _custom_jvp_call_dce +def _custom_jvp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params["jvp"] = params.pop("jvp_jaxpr_fun").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + + +core.pp_eqn_rules[custom_jvp_call_p] = _custom_jvp_call_pp_rule + ### VJPs @custom_api_util.register_custom_decorator_type diff --git a/tests/api_test.py b/tests/api_test.py index 610719518a03..86c0c8ceeee3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -34,6 +34,7 @@ import re import subprocess import sys +import textwrap import traceback import types from typing import NamedTuple @@ -8418,6 +8419,32 @@ def f_jvp(x, t): x = jnp.arange(3.0) jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash + def test_pretty_print(self): + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(primals, tangents): + return f(*primals), tangents[0] + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_jvp_call[ + name=f + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + jvp=f_jvp + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + class CustomVJPTest(jtu.JaxTestCase): From 9f64ddd380381d511889d91e620701f82a15d24c Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 8 May 2025 07:01:27 -0700 Subject: [PATCH 1069/1769] [xla::PyClient] Update PyClient to use xla::ifrt::CompileAndLoad. PiperOrigin-RevId: 756294938 --- jaxlib/py_client.cc | 140 ++++++++++++++++++++++++++++++++++++------- jaxlib/py_client.h | 14 ++--- jaxlib/xla_client.py | 2 +- 3 files changed, 127 insertions(+), 29 deletions(-) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index ecd412ddbb99..0a99d94f81cc 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -411,7 +411,7 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, } // namespace /* static */ absl::StatusOr> -PyClient::CompileIfrtProgram( +PyClient::CompileAndLoadIfrtProgram( nb_class_ptr client, std::unique_ptr ifrt_program, std::unique_ptr ifrt_options) { auto* pjrt_compatible_client = @@ -448,9 +448,10 @@ PyClient::CompileIfrtProgram( std::optional fingerprint; { nb::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(ifrt_loaded_executable, - client->ifrt_client_->GetDefaultCompiler()->Compile( - std::move(ifrt_program), std::move(ifrt_options))); + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->CompileAndLoad( + std::move(ifrt_program), std::move(ifrt_options))); TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); } @@ -460,10 +461,11 @@ PyClient::CompileIfrtProgram( std::move(traceback), std::move(fingerprint)); } -/* static */ absl::StatusOr> PyClient::Compile( - nb_class_ptr client, std::string mlir_module, - ifrt::DeviceListRef executable_devices, CompileOptions options, - std::vector host_callbacks) { +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -472,16 +474,17 @@ PyClient::CompileIfrtProgram( // export it before going to HLO while preserving Shardy ops and attrs. TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); } - return CompileIfrtProgram( + return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), std::move(host_callbacks))); } -/* static */ absl::StatusOr> PyClient::Compile( - nb_class_ptr client, std::string mlir_module, - ifrt::DeviceListRef executable_devices, CompileOptions options, - std::vector host_callbacks) { +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -509,7 +512,7 @@ PyClient::CompileIfrtProgram( auto compile_options = std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); #endif - return CompileIfrtProgram( + return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), std::move(compile_options)); } @@ -761,7 +764,7 @@ PyType_Slot PyClient::slots_[] = { std::vector host_callbacks) { ifrt::DeviceListRef executable_devices = ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), std::move(executable_devices), std::move(options), @@ -777,7 +780,7 @@ PyType_Slot PyClient::slots_[] = { std::vector host_callbacks) { ifrt::DeviceListRef executable_devices = ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), std::move(executable_devices), std::move(options), @@ -793,7 +796,7 @@ PyType_Slot PyClient::slots_[] = { std::vector host_callbacks) { ifrt::DeviceListRef executable_devices = ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::move(mlir_module), std::move(executable_devices), std::move(options), std::move(host_callbacks))); @@ -808,7 +811,7 @@ PyType_Slot PyClient::slots_[] = { std::vector host_callbacks) { ifrt::DeviceListRef executable_devices = ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::move(mlir_module), std::move(executable_devices), std::move(options), std::move(host_callbacks))); @@ -825,7 +828,7 @@ PyType_Slot PyClient::slots_[] = { ifrt::DeviceListRef executable_devices = ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) .ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), std::move(executable_devices), std::move(options), @@ -840,7 +843,100 @@ PyType_Slot PyClient::slots_[] = { ifrt::DeviceListRef executable_devices = ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) .ifrt_device_list()); - return ValueOrThrow(PyClient::Compile( + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( std::move(client), std::move(mlir_module), std::move(executable_devices), std::move(options), std::vector())); @@ -848,7 +944,9 @@ PyType_Slot PyClient::slots_[] = { nb::arg("computation"), nb::arg("executable_devices"), nb::arg("compile_options") = CompileOptions()) .def("compile_ifrt_program", - xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("compile_and_load_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) .def("serialize_executable", xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) .def( diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h index 50529fac5c7e..7f70fa4f111b 100644 --- a/jaxlib/py_client.h +++ b/jaxlib/py_client.h @@ -162,17 +162,17 @@ class PyClient { ifrt::Device* device, bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics); - static absl::StatusOr> CompileIfrtProgram( - nb_class_ptr client, - std::unique_ptr ifrt_program, - std::unique_ptr ifrt_options); + static absl::StatusOr> + CompileAndLoadIfrtProgram(nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); - static absl::StatusOr> Compile( + static absl::StatusOr> CompileAndLoad( nb_class_ptr client, std::string mlir_module, ifrt::DeviceListRef executable_devices, CompileOptions options, std::vector host_callbacks); - static absl::StatusOr> Compile( + static absl::StatusOr> CompileAndLoad( nb_class_ptr client, std::string mlir_module, ifrt::DeviceListRef executable_devices, CompileOptions options, std::vector host_callbacks); @@ -193,7 +193,7 @@ class PyClient { // program through `send_channel_ids` and the results correspond to Recv ops // through `recv_channel_ids`. It returns the host callback as an opaque // object whose reference will keep the Python callback alive. The host - // callback can be passed to `PyClient::Compile` or + // callback can be passed to `PyClient::CompileAndLoad` or // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the // XLA computation can trigger the execution of this host callback. // `serializer` is a function that takes `callable` as an argument and returns diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index a77a8226d944..725d05a2dace 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 337 +_version = 338 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 5c77f249b575481d234ae4c9a8afa733d2457ce3 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 8 May 2025 07:09:49 -0700 Subject: [PATCH 1070/1769] Record event start time in `dispatch.LogElapsedTimeContextManager`. We want to collect timings for various stages of the jax compilation process, but `log_compiles` logs everything. It'd be nice to filter for top-level functions (or some top-n level). We pass fun_name to the event callbacks and that's necessary for labeling, but a) not necessarily unique and b) need to know ahead of time to filter. With just one event for start/end time, we can solve by recording metrics the entire run and do some post-processing at the end. Having a start event just makes this a little easier (in a non-free-threading world, anyway) since we can process as we go / throw out events nested beyond the level(s) we care about. PiperOrigin-RevId: 756297365 --- jax/_src/dispatch.py | 6 +++++- tests/monitoring_test.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index ebab2120c4d0..409b22c849e3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -46,7 +46,7 @@ from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh -from jax._src.monitoring import record_event_duration_secs, record_event_time_span +from jax._src.monitoring import record_scalar, record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( @@ -179,6 +179,10 @@ def __init__(self, fmt: str, fun_name: str, event: str | None = None): def __enter__(self): self.start_time = time.time() + if self.event is not None: + record_scalar( + self.event, self.start_time, fun_name=self.fun_name + ) def __exit__(self, exc_type, exc_value, traceback): if _on_exit: diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 5ef5c5d928ba..a50ddf6f4cc6 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -68,10 +68,10 @@ def test_record_scalar(self): observed_values = [] monitoring.register_scalar_listener( - lambda key, _: observed_keys.append(key), + lambda key, _, **kwargs: observed_keys.append(key), ) monitoring.register_scalar_listener( - lambda _, value: observed_values.append(value), + lambda _, value, **kwargs: observed_values.append(value), ) monitoring.record_scalar("test_unique_event", 1) From 4013965d8165fff028ccb144271116216003f289 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 May 2025 08:08:52 -0700 Subject: [PATCH 1071/1769] jnp.put: check inplace before other conditions --- jax/_src/numpy/indexing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 17fbccd7ac9d..6aa5d6b87ef4 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -1271,16 +1271,16 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32) """ + if inplace: + raise ValueError( + "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " + "Pass inplace=False to instead return an updated array.") arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v) ind_arr = ind_arr.ravel() v_arr = lax_numpy.ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: return arr v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr)) - if inplace: - raise ValueError( - "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " - "Pass inplace=False to instead return an updated array.") if mode is None: scatter_mode = "drop" elif mode == "clip": From 1e1f1e0d7aaff0c48e1dce0eff56837848848fa9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 May 2025 08:17:03 -0700 Subject: [PATCH 1072/1769] jnp.linalg.matrix_power: support non-float inputs --- jax/_src/numpy/linalg.py | 3 +-- tests/linalg_test.py | 7 +++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 146bbbda0213..0e20e5b2a416 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -367,8 +367,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32) """ - a = ensure_arraylike("jnp.linalg.matrix_power", a) - arr, = promote_dtypes_inexact(a) + arr = ensure_arraylike("jnp.linalg.matrix_power", a) if arr.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 74259c300cf7..033ca989c8e7 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1230,6 +1230,13 @@ def testMatrixPower(self, shape, dtype, n): self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker, rtol=1e-3) + def testMatrixPowerBool(self): + # Regression test for https://github.com/jax-ml/jax/issues/28603 + mat = np.array([[True,True], [False,True]]) + np_result = np.linalg.matrix_power(mat, 2) + jnp_result = jnp.linalg.matrix_power(mat, 2) + self.assertArraysEqual(np_result, jnp_result) + @jtu.sample_product( shape=[(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50), (3, 4, 5), (2, 3, 4, 5)], From 66fdee2db2a4896da2bcdc4740f14c1a6230c3bf Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 8 May 2025 08:37:13 -0700 Subject: [PATCH 1073/1769] Upgrade Mac CI builds to run on the Sequoia pool and Apple Clang 17 PiperOrigin-RevId: 756325464 --- .bazelrc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.bazelrc b/.bazelrc index 9ec02b94f03b..4780b0ba37c9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -259,10 +259,6 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 -# Clang 19 requires `-Wno-error=c23-extensions` but this flag is not supported -# on Apple Clang in XCode 16.0 so we suppress unknown warning option errors -# on Mac CI builds. -build:ci_darwin_arm64 --copt=-Wno-unknown-warning-option build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true build:ci_darwin_arm64 --color=yes From f7e11660e9a66e2764781997dc1086863d4af672 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 8 May 2025 08:57:46 -0700 Subject: [PATCH 1074/1769] [Pallas][Mosaic GPU] Expand TMEM support. Allow: - Multiple TMEM allocs - Expand support to 16-bit dtypes. - Allow storing to TMEM from SMEM. PiperOrigin-RevId: 756332369 --- jax/_src/pallas/mosaic_gpu/core.py | 35 +++++++++-- jax/_src/pallas/mosaic_gpu/lowering.py | 78 +++++++++++++++++------- jax/_src/pallas/mosaic_gpu/primitives.py | 31 +++++++++- jax/experimental/mosaic/gpu/__init__.py | 4 ++ jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 63 ++++++++++++++----- 6 files changed, 168 insertions(+), 44 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 808759edf35c..f193ce7d2743 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -124,10 +124,13 @@ def __call__( self, shape: tuple[int, ...], dtype: jnp.dtype, + *, transforms: Sequence[MemoryRefTransform] = (), + packed: bool | None = None ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. - return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) + return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms, + packed=packed) class SemaphoreType(enum.Enum): @@ -219,13 +222,27 @@ def _is_known_divisible(value, divisor, fuel=10) -> bool: class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () + # Whether to allow TMEM packing for sub 4-byte dtypes. + packed: bool | None = dataclasses.field(default=None, kw_only=True) + + def __post_init__(self): + if self.packed is not None and self.memory_space != GPUMemorySpace.TMEM: + raise ValueError("Packed option is only supported for TMEM.") + def get_ref_aval(self) -> _Ref: aval = jax_core.ShapedArray(self.shape, self.dtype) for t in self.transforms: aval = t(aval) - ref = pallas_core.TransformedRef( - AbstractMemoryRef(aval, memory_space=self.memory_space), () - ) + if self.memory_space == GPUMemorySpace.TMEM: + ref = pallas_core.TransformedRef( + AbstractTMEMRef(aval, + memory_space=self.memory_space, + packed=self.packed), () + ) + else: + ref = pallas_core.TransformedRef( + AbstractMemoryRef(aval, memory_space=self.memory_space), () + ) for t in reversed(self.transforms): ref = t.undo(ref) if not ref.transforms: @@ -918,6 +935,16 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: memory_space=ref.memory_space, # pytype: disable=attribute-error ) +class AbstractTMEMRef(AbstractMemoryRef): + __slots__ = ["inner_aval", "memory_space", "packed"] + + def __init__(self, inner_aval, memory_space, packed): + super().__init__(inner_aval, memory_space) + self.packed = packed + + def __repr__(self) -> str: + return f'TMEM({self.inner_aval.str_short()},packed={self.packed})' + _WARPGROUP_AXIS_NAME = object() diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a6bdf76206d3..29ec0d16d2fb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -241,15 +241,20 @@ def _run_scoped_resource_estimator( ) ) elif aval.memory_space == gpu_core.TMEM: - if aval.dtype.itemsize != 4: - raise ValueError("TMEM only supports 32-bit types.") if len(aval.shape) != 2: - raise ValueError("TMEM allocations must be 2D.") + raise ValueError(f"TMEM allocations must be 2D. Got {aval.shape}") if aval.shape[0] % tcgen05.TMEM_ROWS != 0: - raise ValueError("TMEM shape[0] must be a multiple of 128.") - if aval.shape[1] % 8 != 0: - raise ValueError("TMEM shape[1] must be a multiple of 8.") - rs += Resources(tmem_scratch_cols=aval.shape[1]) + raise ValueError( + f"TMEM shape[0] must be a multiple of 128. Got {aval.shape[0]}.") + if aval.packed: + packing = 4 // aval.dtype.itemsize + else: + packing = 1 + layout = tcgen05._infer_tmem_layout( + aval.shape, collective=False, packing=packing) + cols_used = layout.cols_in_shape(aval.shape) + cols_used = tcgen05._alloc_ncols(cols_used, exact=False) + rs += Resources(tmem_scratch_cols=cols_used) elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize @@ -346,20 +351,30 @@ def reserve_barrier( def alloc_tmem( self, struct: jax.ShapeDtypeStruct, - layout: tcgen05.TMEMLayout | None = None + *, + layout: tcgen05.TMEMLayout | None = None, + collective: bool = False, + packed: bool = False, + exact_cols: bool = False ) -> ir.Value: - if self.tmem_used_cols > 0: - raise NotImplementedError( - "Multiple TMEM allocations are not implemented.") + if packed: + packing = 4 // struct.dtype.itemsize + else: + packing = 1 if layout is None: - layout = tcgen05._infer_tmem_layout(struct.shape, collective=False) - cols_used = np.prod(struct.shape) // tcgen05.TMEM_ROWS + layout = tcgen05._infer_tmem_layout( + struct.shape, collective, packing=packing) + unpadded_cols_used = layout.cols_in_shape(struct.shape) + cols_used = tcgen05._alloc_ncols(unpadded_cols_used, exact_cols) + + off = arith_dialect.addi(self.tmem_base_ptr, + _i32_constant(self.tmem_used_cols)) + tmem_ref = tcgen05.TMEMRef( + address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout) self.tmem_used_cols += cols_used - off = self.tmem_base_ptr - tmem_ref = tcgen05.TMEMRef(address=off, - shape=struct.shape, - dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), - layout=layout) yield tmem_ref self.tmem_used_cols -= cols_used @@ -610,6 +625,8 @@ def lower_pipelined_jaxpr_to_module( def ref_for_aval(aval: jax_core.AbstractValue): if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype) + elif isinstance(aval, gpu_core.AbstractTMEMRef): + return gpu_core.TMEM(aval.shape, aval.dtype, packed=aval.packed) elif isinstance(aval, pallas_core.AbstractMemoryRef): return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space) else: @@ -1324,17 +1341,32 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): @register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) def _swap_lowering_rule( - ctx: LoweringRuleContext, x_smem, value, *leaves, tree + ctx: LoweringRuleContext, x_ref, value, *leaves, tree ): if not isinstance(value, mgpu.FragmentedArray): raise TypeError(f"Can only store arrays (got {value}).") - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only store to references (got {x_smem}).") + + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + match transforms: + case (indexer,) if isinstance(indexer, indexing.NDIndexer): + if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): + raise NotImplementedError( + "Only trivial indexing is supported for TMEM refs.") + case _: + raise NotImplementedError( + "Only a single indexing transform is supported for TMEM refs.") + old_value = x_ref[:] + x_ref[:] = value + return old_value + + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only store to references (got {x_ref}).") v_aval = ctx.avals_in[1] transforms = jax.tree.unflatten(tree, leaves) transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT x_smem, transforms = _handle_transforms( - ctx, x_smem, transforms, handle_transposes=not transposed_value, allow_peer_refs=True + ctx, x_ref, transforms, handle_transposes=not transposed_value, allow_peer_refs=True ) mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: @@ -2201,6 +2233,8 @@ def _run_scoped_lowering_rule( input_ref = alloc_stack.enter_context( ctx.module_ctx.alloc_tmem( jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + packed=aval.packed, + exact_cols=False, ) ) input_refs.append(input_ref) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8d0e0c82671d..1081f052f4c1 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1177,8 +1177,9 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, if acc.memory_space != gpu_core.GPUMemorySpace.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") - if a.memory_space != gpu_core.GPUMemorySpace.SMEM: - raise ValueError("LHS must be an SMEM Ref. TMEM not yet supported.") + if a.memory_space not in (gpu_core.GPUMemorySpace.SMEM, + gpu_core.GPUMemorySpace.TMEM): + raise ValueError("LHS must be a TMEM/SMEM Ref.") if b.memory_space != gpu_core.GPUMemorySpace.SMEM: raise ValueError("RHS must be an SMEM Ref.") @@ -1287,6 +1288,27 @@ def _tcgen05_mma_lowering( ctx=ctx.launch_ctx) return [] + +commit_tmem_p = jax_core.Primitive("commit_tmem") +commit_tmem_p.multiple_results = True + + +@commit_tmem_p.def_effectful_abstract_eval +def _commit_tmem_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(commit_tmem_p, mgpu.LoweringSemantics.Lane) +def _commit_tmem_lowering(_): + tcgen05.commit_tmem() + return () + + +def commit_tmem(): + """Commits all writes to TMEM, making them visible to loads and MMA.""" + commit_tmem_p.bind() + + class Layout(enum.Enum): #: [m, n] matrix, where m % 64 == 0 == n % 8. WGMMA = enum.auto() @@ -1299,6 +1321,8 @@ class Layout(enum.Enum): WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() + TCGEN05 = enum.auto() + def __call__(self, *args, **kwargs) -> ParameterizedLayout: return ParameterizedLayout(self, args, kwargs) @@ -1324,6 +1348,9 @@ def check_no_args(): return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: return mgpu.WGStridedFragLayout(*args, **kwargs) # pytype: disable=missing-parameter + case Layout.TCGEN05: + check_no_args() + return mgpu.TCGEN05_LAYOUT @dataclasses.dataclass(frozen=True) class ParameterizedLayout: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 074890d1816d..82155f86d9ea 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -101,3 +101,7 @@ WGMMAAccumulator as WGMMAAccumulator, wgmma as wgmma, ) + +from .tcgen05 import ( + LAYOUT as TCGEN05_LAYOUT, # noqa: F401 +) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index dd1bd3aba4bd..7b300b8cfbfa 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -58,6 +58,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma +from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index dcafa9b7277e..ed45c3b9f1e6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1707,6 +1707,7 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.broadcasted_iota_p, mgpu_primitives.load_p, mgpu_primitives.tcgen05_mma_p, + mgpu_primitives.commit_tmem_p, lax.slice_p, pallas_core.core_map_p, pallas_primitives.semaphore_signal_p, @@ -2092,38 +2093,56 @@ class PallasCallSm90AWGTest( class PallasCallSm100ATest(PallasSm100ATest): - def test_tmem_alloc(self): + def test_tmem(self): self.skip_if_wg_semantics() # TMEM read not wired up in the WG get rule. - + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) @functools.partial( self.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), scratch_shapes=[ plgpu.TMEM((128, 128), jnp.float32), - plgpu.SMEM((128, 128), jnp.float32), + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(num_arrivals=1), ], num_threads=1, thread_name="x", ) - def kernel(y_ref, tmem_ref, smem_ref): - # Issue a write so the TMEM load is not DCE'd. - smem_ref[...] = tmem_ref[...] + def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + # Exercise TMEM by roundtripping SMEM -> TMEM -> TMEM -> SMEM. + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + tmem_ref[...] = x_val + 1 + plgpu.commit_tmem() + tmem_ref2[...] = tmem_ref[...] + plgpu.commit_tmem() + smem_ref[...] = tmem_ref2[...] plgpu.commit_smem() plgpu.copy_smem_to_gmem(smem_ref, y_ref) plgpu.wait_smem_to_gmem(0) - # Test that this runs without errors. - jax.block_until_ready(kernel()) + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) @parameterized.parameters( - ((128, 128), 128, jnp.float16), + ((128, 128), 128, jnp.float16, False), + # Test LHS in TMEM. + ((128, 128), 128, jnp.float16, True), # Test bfloat16 - ((128, 128), 128, jnp.bfloat16), + ((128, 128), 128, jnp.bfloat16, False), # Test additional swizzles. - ((128, 128), 64, jnp.float16), - ((128, 128), 32, jnp.float16), + ((128, 128), 64, jnp.float16, False), + ((128, 128), 32, jnp.float16, False), ) - def test_simple_matmul(self, shape, swizzle, dtype): + def test_simple_matmul(self, shape, swizzle, dtype, lhs_tmem=False): + self.skip_if_wg_semantics() # Test a matmul with a single block. swizzle_elems = swizzle // jnp.dtype(dtype).itemsize transforms = ( @@ -2131,9 +2150,16 @@ def test_simple_matmul(self, shape, swizzle, dtype): plgpu.SwizzleTransform(swizzle), ) - def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, + a_tmem_ref): + if lhs_tmem: + lhs_ref = a_tmem_ref + lhs_ref[...] = plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05) + plgpu.commit_tmem() + else: + lhs_ref = a_smem plgpu.tcgen05_mma(acc_tmem, - a_smem, + lhs_ref, b_smem, barrier_ref, accumulate=False) @@ -2144,10 +2170,15 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): plgpu.wait_smem_to_gmem(0) scratch_shapes = [ - plgpu.TMEM(shape, jnp.float32), + plgpu.TMEM(shape, jnp.float32, packed=False), plgpu.SMEM(shape, dtype, transforms=transforms), plgpu.Barrier(num_arrivals=1, for_tensor_core=True), ] + if lhs_tmem: + scratch_shapes.append(plgpu.TMEM(shape, dtype, packed=True)) + else: + scratch_shapes.append(None) + f = self.pallas_call( kernel, in_specs=( From 8c987bfb0afd36710038a96430ae95f473503e08 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 May 2025 08:58:13 -0700 Subject: [PATCH 1075/1769] Fix handling of final style primitives in pallas cost estimate. The pallas cost estimate logic operates on jaxprs, so it shouldn't call `get_bind_params`, which is used to lift staged out final style primitives. For example, calling `get_bind_params` on a `CallPrimitive` loses the inner jaxpr which would be needed for estimating the cost. Currently, only initial style higher-order primitives have cost rules, so `get_bind_params` is a no-op in all the covered cases, but I'm working on converting `custom_vjp` to an initial style primitive, which led me to this change. As written, I don't anticipate any change in behavior with this update. PiperOrigin-RevId: 756332549 --- jax/_src/pallas/cost_estimate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 73db4a2e2d4a..3b82d3095f64 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -64,12 +64,11 @@ def cost_estimate_jaxpr( total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) for eqn in jaxpr.eqns: - _, bind_params = eqn.primitive.get_bind_params(eqn.params) rule = _cost_rules.get(eqn.primitive, None) if rule is not None: context = Context(avals_in=[v.aval for v in eqn.invars], avals_out=[v.aval for v in eqn.outvars]) - op_cost = rule(context, **bind_params) + op_cost = rule(context, **eqn.params) total_cost = total_cost + op_cost return pallas_core.CostEstimate( flops=total_cost.flops, From 93d02349aab18dbff83569b62281801f4b509908 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 May 2025 10:38:52 -0700 Subject: [PATCH 1076/1769] Add `out_sharding` to `jnp.repeat`. Drop into auto mode if out_sharding is provided. In cases where axis is None or the input is sharded on the `axis` we are going to repeat on. If the input is not sharded on the repeat axis, forward the input sharding to the output. Fixes https://github.com/jax-ml/jax/issues/28538 PiperOrigin-RevId: 756372112 --- jax/_src/basearray.pyi | 3 ++- jax/_src/numpy/array_methods.py | 7 +++++-- jax/_src/numpy/lax_numpy.py | 32 ++++++++++++++++++++++++++++++-- jax/numpy/__init__.pyi | 3 ++- tests/pjit_test.py | 13 +++++++++++++ 5 files changed, 52 insertions(+), 6 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index cf64afdacfe3..a98cc012031e 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -189,7 +189,8 @@ class Array(metaclass=abc.ABCMeta): @property def real(self) -> Array: ... def repeat(self, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: ... + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: ... def reshape(self, *args: Any, order: str = "C", out_sharding: NamedSharding | P | None = ...) -> Array: ... def round(self, decimals: int = 0, out: None = None) -> Array: ... diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index a3dbc0f9f6c6..73f916537af0 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -293,12 +293,15 @@ def _real_property(self: Array) -> Array: return ufuncs.real(self) def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None) -> Array: """Construct an array from repeated elements. Refer to :func:`jax.numpy.repeat` for the full documentation. """ - return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) + return lax_numpy.repeat(self, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length, + out_sharding=out_sharding) def _reshape(self: Array, *args: Any, order: str = "C", out_sharding=None ) -> Array: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b662f1d6f7ed..2a1ec7227439 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -65,7 +65,9 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import (NamedSharding, PartitionSpec as P) +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P +from jax._src.mesh import get_abstract_mesh +from jax._src.pjit import auto_axes from jax.tree_util import tree_flatten, tree_map import numpy as np @@ -6630,7 +6632,8 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Construct an array from repeated elements. JAX implementation of :func:`numpy.repeat`. @@ -6694,6 +6697,31 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ + if out_sharding is not None: + return auto_axes( + partial(_repeat, axis=axis, total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a, repeats) + ctx_mesh = get_abstract_mesh() + if ctx_mesh._are_all_axes_explicit: + aval = core.typeof(a) + if axis is None or aval.sharding.spec[axis] is not None: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + assert axis is not None and aval.sharding.spec[axis] is None + out_sharding = (NamedSharding(ctx_mesh, P()) + if aval.sharding.mesh.empty else aval.sharding) + return auto_axes( + partial(_repeat, axis=axis, total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a, repeats) + try: + return _repeat(a, repeats, axis=axis, + total_repeat_length=total_repeat_length) + except core.ShardingTypeError as e: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + +def _repeat(a: ArrayLike, repeats: ArrayLike, *, axis: int | None = None, + total_repeat_length: int | None = None) -> Array: if core.is_dim(repeats): util.check_arraylike("repeat", a) else: diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index a8de717a0d07..c52ce2628cda 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -808,7 +808,8 @@ def reciprocal(x: ArrayLike, /) -> Array: ... register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, - total_repeat_length: int | None = ...) -> Array: ... + total_repeat_length: int | None = ..., + out_sharding: NamedSharding | P | None = None) -> Array: ... def reshape( a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., out_sharding: NamedSharding | P | None = ..., diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 373e3ceff7a8..0c43df76cc5f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7569,6 +7569,19 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_jnp_repeat(self, mesh): + out = jnp.repeat(np.eye(3), np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + + a = jnp.eye(3) + out = jnp.repeat(a, np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, a.sharding) + + a = jax.device_put(jnp.eye(4), P('x')) + out = jnp.repeat(a, np.array((2,2,2,2)) - 1, axis=0, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.with_explicit_mesh((2,), ('x',)) def test_scatter_gather(self, mesh): x = np.random.uniform(size=(mesh.size * 2, 3)) From bf830823246b4d604438285baf89f0c6ec6e8738 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 8 May 2025 10:39:13 -0700 Subject: [PATCH 1077/1769] [Pallas][Mosaic GPU] Add collective support to Blackwell/tcgen05 MMA. This exposes the collective argument to tcgen05_mma which allows paired CTAs to collaborate on matmuls across blocks. PiperOrigin-RevId: 756372254 --- jax/_src/pallas/mosaic_gpu/core.py | 21 ++++--- jax/_src/pallas/mosaic_gpu/lowering.py | 3 +- jax/_src/pallas/mosaic_gpu/primitives.py | 52 +++++++++++++--- tests/pallas/mosaic_gpu_test.py | 75 ++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index f193ce7d2743..6be8b3c4a8a5 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -126,11 +126,12 @@ def __call__( dtype: jnp.dtype, *, transforms: Sequence[MemoryRefTransform] = (), - packed: bool | None = None + packed: bool | None = None, + collective: bool | None = None ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms, - packed=packed) + packed=packed, collective=collective) class SemaphoreType(enum.Enum): @@ -224,10 +225,14 @@ class GPUMemoryRef(pallas_core.MemoryRef): # Whether to allow TMEM packing for sub 4-byte dtypes. packed: bool | None = dataclasses.field(default=None, kw_only=True) + collective: bool | None = dataclasses.field(default=None, kw_only=True) def __post_init__(self): - if self.packed is not None and self.memory_space != GPUMemorySpace.TMEM: - raise ValueError("Packed option is only supported for TMEM.") + if self.memory_space != GPUMemorySpace.TMEM: + if self.packed is not None: + raise ValueError("Packed option is only supported for TMEM.") + if self.collective is not None: + raise ValueError("Collective option is only supported for TMEM.") def get_ref_aval(self) -> _Ref: aval = jax_core.ShapedArray(self.shape, self.dtype) @@ -237,7 +242,8 @@ def get_ref_aval(self) -> _Ref: ref = pallas_core.TransformedRef( AbstractTMEMRef(aval, memory_space=self.memory_space, - packed=self.packed), () + packed=self.packed, + collective=self.collective), () ) else: ref = pallas_core.TransformedRef( @@ -936,11 +942,12 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: ) class AbstractTMEMRef(AbstractMemoryRef): - __slots__ = ["inner_aval", "memory_space", "packed"] + __slots__ = ["inner_aval", "memory_space", "packed", "collective"] - def __init__(self, inner_aval, memory_space, packed): + def __init__(self, inner_aval, memory_space, packed, collective): super().__init__(inner_aval, memory_space) self.packed = packed + self.collective = collective def __repr__(self) -> str: return f'TMEM({self.inner_aval.str_short()},packed={self.packed})' diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 29ec0d16d2fb..55aca65c6a19 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -251,7 +251,7 @@ def _run_scoped_resource_estimator( else: packing = 1 layout = tcgen05._infer_tmem_layout( - aval.shape, collective=False, packing=packing) + aval.shape, collective=aval.collective, packing=packing) cols_used = layout.cols_in_shape(aval.shape) cols_used = tcgen05._alloc_ncols(cols_used, exact=False) rs += Resources(tmem_scratch_cols=cols_used) @@ -2235,6 +2235,7 @@ def _run_scoped_lowering_rule( jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), packed=aval.packed, exact_cols=False, + collective=aval.collective, ) ) input_refs.append(input_ref) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1081f052f4c1..c5cf257070c4 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1124,9 +1124,20 @@ def tcgen05_mma(acc: _Ref, a: _Ref, b: _Ref, barrier: _Ref, - accumulate: bool | jax.Array = True): + accumulate: bool | jax.Array = True, + collective_axis: str | None = None): """Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell). + If run in collective mode, `acc`, `a` (LHS), and `b` (RHS) should correspond + to half of the total inputs to the MMA, where `acc` and `a` (LHS) are split + in half along the rows and `b` (RHS) is split along the columns like so: + + ----------- ----------- ----------- + | ACC1 | | LHS1 | | | | + ----------- += ----------- @ |RHS1|RHS2| + | ACC2 | | LHS2 | | | | + ----------- ----------- ----------- + Args: acc: The accumulator. Must be a TMEM Ref. a: The left-hand side. Must be a TMEM/SMEM Ref. @@ -1134,10 +1145,15 @@ def tcgen05_mma(acc: _Ref, barrier: Barrier Ref for synchronizing with the tensor core. Should have for_tensor_core set to True. accumulate: Whether to accumulate into acc or overwrite it. + collective_axis: The name of the cluster axis along which to perform + a collective MMA. The cluster axis should have a size of exactly 2, + and must be on the minormost cluster axis. """ acc_m, acc_n = acc.shape lhs_m, lhs_k = a.shape rhs_k, rhs_n = b.shape + if collective_axis is not None: + acc_n /= 2 if acc_m != lhs_m: raise ValueError( f"Accumulator and LHS have incompatible shapes. Accumulator: {acc.shape}. LHS: {a.shape}.") @@ -1164,16 +1180,14 @@ def tcgen05_mma(acc: _Ref, *a_transforms_leaves, *b_transforms_leaves, a_transforms_tree=a_transforms_tree, b_transforms_tree=b_transforms_tree, - collective=False) + collective_axis=collective_axis) @tcgen05_mma_p.def_abstract_eval def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, *transforms_leaves, a_transforms_tree, b_transforms_tree, - collective): + collective_axis): del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) - if collective: - raise NotImplementedError("Collective MMA not yet implemented.") if acc.memory_space != gpu_core.GPUMemorySpace.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") @@ -1183,6 +1197,14 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, if b.memory_space != gpu_core.GPUMemorySpace.SMEM: raise ValueError("RHS must be an SMEM Ref.") + if collective_axis is not None: + if not acc.collective: + raise ValueError( + "Accumulator Ref must be collective if collective_axis is set.") + if a.memory_space == gpu_core.GPUMemorySpace.TMEM and not a.collective: + raise ValueError( + "LHS TMEM Ref must be collective if collective_axis is set.") + for_tensor_core = getattr( barrier.inner_aval.dtype, "for_tensor_core", False) if not for_tensor_core: @@ -1202,7 +1224,7 @@ def _tcgen05_mma_lowering( *transforms_leaves, a_transforms_tree, b_transforms_tree, - collective: bool, + collective_axis, ): _, a_aval, b_aval, *_ = ctx.avals_in lhs_swizzle: int = 128 @@ -1267,12 +1289,26 @@ def _tcgen05_mma_lowering( accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) predicate = ctx.module_ctx.single_lane_predicate - if collective: + collective = False + if collective_axis is not None: + cluster_axis = lowering._resolve_cluster_axis( + ctx.module_ctx.axis_names, collective_axis) + if cluster_axis != gpu_dialect.Dimension(0): + # Note: resolve_cluster_axis checks if axis_names exists. + assert ctx.module_ctx.axis_names is not None + if len(ctx.module_ctx.axis_names.cluster) <= 1: + raise ValueError("No cluster axes found.") + minormost_cluster_axis = ctx.module_ctx.axis_names.cluster[0] + raise ValueError( + "Can only perform collective MMA along minormost cluster axis. " + f"Got {collective_axis}, expected {minormost_cluster_axis}.") index = ir.IndexType.get() is_leader_block = arith_dialect.cmpi( arith_dialect.CmpIPredicate.eq, - ctx.launch_ctx.cluster_idx(gpu_dialect.Dimension.x), mgpu.c(0, index)) + ctx.launch_ctx.cluster_idx(cluster_axis), mgpu.c(0, index)) predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + with mgpu.when(predicate): tcgen05.mma( acc, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index ed45c3b9f1e6..6746674250d0 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2197,6 +2197,81 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.parameters( + ((256, 256), (256, 256), 128, jnp.float16), + # Test additional shape combinations. + ((256, 128), (128, 128), 128, jnp.float16), + ((256, 64), (64, 256), 128, jnp.float16), + # Test bfloat16. + ((256, 256), (256, 256), 128, jnp.bfloat16), + # Test additional swizzles. + ((256, 256), (256, 256), 64, jnp.float16), + ((256, 256), (256, 256), 32, jnp.float16), + ) + def test_simple_collective_matmul(self, lhs_shape, rhs_shape, swizzle, dtype): + self.skip_if_wg_semantics() + # Test a collective (paired CTA) matmul on a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + acc_shape = (lhs_shape[0], rhs_shape[1]) + _acc_shape = (lhs_shape[0] // 2, rhs_shape[1]) + _lhs_shape = (lhs_shape[0] // 2, lhs_shape[1]) + _rhs_shape = (rhs_shape[0], rhs_shape[1] // 2) + + def kernel(a_gmem, b_gmem, out_gmem): + cluster_idx = lax.axis_index("x") + slice_lhs = pl.ds(cluster_idx * _lhs_shape[0], _lhs_shape[0]) + slice_rhs = pl.ds(cluster_idx * _rhs_shape[1], _rhs_shape[1]) + + @functools.partial(pl.run_scoped, + a_smem=plgpu.SMEM(_lhs_shape, dtype, transforms=transforms), + b_smem=plgpu.SMEM(_rhs_shape, dtype, transforms=transforms), + acc_tmem=plgpu.TMEM(_acc_shape, jnp.float32, collective=True), + scratch_smem=plgpu.SMEM(_acc_shape, dtype, transforms=transforms), + tma_barrier=plgpu.Barrier(num_arrivals=1), + mma_barrier=plgpu.Barrier(num_arrivals=1, for_tensor_core=True), + cluster_barrier=plgpu.ClusterBarrier(collective_axes=("x",)), + ) + def _scoped(a_smem, b_smem, + acc_tmem, scratch_smem, tma_barrier, mma_barrier, cluster_barrier): + plgpu.copy_gmem_to_smem(a_gmem.at[slice_lhs, :], a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem.at[:, slice_rhs], b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + + plgpu.tcgen05_mma(acc_tmem, + a_smem, + b_smem, + mma_barrier, + accumulate=False, + collective_axis="x") + plgpu.barrier_wait(mma_barrier) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_gmem.at[slice_lhs, :]) + plgpu.wait_smem_to_gmem(0) + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(acc_shape, dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + ) + x = jax.random.uniform(jax.random.key(0), shape=lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=rhs_shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + class PallasCallSm100AWGTest( PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From f15ad5a3521ed9a966e37a3abcc3e9002bd80ba3 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 8 May 2025 10:54:40 -0700 Subject: [PATCH 1078/1769] [Pallas][Mosaic GPU] Refactor carry_coroutine in warp specialized pipeline to be a callback instead of a coroutine. PiperOrigin-RevId: 756378224 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 90 +++++++++++++------ .../pallas/ops/gpu/attention_mgpu.py | 18 ++-- tests/pallas/mosaic_gpu_test.py | 13 ++- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 9b743bb18b37..426f314bc3a1 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -16,6 +16,7 @@ from __future__ import annotations +from typing import Protocol, TypeVar from collections.abc import Callable, Sequence import dataclasses import functools @@ -39,6 +40,7 @@ map = util.safe_map zip = util.safe_zip +T = TypeVar('T') def _get_block_size( bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None, @@ -378,6 +380,34 @@ def do_fetch(): return pipeline +class ComputeContext(Protocol): + """Protocol for a compute context for the warp specialized pipeline. + + The ComputeContext is run exclusively in the compute thread and allows + the user to set up a prologue to initialize a pipeline carry and an epilogue + to consume the final carry. + + All values allocated in the ComputeContext will only be allocated in the + compute thread and not the memory thread. This can potentially reduce + register pressure if certain values are only consumed by the compute threads. + + Usage will usually follow this structure: + + ``` + def compute_context(pipeline): + # Perform prologue work and compute the initial carry. + initial_carry = ... + # Run the pipeline. + final_carry = pipeline(*initial_carry) + # Perform epilogue work using the final carry. + do_work(final_carry) + ``` + + """ + def __call__(self, pipeline: Callable[[T], T]) -> None: + ... + + def emit_pipeline_warp_specialized( body: Callable[..., None], *, @@ -389,7 +419,7 @@ def emit_pipeline_warp_specialized( wg_axis: str, num_compute_wgs: int, manual_consumed_barriers: bool = False, - carry_coroutine: Any | None = None, + compute_context: ComputeContext | None = None, memory_thread_idx: int | None = None, ): """Creates a function to emit a warp-specialized pipeline. @@ -402,7 +432,7 @@ def emit_pipeline_warp_specialized( def body(indices, *input_refs, *output_refs, [consumed_barriers]) -> None: ``` - or with a carries enabled (enabled via the ``carry_coroutine`` argument), + or with a carries enabled (enabled via the ``compute_context`` argument), where the body returns the next carry: ``` @@ -425,11 +455,15 @@ def body( manual_consumed_barriers: If True, consumed barriers will be passed into the body function after the output refs. There will be one barrier per input and will be passed in the same order. - carry_coroutine: If specified, enables carries in the pipeline. - The signature of the body function will be modified such that the last - argument will be the current carry and it must return the next carry. - The coroutine itself should yield the initial carry, and the - yield statement will return the final value of the carry. + compute_context: If specified, enables carries in the pipeline and allows + a user-specified prologue/epilogue that is only executed in the compute + thread. The signature of the pipeline body function will be modified + such that the last argument will be the current carry and it must + return the next carry. + The compute_context itself should follow the signature of `ComputeContext` + and take a pipeline function as its sole argument. Calling the + pipeline with the initial carry will run the pipeline and return the + final carry. memory_thread_idx: The index of the memory thread. If not specified, defaults to the last thread. """ @@ -443,7 +477,7 @@ def body( # thread is the last thread. raise NotImplementedError("Memory thread must be the last thread.") - has_carry = carry_coroutine is not None + has_carry = compute_context is not None # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. @@ -622,25 +656,29 @@ def compute_loop_body(step, carry): ] if has_carry: - _carry = carry_coroutine() - try: - carry_init = next(_carry) - except StopIteration: - raise ValueError("carry_coroutine must yield the initial carry.") # pylint: disable=raise-missing-from + last_indices = None + def pipeline_callback(user_init_carry): + nonlocal last_indices + if last_indices is not None: + raise ValueError( + "Cannot call pipeline more than once in `compute_context`") + print("[DEBUG] user_init_carry: ", user_init_carry) + init_loop_carry = (init_indices, last_store_slices, user_init_carry) + last_indices, _, final_body_carry = lax.fori_loop(0, + num_steps, + compute_loop_body, + init_loop_carry) + print("[DEBUG] final_body_carry: ", final_body_carry) + return final_body_carry + compute_context(pipeline_callback) + if last_indices is None: + raise ValueError("Pipeline was not called in `compute_context`") else: - _carry = None - carry_init = None - init_loop_carry = (init_indices, last_store_slices, carry_init) - last_indices, _, final_body_carry = lax.fori_loop(0, - num_steps, - compute_loop_body, - init_loop_carry) - if has_carry: - try: - _carry.send(final_body_carry) # pytype: disable=attribute-error - raise ValueError("carry_coroutine must only yield once.") - except StopIteration: - pass + assert compute_context is None + last_indices, _, _ = lax.fori_loop( + 0, num_steps, compute_loop_body, + (init_indices, last_store_slices, None) + ) # Handle index_invariant outputs after the loop. They are not # written in the main pipeline loop. diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 4d43d6045fee..e7a9898bf9f3 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -336,7 +336,7 @@ def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref, kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers - def _compute_thread(): + def _compute_thread(pipeline_callback): q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx] q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q q_slice = (batch, pl.ds(q_seq_base, block_q), q_head) @@ -360,7 +360,7 @@ def _compute_thread(): dq_acc = plgpu.layout_cast( jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, ) - dq, _, _ = (yield (dq_acc, lse, delta)) + dq, _, _ = pipeline_callback((dq_acc, lse, delta)) q_smem[...] = dq.astype(dtype) plgpu.commit_smem() plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice]) @@ -406,7 +406,7 @@ def compute_dq(acc_ref): memory_registers=40, wg_axis="wg", manual_consumed_barriers=True, - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), @@ -429,7 +429,7 @@ def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, (k_smem2, v_smem2) = smem_buffers (k_barriers, v_barriers) = buffer_barriers - def _compute_thread(): + def _compute_thread(pipeline_callback): k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -449,7 +449,7 @@ def _compute_thread(): dv_acc = plgpu.layout_cast( jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, ) - (dk, dv) = (yield (dv_acc, dk_acc)) + (dk, dv) = pipeline_callback((dv_acc, dk_acc)) k_smem[...] = dk.astype(dtype) v_smem[...] = dv.astype(dtype) @@ -513,7 +513,7 @@ def compute_dk(acc_ref): memory_registers=40, wg_axis="wg", manual_consumed_barriers=True, - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ plgpu.GPUBlockSpec( # q block_shape=(block_q, head_dim), @@ -627,7 +627,7 @@ def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) - def _compute_thread(): + def _compute_thread(pipeline_callback): qo_smem = qo_smem2.at[wg_idx] lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32) @@ -641,7 +641,7 @@ def _compute_thread(): ) plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) - final_carry = (yield (acc, m_i, l_i)) + final_carry = pipeline_callback((acc, m_i, l_i)) pl.when(wg_idx == 0)(perform_schedule_barrier) acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) @@ -699,7 +699,7 @@ def compute_pv(acc_ref): memory_registers=40, wg_axis="wg", manual_consumed_barriers=True, - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6746674250d0..71b4b491f7e4 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2749,16 +2749,14 @@ def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): thread_name="wg", ) def kernel(x_gmem, acc_gmem, acc_smem): - def _compute_thread(): + def _compute_thread(pipeline_fn): # Cast the init value to the same layout as x_smem, so the pipeline loop # carry has a constant signature. o_acc = plgpu.layout_cast( jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32), plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2)) - carry_init = (o_acc,) # Pass control to the pipeline emitter and return the final carry. - final_carry = (yield carry_init) - o_final, = final_carry + o_final = pipeline_fn(o_acc) # Note that both compute WGs are doing identical work so the potential # race condition on the store here won't affect the result. acc_smem[...] = o_final @@ -2767,9 +2765,8 @@ def _compute_thread(): plgpu.wait_smem_to_gmem(0) def tiled_acc_kernel(_, x_smem, carry): - o_carry, = carry - new_carry = x_smem[...] + o_carry - return (new_carry,) + new_carry = x_smem[...] + carry + return new_carry pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( tiled_acc_kernel, @@ -2778,7 +2775,7 @@ def tiled_acc_kernel(_, x_smem, carry): num_compute_wgs=num_compute_wgs, memory_registers=40, wg_axis="wg", - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ pl.BlockSpec( block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) From 098624dcf81ad65c7d41dc35efcbc3feda010c7f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 May 2025 11:10:18 -0700 Subject: [PATCH 1079/1769] Make the type checker match the runtime behavior of PartitionSpec not inherting from a tuple. PiperOrigin-RevId: 756385141 --- jax/_src/partition_spec.py | 11 ++--------- jax/_src/pjit.py | 2 +- jax/_src/sharding_impls.py | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index fcea21934bfb..3542a232ba35 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any class UnconstrainedSingleton: @@ -63,7 +63,7 @@ def unpicke_pspec(partitions, unreduced): AxisName = Any -class PartitionSpecImpl: +class PartitionSpec: """Tuple describing how to partition an array across a mesh of devices. Each element is either ``None``, a string, or a tuple of strings. @@ -166,10 +166,3 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: if len(out) < ndim: out.extend([None] * (ndim - len(out))) return self.with_partitions(out) - - -if TYPE_CHECKING: - class PartitionSpec(PartitionSpecImpl, tuple): # type: ignore - ... -else: - PartitionSpec = PartitionSpecImpl diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f87207e0a796..ca77b659a08d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2141,7 +2141,7 @@ def _insert_axis_partitions(spec, dim, val): too_short = dim - len(spec) if too_short > 0: spec += (None,) * too_short - new_partitions = tuple_insert(spec, dim, val) + new_partitions = tuple_insert(spec, dim, val) # type: ignore return PartitionSpec(*new_partitions) def _pjit_batcher_for_sharding( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 6e86911e63b0..2394e9e18f38 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1244,7 +1244,7 @@ def logical_sharding(logical_shape, dtype, phys_sharding) -> jsharding.Sharding: phys_spec = (*phys_sharding.spec, *[None] * (len(phys_shape) - len(phys_sharding.spec))) else: - phys_spec = phys_sharding.spec + phys_spec = phys_sharding.spec # type: ignore return phys_sharding.with_spec(phys_spec[:-elt_aval.ndim]) else: return get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding) From 08385f51d0975d0a46661e95a6a22969c48145fc Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 8 May 2025 11:10:35 -0700 Subject: [PATCH 1080/1769] [jaxlib] Add compile_and_load, compile_and_load_ifrt_program to xla_client stub. PiperOrigin-RevId: 756385283 --- jaxlib/_jax/__init__.pyi | 12 ++++++++++++ jaxlib/xla_client.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 8c02bb4ba722..1582930dd44d 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -511,11 +511,23 @@ class Client: compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... + def compile_and_load( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... def compile_ifrt_program( self, program: ifrt_programs.Program, program_options: ifrt_programs.CompileOptions, ) -> LoadedExecutable: ... + def compile_and_load_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... def deserialize_executable( self, diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 725d05a2dace..449dfa653286 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 338 +_version = 339 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 6c66e938ee362a5ee562d7917bc607779d0a6f64 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Thu, 8 May 2025 11:17:41 -0700 Subject: [PATCH 1081/1769] [Mosaic][SC] Expose control over the number of active SC cores On SparseCore, the `core_parallel` dimension semantic allows the user to set the number of active cores to `all` or `1` based on the dimension size. PiperOrigin-RevId: 756388100 --- jax/_src/tpu_custom_call.py | 86 ++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 2 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 6039979df37b..0f099ed45cac 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -125,6 +125,7 @@ class CustomCallBackendConfig: internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None disable_bounds_checks: bool + active_core_count: int | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -212,6 +213,8 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") + if self.device_type == "sparsecore" and self.active_core_count == 1: + config.write(b', "megachip_parallelism_config": {"cores": ["0"]}') config.write(b"}") return config.getvalue() @@ -355,14 +358,89 @@ def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult: ) if tensorcore_func_found and sparsecore_func_found: raise ValueError( - "A single Mosaic kernel cannot contain both " - "TensorCore and SparseCore functions." + "A single Mosaic kernel cannot contain both TensorCore and SparseCore" + " functions." ) if sparsecore_func_found: return "sparsecore" return None +def _get_active_core_count(module: ir.Module) -> int | None: + + def get_core_parallel_dim_size( + dim_semantics: ir.ArrayAttr, + iter_bounds: ir.DenseI64ArrayAttr, + other_subkernel_core_dim_size: int | None = None) -> int | None: + + if len(iter_bounds) != len(dim_semantics): + raise ValueError( + "The iteration bounds and dimension semantics attributes must have" + " the same number of elements." + ) + + subkernel_core_dim_size = None + + for dim_idx, (dim_size, dim_sem) in enumerate( + zip(iter_bounds, dim_semantics) + ): + if str(dim_sem) != "#tpu.dimension_semantics": + continue + + if ir.ShapedType.is_dynamic_size(dim_size): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + f"{dim_idx} must be statically known." + ) + if subkernel_core_dim_size is not None: + raise ValueError( + "A single Mosaic subkernel cannot contain multiple core sharding " + "dimensions." + ) + if ( + other_subkernel_core_dim_size is not None + and other_subkernel_core_dim_size != dim_size + ): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + "be the same across all subkernels." + ) + subkernel_core_dim_size = dim_size + + return subkernel_core_dim_size + + core_parallel_dim_size = None + + for op in module.body.operations: + if op.operation.name != "func.func": + continue + + if ( + "iteration_bounds" not in op.attributes + or "dimension_semantics" not in op.attributes + ): + continue + + try: + iter_bounds = ir.DenseI64ArrayAttr(op.attributes["iteration_bounds"]) + except ValueError as e: + e.add_note("The iteration bounds attribute must be an array.") + raise + try: + dim_semantics = ir.ArrayAttr(op.attributes["dimension_semantics"]) + except ValueError as e: + e.add_note("The dimension semantics attribute must be an array.") + raise + + core_parallel_dim_size = get_core_parallel_dim_size( + dim_semantics=dim_semantics, + iter_bounds=iter_bounds, + other_subkernel_core_dim_size=core_parallel_dim_size, + ) + + return core_parallel_dim_size + + def _lower_to_custom_call_config( module: ir.Module, *, @@ -392,6 +470,7 @@ def _lower_to_custom_call_config( kernel_name=kernel_name, ir_version=ir_version, ) + active_core_count = _get_active_core_count(module) return _lowered_to_custom_call_config( lowered_module_asm, vmem_limit_bytes=vmem_limit_bytes, @@ -408,6 +487,7 @@ def _lower_to_custom_call_config( needs_layout_passes=needs_layout_passes, output_memory_spaces=output_memory_spaces, disable_bounds_checks=disable_bounds_checks, + active_core_count=active_core_count, ) @@ -428,6 +508,7 @@ def _lowered_to_custom_call_config( device_type: str | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, disable_bounds_checks: bool = False, + active_core_count: int | None = None, ): if has_custom_barrier: if collective_id is None: @@ -459,6 +540,7 @@ def _lowered_to_custom_call_config( internal_scratch_in_bytes, output_memory_spaces, disable_bounds_checks, + active_core_count=active_core_count, ) return config From 515f81bfa3e454ccc2f2b757de18e31d4bcb2ae4 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Thu, 8 May 2025 11:45:36 -0700 Subject: [PATCH 1082/1769] Shut down `PreemptionSyncManager` when `jax.distributed.shutdown()` is called. PiperOrigin-RevId: 756398914 --- jax/_src/distributed.py | 7 +++++-- jaxlib/_jax/__init__.pyi | 1 + jaxlib/xla.cc | 6 +++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index ef8c48a61293..dad445b8e539 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -156,14 +156,17 @@ def initialize(self, self.slice_index = slice_index def shutdown(self): + if self.preemption_sync_manager: + # It's important to shut down the preemption sync manager before the + # client because the preemption sync manager depends on the client. + self.preemption_sync_manager.shutdown() + self.preemption_sync_manager = None if self.client: self.client.shutdown() self.client = None if self.service: self.service.shutdown() self.service = None - if self.preemption_sync_manager: - self.preemption_sync_manager = None def initialize_preemption_sync_manager(self): if self.preemption_sync_manager is not None: diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 1582930dd44d..8eb9b5f8173d 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -847,6 +847,7 @@ def get_distributed_runtime_client( class PreemptionSyncManager: def initialize(self, client: DistributedRuntimeClient) -> _Status: ... def reached_sync_point(self, step_counter: int) -> bool: ... + def shutdown(self) -> None: ... def create_preemption_sync_manager() -> PreemptionSyncManager: ... def collect_garbage() -> None: ... diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 0d3d8f6e1b29..adf6f3c98297 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -627,7 +627,11 @@ NB_MODULE(_jax, m) { .def("reached_sync_point", [](tsl::PreemptionSyncManager& manager, int step_counter) { return manager.ReachedSyncPoint(step_counter); - }); + }) + .def("shutdown", [](tsl::PreemptionSyncManager& manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); m.def("create_preemption_sync_manager", []() { return tsl::CreatePreemptionSyncManager(); }); From 44a30b2090291fa9f8531c98d088082c4f934b82 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 8 May 2025 12:07:46 -0700 Subject: [PATCH 1083/1769] [Pallas][Mosaic GPU] Add transpose support to tcgen05_mma PiperOrigin-RevId: 756407270 --- jax/_src/pallas/mosaic_gpu/primitives.py | 18 ++++++++----- tests/pallas/mosaic_gpu_test.py | 32 ++++++++++++++++-------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index c5cf257070c4..40eccca7c711 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1227,7 +1227,7 @@ def _tcgen05_mma_lowering( collective_axis, ): _, a_aval, b_aval, *_ = ctx.avals_in - lhs_swizzle: int = 128 + lhs_swizzle: int | None = None lhs_transpose: bool = False if a_transforms_tree is not None: a_transforms_leaves, b_transforms_leaves = util.split_list( @@ -1277,14 +1277,20 @@ def _tcgen05_mma_lowering( ) swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize - if rhs_swizzle != lhs_swizzle: + if lhs_swizzle is None: + lhs_swizzle = rhs_swizzle + elif rhs_swizzle != lhs_swizzle: raise ValueError("MMA rhs swizzle must match lhs swizzle." f" {lhs_swizzle=} {rhs_swizzle=}") if rhs_tiling != (8, swizzle_elems): raise ValueError("MMA rhs tiling does not fit swizzle" f" {rhs_tiling=} expected={(8, swizzle_elems)}") - if lhs_transpose or rhs_transpose: - raise NotImplementedError("Lowering does not yet support transpose") + if lhs_transpose: + if isinstance(a_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") + a_ref = mgpu.memref_transpose(a_ref, (1, 0, 3, 2)) + if rhs_transpose: + b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2)) if isinstance(accumulate, bool): accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) @@ -1314,8 +1320,8 @@ def _tcgen05_mma_lowering( acc, a_ref, b_ref, - a_swizzle=rhs_swizzle, - b_swizzle=lhs_swizzle, + a_swizzle=lhs_swizzle, + b_swizzle=rhs_swizzle, accumulate=accumulate, collective=collective, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 71b4b491f7e4..a5719b01f4ad 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2131,18 +2131,20 @@ def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): x_result = jax.block_until_ready(kernel(x)) np.testing.assert_array_equal(x_result, x + 1) - @parameterized.parameters( - ((128, 128), 128, jnp.float16, False), - # Test LHS in TMEM. - ((128, 128), 128, jnp.float16, True), - # Test bfloat16 - ((128, 128), 128, jnp.bfloat16, False), - # Test additional swizzles. - ((128, 128), 64, jnp.float16, False), - ((128, 128), 32, jnp.float16, False), - ) - def test_simple_matmul(self, shape, swizzle, dtype, lhs_tmem=False): + @parameterized.product(shape=[(128, 128)], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], + transpose_rhs=[False, True], + transpose_lhs=[False, True]) + def test_simple_matmul(self, shape, swizzle, + dtype=jnp.float16, + lhs_tmem=False, + transpose_lhs=False, + transpose_rhs=False): self.skip_if_wg_semantics() + if transpose_lhs and lhs_tmem: + self.skipTest("TMEM transpose not supported.") # Test a matmul with a single block. swizzle_elems = swizzle // jnp.dtype(dtype).itemsize transforms = ( @@ -2152,6 +2154,10 @@ def test_simple_matmul(self, shape, swizzle, dtype, lhs_tmem=False): def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, a_tmem_ref): + if transpose_lhs: + a_smem = plgpu.transpose_ref(a_smem, (1, 0)) + if transpose_rhs: + b_smem = plgpu.transpose_ref(b_smem, (1, 0)) if lhs_tmem: lhs_ref = a_tmem_ref lhs_ref[...] = plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05) @@ -2194,6 +2200,10 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) result = f(x, y) + if transpose_lhs: + x = jnp.transpose(x, (1, 0)) + if transpose_rhs: + y = jnp.transpose(y, (1, 0)) expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) From d2284bf89314d447d7270156ae2cda40395b521c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 May 2025 12:53:19 -0700 Subject: [PATCH 1084/1769] Allow unreduced propagation only for `add` right now. All nary ops cannot forward unreduced as is. `mul` is an example since it's not linear when both inputs are unreduced. `mul` can forward unreduced when one of the inputs is replicated or a constant and the other is unreduced. PiperOrigin-RevId: 756423798 --- jax/_src/lax/lax.py | 44 ++++++++++++++++++++++---------------- jax/_src/lax/linalg.py | 3 ++- jax/_src/lax/utils.py | 10 ++++++--- jax/_src/partition_spec.py | 3 +++ 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f6b4c1be102f..b41e78899ba9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3974,23 +3974,10 @@ def broadcasting_sharding_rule(name, *avals): raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') - - unreduced = [a.sharding.spec.unreduced for a in avals if a.shape] - # TODO(yashkatariya): Relax this restriction to allow - # `f32[8]{R:x} * f32[8]{U:x} -> f32[8]{U:x}` for example and maybe more cases. - if unreduced: - if not all(unreduced[0] == u for u in unreduced[1:]): - raise core.ShardingTypeError( - 'All arrays must be unreduced along the same mesh axes. Got' - f' {", ".join(map(str, map(tuple, unreduced)))}') - result_unreduced = unreduced[0] - else: - result_unreduced = None - - return NamedSharding(mesh, P(*result_specs, unreduced=result_unreduced)) + return NamedSharding(mesh, P(*result_specs)) def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, - require_same_dtypes=True): + require_same_dtypes=True, unreduced_rule=None): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, allow_extended_dtype=allow_extended_dtype, require_same=require_same_dtypes) @@ -3998,7 +3985,8 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, sharding_rule = partial(broadcasting_sharding_rule, name) prim = standard_primitive( shape_rule, dtype_rule, name, sharding_rule=sharding_rule, - vma_rule=partial(core.standard_vma_rule, name)) + vma_rule=partial(core.standard_vma_rule, name), + unreduced_rule=unreduced_rule) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -4586,8 +4574,22 @@ def _add_transpose(t, x, y): else: return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)] -# TODO(slebedev): Why does mypy fail to infer the type here? -add_p: Primitive = standard_naryop([_num, _num], 'add') +def _add_unreduced(out_sharding, *avals): + unreduced = [a.sharding.spec.unreduced for a in avals if a.shape] + # TODO(yashkatariya): Relax this restriction to allow + # `f32[8]{R:x} + f32[8]{U:x} -> f32[8]{U:x}` for example and maybe more cases. + if unreduced: + if not all(unreduced[0] == u for u in unreduced[1:]): + raise core.ShardingTypeError( + 'All arrays must be unreduced along the same mesh axes. Got' + f' {", ".join(map(str, map(tuple, unreduced)))}') + res_unreduced = unreduced[0] + else: + res_unreduced = None + return out_sharding.with_spec(out_sharding.spec.with_unreduced(res_unreduced)) + +add_p: Primitive = naryop(_input_dtype, [_num, _num], 'add', + unreduced_rule=_add_unreduced) ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) @@ -4897,7 +4899,8 @@ def _convert_element_type_bind_with_trace(trace, args, params): _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, _convert_element_type_sharding_rule, - partial(core.standard_vma_rule, convert_element_type_p.name))) + partial(core.standard_vma_rule, convert_element_type_p.name), + None)) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule @@ -5202,6 +5205,9 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + if lhs.sharding.spec.unreduced or rhs.sharding.spec.unreduced: + raise NotImplementedError( + 'Please file an issue at https://github.com/jax-ml/jax/issues') (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 848107faf204..dd86c22432d8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -780,7 +780,8 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, sharding_rule, - partial(core.standard_vma_rule, name))) + partial(core.standard_vma_rule, name), + None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 9e033cadd933..a850b2965338 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -38,13 +38,14 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None, vma_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None, + unreduced_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule, vma_rule)) + weak_type_rule, sharding_rule, vma_rule, unreduced_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -103,7 +104,8 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, return out_shapes, out_dtypes, out_shardings def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, vma_rule, *avals, **kwargs): + sharding_rule, vma_rule, unreduced_rule, + *avals, **kwargs): for a in avals: if isinstance(a, state.AbstractRef): raise ValueError( @@ -125,6 +127,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) out_vma = vma_rule(*avals, **kwargs) + if unreduced_rule is not None: + out_sharding = unreduced_rule(out_sharding, *avals, **kwargs) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, vma=out_vma) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 3542a232ba35..040db35ccb2b 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -160,6 +160,9 @@ def count(self, value): def with_partitions(self, new_partitions): return PartitionSpec(*new_partitions, unreduced=self._unreduced) + def with_unreduced(self, new_unreduced): + return PartitionSpec(*self._partitions, unreduced=new_unreduced) + def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self._partitions] From 05d4ff422b521dae7353f344afe29aa39a87717a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 May 2025 13:25:22 -0700 Subject: [PATCH 1085/1769] Remove __div__ and __rdiv__ from jax.Array These are leftover from Python 2, and no longer used in Python 3 --- jax/_src/numpy/array_methods.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 73f916537af0..b29b95219325 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -971,8 +971,6 @@ def max(self, values: ArrayLike, *, "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), - "div": _defer_to_unrecognized_arg("/", ufuncs.divide), - "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), From 3712656b636614527a5cbaa5bf25fec619d8c681 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Thu, 8 May 2025 13:44:37 -0700 Subject: [PATCH 1086/1769] Fix the PyTest TPU jobs on the Continuous Wheel Tests workflow. PiperOrigin-RevId: 756443164 --- tests/xla_bridge_test.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 5a6bf80a469d..3306cb64aced 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -210,18 +210,23 @@ def test_register_plugin_with_config(self): def test_register_plugin_with_lazy_config(self): options = {"bar": "baz"} - def f(): + def getopts(): return options + def make_c_api_client(plugin_name, new_options, *args, **kwargs): + self.assertContainsSubset(new_options, options) + with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( _profiler, "register_plugin_profiler", autospec=True ): - xb.register_plugin("foo", options=f, library_path="/dev/null") - with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: + xb.register_plugin("foo", options=getopts, library_path="/dev/null") + with mock.patch.object( + xc, "make_c_api_client", autospec=True, wraps=make_c_api_client + ) as mock_make: with mock.patch.object(xc, "pjrt_plugin_initialized", autospec=True): xb._backend_factories["foo"].factory() - mock_make.assert_called_once_with("foo", options, None) + mock_make.assert_called_once() class GetBackendTest(jtu.JaxTestCase): From 4786d122433ff9aa5525d901d0797070c975fb3d Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Thu, 8 May 2025 14:10:24 -0700 Subject: [PATCH 1087/1769] Remove type annotation of get_gpu_client PiperOrigin-RevId: 756454023 --- jaxlib/_jax/__init__.pyi | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 8eb9b5f8173d..6f4f952be9c3 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -572,17 +572,6 @@ def get_tfrt_cpu_client( collectives: CpuCollectives | None = ..., num_devices: int | None = ..., ) -> Client: ... -def get_gpu_client( - asynchronous: bool = ..., - allocator_config: GpuAllocatorConfig = ..., - distributed_client: DistributedRuntimeClient | None = ..., - node_id: int = ..., - num_nodes: int = ..., - allowed_devices: Any | None = ..., - platform_name: str | None = ..., - mock: bool | None = ..., - mock_gpu_topology: str | None = ..., -) -> Client: ... def get_mock_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., From 8683b76dad586124c32ee33504d93a3c4751ded2 Mon Sep 17 00:00:00 2001 From: Alina Sbirlea Date: Thu, 8 May 2025 16:38:50 -0700 Subject: [PATCH 1088/1769] Integrate LLVM at llvm/llvm-project@2d287f51eff2 Updates LLVM usage to match [2d287f51eff2](https://github.com/llvm/llvm-project/commit/2d287f51eff2) PiperOrigin-RevId: 756508479 --- jaxlib/mosaic/dialect/tpu/transforms/communication.cc | 8 ++++++-- jaxlib/mosaic/gpu/launch_lowering.cc | 11 ++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 7e99dd15611b..dfe42111916c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -111,8 +111,12 @@ struct LogicalToPhysicalDeviceIdPass {total_devices}, IntegerType::get(func.getContext(), 32), TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); - func.insertArgument(func.getNumArguments(), device_assignment_type, - nullptr, UnknownLoc::get(func.getContext())); + + if (failed(func.insertArgument(func.getNumArguments(), + device_assignment_type, nullptr, + UnknownLoc::get(func.getContext())))) { + return signalPassFailure(); + } auto device_assignment_arg = func.getArgument(func.getNumArguments() - 1); func.walk([device_assignment_arg](Operation *some_op) { if (auto op = dyn_cast(some_op)) { diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 53d4f47e58cc..44362e825345 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -238,7 +238,7 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, cluster = as_32bit(launch.getClusterSizeOperandValues()); } else { cluster.x = cluster.y = cluster.z = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); } mlir::Value stream = launch.getAsyncObject(); builder.create( @@ -337,15 +337,16 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { launch.getDynamicSharedMemorySize(), cluster_shape); // Add a new function argument for the kernel handle. - func.insertArgument(0, ptr_ty, - mlir::DictionaryAttr::get(func.getContext()), - mlir::UnknownLoc::get(func.getContext())); + if (failed(func.insertArgument( + 0, ptr_ty, mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())))) { + return mlir::WalkResult::interrupt(); + } mlir::Value kernel_handle = func.getArgument(0); if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { return mlir::WalkResult::interrupt(); } launch.erase(); - // TODO(apaszke): Generate a destructor function. // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); From a9c49ac085bc8c9635a5c2785c32d26bfa624ab9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 May 2025 17:00:48 -0700 Subject: [PATCH 1089/1769] Simplify add's unreduced rule. Only propagate unreduced if both lhs and rhs are unreduced. PiperOrigin-RevId: 756515463 --- jax/_src/lax/lax.py | 26 +++++++++++++++++--------- tests/pjit_test.py | 22 +++++++++++++++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b41e78899ba9..74297fc57b43 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4574,16 +4574,24 @@ def _add_transpose(t, x, y): else: return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)] -def _add_unreduced(out_sharding, *avals): - unreduced = [a.sharding.spec.unreduced for a in avals if a.shape] - # TODO(yashkatariya): Relax this restriction to allow - # `f32[8]{R:x} + f32[8]{U:x} -> f32[8]{U:x}` for example and maybe more cases. - if unreduced: - if not all(unreduced[0] == u for u in unreduced[1:]): +def _add_unreduced(out_sharding, x, y): + x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced + if x_ur and y_ur: + if x_ur != y_ur: raise core.ShardingTypeError( - 'All arrays must be unreduced along the same mesh axes. Got' - f' {", ".join(map(str, map(tuple, unreduced)))}') - res_unreduced = unreduced[0] + 'lhs and rhs to `add` must be unreduced along the same mesh axes. ' + f'Got lhs={x_ur}, rhs={y_ur}') + res_unreduced = x_ur + elif x_ur or y_ur: + if x_ur and not y_ur: + lhs_str, rhs_str = 'lhs', 'rhs' + else: + assert not x_ur and y_ur + lhs_str, rhs_str = 'rhs', 'lhs' + raise core.ShardingTypeError( + f'{lhs_str} is unreduced while {rhs_str} is not. `add` operation does' + ' not allow this because there will be implicit communication. Please' + f' reduce {lhs_str} via `reshard` before calling `add`.') else: res_unreduced = None return out_sharding.with_spec(out_sharding.spec.with_unreduced(res_unreduced)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0c43df76cc5f..e4b91aa17644 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7763,22 +7763,34 @@ def h(x, y): "unreduced axes should be equal to the contracting specs"): h.trace(x, y) - @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_add_unreduced_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) x = jax.device_put(np_inp, P('x', 'y')) y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'z')) + b = jax.device_put(np_inp.T, P('z', None)) @jax.jit - def f(x, y): + def f(x, y, a, b): m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) - m2 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x')) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced='z')) return m1 + m2 with self.assertRaisesRegex( core.ShardingTypeError, - "arrays must be unreduced along the same mesh axes"): - f.trace(x, y) + "lhs and rhs to `add` must be unreduced along the same mesh axes"): + f.trace(x, y, a, b) + + @jax.jit + def g(x, y): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x')) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, "lhs is unreduced while rhs is not"): + g.trace(x, y) @jtu.pytest_mark_if_available('multiaccelerator') From da8f62b4742901a9388b6f91ab8c4740868bb32a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 8 May 2025 17:35:32 -0700 Subject: [PATCH 1090/1769] Fix empty string handling for cloud_tpu_cluster PiperOrigin-RevId: 756525786 --- jax/_src/clusters/cloud_tpu_cluster.py | 61 ++++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index c8aa765c181c..4807a7194c5b 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Optional import logging import os import re @@ -54,24 +55,26 @@ def get_metadata(key): raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") return api_resp.text, api_resp.status_code -def get_tpu_env_value(key): - def get_tpu_env_value_from_metadata(key): - tpu_env_data = get_metadata('tpu-env')[0] - key_value_pairs = tpu_env_data.split('\n') - for key_value_pair in key_value_pairs: - # Typical line is MEGASCALE_NUM_SLICES: '2' - if ':' in key_value_pair: - row_key, value = re.split(':', key_value_pair, 1) - row_key = row_key.strip() - if row_key == key: - return value.strip().strip("'") - return None - +def get_tpu_env_value_from_metadata(key) -> Optional[str]: + metadata_value = None + tpu_env_data = get_metadata('tpu-env')[0] + key_value_pairs = tpu_env_data.split('\n') + for key_value_pair in key_value_pairs: + # Typical line is MEGASCALE_NUM_SLICES: '2' + if ':' in key_value_pair: + row_key, value = re.split(':', key_value_pair, 1) + row_key = row_key.strip() + if row_key == key: + metadata_value = value.strip().strip("'") + return metadata_value + +def get_tpu_env_value(key) -> Optional[str]: + # First try to get the value from the environment. value = os.environ.get(key, None) - return value if value is not None else get_tpu_env_value_from_metadata(key) - -def has_megascale_address(): - return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None + if value is None: + # If not found, try to get it from the metadata. + value = get_tpu_env_value_from_metadata(key) + return value class BaseTpuCluster(clusters.ClusterEnv): @@ -94,12 +97,11 @@ def is_env_present(cls) -> bool: @classmethod def get_coordinator_address(cls, timeout_secs: int | None) -> str: - if has_megascale_address(): - # For both GCE via QueuedResources and GKE via JobSet, the - # Megascale coordinator address is set as the host with process id = 0, - # so can be used as the jax distributed system coordinator. - coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') - else: + # For both GCE via QueuedResources and GKE via JobSet, the + # Megascale coordinator address is set as the host with process id = 0, + # so can be used as the jax distributed system coordinator. + coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + if not coordinator_address: # For both GCE (QueuedResources and TPUVM create) and GKE via Job API, # the workers lists are sorted by process ID so the first one can # be used as the jax distributed system coordinator. @@ -149,17 +151,18 @@ def get_process_id(cls) -> int: @staticmethod def _get_num_slices() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) - else: + num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES') + if not num_slices: return 1 + return int(num_slices) # type: ignore + @staticmethod def _get_slice_id() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_SLICE_ID')) - else: + slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID') + if not slice_id: return 0 + return int(slice_id) # type: ignore @staticmethod def _get_process_id_in_slice() -> int: From 56bbaa36816185da6934ba7b292c3c6c9bfb6029 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 May 2025 17:57:56 -0700 Subject: [PATCH 1091/1769] Add host offloading docs to public website PiperOrigin-RevId: 756532133 --- docs/advanced_guide.rst | 1 + docs/conf.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index db2e83ae2720..1cc48b8959dd 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -14,6 +14,7 @@ operations. notebooks/Distributed_arrays_and_automatic_parallelization notebooks/explicit-sharding notebooks/shard_map + notebooks/host-offloading multi_process distributed_data_loading diff --git a/docs/conf.py b/docs/conf.py index addf0cf50676..a7a52c9db38c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -143,7 +143,6 @@ def _do_not_evaluate_in_jax( 'autodidax2_part1.md', 'sharded-computation.md', 'ffi.ipynb', - 'notebooks/host-offloading.ipynb', ] # The name of the Pygments (syntax highlighting) style to use. From f8a3f06129d783a75692979b95d6229633b39227 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 8 May 2025 19:00:27 -0700 Subject: [PATCH 1092/1769] Add supports_pinned_allocator to allow debugging pinning issues. PiperOrigin-RevId: 756548068 --- jaxlib/py_socket_transfer.cc | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index 89900b02bd93..491e90d778cf 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -197,7 +197,8 @@ class PyTransferServer { PyTransferServer() = default; absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, size_t xfer_size, const SocketAddress& addr, - const std::vector& transport_addresses) { + const std::vector& transport_addresses, + bool supports_pinned_allocator) { std::shared_ptr factory; if (transport_addresses.empty()) { factory = BulkTransportFactory::CreateLocal(); @@ -207,8 +208,16 @@ class PyTransferServer { SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( client, tmp->data(), tmp->size(), tmp)), xfer_size); + std::optional pinned_allocator; + if (supports_pinned_allocator) { + auto tmp = xla::ValueOrThrow( + AllocateNetworkPinnedMemory(xfer_size * max_num_parallel_copies)); + pinned_allocator.emplace(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + } factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( - transport_addresses, std::nullopt, uallocator)); + transport_addresses, pinned_allocator, uallocator)); } server_ = std::make_shared(); @@ -387,8 +396,8 @@ void RegisterTransferServerTypes(nanobind::module_& m) { "start_transfer_server", [](xla::nb_class_ptr py_client, std::string address, std::vector transport_addresses_str, - size_t max_num_parallel_copies, - size_t transfer_size) -> PyTransferServer { + size_t max_num_parallel_copies, size_t transfer_size, + bool supports_pinned_allocator) -> PyTransferServer { PyTransferServer result; std::vector transport_addresses; transport_addresses.reserve(transport_addresses_str.size()); @@ -399,13 +408,15 @@ void RegisterTransferServerTypes(nanobind::module_& m) { xla::ThrowIfError(result.Start( py_client->ifrt_client(), max_num_parallel_copies, transfer_size, xla::ValueOrThrow(SocketAddress::Parse(address)), - transport_addresses)); + transport_addresses, supports_pinned_allocator)); return result; }, nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), nb::arg("transport_addresses") = std::vector(), nb::arg("max_num_parallel_copies") = 8, - nb::arg("transfer_size") = 256 * 1024 * 1024); + nb::arg("transfer_size") = 256 * 1024 * 1024, + // Dual pinning not confirmed to be supported. + nb::arg("supports_pinned_allocator") = false); } } // namespace aux From 60212a390f9e86877943ef82bfe2fb6596eb32fd Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Thu, 8 May 2025 19:28:44 -0700 Subject: [PATCH 1093/1769] [JAX:ANN] Make a test more reliable by increasing the sample size PiperOrigin-RevId: 756554539 --- tests/ann_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ann_test.py b/tests/ann_test.py index 1d704c725c61..18bb51bec93b 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -179,7 +179,7 @@ def approx_max_k(qy, db): def test_vmap_after(self): - batch = 4 + batch = 8 qy_size = 128 db_size = 1024 feature_dim = 32 From e48fe360c66b88904159349bb6c6f429a89bb5dc Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Fri, 9 May 2025 10:21:57 +0000 Subject: [PATCH 1094/1769] pytest: use importlib mode by default Otherwise the default (prepend) mode will add `/tests` and `` to the start of `sys.path`. This is (has always been) fragile, because `` has `jax/` and `jaxlib/` subdirectories, but it recently broke as `setuptools` 80 has a change in behaviour for editable installations, with the result that if `jaxlib` is installed editable then `import jaxlib` with `` in `sys.path` will try to import from `/jaxlib` instead of the editable install location, and fail. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 03cc78a6dcbb..cf8002ffe610 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" [tool.ruff] preview = true From 244cb362bcd2ae95c8d0da6d45f2ba6a103012cf Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 9 May 2025 04:23:11 -0700 Subject: [PATCH 1095/1769] [Mosaic GPU] Use explicit load/store methods instead of __getitem__/__setitem__ We pretty never use slicing in those methods and I want to add the ability to load in other layouts than the default one (which means we will need extra non-index arguments). PiperOrigin-RevId: 756707034 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- .../mosaic/gpu/examples/matmul_blackwell.py | 2 +- jax/experimental/mosaic/gpu/tcgen05.py | 47 ++++++++----------- tests/mosaic/gpu_test.py | 18 +++---- 4 files changed, 31 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 55aca65c6a19..d50e39d5c3db 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1276,7 +1276,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): raise NotImplementedError( "Only trivial indexing is supported for TMEM refs.") - return x_ref[:] + return x_ref.load() if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): raise TypeError(f"Can only load from references (got {x_ref}).") diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 6af394d00138..03363c1e365f 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -162,7 +162,7 @@ def _mma_body(ki, accumulate): gpu.barrier() mma_done_barrier.wait(for_tensor_core=True) - acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) + acc.load().astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) mgpu.commit_shared() ctx.async_copy( src_ref=d_smem, diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 4726805f5b76..f4ea8e289f01 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -35,6 +35,15 @@ TMEM_ROWS = 128 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) +LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=(-4, -3), + vector_dim=-1, +) + def create_instr_descriptor( m: int, @@ -582,7 +591,9 @@ def slice(self, *idxs): if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") match self.layout: - case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): pass case _: raise NotImplementedError( @@ -607,18 +618,17 @@ def slice(self, *idxs): dtype=self.dtype, ) - def __getitem__(self, *idxs): + def load(self, layout: fa.TiledLayout = LAYOUT): i32 = ir.IntegerType.get_signless(32) - base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) - if any(is_squeezed): - raise ValueError("TMEM loads only support slicing") - if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: - raise NotImplementedError("Slicing of TMEM not impelmented yet") if self.shape[1] % 8: raise NotImplementedError if utils.bitwidth(self.dtype) not in {16, 32}: raise NotImplementedError(f"Unsupported dtype: {self.dtype}") - layout = _m128_layout(self.shape) + if layout != LAYOUT: + raise ValueError( + "TMEM loads can only produce results in the tcgen05 layout" + f" ({LAYOUT}), but got: {layout}" + ) regs_shape = layout.registers_shape(self.shape) match self.layout: case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: @@ -653,16 +663,7 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) - def __setitem__(self, idxs, value): - if not isinstance(idxs, tuple): - idxs = (idxs,) - base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) - if any(is_squeezed): - raise ValueError( - "TMEM stores don't support integer indexing (only slices allowed)" - ) - if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: - raise NotImplementedError("Slicing parts of TMEM not implemented yet") + def store(self, value): if self.shape[1] % 8: raise NotImplementedError if utils.bitwidth(self.dtype) not in {16, 32}: @@ -842,16 +843,6 @@ def _m128_layout(shape: tuple[int, ...]): return LAYOUT -# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. -# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) -LAYOUT = fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, -) - - def commit_tmem(): void = ir.Type.parse("!llvm.void") llvm.inline_asm( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 50d60cae0080..8c26f64bd203 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -922,9 +922,9 @@ def kernel(ctx, input, output, scratch): barrier=barrier, ) barrier.wait() - tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT)) tcgen05.commit_tmem() - tmem[:].store_tiled(smem, swizzle) + tmem.load().store_tiled(smem, swizzle) mgpu.commit_shared() ctx.async_copy( src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling), @@ -964,7 +964,7 @@ def kernel(ctx, input, output, scratch): barrier=barrier, ) barrier.wait() - tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT)) tcgen05.commit_tmem() tmem.slice(slice(None), slice(0, 8))._debug_print() @@ -1075,7 +1075,7 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out, optimized=False) + acc.load().store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) @@ -1144,8 +1144,10 @@ def kernel(ctx, lhs, rhs, out, scratch): ) barriers[0].wait() barriers[1].wait() - lhs_tmem[:] = fa.FragmentedArray.load_tiled( - lhs_smem, swizzle, layout=tcgen05.LAYOUT + lhs_tmem.store( + fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) ) tcgen05.commit_tmem() with mgpu.single_thread(): @@ -1154,7 +1156,7 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out, optimized=False) + acc.load().store_untiled(out, optimized=False) x_shape = (m, k) x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) @@ -1246,7 +1248,7 @@ def kernel(ctx, lhs, rhs, out, scratch): tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) barriers[2].wait(for_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice), optimized=False) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant From 9428fa04c6853bb44868c789b458e28517334cac Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 9 May 2025 06:14:18 -0700 Subject: [PATCH 1096/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0ca232a612deb80c64bc0e6a55f3b9bbd198b27f. PiperOrigin-RevId: 756737174 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b98430d24198..e0a7fc8db91a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f08f9dc30cb4c0c638c244ba59a09722dcbedad5" -XLA_SHA256 = "8259f57fd5a475557f97f5365612646167d8c7c2d849202923a6fda88a218de2" +XLA_COMMIT = "0ca232a612deb80c64bc0e6a55f3b9bbd198b27f" +XLA_SHA256 = "3605a12ccb161443e3893f46e0cd05f96b271692d277185a258781989ea44c29" def repo(): tf_http_archive( From 60892e691fcad26f8e81900030fe1d498fff9f16 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 9 May 2025 06:51:20 -0700 Subject: [PATCH 1097/1769] [Pallas/TPU] * Ensure we wrap scheduler initialize/finalize with grid_env * Add support for basic swaps/gets to pull_block_spec PiperOrigin-RevId: 756747431 --- jax/_src/pallas/fuser/block_spec.py | 168 ++++++++++++++++++++++++++ jax/_src/pallas/mosaic/pipeline.py | 6 +- tests/pallas/fuser_block_spec_test.py | 95 +++++++++++++++ 3 files changed, 267 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index ba2c182014b5..3d4df549949c 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -29,6 +29,7 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import pjit +from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.interpreters import partial_eval as pe @@ -451,6 +452,8 @@ def _remove_nones( _no_aval = object() def _get_block_aval(bs, aval): + if isinstance(aval, state.AbstractRef): + return aval if bs is pallas_core.no_block_spec or bs is None: return _no_aval return aval.update(shape=_remove_nones(bs.block_shape)) # pytype: disable=attribute-error @@ -1065,6 +1068,171 @@ def new_index_map(*args): len(ctx.avals_in) - 1 ) +@register_pull_block_spec_rule(state_primitives.swap_p) +def _swap_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **kwargs, +): + del ctx, kwargs + # The output and val block spec are the same. + return [block_spec, block_spec] + +@register_eval_rule(state_primitives.swap_p) +def _swap_eval_rule( + ctx: KernelEvalContext, + ref, + val, + *idx, + tree +): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval, _ = ctx.avals_in[:2] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:]) + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('swap not supported yet') + indexer_aval = indexers_avals[0] + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.start, int): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('swap not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('swap not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('swap not supported yet') + if idx_aval.size != size: + raise NotImplementedError('swap not supported yet') + # We have a pure slice so now we can just re-index the ref according to the + # block indices. + block_spec = ctx.out_block_specs[0] + block_idx = ctx.get_out_block_indices()[0] + + def _slice(i, b): + if not isinstance(b, int): + raise NotImplementedError('swap not supported yet') + return i if b is None else indexing.ds(i * b, b) + + indexer = tuple( + _slice(i, b) for i, b in zip(block_idx, block_spec.block_shape, + strict=True) + ) + return ref.swap(val, idx=indexer) + +@register_pull_block_spec_rule(state_primitives.get_p) +def _get_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + tree +): + ref_aval = ctx.avals_in[0] + assert hasattr(ref_aval, 'shape') + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + if len(indexers_avals) > 1: + raise NotImplementedError('get not supported yet') + indexer_aval = indexers_avals[0] + block_shape_iter = iter(block_spec.block_shape) + block_shape = [] + if not all( + isinstance(bd, (int, pallas_core.Blocked, pallas_core.Squeezed, None)) + for bd in block_spec.block_shape + ): + raise NotImplementedError('get not supported yet') + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape + block_shape.append(pallas_core.Squeezed()) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bd = next(block_shape_iter) + block_shape.append(_block_size(bd)) + assert next(block_shape_iter, None) is None + def new_index_map(*args): + idx = block_spec.index_map(*args) + idx_iter = iter(idx) + indices = tuple( + 0 + if (bd is None or isinstance(bd, pallas_core.Squeezed)) + else next(idx_iter) + for bd in range(len(block_shape)) + ) + assert next(idx_iter, None) is None + return indices + block_spec = pallas_core.BlockSpec(block_shape, new_index_map) + return [block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1) + +@register_eval_rule(state_primitives.get_p) +def _get_eval_rule( + ctx: KernelEvalContext, + ref, + *idx, + tree +): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval = ctx.avals_in[0] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + ref_block_spec = ctx.in_block_specs[0] + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('get not supported yet') + indexer = indexers[0] + indexer_aval = indexers_avals[0] + block_indexer = [] + + def _slice(i, b): + match b: + case int(): + return indexing.ds(i * b, b) + case pallas_core.Blocked(bs): + return indexing.ds(i * bs, bs) + case pallas_core.Squeezed() | None: + return i + case _: + raise NotImplementedError('get not supported yet') + + if ref_block_spec is pallas_core.no_block_spec: + # Short-circuit if the ref is not blocked. + return state_primitives.get_p.bind(ref, *idx, tree=tree) + block_idx_iter = iter(ctx.get_out_block_indices()[0]) + for idx_aval, size, idx, bd in zip( + indexer_aval.indices, + ref_aval.shape, + indexer.indices, + ref_block_spec.block_shape, + strict=True, + ): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape, idx_aval + assert bd is None or isinstance(bd, pallas_core.Squeezed) + block_indexer.append(idx) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bidx = next(block_idx_iter) + block_indexer.append(_slice(bidx, bd)) + assert next(block_idx_iter, None) is None + return ref.get(idx=tuple(block_indexer)) @register_eval_rule(lax.concatenate_p) def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index df7be297c9e8..4ef22179260b 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1369,7 +1369,8 @@ def _(): initial_indices = (0,) * len(grid) scheduler = make_scheduler(0, initial_indices) brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.initialize, brefs, refs, schedule) + with scheduler.grid_env(): + map_brefs(scheduler.initialize, brefs, refs, schedule) # pipeline loop next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) @@ -1378,7 +1379,8 @@ def _(): final_indices = _prev_index(next_indices, grid) scheduler = make_scheduler(num_steps - 1, final_indices) brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.finalize, brefs, refs, schedule) + with scheduler.grid_env(): + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 665cdfb1dd6b..f7e70ec1d708 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -761,6 +761,101 @@ def f(x): y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) np.testing.assert_array_equal(y, x.reshape((256, 1024))) + def test_basic_swap(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + x = jnp.zeros((256, 512), dtype=jnp.int32) + def outer(refs): + ref, y_ref = refs + def f(x): + return ref.swap(x) + in_type = jax.ShapeDtypeStruct((512, 1024), jnp.int32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertLen(new_values, 1) # Captures Ref + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertLen(value_block_specs, 1) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + y_ref[...] = kernel_fn((0, 1, 1), scalar_prefetch_values, (ref,), x) + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_basic_get(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + def outer(refs): + ref, y_ref = refs + def f(): + return ref.get() + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = ( + block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + ) + y_ref[...] = kernel_fn((0, 1, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_get_with_squeezed_block_spec(self): + value = jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) * 2 + def outer(refs): + ref, y_ref = refs + def f(): + return ref.get() + + block_spec = pl.BlockSpec((pl.Squeezed(), 256, 512), lambda i, j, k: (j, i, k)) + kernel_fn, (), _ = ( + block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + ) + y_ref[...] = kernel_fn((0, 3, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + + def test_get_with_squeezed_indexer(self): + value = jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) * 2 + def outer(refs): + ref, y_ref = refs + def f(): + return ref[3] + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = ( + block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + ) + y_ref[...] = kernel_fn((0, 2, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + class PullBlockSpecHOPTest(jtu.JaxTestCase): From d551839d1531fd3affcd82e936f87f971850ca09 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 9 May 2025 09:09:20 -0700 Subject: [PATCH 1098/1769] Split custom_* tests out of api_test into new target. The `tests/api_test.py` file is getting a bit unwieldy, and this seems like a reasonable place for a split. These customization APIs have a large testing surface area, and are conceptually different enough from the rest of the API tests that it seems defensible for them to live in a separate target. PiperOrigin-RevId: 756790365 --- tests/BUILD | 10 + tests/api_test.py | 4581 ------------------------------------- tests/custom_api_test.py | 4625 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 4635 insertions(+), 4581 deletions(-) create mode 100644 tests/custom_api_test.py diff --git a/tests/BUILD b/tests/BUILD index 20a37ba746de..4c6369bee5de 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -41,6 +41,15 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "custom_api_test", + srcs = ["custom_api_test.py"], + shard_count = 10, + deps = [ + "//jax:experimental", + ], +) + jax_multiplatform_test( name = "debug_info_test", srcs = ["debug_info_test.py"], @@ -1697,6 +1706,7 @@ jax_py_test( exports_files( [ "api_test.py", + "custom_api_test.py", "array_test.py", "cache_key_test.py", "colocated_python_test.py", diff --git a/tests/api_test.py b/tests/api_test.py index 48ce56ee935c..6e55e732151d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,7 +16,6 @@ import collections import collections.abc -from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -34,7 +33,6 @@ import re import subprocess import sys -import textwrap import traceback import types from typing import NamedTuple @@ -44,7 +42,6 @@ from absl import logging from absl.testing import absltest, parameterized import jax -from jax import custom_derivatives as custom_derivatives_public from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit from jax import lax from jax import tree_util @@ -52,8 +49,6 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -67,10 +62,6 @@ from jax._src.lib import _jax import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint -import jax.custom_batching -import jax.custom_derivatives -import jax.custom_transpose -import jax.experimental.custom_dce from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError, ConcretizationTypeError, TracerBoolConversionError) from jax.experimental import pjit @@ -7062,4578 +7053,6 @@ def f(x1, x2): self.assert_dce_result(jaxpr, [True, False], [True, True], 5) -class CustomJVPTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) - - def test_invariance(self): - @jax.custom_jvp - def f(x): - return jnp.cos(2 * x) / 2. - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return (f(x), 3 * g) - f.defjvp(f_jvp) - def f2(x): - y, _ = api.jvp(f, (x,), (x,)) - return y - def f3(x): - y, _ = api.jvp(f2, (x,), (x,)) - return y - x = 1. - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f2, (x,), (x,)), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f3, (x,), (x,)), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_jvp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - if x > 0: - return f(x), 2 * g - else: - return f(x), 3 * g - f.defjvp(f_jvp) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (-x,), (1.,)), - (jnp.cos(-x), 3.), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) - self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) - - def test_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - assert jnp.ndim(x) == jnp.ndim(g) == 0 - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) - - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) - - # vmap of jvp of f - self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) - - # jvp of vmap of f - self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) - - # vmap of jvp of vmap of f - self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) - - def test_jit(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - x = 3. - - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) - - # jit of jvp - self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) - - # jvp of jit - self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) - - def test_pytrees(self): - @jax.custom_jvp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} - f.defjvp(f_jvp) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.jvp(f, (x,), (x,)), - ({'b': jnp.sin(x['a'])}, - {'b': 2 * jnp.cos(x['a']) * x['a']}), - check_dtypes=False) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_jvp - def my_fun(x, y, c=1.): - return c * (x + y) - def my_jvp(primals, tangents): - x, y, c = primals - t_x, t_y, t_c = tangents - return my_fun(x, y, c), t_c - my_fun.defjvp(my_jvp) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash - - def test_initial_style(self): - @jax.custom_jvp - def f(x): - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(foo))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.jit(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.vmap(api.jit(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.vmap(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_jvp - def f(x): - return lax.psum(x, 'foo') - - @f.defjvp - def f_jvp(xs, ts): - x, = xs - t, = ts - return lax.psum(x, 'foo'), t - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(v, 8.) - - def test_closed_over_tracers_error_message(self): - def f(x): - @jax.custom_jvp - def g(y): - return x + y - def g_jvp(primals, tangents): - return g(x), 2 * primals[0] - g.defjvp(g_jvp) - return g(1.) - - self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) - - def test_nondiff_arg(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_jvp(f, primals, tangents): - (x,), (t,) = primals, tangents - return app(f, x), 3 * t - app.defjvp(app_jvp) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) - expected = (2., 3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_jit_tracer(self): - # This test would pass with "final-style" JIT tracing, but that was - # misleading: it doesn't work with "initial-style" staging, i.e. control - # flow primitives like jax.lax.scan or even pjit. The behavior isn't very - # useful either: instead of using nondiff_argnums here, a user can just pass - # such inputs as ordinary arguments, and ignore the corresponding tangents. - # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a - # string- or callable-valued argument which parameterizes the function or - # rule) or (2) static data (e.g. integers which parameterize shapes). - raise unittest.SkipTest("behavior no longer supported") - - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - @jit - def g(x, y): - return f(x, y) - - ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) - expected = (6., 5.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_vmap_tracer(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - g = jax.vmap(f) - - ans = api.jvp(lambda y: g(jnp.array([2.]), y), - (jnp.array([3.]),), (jnp.array([1.]),)) - expected = (jnp.array([6.]), jnp.array([5.])) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_hiding_jvp_tracer(self): - def f(x): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def g(h, x): - return h(x) - @g.defjvp - def g_jvp(h, primals, tangents): - x, = primals - t, = tangents - return g(h, x), 2. * t - h = lambda y: x + y # capture x - return g(h, x) - - with self.assertRaises(UnexpectedTracerError): - api.jvp(f, (2.,), (1.,)) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_jvp_rule_error_message(self): - @jax.custom_jvp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.jvp(foo, (2.,), (1.,))) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.grad(foo)(2.)) - - def test_jvp_rule_inconsistent_pytree_structures_error_message(self): - @jax.custom_jvp - def f(x): - return (x**2,) - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), [2 * x * t, x] - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce primal and tangent outputs " - "with equal container (pytree) structures, but got " - "{} and {} respectively.".format( - jax.tree.structure((1,)), - jax.tree.structure([1, 2])) - ), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_primal_tangent_aval_disagreement_error_message(self): - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), jnp.reshape(t, (1,)) - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule must produce primal and tangent outputs " - "with corresponding shapes and dtypes. " - "Expected float32[] (tangent type of float32[]) but got float32[1]."), - lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) - - - def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/jax-ml/jax/issues/2516 - - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return t - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce a pair (list or tuple of length two) " - "representing primal and tangent outputs, but got 1.0"), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_jvp - def f(x): - return x - - @f.defjvp - def f_jvp(primals, tangents): - (x,), (xdot,) = primals, tangents - return (x, x), (xdot, xdot) - - x = jnp.float32(1.) - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular of the " - "same container/pytree structure), but instead the JVP rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_jvp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - @f.defjvp - def f_jvp2(primals, tangents): - (x,), (xdot,) = primals, tangents - return jnp.zeros((3, *x.shape), x.dtype), xdot - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular " - "with leaves of the same shape/dtype), but instead the JVP rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_jvp-decorated function f had output shapes/dtypes" - " of:\n" - " float32[]" - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - def test_multiple_rule_invocations(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - def scanned_fun(c, _): - return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def foo(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # just make sure these don't crash - foo(3.) - grad(foo)(3.) - grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) - - def test_hard_stuff(self): - arr = jnp.ones((5, 2, 2)) - api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash - - def test_hard_stuff2(self): - @jax.custom_jvp - def f(x): - return np.zeros(x.shape, x.dtype) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), t - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_hard_stuff3(self): - @jax.custom_jvp - def relu(x): - return jnp.maximum(x, 0) - - @relu.defjvp - def _relu_jvp(primals, tangents): - x, = primals - t, = tangents - return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) - - def scanned_fun(c, _): - return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def f(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_eval_shape(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - # don't crash - api.eval_shape(expit, jnp.ones((2, 3))) - api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) - - def test_jaxpr_zeros(self): - # from https://github.com/jax-ml/jax/issues/2657 - @jax.custom_jvp - def f(A, b): - return A @ b - - def f_jvp(primals, tangents): - A, b = primals - dA, db = tangents - z = f(A, b) - dz = A @ db + dA @ b - return z, dz - - f.defjvp(f_jvp) - - def experiment(theta): - def step(q, _): - z = f(jnp.eye(3), jnp.ones(3) * theta) - q += z[0] - return q, q - - q = 0. - q, _ = lax.scan(step, q, None, 4) - return q - - grad(experiment)(1.) # doesn't crash - - def test_linear_in_scan(self): - @jax.custom_jvp - def f(x): - return -x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - return f(x), f(x_dot) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = -1. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvps_first_rule_is_none(self): - # https://github.com/jax-ml/jax/issues/3389 - @jax.custom_jvp - def f(x, y): - return x ** 2 * y - - f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) - ans = grad(f, 1)(2., 3.) # doesn't crash - expected = 12. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_concurrent_initial_style(self): - # https://github.com/jax-ml/jax/issues/3843 - def unroll(param, sequence): - def scan_f(prev_state, inputs): - return prev_state, jax.nn.sigmoid(param * inputs) - return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) - - def run(): - return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) - - expected = run() - - # we just don't want this to crash - n_workers = 2 - with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: - futures = [] - for _ in range(n_workers): - futures.append(e.submit(run)) - results = [f.result() for f in futures] - for ans in results: - self.assertAllClose(ans, expected) - - def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/jax-ml/jax/issues/3964 - @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) - def sample(shape, param, seed): - return jax.random.uniform(key=seed, shape=shape, minval=param) - - @sample.defjvp - def sample_jvp(shape, seed, primals, tangents): - param, = primals - dparam, = tangents - dparam = jnp.broadcast_to(dparam, shape) - samples = sample(shape, param, seed) - return samples, samples * dparam # dummy jvp for proof of concept - - # check these don't crash - jax.vmap(lambda seed: sample((2,3), 1., seed))( - jax.random.split(jax.random.key(1), 10)) - jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), - (1.,), (1.,)) - - def test_fun_with_nested_calls_2(self): - def call(f, *args): - f = jax.custom_jvp(f) - f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) - return f(*args) - - def fun_with_nested_calls_2(x): - def bar(y): - def baz(w): - q = call(lambda x: y, x) - q = q + call(lambda: y) - q = q + call(lambda y: w + y, y) - q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q - return q - return api.jit(baz)(x) - return call(bar, x) - - # test these don't crash - self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), - fun_with_nested_calls_2(3.)) - api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) - - def test_closure_with_vmap(self): - # https://github.com/jax-ml/jax/issues/3822 - alpha = np.float32(2.) - - def sample(seed): - @jax.custom_jvp - def f(alpha): - return jax.random.gamma(seed, alpha, shape=[]) - - @f.defjvp - def f_jvp(primal, tangent): - alpha = primal - dalpha = tangent - sample = f(alpha) - partial_alpha = lax.random_gamma_grad(alpha, sample) - return sample, partial_alpha * dalpha - return f(alpha) - - api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_jvp - def g(y): - return x * y - - # NOTE: rule closes over vmap tracer - @g.defjvp - def g_jvp(primals, tangents): - (y,), (ydot,) = primals, tangents - return x * y, x * ydot - - return g(z) # NOTE: no vmapped arg - - return jax.vmap(f)(jnp.arange(3., dtype='float32')) - - primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) - self.assertAllClose(primals , jnp.arange(3., dtype='float32')) - self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) - - def test_float0(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - self.assertAllClose(api.jvp(f, primals, tangents), - (primals, expected_tangents)) - - def test_float0_initial_style(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) - return out - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - - self.assertAllClose(api.jvp(foo, primals, tangents), - (primals, expected_tangents)) - - def test_remat(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(new_checkpoint(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(new_checkpoint(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(new_checkpoint(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_2(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/jax-ml/jax/issues/6452 - def f2(y, z): - v1 = z - v2 = jnp.sum(y) + z - return jnp.logaddexp(v1, v2) - - def f1(y, z): - v = api.vmap(lambda _y: f2(_y, z))(y) - return jnp.sum(v) - - y = jnp.ones((3, 2)) - f = lambda z: f1(y, z) - z = 0.1 - val, g = api.value_and_grad(f)(z) - self.assertEqual(val.shape, ()) - self.assertEqual(g.shape, ()) - - def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/jax-ml/jax/issues/5849 - @jax.custom_jvp - def transform(box, R): - if jnp.isscalar(box) or box.size == 1: - return R * box - elif box.ndim == 2: - return jnp.einsum('ij,j->i', box, R) - raise ValueError() - - @transform.defjvp - def transform_jvp(primals, tangents): - box, R = primals - dbox, dR = tangents - return (transform(box, R), dR + transform(dbox, R)) - - def periodic_general(box): - def displacement_fn(Ra, Rb, **kwargs): - _box = kwargs.get('box', box) - return transform(_box, Ra - Rb) - - return displacement_fn - - N = 250 - - scalar_box = 1.0 - displacement = periodic_general(scalar_box) - - key = jax.random.key(0) - R = jax.random.uniform(key, (N, 2)) - - def energy_fn(box): - d = partial(displacement, box=box) - d = api.vmap(api.vmap(d, (None, 0)), (0, None)) - return jnp.sum(d(R, R) ** 2) - - self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) - - def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/jax-ml/jax/issues/6357 - if config.enable_x64.value: - raise unittest.SkipTest("test only applies when x64 is disabled") - - @jax.custom_jvp - def projection_unit_simplex(x: jax.Array) -> jax.Array: - """Projection onto the unit simplex.""" - s = 1.0 - n_features = x.shape[0] - u = jnp.sort(x)[::-1] - cssv = jnp.cumsum(u) - s - ind = jnp.arange(n_features, dtype=x.dtype) + 1 - cond = u - cssv / ind > 0 - idx = jnp.count_nonzero(cond) - threshold = cssv[idx - 1] / idx.astype(x.dtype) - return jax.nn.relu(x - threshold) - - - @projection_unit_simplex.defjvp - def projection_unit_simplex_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = projection_unit_simplex(x) - supp = (primal_out > 0).astype(x_dot.dtype) - card = jnp.count_nonzero(supp).astype(x_dot.dtype) - tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp - return primal_out, tangent_out - - rng = self.rng() - x = rng.rand(5).astype(np.float32) - - J_rev = jax.jacrev(projection_unit_simplex)(x) - J_fwd = jax.jacfwd(projection_unit_simplex)(x) - - p = projection_unit_simplex(x) - support = (p > 0).astype(jnp.float32) - cardinality = jnp.count_nonzero(support).astype(support.dtype) - J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality - self.assertAllClose(J_true, J_fwd) - self.assertAllClose(J_true, J_rev) - - proj = jax.vmap(projection_unit_simplex) - - def fun(X): - return jnp.sum(proj(X) ** 2) - - rng = self.rng() - X = rng.rand(4, 5).astype(np.float32) - U = rng.rand(4, 5) - U /= np.sqrt(np.sum(U ** 2)) - U = U.astype(np.float32) - - eps = 1e-3 - dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) - dir_deriv = jnp.vdot(jax.grad(fun)(X), U) - self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) - - def test_vmap_inside_defjvp(self): - # https://github.com/jax-ml/jax/issues/3201 - seed = 47 - key = jax.random.key(seed) - mat = jax.random.normal(key, (2, 3)) - - @jax.custom_jvp - def f(mat, aux): - num_rows, num_cols = mat.shape - return jnp.ones((num_rows, 1)) / num_cols - - @f.defjvp - def f_jvp(primals, tangents): - mat, aux = primals - vec, _ = tangents - output = f(*primals) - num_rows, num_cols = mat.shape - size = num_rows * num_cols - # ----- - bd_mat = mat.reshape(1, 1, num_rows, num_cols) - bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) - bd_mat = bd_mat.reshape(size, num_rows, num_cols) - # ----- - rowsum = jnp.sum(mat, axis=1, keepdims=True) - colsum = jnp.sum(mat, axis=0, keepdims=True) - bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) - bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) - # ----- - bd_vec = vec.reshape(size, 1) - # ----- - def operate(mx, val): - buf = 0 - for i in range(2): - buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) - buf = jnp.matmul(bd_rowsum, buf) - return buf * val[None, :] - # ----- - # Vertorizing will raise shape error - bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) - # ----- - bd_buf = bd_buf / aux - jvp = jnp.sum(bd_buf, axis=0) - jvp = jnp.mean(jvp, axis=1, keepdims=True) - # ----- - # JVP ends successfully, but still raise an error - return (output, jvp) - - jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash - - def test_custom_jvp_unbroadcasting(self): - # https://github.com/jax-ml/jax/issues/3056 - a = jnp.array([1., 1.]) - - @jax.custom_jvp - def f(x): - return a * x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - dx, = tangents - return a * x, a * dx - - shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape - self.assertEqual(shape, ()) - - def test_maybe_perturbed_internal_helper_function(self): - # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/jax-ml/jax/issues/6415. - def f(x): - def g(y, _): - z = y * x - self.assertTrue(custom_derivatives._maybe_perturbed(z)) - return y, None - g(1, None) - return lax.scan(g, 1, xs=None, length=1)[0] - - jax.jvp(f, (1.0,), (1.0,)) # assertions inside f - - def test_maybe_perturbed_int_regression(self): - # see https://github.com/jax-ml/jax/discussions/9951 - - @jax.jit - def f(): - x = jnp.array(1) - _, aux_args = custom_derivatives.closure_convert(lambda: x) - self.assertEmpty(aux_args) - f() - - def test_sinc_constant_function_batching(self): - # https://github.com/jax-ml/jax/pull/10756 - batch_data = jnp.arange(15.).reshape(5, 3) - - @jax.vmap - def f(x): - return jax.lax.map(jnp.sinc, x) - g = lambda param: f(param * batch_data).sum() - - @jax.vmap - def f_ref(x): - return jnp.stack([jnp.sinc(x_) for x_ in x]) - g_ref = lambda param: f_ref(param * batch_data).sum() - - grad = jax.grad(g )(0.1) # doesn't crash - grad_ref = jax.grad(g_ref)(0.1) - self.assertAllClose(grad, grad_ref, check_dtypes=False) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_jvp(f) - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - static_scalar, *_ = primals - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents - self.assertIs(type(t_static) , custom_derivatives_public.SymbolicZero) - self.assertIs(type(t_static_arr), custom_derivatives_public.SymbolicZero) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return f(*primals), (static_scalar + 90, t_dyn_array + 91) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - - def run(primal_ins, tangent_ins): - return jax.jvp(g, primal_ins, tangent_ins) - - if maybe_jit: - run = jax.jit(run) - - primal_ins = (4., jnp.array([5., 6.])) - tangent_ins = (7., jnp.array([8., 9.])) - primal_outs, tangent_outs = run(primal_ins, tangent_ins) - primal_out1, primal_out2 = primal_outs - tangent_out1, tangent_out2 = tangent_outs - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_type) - self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_type) - self.assertAllClose(tangent_out1, 91.) - self.assertIsInstance(primal_out2, jax.Array) - self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) - self.assertIsInstance(tangent_out2, jax.Array) - self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) - - def test_symbolic_zero_custom_jvp_vmap_output(self): - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - self.assertIs(type(y_dot), custom_derivatives_public.SymbolicZero) - return f(x, y), y_dot - - jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - return f(x, y), y_dot - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_symbolic_zeros_under_jit(self): - # https://github.com/jax-ml/jax/issues/14833 - Zero = jax.custom_derivatives.SymbolicZero - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def fjvp(primals, tangents): - x, y = primals - tx, ty = tangents - assert type(tx) is not Zero or type(ty) is not Zero - return f(x, y), ( - ty if type(tx) is Zero else - tx if type(ty) is Zero else - tx + ty) - - jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash - - def test_custom_jvp_functools_partial(self): - def fun(x, y, a): - return x + y * a - - fun_wrapped = functools.partial(fun, a = 0.1) - - def jvp_fn(primals, tangents): - return jax.jvp(fun_wrapped, primals, tangents) - - fn = jax.custom_jvp(fun_wrapped) - fn.defjvp(jvp_fn) - - self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_jvp - def f(x, y): - return x - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, _ = primals - x_dot, _ = tangents - return x, x_dot - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_dce(self): - @jax.custom_jvp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - x, y = primals - dx, dy = tangents - return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) - self.assertAllClose( - api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) - self.assertAllClose( - api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) - self.assertAllClose( - api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_jvp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - def test_symbolic_zero_custom_jvp_vmap_doesnt_instantiate(self): - @jax.custom_jvp - def f(x, y): - return y - - def f_jvp(primals, tangents): - (x, y), (x_dot, y_dot) = primals, tangents - assert type(y_dot) is custom_derivatives_public.SymbolicZero - return y, y_dot - - f.defjvp(f_jvp, symbolic_zeros=True) - - def g(x): - return f(x, f(x, 1.)) - - jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash - - def test_symbolic_zero_under_vmap_of_jit(self): - # https://github.com/jax-ml/jax/issues/28144 - @jax.custom_jvp - def f(x): - return x + 1 - - @f.defjvp - def f_jvp(x, t): - (x,) = x - (t,) = t - z = custom_derivatives_public.zero_from_primal(x, symbolic_zeros=True) - return f(x), z - - x = jnp.arange(3.0) - jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash - - def test_pretty_print(self): - @jax.custom_jvp - def f(x): - return x + 1 - - @f.defjvp - def f_jvp(primals, tangents): - return f(*primals), tangents[0] - - x = jnp.array([4.2], dtype=jnp.float32) - jaxpr = jax.make_jaxpr(f)(x) - actual = jaxpr.pretty_print(use_color=False) - expected = textwrap.dedent( - """ - { lambda ; a:f32[1]. let - b:f32[1] = custom_jvp_call[ - name=f - call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } - jvp=f_jvp - symbolic_zeros=False - ] a - in (b,) } - """).strip() - self.assertEqual(actual, expected) - - - -class CustomVJPTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) - self.assertAllClose(api.value_and_grad(f)(x), - (jnp.sin(x), 2 * jnp.cos(x))) - - def test_invariance(self): - @jax.custom_vjp - def f(x): - return jnp.cos(2 * x) / 2. - def f_fwd(x): - return (f(x), x) - def f_rev(x, g): - return (g * 3,) - f.defvjp(f_fwd, f_rev) - def f2(x): - y, _ = api.value_and_grad(f)(x) - return y - def f3(x): - y, _ = api.value_and_grad(f2)(x) - return y - x = 1. - self.assertAllClose(f(x), f2(x), check_dtypes=False) - self.assertAllClose(f(x), f3(x), check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_vjp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_fwd(x): - if x > 0: - return f(x), x - else: - return f(x), x - def f_rev(x, g): - if x > 0: - return (2 * g,) - else: - return (3 * g,) - f.defvjp(f_fwd, f_rev) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), - check_dtypes=False) - - def test_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_fwd(x): - assert jnp.ndim(x) == 0 - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) - - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) - - # vmap of grad of f - self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) - self.assertAllClose(api.vmap(api.value_and_grad(f))(x), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) - self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx))) - - # grad of vmap of f - self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), - 2 * jnp.cos(x)) - self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), - 2 * jnp.cos(xx)) - - # vmap of grad of vmap of f - self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), - 2 * jnp.cos(xx)) - - def test_jit(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) - - # jit of grad - self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - # grad of jit - self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - def test_pytrees(self): - @jax.custom_vjp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_fwd(x): - return f(x), {'r': jnp.cos(x['a'])} - def f_bwd(res, g): - cos_x = res['r'] - return ({'a': 2 * cos_x * g['b']},) - f.defvjp(f_fwd, f_bwd) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), - {'a': 2 * jnp.cos(x['a'])}) - - def test_jvp_error(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(jit(f), (3.,), (1.,))) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_vjp - def my_fun(x, y, c=1.): - return c * (x + y) - my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), - lambda _, g: (g, g, g)) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.grad(f)(10., 5.) # doesn't crash - - def test_initial_style(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = 2. * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(foo))(3.) - expected = -2. * jnp.sin(3.) - self.assertAllClose(ans, expected) - - def test_initial_style_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg(self): - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_fwd(f, x): - return app(f, x), jnp.cos(x) - def app_rev(f, cos_x, g): - return (cos_x * g,) - app.defvjp(app_fwd, app_rev) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) - expected = (2., jnp.cos(1.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_jit_tracer(self): - # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. - raise unittest.SkipTest("behavior no longer supported") - - # This test is similar to test_nondiff_arg_tracer except it uses lexical - # closure rather than the nondiff_argnums mechanism. We decided to disallow - # tracers in nondiff_argnums to greatly simplify bookkeeping while still - # supporting the cases for which it is necessary. - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @jit - def g(x, y): - return outer(x)(y) - - ans = g(2, 3.) - expected = 6. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g, 1)(2., 3.) - expected = jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_vmap_tracer(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.arange(3.) * 3 - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_tracer3(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), (x, jnp.cos(y)) - def f_rev(res, g): - x, cos_y = res - return (cos_y * g * x,) - f.defvjp(f_fwd, f_rev) - return api.grad(f) - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.cos(3.) * np.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_tracer_error(self): - # This is similar to the old (now skipped) test_nondiff_arg_tracer, except - # we're testing for the error message that usage pattern now raises. - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(x, cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - - @jit - def g(x, y): - return f(x, y) - - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = g(2, 3.) - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = api.grad(g, 1)(2., 3.) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_vjp_rule_error(self): - @jax.custom_vjp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: api.grad(foo)(2.)) - - def test_vjp_rule_inconsistent_pytree_structures_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return (g, g) - - f.defvjp(foo_fwd, foo_bwd) - - f(2) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP bwd rule must produce an output with the same container " - "(pytree) structure as the args tuple of the primal function, " - "and in particular must produce a tuple of length equal to the " - "number of arguments to the primal function, but got bwd output " - "structure {} for primal input structure {}.".format( - jax.tree.structure((1, 1)), - jax.tree.structure((1,))) - ), - lambda: api.grad(f)(2.)) - - def test_vjp_bwd_returns_non_tuple_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return 2. * g # Should be a tuple - - f.defvjp(foo_fwd, foo_bwd) - with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): - api.grad(f)(3.) - - def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_vjp - def f(x): - return x - - def f_fwd(x): - return (x, x), None - - def f_bwd(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd, f_bwd) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_vjp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def f_fwd2(x): - return jnp.zeros((3, *x.shape), x.dtype), None - - def f_bwd2(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd2, f_bwd2) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_vjp-decorated function f had output " - "shapes/dtypes of:\n" - " float32[]" - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def test_issue2511(self): - arr = jnp.ones((5, 2, 2)) - foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) - api.jit(foo)(arr) # doesn't crash - - def test_lowering_out_of_traces(self): - # https://github.com/jax-ml/jax/issues/2578 - - class F(collections.namedtuple("F", ["a"])): - def __call__(self, x): - return jax.nn.relu(self.a) * x - - @jax.jit - def g(f, x): - return f(x) - - jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash - - def test_clip_gradient(self): - # https://github.com/jax-ml/jax/issues/2784 - @jax.custom_vjp - def _clip_gradient(lo, hi, x): - return x # identity function when not differentiating - - def clip_gradient_fwd(lo, hi, x): - return x, (lo, hi,) - - def clip_gradient_bwd(res, g): - lo, hi = res - return (None, None, jnp.clip(g, lo, hi),) - - _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) - - def clip_gradient(x): - lo = -0.1 - hi = x + 0.1 - return _clip_gradient(lo, hi, x) - - g = jax.grad(clip_gradient)(0.1) # doesn't crash - self.assertAllClose(g, jnp.array(0.2)) - - def test_nestable_vjp(self): - # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. - def f(x): - return x ** 2 - - @jax.custom_vjp - def g(x): - return f(x) - - def g_fwd(x): - y, f_vjp = api.vjp(f, x) - return y, f_vjp - - def g_bwd(f_vjp, y_bar): - return f_vjp(y_bar) - - g.defvjp(g_fwd, g_bwd) - - # Check that VJP can be nested in simple situations. For this to pass, - # vjp has to return a PyTree. - _, g_vjp = api.vjp(g, 1.0) - y, = g_vjp(1.0) - self.assertAllClose(y, jnp.array(2.0)) - - # Check that VJP can be nested in complex situations. For this to pass, - # vjp can't treat the closed-over tracer x as a static argument. - @jit - def z(x): - _, g_vjp = api.vjp(g, x) - return g_vjp - y, = z(1.0)(3.0) - self.assertAllClose(y, jnp.array(6.0)) - - def test_initial_style_vmap_2(self): - # https://github.com/jax-ml/jax/issues/4173 - x = jnp.ones((10, 3)) - - # Create the custom function - @jax.custom_vjp - def custom_fun(x): - return x.sum() - - def forward(x): - return x.sum(), (jnp.ones_like(x),) - - def backward(res, g): - return g * res[0], - - custom_fun.defvjp(forward, backward) - - def train_fun(x): - - def summed_fun(x): - return api.vmap(custom_fun)(x).sum() - - return api.grad(summed_fun)(x) - - def scan_body(carry, inputs): - x = carry - return carry, train_fun(x) - - scan_range = jnp.arange(4) - lax.scan(scan_body, x, scan_range) # don't crash - - def test_initial_style_vmap_3(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) * 6 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_vjp - def f(x): - return lax.psum(x, 'foo') - - def f_fwd(x): - return lax.psum(x, 'foo'), None - - def f_bwd(res, dx): - return dx - f.defvjp(f_fwd, f_bwd) - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(out, 8.) - - def test_bwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), () - - def bwd(_, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_fwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), y - - def bwd(y, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_float0(self): - @jax.custom_vjp - def f(x, _): - return x - def f_fwd(x, _): - # we need a defined (non-float0) tangent to trigger the rule - return x, (2., 1) - def f_rev(*_): - return (2., 1) - f.defvjp(f_fwd, f_rev) - - x = 2. - y = 3 - self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_float0_initial_style(self): - @jax.custom_vjp - def f(x): - return x - def f_fwd(x): - return x, (2., x) - def f_rev(*_): - return ((2., jnp.zeros(shape=(), dtype=float0)),) - f.defvjp(f_fwd, f_rev) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) - return out[0] - - x = 2. - y = 3 - self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_remat(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(jax.remat(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(jax.remat(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f(x, x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_vmap(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) - expected = 2 * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_pytree(self): - @jax.custom_vjp - def f(xs, y): - x1, x2 = xs - return x1 * x2 * jnp.sin(y) - def f_fwd(xs, y): - return f(xs, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f((x, x), x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_vjp_closure_4521(self): - # https://github.com/jax-ml/jax/issues/4521 - @jax.custom_vjp - def g(x, y): - return None - def g_fwd(x, y): - return None, y - def g_bwd(residuals, z_bar): - assert False - - g.defvjp(g_fwd, g_bwd) - - def f(xs, y): - v_g = api.vmap(g, in_axes=(0, None), out_axes=None) - v_g(xs, y) - - def scan_body(xs, _): - y = jnp.zeros(1) - _, vjp_f = api.vjp(f, xs, y) - vjp_f(None) - return xs, None - - lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash - - def test_float0_bwd_none(self): - @jax.custom_vjp - def f(i, x): - return jnp.sin(x) - def f_fwd(i, x): - return f(i, x), jnp.cos(x) - def f_rev(cos_x, g): - return (None, 2 * cos_x * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_gradient(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: (g * x,) - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_custom_gradient_2(self): - @jax.custom_gradient - def f(x, y): - return x * y, lambda g: (y, x) - - self.assertAllClose(f(3., 4.), 12., check_dtypes=False) - self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), - check_dtypes=False) - - def test_custom_gradient_3(self): - @jax.custom_gradient - def f(x): - vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) - return jnp.sum(jnp.sin(x)), vjp - - self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), - check_dtypes=False) - self.assertAllClose( - api.grad(f)(jnp.arange(3.)), - api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), - check_dtypes=False) - - def test_custom_gradient_can_return_singleton_value_in_vjp(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: g * x - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_closure_convert(self): - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, x): - return jnp.sum((x - c) ** 2.) - - def solve(c, x): - def closure(x): - return dist(c, x) - return cos_after(closure, x) - - c, x = 2. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, x)) - self.assertAllClose(solve(c, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_mixed_consts(self): - # Like test_closure_convert, but close over values that - # participate in AD as well as values that do not. - # See https://github.com/jax-ml/jax/issues/6415 - - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, s, x): - return jnp.sum(s * (x - c) ** 2.) - - def solve(c, s, x): - def closure(x): - return dist(c, s, x) - return cos_after(closure, x) - - c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, s, x)) - self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_pytree_mismatch(self): - # See https://github.com/jax-ml/jax/issues/23588 - def f(x, z): - return z * x - - x, z = 2.0, 3.0 - _, vjp = api.vjp(f, x, z) - vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) - vjp_pure(x, *vjp_aux_args) - with self.assertRaisesRegex( - TypeError, "The inputs to the closure produced by closure_convert"): - vjp_pure(x, vjp_aux_args) - - def test_float0_cotangents_automatically_handled(self): - @jax.custom_vjp - def f(x, y): - return x - - def f_fwd(x, y): - return x, None - - def f_bwd(_, zbar): - return (0., 1) - - f.defvjp(f_fwd, f_bwd) - - jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash - - def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/jax-ml/jax/issues/5832 - @jax.custom_vjp - def mul(x, coeff): return x * coeff - def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) - def mul_bwd(res, g): - x, coeff = res - g_x = g * coeff - g_coeff = (x * g).sum() - return g_x, g_coeff - mul.defvjp(mul_fwd, mul_bwd) - - def scan_over_mul(x, coeff): - def f_(x, t): - return mul(x, coeff), None - y, _ = jax.lax.scan(f_, x, jnp.arange(3)) - return y - - key = jax.random.key(0) - key1, key2 = jax.random.split(key, 2) - x_batch = jax.random.normal(key1, (3, 2)) - covector_batch = jax.random.normal(key2, (3, 2)) - coeff = jnp.array(1., dtype=x_batch.dtype) - - batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) - res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) - vjp_fun(covector_batch) # doesn't crash - - jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, - modes=['rev']) - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_vjp - def g(y): - return x * y - - def g_fwd(y): - return x * y, (x, x * y, y) - def g_rev(res, w_bar): - x, *_ = res - return (x * w_bar,) - g.defvjp(g_fwd, g_rev) - - return g(z) - - return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() - - jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) - - def test_pytrees_not_required_to_contain_nones(self): - class A(list): - pass - - def unflatten(_, children): - assert children[0] is not None - return A(children) - - tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) - - @jax.custom_vjp - def f(x): - return x[0] - def f_fwd(x): - return x[0], None - def f_bwd(_, g): - return A([g]), - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(A([1.])) # doesn't crash - - def test_vmap_vjp_called_twice(self): - # https://github.com/jax-ml/jax/pull/14728 - @jax.custom_vjp - def f(x): - return x - f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) - - _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) - f_vjp(jnp.array([3.])) - f_vjp(jnp.array([3.])) # doesn't crash - - def test_symbolic_zero_custom_vjp_basic(self): - ZERO = custom_derivatives_public.SymbolicZero - - @jax.custom_vjp - def f(x, y, z): - return x, x - - def fwd(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - self.assertFalse(z.perturbed) - return (x.value, x.value), None - - def fwd_all(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertTrue(y.perturbed) - self.assertTrue(z.perturbed) - return (x.value, x.value), None - - def bwd_all(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - def bwd_fst(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertIs(type(x2), ZERO) - return x1, x1, x2 - - def bwd_snd(_, g): - x1, x2 = g - self.assertIs(type(x1), ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - x, y, z = 4., 5., 6. - i = np.array(7, np.int32) - zero = np.array(0.) - - f.defvjp(fwd, bwd_all, symbolic_zeros=True) - h = jax.jit(f) - jax.jacrev(h)(x, y, z) - jax.jacrev(lambda x: h(x, y, z))(x) - jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) - - f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) - fst_f = lambda *xs: f(*xs)[0] - _, vjp = jax.vjp(fst_f, x, y, z) - _, _, gz = vjp(x) - self.assertArraysAllClose(gz, zero) - - f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) - snd_f = lambda *xs: f(*xs)[1] - _, vjp = jax.vjp(snd_f, x, y, z) - gx, gy, _ = vjp(x) - self.assertArraysAllClose(gx, zero) - self.assertArraysAllClose(gy, zero) - - f.defvjp(fwd, bwd_snd, symbolic_zeros=True) - _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) - gx, = vjp(x) - self.assertArraysAllClose(gx, zero) - - def test_symbolic_zero_custom_vjp_bwd_shape_error(self): - @jax.custom_vjp - def f(x, y, z): - return x, y, z - - def fwd(x, y, z): - return f(x.value, y.value, z.value), None - - def bwd(_, gs): - x_bar, y_bar, z_bar = gs - return y_bar, x_bar, z_bar # swapped! - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - with self.assertRaisesRegex( - ValueError, - r'Consider just returning a None here'): - jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( - jnp.ones(1), jnp.ones(2), jnp.ones(3)) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): - # below: - # * static_scalar will be static in and out - # * static_array will be static in, but dynamic out - # * dyn_scalar and dyn_array will be dynamic in and out - - ZERO = custom_derivatives_public.SymbolicZero - - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return static_scalar, static_array, out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_vjp(f) - - def fwd(*args): - xs, pert = [x.value for x in args], [x.perturbed for x in args] - self.assertFalse(pert[0]) - self.assertFalse(pert[1]) - self.assertTrue(pert[2]) - self.assertTrue(pert[3]) - return f(*xs), xs - - def bwd(res, g): - static_scalar, *_ = res - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g - self.assertIs(type(t_static), ZERO) - self.assertFalse(type(t_static_arr) is ZERO) - self.assertFalse(type(t_dyn_scalar) is ZERO) - self.assertFalse(type(t_dyn_array) is ZERO) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return (static_scalar + 90, - t_static_arr + 91, - t_dyn_scalar + 92, - t_dyn_array + 93) - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - return outs[1:] - - def run(primal_ins, cotangent_outs): - primal_outs, vjp = jax.vjp(g, *primal_ins) - cotangent_ins = vjp(cotangent_outs) - return primal_outs, cotangent_ins - - if maybe_jit: - run = jax.jit(run) - - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - primal_ins = (4., jnp.array([5., 6.])) - cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) - primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) - - primal_out1, primal_out2, primal_out3 = primal_outs - self.assertIsInstance(primal_out1, jax.Array) - self.assertAllClose(primal_out1, jnp.array([2., 3.])) - self.assertIsInstance(primal_out2, scalar_type) - self.assertAllClose(primal_out2, 5.) - self.assertIsInstance(primal_out3, jax.Array) - self.assertAllClose(primal_out3, jnp.array([7., 9.])) - - ct_in1, ct_in2 = cotangent_ins - self.assertIsInstance(ct_in1, scalar_type) - self.assertAllClose(ct_in1, 99.) - self.assertIsInstance(ct_in2, jax.Array) - self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) - - def test_symbolic_zero_custom_vjp_vmap_output(self): - @jax.custom_vjp - def f(x, y): - return x, y - - def fwd(x, y): - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - return f(x.value, y.value), None - - def bwd(_, g): - _, ct_y = g - self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) - return g - - f.defvjp(fwd, bwd, symbolic_zeros=True) - jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zero_custom_vjp_custom_pytree(self): - tree_values = custom_derivatives_public.custom_vjp_primal_tree_values - - @tree_util.register_pytree_node_class - class Box: - def __init__(self_, strict, val): - if strict: - # make sure we aren't getting special arguments that should only - # come up when symbolic_zeros is True - self.assertFalse(hasattr(val, 'perturbed')) - self_.strict = strict - self_.x = val - - def tree_flatten(self_): - return [self_.x], self_.strict - - @classmethod - def tree_unflatten(cls, strict, xs): - x, = xs - return cls(strict, x) - - x, y = Box(False, jnp.array(72.)), jnp.array(73.) - - @jax.custom_vjp - def f(box, y): - return box.x * y - - def fwd0(box, y): - self.assertTrue(box.x.perturbed) - self.assertFalse(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd0(res, g): - box, y = res - return y * g, box.x * g - - def fwd1(box, y): - self.assertFalse(box.x.perturbed) - self.assertTrue(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd1(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd0, bwd0, symbolic_zeros=True) - jax.grad(f, argnums=0)(x, y) - f.defvjp(fwd1, bwd1, symbolic_zeros=True) - jax.grad(f, argnums=1)(x, y) - - def fwd_strict(box, y): - return f(box, y), (box, y) - - def bwd_strict(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd_strict, bwd_strict) - jax.grad(f)(x, y) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return x.value, None - - def f_bwd(_, z_bar): - return z_bar, None - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_vjp - def f(x, y): - return x + y - - def f_fwd(x, y): - if y.perturbed: - res = None - else: - res = [] - return x.value + y.value, res - - def f_bwd(res, ct): - return ct, ct - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/jax-ml/jax/issues/8356 - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, 3.0))) # don't crash - - def test_pytree_nones_returned_by_bwd(self): - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, None))) # don't crash - - def test_bwd_rule_shape_mismatch(self): - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with self.assertRaisesRegex( - ValueError, - r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): - jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_shape_mismatch_disable(self): - # TODO(mattjj): remove this test when the config option is removed - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with config.custom_vjp_disable_shape_check(True): - jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_can_produce_list_or_tuple(self): - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return f(x, y), (x, y) - - def f_bwd(xy, g): - x, y = xy - return [g * y, x * g] # list, not tuple - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(1., 2.) # don't crash - - def test_optimize_remat(self): - def fun(x): - # This array is included to make sure that we handle consts appropriately - return np.array([1.0])*x - - def fwd(x): - return np.array([2.0])*x*x/np.array([1.0]), (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE - self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed - - def test_optimize_remat_vmap(self): - def fun(x): - return (np.array([1.0])*x)[0] - def fwd(x): - return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) - self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) - - def test_optimize_remat_cond(self): - def fun(x): - return x - def fwd(x): - return x*x, (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - def g(x): - return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) - - self.assertAllClose(jax.jit(g)(x)[0], x*x) - self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) - - def test_optimize_remat_jvp(self): - def fun(x): - return x**2 - def fwd_(x): - return x*x, (x,) - - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), - fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) - calc = jax.jvp(fwd, (3.2,), (1.0,)) - expected = jax.jvp(fwd_, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - @jax.jit - def g(x, t): - (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) - return y, y_dot - calc = g(3.2, 1.0) - expected = jax.jvp(fun, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - def test_optimize_remat_gh21303(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - def f_fwd(x): - return jnp.sin(x), (x,) - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_optimize_remat_multiple_args(self): - def f_(x, y): - return jnp.sin(x) * y - - @jax.custom_vjp - def f(x, y): - return f_(x, y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) - - def test_optimize_remat_kwargs(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y, *, keyword=False): - del keyword - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - jax.grad(f)(x, y) # Doesn't error - - def test_optimize_remat_custom_vmap(self): - # See https://github.com/jax-ml/jax/pull/23000 - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - @jax.custom_batching.custom_vmap - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - @f_fwd.def_vmap - def f_fwd_vmap(_, in_batched, x, y): - # Insert a new const here to test the optimize_remat batching rule. - out = np.array([2.0])*f(x, y) - out_batched = (True, (True, True, True)) - return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) - jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error - - def test_dce(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(y)) - - def f_bwd(res, cts): - cos_x, sin_y = res - ct_a, ct_b = cts - return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b - - f.defvjp(f_fwd, f_bwd) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["fun_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) - self.assertAllClose( - api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_vjp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - self.fail("should not be executed") - - def f_bwd(res, cts): - self.fail("should not be executed") - - f.defvjp(f_fwd, f_bwd) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -def transpose_unary(f, x_example): - def transposed(y): - x, = api.linear_transpose(f, x_example)(y) - return x - return transposed - - -# This class wraps jax.custom_transpose.custom_transpose in order to pass in a -# particular tree of output type on each call. Otherwise it forwards -# all attribute access. -class _custom_transpose: - def __init__(self, out_types, fun): - self.out_types = out_types - self.fun = jax.custom_transpose.custom_transpose(fun) - - def __getattr__(self, name): - return getattr(self.fun, name) - - def __call__(self, *args): - return self.fun(self.out_types, *args) - - -# This function is meant to be used as a decorator that delegates to -# custom_transpose but makes it easy to specify output argument types -# by example. If used directly a decorator (i.e. not invoked with -# example arguments), assumes a scalar-valued function. -# -# TODO(frostig): remove this (and its uses) once custom_transpose offers -# an option of inferring output types. -def custom_transpose(example_out): - if isinstance(example_out, Callable): - out_type = core.get_aval(0.).to_tangent_aval() - return _custom_transpose(out_type, example_out) - return partial( - _custom_transpose, - jax.tree.map( - lambda x: core.get_aval(x).to_tangent_aval(), example_out)) - - -class CustomTransposeTest(jtu.JaxTestCase): - - def test_linear_call(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_incorrect_transpose(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: not the true transpose - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_transpose_transpose_transpose(self): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: untrue transpose - def f_(x, y): - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_linear_call_scalar_to_vector(self): - def f(c, x): - def fn(_, x): - return [x, x] - - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return jax.custom_derivatives.linear_call(fn, tp, (), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_linear_call_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - def f(_, x): return x - def t(_, t): return 0. - return jax.custom_derivatives.linear_call(f, t, (), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - def f_(_, x): return id_(x) - def t_(_, t): return id_(7.) - return jax.custom_derivatives.linear_call(f_, t_, (), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_linear_call_jit(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f1 = lambda x: f(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - jax.jit(transpose_unary(f1, x))(x)) - - def test_linear_call_type_mismatch(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return None - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f1 = lambda x: f(x, y) - with self.assertRaisesRegex(TypeError, "transpose output pytree"): - transpose_unary(f1, x)(x) - - def test_linear_call_recursion(self): - def f(x): - def fn(_, x): return x - def tp(_, t): return f(t) - return jax.custom_derivatives.linear_call(fn, tp, None, x) - jax.jit(f)(0.1) - - def test_linear_call_grad(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.array(6.) - y = jnp.array(3.) - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y)) - - def test_basic(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_incorrect_transpose(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / (2. * r) # nb: not the true transpose - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_transpose_transpose_transpose(self): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @custom_transpose(jnp.ones(2)) - def tp(r, t): return t / (2. * r) # nb: untrue transpose - - fn.def_transpose(tp) - tp.def_transpose(fn) - - def f_(x, y): - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_scalar_to_vector(self): - def f(c, x): - @custom_transpose([0., 0.]) - def fn(_, x): - return [x, x] - - @fn.def_transpose - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return fn((), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - f = custom_transpose(lambda _, x: x) - t = custom_transpose(lambda _, t: 0.) - f.def_transpose(t) - t.def_transpose(f) - return f((), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - f_ = custom_transpose(lambda _, x: id_(x)) - t_ = custom_transpose(lambda _, t: id_(7.)) - f_.def_transpose(t_) - t_.def_transpose(f_) - return f_((), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_one_degree(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @f.def_transpose - def ft(_, z): return 3. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(3., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... - - def test_two_degrees(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - - @f.def_transpose - @custom_transpose - def ft(_, z): return 3. * z - - @ft.def_transpose - def ftt(_, z): return 7. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(7., T(T(f))(1.)) - self.assertAllClose(7., T(T(T(f)))(1.)) - self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... - - def test_symmetric(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @custom_transpose - def g(_, z): return 3. * z - - f.def_transpose(g) - g.def_transpose(f) - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(2., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... - - def test_recursive(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(c, z): return c * z - - @f.def_transpose - def ft(c, z): return f(c + 1., z) - - g = partial(f, 1.) - self.assertAllClose(1., g(1.)) - self.assertAllClose(2., T(g)(1.)) - self.assertAllClose(3., T(T(g))(1.)) - self.assertAllClose(4., T(T(T(g)))(1.)) - self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... - - def test_jvp_lin(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx = 6., 3., 1. - g = lambda x: f(x, y) - g_ref = lambda x: f_ref(x, y) - self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) - - def test_jvp_res(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, ty = 6., 3., 1. - g = lambda y: f(x, y) - g_ref = lambda y: f_ref(x, y) - self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) - - def test_jvp_both(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx, ty = 6., 3., 1., 1. - self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), - api.jvp(f_ref, [x, y], [tx, ty])) - - def test_make_jaxpr(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - - jaxpr = api.make_jaxpr(f_)(x) - self.assertIn('custom_transpose_call', str(jaxpr)) - - jaxpr_t = api.make_jaxpr(f_t)(x) - self.assertNotIn('custom_transpose_call', str(jaxpr_t)) - - def test_jit(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_jit_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_jit_signature_deprecation(self): - fun = lambda x: x - if deprecations.is_accelerated('jax-jit-positional-args'): - with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): - jax.jit(fun=fun) - with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): - jax.jit(fun, None) - else: - with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): - jax.jit(fun=fun) - with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): - jax.jit(fun, None) - - def test_cond(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_cond_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_compose_custom_jvp(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - dx, = tangents - return f(x), g(x, dx) - - @custom_transpose - def g(x, dx): - return jnp.cos(x) * dx - - @g.def_transpose - def gt(x, t): - return jnp.cos(x) * t - - with config.use_direct_linearize(True): - self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) - - -class CustomDceTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) - - def test_recursive(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.exp(x), 10 * jnp.sqrt(x) - - @f.def_dce - def f_dce(used_outs, x): - return [2 * v if used else None for used, v in zip(used_outs, f(x))] - - x = 1.1234 - expected = f(x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) - - def test_multiple_rounds(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, z): - return jnp.sin(x), jnp.sin(y), jnp.sin(z) - - @f.def_dce - def rule(used_outs, x, y, z): - patterns.append(used_outs) - outs = [ - jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) - ] - return outs - - patterns = [] - x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) - jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr - new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) - assert used_ins == [True, False, True] - new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) - assert used_ins == [True, False] - assert patterns == [(True, False, True), (True, False, False)], patterns - - def test_batching(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y): - return jnp.sin(x), jnp.sin(y) - - @f.def_dce - def rule(used_outs, x, y): - return ( - jnp.cos(x) if used_outs[0] else None, - jnp.cos(y) if used_outs[1] else None, - ) - - x = jnp.linspace(-0.1, 0.2, 5) - y = jnp.linspace(3.0, 4.0, 5) - self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) - ) - - def test_composes_with_custom_vjp(self): - # custom_dce must be the "outer" decorator (for now!) because custom_vjp - # doesn't pass through DCE. - @jax.experimental.custom_dce.custom_dce - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - return ( - jnp.cos(x) * y if used_outs[0] else None, - x * jnp.cos(y) if used_outs[1] else None, - ) - - def f_fwd(x, y): - return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) - - def f_bwd(res, g): - ga, gb = g - x, cos_x, sin_x, y, cos_y, sin_y = res - return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) - - f.defvjp(f_fwd, f_bwd) - - x, y = jnp.array(1.), jnp.array(2.) - self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), - jnp.cos(x) * y) - jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. - - def test_can_optimize_remat(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - @jax.experimental.custom_dce.custom_dce - def f_fwd(x): - return jnp.sin(x), (x,) - - @f_fwd.def_dce - def f_dce_rule(used_outs, x): - used_prim, used_res = used_outs - used_res, = used_res - if not used_res: - return f(x), None - prim, res = f_fwd(x) - return prim if used_prim else None, res - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) - def g(f, x): - return f(x), 10 * f(x) - - @g.def_dce - def g_dce(f, used_outs, x): # note: static_argnums are always passes first - self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] - - x = 1.1234 - f = lambda x: jnp.exp(x) - expected = g(f, x) - self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) - - def test_shape_mismatch_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.stack((x, x)), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - x.astype(jnp.int32) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', - ): - jax.jit(lambda x: f(x)[1])(x) - - def test_missing_output_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return None, None - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* produce values for all .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - - def test_consts(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return np.eye(1) * jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - expected = rule([True, True], x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) - - def test_resolve_kwargs_error_message(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, *, z=None): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomVmapTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, jnp.cos(xs)) - - @jax.numpy_dtype_promotion('standard') - def test_closure(self): - z = jnp.array([2., 1., 3.]) - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, *args): - self.assertEqual(len(in_batched), 1) - self.assertEqual(len(args), 1) - xs, = args - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return z + jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, z + jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, z + jnp.cos(xs)) - - def test_rule_multi_output(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x), jnp.cos(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) - - x, xs = jnp.array(1.), jnp.arange(3) - y1, y2 = f(x) - self.assertAllClose(y1, jnp.sin(x)) - self.assertAllClose(y2, jnp.cos(x)) - ys1, ys2 = api.vmap(f)(xs) - self.assertAllClose(ys1, jnp.cos(xs)) - self.assertAllClose(ys2, jnp.sin(xs)) - - def test_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y ** 2. - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(in_batched, [True, True]) - self.assertEqual(axis_size, 3) - self.assertEqual(axis_size, xs.shape[0]) - self.assertEqual(axis_size, ys.shape[0]) - return jnp.cos(xs) + ys ** 2., True - - xs, ys = jnp.arange(3.0), jnp.arange(3.0) - zs = api.vmap(f)(xs, ys) - self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) - - def test_nary_mixed_batching(self): - @jax.custom_batching.custom_vmap - def vector_dot(u, v): - self.assertEqual(u.ndim, 1) - self.assertEqual(v.ndim, 1) - return u @ v - - size = 4 - vlen = 3 - in_batched_log = [] - - @vector_dot.def_vmap - def vector_dot_vmap_rule(axis_size, in_batched, u, v): - in_batched_log.append(in_batched) - self.assertEqual(axis_size, size) - u_batched, v_batched = in_batched - if u_batched: - self.assertEqual(u.ndim, 2) - self.assertEqual(u.shape[0], size) - else: - self.assertEqual(u.ndim, 1) - self.assertEqual(u.shape[0], vlen) - if v_batched: - self.assertEqual(v.ndim, 2) - self.assertEqual(v.shape[0], size) - else: - self.assertEqual(v.ndim, 1) - self.assertEqual(v.shape[0], vlen) - if u_batched and v_batched: - out = jnp.sum(u * v, axis=1) - else: - out = u @ v if u_batched else v @ u - return out, u_batched or v_batched - - f = vector_dot - v = lambda *shape: jnp.ones(shape) - - y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) - self.assertAllClose(y, v(4, 3) @ v(3)) - y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) - self.assertAllClose(y, v(3, 4).T @ v(3)) - y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) - self.assertAllClose(y, v(3) @ v(4, 3).T) - y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) - self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) - self.assertEqual(in_batched_log[0], [True, False]) - self.assertEqual(in_batched_log[1], [True, False]) - self.assertEqual(in_batched_log[2], [False, True]) - self.assertEqual(in_batched_log[3], [True, True]) - - def test_rule_input_signature(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - rule_args = [] - - @f.def_vmap - def rule(axis_size, in_batched, xs): - rule_args.append((axis_size, in_batched)) - return jnp.cos(xs), in_batched[0] - - xs = jnp.arange(3) - _ = api.vmap(f)(xs) - (axis_size, in_batched), = rule_args - self.assertIs(type(axis_size), int) - self.assertIs(type(in_batched), list) - self.assertEqual(len(in_batched), 1) - - def test_rule_output_vs_batching_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc(axis_size, in_batched, xs): - return [jnp.sin(xs), jnp.cos(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - 'structure of output value and output batching specification ' - r'returned by custom vmap rule \(test_rule_abc\) do not match.*', - lambda: api.vmap(f)(xs)) - - def test_rule_vs_call_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc2(axis_size, in_batched, xs): - return [jnp.sin(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' - r'does not match that of original custom-vmapped function.*', - lambda: api.vmap(f)(xs)) - - def test_jvp_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, jnp.sin(x)) - self.assertAllClose(ty, jnp.cos(x) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - @jax.numpy_dtype_promotion('standard') - def test_jvp_closure(self): - z = jnp.array([2., 1., 3.]) - def bcast(x): return z + x - z - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return z + jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, z + jnp.sin(x)) - self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - def test_jvp_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True, True]) - return jnp.cos(xs) + ys, True - - f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) - - x, y, tx, ty = jnp.arange(4.) - xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) - - zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - def test_jvp_extra_batched_tangents(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) - self.assertAllClose(y, jnp.cos(x)) - self.assertAllClose(tys, -jnp.sin(x) * txs) - - def test_jacfwd(self): - # jacfwd is another way to exercise extra-batched tangents - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - x = jnp.arange(3.) + .72 - j = api.jacfwd(f)(x) - self.assertAllClose(j, -jnp.diag(jnp.sin(x))) - - def test_jvp_extra_batched_primals(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - xs = jnp.arange(3.) - tx = jnp.array(4, dtype=xs.dtype) - ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * tx) - - def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): - # When a function is linear, its Jacobian is constant. JAX's JVP - # of linear functions takes advantage of this: when mapping over a - # batch of primals relative to a fixed (i.e. symbolically - # replicated) tangent, output tangents remain replicated as well - # (i.e. JAX will not broadcast them). This is true in general, and - # this test checks that vmapped JVPs continue to behave this way - # when custom_vmap is involved and the custom vmap rule is linear. - - @jax.custom_batching.custom_vmap - def f_linear(x): return 7. * x - - @f_linear.def_vmap - def linear_rule(axis_size, in_batched, xs): - return 11. * xs, in_batched[0] - - @jax.custom_batching.custom_vmap - def f_nonlinear(x): return jnp.sin(x) - - @f_nonlinear.def_vmap - def nonlinear_rule(axis_size, in_batched, xs): - return jnp.cos(xs), in_batched[0] - - f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) - f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) - xs = jnp.arange(3.) - tx = jnp.array(4., dtype=xs.dtype) - - # doesn't err - _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) - - def test_jvp_dataflow_violation(self): - # The jvp-of-custom-vmap machinery should not assume the standard - # dataflow constraint on the JVP of the custom vmap rule (primal - # outputs independent of tangent inputs). Both jvp and vmap are - # "forward" transformations under which, at present, we don't - # enforce the JVP dependence diagram. Because output primals can - # depend on input tangents, extra-batched input tangents can - # create batched output primals, as this test checks. - - @jax.custom_jvp - def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) - - @cos_with_invalid_dataflow_jvp.defjvp - def invalid_dataflow_jvp(x, tx): - [x], [tx] = x, tx - return jnp.cos(x * tx), tx - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return cos_with_invalid_dataflow_jvp(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - - # doesn't err - ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) - self.assertAllClose(ys, jnp.cos(x * txs)) - self.assertAllClose(tys, txs) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) - - def test_tree(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, x + 2], [x + 3], x + 4) - xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_tree_with_nones(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, None], [x + 3], None) - xs = (xs, [xs + 1, None], [xs + 3], None) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_jit(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [True]) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), in_batched[0] - - x, xs = jnp.array(1.), jnp.arange(3) - self.assertAllClose(f(x), jit(f)(x)) - self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) - self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) - - def test_sequential_vmap_basic(self): - @jax.custom_batching.sequential_vmap - def f(x): - return x + 1. - - def vmap_ref(xs): - return lax.map(f, xs) - - xs = jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_same_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, ys): - return lax.map(lambda args: f(*args), (xs, ys)) - - xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_mixed_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, y): - return lax.map(lambda x: f(x, y), xs) - - xs, y = jnp.arange(3.), 4. - jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_basic(self, batch_size: int): - def f(x): - self.assertEqual(x.shape, ()) - return x**2 - - x = np.arange(16) - y = jax.lax.map(f, x, batch_size=batch_size) - - np.testing.assert_array_equal(y, x**2) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_pytrees(self, batch_size: int): - f = lambda x: {'b': x['a'] ** 2} - inputs = {'a': np.arange(16)} - expected = np.arange(16) ** 2 - - outputs = jax.lax.map(f, inputs, batch_size=batch_size) - self.assertAllClose(outputs['b'], expected) - - outputs = jax.lax.map( - f, inputs, batch_size=batch_size - ) - self.assertAllClose(outputs['b'], expected) - - def test_batch_divides_axis(self): - def f(t): - x, a = t - self.assertEqual(x.shape, (4,)) - return (x + a)**2 - - x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) - a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) - - @jax.jit - def g(x, a): - return jax.lax.map(f, (x, a), batch_size=8) - - y = g(x, a) - - self.assertAllClose(y, (x + a)**2) - - def test_undefined_rule(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - with self.assertRaisesRegex( - AttributeError, "No batching rule defined for custom_vmap function f"): - f(0.5) - - def test_kwargs(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x=x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(x=xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_partial_eval_raises(self): - @jax.custom_batching.custom_vmap - def f(x): - return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - del axis_size # unused - return jnp.cos(xs), in_batched[0] - - with self.assertRaisesRegex( - ValueError, - "Linearization failed to produce known values for all output primals", - ): - jax.grad(f)(0.5) - - def test_compose_custom_vjp(self): - @jax.custom_vjp - @jax.custom_batching.custom_vmap - def f(x, y): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - return jnp.cos(xs) * ys, True - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) - - xs = jnp.linspace(0, 1, 5) - ys = jnp.linspace(-0.1, 0.1, 5) - self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) - jax.grad(f)(xs[0], ys[0]) # Doesn't crash. - - def test_compose_custom_vjp_bwd_rule(self): - # This tests the case where both the forward and backward rules are wrapped - # in custom_vmap. - @jax.custom_batching.sequential_vmap - def fun_fwd(x, y): - return jnp.sin(x) * y, (x, y) - - @jax.custom_batching.sequential_vmap - def fun_bwd(res, ct): - x, y = res - return x * ct, y * ct - - fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) - fun.defvjp(fun_fwd, fun_bwd) - - xs = jnp.linspace(0, 1, 5) - y = jnp.array(0.5, dtype=xs.dtype) - f = jax.vmap(jax.jit(fun), in_axes=(0, None)) - out, f_vjp = jax.vjp(f, xs, y) - f_vjp(out) # Doesn't crash. - - def test_resolve_kwargs_error_message(self): - @jax.custom_batching.custom_vmap - def f(x, y, *, z=None): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomApiTest(jtu.JaxTestCase): - """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" - - def test_method_forwarding(self): - @jax.custom_batching.custom_vmap - @jax.custom_jvp - @jax.custom_transpose.custom_transpose - def f(x): return 2. * x - - # none of these err: - @f.def_vmap - def f_batch(sz, b, xs): return 2. * xs - @f.defjvp - def f_jvp(x, tx): return 2. * x, 2. * tx - @f.def_transpose - def f_transpose(x): return 2. * x - - def test_def_method_forwarding_all_permutations(self): - for wraps in it.permutations([ - jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for wrap in wraps: - f = wrap(f) - for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - for decorators in it.permutations([ - jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for decorator in decorators: - f = decorator(f) - for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - class BufferDonationTest(jtu.BufferDonationTestCase): @jtu.device_supports_buffer_donation() diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py new file mode 100644 index 000000000000..72c14634a9c8 --- /dev/null +++ b/tests/custom_api_test.py @@ -0,0 +1,4625 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Callable +import concurrent.futures +import functools +from functools import partial +import itertools as it +import re +import unittest +import textwrap + +from absl.testing import absltest, parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax import float0, grad, jit +from jax import lax +from jax import tree_util +from jax.ad_checkpoint import checkpoint as new_checkpoint +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +import jax.experimental.custom_dce +from jax.errors import UnexpectedTracerError + +from jax._src import api +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import custom_derivatives +from jax._src import deprecations +from jax._src import test_util as jtu +from jax._src.interpreters import partial_eval as pe + +config.parse_flags_with_absl() + + +class CustomJVPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + + def test_invariance(self): + @jax.custom_jvp + def f(x): + return jnp.cos(2 * x) / 2. + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return (f(x), 3 * g) + f.defjvp(f_jvp) + def f2(x): + y, _ = api.jvp(f, (x,), (x,)) + return y + def f3(x): + y, _ = api.jvp(f2, (x,), (x,)) + return y + x = 1. + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f2, (x,), (x,)), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f3, (x,), (x,)), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_jvp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + if x > 0: + return f(x), 2 * g + else: + return f(x), 3 * g + f.defjvp(f_jvp) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (-x,), (1.,)), + (jnp.cos(-x), 3.), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) + self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) + + def test_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + assert jnp.ndim(x) == jnp.ndim(g) == 0 + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of jvp of f + self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # jvp of vmap of f + self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # vmap of jvp of vmap of f + self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + def test_jit(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of jvp + self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + # jvp of jit + self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_jvp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} + f.defjvp(f_jvp) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.jvp(f, (x,), (x,)), + ({'b': jnp.sin(x['a'])}, + {'b': 2 * jnp.cos(x['a']) * x['a']}), + check_dtypes=False) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_jvp + def my_fun(x, y, c=1.): + return c * (x + y) + def my_jvp(primals, tangents): + x, y, c = primals + t_x, t_y, t_c = tangents + return my_fun(x, y, c), t_c + my_fun.defjvp(my_jvp) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash + + def test_initial_style(self): + @jax.custom_jvp + def f(x): + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.jit(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.vmap(api.jit(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.vmap(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_jvp + def f(x): + return lax.psum(x, 'foo') + + @f.defjvp + def f_jvp(xs, ts): + x, = xs + t, = ts + return lax.psum(x, 'foo'), t + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(v, 8.) + + def test_closed_over_tracers_error_message(self): + def f(x): + @jax.custom_jvp + def g(y): + return x + y + def g_jvp(primals, tangents): + return g(x), 2 * primals[0] + g.defjvp(g_jvp) + return g(1.) + + self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) + + def test_nondiff_arg(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) + expected = (2., 3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_jit_tracer(self): + # This test would pass with "final-style" JIT tracing, but that was + # misleading: it doesn't work with "initial-style" staging, i.e. control + # flow primitives like jax.lax.scan or even pjit. The behavior isn't very + # useful either: instead of using nondiff_argnums here, a user can just pass + # such inputs as ordinary arguments, and ignore the corresponding tangents. + # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a + # string- or callable-valued argument which parameterizes the function or + # rule) or (2) static data (e.g. integers which parameterize shapes). + raise unittest.SkipTest("behavior no longer supported") + + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + @jit + def g(x, y): + return f(x, y) + + ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) + expected = (6., 5.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_vmap_tracer(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + g = jax.vmap(f) + + ans = api.jvp(lambda y: g(jnp.array([2.]), y), + (jnp.array([3.]),), (jnp.array([1.]),)) + expected = (jnp.array([6.]), jnp.array([5.])) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_hiding_jvp_tracer(self): + def f(x): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def g(h, x): + return h(x) + @g.defjvp + def g_jvp(h, primals, tangents): + x, = primals + t, = tangents + return g(h, x), 2. * t + h = lambda y: x + y # capture x + return g(h, x) + + with self.assertRaises(UnexpectedTracerError): + api.jvp(f, (2.,), (1.,)) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_jvp_rule_error_message(self): + @jax.custom_jvp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.jvp(foo, (2.,), (1.,))) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.grad(foo)(2.)) + + def test_jvp_rule_inconsistent_pytree_structures_error_message(self): + @jax.custom_jvp + def f(x): + return (x**2,) + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), [2 * x * t, x] + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce primal and tangent outputs " + "with equal container (pytree) structures, but got " + "{} and {} respectively.".format( + jax.tree.structure((1,)), + jax.tree.structure([1, 2])) + ), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_primal_tangent_aval_disagreement_error_message(self): + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), jnp.reshape(t, (1,)) + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule must produce primal and tangent outputs " + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), + lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + + + def test_jvp_rule_doesnt_return_pair_error_message(self): + # https://github.com/jax-ml/jax/issues/2516 + + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return t + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce a pair (list or tuple of length two) " + "representing primal and tangent outputs, but got 1.0"), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_jvp + def f(x): + return x + + @f.defjvp + def f_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return (x, x), (xdot, xdot) + + x = jnp.float32(1.) + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular of the " + "same container/pytree structure), but instead the JVP rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_jvp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + @f.defjvp + def f_jvp2(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.zeros((3, *x.shape), x.dtype), xdot + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular " + "with leaves of the same shape/dtype), but instead the JVP rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_jvp-decorated function f had output shapes/dtypes" + " of:\n" + " float32[]" + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + def test_multiple_rule_invocations(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + def scanned_fun(c, _): + return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def foo(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # just make sure these don't crash + foo(3.) + grad(foo)(3.) + grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) + + def test_hard_stuff(self): + arr = jnp.ones((5, 2, 2)) + api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash + + def test_hard_stuff2(self): + @jax.custom_jvp + def f(x): + return np.zeros(x.shape, x.dtype) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), t + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_hard_stuff3(self): + @jax.custom_jvp + def relu(x): + return jnp.maximum(x, 0) + + @relu.defjvp + def _relu_jvp(primals, tangents): + x, = primals + t, = tangents + return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) + + def scanned_fun(c, _): + return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def f(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_eval_shape(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + # don't crash + api.eval_shape(expit, jnp.ones((2, 3))) + api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) + + def test_jaxpr_zeros(self): + # from https://github.com/jax-ml/jax/issues/2657 + @jax.custom_jvp + def f(A, b): + return A @ b + + def f_jvp(primals, tangents): + A, b = primals + dA, db = tangents + z = f(A, b) + dz = A @ db + dA @ b + return z, dz + + f.defjvp(f_jvp) + + def experiment(theta): + def step(q, _): + z = f(jnp.eye(3), jnp.ones(3) * theta) + q += z[0] + return q, q + + q = 0. + q, _ = lax.scan(step, q, None, 4) + return q + + grad(experiment)(1.) # doesn't crash + + def test_linear_in_scan(self): + @jax.custom_jvp + def f(x): + return -x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + x_dot, = tangents + return f(x), f(x_dot) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = -1. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvps_first_rule_is_none(self): + # https://github.com/jax-ml/jax/issues/3389 + @jax.custom_jvp + def f(x, y): + return x ** 2 * y + + f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) + ans = grad(f, 1)(2., 3.) # doesn't crash + expected = 12. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_concurrent_initial_style(self): + # https://github.com/jax-ml/jax/issues/3843 + def unroll(param, sequence): + def scan_f(prev_state, inputs): + return prev_state, jax.nn.sigmoid(param * inputs) + return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) + + def run(): + return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) + + expected = run() + + # we just don't want this to crash + n_workers = 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: + futures = [] + for _ in range(n_workers): + futures.append(e.submit(run)) + results = [f.result() for f in futures] + for ans in results: + self.assertAllClose(ans, expected) + + def test_nondiff_argnums_vmap_tracer(self): + # https://github.com/jax-ml/jax/issues/3964 + @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) + def sample(shape, param, seed): + return jax.random.uniform(key=seed, shape=shape, minval=param) + + @sample.defjvp + def sample_jvp(shape, seed, primals, tangents): + param, = primals + dparam, = tangents + dparam = jnp.broadcast_to(dparam, shape) + samples = sample(shape, param, seed) + return samples, samples * dparam # dummy jvp for proof of concept + + # check these don't crash + jax.vmap(lambda seed: sample((2,3), 1., seed))( + jax.random.split(jax.random.key(1), 10)) + jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), + (1.,), (1.,)) + + def test_fun_with_nested_calls_2(self): + def call(f, *args): + f = jax.custom_jvp(f) + f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) + return f(*args) + + def fun_with_nested_calls_2(x): + def bar(y): + def baz(w): + q = call(lambda x: y, x) + q = q + call(lambda: y) + q = q + call(lambda y: w + y, y) + q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q + return q + return api.jit(baz)(x) + return call(bar, x) + + # test these don't crash + self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), + fun_with_nested_calls_2(3.)) + api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) + + def test_closure_with_vmap(self): + # https://github.com/jax-ml/jax/issues/3822 + alpha = np.float32(2.) + + def sample(seed): + @jax.custom_jvp + def f(alpha): + return jax.random.gamma(seed, alpha, shape=[]) + + @f.defjvp + def f_jvp(primal, tangent): + alpha = primal + dalpha = tangent + sample = f(alpha) + partial_alpha = lax.random_gamma_grad(alpha, sample) + return sample, partial_alpha * dalpha + return f(alpha) + + api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_jvp + def g(y): + return x * y + + # NOTE: rule closes over vmap tracer + @g.defjvp + def g_jvp(primals, tangents): + (y,), (ydot,) = primals, tangents + return x * y, x * ydot + + return g(z) # NOTE: no vmapped arg + + return jax.vmap(f)(jnp.arange(3., dtype='float32')) + + primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) + self.assertAllClose(primals , jnp.arange(3., dtype='float32')) + self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) + + def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(f, primals, tangents), + (primals, expected_tangents)) + + def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) + return out + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + + self.assertAllClose(api.jvp(foo, primals, tangents), + (primals, expected_tangents)) + + def test_remat(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(new_checkpoint(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(new_checkpoint(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(new_checkpoint(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_2(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvp_vmap_broadcasting_interaction(self): + # https://github.com/jax-ml/jax/issues/6452 + def f2(y, z): + v1 = z + v2 = jnp.sum(y) + z + return jnp.logaddexp(v1, v2) + + def f1(y, z): + v = api.vmap(lambda _y: f2(_y, z))(y) + return jnp.sum(v) + + y = jnp.ones((3, 2)) + f = lambda z: f1(y, z) + z = 0.1 + val, g = api.value_and_grad(f)(z) + self.assertEqual(val.shape, ()) + self.assertEqual(g.shape, ()) + + def test_custom_jvp_vmap_broadcasting_interaction_2(self): + # https://github.com/jax-ml/jax/issues/5849 + @jax.custom_jvp + def transform(box, R): + if jnp.isscalar(box) or box.size == 1: + return R * box + elif box.ndim == 2: + return jnp.einsum('ij,j->i', box, R) + raise ValueError() + + @transform.defjvp + def transform_jvp(primals, tangents): + box, R = primals + dbox, dR = tangents + return (transform(box, R), dR + transform(dbox, R)) + + def periodic_general(box): + def displacement_fn(Ra, Rb, **kwargs): + _box = kwargs.get('box', box) + return transform(_box, Ra - Rb) + + return displacement_fn + + N = 250 + + scalar_box = 1.0 + displacement = periodic_general(scalar_box) + + key = jax.random.key(0) + R = jax.random.uniform(key, (N, 2)) + + def energy_fn(box): + d = partial(displacement, box=box) + d = api.vmap(api.vmap(d, (None, 0)), (0, None)) + return jnp.sum(d(R, R) ** 2) + + self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) + + def test_custom_jvp_implicit_broadcasting(self): + # https://github.com/jax-ml/jax/issues/6357 + if config.enable_x64.value: + raise unittest.SkipTest("test only applies when x64 is disabled") + + @jax.custom_jvp + def projection_unit_simplex(x: jax.Array) -> jax.Array: + """Projection onto the unit simplex.""" + s = 1.0 + n_features = x.shape[0] + u = jnp.sort(x)[::-1] + cssv = jnp.cumsum(u) - s + ind = jnp.arange(n_features, dtype=x.dtype) + 1 + cond = u - cssv / ind > 0 + idx = jnp.count_nonzero(cond) + threshold = cssv[idx - 1] / idx.astype(x.dtype) + return jax.nn.relu(x - threshold) + + + @projection_unit_simplex.defjvp + def projection_unit_simplex_jvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = projection_unit_simplex(x) + supp = (primal_out > 0).astype(x_dot.dtype) + card = jnp.count_nonzero(supp).astype(x_dot.dtype) + tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp + return primal_out, tangent_out + + rng = self.rng() + x = rng.rand(5).astype(np.float32) + + J_rev = jax.jacrev(projection_unit_simplex)(x) + J_fwd = jax.jacfwd(projection_unit_simplex)(x) + + p = projection_unit_simplex(x) + support = (p > 0).astype(jnp.float32) + cardinality = jnp.count_nonzero(support).astype(support.dtype) + J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality + self.assertAllClose(J_true, J_fwd) + self.assertAllClose(J_true, J_rev) + + proj = jax.vmap(projection_unit_simplex) + + def fun(X): + return jnp.sum(proj(X) ** 2) + + rng = self.rng() + X = rng.rand(4, 5).astype(np.float32) + U = rng.rand(4, 5) + U /= np.sqrt(np.sum(U ** 2)) + U = U.astype(np.float32) + + eps = 1e-3 + dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) + dir_deriv = jnp.vdot(jax.grad(fun)(X), U) + self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) + + def test_vmap_inside_defjvp(self): + # https://github.com/jax-ml/jax/issues/3201 + seed = 47 + key = jax.random.key(seed) + mat = jax.random.normal(key, (2, 3)) + + @jax.custom_jvp + def f(mat, aux): + num_rows, num_cols = mat.shape + return jnp.ones((num_rows, 1)) / num_cols + + @f.defjvp + def f_jvp(primals, tangents): + mat, aux = primals + vec, _ = tangents + output = f(*primals) + num_rows, num_cols = mat.shape + size = num_rows * num_cols + # ----- + bd_mat = mat.reshape(1, 1, num_rows, num_cols) + bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) + bd_mat = bd_mat.reshape(size, num_rows, num_cols) + # ----- + rowsum = jnp.sum(mat, axis=1, keepdims=True) + colsum = jnp.sum(mat, axis=0, keepdims=True) + bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) + bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) + # ----- + bd_vec = vec.reshape(size, 1) + # ----- + def operate(mx, val): + buf = 0 + for i in range(2): + buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) + buf = jnp.matmul(bd_rowsum, buf) + return buf * val[None, :] + # ----- + # Vertorizing will raise shape error + bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) + # ----- + bd_buf = bd_buf / aux + jvp = jnp.sum(bd_buf, axis=0) + jvp = jnp.mean(jvp, axis=1, keepdims=True) + # ----- + # JVP ends successfully, but still raise an error + return (output, jvp) + + jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash + + def test_custom_jvp_unbroadcasting(self): + # https://github.com/jax-ml/jax/issues/3056 + a = jnp.array([1., 1.]) + + @jax.custom_jvp + def f(x): + return a * x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return a * x, a * dx + + shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape + self.assertEqual(shape, ()) + + def test_maybe_perturbed_internal_helper_function(self): + # This is a unit test for an internal API. We include it so as not to + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. + def f(x): + def g(y, _): + z = y * x + self.assertTrue(custom_derivatives._maybe_perturbed(z)) + return y, None + g(1, None) + return lax.scan(g, 1, xs=None, length=1)[0] + + jax.jvp(f, (1.0,), (1.0,)) # assertions inside f + + def test_maybe_perturbed_int_regression(self): + # see https://github.com/jax-ml/jax/discussions/9951 + + @jax.jit + def f(): + x = jnp.array(1) + _, aux_args = custom_derivatives.closure_convert(lambda: x) + self.assertEmpty(aux_args) + f() + + def test_sinc_constant_function_batching(self): + # https://github.com/jax-ml/jax/pull/10756 + batch_data = jnp.arange(15.).reshape(5, 3) + + @jax.vmap + def f(x): + return jax.lax.map(jnp.sinc, x) + g = lambda param: f(param * batch_data).sum() + + @jax.vmap + def f_ref(x): + return jnp.stack([jnp.sinc(x_) for x_ in x]) + g_ref = lambda param: f_ref(param * batch_data).sum() + + grad = jax.grad(g )(0.1) # doesn't crash + grad_ref = jax.grad(g_ref)(0.1) + self.assertAllClose(grad, grad_ref, check_dtypes=False) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_jvp(f) + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + static_scalar, *_ = primals + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents + self.assertIs(type(t_static) , jax.custom_derivatives.SymbolicZero) + self.assertIs(type(t_static_arr), jax.custom_derivatives.SymbolicZero) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return f(*primals), (static_scalar + 90, t_dyn_array + 91) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + + def run(primal_ins, tangent_ins): + return jax.jvp(g, primal_ins, tangent_ins) + + if maybe_jit: + run = jax.jit(run) + + primal_ins = (4., jnp.array([5., 6.])) + tangent_ins = (7., jnp.array([8., 9.])) + primal_outs, tangent_outs = run(primal_ins, tangent_ins) + primal_out1, primal_out2 = primal_outs + tangent_out1, tangent_out2 = tangent_outs + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_type) + self.assertAllClose(primal_out1, 5.) + self.assertIsInstance(tangent_out1, scalar_type) + self.assertAllClose(tangent_out1, 91.) + self.assertIsInstance(primal_out2, jax.Array) + self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) + self.assertIsInstance(tangent_out2, jax.Array) + self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) + + def test_symbolic_zero_custom_jvp_vmap_output(self): + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + self.assertIs(type(y_dot), jax.custom_derivatives.SymbolicZero) + return f(x, y), y_dot + + jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + return f(x, y), y_dot + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_symbolic_zeros_under_jit(self): + # https://github.com/jax-ml/jax/issues/14833 + Zero = jax.custom_derivatives.SymbolicZero + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def fjvp(primals, tangents): + x, y = primals + tx, ty = tangents + assert type(tx) is not Zero or type(ty) is not Zero + return f(x, y), ( + ty if type(tx) is Zero else + tx if type(ty) is Zero else + tx + ty) + + jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash + + def test_custom_jvp_functools_partial(self): + def fun(x, y, a): + return x + y * a + + fun_wrapped = functools.partial(fun, a = 0.1) + + def jvp_fn(primals, tangents): + return jax.jvp(fun_wrapped, primals, tangents) + + fn = jax.custom_jvp(fun_wrapped) + fn.defjvp(jvp_fn) + + self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_jvp + def f(x, y): + return x + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, _ = primals + x_dot, _ = tangents + return x, x_dot + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_dce(self): + @jax.custom_jvp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + x, y = primals + dx, dy = tangents + return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) + self.assertAllClose( + api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) + self.assertAllClose( + api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) + self.assertAllClose( + api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) + + def test_resolve_kwargs_error_message(self): + @jax.custom_jvp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + def test_symbolic_zero_custom_jvp_vmap_doesnt_instantiate(self): + @jax.custom_jvp + def f(x, y): + return y + + def f_jvp(primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + assert type(y_dot) is jax.custom_derivatives.SymbolicZero + return y, y_dot + + f.defjvp(f_jvp, symbolic_zeros=True) + + def g(x): + return f(x, f(x, 1.)) + + jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash + + def test_symbolic_zero_under_vmap_of_jit(self): + # https://github.com/jax-ml/jax/issues/28144 + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(x, t): + (x,) = x + (t,) = t + z = jax.custom_derivatives.zero_from_primal(x, symbolic_zeros=True) + return f(x), z + + x = jnp.arange(3.0) + jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash + + def test_pretty_print(self): + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(primals, tangents): + return f(*primals), tangents[0] + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_jvp_call[ + name=f + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + jvp=f_jvp + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + + +class CustomVJPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + self.assertAllClose(api.value_and_grad(f)(x), + (jnp.sin(x), 2 * jnp.cos(x))) + + def test_invariance(self): + @jax.custom_vjp + def f(x): + return jnp.cos(2 * x) / 2. + def f_fwd(x): + return (f(x), x) + def f_rev(x, g): + return (g * 3,) + f.defvjp(f_fwd, f_rev) + def f2(x): + y, _ = api.value_and_grad(f)(x) + return y + def f3(x): + y, _ = api.value_and_grad(f2)(x) + return y + x = 1. + self.assertAllClose(f(x), f2(x), check_dtypes=False) + self.assertAllClose(f(x), f3(x), check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_vjp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_fwd(x): + if x > 0: + return f(x), x + else: + return f(x), x + def f_rev(x, g): + if x > 0: + return (2 * g,) + else: + return (3 * g,) + f.defvjp(f_fwd, f_rev) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), + check_dtypes=False) + + def test_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_fwd(x): + assert jnp.ndim(x) == 0 + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of grad of f + self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) + self.assertAllClose(api.vmap(api.value_and_grad(f))(x), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) + self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx))) + + # grad of vmap of f + self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), + 2 * jnp.cos(x)) + self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), + 2 * jnp.cos(xx)) + + # vmap of grad of vmap of f + self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), + 2 * jnp.cos(xx)) + + def test_jit(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of grad + self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + # grad of jit + self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_vjp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_fwd(x): + return f(x), {'r': jnp.cos(x['a'])} + def f_bwd(res, g): + cos_x = res['r'] + return ({'a': 2 * cos_x * g['b']},) + f.defvjp(f_fwd, f_bwd) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), + {'a': 2 * jnp.cos(x['a'])}) + + def test_jvp_error(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(jit(f), (3.,), (1.,))) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_vjp + def my_fun(x, y, c=1.): + return c * (x + y) + my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), + lambda _, g: (g, g, g)) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.grad(f)(10., 5.) # doesn't crash + + def test_initial_style(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = -2. * jnp.sin(3.) + self.assertAllClose(ans, expected) + + def test_initial_style_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_jit_tracer(self): + # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. + raise unittest.SkipTest("behavior no longer supported") + + # This test is similar to test_nondiff_arg_tracer except it uses lexical + # closure rather than the nondiff_argnums mechanism. We decided to disallow + # tracers in nondiff_argnums to greatly simplify bookkeeping while still + # supporting the cases for which it is necessary. + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @jit + def g(x, y): + return outer(x)(y) + + ans = g(2, 3.) + expected = 6. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g, 1)(2., 3.) + expected = jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_vmap_tracer(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.arange(3.) * 3 + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_tracer3(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), (x, jnp.cos(y)) + def f_rev(res, g): + x, cos_y = res + return (cos_y * g * x,) + f.defvjp(f_fwd, f_rev) + return api.grad(f) + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.cos(3.) * np.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_tracer_error(self): + # This is similar to the old (now skipped) test_nondiff_arg_tracer, except + # we're testing for the error message that usage pattern now raises. + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(x, cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + + @jit + def g(x, y): + return f(x, y) + + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = g(2, 3.) + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = api.grad(g, 1)(2., 3.) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_vjp_rule_error(self): + @jax.custom_vjp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: api.grad(foo)(2.)) + + def test_vjp_rule_inconsistent_pytree_structures_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return (g, g) + + f.defvjp(foo_fwd, foo_bwd) + + f(2) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP bwd rule must produce an output with the same container " + "(pytree) structure as the args tuple of the primal function, " + "and in particular must produce a tuple of length equal to the " + "number of arguments to the primal function, but got bwd output " + "structure {} for primal input structure {}.".format( + jax.tree.structure((1, 1)), + jax.tree.structure((1,))) + ), + lambda: api.grad(f)(2.)) + + def test_vjp_bwd_returns_non_tuple_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return 2. * g # Should be a tuple + + f.defvjp(foo_fwd, foo_bwd) + with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): + api.grad(f)(3.) + + def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return (x, x), None + + def f_bwd(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd, f_bwd) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_vjp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def f_fwd2(x): + return jnp.zeros((3, *x.shape), x.dtype), None + + def f_bwd2(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd2, f_bwd2) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_vjp-decorated function f had output " + "shapes/dtypes of:\n" + " float32[]" + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def test_issue2511(self): + arr = jnp.ones((5, 2, 2)) + foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) + api.jit(foo)(arr) # doesn't crash + + def test_lowering_out_of_traces(self): + # https://github.com/jax-ml/jax/issues/2578 + + class F(collections.namedtuple("F", ["a"])): + def __call__(self, x): + return jax.nn.relu(self.a) * x + + @jax.jit + def g(f, x): + return f(x) + + jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash + + def test_clip_gradient(self): + # https://github.com/jax-ml/jax/issues/2784 + @jax.custom_vjp + def _clip_gradient(lo, hi, x): + return x # identity function when not differentiating + + def clip_gradient_fwd(lo, hi, x): + return x, (lo, hi,) + + def clip_gradient_bwd(res, g): + lo, hi = res + return (None, None, jnp.clip(g, lo, hi),) + + _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) + + def clip_gradient(x): + lo = -0.1 + hi = x + 0.1 + return _clip_gradient(lo, hi, x) + + g = jax.grad(clip_gradient)(0.1) # doesn't crash + self.assertAllClose(g, jnp.array(0.2)) + + def test_nestable_vjp(self): + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. + def f(x): + return x ** 2 + + @jax.custom_vjp + def g(x): + return f(x) + + def g_fwd(x): + y, f_vjp = api.vjp(f, x) + return y, f_vjp + + def g_bwd(f_vjp, y_bar): + return f_vjp(y_bar) + + g.defvjp(g_fwd, g_bwd) + + # Check that VJP can be nested in simple situations. For this to pass, + # vjp has to return a PyTree. + _, g_vjp = api.vjp(g, 1.0) + y, = g_vjp(1.0) + self.assertAllClose(y, jnp.array(2.0)) + + # Check that VJP can be nested in complex situations. For this to pass, + # vjp can't treat the closed-over tracer x as a static argument. + @jit + def z(x): + _, g_vjp = api.vjp(g, x) + return g_vjp + y, = z(1.0)(3.0) + self.assertAllClose(y, jnp.array(6.0)) + + def test_initial_style_vmap_2(self): + # https://github.com/jax-ml/jax/issues/4173 + x = jnp.ones((10, 3)) + + # Create the custom function + @jax.custom_vjp + def custom_fun(x): + return x.sum() + + def forward(x): + return x.sum(), (jnp.ones_like(x),) + + def backward(res, g): + return g * res[0], + + custom_fun.defvjp(forward, backward) + + def train_fun(x): + + def summed_fun(x): + return api.vmap(custom_fun)(x).sum() + + return api.grad(summed_fun)(x) + + def scan_body(carry, inputs): + x = carry + return carry, train_fun(x) + + scan_range = jnp.arange(4) + lax.scan(scan_body, x, scan_range) # don't crash + + def test_initial_style_vmap_3(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) * 6 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_vjp + def f(x): + return lax.psum(x, 'foo') + + def f_fwd(x): + return lax.psum(x, 'foo'), None + + def f_bwd(res, dx): + return dx + f.defvjp(f_fwd, f_bwd) + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(out, 8.) + + def test_bwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), () + + def bwd(_, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_fwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), y + + def bwd(y, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_float0(self): + @jax.custom_vjp + def f(x, _): + return x + def f_fwd(x, _): + # we need a defined (non-float0) tangent to trigger the rule + return x, (2., 1) + def f_rev(*_): + return (2., 1) + f.defvjp(f_fwd, f_rev) + + x = 2. + y = 3 + self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_float0_initial_style(self): + @jax.custom_vjp + def f(x): + return x + def f_fwd(x): + return x, (2., x) + def f_rev(*_): + return ((2., jnp.zeros(shape=(), dtype=float0)),) + f.defvjp(f_fwd, f_rev) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) + return out[0] + + x = 2. + y = 3 + self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_remat(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(jax.remat(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(jax.remat(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f(x, x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_vmap(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) + expected = 2 * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_pytree(self): + @jax.custom_vjp + def f(xs, y): + x1, x2 = xs + return x1 * x2 * jnp.sin(y) + def f_fwd(xs, y): + return f(xs, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f((x, x), x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_vjp_closure_4521(self): + # https://github.com/jax-ml/jax/issues/4521 + @jax.custom_vjp + def g(x, y): + return None + def g_fwd(x, y): + return None, y + def g_bwd(residuals, z_bar): + assert False + + g.defvjp(g_fwd, g_bwd) + + def f(xs, y): + v_g = api.vmap(g, in_axes=(0, None), out_axes=None) + v_g(xs, y) + + def scan_body(xs, _): + y = jnp.zeros(1) + _, vjp_f = api.vjp(f, xs, y) + vjp_f(None) + return xs, None + + lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash + + def test_float0_bwd_none(self): + @jax.custom_vjp + def f(i, x): + return jnp.sin(x) + def f_fwd(i, x): + return f(i, x), jnp.cos(x) + def f_rev(cos_x, g): + return (None, 2 * cos_x * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_gradient(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: (g * x,) + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_custom_gradient_2(self): + @jax.custom_gradient + def f(x, y): + return x * y, lambda g: (y, x) + + self.assertAllClose(f(3., 4.), 12., check_dtypes=False) + self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), + check_dtypes=False) + + def test_custom_gradient_3(self): + @jax.custom_gradient + def f(x): + vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) + return jnp.sum(jnp.sin(x)), vjp + + self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), + check_dtypes=False) + self.assertAllClose( + api.grad(f)(jnp.arange(3.)), + api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), + check_dtypes=False) + + def test_custom_gradient_can_return_singleton_value_in_vjp(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: g * x + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_closure_convert(self): + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, x): + return jnp.sum((x - c) ** 2.) + + def solve(c, x): + def closure(x): + return dist(c, x) + return cos_after(closure, x) + + c, x = 2. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, x)) + self.assertAllClose(solve(c, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_mixed_consts(self): + # Like test_closure_convert, but close over values that + # participate in AD as well as values that do not. + # See https://github.com/jax-ml/jax/issues/6415 + + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, s, x): + return jnp.sum(s * (x - c) ** 2.) + + def solve(c, s, x): + def closure(x): + return dist(c, s, x) + return cos_after(closure, x) + + c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, s, x)) + self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_pytree_mismatch(self): + # See https://github.com/jax-ml/jax/issues/23588 + def f(x, z): + return z * x + + x, z = 2.0, 3.0 + _, vjp = api.vjp(f, x, z) + vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) + vjp_pure(x, *vjp_aux_args) + with self.assertRaisesRegex( + TypeError, "The inputs to the closure produced by closure_convert"): + vjp_pure(x, vjp_aux_args) + + def test_float0_cotangents_automatically_handled(self): + @jax.custom_vjp + def f(x, y): + return x + + def f_fwd(x, y): + return x, None + + def f_bwd(_, zbar): + return (0., 1) + + f.defvjp(f_fwd, f_bwd) + + jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash + + def test_custom_vjp_scan_batching_edge_case(self): + # https://github.com/jax-ml/jax/issues/5832 + @jax.custom_vjp + def mul(x, coeff): return x * coeff + def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) + def mul_bwd(res, g): + x, coeff = res + g_x = g * coeff + g_coeff = (x * g).sum() + return g_x, g_coeff + mul.defvjp(mul_fwd, mul_bwd) + + def scan_over_mul(x, coeff): + def f_(x, t): + return mul(x, coeff), None + y, _ = jax.lax.scan(f_, x, jnp.arange(3)) + return y + + key = jax.random.key(0) + key1, key2 = jax.random.split(key, 2) + x_batch = jax.random.normal(key1, (3, 2)) + covector_batch = jax.random.normal(key2, (3, 2)) + coeff = jnp.array(1., dtype=x_batch.dtype) + + batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) + res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) + vjp_fun(covector_batch) # doesn't crash + + jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, + modes=['rev']) + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_vjp + def g(y): + return x * y + + def g_fwd(y): + return x * y, (x, x * y, y) + def g_rev(res, w_bar): + x, *_ = res + return (x * w_bar,) + g.defvjp(g_fwd, g_rev) + + return g(z) + + return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() + + jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) + + def test_pytrees_not_required_to_contain_nones(self): + class A(list): + pass + + def unflatten(_, children): + assert children[0] is not None + return A(children) + + tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) + + @jax.custom_vjp + def f(x): + return x[0] + def f_fwd(x): + return x[0], None + def f_bwd(_, g): + return A([g]), + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(A([1.])) # doesn't crash + + def test_vmap_vjp_called_twice(self): + # https://github.com/jax-ml/jax/pull/14728 + @jax.custom_vjp + def f(x): + return x + f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) + + _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) + f_vjp(jnp.array([3.])) + f_vjp(jnp.array([3.])) # doesn't crash + + def test_symbolic_zero_custom_vjp_basic(self): + ZERO = jax.custom_derivatives.SymbolicZero + + @jax.custom_vjp + def f(x, y, z): + return x, x + + def fwd(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + self.assertFalse(z.perturbed) + return (x.value, x.value), None + + def fwd_all(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertTrue(y.perturbed) + self.assertTrue(z.perturbed) + return (x.value, x.value), None + + def bwd_all(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + def bwd_fst(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertIs(type(x2), ZERO) + return x1, x1, x2 + + def bwd_snd(_, g): + x1, x2 = g + self.assertIs(type(x1), ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + x, y, z = 4., 5., 6. + i = np.array(7, np.int32) + zero = np.array(0.) + + f.defvjp(fwd, bwd_all, symbolic_zeros=True) + h = jax.jit(f) + jax.jacrev(h)(x, y, z) + jax.jacrev(lambda x: h(x, y, z))(x) + jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) + + f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) + fst_f = lambda *xs: f(*xs)[0] + _, vjp = jax.vjp(fst_f, x, y, z) + _, _, gz = vjp(x) + self.assertArraysAllClose(gz, zero) + + f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) + snd_f = lambda *xs: f(*xs)[1] + _, vjp = jax.vjp(snd_f, x, y, z) + gx, gy, _ = vjp(x) + self.assertArraysAllClose(gx, zero) + self.assertArraysAllClose(gy, zero) + + f.defvjp(fwd, bwd_snd, symbolic_zeros=True) + _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) + gx, = vjp(x) + self.assertArraysAllClose(gx, zero) + + def test_symbolic_zero_custom_vjp_bwd_shape_error(self): + @jax.custom_vjp + def f(x, y, z): + return x, y, z + + def fwd(x, y, z): + return f(x.value, y.value, z.value), None + + def bwd(_, gs): + x_bar, y_bar, z_bar = gs + return y_bar, x_bar, z_bar # swapped! + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + with self.assertRaisesRegex( + ValueError, + r'Consider just returning a None here'): + jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( + jnp.ones(1), jnp.ones(2), jnp.ones(3)) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): + # below: + # * static_scalar will be static in and out + # * static_array will be static in, but dynamic out + # * dyn_scalar and dyn_array will be dynamic in and out + + ZERO = jax.custom_derivatives.SymbolicZero + + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return static_scalar, static_array, out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_vjp(f) + + def fwd(*args): + xs, pert = [x.value for x in args], [x.perturbed for x in args] + self.assertFalse(pert[0]) + self.assertFalse(pert[1]) + self.assertTrue(pert[2]) + self.assertTrue(pert[3]) + return f(*xs), xs + + def bwd(res, g): + static_scalar, *_ = res + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g + self.assertIs(type(t_static), ZERO) + self.assertFalse(type(t_static_arr) is ZERO) + self.assertFalse(type(t_dyn_scalar) is ZERO) + self.assertFalse(type(t_dyn_array) is ZERO) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return (static_scalar + 90, + t_static_arr + 91, + t_dyn_scalar + 92, + t_dyn_array + 93) + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + return outs[1:] + + def run(primal_ins, cotangent_outs): + primal_outs, vjp = jax.vjp(g, *primal_ins) + cotangent_ins = vjp(cotangent_outs) + return primal_outs, cotangent_ins + + if maybe_jit: + run = jax.jit(run) + + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + primal_ins = (4., jnp.array([5., 6.])) + cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) + primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) + + primal_out1, primal_out2, primal_out3 = primal_outs + self.assertIsInstance(primal_out1, jax.Array) + self.assertAllClose(primal_out1, jnp.array([2., 3.])) + self.assertIsInstance(primal_out2, scalar_type) + self.assertAllClose(primal_out2, 5.) + self.assertIsInstance(primal_out3, jax.Array) + self.assertAllClose(primal_out3, jnp.array([7., 9.])) + + ct_in1, ct_in2 = cotangent_ins + self.assertIsInstance(ct_in1, scalar_type) + self.assertAllClose(ct_in1, 99.) + self.assertIsInstance(ct_in2, jax.Array) + self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) + + def test_symbolic_zero_custom_vjp_vmap_output(self): + @jax.custom_vjp + def f(x, y): + return x, y + + def fwd(x, y): + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + return f(x.value, y.value), None + + def bwd(_, g): + _, ct_y = g + self.assertIs(type(ct_y), jax.custom_derivatives.SymbolicZero) + return g + + f.defvjp(fwd, bwd, symbolic_zeros=True) + jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zero_custom_vjp_custom_pytree(self): + tree_values = jax.custom_derivatives.custom_vjp_primal_tree_values + + @tree_util.register_pytree_node_class + class Box: + def __init__(self_, strict, val): + if strict: + # make sure we aren't getting special arguments that should only + # come up when symbolic_zeros is True + self.assertFalse(hasattr(val, 'perturbed')) + self_.strict = strict + self_.x = val + + def tree_flatten(self_): + return [self_.x], self_.strict + + @classmethod + def tree_unflatten(cls, strict, xs): + x, = xs + return cls(strict, x) + + x, y = Box(False, jnp.array(72.)), jnp.array(73.) + + @jax.custom_vjp + def f(box, y): + return box.x * y + + def fwd0(box, y): + self.assertTrue(box.x.perturbed) + self.assertFalse(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd0(res, g): + box, y = res + return y * g, box.x * g + + def fwd1(box, y): + self.assertFalse(box.x.perturbed) + self.assertTrue(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd1(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd0, bwd0, symbolic_zeros=True) + jax.grad(f, argnums=0)(x, y) + f.defvjp(fwd1, bwd1, symbolic_zeros=True) + jax.grad(f, argnums=1)(x, y) + + def fwd_strict(box, y): + return f(box, y), (box, y) + + def bwd_strict(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd_strict, bwd_strict) + jax.grad(f)(x, y) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return x.value, None + + def f_bwd(_, z_bar): + return z_bar, None + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_vjp + def f(x, y): + return x + y + + def f_fwd(x, y): + if y.perturbed: + res = None + else: + res = [] + return x.value + y.value, res + + def f_bwd(res, ct): + return ct, ct + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): + # https://github.com/jax-ml/jax/issues/8356 + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, 3.0))) # don't crash + + def test_pytree_nones_returned_by_bwd(self): + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, None))) # don't crash + + def test_bwd_rule_shape_mismatch(self): + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with self.assertRaisesRegex( + ValueError, + r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): + jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_shape_mismatch_disable(self): + # TODO(mattjj): remove this test when the config option is removed + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with config.custom_vjp_disable_shape_check(True): + jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_can_produce_list_or_tuple(self): + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(xy, g): + x, y = xy + return [g * y, x * g] # list, not tuple + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(1., 2.) # don't crash + + def test_optimize_remat(self): + def fun(x): + # This array is included to make sure that we handle consts appropriately + return np.array([1.0])*x + + def fwd(x): + return np.array([2.0])*x*x/np.array([1.0]), (x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE + self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed + + def test_optimize_remat_vmap(self): + def fun(x): + return (np.array([1.0])*x)[0] + def fwd(x): + return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) + self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) + + def test_optimize_remat_cond(self): + def fun(x): + return x + def fwd(x): + return x*x, (x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + def g(x): + return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) + + self.assertAllClose(jax.jit(g)(x)[0], x*x) + self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) + + def test_optimize_remat_jvp(self): + def fun(x): + return x**2 + def fwd_(x): + return x*x, (x,) + + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) + calc = jax.jvp(fwd, (3.2,), (1.0,)) + expected = jax.jvp(fwd_, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + @jax.jit + def g(x, t): + (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) + return y, y_dot + calc = g(3.2, 1.0) + expected = jax.jvp(fun, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + def test_optimize_remat_gh21303(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jnp.sin(x), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_optimize_remat_multiple_args(self): + def f_(x, y): + return jnp.sin(x) * y + + @jax.custom_vjp + def f(x, y): + return f_(x, y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + + def test_optimize_remat_custom_vmap(self): + # See https://github.com/jax-ml/jax/pull/23000 + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + @jax.custom_batching.custom_vmap + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + @f_fwd.def_vmap + def f_fwd_vmap(_, in_batched, x, y): + # Insert a new const here to test the optimize_remat batching rule. + out = np.array([2.0])*f(x, y) + out_batched = (True, (True, True, True)) + return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) + jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + + def test_dce(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(y)) + + def f_bwd(res, cts): + cos_x, sin_y = res + ct_a, ct_b = cts + return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b + + f.defvjp(f_fwd, f_bwd) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["fun_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) + self.assertAllClose( + api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) + + def test_resolve_kwargs_error_message(self): + @jax.custom_vjp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + self.fail("should not be executed") + + def f_bwd(res, cts): + self.fail("should not be executed") + + f.defvjp(f_fwd, f_bwd) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +def transpose_unary(f, x_example): + def transposed(y): + x, = api.linear_transpose(f, x_example)(y) + return x + return transposed + + +# This class wraps jax.custom_transpose.custom_transpose in order to pass in a +# particular tree of output type on each call. Otherwise it forwards +# all attribute access. +class _custom_transpose: + def __init__(self, out_types, fun): + self.out_types = out_types + self.fun = jax.custom_transpose.custom_transpose(fun) + + def __getattr__(self, name): + return getattr(self.fun, name) + + def __call__(self, *args): + return self.fun(self.out_types, *args) + + +# This function is meant to be used as a decorator that delegates to +# custom_transpose but makes it easy to specify output argument types +# by example. If used directly a decorator (i.e. not invoked with +# example arguments), assumes a scalar-valued function. +# +# TODO(frostig): remove this (and its uses) once custom_transpose offers +# an option of inferring output types. +def custom_transpose(example_out): + if isinstance(example_out, Callable): + out_type = core.get_aval(0.).to_tangent_aval() + return _custom_transpose(out_type, example_out) + return partial( + _custom_transpose, + jax.tree.map( + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) + + +class CustomTransposeTest(jtu.JaxTestCase): + + def test_linear_call(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_incorrect_transpose(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: not the true transpose + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_transpose_transpose_transpose(self): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: untrue transpose + def f_(x, y): + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_linear_call_scalar_to_vector(self): + def f(c, x): + def fn(_, x): + return [x, x] + + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return jax.custom_derivatives.linear_call(fn, tp, (), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_linear_call_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + def f(_, x): return x + def t(_, t): return 0. + return jax.custom_derivatives.linear_call(f, t, (), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + def f_(_, x): return id_(x) + def t_(_, t): return id_(7.) + return jax.custom_derivatives.linear_call(f_, t_, (), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_linear_call_jit(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f1 = lambda x: f(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + jax.jit(transpose_unary(f1, x))(x)) + + def test_linear_call_type_mismatch(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return None + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f1 = lambda x: f(x, y) + with self.assertRaisesRegex(TypeError, "transpose output pytree"): + transpose_unary(f1, x)(x) + + def test_linear_call_recursion(self): + def f(x): + def fn(_, x): return x + def tp(_, t): return f(t) + return jax.custom_derivatives.linear_call(fn, tp, None, x) + jax.jit(f)(0.1) + + def test_linear_call_grad(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.array(6.) + y = jnp.array(3.) + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y)) + + def test_basic(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_incorrect_transpose(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / (2. * r) # nb: not the true transpose + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_transpose_transpose_transpose(self): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @custom_transpose(jnp.ones(2)) + def tp(r, t): return t / (2. * r) # nb: untrue transpose + + fn.def_transpose(tp) + tp.def_transpose(fn) + + def f_(x, y): + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_scalar_to_vector(self): + def f(c, x): + @custom_transpose([0., 0.]) + def fn(_, x): + return [x, x] + + @fn.def_transpose + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return fn((), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + f = custom_transpose(lambda _, x: x) + t = custom_transpose(lambda _, t: 0.) + f.def_transpose(t) + t.def_transpose(f) + return f((), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + f_ = custom_transpose(lambda _, x: id_(x)) + t_ = custom_transpose(lambda _, t: id_(7.)) + f_.def_transpose(t_) + t_.def_transpose(f_) + return f_((), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_one_degree(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @f.def_transpose + def ft(_, z): return 3. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(3., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... + + def test_two_degrees(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + + @f.def_transpose + @custom_transpose + def ft(_, z): return 3. * z + + @ft.def_transpose + def ftt(_, z): return 7. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(7., T(T(f))(1.)) + self.assertAllClose(7., T(T(T(f)))(1.)) + self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... + + def test_symmetric(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @custom_transpose + def g(_, z): return 3. * z + + f.def_transpose(g) + g.def_transpose(f) + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(2., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... + + def test_recursive(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(c, z): return c * z + + @f.def_transpose + def ft(c, z): return f(c + 1., z) + + g = partial(f, 1.) + self.assertAllClose(1., g(1.)) + self.assertAllClose(2., T(g)(1.)) + self.assertAllClose(3., T(T(g))(1.)) + self.assertAllClose(4., T(T(T(g)))(1.)) + self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... + + def test_jvp_lin(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx = 6., 3., 1. + g = lambda x: f(x, y) + g_ref = lambda x: f_ref(x, y) + self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) + + def test_jvp_res(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, ty = 6., 3., 1. + g = lambda y: f(x, y) + g_ref = lambda y: f_ref(x, y) + self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) + + def test_jvp_both(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx, ty = 6., 3., 1., 1. + self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), + api.jvp(f_ref, [x, y], [tx, ty])) + + def test_make_jaxpr(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + + jaxpr = api.make_jaxpr(f_)(x) + self.assertIn('custom_transpose_call', str(jaxpr)) + + jaxpr_t = api.make_jaxpr(f_t)(x) + self.assertNotIn('custom_transpose_call', str(jaxpr_t)) + + def test_jit(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_jit_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_jit_signature_deprecation(self): + fun = lambda x: x + if deprecations.is_accelerated('jax-jit-positional-args'): + with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): + jax.jit(fun=fun) + with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): + jax.jit(fun, None) + else: + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): + jax.jit(fun=fun) + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): + jax.jit(fun, None) + + def test_cond(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + + +class CustomDceTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) + + def test_recursive(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.exp(x), 10 * jnp.sqrt(x) + + @f.def_dce + def f_dce(used_outs, x): + return [2 * v if used else None for used, v in zip(used_outs, f(x))] + + x = 1.1234 + expected = f(x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) + + def test_multiple_rounds(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, z): + return jnp.sin(x), jnp.sin(y), jnp.sin(z) + + @f.def_dce + def rule(used_outs, x, y, z): + patterns.append(used_outs) + outs = [ + jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) + ] + return outs + + patterns = [] + x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) + jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr + new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) + assert used_ins == [True, False, True] + new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) + assert used_ins == [True, False] + assert patterns == [(True, False, True), (True, False, False)], patterns + + def test_batching(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y): + return jnp.sin(x), jnp.sin(y) + + @f.def_dce + def rule(used_outs, x, y): + return ( + jnp.cos(x) if used_outs[0] else None, + jnp.cos(y) if used_outs[1] else None, + ) + + x = jnp.linspace(-0.1, 0.2, 5) + y = jnp.linspace(3.0, 4.0, 5) + self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) + ) + + def test_composes_with_custom_vjp(self): + # custom_dce must be the "outer" decorator (for now!) because custom_vjp + # doesn't pass through DCE. + @jax.experimental.custom_dce.custom_dce + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + return ( + jnp.cos(x) * y if used_outs[0] else None, + x * jnp.cos(y) if used_outs[1] else None, + ) + + def f_fwd(x, y): + return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) + + def f_bwd(res, g): + ga, gb = g + x, cos_x, sin_x, y, cos_y, sin_y = res + return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) + + f.defvjp(f_fwd, f_bwd) + + x, y = jnp.array(1.), jnp.array(2.) + self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), + jnp.cos(x) * y) + jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. + + def test_can_optimize_remat(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + @jax.experimental.custom_dce.custom_dce + def f_fwd(x): + return jnp.sin(x), (x,) + + @f_fwd.def_dce + def f_dce_rule(used_outs, x): + used_prim, used_res = used_outs + used_res, = used_res + if not used_res: + return f(x), None + prim, res = f_fwd(x) + return prim if used_prim else None, res + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_static_argnums(self): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) + def g(f, x): + return f(x), 10 * f(x) + + @g.def_dce + def g_dce(f, used_outs, x): # note: static_argnums are always passes first + self.assertTrue(callable(f)) + return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] + + x = 1.1234 + f = lambda x: jnp.exp(x) + expected = g(f, x) + self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) + + def test_shape_mismatch_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.stack((x, x)), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + x.astype(jnp.int32) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', + ): + jax.jit(lambda x: f(x)[1])(x) + + def test_missing_output_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return None, None + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* produce values for all .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + + def test_consts(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return np.eye(1) * jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + expected = rule([True, True], x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) + + def test_resolve_kwargs_error_message(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, *, z=None): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomVmapTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, jnp.cos(xs)) + + @jax.numpy_dtype_promotion('standard') + def test_closure(self): + z = jnp.array([2., 1., 3.]) + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, *args): + self.assertEqual(len(in_batched), 1) + self.assertEqual(len(args), 1) + xs, = args + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return z + jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, z + jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, z + jnp.cos(xs)) + + def test_rule_multi_output(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x), jnp.cos(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) + + x, xs = jnp.array(1.), jnp.arange(3) + y1, y2 = f(x) + self.assertAllClose(y1, jnp.sin(x)) + self.assertAllClose(y2, jnp.cos(x)) + ys1, ys2 = api.vmap(f)(xs) + self.assertAllClose(ys1, jnp.cos(xs)) + self.assertAllClose(ys2, jnp.sin(xs)) + + def test_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y ** 2. + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(in_batched, [True, True]) + self.assertEqual(axis_size, 3) + self.assertEqual(axis_size, xs.shape[0]) + self.assertEqual(axis_size, ys.shape[0]) + return jnp.cos(xs) + ys ** 2., True + + xs, ys = jnp.arange(3.0), jnp.arange(3.0) + zs = api.vmap(f)(xs, ys) + self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) + + def test_nary_mixed_batching(self): + @jax.custom_batching.custom_vmap + def vector_dot(u, v): + self.assertEqual(u.ndim, 1) + self.assertEqual(v.ndim, 1) + return u @ v + + size = 4 + vlen = 3 + in_batched_log = [] + + @vector_dot.def_vmap + def vector_dot_vmap_rule(axis_size, in_batched, u, v): + in_batched_log.append(in_batched) + self.assertEqual(axis_size, size) + u_batched, v_batched = in_batched + if u_batched: + self.assertEqual(u.ndim, 2) + self.assertEqual(u.shape[0], size) + else: + self.assertEqual(u.ndim, 1) + self.assertEqual(u.shape[0], vlen) + if v_batched: + self.assertEqual(v.ndim, 2) + self.assertEqual(v.shape[0], size) + else: + self.assertEqual(v.ndim, 1) + self.assertEqual(v.shape[0], vlen) + if u_batched and v_batched: + out = jnp.sum(u * v, axis=1) + else: + out = u @ v if u_batched else v @ u + return out, u_batched or v_batched + + f = vector_dot + v = lambda *shape: jnp.ones(shape) + + y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) + self.assertAllClose(y, v(4, 3) @ v(3)) + y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) + self.assertAllClose(y, v(3, 4).T @ v(3)) + y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) + self.assertAllClose(y, v(3) @ v(4, 3).T) + y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) + self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) + self.assertEqual(in_batched_log[0], [True, False]) + self.assertEqual(in_batched_log[1], [True, False]) + self.assertEqual(in_batched_log[2], [False, True]) + self.assertEqual(in_batched_log[3], [True, True]) + + def test_rule_input_signature(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + rule_args = [] + + @f.def_vmap + def rule(axis_size, in_batched, xs): + rule_args.append((axis_size, in_batched)) + return jnp.cos(xs), in_batched[0] + + xs = jnp.arange(3) + _ = api.vmap(f)(xs) + (axis_size, in_batched), = rule_args + self.assertIs(type(axis_size), int) + self.assertIs(type(in_batched), list) + self.assertEqual(len(in_batched), 1) + + def test_rule_output_vs_batching_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc(axis_size, in_batched, xs): + return [jnp.sin(xs), jnp.cos(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + 'structure of output value and output batching specification ' + r'returned by custom vmap rule \(test_rule_abc\) do not match.*', + lambda: api.vmap(f)(xs)) + + def test_rule_vs_call_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc2(axis_size, in_batched, xs): + return [jnp.sin(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' + r'does not match that of original custom-vmapped function.*', + lambda: api.vmap(f)(xs)) + + def test_jvp_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, jnp.sin(x)) + self.assertAllClose(ty, jnp.cos(x) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + @jax.numpy_dtype_promotion('standard') + def test_jvp_closure(self): + z = jnp.array([2., 1., 3.]) + def bcast(x): return z + x - z + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return z + jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, z + jnp.sin(x)) + self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + def test_jvp_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True, True]) + return jnp.cos(xs) + ys, True + + f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) + + x, y, tx, ty = jnp.arange(4.) + xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) + + zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + def test_jvp_extra_batched_tangents(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) + self.assertAllClose(y, jnp.cos(x)) + self.assertAllClose(tys, -jnp.sin(x) * txs) + + def test_jacfwd(self): + # jacfwd is another way to exercise extra-batched tangents + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + x = jnp.arange(3.) + .72 + j = api.jacfwd(f)(x) + self.assertAllClose(j, -jnp.diag(jnp.sin(x))) + + def test_jvp_extra_batched_primals(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + xs = jnp.arange(3.) + tx = jnp.array(4, dtype=xs.dtype) + ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * tx) + + def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): + # When a function is linear, its Jacobian is constant. JAX's JVP + # of linear functions takes advantage of this: when mapping over a + # batch of primals relative to a fixed (i.e. symbolically + # replicated) tangent, output tangents remain replicated as well + # (i.e. JAX will not broadcast them). This is true in general, and + # this test checks that vmapped JVPs continue to behave this way + # when custom_vmap is involved and the custom vmap rule is linear. + + @jax.custom_batching.custom_vmap + def f_linear(x): return 7. * x + + @f_linear.def_vmap + def linear_rule(axis_size, in_batched, xs): + return 11. * xs, in_batched[0] + + @jax.custom_batching.custom_vmap + def f_nonlinear(x): return jnp.sin(x) + + @f_nonlinear.def_vmap + def nonlinear_rule(axis_size, in_batched, xs): + return jnp.cos(xs), in_batched[0] + + f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) + f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) + xs = jnp.arange(3.) + tx = jnp.array(4., dtype=xs.dtype) + + # doesn't err + _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) + + def test_jvp_dataflow_violation(self): + # The jvp-of-custom-vmap machinery should not assume the standard + # dataflow constraint on the JVP of the custom vmap rule (primal + # outputs independent of tangent inputs). Both jvp and vmap are + # "forward" transformations under which, at present, we don't + # enforce the JVP dependence diagram. Because output primals can + # depend on input tangents, extra-batched input tangents can + # create batched output primals, as this test checks. + + @jax.custom_jvp + def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) + + @cos_with_invalid_dataflow_jvp.defjvp + def invalid_dataflow_jvp(x, tx): + [x], [tx] = x, tx + return jnp.cos(x * tx), tx + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return cos_with_invalid_dataflow_jvp(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + + # doesn't err + ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) + self.assertAllClose(ys, jnp.cos(x * txs)) + self.assertAllClose(tys, txs) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) + + def test_tree(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, x + 2], [x + 3], x + 4) + xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_tree_with_nones(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, None], [x + 3], None) + xs = (xs, [xs + 1, None], [xs + 3], None) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_jit(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [True]) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), in_batched[0] + + x, xs = jnp.array(1.), jnp.arange(3) + self.assertAllClose(f(x), jit(f)(x)) + self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) + self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) + + def test_sequential_vmap_basic(self): + @jax.custom_batching.sequential_vmap + def f(x): + return x + 1. + + def vmap_ref(xs): + return lax.map(f, xs) + + xs = jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_same_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, ys): + return lax.map(lambda args: f(*args), (xs, ys)) + + xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_mixed_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, y): + return lax.map(lambda x: f(x, y), xs) + + xs, y = jnp.arange(3.), 4. + jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + @parameterized.named_parameters( + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_basic(self, batch_size: int): + def f(x): + self.assertEqual(x.shape, ()) + return x**2 + + x = np.arange(16) + y = jax.lax.map(f, x, batch_size=batch_size) + + np.testing.assert_array_equal(y, x**2) + + @parameterized.named_parameters( + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_pytrees(self, batch_size: int): + f = lambda x: {'b': x['a'] ** 2} + inputs = {'a': np.arange(16)} + expected = np.arange(16) ** 2 + + outputs = jax.lax.map(f, inputs, batch_size=batch_size) + self.assertAllClose(outputs['b'], expected) + + outputs = jax.lax.map( + f, inputs, batch_size=batch_size + ) + self.assertAllClose(outputs['b'], expected) + + def test_batch_divides_axis(self): + def f(t): + x, a = t + self.assertEqual(x.shape, (4,)) + return (x + a)**2 + + x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) + a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) + + @jax.jit + def g(x, a): + return jax.lax.map(f, (x, a), batch_size=8) + + y = g(x, a) + + self.assertAllClose(y, (x + a)**2) + + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + + def test_partial_eval_raises(self): + @jax.custom_batching.custom_vmap + def f(x): + return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + del axis_size # unused + return jnp.cos(xs), in_batched[0] + + with self.assertRaisesRegex( + ValueError, + "Linearization failed to produce known values for all output primals", + ): + jax.grad(f)(0.5) + + def test_compose_custom_vjp(self): + @jax.custom_vjp + @jax.custom_batching.custom_vmap + def f(x, y): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + return jnp.cos(xs) * ys, True + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + xs = jnp.linspace(0, 1, 5) + ys = jnp.linspace(-0.1, 0.1, 5) + self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) + jax.grad(f)(xs[0], ys[0]) # Doesn't crash. + + def test_compose_custom_vjp_bwd_rule(self): + # This tests the case where both the forward and backward rules are wrapped + # in custom_vmap. + @jax.custom_batching.sequential_vmap + def fun_fwd(x, y): + return jnp.sin(x) * y, (x, y) + + @jax.custom_batching.sequential_vmap + def fun_bwd(res, ct): + x, y = res + return x * ct, y * ct + + fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) + fun.defvjp(fun_fwd, fun_bwd) + + xs = jnp.linspace(0, 1, 5) + y = jnp.array(0.5, dtype=xs.dtype) + f = jax.vmap(jax.jit(fun), in_axes=(0, None)) + out, f_vjp = jax.vjp(f, xs, y) + f_vjp(out) # Doesn't crash. + + def test_resolve_kwargs_error_message(self): + @jax.custom_batching.custom_vmap + def f(x, y, *, z=None): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomApiTest(jtu.JaxTestCase): + """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" + + def test_method_forwarding(self): + @jax.custom_batching.custom_vmap + @jax.custom_jvp + @jax.custom_transpose.custom_transpose + def f(x): return 2. * x + + # none of these err: + @f.def_vmap + def f_batch(sz, b, xs): return 2. * xs + @f.defjvp + def f_jvp(x, tx): return 2. * x, 2. * tx + @f.def_transpose + def f_transpose(x): return 2. * x + + def test_def_method_forwarding_all_permutations(self): + for wraps in it.permutations([ + jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for wrap in wraps: + f = wrap(f) + for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + for decorators in it.permutations([ + jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for decorator in decorators: + f = decorator(f) + for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From e194d532a46903b8ca5d811ab4d63c6bb401c367 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 May 2025 09:49:03 -0700 Subject: [PATCH 1099/1769] Add `.update` to ShapeDtypeStruct PiperOrigin-RevId: 756804171 --- jax/_src/api.py | 18 ++++++++++++++++++ tests/pjit_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/jax/_src/api.py b/jax/_src/api.py index 379f6d8d7c93..5c8a86c035e6 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2870,6 +2870,24 @@ def __hash__(self): # https://github.com/jax-ml/jax/issues/8182 return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) + def update(self, **kwargs): + if 'sharding' in kwargs: + s = kwargs['sharding'] + if self._dll is not None and isinstance(s, Sharding): + raise ValueError( + f"You are updating ShapeDtypeStruct with a {type(s)} when the" + f" original ShapeDtypeStruct had a concrete layout {self.layout}." + " This might lead to bugs. If you want to do this, create a new" + " ShapeDtypeStruct via the constructor.") + sharding = s + else: + sharding = self.layout + return ShapeDtypeStruct( + shape=kwargs.pop('shape', self.shape), + dtype=kwargs.pop('dtype', self.dtype), + sharding=sharding, + weak_type=kwargs.pop('weak_type', self.weak_type)) + def _sds_aval_mapping(x): aval = ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e4b91aa17644..55c38f6c2ef8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -55,6 +55,7 @@ SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, use_auto_axes, use_explicit_axes, reshard) +from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType @@ -4973,6 +4974,40 @@ def make_keys(seeds): self.assertEqual(out.shape, input_shape) jax.random.key_data(out) # doesn't crash + def test_sds_update(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + s1 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s1_u = s1.update(shape=(4, 2), dtype=np.float32) + self.assertEqual(s1_u.shape, (4, 2)) + self.assertEqual(s1_u.dtype, np.float32) + self.assertFalse(s1_u.weak_type) + + s2 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s2_u = s2.update(shape=(4, 2), weak_type=True) + self.assertEqual(s2_u.shape, (4, 2)) + self.assertEqual(s2_u.dtype, np.int32) + self.assertTrue(s2_u.weak_type) + + s3 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=NamedSharding(mesh, P())) + s3_u = s3.update(sharding=NamedSharding(mesh, P('x'))) + self.assertEqual(s3_u.sharding, NamedSharding(mesh, P('x'))) + + s32_u = s3.update(shape=(4, 2)) + self.assertEqual(s32_u.shape, (4, 2)) + self.assertEqual(s32_u.sharding, NamedSharding(mesh, P())) + + sh = NamedSharding(mesh, P()) + s4 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=Layout(DLL((0, 1)), sh)) + new_layout = Layout(DLL((1, 0)), NamedSharding(mesh, P('x'))) + s4_u = s4.update(sharding=new_layout) + self.assertEqual(s4_u.sharding, new_layout.sharding) + self.assertEqual(s4_u.layout, new_layout) + + with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"): + s4.update(sharding=NamedSharding(mesh, P('x'))) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From 7caebde796b60cf7ff2763b5caa2f4e78b40ca12 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 9 May 2025 10:42:52 -0400 Subject: [PATCH 1100/1769] Use cython from pypi in tsan CI build. Cython 3.1 was released, which means we no longer need a prerelease of cython for free-threaded builds. --- .github/workflows/tsan.yaml | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 596bc425bfeb..882e140b91ad 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -31,7 +31,7 @@ jobs: requirements_lock_name: "requirements_lock_3_13_ft" - name-prefix: "with 3.14" python-version: "3.14" - github_branch: "main" + github_branch: "3.14" requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: @@ -133,9 +133,6 @@ jobs: python3 -m pip install uv~=0.5.30 - # Install Cython same as in numpy CI: https://github.com/numpy/numpy/blob/9ead596ce4f8df0189f9ba3d54937e22e2785a5e/.github/workflows/linux.yml#L75C21-L75C96 - python3 -m uv pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple cython - python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -204,11 +201,8 @@ jobs: python3 -m pip install uv~=0.5.30 - # Install Cython same as in numpy CI: https://github.com/numpy/numpy/blob/9ead596ce4f8df0189f9ba3d54937e22e2785a5e/.github/workflows/linux.yml#L75C21-L75C96 - python3 -m uv pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple cython - python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ - python3 -m uv pip install pythran pybind11 meson-python ninja + python3 -m uv pip install cython pythran pybind11 meson-python ninja python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" @@ -216,8 +210,6 @@ jobs: export CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized - python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" - # Create simple index and copy the wheel mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy @@ -266,7 +258,6 @@ jobs: export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz)) echo "Python sha256: ${PYTHON_SHA256}" - python3 -VV python3 build/build.py build --configure_only \ --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ From 9eb8c9eacbda9f395c2561889af1866b36be7a23 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 9 May 2025 10:21:23 -0700 Subject: [PATCH 1101/1769] [Mosaic] Move sitofp lowering to Mosaic. Also, the compatibility check in fptosi is likely wrong - we should check if both src and dst bit widths < 32, not just dst. Correct it while I'm here. PiperOrigin-RevId: 756816853 --- jax/_src/pallas/mosaic/lowering.py | 9 ++- .../tpu/transforms/canonicalize_mosaic.cc | 80 ++++++++++++++----- jaxlib/mosaic/dialect/tpu/util.h | 10 +++ 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d7e26ec3b342..cdc64bbf96f4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2177,11 +2177,14 @@ def _convert_element_type_lowering_rule( elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: # This case triggers when casting signed to unsigned or vice versa. return x - # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): return arith.fptosi(out_type, x) - elif _from(signed) and _to(floating) and both_32bit: - return arith.sitofp(out_type, x) + elif _from(signed) and _to(floating): + if ( + not (ctx.forward_compatible or is_cloud_tpu_older_than(2025, 5, 12)) + or both_32bit + ): + return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: return arith.extui(out_type, x) return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 110550127ca5..258fec8caff1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -44,13 +44,13 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" namespace mlir::tpu { @@ -601,14 +601,10 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, return op.emitOpError("Vector/scalar mismatch between input and output"); } bool is_vector = static_cast(src_vty); - unsigned src_bitwidth, dst_bitwidth; - if (is_vector) { - src_bitwidth = src_vty.getElementTypeBitWidth(); - dst_bitwidth = dst_vty.getElementTypeBitWidth(); - } else { - src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth(); - dst_bitwidth = op.getType().getIntOrFloatBitWidth(); - } + FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth, + getElementTypeBitwidth(op.getIn().getType())); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth, + getElementTypeBitwidth(op.getType())); if (dst_bitwidth > 32) { return op.emitOpError("Target bitwidth too large"); } @@ -623,6 +619,14 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, op.erase(); return success(); } + + if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) { + return op.emitOpError( + "On this target float-to-integer conversions can only happen on " + "32-bit values. Enable compatibility mode or upcast to float32, cast " + "to int32 and truncate to desired bitwidth."); + } + Value x = op.getIn(); // Upcast the input to f32. if (src_bitwidth < 32) { @@ -634,11 +638,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, } } if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or upcast to float32."); - } // Need to clip values to match XLA auto clip = [&](Value x, Value low, Value high) { x = builder.create(x, low); @@ -666,12 +665,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, x = builder.create(builder.getI32Type(), x); } if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or cast to int32 and " - "truncate later."); - } x = builder.create(op.getType(), x); } op.replaceAllUsesWith(x); @@ -679,6 +672,52 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto src_vty = dyn_cast(op.getIn().getType()); + auto dst_vty = dyn_cast(op.getType()); + if (static_cast(src_vty) != static_cast(dst_vty)) { + return op.emitOpError("Vector/scalar mismatch between input and output"); + } + bool is_vector = static_cast(src_vty); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth, + getElementTypeBitwidth(op.getIn().getType())); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth, + getElementTypeBitwidth(op.getType())); + + if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) { + return op.emitOpError( + "On this target integer-to-float conversions can only happen on " + "32-bit values. Enable compatibility mode or upcast to int32, cast to " + "float32 and truncate to desired bitwidth."); + } + + // Canonicalize (intX -> floatY) to (intX -> int32 -> float32 -> floatY). + Value x = op.getIn(); + if (src_bitwidth < 32) { + if (is_vector) { + x = builder.create( + VectorType::get(src_vty.getShape(), builder.getI32Type()), x); + } else { + x = builder.create(builder.getI32Type(), x); + } + } + if (is_vector) { + x = builder.create( + VectorType::get(src_vty.getShape(), builder.getF32Type()), x); + } else { + x = builder.create(builder.getF32Type(), x); + } + if (dst_bitwidth < 32) { + x = builder.create(op.getType(), x); + } + op.replaceAllUsesWith(x); + op.erase(); + return success(); +} + LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = dyn_cast(raw_op); @@ -727,6 +766,7 @@ const llvm::StringMap &rules() { {vector::TransposeOp::getOperationName(), canonicalize_vector_transpose}, {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, + {arith::SIToFPOp::getOperationName(), canonicalize_sitofp}, {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index eed0df14f707..b9aea1b087dc 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -180,6 +180,16 @@ FailureOr getTypeBitwidth(Type ty) { << ty; } +// Returns the bitwidth of the element type. The function works for both +// scalar and vector types. +template +inline FailureOr getElementTypeBitwidth(Type ty) { + if (auto vty = dyn_cast(ty)) { + return getTypeBitwidth(vty.getElementType()); + } + return getTypeBitwidth(ty); +} + template ArrayRef> toArrayRef(absl::Span span) { return ArrayRef>(span.data(), span.size()); From d0858a2ce60b94d50445368711208efa74f2379d Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Fri, 9 May 2025 11:03:28 -0700 Subject: [PATCH 1102/1769] Use `ValueRef` instead of `tsl::RCReference` PiperOrigin-RevId: 756832576 --- jaxlib/util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/util.cc b/jaxlib/util.cc index 814886b9a4d3..a014afa5bebe 100644 --- a/jaxlib/util.cc +++ b/jaxlib/util.cc @@ -62,7 +62,7 @@ absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { if (ifrt_arrays.size() == 1) { future = ifrt_arrays[0]->GetReadyFuture(); } else { - std::vector> values; + std::vector values; values.reserve(ifrt_arrays.size()); for (ifrt::Array* const ifrt_array : ifrt_arrays) { values.push_back(tsl::FormRef(ifrt_array)); From 701af068593bbf44e540f1641a86697d27afadd2 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 9 May 2025 11:48:39 -0700 Subject: [PATCH 1103/1769] Reverts 8137c37e324c9cb5c8f991a16d78310b6e37bd05 PiperOrigin-RevId: 756850393 --- jax/_src/array.py | 3 ++- jaxlib/py_array.cc | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f2b070c8221d..422fa5086e62 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -636,7 +636,8 @@ def _value(self) -> np.ndarray: self._check_if_deleted() if self._npy_value is None: - if self.is_fully_replicated: + if (self.is_fully_replicated and + self.sharding._internal_device_list.addressable_device_list): # type: ignore npy_value, did_copy = self._single_device_array_to_np_array_did_copy() npy_value.flags.writeable = False if did_copy: diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 022c7a831c92..1222d410bad8 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1528,6 +1528,9 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { absl::Span> buffers = array->pjrt_buffers(); + if (buffers.empty()) { + return InvalidArgument("Array has no buffers."); + } PjRtBuffer& buffer = *buffers.front(); if (!buffer.IsOnCpu()) { return InvalidArgument( From 0605bc4a5dbe688489b38ea0c8ea89c24e1a7327 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 May 2025 11:53:35 -0700 Subject: [PATCH 1104/1769] Make `jax.ShapeDtypeStruct` immutable. It was always supposed to be immutable inside `jax.Array` is immutable and `ShapeDtypeStruct` is a duck of `jax.Array` but immutability was never enforced. **If you are broken by this change, just update your code to use sds.update(...)** PiperOrigin-RevId: 756852248 --- CHANGELOG.md | 2 ++ jax/_src/api.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a03eb80eb973..518e854b5bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * JAX nightly packages are now published to artifact registry. To install these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). * `jax.sharding.PartitionSpec` no longer inherits from a tuple. + * `jax.ShapeDtypeStruct` is immutable now. Please use `.update` method to + update your `ShapeDtypeStruct` instead of doing in-place updates. ## JAX 0.6.0 (April 16, 2025) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5c8a86c035e6..9b4ce0fa0f24 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2870,6 +2870,17 @@ def __hash__(self): # https://github.com/jax-ml/jax/issues/8182 return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) + def __setattr__(self, name, value): + if hasattr(self, name): + if getattr(self, name) == value: + # This can to happen if two threads race, for example if two threads + # are trying to hash the same SDS instance. + return + raise RuntimeError( + f"Cannot reassign attributes ({name}) of immutable ShapeDtypeStruct" + " objects") + super().__setattr__(name, value) + def update(self, **kwargs): if 'sharding' in kwargs: s = kwargs['sharding'] @@ -2888,6 +2899,7 @@ def update(self, **kwargs): sharding=sharding, weak_type=kwargs.pop('weak_type', self.weak_type)) + def _sds_aval_mapping(x): aval = ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), From 0d5771ccaae3e82704c315e8c45852779578172f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 9 May 2025 19:28:09 +0000 Subject: [PATCH 1105/1769] Disable profiler tests under Python 3.14 if multithreaded. These are currently thread-unsafe due to https://github.com/python/cpython/issues/132817 --- tests/profiler_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 215e363e446d..d577f1c24c49 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -61,6 +61,12 @@ class ProfilerTest(unittest.TestCase): # check functional correctness. def setUp(self): + if sys.version_info >= (3, 14) and jtu.TEST_NUM_THREADS.value > 1: + # TODO(phawkins): try reenabling these after + # https://github.com/python/cpython/issues/132817 is fixed. Simply + # installing the profiler hook is unsafe if there are multiple threads. + self.skipTest("Profiler tests are not thread-safe under Python 3.14") + super().setUp() self.worker_start = threading.Event() self.profile_done = False From 6e97afdafb770a6e49be2fed055480be44d82b6b Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 9 May 2025 12:56:34 -0700 Subject: [PATCH 1106/1769] [Mosaic] Fix typo: FPToSI > SIToFP. PiperOrigin-RevId: 756874608 --- jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 258fec8caff1..71e48539f4dd 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -674,7 +674,7 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx, Operation &raw_op) { - auto op = cast(raw_op); + auto op = cast(raw_op); ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); auto src_vty = dyn_cast(op.getIn().getType()); auto dst_vty = dyn_cast(op.getType()); From 4aee38103e2f5795ee3f55d4d6b2b59b37ffc6a1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 May 2025 14:14:22 -0700 Subject: [PATCH 1107/1769] Disallow unreduced inputs for all primitives except those that implement the unreduced rule. Currently that's only `add`. PiperOrigin-RevId: 756902404 --- jax/_src/api.py | 7 +++---- jax/_src/lax/lax.py | 3 --- jax/_src/lax/utils.py | 29 ++++++++++++++++++----------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 9b4ce0fa0f24..a03933573e5d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2866,14 +2866,13 @@ def __eq__(self, other): (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) def __hash__(self): - # TODO(frostig): avoid the conversion from dict by addressing - # https://github.com/jax-ml/jax/issues/8182 - return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) + return hash((self.shape, self.dtype, self.sharding, self.layout, + self.weak_type)) def __setattr__(self, name, value): if hasattr(self, name): if getattr(self, name) == value: - # This can to happen if two threads race, for example if two threads + # This can happen if two threads race, for example if two threads # are trying to hash the same SDS instance. return raise RuntimeError( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 74297fc57b43..a49c27d06eee 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5213,9 +5213,6 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') - if lhs.sharding.spec.unreduced or rhs.sharding.spec.unreduced: - raise NotImplementedError( - 'Please file an issue at https://github.com/jax-ml/jax/issues') (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index a850b2965338..97a2687bbb67 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -67,7 +67,7 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: return mesh_lib.empty_abstract_mesh if m is None else m -def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): +def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs): cur_mesh = mesh_lib.get_abstract_mesh() aval_mesh = _get_abstract_mesh_from_avals(avals) if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and @@ -75,22 +75,30 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh s = NamedSharding(aval_mesh, P()) return s if num_out is None else [s] * num_out - if rule is None: + if sh_rule is None: raise core.ShardingTypeError( - f'sharding rule for {prim.name} is not implemented. Please file a' - ' bug at https://github.com/jax-ml/jax/issues. You can work around' + f'sharding rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') - return rule(*avals, **kwargs) + out_sharding = sh_rule(*avals, **kwargs) + if unreduced_rule is not None: + out_sharding = unreduced_rule(out_sharding, *avals, **kwargs) + else: + if any(a.sharding.spec.unreduced for a in avals): + raise NotImplementedError( + f'unreduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + return out_sharding def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, - multi_out, *avals, **kwargs): + unreduced_rule, multi_out, *avals, **kwargs): out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) num_out = len(out_shapes) if multi_out else None try: out_shardings = call_sharding_rule( - prim, sharding_rule, num_out, *avals, **kwargs) + prim, sharding_rule, unreduced_rule, num_out, *avals, **kwargs) except DuplicateSpecError as e: if multi_out: raise @@ -124,11 +132,9 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, False, + prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, False, *avals, **kwargs) out_vma = vma_rule(*avals, **kwargs) - if unreduced_rule is not None: - out_sharding = unreduced_rule(out_sharding, *avals, **kwargs) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, vma=out_vma) @@ -154,7 +160,8 @@ def standard_multi_result_abstract_eval( if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + prim, shape_rule, dtype_rule, sharding_rule, None, True, + *avals, **kwargs) out_vmas = vma_rule(*avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) From b95ba89b24a16ca2c61c77581c7264c2dffffcbb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 9 May 2025 14:41:26 -0700 Subject: [PATCH 1108/1769] JEP 28661: the __jax_array__ protocol --- docs/jep/28661-jax-array-protocol.md | 214 +++++++++++++++++++++++++++ docs/jep/index.rst | 1 + 2 files changed, 215 insertions(+) create mode 100644 docs/jep/28661-jax-array-protocol.md diff --git a/docs/jep/28661-jax-array-protocol.md b/docs/jep/28661-jax-array-protocol.md new file mode 100644 index 000000000000..e05d69d2822d --- /dev/null +++ b/docs/jep/28661-jax-array-protocol.md @@ -0,0 +1,214 @@ +# JEP 28661: Supporting the `__jax_array__` protocol + +[@jakevdp](http://github.com/jakevdp), *May 2025* + +An occasional user request is for the ability to define custom array-like objects that +work with jax APIs. JAX currently has a partial implementation of a mechanism that does +this via a `__jax_array__` method defined on the custom object. This was never intended +to be a load-bearing public API (see the discussion at {jax-issue}`#4725`), but has +become essential to packages like Keras and flax, which explicitly document the ability +to use their custom array objects with jax functions. This JEP proposes a design for +full, documented support of the `__jax_array__` protocol. + +## Levels of array extensibility +Requests for extensibility of JAX arrays come in a few flavors: + +### Level 1 Extensibility: polymorphic inputs +What I’ll call "Level 1" extensibility is the desire that JAX APIs accept polymorphic inputs. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array +``` + +Under this extensibility model, JAX functions would accept CustomArray objects as inputs, +implicitly converting them to `jax.Array` objects for the sake of computation. +This is similar to the functionality offered by NumPy via the `__array__` method, and in +JAX (in many but not all cases) via the `__jax_array__` method. + +This is the mode of extensibility that has been requested by the maintainers of `flax.nnx` +and others. The current implementation is also used by JAX internally for the case of +symbolic dimensions. + +### Level 2 extensibility: polymorphic outputs +What I’ll call "Level 2" extensibility is the desire that JAX APIs should not only accept +polymorphic inputs, but also wrap outputs to match the class of the input. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # returns a new CustomArray +``` + +Under this extensibility model, JAX functions would not only accept custom objects +as inputs, but have some protocol to determine how to correctly re-wrap outputs with +the same class. In NumPy, this sort of functionality is offered in varying degrees by +the special `__array_ufunc__`, `__array_wrap__`, and `__array_function__` protocols, +which allow user-defined objects to customize how NumPy API functions operate on +arbitrary inputs and map input types to outputs. +JAX does not currently have any equivalent to these interfaces in NumPy. + +This is the mode of extensibility that has been requested by the maintainers of `keras`, +among others. + +### Level 3 extensibility: subclassing `Array` + +What I’ll call "Level 3" extensibility is the desire that the JAX array object itself +could be subclassable. NumPy provides some APIs that allow this +(see [Subclassing ndarray](https://numpy.org/devdocs/user/basics.subclassing.html)) but +this sort of approach would take some extra thought in JAX due to the need for +representing array objects abstractly via tracing. + +This mode of extensibility has occasionally been requested by users who want to add +special metadata to JAX arrays, such as units of measurement. + +## Synopsis + +For the sake of this proposal, we will stick with the simplest, level 1 extensibility +model. The proposed interface is the one currently non-uniformly supported by a number +of JAX APIs, the `__jax_array__` method. Its usage looks something like this: + +```python +import jax +import jax.numpy as jnp +import numpy as np + +class CustomArray: + data: np.ndarray + + def __init__(self, data: np.ndarray): + self.data = data + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.data) + +arr = CustomArray(np.arange(5)) +result = jnp.multiply(arr, 2) +print(repr(result)) +# Array([0, 2, 4, 6, 8], dtype=int32) +``` + +We may revisit other extensibility levels in the future. + +## Design challenges + +JAX presents some interesting design challenges related to this kind of extensibility, +which have not been fully explored previously. We’ll discuss them in turn here: + +### Priority of `__jax_array__` vs. PyTree flattening +JAX already has a supported mechanism for registering custom objects, namely pytree +registration (see [Extending pytrees](https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)). +If we also support __jax_array__, which one should take precedence? + +To put this more concretely, what should be the result of this code? + +```python +@jax.jit +def f(x): + print("is JAX array:", isinstance(x, jax.Array)) + +f(CustomArray(...)) +``` + +If we choose to prioritize `__jax_array__` at the JIT boundary, then the output of this +function would be: +``` +is JAX array: True +``` +That is, at the JIT boundary, the `CustomArray` object would be converted into a +`__jax_array__`, and its shape and dtype would be used to construct a standard JAX +tracer for the function. + +If we choose to prioritize pytree flattening at the JIT boundary, then the output of +this function would be: +``` +type(x)=CustomArray +``` +That is, at the JIT boundary, the `CustomArray` object is flattened, and then unflattened +before being passed to the JIT-compiled function for tracing. If `CustomArray` has been +registered as a pytree, it will generally contain traced arrays as its attributes, and +when x is passed to any JAX API that supports `__jax_array__`, these traced attributes +will be converted to a single traced array according to the logic specified in the method. + +There are deeper consequences here for how other transformations like vmap and grad work +when encountering custom objects: for example, if we prioritize pytree flattening, vmap +would operate over the dimensions of the flattened contents of the custom object, while +if we prioritize `__jax_array__`, vmap would operate over the converted array dimensions. + +This also has consequences when it comes to JIT invariance: consider a function like this: +```python +def f(x): + if isinstance(x, CustomArray): + return x.custom_method() + else: + # do something else + ... + +result1 = f(x) +result2 = jax.jit(f)(x) +``` +If `jit` consumes `x` via pytree flattening, the results should agree for a well-specified +flattening rule. If `jit` consumes `x` via `__jax_array__`, the results will differ because +`x` is no longer a CustomArray within the JIT-compiled version of the function. + +#### Synopsis +As of JAX v0.6.0, transformations prioritize `__jax_array__` when it is available. This status +quo can lead to confusion around lack of JIT invariance, and the current implementation in practice +leads to subtle bugs in the case of automatic differentiation, where the forward and backward pass +do not treat inputs consistently. + +Because the pytree extensibility mechanism already exists for the case of customizing +transformations, it seems most straightforward if transformations act only via this +mechanism: that is, **we propose to remove `__jax_array__` parsing during abstractification.** +This approach will preserve object identity through transformations, and give the user the +most possible flexibility. If the user wants to opt-in to array conversion semantics, that +is always possible by explicitly casting their input via jnp.asarray, which will trigger the +`__jax_array__` protocol. + +### Which APIs should support `__jax_array__`? +JAX has a number of different levels of API, from the level of explicit primitive binding +(e.g. `jax.lax.add_p.bind(x, y)`) to the `jax.lax` APIs (e.g. `jax.lax.add(x, y)`) to the +`jax.numpy` APIs (e.g. `jax.numpy.add(x, y)`). Which of these API categories should handle +implicit conversion via `__jax_array__`? + +In order to limit the scope of the change and the required testing, I propose that `__jax_array__` +only be explicitly supported in `jax.numpy` APIs: after all, it is inspired by the` __array__` +protocol which is supported by the NumPy package. We could always expand this in the future to +`jax.lax` APIs if needed. + +This is in line with the current state of the package, where `__jax_array__` handling is mainly +within the input validation utilities used by `jax.numpy` APIs. + +## Implementation +With these design choices in mind, we plan to implement this as follows: + +- **Adding runtime support to `jax.numpy`**: This is likely the easiest part, as most + `jax.numpy` functions use a common internal utility (`ensure_arraylike`) to validate + inputs and convert them to array. This utility already supports `__jax_array__`, and + so most jax.numpy APIs are already compliant. +- **Adding test coverage**: To ensure compliance across the APIs, we should add a new + test scaffold that calls every `jax.numpy` API with custom inputs and validates correct + behavior. +- **Deprecating `__jax_array__` during abstractification**: Currently JAX's abstractification + pass, used in `jit` and other transformations, does parse the `__jax_array__` protocol, + and this is not the behavior we want long-term. We need to deprecate this behavior, and + ensure that downstream packages that rely on it can move toward pytree registration or + explicit array conversion where necessary. +- **Adding type annotations**: the type interface for jax.numpy functions is in + `jax/numpy/__init__.pyi`, and we’ll need to change each input type from `ArrayLike` to + `ArrayLike | SupportsJAXArray`, where the latter is a protocol with a `__jax_array__` + method. We cannot add this directly to the `ArrayLike` definition, because `ArrayLike` + is used in contexts where `__jax_array__` should not be supported. +- **Documentation**: once the above support is added, we should add a documentation section + on array extensibility that outlines exactly what to expect regarding the `__jax_array__` + protocol, with examples of how it can be used in conjunction with pytree registration + in order to effectively work with user-defined types. diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 1c4ecbb3411f..2ba85a5f4a8d 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -52,6 +52,7 @@ Then create a pull request that adds a file named 17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose> 18137: Scope of JAX NumPy & SciPy Wrappers <18137-numpy-scipy-scope> 25516: Effort-based versioning <25516-effver> + 28661: Supporting the `__jax_array__` protocol <28661-jax-array-protocol> Several early JEPs were converted in hindsight from other documentation, From 0aef447539f8fc1d2b71bc8afd9dbf35ad98d606 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 9 May 2025 14:48:34 -0700 Subject: [PATCH 1109/1769] Fix typo in FusedAttentionTest PiperOrigin-RevId: 756914581 --- tests/pallas/gpu_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 3c352afe3382..1b758cdd0a58 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -268,7 +268,7 @@ def test_fused_attention_bwd( use_segment_ids, ): if jtu.is_cuda_compute_capability_equal("8.0") and all([ - block_sizes["block_q"] == 128, + dict(block_sizes)["block_q"] == 128, batch_size == 2, num_heads == 2, head_dim == 128, From 35e2657be8308917c7fa407be5a0b53192134890 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 9 May 2025 19:16:24 -0700 Subject: [PATCH 1110/1769] [Pallas] Allow more int casting tests. PiperOrigin-RevId: 756989842 --- tests/pallas/ops_test.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf6536df2344..61ebc19e018f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -606,10 +606,17 @@ def test_cast_from_32bit(self, from_dtype, to_dtype, data): if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): - if to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}: + if to_dtype in {"int2", "uint2"}: self.skipTest("Not supported on this TPU generation") if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): self.skipTest("Test requires libtpu from 2025/1/18 or later") + if to_dtype in { + "int4", + "uint4", + "int8", + "uint8", + } and not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: # Currently only casts between 32-bit types and to bf16 are supported. if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}: @@ -673,18 +680,7 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): if jtu.is_device_tpu(version=4): allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)} if ( - from_dtype - in { - "int16", - "int8", - "uint16", - "uint8", - "int4", - "uint4", - "int2", - "uint2", - } - or to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"} + from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"} ) and (from_dtype, to_dtype) not in allowed_v4_cats: self.skipTest("Not supported on this TPU generation") if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None): @@ -692,6 +688,12 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): self.skipTest("Test requires a newer libtpu") if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): self.skipTest("Test requires libtpu from 2025/1/18 or later") + if ( + to_dtype in {"int4", "uint4", "int8", "uint8"} + and from_dtype in {"int4", "uint4", "int8", "uint8"} + and not jtu.if_cloud_tpu_at_least(2025, 5, 15) + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and ( From 55a9de3c8a689bc48a2772a2d9dc359b0ecca51e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 10 May 2025 04:01:35 +0000 Subject: [PATCH 1111/1769] [si_vjp] fix bugs around symbolic zeros * fix leaking of internal symbolic zeros in returned cotangents * fix a bug around symbolic zero output tangents --- jax/_src/api.py | 12 +++++++----- tests/api_test.py | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a03933573e5d..154ff5132a39 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2199,7 +2199,8 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, fun = lu.wrap_init(f, debug_info=dbg) primals_flat, in_tree = tree_flatten(primals) fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat) + out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(fun, *primals_flat) + out_known = [pval.is_known() for pval in out_pvals] primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) id_map = {id(x): i for i, x in enumerate(primals_filt)} opaque_residuals = [] @@ -2207,7 +2208,7 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore for r in residuals] f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, - out_tree(), jaxpr), opaque_residuals) + out_tree(), out_known, jaxpr), opaque_residuals) if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) @@ -2232,8 +2233,8 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, out_primals = tree_unflatten(out_tree(), out_primals_flat) return out_primals, f_vjp -def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr, - opaque_residuals, ct, *saved_primals): +def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, + jaxpr, opaque_residuals, ct, *saved_primals): primals_filtered, filtered_tree_ = tree_flatten(saved_primals) if filtered_tree != filtered_tree_: raise ValueError( @@ -2253,8 +2254,9 @@ def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr, dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] cts_flat, out_tree_ = tree_flatten(ct) assert out_tree_ == out_tree + cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) - return tree_unflatten(in_tree, arg_cts) + return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) @dataclasses.dataclass(frozen=True) class RSpec: diff --git a/tests/api_test.py b/tests/api_test.py index 6e55e732151d..15966c678d87 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7323,6 +7323,11 @@ def f2(x, w): self.assertAllClose(x_grad, 2. * y_grad @ w.T) self.assertAllClose(w_grad, 2. * x.T @ y_grad) + def test_doesnt_leak_symbolic_zeros(self): + _, vjp = api.si_vjp(lambda x: 1., [False], 3.14) + ans, = vjp(1.0) + self.assertIsInstance(ans, jax.Array) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 254d64e3f8e884b0fa3a21a547f1438a66414967 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 9 May 2025 17:19:27 +0000 Subject: [PATCH 1112/1769] [shard-map] start adding systematic smap tests --- tests/shard_map_test.py | 88 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 4d3b265bd869..5b417bb4b87a 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3649,6 +3649,94 @@ def f(x): self.assertAllClose(f(jnp.arange(8.)), jnp.array([1., 5., 9., 13.])) +def smap_ref(f, in_axes, out_axes, axis_name, axis_size): + del axis_name # no collectives + def smapped(*args): + split_args = zip(*[split_arg(x, d, axis_size) for x, d in zip(args, in_axes)]) + split_result = [f(*xs) for xs in split_args] + return concat_result(split_result, out_axes) + return smapped + +def split_arg(x, d, axis_size): + if d is None: + x = np.tile(x, [axis_size] + [1] * (x.ndim - 1)) + return np.split(x, axis_size, d or 0) + +def concat_result(results, out_axes): + if not isinstance(results[0], (list, tuple)): + return results[0] if out_axes is None else np.concatenate(results, out_axes) + return [res[0] if d is None else np.concatenate(res, d) + for res, d in zip(zip(*results), out_axes)] + +def sample_smap() -> Chooser: + spec = yield fun_specs + mesh_shape = yield mesh_shapes + axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)] + mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)), + axis_names=axis_names) + axis_name = yield axis_names + body_in_types = yield (tys for tys in it.product(input_shapes, repeat=spec.num_inputs) + if not spec.valid_types or spec.valid_types(*tys)) + in_axes = yield from sample_in_axes(body_in_types) + out_rep = spec.out_rep(*[ax is None for ax in in_axes]) + body_out_type = jax.eval_shape(spec.fun, *body_in_types) + out_axes = yield from sample_out_axes(out_rep, body_out_type) + in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short() + for t in body_in_types) + ')' + name = f'{spec.name}_{mesh.shape}_{in_axes}_{out_axes}_{axis_name}_{in_str}' + in_types = [ty.update(shape=dilate_axis(ty.shape, d, mesh.shape[axis_name])) + for ty, d in zip(body_in_types, in_axes)] + args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size + for ty in in_types] + return name, spec, mesh.shape, in_axes, out_axes, axis_name, args + +def sample_in_axes(body_in_types) -> Chooser: + in_axes = [] + for ty in body_in_types: + in_axes.append((yield [None, *range(ty.ndim)])) + return tuple(in_axes) + +def sample_out_axes(out_rep, body_out_type) -> Chooser: + if not isinstance(body_out_type, (list, tuple)): + out_axes = yield [None] * out_rep + list(range(body_out_type.ndim)) + else: + out_axes_ = [] + for ty, r in zip(body_out_type, out_rep): + out_axes_.append((yield [None] * r + list(range(ty.ndim)))) + out_axes = tuple(out_axes_) + return out_axes + +def dilate_axis(shape: tuple[int, ...], i: int | None, size: int) -> tuple[int, ...]: + if i is None: + return shape + shp = list(shape) + shp[i] *= size + return tuple(shp) + +class SmapSystematicTest(jtu.JaxTestCase): + + @staticmethod + def make_mesh(mesh_shape): + return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) + + @parameterized.parameters( + sample(jtu.NUM_GENERATED_CASES.value, sample_smap)) + def test_against_ref(self, fun_spec, mesh_shape, in_axes, out_axes, axis_name, args): + fun = fun_spec.fun + mesh = self.make_mesh(mesh_shape) + args = map(jnp.array, args) + + with jax.sharding.use_mesh(mesh): + fun_ = smap(fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + out = jax.jit(fun_)(*args) + + fun_ref = smap_ref(fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + axis_size=mesh_shape[axis_name]) + expected = fun_ref(*args) + + self.assertAllClose(out, expected, check_dtypes=False) + + @jtu.with_config(jax_use_shardy_partitioner=True) # TODO(phawkins): enable this test unconditionally once shardy is the default. @unittest.skipIf(sdy is None, "shardy is not enabled") From 6ec18d4987929154babde1837dfb2bd2728205d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 10 May 2025 06:12:39 -0700 Subject: [PATCH 1113/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/80924f3d144737d14758d8a92b236d90c8ec8cb9. PiperOrigin-RevId: 757132575 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e0a7fc8db91a..30550618d568 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0ca232a612deb80c64bc0e6a55f3b9bbd198b27f" -XLA_SHA256 = "3605a12ccb161443e3893f46e0cd05f96b271692d277185a258781989ea44c29" +XLA_COMMIT = "80924f3d144737d14758d8a92b236d90c8ec8cb9" +XLA_SHA256 = "6bff089f5ff767a31d6e8f4de95738340b7325a715624aaedd7f5ccd06754a9d" def repo(): tf_http_archive( From 2016e59607cbcde8efda3ff90571c84b11f9c4bb Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 10 May 2025 09:50:56 -0700 Subject: [PATCH 1114/1769] Automated Code Change PiperOrigin-RevId: 757170420 --- jaxlib/sharding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc index fa19e1434a90..0514946a729b 100644 --- a/jaxlib/sharding.cc +++ b/jaxlib/sharding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,7 +30,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep From b4eb48cd74179d29351b51b24d052cac579f1cf7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 11 May 2025 05:56:23 -0700 Subject: [PATCH 1115/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/633c9abd097a2cf20884d29da51cc53b6e7144b5. PiperOrigin-RevId: 757401878 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 30550618d568..5b1f2e5db99c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "80924f3d144737d14758d8a92b236d90c8ec8cb9" -XLA_SHA256 = "6bff089f5ff767a31d6e8f4de95738340b7325a715624aaedd7f5ccd06754a9d" +XLA_COMMIT = "633c9abd097a2cf20884d29da51cc53b6e7144b5" +XLA_SHA256 = "2c8280edf20af0c16bf952373ce44db2fe42d701b8bcd9b8487484860b15afbd" def repo(): tf_http_archive( From caf10dfc4503b98ea815b5c44c08d75f32fa83c7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sun, 11 May 2025 23:15:12 -0700 Subject: [PATCH 1116/1769] [Mosaic GPU] Add support for 8-bit MMA on Blackwell PiperOrigin-RevId: 757608204 --- jax/experimental/mosaic/gpu/launch_context.py | 25 ++++++-- jax/experimental/mosaic/gpu/tcgen05.py | 57 +++++++++++++++---- tests/mosaic/gpu_test.py | 10 +++- 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index d169c448a80e..02ed3859d8c2 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -456,12 +456,19 @@ def init_tma_desc(host_ptr): tma_dtype = 3 elif bitwidth == 64: tma_dtype = 4 + else: + raise ValueError(f"Unsupported integer bitwidth: {bitwidth}") elif ir.F16Type.isinstance(ref_ty.element_type): tma_dtype = 5 elif ir.F32Type.isinstance(ref_ty.element_type): tma_dtype = 6 elif ir.BF16Type.isinstance(ref_ty.element_type): tma_dtype = 7 + # We treat 8 bit floats as 8 bit integers + elif ir.Float8E5M2Type.isinstance(ref_ty.element_type): + tma_dtype = 1 + elif ir.Float8E4M3Type.isinstance(ref_ty.element_type): + tma_dtype = 1 else: raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") dtype_or_bitwidth = c(tma_dtype, i64) @@ -584,12 +591,18 @@ def async_copy( " multiple of 16 bytes" ) - if reduction_op is not None and jaxlib.version < (0, 5, 4): - raise ValueError("TMA with reduction is only supported with jaxlib >= 0.5.4") - if reduction_op is not None and not isinstance(gmem_ref_ty.element_type, ir.FloatType): - raise ValueError("TMA with reduction is only supported with float dtype") - if reduction_op is not None and reduction_op != "add": - raise ValueError("TMA with reduction is only supported with add operation") + if reduction_op is not None: + if not any( + t.isinstance(gmem_ref_ty.element_type) + for t in (ir.F32Type, ir.BF16Type, ir.F16Type) + ): + raise ValueError( + "TMA with reduction is only supported with f32, f16 and bf16" + ) + if reduction_op != "add": + raise ValueError( + "TMA with reduction is only supported with add operation" + ) # NOTE: TMA supports OOB indices, so we skip the check. base_indices, slice_shape, is_squeezed = utils.parse_indices( diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index f4ea8e289f01..730761cb7eff 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -56,17 +56,42 @@ def create_instr_descriptor( f32 = ir.F32Type.get() bf16 = ir.BF16Type.get() f16 = ir.F16Type.get() - if input_dtype not in {f16, bf16}: - raise NotImplementedError("Only float16 and bfloat16 inputs supported") if acc_dtype not in {f32, f16}: raise NotImplementedError("Only float32 and float16 accumulators supported") + if utils.bitwidth(input_dtype) == 16: + if input_dtype not in {f16, bf16}: + raise NotImplementedError( + "The only supported 16-bit input types are float16 and bfloat16, got" + f" {input_dtype}" + ) + desc = 0 + desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 + # Bit 6 is reserved + desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9 + desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12 + return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b) + elif utils.bitwidth(input_dtype) == 8: + desc = 0 + desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 + # Bit 6 is reserved + if input_dtype == ir.Float8E4M3Type.get(): + input_dtype_enum = 0 + elif input_dtype == ir.Float8E5M2Type.get(): + input_dtype_enum = 1 + else: + raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") + desc |= input_dtype_enum << 7 # A dtype, bits 7-9 + desc |= input_dtype_enum << 10 # B dtype, bits 10-12 + return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b) + else: + raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") - desc = 0 + +def _finish_instr_descriptor( + desc: int, m: int, n: int, transpose_a: bool, transpose_b: bool, +): # We ignore sparsity in bits 0-3 - desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 - # Bit 6 is reserved - desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9 - desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12 + # A, B and D types are set by the caller # We ignore negate bits 13-14 desc |= transpose_a << 15 # Transpose A desc |= transpose_b << 16 # Transpose B @@ -139,20 +164,24 @@ def mma( f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" ) f32 = ir.F32Type.get() + f16 = ir.F16Type.get() if element_type == f32 or element_type == ir.BF16Type.get(): if d.dtype != f32: raise ValueError( f"MMA with element type {element_type} only supports accumulators" f" of type f32, but got: {d.dtype}" ) - elif element_type == ir.F16Type.get(): - if d.dtype != element_type and d.dtype != f32: + elif any( + t.isinstance(element_type) + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3Type} + ): + if d.dtype != f16 and d.dtype != f32: raise ValueError( - "MMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {d.dtype}" + f"MMA with element type {element_type} only supports accumulators of" + f" type f32 or f16, but got: {d.dtype}" ) else: - raise NotImplementedError(f"Unsupported element type: {element_type}") + raise NotImplementedError(f"Unsupported element type: {element_type}", type(element_type)) # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. @@ -268,6 +297,10 @@ def _do_mma( if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type): kind = "f16" + elif ir.Float8E5M2Type.isinstance(element_type): + kind = "f8f6f4" + elif ir.Float8E4M3Type.isinstance(element_type): + kind = "f8f6f4" else: raise NotImplementedError(f"Unsupported input element type: {element_type}") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 8c26f64bd203..b9350b0c995b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -983,13 +983,15 @@ def kernel(ctx, input, output, scratch): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3), # TODO(apaszke): f32 out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 swizzle=(32, 64, 128,), ) def test_mma_basic(self, **kwargs): + if kwargs["n"] * jnp.dtype(kwargs["in_jax_dtype"]).itemsize < kwargs["swizzle"]: + self.skipTest("swizzle too large for input") self._basic_mma_test( **kwargs, k_steps=2, # Reducing to 1 can be helpful while debugging. @@ -1029,8 +1031,10 @@ def _basic_mma_test( rhs_transpose_tiles, lhs_transpose_tiles, ): - if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: - self.skipTest("Only f16 input is supported for f16 output.") + if out_jax_dtype != jnp.float32 and ( + in_jax_dtype == jnp.float32 or in_jax_dtype == jnp.bfloat16 + ): + self.skipTest("Only f32 output is supported for f32 and bf16 input.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) swizzle_elems = swizzle // bytewidth(in_mlir_dtype) From 863f762de1644d9fc341f098d34e990d71444f85 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 12 May 2025 00:26:14 -0700 Subject: [PATCH 1117/1769] [Pallas:MGPU] Update one more lowering rule to the load/store rename in TMEMRef PiperOrigin-RevId: 757628207 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d50e39d5c3db..2ef504e518f6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1356,8 +1356,8 @@ def _swap_lowering_rule( case _: raise NotImplementedError( "Only a single indexing transform is supported for TMEM refs.") - old_value = x_ref[:] - x_ref[:] = value + old_value = x_ref.load(layout=value.layout) + x_ref.store(value) return old_value if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): From f6bce253fc6729556a636540659d338729c8711d Mon Sep 17 00:00:00 2001 From: Aleksei Rechinskii Date: Mon, 12 May 2025 08:13:28 +0000 Subject: [PATCH 1118/1769] Fix debug rule in .bazelrc For recursive config definitions Bazel requires use of a single token notation `--config=value` --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 4780b0ba37c9..53676637c839 100644 --- a/.bazelrc +++ b/.bazelrc @@ -418,7 +418,7 @@ build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base ############################################################################# build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" -build:debug --config debug_symbols -c fastbuild +build:debug --config=debug_symbols -c fastbuild # Load `.jax_configure.bazelrc` file written by build.py try-import %workspace%/.jax_configure.bazelrc From eb2fe9715d849d6b3d63d5345b96b6865c68b3e2 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 12 May 2025 01:37:36 -0700 Subject: [PATCH 1119/1769] [Mosaic GPU] Add an additional WG barrier before copy_gmem_to_smem This is necessary to ensure that all SMEM reads issued from a current WG have completed before we schedule the copy (that acts as an SMEM write)! PiperOrigin-RevId: 757647993 --- jax/_src/pallas/mosaic_gpu/primitives.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 40eccca7c711..f199b7b245c6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -449,9 +449,6 @@ def _copy_gmem_to_smem_pp_eqn( @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) -@lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, - primitive_semantics=gpu_core.PrimitiveSemantics.Warp) @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) @@ -465,6 +462,7 @@ def _copy_gmem_to_smem_lowering( dst_transforms_treedef, barrier_transforms_treedef, collective_axes, + warpgroup_sync: bool = True, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -509,6 +507,8 @@ def _copy_gmem_to_smem_lowering( # arrive with the whole transfer size, while everyone else arrives with 0. # But we should continue using this scheme as it's likely to be faster. bytes //= WARPGROUP_SIZE + if warpgroup_sync: + mgpu.warpgroup_barrier() # Make sure all reads have completed. barrier.arrive_expect_tx(bytes) ctx.launch_ctx.async_copy( src_ref=src, @@ -541,6 +541,12 @@ def _copy_gmem_to_smem_lowering( ) return () +lowering.register_lowering_rule( + copy_gmem_to_smem_p, + mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, +)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False)) + def copy_gmem_to_smem( src: _Ref, From 0c5109027947163b937a3a0df3e95e8d1d6197a3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 05:56:56 -0700 Subject: [PATCH 1120/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6ad6ae3dafa9868708e54de10e3aeafb081a71f2. PiperOrigin-RevId: 757728274 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 5b1f2e5db99c..94fbee2d8b83 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "633c9abd097a2cf20884d29da51cc53b6e7144b5" -XLA_SHA256 = "2c8280edf20af0c16bf952373ce44db2fe42d701b8bcd9b8487484860b15afbd" +XLA_COMMIT = "6ad6ae3dafa9868708e54de10e3aeafb081a71f2" +XLA_SHA256 = "405c2787b2fa2a467f4b2179cbfb2fd25f282f55ff43b0571ac20e4e56b8c26c" def repo(): tf_http_archive( From bd8765d3832ced4cc5f0482ef65cb19a16d22dad Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 12 May 2025 06:15:01 -0700 Subject: [PATCH 1121/1769] Add collective_axes to run_scoped Our current allocation scheme on GPU is unsafe in presence of multiple threads that might take diverging control paths. We work around this problem using our favorite trick and simply forbid this! With this change, `run_scoped(..., collective_axes="wg")` means that the same allocation will be returned in all programs that only differ in the `wg` axis. What's more, this call is a user promise that the allocation is a collective that will be executed by all threads along that axis. Only executing it on a subset is undefined behavior and in our current Mosaic GPU implementation might lead to deadlocks due to barriers. Note that nothing changes for single-threaded kernels, where run_scoped is always allowed. PiperOrigin-RevId: 757734362 --- jax/_src/pallas/hlo_interpreter.py | 10 +++- jax/_src/pallas/mosaic/interpret.py | 4 ++ jax/_src/pallas/mosaic/lowering.py | 4 +- jax/_src/pallas/mosaic_gpu/core.py | 5 +- jax/_src/pallas/mosaic_gpu/lowering.py | 48 +++++++++++++++++-- jax/_src/pallas/mosaic_gpu/pipeline.py | 1 + jax/_src/pallas/primitives.py | 38 ++++++++++++--- .../pallas/ops/gpu/attention_mgpu.py | 2 + 8 files changed, 97 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index f3d2c46ad9a9..755df2cd8ceb 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -312,9 +312,15 @@ def rule(interpreter, *args, **params): lax.while_p, 'body_jaxpr', 'cond_jaxpr') _eval_jaxpr_hop_rules[lax.cond_p] = make_hop_rule(lax.cond_p, 'branches') def _run_scoped_physicalize_rule( - interpreter, *consts, jaxpr: jax_core.Jaxpr): + interpreter, *consts, jaxpr: jax_core.Jaxpr, collective_axes): + if collective_axes: + raise NotImplementedError( + "run_scoped interpret rule does not support collective axes" + ) physical_jaxpr, physical_consts = interpreter(jaxpr, consts) - return primitives.run_scoped_p.bind(*physical_consts, jaxpr=physical_jaxpr) + return primitives.run_scoped_p.bind( + *physical_consts, jaxpr=physical_jaxpr, collective_axes=collective_axes + ) _eval_jaxpr_hop_rules[primitives.run_scoped_p] = _run_scoped_physicalize_rule diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index f7160f5af386..c0e52f54e6f3 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1169,6 +1169,10 @@ def f(*args, jaxpr): out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) elif prim is primitives.run_scoped_p: + if eqn.params['collective_axes']: + raise NotImplementedError( + 'run_scoped_p with collective axes is not supported' + ) # Allocate a buffer or semaphore for each element of # eqn.params['jaxpr'].invars . allocs = [] diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index cdc64bbf96f4..b2a1f356ad0b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3347,7 +3347,9 @@ def _alloc_value( @register_lowering_rule(primitives.run_scoped_p) -def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): +def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr, collective_axes): + if collective_axes: + raise NotImplementedError("run_scoped lowering does not support collective axes") out_type = [ aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval) for aval in ctx.avals_out diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6be8b3c4a8a5..21c78720e812 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -182,13 +182,16 @@ def kernel( def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs + mesh = GPUMesh(**mesh_kwargs) + thread_name = mesh.thread_name if mesh.thread_name is not None else () def cmap_body(): pallas_primitives.run_scoped( lambda *scratch_refs: body(*operand_refs, *out_refs, *scratch_refs), *scratch_shapes, + collective_axes=thread_name, ) pallas_core.core_map( - GPUMesh(**mesh_kwargs), compiler_params=compiler_params + mesh, compiler_params=compiler_params )(cmap_body) _, outs = state_discharge.run_state(stateful)( (operands, jax.tree.map(jnp.zeros_like, out_shape)) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2ef504e518f6..b501693bf627 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -215,8 +215,11 @@ def _while_resource_estimator( @_register_resource_estimator(primitives.run_scoped_p) def _run_scoped_resource_estimator( - ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr + ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr, collective_axes ) -> int: + # NOTE: This rule assumes that the allocation happens collectively, although + # it can't be checked here due to limited context. We check this in the actual + # lowering rule. del consts # Unused. rs = Resources() for v in jaxpr.invars: @@ -298,7 +301,7 @@ def __iter__(self) -> Iterable[Hashable]: @dataclasses.dataclass class ModuleContext: name: str - axis_names: _AxisNames | None + axis_names: _AxisNames program_ids: Sequence[ir.Value] | None approx_math: bool single_wg_lane_predicate: ir.Value | None @@ -602,9 +605,13 @@ def lower_pipelined_jaxpr_to_module( assert isinstance(gpu_mesh, gpu_core.GPUMesh) block = (128 * (gpu_mesh.num_threads or 1), 1, 1) grid = gpu_mesh.grid + thread_axis = ( + gpu_mesh.thread_name if gpu_mesh.thread_name is not None else () + ) else: block = (128, 1, 1) grid = grid_mapping.grid + thread_axis = () if params.dimension_semantics is None: which_parallel = [True] * len(grid) @@ -659,6 +666,7 @@ def pipeline_fn(*refs): ref_for_aval(aval) if aval is not sem_placeholder else aval for aval in scratch_avals ], + collective_axes=thread_axis, # scratch_refs are shared across threads ) return () # ``wrap_init`` does not support functions returning None. @@ -1937,6 +1945,14 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): case mgpu.WGStridedFragLayout(): if set(axes) != set(range(x_aval.ndim)): raise NotImplementedError("No support for axes yet") + # To relax the restriction below, you need to ensure sufficient + # synchronization with other places that use `scratch_view` (which at the + # time of writing is only `run_scoped`). + if ctx.module_ctx.axis_names.wg is not None: + raise NotImplementedError( + "No support for reduce_sum over all axes and multiple Pallas" + " threads" + ) scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: return x.reduce("add", axes, scratch) @@ -2178,14 +2194,28 @@ def _debug_print_lowering_rule_wg( @register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( - ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr + ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr, collective_axes ): input_refs = [] should_discharge = [] + wg_axis = ctx.module_ctx.axis_names.wg + is_multithreaded = wg_axis is not None + is_thread_collective = is_multithreaded and collective_axes == (wg_axis,) + # Make sure everyone has exited previous scoped allocations. Note that we + # don't synchronize when we exit the allocation, but only when we might want + # to reuse its memory again. + if is_multithreaded and is_thread_collective: + gpu_dialect.barrier() with contextlib.ExitStack() as alloc_stack: for v in jaxpr.invars: aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + if collective_axes: + raise ValueError( + "WGMMA accumulators can only be allocated non-collectively. Hint:" + " remove collective_axes from run_scoped. If other allocations" + " are performed as well, split the run_scoped into two." + ) dtype = mlir.dtype_to_ir_type(aval.dtype) if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) @@ -2196,7 +2226,17 @@ def _run_scoped_lowering_rule( nvvm_dialect.wgmma_fence_aligned() input_refs.append(acc) should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): + continue + # All other allocations must be made collectively across all threads. + if is_multithreaded and not is_thread_collective: + raise NotImplementedError( + "Only thread-collective allocations are supported in multithreaded" + " kernels. Hint: add" + f" collective_axes={ctx.module_ctx.axis_names.wg} to your" + " run_scoped if you intend all threads to share the same" + f" allocation (currently collective_axes={collective_axes})." + ) + if isinstance(aval.dtype, gpu_core.BarrierType): multiplier = (1 if aval.dtype.for_tensor_core else ctx.estimator_ctx.arrival_multiplier) barrier_ref = alloc_stack.enter_context( diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 426f314bc3a1..db5fb4fb316a 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -558,6 +558,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): out_smem_refs=out_smem_refs, in_smem_barrier_refs=in_smem_barriers, consumed_barrier_refs=consumed_barriers, + collective_axes=wg_axis, ) def scoped_pipeline( diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 986a62571010..5038ac6e5171 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -19,6 +19,7 @@ import enum import functools import string +from collections.abc import Hashable from typing import Any, Callable import jax @@ -878,13 +879,25 @@ def wrap_with_transforms(f, transforms, *args): run_scoped_p.multiple_results = True -def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: +def run_scoped( + f: Callable[..., Any], + *types: Any, + collective_axes: Hashable | tuple[Hashable, ...] = (), + **kw_types: Any, +) -> Any: """Calls the function with allocated references and returns the result. The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to :class:`jax.experimental.pallas.MemoryRef`. + + When `collective_axes` is specified, the same allocation will be returned for + all programs that only differ in their program ids along the collective axes. + It is an error not to call the same `run_scoped` in all programs along that + axis. """ + if not isinstance(collective_axes, tuple): + collective_axes = (collective_axes,) flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, @@ -908,13 +921,13 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: # are not in the invars of an operation so we just put them all # there. jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - out = run_scoped_p.bind(*consts, jaxpr=jaxpr) + out = run_scoped_p.bind(*consts, jaxpr=jaxpr, collective_axes=collective_axes) return tree_util.tree_unflatten(out_tree_thunk(), out) @run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr): - del args +def _run_scoped_abstract_eval(*args, jaxpr, collective_axes): + del args, collective_axes # jaxpr will have effects for its inputs (Refs that are allocated) and for # constvars (closed over Refs). The effects for the allocated Refs are local # to the jaxpr and shouldn't propagate out. @@ -935,8 +948,12 @@ def _run_scoped_discharge_rule( out_avals, *args_flat, jaxpr, - **_): + collective_axes): del out_avals + if collective_axes: + raise NotImplementedError( + "run_scoped discharge does not support collective_axes yet." + ) num_consts = len(args_flat) # discharge_state only discharges invars, not consts, so in order to # discharge the requested refs we need to move them to the invar set. @@ -956,7 +973,9 @@ def _run_scoped_discharge_rule( # Run_scoped discharged the external variables but the scoped ones # are not discharged. - out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body) + out = run_scoped_p.bind( + *args_flat, jaxpr=discharged_body, collective_axes=collective_axes + ) # Order of outputs: # (1) return values, (2) closed refs, (3) scoped refs. return_values = out[:num_return_values] @@ -975,7 +994,12 @@ def _run_scoped_discharge_rule( @functools.partial(mlir.register_lowering, run_scoped_p) -def _run_scoped_lowering_rule(ctx, *args, jaxpr): +def _run_scoped_lowering_rule(ctx, *args, jaxpr, collective_axes): + if collective_axes: + raise ValueError( + "run_scoped lowering outside of Pallas does not support" + " collective_axes." + ) jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) num_return_values = len(jaxpr_noconst.outvars) discharged_body, new_consts = state_discharge.discharge_state( diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index e7a9898bf9f3..3256953cd332 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -245,6 +245,7 @@ def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): ), (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, plgpu.Barrier(num_arrivals=compute_wgs), + collective_axes="wg", ) num_q_tiles, rem = divmod(q_seq_len, block_q * 2) @@ -740,6 +741,7 @@ def _kernel_entry(): scratch, plgpu.Barrier(1, num_barriers=compute_wgs), plgpu.Barrier(num_arrivals=compute_wgs), + collective_axes="wg", ) @jax.jit def run_function(q, k, v, o, lse): From 59992087a68dfeb67fea4b8c285ee34c65bab7cd Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 06:20:24 -0700 Subject: [PATCH 1122/1769] Call block_until_ready for testAutodiffCache PiperOrigin-RevId: 757735827 --- tests/pjit_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 55c38f6c2ef8..22a4d4f70f8c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -671,7 +671,7 @@ def testAutodiffCache(self): jax.grad(f)(x) # Warm up the cache. with jtu.count_pjit_cpp_cache_miss() as count: - jax.grad(f)(x) + jax.block_until_ready(jax.grad(f)(x)) self.assertEqual(count(), 0) # no cache miss i.e. cache hit @jtu.with_mesh([('x', 2), ('y', 1)]) From f5c63053eb2ee59edfd669feb95cf81626b4cd48 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 12 May 2025 06:39:57 -0700 Subject: [PATCH 1123/1769] [Mosaic GPU] Use f8e4m3fn in place of f8e4m3 PTX docs are a bit confusing because the type is called e4m3, but [its description](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) indicates that it is actually e4m3fn (no infs, limited NaNs). PiperOrigin-RevId: 757741649 --- jax/experimental/mosaic/gpu/launch_context.py | 2 +- jax/experimental/mosaic/gpu/tcgen05.py | 6 +++--- tests/mosaic/gpu_test.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 02ed3859d8c2..eccb363f7537 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -467,7 +467,7 @@ def init_tma_desc(host_ptr): # We treat 8 bit floats as 8 bit integers elif ir.Float8E5M2Type.isinstance(ref_ty.element_type): tma_dtype = 1 - elif ir.Float8E4M3Type.isinstance(ref_ty.element_type): + elif ir.Float8E4M3FNType.isinstance(ref_ty.element_type): tma_dtype = 1 else: raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 730761cb7eff..c46e24b9ada2 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -74,7 +74,7 @@ def create_instr_descriptor( desc = 0 desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 # Bit 6 is reserved - if input_dtype == ir.Float8E4M3Type.get(): + if input_dtype == ir.Float8E4M3FNType.get(): input_dtype_enum = 0 elif input_dtype == ir.Float8E5M2Type.get(): input_dtype_enum = 1 @@ -173,7 +173,7 @@ def mma( ) elif any( t.isinstance(element_type) - for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3Type} + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType} ): if d.dtype != f16 and d.dtype != f32: raise ValueError( @@ -299,7 +299,7 @@ def _do_mma( kind = "f16" elif ir.Float8E5M2Type.isinstance(element_type): kind = "f8f6f4" - elif ir.Float8E4M3Type.isinstance(element_type): + elif ir.Float8E4M3FNType.isinstance(element_type): kind = "f8f6f4" else: raise NotImplementedError(f"Unsupported input element type: {element_type}") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b9350b0c995b..e02acf8cce13 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -983,7 +983,7 @@ def kernel(ctx, input, output, scratch): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3), # TODO(apaszke): f32 + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn), # TODO(apaszke): f32 out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 From d99778cfe61b256479d3102ffa4a667e6f97a815 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 12 May 2025 06:56:21 -0700 Subject: [PATCH 1124/1769] [Mosaic GPU] Add support for TMEM loads/stores with the 32x32b shape This should be useful for kernels such as FlashAttention since row-wise reductions can be performed entirely without any communication with other threads. PiperOrigin-RevId: 757746207 --- jax/experimental/mosaic/gpu/tcgen05.py | 274 +++++++++++++++++++------ jax/experimental/mosaic/gpu/utils.py | 1 + tests/mosaic/gpu_test.py | 27 ++- 3 files changed, 234 insertions(+), 68 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index c46e24b9ada2..89a58e4788f5 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -43,6 +43,20 @@ lane_dims=(-4, -3), vector_dim=-1, ) +# A layout resembling the logical organization of TMEM. The 128 rows in a tile +# are assigned to 128 lanes in the warpgroup. Useful when the result needs to be +# processed in registers and then stored back into TMEM. Should not be used if +# the result is to be written back to SMEM, as there is no good way to store it +# without bank conflicts. +# +# We use a vector_dim of 2, to be able to make sure that the vectors are always +# a multiple of 32-bits, even when the data is 16-bits. +TMEM_NATIVE_LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 2), (32, 2))), + warp_dim=-4, + lane_dims=(-2,), + vector_dim=-1, +) def create_instr_descriptor( @@ -428,6 +442,8 @@ def _tmem_access_helper(shape, num): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: + case "32x32b": + num_regs = 1 case "16x128b": num_regs = 2 case "16x256b": @@ -657,43 +673,60 @@ def load(self, layout: fa.TiledLayout = LAYOUT): raise NotImplementedError if utils.bitwidth(self.dtype) not in {16, 32}: raise NotImplementedError(f"Unsupported dtype: {self.dtype}") - if layout != LAYOUT: + if layout == LAYOUT: + regs_shape = layout.registers_shape(self.shape) + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): + # load_32xcols returns a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + registers = _load_32xcols( + self.address, self.shape[1], self.dtype, packing + ).T.reshape(regs_shape) + case TMEMLayout(elements_in_tile=(r, 128), column_tile_stride=2) if r == TMEM_ROWS: + if self.shape[1] % 128 != 0: + raise ValueError( + f"TMEM layout {self.layout} is not compatible with shape {self.shape}" + ) + num_column_tiles = self.shape[1] // 128 + column_tile_stride = self.layout.column_tile_stride + num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) + tiles = [] + for col_tile_base in range(num_strided_col_groups): + for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): + tiles.append( + _load_32xcols( + arith.addi(self.address, arith.constant(i32, col_tile * 128)), + cols=128, + dtype=self.dtype, + tmem_packing=1, + ) + ) + registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + case _: + raise NotImplementedError( + f"Loads only implemented for refs with standard layout, got: {self.layout}" + ) + elif layout == TMEM_NATIVE_LAYOUT: + regs_shape = layout.registers_shape(self.shape) + match self.layout: + case TMEMLayout(elements_in_tile=(r, c), packing=packing) if ( + r == TMEM_ROWS and c % 2 == 0 + ): + registers = _load_32xcols_native( + self.address, self.shape[1], self.dtype, packing + ).reshape(regs_shape) + case _: + raise NotImplementedError( + "Loads only implemented for refs with standard layout, got:" + f" {self.layout}" + ) + else: raise ValueError( - "TMEM loads can only produce results in the tcgen05 layout" - f" ({LAYOUT}), but got: {layout}" + "TMEM loads can only produce results in the tcgen05 layouts" + f" ({LAYOUT} and {TMEM_NATIVE_LAYOUT}), but got: {layout}" ) - regs_shape = layout.registers_shape(self.shape) - match self.layout: - case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: - # load_32xcols returns a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). - registers = _load_32xcols( - self.address, self.shape[1], self.dtype, packing - ).T.reshape(regs_shape) - case TMEMLayout(elements_in_tile=(r, 128), column_tile_stride=2) if r == TMEM_ROWS: - if self.shape[1] % 128 != 0: - raise ValueError( - f"TMEM layout {self.layout} is not compatible with shape {self.shape}" - ) - num_column_tiles = self.shape[1] // 128 - column_tile_stride = self.layout.column_tile_stride - num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) - tiles = [] - for col_tile_base in range(num_strided_col_groups): - for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): - tiles.append( - _load_32xcols( - arith.addi(self.address, arith.constant(i32, col_tile * 128)), - cols=128, - dtype=self.dtype, - tmem_packing=1, - ) - ) - registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) - case _: - raise NotImplementedError( - f"Loads only implemented for refs with standard layout, got: {self.layout}" - ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) def store(self, value): @@ -713,23 +746,39 @@ def store(self, value): f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" f" {self.dtype}" ) - if value.layout != LAYOUT: + if value.layout == LAYOUT: + # TODO(apaszke): Collective MMA layout + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)), packing + ) + case _: + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + elif value.layout == TMEM_NATIVE_LAYOUT: + # TODO(apaszke): Collective MMA layout + match self.layout: + case TMEMLayout(elements_in_tile=(r, c), packing=packing) if ( + r == TMEM_ROWS and c % 2 == 0 + ): + _store_32xcols_native( + self.address, value.registers.reshape(-1), packing + ) + case _: + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + else: raise ValueError( - f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is" - " supported" + f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT and" + " tcgen05.TMEM_NATIVE_LAYOUT are supported" ) - # TODO(apaszke): Collective MMA layout - match self.layout: - case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if r == TMEM_ROWS: - # store_32xcols needs a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). - _store_32xcols( - self.address, value.registers.T.reshape((4, -1)), packing - ) - case _: - raise NotImplementedError( - f"Stores only implemented for refs with standard layout, got: {self.layout}" - ) def _debug_print(self): i32 = ir.IntegerType.get_signless(32) @@ -756,28 +805,43 @@ def _debug_print(self): utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False) -def _transfer_32xcols(base_addr: ir.Value, cols: int, packing: int): +def _transfer_32xcols( + base_addr: ir.Value, + cols: int, + atom_shape: tuple[int, int], + tmem_packing: int, + reg_packing: int, +): + """Generates a sequence of parameters for a given TMEM read or write. + + Arguments: + base_addr: The base address of the TMEM region. + cols: The number of logical columns to transfer. + atom_shape: The logical shape of the tile written by the warp in a single + TMEM transfer. + tmem_packing: Packing degree in TMEM. When packing is 1, but the data is + 16-bit, we expect that each transfer actually involves double the number + of physical columns. + reg_packing: The number of elements that fit in a single 32-bit register. + """ i32 = ir.IntegerType.get_signless(32) - cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. - assert cols % cols_per_num == 0 - total_num = cols // cols_per_num + atom_rows, atom_cols = atom_shape + assert cols % atom_cols == 0 + total_num = cols // atom_cols assert total_num.bit_count() == 1 + regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing) # We artificially lower the instr_num compared to its limits, because higher # values can lead to register spills.. - if total_num <= 16: - instr_num = total_num - elif 32 <= total_num <= 64: - instr_num = 16 - else: - raise NotImplementedError(total_num) - # We transfer 16 lanes at a time, but have 32 to deal with. - for lane_step in range(2): - addr_row = arith.addi(base_addr, utils.c((lane_step * 16) << 16, i32)) - cols_per_instr = instr_num * cols_per_num + instr_num = min(total_num, 64 // regs_per_instr) + assert 32 % atom_rows == 0 + num_row_steps = 32 // atom_rows + for lane_step in range(num_row_steps): + addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32)) + cols_per_instr = instr_num * atom_cols for num_step in range(total_num // instr_num): num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) addr_row_col = arith.addi( - addr_row, utils.c(num_step * cols_per_instr // packing, i32) + addr_row, utils.c(num_step * cols_per_instr // tmem_packing, i32) ) yield addr_row_col, instr_num, lane_step, num_slice @@ -813,12 +877,44 @@ def _store_32xcols(base_addr, vector_regs, tmem_packing): else: raise NotImplementedError(reg_packing) - it = _transfer_32xcols(base_addr, cols, tmem_packing) + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) for addr_row_col, instr_num, lane_step, num_slice in it: regs_slice = regs[lane_step, num_slice].flat tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) +def _store_32xcols_native(base_addr, vector_regs, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 1 + cols = len(vector_regs) * TMEM_NATIVE_LAYOUT.vector_length + + reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + store_shape = "32x32b" + if reg_packing == 1: + store_atom_shape = (32, 1) + regs = [None] * (len(vector_regs) * 2) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in enumerate(vector_regs): + regs[2 * idx] = llvm.extractelement(vreg, c0) + regs[2 * idx + 1] = llvm.extractelement(vreg, c1) + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: + store_atom_shape = (32, 2) + regs = vector_regs + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0 + regs_slice = regs[num_slice] + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) + + def _load_32xcols(base_addr, cols, dtype, tmem_packing): i32 = ir.IntegerType.get_signless(32) vec_ty = ir.VectorType.get((2,), dtype) @@ -836,7 +932,7 @@ def _load_32xcols(base_addr, cols, dtype, tmem_packing): vector_regs = np.ndarray((4, cols // 8), dtype=object) - it = _transfer_32xcols(base_addr, cols, tmem_packing) + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) c0 = arith.constant(i32, 0) c1 = arith.constant(i32, 1) for addr_row_col, instr_num, lane_step, num_slice in it: @@ -868,6 +964,50 @@ def _load_32xcols(base_addr, cols, dtype, tmem_packing): return vector_regs +def _load_32xcols_native(base_addr, cols, dtype, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((2,), dtype) + reg_packing = 32 // utils.bitwidth(dtype) + load_shape = "32x32b" + if reg_packing == 1: + load_atom_shape = (32, 1) + assert tmem_packing == 1 + pack = False + elif reg_packing == 2: + load_atom_shape = (32, 2) + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + regs = [None] * (cols // reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0, lane_step + instr_regs = tmem_load(addr_row_col, load_shape, instr_num, pack) + if reg_packing == 1: + regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs] + else: + assert reg_packing == 2 + regs[num_slice] = [llvm.bitcast(vec_ty, r) for r in instr_regs] + + if reg_packing == 1: + vector_regs = np.ndarray((cols // 2,), dtype=object) + undef = llvm.mlir_undef(vec_ty) + for idx in range(vector_regs.size): + high_undef = llvm.insertelement(undef, regs[2 * idx], c0) + vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1) + vector_regs[idx] = vreg + else: + assert reg_packing == 2 + vector_regs = np.asarray(regs, dtype=object) + + assert vector_regs.shape == (cols // TMEM_NATIVE_LAYOUT.vector_length,) + return vector_regs + + def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index bd11c3a07544..bf0b06ccb9c9 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -40,6 +40,7 @@ # mypy: ignore-errors +WARP_SIZE: int = 32 WARPGROUP_SIZE: int = 128 DYNAMIC = -9223372036854775808 DYNAMIC32 = -2147483648 diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e02acf8cce13..03ded0ac446c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -906,7 +906,7 @@ def setUp(self): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") @parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)]) - def test_load_store_tmem(self, jax_dtype, packing): + def test_load_store_tmem_swizzle(self, jax_dtype, packing): swizzle = 128 in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) swizzle_elems = swizzle // bytewidth(in_mlir_dtype) @@ -942,6 +942,31 @@ def kernel(ctx, input, output, scratch): )(x) np.testing.assert_array_equal(x, y) + @parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)]) + def test_load_store_tmem_native(self, jax_dtype, packing): + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy(src_ref=input, dst_ref=smem, barrier=barrier) + barrier.wait() + tmem.store(fa.FragmentedArray.load_untiled(smem, layout=tcgen05.TMEM_NATIVE_LAYOUT, optimized=False)) + tcgen05.commit_tmem() + tmem.load(tcgen05.TMEM_NATIVE_LAYOUT).store_untiled(smem, optimized=False) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem, dst_ref=output) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(x.shape, jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + @parameterized.parameters([ (jnp.float32, 1, "130.0000"), (jnp.float16, 1, "130.0000"), From fc12df095e01e630fde9c9520a740b92977779a1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 12 May 2025 07:25:38 -0700 Subject: [PATCH 1125/1769] [pallas:mosaic_gpu] Slightly generalized `MosaicGridMapping` PiperOrigin-RevId: 757755328 --- jax/_src/pallas/mosaic/lowering.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b2a1f356ad0b..1ea5a048a17e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -425,6 +425,7 @@ def __init__( dynamic_shape_replacement_fn: Callable[ [tuple[jax.DimSize, ...]], tuple[int, ...] ], + arg_type_fn: Callable[..., ir.Type], ): self.grid = grid_mapping.grid self.grid_names = grid_mapping.grid_names @@ -464,17 +465,17 @@ def __init__( operand_avals = in_avals[grid_mapping.slice_block_ops] scratch_avals = in_avals[grid_mapping.slice_scratch_ops] self.scalar_prefetch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + arg_type_fn(dynamic_shape_replacement_fn, aval, None) for aval in scalar_prefetch_avals ]) self.scalar_prefetch_block_shapes = tuple( aval.shape for aval in scalar_prefetch_avals) self.operand_types, self.operand_block_shapes = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, block_mapping) + arg_type_fn(dynamic_shape_replacement_fn, aval, block_mapping) for aval, block_mapping in zip(operand_avals, self.block_mappings) ]) self.scratch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + arg_type_fn(dynamic_shape_replacement_fn, aval, None) for aval in scratch_avals ]) self.scratch_block_shapes = tuple( @@ -482,7 +483,7 @@ def __init__( for aval in scratch_avals ) self.grid_types, _ = unzip2([ - _get_arg_type( + arg_type_fn( dynamic_shape_replacement_fn, pallas_core.index_map_grid_aval, None, @@ -710,6 +711,7 @@ def dynamic_shape_replacement_fn( dimension_semantics, mesh, dynamic_shape_replacement_fn, + arg_type_fn=_get_arg_type, ) mosaic_grid_mapping.maybe_compress_grid() m = ir.Module.create() From e65b317a4cef945eec0ed2378442df743fa1ee31 Mon Sep 17 00:00:00 2001 From: David Marttila Date: Fri, 9 May 2025 21:56:42 +0100 Subject: [PATCH 1126/1769] Speed up `scipy.signal.stft` by using `lax.dynamic_slice_in_dim` for windowing --- jax/_src/scipy/signal.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 565909e8a6d1..f8c2563027f5 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -566,13 +566,9 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], result = x[..., np.newaxis] else: step = nperseg - noverlap - batch_shape = list(batch_shape) - x = x.reshape((math.prod(batch_shape), signal_length, 1)) - result = jax.lax.conv_general_dilated_patches( - x, (nperseg,), (step,), - 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) - result = result.reshape(*batch_shape, *result.shape[-2:]) + starts = jnp.arange(signal_length - nperseg + 1, step=step) + slice_func = partial(jax.lax.dynamic_slice_in_dim, operand=x, slice_size=nperseg, axis=-1) + result = jax.vmap(slice_func, out_axes=-2)(start_index=starts) # Detrend each data segment individually result = detrend_func(result) From f2121a72fc97c555eda6f519dba28dfe883e62cb Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 8 May 2025 14:39:59 +0300 Subject: [PATCH 1127/1769] [platform_dependent] Ensure that platform_dependent only lowers for intended platforms Fixes: #28594 Currently `lax.platform_dependent` allows specifying code that behaves differently when lowered on different platforms. However, this function operates in a confusing way, in that it will create a branch on the platform, but will lower all branches for the **current** lowering platforms. For example, in the following code: ``` lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu) ``` If we lower for CPU, we lower both `for_cpu` and `for_tpu` for CPU (!), but only the branch corresponding to `for_cpu` will actually run. This is a problem if, e.g., `for_tpu` does not have a lowering for CPU. We will get an error during lowering. Instead there should be no error during lowering, because that branch is not actually needed. We add a new test `test_platform_dependent_with_primitive_with_lowering_error` to demonstrate this. The solution implememented here is the Solution A from #28594: we add a `branches_platform` param to the `cond` primitive, which is propagated by all transformations. This param is used only for the conditionals arising from `lax.platform_dependendet`. During lowering we drop the branches corresponding to the platforms that are not interesting. --- jax/_src/checkify.py | 5 +- jax/_src/interpreters/mlir.py | 8 +- jax/_src/lax/control_flow/__init__.py | 1 + jax/_src/lax/control_flow/conditionals.py | 176 ++++++++++++++-------- jax/_src/pallas/mosaic/lowering.py | 12 +- jax/_src/pallas/mosaic_gpu/lowering.py | 5 +- jax/experimental/jax2tf/jax2tf.py | 5 +- tests/lax_control_flow_test.py | 79 ++++++++-- 8 files changed, 201 insertions(+), 90 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5a6456762db7..144cbaf5cd21 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -759,7 +759,8 @@ def jaxpr_to_checkify_jaxpr( out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects -def cond_error_check(error: Error, enabled_errors, index, *ops, branches): +def cond_error_check(error: Error, enabled_errors, index, *ops, + branches, **params): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) @@ -780,7 +781,7 @@ def get_error_effects_from_jaxpr(jxpr): err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, - branches=tuple(new_branches)) + branches=tuple(new_branches), **params) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e9deb8d3fff9..f6ef5787ccbf 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2080,6 +2080,11 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return ('tpu',) return () +def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: + """The lowering platforms for the current eqn""" + return tuple((_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms)) + def lower_per_platform(ctx: LoweringRuleContext, description: str, @@ -2122,8 +2127,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or - ctx.platforms or ctx.module_context.platforms) + platforms: Sequence[str] = _platforms_for_eqn(ctx) # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index f89e4d53a476..44ee94e14ca2 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -34,6 +34,7 @@ while_p as while_p, ) from jax._src.lax.control_flow.conditionals import ( + BranchesPlatforms as BranchesPlatforms, cond as cond, cond_p as cond_p, switch as switch, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 99fa72421ea1..d875989921d0 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -46,6 +46,7 @@ from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.traceback_util import api_boundary +from jax._src.typing import ArrayLike from jax._src.util import safe_map, split_list, partition_list, unzip2 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -127,9 +128,17 @@ def switch(index, branches, *operands): lo = np.array(0, np.int32) hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) + return _switch_internal(index, branches, operands, + branches_platforms=None) + +def _switch_internal( + index: ArrayLike, + branches: Sequence[Callable], + operands: Sequence[ArrayLike], *, + branches_platforms: BranchesPlatforms | None): if (config.disable_jit.value and core.is_concrete(index)): - return branches[int(index)](*operands) + return branches[int(index)](*operands) # type: ignore dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] @@ -159,7 +168,10 @@ def switch(index, branches, *operands): raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + params = dict(branches=tuple(jaxprs)) + if branches_platforms is not None: + params["branches_platforms"] = branches_platforms + out = cond_p.bind(index, *consts, *ops, **params) out_ = iter(out) all_inputs = [*consts, *ops] @@ -464,7 +476,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(axis_data, args, dims, branches): +def _cond_batching_rule(axis_data, args, dims, *, branches, **params): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -480,6 +492,9 @@ def _cond_batching_rule(axis_data, args, dims, branches): if index_dim is not batching.not_mapped: + assert "branches_platforms" not in params, ( + "The index of a cond with branches_platforms should be a " + "platform_index and should never be mapped") # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave @@ -518,10 +533,11 @@ def _cond_batching_rule(axis_data, args, dims, branches): for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind(index, *ops, branches=branches_batched) + out = cond_p.bind(index, *ops, branches=branches_batched, + **params) return out, out_dims -def _cond_jvp(primals, tangents, branches): +def _cond_jvp(primals, tangents, *, branches, **params): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros @@ -538,14 +554,15 @@ def _cond_jvp(primals, tangents, branches): _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) - out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) + out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp, + **params) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents -def _cond_partial_eval(trace, *tracers, branches): +def _cond_partial_eval(trace, *tracers, branches, **params): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in @@ -556,7 +573,7 @@ def _cond_partial_eval(trace, *tracers, branches): if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches) + params = dict(branches=branches, **params) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] @@ -586,7 +603,8 @@ def _cond_partial_eval(trace, *tracers, branches): for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known) + out_consts_res = cond_p.bind(*in_consts, branches=branches_known, + **params) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) @@ -595,7 +613,7 @@ def _cond_partial_eval(trace, *tracers, branches): res_tracers = map(trace.new_instantiated_const, res) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals] - params = dict(branches=branches_unknown) + params = dict(branches=branches_unknown, **params) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( @@ -608,6 +626,7 @@ def _cond_partial_eval(trace, *tracers, branches): def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): index_uk, *ops_uk = unks_in branches = eqn.params['branches'] + eqn_rest_params = dict(k_v for k_v in eqn.params.items() if k_v[0] != 'branches') # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -664,7 +683,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar out_binders_known, _ = partition_list(unks_out, eqn.outvars) - params_known = dict(branches=branches_known) + params_known = dict(branches=branches_known, **eqn_rest_params) effects_known = _join_cond_effects(branches_known) eqn_known = pe.new_jaxpr_eqn( ins_known, [*out_binders_known, *res_binders], cond_p, params_known, @@ -672,7 +691,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) - params_staged = dict(branches=branches_staged) + params_staged = dict(branches=branches_staged, **eqn_rest_params) effects_staged = _join_cond_effects(branches_staged) eqn_staged = pe.new_jaxpr_eqn( [eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged, @@ -818,7 +837,7 @@ def transposed(*args): debug_info=jaxpr.jaxpr.debug_info), res_avals + jaxpr.out_avals) -def _cond_transpose(cts, *args, branches): +def _cond_transpose(cts, *args, branches, **params): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] @@ -838,7 +857,8 @@ def _cond_transpose(cts, *args, branches): res = ops[:num_res] cts = map(ad.instantiate_zeros, cts) - out = cond_p.bind(index, *res, *cts, branches=branches_trans) + out = cond_p.bind(index, *res, *cts, branches=branches_trans, + **params) assert all(map(core.typecheck, lin_in_avals, out)) out_iter = iter(out) @@ -846,7 +866,8 @@ def _cond_transpose(cts, *args, branches): assert next(out_iter, None) is None return [None] + out -def _cond_typecheck(bind_time, *in_atoms, branches): +def _cond_typecheck(bind_time, *in_atoms, branches, **params): + del params if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -900,6 +921,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects + +BranchesPlatforms = tuple[tuple[str, ...] | None, ...] +# cond_p takes an optional branches_platforms param of type `BranchesPlatforms` +# when it is a `platform_dependent` conditional. +# In that case, `branches_platforms` is a tuple as long +# as `branches` and for each branch it specifies the lowering platforms it +# corresponds to. The last element, corresponding to the last branch, +# can be `None` to represent a default match-all-lowering-platforms. +# The index argument of a `platform_dependent` cond is always a +# `platform_index` primitive. cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.skip_canonicalization = True @@ -915,7 +946,39 @@ def _cond_typecheck(bind_time, *in_atoms, branches): pe.dce_rules[cond_p] = _cond_dce_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule -def _cond_lowering(ctx, index, *args, branches): +def _cond_lowering(ctx, index, *args, branches, + **params): + if (branches_platforms := params.get("branches_platforms", None)) is not None: + branches_kept: list[core.ClosedJaxpr] = [] + index_to_kept_index: dict[int, int] = {} + for p in mlir._platforms_for_eqn(ctx): + # Each `p` must appear in exactly one branches_platforms, or in the + # last default branch. Otherwise, platform_index lowering would have + # failed already. + for b_idx, b_platforms in enumerate(branches_platforms): + if b_platforms is None or p in b_platforms: + if b_idx not in index_to_kept_index: + index_to_kept_index[b_idx] = len(branches_kept) + branches_kept.append(branches[b_idx]) + break + else: + assert False, p + + # Compute the new index into branches_keep + i32_type = ir.RankedTensorType.get([], mlir.dtype_to_ir_type(dtypes.dtype(np.int32))) + kept_index_case_op = hlo.CaseOp([i32_type], + index=index, + num_branches=len(branches)) + for i in range(len(branches)): + branch = kept_index_case_op.regions[i].blocks.append() + with ir.InsertionPoint(branch): + kept_i = np.int32(index_to_kept_index.get(i, 0)) + hlo.return_([mlir.ir_constant(kept_i)]) + + index = kept_index_case_op + branches = branches_kept + assert branches, "platform_index lowering should have failed first" + joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) @@ -952,7 +1015,8 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) @register_partial_discharge_rule(cond_p) -def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches): +def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, + branches, **params): assert not should_discharge[0], "Can't discharge the index." discharged_branches = tuple( discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0] @@ -981,7 +1045,8 @@ def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *ar if fwd is None]), ()) for branch in discharged_branches ) - out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches) + out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches, + **params) out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)]) # Insert forwarded values into reference outputs ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd) @@ -1046,50 +1111,41 @@ def other_platforms_code(*args): ... The value ``per_platform[execution_platform](*args)``. """ # Join identical branches - platform_branches: list[tuple[list[str], Callable]] = [] + branches_platforms_list: list[tuple[list[str], Callable]] = [] for pname, pbranch in per_platform.items(): + if not callable(pbranch): + raise TypeError(f"lax.platform_dependent: the '{pname}' branch must " + "be a callable.") if pname == "gpu": raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.") - for ps, b in platform_branches: + for ps, b in branches_platforms_list: if b == pbranch: ps.append(pname) break else: - platform_branches.append(([pname], pbranch)) - - platforms_lists, branches = util.unzip2(platform_branches) - platform_index = platform_index_p.bind( - platforms=tuple(tuple(ps) for ps in platforms_lists), - has_default=(default is not None)) + branches_platforms_list.append(([pname], pbranch)) + platforms_lists, branches = util.unzip2(branches_platforms_list) + branches_platforms: BranchesPlatforms = tuple(tuple(ps) for ps in platforms_lists) if default is not None: + if not callable(default): + raise TypeError("lax.platform_dependent: the 'default' branch must " + "be a callable.") branches = branches + (default,) - # Use a switch, to get the proper transformation rules for free. Since - # platform index has no dependence on the input data, it won't be vectorized - # under vmap. - # If the switch and the platform_index_p above are in the same compilation - # unit then constant-folding will remove the unnecessary branches. However, - # if we run in eager mode the switch below cannot be constant-folded and - # the compilation may fail if some of the branches contain custom calls not - # recognized on the compilation platform. Detect eager mode and keep only the - # needed branch. - try: - # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into - # core.ensure_compile_time_eval to get better single-branch selection. - platform_index_concrete = core.concrete_or_error(operator.index, platform_index) - except core.ConcretizationTypeError: - return switch(platform_index, branches, *args) - else: - assert 0 <= platform_index_concrete < len(branches) - return branches[platform_index_concrete](*args) + branches_platforms = branches_platforms + (None,) # type: ignore + platform_index = platform_index_p.bind(platforms=branches_platforms) + + if core.is_concrete(platform_index): + return branches[int(platform_index)](*args) + return _switch_internal(platform_index, branches, args, + branches_platforms=branches_platforms) + # A primitive to compute the index of a platform into a list of platforms. # Args: -# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform -# names. If the current lowering platform is in one of the inner sequences -# returns the index of that inner sequence in the outer sequence. -# has_default: if True, and if the lowering platform is not found in -# `platforms` then return `len(platforms)`. Otherwise, raise an error. +# platforms: BranchesPlatforms. If the current lowering +# platform is in one of the inner tuples returns the index of that inner +# tuple in the outer tuple. platform_index_p = core.Primitive("platform_index") platform_index_p.multiple_results = False platform_index_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -1101,25 +1157,25 @@ def _platform_index_aval(*_, **__): def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool): - def lower_constant( - ctx: mlir.LoweringRuleContext, *, i: int - ) -> Sequence[ir.Value]: + platforms: BranchesPlatforms): + def lower_constant(ctx: mlir.LoweringRuleContext, *, + i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - assert isinstance(v, ir.Value), v return [v] + platform_rules: dict[str, mlir.LoweringRule] = {} + default_rule = None for i, ps in enumerate(platforms): rule = partial(lower_constant, i=i) - for p in ps: - platform_rules[p] = rule + if ps is None: + default_rule = rule + else: + for p in ps: + platform_rules[p] = rule - default_rule = ( - partial(lower_constant, i=len(platforms)) if has_default else None) return mlir.lower_per_platform( ctx, - f"platform_index(platforms={platforms}, has_default={has_default})", + f"platform_index(platforms={platforms})", platform_rules, default_rule, effects.no_effects) mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1ea5a048a17e..bba49c75f9df 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -47,7 +47,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal -from jax._src.lax.control_flow import for_loop +from jax._src.lax.control_flow import for_loop, BranchesPlatforms from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -3100,7 +3100,7 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p) -def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): +def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params): index, *args = args constant_index = _fold_and_get_constant_value(index) @@ -3870,17 +3870,13 @@ def _pad(val): def _platform_index_lowering( ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool, + platforms: BranchesPlatforms, ): for i, ps in enumerate(platforms): # note - slightly odd structure here, as platforms is a seq[seq[str]] - if "mosaic" in ps: + if "mosaic" in ps or ps is None: return ir_constant(i) - if has_default: - return ir_constant(len(platforms)) - raise NotImplementedError( "No mosaic or default platform indexing rule found." ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b501693bf627..9ead4f16c1a6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2598,7 +2598,10 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) -def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches, + **params): + if params: + raise NotImplementedError("platform_dependent cond") index_aval, *_arg_avals = ctx.avals_in def _yielded_values(outs, avals): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 4c2f35a95c57..786e021e2ff0 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3062,8 +3062,11 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: def _cond( - index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr] + index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], + **params ) -> Sequence[TfVal]: + if params: + raise NotImplementedError("jax2tf conversion for platform_dependent") # tf.cond needs lambdas with no arguments. branches_tf = [ partial(_interpret_jaxpr, jaxpr, *operands, diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 422ef769e392..d32d761ee1fa 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -37,8 +37,10 @@ from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp +from jax._src import dispatch from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import batching from jax._src.interpreters import mlir jax.config.parse_flags_with_absl() @@ -137,6 +139,36 @@ def scan_reference(f, init, xs): lambda ctx, x: mlir.hlo.CustomCallOp( [x.type], [x], call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) +batching.primitive_batchers[prim_non_existent_custom_call] = ( + lambda batched_args, batch_dims: (prim_non_existent_custom_call.bind(batched_args[0]), + batch_dims[0])) + +# A JAX primitive that triggers error when lowering on unintended platforms +prim_with_lowering_error = core.Primitive("__testing_prim_with_lowering_error") +prim_with_lowering_error.def_abstract_eval(lambda x_aval, **_: x_aval) +def prim_with_lowering_error_lowering(platform: str, + ctx: mlir.LoweringRuleContext, x, *, + only_on: str): + if platform != only_on: + raise ValueError(f"prim_with_lowering_error with only_on={only_on} lowered for {platform}") + return mlir.hlo.SineOp(x).results +def prim_with_lowering_error_batch_rule(batched_args, batch_dims, **params): + xs, = batched_args + xs_bdim, = batch_dims + return prim_with_lowering_error.bind(xs, **params), xs_bdim + +batching.primitive_batchers[prim_with_lowering_error] = prim_with_lowering_error_batch_rule + +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "cpu"), + platform="cpu") +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "tpu"), + platform="tpu") +prim_with_lowering_error.def_impl(partial(dispatch.apply_primitive, + prim_with_lowering_error)) class LaxControlFlowTest(jtu.JaxTestCase): @@ -1378,7 +1410,7 @@ def f(x): @parameterized.named_parameters( {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) - def testCondGrad2(self, cond): + def testCondGrad2(self, cond=cond_with_new_checkpoint): def f_ref(x): z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x) return z.sum() @@ -2905,18 +2937,13 @@ def f(x): x = np.arange(3, dtype=np.float32) lowered = jax.jit(f).lower(x) stablehlo = lowered.as_text() - self.assertIn("stablehlo.case", stablehlo) - self.assertIn("stablehlo.sine", stablehlo) - self.assertIn("stablehlo.cosine", stablehlo) - - # The HLO has been canonicalized and contains only the branch we need - hlo = lowered.as_text("hlo") + # The StableHLO contains only the branch we need if jtu.device_under_test() == "cpu": - self.assertIn(" sine", hlo) - self.assertNotIn(" cosine", hlo) + self.assertIn("stablehlo.sine", stablehlo) + self.assertNotIn("stablehlo.cosine", stablehlo) else: - self.assertNotIn(" sine", hlo) - self.assertIn(" cosine", hlo) + self.assertNotIn("stablehlo.sine", stablehlo) + self.assertIn("stablehlo.cosine", stablehlo) def test_platform_dependent_with_non_existent_custom_call(self): if not jtu.test_device_matches(["cpu"]): @@ -2939,8 +2966,7 @@ def f(x): x = np.arange(3, dtype=np.float32) hlo = str(jax.jit(f).lower(x).compiler_ir()) - occurrences = re.findall(prim_non_existent_custom_call.name, hlo) - self.assertLen(occurrences, 3) + self.assertNotIn(prim_non_existent_custom_call.name, hlo) res_eager = f(x) self.assertAllClose(res_eager, 3. * np.sin(x)) @@ -2956,6 +2982,26 @@ def f(x): res_grad = jax.grad(f)(1.) self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_with_primitive_with_lowering_error(self): + if not jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Only for CPU and TPU") + + def f(x): + return lax.platform_dependent( + x, + # Check that we only lower on the intended platform + cpu=lambda x: prim_with_lowering_error.bind(x, only_on="cpu"), + tpu=lambda x: prim_with_lowering_error.bind(x, only_on="tpu")) + + self.assertAllClose(np.sin(1.), f(1.)) # Eager + self.assertAllClose(np.sin(1.), jax.jit(f)(1.)) + self.assertAllClose(np.sin(1.), lax.cond(True, f, lambda x: x, 1.)) + self.assertAllClose(1., lax.cond(False, f, lambda x: x, 1.)) + self.assertAllClose((0., np.sin(np.arange(8.))), + lax.scan(lambda carry, x: (carry, f(x)), + 0., np.arange(8.))) + self.assertAllClose(np.sin(np.arange(8.)), jax.vmap(f)(np.arange(8.))) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2965,13 +3011,14 @@ def f(x): tpu=jnp.sin, default=lambda x: x) res = f(x) + on_cpu_tpu = jtu.device_under_test() in ["cpu", "tpu"] self.assertAllClose( res, - np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x) - # We only lower the common branches once + np.sin(x) if on_cpu_tpu else x) + stablehlo = jax.jit(f).lower(x).as_text() sines = re.findall(r"stablehlo.sine", stablehlo) - self.assertEqual(1, len(sines)) + self.assertEqual(1 if on_cpu_tpu else 0, len(sines)) def test_platform_dependent_no_default(self): ctx = contextlib.ExitStack() From 1be91c5dd634dc1772a5508f0da73dd97ec1de30 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 May 2025 09:05:55 -0700 Subject: [PATCH 1128/1769] Revert pytest: use importlib mode by default --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cf8002ffe610..03cc78a6dcbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" [tool.ruff] preview = true From 3b9865af0c418634811284953f76e70c347319e9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 May 2025 10:35:04 -0700 Subject: [PATCH 1129/1769] [array API] update test suite to most recent version --- .github/workflows/jax-array-api.yml | 2 +- tests/array_api_skips.txt | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 825d3ada9a0b..6419cb730b71 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 + ref: 'c847143beb8d769bde5dbcc063fe19ed7acc2f9b' # Latest commit as of 2025-05-12 submodules: 'true' path: 'array-api-tests' - name: Install dependencies diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 7534cf6f8acd..7781b93e7820 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -2,6 +2,7 @@ # finfo return type misalignment (https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem @@ -28,6 +29,8 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +# Array API expects default value for axis argument. +array_api_tests/test_indexing_functions.py::test_take_along_axis # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted From f6b9f7d7e32272ff10d447fb7c986edbe3fd3ef8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 13:05:55 -0700 Subject: [PATCH 1130/1769] remove :custom_call and :runtime from mosaic_gpu since they are in :mosaic_gpu_support now. PiperOrigin-RevId: 757881025 --- jaxlib/mosaic/gpu/BUILD | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b694258fed1e..115d0c47cc52 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -23,10 +23,7 @@ package( py_library( name = "mosaic_gpu", data = [":libmosaic_gpu_runtime.so"], - deps = [ - ":_mosaic_gpu_ext", - ":mosaic_gpu_support", - ], + deps = [":_mosaic_gpu_ext"], ) cc_library( From 4dfcbc2c934456bbda866b8882d57864210f333d Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 13:21:33 -0700 Subject: [PATCH 1131/1769] [Mosaic] Support squeezing tiled memrefs to 1d shapes. Previously, squeezing to a 1D memref failed w/ verification errors, as we would always use the old layout. If we are squeezing from a source to a 1D shape, we need to modify the tile dimensions when we emit the result layout, as the removed dimensions should not be included in the new tiling. PiperOrigin-RevId: 757887087 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 212 +++++++++++++++++++-------- 1 file changed, 147 insertions(+), 65 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 134db412042d..934088e91506 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -164,53 +166,115 @@ LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, return success(); } +// Computes the dimensions that were squeezed from the source shape to match the +// target shape. Returns the dimensions in increasing order. +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape) { + SmallVector squeezed; + int source_index = source_shape.size() - 1; + int target_index = target_shape.size() - 1; + + while (source_index >= 0 || target_index >= 0) { + int64_t target_dim = (target_index >= 0) ? target_shape[target_index] : -1; + if (source_index < 0) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + int64_t source_dim = source_shape[source_index]; + if (source_dim == target_dim) { + source_index--; + target_index--; + } else { + if (source_dim != 1) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + squeezed.push_back(source_index); + source_index--; + } + } + + if (source_index != -1 || target_index != -1) { + op->emitError() << "Shape mismatch after traversal. Source shape: " + << shapeToString(source_shape) + << ", target shape: " << shapeToString(target_shape); + return failure(); + } + std::reverse(squeezed.begin(), squeezed.end()); + return squeezed; +} + LogicalResult MemRefSqueezeOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); - // Source and target attributes may be different before propagation is done by - // the canonicalizer, so we allow this when attributes are "unset" in the - // target type. + if (target_type.getMemorySpace() != nullptr && target_type.getMemorySpace() != source_type.getMemorySpace()) { - emitOpError("Memory spaces do not match."); - return failure(); + return emitOpError("Memory spaces do not match."); } + if (target_type.getElementType() != source_type.getElementType()) { - this->emitOpError("Element types don't match."); - return failure(); - } - if (!HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem) && - source_type.getRank() > 1 && target_type.getRank() == 1) { - return emitError("Not implemented: squeeze memref to 1d."); + return emitOpError("Element types don't match."); } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto error_msg = llvm::formatv( - "Target shape is not valid. " - "Source type: {0}. Target type: {1}.", - source_type, target_type); - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - if (source_index < 0) { - // We have run out of source shape but target shape still remains. - emitOpError(error_msg); - return failure(); + auto squeezed_or = + computeSqueezedDimsChecked(*this, source_shape, target_shape); + if (failed(squeezed_or)) { + return failure(); + } + + auto erase_layout_op = getInput().getDefiningOp(); + if (!erase_layout_op) { + return success(); + } + + auto layout_ref = erase_layout_op.getOperand(); + MemRefType layout_ty = getMemRefType(layout_ref); + auto layout_attr = dyn_cast(layout_ty.getLayout()); + if (!layout_attr) { + return emitOpError( + "Input from EraseLayoutOp is expected to have a TiledLayoutAttr."); + } + auto &squeezed = squeezed_or.value(); + if (squeezed.empty() && source_shape != target_shape) { + return failure(); + } + + auto tiles = layout_attr.getTiles(); + if (tiles.size() == 1) { + auto tile = layout_attr.getTiles().front(); + auto tile_dims = tile.dimensions(); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int dim : squeezed) { + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return emitOpError() << "Internal error: tile index out of bounds."; + } + if (tile_dims[tile_idx] != 1) { + return emitOpError() + << "All tiled squeezed dimensions must be of size 1."; + } + } } - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Only the source dim can be 1 here. - if (source_dim != 1) { - this->emitOpError(error_msg); - return failure(); - } - source_index--; + } else { + auto first_tile = tiles.front(); + for (int dim : squeezed) { + int first_tiled = source_shape.size() - first_tile.dimensions().size(); + if (dim >= first_tiled) { + return emitOpError() << "When multiple tiles are present, no tiled " + "dimensions can be squeezed."; + } } } + return success(); } @@ -222,42 +286,60 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, if (!erase_layout) { return failure(); } - // Push layout erasure through squeezing. It is important we see the layout - // for lowering and don't make it hard for other ops to query it. + auto layout_ref = erase_layout.getOperand(); - MemRefType layout_ty = layout_ref.getType(); + MemRefType layout_ty = getMemRefType(layout_ref); + auto layout_attr = dyn_cast(layout_ty.getLayout()); + if (!layout_attr) { + return failure(); + } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto old_layout = dyn_cast(layout_ty.getLayout()); - auto target_strides = old_layout.getTileStrides(); - SmallVector tile_strides(target_strides.begin(), - target_strides.end()); - // We want to remove all strides that correspond to squeezed dimensions and - // update the corresponding output layout. - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Source index must be 1 here (otherwise verification will have failed). - // We are safe to mutate the strides vector here because we are looping - // backwards. - tile_strides.erase(tile_strides.begin() + source_index); - source_index--; + auto squeezed_or = computeSqueezedDimsChecked(op, source_shape, target_shape); + if (failed(squeezed_or)) { + return failure(); + } + auto &squeezed = squeezed_or.value(); + if (squeezed.empty() && source_shape != target_shape) { + return failure(); + } + + SmallVector tile_strides = + llvm::to_vector(layout_attr.getTileStrides()); + for (int i = squeezed.size() - 1; i >= 0; --i) { + tile_strides.erase(tile_strides.begin() + squeezed[i]); + } + + tpu::TiledLayoutAttr new_layout; + bool target_is_1d = target_shape.size() == 1; + auto tiles = layout_attr.getTiles(); + if (target_is_1d && tiles.size() == 1) { + auto tile_dims = llvm::to_vector(tiles.front().dimensions()); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int i = squeezed.size() - 1; i >= 0; --i) { + int dim = squeezed[i]; + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return op.emitError() << "Internal error: tile index out of bounds."; + } + tile_dims.erase(tile_dims.begin() + tile_idx); + } } + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), {xla::Tile(tile_dims)}, tile_strides); + } else { + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), layout_attr.getTiles(), tile_strides); } - auto new_layout = tpu::TiledLayoutAttr::get( - source_type.getContext(), old_layout.getTiles(), tile_strides); - auto new_result_type = MemRefType::get(op.getResult().getType().getShape(), - layout_ty.getElementType(), new_layout, - layout_ty.getMemorySpace()); - auto squeeze = rewriter.create(op.getLoc(), new_result_type, - layout_ref); - rewriter.replaceOpWithNewOp(op, op.getType(), squeeze); + + auto new_ty = MemRefType::get(target_shape, layout_ty.getElementType(), + new_layout, layout_ty.getMemorySpace()); + + auto new_squeeze = + rewriter.create(op.getLoc(), new_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, target_type, new_squeeze); return success(); } From 167d6bc765d05f2f49c9aedf1e1c94500d6eefde Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 14:00:17 -0700 Subject: [PATCH 1132/1769] Move the existing mask handling code to the relayout fn, invoke it from the existing tpu relayout rule. PiperOrigin-RevId: 757902288 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 48 +++ .../tpu/transforms/apply_vector_layout.cc | 276 ++++++++++-------- jaxlib/mosaic/dialect/tpu/vreg_util.cc | 13 + jaxlib/mosaic/dialect/tpu/vreg_util.h | 9 + 5 files changed, 219 insertions(+), 128 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index fb72b6948d9d..226fc6285192 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -441,6 +441,7 @@ def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> { let arguments = (ins AnyType:$input); let results = (outs AnyType:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; } def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 934088e91506..b5e68bf08370 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -343,6 +343,54 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, return success(); } +LogicalResult RelayoutOp::verify() { + auto in_layout_array_attr = + getOperation()->getAttrOfType("in_layout"); + if (!in_layout_array_attr || in_layout_array_attr.empty()) { + return emitOpError("missing or empty 'in_layout' attribute"); + } + if (in_layout_array_attr.size() != 1) { + return emitOpError( + "'in_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + auto src_vla = dyn_cast(in_layout_array_attr[0]); + if (!src_vla) { + return emitOpError("'in_layout' attribute is not a VectorLayoutAttr"); + } + + auto out_layout_array_attr = + getOperation()->getAttrOfType("out_layout"); + if (!out_layout_array_attr || out_layout_array_attr.empty()) { + return emitOpError("missing or empty 'out_layout' attribute"); + } + if (out_layout_array_attr.size() != 1) { + return emitOpError( + "'out_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + auto dst_vla = dyn_cast(out_layout_array_attr[0]); + if (!dst_vla) { + return emitOpError("'out_layout' attribute is not a VectorLayoutAttr"); + } + + VectorType input_type = cast(getInput().getType()); + VectorType output_type = cast(getOutput().getType()); + + if (input_type.getShape() != output_type.getShape()) { + return emitOpError("input and output shapes must match"); + } + if (input_type.getElementType() != output_type.getElementType()) { + // Allow i1 to i1 even if bitwidth in layout changes. + if (!(input_type.getElementType().isInteger(1) && + output_type.getElementType().isInteger(1))) { + return emitOpError( + "input and output element types must match for non-mask relayouts"); + } + } + return success(); +} + LogicalResult MemRefReshapeOp::verify() { auto src_ty = getMemRefType(getInput()); auto tgt_ty = getType(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index f8e18070e5e7..b8ba61e7c914 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2101,74 +2101,6 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(op.getNumOperands(), 1); - TPU_ASSERT_EQ_OP(op.getNumResults(), 1); - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in[0].has_value()); - TPU_ASSERT_OP(layouts_out[0].has_value()); - const auto& in_layout = *layouts_in[0]; - const auto& out_layout = *layouts_out[0]; - auto realyout_op = cast(op); - auto in_bitwidth = in_layout.bitwidth(); - auto out_bitwidth = out_layout.bitwidth(); - auto vty = cast(realyout_op.getType()); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - if (in_layout == out_layout) { - realyout_op.replaceAllUsesWith(realyout_op.getInput()); - realyout_op.erase(); - return success(); - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array vals, - disassemble(builder, in_layout, - cast>(realyout_op.getInput()), - ctx.target_shape, - /*use_implicit_shape=*/true)); - // Packing vector masks from 32-bit to 16-bit. - if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 && - out_bitwidth == 16 && - in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] && - in_layout.tiling()[1] == ctx.target_shape[1] && - in_layout.tiling() == out_layout.tiling() && - in_layout.offsets() == out_layout.offsets() && - in_layout.implicit_dim() == out_layout.implicit_dim()) { - std::vector vmsks_shape(vals.dimensions().begin(), - vals.dimensions().end()); - *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); - xla::Array out_vmsks(vmsks_shape, nullptr); - SmallVector val_idx; - Value default_val = - getFullLikeVector(builder, cast>(*vals.begin()), - IntegerAttr::get(builder.getI1Type(), 0)); - out_vmsks.Each([&](absl::Span idx, Value *v) { - val_idx.assign(idx.begin(), idx.end()); - // TODO(jevinjiang): can be simplified when offset is replicated. - *(val_idx.end() - 1) *= 2; - Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - *(val_idx.end() - 1) += 1; - Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - const VectorType mask_ty = getNativeVregOrVmaskType( - builder.getI1Type(), in_bitwidth / 2, ctx.target_shape); - *v = builder.create(mask_ty, low_part, high_part); - }); - const RollVectorsOp rolled_op = - assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape, - /*use_implicit_shape=*/true); - op.replaceAllUsesWith(rolled_op); - op.erase(); - return success(); - } - return op.emitOpError("Not implemented: unsupported layout change"); -} - Value createSubelementMask(OpBuilder &builder, const Location loc, const int bitwidth, const int64_t from, const int64_t to, @@ -4827,60 +4759,6 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op, return success(); } -const llvm::StringMap &rules() { - static const llvm::StringMap *rules = [] { - static auto rules = new llvm::StringMap{ - {arith::ConstantOp::getOperationName(), arith_constant_rule}, - {arith::ExtFOp::getOperationName(), arith_extf_rule}, - {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, - {arith::TruncFOp::getOperationName(), arith_truncf_rule}, - {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - {func::ReturnOp::getOperationName(), func_return_rule}, - {scf::ForOp::getOperationName(), scf_for_rule}, - {scf::WhileOp::getOperationName(), scf_while_rule}, - {scf::ConditionOp::getOperationName(), scf_condition_rule}, - {scf::IfOp::getOperationName(), scf_if_rule}, - {scf::YieldOp::getOperationName(), yield_rule}, - {tpu::YieldOp::getOperationName(), yield_rule}, - {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, - {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, - {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, - {tpu::IotaOp::getOperationName(), tpu_iota_rule}, - {tpu::GatherOp::getOperationName(), tpu_gather_rule}, - {tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule}, - {tpu::LoadOp::getOperationName(), tpu_load_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, - {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, - {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, - {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, - {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, - {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, - {tpu::TraceOp::getOperationName(), tpu_trace_rule}, - {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, - {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, - {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, - {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, - {vector::ExtractOp::getOperationName(), vector_extract_rule}, - {vector::LoadOp::getOperationName(), vector_load_rule}, - {vector::MultiDimReductionOp::getOperationName(), - vector_multi_reduction_rule}, - {vector::ExtractStridedSliceOp::getOperationName(), - vector_extract_strided_slice_rule}, - {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, - {vector::StoreOp::getOperationName(), vector_store_rule}, - {tpu::TransposeOp::getOperationName(), vector_transpose_rule}}; - - for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { - rules->insert({name, rule}); - } - return rules; - }(); - return *rules; -} - // Determines whether we should handle bank conflict for the given stride and // max_sublane_offset. // @@ -6773,12 +6651,20 @@ FailureOr> relayout(RewriteContext &ctx, VectorLayout src, VectorLayout dst) { const auto target_shape = ctx.target_shape; + VectorType vty = v.getType(); const int8_t bitwidth = src.bitwidth(); - if (bitwidth != dst.bitwidth()) { + const bool is_mask = vty.getElementTypeBitWidth() == 1; + const bool is_mask_pack = + is_mask && bitwidth == 32 && dst.bitwidth() == 16 && + src.tiling()[0] == src.packing() * target_shape[0] && + src.tiling()[1] == target_shape[1] && src.tiling() == dst.tiling() && + src.offsets() == dst.offsets() && + src.implicit_dim() == dst.implicit_dim(); + + if (bitwidth != dst.bitwidth() && !is_mask_pack) { return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } - VectorType vty = v.getType(); - const bool is_mask = vty.getElementTypeBitWidth() == 1; + { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6812,6 +6698,38 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask_pack) { + std::vector vmsks_shape(src_tiles.dimensions().begin(), + src_tiles.dimensions().end()); + *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); + xla::Array out_vmsks(vmsks_shape, nullptr); + SmallVector val_idx; + Value default_val = getFullVector( + builder, v.getLoc(), + cast>(*src_tiles.begin()).getType(), + IntegerAttr::get(builder.getI1Type(), 0)); + out_vmsks.Each([&](absl::Span idx, Value *v_slot_in_array) { + val_idx.assign(idx.begin(), idx.end()); + *(val_idx.end() - 1) *= 2; + Value low_part = + *(val_idx.end() - 1) < *(src_tiles.dimensions().end() - 1) + ? src_tiles(val_idx) + : default_val; + *(val_idx.end() - 1) += 1; + Value high_part = + *(val_idx.end() - 1) < *(src_tiles.dimensions().end() - 1) + ? src_tiles(val_idx) + : default_val; + const VectorType mask_ty = getNativeVregOrVmaskType( + builder.getI1Type(), bitwidth / 2, target_shape); + *v_slot_in_array = + builder.create(v.getLoc(), mask_ty, low_part, high_part); + }); + return assemble(builder, vty, dst, out_vmsks, target_shape, + /*use_implicit_shape=*/true) + .getResult(); + } + if (is_mask) { auto new_tile_ty = getNativeVregOrVmaskType( builder.getIntegerType(bitwidth), bitwidth, target_shape); @@ -6823,6 +6741,7 @@ FailureOr> relayout(RewriteContext &ctx, } auto assemble_with_mask_check = [&](xla::Array &tiles, bool use_implicit_shape = false) { + if (is_mask) { auto zeros_tile = builder.create( tiles.begin()->getLoc(), @@ -6941,9 +6860,110 @@ FailureOr> relayout(RewriteContext &ctx, changeOffsets(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), dst.offsets())); - CHECK_EQ(src, dst); // At this point we've should be done. - return assemble_with_mask_check(src_tiles, - /*use_implicit_shape=*/true); + CHECK_EQ(src, dst); + return assemble_with_mask_check(src_tiles, /*use_implicit_shape=*/true); +} + +LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto tpu_relayout_op = cast(op); + auto input_val = dyn_cast>(tpu_relayout_op.getInput()); + + auto in_layout_array_attr = + tpu_relayout_op->getAttrOfType("in_layout"); + if (!in_layout_array_attr || in_layout_array_attr.empty()) { + return tpu_relayout_op.emitOpError( + "missing or empty 'in_layout' attribute"); + } + auto src_vla = dyn_cast(in_layout_array_attr[0]); + if (!src_vla) { + return tpu_relayout_op.emitOpError( + "'in_layout' attribute is not a VectorLayoutAttr"); + } + VectorLayout src_layout = src_vla.getLayout().value(); + + auto out_layout_array_attr = + tpu_relayout_op->getAttrOfType("out_layout"); + if (!out_layout_array_attr || out_layout_array_attr.empty()) { + return tpu_relayout_op.emitOpError( + "missing or empty 'out_layout' attribute"); + } + auto dst_vla = dyn_cast(out_layout_array_attr[0]); + if (!dst_vla) { + return tpu_relayout_op.emitOpError( + "'out_layout' attribute is not a VectorLayoutAttr"); + } + VectorLayout dst_layout = dst_vla.getLayout().value(); + + if (src_layout == dst_layout) { + tpu_relayout_op.replaceAllUsesWith(tpu_relayout_op.getInput()); + tpu_relayout_op.erase(); + return success(); + } + + OpBuilder builder(&op); + FAILUREOR_ASSIGN_OR_RETURN( + TypedValue new_v, + relayout(ctx, builder, input_val, src_layout, dst_layout)); + + tpu_relayout_op.replaceAllUsesWith(new_v); + tpu_relayout_op.erase(); + return success(); +} + +const llvm::StringMap &rules() { + static const llvm::StringMap *rules = [] { + static auto rules = new llvm::StringMap{ + {arith::ConstantOp::getOperationName(), arith_constant_rule}, + {arith::ExtFOp::getOperationName(), arith_extf_rule}, + {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, + {arith::TruncFOp::getOperationName(), arith_truncf_rule}, + {arith::TruncIOp::getOperationName(), arith_trunci_rule}, + {func::ReturnOp::getOperationName(), func_return_rule}, + {scf::ForOp::getOperationName(), scf_for_rule}, + {scf::WhileOp::getOperationName(), scf_while_rule}, + {scf::ConditionOp::getOperationName(), scf_condition_rule}, + {scf::IfOp::getOperationName(), scf_if_rule}, + {scf::YieldOp::getOperationName(), yield_rule}, + {tpu::YieldOp::getOperationName(), yield_rule}, + {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, + {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, + {tpu::IotaOp::getOperationName(), tpu_iota_rule}, + {tpu::GatherOp::getOperationName(), tpu_gather_rule}, + {tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule}, + {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, + {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, + {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, + {tpu::RegionOp::getOperationName(), tpu_region_rule}, + {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, + {tpu::TraceOp::getOperationName(), tpu_trace_rule}, + {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, + {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, + {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, + {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, + {vector::ExtractOp::getOperationName(), vector_extract_rule}, + {vector::LoadOp::getOperationName(), vector_load_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_reduction_rule}, + {vector::ExtractStridedSliceOp::getOperationName(), + vector_extract_strided_slice_rule}, + {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, + {vector::StoreOp::getOperationName(), vector_store_rule}, + {tpu::TransposeOp::getOperationName(), vector_transpose_rule}}; + + for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { + rules->insert({name, rule}); + } + return rules; + }(); + return *rules; } // TODO(apaszke): Implement a debug mode that inserts additional assertions. diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 72e0bf7f0caf..237bbe5cc722 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -79,6 +79,19 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, return getFullVector(builder, vec.getType(), value); } +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value) { + return cast>( + builder.create(loc, DenseElementsAttr::get(vty, value)) + .getResult()); +} + +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value) { + return getFullVector(builder, loc, vec.getType(), value); +} + TypedValue getZerosVector(ImplicitLocOpBuilder &builder, VectorType vty) { return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType())); diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 8c2967e776c7..8833390ef87b 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -50,6 +50,15 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, TypedValue vec, Attribute value); +// Same as above, but takes a `loc` as input, in case of an OpBuilder. +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value); + +// Same as above, but takes a `vec` as input. +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value); + // Creates a vmask with false flags to bottom (dim = 0) // or right (dim = 1) where the flag count corresponds to the (dim_size - // padding). From 1a3d3e37b2a71a1ebe9c0bdc3ffb4d95ac4e0bd5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 12 May 2025 14:08:12 -0700 Subject: [PATCH 1133/1769] Default `in_axes` of `smap` to `Infer`. This matches the behavior of `jax.shard_map` where `in_specs` is optional. If the `axis_name` `smap` is going manual over is not of type `Explicit`, we error out and providing `in_axes` is compulsory. This also allows us to not expose `Infer` as a public API! Added some more tests and fixed some bugs too. PiperOrigin-RevId: 757905314 --- jax/_src/shard_map.py | 28 ++++++++++++------ tests/shard_map_test.py | 64 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 939dbeddf3d7..e9f2d3b6072e 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -146,7 +146,7 @@ def _get_default_infer(): # TODO(yashkatariya): We need a singleton which users can provide to `in_axes` # to tell smap to infer in_specs from args when mesh is fully explicit. -def smap(f, in_axes, out_axes, axis_name: AxisName): +def smap(f, /, *, in_axes=Infer, out_axes, axis_name: AxisName): if isinstance(axis_name, (list, tuple)): raise TypeError( f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}") @@ -164,6 +164,15 @@ def smap(f, in_axes, out_axes, axis_name: AxisName): if not all(isinstance(l, int) for l in tree_leaves(out_axes)): raise TypeError("smap out_axes must be an int, None, or (nested) container " f"with those types as leaves, but got {out_axes}.") + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Use" + " `jax.sharding.use_mesh(mesh)` to enter into a mesh context.") + if mesh._name_to_type[axis_name] != AxisType.Explicit and in_axes is Infer: + raise TypeError( + f"in_axes was not specified when {axis_name=} was of type" + f" {mesh._name_to_type[axis_name]}.") in_specs = (None if in_axes is Infer else tree_map(partial(_axes_to_pspec, axis_name), in_axes, @@ -215,12 +224,13 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, f"jax.shard_map requires axis_names={axis_names} to be a subset of " f"mesh.axis_names={mesh.axis_names}") - # TODO(yashkatariya): Maybe we don't have to be this strict? - if mesh._any_axis_auto_or_manual and in_specs is None: + if (in_specs is None and + not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): raise TypeError( "shard_map in_specs argument must be a pytree of" - " `jax.sharding.PartitionSpec` instances, but it was None when mesh" - f" has `Auto` axes {mesh}") + " `jax.sharding.PartitionSpec` instances, but it was `None` when" + f" {axis_names=} are of type" + f" {', '.join(str(mesh._name_to_type[a]) for a in axis_names)}") if in_specs is not None: _check_specs(SpecErrorType.input, in_specs, axis_names) @@ -242,9 +252,8 @@ def wrapped(*args): e, *_ = prefix_errors(in_specs, args) raise e('shard_map in_specs') from None - # TODO(yashkatariya): Relax this and convert only `None`s in `in_specs_flat` - # and accept the other specs as is. - if mesh._are_all_axes_explicit and in_specs is None: + if (in_specs is None and + all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): arg_s = [typeof(a).sharding for a in args_flat] assert all(i is None for i in in_specs_flat), in_specs_flat in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s] @@ -597,7 +606,8 @@ def _as_manual_mesh(mesh, manual_axes: frozenset): if cur_mesh._name_to_type[a] == AxisType.Auto: auto_axes.add(a) else: - assert cur_mesh._name_to_type[a] == AxisType.Explicit, cur_mesh._name_to_type[a] + assert cur_mesh._name_to_type[a] == AxisType.Explicit, ( + a, cur_mesh._name_to_type[a]) explicit_axes.add(a) new_axis_types = [] diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 4d3b265bd869..1abef7b06323 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -36,7 +36,7 @@ from jax._src import config from jax._src import core from jax._src import prng -from jax._src.shard_map import shard_map, smap, Infer +from jax._src.shard_map import shard_map, smap from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists @@ -971,7 +971,7 @@ def test_in_specs_none_error(self): def f(x): return x - with self.assertRaisesRegex(TypeError, "but it was None"): + with self.assertRaisesRegex(TypeError, "but it was `None`"): shard_map(f, mesh=mesh, in_specs=None, out_specs=P())(3.) # TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug @@ -3182,7 +3182,7 @@ def h(x): @jax.jit def f(x): - return smap(h, in_axes=Infer, out_axes=0, axis_name='x')(x) + return smap(h, out_axes=0, axis_name='x')(x) out = f(arr) self.assertArraysEqual(out, np_inp * np_inp) @@ -3215,6 +3215,64 @@ def g(x, y): out = g(np.arange(4), np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data'))) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_auto_error(self, mesh): + with self.assertRaisesRegex(TypeError, "in_axes was not specified"): + smap(lambda x: x * 2, out_axes=0, axis_name='x') + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit(self, mesh): + def f(x): + self.assertEqual(x.aval.vma, {'x'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('x')) + out = jax.jit(smap(f, out_axes=0, axis_name='x'))(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def g(x): + self.assertEqual(x.aval.vma, {'y'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('y')) + out = jax.jit(smap(g, in_axes=0, out_axes=0, axis_name='y'))(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest(self, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'x', 'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return smap(g, in_axes=1, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(smap(f, in_axes=0, out_axes=0, axis_name='y'))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_inner_none(self, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + # Going manual over explicit axis `x` but in_axes is Infer and since + # input has no sharding, it will default to None. + return smap(g, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(smap(f, in_axes=0, out_axes=0, axis_name='y'))(arr) # doesn't crash + class FunSpec(NamedTuple): name: str From 6fc9e17a9e7d46dad326bd595ad77218fa5389e5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 12 May 2025 14:26:03 -0700 Subject: [PATCH 1134/1769] Do all mesh checks at shard_map **call time** instead of at construction time. **Why do this change?** * If a `smap`/`shard_map` is constructed NOT under a mesh context but called under a mesh context, we will error at construction time. After this change, we won't. * If a `smap`/`shard_map` is nested but constructed at the top level, it will be bound with the mesh available at construction time instead of at call time. This is not ideal since while nesting one axis at a time, manualness of a mesh changes for the nested shard_map call. So we need to look at the mesh at call time. For example: ``` @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit, AxisType.Auto)) def test_smap_auto_explicit_nest_mesh_call_time(self, mesh): @partial(smap, in_axes=1, out_axes=1, axis_name='x') def g(b): return jnp.sin(b) @partial(smap, in_axes=0, out_axes=0, axis_name='y') def f(a): self.assertEqual(a.aval.vma, {'y'}) b = a * 2 return g(b) arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) jax.jit(f)(arr) # doesn't crash ``` In the above example, before this change, `g` would be bound with the mesh whose axis_types were `Explict, Auto` but since `g` is being used inside a `smap` i.e. it's nested, it needs to be bound with the mesh at call time which would have axis_types `Explicit, Manual` for the computation to be correct. One minor point regarding this change is that since `axis_name` or `in_specs/out_specs` refer to mesh axis names, the shard_map would need to be called with the correct mesh. Before this errored out at construction time but now it'll error out at call time. PiperOrigin-RevId: 757912835 --- jax/_src/shard_map.py | 117 ++++++++++++++++++++-------------------- tests/shard_map_test.py | 18 ++++++- 2 files changed, 76 insertions(+), 59 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index e9f2d3b6072e..b772a3de239e 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -164,82 +164,30 @@ def smap(f, /, *, in_axes=Infer, out_axes, axis_name: AxisName): if not all(isinstance(l, int) for l in tree_leaves(out_axes)): raise TypeError("smap out_axes must be an int, None, or (nested) container " f"with those types as leaves, but got {out_axes}.") - mesh = get_abstract_mesh() - if mesh.empty: - raise ValueError( - "The context mesh cannot be empty. Use" - " `jax.sharding.use_mesh(mesh)` to enter into a mesh context.") - if mesh._name_to_type[axis_name] != AxisType.Explicit and in_axes is Infer: - raise TypeError( - f"in_axes was not specified when {axis_name=} was of type" - f" {mesh._name_to_type[axis_name]}.") in_specs = (None if in_axes is Infer else tree_map(partial(_axes_to_pspec, axis_name), in_axes, is_leaf=lambda x: x is None)) out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes, is_leaf=lambda x: x is None) - return shard_map(f, axis_names={axis_name}, in_specs=in_specs, - out_specs=out_specs) + return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs, + axis_names={axis_name}, check_vma=True, _smap=True) def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs | Callable[[], Specs], axis_names: Set[AxisName], check_vma: bool, - _skip_mesh_check: bool = False) -> Callable: + _skip_mesh_check: bool = False, _smap: bool = False) -> Callable: if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") - if mesh is None: - mesh = get_abstract_mesh() - if mesh.empty: - raise ValueError( - "The context mesh cannot be empty. Either use" - " `jax.sharding.use_mesh(mesh)` to enter into a mesh context or pass" - " a mesh to `shard_map` via the `mesh` keyword argument.") - else: - ctx_mesh = get_abstract_mesh() - if (not _skip_mesh_check and not ctx_mesh.empty and - mesh.abstract_mesh != ctx_mesh): - raise ValueError( - f"The context mesh {ctx_mesh} should match the mesh passed to" - f" shard_map {mesh}") - - if not isinstance(mesh, (Mesh, AbstractMesh)): - raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " - "`jax.sharding.AbstractMesh` instance for its " - f"second argument, but got {mesh} of type {type(mesh)}.") - - if not isinstance(axis_names, (frozenset, set)): - raise TypeError( - "`axis_names` argument of shard_map should be of type `frozenset` or" - f" `set`. Got type: {type(axis_names)}") - if isinstance(axis_names, set): - axis_names = frozenset(axis_names) - if not axis_names: - axis_names = frozenset(mesh.axis_names) - if not axis_names.issubset(mesh.axis_names): - raise ValueError( - f"jax.shard_map requires axis_names={axis_names} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") - - if (in_specs is None and - not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): - raise TypeError( - "shard_map in_specs argument must be a pytree of" - " `jax.sharding.PartitionSpec` instances, but it was `None` when" - f" {axis_names=} are of type" - f" {', '.join(str(mesh._name_to_type[a]) for a in axis_names)}") - - if in_specs is not None: - _check_specs(SpecErrorType.input, in_specs, axis_names) - if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, axis_names) - @util.wraps(f) @traceback_util.api_boundary def wrapped(*args): + nonlocal mesh, axis_names + mesh, axis_names = _shmap_checks(mesh, axis_names, in_specs, out_specs, + _skip_mesh_check, _smap) fun = lu.wrap_init( f, debug_info=api_util.debug_info("shard_map", f, args, {})) args_flat, in_tree = tree_flatten(args) @@ -305,6 +253,59 @@ def out_names_thunk(): return wrapped +def _shmap_checks(mesh, axis_names, in_specs, out_specs, _skip_mesh_check, + _smap): + if mesh is None: + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Use" + " `jax.sharding.use_mesh(mesh)` to enter into a mesh context") + else: + ctx_mesh = get_abstract_mesh() + if (not _skip_mesh_check and not ctx_mesh.empty and + mesh.abstract_mesh != ctx_mesh): + raise ValueError( + f"The context mesh {ctx_mesh} should match the mesh passed to" + f" shard_map {mesh}") + + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " + f"second argument, but got {mesh} of type {type(mesh)}.") + + if not isinstance(axis_names, (frozenset, set)): + raise TypeError( + "`axis_names` argument of shard_map should be of type `frozenset` or" + f" `set`. Got type: {type(axis_names)}") + if isinstance(axis_names, set): + axis_names = frozenset(axis_names) + if not axis_names: + axis_names = frozenset(mesh.axis_names) + if not axis_names.issubset(mesh.axis_names): + raise ValueError( + f"jax.shard_map requires axis_names={axis_names} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") + + if (in_specs is None and + not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names) + if _smap: + msg = (f"in_axes was not specified when axis_name={axis_names} was of" + f" type {axis_types}") + else: + msg = ("shard_map in_specs argument must be a pytree of" + " `jax.sharding.PartitionSpec` instances, but it was `None` when" + f" {axis_names=} are of type {axis_types}") + raise TypeError(msg) + + if in_specs is not None: + _check_specs(SpecErrorType.input, in_specs, axis_names) + if not callable(out_specs): + _check_specs(SpecErrorType.out, out_specs, axis_names) + return mesh, axis_names + + # Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1abef7b06323..00d437aadb08 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3218,7 +3218,7 @@ def g(x, y): @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) def test_smap_auto_error(self, mesh): with self.assertRaisesRegex(TypeError, "in_axes was not specified"): - smap(lambda x: x * 2, out_axes=0, axis_name='x') + smap(lambda x: x * 2, out_axes=0, axis_name='x')(np.arange(4)) @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit, AxisType.Auto)) @@ -3273,6 +3273,22 @@ def f(a): arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) jax.jit(smap(f, in_axes=0, out_axes=0, axis_name='y'))(arr) # doesn't crash + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_mesh_call_time(self, mesh): + @partial(smap, in_axes=1, out_axes=1, axis_name='x') + def g(b): + return jnp.sin(b) + + @partial(smap, in_axes=0, out_axes=0, axis_name='y') + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return g(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(f)(arr) # doesn't crash + class FunSpec(NamedTuple): name: str From bc3a3f0b24e3cb3c8c329053be12e3311b9ef2ff Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 12 May 2025 14:59:26 -0700 Subject: [PATCH 1135/1769] [pallas:mosaic] Allowed registering lowering per `pltpu.KernelType` PiperOrigin-RevId: 757925207 --- jax/_src/pallas/mosaic/lowering.py | 58 ++++++++++++++----- .../pallas/mosaic/pallas_call_registration.py | 1 + jax/_src/pallas/mosaic/verification.py | 6 +- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1ea5a048a17e..776ac1cb8143 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,7 +15,7 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Collection, Sequence import contextlib import dataclasses import functools @@ -168,6 +168,7 @@ class LoweringContext: block_shapes: list[tuple[int | pallas_core.Squeezed, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None + kernel_type: tpu_core.KernelType traceback_caches: mlir.TracebackCaches for_verification: bool forward_compatible: bool @@ -324,7 +325,7 @@ def ir_constant(x, mlir_type=None): raise NotImplementedError(x.dtype) -lowering_rules = {} +lowering_rules = {kernel_type: {} for kernel_type in tpu_core.KernelType} skip_mlir_conversions = set() @@ -332,10 +333,14 @@ def ir_constant(x, mlir_type=None): def register_lowering_rule( - prim: jax_core.Primitive, *, ensure_mlir_values: bool = True + prim: jax_core.Primitive, + *, + kernel_types: Collection[tpu_core.KernelType] = (tpu_core.KernelType.TC,), + ensure_mlir_values: bool = True, ) -> Callable[[T], T]: def decorator(rule: T) -> T: - lowering_rules[prim] = rule + for kernel_type in kernel_types: + lowering_rules[kernel_type][prim] = rule if not ensure_mlir_values: skip_mlir_conversions.add(prim) return rule @@ -673,6 +678,7 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, *, dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + kernel_type: tpu_core.KernelType, mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, dynamic_shape_replacement_enabled: bool = False, @@ -724,6 +730,7 @@ def dynamic_shape_replacement_fn( jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, name="main", + kernel_type=kernel_type, for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, @@ -759,6 +766,7 @@ def dynamic_shape_replacement_fn( bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, + kernel_type=kernel_type, for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, @@ -906,8 +914,9 @@ def lower_jaxpr_to_transform_func( *, name: str, mosaic_grid_mapping: MosaicGridMapping, + kernel_type: tpu_core.KernelType, for_verification: bool, - forward_compatible: bool, + forward_compatible: bool, dynamic_shape_replacement_fn: ( Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None ) = None, @@ -942,6 +951,7 @@ def body_func(*args): arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, forward_compatible=forward_compatible, @@ -966,11 +976,19 @@ def body_func(*args): return body.func_op +lower_jaxpr_to_func_fns = {} + + +def register_jaxpr_to_func(kernel_type: tpu_core.KernelType): + lower_jaxpr_to_func_fns[kernel_type] = lower_jaxpr_to_func + + def lower_jaxpr_to_func( jaxpr: jax_core.Jaxpr, *, mosaic_grid_mapping: MosaicGridMapping, name: str, + kernel_type: tpu_core.KernelType, for_verification: bool, forward_compatible: bool, dynamic_shape_replacement_fn: ( @@ -1012,6 +1030,7 @@ def body_func(*args): arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, forward_compatible=forward_compatible, @@ -1119,7 +1138,7 @@ def write_env(var: jax_core.Var, val): loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with (source_info_util.user_context(eqn.source_info.traceback), loc, eqn.ctx.manager): - if eqn.primitive in lowering_rules: + if eqn.primitive in lowering_rules[ctx.kernel_type]: if eqn.primitive not in skip_mlir_conversions: invals = [_ensure_mlir_value(x, v.aval) for x, v in zip(invals, eqn.invars)] @@ -1142,7 +1161,7 @@ def write_env(var: jax_core.Var, val): tpu.trace_start(message=name, level=10) try: - ans = lowering_rules[eqn.primitive]( + ans = lowering_rules[ctx.kernel_type][eqn.primitive]( rule_context, *invals, **eqn.params ) except LoweringException: @@ -1162,9 +1181,10 @@ def write_env(var: jax_core.Var, val): raise new_error from e else: raise NotImplementedError( - "Unimplemented primitive in Pallas TPU lowering: " - f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/jax-ml/jax/issues.") + "Unimplemented primitive in Pallas TPU lowering for" + f" {ctx.kernel_type}: {eqn.primitive.name}. Please file an issue on" + " https://github.com/jax-ml/jax/issues." + ) if eqn.primitive.multiple_results: foreach(write_env, eqn.outvars, ans) else: @@ -1889,7 +1909,9 @@ def _broadcast_to_lowering_rule( ) -@register_lowering_rule(lax.broadcast_in_dim_p) +@register_lowering_rule( + lax.broadcast_in_dim_p, kernel_types=[*tpu_core.KernelType] +) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): @@ -2139,7 +2161,9 @@ def _convert_helper(x, *, to_dtype): raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") -@register_lowering_rule(lax.convert_element_type_p) +@register_lowering_rule( + lax.convert_element_type_p, kernel_types=[*tpu_core.KernelType] +) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -2397,7 +2421,9 @@ def _bcast(x, y, x_aval, y_aval, out_aval): return x, y -@register_lowering_rule(lax.add_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.add_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) @register_lowering_rule(ad_util.add_any_p, ensure_mlir_values=False) def _add_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) @@ -2806,7 +2832,9 @@ def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): for prim in [lax.eq_p, lax.ne_p, lax.lt_p, lax.le_p, lax.gt_p, lax.ge_p]: - register_lowering_rule(prim)(functools.partial(_cmp_lowering_rule, prim)) + register_lowering_rule(prim, kernel_types=[*tpu_core.KernelType])( + functools.partial(_cmp_lowering_rule, prim) + ) @register_lowering_rule(lax.and_p, ensure_mlir_values=False) @@ -3530,7 +3558,7 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, return [] -@register_lowering_rule(lax.axis_index_p) +@register_lowering_rule(lax.axis_index_p, kernel_types=[*tpu_core.KernelType]) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 5de917d077ce..74253e809a35 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -150,6 +150,7 @@ def lower_module(for_verification: bool): grid_mapping, jaxpr, dimension_semantics=mosaic_params.dimension_semantics, + kernel_type=mosaic_params.kernel_type, mesh=jax_mesh, for_verification=for_verification, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index 08ff58770804..f45f36a473e9 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -596,11 +596,10 @@ def _assume_abstract_eval(x, y): assert jax_core.typematch(x, y) return x +@lowering.register_lowering_rule(assume_p) def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x -lowering.lowering_rules[assume_p] = _assume_lowering # type: ignore - def assume(normally, *, when_verifying): return assume_p.bind(normally, when_verifying) @@ -613,6 +612,7 @@ def _pretend_abstract_eval(*_, **params): del params # Unused. return () +@lowering.register_lowering_rule(pretend_p) def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): if ctx.lowering_context.for_verification: (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args) @@ -631,8 +631,6 @@ def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): ir.Operation.create("verification.pretend", operands=read_refs) return () -lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore - def pretend(read_refs): refs, transforms = unzip2( primitives._get_ref_and_transforms(r) for r in read_refs From e43432128b3be8e4e94a82c3c6cb6a24ca44863d Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 15:45:48 -0700 Subject: [PATCH 1136/1769] Support pltpu.roll on sublanes when not all lanes are used. PiperOrigin-RevId: 757942183 --- .../tpu/transforms/apply_vector_layout.cc | 123 +++++++++++++++++- .../tpu/transforms/infer_vector_layout.cc | 21 ++- tests/pallas/tpu_pallas_test.py | 22 +++- 3 files changed, 157 insertions(+), 9 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index b8ba61e7c914..d625e8bf4d6f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -36,6 +36,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" @@ -2141,16 +2142,41 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, if (layout_in != layout) { return op.emitOpError("Not implemented: unsupported layout for input"); } - if (layout_out != layout) { + LayoutOffsets expected_offsets_out = layout_in.offsets(); + auto shift = getIntConst(amount, /*silent=*/true); + const bool has_static_shift = succeeded(shift); + int rotated_tiled_dim = op.getDimension() - (op.getType().getRank() - 2); + bool has_padding_along_rotation = + (rotated_tiled_dim == 0 || rotated_tiled_dim == 1) && + op.getType().getShape()[op.getDimension()] % + layout.tiling()[rotated_tiled_dim] != + 0; + if (has_static_shift && has_padding_along_rotation) { + // We checked above that there are no implicit dims. + const int64_t dim_size = op.getType().getShape()[op.getDimension()]; + // TODO(b/337384645): Currently we assume {0, 0} offsets in the input + // layout. Relax this assumption. + expected_offsets_out[rotated_tiled_dim] = + (dim_size - (shift.value() % dim_size)) % + layout.tiling()[rotated_tiled_dim]; + } + if (layout_out.bitwidth() != layout.bitwidth() || + layout_out.offsets() != expected_offsets_out || + layout_out.tiling() != layout.tiling() || + layout_out.implicit_dim() != layout.implicit_dim()) { return op.emitOpError("Not implemented: unsupported layout for output"); } auto vty = op.getResult().getType(); if (vty.getRank() < 2) { return op.emitOpError("Not implemented: unsupported 1D shape"); } - if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 || - *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) { - return op.emitOpError("Not implemented: unsupported unaliged shape"); + // TODO(b/411170715): Allow sublane rotation once the bug is fixed. + // TODO(b/337384645): Support non-zero stride. + if (has_padding_along_rotation && + (!has_static_shift || + (rotated_tiled_dim == 0 || + (rotated_tiled_dim == 1 && op.getStride().value_or(0) != 0)))) { + return op.emitOpError("Not implemented: unsupported unaligned shape"); } ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); @@ -2277,6 +2303,88 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, return concatenate(chunks, axis); }; + // Applies lazy rotation (see go/pltpu-roll for details). + auto lazyRotate = [&](const xla::Array &vregs, int64_t shift, + int axis) { + const int tiling_dim = axis - (vregs.num_dimensions() - 2); + const int64_t tile_size = ctx.target_shape[tiling_dim]; + const int64_t input_size = vty.getShape()[axis]; + const int64_t normalized_shift = shift % input_size; + const int64_t start_idx = input_size - normalized_shift; + const int64_t start_vreg_idx = start_idx / tile_size; + const int64_t valid_amount = input_size % tile_size; + + // We start with the following: + // + // vregs: + // +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| + // +------+ +------+ +------+ + // + // where XXX is the padding and ░░░ is the prefix of the same size as the + // padding. + + // After concatenation: + // + // concat: + // +------+ +------+ +------+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 XXX| + // +------+ +------+ +------+ +------+ +------+ +------+ + auto concat = concatenate({vregs, vregs}, axis); + auto chunks = split(concat, axis); + int64_t original_num_chunks = chunks.size() / 2; + + Value rotate_amount = mlirI32Const(valid_amount); + SmallVector low = {mlirIndexConst(0), mlirIndexConst(0)}; + low[tiling_dim] = mlirIndexConst(valid_amount); + auto mask = builder.create( + VectorType::get(ctx.target_shape, builder.getI1Type()), low, + /*high=*/ + ArrayRef{mlirIndexConst(ctx.target_shape[0]), + mlirIndexConst(ctx.target_shape[1])}); + // overwrite padding in the last vreg with valid data from the first vreg, + // yielding: + // + // +------+ +------+ +------+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 ░░░| + // +------+ +------+ +------+ +------+ +------+ +------+ + chunks.back().Each([&](absl::Span idxs, Value *v) { + *v = builder.create( + mask, + builder.create( + res_vreg_ty, chunks.front()(idxs), rotate_amount, tiling_dim, + nullptr, nullptr), + *v); + }); + // rotate the vregs starting from the middle vreg and then blend the vregs + // to overwrite the padding, yielding: + // + // +------+ +------+ +---+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 | |░░░ 0 | | 1 | | 2 ░░░| + // +------+ +------+ +---+ +------+ +------+ +------+ + for (int64_t i = original_num_chunks; i < chunks.size(); ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create( + res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr, nullptr); + }); + } + for (int64_t i = original_num_chunks - 1; i < chunks.size() - 1; ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, chunks[i + 1](idxs), *v); + }); + } + SmallVector result_dimensions = + layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape); + // assemble the result + xla::Array result(result_dimensions); + SmallVector starts(result.num_dimensions(), 0); + for (int64_t i = 0; i < result_dimensions[axis]; ++i) { + starts[axis] = i; + result.UpdateSlice(chunks[i + start_vreg_idx], starts); + } + return result; + }; + std::function(const xla::Array &, Value, int, int)> rotate; rotate = [&](const xla::Array &vregs, Value shift, int axis, @@ -2290,6 +2398,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, if (auto shift_cst = getIntConst(shift, /*silent=*/true); succeeded(shift_cst)) { int64_t static_shift = shift_cst.value(); + if (has_padding_along_rotation) { + return lazyRotate(vregs, static_shift, axis); + } if (tiling_dim >= 0) { shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]); static_shift /= ctx.target_shape[tiling_dim]; @@ -2379,7 +2490,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, return result; }; - xla::Array out_tiles(in_tiles.dimensions()); + SmallVector out_dimensions = + layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape); + xla::Array out_tiles(out_dimensions); const auto dim = op.getDimension(); amount = modI(amount, vty.getDimSize(dim)); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2e4c1c9c48a9..f42cfb139a37 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -757,9 +757,28 @@ class VectorLayoutInferer { if (op.getType().getRank() < 2) { NYI("Unsupported 1D shape"); } + // TODO(b/337384645): Currently we assume {0, 0} offsets in the input + // layout. Relax this assumption. auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), ImplicitDim::kNone); - setLayout(op, {layout, kNoLayout}, layout); + // Calculate the offsets for the output layout. + LayoutOffsets offsets_out = layout.offsets(); + // We assume there are no implicit dims. + int tiling_dim = op.getDimension() - (op.getType().getRank() - 2); + if (auto amount = op.getAmount().getDefiningOp(); + amount && (tiling_dim == 0 || tiling_dim == 1)) { + if (auto integer_attr = dyn_cast(amount.getValue())) { + const int64_t tile_size = layout.tiling()[tiling_dim]; + const int64_t dim_size = op.getType().getShape()[op.getDimension()]; + const int64_t shift = integer_attr.getValue().getSExtValue(); + if (dim_size % tile_size != 0) { + offsets_out[tiling_dim] = (dim_size - (shift % dim_size)) % tile_size; + } + } + } + auto out_layout = VectorLayout(bitwidth, offsets_out, + nativeTiling(bitwidth), ImplicitDim::kNone); + setLayout(op, {layout, kNoLayout}, out_layout); return success(); } diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index a70aa19bda4d..83f21bca7fc1 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2895,9 +2895,9 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) - @only_passes_in_interpret() - def test_roll_partial(self): - """b/337384645""" + def test_roll_partial_with_static_shift(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest('Needs a newer libtpu') x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) def kernel(x_ref, out_ref): @@ -2908,6 +2908,22 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.roll(x, 3, 1)) + def test_roll_partial_with_dynamic_shift(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest('Needs a newer libtpu') + if self.INTERPRET: + self.skipTest('Test only applies to non-interpret mode.') + x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) + + def kernel(x_ref, out_ref): + amount = x_ref[0, 0].astype(jnp.int32) + out_ref[...] = pltpu.roll(x_ref[...], amount, 1) + + with self.assertRaisesRegex(Exception, 'unsupported unaligned shape'): + _ = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32) + )(x) + @only_passes_in_interpret() def test_retiling1(self): """b/352626602""" From 189ba3a99959245724105217a72902cc50e38b14 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 May 2025 16:43:16 -0700 Subject: [PATCH 1137/1769] Add `block_until_ready()` to FAQ code snippet. PiperOrigin-RevId: 757963427 --- docs/faq.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index f5d43d25afb6..25d1d9ffab57 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -422,7 +422,6 @@ for comparing JAX versus NumPy, making using of IPython's convenient `%time and %timeit magics`_:: import numpy as np - import jax.numpy as jnp import jax def f(x): # function we're benchmarking (works in both NumPy & JAX) @@ -431,7 +430,9 @@ for comparing JAX versus NumPy, making using of IPython's convenient x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype %timeit f(x_np) # measure NumPy runtime - %time x_jax = jax.device_put(x_np) # measure JAX device transfer time + # measure JAX device transfer time + %time x_jax = jax.device_put(x_np).block_until_ready() + f_jit = jax.jit(f) %time f_jit(x_jax).block_until_ready() # measure JAX compilation time %timeit f_jit(x_jax).block_until_ready() # measure JAX runtime From 74938be8456050fec5032682f1463ea36983e1de Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 7 May 2025 14:03:14 -0400 Subject: [PATCH 1138/1769] Consolidate initial/final style custom_vjp primitives into one. --- jax/_src/checkify.py | 22 ++-- jax/_src/custom_derivatives.py | 148 +++++++++----------------- jax/_src/interpreters/partial_eval.py | 82 +++++++------- jax/_src/pallas/cost_estimate.py | 6 +- jax/custom_derivatives.py | 1 - jax/experimental/jax2tf/jax2tf.py | 8 +- jax/extend/core/primitives.py | 1 - tests/custom_api_test.py | 2 +- 8 files changed, 111 insertions(+), 159 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 144cbaf5cd21..aa9bfe9529ce 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1079,17 +1079,17 @@ def jvp(*xs): return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk, num_consts, - bwd: lu.WrappedFun, out_trees, - symbolic_zeros: bool): +def custom_vjp_call_rule(in_err, enabled_errors, *in_vals, + call_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk, num_consts, + bwd: lu.WrappedFun, out_trees, + symbolic_zeros: bool): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( - functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, - fun_jaxpr.consts, enabled_errors, err_tree), - debug_info=fun_jaxpr.jaxpr.debug_info) + functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, + call_jaxpr.consts, enabled_errors, err_tree), + debug_info=call_jaxpr.jaxpr.debug_info) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) @@ -1097,13 +1097,13 @@ def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] xs, zeros = xs[num_errs:], zeros[num_errs:] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr checkified_fwd_wrapped = lu.wrap_init(checkified_fwd, - debug_info=fun_jaxpr.jaxpr.debug_info) + debug_info=fwd_jaxpr_thunk.debug_info) bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)), debug_info=bwd.debug_info) checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped) @@ -1118,7 +1118,7 @@ def checkified_fwd(*args): else: out_err, out_vals = in_err, all_outs return out_err, out_vals -error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule +error_checks[custom_derivatives.custom_vjp_call_p] = custom_vjp_call_rule def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index dc8fc90e3d1f..dcd893f44123 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -425,16 +425,14 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, return call_jaxpr.out_avals, call_jaxpr.effects core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck -def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - del jvp_jaxpr_fun, num_consts, symbolic_zeros +def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_): consts = mlir._ir_consts(call_jaxpr.consts) out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr, ctx.name_stack, ctx.tokens_in, consts, *args, dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) return out -mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation) +mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering) # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already @@ -936,8 +934,8 @@ def _temporary_dtype_exception(a, a_) -> bool: def _temporary_shape_exception(a, a_) -> bool: return config.custom_vjp_disable_shape_check.value -class CustomVJPCallPrimitive(core.CallPrimitive): - initial_style: core.Primitive +class CustomVJPCallPrimitive(core.Primitive): + multiple_results = True def bind(self, *args, **params): return self._true_bind(*args, **params) @@ -946,107 +944,70 @@ def bind_with_trace(self, trace, args, params): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) -custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') + def impl(self, fun, fwd, bwd, *args): + raise NotImplementedError + + def get_bind_params(self, params): + new_params = dict(params) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + num_consts: int = new_params.pop('num_consts') + fwd_jaxpr_thunk = new_params.pop('fwd_jaxpr_thunk') + fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = lift_fwd(num_consts, fwd_jaxpr_thunk) + const_avals, _ = split_list(call_jaxpr.in_avals, [num_consts]) + bwd = _handle_consts_in_bwd(new_params.pop('bwd'), const_avals) + return [fun, fwd, bwd], new_params + +def lift_fwd(num_consts: int, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: + def fwd(*args): + vals, zeros = args[::2], args[1::2] + assert len(vals) == len(zeros) + _, primals = split_list(vals, [num_consts]) + const_zeros, in_zeros = split_list(zeros, [num_consts]) + if any(const_zeros): + raise ad.CustomVJPException() + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*in_zeros) + return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *primals) + return lu.wrap_init(fwd, debug_info=fwd_jaxpr_thunk.debug_info) -def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): - return core.jaxpr_as_fun(fun_jaxpr)(*args) +@lu.transformation2 +def _handle_consts_in_bwd(f, const_avals, *args): + return [Zero(a) for a in const_avals] + list(f(*args)) -def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects) +custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') +mlir.register_lowering(custom_vjp_call_p, _custom_jvp_vjp_call_lowering) + +def _custom_vjp_call_typecheck(_, *in_avals, call_jaxpr, **kwargs): + del in_avals, kwargs + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in( + call_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_vjp`: {disallowed_effects}') - return fun_jaxpr.out_avals, fun_jaxpr.effects - -custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') -custom_vjp_call_jaxpr_p.multiple_results = True -custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) -custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) -CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p - -mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun( - _custom_vjp_call_jaxpr_impl, multiple_results=True)) - -def _custom_vjp_call_jaxpr_jvp( - primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - _, args = split_list(primals, [num_consts]) - consts_dot, args_dot = split_list(tangents, [num_consts]) - if any(type(t) is not Zero for t in consts_dot): - raise ad.CustomVJPException() - zeros = [type(t) is not Zero for t in args_dot] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! - _, res_tree = out_trees() - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - args_dot = map(ad.instantiate_zeros, args_dot) - tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - return primals_out, tangents_out -ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp - -def _custom_vjp_call_jaxpr_vmap( - axis_data, args, in_dims, *, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable, symbolic_zeros: bool): - args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 - else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] - _, args_batched = split_list(in_batched, [num_consts]) - batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_data, in_batched, False) - out_dims1 = [0 if b else not_mapped for b in out_batched] - out_dims2 = [] - - @pe._memoize - def batched_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers - batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_data, args_batched, False) - out_dims2.append([0 if b else not_mapped for b in out_batched]) - return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts - - fwd_args_batched = [0 if b else not_mapped for b in args_batched] - fwd_out_dims = lambda: out_dims2[0] - tag = core.TraceTag() - batched_bwd = batching.batch_custom_vjp_bwd( - bwd, tag, axis_data, fwd_out_dims, fwd_args_batched) - - batched_outs = custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=batched_fun_jaxpr, - fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_dims = out_dims2[0] if out_dims2 else out_dims1 - return batched_outs, out_dims -batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap + return call_jaxpr.out_avals, call_jaxpr.effects +core.custom_typechecks[custom_vjp_call_p] = _custom_vjp_call_typecheck -def _custom_vjp_call_jaxpr_dce( +def _custom_vjp_call_dce( used_outs: Sequence[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None - fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"] + call_jaxpr: core.ClosedJaxpr = eqn.params["call_jaxpr"] fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"] bwd: lu.WrappedFun = eqn.params["bwd"] out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"] symbolic_zeros: bool = eqn.params["symbolic_zeros"] - dce_fun_jaxpr: core.ClosedJaxpr + dce_call_jaxpr: core.ClosedJaxpr used_ins: Sequence[bool] - dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate( - fun_jaxpr, tuple(used_outs)) + dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate( + call_jaxpr, tuple(used_outs)) assert all(used_ins) + @partial(lu.wrap_init, debug_info=fwd_jaxpr_thunk.debug_info) @pe._memoize def dce_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk.call_wrapped(*zeros)) _, res_tree = out_trees() num_res = res_tree.num_leaves dce_fwd_jaxpr, _ = _cached_closed_call_dce_instantiate( @@ -1058,7 +1019,7 @@ def dce_bwd(*args): res, cts = split_list(args, [res_tree.num_leaves]) cts_ = iter(cts) all_cts = [] - for used, aval in zip(used_outs, fun_jaxpr.out_avals): + for used, aval in zip(used_outs, call_jaxpr.out_avals): if used: all_cts.append(next(cts_)) else: @@ -1075,17 +1036,15 @@ def dce_bwd(*args): outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] new_params = dict( eqn.params, - fun_jaxpr=dce_fun_jaxpr, + call_jaxpr=dce_call_jaxpr, fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk, bwd=dce_bwd_wrapped, ) new_eqn = pe.new_jaxpr_eqn( - eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects, + eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects, eqn.source_info, eqn.ctx) return list(used_ins), new_eqn -pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce - -xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) +pe.dce_rules[custom_vjp_call_p] = _custom_vjp_call_dce batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp) @@ -1586,7 +1545,6 @@ def jvp(primals, tangents): # TODO(mattjj): remove these stubs, which exist to avoid breaking internal users custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") - # The following is a helper for optimizing the behavior of custom_vjp when used # under remat. This is really only useful when the `fwd` function to custom_vjp # executes a black box kernel. Otherwise, DCE will perform this optimization diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 64226a789cde..5866b0c5f8eb 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -434,49 +434,45 @@ def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symboli if all(t.is_known() for t in tracers): vals = [t.pval[1] for t in tracers] with core.set_current_trace(self.parent_trace): - return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - else: - # TODO(mattjj): remove non-ad users of partial eval, then drop this case. - # We stage out the whole thing, i.e. no nontrivial partial evaluation. - tracers = map(self.instantiate_const_abstracted, tracers) - # Because we instantiate all tracers, in_knowns is all False. - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info) - f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) - with core.set_current_trace(self.parent_trace): - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.to_jaxpr_tracer, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info) - fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) - out_flat = fwd_.call_wrapped() - out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) + return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) + + tracers = map(self.instantiate_const, tracers) + in_knowns = (False,) * len(tracers) + in_avals = tuple(t.aval for t in tracers) + f_ = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, True) + f_, aux = partial_eval_wrapper_nounits(f_, in_knowns, in_avals) + params = dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros) + res = prim.bind_with_trace(self.parent_trace, (f_, fwd, bwd), params) + out_knowns, out_avals, jaxpr, env = aux() + assert not any(out_knowns) + res_tracers = map(self.instantiate_const, map(self.new_const, res)) + env_tracers = map(self.to_jaxpr_tracer, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr)) + + @partial(lu.wrap_init, debug_info=fwd.debug_info) + @_memoize + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_jaxpr, _, consts, () = trace_to_jaxpr_dynamic(fwd_, in_avals) + return fwd_jaxpr, consts name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) + params = dict( + call_jaxpr=closed_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, + num_consts=len(res) + len(env), + bwd=bwd, + out_trees=out_trees, + symbolic_zeros=symbolic_zeros + ) eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), - out_tracers, prim.initial_style, - dict(fun_jaxpr=closed_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, - num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), - jaxpr.effects, source) + out_tracers, prim, params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) + return out_tracers def partition_pvals( pvals: list[PartialVal] @@ -2050,6 +2046,7 @@ def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=jvp.debug_info) @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() @@ -2065,8 +2062,7 @@ def jvp_jaxpr_thunk(*in_zeros): outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, - jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk, - debug_info=jvp.debug_info), + jvp_jaxpr_fun=jvp_jaxpr_thunk, num_consts=len(consts), symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, @@ -2086,6 +2082,7 @@ def process_custom_vjp_call(self, prim: core.Primitive, fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=fwd.debug_info) @_memoize def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() @@ -2098,9 +2095,8 @@ def fwd_jaxpr_from_zeros(*zeros): invars = map(self.getvar, tracers) constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, - prim.initial_style, # pytype: disable=attribute-error - dict(fun_jaxpr=closed_fun_jaxpr, + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, + dict(call_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, num_consts=len(consts), bwd=bwd, out_trees=out_trees, diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 3b82d3095f64..93bcf5348b24 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -238,15 +238,15 @@ def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): ) register_cost_rule(pjit.pjit_p, _pjit_cost_rule) -def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): +def _custom_vjp_rule(ctx, *, call_jaxpr: jax_core.ClosedJaxpr, **_): del ctx - inner_cost = cost_estimate_jaxpr(fun_jaxpr) + inner_cost = cost_estimate_jaxpr(call_jaxpr) return CostEstimate( flops=inner_cost.flops, transcendentals=inner_cost.transcendentals, bytes_accessed=inner_cost.bytes_accessed, ) -register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) +register_cost_rule(custom_derivatives.custom_vjp_call_p, _custom_vjp_rule) def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 3628ae4aaa6e..b768b687dfad 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -26,7 +26,6 @@ custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp as custom_vjp, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 786e021e2ff0..536bf1f201f0 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3461,14 +3461,14 @@ def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call -def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr, - **_) -> Sequence[TfVal]: +def _custom_vjp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, + **_) -> Sequence[TfVal]: # TODO(necula): ensure that there is no AD transformation in scope - return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp", + return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_vjp", fresh_constant_cache=False) -tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr +tf_impl[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index 60d8cd24a949..30350dace637 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -26,7 +26,6 @@ custom_jvp_call_p as custom_jvp_call_p, custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, ) from jax._src.dispatch import device_put_p as device_put_p diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 72c14634a9c8..73dc2fbefcaa 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -3065,7 +3065,7 @@ def check_jaxpr(jaxpr, used_outs, includes, excludes): if not dce_jaxpr.eqns: assert not includes return - call_jaxpr = dce_jaxpr.eqns[0].params["fun_jaxpr"] + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] for prim in includes: assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) for prim in excludes: From 91de2e39c161e3f26e8d58ab7071a189903563f8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 13 May 2025 02:45:38 -0700 Subject: [PATCH 1139/1769] Update tridiagonal solve kernels on GPU to properly use the FFI. This fixes https://github.com/jax-ml/jax/issues/28544 by using the batched algorithms directly when possible. It also adds complex dtype and batch partitioning support to tridiagonal solves on GPU. PiperOrigin-RevId: 758129745 --- jax/_src/lax/linalg.py | 52 ++++++----- jaxlib/cuda/BUILD | 2 + jaxlib/gpu/sparse.cc | 2 + jaxlib/gpu/sparse_kernels.cc | 166 ++++++++++++++++++++++++++++++++++ jaxlib/gpu/sparse_kernels.h | 1 + jaxlib/gpu/vendor.h | 39 +++++++- jaxlib/gpu_sparse.py | 9 ++ jaxlib/rocm/BUILD | 2 + jaxlib/xla_client.py | 2 +- tests/linalg_sharding_test.py | 22 +++-- tests/linalg_test.py | 10 +- 11 files changed, 270 insertions(+), 37 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index dd86c22432d8..857b115b06d8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -48,7 +48,7 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version +from jax._src.lib import version as jaxlib_version, jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2530,29 +2530,33 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): return b_shape def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): - _, _, _, b_aval = ctx.avals_in - *batch_dims, m, n = b_aval.shape - batch_size = math.prod(batch_dims) - - mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse - assert mod is not None - opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) - if b_aval.dtype == np.float32: - buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) - target_name = "sparse_gtsv2_f32_ffi" - elif b_aval.dtype == np.float64: - buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) - target_name = "sparse_gtsv2_f64_ffi" - else: - raise NotImplementedError( - "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - - buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) - sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) - rule = _linalg_ffi_lowering( - f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, - batch_partitionable=False) - return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] + if jaxlib_extension_version < 340: + _, _, _, b_aval = ctx.avals_in + *batch_dims, m, n = b_aval.shape + batch_size = math.prod(batch_dims) + mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse + assert mod is not None + opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) + if b_aval.dtype == np.float32: + buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f32_ffi" + elif b_aval.dtype == np.float64: + buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f64_ffi" + else: + raise NotImplementedError( + "tridiagonal_solve is only implemented for float32 and float64 on GPU.") + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = _linalg_ffi_lowering( + f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, + batch_partitionable=False) + return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] + + target_name = f"{target_name_prefix}sparse_gtsv2_ffi" + rule = _linalg_ffi_lowering(target_name, operand_output_aliases={3: 0}) + return rule(ctx, dl, d, du, b) def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 2cc1476b637e..eabb3157ecca 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -259,11 +259,13 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 592c0f454a55..0190ba776de5 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -614,6 +614,8 @@ nb::dict Registrations() { EncapsulateFfiHandler(gtsv2_f32_ffi); dict[JAX_GPU_PREFIX "sparse_gtsv2_f64_ffi"] = EncapsulateFfiHandler(gtsv2_f64_ffi); + dict[JAX_GPU_PREFIX "sparse_gtsv2_ffi"] = EncapsulateFfiHandler(kGtsv2); + // TODO(tomhennigan): Add support for gtsv2 complex 32/64. return dict; } diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index a9c08317e066..363321e3ca8b 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -16,20 +16,29 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + +namespace ffi = ::xla::ffi; + namespace jax { template <> @@ -641,5 +650,162 @@ void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, } } +template +ffi::Error Gtsv2Impl(BufferSizeF getBufferSize, KernelF kernel, int64_t batch, + int64_t rows, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, ffi::AnyBuffer b, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, n, nullptr, + nullptr, nullptr, nullptr, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); + } + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + for (int64_t i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, n, dl_data, d_data, + du_data, out_data, m, workspace)); + dl_data += m; + d_data += m; + du_data += m; + out_data += m * n; + } + return ffi::Error::Success(); +} + +template +ffi::Error Gtsv2BatchedImpl(BufferSizeF getBufferSize, KernelF kernel, + int64_t batch, int64_t rows, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto batch_count, MaybeCastNoOverflow(batch)); + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, nullptr, nullptr, + nullptr, nullptr, batch_count, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); + } + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, dl_data, d_data, du_data, + out_data, batch_count, m, workspace)); + return ffi::Error::Success(); +} + +ffi::Error Gtsv2(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer dl, ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + auto dataType = dl.element_type(); + if (dataType != d.element_type() || dataType != du.element_type() || + dataType != b.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gtsv2 must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(b.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(dl.dimensions(), {batch, rows}, "dl", "gtsv2")); + FFI_RETURN_IF_ERROR(CheckShape(d.dimensions(), {batch, rows}, "d", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(du.dimensions(), {batch, rows}, "du", "gtsv2")); + if (batch > 1 && cols == 1) { + switch (dataType) { + case ffi::F32: + return Gtsv2BatchedImpl( + gpusparseSgtsv2StridedBatch_bufferSizeExt, + gpusparseSgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::F64: + return Gtsv2BatchedImpl( + gpusparseDgtsv2StridedBatch_bufferSizeExt, + gpusparseDgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C64: + return Gtsv2BatchedImpl( + gpusparseCgtsv2StridedBatch_bufferSizeExt, + gpusparseCgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C128: + return Gtsv2BatchedImpl( + gpusparseZgtsv2StridedBatch_bufferSizeExt, + gpusparseZgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + default: + break; + } + + } else { + switch (dataType) { + case ffi::F32: + return Gtsv2Impl(gpusparseSgtsv2_bufferSizeExt, gpusparseSgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::F64: + return Gtsv2Impl(gpusparseDgtsv2_bufferSizeExt, gpusparseDgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::C64: + return Gtsv2Impl(gpusparseCgtsv2_bufferSizeExt, + gpusparseCgtsv2, batch, rows, cols, stream, + scratch, dl, d, du, b, out); + case ffi::C128: + return Gtsv2Impl(gpusparseZgtsv2_bufferSizeExt, + gpusparseZgtsv2, batch, rows, cols, + stream, scratch, dl, d, du, b, out); + default: + break; + } + } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gtsv2", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kGtsv2, Gtsv2, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // dl + .Arg() // d + .Arg() // du + .Arg() // b + .Ret() // out +); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index d735c320307c..3b365872f591 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -157,6 +157,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatvecFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatmatFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f32_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f64_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGtsv2); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 5deb8d4c650a..b96552f81bd1 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -152,7 +152,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS #define GPUDNN_WGRAD_MODE_ADD CUDNN_WGRAD_MODE_ADD #define GPUDNN_RNN_ALGO_STANDARD CUDNN_RNN_ALGO_STANDARD -#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED +#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED \ + CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED #define GPUDNN_RNN_PADDED_IO_ENABLED CUDNN_RNN_PADDED_IO_ENABLED #define GPUDNN_DEFAULT_MATH CUDNN_DEFAULT_MATH #define GPUDNN_FMA_MATH CUDNN_FMA_MATH @@ -289,10 +290,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize #define gpusparseSpMV cusparseSpMV #define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize + #define gpusparseSgtsv2 cusparseSgtsv2 #define gpusparseDgtsv2 cusparseDgtsv2 +#define gpusparseCgtsv2 cusparseCgtsv2 +#define gpusparseZgtsv2 cusparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt cusparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt cusparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + cusparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + cusparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + cusparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + cusparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch cusparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch cusparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch cusparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch cusparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I @@ -636,10 +655,28 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize #define gpusparseSpMV hipsparseSpMV #define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize + #define gpusparseSgtsv2 hipsparseSgtsv2 #define gpusparseDgtsv2 hipsparseDgtsv2 +#define gpusparseCgtsv2 hipsparseCgtsv2 +#define gpusparseZgtsv2 hipsparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt hipsparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt hipsparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + hipsparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + hipsparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + hipsparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + hipsparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch hipsparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch hipsparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch hipsparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch hipsparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index af03eb6e6a8a..bf1dc6f64ec1 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -35,3 +35,12 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: for name, value in module.registrations().items() ) return registrations # pytype: disable=bad-return-type + +def batch_partitionable_targets() -> list[str]: + targets: list[str] = [] + for module in [_cusparse, _hipsparse]: + if module: + targets.extend( + name for name in module.registrations() if name.endswith("gtsv2_ffi") + ) + return targets diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 75406174dd93..d0468d72d1b3 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -244,11 +244,13 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 449dfa653286..6aaae11c139d 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 339 +_version = 340 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/linalg_sharding_test.py b/tests/linalg_sharding_test.py index 2f190cdc5ad6..e68e94e16494 100644 --- a/tests/linalg_sharding_test.py +++ b/tests/linalg_sharding_test.py @@ -22,6 +22,7 @@ from jax import lax from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import jaxlib_extension_version from jax.sharding import PartitionSpec as P config.parse_flags_with_absl() @@ -31,13 +32,8 @@ complex_types = jtu.dtypes.complex +# These functions are only supported on CPU. CPU_ONLY_FUN_AND_SHAPES = [ - # The GPU kernel for this function still uses an opaque descriptor to - # encode the input shapes so it is not partitionable. - # TODO(danfm): Update the kernel and enable this test on GPU. - (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), - - # These functions are only supported on CPU. (lax.linalg.hessenberg, ((6, 6),)), (lax.linalg.schur, ((6, 6),)), ] @@ -51,6 +47,7 @@ (lax.linalg.svd, ((10, 6),)), (lax.linalg.triangular_solve, ((6, 6), (4, 6))), (lax.linalg.tridiagonal, ((6, 6),)), + (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), ] ALL_FUN_AND_SHAPES = CPU_ONLY_FUN_AND_SHAPES + CPU_AND_GPU_FUN_AND_SHAPES @@ -73,6 +70,11 @@ def get_fun_and_shapes(self, fun_and_shapes, grad=False): self.skipTest( f"Partitioning {fun_and_shapes[0].__name__} only supported on GPU " "when shardy is enabled.") + if (fun_and_shapes[0] == lax.linalg.tridiagonal_solve and + jaxlib_extension_version < 340): + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} on GPU, requires a " + "more recent jaxlib version.") if not grad: return fun_and_shapes @@ -178,7 +180,9 @@ def jvp_fun(primals, tangents): (primals_sharded, tangents), ]: _, actual = jvp_fun_jit(*args) - self.assertAllClose(actual, expected, atol={np.float64: 1e-12}) + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 1e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() self.assertNotIn("all-", hlo.as_text()) @@ -199,7 +203,9 @@ def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): vjp_fun_jit = jax.jit(vjp_fun) expected = vjp_fun(tangents) actual = vjp_fun_jit(tangents_sharded) - self.assertAllClose(actual, expected) + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 1e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) hlo = vjp_fun_jit.lower(tangents_sharded).compile() self.assertNotIn("all-", hlo.as_text()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 033ca989c8e7..a9f81ec04560 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -33,6 +33,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import jaxlib_extension_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -2202,7 +2203,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) def test_tridiagonal_solve(self, shape, dtype): - if dtype not in float_types and jtu.test_device_matches(["gpu"]): + if dtype not in float_types and jtu.test_device_matches(["gpu"]) and jaxlib_extension_version < 340: self.skipTest("Data type not supported on GPU") rng = self.rng() d = 1.0 + jtu.rand_positive(rng)(shape, dtype) @@ -2217,7 +2218,10 @@ def build_tri(dl, d, du): build_tri = jax.vmap(build_tri) a = build_tri(dl, d, du) - self.assertAllClose(a @ x, b, atol=5e-5, rtol=1e-4) + with jax.default_matmul_precision("float32"): + self.assertAllClose(a @ x, b, atol={ + np.float32: 1e-3, np.float64: 1e-10, np.complex64: 1e-3, + np.complex128: 1e-10}) def test_tridiagonal_solve_endpoints(self): # tridagonal_solve shouldn't depend on the endpoints being explicitly zero. @@ -2238,7 +2242,7 @@ def test_tridiagonal_solve_endpoints(self): @jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types) def test_tridiagonal_solve_grad(self, shape, dtype): - if dtype not in float_types and jtu.test_device_matches(["gpu"]): + if dtype not in float_types and jtu.test_device_matches(["gpu"]) and jaxlib_extension_version < 340: self.skipTest("Data type not supported on GPU") rng = self.rng() d = 1.0 + jtu.rand_positive(rng)(shape, dtype) From 3667353dce35e9bf1789b9948e796d41c82fef05 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 13 May 2025 03:23:54 -0700 Subject: [PATCH 1140/1769] [pallas:mosaic_gpu] Removed debug prints from `emit_pipeline_warp_specialized` PiperOrigin-RevId: 758140625 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index db5fb4fb316a..41ddf993df3a 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -663,13 +663,11 @@ def pipeline_callback(user_init_carry): if last_indices is not None: raise ValueError( "Cannot call pipeline more than once in `compute_context`") - print("[DEBUG] user_init_carry: ", user_init_carry) init_loop_carry = (init_indices, last_store_slices, user_init_carry) last_indices, _, final_body_carry = lax.fori_loop(0, num_steps, compute_loop_body, init_loop_carry) - print("[DEBUG] final_body_carry: ", final_body_carry) return final_body_carry compute_context(pipeline_callback) if last_indices is None: From 0d6ad8aee0fe786109e2e37b2b0e92eed712c950 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 13 May 2025 03:38:09 -0700 Subject: [PATCH 1141/1769] [pallas:mosaic_gpu] Nuked a debug test added by accident PiperOrigin-RevId: 758144215 --- tests/pallas/mosaic_gpu_test.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a5719b01f4ad..873854266782 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3321,20 +3321,6 @@ def do_wgmma(acc_ref): np.testing.assert_allclose(kernel(x, x), x @ x) - def test_debug_bug(self): - dtype = jnp.float16 - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - ) - def kernel(o_ref): - kv_step = jnp.asarray(0) - @pl.when(kv_step < -2) - def dp(): - pl.debug_print("foo") - o_ref[...] = jnp.zeros_like(o_ref) - kernel() - # TODO(apaszke): Clusters and multicast From f73726e958ec3701fa635d69da11e4e29dd83d17 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 13 May 2025 05:26:09 -0700 Subject: [PATCH 1142/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c9b736e8c217529795badfe9cf7730b5d0f38242. PiperOrigin-RevId: 758173411 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 94fbee2d8b83..a4a7c8cb7dcd 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6ad6ae3dafa9868708e54de10e3aeafb081a71f2" -XLA_SHA256 = "405c2787b2fa2a467f4b2179cbfb2fd25f282f55ff43b0571ac20e4e56b8c26c" +XLA_COMMIT = "c9b736e8c217529795badfe9cf7730b5d0f38242" +XLA_SHA256 = "d5fb6fb909a81838e4fa5f5dc755365a10132bc4acec6ee18df815d3cda6df2f" def repo(): tf_http_archive( From 66476a0b4c20753d79b0960bc476167589691e67 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 13 May 2025 07:11:52 -0700 Subject: [PATCH 1143/1769] [Mosaic GPU] Add layout inference and lowering for `scf.WhileOp` and enable tests. PiperOrigin-RevId: 758205399 --- .../mosaic/gpu/dialect_lowering.py | 153 ++++++++++++++---- .../mosaic/gpu/layout_inference.py | 78 +++++++-- tests/mosaic/gpu_layout_inference_test.py | 53 ++++++ tests/pallas/mosaic_gpu_test.py | 21 +-- 4 files changed, 243 insertions(+), 62 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index c1506bde32ea..320ae32607e9 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1123,6 +1123,54 @@ def _unflatten_ir_values( return result +def _move_scf_block_to_block_with_flattened_arguments( + ctx: LoweringContext, + old_block: ir.Block, + new_block: ir.Block, + last_op_type: type[ir.OpView], + args_template: Sequence[_VectorTemplate | None], + *new_leading_args: Sequence[ir.Value], +) -> Sequence[_VectorTemplate | None]: + """Moves the operations from `old_block` to `new_block`. + + The input arguments to the block, if any, are flattened using the provided + `args_template`, except for any new_leading_args which are simply prepended + to the flattened arguments and must be part of the template. + + The last operation of the old block must be of type `last_op_type` which + is expected to be either a `scf.YieldOp` or a `scf.ConditionOp`. This + operation is recreated with flattened output arguments. + """ + out_template = None + with ir.InsertionPoint(new_block): + new_carry = _unflatten_ir_values(new_block.arguments[len(new_leading_args):], args_template) + new_args = new_leading_args + tuple(new_carry) + for old_arg, new_arg in zip(old_block.arguments, new_args, strict=True): + old_arg.replace_all_uses_with(new_arg) + for op in [*old_block]: + if not isinstance(op, last_op_type): + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_block, op) + ctx.lower_op(op) + else: + assert out_template is None + layouts = ( + inference_utils.in_layouts(op) + if inference_utils.has_in_layouts_set(op) + else [] + ) + if isinstance(op, scf.YieldOp): + flat_operands, out_template = _flatten_ir_values(op.operands, layouts) + scf.yield_(flat_operands) + elif isinstance(op, scf.ConditionOp): + flat_carry, out_template = _flatten_ir_values(op.args, layouts) + scf.condition(op.condition, flat_carry) + else: + raise NotImplementedError(f"Unsupported op type: {op}") + op.erase() + assert out_template is not None + return out_template + @_register_lowering(scf.ForOp) def _for_op_lowering_rule( ctx: LoweringContext, for_op: scf.ForOp @@ -1145,33 +1193,78 @@ def _for_op_lowering_rule( for_op.step, flat_init_args, ) - with ir.InsertionPoint(new_for_op.body): - recreated_carry = _unflatten_ir_values( - new_for_op.body.arguments[1:], args_template - ) - ops_to_lower = [] - for op in [*for_op.body]: - if op == yield_op: - continue - mgpu.private_operation_remove_from_parent(op) - mgpu.private_block_append_owned_operation(new_for_op.body, op) - ops_to_lower.append(op) - new_args = (new_for_op.induction_variable, *recreated_carry) - for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): - old_carry.replace_all_uses_with(new_carry) - - for op in ops_to_lower: - with ir.InsertionPoint(op): - ctx.lower_op(op) - with ir.InsertionPoint(new_for_op.body): - flat_operands, _ = _flatten_ir_values(yield_op.operands, in_layouts) - yield_op.erase() - scf.yield_(flat_operands) + _move_scf_block_to_block_with_flattened_arguments( + ctx, + for_op.body, + new_for_op.body, + scf.YieldOp, + args_template, + new_for_op.induction_variable, + ) return _unflatten_ir_values(new_for_op.results, args_template) +@_register_lowering(scf.WhileOp) +def _while_op_lowering_rule( + ctx: LoweringContext, while_op: scf.WhileOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(while_op): + return _traverse_op_lowering_rule(ctx, while_op) + + before_block = while_op.before.blocks[0] + after_block = while_op.after.blocks[0] + condition_op = before_block.operations[len(before_block.operations) - 1] + yield_op = after_block.operations[len(after_block.operations) - 1] + + in_layouts = inference_utils.in_layouts(while_op) + out_layouts = inference_utils.out_layouts(while_op) + + if in_layouts: + yield_layouts = inference_utils.in_layouts(yield_op) + if in_layouts != yield_layouts: + raise ValueError( + f"Input layouts {in_layouts} do not match yield layouts" + f" {yield_layouts}" + ) + + if out_layouts: + condition_layouts = inference_utils.in_layouts(condition_op) + if out_layouts != condition_layouts: + raise ValueError( + f"Output layouts {out_layouts} do not match condition layouts" + f" {condition_layouts}" + ) + + flat_inits, inits_template = _flatten_ir_values(while_op.inits, in_layouts) + result_types = _infer_flat_result_types(while_op, out_layouts) + new_while_op = scf.WhileOp(result_types, flat_inits) + + # Before block + init_types = [v.type for v in flat_inits] + new_before_block = new_while_op.before.blocks.append(*init_types) + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, + before_block, + new_before_block, + scf.ConditionOp, + inits_template, + ) + + # After block + new_after_block = new_while_op.after.blocks.append(*result_types) + _move_scf_block_to_block_with_flattened_arguments( + ctx, + after_block, + new_after_block, + scf.YieldOp, + results_template, + ) + + return _unflatten_ir_values(new_while_op.results, results_template) + + def _infer_flat_result_types( op: ir.OpView, out_layouts: Sequence[ir.Attribute] ) -> Sequence[ir.Type]: @@ -1221,19 +1314,9 @@ def _index_switch_op_lowering_rule( ): [block] = region.blocks new_block = new_region.blocks.append() - with ir.InsertionPoint(new_block): - for op in [*block]: - if not isinstance(op, scf.YieldOp): - mgpu.private_operation_remove_from_parent(op) - mgpu.private_block_append_owned_operation(new_block, op) - ctx.lower_op(op) - continue - if inference_utils.in_layouts(op) != out_layouts: - raise ValueError("Layout mismatch") - flat_results, results_template = _flatten_ir_values( - op.operands, out_layouts - ) - scf.yield_(flat_results) + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, block, new_block, scf.YieldOp, [] + ) return _unflatten_ir_values(new_switch_op.results, results_template) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index b39dc933ce9d..c010bf181bce 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -336,38 +336,61 @@ def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: return [], [layout] -@partial(_add_layout_inference_rule, scf.YieldOp) -def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: +def _layouts_from_values(values: Sequence[ir.Value]) -> list[ir.Attribute] | None: layouts = [] - for result in op.results_: - if not ir.VectorType.isinstance(result.type): + for value in values: + if not ir.VectorType.isinstance(value.type): continue - if (layout := inference_utils.value_layout(result)) is not None: + if (layout := inference_utils.value_layout(value)) is not None: if layouts_lib.is_splat_fragmented_layout(layout): return None layouts.append(layout) else: # Not all layouts could be inferred for vector ops. Return for now. return None + return layouts +@partial(_add_layout_inference_rule, scf.YieldOp) +def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: + layouts = _layouts_from_values(op.results_) + if layouts is None: + return None return (layouts, []) +@partial(_add_layout_inference_rule, scf.ConditionOp) +def _infer_condition_op_layout(op: scf.ConditionOp) -> OptionalLayouts: + layouts = _layouts_from_values(op.args) + if layouts is None: + return None + return (layouts, []) + + +def _last_op(region: ir.Region, expected_op_type: type[ir.OpView]): + [block] = region.blocks + last_op = block.operations[len(block.operations) - 1] + assert isinstance(last_op, expected_op_type) + return last_op + + +def _infer_from_op(op: ir.OpView) -> list[ir.Attribute] | None: + if not inference_utils.has_in_layouts_set(op): + return None + in_layouts = list(inference_utils.in_layouts(op)) + if any( + layouts_lib.is_splat_fragmented_layout(layout) + for layout in in_layouts + ): + return None + return in_layouts + + def _infer_from_yield_ops(op: ir.Operation) -> list[ir.Attribute] | None: candidates = [] for region in op.regions: - [block] = region.blocks - yield_op = block.operations[len(block.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - if not inference_utils.has_in_layouts_set(yield_op): - continue - yield_layouts = inference_utils.in_layouts(yield_op) - if any( - layouts_lib.is_splat_fragmented_layout(layout) - for layout in yield_layouts - ): - continue - candidates.append(yield_layouts) + yield_layouts = _infer_from_op(_last_op(region, scf.YieldOp)) + if yield_layouts is not None: + candidates.append(yield_layouts) if not candidates: return None return [_choose_representative_layout(set(c)) for c in zip(*candidates)] @@ -382,6 +405,27 @@ def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: return None +@partial(_add_layout_inference_rule, scf.WhileOp) +def _infer_while_op_layout(op: scf.WhileOp) -> OptionalLayouts: + # TODO(dasenov): we don't attempt to propagate from outside for the moment. + + # Note that the inputs or results do not necessarily contain vector types. If + # there is no vector type, the corresponding layouts (in_layouts or + # out_layouts) should be an empty list. + + yield_op = _last_op(op.after, scf.YieldOp) + needs_in_layouts = inference_utils.should_have_layout(yield_op) + in_layouts = _infer_from_op(yield_op) if needs_in_layouts else [] + + condition_op = _last_op(op.before, scf.ConditionOp) + needs_out_layouts = inference_utils.should_have_layout(condition_op) + out_layouts = _infer_from_op(condition_op) if needs_out_layouts else [] + + if in_layouts is None or out_layouts is None: + return None + return in_layouts, out_layouts + + @partial(_add_layout_inference_rule, scf.IfOp) def _infer_if_op_layout(op: scf.IfOp) -> OptionalLayouts: if layouts := _infer_from_yield_ops(op): diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 315ae2659ab6..038766542f3b 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -429,6 +429,59 @@ def body(lower_bound, upper_bound, step, a, b, c): self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout]) + @parameterized.parameters( + ((), None, (), None), + ((64, 32), mgpu.WGMMA_LAYOUT, (), None), + ((), None, (64, 32), mgpu.WGMMA_LAYOUT), + ((64,), mgpu.WGMMA_ROW_LAYOUT, (64, 32), mgpu.WGMMA_LAYOUT), + ) + def test_infer_while_op_layouts( + self, init_shape, init_layout, result_shape, result_layout + ): + if init_shape: + in_type = ir.VectorType.get(init_shape, ir.F32Type.get()) + else: + in_type = ir.F32Type.get() + + if result_shape: + out_type = ir.VectorType.get(result_shape, ir.F32Type.get()) + else: + out_type = ir.F32Type.get() + + while_op = condition_op = yield_op = None + + def body(condition, init, result): + nonlocal while_op, condition_op, yield_op + while_op = scf.WhileOp([out_type], [init]) + before_block = while_op.before.blocks.append(init.type) + with ir.InsertionPoint(before_block): + condition_op = scf.ConditionOp(condition, [result]) + + after_block = while_op.after.blocks.append(out_type) + with ir.InsertionPoint(after_block): + yield_op = scf.YieldOp([init]) + + with ir.InsertionPoint(self.module.body): + i1 = ir.IntegerType.get_signless(1) + func.FuncOp.from_py_func(i1, in_type, out_type)(body) + + [f] = self.module.body.operations + f_layouts = [] + if init_layout: + f_layouts.append(layouts.to_layout_attr(init_layout)) + if result_layout: + f_layouts.append(layouts.to_layout_attr(result_layout)) + if f_layouts: + f.attributes["in_layouts"] = ir.ArrayAttr.get(f_layouts) + + mgpu.infer_layout(self.module) + + if init_layout or result_layout: + init_layouts = [layouts.to_layout_attr(init_layout)] if init_layout else [] + result_layouts = [layouts.to_layout_attr(result_layout)] if result_layout else [] + self.assertSequenceEqual(while_op.attributes["in_layouts"], init_layouts) + self.assertSequenceEqual(while_op.attributes["out_layouts"], result_layouts) + def test_infer_layout_has_no_layout_for_non_vector_types(self): shape = (32, 4) elt_ty = ir.BF16Type.get() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 873854266782..9d259097f90a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1218,8 +1218,6 @@ def body(idx, _): np.testing.assert_array_equal(kernel(x, y), x + y) def test_while_loop(self): - self.skip_if_wg_semantics() - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) @@ -1242,8 +1240,6 @@ def body(acc): ) def test_while_loop_layout_mismatch(self): - self.skip_if_wg_semantics() # while and conditional are not yet supported. - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) @@ -1261,8 +1257,17 @@ def body(acc): _ = jax.lax.while_loop(cond, body, o_ref[...]) - with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): - kernel() + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + with self.assertRaisesRegex( + NotImplementedError, + "Cannot convert from WGStridedFragLayout.* to TiledLayout", + ): + kernel() + else: + with self.assertRaisesRegex( + ValueError, "has layout .*, when it should be" + ): + kernel() def test_cond(self): @functools.partial( @@ -1722,10 +1727,6 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - if force_while: - # Layout inference and lowering for 'while' are not yet implemented for - # warpgroup semantics. - self.skip_if_wg_semantics() if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) else: From 123022cae08d83c4d53ac77481b5c2391f003794 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 13 May 2025 08:55:33 -0700 Subject: [PATCH 1144/1769] [Mosaic] Make `tpu.relayout` an explicit operation, merge in existing behavior, stop calling relayout() in apply. This change should reduce complexity and make it easier to see what happened in a graph. Note - there are still cases where certain relayout() calls are not ops yet, those will be migrated in the future. Specifically, see the note in the CL around force_relayout. Added helper methods to generate full like vectors. Followup for subsequent CLs: Simplify Relayout rule in future CLs, maybe break up into smaller sub relayouts with nice names. Followup for subsequent CLs: Unify transpose in here PiperOrigin-RevId: 758240646 --- .../tpu/transforms/apply_vector_layout.cc | 41 +++++-------------- .../tpu/transforms/relayout_insertion.cc | 30 +++++++++++++- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d625e8bf4d6f..656be0e677b0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6811,6 +6811,7 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask_pack) { std::vector vmsks_shape(src_tiles.dimensions().begin(), src_tiles.dimensions().end()); @@ -6855,6 +6856,7 @@ FailureOr> relayout(RewriteContext &ctx, auto assemble_with_mask_check = [&](xla::Array &tiles, bool use_implicit_shape = false) { + if (is_mask) { auto zeros_tile = builder.create( tiles.begin()->getLoc(), @@ -6985,34 +6987,18 @@ LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, auto in_layout_array_attr = tpu_relayout_op->getAttrOfType("in_layout"); - if (!in_layout_array_attr || in_layout_array_attr.empty()) { - return tpu_relayout_op.emitOpError( - "missing or empty 'in_layout' attribute"); - } auto src_vla = dyn_cast(in_layout_array_attr[0]); - if (!src_vla) { - return tpu_relayout_op.emitOpError( - "'in_layout' attribute is not a VectorLayoutAttr"); - } VectorLayout src_layout = src_vla.getLayout().value(); auto out_layout_array_attr = tpu_relayout_op->getAttrOfType("out_layout"); - if (!out_layout_array_attr || out_layout_array_attr.empty()) { - return tpu_relayout_op.emitOpError( - "missing or empty 'out_layout' attribute"); - } auto dst_vla = dyn_cast(out_layout_array_attr[0]); - if (!dst_vla) { - return tpu_relayout_op.emitOpError( - "'out_layout' attribute is not a VectorLayoutAttr"); - } VectorLayout dst_layout = dst_vla.getLayout().value(); if (src_layout == dst_layout) { - tpu_relayout_op.replaceAllUsesWith(tpu_relayout_op.getInput()); - tpu_relayout_op.erase(); - return success(); + return op.emitError( + "Source and destination layouts are the same - did you forget to run " + "relayout-insertion-pass?"); } OpBuilder builder(&op); @@ -7079,9 +7065,6 @@ const llvm::StringMap &rules() { return *rules; } -// TODO(apaszke): Implement a debug mode that inserts additional assertions. -// For example, we should verify that ops that were supposed to generate -// replicated outputs satisfy that requirement. LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // When an operation does not have any operands, the layout_in tuple is empty. // If one of the operands is not of vector type, the corresponding entry in @@ -7117,14 +7100,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (*lo == *li) { - continue; + if (*lo != *li) { + return op.emitError( + "Invariant violation: Input layout does not match output layout - " + "did you forget to run relayout-insertion?"); } - OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN( - Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, - /*dst=*/*li)); - op.setOperand(idx, new_v); } } @@ -7132,7 +7112,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // support for offsets outside of the first tile. When support is more broad, // any op without support should check it within their own rule. if (!isa(op)) { + vector::ExtractStridedSliceOp, vector::ShapeCastOp, tpu::RelayoutOp>( + op)) { for (const Layout &layout : layouts_in) { if (layout && layout->offsets()[1].has_value() && layout->offsets()[1].value() >= layout->tiling()[1]) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index 6ddf8bd5ce66..178b97876b49 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -119,7 +119,26 @@ FailureOr> relayout( dst_bitwidth_layout); return cast>(cmp_op.getResult()); } - return v; + // Fall through to generic relayout. + auto relayout_op = + builder.create(v.getLoc(), v.getType(), v); + setLayout(relayout_op, src, dst); + + return cast>(relayout_op.getResult()); +} + +LogicalResult insertRelayout(Operation &op, int hardware_generation, + std::array target_shape); + +LogicalResult insertRelayoutBlock(Block &block, int hardware_generation, + const std::array target_shape) { + // We'll be modifying the block, so use early increment. + for (Operation &op : make_early_inc_range(block)) { + if (failed(insertRelayout(op, hardware_generation, target_shape))) { + return failure(); + } + } + return success(); } // TODO(jevinjiang): make relayout to an op so we don't need decide when to @@ -167,6 +186,15 @@ LogicalResult insertRelayout(Operation &op, int hardware_generation, /*dst=*/*li, hardware_generation, target_shape)); op.setOperand(idx, new_v); } + + for (auto ®ion : op.getRegions()) { + for (auto &block : region.getBlocks()) { + if (failed( + insertRelayoutBlock(block, hardware_generation, target_shape))) { + return failure(); + } + } + } return success(); } From df66c2fdc538a5b0d8e7d052a96ceaa6258a9da5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 13 May 2025 11:20:07 -0400 Subject: [PATCH 1145/1769] Don't instantiate zeros passed to custom_lin_p. --- jax/_src/interpreters/ad.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 435e9027f5b3..45705382efa0 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -565,12 +565,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.parent_trace): - tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( - *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_transpose(self, prim, call, tracers, **params): @@ -734,11 +734,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - tangents_in_zeros = map(instantiate_zeros, tangents_in) + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.tangent_trace): tangents_out = custom_lin_p.bind( - *res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) @@ -1223,7 +1224,7 @@ def raise_custom_vjp_error_on_jvp(*_, **__): def _custom_lin_transpose(cts_out, *invals, num_res, bwd: lu.WrappedFun, out_avals, - symbolic_zeros): + symbolic_zeros, in_zeros): res, _ = split_list(invals, [num_res]) if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) @@ -1231,7 +1232,8 @@ def _custom_lin_transpose(cts_out, *invals, num_res, cts_out = map(instantiate_zeros, cts_out) cts_in = bwd.call_wrapped(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) - return [None] * num_res + list(cts_in) + nz_cts_in, _ = partition_list(in_zeros, cts_in) + return [None] * num_res + nz_cts_in primitive_transposes[custom_lin_p] = _custom_lin_transpose From 60a37718365e68190a76972723cdf73f581cd251 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Tue, 13 May 2025 10:43:29 -0700 Subject: [PATCH 1146/1769] Fix xla_bridge_test: Do "assert subset" the right way around. PiperOrigin-RevId: 758283311 --- tests/xla_bridge_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 3306cb64aced..5c7472492fd5 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -214,7 +214,8 @@ def getopts(): return options def make_c_api_client(plugin_name, new_options, *args, **kwargs): - self.assertContainsSubset(new_options, options) + for k in options: + self.assertEqual(new_options[k], options[k]) with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( From 71692fcbedf162c586ec8847d34938151ffa1f79 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 13 May 2025 13:51:31 -0400 Subject: [PATCH 1147/1769] Add a pretty printing rule for custom_vjp. --- jax/_src/custom_derivatives.py | 17 +++++++++++++++++ tests/custom_api_test.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index dcd893f44123..7b81c4e86889 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1046,6 +1046,23 @@ def dce_bwd(*args): return list(used_ins), new_eqn pe.dce_rules[custom_vjp_call_p] = _custom_vjp_call_dce + +def _custom_vjp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params.pop("out_trees") + params["fwd"] = params.pop("fwd_jaxpr_thunk").debug_info.func_name + params["bwd"] = params.pop("bwd").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + +core.pp_eqn_rules[custom_vjp_call_p] = _custom_vjp_call_pp_rule + batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp) diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 73dc2fbefcaa..9d10b40c6030 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -3117,6 +3117,35 @@ def f_bwd(res, cts): ): f(0.5, 0.1, z=1.0) + def test_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_vjp_call[ + name=f + bwd=f_bwd + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + fwd=f_fwd + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + def transpose_unary(f, x_example): def transposed(y): From 78f89b8bb4742fd27ba0aa155b4f6f31c0a1d8db Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 13 May 2025 10:51:19 -0700 Subject: [PATCH 1148/1769] [Pallas] Allow f8 casting tests on TPUv5-. PiperOrigin-RevId: 758287085 --- tests/pallas/ops_test.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 61ebc19e018f..9bb6d31d15e1 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -594,10 +594,16 @@ def kernel(x_ref, y_ref): def test_cast_from_32bit(self, from_dtype, to_dtype, data): sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + if jtu.get_tpu_version() >= 5 and not jtu.if_cloud_tpu_at_least( + 2025, 3, 8 + ): self.skipTest("Test requires libtpu from 2025/3/8 or later") + if jtu.get_tpu_version() < 5 and not jtu.if_cloud_tpu_at_least( + 2025, 5, 15 + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}: if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( 2025, 4, 1 @@ -721,10 +727,16 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): "float8_e5m2", "float8_e4m3fn", } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 9): + if jtu.get_tpu_version() >= 5 and not jtu.if_cloud_tpu_at_least( + 2025, 3, 9 + ): self.skipTest("Test requires libtpu from 2025/3/9 or later") + if jtu.get_tpu_version() < 5 and not jtu.if_cloud_tpu_at_least( + 2025, 5, 15 + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if from_dtype == "int2" and to_dtype == "bool": self.skipTest( "TODO(b/343490729): XLA compare(s2, s2) yields wrong results" From ef1b3e9231a929ed509628c080e48d516a5e173d Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 13 May 2025 12:06:48 -0700 Subject: [PATCH 1149/1769] [JAX] Make fully replicated sharding to avoid materializing the same host buffers This change recognizes fully replicated shardings (based on JAX sharding), and materialize only one host buffer. This saves the cost of repeatedly materializing the same host buffer for multiple devices, and makes it streamlined to create a multi-device IFRT array. Clean up `JAX_IFRT_VERSION_NUMBER < 2` since we are well past it. PiperOrigin-RevId: 758320113 --- jax/_src/test_util.py | 15 ++++++ jaxlib/_jax/__init__.pyi | 2 + jaxlib/py_values.cc | 108 +++++++++++++++++++++++++++++---------- jaxlib/py_values.h | 22 +++++++- jaxlib/xla.cc | 4 ++ jaxlib/xla_client.py | 2 +- tests/api_test.py | 85 ++++++++++++++++++++++++++++++ 7 files changed, 210 insertions(+), 28 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c584ffefa4f2..bb1ef6595ec3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -54,6 +54,7 @@ from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 @@ -354,6 +355,20 @@ def assert_num_jit_and_pmap_compilations(times): raise AssertionError(f"Expected exactly {times} XLA compilations, " f"but executed {count()}") +@contextmanager +def count_internal_device_puts(): + if jaxlib_extension_version >= 341: + before = jax._src.lib._jax.get_internal_device_put_info() + counts = {} + try: + yield lambda: counts + finally: + if jaxlib_extension_version >= 341: + after = jax._src.lib._jax.get_internal_device_put_info() + for k, v in after.items(): + diff = v - before.get(k, 0) + if diff != 0: + counts[k] = diff def jaxlib_version() -> tuple[int, ...]: return _jaxlib.version diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 6f4f952be9c3..c9c25e172161 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -989,3 +989,5 @@ def approx_top_k_reduction_output_size( aggregate_to_topk: bool | None = ..., input_size_override: int | None = ..., ) -> tuple[int, int]: ... + +def get_internal_device_put_info() -> dict[str, int]: ... diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index 81f6523d3e14..6ea5c272eea3 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -78,6 +79,12 @@ namespace xla { namespace { +// Gets the thread-local instance. +static DevicePutInfo& GetDevicePutInfo() { + thread_local DevicePutInfo device_put_info; + return device_put_info; +} + // Prepared data for creating a single shard of an array. Holds a single-device // IFRT array or a host buffer. struct Shard { @@ -147,6 +154,27 @@ using DevicePutHandler = std::function( nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options)>; +// Shared logic that makes an IFRT array (either single-device or multi-device) +// from a fully-replicated `shard` that is created from a host buffer (not from +// an existing IFRT array). `shard` will be consumed. +// +// `user_context` will be used for a new IFRT array created. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeIfrtArrayFromFullyReplicatedShard( + ifrt::Client* ifrt_client, ifrt::ShardingRef ifrt_sharding, Shard& shard, + tsl::RCReference user_context) { + auto host_buffer_shard = std::get( + std::move(shard.ifrt_array_or_host_buffer)); + return ifrt_client->MakeArrayFromHostBuffer( + host_buffer_shard.data, host_buffer_shard.dtype, + std::move(host_buffer_shard.shape), + std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), + shard.host_buffer_semantics, std::move(host_buffer_shard.on_done), + std::move(user_context)); +} + // Shared logic that makes a single-device IFRT array from a `shard`. `shard` // will be consumed. // @@ -161,18 +189,11 @@ absl::StatusOr MakeSingleDeviceIfrtArrayFromShard( if (auto* ifrt_array = std::get_if(&shard.ifrt_array_or_host_buffer)) { return std::move(*ifrt_array); - } else { - auto host_buffer_shard = std::get( - std::move(shard.ifrt_array_or_host_buffer)); - ifrt::ShardingRef ifrt_sharding = - ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); - return ifrt_client->MakeArrayFromHostBuffer( - host_buffer_shard.data, host_buffer_shard.dtype, - std::move(host_buffer_shard.shape), - std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), - shard.host_buffer_semantics, std::move(host_buffer_shard.on_done), - std::move(user_context)); } + ifrt::ShardingRef ifrt_sharding = + ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); + return MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shard, std::move(user_context)); } // Makes an IFRT Array from `shards` using a batched array creation API (fast @@ -587,10 +608,12 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, ifrt::Device* to_device, ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options) { - using PyObjectDeviceHandlerMap = absl::flat_hash_map; + using PyObjectDeviceHandlerMap = + absl::flat_hash_map; - auto init_fn = [](){ - std::unique_ptr p = std::make_unique(); + auto init_fn = []() { + std::unique_ptr p = + std::make_unique(); const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); // Python scalar types. @@ -660,7 +683,8 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; return p; }; - const PyObjectDeviceHandlerMap& handlers = xla::SafeStaticInit(init_fn); + const PyObjectDeviceHandlerMap& handlers = + xla::SafeStaticInit(init_fn); if (arg.type().ptr() == PyArray::type().ptr()) { auto array = nb::borrow(arg); @@ -895,6 +919,7 @@ absl::StatusOr DevicePutWithDevice( ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, const DevicePutOptions& options) { tsl::profiler::TraceMe traceme("DevicePut"); + ++GetDevicePutInfo().device_put_with_device; if (!ifrt_device->IsAddressable()) { return InvalidArgument("Cannot copy array to non-addressable device: %s", @@ -924,6 +949,7 @@ absl::StatusOr DevicePutWithSharding( absl::Span shape, nanobind::handle sharding, const DevicePutOptions& options) { tsl::profiler::TraceMe traceme("DevicePutWithSharding"); + ++GetDevicePutInfo().device_put_with_sharding; TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef ifrt_device_list, GetIfrtDeviceList(sharding)); @@ -973,12 +999,19 @@ absl::StatusOr DevicePutWithSharding( } ifrt::ShardingRef ifrt_sharding; + bool is_fully_replicated; if (is_pmap_sharding) { CHECK(!shard_fns.empty()); // IFRT Sharding will be determined once we discover the shard shape. + is_fully_replicated = false; } else { TF_ASSIGN_OR_RETURN(ifrt_sharding, GetIfrtHloSharding(sharding, ifrt_shape)); + // Fully-replicated shardings enable additional optimizations of using a + // single host buffer. + // TODO(hyeontaek): Enable a similar optimization for partially replicated + // cases to reduce the number of host buffers to obtain. + is_fully_replicated = ifrt_sharding->IsFullyReplicated(); } tsl::RCReference ifrt_user_context = ifrt_client->CreateUserContext(); @@ -988,12 +1021,6 @@ absl::StatusOr DevicePutWithSharding( // Whether to build an IFRT array from host buffers as a single batch. We do // not batch any shard is already an IFRT array. bool should_batch = true; -#if JAX_IFRT_VERSION_NUMBER < 2 - // PjRt-IFRT would fail `xla::ifrt::Client::MakeArrayFromHostBuffer()` invoked - // by `xla::ifrt::ClientMakeArraysFromHostBufferShards()` for a fully - // replicated sharding if the sharding has any non-addressable device. - should_batch = false; -#endif std::vector shards; shards.reserve(shard_fns.size()); @@ -1004,7 +1031,15 @@ absl::StatusOr DevicePutWithSharding( should_batch = false; } shards.push_back(std::move(shard)); + if (should_batch && is_fully_replicated) { + // We need only one host buffer for a fully-replicated array. + break; + } } + // While we have finished calling `shard_fns`, we cannot destroy them until we + // make a call to IFRT array creation. Destroying `shard_fns` would release + // host buffers prematurely and can cause the array creation API to see + // garbage data. // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are // supported. @@ -1021,12 +1056,22 @@ absl::StatusOr DevicePutWithSharding( ifrt::ArrayRef ifrt_array; if (should_batch) { - TF_ASSIGN_OR_RETURN(ifrt_array, - MakeIfrtArrayFromShardsInBatch( - ifrt_client, ifrt_dtype, std::move(ifrt_shape), - std::move(ifrt_sharding), absl::MakeSpan(shards), - std::move(ifrt_user_context))); + if (is_fully_replicated && shards.size() == 1) { + ++GetDevicePutInfo().device_put_fully_replicated; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shards.front(), + std::move(ifrt_user_context))); + } else { + ++GetDevicePutInfo().device_put_batched; + TF_ASSIGN_OR_RETURN(ifrt_array, + MakeIfrtArrayFromShardsInBatch( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } } else { + ++GetDevicePutInfo().device_put_assembled; TF_ASSIGN_OR_RETURN( ifrt_array, MakeIfrtArrayFromShardsWithAssembly( ifrt_client, ifrt_dtype, std::move(ifrt_shape), @@ -1038,4 +1083,15 @@ absl::StatusOr DevicePutWithSharding( return DevicePutResult(std::move(ifrt_array), weak_type); } +std::unordered_map DevicePutInfo::GetInfo() { + const DevicePutInfo& info = GetDevicePutInfo(); + return std::unordered_map({ + {"device_put_with_device", info.device_put_with_device}, + {"device_put_with_sharding", info.device_put_with_sharding}, + {"device_put_fully_replicated", info.device_put_fully_replicated}, + {"device_put_batched", info.device_put_batched}, + {"device_put_assembled", info.device_put_assembled}, + }); +} + } // namespace xla diff --git a/jaxlib/py_values.h b/jaxlib/py_values.h index 64a83aa66ab9..d74cf9668a99 100644 --- a/jaxlib/py_values.h +++ b/jaxlib/py_values.h @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/inlined_vector.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" #include "xla/python/nb_numpy.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" namespace xla { @@ -136,6 +136,26 @@ H AbslHashValue(H h, const xla::PyArgSignature& s) { return h; } +// Tracks the number of DevicePut calls and subcases. For testing. +struct DevicePutInfo { + // DevicePutWithDevice call count. + int device_put_with_device = 0; + + // DevicePutWithSharding call count. + int device_put_with_sharding = 0; + + // DevicePutWithSharding with a fully replicated sharding. + int device_put_fully_replicated = 0; + // DevicePutWithSharding that made a batched array creation call. + int device_put_batched = 0; + // DevicePutWithSharding that made per-shard creation calls followed by an + // assembly call. + int device_put_assembled = 0; + + // Returns a map of the counters for the current thread. + static std::unordered_map GetInfo(); +}; + } // namespace xla #endif // JAXLIB_PY_VALUES_H_ diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index adf6f3c98297..4020e061b3f4 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -45,6 +45,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/ffi.h" @@ -975,6 +976,9 @@ NB_MODULE(_jax, m) { nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, nb::arg("input_size_override") = -1); + m.def("get_internal_device_put_info", + []() { return DevicePutInfo::GetInfo(); }); + } // NOLINT(readability/fn_size) } // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 6aaae11c139d..69e168de9c2d 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 340 +_version = 341 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/api_test.py b/tests/api_test.py index 15966c678d87..5f775b46fb16 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -60,6 +60,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lib import _jax +from jax._src.lib import jaxlib_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError, @@ -1972,6 +1973,90 @@ def test_device_put_sharding_mismatched_tree_different_leaf_count(self): ): jax.device_put((x, y, z), device=(s1, s2)) + def test_internal_device_put_with_device(self): + if jaxlib_extension_version < 341: + raise unittest.SkipTest( + "Test requires jaxlib extension version >= 341 for tracking low-level" + " DevicePut calls") + + # Hitting the cache for a single-device jitted execution while using a numpy + # array calls internal `DevicePutWithDevice`. + f = jax.jit(lambda x: x + 1) + f(np.arange(8)) + + with jtu.count_internal_device_puts() as counts: + f(np.arange(8)) + self.assertEqual(counts(), {"device_put_with_device": 1}) + + def test_internal_device_put_fully_replicated(self): + if jaxlib_extension_version < 341: + raise unittest.SkipTest( + "Test requires jaxlib extension version >= 341 for tracking low-level" + " DevicePut calls") + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a fully-replicated sharding + # calls internal `DevicePutWithSharding`, taking the fully-replicated sub + # case. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P()) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), + {"device_put_with_sharding": 1, "device_put_fully_replicated": 1}, + ) + + def test_internal_device_put_batched(self): + if jaxlib_extension_version < 341: + raise unittest.SkipTest( + "Test requires jaxlib extension version >= 341 for tracking low-level" + " DevicePut calls") + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a non-fully-replicated sharding + # calls internal `DevicePutWithSharding`, performing batched creation of a + # multi-shard array. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_batched": 1} + ) + + def test_internal_device_put_assembled(self): + if jaxlib_extension_version < 341: + raise unittest.SkipTest( + "Test requires jaxlib extension version >= 341 for tracking low-level" + " DevicePut calls") + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from per-device JAX arrays calls internal + # `DevicePutWithSharding`, performing per-shard array adoption followed by + # assembly. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + arr = np.arange(8) + per_device_arrs = { + # Use uncommitted arrays that are not aligned with the destination + # sharding so that we trigger `BatchedDevicePut`. + index: jnp.array(arr[index]) + for _, index in sharding.devices_indices_map(arr.shape).items() + } + data_callback = lambda index: per_device_arrs[index] + with jtu.count_internal_device_puts() as counts: + jax.make_array_from_callback(arr.shape, sharding, data_callback) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_assembled": 1} + ) + def test_device_put_custom_type_not_accepting_none_leaves(self): class CustomNode(list): From 0e4f213e9c84b1f59a9d1be84c63cd1b5e4dfd2b Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 13 May 2025 12:31:17 -0700 Subject: [PATCH 1150/1769] Use DmaCopyChunk::Make because directly assigning the struct fully constrains the implementation. PiperOrigin-RevId: 758330465 --- jaxlib/py_socket_transfer.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index 491e90d778cf..ed2a4f4c204a 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -127,12 +127,9 @@ class IfrtArrayEntry : public PullTable::Entry { auto req_id = base_req_id; ++base_req_id; for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { - DmaCopyChunk blob; - blob.arr = std::move(arrs_[bid].arr); - blob.buffer = arrs_[bid].buffer; - blob.buffer_id = bid; - blob.offset = i * xfer_size_; - blob.size = std::min(xfer_size_, arrs_[bid].buf_size - blob.offset); + DmaCopyChunk blob = DmaCopyChunk::Make( + std::move(arrs_[bid].arr), arrs_[bid].buffer, bid, i * xfer_size_, + std::min(xfer_size_, arrs_[bid].buf_size - i * xfer_size_)); bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; state_->ScheduleCopy( blob, [req_id, state, copier_state = state_, is_largest]( From 725d0f64addd67b242691012854a582d2f14ce6c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 May 2025 12:31:52 -0700 Subject: [PATCH 1151/1769] Add direct `pypi` dependencies to the JAX test targets. PiperOrigin-RevId: 758330650 --- jaxlib/jax.bzl | 7 +- tests/BUILD | 579 +++++++++++++++++++++++++++++++++++++++------ tests/mosaic/BUILD | 27 ++- tests/pallas/BUILD | 185 ++++++++++++--- 4 files changed, 674 insertions(+), 124 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index a8fe2b50344b..e739b681a029 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -256,18 +256,13 @@ def _get_jax_test_deps(deps): """ jax_build_deps = [d for d in deps if not d.startswith("@pypi//")] - # A lot of tests don't have explicit dependencies on absl/testing, numpy, etc. But the tests + # A lot of tests don't have explicit dependencies on scipy, ml_dtypes, etc. But the tests # transitively depends on them via //jax. So we need to make sure that these dependencies are # included in the test when JAX is built from source. - # TODO(ybaturina): Add individual dependencies for each test and remove this block. jax_transitive_pypi_test_deps = {k: "true" for k in py_deps([ - "absl/testing", - "numpy", "ml_dtypes", "scipy", "opt_einsum", - "hypothesis", - "cloudpickle", "flatbuffers", ])} diff --git a/tests/BUILD b/tests/BUILD index 4c6369bee5de..e70d4593e8fc 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -38,7 +38,10 @@ jax_multiplatform_test( shard_count = 10, deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -61,12 +64,16 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -75,12 +82,16 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ] + py_deps("absl/testing"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -98,7 +109,10 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ] + py_deps("absl/testing"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -112,7 +126,11 @@ jax_multiplatform_test( "gpu_h100x2", ], tags = ["multiaccelerator"], - deps = py_deps("tensorflow_core"), + deps = py_deps([ + "absl/testing", + "numpy", + "tensorflow_core", + ]), ) jax_multiplatform_test( @@ -121,6 +139,10 @@ jax_multiplatform_test( shard_count = { "gpu": 5, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -132,7 +154,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental_buffer_callback", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -151,11 +176,19 @@ jax_multiplatform_test( "cpu": 5, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -164,14 +197,20 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ] + py_deps("portpicker"), + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], enable_backends = ["gpu"], - deps = py_deps("portpicker"), + deps = py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_py_test( @@ -184,12 +223,19 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ] + py_deps("portpicker"), + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -199,12 +245,13 @@ jax_multiplatform_test( enable_configs = [ "cpu", ], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], - deps = ["//jax:extend"], + deps = ["//jax:extend"] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -214,7 +261,10 @@ jax_multiplatform_test( "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. - deps = ["//jax:extend"], + deps = ["//jax:extend"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -231,11 +281,19 @@ jax_multiplatform_test( "cpu": 20, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -246,6 +304,7 @@ jax_multiplatform_test( "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, main = "gpu_memory_flags_test.py", + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -255,6 +314,7 @@ jax_multiplatform_test( env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -269,7 +329,12 @@ jax_multiplatform_test( }, deps = [ "//jax:experimental_sparse", - ] + py_deps("matplotlib"), + ] + py_deps([ + "matplotlib", + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -280,6 +345,11 @@ jax_multiplatform_test( "gpu": 10, "tpu": 15, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_py_test( @@ -288,7 +358,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -306,7 +376,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -329,7 +402,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -345,7 +421,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -359,7 +438,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -375,7 +457,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -390,7 +475,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -406,7 +494,7 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -422,7 +510,10 @@ jax_multiplatform_test( deps = [ "//jax:experimental", "//jax:internal_test_util", - ], + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -431,7 +522,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -443,18 +537,31 @@ jax_multiplatform_test( "tpu": 8, }, tags = ["noasan"], # Linking TF causes a linker OOM. - deps = py_deps("pil") + py_deps("tensorflow_core"), + deps = py_deps([ + "pil", + "tensorflow_core", + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -464,7 +571,10 @@ jax_py_test( "//jax:test_util", "//jax/experimental/jax2tf", "//jax/tools:jax_to_ir", - ] + py_deps("tensorflow_core"), + ] + py_deps([ + "tensorflow_core", + "absl/testing", + ]), ) jax_py_test( @@ -474,7 +584,7 @@ jax_py_test( "//jax", "//jax:jaxpr_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -487,7 +597,10 @@ jax_multiplatform_test( deps = [ "//jax:jet", "//jax:stax", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -498,16 +611,28 @@ jax_multiplatform_test( "gpu": 30, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -526,6 +651,10 @@ jax_multiplatform_test( "noasan", # Test times out on all backends "test_cpu_thunks", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -536,6 +665,10 @@ jax_multiplatform_test( "gpu": 30, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -546,6 +679,10 @@ jax_multiplatform_test( "gpu": 20, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -556,11 +693,19 @@ jax_multiplatform_test( "gpu": 10, "tpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -571,11 +716,19 @@ jax_multiplatform_test( "gpu": 5, "tpu": 5, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -586,7 +739,11 @@ jax_multiplatform_test( "gpu": 20, "tpu": 8, }, - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -600,6 +757,11 @@ jax_multiplatform_test( "gpu": 5, "tpu": 5, }, + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -617,7 +779,11 @@ jax_multiplatform_test( "tpu": 20, }, tags = ["noasan"], # Times out under asan. - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -630,7 +796,10 @@ jax_multiplatform_test( }, deps = [ "//jax:internal_test_util", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -648,7 +817,11 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy") + py_deps("mpmath"), + ] + py_deps([ + "numpy", + "absl/testing", + "mpmath", + ]), ) jax_multiplatform_test( @@ -659,7 +832,10 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -670,6 +846,10 @@ jax_multiplatform_test( "gpu": 30, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -680,7 +860,10 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -691,7 +874,10 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -702,7 +888,7 @@ jax_py_test( deps = [ "//jax:internal_test_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -713,7 +899,7 @@ jax_py_test( deps = [ "//jax:internal_test_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -733,6 +919,11 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -753,24 +944,36 @@ jax_multiplatform_test( tags = [ "multiaccelerator", ], + deps = py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( name = "magma_linalg_test", srcs = ["magma_linalg_test.py"], enable_backends = ["gpu"], - deps = py_deps("magma"), + deps = py_deps([ + "magma", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -779,7 +982,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -789,12 +992,17 @@ jax_multiplatform_test( "tpu_v3_x4", "gpu_h100x2", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -813,12 +1021,20 @@ jax_multiplatform_test( "tpu": 10, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], - deps = ["//jax:optimizers"], + deps = ["//jax:optimizers"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -826,7 +1042,11 @@ jax_multiplatform_test( srcs = ["pickle_test.py"], deps = [ "//jax:experimental", - ] + py_deps("cloudpickle") + py_deps("numpy"), + ] + py_deps([ + "cloudpickle", + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -850,7 +1070,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:internal_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -868,12 +1091,21 @@ jax_multiplatform_test( # in this case there's not a good place to do it, see b/197635968#comment19 # for details. tags = ["nomsan"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], enable_backends = ["cpu"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -892,7 +1124,7 @@ jax_multiplatform_test( ], deps = [ "//jax:profiler", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -907,7 +1139,10 @@ jax_multiplatform_test( "nomsan", # TODO(b/355237462): msan false-positives in torch? "not_build:arm", ], - deps = py_deps("torch"), + deps = py_deps([ + "torch", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -921,11 +1156,19 @@ jax_multiplatform_test( ], }, shard_count = 8, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -951,6 +1194,11 @@ jax_multiplatform_test( "tpu": 40, }, tags = ["noasan"], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) # TODO(b/199564969): remove once we always enable_custom_prng @@ -959,6 +1207,7 @@ jax_multiplatform_test( srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], main = "random_test.py", + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -972,21 +1221,41 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan/msan. }, shard_count = 12, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1012,6 +1281,11 @@ jax_multiplatform_test( "gpu": 40, "tpu": 50, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1021,7 +1295,11 @@ jax_multiplatform_test( "cpu": 4, "gpu": 4, }, - deps = py_deps("scipy"), + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1039,6 +1317,11 @@ jax_multiplatform_test( "noasan", "notsan", ], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1071,7 +1354,11 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ] + py_deps("scipy"), + ] + py_deps([ + "scipy", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1108,7 +1395,10 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ] + py_deps("scipy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1133,12 +1423,16 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1149,33 +1443,50 @@ jax_multiplatform_test( "gpu": 2, "tpu": 4, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "error_check_test", srcs = ["error_check_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "jax_numpy_error_test", srcs = ["jax_numpy_error_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - deps = ["//jax:stax"], + deps = ["//jax:stax"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", + deps = py_deps([ + "absl/testing", + "scipy", + ]), ) jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1184,7 +1495,11 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) pytype_test( @@ -1193,7 +1508,11 @@ pytype_test( deps = [ "//jax", "//jax:test_util", - ], + "//jax:typing", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1202,7 +1521,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1211,7 +1530,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1230,7 +1549,9 @@ jax_py_test( "//jax", "//jax:compiler", "//jax:test_util", - ] + py_deps("absl/logging"), + ] + py_deps([ + "absl/logging", + ]), ) jax_py_test( @@ -1240,7 +1561,10 @@ jax_py_test( "//jax", "//jax:lru_cache", "//jax:test_util", - ] + py_deps("filelock"), + ] + py_deps([ + "filelock", + "absl/logging", + ]), ) jax_multiplatform_test( @@ -1249,7 +1573,10 @@ jax_multiplatform_test( deps = [ "//jax:compilation_cache_internal", "//jax:compiler", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1258,7 +1585,7 @@ jax_multiplatform_test( deps = [ "//jax:cache_key", "//jax:compiler", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1267,18 +1594,27 @@ jax_multiplatform_test( shard_count = { "cpu": 10, }, - deps = ["//jax:ode"], + deps = ["//jax:ode"] + py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "roofline_test", srcs = ["roofline_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1286,7 +1622,10 @@ jax_multiplatform_test( srcs = ["x64_context_test.py"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1297,6 +1636,10 @@ jax_multiplatform_test( "gpu": 5, "tpu": 10, }, + deps = py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -1306,17 +1649,26 @@ jax_py_test( "//jax", "//jax:mesh_utils", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) jax_multiplatform_test( name = "garbage_collection_guard_test", srcs = ["garbage_collection_guard_test.py"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -1342,6 +1694,10 @@ jax_multiplatform_test( "tpu_v4_x4", ], tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1356,6 +1712,10 @@ jax_multiplatform_test( "gpu_h100_shardy", "tpu_v3_x4_shardy", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1374,7 +1734,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1390,6 +1753,10 @@ jax_multiplatform_test( "tpu_v3_x4", "tpu_v4_x4", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1411,12 +1778,20 @@ jax_multiplatform_test( "gpu": 2, "tpu": 2, }, - deps = py_deps("hypothesis"), + deps = py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1425,6 +1800,10 @@ jax_multiplatform_test( shard_count = { "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1445,7 +1824,7 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1469,12 +1848,16 @@ jax_multiplatform_test( deps = [ "//jax:experimental", "//jax:tree_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1482,7 +1865,10 @@ jax_multiplatform_test( srcs = ["attrs_test.py"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -1491,7 +1877,10 @@ jax_multiplatform_test( deps = [ "//jax:experimental_colocated_python", "//jax/extend:ifrt_programs", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1504,7 +1893,10 @@ jax_multiplatform_test( shard_count = 15, deps = [ "//jax:rnn", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1514,7 +1906,7 @@ jax_py_test( "//jax", "//jax:mosaic", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1523,7 +1915,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1532,12 +1924,13 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1550,6 +1943,10 @@ jax_multiplatform_test( "tpu_v3_x4", ], tags = [], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1574,7 +1971,10 @@ jax_multiplatform_test( ], deps = [ "//jax:internal_test_harnesses", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1596,7 +1996,10 @@ jax_multiplatform_test( ], deps = [ "//jax:internal_test_harnesses", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1610,7 +2013,10 @@ jax_multiplatform_test( deps = [ "//jax:internal_export_back_compat_test_data", "//jax:internal_export_back_compat_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1618,12 +2024,16 @@ jax_multiplatform_test( srcs = ["fused_attention_stablehlo_test.py"], enable_backends = ["gpu"], tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], - deps = ["//jax:experimental"], + deps = ["//jax:experimental"] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1637,7 +2047,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1646,7 +2059,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1656,7 +2069,7 @@ jax_py_test( "//jax", "//jax:source_mapper", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1665,12 +2078,19 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "string_array_test", srcs = ["string_array_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1682,6 +2102,7 @@ jax_multiplatform_test( "gpu_h100", ], tags = ["multiaccelerator"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1691,6 +2112,10 @@ jax_multiplatform_test( shard_count = { "gpu": 4, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1700,7 +2125,7 @@ jax_py_test( "//jax", "//jax:experimental", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) exports_files( diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 24acb1b9a3f2..75e1df335f6f 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -41,7 +41,11 @@ jax_multiplatform_test( ], deps = [ "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -53,7 +57,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -83,7 +90,10 @@ jax_py_test( "//jax", "//jax:mosaic_gpu", "//jax:test_util", - ] + py_deps("absl/testing"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -95,7 +105,11 @@ jax_multiplatform_test( deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -110,7 +124,10 @@ jax_multiplatform_test( ], deps = [ "//jax:mosaic_gpu", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 3769da27a1eb..49a05ee487f0 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -54,7 +54,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -68,7 +71,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -88,7 +94,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -124,7 +133,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -162,7 +175,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -182,7 +199,11 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -202,7 +223,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -221,7 +245,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", # build_cleaner: keep - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -240,7 +267,7 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu_ops", # build_cleaner: keep - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -253,7 +280,10 @@ jax_py_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep "//jax:test_util", - ] + jax_gpu_support_deps, + ] + jax_gpu_support_deps + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -273,7 +303,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -294,7 +327,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -307,7 +343,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -321,7 +360,11 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -360,7 +403,10 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -398,7 +444,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -415,7 +465,10 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -436,7 +489,11 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -449,7 +506,10 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -464,7 +524,10 @@ jax_multiplatform_test( deps = [ "//jax:extend", "//jax:pallas_tpu", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -481,7 +544,10 @@ jax_multiplatform_test( "//jax:pallas_tpu", "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -494,7 +560,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -507,7 +576,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -525,7 +597,10 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -543,7 +618,10 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -560,7 +638,11 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -575,7 +657,10 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) # This test doesn't need a TPU; it only tests numpy-using helpers. @@ -588,7 +673,11 @@ jax_py_test( "//jax", "//jax:pallas_tpu_ops", "//jax:test_util", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -606,7 +695,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -629,7 +721,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -647,7 +742,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -663,7 +761,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -680,7 +781,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_experimental_gpu_ops", "//jax:pallas_mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -701,7 +805,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -722,7 +829,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -758,5 +868,8 @@ jax_multiplatform_test( "//jax:pallas_tpu", "//jax:pallas_tpu_ops", "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) From a98f533cae69ef6f7a8d74d6785d6cddd3bffaaf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 May 2025 13:08:43 -0700 Subject: [PATCH 1152/1769] Testing: change jax.extend.ffi to jax.ffi PiperOrigin-RevId: 758344491 --- docs/ffi/CMakeLists.txt | 2 +- examples/ffi/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt index 9d3e9df7d3bf..b7f1af5c1a1b 100644 --- a/docs/ffi/CMakeLists.txt +++ b/docs/ffi/CMakeLists.txt @@ -4,7 +4,7 @@ project(rms_norm LANGUAGES CXX) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index ea7670b81ccc..a7b8869a64b5 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -6,7 +6,7 @@ option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") From d89f2a56d2d771a77638ffe1c53f6b05a00ca4c2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 May 2025 13:54:07 -0700 Subject: [PATCH 1153/1769] Fix a stale type annotation. PiperOrigin-RevId: 758362952 --- jaxlib/_jax/__init__.pyi | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index c9c25e172161..1d7f3042e8a3 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -645,7 +645,9 @@ def batched_device_put( sharding: Any, shards: Sequence[Any], devices: list[Device], - committed: bool = True, + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., ) -> ArrayImpl: ... def reorder_shards( x: ArrayImpl, From b5a595571bb21ae5e8a93f9f0150cf289be4f423 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 May 2025 14:07:33 -0700 Subject: [PATCH 1154/1769] jax.numpy: make type stubs consistent with runtime --- jax/numpy/__init__.pyi | 63 +++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index c52ce2628cda..4db407861f34 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -253,7 +253,8 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] ) -> tuple[int | _core.Tracer, ...]: ... -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, *, + out_sharding: NamedSharding | P | None = None) -> Array: ... c_: _CClass can_cast = _np.can_cast def cbrt(x: ArrayLike, /) -> Array: ... @@ -267,6 +268,7 @@ def clip( /, min: ArrayLike | None = ..., max: ArrayLike | None = ..., + *, a: ArrayLike | DeprecatedArg | None = ..., a_min: ArrayLike | DeprecatedArg | None = ..., a_max: ArrayLike | DeprecatedArg | None = ... @@ -278,7 +280,7 @@ complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., *, size: int | None = ..., fill_value: ArrayLike = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( @@ -314,9 +316,9 @@ def cross( axis: int | None = ..., ) -> Array: ... csingle: Any -def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., @@ -371,7 +373,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -385,7 +386,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -397,7 +397,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -422,7 +421,7 @@ def einsum_path( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... -def empty(shape: Any, dtype: DTypeLike | None = ..., +def empty(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., @@ -579,17 +578,17 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = . def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... -def iscomplex(m: ArrayLike) -> Array: ... +def iscomplex(x: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... -def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: builtins.bool = ..., invert: builtins.bool = ..., method: str = ...) -> Array: ... +def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., + invert: builtins.bool = ..., *, method: str = ...) -> Array: ... def isinf(x: ArrayLike, /) -> Array: ... def isnan(x: ArrayLike, /) -> Array: ... def isneginf(x: ArrayLike, /) -> Array: ... def isposinf(x: ArrayLike, /) -> Array: ... -def isreal(m: ArrayLike) -> Array: ... +def isreal(x: ArrayLike) -> Array: ... def isrealobj(x: Any) -> builtins.bool: ... def isscalar(element: Any) -> builtins.bool: ... def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> builtins.bool: ... @@ -644,7 +643,7 @@ def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( - n: int, mask_func: Callable, k: int = ... + n: int, mask_func: Callable, k: int = ..., *, size: int | None = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., @@ -655,7 +654,7 @@ def max(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., @@ -689,14 +688,14 @@ def nanargmin( out: None = ..., keepdims: builtins.bool | None = ..., ) -> Array: ... -def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... @@ -710,21 +709,21 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - ddof: int = ..., keepdims: builtins.bool = ..., +def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., + out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... -def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... @@ -740,7 +739,7 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: DTypeLike | None = ..., +def ones(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., @@ -782,7 +781,7 @@ def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions -def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... @@ -805,7 +804,6 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = ..., order: str = ...) -> Array: ... def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... -register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, total_repeat_length: int | None = ..., @@ -844,7 +842,8 @@ def setdiff1d( size: int | None = ..., fill_value: ArrayLike | None = ..., ) -> Array: ... -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., *, + size: int | None = ..., fill_value: ArrayLike | None = ...) -> Array: ... def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... @@ -882,14 +881,14 @@ def stack( out: None = ..., dtype: DTypeLike | None = ..., ) -> Array: ... -def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... subtract: BinaryUfunc def sum( a: ArrayLike, axis: _Axis = ..., - dtype: DTypeLike = ..., + dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., @@ -927,7 +926,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., axis: int = ...) -> Array: ... def tri( - N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike | None = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( @@ -970,7 +969,7 @@ class _UniqueInverseResult(NamedTuple): def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., return_counts: builtins.bool = ..., axis: int | None = ..., *, equal_nan: builtins.bool = ..., size: int | None = ..., - fill_value: ArrayLike | None = ... + fill_value: ArrayLike | None = ..., sorted: bool = ..., ): ... def unique_all(x: ArrayLike, /, *, size: int | None = ..., fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... @@ -994,7 +993,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., def vander( x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... -def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( @@ -1029,7 +1028,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = ..., fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array | tuple[Array, ...]: ... -def zeros(shape: Any, dtype: DTypeLike | None = ..., +def zeros(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., From 6971e125a904ec6af6336c823154f0eda2d7f86a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 May 2025 14:18:03 -0700 Subject: [PATCH 1155/1769] [doc] mention jaxlib in the API compatibility doc --- docs/api_compatibility.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 9dca1fc08f50..dda86e2e5d31 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -59,6 +59,11 @@ Any API or import path prefixed with an underscore is explicitly private, and may change without warning between JAX releases. We are working to move all private APIs into `jax._src` to make these expectations more clear. +### jaxlib +Any import path in the `jaxlib` package is considered private, and may change +without warning between releases. Some APIs defined in `jaxlib` have public +aliases in the `jax` package. + ### Legacy internal APIs In addition, there are several legacy modules that currently expose some private APIs without an underscore, including: From b2549761e1cdee37b4d990bb28e97fdab7a99fe7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 May 2025 14:32:48 -0700 Subject: [PATCH 1156/1769] Avoid duplication of Bazel dependencies between //jax/_src/lib and //jaxlib. Put all dependencies on //jaxlib and make //jax/_src/lib a pure forwarding rule. PiperOrigin-RevId: 758378774 --- jax/_src/lib/BUILD | 25 --------------------- jaxlib/BUILD | 56 +++++++++++++++++++++++++++++----------------- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 4bbc861432aa..e0b5ea607501 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -40,30 +40,5 @@ py_library_providing_imports_info( "//jax:version", ] + if_building_jaxlib([ "//jaxlib", - "//jaxlib/mosaic/python:gpu_dialect", - "//jaxlib/mosaic/python:tpu_dialect", - "//jaxlib:cpu_feature_guard", - "//jaxlib:utils", - "//jaxlib:weakref_lru_cache", - "//jaxlib:xla_client", - "//jaxlib:_jax", - "//jaxlib/triton", - "//jaxlib/mlir/_mlir_libs:register_jax_dialects", - "//jaxlib/mlir:arithmetic_dialect", - "//jaxlib/mlir:builtin_dialect", - "//jaxlib/mlir:chlo_dialect", - "//jaxlib/mlir:control_flow_dialect", - "//jaxlib/mlir:func_dialect", - "//jaxlib/mlir:ir", - "//jaxlib/mlir:math_dialect", - "//jaxlib/mlir:memref_dialect", - "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:pass_manager", - "//jaxlib/mlir:scf_dialect", - "//jaxlib/mlir:sdy_dialect", - "//jaxlib/mlir:sparse_tensor_dialect", - "//jaxlib/mlir:stablehlo_dialect", - "//jaxlib/mlir:vector_dialect", - "@xla//xla/python:_profiler", ]), ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e0fb2699a25e..add6dbd7d92a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -22,7 +22,6 @@ load( "nanobind_extension", "proto_library", "py_deps", - "py_library_providing_imports_info", "py_strict_test", "pytype_library", "pytype_strict_library", @@ -49,33 +48,17 @@ package_group( ], ) -py_library_providing_imports_info( +pytype_strict_library( name = "jaxlib", - srcs = [ - "cpu_sparse.py", - "gpu_common_utils.py", - "gpu_linalg.py", - "gpu_prng.py", - "gpu_rnn.py", - "gpu_solver.py", - "gpu_sparse.py", - "gpu_triton.py", - "hlo_helpers.py", - "init.py", - "lapack.py", - "plugin_support.py", - "xla_client.py", - ":version", - ], data = [":ffi_headers"], - lib_rule = pytype_library, deps = [ + ":_jax", ":cpu_feature_guard", ":jax", + ":jaxlib_files", ":utils", ":weakref_lru_cache", - "//jaxlib:_jax", - "//jaxlib:xla_client", + ":xla_client", "//jaxlib/cpu:_lapack", "//jaxlib/cpu:_sparse", "//jaxlib/mlir", @@ -98,8 +81,39 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mosaic", + "//jaxlib/mosaic/python:gpu_dialect", + "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib/triton", + "@xla//xla/python:_profiler", + ], +) + +pytype_library( + name = "jaxlib_files", + srcs = [ + "cpu_sparse.py", + "gpu_common_utils.py", + "gpu_linalg.py", + "gpu_prng.py", + "gpu_rnn.py", + "gpu_solver.py", + "gpu_sparse.py", + "gpu_triton.py", + "hlo_helpers.py", + "init.py", + "lapack.py", + "plugin_support.py", + "xla_client.py", + ":version", + ], + deps = [ + ":_jax", + "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:stablehlo_dialect", ], ) From 7887298da20d00af7db58a524beeeebdef4f776a Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 13 May 2025 15:27:19 -0700 Subject: [PATCH 1157/1769] [Mosaic TPU] Fold reshape (..., M, N, 128) -> (..., M, N * 128) with any tiling and dtype to load. Now we should support decent number of reshape that match with this pattern. And it is much more efficient to fold reshape (retiling + re-pack) into load. Take bf16(256, 8, 128) -> bf16(256, 8 * 128) as example, this cl emits 148 bundles and compared to before 630 bundles - about 4.2x speedup. PiperOrigin-RevId: 758398675 --- jaxlib/mosaic/dialect/tpu/tpu.td | 2 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 3 +- .../tpu/transforms/canonicalize_mosaic.cc | 172 +++++++++++++++++- jaxlib/mosaic/dialect/tpu/util.cc | 23 +++ jaxlib/mosaic/dialect/tpu/util.h | 31 +++- tests/pallas/tpu_ops_test.py | 61 ++++++- 6 files changed, 277 insertions(+), 15 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 226fc6285192..29ce9c84de07 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1006,6 +1006,8 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func let options = [ Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, Option<"compatibility_mode", "compatibility-mode", "bool", /*default=*/"1", "">, + Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, + Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 2afaf08f29ed..798386b92744 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -74,7 +74,8 @@ std::unique_ptr> createInferMemRefLayoutPass( const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation = -1, bool compatibility_mode = true); + int hardware_generation = -1, bool compatibility_mode = true, + std::array target_shape = {8, 128}); std::unique_ptr> createInferVectorLayoutPass( int hardware_generation = -1, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 71e48539f4dd..645e6d615722 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -66,6 +67,8 @@ struct CanonicalizeContext { bool compatibility_mode; int hardware_generation; + + std::array target_shape; }; bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, @@ -753,6 +756,149 @@ LogicalResult canonicalize_vector_transpose(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_reshape(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + // We can canonicalize some reshape(load(x)) -> strided load + ALU ops. + auto src = op.getSource(); + auto src_ty = src.getType(); + auto tgt_ty = op.getType(); + if (auto load_op = src.getDefiningOp()) { + // Pattern match (..., M, N, 128) -> (..., M, N * 128). + // This reshape can be folded into the load for any dtype and tiling + // as long as the minormost dim is 128 and N is aligned to packing. The + // pseudo code is: + // ``` + // src_ref: (M, N, 128) with src_ty + // + // def load_to_reshape(src_ref): + // b_ref = src_ref.bitcast(i32) # i32[M, N / packing, 128] + // r_ref = b_ref.reshape(M * N / packing, 128) + // chunks = [] + // for i in range(N / packing): + // v = r_ref[i::N / packing, :] # i32[M, 128] + // for j in range(packing): + // chunk = v >> (j * bitwidth) + // chunks.append(chunk) + // res = concat(chunks, axis=-1) # i32[M, N * 128] + // # int_src_ty refers to int type with the same bitwidth as src_ty. + // res = res.astype(int_src_ty) # Trigger i32 -> int_src_ty packing. + // return bitcast(res, src_ty) # src_ty[M, N * 128] + // ``` + // TODO(jevinjiang): we can extend this to support folding more dims to last + // dim not just last 2 dims. + auto bitwidth = src_ty.getElementTypeBitWidth(); + auto packing = 32 / bitwidth; + if (packing <= 0) { + return op.emitOpError("Unsupported bitwidth = ") << bitwidth; + } + // Memref bitcast is not supported if HW generation is below 4. We don't + // return failure because we will rely on vector reshape. + if ((ctx.hardware_generation < 4 && packing > 1) || + (ctx.hardware_generation == 4 && packing > 2)) { + return success(); + } + auto ref = load_op.getBase(); + auto indices = load_op.getIndices(); + auto ref_shape = ref.getType().getShape(); + auto src_shape = src_ty.getShape(); + auto tgt_shape = tgt_ty.getShape(); + int ref_rank = ref_shape.size(); + int src_rank = src_shape.size(); + int tgt_rank = tgt_shape.size(); + if (ref_rank != src_rank) { + return op.emitOpError("Loaded vector rank and memref rank mismatch"); + } + // Check the memref's eligibility. + if (!isContiguousMemref(ref) || ref_rank <= 2 || + // TODO(jevinjiang): add support for partial load on last 2 dims where + // last 2 indices are not necessarily 0 or load shape is not full. + getIntConst(indices[ref_rank - 1]) != 0 || + getIntConst(indices[ref_rank - 2]) != 0 || + ref_shape[ref_rank - 1] != src_shape[src_rank - 1] || + ref_shape[ref_rank - 2] != src_shape[src_rank - 2]) { + return success(); + } + // Check the reshape's eligibility. + if (src_rank != tgt_rank + 1 || src_shape[src_rank - 2] % packing != 0 || + src_shape[src_rank - 1] != ctx.target_shape[1] || + src_shape[src_rank - 2] * src_shape[src_rank - 1] != + tgt_shape[tgt_rank - 1]) { + return success(); + } + // At this point, the pattern is matched. + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto loc = op.getLoc(); + // First, we bitcast and reshape src ref from (..., M, N, 128) to + // i32(..., M * N / packing, 128). + SmallVector bitcast_shape(ref_shape); + // TODO(jevinjiang): once we have memref pad op, we can use ceiling + // division to ref_shape[ref_rank - 2] and packing to get sublane_cnt. + CHECK_EQ(ref_shape[ref_rank - 2] % packing, 0); + auto i32_2nd_minor_size = ref_shape[ref_rank - 2] / packing; + bitcast_shape[ref_rank - 2] = i32_2nd_minor_size; + auto i32_ref = builder.create( + MemRefType::get(bitcast_shape, builder.getI32Type()), ref); + + SmallVector reshape_shape(ref_shape.begin(), + ref_shape.begin() + tgt_rank); + reshape_shape[tgt_rank - 1] = ctx.target_shape[1]; + reshape_shape[tgt_rank - 2] = ref_shape[ref_rank - 3] * i32_2nd_minor_size; + auto reshape_ref = builder.create( + MemRefType::get(reshape_shape, builder.getI32Type()), i32_ref); + + // We also need to transform the indices while transforming the memref. + SmallVector new_indices(indices.begin(), indices.begin() + tgt_rank); + new_indices[tgt_rank - 1] = IdxConst(0, builder, loc); + new_indices[tgt_rank - 2] = builder.create( + builder.getIndexType(), indices[ref_rank - 3], + IdxConst(i32_2nd_minor_size, builder, loc)); + // Then, we strided load the bitcasted ref by stride (N / packing). + int stride = i32_2nd_minor_size; + // Expect to hold src_shape[src_rank - 2] number of chunks which have the + // shape (..., src_shape[src_rank - 3], 128) and wait to be concatenated + // along the last dim. + SmallVector chunks(src_shape[src_rank - 2]); + SmallVector chunk_shape(tgt_shape); + chunk_shape[tgt_rank - 1] = ctx.target_shape[1]; + SmallVector strides(tgt_rank, 1); + strides[tgt_rank - 2] = stride; + auto tgt_2nd_minor_idx = new_indices[tgt_rank - 2]; + for (int i = 0; i < stride; ++i) { + new_indices[tgt_rank - 2] = builder.create( + builder.getIndexType(), tgt_2nd_minor_idx, IdxConst(i, builder, loc)); + auto chunk = builder.create( + VectorType::get(chunk_shape, builder.getI32Type()), reshape_ref, + new_indices, strides); + for (int j = 0; j < packing; ++j) { + int idx = i * packing + j; + chunks[idx] = builder.create( + chunk.getType(), chunk, + I32Const(j * bitwidth, chunk_shape, builder, loc)); + } + } + // Concatenate the chunks along the last dim to get i32(..., M, N * 128). + CHECK_GT(chunks.size(), 0); + Value i32_tgt = chunks[0]; + if (chunks.size() > 1) { + i32_tgt = builder.create( + VectorType::get(tgt_shape, builder.getI32Type()), chunks, + /*dimension=*/tgt_rank - 1); + } + Value tgt = i32_tgt; + // Convert to target dtype. + if (packing > 1) { + tgt = builder.create( + VectorType::get(tgt_shape, builder.getIntegerType(bitwidth)), + i32_tgt); + } + tgt = builder.create(tgt_ty, tgt); + op.replaceAllUsesWith(tgt); + op.erase(); + } + return success(); +} + using canonicalize_rule_type = std::function; @@ -764,6 +910,7 @@ const llvm::StringMap &rules() { {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, {vector::TransposeOp::getOperationName(), canonicalize_vector_transpose}, + {vector::ShapeCastOp::getOperationName(), canonicalize_reshape}, {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, {arith::SIToFPOp::getOperationName(), canonicalize_sitofp}, @@ -808,12 +955,15 @@ bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, class MosaicCanonicalizer { public: - MosaicCanonicalizer(int hardware_generation, bool compatibility_mode) + MosaicCanonicalizer(int hardware_generation, bool compatibility_mode, + std::array target_shape) : hardware_generation_(hardware_generation), - compatibility_mode_(compatibility_mode) {} + compatibility_mode_(compatibility_mode), + target_shape_(target_shape) {} int hardware_generation_; bool compatibility_mode_; + std::array target_shape_; LogicalResult canonicalize(func::FuncOp op) { if (!op.getBody().hasOneBlock()) { @@ -834,7 +984,8 @@ class MosaicCanonicalizer { } LogicalResult canonicalizeOp(Operation &any_op) { - CanonicalizeContext ctx({compatibility_mode_, hardware_generation_}); + CanonicalizeContext ctx( + {compatibility_mode_, hardware_generation_, target_shape_}); // We must iterate over the op first, because canonicalization can cause // us to .erase() an op, and accessing getRegions on it after is not sound. // Invariant - top level ops with regions may never be invalidated. @@ -859,14 +1010,18 @@ class MosaicCanonicalizer { struct CanonicalizeMosaicPass : public impl::CanonicalizeMosaicPassBase { - CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p) + CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p, + std::array target_shape) : compatibility_mode_(compatibility_mode_p) { this->hardware_generation = hardware_generation_p; + this->sublane_count = target_shape[0]; + this->lane_count = target_shape[1]; } void runOnOperation() override { func::FuncOp func = getOperation(); - MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_); + MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_, + {sublane_count, lane_count}); if (vlc.canonicalize(func).failed()) { signalPassFailure(); } @@ -878,9 +1033,10 @@ struct CanonicalizeMosaicPass } // namespace std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation, bool compatibility_mode) { - return std::make_unique(hardware_generation, - compatibility_mode); + int hardware_generation, bool compatibility_mode, + std::array target_shape) { + return std::make_unique( + hardware_generation, compatibility_mode, target_shape); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index bb42c678bbf6..b562f81ad534 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -27,7 +27,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" @@ -159,6 +161,17 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, *(tiled_layout.getTileStrides().end() - 2) == 1; } +bool isContiguousMemref(TypedValue memref) { + auto memref_ty = getMemRefType(memref); + if (auto tiled_layout = + dyn_cast(memref_ty.getLayout())) { + auto contiguous_tile_strides = ComputeTileStrides( + memref_ty, tiled_layout.getTiles().front().dimensions()); + return contiguous_tile_strides == tiled_layout.getTileStrides(); + } + return true; +} + bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) { auto memory_space = dyn_cast_or_null(ty.getMemorySpace()); @@ -278,4 +291,14 @@ void setLayout(Operation *op, ArrayRef in, ArrayRef out) { setInLayout(op, in); setOutLayout(op, out); } + +std::optional getIntConst(Value v) { + if (auto const_op = v.getDefiningOp()) { + if (auto cst_attr = dyn_cast(const_op.getValue())) { + return cst_attr.getValue().getSExtValue(); + } + } + return std::nullopt; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index b9aea1b087dc..000cb4411e62 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -195,11 +195,6 @@ ArrayRef> toArrayRef(absl::Span span) { return ArrayRef>(span.data(), span.size()); } -inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, - Location loc) { - return builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(idx)); -} // Debug only util. template @@ -242,6 +237,8 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array &target_shape, bool allow_minormost_padding = false); +bool isContiguousMemref(TypedValue memref); + // Determines whether the given MemRefType has the given memory space. bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space); @@ -264,6 +261,30 @@ void setLayout(Operation *op, Layout in, Layout out); void setLayout(Operation *op, ArrayRef in, Layout out); void setLayout(Operation *op, Layout in, ArrayRef out); void setLayout(Operation *op, ArrayRef in, ArrayRef out); + +// Helper functions to create constants. +inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, + Location loc) { + return builder.create(loc, builder.getIndexType(), + builder.getIndexAttr(idx)); +} + +inline arith::ConstantOp I32Const(int32_t value, OpBuilder &builder, + Location loc) { + return builder.create(loc, builder.getI32Type(), + builder.getI32IntegerAttr(value)); +} + +inline arith::ConstantOp I32Const(int32_t value, ArrayRef shape, + OpBuilder &builder, Location loc) { + return builder.create( + loc, DenseElementsAttr::get( + VectorType::get(shape, builder.getI32Type()), + builder.getIntegerAttr(builder.getI32Type(), value))); +} + +// TODO(jevinjiang): consolidate this with getIntConst in apply-vector-layout. +std::optional getIntConst(Value v); } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 1fb0bc24701b..de87126ebd3f 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -38,16 +38,36 @@ jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) -_JAX_DTYPES = ( +_JAX_DTYPES_NO_BOOL = ( jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16, jnp.int8, + jnp.int4, + jnp.float8_e5m2, +) + +_JAX_DTYPES = ( + *_JAX_DTYPES_NO_BOOL, jnp.bool_, ) +def rand( + shape: tuple[int, ...], dtype: np.dtype | jnp.dtype, seed: int = 1234 +) -> np.ndarray: + """A helper function to generate random data for testing.""" + rng = np.random.Generator(np.random.Philox(counter=0, key=seed)) + if jnp.issubdtype(dtype, jnp.floating): + return rng.normal(size=shape).astype(dtype) + if jnp.issubdtype(dtype, jnp.integer): + return rng.integers( + jnp.iinfo(dtype).min, jnp.iinfo(dtype).max, shape, dtype=np.int32 + ).astype(dtype) + raise NotImplementedError(f"Unsupported random data generation for {dtype=}") + + class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -511,6 +531,45 @@ def kernel(src, tgt): output[tuple(slice(0, d) for d in src_shape)], x ) + # TODO(jevinjiang): we need to support strided load for bool. + @parameterized.product(dtype=_JAX_DTYPES_NO_BOOL) + @hp.given( + slice_start=hps.integers(0, 3), + slice_size=hps.integers(1, 3), + m=hps.integers(1, 32), + # Need to make sure the 2nd minor has no padding. + n=hps.sampled_from([1, 2, 4, 8, 16, 24, 32]), + ) + @hp.settings(max_examples=20) # 20 examples for each dtype. + def test_load_to_reshape(self, dtype, slice_start, slice_size, m, n): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest("Requires libtpu built after 2025-05-15") + bitwidth = pallas_utils.dtype_bitwidth(dtype) + if jtu.get_tpu_version() < 4 and bitwidth != 32: + self.skipTest("Requires TPUv4+ for non-32-bit types") + if jtu.get_tpu_version() == 4 and bitwidth <= 8: + self.skipTest("Int8 is not supported on this target") + packing = 32 // bitwidth + n *= packing + slices = ( + slice(slice_start, slice_start + slice_size), + slice(slice_start, slice_start + m), + slice(None), + slice(None), + ) + inp_shape = (8, 64, n, 128) + out_shape = (slice_size, m, n * 128) + + def kernel(inp_ref, out_ref): + inp = inp_ref[slices] + out_ref[...] = inp.reshape(out_shape) + + inp = rand(inp_shape, dtype, seed=1234) + run = pl.pallas_call(kernel, jax.ShapeDtypeStruct(out_shape, dtype)) + output = run(inp) + expected = inp[slices].reshape(out_shape) + np.testing.assert_array_equal(output, expected) + @jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsInterpretTest(OpsTest): From 49e3432ba005a69e9287f7ee60216d7f390e0eae Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 13 May 2025 15:50:39 -0700 Subject: [PATCH 1158/1769] Pass RULES_PYTHON_REPO_DEBUG value to the docker container env MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Set RULES_PYTHON_REPO_DEBUG environment variable to 0 by default. If set to 1, repository rules will print debug information about what they’re doing. This is needed to help us debug python 3.14 builds. PiperOrigin-RevId: 758407455 --- ci/utilities/run_docker_container.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/utilities/run_docker_container.sh b/ci/utilities/run_docker_container.sh index b12566182331..e0a4592cdf6f 100755 --- a/ci/utilities/run_docker_container.sh +++ b/ci/utilities/run_docker_container.sh @@ -56,6 +56,8 @@ if ! docker container inspect jax >/dev/null 2>&1 ; then # variables to the container. JAXCI_TEMP_ENVFILE_DIR=$(mktemp) env | grep -e "JAXCI_" -e "JAX_" -e "JAXLIB_" > "$JAXCI_TEMP_ENVFILE_DIR" + # TODO(kanglan): Remove this once the rules python debug is done. + echo "RULES_PYTHON_REPO_DEBUG=${RULES_PYTHON_REPO_DEBUG:-0}" >> "$JAXCI_TEMP_ENVFILE_DIR" # On Windows, convert MSYS Linux-like paths to Windows paths. if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then From c7a4e34bea51921b1d359be37ef0b5d59de944ab Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 13 May 2025 16:14:03 -0700 Subject: [PATCH 1159/1769] Deprecate the no-op custom_jvp_call_jaxpr_p import stub. The `custom_jvp_call_jaxpr_p` primitive has not been used for a long time, and the existing object is just an import stub. Let's try to clean up some @mattjj TODOs! Sadly, since this lives in the public API, I think we need to do a full deprecation cycle, so let's at least get that started! PiperOrigin-RevId: 758415957 --- CHANGELOG.md | 4 ++++ jax/custom_derivatives.py | 21 ++++++++++++++++++++- jax/extend/core/primitives.py | 1 - 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 518e854b5bb1..a0c30132c169 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.ShapeDtypeStruct` is immutable now. Please use `.update` method to update your `ShapeDtypeStruct` instead of doing in-place updates. +* Deprecations + * `jax.custom_derivatives.custom_jvp_call_jaxpr_p` is deprecated, and will be + removed in JAX v0.7.0. + ## JAX 0.6.0 (April 16, 2025) * Breaking changes diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index b768b687dfad..6674046dd8e8 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -23,7 +23,7 @@ custom_gradient as custom_gradient, custom_jvp as custom_jvp, custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, + custom_jvp_call_jaxpr_p as _custom_jvp_call_jaxpr_p, custom_vjp as custom_vjp, custom_vjp_call_p as custom_vjp_call_p, custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, @@ -36,3 +36,22 @@ SymbolicZero as SymbolicZero, zero_from_primal as zero_from_primal ) + +_deprecations = { + # Added May 12, 2025 + "custom_jvp_call_jaxpr_p": ( + ("jax.custom_derivatives.custom_jvp_call_jaxpr_p is deprecated, use " + "jax.extend.core.primitives.custom_jvp_call_p instead."), + _custom_jvp_call_jaxpr_p, + ), +} + +import typing +if typing.TYPE_CHECKING: + custom_jvp_call_jaxpr_p = _custom_jvp_call_jaxpr_p +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _custom_jvp_call_jaxpr_p diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index 30350dace637..515dd3e11dcf 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -24,7 +24,6 @@ from jax._src.custom_derivatives import ( custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp_call_p as custom_vjp_call_p, ) From e2c4a7f53fb95ae30ef2ff74fcb7107e05a94aa9 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Tue, 13 May 2025 16:15:29 -0700 Subject: [PATCH 1160/1769] Global find/replace in jax for s/divisble/divisible. Spotted this in an error message, then saw there were a few other places with this typo. PiperOrigin-RevId: 758416484 --- jax/_src/lax/slicing.py | 5 +++-- jaxlib/mosaic/gpu/runtime.cc | 2 +- tests/pallas/pallas_jumble_test.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 9f4645dca975..ad8a2cf0b315 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1353,9 +1353,10 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): raise core.ShardingTypeError( - f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" + f"{name} on sharded dims where out dim ({out_sh}) is not divisible by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" - f" ({op_spec}) is not implemented.") + f" ({op_spec}) is not implemented." + ) # TODO(yashkatariya): Returning operand.sharding as is may or may not move # data. So think about how to avoid it which might include creating a new # mesh? For example: diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index cb48a20dc3d5..da7b0159d7b2 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -115,7 +115,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, if (tma_stride_i % 16 != 0 || tma_stride_i >= static_cast(1) << 40) { fprintf(stderr, - "Byte strides must be divisble by 16 and less than 2**40, but " + "Byte strides must be divisible by 16 and less than 2**40, but " "got %ld (item stride = %ld, item size = %ld) at index %ld\n", tma_stride_i, strides[rank - 1], elem_bytewidth, rank - i - 2); abort(); diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index 509ef08a987f..0a2994a84a8f 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -354,7 +354,7 @@ def invoke_kernel(x): with self.assertRaisesRegex( ValueError, - "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + "Ragged input shape must be evenly divisible by the grid" # noqa: W605 " size at the ragged dimension 2", ): jax.vmap( From 9b2043e4f2b434d640a9460fc730a7b7300d9d82 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 13 May 2025 17:15:23 -0700 Subject: [PATCH 1161/1769] Prepare to make DmaCopyChunk movable. PiperOrigin-RevId: 758436699 --- jaxlib/py_socket_transfer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index ed2a4f4c204a..114e3c14874d 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -132,9 +132,9 @@ class IfrtArrayEntry : public PullTable::Entry { std::min(xfer_size_, arrs_[bid].buf_size - i * xfer_size_)); bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; state_->ScheduleCopy( - blob, [req_id, state, copier_state = state_, is_largest]( - PremappedCopierState* copier_state_ptr, void* buf, - const DmaCopyChunk& chunk) { + std::move(blob), [req_id, state, copier_state = state_, is_largest]( + PremappedCopierState* copier_state_ptr, + void* buf, const DmaCopyChunk& chunk) { state->Send( req_id, buf, chunk.offset, chunk.size, is_largest, [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); From dfb5a93f3740c5384125a4f7791e4f8ccbb8575b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 May 2025 21:30:19 -0700 Subject: [PATCH 1162/1769] allow accelerated deprecations from jax.tree_util PiperOrigin-RevId: 758500294 --- jax/tree_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/tree_util.py b/jax/tree_util.py index 3d24c457b3f8..9f42284144ec 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -90,5 +90,5 @@ else: from jax._src.deprecations import deprecation_getattr __getattr__ = deprecation_getattr(__name__, _deprecations) - del deprecation_getattr, _deprecations + del deprecation_getattr del _typing From 3d64dd7895fe346646c04e031522c77d319f16fd Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 May 2025 04:21:21 -0700 Subject: [PATCH 1163/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9c00cd639a2076210ecc68e381ad242b95d46be4. PiperOrigin-RevId: 758618886 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a4a7c8cb7dcd..e116decd5c4b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c9b736e8c217529795badfe9cf7730b5d0f38242" -XLA_SHA256 = "d5fb6fb909a81838e4fa5f5dc755365a10132bc4acec6ee18df815d3cda6df2f" +XLA_COMMIT = "9c00cd639a2076210ecc68e381ad242b95d46be4" +XLA_SHA256 = "4987bf859992058d6c4bfc5e1b9300998a3c1aa20dc6c74e880b5b473668a606" def repo(): tf_http_archive( From 67e1d5c2c8baf4face571e637deddd363cd864b9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 04:37:39 -0700 Subject: [PATCH 1164/1769] Simplify if_building_jaxlib macro. I don't believe the GPU case of this macro ever matters, so this can be a condition strictly about how we're building jaxlib. No behavioral changes intended. PiperOrigin-RevId: 758624138 --- jax/BUILD | 20 +++++++++++++++----- jaxlib/jax.bzl | 38 +++++++++++++++----------------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 16bc9de6935e..5f1cf1670729 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -64,12 +64,26 @@ string_flag( ) config_setting( - name = "enable_jaxlib_build", + name = "config_build_jaxlib_true", flag_values = { ":build_jaxlib": "true", }, ) +config_setting( + name = "config_build_jaxlib_false", + flag_values = { + ":build_jaxlib": "false", + }, +) + +config_setting( + name = "config_build_jaxlib_wheel", + flag_values = { + ":build_jaxlib": "wheel", + }, +) + # The flag controls whether jax should be built by Bazel. # If ":build_jax=true", then jax will be built. # If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel @@ -212,7 +226,6 @@ py_library( ":jax", ], if_not_building = [], - if_not_building_for_cpu = [], ) + py_deps("numpy"), ) @@ -229,7 +242,6 @@ py_library( "//jax/_src/lib", ], if_not_building = [], - if_not_building_for_cpu = [], ) + py_deps("numpy"), ) @@ -243,7 +255,6 @@ py_library( ":test_util", ], if_not_building = [], - if_not_building_for_cpu = [], ), ) @@ -259,7 +270,6 @@ py_library( ":test_util", ], if_not_building = [], - if_not_building_for_cpu = [], ) + py_deps("numpy"), ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e739b681a029..a48c44f406f2 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -158,28 +158,20 @@ def if_building_jaxlib( if_building, if_not_building = [ "@pypi//jaxlib", - "@pypi//jax_cuda12_plugin", - "@pypi//jax_cuda12_pjrt", - ], - if_not_building_for_cpu = [ - "@pypi//jaxlib", ]): - """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. + """Adds jaxlib wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the wheels to depend on including gpu-specific plugins in case of - gpu-enabled builds - if_not_building_for_cpu: the wheels to depend on in case of cpu-only builds + if_not_building: the wheels to depend on if we are not depending directly on //jaxlib. """ return select({ - "//jax:enable_jaxlib_build": if_building, - "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//conditions:default": [], + "//jax:config_build_jaxlib_true": if_building, + "//jax:config_build_jaxlib_false": if_not_building, + "//jax:config_build_jaxlib_wheel": [], }) def _get_test_deps(deps, backend_independent): @@ -192,14 +184,14 @@ def _get_test_deps(deps, backend_independent): Returns: A list of test deps for the given backend. For CPU builds: - If --//jax:enable_jaxlib_build=true, returns pypi test deps. - If --//jax:enable_jaxlib_build=false, returns jaxlib pypi wheel dep and pypi test deps. - If --//jax:enable_jaxlib_build=wheel, returns jaxlib py_import dep and pypi test deps. + If --//jax:build_jaxlib=true, returns pypi test deps. + If --//jax:build_jaxlib=false, returns jaxlib pypi wheel dep and pypi test deps. + If --//jax:build_jaxlib=wheel, returns jaxlib py_import dep and pypi test deps. For GPU builds: - If --//jax:enable_jaxlib_build=true, returns pypi test deps and gpu build deps. - If --//jax:enable_jaxlib_build=false, returns jaxlib, jax-cuda-plugin, + If --//jax:build_jaxlib=true, returns pypi test deps and gpu build deps. + If --//jax:build_jaxlib=false, returns jaxlib, jax-cuda-plugin, jax-cuda-pjrt pypi wheel deps and pypi test deps. - If --//jax:enable_jaxlib_build=wheel, returns jaxlib, + If --//jax:build_jaxlib=wheel, returns jaxlib, jax-cuda-plugin, jax-cuda-pjrt py_import deps and pypi test deps. """ gpu_build_deps = [ @@ -234,7 +226,7 @@ def _get_test_deps(deps, backend_independent): gpu_py_import_deps = gpu_py_imports return select({ - "//jax:enable_jaxlib_build": test_deps, + "//jax:config_build_jaxlib_true": test_deps, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": jaxlib_pypi_wheel_deps, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, @@ -250,9 +242,9 @@ def _get_jax_test_deps(deps): Returns: A list of jax test deps. - If --//jax:enable_jax_build=true, returns jax build deps. - If --//jax:enable_jax_build=false, returns jax pypi wheel dep and transitive pypi test deps. - If --//jax:enable_jax_build=wheel, returns jax py_import dep and transitive pypi test deps. + If --//jax:build_jax=true, returns jax build deps. + If --//jax:build_jax=false, returns jax pypi wheel dep and transitive pypi test deps. + If --//jax:build_jax=wheel, returns jax py_import dep and transitive pypi test deps. """ jax_build_deps = [d for d in deps if not d.startswith("@pypi//")] From 08a0485a49378a8ba688cc06a9da925cafdb4c42 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 14 May 2025 13:47:35 +0200 Subject: [PATCH 1165/1769] Removed few tsan cpython suppressions as fixed --- .github/workflows/tsan-suppressions_3.13.txt | 4 ---- .github/workflows/tsan-suppressions_3.14.txt | 4 ---- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index e82699036e92..aec94dfef004 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -40,7 +40,3 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/132245 race:split_keys_entry_added race_top:dict_dict_merge - -# https://github.com/python/cpython/issues/132013 -# Fixed on 3.14 and not backported to 3.13 -race_top:frozenset_hash \ No newline at end of file diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index ec4d81c987d0..ec5102502a2b 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -18,7 +18,3 @@ race:dscal_k_ race:scal_k_ race:gemm_beta race:gemm_oncopy - -# https://github.com/python/cpython/issues/132214 -# Should be fixed -# race_top:update_one_slot From 50970c6bd21bb82e49b4c5c862ee65f7f0d450cd Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 May 2025 05:49:58 -0700 Subject: [PATCH 1166/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c7dbe1e216d132e7b4042907b8ba9454535edc70. PiperOrigin-RevId: 758642769 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e116decd5c4b..ca75a0be471c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9c00cd639a2076210ecc68e381ad242b95d46be4" -XLA_SHA256 = "4987bf859992058d6c4bfc5e1b9300998a3c1aa20dc6c74e880b5b473668a606" +XLA_COMMIT = "c7dbe1e216d132e7b4042907b8ba9454535edc70" +XLA_SHA256 = "fea578d5f53daec13f9e84c615a2e2ba783b1886c471ff55897505259de6b137" def repo(): tf_http_archive( From 7f5b6e7d02656d3b5f116fe557fcdc0c365f88ee Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 14 May 2025 06:41:48 -0700 Subject: [PATCH 1167/1769] [Pallas/Mosaic GPU] Optimize the construction of output buffers for `plgpu.kernel`. Previously to this change, we would zero initialize all buffers---which was a big overhead. Instead, we now generate an empty custom call in order to only generate an allocation. PiperOrigin-RevId: 758659576 --- jax/_src/pallas/mosaic_gpu/core.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 21c78720e812..e4fe1a7035a7 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -33,6 +33,7 @@ from jax._src import tree_util from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives import jax._src.pallas.utils as pallas_utils from jax._src.state import discharge as state_discharge @@ -194,12 +195,24 @@ def cmap_body(): mesh, compiler_params=compiler_params )(cmap_body) _, outs = state_discharge.run_state(stateful)( - (operands, jax.tree.map(jnp.zeros_like, out_shape)) + (operands, empty_like(out_shape)) ) return outs[0] if unwrap_out else outs return wrapper +def empty_like(shape): + return pallas_call.pallas_call( + lambda *_: None, + out_shape=shape, + out_specs=jax.tree.map( + lambda _: pallas_core.BlockSpec(memory_space=GPUMemorySpace.GMEM), + shape, + ), + backend="mosaic_gpu", + )() + + def _is_known_divisible(value, divisor, fuel=10) -> bool: """Returns True if the value is statically known to be divisible by the divisor.""" if fuel < 0: From af66ca9cfe740e00b9f37b0458aa431187b6b65b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 14 May 2025 06:44:58 -0700 Subject: [PATCH 1168/1769] [pallas:mosaic_gpu] Do not do unnecessary `commit_smem_to_gmem_group` in `emit_pipeline` If the loop does no copies, there is nothing to commit. PiperOrigin-RevId: 758660634 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 41ddf993df3a..4966eec7fa6d 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -326,7 +326,8 @@ def loop_body(step, carry): predicate=lax.bitwise_or(slices_changed, is_last_step), ) - gpu_primitives.commit_smem_to_gmem_group() + if copies_out_in_loop: + gpu_primitives.commit_smem_to_gmem_group() fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = lax.rem(fetch_step, max_concurrent_steps) @@ -367,6 +368,7 @@ def do_fetch(): # loop. This is the only place where we store them. if not copies_out_in_loop: gpu_primitives.commit_smem() + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: From 85a68fcd437c95780063d196a9c944d4c7fab896 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 14 May 2025 07:15:39 -0700 Subject: [PATCH 1169/1769] [Pallas/Mosaic GPU] Allow `_is_known_divisible` to always return `True` when the divisor is `1`. PiperOrigin-RevId: 758670548 --- jax/_src/pallas/mosaic_gpu/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index e4fe1a7035a7..4a5cfbf3517f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -215,6 +215,8 @@ def empty_like(shape): def _is_known_divisible(value, divisor, fuel=10) -> bool: """Returns True if the value is statically known to be divisible by the divisor.""" + if divisor == 1: + return True if fuel < 0: return False if not isinstance(value.owner, ir.Operation): From abc79f3d590360a48245d051c62fdc23351a90d6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 14 May 2025 07:22:32 -0700 Subject: [PATCH 1170/1769] Replace gsutil command with gcloud storage commands. GCP recommends using `gcloud storage` instead of `gsutil` https://cloud.google.com/storage/docs/gsutil PiperOrigin-RevId: 758672654 --- .github/workflows/bazel_cuda_non_rbe.yml | 8 ++++---- .github/workflows/build_artifacts.yml | 4 ++-- .github/workflows/pytest_cpu.yml | 10 +++++----- .github/workflows/pytest_cuda.yml | 8 ++++---- .github/workflows/pytest_tpu.yml | 4 ++-- .github/workflows/wheel_tests_nightly_release.yml | 8 ++++---- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 72878ad7aacb..677d8d869a22 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -84,12 +84,12 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then PYTHON=python${{ inputs.python }} $PYTHON -m pip download jaxlib jax-cuda12-pjrt jax-cuda12-plugin --dest $(pwd)/dist/ diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 37a791784506..1b534ee3b6fc 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -136,13 +136,13 @@ jobs: - name: Upload artifacts to a GCS bucket (non-Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }} - run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ # Set shell to cmd to avoid path errors when using gcloud commands on Windows - name: Upload artifacts to a GCS bucket (Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }} shell: cmd - run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ - name: Store the GCS upload URI as an output id: store-gcs-upload-uri if: ${{ inputs.upload_artifacts_to_gcs }} diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index bdce2b684803..fc4633110667 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -96,12 +96,12 @@ jobs: if: ${{ !contains(inputs.runner, 'windows-x86') }} run: | mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then echo "JAX only release. Only downloading the jax wheel from the release bucket." else - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ fi - name: Download wheels from GCS (Windows runs) id: download-wheel-artifacts-w @@ -113,14 +113,14 @@ jobs: shell: cmd run: | mkdir dist - @REM Use `call` so that we can run sequential gsutil commands on Windows + @REM Use `call` so that we can run sequential gcloud storage commands on Windows @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ if "${{ inputs.download-jax-only-from-gcs }}"=="1" ( echo "JAX only release. Only downloading the jax wheel from the release bucket." ) else ( - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ ) - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 4df752310ace..af034ab09991 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -93,7 +93,7 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ # Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only # release. @@ -104,9 +104,9 @@ jobs: # required dependency of jax so that gets installed automatically. echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=jax_cuda_pypi">> $GITHUB_ENV else - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 55a0b4cc1a5f..2d4d2925bd2f 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -114,11 +114,11 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then echo "JAX only release. Only downloading the jax wheel from the release bucket." else - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 8d597b84f735..3e616a894d13 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -171,15 +171,15 @@ jobs: python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" - gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then - gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ - gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ - gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) From 3142bc3b464dec165008257c62d93afa4e97d919 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 May 2025 07:45:22 -0700 Subject: [PATCH 1171/1769] Rename `SdyArraySharding -> SdyArray` and `SdyDimSharding -> SdyDim` since these are not `Sharding` from a JAX POV. PiperOrigin-RevId: 758679923 --- jax/_src/callback.py | 36 +++++++++++++++++------------------ jax/_src/debugging.py | 10 +++++----- jax/_src/interpreters/mlir.py | 14 +++++++------- jax/_src/named_sharding.py | 24 +++++++++++------------ jax/_src/shard_map.py | 6 +++--- jax/_src/sharding_impls.py | 32 +++++++++++++++---------------- tests/array_test.py | 20 +++++++++---------- tests/pjit_test.py | 14 +++++++------- 8 files changed, 78 insertions(+), 78 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index d23389af16eb..06b36ce5c880 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -40,7 +40,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding +from jax._src.sharding_impls import SdyArray, SdyArrayList, SingleDeviceSharding from jax._src.typing import DeprecatedArg import numpy as np @@ -155,11 +155,11 @@ def _callback_op_sharding( ) if config.use_shardy_partitioner.value: assert len(avals_out) == 1 - op_sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + op_sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_open=False) + sharding_impls.SdyDim(axes=[], is_open=False) ] * avals_out[0].ndim, logical_device_ids=())]) else: @@ -197,8 +197,8 @@ def _callback_op_sharding( # number of result ops. If there are no result ops, we need 1 shardy # annotation. num_sdy_shardings = max(1, len(avals_out)) - op_sharding = sharding_impls.SdyArrayShardingList(num_sdy_shardings * [ - sharding_impls.SdyArraySharding( + op_sharding = sharding_impls.SdyArrayList(num_sdy_shardings * [ + sharding_impls.SdyArray( mesh_shape=(), dimension_shardings=[], logical_device_ids=(device_index,))]) @@ -590,7 +590,7 @@ def send_to_host( operand: Any, name: str, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> ir.Value: channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE) send_op = hlo.SendOp([operand], token, channel_handle, @@ -606,10 +606,10 @@ def send_to_host( # we need to create an equivalent sharding with no dimensions. If there # are multiple shardings, just grab the first one since all these # shardings should be the same. - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 - sharding = SdyArrayShardingList([ - SdyArraySharding( + sharding = SdyArrayList([ + SdyArray( mesh_shape=(), dimension_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(send_op, sharding) @@ -622,7 +622,7 @@ def receive_from_host( out_aval: core.ShapedArray, name: str, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval), @@ -634,7 +634,7 @@ def receive_from_host( _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: if config.use_shardy_partitioner.value: - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the # number of shardings to match the number of results, but JAX only sees @@ -642,9 +642,9 @@ def receive_from_host( # Note that even if a function returns N results, we will end up with N # `RecvOp`s, so we only need to get the first sharding. All shardings are # the same anyways, operating on the same single device ID. - sharding = SdyArrayShardingList([ + sharding = SdyArrayList([ sharding.shardings[0], - SdyArraySharding( + SdyArray( mesh_shape=(), dimension_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(recv_op, sharding) @@ -683,7 +683,7 @@ def _emit_tpu_python_callback( result_avals: Sequence[core.ShapedArray], result_shapes: Sequence[xc.Shape], *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[ir.Value], Any]: token = token or hlo.create_token() _wrapped_callback = callback @@ -738,7 +738,7 @@ def emit_python_callback( *, has_side_effect: bool, partitioned: bool = False, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[mlir.IrValues], Any, Any]: """Emits MLIR that calls back to a provided Python function. @@ -836,12 +836,12 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function config.use_shardy_partitioner.value and sharding is not None and len(ctx.avals_out) > 0 - and isinstance(sharding, sharding_impls.SdyArrayShardingList) + and isinstance(sharding, sharding_impls.SdyArrayList) ): # Add a sharding annotation for the token if we have at least one # output. Otherwise, the single shardy annotation required of all ops # (even those without any results) can annotate the token. - sharding = sharding_impls.SdyArrayShardingList( + sharding = sharding_impls.SdyArrayList( [*sharding.shardings, sharding.shardings[-1]] ) ctx = dataclasses.replace( diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 29cbb01511e9..63abcbef331e 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -165,11 +165,11 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) # program has per-device semantics, so we run the callback on each device. if config.use_shardy_partitioner.value: assert len(ctx.avals_out) == 1 - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_open=False) + sharding_impls.SdyDim(axes=[], is_open=False) ] * ctx.avals_out[0].ndim, logical_device_ids=())]) else: @@ -182,8 +182,8 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) else: sharding = xc.OpSharding() diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f6ef5787ccbf..94418f0b958b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -57,7 +57,7 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( AUTO, NamedSharding, - SdyArraySharding, SdyArrayShardingList, + SdyArray, SdyArrayList, modify_sdy_sharding_wrt_axis_types) from jax._src.state.types import AbstractRef from jax._src.util import foreach @@ -1034,7 +1034,7 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): def _to_physical_op_sharding( ctx: ModuleContext, aval: core.AbstractValue, sharding: JSharding | AUTO | None, -) -> xc.OpSharding | SdyArraySharding | None: +) -> xc.OpSharding | SdyArray | None: if sharding is None: return None if all_unconstrained(sharding, aval): @@ -1839,10 +1839,10 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim - s = SdyArraySharding( + s = SdyArray( mesh_shape=None, dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_open=i < aval.ndim) + sharding_impls.SdyDim(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) return wrap_with_sharding_op(ctx, val, aval, s) @@ -2665,7 +2665,7 @@ def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - sharding: xc.OpSharding | SdyArraySharding, + sharding: xc.OpSharding | SdyArray, unspecified_dims: set[int] | None = None, has_side_effect: bool = False, allow_shardy_lowering: bool = False): @@ -2730,7 +2730,7 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) -def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList): +def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -2738,7 +2738,7 @@ def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardi def get_sharding_attr( - sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList + sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: if config.use_shardy_partitioner.value: return sharding.build() # type: ignore diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 3dfcbd29fc96..45ae1c124a22 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -41,10 +41,10 @@ class AUTO: def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh - def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_open=True) + def _to_sdy_sharding(self, ndim: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=True) for _ in range(ndim)] - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) + return SdyArray(self.mesh.shape_tuple, dim_shardings) class UnspecifiedValue: def __repr__(self): @@ -232,8 +232,8 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_open=False) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: @@ -244,8 +244,8 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: else: dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,) dim_shardings[i].axes = dim_spec - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings, - self._logical_device_ids) + return SdyArray(self.mesh.shape_tuple, dim_shardings, + self._logical_device_ids) NamedSharding.__module__ = 'jax.sharding' @@ -264,7 +264,7 @@ def get_array_mapping( return d @dataclasses.dataclass -class SdyDimSharding: +class SdyDim: axes: Sequence[str] is_open: bool priority: int | None = None @@ -275,7 +275,7 @@ def build(self) -> sdy.DimensionShardingAttr: is_closed=not self.is_open, priority=self.priority) def __repr__(self): - return f'SdyDimSharding({self._custom_repr()})' + return f'SdyDim({self._custom_repr()})' def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) @@ -287,9 +287,9 @@ def _custom_repr(self): @dataclasses.dataclass -class SdyArraySharding: +class SdyArray: mesh_shape: tuple[tuple[str, int], ...] | None - dimension_shardings: Sequence[SdyDimSharding] + dimension_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () @@ -314,7 +314,7 @@ def __repr__(self): if self.logical_device_ids is not None else '') rar = (f', replicated_axes={self.replicated_axes}' if self.replicated_axes else '') - return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" + return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})" @cache(max_size=4096, trace_context_in_key=False) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index b772a3de239e..3eb46da890f2 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -773,7 +773,7 @@ def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool: def _shardy_shard_map_sharding( ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in -) -> sharding_impls.SdyArraySharding: +) -> sharding_impls.SdyArray: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): @@ -808,10 +808,10 @@ def _shard_map_lowering_shardy( dim_var_values=ctx.dim_var_values) return out_nodes - in_shardings = sharding_impls.SdyArrayShardingList(map( + in_shardings = sharding_impls.SdyArrayList(map( partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), in_names, ctx.avals_in)).build() - out_shardings = sharding_impls.SdyArrayShardingList(map( + out_shardings = sharding_impls.SdyArrayList(map( partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), out_names, ctx.avals_out)).build() output_types = map(mlir.aval_to_ir_type, ctx.avals_out) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2394e9e18f38..f7f0ebd2cc26 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -35,7 +35,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 - SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, + SdyArray, SdyDim, UnspecifiedValue, AUTO, _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding) @@ -87,8 +87,8 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] @dataclasses.dataclass -class SdyArrayShardingList: - shardings: Sequence[SdyArraySharding] +class SdyArrayList: + shardings: Sequence[SdyArray] def build(self) -> sdy.TensorShardingPerValueAttr: return sdy.TensorShardingPerValueAttr.get( @@ -97,12 +97,12 @@ def build(self) -> sdy.TensorShardingPerValueAttr: # TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra # parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` -def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): +def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): if mesh._any_axis_auto: dim_shardings, used_axes = [], [] # type: ignore for d in sdy_sharding.dimension_shardings: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDimSharding(axes=[], is_open=True) + dim_shardings.append(SdyDim(axes=[], is_open=True) if not d.axes and not d.is_open else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) @@ -111,8 +111,8 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): remaining_axes = [n for n in mesh.axis_names if n in remaining_axes] replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) - return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, - sdy_sharding.logical_device_ids, replicated_axes) + return SdyArray(sdy_sharding.mesh_shape, dim_shardings, + sdy_sharding.logical_device_ids, replicated_axes) return sdy_sharding @@ -185,10 +185,10 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return replicated_hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - sdy_dim_sharding = [SdyDimSharding(axes=[], is_open=False) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + sdy_dim_sharding = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] - return SdyArraySharding(None, sdy_dim_sharding) + return SdyArray(None, sdy_dim_sharding) @property def is_fully_replicated(self) -> bool: @@ -330,8 +330,8 @@ def with_memory_kind(self, kind: str): def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: raise NotImplementedError("pmap doesn't use OpSharding.") - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError("pmap doesn't use SdyArraySharding.") + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + raise NotImplementedError("pmap doesn't use SdyArray.") @functools.cached_property def is_fully_replicated(self) -> bool: @@ -540,9 +540,9 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions) - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: raise NotImplementedError( - "PositionalSharding can't be converted to an SdyArraySharding.") + "PositionalSharding can't be converted to an SdyArray.") @functools.cached_property def is_fully_addressable(self) -> bool: @@ -657,9 +657,9 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return self._hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: raise NotImplementedError( - "GSPMDSharding can't be converted to SdyArraySharding.") + "GSPMDSharding can't be converted to SdyArray.") @functools.cached_property def is_fully_replicated(self) -> bool: diff --git a/tests/array_test.py b/tests/array_test.py index b951f7f6b4cd..0bab37c07bff 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -35,8 +35,8 @@ from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding, - SdyArraySharding) + NamedSharding, GSPMDSharding, PositionalSharding, SdyDim, + SdyArray) from jax.experimental.pjit import pjit from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P @@ -1476,12 +1476,12 @@ def test_long_axis_names(self): sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( sdy_sharding, - SdyArraySharding( + SdyArray( mesh.shape_tuple, - [SdyDimSharding( + [SdyDim( ('sequence', 'data'), False), - SdyDimSharding(('model',), False), - SdyDimSharding([], False)])) + SdyDim(('model',), False), + SdyDim([], False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( @@ -1496,11 +1496,11 @@ def test_unconstrained(self): sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( sdy_sharding, - SdyArraySharding( + SdyArray( mesh.shape_tuple, - [SdyDimSharding([], False), - SdyDimSharding([], True), - SdyDimSharding(('x',), False)])) + [SdyDim([], False), + SdyDim([], True), + SdyDim(('x',), False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 22a4d4f70f8c..9b658dd8c604 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8593,26 +8593,26 @@ def f(x, y): self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) def test_array_sharding_repr_with_priority(self): - sharding = sharding_impls.SdyArraySharding( + sharding = sharding_impls.SdyArray( mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_open=False), - sharding_impls.SdyDimSharding(axes=['model'], is_open=True, priority=2)]) - self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") + sharding_impls.SdyDim(axes=['data', 'expert'], is_open=False), + sharding_impls.SdyDim(axes=['model'], is_open=True, priority=2)]) + self.assertEqual(repr(sharding), "SdyArray([{'data', 'expert'}, {'model', ?}p2])") def test_array_sharding_repr_with_logical_ids(self): abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z')) ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None), _logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3]) self.assertEqual(repr(ns._to_sdy_sharding(4)), - "SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], " + "SdyArray([{'x', 'y'}, {'z'}, {?}, {}], " "device_ids=[4, 5, 6, 7, 0, 1, 2, 3])") def test_dimension_sharding_repr(self): - dim_sharding = sharding_impls.SdyDimSharding( + dim_sharding = sharding_impls.SdyDim( axes=['data', 'model'], is_open=True, priority=2) self.assertEqual(repr(dim_sharding), - "SdyDimSharding({'data', 'model', ?}p2)") + "SdyDim({'data', 'model', ?}p2)") def test_tensor_dialect(self): # While this doesn't emit any `mlir::TensorDialect` ops, some pass in the From 0af3e8205307db0e12259930de0ce63d50f18f2e Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Fri, 2 May 2025 14:39:28 -0500 Subject: [PATCH 1172/1769] Skip CSR matmat/matvec float tests on ROCm <6.4 (NaN issue with beta==0). Cherry-picked from ROCm fork (commit 5dbfa9d1bcb3629658c2ec9addf45ac389f17305). Added TODOs to remove this check when ROCm 6.4+ is the minimum supported version. --- tests/sparse_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 71437fd0e028..97a156f9f6f5 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools import math +import os +from pathlib import Path from absl.testing import absltest from absl.testing import parameterized @@ -42,6 +44,15 @@ import numpy as np import scipy.sparse +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex @@ -208,6 +219,14 @@ def test_csr_fromdense(self, shape, dtype): transpose=[True, False], ) def test_csr_matvec(self, shape, dtype, transpose): + if ( + jtu.is_device_rocm() and + get_rocm_version() < (6, 4) and + dtype in (jtu.dtypes.floating + jtu.dtypes.complex) + ): + # TODO: Remove this check when ROCm 6.4+ is the minimum supported version + self.skipTest("ROCm <6.4 bug: NaN propagation when beta==0 (fixed in ROCm 6.4.0)") + op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) @@ -228,6 +247,14 @@ def test_csr_matvec(self, shape, dtype, transpose): transpose=[True, False], ) def test_csr_matmat(self, shape, dtype, transpose): + if ( + jtu.is_device_rocm() and + get_rocm_version() < (6, 4) and + dtype in (jtu.dtypes.floating + jtu.dtypes.complex) + ): + # TODO: Remove this check when ROCm 6.4+ is the minimum supported version + self.skipTest("ROCm <6.4 bug: NaN propagation when beta==0 (fixed in ROCm 6.4.0)") + op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) From 728102394f8c23e848c7cdf3f39ec7866a71f785 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 May 2025 09:07:27 -0700 Subject: [PATCH 1173/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/caef87114bc75bd89f56b8562d32cd7b8887319f. PiperOrigin-RevId: 758707765 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ca75a0be471c..4a8a44139a67 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c7dbe1e216d132e7b4042907b8ba9454535edc70" -XLA_SHA256 = "fea578d5f53daec13f9e84c615a2e2ba783b1886c471ff55897505259de6b137" +XLA_COMMIT = "caef87114bc75bd89f56b8562d32cd7b8887319f" +XLA_SHA256 = "ae3e9cd59cfcacdd8dbdf7e808b699fc728228654980b0766730a54fcaee0201" def repo(): tf_http_archive( From 74db09f8bf500cf5430f25a89a0ed68ac6e7c4da Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 14 May 2025 09:28:11 -0700 Subject: [PATCH 1174/1769] [ifrt] Refactor away from deprecated constructors PiperOrigin-RevId: 758715355 --- jaxlib/xla_compiler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 8803da924381..272cd2d409a8 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -488,7 +488,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { LayoutProto result; nb::bytes serialized = nb::cast(t[0]); result.ParseFromArray(serialized.c_str(), serialized.size()); - new (self) Layout(Layout::CreateFromProto(result)); + new (self) Layout(ValueOrThrow(Layout::FromProto(result))); }); nb::class_ shape_class(m, "Shape"); From 72d3ee6c413ac83577527f1b8ca8f02bd8d46661 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 May 2025 09:39:21 -0700 Subject: [PATCH 1175/1769] support accelerated deprecation for jax.extend.ffi --- tests/ffi_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index a66d17622ae5..77b0f823e125 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -28,6 +28,7 @@ from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import test_util as jtu from jax._src.interpreters import mlir @@ -284,6 +285,8 @@ def f(x): @jtu.run_on_devices("gpu", "cpu") @jtu.ignore_warning(category=DeprecationWarning) def test_extend_import_shim(self): + if deprecations.is_accelerated_attribute(jex.ffi, "ffi_call"): + self.skipTest("FFI call deprecation is accelerated.") ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True) def test_extended_dtype_lowering(self): From 0b70c2323df506811fb4d5290df72ca5ee328c55 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 May 2025 09:40:09 -0700 Subject: [PATCH 1176/1769] Lower unreduced to shardy. GSPMD doesn't support unreduced. PiperOrigin-RevId: 758719692 --- jax/_src/callback.py | 8 +++--- jax/_src/debugging.py | 4 +-- jax/_src/interpreters/mlir.py | 2 +- jax/_src/named_sharding.py | 51 +++++++++++++++++++++++++++++------ jax/_src/shard_map.py | 2 +- jax/_src/sharding_impls.py | 26 +++--------------- tests/array_test.py | 8 +++--- tests/pjit_test.py | 8 ++++-- 8 files changed, 64 insertions(+), 45 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 06b36ce5c880..bc233b634f3c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -158,7 +158,7 @@ def _callback_op_sharding( op_sharding = sharding_impls.SdyArrayList([ sharding_impls.SdyArray( mesh_shape=(), - dimension_shardings=[ + dim_shardings=[ sharding_impls.SdyDim(axes=[], is_open=False) ] * avals_out[0].ndim, logical_device_ids=())]) @@ -200,7 +200,7 @@ def _callback_op_sharding( op_sharding = sharding_impls.SdyArrayList(num_sdy_shardings * [ sharding_impls.SdyArray( mesh_shape=(), - dimension_shardings=[], + dim_shardings=[], logical_device_ids=(device_index,))]) else: op_sharding = xc.OpSharding() # type: ignore[assignment] @@ -610,7 +610,7 @@ def send_to_host( assert len(sharding.shardings) >= 1 sharding = SdyArrayList([ SdyArray( - mesh_shape=(), dimension_shardings=[], + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(send_op, sharding) return send_op.result @@ -645,7 +645,7 @@ def receive_from_host( sharding = SdyArrayList([ sharding.shardings[0], SdyArray( - mesh_shape=(), dimension_shardings=[], + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(recv_op, sharding) # Token should be at the end of the results diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 63abcbef331e..c2febf752b92 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -168,7 +168,7 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) sharding = sharding_impls.SdyArrayList([ sharding_impls.SdyArray( mesh_shape=(), - dimension_shardings=[ + dim_shardings=[ sharding_impls.SdyDim(axes=[], is_open=False) ] * ctx.avals_out[0].ndim, logical_device_ids=())]) @@ -184,7 +184,7 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) if config.use_shardy_partitioner.value: sharding = sharding_impls.SdyArrayList([ sharding_impls.SdyArray( - mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) + mesh_shape=(), dim_shardings=[], logical_device_ids=(0,))]) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MAXIMAL diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 94418f0b958b..0256057b8b09 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1841,7 +1841,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: physical_ndim = core.physical_aval(aval).ndim s = SdyArray( mesh_shape=None, - dimension_shardings=[ + dim_shardings=[ sharding_impls.SdyDim(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 45ae1c124a22..faf0b2a9f2b2 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -44,7 +44,8 @@ def __init__(self, mesh: mesh_lib.Mesh): def _to_sdy_sharding(self, ndim: int) -> SdyArray: dim_shardings = [SdyDim(axes=[], is_open=True) for _ in range(ndim)] - return SdyArray(self.mesh.shape_tuple, dim_shardings) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings) class UnspecifiedValue: def __repr__(self): @@ -244,8 +245,10 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: else: dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,) dim_shardings[i].axes = dim_spec - return SdyArray(self.mesh.shape_tuple, dim_shardings, - self._logical_device_ids) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings, + logical_device_ids=self._logical_device_ids, + unreduced_axes=self.spec.unreduced) NamedSharding.__module__ = 'jax.sharding' @@ -285,13 +288,21 @@ def _custom_repr(self): priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' +def _get_axes(axes, mesh_shape): + if not axes: + return () + assert mesh_shape is not None + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + return tuple(n for n, _ in mesh_shape if n in axes) -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class SdyArray: mesh_shape: tuple[tuple[str, int], ...] | None - dimension_shardings: Sequence[SdyDim] + dim_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () + unreduced_axes: tuple[str, ...] = () def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: @@ -302,14 +313,18 @@ def build(self) -> sdy.TensorShardingAttr: mesh_attr = sdy.MeshAttr.get( [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) + + replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) + unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, - [dim_sharding.build() for dim_sharding in self.dimension_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes]) + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( - d._custom_repr() for d in self.dimension_shardings) + d._custom_repr() for d in self.dim_shardings) device_id_repr = (f', device_ids={self.logical_device_ids}' if self.logical_device_ids is not None else '') rar = (f', replicated_axes={self.replicated_axes}' @@ -317,6 +332,26 @@ def __repr__(self): return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})" +# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra +# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` +def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): + if mesh._any_axis_auto: + dim_shardings, used_axes = [], [] # type: ignore + for d in sdy_sharding.dim_shardings: + # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? + dim_shardings.append(SdyDim(axes=[], is_open=True) + if not d.axes and not d.is_open else d) + used_axes.extend(d.axes) + remaining_axes = set(mesh.axis_names) - set(used_axes) + replicated_axes = tuple(r for r in remaining_axes + if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) + return SdyArray(mesh_shape=sdy_sharding.mesh_shape, + dim_shardings=dim_shardings, + logical_device_ids=sdy_sharding.logical_device_ids, + replicated_axes=replicated_axes) + return sdy_sharding + + @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 3eb46da890f2..abcc2ca0acf1 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -781,7 +781,7 @@ def _shardy_shard_map_sharding( aval_in = core.physical_aval(aval_in) sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) if len(manual_axes) < len(mesh.axis_names): - for dim_sharding in sdy_sharding.dimension_shardings: + for dim_sharding in sdy_sharding.dim_shardings: dim_sharding.is_open = True return sdy_sharding diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f7f0ebd2cc26..982af82c5c4d 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -38,7 +38,8 @@ SdyArray, SdyDim, UnspecifiedValue, AUTO, _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding, + modify_sdy_sharding_wrt_axis_types) from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec @@ -95,27 +96,6 @@ def build(self) -> sdy.TensorShardingPerValueAttr: [sharding.build() for sharding in self.shardings]) -# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra -# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` -def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): - if mesh._any_axis_auto: - dim_shardings, used_axes = [], [] # type: ignore - for d in sdy_sharding.dimension_shardings: - # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDim(axes=[], is_open=True) - if not d.axes and not d.is_open else d) - used_axes.extend(d.axes) - remaining_axes = set(mesh.axis_names) - set(used_axes) - # Sort wrt mesh axis names so order is deterministic and doesn't hang in - # McJAX. - remaining_axes = [n for n in mesh.axis_names if n in remaining_axes] - replicated_axes = tuple(r for r in remaining_axes - if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) - return SdyArray(sdy_sharding.mesh_shape, dim_shardings, - sdy_sharding.logical_device_ids, replicated_axes) - return sdy_sharding - - replicated_hlo_sharding = xc.HloSharding.replicate() @@ -188,7 +168,7 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: sdy_dim_sharding = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] - return SdyArray(None, sdy_dim_sharding) + return SdyArray(mesh_shape=None, dim_shardings=sdy_dim_sharding) @property def is_fully_replicated(self) -> bool: diff --git a/tests/array_test.py b/tests/array_test.py index 0bab37c07bff..1691c3acc749 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1477,8 +1477,8 @@ def test_long_axis_names(self): self.assertEqual( sdy_sharding, SdyArray( - mesh.shape_tuple, - [SdyDim( + mesh_shape=mesh.shape_tuple, + dim_shardings=[SdyDim( ('sequence', 'data'), False), SdyDim(('model',), False), SdyDim([], False)])) @@ -1497,8 +1497,8 @@ def test_unconstrained(self): self.assertEqual( sdy_sharding, SdyArray( - mesh.shape_tuple, - [SdyDim([], False), + mesh_shape=mesh.shape_tuple, + dim_shardings=[SdyDim([], False), SdyDim([], True), SdyDim(('x',), False)])) with ir.Context() as ctx: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9b658dd8c604..e374b0b15a7b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7735,6 +7735,7 @@ def f(x): f(arr) # doesn't crash jax.jit(f)(arr) # doesn't crash + @config.use_shardy_partitioner(True) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_unreduced_basic(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -7758,7 +7759,10 @@ def f(x, y, a, b): self.assertEqual(out.aval.sharding.spec, P('x', None)) return out - f.trace(x, y, a, b) # doesn't crash + traced = f.trace(x, y, a, b) + lowered_text = traced.lower().as_text() + self.assertIn('unreduced={"y"}', lowered_text) + self.assertTrue(lowered_text.count('unreduced={"y"}') == 3) @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_dot_general_unreduced_error(self, mesh): @@ -8595,7 +8599,7 @@ def f(x, y): def test_array_sharding_repr_with_priority(self): sharding = sharding_impls.SdyArray( mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), - dimension_shardings=[ + dim_shardings=[ sharding_impls.SdyDim(axes=['data', 'expert'], is_open=False), sharding_impls.SdyDim(axes=['model'], is_open=True, priority=2)]) self.assertEqual(repr(sharding), "SdyArray([{'data', 'expert'}, {'model', ?}p2])") From b5d4e2176773cdb5d44c4d68dc1263e26c99a9cd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 May 2025 10:06:11 -0700 Subject: [PATCH 1177/1769] Make device_put work for python scalars when the sharding is not fully addressable i.e. `device_put(1, global_sharding)`. This already works for numpy arrays. Fixes: https://github.com/jax-ml/jax/discussions/14578#discussioncomment-13145332 PiperOrigin-RevId: 758730649 --- jax/_src/dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 409b22c849e3..1ab560fb58c5 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -474,7 +474,7 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or - type(x) in array_types): + type(x) in array_types or type(x) in dtypes.python_scalar_dtypes): # If all hosts participate in the sharding, assert that the input is the # same on all hosts. If some hosts have no addressable devices in the # sharding, bypass the check, since we can't easily distinguish between From 5abc510b434884639b9c3f48ba379172d443dba6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 May 2025 10:35:35 -0700 Subject: [PATCH 1178/1769] Partial in `repeats` if `total_repeat_length is None` since repeats is expected to be a constant. This fixes an error in explicit sharding mode where we were converting repeats to a Tracer instead of it being a concrete value. PiperOrigin-RevId: 758743044 --- jax/_src/numpy/lax_numpy.py | 24 ++++++++++++++++-------- tests/pjit_test.py | 6 ++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 266aad4954ba..0bd287dadd51 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6698,9 +6698,8 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ if out_sharding is not None: - return auto_axes( - partial(_repeat, axis=axis, total_repeat_length=total_repeat_length), - out_sharding=out_sharding)(a, repeats) + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) ctx_mesh = get_abstract_mesh() if ctx_mesh._are_all_axes_explicit: aval = core.typeof(a) @@ -6710,17 +6709,26 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, assert axis is not None and aval.sharding.spec[axis] is None out_sharding = (NamedSharding(ctx_mesh, P()) if aval.sharding.mesh.empty else aval.sharding) - return auto_axes( - partial(_repeat, axis=axis, total_repeat_length=total_repeat_length), - out_sharding=out_sharding)(a, repeats) + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) try: - return _repeat(a, repeats, axis=axis, + return _repeat(a, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) except core.ShardingTypeError as e: raise ValueError( "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") -def _repeat(a: ArrayLike, repeats: ArrayLike, *, axis: int | None = None, +def _auto_repeat(fun, a, repeats, axis, total_repeat_length, out_sharding): + if total_repeat_length is None: + return auto_axes(partial(fun, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a) + else: + return auto_axes( + partial(fun, axis=axis, total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a, repeats=repeats) + +def _repeat(a: ArrayLike, *, repeats: ArrayLike, axis: int | None = None, total_repeat_length: int | None = None) -> Array: if core.is_dim(repeats): util.check_arraylike("repeat", a) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e374b0b15a7b..d1c8ec7f050d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7617,6 +7617,12 @@ def test_jnp_repeat(self, mesh): out = jnp.repeat(a, np.array((2,2,2,2)) - 1, axis=0, out_sharding=P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + a = jax.device_put(jnp.eye(16).reshape(16, 16), P('x')) + @jax.jit + def f(x): + return jnp.repeat(x, 3, axis=-1) + f(a) + @jtu.with_explicit_mesh((2,), ('x',)) def test_scatter_gather(self, mesh): x = np.random.uniform(size=(mesh.size * 2, 3)) From a5361315fd874cf4d5e278e20a9ab2de521d8df4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 May 2025 11:08:48 -0700 Subject: [PATCH 1179/1769] Warn when __jax_array__ is seen in xla dtype canonicalization --- jax/BUILD | 1 + jax/_src/interpreters/xla.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/jax/BUILD b/jax/BUILD index 5f1cf1670729..6afcdf2892f4 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1180,6 +1180,7 @@ pytype_strict_library( ":abstract_arrays", ":config", ":core", + ":deprecations", ":dtypes", ":sharding_impls", ":source_info_util", diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 7fbb22923e0f..73a57f935f5d 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -23,6 +23,7 @@ import numpy as np from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types from jax._src.util import safe_zip, safe_map @@ -100,6 +101,12 @@ def canonicalize_dtype(x): handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return canonicalize_dtype(x.__jax_array__()) raise InvalidInputException( f"Argument '{x}' of type {type(x)} is not a valid JAX type.") From e968cb6b04d31ac7a045f119de027ee17fc16e1e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 18:36:10 +0000 Subject: [PATCH 1180/1769] Add requirements needed for building wheels under Python 3.14t. Also allow ml_dtypes as a local wheel override. --- WORKSPACE | 2 ++ build/requirements_lock_3_14_ft.txt | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/WORKSPACE b/WORKSPACE index 903085714e65..f389afe2263f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,8 @@ python_init_repositories( default_python_version = "system", local_wheel_dist_folder = "../dist", local_wheel_inclusion_list = [ + "ml_dtypes*", + "ml-dtypes*", "numpy*", "scipy*", "jax-*", diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt index e50305f4fa48..6eedf149f5fa 100644 --- a/build/requirements_lock_3_14_ft.txt +++ b/build/requirements_lock_3_14_ft.txt @@ -19,3 +19,9 @@ flatbuffers==24.12.23 ml-dtypes==0.5.1 opt-einsum==3.4.0 + +build==1.2.2.post1 +setuptools==80.0.0 +wheel==0.45.1 +pyproject-hooks==1.2.0 +packaging==25.0 From 183425cf05da44cb5496db3046409af812747e5d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 19:02:00 +0000 Subject: [PATCH 1181/1769] Move definition of pjrt_c_api_gpu_plugin.so into jax_plugins/{cuda,rocm} This is simpler and should work better for a bazel submodule. --- jax_plugins/cuda/BUILD.bazel | 21 ++++++++--- jax_plugins/cuda/__init__.py | 2 +- .../cuda}/gpu_version_script.lds | 0 jax_plugins/rocm/BUILD.bazel | 20 ++++++++--- jax_plugins/rocm/__init__.py | 2 +- jax_plugins/rocm/gpu_version_script.lds | 9 +++++ jaxlib/tools/BUILD.bazel | 35 ++++--------------- jaxlib/tools/build_gpu_plugin_wheel.py | 4 +-- 8 files changed, 52 insertions(+), 41 deletions(-) rename {jaxlib/tools => jax_plugins/cuda}/gpu_version_script.lds (100%) create mode 100644 jax_plugins/rocm/gpu_version_script.lds diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 6566cfc62b0c..c3c20f536cff 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -34,15 +34,28 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "//jaxlib/mosaic/gpu:custom_call", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:cuda_platform", + ], +) + py_library_providing_imports_info( name = "cuda_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["//jaxlib/tools:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 4891fbeb3332..1be29326c95f 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -51,7 +51,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = os.path.join( - runfiles_dir, '__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so' + runfiles_dir, '__main__/jax_plugins/cuda/pjrt_c_api_gpu_plugin.so' ) if os.path.exists(local_path): diff --git a/jaxlib/tools/gpu_version_script.lds b/jax_plugins/cuda/gpu_version_script.lds similarity index 100% rename from jaxlib/tools/gpu_version_script.lds rename to jax_plugins/cuda/gpu_version_script.lds diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 6e265bcd18cf..15e9e627830e 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -34,14 +34,26 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:rocm_platform", + ], +) + py_library_providing_imports_info( name = "rocm_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 0b1b077acfcd..cf2a625fa783 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -51,7 +51,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = pathlib.Path( - os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') + os.path.join(runfiles_dir, '__main__/jax_plugins/rocm/pjrt_c_api_gpu_plugin.so') ) if local_path.exists(): diff --git a/jax_plugins/rocm/gpu_version_script.lds b/jax_plugins/rocm/gpu_version_script.lds new file mode 100644 index 000000000000..cbac4549bde3 --- /dev/null +++ b/jax_plugins/rocm/gpu_version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + }; + + local: + *; +}; diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 219096836ffc..22bae26a4420 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,10 +64,10 @@ py_binary( "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", + "//jaxlib:_jax", "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", "//jaxlib:xla_client.py", - "//jaxlib:_jax", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", @@ -90,35 +90,15 @@ jax_py_test( ], ) -cc_binary( - name = "pjrt_c_api_gpu_plugin.so", - linkopts = [ - "-Wl,--version-script,$(location :gpu_version_script.lds)", - "-Wl,--no-undefined", - ], - linkshared = True, - deps = [ - ":gpu_version_script.lds", - "@xla//xla/pjrt/c:pjrt_c_api_gpu", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", - "@xla//xla/service:gpu_plugin", - ] + if_cuda([ - "//jaxlib/mosaic/gpu:custom_call", - "@xla//xla/stream_executor:cuda_platform", - ]) + if_rocm([ - "@xla//xla/stream_executor:rocm_platform", - ]), -) - py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], data = [ "LICENSE.txt", - ":pjrt_c_api_gpu_plugin.so", ] + if_cuda([ "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", + "//jax_plugins/cuda:pjrt_c_api_gpu_plugin.so", "//jax_plugins/cuda:pyproject.toml", "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", @@ -126,6 +106,7 @@ py_binary( ]) + if_rocm([ "//jaxlib:version", "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:pjrt_c_api_gpu_plugin.so", "//jax_plugins/rocm:pyproject.toml", "//jax_plugins/rocm:setup.py", "//jax_plugins/rocm:__init__.py", @@ -387,10 +368,6 @@ jax_wheel( ) # JAX PJRT wheel targets. -pytype_strict_library( - name = "pjrt_c_api_gpu_plugin_so", - data = [":pjrt_c_api_gpu_plugin.so"], -) py_binary( name = "build_gpu_plugin_wheel_tool", @@ -407,12 +384,12 @@ py_binary( wheel_sources( name = "jax_pjrt_sources", - data_srcs = [ - ":pjrt_c_api_gpu_plugin_so", - ] + if_cuda([ + data_srcs = if_cuda([ + "//jax_plugins/cuda:cuda_plugin", "//jaxlib/cuda:cuda_gpu_support", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ + "//jax_plugins/rocm:rocm_plugin", "//jaxlib/rocm:rocm_gpu_support", ]), py_srcs = [ diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 337bedab4591..68e08d89338e 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -120,7 +120,7 @@ def prepare_cuda_plugin_wheel( ], ) copy_files( - f"{source_file_prefix}jaxlib/tools/pjrt_c_api_gpu_plugin.so", + f"{source_file_prefix}jax_plugins/cuda/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) @@ -158,7 +158,7 @@ def prepare_rocm_plugin_wheel( ], ) copy_files( - f"{source_file_prefix}jaxlib/tools/pjrt_c_api_gpu_plugin.so", + f"{source_file_prefix}jax_plugins/rocm/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_rocm_plugin.so", ) From f0e00a6658709067951446868894ea300b365c8b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 14 May 2025 13:35:43 -0400 Subject: [PATCH 1182/1769] Add a linearization rule for scan. --- jax/_src/lax/control_flow/loops.py | 103 +++++++++++++++++++++++++++++ tests/lax_control_flow_test.py | 13 ++++ 2 files changed, 116 insertions(+) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 05e9c010dc51..c85a23b6b199 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -683,6 +683,108 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out +def _scan_linearization(nzs, *primals_in, reverse: bool, length: int, + num_consts: int, num_carry: int, + jaxpr: core.ClosedJaxpr, linear: Sequence[bool], + unroll: int, _split_transpose: bool): + const_nz, init_nz, xs_nz = split_list(nzs, [num_consts, num_carry]) + carry_nz = init_nz + for _ in range(1 + num_carry): + nzs = const_nz + carry_nz + xs_nz + primal_jaxpr, num_res, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) + carry_nz_out = nzs_out[:num_carry] + if carry_nz_out == carry_nz: + break + else: + carry_nz = _map(operator.or_, carry_nz, carry_nz_out) + else: + assert False, "Fixpoint not reached" + + # The linearize_jaxpr function produces primal_jaxpr with num_res residuals + # output at the front, and tangent_jaxpr with num_res residuals input at the + # back. We could move all the residuals to the back and treat them as + # extensive outputs, but this would be wasteful for residuals that are + # loop invariant, or forwarded extensive inputs. + + # First, for residuals that are forwarded constants, we move those to the + # front in the tangent_jaxpr to treat them as intensive inputs. + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + primal_jaxpr, tangent_jaxpr, intensive_res, in_fwd = _const_to_intensive_res_forwarding( + primal_jaxpr, tangent_jaxpr, num_res, num_consts, primals_in, in_fwd) + num_intensive_res = len(intensive_res) + num_res -= num_intensive_res + + # After pruning the intensive residuals, the rest get moved to the back and + # handled as extensive outputs from the primal. + num_out = len(nzs_out) + primal_jaxpr = pe.move_outvars_to_back( + primal_jaxpr, [True] * num_res + [False] * num_out) + in_fwd = in_fwd[num_res:] + in_fwd[:num_res] + + # Then, any residuals or other extensive outputs that are forwarded extensive + # inputs, we remove them from the primal jaxpr, and manually forward them. + in_fwd = [in_idx if out_idx >= num_carry and in_idx is not None and + in_idx >= num_consts + num_carry else None + for out_idx, in_idx in enumerate(in_fwd)] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, + [i is None for i in in_fwd]) + + out = scan_p.bind(*primals_in, jaxpr=primal_jaxpr, reverse=reverse, + length=length, num_consts=num_consts, num_carry=num_carry, + linear=linear, unroll=unroll, _split_transpose=_split_transpose) + out_ = iter(out) + all_out = [next(out_) if f is None else _maybe_put(primals_in[f]) for f in in_fwd] + assert next(out_, None) is None + primals_out, extensive_res = split_list(all_out, [len(all_out) - num_res]) + res = [*intensive_res, *extensive_res] + + def tangent_fun(res, *tangents): + intensive_res, extensive_res = split_list(res, [num_intensive_res]) + nz_tangents = [ad.instantiate_zeros(x) for nz, x in zip(nzs, tangents) if nz] + tangent_linear = ( + (False,) * len(intensive_res) + + (True,) * len(nz_tangents) + + (False,) * len(extensive_res) + ) + tangent_num_consts = len(intensive_res) + sum(nzs[:num_consts]) + tangent_num_carry = sum(nzs[num_consts:num_consts + num_carry]) + nz_tangents_out = scan_p.bind(*intensive_res, *nz_tangents, *extensive_res, + jaxpr=tangent_jaxpr, + reverse=reverse, length=length, + num_consts=tangent_num_consts, + num_carry=tangent_num_carry, + linear=tangent_linear, unroll=unroll, + _split_transpose=_split_transpose) + tangent_avals_out = [v.aval.to_tangent_aval() for v in jaxpr.jaxpr.outvars] + nz_tangents_out_ = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) + for aval, nz in zip(tangent_avals_out, nzs_out)] + assert next(nz_tangents_out_, None) is None + return tangents_out + + return primals_out, nzs_out, res, tangent_fun + +def _const_to_intensive_res_forwarding( + primal_jaxpr: core.ClosedJaxpr, + tangent_jaxpr: core.ClosedJaxpr, + num_res: int, + num_consts: int, + primals_in: Sequence[Any], + in_fwd: list[int | None] +) -> tuple[core.ClosedJaxpr, core.ClosedJaxpr, list[Any], list[int | None]]: + const_to_res = [in_idx if in_idx is not None and in_idx < num_consts else None + for in_idx in in_fwd[:num_res]] + new_in_fwd = [f for c, f in zip(const_to_res, in_fwd[:num_res]) if c is None] + new_in_fwd += in_fwd[num_res:] + intensive_res = [primals_in[f] for f in const_to_res if f is not None] + num_out = len(primal_jaxpr.out_avals) - num_res + primal_jaxpr = pe.prune_closed_jaxpr_outputs( + primal_jaxpr, [i is None for i in const_to_res] + [True] * num_out) + num_nz = len(tangent_jaxpr.in_avals) - num_res + tangent_jaxpr = pe.move_binders_to_front( + tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res]) + return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd + def _scan_partial_eval(trace, *tracers, reverse: bool, length: int, num_consts: int, num_carry: int, jaxpr: core.ClosedJaxpr, linear: Sequence[bool], @@ -1385,6 +1487,7 @@ def arrange_jaxpr_args_for_wrapped(args): scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.primitive_transposes[scan_p] = _scan_transpose +ad.primitive_linearizations[scan_p] = _scan_linearization pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d32d761ee1fa..54dff47fea32 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -28,6 +28,7 @@ import jax from jax._src import core +from jax._src import config from jax import dtypes from jax import lax from jax import random @@ -3298,6 +3299,18 @@ def body_fun(c, _): outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] self.assertAllClose(outs, outs_ref, check_dtypes=False) + def test_scan_diff_of_print(self): + # ref: https://github.com/jax-ml/jax/issues/28738 + def f(c, _): + jax.debug.print("c = {c}", c=c, ordered=True) + return c + 1, None + def g(x): + return jax.lax.scan(f, x, length=2)[0] + with config.use_direct_linearize(True): + jaxpr = jax.make_jaxpr(jax.value_and_grad(g))(1.0) + eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"] + self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns]) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 4102cf09ec9b30651f12c22f8537f007d6f7e129 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 20:25:01 +0000 Subject: [PATCH 1183/1769] Simplify the Bazel logic used to add dependencies to tests. A lot of this logic was confusing phrased as conditions over both CPU and GPU build flags. But we can decompose it: * dependencies we add for CPU tests, and * additional dependencies we add for GPU tests. While we are here, also add the necessary pypi dependency for TPU tests. --- jax/BUILD | 13 ++-- jax_plugins/cuda/BUILD.bazel | 33 ---------- jax_plugins/rocm/BUILD.bazel | 1 - jaxlib/jax.bzl | 114 ++++++++++++----------------------- 4 files changed, 48 insertions(+), 113 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 6afcdf2892f4..3f73a4b9e68f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -101,17 +101,22 @@ string_flag( ) config_setting( - name = "disable_jaxlib_and_jax_build", + name = "config_build_jax_true", + flag_values = { + ":build_jax": "true", + }, +) + +config_setting( + name = "config_build_jax_false", flag_values = { - ":build_jaxlib": "false", ":build_jax": "false", }, ) config_setting( - name = "enable_jaxlib_and_jax_py_import", + name = "config_build_jax_wheel", flag_values = { - ":build_jaxlib": "wheel", ":build_jax": "wheel", }, ) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index c3c20f536cff..7070bf6bc495 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -14,7 +14,6 @@ load( "//jaxlib:jax.bzl", - "if_windows", "py_library_providing_imports_info", "pytype_library", ) @@ -58,35 +57,3 @@ py_library_providing_imports_info( data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) - -config_setting( - name = "disable_jaxlib_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "disable_jaxlib_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "True", - }, -) - -config_setting( - name = "enable_py_import_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "enable_py_import_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "True", - }, -) diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 15e9e627830e..7ee0726e7960 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -16,7 +16,6 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", - "if_windows", "py_library_providing_imports_info", "pytype_library", ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index a48c44f406f2..1d4e24720c2e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -167,70 +167,36 @@ def if_building_jaxlib( if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels if_not_building: the wheels to depend on if we are not depending directly on //jaxlib. """ - return select({ "//jax:config_build_jaxlib_true": if_building, "//jax:config_build_jaxlib_false": if_not_building, "//jax:config_build_jaxlib_wheel": [], }) -def _get_test_deps(deps, backend_independent): - """Returns the test deps for the given backend. - - Args: - deps: the full list of test dependencies - backend_independent: whether the test is backend independent +def _cpu_test_deps(): + """Returns the test depencies needed for a CPU-only JAX test.""" + return select({ + "//jax:config_build_jaxlib_true": [], + "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], + "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], + }) - Returns: - A list of test deps for the given backend. - For CPU builds: - If --//jax:build_jaxlib=true, returns pypi test deps. - If --//jax:build_jaxlib=false, returns jaxlib pypi wheel dep and pypi test deps. - If --//jax:build_jaxlib=wheel, returns jaxlib py_import dep and pypi test deps. - For GPU builds: - If --//jax:build_jaxlib=true, returns pypi test deps and gpu build deps. - If --//jax:build_jaxlib=false, returns jaxlib, jax-cuda-plugin, - jax-cuda-pjrt pypi wheel deps and pypi test deps. - If --//jax:build_jaxlib=wheel, returns jaxlib, - jax-cuda-plugin, jax-cuda-pjrt py_import deps and pypi test deps. - """ - gpu_build_deps = [ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ] - pypi_test_deps = [d for d in deps if d.startswith("@pypi//")] - - gpu_py_imports = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ] + pypi_test_deps - cpu_py_imports = [ - "//jaxlib/tools:jaxlib_py_import", - ] + pypi_test_deps - jaxlib_pypi_wheel_deps = [ - "@pypi//jaxlib", - ] + pypi_test_deps - - if backend_independent: - test_deps = pypi_test_deps - gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps - gpu_py_import_deps = cpu_py_imports - else: - test_deps = gpu_build_deps + pypi_test_deps - gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps + [ +def _gpu_test_deps(): + """Returns the additional dependencies needed for a GPU test.""" + return select({ + "//jax:config_build_jaxlib_true": [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ], + "//jax:config_build_jaxlib_false": [ "@pypi//jax_cuda12_plugin", "@pypi//jax_cuda12_pjrt", - ] - gpu_py_import_deps = gpu_py_imports - - return select({ - "//jax:config_build_jaxlib_true": test_deps, - "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": jaxlib_pypi_wheel_deps, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps, + ], + "//jax:config_build_jaxlib_wheel": [ + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], }) def _get_jax_test_deps(deps): @@ -246,28 +212,23 @@ def _get_jax_test_deps(deps): If --//jax:build_jax=false, returns jax pypi wheel dep and transitive pypi test deps. If --//jax:build_jax=wheel, returns jax py_import dep and transitive pypi test deps. """ - jax_build_deps = [d for d in deps if not d.startswith("@pypi//")] + non_pypi_deps = [d for d in deps if not d.startswith("@pypi//")] # A lot of tests don't have explicit dependencies on scipy, ml_dtypes, etc. But the tests # transitively depends on them via //jax. So we need to make sure that these dependencies are # included in the test when JAX is built from source. - jax_transitive_pypi_test_deps = {k: "true" for k in py_deps([ + pypi_deps = depset([d for d in deps if d.startswith("@pypi//")]) + pypi_deps = depset(py_deps([ "ml_dtypes", "scipy", "opt_einsum", "flatbuffers", - ])} + ]), transitive = [pypi_deps]).to_list() - # Remove the pypi deps that are already provided by _get_test_deps(). - for d in deps: - if d.startswith("@pypi//") and jax_transitive_pypi_test_deps.get(d): - jax_transitive_pypi_test_deps.pop(d) - return select({ - "//jax:disable_jaxlib_and_jax_build": ["//:jax_wheel_with_internal_test_util"] + - jax_transitive_pypi_test_deps.keys(), - "//jax:enable_jaxlib_and_jax_py_import": ["//:jax_py_import"] + - jax_transitive_pypi_test_deps.keys(), - "//conditions:default": jax_build_deps + jax_transitive_pypi_test_deps.keys(), + return pypi_deps + select({ + "//jax:config_build_jax_false": ["//:jax_wheel_with_internal_test_util"], + "//jax:config_build_jax_wheel": ["//:jax_py_import"], + "//jax:config_build_jax_true": non_pypi_deps, }) # buildifier: disable=function-docstring @@ -316,18 +277,21 @@ def jax_multiplatform_test( test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): test_tags.append("manual") + test_deps = _cpu_test_deps() + _get_jax_test_deps([ + "//jax", + "//jax:test_util", + ] + deps) if backend == "gpu": + test_deps += _gpu_test_deps() test_tags += tf_cuda_tests_tags() + elif backend == "tpu": + test_deps += ["@pypi//libtpu"] native.py_test( name = name + "_" + backend, srcs = srcs, args = test_args, env = env, - deps = _get_test_deps(deps, backend_independent = False) + - _get_jax_test_deps([ - "//jax", - "//jax:test_util", - ] + deps), + deps = test_deps, data = data, shard_count = test_shards, tags = test_tags, @@ -620,13 +584,13 @@ def jax_py_test( env = dict(env) env.setdefault("PYTHONWARNINGS", "error") deps = kwargs.get("deps", []) - test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) kwargs["deps"] = test_deps py_test(name = name, env = env, **kwargs) def pytype_test(name, **kwargs): deps = kwargs.get("deps", []) - test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) kwargs["deps"] = test_deps native.py_test(name = name, **kwargs) From 5a6957fa75db31c9f63df8788f12f47b0d060f44 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 14 May 2025 13:50:02 -0700 Subject: [PATCH 1184/1769] Add use_raw_buffers which allows switching the implementation to hold references to raw buffers instead of PjRtBuffers. This fixes an issue where the buffers can be deleted before the transfer is complete, but introduces another problem where if they are donated it will now silently read from donated arrays. Once the underlying runtime exposes usage holds properly, this new codepath should take a usage hold and the old pjrtbuffer path should be removed. PiperOrigin-RevId: 758819621 --- jaxlib/BUILD | 1 + jaxlib/py_socket_transfer.cc | 64 +++++++++++++++++++++++++++++++----- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index add6dbd7d92a..7047ddc3edd6 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1033,6 +1033,7 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", + "@xla//xla/python:version", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", "@xla//xla/python/pjrt_ifrt:pjrt_dtype", diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index 114e3c14874d..fde63df8da47 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -60,6 +60,7 @@ limitations under the License. #include "xla/python/transfer/streaming_ifrt.h" #include "xla/python/transfer/transfer_socket.pb.h" #include "xla/python/types.h" +#include "xla/python/version.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -109,6 +110,7 @@ absl::StatusOr MemorySpaceFromSharding( } } +#if JAX_IFRT_VERSION_NUMBER < 8 class IfrtArrayEntry : public PullTable::Entry { public: struct BufferRef { @@ -153,10 +155,48 @@ class IfrtArrayEntry : public PullTable::Entry { std::shared_ptr state_; size_t xfer_size_; }; +#endif -absl::StatusOr> CreatePullEntry( +absl::StatusOr> CreatePullEntry( const std::vector& arrs, - std::shared_ptr state, size_t xfer_size) { + std::shared_ptr state, size_t xfer_size, + bool use_raw_buffers) { +#if JAX_IFRT_VERSION_NUMBER >= 8 + if (use_raw_buffers) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, + pjrt_buf->GetOnDeviceSizeInBytes()); + TF_ASSIGN_OR_RETURN( + auto raw_buffer, + xla::PjRtRawBuffer::CreateRawAliasOfBuffer(pjrt_buf.get())); + refs.push_back( + {pjrt_buf->GetReadyFuture(), std::move(raw_buffer), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); + } + + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({pjrt_buf, buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +#else std::vector refs; for (auto& arr : arrs) { auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); @@ -170,6 +210,7 @@ absl::StatusOr> CreatePullEntry( } } return tsl::MakeRef(std::move(refs), state, xfer_size); +#endif } class PyTransferServerConnection { @@ -195,7 +236,8 @@ class PyTransferServer { absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, size_t xfer_size, const SocketAddress& addr, const std::vector& transport_addresses, - bool supports_pinned_allocator) { + bool supports_pinned_allocator, bool use_raw_buffers) { + use_raw_buffers_ = use_raw_buffers; std::shared_ptr factory; if (transport_addresses.empty()) { factory = BulkTransportFactory::CreateLocal(); @@ -235,8 +277,9 @@ class PyTransferServer { } void AwaitPull(uint64_t uuid, const std::vector& arrs) { - server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( - arrs, premapped_copier_, xfer_size_))); + server_->AwaitPull( + uuid, xla::ValueOrThrow(CreatePullEntry(arrs, premapped_copier_, + xfer_size_, use_raw_buffers_))); } size_t xfer_size() { return xfer_size_; } @@ -249,6 +292,7 @@ class PyTransferServer { std::shared_ptr server_; std::shared_ptr premapped_copier_; size_t xfer_size_; + bool use_raw_buffers_ = false; }; absl::StatusOr ArraySpecFromShapeDtypeStruct( @@ -394,7 +438,8 @@ void RegisterTransferServerTypes(nanobind::module_& m) { [](xla::nb_class_ptr py_client, std::string address, std::vector transport_addresses_str, size_t max_num_parallel_copies, size_t transfer_size, - bool supports_pinned_allocator) -> PyTransferServer { + bool supports_pinned_allocator, + bool use_raw_buffers) -> PyTransferServer { PyTransferServer result; std::vector transport_addresses; transport_addresses.reserve(transport_addresses_str.size()); @@ -405,7 +450,7 @@ void RegisterTransferServerTypes(nanobind::module_& m) { xla::ThrowIfError(result.Start( py_client->ifrt_client(), max_num_parallel_copies, transfer_size, xla::ValueOrThrow(SocketAddress::Parse(address)), - transport_addresses, supports_pinned_allocator)); + transport_addresses, supports_pinned_allocator, use_raw_buffers)); return result; }, nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), @@ -413,7 +458,10 @@ void RegisterTransferServerTypes(nanobind::module_& m) { nb::arg("max_num_parallel_copies") = 8, nb::arg("transfer_size") = 256 * 1024 * 1024, // Dual pinning not confirmed to be supported. - nb::arg("supports_pinned_allocator") = false); + nb::arg("supports_pinned_allocator") = false, + // Technically unsafe (because a future donation won't wait for the + // transfer to complete). + nb::arg("use_raw_buffers") = false); } } // namespace aux From 011639cf3621c52c00ffc1a24abf7f4dacf19966 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 21:28:15 +0000 Subject: [PATCH 1185/1769] Reenable CUDA version checks from Python. These had been accidentally broken at some point in the plugin switchover.. --- CHANGELOG.md | 2 + jax/_src/xla_bridge.py | 136 -------------------------------- jax_plugins/cuda/__init__.py | 148 +++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 136 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c30132c169..9fd4e50304d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. given its name. * Changes + * Additional checking for the versions of CUDA package dependencies was + reenabled, having been accidentally disabled in a previous release. * JAX nightly packages are now published to artifact registry. To install these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). * `jax.sharding.PartitionSpec` no longer inherits from a tuple. diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 72a16d5fbe5c..ce0c36fdcca4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -31,7 +31,6 @@ import pkgutil import platform as py_platform import threading -import traceback from typing import Any, Sequence, Union import warnings @@ -311,141 +310,6 @@ def _check_cuda_compute_capability(devices_to_check): ) -def _check_cuda_versions(raise_on_first_error: bool = False, - debug: bool = False): - assert cuda_versions is not None - results: list[dict[str, Any]] = [] - - def _make_msg(name: str, - runtime_version: int, - build_version: int, - min_supported: int, - debug_msg: bool = False): - if debug_msg: - return (f"Package: {name}\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}") - if min_supported: - req_str = (f"The local installation version must be no lower than " - f"{min_supported}.") - else: - req_str = ("The local installation must be the same version as " - "the version against which JAX was built.") - msg = (f"Outdated {name} installation found.\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}\n" - f"{req_str}") - return msg - - def _version_check(name: str, - get_version, - get_build_version, - scale_for_comparison: int = 1, - min_supported_version: int = 0): - """Checks the runtime CUDA component version against the JAX one. - - Args: - name: Of the CUDA component. - get_version: A function to get the local runtime version of the component. - get_build_version: A function to get the build version of the component. - scale_for_comparison: For rounding down a version to ignore patch/minor. - min_supported_version: An absolute minimum version required. Must be - passed without rounding down. - - Raises: - RuntimeError: If the component is not found, or is of unsupported version, - and if raising the error is not deferred till later. - """ - - build_version = get_build_version() - try: - version = get_version() - except Exception as e: - err_msg = f"Unable to load {name}. Is it installed?" - if raise_on_first_error: - raise RuntimeError(err_msg) from e - err_msg += f"\n{traceback.format_exc()}" - results.append({"name": name, "installed": False, "msg": err_msg}) - return - - if not min_supported_version: - min_supported_version = build_version // scale_for_comparison - passed = min_supported_version <= version - - if not passed or debug: - msg = _make_msg(name=name, - runtime_version=version, - build_version=build_version, - min_supported=min_supported_version, - debug_msg=passed) - if not passed and raise_on_first_error: - raise RuntimeError(msg) - else: - record = {"name": name, - "installed": True, - "msg": msg, - "passed": passed, - "build_version": build_version, - "version": version, - "minimum_supported": min_supported_version} - results.append(record) - - _version_check("CUDA", cuda_versions.cuda_runtime_get_version, - cuda_versions.cuda_runtime_build_version, - scale_for_comparison=10, - min_supported_version=12010) - _version_check( - "cuDNN", - cuda_versions.cudnn_get_version, - cuda_versions.cudnn_build_version, - # NVIDIA promise both backwards and forwards compatibility for cuDNN patch - # versions: - # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat - scale_for_comparison=100, - min_supported_version=9100 - ) - _version_check("cuFFT", cuda_versions.cufft_get_version, - cuda_versions.cufft_build_version, - # Ignore patch versions. - scale_for_comparison=100) - _version_check("cuSOLVER", cuda_versions.cusolver_get_version, - cuda_versions.cusolver_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=11400) - _version_check("cuPTI", cuda_versions.cupti_get_version, - cuda_versions.cupti_build_version, - min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=120100) - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=12100) - - errors = [] - debug_results = [] - for result in results: - message: str = result['msg'] - if not result['installed'] or not result['passed']: - errors.append(message) - else: - debug_results.append(message) - - join_str = f'\n{"-" * 50}\n' - if debug_results: - print(f'CUDA components status (debug):\n' - f'{join_str.join(debug_results)}') - if errors: - raise RuntimeError(f'Unable to use CUDA because of the ' - f'following issues with CUDA components:\n' - f'{join_str.join(errors)}') def get_num_nodes_from_gpu_topology(topology: str) -> int: try: diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 1be29326c95f..9df7fc69ff1a 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -17,6 +17,8 @@ import logging import os import pathlib +import traceback +from typing import Any from jax._src.lib import triton from jax._src.lib import xla_client @@ -29,8 +31,12 @@ cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' ) + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) except ImportError: cuda_plugin_extension = None + cuda_versions = None else: break @@ -76,11 +82,153 @@ def _get_library_path(): return None +def _check_cuda_versions(raise_on_first_error: bool = False, + debug: bool = False): + assert cuda_versions is not None + results: list[dict[str, Any]] = [] + + def _make_msg(name: str, + runtime_version: int, + build_version: int, + min_supported: int, + debug_msg: bool = False): + if debug_msg: + return (f"Package: {name}\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}") + if min_supported: + req_str = (f"The local installation version must be no lower than " + f"{min_supported}.") + else: + req_str = ("The local installation must be the same version as " + "the version against which JAX was built.") + msg = (f"Outdated {name} installation found.\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}\n" + f"{req_str}") + return msg + + def _version_check(name: str, + get_version, + get_build_version, + scale_for_comparison: int = 1, + min_supported_version: int = 0): + """Checks the runtime CUDA component version against the JAX one. + + Args: + name: Of the CUDA component. + get_version: A function to get the local runtime version of the component. + get_build_version: A function to get the build version of the component. + scale_for_comparison: For rounding down a version to ignore patch/minor. + min_supported_version: An absolute minimum version required. Must be + passed without rounding down. + + Raises: + RuntimeError: If the component is not found, or is of unsupported version, + and if raising the error is not deferred till later. + """ + + build_version = get_build_version() + try: + version = get_version() + except Exception as e: + err_msg = f"Unable to load {name}. Is it installed?" + if raise_on_first_error: + raise RuntimeError(err_msg) from e + err_msg += f"\n{traceback.format_exc()}" + results.append({"name": name, "installed": False, "msg": err_msg}) + return + + if not min_supported_version: + min_supported_version = build_version // scale_for_comparison + passed = min_supported_version <= version + + if not passed or debug: + msg = _make_msg(name=name, + runtime_version=version, + build_version=build_version, + min_supported=min_supported_version, + debug_msg=passed) + if not passed and raise_on_first_error: + raise RuntimeError(msg) + else: + record = {"name": name, + "installed": True, + "msg": msg, + "passed": passed, + "build_version": build_version, + "version": version, + "minimum_supported": min_supported_version} + results.append(record) + + _version_check("CUDA", cuda_versions.cuda_runtime_get_version, + cuda_versions.cuda_runtime_build_version, + scale_for_comparison=10, + min_supported_version=12010) + _version_check( + "cuDNN", + cuda_versions.cudnn_get_version, + cuda_versions.cudnn_build_version, + # NVIDIA promise both backwards and forwards compatibility for cuDNN patch + # versions: + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility + scale_for_comparison=100, + ) + _version_check("cuFFT", cuda_versions.cufft_get_version, + cuda_versions.cufft_build_version, + # Ignore patch versions. + scale_for_comparison=100) + _version_check("cuSOLVER", cuda_versions.cusolver_get_version, + cuda_versions.cusolver_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=11400) + _version_check("cuPTI", cuda_versions.cupti_get_version, + cuda_versions.cupti_build_version, + min_supported_version=18) + _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) + + errors = [] + debug_results = [] + for result in results: + message: str = result['msg'] + if not result['installed'] or not result['passed']: + errors.append(message) + else: + debug_results.append(message) + + join_str = f'\n{"-" * 50}\n' + if debug_results: + print(f'CUDA components status (debug):\n' + f'{join_str.join(debug_results)}') + if errors: + raise RuntimeError(f'Unable to use CUDA because of the ' + f'following issues with CUDA components:\n' + f'{join_str.join(errors)}') + + def initialize(): path = _get_library_path() if path is None: return + if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): + _check_cuda_versions(raise_on_first_error=True) + else: + print('Skipped CUDA versions constraints check due to the ' + 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') + options = xla_client.generate_pjrt_gpu_plugin_options() c_api = xb.register_plugin( 'cuda', priority=500, library_path=str(path), options=options From bf4fda96a936aab5adc6430763419c3bdb3d9495 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 May 2025 17:31:51 -0700 Subject: [PATCH 1186/1769] Add numpy and absl/testing dep to custom_api_test. Fixes https://github.com/jax-ml/jax/actions/runs/15031061909/job/42243435305 PiperOrigin-RevId: 758898284 --- tests/BUILD | 5 ++++- tests/pallas/BUILD | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index e70d4593e8fc..1d1f3c28b239 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -50,7 +50,10 @@ jax_multiplatform_test( shard_count = 10, deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 49a05ee487f0..21ab7ea1a482 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -427,7 +427,11 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_mosaic_gpu", "//jax:test_multiprocess", - ] + py_deps("portpicker"), + ] + py_deps([ + "portpicker", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( From 9a1535e210dcb29f2b0f7494a773ed9aee5b9049 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 14 May 2025 17:57:12 -0700 Subject: [PATCH 1187/1769] [JAX] Fix unhashable slice in api_test `slice` is not hashable before Python 3.12. This change mitigates it by converting it into a hash value. PiperOrigin-RevId: 758905560 --- tests/api_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/api_test.py b/tests/api_test.py index 1fd192a52525..3fe3d6fa7514 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -54,6 +54,7 @@ from jax._src import xla_bridge from jax._src import debugging from jax._src import pjit as pjit_lib +from jax._src import sharding_impls from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir @@ -2047,10 +2048,12 @@ def test_internal_device_put_assembled(self): per_device_arrs = { # Use uncommitted arrays that are not aligned with the destination # sharding so that we trigger `BatchedDevicePut`. - index: jnp.array(arr[index]) + sharding_impls.hashed_index(index): jnp.array(arr[index]) for _, index in sharding.devices_indices_map(arr.shape).items() } - data_callback = lambda index: per_device_arrs[index] + data_callback = lambda index: per_device_arrs[ + sharding_impls.hashed_index(index) + ] with jtu.count_internal_device_puts() as counts: jax.make_array_from_callback(arr.shape, sharding, data_callback) self.assertEqual( From ec72f173cf98d95d1537b1f3b9f6720e2a032203 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 18:36:56 -0700 Subject: [PATCH 1188/1769] Fix CI build failure on Mac. We must not depend on the nvidia_nvshmem_cu12 pip package directly since it does not exist on Windows and Mac platforms. PiperOrigin-RevId: 758917499 --- jax/BUILD | 1 - jaxlib/jax.bzl | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 3f73a4b9e68f..4018bff873bd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -887,7 +887,6 @@ pytype_strict_library( py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), - data = py_deps("libnvshmem_device"), visibility = [ ":mosaic_gpu_users", ], diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 1d4e24720c2e..a8dc67eb3804 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -100,7 +100,6 @@ _py_deps = { "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], "zstandard": get_zstandard(), - "libnvshmem_device": ["@pypi//nvidia_nvshmem_cu12"], } def all_py_deps(excluded = []): @@ -188,14 +187,17 @@ def _gpu_test_deps(): "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", + "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_false": [ "@pypi//jax_cuda12_plugin", "@pypi//jax_cuda12_pjrt", + "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_wheel": [ "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", + "@pypi//nvidia_nvshmem_cu12", ], }) From 7cbdc3c2defe2bb1fdec55b7157a0d2b43fcd27b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 15 May 2025 00:54:31 -0700 Subject: [PATCH 1189/1769] [pallas] Do not emit verbose lowering errors by default The errors are too verbose and mostly not very useful. PiperOrigin-RevId: 759025165 --- jax/_src/pallas/pallas_call.py | 2 +- tests/pallas/BUILD | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index def8efd472c6..2d27bd3cc485 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1235,7 +1235,7 @@ def _trace_kernel_to_jaxpr( _PALLAS_VERBOSE_ERRORS = config.bool_flag( "jax_pallas_verbose_errors", - default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), + default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", False), help=( "If True, print verbose error messages for Pallas kernels." ), diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 21ab7ea1a482..109af4213b81 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -162,7 +162,6 @@ jax_multiplatform_test( ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", - "JAX_PALLAS_VERBOSE_ERRORS": "0", }, shard_count = 16, tags = [ @@ -239,9 +238,6 @@ jax_multiplatform_test( "gpu_h100_x32", "gpu_h100", ], - env = { - "JAX_PALLAS_VERBOSE_ERRORS": "0", - }, deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", # build_cleaner: keep From 0a0368bb2add564122617bb68caf731942d41279 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 15 May 2025 03:36:51 -0700 Subject: [PATCH 1190/1769] #sdy Properly handle token types in JAX and `ManualComputationOp`. We weren't handling them correctly meaning you couldn't use a `shard_map`/`ManualComputationOp` which has callbacks inside. PiperOrigin-RevId: 759072597 --- jax/_src/callback.py | 6 ++- jax/_src/debugging.py | 8 ++-- jax/_src/shard_map.py | 74 +++++++++++++++++++++++++++-------- tests/pjit_test.py | 2 - tests/python_callback_test.py | 2 - tests/shard_map_test.py | 19 ++++++++- 6 files changed, 84 insertions(+), 27 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bc233b634f3c..7fcccac14950 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -154,13 +154,15 @@ def _callback_op_sharding( " computations" ) if config.use_shardy_partitioner.value: - assert len(avals_out) == 1 + ndim = 0 + if avals_out and isinstance(avals_out[0], core.ShapedArray): + ndim = avals_out[0].ndim op_sharding = sharding_impls.SdyArrayList([ sharding_impls.SdyArray( mesh_shape=(), dim_shardings=[ sharding_impls.SdyDim(axes=[], is_open=False) - ] * avals_out[0].ndim, + ] * ndim, logical_device_ids=())]) else: op_sharding = xc.OpSharding() # type: ignore[assignment] diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index c2febf752b92..e931a6edb9b3 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -164,14 +164,16 @@ def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params) # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. if config.use_shardy_partitioner.value: - assert len(ctx.avals_out) == 1 + ndim = 0 + if ctx.avals_out and isinstance(ctx.avals_out[0], core.ShapedArray): + ndim = ctx.avals_out[0].ndim sharding = sharding_impls.SdyArrayList([ sharding_impls.SdyArray( mesh_shape=(), dim_shardings=[ sharding_impls.SdyDim(axes=[], is_open=False) - ] * ctx.avals_out[0].ndim, - logical_device_ids=())]) + ] * ndim, + logical_device_ids=(0,))]) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MANUAL diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index abcc2ca0acf1..ac529a667d04 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -44,7 +44,7 @@ get_abstract_mesh, get_concrete_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import sdy +from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) @@ -786,6 +786,13 @@ def _shardy_shard_map_sharding( return sdy_sharding +def _shardy_shard_map_token_sharding( + ctx: mlir.LoweringRuleContext, mesh + ) -> ir.Attribute: + ns = _make_scoped_manual_sharding(ctx, mesh, {}) + return ns._to_sdy_sharding(0) + + def _shard_map_lowering_shardy( ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma): axis_ctx = ctx.module_context.axis_context @@ -799,36 +806,70 @@ def _shard_map_lowering_shardy( new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()] + num_tokens = len(tokens) manual_axes = order_wrt_mesh(mesh, shardy_manual_axes) if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, + args = (*ctx.dim_var_values, *tokens, *in_nodes) + out_nodes, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip(ctx.tokens_in.effects(), in_nodes[:num_tokens])), + (), *args[num_tokens:], dim_var_values=ctx.dim_var_values) - return out_nodes + num_tokens = len(tokens_out.effects()) + tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( + ctx.tokens_in.effects(), out_nodes[:num_tokens]))) + ctx.set_tokens_out(tokens_out) + return out_nodes[num_tokens:] - in_shardings = sharding_impls.SdyArrayList(map( + in_shardings = list(map( partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), - in_names, ctx.avals_in)).build() - out_shardings = sharding_impls.SdyArrayList(map( + in_names, ctx.avals_in)) + num_dim_vars = len(ctx.dim_var_values) + in_shardings = ([_shardy_shard_map_token_sharding(ctx, mesh)] + * (num_tokens + num_dim_vars) + in_shardings) + in_shardings = sharding_impls.SdyArrayList(in_shardings).build() + + out_shardings = list(map( partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), - out_names, ctx.avals_out)).build() - output_types = map(mlir.aval_to_ir_type, ctx.avals_out) + out_names, ctx.avals_out)) + out_shardings = [ + _shardy_shard_map_token_sharding(ctx, mesh)] * num_tokens + out_shardings + out_shardings = sharding_impls.SdyArrayList(out_shardings).build() + + output_types = ([hlo.TokenType.get()] * num_tokens + + list(map(mlir.aval_to_ir_type, ctx.avals_out))) + + args = (*ctx.dim_var_values, *tokens, *in_nodes) manual_computation_op = sdy.ManualComputationOp( - output_types, in_nodes, in_shardings, out_shardings, + output_types, + mlir.flatten_ir_values(args), + in_shardings, out_shardings, sdy.ManualAxesAttr.get( ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) block = ir.Block.create_at_start( - manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) + manual_computation_op.body, + (*(i if isinstance(i, ir.Type) else i.type for i in ctx.dim_var_values), + *([hlo.TokenType.get()] * num_tokens), + *map(mlir.aval_to_ir_type, in_avals_))) with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma)): - out_nodes_, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, + out_nodes_, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip( + ctx.tokens_in.effects(), block.arguments[:num_tokens])), + (), *block.arguments[num_tokens+num_dim_vars:], dim_var_values=ctx.dim_var_values) - sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) + sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()], + *out_nodes_)]) + num_tokens = len(tokens_out.effects()) + tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( + ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens]))) + ctx.set_tokens_out(tokens_out) - return manual_computation_op.results + return manual_computation_op.results[num_tokens:] def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, @@ -846,7 +887,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, - out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, + out_avals_, ctx.tokens_in, *in_nodes_, + dim_var_values=ctx.dim_var_values, arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) ctx.set_tokens_out(tokens_out) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d1c8ec7f050d..7d87705e20bd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4485,8 +4485,6 @@ def test_in_out_shardings_unconstrained_error(self): in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) def test_empty_io_callback_under_shard_map(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 9f7336548d12..9a3b26530044 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1364,8 +1364,6 @@ def f_base(i, x): self.assertEqual(_collected, expected) def test_can_shard_io_callback_manually(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 00d437aadb08..2fdc846a356b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -869,8 +869,6 @@ def test_shmap_abstract_mesh_errors(self): @jtu.run_on_devices('cpu', 'gpu', 'tpu') @jtu.thread_unsafe_test() def test_debug_print_jit(self, jit): - if config.use_shardy_partitioner.value: - self.skipTest('TODO(b/384938613): Failing under shardy') mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -892,6 +890,23 @@ def f(x): for i in range(len(jax.devices())): self.assertIn(f'instance {i} has value', output()) + @jtu.run_on_devices('cpu', 'gpu', 'tpu') + @jtu.thread_unsafe_test() + def test_debug_print_jit_partial_auto(self): + mesh = jtu.create_mesh((2,2), ('x', 'y')) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + axis_names=frozenset({'x'})) + def f(x): + idx = jax.lax.axis_index('x') + jax.debug.print("instance {i} has value x={x}", i=idx, x=x) + y = jnp.cos(x) + return y + + f = jax.jit(f) + x = jnp.arange(2 * len(jax.devices())) + f(x) # don't crash! + def test_debug_print_eager(self): mesh = Mesh(jax.devices(), ('i',)) From 9c3b7f07568d8caa0f5a09538b2bb56a2295203c Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 15 May 2025 04:57:47 -0700 Subject: [PATCH 1191/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ab7cea20271d8a24a7309e09fc5af486dde8e155. PiperOrigin-RevId: 759095567 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4a8a44139a67..f15afc93cefb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "caef87114bc75bd89f56b8562d32cd7b8887319f" -XLA_SHA256 = "ae3e9cd59cfcacdd8dbdf7e808b699fc728228654980b0766730a54fcaee0201" +XLA_COMMIT = "ab7cea20271d8a24a7309e09fc5af486dde8e155" +XLA_SHA256 = "4c5b4fb5401b26f140100c308e549ff6fd6d11daabc6b0340fe353b25a4b0725" def repo(): tf_http_archive( From afdf51d797da6b851b80d40fda856040e75c641e Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 15 May 2025 07:14:25 -0700 Subject: [PATCH 1192/1769] #sdy Fix incorrect sharding on a token during a callback. The "add a token" part of the `callback` primitive's MLIR lowering was incorrectly adding a ranked sharding by using the sharding of a ranked tensor. So instead create an unranked sharding explicitly PiperOrigin-RevId: 759135477 --- jax/_src/callback.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 7fcccac14950..5b5ec593a550 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -40,7 +40,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import SdyArray, SdyArrayList, SingleDeviceSharding +from jax._src.sharding_impls import SdyArray, SdyArrayList, SdyDim, SingleDeviceSharding from jax._src.typing import DeprecatedArg import numpy as np @@ -157,11 +157,11 @@ def _callback_op_sharding( ndim = 0 if avals_out and isinstance(avals_out[0], core.ShapedArray): ndim = avals_out[0].ndim - op_sharding = sharding_impls.SdyArrayList([ - sharding_impls.SdyArray( + op_sharding = SdyArrayList([ + SdyArray( mesh_shape=(), dim_shardings=[ - sharding_impls.SdyDim(axes=[], is_open=False) + SdyDim(axes=[], is_open=False) ] * ndim, logical_device_ids=())]) else: @@ -199,8 +199,8 @@ def _callback_op_sharding( # number of result ops. If there are no result ops, we need 1 shardy # annotation. num_sdy_shardings = max(1, len(avals_out)) - op_sharding = sharding_impls.SdyArrayList(num_sdy_shardings * [ - sharding_impls.SdyArray( + op_sharding = SdyArrayList(num_sdy_shardings * [ + SdyArray( mesh_shape=(), dim_shardings=[], logical_device_ids=(device_index,))]) @@ -838,14 +838,17 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function config.use_shardy_partitioner.value and sharding is not None and len(ctx.avals_out) > 0 - and isinstance(sharding, sharding_impls.SdyArrayList) + and isinstance(sharding, SdyArrayList) ): # Add a sharding annotation for the token if we have at least one # output. Otherwise, the single shardy annotation required of all ops # (even those without any results) can annotate the token. - sharding = sharding_impls.SdyArrayList( - [*sharding.shardings, sharding.shardings[-1]] - ) + sharding = SdyArrayList([ + SdyArray( + mesh_shape=(), + dim_shardings=[], + logical_device_ids=()), + *sharding.shardings]) ctx = dataclasses.replace( ctx, avals_in=[core.abstract_token, *ctx.avals_in], From 0984dc8bbcb9406a86111c700cb7cbbb3faedbe8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 14 May 2025 21:44:09 +0000 Subject: [PATCH 1193/1769] improve error message with when custom_vjp bwd rule produces wrong shape/dtype --- jax/_src/core.py | 20 ++++++++++++++++++++ jax/_src/custom_derivatives.py | 3 ++- jax/_src/lax/control_flow/common.py | 20 -------------------- jax/_src/lax/control_flow/conditionals.py | 6 +++--- jax/_src/lax/control_flow/loops.py | 4 ++-- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index e004263abe71..c730e1c289ae 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2777,6 +2777,26 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: else: return False +def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str: + assert not typematch(a1, a2) + if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray): + mismatches = [] + if a1.dtype != a2.dtype: + mismatches.append('the dtypes do not match') + if a1.shape != a2.shape: + mismatches.append('the shapes do not match') + if a1.vma != a2.vma: + mismatches.append('the varying manual axes do not match') + # TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch + + if len(mismatches) == 0: + return '' + elif len(mismatches) == 1: + return ', so ' + mismatches[0] + else: + return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1] + return '' + class JaxprTypeError(TypeError): pass custom_typechecks: dict[Primitive, Callable] = {} diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 7b81c4e86889..9b28595e1835 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -917,7 +917,8 @@ def append(x, d): "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " f"shape/dtype {a_.str_short()} corresponding " - f"to an input of shape/dtype {a.str_short()}.") + f"to an input of shape/dtype {a.str_short()}" + f"{core.aval_mismatch_extra(a, a_)}") raise ValueError(msg) results.append(ct) return results diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 87dbcd8d3f32..b75cbf6ac708 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -260,23 +260,3 @@ def _show_diff(array1, array2): def _avals_short(avals): to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))() return ' '.join(map(to_str, avals)) - -def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: - assert not core.typematch(a1, a2) - if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): - mismatches = [] - if a1.dtype != a2.dtype: - mismatches.append('the dtypes do not match') - if a1.shape != a2.shape: - mismatches.append('the shapes do not match') - if a1.vma != a2.vma: - mismatches.append('the varying manual axes do not match') - # TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch - - if len(mismatches) == 0: - return '' - elif len(mismatches) == 1: - return ', so ' + mismatches[0] - else: - return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1] - return '' diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d875989921d0..741636c47e31 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -53,8 +53,8 @@ import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, _typecheck_param, _aval_mismatch_extra, - _initial_style_jaxprs_with_common_consts, _make_closed_jaxpr, _prune_zeros) + _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts, + _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -351,7 +351,7 @@ def _check_branch_outputs( if not all(map(core.typematch, out_avals1, out_avals2)): diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}' f' but the corresponding output of {name2} has type ' - f'{a2.str_short()}{_aval_mismatch_extra(a1, a2)}' + f'{a2.str_short()}{core.aval_mismatch_extra(a1, a2)}' for p, a1, a2 in zip(paths, out_avals1, out_avals2) if not core.typematch(a1, a2)] if len(diffs) == 0: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c85a23b6b199..7efe3294fdca 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -51,7 +51,7 @@ from jax._src.lax.control_flow.common import ( _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, - _typecheck_param, _aval_mismatch_extra) + _typecheck_param) from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -478,7 +478,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): if not all(_map(core.typematch, in_avals, out_avals)): diffs = [f'{component(path)} has type {in_aval.str_short()}' ' but the corresponding output carry component has type ' - f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' + f'{out_aval.str_short()}{core.aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] From 594e1d233a48c7b1f5eb6372e9baf014cc21f794 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 15 May 2025 13:34:47 -0400 Subject: [PATCH 1194/1769] Workaround a crash on aarch64 due to a NumPy bug. See: https://github.com/numpy/numpy/issues/28843 --- .bazelrc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bazelrc b/.bazelrc index 53676637c839..367138b2d026 100644 --- a/.bazelrc +++ b/.bazelrc @@ -244,6 +244,10 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" build:ci_linux_aarch64_base --color=yes +# Workaround for https://github.com/numpy/numpy/issues/28843 +# TODO(phawkins): remove this after upgrading to NumPy 2.2.6. +build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8 + build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" From 2618a231bead6de7611c9e2ae4f408c184a743f2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 15 May 2025 14:21:03 -0700 Subject: [PATCH 1195/1769] [pre-commit] bump pre-commit version to v5.0.0 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fcfdb54bada..a6697076404f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 + rev: cef0300fd0fc4d2a87a85fa2093c6b283ea36f4b # frozen: v5.0.0 hooks: - id: check-ast - id: check-merge-conflict From 8ff60c3422e3ac93e5c2de31e4f7ea7fbd90ba92 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 15 May 2025 14:11:50 -0700 Subject: [PATCH 1196/1769] [pre-commit] update mypy to v1.15.0 --- .pre-commit-config.yaml | 2 +- jax/_src/checkify.py | 2 +- jax/_src/export/_export.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6697076404f..46deb8eb4879 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1 + rev: 'f40886d54c729f533f864ed6ce584e920feb0af7' # frozen: v1.15.0 hooks: - id: mypy files: (jax/|tests/typing_test\.py) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index aa9bfe9529ce..f26b4222b23b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -264,7 +264,7 @@ def _get_batched_exception(self) -> BatchedError | None: cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore - if min_code is None or code[idx] < min_code: + if min_code is None or code[idx] < min_code: # type: ignore[index] min_code = code[idx] # type: ignore cur_effect = error_effect diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d5a328bb8e05..c0ca1e108590 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -760,8 +760,8 @@ def export_sharding(s: LoweringSharding, elif cur_mesh.shape_tuple != sharding.mesh.shape_tuple: raise ValueError( "Mesh for all inputs/outputs should be equal. Got one mesh " - f"{cur_mesh} on an array {cur_arg._aval} at " - f"{shape_poly.args_kwargs_path_to_str(cur_k_path)} and another mesh: " + f"{cur_mesh} on an array {cur_arg._aval} at " # type: ignore[union-attr] + f"{shape_poly.args_kwargs_path_to_str(cur_k_path)} and another mesh: " # type: ignore[arg-type] f"{sharding.mesh}' on a tensor {arg._aval} at " f"{shape_poly.args_kwargs_path_to_str(k_path)}") if cur_mesh and isinstance(cur_mesh, mesh_lib.Mesh): From a043de0e7ee4cfe3c793b16b43d4626efdc55fc1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 8 May 2025 20:51:35 +0000 Subject: [PATCH 1197/1769] [hijax] landing prototype pieces shouldn't affect existing behaviors, or trace time The main implementation ideas: * each Trace is tagged with a `requires_low: bool` * each Jaxpr * is tagged with an `is_high: bool`, default False but set True while tracing if any hijax primitives are encountered * includes an `mut_types: dict[Var, HijaxType]` indicating final types for type-changing mutable hijax types * each AbstractValue is tagged by a `mutable: bool` which is read to populate `mut_types` * each Primitive * has an `is_high(**params) -> bool` method (depends on params for HOPs) * has a `to_lojax(*args, **params)` method taking and returning hijaxtypes-wrapping-lowtracers * in `Primitive.bind`, we check if `prim.is_high(**params) and trace.requires_low`, and if so we call `prim.to_lojax` Co-authored-by: Dougal Maclaurin --- jax/_src/ad_util.py | 5 + jax/_src/api.py | 24 ++- jax/_src/core.py | 36 +++- jax/_src/interpreters/ad.py | 6 +- jax/_src/interpreters/partial_eval.py | 68 ++++-- jax/_src/pjit.py | 37 +++- tests/attrs_test.py | 293 ++++++++++++++++++++++++++ 7 files changed, 434 insertions(+), 35 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 8cfd7b214338..4e9616e48375 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,9 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + ty = core.typeof(x) + if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr + return ty.vspace_add(x, y) x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) @@ -48,6 +51,8 @@ def add_abstract(x, y): return x def zeros_like_aval(aval: core.AbstractValue) -> Array: + if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr + return aval.vspace_zero() return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} diff --git a/jax/_src/api.py b/jax/_src/api.py index 154ff5132a39..33802e494304 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -540,17 +540,18 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") - if (dtypes.issubdtype(aval.dtype, dtypes.extended) or - dtypes.issubdtype(aval.dtype, np.integer) or - dtypes.issubdtype(aval.dtype, np.bool_)): - if not allow_int: - raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " - f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " - "If you want to use Boolean- or integer-valued inputs, use vjp " - "or set allow_int to True.") - elif not dtypes.issubdtype(aval.dtype, np.inexact): - raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " - f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") + if isinstance(aval, ShapedArray): + if (dtypes.issubdtype(aval.dtype, dtypes.extended) or + dtypes.issubdtype(aval.dtype, np.integer) or + dtypes.issubdtype(aval.dtype, np.bool_)): + if not allow_int: + raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " + f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " + "If you want to use Boolean- or integer-valued inputs, use vjp " + "or set allow_int to True.") + elif not dtypes.issubdtype(aval.dtype, np.inexact): + raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " + f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad") def _check_output_dtype_revderiv(name, holomorphic, x): @@ -1873,6 +1874,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): f"structure; primals have tree structure {tree_def} whereas tangents have " f"tree structure {tree_def_2}.") for p, t in zip(ps_flat, ts_flat): + if not isinstance(core.typeof(p), ShapedArray): continue if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): raise TypeError("primal and tangent arguments to jax.jvp do not match; " "dtypes must be equal, or in case of int/bool primal dtype " diff --git a/jax/_src/core.py b/jax/_src/core.py index e004263abe71..12e2763d7202 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -88,7 +88,7 @@ class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info'] + '_effects', '_debug_info', '_is_high', '_mut_types'] _constvars: list[Var] _invars: list[Var] @@ -96,6 +96,8 @@ class Jaxpr: _eqns: list[JaxprEqn] _effects: Effects _debug_info: DebugInfo + _is_high: bool + _mut_types: dict[Var, Any] @property def constvars(self) -> list[Var]: @@ -121,6 +123,14 @@ def effects(self) -> Effects: def debug_info(self) -> DebugInfo: return self._debug_info + @property + def is_high(self) -> bool: + return self._is_high + + @property + def mut_types(self) -> dict[Var, Any]: + return self._mut_types + def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, @@ -128,6 +138,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # compatibility we have to allow calls when the debug_info # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] + is_high: bool = False, + mut_types: dict | None = None, ): """ Args: @@ -152,6 +164,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # TODO(necula): re-enable these safety checks # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + self._is_high = is_high + self._mut_types = mut_types or {} def __str__(self): return str(self.pretty_print()) @@ -178,6 +192,8 @@ def replace(self, **kwargs): eqns=kwargs.pop("eqns", self.eqns), effects=kwargs.pop("effects", self.effects), debug_info=kwargs.pop("debug_info", self.debug_info), + is_high=kwargs.pop("is_high", self.is_high), + mut_types=kwargs.pop("mut_types", self.mut_types), ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") @@ -517,7 +533,7 @@ def _true_bind(self, *args, **params): for arg in args: if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) - # TODO: figure out how to handle function arguments + # TODO: figure out how to handle function arguments for this assert # assert (not config.enable_checks.value or # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args @@ -525,6 +541,10 @@ def _true_bind(self, *args, **params): # is called frequently and it's slightly faster to avoid using a context # manager object. prev_trace = trace_ctx.trace + + if self.is_high(**params) and prev_trace.requires_low: + return self.to_lojax(*args, **params) # type: ignore + trace_ctx.set_trace(eval_trace) try: return self.bind_with_trace(prev_trace, args, params) @@ -561,6 +581,9 @@ def abstract_eval(self, *args, **params): def get_bind_params(self, params): return [], params + def is_high(self, **params) -> bool: + return False + def _effect_free_abstract_eval(abstract_eval): def abstract_eval_(*args, **kwargs): @@ -627,12 +650,13 @@ def check_avals_context_mesh(avals, prim_name): TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): - __slots__ = ("__weakref__", "_invalidated", "_weakref") + __slots__ = ("__weakref__", "_invalidated", "_weakref", "requires_low") def __init__(self): self._invalidated = False # We frequently need a weakref to a trace, so let's precompute one. self._weakref = weakref.ref(self) + self.requires_low = True def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") @@ -1445,6 +1469,8 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] + is_high = False + mutable = False def to_tangent_aval(self): raise NotImplementedError("must override") @@ -1948,6 +1974,10 @@ def __init__(self, shape, dtype, weak_type=False, *, sharding=None, self.sharding = get_sharding(sharding, self.shape) self.vma = get_vma(vma, self.sharding.mesh) + def lower_val(self, val): return [val] + def raise_val(self, val): return val + def lo_ty(self): return [self] + def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: shape = self.shape diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 45705382efa0..090022c9b6a4 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -73,6 +73,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and isinstance(core.typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -475,6 +476,7 @@ def __init__(self, parent_trace, tag): super().__init__() self.tag = tag self.parent_trace = parent_trace + self.requires_low = False def to_primal_tangent_pair(self, val): if isinstance(val, JVPTracer) and val._trace.tag is self.tag: @@ -606,7 +608,8 @@ def process_custom_transpose(self, prim, call, tracers, **params): return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) def maybe_jvp_tracer(trace, primal, tangent): - if type(tangent) is Zero or dtype(tangent) == float0: + if (type(tangent) is Zero or + core.typeof(tangent) is core.ShapedArray and dtype(tangent) == float0): return primal else: return JVPTracer(trace, primal, tangent) @@ -641,6 +644,7 @@ def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() tangent_aval = get_aval(tangent).strip_weak_type() + if not isinstance(primal_aval, core.ShapedArray): return # TODO(mattjj,dougalm) assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5866b0c5f8eb..9e875f43d831 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -155,6 +155,7 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t self.name_stack = name_stack self.tag = tag self.parent_trace = parent_trace + self.requires_low = False def to_jaxpr_tracer(self, x): if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: @@ -899,9 +900,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: raise NotImplementedError config.enable_checks.value and core.check_jaxpr(jaxpr) env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) - converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, - invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=jaxpr.debug_info) + converted_jaxpr = jaxpr.replace(constvars=jaxpr.constvars + env_vars, + invars=invars) config.enable_checks.value and core.check_jaxpr(converted_jaxpr) return converted_jaxpr @@ -1173,6 +1173,7 @@ def has_effects(effects) -> bool: out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns) out_inst = map(op.or_, out_inst, ensure_out_inst) + ins_known, _ = partition_list(in_unknowns, jaxpr.invars) outs_known, _ = partition_list(out_unknowns, jaxpr.outvars) ref_res_is_input = [r in ins_known for r in residual_refs] @@ -1181,8 +1182,14 @@ def has_effects(effects) -> bool: known_outvars = [*outs_known, *residuals] known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) - jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, - known_eqns, known_effects, jaxpr.debug_info) + known_mut, staged_mut, ins_known_ = {}, {}, set(ins_known) # type: ignore + for v, t in jaxpr.mut_types.items(): + [staged_mut, known_mut][v in ins_known_][v] = t + + # TODO(mattjj,necula): debug info should be updated here + jaxpr_known = jaxpr.replace( + invars=ins_known_and_ref_res, outvars=known_outvars, + eqns=known_eqns, effects=known_effects, mut_types=known_mut) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1190,9 +1197,10 @@ def has_effects(effects) -> bool: staged_invars = [*residuals, *non_input_res_refs, *ins_staged] staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, outs_staged, staged_eqns) - jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns, staged_effects, - jaxpr.debug_info) + # TODO(mattjj,necula): debug info should be updated here + jaxpr_staged = jaxpr.replace( + invars=staged_invars, outvars=outs_staged, eqns=staged_eqns, + effects=staged_effects, mut_types=staged_mut) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1483,7 +1491,8 @@ def write(x: Atom, b: bool) -> None: jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.filter_arg_names(used_inputs), jaxpr.debug_info.filter_result_paths(used_outputs)) - new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) + new_jaxpr = jaxpr.replace(invars=invars, outvars=outvars, eqns=eqns, + effects=jaxpr_effects, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) return new_jaxpr, used_inputs @@ -1561,9 +1570,8 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] new_invars = _move_to_front(invars, to_move) new_effs = _renumber_effects( (*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects) - new_jaxpr = Jaxpr(constvars, new_invars, closed_jaxpr.jaxpr.outvars, - closed_jaxpr.jaxpr.eqns, new_effs, - closed_jaxpr.jaxpr.debug_info) + new_jaxpr = closed_jaxpr.jaxpr.replace( + constvars=constvars, invars=new_invars, effects=new_effs) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr @@ -1704,6 +1712,7 @@ class JaxprStackFrame: attrs_inits: list attrs_vars: list[Var] debug_info: core.DebugInfo + is_high: bool def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() @@ -1718,6 +1727,7 @@ def __init__(self, debug_info: core.DebugInfo): self.attrs_inits = [] self.attrs_vars = [] self.debug_info = debug_info + self.is_high = False def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) @@ -1743,8 +1753,9 @@ def to_jaxpr( outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) + mut_types = {v: v.aval for v in invars if v.aval.mutable} if self.is_high else {} jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info) + debug_info, self.is_high, mut_types) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1831,8 +1842,9 @@ def vars(atom: Atom) -> list[Var]: class DynamicJaxprTrace(core.Trace): __slots__ = ("frame", "tag") - def __init__(self, debug_info: core.DebugInfo): + def __init__(self, debug_info: core.DebugInfo, lower=False): super().__init__() + self.requires_low = lower self.frame = JaxprStackFrame(debug_info) def invalidate(self): @@ -2193,10 +2205,11 @@ def trace_to_jaxpr_dynamic( in_avals: Sequence[AbstractValue], *, keep_inputs: list[bool] | None = None, + lower: bool = False, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - trace = DynamicJaxprTrace(fun.debug_info) + trace = DynamicJaxprTrace(fun.debug_info, lower=lower) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() in_tracers = _input_type_to_tracers( @@ -2418,8 +2431,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]: kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) out_type = tuple(zip(out_avals, kept_outs)) - new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, - jaxpr.effects, jaxpr.debug_info) + new_jaxpr = jaxpr.replace(outvars=outvars) config.enable_checks.value and core.check_jaxpr(jaxpr) return new_jaxpr, out_type @@ -2663,3 +2675,25 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis _, jaxpr = f.f.closure return convert_constvars_jaxpr(jaxpr), [] return jaxpr, consts + + +@weakref_lru_cache +def lower_jaxpr(hi_jaxpr): + in_avals = [lo_ty for t in hi_jaxpr.in_avals for lo_ty in t.lo_ty()] + f = lu.wrap_init(partial(lower_traceable, hi_jaxpr), + debug_info=hi_jaxpr.jaxpr.debug_info) + lo_jaxpr, _, consts, () = trace_to_jaxpr_dynamic(f, in_avals, lower=True) + return core.ClosedJaxpr(lo_jaxpr, consts) + +def lower_traceable(jaxpr, *lo_args): + lo_args_ = iter(lo_args) + hi_args = [t.raise_val(*it.islice(lo_args_, len(t.lo_ty()))) + for t in jaxpr.in_avals] + assert (problem := next(lo_args_, None)) is None + hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + mut_outs = [lo_val for v, ty in jaxpr.jaxpr.mut_types.items() + for lo_val in ty.get(hi_args[in_idx[v]])] + lo_outs = [lo_val for t, hi_val in zip(jaxpr.out_avals, hi_outs) + for lo_val in t.lower_val(hi_val)] + return mut_outs + lo_outs diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ca77b659a08d..5edd74fe74ef 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -20,7 +20,7 @@ import dataclasses from functools import partial import inspect -import itertools +import itertools as it import logging import weakref from typing import NamedTuple, Any, Union, cast @@ -188,7 +188,8 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if core.trace_state_clean() and not config.debug_key_reuse.value: + if (core.trace_state_clean() and not config.debug_key_reuse.value + and not p.params['jaxpr'].jaxpr.is_high): args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) @@ -1592,6 +1593,36 @@ def check_aval_layout_compatibility( pjit_p.multiple_results = True pjit_p.skip_canonicalization = True +def _is_high(jaxpr, **_) -> bool: + return jaxpr.jaxpr.is_high +pjit_p.is_high = _is_high # type: ignore + +def _to_lojax(*hi_args, jaxpr, out_shardings, out_layouts, **params): + num_mut = [len(ty.lo_ty()) for ty in jaxpr.jaxpr.mut_types.values()] + out_shardings = (UNSPECIFIED,) * sum(num_mut) + out_shardings + out_layouts = (None,) * sum(num_mut) + out_layouts + + lo_args = [lo_val for t, hi_val in zip(jaxpr.in_avals, hi_args) + for lo_val in t.lower_val(hi_val)] + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = pjit_p.bind(*lo_args, jaxpr=lo_jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts, **params) + out_mut, lo_outs = split_list(all_outs, [sum(num_mut)]) + + out_mut_ = iter(out_mut) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + for var, ty in jaxpr.jaxpr.mut_types.items(): + ty.set(hi_args[in_idx[var]], *it.islice(out_mut_, len(ty.lo_ty()))) + assert next(out_mut_, None) is None + + lo_outs_ = iter(lo_outs) + hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) + for t in jaxpr.out_avals] + assert next(lo_outs_, None) is None + + return hi_outs +pjit_p.to_lojax = _to_lojax + def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # If device or backend is set, return the default layout. This is because you # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' @@ -3233,7 +3264,7 @@ def _flatten_boxes(dbg, args, kwargs): return args, kwargs, [] box_data = [] id_first_occurrences = {} - idxs = itertools.count() + idxs = it.count() def visit(x): i = next(idxs) if (isinstance(x, (Box, List)) and diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 60a3753a7ba5..b6cef7fec4dc 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,7 +15,9 @@ from __future__ import annotations from dataclasses import dataclass +from functools import partial import itertools as it +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -25,6 +27,10 @@ import jax.numpy as jnp from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map @@ -1326,5 +1332,292 @@ def f(lst1, lst2): f(b, b) +class HiPrimitive(core.Primitive): + def __init__(self, name): + self.name = name + ad.primitive_jvps[self] = self.jvp + ad.primitive_transposes[self] = self.transpose + pe.custom_staging_rules[self] = self.staging + + def staging(self, trace, *args, **kwargs): + trace.frame.is_high = True + return trace.default_process_primitive(self, args, kwargs) + + def is_high(self, **params): + return True + + def abstract_eval(self, *arg_avals, **params): + assert False, "must override" + + def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): + assert False, "must override" + + def jvp(self, primals, tangents, **params): + assert False, "must override" + + def transpose(self, *args, **params): + assert False # TODO + + +class HijaxTest(jtu.JaxTestCase): + + def test_custom_types_and_primitive(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(core.AbstractValue): + mutable = False + + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + def vspace_zero(self): + return MyArray(jnp.zeros((), 'float32')) + def vspace_add(self, x, y): + return add(x, y) + + def strip_weak_type(self): return self + def normalize(self): return self + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + class ToMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, lo_aval): + return MyTy(), set() + + def to_lojax(_, lo): + return MyArray(lo) + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return to(x), to(x_dot) + + def transpose(self, out_bar, _): + return from_(out_bar), + + class FromMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_aval): + return hi_aval.lo_ty()[0], set() + + def to_lojax(_, hi): + return hi.arr + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return from_(x), from_(x_dot) + + def transpose(self, out_bar, _): + return to(out_bar), + + def to(x): return to_p.bind(x) + to_p = ToMy('to_my') + + def from_(x): return from_p.bind(x) + from_p = FromMy('from_my') + + def mul(x, y): return mul_p.bind(x, y) + def add(x, y): return add_p.bind(x, y) + + class MyMul(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr * hi_y.arr) + + def jvp(_, primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + return mul(x, y), add(mul(x, y_dot), mul(x_dot, y)) + + def transpose(self, out_bar, x, y): + assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) + if ad.is_undefined_primal(x): + return mul(out_bar, y), None + else: + return None, mul(x, out_bar) + + class MyAdd(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr + hi_y.arr) + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(self, out_bar, x, y): + return out_bar, out_bar + + mul_p = MyMul('my_mul') + add_p = MyAdd('my_add') + + + @jax.jit + def f(x): + return to(from_(x)) + + # test basic to/from jit + a = MyArray(jnp.ones(())) + b = f(a) # don't crash + self.assertIsInstance(b, MyArray) + self.assertAllClose(b.arr, jnp.ones(())) + + # test basic to/from autodiff + b, b_dot = jax.jvp(f, (a,), (a,)) + self.assertIsInstance(b, MyArray) + self.assertIsInstance(b_dot, MyArray) + + # test mul jit and backward pass + + @jax.jit + def f(x): + return mul(x, x) + + b, f_vjp = jax.vjp(f, a) + self.assertIn('MyTy', str(f_vjp)) + a_grad, = f_vjp(b) + self.assertIsInstance(a_grad, MyArray) + self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) + + def test_box_autodiff(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + class BoxTy(core.AbstractValue): + mutable = True + + def to_tangent_aval(self): + # NOTE not really used, for some reason we had to write it anyway + return core.ShapedArray((), dtypes.float0) + + def str_short(self, short_dtypes=False): + return 'BoxTy' + + def lower_val(self, box): + return [box._val] + + def raise_val(self, val): + return Box(val) # we're gonna mutate this + + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + + def get(self, box): + return [box._val] + + def set(self, box, val): + box._val = val + + class Box: + def __init__(self, val): + self._val = val + ty = BoxTy() + core.pytype_aval_mappings[Box] = lambda b: b.ty + + + class BoxSet(HiPrimitive): + multiple_results = True + def is_high(self) -> bool: return True + + def abstract_eval(*_, **__): + return [], set() + + def to_lojax(_, box, val): + box._val = val + return [] + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(_, *args): + assert False # TODO + box_set_p = BoxSet('box_set') + + class BoxGet(HiPrimitive): + def is_high(self) -> bool: return True + + def abstract_eval(*_, **__): + return jnp.dtype('float32'), set() + + def to_lojax(_, box): + return box._val + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(_, *args): + assert False # TODO + box_get_p = BoxGet('box_get') + + class StashTangents(HiPrimitive): + def is_high(self): + return True + + def abstract_eval(_, box_aval, x_aval): + del box_aval + return x_aval, set() + + def to_lojax(_, box, x): + assert False # TODO + + def jvp(_, primals, tangents): + box, x = primals + _, x_dot = tangents + box_set(box, x_dot) + return x, x_dot + + def transpose(self, *args): + assert False # TODO + stash_tangents_p = StashTangents('stash_tangents') + + def box_set(box, val): + box_set_p.bind(box, val) + + def box_get(box): + return box_get_p.bind(box) + + def stash_tangents(box, x): + return stash_tangents_p.bind(box, x) + + @jax.jit + def f(box, x): + box_set(box, x) + + box = Box(0.0) + f(box, 1.) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + + @jax.jit + def f(box, x): + x = stash_tangents(box, x) + return x + + box = Box(0.0) + jax.jvp(partial(f, box), (3.,), (5.,)) + self.assertAllClose(box_get(box), 5.0, check_dtypes=False) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 8f919a17f7f19555a911cc76ce4c9722a5228208 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 16 May 2025 01:08:02 -0700 Subject: [PATCH 1198/1769] [JAX] Add ticks around input to clearify in the error message. Currently it looks like this. ``` ValueError: Pytree for `in_specs` and inputs do not match. There are 1 mismatches, including: * `in_specs` is a tuple of length 1 but inputs is a tuple of length 4, so the lengths do not match ``` PiperOrigin-RevId: 759499528 --- jax/_src/pallas/core.py | 2 +- tests/pallas/pallas_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f68393a7de54..709bb4640241 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1077,7 +1077,7 @@ def get_grid_mapping( if in_specs_tree != in_tree: raise ValueError( pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree, - "inputs", in_tree)) + "`inputs`", in_tree)) else: flat_in_specs = [no_block_spec] * len(in_avals) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 725b3adb4388..1114153b16c2 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1081,10 +1081,10 @@ def test_pallas_call_in_specs_mismatch_inputs(self): pl.BlockSpec((4,), lambda: 0)]) with self.assertRaisesRegex( ValueError, - re.compile("Pytree for `in_specs` and inputs do not match. " + re.compile("Pytree for `in_specs` and `inputs` do not match. " "There are 1 mismatches, including:" ".* at \\[1\\], `in_specs` is a pytree leaf but " - "inputs is a.*", re.DOTALL)): + "`inputs` is a.*", re.DOTALL)): f(a, dict(a=a)) def test_pallas_call_index_map_wrong_number_of_arguments(self): From 921ddd545c54b3d898d8629ae328af55ef6fe3e1 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 16 May 2025 10:50:47 +0200 Subject: [PATCH 1199/1769] Added dict_dict_merge/split_keys_entry_added suppression to 3.14 TSAN CI job --- .github/workflows/tsan-suppressions_3.14.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index ec5102502a2b..384560128cfc 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -18,3 +18,7 @@ race:dscal_k_ race:scal_k_ race:gemm_beta race:gemm_oncopy + +# https://github.com/python/cpython/issues/132245 +race:split_keys_entry_added +race_top:dict_dict_merge From 5567b58ec4c33ddc5976231938c9cb5ddbe6ab03 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 16 May 2025 04:53:33 -0700 Subject: [PATCH 1200/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5fee96f09a42daa80283dde9fb7090ba90d9d07a. PiperOrigin-RevId: 759564260 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f15afc93cefb..589413130c8e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ab7cea20271d8a24a7309e09fc5af486dde8e155" -XLA_SHA256 = "4c5b4fb5401b26f140100c308e549ff6fd6d11daabc6b0340fe353b25a4b0725" +XLA_COMMIT = "5fee96f09a42daa80283dde9fb7090ba90d9d07a" +XLA_SHA256 = "8d5d109185dc4383b7589504e7da769f9ec57c360d2a4810db9b2b407f7a9fa4" def repo(): tf_http_archive( From 3662851375bf6c26c1a92a1ef3afeea8427e0e0f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 16 May 2025 10:05:05 -0400 Subject: [PATCH 1201/1769] Add OMP_NUM_THREADS workaround to more aarch64 CI configuations. --- .bazelrc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bazelrc b/.bazelrc index 367138b2d026..79df03863b02 100644 --- a/.bazelrc +++ b/.bazelrc @@ -383,6 +383,10 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/ build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base +# Workaround for https://github.com/numpy/numpy/issues/28843 +# TODO(phawkins): remove this after upgrading to NumPy 2.2.6. +build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8 + # Mac x86 build:cross_compile_darwin_x86_64 --config=cross_compile_base build:cross_compile_darwin_x86_64 --config=nonccl From 73be65e1bc35f754db3b994f7cf4cd2556269646 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 16 May 2025 07:34:08 -0700 Subject: [PATCH 1202/1769] [mosaic_gpu] Added support for attaching source information to the PTX The implementation currently forces O=0 due to a suspected bug in the NVPTX backend. To get source information * Set MOSAIC_GPU_LINE_INFO=1 * Run with --jax_include_full_tracebacks_in_locations=true PiperOrigin-RevId: 759608368 --- jaxlib/mosaic/gpu/BUILD | 9 ++++++--- jaxlib/mosaic/gpu/custom_call.cc | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 115d0c47cc52..66f13bdac7f5 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -149,8 +149,6 @@ cc_library( ":nvshmem", ":passes", ":target", - "//jaxlib/cuda:cuda_vendor", - "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -200,11 +198,16 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", - "@tsl//tsl/profiler/lib:traceme", + "//jaxlib/cuda:cuda_vendor", + "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", "@xla//xla/ffi", "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", + "@tsl//tsl/profiler/lib:traceme", + # TODO(slebedev): Remove once enable-line-info is merged into the upstream + # ensure-debug-info-scope-on-llvm-func pass in MLIR. + "@triton//:TritonLLVMIR", ], alwayslink = True, ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 7c93d54aff9e..27175c3773e6 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -100,6 +100,7 @@ limitations under the License. #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" #include "tsl/profiler/lib/traceme.h" +#include "triton/Target/LLVMIR/Passes.h" namespace { @@ -174,8 +175,10 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass(); mlir::arith::registerArithExpandOpsPass(); + mlir::registerLLVMDIScopePass(); return true; }); + bool emit_line_info = getenv("MOSAIC_GPU_LINE_INFO") != nullptr; return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( @@ -188,23 +191,35 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{O=3 chip=)", - sm, R"( fast=false features=+)", ptx_isa, + nvvm-attach-target{)", + // TODO(slebedev): Always use O=3 once + // https://github.com/llvm/llvm-project/pull/140146 is merged. + emit_line_info ? "O=0" : "O=3", " chip=", sm, " fast=false features=+", + ptx_isa, R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, convert-index-to-llvm{index-bitwidth=64}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, - gpu.module(strip-debuginfo), + )", + emit_line_info ? "" : "gpu.module(strip-debuginfo),", + R"( gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), gpu.module(cse), gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, + )", + // TODO(slebedev): Switch to the ensure-debug-info-scope-on-llvm-func + // pass in MLIR once Triton upstreams its changes. + emit_line_info ? "enable-line-info," : "", + R"( gpu-module-to-binary{format=)" + - mlir::gpu::stringifyCompilationTarget(target).str() + (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + R"(}, + mlir::gpu::stringifyCompilationTarget(target).str() + + (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + + (emit_line_info ? " opts=-lineinfo" : "") + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, From 6523226e701ee08ed9e65f28c74dd01c8a49728f Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Fri, 16 May 2025 07:40:52 -0700 Subject: [PATCH 1203/1769] Strip leading zeros from ML_WHEEL_GIT_HASH. They end up being stripped by setuptools, which leads to a mismatch between expected and actual wheel names, which is fatal, as Bazel is expecting only a particular name, not to mention other issues. https://peps.python.org/pep-0440/ PiperOrigin-RevId: 759610274 --- build/build.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/build/build.py b/build/build.py index 4a7c745ce9d9..b65f7a49dd8f 100755 --- a/build/build.py +++ b/build/build.py @@ -637,7 +637,10 @@ async def main(): if "ML_WHEEL_BUILD_DATE" in option: wheel_build_date = option.split("=")[-1].replace("-", "") if "ML_WHEEL_GIT_HASH" in option: - wheel_git_hash = option.split("=")[-1][:9] + # Strip leading zeros as they end up being stripped by setuptools, + # which leads to a mismatch between expected and actual wheel names + # https://peps.python.org/pep-0440/ + wheel_git_hash = option.split("=")[-1][:9].lstrip('0') with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) From 997978143b57e9e6a237cd16c596e4e74424c90b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 16 May 2025 08:27:46 -0700 Subject: [PATCH 1204/1769] [Pallas/Mosaic GPU] Fix `get_p` lowering to handle `RefUnion`s correctly. In the case of transformed refs, it's possible that a transform ends up changing the dtype of the reference considered. This typically happens when extracting an aliased ref out of a `RefUnion`, but it could also happen if we were to handle `RefBitcast`s. As a result, querying the dtype of refs by querying their `aval` is not a safe operation. PiperOrigin-RevId: 759624846 --- jax/_src/pallas/mosaic_gpu/lowering.py | 29 ++++++++---------- jax/_src/pallas/mosaic_gpu/primitives.py | 29 +++++++++++------- tests/pallas/mosaic_gpu_test.py | 39 ++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 9ead4f16c1a6..b611ea4c17f9 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -812,7 +812,6 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mesh=jax_mesh, ) del runtime_smem, grouped_barriers, runtime_barriers - _ = lower_jaxpr_to_mosaic_gpu( module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) @@ -1288,8 +1287,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): raise TypeError(f"Can only load from references (got {x_ref}).") - - x_aval = ctx.avals_in[0] + dtype = ctx.avals_out[0].dtype transforms = jax.tree.unflatten(tree, leaves) x_smem, transforms = _handle_transforms( @@ -1300,21 +1298,21 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != ( 8, - (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype), + (swizzle * 8) // pallas_utils.dtype_bitwidth(dtype), ): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, is_signed=mgpu_utils.is_signed(dtype), swizzle=swizzle ) case (): # Handle scalar indexing. if not ctx.avals_out[0].shape: - is_signed = mgpu_utils.is_signed(x_aval.dtype) + is_signed = mgpu_utils.is_signed(dtype) val = memref_dialect.load(x_smem, []) return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) return mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + x_smem, is_signed=mgpu_utils.is_signed(dtype) ) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") @@ -1325,12 +1323,11 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") - x_aval = ctx.avals_in[0] - transforms = jax.tree.unflatten(tree, leaves) x_smem, transforms = _handle_transforms( ctx, x_smem, transforms, allow_peer_refs=True ) + mlir_dtype = ir.MemRefType(x_smem.type).element_type if transforms: raise NotImplementedError( @@ -1338,7 +1335,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): ) shape = ctx.avals_out[0].shape - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + ty = ir.VectorType.get(shape, mlir_dtype) if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] @@ -1374,7 +1371,8 @@ def _swap_lowering_rule( transforms = jax.tree.unflatten(tree, leaves) transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT x_smem, transforms = _handle_transforms( - ctx, x_ref, transforms, handle_transposes=not transposed_value, allow_peer_refs=True + ctx, x_ref, transforms, handle_transposes=not transposed_value, + allow_peer_refs=True ) mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: @@ -1437,16 +1435,15 @@ def _swap_lowering_rule_wg( if not ir.MemRefType.isinstance(x_smem.type): raise TypeError(f"Can only store to references (got {x_smem}).") - x_aval = ctx.avals_in[0] - transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(ctx, x_smem, transforms, allow_peer_refs=True) + x_smem, transforms = _handle_transforms( + ctx, x_smem, transforms, allow_peer_refs=True) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" ) - - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + x_mlir_dtype = ir.MemRefType(x_smem.type).element_type + ty = ir.VectorType.get(shape, x_mlir_dtype) if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index f199b7b245c6..0e9319972949 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -85,7 +85,7 @@ def _load_p_lowering_rule( if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): raise TypeError(f"Can only load from references (got {x_ref}).") - x_aval = ctx.avals_in[0] + out_aval = ctx.avals_out[0] transforms = jax.tree.unflatten(args_tree, leaves) x_ref, transforms = lowering._handle_transforms(ctx, x_ref, transforms) @@ -93,10 +93,10 @@ def _load_p_lowering_rule( if layout is not None: layout = layout.to_mgpu() - is_signed = mgpu_utils.is_signed(x_aval.dtype) + is_signed = mgpu_utils.is_signed(out_aval.dtype) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (8, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // out_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_ref, @@ -106,8 +106,8 @@ def _load_p_lowering_rule( ) case (): # Handle scalar indexing. - if not ctx.avals_out[0].shape: - is_signed = mgpu_utils.is_signed(x_aval.dtype) + if not out_aval.shape: + is_signed = mgpu_utils.is_signed(out_aval.dtype) val = memref_dialect.load(x_ref, []) return mgpu.FragmentedArray.splat( val, shape=(), layout=layout, is_signed=is_signed @@ -259,7 +259,9 @@ def _copy_smem_to_gmem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_transforms(ctx, src, src_transforms, handle_transposes=False) + src, src_transforms = lowering._handle_transforms( + ctx, src, src_transforms, handle_transposes=False + ) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( @@ -475,7 +477,9 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_transforms(ctx, dst, dst_transforms, handle_transposes=False) + dst, dst_transforms = lowering._handle_transforms( + ctx, dst, dst_transforms, handle_transposes=False + ) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -921,7 +925,6 @@ def _wgmma_lowering( a_transforms_tree, b_transforms_tree, ): - _, a_aval, *_ = ctx.avals_in lhs_swizzle: int | None = None if a_transforms_tree is not None: a_transforms_leaves, b_transforms_leaves = util.split_list( @@ -942,7 +945,8 @@ def _wgmma_lowering( lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + a_mlir_dtype = ir.MemRefType(a.type).element_type + swizzle_elems = lhs_swizzle // mgpu_utils.bytewidth(a_mlir_dtype) if tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") else: @@ -991,7 +995,8 @@ def _wgmma_lowering( raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") if lhs_swizzle is not None: - swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize + b_mlir_dtype = ir.MemRefType(b.type).element_type + swizzle_elems = rhs_swizzle // mgpu_utils.bytewidth(b_mlir_dtype) if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") if rhs_tiling != (8, swizzle_elems): @@ -1917,7 +1922,9 @@ def _inline_mgpu_lowering_rule( assert transforms is None continue assert isinstance(aval, pallas_core.AbstractMemoryRef) - a, user_transforms = lowering._handle_transforms(ctx, a, transforms, handle_transposes=False) + a, user_transforms = lowering._handle_transforms( + ctx, a, transforms, handle_transposes=False + ) # Transforms that do not originate from a MemoryRefTransform are # applied implicitly (eg by emit-pipeline) and therefore we do not # expect the user to pass them to the type. The transforms not diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9d259097f90a..40539463e007 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1671,6 +1671,45 @@ def unpack_i4_as_i8(x): test_as_i8 = jax.lax.convert_element_type(kernel(x), new_dtype=jnp.int8) np.testing.assert_array_equal(test_as_i8[:256], unpack_i4_as_i8(x)) + def test_smem_aliasing_works_for_quantization(self): + self.skip_if_wg_semantics() + shape = (64, 256) + large_ty, small_ty = jnp.bfloat16, jnp.uint4 + large_swizzle = plgpu.SwizzleTransform(64 * jnp.finfo(large_ty).bits // 8) + small_swizzle = plgpu.SwizzleTransform(64 * jnp.iinfo(small_ty).bits // 8) + tiling = plgpu.TilingTransform((8, 64)) + + def kernel(x_gmem, o_gmem): + return pl.run_scoped( + functools.partial(scoped_kernel, x_gmem, o_gmem), + plgpu.RefUnion( + plgpu.SMEM(shape, large_ty, transforms=(tiling, large_swizzle)), + plgpu.SMEM(shape, small_ty, transforms=(tiling, small_swizzle)) + ), + plgpu.Barrier(1, num_barriers=1), + ) + + def scoped_kernel(x_gmem, o_gmem, aliased_ref, barrier): + ref_large_ty, ref_small_ty = aliased_ref + plgpu.copy_gmem_to_smem(x_gmem, ref_small_ty, barrier=barrier) + plgpu.barrier_wait(barrier) + ref_large_ty[...] = ref_small_ty[...].astype(ref_large_ty.dtype) * 3 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(ref_large_ty, o_gmem) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, large_ty), + grid=(1, 1), + ) + key = jax.random.key(42) + x = jax.random.randint(key, shape, 0, 4).astype(small_ty) + expected = x * 3 + np.testing.assert_array_equal(kernel_fn(x), expected) + def test_assigning_to_ref_union_raises(self): @functools.partial( self.pallas_call, From abda2872ce6562807012ad927534d06b8a515f9c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 16 May 2025 10:02:58 -0700 Subject: [PATCH 1205/1769] [Mosaic GPU] Skip tests that need stdout capture when using pytest The prints all go to the stdout captured by pytest instead of being intercepted by `jtu`. PiperOrigin-RevId: 759656514 --- tests/mosaic/gpu_test.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 03ded0ac446c..62f377f031a0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -20,6 +20,7 @@ import itertools import math import operator +import sys import re import unittest @@ -236,6 +237,8 @@ def setUp(self): @contextlib.contextmanager def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") if mosaic_gpu_lib is None: raise ValueError("Running tests but missing Mosaic GPU extension") with jtu.capture_stdout() as stdout: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 40539463e007..dc35f03843f7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -19,6 +19,7 @@ import operator import os import re +import sys import tempfile from typing import ClassVar @@ -106,6 +107,8 @@ def pallas_call(self, *args, **kwargs): @contextlib.contextmanager def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") if mosaic_gpu_lib is None: raise ValueError("Running tests but missing Mosaic GPU extension") with jtu.capture_stdout() as stdout: From 6dad5b1b6718b61f5ef863b8ddbae3c2c8ec4ff4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 16 May 2025 10:06:04 -0700 Subject: [PATCH 1206/1769] quick fix for internal test broken by #28781 PiperOrigin-RevId: 759658054 --- jax/_src/interpreters/ad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 090022c9b6a4..29af03416a76 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -609,7 +609,8 @@ def process_custom_transpose(self, prim, call, tracers, **params): def maybe_jvp_tracer(trace, primal, tangent): if (type(tangent) is Zero or - core.typeof(tangent) is core.ShapedArray and dtype(tangent) == float0): + isinstance(core.typeof(tangent), core.ShapedArray) + and dtype(tangent) == float0): return primal else: return JVPTracer(trace, primal, tangent) From 87c84380537959d9df6da7a430f7c32a5db7a3c3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 15 May 2025 14:22:11 -0700 Subject: [PATCH 1207/1769] [pre-commit] bump ruff to v0.11.9 --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46deb8eb4879..71b3d51caaa7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 8983acb92ee4b01924893632cf90af926fa608f0 # frozen: v0.7.0 + rev: 24e02b24b8ab2b7c76225602d13fa60e12d114e6 # frozen: v0.11.9 hooks: - id: ruff diff --git a/pyproject.toml b/pyproject.toml index 03cc78a6dcbb..83b85b0271f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,8 @@ ignore = [ "C901", # Local variable is assigned to but never used "F841", + # Class could be dataclass or namedtuple + "B903", # Raise with from clause inside except block "B904", # Zip without explicit strict parameter From 3454bd275cb481aabb67f7d40dfa07875f201129 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 16 May 2025 15:40:38 -0400 Subject: [PATCH 1208/1769] Fix test failures in GPU CI. --- BUILD.bazel | 1 + tests/BUILD | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/BUILD.bazel b/BUILD.bazel index 82e3b4ab5c00..59fd949b7ad8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -113,6 +113,7 @@ genrule( "//jax:experimental/pallas/ops/tpu/random/threefry.py", "//jax/experimental/mosaic/gpu/examples:flash_attention.py", "//jax/experimental/mosaic/gpu/examples:matmul.py", + "//jax:test_multiprocess", ], outs = ["wheel_additives.zip"], cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", diff --git a/tests/BUILD b/tests/BUILD index 1d1f3c28b239..c51d40715d15 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -400,6 +400,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 5, + "gpu": 5, "tpu": 5, }, tags = ["multiaccelerator"], @@ -1839,7 +1840,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 50, - "gpu": 10, + "gpu": 20, "tpu": 50, }, tags = [ From d47fc8d928905516a0d4dbfda6f70060068c779d Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 16 May 2025 13:10:11 -0700 Subject: [PATCH 1209/1769] [Mosaic TPU][NFC] Consolidate `getIntConst`. PiperOrigin-RevId: 759728429 --- .../tpu/transforms/apply_vector_layout.cc | 60 ++++++++----------- jaxlib/mosaic/dialect/tpu/util.h | 1 - 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 656be0e677b0..ba1dfc95c66c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -209,23 +209,18 @@ bool incrementIndex(const MutableArrayRef idx, return false; } -FailureOr getIntConst(Value v, bool silent = false) { - if (auto constant_op = v.getDefiningOp()) { - if (auto integer_attr = dyn_cast(constant_op.getValue())) { - return integer_attr.getValue().getSExtValue(); - } - } - if (silent) { - return failure(); +FailureOr expectIntConst(Value v) { + if (auto cst = getIntConst(v)) { + return cst.value(); } return emitError(v.getLoc(), "Expected an integer constant"); } -FailureOr> getIntConstsFromOperandRange( - ValueRange vals, bool silent = false) { +FailureOr> expectIntConstsFromOperandRange( + ValueRange vals) { SmallVector res(vals.size()); for (int i = 0; i < vals.size(); ++i) { - FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i], silent)); + FAILUREOR_ASSIGN_OR_RETURN(res[i], expectIntConst(vals[i])); } return res; } @@ -265,7 +260,7 @@ FailureOr>> sliceRef( Value c0 = nullptr; SmallVector indices_within_slice(indices.size() - tiling.size(), 0); for (auto tiled_idx : indices.take_back(tiling.size())) { - if (auto cst = getIntConst(tiled_idx, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(tiled_idx)) { indices_within_slice.push_back(*cst); if (!c0) { c0 = builder.create(i32, @@ -1548,7 +1543,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, } FAILUREOR_ASSIGN_OR_RETURN( const SmallVector indices, - getIntConstsFromOperandRange(load_op.getIndices())); + expectIntConstsFromOperandRange(load_op.getIndices())); TPU_ASSERT_EQ_OP(indices.size(), 2); if (indices[1] % ctx.target_shape[1] != 0) { return op.emitOpError("Not implemented: Lane index is not a multiple of ") @@ -1606,8 +1601,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, if (strides[rank - 1] != 1) { return op.emitOpError("Not Implemented: Stride on last dim is not 1"); } - auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true); - if (failed(last_idx)) { + auto last_idx = getIntConst(indices[rank - 1]); + if (!last_idx.has_value()) { return op.emitOpError("Not Implemented: Dynamic index on last dim"); } else if (last_idx.value() != 0) { return op.emitOpError("Not Implemented: Index on last dim is not 0"); @@ -1975,7 +1970,7 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op, tpu::StoreOp store_op = cast(op); FAILUREOR_ASSIGN_OR_RETURN( const SmallVector indices, - getIntConstsFromOperandRange(store_op.getIndices())); + expectIntConstsFromOperandRange(store_op.getIndices())); TPU_ASSERT_EQ_OP(indices.size(), 2); if (indices[1] % ctx.target_shape[1] != 0) { return op.emitOpError("Not implemented: Lane index is not a multiple of ") @@ -2143,15 +2138,14 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, return op.emitOpError("Not implemented: unsupported layout for input"); } LayoutOffsets expected_offsets_out = layout_in.offsets(); - auto shift = getIntConst(amount, /*silent=*/true); - const bool has_static_shift = succeeded(shift); + auto shift = getIntConst(amount); int rotated_tiled_dim = op.getDimension() - (op.getType().getRank() - 2); bool has_padding_along_rotation = (rotated_tiled_dim == 0 || rotated_tiled_dim == 1) && op.getType().getShape()[op.getDimension()] % layout.tiling()[rotated_tiled_dim] != 0; - if (has_static_shift && has_padding_along_rotation) { + if (shift.has_value() && has_padding_along_rotation) { // We checked above that there are no implicit dims. const int64_t dim_size = op.getType().getShape()[op.getDimension()]; // TODO(b/337384645): Currently we assume {0, 0} offsets in the input @@ -2173,7 +2167,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, // TODO(b/411170715): Allow sublane rotation once the bug is fixed. // TODO(b/337384645): Support non-zero stride. if (has_padding_along_rotation && - (!has_static_shift || + (!shift.has_value() || (rotated_tiled_dim == 0 || (rotated_tiled_dim == 1 && op.getStride().value_or(0) != 0)))) { return op.emitOpError("Not implemented: unsupported unaligned shape"); @@ -2200,19 +2194,19 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, builder.getIntegerAttr(builder.getIndexType(), d)); }; auto modI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() % d); } return builder.create(v, mlirI32Const(d)); }; auto divI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() / d); } return builder.create(v, mlirI32Const(d)); }; auto addI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() + d); } return builder.create(v, mlirI32Const(d)); @@ -2239,8 +2233,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) { CHECK(dim == 0 || dim == 1); Value padding_vreg; - if (auto padding_cst = getIntConst(padding, /*silent=*/true); - succeeded(padding_cst)) { + if (auto padding_cst = getIntConst(padding)) { CHECK_GE(padding_cst.value(), 0); CHECK_LE(padding_cst.value(), ctx.target_shape[dim]); padding_vreg = builder.create(DenseElementsAttr::get( @@ -2269,8 +2262,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, // and blend the data from contiguous vregs to emulate circular rotation. auto rotateOnTilingDim = [&](const xla::Array &vregs, const Value &shift, int axis, int stride = 0) { - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { + if (auto shift_cst = getIntConst(shift)) { if (shift_cst.value() == 0 && stride == 0) { return vregs; } @@ -2395,8 +2387,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0)); SmallVector, 4> chunks; // Handle rotation with static shift. - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { + if (auto shift_cst = getIntConst(shift)) { int64_t static_shift = shift_cst.value(); if (has_padding_along_rotation) { return lazyRotate(vregs, static_shift, axis); @@ -2519,8 +2510,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, vty.getDimSize(dim)); // After applying stride, we expect all shifts in a vreg are less or // equal to the vreg's lane count for now. - if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true); - succeeded(base_amount_cst)) { + if (auto base_amount_cst = getIntConst(base_amount)) { int64_t static_base_amount = base_amount_cst.value(); auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] + (ctx.target_shape[0] - 1) * stride; @@ -3163,7 +3153,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, bool must_support_unaligned_dynamic_index = false; if (load_op.getIndices().size() > 1) { auto second_minor_idx = load_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && + if (!getIntConst(second_minor_idx).has_value() && !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { must_support_unaligned_dynamic_index = true; } @@ -3196,7 +3186,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, } auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return IdxConst(cst.value() + d, builder, op.getLoc()); } return builder.create(v, IdxConst(d, builder, op.getLoc())); @@ -4476,7 +4466,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, bool must_support_unaligned_dynamic_index = false; if (store_op.getIndices().size() > 1) { auto second_minor_idx = store_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && + if (!getIntConst(second_minor_idx).has_value() && !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { must_support_unaligned_dynamic_index = true; } @@ -4507,7 +4497,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, } auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return IdxConst(cst.value() + d, builder, op.getLoc()); } return builder.create(v, IdxConst(d, builder, op.getLoc())); diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 000cb4411e62..ac83d95b715e 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -283,7 +283,6 @@ inline arith::ConstantOp I32Const(int32_t value, ArrayRef shape, builder.getIntegerAttr(builder.getI32Type(), value))); } -// TODO(jevinjiang): consolidate this with getIntConst in apply-vector-layout. std::optional getIntConst(Value v); } // namespace mlir::tpu From 7412adec21c534f8e4bcc627552f28d162decc86 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Fri, 16 May 2025 13:30:22 -0700 Subject: [PATCH 1210/1769] Keep the serialized version of `BufferAssignmentProto` in `CompiledMemoryStats` to reduce its overheads Most users of `CompiledMemoryStats` do not use this field. So it is cheaper in terms of both CPU and RAM to keep it as a serialized string rather than a proto. If this continues to become a problem, we can consider inventing a separate executable API for buffer assignment. PiperOrigin-RevId: 759735092 --- jaxlib/xla.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 4020e061b3f4..3412766de6bd 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -490,7 +490,10 @@ NB_MODULE(_jax, m) { &CompiledMemoryStats::host_temp_size_in_bytes) .def_prop_ro("serialized_buffer_assignment_proto", [](const CompiledMemoryStats& cms) -> nb::bytes { -#if JAX_IFRT_VERSION_NUMBER >= 7 +#if JAX_IFRT_VERSION_NUMBER >= 9 + const std::string& s = cms.serialized_buffer_assignment; + return nb::bytes(s.data(), s.size()); +#elif JAX_IFRT_VERSION_NUMBER >= 7 if (cms.buffer_assignment.has_value()) { std::string s = cms.buffer_assignment->SerializeAsString(); From 1f592543337c8e542e5ebb801c89b083928f6009 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 15 May 2025 23:04:39 +0000 Subject: [PATCH 1211/1769] [hijax] type-changing boxes with pytree contents --- jax/_src/core.py | 15 +- jax/_src/interpreters/partial_eval.py | 15 +- jax/_src/pjit.py | 38 ++- tests/BUILD | 12 + tests/attrs_test.py | 293 ----------------- tests/hijax_test.py | 453 ++++++++++++++++++++++++++ 6 files changed, 512 insertions(+), 314 deletions(-) create mode 100644 tests/hijax_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index ab97c7ff1c2c..e49173c3df45 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -88,7 +88,7 @@ class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info', '_is_high', '_mut_types'] + '_effects', '_debug_info', '_is_high', '_final_typechange_env'] _constvars: list[Var] _invars: list[Var] @@ -97,7 +97,7 @@ class Jaxpr: _effects: Effects _debug_info: DebugInfo _is_high: bool - _mut_types: dict[Var, Any] + _final_typechange_env: dict[Var, Any] @property def constvars(self) -> list[Var]: @@ -128,8 +128,8 @@ def is_high(self) -> bool: return self._is_high @property - def mut_types(self) -> dict[Var, Any]: - return self._mut_types + def final_typechange_env(self) -> dict[Var, Any]: + return self._final_typechange_env def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], @@ -139,7 +139,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] is_high: bool = False, - mut_types: dict | None = None, + final_typechange_env: dict | None = None, ): """ Args: @@ -165,7 +165,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) self._is_high = is_high - self._mut_types = mut_types or {} + self._final_typechange_env = final_typechange_env or {} def __str__(self): return str(self.pretty_print()) @@ -193,7 +193,8 @@ def replace(self, **kwargs): effects=kwargs.pop("effects", self.effects), debug_info=kwargs.pop("debug_info", self.debug_info), is_high=kwargs.pop("is_high", self.is_high), - mut_types=kwargs.pop("mut_types", self.mut_types), + final_typechange_env=kwargs.pop("final_typechange_env", + self.final_typechange_env), ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 9e875f43d831..f77db5443a86 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1183,13 +1183,13 @@ def has_effects(effects) -> bool: known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) known_mut, staged_mut, ins_known_ = {}, {}, set(ins_known) # type: ignore - for v, t in jaxpr.mut_types.items(): + for v, t in jaxpr.final_typechange_env.items(): [staged_mut, known_mut][v in ins_known_][v] = t # TODO(mattjj,necula): debug info should be updated here jaxpr_known = jaxpr.replace( invars=ins_known_and_ref_res, outvars=known_outvars, - eqns=known_eqns, effects=known_effects, mut_types=known_mut) + eqns=known_eqns, effects=known_effects, final_typechange_env=known_mut) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1200,7 +1200,7 @@ def has_effects(effects) -> bool: # TODO(mattjj,necula): debug info should be updated here jaxpr_staged = jaxpr.replace( invars=staged_invars, outvars=outs_staged, eqns=staged_eqns, - effects=staged_effects, mut_types=staged_mut) + effects=staged_effects, final_typechange_env=staged_mut) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1713,6 +1713,7 @@ class JaxprStackFrame: attrs_vars: list[Var] debug_info: core.DebugInfo is_high: bool + final_typechange_env: dict def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() @@ -1728,6 +1729,7 @@ def __init__(self, debug_info: core.DebugInfo): self.attrs_vars = [] self.debug_info = debug_info self.is_high = False + self.final_typechange_env = {} def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) @@ -1753,9 +1755,8 @@ def to_jaxpr( outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) - mut_types = {v: v.aval for v in invars if v.aval.mutable} if self.is_high else {} jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info, self.is_high, mut_types) + debug_info, self.is_high, self.final_typechange_env) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1872,6 +1873,8 @@ def new_arg(self, aval, source_info: SourceInfo): self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) + if aval.mutable: + self.frame.final_typechange_env[var] = aval return tracer def new_const(self, c, source_info: SourceInfo): @@ -2692,7 +2695,7 @@ def lower_traceable(jaxpr, *lo_args): assert (problem := next(lo_args_, None)) is None hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - mut_outs = [lo_val for v, ty in jaxpr.jaxpr.mut_types.items() + mut_outs = [lo_val for v, ty in jaxpr.jaxpr.final_typechange_env.items() for lo_val in ty.get(hi_args[in_idx[v]])] lo_outs = [lo_val for t, hi_val in zip(jaxpr.out_avals, hi_outs) for lo_val in t.lower_val(hi_val)] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5edd74fe74ef..10e7e697e706 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1597,21 +1597,18 @@ def _is_high(jaxpr, **_) -> bool: return jaxpr.jaxpr.is_high pjit_p.is_high = _is_high # type: ignore -def _to_lojax(*hi_args, jaxpr, out_shardings, out_layouts, **params): - num_mut = [len(ty.lo_ty()) for ty in jaxpr.jaxpr.mut_types.values()] - out_shardings = (UNSPECIFIED,) * sum(num_mut) + out_shardings - out_layouts = (None,) * sum(num_mut) + out_layouts +def _to_lojax( *hi_args, jaxpr, **params): + params, num_mutants = _lojax_expand_params(jaxpr, **params) lo_args = [lo_val for t, hi_val in zip(jaxpr.in_avals, hi_args) for lo_val in t.lower_val(hi_val)] lo_jaxpr = pe.lower_jaxpr(jaxpr) - all_outs = pjit_p.bind(*lo_args, jaxpr=lo_jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts, **params) - out_mut, lo_outs = split_list(all_outs, [sum(num_mut)]) + all_outs = pjit_p.bind(*lo_args, jaxpr=lo_jaxpr, **params) + out_mut, lo_outs = split_list(all_outs, [num_mutants]) out_mut_ = iter(out_mut) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - for var, ty in jaxpr.jaxpr.mut_types.items(): + for var, ty in jaxpr.jaxpr.final_typechange_env.items(): ty.set(hi_args[in_idx[var]], *it.islice(out_mut_, len(ty.lo_ty()))) assert next(out_mut_, None) is None @@ -1623,6 +1620,31 @@ def _to_lojax(*hi_args, jaxpr, out_shardings, out_layouts, **params): return hi_outs pjit_p.to_lojax = _to_lojax +def _lojax_expand_params( + hi_jaxpr, *, donated_invars, in_shardings, in_layouts, out_shardings, + out_layouts, **params): + # some pjit params match the length of hi_jaxpr.invars/outvars, so when + # lowering we must expand them to match their number of lojax types + def expand(hi_tys, xs): + return tuple(y for hi, x in zip(hi_tys, xs) for y in (x,) * len(hi.lo_ty())) + donated_invars = expand(hi_jaxpr.in_avals , donated_invars) + in_shardings = expand(hi_jaxpr.in_avals , in_shardings ) + in_layouts = expand(hi_jaxpr.in_avals , in_layouts ) + out_shardings = expand(hi_jaxpr.out_avals, out_shardings ) + out_layouts = expand(hi_jaxpr.out_avals, out_layouts ) + + # also, the lo_jaxpr has pure outputs corresponding to mutable hi_jaxpr types + num_mutants = sum(len(hi_ty.lo_ty()) for hi_ty in + hi_jaxpr.jaxpr.final_typechange_env.values()) + out_shardings = (UNSPECIFIED,) * num_mutants + out_shardings + out_layouts = (None,) * num_mutants + out_layouts + + new_params = dict(params, donated_invars=donated_invars, + in_shardings=in_shardings, in_layouts=in_layouts, + out_shardings=out_shardings, out_layouts=out_layouts) + return new_params, num_mutants + + def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # If device or backend is set, return the default layout. This is because you # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' diff --git a/tests/BUILD b/tests/BUILD index c51d40715d15..35695c75c79f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1875,6 +1875,18 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "hijax_test", + srcs = ["hijax_test.py"], + deps = [ + "//jax:experimental", + ] + py_deps([ + "numpy", + "absl/testing", + ]), +) + + jax_multiplatform_test( name = "colocated_python_test", srcs = ["colocated_python_test.py"], diff --git a/tests/attrs_test.py b/tests/attrs_test.py index b6cef7fec4dc..60a3753a7ba5 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,9 +15,7 @@ from __future__ import annotations from dataclasses import dataclass -from functools import partial import itertools as it -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -27,10 +25,6 @@ import jax.numpy as jnp from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src.interpreters import ad -from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map @@ -1332,292 +1326,5 @@ def f(lst1, lst2): f(b, b) -class HiPrimitive(core.Primitive): - def __init__(self, name): - self.name = name - ad.primitive_jvps[self] = self.jvp - ad.primitive_transposes[self] = self.transpose - pe.custom_staging_rules[self] = self.staging - - def staging(self, trace, *args, **kwargs): - trace.frame.is_high = True - return trace.default_process_primitive(self, args, kwargs) - - def is_high(self, **params): - return True - - def abstract_eval(self, *arg_avals, **params): - assert False, "must override" - - def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): - assert False, "must override" - - def jvp(self, primals, tangents, **params): - assert False, "must override" - - def transpose(self, *args, **params): - assert False # TODO - - -class HijaxTest(jtu.JaxTestCase): - - def test_custom_types_and_primitive(self): - if config.enable_x64.value: raise unittest.SkipTest("no x64") - - @dataclass(frozen=True) - class MyArray: - arr: jax.Array # always f32 - - @dataclass(frozen=True) - class MyTy(core.AbstractValue): - mutable = False - - def to_tangent_aval(self): - return MyTy() - def str_short(self, short_dtypes=False): - return 'MyTy' - def lo_ty(self): - return [core.ShapedArray((), jnp.dtype('float32'))] - def lower_val(self, hi_val: MyArray) -> list[jax.Array]: - return [hi_val.arr] - def raise_val(self, val) -> MyArray: - return MyArray(val) - - def __eq__(self, other): return isinstance(other, MyTy) - - def vspace_zero(self): - return MyArray(jnp.zeros((), 'float32')) - def vspace_add(self, x, y): - return add(x, y) - - def strip_weak_type(self): return self - def normalize(self): return self - core.pytype_aval_mappings[MyArray] = lambda _: MyTy() - - class ToMy(HiPrimitive): - def is_high(self): return True - - def abstract_eval(_, lo_aval): - return MyTy(), set() - - def to_lojax(_, lo): - return MyArray(lo) - - def jvp(_, primals, tangents): - x, x_dot = *primals, *tangents - return to(x), to(x_dot) - - def transpose(self, out_bar, _): - return from_(out_bar), - - class FromMy(HiPrimitive): - def is_high(self): return True - - def abstract_eval(_, hi_aval): - return hi_aval.lo_ty()[0], set() - - def to_lojax(_, hi): - return hi.arr - - def jvp(_, primals, tangents): - x, x_dot = *primals, *tangents - return from_(x), from_(x_dot) - - def transpose(self, out_bar, _): - return to(out_bar), - - def to(x): return to_p.bind(x) - to_p = ToMy('to_my') - - def from_(x): return from_p.bind(x) - from_p = FromMy('from_my') - - def mul(x, y): return mul_p.bind(x, y) - def add(x, y): return add_p.bind(x, y) - - class MyMul(HiPrimitive): - def is_high(self): return True - - def abstract_eval(_, hi_x, hi_y): - if hi_x != hi_y: raise Exception - return hi_x, set() - - def to_lojax(_, hi_x, hi_y): - return MyArray(hi_x.arr * hi_y.arr) - - def jvp(_, primals, tangents): - (x, y), (x_dot, y_dot) = primals, tangents - return mul(x, y), add(mul(x, y_dot), mul(x_dot, y)) - - def transpose(self, out_bar, x, y): - assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) - if ad.is_undefined_primal(x): - return mul(out_bar, y), None - else: - return None, mul(x, out_bar) - - class MyAdd(HiPrimitive): - def is_high(self): return True - - def abstract_eval(_, hi_x, hi_y): - if hi_x != hi_y: raise Exception - return hi_x, set() - - def to_lojax(_, hi_x, hi_y): - return MyArray(hi_x.arr + hi_y.arr) - - def jvp(_, primals, tangents): - assert False # TODO - - def transpose(self, out_bar, x, y): - return out_bar, out_bar - - mul_p = MyMul('my_mul') - add_p = MyAdd('my_add') - - - @jax.jit - def f(x): - return to(from_(x)) - - # test basic to/from jit - a = MyArray(jnp.ones(())) - b = f(a) # don't crash - self.assertIsInstance(b, MyArray) - self.assertAllClose(b.arr, jnp.ones(())) - - # test basic to/from autodiff - b, b_dot = jax.jvp(f, (a,), (a,)) - self.assertIsInstance(b, MyArray) - self.assertIsInstance(b_dot, MyArray) - - # test mul jit and backward pass - - @jax.jit - def f(x): - return mul(x, x) - - b, f_vjp = jax.vjp(f, a) - self.assertIn('MyTy', str(f_vjp)) - a_grad, = f_vjp(b) - self.assertIsInstance(a_grad, MyArray) - self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) - - def test_box_autodiff(self): - if config.enable_x64.value: raise unittest.SkipTest("no x64") - class BoxTy(core.AbstractValue): - mutable = True - - def to_tangent_aval(self): - # NOTE not really used, for some reason we had to write it anyway - return core.ShapedArray((), dtypes.float0) - - def str_short(self, short_dtypes=False): - return 'BoxTy' - - def lower_val(self, box): - return [box._val] - - def raise_val(self, val): - return Box(val) # we're gonna mutate this - - def lo_ty(self): - return [core.ShapedArray((), jnp.dtype('float32'))] - - def get(self, box): - return [box._val] - - def set(self, box, val): - box._val = val - - class Box: - def __init__(self, val): - self._val = val - ty = BoxTy() - core.pytype_aval_mappings[Box] = lambda b: b.ty - - - class BoxSet(HiPrimitive): - multiple_results = True - def is_high(self) -> bool: return True - - def abstract_eval(*_, **__): - return [], set() - - def to_lojax(_, box, val): - box._val = val - return [] - - def jvp(_, primals, tangents): - assert False # TODO - - def transpose(_, *args): - assert False # TODO - box_set_p = BoxSet('box_set') - - class BoxGet(HiPrimitive): - def is_high(self) -> bool: return True - - def abstract_eval(*_, **__): - return jnp.dtype('float32'), set() - - def to_lojax(_, box): - return box._val - - def jvp(_, primals, tangents): - assert False # TODO - - def transpose(_, *args): - assert False # TODO - box_get_p = BoxGet('box_get') - - class StashTangents(HiPrimitive): - def is_high(self): - return True - - def abstract_eval(_, box_aval, x_aval): - del box_aval - return x_aval, set() - - def to_lojax(_, box, x): - assert False # TODO - - def jvp(_, primals, tangents): - box, x = primals - _, x_dot = tangents - box_set(box, x_dot) - return x, x_dot - - def transpose(self, *args): - assert False # TODO - stash_tangents_p = StashTangents('stash_tangents') - - def box_set(box, val): - box_set_p.bind(box, val) - - def box_get(box): - return box_get_p.bind(box) - - def stash_tangents(box, x): - return stash_tangents_p.bind(box, x) - - @jax.jit - def f(box, x): - box_set(box, x) - - box = Box(0.0) - f(box, 1.) - self.assertAllClose(box_get(box), 1.0, check_dtypes=False) - - @jax.jit - def f(box, x): - x = stash_tangents(box, x) - return x - - box = Box(0.0) - jax.jvp(partial(f, box), (3.,), (5.,)) - self.assertAllClose(box_get(box), 5.0, check_dtypes=False) - - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/hijax_test.py b/tests/hijax_test.py new file mode 100644 index 000000000000..21034d164d28 --- /dev/null +++ b/tests/hijax_test.py @@ -0,0 +1,453 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +import itertools as it +import unittest + +from absl.testing import absltest + +import jax +import jax.numpy as jnp + +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe +from jax._src import test_util as jtu +from jax._src.util import safe_zip, safe_map + +config.parse_flags_with_absl() + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + + +# TODO(mattjj,dougalm): move HiPrimitive, Box, etc out of tests and into library +class HiPrimitive(core.Primitive): + def __init__(self, name): + self.name = name + ad.primitive_jvps[self] = self.jvp + ad.primitive_transposes[self] = self.transpose + pe.custom_staging_rules[self] = self.staging + + def staging(self, trace, *args, **kwargs): + trace.frame.is_high = True + return trace.default_process_primitive(self, args, kwargs) + + def is_high(self, **params): + return True + + def abstract_eval(self, *arg_avals, **params): + assert False, "must override" + + def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): + assert False, "must override" + + def jvp(self, primals, tangents, **params): + assert False, "must override" + + def transpose(self, *args, **params): + assert False # TODO + + +class BoxTy(core.AbstractValue): + mutable = True + + def __init__(self, leaf_avals, treedef): + self._leaf_avals = leaf_avals # hijax avals + self._treedef = treedef + + # aval interface: hashability and str_short + def __hash__(self): + return hash((self._leaf_avals, self._treedef)) + + def __eq__(self, other): + return (isinstance(other, BoxTy) and self._leaf_avals == other._leaf_avals + and self._treedef == other._treedef) + + def str_short(self, short_dtypes=False): + return 'BoxTy' + + # hijax interface: lower val, raise val, and low type + def lo_ty(self): + return [lo_aval for hi_aval in self._leaf_avals for lo_aval in hi_aval.lo_ty()] + + def lower_val(self, box): + leaf_vals, treedef = jax.tree.flatten(box._val) + assert treedef == self._treedef + return [lo_val for hi_aval, hi_val in zip(self._leaf_avals, leaf_vals) + for lo_val in hi_aval.lower_val(hi_val)] + + def raise_val(self, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in self._leaf_avals] + assert next(lo_vals_, None) is None + return Box(jax.tree.unflatten(self._treedef, hi_vals)) # will be mutated + + # mutable interface: get/set + def get(self, box): + leaf_vals, treedef = jax.tree.flatten(box._val) + assert treedef == self._treedef + return [lo_val for hi_ty, hi_val in zip(self._leaf_avals, leaf_vals) + for lo_val in hi_ty.lower_val(hi_val)] + + def set(self, box, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in self._leaf_avals] + assert next(lo_vals_, None) is None + box._val = jax.tree.unflatten(self._treedef, hi_vals) + + # TODO placeholder thing + def to_tangent_aval(self): + return core.ShapedArray((), dtypes.float0) # TODO revise placeholder + +class Box: # noqa: F811 + def __init__(self, val): + self._val = val + + @property + def ty(self): + leaves, treedef = jax.tree.flatten(self._val) + leaf_avals = tuple(map(core.typeof, leaves)) + return BoxTy(leaf_avals, treedef) +core.pytype_aval_mappings[Box] = lambda b: b.ty + + +class BoxSet(HiPrimitive): + multiple_results = True + + def is_high(self, *, treedef) -> bool: return True + + def staging(self, trace, box, *leaves, treedef): + super().staging(trace, box, *leaves, treedef=treedef) + avals = tuple(t.aval for t in leaves) + trace.frame.final_typechange_env[trace.getvar(box)] = BoxTy(avals, treedef) + + def abstract_eval(self, box_ty, *leaf_avals, treedef): + return [], set() # TODO better typechecking... + + def to_lojax(_, box, *leaves, treedef): + box._val = jax.tree.unflatten(treedef, leaves) + return [] + + def jvp(_, primals, tangents, *, treedef): + assert False # TODO + + def transpose(_, *args, treedef): + assert False # TODO +box_set_p = BoxSet('box_set') + +def box_set(box, val): + leaves, treedef = jax.tree.flatten(val) + box_set_p.bind(box, *leaves, treedef=treedef) + + +class BoxGet(HiPrimitive): + multiple_results = True + + def is_high(self) -> bool: return True + + def abstract_eval(self, box_ty): + return box_ty._leaf_avals, set() + + def to_lojax(_, box): + return jax.tree.leaves(box._val) + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(_, *args): + assert False # TODO +box_get_p = BoxGet('box_get') + +def box_get(box): + leaf_vals = box_get_p.bind(box) + return jax.tree.unflatten(core.typeof(box)._treedef, leaf_vals) + + +class HijaxTest(jtu.JaxTestCase): + + def test_custom_types_and_primitive(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(core.AbstractValue): + mutable = False + + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + def vspace_zero(self): + return MyArray(jnp.zeros((), 'float32')) + def vspace_add(self, x, y): + return add(x, y) + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + class ToMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, lo_aval): + return MyTy(), set() + + def to_lojax(_, lo): + return MyArray(lo) + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return to(x), to(x_dot) + + def transpose(self, out_bar, _): + return from_(out_bar), + + class FromMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_aval): + return hi_aval.lo_ty()[0], set() + + def to_lojax(_, hi): + return hi.arr + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return from_(x), from_(x_dot) + + def transpose(self, out_bar, _): + return to(out_bar), + + def to(x): return to_p.bind(x) + to_p = ToMy('to_my') + + def from_(x): return from_p.bind(x) + from_p = FromMy('from_my') + + def mul(x, y): return mul_p.bind(x, y) + def add(x, y): return add_p.bind(x, y) + + class MyMul(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr * hi_y.arr) + + def jvp(_, primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + return mul(x, y), add(mul(x, y_dot), mul(x_dot, y)) + + def transpose(self, out_bar, x, y): + assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) + if ad.is_undefined_primal(x): + return mul(out_bar, y), None + else: + return None, mul(x, out_bar) + + class MyAdd(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr + hi_y.arr) + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(self, out_bar, x, y): + return out_bar, out_bar + + mul_p = MyMul('my_mul') + add_p = MyAdd('my_add') + + + @jax.jit + def f(x): + return to(from_(x)) + + # test basic to/from jit + a = MyArray(jnp.ones(())) + b = f(a) # don't crash + self.assertIsInstance(b, MyArray) + self.assertAllClose(b.arr, jnp.ones(())) + + # test basic to/from autodiff + b, b_dot = jax.jvp(f, (a,), (a,)) + self.assertIsInstance(b, MyArray) + self.assertIsInstance(b_dot, MyArray) + + # test mul jit and backward pass + + @jax.jit + def f(x): + return mul(x, x) + + b, f_vjp = jax.vjp(f, a) + self.assertIn('MyTy', str(f_vjp)) + a_grad, = f_vjp(b) + self.assertIsInstance(a_grad, MyArray) + self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) + + def test_box_autodiff(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + class StashTangents(HiPrimitive): + def is_high(self): + return True + + def abstract_eval(_, box_aval, x_aval): + del box_aval + return x_aval, set() + + def to_lojax(_, box, x): + assert False # TODO + + def jvp(_, primals, tangents): + box, x = primals + _, x_dot = tangents + box_set(box, x_dot) + return x, x_dot + + def transpose(self, *args): + assert False # TODO + stash_tangents_p = StashTangents('stash_tangents') + + def stash_tangents(box, x): + return stash_tangents_p.bind(box, x) + + @jax.jit + def f(box, x): + box_set(box, x) + + box = Box(0.0) + f(box, 1.) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + + @jax.jit + def f(box, x): + x = stash_tangents(box, x) + return x + + box = Box(0.0) + jax.jvp(partial(f, box), (3.,), (5.,)) + self.assertAllClose(box_get(box), 5.0, check_dtypes=False) + + def test_type_changing_box(self): + box = Box(jnp.arange(1)) + box_set(box, jnp.arange(2)) + self.assertLen(box._val, 2) + + @jax.jit + def f(box, x): + box_set(box, x) + + f(box, jnp.arange(3)) + self.assertLen(box._val, 3) + f(box, jnp.arange(4)) + self.assertLen(box._val, 4) + + def test_pytree_box(self): + box = Box(None) + + @jax.jit + def f(box, x): + assert tracing_ok + val = box_get(box) + if val is None: + box_set(box, x) + else: + box_set(box, [x, x]) + + tracing_ok = True + f(box, 1.0) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + f(box, 2.0) + self.assertAllClose(box_get(box), [2.0, 2.0], check_dtypes=False) + f(box, 3.0) + self.assertAllClose(box_get(box), [3.0, 3.0], check_dtypes=False) + tracing_ok = False + f(box, 4.0) + self.assertAllClose(box_get(box), [4.0, 4.0], check_dtypes=False) + + def test_pytree_of_hijaxtypes_box(self): + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(core.AbstractValue): + mutable = False + + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + box = Box([MyArray(jnp.float32(1)), + MyArray(jnp.float32(2))]) + + @jax.jit + def f(box): + a, b = box_get(box) + box_set(box, [b, a]) + + f(box) + val = box_get(box) + self.assertIsInstance(val, list) + self.assertLen(val, 2) + b_, a_ = val + self.assertIsInstance(a_, MyArray) + self.assertIsInstance(b_, MyArray) + self.assertAllClose(a_.arr, 1, check_dtypes=False) + self.assertAllClose(b_.arr, 2, check_dtypes=False) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From 4097840c58599971cde03c8d5044b7336bcad085 Mon Sep 17 00:00:00 2001 From: Zhuo Peng Date: Fri, 16 May 2025 17:05:19 -0700 Subject: [PATCH 1212/1769] Added a test case to guard JAX ad, jax2tf in jax.lax.scan. PiperOrigin-RevId: 759808188 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 48 ++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 3052b532cb97..db608adc3dde 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1694,6 +1694,54 @@ def f_jax(x): "Unsupported precision in dot_general"): jax2tf.convert(f_jax, native_serialization=False)(x) + def test_jvp_through_loop(self): + # Context: b/388929258 + + num_actions = 512 + + def tf_preprocessor(features): + features["num_c_actions"] = tf.constant(256, tf.int32) + return features + + def postprocessor(prob, features): + actions = jnp.arange(num_actions, dtype=jnp.int32) + r = actions // features["num_c_actions"] + c = actions - r * features["num_c_actions"] + rr = jnp.array([0.12, 0.3])[r] * prob + rc = (jnp.arange(256) * 0.7)[c] * prob + return rr, rc + + def loop_step(features, params): + features = jax2tf.call_tf(tf_preprocessor)(features) + odds = features["f1"] @ params["w1"] + features["f2"] @ params["w2"] + prob = jax.nn.sigmoid(odds) + rr, rc = postprocessor(prob, features) + new_f1 = jnp.mean(rr, keepdims=True) + new_f2 = jnp.mean(rc, keepdims=True) + return new_f1, new_f2 + + def loop(init_features, params): + def body(carry, unused_x): + f1, f2 = carry + return loop_step({"f1": f1, "f2": f2}, params), None + + (rr, rc), _ = jax.lax.scan( + body, (init_features["f1"], init_features["f2"]), length=10 + ) + return rr, rc + + def loss(features, params): + rr, rc = loop(features, params) + return jnp.mean((rr - rc) ** 2) + + jax.grad(loss, argnums=(1,))( + {"f1": jnp.array([0.5]), "f2": jnp.array([0.7])}, + { + "w1": jnp.ones((1, num_actions)) * 0.01, + "w2": jnp.ones((1, num_actions)) * 0.01, + }, + ) + @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): From 008f2210f7bec54a9906b74c91384e2b04de8d1f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 17 May 2025 08:43:04 -0400 Subject: [PATCH 1213/1769] Increase some test shardings to work around CI timeouts. --- tests/BUILD | 2 +- tests/pallas/BUILD | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index c417a63404a9..1b33f292d503 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2002,7 +2002,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 40, - "gpu": 20, + "gpu": 30, "tpu": 20, }, tags = [ diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 109af4213b81..86c3fe187b79 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -120,7 +120,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 16, - "gpu": 16, + "gpu": 32, "tpu": 16, }, tags = [ From 89a76088ebc3e78172e1a37cf5186936ef7f6d8b Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 17 May 2025 05:48:54 -0700 Subject: [PATCH 1214/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/582b15b6146c3c20ecc88cc2fc7dfddbe3e63dc0. PiperOrigin-RevId: 759975157 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 589413130c8e..678c8660b546 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5fee96f09a42daa80283dde9fb7090ba90d9d07a" -XLA_SHA256 = "8d5d109185dc4383b7589504e7da769f9ec57c360d2a4810db9b2b407f7a9fa4" +XLA_COMMIT = "582b15b6146c3c20ecc88cc2fc7dfddbe3e63dc0" +XLA_SHA256 = "60ab686b2bd9cc1d58c4e04c21a3ae7dcfd51d1d83c30e9f85b2ef4d4f7a4f96" def repo(): tf_http_archive( From 819476df4eae0826049257fa0e5d6b4ab1e0cb1b Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 18 May 2025 05:55:36 -0700 Subject: [PATCH 1215/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d4ff3a507f46f8102065b598456a99dc17758394. PiperOrigin-RevId: 760260591 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 678c8660b546..e105491ae6af 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "582b15b6146c3c20ecc88cc2fc7dfddbe3e63dc0" -XLA_SHA256 = "60ab686b2bd9cc1d58c4e04c21a3ae7dcfd51d1d83c30e9f85b2ef4d4f7a4f96" +XLA_COMMIT = "d4ff3a507f46f8102065b598456a99dc17758394" +XLA_SHA256 = "fd491423023f7e2cc6f4da75bc95331be3d4f6ba47edef1095f7be1279cd5c27" def repo(): tf_http_archive( From 0a3e8dc9b7dc636b4deba44adc8ca8174d2f562b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 19 May 2025 02:16:07 -0700 Subject: [PATCH 1216/1769] [pallas] Generalized `empty_like` to accept a pytree of shapes/dtypes PiperOrigin-RevId: 760529300 --- jax/_src/pallas/helpers.py | 54 +++++++++++++++--------------- jax/_src/pallas/mosaic_gpu/core.py | 18 ++-------- 2 files changed, 30 insertions(+), 42 deletions(-) diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 684101e47e9e..6b274c0b6cce 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -13,10 +13,7 @@ # limitations under the License. """Pallas helper functions.""" -from typing import Any, Protocol - import jax -import jax.numpy as jnp from jax._src.pallas import pallas_call from jax._src.pallas import core as pl_core @@ -24,39 +21,42 @@ @jax.named_call def empty( shape: tuple[int, ...], - dtype: jnp.dtype, + dtype: jax.typing.DTypeLike, *, - memory_space: Any = None, - interpret: Any = False, + memory_space: object | None = None, + interpret: bool = False, + backend: pl_core.Backend | None = None, ): - def _empty_kernel(_): - # No-op to leave the out_ref uninitialized - pass + return empty_like( + jax.ShapeDtypeStruct(shape, dtype), + memory_space=memory_space, + interpret=interpret, + backend=backend, + ) + +@jax.named_call +def empty_like( + x: object, + *, + memory_space: object | None = None, + interpret: bool = False, + backend: pl_core.Backend | None = None, +): if memory_space is None: - kernel_memory_space = pl_core.MemorySpace.ANY - memory_space = jax.ShapeDtypeStruct - else: - kernel_memory_space = memory_space + memory_space = pl_core.MemorySpace.ANY return pallas_call.pallas_call( - _empty_kernel, - in_specs=[], - out_specs=pl_core.BlockSpec(memory_space=kernel_memory_space), - out_shape=memory_space(shape, dtype), + # No-op to leave the out_ref uninitialized + lambda *_: None, + out_specs=jax.tree.map( + lambda _: pl_core.BlockSpec(memory_space=memory_space), x + ), + out_shape=x, interpret=interpret, + backend=backend, )() -class ArrayLike(Protocol): - shape: tuple[int, ...] - dtype: jnp.dtype - - -def empty_like( - x: ArrayLike, *, memory_space: Any = None, interpret: Any = False): - return empty(x.shape, x.dtype, memory_space=memory_space, interpret=interpret) - - def when(condition): def _wrapped(f): if isinstance(condition, bool): diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 4a5cfbf3517f..3d43e2c0ab61 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -33,7 +33,7 @@ from jax._src import tree_util from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import primitives as pallas_primitives import jax._src.pallas.utils as pallas_utils from jax._src.state import discharge as state_discharge @@ -175,7 +175,7 @@ def kernel( out_shape: object, *, scratch_shapes: pallas_core.ScratchShapeTree = (), - compiler_params: object | None = None, + compiler_params: pallas_core.CompilerParams | None = None, **mesh_kwargs: object, ): if unwrap_out := not isinstance(out_shape, (tuple, list)): @@ -195,24 +195,12 @@ def cmap_body(): mesh, compiler_params=compiler_params )(cmap_body) _, outs = state_discharge.run_state(stateful)( - (operands, empty_like(out_shape)) + (operands, pallas_helpers.empty_like(out_shape, backend="mosaic_gpu")) ) return outs[0] if unwrap_out else outs return wrapper -def empty_like(shape): - return pallas_call.pallas_call( - lambda *_: None, - out_shape=shape, - out_specs=jax.tree.map( - lambda _: pallas_core.BlockSpec(memory_space=GPUMemorySpace.GMEM), - shape, - ), - backend="mosaic_gpu", - )() - - def _is_known_divisible(value, divisor, fuel=10) -> bool: """Returns True if the value is statically known to be divisible by the divisor.""" if divisor == 1: From 88105e90e03dc52055a57f2d84628bb563a053e9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 19 May 2025 02:55:28 -0700 Subject: [PATCH 1217/1769] [pallas:mosaic_gpu] Added support for unrolling to `lax.fori_loop` lowering We currently require that `unroll` divides `length` for simplicity. This restriction can be lifted later if/when necessary. PiperOrigin-RevId: 760540276 --- jax/_src/pallas/mosaic_gpu/lowering.py | 60 ++++++++++++++++---------- tests/pallas/mosaic_gpu_test.py | 18 +++++++- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b611ea4c17f9..73623de43e39 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2370,11 +2370,11 @@ def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: ir.Value, - length: ir.Value | int, + length: int | ir.Value, consts, *args, has_loop_index: bool, - unroll: bool = False, + unroll: int | None = None, ): _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) arg_avals = arg_avals[has_loop_index:] @@ -2395,22 +2395,42 @@ def as_values(vals, avals): return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] def loop(loop_index, body_args): - if has_loop_index: - loop_index = arith_dialect.addi(loop_index, start) - jaxpr_args = [*consts, loop_index, *body_args] - else: - jaxpr_args = [*consts, *body_args] - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args - ) + outs = body_args + if unroll is not None: + loop_index = arith_dialect.muli( + loop_index, _ir_constant(unroll, start.type) + ) + loop_index = arith_dialect.addi(loop_index, start) + for step in range(unroll or 1): + if has_loop_index: + loop_index = arith_dialect.addi( + loop_index, _ir_constant(step, start.type) + ) + jaxpr_args = [*consts, loop_index, *outs] + else: + jaxpr_args = [*consts, *outs] + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ) return as_values(outs, out_avals) - if unroll: - assert isinstance(length, int) - outs = as_values(args, arg_avals) - for i in range(length): - outs = loop(_ir_constant(i, start.type), outs) - return outs + if unroll is not None: + if not isinstance(length, int): + raise NotImplementedError( + "``length`` must be an integer when ``unroll` is specified, got" + f" {length}" + ) + if length % unroll: + # TODO(slebedev): Emit an epilogue taking care of the remaining steps. + raise NotImplementedError( + f"``unroll`` must divide ``length``, got {unroll=} and {length=}" + ) + if unroll == length: + # Special-case: the loop is fully unrolled. + return loop(_ir_constant(0, start.type), as_values(args, arg_avals)) + return mgpu.fori( + _ir_constant(length // unroll, start.type), as_values(args, arg_avals) + )(loop).results else: if not isinstance(length, ir.Value): length = _ir_constant(length, start.type) @@ -2432,11 +2452,7 @@ def _scan_lowering_rule( _split_transpose: bool, ): # Can only handle fori_loop-like scans. - if ( - (num_extensive := len(args) - num_consts - num_carry) - or reverse - or not (unroll == 1 or unroll == length) - ): + if (num_extensive := len(args) - num_consts - num_carry) or reverse: raise NotImplementedError del linear, num_extensive, reverse @@ -2465,7 +2481,7 @@ def _scan_lowering_rule( consts, *args, has_loop_index=has_loop_index, - unroll=unroll == length, + unroll=unroll, ) if has_loop_index: # Need to return the final loop index value if the outer scan expects diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index dc35f03843f7..d71593dc9078 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1147,11 +1147,27 @@ def test_fori_loop_array(self, force_while): ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. - o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...]) + o_ref[...] = _fori_loop( + force_while, 2, 4, lambda i, x: x + i, x_ref[...] + ) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) + @parameterized.product(unroll=[1, 2]) + def test_fori_loop_array_unrolled(self, unroll): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) + ) + def kernel(x_ref, o_ref): + # Equivalent to x_ref[...] + 2 + 3 + 4 + 5. + o_ref[...] = lax.fori_loop( + 2, 6, lambda i, x: x + i, x_ref[...], unroll=unroll + ) + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x + 2 + 3 + 4 + 5) + @parameterized.product(force_while=[False, True]) def test_fori_loop_scalar(self, force_while): @functools.partial( From 097e755b22400f1d6ea633610b70e03d311d2e09 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 03:38:08 -0700 Subject: [PATCH 1218/1769] [Mosaic GPU] Add support for fp8 types in WGMMA PiperOrigin-RevId: 760551961 --- jax/experimental/mosaic/gpu/tcgen05.py | 2 +- jax/experimental/mosaic/gpu/wgmma.py | 25 ++++++++++++++++++------- tests/mosaic/gpu_test.py | 26 +++++++++++++++++++++----- 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 89a58e4788f5..c4f670527a0c 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -195,7 +195,7 @@ def mma( f" type f32 or f16, but got: {d.dtype}" ) else: - raise NotImplementedError(f"Unsupported element type: {element_type}", type(element_type)) + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 8baa16d8a7e9..3637778c371b 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -85,10 +85,11 @@ def tree_unflatten(cls, aux, value): def _supported_wgmma_types(dtype, abtype) -> bool: input_types_are = lambda ty: ty.isinstance(abtype) + f16_acc_types = (ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType) if ir.F32Type.isinstance(dtype): - return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type)) + return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types)) elif ir.F16Type.isinstance(dtype): - return input_types_are(ir.F16Type) + return any(input_types_are(ty) for ty in f16_acc_types) else: return False @@ -187,8 +188,12 @@ def take_regs(n): b_desc_reg, use_out_reg = take_regs(2) imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). assert next(reg_count) == len(reg_constraints_list) - el_ty = element_type k_instr = 32 // bytewidth(element_type) + el_ty = str(element_type) + if ir.Float8E5M2Type.isinstance(element_type): + el_ty = "e5m2" + elif ir.Float8E4M3FNType.isinstance(element_type): + el_ty = "e4m3" wgmma_instr = ( f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" @@ -291,18 +296,24 @@ def wgmma( f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}" ) f32 = ir.F32Type.get() + f16 = ir.F16Type.get() if element_type == f32 or element_type == ir.BF16Type.get(): if acc.value.mlir_dtype != f32: raise ValueError( f"WGMMA with element type {element_type} only supports accumulators" f" of type f32, but got: {acc.value.mlir_dtype}" ) - elif element_type == ir.F16Type.get(): - if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32: + elif any( + t.isinstance(element_type) + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType} + ): + if acc.value.mlir_dtype != f16 and acc.value.mlir_dtype != f32: raise ValueError( - "WGMMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {acc.value.mlir_dtype}" + f"WGMMA with element type {element_type} only supports accumulators " + f"of type f32 or f16, but got: {acc.value.mlir_dtype}" ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 62f377f031a0..08b424731f19 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -653,7 +653,13 @@ def setUp(self): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + in_mlir_dtype_cls=( + ir.F16Type, + ir.BF16Type, + ir.F32Type, + ir.Float8E5M2Type, + ir.Float8E4M3FNType, + ), m=(64, 128, 192), n=(64, 128, 192), k_steps=(1, 2), @@ -675,8 +681,8 @@ def test_wgmma_basic( rhs_tiling_kind, lhs_tiling_kind, ): - if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: - self.skipTest("Only f16 input is supported for f16 output.") + if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls in {ir.F32Type, ir.BF16Type}: + self.skipTest(f"{in_mlir_dtype_cls.get()} does not support f16 output.") if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": self.skipTest("Transpose only supported in 128B swizzled WGMMA") if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: @@ -686,10 +692,10 @@ def test_wgmma_basic( in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) + if (lhs_transpose or not rhs_transpose) and bytewidth(in_mlir_dtype) != 2: + self.skipTest("Transpose only supported in 16-bit WGMMA") if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead in_jax_dtype = jnp.float32 - if lhs_transpose or not rhs_transpose: - self.skipTest("Transpose only supported in 16-bit WGMMA") exponent_bits, mantissa_bits = 8, 10 # Use tf32 elif bytewidth(in_mlir_dtype) == 2: if n % 64 != 0: @@ -702,10 +708,18 @@ def test_wgmma_basic( exponent_bits, mantissa_bits = 8, 7 else: raise NotImplementedError(in_mlir_dtype) + elif in_mlir_dtype_cls == ir.Float8E5M2Type: + in_jax_dtype = jnp.float8_e5m2 + exponent_bits, mantissa_bits = 5, 2 + elif in_mlir_dtype_cls == ir.Float8E4M3FNType: + in_jax_dtype = jnp.float8_e4m3fn + exponent_bits, mantissa_bits = 4, 3 else: raise NotImplementedError(in_mlir_dtype) nk_tile = swizzle // bytewidth(in_mlir_dtype) k = nk_tile * k_steps + if n % nk_tile: + self.skipTest("tiling does not divide N") assert m % 64 == 0 and n % nk_tile == 0 small_rhs_tile = rhs_tiling_kind != "large" @@ -781,6 +795,8 @@ def quantize(x): x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 + if utils.bitwidth(in_mlir_dtype) == 8: + atol = 3e-2 np.testing.assert_allclose(z, ref, atol=atol) # TODO(apaszke): Add support for f32 From 849e20c29e504650619bb827c1aed1c10f4d4a6f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 05:04:35 -0700 Subject: [PATCH 1219/1769] [Pallas:MGPU] Fix a test that assumes x32 mode by specifying dtype explicitly PiperOrigin-RevId: 760576351 --- tests/pallas/gpu_pallas_distributed_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 882475406b28..3da39ba925c7 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -62,7 +62,7 @@ def kernel(x_ref, y_ref, ready_sem, recv_sem): device_id_type=pl.DeviceIdType.LOGICAL) pl.semaphore_wait(recv_sem) - x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) + x = jnp.arange(2 * 8 * 128.0, dtype=jnp.float32).reshape((2 * 8, 128)) def body(x): return pl.pallas_call( kernel, From 8f7f3e10be5755b2b3296830ba024dbb381462d0 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 05:14:21 -0700 Subject: [PATCH 1220/1769] [Mosaic GPU] Ignore singleton tiling dims when constructing nested shape They don't contribute anything, but they can trigger some NotImplemented errors down the line. PiperOrigin-RevId: 760579238 --- .../mosaic/gpu/fragmented_array.py | 2 ++ tests/mosaic/gpu_test.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 62c5903f475f..acbcf7c2cda1 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2166,10 +2166,12 @@ def transfer_tiled2( raise ValueError() nested_ref_shape = tuple( (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_ty.shape[i],) for i in range(ref_logical_rank) ) nested_ref_strides = tuple( (ref_strides[i], ref_strides[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_strides[i],) for i in range(ref_logical_rank) ) tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 08b424731f19..39bd8aa77331 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1469,6 +1469,25 @@ def kernel(ctx, src, dst, smem): y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) + def test_tma_with_1d_tiling(self): + swizzle = 128 + dtype = jnp.float16 + shape = (64, 128) + tiling = (1, swizzle // jnp.dtype(dtype).itemsize) + def kernel(ctx, dst, smem): + iota_tensor(*shape, dtype=dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(utils.tile_shape(shape, tiling), dtype) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), x, smem)() + np.testing.assert_array_equal(y, x) + @parameterized.named_parameters( ( f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", From 091b3ecefd4d512882cc3a7abad4ad60b3f0f5cb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 19 May 2025 05:19:39 -0700 Subject: [PATCH 1221/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/821e786cf9e478c8d866610203a80acaa70539b9. PiperOrigin-RevId: 760580406 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e105491ae6af..d8f1faaad803 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d4ff3a507f46f8102065b598456a99dc17758394" -XLA_SHA256 = "fd491423023f7e2cc6f4da75bc95331be3d4f6ba47edef1095f7be1279cd5c27" +XLA_COMMIT = "821e786cf9e478c8d866610203a80acaa70539b9" +XLA_SHA256 = "42d7cb180a65ea9e8589805941ce05612df987a5d00a98381fd548dc1dd31211" def repo(): tf_http_archive( From 0f5e952975a15435a4e71395c45f4f3998ab79d9 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 19 May 2025 05:50:34 -0700 Subject: [PATCH 1222/1769] #sdy fix `shard_map` lowering if there is only one device. This was wrong when the shmap contains callbacks, which cause tokens to be created. I was calling `jaxpr_subcomp` incorrectly. PiperOrigin-RevId: 760588882 --- jax/_src/shard_map.py | 10 +++------- tests/python_callback_test.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index ac529a667d04..010773f74d0c 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -812,17 +812,13 @@ def _shard_map_lowering_shardy( if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): - args = (*ctx.dim_var_values, *tokens, *in_nodes) out_nodes, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, - mlir.TokenSet(zip(ctx.tokens_in.effects(), in_nodes[:num_tokens])), - (), *args[num_tokens:], + mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)), + (), *in_nodes, dim_var_values=ctx.dim_var_values) - num_tokens = len(tokens_out.effects()) - tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( - ctx.tokens_in.effects(), out_nodes[:num_tokens]))) ctx.set_tokens_out(tokens_out) - return out_nodes[num_tokens:] + return out_nodes in_shardings = list(map( partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 9a3b26530044..26664faa6faf 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1363,9 +1363,18 @@ def f_base(i, x): jax.effects_barrier() self.assertEqual(_collected, expected) - def test_can_shard_io_callback_manually(self): + @parameterized.named_parameters( + dict(testcase_name='multi_device', + single_device=False), + dict(testcase_name='single_device', + single_device=True) + ) + def test_can_shard_io_callback_manually(self, single_device: bool): - mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + devices = jax.devices() + if single_device: + devices = devices[:1] + mesh = Mesh(np.array(devices), axis_names=('x',)) spec = jax.sharding.PartitionSpec('x') sharding = jax.sharding.NamedSharding(mesh, spec) From 168f771c93cc8773fdc53e924defe8bae2b07c6d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 05:52:21 -0700 Subject: [PATCH 1223/1769] Add missing test skips and improve compatibility with older jaxlib versions PiperOrigin-RevId: 760589385 --- jax/_src/distributed.py | 4 +++- tests/fused_attention_stablehlo_test.py | 3 +++ tests/multiprocess_gpu_test.py | 6 ++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index dad445b8e539..ae1baf8052c0 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -159,7 +159,9 @@ def shutdown(self): if self.preemption_sync_manager: # It's important to shut down the preemption sync manager before the # client because the preemption sync manager depends on the client. - self.preemption_sync_manager.shutdown() + # TODO: Delete hasattr check once 0.6.1 is the minimum jaxlib version + if hasattr(self.preemption_sync_manager, "shutdown"): + self.preemption_sync_manager.shutdown() self.preemption_sync_manager = None if self.client: self.client.shutdown() diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 64e0f4377462..925fc2ed4825 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -503,6 +503,9 @@ def test_sdpa_broadcast_bias_and_dbias(self): ) @jtu.run_on_devices("cuda") def test_sdpa_dbias(self, batch_size: int): + # TODO: Delete once 0.6.0 is no longer supported. + if jtu.jaxlib_version() == (0, 6, 0): + self.skipTest("jaxlib 0.6.0 has a bug") if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") # cuDNN only supports dbias when batch size is 1. If the batch size is diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 20a2b9ba972b..c2ec44916745 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -82,9 +82,11 @@ def test_gpu_distributed_initialize(self): try: for proc in subprocesses: - out, _ = proc.communicate() + out, err = proc.communicate() self.assertEqual(proc.returncode, 0) - self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + self.assertEqual( + out, f"{num_gpus_per_task},{num_gpus}", msg=f"Process failed:\n\n{err}", + ) finally: for proc in subprocesses: proc.kill() From b2418681d7b55dc62bbc7a1646ea3d7b1d1da2ee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 06:31:23 -0700 Subject: [PATCH 1224/1769] [Mosaic GPU] A handful of minor bug fixes PiperOrigin-RevId: 760601235 --- jax/_src/pallas/mosaic_gpu/core.py | 4 ++-- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- jax/experimental/mosaic/gpu/launch_context.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 3d43e2c0ab61..d13977ac4fbf 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -627,7 +627,7 @@ def remote_ref( ) -> pallas_core.TransformedRef: """Translate memref to a symmetric memref on a peer device.""" if not isinstance(ref, pallas_core.TransformedRef): - if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( @@ -640,7 +640,7 @@ def transform_ref( transform: state_types.Transform ) -> pallas_core.TransformedRef: if not isinstance(ref, pallas_core.TransformedRef): - if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index acbcf7c2cda1..62b43b4de737 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1993,7 +1993,7 @@ def _store_untiled_wg_strided(self, ref: ir.Value): idxs = ([i] for i in self.layout.linear_thread_idxs()) except NotImplementedError: ref_ = ref - idxs = self.layout.thread_idxs() + idxs = self.layout.thread_idxs(self.shape) ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index eccb363f7537..175dc8b0ac74 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -878,8 +878,18 @@ def _ensure_nvshmem_decls(self): def to_remote(self, ref: ir.Value, peer: ir.Value): self._ensure_nvshmem_decls() if ir.MemRefType.isinstance(ref.type): + # We replace the offset in the ref type by 0, because memref_ptr always + # folds the offset into the pointer. + ref_ty = ir.MemRefType(ref.type) + strides, _ = ref_ty.get_strides_and_offset() + result_type = ir.MemRefType.get( + ref_ty.shape, + ref_ty.element_type, + ir.StridedLayoutAttr.get(0, strides), + ref_ty.memory_space, + ) return utils.ptr_as_memref( - self.to_remote(utils.memref_ptr(ref), peer), ref.type + self.to_remote(utils.memref_ptr(ref), peer), result_type ) if ref.type != ir.Type.parse("!llvm.ptr"): raise ValueError(f"Unsupported type for to_remote: {ref.type}") From 169ad4ae521872a85d5c5a93102f3c31663dd497 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 19 May 2025 06:40:26 -0700 Subject: [PATCH 1225/1769] Add version guards on unreduced lowering to shardy PiperOrigin-RevId: 760603883 --- jax/_src/named_sharding.py | 17 ++++++++++++----- jaxlib/xla_client.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index faf0b2a9f2b2..ae99236a6cdc 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -23,6 +23,7 @@ from jax._src import config from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType @@ -316,11 +317,17 @@ def build(self) -> sdy.TensorShardingAttr: replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) - return sdy.TensorShardingAttr.get( - mesh_attr, - [dim_sharding.build() for dim_sharding in self.dim_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], - unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) + if jaxlib_extension_version >= 342: + return sdy.TensorShardingAttr.get( + mesh_attr, + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) + else: + return sdy.TensorShardingAttr.get( + mesh_attr, + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes]) def __repr__(self): dim_sharding_repr = ', '.join( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 69e168de9c2d..8f8c829ee6c7 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 341 +_version = 342 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From d62c10f1b5075b3aa5b3327c71ed43d133fc975c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 19 May 2025 06:44:11 -0700 Subject: [PATCH 1226/1769] Disable cusolver version check. This appears to be failing for some reason, and it's safest just to disable it for now. --- jax_plugins/cuda/__init__.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 9df7fc69ff1a..02bcbcf16dbc 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -180,11 +180,14 @@ def _version_check(name: str, cuda_versions.cufft_build_version, # Ignore patch versions. scale_for_comparison=100) - _version_check("cuSOLVER", cuda_versions.cusolver_get_version, - cuda_versions.cusolver_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=11400) + # TODO(phawkins): for some reason this check fails with a cusolver internal + # error when fetching the version. This may be a path error from our stubs. + # Figure out what's happening here and reenable. + # _version_check("cuSOLVER", cuda_versions.cusolver_get_version, + # cuda_versions.cusolver_build_version, + # # Ignore patch versions. + # scale_for_comparison=100, + # min_supported_version=11400) _version_check("cuPTI", cuda_versions.cupti_get_version, cuda_versions.cupti_build_version, min_supported_version=18) From 18ff6caa4f767701dd7cca3a1333d9b99465e045 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 19 May 2025 14:07:28 +0000 Subject: [PATCH 1227/1769] Bump NumPy version to 2.2.6 for 3.13+ Python versions. Remove workaround for NumPy 2.2 bug on aarch64. Also add a linux system constraint to GPU wheels. --- .bazelrc | 8 -- build/freethreading-requirements.txt | 4 +- build/nonfreethreading-requirements.txt | 2 +- build/requirements.in | 4 +- build/requirements_lock_3_10.txt | 4 +- build/requirements_lock_3_11.txt | 4 +- build/requirements_lock_3_12.txt | 4 +- build/requirements_lock_3_13.txt | 4 +- build/requirements_lock_3_13_ft.txt | 116 ++++++++++++------------ build/requirements_lock_3_14.txt | 2 +- 10 files changed, 72 insertions(+), 80 deletions(-) diff --git a/.bazelrc b/.bazelrc index 79df03863b02..53676637c839 100644 --- a/.bazelrc +++ b/.bazelrc @@ -244,10 +244,6 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" build:ci_linux_aarch64_base --color=yes -# Workaround for https://github.com/numpy/numpy/issues/28843 -# TODO(phawkins): remove this after upgrading to NumPy 2.2.6. -build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8 - build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -383,10 +379,6 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/ build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base -# Workaround for https://github.com/numpy/numpy/issues/28843 -# TODO(phawkins): remove this after upgrading to NumPy 2.2.6. -build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8 - # Mac x86 build:cross_compile_darwin_x86_64 --config=cross_compile_base build:cross_compile_darwin_x86_64 --config=nonccl diff --git a/build/freethreading-requirements.txt b/build/freethreading-requirements.txt index cc302cffdd0c..467578870ee9 100644 --- a/build/freethreading-requirements.txt +++ b/build/freethreading-requirements.txt @@ -1,3 +1,3 @@ # Under free-threading, we need an up-to-date numpy at least for the moment. -numpy~=2.2.5; python_version=="3.13" -numpy>=2.2.5; python_version>="3.14" +numpy~=2.2.6; python_version=="3.13" +numpy>=2.2.6; python_version>="3.14" diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt index f8171559a142..8bd139bf99ac 100644 --- a/build/nonfreethreading-requirements.txt +++ b/build/nonfreethreading-requirements.txt @@ -1,6 +1,6 @@ numpy~=2.0.0; python_version<="3.12" numpy~=2.1.0; python_version=="3.13" -numpy>=2.2.5; python_version>="3.14" +numpy>=2.2.6; python_version>="3.14" # These packages have not released free-threaded wheels. zstandard diff --git a/build/requirements.in b/build/requirements.in index 8b8af9d6b591..c5ce2ea279bd 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -19,8 +19,8 @@ wheel jaxlib # The with-cuda extra also includes NVIDIA's pip packages. -jax-cuda12-plugin[with-cuda] -jax-cuda12-pjrt +jax-cuda12-plugin[with-cuda] ; sys_platform == "linux" +jax-cuda12-pjrt ; sys_platform == "linux" # TPU dependencies libtpu ; sys_platform == "linux" and platform_machine == "x86_64" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index c4ca6088e4bf..a4c6b1bf2b77 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -160,13 +160,13 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 \ +jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 \ +jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 1f667115af04..0633e733414b 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -154,13 +154,13 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 \ +jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 \ +jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 20ca67a3e921..1ab77a6ec36e 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -154,13 +154,13 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 \ +jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 \ +jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 804373b03899..c20068b732e6 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -181,13 +181,13 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 \ +jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 \ +jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index c7a1c882fc73..3795343df0cb 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -172,13 +172,13 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 \ +jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 \ +jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ @@ -371,62 +371,62 @@ mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.5 ; python_version == "3.13" \ - --hash=sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70 \ - --hash=sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a \ - --hash=sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4 \ - --hash=sha256:0bcb1d057b7571334139129b7f941588f69ce7c4ed15a9d6162b2ea54ded700c \ - --hash=sha256:0cd48122a6b7eab8f06404805b1bd5856200e3ed6f8a1b9a194f9d9054631beb \ - --hash=sha256:19f4718c9012e3baea91a7dba661dcab2451cda2550678dc30d53acb91a7290f \ - --hash=sha256:1a161c2c79ab30fe4501d5a2bbfe8b162490757cf90b7f05be8b80bc02f7bb8e \ - --hash=sha256:1f4a922da1729f4c40932b2af4fe84909c7a6e167e6e99f71838ce3a29f3fe26 \ - --hash=sha256:261a1ef047751bb02f29dfe337230b5882b54521ca121fc7f62668133cb119c9 \ - --hash=sha256:262d23f383170f99cd9191a7c85b9a50970fe9069b2f8ab5d786eca8a675d60b \ - --hash=sha256:2ba321813a00e508d5421104464510cc962a6f791aa2fca1c97b1e65027da80d \ - --hash=sha256:2c1a1c6ccce4022383583a6ded7bbcda22fc635eb4eb1e0a053336425ed36dfa \ - --hash=sha256:352d330048c055ea6db701130abc48a21bec690a8d38f8284e00fab256dc1376 \ - --hash=sha256:369e0d4647c17c9363244f3468f2227d557a74b6781cb62ce57cf3ef5cc7c610 \ - --hash=sha256:36ab5b23915887543441efd0417e6a3baa08634308894316f446027611b53bf1 \ - --hash=sha256:37e32e985f03c06206582a7323ef926b4e78bdaa6915095ef08070471865b906 \ - --hash=sha256:3a801fef99668f309b88640e28d261991bfad9617c27beda4a3aec4f217ea073 \ - --hash=sha256:3d14b17b9be5f9c9301f43d2e2a4886a33b53f4e6fdf9ca2f4cc60aeeee76372 \ - --hash=sha256:422cc684f17bc963da5f59a31530b3936f57c95a29743056ef7a7903a5dbdf88 \ - --hash=sha256:4520caa3807c1ceb005d125a75e715567806fed67e315cea619d5ec6e75a4191 \ - --hash=sha256:47834cde750d3c9f4e52c6ca28a7361859fcaf52695c7dc3cc1a720b8922683e \ - --hash=sha256:47f9ed103af0bc63182609044b0490747e03bd20a67e391192dde119bf43d52f \ - --hash=sha256:498815b96f67dc347e03b719ef49c772589fb74b8ee9ea2c37feae915ad6ebda \ - --hash=sha256:54088a5a147ab71a8e7fdfd8c3601972751ded0739c6b696ad9cb0343e21ab73 \ - --hash=sha256:55f09e00d4dccd76b179c0f18a44f041e5332fd0e022886ba1c0bbf3ea4a18d0 \ - --hash=sha256:5a0ac90e46fdb5649ab6369d1ab6104bfe5854ab19b645bf5cda0127a13034ae \ - --hash=sha256:6411f744f7f20081b1b4e7112e0f4c9c5b08f94b9f086e6f0adf3645f85d3a4d \ - --hash=sha256:6413d48a9be53e183eb06495d8e3b006ef8f87c324af68241bbe7a39e8ff54c3 \ - --hash=sha256:7451f92eddf8503c9b8aa4fe6aa7e87fd51a29c2cfc5f7dbd72efde6c65acf57 \ - --hash=sha256:8b4c0773b6ada798f51f0f8e30c054d32304ccc6e9c5d93d46cb26f3d385ab19 \ - --hash=sha256:8dfa94b6a4374e7851bbb6f35e6ded2120b752b063e6acdd3157e4d2bb922eba \ - --hash=sha256:97c8425d4e26437e65e1d189d22dff4a079b747ff9c2788057bfb8114ce1e133 \ - --hash=sha256:9d75f338f5f79ee23548b03d801d28a505198297534f62416391857ea0479571 \ - --hash=sha256:9de6832228f617c9ef45d948ec1cd8949c482238d68b2477e6f642c33a7b0a54 \ - --hash=sha256:a4cbdef3ddf777423060c6f81b5694bad2dc9675f110c4b2a60dc0181543fac7 \ - --hash=sha256:a9c0d994680cd991b1cb772e8b297340085466a6fe964bc9d4e80f5e2f43c291 \ - --hash=sha256:aa70fdbdc3b169d69e8c59e65c07a1c9351ceb438e627f0fdcd471015cd956be \ - --hash=sha256:abe38cd8381245a7f49967a6010e77dbf3680bd3627c0fe4362dd693b404c7f8 \ - --hash=sha256:b13f04968b46ad705f7c8a80122a42ae8f620536ea38cf4bdd374302926424dd \ - --hash=sha256:b4ea7e1cff6784e58fe281ce7e7f05036b3e1c89c6f922a6bfbc0a7e8768adbe \ - --hash=sha256:b6f91524d31b34f4a5fee24f5bc16dcd1491b668798b6d85585d836c1e633a6a \ - --hash=sha256:c26843fd58f65da9491165072da2cccc372530681de481ef670dcc8e27cfb066 \ - --hash=sha256:c42365005c7a6c42436a54d28c43fe0e01ca11eb2ac3cefe796c25a5f98e5e9b \ - --hash=sha256:c8b82a55ef86a2d8e81b63da85e55f5537d2157165be1cb2ce7cfa57b6aef38b \ - --hash=sha256:ced69262a8278547e63409b2653b372bf4baff0870c57efa76c5703fd6543282 \ - --hash=sha256:d2e3bdadaba0e040d1e7ab39db73e0afe2c74ae277f5614dad53eadbecbbb169 \ - --hash=sha256:d403c84991b5ad291d3809bace5e85f4bbf44a04bdc9a88ed2bb1807b3360bb8 \ - --hash=sha256:d7543263084a85fbc09c704b515395398d31d6395518446237eac219eab9e55e \ - --hash=sha256:d8882a829fd779f0f43998e931c466802a77ca1ee0fe25a3abe50278616b1471 \ - --hash=sha256:e4f0b035d9d0ed519c813ee23e0a733db81ec37d2e9503afbb6e54ccfdee0fa7 \ - --hash=sha256:e8b025c351b9f0e8b5436cf28a07fa4ac0204d67b38f01433ac7f9b870fa38c6 \ - --hash=sha256:eb7fd5b184e5d277afa9ec0ad5e4eb562ecff541e7f60e69ee69c8d59e9aeaba \ - --hash=sha256:ec31367fd6a255dc8de4772bd1658c3e926d8e860a0b6e922b615e532d320ddc \ - --hash=sha256:ee461a4eaab4f165b68780a6a1af95fb23a29932be7569b9fab666c407969051 \ - --hash=sha256:f5045039100ed58fa817a6227a356240ea1b9a1bc141018864c306c1a16d4175 +numpy==2.2.6 ; python_version == "3.13" \ + --hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \ + --hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \ + --hash=sha256:0811bb762109d9708cca4d0b13c4f67146e3c3b7cf8d34018c722adb2d957c84 \ + --hash=sha256:0b605b275d7bd0c640cad4e5d30fa701a8d59302e127e5f79138ad62762c3e3d \ + --hash=sha256:0bca768cd85ae743b2affdc762d617eddf3bcf8724435498a1e80132d04879e6 \ + --hash=sha256:1bc23a79bfabc5d056d106f9befb8d50c31ced2fbc70eedb8155aec74a45798f \ + --hash=sha256:287cc3162b6f01463ccd86be154f284d0893d2b3ed7292439ea97eafa8170e0b \ + --hash=sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49 \ + --hash=sha256:37e990a01ae6ec7fe7fa1c26c55ecb672dd98b19c3d0e1d1f326fa13cb38d163 \ + --hash=sha256:389d771b1623ec92636b0786bc4ae56abafad4a4c513d36a55dce14bd9ce8571 \ + --hash=sha256:3d70692235e759f260c3d837193090014aebdf026dfd167834bcba43e30c2a42 \ + --hash=sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff \ + --hash=sha256:481b49095335f8eed42e39e8041327c05b0f6f4780488f61286ed3c01368d491 \ + --hash=sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4 \ + --hash=sha256:55a4d33fa519660d69614a9fad433be87e5252f4b03850642f88993f7b2ca566 \ + --hash=sha256:5a6429d4be8ca66d889b7cf70f536a397dc45ba6faeb5f8c5427935d9592e9cf \ + --hash=sha256:5bd4fc3ac8926b3819797a7c0e2631eb889b4118a9898c84f585a54d475b7e40 \ + --hash=sha256:5beb72339d9d4fa36522fc63802f469b13cdbe4fdab4a288f0c441b74272ebfd \ + --hash=sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06 \ + --hash=sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282 \ + --hash=sha256:74d4531beb257d2c3f4b261bfb0fc09e0f9ebb8842d82a7b4209415896adc680 \ + --hash=sha256:7befc596a7dc9da8a337f79802ee8adb30a552a94f792b9c9d18c840055907db \ + --hash=sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3 \ + --hash=sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90 \ + --hash=sha256:8e9ace4a37db23421249ed236fdcdd457d671e25146786dfc96835cd951aa7c1 \ + --hash=sha256:8fc377d995680230e83241d8a96def29f204b5782f371c532579b4f20607a289 \ + --hash=sha256:9551a499bf125c1d4f9e250377c1ee2eddd02e01eac6644c080162c0c51778ab \ + --hash=sha256:b0544343a702fa80c95ad5d3d608ea3599dd54d4632df855e4c8d24eb6ecfa1c \ + --hash=sha256:b093dd74e50a8cba3e873868d9e93a85b78e0daf2e98c6797566ad8044e8363d \ + --hash=sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb \ + --hash=sha256:b4f13750ce79751586ae2eb824ba7e1e8dba64784086c98cdbbcc6a42112ce0d \ + --hash=sha256:b64d8d4d17135e00c8e346e0a738deb17e754230d7e0810ac5012750bbd85a5a \ + --hash=sha256:ba10f8411898fc418a521833e014a77d3ca01c15b0c6cdcce6a0d2897e6dbbdf \ + --hash=sha256:bd48227a919f1bafbdda0583705e547892342c26fb127219d60a5c36882609d1 \ + --hash=sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2 \ + --hash=sha256:c820a93b0255bc360f53eca31a0e676fd1101f673dda8da93454a12e23fc5f7a \ + --hash=sha256:ce47521a4754c8f4593837384bd3424880629f718d87c5d44f8ed763edd63543 \ + --hash=sha256:d042d24c90c41b54fd506da306759e06e568864df8ec17ccc17e9e884634fd00 \ + --hash=sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c \ + --hash=sha256:e1dda9c7e08dc141e0247a5b8f49cf05984955246a327d4c48bda16821947b2f \ + --hash=sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd \ + --hash=sha256:e3143e4451880bed956e706a3220b4e5cf6172ef05fcc397f6f36a550b1dd868 \ + --hash=sha256:e8213002e427c69c45a52bbd94163084025f533a55a59d6f9c5b820774ef3303 \ + --hash=sha256:efd28d4e9cd7d7a8d39074a4d44c63eda73401580c5c76acda2ce969e0a38e83 \ + --hash=sha256:f0fd6321b839904e15c46e0d257fdd101dd7f530fe03fd6359c1ea63738703f3 \ + --hash=sha256:f1372f041402e37e5e633e586f62aa53de2eac8d98cbfb822806ce4bbefcb74d \ + --hash=sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87 \ + --hash=sha256:f447e6acb680fd307f40d3da4852208af94afdfab89cf850986c3ca00562f4fa \ + --hash=sha256:f92729c95468a2f4f15e9bb94c432a9229d0d50de67304399627a943201baa2f \ + --hash=sha256:f9f1adb22318e121c5c69a09142811a201ef17ab257a1e66ca3025065b7f53ae \ + --hash=sha256:fc0c5673685c508a142ca65209b4e79ed6740a4ed6b2267dbba90f34b0b3cfda \ + --hash=sha256:fc7b73d02efb0e18c000e9ad8b83480dfcd5dfd11065997ed4c6747470ae8915 \ + --hash=sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249 \ + --hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \ + --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via # -r build/freethreading-requirements.txt # contourpy diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt index 6edcd30ebe16..157dca5adbab 100644 --- a/build/requirements_lock_3_14.txt +++ b/build/requirements_lock_3_14.txt @@ -48,7 +48,7 @@ ml-dtypes==0.5.1 # tensorstore mpmath==1.4.0a4 # via -r build/test-requirements.txt -numpy==2.2.5 +numpy==2.2.6 # via # -r build/nonfreethreading-requirements.txt # contourpy From ec9e71e1e71e015bc6cdccf825a554da41f5368d Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 19 May 2025 07:28:21 -0700 Subject: [PATCH 1228/1769] [Mosaic GPU] Adding an optional causal masking to the manual pipelining example. Testing shows the runtime is about half the flops with causal masking vs without. Tests pass if we revert to cuda 12.0 with `--//third_party/gpus/cuda:by_exception_only_cuda_version_override=12_0` PiperOrigin-RevId: 760616647 --- .../pallas/ops/gpu/attention_mgpu.py | 94 +++++++++++++++---- tests/pallas/BUILD | 2 +- tests/pallas/mgpu_attention_test.py | 22 ++++- 3 files changed, 95 insertions(+), 23 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 3256953cd332..a100aa96faba 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -33,6 +33,7 @@ class TuningConfig: block_kv: int max_concurrent_steps: int use_schedule_barrier: bool = True + causal: bool = False compute_wgs_bwd: int = 1 block_q_dkv: int | None = None @@ -84,6 +85,8 @@ def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = Fal config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv + if kv_seq_len % block_kv: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") @@ -97,6 +100,12 @@ def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) + if config.causal: + block_q_end = (lax.axis_index("q_seq") + 1) * (2 * block_q) + block_max_kv_steps = pl.cdiv(block_q_end, jnp.array(block_kv, jnp.int32)) + else: + block_max_kv_steps = kv_seq_len // block_kv + @pl.when(wg_idx < 2) def _compute_wg(): plgpu.set_max_registers(232, action="increase") @@ -104,6 +113,11 @@ def _compute_wg(): lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q + if config.causal: + kv_steps = pl.cdiv(q_seq_base + block_q, jnp.array(block_kv, jnp.int32)) + else: + kv_steps = block_max_kv_steps + plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, @@ -121,12 +135,14 @@ def _compute_wg(): jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, ) - plgpu.barrier_wait(k_barriers.at[0]) + @pl.when(kv_steps > 0) + def _(): + plgpu.barrier_wait(k_barriers.at[0]) pl.when(wg_idx == 1)(perform_schedule_barrier) - def kv_loop(kv_step, carry): + def kv_loop(kv_step, carry, causal: bool = False): acc, m_i, l_i = carry - slot = lax.rem(kv_step, max_concurrent_steps) + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) # QK def compute_qk(acc_ref): @@ -136,6 +152,12 @@ def compute_qk(acc_ref): qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + if causal: + q_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 0, layout=plgpu.Layout.WGMMA) + kv_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 1, layout=plgpu.Layout.WGMMA) + mask = (q_ids + q_seq_base) >= (kv_ids + kv_step * block_kv) + qk = jnp.where(mask, qk, -jnp.inf) + # Softmax # We keep m scaled by log2e to use FMA instructions when computing p. log2e = math.log2(math.e) @@ -166,18 +188,35 @@ def compute_pv(acc_ref): plgpu.wgmma(acc_ref, p16, v_smem.at[slot]) wait_step = kv_step + 1 - wait_slot = lax.rem(wait_step, max_concurrent_steps) - @pl.when(wait_step < kv_seq_len // block_kv) + wait_slot = lax.rem(wait_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + @pl.when(wait_step < kv_steps) def _wait(): plgpu.barrier_wait(k_barriers.at[wait_slot]) acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) plgpu.barrier_arrive(v_consumed_barriers.at[slot]) return acc, m_i, l_i - if kv_seq_len % block_kv: - raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") - acc, m_i, l_i = lax.fori_loop( - 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) - ) + + if not config.causal: + acc, m_i, l_i = lax.fori_loop(0, block_max_kv_steps, kv_loop, (acc, m_i, l_i)) + else: + def epilogue_kv_loop(kv_step, _): + # This loop makes sure that all the pipelined KV data is processed, even + # if one compute wg finishes early like with causal masking. + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + plgpu.barrier_arrive(v_consumed_barriers.at[slot]) + perform_schedule_barrier() + perform_schedule_barrier() + + causal_kv_loop = functools.partial(kv_loop, causal=True) + full_kv_steps = lax.div(q_seq_base, jnp.array(block_kv, jnp.int32)) + # With causal masking, the KV loop unrolling is split in 3 sections: + # 1. A fast path where no causal mask is needed. + acc, m_i, l_i = lax.fori_loop(0, full_kv_steps, kv_loop, (acc, m_i, l_i)) + # 2. Causal masking. + acc, m_i, l_i = lax.fori_loop(full_kv_steps, kv_steps, causal_kv_loop, (acc, m_i, l_i)) + # 3. Epilogue to flush the data pipeline. + lax.fori_loop(kv_steps, block_max_kv_steps, epilogue_kv_loop, None) pl.when(wg_idx == 0)(perform_schedule_barrier) # TODO(apaszke): Invert and multiply to avoid expensive divisions. @@ -208,13 +247,13 @@ def _memory_wg(): def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps - tma_slot = lax.rem(kv_step, max_concurrent_steps) + tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) - lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) + lax.fori_loop(0, block_max_kv_steps - max_concurrent_steps, kv_loop, None) def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 @@ -291,6 +330,9 @@ def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do): del save_residuals q, k, v, out, lse = res + if config.causal: + raise NotImplementedError("Causal attention not supported in the backwards pass yet.") + if not config.has_backward_blocks: raise ValueError("Need to specify backward blocks.") @@ -586,6 +628,8 @@ def compute_dk(acc_ref): @functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): + if config.causal: + raise NotImplementedError("Causal attention is not supported with the pipeline emitter yet.") if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -762,15 +806,21 @@ def run_function(q, k, v, o, lse): return out -@functools.partial(jax.jit, static_argnames=["save_residuals"]) -def attention_reference(q, k, v, save_residuals=False): +@functools.partial(jax.jit, static_argnames=["causal", "save_residuals"]) +def attention_reference(q, k, v, causal=False, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape - num_kv_heads = k.shape[2] + kv_seq_len, num_kv_heads = k.shape[1], k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) q_reshaped = q.reshape( batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim ) logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k) + + if causal: + mask = jnp.arange(q_seq_len)[:, None] >= jnp.arange(kv_seq_len)[None, :] + mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape) + logits = jnp.where(mask, logits, -jnp.inf) + m = logits.max(axis=-1, keepdims=True) unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) @@ -798,11 +848,13 @@ def main(unused_argv): schedule_barrier_opts = (True,) problem_it = itertools.product( - (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts) - for batch_size, seq_len, head_dim, use_schedule_barrier in problem_it: + (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True)) + for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it: + if causal and use_pipeline_emitter: + continue q_seq_len = kv_seq_len = seq_len print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" - f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} ====") + f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} {causal=:} ====") k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) @@ -810,11 +862,11 @@ def main(unused_argv): block_q = 64 best = None for block_kv in (256, 128, 64): - config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier) + config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier, causal=causal) try: out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v) if seq_len < 32768: - out_ref = attention_reference(q, k, v) + out_ref = attention_reference(q, k, v, causal=causal) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) except ValueError as e: if "exceeds available shared memory" in e.args[0]: @@ -824,6 +876,8 @@ def main(unused_argv): matmul_flops = ( 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size ) + if causal: + matmul_flops //= 2 peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS optimal_time = matmul_flops / peak_flops * 1e6 # us achieved_tc_util = optimal_time / runtime_us * 100 diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 86c3fe187b79..6690fb2dac62 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -776,7 +776,7 @@ jax_multiplatform_test( "gpu_h100", ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, - shard_count = 4, + shard_count = 8, deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 50f9a455c9a2..f86793174c16 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -21,6 +21,7 @@ from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import cuda_versions from jax._src.pallas import pallas_call import jax.numpy as jnp @@ -63,11 +64,13 @@ def setUp(self): (4, 4), ), # MHA head_dim=(64, 128, 256), + blocks=((64, 64), (64, 128), (128, 64)), attention_impl=( attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), save_residuals=(True,), + causal=(True, False,), ) def test_flash_attention( self, @@ -76,10 +79,24 @@ def test_flash_attention( kv_seq_len, num_q_and_kv_heads, head_dim, + blocks, attention_impl, save_residuals, + causal, ): + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and (cuda_runtime_version >= 12080 and cuda_runtime_version < 12091): + self.skipTest("Skipping because of ptxas miscompilation.") + + if causal and attention_impl == attention_mgpu.attention_with_pipeline_emitter: + self.skipTest("Pipeline emitter does not support causal attention.") + + if head_dim >= 256 and max(blocks) >= 128: + self.skipTest("Head dim too large for block sizes.") + num_q_heads, num_kv_heads = num_q_and_kv_heads + block_q, block_kv = blocks k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) @@ -89,11 +106,12 @@ def test_flash_attention( k, v, attention_mgpu.TuningConfig( - block_q=64, block_kv=64, max_concurrent_steps=2 + block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, causal=causal ), save_residuals=save_residuals, ) - out_ref, *res_ref = attention_mgpu.attention_reference(q, k, v, save_residuals=save_residuals) + out_ref, *res_ref = attention_mgpu.attention_reference( + q, k, v, causal=causal, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) if save_residuals: (lse,) = res[0] From ea5a47d44a69f782d153fbecd5cf6999af77caec Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 May 2025 07:40:16 -0700 Subject: [PATCH 1229/1769] Fix the CAS loop in semaphore_wait lowering The loop incorrectly assumed that we're waiting for the semaphore value to be equal to the wait value, so that we can reset it to 0. This, however, is not how semaphores work. The CAS loop should wait until the value is _at least_ the wait value, and should update its expectation at every step in case the swap failed. The attached test deadlocks with the original implementation, but works fine with the new one. PiperOrigin-RevId: 760619094 --- jax/_src/pallas/mosaic_gpu/lowering.py | 19 ++++++++----- tests/pallas/gpu_pallas_distributed_test.py | 30 +++++++++++++++++++-- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 73623de43e39..bd304e8b6745 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2999,6 +2999,11 @@ def _semaphore_signal_lowering_rule( ) # TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local. val = _ir_constant(value, i32) + # We only signal the semaphore from a single lane, which does not guarantee + # anything about the state of the other three warps in the warpgroup (they + # might still be e.g. reading memory that someone will overwrite once they + # receive a signal). + mgpu.utils.warpgroup_barrier() pred = ctx.module_ctx.single_wg_lane_predicate llvm_dialect.inline_asm( i32, @@ -3022,23 +3027,25 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_ptr = mgpu.utils.memref_ptr(sem) i32_ty = ir.IntegerType.get_signless(32) ne_pred = arith_dialect.CmpIPredicate.ne - zero_const = mgpu.utils.c(0, i32_ty) val = _ir_constant(value, i32_ty) with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): # Create the while loop for busy waiting - while_op = scf_dialect.WhileOp([i32_ty], [zero_const]) + while_op = scf_dialect.WhileOp([i32_ty], [val]) before_block = while_op.before.blocks.append(i32_ty) with ir.InsertionPoint.at_block_begin(before_block): - old_val = llvm_dialect.inline_asm( + [expected_in_memory] = before_block.arguments + new_val = arith_dialect.subi(expected_in_memory, val) + in_memory = llvm_dialect.inline_asm( i32_ty, - [sem_ptr, val, zero_const], + [sem_ptr, expected_in_memory, new_val], "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", "=r,l,r,r", has_side_effects=True, ) - comparison = arith_dialect.cmpi(ne_pred, old_val, val) - scf_dialect.condition(comparison, before_block.arguments) + comparison = arith_dialect.cmpi(ne_pred, in_memory, expected_in_memory) + new_expected_in_memory = arith_dialect.maxui(in_memory, val) + scf_dialect.condition(comparison, [new_expected_in_memory]) after_block = while_op.after.blocks.append(i32_ty) with ir.InsertionPoint.at_block_begin(after_block): scf_dialect.yield_(after_block.arguments) diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 3da39ba925c7..d862e6b9b819 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -44,8 +44,6 @@ def setUp(self): super().setUp() def test_basic_remote_dma(self): - if jax.process_count() < 2: - self.skipTest("Test requires multiple processes.") if jax.process_index() > 2: return # Only 2 processes needed. def kernel(x_ref, y_ref, ready_sem, recv_sem): @@ -86,6 +84,34 @@ def body(x): expected = x[8:] if jax.process_index() == 0 else x[:8] np.testing.assert_allclose(y.addressable_shards[0].data, expected) + def test_wait_twice(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 2, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(sem) + pl.semaphore_wait(sem) + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + shard_map.shard_map( + kernel_call, mesh, in_specs=(), out_specs=P(None), check_rep=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + if __name__ == '__main__': jt_multiprocess.main() From 08b63899693000da4b098984e080e4d994d6e622 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 19 May 2025 10:51:32 -0400 Subject: [PATCH 1230/1769] Fix scipy nightly CI failures caused by toeplitz. --- tests/linalg_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index a9f81ec04560..cba3dbb7189d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -70,7 +70,9 @@ def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" - if scipy_version >= (1, 17, 0): + # TODO(dfm,jakevdp): Remove dev check after upstream PR is merged: + # https://github.com/scipy/scipy/issues/21466. + if scipy_version >= (1, 17, 0) and "dev0" not in scipy.version.version: return scipy.linalg.toeplitz(c, r) elif r is None: c = np.atleast_1d(c) From 3143214ff69c62523ed1a6baf47fd89dc40bfef9 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Sat, 17 May 2025 17:43:12 +0000 Subject: [PATCH 1231/1769] shard map pbroadcast --- jax/_src/lax/parallel.py | 19 ++++++++++++++++--- tests/shard_map_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 6df8690f1123..a9abf8f12939 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1122,14 +1122,27 @@ def _pbroadcast_lowering(ctx, x, *, axis_name, source): def source_to_front(group): return [group[source]] + list(group[:source]) + list(group[source + 1:]) replica_groups = [source_to_front(group) for group in replica_groups] - channel = ctx.module_context.new_channel() + is_spmd = isinstance( + ctx.module_context.axis_context, + (SPMDAxisContext, ShardingContext), + ) + if is_spmd: + # We want to emit the collective-broadcast with global device IDs and a unique + # channel ID, as otherwise it interprets the devices as replicas instead + # of partitions - and XLA is configured with only a single replica. + channel = ctx.module_context.new_channel() + channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) + other_args = dict(channel_handle=channel_handle) + else: + other_args = {} return hlo.CollectiveBroadcastOp( - x, replica_groups=_replica_groups_hlo(replica_groups)).results + x, replica_groups=_replica_groups_hlo(replica_groups), **other_args + ).results pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) -mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) +mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering, platform='gpu') batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name') diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2fdc846a356b..1bebba095896 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -290,6 +290,39 @@ def fwd(a): c = fwd(a) assert (c == jnp.reshape(a.T, (1, 64))).all() + @parameterized.named_parameters( + dict( + testcase_name='_partial_replicated', replicate_on_axes='x', + ), + dict( + testcase_name='_fully_replicated', + replicate_on_axes=('x', 'y'), + ), + ) + @jtu.run_on_devices("gpu") + def test_pbroadcast(self, replicate_on_axes): + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + sharded_axes = set(mesh.axis_names) - set(replicate_on_axes) + sharded_axes = None if not sharded_axes else list(sharded_axes) + in_out_sharding = jax.sharding.NamedSharding(mesh, P(sharded_axes, None)) + a = jax.device_put(jnp.arange(16).reshape((4, 4)), in_out_sharding) + + @jax.jit + @partial( + shard_map, + mesh=mesh, + in_specs=(in_out_sharding.spec,), + out_specs=in_out_sharding.spec, + check_vma=False, + ) + def fwd(x): + axis_index = lax.axis_index(replicate_on_axes) + x = jnp.where(axis_index == 0, x + 1, x) + return lax.pbroadcast(x, replicate_on_axes, source=0) + + c = fwd(a) # Don't crash + self.assertAllClose(c, a + 1) + def test_all_to_all_with_axis_index_groups(self): mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( From ff11a56e338be89e2ae0654e084da90b44d8f43a Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 19 May 2025 08:51:30 -0700 Subject: [PATCH 1232/1769] [CI] Rollback move to gcloud cp for tpu job PiperOrigin-RevId: 760640560 --- .github/workflows/pytest_tpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 2d4d2925bd2f..55a0b4cc1a5f 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -114,11 +114,11 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist - gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then echo "JAX only release. Only downloading the jax wheel from the release bucket." else - gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' From 62d59ace3087877347f606405bbcbc7196fe5201 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 19 May 2025 11:32:55 -0700 Subject: [PATCH 1233/1769] [Pallas] Fix 1D Iota PiperOrigin-RevId: 760705572 --- jax/_src/pallas/mosaic/lowering.py | 10 ++++++---- tests/pallas/ops_test.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9af3cf1e3c0a..a67a71dd40f7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2302,11 +2302,13 @@ def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, if len(shape) == 1: if dimension != 0: raise ValueError("Dimension must be 0 for 1D iota.") - def _1d_iota_helper(dtype, shape, dimension, sharding): - iota_2d = lax.iota_p.bind(dtype, (1,) + shape, dimension, sharding) + def _1d_iota_helper(): + iota_2d = lax.iota_p.bind(dtype=dtype, + shape=(1,) + shape, + dimension=1, + sharding=sharding) return iota_2d[0] - return lower_fun(_1d_iota_helper, multiple_results=False)( - ctx, dtype, shape, dimension, sharding) + return lower_fun(_1d_iota_helper, multiple_results=False)(ctx) out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 9bb6d31d15e1..3baa26e5efd7 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1554,7 +1554,7 @@ def kernel(x_ref, y_ref, o_ref): def test_iota(self, shape, dtype, dimension): self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["tpu"]): + if jtu.test_device_matches(["tpu"]) and dtype != jnp.int32: self.skipTest("Only 32-bit integer iota supported") f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) From 00996c698b185f68b46521a3ba2ba958a8091246 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Mon, 19 May 2025 15:08:08 -0700 Subject: [PATCH 1234/1769] cleanup: load python rules from rules_python --- build/BUILD.bazel | 1 + jaxlib/mlir/BUILD.bazel | 1 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 1 + 3 files changed, 3 insertions(+) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 761cf02ad624..a3d347d9209a 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -15,6 +15,7 @@ load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:jax.bzl", "all_py_deps") licenses(["notice"]) diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index c7231c557e78..3cc3003c8daa 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 6e54c9be83f5..2f0736c43f11 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_windows", From ef1adc349693d75bdd9ac1de050bdcf3f2aa401b Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Mon, 19 May 2025 17:22:06 -0700 Subject: [PATCH 1235/1769] Strip leading zeros local version in all the necessary places, for wheel name. PiperOrigin-RevId: 760830953 --- build/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/build.py b/build/build.py index b65f7a49dd8f..d059251552eb 100755 --- a/build/build.py +++ b/build/build.py @@ -640,7 +640,7 @@ async def main(): # Strip leading zeros as they end up being stripped by setuptools, # which leads to a mismatch between expected and actual wheel names # https://peps.python.org/pep-0440/ - wheel_git_hash = option.split("=")[-1][:9].lstrip('0') + wheel_git_hash = option.split("=")[-1].lstrip('0')[:9] with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) From a17810d1afd77add4fa7bc827d897346ca67a240 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 19 May 2025 19:52:02 -0700 Subject: [PATCH 1236/1769] Propagate use_shardy_partitioner to XlaCallModule op. PiperOrigin-RevId: 760876735 --- jax/experimental/jax2tf/README.md | 2 ++ jax/experimental/jax2tf/jax2tf.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index cb1c97bc7b7c..06cc5c86a109 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -1007,6 +1007,8 @@ We list here a history of the serialization version numbers: available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). This is the only supported version as of 27th of March, 2024. + * Version 10 propagate the `jax.config.use_shardy_partitioner` value to + XlaCallModule. ## Known issues diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 536bf1f201f0..3c34a26af982 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -944,6 +944,11 @@ def _convert_value(val, aval): if DisabledSafetyCheck.platform() in exported.disabled_safety_checks: call_module_attrs["platforms"] = () # No platform checking + if version >= 10: + call_module_attrs["use_shardy_partitioner"] = ( + config.use_shardy_partitioner.value + ) + if logging.vlog_is_on(3): # We already logged the MLIR module when we exported it. logging.vlog(3, "XlaCallModule %s", str(call_module_attrs)) From 30339b08e1f88d7e9808589b67fdd40550324c77 Mon Sep 17 00:00:00 2001 From: Jaswanth Sreeram Date: Mon, 19 May 2025 21:37:32 -0700 Subject: [PATCH 1237/1769] [Mosaic] Add support for currently unsupported reshapes for 32-bit datatypes with native tiling and adds tests for those cases. The cases supported are (k % 128 == 0 in the below): - (q, m, n, k) -> (q, m, n * k) - (p, q, m, n, k) -> (p, q * m * n * k) - (q, m, n, k) -> (q, m, 1, n * k) (in 2 steps, first to n*k then add unit dim) - (q, m, n, k) -> (q * m, n * k) - (q * m, n, k) -> (q, m, n * k) - (q * m, n * k) -> (q, m, n, k) - (q, m, n * k) -> (q * m, n, k) PiperOrigin-RevId: 760904758 --- .../tpu/transforms/apply_vector_layout.cc | 208 ++++++++++++---- .../tpu/transforms/infer_vector_layout.cc | 26 +- jaxlib/mosaic/dialect/tpu/util.cc | 12 + jaxlib/mosaic/dialect/tpu/util.h | 6 + tests/pallas/tpu_pallas_test.py | 225 ++++++++++++++++++ 5 files changed, 433 insertions(+), 44 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index ba1dfc95c66c..5ddff9d9ee53 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4301,6 +4301,43 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, return success(); } +// Copy one sublane from a vreg to another vreg. +// +// Arguments: +// src_vreg: The source vreg to copy a sublane from. +// src_sl_idx: The sublane index in src_vreg to copy from. +// dst_vreg: The base vreg to copy the sublane into. May be null. +// dst_sl_idx: The sublane index in the result. +// +// Returns: +// A new dst_vreg with the copied sublane. +Value copyOneSublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, + Value dst_vreg, int dst_sl_idx, + const std::array target_shape) { + src_vreg = builder.create( + src_vreg.getLoc(), src_vreg, + /*amount=*/(dst_sl_idx - src_sl_idx + target_shape[0]) % target_shape[0], + /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); + if (dst_vreg) { + auto boundIdxConst = + std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc()); + const int bitwidth = + cast(src_vreg.getType()).getElementTypeBitWidth(); + CHECK_EQ(bitwidth, + cast(dst_vreg.getType()).getElementTypeBitWidth()); + const VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + auto sublanes_mask = builder.create( + src_vreg.getLoc(), vmask_ty, + ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, + ValueRange{boundIdxConst(dst_sl_idx + 1), + boundIdxConst(target_shape[1])}); + src_vreg = builder.create(src_vreg.getLoc(), sublanes_mask, + src_vreg, dst_vreg); + } + return src_vreg; +} + LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -4397,6 +4434,132 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, dst_vregs_local.Reshape( layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); return dst_vregs_local; + } else if ( + // Lower shape_casts for 32-bit types where the minor dimension both + // before and after the shape cast is a multiple of 128. We allow + // folding or unfolding multiple number of minor dimensions and folding + // or unfolding some number of leading dimensions. For example (given + // k % 128 == 0 in the following): + // (q, m, n, k) -> (q, m, n * k) + // (p, q, m, n, k) -> (p, q * m * n * k) + // (q, m, n, k) -> (q, m, 1, n * k) (in 2 steps, first to fold n, k then + // to add the unit dimension) + // (q, m, n, k) -> (q * m, n * k) + // (q * m, n, k) -> (q, m, n * k) + // (q * m, n * k) -> (q, m, n, k) + // (q, m, n * k) -> (q * m, n, k) + dst_shape.size() > 1 && src_shape.size() > 1 && + (mlir::tpu::canFoldMinorDimsToSize(src_shape, dst_shape.back()) || + mlir::tpu::canFoldMinorDimsToSize(dst_shape, src_shape.back())) && + dst_shape.back() % ctx.target_shape[1] == 0 && + src_shape.back() % ctx.target_shape[1] == 0 && + layout_in.offsets() == LayoutOffsets{0, 0} && + layout_in.hasNativeTiling(ctx.target_shape) && + layout_in.bitwidth() == 32 && + layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && + layout_out == layout_in) { + auto target_sublanes = ctx.target_shape[0]; + auto target_lanes = ctx.target_shape[1]; + xla::Array dst_vregs( + layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape)); + + auto to_linear_index = [&](absl::Span indices, + absl::Span bounds) { + CHECK_EQ(indices.size(), bounds.size()); + int linear_index = 0; + int multiplier = 1; + for (int i = indices.size() - 1; i >= 0; --i) { + linear_index += multiplier * indices[i]; + multiplier *= bounds[i]; + } + return linear_index; + }; + auto from_linear_index = [&](int linear_index, + absl::Span bounds) { + SmallVector indices(bounds.size(), 0); + int64_t divisor = std::accumulate(bounds.begin(), bounds.end(), 1, + std::multiplies()); + CHECK_GT(divisor, 0); + int64_t remainder = linear_index % divisor; + for (int i = 0; i < bounds.size(); ++i) { + int64_t radix = bounds[i]; + CHECK_GT(radix, 0); + divisor /= radix; + CHECK_GT(divisor, 0); + indices[i] = remainder / divisor; + remainder = remainder % divisor; + } + return indices; + }; + // Gather sublanes from src_vregs via rotating and selecting each relevant + // sublane from the source, into the destination vreg. + // Args: + // * src_sublane_indices: the mixed-radix indices of the sublanes to + // gather in the order they should be gathered. + // * src_vregs: the vregs to gather from. + // Returns: + // * a vreg with the gathered sublanes. + auto gather_sublanes = [target_sublanes]( + RewriteContext &ctx, Operation &op, + SmallVector> + src_sublane_indices, + const xla::Array &src_vregs) { + ImplicitLocOpBuilder builder(op.getLoc(), &op); + Value dst_vreg = getZerosVector( + builder, cast(src_vregs.begin()->getType())); + for (int sublane_number = 0; + sublane_number < src_sublane_indices.size(); ++sublane_number) { + SmallVector src_vreg_index = + src_sublane_indices[sublane_number]; + src_vreg_index[src_vreg_index.size() - 2] /= target_sublanes; + Value src_vreg = src_vregs(src_vreg_index); + int sublane_within_src_vreg = + src_sublane_indices[sublane_number] + [src_sublane_indices[sublane_number].size() - + 2] % + target_sublanes; + dst_vreg = copyOneSublane(builder, src_vreg, sublane_within_src_vreg, + dst_vreg, sublane_number, ctx.target_shape); + } + return dst_vreg; + }; + SmallVector dst_shape_in_sublanes(dst_shape); + dst_shape_in_sublanes[dst_shape.size() - 1] = + dst_shape[dst_shape.size() - 1] / target_lanes; + SmallVector src_shape_in_sublanes(src_shape); + src_shape_in_sublanes[src_shape.size() - 1] = + src_shape[src_shape.size() - 1] / target_lanes; + // The algorithm operates on 1 destination vreg at a time: + // 1. For each destination vreg, compute the linear index of each sublane + // within it + // 2. Map the destination sublane linear index to a source sublane linear + // index + // 3. convert that to a mixed-radix index into the source shape + // 4. Gather from those source sublane indices. + SmallVector indices; + dst_vregs.Each([&](absl::Span dst_vreg_indices, + Value *dst_vreg) { + indices.assign(dst_vreg_indices.begin(), dst_vreg_indices.end()); + indices[indices.size() - 2] *= target_sublanes; + int sublane_offset = to_linear_index(indices, dst_shape_in_sublanes); + + // Only move non-padding sublanes to the destination vreg. + int num_non_padding_sublanes = std::min( + dst_shape_in_sublanes[dst_shape_in_sublanes.size() - 2] - + dst_vreg_indices[dst_vreg_indices.size() - 2] * target_sublanes, + target_sublanes); + CHECK_EQ(dst_shape.back() % target_lanes, 0); + int stride_in_sublanes = dst_shape.back() / target_lanes; + SmallVector> gathered_sublanes( + num_non_padding_sublanes); + for (int i = 0; i < gathered_sublanes.size(); ++i) { + gathered_sublanes[i] = + from_linear_index(sublane_offset, src_shape_in_sublanes); + sublane_offset += stride_in_sublanes; + } + *dst_vreg = gather_sublanes(ctx, op, gathered_sublanes, src_vregs); + }); + return dst_vregs; } else { return shape_cast_op.emitOpError( "Not implemented: Unsupported vector.shape_cast: ") @@ -5262,45 +5425,6 @@ xla::Array retileToReducedSublanes( return dst_vreg_array; } - -// Copy one sublane from a vreg to another vreg. -// -// Arguments: -// src_vreg: The source vreg to copy a sublane from. -// src_sl_idx: The sublane index in src_vreg to copy from. -// dst_vreg: The base vreg to copy the sublane into. May be null. -// dst_sl_idx: The sublane index in the result. -// -// Returns: -// A new dst_vreg with the copied sublane. -Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, - Value dst_vreg, int dst_sl_idx, - const std::array target_shape) { - src_vreg = builder.create( - src_vreg.getLoc(), src_vreg, - /*amount=*/(dst_sl_idx - src_sl_idx + target_shape[0]) % target_shape[0], - /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); - if (dst_vreg) { - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc()); - const int bitwidth = - cast(src_vreg.getType()).getElementTypeBitWidth(); - CHECK_EQ(bitwidth, - cast(dst_vreg.getType()).getElementTypeBitWidth()); - const VectorType vmask_ty = - getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); - auto sublanes_mask = builder.create( - src_vreg.getLoc(), vmask_ty, - ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, - ValueRange{boundIdxConst(dst_sl_idx + 1), - boundIdxConst(target_shape[1])}); - src_vreg = builder.create(src_vreg.getLoc(), sublanes_mask, - src_vreg, dst_vreg); - } - return src_vreg; -} - - void rotateVregs(OpBuilder &builder, xla::Array &vregs, const int64_t amount, const int dimension) { if (amount != 0) { @@ -6714,9 +6838,9 @@ FailureOr>> changeImplicitDim( for (int tile_idx = 0; tile_idx < tiles_per_vreg; ++tile_idx) { int tile_off = tile_idx * sublanes_per_tile; *tile = - copy_one_sublane(builder, vregs(src_idx), - tile_off + src.offsets()[0].value_or(dst_sl_idx), - *tile, tile_off + dst_sl_idx, target_shape); + copyOneSublane(builder, vregs(src_idx), + tile_off + src.offsets()[0].value_or(dst_sl_idx), + *tile, tile_off + dst_sl_idx, target_shape); } } }); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index f42cfb139a37..9c4a7b4c397d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -60,7 +60,6 @@ using ImplicitDim = VectorLayout::ImplicitDim; static constexpr int kLayoutLog = 10; - bool is_fully_replicated(const Layout &layout) { static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt}; return layout.has_value() && layout->offsets() == replicated_offsets; @@ -1520,7 +1519,30 @@ class VectorLayoutInferer { native_tiling, ImplicitDim::kNone)); return success(); } - op.emitOpError("unsupported shape cast"); + + // Shape cast (..., m, n, k * target_shape_[1]) -> (..., m, n * k * + // target_shape_[1]) for 32-bit types. We allow multiple major or minor + // dimensions to be folded or unfolded. + if (kNativeBitwidth == bitwidth && res_shape.size() >= 2 && + src_shape.size() >= 2 && src_shape.back() % native_tiling[1] == 0 && + res_shape.back() % native_tiling[1] == 0 && + (mlir::tpu::canFoldMinorDimsToSize(src_shape, res_shape.back()) || + mlir::tpu::canFoldMinorDimsToSize(res_shape, src_shape.back()))) { + // TODO(jsreeram): Add support for picking space-efficient tilings for + // small 2nd minor dim shapes. + // Example 1: (4, 2, 1024) -> (4, 2048) If we infer src and tgt layout to + // be (1, 128), it is no-op because essentially we just shufflle the VREGs + // in VREG array. + // Example 2: (4, 256) -> (1, 1024) is actually sublane + // shuffle inside each vreg from [0, 1, 2, 3, 4,..7] to [0, 4, 1, 5, ...] + setLayout(op, + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, + ImplicitDim::kNone), + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, + ImplicitDim::kNone)); + return success(); + } + op.emitOpError("infer-vector-layout: unsupported shape cast"); return failure(); } diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index b562f81ad534..02598bd16f9a 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -301,4 +301,16 @@ std::optional getIntConst(Value v) { return std::nullopt; } +bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size) { + CHECK_GE(shape.size(), 2); + int64_t product = shape.back(); + for (int i = shape.size() - 2; i >= 1; --i) { + product *= shape[i]; + if (product >= target_size) { + break; + } + } + return product == target_size; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index ac83d95b715e..2a7325ee7b24 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -284,6 +284,12 @@ inline arith::ConstantOp I32Const(int32_t value, ArrayRef shape, } std::optional getIntConst(Value v); + +// Returns true if the product of up to `shape.size() - 1` minor-most dimensions +// in `shape` equals `target_size`. The major-most dimension is not considered. +// Precondition: `shape` has at least 2 dimensions. +bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size); + } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 83f21bca7fc1..aac249251e2b 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -3022,6 +3022,231 @@ def kernel(x_ref, out_ref): out, np.zeros((8, 8, 2, 128), dtype=jnp.float32) ) + # (q, m, n) -> (q, m * n) where n % 128 == 0 + @parameterized.parameters( + (32, 16, 512, jnp.float32), + (24, 1, 512, jnp.uint32), + (3, 3, 256, jnp.uint32), + (9, 15, 256, jnp.float32), + (3, 2, 256, jnp.float32), + ) + def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] + ) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m, n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n])) + + # (q, m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (3, 8, 17, 512, jnp.float32), + (1, 8, 9, 256, jnp.float32), + (1, 8, 3, 256, jnp.uint32), + (10, 1, 4, 256, jnp.uint32), + (1, 2, 2, 256, jnp.float32), + ) + def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (p, q, m, n, k) -> (p, q * m * n * k) where k % 128 == 0 + @parameterized.parameters( + (5, 3, 8, 17, 512, jnp.float32), + (6, 1, 8, 9, 256, jnp.float32), + (16, 1, 8, 3, 256, jnp.uint32), + (3, 2, 1, 4, 256, jnp.uint32), + (1, 7, 2, 2, 256, jnp.float32), + ) + def test_reshape_four_minor_dims_to_R2(self, p, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], + x_ref.shape[1] * x_ref.shape[2] * x_ref.shape[3] * x_ref.shape[4], + ) + + x = np.arange(p * q * m * n * k, dtype=dtype).reshape(p, q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((p, q * m * n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([p, q * m * n * k])) + + # (q, m, n, k) -> (q, m, 1, n * k) where k % 128 == 0 + def test_reshape_two_minor_dims_preserve_rank(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = ( + x_ref[...] + .reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + .reshape( + x_ref.shape[0], 1, x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + ) + + q, m, n, k = 10, 1, 4, 256 + x = np.arange(q * m * n * k, dtype=jnp.float32).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, 1, n * k), jnp.float32), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, 1, n * k])) + + # (q, m, n, k) -> (q * m, n * k) where k % 128 == 0 + @parameterized.parameters( + (3, 8, 17, 512, jnp.float32), + (1, 8, 9, 256, jnp.float32), + (1, 8, 3, 256, jnp.uint32), + (10, 1, 4, 256, jnp.uint32), + (1, 2, 2, 256, jnp.float32), + ) + def test_reshape_fold_two_leading_dims_and_two_minor_dims_R4_to_R2( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0] * x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n * k])) + + # (q * m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 3, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_leading_dim_and_fold_two_minor_dims_R3_to_R3( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + q, + m, + x_ref.shape[1] * x_ref.shape[2], + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (q * m, n * k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 3, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_leading_and_minor_dims_R2_to_R4( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + + # (q, m, n * k) -> (q * m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_fold_leading_dims_and_unfold_minor_dim( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q * m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n, k])) + + # (q, m, n, k) -> (q, m * n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_fold_middle_dims(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m * n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n, k])) + + # (q, m * n, k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_middle_dims(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m * n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True From ddbc64aa3cdc1106bcb23b469448dbbd4e953b08 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 19 May 2025 21:55:27 -0700 Subject: [PATCH 1238/1769] Automated Code Change PiperOrigin-RevId: 760909620 --- jaxlib/py_array.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 1222d410bad8..37932a2aed45 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1655,9 +1655,9 @@ void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { // Returns if shape has a major-to-minor layout. bool HasMajorToMinorLayout(const xla::Shape& shape) { if (shape.has_layout()) { - for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + for (int i = 0; i < shape.layout().minor_to_major().size(); ++i) { if (shape.layout().minor_to_major(i) != - shape.layout().minor_to_major_size() - 1 - i) { + shape.layout().minor_to_major().size() - 1 - i) { return false; } } From fc786d7422812d48637d03a47baf2b5b3bf15738 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 19 May 2025 23:26:58 -0700 Subject: [PATCH 1239/1769] [Mosaic] Use native bf16 ops for tanh, exp and log on TPUv6+. Replace `needs_cast` condition during canonicalization with `need_elementwise_canonicalization`. PiperOrigin-RevId: 760934277 --- .../tpu/transforms/canonicalize_mosaic.cc | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 645e6d615722..368bfc596732 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -359,12 +359,7 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, auto element_type = ty.getElementType(); // There's an annoying hodgepodge of elementwise ops that need to be // rewritten to f32 on later hardware. - // TODO(mvoz): Look into (1) what it would take to support these ops - // natively on later hardware, and (2) how to better organize this list. - bool needs_cast = ctx.hardware_generation <= 5 || isa(op) || - isa(op) || isa(op) || - isa(op); - if (needs_cast && element_type.isBF16()) { + if (element_type.isBF16()) { if (ctx.compatibility_mode) { auto target_f32 = builder.create(op.getLoc(), target_f32_ty, operand) @@ -918,21 +913,22 @@ const llvm::StringMap &rules() { return *rules; } -const llvm::StringMap &bf16_upcast_min_supported_versions() { +const llvm::StringMap &bf16_ops_min_supported_versions() { constexpr int kAlwaysUpcast = std::numeric_limits::max(); static const auto m = new llvm::StringMap{ {arith::DivFOp::getOperationName(), 4}, {arith::SelectOp::getOperationName(), 5}, {arith::CmpFOp::getOperationName(), 5}, - {arith::MulFOp::getOperationName(), kAlwaysUpcast}, - {arith::AddFOp::getOperationName(), kAlwaysUpcast}, - {arith::SubFOp::getOperationName(), kAlwaysUpcast}, - {arith::MaximumFOp::getOperationName(), kAlwaysUpcast}, - {arith::MinimumFOp::getOperationName(), kAlwaysUpcast}, + {arith::MulFOp::getOperationName(), 6}, + {arith::AddFOp::getOperationName(), 6}, + {arith::SubFOp::getOperationName(), 6}, + {arith::MaximumFOp::getOperationName(), 6}, + {arith::MinimumFOp::getOperationName(), 6}, {math::PowFOp::getOperationName(), kAlwaysUpcast}, - {math::TanhOp::getOperationName(), kAlwaysUpcast}, - {math::ExpOp::getOperationName(), kAlwaysUpcast}, - {math::LogOp::getOperationName(), kAlwaysUpcast}, + {math::TanhOp::getOperationName(), 6}, + {math::ExpOp::getOperationName(), 6}, + {math::Exp2Op::getOperationName(), 6}, + {math::LogOp::getOperationName(), 6}, }; return *m; } @@ -941,9 +937,8 @@ bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, Operation &op) { // Only rewrite when the hardware generation is below the minimum supported // version. - auto it = - bf16_upcast_min_supported_versions().find(op.getName().getStringRef()); - if (it == bf16_upcast_min_supported_versions().end() || + auto it = bf16_ops_min_supported_versions().find(op.getName().getStringRef()); + if (it == bf16_ops_min_supported_versions().end() || ctx.hardware_generation >= it->second) { return false; } From 0842cc6f386a20aa20ed20691fb78a43f6c4a307 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 May 2025 00:44:45 -0700 Subject: [PATCH 1240/1769] Automated Code Change PiperOrigin-RevId: 760956866 --- jaxlib/BUILD | 1 + jaxlib/xla_compiler.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 7047ddc3edd6..363218f4a9f3 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1231,6 +1231,7 @@ cc_library( "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:pjrt_executable", "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/proto:compile_options_proto_cc", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 272cd2d409a8..73007530c27b 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -61,6 +61,7 @@ limitations under the License. #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/proto/compile_options.pb.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_numpy.h" From 7025c2310b576030875327275c761cfb64c3720e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 May 2025 02:09:29 -0700 Subject: [PATCH 1241/1769] [pallas] Pulled `runtime_assert_enabled` from `pltpu` to `pl` PiperOrigin-RevId: 760983479 --- jax/_src/pallas/core.py | 17 ++++++++++++ jax/_src/pallas/mosaic/core.py | 16 ----------- jax/_src/pallas/mosaic/lowering.py | 35 +++++++++++------------- jax/_src/pallas/mosaic_gpu/lowering.py | 37 ++++++++++++++++++++++++++ jax/experimental/pallas/__init__.py | 2 ++ jax/experimental/pallas/tpu.py | 4 +-- tests/pallas/mosaic_gpu_test.py | 20 +++++++++++++- 7 files changed, 93 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 709bb4640241..fe755d61a310 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -44,6 +44,23 @@ from jax._src.state.types import TransformedRef import jax.numpy as jnp +# TODO(slebedev): Rename to --jax_pallas_debug_assertions. +_ENABLE_RUNTIME_ASSERT = config.bool_state( + "jax_pallas_enable_runtime_assert", + default=False, + help=( + "If set, enables runtime assertions in the kernel via checkify.check." + " Otherwise, runtime asserts will be ignored unless functionalized" + " using checkify.checkify." + ), +) + + +def runtime_assert_enabled() -> bool: + """Returns whether runtime asserts are enabled.""" + return _ENABLE_RUNTIME_ASSERT.value + + class DynamicGridDim: def __repr__(self): return "DynamicGridDim" diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index c04fc6f155b9..49ff632f5c14 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -23,7 +23,6 @@ from typing import Any, ClassVar, Literal import jax -from jax._src import config from jax._src import core as jax_core from jax._src import util from jax._src.pallas import core as pallas_core @@ -48,16 +47,6 @@ _out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping split_list = util.split_list -_ENABLE_RUNTIME_ASSERT = config.bool_state( - "jax_pallas_enable_runtime_assert", - default=False, - help=( - "If set, enables runtime assertions in the kernel via checkify.check." - " Otherwise, runtime asserts will be ignored unless functionalized" - " using checkify.checkify." - ), -) - class KernelType(enum.Enum): TC = 0 @@ -221,11 +210,6 @@ def create_tensorcore_mesh( ) -def runtime_assert_enabled() -> bool: - """Returns whether runtime asserts are enabled.""" - return _ENABLE_RUNTIME_ASSERT.value - - def _tensorcore_mesh_discharge_rule( in_avals, out_avals, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index a67a71dd40f7..a9fbf8dcd982 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -235,7 +235,7 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") -def _dtype_to_ir_type(dtype: jnp.dtype, +def _dtype_to_ir_type(dtype: jax.typing.DTypeLike, is_kernel_boundary: bool = False) -> ir.Type: if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): if jnp.issubdtype(dtype, tpu_core.dma_semaphore): @@ -246,11 +246,11 @@ def _dtype_to_ir_type(dtype: jnp.dtype, return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError - if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')): + if is_kernel_boundary and jnp.issubdtype(dtype, jnp.bool): dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(dtype) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -3766,14 +3766,15 @@ def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): @register_lowering_rule(checkify.check_p) def _checkify_lowering_rule( ctx: LoweringRuleContext, *err_args, err_tree, debug): - if not tpu_core.runtime_assert_enabled(): + if not pallas_core.runtime_assert_enabled(): if debug: return [] else: - raise LoweringException("Non-debug check must be functionalized. " - "Enable runtime asserts with " - "--jax_pallas_enable_runtime_assert " - "or functionalize with checkify.check.") + raise LoweringException( + "Non-debug check must be functionalized. Enable runtime asserts via" + " ``pl.enable_runtime_assert`` or --jax_pallas_enable_runtime_assert" + " or, alternatively, functionalize with ``checkify.check``." + ) if cf is None: # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. @@ -3782,20 +3783,16 @@ def _checkify_lowering_rule( ) error = jax.tree.unflatten(err_tree, err_args) - assert len(error._pred) == 1 - assert len(error._metadata) == 1 - assert len(error._payload) == 1 - pred = list(error._pred.items())[0][1] - metadata = list(error._metadata.items())[0] - payload = list(error._payload.items())[0][1] - exception_tree = metadata[1] + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() exception = jax.tree.unflatten(exception_tree, payload) assert isinstance(exception, checkify.FailedCheckError) + assert isinstance(exception, checkify.FailedCheckError) - # check_p has an inverted predicate compared to assert, - # so we need to compute not(pred) here. - out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) - minus_one = ir_constant(-1, out_scalar_type) + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = ir_constant(-1, _dtype_to_ir_type(jnp.bool)) not_pred = arith.xori(pred, minus_one) cf.assert_(not_pred, exception.fmt_string) return [] diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index bd304e8b6745..751a2bae2ed0 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -29,6 +29,7 @@ import jax from jax import api_util from jax import lax +from jax._src import checkify from jax._src import core as jax_core from jax._src import lib as jaxlib from jax._src import linear_util as lu @@ -41,6 +42,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import cf as cf_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import math as math_dialect @@ -3051,3 +3053,38 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): scf_dialect.yield_(after_block.arguments) mgpu_utils.warpgroup_barrier() return () + + +@register_lowering_rule(checkify.check_p, mgpu.LoweringSemantics.Lane) +def _checkify_lowering_rule( + ctx: LoweringRuleContext, *err_args, err_tree, debug +): + if not pallas_core.runtime_assert_enabled(): + if debug: + return [] + else: + raise LoweringError( + "Non-debug check must be functionalized. Enable runtime asserts via" + " ``pl.enable_runtime_assert`` or --jax_pallas_enable_runtime_assert" + " or, alternatively, functionalize with ``checkify.check``." + ) + + if cf_dialect is None: + # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. + raise ValueError( + "cf dialect is not available. Make sure you have jaxlib 0.6.1 or later." + ) + + error = jax.tree.unflatten(err_tree, err_args) + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() + exception = jax.tree.unflatten(exception_tree, payload) + assert isinstance(exception, checkify.FailedCheckError) + + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = _ir_constant(-1, mgpu_utils.dtype_to_ir_type(jnp.bool)) + not_pred = arith_dialect.xori(pred.registers.item(), minus_one) + cf_dialect.assert_(not_pred, exception.fmt_string) + return [] diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e631ad407fd..406d6e965322 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -18,6 +18,7 @@ https://docs.jax.dev/en/latest/pallas.html. """ +from jax._src.pallas.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 from jax._src.pallas.core import BlockDim as BlockDim from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec @@ -32,6 +33,7 @@ from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import runtime_assert_enabled as runtime_assert_enabled from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Squeezed as Squeezed from jax._src.pallas.core import squeezed as squeezed diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 5ed6968c673e..c8e2ba131a9b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -25,8 +25,6 @@ from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams -from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled -from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core @@ -53,6 +51,8 @@ # Those primitives got moved to Pallas core. Keeping the updated imports # here for backward compatibility. from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.core import runtime_assert_enabled as runtime_assert_enabled +from jax._src.pallas.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d71593dc9078..ba7f2d74bbb1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -28,14 +28,15 @@ import jax from jax import export from jax import lax +from jax._src import checkify from jax._src import test_util as jtu from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call +from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives -from jax._src.pallas import primitives as pallas_primitives from jax._src.state import types as state_types from jax.experimental import pallas as pl import jax.experimental.mosaic.gpu as mgpu @@ -995,6 +996,22 @@ def kernel(x_ref, o_ref): self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) + def test_check(self): + self.skip_if_wg_semantics() + + self.enter_context(pallas_core._ENABLE_RUNTIME_ASSERT(True)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + checkify.check(_sum_same_dtype(x_ref[...]) > 0, "x.sum() is negative") + o_ref[...] = x_ref[...] + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) + def test_load_scalar(self): @functools.partial( self.pallas_call, @@ -1776,6 +1793,7 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.semaphore_signal_p, pallas_primitives.semaphore_wait_p, pallas_primitives.semaphore_read_p, + checkify.check_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) From e5e9be55950d1358a93cd3eed10a91e6f5ae0168 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 May 2025 03:35:09 -0700 Subject: [PATCH 1242/1769] [pallas:mosaic_gpu] Added `plgpu.nd_loop` This is a generalization of `lax.fori_loop` which partitions the flat iteration space across the given axes, and is useful for writing persistent kernels. PiperOrigin-RevId: 761007326 --- jax/BUILD | 1 + jax/_src/pallas/mosaic_gpu/BUILD | 6 ++ jax/_src/pallas/mosaic_gpu/helpers.py | 86 +++++++++++++++++++++++++++ jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 29 +++++++++ 5 files changed, 123 insertions(+) create mode 100644 jax/_src/pallas/mosaic_gpu/helpers.py diff --git a/jax/BUILD b/jax/BUILD index 4018bff873bd..61b5a99dfe31 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -877,6 +877,7 @@ pytype_strict_library( deps = [ ":mosaic_gpu", "//jax/_src/pallas/mosaic_gpu:core", + "//jax/_src/pallas/mosaic_gpu:helpers", "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pipeline", "//jax/_src/pallas/mosaic_gpu:primitives", diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 2652be7a7c9a..74b44fb8f991 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -123,3 +123,9 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "helpers", + srcs = ["helpers.py"], + deps = ["//jax"], +) diff --git a/jax/_src/pallas/mosaic_gpu/helpers.py b/jax/_src/pallas/mosaic_gpu/helpers.py new file mode 100644 index 000000000000..54c4910059d5 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/helpers.py @@ -0,0 +1,86 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for Pallas Mosaic GPU kernels.""" + +from collections.abc import Callable, Hashable, Sequence +import math +from typing import TypeVar + +import jax +from jax import lax + +_T = TypeVar("_T") + + +def nd_loop( + grid: Sequence[int], + body: Callable[[Sequence[jax.Array], _T], _T], + init_val: _T, + *, + collective_axes: Sequence[Hashable] | Hashable, +) -> _T: + """A loop over a multi-dimensional grid partitioned along the given axes. + + For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size` + equal to 4 and the grid is (2, 3), the implementation would produce the + following iteration order + + loop step index axis index + + 0 (0, 0) 0 + 1 (0, 1) 1 + 2 (0, 2) 2 + 3 (1, 0) 3 + 4 (1, 1) 0 + 5 (1, 2) 1 + + which comes from partitioning the flat iteration space into chunks in an + interleaved fashion wrt the ``"x"`` axis index. + + Note that in the example the total number of loop steps is not divisible + by the axis size of ``"x"``, and thus for some ``"x"`` axis indices the + loop will do one iteration less. + + axis index indices + + 0 (0, 0), (1, 1) + 1 (0, 1), (1, 2) + 2 (0, 2) + 3 (1, 0) + + See also: + - :func:`jax.lax.fori_loop`: A single-dimensional indexed loop. + """ + axis_index = lax.axis_index(collective_axes) + axis_size = lax.axis_size(collective_axes) + grid_size = math.prod(grid) + + def wrapper(step, carry): + step = step * axis_size + axis_index + # The loop below is conceptually ``jnp.unravel_index``, but it uses + # ``lax`` APIs instead of ``jax.numpy`` to minimize the number of + # primitives used. + index = [] + for grid_dim in reversed(grid): + grid_dim = lax.convert_element_type(grid_dim, step.dtype) + index.append(lax.rem(step, grid_dim)) + step = lax.div(step, grid_dim) + index.reverse() + return body(tuple(index), carry) + + upper = lax.div(grid_size, axis_size) + lax.convert_element_type( + axis_index < grid_size % axis_size, axis_index.dtype + ) + return lax.fori_loop(0, upper, wrapper, init_val) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 7b300b8cfbfa..a7d8c3e34223 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -38,6 +38,7 @@ from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index ba7f2d74bbb1..68aeea5a03e8 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1762,6 +1762,35 @@ def kernel(x_ref, o_ref128, aliased_ref): with self.assertRaisesRegex(ValueError, "can't be assigned to"): kernel(jnp.arange(128).astype(jnp.float32)) + @parameterized.parameters(1, 2, 3) + def test_nd_loop(self, sm_steps): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((sm_steps, 132, 128), jnp.int32), + grid=(132,), + grid_names=("sm",), + ) + def kernel(o_ref): + def body(idx, _): + assert len(idx) == 3 + # We need to use `mode="clip"`, because the indices are not static. + flat_idx = jnp.ravel_multi_index(idx, (sm_steps, 4, 33), mode="clip") + sm_step = lax.div( + flat_idx, lax.convert_element_type(lax.axis_size("sm"), jnp.int32) + ) + o_ref[sm_step, lax.axis_index("sm")] = lax.broadcast( + flat_idx, o_ref.shape[-1:] + ) + + plgpu.nd_loop((sm_steps, 4, 33), body, None, collective_axes="sm") + + result = kernel() + for sm_step in range(sm_steps): + np.testing.assert_array_equal( + result[sm_step], + jnp.tile((132 * sm_step + jnp.arange(132))[:, None], 128), + ) + class PallasCallWGTest( PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From e9cdbaca8dec2ed39bbe592372f203fce428027a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 20 May 2025 06:39:03 -0400 Subject: [PATCH 1243/1769] Update scipy.signal.welch tests to be compatible with upstream dev version. --- jax/_src/third_party/scipy/signal_helper.py | 2 +- tests/scipy_signal_test.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 4a021675804d..ad7bdfbef62a 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -57,7 +57,7 @@ def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | N win = get_window(window, nperseg_int) win = jnp.array(win, dtype=dtype) else: - win = jnp.asarray(window) + win = jnp.asarray(window, dtype=dtype) nperseg_int = win.size if nperseg is None else int(nperseg) if win.ndim != 1: raise ValueError('window must be 1-D') diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 7ff3c87435c7..b1c5d9c98fed 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -357,12 +357,11 @@ def testWelchWithDefaultStepArgsAgainstNumpy( if use_nperseg: kwargs['nperseg'] = nperseg if use_window: - kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), - dtype=dtypes.to_complex_dtype(dtype)) + kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg)) if use_noverlap: kwargs['noverlap'] = noverlap - @jtu.ignore_warning(message="nperseg = 256 is greater than") + @jtu.ignore_warning(message="nperseg") def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype)) From a022864b048b8c62e19b96ceebaec851edbc24ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 May 2025 04:29:15 -0700 Subject: [PATCH 1244/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4a914a285377e90c464a0b7fad9b5cbcfeeb27a9. PiperOrigin-RevId: 761022928 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d8f1faaad803..d91873fdfad8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "821e786cf9e478c8d866610203a80acaa70539b9" -XLA_SHA256 = "42d7cb180a65ea9e8589805941ce05612df987a5d00a98381fd548dc1dd31211" +XLA_COMMIT = "4a914a285377e90c464a0b7fad9b5cbcfeeb27a9" +XLA_SHA256 = "3df562d5b67db755d88c469e45ae27ba9e1387f5d79cf25fba17c8c8ea74cfe8" def repo(): tf_http_archive( From e77df81f14f44bfe5ed1bf47b989673a13f3ceee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 20 May 2025 05:24:13 -0700 Subject: [PATCH 1245/1769] [Mosaic GPU] Implement reshapes from and to refs with empty shapes PiperOrigin-RevId: 761038562 --- jax/experimental/mosaic/gpu/utils.py | 32 +++++++++++++++++++++++++++- tests/mosaic/gpu_test.py | 6 ++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index bf0b06ccb9c9..9eedc3402579 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -511,12 +511,42 @@ def memref_reshape(ref: ir.Value, shape: tuple[int, ...]) -> ir.Value: f" allowed) {shape}" ) - return _reshape(ref, list(ref_ty.shape), list(shape)) + src_shape = list(ref_ty.shape) + dst_shape = list(shape) + if src_shape == dst_shape: + return ref + if not src_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get(ir.AffineMap.get_identity(len(dst_shape))) + else: + new_layout = ir.StridedLayoutAttr.get(offset, [1] * len(dst_shape)) + result_ty = ir.MemRefType.get(dst_shape, ref_ty.element_type, new_layout, ref_ty.memory_space) + return memref.expand_shape(result_ty, ref, [], [], dst_shape) + if not dst_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: + new_layout = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + else: + new_layout = ir.StridedLayoutAttr.get(offset, []) + result_ty = ir.MemRefType.get((), ref_ty.element_type, new_layout, ref_ty.memory_space) + return memref.collapse_shape(result_ty, ref, []) + return _reshape(ref, src_shape, dst_shape) def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: ref_ty = ir.MemRefType(ref.type) new_shape = list(ref_ty.shape) + if dim < 0: + raise ValueError(f"Dimension {dim} is negative") + if dim + fold_rank > len(new_shape): + raise ValueError( + f"Folding {fold_rank} dimensions starting from {dim} is out of bounds" + f" for shape {new_shape}" + ) new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])] identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) contig_strided_1d = ir.Attribute.parse("strided<[1]>") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 39bd8aa77331..80e67b20e1ef 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -389,17 +389,19 @@ def kernel(ctx, inp, out, _): ("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)), ("fold", (1, 5, 2, 1,), (1, 10, 1)), ("un", (1, 10, 1), (1, 5, 2, 1,)), + ("to_scalar", (1, 1, 1), ()), + ("from_scalar", (), (1, 1, 1)), ) def test_reshape(self, inp_shape, out_shape): def kernel(ctx, inp, out, _): copy(memref_reshape(inp, out_shape), out) - x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape) + x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(inp_shape) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) - np.testing.assert_array_equal(y, x.reshape(*out_shape)) + np.testing.assert_array_equal(y, x.reshape(out_shape)) @parameterized.named_parameters([ ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), From b8383df255fb9b056146911807a9a786cc4110e6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 05:42:06 -0700 Subject: [PATCH 1246/1769] Add Python 3.13 support to traceback code under PLATFORM_GOOGLE. This build options uses CPython internals, which changed under Python 3.13. PiperOrigin-RevId: 761042927 --- jaxlib/traceback.cc | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index 3eba5288335c..48edc584c94f 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -68,11 +68,12 @@ Traceback::Traceback() { #else // PY_VERSION_HEX < 0x030b0000 #ifdef PLATFORM_GOOGLE - // This code is equivalent to the version using public APIs, but it saves us - // an allocation of one object per stack frame. However, this is definitely - // violating the API contract of CPython, so we only use this where we can be - // confident we know exactly which CPython we are using (internal to Google). - // Feel free to turn this on if you like, but it might break at any time! +// This code is equivalent to the version using public APIs, but it saves us +// an allocation of one object per stack frame. However, this is definitely +// violating the API contract of CPython, so we only use this where we can be +// confident we know exactly which CPython we are using (internal to Google). +// Feel free to turn this on if you like, but it might break at any time! +#if PY_VERSION_HEX < 0x030d0000 for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; f != nullptr; f = f->previous) { if (_PyFrame_IsIncomplete(f)) continue; @@ -80,6 +81,16 @@ Traceback::Traceback() { frames_.emplace_back(f->f_code, _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); } +#else // PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->current_frame; f != nullptr; + f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_executable); + frames_.emplace_back(reinterpret_cast(f->f_executable), + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#endif // PY_VERSION_HEX < 0x030d0000 + #else // PLATFORM_GOOGLE PyFrameObject* next; for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); From 414db00e807a4faeeba953fd4a61b4f5fe3d864c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 20 May 2025 05:50:14 -0700 Subject: [PATCH 1247/1769] Only run relevant GPU tests in the optional GPU CI This aims to prepare those tests to be used as a presubmit. We can later add a more complete nightly configuration for testing. PiperOrigin-RevId: 761044847 --- .../workflows/bazel_optional_h100_b200.yml | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml index bde033361609..0c73b238505e 100644 --- a/.github/workflows/bazel_optional_h100_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -49,19 +49,23 @@ jobs: --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ --test_env=JAX_ACCELERATOR_COUNT=1 \ - --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=8 \ --strategy=TestRunner=local \ - --local_test_jobs=32 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ - --test_tag_filters=-multiaccelerator \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="1" \ --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ + --test_timeout=420 \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests \ - //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ - //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests + //tests:cudnn_fusion_test_gpu \ + //tests:scaled_matmul_stablehlo_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu \ + //tests:nn_test_gpu \ + //tests/pallas:gpu_tests \ + //tests/mosaic:gpu_tests run_multiaccelerator_tests: if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} runs-on: linux-x86-a3-8g-h100-8gpu @@ -86,13 +90,20 @@ jobs: --test_output=errors \ --strategy=TestRunner=local \ --local_test_jobs=8 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ --test_tag_filters=multiaccelerator \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="1" \ --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests \ - //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ - //tests/mosaic:gpu_tests //tests/mosaic:backend_independent_tests \ No newline at end of file + //tests/mosaic:gpu_tests \ + //tests/pallas:gpu_tests \ + //tests:array_interoperability_test_gpu \ + //tests:cudnn_fusion_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu + //tests:fused_attention_stablehlo_test_gpu \ + //tests:gpu_tests \ + //tests:python_callback_test_gpu \ + //tests:ragged_collective_test_gpu \ No newline at end of file From f2188786c225c7d16d8a7effd852470b2ad1b229 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Tue, 20 May 2025 05:59:50 -0700 Subject: [PATCH 1248/1769] [CI] Increase CUDA pytest timeout. PiperOrigin-RevId: 761047639 --- .github/workflows/pytest_cuda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index af034ab09991..a20be5b1dbcf 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -123,5 +123,5 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CUDA tests - timeout-minutes: 60 + timeout-minutes: 120 run: ./ci/run_pytest_cuda.sh \ No newline at end of file From 3c2cb8026e4ba3df1ad9e943c90da0cf4cfed3f8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 May 2025 08:03:00 -0700 Subject: [PATCH 1249/1769] [mosaic_gpu] Removed `uniform=` from `async_copy` It is redundant in the presence of `predicate=`. PiperOrigin-RevId: 761085680 --- .../mosaic/gpu/dialect_lowering.py | 2 - .../mosaic/gpu/examples/flash_attention.py | 4 +- .../mosaic/gpu/examples/matmul.py | 2 +- .../mosaic/gpu/examples/matmul_blackwell.py | 2 +- jax/experimental/mosaic/gpu/launch_context.py | 130 ++++++++---------- 5 files changed, 65 insertions(+), 75 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 320ae32607e9..20138bbe6fd4 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -713,7 +713,6 @@ def _mgpu_async_load_op_lowering_rule( gmem_slice=tuple(gmem_slice), barrier=barrier.barrier_ref, arrive=False, - uniform=True, swizzle=swizzle, gmem_transform=transforms, predicate=ctx.single_thread_per_warpgroup_predicate, @@ -755,7 +754,6 @@ def _mgpu_async_store_op_lowering_rule( gmem_slice=tuple(gmem_slice), swizzle=swizzle, gmem_transform=transforms, - uniform=True, predicate=ctx.single_thread_per_warpgroup_predicate, arrive=store_op.commit_group, ) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 071a4dec81fd..78ef1faddc59 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -309,7 +309,7 @@ def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), gmem_transform=transform, barrier=barrier, - uniform=False, + predicate=None, swizzle=128, ) def start_k_copy(slot, kv_seq_base): @@ -403,7 +403,7 @@ def kv_copy_init(slot, kv_seq_base): gmem_transform=t, barrier=barriers[slot], arrive=False, - uniform=False, + predicate=None, swizzle=128, ) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a5dd29e0dc4d..5c8363fa8b27 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -206,7 +206,7 @@ def fetch(slot, ki): rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes common_copy_args = dict( - swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, + swizzle=swizzle, barrier=barrier, arrive=False, predicate=None, ) with single_thread(): barrier.arrive_expect_tx(txcount) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 03363c1e365f..f771c8bc1ef1 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -114,7 +114,7 @@ def _tma_body(ki, _): swizzle=swizzle, barrier=full_barrier, arrive=False, - uniform=False, + predicate=None, collective=gpu.Dimension.x, partitioned=0, # Non-contracting dim is always 0. ) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 175dc8b0ac74..aaae007a67f0 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -321,6 +321,10 @@ def finalize_size(self): init_callback(self._alloc_op.result) +class _DefaultPredicate: + pass + + @dataclasses.dataclass() class LaunchContext: module: ir.Module @@ -506,12 +510,10 @@ def async_copy( barrier: utils.BarrierRef | None = None, swizzle: int | None = None, arrive: bool | None = None, - uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, - predicate: ( - ir.Value | None - ) = None, # Should select 0 or 1 threads from the WG. + # Should select 0 or 1 threads from the WG. + predicate: ir.Value | None | _DefaultPredicate = _DefaultPredicate(), reduction_op: ReductionOp | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -553,8 +555,8 @@ def async_copy( f"Expected same element type, got {element_type} and" f" {dst_ref_ty.element_type}" ) - if predicate is not None and not uniform: - raise ValueError("Predicate can only be defined when uniform is True") + if isinstance(predicate, _DefaultPredicate): + predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) if not isinstance(gmem_transform, tuple): gmem_transform = (gmem_transform,) @@ -756,13 +758,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) ] - uniform_ctx = ( - functools.partial( - utils.single_thread, scope=utils.ThreadSubset.WARPGROUP) - if uniform and predicate is None - else contextlib.nullcontext - ) - if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" @@ -792,68 +787,65 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): np.prod(slice_shape) * element_bitwidth * collective_size // 8, i32 ) barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - assert reduction_op is None - if collective_size > 1 and partitioned is not None: - if predicate is None: - predicate = c(1, ir.IntegerType.get_signless(1)) - if arrive: - first_block = arith.cmpi( - arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), - ) - arrive_predicate = arith.andi(predicate, first_block) - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=arrive_predicate - ) - rank = len(slice_shape) - idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) - llvm.inline_asm( - ir.Type.parse("!llvm.void"), - [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], - f""" - {{ - .reg .b32 mapped_addr; - @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; - @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 - [$1], [$2, {{{idx_operands}}}], [mapped_addr]; - }} - """, - "b,r,l,r" + ",r" * rank, - has_side_effects=True, + assert reduction_op is None + if collective_size > 1 and partitioned is not None: + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + if arrive: + first_block = arith.cmpi( + arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), ) - else: - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=predicate - ) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], - multicast_mask=multicast_mask, predicate=predicate + arrive_predicate = arith.andi(predicate, first_block) + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=arrive_predicate ) - else: - assert multicast_mask is None - if reduction_op is not None: - with uniform_ctx(): - if predicate is None: - predicate = c(1, ir.IntegerType.get_signless(1)) - rank = len(slice_shape) - idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) - llvm.inline_asm( + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) + llvm.inline_asm( ir.Type.parse("!llvm.void"), - [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], - f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", - "b,r,l" + ",r" * rank, + [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], + f""" + {{ + .reg .b32 mapped_addr; + @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; + @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 + [$1], [$2, {{{idx_operands}}}], [mapped_addr]; + }} + """, + "b,r,l,r" + ",r" * rank, has_side_effects=True, - ) - if arrive: - nvvm.cp_async_bulk_commit_group() + ) else: - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=predicate ) - if arrive: - nvvm.cp_async_bulk_commit_group() + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate + ) + else: + assert multicast_mask is None + if reduction_op is not None: + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + ) + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False From d1511c1d781cf0b3992ee9e62ae0773deb9cb833 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 May 2025 08:12:30 -0700 Subject: [PATCH 1250/1769] [pallas:mosaic_gpu] Removed the `GPU*` prefix from Mosaic GPU-specific types These APIs are always used qualified, e.g. `plgpu.GPUCompilerParams`, so the prefix is redundant. PiperOrigin-RevId: 761088896 --- docs/jax.experimental.pallas.mosaic_gpu.rst | 6 +- docs/pallas/gpu/reference.md | 9 +- jax/_src/pallas/mosaic_gpu/core.py | 40 ++--- jax/_src/pallas/mosaic_gpu/lowering.py | 14 +- .../mosaic_gpu/pallas_call_registration.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 12 +- jax/_src/pallas/pallas_call.py | 2 +- jax/experimental/pallas/mosaic_gpu.py | 26 +-- .../pallas/ops/gpu/attention_mgpu.py | 26 +-- tests/pallas/mosaic_gpu_test.py | 151 ++++++++++-------- tests/pallas/ops_test.py | 2 +- 11 files changed, 157 insertions(+), 135 deletions(-) diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 2d3452609c75..4191dde74df7 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -10,9 +10,9 @@ Classes :toctree: _autosummary Barrier - GPUBlockSpec - GPUCompilerParams - GPUMemorySpace + BlockSpec + CompilerParams + MemorySpace Layout SwizzleTransform TilingTransform diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 0db31e11b459..7b4a1e6e9c7d 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -225,17 +225,20 @@ def body(..., scratch_ref): There are two ways in which references are allocated and each has a way to select the desired transforms: -**1. Using `GPUBlockSpec`** +**1. Using `plgpu.BlockSpec`** ```python transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) f = pl.pallas_call( - in_specs=plgpu.GPUBlockSpec(in_block_shape, in_index_map, transforms=transforms), - out_specs=plgpu.GPUBlockSpec(out_block_shape, out_index_map, transforms=transforms), + in_specs=plgpu.BlockSpec(in_block_shape, in_index_map, transforms=transforms), + out_specs=plgpu.BlockSpec(out_block_shape, out_index_map, transforms=transforms), ... ) ``` +Note that unlike `plgpu.BlockSpec`, `pl.BlockSpec` does *not* allow specifying +transforms. + **2. Specifying the `transforms` argument on the allocated `SMEM`** ```python diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d13977ac4fbf..08e47cec4b2e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -71,7 +71,7 @@ def _slices(d): @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. Attributes: @@ -108,7 +108,7 @@ def __post_init__(self): ) -class GPUMemorySpace(enum.Enum): +class MemorySpace(enum.Enum): #: Global memory. GMEM = "gmem" #: Shared memory. @@ -145,7 +145,7 @@ def __call__(self, shape: tuple[int, ...]): dtype = pallas_core.BarrierSemaphore() else: dtype = pallas_core.Semaphore() - return pallas_core.MemoryRef(shape, dtype, GPUMemorySpace.GMEM) + return pallas_core.MemoryRef(shape, dtype, MemorySpace.GMEM) def get_array_aval(self) -> jax_core.ShapedArray: return self(()).get_array_aval() @@ -183,7 +183,7 @@ def kernel( def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs - mesh = GPUMesh(**mesh_kwargs) + mesh = Mesh(**mesh_kwargs) thread_name = mesh.thread_name if mesh.thread_name is not None else () def cmap_body(): pallas_primitives.run_scoped( @@ -234,7 +234,7 @@ class GPUMemoryRef(pallas_core.MemoryRef): collective: bool | None = dataclasses.field(default=None, kw_only=True) def __post_init__(self): - if self.memory_space != GPUMemorySpace.TMEM: + if self.memory_space != MemorySpace.TMEM: if self.packed is not None: raise ValueError("Packed option is only supported for TMEM.") if self.collective is not None: @@ -244,7 +244,7 @@ def get_ref_aval(self) -> _Ref: aval = jax_core.ShapedArray(self.shape, self.dtype) for t in self.transforms: aval = t(aval) - if self.memory_space == GPUMemorySpace.TMEM: + if self.memory_space == MemorySpace.TMEM: ref = pallas_core.TransformedRef( AbstractTMEMRef(aval, memory_space=self.memory_space, @@ -785,7 +785,7 @@ def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: @dataclasses.dataclass -class GPUBlockSpec(pallas_core.BlockSpec): +class BlockSpec(pallas_core.BlockSpec): transforms: Sequence[MemoryRefTransform] = () def to_block_mapping( @@ -817,10 +817,10 @@ def to_block_mapping( ) -GMEM = GPUMemorySpace.GMEM -SMEM = GPUMemorySpace.SMEM -TMEM = GPUMemorySpace.TMEM -REGS = GPUMemorySpace.REGS +GMEM = MemorySpace.GMEM +SMEM = MemorySpace.SMEM +TMEM = MemorySpace.TMEM +REGS = MemorySpace.REGS class barrier_dtype(dtypes.extended): @@ -903,7 +903,7 @@ def get_ref_aval(self) -> AbstractMemoryRef: "Preinitialized WGMMAAccumulatorRef only supported in pl.run_state." ) return WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), MemorySpace.REGS ) @staticmethod @@ -913,7 +913,7 @@ def init(array): def _wgmma_ref_type_mapping(ref: WGMMAAccumulatorRef): aval = WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), MemorySpace.REGS ) return aval, ref._init state_types._ref_type_aval_mappings[WGMMAAccumulatorRef] = _wgmma_ref_type_mapping @@ -962,7 +962,7 @@ def __repr__(self) -> str: _WARPGROUP_AXIS_NAME = object() @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUMesh: +class Mesh: grid: Sequence[int] = () grid_names: Sequence[str] = () cluster: Sequence[int] = () @@ -1049,15 +1049,15 @@ def _gpu_mesh_discharge_rule( cost_estimate, name, ): - if not isinstance(mesh, GPUMesh): - raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if compiler_params and not isinstance(compiler_params, GPUCompilerParams): + if not isinstance(mesh, Mesh): + raise TypeError(f"Mesh must be a `plgpu.Mesh`, got {type(mesh)}") + if compiler_params and not isinstance(compiler_params, CompilerParams): raise TypeError( - "Compiler params must be a GPUCompilerParams, got" + "Compiler params must be a `plgpu.CompilerParams`, got" f" {type(compiler_params)}" ) if not compiler_params: - compiler_params = GPUCompilerParams() + compiler_params = CompilerParams() return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, @@ -1073,7 +1073,7 @@ def _gpu_mesh_discharge_rule( ) -pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule +pallas_core._core_map_mesh_rules[Mesh] = _gpu_mesh_discharge_rule class MemoryEffect(jax_core.Effect): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 751a2bae2ed0..db959b1dbc24 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -568,7 +568,7 @@ def index_map(*indices): ) return eval_index_map(*new_indices) - return gpu_core.GPUBlockSpec( + return gpu_core.BlockSpec( bm.block_shape, index_map, memory_space=bm.transformed_block_aval.memory_space, @@ -581,7 +581,7 @@ def lower_pipelined_jaxpr_to_module( gpu_mesh: pallas_core.Mesh | None, jax_mesh: mesh_lib.Mesh | None, jaxpr: jax_core.Jaxpr, - params: gpu_core.GPUCompilerParams, + params: gpu_core.CompilerParams, cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. @@ -604,7 +604,7 @@ def lower_pipelined_jaxpr_to_module( ) if gpu_mesh: - assert isinstance(gpu_mesh, gpu_core.GPUMesh) + assert isinstance(gpu_mesh, gpu_core.Mesh) block = (128 * (gpu_mesh.num_threads or 1), 1, 1) grid = gpu_mesh.grid thread_axis = ( @@ -649,7 +649,7 @@ def ref_for_aval(aval: jax_core.AbstractValue): aval = v.aval if (isinstance(aval, pallas_core.AbstractMemoryRef) and jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): - if aval.memory_space != gpu_core.GPUMemorySpace.GMEM: + if aval.memory_space != gpu_core.MemorySpace.GMEM: raise ValueError( "Only GMEM memory space is supported for semaphores in Mosaic GPU." ) @@ -747,7 +747,7 @@ def lower_jaxpr_to_module( out_shapes: Sequence[jax.ShapeDtypeStruct], gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, - params: gpu_core.GPUCompilerParams, + params: gpu_core.CompilerParams, consts=(), ) -> LoweringResult: debug_info = jaxpr.debug_info @@ -2048,7 +2048,7 @@ def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): if not axis_names: raise LookupError( "No axis names are available. Make sure you are using `pl.core_map`" - " with a `plgpu.GPUMesh`." + " with a `plgpu.Mesh`." ) if not axis_names or axis_name not in axis_names.cluster: raise LookupError( @@ -2066,7 +2066,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): if gpu_axis_names is None and not jax_axis_names: raise LookupError( "No axis names are available. Make sure you are using `pl.core_map`" - " with a `plgpu.GPUMesh` or an appropriate JAX device mesh." + " with a `plgpu.Mesh` or an appropriate JAX device mesh." ) if axis_name not in itertools.chain((gpu_axis_names or ()), jax_axis_names): raise LookupError( diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 72e6f96c125a..ef1ba37f0f5c 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -63,9 +63,9 @@ def pallas_call_lowering( mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error if "mosaic_gpu" in compiler_params: - params = cast(gpu_core.GPUCompilerParams, compiler_params["mosaic_gpu"]) + params = cast(gpu_core.CompilerParams, compiler_params["mosaic_gpu"]) else: - params = gpu_core.GPUCompilerParams() + params = gpu_core.CompilerParams() jax_mesh = None axis_context = ctx.module_context.axis_context diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 0e9319972949..af9d4138cfbb 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -56,7 +56,7 @@ def _check_ref( - aval: object, name: str, memory_space: gpu_core.GPUMemorySpace + aval: object, name: str, memory_space: gpu_core.MemorySpace ) -> None: if not isinstance(aval, state_types.AbstractRef): raise TypeError(f"{name} must be a reference, got {aval}") @@ -1200,19 +1200,19 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, collective_axis): del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) - if acc.memory_space != gpu_core.GPUMemorySpace.TMEM: + if acc.memory_space != gpu_core.MemorySpace.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") - if a.memory_space not in (gpu_core.GPUMemorySpace.SMEM, - gpu_core.GPUMemorySpace.TMEM): + if a.memory_space not in (gpu_core.MemorySpace.SMEM, + gpu_core.MemorySpace.TMEM): raise ValueError("LHS must be a TMEM/SMEM Ref.") - if b.memory_space != gpu_core.GPUMemorySpace.SMEM: + if b.memory_space != gpu_core.MemorySpace.SMEM: raise ValueError("RHS must be an SMEM Ref.") if collective_axis is not None: if not acc.collective: raise ValueError( "Accumulator Ref must be collective if collective_axis is set.") - if a.memory_space == gpu_core.GPUMemorySpace.TMEM and not a.collective: + if a.memory_space == gpu_core.MemorySpace.TMEM and not a.collective: raise ValueError( "LHS TMEM Ref must be collective if collective_axis is set.") diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2d27bd3cc485..964709b4c915 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1551,7 +1551,7 @@ def pallas_call( backend-specific dataclass (:class:`jax.experimental.pallas.tpu.TPUCompilerParams`, :class:`jax.experimental.pallas.triton.TritonCompilerParams`, - :class:`jax.experimental.pallas.mosaic_gpu.GPUCompilerParams`) or a dict + :class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict mapping backend name to the corresponding platform-specific dataclass. backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or ``"mosaic_gpu"`` determining the backend to be used. None means let Pallas diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index a7d8c3e34223..cc0e185e296a 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -19,10 +19,10 @@ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier -from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec -from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec +from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams +from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh +from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion @@ -63,9 +63,15 @@ from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. -GMEM = GPUMemorySpace.GMEM -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. -SMEM = GPUMemorySpace.SMEM -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.TMEM`. -TMEM = GPUMemorySpace.TMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`. +GMEM = MemorySpace.GMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`. +SMEM = MemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`. +TMEM = MemorySpace.TMEM + +# TODO(slebedev): Deprecate and remove these aliases. +GPUBlockSpec = BlockSpec +GPUCompilerParams = CompilerParams +GPUMemorySpace = MemorySpace +GPUMesh = Mesh diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index a100aa96faba..6da468f2cc3e 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -306,7 +306,7 @@ def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): grid_names=("batch", "q_seq", "heads"), num_threads=3, thread_name="wg", - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), )(q, k, v) if save_residuals: @@ -451,11 +451,11 @@ def compute_dq(acc_ref): manual_consumed_barriers=True, compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( # k + plgpu.BlockSpec( # k block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), transforms=[tiling, swizzle]), - plgpu.GPUBlockSpec( # v + plgpu.BlockSpec( # v block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), transforms=[tiling, swizzle]), @@ -558,16 +558,16 @@ def compute_dk(acc_ref): manual_consumed_barriers=True, compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( # q + plgpu.BlockSpec( # q block_shape=(block_q, head_dim), index_map=lambda i: (i, 0), transforms=[tiling, swizzle]), - plgpu.GPUBlockSpec( # do + plgpu.BlockSpec( # do block_shape=(block_q, head_dim), index_map=lambda i: (i, 0), transforms=[tiling, swizzle]), - plgpu.GPUBlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)), - plgpu.GPUBlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)) + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)), + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)) ]) q_ref = q_ref.at[batch, :, q_head, :] do_ref = do_ref.at[batch, :, q_head, :] @@ -589,7 +589,7 @@ def compute_dk(acc_ref): (q_scratch, do_scratch, lse_scratch, delta_scratch), # type: ignore (plgpu.Barrier(1, num_barriers=compute_wgs),) * 4 # type: ignore ], - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), grid=(batch_size, num_q_tiles, num_q_heads), grid_names=("batch", "q_seq", "heads"), num_threads=compute_wgs + 1, @@ -610,7 +610,7 @@ def compute_dk(acc_ref): (k_scratch, v_scratch), # type: ignore (plgpu.Barrier(1, num_barriers=compute_wgs),) * 2 # type: ignore ], - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), grid=(batch_size, num_kv_tiles, num_q_heads), grid_names=("batch", "kv_seq", "heads"), num_threads=compute_wgs + 1, @@ -746,10 +746,10 @@ def compute_pv(acc_ref): manual_consumed_barriers=True, compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( # k + plgpu.BlockSpec( # k block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0)), - plgpu.GPUBlockSpec( # v + plgpu.BlockSpec( # v block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0)), ], @@ -758,7 +758,7 @@ def compute_pv(acc_ref): k_ref = k_ref.at[batch, :, kv_head, :] v_ref = v_ref.at[batch, :, kv_head, :] pipeline(k_ref, v_ref) - mesh = plgpu.GPUMesh( + mesh = plgpu.Mesh( grid=(batch_size, num_q_tiles, num_q_heads), grid_names=("batch", "q_seq", "heads"), num_threads=3, @@ -769,7 +769,7 @@ def run(refs): @pl.core_map( mesh, - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ), ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 68aeea5a03e8..b10bc0f390b0 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -94,14 +94,14 @@ def skip_if_wg_semantics(self): def kernel(self, *args, **kwargs): compiler_params = dataclasses.replace( - kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + kwargs.pop("compiler_params", plgpu.CompilerParams()), lowering_semantics=self.LOWERING_SEMANTICS, ) return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) def pallas_call(self, *args, **kwargs): compiler_params = dataclasses.replace( - kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + kwargs.pop("compiler_params", plgpu.CompilerParams()), lowering_semantics=self.LOWERING_SEMANTICS, ) return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) @@ -153,7 +153,7 @@ def test_unary_op(self, op, approx_math): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), + compiler_params=plgpu.CompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -296,12 +296,13 @@ def kernel(x_ref, o_ref, scratch_ref): @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): + @functools.partial( self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=max_concurrent_steps, ), @@ -314,11 +315,12 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_pipelined_program_id(self): + @functools.partial( self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -339,7 +341,7 @@ def test_add_one_grid_pipelined_sequential_invariant_output(self): in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -634,7 +636,7 @@ def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), in_specs=( - plgpu.GPUBlockSpec( + plgpu.BlockSpec( block_shape=(128, 128), index_map=lambda i, j: (i, j), memory_space=plgpu.SMEM, @@ -645,7 +647,7 @@ def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): ), ), out_specs=( - plgpu.GPUBlockSpec( + plgpu.BlockSpec( block_shape=(64, 32), index_map=lambda i, j: (i, j), memory_space=plgpu.SMEM, @@ -696,7 +698,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.wait_smem_to_gmem(0) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( + out_spec = plgpu.BlockSpec( transforms=( plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), @@ -727,7 +729,7 @@ def body(tmp_ref): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM) + out_spec = plgpu.BlockSpec(transforms=ts, memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), @@ -767,7 +769,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( + out_spec = plgpu.BlockSpec( transforms=( plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), @@ -797,7 +799,7 @@ def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -818,7 +820,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM) + out_spec = plgpu.BlockSpec(memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), @@ -920,7 +922,7 @@ def test_print_wgmma_tiled_layout(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( transforms=( plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), @@ -1013,10 +1015,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x) def test_load_scalar(self): + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GMEM)], + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) @@ -1135,7 +1138,7 @@ def kernel(o_ref): def test_swizzled_blockspec_shapes(self): self.skip_if_wg_semantics() - spec = plgpu.GPUBlockSpec( + spec = plgpu.BlockSpec( (128, 64), lambda *i: i, transforms=( @@ -1344,7 +1347,7 @@ def test_tile_slicing(self): self.skip_if_wg_semantics() shape = (256, 128) - block_spec = plgpu.GPUBlockSpec( + block_spec = plgpu.BlockSpec( transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) @functools.partial( @@ -1375,8 +1378,8 @@ def kernel(a_ref, b_ref): a = np.zeros((64, 64), dtype=jnp.float32) b = self.pallas_call( kernel, - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GMEM), + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), input_output_aliases={0: 0}, out_shape=a, )(a) @@ -1395,7 +1398,7 @@ def rotate(src, dst): dst[lower, left] = src[lower, right] x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) - spec = plgpu.GPUBlockSpec( + spec = plgpu.BlockSpec( transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) @@ -1472,7 +1475,7 @@ def kernel(x_ref, o_ref): y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( profile_space=16, profile_dir=tmpdir ), )(x) @@ -1836,11 +1839,12 @@ def test_fori_loop_accumulator(self, force_while): transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) else: transforms = () + @functools.partial( self.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], + in_specs=[plgpu.BlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64)), + out_specs=plgpu.BlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1907,7 +1911,7 @@ def _epilogue(): ) if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: - lhs_spec = plgpu.GPUBlockSpec( + lhs_spec = plgpu.BlockSpec( lhs_spec.block_shape, lhs_spec.index_map, transforms=( @@ -1915,7 +1919,7 @@ def _epilogue(): plgpu.SwizzleTransform(128), ), ) - rhs_spec = plgpu.GPUBlockSpec( + rhs_spec = plgpu.BlockSpec( rhs_spec.block_shape, rhs_spec.index_map, transforms=( @@ -1923,7 +1927,7 @@ def _epilogue(): plgpu.SwizzleTransform(128), ), ) - out_spec = plgpu.GPUBlockSpec( + out_spec = plgpu.BlockSpec( out_spec.block_shape, out_spec.index_map, transforms=( @@ -1939,7 +1943,7 @@ def _epilogue(): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, @@ -1980,7 +1984,7 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (64, 128), lambda i, j: (i, j), transforms=( @@ -1988,13 +1992,13 @@ def scope(acc_ref): plgpu.SwizzleTransform(128), ), ), - plgpu.GPUBlockSpec( + plgpu.BlockSpec( b_shape, lambda *i: i, transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), ), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i), + out_specs=plgpu.BlockSpec((64, 192), lambda *i: i), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), grid=(1, 1), )(a, b) @@ -2019,8 +2023,8 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec(transforms=transforms), - plgpu.GPUBlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) @@ -2044,9 +2048,9 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec(transforms=transforms), - plgpu.GPUBlockSpec(transforms=transforms), - plgpu.GPUBlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) @@ -2073,8 +2077,8 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec(transforms=transforms), - plgpu.GPUBlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) @@ -2104,14 +2108,10 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda *ij: ij, transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 128), lambda *ij: ij, transforms=transforms - ), + plgpu.BlockSpec((64, 128), lambda *ij: ij, transforms=transforms), + plgpu.BlockSpec((128, 128), lambda *ij: ij, transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *ij: ij), + out_specs=plgpu.BlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) @@ -2129,7 +2129,7 @@ def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layo self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -2175,14 +2175,14 @@ def compute(acc_ref): out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( pl.BlockSpec(memory_space=src_memory_space), - plgpu.GPUBlockSpec( + plgpu.BlockSpec( transforms=( plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), ), - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), ) out_ref = ( @@ -2294,12 +2294,10 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, f = self.pallas_call( kernel, in_specs=( - plgpu.GPUBlockSpec(transforms=transforms, - memory_space=plgpu.SMEM), - plgpu.GPUBlockSpec(transforms=transforms, - memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), ), - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GMEM), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), out_shape=jax.ShapeDtypeStruct(shape, dtype), scratch_shapes=scratch_shapes, ) @@ -2511,12 +2509,12 @@ def kernel(x_gmem, o_gmem): plgpu.emit_pipeline( kernel_body, in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (64, 64), lambda i: (0, i), transforms=transforms ) ], out_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (64, 64), lambda i: (0, i), transforms=transforms ) ], @@ -2705,10 +2703,10 @@ def kernel_body(_, a_smem, b_smem): plgpu.emit_pipeline( kernel_body, in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms ), - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms ), ], @@ -2729,7 +2727,7 @@ def kernel_body(_, a_smem, b_smem): pl.BlockSpec(memory_space=plgpu.GMEM), pl.BlockSpec(memory_space=plgpu.GMEM), ], - out_specs=plgpu.GPUBlockSpec( + out_specs=plgpu.BlockSpec( (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms ), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), @@ -2794,7 +2792,7 @@ def body(*gmem_refs): jax.ShapeDtypeStruct((m, n), jnp.float16), jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16), ), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), grid=(1,), grid_names=("_",), num_threads=3, @@ -2839,7 +2837,7 @@ def pipeline(*gmem_refs): kernel = self.kernel( pipeline, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), grid=(1,), grid_names=("_",), num_threads=num_compute_wgs + 1, @@ -2858,7 +2856,7 @@ def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): scratch_shapes=[ plgpu.SMEM((blk_m, blk_n), jnp.float32), ], - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + compiler_params=plgpu.CompilerParams(approx_math=True), grid=(1,), grid_names=("_",), num_threads=num_compute_wgs + 1, @@ -3176,11 +3174,12 @@ class CoreMapWGTest( class PrettyPrintingTest(PallasTest): def test_load(self): + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -3228,8 +3227,8 @@ def test_wgmma(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), in_specs=[ - plgpu.GPUBlockSpec(transforms=transforms), - plgpu.GPUBlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], ) def kernel(a_ref, b_ref, o_ref): @@ -3348,9 +3347,13 @@ def kernel(l_ref, r_ref, o_ref): def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") - block = plgpu.GPUBlockSpec( - (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)), + block = plgpu.BlockSpec( + (row_block, col_block), + lambda c: (r, c), + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(64), + ), ) plgpu.emit_pipeline( compute, @@ -3420,9 +3423,19 @@ def do_wgmma(acc_ref): plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), - in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms), - plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)], - out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], + in_specs=[ + plgpu.BlockSpec( + (m_block, k_block), lambda k: (m, k), transforms=lo_transforms + ), + plgpu.BlockSpec( + (k_block, n_block), lambda k: (k, n), transforms=r_transforms + ), + ], + out_specs=[ + plgpu.BlockSpec( + (m_block, n_block), lambda k: (m, n), transforms=lo_transforms + ) + ], )(l_ref, r_ref, o_ref) np.testing.assert_allclose(kernel(x, x), x @ x) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 3baa26e5efd7..bcda0ca9f71e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -292,7 +292,7 @@ def setUp(self): def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None - compiler_params = plgpu_mgpu.GPUCompilerParams( + compiler_params = plgpu_mgpu.CompilerParams( lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params From 06448864abd6e8187e5b4d9b1ff08ab14fe3b8e0 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 20 May 2025 08:40:37 -0700 Subject: [PATCH 1251/1769] Rename backend.compile to backend.compile_and_load. Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable. PiperOrigin-RevId: 761098001 --- jax/_src/compiler.py | 52 +++++++++++++++++++++++++++++++--------- jaxlib/_jax/__init__.pyi | 11 +++++++++ jaxlib/xla_client.py | 1 + jaxlib/xla_client.pyi | 1 + 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 04f993fed799..e8ef647a1312 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -292,6 +292,19 @@ def backend_compile( executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], +) -> xc.LoadedExecutable: + return backend_compile_and_load( + backend, module, executable_devices, options, host_callbacks + ) + + +@profiler.annotate_function +def backend_compile_and_load( + backend: xc.Client, + module: ir.Module, + executable_devices: xc.DeviceList, + options: xc.CompileOptions, + host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -322,18 +335,35 @@ def backend_compile( # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - if host_callbacks: + elif jaxlib_extension_version < 342 or isinstance(backend, xc.CompileOnlyPyClient): + if host_callbacks: + return backend.compile( + built_c, + executable_devices=executable_devices, # type: ignore + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` return backend.compile( + built_c, executable_devices=executable_devices, compile_options=options) # type: ignore + else: + if host_callbacks: + return backend.compile_and_load( + built_c, + executable_devices=executable_devices, + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile_and_load( built_c, - executable_devices=executable_devices, # type: ignore + executable_devices=executable_devices, compile_options=options, - host_callbacks=host_callbacks, ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile( - built_c, executable_devices=executable_devices, compile_options=options) # type: ignore except xc.XlaRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) @@ -398,7 +428,7 @@ def compile_or_get_cached( ) if cache_key is None: - return backend_compile( + return backend_compile_and_load( backend, computation, executable_devices, compile_options, host_callbacks) @@ -426,7 +456,7 @@ def compile_or_get_cached( config.share_binary_between_hosts.value and is_multi_process and distributed.global_state.client is not None - # Host callbacks are currently baked into the HLO module so we cant share + # Host callbacks are currently baked into the HLO module so we can't share # them. and len(host_callbacks) == 0 ): @@ -716,7 +746,7 @@ def _compile_and_write_cache( cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() - executable = backend_compile( + executable = backend_compile_and_load( backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 1d7f3042e8a3..000c05acacad 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -551,6 +551,17 @@ class Client: ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... + +class CompileOnlyPyClient(Client): + def compile( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + + class CpuCollectives: ... def make_gloo_tcp_collectives( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 8f8c829ee6c7..b1bbc464610e 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -304,6 +304,7 @@ def computation_count(): XlaComputation = _xla.XlaComputation Client = _xla.Client +CompileOnlyPyClient = _xla.CompileOnlyPyClient Memory = _xla.Memory Array = _xla.Array ArrayImpl = _xla.ArrayImpl diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 80599e86676b..fce114f45474 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -24,6 +24,7 @@ from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics from jaxlib._jax import ArrayImpl as ArrayImpl from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode from jaxlib._jax import Client as Client +from jaxlib._jax import CompileOnlyPyClient as CompileOnlyPyClient from jaxlib._jax import CompileOptions as CompileOptions from jaxlib._jax import Device as Device from jaxlib._jax import DeviceAssignment as DeviceAssignment From fa823ae629cb1361ac63b5072f83e679abcf7d69 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 09:11:28 -0700 Subject: [PATCH 1252/1769] Add a TSAN suppression for https://github.com/python/cpython/issues/132214 under Python 3.13. --- .github/workflows/tsan-suppressions_3.13.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index aec94dfef004..a929a8c44728 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -40,3 +40,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/132245 race:split_keys_entry_added race_top:dict_dict_merge + +# https://github.com/python/cpython/issues/132214 +race:type_update_dict From 496cbd07cea5134e9ab83a72fc33941acc6149b8 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Tue, 20 May 2025 09:26:35 -0700 Subject: [PATCH 1253/1769] =?UTF-8?q?[Mosaic=20GPU]=20Replace=20`core=5Fma?= =?UTF-8?q?p`=20+=20`run=5Fstate`=20with=20`plgpu.kernel`=20for=20simpler?= =?UTF-8?q?=20code.=20Slightly=20more=20efficient=20because=20we=20don?= =?UTF-8?q?=E2=80=99t=20initialize=20the=20outputs=20now.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PiperOrigin-RevId: 761113870 --- .../pallas/ops/gpu/attention_mgpu.py | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 6da468f2cc3e..c7e9f95e3f99 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -658,10 +658,9 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residual if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, smem_buffers, q_barriers, schedule_barrier): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - smem_buffers, q_barriers, schedule_barrier = scoped qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") @@ -758,46 +757,31 @@ def compute_pv(acc_ref): k_ref = k_ref.at[batch, :, kv_head, :] v_ref = v_ref.at[batch, :, kv_head, :] pipeline(k_ref, v_ref) - mesh = plgpu.Mesh( + + out_shape = [q, None] + if save_residuals: + out_shape[1] = jax.ShapeDtypeStruct((batch_size, num_q_heads, q_seq_len), jnp.float32) + + qo_scratch = plgpu.SMEM((compute_wgs, block_q, head_dim), jnp.float16) + smem_scratch = [qo_scratch, None] + if save_residuals: + smem_scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) + + out, lse = plgpu.kernel( + fa3_kernel, grid=(batch_size, num_q_tiles, num_q_heads), grid_names=("batch", "q_seq", "heads"), num_threads=3, thread_name="wg", - ) - def run(refs): - q_ref, k_ref, v_ref, out_ref, lse_ref = refs - - @pl.core_map( - mesh, - compiler_params=plgpu.CompilerParams( - approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup - ), - ) - def _kernel_entry(): - qo_scratch = plgpu.SMEM( - (compute_wgs, block_q, head_dim), jnp.float16, - ) - scratch = [qo_scratch, None] - if save_residuals: - scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) - pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), - scratch, - plgpu.Barrier(1, num_barriers=compute_wgs), - plgpu.Barrier(num_arrivals=compute_wgs), - collective_axes="wg", - ) - @jax.jit - def run_function(q, k, v, o, lse): - *_, out, lse = pl.run_state(run)((q, k, v, o, lse)) - return out, lse - - lse = ( - jnp.full((batch_size, num_q_heads, q_seq_len), -jnp.inf, dtype=jnp.float32) - if save_residuals - else None - ) - out, lse = run_function(q, k, v, jnp.full_like(q, jnp.inf), lse) + out_shape=out_shape, + scratch_shapes=( + tuple(smem_scratch), # type: ignore + plgpu.Barrier(1, num_barriers=compute_wgs), # type: ignore + plgpu.Barrier(num_arrivals=compute_wgs),), # type: ignore + compiler_params=plgpu.CompilerParams( + approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), + )(q, k, v) if save_residuals: assert lse is not None From 4a3ce2b2dc75bdb02b52687c24dfb7d182278be5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 May 2025 10:29:31 -0700 Subject: [PATCH 1254/1769] [pallas:mosaic_gpu] Use `MemorySpace` aliases This just makes the corresponding conditions a bit easier to read. PiperOrigin-RevId: 761137840 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index db959b1dbc24..eb5fc136082e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -649,7 +649,7 @@ def ref_for_aval(aval: jax_core.AbstractValue): aval = v.aval if (isinstance(aval, pallas_core.AbstractMemoryRef) and jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): - if aval.memory_space != gpu_core.MemorySpace.GMEM: + if aval.memory_space != gpu_core.GMEM: raise ValueError( "Only GMEM memory space is supported for semaphores in Mosaic GPU." ) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index af9d4138cfbb..53c890932e38 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1200,19 +1200,18 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, collective_axis): del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) - if acc.memory_space != gpu_core.MemorySpace.TMEM: + if acc.memory_space != gpu_core.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") - if a.memory_space not in (gpu_core.MemorySpace.SMEM, - gpu_core.MemorySpace.TMEM): + if a.memory_space not in (gpu_core.SMEM, gpu_core.TMEM): raise ValueError("LHS must be a TMEM/SMEM Ref.") - if b.memory_space != gpu_core.MemorySpace.SMEM: + if b.memory_space != gpu_core.SMEM: raise ValueError("RHS must be an SMEM Ref.") if collective_axis is not None: if not acc.collective: raise ValueError( "Accumulator Ref must be collective if collective_axis is set.") - if a.memory_space == gpu_core.MemorySpace.TMEM and not a.collective: + if a.memory_space == gpu_core.TMEM and not a.collective: raise ValueError( "LHS TMEM Ref must be collective if collective_axis is set.") From 86680a9b1282d00398f3d3d3a56336c0452a76b8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 20 May 2025 10:31:19 -0700 Subject: [PATCH 1255/1769] Include Pallas TPU random ops in JAX wheel. Since the `pallas/tpu/ops/random` directory was missing an `__init__.py` file, it was inadvertently excluded from the released JAX distribution. I don't see any reason why this submodule shouldn't be included so let's fix that! To deal with the fact that they weren't included in the distribution, we were also monkey patching these files into the wheel when testing, but that's no longer needed. PiperOrigin-RevId: 761138525 --- BUILD.bazel | 3 --- jax/experimental/pallas/ops/tpu/random/__init__.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 jax/experimental/pallas/ops/tpu/random/__init__.py diff --git a/BUILD.bazel b/BUILD.bazel index 59fd949b7ad8..887f28d4583e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -108,9 +108,6 @@ genrule( "//jax:internal_test_harnesses", "//jax:internal_test_util", "//jax:internal_export_back_compat_test_data", - "//jax:experimental/pallas/ops/tpu/random/philox.py", - "//jax:experimental/pallas/ops/tpu/random/prng_utils.py", - "//jax:experimental/pallas/ops/tpu/random/threefry.py", "//jax/experimental/mosaic/gpu/examples:flash_attention.py", "//jax/experimental/mosaic/gpu/examples:matmul.py", "//jax:test_multiprocess", diff --git a/jax/experimental/pallas/ops/tpu/random/__init__.py b/jax/experimental/pallas/ops/tpu/random/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== From 8cfabd7d545e601fdd111d36acc2142e048d11b5 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 20 May 2025 11:00:41 -0700 Subject: [PATCH 1256/1769] Migrate users of backend.compile to backend.compile_and_load. PiperOrigin-RevId: 761150621 --- jax/experimental/jax2tf/tests/sharding_test.py | 2 +- jax/experimental/jax2tf/tests/tf_test_util.py | 2 +- tests/compilation_cache_test.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 55ccb1328c87..5fc45df218cd 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -115,7 +115,7 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, executable = backend.compile( jax_hlo, compile_options=compile_options) # type: ignore else: - executable = backend.compile( + executable = backend.compile_and_load( jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore jax_optimized_hlo = executable.hlo_modules()[0].to_string() logging.info("[%s] got JAX optimized HLO for platform %s %s", diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index e87a8af5d15e..faecf9f0f09e 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -346,7 +346,7 @@ def log_message(extra): backend = xla_bridge.get_backend() device_list = xc.DeviceList(tuple(backend.local_devices())) - modules = backend.compile( + modules = backend.compile_and_load( str(jax_lowered.compiler_ir()), device_list).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info("[%s] JAX OPT HLO\n%s", self._testMethodName, diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 1ba6b1221a88..5a76d732bd76 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -150,9 +150,9 @@ def test_diff_executables(self): executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) else: - executable1 = backend.compile( + executable1 = backend.compile_and_load( computation1, executable_devices, compile_options) - executable2 = backend.compile( + executable2 = backend.compile_and_load( computation2, executable_devices, compile_options) cc.put_executable_and_time( "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) @@ -180,7 +180,7 @@ def test_put_executable(self): if jax._src.lib.jaxlib_extension_version < 331: executable = backend.compile(str(computation), compile_options) else: - executable = backend.compile( + executable = backend.compile_and_load( str(computation), executable_devices, compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) cc.put_executable_and_time( @@ -251,7 +251,7 @@ def test_enable_compilation_cache(self): g = jit(lambda x: x * 3) g(2) cache = cc._get_cache(backend) - self.assertIsNotNone(cache) # Cache should be initalized + self.assertIsNotNone(cache) # Cache should be initialized def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value From 683a215a154eebc94a3a980e2b80453fb2392553 Mon Sep 17 00:00:00 2001 From: Alexandre Boulgakov Date: Tue, 20 May 2025 12:20:41 -0700 Subject: [PATCH 1257/1769] [Mosaic] Tweak `tpu.log` verification on SC. PiperOrigin-RevId: 761182341 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 30 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index b5e68bf08370..3733bf5d4465 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -52,15 +53,15 @@ LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op, RollVectorsOp roll_op = dyn_cast_or_null(op.getOperand().getDefiningOp()); if (!roll_op) { - return failure(); + return failure(); } if (roll_op.getNumOperands() != op.getNumResults()) { - return failure(); + return failure(); } for (auto [v1, v2] : llvm::zip(roll_op.getOperandTypes(), op.getResultTypes())) { if (v1 != v2) { - return failure(); + return failure(); } } rewriter.replaceOp(op, roll_op.getOperands()); @@ -499,8 +500,7 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, } auto layout_ref = erase_layout_op.getOperand(); auto layout_ty = layout_ref.getType(); - auto layout = - dyn_cast(layout_ty.getLayout()); + auto layout = dyn_cast(layout_ty.getLayout()); CHECK(!layout.getTiles().empty()); auto tile = layout.getTiles().front().dimensions(); auto new_tile_strides = ComputeTileStrides(dst_ty, tile); @@ -594,8 +594,8 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { return failure(); } - SmallVector new_tiles = - {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + SmallVector new_tiles = { + xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; if (tgt_bitwidth < 32) { new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); } @@ -1325,11 +1325,21 @@ LogicalResult LogOp::verify() { return failure(); } CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); - if ((logging_core_type == CoreType::kScScalarSubcore || - logging_core_type == CoreType::kScVectorSubcore) && - getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + bool is_sc_core = logging_core_type == CoreType::kScScalarSubcore || + logging_core_type == CoreType::kScVectorSubcore; + if (is_sc_core && getFormattedAttr() != nullptr && + getFormattedAttr().getValue()) { return emitOpError("Formatted logging is not supported on SC"); } + if (is_sc_core && getInputs().size() > 1) { + return emitOpError("SC logging only supports 0 or 1 inputs"); + } + if (is_sc_core && getInputs().size() == 1) { + Type input_type = getInputs().front().getType(); + if (!llvm::isa(input_type)) { + return emitOpError("SC logging only supports memrefs or scalars"); + } + } switch (logging_core_type) { case CoreType::kTc: case CoreType::kScScalarSubcore: From 0144ec1edf538a1e19dc7736cdf403861975087b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 12:55:08 -0700 Subject: [PATCH 1258/1769] Remove tsan suppression from python/cpython#129748 This bug is marked as fixed upstream. --- .github/workflows/tsan-suppressions_3.13.txt | 3 --- .github/workflows/tsan-suppressions_3.14.txt | 3 --- .github/workflows/tsan.yaml | 1 - 3 files changed, 7 deletions(-) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index a929a8c44728..483e3f0b3c2a 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -23,9 +23,6 @@ race_top:PyMember_GetOne race_top:new_reference race:_Py_IsOwnedByCurrentThread -# https://github.com/python/cpython/issues/129748 -race:mi_block_set_nextx - # https://github.com/python/cpython/issues/128130 race_top:run_eval_code_obj diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index 384560128cfc..008b61933a0b 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -8,9 +8,6 @@ race:dnnl_sgemm # https://github.com/python/cpython/issues/128050 race:partial_vectorcall_fallback -# https://github.com/python/cpython/issues/129748 -race:mi_block_set_nextx - # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 882e140b91ad..ce4130c31a30 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -14,7 +14,6 @@ on: paths: - '**/workflows/tsan.yaml' - '**/workflows/tsan-suppressions*.txt' - - '**/workflows/requirements_lock_3_13_ft.patch' jobs: tsan: From e896282219481c3c6edbd1334c186c7bfbdbdae6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 20 May 2025 12:58:45 -0700 Subject: [PATCH 1259/1769] Remove pspec -> names conversion during `shard_map_p.bind` and instead preserve partition specs everywhere internally. **This is because spec -> names canonicalization gets rid of unreduced axes present on PartitionSpecs and we want to preserve that**. We can thread 2 new parameters called `in_unreduced` and `out_unreduced` and keep `in_names`, `out_names` but that doesn't buy us anything except for more lines added and complexity :) It's better to just use pspecs everywhere. It's a net reduction in lines of code too! PiperOrigin-RevId: 761196531 --- jax/_src/checkify.py | 15 +- jax/_src/dispatch.py | 10 +- jax/_src/interpreters/pxla.py | 9 +- jax/_src/pjit.py | 14 +- jax/_src/shard_map.py | 472 ++++++++++++++++------------------ tests/shard_map_test.py | 16 +- 6 files changed, 258 insertions(+), 278 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index f26b4222b23b..0061c9c63f7b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -53,6 +53,7 @@ from jax._src.tree_util import tree_map from jax._src.tree_util import tree_unflatten from jax._src.typing import Array +from jax._src.partition_spec import PartitionSpec as P from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache, HashableWrapper, foreach) @@ -958,7 +959,7 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): def shard_map_error_check( error: Error, enabled_errors, *vals_in, - jaxpr: core.Jaxpr, in_names, out_names, **kwargs + jaxpr: core.Jaxpr, in_specs, out_specs, **kwargs ): if (mesh := kwargs.get('mesh')) is None: raise ValueError('Mesh must be provided for shard_map with checkify.') @@ -966,7 +967,7 @@ def shard_map_error_check( err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) # Replicated sharding for in errors. - new_in_names = (*([{}] * num_error_vals), *in_names) + new_in_specs = (*([P()] * num_error_vals), *in_specs) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) manual_axes = kwargs.get('manual_axes') @@ -974,7 +975,7 @@ def shard_map_error_check( for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_names[i], v) + in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) with (jshmap._extend_axis_env(mesh, manual_axes), mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] @@ -983,7 +984,7 @@ def shard_map_error_check( checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals ) - num_out_error_vals = out_tree.num_leaves - len(out_names) + num_out_error_vals = out_tree.num_leaves - len(out_specs) def expand_errors_leading_dim(*xs): outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) @@ -1001,15 +1002,15 @@ def expand_errors_leading_dim(*xs): # Update shard_map params to account for extra error values. # Use fully sharded partitioning for out errors. - new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + new_out_specs = (*([P(mesh.axis_names)] * num_out_error_vals), *out_specs) subfun = lu.hashable_partial( lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info), checked_jaxpr.jaxpr, checked_jaxpr.consts ) new_params = dict( jaxpr=checked_jaxpr.jaxpr, - in_names=new_in_names, - out_names=new_out_names, + in_specs=new_in_specs, + out_specs=new_out_specs, **kwargs, ) _, new_params = jshmap.shard_map_p.get_bind_params(new_params) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 1ab560fb58c5..8f553ea884d7 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -264,14 +264,12 @@ def get_intermediate_shardings( out.extend((i, source_info) for i in eqn.params['in_shardings']) out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - if isinstance(eqn.params['mesh'], AbstractMesh): + mesh = eqn.params['mesh'] + if isinstance(mesh, AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - def _names_to_pspec(names): - ndmin = max(names) + 1 if names else 0 - return PartitionSpec(*(names.get(i) for i in range(ndmin))) - out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) - for names in [*eqn.params['in_names'], *eqn.params['out_names']]) + out.extend((NamedSharding(mesh, spec), source_info) + for spec in [*eqn.params['in_specs'], *eqn.params['out_specs']]) elif eqn.primitive is device_put_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) out.extend((s, source_info) for s in eqn.params['devices'] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a7782063491c..4a21fae59e52 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -72,7 +72,8 @@ PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, - unzip2, HashableFunction, weakref_lru_cache) + unzip2, HashableFunction, weakref_lru_cache, + tuple_insert) from jax._src.state.types import AbstractRef, RefEffect @@ -3339,6 +3340,12 @@ def check_array_xla_sharding_layout_match( "compiled with. " f"Here are {num_mismatch_str}:\n{str_errors}") +def batch_spec(spec, dim, val): + too_short = dim - len(spec) + if too_short > 0: + spec += (None,) * too_short + new_partitions = tuple_insert(spec, dim, val) # type: ignore + return PartitionSpec(*new_partitions) def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 10e7e697e706..de01f4c05983 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -77,7 +77,7 @@ treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, PyTreeDef, none_leaf_registry as none_lr, tree_map, tree_flatten_with_path) from jax._src.util import ( - HashableFunction, safe_map, safe_zip, wraps, tuple_insert, + HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, split_list_checked, weakref_lru_cache, merge_lists, subs_list, fun_name, fun_qual_name) from jax._src.attrs import (Box, List, dne_sentinel, jax_setattr, jax_getattr, @@ -2190,12 +2190,6 @@ def _pjit_batcher(axis_data, vals_in, batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule -def _insert_axis_partitions(spec, dim, val): - too_short = dim - len(spec) - if too_short > 0: - spec += (None,) * too_short - new_partitions = tuple_insert(spec, dim, val) # type: ignore - return PartitionSpec(*new_partitions) def _pjit_batcher_for_sharding( s: Sharding | UnspecifiedValue, @@ -2209,7 +2203,7 @@ def _pjit_batcher_for_sharding( return s if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, PartitionSpec.UNCONSTRAINED)) + s.mesh, pxla.batch_spec(s.spec, dim, PartitionSpec.UNCONSTRAINED)) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) # type: ignore @@ -2221,7 +2215,7 @@ def _pjit_batcher_for_sharding( else: if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, spmd_axis_name)) + s.mesh, pxla.batch_spec(s.spec, dim, spmd_axis_name)) if isinstance(s, NamedSharding): mesh = s.mesh if mesh is None or mesh.empty: @@ -2234,7 +2228,7 @@ def _pjit_batcher_for_sharding( f' manager scope{s!r}') spec = parse_flatten_op_sharding(hlo_s, mesh)[0] return NamedSharding( - mesh, _insert_axis_partitions(spec, dim, spmd_axis_name)) + mesh, pxla.batch_spec(spec, dim, spmd_axis_name)) def _pjit_jvp(primals_in, tangents_in, diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 010773f74d0c..72e1420b0b2b 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -47,7 +47,8 @@ from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2) + merge_lists, split_list, subs_list2, + fun_name as util_fun_name) from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -209,11 +210,11 @@ def wrapped(*args): dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) if s is not None) fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) - _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, + args_flat) @memoize - def out_names_thunk(): + def out_specs_thunk(): if callable(out_specs): out_specs_ = out_specs() _check_specs(SpecErrorType.out, out_specs_, axis_names) @@ -225,15 +226,15 @@ def out_names_thunk(): except ValueError: e, *_ = prefix_errors(out_specs_, dummy) raise e('shard_map out_specs') from None - return tuple(map(_canonicalize_spec, out_specs_flat)) + return tuple(out_specs_flat) if check_vma: - fun = _implicit_pvary_on_output(fun, out_names_thunk) + fun = _implicit_pvary_on_output(fun, out_specs_thunk) try: out_flat = shard_map_p.bind( - fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_vma=check_vma, + fun, *args_flat, mesh=mesh, in_specs=in_specs_flat, + out_specs_thunk=out_specs_thunk, check_vma=check_vma, manual_axes=axis_names) except _SpecError as e: fails, = e.args @@ -305,16 +306,6 @@ def _shmap_checks(mesh, axis_names, in_specs, out_specs, _skip_mesh_check, _check_specs(SpecErrorType.out, out_specs, axis_names) return mesh, axis_names - -# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs -AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable -def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: - if isinstance(spec, PartitionSpec): - return {i: names if isinstance(names, tuple) else (names,) - for i, names in enumerate(spec) if names is not None} - else: - return spec - def _manual_spec(manual_axes, spec: P) -> P: out = [] # type: ignore for s in spec: @@ -391,7 +382,7 @@ def _check_specs_vs_args( fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + in_names_flat = tuple(map(_spec_to_names, in_specs_flat)) fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) for d, ns in names.items()) else no_fail for a, names in zip(in_avals, in_names_flat)] @@ -411,7 +402,7 @@ def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], def _spec_rank_error( error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) + fun_name = util_fun_name(f) if error_type == SpecErrorType.input: prefix, base = 'in', 'args' ba = _try_infer_args(f, tree) @@ -472,7 +463,7 @@ def _spec_divisibility_error( extra = (f", where args{arg_key} is the index " f"{arg_key.idx - len(ba.signature.parameters) + 1} component " f"of {fun_name}'s varargs parameter '{param.name}',") - names = _canonicalize_spec(spec) + names = _spec_to_names(spec) for d, ns in names.items(): if aval.shape[d] % prod(mesh.shape[n] for n in ns): axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" @@ -504,8 +495,7 @@ def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, fun_name = getattr(f, '__name__', str(f)) msgs = [] for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails): - dst = _canonicalize_spec(spec) - unmentioned = _unmentioned(mesh, dst) + unmentioned = _unmentioned(mesh, spec) if len(unmentioned) > 1: need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec)))) got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma))) @@ -536,9 +526,9 @@ def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, "check_vma=False argument to `jax.shard_map`.") return msg -def _unmentioned(mesh: Mesh | AbstractMesh, names: AxisNames) -> list[AxisName]: - name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] +def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]: + vma_set = _spec_to_vma(spec) + return [n for n in mesh.axis_names if n not in vma_set] def _try_infer_args(f, tree): @@ -563,10 +553,10 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] # Primitive @lu.transformation2 -def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): +def _implicit_pvary_on_output(f, out_specs_thunk, *args, **kwargs): out_flat = f(*args, **kwargs) - return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) - for o, n in zip(out_flat, out_names_thunk())] + return [pvary(o, tuple(_spec_to_vma(sp) - typeof(o).vma)) + for o, sp in zip(out_flat, out_specs_thunk())] JaxType = Any MaybeTracer = Union[JaxType, Tracer] @@ -588,8 +578,8 @@ def get_bind_params(self, params): subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) - axes = new_params.pop('out_names') - new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) + axes = new_params.pop('out_specs') + new_params['out_specs_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params shard_map_p = ShardMapPrimitive('shard_map') @@ -631,38 +621,35 @@ def _extend_axis_env(mesh, manual_axes): def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, in_tracers: Sequence[Any], *, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_vma: bool, - manual_axes: frozenset, + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: source_info = source_info_util.current() to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) in_tracers = map(to_jaxpr_tracer, in_tracers) inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_names, + in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_specs, in_avals) with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) - _check_names(out_names_thunk(), out_avals_) + _check_names(out_specs_thunk(), out_avals_) if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] - _check_vmas(mesh, out_names_thunk(), out_vma) + _check_vmas(mesh, out_specs_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) - for names, aval in zip(out_names_thunk(), out_avals)] + out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, spec, aval)) + for spec, aval in zip(out_specs_thunk(), out_avals)] out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) - in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore + in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) - params = dict(mesh=mesh, in_names=in_names_staged, - out_names=tuple(out_names_thunk()), jaxpr=jaxpr, + params = dict(mesh=mesh, in_specs=in_specs_staged, + out_specs=tuple(out_specs_thunk()), jaxpr=jaxpr, check_vma=check_vma, manual_axes=manual_axes) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, @@ -673,44 +660,48 @@ def _shard_map_staging( # TODO add underscore version, for direct-linearize to consume +def _spec_to_names(spec: PartitionSpec): + return {i: names if isinstance(names, tuple) else (names,) + for i, names in enumerate(spec) if names is not None} + def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval -def _shard_aval(mesh: Mesh, manual_axes, check_vma, names: AxisNames, +def _shard_aval(mesh: Mesh, manual_axes, check_vma, spec, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.shard_aval_handlers: return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma, - names, aval) + spec, aval) raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, +def _unshard_aval(mesh: Mesh, check_vma, spec, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, check_vma, names, aval) + return core.unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval) else: raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, - names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: + spec, aval: core.AbstractValue) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) + names = _spec_to_names(spec) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - vma = (frozenset({n for ns in names.values() for n in ns}) - if check_vma else frozenset()) + vma = _spec_to_vma(spec) if check_vma else frozenset() vma = vma | aval.vma return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array -def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, - aval: core.AbstractValue,) -> core.AbstractValue: +def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.AbstractValue + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) + names = _spec_to_names(spec) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) + names_spec = spec._normalized_spec_for_aval(aval.ndim) if aval.ndim == 0: out_spec = P() else: @@ -739,32 +730,32 @@ def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, # Type-checking -def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, +def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): # TODO(mattjj,parkers): check auto - for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): + for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs): if not core.typecompat(v.aval, _shard_aval( - mesh, manual_axes, check_vma, in_name, x.aval)): + mesh, manual_axes, check_vma, in_spec, x.aval)): raise core.JaxprTypeError("shard_map argument avals not compatible with " - "jaxpr binder avals and in_names") + "jaxpr binder avals and in_specs") with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): core.check_jaxpr(jaxpr) if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] - for vma, dst in zip(out_vma, out_names): - if not _valid_repeats(mesh, vma, dst): + for vma, out_spec in zip(out_vma, out_specs): + if not _valid_repeats(mesh, vma, out_spec): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh, check_vma), out_names, + out_avals = map(partial(_unshard_aval, mesh, check_vma), out_specs, out_avals_sharded) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) return out_avals, effs core.custom_typechecks[shard_map_p] = _shard_map_typecheck -def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool: - um = set(_unmentioned(mesh, names)) - set(mesh.manual_axes) +def _valid_repeats(mesh: Mesh, vma: Set[AxisName], spec) -> bool: + um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes) if any(u in vma for u in um): return False return True @@ -772,10 +763,9 @@ def _valid_repeats(mesh: Mesh, vma: Set[AxisName], names: AxisNames) -> bool: # Lowering def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in + ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in ) -> sharding_impls.SdyArray: - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) + ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) @@ -789,12 +779,12 @@ def _shardy_shard_map_sharding( def _shardy_shard_map_token_sharding( ctx: mlir.LoweringRuleContext, mesh ) -> ir.Attribute: - ns = _make_scoped_manual_sharding(ctx, mesh, {}) + ns = _make_scoped_manual_sharding(ctx, mesh, P()) return ns._to_sdy_sharding(0) def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma): + ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma): axis_ctx = ctx.module_context.axis_context in_avals_ = [v.aval for v in jaxpr.invars] if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): @@ -820,17 +810,17 @@ def _shard_map_lowering_shardy( ctx.set_tokens_out(tokens_out) return out_nodes - in_shardings = list(map( - partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), - in_names, ctx.avals_in)) + in_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + in_specs, ctx.avals_in)) num_dim_vars = len(ctx.dim_var_values) in_shardings = ([_shardy_shard_map_token_sharding(ctx, mesh)] * (num_tokens + num_dim_vars) + in_shardings) in_shardings = sharding_impls.SdyArrayList(in_shardings).build() - out_shardings = list(map( - partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), - out_names, ctx.avals_out)) + out_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + out_specs, ctx.avals_out)) out_shardings = [ _shardy_shard_map_token_sharding(ctx, mesh)] * num_tokens + out_shardings out_shardings = sharding_impls.SdyArrayList(out_shardings).build() @@ -868,15 +858,15 @@ def _shard_map_lowering_shardy( return manual_computation_op.results[num_tokens:] -def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, +def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma) + ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_names, + in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs, ctx.avals_in, in_avals_, in_nodes) new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) @@ -885,28 +875,26 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, - arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), - result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) + arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_), + result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_)) ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_names, + return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs, out_avals_, ctx.avals_out, out_nodes_) mlir.register_lowering(shard_map_p, _shard_map_lowering) -def _make_scoped_manual_sharding(ctx, mesh, axes): +def _make_scoped_manual_sharding(ctx, mesh, spec): axis_ctx = ctx.module_context.axis_context mesh = mesh.abstract_mesh if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): mesh = mesh.update_axis_types( {a: AxisType.Manual for a in axis_ctx.manual_axes}) - return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore + return NamedSharding(mesh, spec) -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in, aval_out, x): if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) + ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) @@ -920,12 +908,11 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in, aval_out, x): if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) + ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) @@ -941,8 +928,9 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, unspecified) -def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: +def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str: if isinstance(aval, core.ShapedArray): + names = _spec_to_names(spec) return str(map(names.get, range(aval.ndim))) return '' @@ -969,15 +957,13 @@ def _vma_to_spec(mesh, vma): return P(order_wrt_mesh(mesh, vma)) def _spec_to_vma(spec): - return _names_to_vma(_canonicalize_spec(spec)) - -def _names_to_vma(names): - return {n for ns in names.values() for n in ns} + return frozenset(p for s in spec if s is not None + for p in (s if isinstance(s, tuple) else (s,))) def order_wrt_mesh(mesh, x): return tuple(a for a in mesh.axis_names if a in x) -def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, +def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk, check_vma, manual_axes): if len(manual_axes) < len(mesh.axis_names): raise NotImplementedError @@ -988,18 +974,18 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, mesh = get_mesh_from_args(args, mesh) cur_mesh = get_abstract_mesh() args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), - in_names, args) - in_vma = map(_names_to_vma, in_names) + in_specs, args) + in_vma = map(_spec_to_vma, in_specs) outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma, cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] - _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types + _check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_vma: - _check_vmas(mesh, out_names_thunk(), out_vma) + _check_vmas(mesh, out_specs_thunk(), out_vma) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) - dst_pspecs = map(_names_to_pspec, out_names_thunk()) + dst_pspecs = out_specs_thunk() return map(partial(_match_spec, mesh, check_vma), src_pspecs, dst_pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl @@ -1014,42 +1000,35 @@ def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh): outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) return outs, out_vma -def _names_to_pspec(names: AxisNames) -> PartitionSpec: - ndmin = max(names) + 1 if names else 0 - unpack = lambda t: t[0] if t is not None and len(t) == 1 else t - return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) -def _unmatch_spec(mesh: Mesh, check_vma, src: AxisNames, x: JaxType, - context_mesh) -> JaxType: +def _unmatch_spec(mesh: Mesh, check_vma, in_spec, x: JaxType, context_mesh + ) -> JaxType: with (core.eval_context(), jax.disable_jit(False), use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, check_vma, - tuple(src.items())))(x) + return jax.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec))(x) -def _unmatch(mesh, check_vma, src_tup, x): - src = _names_to_pspec(dict(src_tup)) +def _unmatch(mesh, check_vma, in_spec, x): if check_vma: - used_axes = {i for _, ns in src_tup for i in ns} + used_axes = _spec_to_vma(in_spec) dst = P(order_wrt_mesh(mesh, used_axes)) else: dst = P(mesh.axis_names) check_vma = False - return shard_map(_add_singleton, mesh=mesh, in_specs=(src,), out_specs=dst, - check_vma=check_vma)(x) + return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,), + out_specs=dst, check_vma=check_vma)(x) -def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] - ) -> None: - fail = [a if n and not max(n) < a.ndim else no_fail - for n, a in zip(names, avals)] +def _check_names(specs, avals: Sequence[core.ShapedArray]) -> None: + fail = [a if sp and len(sp) > a.ndim else no_fail + for sp, a in zip(specs, avals)] if any(f is not no_fail for f in fail): raise _SpecError(fail) class _SpecError(Exception): pass -def _check_vmas(mesh, names, vmas): - fail = [vma if not _valid_repeats(mesh, vma, n) else no_fail - for n, vma in zip(names, vmas)] +def _check_vmas(mesh, specs, vmas): + fail = [vma if not _valid_repeats(mesh, vma, sp) else no_fail + for sp, vma in zip(specs, vmas)] if any(f is not no_fail for f in fail): raise _RepError(fail) @@ -1099,7 +1078,7 @@ def to_val_vma_pair(self, val): elif isinstance(val, Tracer): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: - val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) + val_ = _unmatch_spec(self.mesh, self.check, P(), val, self.context_mesh) return val_, frozenset() def process_primitive(self, prim, tracers, params): @@ -1205,6 +1184,7 @@ def __str__(self) -> str: return '\n'.join( f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + __repr__ = __str__ # for debuggers, like `p x` def _prim_applier(prim, check_vma, params_tup, mesh, in_specs, out_specs, *args): @@ -1251,34 +1231,33 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): def _shard_map_batch( trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_vma: bool, - manual_axes: frozenset) -> Sequence[batching.BatchTracer]: + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset + ) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(in_names, in_dims)] spmd_axis_name = trace.axis_data.spmd_name if spmd_axis_name is not None: - used = {n for names in in_names for ns in names.values() for n in ns} + used = {n for spec in in_specs for n in _spec_to_vma(spec)} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(new_in_names, in_dims)] + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name) + for sp, d in zip(in_specs, in_dims)] new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name, None) else: + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(in_specs, in_dims)] new_axis_data = trace.axis_data fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) - new_params = dict(mesh=mesh, in_names=new_in_names, - out_names_thunk=new_out_names_thunk, check_vma=check_vma, + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + return _batch_out_specs(spmd_axis_name, out_dims(), out_specs_thunk()) + + new_params = dict(mesh=mesh, in_specs=new_in_specs, + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) @@ -1287,36 +1266,36 @@ def new_out_names_thunk(): return map(make_tracer, out_vals, out_dims()) batching.BatchTrace.process_shard_map = _shard_map_batch -def _batch_out_names(spmd_axis_name, dims, out_names): - out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(out_names, dims)] - if spmd_axis_name is not None: - used = {n for names in out_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: +def _batch_out_specs(spmd_name, dims, out_specs): + if spmd_name is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if not config.disable_vmap_shmap_error.value and set(spmd_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") - out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(out_names_, dims)] - return out_names_ + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name) + for sp, d in zip(out_specs, dims)] + else: + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(out_specs, dims)] # Autodiff -def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_vma, manual_axes): +def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) f_jvp = ad.jvp_subtrace(f, trace.tag) f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) - tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] + tangent_in_specs = [sp for sp, nz in zip(in_specs, which_nz) if nz] - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_ax = out_names_thunk() + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + out_ax = out_specs_thunk() return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) - params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), - out_names_thunk=new_out_names_thunk, check_vma=check_vma, + params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) @@ -1327,32 +1306,32 @@ def new_out_names_thunk(): ad.JVPTrace.process_shard_map = _shard_map_jvp def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, - f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_vma, manual_axes): + f: lu.WrappedFun, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) - unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) + unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs) in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma), - unk_in_names, in_avals) + unk_in_specs, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits2( f, (*in_knowns,), (*in_avals_sharded,)) all_names = _all_newly_manual_mesh_names(mesh, manual_axes) - @as_hashable_function(closure=out_names_thunk) - def known_out_names(): + @as_hashable_function(closure=out_specs_thunk) + def known_out_specs(): _, _, out_knowns, res_avals, _, _ = aux() - _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) + _, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk()) if check_vma: - res_names = [{0: order_wrt_mesh(mesh, a.vma)} for a in res_avals] + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] else: - res_names = [{0: all_names}] * len(res_avals) - return (*out_known_names, *res_names) + res_specs = [P(all_names)] * len(res_avals) + return (*out_known_specs, *res_specs) - known_params = dict(mesh=mesh, in_names=(*known_in_names,), - out_names_thunk=known_out_names, check_vma=check_vma, + known_params = dict(mesh=mesh, in_specs=(*known_in_specs,), + out_specs_thunk=known_out_specs, check_vma=check_vma, manual_axes=manual_axes) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) @@ -1360,32 +1339,32 @@ def known_out_names(): num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) assert not jaxpr.constvars - unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) - known_out_names_ = known_out_names() + unk_out_specs, _ = pe.partition_list(out_knowns, out_specs_thunk()) + known_out_specs_ = known_out_specs() res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) # TODO make res_avals be the full set, not just the non-fwd ones res_avals_iter = iter(res_avals) - res_names = [] + res_specs = [] for f1, f2 in zip(in_fwd, out_fwd): if f1 is not None: - res_names.append(known_in_names[f1]) + res_specs.append(known_in_specs[f1]) elif f2 is not None: - res_names.append(known_out_names_[f2]) + res_specs.append(known_out_specs_[f2]) else: if check_vma: res_vma = next(res_avals_iter).vma - res_names.append({0: order_wrt_mesh(mesh, res_vma)}) + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) else: - res_names.append({0: all_names}) - unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] + res_specs.append(P(all_names)) + unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment] const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] out_avals_sharded = [v.aval for v in jaxpr.outvars] - unk_params = dict(mesh=mesh, in_names=unk_in_names, - out_names=unk_out_names, jaxpr=jaxpr, + unk_params = dict(mesh=mesh, in_specs=unk_in_specs, + out_specs=unk_out_specs, jaxpr=jaxpr, check_vma=check_vma, manual_axes=manual_axes) - out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_names, + out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_specs, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] @@ -1398,8 +1377,8 @@ def known_out_names(): pe.JaxprTrace.process_shard_map = _shard_map_partial_eval def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, - tracers, mesh, in_names, - out_names_thunk, check_vma, manual_axes): + tracers, mesh, in_specs, out_specs_thunk, check_vma, + manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) @@ -1407,19 +1386,19 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, all_names = _all_newly_manual_mesh_names(mesh, manual_axes) @as_hashable_function(closure=linearize_outs_thunk) - def fwd_out_names_thunk(): + def fwd_out_specs_thunk(): res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) if f1 is None and f2 is None] - out_names = out_names_thunk() + out_specs = out_specs_thunk() if check_vma: - res_names = [{0: order_wrt_mesh(mesh, a.vma)} for a in res_avals] + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] else: - res_names = [{0: all_names}] * len(res_avals) - return (*res_names, *out_names) + res_specs = [P(all_names)] * len(res_avals) + return (*res_specs, *out_specs) fwd_params = dict( - mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, + mesh=mesh, in_specs=in_specs, + out_specs_thunk=fwd_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) @@ -1434,30 +1413,31 @@ def fwd_out_names_thunk(): use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))), config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) - out_names = out_names_thunk() + out_specs = out_specs_thunk() res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) if f1 is None and f2 is None] res_avals_iter = iter(res_avals2) - res_names = [] + res_specs = [] for f1, f2 in zip(in_fwd, out_fwd): if f1 is not None: - res_names.append(in_names[f1]) + res_specs.append(in_specs[f1]) elif f2 is not None: - res_names.append(out_names[f2]) + res_specs.append(out_specs[f2]) else: if check_vma: res_vma = next(res_avals_iter).vma - res_names.append({0: order_wrt_mesh(mesh, res_vma)}) + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) else: - res_names.append({0: all_names}) - new_in_names = (*res_names, *({} for _ in range(len(env))), - *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) - @as_hashable_function(closure=tangent_out_names) - def tangent_out_names_thunk(): - return tangent_out_names + res_specs.append(P(all_names)) + new_in_specs = (*res_specs, *(P(),) * len(env), + *(ax for ax, nz in zip(in_specs, nzs_in) if nz)) + tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out) + if nz) + @as_hashable_function(closure=tangent_out_specs) + def tangent_out_specs_thunk(): + return tangent_out_specs tangent_params = dict( - mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, + mesh=mesh, in_specs=new_in_specs, out_specs_thunk=tangent_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here @@ -1509,29 +1489,29 @@ def fun(*res_and_args): return jaxpr -def _unmentioned2(mesh: Mesh, names: AxisNames, - manual_axes: frozenset[AxisName]) -> list[AxisName]: +def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName] + ) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-vma case. - name_set = {n for ns in names.values() for n in ns} + name_set = _spec_to_vma(spec) return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) if n not in name_set] def _shard_map_transpose(out_cts, *args, - jaxpr: core.Jaxpr, mesh, in_names, out_names, + jaxpr: core.Jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ - ad.Zero(_shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) + ad.Zero(_shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, manual_axes)))) - for ns, x in zip(out_names, out_cts) + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes)))) + for sp, x in zip(out_specs, out_cts) ] args = tuple(x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal( - _shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) - for ns, x in zip(in_names, args)) + _shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) + for sp, x in zip(in_specs, args)) all_args, in_tree = tree_flatten((out_cts, args)) def fun_trans_callable(out_cts, args): @@ -1544,11 +1524,11 @@ def fun_trans_callable(out_cts, args): in_cts = ad.backward_pass( jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts )[len(res_reshaped):] - _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, ns, x.aval)) + _, in_ct_specs = partition_list(in_undef, in_specs) + in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, sp, x.aval)) if type(x) is ad.Zero else x if check_vma - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, manual_axes))) - for ns, x in zip(in_ct_names, in_cts)] + else jax.lax.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes))) + for sp, x in zip(in_ct_specs, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] return merge_lists(in_undef, res_zeros, in_cts) @@ -1556,17 +1536,17 @@ def fun_trans_callable(out_cts, args): fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) - new_in_names = \ - [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ - [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] + new_in_specs = ( + [n for n, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] + + [n for n, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal]) - def new_out_names_thunk(): - return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) + def new_out_specs_thunk(): + return tuple(sp for sp, nz in zip(in_specs, nz_arg_cts()) if nz) try: out_flat = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_vma=check_vma, + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " @@ -1576,8 +1556,8 @@ def new_out_names_thunk(): # in eager mode so that output of shmap are not manual. with jax.disable_jit(True): _ = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_vma=check_vma, + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None @@ -1618,22 +1598,22 @@ def _partial_eval_jaxpr_custom_rule( _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) newvar = core.gensym() - residuals, staged_in_res_names = [], [] + residuals, staged_in_res_specs = [], [] for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: - rn = ({0: order_wrt_mesh(mesh, var.aval.vma)} # type: ignore - if check_vma else {0: _all_newly_manual_mesh_names(mesh, manual_axes)}) + rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore + if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes))) residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) - staged_in_res_names.append(rn) + staged_in_res_specs.append(rn) if check_vma: - out_res_names_known = [{0: order_wrt_mesh(mesh, var.aval.vma)} # type: ignore + out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore for var, o in zip(res_vars, out_fwd) if o is None] else: - out_res_names_known = [ - {0: _all_newly_manual_mesh_names(mesh, manual_axes)}] * sum(which) + out_res_specs_known = [ + P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which) params_known, params_staged = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, - out_res_names_known, staged_in_res_names, + out_res_specs_known, staged_in_res_specs, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1681,27 +1661,27 @@ def staged(*args): return jaxpr_known, jaxpr_staged def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, - in_fwd, out_fwd, out_res_names_known, staged_in_res_names, + in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, params_known, params_staged): # prune inputs to jaxpr_known according to unks_in - in_names_known, _ = partition_list(unks_in, params_known['in_names']) - _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + out_res_names_known - assert len(out_names_known) == len(params_known['jaxpr'].outvars) - new_params_known = dict(params_known, in_names=tuple(in_names_known), - out_names=tuple(out_names_known)) + in_specs_known, _ = partition_list(unks_in, params_known['in_specs']) + _, out_specs_known = partition_list(kept_outs_known, params_known['out_specs']) + out_specs_known = out_specs_known + out_res_specs_known + assert len(out_specs_known) == len(params_known['jaxpr'].outvars) + new_params_known = dict(params_known, in_specs=tuple(in_specs_known), + out_specs=tuple(out_specs_known)) # added num_res new inputs to jaxpr_staged, pruning according to inst_in - _, in_names_staged = partition_list(inst_in, params_staged['in_names']) - iter_staged = iter(staged_in_res_names) - res_names = [in_names_known[f1] if f1 is not None else - out_names_known[f2] if f2 is not None else + _, in_specs_staged = partition_list(inst_in, params_staged['in_specs']) + iter_staged = iter(staged_in_res_specs) + res_specs = [in_specs_known[f1] if f1 is not None else + out_specs_known[f2] if f2 is not None else next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] - in_names_staged = res_names + in_names_staged - _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) - new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), - out_names=tuple(out_names_staged)) + in_specs_staged = res_specs + in_specs_staged + _, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs']) + new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged), + out_specs=tuple(out_specs_staged)) return new_params_known, new_params_staged # TODO(mattjj): remove this mechanism when we revise mesh scopes @@ -1742,10 +1722,10 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None else: - _, in_names = partition_list(used_inputs, eqn.params['in_names']) - _, out_names = partition_list(used_outputs, eqn.params['out_names']) - new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), - out_names=tuple(out_names)) + _, in_specs = partition_list(used_inputs, eqn.params['in_specs']) + _, out_specs = partition_list(used_outputs, eqn.params['out_specs']) + new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs), + out_specs=tuple(out_specs)) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1bebba095896..9b4ca76c3bc5 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -736,10 +736,10 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name='y'))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('y',), 1: ('x',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'], (P('y', 'x'),)) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'], (P('y', 'x'),)) def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -771,10 +771,10 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'][0], P(('x', 'y'))) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'][0], P(('x', 'y'))) def test_nested_vmap_with_capture_spmd_axis_name(self): self.skipTest('https://github.com/jax-ml/jax/issues/23476') From 048db94ed914fd656818f21efd70d7356326a893 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Tue, 20 May 2025 13:19:37 -0700 Subject: [PATCH 1260/1769] [JAX] Enable tfrt gpu in jax multi platform test PiperOrigin-RevId: 761204561 --- tests/BUILD | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index 1b33f292d503..aa777080fd92 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -121,6 +121,10 @@ jax_py_test( jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], + disable_configs = [ + "gpu_h100_tfrt", # TODO(b/411472145): Re-enable once fixed. + "gpu_h100x2_tfrt", + ], enable_backends = [ "cpu", "gpu", @@ -128,7 +132,9 @@ jax_multiplatform_test( enable_configs = [ "gpu_h100x2", ], - tags = ["multiaccelerator"], + tags = [ + "multiaccelerator", + ], deps = py_deps([ "absl/testing", "numpy", @@ -1134,6 +1140,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], + disable_configs = [ + "gpu_h100_tfrt", # TODO(b/411472145): Re-enable once fixed. + "gpu_h100x2_tfrt", + ], enable_backends = [ "cpu", "gpu", From cabda1054ed01f7378b557bd1f482f34f5cf49e0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 May 2025 13:40:09 -0700 Subject: [PATCH 1261/1769] [chex] remove stale try/except import chex now requires jax>=0.4.27, so the previous backward-compatibility is no longer necessary. PiperOrigin-RevId: 761212887 --- jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/BUILD b/jax/BUILD index 61b5a99dfe31..18e670e0269c 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -548,6 +548,7 @@ pytype_strict_library( pytype_strict_library( name = "compiler", srcs = ["_src/compiler.py"], + visibility = [":internal"] + jax_visibility("compiler"), deps = [ ":cache_key", ":compilation_cache_internal", From 11cf85deb0776f800144e829d4ba8e38eb9d76fc Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 20 May 2025 13:52:13 -0700 Subject: [PATCH 1262/1769] [ragged-paged-attn] Apply kv mask to filter out NaNs Expected small regression as we insert kv masking logic which needs to unpack/pack. PiperOrigin-RevId: 761217303 --- .../ops/tpu/ragged_paged_attention/kernel.py | 55 ++++++++++------- .../pallas/tpu_ragged_paged_attention_test.py | 61 +++++++++++++------ 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index cd5de96ccca7..df47674a59a9 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -39,19 +39,16 @@ def __init__( vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] - offset, # [seq_idx, kv_pages_start] + metadata, # [seq_idx, start_page_idx, end_page_idx] ): self._vmem_buf = vmem_buf - seq_id, kv_pages_start = offset - pages_per_seq = page_indices_ref.shape[1] + seq_id, start_page_idx, end_page_idx = metadata self._async_copies = [] # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert # a bunch of if-ops. Check the performance when we have benchmarking setup. for i in range(vmem_buf.shape[0]): - page_idx = kv_pages_start + i - page_idx = jax.lax.select( - page_idx < pages_per_seq, page_idx, pages_per_seq - 1 - ) + page_idx = start_page_idx + i + page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0) self._async_copies.append( pltpu.make_async_copy( pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], @@ -298,6 +295,7 @@ def ragged_paged_attention_kernel( if mask_value is None: mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + pages_per_seq = page_indices_ref.shape[-1] num_seqs = num_seqs_ref[0] _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( kv_bufs.shape @@ -318,7 +316,11 @@ def ragged_paged_attention_kernel( def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): - offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) + start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk + end_kv_page_idx = jnp.minimum( + pages_per_seq, cdiv(kv_lens_ref[seq_idx], page_size) + ) + metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx) heads_start = heads_blk_idx * num_combined_kv_heads_per_blk async_copy_kv = MultiPageAsyncCopyDescriptor( kv_pages_hbm_ref.at[ @@ -327,7 +329,7 @@ def create_kv_async_copy_descriptors( kv_bufs.at[buf_idx], sems.at[buf_idx], page_indices_ref, - offset, + metadata, ) return async_copy_kv @@ -423,18 +425,22 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, head_dim, ) - assert k.shape == ( - num_kv_per_blk, - head_dim, - ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" - assert v.shape == (num_kv_per_blk, head_dim) - assert head_m_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert ( + k.shape + == v.shape + == ( + num_kv_per_blk, + head_dim, + ) ) - assert head_l_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert k.dtype == v.dtype + assert ( + head_m_ref.shape + == head_l_ref.shape + == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) ) assert head_acc_ref.shape == ( num_q_per_blk, @@ -448,6 +454,13 @@ def masked_store(ref, val, start, end, group=1): mask = jnp.logical_and(iota >= start, iota < end) pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + # kv lens will be contracting dim, we should mask out the NaNs. + kv_mask = ( + lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start + ) + k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype) + v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype) + qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * sm_scale @@ -709,7 +722,7 @@ def ragged_paged_attention( Args: q: concatenated all sequences' queries. - kv_pages: paged K cache. Normally in HBM. + kv_pages: paged KV cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index f86d54575519..4265445c69c7 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -19,6 +19,7 @@ import jax from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + cdiv, dynamic_validate_inputs, ragged_paged_attention, ref_ragged_paged_attention, @@ -29,13 +30,8 @@ jax.config.parse_flags_with_absl() -def ceil_div(x, a): - assert a != 0 - return (x + a - 1) // a - - @jtu.with_config(jax_numpy_dtype_promotion="standard") -class PagedAttentionKernelTest(jtu.JaxTestCase): +class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): def _test_ragged_paged_attention( self, @@ -66,29 +62,56 @@ def _test_ragged_paged_attention( max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) max_num_seq = max(len(seq_lens), max_num_seq) max_kv_len = max(kv_lens) - pages_per_seq = ceil_div(max_kv_len, page_size) + pages_per_seq = cdiv(max_kv_len, page_size) num_q_heads, num_kv_heads = num_heads - cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) - kv_lens = jnp.array(kv_lens, dtype=jnp.int32) - cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) - kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2 = jax.random.split(prng_key, 3) + k0, k1 = jax.random.split(prng_key, 2) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), dtype=dtype, ) - kv_pages = jax.random.normal( - k1, - (num_pages, page_size, num_kv_heads * 2, head_dim), - dtype=dtype, + page_cnt = 0 + page_indices_list = [] + kv_pages_list = [] + for kv_len in kv_lens: + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=dtype, + ) + kv = jnp.pad( + kv, + ((0, cdiv(kv_len, page_size) * page_size - kv_len), (0, 0), (0, 0)), + constant_values=jnp.nan, + ).reshape(-1, page_size, num_kv_heads * 2, head_dim) + indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32) + indices = jnp.pad( + indices, + ((0, pages_per_seq - indices.shape[0]),), + constant_values=jnp.nan, + ) + page_indices_list.append(indices) + page_cnt += kv.shape[0] + kv_pages_list.append(kv) + + kv_pages = jnp.concatenate(kv_pages_list, axis=0) + kv_pages = jnp.pad( + kv_pages, + ((0, num_pages - kv_pages.shape[0]), (0, 0), (0, 0), (0, 0)), + constant_values=jnp.nan, ) - page_indices = jax.random.randint( - k2, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + page_indices = jnp.stack(page_indices_list, axis=0) + page_indices = jnp.pad( + page_indices, + ((0, max_num_seq - page_indices.shape[0]), (0, 0)), + constant_values=jnp.nan, ) - + cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) + cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) + kv_lens = jnp.array(kv_lens, dtype=jnp.int32) + kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) dynamic_validate_inputs( From 4ff6eb25f21e7a8296bcb6f8adbc5ffec4e8ae6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 20 May 2025 14:24:26 -0700 Subject: [PATCH 1263/1769] [Mosaic:TPU] Enforce that tpu.dynamic_gather operands and result have the same shape PiperOrigin-RevId: 761230458 --- jax/_src/pallas/mosaic/lowering.py | 4 ++-- jaxlib/mosaic/dialect/tpu/tpu.td | 13 +++++++++++-- .../tpu/transforms/apply_vector_layout.cc | 16 ++++++++++------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index a9fbf8dcd982..eb5e6df7b381 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2369,7 +2369,7 @@ def _gather_lowering_rule( operand_batching_dims=(1,), start_indices_batching_dims=(1,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 0) + return tpu.dynamic_gather(x, recovered_indices, 0) if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(1,), @@ -2377,7 +2377,7 @@ def _gather_lowering_rule( operand_batching_dims=(0,), start_indices_batching_dims=(0,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 1) + return tpu.dynamic_gather(x, recovered_indices, 1) raise NotImplementedError("Unsupported gather") diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 29ce9c84de07..b6ae1e52e822 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -466,10 +466,19 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> { }]; } -def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { +def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultShape, AllTypesMatch<["source", "output"]>]> { + let description = [{ + Gathers elements from `source` using `indices`. + + Given a shape `N0 x N1 x ...`, `output[i0, i1, ...]` is given by + `input[j0, j1, ...]` where `jn = indices[i0, i1, ...] mod Ni` for + `n = dimension` and `jn = in` otherwise. + + Similar to `np.take_along_axis`, except that OOB indices wrap. + }]; let arguments = (ins AnyVectorOfNonZeroRank:$source, - AnyVectorOfNonZeroRank:$indices, + VectorOfNonZeroRankOf<[AnyInteger]>:$indices, I32Attr:$dimension ); let results = (outs AnyVectorOfNonZeroRank:$output); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5ddff9d9ee53..fa14c8ef9238 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3537,12 +3537,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; Value lane_offset_cst = getFullVector( - builder, getNativeVregType(builder.getI32Type(), ctx.target_shape), - builder.getI32IntegerAttr(lane_offset)); + builder, i32_vreg_ty, builder.getI32IntegerAttr(lane_offset)); DenseI32ArrayAttr sublane_pattern; if (num_tiles != 1) { SmallVector pattern; @@ -3555,7 +3556,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, sublane_pattern = builder.getDenseI32ArrayAttr(pattern); } src_tiles.Each([&](const absl::Span src_idx, - Value *const src_tile) { + Value *const src_vreg) { SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3567,10 +3568,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, dst_limits[i] = dst_starts[i] + 1; } } - Value res_vreg = builder.create( - broadcast_op.getLoc(), src_tile->getType(), *src_tile, - lane_offset_cst, + Value src_vreg_i32 = + builder.create(i32_vreg_ty, *src_vreg); + Value res_vreg_i32 = builder.create( + broadcast_op.getLoc(), i32_vreg_ty, src_vreg_i32, lane_offset_cst, /*dimension=*/1); + Value res_vreg = builder.create( + src_vreg->getType(), res_vreg_i32); if (num_tiles != 1) { res_vreg = builder.create( broadcast_op.getLoc(), res_vreg.getType(), res_vreg, From 5d3134e9fa3d40f0a24ce08a6225d507d65b634c Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 May 2025 14:27:29 -0700 Subject: [PATCH 1264/1769] Make sure tests with `--build_jaxlib=false` depend on NVIDIA CUDA wheels hermetically. PiperOrigin-RevId: 761231648 --- jaxlib/jax.bzl | 4 ++-- jaxlib/tools/BUILD.bazel | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index a8dc67eb3804..eceb38e35aab 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -190,8 +190,8 @@ def _gpu_test_deps(): "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_false": [ - "@pypi//jax_cuda12_plugin", - "@pypi//jax_cuda12_pjrt", + "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", + "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_wheel": [ diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 22bae26a4420..d6a5f94dfd4b 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -487,6 +487,19 @@ py_import( wheel_deps = if_cuda([":nvidia_wheel_deps"]), ) +# The targets below are used for GPU tests with `--//jax:build_jaxlib=false`. +py_import( + name = "pypi_jax_cuda_plugin_with_cuda_deps", + wheel = "@pypi_jax_cuda12_plugin//:whl", + wheel_deps = if_cuda([":nvidia_wheel_deps"]), +) + +py_import( + name = "pypi_jax_cuda_pjrt_with_cuda_deps", + wheel = "@pypi_jax_cuda12_pjrt//:whl", + wheel_deps = if_cuda([":nvidia_wheel_deps"]), +) + # Wheel tests. AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) From 12e07c963286c9fa7eb7d8651ce968537be1ab8a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 14:43:50 -0700 Subject: [PATCH 1265/1769] Reverts 06448864abd6e8187e5b4d9b1ff08ab14fe3b8e0 PiperOrigin-RevId: 761237485 --- jax/_src/compiler.py | 52 +++++++++------------------------------- jaxlib/_jax/__init__.pyi | 11 --------- jaxlib/xla_client.py | 1 - jaxlib/xla_client.pyi | 1 - 4 files changed, 11 insertions(+), 54 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index e8ef647a1312..04f993fed799 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -292,19 +292,6 @@ def backend_compile( executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], -) -> xc.LoadedExecutable: - return backend_compile_and_load( - backend, module, executable_devices, options, host_callbacks - ) - - -@profiler.annotate_function -def backend_compile_and_load( - backend: xc.Client, - module: ir.Module, - executable_devices: xc.DeviceList, - options: xc.CompileOptions, - host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -335,35 +322,18 @@ def backend_compile_and_load( # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - elif jaxlib_extension_version < 342 or isinstance(backend, xc.CompileOnlyPyClient): - if host_callbacks: - return backend.compile( - built_c, - executable_devices=executable_devices, # type: ignore - compile_options=options, - host_callbacks=host_callbacks, - ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` + if host_callbacks: return backend.compile( - built_c, executable_devices=executable_devices, compile_options=options) # type: ignore - else: - if host_callbacks: - return backend.compile_and_load( - built_c, - executable_devices=executable_devices, - compile_options=options, - host_callbacks=host_callbacks, - ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile_and_load( built_c, - executable_devices=executable_devices, + executable_devices=executable_devices, # type: ignore compile_options=options, + host_callbacks=host_callbacks, ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile( + built_c, executable_devices=executable_devices, compile_options=options) # type: ignore except xc.XlaRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) @@ -428,7 +398,7 @@ def compile_or_get_cached( ) if cache_key is None: - return backend_compile_and_load( + return backend_compile( backend, computation, executable_devices, compile_options, host_callbacks) @@ -456,7 +426,7 @@ def compile_or_get_cached( config.share_binary_between_hosts.value and is_multi_process and distributed.global_state.client is not None - # Host callbacks are currently baked into the HLO module so we can't share + # Host callbacks are currently baked into the HLO module so we cant share # them. and len(host_callbacks) == 0 ): @@ -746,7 +716,7 @@ def _compile_and_write_cache( cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() - executable = backend_compile_and_load( + executable = backend_compile( backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 000c05acacad..1d7f3042e8a3 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -551,17 +551,6 @@ class Client: ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... - -class CompileOnlyPyClient(Client): - def compile( - self, - computation: str | bytes, - executable_devices: DeviceList | Sequence[Device], - compile_options: CompileOptions = ..., - host_callbacks: Sequence[Any] = ..., - ) -> LoadedExecutable: ... - - class CpuCollectives: ... def make_gloo_tcp_collectives( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index b1bbc464610e..8f8c829ee6c7 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -304,7 +304,6 @@ def computation_count(): XlaComputation = _xla.XlaComputation Client = _xla.Client -CompileOnlyPyClient = _xla.CompileOnlyPyClient Memory = _xla.Memory Array = _xla.Array ArrayImpl = _xla.ArrayImpl diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index fce114f45474..80599e86676b 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -24,7 +24,6 @@ from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics from jaxlib._jax import ArrayImpl as ArrayImpl from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode from jaxlib._jax import Client as Client -from jaxlib._jax import CompileOnlyPyClient as CompileOnlyPyClient from jaxlib._jax import CompileOptions as CompileOptions from jaxlib._jax import Device as Device from jaxlib._jax import DeviceAssignment as DeviceAssignment From bd436fb1f84a8162d83675e7d35cea1f01412be8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 May 2025 15:20:19 -0700 Subject: [PATCH 1266/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ea0d903efdff87b261e1d68c59011138ee13a9ac. PiperOrigin-RevId: 761250461 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d91873fdfad8..3c17ef07aacc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4a914a285377e90c464a0b7fad9b5cbcfeeb27a9" -XLA_SHA256 = "3df562d5b67db755d88c469e45ae27ba9e1387f5d79cf25fba17c8c8ea74cfe8" +XLA_COMMIT = "ea0d903efdff87b261e1d68c59011138ee13a9ac" +XLA_SHA256 = "619d17a8f03f2bdd36c801b7d55180f21f8fed63623da224ddbb099346384e84" def repo(): tf_http_archive( From b96a63904faf82b9404c8b551cff9c1c6bbb962f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 May 2025 16:14:40 -0700 Subject: [PATCH 1267/1769] [JAX] Fix float typo in a code example in the sharded-computation doc. PiperOrigin-RevId: 761268857 --- docs/sharded-computation.ipynb | 4 ++-- docs/sharded-computation.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 1bae4014b5a8..72cc2d193bfd 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -7,7 +7,7 @@ "(sharded-computation)=\n", "# Introduction to parallel programming\n", "\n", - "\n", + "\n", "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", @@ -495,7 +495,7 @@ "id": "c09acf7d", "metadata": {}, "source": [ - "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", + "We should read the type `int32[4@X, 2]` as \"a 4-by-2 array of 32-bit ints whose first dimension\n", "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", "axes\"\n", "\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 16a5dc8cfa08..89ffbc07da38 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -14,7 +14,7 @@ kernelspec: (sharded-computation)= # Introduction to parallel programming - + This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. @@ -193,7 +193,7 @@ print(f"replicated_array type: {jax.typeof(replicated_array)}") print(f"sharded_array type: {jax.typeof(sharded_array)}") ``` -We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension +We should read the type `int32[4@X, 2]` as "a 4-by-2 array of 32-bit ints whose first dimension is sharded along mesh axis 'X'. The array is replicated along all other mesh axes" From a66f5086dfbfec2f45ccb617f08128574194cf0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 17:07:22 -0700 Subject: [PATCH 1268/1769] Do not construct an error string for FoldingError. This just wastes time since the error will never be read. PiperOrigin-RevId: 761285843 --- jax/_src/pallas/mosaic/lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index eb5e6df7b381..919c548adc9b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2444,7 +2444,7 @@ class FoldingError(Exception): def _fold_and_get_constant_value(x): def _fold(x, fuel): if fuel <= 0: - raise FoldingError("Folding depth exceeded") + raise FoldingError() op_name = getattr(x.owner, "name", None) binop_folds = { "arith.maxsi": max, @@ -2459,7 +2459,7 @@ def _fold(x, fuel): raise ValueError(f"Unsupported constant type: {x.type}") if op_name in binop_folds: return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands) - raise FoldingError(f"Folding not supported for {x.owner}") + raise FoldingError() try: return _fold(x, 10) From 862791a91e2a8c87fe4c2c46ae9fd6a9e01adbec Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 May 2025 18:15:10 -0700 Subject: [PATCH 1269/1769] Reland the limit on the number of OpenBLAS threads. This was previously removed in https://github.com/jax-ml/jax/commit/18ff6caa4f767701dd7cca3a1333d9b99465e045, and that promptly broke our CI again. I am guessing the problem is actually too few threads, not a NumPy deadlock as I originally guessed. --- .bazelrc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.bazelrc b/.bazelrc index 53676637c839..8906234c9061 100644 --- a/.bazelrc +++ b/.bazelrc @@ -244,6 +244,9 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" build:ci_linux_aarch64_base --color=yes +# This appears to help avoid a timeout in CI for linalg_test. +build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8 + build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -379,6 +382,9 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/ build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base +# Avoids a timeout in linalg_test on ARM. +build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8 + # Mac x86 build:cross_compile_darwin_x86_64 --config=cross_compile_base build:cross_compile_darwin_x86_64 --config=nonccl From eef1f6cf9af4fabd47a70b71f78bf1339c9a36cd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 20 May 2025 19:35:59 -0700 Subject: [PATCH 1270/1769] Support passing PartitionSpecs to ShapeDtypeStruct when there is a mesh in context. PiperOrigin-RevId: 761322712 --- jax/_src/api.py | 19 ++++++++++++++++--- tests/api_test.py | 7 ------- tests/pjit_test.py | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 33802e494304..059db1c92c98 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2826,9 +2826,10 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False): if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - if sharding is not None and not isinstance(sharding, (Sharding, Layout)): + if sharding is not None and not isinstance(sharding, (Sharding, Layout, P)): raise ValueError( - "sharding should be an instance of `jax.sharding.Sharding` or" + "sharding should be an instance of `jax.sharding.Sharding`, " + "`jax.sharding.PartitionSpec` or" f" `jax.experimental.layout.Layout`. Got {sharding} of type" f" {type(sharding)}.") if (isinstance(sharding, Layout) and @@ -2836,7 +2837,19 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False): raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout in a `ShapeDtypeStruct`. Got {sharding}") - self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding + if isinstance(sharding, Layout): + self.sharding = sharding.sharding + elif isinstance(sharding, P): + # TODO(yashkatariya): Should this be abstract mesh? + cur_mesh = get_concrete_mesh() + if cur_mesh is None: + raise TypeError( + "When specifying PartitionSpec to `ShapeDtypeStruct`, the context" + " mesh cannot be empty. Please use `jax.sharding.use_mesh` to set" + " the mesh context.") + self.sharding = NamedSharding(cur_mesh, sharding) + else: + self.sharding = sharding self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None self.weak_type = weak_type diff --git a/tests/api_test.py b/tests/api_test.py index 3fe3d6fa7514..9963e2603588 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4525,13 +4525,6 @@ def foo(x): with self.assertRaisesRegex(TypeError, "applied to foo"): f_vjp(1.0, 1.0) - def test_shapedtypestruct_sharding_error(self): - with self.assertRaisesRegex( - ValueError, - "sharding should be an instance of `jax.sharding.Sharding`."): - jax.ShapeDtypeStruct((8, 2), np.float32, - sharding=jax.sharding.PartitionSpec('x')) - def test_make_jaxpr_weakref(self): class Foo(NamedTuple): x: int diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7d87705e20bd..88ec58ec37cd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5006,6 +5006,25 @@ def test_sds_update(self): with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"): s4.update(sharding=NamedSharding(mesh, P('x'))) + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_sds_pspec_input(self, mesh): + inp = jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + lowered = jax.jit(lambda x: x * 2).lower(inp) + self.assertIn('num_partitions = 2', lowered.as_text()) + + np_inp = np.arange(4, dtype=np.float32).reshape(2, 2) + arr = jax.device_put(np_inp, P('x')) + out = lowered.compile()(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_sds_pspec_no_mesh_ctx_error(self): + with self.assertRaisesRegex( + TypeError, + 'When specifying PartitionSpec to `ShapeDtypeStruct`, the context mesh' + ' cannot be empty'): + jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From d8804aae654bf902fb1cc7c92e22ad65b1c59e6b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 May 2025 02:07:33 -0700 Subject: [PATCH 1271/1769] Add missing test skips (for too old jaxlib) + bump minimum libtpu version PiperOrigin-RevId: 761429155 --- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- .github/workflows/pytest_tpu.yml | 2 +- tests/pallas/gpu_pallas_distributed_test.py | 2 ++ tests/pjit_test.py | 3 +++ 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 5cc2aebe3cd0..061b399132e2 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -44,7 +44,7 @@ jobs: jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250226 PYTHON: python${{ matrix.python-version }} runs-on: ${{ matrix.tpu.runner }} container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 55a0b4cc1a5f..8ecccbe274e9 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -82,7 +82,7 @@ jobs: # End Presubmit Naming Check github-tpu-presubmits env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250226 JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index d862e6b9b819..81433b8c5067 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -34,6 +34,8 @@ class PallasCallRemoteDMATest(jt_multiprocess.MultiProcessTest): def setUp(self): + if jtu.jaxlib_version() < (0, 6, 1): + self.skipTest("Test requires jaxlib >= 0.6.1") if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 88ec58ec37cd..523601691a97 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -63,6 +63,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import _jax +from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -7761,6 +7762,8 @@ def f(x): @config.use_shardy_partitioner(True) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_unreduced_basic(self, mesh): + if jaxlib_extension_version < 342: + self.skipTest("Test requires a newer jaxlib") np_inp = np.arange(16).reshape(8, 2) x = jax.device_put(np_inp, P('x', 'y')) y = jax.device_put(np_inp.T, P('y', None)) From 4d6e39e4dd38e60dcd4b999284c54030a4076cea Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 May 2025 02:52:37 -0700 Subject: [PATCH 1272/1769] Disable the newly added tfrt targets that never worked PiperOrigin-RevId: 761442340 --- tests/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index aa777080fd92..2418c8224869 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -460,6 +460,9 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work + ], enable_backends = ["gpu"], tags = [ "config-cuda-only", @@ -1844,6 +1847,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work + ], enable_configs = [ "gpu_p100x2_shardy", "tpu_v3_x4_shardy", From 0f844335c706aceba87390468395bfbf168bb7c8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 May 2025 03:11:01 -0700 Subject: [PATCH 1273/1769] Disable tests for the paged attention kernel I ran the TPU race checker on it and it did report a number of races that were uncovered by recent Mosaic compiler changes. PiperOrigin-RevId: 761447182 --- tests/pallas/tpu_paged_attention_kernel_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index ac24fea1b45a..9886e7943f6f 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -265,6 +265,8 @@ def test_paged_attention( attn_logits_soft_cap, are_kv_quantized, ): + # TODO(mvoz, skyewm): Re-enable this test once the data race is fixed. + self.skipTest("This kernel has data races that need to be fixed.") if not jtu.is_device_tpu_at_least(4): self.skipTest("Only supports TPU generation 4 or above") if jtu.is_device_tpu(version=4) and are_kv_quantized: From 482b8129e6997cfbd66d068c0d06d7ccec634086 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 May 2025 05:07:41 -0700 Subject: [PATCH 1274/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/47a47140b5f7a33d86901316d9c569de179bae08. PiperOrigin-RevId: 761478240 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3c17ef07aacc..8062f071f777 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ea0d903efdff87b261e1d68c59011138ee13a9ac" -XLA_SHA256 = "619d17a8f03f2bdd36c801b7d55180f21f8fed63623da224ddbb099346384e84" +XLA_COMMIT = "47a47140b5f7a33d86901316d9c569de179bae08" +XLA_SHA256 = "c6c55e0e426b2f4a78f1b5459baab29a61d2435c089628e83db92ef8f890a02f" def repo(): tf_http_archive( From 2f0ca797f0a90801714cd5093f32983275db3296 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 21 May 2025 05:25:11 -0700 Subject: [PATCH 1275/1769] [pallas:mosaic_gpu] Removed `GPU*` aliases PiperOrigin-RevId: 761482778 --- jax/experimental/pallas/mosaic_gpu.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index cc0e185e296a..8c7870412403 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -69,9 +69,3 @@ SMEM = MemorySpace.SMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`. TMEM = MemorySpace.TMEM - -# TODO(slebedev): Deprecate and remove these aliases. -GPUBlockSpec = BlockSpec -GPUCompilerParams = CompilerParams -GPUMemorySpace = MemorySpace -GPUMesh = Mesh From 8653a78b80eb7bd6dce2a9c78bce6af949dd53d2 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 21 May 2025 05:52:13 -0700 Subject: [PATCH 1276/1769] [Mosaic GPU] Support `s8xs8->s32` WGMMA. PiperOrigin-RevId: 761489756 --- .../mosaic/gpu/fragmented_array.py | 4 +- jax/experimental/mosaic/gpu/wgmma.py | 52 ++++++++-- tests/mosaic/gpu_test.py | 97 ++++++++++++++++++- 3 files changed, 136 insertions(+), 17 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 62b43b4de737..f69d3f33fe7c 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2551,7 +2551,7 @@ def _repack(regs_it, reg_ty): for array in arrays: reg_ty = array.registers.flat[0].type dtype = array.mlir_dtype - if ir.F32Type.isinstance(dtype): + if ir.F32Type.isinstance(dtype) or dtype == i32: if ir.VectorType.isinstance(reg_ty): [vec_len] = ir.VectorType(reg_ty).shape array_regs = [ # pylint: disable=g-complex-comprehension @@ -2561,7 +2561,7 @@ def _repack(regs_it, reg_ty): ] else: array_regs = list(array.registers.flat) - reg_constraint = "f" + reg_constraint = "r" if dtype == i32 else "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 3637778c371b..abbd517fb37d 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -63,7 +63,10 @@ def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): f32 = ir.F32Type.get() if dtype is None: dtype = f32 - zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + if ir.IntegerType.isinstance(dtype): + zero = arith.constant(dtype, ir.IntegerAttr.get(dtype, 0)) + else: + zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) return cls( _value=fa.FragmentedArray.splat( zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed @@ -90,6 +93,8 @@ def _supported_wgmma_types(dtype, abtype) -> bool: return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types)) elif ir.F16Type.isinstance(dtype): return any(input_types_are(ty) for ty in f16_acc_types) + elif ir.IntegerType.get_signless(32).isinstance(dtype): + return input_types_are(ir.IntegerType.get_signless(8)) else: return False @@ -135,7 +140,7 @@ def wgmma_m64( if a_transpose is None: raise ValueError - if ir.F32Type.isinstance(out_ty): + if ir.F32Type.isinstance(out_ty) or out_ty == i32: num_acc_regs = n // 2 out_ty_field = out_ty acc_regs = [ # pylint: disable=g-complex-comprehension @@ -143,8 +148,9 @@ def wgmma_m64( for reg in acc.flat for pos in range(2) ] - to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) - acc_constraint = "f" + to_acc_vec_regs = functools.partial( + _as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) + acc_constraint = "r" if ir.IntegerType.isinstance(out_ty) else "f" elif ir.F16Type.isinstance(out_ty): num_acc_regs = n // 4 out_ty_field = i32 @@ -153,9 +159,15 @@ def wgmma_m64( to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) acc_constraint = "r" else: - raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})") + raise ValueError( + f"WGMMA instruction only supports f32, f16 and s32 out (got {out_ty})") - num_imm_regs = 4 if supports_transpose else 2 + if supports_transpose: + num_imm_regs = 4 + elif out_ty == i32: + num_imm_regs = 0 + else: + num_imm_regs = 2 if a_in_regs: a_reg_constraints = ["r"] * 4 # 4x f16x2 registers @@ -172,7 +184,6 @@ def wgmma_m64( + ["n"] * (1 + num_imm_regs) # literal constants ) reg_constraints = ",".join(reg_constraints_list) - reg_count = itertools.count() def take_regs(n): @@ -186,7 +197,8 @@ def take_regs(n): else: a_regs, = take_regs(1) b_desc_reg, use_out_reg = take_regs(2) - imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). + # Immediate regs (scale, ...). + imm_regs = "".join(f", {r}" for r in take_regs(num_imm_regs)) assert next(reg_count) == len(reg_constraints_list) k_instr = 32 // bytewidth(element_type) el_ty = str(element_type) @@ -194,9 +206,19 @@ def take_regs(n): el_ty = "e5m2" elif ir.Float8E4M3FNType.isinstance(element_type): el_ty = "e4m3" + elif ir.IntegerType.get_signless(8).isinstance(element_type): + # TODO(bchetioui): add u8 support in the future. Currently we always assume + # that 8-bit integers are s8, and we would need to change the signature of + # `wgmma` to indicate whether the input should be treated as signed or not. + el_ty = "s8" + + out_ty_str = str(out_ty) + if out_ty == i32: + out_ty_str = "s32" + wgmma_instr = ( - f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " - f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty_str}.{el_ty}.{el_ty} " + f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p{imm_regs};" ) ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" @@ -297,6 +319,8 @@ def wgmma( ) f32 = ir.F32Type.get() f16 = ir.F16Type.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) if element_type == f32 or element_type == ir.BF16Type.get(): if acc.value.mlir_dtype != f32: raise ValueError( @@ -312,6 +336,14 @@ def wgmma( f"WGMMA with element type {element_type} only supports accumulators " f"of type f32 or f16, but got: {acc.value.mlir_dtype}" ) + elif element_type == i8: + if a_in_regs and not a.is_signed: + raise NotImplementedError("WGMMA with lhs of type u8") + if acc.value.mlir_dtype != i32 or not acc.value.is_signed: + raise ValueError( + f"WGMMA with element type {element_type} only supports accumulators " + f"of type s32, but got: {acc.value.mlir_dtype}" + ) else: raise NotImplementedError(f"Unsupported element type: {element_type}") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 80e67b20e1ef..4e0544d1758e 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -645,6 +645,19 @@ def kernel(ctx, in_, out, smem): np.testing.assert_array_equal(iota, expected) +class I8Type: + """A type that represents a 8-bit signed integer. + + This is a workaround to bypass the fact that we don't have a proper 8-bit + integer type class available in MLIR, and can't instantiate types without a + MLIR context. + """ + + @staticmethod + def get(): # pylint: disable=no-method-argument + return ir.IntegerType.get_signless(8) + + class WGMMATest(TestCase): def setUp(self): @@ -670,7 +683,67 @@ def setUp(self): rhs_tiling_kind=("large", "small", "small+no_transpose"), lhs_tiling_kind=("large", "small", "small+no_transpose"), ) - def test_wgmma_basic( + def test_wgmma_basic_float( + self, + lhs_transpose, + rhs_transpose, + in_mlir_dtype_cls, + m, + n, + k_steps, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ): + self._test_wgmma_basic( + m, + n, + k_steps, + in_mlir_dtype_cls, + lhs_transpose, + rhs_transpose, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ) + + @parameterized.product( + in_mlir_dtype_cls=(I8Type,), + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + swizzle=(32, 64, 128), + jax_out_dtype=(jnp.int32,), + rhs_tiling_kind=("large", "small", "small+no_transpose"), + lhs_tiling_kind=("large", "small"), + ) + def test_wgmma_basic_int( + self, + in_mlir_dtype_cls, + m, + n, + k_steps, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ): + self._test_wgmma_basic( + m, + n, + k_steps, + in_mlir_dtype_cls, + lhs_transpose=False, + rhs_transpose=True, + swizzle=swizzle, + jax_out_dtype=jax_out_dtype, + rhs_tiling_kind=rhs_tiling_kind, + lhs_tiling_kind=lhs_tiling_kind, + ) + + def _test_wgmma_basic( self, m, n, @@ -683,6 +756,10 @@ def test_wgmma_basic( rhs_tiling_kind, lhs_tiling_kind, ): + if jax_out_dtype == jnp.int32 and in_mlir_dtype_cls != I8Type: + self.skipTest("s32 accumulator only supported with s8 inputs") + if jax_out_dtype != jnp.int32 and in_mlir_dtype_cls == I8Type: + self.skipTest("s8 inputs only supported with s32 accumulator") if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls in {ir.F32Type, ir.BF16Type}: self.skipTest(f"{in_mlir_dtype_cls.get()} does not support f16 output.") if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": @@ -716,6 +793,9 @@ def test_wgmma_basic( elif in_mlir_dtype_cls == ir.Float8E4M3FNType: in_jax_dtype = jnp.float8_e4m3fn exponent_bits, mantissa_bits = 4, 3 + elif in_mlir_dtype_cls == I8Type: + in_jax_dtype = jnp.int8 + exponent_bits = mantissa_bits = None else: raise NotImplementedError(in_mlir_dtype) nk_tile = swizzle // bytewidth(in_mlir_dtype) @@ -755,7 +835,8 @@ def kernel(ctx, lhs, rhs, out, scratch): ) for i in range(2): barriers[i].wait() - init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) + is_signed = True if ir.IntegerType.isinstance(in_mlir_dtype) else None + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype, is_signed=is_signed) if lhs_transpose: perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) lhs_smem = memref_transpose(lhs_smem, perm) @@ -772,9 +853,13 @@ def quantize(x): return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + if in_mlir_dtype_cls == I8Type: + x = self.prng.integers(-128, 127, x_shape).astype(in_jax_dtype) + y = self.prng.integers(-128, 127, y_shape).astype(in_jax_dtype) + else: + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) if transpose_rhs_tiles: rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling @@ -797,7 +882,9 @@ def quantize(x): x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 - if utils.bitwidth(in_mlir_dtype) == 8: + if ir.IntegerType.isinstance(in_mlir_dtype) and ir.IntegerType.isinstance(out_mlir_dtype): + atol = 0 + elif utils.bitwidth(in_mlir_dtype) == 8: atol = 3e-2 np.testing.assert_allclose(z, ref, atol=atol) From 5f764b55d82595141837a7a141a650625b4f7679 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 21 May 2025 05:57:03 -0700 Subject: [PATCH 1277/1769] [Pallas/Mosaic GPU] Loosen tiling requirements for `get` and `swap`. We now allow arbitrary 2D tilings where the minor dimension fits the associated swizzle. PiperOrigin-RevId: 761490845 --- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +++++++++++++++------ tests/pallas/mosaic_gpu_test.py | 30 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index eb5fc136082e..b6960c479558 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1298,11 +1298,14 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != ( - 8, - (swizzle * 8) // pallas_utils.dtype_bitwidth(dtype), - ): - raise NotImplementedError("Tiling does not fit swizzle") + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + expected_minor_tiling = swizzle * 8 // pallas_utils.dtype_bitwidth(dtype) + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(dtype), swizzle=swizzle ) @@ -1383,8 +1386,15 @@ def _swap_lowering_rule( gpu_core.UntileRef(tiling), *maybe_transpose, ): - if tiling != (8, swizzle // v_aval.dtype.itemsize): - raise NotImplementedError("Tiling does not fit swizzle") + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + bw = pallas_utils.dtype_bitwidth(v_aval.dtype) + expected_minor_tiling = swizzle * 8 // bw + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) if transposed_value != bool(maybe_transpose): raise ValueError( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b10bc0f390b0..aedd79e23194 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -998,6 +998,36 @@ def kernel(x_ref, o_ref): self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) + @parameterized.parameters( + (plgpu.TilingTransform((1, 32)), plgpu.SwizzleTransform(128)), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + (), + ) + def test_get_swap_with_transforms(self, *transforms): + self.skip_if_wg_semantics() + + shape = (128, 128) + + @functools.partial( + self.pallas_call, + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + scratch_shapes=[ + plgpu.SMEM(shape, jnp.int32, transforms=tuple(transforms)), + plgpu.Barrier(num_arrivals=1), + ] + ) + def kernel(x_ref, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, scratch_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + scratch_ref[...] = scratch_ref[...] * 2 + plgpu.copy_smem_to_gmem(scratch_ref, o_ref) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(math.prod(shape), dtype=jnp.int32).reshape(shape) + np.testing.assert_array_equal(kernel(x), x * 2) + def test_check(self): self.skip_if_wg_semantics() From ebac0505b3a079734e76bc70b443cb57e68d8516 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 May 2025 07:47:11 -0700 Subject: [PATCH 1278/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c74abfc3ecfdc31901acca65efa29ffda3ed84cc. PiperOrigin-RevId: 761521044 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8062f071f777..521e07bb6d70 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "47a47140b5f7a33d86901316d9c569de179bae08" -XLA_SHA256 = "c6c55e0e426b2f4a78f1b5459baab29a61d2435c089628e83db92ef8f890a02f" +XLA_COMMIT = "c74abfc3ecfdc31901acca65efa29ffda3ed84cc" +XLA_SHA256 = "e367e84d64730cfb94c58dd75477239183dbe74a73ea57bde3c87abcfcaffb3a" def repo(): tf_http_archive( From 382506f1705db9c9ac348b9783497e310feef6a5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 21 May 2025 08:01:15 -0700 Subject: [PATCH 1279/1769] Prepare for jax release v0.6.1 --- CHANGELOG.md | 2 +- jax/version.py | 2 +- setup.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fd4e50304d0..939177e01311 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## Unreleased +## JAX 0.6.1 * New features: * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis diff --git a/jax/version.py b/jax/version.py index 9301848b0cfb..acbfb7577e49 100644 --- a/jax/version.py +++ b/jax/version.py @@ -152,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.6.0" +_minimum_jaxlib_version = "0.6.1" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 823354adb70d..85ba1fdb4e96 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.6.0' +_current_jaxlib_version = '0.6.1' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.6.0' -_libtpu_version = '0.0.13.*' +_libtpu_version = '0.0.15.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From 99a0e678c8b6747efa3ed10e1ed4d28fecde525a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 21 May 2025 10:49:23 -0400 Subject: [PATCH 1280/1769] Try using uv for installing packages on Read the Docs. --- .readthedocs.yml | 17 ++++++++++------- docs/conf.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 3b7ba275a0d6..0ac20301cee2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,15 +6,23 @@ version: 2 build: - os: "ubuntu-22.04" + os: "ubuntu-24.04" tools: - python: "3.10" + python: "3.12" jobs: post_checkout: # Skip building PRs unless tagged with the "documentation" label. - | [ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0 (curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183 + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv $READTHEDOCS_VIRTUALENV_PATH + - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv pip install -r docs/requirements.txt + install: + - "true" # skip # Build documentation in the docs/ directory with Sphinx sphinx: @@ -24,8 +32,3 @@ sphinx: # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - -# Optionally set the version of Python and requirements required to build your docs -python: - install: - - requirements: docs/requirements.txt diff --git a/docs/conf.py b/docs/conf.py index a7a52c9db38c..a44177407344 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,11 +38,11 @@ from typing import ForwardRef def _do_not_evaluate_in_jax( - self, globalns, *args, _evaluate=ForwardRef._evaluate, + self, globalns, *args, _evaluate=ForwardRef._evaluate, **kwargs, ): if globalns.get('__name__', '').startswith('jax'): return self - return _evaluate(self, globalns, *args) + return _evaluate(self, globalns, *args, **kwargs) ForwardRef._evaluate = _do_not_evaluate_in_jax From eb1220374c066d429a20488a3557a813f078a734 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 21 May 2025 08:28:46 -0700 Subject: [PATCH 1281/1769] Skip oldest supported numpy presubmit on release branches PiperOrigin-RevId: 761534198 --- .github/workflows/oldest_supported_numpy.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index 80e0cb154ecd..06e7cf6230df 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -10,7 +10,6 @@ on: push: branches: - main - - 'release/**' # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. @@ -24,7 +23,7 @@ concurrency: jobs: test-oldest-supported-numpy: - if: github.event.repository.fork == false + if: "github.event.repository.fork == false && !startsWith(github.head_ref, 'release/')" defaults: run: shell: bash From 6d2dc34bb05e4f4fca3488b2f7ad587762ead497 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 21 May 2025 10:43:24 -0400 Subject: [PATCH 1282/1769] Skip generating source links for re-exported numpy functions in docs. --- docs/conf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index a44177407344..dd7533aecf83 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,6 +29,7 @@ import inspect import operator import os +from pathlib import Path import sys sys.path.insert(0, os.path.abspath('..')) @@ -354,7 +355,11 @@ def linkcode_resolve(domain, info): source, linenum = inspect.getsourcelines(obj) except: return None - filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) + try: + filename = Path(filename).relative_to(Path(jax.__file__).parent) + except ValueError: + # Source file is not a relative to jax; this must be a re-exported function. + return None lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" From c892cda147b894df7553ca1cbb730648a9ce98bb Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 21 May 2025 10:03:07 -0700 Subject: [PATCH 1283/1769] Revert "jax-cuda12-plugin: require nvidia-cublas-cu12<12.9" --- jax_plugins/cuda/plugin_setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index acd82702b357..c8b70408471c 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -53,8 +53,7 @@ def has_ext_modules(self): install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ 'with-cuda': [ - # cudnn has a bug with mxfp8 with multiple GPUs per process and cublas 12.9 - "nvidia-cublas-cu12>=12.1.3.1,<12.9", + "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", From 8e4f3b5dab9f88a8218b7ec79369210bb95305c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 May 2025 10:08:42 -0700 Subject: [PATCH 1284/1769] Introduce the flag `add_pypi_cuda_wheel_deps` that controls if the tests depend on NVIDIA CUDA wheels hermetically. The flag is enabled by default. To disable the dependency, pass `add_pypi_cuda_wheel_deps=False` in the Bazel options. PiperOrigin-RevId: 761568590 --- jaxlib/jax.bzl | 7 +++++++ jaxlib/tools/BUILD.bazel | 25 ++++++++++++++++++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index eceb38e35aab..678d92bc434a 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -637,3 +637,10 @@ def wheel_sources( ":{}_data".format(name), ":{}_hdrs".format(name), ] + static_srcs) + +def if_pypi_cuda_wheel_deps(if_true, if_false = []): + """ select() on whether we're adding pypi CUDA wheel deps. """ + return select({ + "//jaxlib/tools:pypi_cuda_wheel_deps": if_true, + "//conditions:default": if_false, + }) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index d6a5f94dfd4b..433747e2bb8d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -15,7 +15,7 @@ # JAX is Autograd and XLA load("@bazel_skylib//lib:selects.bzl", "selects") -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( @@ -29,6 +29,7 @@ load( load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", + "if_pypi_cuda_wheel_deps", "jax_py_test", "jax_wheel", "pytype_strict_library", @@ -470,6 +471,20 @@ filegroup( ], ) +# The flag configures whether to add the pypi NVIDIA CUDA deps to py_import. +bool_flag( + name = "add_pypi_cuda_wheel_deps", + build_setting_default = True, +) + +config_setting( + name = "pypi_cuda_wheel_deps", + flag_values = { + ":add_pypi_cuda_wheel_deps": "True", + "@local_config_cuda//:enable_cuda": "True", + }, +) + py_import( name = "jaxlib_py_import", wheel = ":jaxlib_wheel", @@ -478,26 +493,26 @@ py_import( py_import( name = "jax_cuda_plugin_py_import", wheel = ":jax_cuda_plugin_wheel", - wheel_deps = if_cuda([":nvidia_wheel_deps"]), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) py_import( name = "jax_cuda_pjrt_py_import", wheel = ":jax_cuda_pjrt_wheel", - wheel_deps = if_cuda([":nvidia_wheel_deps"]), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) # The targets below are used for GPU tests with `--//jax:build_jaxlib=false`. py_import( name = "pypi_jax_cuda_plugin_with_cuda_deps", wheel = "@pypi_jax_cuda12_plugin//:whl", - wheel_deps = if_cuda([":nvidia_wheel_deps"]), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) py_import( name = "pypi_jax_cuda_pjrt_with_cuda_deps", wheel = "@pypi_jax_cuda12_pjrt//:whl", - wheel_deps = if_cuda([":nvidia_wheel_deps"]), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) # Wheel tests. From f227b13613fe8d1a9c263006b40d31484b24f4c7 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 21 May 2025 10:32:09 -0700 Subject: [PATCH 1285/1769] Adjust triton dialect lowering rounding mode to allow upcasting fp8 types Fix: https://github.com/jax-ml/jax/issues/28416 PiperOrigin-RevId: 761577943 --- jax/_src/pallas/triton/lowering.py | 7 ++- tests/pallas/BUILD | 18 +++++++ tests/pallas/triton_pallas_test.py | 77 ++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 tests/pallas/triton_pallas_test.py diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 2cddb623b33f..bd70dc8d470c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1572,11 +1572,10 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: src_element_type = ir.FloatType(_element_type(src.type)) dst_element_type = ir.FloatType(_element_type(dst_type)) if src_element_type.width == 8 or dst_element_type.width == 8: - return tt_dialect.fp_to_fp( - dst_type, - src, - rounding=tt_dialect.RoundingMode.RTNE, + rounding = ( + tt_dialect.RoundingMode.RTNE if src_element_type.width > 8 else None ) + return tt_dialect.fp_to_fp(dst_type, src, rounding=rounding) if src_element_type.width > dst_element_type.width: return arith_dialect.truncf(dst_type, src) elif src_element_type.width < dst_element_type.width: diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 6690fb2dac62..d7df261a1ca9 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -748,6 +748,24 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "triton_pallas_test", + srcs = [ + "triton_pallas_test.py", + ], + enable_backends = ["cpu"], + enable_configs = [ + "gpu_h100_x32", + ], + shard_count = 1, + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", + ] + py_deps([ + "absl/testing", + ]), +) + jax_multiplatform_test( name = "mgpu_attention_run", srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"], diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py new file mode 100644 index 000000000000..4e2b10e72eb1 --- /dev/null +++ b/tests/pallas/triton_pallas_test.py @@ -0,0 +1,77 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test the Triton dialect lowering for a variety of atomic operations.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +DTYPE_LIST = [jnp.float32, jnp.float16, jnp.bfloat16, + jnp.float8_e4m3fn, jnp.float8_e5m2] + + +class TritonPallasTest(PallasBaseTest): + INTERPRET = False + + @parameterized.product(src_dtype=DTYPE_LIST, dst_dtype=DTYPE_LIST) + def test_fp_dtype_cast(self, src_dtype, dst_dtype): + if src_dtype == dst_dtype: + self.skipTest("No need to test the same dtype") + if dtypes.bit_width(src_dtype) == 8 and dtypes.bit_width(dst_dtype) == 8: + self.skipTest("Not casting between 8-bit types") + + def body(x_ref, y_ref): + y_ref[...] = x_ref[...].astype(dst_dtype) + + x = 10 * jax.random.normal(jax.random.key(0), (64, 64), dtype=src_dtype) + y = self.pallas_call(body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, 0))], + out_specs=pl.BlockSpec((64, 64), lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 64), dst_dtype), + grid=(1,), + )(x) + self.assertEqual(y.dtype, dst_dtype) + self.assertArraysEqual(y, x.astype(dst_dtype)) + +if __name__ == "__main__": + absltest.main() From 3179e5d84a624f8ac456a2588191f92813ddf2f6 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 21 May 2025 10:33:21 -0700 Subject: [PATCH 1286/1769] [Mosaic] Add faster implementation for s8->bf16 and s4->bf16 on TPUv6+. PiperOrigin-RevId: 761578503 --- jaxlib/mosaic/dialect/tpu/tpu.td | 8 +++++ .../tpu/transforms/apply_vector_layout.cc | 35 +++++++++++++++++++ .../tpu/transforms/canonicalize_mosaic.cc | 12 +++++++ .../tpu/transforms/infer_vector_layout.cc | 6 ++++ 4 files changed, 61 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b6ae1e52e822..505478b9ad72 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -510,6 +510,14 @@ def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { let hasCanonicalizeMethod = 1; } +// Internal operation. All arith.sitofp operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyVectorOfAnyRank:$in, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; +} + def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { let parameters = (ins ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index fa14c8ef9238..4200551ff450 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1027,6 +1027,40 @@ LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Unsupported FPToSI conversion"); } +LogicalResult tpu_sitofp_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + auto &layout_in = *layouts_in.front(); + auto &layout_out = *layouts_out.front(); + if (layout_in.bitwidth() == layout_out.bitwidth()) { + return elementwise_op_rule(ctx, op, layouts_in, layouts_out); + } else if (layout_in.bitwidth() < layout_out.bitwidth()) { + auto sitofp_op = cast(op); + switch (sitofp_op.getRoundingMode()) { + case tpu::RoundingMode::kToNearestEven: { + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array vregs, + ext_op_rule_impl(ctx, builder, sitofp_op, layout_in, layout_out)); + sitofp_op.replaceAllUsesWith(assemble(builder, sitofp_op.getType(), + layout_out, std::move(vregs), + ctx.target_shape) + .getResult()); + sitofp_op.erase(); + return success(); + } + case tpu::RoundingMode::kTowardsZero: + return op.emitOpError( + "Not implemented: SIToFP with rounding mode kTowardsZero"); + } + } + return op.emitOpError("Unsupported SIToFP conversion"); +} + LogicalResult func_return_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -7164,6 +7198,7 @@ const llvm::StringMap &rules() { {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, + {tpu::SIToFPOp::getOperationName(), tpu_sitofp_rule}, {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::ExtractOp::getOperationName(), vector_extract_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 368bfc596732..1d8ea1299f04 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -685,6 +685,18 @@ LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx, FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth, getElementTypeBitwidth(op.getType())); + // We have low-level optimized code for s8->bf16 and s4->bf16 casts on v6. + if (ctx.hardware_generation >= 6 && is_vector && + (src_vty.getElementType().isSignlessInteger(8) || + src_vty.getElementType().isSignlessInteger(4)) && + dst_vty.getElementType().isBF16()) { + auto new_op = builder.create( + op.getType(), op.getIn(), tpu::RoundingMode::kToNearestEven); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + return success(); + } + if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) { return op.emitOpError( "On this target integer-to-float conversions can only happen on " diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 9c4a7b4c397d..f01d0b4c5888 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -154,6 +154,12 @@ class VectorLayoutInferer { if (inferExt(&any_op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op); + op && op.getIn().getType().getElementTypeBitWidth() < + op.getType().getElementTypeBitWidth()) { + if (inferExt(&any_op).failed()) { + return failure(); + } } else if (isa(any_op)) { if (inferTrunc(&any_op).failed()) { return failure(); From 06c323ccc9a73f25b869555cf7b4822bbf35971e Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 21 May 2025 10:48:10 -0700 Subject: [PATCH 1287/1769] Show exactly what the copy paths/patterns are when copying wheels. This will make it easier to track down unexpected path mismatches in the future. PiperOrigin-RevId: 761584888 --- build/tools/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/build/tools/utils.py b/build/tools/utils.py index 4bf871067501..7ed7f74d07a5 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -293,9 +293,12 @@ def copy_dir_recursively(src, dst): logging.info("Editable wheel path: %s" % dst) -def copy_individual_files(src, dst, regex): +def copy_individual_files(src: str, dst: str, glob_pattern: str): os.makedirs(dst, exist_ok=True) - for f in glob.glob(os.path.join(src, regex)): + logging.debug( + f"Copying files matching pattern {glob_pattern!r} from {src!r} to {dst!r}" + ) + for f in glob.glob(os.path.join(src, glob_pattern)): dst_file = os.path.join(dst, os.path.basename(f)) if os.path.exists(dst_file): os.remove(dst_file) From 08c2a36e280b3380ed715301a0d6e396bccd9310 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 May 2025 11:54:30 -0700 Subject: [PATCH 1288/1769] [ragged-paged-attn] Use select for initialization in flash attention. PiperOrigin-RevId: 761612390 --- .../ops/tpu/ragged_paged_attention/kernel.py | 34 +++++-------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index df47674a59a9..d9d952d5a378 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -454,6 +454,11 @@ def masked_store(ref, val, start, end, group=1): mask = jnp.logical_and(iota >= start, iota < end) pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + def load_with_init(ref, init_val): + return jnp.where( + kv_blk_idx == 0, jnp.full_like(ref, init_val), ref[...] + ) + # kv lens will be contracting dim, we should mask out the NaNs. kv_mask = ( lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start @@ -468,29 +473,6 @@ def masked_store(ref, val, start, end, group=1): store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) - @pl.when(kv_blk_idx == 0) - def init_scratch_ref(): - masked_store( - head_m_ref, - jnp.full_like(head_m_ref, -jnp.inf), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_l_ref, - jnp.zeros_like(head_l_ref), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_acc_ref, - jnp.zeros_like(head_acc_ref), - store_start, - store_end, - ) - row_ids = ( (kv_len - q_len) + q_len_start @@ -522,8 +504,8 @@ def init_scratch_ref(): l_curr = jnp.broadcast_to( s_curr.sum(axis=1, keepdims=True), lm_store_shape ) - m_prev = head_m_ref[...] - l_prev = head_l_ref[...] + m_prev = load_with_init(head_m_ref, -jnp.inf) + l_prev = load_with_init(head_l_ref, 0.0) m_next = jnp.maximum(m_prev, m_curr) masked_store( head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head @@ -552,7 +534,7 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_acc_ref[...].reshape(-1, head_dim) + o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) From 2e070cade219ddb0033ec3ca2438bd8f74853218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 21 May 2025 14:44:42 -0700 Subject: [PATCH 1289/1769] [Mosaic:TPU][Relayout] Support minor to 2nd minor implicit dimension for unpacked types and native tiling on TPUv5 PiperOrigin-RevId: 761676578 --- .../tpu/transforms/apply_vector_layout.cc | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 4200551ff450..4d1323d68057 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -441,6 +441,121 @@ FailureOr maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder, .getResult(); } +// Transpose the 2nd minor dimension of the implicit shape. +// +// Shape of (..., N, 1) becomes (..., 1, N) +FailureOr> transposeSingletonMinorDimension( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array vregs, const ArrayRef ishape, + VectorLayout layout, const int64_t new_minor_offset) { + if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) { + // Note: For non-native tilings it is probably better to retile first, to + // to make the most out of each lane rotate (they are expensive). + return emitError(loc, "Not implemented: Unsupported bitwidth or tiling"); + } + auto create_index_const = [&](const int64_t idx) { + return builder.create(loc, idx); + }; + auto create_i32_vreg_const = [&](const int64_t val) { + return I32Const(val, ctx.target_shape, builder, loc); + }; + if (layout.offsets()[1].has_value()) { + // Replicate minor dimension + // TODO(tlongeri): Move into its own function (it will be needed for + // relayout) and make this a precondition of this function, so that we have + // "building block" functions with minimal overlap + vregs.Each([&](const absl::Span idxs, Value *vreg) { + *vreg = builder.create( + loc, vreg->getType(), *vreg, + create_i32_vreg_const(*layout.offsets()[1]), 1); + }); + layout = + VectorLayout(layout.bitwidth(), {layout.offsets()[0], std::nullopt}, + layout.tiling(), VectorLayout::ImplicitDim::kNone); + } + if (!layout.offsets()[0].has_value()) { + return vregs; + } + const int64_t old_2nd_minor_offset = *layout.offsets()[0]; + SmallVector new_ishape(ishape); + CHECK_EQ(new_ishape.back(), 1); + std::iter_swap(new_ishape.end() - 2, new_ishape.end() - 1); + // new_layout is only to get the new vreg array shape, the implicit dim is + // irrelevant (since we already have the implicit shape): + const VectorLayout new_layout( + layout.bitwidth(), {std::nullopt, new_minor_offset}, layout.tiling(), + VectorLayout::ImplicitDim::kNone); + xla::Array new_vregs(new_layout.tileArrayShape( + /*src_is_implicit=*/true, /*res_is_implicit=*/true, new_ishape, + ctx.target_shape)); + VectorType iota_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + // Preallocate an indices vector to avoid repeated allocations: + SmallVector old_idxs; + new_vregs.Each([&](const absl::Span new_idxs, + Value *new_vreg) { + const int64_t uncorrected_shape_start = + ctx.target_shape[1] * new_idxs.back() - new_minor_offset; + // The start and end of the data contained by new_vreg in the implicit shape + const int64_t shape_start = std::max(uncorrected_shape_start, 0); + const int64_t shape_end = std::min( + uncorrected_shape_start + ctx.target_shape[1], new_ishape.back()); + old_idxs.assign(new_idxs.begin(), new_idxs.end()); + CHECK_EQ(*(old_idxs.end() - 2), 0); + old_idxs.back() = 0; + *new_vreg = nullptr; + VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), 32, ctx.target_shape); + int64_t shape_offset = shape_start; + // The data in the new vreg is composed of data from multiple of the old + // vregs, so iterate over them until the new vreg is full + while (shape_offset < shape_end) { + // Find the vreg that contains the data at shape_offset + *(old_idxs.end() - 2) = + (shape_offset + old_2nd_minor_offset) / ctx.target_shape[0]; + const int64_t old_sublane_offset = + (shape_offset + old_2nd_minor_offset) % ctx.target_shape[0]; + const int64_t new_lane_offset = + (shape_offset + new_minor_offset) % ctx.target_shape[1]; + // We will blend in all the relevant data contained by the old vreg + const int64_t data_size = + std::min(ctx.target_shape[0] - old_sublane_offset, + ctx.target_shape[1] - new_lane_offset); + // [ a a a a a a a a ] [ . . a b c . . . ] + // [ b b b b b b b b ] => [ . . a b c . . . ] + // [ c c c c c c c c ] [ . . a b c . . . ] + // [ . . . . . . . . ] [ . . a b c . . . ] + // Every lane has all the data, so at each sublane we can just pick out + // the element that we want using a sublane shuffle. + Value vreg = vregs(old_idxs); + Value iota_vreg = builder.create( + loc, iota_vreg_ty, + /*dimension =*/builder.getI32IntegerAttr(1)); + iota_vreg = builder.create( + loc, iota_vreg, + create_i32_vreg_const(old_sublane_offset - new_lane_offset)); + vreg = builder.create(loc, vreg.getType(), vreg, + iota_vreg, 0); + // Now, blend the transposed data into new_vreg + if (*new_vreg == nullptr) { + *new_vreg = vreg; + } else { + Value mask = builder.create( + loc, vmask_ty, + ArrayRef{create_index_const(0), + create_index_const(new_lane_offset)}, + ArrayRef{create_index_const(ctx.target_shape[0]), + create_index_const(new_lane_offset + data_size)}); + *new_vreg = builder.create(loc, mask, vreg, *new_vreg); + } + shape_offset += data_size; + ++*(old_idxs.end() - 2); + } + CHECK(*new_vreg != nullptr); + }); + return new_vregs; +} + // Insert a minor dimension to the implicit shape. The original minor dimension // becomes the new second minor dimension, laid out across sublanes. // @@ -6904,6 +7019,19 @@ FailureOr>> changeImplicitDim( dst.offsets())); return std::make_pair(dst, std::move(dst_vregs)); } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kMinor && + dst_implicit_dim == VectorLayout::ImplicitDim::kSecondMinor && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + const int64_t dst_minor_offset = dst_offset_hints[1].value_or(0); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs, + transposeSingletonMinorDimension(ctx, builder, loc, vregs, + src.implicitShape(vty.getShape()), src, + dst_minor_offset)); + VectorLayout dst(src.bitwidth(), {std::nullopt, dst_minor_offset}, + src.tiling(), VectorLayout::ImplicitDim::kSecondMinor); + return std::make_pair(dst, std::move(dst_vregs)); + } return emitError(loc, "Not implemented: Unsupported implicit dim change: from ") << src << " to " << dst_implicit_dim; From 7e9c7e69427628bce16a39701710e3d781b18468 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 21 May 2025 14:23:57 -0700 Subject: [PATCH 1290/1769] Update version numbers after 0.6.1 release. --- CHANGELOG.md | 4 +++- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 ++ jax/version.py | 2 +- setup.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 939177e01311..1e866fae6af5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## JAX 0.6.1 +## Unreleased + +## JAX 0.6.1 (May 21, 2025) * New features: * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index db608adc3dde..ece88841fdc5 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -48,6 +48,7 @@ config.parse_flags_with_absl() +@unittest.skip("Failing after jax 0.6.1 release") class Jax2TfTest(tf_test_util.JaxToTfTestCase): def setUp(self): @@ -1782,6 +1783,7 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) +@unittest.skip("Failing after jax 0.6.1 release") class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): diff --git a/jax/version.py b/jax/version.py index acbfb7577e49..e15af7ab50fc 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.6.1" +_version = "0.6.2" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index 85ba1fdb4e96..4c5c86f588c3 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.6.1' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.6.0' +_latest_jaxlib_version_on_pypi = '0.6.1' _libtpu_version = '0.0.15.*' From a1d28dc2df6c8545f63b93466ef54442e56bd00b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 21 May 2025 15:19:56 -0700 Subject: [PATCH 1291/1769] Move _src/stages.py to its own build target Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This refactor required moving the definitions of a few private utilities from pjit and pxla, because these files are part of the larger jax build target. PiperOrigin-RevId: 761689391 --- jax/BUILD | 21 +++++- jax/_src/dispatch.py | 9 +-- jax/_src/interpreters/pxla.py | 78 +++------------------ jax/_src/pjit.py | 48 +------------ jax/_src/stages.py | 126 +++++++++++++++++++++++++++++++--- 5 files changed, 150 insertions(+), 132 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 18e670e0269c..4218cd3f0a77 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -326,7 +326,6 @@ py_library_providing_imports_info( "_src/shard_alike.py", "_src/shard_map.py", "_src/sourcemap.py", - "_src/stages.py", "_src/tree.py", ] + glob( [ @@ -415,6 +414,7 @@ py_library_providing_imports_info( ":sharding_impls", ":sharding_specs", ":source_info_util", + ":stages", ":traceback_util", ":tree_util", ":typing", @@ -1001,6 +1001,25 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "stages", + srcs = ["_src/stages.py"], + deps = [ + ":config", + ":core", + ":layout", + ":mlir", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "compute_on", srcs = ["_src/compute_on.py"], diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8f553ea884d7..9a11ffa104a8 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -24,7 +24,7 @@ import logging import threading import time -from typing import Any, Callable, NamedTuple +from typing import Any, Callable import jax from jax._src import api @@ -34,7 +34,6 @@ from jax._src import core from jax._src import dtypes from jax._src import lib -from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.abstract_arrays import array_types @@ -52,6 +51,7 @@ from jax._src.sharding_impls import ( NamedSharding, SingleDeviceSharding, TransferToMemoryKind, GSPMDSharding, is_single_device_sharding) +from jax._src.stages import SourceInfo import numpy as np @@ -240,11 +240,6 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: return False -class SourceInfo(NamedTuple): - source_info: source_info_util.SourceInfo - eqn_name: str - - @util.weakref_lru_cache def get_intermediate_shardings( jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4a21fae59e52..d0a22cd784b4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -15,7 +15,6 @@ from __future__ import annotations -import enum import collections from collections import namedtuple from collections.abc import Callable, Sequence, Iterable @@ -1660,67 +1659,10 @@ def check_if_any_auto( return True return False -class MismatchType(enum.Enum): - ARG_SHARDING = 0 - OUT_SHARDING = 1 - SHARDING_INSIDE_COMPUTATION = 2 - CONTEXT_DEVICES = 3 - IN_SHARDING = 4 - - def __str__(self): - if self.name == 'IN_SHARDING': - return 'explicit input sharding' - elif self.name == 'OUT_SHARDING': - return 'explicit output sharding' - elif self.name == 'CONTEXT_DEVICES': - return 'context mesh' - return f'{self.name}' - - -@dataclasses.dataclass -class DeviceAssignmentMismatch: - da: Sequence[xc.Device] - m_type: MismatchType - source_info: dispatch.SourceInfo | None - - @property - def device_ids(self) -> Sequence[int]: - return [d.id for d in self.da] - - @property - def platform(self) -> str: - return self.da[0].platform.upper() - - def _maybe_api_name(self, api_name) -> str: - return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" - - @property - def source_info_str(self): - return ( - "" if self.source_info is None - else f" at {source_info_util.summarize(self.source_info.source_info)}" - ) - - @property - def _dev_ids_plat_str(self): - return f"device ids {self.device_ids} on platform {self.platform}" - - def m_type_str(self, api_name): - return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' - if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) - - def _str(self, api_name): - return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " - f"{self._dev_ids_plat_str}{self.source_info_str}") - - -class DeviceAssignmentMismatchError(Exception): - pass - ShardingInfo = tuple[ Union[JSharding, UnspecifiedValue, AUTO], - MismatchType, + stages.MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1752,14 +1694,14 @@ def _get_and_check_device_assignment( else sh._device_assignment) if not devices: if first_sharding_info[0] != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(*first_sharding_info), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch(*first_sharding_info), + stages.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) else: if devices != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch(devices, stages.MismatchType.CONTEXT_DEVICES, None), + stages.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: @@ -2283,9 +2225,9 @@ def lower_sharding_computation( unique_out_shardings = util.stable_unique(out_shardings) backend, device_assignment = _get_and_check_device_assignment( it.chain( - ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), - ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), - ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) + ((i, stages.MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), + ((o, stages.MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), + ((js, stages.MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index de01f4c05983..0503e58b2e45 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -96,48 +96,6 @@ logger = logging.getLogger(__name__) -def _find_arg_mismatch(arg_list, fails, fun_name): - mismatched_args_msg = [] - def mismatch(err): - for name, inp_da, aval in arg_list: - if err.m_type == pxla.MismatchType.ARG_SHARDING and err.da == inp_da: - mismatched_args_msg.append( - f"argument {name} of {fun_name} with shape {aval.str_short()} and " - f"{err._dev_ids_plat_str}") - break - first_err, second_err = fails - mismatch(first_err) - mismatch(second_err) - return mismatched_args_msg - - -def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, - arg_names): - arg_list = [] - if arg_names is None: - arg_names = [''] * len(args_flat) - for a, n in zip(args_flat, arg_names): - da = (a.sharding._device_assignment - if getattr(a, 'sharding', None) is not None else None) - arg_list.append((n, da, core.shaped_abstractify(a))) - - mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) - - if len(mismatched_args_msg) == 2: - first, second = mismatched_args_msg # pytype: disable=bad-unpacking - extra_msg = f" Got {first} and {second}" - elif len(mismatched_args_msg) == 1: - first, second = fails - # Choose the failure left which is not already covered by ARG_SHARDING. - left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first - extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" - else: - first, second = fails - extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" - msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") - return msg - - class PjitInfo(NamedTuple): """Things that we know about a jit instance before it is called. @@ -197,10 +155,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): out_flat = pjit_p.bind(*args_flat, **p.params) compiled = None profiler = None - except pxla.DeviceAssignmentMismatchError as e: + except stages.DeviceAssignmentMismatchError as e: fails, = e.args fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) - msg = _device_assignment_mismatch_error( + msg = stages._device_assignment_mismatch_error( fun_name, fails, args_flat, 'jit', p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: @@ -1740,7 +1698,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if isinstance(arg_s, PmapSharding): continue if getattr(a, '_committed', True): - committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) + committed_arg_shardings.append((arg_s, stages.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3c5d710f3bdc..d92d1ccb2aa3 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,18 +30,20 @@ """ from __future__ import annotations +import dataclasses +import enum import functools from collections.abc import Sequence from dataclasses import dataclass from typing import Any, NamedTuple, Protocol, Union, runtime_checkable -import jax - from jax._src import core from jax._src import config +from jax._src import sharding as sharding_lib from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout @@ -79,7 +81,7 @@ def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: """Optionally constructs a fast c++ dispatcher.""" return None - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: + def input_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -88,7 +90,7 @@ def input_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: + def output_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -310,8 +312,8 @@ def dtype(self): @dataclass(frozen=True) class OutInfo: shape: tuple[int, ...] - dtype: jax.typing.DTypeLike - sharding: jax.sharding.Sharding | None = None + dtype: typing.DTypeLike + sharding: sharding_lib.Sharding | None = None class Stage: @@ -689,9 +691,6 @@ def out_info(self): def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): """Lower to compiler input, returning a ``Lowered`` instance.""" - from jax._src.interpreters import pxla - from jax._src import pjit - if _private_parameters is None: _private_parameters = mlir.LoweringParameters() new_callable = functools.partial( @@ -699,9 +698,9 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, lowering_parameters=_private_parameters) try: lowering = new_callable() - except pxla.DeviceAssignmentMismatchError as e: + except DeviceAssignmentMismatchError as e: fails, = e.args - msg = pjit._device_assignment_mismatch_error( + msg = _device_assignment_mismatch_error( self.fun_name, fails, self._args_flat, 'jit', self._arg_names) raise ValueError(msg) from None return Lowered(lowering, self.args_info, self._out_tree) @@ -745,3 +744,108 @@ def lower(self, *args, **kwargs) -> Lowered: A ``Lowered`` instance representing the lowering. """ raise NotImplementedError + + +class MismatchType(enum.Enum): + ARG_SHARDING = 0 + OUT_SHARDING = 1 + SHARDING_INSIDE_COMPUTATION = 2 + CONTEXT_DEVICES = 3 + IN_SHARDING = 4 + + def __str__(self): + if self.name == 'IN_SHARDING': + return 'explicit input sharding' + elif self.name == 'OUT_SHARDING': + return 'explicit output sharding' + elif self.name == 'CONTEXT_DEVICES': + return 'context mesh' + return f'{self.name}' + + +class SourceInfo(NamedTuple): + source_info: source_info_util.SourceInfo + eqn_name: str + + +@dataclasses.dataclass +class DeviceAssignmentMismatch: + da: Sequence[xc.Device] + m_type: MismatchType + source_info: SourceInfo | None + + @property + def device_ids(self) -> Sequence[int]: + return [d.id for d in self.da] + + @property + def platform(self) -> str: + return self.da[0].platform.upper() + + def _maybe_api_name(self, api_name) -> str: + return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" + + @property + def source_info_str(self): + return ( + "" if self.source_info is None + else f" at {source_info_util.summarize(self.source_info.source_info)}" + ) + + @property + def _dev_ids_plat_str(self): + return f"device ids {self.device_ids} on platform {self.platform}" + + def m_type_str(self, api_name): + return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' + if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) + + def _str(self, api_name): + return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " + f"{self._dev_ids_plat_str}{self.source_info_str}") + + +class DeviceAssignmentMismatchError(Exception): + pass + + +def _find_arg_mismatch(arg_list, fails, fun_name): + mismatched_args_msg = [] + def mismatch(err): + for name, inp_da, aval in arg_list: + if err.m_type == MismatchType.ARG_SHARDING and err.da == inp_da: + mismatched_args_msg.append( + f"argument {name} of {fun_name} with shape {aval.str_short()} and " + f"{err._dev_ids_plat_str}") + break + first_err, second_err = fails + mismatch(first_err) + mismatch(second_err) + return mismatched_args_msg + + +def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, + arg_names): + arg_list = [] + if arg_names is None: + arg_names = [''] * len(args_flat) + for a, n in zip(args_flat, arg_names): + da = (a.sharding._device_assignment + if getattr(a, 'sharding', None) is not None else None) + arg_list.append((n, da, core.shaped_abstractify(a))) + + mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) + + if len(mismatched_args_msg) == 2: + first, second = mismatched_args_msg # pytype: disable=bad-unpacking + extra_msg = f" Got {first} and {second}" + elif len(mismatched_args_msg) == 1: + first, second = fails + # Choose the failure left which is not already covered by ARG_SHARDING. + left = second if first.m_type == MismatchType.ARG_SHARDING else first + extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" + else: + first, second = fails + extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" + msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") + return msg From efc70a06e29c6f4a6f18d86e261a2fe1546572a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 21 May 2025 15:31:09 -0700 Subject: [PATCH 1292/1769] [Mosaic:TPU][Relayout] Remove minor implicit dimension for 32-bit native tiling PiperOrigin-RevId: 761692972 --- .../dialect/tpu/transforms/apply_vector_layout.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 4d1323d68057..99134b46315f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6887,7 +6887,7 @@ FailureOr>> changeTiling( FailureOr>> changeImplicitDim( RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, - const VectorLayout src, xla::Array vregs, + VectorLayout src, xla::Array vregs, const VectorLayout::ImplicitDim dst_implicit_dim, const LayoutOffsets dst_offset_hints) { const auto &target_shape = ctx.target_shape; @@ -7032,6 +7032,18 @@ FailureOr>> changeImplicitDim( src.tiling(), VectorLayout::ImplicitDim::kSecondMinor); return std::make_pair(dst, std::move(dst_vregs)); } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kMinor && + dst_implicit_dim == VectorLayout::ImplicitDim::kNone && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + FAILUREOR_ASSIGN_OR_RETURN( + std::tie(src, vregs), + changeImplicitDim(ctx, builder, loc, vty, src, std::move(vregs), + VectorLayout::ImplicitDim::kSecondMinor, + dst_offset_hints)); + return changeImplicitDim(ctx, builder, loc, vty, src, std::move(vregs), + VectorLayout::ImplicitDim::kNone, + dst_offset_hints); + } return emitError(loc, "Not implemented: Unsupported implicit dim change: from ") << src << " to " << dst_implicit_dim; From 61a9bd2b3d58640452c9f3c8514736cff9ac4cfe Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 May 2025 16:14:20 -0700 Subject: [PATCH 1293/1769] Allow eval_shape to propagate shardings if the aval has shardings in full explicit mode PiperOrigin-RevId: 761708753 --- jax/_src/pjit.py | 16 +++++++++++----- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0503e58b2e45..6340d96a55ee 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -350,11 +350,17 @@ def jit_lower(jit_func, *args, **kwargs): @api_boundary def jit_eval_shape(jit_func, *args, **kwargs): p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) - out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] - # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, - weak_type=x.weak_type) - for x, s in zip(p.params['jaxpr'].out_avals, out_s)] + out_shardings = [None if isinstance(s, UnspecifiedValue) else s + for s in p.params['out_shardings']] + out = [] + for a, out_s in zip(p.params['jaxpr'].out_avals, out_shardings): + if out_s is None: + s = a.sharding if a.sharding.mesh._are_all_axes_explicit else out_s + else: + s = out_s + # TODO(yashkatariya): Add `Layout` to SDS. + out.append(api.ShapeDtypeStruct(a.shape, a.dtype, sharding=s, + weak_type=a.weak_type)) return tree_unflatten(p.out_tree, out) def jit_evict_fn(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 523601691a97..aa2e0af2a57c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7857,6 +7857,20 @@ def g(x, y): core.ShardingTypeError, "lhs is unreduced while rhs is not"): g.trace(x, y) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_eval_shape(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + return x * 2 + + out = jax.eval_shape(f, arr) + self.assertIsInstance(out, jax.ShapeDtypeStruct) + self.assertEqual(out.sharding, + NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From f9a1475e1a286a8d21ec8685aea9f92b2c40d941 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Wed, 21 May 2025 15:43:49 -0700 Subject: [PATCH 1294/1769] Expose tree prefix broadcasting as a public API in tree utils. --- CHANGELOG.md | 3 +++ docs/jax.tree.rst | 1 + docs/jax.tree_util.rst | 1 + jax/_src/tree.py | 31 +++++++++++++++++++++++++++++++ jax/_src/tree_util.py | 35 ++++++++++++++++++++++++++++++----- jax/tree.py | 1 + jax/tree_util.py | 1 + tests/tree_util_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 108 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e866fae6af5..b34bf36997af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* New features: + * Added {func}`jax.tree.broadcast` which implements a pytree prefix broadcasting helper. + ## JAX 0.6.1 (May 21, 2025) * New features: diff --git a/docs/jax.tree.rst b/docs/jax.tree.rst index e65c77c757c1..1a0ddaec86d0 100644 --- a/docs/jax.tree.rst +++ b/docs/jax.tree.rst @@ -12,6 +12,7 @@ List of Functions :toctree: _autosummary all + broadcast flatten flatten_with_path leaves diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 73fd1f376e9f..c89b777ca548 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -38,6 +38,7 @@ These APIs are now accessed via :mod:`jax.tree`. :toctree: _autosummary tree_all + tree_broadcast tree_flatten tree_leaves tree_map diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 70d75a126804..9a3e001d902b 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -378,3 +378,34 @@ def map_with_path( - :func:`jax.tree_util.register_pytree_with_keys` """ return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) + + +def broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> list[Any]: + """Broadcasts a tree prefix into the full structure of a given tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A pytree matching the structure of full_tree where the leaves of prefix_tree have been + broadcasted into the leaves of each corresponding subtree. + + Examples: + >>> import jax + >>> prefix = (1, 2, 3) + >>> full = (0, {'a': 0, 'b': 0}, (0, 0)) + >>> jax.tree.broadcast(prefix, full) + (1, {'a': 2, 'b': 2}, (3, 3)) + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + """ + return tree_util.tree_broadcast(prefix_tree, full_tree, is_leaf=is_leaf) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index e2e97c90f120..6edbbfd62d12 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -560,17 +560,42 @@ def __new__(klass, func, *args, **kw): ) -# broadcast_prefix is not exported. +@export +def tree_broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> list[Any]: + """Alias of :func:`jax.tree.broadcast`.""" + broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf) + return tree_structure(full_tree).unflatten(broadcast_leaves) + + +# broadcast_prefix is not exported def broadcast_prefix(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[Any]: - # If prefix_tree is not a tree prefix of full_tree, this code can raise a - # ValueError; use prefix_errors to find disagreements and raise more precise - # error messages. + """Broadcasts tree prefix leaves into the full set of leaves for a given full tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A list of leaves matching the expected count for the full tree, + with the leaf of each prefix tree being duplicated to match the count of + its corresponding subtree. + """ result = [] num_leaves = lambda t: tree_structure(t).num_leaves add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) - tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + try: + tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + except ValueError: + e, *_ = prefix_errors(prefix_tree, full_tree) + raise e('broadcast_prefix prefix_tree') from None return result diff --git a/jax/tree.py b/jax/tree.py index 270c34fe9647..03ca503f3a41 100644 --- a/jax/tree.py +++ b/jax/tree.py @@ -19,6 +19,7 @@ from jax._src.tree import ( all as all, + broadcast as broadcast, flatten_with_path as flatten_with_path, flatten as flatten, leaves_with_path as leaves_with_path, diff --git a/jax/tree_util.py b/jax/tree_util.py index 9f42284144ec..b35890dfc887 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -58,6 +58,7 @@ register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, + tree_broadcast as tree_broadcast, tree_flatten_with_path as tree_flatten_with_path, tree_flatten as tree_flatten, tree_leaves_with_path as tree_leaves_with_path, diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 0df811d9da28..8d4cd5854e7d 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -627,6 +627,39 @@ def testTransposeWithCustomObject(self): FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) + @parameterized.parameters(*TREES) + def testBroadcast(self, tree): + if isinstance(tree, FlatCache): + # The tree_map construction below fails for FlatCache, because + # the cached metadata becomes out of sync. + self.skipTest("Test does not work properly for FlatCache.") + def make_inner(x): + return [x, x, x] + nested = tree_util.tree_map(make_inner, tree) + actual = tree_util.tree_broadcast(tree, nested) + self.assertEqual(actual, nested) + + def testBroadcastSimple(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = tree_util.tree_broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + + def testBroadcastError(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + prefix = (1, 2) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + prefix = (1, {'a': 0}) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" @@ -1444,6 +1477,13 @@ def test_tree_transpose(self): tree_util.tree_transpose(outer_treedef, inner_treedef, obj) ) + def test_tree_broadcast(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = jax.tree.broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + def test_tree_unflatten(self): leaves, treedef = jax.tree.flatten([1, 2, (3, 4)]) self.assertEqual( From 8da86ea0a3128d2cff517251f40a5b3e285f4d7b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 21 May 2025 16:53:35 -0700 Subject: [PATCH 1295/1769] Move _src/interpreters/ad.py to its own BUILD rule. Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This required moving some internal utilities out of dispatch.py, which is part of the main JAX build rule. I chose api_util.py because they seem to fit there. PiperOrigin-RevId: 761722054 --- jax/BUILD | 19 +++++++++++++++++- jax/_src/api.py | 7 ++++--- jax/_src/api_util.py | 38 +++++++++++++++++++++++++++++++++++ jax/_src/dispatch.py | 40 ++----------------------------------- jax/_src/interpreters/ad.py | 18 +++++------------ jax/_src/lax/parallel.py | 7 +++++++ jax/_src/pjit.py | 8 ++++---- jax/_src/shard_map.py | 4 ++-- jax/extend/BUILD | 1 + 9 files changed, 81 insertions(+), 61 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 4218cd3f0a77..e431a3c5056c 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -316,7 +316,6 @@ py_library_providing_imports_info( "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", - "_src/interpreters/ad.py", "_src/interpreters/batching.py", "_src/interpreters/pxla.py", "_src/pjit.py", @@ -381,6 +380,7 @@ py_library_providing_imports_info( visibility = ["//visibility:public"], deps = [ ":abstract_arrays", + ":ad", ":ad_util", ":api_util", ":basearray", @@ -671,6 +671,23 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "ad", + srcs = ["_src/interpreters/ad.py"], + deps = [ + ":ad_util", + ":api_util", + ":config", + ":core", + ":dtypes", + ":mesh", + ":partial_eval", + ":source_info_util", + ":tree_util", + ":util", + ], +) + pytype_strict_library( name = "mlir", srcs = ["_src/interpreters/mlir.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index 059db1c92c98..3ff103997dc7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -37,6 +37,7 @@ import numpy as np from contextlib import contextmanager +from jax._src import api_util from jax._src import deprecations from jax._src import linear_util as lu from jax._src import stages @@ -113,14 +114,14 @@ def _nan_check_posthook(fun, args, kwargs, output): try: dispatch.check_special(pjit.pjit_p.name, buffers) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value if hasattr(fun, '_fun'): f = fun._fun if getattr(f, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None # compiled_fun can only raise in this case - dispatch.maybe_recursive_nan_check(e, f, args, kwargs) + api_util.maybe_recursive_nan_check(e, f, args, kwargs) raise AssertionError("Unreachable") from e else: # TODO(emilyaf): Shouldn't need this fallback. @@ -1707,7 +1708,7 @@ def cache_miss(*args, **kwargs): out = execute(*p.flat_args) else: out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.') out_tree, out_flat = p.out_tree, out diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 163bade2065c..2e7ba551c624 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -767,3 +767,41 @@ def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> Non f"array reference of type {a.str_short()} was both closed over and " f"passed as the argument " f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}") + +class InternalFloatingPointError(Exception): + name: str + ty: str + + def __init__(self, name: str, ty: str): + self.name = name + self.ty = ty + +def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, +) -> None: # always raises an exception + print("Invalid nan value encountered in the output of a jax.jit " + "function. Calling the de-optimized version.") + try: + _ = fun(*args, **kwargs) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + _raise_no_nan_in_deoptimized(e) + + +def _raise_no_nan_in_deoptimized(e) -> None: + msg = (f"{str(e)}. Because " + "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " + "de-optimized function (i.e., the function as if the `jit` " + "decorator were removed) was called in an attempt to get a more " + "precise error message. However, the de-optimized function did not " + "produce invalid values during its execution. This behavior can " + "result from `jit` optimizations causing the invalid value to be " + "produced. It may also arise from having nan/inf literals as " + "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " + "\n\n" + "It may be possible to avoid the invalid value by removing the " + "`jit` decorator, at the cost of losing optimizations. " + "\n\n" + "If you see this error, consider opening a bug report at " + "https://github.com/jax-ml/jax.") + raise FloatingPointError(msg) from None diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 9a11ffa104a8..d1ea7439cb0c 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -24,7 +24,7 @@ import logging import threading import time -from typing import Any, Callable +from typing import Any import jax from jax._src import api @@ -42,6 +42,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla +from jax._src.api_util import InternalFloatingPointError from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh @@ -341,43 +342,6 @@ class CopySemantics(enum.Enum): COPY = enum.auto() DONATE = enum.auto() -class InternalFloatingPointError(Exception): - name: str - ty: str - - def __init__(self, name: str, ty: str): - self.name = name - self.ty = ty - -def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, -) -> None: # always raises an exception - print("Invalid nan value encountered in the output of a jax.jit " - "function. Calling the de-optimized version.") - try: - _ = fun(*args, **kwargs) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - _raise_no_nan_in_deoptimized(e) - -def _raise_no_nan_in_deoptimized(e) -> None: - msg = (f"{str(e)}. Because " - "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " - "de-optimized function (i.e., the function as if the `jit` " - "decorator were removed) was called in an attempt to get a more " - "precise error message. However, the de-optimized function did not " - "produce invalid values during its execution. This behavior can " - "result from `jit` optimizations causing the invalid value to be " - "produced. It may also arise from having nan/inf literals as " - "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " - "\n\n" - "It may be possible to avoid the invalid value by removing the " - "`jit` decorator, at the cost of losing optimizations. " - "\n\n" - "If you see this error, consider opening a bug report at " - "https://github.com/jax-ml/jax.") - raise FloatingPointError(msg) from None - def _identity_fn(x): return x diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 29af03416a76..9366b91f8022 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,12 +21,12 @@ from functools import partial from typing import Any +from jax._src import api_util from jax._src import config -from jax._src import dispatch from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_unflatten, - register_pytree_node, Partial, PyTreeDef) +from jax._src.tree_util import (tree_flatten, tree_unflatten, + register_pytree_node, Partial, PyTreeDef) from jax._src import mesh as mesh_lib from jax._src import core from jax._src import source_info_util @@ -1125,7 +1125,7 @@ def out_axes_thunk(): try: out_flat = primitive.bind(fun, *all_args, **new_params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: @@ -1135,7 +1135,7 @@ def out_axes_thunk(): else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the @@ -1266,11 +1266,3 @@ def __init__(self): # TODO(mattjj): remove this vestigial dict reducing_transposes: dict[core.Primitive, Callable] = {} - -########################### pvary ################################## - -def _pvary_transpose_rule(cts, *_, axes, axis_index_groups): - from jax._src.lax import parallel as lax_parallel - return lax_parallel.psum_invariant_p.bind( - *cts, axes=axes, axis_index_groups=axis_index_groups) -deflinear2(core.pvary_p, _pvary_transpose_rule) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index a9abf8f12939..bf27261a2c8e 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -2059,3 +2059,10 @@ def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups): del args return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule) + +########################### pvary ################################## + +def _pvary_transpose_rule(cts, *_, axes, axis_index_groups): + return psum_invariant_p.bind( + *cts, axes=axes, axis_index_groups=axis_index_groups) +ad.deflinear2(core.pvary_p, _pvary_transpose_rule) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6340d96a55ee..0624dad88a2b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -176,10 +176,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): f"Argument '{name}' of shape {aval.str_short()} of type" f' {type(arg)} is not a valid JAX type.') from e raise AssertionError("Unreachable") from e - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: if getattr(fun, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None - dispatch.maybe_recursive_nan_check(e, fun, args, kwargs) + api_util.maybe_recursive_nan_check(e, fun, args, kwargs) if p.box_data: box_treedef, out_tree = p.out_tree.children() @@ -2562,7 +2562,7 @@ def prune_type(ty, xs, maybe_zeros): keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: @@ -2572,7 +2572,7 @@ def prune_type(ty, xs, maybe_zeros): else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs]) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 72e1420b0b2b..66df2505100c 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1054,7 +1054,7 @@ def _maybe_check_special(outs): for s in getattr(leaf, 'addressable_shards', [])] try: dispatch.check_special('shard_map', bufs) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None class ShardMapTrace(core.Trace): @@ -1562,7 +1562,7 @@ def new_out_specs_thunk(): except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 59958c1da389..06fb8e671120 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -43,6 +43,7 @@ py_library_providing_imports_info( deps = [ "//jax", "//jax:abstract_arrays", + "//jax:ad", "//jax:ad_util", "//jax:core", ], From 62c46ff976db981912d2182760d4f2abffb972fa Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 May 2025 17:23:52 -0700 Subject: [PATCH 1296/1769] Add initial support aligned jnp.swapaxes on major/minor dims Next steps: - non-tile aligned - Clean up fn and utilize it for general changeTiling PiperOrigin-RevId: 761731600 --- jax/_src/pallas/mosaic/lowering.py | 4 +- .../tpu/transforms/apply_vector_layout.cc | 308 +++++++++++++++++- .../tpu/transforms/infer_vector_layout.cc | 28 +- tests/pallas/ops_test.py | 48 ++- 4 files changed, 375 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 919c548adc9b..dd0b9ba4b4a7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2383,7 +2383,9 @@ def _gather_lowering_rule( @register_lowering_rule(lax.transpose_p) def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): - if permutation != (1, 0): + minormost_transpose = (1, 0) + untiled_tiled_swap = (1, 0, 2) + if permutation not in (minormost_transpose, untiled_tiled_swap): raise NotImplementedError out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 99134b46315f..6502a9c6682e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5013,11 +5013,315 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, ctx.target_shape)); ArrayRef permutation = transpose_op.getPermutation(); const auto tile_perm = permutation.take_back(2); + + // Major minor pemute if (tile_perm != ArrayRef{rank - 2, rank - 1} && tile_perm != ArrayRef{rank - 1, rank - 2}) { - return transpose_op->emitOpError( - "Not implemented: Unsupported permutation"); + // This is a 3 stage algorithm that uses combinations and shuffles + // to do a transposition of an 8x8 block of sublanes. + // In the following algorithm description, A, B, ..., H represent 8 + // distinct input vregs that form an 8x8 block of data + // to be transposed. In our notation, B2 identifies the third + // sublane (2) of the second vreg (B)". + // + // + // If we think of each starting input vreg as a row in an 8x8 block of + // elements: + // A: A0 A1 A2 A3 A4 A5 A6 A7 + // B: B0 B1 B2 B3 B4 B5 B6 B7 + // ... + // H: H0 H1 H2 H3 H4 H5 H6 H7 + // + // The goal is to transpose this block, so the output vregs are: + // out0: A0 B0 C0 D0 E0 F0 G0 H0 + // out1: A1 B1 C1 D1 E1 F1 G1 H1 + // ... + // out7: A7 B7 C7 D7 E7 F7 G7 H7 + // + // Stage 1: Operates on pairs of input vregs (e.g., A and B). + // + // Input to Stage 1 (example pair A, B): + // A: A0 A1 A2 A3 A4 A5 A6 A7 + // B: B0 B1 B2 B3 B4 B5 B6 B7 + // + // Step 1.1: Combine low/high halves. + // combine_low(A, B) -> CL_AB: [A0 A1 A2 A3 | B0 B1 B2 B3] (8 elements) + // combine_high(A, B) -> CH_AB: [A4 A5 A6 A7 | B4 B5 B6 B7] (8 elements) + // (Notation: '|' separates the 4 elements from A and 4 from B) + // + // Step 1.2: Shuffle. + // The shuffle pattern for the low part (applied to CL_AB using + // `shuffle(CL_AB, CH_AB, pattern)`) is {0, 4, 1, 5, 2, 6, 3, 7}. + // The shuffle pattern for the high part (applied to CH_AB using + // `shuffle(CL_AB, CH_AB, pattern)`) is {8, 12, 9, 13, 10, 14, 11, 15}. + // (Indices 0-7 in shuffle refer to CL_AB, 8-15 to CH_AB). + // This results in: + // s1_AB_0: A0 B0 A1 B1 A2 B2 A3 B3 (from shuffling CL_AB elements) + // s1_AB_1: A4 B4 A5 B5 A6 B6 A7 B7 (from shuffling CH_AB elements) + // + // Output of Stage 1 / Input to Stage 2 (example for A,B,C,D processing): + // s1_vregs[0] (from A,B): A0 B0 A1 B1 A2 B2 A3 B3 + // s1_vregs[1] (from A,B): A4 B4 A5 B5 A6 B6 A7 B7 + // s1_vregs[2] (from C,D): C0 D0 C1 D1 C2 D2 C3 D3 + // s1_vregs[3] (from C,D): C4 D4 C5 D5 C6 D6 C7 D7 + // ... (and so on for E,F,G,H into s1_vregs[4-7]) + + // Stage 2: Operates on groups of 4 vregs from Stage 1 output. + // (e.g., s1_vregs[0], s1_vregs[1], s1_vregs[2], s1_vregs[3]) + // + // Input to Stage 2 (example processing s1_vregs[0] and s1_vregs[2]): + // X = s1_vregs[0] = [A0 B0 A1 B1 | A2 B2 A3 B3] + // Y = s1_vregs[2] = [C0 D0 C1 D1 | C2 D2 C3 D3] + // + // Step 2.1: Combine low/high halves. + // combine_low(X, Y) -> CL_XY: [A0 B0 A1 B1 | C0 D0 C1 D1] + // combine_high(X, Y) -> CH_XY: [A2 B2 A3 B3 | C2 D2 C3 D3] + // + // (Similarly for s1_vregs[1] and s1_vregs[3], let them be X' and Y') + // combine_low(X', Y') -> CL_X'Y': [A4 B4 A5 B5 | C4 D4 C5 D5] + // combine_high(X', Y') -> CH_X'Y': [A6 B6 A7 B7 | C6 D6 C7 D7] + // + // Step 2.2: Shuffle. + // The shuffle pattern for the low part (e.g., applied to CL_XY) is {0, 1, + // 4, 5, 2, 3, 6, 7}. The shuffle pattern for the high part (e.g., applied + // to CH_XY, effectively) is {8, 9, 12, 13, 10, 11, 14, 15}. + // + // This results in (for the first group of 4 input vregs A,B,C,D): + // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 (from shuffling CL_XY elements) + // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 (from shuffling CH_XY elements) + // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 (from shuffling CL_X'Y' elements) + // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 (from shuffling CH_X'Y' elements) + // + // Output of Stage 2 / Input to Stage 3: + // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 + // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 + // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 + // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 + // s2_vregs[4]: E0 F0 G0 H0 E1 F1 G1 H1 (from E,F,G,H processing) + // s2_vregs[5]: E2 F2 G2 H2 E3 F3 G3 H3 + // s2_vregs[6]: E4 F4 G4 H4 E5 F5 G5 H5 + // s2_vregs[7]: E6 F6 G6 H6 E7 F7 G7 H7 + + // Stage 3: Combine results from Stage 2. No shuffle needed after combine. + // Input to Stage 3 (example for the first two rows of the final transpose): + // L = s2_vregs[0] = [A0 B0 C0 D0 | A1 B1 C1 D1] + // R = s2_vregs[4] = [E0 F0 G0 H0 | E1 F1 G1 H1] + // + // Step 3.1: Combine low/high halves. + // combine_low(L, R) -> [A0 B0 C0 D0 | E0 F0 G0 H0] -> + // Final out0: A0 B0 C0 D0 E0 F0 G0 H0 + // combine_high(L, R) -> [A1 B1 C1 D1 | E1 F1 G1 H1] -> + // Final out1: A1 B1 C1 D1 E1 F1 G1 H1 + // ... and so on for other pairs from Stage 2 output + // (e.g. L=s2_vregs[1], R=s2_vregs[5]). + // + // This results in the correctly transposed 8x8 block. + + constexpr int64_t kMajorDimOriginalIdx = 0; + constexpr int64_t kSecondMinorDimOriginalIdx = 1; + constexpr int64_t kMinorMostDimOriginalIdx = 2; + + auto vec_shape = src_ty.getShape(); + auto major_dim_size = vec_shape[kMajorDimOriginalIdx]; + auto second_minor_dim_size = vec_shape[kSecondMinorDimOriginalIdx]; + + if (layout_in.offsets() != LayoutOffsets{0, 0}) { + return transpose_op.emitOpError("Not implemented: Layout with offset."); + } + if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + return transpose_op.emitOpError( + "Not implemented: Layout with implicit dimension."); + } + + auto sublane_count = ctx.target_shape[0]; + if (second_minor_dim_size % sublane_count != 0 || + major_dim_size % sublane_count != 0) { + return transpose_op.emitOpError( + "Not implemented: Swapping major and second minor dimensions must " + "result in dimension sizes that are multiples of sublane_count."); + } + + if (!layout_in.hasNativeTiling(ctx.target_shape)) { + return transpose_op.emitOpError( + "Not implemented: Expected native input tiling."); + } + if (layout_in != layout_out) { + return transpose_op.emitOpError( + "Not implemented: Expected same input and output layouts."); + } + xla::Array dst_vregs( + layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape)); + + if (layout_in.bitwidth() != 32) { + return transpose_op.emitOpError( + "Not implemented: Major-second-minor transpose only supported for " + "32-bit vectors. Also, input must be a vector type."); + } + if (ctx.target_shape[0] != 8) { + return transpose_op.emitOpError( + "Not implemented: Major-second-minor transpose expects 8 sublanes."); + } + + auto vreg_dimensions = src_vregs.dimensions(); + // Note(mvoz): Slice is a weird word here, This is used for constructing + // the output vregs - the reason we divide here is because we multiply it + // back later on to get the correct index into src_vregs, but the reason + // we cannot just resolve that in our outer loop is because of the nature + // of a transpose - this dim value goes unmultiplied into the output vregs. + // effectively, our indexing: + // {major_dim_slice_idx * sublane_count, second_minor_dim_slice_idx, + // minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx * + // sublane_count, major_dim_slice_idx, minor_most_dim_slice_idx} + auto num_slices_in_major_dim = + vreg_dimensions[kMajorDimOriginalIdx] / sublane_count; + auto num_slices_in_second_minor_dim = + vreg_dimensions[kSecondMinorDimOriginalIdx]; + auto num_slices_in_minor_most_dim = + vreg_dimensions[kMinorMostDimOriginalIdx]; + + auto shuffle = [&](Value lhs_vreg, Value rhs_vreg, ArrayRef pattern) { + auto lhs_vreg_type = lhs_vreg.getType(); + auto pattern_attr = builder.getDenseI32ArrayAttr(pattern); + return builder + .create(transpose_op.getLoc(), lhs_vreg_type, + lhs_vreg, rhs_vreg, pattern_attr) + .getResult(); + }; + + static constexpr std::array combine_low_pattern = {0, 1, 2, 3, + 8, 9, 10, 11}; + static constexpr std::array combine_high_pattern = {4, 5, 6, 7, + 12, 13, 14, 15}; + + auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) { + return shuffle(lhs_vreg, rhs_vreg, combine_low_pattern); + }; + auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) { + return shuffle(lhs_vreg, rhs_vreg, combine_high_pattern); + }; + + // Shuffle patterns for Stage 1 + // Input to shuffle: (combine_low_val, combine_high_val) + // combine_low_val has A0-A3, B0-B3. Indices 0-7 for shuffle. + // combine_high_val has A4-A7, B4-B7. Indices 8-15 for shuffle. + static constexpr std::array permute_pattern_stage1_low_arr = { + 0, 4, 1, 5, + 2, 6, 3, 7}; // Selects from combine_low_val to make A0B0A1B1A2B2A3B3 + static constexpr std::array permute_pattern_stage1_high_arr = { + 8, 12, 9, 13, 10, + 14, 11, 15}; // Selects from combine_high_val to make A4B4A5B5A6B6A7B7 + + // Shuffle patterns for Stage 2 + // Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments. + // CL_XY has A0B0A1B1C0D0C1D1. Indices 0-7 for shuffle. + // CH_XY has A2B2A3B3C2D2C3D3. Indices 8-15 for shuffle. + static constexpr std::array permute_pattern_stage2_low_arr = { + 0, 1, 4, 5, 2, 3, 6, 7}; // Selects from CL_XY to make A0B0C0D0A1B1C1D1 + static constexpr std::array permute_pattern_stage2_high_arr = { + 8, 9, 12, 13, + 10, 11, 14, 15}; // Selects from CH_XY to make A2B2C2D2A3B3C3D3 + + for (int major_dim_slice_idx = 0; + major_dim_slice_idx < num_slices_in_major_dim; ++major_dim_slice_idx) { + for (int second_minor_dim_slice_idx = 0; + second_minor_dim_slice_idx < num_slices_in_second_minor_dim; + ++second_minor_dim_slice_idx) { + for (int minor_most_dim_slice_idx = 0; + minor_most_dim_slice_idx < num_slices_in_minor_most_dim; + ++minor_most_dim_slice_idx) { + // STAGE 1! + std::array + stage1_output_vregs; // Stores s1_vregs from comments + constexpr int num_pairs_stage1 = + 4; // Processes 4 pairs of vregs (A,B), (C,D), (E,F), (G,H) + + for (int i = 0; i < num_pairs_stage1; ++i) { + Value first_vreg = src_vregs( + {(2 * i) + (sublane_count * major_dim_slice_idx), + second_minor_dim_slice_idx, minor_most_dim_slice_idx}); + Value second_vreg = src_vregs( + {(2 * i) + (sublane_count * major_dim_slice_idx) + 1, + second_minor_dim_slice_idx, minor_most_dim_slice_idx}); + + auto combined_low_val = combine_low(first_vreg, second_vreg); + auto combined_high_val = combine_high(first_vreg, second_vreg); + + stage1_output_vregs[2 * i] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage1_low_arr); + stage1_output_vregs[2 * i + 1] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage1_high_arr); + } + + // STAGE 2! + std::array + stage2_output_vregs; // Stores s2_vregs from comments + constexpr int num_pairs_stage2 = + 4; // Processes 4 pairs of vregs from stage1_output_vregs + + for (int i = 0; i < num_pairs_stage2; ++i) { + // Determine the indices for the input pair from + // stage1_output_vregs. The 4 pairs processed in this stage are: + // i=0: (s1_vregs[0], s1_vregs[2]) + // i=1: (s1_vregs[1], s1_vregs[3]) + // i=2: (s1_vregs[4], s1_vregs[6]) + // i=3: (s1_vregs[5], s1_vregs[7]) + int s1_lhs_idx = (i / 2) * 4 + (i % 2); + int s1_rhs_idx = s1_lhs_idx + 2; + + Value s1_lhs_vreg = stage1_output_vregs[s1_lhs_idx]; + Value s1_rhs_vreg = stage1_output_vregs[s1_rhs_idx]; + + auto combined_low_val = combine_low(s1_lhs_vreg, s1_rhs_vreg); + auto combined_high_val = combine_high(s1_lhs_vreg, s1_rhs_vreg); + + // Determine the output indices for stage2_output_vregs. + // Each pair from Stage 1 produces a pair of vregs for Stage 2. + // Results are stored pair-wise: + // i=0 -> s2_vregs[0], s2_vregs[1] + // i=1 -> s2_vregs[2], s2_vregs[3] + // i=2 -> s2_vregs[4], s2_vregs[5] + // i=3 -> s2_vregs[6], s2_vregs[7] + int s2_out_idx_base = 2 * i; + + stage2_output_vregs[s2_out_idx_base] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage2_low_arr); + stage2_output_vregs[s2_out_idx_base + 1] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage2_high_arr); + } + + // STAGE 3! Combine results from stage 2. + std::array output_idx_parts{ + second_minor_dim_slice_idx * sublane_count, major_dim_slice_idx, + minor_most_dim_slice_idx}; + + constexpr int num_final_combines = + 4; // Corresponds to s2_vregs[0]..s2_vregs[3] pairing with + // s2_vregs[4]..s2_vregs[7] + for (int i = 0; i < num_final_combines; ++i) { + Value lhs = stage2_output_vregs[i]; // e.g., s2_ABCD_0 + Value rhs = stage2_output_vregs[i + 4]; // e.g., s2_EFGH_0 + auto final_combined_low = combine_low(lhs, rhs); + auto final_combined_high = combine_high(lhs, rhs); + + dst_vregs(output_idx_parts) = final_combined_low; + output_idx_parts[0] += 1; + dst_vregs(output_idx_parts) = final_combined_high; + output_idx_parts[0] += 1; + } + } + } + } + auto assembled = + assemble(builder, dst_ty, layout_out, dst_vregs, ctx.target_shape); + transpose_op.getOperation()->replaceAllUsesWith(assembled); + transpose_op.erase(); + return success(); } + { SmallVector p(permutation); p[rank - 2] = rank - 2; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index f01d0b4c5888..976e31cb55f4 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1680,17 +1680,27 @@ class VectorLayoutInferer { auto src_ty = op.getSourceVectorType(); TPU_CHECK_OP(permutation.size() == src_ty.getRank(), "Transpose permutation has incorrect rank"); - for (auto dim : permutation.drop_back(2)) { - TPU_CHECK_OP(dim < src_ty.getRank() - 2, - "Unsupported transpose permutation - minor dims into major"); - } - for (auto dim : permutation.take_back(2)) { - TPU_CHECK_OP(dim >= src_ty.getRank() - 2, - "Unsupported transpose permutation - major dims into minor"); + bool untiled_tiled_swap = false; + // TODO(mvoz): Expand to more general cases. b/419268277 + if (permutation.size() == 3 && permutation[0] == 1 && permutation[1] == 0) { + untiled_tiled_swap = true; + } else { + for (auto dim : permutation.drop_back(2)) { + TPU_CHECK_OP(dim < src_ty.getRank() - 2, + "Unsupported transpose permutation - minor dims into " + "major > 3 dimensions"); + } + for (auto dim : permutation.take_back(2)) { + TPU_CHECK_OP(dim >= src_ty.getRank() - 2, + "Unsupported transpose permutation - major dims into " + "minor > 3 dimensions"); + } } Layout required_layout = some_layout; - // Require native tiling if we're going to use the XLU. - if (permutation[permutation.size() - 1] == permutation.size() - 2) { + // Require native tiling if we're going to use the XLU, or doing a + // major/minor permute. + if (untiled_tiled_swap || + permutation[permutation.size() - 1] == permutation.size() - 2) { auto native_tiling = nativeTiling(layout.bitwidth()); required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0}, native_tiling, ImplicitDim::kNone); diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index bcda0ca9f71e..3e777ac7ea2c 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -300,7 +300,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) def skip_if_mosaic_gpu(self): - if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + if jtu.test_device_matches(["gpu"]) and use_mosaic_gpu: self.skipTest("TODO: Mosaic GPU does not support this yet") @@ -2569,6 +2569,52 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.diagonal(x)) + @parameterized.product( + # Skip some steps to just run less cases + # TODO(mvoz): Hypothesis? + x_dim_size=tuple(8 * i for i in range(1, 5)), + y_dim_size=tuple(8 * i for i in range(1, 5)), + z_dim_size=tuple(128 * i for i in range(1, 3)), + dtype=(jnp.float32,), + ) + def test_jnp_swapaxes_major_minor( + self, x_dim_size, y_dim_size, z_dim_size, dtype + ): + if jtu.test_device_matches(["gpu"]): + if any( + not is_power_of_two(x) for x in [x_dim_size, y_dim_size, z_dim_size] + ): + self.skipTest( + "the Pallas Triton lowering currently requires that all operations" + " have array arguments and results whose size is a power of 2." + f" Encountered an array of shape ({x_dim_size}, {y_dim_size}," + f" {z_dim_size})" + ) + if x_dim_size * y_dim_size * z_dim_size * 4 > 32768: + self.skipTest( + "Mosaic GPU kernel exceeds available shared memory" + f" smem_bytes={x_dim_size * y_dim_size * z_dim_size * 4} > 32768" + ) + self.skip_if_mosaic_gpu() + if not jtu.if_cloud_tpu_at_least(2025, 5, 22): + self.skipTest("Requires libtpu built after 2025-5-22") + + x = jnp.arange(x_dim_size * y_dim_size * z_dim_size, dtype=dtype).reshape( + (x_dim_size, y_dim_size, z_dim_size) + ) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.swapaxes(x_ref[...], 0, 1) + + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct( + (y_dim_size, x_dim_size, z_dim_size), dtype + ), + )(x) + expected = jnp.swapaxes(x, 0, 1) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True From 0169f32fa2ea3166bcef7e113c9e3158195a9db3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 May 2025 17:43:52 -0700 Subject: [PATCH 1297/1769] remove jaxlib_extension_version, ifrt_version and jaxlib.__version_info__ guards after 0.6.1 release. PiperOrigin-RevId: 761737523 --- docs/autodidax.ipynb | 7 +-- docs/autodidax.md | 7 +-- docs/autodidax.py | 7 +-- jax/_src/buffer_callback.py | 12 ++-- jax/_src/compilation_cache.py | 9 +-- jax/_src/compiler.py | 15 +---- jax/_src/lax/ann.py | 13 +--- jax/_src/lax/linalg.py | 26 +------- jax/_src/lib/__init__.py | 36 +++-------- jax/_src/lib/mlir/dialects/__init__.py | 5 +- jax/_src/named_sharding.py | 17 ++--- jax/_src/pallas/mosaic_gpu/lowering.py | 6 -- jax/_src/profiler.py | 2 +- jax/_src/test_util.py | 15 ++--- jax/experimental/buffer_callback.py | 19 ++---- .../jax2tf/tests/sharding_test.py | 9 +-- jax/experimental/serialize_executable.py | 3 - jaxlib/py_client.cc | 15 ----- jaxlib/py_compile_only_client.cc | 4 -- jaxlib/py_program.cc | 5 -- jaxlib/py_socket_transfer.cc | 63 ------------------- jaxlib/util.cc | 2 - jaxlib/xla.cc | 19 ------ tests/api_test.py | 18 ------ tests/buffer_callback_test.py | 7 +-- tests/compilation_cache_test.py | 29 +++------ tests/fused_attention_stablehlo_test.py | 3 - tests/linalg_sharding_test.py | 4 +- tests/linalg_test.py | 5 +- tests/mosaic/gpu_layout_inference_test.py | 6 -- tests/mosaic/gpu_test.py | 5 -- tests/pallas/gpu_pallas_distributed_test.py | 2 - tests/pjit_test.py | 3 - tests/python_callback_test.py | 7 --- 34 files changed, 59 insertions(+), 346 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 07c7d7e84ff0..16d4da37b3f2 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -1986,7 +1986,6 @@ "from jax.extend.mlir import ir\n", "from jax.extend.mlir.dialects import func\n", "from jax.extend.mlir.dialects import stablehlo as hlo\n", - "import jax._src.lib\n", "from jax._src import xla_bridge as xb\n", "\n", "class MlirContext(NamedTuple):\n", @@ -2021,11 +2020,7 @@ " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", " backend = xb.get_backend(None)\n", - " if jax._src.lib.version >= (0, 6, 1):\n", - " compiled = backend.compile(\n", - " output.getvalue(), backend.devices()[:1])\n", - " else:\n", - " compiled = backend.compile(output.getvalue())\n", + " compiled = backend.compile(output.getvalue(), backend.devices()[:1])\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index e78aeded41c0..870ee20f0f9a 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1556,7 +1556,6 @@ import io from jax.extend.mlir import ir from jax.extend.mlir.dialects import func from jax.extend.mlir.dialects import stablehlo as hlo -import jax._src.lib from jax._src import xla_bridge as xb class MlirContext(NamedTuple): @@ -1591,11 +1590,7 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - if jax._src.lib.version >= (0, 6, 1): - compiled = backend.compile( - output.getvalue(), backend.devices()[:1]) - else: - compiled = backend.compile(output.getvalue()) + compiled = backend.compile(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: diff --git a/docs/autodidax.py b/docs/autodidax.py index 9531ef7694c5..b0dbf9f73d9f 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1548,7 +1548,6 @@ def __eq__(self, other): from jax.extend.mlir import ir from jax.extend.mlir.dialects import func from jax.extend.mlir.dialects import stablehlo as hlo -import jax._src.lib from jax._src import xla_bridge as xb class MlirContext(NamedTuple): @@ -1583,11 +1582,7 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - if jax._src.lib.version >= (0, 6, 1): - compiled = backend.compile( - output.getvalue(), backend.devices()[:1]) - else: - compiled = backend.compile(output.getvalue()) + compiled = backend.compile(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: diff --git a/jax/_src/buffer_callback.py b/jax/_src/buffer_callback.py index 739fdb4c408d..a1dfb5c2ff18 100644 --- a/jax/_src/buffer_callback.py +++ b/jax/_src/buffer_callback.py @@ -27,16 +27,12 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lib import jaxlib_extension_version +from jax._src.lib import ffi as ffi_lib export = util.set_module("jax.experimental.buffer_callback") - -if jaxlib_extension_version >= 334: - from jax._src.lib import ffi as ffi_lib - - Buffer = export(ffi_lib.Buffer) - ExecutionStage = export(ffi_lib.ExecutionStage) - ExecutionContext = export(ffi_lib.ExecutionContext) +Buffer = export(ffi_lib.Buffer) +ExecutionStage = export(ffi_lib.ExecutionStage) +ExecutionContext = export(ffi_lib.ExecutionContext) def buffer_callback( diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index aa1bd6ab65ba..058670642f41 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -31,7 +31,6 @@ from jax._src import config from jax._src import monitoring from jax._src.compilation_cache_interface import CacheInterface -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lru_cache import LRUCache @@ -224,12 +223,8 @@ def get_executable_and_time( executable_and_time = decompress_executable(executable_and_time) serialized_executable, compile_time = extract_executable_and_time( executable_and_time) - if jaxlib_extension_version < 332: - xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options) - else: - xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, executable_devices, compile_options) + xla_executable_deserialized = backend.deserialize_executable( + serialized_executable, executable_devices, compile_options) return xla_executable_deserialized, compile_time diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 04f993fed799..e1b8e7c35697 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -35,7 +35,6 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -314,12 +313,6 @@ def backend_compile( ) try: - if jaxlib_extension_version < 332: - if host_callbacks: - return backend.compile( - built_c, compile_options=options, host_callbacks=host_callbacks) # type: ignore - return backend.compile(built_c, compile_options=options) # type: ignore - # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results if host_callbacks: @@ -692,12 +685,8 @@ def _compile_and_share_module( serialized_executable = compilation_cache.decompress_executable( serialized_executable ) - if jaxlib_extension_version < 332: - executable = backend.deserialize_executable( - serialized_executable, compile_options) # type: ignore - else: - executable = backend.deserialize_executable( - serialized_executable, executable_devices, compile_options) # type: ignore + executable = backend.deserialize_executable( + serialized_executable, executable_devices, compile_options) # type: ignore _compile_and_share_module.modules_cache[cache_key] = executable return executable diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index bfcd45fba574..61d383ee29c2 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -83,8 +83,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.lib import _jax -from jax._src.lib import jaxlib_extension_version -from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo @@ -233,14 +231,9 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if aggregate_to_topk: dims[reduction_dimension] = k elif core.is_constant_shape((reduction_input_size, k)): - if jaxlib_extension_version >= 331: - dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size( - reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, - reduction_input_size_override)[0] - else: - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( # type: ignore # pytype: disable=module-attr - reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, - reduction_input_size_override)[0] + dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size( + reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, + reduction_input_size_override)[0] else: raise NotImplementedError( "approx_top_k with aggregate_to_topk=False not yet implemented when " diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 857b115b06d8..2fda4a90369d 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -48,7 +48,7 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version, jaxlib_extension_version +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2530,30 +2530,6 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): return b_shape def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): - if jaxlib_extension_version < 340: - _, _, _, b_aval = ctx.avals_in - *batch_dims, m, n = b_aval.shape - batch_size = math.prod(batch_dims) - mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse - assert mod is not None - opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) - if b_aval.dtype == np.float32: - buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) - target_name = "sparse_gtsv2_f32_ffi" - elif b_aval.dtype == np.float64: - buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) - target_name = "sparse_gtsv2_f64_ffi" - else: - raise NotImplementedError( - "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - - buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) - sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) - rule = _linalg_ffi_lowering( - f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, - batch_partitionable=False) - return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] - target_name = f"{target_name_prefix}sparse_gtsv2_ffi" rule = _linalg_ffi_lowering(target_name, operand_output_aliases={3: 0}) return rule(ctx, dl, d, du, b) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 5cdcaf400c8a..8de05061ec99 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -85,23 +85,13 @@ def _parse_version(v: str) -> tuple[int, ...]: import jaxlib.lapack as lapack # noqa: F401 import jaxlib.utils as utils # noqa: F401 - -if version >= (0, 6, 1): - import jaxlib._jax as _jax # noqa: F401 - from jaxlib._jax import guard_lib as guard_lib # noqa: F401 - from jaxlib._jax import jax_jit as jax_jit # noqa: F401 - from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 - from jaxlib._jax import pytree as pytree # noqa: F401 - from jaxlib._jax import Device as Device # noqa: F401 - from jaxlib import _profiler as _profiler # noqa: F401 -else: - import jaxlib.xla_extension as _jax # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import guard_lib as guard_lib # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import jax_jit as jax_jit # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import pmap_lib as pmap_lib # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import pytree as pytree # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import Device as Device # type: ignore # pytype: disable=import-error # noqa: F401 - from jaxlib.xla_extension import profiler as _profiler # type: ignore # pytype: disable=import-error # noqa: F401 +import jaxlib._jax as _jax # noqa: F401 +from jaxlib._jax import guard_lib as guard_lib # noqa: F401 +from jaxlib._jax import jax_jit as jax_jit # noqa: F401 +from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 +from jaxlib._jax import pytree as pytree # noqa: F401 +from jaxlib._jax import Device as Device # noqa: F401 +from jaxlib import _profiler as _profiler # noqa: F401 import jaxlib.xla_client as xla_client # noqa: F401 @@ -112,15 +102,9 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_extension_version: int = getattr(xla_client, '_version', 0) ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) -if jaxlib_extension_version >= 334: - from jaxlib._jax import ffi as ffi # noqa: F401 - -if jaxlib_extension_version >= 335: - import jaxlib.cpu_sparse as cpu_sparse # noqa: F401 - - has_cpu_sparse = True -else: - has_cpu_sparse = False +from jaxlib._jax import ffi as ffi # noqa: F401 +import jaxlib.cpu_sparse as cpu_sparse # noqa: F401 +has_cpu_sparse = True import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 5584afee2116..b49154e7936a 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -57,7 +57,4 @@ from jaxlib.mlir.dialects import stablehlo as hlo from jax._src import lib -if lib.version >= (0, 6, 1): - from jaxlib.mlir.dialects import cf -else: - cf = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import cf diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index ae99236a6cdc..faf0b2a9f2b2 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -23,7 +23,6 @@ from jax._src import config from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType @@ -317,17 +316,11 @@ def build(self) -> sdy.TensorShardingAttr: replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) - if jaxlib_extension_version >= 342: - return sdy.TensorShardingAttr.get( - mesh_attr, - [dim_sharding.build() for dim_sharding in self.dim_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], - unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) - else: - return sdy.TensorShardingAttr.get( - mesh_attr, - [dim_sharding.build() for dim_sharding in self.dim_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes]) + return sdy.TensorShardingAttr.get( + mesh_attr, + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b6960c479558..e212f9770a94 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -31,7 +31,6 @@ from jax import lax from jax._src import checkify from jax._src import core as jax_core -from jax._src import lib as jaxlib from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit @@ -1569,11 +1568,6 @@ def _broadcast_in_dim_lowering_rule_wg( ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), x, ) - - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. - if jaxlib.version < (0, 6, 1): - raise NotImplementedError() - mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) result_ty = ir.VectorType.get(shape, mlir_type) return mgpu.dialect.broadcast_in_dim(result_ty, x, broadcast_dimensions) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 6b58b2ba6326..424e2b81035f 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -215,7 +215,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.stop_and_export(str(_profile_state.log_dir)) + sess.stop_and_export(str(_profile_state.log_dir)) # type: ignore if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bb1ef6595ec3..f6810b533b31 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -54,7 +54,6 @@ from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 @@ -357,18 +356,16 @@ def assert_num_jit_and_pmap_compilations(times): @contextmanager def count_internal_device_puts(): - if jaxlib_extension_version >= 341: - before = jax._src.lib._jax.get_internal_device_put_info() + before = jax._src.lib._jax.get_internal_device_put_info() counts = {} try: yield lambda: counts finally: - if jaxlib_extension_version >= 341: - after = jax._src.lib._jax.get_internal_device_put_info() - for k, v in after.items(): - diff = v - before.get(k, 0) - if diff != 0: - counts[k] = diff + after = jax._src.lib._jax.get_internal_device_put_info() + for k, v in after.items(): + diff = v - before.get(k, 0) + if diff != 0: + counts[k] = diff def jaxlib_version() -> tuple[int, ...]: return _jaxlib.version diff --git a/jax/experimental/buffer_callback.py b/jax/experimental/buffer_callback.py index 6c8514340af0..f919cfa10208 100644 --- a/jax/experimental/buffer_callback.py +++ b/jax/experimental/buffer_callback.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import jaxlib_extension_version as _jaxlib_extension_version - -if _jaxlib_extension_version >= 334: - from jax._src.buffer_callback import ( - Buffer as Buffer, - ExecutionContext as ExecutionContext, - ExecutionStage as ExecutionStage, - buffer_callback as buffer_callback, - ) - -from jax._src.buffer_callback import buffer_callback as buffer_callback - -del _jaxlib_extension_version +from jax._src.buffer_callback import ( + Buffer as Buffer, + ExecutionContext as ExecutionContext, + ExecutionStage as ExecutionStage, + buffer_callback as buffer_callback, +) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 5fc45df218cd..fa15522cbe90 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -33,7 +33,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax import lax from jax.experimental import jax2tf @@ -111,12 +110,8 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) - if jaxlib_extension_version < 332: - executable = backend.compile( - jax_hlo, compile_options=compile_options) # type: ignore - else: - executable = backend.compile_and_load( - jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore + executable = backend.compile_and_load( + jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore jax_optimized_hlo = executable.hlo_modules()[0].to_string() logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 6f5062d4ce99..7c112f56ef42 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -19,7 +19,6 @@ import io import jax -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from typing import Sequence @@ -110,8 +109,6 @@ def __init__(self, file, backend, execution_devices=None): def persistent_load(self, pid): if pid[0] == 'exec': - if jaxlib_extension_version < 332: - return self.backend.deserialize_executable(pid[1]) return self.backend.deserialize_executable( pid[1], executable_devices=self.execution_devices) if pid[0] == 'device': diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 0a99d94f81cc..842bdfecad3d 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -373,14 +373,9 @@ std::unique_ptr MakeIfrtCompileOptions( ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } -#if JAX_IFRT_VERSION_NUMBER >= 6 return std::make_unique( std::move(options), std::move(executable_devices), std::move(ifrt_loaded_host_callbacks)); -#else - return std::make_unique( - std::move(options), std::move(ifrt_loaded_host_callbacks)); -#endif } // Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and @@ -398,14 +393,9 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } -#if JAX_IFRT_VERSION_NUMBER >= 6 return std::make_unique( std::move(options), std::move(executable_devices), std::move(ifrt_loaded_host_callbacks)); -#else - return std::make_unique( - std::move(options), std::move(ifrt_loaded_host_callbacks)); -#endif } } // namespace @@ -504,14 +494,9 @@ PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, client->ifrt_client(), std::move(host_callback)); ifrt_loaded_host_callbacks.push_back(callback); } -#if JAX_IFRT_VERSION_NUMBER >= 6 auto compile_options = std::make_unique( std::move(options), std::move(executable_devices), std::move(ifrt_loaded_host_callbacks)); -#else - auto compile_options = std::make_unique( - std::move(options), std::move(ifrt_loaded_host_callbacks)); -#endif return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), std::move(compile_options)); diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 0fa2f4b48fd7..2de896d80bef 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -91,12 +91,8 @@ class CompileOnlyPyClient : public PyClient { llvm::dyn_cast_or_null(this->ifrt_client()); CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " "CompileOnlyIfRtClient"; -#if JAX_IFRT_VERSION_NUMBER >= 6 auto xla_options = std::make_unique( options, std::move(executable_devices)); -#else - auto xla_options = std::make_unique(options); -#endif TF_ASSIGN_OR_RETURN(auto executable, PjRtCompile(std::move(options), module.get(), *ifrt_client->topology().description())); diff --git a/jaxlib/py_program.cc b/jaxlib/py_program.cc index 40bfd3497ebd..ee2d3eef9973 100644 --- a/jaxlib/py_program.cc +++ b/jaxlib/py_program.cc @@ -236,16 +236,11 @@ absl::StatusOr> MakeXlaCompileOptions( ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); } -#if JAX_IFRT_VERSION_NUMBER >= 6 TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef executable_devices, py_executable_devices.ifrt_device_list()); return std::make_unique( std::move(options), std::move(executable_devices), std::move(ifrt_loaded_host_callbacks)); -#else - return std::make_unique( - std::move(options), std::move(ifrt_loaded_host_callbacks)); -#endif } constexpr absl::string_view kColocatedPythonProgramType = diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index fde63df8da47..8086196b9df8 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -110,58 +110,10 @@ absl::StatusOr MemorySpaceFromSharding( } } -#if JAX_IFRT_VERSION_NUMBER < 8 -class IfrtArrayEntry : public PullTable::Entry { - public: - struct BufferRef { - xla::ifrt::ArrayRef arr; - xla::PjRtBuffer* buffer; - size_t buf_size; - }; - explicit IfrtArrayEntry(std::vector arrs, - std::shared_ptr state, - size_t xfer_size) - : arrs_(std::move(arrs)), state_(state), xfer_size_(xfer_size) {} - bool Handle(tsl::RCReference state, - const SocketTransferPullRequest& req, - size_t base_req_id) override { - for (uint64_t bid : req.buffer_ids()) { - auto req_id = base_req_id; - ++base_req_id; - for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { - DmaCopyChunk blob = DmaCopyChunk::Make( - std::move(arrs_[bid].arr), arrs_[bid].buffer, bid, i * xfer_size_, - std::min(xfer_size_, arrs_[bid].buf_size - i * xfer_size_)); - bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; - state_->ScheduleCopy( - std::move(blob), [req_id, state, copier_state = state_, is_largest]( - PremappedCopierState* copier_state_ptr, - void* buf, const DmaCopyChunk& chunk) { - state->Send( - req_id, buf, chunk.offset, chunk.size, is_largest, - [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); - }); - } - } - - num_consumed_bufs_ += req.buffer_ids().size(); - return num_consumed_bufs_ == arrs_.size(); - } - - private: - absl::Mutex mu_; - size_t num_consumed_bufs_ = 0; - std::vector arrs_; - std::shared_ptr state_; - size_t xfer_size_; -}; -#endif - absl::StatusOr> CreatePullEntry( const std::vector& arrs, std::shared_ptr state, size_t xfer_size, bool use_raw_buffers) { -#if JAX_IFRT_VERSION_NUMBER >= 8 if (use_raw_buffers) { std::vector refs; for (auto& arr : arrs) { @@ -196,21 +148,6 @@ absl::StatusOr> CreatePullEntry( } } return tsl::MakeRef(std::move(refs), state, xfer_size); -#else - std::vector refs; - for (auto& arr : arrs) { - auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); - if (pjrt_arr == nullptr) { - return absl::InvalidArgumentError( - "Cannot remote transfer non-pjrt arrays."); - } - for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { - TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); - refs.push_back({arr, pjrt_buf.get(), buf_size}); - } - } - return tsl::MakeRef(std::move(refs), state, xfer_size); -#endif } class PyTransferServerConnection { diff --git a/jaxlib/util.cc b/jaxlib/util.cc index a014afa5bebe..a8d45749f4d1 100644 --- a/jaxlib/util.cc +++ b/jaxlib/util.cc @@ -36,7 +36,6 @@ limitations under the License. namespace xla { void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) { -#if JAX_IFRT_VERSION_NUMBER >= 5 future.BlockUntilReady([](tsl::AsyncValue* value) { auto state = std::make_shared(); value->AndThen([state]() { state->Notify(); }); @@ -50,7 +49,6 @@ void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) { } } }); -#endif } absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 3412766de6bd..d97c6868a04b 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -490,27 +490,8 @@ NB_MODULE(_jax, m) { &CompiledMemoryStats::host_temp_size_in_bytes) .def_prop_ro("serialized_buffer_assignment_proto", [](const CompiledMemoryStats& cms) -> nb::bytes { -#if JAX_IFRT_VERSION_NUMBER >= 9 const std::string& s = cms.serialized_buffer_assignment; return nb::bytes(s.data(), s.size()); -#elif JAX_IFRT_VERSION_NUMBER >= 7 - if (cms.buffer_assignment.has_value()) { - std::string s = - cms.buffer_assignment->SerializeAsString(); - return nb::bytes(s.data(), s.size()); - } else { - return nb::bytes(); - } -#else - xla::HloProto hlo; - if (!cms.serialized_hlo_proto.empty() && - hlo.ParseFromString(cms.serialized_hlo_proto)) { - std::string s = - hlo.buffer_assignment().SerializeAsString(); - return nb::bytes(s.data(), s.size()); - } - return nb::bytes(); -#endif }) .def("__str__", &CompiledMemoryStats::DebugString); diff --git a/tests/api_test.py b/tests/api_test.py index 9963e2603588..f5b74e1e10d6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -61,7 +61,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lib import _jax -from jax._src.lib import jaxlib_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError, @@ -1975,11 +1974,6 @@ def test_device_put_sharding_mismatched_tree_different_leaf_count(self): jax.device_put((x, y, z), device=(s1, s2)) def test_internal_device_put_with_device(self): - if jaxlib_extension_version < 341: - raise unittest.SkipTest( - "Test requires jaxlib extension version >= 341 for tracking low-level" - " DevicePut calls") - # Hitting the cache for a single-device jitted execution while using a numpy # array calls internal `DevicePutWithDevice`. f = jax.jit(lambda x: x + 1) @@ -1990,10 +1984,6 @@ def test_internal_device_put_with_device(self): self.assertEqual(counts(), {"device_put_with_device": 1}) def test_internal_device_put_fully_replicated(self): - if jaxlib_extension_version < 341: - raise unittest.SkipTest( - "Test requires jaxlib extension version >= 341 for tracking low-level" - " DevicePut calls") if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices") @@ -2011,10 +2001,6 @@ def test_internal_device_put_fully_replicated(self): ) def test_internal_device_put_batched(self): - if jaxlib_extension_version < 341: - raise unittest.SkipTest( - "Test requires jaxlib extension version >= 341 for tracking low-level" - " DevicePut calls") if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices") @@ -2031,10 +2017,6 @@ def test_internal_device_put_batched(self): ) def test_internal_device_put_assembled(self): - if jaxlib_extension_version < 341: - raise unittest.SkipTest( - "Test requires jaxlib extension version >= 341 for tracking low-level" - " DevicePut calls") if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices") diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py index e77ee4af687f..8bef4135f5d5 100644 --- a/tests/buffer_callback_test.py +++ b/tests/buffer_callback_test.py @@ -19,7 +19,6 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu -from jax._src.lib import jaxlib_extension_version from jax.experimental import buffer_callback jax.config.parse_flags_with_absl() @@ -29,10 +28,6 @@ class BufferCallbackTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if jaxlib_extension_version < 334: - self.skipTest( - "Requires a version of jaxlib with buffer callback support." - ) if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU.") @@ -102,7 +97,7 @@ def callback(ctx, out, arg): ) @jtu.run_on_devices("cuda") def test_cuda_array_interface(self, dtype, command_buffer_compatible): - if command_buffer_compatible and jaxlib_extension_version < 337: + if command_buffer_compatible: self.skipTest("Requires jaxlib extension version of at least 337.") def callback(ctx, out, arg): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 5a76d732bd76..3f1bb7fab4b1 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -146,14 +146,10 @@ def test_diff_executables(self): ) backend = xla_bridge.get_backend() executable_devices = xc.DeviceList(tuple(backend.local_devices())) - if jax._src.lib.jaxlib_extension_version < 331: - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) - else: - executable1 = backend.compile_and_load( - computation1, executable_devices, compile_options) - executable2 = backend.compile_and_load( - computation2, executable_devices, compile_options) + executable1 = backend.compile_and_load( + computation1, executable_devices, compile_options) + executable2 = backend.compile_and_load( + computation2, executable_devices, compile_options) cc.put_executable_and_time( "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) cc.put_executable_and_time( @@ -177,11 +173,8 @@ def test_put_executable(self): ) backend = xla_bridge.get_backend() executable_devices = xc.DeviceList(tuple(devices.flat)) - if jax._src.lib.jaxlib_extension_version < 331: - executable = backend.compile(str(computation), compile_options) - else: - executable = backend.compile_and_load( - str(computation), executable_devices, compile_options) + executable = backend.compile_and_load( + str(computation), executable_devices, compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) cc.put_executable_and_time( key, "alambda", executable, backend, FAKE_COMPILE_TIME) @@ -577,13 +570,9 @@ def test_backend_serialization_deserialization(self): .runtime_executable() ) serialized_executable = backend.serialize_executable(executable) - if jax._src.lib.jaxlib_extension_version < 331: - deserialized_executable = backend.deserialize_executable( # type: ignore - serialized_executable, None) - else: - deserialized_executable = backend.deserialize_executable( # type: ignore - serialized_executable, - xc.DeviceList(tuple(jax.local_devices(backend=backend))), None) + deserialized_executable = backend.deserialize_executable( # type: ignore + serialized_executable, + xc.DeviceList(tuple(jax.local_devices(backend=backend))), None) self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 925fc2ed4825..64e0f4377462 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -503,9 +503,6 @@ def test_sdpa_broadcast_bias_and_dbias(self): ) @jtu.run_on_devices("cuda") def test_sdpa_dbias(self, batch_size: int): - # TODO: Delete once 0.6.0 is no longer supported. - if jtu.jaxlib_version() == (0, 6, 0): - self.skipTest("jaxlib 0.6.0 has a bug") if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") # cuDNN only supports dbias when batch size is 1. If the batch size is diff --git a/tests/linalg_sharding_test.py b/tests/linalg_sharding_test.py index e68e94e16494..5d7b3b8a637b 100644 --- a/tests/linalg_sharding_test.py +++ b/tests/linalg_sharding_test.py @@ -22,7 +22,6 @@ from jax import lax from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import jaxlib_extension_version from jax.sharding import PartitionSpec as P config.parse_flags_with_absl() @@ -70,8 +69,7 @@ def get_fun_and_shapes(self, fun_and_shapes, grad=False): self.skipTest( f"Partitioning {fun_and_shapes[0].__name__} only supported on GPU " "when shardy is enabled.") - if (fun_and_shapes[0] == lax.linalg.tridiagonal_solve and - jaxlib_extension_version < 340): + if fun_and_shapes[0] == lax.linalg.tridiagonal_solve: self.skipTest( f"Partitioning {fun_and_shapes[0].__name__} on GPU, requires a " "more recent jaxlib version.") diff --git a/tests/linalg_test.py b/tests/linalg_test.py index cba3dbb7189d..99cb66c92857 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -33,7 +33,6 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.lib import jaxlib_extension_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -2205,7 +2204,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) def test_tridiagonal_solve(self, shape, dtype): - if dtype not in float_types and jtu.test_device_matches(["gpu"]) and jaxlib_extension_version < 340: + if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") rng = self.rng() d = 1.0 + jtu.rand_positive(rng)(shape, dtype) @@ -2244,7 +2243,7 @@ def test_tridiagonal_solve_endpoints(self): @jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types) def test_tridiagonal_solve_grad(self, shape, dtype): - if dtype not in float_types and jtu.test_device_matches(["gpu"]) and jaxlib_extension_version < 340: + if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") rng = self.rng() d = 1.0 + jtu.rand_positive(rng)(shape, dtype) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 038766542f3b..cdc840b0a6f1 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -19,7 +19,6 @@ from absl.testing import parameterized import jax from jax._src import config -from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir @@ -245,12 +244,7 @@ def body(x): def test_infer_broadcast_in_dim_layout( self, broadcast_dim, in_cast, out_cast, in_layout, out_layout ): - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. - if jaxlib.version < (0, 6, 1): - self.skipTest("Test requires jaxlib version >= 0.6.1") - bcast = None - in_shape = (64,) out_shape = (64, 64) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 4e0544d1758e..e7fc9723347b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -27,7 +27,6 @@ from absl.testing import absltest, parameterized import jax from jax._src import config -from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -3116,10 +3115,6 @@ def add( ((64,), (128, 64), [1]), ) def test_broadcast_in_dim(self, input_shape, output_shape, bcast_dims): - # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. - if jaxlib.version < (0, 6, 1): - self.skipTest("Test requires jaxlib version >= 0.6.1") - element_value = 42.0 def body(ctx, result_gmem_ref, smem): del ctx diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 81433b8c5067..d862e6b9b819 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -34,8 +34,6 @@ class PallasCallRemoteDMATest(jt_multiprocess.MultiProcessTest): def setUp(self): - if jtu.jaxlib_version() < (0, 6, 1): - self.skipTest("Test requires jaxlib >= 0.6.1") if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index aa2e0af2a57c..5d616c43ce54 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -63,7 +63,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import _jax -from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -7762,8 +7761,6 @@ def f(x): @config.use_shardy_partitioner(True) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_unreduced_basic(self, mesh): - if jaxlib_extension_version < 342: - self.skipTest("Test requires a newer jaxlib") np_inp = np.arange(16).reshape(8, 2) x = jax.device_put(np_inp, P('x', 'y')) y = jax.device_put(np_inp.T, P('y', None)) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 26664faa6faf..eef45b3b412b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -31,7 +31,6 @@ from jax.experimental import io_callback from jax.experimental import pjit from jax._src.shard_map import shard_map -from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp from jax.sharding import Mesh import numpy as np @@ -588,8 +587,6 @@ def fun(x): @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version < 336: - self.skipTest("Requires jaxlib_extension_version >= 336.") if "2" in dtype and jtu.test_device_matches(["tpu"]): self.skipTest( "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" @@ -609,8 +606,6 @@ def f(x): @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 336: - self.skipTest("Requires jaxlib_extension_version >= 336.") if "2" in dtype and jtu.test_device_matches(["tpu"]): self.skipTest( "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" @@ -630,8 +625,6 @@ def f(): @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_non_default_stride_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 336: - self.skipTest("Requires jaxlib_extension_version >= 336.") if "2" in dtype and jtu.test_device_matches(["tpu"]): self.skipTest( "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" From 29e6647577ad98b2d1a35cccdfe9da8dfa0f1f24 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 May 2025 18:05:10 -0700 Subject: [PATCH 1298/1769] Add out_sharding to the function returned by `jax.nn.initializers.he_normal` and other APIs implementing the `Initializer` protocol. Currently it takes `key, shape, dtype` and now we added an optional out_sharding parameter to it. PiperOrigin-RevId: 761742909 --- jax/_src/nn/initializers.py | 51 ++++++++++++++++++++++++++----------- tests/pjit_test.py | 22 ++++++++++++++++ 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 6f117eef749f..855729fa16ff 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -30,6 +30,7 @@ from jax import random from jax._src import core from jax._src import dtypes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike from jax._src.util import set_module @@ -48,7 +49,8 @@ class Initializer(Protocol): def __call__(self, key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: + dtype: DTypeLikeInexact = jnp.float_, + out_sharding=None) -> Array: raise NotImplementedError @export @@ -100,9 +102,12 @@ def constant(value: ArrayLike, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return jnp.full(shape, value, dtype=dtype) + out_sharding = canonicalize_sharding( + out_sharding, 'nn.initializers.constant') + return jnp.full(shape, value, dtype=dtype, device=out_sharding) return init @export @@ -126,9 +131,11 @@ def uniform(scale: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return random.uniform(key, shape, dtype) * jnp.array(scale, dtype) + return random.uniform(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(scale, dtype) return init @export @@ -152,9 +159,11 @@ def normal(stddev: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return random.normal(key, shape, dtype) * jnp.array(stddev, dtype) + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export @@ -189,10 +198,12 @@ def truncated_normal(stddev: RealNumeric = 1e-2, def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return random.truncated_normal( - key, lower, upper, shape, dtype) * jnp.array(stddev, dtype) + key, lower, upper, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export @@ -315,7 +326,8 @@ def variance_scaling( def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) @@ -332,16 +344,19 @@ def init(key: Array, if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, shape, dtype, + out_sharding=out_sharding) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, shape, dtype) * jnp.sqrt(variance) + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) + return random.uniform(key, shape, dtype, -1, + out_sharding=out_sharding) * jnp.sqrt(3 * variance) else: return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: @@ -601,7 +616,10 @@ def orthogonal(scale: RealNumeric = 1.0, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: + if out_sharding is not None: + raise NotImplementedError dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") @@ -651,7 +669,10 @@ def delta_orthogonal( """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: + if out_sharding is not None: + raise NotImplementedError dtype = dtypes.canonicalize_dtype(dtype) if len(shape) not in [3, 4, 5]: raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D " diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5d616c43ce54..024901b746a8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7868,6 +7868,28 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_he_normal(self, mesh): + init = jax.nn.initializers.he_normal(in_axis=0, out_axis=1) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_uniform(self, mesh): + init = jax.nn.initializers.uniform() + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_constant(self, mesh): + init = jax.nn.initializers.constant(-7) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertArraysEqual(out, jnp.full((8, 2), -7, dtype=jnp.float32)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From c19bf667cb9c0fe208d6583c7f81b3968dba4cba Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 21 May 2025 18:14:13 -0700 Subject: [PATCH 1299/1769] Add visibility hook for //jax:stages PiperOrigin-RevId: 761745409 --- jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/BUILD b/jax/BUILD index e431a3c5056c..e820bd06fe89 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1021,6 +1021,7 @@ pytype_strict_library( pytype_strict_library( name = "stages", srcs = ["_src/stages.py"], + visibility = [":internal"] + jax_visibility("stages"), deps = [ ":config", ":core", From 27ee82bac71387cfd97e4548819c51afefd53783 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 22 May 2025 09:47:47 +0530 Subject: [PATCH 1300/1769] Fix `jax.grad()` documentation in quickstart.md The documentation wrongly mentions that the `sum_logistic` function was jitted in the preceding example - which is not true. Fixed the phrasing to be more accurate. --- docs/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index ec9f3ccd3633..40c50dba3dbd 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -121,7 +121,7 @@ print(first_finite_differences(sum_logistic, x_small)) ``` The {func}`~jax.grad` and {func}`~jax.jit` transformations compose and can be mixed arbitrarily. -In the above example we jitted `sum_logistic` and then took its derivative. We can go further: +For instance, while the `sum_logistic` function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further: ```{code-cell} print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) From bddb877c217e045f1210f0c831018ee0a54078a9 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 22 May 2025 04:56:02 +0000 Subject: [PATCH 1301/1769] fix custom_vjp optimize_remat=True with collectives --- jax/_src/core.py | 2 -- jax/_src/custom_derivatives.py | 12 +++++------- tests/shard_map_test.py | 26 ++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index e49173c3df45..b20b85a43b6e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2745,10 +2745,8 @@ def __lt__(self, other): @dataclass(frozen=True) class NamedAxisEffect(effects.Effect): """A side-effect introducing a new named axis into the current scope.""" - name: AxisName - effects.control_flow_allowed_effects.add_type(NamedAxisEffect) effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect) effects.lowerable_effects.add_type(NamedAxisEffect) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 9b28595e1835..d76d145fd0a6 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1619,21 +1619,19 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: prim_tree, res_tree = out_trees() num_res = res_tree.num_leaves - if fwd_jaxpr.effects: + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fwd_jaxpr.effects) + if disallowed_effects: raise NotImplementedError( "remat optimization for custom_vjp does not support forward " - f"functions with side effects, but {fwd_name} has the following " - f"effects: {fwd_jaxpr.effects}") + f"functions with these side effects: {disallowed_effects}") @pe._memoize def fun_jaxpr_thunk(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return jaxpr, consts - out_flat = remat_opt_p.bind(*consts, *args_flat, - num_consts=len(consts), - num_res=num_res, - fwd_jaxpr=fwd_jaxpr, + out_flat = remat_opt_p.bind(*consts, *args_flat, num_consts=len(consts), + num_res=num_res, fwd_jaxpr=fwd_jaxpr, fun_jaxpr_thunk=fun_jaxpr_thunk) res, out_flat = split_list(out_flat, [num_res]) out_tree = treedef_tuple((prim_tree, res_tree)) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 9b4ca76c3bc5..5fbace3c98e1 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -621,6 +621,32 @@ def f(): x = f() self.assertAllClose(x, jnp.arange(4), check_dtypes=False) + def test_optimize_remat(self): + mesh = jtu.create_mesh((4,), 'x') + + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jax.lax.psum(x, 'x'), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + @jax.jit + @jax.shard_map(mesh=mesh, in_specs=P(), out_specs=P()) + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + jax.grad(lambda x: temp(x).sum())(jnp.arange(4.)) + def test_remat_basic(self): # this tests remat-of-shmap mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) From 421aa54b269bb9ec0e003fd2736b0c0f2edb4759 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 02:04:34 -0700 Subject: [PATCH 1302/1769] Fix skip condition in Pallas Triton tests It should not run in TPU configs at all. PiperOrigin-RevId: 761865943 --- tests/pallas/triton_pallas_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py index 4e2b10e72eb1..fe13716705de 100644 --- a/tests/pallas/triton_pallas_test.py +++ b/tests/pallas/triton_pallas_test.py @@ -33,11 +33,14 @@ class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + if jtu.test_device_matches(["cpu"]): + if not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + elif jtu.test_device_matches(["gpu"]): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPU with capability >= sm90") + else: + self.skipTest("Test only works on CPU and GPU") super().setUp() _trace_kernel_to_jaxpr.cache_clear() From 012f9b2677cbe84053dd5a9a71ab1aba349a8973 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 22 May 2025 02:24:36 -0700 Subject: [PATCH 1303/1769] [Pallas] Simulate multiple cores per device when interpreting kernels on CPU. The `TPUInterpreterParams` are extended with a field `num_cores_per_device` that specifies how many cores should be simulated per (TPU) device. Per-core devices are mapped along grid dimensions with _parallel_ dimension semantics. (I.e. the body of the grid loop is considered to execute on the next device-local core if the multi-dimensional loop index has changed along a _parallel_ grid dimension from one loop iteration to the next.) Each per-device core is identified by a `local_core_id`. Globally, each core is identified by a `global_core_id` (= `device_id` * `num_cores_per_device` + `local_core_id`). There is one vector clock per core (previously, one per device), and each vector clock has as many entries as there are cores in total (= `num_device` * `num_cores_per_device`). PiperOrigin-RevId: 761872013 --- jax/_src/pallas/mosaic/interpret.py | 851 +++++++++++++++++----- tests/pallas/tpu_pallas_interpret_test.py | 59 ++ 2 files changed, 713 insertions(+), 197 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index c0e52f54e6f3..4258e52e6541 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -99,6 +99,8 @@ class TPUInterpretParams: intended for inspecting the randomization of coordinates along grid dimensions with 'parallel' semantics. Default: None. + num_cores_per_device: The number of cores per device. + Default: 1. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False @@ -106,6 +108,7 @@ class TPUInterpretParams: uninitialized_memory: Literal["nan", "zero"] = "nan" random_seed: int | None = None grid_point_recorder: Callable[[tuple[jnp.int32, ...]], None] | None = None + num_cores_per_device: int = 1 VectorClock = np.ndarray @@ -115,11 +118,12 @@ class TPUInterpretParams: # of DMAs. # # Instead, we use approximate vector clocks of fixed size. We assign each DMA -# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] -- +# a virtual core ID in the range +# [num_devices*num_cores_per_device + 1, NUM_VIRTUAL_CORES], # and each operation of a DMA increments the corresponding coordinate in its -# vector clock. (So the "virtual" part of a vector clock is effectively -# counting, for each virtual device, the number of DMAs that happened-before -# the vector clock and were assigned to that virtual device.) +# vector clock. (So the "virtual" part of a vector clock is effectively +# counting, for each virtual core, the number of DMAs that happened-before +# the vector clock and were assigned to that virtual core.) # # If two approximate clocks are unordered, then their corresponding events are # not ordered by the happens-before relation. So this approximation will not @@ -128,11 +132,11 @@ class TPUInterpretParams: # clocks are ordered, and we will treat the corresponding events as ordered # by the happens-before relation, but the corresponding events are not # actually ordered. -NUM_VIRTUAL_DEVICES = 32 +NUM_VIRTUAL_CORES = 32 -def make_vector_clock(num_devices: int) -> VectorClock: - del num_devices - return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32) +def make_vector_clock(_: int) -> VectorClock: + del _ + return np.zeros(NUM_VIRTUAL_CORES, dtype=np.int32) def copy_vector_clock(x: VectorClock) -> VectorClock: if x is None: @@ -140,7 +144,7 @@ def copy_vector_clock(x: VectorClock) -> VectorClock: return x.copy() def update_vector_clock(x: VectorClock, y: VectorClock): - x[:] = np.maximum(x, y) + x[:] = np.maximum(x[:], y[:]) def lt(x: VectorClock, y: VectorClock) -> bool: return bool((x <= y).all() & (x < y).any()) @@ -148,11 +152,17 @@ def lt(x: VectorClock, y: VectorClock) -> bool: def ordered(x: VectorClock, y: VectorClock) -> bool: return lt(x, y) | lt(y, x) -def inc_vector_clock(x: VectorClock, device_id: int): - if device_id >= len(x): - raise ValueError(f'device_id={device_id} is out of range for x={x}') - assert device_id < len(x) - x[device_id] += 1 +def inc_vector_clock(x: VectorClock, global_core_id: int): + if global_core_id >= len(x): + raise ValueError(f'device_id={global_core_id} is out of range for x={x}') + assert global_core_id < len(x) + x[global_core_id] += 1 + +def _get_global_core_id(device_id, local_core_id): + """Computes the global core ID from the given device and local core ID.""" + device_id = int(device_id) + local_core_id = int(local_core_id) + return device_id * _get_shared_memory().num_cores_per_device + local_core_id class Semaphore: @@ -165,45 +175,45 @@ def __init__(self, semaphore_id=None): # easier to do when we're using single integer device IDs.) self.cv = threading.Condition() - self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32) + self.counts = np.zeros(shared_memory.num_cores, dtype=np.int32) self.interpret_params = shared_memory.interpret_params if self.interpret_params.detect_races: # We associate a vector clock with each count in self.counts. Whenever # self.counts[i] is signaled, self.clocks[i] is updated with the vector - # clock of the signaling device. Whenever device i successfully waits on - # self.counts[i], the vector clock of device i is updated with + # clock of the signaling core. Whenever core i successfully waits on + # self.counts[i], the vector clock of core i is updated with # self.clocks[i]. # # TODO(jburnim): Model happens-before more precisely for the case where # semaphores are over-signaled. - self.clocks = [None] * shared_memory.num_devices + self.clocks = [None] * shared_memory.num_cores - def signal(self, inc, device_id, clock): - """Signal the semaphore on `device_id` by `inc`. + def signal(self, inc, global_core_id, clock): + """Signal the semaphore on `(device_id, core_id)` by `inc`. Args: inc: A positive integer. The amount by which to increment the semaphore on the target device. - device_id: The ID of the target device. + global_core_id: The ID of the target core. clock: The vector clock of the signaling device at the time of the signal. """ - device_id = int(device_id) + global_core_id = int(global_core_id) with self.cv: - self.counts[device_id] += inc + self.counts[global_core_id] += inc if self.interpret_params.detect_races: - if self.clocks[device_id] is None: - self.clocks[device_id] = copy_vector_clock(clock) + if self.clocks[global_core_id] is None: + self.clocks[global_core_id] = copy_vector_clock(clock) else: - update_vector_clock(self.clocks[device_id], clock) + update_vector_clock(self.clocks[global_core_id], clock) self.cv.notify_all() - def read(self, device_id): + def read(self, global_core_id): with self.cv: - return self.counts[device_id] + return self.counts[global_core_id] - def wait(self, value, device_id, *, is_dma=False): - device_id = int(device_id) + def wait(self, value, global_core_id, *, is_dma=False): + global_core_id = int(global_core_id) shared_memory = _get_shared_memory() # TODO(jburnim): @@ -214,14 +224,14 @@ def wait(self, value, device_id, *, is_dma=False): # Simple implementation for non-DMA semaphores. if not is_dma or (self.interpret_params.dma_execution_mode == "eager"): with self.cv: - while self.counts[device_id] < value: + while self.counts[global_core_id] < value: self.cv.wait() - self.counts[device_id] -= value + self.counts[global_core_id] -= value if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) + clock = copy_vector_clock(self.clocks[global_core_id]) if self.interpret_params.detect_races: with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) + update_vector_clock(shared_memory.clocks[global_core_id], clock) return # For DMA semaphores (when dma_execution_mode=='on_wait'), while our count @@ -235,15 +245,15 @@ def wait(self, value, device_id, *, is_dma=False): while True: clock = None with self.cv: - if self.counts[device_id] >= value: - self.counts[device_id] -= value + if self.counts[global_core_id] >= value: + self.counts[global_core_id] -= value if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) + clock = copy_vector_clock(self.clocks[global_core_id]) else: return if clock is not None: with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) + update_vector_clock(shared_memory.clocks[global_core_id], clock) return with shared_memory.lock: @@ -258,25 +268,32 @@ def wait(self, value, device_id, *, is_dma=False): with dma.lock: if dma.virtual_device_id is None: dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + shared_memory.num_devices, NUM_VIRTUAL_CORES) if dma.state == DmaState.STARTED: # Do the read. if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) dma.data = get(dma.src_device_id, + dma.src_local_core_id, dma.src_memory_space, dma.src_buffer_id, dma.src_transforms, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: data_size = dma.data.itemsize * dma.data.size dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.src_device_id, dma.src_local_core_id + ), + clock=dma.clock, + ) dma.state = DmaState.READ if dma.src_sem is self: @@ -290,18 +307,25 @@ def wait(self, value, device_id, *, is_dma=False): if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) store(dma.dst_device_id, + dma.dst_local_core_id, dma.dst_memory_space, dma.dst_buffer_id, dma.dst_transforms, dma.data, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) data_size = dma.data.itemsize * dma.data.size dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.dst_device_id, dma.dst_local_core_id + ), + clock=dma.clock, + ) dma.data = None dma.state = DmaState.COMPLETED @@ -317,10 +341,12 @@ class DMA: id: int src_device_id: int + src_local_core_id: int src_memory_space: int src_buffer_id: int src_transforms: tuple[Any, ...] dst_device_id: int + dst_local_core_id: int dst_memory_space: int dst_buffer_id: int dst_transforms: tuple[Any, ...] @@ -339,13 +365,14 @@ class DMA: @dataclasses.dataclass class RaceDetectionState: - num_devices: int + num_cores: int + - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] reads: dict = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] writes: dict = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) @@ -387,7 +414,10 @@ def ranges_overlap(range1: tuple[slice | int, ...], return all(slices_overlap(r1, r2) for r1, r2 in itertools.zip_longest(range1, range2, fillvalue=slice(None))) -def check_read(device_id, clock, buffer_key, rnge, source_info=None): + +def check_read( + device_id, local_core_id, clock, buffer_key, rnge, source_info=None +): if source_info is not None: user_frame = source_info_util.summarize(source_info) else: @@ -396,24 +426,36 @@ def check_read(device_id, clock, buffer_key, rnge, source_info=None): with races.lock: writes = races.writes[buffer_key] num_writes = len(writes) - races.reads[buffer_key].append((device_id, clock, rnge, user_frame)) + races.reads[buffer_key].append( + (device_id, local_core_id, clock, rnge, user_frame) + ) for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] if ordered(write_clock, clock): continue if not ranges_overlap(rnge, write_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + print( + f'RACE DETECTED\n read of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id} {write_frame}' + ) with races.lock: races.races_found = True return -def check_write(device_id, clock, buffer_key, rnge, source_info=None): + +def check_write(device_id, local_core_id, clock, buffer_key, rnge, source_info=None): if source_info is not None: user_frame = source_info_util.summarize(source_info) else: @@ -424,37 +466,50 @@ def check_write(device_id, clock, buffer_key, rnge, source_info=None): reads = races.reads[buffer_key] num_writes = len(writes) num_reads = len(reads) - races.writes[buffer_key].append((device_id, clock, rnge, user_frame)) + races.writes[buffer_key].append((device_id, local_core_id, clock, rnge, user_frame)) # TODO(jburnim): For performance, we should also probably remove any # conflicting reads and writes that happened-before the current write. for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] if ordered(write_clock, clock): continue if not ranges_overlap(rnge, write_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id}, {write_frame}' + ) with races.lock: races.races_found = True break for i in range(num_reads): - read_device_id, read_clock, read_range, read_frame = reads[i] + read_device_id, read_local_core_id, read_clock, read_range, read_frame = ( + reads[i] + ) if ordered(read_clock, clock): continue if not ranges_overlap(rnge, read_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}') + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n read of {buffer_key}[{read_range}]' + f' from {read_device_id}, {read_local_core_id}, {read_frame}' + ) with races.lock: races.races_found = True return @@ -464,13 +519,14 @@ def check_write(device_id, clock, buffer_key, rnge, source_info=None): class SharedMemory: interpret_params: TPUInterpretParams num_devices: int + num_cores_per_device: int clocks: list[VectorClock] barrier: threading.Barrier clean_up_barrier: threading.Barrier - # (memory_space, buffer_id, device_id) -> NumPy array + # (memory_space, buffer_id, device_id, local_core_id) -> NumPy array # TODO(jburnim): Handle Megacore. - mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field( + mem: dict[tuple[str, int, int, int], np.ndarray] = dataclasses.field( default_factory=dict) # semaphore_id -> Semaphore @@ -478,15 +534,18 @@ class SharedMemory: # (semaphore_id, device_id) # -> list of DMAs that will signal the semaphore on the given device + # TODO(jburnim): Fix uses of `dmas_by_sem` to align with the two lines of + # documentation above, i.e. index `dmas_by_sem` with + # `(semaphore_id, device_id)` (currently indexed with `semaphore_id only). dmas_by_sem: dict[tuple[int, int], list[DMA]] = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) - # device_id -> next buffer ID - next_buffer_id: dict[int, int] = dataclasses.field( + # (device_id, local_core_id) -> next buffer ID + next_buffer_id: dict[tuple[int, int], int] = dataclasses.field( default_factory=lambda: collections.defaultdict(lambda: 100)) - # device_id -> next semaphore ID + # global_core_id -> next semaphore ID next_semaphore_id: dict[int, int] = dataclasses.field( default_factory=lambda: collections.defaultdict(lambda: 2000)) @@ -494,6 +553,10 @@ class SharedMemory: deallocated_bytes: int = 0 + @property + def num_cores(self) -> int: + return self.num_devices * self.num_cores_per_device + # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? @@ -510,34 +573,54 @@ def _clear_shared_memory(): with _shared_memory_init_lock: _shared_memory = None -def _initialize_shared_memory(device_id, num_devices, *, interpret_params): + +def _initialize_shared_memory( + device_id, num_devices, num_cores_per_device, *, interpret_params +): global _shared_memory del device_id num_devices = int(num_devices) + num_cores_per_device = int(num_cores_per_device) + num_cores = num_devices * num_cores_per_device with _shared_memory_init_lock: if _shared_memory is None: _shared_memory = SharedMemory( interpret_params=interpret_params, num_devices=num_devices, - clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], + num_cores_per_device=num_cores_per_device, + clocks=[make_vector_clock(num_cores) for _ in range(num_cores)], barrier=threading.Barrier( num_devices, action=_update_clocks_for_global_barrier), clean_up_barrier=threading.Barrier( num_devices, action=_clear_shared_memory)) - assert _shared_memory.num_devices == num_devices + assert _shared_memory.num_cores == num_cores global races - races = RaceDetectionState(num_devices=num_devices) + races = RaceDetectionState(num_cores=num_cores) -def _update_clocks_for_global_barrier(): +def _update_clocks(low_global_core_id, high_global_core_id): + """Synchronizes the vector clocks for the cores with ids in the range between the two arguments.""" shared_memory = _get_shared_memory() + # Despite only updating the vector clocks for some cores, we still need to + # hold the global lock to ensure that no other devices are concurrently + # accessing the same vector clocks. with shared_memory.lock: - # Set the vector clock for device 0 to the max over all device clocks. - for c in shared_memory.clocks[1:]: - update_vector_clock(shared_memory.clocks[0], c) - # Set all other device vector clocks to the max over all the clocks. - for c in shared_memory.clocks[1:]: - update_vector_clock(c, shared_memory.clocks[0]) + for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: + update_vector_clock(shared_memory.clocks[low_global_core_id], c) + for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: + update_vector_clock(c, shared_memory.clocks[low_global_core_id]) + +def _update_clocks_for_device_barrier(device_id): + """Synchronizes the vector clocks for the cores on the given device.""" + shared_memory = _get_shared_memory() + low_core_id = device_id * shared_memory.num_cores_per_device + high_core_id = (device_id + 1) * shared_memory.num_cores_per_device + _update_clocks(low_core_id, high_core_id) + +def _update_clocks_for_global_barrier(): + """Synchronizes all vector clocks.""" + shared_memory = _get_shared_memory() + _update_clocks(0, shared_memory.num_cores) def _barrier(device_id): device_id = int(device_id) @@ -564,30 +647,80 @@ def _validate(device_id): f'Semaphore {sem.id} has non-zero count for {device_id} at ' f'kernel exit: {sem.counts[device_id]}') -def _allocate_buffer(device_id, memory_space, val): +def _allocate_buffer( + device_id: Array, + local_core_id: Array | None, + memory_space: Array, + val: Array, +): + """Allocates a memory buffer on the device with id `device_id` and core with id `local_core_id`. + + Args: + device_id: Singleton array holding the device id where the buffer will be + allocated. + local_core_id: None or singleton array holding the core id where the buffer + will be allocated. If None, a buffer will be allocated on each cores on + the device. + memory_space: Singleton array indicating the memory space to allocate the + buffer in. If the corresponding memory space is "any" (i.e. HBM), at most + one buffer will be allocated and it will belong to (local) core id 0. + val: Array of values to initialize the allocated buffer with. + + Returns: + Integer id for the allocated buffer. + """ device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + memory_space_str = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + del memory_space val = np.array(val) shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + local_core_ids = tuple(range(shared_memory.num_cores_per_device)) + else: + local_core_id_int = int(local_core_id) + local_core_ids = (local_core_id_int,) + del local_core_id + + local_core_id_to_buffer_id = {} with shared_memory.lock: - buffer_id = shared_memory.next_buffer_id[device_id] - shared_memory.next_buffer_id[device_id] = buffer_id + 1 - # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, - # with zeros, or with the buffer ID). - shared_memory.mem[(memory_space, buffer_id, device_id)] = val + for lci in local_core_ids: + buffer_id = shared_memory.next_buffer_id[(device_id, lci)] + shared_memory.next_buffer_id[(device_id, lci)] = buffer_id + 1 + if lci == 0 or memory_space_str != 'any': + # If allocating in HBM, only actually allocate a buffer for local core + # id 0. + # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, + # with zeros, or with the buffer ID). + shared_memory.mem[(memory_space_str, buffer_id, device_id, lci)] = val + + local_core_id_to_buffer_id[lci] = buffer_id + + # The buffer ids should always be kept in sync across all cores. + assert all( + buffer_id == local_core_id_to_buffer_id[local_core_id_int] + for buffer_id in local_core_id_to_buffer_id.values() + ) # TODO(jburnim): Raise an error if buffer_id is too big for int16. - return np.int16(buffer_id) + return np.int16(local_core_id_to_buffer_id[local_core_id_int]) -def _deallocate_buffer(device_id, memory_space, buffer_id): +def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) + if memory_space == 'any': + local_core_id = 0 + shared_memory = _get_shared_memory() with shared_memory.lock: - buff = shared_memory.mem.pop((memory_space, buffer_id, device_id)) + buff = shared_memory.mem.pop( + (memory_space, buffer_id, device_id, local_core_id) + ) shared_memory.deallocated_bytes += buff.size * buff.itemsize del buff @@ -600,26 +733,80 @@ def _deallocate_buffer(device_id, memory_space, buffer_id): # why arrays are not getting freed without this. gc.collect() -def _allocate_semaphores(device_id, shape): + +def _allocate_semaphores( + device_id: Array, local_core_id: Array | None, shape: Array +): + """Allocates semaphores on the device with id `device_id` and core with id `local_core_id`. + + The number of sempahores allocated is given by the product of the entries in + `shape`. + + Since for each semaphore id there is really only one global `Semaphore` + object, 'allocation' of semaphores per device and core here means that the + internal counter of semaphore ids that is held by `SharedMemory` is + incremented for each the device and core (or for all cores on the dive if + argument `local_core_id` is None, see below). + + Args: + device_id: Singleton array holding the id for the device where the + semaphores will be allocated. + local_core_id: None or singleton array holding the id for the core where the + semaphores will be allocated. If None, semaphores will be allocated on all + cores on the device. + shape: Shape of the semaphore array to allocate. + + Returns: + Array of semaphore ids. + """ device_id = int(device_id) shape = tuple(map(int, shape)) num_semaphores = math.prod(shape) shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + global_core_ids = tuple( + _get_global_core_id(device_id, core_id) + for core_id in range(shared_memory.num_cores_per_device) + ) + else: + local_core_id_int = int(local_core_id) + global_core_ids = (_get_global_core_id(device_id, local_core_id_int),) + del local_core_id + + global_core_id_to_semaphore_id = {} with shared_memory.lock: - semaphore_id = shared_memory.next_semaphore_id[device_id] - shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores - for i in range(semaphore_id, semaphore_id + num_semaphores): - if i not in shared_memory.sem: - shared_memory.sem[i] = Semaphore(i) + for gci in global_core_ids: + semaphore_id = shared_memory.next_semaphore_id[gci] + shared_memory.next_semaphore_id[gci] = ( + semaphore_id + num_semaphores + ) + + # Ensure that only one global `Semaphore` object is allocated for each + # `semaphore_id`. + for i in range(semaphore_id, semaphore_id + num_semaphores): + if i not in shared_memory.sem: + shared_memory.sem[i] = Semaphore(i) + + global_core_id_to_semaphore_id[gci] = semaphore_id + + global_core_id = _get_global_core_id(device_id, local_core_id_int) + # The semaphore ids should always be kept in sync across all cores. + assert all( + semaphore_id == global_core_id_to_semaphore_id[global_core_id] + for semaphore_id in global_core_id_to_semaphore_id.values() + ) # NOTE: For now, we use a relatively uncommon datatype (int16) for # semaphore (and buffer) IDs, so these values are more easily identifiable # in kernels. # # TODO(jburnim): Raise an error if any IDs are too big for int16. - return np.int16( - range(semaphore_id, semaphore_id + num_semaphores) + semaphore_id = global_core_id_to_semaphore_id[global_core_id] + return np.arange( + semaphore_id, semaphore_id + num_semaphores, dtype=np.int16 ).reshape(shape) @@ -693,24 +880,48 @@ def _to_range(transforms) -> tuple[slice | int, ...]: ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) return ret -def get(device_id, memory_space, buffer_id, transforms, *, - src_device_id=None, clock=None, source_info=None): +def _to_int(x : int | Array | None) -> int | None: + """Converts a value to an integer, or returns None if the value is None.""" + if x is None: + return None + return int(x) + +def get( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: transforms = jax.tree.map(int, transforms) except: raise ValueError('Advanced indexers are not supported on TPU') + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: read_range = _to_range(transforms) if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) + buffer = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] ret = buffer[read_range].copy() if transforms: # TODO(jburnim): Instead of using NDIndexer, do the computation ourselves @@ -718,20 +929,43 @@ def get(device_id, memory_space, buffer_id, transforms, *, expected_shape = transforms[-1].get_indexer_shape() if expected_shape != ret.shape[:len(expected_shape)]: raise ValueError( - f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' - f'reading [{read_range}] but bufer has shape {buffer.shape} .') + 'Out-of-bounds read of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' reading [{read_range}] but bufer has shape {buffer.shape} .' + ) if shared_memory.interpret_params.detect_races: if src_device_id is None: src_device_id = device_id - check_read(src_device_id, clock, (memory_space, buffer_id, device_id), - read_range, source_info=source_info) + if src_local_core_id is None: + src_local_core_id = local_core_id + check_read( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_range, + source_info=source_info, + ) return ret -def store(device_id, memory_space, buffer_id, transforms, val, *, - src_device_id=None, clock=None, source_info=None): + +def store( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -739,38 +973,67 @@ def store(device_id, memory_space, buffer_id, transforms, val, *, except: raise ValueError('Advanced indexers are not supported on TPU') val = np.array(val) + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + buff = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. write_range = _to_range(transforms) # TODO(jburnim): Better error message if this raises? in_bounds_shape = buff[write_range].shape if in_bounds_shape != val.shape: raise ValueError( - f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): ' - f'writing [{write_range}] but buffer has shape {buff.shape} .') + 'Out-of-bounds write of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}): writing' + f' [{write_range}] but buffer has shape {buff.shape} .' + ) buff[write_range] = val if shared_memory.interpret_params.detect_races: if src_device_id is None: src_device_id = device_id - check_write(src_device_id, clock, (memory_space, buffer_id, device_id), - write_range, source_info=source_info) + if src_local_core_id is None: + src_local_core_id = local_core_id + check_write( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + write_range, + source_info=source_info, + ) + -def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, - source_info=None): +def swap( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + mask, + *, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: transforms = jax.tree.map(int, transforms) + # jax.debug.print(f'swap: {transforms}') except: raise ValueError('Advanced indexers are not supported on TPU') val = np.array(val) @@ -778,12 +1041,17 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, if mask is not None: assert mask.shape == val.shape + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) + shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) + buff = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. read_write_range = _to_range(transforms) # TODO(jburnim): Better error message if this raises? @@ -792,8 +1060,11 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, if mask is None: if in_bounds_shape != val.shape: raise ValueError( - f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} .') + 'Out-of-bounds swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' swapping [{read_write_range}] but buffer has shape' + f' {buff.shape} .' + ) buff[read_write_range] = val return raw_result.copy() @@ -804,8 +1075,10 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, # TODO(jburnim): Include indices of out-of-bounds locations where mask # is True. raise ValueError( - f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ') + 'Out-of-bounds masked swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}): swapping' + f' [{read_write_range}] but buffer has shape {buff.shape} . ' + ) in_bounds_idx = tuple(slice(i) for i in in_bounds_shape) result = val.copy() @@ -815,8 +1088,14 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, mask[in_bounds_idx], val[in_bounds_idx], raw_result) if shared_memory.interpret_params.detect_races: - check_write(device_id, clock, (memory_space, buffer_id, device_id), - read_write_range, source_info=source_info) + check_write( + device_id, + local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_write_range, + source_info=source_info, + ) return result def execute_dma(dma): @@ -828,17 +1107,19 @@ def execute_dma(dma): if dma.virtual_device_id is None: # See comment in Semaphore.wait . dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + shared_memory.num_cores, NUM_VIRTUAL_CORES) # Do the read. if shared_memory.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) dma.data = get(dma.src_device_id, + dma.src_local_core_id, dma.src_memory_space, dma.src_buffer_id, dma.src_transforms, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) data_size = dma.data.itemsize * dma.data.size @@ -847,19 +1128,26 @@ def execute_dma(dma): inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.src_device_id, dma.src_local_core_id + ), + clock=dma.clock, + ) dma.state = DmaState.READ # Do the write. if shared_memory.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) store(dma.dst_device_id, + dma.dst_local_core_id, dma.dst_memory_space, dma.dst_buffer_id, dma.dst_transforms, dma.data, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) # Signal the receive semaphore. @@ -867,7 +1155,12 @@ def execute_dma(dma): inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.dst_sem is not None: dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.dst_device_id, dma.dst_local_core_id + ), + clock=dma.clock, + ) dma.data = None dma.state = DmaState.COMPLETED @@ -879,11 +1172,24 @@ def print_memory(device_id): with shared_memory.lock: print(shared_memory.mem) -def dma_start(device_id, src_memory_space, src_id, src_transforms, - dst_memory_space, dst_id, dst_transforms, - dst_sem_id, src_sem_id, dst_device_id, - source_info=None): + +def dma_start( + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_memory_space, + dst_id, + dst_transforms, + dst_sem_id, + src_sem_id, + dst_device_id, + source_info=None, +): device_id = int(device_id) + src_local_core_id = int(src_local_core_id) + src_global_core_id = _get_global_core_id(device_id, src_local_core_id) src_memory_space, src_id = int(src_memory_space), int(src_id) src_transforms = jax.tree.map(int, src_transforms) dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) @@ -902,15 +1208,25 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, clock = None if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) + inc_vector_clock( + shared_memory.clocks[src_global_core_id], src_global_core_id + ) + clock = copy_vector_clock(shared_memory.clocks[src_global_core_id]) dma_id = shared_memory.next_dma_id shared_memory.next_dma_id += 1 dma = DMA( dma_id, - device_id, src_memory_space, src_id, src_transforms, - dst_device_id, dst_memory_space, dst_id, dst_transforms, + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_device_id, + src_local_core_id, # Same core on destination device as on source. + dst_memory_space, + dst_id, + dst_transforms, src_sem, dst_sem, clock=clock, @@ -926,52 +1242,61 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, assert shared_memory.interpret_params.dma_execution_mode == 'eager' execute_dma(dma) -def dma_wait(device_id, sem_id, size): +def dma_wait(device_id, local_core_id, sem_id, size): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) size = int(size) + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) sem = shared_memory.sem[sem_id] - sem.wait(size, device_id, is_dma=True) + sem.wait(size, global_core_id, is_dma=True) -def semaphore_signal(device_id, sem_id, inc, target_device_id, - target_core_index): +def semaphore_signal(device_id, local_core_id, sem_id, inc, target_device_id, + target_local_core_id): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) inc = int(inc) + src_global_core_id = _get_global_core_id(device_id, local_core_id) if target_device_id is None: target_device_id = device_id else: target_device_id = int(target_device_id) - if target_core_index is not None: - if int(target_core_index) != 0: - raise NotImplementedError('semaphore_signal with target_core_index != 0') + if target_local_core_id is None: + target_local_core_id = 0 shared_memory = _get_shared_memory() with shared_memory.lock: clock = None if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) + inc_vector_clock( + shared_memory.clocks[src_global_core_id], src_global_core_id + ) + clock = copy_vector_clock(shared_memory.clocks[src_global_core_id]) sem = shared_memory.sem[sem_id] - sem.signal(inc, target_device_id, clock) + sem.signal( + inc, _get_global_core_id(target_device_id, target_local_core_id), clock + ) -def semaphore_wait(device_id, sem_id, value): +def semaphore_wait(device_id, local_core_id, sem_id, value): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) value = int(value) + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) sem = shared_memory.sem[sem_id] - sem.wait(value, device_id) + sem.wait(value, global_core_id) def _compute_transformed_shape_and_dtype(shape, dtype, transforms): for transform in transforms: @@ -1022,7 +1347,10 @@ class Placeholder: shape: tuple[int, ...] dtype: jnp.dtype -def _interpret_jaxpr(jaxpr, *args, mesh, compiler_params, interpret_params): + +def _interpret_jaxpr( + jaxpr, *args, mesh, local_core_id, compiler_params, interpret_params +): env = {} def read(var): @@ -1054,8 +1382,12 @@ def write(var, value): # - Handle other higher-order primitives? # - Megacore. _interpret = functools.partial( - _interpret_jaxpr, mesh=mesh, compiler_params=compiler_params, - interpret_params=interpret_params) + _interpret_jaxpr, + mesh=mesh, + local_core_id=local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) for eqn in jaxpr.eqns: with source_info_util.user_context( eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): @@ -1065,7 +1397,9 @@ def write(var, value): # not need to do any reads if `interpret_params.skip_floating_point_ops` # is True. If this is the case, we want to avoid materializing the read # array into the jaxpr when this function is traced. - deferred_invals = functools.partial(jax._src.util.safe_map, read, eqn.invars) + deferred_invals = functools.partial( + jax._src.util.safe_map, read, eqn.invars + ) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( @@ -1076,6 +1410,7 @@ def write(var, value): functools.partial(get, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], ref, transforms, @@ -1088,6 +1423,7 @@ def write(var, value): functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], ref, transforms, @@ -1174,7 +1510,8 @@ def f(*args, jaxpr): 'run_scoped_p with collective axes is not supported' ) # Allocate a buffer or semaphore for each element of - # eqn.params['jaxpr'].invars . + # eqn.params['jaxpr'].invars. It is assumed that each core + # runs the same sequence of `run_scoped`s. allocs = [] for v in eqn.params['jaxpr'].invars: if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: @@ -1182,6 +1519,7 @@ def f(*args, jaxpr): _allocate_semaphores, jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), device_id, + local_core_id, v.aval.shape, ordered=True)) else: @@ -1189,6 +1527,7 @@ def f(*args, jaxpr): _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], _uninitialized_value( v.aval.shape, v.aval.dtype, interpret_params), @@ -1211,6 +1550,7 @@ def f(*args, jaxpr): _deallocate_buffer, None, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], a, ordered=True) @@ -1221,6 +1561,7 @@ def f(*args, jaxpr): functools.partial(get, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[1:]), @@ -1232,6 +1573,7 @@ def f(*args, jaxpr): functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[2:]), @@ -1259,6 +1601,7 @@ def f(*args, jaxpr): functools.partial(dma_start, source_info=eqn.source_info), (), device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], src, src_transforms, TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], @@ -1287,6 +1630,7 @@ def f(*args, jaxpr): dma_wait, (), device_id, + local_core_id, state_discharge.transform_array(dst_sem, dst_sem_transforms), math.prod(read_shape) * read_dtype.itemsize, ordered=True) @@ -1309,6 +1653,7 @@ def f(*args, jaxpr): semaphore_signal, (), device_id, + local_core_id, state_discharge.transform_array(sem, sem_transforms), inc, target_device_id, @@ -1323,6 +1668,7 @@ def f(*args, jaxpr): semaphore_wait, (), device_id, + local_core_id, state_discharge.transform_array(sem, sem_transforms), value, ordered=True) @@ -1358,8 +1704,15 @@ def _compute_start_indices( block_mapping, loop_idx, *args, mesh, compiler_params, interpret_params): jaxpr = block_mapping.index_map_jaxpr block_indices = _interpret_jaxpr( - jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, mesh=mesh, - compiler_params=compiler_params, interpret_params=interpret_params) + jaxpr.jaxpr, + *jaxpr.consts, + *loop_idx, + *args, + mesh=mesh, + local_core_id=0, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) def _get_start_index(i, b): match b: case pallas_core.Squeezed(): @@ -1397,12 +1750,12 @@ def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) - def _get_parallel_dim_semantics( - compiler_params: dict[str, Any], grid: tuple[int, ...] + compiler_params: dict[str, Any], num_dimensions_in_grid: int, ) -> tuple[bool, ...]: - """Returns a tuple of booleans indicating whether the corresponding dimension in `grid` is parallel.""" + """Returns a tuple of booleans indicating whether the corresponding dimension in the grid is parallel.""" mosaic_params = _get_mosaic_params(compiler_params) if mosaic_params.dimension_semantics is None: - return (False,) * len(grid) + return (False,) * num_dimensions_in_grid return tuple(ds == 'parallel' for ds in mosaic_params.dimension_semantics) _GridPointCoordinatesPerDim = tuple[Array, ...] @@ -1432,7 +1785,7 @@ def _get_randomized_grid_coordinates( dimensions. """ parallel_semantics_per_dim = _get_parallel_dim_semantics( - compiler_params, grid + compiler_params, len(grid) ) key = jax.random.key(random_seed or 0) @@ -1484,6 +1837,23 @@ def _get_grid_point( return jnp.array(grid_point, dtype=np.int32) +def _get_next_local_core_id( + local_core_id: int, + parallel_semantics_per_dim: tuple[bool, ...], + grid_point: Array, + next_grid_point: Array, + interpret_params: TPUInterpretParams, +) -> int: + delta = next_grid_point - grid_point + assert delta.shape == (len(parallel_semantics_per_dim),) + parallel_semantics_per_dim = jnp.array(parallel_semantics_per_dim) + deltas_along_parallel_dims = jnp.where(parallel_semantics_per_dim, delta, 0) + return jax.lax.cond( + jnp.any(deltas_along_parallel_dims), + lambda: (local_core_id + 1) % interpret_params.num_cores_per_device, + lambda: local_core_id, + ) + def _uninitialized_value(shape, dtype, interpret_params): if interpret_params.uninitialized_memory == 'nan': if jnp.issubdtype(dtype, jnp.floating): @@ -1562,6 +1932,7 @@ def interpret_pallas_call( (), device_id, num_devices, + interpret_params.num_cores_per_device, ordered=True) # Pad input arguments. @@ -1591,6 +1962,7 @@ def interpret_pallas_call( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, + None, # local_core_id TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], input_args[i], ordered=True)) @@ -1613,14 +1985,19 @@ def interpret_pallas_call( bm.array_shape_dtype.dtype, interpret_params) padded_val = _pad_to_block_dimension( - out_val, output_block_shapes[i], interpret_params) - output_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, - ordered=True)) + out_val, output_block_shapes[i], interpret_params + ) + output_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + padded_val, + ordered=True, + ) + ) output_buffer_shapes.append(padded_val.shape) output_vals.append(out_val) @@ -1630,25 +2007,34 @@ def interpret_pallas_call( for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): assert var.aval.shape == val.shape assert var.aval.dtype == val.dtype - scalar_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], - val, - ordered=True)) + scalar_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], + val, + ordered=True, + ) + ) + kernel_buffer_ids = scalar_buffer_ids.copy() for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - kernel_buffer_ids.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), - device_id, - var.aval.shape, - ordered=True)) + kernel_buffer_ids.append( + callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), + device_id, + None, # local_core_id + var.aval.shape, + ordered=True, + ) + ) elif _is_any(var.aval.memory_space): # Use the already-allocated HBM input or output buffer. # @@ -1661,14 +2047,19 @@ def interpret_pallas_call( if is_output: kernel_buffer_ids.append(output_buffer_ids[output_idx]) else: - kernel_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - _uninitialized_value( - var.aval.shape, var.aval.dtype, interpret_params), - ordered=True)) + kernel_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + _uninitialized_value( + var.aval.shape, var.aval.dtype, interpret_params + ), + ordered=True, + ) + ) if _get_mosaic_params(compiler_params).collective_id is None: # The kernel doesn't specify its own barrier semaphore, so we do a global @@ -1687,6 +2078,9 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 + parallel_semantics_per_dim = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) randomized_grid_coordinates = _get_randomized_grid_coordinates( grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] ) @@ -1703,18 +2097,38 @@ def _get_local_grid_env(loop_idx): def body( carry: tuple[ - jnp.int32, tuple[jnp.int32, ...], list[jnp.ndarray], list[jnp.ndarray] + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + jnp.int32, + jnp.int32, + list[jnp.ndarray], + list[jnp.ndarray], ], - ): + ) -> tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + jnp.int32, + jnp.int32, + list[jnp.ndarray], + list[jnp.ndarray], + ]: """Performs a single iteration of `jaxpr` in the device grid. Execution of `jaxpr` is preceded by reading kernel input buffers and followed by writing kernel output buffers. Args: - carry: (iteration_idx, loop_idx, prev_start_indices, cur_start_indices). + carry: (iteration_idx, loop_idx, grid_point, prev_local_core_id, + cur_local_core_id, prev_start_indices, cur_start_indices). - iteration_idx is the interation index. - loop_idx are the program ids for each grid axis. + - grid_point is the grid point for the current loop iteration. + - prev_local_core_id is the (device-local) core id from the previous + loop iteration. + - cur_local_core_id is the (device-local) core id for the current loop + iteration. - prev_start_indices is a rank-1 array that contains the start indices for the slices of inputs and outputs processed in the previous loop iteration. @@ -1729,9 +2143,16 @@ def body( Returns: The carry for the next iteration. """ - iteration_idx, loop_idx, prev_start_indices, cur_start_indices = carry + ( + iteration_idx, + loop_idx, + grid_point, + prev_local_core_id, + cur_local_core_id, + prev_start_indices, + cur_start_indices, + ) = carry if interpret_params.grid_point_recorder is not None: - grid_point = _get_grid_point(loop_idx, randomized_grid_coordinates) callback.io_callback(interpret_params.grid_point_recorder, (), grid_point) with pallas_core.grid_env(_get_local_grid_env(loop_idx)): @@ -1739,6 +2160,13 @@ def body( next_grid_point = _get_grid_point( next_loop_idx, randomized_grid_coordinates ) + next_local_core_id = _get_next_local_core_id( + cur_local_core_id, + parallel_semantics_per_dim, + grid_point, + next_grid_point, + interpret_params, + ) next_start_indices = [ _compute_start_indices( bm, @@ -1774,6 +2202,7 @@ def _store_slice_to_kernel_input(index, input_var): get, jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), device_id, + cur_local_core_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], input_buffer_ids[index], (transform,), @@ -1785,6 +2214,7 @@ def _store_slice_to_kernel_input(index, input_var): store, (), device_id, + cur_local_core_id, TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], input_ids[index], (), @@ -1799,6 +2229,7 @@ def _store_slice_to_kernel_input(index, input_var): assert len(prev_start_indices[j].shape) == 1 jax.lax.cond( (iteration_idx == 0) + | (cur_local_core_id != prev_local_core_id) | jax.lax.reduce_or( cur_start_indices[j] != prev_start_indices[j], axes=(0,) ), @@ -1807,9 +2238,14 @@ def _store_slice_to_kernel_input(index, input_var): ) # Invoke the kernel. - _interpret_jaxpr(jaxpr, *kernel_buffer_ids, mesh=mesh, - compiler_params=compiler_params, - interpret_params=interpret_params) + _interpret_jaxpr( + jaxpr, + *kernel_buffer_ids, + mesh=mesh, + local_core_id=cur_local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) # Copy from the kernel buffers to slices of the output in HBM. def _store_to_output_buffer(index, output_var): @@ -1819,6 +2255,7 @@ def _store_to_output_buffer(index, output_var): get, output_var.aval, device_id, + cur_local_core_id, TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], kernel_output_ids[j], (), @@ -1842,6 +2279,7 @@ def _store_to_output_buffer(index, output_var): store, (), device_id, + cur_local_core_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], output_buffer_ids[index], (transform,), @@ -1856,6 +2294,7 @@ def _store_to_output_buffer(index, output_var): assert len(next_start_indices[num_inputs + j].shape) == 1 jax.lax.cond( (iteration_idx + 1 == num_iterations) + | (cur_local_core_id != next_local_core_id) | jax.lax.reduce_or( cur_start_indices[num_inputs + j] != next_start_indices[num_inputs + j], @@ -1865,7 +2304,15 @@ def _store_to_output_buffer(index, output_var): lambda: None, ) - return iteration_idx + 1, next_loop_idx, cur_start_indices, next_start_indices + return ( + iteration_idx + 1, + next_loop_idx, + next_grid_point, + cur_local_core_id, + next_local_core_id, + cur_start_indices, + next_start_indices, + ) initial_loop_idx = (jnp.int32(0),) * len(grid) initial_grid_point = _get_grid_point( @@ -1884,16 +2331,25 @@ def _store_to_output_buffer(index, output_var): for bm in grid_mapping.block_mappings ] # TODO(jburnim): Handle parallel grid dimensions + megacore. + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True + ) _ = lax.while_loop( lambda carry: carry[0] < num_iterations, body, ( jnp.int32(0), initial_loop_idx, + initial_grid_point, + jnp.int32(0), # Previous core id is ignored on the first iteration. + jnp.int32(0), # Current core id is set to 0 for the first iteration. initial_start_indices, # Previous start indices are ignored on the first iteration. initial_start_indices, ), ) + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True + ) # Read the output from the allocated output buffers. ret = [ @@ -1903,6 +2359,7 @@ def _store_to_output_buffer(index, output_var): get, val, device_id, + 0, # local_core_id TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], output_buffer_id, (indexing.NDIndexer.from_indices_shape( diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 1af4b29d60ff..28c63dc3bd9b 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -521,6 +521,65 @@ def alloc(x_vmem_ref, y_vmem_ref, sem): y = f(x) np.testing.assert_array_equal(y, x + 1) + def test_two_cores_along_parallel_dimension_with_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((8, 128), jnp.float32) + y = pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + ], + interpret=mosaic_interpret.TPUInterpretParams( + num_cores_per_device=2, + detect_races=True, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',), + ), + )(x) + self.assertTrue(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + + def test_two_cores_along_parallel_dimension_no_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((16, 128), jnp.float32) + y = pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_specs=pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + in_specs=[ + pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + ], + scratch_shapes=[ + pltpu.VMEM((8, 128), x.dtype), + ], + interpret=mosaic_interpret.TPUInterpretParams( + num_cores_per_device=2, + detect_races=True, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), + )(x) + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From c670a2803897b06f86952a705ab93709bab1fb6f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 22 May 2025 02:36:51 -0700 Subject: [PATCH 1304/1769] [Pallas/Mosaic GPU] Add a Pallas:MGPU implementation of `ragged_dot`. PiperOrigin-RevId: 761875827 --- .../pallas/ops/gpu/ragged_dot_mgpu.py | 327 ++++++++++++++++++ tests/pallas/BUILD | 37 ++ tests/pallas/mgpu_ragged_dot_test.py | 114 ++++++ 3 files changed, 478 insertions(+) create mode 100644 jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py create mode 100644 tests/pallas/mgpu_ragged_dot_test.py diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py new file mode 100644 index 000000000000..6d295a36f435 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -0,0 +1,327 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ragged dot Pallas-Mosaic-GPU implementation.""" + +import dataclasses +import functools +import itertools +import math +import jax +from jax import lax +from jax import numpy as jnp +from jax import random +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class GroupInfo: + """Information regarding the group being processed in a block.""" + + group_id: jax.Array + block: jax.Array + block_start: jax.Array + actual_start: jax.Array + actual_end: jax.Array + start_within_block: jax.Array + actual_size: jax.Array + + @classmethod + def create(cls, group_lengths, tile, tid): + """Get the group info for the current block.""" + + tile = jnp.int32(tile) + group_boundaries = [group_lengths[i] for i in range(group_lengths.shape[0])] + + # We usually only have very few groups, so we unroll the loop processing + # them. Normally we'd break out of the loop early, once we'd have found our + # boundary, but we can't do that when unrolling, so we rely on many selects + # to mask out the epilogue of the loop. + group_end = group_start = block = group = end = jnp.array( + 0, dtype=jnp.int32 + ) + + for i, b in enumerate(group_boundaries): + # Start/end are inclusive + start = end + end = start + b + final = end - 1 + start_block = lax.div(start, tile) + final_block = lax.div(final, tile) + block_end = final_block + 1 + tid_begin = start_block + i + tid_end = block_end + i + # How many blocks after is our block? + this_is_group = (tid_begin <= tid) & (tid < tid_end) + block = lax.select(this_is_group, tid - tid_begin + start_block, block) + group = lax.select(this_is_group, jnp.int32(i), group) + group_start = lax.select(this_is_group, start, group_start) + group_end = lax.select(this_is_group, end, group_end) + + block_start = block * tile + actual_start = jnp.maximum(group_start, block_start) + actual_end = jnp.minimum(group_end, block_start + tile) + start_within_block = actual_start - block_start + actual_size = actual_end - actual_start + return cls( + group_id=group, + block=block, + block_start=block_start, + actual_start=actual_start, + actual_end=actual_end, + start_within_block=start_within_block, + actual_size=actual_size, + ) + + +def _find_swizzle(dim_size_bits: int, what: str): + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"No valid out swizzle for {what}: its minor dimension has" + f" {dim_size_bits} bits, which is not a multiple of 128" + ) + + +def ragged_dot( + lhs, # (M, K) + rhs, # (G, K, N) + *, + group_sizes, # (G,) + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, + grid_block_n: int, +) -> jax.Array: + if lhs.dtype != rhs.dtype: + raise NotImplementedError( + f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}" + ) + + elem_bits = jnp.finfo(lhs.dtype).bits + swizzle = _find_swizzle(elem_bits * block_k, "lhs") + swizzle_elems = swizzle * 8 // elem_bits + + m, k = lhs.shape + g, k2, n = rhs.shape + + if group_sizes.shape[0] != g: + raise ValueError( + f"Expected group_sizes to have shape {g} but got {group_sizes.shape}" + ) + + if k != k2: + raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}") + + if k % block_k != 0: + raise ValueError(f"k={k} must be a multiple of block_k={block_k}") + + def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem): + grid = ( + grid_block_n, + pl.cdiv(m, block_m) + g - 1, + pl.cdiv(n, grid_block_n * block_n), + ) + + @functools.partial( + plgpu.nd_loop, grid, init_val=None, collective_axes="sm" + ) + def mn_loop(idx, _): # pylint: disable=unused-variable + block_ni, mi, remainder_ni = idx + ni = block_ni * pl.cdiv(n, block_n * grid_block_n) + remainder_ni + group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi) + + def acc_scope(acc_ref): + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + plgpu.emit_pipeline( + lambda _, lhs_smem, rhs_smem: plgpu.wgmma(acc_ref, lhs_smem, rhs_smem), + grid=(k // block_k,), + in_specs=[ + plgpu.BlockSpec( + (block_m, block_k), + lambda k: (group_info.block, k), + transforms=transforms, + ), + plgpu.BlockSpec( + (block_k, block_n), lambda k: (k, ni), transforms=transforms + ), + ], + max_concurrent_steps=max_concurrent_steps, + delay_release=1, + )(lhs_gmem, rhs_gmem.at[group_info.group_id]) + return acc_ref[...] + + acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n))) + + store_transforms = ( + plgpu.TilingTransform((1, swizzle_elems)), + plgpu.SwizzleTransform(swizzle) + ) + @functools.partial( + pl.run_scoped, + o_smem=plgpu.SMEM( + (block_m, block_n), + dtype=o_gmem.dtype, + transforms=store_transforms, + ) + ) + def store_scope(o_smem): # pylint: disable=unused-variable + o_smem[...] = acc.astype(o_smem.dtype) + plgpu.commit_smem() + + smem_start = group_info.start_within_block + remaining_rows = min(block_m, m) + # TMA descriptors need to be generated with static tile sizes along each + # axis, but we do not know at compile time how many rows we will need to + # store. We only know that the number of rows to store is bounded by + # min(block_m, m). + # + # In order to work around that, we construct a logarithmic ladder of + # TMA descriptors, where each descriptor can store 2**i rows for some + # i between 0 and log2(min(block_m, m)). This allows storing any + # number of rows we will need to store, so long as this number of rows + # is between `1` and `min(block_m, m)`. + # + # E.g., imagine we have block_m = 8, m = 16. The loop below will be + # unrolled into 4 iterations, where the first one will generate a TMA + # descriptor that can store 8 rows, the second one will generate a TMA + # descriptor that can store 4 rows, etc. all the way to 1 row. + # + # At run time, we finally know the actual number of rows we need to + # store as we go through the unrolled loop iterations. Let's imagine + # that we need to store 5 rows. + # + # The first unrolled iteration will check whether we can store 8 rows. + # Since we only need to store 5 rows, we won't store anything then. + # + # The second unrolled iteration will check whether we can store 4 rows. + # We're able to store 4 rows, and are left with a single remaining row. + # + # The fourth unrolled iteration will store the single remaining row, and + # we end up with a storing scheme as follows for our 5 rows: + # + # ----------------------------------------------------------- + # 0 | | + # 1 | | + # 2 | Store 4 rows | + # 3 | | + # ----------------------------------------------------------- + # 4 | Store 1 row | + # ----------------------------------------------------------- + while remaining_rows > 0: + const_rows_len = 1 << int(math.log2(remaining_rows)) + remaining_rows //= 2 + + @pl.when(group_info.actual_size & const_rows_len != 0) + def _(): + o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)] + o_gref_slice = o_gmem.at[ + pl.ds(group_info.block_start + smem_start, const_rows_len), + pl.ds(ni * block_n, block_n), + ] + plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice) + + smem_start += group_info.actual_size & const_rows_len + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + # There are 132 SMs on a H100 SXM GPU. + num_sms = 132 + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype), + grid=(num_sms,), + grid_names=("sm",), + ) + return kernel(group_sizes, lhs, rhs) + + +def main(unused_argv): + m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (m, k), jnp.float16) + rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + block_m = block_n = (64, 128, 192) + block_k = (64,) + max_concurrent_steps = (2, 4, 5, 6) + grid_block_n = (1, 2, 4, 8, 16) + configs = itertools.product( + block_m, block_n, block_k, max_concurrent_steps, grid_block_n + ) + names = ( + "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n" + ) + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if n % (kwargs["grid_block_n"] * kwargs["block_n"]): + continue + try: + f = functools.partial(ragged_dot, group_sizes=group_sizes, **kwargs) + _, runtime = profiler.measure(f, mode="cupti")(lhs, rhs) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + else: + print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: # pytype: disable=unsupported-operands + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + ref, ref_runtime = profiler.measure(jax.lax.ragged_dot)( + lhs, rhs, group_sizes=group_sizes + ) + result = ragged_dot(lhs, rhs, group_sizes=group_sizes, **best_kwargs) + np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) + + tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 + ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print( + "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) + ) + print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index d7df261a1ca9..c45e52b1fe88 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -805,6 +805,43 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "mgpu_ragged_dot_run", + srcs = ["//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100_x32", + "gpu_h100", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_ragged_dot_test", + srcs = ["mgpu_ragged_dot_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100", + ], + shard_count = 12, + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "fuser_block_spec_test", srcs = [ diff --git a/tests/pallas/mgpu_ragged_dot_test.py b/tests/pallas/mgpu_ragged_dot_test.py new file mode 100644 index 000000000000..e9137df1298a --- /dev/null +++ b/tests/pallas/mgpu_ragged_dot_test.py @@ -0,0 +1,114 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU ragged dot kernel.""" + +import contextlib +import os + +from absl.testing import absltest, parameterized # pylint: disable=g-multiple-import +from jax import random +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.numpy as jnp +import numpy as np + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + ragged_dot = None +else: + from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu + + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class RaggedDotTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if ragged_dot_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) + + @parameterized.product( + block_m=(64, 128, 192), + block_n=(64, 128, 192), + block_k=(64, 128), + grid_block_n=(2, 4), + max_concurrent_steps=(2, 4), + num_groups=(1, 3, 16), + ) + def test_ragged_dot( + self, + block_m, + block_n, + block_k, + grid_block_n, + max_concurrent_steps, + num_groups, + ): + dtype = jnp.float16 + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + + m, k, n = 16 * 1024, 2048, 16 * 1024 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (m, k), dtype) + rhs = jax.random.normal(ky, (num_groups, k, n), dtype) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = jax.lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = jax.lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + out = ragged_dot_mgpu.ragged_dot( + lhs, + rhs, + group_sizes=group_sizes, + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + grid_block_n=grid_block_n, + ) + out_ref = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + absltest.main(testLoader=jtu.JaxTestLoader()) From 925e705186ccc5412213c42901099511915096a8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 03:14:05 -0700 Subject: [PATCH 1305/1769] [Pallas:MGPU] Make semaphores compatible with plgpu.kernel and the profiler The previous code was overly specific to pl.pallas_call and did not work with plgpu.kernel at all. Now, semaphores can be allocated using run_scoped, which also has the interesting side effect of the allocations being collective within each program. For a persistent kernel that means that a program/block can communicate with programs/blocks on other devices that have the same program ID, but it is currently impossible to e.g. synchronize all programs in a grid. We don't have a use case for it now, so we can add it later. PiperOrigin-RevId: 761885876 --- jax/_src/pallas/mosaic_gpu/lowering.py | 124 ++++++++++-------- .../mosaic_gpu/pallas_call_registration.py | 14 +- jax/experimental/mosaic/gpu/utils.py | 16 ++- tests/pallas/mosaic_gpu_test.py | 50 ++++++- 4 files changed, 138 insertions(+), 66 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e212f9770a94..2f81ecb969ac 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -105,6 +105,7 @@ class Resources: barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) + gmem_semaphores: int = 0 def __post_init__(self): object.__setattr__( @@ -132,6 +133,7 @@ def __add__(self, other: Resources) -> Resources: smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, + gmem_semaphores=self.gmem_semaphores + other.gmem_semaphores, ) def __or__(self, other: Resources) -> Resources: @@ -143,6 +145,7 @@ def __or__(self, other: Resources) -> Resources: self.tmem_scratch_cols, other.tmem_scratch_cols ), barrier_counts=self.barrier_counts | other.barrier_counts, + gmem_semaphores=max(self.gmem_semaphores, other.gmem_semaphores), ) @@ -266,6 +269,8 @@ def _run_scoped_resource_estimator( elif aval.memory_space == gpu_core.REGS: # Don't need to allocate anything. pass + elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): + rs += Resources(gmem_semaphores=math.prod(aval.shape)) else: raise NotImplementedError( f"Unsupported memory space: {aval.memory_space}") @@ -312,6 +317,8 @@ class ModuleContext: tmem_requested_cols: int tmem_used_cols: int tmem_base_ptr: ir.Value + gmem_used_semaphores: int + gmem_semaphore_base_ptr: ir.Value | None runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches @@ -351,6 +358,21 @@ def reserve_barrier( yield barrier available.append(barrier) + @contextlib.contextmanager + def reserve_semaphores( + self, shape: tuple[int, ...] + ): + allocated_sems = math.prod(shape) + ref = mgpu.memref_slice( + self.gmem_semaphore_base_ptr, + mgpu.ds(self.gmem_used_semaphores, allocated_sems), + ) + ref = mgpu.memref_reshape(ref, shape) + self.gmem_used_semaphores += allocated_sems + yield ref + # TODO: In debug mode verify the values of all semaphores are again 0 + self.gmem_used_semaphores -= allocated_sems + @contextlib.contextmanager def alloc_tmem( self, @@ -640,42 +662,15 @@ def ref_for_aval(aval: jax_core.AbstractValue): else: return gpu_core.SMEM(aval.shape, aval.dtype) - sem_placeholder = None - semaphore_ref_avals = [] - scratch_avals = [] - # Need to unzip semaphores - for v in jaxpr.invars[grid_mapping.slice_scratch_ops]: - aval = v.aval - if (isinstance(aval, pallas_core.AbstractMemoryRef) and - jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): - if aval.memory_space != gpu_core.GMEM: - raise ValueError( - "Only GMEM memory space is supported for semaphores in Mosaic GPU." - ) - semaphore_ref_avals.append(aval) - scratch_avals.append(sem_placeholder) - else: - scratch_avals.append(aval) - def pipeline_fn(*refs): - sem_refs = [] - if semaphore_ref_avals: - refs, sem_refs = util.split_list(refs, [-len(semaphore_ref_avals)]) primitives.run_scoped( - functools.partial(scoped_pipeline_fn, *refs, sem_refs=sem_refs), - scratch_refs=[ - ref_for_aval(aval) if aval is not sem_placeholder else aval - for aval in scratch_avals - ], + functools.partial(scoped_pipeline_fn, *refs), + scratch_refs=[ref_for_aval(v.aval) for v in jaxpr.invars[grid_mapping.slice_scratch_ops]], collective_axes=thread_axis, # scratch_refs are shared across threads ) return () # ``wrap_init`` does not support functions returning None. - def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): - sem_refs_it = iter(sem_refs) - scratch_refs = [ - next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs - ] + def scoped_pipeline_fn(*refs, scratch_refs): def body_fn(indices, *refs): program_ids_template = util.merge_lists( which_parallel, indices, [None] * sum(which_parallel) @@ -708,7 +703,7 @@ def body_fn(indices, *refs): bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype ).get_ref_aval() for bm in block_mappings - ] + semaphore_ref_avals, + ], ) assert not new_consts @@ -726,10 +721,6 @@ def body_fn(indices, *refs): gpu_mesh.cluster if gpu_mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], - [ - jax.ShapeDtypeStruct(r.shape, np.dtype(np.int32)) - for r in semaphore_ref_avals - ], new_jaxpr, params, new_consts, @@ -744,7 +735,6 @@ def lower_jaxpr_to_module( cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], - gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, params: gpu_core.CompilerParams, consts=(), @@ -767,13 +757,31 @@ def lower_jaxpr_to_module( squashed_dims = grid[:-2] parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, lowering_semantics=lowering_semantics + ), + jaxpr, + ) + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers - if gmem_scratch_shapes: - in_buffers, _, out_scratch_buffers = util.split_list( - buffers_gmem, [len(in_shapes), len(gmem_scratch_shapes)] + gmem_semaphores = None + if rs.gmem_semaphores: + # Extract the semaphores local to the current block. + index = ir.IndexType.get() + block_idx = arith_dialect.index_castui(index, mgpu_utils.block_idx()) + gmem_semaphores = mgpu.memref_slice( + buffers_gmem[-1], + mgpu.ds( + arith_dialect.muli( + block_idx, arith_dialect.constant(index, rs.gmem_semaphores) + ), + rs.gmem_semaphores, + ), ) - buffers_gmem = in_buffers + out_scratch_buffers + # The semaphore buffer is an aliased input/output, so we need to skip it twice. + buffers_gmem = buffers_gmem[:len(in_shapes)] + buffers_gmem[-len(out_shapes) - 1:-1] grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): @@ -804,6 +812,8 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_requested_cols=tmem_cols, tmem_used_cols=0, tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, + gmem_used_semaphores=0, + gmem_semaphore_base_ptr=gmem_semaphores, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), @@ -817,13 +827,6 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources( - ResourceEstimatorContext( - axis_names=axis_names, lowering_semantics=lowering_semantics - ), - jaxpr, - ) - scratch_buffers = [ jax.ShapeDtypeStruct(shape=[rs.smem_scratch_bytes], dtype=np.int8), rs.barriers, @@ -842,14 +845,24 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(params.profile_space * 2 * 4) prof_ctx = ProfilerContext(params.profile_dir, prof_spec) + mgpu_grid = tuple(map(operator.mul, parallel_grid, cluster)) + semaphores_shape = () + if rs.gmem_semaphores: + semaphores_shape = ( + jax.ShapeDtypeStruct( + shape=(math.prod(mgpu_grid) * rs.gmem_semaphores,), dtype=np.int32 + ), + ) + # NOTE: new_out_shapes has out_shapes, then semaphores_shape and + # optionally the profiler buffer. module, new_out_shapes, _, launch_ctx = ( mgpu_core._lower_as_gpu_kernel( body, - grid=tuple(map(operator.mul, parallel_grid, cluster)), + grid=mgpu_grid, cluster=cluster, block=block, - in_shapes=(*in_shapes, *gmem_scratch_shapes), - out_shape=(*out_shapes, *gmem_scratch_shapes), + in_shapes=(*in_shapes, *semaphores_shape), + out_shape=(*out_shapes, *semaphores_shape), smem_scratch_shape=scratch_buffers, lowering_semantics=lowering_semantics, module_name=mlir.sanitize_name(debug_info.func_name), @@ -871,11 +884,8 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): launch_ctx.scratch.finalize_size() - if gmem_scratch_shapes: - new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)] - return LoweringResult( - module, parallel_grid, block, new_out_shapes, prof_ctx, tuple(gmem_scratch_shapes) + module, parallel_grid, block, new_out_shapes, prof_ctx, semaphores_shape ) @@ -2283,6 +2293,12 @@ def _run_scoped_lowering_rule( ) input_refs.append(input_ref) should_discharge.append(False) + elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): + input_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_semaphores(aval.shape) + ) + input_refs.append(input_ref) + should_discharge.append(False) else: raise ValueError(f"Can't convert to ref: {aval}") diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index ef1ba37f0f5c..a14ccbb7daa9 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -85,12 +85,16 @@ def pallas_call_lowering( new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) scratch_args = () if lowering_result.gmem_scratch_shapes: + # The new_out_shapes contain the original outputs first, followed by the + # GMEM scratch shapes, and optionally the profiler buffer. input_output_aliases += tuple( - (len(new_avals_in) + i, len(new_avals_out) + i) + (len(ctx.avals_in) + i, len(ctx.avals_out) + i) for i in range(len(lowering_result.gmem_scratch_shapes)) ) + # The GMEM scratch is an aliased kernel input/output. new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) - new_avals_out.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + # We guarantee zero-initialization of the GMEM scratch at the moment, which + # is important for semaphores. def zero_init_gmem_scratch(): return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes] scratch_args = mlir.lower_fun( @@ -100,12 +104,10 @@ def zero_init_gmem_scratch(): ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), *args, *scratch_args, module=module, - out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), + out_types=lowering_result.new_out_shapes, input_output_aliases=input_output_aliases, use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) - if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. - outs = outs[:-len(lowering_result.gmem_scratch_shapes)] if (prof_ctx := lowering_result.profiler_context) is not None: *outs, prof_buffer = outs if (dump_path := prof_ctx.dump_path) == "sponge": @@ -133,6 +135,8 @@ def do_callback(prof_buffer): mlir.lower_fun(do_callback, multiple_results=True)( ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] return outs diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 9eedc3402579..1e20675f7909 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -226,15 +226,19 @@ def when(cond): scf.yield_([]) -def thread_idx(): +def _3d_to_1d_idx(dim_idx_fn, dim_size_fn): i32 = ir.IntegerType.get_signless(32) as_i32 = lambda x: arith.index_cast(i32, x) - tidx = as_i32(gpu.thread_id(gpu.Dimension.x)) - stride = as_i32(gpu.block_dim(gpu.Dimension.x)) + idx = as_i32(dim_idx_fn(gpu.Dimension.x)) + stride = as_i32(dim_size_fn(gpu.Dimension.x)) for dim in (gpu.Dimension.y, gpu.Dimension.z): - tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride)) - stride = arith.muli(stride, as_i32(gpu.block_dim(dim))) - return tidx + idx = arith.addi(idx, arith.muli(as_i32(dim_idx_fn(dim)), stride)) + stride = arith.muli(stride, as_i32(dim_size_fn(dim))) + return idx + + +thread_idx = functools.partial(_3d_to_1d_idx, gpu.thread_id, gpu.block_dim) +block_idx = functools.partial(_3d_to_1d_idx, gpu.block_id, gpu.grid_dim) def _warp_bcast(val, lane_idx=0): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index aedd79e23194..7173639b879f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3394,7 +3394,10 @@ def compute(_, l_smem, r_smem, o_smem): np.testing.assert_allclose(kernel(x, x), x + x) - def test_semaphore_lowering(self): + +class SemaphoreTest(PallasTest): + + def test_lowering(self): # This is a smoke test until we add support for lowering of semaphore ops. def body(i_ref1, i_ref2, o_ref, sem_ref): del i_ref2 # Only here to have a different number of inputs and outputs. @@ -3420,6 +3423,51 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): text, ) + def test_basic(self): + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + pl.semaphore_signal(sem_ref) + o_ref[...] = jnp.ones_like(o_ref) + pl.semaphore_wait(sem_ref) + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + # The semaphore array is scaled up by the grid size. + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) -> (tensor<128xf32>, tensor<2xi32>)", + text, + ) + + def test_with_profiler(self): + # Dealing with profiler and semaphores together is tricky because they both + # add extra outputs to the HLO op. + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + with jax.named_scope("output"): + o_ref[...] = jnp.ones_like(o_ref) + with tempfile.TemporaryDirectory() as tmp_dir: + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + compiler_params=plgpu.CompilerParams(profile_space=32, profile_dir=tmp_dir), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) ->" + r" (tensor<128xf32>, tensor<2xi32>, tensor<512xui32>)", + text, + ) + class ExamplesWGTest( ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From 2520f6b22b7d6ab5a0594d4132daf9348a416426 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 21 May 2025 15:43:19 +0000 Subject: [PATCH 1306/1769] Add extra cuBLAS/cuDNN version checks. --- jax_plugins/cuda/__init__.py | 38 +++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 02bcbcf16dbc..f1e3c55811dc 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -114,7 +114,7 @@ def _version_check(name: str, get_version, get_build_version, scale_for_comparison: int = 1, - min_supported_version: int = 0): + min_supported_version: int = 0) -> int | None: """Checks the runtime CUDA component version against the JAX one. Args: @@ -125,6 +125,8 @@ def _version_check(name: str, min_supported_version: An absolute minimum version required. Must be passed without rounding down. + Returns: the runtime version, or None if the component is not found. + Raises: RuntimeError: If the component is not found, or is of unsupported version, and if raising the error is not deferred till later. @@ -162,12 +164,13 @@ def _version_check(name: str, "version": version, "minimum_supported": min_supported_version} results.append(record) + return version _version_check("CUDA", cuda_versions.cuda_runtime_get_version, cuda_versions.cuda_runtime_build_version, scale_for_comparison=10, min_supported_version=12010) - _version_check( + cudnn_version = _version_check( "cuDNN", cuda_versions.cudnn_get_version, cuda_versions.cudnn_build_version, @@ -191,7 +194,7 @@ def _version_check(name: str, _version_check("cuPTI", cuda_versions.cupti_get_version, cuda_versions.cupti_build_version, min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, + cublas_version = _version_check("cuBLAS", cuda_versions.cublas_get_version, cuda_versions.cublas_build_version, # Ignore patch versions. scale_for_comparison=100, @@ -202,6 +205,35 @@ def _version_check(name: str, scale_for_comparison=100, min_supported_version=12100) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-1 + if (cudnn_version is not None and cudnn_version == 91000 + and cuda_versions.cudnn_build_version() != 91000): + msg = ("cuDNN 9.10.0 had a binary backward-compatibility issue due to reordered enum " + f"values affecting block-scale datatypes. Found runtime version {cudnn_version} " + f"and build version {cuda_versions.cudnn_build_version()}. Please upgrade to " + "9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + # xb.local_device_count() cannot safely be called at this point + if xb.CUDA_VISIBLE_DEVICES.value == "all": + local_device_count = cuda_versions.cuda_device_count() + else: + local_device_count = len(xb.CUDA_VISIBLE_DEVICES.value.split(",")) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-0 + if (cudnn_version is not None and cudnn_version < 91001 + and cublas_version is not None and cublas_version >= 120900 + and local_device_count > 1): + msg = (f"cuDNN < 9.10.0 ({cudnn_version} found) had an issue that caused some multi-GPU " + "matmuls, in which the same finalized execution plan is used across different " + f"GPUs, to be functionally incorrect when run with cublasLt >= 12.9 ({cublas_version} " + "found). Please upgrade to 9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + errors = [] debug_results = [] for result in results: From 223220167324ddc4957de78eb2c56870bd714ee5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 22 May 2025 05:00:23 -0700 Subject: [PATCH 1307/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f040466fc5fa5052b1f4e89fb2b166bb2c9d656a. PiperOrigin-RevId: 761914597 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 521e07bb6d70..dda1d3c36cf8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c74abfc3ecfdc31901acca65efa29ffda3ed84cc" -XLA_SHA256 = "e367e84d64730cfb94c58dd75477239183dbe74a73ea57bde3c87abcfcaffb3a" +XLA_COMMIT = "f040466fc5fa5052b1f4e89fb2b166bb2c9d656a" +XLA_SHA256 = "52ecb43ba06e2d5e8cf84fad5c43104c8f5be92de4bb5e5c8a936e6351093c6e" def repo(): tf_http_archive( From 7014bde5a57931054ea623ce5a0d3dc85090ce4e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 06:36:00 -0700 Subject: [PATCH 1308/1769] [Pallas:MGPU] Add a first prototype of an all_gather collective matmul kernel It's not very optimized at the moment and is unlikely to outperform the baseline of raw all_gather + matmul, but it computes the right numbers. We are already aware of a few places that could be optimized and we'll start rolling them out soon. PiperOrigin-RevId: 761939624 --- jax/_src/pallas/core.py | 2 +- .../pallas/ops/gpu/collective_matmul_mgpu.py | 178 ++++++++++++++++++ tests/pallas/BUILD | 29 +++ tests/pallas/mgpu_collective_matmul_test.py | 134 +++++++++++++ 4 files changed, 342 insertions(+), 1 deletion(-) create mode 100644 jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py create mode 100644 tests/pallas/mgpu_collective_matmul_test.py diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index fe755d61a310..13c634eb395f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -971,7 +971,7 @@ def _convert_block_spec_to_block_mapping( class ScratchShape(Protocol): def get_array_aval(self) -> jax_core.AbstractValue: ... - def get_ref_aval(self) -> state.AbstractRef: + def get_ref_aval(self) -> state.AbstractRef | TransformedRef: ... diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py new file mode 100644 index 000000000000..a6c372f2cee7 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -0,0 +1,178 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collective matmul kernel implemented using Mosaic GPU.""" + +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.numpy as jnp + + +def _find_swizzle(dim_size_bits: int, what: str): + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"No valid out swizzle for {what}: its minor dimension has" + f" {dim_size_bits} bits, which is not a multiple of 128" + ) + + +# TODO(apaszke): Add grid tiling +def all_gather_lhs_matmul( + lhs: jax.Array, + rhs: jax.Array, + axis_name, + *, + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, +) -> jax.Array: + if (num_devices := jax.device_count()) != jax.process_count(): + raise ValueError("The kernel only supports one device per process") + if (axis_size := lax.axis_size(axis_name)) != num_devices: + raise ValueError("The kernel can only work over all devices in a Mesh.") + if max_concurrent_steps < 2: + raise ValueError("max_concurrent_steps must be >= 2") + + num_sms = 132 # There are 132 SMs on a H100 SXM GPU. + + m_shard, k = lhs.shape + k2, n_shard = rhs.shape + if k != k2: + raise ValueError( + f"lhs and rhs must have the same contraction size, got {k} and {k2}." + ) + if (element_type := lhs.dtype) != rhs.dtype: + raise ValueError( + f"lhs and rhs must have the same element type, got {element_type} and" + f" {rhs.dtype}." + ) + if k % block_k != 0: + raise NotImplementedError(f"k={k} must be a multiple of block_k={block_k}") + if m_shard % block_m != 0: + raise NotImplementedError(f"m_shard={m_shard} must be a multiple of block_m={block_m}") + if n_shard % block_n != 0: + raise NotImplementedError(f"n_shard={n_shard} must be a multiple of block_n={block_n}") + if n_shard != block_n: + raise NotImplementedError( + f"n_shard={n_shard} must be equal to block_n={block_n}" + ) + + swizzle = min( + _find_swizzle(block_k * jnp.finfo(element_type).bits, "lhs"), + _find_swizzle(block_n * jnp.finfo(element_type).bits, "rhs"), + ) + transforms = ( + plgpu.TilingTransform((8, swizzle // jnp.dtype(element_type).itemsize)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, capacity_sem, received_sem): + sm_id = lax.axis_index('sm') + scratch_ref = scratch_ref.at[sm_id] + + dev_id = lax.axis_index(axis_name) + send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) + recv_dev_id = lax.rem(dev_id + 1, axis_size) + # NOTE: Technically we should signal the recv_dev_id (and our signal would + # be received from send_dev_id), but if everyone signals in a ring after a + # barrier then it's equivalent to a local signal. + pl.semaphore_signal(capacity_sem) + send_scratch_ref = plgpu.remote_ref( + scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL + ) + + def m_loop(mi, _): + mi = mi * lax.axis_size('sm') + sm_id + m_tile_slice = pl.ds(mi * block_m, block_m) + + # For some reason ptxas spills if we unroll the loop over k + copy_block = 32 + def k_copy_loop(ki, _): + k_slice = pl.ds(ki * copy_block, copy_block) + scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice] + jax.lax.fori_loop(0, k // copy_block, k_copy_loop, None) + + def device_loop(device_offset, _): + # Loop invariant: scratch_ref.at[scratch_slot] is ready to be used + # We're double buffering the scratch space. At each step, we read from + # scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot] + # located on the send_dev_id. We swap the slots after completing a step, + # which lets us overlap the copy with compute. + scratch_slot = lax.rem(device_offset, 2) + next_scratch_slot = 1 - scratch_slot + + @functools.partial( + pl.run_scoped, + acc_ref=plgpu.ACC((block_m, block_n)), + out_smem=plgpu.SMEM((block_m, block_n), jnp.float16, transforms=transforms), + ) + def _(acc_ref, out_smem): + pl.semaphore_wait(capacity_sem) + @functools.partial( + plgpu.emit_pipeline, + grid=(k // block_k,), + in_specs=[ + plgpu.BlockSpec((block_m, block_k), lambda k: (0, k), transforms=transforms), + plgpu.BlockSpec((block_k, block_n), lambda k: (k, 0), transforms=transforms), + ], + max_concurrent_steps=max_concurrent_steps, + delay_release=1, + ) + def k_loop(idxs, lhs_smem, rhs_smem): + (ki,) = idxs + plgpu.wgmma(acc_ref, lhs_smem, rhs_smem) + k_slice = pl.ds(ki * block_k, block_k) + # TODO(apaszke): No need to send on the last step + # TODO(apaszke): Use an async copy. This is uncoalesced. + send_scratch_ref[next_scratch_slot, :, k_slice] = lhs_smem[...] + k_loop(scratch_ref.at[scratch_slot], rhs_ref) + # TODO(apaszke): Both of those semaphores perform a .sys release. + # This is very expensive and we should only do a single .sys fence. + pl.semaphore_signal(capacity_sem, device_id=recv_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_signal(received_sem, device_id=send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) + # Make sure all TMAs have read SMEM before we overwrite it. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + out_smem[...] = acc_ref[...].astype(out_smem.dtype) + plgpu.commit_smem() + device_m_slice = pl.ds( + lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m + ) + plgpu.copy_smem_to_gmem( + out_smem, out_ref.at[device_m_slice].at[m_tile_slice] + ) + # Wait for the next scratch to arrive --- see the loop invariant. + pl.semaphore_wait(received_sem) + jax.lax.fori_loop(0, num_devices, device_loop, None) + grid_size = m_shard // block_m + m_steps = grid_size // num_sms + jnp.int32(sm_id < grid_size % num_sms) + # TODO(apaszke): Use the ND-loop helper. + jax.lax.fori_loop(0, m_steps, m_loop, None) + + result, _ = plgpu.kernel( + kernel_body, + out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), jnp.float16), + jax.ShapeDtypeStruct((num_sms, 2, block_m, k), jnp.float16)], + scratch_shapes=[ + plgpu.SemaphoreType.REGULAR, plgpu.SemaphoreType.REGULAR, + ], + grid=(num_sms,), + grid_names=('sm',), + )(lhs, rhs) + return result diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index c45e52b1fe88..cf0d46639559 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -842,6 +842,35 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "mgpu_collective_matmul_test", + srcs = ["mgpu_collective_matmul_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + shard_count = 4, + tags = [ + "manual", + "multiaccelerator", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + "//jax:test_multiprocess", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "fuser_block_spec_test", srcs = [ diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py new file mode 100644 index 000000000000..bbc50d39d7f6 --- /dev/null +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -0,0 +1,134 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU collective matmul.""" + +import contextlib +import functools +import os + +from absl.testing import parameterized # pylint: disable=g-multiple-import +import jax +from jax import lax +from jax import random +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec + + +@jtu.with_config(jax_traceback_filtering="off") +class CollectiveMatmulTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if collective_matmul_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) + + @parameterized.product( + m_shard=(1024, 8192), + n_shard=(64, 128, 192), + k=(256, 8192), + block_m=(64, 128, 192), + block_n=(64, 128, 192), + block_k=(64, 128), + max_concurrent_steps=(2, 4), + ) + def test_all_gather_lhs_matmul( + self, + m_shard, + n_shard, + k, + block_m, + block_n, + block_k, + max_concurrent_steps, + ): + num_devices = jax.device_count() + dtype = jnp.float16 + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + if n_shard != block_n: + self.skipTest("n_shard must be equal to block_n for now.") + if n_shard % block_n: + self.skipTest("n_shard must be divisble by block_n for now.") + if m_shard % block_m: + self.skipTest("m_shard must be divisible by block_m for now.") + + k1, k2 = random.split(random.key(1234), num=2) + lhs = random.normal(k1, (num_devices * m_shard, k), dtype) + rhs = random.normal(k2, (k, num_devices * n_shard), dtype) + + mesh = jax.sharding.Mesh(jax.devices(), ["x"]) + lhs = jax.device_put(lhs, jax.sharding.NamedSharding(mesh, P("x", None))) + rhs = jax.device_put(rhs, jax.sharding.NamedSharding(mesh, P(None, "x"))) + + def run(body): + out = jax.jit( + jax.shard_map( + body, + mesh=mesh, + in_specs=(P("x", None), P(None, "x")), + out_specs=P(None, "x"), + check_vma=False, + ) + )(lhs, rhs) + # Gather output, for NumPy comparison on the host. + out = jax.shard_map( + lambda x: lax.all_gather(x, "x", axis=1, tiled=True), + mesh=mesh, + in_specs=P(None, "x"), + out_specs=P(None), + check_vma=False, + )(out) + return out + + out = run( + functools.partial( + collective_matmul_mgpu.all_gather_lhs_matmul, + axis_name="x", + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + ) + ) + ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y) + np.testing.assert_allclose(out, ref_out) + + +if __name__ == "__main__": + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + jt_multiprocess.main() From 4c83b08980f33f208d55667135e9fbe2cc3e8938 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 07:12:01 -0700 Subject: [PATCH 1309/1769] [Pallas:MGPU] Disable the GPU distributed test when the platform allocator is used And, if we're only running this test, try to override the flag. PiperOrigin-RevId: 761950583 --- .github/workflows/bazel_optional_h100_b200.yml | 2 +- jax/experimental/mosaic/gpu/core.py | 7 +++++++ tests/pallas/gpu_pallas_distributed_test.py | 9 +++++++++ tests/pallas/mgpu_collective_matmul_test.py | 7 +++++++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml index 0c73b238505e..7381ce6d80bf 100644 --- a/.github/workflows/bazel_optional_h100_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -48,6 +48,7 @@ jobs: --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ + --test_tag_filters=-multiaccelerator \ --test_env=JAX_ACCELERATOR_COUNT=1 \ --test_env=JAX_TESTS_PER_ACCELERATOR=8 \ --strategy=TestRunner=local \ @@ -102,7 +103,6 @@ jobs: //tests/pallas:gpu_tests \ //tests:array_interoperability_test_gpu \ //tests:cudnn_fusion_test_gpu \ - //tests:fused_attention_stablehlo_test_gpu //tests:fused_attention_stablehlo_test_gpu \ //tests:gpu_tests \ //tests:python_callback_test_gpu \ diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index c20c5252a27f..48a877f8c67a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -102,8 +102,15 @@ def supports_cross_device_collectives(): try: nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] + nvshmem_so_path = os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] except KeyError: return False + try: + # This both ensures that the file exists, and it populates the dlopen cache + # helping XLA find the library even if the RPATH is not exactly right... + ctypes.CDLL(nvshmem_so_path) + except OSError: + return False xla_flags = os.environ.get("XLA_FLAGS", "") return ( os.path.exists(nvshmem_bc_path) diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index d862e6b9b819..3aeee352ff6d 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -15,6 +15,8 @@ """Tests for distributed pallas GPU operations.""" import functools +import os + import jax from jax import lax from jax._src import test_util as jtu @@ -41,6 +43,8 @@ def setUp(self): self.skipTest("NVSHMEM library unavailable.") if jax.process_count() == 1: self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") super().setUp() def test_basic_remote_dma(self): @@ -114,4 +118,9 @@ def kernel(y_ref, sem): if __name__ == '__main__': + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' jt_multiprocess.main() diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py index bbc50d39d7f6..e0ced79801d8 100644 --- a/tests/pallas/mgpu_collective_matmul_test.py +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -48,6 +48,8 @@ def setUp(self): self.skipTest("NVSHMEM library unavailable.") if jax.process_count() == 1: self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") context_stack = contextlib.ExitStack() context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) self.addCleanup(context_stack.close) @@ -128,6 +130,11 @@ def run(body): if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.01" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "default" os.environ["XLA_FLAGS"] = ( os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" ) From be0ed4ac78f720567f947960852935b9494e590a Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Thu, 22 May 2025 07:20:20 -0700 Subject: [PATCH 1310/1769] Prevent tests of TPU interpret mode from being run in parallel. --- tests/pallas/tpu_pallas_interpret_distributed_test.py | 1 + tests/pallas/tpu_pallas_interpret_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index bd85ded66a73..a029b8094aa1 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -38,6 +38,7 @@ P = jax.sharding.PartitionSpec +@jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): super().setUp() diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 28c63dc3bd9b..9d6188cbc0cc 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -82,6 +82,7 @@ def grid_points(self): return self._grid_points +@jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase): def setUp(self): From 125f8170331886437fe829db46e0439a71e45ad4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 07:57:25 -0700 Subject: [PATCH 1311/1769] Bump the oldest libtpu version again The previous bump was wrong because apparently there was a 3 day window in which we haven't been publishing nightlies. This time, I checked that there is a package released on that day. PiperOrigin-RevId: 761963834 --- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- .github/workflows/pytest_tpu.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 061b399132e2..c7394a498dd6 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -44,7 +44,7 @@ jobs: jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20250226 + LIBTPU_OLDEST_VERSION_DATE: 20250228 PYTHON: python${{ matrix.python-version }} runs-on: ${{ matrix.tpu.runner }} container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 8ecccbe274e9..d1af90283001 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -82,7 +82,7 @@ jobs: # End Presubmit Naming Check github-tpu-presubmits env: - LIBTPU_OLDEST_VERSION_DATE: 20250226 + LIBTPU_OLDEST_VERSION_DATE: 20250228 JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" From 2302a2e179c9fb184414e36418a7f303ad5ff37d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 22 May 2025 08:25:57 -0700 Subject: [PATCH 1312/1769] Fix `test_make_array_from_single_device_arrays` in random_test.py for TPU 7x. This is because `create_mesh` return a different order of devices while `arrays` are being created with `jax.devices()` order. PiperOrigin-RevId: 761975228 --- tests/random_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/random_test.py b/tests/random_test.py index d75f3a9c5e2e..86a1622240dc 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -991,7 +991,7 @@ def callback(index): def test_make_array_from_single_device_arrays(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',), iota_order=True) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] From 3a667ce9169f2d9d6a7f94cdb34b1910da0c4bb3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 22 May 2025 08:37:40 -0700 Subject: [PATCH 1313/1769] [pallas:mosaic] Enabled `scan` lowering rule for all kernel types PiperOrigin-RevId: 761979336 --- jax/_src/pallas/mosaic/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index dd0b9ba4b4a7..d1222c16ac96 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2996,7 +2996,7 @@ def _run_body(i, args): return for_op.results -@register_lowering_rule(lax.scan_p, ensure_mlir_values=False) +@register_lowering_rule(lax.scan_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, From 44d2cc927989b8a6a4681ba49c49f6ac8efba910 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 08:54:00 -0700 Subject: [PATCH 1314/1769] [Pallas:MGPU] (Slightly) Clean up the collective matmul test by using explicit sharding Not a huge difference, but it's a bit nicer. PiperOrigin-RevId: 761984854 --- tests/pallas/BUILD | 1 + tests/pallas/mgpu_collective_matmul_test.py | 27 +++++++++------------ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index cf0d46639559..e4a308c2f10b 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -864,6 +864,7 @@ jax_multiplatform_test( "notap", ], deps = [ + "//jax:experimental", "//jax:pallas", "//jax:pallas_experimental_gpu_ops", "//jax:pallas_mosaic_gpu", diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py index e0ced79801d8..386162b1992c 100644 --- a/tests/pallas/mgpu_collective_matmul_test.py +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -27,6 +27,7 @@ from jax._src.pallas import pallas_call from jax.experimental.mosaic import gpu as mgpu from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu +from jax.experimental import shard import jax.numpy as jnp import numpy as np @@ -51,8 +52,13 @@ def setUp(self): if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": self.skipTest("NVSHMEM doesn't work with the platform allocator.") context_stack = contextlib.ExitStack() - context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) self.addCleanup(context_stack.close) + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + num_devices = jax.device_count() + mesh = jax.make_mesh( + (num_devices,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + context_stack.enter_context(jax.sharding.use_mesh(mesh)) @parameterized.product( m_shard=(1024, 8192), @@ -90,28 +96,17 @@ def test_all_gather_lhs_matmul( k1, k2 = random.split(random.key(1234), num=2) lhs = random.normal(k1, (num_devices * m_shard, k), dtype) rhs = random.normal(k2, (k, num_devices * n_shard), dtype) - - mesh = jax.sharding.Mesh(jax.devices(), ["x"]) - lhs = jax.device_put(lhs, jax.sharding.NamedSharding(mesh, P("x", None))) - rhs = jax.device_put(rhs, jax.sharding.NamedSharding(mesh, P(None, "x"))) + lhs = shard.reshard(lhs, P("x", None)) + rhs = shard.reshard(rhs, P(None, "x")) def run(body): out = jax.jit( - jax.shard_map( - body, - mesh=mesh, - in_specs=(P("x", None), P(None, "x")), - out_specs=P(None, "x"), - check_vma=False, - ) + jax.shard_map(body, out_specs=P(None, "x"), check_vma=False) )(lhs, rhs) # Gather output, for NumPy comparison on the host. out = jax.shard_map( lambda x: lax.all_gather(x, "x", axis=1, tiled=True), - mesh=mesh, - in_specs=P(None, "x"), - out_specs=P(None), - check_vma=False, + out_specs=P(None), check_vma=False, )(out) return out From 35a45a601c461b611cf28257cc79d05239d7ec04 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 22 May 2025 09:35:00 -0700 Subject: [PATCH 1315/1769] Reduce input ranges tested for hyp2f1. The `z` input to this function only supports input in the range 0 <= z < 1, so most of the tested samples were explicitly `inf`. The other parameters can take arbitrary values, but the dynamic range required explodes for values larger than ~1, so the old tests were quite numerically unstable. We could test some specific carefully chosen larger values, but I don't think this brings much benefit. PiperOrigin-RevId: 762000716 --- tests/lax_scipy_special_functions_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index e02581626a53..995854dae348 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -159,7 +159,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t ), op_record( "hyp2f1", 4, float_dtypes, - functools.partial(jtu.rand_uniform, low=0.5, high=30), False + functools.partial(jtu.rand_uniform, low=0.1, high=0.9), False ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), From 0e77a1617343886387369aa46b21cf8afbc7870c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 May 2025 09:54:35 -0700 Subject: [PATCH 1316/1769] [Mosaic GPU] Don't require MOSAIC_GPU_NVSHMEM_SO_PATH to be set PiperOrigin-RevId: 762008659 --- jax/experimental/mosaic/gpu/core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 48a877f8c67a..5e1ed6b88412 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -102,15 +102,15 @@ def supports_cross_device_collectives(): try: nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] - nvshmem_so_path = os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] except KeyError: return False - try: - # This both ensures that the file exists, and it populates the dlopen cache - # helping XLA find the library even if the RPATH is not exactly right... - ctypes.CDLL(nvshmem_so_path) - except OSError: - return False + if nvshmem_so_path := os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH", ""): + try: + # This both ensures that the file exists, and it populates the dlopen + # cache, helping XLA find the library even if the RPATH is not right... + ctypes.CDLL(nvshmem_so_path) + except OSError: + return False xla_flags = os.environ.get("XLA_FLAGS", "") return ( os.path.exists(nvshmem_bc_path) From 0dc70b93f2e13fae5b097837760bd621e746dae7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 22 May 2025 09:59:16 -0700 Subject: [PATCH 1317/1769] Update the trove classifiers for JAX packages. * Don't tag packages as alpha. * Tag free-threading as supported. * Update some python version lists. --- jax_plugins/cuda/plugin_setup.py | 4 +++- jax_plugins/cuda/setup.py | 3 ++- jaxlib/setup.py | 2 ++ setup.py | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index c8b70408471c..fc467824fe5f 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -78,10 +78,12 @@ def has_ext_modules(self): url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: [ diff --git a/jax_plugins/cuda/setup.py b/jax_plugins/cuda/setup.py index 1ce555978dac..b2c89285e7fd 100644 --- a/jax_plugins/cuda/setup.py +++ b/jax_plugins/cuda/setup.py @@ -51,8 +51,9 @@ def load_version_module(pkg_path): url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: ["xla_cuda_plugin.so"], diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 8d7933953851..30e81c9ad671 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -68,10 +68,12 @@ def has_ext_modules(self): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ 'jaxlib': [ diff --git a/setup.py b/setup.py index 4c5c86f588c3..ef78b8f6e7ff 100644 --- a/setup.py +++ b/setup.py @@ -118,10 +118,12 @@ def load_version_module(pkg_path): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], zip_safe=False, ) From 437e32bfddabe6c4b5ff5682ca944cb6a4cbaf89 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 22 May 2025 10:22:03 -0700 Subject: [PATCH 1318/1769] jax.random: thread mode parameter through categorical and choice --- jax/_src/random.py | 32 ++++++++++++++++++++++++++------ tests/random_lax_test.py | 10 ++++++---- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 6f139dd9665c..60dad3a82021 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -633,7 +633,8 @@ def choice(key: ArrayLike, shape: Shape = (), replace: bool = True, p: RealArray | None = None, - axis: int = 0) -> Array: + axis: int = 0, + mode: str | None = None) -> Array: """Generates a random sample from a given array. .. warning:: @@ -656,6 +657,12 @@ def choice(key: ArrayLike, entries in a. axis: int, optional. The axis along which the selection is performed. The default, 0, selects by row. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler + when `p is None` and `replace = False`. The default is determined by the + ``use_high_dynamic_range_gumbel`` config, which defaults to "low". With mode="low", + in float32 sampling will be biased for choices with probability less than about + 1E-7; with mode="high" this limit is pushed down to about 1E-14. mode="high" + approximately doubles the cost of sampling. Returns: An array of shape `shape` containing samples from `a`. @@ -701,7 +708,7 @@ def choice(key: ArrayLike, ind = jnp.searchsorted(p_cuml, r).astype(int) else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr) + g = gumbel(key, (n_inputs,), dtype=p_arr.dtype, mode=mode) + jnp.log(p_arr) ind = lax.top_k(g, k=n_draws)[1].astype(int) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) @@ -940,7 +947,8 @@ def bernoulli(key: ArrayLike, mode: optional, "high" or "low" for how many bits to use when sampling. default='low'. Set to "high" for correct sampling at small values of `p`. When sampling in float32, bernoulli samples with mode='low' produce - incorrect results for p < ~1E-7. + incorrect results for p < ~1E-7. mode="high" approximately doubles the + cost of sampling. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` @@ -1544,7 +1552,7 @@ def poisson(key: ArrayLike, def gumbel(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, - mode: str | None =None) -> Array: + mode: str | None = None) -> Array: """Sample Gumbel random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1559,6 +1567,11 @@ def gumbel(key: ArrayLike, dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). mode: optional, "high" or "low" for how many bits to use when sampling. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". When drawing float32 samples, with mode="low" the + uniform resolution is such that the largest possible gumbel logit is ~16; + with mode="high" this is increased to ~32, at approximately double the + computational cost. Returns: A random array with the specified shape and dtype. @@ -1599,6 +1612,7 @@ def categorical( axis: int = -1, shape: Shape | None = None, replace: bool = True, + mode: str | None = None, ) -> Array: """Sample random values from categorical distributions. @@ -1615,6 +1629,12 @@ def categorical( The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. replace: If True (default), perform sampling with replacement. If False, perform sampling without replacement. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". With mode="low", in float32 sampling will be biased + for events with probability less than about 1E-7; with mode="high" this limit + is pushed down to about 1E-14. mode="high" approximately doubles the cost of + sampling. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` @@ -1644,11 +1664,11 @@ def categorical( logits_shape = list(shape[len(shape) - len(batch_shape):]) logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) return jnp.argmax( - gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + + gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype, mode=mode) + lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), axis=axis) else: - logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype) + logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype, mode=mode) k = math.prod(shape_prefix) if k > logits_arr.shape[axis]: raise ValueError( diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index f87b079b759c..9fe4d2ecbda3 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -286,8 +286,9 @@ def testTruncatedNormal(self, dtype): ], dtype=jtu.dtypes.floating + jtu.dtypes.integer, weighted=[True, False], + mode=[None, 'low', 'high'] ) - def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): + def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis, mode): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice p_dtype = dtypes.to_inexact_dtype(dtype) @@ -303,7 +304,7 @@ def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis p /= p.sum() else: p = None - rand = lambda key, x: random.choice(key, x, shape, replace, p, axis) + rand = lambda key, x: random.choice(key, x, shape, replace, p, axis, mode=mode) sample = rand(key(), x) if not is_range: self.assertEqual(dtype, sample.dtype) @@ -397,15 +398,16 @@ def testBernoulli(self, p, dtype, mode): ] ], sample_shape=[(10000,), (5000, 2)], + mode=[None, 'low', 'high'], dtype=jtu.dtypes.floating, ) - def testCategorical(self, p, axis, dtype, sample_shape): + def testCategorical(self, p, axis, dtype, sample_shape, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape - rand = partial(random.categorical, shape=shape, axis=axis) + rand = partial(random.categorical, shape=shape, axis=axis, mode=mode) crand = jax.jit(rand) uncompiled_samples = rand(key(), logits) From 210b5fc8674e4993254c804720144c570992984e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 22 May 2025 10:23:23 -0700 Subject: [PATCH 1319/1769] Error out if wsc(x, P()) is called in a Explicit mesh context PiperOrigin-RevId: 762021159 --- jax/_src/pjit.py | 8 +++++++- tests/pjit_test.py | 7 +++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0624dad88a2b..ecdcf3e17332 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2703,7 +2703,13 @@ def check_shardings_are_auto(shardings_flat): raise ValueError( 'The spec of NamedSharding passed to with_sharding_constraint can' f' only refer to Auto axes of the mesh. Got spec={s.spec} and' - f' mesh={mesh}') + f' mesh={mesh}. You probably meant to use `reshard` API?') + + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh._are_all_axes_explicit: + raise ValueError( + 'with_sharding_constraint cannot be used when all axes of the mesh are' + ' of type `Explicit`. Please use the `reshard` API.') def with_sharding_constraint(x, shardings): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 024901b746a8..48339bb2a519 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7069,8 +7069,11 @@ def test_wsc_error(self, mesh): "The spec of NamedSharding passed to with_sharding_constraint"): jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s) - s = NamedSharding(mesh, P()) - jax.lax.with_sharding_constraint(np.arange(8), s) + with self.assertRaisesRegex( + ValueError, + 'with_sharding_constraint cannot be used when all axes of the mesh are' + ' of type `Explicit`'): + jax.lax.with_sharding_constraint(np.arange(8), NamedSharding(mesh, P())) s = NamedSharding(Mesh(mesh.devices, mesh.axis_names, axis_types=(AxisType.Explicit, AxisType.Auto)), From a827a274baf18ece651d37b7933bee3cb2f8760e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 22 May 2025 10:57:02 -0700 Subject: [PATCH 1320/1769] [Mosaic GPU] Add support for loops, debug_print, and unary ops to Warp semantics. PiperOrigin-RevId: 762036132 --- jax/_src/pallas/mosaic_gpu/lowering.py | 21 +++ jax/experimental/mosaic/gpu/utils.py | 8 +- tests/pallas/mosaic_gpu_test.py | 170 ++++++++++++++++--------- 3 files changed, 139 insertions(+), 60 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2f81ecb969ac..00716bb1c675 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1698,6 +1698,19 @@ def convert(ty, x): lax.not_p: lambda ctx, x: ~x, }) +def _unary_warp_lowering_rule(impl): + def _lowering_rule(ctx: LoweringRuleContext, x): + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") + return impl(x) + return _lowering_rule + +mosaic_lowering_rules[gpu_core.LANExWARP_SEMANTICS].update({ + lax.neg_p: _unary_warp_lowering_rule(lambda x: -x), + lax.not_p: _unary_warp_lowering_rule(lambda x: ~x) +}) + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( @@ -2163,6 +2176,8 @@ def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): @register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2171,6 +2186,9 @@ def _debug_print_lowering_rule( ): del has_placeholders # Unused. primitives.check_debug_print_format(fmt, *args) + scope = mgpu.ThreadSubset.WARPGROUP + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + scope = mgpu.ThreadSubset.WARP if not any(aval.shape for aval in ctx.avals_in): mgpu.debug_print( fmt, @@ -2178,6 +2196,7 @@ def _debug_print_lowering_rule( _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) ), + scope=scope ) elif len(ctx.avals_in) == 1: [arg] = args @@ -2461,6 +2480,8 @@ def loop(loop_index, body_args): @register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 1e20675f7909..4aeb3358b97a 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -144,7 +144,11 @@ def _debug_scalar_ty_format(arg): return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") -def debug_print(fmt, *args, uniform=True): +def debug_print(fmt, *args, uniform=True, scope=None): + if not uniform and scope is not None: + raise ValueError("Cannot specify scope to a non-uniform debug_print.") + if scope is None: + scope = ThreadSubset.WARPGROUP type_formats = [] new_args = [] for arg in args: @@ -168,7 +172,7 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) ctx = ( - functools.partial(single_thread, scope=ThreadSubset.WARPGROUP) + functools.partial(single_thread, scope=scope) if uniform else contextlib.nullcontext ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7173639b879f..4ef2fa8096ee 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1569,64 +1569,6 @@ def kernel(x_ref, y_ref, o_ref): y = jax.lax.iota(jnp.float32, 128) * 3 np.testing.assert_array_equal(kernel(x, y), x + y) - def test_warp_specialization_axis_index(self): - if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: - self.skipTest("Test only works on Lane semantics") - warp_mesh = plgpu.WarpMesh(axis_name="warp") - @functools.partial(plgpu.kernel, - out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) - def kernel(y_ref): - def scope(ones_smem_ref, threes_smem_ref): - # Prepare data to copy. - ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) - threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 - plgpu.commit_smem() - @pl.core_map(warp_mesh) - def _(): - warp_id = lax.axis_index("warp") - # We cannot load/store inside of core_map, so we issue async - # copies instead to produce a testable result. - @pl.when(warp_id == 1) - def _(): - plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) - @pl.when(warp_id == 3) - def _(): - plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped(scope, - plgpu.SMEM((1, 128), jnp.int32), - plgpu.SMEM((1, 128), jnp.int32) - ) - result = kernel() - expected = jnp.stack((jnp.ones((128,), jnp.int32), - jnp.ones((128,), jnp.int32) * 3), axis=0) - np.testing.assert_array_equal(result, expected) - - def test_warp_mesh_errors_when_closing_over_array(self): - if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: - self.skipTest("Test only works on Lane semantics") - # We currently do not allow closing over arrays when mapping over - # a mesh, since we would need to present a view of the array local - # to each warp. - warp_mesh = plgpu.WarpMesh(axis_name="warp") - @functools.partial(plgpu.kernel, - out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), - scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) - def kernel(out_ref, smem_ref): - arr = jnp.ones((32, 32), dtype=jnp.float32) - @pl.core_map(warp_mesh) - def _(): - smem_ref[...] = arr + 1 - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem_ref, out_ref) - plgpu.wait_smem_to_gmem(0) - with self.assertRaisesRegex( - mgpu_lowering.LoweringError, - "Can only close over scalars and Refs when using core_map with " - "WarpMesh", - ): - kernel() - def test_smem_aliasing_works(self): self.skip_if_wg_semantics() @@ -1825,6 +1767,118 @@ def body(idx, _): ) +class PallasCallWarpPrimitiveSemanticsTest(PallasTest): + def setUp(self): + super().setUp() + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + + def test_axis_index(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) + def kernel(y_ref): + def scope(ones_smem_ref, threes_smem_ref): + # Prepare data to copy. + ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) + threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + # We cannot load/store inside of core_map, so we issue async + # copies instead to produce a testable result. + @pl.when(warp_id == 1) + def _(): + plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) + @pl.when(warp_id == 3) + def _(): + plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.SMEM((1, 128), jnp.int32), + plgpu.SMEM((1, 128), jnp.int32) + ) + result = kernel() + expected = jnp.stack((jnp.ones((128,), jnp.int32), + jnp.ones((128,), jnp.int32) * 3), axis=0) + np.testing.assert_array_equal(result, expected) + + def test_errors_when_closing_over_array(self): + # We currently do not allow closing over arrays when mapping over + # a mesh, since we would need to present a view of the array local + # to each warp. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), + scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) + def kernel(out_ref, smem_ref): + arr = jnp.ones((32, 32), dtype=jnp.float32) + @pl.core_map(warp_mesh) + def _(): + smem_ref[...] = arr + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, out_ref) + plgpu.wait_smem_to_gmem(0) + with self.assertRaisesRegex( + mgpu_lowering.LoweringError, + "Can only close over scalars and Refs when using core_map with " + "WarpMesh", + ): + kernel() + + def test_single_warp_scan(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((10, 128), jnp.int32)) + def kernel(y_ref): + def scope(smem_ref): + # Prepare data to copy. + for i in range(10): + smem_ref[i, :] = jnp.ones_like(smem_ref.at[i]) * i + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + def loop_body(i, _): + _slice = pl.ds(i, 1) + plgpu.copy_smem_to_gmem(smem_ref.at[_slice], y_ref.at[_slice]) + lax.fori_loop(0, 10, loop_body, None) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, plgpu.SMEM((10, 128), jnp.int32)) + result = kernel() + expected = jnp.stack( + [jnp.ones((128,), jnp.int32) * i for i in range(10)], axis=0) + np.testing.assert_array_equal(result, expected) + + def test_debug_print(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros(128, np.int32), + ) + def kernel(ref): + ref[...] = ref[...] # Prevent kernel from being DCE'd + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + pl.debug_print("warp: {}", warp_id) + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "warp: 0", + "warp: 1", + "warp: 2", + "warp: 3", + }, + ) + + class PallasCallWGTest( PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): From 1aaec81f22a0dde3f7da56bc54bdd71f212076c9 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 22 May 2025 11:35:56 -0700 Subject: [PATCH 1321/1769] Add support for non-power-of-2 head size in flash attention Introduce checks on sequences being divisible by block sizes to address https://github.com/jax-ml/jax/issues/27224 PiperOrigin-RevId: 762051831 --- jax/experimental/pallas/ops/gpu/attention.py | 184 ++++++++++--------- tests/pallas/gpu_ops_test.py | 4 +- 2 files changed, 104 insertions(+), 84 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 8b83d24ea199..ccb3ae8fd3b7 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -86,28 +86,29 @@ def mha_forward_kernel( segment_ids_ref: jax.Array | None, # segment_id arrays o_ref: Any, # Output *residual_refs: Any, # Residual outputs - num_heads: int, sm_scale: float, causal: bool, block_q: int, - block_d: int, block_k: int, + head_dim: int, ): seq_len = k_ref.shape[0] start_q = pl.program_id(0) + head_dim_padded = q_ref.shape[-1] # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - o = jnp.zeros((block_q, block_d), dtype=jnp.float32) + o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_q, block_d], block_d == head_dim. + # q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = q_ref[...] + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + q = pl.load(q_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) q_segment_ids = ( None if segment_ids_ref is None @@ -121,7 +122,7 @@ def body(start_k, carry): o_prev, m_prev, l_prev = carry curr_k_slice = pl.dslice(start_k * block_k, block_k) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) # [block_q, block_k] # Scale logits to convert from base-2 to the natural log domain. @@ -161,7 +162,7 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev - v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -182,7 +183,8 @@ def body(start_k, carry): lse_ref = residual_refs[0] lse_ref[...] = m_i + jnp.log2(l_i) # Write output to dram. - o_ref[...] = o.astype(o_ref.dtype) + pl.store(o_ref, (slice(None), slice(o.shape[-1])), o.astype(o_ref.dtype), + mask=head_mask) def segment_mask( q_segment_ids: jax.Array, @@ -235,6 +237,17 @@ def mha( kv_seq_len = k.shape[1] block_q = min(block_sizes.block_q, q_seq_len) block_k = min(block_sizes.block_k, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) + if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]): + raise ValueError( + f"This kernel expects q, k, and v to have the same head dimension, but" + f" found {q.shape=}, {k.shape=}, {v.shape=}." + ) + if q_seq_len % block_q != 0: + raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}") + if kv_seq_len % block_k != 0: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}") + # Heuristics. grid_ = grid if grid_ is None: @@ -243,21 +256,17 @@ def mha( num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, block_q=block_q, - block_k=block_k, block_d=head_dim, - causal=causal) + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + block_q=block_q, block_k=block_k, + head_dim=head_dim, causal=causal) in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] @@ -270,7 +279,7 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) + (None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0) ), compiler_params=plgpu.TritonCompilerParams( num_warps=num_warps_, num_stages=num_stages), @@ -301,6 +310,17 @@ def _mha_forward( kv_seq_len = k.shape[1] block_q = min(block_sizes.block_q, q_seq_len) block_k = min(block_sizes.block_k, kv_seq_len) + if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]): + raise ValueError( + f"This kernel expects q, k, and v to have the same head dimension, but" + f" found {q.shape=}, {k.shape=}, {v.shape=}." + ) + if q_seq_len % block_q != 0: + raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}") + if kv_seq_len % block_k != 0: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}") + head_dim_padded = pl.next_power_of_2(head_dim) + # Heuristics. grid_ = grid if grid_ is None: @@ -309,9 +329,9 @@ def _mha_forward( num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, causal=causal, block_q=block_q, - block_k=block_k, block_d=head_dim) + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + causal=causal, block_q=block_q, block_k=block_k, + head_dim=head_dim) out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out jax.ShapeDtypeStruct( @@ -319,15 +339,12 @@ def _mha_forward( ), ] in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] @@ -339,9 +356,8 @@ def _mha_forward( grid=grid_, in_specs=in_specs, out_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), ], compiler_params=plgpu.TritonCompilerParams( @@ -355,10 +371,11 @@ def _mha_forward( return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int): # load - o = out_ref[...].astype(jnp.float32) - do = dout_ref[...].astype(jnp.float32) + head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :] + o = pl.load(out_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) + do = pl.load(dout_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) # compute delta = jnp.sum(o * do, axis=1) # write-back @@ -368,17 +385,16 @@ def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape + head_dim_padded = pl.next_power_of_2(head_dim) out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) delta = pl.pallas_call( - _preprocess_backward_kernel, + functools.partial(_preprocess_backward_kernel, head_dim=head_dim), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), ], out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), @@ -414,7 +430,7 @@ def mha_backward_kernel( block_kv_dkv: int, block_q_dq: int, block_kv_dq: int, - block_d: int, + head_dim: int, ): del out_ref # Not needed q_seq_len = q_ref.shape[0] @@ -427,11 +443,13 @@ def mha_backward_kernel( start_k = pl.program_id(2) curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv) - dv = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) + head_dim_padded = q_ref.shape[-1] + dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) + dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) - v = pl.load(v_ref, (curr_k_slice, slice(None))) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv) kv_segment_ids = ( None @@ -443,7 +461,7 @@ def inner_loop_dkdv(start_q, carry): dv, dk = carry curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) if sm_scale != 1.: @@ -466,7 +484,8 @@ def inner_loop_dkdv(start_q, carry): lse = pl.load(lse_ref, (curr_q_slice,)) di = pl.load(delta_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask, + other=0.0) p = jnp.exp2(qk - lse[:, None]) dv = dv + pl.dot(p.astype(do.dtype).T, do) @@ -483,8 +502,10 @@ def inner_loop_dkdv(start_q, carry): dv, dk = lax.fori_loop( lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk) ) - dv_ref[...] = dv.astype(dv_ref.dtype) - dk_ref[...] = dk.astype(dk_ref.dtype) + pl.store(dv_ref, (slice(None), slice(dv.shape[-1])), dv.astype(dv_ref.dtype), + mask=head_mask) + pl.store(dk_ref, (slice(None), slice(dk.shape[-1])), dk.astype(dk_ref.dtype), + mask=head_mask) # Scan #2: dQ # 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM. @@ -493,22 +514,23 @@ def inner_loop_dkdv(start_q, carry): start_q = pl.program_id(2) curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq) span_q = start_q * block_q_dq + jnp.arange(block_q_dq) - dq = jnp.zeros([block_q_dq, block_d], dtype=jnp.float32) + dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0) q_segment_ids = ( None if segment_ids_ref is None else pl.load(segment_ids_ref, (curr_q_slice,)) ) lse = pl.load(lse_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask, + other=0.0) di = pl.load(delta_ref, (curr_q_slice,)) def inner_loop_dq(start_k, dq): curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq) - k = pl.load(k_ref, (curr_k_slice, slice(None))) - v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) @@ -547,7 +569,8 @@ def inner_loop_dq(start_k, dq): upper_bound = pl.cdiv(kv_seq_len, block_kv_dq) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - dq_ref[...] = dq.astype(dq_ref.dtype) + pl.store(dq_ref, (slice(None), slice(dq.shape[-1])), dq.astype(dq_ref.dtype), + mask=head_mask) def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, @@ -576,6 +599,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len) block_q_dq = min(block_sizes.block_q_dq, q_seq_len) block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv: raise ValueError( @@ -591,28 +615,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, ] in_specs = [ - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), + lambda i, j, _: (i, 0))) grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv)) num_warps_ = num_warps @@ -635,22 +655,22 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv=block_kv_dkv, block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, - block_d=head_dim, + head_dim=head_dim, ), out_shape=out_shapes, in_specs=in_specs, grid=grid, out_specs=[ pl.BlockSpec( - (None, block_q_dq, None, head_dim), + (None, block_q_dq, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dv ), ], diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 1b758cdd0a58..1637686365e1 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -153,7 +153,7 @@ def setUp(self): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2, 8), - head_dim=(32, 64, 128), + head_dim=(32, 64, 72, 128), block_sizes=( (("block_q", 128), ("block_k", 128)), (("block_q", 64), ("block_k", 64)), @@ -226,7 +226,7 @@ def impl(q, k, v): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2), - head_dim=(32, 64, 128,), + head_dim=(32, 64, 72, 128,), block_sizes=( ( ("block_q", 128), From 7cf4f35442743b01f4685ca906d282586428b3d0 Mon Sep 17 00:00:00 2001 From: Jen Ha Date: Wed, 21 May 2025 10:49:51 -0700 Subject: [PATCH 1322/1769] Use the default python logging instead of absl log. Update use cases in export / compiler as well as the documentation. --- docs/export/export.md | 5 +---- jax/_src/compiler.py | 6 +++--- jax/_src/export/_export.py | 23 ++++++++++++----------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/export/export.md b/docs/export/export.md index 63c0db14f905..95e47385997c 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -710,10 +710,7 @@ total 32 -rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir ``` -Inside Google, you can turn on logging by using the `--vmodule` argument to -specify the logging levels for different modules, -e.g., `--vmodule=_export=3`. - +Set [`JAX_DEBUG_LOG_MODULES=jax._src.export`](https://docs.jax.dev/en/latest/config_options.html#jax_debug_log_modules) to enable extra debugging logging. (export_ensuring_compat)= ### Ensuring forward and backward compatibility diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 04f993fed799..cbde6fdb3366 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -242,7 +242,7 @@ def get_compile_options( else: compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE if backend is None: - logging.info("get_compile_options: no backend supplied; " + logger.info("get_compile_options: no backend supplied; " "disabling XLA-AutoFDO profile") else: fdo_profile_version = get_latest_profile_version(backend) @@ -376,7 +376,7 @@ def compile_or_get_cached( module_name = ir.StringAttr(sym_name).value if dumped_to := mlir.dump_module_to_file(computation, "compile"): - logging.info("Dumped the module to %s.", dumped_to) + logger.info("Dumped the module to %s.", dumped_to) is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 @@ -521,7 +521,7 @@ def _resolve_compilation_strategy( # The compilation cache is enabled and AutoPGLE is enabled/expected if _is_executable_in_cache(backend, pgle_optimized_cache_key): if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") + logger.info(f"PGLE-optimized {module_name} loaded from compilation cache") # No need to record N profiles in this case if pgle_profiler is not None: pgle_profiler.disable() diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index c0ca1e108590..189818541a2c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -25,7 +25,7 @@ import re from typing import Any, Protocol, TypeVar, Union, cast -from absl import logging +import logging import numpy as np import jax @@ -55,6 +55,8 @@ from jax._src.export import shape_poly +logger = logging.getLogger(__name__) + map = util.safe_map zip = util.safe_zip @@ -704,16 +706,15 @@ def _export_lowered( out_avals_flat = lowered.compile_args["out_avals"] # type: ignore # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"fun_name={fun_name} version={version} " - f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] - f"disabled_checks={disabled_checks}") - logging.info("Exported JAX function: %s\n", logmsg) - logging.info(mlir.dump_module_message(mlir_module, "export")) - logging.info( - "Size of mlir_module_serialized: %d byte", - len(mlir_module_serialized), - ) + logmsg = (f"fun_name={fun_name} version={version} " + f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] + f"disabled_checks={disabled_checks}") + logger.debug("Exported JAX function: %s\n", logmsg) + logger.debug(mlir.dump_module_message(mlir_module, "export")) + logger.debug( + "Size of mlir_module_serialized: %d byte", + len(mlir_module_serialized), + ) _check_module(mlir_module, disabled_checks=disabled_checks, From 80b4f801bf87c8bdf8708bad84a73bcc406514c1 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 22 May 2025 14:02:39 -0700 Subject: [PATCH 1323/1769] Remove block_until_ready from testAutodiffCache PiperOrigin-RevId: 762116819 --- tests/pjit_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 48339bb2a519..92190dc6bf3c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -671,7 +671,7 @@ def testAutodiffCache(self): jax.grad(f)(x) # Warm up the cache. with jtu.count_pjit_cpp_cache_miss() as count: - jax.block_until_ready(jax.grad(f)(x)) + jax.grad(f)(x) self.assertEqual(count(), 0) # no cache miss i.e. cache hit @jtu.with_mesh([('x', 2), ('y', 1)]) From 0e24e98032fef09e31a8e9219967aa92cf370b86 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 22 May 2025 14:25:14 -0700 Subject: [PATCH 1324/1769] Add CUDA pytest job strategy to test CUDA packages downloaded from PyPI Continuous and Nightly/Release workflows will now run the CUDA 12.8 test runs by using the Nvidia CUDA packages from PyPI instead of those on the system. PiperOrigin-RevId: 762125356 --- .github/workflows/pytest_cuda.yml | 33 +++++++++++++++---- .github/workflows/pytest_tpu.yml | 6 ++-- .github/workflows/wheel_tests_continuous.yml | 16 ++++++--- .../workflows/wheel_tests_nightly_release.yml | 10 ++++-- ci/envs/README.md | 2 +- ci/envs/default.env | 6 ++-- ci/utilities/install_wheels_locally.sh | 14 +++----- 7 files changed, 56 insertions(+), 31 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index a20be5b1dbcf..2f22901e661a 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -24,11 +24,16 @@ on: type: string required: true default: "3.12" - cuda: + cuda-version: description: "Which CUDA version to test?" type: string required: true - default: "12.3" + default: "12.8" + use-nvidia-pip-wheels: + description: "Whether to download CUDA packages from PyPI?" + type: boolean + required: false + default: false enable-x64: description: "Should x64 mode be enabled?" type: string @@ -58,8 +63,11 @@ jobs: shell: bash runs-on: ${{ inputs.runner }} # Test the oldest and newest supported CUDA versions. - container: ${{ (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || - (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} + # If testing the CUDA packages from PyPI, then use the ml-build image which does not have any + # CUDA pckages installed on the system. + container: ${{ !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || + !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') || + inputs.use-nvidia-pip-wheels && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'}} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: @@ -100,13 +108,24 @@ jobs: if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then echo "JAX only release. Only downloading the jax wheel from the release bucket." - # Set the env var to install the CUDA plugin and PJRT packages from PyPI. jaxlib is - # required dependency of jax so that gets installed automatically. - echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=jax_cuda_pypi">> $GITHUB_ENV + if [[ "${{ inputs.use-nvidia-pip-wheels }}" == false ]]; then + # Install only the PJRT and JAX CUDA Plugin packages from PyPI. Nvidia CUDA packages + # are used from the system. + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12-local">> $GITHUB_ENV + else + # Install the PJRT, JAX CUDA Plugin, and Nvidia CUDA packages from PyPI. + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12">> $GITHUB_ENV + fi else gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + + if [[ "${{ inputs.use-nvidia-pip-wheels }}" == true ]]; then + # Install the Nvidia CUDA packages from PyPI. The wheels downloaded in the previous + # step will be used for the PJRT and JAX CUDA Plugin packages. + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12">> $GITHUB_ENV + fi fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index d1af90283001..ae0250884831 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -137,9 +137,9 @@ jobs: $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" - # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` - # script will install the latest libtpu wheel from PyPI. - echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV + # Set JAXCI_JAX_PYPI_EXTRAS to "tpu". The `run_pytest_tpu.sh` script will install the + # latest libtpu wheel from PyPI. + echo "JAXCI_JAX_PYPI_EXTRAS=tpu" >> $GITHUB_ENV elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 207075fd0340..91662ff51f3e 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -117,25 +117,31 @@ jobs: # See exlusions for what is fully tested runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1", "12.8"] + cuda: [ + {version: "12.1", use-nvidia-pip-wheels: false}, + {version: "12.8", use-nvidia-pip-wheels: true}, + ] enable-x64: [1, 0] exclude: # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.1" + cuda: + version: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.1" + cuda: + version: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" enable-x64: "0" - name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda.version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} enable-x64: ${{ matrix.enable-x64 }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 3e616a894d13..7bad41647e6b 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -66,13 +66,17 @@ jobs: # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.1", "12.8"] + cuda: [ + {cuda-version: "12.1", use-nvidia-pip-wheels: false}, + {cuda-version: "12.8", use-nvidia-pip-wheels: true} + ] enable-x64: [0] - name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda.cuda-version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} enable-x64: ${{ matrix.enable-x64 }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} diff --git a/ci/envs/README.md b/ci/envs/README.md index 6b5dc554d824..cf7a0c12fc9f 100644 --- a/ci/envs/README.md +++ b/ci/envs/README.md @@ -21,7 +21,7 @@ Name | Default Value `JAXCI_ENABLE_X64` | 0 | By default, JAX enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to `double`. When set to 1, the tests will use double-precision numbers. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ENABLE_X64&type=code) `JAXCI_TPU_CORES` | Unset | Sets the number of TPU cores for the TPU machine type. Values are set in the workflow files. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_TPU_CORES&type=code) `JAXCI_RUN_FULL_TPU_TEST_SUITE` | 0 | When set to 1, the full TPU test suite is run. Otherwise, a subset of tests is run. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_RUN_FULL_TPU_TEST_SUITE&type=code) -`JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI` | Unset | Used to control the installation of JAX [extras](https://github.com/jax-ml/jax/blob/7e42539653d33ec995487b683794c0bc86f7199b/setup.py#L64) from PyPI. See [ci/utilities/install_wheels_locally.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/install_wheels_locally.sh) for the list of valid values and their behavior. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI&type=code) +`JAXCI_JAX_PYPI_EXTRAS` | Unset | Used to control the installation of JAX extras from PyPI. See JAX's [setup.py](https://github.com/jax-ml/jax/blob/c9934912885bb7c4b72c5a9271598235a6789a81/setup.py#L71) for the list of valid values. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_PYPI_EXTRAS&type=code) ## Docker Specific Environment Variables diff --git a/ci/envs/default.env b/ci/envs/default.env index 774464724646..09594af89cbe 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -58,7 +58,7 @@ export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} # Sets the number of TPU cores for the TPU machine type. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# JAXCI_PYTHON points to the Python binary on the system that should be used +# JAXCI_PYTHON points to the Python binary on the system that should be used # for installing the JAX wheels on the system and running Pytest scripts. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} @@ -66,5 +66,5 @@ export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # is run. export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} -# Controls which additional extras to install from PyPI. -export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file +# Controls which additional extras for JAX to install from PyPI. +export JAXCI_JAX_PYPI_EXTRAS=${JAXCI_JAX_PYPI_EXTRAS:-""} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 53f070d1e0e6..b1472d765c08 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -22,15 +22,11 @@ WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o - for i in "${!WHEELS[@]}"; do if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then - if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then - # Append [tpu] to the jax wheel name to download the latest libtpu wheel - # from PyPI. - WHEELS[$i]="${WHEELS[$i]}[tpu]" - elif [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "jax_cuda_pypi" ]]; then - # Append [cuda12-local] to the jax wheel name to download the latest - # release of JAX's CUDA plugin and PJRT packages from PyPI. This is used - # when running CUDA tests for a "jax" only release. - WHEELS[$i]="${WHEELS[$i]}[cuda12-local]" + # Apppend an extra to the end of the JAX wheel path to install those + # packages as well from PyPI. E.g. jax[tpu] will install the libtpu package + # from PyPI. See ci/envs/README.md for more details. + if [[ -n "$JAXCI_JAX_PYPI_EXTRAS" ]]; then + WHEELS[$i]="${WHEELS[$i]}[$JAXCI_JAX_PYPI_EXTRAS]" fi fi done From e48080f46c66fe97da1e5a769f1bd24bd6cbca7d Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 22 May 2025 14:28:28 -0700 Subject: [PATCH 1325/1769] Add pypi NVIDIA NVCC wheel dependency to the tests that run with pre-built wheels. NCCC wheel is used in `[with-cuda]` requirements as stated [here](https://github.com/jax-ml/jax/blob/0dc70b93f2e13fae5b097837760bd621e746dae7/jax_plugins/cuda/plugin_setup.py#L58). PiperOrigin-RevId: 762126511 --- jaxlib/tools/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 433747e2bb8d..7f8a85a5d9ab 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -461,6 +461,7 @@ filegroup( srcs = [ "@pypi_nvidia_cublas_cu12//:whl", "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_nvcc_cu12//:whl", "@pypi_nvidia_cuda_runtime_cu12//:whl", "@pypi_nvidia_cudnn_cu12//:whl", "@pypi_nvidia_cufft_cu12//:whl", From fc683368fa457bbffd3a693fc80ce880c1796e18 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Thu, 22 May 2025 14:50:25 -0700 Subject: [PATCH 1326/1769] Raises an explicit error in reshard. Hit this while working with sharding in types -- passing a sharding that had an empty mesh. (I think this was in a test). This failed trying to acces with `with_spec` attribute on None -- so just catching this case early. PiperOrigin-RevId: 762135310 --- jax/_src/pjit.py | 5 +++++ tests/pjit_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ecdcf3e17332..0c55f3fe30ab 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2987,6 +2987,11 @@ def reshard(xs, out_shardings): out_flat = [] for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): ds = canonicalize_sharding(s, 'reshard') + if ds is None: + raise ValueError( + 'Reshard should only be used with out_shardings which are non-None ' + 'and have a nonempty mesh. Got sharding {s}.' + ) ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error out_flat.append(reshard_p.bind(x, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 92190dc6bf3c..d37b21bd2460 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8781,6 +8781,16 @@ def f(x): "to Shardy"): jax.jit(f)(x) + def test_reshard_empty_mesh_error(self): + arr = jax.device_put(np.arange(8), jax.devices()[0]) + with self.assertRaisesRegex(ValueError, "nonempty mesh"): + reshard(arr, NamedSharding(mesh_lib.empty_abstract_mesh, P(None))) + + def test_reshard_none_sharding_error(self): + arr = jax.device_put(np.arange(8), jax.devices()[0]) + with self.assertRaisesRegex(ValueError, "non-None"): + reshard(arr, None) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 0e690c1a88049787f8371d462d674a94f9b85c73 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Thu, 22 May 2025 14:52:45 -0700 Subject: [PATCH 1327/1769] Use additional semaphores to avoid data races in TPU paged_attention_kernel. Also prevents an out-of-bounds read of SMEM. And re-enables tests for the TPU paged_attention_kernel. @apaszke confirmed the presence of data races using the race detector in the new TPU interpret mode. With the additional semaphores, the race detector no longer detects any races in the this kernel and I no longer see any test failures in 20+ test runs on a TPU. Details on the data races: - In each iteration, the kernel: (a) Starts copying data for `k` and `v` for the next iteration. (b) Waits for the copy of `k` for the current iteration to finish. (c) Waits for the copy of `v` for the current iteration to finish. - It is possible for these copies to happen out of order -- that is: (a) The copies for the next iteration can finish before the copies for the current iteration. (b) And the copies for `v` for the current iteration can finish before the copies for `k` for the current iteration. - If the same DMA semaphore is used for everything, then out-of-order copies can lead to: (a) `k = async_copy_k.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `v` for the current iteration or copies of `k` or `v` for the next iteration. (a) `v = async_copy_v.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `k` or `v` for the next iteration. PiperOrigin-RevId: 762136079 --- .../paged_attention/paged_attention_kernel.py | 23 ++++++++++++------- .../pallas/tpu_paged_attention_kernel_test.py | 2 -- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 6280064f29d3..9c02679c45ea 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -127,7 +127,8 @@ def paged_flash_attention_kernel( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -176,7 +177,9 @@ def advance_to_next_non_zero_length(): return ( lax.cond( - jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + jnp.logical_and( + next_b < batch_size, + lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0), advance_to_next_non_zero_length, lambda: next_b, ), @@ -200,7 +203,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): k_scales_vmem_buffer.at[buffer_index] if k_scales_vmem_buffer is not None else None, - sem, + k_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -213,7 +216,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): v_scales_vmem_buffer.at[buffer_index] if v_scales_vmem_buffer is not None else None, - sem, + v_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -301,7 +304,8 @@ def paged_flash_attention_kernel_inline_seq_dim( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -336,7 +340,8 @@ def body(i, _): k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, batch_size=batch_size, pages_per_compute_block=pages_per_compute_block, pages_per_sequence=pages_per_sequence, @@ -584,7 +589,8 @@ def paged_attention( ), v_scales_pages.dtype, # pytype: disable=attribute-error ), # v_scales_pages buffer - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) else: in_specs = [ @@ -615,7 +621,8 @@ def paged_attention( v_pages.dtype, ), # v_pages buffer None, - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) out, _, _ = pl.pallas_call( diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 9886e7943f6f..ac24fea1b45a 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -265,8 +265,6 @@ def test_paged_attention( attn_logits_soft_cap, are_kv_quantized, ): - # TODO(mvoz, skyewm): Re-enable this test once the data race is fixed. - self.skipTest("This kernel has data races that need to be fixed.") if not jtu.is_device_tpu_at_least(4): self.skipTest("Only supports TPU generation 4 or above") if jtu.is_device_tpu(version=4) and are_kv_quantized: From 859e120fdd79575da26da4ea561c4bb135492b3c Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 22 May 2025 13:35:14 -0700 Subject: [PATCH 1328/1769] Avoid doing DCE of effectful ops and reordering in partial eval. --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/interpreters/ad.py | 2 +- jax/_src/interpreters/partial_eval.py | 46 +++++++++++++++++------ jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/lax.py | 2 +- jax/_src/pjit.py | 16 ++------ jax/_src/shard_map.py | 2 +- jax/_src/state/discharge.py | 2 +- tests/mutable_array_test.py | 8 +++- 12 files changed, 52 insertions(+), 36 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 2d743bf06c6b..2a056d5c94f0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) - recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, + recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index e931a6edb9b3..3490de5118e1 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -127,7 +127,7 @@ def debug_callback_jvp_rule(primals, tangents, **params): ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], - effect: DebugEffect): + effect: DebugEffect, partitioned): del flat_args, callback, effect raise ValueError("Transpose doesn't support debugging callbacks.") ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 9366b91f8022..7cbdfff01462 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -885,7 +885,7 @@ def make_zero(aval): out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info) jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f77db5443a86..6ea16ec8e8ba 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -16,6 +16,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable import contextlib +from dataclasses import dataclass from functools import partial import itertools as it import operator as op @@ -42,7 +43,7 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.source_info_util import SourceInfo -from jax._src.state.types import AbstractRef, ReadEffect +from jax._src.state.types import AbstractRef, ReadEffect, RefEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue: else: return self[0] +@dataclass(frozen=True) +class EffectHandle: + parents : list[Tracer] + recipe : JaxprEqnRecipe class JaxprTrace(Trace['JaxprTracer']): @@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t self.tag = tag self.parent_trace = parent_trace self.requires_low = False + self.effect_handles : list[EffectHandle] = [] + self.counter = it.count() def to_jaxpr_tracer(self, x): if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: @@ -239,14 +246,19 @@ def default_process_primitive(self, primitive, tracers, params): if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, + eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects, source) + if any(isinstance(e, RefEffect) for e in effects): + self.effect_handles.append(EffectHandle(tracers, eqn)) for t in out_tracers: t.recipe = eqn return out_tracers else: out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) - out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, - params, effects, source) + eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive, + params, effects, source) + if any(isinstance(e, RefEffect) for e in effects): + self.effect_handles.append(EffectHandle(tracers, eqn)) + out_tracer.recipe = eqn return out_tracer def process_call(self, primitive, f: lu.WrappedFun, tracers, params): @@ -321,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): for a in out_type] name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn @@ -390,7 +402,7 @@ def const_out_axes_thunk(): for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) src_info = source_info_util.current() - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn @@ -425,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for aval in params['out_types']] in_tracers = map(self.instantiate_const, tracers) new_params = dict(params, call=call) - eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, + eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params, core.no_effects, source_info_util.current()) for t in out_tracers: t.recipe = eqn return out_tracers @@ -470,7 +482,7 @@ def fwd_jaxpr_thunk(*zeros): out_trees=out_trees, symbolic_zeros=symbolic_zeros ) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers), out_tracers, prim, params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -657,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info) + jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info) return out_tracers, jaxpr, out_consts, env # The below variant implements an optimization where residuals which are also @@ -739,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple): source_info: source_info_util.SourceInfo ctx: JaxprEqnContext -def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], +def new_eqn_recipe(trace: JaxprTrace, + in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], primitive: Primitive, params: dict[str, Any], @@ -762,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], config.threefry_partitionable.value, xla_metadata_lib.current_xla_metadata(), ) - return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), + return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) @@ -780,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], + effect_handles: Sequence[Any], debug_info: core.DebugInfo, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. @@ -821,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort((*in_tracers, *out_tracers)): + + reachable = toposort + tracers = reachable((*in_tracers, *out_tracers, *effect_handles)) + def sort_key(t): + r = t.recipe + return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1 + tracers = sorted(tracers, key=sort_key) + + for t in tracers: r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 741636c47e31..4e8368341d9f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -617,7 +617,7 @@ def _cond_partial_eval(trace, *tracers, branches, **params): name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( - [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, + trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.join_effects(*(j.effects for j in branches_unknown)), source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index fc7ebde4cbea..90b81ae367aa 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1 - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps, reverse=reverse, which_linear=which_linear_unknown, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7efe3294fdca..83c31928d7cb 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -920,7 +920,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], + eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a49c27d06eee..0e3695ba4506 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6550,7 +6550,7 @@ def _broadcast_in_dim_partial_eval( out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe( - [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, + trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None), core.no_effects, source_info_util.current()) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0c55f3fe30ab..d5286be8e0c9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2324,18 +2324,8 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) - for e in jaxpr.effects): - known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, - False, False, None) - if num_res_ref: raise NotImplementedError - known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) - unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) - res_avals = unknown_jaxpr.in_avals[:num_res_val] - else: - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) @@ -2431,7 +2421,7 @@ def keep_where(l, should_keep): pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in unknown_out_avals ] - eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), + eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers), unknown_tracers_out, pjit_p, unknown_params, diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 66df2505100c..4f60a833429a 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1369,7 +1369,7 @@ def known_out_specs(): out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), + eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers), out_tracers, shard_map_p, unk_params, effs, source_info_util.current()) for t in out_tracers: t.recipe = eqn diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index bc6a20a0a76e..100447f12d18 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -828,7 +828,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, is_initialized=(True,) * len(jaxpr_unknown.invars)) _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], **uk_params) - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, run_state_p, uk_params, eqn_effects, source) for t in res_ref_unknown_outputs: t.recipe = eqn diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 865d4f8520f1..0da335e2fac5 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -192,14 +192,18 @@ def f(): x = f() self.assertArraysEqual(x, jnp.zeros(8)) - def test_grad_mutable_array(self): - @jax.jit + @parameterized.parameters([False, True]) + def test_grad_mutable_array(self, jit): + def f(x): x_ = core.mutable_array(x) x_[()] = x_[()] + x_[()] y = core.freeze(x_) return y + if jit: + f = jax.jit(f) + ans = jax.grad(f)(1.) expected = 2.0 self.assertAllClose(ans, expected, check_dtypes=False) From 7cc9053f9d09aff6df2917a0804d861a04be0260 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 22 May 2025 16:38:57 -0700 Subject: [PATCH 1329/1769] Adds explicit-axis handling to shard_map batching rule. Without this handling, in explicit sharding mode vmap of a function with an internal shmap can introduce unnecessary replication. PiperOrigin-RevId: 762175189 --- jax/_src/shard_map.py | 37 ++++++++++++++++++++++++++++++------- tests/shard_map_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 66df2505100c..0bae2da15272 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1228,6 +1228,15 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): # Batching +def _modify_specs_axis_data(trace, name, mesh, in_specs, in_dims): + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, name) + for sp, d in zip(in_specs, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in name) + new_axis_data = batching.AxisData( + trace.axis_data.name, new_size, trace.axis_data.spmd_name, + trace.axis_data.explicit_mesh_axis) + return new_in_specs, new_axis_data + def _shard_map_batch( trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, @@ -1237,15 +1246,20 @@ def _shard_map_batch( if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError spmd_axis_name = trace.axis_data.spmd_name + explicit_mesh_axis = trace.axis_data.explicit_mesh_axis if spmd_axis_name is not None: used = {n for spec in in_specs for n in _spec_to_vma(spec)} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name) - for sp, d in zip(in_specs, in_dims)] - new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) - new_axis_data = batching.AxisData(trace.axis_data.name, new_size, - trace.axis_data.spmd_name, None) + new_in_specs, new_axis_data = _modify_specs_axis_data( + trace, spmd_axis_name, mesh, in_specs, in_dims) + elif explicit_mesh_axis is not None: + used = {n for spec in in_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map in_specs") + new_in_specs, new_axis_data = _modify_specs_axis_data( + trace, explicit_mesh_axis, mesh, in_specs, in_dims) else: new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(in_specs, in_dims)] @@ -1254,7 +1268,8 @@ def _shard_map_batch( @as_hashable_function(closure=out_specs_thunk) def new_out_specs_thunk(): - return _batch_out_specs(spmd_axis_name, out_dims(), out_specs_thunk()) + return _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims(), + out_specs_thunk()) new_params = dict(mesh=mesh, in_specs=new_in_specs, out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, @@ -1266,13 +1281,21 @@ def new_out_specs_thunk(): return map(make_tracer, out_vals, out_dims()) batching.BatchTrace.process_shard_map = _shard_map_batch -def _batch_out_specs(spmd_name, dims, out_specs): +def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs): if spmd_name is not None: used = {n for spec in out_specs for n in _spec_to_vma(spec)} if not config.disable_vmap_shmap_error.value and set(spmd_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name) for sp, d in zip(out_specs, dims)] + elif explicit_mesh_axis is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map out_specs") + return [sp if d is batching.not_mapped else + pxla.batch_spec(sp, d, explicit_mesh_axis) + for sp, d in zip(out_specs, dims)] else: return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(out_specs, dims)] diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 5fbace3c98e1..90360989f13f 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -767,6 +767,46 @@ def f(x): self.assertIn('out_specs', e.params) self.assertEqual(e.params['out_specs'], (P('y', 'x'),)) + def test_vmap_explicit_mesh_axis(self): + mesh = jtu.create_mesh( + (1, 2, 2), ('z', 'x', 'y'), axis_types=(AxisType.Explicit,) * 3) + + @shard_map(mesh=mesh, in_specs=P('y'), out_specs=P('y')) + def f(x): + return x + + x = jnp.arange(4 * 4).reshape(4, 4) + s = NamedSharding(mesh, P(('z', 'x'), 'y')) + x = jax.device_put(x, s) + + f = jax.jit(jax.vmap(f)) + out = f(x) + self.assertEqual(out.sharding, s) + + def test_vmap_explicit_mesh_axis_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return x + + x = jnp.arange(4 * 4).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + x = jax.device_put(x, s) + + f = jax.jit(jax.vmap(f)) + with self.assertRaisesRegex( + ValueError, "vmapped away explicit mesh axis cannot appear"): + f(x) + + f = jax.jit(jax.vmap(f, spmd_axis_name='y')) + with self.assertRaisesRegex( + ValueError, + 'Only one of spmd_axis_name or arrays sharded on `Explicit` mesh axis' + ' type is allowed'): + f(x) + def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From 326361828d20d9cc295edf2546c97696f446f1df Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 22 May 2025 18:54:37 -0700 Subject: [PATCH 1330/1769] Introduce `get_axis_sizes` API for `HloSharding` in JAX. PiperOrigin-RevId: 762213931 --- jaxlib/BUILD | 2 ++ jaxlib/_jax/__init__.pyi | 1 + jaxlib/xla_client.py | 2 +- jaxlib/xla_compiler.cc | 11 ++++++++++- tests/array_test.py | 13 +++++++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 363218f4a9f3..dd96b7d23a8e 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1212,6 +1212,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", "@nanobind", "@xla//xla:array", "@xla//xla:debug_options_flags", @@ -1242,6 +1243,7 @@ cc_library( "@xla//xla/service:hlo_module_config", "@xla//xla/service:hlo_proto_cc", "@xla//xla/service:name_uniquer", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", "@xla//xla/tsl/lib/strings:proto_serialization", "@xla//xla/tsl/platform:env", "@xla//xla/tsl/platform:errors", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 1d7f3042e8a3..898a4d5f2d22 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -422,6 +422,7 @@ class HloSharding: def subgroup_types(self) -> Sequence[OpSharding_Type]: ... def replicate_on_last_tile_dim(self) -> bool: ... def to_proto(self) -> OpSharding: ... + def get_axis_sizes(self) -> list[int]: ... # === END xla_compiler.cc diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 8f8c829ee6c7..c97ebf9d7c0a 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 342 +_version = 343 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 73007530c27b..1066c8137d32 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "mlir/Support/LLVM.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -71,6 +72,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" @@ -1447,6 +1449,13 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def("subgroup_types", &xla::HloSharding::subgroup_types) .def("__repr__", [](const xla::HloSharding& self) { return self.ToString(); }) - .def("to_proto", &xla::HloSharding::ToProto); + .def("to_proto", &xla::HloSharding::ToProto) + .def("get_axis_sizes", [](const xla::HloSharding& self) { + // If returning the SmallVector, we encounter the error "unable to + // convert function return value to a Python type!". + mlir::SmallVector mesh_shape = + xla::sdy::getAxisSizes(self.tile_assignment()); + return std::vector(mesh_shape.begin(), mesh_shape.end()); + }); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/tests/array_test.py b/tests/array_test.py index 1691c3acc749..0a04e7bd4bc9 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1466,6 +1466,19 @@ def test_named_sharding_unreduced_error(self): ValueError, "unreduced cannot contain None.*"): NamedSharding(mesh, P('x', unreduced=('y', None))) + def test_hlo_sharding_get_axis_sizes(self): + if jax._src.lib.jaxlib_extension_version < 343: + self.skipTest('Requires jaxlib_extension_version >= 343') + + op = xc.OpSharding() + op.type = xc.OpSharding.Type.OTHER + op.tile_assignment_dimensions = [6, 35] + op.iota_reshape_dims = [7, 10, 3] + op.iota_transpose_perm = [2, 1, 0] + s = GSPMDSharding(jax.devices(), op) + self.assertIn('{devices=[6,35]<=[7,10,3]T(2,1,0)}', repr(s)) + self.assertEqual(s._to_xla_hlo_sharding(2).get_axis_sizes(), [7, 2, 5, 3]) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): From 82d76099b448a1ea9da1ae475d426abed2d57f90 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 22 May 2025 19:22:00 -0700 Subject: [PATCH 1331/1769] Implement `_to_sdy_sharding` for GSPMDSharding. PiperOrigin-RevId: 762220186 --- jax/_src/sharding_impls.py | 25 +++++++++++++++++++--- jaxlib/_jax/__init__.pyi | 1 + jaxlib/xla_client.py | 2 +- jaxlib/xla_compiler.cc | 4 ++++ tests/array_test.py | 43 +++++++++++++++++++++++++++++++++++++- 5 files changed, 70 insertions(+), 5 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 982af82c5c4d..4703e6403079 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -638,8 +638,25 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return self._hlo_sharding def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: - raise NotImplementedError( - "GSPMDSharding can't be converted to SdyArray.") + if self._hlo_sharding.tuple_elements(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + elif self._hlo_sharding.is_replicated(): + empty_mesh = mesh_lib.AbstractMesh((), ()) + return NamedSharding(empty_mesh, PartitionSpec())._to_sdy_sharding( + num_dimensions) + elif self._hlo_sharding.is_tiled(): + if not self._hlo_sharding.is_tile_assignment_iota(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + axis_sizes = tuple(self._hlo_sharding.get_axis_sizes()) + axis_names = tuple(f'_axis_{i}' for i in range(len(axis_sizes))) + mesh = mesh_lib.AbstractMesh(axis_sizes, axis_names) + return _gspmd_to_named_sharding_via_mesh(self, mesh)._to_sdy_sharding( + num_dimensions) + else: + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') @functools.cached_property def is_fully_replicated(self) -> bool: @@ -1241,11 +1258,13 @@ def create_mesh_pspec_sharding( def _gspmd_to_named_sharding_via_mesh( - out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + out_s: GSPMDSharding, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh +) -> NamedSharding: spec = parse_flatten_op_sharding(out_s._hlo_sharding, mesh)[0] return create_mesh_pspec_sharding( mesh, spec, memory_kind=out_s.memory_kind) + def flatten_spec(spec): out = [] for s in spec: diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 898a4d5f2d22..ed0089a3dd88 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -417,6 +417,7 @@ class HloSharding: def tuple_elements(self) -> list[HloSharding]: ... def num_devices(self) -> int: ... def num_dimensions(self) -> int: ... + def is_tile_assignment_iota(self) -> bool: ... def tile_assignment_dimensions(self) -> Sequence[int]: ... def tile_assignment_devices(self) -> Sequence[int]: ... def subgroup_types(self) -> Sequence[OpSharding_Type]: ... diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index c97ebf9d7c0a..ac816e72bebe 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 343 +_version = 344 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 1066c8137d32..f9ec134793ed 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -1425,6 +1425,10 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { return self.tile_assignment().num_dimensions(); }, nb::lock_self()) + .def("is_tile_assignment_iota", + [](const xla::HloSharding& self) { + return self.tile_assignment().iota().has_value(); + }) .def( "tile_assignment_dimensions", [](const xla::HloSharding& self) { diff --git a/tests/array_test.py b/tests/array_test.py index 0a04e7bd4bc9..10a17b557dea 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -28,6 +28,7 @@ from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip @@ -1467,7 +1468,7 @@ def test_named_sharding_unreduced_error(self): NamedSharding(mesh, P('x', unreduced=('y', None))) def test_hlo_sharding_get_axis_sizes(self): - if jax._src.lib.jaxlib_extension_version < 343: + if jaxlib_extension_version < 343: self.skipTest('Requires jaxlib_extension_version >= 343') op = xc.OpSharding() @@ -1479,6 +1480,46 @@ def test_hlo_sharding_get_axis_sizes(self): self.assertIn('{devices=[6,35]<=[7,10,3]T(2,1,0)}', repr(s)) self.assertEqual(s._to_xla_hlo_sharding(2).get_axis_sizes(), [7, 2, 5, 3]) + @parameterized.named_parameters( + ('2d_mesh_x_y', (4, 2), P('x', 'y')), + ('2d_mesh_x', (4, 2), P('x')), + ('2d_mesh_y', (4, 2), P('y')), + ('2d_mesh_none_y', (4, 2), P(None, 'y')), + ('2d_mesh_none_x', (4, 2), P(None, 'x')), + ('2d_mesh_xy', (4, 2), P(('x', 'y'))), + ('2d_mesh_none_xy', (4, 2), P(None, ('x', 'y'))), + ('2d_mesh_fully_replicated', (4, 2), P()), + ('2d_mesh_x_none', (2, 1), P(('x',), None)), + ('3d_mesh_none_none_z', (2, 2, 2), P(None, None, 'z')), + ('3d_mesh_none_y_none', (2, 2, 2), P(None, 'y', None)), + ('3d_mesh_x_y_none', (2, 2, 2), P('x', 'y', None)), + ('3d_mesh_none_yz', (2, 2, 2), P(None, ('y', 'z'))), + ('3d_mesh_x_none_yz', (2, 2, 2), P('x', None, ('y', 'z'))), + ('3d_mesh_none_x_yz', (2, 2, 2), P(None, 'x', ('y', 'z'))), + ('3d_mesh_xy_z', (2, 2, 2), P(('x', 'y'), 'z')), + ('3d_mesh_xy_none_z', (2, 2, 2), P(('x', 'y'), None, 'z')), + ('3d_mesh_x_y_z', (2, 2, 2), P('x', 'y', 'z')), + ('3d_mesh_xz_y', (2, 2, 2), P(('x', 'z'), 'y')), + ('3d_mesh_xz_none_y', (2, 2, 2), P(('x', 'z'), None, 'y')), + ('3d_mesh_y_none_xz', (2, 2, 2), P('y', None, ('x', 'z'))), + ('3d_mesh_none_y_xz', (2, 2, 2), P(None, 'y', ('x', 'z'))), + ('3d_mesh2_none_none_z', (1, 2, 4), P(None, None, 'z')), + ('3d_mesh2_x_none_none', (1, 2, 4), P('x', None, None)), + ('3d_mesh_x_none_none', (2, 1, 1), P('x', None, None)), + ) + def test_gspmd_sharding_shardy_lowering(self, mesh_shape, pspec): + if jaxlib_extension_version < 344: + self.skipTest('Requires jaxlib_extension_version >= 344') + + ndim = len(mesh_shape) + mesh = jtu.create_mesh( + mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z') + ) + ns = jax.sharding.NamedSharding(mesh, pspec) + gs = GSPMDSharding(ns._device_assignment, ns._to_xla_hlo_sharding(ndim)) + out_sdy_sharding = gs._to_sdy_sharding(ndim) + self.assertTrue(out_sdy_sharding, ns._to_sdy_sharding(ndim)) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): From 199d9f73cba51e940a69c3e90f666e869c13f413 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 22 May 2025 22:14:25 -0700 Subject: [PATCH 1332/1769] Fix psum_invariant transpose rule where we were binding `pvary` with `ad_util.Zero` which lead to errors like this: `TypeError: Argument 'Zero(float32[1,1,512])' of type '' is not a valid JAX type` PiperOrigin-RevId: 762264425 --- jax/_src/lax/parallel.py | 22 +++++++++++++++++----- tests/shard_map_test.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index bf27261a2c8e..c5f8d3988144 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -2056,13 +2056,25 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes') def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups): - del args - return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) + def f(ct, arg): + assert ad.is_undefined_primal(arg) + return ad.Zero(arg.aval) if type(ct) is ad.Zero else ct + cts = map(f, cts, args) + nonzero_out_cts, treedef = tree_util.tree_flatten(cts) + nonzero_in_cts = core.pvary_p.bind(*nonzero_out_cts, axes=axes, + axis_index_groups=axis_index_groups) + return tree_util.tree_unflatten(treedef, nonzero_in_cts) ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule) ########################### pvary ################################## -def _pvary_transpose_rule(cts, *_, axes, axis_index_groups): - return psum_invariant_p.bind( - *cts, axes=axes, axis_index_groups=axis_index_groups) +def _pvary_transpose_rule(cts, *args, axes, axis_index_groups): + def f(ct, arg): + assert ad.is_undefined_primal(arg) + return ad.Zero(arg.aval) if type(ct) is ad.Zero else ct + cts = map(f, cts, args) + nonzero_out_cts, treedef = tree_util.tree_flatten(cts) + nonzero_in_cts = psum_invariant_p.bind(*nonzero_out_cts, axes=axes, + axis_index_groups=axis_index_groups) + return tree_util.tree_unflatten(treedef, nonzero_in_cts) ad.deflinear2(core.pvary_p, _pvary_transpose_rule) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 90360989f13f..df69db7c9462 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -989,6 +989,22 @@ def f(x): for i in range(len(jax.devices())): self.assertIn(f'instance {i} has value', output()) + def test_psum_transpose_non_zero_cts(self): + mesh = jtu.create_mesh((8,), 'x') + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=(P('x'), P())) + def f1(x_block): + return x_block, jax.lax.psum(x_block, axis_name='x') + + x1 = jnp.arange(16.) + f1(x1) # doesn't crash + + def f2(x_block): + y, _ = f1(x_block) + return y.sum() + + jax.jit(jax.grad(f2))(x1) # doesn't crash + jax.grad(f2)(x1) # doesn't crash + @jtu.run_on_devices('cpu', 'gpu', 'tpu') @jtu.thread_unsafe_test() def test_debug_print_jit_partial_auto(self): From dc0cdf720bd3e3747d702a29cb1e63ef529eba80 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 23 May 2025 04:27:21 -0700 Subject: [PATCH 1333/1769] [Mosaic GPU] Properly handle single-element outputs in inline assembly. This issue was discovered when enabling the ragged dot example kernel to run using warpgroup semantics. The new test requires `-UNDEBUG`. PiperOrigin-RevId: 762365352 --- .../mosaic/gpu/fragmented_array.py | 33 ++++++++++++------- tests/mosaic/gpu_test.py | 15 +++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f69d3f33fe7c..77584b5f0dd4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2591,17 +2591,28 @@ def _repack(regs_it, reg_ty): all_reg_constraints = ",".join( [*("=" + c for c in reg_constraints), *reg_constraints] ) - struct_ty = ir.Type.parse( - f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" - ) - result_struct = llvm.inline_asm( - struct_ty, regs, ptx, all_reg_constraints, - asm_dialect=0, has_side_effects=True, - ) - regs = [ - llvm.extractvalue(dtype, result_struct, [i]) - for i, dtype in enumerate(reg_dtypes) - ] + + if len(reg_dtypes) == 1: + # The InlineAsm::verify() function doesn't allow a struct output when there + # is only one element (even though that seems to work for the case below). + result_elem = llvm.inline_asm( + reg_dtypes[0], regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [result_elem] + else: + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" + ) + result_struct = llvm.inline_asm( + struct_ty, regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(dtype, result_struct, [i]) + for i, dtype in enumerate(reg_dtypes) + ] + i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7fc9723347b..6ea27eb42878 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2451,6 +2451,21 @@ def kernel(ctx, inp, out, smem): f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None) np.testing.assert_array_equal(f(x), x * 3) + def test_optimization_barrier_with_single_value(self): + shape = (64, 64) + value = 5.0 + dtype = jnp.float32 + def kernel(ctx, out, smem): + del ctx, smem + mlir_type = utils.dtype_to_ir_type(dtype) + arr = mgpu.FragmentedArray.splat(c(value, mlir_type), shape) + arr = mgpu.optimization_barrier(arr) + arr.store_untiled(out) + + out_shape = jax.ShapeDtypeStruct(shape, dtype) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()) + np.testing.assert_array_equal(f(), jnp.full(shape, value, dtype=dtype)) + def test_convert_bool_to_u8(self): m, n = 128, 128 def kernel(ctx, dst, _): From 5a448b867cf9c91d99472b51f78c66749ce98e62 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 May 2025 04:46:46 -0700 Subject: [PATCH 1334/1769] [Mosaic GPU] Add support for async copies to peer devices PiperOrigin-RevId: 762370447 --- jax/experimental/mosaic/gpu/launch_context.py | 55 +++++++++- jaxlib/mosaic/gpu/BUILD | 2 + tests/mosaic/BUILD | 21 ++++ tests/mosaic/gpu_test_distributed.py | 100 ++++++++++++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 tests/mosaic/gpu_test_distributed.py diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index aaae007a67f0..e4f0c4efa22c 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -400,6 +400,7 @@ def _get_tma_desc( self, gmem_ref, gmem_transform: tuple[MemRefTransform, ...], + gmem_peer_id: int | ir.Value | None, transformed_slice_shape: tuple[int, ...], swizzle: int | None, reduction_op: Literal[ @@ -408,6 +409,7 @@ def _get_tma_desc( ): tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) ptr_ty = ir.Type.parse("!llvm.ptr") def init_tma_desc(host_ptr): @@ -432,6 +434,25 @@ def init_tma_desc(host_ptr): base_ptr = llvm.getelementptr( ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none, ) + if gmem_peer_id is not None: + if not isinstance(gmem_peer_id, ir.Value): + peer_id = c(gmem_peer_id, i32) + else: + try: + # We try to reproduce the gmem_peer_id computation on the host. + peer_id = _recompute_peer_id(gmem_peer_id) + except ReplicationError as e: + raise ValueError( + "Failed to recompute the async_copy peer id on the host" + ) from e + self._ensure_nvshmem_decls() + base_ptr = llvm.call( + base_ptr.type, + [base_ptr, peer_id], + [], + [], + callee="nvshmem_ptr", + ) rank = ref_ty.rank assert rank * 2 == len(sizes_and_strides) swizzle_arg = ( @@ -507,6 +528,7 @@ def async_copy( dst_ref, gmem_slice: Any = (), gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + gmem_peer_id: int | ir.Value | None = None, barrier: utils.BarrierRef | None = None, swizzle: int | None = None, arrive: bool | None = None, @@ -750,7 +772,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): multicast_mask = None tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, + gmem_ref, gmem_transform, gmem_peer_id, + tuple(slice_shape), swizzle, reduction_op, ) # We constuct TMA descriptors in column-major order. @@ -893,3 +916,33 @@ def device_id(self) -> ir.Value: self._ensure_nvshmem_decls() i32 = ir.IntegerType.get_signless(32) return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + + +class ReplicationError(Exception): + pass + +def _recompute_peer_id(peer_id: ir.Value, fuel=8) -> ir.Value: + if fuel == 0: + raise ReplicationError( + "gmem_peer_id computation is too complicated to recompute on the host" + ) + if isinstance(peer_id, ir.BlockArgument): + raise ReplicationError("Can't recompute a value that's a block argument") + op = peer_id.owner.opview + # We accept all arith ops + if op.OPERATION_NAME.startswith("arith."): + new_operands = [_recompute_peer_id(x, fuel - 1) for x in op.operands] + result_types = [r.type for r in op.results] + new_attributes = {na.name: na.attr for na in op.attributes} + new_op = ir.Operation.create( + op.OPERATION_NAME, result_types, new_operands, new_attributes + ) + return new_op.results if len(new_op.results) > 1 else new_op.result + # nvshmem_my_pe queries the device id of the current process and works on both + # the host and the device. + if isinstance(op, llvm.CallOp) and op.callee.value == "nvshmem_my_pe": + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + raise ReplicationError( + f"Unrecognized op can't be recomputed on the host: {op}" + ) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 66f13bdac7f5..fc1abb9397d5 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -122,6 +122,8 @@ cc_library( # Linker may prune these symbols if they are not explicitly exported. linkopts = [ "-Wl,--export-dynamic-symbol='mosaic_gpu_*'", + "-Wl,--export-dynamic-symbol='nvshmem_my_pe'", + "-Wl,--export-dynamic-symbol='nvshmem_ptr'", "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 75e1df335f6f..9b6a7f79d099 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -63,6 +63,27 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "gpu_test_distributed", + srcs = ["gpu_test_distributed.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"}, + tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + "//jax:mosaic_gpu", + "//jax:test_multiprocess", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_py_test( name = "gpu_dialect_test", srcs = ["gpu_dialect_test.py"], diff --git a/tests/mosaic/gpu_test_distributed.py b/tests/mosaic/gpu_test_distributed.py new file mode 100644 index 000000000000..fee2ce5b03a6 --- /dev/null +++ b/tests/mosaic/gpu_test_distributed.py @@ -0,0 +1,100 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +from jax.experimental import shard +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + import jax.experimental.mosaic.gpu as mgpu + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +P = jax.sharding.PartitionSpec + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if jax.device_count() != jax.process_count(): + self.skipTest("Need 1 device per process") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_remote_async_copy(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, src, dst, scratch): + tmp, barrier = scratch + other_device = arith.subi(arith.constant(i32, 1), ctx.device_id()) + ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier) + barrier.wait() + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device) + ctx.await_async_copy(0) + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + x = shard.reshard(x_np, P("x")) + y = jax.jit( + jax.shard_map( + lambda x: mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier()) + )(x), + out_specs=P("x"), + check_vma=False, + ) + )(x) + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal( + y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0) + ) + + +if __name__ == "__main__": + jt_multiprocess.main() From a0af34f0b166f36d2c69efb1bba066dcf1d8f219 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 May 2025 05:09:36 -0700 Subject: [PATCH 1335/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c361fc2992e8d674636e7870992e95658b1be792. PiperOrigin-RevId: 762376295 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index dda1d3c36cf8..f4af3bc02961 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f040466fc5fa5052b1f4e89fb2b166bb2c9d656a" -XLA_SHA256 = "52ecb43ba06e2d5e8cf84fad5c43104c8f5be92de4bb5e5c8a936e6351093c6e" +XLA_COMMIT = "c361fc2992e8d674636e7870992e95658b1be792" +XLA_SHA256 = "990dd1c54128015235bc28286005704209f52abaf3e39e8f96299e1f5af4f7f1" def repo(): tf_http_archive( From 6cc627a4dc194d91e2ea3920ddd51e310b88920e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 May 2025 06:19:05 -0700 Subject: [PATCH 1336/1769] [Mosaic GPU] Add checks for argument shapes and types Apparently we never checked it and it's been quite easy to get this wrong. PiperOrigin-RevId: 762394139 --- jax/experimental/mosaic/gpu/core.py | 16 +++++++++++++++- tests/mosaic/gpu_test.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 5e1ed6b88412..9464bb587c71 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -677,7 +677,7 @@ def as_gpu_kernel( if launch_ctx.is_device_collective and not supports_cross_device_collectives(): raise RuntimeError("Kernel is a cross-device collective but no support is available.") - expected_arg_treedef = jax.tree.structure(in_shape) + expected_arg_tys, expected_arg_treedef = jax.tree.flatten(in_shape) def _check_args(*args): arg_treedef = jax.tree.structure(args) if arg_treedef != expected_arg_treedef: @@ -685,6 +685,20 @@ def _check_args(*args): f"Invalid argument structure: expected {expected_arg_treedef}, got" f" {arg_treedef}, ({args=})" ) + for arg, expected_ty in zip(args, expected_arg_tys): + if arg.shape != expected_ty.shape: + raise ValueError( + f"Argument shape mismatch: expected {expected_ty.shape}, got" + f" {arg.shape}" + ) + if arg.dtype != expected_ty.dtype: + hint = "" + if not arg.shape: + hint = f". Hint: cast the scalar to {expected_ty.dtype} explicitly." + raise ValueError( + f"Argument dtype mismatch: expected {expected_ty.dtype}, got" + f" {arg.dtype}{hint}" + ) def bind(*args) -> Any: return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6ea27eb42878..42a3f0fc83c1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -460,7 +460,7 @@ def test_scalar_argument(self, dtype): " values read from the 32-bit input buffer to sometimes" " (nondeterministically) contain garbage.") - scalar = 42 + scalar = dtype(42) expected = np.full((128, 128), scalar, dtype=dtype) def kernel(ctx, inp, out, _): From c1c0c0fa27f52de19ddd5112580c8da2f711e1e8 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 23 May 2025 06:20:38 -0700 Subject: [PATCH 1337/1769] [Mosaic GPU] Pass the right number of immediate values in the wgmma inline asm. Before this change, in the int32 case, we pass two extra immediate args compared to the number of parameters in the ASM string. Running tests with `-UNDEBUG` detects the error. I think this likely broke with the special-casing of `int32` in cl/761489756. I've added an assert that should prevent mismatches in the future. PiperOrigin-RevId: 762394530 --- jax/experimental/mosaic/gpu/wgmma.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index abbd517fb37d..2fe826e173e5 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -226,11 +226,18 @@ def lc(x): return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result use_out = scale_a = scale_b = lc(1) - imms = [use_out, scale_a, scale_b] + if out_ty == i32: + imms = [use_out] + else: + imms = [use_out, scale_a, scale_b] + if supports_transpose and a_transpose is not None: imms += [lc(int(a_transpose)), lc(int(b_transpose))] elif supports_transpose: imms += [lc(int(b_transpose))] + + assert len(imms) == num_imm_regs + 1 # +1 for the use_out_reg in setp.ne.b32 + if acc.ndim != 10 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: raise ValueError(acc.shape) acc_struct_type = ir.Type.parse( From 12966b5f578e78f5f80fda3bc877441767738584 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 23 May 2025 07:25:14 -0700 Subject: [PATCH 1338/1769] Fix sharding rule in jax test `ValueError: Sharding rule has 1 operands, but the operation has 2 operands` PiperOrigin-RevId: 762412744 --- tests/cache_key_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index fd3e7706260a..35ac03011a97 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -181,7 +181,7 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, partition=_partition, - sharding_rule='i i -> i') + sharding_rule='..., ... -> ...') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: From 0a970e43eceeb5aa076d1b981c49dfcf796b041c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 07:31:26 -0700 Subject: [PATCH 1339/1769] Move _src/custom_transpose.py into its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 762414305 --- jax/BUILD | 21 ++++++++++++++++++++- tests/BUILD | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index e820bd06fe89..e7f1fad3121d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -307,7 +307,6 @@ py_library_providing_imports_info( "_src/custom_derivatives.py", "_src/custom_partitioning.py", "_src/custom_partitioning_sharding_rule.py", - "_src/custom_transpose.py", "_src/debugging.py", "_src/dispatch.py", "_src/dlpack.py", @@ -391,6 +390,7 @@ py_library_providing_imports_info( ":config", ":core", ":custom_api_util", + ":custom_transpose", ":deprecations", ":dtypes", ":effects", @@ -595,6 +595,25 @@ pytype_strict_library( srcs = ["_src/custom_api_util.py"], ) +pytype_strict_library( + name = "custom_transpose", + srcs = ["_src/custom_transpose.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + pytype_strict_library( name = "deprecations", srcs = ["_src/deprecations.py"], diff --git a/tests/BUILD b/tests/BUILD index 2418c8224869..6c9f3f74b56a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -61,6 +61,7 @@ jax_multiplatform_test( srcs = ["debug_info_test.py"], enable_configs = ["tpu_v3_x4"], deps = [ + "//jax:custom_transpose", "//jax:experimental", "//jax:pallas", "//jax:pallas_gpu", From 9cbf4934936043d472bb7b81cf02c49af77a97b1 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 23 May 2025 07:44:27 -0700 Subject: [PATCH 1340/1769] pytest: use importlib mode by default. This is an attempt to re-land https://github.com/jax-ml/jax/pull/28650, fixing build failures. **Motivation** https://github.com/jax-ml/jax/pull/28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do. The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX. **Solutions** The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest. This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me. **Alternatives** One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in https://github.com/jax-ml/jax/pull/28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution! A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter. PiperOrigin-RevId: 762419160 --- .github/workflows/ci-build.yaml | 2 -- BUILD.bazel | 7 +++++++ .../export_back_compat_test_data/__init__.py | 14 ++++++++++++++ .../pallas/__init__.py | 14 ++++++++++++++ jax/experimental/mosaic/gpu/examples/__init__.py | 14 ++++++++++++++ pyproject.toml | 2 +- setup.py | 2 +- 7 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/__init__.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py create mode 100644 jax/experimental/mosaic/gpu/examples/__init__.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 09f169548796..0769c698d5fe 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -88,7 +88,6 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" @@ -185,7 +184,6 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" diff --git a/BUILD.bazel b/BUILD.bazel index 887f28d4583e..44885124797f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -42,12 +42,19 @@ wheel_sources( "//jax:pallas_triton", "//jax:source_mapper", "//jax:sparse_test_util", + "//jax:test_multiprocess", "//jax:test_util", + "//jax:internal_export_back_compat_test_util", + "//jax:internal_export_back_compat_test_data", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", "//jax/_src/lib", "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", "//jax/experimental/jax2tf", + "//jax/experimental/mosaic/gpu/examples:flash_attention", + "//jax/experimental/mosaic/gpu/examples:matmul", "//jax/extend", "//jax/extend:ifrt_programs", "//jax/extend/mlir", diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/mosaic/gpu/examples/__init__.py b/jax/experimental/mosaic/gpu/examples/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/pyproject.toml b/pyproject.toml index 83b85b0271f5..d48351197b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" [tool.ruff] preview = true diff --git a/setup.py b/setup.py index ef78b8f6e7ff..2b50b041008d 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ def load_version_module(pkg_path): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=find_packages(exclude=["*examples*", "*internal_test_util*"]), + packages=find_packages(exclude=["examples"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, python_requires='>=3.10', install_requires=[ From e989e23d6bb5895215864371e25724910a05fb0b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 23 May 2025 08:27:31 -0700 Subject: [PATCH 1341/1769] [pallas] Slightly revamped how checkify is exposed in Pallas * We now re-export a restricted version of `debug_check` under `pl`. Unlike the original, the `pl` version only allows a static message, i.e. string interpolation is not supported. * Only debug checks are supported, which means that by default no checking is done -- `debug_check` is lowered to a noop. * The context manager enabling debug checks is called `enable_debug_checks`. I would very much like to drop the `enable_` prefix, but without it the context manager reads too similar to `debug_check`. PiperOrigin-RevId: 762433258 --- docs/pallas/CHANGELOG.md | 5 ++++ jax/_src/pallas/core.py | 16 ------------ jax/_src/pallas/helpers.py | 30 +++++++++++++++++++++- jax/_src/pallas/mosaic/lowering.py | 24 +++++++++-------- jax/_src/pallas/mosaic_gpu/lowering.py | 23 ++++++++--------- jax/experimental/pallas/__init__.py | 5 ++-- jax/experimental/pallas/g3doc/debugging.md | 17 +++++------- jax/experimental/pallas/tpu.py | 2 -- tests/pallas/mosaic_gpu_test.py | 4 +-- tests/pallas/pallas_test.py | 7 +++-- 10 files changed, 72 insertions(+), 61 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 476cc54673a1..2d8a83c897f1 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -26,6 +26,11 @@ Remember to align the itemized text with the first line of an item within a list `block_shape` for each entry that needs unblocked indexing. * {func}`jax.experimental.pallas.pallas_call` now requires `compiler_params` to be a backend-specific dataclass instead of a param to value mapping. + * {func}`jax.experimental.pallas.debug_check` is now supported both on + TPU and Mosaic GPU. Previously, this functionality was only supported + on TPU and required using the APIs from {mod}`jax.experimental.checkify`. + Note that debug checks are not executed unless + {data}`jax.experimental.pallas.enable_debug_checks` is set. ## Released with jax 0.5.0 diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 13c634eb395f..7950f90bc377 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -44,22 +44,6 @@ from jax._src.state.types import TransformedRef import jax.numpy as jnp -# TODO(slebedev): Rename to --jax_pallas_debug_assertions. -_ENABLE_RUNTIME_ASSERT = config.bool_state( - "jax_pallas_enable_runtime_assert", - default=False, - help=( - "If set, enables runtime assertions in the kernel via checkify.check." - " Otherwise, runtime asserts will be ignored unless functionalized" - " using checkify.checkify." - ), -) - - -def runtime_assert_enabled() -> bool: - """Returns whether runtime asserts are enabled.""" - return _ENABLE_RUNTIME_ASSERT.value - class DynamicGridDim: def __repr__(self): diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 6b274c0b6cce..5c77d0a04f09 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -14,8 +14,10 @@ """Pallas helper functions.""" import jax -from jax._src.pallas import pallas_call +from jax._src import checkify +from jax._src import config from jax._src.pallas import core as pl_core +from jax._src.pallas import pallas_call @jax.named_call @@ -65,3 +67,29 @@ def _wrapped(f): else: jax.lax.cond(condition, f, lambda: None) return _wrapped + + +_ENABLE_DEBUG_CHECKS = config.bool_state( + "jax_pallas_enable_debug_checks", + default=False, + help=( + "If set, ``pl.debug_check`` calls are checked at runtime. Otherwise," + " they are a noop." + ), +) + + +enable_debug_checks = _ENABLE_DEBUG_CHECKS + + +def debug_checks_enabled() -> bool: + """Returns runtime checks are enabled.""" + return _ENABLE_DEBUG_CHECKS.value + + +def debug_check(condition, message): + """Check the condition if + :func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise + do nothing. + """ + return checkify.debug_check(condition, message) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d1222c16ac96..4e8827401e0c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -61,6 +61,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives @@ -3766,17 +3767,18 @@ def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): @register_lowering_rule(checkify.check_p) -def _checkify_lowering_rule( - ctx: LoweringRuleContext, *err_args, err_tree, debug): - if not pallas_core.runtime_assert_enabled(): - if debug: - return [] - else: - raise LoweringException( - "Non-debug check must be functionalized. Enable runtime asserts via" - " ``pl.enable_runtime_assert`` or --jax_pallas_enable_runtime_assert" - " or, alternatively, functionalize with ``checkify.check``." - ) +def _check_lowering_rule( + ctx: LoweringRuleContext, *err_args, err_tree, debug +): + del ctx # Unused. + + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] if cf is None: # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 00716bb1c675..9e396167f610 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -53,6 +53,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import discharge from jax._src.state import indexing @@ -3097,18 +3098,16 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): @register_lowering_rule(checkify.check_p, mgpu.LoweringSemantics.Lane) -def _checkify_lowering_rule( - ctx: LoweringRuleContext, *err_args, err_tree, debug -): - if not pallas_core.runtime_assert_enabled(): - if debug: - return [] - else: - raise LoweringError( - "Non-debug check must be functionalized. Enable runtime asserts via" - " ``pl.enable_runtime_assert`` or --jax_pallas_enable_runtime_assert" - " or, alternatively, functionalize with ``checkify.check``." - ) +def _check_lowering_rule(ctx: LoweringRuleContext, *err_args, err_tree, debug): + del ctx # Unused. + + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic GPU backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] if cf_dialect is None: # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 406d6e965322..caf77a3c4fce 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -18,7 +18,6 @@ https://docs.jax.dev/en/latest/pallas.html. """ -from jax._src.pallas.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 from jax._src.pallas.core import BlockDim as BlockDim from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec @@ -33,7 +32,6 @@ from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import no_block_spec as no_block_spec -from jax._src.pallas.core import runtime_assert_enabled as runtime_assert_enabled from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Squeezed as Squeezed from jax._src.pallas.core import squeezed as squeezed @@ -41,6 +39,9 @@ from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like from jax._src.pallas.helpers import when as when +from jax._src.pallas.helpers import debug_check as debug_check +from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled +from jax._src.pallas.helpers import enable_debug_checks as enable_debug_checks from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p from jax._src.pallas.primitives import atomic_add as atomic_add diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md index 6dfa95eb16fa..791705d00d30 100644 --- a/jax/experimental/pallas/g3doc/debugging.md +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -3,7 +3,7 @@ [TOC] @@ -45,16 +45,14 @@ as a Python error after the kernel has successfully executed. #### Hard assertion -Hard assertions can be inserted with `checkify.check` -and running your program with the `--jax_pallas_enable_runtime_assert` flag. +Hard assertions can be inserted with `pl.debug_check` +and running your program with the `--jax_pallas_enable_debug_checks` flag. Your code will look like the following: ```python -from jax.experimental import checkify - def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will halt if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will halt if x <= y ``` This will print a relatively lengthy dump which resembles the following: @@ -76,11 +74,10 @@ Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op from jax.experimental import checkify def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will throw an error if x <= y kernel = pl.pallas_call(...) -checkified_kernel = checkify.checkify(kernel, - errors=checkify.all_checks) +checkified_kernel = checkify.checkify(kernel, errors=checkify.all_checks) error, result = checkified_kernel(x) error.throw() ``` @@ -203,5 +200,3 @@ In most cases the error message should hint at what is wrong. For specific errors: * `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod - - diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index c8e2ba131a9b..401b2fe66c45 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -51,8 +51,6 @@ # Those primitives got moved to Pallas core. Keeping the updated imports # here for backward compatibility. from jax._src.pallas.core import semaphore as semaphore -from jax._src.pallas.core import runtime_assert_enabled as runtime_assert_enabled -from jax._src.pallas.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4ef2fa8096ee..b3c2f11ee43f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1031,14 +1031,14 @@ def kernel(x_ref, o_ref, scratch_ref, barrier_ref): def test_check(self): self.skip_if_wg_semantics() - self.enter_context(pallas_core._ENABLE_RUNTIME_ASSERT(True)) + self.enter_context(pl.enable_debug_checks(True)) @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): - checkify.check(_sum_same_dtype(x_ref[...]) > 0, "x.sum() is negative") + pl.debug_check(_sum_same_dtype(x_ref[...]) > 0, "x.sum() is negative") o_ref[...] = x_ref[...] x = jnp.arange(256, dtype=jnp.int32) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 1114153b16c2..03399e12b609 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2276,7 +2276,7 @@ def kernel(x_ref, y_ref): checkify.check(False, "second check failed") input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(True): + with pl.enable_debug_checks(True): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) pallas_call(input_) # This should log "second check failed" @@ -2286,11 +2286,10 @@ def test_runtime_assert_is_noop_when_not_enabled(self): self.skipTest("Runtime check only implemented on TPU.") def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] - checkify.check(False, "failed check", - debug=True) # This check always fails. + pl.debug_check(False, "failed check") # This check always fails. input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(False): + with pl.enable_debug_checks(False): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) result = pallas_call(input_) np.testing.assert_allclose(result, input_) From 7d13c56570072565ce244bf6ff77c2a55f4d5e66 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 23 May 2025 08:48:11 -0700 Subject: [PATCH 1342/1769] [jaxlib] Add CompileOnlyPyClient to xla_client. We have users of CompileOnlyPyClient that use `backend.compile` as we eventually intend it (i.e., return `ExecutableRef`, possibly `PyExecutable` eventually, instead of `PyLoadedExectuable`). PiperOrigin-RevId: 762440439 --- jaxlib/_jax/__init__.pyi | 11 +++++++++++ jaxlib/xla_client.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index ed0089a3dd88..67dc9ffc6001 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -553,6 +553,17 @@ class Client: ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... + +class CompileOnlyPyClient(Client): + def compile( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + + class CpuCollectives: ... def make_gloo_tcp_collectives( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index ac816e72bebe..24861bad81de 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 344 +_version = 345 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 9d6553815f12d6903d4ccfd810519e31e2a97810 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Fri, 23 May 2025 08:50:24 -0700 Subject: [PATCH 1343/1769] [CI] Pin the ML Connect action to a specific sha PiperOrigin-RevId: 762441171 --- .github/workflows/bazel_cpu_py_import_rbe.yml | 2 +- .github/workflows/bazel_cpu_rbe.yml | 2 +- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- .github/workflows/bazel_cuda_rbe.yml | 2 +- .github/workflows/bazel_optional_h100_b200.yml | 4 ++-- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/numpy_nightly.yml | 2 +- .github/workflows/oldest_supported_numpy.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_cuda.yml | 2 +- .github/workflows/pytest_tpu.yml | 2 +- 11 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index 14d6b95b4347..c98bcee980e8 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -55,7 +55,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CPU tests with py_import (RBE) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index ef5084960b30..a8b40c260260 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -54,7 +54,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} # Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 677d8d869a22..3e68034dfbf4 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -106,7 +106,7 @@ jobs: exit 1 # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA tests (Non-RBE) diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index 5a2c94c4db47..83f651c0ef95 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -50,7 +50,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA Tests with RBE diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml index 7381ce6d80bf..ec907280938e 100644 --- a/.github/workflows/bazel_optional_h100_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel single B200 CUDA Tests @@ -75,7 +75,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel multiple H100 CUDA Tests diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1b534ee3b6fc..d5fc35a99cd5 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -127,7 +127,7 @@ jobs: run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ inputs.artifact }} diff --git a/.github/workflows/numpy_nightly.yml b/.github/workflows/numpy_nightly.yml index 51876a7eb71d..17357e9f1dd8 100644 --- a/.github/workflows/numpy_nightly.yml +++ b/.github/workflows/numpy_nightly.yml @@ -54,7 +54,7 @@ jobs: path: ml_dtypes # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Install numpy & scipy development versions diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index 06e7cf6230df..a63cb0b1c614 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -51,7 +51,7 @@ jobs: $JAXCI_PYTHON -m uv pip install -e .[minimum-jaxlib] # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index fc4633110667..3af06fe8037e 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -140,7 +140,7 @@ jobs: fi # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 2f22901e661a..78f32cda672d 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -138,7 +138,7 @@ jobs: run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CUDA tests diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index ae0250884831..22cd64977dc5 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -152,7 +152,7 @@ jobs: fi # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests From c2c55aef522abcff80fbcf9b11dd8faaa28924be Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 09:08:08 -0700 Subject: [PATCH 1344/1769] Move jax/_src/interpreters/batching.py into its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import. We have a few options here: 1. Continue to bundle the batching.py source with the main build. 2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere). 3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site. I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users. PiperOrigin-RevId: 762447323 --- jax/BUILD | 20 +++++++++++++++++++- jax/_src/interpreters/batching.py | 28 +++++++++++++++++++--------- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index e7f1fad3121d..7c8847e2e94b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -315,7 +315,6 @@ py_library_providing_imports_info( "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", - "_src/interpreters/batching.py", "_src/interpreters/pxla.py", "_src/pjit.py", "_src/prng.py", @@ -383,6 +382,7 @@ py_library_providing_imports_info( ":ad_util", ":api_util", ":basearray", + ":batching", ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", @@ -707,6 +707,24 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "batching", + srcs = ["_src/interpreters/batching.py"], + deps = [ + ":ad_util", + ":config", + ":core", + ":mesh", + ":partial_eval", + ":partition_spec", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "mlir", srcs = ["_src/interpreters/mlir.py"], diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 0fbe54a30672..55769aa307fc 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -21,7 +21,6 @@ import numpy as np -import jax from jax._src import config from jax._src import core from jax._src import source_info_util @@ -301,11 +300,14 @@ def _cont(axis_size, elt, axis): from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: + # Callers of this utility, via batch() or vtile(), must be in a context + # where lax is importable. + from jax import lax # pytype: disable=import-error handler = make_iota_handlers.get(type(axis_size)) if handler: return handler(axis_size) else: - return jax.lax.iota('int32', int(axis_size)) + return lax.iota('int32', int(axis_size)) make_iota_handlers: dict[type, MakeIotaHandler] = {} def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, @@ -1019,10 +1021,13 @@ def broadcast_batcher(prim, args, dims, **params): return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) def _handle_scalar_broadcasting(nd, x, d): + # Callers of this utility, via broadcast_batcher() or defbroadcasting(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error if d is not_mapped or nd == np.ndim(x): return x else: - return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd))) + return lax.expand_dims(x, tuple(range(np.ndim(x), nd))) def defreducer(prim, ident): primitive_batchers[prim] = partial(reducer_batcher, prim, ident) @@ -1078,17 +1083,20 @@ def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array: def _mask_one_ragged_axis( operand: Array, ident, axis_spec: RaggedAxis) -> Array: + # Callers of this utility, via reducer_batcher() or defreducer(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time" ragged_axis, segment_lengths = axis_spec.ragged_axes[0] value = ident(operand.dtype) - positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis) + positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis) # TODO(mattjj, axch) can't get ._data, need to convert it - # lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32') - lengths = jax.lax.convert_element_type(segment_lengths, 'int32') - limits = jax.lax.broadcast_in_dim( + # lengths = lax.convert_element_type(segment_lengths._data, 'int32') + lengths = lax.convert_element_type(segment_lengths, 'int32') + limits = lax.broadcast_in_dim( lengths, operand.shape, [axis_spec.stacked_axis]) mask = positions < limits - return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape)) + return lax.select(mask, operand, lax.broadcast(value, operand.shape)) def move_stacked_axis(operand, bdim, dst): dst = canonicalize_axis(dst, operand.ndim) @@ -1103,6 +1111,8 @@ def move_stacked_axis(operand, bdim, dst): ### general utilities for manipulating axes on jaxpr types (not vmappables) def broadcast(x, sz, axis, mesh_axis=None): + # Callers of this utility must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error shape = list(np.shape(x)) shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) @@ -1114,7 +1124,7 @@ def broadcast(x, sz, axis, mesh_axis=None): # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): - x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) if config._check_vma.value: # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 spmd_names = core.get_axis_env().spmd_axis_names From 5d24628ce70722de361bf874240cc0bc171a6267 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 23 May 2025 09:32:43 -0700 Subject: [PATCH 1345/1769] Add TODO to run TPU interpret mode tests in parallel. PiperOrigin-RevId: 762456274 --- tests/pallas/tpu_pallas_interpret_distributed_test.py | 2 ++ tests/pallas/tpu_pallas_interpret_test.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index a029b8094aa1..4e4776736cf1 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -38,6 +38,8 @@ P = jax.sharding.PartitionSpec +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9d6188cbc0cc..cfbf5d70e212 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -82,6 +82,8 @@ def grid_points(self): return self._grid_points +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase): From 9928409798fbdf4b9a0b811e78a7bb1698caeda3 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 23 May 2025 09:36:47 -0700 Subject: [PATCH 1346/1769] Rename backend.compile to backend.compile_and_load. Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable. PiperOrigin-RevId: 762457830 --- jax/_src/compiler.py | 59 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 343f747efbd7..4f805034e99c 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -34,7 +34,9 @@ from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc +from jax._src.lib import _jax from jax._src.lib.mlir import ir import numpy as np @@ -291,6 +293,19 @@ def backend_compile( executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], +) -> xc.LoadedExecutable: + return backend_compile_and_load( + backend, module, executable_devices, options, host_callbacks + ) + + +@profiler.annotate_function +def backend_compile_and_load( + backend: xc.Client, + module: ir.Module, + executable_devices: xc.DeviceList, + options: xc.CompileOptions, + host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -315,18 +330,40 @@ def backend_compile( try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - if host_callbacks: + # TODO(dsuo): Simplify this logic once backend_compile actually returns an + # unloaded executable. + if jaxlib_extension_version < 345 or ( + jaxlib_extension_version >= 345 + and isinstance(backend, _jax.CompileOnlyPyClient) + ): + if host_callbacks: + return backend.compile( + built_c, + executable_devices=executable_devices, # type: ignore + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` return backend.compile( + built_c, executable_devices=executable_devices, compile_options=options) # type: ignore + else: + if host_callbacks: + return backend.compile_and_load( + built_c, + executable_devices=executable_devices, + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile_and_load( built_c, - executable_devices=executable_devices, # type: ignore + executable_devices=executable_devices, compile_options=options, - host_callbacks=host_callbacks, ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile( - built_c, executable_devices=executable_devices, compile_options=options) # type: ignore except xc.XlaRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) @@ -391,7 +428,7 @@ def compile_or_get_cached( ) if cache_key is None: - return backend_compile( + return backend_compile_and_load( backend, computation, executable_devices, compile_options, host_callbacks) @@ -419,7 +456,7 @@ def compile_or_get_cached( config.share_binary_between_hosts.value and is_multi_process and distributed.global_state.client is not None - # Host callbacks are currently baked into the HLO module so we cant share + # Host callbacks are currently baked into the HLO module so we can't share # them. and len(host_callbacks) == 0 ): @@ -705,7 +742,7 @@ def _compile_and_write_cache( cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() - executable = backend_compile( + executable = backend_compile_and_load( backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time From 715ab618429ef1da6e613aaa5e8c5ca2daa6aada Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 23 May 2025 11:59:48 -0500 Subject: [PATCH 1347/1769] Fix Gotchas link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a5deecef6c37..14a0a06ae700 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-exa ### Contents * [Transformations](#transformations) * [Scaling](#scaling) -* [Current gotchas](#current-gotchas) +* [Current gotchas](#gotchas-and-sharp-bits) * [Installation](#installation) * [Neural net libraries](#neural-network-libraries) * [Citing JAX](#citing-jax) From aa63a159e5a2ef6145f8b87538e16c6bc42970b8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 10:11:39 -0700 Subject: [PATCH 1348/1769] [tree_util] raise more informative error when pytree equality check fails --- jaxlib/pytree.cc | 12 ++++++++++-- jaxlib/xla_client.py | 2 +- tests/tree_util_test.py | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc index 2700ac9e6c9a..bd845c47ec1e 100644 --- a/jaxlib/pytree.cc +++ b/jaxlib/pytree.cc @@ -281,8 +281,16 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { a.custom != b.custom) { return false; } - if (a.node_data && a.node_data.not_equal(b.node_data)) { - return false; + try { + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Exception raised while checking equality of metadata " + "fields of pytree. Make sure that metadata fields are " + "hashable and have simple equality semantics. (Note: " + "arrays cannot be passed as metadata fields!)"); } if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { return false; diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 24861bad81de..b9497b71dcb1 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 345 +_version = 346 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 8d4cd5854e7d..0d92156b2530 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -1050,6 +1050,27 @@ def testPickle(self): unpickled = pickle.loads(pickle.dumps(key)) self.assertEqual(key, unpickled) + def testEqualityErrorWithArrayAsStaticArg(self): + # Regression test for https://github.com/jax-ml/jax/issues/28659 + @tree_util.register_dataclass + @dataclasses.dataclass + class Tree: + x : jnp.ndarray = dataclasses.field(metadata={'static': True}) + + f = jax.jit(lambda x: x) + + if jax._src.lib.jaxlib_extension_version < 346: + msg = "The truth value of an array with more than one element is ambiguous." + else: + msg = "Exception raised while checking equality of metadata fields of pytree." + + # First call succeeds, because there is no equality check. + f(Tree(jnp.arange(4))) + + # Second fall fails, because arrays are marked static and compared for equality. + with self.assertRaisesRegex(ValueError, msg): + f(Tree(jnp.arange(4))) + class StaticTest(parameterized.TestCase): From d4ab82637a9200752d30726d46281e97e5987427 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 23 May 2025 10:16:13 -0700 Subject: [PATCH 1349/1769] [Mosaic GPU] Add support for copy_gmem_to_smem in Warp semantics. PiperOrigin-RevId: 762475094 --- jax/_src/pallas/mosaic_gpu/primitives.py | 31 ++++++++++++++++-------- jax/experimental/mosaic/gpu/utils.py | 11 +++++++-- tests/pallas/mosaic_gpu_test.py | 26 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 53c890932e38..379a972be9b0 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -49,6 +49,7 @@ import jax.numpy as jnp +WARP_SIZE = 32 WARPGROUP_SIZE = 128 @@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering( dst_transforms_treedef, barrier_transforms_treedef, collective_axes, - warpgroup_sync: bool = True, + for_warpgroup: bool = True, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering( if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") - # We arrive uniformly from each thread in the WG, so we need to divide the - # number of bytes by the number of threads in the WG. - # TODO: apaszke - Relax this. We can just select the WG leader and have it - # arrive with the whole transfer size, while everyone else arrives with 0. - # But we should continue using this scheme as it's likely to be faster. - bytes //= WARPGROUP_SIZE - if warpgroup_sync: + if for_warpgroup: + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: apaszke - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= WARPGROUP_SIZE mgpu.warpgroup_barrier() # Make sure all reads have completed. - barrier.arrive_expect_tx(bytes) + barrier.arrive_expect_tx(bytes) + else: + # In Warp-level lowering, we arrive on each CUDA thread in a warp, but + # the barrier still expects a full 128 arrivals so we arrive 4 times + # on each CUDA thread instead. + bytes //= WARP_SIZE + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, @@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, primitive_semantics=gpu_core.PrimitiveSemantics.Warp, -)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False)) +)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False)) def copy_gmem_to_smem( @@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn( @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 4aeb3358b97a..1915b0b45f11 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -816,9 +816,16 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: ) return parity, arith.xori(parities, bitmask) - def arrive(self): + def arrive(self, arrival_count: int = 1, can_complete: bool = True): i64 = ir.IntegerType.get_signless(64) - nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) + if can_complete: + if arrival_count > 1: + count = c(arrival_count - 1, ir.IntegerType.get_signless(32)) + nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count) + nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) + else: + count = c(arrival_count, ir.IntegerType.get_signless(32)) + nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count) def arrive_expect_tx( self, bytes: int | ir.Value, predicate: ir.Value | None = None diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c2f11ee43f..029445143a0f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1878,6 +1878,32 @@ def _(): }, ) + def test_copy_gmem_to_smem_from_different_warps(self): + # In this test, we issue a copy from from warp 0 and await it in warp 1. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32)) + def kernel(x_ref, y_ref): + def scope(smem_ref, tma_barrier): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem(x_ref.at[32:64], smem_ref, tma_barrier) + + @pl.when(warp_id == 1) + def _(): + plgpu.barrier_wait(tma_barrier) + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + smem_ref=plgpu.SMEM((32, 32), jnp.float32), + tma_barrier=plgpu.Barrier(num_arrivals=1)) + x = jax.random.uniform(jax.random.key(42), (64, 32), jnp.float32) + result = kernel(x) + np.testing.assert_array_equal(result, x[32:64]) + class PallasCallWGTest( PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From c4a90c193473a686c08a31650f23f4dc436c801e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 23 May 2025 11:41:52 -0700 Subject: [PATCH 1350/1769] [Mosaic GPU] Add barrier transformation support to tcgen05_mma. Also fix accumulator argument when it's dynamic. PiperOrigin-RevId: 762509416 --- jax/_src/pallas/mosaic_gpu/primitives.py | 76 +++++++++++++++++++----- tests/pallas/mosaic_gpu_test.py | 46 ++++++++++++++ 2 files changed, 106 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 379a972be9b0..1ec22bff3f6d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -554,6 +554,7 @@ def _copy_gmem_to_smem_lowering( ) return () + lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, @@ -722,9 +723,14 @@ def _barrier_wait_pp_eqn( @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane, - gpu_core.PrimitiveSemantics.Warp) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) +@lowering.register_lowering_rule( + barrier_wait_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +@lowering.register_lowering_rule( + barrier_wait_p, mgpu.LoweringSemantics.Warpgroup +) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -1198,18 +1204,31 @@ def tcgen05_mma(acc: _Ref, else: b_transforms_leaves, b_transforms_tree = [], None + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + tcgen05_mma_p.bind(acc, a, b, barrier, accumulate, *a_transforms_leaves, *b_transforms_leaves, + *barrier_transforms_leaves, a_transforms_tree=a_transforms_tree, b_transforms_tree=b_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, collective_axis=collective_axis) + @tcgen05_mma_p.def_abstract_eval def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, *transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree, collective_axis): - del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) + del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree) if acc.memory_space != gpu_core.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") @@ -1233,6 +1252,7 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, return [] + @lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS) @lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS) def _tcgen05_mma_lowering( @@ -1245,16 +1265,26 @@ def _tcgen05_mma_lowering( *transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree, collective_axis, ): _, a_aval, b_aval, *_ = ctx.avals_in lhs_swizzle: int | None = None lhs_transpose: bool = False - if a_transforms_tree is not None: - a_transforms_leaves, b_transforms_leaves = util.split_list( - transforms_leaves, [a_transforms_tree.num_leaves] - ) + transforms_trees = ( + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + ) + (a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = ( + util.split_list( + transforms_leaves, + [getattr(tree, "num_leaves", 0) for tree in transforms_trees], + ) + ) + + if a_transforms_tree is not None: a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) a_ref, a_transforms = lowering._handle_transforms( ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True @@ -1276,9 +1306,8 @@ def _tcgen05_mma_lowering( if lhs_tiling != (8, swizzle_elems): raise ValueError("MMA lhs tiling does not fit swizzle. " f"{lhs_tiling=} expected={(8, swizzle_elems)}") - else: - b_transforms_leaves = transforms_leaves # type: ignore + assert b_transforms_tree is not None b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) b_ref, b_transforms = lowering._handle_transforms( ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True @@ -1296,16 +1325,28 @@ def _tcgen05_mma_lowering( raise NotImplementedError( f"Unsupported transforms: {b_transforms}." ) - swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize + if rhs_tiling != (8, swizzle_elems): + raise ValueError( + "MMA rhs tiling does not fit swizzle" + f" {rhs_tiling=} expected={(8, swizzle_elems)}" + ) + + if barrier_transforms_tree is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + indexer = _extract_barrier_indexer(barrier_transforms) + if indexer is not None: + barrier_ref = barrier_ref.__getitem__( + *map(lowering._as_index, indexer.indices) + ) + if lhs_swizzle is None: lhs_swizzle = rhs_swizzle elif rhs_swizzle != lhs_swizzle: raise ValueError("MMA rhs swizzle must match lhs swizzle." f" {lhs_swizzle=} {rhs_swizzle=}") - if rhs_tiling != (8, swizzle_elems): - raise ValueError("MMA rhs tiling does not fit swizzle" - f" {rhs_tiling=} expected={(8, swizzle_elems)}") if lhs_transpose: if isinstance(a_ref, tcgen05.TMEMRef): raise ValueError("TMEM transpose not allowed.") @@ -1314,6 +1355,9 @@ def _tcgen05_mma_lowering( b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2)) if isinstance(accumulate, bool): accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) + elif isinstance(accumulate, mgpu.FragmentedArray): + accumulate = accumulate.registers.item() + assert isinstance(accumulate, ir.Value) predicate = ctx.module_ctx.single_lane_predicate collective = False @@ -1341,8 +1385,8 @@ def _tcgen05_mma_lowering( acc, a_ref, b_ref, - a_swizzle=lhs_swizzle, - b_swizzle=rhs_swizzle, + a_swizzle=int(lhs_swizzle), + b_swizzle=int(rhs_swizzle), accumulate=accumulate, collective=collective, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 029445143a0f..3c0b463ba1c7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2496,6 +2496,52 @@ def _scoped(a_smem, b_smem, expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.parameters((0,), (1,)) + def test_mma_barrier_indexing( + self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16 + ): + self.skip_if_wg_semantics() + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier_ref.at[barrier_index], + accumulate=False, + ) + plgpu.barrier_wait(barrier_ref.at[barrier_index]) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(num_arrivals=1, num_barriers=2, for_tensor_core=True), + ] + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + class PallasCallSm100AWGTest( PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From f5a9d460723b417ca2032057f8b2ef5ad9e6fdf9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 12:31:09 -0700 Subject: [PATCH 1351/1769] Move jax/_src/custom_dce.py to its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This was unblocked by moving batching & ad to their own rules in prior changes. PiperOrigin-RevId: 762527517 --- jax/BUILD | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 7c8847e2e94b..586e9c2dd6e6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -303,7 +303,6 @@ py_library_providing_imports_info( "_src/callback.py", "_src/checkify.py", "_src/custom_batching.py", - "_src/custom_dce.py", "_src/custom_derivatives.py", "_src/custom_partitioning.py", "_src/custom_partitioning_sharding_rule.py", @@ -390,6 +389,7 @@ py_library_providing_imports_info( ":config", ":core", ":custom_api_util", + ":custom_dce", ":custom_transpose", ":deprecations", ":dtypes", @@ -595,6 +595,24 @@ pytype_strict_library( srcs = ["_src/custom_api_util.py"], ) +pytype_strict_library( + name = "custom_dce", + srcs = ["_src/custom_dce.py"], + deps = [ + ":ad", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ], +) + pytype_strict_library( name = "custom_transpose", srcs = ["_src/custom_transpose.py"], From 9153ab760bc146945f572a79d86fc345286d5f46 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 23 May 2025 13:30:31 -0700 Subject: [PATCH 1352/1769] Transfer library: poison outstanding buffer fetches upon connection failure. PiperOrigin-RevId: 762546985 --- jaxlib/py_socket_transfer.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index 8086196b9df8..c7cc7c496b0b 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -163,6 +163,8 @@ class PyTransferServerConnection { } } + SocketServer::Connection& conn() { return *conn_; } + private: tsl::RCReference conn_; }; @@ -257,6 +259,11 @@ struct CopyDests { void RegisterTransferServerTypes(nanobind::module_& m) { nb::class_(m, "TransferConnection") +#if JAX_IFRT_VERSION_NUMBER > 9 + .def( + "_testonly_inject_failure", + [](PyTransferServerConnection& self) { self.conn().InjectFailure(); }) +#endif .def("_pull_flat", [](PyTransferServerConnection& self, uint64_t uuid, xla::nb_class_ptr py_client, std::vector py_avals) { From 57d07e195b4643fd134abc175f46e108ea667875 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Fri, 23 May 2025 13:30:39 -0700 Subject: [PATCH 1353/1769] This is a change to patch some internal Google builds while we complete a refactor. PiperOrigin-RevId: 762547026 --- jax/profiler.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/profiler.py b/jax/profiler.py index 31f3ea186d79..d776791e9200 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -14,6 +14,7 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from typing import Any from jax._src.profiler import ( ProfileOptions as ProfileOptions, @@ -28,3 +29,9 @@ stop_trace as stop_trace, trace as trace, ) + +# this is a temporary shim to please pytype in the meantime before the migration +# is complete for cl/760646494 +ProfileData: Any = None +ProfileEvent: Any = None +ProfilePlane: Any = None From ae2f943b54bd353ad3731fc07b6ef3d723f83446 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 23 May 2025 12:15:47 +0000 Subject: [PATCH 1354/1769] TSAN CI, make jax buid/test step fail if missing deps wheels --- .github/workflows/tsan.yaml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index ce4130c31a30..6cd502050344 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -116,6 +116,7 @@ jobs: - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' run: | + set -eux cd numpy # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz @@ -131,7 +132,6 @@ jobs: export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH python3 -m pip install uv~=0.5.30 - python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -268,11 +268,15 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - mkdir -p dist + # Check whether we have numpy wheel or exit with error + ls ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl || exit 1 cp -v ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl dist/ - cp -v ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl dist/ if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check whether we have scipy wheel or exit with error + ls ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl || exit 1 + cp -v ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl dist/ + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt From 292dea67fa3f550b8add78ea1840b566d95069b1 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Fri, 23 May 2025 14:24:53 -0700 Subject: [PATCH 1355/1769] Simplify attention VJP definition PiperOrigin-RevId: 762567722 --- jax/experimental/pallas/ops/gpu/attention.py | 103 ++++++------------- tests/pallas/gpu_ops_test.py | 24 +++++ 2 files changed, 54 insertions(+), 73 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index ccb3ae8fd3b7..2442ed14f351 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -152,7 +152,7 @@ def body(start_k, carry): # Apply mask to qk. qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - m_curr = qk.max(axis=-1) + m_curr = jnp.max(qk, axis=-1) m_next = jnp.maximum(m_prev, m_curr) correction = jnp.exp2(m_prev - m_next) l_prev_corr = correction * l_prev @@ -201,7 +201,7 @@ def segment_mask( @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] ) @functools.partial( jax.jit, @@ -215,6 +215,7 @@ def segment_mask( "grid", "interpret", "debug", + "return_residuals", ], ) def mha( @@ -231,6 +232,7 @@ def mha( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False, ): del backward_pass_impl batch_size, q_seq_len, num_heads, head_dim = q.shape @@ -273,14 +275,19 @@ def mha( if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) - out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) - return pl.pallas_call( + out_shape = [q] + out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0))] + if return_residuals: + out_shape.append(jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse + out_specs.append( + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse + out = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, - out_specs=pl.BlockSpec( - (None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0) - ), + out_specs=out_specs, compiler_params=plgpu.TritonCompilerParams( num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, @@ -288,6 +295,7 @@ def mha( interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) + return out if return_residuals else out[0] def _mha_forward( @@ -304,71 +312,17 @@ def _mha_forward( grid: Any, interpret: bool, debug: bool, + return_residuals: bool, ): - del backward_pass_impl - batch_size, q_seq_len, num_heads, head_dim = q.shape - kv_seq_len = k.shape[1] - block_q = min(block_sizes.block_q, q_seq_len) - block_k = min(block_sizes.block_k, kv_seq_len) - if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]): - raise ValueError( - f"This kernel expects q, k, and v to have the same head dimension, but" - f" found {q.shape=}, {k.shape=}, {v.shape=}." - ) - if q_seq_len % block_q != 0: - raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}") - if kv_seq_len % block_k != 0: - raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}") - head_dim_padded = pl.next_power_of_2(head_dim) - - # Heuristics. - grid_ = grid - if grid_ is None: - grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) - - num_warps_ = num_warps - if num_warps_ is None: - num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, - causal=causal, block_q=block_q, block_k=block_k, - head_dim=head_dim) - out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse - ), - ] - in_specs = [ - pl.BlockSpec((None, block_q, None, head_dim_padded), - lambda i, j, k: (j, i, k, 0)), - pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), - lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), - lambda _, j, k: (j, 0, k, 0)), - ] - in_specs.append( - None # type: ignore[arg-type] - if segment_ids is None - else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) - ) - out, lse = pl.pallas_call( - kernel, - grid=grid_, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec((None, block_q, None, head_dim_padded), - lambda i, j, k: (j, i, k, 0)), - pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=out_shape, - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, lse) + out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale, + causal=causal, block_sizes=block_sizes, + backward_pass_impl=backward_pass_impl, + num_warps=num_warps, num_stages=num_stages, + grid=grid, interpret=interpret, debug=debug, + return_residuals=True) + residuals = (q, k, v, segment_ids, out, lse) + ret = (out, lse) if return_residuals else out + return ret, residuals def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int): @@ -576,9 +530,12 @@ def inner_loop_dq(start_k, dq): def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, backward_pass_impl: str, num_warps: int | None, num_stages: int, grid: Any, interpret: bool, - debug: bool, res, do): - del num_stages, grid + debug: bool, return_residuals: bool, res, do): + if return_residuals: + raise ValueError( + "Kernel differentiation is not supported if return_residuals is True.") q, k, v, segment_ids, out, lse = res + del num_stages, grid, return_residuals if backward_pass_impl == "xla": return jax.vjp( diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 1637686365e1..cc2d15a8fdee 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -313,6 +313,30 @@ def f_ref(q, k, v): self.assertAllClose(dk, dk_ref, atol=5e-2) self.assertAllClose(dv, dv_ref, atol=5e-2) + def test_return_residuals_not_differentiable(self): + batch_size, seq_len, num_heads, head_dim = 2, 128, 2, 128 + causal = False + k1, k2, k3 = random.split(random.key(0), 3) + q = random.normal( + k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + segment_ids = None + + def f(q, k, v): + return attention.mha(q, k, v, causal=causal, segment_ids=segment_ids, + interpret=self.INTERPRET, + return_residuals=True)[0].sum() + + with self.assertRaisesRegex(ValueError, "Kernel differentiation is not" + " supported if return_residuals is True."): + _ = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + class FusedAttentionInterpretTest(FusedAttentionTest): INTERPRET = True From 966bcb932ea962133d77ea2ad29d65a85127b402 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 May 2025 14:48:36 -0700 Subject: [PATCH 1356/1769] [ragged-paged-attn] Implement static kv cache quantization. (The scale of kv cache is a scalar float value) PiperOrigin-RevId: 762576286 --- jax/BUILD | 1 + .../ops/tpu/ragged_paged_attention/kernel.py | 136 +++++++++++++----- .../pallas/tpu_ragged_paged_attention_test.py | 83 +++++++++-- 3 files changed, 171 insertions(+), 49 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 586e9c2dd6e6..5fb96d34d91e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -908,6 +908,7 @@ pytype_strict_library( ":pallas_tpu_users", ], deps = [ + ":dtypes", ":jax", ":pallas", ":pallas_tpu", diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index d9d952d5a378..67c0b376ecc6 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -22,11 +22,13 @@ import functools import jax from jax import lax +from jax._src import dtypes from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp + DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -80,6 +82,8 @@ def ref_ragged_paged_attention( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): static_validate_inputs( queries, @@ -89,6 +93,8 @@ def ref_ragged_paged_attention( cu_q_lens, num_seqs, sm_scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, sliding_window=sliding_window, soft_cap=soft_cap, mask_value=mask_value, @@ -115,6 +121,12 @@ def ref_ragged_paged_attention( v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ :kv_len ] + if k_scale is not None: + k = k.astype(jnp.float32) * k_scale + k = k.astype(q.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q.dtype) k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -150,7 +162,9 @@ def dynamic_validate_inputs( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = None, - # Kernel specific params. + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, @@ -166,6 +180,8 @@ def dynamic_validate_inputs( sliding_window=sliding_window, soft_cap=soft_cap, mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, @@ -210,7 +226,9 @@ def static_validate_inputs( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = None, - # Kernel specific params. + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, @@ -218,6 +236,8 @@ def static_validate_inputs( _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape assert num_combined_kv_heads % 2 == 0 + assert isinstance(k_scale, float) or k_scale is None + assert isinstance(v_scale, float) or v_scale is None num_kv_heads = num_combined_kv_heads // 2 max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): @@ -291,6 +311,8 @@ def ragged_paged_attention_kernel( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): if mask_value is None: mask_value = DEFAULT_MASK_VALUE @@ -334,23 +356,41 @@ def create_kv_async_copy_descriptors( return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: - # 1. Support arbitrary strided load/store for any dtype. + # 1. Support arbitrary strided load/store for int4 and int8 dtype. # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): - if ref.dtype == jnp.float32: - return ref[start::step, :], ref[start + 1 :: step, :] packing = get_dtype_packing(ref.dtype) - assert ref.dtype == jnp.bfloat16 + if packing == 1: + return [ref[start::step, :]], [ref[start + 1 :: step, :]] + assert packing in (2, 4, 8) assert step % packing == 0 + k_list, v_list = [], [] b_start = start // packing b_step = step // packing b_ref = ref.bitcast(jnp.uint32) b = b_ref[b_start::b_step, :] - bk = b << 16 - bv = b & jnp.uint32(0xffff0000) - k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) - v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) - return k, v + + # TODO(chengjiyao): use the general strided loading logic for bf16 after + # fixing the issue in mosaic's infer vector layout pass + if ref.dtype == jnp.bfloat16: + bk = b << 16 + bv = b & jnp.uint32(0xFFFF0000) + k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) + v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) + k_list.append(k) + v_list.append(v) + else: + bitwidth = 32 // packing + bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}") + for i in range(0, packing, 2): + bk = b >> (i * bitwidth) + k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype) + k_list.append(k) + bv = b >> ((i + 1) * bitwidth) + v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype) + v_list.append(v) + + return k_list, v_list def fold_on_2nd_minor(vec): assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 @@ -578,25 +618,42 @@ def prefetch_next_kv_blk(): num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - for kv_head_idx in range(num_kv_heads_per_blk): - q_head_idx = kv_head_idx * num_q_heads_per_kv_head - # TODO(jevinjiang): extra handlig for packed type that can start at - # unaligned position! - q = fold_on_2nd_minor( - q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] - ) - k, v = strided_load_kv( - kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk - ) - flash_attention( - q, - k, - v, - l_ref.at[kv_head_idx], - m_ref.at[kv_head_idx], - acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], - kv_blk_idx=kv_blk_idx, + kv_packing = get_dtype_packing(kv_ref.dtype) + # NOTE: kv_packing is divided by 2 because k and v are packed together. + kv_load_step = max(1, kv_packing // 2) + for kv_head_chunk_idx in range(0, num_kv_heads_per_blk, kv_load_step): + k_list, v_list = strided_load_kv( + kv_ref, kv_head_chunk_idx * 2, num_combined_kv_heads_per_blk ) + for step_idx in range(kv_load_step): + k = k_list[step_idx] + v = v_list[step_idx] + if k_scale is not None: + # NOTE: Conversion between arbitrary data types is not supported. + # That's why it is converted to float32 first. + k = k.astype(jnp.float32) * k_scale + k = k.astype(q_ref.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q_ref.dtype) + kv_head_idx = kv_head_chunk_idx + step_idx + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handlig for packed type that can start at + # unaligned position! + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + acc_ref.at[ + :, q_head_idx : q_head_idx + num_q_heads_per_kv_head, : + ], + kv_blk_idx=kv_blk_idx, + ) return kv_blk_idx + 1, next_buf_idx _, next_buf_idx = lax.while_loop( @@ -625,15 +682,8 @@ def cdiv(a, b): def get_dtype_packing(dtype): - if dtype == jnp.float32: - return 1 - if dtype == jnp.bfloat16: - return 2 - if dtype == jnp.int8: - return 4 - if dtype == jnp.int4: - return 8 - raise ValueError(f"Not implemented: unsupported {dtype=}") + bits = dtypes.bit_width(dtype) + return 32 // bits def get_min_heads_per_blk( @@ -681,6 +731,8 @@ def can_be_xla_fully_tiled(x, packing): "vmem_limit_bytes", "sliding_window", "soft_cap", + "k_scale", + "v_scale", ], ) def ragged_paged_attention( @@ -696,6 +748,8 @@ def ragged_paged_attention( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, @@ -715,6 +769,8 @@ def ragged_paged_attention( sliding_window: the sliding window size for the attention. soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. + k_scale: the scale for the key cache. + v_scale: the scale for the value cache. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. num_queries_per_block: number of kv pages to be processed in one flash @@ -735,6 +791,8 @@ def ragged_paged_attention( sliding_window=sliding_window, soft_cap=soft_cap, mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, @@ -823,6 +881,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): sliding_window=sliding_window, soft_cap=soft_cap, mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=len(scalar_prefetches), diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 4265445c69c7..eebc292ce3ab 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax +from jax._src import dtypes from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( cdiv, @@ -39,7 +40,8 @@ def _test_ragged_paged_attention( num_heads, # [num_q_heads, num_kv_heads] head_dim, page_size, - dtype, + q_dtype, + kv_dtype, num_pages, *, num_kv_pages_per_block=8, @@ -49,6 +51,8 @@ def _test_ragged_paged_attention( max_num_seq=8, sliding_window: int | None = None, soft_cap: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -70,17 +74,27 @@ def _test_ragged_paged_attention( q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), - dtype=dtype, + dtype=q_dtype, ) page_cnt = 0 page_indices_list = [] kv_pages_list = [] for kv_len in kv_lens: - kv = jax.random.normal( - k1, - (kv_len, num_kv_heads * 2, head_dim), - dtype=dtype, - ) + if jnp.issubdtype(kv_dtype, jnp.integer): + # random.randint doesn't support int4, so we use jnp.int32 here and then + # convert to the desired dtype. + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=jnp.int32, + ) + kv = kv.astype(kv_dtype) + else: + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=kv_dtype, + ) kv = jnp.pad( kv, ((0, cdiv(kv_len, page_size) * page_size - kv_len), (0, 0), (0, 0)), @@ -138,7 +152,9 @@ def _test_ragged_paged_attention( vmem_limit_bytes=vmem_limit_bytes, sliding_window=sliding_window, soft_cap=soft_cap, - )[: actual_num_q_tokens] + k_scale=k_scale, + v_scale=v_scale, + )[:actual_num_q_tokens] expected = ref_ragged_paged_attention( q, @@ -149,12 +165,17 @@ def _test_ragged_paged_attention( num_seqs=num_seqs, sliding_window=sliding_window, soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale, ) + dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype)) tols = { - "float32": 0.15, - "bfloat16": 0.2, + 32: 0.15, + 16: 0.2, + 8: 0.2, + 4: 0.2, } - tol = tols[jnp.dtype(dtype).name] + tol = tols[dtype_bits] self.assertAllClose(output, expected, atol=tol, rtol=tol) @parameterized.product( @@ -173,9 +194,40 @@ def test_ragged_paged_attention_basic(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) + # TODO: support int4 and int8 + @parameterized.product( + q_dtype=[jnp.bfloat16], + kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn], + kv_scales=[(0.5, 0.5), (None, None)], + ) + def test_ragged_paged_attention_quantized_kv_cache( + self, q_dtype, kv_dtype, kv_scales + ): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Expect TPUv5+") + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + k_scale, v_scale = kv_scales + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + q_dtype, + kv_dtype, + num_pages, + k_scale=k_scale, + v_scale=v_scale, + ) + @parameterized.product( dtype=[jnp.float32, jnp.bfloat16], ) @@ -209,6 +261,7 @@ def test_ragged_paged_attention_decode_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -245,6 +298,7 @@ def test_ragged_paged_attention_prefill_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -281,6 +335,7 @@ def test_ragged_paged_attention_mixed(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -316,6 +371,7 @@ def test_ragged_paged_attention_complex( head_dim, page_size, dtype, + dtype, num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, @@ -351,6 +407,7 @@ def test_ragged_paged_attention_sliding_window( head_dim, page_size, dtype, + dtype, num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, @@ -386,6 +443,7 @@ def test_ragged_paged_attention_logit_soft_capping( head_dim, page_size, dtype, + dtype, num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, @@ -407,6 +465,7 @@ def test_ragged_paged_attention_sliding_window_should_be_positive(self): head_dim, page_size, dtype, + dtype, num_pages, sliding_window=0, ) @@ -418,6 +477,7 @@ def test_ragged_paged_attention_sliding_window_should_be_positive(self): head_dim, page_size, dtype, + dtype, num_pages, sliding_window=-1, ) @@ -437,6 +497,7 @@ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): head_dim, page_size, dtype, + dtype, num_pages, soft_cap=0.0, ) From 2b9d7c80a84401c39b2ae7b9083b1db9d0bbc22a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 15:27:07 -0700 Subject: [PATCH 1357/1769] Move jax/_src/tree.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 762589488 --- jax/BUILD | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 5fb96d34d91e..de187b4ce597 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -322,7 +322,6 @@ py_library_providing_imports_info( "_src/shard_alike.py", "_src/shard_map.py", "_src/sourcemap.py", - "_src/tree.py", ] + glob( [ "*.py", @@ -416,6 +415,7 @@ py_library_providing_imports_info( ":source_info_util", ":stages", ":traceback_util", + ":tree", ":tree_util", ":typing", ":util", @@ -1207,6 +1207,14 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "tree", + srcs = ["_src/tree.py"], + deps = [ + ":tree_util", + ], +) + pytype_strict_library( name = "tree_util", srcs = ["_src/tree_util.py"], From 4eb32206d2a6923cc68c933bf48e3c0649496abd Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 23 May 2025 16:14:59 -0700 Subject: [PATCH 1358/1769] Update index.rst add marin to ecosystem --- docs/index.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 5a43be427041..07739c01c2fb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -122,6 +122,7 @@ numerical computing tools; the following is just a small sample of what is out t - AXLearn_ - Levanter_ - EasyLM_ + - Marin_ Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page @@ -189,6 +190,7 @@ maintains an up-to-date list. .. _JAX AI Stack Examples: https://docs.jaxstack.ai/en/latest/examples.html .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter +.. _Marin: https://github.com/marin-community/marin .. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ .. _Numpyro: https://num.pyro.ai/en/latest/index.html From d0195f2240fdf081e011e193585593df2e060b26 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 23 May 2025 17:06:16 -0700 Subject: [PATCH 1359/1769] Move jax/_src/sourcemap to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, discourages private imports downstream, and leads to improved build and iteration times. PiperOrigin-RevId: 762621491 --- jax/BUILD | 8 +++++++- tests/BUILD | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index de187b4ce597..a236cd206a55 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -321,7 +321,6 @@ py_library_providing_imports_info( "_src/random.py", "_src/shard_alike.py", "_src/shard_map.py", - "_src/sourcemap.py", ] + glob( [ "*.py", @@ -413,6 +412,7 @@ py_library_providing_imports_info( ":sharding_impls", ":sharding_specs", ":source_info_util", + ":sourcemap", ":stages", ":traceback_util", ":tree", @@ -795,6 +795,11 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "sourcemap", + srcs = ["_src/sourcemap.py"], +) + pytype_strict_library( name = "source_mapper", srcs = glob(include = ["experimental/source_mapper/**/*.py"]), @@ -806,6 +811,7 @@ pytype_strict_library( ":core", ":jax", ":source_info_util", + ":sourcemap", ] + py_deps("absl/flags"), ) diff --git a/tests/BUILD b/tests/BUILD index 6c9f3f74b56a..6fac61933d59 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2109,6 +2109,7 @@ jax_py_test( srcs = ["sourcemap_test.py"], deps = [ "//jax", + "//jax:sourcemap", "//jax:test_util", ] + py_deps([ "absl/testing", From 0833cc2c3e9fcd80b870a49578a29d410d35554b Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 23 May 2025 17:06:33 -0700 Subject: [PATCH 1360/1769] Use block_until_ready to fix races in TPU interpret mode tests PiperOrigin-RevId: 762621608 --- tests/pallas/tpu_pallas_interpret_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index cfbf5d70e212..014667d5f948 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -545,7 +545,7 @@ def kernel(x_ref, o_ref, vmem_ref): compiler_params=pltpu.TPUCompilerParams( dimension_semantics=('parallel',), ), - )(x) + )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, 2.0 * x) @@ -579,7 +579,7 @@ def kernel(x_ref, o_ref, vmem_ref): compiler_params=pltpu.TPUCompilerParams( dimension_semantics=('parallel',) ), - )(x) + )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, 2.0 * x) From f28565d3d5bef3021de05edd50fb39c36ca475b5 Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Fri, 23 May 2025 17:51:52 -0700 Subject: [PATCH 1361/1769] Update XlaCallModule so tests are compatible with DCE. Move the logic verifying that `shape_assertion` custom calls have side effects to run before MLIR optimizations are applied instead of after. Any `shape_assertion` custom call violating this condition (i.e. declared as pure) is likely to be removed by dead-code elimination, making it undetectable after optimizations. (Until recently, the test passed because DCE wasn't correctly applied to `custom_call` ops.) PiperOrigin-RevId: 762634648 --- jax/_src/ffi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index b25306c66b42..3bfe8130ccda 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -16,6 +16,7 @@ from collections.abc import Callable, Mapping, Sequence import ctypes +import dataclasses import functools import os from typing import Any, overload @@ -593,6 +594,7 @@ def __eq__(self, other): return isinstance(other, HashableDict) and self.val == other.val +@dataclasses.dataclass(frozen=True) class FfiEffect(effects.Effect): def __str__(self): return "FFI" From 8b54a6da58482fd26d0d445238a2823bb34289d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 24 May 2025 05:15:46 -0700 Subject: [PATCH 1362/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/bc63ed41661e939ced1b3d3bfd7d3083eaabd747. PiperOrigin-RevId: 762804332 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f4af3bc02961..9975382d661f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c361fc2992e8d674636e7870992e95658b1be792" -XLA_SHA256 = "990dd1c54128015235bc28286005704209f52abaf3e39e8f96299e1f5af4f7f1" +XLA_COMMIT = "bc63ed41661e939ced1b3d3bfd7d3083eaabd747" +XLA_SHA256 = "bef55711fafa8f39a47506c91d375b667219cb0975d7be29903645e462558508" def repo(): tf_http_archive( From f6b8cb61cb4eddb663d2ae6de246fa25419d11ca Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 24 May 2025 18:11:24 -0700 Subject: [PATCH 1363/1769] Enter into the right mesh context during shmap DCE PiperOrigin-RevId: 762956847 --- jax/_src/shard_map.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 0bae2da15272..bc6bda9c16de 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1740,7 +1740,8 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn mesh = eqn.params["mesh"] manual_axes = eqn.params["manual_axes"] check_vma = eqn.params["check_vma"] - with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None From 37068b6cfe2baa39850e89b40081c33d24569117 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 25 May 2025 04:21:00 -0700 Subject: [PATCH 1364/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/2c35b7a03388e9ee81b3a71037c372595232ff84. PiperOrigin-RevId: 763076180 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 9975382d661f..9f6dadf35a01 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "bc63ed41661e939ced1b3d3bfd7d3083eaabd747" -XLA_SHA256 = "bef55711fafa8f39a47506c91d375b667219cb0975d7be29903645e462558508" +XLA_COMMIT = "2c35b7a03388e9ee81b3a71037c372595232ff84" +XLA_SHA256 = "0926c4b75d694d699deef1ac8f9c0a9ded787c6cf8d66c6ae00b53fa90f913f3" def repo(): tf_http_archive( From c22bba2f9c237feb8743caf5378ee3b7966209bd Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sun, 25 May 2025 19:07:35 +0200 Subject: [PATCH 1365/1769] Clarify that upper bound takes precedence in jnp.clip where bounds are incongruent --- jax/_src/numpy/lax_numpy.py | 1 + tests/lax_numpy_test.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0bd287dadd51..ad2b3ad6aa75 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3410,6 +3410,7 @@ def clip( Returns: An array containing values from ``arr``, with values smaller than ``min`` set to ``min``, and values larger than ``max`` set to ``max``. + Wherever ``min`` is larger than ``max``, the value of ``max`` is returned. See also: - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 875024617b5f..29e6586ffa18 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1065,6 +1065,14 @@ def testClipDeprecatedArgs(self): "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"): jnp.clip(jnp.arange(4), a_min=2, a_max=3) + def testClipUpperPrecedence(self): + a_min = 3 * np.ones(1) + a_max = 2 * np.ones(1) + x = 4 * np.ones(1) + y = jnp.clip(x, min=a_min, max=a_max) + assert y == a_max, f"Expected {y} to equal {a_max} when a_min > a_max." + assert y == jnp.asarray(np.clip(x, a_min=a_min, a_max=a_max)) + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) x = rng((5,), dtype=jnp.complex64) From f9c7a1421571ceb441011775af41ea75410ceeeb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 26 May 2025 03:04:00 -0700 Subject: [PATCH 1366/1769] [Mosaic GPU] Add missing allocator config and skips in one of our distributed tests I added them in all other files, but forgot about this one. PiperOrigin-RevId: 763352483 --- tests/mosaic/gpu_test_distributed.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/mosaic/gpu_test_distributed.py b/tests/mosaic/gpu_test_distributed.py index fee2ce5b03a6..cf3913771983 100644 --- a/tests/mosaic/gpu_test_distributed.py +++ b/tests/mosaic/gpu_test_distributed.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +import os + from absl.testing import parameterized import jax from jax._src import config @@ -50,6 +52,8 @@ def setUp(self): self.skipTest("Only works on GPU with capability >= sm90") if not mgpu.supports_cross_device_collectives(): self.skipTest("NVSHMEM library unavailable.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") if jax.process_count() == 1: self.skipTest("Test requires multiple processes.") if jax.device_count() != jax.process_count(): @@ -97,4 +101,9 @@ def kernel(ctx, src, dst, scratch): if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' jt_multiprocess.main() From fae05bd8592b3508143179602b25dcd609a630b9 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 26 May 2025 03:07:08 -0700 Subject: [PATCH 1367/1769] [Pallas:MGPU] Support remote async copies and use them in the collective matmul PiperOrigin-RevId: 763353415 --- jax/_src/pallas/mosaic_gpu/primitives.py | 25 +++++++++++++++++-- jax/experimental/mosaic/gpu/launch_context.py | 5 +++- .../pallas/ops/gpu/collective_matmul_mgpu.py | 10 ++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1ec22bff3f6d..61be6e35cc55 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -37,6 +37,7 @@ from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering from jax._src.pallas.mosaic_gpu.core import state_types @@ -282,6 +283,10 @@ def _copy_smem_to_gmem_lowering( else: indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) assert not copy_params.get("gmem_transform") mgpu.dialect.async_store( src, @@ -317,13 +322,25 @@ def _split_gmem_slice(gmem_slice): def _extract_gmem_copy_params(transforms): if not transforms: return {} + peer_id = None + indexers = [] for transform in transforms: - if not isinstance(transform, indexing.NDIndexer): + if isinstance(transform, gpu_core.PeerMemRef): + if transform.device_id_type != pallas_primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + "Only logical device ids are supported for GMEM refs." + ) + peer_id = lowering._ensure_ir_value(transform.device_id, jnp.int32) + continue + elif isinstance(transform, indexing.NDIndexer): + indexers.append(transform) + else: raise NotImplementedError( "Non-indexing transforms on GMEM refs are not implemented.") - indexer = lowering.merge_indexers(transforms) + indexer = lowering.merge_indexers(indexers) return dict( gmem_slice=lowering._ndindexer_indices(indexer), + gmem_peer_id=peer_id, ) @@ -542,6 +559,10 @@ def _copy_gmem_to_smem_lowering( indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) barrier_ref = barrier.as_barrier_memref() mgpu.dialect.arrive_expect_tx(barrier_ref, bytes) mgpu.dialect.async_load( diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index e4f0c4efa22c..2a5bb96f4708 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -407,7 +407,10 @@ def _get_tma_desc( "add","min","max","inc","dec","and","or","xor" ] | None, ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + # Using ir.Values in cache keys is a little sketchy, but I think it should + # be fine. Having it in the key will keep it alive, and if comparison and + # hashing is by identity then it should work out. + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py index a6c372f2cee7..854d75dbf6a3 100644 --- a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -140,9 +140,15 @@ def k_loop(idxs, lhs_smem, rhs_smem): plgpu.wgmma(acc_ref, lhs_smem, rhs_smem) k_slice = pl.ds(ki * block_k, block_k) # TODO(apaszke): No need to send on the last step - # TODO(apaszke): Use an async copy. This is uncoalesced. - send_scratch_ref[next_scratch_slot, :, k_slice] = lhs_smem[...] + plgpu.copy_smem_to_gmem( + lhs_smem, send_scratch_ref.at[next_scratch_slot, :, k_slice] + ) + # We only delay release by 1 step, so we need to wait for the + # previous copies. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) k_loop(scratch_ref.at[scratch_slot], rhs_ref) + # Make sure the copy is fully done. + plgpu.wait_smem_to_gmem(0, wait_read_only=False) # TODO(apaszke): Both of those semaphores perform a .sys release. # This is very expensive and we should only do a single .sys fence. pl.semaphore_signal(capacity_sem, device_id=recv_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) From 4bfd163772d640e22fcd7696fa920e64f4103013 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 26 May 2025 04:15:54 -0700 Subject: [PATCH 1368/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5a5e232f7bb9a2fa0d79f461f86a3cfa2c78f2cf. PiperOrigin-RevId: 763372229 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 9f6dadf35a01..7f70f1fa01a6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2c35b7a03388e9ee81b3a71037c372595232ff84" -XLA_SHA256 = "0926c4b75d694d699deef1ac8f9c0a9ded787c6cf8d66c6ae00b53fa90f913f3" +XLA_COMMIT = "5a5e232f7bb9a2fa0d79f461f86a3cfa2c78f2cf" +XLA_SHA256 = "2e23c48918d56aac8ec1986c0deacdbf3bdc5740c6eb796a98bd36329bcd3af0" def repo(): tf_http_archive( From 444e9528277026684ba2dead0c170b21506b7e16 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Mon, 26 May 2025 05:55:49 -0700 Subject: [PATCH 1369/1769] Fix a test which blocks the openxla change. The C128 matmuls will be routed to cuBLAS rather than to be handled by the loop emitter, causing a very slight numerical difference. Therefore, don't be very strict in the comparison. PiperOrigin-RevId: 763397887 --- tests/sparse_bcoo_bcsr_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index feac4882c9a3..f489d4551465 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -603,7 +603,7 @@ def test_bcoo_batched_matmat_default_lowering( # with self.gpu_matmul_warning_context( # "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) - self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) + self.assertArraysAllClose(matmat_expected, matmat_default_lowering_fallback) @jtu.run_on_devices("gpu") def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): From c1e8f250b53e2c9636fd095ead21b64f13c1d422 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Wed, 7 May 2025 15:32:08 +0000 Subject: [PATCH 1370/1769] [Mosaic GPU] Use PTX ISA version = min(ptxas, LLVM) --- jaxlib/mosaic/gpu/custom_call.cc | 243 ++++++++++++++++++++++--------- jaxlib/mosaic/gpu/target.cc | 45 ++++-- jaxlib/mosaic/gpu/target.h | 4 +- 3 files changed, 206 insertions(+), 86 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 27175c3773e6..d109520582c9 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -42,6 +42,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "llvm/ADT/SmallVector.h" @@ -109,6 +110,106 @@ namespace ffi = xla::ffi; using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); +class TemporaryDirectory { + private: + TemporaryDirectory(std::string path) : path(std::move(path)) {} + // TODO(apaszke): Unlink in destructor. + + public: + static absl::StatusOr Create() { + std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; + if (mkdtemp(pattern.data()) == NULL) { + return absl::InternalError("Failed to create temporary directory"); + } + return TemporaryDirectory(std::move(pattern)); + } + + std::string_view GetPath() { return path; } + + private: + std::string path; +}; + +absl::StatusOr RunCUDATool(const char* tool, + const std::vector& args, + bool stderr_to_stdout = true) { + CHECK(!args.empty() && args.back() == nullptr); + const char * cuda_path_ptr = getenv("CUDA_ROOT"); + if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); + std::string tool_path(cuda_path_ptr); + tool_path += "/bin/"; + tool_path += tool; + int stdout_pipe[2] = {-1, -1}; + pid_t child_pid; + posix_spawn_file_actions_t file_actions; + if (posix_spawn_file_actions_init(&file_actions)) { + return absl::InternalError("Failed to initialize spawn file actions"); + } + absl::Cleanup file_actions_destroyer = [&file_actions] { + posix_spawn_file_actions_destroy(&file_actions); + }; + if (pipe(stdout_pipe) == -1) { + return absl::InternalError("Failed to set up pipe"); + } + absl::Cleanup pipe_closer = [&stdout_pipe] { + if (stdout_pipe[0] != -1) close(stdout_pipe[0]); + if (stdout_pipe[1] != -1) close(stdout_pipe[1]); + }; + // close read end in child + if (posix_spawn_file_actions_addclose(&file_actions, stdout_pipe[0])) { + return absl::InternalError("Failed to close read end of the pipe in child"); + } + if (posix_spawn_file_actions_adddup2(&file_actions, stdout_pipe[1], + STDOUT_FILENO)) { + return absl::InternalError("Failed to redirect stdout to pipe"); + } + if (stderr_to_stdout && posix_spawn_file_actions_adddup2( + &file_actions, STDOUT_FILENO, STDERR_FILENO)) { + return absl::InternalError("Failed to redirect stderr to stdout"); + } + // execv is guaranteed by POSIX to not modify the args (other than + // replacing the whole process image), so the const_cast is valid. + if (int status = + posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, + const_cast(args.data()), environ)) { + return absl::InternalError( + absl::StrCat("Process spawn failed: ", strerror(status))); + } + // Proactively close write end in parent. If we don't do this, read + // will block since the pipe will have an open write end in the + // parent process. + if (close(stdout_pipe[1]) == -1) { + return absl::InternalError( + absl::StrCat("Failed to close write end of pipe in parent process: ", + strerror(errno))); + } + // Mark the write end as successfully closed, so it doesn't get + // closed a second time by the deferred pipe_closer. + stdout_pipe[1] = -1; + std::string stdout; + char buf[1024]; + while (int bytes_read = read(stdout_pipe[0], buf, sizeof buf)) { + if (bytes_read == -1) { + return absl::InternalError( + absl::StrCat("Failed to read from pipe: ", strerror(errno))); + } + stdout.append(buf, bytes_read); + } + int status; + if (waitpid(child_pid, &status, 0) == -1) { + return absl::InternalError("Failed to wait for CUDA tool invocation"); + } + if (status != 0) { + std::string error_message = "CUDA tool failed"; + if (!stdout.empty()) { + error_message += ": "; + error_message += stdout; + } + return absl::InternalError(error_message); + } + return stdout; +} + void EnsureLLVMNVPTXTargetIsRegistered() { static absl::once_flag register_nvptx_target_flag; absl::call_once(register_nvptx_target_flag, []() { @@ -119,7 +220,65 @@ void EnsureLLVMNVPTXTargetIsRegistered() { }); } -absl::StatusOr> GetSmAndPtxIsaVersion() { +absl::StatusOr GetLatestPtxasPtxIsaVersion() { + std::vector ptxas_args = {"ptxas", "--input-as-string", + ".version 99.99", nullptr}; + auto status = RunCUDATool("ptxas", ptxas_args).status(); + if (status.ok()) { + return absl::InternalError("ptxas succeeded where it was expected to fail"); + } + // Output message is of the form: + // ptxas application ptx input, line 1; fatal : Unsupported .version 99.99; current version is '8.8' + std::vector chunks = + absl::StrSplit(status.message(), '\''); + if (chunks.size() != 3) { + return absl::InternalError( + "Failed to locate PTX ISA version in ptxas error message"); + } + std::vector major_minor = absl::StrSplit(chunks[1], '.'); + if (major_minor.size() != 2) { + return absl::InternalError( + absl::StrFormat("Expected PTX ISA version to be formatted as " + "MAJOR.MINOR, instead got: %s", + chunks[1])); + } + int major; + if (!absl::SimpleAtoi(major_minor[0], &major)) { + return absl::InternalError( + absl::StrFormat("Failed to parse PTX ISA major version, expected a " + "parsable integer, instead got: %s", + major_minor[0])); + } + int minor; + if (!absl::SimpleAtoi(major_minor[1], &minor)) { + return absl::InternalError( + absl::StrFormat("Failed to parse PTX ISA minor version, expected a " + "parsable integer, instead got: %s", + major_minor[1])); + } + if (minor >= 10) { + return absl::InternalError( + absl::StrFormat("PTX ISA minor version %d is not less than or equal to " + "9, which is assumed for version comparison", + minor)); + } + return major * 10 + minor; +} + +absl::StatusOr GetPtxIsaVersion() { + TF_ASSIGN_OR_RETURN(int ptxas_latest_version, GetLatestPtxasPtxIsaVersion()); + // We'd like to target the latest PTX ISA version supported by + // ptxas. However, it doesn't make sense to ask LLVM to target a PTX + // ISA that it isn't aware of yet. Find the latest version supported + // by LLVM and return the minimum of the two versions, one from + // ptxas and the other from LLVM. + TF_ASSIGN_OR_RETURN(int llvm_latest_version, + mosaic::gpu::GetLatestLlvmPtxIsaVersion()); + int final_version = std::min(ptxas_latest_version, llvm_latest_version); + return absl::StrFormat("ptx%d", final_version); +} + +absl::StatusOr GetSmVersion() { // Assumes driver has been initialized and a context exists. XLA already has // some utilities to query this, but we try to stay runtime-agnostic, so we // build our own here. @@ -138,10 +297,9 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { return absl::InternalError("Failed to get minor compute capability"); } EnsureLLVMNVPTXTargetIsRegistered(); - return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); + return mosaic::gpu::GetSmVersion(major, minor); } - mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { @@ -272,61 +430,6 @@ void InitContext(mlir::MLIRContext* context) { context->loadAllAvailableDialects(); } -absl::Status RunCUDATool(const char* tool, - const std::vector& args, - bool stderr_to_stdout = false) { - CHECK(!args.empty() && args.back() == nullptr); - const char * cuda_path_ptr = getenv("CUDA_ROOT"); - if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); - std::string tool_path(cuda_path_ptr); - tool_path += "/bin/"; - tool_path += tool; - pid_t child_pid; - posix_spawn_file_actions_t file_actions; - if (posix_spawn_file_actions_init(&file_actions)) { - return absl::InternalError("Failed to initialize spawn file actions"); - } - if (posix_spawn_file_actions_adddup2(&file_actions, STDOUT_FILENO, - STDERR_FILENO)) { - return absl::InternalError("Failed to set up spawn file actions"); - } - // execv is guaranteed by POSIX to not modify the args (other than - // replacing the whole process image), so the const_cast is valid. - if (posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, - const_cast(args.data()), environ)) { - return absl::InternalError("Process spawn failed"); - } - int status; - if (waitpid(child_pid, &status, 0) == -1) { - return absl::InternalError("Failed to wait for CUDA tool invocation"); - } - if (status != 0) return absl::InternalError("CUDA tool failed"); - if (posix_spawn_file_actions_destroy(&file_actions) != 0) { - return absl::InternalError("Failed to clean up after posix_spawn"); - } - return absl::OkStatus(); -} - -class TemporaryDirectory { - private: - TemporaryDirectory(std::string path) : path(std::move(path)) {} - // TODO(apaszke): Unlink in destructor. - - public: - static absl::StatusOr Create() { - std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; - if (mkdtemp(pattern.data()) == NULL) { - return absl::InternalError("Failed to create temporary directory"); - } - return TemporaryDirectory(std::move(pattern)); - } - - std::string_view GetPath() { return path; } - - private: - std::string path; -}; - void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; @@ -382,19 +485,23 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, ptxas_args.push_back("-v"); } ptxas_args.push_back(nullptr); - if (auto status = RunCUDATool("ptxas", ptxas_args); !status.ok()) { - std::cerr << "ptxas invocation failed: " << status.message() << std::endl; + if (auto result = RunCUDATool("ptxas", ptxas_args); !result.ok()) { + std::cerr << "ptxas invocation failed: " << result.status() << std::endl; continue; + } else if (dump_ptxas) { + std::cout << *result << std::endl; } if (!dump_sass) { continue; } // We're done. // Call nvdisasm to pretty-print SASS. - if (auto status = RunCUDATool( - "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); - !status.ok()) { - std::cerr << "nvdisasm invocation failed: " << status.message() + auto result = RunCUDATool( + "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); + if (!result.ok()) { + std::cerr << "nvdisasm invocation failed: " << result.status() << std::endl; continue; } + // Dump SASS. + std::cout << *result << std::endl; } } @@ -424,12 +531,8 @@ absl::StatusOr get_nvshmem_llvm_lib_path() { absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { tsl::profiler::TraceMe trace("Compile"); - auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); - if (!sm_and_ptx_isa.ok()) { - return sm_and_ptx_isa.status(); - } - const std::string sm = sm_and_ptx_isa.value().first; - const std::string ptx_isa = sm_and_ptx_isa.value().second; + TF_ASSIGN_OR_RETURN(std::string sm, GetSmVersion()); + TF_ASSIGN_OR_RETURN(std::string ptx_isa, GetPtxIsaVersion()); bool is_comm_used = is_nvshmem_used(module); std::string nvshmem_path = ""; if (is_comm_used) { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a259b3dead7b..4c1866fdaea9 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/strip.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "llvm/MC/MCSubtargetInfo.h" @@ -28,8 +30,7 @@ limitations under the License. namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor) { +absl::StatusOr GetSmVersion(int major, int minor) { // "base" compute capability as reported by the driver. // For example for a Hopper H200 GPU this would return sm_90, and never // sm_90a. @@ -64,25 +65,41 @@ absl::StatusOr> GetSmAndPtxIsaVersion( } } } + return sm_arch_specific ? sm_arch_specific : sm_base; +} - const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; - +absl::StatusOr GetLatestLlvmPtxIsaVersion() { + const std::string triple = "nvptx64-nvidia-cuda"; + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (target == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to lookup LLVM target based on triple %s: %s", triple, error)); + } + // generic subtarget std::unique_ptr subtarget_info{ - target->createMCSubtargetInfo(triple, sm, "")}; + target->createMCSubtargetInfo(triple, "", "")}; if (subtarget_info == nullptr) { - return absl::InternalError( - absl::StrFormat("Failed to get LLVM subtarget info for sm %s", sm)); + return absl::InternalError(absl::StrFormat( + "Failed to get generic LLVM subtarget info for triple %s", triple)); } - + int llvm_latest_version = 0; for (const llvm::SubtargetFeatureKV& feature : - subtarget_info->getEnabledProcessorFeatures()) { - if (absl::StartsWith(feature.Key, "ptx")) { - std::string ptx_isa = feature.Key; - return std::make_pair(sm, ptx_isa); + subtarget_info->getAllProcessorFeatures()) { + absl::string_view version_string = feature.Key; + if (absl::ConsumePrefix(&version_string, "ptx")) { + int version; + if (!absl::SimpleAtoi(version_string, &version)) { + return absl::InternalError( + absl::StrFormat("Failed to convert PTX ISA version to integer: %s", + version_string)); + } + llvm_latest_version = + version > llvm_latest_version ? version : llvm_latest_version; } } - return absl::InternalError(absl::StrFormat( - "Failed to find a PTX ISA LLVM subtarget feature for %s", sm)); + return llvm_latest_version; } } // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/target.h b/jaxlib/mosaic/gpu/target.h index 070ecedebd01..5a2a240d8db1 100644 --- a/jaxlib/mosaic/gpu/target.h +++ b/jaxlib/mosaic/gpu/target.h @@ -22,8 +22,8 @@ limitations under the License. namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor); +absl::StatusOr GetSmVersion(int major, int minor); +absl::StatusOr GetLatestLlvmPtxIsaVersion(); } // namespace mosaic::gpu From f35d708503c7c0b401fe55e3716ad2f3ce8396cc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 27 May 2025 02:32:53 -0700 Subject: [PATCH 1371/1769] [pallas] The `cf` dialect is now always available PiperOrigin-RevId: 763697379 --- jax/_src/lib/mlir/dialects/__init__.py | 3 ++- jax/_src/pallas/mosaic/lowering.py | 6 ------ jax/_src/pallas/mosaic_gpu/lowering.py | 6 ------ 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index b49154e7936a..eccd40104dc1 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from jaxlib.mlir.dialects import arith as arith from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import cf as cf from jaxlib.mlir.dialects import chlo as chlo from jaxlib.mlir.dialects import func as func from jaxlib.mlir.dialects import gpu as gpu @@ -36,6 +37,7 @@ __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ "arith", "builtin", + "cf", "chlo", "func", "gpu", @@ -57,4 +59,3 @@ from jaxlib.mlir.dialects import stablehlo as hlo from jax._src import lib -from jaxlib.mlir.dialects import cf diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4e8827401e0c..635c473620c3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3780,12 +3780,6 @@ def _check_lowering_rule( if not pallas_helpers.debug_checks_enabled(): return [] - if cf is None: - # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. - raise ValueError( - "cf dialect is not available. Make sure you have jaxlib 0.6.1 or later." - ) - error = jax.tree.unflatten(err_tree, err_args) [pred] = error._pred.values() [exception_tree] = error._metadata.values() diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 9e396167f610..b9a3aa17c39d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -3109,12 +3109,6 @@ def _check_lowering_rule(ctx: LoweringRuleContext, *err_args, err_tree, debug): if not pallas_helpers.debug_checks_enabled(): return [] - if cf_dialect is None: - # TODO(slebedev): Remove once the minimal jaxlib version is 0.6.1. - raise ValueError( - "cf dialect is not available. Make sure you have jaxlib 0.6.1 or later." - ) - error = jax.tree.unflatten(err_tree, err_args) [pred] = error._pred.values() [exception_tree] = error._metadata.values() From 3aa4e3688de9d3e3212833a4fbf6b6a7b6204635 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 May 2025 03:17:50 -0700 Subject: [PATCH 1372/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/cb67f2f7ce4787f63f5fc80dc5c30cd3dee8f4e3. PiperOrigin-RevId: 763710186 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 7f70f1fa01a6..1bce754dbd3a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5a5e232f7bb9a2fa0d79f461f86a3cfa2c78f2cf" -XLA_SHA256 = "2e23c48918d56aac8ec1986c0deacdbf3bdc5740c6eb796a98bd36329bcd3af0" +XLA_COMMIT = "cb67f2f7ce4787f63f5fc80dc5c30cd3dee8f4e3" +XLA_SHA256 = "483000398e9c8dc090e5ed493286f91b8f8160793bb290dbe736440eb55e0382" def repo(): tf_http_archive( From f68aab1213bfef7ed3f4d8df8521094e8f1155b4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 03:23:01 -0700 Subject: [PATCH 1373/1769] [Mosaic GPU] Work around MLIR recognizing strided<[1]> as identity layout in some ops I can't explain it, but if we don't do it then the verifier sometimes fails... I'm not even sure how to properly trigger this in a test right now, but worst case it would result in more verifier failures to fix, so I think it's fine to merge as is. PiperOrigin-RevId: 763711454 --- jax/experimental/mosaic/gpu/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 1915b0b45f11..224f5a09cfe5 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -597,7 +597,8 @@ def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value: ) new_shape[dim : dim + 1] = factors identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) - if ref_ty.layout == identity: + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: new_layout = ir.AffineMapAttr.get( ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1) ) From b44b9634ec4c1613b87dcf1278b9437ce9ab8004 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 03:45:23 -0700 Subject: [PATCH 1374/1769] [Pallas:MGPU] Make sure that lowering errors mention the offending line I thought this doesn't work, but it does! Still, adding a test to make sure we don't regress it. PiperOrigin-RevId: 763717665 --- tests/pallas/mosaic_gpu_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 3c0b463ba1c7..7430fcc1ca53 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -21,6 +21,7 @@ import re import sys import tempfile +import traceback from typing import ClassVar from absl.testing import absltest @@ -1766,6 +1767,27 @@ def body(idx, _): jnp.tile((132 * sm_step + jnp.arange(132))[:, None], 128), ) + def test_lowering_error_context(self): + def body(x_ref, y_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier) + plgpu.barrier_wait(barrier) + + x = jnp.arange(127, dtype=jnp.int4) # Size is not a multiple of bytes + offending_line = "plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier)" + try: + pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), + out_shape=x, + scratch_shapes=[plgpu.Barrier(1)], + )(x) + except: + # assertRaisesRegex raises does not let us match the traceback. + self.assertIn(offending_line, traceback.format_exc()) + else: + self.fail("Should have raised an exception") + class PallasCallWarpPrimitiveSemanticsTest(PallasTest): def setUp(self): From 9a7f9f13efdb5129a861b44263ab3024fa95ceb6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 04:00:02 -0700 Subject: [PATCH 1375/1769] [Pallas:MGPU] Add a missing warpgroup barrier before warp core_map If we don't synchronize the warps, some of them can go on and schedule e.g. async copies without waiting for the memory transactions of other warps in the warpgroup to complete. PiperOrigin-RevId: 763721411 --- jax/_src/pallas/mosaic_gpu/lowering.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b9a3aa17c39d..cf867e55f4c9 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2793,6 +2793,10 @@ def _core_map_lowering_rule( "Can only close over scalars and Refs when using core_map with " f"WarpMesh. Found array of shape {aval_in}." ) + # We allow the warps to schedule async copies without synchronizing with + # other warps, so we need to add a barrier here to make sure all reads and + # writes have completed. + mgpu.warpgroup_barrier() _ = lower_jaxpr_to_mosaic_gpu( module_ctx, ctx.launch_ctx, @@ -2800,6 +2804,7 @@ def _core_map_lowering_rule( args=(), consts=args, ) + # TODO(apaszke,justinfu): Do we really need this barrier? mgpu.warpgroup_barrier() return [] raise ValueError(f"Unsupported mesh: {mesh}") From 4f717d31c5b94bcfb742d1a9aaf0024cca7ccd6c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 27 May 2025 04:29:34 -0700 Subject: [PATCH 1376/1769] [pallas:mosaic_gpu] `Barrier` and `ClusterBarrier` are now `kw_only=True` PiperOrigin-RevId: 763730217 --- jax/_src/pallas/mosaic_gpu/core.py | 8 ++-- jax/_src/pallas/mosaic_gpu/pipeline.py | 6 +-- .../pallas/ops/gpu/attention_mgpu.py | 12 ++--- tests/pallas/mosaic_gpu_test.py | 44 +++++++++---------- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 08e47cec4b2e..2fca1464ee0b 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -850,7 +850,7 @@ def __str__(self): return self.name -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class Barrier: """Describes a barrier Ref. @@ -862,9 +862,9 @@ class Barrier: the tensor core. This should be set to True when waiting on Blackwell (TC Gen 5) asynchoronous matmul instructions. """ - num_arrivals: int + num_arrivals: int = 1 num_barriers: int = 1 - for_tensor_core: bool = dataclasses.field(default=False, kw_only=True) + for_tensor_core: bool = False def get_ref_aval(self) -> AbstractMemoryRef: aval = jax_core.ShapedArray( @@ -879,7 +879,7 @@ def __post_init__(self): f"Num arrivals must be at least 1, but got {self.num_arrivals}" ) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ClusterBarrier: collective_axes: tuple[str | tuple[str, ...], ...] num_barriers: int = 1 diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 4966eec7fa6d..a7f8d32677b0 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -230,7 +230,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ], [len(in_specs)], ) - arrival_count = sum(map(_in_smem, in_specs)) + num_arrivals = sum(map(_in_smem, in_specs)) return pl.run_scoped( functools.partial( scoped_pipeline, @@ -240,10 +240,10 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_smem_refs=in_smem_refs, out_smem_refs=out_smem_refs, barrier_ref=None - if arrival_count == 0 + if num_arrivals == 0 else gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - arrival_count, + num_arrivals=num_arrivals, num_barriers=max_concurrent_steps, ), ) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index c7e9f95e3f99..447e3affd7c1 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -278,9 +278,9 @@ def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), scratch, ( - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=compute_wgs), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=compute_wgs), ), (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, plgpu.Barrier(num_arrivals=compute_wgs), @@ -587,7 +587,7 @@ def compute_dk(acc_ref): out_shape=q, scratch_shapes=[ (q_scratch, do_scratch, lse_scratch, delta_scratch), # type: ignore - (plgpu.Barrier(1, num_barriers=compute_wgs),) * 4 # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 4 # type: ignore ], compiler_params=plgpu.CompilerParams(approx_math=True), grid=(batch_size, num_q_tiles, num_q_heads), @@ -608,7 +608,7 @@ def compute_dk(acc_ref): out_shape=[out_shape_kv, out_shape_kv], scratch_shapes=[ (k_scratch, v_scratch), # type: ignore - (plgpu.Barrier(1, num_barriers=compute_wgs),) * 2 # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 2 # type: ignore ], compiler_params=plgpu.CompilerParams(approx_math=True), grid=(batch_size, num_kv_tiles, num_q_heads), @@ -776,7 +776,7 @@ def compute_pv(acc_ref): out_shape=out_shape, scratch_shapes=( tuple(smem_scratch), # type: ignore - plgpu.Barrier(1, num_barriers=compute_wgs), # type: ignore + plgpu.Barrier(num_barriers=compute_wgs), # type: ignore plgpu.Barrier(num_arrivals=compute_wgs),), # type: ignore compiler_params=plgpu.CompilerParams( approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7430fcc1ca53..608cfcba2465 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -408,7 +408,7 @@ def test_inline_mgpu(self): dtype, transforms=transforms, ), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), ) @@ -540,7 +540,7 @@ def test_copy_gmem_to_smem(self, indexer): in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((256,), jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -586,7 +586,7 @@ def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], grid=(1,), ) @@ -617,7 +617,7 @@ def test_gmem_to_smem_with_multiple_smem_indexers(self): in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM(x.shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -672,7 +672,7 @@ def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1, num_barriers=4), + plgpu.Barrier(num_barriers=4), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -713,7 +713,7 @@ def kernel(x_ref, o_ref, barrier_ref): out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x) @@ -736,7 +736,7 @@ def body(tmp_ref): out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) @@ -756,7 +756,7 @@ def body(tmp_ref): kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) @@ -783,7 +783,7 @@ def kernel(x_ref, o_ref, barrier_ref): out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) @@ -827,7 +827,7 @@ def kernel(x_ref, o_ref, barrier_ref): out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128) xt = x.transpose((1, 0, 2)) @@ -847,7 +847,7 @@ def inner_body(scratch_ref): plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) - pl.run_scoped(body, plgpu.Barrier(num_arrivals=1)) + pl.run_scoped(body, plgpu.Barrier()) x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) @@ -1016,7 +1016,7 @@ def test_get_swap_with_transforms(self, *transforms): out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), scratch_shapes=[ plgpu.SMEM(shape, jnp.int32, transforms=tuple(transforms)), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ] ) def kernel(x_ref, o_ref, scratch_ref, barrier_ref): @@ -1089,7 +1089,7 @@ def scoped_kernel(barrier_ref): plgpu.barrier_wait(barrier_ref) def branch(): - pl.run_scoped(scoped_kernel, plgpu.Barrier(num_arrivals=1)) + pl.run_scoped(scoped_kernel, plgpu.Barrier()) jax.lax.cond(x_ref_gmem[0] % 2 == 0, branch, branch) @@ -1698,7 +1698,7 @@ def kernel(x_gmem, o_gmem): plgpu.SMEM(shape, large_ty, transforms=(tiling, large_swizzle)), plgpu.SMEM(shape, small_ty, transforms=(tiling, small_swizzle)) ), - plgpu.Barrier(1, num_barriers=1), + plgpu.Barrier(num_barriers=1), ) def scoped_kernel(x_gmem, o_gmem, aliased_ref, barrier): @@ -1780,7 +1780,7 @@ def body(x_ref, y_ref, barrier): in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), out_shape=x, - scratch_shapes=[plgpu.Barrier(1)], + scratch_shapes=[plgpu.Barrier()], )(x) except: # assertRaisesRegex raises does not let us match the traceback. @@ -1921,7 +1921,7 @@ def _(): plgpu.wait_smem_to_gmem(0) pl.run_scoped(scope, smem_ref=plgpu.SMEM((32, 32), jnp.float32), - tma_barrier=plgpu.Barrier(num_arrivals=1)) + tma_barrier=plgpu.Barrier()) x = jax.random.uniform(jax.random.key(42), (64, 32), jnp.float32) result = kernel(x) np.testing.assert_array_equal(result, x[32:64]) @@ -2345,7 +2345,7 @@ def test_tmem(self): plgpu.TMEM((128, 128), jnp.float32), plgpu.TMEM((128, 128), jnp.float32), plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], num_threads=1, thread_name="x", @@ -2416,7 +2416,7 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, scratch_shapes = [ plgpu.TMEM(shape, jnp.float32, packed=False), plgpu.SMEM(shape, dtype, transforms=transforms), - plgpu.Barrier(num_arrivals=1, for_tensor_core=True), + plgpu.Barrier(for_tensor_core=True), ] if lhs_tmem: scratch_shapes.append(plgpu.TMEM(shape, dtype, packed=True)) @@ -2478,8 +2478,8 @@ def kernel(a_gmem, b_gmem, out_gmem): b_smem=plgpu.SMEM(_rhs_shape, dtype, transforms=transforms), acc_tmem=plgpu.TMEM(_acc_shape, jnp.float32, collective=True), scratch_smem=plgpu.SMEM(_acc_shape, dtype, transforms=transforms), - tma_barrier=plgpu.Barrier(num_arrivals=1), - mma_barrier=plgpu.Barrier(num_arrivals=1, for_tensor_core=True), + tma_barrier=plgpu.Barrier(), + mma_barrier=plgpu.Barrier(for_tensor_core=True), cluster_barrier=plgpu.ClusterBarrier(collective_axes=("x",)), ) def _scoped(a_smem, b_smem, @@ -2546,7 +2546,7 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): scratch_shapes = [ plgpu.TMEM(shape, jnp.float32, packed=False), plgpu.SMEM(shape, dtype, transforms=transforms), - plgpu.Barrier(num_arrivals=1, num_barriers=2, for_tensor_core=True), + plgpu.Barrier(num_barriers=2, for_tensor_core=True), ] f = self.pallas_call( kernel, @@ -2612,7 +2612,7 @@ def kernel(x_gmem, o_gmem): functools.partial(scoped_kernel, x_gmem, o_gmem), plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=max_concurrent_steps), ) def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier): From c13de5cb83c33c7c4d9a3bbbdfd35478e5bc91eb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 27 May 2025 04:49:58 -0700 Subject: [PATCH 1377/1769] Move jax/_src/custom_derivatives.py to its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This was unblocked by moving ad, batching, and custom_transpose to their own rules in prior changes. It required one small code refactoring: moving an effects registration to the location where the effect is defined. PiperOrigin-RevId: 763736189 --- jax/BUILD | 26 +++++++++++++++++++++++++- jax/_src/custom_derivatives.py | 3 --- jax/_src/lax/lax.py | 1 + jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/triton/BUILD | 1 + jax/extend/BUILD | 1 + tests/BUILD | 1 + 8 files changed, 31 insertions(+), 4 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index a236cd206a55..396e6fdf6ed4 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -303,7 +303,6 @@ py_library_providing_imports_info( "_src/callback.py", "_src/checkify.py", "_src/custom_batching.py", - "_src/custom_derivatives.py", "_src/custom_partitioning.py", "_src/custom_partitioning_sharding_rule.py", "_src/debugging.py", @@ -388,6 +387,7 @@ py_library_providing_imports_info( ":core", ":custom_api_util", ":custom_dce", + ":custom_derivatives", ":custom_transpose", ":deprecations", ":dtypes", @@ -613,6 +613,30 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "custom_derivatives", + srcs = ["_src/custom_derivatives.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":batching", + ":config", + ":core", + ":custom_api_util", + ":custom_transpose", + ":dtypes", + ":effects", + ":mlir", + ":partial_eval", + ":state_types", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + pytype_strict_library( name = "custom_transpose", srcs = ["_src/custom_transpose.py"], diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d76d145fd0a6..2a09665f6285 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -40,7 +40,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped -from jax._src.lax import lax from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, @@ -410,8 +409,6 @@ def jvp(*xs): return [*out_primals, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) - custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a49c27d06eee..a9d81c684297 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8188,6 +8188,7 @@ class InOutFeedEffect(effects.Effect): infeed_effect = InOutFeedEffect() outfeed_effect = InOutFeedEffect() +effects.custom_derivatives_allowed_effects.add_type(InOutFeedEffect) def infeed(token, shape=None, partitions=None): """Consumes an infeed value of `shape` from the host. Experimental. diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index a62a9937d91d..a4c3402f5309 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -50,6 +50,7 @@ pytype_strict_library( "//jax:ad_util", "//jax:api_util", "//jax:core", + "//jax:custom_derivatives", "//jax:partial_eval", "//jax:tree_util", "//jax:util", diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index fdd3a56ac7c8..83525f11d3cf 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -103,6 +103,7 @@ py_library( "//jax", "//jax:ad_util", "//jax:core", + "//jax:custom_derivatives", "//jax:dtypes", "//jax:mesh", "//jax:mlir", diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 2b8ee4eaa8f2..acbc11a60039 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -63,6 +63,7 @@ pytype_strict_library( "//jax:api_util", "//jax:config", "//jax:core", + "//jax:custom_derivatives", "//jax:mlir", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 06fb8e671120..6dc5d7d76311 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -46,6 +46,7 @@ py_library_providing_imports_info( "//jax:ad", "//jax:ad_util", "//jax:core", + "//jax:custom_derivatives", ], ) diff --git a/tests/BUILD b/tests/BUILD index 6fac61933d59..e3672eb73f48 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -49,6 +49,7 @@ jax_multiplatform_test( srcs = ["custom_api_test.py"], shard_count = 10, deps = [ + "//jax:custom_derivatives", "//jax:experimental", ] + py_deps([ "absl/testing", From 8124cb64dc7329fa348da82a0e92f0731dccae0a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 May 2025 04:55:58 -0700 Subject: [PATCH 1378/1769] [Pallas] Require parallel dimensions to form a prefix of the grid in TPU interpret mode. Since dimensions with parallel semantics must now appear as the leading dimensions of the grid, this CL also makes the sequential iteration over cores in the simulation never re-visit a core after the simulation has moved on to the next core. This enables the simulation to correctly omit loads and stores of kernel buffers if the same (slice of a) buffer is processed by multiple kernel invocations on the same core. PiperOrigin-RevId: 763737647 --- jax/_src/pallas/mosaic/interpret.py | 116 ++++--- tests/pallas/tpu_pallas_interpret_test.py | 378 ++++++++++++++++++---- 2 files changed, 385 insertions(+), 109 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 4258e52e6541..f3a165105640 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -86,18 +86,23 @@ class TPUInterpretParams: replaced with arrays all of `jnp.inf`. Additionaly any floating point operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. - uninitialized_memory: If "nan", allocated buffers are initialized to - contain all NaNs (or to their maximum possible value for integers). If - "zero", allocated buffers are initialized to all zeros. + uninitialized_memory: If "nan", allocated buffers are initialized to contain + all NaNs (or to their maximum possible value for integers). If "zero", + allocated buffers are initialized to all zeros. Default: "nan". random_seed: Seed for random number generator used during interpretation. Currently random numbers are used to randomize the grid coordinates along dimensions with 'parallel' semantics. Default: None. grid_point_recorder: Callback that is invoked by the interpreter for each - grid point in the order in which the grid points are traversed. This is - intended for inspecting the randomization of coordinates along grid - dimensions with 'parallel' semantics. + grid point in the order in which the grid points are traversed. The + callback is invoked with two arguments: + - A tuple of grid coordinates. + - The local core ID of the core that is processing the grid point. + This callback is intended for inspecting + - the randomization of coordinates along grid dimensions with 'parallel' + semantics and + - the mapping of grid points to local (i.e. per-device) cores. Default: None. num_cores_per_device: The number of cores per device. Default: 1. @@ -107,7 +112,9 @@ class TPUInterpretParams: skip_floating_point_ops: bool = False uninitialized_memory: Literal["nan", "zero"] = "nan" random_seed: int | None = None - grid_point_recorder: Callable[[tuple[jnp.int32, ...]], None] | None = None + grid_point_recorder: ( + Callable[[tuple[np.int32, ...], np.int32], None] | None + ) = None num_cores_per_device: int = 1 @@ -1752,11 +1759,45 @@ def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) - def _get_parallel_dim_semantics( compiler_params: dict[str, Any], num_dimensions_in_grid: int, ) -> tuple[bool, ...]: - """Returns a tuple of booleans indicating whether the corresponding dimension in the grid is parallel.""" + """Returns a tuple indicating which grid dimensions have parallel semantics. + + Args: + compiler_params: Representation of a `mosaic_core.TPUCompilerParams` object + as a dictionary. + num_dimensions_in_grid: The number of dimensions in the grid. + + Returns: + A tuple of booleans where the entry at index `i` is `True` precisely if the + `i`-th dimension in the grid has parallel semantics. + + Raises: + ValueError: If the dimensions with parallel semantics do not form a prefix + of the grid. + """ mosaic_params = _get_mosaic_params(compiler_params) if mosaic_params.dimension_semantics is None: return (False,) * num_dimensions_in_grid - return tuple(ds == 'parallel' for ds in mosaic_params.dimension_semantics) + result = tuple(ds == 'parallel' for ds in mosaic_params.dimension_semantics) + for ds0, ds1 in zip(result[:-1], result[1:]): + if ds1 and not ds0: + raise ValueError( + 'Dimensions with parallel semantics must form a prefix of the grid.' + ) + return result + + +def _get_parallel_subgrid_size( + parallel_semantics_per_dim: tuple[bool, ...], grid: tuple[int, ...] +) -> int: + """Returns the size of the subgrid along the parallel dimensions.""" + return functools.reduce( + lambda x, y: x * y, + ( + dim_size if parallel_dim else 1 + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim) + ), + 1, + ) _GridPointCoordinatesPerDim = tuple[Array, ...] @@ -1836,24 +1877,6 @@ def _get_grid_point( grid_point.append(li if jnp.size(coords) == 0 else coords[li]) return jnp.array(grid_point, dtype=np.int32) - -def _get_next_local_core_id( - local_core_id: int, - parallel_semantics_per_dim: tuple[bool, ...], - grid_point: Array, - next_grid_point: Array, - interpret_params: TPUInterpretParams, -) -> int: - delta = next_grid_point - grid_point - assert delta.shape == (len(parallel_semantics_per_dim),) - parallel_semantics_per_dim = jnp.array(parallel_semantics_per_dim) - deltas_along_parallel_dims = jnp.where(parallel_semantics_per_dim, delta, 0) - return jax.lax.cond( - jnp.any(deltas_along_parallel_dims), - lambda: (local_core_id + 1) % interpret_params.num_cores_per_device, - lambda: local_core_id, - ) - def _uninitialized_value(shape, dtype, interpret_params): if interpret_params.uninitialized_memory == 'nan': if jnp.issubdtype(dtype, jnp.floating): @@ -2078,13 +2101,28 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 - parallel_semantics_per_dim = _get_parallel_dim_semantics( - compiler_params, len(grid) - ) randomized_grid_coordinates = _get_randomized_grid_coordinates( grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] ) + parallel_dim_semantics = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) + parallel_subgrid_size = _get_parallel_subgrid_size( + parallel_dim_semantics, grid # type: ignore[arg-type] + ) + num_points_in_parallel_subgrid_per_core = ( + parallel_subgrid_size + interpret_params.num_cores_per_device - 1 + ) // interpret_params.num_cores_per_device # We round up here. + num_iterations_per_point_in_parallel_subgrid = ( + # This is evenly divisible. + num_iterations // parallel_subgrid_size # type: ignore[operator] + ) + num_iterations_per_core = ( + num_points_in_parallel_subgrid_per_core + * num_iterations_per_point_in_parallel_subgrid + ) + def _get_local_grid_env(loop_idx): if grid_mapping.local_grid_env is not None: return grid_mapping.local_grid_env(loop_idx, grid) @@ -2153,20 +2191,20 @@ def body( cur_start_indices, ) = carry if interpret_params.grid_point_recorder is not None: - callback.io_callback(interpret_params.grid_point_recorder, (), grid_point) + callback.io_callback( + interpret_params.grid_point_recorder, + (), + grid_point, + cur_local_core_id, + ) + + next_local_core_id = (iteration_idx + 1) // num_iterations_per_core with pallas_core.grid_env(_get_local_grid_env(loop_idx)): next_loop_idx = _get_next_indices(grid, loop_idx) next_grid_point = _get_grid_point( next_loop_idx, randomized_grid_coordinates ) - next_local_core_id = _get_next_local_core_id( - cur_local_core_id, - parallel_semantics_per_dim, - grid_point, - next_grid_point, - interpret_params, - ) next_start_indices = [ _compute_start_indices( bm, @@ -2178,8 +2216,8 @@ def body( ) for bm in grid_mapping.block_mappings ] - # Copy slices of the input to the kernel buffers. + # Copy slices of the input to the kernel buffers. def _store_slice_to_kernel_input(index, input_var): # Copy from the HBM buffer for the pallas_call input to the kernel # input buffer. diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 014667d5f948..47d4ba3e1acf 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,6 +18,8 @@ contains only tests that do not use shard_map. """ +from collections.abc import Callable +import dataclasses import functools from absl.testing import absltest @@ -59,11 +61,18 @@ def num_stores(self): return self._num_stores +@dataclasses.dataclass(frozen=True) +class ProcessedGridPoint(): + """Represents a grid point and the ID of the core that has processed it.""" + grid_point: tuple[int, ...] + core_id: int + + class GridPointRecorderContext(object): - """Records grid points in the order in which they are traversed.""" + """Records grid points in the order in which they are procsessed.""" def __init__(self): - self._grid_points = [] + self._grid_points: list[ProcessedGridPoint] = [] def __enter__(self): return self @@ -71,14 +80,17 @@ def __enter__(self): def __exit__(self, ty, value, traceback): ... - def get_recorder(self): - def _recorder(grid_point): - self._grid_points.append(grid_point) + def get_recorder(self) -> Callable[[tuple[np.int32, ...], np.int32], None]: + def _recorder(grid_point, core_id): + processed_grid_point = ProcessedGridPoint( + tuple(int(coord) for coord in grid_point), int(core_id) + ) + self._grid_points.append(processed_grid_point) return _recorder @property - def grid_points(self): + def grid_points(self) -> list[ProcessedGridPoint]: return self._grid_points @@ -359,7 +371,7 @@ def kernel(s_ref, o_ref): s_ref[0] = s + 1 o_ref[:] = jax.lax.full_like(o_ref, s) - def kernel_call_dimensions_arbitrary_parallel(s, grid_point_recorder): + def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), @@ -370,13 +382,13 @@ def kernel_call_dimensions_arbitrary_parallel(s, grid_point_recorder): random_seed=12345, grid_point_recorder=grid_point_recorder ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=('arbitrary', 'parallel') + dimension_semantics=('parallel', 'arbitrary') ), )(s) with GridPointRecorderContext() as grid_point_recorder: result = jax.jit( - kernel_call_dimensions_arbitrary_parallel, static_argnums=1 + kernel_call_dimensions_parallel_arbitrary, static_argnums=1 )( jnp.zeros((1,), jnp.int32), grid_point_recorder.get_recorder(), @@ -384,85 +396,55 @@ def kernel_call_dimensions_arbitrary_parallel(s, grid_point_recorder): np.testing.assert_allclose( result[::8, ::128], [ - [ 2.0, 3.0, 0.0, 1.0], - [ 6.0, 7.0, 4.0, 5.0], - [10.0, 11.0, 8.0, 9.0], - [14.0, 15.0, 12.0, 13.0], + [ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0], ], ) - np.testing.assert_array_equal( + self.assertListEqual( grid_point_recorder.grid_points, [ - [0, 2], - [0, 3], - [0, 0], - [0, 1], - [1, 2], - [1, 3], - [1, 0], - [1, 1], - [2, 2], - [2, 3], - [2, 0], - [2, 1], - [3, 2], - [3, 3], - [3, 0], - [3, 1], + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), ], ) - def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): + def test_dimensions_arbitrary_parallel_raises(self): + def kernel_call(s): + def kernel(s_ref, o_ref): + s = s_ref[0] + o_ref[0] = s + return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=mosaic_interpret.TPUInterpretParams( - random_seed=12345, grid_point_recorder=grid_point_recorder - ), + interpret=mosaic_interpret.TPUInterpretParams(random_seed=12345), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=('parallel', 'arbitrary') + dimension_semantics=('arbitrary', 'parallel') ), )(s) - with GridPointRecorderContext() as grid_point_recorder: - result = jax.jit( - kernel_call_dimensions_parallel_arbitrary, static_argnums=1 - )( + with self.assertRaises(ValueError): + jax.jit(kernel_call)( jnp.zeros((1,), jnp.int32), - grid_point_recorder.get_recorder(), - ) - np.testing.assert_allclose( - result[::8, ::128], - [ - [ 8.0, 9.0, 10.0, 11.0], - [12.0, 13.0, 14.0, 15.0], - [ 0.0, 1.0, 2.0, 3.0], - [ 4.0, 5.0, 6.0, 7.0], - ], - ) - np.testing.assert_array_equal( - grid_point_recorder.grid_points, - [ - [2, 0], - [2, 1], - [2, 2], - [2, 3], - [3, 0], - [3, 1], - [3, 2], - [3, 3], - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [1, 0], - [1, 1], - [1, 2], - [1, 3], - ], ) def test_dynamic_parallel_dimension_raises(self): @@ -583,6 +565,262 @@ def kernel(x_ref, o_ref, vmem_ref): self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, 2.0 * x) + def test_parallel_dimension_and_multiple_cores(self): + def kernel(s_ref, o_ref): + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + + def kernel_call(s, num_cores_per_device, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=mosaic_interpret.TPUInterpretParams( + random_seed=12345, + num_cores_per_device=num_cores_per_device, + grid_point_recorder=grid_point_recorder, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s) + + with self.subTest('num_cores_per_device=1'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 1, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 1 - 1) // 1 = 4 + # num_iterations_per_core = 4 * (16 // 4) = 16 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), + ], + ) + + with self.subTest('num_cores_per_device=2'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 2, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 2 - 1) // 2 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=3'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 3, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 3 - 1) // 3 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=4'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 4, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 4 - 1) // 4 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=5'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 5, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 5 - 1) // 5 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=6'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.int32), 6, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 6 - 1) // 6 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 71edce424c8680825e8e60d741c7e48c184742a7 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 27 May 2025 06:01:08 -0700 Subject: [PATCH 1379/1769] Skip //third_party/py/jax/tests/pallas:mgpu_ragged_dot_test_gpu_h100 on ASAN. PiperOrigin-RevId: 763756072 --- tests/pallas/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index e4a308c2f10b..8be899123de0 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -832,6 +832,9 @@ jax_multiplatform_test( "gpu_h100", ], shard_count = 12, + tags = [ + "noasan", # Times out. + ], deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", From a57b4a1583e9b67f2520b76f7e5c466f3c5ff99b Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 27 May 2025 06:17:31 -0700 Subject: [PATCH 1380/1769] #sdy remove redundant call to sdy-round-trip-export in JAX export. We already call `xla::sdy::addSdyRoundTripExportPipeline` in `xla::SerializeUsingVersionedStablehlo` so no need for this anymore. PiperOrigin-RevId: 763762358 --- jax/_src/export/_export.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 189818541a2c..b390574c0a79 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -694,7 +694,7 @@ def _export_lowered( shardy_enabled = _jax.sdy.lowered_with_shardy( mlir.module_to_bytecode(mlir_module)) - mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) + mlir_module_serialized = _module_to_bytecode(mlir_module) # Figure out the result types and shapes if "global_out_avals" in lowering.compile_args: @@ -808,12 +808,8 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: calling_convention_version=version, _get_vjp=_get_exported_vjp) -def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if shardy_enabled: - mlir_str = _jax.sdy.sdy_round_trip_export_pipeline( - mlir.module_to_bytecode(module)) - else: - mlir_str = mlir.module_to_bytecode(module) +def _module_to_bytecode(module: ir.Module) -> bytes: + mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer # and a StableHLO consumer were built using different versions of StableHLO. # From 487eeb4c0fa518b4055f2d154ffd49153809e845 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 06:35:27 -0700 Subject: [PATCH 1381/1769] [Mosaic GPU] Add tests for the Blackwell matmul kernel Just to give us extra confidence while we make changes. PiperOrigin-RevId: 763767275 --- jax/experimental/mosaic/gpu/examples/BUILD | 9 +++ .../mosaic/gpu/examples/matmul_blackwell.py | 44 +++++----- tests/mosaic/BUILD | 1 + tests/mosaic/gpu_test.py | 3 +- tests/mosaic/matmul_test.py | 80 ++++++++++++++++++- 5 files changed, 113 insertions(+), 24 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index fe1a7e9180ac..b24c38b34235 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -39,6 +39,15 @@ py_library( ], ) +py_library( + name = "matmul_blackwell", + srcs = ["matmul_blackwell.py"], + deps = [ + "//jax", + "//jax:mosaic_gpu", + ], +) + py_library( name = "flash_attention", srcs = ["flash_attention.py"], diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index f771c8bc1ef1..3653d9be8d8d 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -41,7 +41,8 @@ def bytecount(shape, dtype): def build_kernel( - m, n, k, + m, k, n, + dtype: jnp.dtype, tile_m: int = 128, tile_n: int = 128, grid_tile_m: int = 1, @@ -51,12 +52,15 @@ def build_kernel( i1 = ir.IntegerType.get_signless(1) i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() + if jnp.dtype(dtype).itemsize != 2: + raise NotImplementedError(f"Only tested with 16-bit dtypes, but got {dtype}") + if tile_m != 128: + raise NotImplementedError(f"Only tile_m=128 supported, but got {tile_m}") swizzle = 128 - swizzle_elems = tile_k = swizzle // 2 + swizzle_elems = tile_k = 8 * swizzle // jnp.finfo(dtype).bits tiling = (8, swizzle_elems) - in_dtype = jnp.float16 k_loop_iter = k // tile_k max_concurrent_steps = min(max_concurrent_steps, k_loop_iter) @@ -74,7 +78,7 @@ def build_kernel( raise ValueError(f"{n=} must be divisible by {tile_n=}") if k % tile_k != 0: raise ValueError(f"{k=} must be divisible by {tile_k=}") - if (m // tile_m) % grid_tile_m: + if (m // block_tile_m) % grid_tile_m: raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}") def kernel(ctx, a, b, d, smem): @@ -83,8 +87,12 @@ def kernel(ctx, a, b, d, smem): warp_idx = mgpu.warp_idx(sync=True) is_warp_leader = nvvm.elect_sync(i1) - is_leader_of = lambda i: arith.andi(arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader) - is_leader_block = arith.cmpi(arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index)) + is_leader_of = lambda i: arith.andi( + arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader + ) + is_leader_block = arith.cmpi( + arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) + ) m_idx = arith.addi( gpu.block_id(gpu.Dimension.x), @@ -96,7 +104,6 @@ def kernel(ctx, a, b, d, smem): m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) n_start = arith.muli(n_idx, c(tile_n,index)) - with mgpu.when(is_leader_of(TMA_WARP)): @mgpu.fori(c(k_loop_iter, index), None) def _tma_body(ki, _): @@ -107,7 +114,7 @@ def _tma_body(ki, _): full_barrier = ab_full_barriers[slot] with mgpu.when(is_leader_block): full_barrier.arrive_expect_tx( - bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype) + bytecount((tile_m, tile_k), dtype) + bytecount((tile_n, tile_k), dtype) ) k_start = arith.muli(ki, c(tile_k, index)) common_args = dict( @@ -162,7 +169,8 @@ def _mma_body(ki, accumulate): gpu.barrier() mma_done_barrier.wait(for_tensor_core=True) - acc.load().astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) + final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + final_acc.store_tiled(d_smem, swizzle=128) mgpu.commit_shared() ctx.async_copy( src_ref=d_smem, @@ -176,14 +184,14 @@ def _mma_body(ki, accumulate): compute_buffers = ( jax.ShapeDtypeStruct( mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling), - jnp.float16), + dtype), jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), - jnp.float16), + mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), + dtype), ) epilogue_buffer = jax.ShapeDtypeStruct( mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), - jnp.float16) + dtype) smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) smem = ( smem_buffers, @@ -196,10 +204,10 @@ def _mma_body(ki, accumulate): (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)), (128, 1, 1), ( - jax.ShapeDtypeStruct((m, k), jnp.float16), - jax.ShapeDtypeStruct((n, k), jnp.float16), + jax.ShapeDtypeStruct((m, k), dtype), + jax.ShapeDtypeStruct((n, k), dtype), ), - jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((m, n), dtype), smem, cluster=(2 if collective else 1, 1, 1), ) @@ -236,7 +244,7 @@ def main(unused_argv): continue try: with mlir.make_ir_context(), ir.Location.unknown(): - f = build_kernel(m, n, k, **kwargs) + f = build_kernel(m, k, n, jnp.float16, **kwargs) _, runtime = profiler.measure(f)(a, b) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" not in str(e): @@ -251,7 +259,7 @@ def main(unused_argv): raise ValueError("No valid configuration found") with mlir.make_ir_context(), ir.Location.unknown(): - d, runtime = profiler.measure(build_kernel(m, n, k, **best_kwargs))(a, b) + d, runtime = profiler.measure(build_kernel(m, k, n, jnp.float16, **best_kwargs))(a, b) d_ref, ref_runtime = profiler.measure(jax.jit(lambda a, b: a @ b.T))(a, b) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 9b6a7f79d099..ffaa0c3c843f 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -126,6 +126,7 @@ jax_multiplatform_test( deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", + "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", ] + py_deps([ "absl/testing", "numpy", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 42a3f0fc83c1..0c79d26782c7 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -228,8 +228,7 @@ def setUp(self): super().setUp() self.prng = np.random.default_rng(1234) self.context = mlir.make_ir_context() - if mgpu_dialect is not None: - mgpu_dialect.register_dialect(self.context) + mgpu_dialect.register_dialect(self.context) self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 9634718d2d44..680e699c8972 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -19,7 +19,11 @@ from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member import jax.numpy as jnp +import numpy as np import hypothesis as hp import hypothesis.strategies as hps @@ -31,6 +35,7 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul + from jax.experimental.mosaic.gpu.examples import matmul_blackwell config.parse_flags_with_absl() @@ -53,9 +58,13 @@ def setUp(self): super().setUp() if matmul is None: self.skipTest("Mosaic GPU not available.") - if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_equal("9.0")): - self.skipTest("Only works on GPU with capability sm90a") + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test needs a GPU device") + self.context = mlir.make_ir_context() + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) @parameterized.named_parameters( (f"_shard{i}", i) for i in range(5) @@ -63,7 +72,10 @@ def setUp(self): @seed_hypothesis @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug @hp.given(hps.data()) - def test_matmul(self, data): + def test_matmul_sm90(self, data): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + in_dtype = data.draw( hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), label="in_dtype", @@ -122,6 +134,66 @@ def test_matmul(self, data): hp.assume(False) raise e + @parameterized.named_parameters( + # TODO(apaszke): Increase shard count once we have more B200s in CI. + (f"_shard{i}", i) for i in range(1) + ) + @seed_hypothesis + @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug + @hp.given(hps.data()) + def test_matmul_sm100(self, data): + if not jtu.is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + + dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16]), + label="dtype", + ) + m, n, k = ( + data.draw(hps.sampled_from([128, 256, 512, 2048, 8192]), label=d) for d in "mnk" + ) + max_concurrent_steps = data.draw( + hps.integers(2, 5), label="max_concurrent_steps" + ) + collective = data.draw(hps.booleans(), label="collective") + num_ctas = 2 if collective else 1 + hp.assume(not (m == 128 and collective)) # Too small for collective MMA. + tile_m = data.draw( + hps.sampled_from([t for t in [128] if t * num_ctas <= m]), label="tile_m" + ) + tile_n = data.draw( + hps.sampled_from([t for t in [64, 128, 256] if t * num_ctas <= n]), label="tile_n" + ) + grid_m = m // (num_ctas * tile_m) + grid_tile_m = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_m") + hp.assume(grid_m % grid_tile_m == 0) + + try: + kernel = matmul_blackwell.build_kernel( + m, + k, + n, + dtype=dtype, + tile_m=tile_m, + tile_n=tile_n, + grid_tile_m=grid_tile_m, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + ) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" in str(e): + hp.assume(False) + raise + + ka, kb = jax.random.split(jax.random.key(0), 2) + a = jax.random.normal(key=ka, shape=(m, k), dtype=dtype) + b = jax.random.normal(key=kb, shape=(n, k), dtype=dtype) + out = kernel(a, b) + out_ref = jnp.dot(a, b.T) + np.testing.assert_allclose( + out, out_ref, atol=1e-3, rtol=1e-3 if k < 512 else 1e-2 + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From f5ffd7fc7c717b9e1aa91d3887bfb1c93189c73e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 07:09:43 -0700 Subject: [PATCH 1382/1769] [Mosaic GPU] Fix missing symbol errors in OSS collective kernels We sometimes access NVSHMEM functions from the host code too, which means we should include the NVSHMEM host library in the context of the ExecutionEngine. PiperOrigin-RevId: 763777731 --- jaxlib/mosaic/gpu/custom_call.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 7df185cffaf1..54fef13a8521 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -555,9 +555,12 @@ absl::StatusOr, bool>> Compile( return absl::InternalError("Pass pipeline failed"); } - llvm::SmallVector runtime_lib; - if (const char* lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { - runtime_lib.emplace_back(lib_path); + llvm::SmallVector runtime_libs; + if (const char* runtime_lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { + runtime_libs.emplace_back(runtime_lib_path); + } + if (const char* nvshmem_path = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH")) { + runtime_libs.emplace_back(nvshmem_path); } // Create a transformer to run all LLVM optimization passes at the // specified optimization level. @@ -566,7 +569,7 @@ absl::StatusOr, bool>> Compile( mlir::ExecutionEngineOptions options; options.transformer = transformer; options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; - options.sharedLibPaths = runtime_lib; + options.sharedLibPaths = runtime_libs; auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); From ee727f98746e77a0cd7f8ad5b63debfd0ba00053 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 07:26:06 -0700 Subject: [PATCH 1383/1769] [Mosaic GPU][NFC] Refactor the body of the matmul kernel This will make it much simpler to make the kernel persistent. PiperOrigin-RevId: 763782577 --- .../mosaic/gpu/examples/matmul_blackwell.py | 163 +++++++++--------- 1 file changed, 84 insertions(+), 79 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 3653d9be8d8d..929c7c498986 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -94,92 +94,97 @@ def kernel(ctx, a, b, d, smem): arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) ) + # This function executes the kernel for a single output tile. + def compute_output(block_m_start, n_start): + """Compute and store a single output tile.""" + # All blocks in the cluster share the same m_start -- align it! + m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) + with mgpu.when(is_leader_of(TMA_WARP)): + @mgpu.fori(c(k_loop_iter, index), None) + def _tma_body(ki, _): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + # TODO(apaszke): Use a predicate instead of a conditional. + with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))): + ab_empty_barriers[slot].wait() + full_barrier = ab_full_barriers[slot] + with mgpu.when(is_leader_block): + full_barrier.arrive_expect_tx( + bytecount((tile_m, tile_k), dtype) + bytecount((tile_n, tile_k), dtype) + ) + k_start = arith.muli(ki, c(tile_k, index)) + common_args = dict( + swizzle=swizzle, + barrier=full_barrier, + arrive=False, + predicate=None, + collective=gpu.Dimension.x, + partitioned=0, # Non-contracting dim is always 0. + ) + ctx.async_copy( + src_ref=a, + dst_ref=mgpu.memref_slice(a_smem, slot), + gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, + ) + ctx.async_copy( + src_ref=b, + dst_ref=mgpu.memref_slice(b_smem, slot), + gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, + ) + + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): + @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) + def _mma_body(ki, accumulate): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + ab_full_barriers[slot].wait() + tcgen05.mma( + acc, + mgpu.memref_slice(a_smem, slot), + mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=accumulate, + collective=collective, + ) + accumulate = arith.constant(i1, 1) + is_last_iter = arith.cmpi( + arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) + ) + barrier_ptr = arith.select( + is_last_iter, + mma_done_barrier.get_ptr(), + ab_empty_barriers[slot].get_ptr(), + ) + tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx) + return accumulate + + gpu.barrier() + mma_done_barrier.wait(for_tensor_core=True) + + final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + final_acc.store_tiled(d_smem, swizzle=128) + mgpu.commit_shared() + ctx.async_copy( + src_ref=d_smem, + dst_ref=d, + gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), + swizzle=swizzle, + ) + ctx.await_async_copy(0) + m_idx = arith.addi( gpu.block_id(gpu.Dimension.x), arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)), ) n_idx = gpu.block_id(gpu.Dimension.y) block_m_start = arith.muli(m_idx, c(block_tile_m, index)) - # All blocks in the cluster share the same m_start -- align it! - m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) n_start = arith.muli(n_idx, c(tile_n,index)) - - with mgpu.when(is_leader_of(TMA_WARP)): - @mgpu.fori(c(k_loop_iter, index), None) - def _tma_body(ki, _): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - # TODO(apaszke): Use a predicate instead of a conditional. - with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))): - ab_empty_barriers[slot].wait() - full_barrier = ab_full_barriers[slot] - with mgpu.when(is_leader_block): - full_barrier.arrive_expect_tx( - bytecount((tile_m, tile_k), dtype) + bytecount((tile_n, tile_k), dtype) - ) - k_start = arith.muli(ki, c(tile_k, index)) - common_args = dict( - swizzle=swizzle, - barrier=full_barrier, - arrive=False, - predicate=None, - collective=gpu.Dimension.x, - partitioned=0, # Non-contracting dim is always 0. - ) - ctx.async_copy( - src_ref=a, - dst_ref=mgpu.memref_slice(a_smem, slot), - gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - ctx.async_copy( - src_ref=b, - dst_ref=mgpu.memref_slice(b_smem, slot), - gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - - with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): - @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) - def _mma_body(ki, accumulate): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - ab_full_barriers[slot].wait() - tcgen05.mma( - acc, - mgpu.memref_slice(a_smem, slot), - mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), - a_swizzle=swizzle, - b_swizzle=swizzle, - accumulate=accumulate, - collective=collective, - ) - accumulate = arith.constant(i1, 1) - is_last_iter = arith.cmpi( - arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) - ) - barrier_ptr = arith.select( - is_last_iter, - mma_done_barrier.get_ptr(), - ab_empty_barriers[slot].get_ptr(), - ) - tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx) - return accumulate - - gpu.barrier() - mma_done_barrier.wait(for_tensor_core=True) - - final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) - final_acc.store_tiled(d_smem, swizzle=128) - mgpu.commit_shared() - ctx.async_copy( - src_ref=d_smem, - dst_ref=d, - gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, swizzle_elems)), - swizzle=swizzle, - ) - ctx.await_async_copy(0) + # This is not a persistent kernel, so we only process one tile. + compute_output(block_m_start, n_start) compute_buffers = ( jax.ShapeDtypeStruct( From 10cdbb715efda6b30d75c0aa85d1752717beafa8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 May 2025 07:29:42 -0700 Subject: [PATCH 1384/1769] Block until ready for PGLE test Before this fix, the test would finish before execution was done, and profiling would thus yield nothing. PiperOrigin-RevId: 763783695 --- tests/pgle_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 8814250ea066..fd55f0a392f0 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -85,7 +85,7 @@ def f(x, y): pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - compiled(x, y) + jax.block_until_ready(compiled(x, y)) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertIsNotNone(fdo_profile) From fce93d2f829ae20b4be49451fb3511df91627679 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 27 May 2025 10:56:25 -0400 Subject: [PATCH 1385/1769] Fix handling of input None in custom_transpose. --- jax/_src/custom_transpose.py | 16 +++++++++++----- tests/custom_api_test.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 21e607b5bff2..fb125e174122 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -217,7 +217,6 @@ def custom_transpose_transpose_rule( # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) - del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct @@ -225,10 +224,17 @@ def custom_transpose_transpose_rule( ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose.call_wrapped(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) - ct_lin_flat, _ = tree_flatten( - tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), - is_leaf=lambda x: x is None) - return [None] * len(tree_leaves(res_arg)) + ct_lin_flat + ct_lin = tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None) + + # When the transpose returns None, we treat that as a Zero, except when the + # input is also None. In that case, the cotangent corresponding to that input + # should be dropped. + zero = object() + ct_lin = tree_map(lambda l, ct: zero if ct is None and l is not None else ct, + lin_arg, ct_lin, is_leaf=ad.is_undefined_primal) + + ct_lin_flat, _ = tree_flatten(ct_lin) + return [None] * res_tree.num_leaves + [None if ct is zero else ct for ct in ct_lin_flat] def custom_transpose_lowering(*args, call_jaxpr, **params): diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 9d10b40c6030..bfe391797920 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -3722,6 +3722,20 @@ def gt(x, t): with config.use_direct_linearize(True): self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + def test_input_none(self): + # ref: https://github.com/jax-ml/jax/issues/29009 + @jax.custom_jvp + def f(x, y): return y + @f.defjvp + def f_jvp(p, t): return f(*p), g(p, t) + + @custom_transpose(jnp.float32(0)) + def g(r, x): return x[1] + @g.def_transpose + def gt(r, t): return None, jnp.zeros_like(r[1]) + + jax.grad(f, argnums=(1,))(None, jnp.float32(2)) # doesn't crash + class CustomDceTest(jtu.JaxTestCase): From 3c926a233d94e8cefe6e802655c195cb91c6f1b4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 May 2025 09:18:28 -0700 Subject: [PATCH 1386/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a566a66e53c489f947eb6c04fe44205013250922. PiperOrigin-RevId: 763822788 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1bce754dbd3a..0032e491f329 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "cb67f2f7ce4787f63f5fc80dc5c30cd3dee8f4e3" -XLA_SHA256 = "483000398e9c8dc090e5ed493286f91b8f8160793bb290dbe736440eb55e0382" +XLA_COMMIT = "a566a66e53c489f947eb6c04fe44205013250922" +XLA_SHA256 = "e9d265946403dc94c39a6f0d1e4b823cab9aa6056a01465e1be4d4a3b4eb43da" def repo(): tf_http_archive( From 3b3c3385e8c51c81fea111707ff107a5a40edca3 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Tue, 27 May 2025 09:57:13 -0700 Subject: [PATCH 1387/1769] #sdy Remove redundant sdy export since it's now done as part of `MlirToXlaComputation`. PiperOrigin-RevId: 763837933 --- jaxlib/mlir.cc | 2 -- jaxlib/py_client.cc | 10 ---------- jaxlib/py_compile_only_client.cc | 5 ----- 3 files changed, 17 deletions(-) diff --git a/jaxlib/mlir.cc b/jaxlib/mlir.cc index a632cac71d10..4c8188b04a7f 100644 --- a/jaxlib/mlir.cc +++ b/jaxlib/mlir.cc @@ -106,8 +106,6 @@ absl::StatusOr PyMlirModuleToXlaComputation( TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); XlaComputation computation; - // SDY dialect may be part of the module which XLA doesn't know about. - TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, return_tuple, /*use_shardy=*/false)); diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 842bdfecad3d..98bde8c27396 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -459,11 +459,6 @@ PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); - if (options.executable_build_options.use_shardy_partitioner()) { - // Since Shardy is located in the middle of the XLA pipeline, we need to - // export it before going to HLO while preserving Shardy ops and attrs. - TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); - } return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), @@ -478,11 +473,6 @@ PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); - if (options.executable_build_options.use_shardy_partitioner()) { - // Since Shardy is located in the middle of the XLA pipeline, we need to - // export it before going to HLO while preserving Shardy ops and attrs. - TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); - } std::vector> ifrt_loaded_host_callbacks; diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 2de896d80bef..274f57acba00 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -82,11 +82,6 @@ class CompileOnlyPyClient : public PyClient { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); - if (options.executable_build_options.use_shardy_partitioner()) { - // Since Shardy is located in the middle of the XLA pipeline, we need to - // export it before going to HLO while preserving Shardy ops and attrs. - TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); - } auto* ifrt_client = llvm::dyn_cast_or_null(this->ifrt_client()); CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " From 6f0b99356d4b95d5547c11c8ab12aa4f224e4181 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 27 May 2025 10:32:00 -0700 Subject: [PATCH 1388/1769] [Pallas:MGPU] Add an unsafe flag that disables automatic WG-barrier insertion Enabling this flag can introduce races into certain kernels, which is why it's False by default. Still, there's plenty of kernels where it's unnecessary and a few of those suffer performance regressions when it is on. So it makes sense to at least allow users to opt out. PiperOrigin-RevId: 763853668 --- jax/_src/pallas/mosaic_gpu/core.py | 11 +++++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 22 ++++++++++++++++------ jax/_src/pallas/mosaic_gpu/primitives.py | 3 ++- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2fca1464ee0b..7fb933f5623d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -87,6 +87,16 @@ class CompilerParams(pallas_core.CompilerParams): references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. + unsafe_no_auto_barriers: If True, Pallas will never automatically insert + barrier instructions that ensure synchronous semantics of loads and stores. + At the moment, the insertion is done conservatively and might regress + performance. There are (at least) two conditions that must be satisfied + for the use of this flag to be safe. First, no memory region is ever read + *and* written to by the same thread (async copies are performed by + background threads and do not count towards this rule). Secondly, no + thread ever calls commit_smem(), reads from the committed SMEM and then + issues an async copy overwriting that region (this is a very artificial + and highly unlikely scenario). profile_space: The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this. @@ -97,6 +107,7 @@ class CompilerParams(pallas_core.CompilerParams): dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 delay_release: int = 0 + unsafe_no_auto_barriers: bool = False profile_space: int = 0 profile_dir: str = "" lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cf867e55f4c9..6d54e153a9a2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -327,6 +327,8 @@ class ModuleContext: lowering_semantics: mgpu.LoweringSemantics primitive_semantics: gpu_core.PrimitiveSemantics mesh: mesh_lib.Mesh | None + # See the documentation of unsafe_no_auto_barriers in CompilerParams. + auto_barriers: bool warp_axis_name: str | None = None @property @@ -822,6 +824,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): lowering_semantics=lowering_semantics, primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, mesh=jax_mesh, + auto_barriers=not params.unsafe_no_auto_barriers, ) del runtime_smem, grouped_barriers, runtime_barriers _ = lower_jaxpr_to_mosaic_gpu( @@ -1389,7 +1392,8 @@ def _swap_lowering_rule( ctx, x_ref, transforms, handle_transposes=not transposed_value, allow_peer_refs=True ) - mgpu.warpgroup_barrier() # Make sure reads have completed before we write. + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: case ( gpu_core.UnswizzleRef(swizzle), @@ -1443,7 +1447,8 @@ def _swap_lowering_rule( value.store_untiled(x_smem) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") - mgpu.warpgroup_barrier() # Make sure the writes have completed. + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure the writes have completed. return old_value @@ -2796,7 +2801,8 @@ def _core_map_lowering_rule( # We allow the warps to schedule async copies without synchronizing with # other warps, so we need to add a barrier here to make sure all reads and # writes have completed. - mgpu.warpgroup_barrier() + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() _ = lower_jaxpr_to_mosaic_gpu( module_ctx, ctx.launch_ctx, @@ -2804,8 +2810,9 @@ def _core_map_lowering_rule( args=(), consts=args, ) - # TODO(apaszke,justinfu): Do we really need this barrier? - mgpu.warpgroup_barrier() + if ctx.module_ctx.auto_barriers: + # TODO(apaszke,justinfu): Do we really need this barrier? + mgpu.warpgroup_barrier() return [] raise ValueError(f"Unsupported mesh: {mesh}") @@ -3052,7 +3059,8 @@ def _semaphore_signal_lowering_rule( # anything about the state of the other three warps in the warpgroup (they # might still be e.g. reading memory that someone will overwrite once they # receive a signal). - mgpu.utils.warpgroup_barrier() + if ctx.module_ctx.auto_barriers: + mgpu.utils.warpgroup_barrier() pred = ctx.module_ctx.single_wg_lane_predicate llvm_dialect.inline_asm( i32, @@ -3098,6 +3106,8 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): after_block = while_op.after.blocks.append(i32_ty) with ir.InsertionPoint.at_block_begin(after_block): scf_dialect.yield_(after_block.arguments) + # NOTE: This barrier is necessary for a correct lowering of this op and can't + # be removed even if auto_barriers is False. mgpu_utils.warpgroup_barrier() return () diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 61be6e35cc55..9e40d046af13 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -530,7 +530,8 @@ def _copy_gmem_to_smem_lowering( # arrive with the whole transfer size, while everyone else arrives with 0. # But we should continue using this scheme as it's likely to be faster. bytes //= WARPGROUP_SIZE - mgpu.warpgroup_barrier() # Make sure all reads have completed. + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure all reads have completed. barrier.arrive_expect_tx(bytes) else: # In Warp-level lowering, we arrive on each CUDA thread in a warp, but From e258708fc74da6b5757b0996f0fe4bdab07bc526 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Tue, 27 May 2025 10:52:27 -0700 Subject: [PATCH 1389/1769] Fix sempahore typo in JAX PiperOrigin-RevId: 763862020 --- docs/pallas/tpu/distributed.ipynb | 4 ++-- docs/pallas/tpu/distributed.md | 4 ++-- jax/_src/pallas/mosaic/interpret.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index f2b9562c0db2..75aeeb92ca43 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -178,7 +178,7 @@ "\n", "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", "\n", - "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", "\n", "### Routing\n", "\n", @@ -531,7 +531,7 @@ "\n", "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", - " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n", + " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted.\n", "\n", "#### Barrier Semaphores\n", "\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index 36528bfbddec..7b1f26bccf89 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -163,7 +163,7 @@ def example_kernel(input_ref, output_ref, send_sem, recv_sem): `send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. -Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). ### Routing @@ -453,7 +453,7 @@ In order to use regular semaphores, they can be allocated in the same way as a D Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. - - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted. + - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted. #### Barrier Semaphores diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index f3a165105640..401ed02288bc 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -746,7 +746,7 @@ def _allocate_semaphores( ): """Allocates semaphores on the device with id `device_id` and core with id `local_core_id`. - The number of sempahores allocated is given by the product of the entries in + The number of semaphores allocated is given by the product of the entries in `shape`. Since for each semaphore id there is really only one global `Semaphore` From c09b1bb763d846a694f919e5a5adda9575ce66d6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 27 May 2025 20:44:50 +0000 Subject: [PATCH 1390/1769] Update lock files for jaxlib 0.6.1 --- build/requirements.in | 6 +-- build/requirements_lock_3_10.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_11.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_12.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_13.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_13_ft.txt | 70 +++++++++++++++-------------- 6 files changed, 183 insertions(+), 173 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index c5ce2ea279bd..c1be7a250bff 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -16,11 +16,11 @@ wheel # JAX's own libraries. We include these in the requirements so you can # bazel test without building jaxlib and without manually updating the # the requirements files. -jaxlib +jaxlib==0.6.1 # The with-cuda extra also includes NVIDIA's pip packages. -jax-cuda12-plugin[with-cuda] ; sys_platform == "linux" -jax-cuda12-pjrt ; sys_platform == "linux" +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" # TPU dependencies libtpu ; sys_platform == "linux" and platform_machine == "x86_64" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index a4c6b1bf2b77..832c801ced63 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -160,43 +160,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ - --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ + --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ - --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ - --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ - --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ - --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ - --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ - --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ - --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ - --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ - --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ + --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ + --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ + --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ + --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ + --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ + --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ + --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ + --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ + --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 # via -r build/requirements.in -jaxlib==0.6.0 \ - --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ - --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ - --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ - --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ - --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ - --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ - --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ - --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ - --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ - --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ - --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ - --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ - --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ - --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ - --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ - --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ - --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ - --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 +jaxlib==0.6.1 \ + --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ + --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ + --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ + --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ + --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ + --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ + --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ + --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ + --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ + --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ + --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ + --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ + --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ + --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ + --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ + --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ + --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ + --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ @@ -494,7 +494,9 @@ nvidia-nvjitlink-cu12==12.8.61 \ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 0633e733414b..de3c35ed3c02 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -154,43 +154,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ - --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ + --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ - --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ - --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ - --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ - --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ - --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ - --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ - --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ - --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ - --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ + --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ + --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ + --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ + --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ + --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ + --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ + --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ + --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ + --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 # via -r build/requirements.in -jaxlib==0.6.0 \ - --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ - --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ - --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ - --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ - --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ - --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ - --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ - --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ - --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ - --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ - --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ - --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ - --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ - --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ - --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ - --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ - --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ - --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 +jaxlib==0.6.1 \ + --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ + --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ + --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ + --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ + --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ + --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ + --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ + --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ + --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ + --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ + --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ + --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ + --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ + --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ + --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ + --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ + --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ + --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ @@ -489,7 +489,9 @@ nvidia-nvjitlink-cu12==12.8.61 \ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 1ab77a6ec36e..04c6990da696 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -154,43 +154,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ - --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ + --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ - --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ - --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ - --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ - --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ - --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ - --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ - --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ - --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ - --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ + --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ + --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ + --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ + --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ + --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ + --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ + --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ + --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ + --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 # via -r build/requirements.in -jaxlib==0.6.0 \ - --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ - --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ - --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ - --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ - --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ - --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ - --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ - --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ - --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ - --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ - --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ - --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ - --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ - --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ - --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ - --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ - --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ - --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 +jaxlib==0.6.1 \ + --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ + --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ + --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ + --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ + --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ + --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ + --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ + --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ + --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ + --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ + --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ + --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ + --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ + --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ + --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ + --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ + --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ + --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ @@ -489,7 +489,9 @@ nvidia-nvjitlink-cu12==12.8.61 \ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index c20068b732e6..965cb3bc9672 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -181,43 +181,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ - --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ + --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ - --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ - --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ - --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ - --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ - --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ - --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ - --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ - --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ - --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ + --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ + --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ + --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ + --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ + --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ + --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ + --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ + --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ + --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 # via -r build/requirements.in -jaxlib==0.6.0 \ - --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ - --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ - --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ - --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ - --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ - --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ - --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ - --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ - --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ - --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ - --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ - --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ - --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ - --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ - --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ - --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ - --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ - --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 +jaxlib==0.6.1 \ + --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ + --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ + --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ + --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ + --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ + --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ + --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ + --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ + --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ + --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ + --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ + --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ + --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ + --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ + --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ + --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ + --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ + --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d # via -r build/requirements.in kiwisolver==1.4.7 \ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ @@ -544,7 +544,9 @@ nvidia-nvjitlink-cu12==12.8.61 \ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 3795343df0cb..e7d111c3b3e9 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -172,43 +172,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:68371bd9c135244b89663039be208255698a75bec9854d419ea3c3f957ca4646 \ - --hash=sha256:9bfebb06a39614cb6899f7730ea8561f11156ac81cbb3ec6884a62afb3b15ff3 +jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ + --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.0 ; sys_platform == "linux" \ - --hash=sha256:0d9ecede66c40258702a42261e868cdb56a103551a7c3c884b35f531c9acd48e \ - --hash=sha256:28ae6cb1a09b1824d4baeb68386bc615976e89f7a65d403a93822b76dcd1e508 \ - --hash=sha256:530ad851ca462991ce82db26ad47f02b08cebe483c9c8d0c0037e9e27a7b529f \ - --hash=sha256:581f9468c6394f572a9ef0b25cf28b4a8d099abc26ee5da981dd5b680d0a00df \ - --hash=sha256:7cd1b488a54a3089e89588ccaf677089952c82529e7d0403e0b050199e525418 \ - --hash=sha256:a2a3af5f98880d86f8d246abb46a552e5a2ef49d767bfc4a74c8c357752007c6 \ - --hash=sha256:a342f2ce7c4b1f59d403f665a35a86b8650253bb25de34647fb225c45ceb0a04 \ - --hash=sha256:a700e171823ce255102002e40c94788fa868f216257b7d3f0568d09fe75c107b \ - --hash=sha256:e70eb4f084696c3e3be12b5e909ef1205c9f56efe3dcecf2621bd9b5ab5954d5 \ - --hash=sha256:e96f3dd4a942516ae878c9f697e6aefed78e148f09018ca73ee28b23426a7d8a +jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ + --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ + --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ + --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ + --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ + --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ + --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ + --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ + --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ + --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ + --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 # via -r build/requirements.in -jaxlib==0.6.0 \ - --hash=sha256:1597e972ff0e99abbb5bd376167b0b1d565554da54de94f12a5f5c574082f9c6 \ - --hash=sha256:189729639762050c1780b050e98ff620480b1ea32bf167533e000a5cf4c5738e \ - --hash=sha256:2536fa93ec148d5016da8b2077ba66325b0d86aae2289a61c126877f042b3d1c \ - --hash=sha256:541a418b98b28df5bd3a1e93c62b2d3f64d44b0c70b7b608f7fe2b4aa452b2af \ - --hash=sha256:554512c1445ee69c566ef097c3dbdd09e9d9908523eef222c589a559f4220370 \ - --hash=sha256:63106d4e38aec5e4285c8de85e8cddcbb40084c077d07ac03778d3a2bcfa3aae \ - --hash=sha256:64a82f8eb40fdb7ba1d46ef907300d42e4f98cbda9602a2ed8e70db1a9ac4a60 \ - --hash=sha256:7e3ce2ef0edc9b48b36e2704c36181f1ece7a12ac114df753db4286ea2c6e8b8 \ - --hash=sha256:9494cf32c5894669d785c9e2311d2ac0794b29a1a8e9822593211ab43517e657 \ - --hash=sha256:a4d4254c713388887a321379d3c5b1a20213a8dcdc903faf15139ba81e3ecd61 \ - --hash=sha256:b6d85b8d1fd79248b04503517201e72fcbcd3980cf791d37e814709ea50a3c82 \ - --hash=sha256:bed45525e3bb5ec08630bfd207c09af9d62e9ff13f5f07c2ee2cfd8ed8411ba1 \ - --hash=sha256:c0ae959899802e1329cc8ec5a2b4d4be9a076b5beb2052eb49ba37514e623ebc \ - --hash=sha256:c4e97934cbaf5172343aa5ae8ef0c58462ce26154dfda754202b3034160cac7b \ - --hash=sha256:d0fb122dc7830ca2a5ca3c874a087363a00532b644509c219c3bfd1d54515e8d \ - --hash=sha256:d7ab9eaa6e4db3dc6bfba8a061b660147bcd5a1b9d777fde3d729c794f274ab9 \ - --hash=sha256:ec61ca368d0708e1a7543eae620823025bfd405fa9ab331302f209833e970107 \ - --hash=sha256:ef163cf07de00bc5690169e97fafaadc378f1c381f0287e8a473e78ab5bab1b5 +jaxlib==0.6.1 \ + --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ + --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ + --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ + --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ + --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ + --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ + --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ + --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ + --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ + --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ + --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ + --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ + --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ + --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ + --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ + --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ + --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ + --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d # via -r build/requirements.in kiwisolver==1.4.8 \ --hash=sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50 \ @@ -495,7 +495,9 @@ nvidia-nvjitlink-cu12==12.8.61 \ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac From 669f08a8276bda81fb851a2242158802fcbc5f47 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 27 May 2025 14:53:57 -0700 Subject: [PATCH 1391/1769] Reshape ragged_all_to_all to correct shape before concatenating Previously the result of vmapped RA2A was concatenating a flattened result. PiperOrigin-RevId: 763958632 --- jax/_src/lax/parallel.py | 6 ++-- tests/ragged_collective_test.py | 62 +++++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c5f8d3988144..a5bb7222143d 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1448,13 +1448,15 @@ def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, sliced_results = [] for i in range(operand.shape[operand_dim]): sliced_operand = slicing.slice_in_dim(operand, start_index=i, limit_index=i+1, axis=operand_dim).flatten() - sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim).flatten() + sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim) + sliced_output_shape = sliced_output.shape + sliced_output = sliced_output.flatten() sliced_input_offsets = slicing.slice_in_dim(input_offsets, start_index=i, limit_index=i+1, axis=input_offsets_dim).flatten() sliced_send_sizes = slicing.slice_in_dim(send_sizes, start_index=i, limit_index=i+1, axis=send_sizes_dim).flatten() sliced_output_offsets = slicing.slice_in_dim(output_offsets, start_index=i, limit_index=i+1, axis=output_offsets_dim).flatten() sliced_recv_sizes = slicing.slice_in_dim(recv_sizes, start_index=i, limit_index=i+1, axis=recv_sizes_dim).flatten() sliced_result = ragged_all_to_all(sliced_operand, sliced_output, sliced_input_offsets, sliced_send_sizes, sliced_output_offsets, sliced_recv_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) - sliced_result = lax.expand_dims(sliced_result, dimensions=(output_dim,)) + sliced_result = lax.expand_dims(sliced_result.reshape(sliced_output_shape), dimensions=(output_dim,)) sliced_results.append(sliced_result) concat_result = lax.concatenate(sliced_results, dimension=output_dim) diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 1734f67ff063..8b94b862419c 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -382,6 +382,66 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + def test_ragged_all_to_all_vmap_multi_dim_operand(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=0, out_axes=0, axis_name='x' + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ) + self.assertEqual(res.shape, (2, 2, 4)) + @parameterized.named_parameters( dict( testcase_name='_batch_0_data_shard_axis_0_input_0', @@ -510,8 +570,6 @@ def fwd( fwd, in_axes=vmap_batch_axis, out_axes=0, axis_name=vmap_axis_name )( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes - ).reshape( - (2, 2, 4) ) expected_res = jnp.array([[[1, 4, 0, 0], [2, 3, 5, 0]], [[1, 4, 0, 0], [2, 3, 5, 0]]], dtype=jnp.int32) From 69c431759123b4db3de9578837d4dbe0b58db74b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 27 May 2025 17:52:28 -0700 Subject: [PATCH 1392/1769] [pallas] Fix `broadcast_in_dim` fuser eval rule. PiperOrigin-RevId: 764019664 --- jax/_src/pallas/fuser/block_spec.py | 28 +++++---- tests/pallas/fuser_block_spec_test.py | 83 +++++++++++++++++++-------- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 3d4df549949c..3e9ff497bf1e 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -1364,14 +1364,15 @@ def _broadcast_in_dim_usage_rule(ctx, used_out: set[Usage], **params): def _broadcast_in_dim_eval_rule( eval_ctx: KernelEvalContext, x, broadcast_dimensions, **params ): - if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error - # Scalar -> Array broadcast - block_spec = eval_ctx.out_block_specs[0] - shape = tuple( - _block_size(s) for s in block_spec.block_shape if s is not None - ) - return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape) - return x + del params # Unused. + shape = tuple(map(_block_size, eval_ctx.out_block_specs[0].block_shape)) + dims = tuple( + d - sum(s is None for s in shape[:d]) + for d in broadcast_dimensions + if shape[d] is not None + ) + shape = tuple(s for s in shape if s is not None) + return jax.lax.broadcast_in_dim(x, broadcast_dimensions=dims, shape=shape) @register_pull_block_spec_rule(lax.broadcast_in_dim_p) @@ -1385,15 +1386,20 @@ def _broadcast_in_dim_pull_rule( ): del shape, sharding - if not ctx.avals_in[0].shape: # pytype: disable=attribute-error + shape = ctx.avals_in[0].shape # pytype: disable=attribute-error + if not shape: return [pallas_core.no_block_spec] def new_index_map(*args): idx = block_spec.index_map(*args) - return tuple(idx[i] for i in broadcast_dimensions) + return tuple( + 0 if (d == 1) else idx[i] + for i, d in zip(broadcast_dimensions, shape, strict=True) + ) new_block_shape = tuple( - block_spec.block_shape[i] for i in broadcast_dimensions + b if ((b := block_spec.block_shape[i]) is None) or (d != 1) else 1 + for i, d in zip(broadcast_dimensions, shape, strict=True) ) return [pallas_core.BlockSpec(new_block_shape, new_index_map)] diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index f7e70ec1d708..5c0ef0352b1c 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -653,9 +653,12 @@ def f(): kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x ) - def test_broadcast_array(self): + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True) + ) + def test_broadcast_array(self, bcast0, bcast1): - x = jnp.ones((512, 512)) + x = jnp.ones((1 if bcast0 else 512, 1 if bcast1 else 512)) def f(): return jax.lax.broadcast_in_dim(x, (2, 2, 512, 512), (2, 3)) @@ -664,9 +667,8 @@ def f(): self.assertLen(new_values, 1) self.assertEmpty(scalar_prefetch_values) - block_spec = pl.BlockSpec( - (None, 1, 128, 128), lambda i, j, k, l: (i, j, k, l) - ) + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( f2, block_spec, @@ -674,27 +676,62 @@ def f(): scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), )(new_values) self.assertLen(value_block_specs, 1) - x_block_spec = value_block_specs[0] - self.assertEqual(x_block_spec.index_map(0, 0, 1, 2), (1, 2)) - self.assertEqual(x_block_spec.index_map(1, 2, 3, 3), (3, 3)) - - x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32) - np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x - ) - np.testing.assert_array_equal( - kernel_fn((1, 1, 0, 0), scalar_prefetch_values, (x,)), x - ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 1), scalar_prefetch_values, (x,)), x - ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 1, 0), scalar_prefetch_values, (x,)), x + x_index_map = value_block_specs[0].index_map + self.assertEqual( + x_index_map(0, 0, 1, 2), (0 if bcast0 else 1, 0 if bcast1 else 2) ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x + self.assertEqual( + x_index_map(1, 2, 3, 3), (0 if bcast0 else 3, 0 if bcast1 else 3) ) + block_shape = (1 if bcast0 else 128, 1 if bcast1 else 128) + self.assertEqual(block_shape, value_block_specs[0].block_shape) + x = jnp.full(block_shape, fill_value=1.2345, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (1, 2)) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + + @parameterized.parameters(0, 1, 2, 3) + def test_broadcast_1d_array(self, bcast_dim): + full_shape = (2, 2, 512, 512) + x = jnp.ones((full_shape[bcast_dim],)) + + def f(): + return jax.lax.broadcast_in_dim(x, full_shape, (bcast_dim,)) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + x_index_map = value_block_specs[0].index_map + self.assertEqual(x_index_map(0, 0, 1, 2), ((0, 0, 1, 2)[bcast_dim],)) + self.assertEqual(x_index_map(1, 2, 3, 3), ((1, 2, 3, 3)[bcast_dim],)) + + if block_shape[bcast_dim] is None: + x = jnp.ones(()) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), ()) + else: + x = jnp.arange(block_shape[bcast_dim] or 1, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (bcast_dim - 1,)) + + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + def test_element_indexing(self): x = np.zeros((512, 512), dtype=np.float32) From b07aa270626a262fd85266063de9e5a147e8d0e2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 28 May 2025 01:31:39 -0700 Subject: [PATCH 1393/1769] Automated Code Change PiperOrigin-RevId: 764148812 --- jaxlib/py_socket_transfer.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index c7cc7c496b0b..69321aa788d5 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/py_socket_transfer.h" -#include #include #include #include @@ -28,7 +27,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/array.h" // IWYU pragma: keep From 6004c7b6daf68650bec2595ab86bc1efc2eb8845 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 02:33:37 -0700 Subject: [PATCH 1394/1769] [Mosaic GPU] Make the Blackwell matmul kernel persistent This helps with performance a bit (we only allocate and deallocate TMEM once in each SM), and opens up the opportunity for better overlapping of the epilogue. PiperOrigin-RevId: 764168230 --- .../mosaic/gpu/examples/matmul_blackwell.py | 68 +++++++++++++------ 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 929c7c498986..8bf8ca557496 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -15,6 +15,7 @@ """Matmul kernel for Blackwell.""" import itertools +import math import jax from jax._src.interpreters import mlir @@ -81,6 +82,9 @@ def build_kernel( if (m // block_tile_m) % grid_tile_m: raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}") + # We intend this to be iterated in column-major order. + logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)) + def kernel(ctx, a, b, d, smem): ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem (ab_full_barriers, ab_empty_barriers) = barriers @@ -94,17 +98,25 @@ def kernel(ctx, a, b, d, smem): arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) ) - # This function executes the kernel for a single output tile. - def compute_output(block_m_start, n_start): - """Compute and store a single output tile.""" + def compute_output(block_m_start, n_start, call_counter): + """Compute and store a single output tile. + + call_counter should be 0 the first time this function is called and + incremented by 1 before each subsequent call. + """ # All blocks in the cluster share the same m_start -- align it! m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) with mgpu.when(is_leader_of(TMA_WARP)): @mgpu.fori(c(k_loop_iter, index), None) def _tma_body(ki, _): slot = arith.remui(ki, c(max_concurrent_steps, index)) - # TODO(apaszke): Use a predicate instead of a conditional. - with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))): + isnt_warmup = arith.cmpi( + arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index) + ) + isnt_first_call = arith.cmpi( + arith.CmpIPredicate.ne, call_counter, c(0, index) + ) + with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)): ab_empty_barriers[slot].wait() full_barrier = ab_full_barriers[slot] with mgpu.when(is_leader_block): @@ -150,15 +162,12 @@ def _mma_body(ki, accumulate): collective=collective, ) accumulate = arith.constant(i1, 1) + tcgen05.commit_arrive(ab_empty_barriers[slot], collective=collective, ctx=ctx) is_last_iter = arith.cmpi( arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) ) - barrier_ptr = arith.select( - is_last_iter, - mma_done_barrier.get_ptr(), - ab_empty_barriers[slot].get_ptr(), - ) - tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx) + with mgpu.when(is_last_iter): + tcgen05.commit_arrive(mma_done_barrier, collective=collective, ctx=ctx) return accumulate gpu.barrier() @@ -176,15 +185,33 @@ def _mma_body(ki, accumulate): ) ctx.await_async_copy(0) - m_idx = arith.addi( - gpu.block_id(gpu.Dimension.x), - arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)), + # We statically assign the tiles to SMs. + logical_grid_size = math.prod(logical_grid) + sm_id = gpu.block_id(gpu.Dimension.x) + extra_step = arith.cmpi( + arith.CmpIPredicate.slt, sm_id, c(logical_grid_size % num_sms, index) + ) # Some SMs do an extra step when grid size isn't divisible by SM count. + mn_steps = arith.addi( + mgpu.c(logical_grid_size // num_sms, index), + arith.index_castui(index, extra_step), ) - n_idx = gpu.block_id(gpu.Dimension.y) - block_m_start = arith.muli(m_idx, c(block_tile_m, index)) - n_start = arith.muli(n_idx, c(tile_n,index)) - # This is not a persistent kernel, so we only process one tile. - compute_output(block_m_start, n_start) + + @mgpu.fori(mn_steps, None) + def _mn_loop(local_mn_step, _): + global_mn_step = arith.addi( + sm_id, arith.muli(local_mn_step, mgpu.c(num_sms, index)) + ) + logical_idxs = [] + for dim_size in logical_grid: + logical_idxs.append(arith.remui(global_mn_step, mgpu.c(dim_size, index))) + global_mn_step = arith.divui(global_mn_step, mgpu.c(dim_size, index)) + lx, ly, lz = logical_idxs + m_idx = arith.addi(lx, arith.muli(lz, c(grid_tile_m, index))) + n_idx = ly + + block_m_start = arith.muli(m_idx, c(block_tile_m, index)) + n_start = arith.muli(n_idx, c(tile_n,index)) + compute_output(block_m_start, n_start, local_mn_step) compute_buffers = ( jax.ShapeDtypeStruct( @@ -204,9 +231,10 @@ def _mma_body(ki, accumulate): mgpu.Barrier(arrival_count=1), mgpu.TMEM((128, tile_n), jnp.float32, collective=collective), ) + num_sms = 148 return mgpu.as_gpu_kernel( kernel, - (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)), + (num_sms, 1, 1), # This is a persistent kernel. (128, 1, 1), ( jax.ShapeDtypeStruct((m, k), dtype), From 27e4a7486246879a16946ed18a9ba964d72d97c2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 28 May 2025 02:41:04 -0700 Subject: [PATCH 1395/1769] Move jax/_src/attrs.py to its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This required moving a couple `jax.numpy` imports into local functions. These could probably be addressed by moving the registrations elsewhere. PiperOrigin-RevId: 764170653 --- jax/BUILD | 18 +++++++++++++++++- jax/_src/attrs.py | 10 ++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 396e6fdf6ed4..2e2d7902577d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -297,7 +297,6 @@ py_library_providing_imports_info( "_src/ad_checkpoint.py", "_src/api.py", "_src/array.py", - "_src/attrs.py", "_src/blocked_sampler.py", "_src/buffer_callback.py", "_src/callback.py", @@ -377,6 +376,7 @@ py_library_providing_imports_info( ":ad", ":ad_util", ":api_util", + ":attrs", ":basearray", ":batching", ":cloud_tpu_init", @@ -465,6 +465,22 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "attrs", + srcs = ["_src/attrs.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":core", + ":dtypes", + ":partial_eval", + ":source_info_util", + ":tree_util", + ":util", + ], +) + pytype_strict_library( name = "basearray", srcs = ["_src/basearray.py"], diff --git a/jax/_src/attrs.py b/jax/_src/attrs.py index db738ee6368d..7ad6f0e52d32 100644 --- a/jax/_src/attrs.py +++ b/jax/_src/attrs.py @@ -16,7 +16,6 @@ from typing import Any, Callable -import jax from jax._src import core from jax._src import source_info_util from jax._src import api_util @@ -53,7 +52,8 @@ def jax_setattr(obj: Any, attr: str, val: PyTree) -> None: return t.process_setattr(obj, attr, val) def jax_appendattr(obj: Any, attr: str, val: Array) -> None: - return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) + import jax.numpy as jnp # pytype: disable=import-error + return jax_extendattr(obj, attr, jnp.expand_dims(val, 0)) def jax_extendattr(obj: Any, attr: str, val: Array) -> None: with core.take_current_trace() as t: @@ -68,12 +68,13 @@ def _setattr_impl(_, obj, attr, val): core.EvalTrace.process_setattr = _setattr_impl def _extendattr_impl(_, obj, attr, val): + import jax.numpy as jnp # pytype: disable=import-error cur = getattr(obj, attr, dne_sentinel) if cur is dne_sentinel: new = val else: _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) - new = jax.numpy.concatenate([cur, val]) + new = jnp.concatenate([cur, val]) setattr(obj, attr, new) core.EvalTrace.process_extendattr = _extendattr_impl @@ -122,6 +123,7 @@ def _setattr_staging(trace, obj, attr, val): pe.DynamicJaxprTrace.process_setattr = _setattr_staging def _extendattr_staging(trace, obj, attr, val): + import jax.numpy as jnp # pytype: disable=import-error frame = trace.frame if (obj, attr, ReadWrite) in frame.attrs_tracked: @@ -138,7 +140,7 @@ def _extendattr_staging(trace, obj, attr, val): else: assert init_val is not dne_sentinel with core.set_current_trace(trace): - tracer = jax.numpy.concatenate([init_val, val]) + tracer = jnp.concatenate([init_val, val]) setattr(obj, attr, tracer) pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging From 1ff8a65bba4bc43c9aa4ca4b430cc11087f63b50 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 02:52:15 -0700 Subject: [PATCH 1396/1769] [Mosaic GPU] Perform a cluster barrier before deallocating collective TMEM Otherwise one block can begin the deallocation process before the other is done using it. PiperOrigin-RevId: 764173760 --- jax/experimental/mosaic/gpu/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 9464bb587c71..73403fccd595 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -513,6 +513,9 @@ def _launch( if tmem_allocs: gpu.barrier() # Make sure everyone is done before we release TMEM. + if any(alloc.collective for alloc in tmem_allocs): + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) with utils.when(is_init_warp): for alloc in tmem_allocs: alloc.dealloc() From f7adde5227514094beb725766fc4648f115a0aa5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 03:24:22 -0700 Subject: [PATCH 1397/1769] [Mosaic GPU] Improve the error message when PTX version inference fails PiperOrigin-RevId: 764182705 --- jaxlib/mosaic/gpu/custom_call.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 54fef13a8521..bf6a04783be7 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -232,8 +232,9 @@ absl::StatusOr GetLatestPtxasPtxIsaVersion() { // Unsupported .version 99.99; current version is '8.8' std::vector chunks = absl::StrSplit(status.message(), '\''); if (chunks.size() != 3) { - return absl::InternalError( - "Failed to locate PTX ISA version in ptxas error message"); + return absl::InternalError(absl::StrCat( + "Failed to locate PTX ISA version in ptxas error message: ", + status.message())); } std::vector major_minor = absl::StrSplit(chunks[1], '.'); if (major_minor.size() != 2) { From 5635717307baa6684c389517c4d86a7553847a6e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 28 May 2025 03:27:33 -0700 Subject: [PATCH 1398/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/acc83e32f93d83280d3672aa9194847a3d416b06. PiperOrigin-RevId: 764183483 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0032e491f329..6b12b8a028a4 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a566a66e53c489f947eb6c04fe44205013250922" -XLA_SHA256 = "e9d265946403dc94c39a6f0d1e4b823cab9aa6056a01465e1be4d4a3b4eb43da" +XLA_COMMIT = "acc83e32f93d83280d3672aa9194847a3d416b06" +XLA_SHA256 = "d751dbe8cd7baa04c3def33761cb2e0194f8b1923b591a0cb91479acbf3778ab" def repo(): tf_http_archive( From 0b17f6ce59e893940b16cb8d2aa7b39661e587df Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 03:48:42 -0700 Subject: [PATCH 1399/1769] [Mosaic GPU] Implement FragmentedArray.__getitem__ for arbitrary tiled layouts Any tile-aligned slicing is easy to handle. PiperOrigin-RevId: 764189366 --- .../mosaic/gpu/fragmented_array.py | 60 ++++++++++----- tests/mosaic/gpu_test.py | 75 ++++++++++++++++--- 2 files changed, 108 insertions(+), 27 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 77584b5f0dd4..04dd30023293 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1376,26 +1376,26 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): ) def __getitem__(self, idx): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError("Only WGMMA layouts support slicing") + if not isinstance(self.layout, TiledLayout): + raise NotImplementedError("Only arrays with tiled layouts can be sliced") base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if any(isinstance(idx, ir.Value) for idx in base_idx): + raise ValueError("Only static slicing allowed") if any(is_squeezed): raise NotImplementedError("Only slicing implemented") - if ( - base_idx[0] % 64 - or slice_shape[0] % 64 - or base_idx[1] % 8 - or slice_shape[1] % 8 + base_tile_shape = self.layout.base_tile_shape + if len(base_tile_shape) != len(self.shape): + raise NotImplementedError("Tiling has different rank than array") + if any( + b % t or l % t + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) ): raise NotImplementedError("Only tile aligned slicing supported") - base_idx[0] //= 64 - slice_shape[0] //= 64 - base_idx[1] //= 8 - slice_shape[1] //= 8 - new_regs = self.registers[ - base_idx[0] : base_idx[0] + slice_shape[0], - base_idx[1] : base_idx[1] + slice_shape[1], - ] + register_slices = tuple( + slice(b // t, (b + l) // t) + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) + ) + new_regs = self.registers[register_slices] return FragmentedArray( _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed ) @@ -1882,6 +1882,21 @@ def select(self, on_true, on_false): lambda t, p, f: arith.select(p, t, f), self, on_false, ) + @classmethod + def build( + cls, + shape: tuple[int, ...], + layout: FragmentedLayout, + fn: Callable[..., ir.Value], # ir.Value varargs, one for each dim + *, + is_signed: bool | None = None, + ): + undef = llvm.mlir_undef(ir.IntegerType.get_signless(32)) + dummy = cls.splat(undef, shape, layout, is_signed=False) + return dummy.foreach( + lambda _, idx: fn(*idx), create_array=True, is_signed=is_signed + ) + def foreach( self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], @@ -1892,8 +1907,19 @@ def foreach( """Call a function for each value and index.""" index = ir.IndexType.get() new_regs = None - if create_array: - new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) + orig_fn = fn + def fn(*args): + nonlocal new_regs + result = orig_fn(*args) + old_reg_type = self.registers.flat[0].type + # Lazily create new_regs once we know the desired output type. + if create_array and new_regs is None: + if ir.VectorType.isinstance(old_reg_type): + new_reg_type = ir.VectorType.get(old_reg_type.shape, result.type) + else: + new_reg_type = result.type + new_regs = np.full_like(self.registers, llvm.mlir_undef(new_reg_type)) + return result for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 0c79d26782c7..232cb703d06a 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3544,7 +3544,9 @@ def test_pass_is_registered(self): if hp is not None: @hps.composite - def tiled_layouts(draw, initial_tile, vector_transfer: bool = False): + def tiled_layouts( + draw, initial_tile, vector_transfer: bool = False + ) -> fa.TiledLayout: assert all(t.bit_count() == 1 for t in initial_tile) assert math.prod(initial_tile) >= 128 tiles = [initial_tile] @@ -3605,20 +3607,28 @@ def tiled_layouts(draw, initial_tile, vector_transfer: bool = False): vector_dim=vector_dim, ) + @hps.composite + def shape_and_tiled_layout( + draw, vector_transfer: bool = False + ) -> tuple[tuple[int, ...], fa.TiledLayout]: + rank = draw(hps.integers(2, 3)) + initial_tile = tuple( + draw(hps.sampled_from([1, 2, 4, 8, 16, 32, 64, 128])) + for _ in range(rank) + ) + hp.assume(128 <= math.prod(initial_tile) < 128 * 32) + shape = tuple(t * draw(hps.integers(1, 5)) for t in initial_tile) + hp.assume(math.prod(shape) <= 128 * 128) + layout = draw(tiled_layouts(initial_tile, vector_transfer=vector_transfer)) + return shape, layout + class HypothesisTest(TestCase): def test_reduce(self): @hps.composite def strategy(draw): - rank = draw(hps.integers(2, 3)) - initial_tile = tuple( - draw(hps.sampled_from([1, 2, 4, 8, 16, 32, 64, 128])) - for _ in range(rank) - ) - hp.assume(128 <= math.prod(initial_tile) < 128 * 32) - shape = tuple(t * draw(hps.integers(1, 5)) for t in initial_tile) - hp.assume(math.prod(shape) <= 128 * 128) - layout = draw(tiled_layouts(initial_tile, vector_transfer=True)) + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + rank = len(shape) reduced_dims = draw(hps.sets(hps.integers(0, rank - 1), min_size=1)) return shape, layout, tuple(reduced_dims) @@ -3645,6 +3655,51 @@ def kernel(ctx, src, dst, scratch): np.testing.assert_array_equal(result, x.max(reduced_dims)) run() + def test_slice(self): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + tiling = layout.base_tile_shape + tiled_shape = mgpu.tile_shape(shape, tiling)[:len(shape)] + def draw_slice(size, tile): + start = draw(hps.integers(0, size - 1)) + length = draw(hps.integers(1, size - start)) + return slice(start * tile, (start + length) * tile) + slices = tuple(map(draw_slice, tiled_shape, tiling)) + return shape, layout, slices + + basic_slices = (slice(128, 256), slice(16, 16 + 32)) + @hp.given(strategy()) + @hp.example(((256, 256), fa.WGMMA_LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.TMEM_NATIVE_LAYOUT, basic_slices)) + def run(args): + shape, layout, slices = args + def kernel(ctx, dst, _): + def linear_index(*idxs): + total = arith.constant(index, 0) + stride = 1 + for i, size in zip(idxs[::-1], shape[::-1]): + total = arith.addi(total, arith.muli(i, c(stride, index))) + stride *= size + return arith.index_cast(i32, total) + x = mgpu.FragmentedArray.build( + shape, layout, linear_index, is_signed=True + ) + x[slices].store_untiled(dst, optimized=False) + + slice_shape = tuple(len(range(size)[s]) for s, size in zip(slices, shape)) + out_shape = jax.ShapeDtypeStruct(shape=slice_shape, dtype=jnp.int32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(np.prod(shape), dtype=jnp.int32).reshape(*shape) + np.testing.assert_array_equal(result, iota[slices]) + run() + if __name__ == "__main__": absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) From 39f09066e35205335f4c9dd1013316200e77dc31 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 05:41:45 -0700 Subject: [PATCH 1400/1769] [Mosaic GPU] Use a second warpgroup to store the MMA outputs This allows us to prime the GMEM->SMEM pipeline for the next tile while storing the SMEM->GMEM tile for the current one. However, this implies that we can no longer share the same SMEM region for the MMA pipeline and the epilogue, which pushes the SMEM pressure so high that we can't fetch too many steps into the future. Overall the performance is slightly worse than for the baseline kernel, but it recovers and improves upon it in the follow up. PiperOrigin-RevId: 764220403 --- .../mosaic/gpu/examples/matmul_blackwell.py | 47 +++++++++++-------- jax/experimental/mosaic/gpu/utils.py | 13 ++++- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 8bf8ca557496..bf50ca702063 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -86,7 +86,7 @@ def build_kernel( logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)) def kernel(ctx, a, b, d, smem): - ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem + ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, tmem_done_barrier, acc = smem (ab_full_barriers, ab_empty_barriers) = barriers warp_idx = mgpu.warp_idx(sync=True) @@ -97,6 +97,9 @@ def kernel(ctx, a, b, d, smem): is_leader_block = arith.cmpi( arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) ) + is_store_warpgroup = arith.cmpi( + arith.CmpIPredicate.eq, mgpu.warpgroup_idx(sync=True), c(1, i32) + ) def compute_output(block_m_start, n_start, call_counter): """Compute and store a single output tile. @@ -104,6 +107,9 @@ def compute_output(block_m_start, n_start, call_counter): call_counter should be 0 the first time this function is called and incremented by 1 before each subsequent call. """ + isnt_first_call = arith.cmpi( + arith.CmpIPredicate.ne, call_counter, c(0, index) + ) # All blocks in the cluster share the same m_start -- align it! m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) with mgpu.when(is_leader_of(TMA_WARP)): @@ -113,9 +119,6 @@ def _tma_body(ki, _): isnt_warmup = arith.cmpi( arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index) ) - isnt_first_call = arith.cmpi( - arith.CmpIPredicate.ne, call_counter, c(0, index) - ) with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)): ab_empty_barriers[slot].wait() full_barrier = ab_full_barriers[slot] @@ -147,6 +150,9 @@ def _tma_body(ki, _): **common_args, ) + # We wait in all blocks in the cluster to avoid double arrival errors. + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), isnt_first_call)): + tmem_done_barrier.wait(for_tensor_core=True) with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) def _mma_body(ki, accumulate): @@ -170,20 +176,20 @@ def _mma_body(ki, accumulate): tcgen05.commit_arrive(mma_done_barrier, collective=collective, ctx=ctx) return accumulate - gpu.barrier() - mma_done_barrier.wait(for_tensor_core=True) - - final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) - final_acc.store_tiled(d_smem, swizzle=128) - mgpu.commit_shared() - ctx.async_copy( - src_ref=d_smem, - dst_ref=d, - gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, swizzle_elems)), - swizzle=swizzle, - ) - ctx.await_async_copy(0) + with mgpu.when(is_store_warpgroup): + mma_done_barrier.wait(for_tensor_core=True) + final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + final_acc.store_tiled(d_smem, swizzle=128) + mgpu.commit_shared() + tmem_done_barrier.arrive() + ctx.async_copy( + src_ref=d_smem, + dst_ref=d, + gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), + swizzle=128, + ) + ctx.await_async_copy(0) # We statically assign the tiles to SMs. logical_grid_size = math.prod(logical_grid) @@ -224,18 +230,19 @@ def _mn_loop(local_mn_step, _): epilogue_buffer = jax.ShapeDtypeStruct( mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), dtype) - smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) + smem_buffers = [compute_buffers, epilogue_buffer] smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, mgpu.Barrier(arrival_count=1), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=1), mgpu.TMEM((128, tile_n), jnp.float32, collective=collective), ) num_sms = 148 return mgpu.as_gpu_kernel( kernel, (num_sms, 1, 1), # This is a persistent kernel. - (128, 1, 1), + (2 * 128, 1, 1), ( jax.ShapeDtypeStruct((m, k), dtype), jax.ShapeDtypeStruct((n, k), dtype), diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 224f5a09cfe5..e55a21442db0 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -817,8 +817,19 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: ) return parity, arith.xori(parities, bitmask) - def arrive(self, arrival_count: int = 1, can_complete: bool = True): + def arrive( + self, + arrival_count: int = 1, + can_complete: bool = True, + for_tensor_core: bool = False, + ): i64 = ir.IntegerType.get_signless(64) + if for_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], "tcgen05.fence::before_thread_sync;", "", + has_side_effects=True, + ) if can_complete: if arrival_count > 1: count = c(arrival_count - 1, ir.IntegerType.get_signless(32)) From 360799e6405004a9d9a59e044122168314dff970 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 07:18:58 -0700 Subject: [PATCH 1401/1769] [Mosaic GPU] Reduce SMEM pressure of the GMEM store This reworks the previous scheme by transferring all of TMEM to registers at once, and then doing RMEM->SMEM->GMEM in multiple phases, allowing us to use a smaller SMEM buffer. This, in turn, lets us bump max_concurrent_steps for the MMA pipeline which increases performance considerably. The only downside of this scheme is that even though it should be technically feasible to perform the epilogue with 255 registers per thread, ptxas generates a number of spills that might be lowering our performance. Either way, it's still better than the previous alternatives. PiperOrigin-RevId: 764249234 --- .../mosaic/gpu/examples/matmul_blackwell.py | 36 ++++++++++++------- jax/experimental/mosaic/gpu/tcgen05.py | 13 ++++++- jax/experimental/mosaic/gpu/utils.py | 8 ++++- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index bf50ca702063..52909f6a6a2e 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -179,17 +179,28 @@ def _mma_body(ki, accumulate): with mgpu.when(is_store_warpgroup): mma_done_barrier.wait(for_tensor_core=True) final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) - final_acc.store_tiled(d_smem, swizzle=128) - mgpu.commit_shared() - tmem_done_barrier.arrive() - ctx.async_copy( - src_ref=d_smem, - dst_ref=d, - gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, swizzle_elems)), - swizzle=128, - ) - ctx.await_async_copy(0) + assert tile_n % epilogue_tile_n == 0 + for ni in range(tile_n // epilogue_tile_n): + n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n) + final_acc[:, n_slice].store_tiled(d_smem, swizzle=128) + # We store the first tile before arriving to reduce register pressure. + if ni == 0: + # Make sure we've loaded all of TMEM before we arrive. + tcgen05.wait_tmem_load() + tmem_done_barrier.arrive(for_tensor_core=True) + mgpu.commit_shared() + store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index)) + ctx.async_copy( + src_ref=d_smem, + dst_ref=d, + gmem_slice=( + ds(block_m_start, block_tile_m), + ds(store_n_start, epilogue_tile_n), + ), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), + swizzle=128, + ) + ctx.await_async_copy(0) # We statically assign the tiles to SMs. logical_grid_size = math.prod(logical_grid) @@ -227,8 +238,9 @@ def _mn_loop(local_mn_step, _): mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), dtype), ) + epilogue_tile_n = 64 epilogue_buffer = jax.ShapeDtypeStruct( - mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), + mgpu.tile_shape((block_tile_m, epilogue_tile_n), (128, swizzle_elems)), dtype) smem_buffers = [compute_buffers, epilogue_buffer] smem = ( diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index c4f670527a0c..86fbd31ed56d 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -477,6 +477,17 @@ def tmem_load(tmem_addr, shape, num, pack: bool): return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] +def wait_tmem_load(): + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "tcgen05.wait::ld.sync.aligned;", + "", + has_side_effects=True, + ) + utils.warpgroup_barrier() + + def tmem_store(tmem_addr, shape, num, regs, unpack: bool): num_out_regs, regs_vector = _tmem_access_helper(shape, num) pack_mod = ".unpack::16b" if unpack else "" @@ -832,7 +843,7 @@ def _transfer_32xcols( regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing) # We artificially lower the instr_num compared to its limits, because higher # values can lead to register spills.. - instr_num = min(total_num, 64 // regs_per_instr) + instr_num = min(total_num, 32 // regs_per_instr) assert 32 % atom_rows == 0 num_row_steps = 32 // atom_rows for lane_step in range(num_row_steps): diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index e55a21442db0..51b6ed4612ca 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -993,11 +993,17 @@ def __iter__(self): def __getitem__(self, offset): return CollectiveBarrierRef(self.barrier[offset], self.cluster_mask) - def arrive(self): + def arrive(self, for_tensor_core: bool = False): """Arrives on a barrier in all blocks that share at least one of the coordinates along the collective dimensions. Note that unlike in arrive, each warpgroup arrives once. """ + if for_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], "tcgen05.fence::before_thread_sync;", "", + has_side_effects=True, + ) if self.barrier.num_barriers != 1: raise ValueError("Can only arrive on a single barrier") if self.cluster_mask is None: From 98e6041a214af302e9945f066d1db56084584e6b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 07:43:10 -0700 Subject: [PATCH 1402/1769] [Mosaic GPU] Implement a new MMA/TMEM read pipelined matmul kernel This replaces the old scheme that still included a bit of a bubble at the end of each tile with a new scheme that should be entirely bubble-free, for as long as the MMA loop is long enough to hide the store latency (i.e. for big enough K dimensions). This also removes the problems with spills we had in the previous version since the register footprint is relatively small now. PiperOrigin-RevId: 764256446 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +-- jax/experimental/mosaic/gpu/core.py | 2 +- .../mosaic/gpu/examples/matmul_blackwell.py | 38 +++++++++---------- jax/experimental/mosaic/gpu/tcgen05.py | 27 +++++++------ tests/mosaic/gpu_test.py | 3 +- tests/mosaic/matmul_test.py | 11 +++++- 6 files changed, 49 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6d54e153a9a2..82a0a47c4a0d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -258,8 +258,7 @@ def _run_scoped_resource_estimator( packing = 4 // aval.dtype.itemsize else: packing = 1 - layout = tcgen05._infer_tmem_layout( - aval.shape, collective=aval.collective, packing=packing) + layout = tcgen05._infer_tmem_layout(aval.shape, packing=packing) cols_used = layout.cols_in_shape(aval.shape) cols_used = tcgen05._alloc_ncols(cols_used, exact=False) rs += Resources(tmem_scratch_cols=cols_used) @@ -391,8 +390,7 @@ def alloc_tmem( else: packing = 1 if layout is None: - layout = tcgen05._infer_tmem_layout( - struct.shape, collective, packing=packing) + layout = tcgen05._infer_tmem_layout(struct.shape, packing=packing) unpadded_cols_used = layout.cols_in_shape(struct.shape) cols_used = tcgen05._alloc_ncols(unpadded_cols_used, exact_cols) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 73403fccd595..4ed551654a0e 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -351,7 +351,7 @@ def ref(member_thunks=member_thunks): ) if layout is None: layout = tcgen05._infer_tmem_layout( - shape, collective, 1 if packing is None else packing + shape, 1 if packing is None else packing ) num_cols = layout.cols_in_shape(shape) tmem_allocs.append(_TMEMAlloc(addr_ref, num_cols, collective)) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 52909f6a6a2e..ac5a8985ebff 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -107,9 +107,8 @@ def compute_output(block_m_start, n_start, call_counter): call_counter should be 0 the first time this function is called and incremented by 1 before each subsequent call. """ - isnt_first_call = arith.cmpi( - arith.CmpIPredicate.ne, call_counter, c(0, index) - ) + acc_slot = arith.remui(call_counter, c(2, index)) + acc_slice = acc.slice(slice(None), mgpu.ds(arith.muli(acc_slot, c(tile_n, index)), tile_n)) # All blocks in the cluster share the same m_start -- align it! m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) with mgpu.when(is_leader_of(TMA_WARP)): @@ -119,6 +118,9 @@ def _tma_body(ki, _): isnt_warmup = arith.cmpi( arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index) ) + isnt_first_call = arith.cmpi( + arith.CmpIPredicate.ne, call_counter, c(0, index) + ) with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)): ab_empty_barriers[slot].wait() full_barrier = ab_full_barriers[slot] @@ -151,15 +153,16 @@ def _tma_body(ki, _): ) # We wait in all blocks in the cluster to avoid double arrival errors. - with mgpu.when(arith.andi(is_leader_of(MMA_WARP), isnt_first_call)): - tmem_done_barrier.wait(for_tensor_core=True) + reuses_tmem = arith.cmpi(arith.CmpIPredicate.uge, call_counter, c(2, index)) + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), reuses_tmem)): + tmem_done_barrier[acc_slot].wait(for_tensor_core=True) with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) def _mma_body(ki, accumulate): slot = arith.remui(ki, c(max_concurrent_steps, index)) ab_full_barriers[slot].wait() tcgen05.mma( - acc, + acc_slice, mgpu.memref_slice(a_smem, slot), mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), a_swizzle=swizzle, @@ -173,21 +176,17 @@ def _mma_body(ki, accumulate): arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) ) with mgpu.when(is_last_iter): - tcgen05.commit_arrive(mma_done_barrier, collective=collective, ctx=ctx) + tcgen05.commit_arrive(mma_done_barrier[acc_slot], collective=collective, ctx=ctx) return accumulate with mgpu.when(is_store_warpgroup): - mma_done_barrier.wait(for_tensor_core=True) - final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + mma_done_barrier[acc_slot].wait(for_tensor_core=True) + final_acc = acc_slice.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) assert tile_n % epilogue_tile_n == 0 for ni in range(tile_n // epilogue_tile_n): n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n) final_acc[:, n_slice].store_tiled(d_smem, swizzle=128) # We store the first tile before arriving to reduce register pressure. - if ni == 0: - # Make sure we've loaded all of TMEM before we arrive. - tcgen05.wait_tmem_load() - tmem_done_barrier.arrive(for_tensor_core=True) mgpu.commit_shared() store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index)) ctx.async_copy( @@ -200,7 +199,8 @@ def _mma_body(ki, accumulate): gmem_transform=mgpu.TileTransform((128, swizzle_elems)), swizzle=128, ) - ctx.await_async_copy(0) + ctx.await_async_copy(0, await_read_only=True) + tmem_done_barrier[acc_slot].arrive(for_tensor_core=True) # We statically assign the tiles to SMs. logical_grid_size = math.prod(logical_grid) @@ -246,9 +246,9 @@ def _mn_loop(local_mn_step, _): smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, - mgpu.Barrier(arrival_count=1), - mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=1), - mgpu.TMEM((128, tile_n), jnp.float32, collective=collective), + mgpu.Barrier(arrival_count=1, num_barriers=2), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=2), + mgpu.TMEM((128, 2 * tile_n), jnp.float32, collective=collective), ) num_sms = 148 return mgpu.as_gpu_kernel( @@ -273,7 +273,7 @@ def main(unused_argv): b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16) tile_m = (128,) - tile_n = (128, 256, 512) + tile_n = (128, 256) max_concurrent_steps = (2, 4, 5, 6) grid_tile_m = (1, 2, 4, 8, 16) collective = (False, True) @@ -290,7 +290,7 @@ def main(unused_argv): tile_n *= 2 if m < tile_m or n < tile_n: continue - if tile_n > 512: + if 2 * tile_n > 512: continue if (m // tile_m) % kwargs["grid_tile_m"]: continue diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 86fbd31ed56d..13d945249b69 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -173,9 +173,14 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) - if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective, packing=1)): + expected_d_layout = ( + TMEM_COLLECTIVE_N512_LAYOUT + if collective and n * num_cta == 512 + else TMEM_DEFAULT_LAYOUT + ) + if d.layout != expected_d_layout: raise ValueError( - f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" + f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}" ) f32 = ir.F32Type.get() f16 = ir.F16Type.get() @@ -570,9 +575,7 @@ def cols_in_shape(self, shape: tuple[int, int]): return num_tiles // tiles_in_row * cols_in_tile -def _infer_tmem_layout( - shape: tuple[int, int], collective: bool, packing: int = 1 -) -> TMEMLayout: +def _infer_tmem_layout(shape: tuple[int, int], packing: int = 1) -> TMEMLayout: if shape[0] > TMEM_ROWS: raise ValueError( "Can only infer TMEM layout for shapes with at most 128 rows, got:" @@ -593,14 +596,14 @@ def _infer_tmem_layout( "Can only infer TMEM layout for shapes with column count that's a" f" multiple of 8, got: {shape[1]}" ) - if collective and shape[1] == 512: - return TMEMLayout( - elements_in_tile=(shape[0], 128), column_tile_stride=2, packing=packing - ) - else: - return TMEMLayout(elements_in_tile=(shape[0], 8), packing=packing) + return TMEMLayout(elements_in_tile=(shape[0], 8), packing=packing) +TMEM_DEFAULT_LAYOUT = TMEMLayout(elements_in_tile=(TMEM_ROWS, 8), packing=1) +TMEM_COLLECTIVE_N512_LAYOUT = TMEMLayout( + elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2, packing=1 +) + @dataclasses.dataclass(frozen=True) class TMEMRef: address: ir.Value @@ -669,6 +672,8 @@ def slice(self, *idxs): col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): col_idx = arith.constant(i32, col_idx) + if col_idx.type == ir.IndexType.get(): + col_idx = arith.index_cast(i32, col_idx) if packing != 1: col_idx = arith.divui(col_idx, arith.constant(i32, packing)) return TMEMRef( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 232cb703d06a..99b7d67cd691 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1398,11 +1398,12 @@ def quantize(x): y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + tmem_layout = tcgen05.TMEM_COLLECTIVE_N512_LAYOUT if n == 512 else None scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), mgpu.TMABarrier(3), - mgpu.TMEM((128, n), out_jax_dtype, collective=True), + mgpu.TMEM((128, n), out_jax_dtype, collective=True, layout=tmem_layout), ] z = mgpu.as_gpu_kernel( kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 680e699c8972..13082885710e 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -161,8 +161,15 @@ def test_matmul_sm100(self, data): tile_m = data.draw( hps.sampled_from([t for t in [128] if t * num_ctas <= m]), label="tile_m" ) + tmem_cols = 512 tile_n = data.draw( - hps.sampled_from([t for t in [64, 128, 256] if t * num_ctas <= n]), label="tile_n" + hps.sampled_from([ + t + for t in [64, 128, 256] + # We're double buffering TMEM in the kernel, hence the 2x. + if t * num_ctas <= n and 2 * t * num_ctas <= tmem_cols + ]), + label="tile_n", ) grid_m = m // (num_ctas * tile_m) grid_tile_m = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_m") @@ -196,4 +203,4 @@ def test_matmul_sm100(self, data): if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) From 0d0393fd39b44c0627616752df71bc7b97904b80 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 May 2025 08:04:02 -0700 Subject: [PATCH 1403/1769] Set the mesh in SPMDAxisContext to be a concrete mesh so that pallas/mosaic:GPU can get access to the device ids in the mesh PiperOrigin-RevId: 764263324 --- jax/_src/shard_map.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 1ed831d8577f..c38312868163 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -783,6 +783,13 @@ def _shardy_shard_map_token_sharding( return ns._to_sdy_sharding(0) +def _get_spmdaxis_ctx_mesh(mesh): + if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + return concrete_mesh if concrete_mesh is not None else mesh + return mesh + + def _shard_map_lowering_shardy( ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma): axis_ctx = ctx.module_context.axis_context @@ -793,7 +800,8 @@ def _shard_map_lowering_shardy( shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes else: shardy_manual_axes = manual_axes - new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()] @@ -868,7 +876,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_specs, out_specs, out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs, ctx.avals_in, in_avals_, in_nodes) - new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( From fd28b2f45953d5dfd01a5bc9e7795fe38ba22b72 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 28 May 2025 08:32:55 -0700 Subject: [PATCH 1404/1769] Fix JAX PGLE test XLA dumps one more HLO file by default, which leads to one more PGLE profile file. PiperOrigin-RevId: 764274080 --- tests/pgle_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index fd55f0a392f0..e136c3ab8a5a 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -167,8 +167,9 @@ def f(x): self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_before_pgle, 2) + # One for before optimizatiom, one after SPMD partitioning, and one + # after optimization. + self.assertLen(fdo_profiles_before_pgle, 3) # The FDO profile file should be empty. self.assertEqual( os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) @@ -178,8 +179,9 @@ def f(x): self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_after_pgle, 4) + # One more before optimizatiom, one more after SPMD partitioning, and + # one more after optimization. + self.assertLen(fdo_profiles_after_pgle, 6) for fdo_profile in fdo_profiles_after_pgle: if fdo_profile not in fdo_profiles_before_pgle: From 30eecf68052c2ee485c40a04f07b3fe2097a7f8a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 28 May 2025 08:38:45 -0700 Subject: [PATCH 1405/1769] [pallas:triton] Removed the `Triton` prefix from `TritonCompilerParams` All Triton-specific APIs are always used qualified, e.g. `plgpu.TritonCompilerParams`, so the prefix is redundant. PiperOrigin-RevId: 764276165 --- docs/jax.experimental.pallas.triton.rst | 4 ++-- docs/pallas/CHANGELOG.md | 8 ++++++++ jax/_src/pallas/pallas_call.py | 8 ++++---- jax/_src/pallas/triton/core.py | 2 +- .../pallas/triton/pallas_call_registration.py | 4 ++-- jax/experimental/pallas/ops/gpu/attention.py | 6 +++--- .../pallas/ops/gpu/decode_attention.py | 2 +- jax/experimental/pallas/ops/gpu/layer_norm.py | 8 ++++---- .../pallas/ops/gpu/paged_attention.py | 2 +- jax/experimental/pallas/ops/gpu/rms_norm.py | 8 ++++---- jax/experimental/pallas/ops/gpu/softmax.py | 2 +- jax/experimental/pallas/triton.py | 18 +++++++++++++++++- tests/pallas/ops_test.py | 8 ++++---- 13 files changed, 52 insertions(+), 28 deletions(-) diff --git a/docs/jax.experimental.pallas.triton.rst b/docs/jax.experimental.pallas.triton.rst index 76b0896ccf17..023a33bb0909 100644 --- a/docs/jax.experimental.pallas.triton.rst +++ b/docs/jax.experimental.pallas.triton.rst @@ -9,7 +9,7 @@ Classes .. autosummary:: :toctree: _autosummary - TritonCompilerParams + CompilerParams Functions --------- @@ -19,4 +19,4 @@ Functions approx_tanh debug_barrier - elementwise_inline_asm \ No newline at end of file + elementwise_inline_asm diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2d8a83c897f1..40a30057354d 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list ## Unreleased +* Deprecations + + * {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been + renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The + old name is deprecated and will be removed in a future release. + +## Released with jax 0.6.1 + * Removals * Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 964709b4c915..6f8c96a4591c 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1499,7 +1499,7 @@ def pallas_call( interpret: Any = False, name: str | None = None, compiler_params: ( - Mapping[Backend, CompilerParams] | CompilerParams | None + Mapping[Backend, "CompilerParams"] | "CompilerParams" | None ) = None, cost_estimate: CostEstimate | None = None, backend: Backend | None = None, @@ -1550,7 +1550,7 @@ def pallas_call( compiler_params: Optional compiler parameters. The value should either be a backend-specific dataclass (:class:`jax.experimental.pallas.tpu.TPUCompilerParams`, - :class:`jax.experimental.pallas.triton.TritonCompilerParams`, + :class:`jax.experimental.pallas.triton.CompilerParams`, :class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict mapping backend name to the corresponding platform-specific dataclass. backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or @@ -1600,13 +1600,13 @@ def _normalize_compiler_params( ) -> Mapping[Backend, CompilerParams]: if compiler_params is None: return {} - if isinstance(compiler_params, pallas_core.CompilerParams): + if isinstance(compiler_params, CompilerParams): compiler_params = {compiler_params.BACKEND: compiler_params} assert isinstance(compiler_params, Mapping) for backend, params in compiler_params.items(): if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]: raise ValueError(f"Unknown backend in compiler_params: {backend}") - if not isinstance(params, pallas_core.CompilerParams): + if not isinstance(params, CompilerParams): raise ValueError( f"Unexpected compiler_params for backend {backend}: {params}" ) diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index 6b3e10f2b018..7b6e69dc8dd8 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -21,7 +21,7 @@ from jax._src.pallas import core as pallas_core @dataclasses.dataclass(frozen=True) -class TritonCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Compiler parameters for Triton. Attributes: diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index e111cef0f924..9bb5c8f21628 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -72,9 +72,9 @@ def pallas_call_lowering( [lowering_platform] = ctx.platforms or ctx.module_context.platforms if "triton" in compiler_params: - params = cast(triton_core.TritonCompilerParams, compiler_params["triton"]) + params = cast(triton_core.CompilerParams, compiler_params["triton"]) else: - params = triton_core.TritonCompilerParams() + params = triton_core.CompilerParams() num_warps = 4 if params.num_warps is None else params.num_warps num_stages = params.num_stages if num_stages is None: diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 2442ed14f351..ae429be5d73a 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -288,7 +288,7 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=out_specs, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, @@ -351,7 +351,7 @@ def _preprocess_backward(out, do, lse, block_q: int, lambda i, j, k: (j, i, k, 0)), ], out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), + compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3), out_shape=out_shape, debug=debug, interpret=interpret, @@ -634,7 +634,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, name="mha_backward", debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=2 ), )(q, k, v, segment_ids, out, do, lse, delta) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index e2c19b3eaf2d..ee8c22d1b3a4 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -193,7 +193,7 @@ def decode_attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages ), out_shape=[ diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 187d74ee1fd9..b838885a9136 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -94,7 +94,7 @@ def layer_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -215,7 +215,7 @@ def layer_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -247,7 +247,7 @@ def layer_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -283,7 +283,7 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/paged_attention.py b/jax/experimental/pallas/ops/gpu/paged_attention.py index b30ef554fe12..fbf861f92412 100644 --- a/jax/experimental/pallas/ops/gpu/paged_attention.py +++ b/jax/experimental/pallas/ops/gpu/paged_attention.py @@ -222,7 +222,7 @@ def paged_attention_unbatched( ], debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages ), name=f"paged_attention_{block_h=}_{pages_per_compute_block=}", diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index baeaeb8a57b3..a1b2b582f7bb 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -82,7 +82,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -196,7 +196,7 @@ def rms_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -228,7 +228,7 @@ def rms_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -264,7 +264,7 @@ def rms_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages ), grid=(), diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 7fc6a0f50cb4..68960081288e 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -80,7 +80,7 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py index 06adb9e6da7e..1c512540adf2 100644 --- a/jax/experimental/pallas/triton.py +++ b/jax/experimental/pallas/triton.py @@ -14,7 +14,23 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams +from jax._src.pallas.triton.core import CompilerParams as CompilerParams from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm + +import typing as _typing # pylint: disable=g-import-not-at-top +if _typing.TYPE_CHECKING: + TritonCompilerParams = CompilerParams +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + _deprecations = { + # Deprecated on May 27th 2025. + "TritonCompilerParams": ( + "TritonCompilerParams is deprecated, use CompilerParams instead.", + CompilerParams, + ), + } + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 3e777ac7ea2c..c819d050c8a5 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1646,7 +1646,7 @@ def kernel(x_ref, o_ref): @unittest.skipIf( sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", + "plgpu_triton.CompilerParams unavailable on Windows", ) def test_debug_print(self): self.skip_if_mosaic_gpu() @@ -1661,7 +1661,7 @@ def test_debug_print(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( + compiler_params=plgpu_triton.CompilerParams( num_warps=1, num_stages=1 ), ) @@ -1677,7 +1677,7 @@ def kernel(x_ref, o_ref): @unittest.skipIf( sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", + "plgpu_triton.CompilerParams unavailable on Windows", ) def test_debug_print_with_values(self): if jtu.test_device_matches(["tpu"]): @@ -1690,7 +1690,7 @@ def test_debug_print_with_values(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( + compiler_params=plgpu_triton.CompilerParams( num_warps=1, num_stages=1 ), ) From de491b96c5df15464fb6410cf264c4d830bfe0f9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 28 May 2025 08:40:22 -0700 Subject: [PATCH 1406/1769] [pallas:mosaic_gpu] Added the missing resource estimation rule for `pl.run_state` PiperOrigin-RevId: 764276682 --- jax/_src/pallas/mosaic_gpu/lowering.py | 62 +++++++++++++++++++++----- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 82a0a47c4a0d..a56733a89f60 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -176,11 +176,18 @@ def _estimate_resources( rs = Resources(smem_scratch_bytes=0) for eqn in jaxpr.eqns: # TODO(slebedev): Add support for other primitives, notably control flow. - rule = _resource_estimators.get(eqn.primitive) - if rule is None: - # Assume that unsupported primitives are neutral wrt resource usage. + if rule := _resource_estimators.get(eqn.primitive): + rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) continue - rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) + # Assume that unsupported primitives are neutral wrt resource usage, + # unless they have a jaxpr in their params. + if any( + isinstance(v, (jax_core.Jaxpr, jax_core.ClosedJaxpr)) + for v in eqn.params.values() + ): + raise NotImplementedError( + f"Resource estimation does not support {eqn.primitive}" + ) return rs @@ -188,7 +195,7 @@ def _estimate_resources( @_register_resource_estimator(lax.cond_p) def _cond_resource_estimator( ctx: ResourceEstimatorContext, *args, branches -) -> int: +) -> Resources: del args # Unused. return functools.reduce( lambda a, b: a | b, @@ -199,7 +206,7 @@ def _cond_resource_estimator( @_register_resource_estimator(lax.scan_p) def _scan_resource_estimator( ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.ClosedJaxpr, **params -) -> int: +) -> Resources: del args, params # Unused. return _estimate_resources(ctx, jaxpr) @@ -211,17 +218,52 @@ def _while_resource_estimator( cond_jaxpr: jax_core.ClosedJaxpr, body_jaxpr: jax_core.ClosedJaxpr, **params, -) -> int: +) -> Resources: del args, params # Unused. return _estimate_resources(ctx, cond_jaxpr) | _estimate_resources( ctx, body_jaxpr ) +@_register_resource_estimator(pjit.pjit_p) +def _pjit_resource_estimator( + ctx: ResourceEstimatorContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + **params, +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + +@_register_resource_estimator(pallas_core.core_map_p) +def _core_map_resource_estimator( + ctx: ResourceEstimatorContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + **params, +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + +@_register_resource_estimator(discharge.run_state_p) +def _run_state_resource_estimator( + ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.Jaxpr, **params +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + @_register_resource_estimator(primitives.run_scoped_p) def _run_scoped_resource_estimator( - ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr, collective_axes -) -> int: + ctx: ResourceEstimatorContext, + *consts, + jaxpr: jax_core.Jaxpr, + collective_axes, +) -> Resources: + del collective_axes # Unused. + # NOTE: This rule assumes that the allocation happens collectively, although # it can't be checked here due to limited context. We check this in the actual # lowering rule. @@ -280,7 +322,7 @@ def _run_scoped_resource_estimator( @_register_resource_estimator(lax.reduce_sum_p) def _reduce_sum_resource_estimator( ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes -) -> int: +) -> Resources: del ctx, axes # Unused. # We don't need shmem for some reductons, but it depends on the layout, so we # conservatively request some scratch space. From ba64c02fb3baa283c5475a2b74b7ceed584c25c7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 27 May 2025 22:53:00 +0000 Subject: [PATCH 1407/1769] [better-errors] if a non-jaxtype is returned, say it's a return problem --- jax/_src/interpreters/partial_eval.py | 16 +++++++++++++++- tests/api_test.py | 23 +++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6ea16ec8e8ba..3c499429a663 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2243,7 +2243,7 @@ def trace_to_jaxpr_dynamic( try: with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - + _check_returned_jaxtypes(fun.debug_info, ans) out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) @@ -2255,6 +2255,20 @@ def trace_to_jaxpr_dynamic( config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked +def _check_returned_jaxtypes(dbg, out_tracers): + for i, x in enumerate(out_tracers): + try: + core.typeof(x) + except TypeError: + if (dbg and len(paths := dbg.result_paths()) > i and + (p := paths[i].removeprefix('result'))): + extra = f' at output component {p}' + else: + extra = '' + raise TypeError( + f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a " + f"value of type {type(x)}{extra}, which is not a valid JAX type") from None + def _check_no_returned_refs( dbg: core.DebugInfo, out_tracers: Sequence[DynamicJaxprTracer] diff --git a/tests/api_test.py b/tests/api_test.py index f5b74e1e10d6..584eb0eda496 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5058,6 +5058,29 @@ def test_ensure_compile_time_eval_no_leaks(self): with jax.ensure_compile_time_eval(): jnp.linalg.solve(jnp.eye(3), jnp.ones(3)) # doesn't crash + def test_returned_non_jaxtype(self): + + class TestEnum(enum.Enum): + A = enum.auto() + + @jax.tree_util.register_dataclass + @dataclasses.dataclass + class TestClass3: + test_enum_field: TestEnum = dataclasses.field(metadata=dict(static=True)) + test_data_field: int + + def test_jax_function(test_class: TestClass3) -> TestEnum: + return test_class.test_enum_field + + jitted_test_function = jax.jit(test_jax_function) + with self.assertRaisesRegex(TypeError, "returned a value of type"): + jitted_test_function( + TestClass3( + test_data_field=1, + test_enum_field=TestEnum.A, + ) + ) + class RematTest(jtu.JaxTestCase): From 68fcf154bdcce970ef8fae50b1876404726855bf Mon Sep 17 00:00:00 2001 From: Jen Ha Date: Wed, 28 May 2025 11:17:57 -0700 Subject: [PATCH 1408/1769] Skip TPU metadata server query when not using TPU. Resolve an issue where `jax.devices()` hangs due to unwanted TPU metadata query when using LibTPU with a device other than TPU (ex: CPU's). This feature can be useful in cross [AOT](https://docs.jax.dev/en/latest/aot.html). --- jax/_src/cloud_tpu_init.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 0539e4253063..f42794db7696 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,11 +15,14 @@ import datetime import os import re +import logging import warnings from jax import version from jax._src import config from jax._src import hardware_utils +logger = logging.getLogger(__name__) + running_in_cloud_tpu_vm: bool = False @@ -74,6 +77,9 @@ def cloud_tpu_init() -> None: # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if num_tpu_chips == 0: + logger.info('Using LibTPU with a device other than TPU. Skipping TPU metadata query.') + os.environ['TPU_SKIP_MDS_QUERY'] = '1' if ( tpu_id is not None and tpu_id >= hardware_utils.TpuVersion.v5e From b9407380d60c831cccce9f24c69d4eddf2ae76ba Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 28 May 2025 11:49:13 -0700 Subject: [PATCH 1409/1769] Add visibility registration for `jax._src.sharding_impls` PiperOrigin-RevId: 764354256 --- jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/BUILD b/jax/BUILD index 2e2d7902577d..91a6e5926e42 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1172,6 +1172,7 @@ pytype_strict_library( pytype_strict_library( name = "sharding_impls", srcs = ["_src/sharding_impls.py"], + visibility = [":internal"] + jax_visibility("sharding_impls"), deps = [ ":config", ":core", From 1994074023349c6f6c144a474f5b9c0fea8f9e16 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 28 May 2025 13:24:17 -0700 Subject: [PATCH 1410/1769] Make CI job names to be shorter This strips away the redundant terms in job names to keep them shorter and easy to read. Actions displays job names that reuse workflows in the following format: `caller workflow name / called workflow name`. The changes here are done in the called workflow names as changing the caller workflow names seem to make the summary page hard to parse (see https://github.com/jax-ml/jax/actions/runs/15217612585). Here's how the continuous workflow's summary page looks like with this change: https://github.com/jax-ml/jax/actions/runs/15286609214/job/42998511666 PiperOrigin-RevId: 764390866 --- .github/workflows/bazel_cpu_py_import_rbe.yml | 3 ++- .github/workflows/bazel_cuda_non_rbe.yml | 5 ++++- .github/workflows/build_artifacts.yml | 5 ++++- .github/workflows/cloud-tpu-ci-presubmit.yml | 2 +- .github/workflows/pytest_cpu.yml | 4 +++- .github/workflows/pytest_cuda.yml | 4 +++- .github/workflows/pytest_tpu.yml | 2 +- 7 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index c98bcee980e8..7eb6d2ed27b4 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -49,7 +49,8 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} - name: "Bazel CPU tests with py_import (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 3e68034dfbf4..5a3ceaa8a4e8 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -60,7 +60,10 @@ jobs: # Enable writing to the Bazel remote cache bucket. JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1" - name: "Bazel single accelerator and multi-accelerator CUDA tests (jaxlib version=${{ inputs.jaxlib-version }}, ${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "jaxlib=${{ inputs.jaxlib-version }}, + ${{ (contains(inputs.runner, 'h100') && 'h100') || + (contains(inputs.runner, 'b200') && 'b200') || + (contains(inputs.runner, 'l4') && 'l4') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index d5fc35a99cd5..7bca28c3190d 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -111,7 +111,10 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) + name: "${{ inputs.artifact }}, + ${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}" # Map the job outputs to step outputs outputs: diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index 40c99735c2de..090259c0f849 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -54,7 +54,7 @@ jobs: needs: [build-jax-artifacts] uses: ./.github/workflows/pytest_tpu.yml # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "TPU test (jaxlib=head, v5e-8)" + name: "TPU test (jaxlib=head)" with: runner: "linux-x86-ct5lp-224-8tpu" cores: "8" diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 3af06fe8037e..263bfd7ec9a9 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -56,7 +56,9 @@ jobs: (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} - name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 78f32cda672d..5f8888526aad 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -68,7 +68,9 @@ jobs: container: ${{ !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') || inputs.use-nvidia-pip-wheels && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'}} - name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "${{ (contains(inputs.runner, 'h100') && 'h100') || + (contains(inputs.runner, 'b200') && 'b200') || + (contains(inputs.runner, 'l4') && 'l4') }}, CUDA ${{ inputs.cuda-version }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 22cd64977dc5..8c2457208e12 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -78,7 +78,7 @@ jobs: runs-on: ${{ inputs.runner }} container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" + name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}" # End Presubmit Naming Check github-tpu-presubmits env: From 36eeceb5ef4263559ede1bac18b63306c2e2ddd6 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Wed, 28 May 2025 20:50:03 +0000 Subject: [PATCH 1411/1769] Update actions to adhere to best practices --- .github/actionlint.yaml | 20 +++++++++++++++++++ .github/workflows/asan.yaml | 4 +++- .github/workflows/bazel_cpu_py_import_rbe.yml | 6 +++--- .github/workflows/bazel_cpu_rbe.yml | 4 +++- .github/workflows/bazel_cuda_non_rbe.yml | 3 ++- .github/workflows/bazel_cuda_rbe.yml | 4 +++- .../workflows/bazel_optional_h100_b200.yml | 4 ++++ .github/workflows/build_artifacts.yml | 7 +++---- .github/workflows/ci-build.yaml | 17 ++++++++++++---- .github/workflows/cloud-tpu-ci-nightly.yml | 3 +++ .github/workflows/cloud-tpu-ci-presubmit.yml | 4 +--- .github/workflows/jax-array-api.yml | 8 ++++---- .github/workflows/k8s.yaml | 11 +++------- .github/workflows/metal_plugin_ci.yml | 3 ++- .github/workflows/numpy_nightly.yml | 7 ++++--- .github/workflows/oldest_supported_numpy.yml | 13 ++++-------- .github/workflows/pytest_cpu.yml | 2 ++ .github/workflows/pytest_cuda.yml | 2 ++ .github/workflows/pytest_tpu.yml | 3 ++- .github/workflows/release-notification.yml | 11 ++++++++-- .github/workflows/rocm-ci.yml | 5 ++--- .github/workflows/tsan.yaml | 8 ++++++-- .github/workflows/upstream-nightly.yml | 4 +++- .../workflows/wheel_tests_nightly_release.yml | 2 +- 24 files changed, 102 insertions(+), 53 deletions(-) create mode 100644 .github/actionlint.yaml diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml new file mode 100644 index 000000000000..e7ee1a086558 --- /dev/null +++ b/.github/actionlint.yaml @@ -0,0 +1,20 @@ +# Configuration related to self-hosted runner. +self-hosted-runner: + labels: + - "linux-x86-n2-32" # Linux X86 runner using the 32 vcpu n2-standard-32 machine. + - "linux-x86-n2-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine. + - "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached. + - "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached. + - "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology. + - "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine. + - "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine. + - "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine. + - "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine. + - "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine + - "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached. + - "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology. + - "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology. + - "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology. + - "linux-x86-n2-128" # Linux X86 runner using the 128 vcpu n2-standard-128 machine. + - "linux-x86-n2-16" # Linux X86 runner using the 16 vcpu n2-standard-16 machine. + - "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea69d92e552e..533d4381f474 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -13,7 +13,7 @@ on: - main paths: - '**/workflows/asan.yaml' - +permissions: {} jobs: asan: # Don't execute in fork due to runner type @@ -41,11 +41,13 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax + persist-credentials: false - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: python/cpython path: cpython ref: v3.13.0 + persist-credentials: false - name: Build CPython with ASAN enabled env: ASAN_OPTIONS: detect_leaks=0 diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index 7eb6d2ed27b4..d6173a809500 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -9,9 +9,7 @@ # - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions: # - Runs the Bazel CPU tests with py_import dependency. name: CI - Bazel CPU tests with py_import (RBE) -permissions: - contents: read - +permissions: {} on: workflow_call: inputs: @@ -54,6 +52,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index a8b40c260260..2f8eb2c33cee 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -18,7 +18,7 @@ on: branches: - main - 'release/**' - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. @@ -53,6 +53,8 @@ jobs: # End Presubmit Naming Check github-cpu-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 5a3ceaa8a4e8..348d19763989 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -44,7 +44,6 @@ on: type: string required: false default: 'no' - jobs: run-tests: defaults: @@ -67,6 +66,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index 83f651c0ef95..cd4e9a021cfc 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -23,7 +23,7 @@ concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} - +permissions: {} jobs: run_tests: if: github.event.repository.fork == false @@ -49,6 +49,8 @@ jobs: # End Presubmit Naming Check github-cuda-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml index ec907280938e..68fea50857a5 100644 --- a/.github/workflows/bazel_optional_h100_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -32,6 +32,8 @@ jobs: # End Presubmit Naming Check github-cuda-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: @@ -74,6 +76,8 @@ jobs: name: "Bazel multiple H100 CUDA tests" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 7bca28c3190d..90c888471b73 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -90,10 +90,7 @@ on: gcs_upload_uri: description: "GCS location prefix to where the artifacts were uploaded" value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }} - -permissions: - contents: read - +permissions: {} jobs: build-artifacts: defaults: @@ -122,6 +119,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Enable RBE if building on Linux x86 or Windows x86 if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0769c698d5fe..ada470526ef8 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -16,10 +16,7 @@ on: branches: - main -permissions: - contents: read # to fetch code - actions: write # to cancel previous workflows - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true @@ -30,6 +27,8 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python 3.11 uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: @@ -65,6 +64,8 @@ jobs: num_generated_cases: 1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Image Setup run: | apt update @@ -106,6 +107,8 @@ jobs: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: @@ -136,6 +139,8 @@ jobs: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Image Setup run: | apt update @@ -166,6 +171,8 @@ jobs: num_generated_cases: 10 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: @@ -198,6 +205,8 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index c7394a498dd6..3ed560b04a88 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -57,6 +57,8 @@ jobs: # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false # Checkout XLA at head, if we're building jaxlib at head. - name: Checkout XLA at head uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -64,6 +66,7 @@ jobs: with: repository: openxla/xla path: xla + persist-credentials: false # We need to mark the GitHub workspace as safe as otherwise git commands will fail. - name: Mark GitHub workspace as safe run: | diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index 090259c0f849..fe1f2820b338 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -25,9 +25,7 @@ on: # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 6419cb730b71..16a72fe34714 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -11,22 +11,21 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - +permissions: {} jobs: build: - runs-on: linux-x86-n2-16 container: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest strategy: matrix: python-version: [3.11] - env: PYTHON: "python${{ matrix.python-version }}" - steps: - name: Checkout jax uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Checkout array-api-tests uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -35,6 +34,7 @@ jobs: ref: 'c847143beb8d769bde5dbcc063fe19ed7acc2f9b' # Latest commit as of 2025-05-12 submodules: 'true' path: 'array-api-tests' + persist-credentials: false - name: Install dependencies run: | $PYTHON -m uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 5756b1afbbd2..81552f9bb43b 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -1,5 +1,4 @@ name: Multi-process run using K8s - on: push: branches: @@ -16,19 +15,14 @@ on: - 'jax/_src/distributed.py' - 'jax/_src/clusters/**' -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - defaults: run: shell: bash -ex -o pipefail {0} - jobs: - distributed-initialize: runs-on: ubuntu-22.04 strategy: @@ -40,6 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 with: path: jax + persist-credentials: false - name: Start Minikube cluster uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 @@ -105,7 +100,7 @@ jobs: done - name: Examine individual pod outputs - if: "!cancelled()" + if: ${{ !cancelled() }} run: | set +x kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 2135e473d6be..c76153d48f10 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -14,7 +14,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - +permissions: {} jobs: jax-metal-plugin-test: @@ -30,6 +30,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax + persist-credentials: false - name: Setup build and test enviroment run: | rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv diff --git a/.github/workflows/numpy_nightly.yml b/.github/workflows/numpy_nightly.yml index 17357e9f1dd8..d9a858216857 100644 --- a/.github/workflows/numpy_nightly.yml +++ b/.github/workflows/numpy_nightly.yml @@ -18,9 +18,7 @@ on: schedule: - cron: "0 */3 * * *" # Run once every 3 hours -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. @@ -46,12 +44,15 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Checkout ml_dtypes uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: repository: jax-ml/ml_dtypes ref: main path: ml_dtypes + persist-credentials: false # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index a63cb0b1c614..c7f1f1e38a26 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -10,12 +10,7 @@ on: push: branches: - main - -# This should also be set to read-only in the project settings, but it's nice to -# document and enforce the permissions here. -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. @@ -23,7 +18,7 @@ concurrency: jobs: test-oldest-supported-numpy: - if: "github.event.repository.fork == false && !startsWith(github.head_ref, 'release/')" + if: ${{ github.event.repository.fork == false && !startsWith(github.head_ref, 'release/') }} defaults: run: shell: bash @@ -40,6 +35,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Install Python dependencies run: | $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt @@ -52,8 +49,6 @@ jobs: # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests timeout-minutes: 30 run: ./ci/run_pytest_cpu.sh \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 263bfd7ec9a9..d08cb520eab4 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -67,6 +67,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 5f8888526aad..d095dbfb9e80 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -79,6 +79,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 8c2457208e12..ce5dcf0c9fc8 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -11,7 +11,6 @@ # - Installs the downloaded jaxlib wheel. # - Runs the TPU tests with Pytest. name: CI - Pytest TPU - on: workflow_call: inputs: @@ -90,6 +89,8 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') diff --git a/.github/workflows/release-notification.yml b/.github/workflows/release-notification.yml index a4a342ef6de7..6d68bf922655 100644 --- a/.github/workflows/release-notification.yml +++ b/.github/workflows/release-notification.yml @@ -2,14 +2,21 @@ name: Google Chat Release Notification on: release: types: [published] +permissions: {} jobs: build: + env: + WEBHOOK_URL: ${{ secrets.RELEASES_WEBHOOK }} + RELEASE_NAME: ${{github.event.release.name}} + PUBLISHED_AT: ${{github.event.release.published_at}} + AUTHOR_LOGIN: ${{github.event.release.author.login}} + RELEASE_URL: ${{github.event.release.url}} runs-on: ubuntu-latest steps: - name: Google Chat Notification run: | - curl --location --request POST '${{ secrets.RELEASES_WEBHOOK }}' \ + curl --location --request POST '${WEBHOOK_URL}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "text": "Release ${{github.event.release.name}} at ${{github.event.release.published_at}} by ${{github.event.release.author.login}}. <${{github.event.release.url}}|[github]>" + "text": "Release $RELEASE_NAME at $PUBLISHED_AT by $AUTHOR_LOGIN. <$RELEASE_URL|[github]>" }' diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 0ce20726ce63..4bfb8cb50a5e 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -6,9 +6,7 @@ on: branches: - main -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -36,6 +34,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: ${{ env.WORKSPACE_DIR }} + persist-credentials: false - name: Build JAX run: | pushd $WORKSPACE_DIR diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index ce4130c31a30..67ff8dd93e3d 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -3,7 +3,6 @@ name: CI - Free-threading and Thread Sanitizer (nightly) concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - on: schedule: - cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST @@ -14,7 +13,7 @@ on: paths: - '**/workflows/tsan.yaml' - '**/workflows/tsan-suppressions*.txt' - +permissions: {} jobs: tsan: runs-on: linux-x86-n2-64 @@ -50,17 +49,20 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax + persist-credentials: false - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + persist-credentials: false - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 if: ${{ matrix.python-version == '3.14' }} with: repository: scipy/scipy path: scipy submodules: true + persist-credentials: false - name: Get year & week number id: get-date @@ -81,6 +83,8 @@ jobs: repository: python/cpython path: cpython ref: ${{ matrix.github_branch }} + persist-credentials: false + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 349ddf0d96a3..23b8ac32d844 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -19,7 +19,7 @@ on: - main paths: - '**workflows/upstream-nightly.yml' - +permissions: {} jobs: upstream-dev: runs-on: linux-x86-n2-64 @@ -33,6 +33,8 @@ jobs: python-version: ["3.13"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 7bad41647e6b..c536466c7dcb 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -33,7 +33,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - +permissions: {} jobs: run-pytest-cpu: uses: ./.github/workflows/pytest_cpu.yml From 5aa339561871ccf037f1216237cf4e5db937376c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 28 May 2025 14:31:28 -0700 Subject: [PATCH 1412/1769] [Mosaic GPU] Rework CUDA_ROOT logic a bit PiperOrigin-RevId: 764419062 --- jaxlib/mosaic/gpu/custom_call.cc | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index bf6a04783be7..524d0ffe23ab 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -45,6 +45,7 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +// Leave this comment here. Internal Google business. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/TargetSelect.h" @@ -130,12 +131,17 @@ class TemporaryDirectory { std::string path; }; +const char *GetCUDARoot() { + return getenv("CUDA_ROOT"); +} + absl::StatusOr RunCUDATool(const char* tool, const std::vector& args, bool stderr_to_stdout = true) { CHECK(!args.empty() && args.back() == nullptr); - const char* cuda_path_ptr = getenv("CUDA_ROOT"); - if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); + const char* cuda_path_ptr = GetCUDARoot(); + if (!cuda_path_ptr) + return absl::InternalError("Failed to get the CUDA toolkit path"); std::string tool_path(cuda_path_ptr); tool_path += "/bin/"; tool_path += tool; @@ -338,6 +344,10 @@ mlir::FailureOr GetPassPipeline( return true; }); bool emit_line_info = getenv("MOSAIC_GPU_LINE_INFO") != nullptr; + const char *cuda_root = GetCUDARoot(); + if (!cuda_root) { + return mlir::failure(); + } return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( @@ -374,11 +384,12 @@ mlir::FailureOr GetPassPipeline( // TODO(slebedev): Switch to the ensure-debug-info-scope-on-llvm-func // pass in MLIR once Triton upstreams its changes. emit_line_info ? "enable-line-info," : "", - R"( - gpu-module-to-binary{format=)" + - mlir::gpu::stringifyCompilationTarget(target).str() + - (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + - (emit_line_info ? " opts=-lineinfo" : "") + R"(}, + "gpu-module-to-binary{format=", + mlir::gpu::stringifyCompilationTarget(target).str(), + (!nvshmem_path.empty() ? " l=" + nvshmem_path : ""), + (emit_line_info ? " opts=-lineinfo" : ""), + " toolkit=", cuda_root, + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, From c5b908ceca710900e0e099b451b7c138ffcdf0ec Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Wed, 28 May 2025 15:13:51 -0700 Subject: [PATCH 1413/1769] Change all us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build container to us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build. These containers are the same (same build script), but they are just in a different repositories. PiperOrigin-RevId: 764435895 --- .github/workflows/bazel_cpu_py_import_rbe.yml | 2 +- .github/workflows/bazel_cpu_rbe.yml | 2 +- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- .github/workflows/bazel_cuda_rbe.yml | 2 +- .github/workflows/bazel_optional_h100_b200.yml | 4 ++-- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- .github/workflows/jax-array-api.yml | 2 +- .github/workflows/numpy_nightly.yml | 2 +- .github/workflows/oldest_supported_numpy.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_cuda.yml | 6 +++--- .github/workflows/pytest_tpu.yml | 2 +- .github/workflows/wheel_tests_nightly_release.yml | 2 +- ci/envs/docker.env | 2 +- 15 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index d6173a809500..09c9d173e0d0 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -41,7 +41,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 2f8eb2c33cee..3eff0932adcb 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -28,7 +28,7 @@ jobs: run_tests: if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 348d19763989..458589199c53 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -51,7 +51,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index cd4e9a021cfc..2c57b35587fa 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -28,7 +28,7 @@ jobs: run_tests: if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest' env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml index 68fea50857a5..16c7bb95c16b 100644 --- a/.github/workflows/bazel_optional_h100_b200.yml +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -27,7 +27,7 @@ jobs: run_tests: if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} runs-on: linux-x86-a4-224-b200-1gpu - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest' name: "Bazel single B200 CUDA tests" # End Presubmit Naming Check github-cuda-presubmits steps: @@ -72,7 +72,7 @@ jobs: run_multiaccelerator_tests: if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} runs-on: linux-x86-a3-8g-h100-8gpu - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest' name: "Bazel multiple H100 CUDA tests" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 90c888471b73..72d554aa5d1b 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -100,7 +100,7 @@ jobs: runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 3ed560b04a88..1f096ce48e2d 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -47,7 +47,7 @@ jobs: LIBTPU_OLDEST_VERSION_DATE: 20250228 PYTHON: python${{ matrix.python-version }} runs-on: ${{ matrix.tpu.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" timeout-minutes: 180 defaults: run: diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 16a72fe34714..41879a6f2e9f 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -15,7 +15,7 @@ permissions: {} jobs: build: runs-on: linux-x86-n2-16 - container: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest strategy: matrix: python-version: [3.11] diff --git a/.github/workflows/numpy_nightly.yml b/.github/workflows/numpy_nightly.yml index d9a858216857..c0036ccf8f7f 100644 --- a/.github/workflows/numpy_nightly.yml +++ b/.github/workflows/numpy_nightly.yml @@ -33,7 +33,7 @@ jobs: strategy: matrix: python: ["3.13",] - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" name: "CI - jaxlib head with NumPy nightly" env: diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index c7f1f1e38a26..67fc9f10e5ce 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -23,7 +23,7 @@ jobs: run: shell: bash runs-on: "linux-x86-n2-64" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated name: "CI - Oldest Supported NumPy (Python 3.10, x64=0)" # End Presubmit Naming Check github-oldest-supported-numpy-presubmit diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index d08cb520eab4..a92f2d96dc89 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -52,7 +52,7 @@ jobs: # Explicitly set the shell to bash to override Windows's default (cmd) shell: bash runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index d095dbfb9e80..6fa4e14f8b85 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -65,9 +65,9 @@ jobs: # Test the oldest and newest supported CUDA versions. # If testing the CUDA packages from PyPI, then use the ml-build image which does not have any # CUDA pckages installed on the system. - container: ${{ !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || - !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') || - inputs.use-nvidia-pip-wheels && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'}} + container: ${{ !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.1') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.1-cudnn9.8:latest') || + !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.8') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest') || + inputs.use-nvidia-pip-wheels && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest'}} name: "${{ (contains(inputs.runner, 'h100') && 'h100') || (contains(inputs.runner, 'b200') && 'b200') || (contains(inputs.runner, 'l4') && 'l4') }}, CUDA ${{ inputs.cuda-version }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index ce5dcf0c9fc8..5f56b165c295 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -75,7 +75,7 @@ jobs: run: shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}" # End Presubmit Naming Check github-tpu-presubmits diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index c536466c7dcb..6d25ee281c7b 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -154,7 +154,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: python: ["3.10", "3.13", "3.13-nogil"] - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Verifies that JAX's release wheels can be installed name: "Verify release wheels install (Python ${{ matrix.python }})" diff --git a/ci/envs/docker.env b/ci/envs/docker.env index a0f558520d45..5135b61ac45b 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -29,7 +29,7 @@ export JAXCI_DOCKER_ARGS="" # Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" fi # Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and From 2dc69daec8ed513668e155bc3c9973f2d5d32b05 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 28 May 2025 15:23:43 -0700 Subject: [PATCH 1414/1769] Integrate LLVM at llvm/llvm-project@2b8bff6f66fd Updates LLVM usage to match [2b8bff6f66fd](https://github.com/llvm/llvm-project/commit/2b8bff6f66fd) PiperOrigin-RevId: 764439621 --- jaxlib/mosaic/dialect/tpu/tpu.td | 4 ++-- .../mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 8 ++++---- .../mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 8 +++++--- .../mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 6 ++++-- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 505478b9ad72..766900cd07e4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -513,8 +513,8 @@ def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { // Internal operation. All arith.sitofp operations that change the bitwidth // must be canonicalized to this operation. def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyVectorOfAnyRank:$in, TPU_RoundingModeEnum:$rounding_mode); - let results = (outs AnyVectorOfAnyRank:$output); + let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyType:$output); let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 6502a9c6682e..53d8712d5274 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1161,10 +1161,10 @@ LogicalResult tpu_sitofp_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( xla::Array vregs, ext_op_rule_impl(ctx, builder, sitofp_op, layout_in, layout_out)); - sitofp_op.replaceAllUsesWith(assemble(builder, sitofp_op.getType(), - layout_out, std::move(vregs), - ctx.target_shape) - .getResult()); + sitofp_op.replaceAllUsesWith( + assemble(builder, cast(sitofp_op.getType()), layout_out, + std::move(vregs), ctx.target_shape) + .getResult()); sitofp_op.erase(); return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 1d8ea1299f04..c963cff0be50 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -715,10 +715,12 @@ LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx, } } if (is_vector) { - x = builder.create( - VectorType::get(src_vty.getShape(), builder.getF32Type()), x); + x = builder.create( + VectorType::get(src_vty.getShape(), builder.getF32Type()), x, + tpu::RoundingMode::kToNearestEven); } else { - x = builder.create(builder.getF32Type(), x); + x = builder.create(builder.getF32Type(), x, + tpu::RoundingMode::kToNearestEven); } if (dst_bitwidth < 32) { x = builder.create(op.getType(), x); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 976e31cb55f4..14d1fb2104fa 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -155,8 +155,10 @@ class VectorLayoutInferer { return failure(); } } else if (auto op = dyn_cast(any_op); - op && op.getIn().getType().getElementTypeBitWidth() < - op.getType().getElementTypeBitWidth()) { + op && + cast(op.getIn().getType()) + .getElementTypeBitWidth() < + cast(op.getType()).getElementTypeBitWidth()) { if (inferExt(&any_op).failed()) { return failure(); } From 37a9ac23681d85e5e5663d21fac4b88224715ba3 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 28 May 2025 17:45:44 -0700 Subject: [PATCH 1415/1769] [Pallas Fuser] Add support for basic PRNG op fusion PiperOrigin-RevId: 764490044 --- jax/_src/pallas/core.py | 1 + jax/_src/pallas/fuser/block_spec.py | 94 +++++++++++++++++++--- jax/_src/pallas/mosaic/lowering.py | 30 +++++--- tests/pallas/fuser_block_spec_test.py | 107 +++++++++++++++++++------- 4 files changed, 186 insertions(+), 46 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 7950f90bc377..a05f97eb122f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -272,6 +272,7 @@ class MemorySpace(enum.Enum): ANY = "any" # Unrestricted memory space (usually HBM) ERROR = "error" # Memory space for checkify errors. INDEX = "index" # Memory space for scalar prefetch arguments. + KEY = "key" # Memory space for PRNG keys. def __str__(self) -> str: return self.value diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 3e9ff497bf1e..4f9d1c344429 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -29,6 +29,7 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import pjit +from jax._src import prng from jax._src import state from jax._src import tree_util from jax._src import util @@ -215,7 +216,7 @@ def _wrap_block_spec_scalar_prefetch( block_spec: pallas_core.BlockSpec, num_grid_args: int, ) -> pallas_core.BlockSpec: - if block_spec is pallas_core.no_block_spec: + if block_spec is pallas_core.no_block_spec or block_spec.index_map is None: return block_spec def new_index_map(*args_and_scalar_prefetch): @@ -272,11 +273,12 @@ def wrapped(*args, **kwargs): ) assert all(used_invars) assert all(used_consts) + read_usage_env = compute_usage(jaxpr, jaxpr_out_usages) in_block_specs, env, read_usage_env = _pull_block_spec( jaxpr, tuple(flat_block_specs), - jaxpr_out_usages, scalar_prefetch_handler=scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=grid, ) kernel_fn = make_kernel_function( @@ -307,8 +309,8 @@ def wrapped(*args, **kwargs): def _pull_block_spec( jaxpr: core.Jaxpr, out_block_specs: tuple[pallas_core.BlockSpec, ...], - out_usages, *, + read_usage_env: Callable[[core.Var], set[Usage]], scalar_prefetch_handler: Any | None = None, grid: tuple[int | jax.Array, ...], ) -> tuple[ @@ -316,7 +318,6 @@ def _pull_block_spec( tuple[dict[core.Var, pallas_core.BlockSpec], dict[int, Any]], Any, ]: - read_usage_env = compute_usage(jaxpr, out_usages) jaxpr_invar_usages = util.safe_map(read_usage_env, jaxpr.invars) env: dict[core.Var, pallas_core.BlockSpec] = {} scalar_prefetch_fn_env = {} @@ -456,6 +457,8 @@ def _get_block_aval(bs, aval): return aval if bs is pallas_core.no_block_spec or bs is None: return _no_aval + if bs.block_shape is None: + return aval return aval.update(shape=_remove_nones(bs.block_shape)) # pytype: disable=attribute-error in_block_avals = [ @@ -830,7 +833,10 @@ def register_binop_rule(prim: core.Primitive): register_binop_rule(lax.eq_p) register_binop_rule(lax.gt_p) register_binop_rule(lax.ge_p) +register_binop_rule(lax.or_p) +register_binop_rule(lax.xor_p) register_binop_rule(lax.and_p) +register_binop_rule(lax.shift_right_logical_p) register_binop_rule(ad_util.add_any_p) @@ -1473,6 +1479,68 @@ def _convert_element_type_pull_rule( return [block_spec] +@register_eval_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_eval_rule(eval_ctx: KernelEvalContext, x, new_dtype): + return jax.lax.bitcast_convert_type(x, new_dtype) + + +@register_pull_block_spec_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + new_dtype: jnp.dtype, +): + old_dtype = ctx.avals_in[0].dtype # pytype: disable=attribute-error + if old_dtype.itemsize != new_dtype.itemsize: + raise NotImplementedError( + 'bitcast_convert_type with different bitwidths not supported yet:' + f' {old_dtype=}, {new_dtype=}' + ) + return [block_spec] + + +@register_eval_rule(prng.random_bits_p) +def _random_bits_eval_rule(eval_ctx: KernelEvalContext, key, bit_width, shape): + del shape + block_spec = eval_ctx.out_block_specs[0] + indices = eval_ctx.get_out_block_indices()[0] + block_shape = block_spec.block_shape + # This is the important part here: we fold in block indices into the key so + # each block gets different random numbers. + for idx in indices: + key = jax.random.fold_in(key, idx) + return prng.random_bits(key, bit_width=bit_width, shape=block_shape) + + +@register_pull_block_spec_rule(prng.random_bits_p) +def _random_bits_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **_, +): + del ctx, block_spec + key_block_spec = pallas_core.BlockSpec( + block_shape=None, memory_space=pallas_core.MemorySpace.KEY + ) + return [key_block_spec] + +@register_eval_rule(prng.random_wrap_p) +def _random_wrap_eval_rule(eval_ctx: KernelEvalContext, arr, *, impl): + del eval_ctx + return jax.random.wrap_key_data(arr, impl=impl) + +@register_pull_block_spec_rule(prng.random_wrap_p) +def _random_wrap_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + impl +): + del ctx, block_spec, impl + return [pallas_core.BlockSpec(block_shape=None)] + + @register_eval_rule(lax.iota_p) def _iota_eval_rule( eval_ctx: KernelEvalContext, *, dimension, shape, dtype, sharding @@ -1599,12 +1667,13 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): raise NotImplementedError('pjit with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + def read_usage_env(_: core.Var): + return {Usage.REGULAR} _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) kernel_fn = make_kernel_function( @@ -1628,11 +1697,13 @@ def _jit_pull_block_spec_rule( jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError('pjit with consts not supported yet') + def read_usage_env(_: core.Var): + return {Usage.REGULAR} in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) return in_block_specs @@ -1657,13 +1728,14 @@ def _custom_jvp_call_eval_rule( raise NotImplementedError('custom_jvp_call with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + def read_usage_env(_: core.Var): + return {Usage.REGULAR} _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) kernel_fn = make_kernel_function( jaxpr, @@ -1686,12 +1758,14 @@ def _custom_jvp_call_pull_block_spec_rule( jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts if consts: raise NotImplementedError('custom_jvp_call with consts not supported yet') + def read_usage_env(_: core.Var): + return {Usage.REGULAR} in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) return in_block_specs diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 635c473620c3..c6aaf77199b5 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -222,7 +222,11 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None case pallas_core.MemorySpace.ANY: # Map the general ANY memory space to TPU ANY memory space return TPUMemorySpace.ANY - case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX: + case ( + pallas_core.MemorySpace.ERROR + | pallas_core.MemorySpace.INDEX + | pallas_core.MemorySpace.KEY + ): return TPUMemorySpace.SMEM case TPUMemorySpace(): # Leave the memory space unchanged @@ -365,7 +369,7 @@ def _get_arg_type( ): memory_space = None if isinstance(aval, pallas_core.AbstractMemoryRef): - memory_space = aval.memory_space + memory_space = _memory_space_to_tpu_memory_space(aval.memory_space) # We assume unannotated memory refs are in VMEM if memory_space is None: memory_space = TPUMemorySpace.VMEM @@ -595,10 +599,10 @@ def _check_block_mappings( rank = len(bm.block_shape) # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and - bm.has_trivial_window()): + memory_space = _memory_space_to_tpu_memory_space(bm.block_aval.memory_space) + if memory_space == tpu_core.TPUMemorySpace.SMEM and bm.has_trivial_window(): continue - if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + if memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: continue def err_details(): @@ -614,8 +618,10 @@ def err_details(): "The Pallas TPU lowering currently supports only blocks of " "rank >= 1. " + err_details()) - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and - not bm.has_trivial_window()): + if ( + memory_space == tpu_core.TPUMemorySpace.ANY + and not bm.has_trivial_window() + ): raise ValueError( "The Pallas TPU lowering currently supports in memory space ANY " "only blocks having the same block shape as the array shape " @@ -3723,10 +3729,16 @@ def new_lowering(key, bit_width, shape): @register_lowering_rule(prng.random_fold_in_p) def random_fold_in_lowering(ctx, keys, msgs): - keys_aval, _ = ctx.avals_in + keys_aval, msgs_aval = ctx.avals_in impl = keys_aval.dtype._impl fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False) - return fold_in_lowering(ctx, keys, msgs) + if pl_random.is_pallas_impl(impl): + return fold_in_lowering(ctx, keys, msgs) + else: + ctx = dataclasses.replace(ctx, + avals_in=[jax_core.physical_aval(keys_aval), msgs_aval], + avals_out=map(jax_core.physical_aval, ctx.avals_out)) + return fold_in_lowering(ctx, keys, msgs) @register_lowering_rule(prng.random_unwrap_p) diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 5c0ef0352b1c..b348ba971c38 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -755,7 +755,9 @@ def f(): )(new_values) self.assertLen(value_block_specs, 1) self.assertEmpty(scalar_prefetch_values) - self.assertEqual(value_block_specs[0].block_shape, (pl.Element(128, (0, 16)), 128)) + self.assertEqual( + value_block_specs[0].block_shape, (pl.Element(128, (0, 16)), 128) + ) self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (16, 1)) self.assertEqual(value_block_specs[0].index_map(1, 1, 2), (128 + 16, 1)) @@ -801,10 +803,13 @@ def f(x): def test_basic_swap(self): value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 x = jnp.zeros((256, 512), dtype=jnp.int32) + def outer(refs): ref, y_ref = refs + def f(x): return ref.swap(x) + in_type = jax.ShapeDtypeStruct((512, 1024), jnp.int32) f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( f, in_type @@ -826,73 +831,121 @@ def f(x): self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) y_ref[...] = kernel_fn((0, 1, 1), scalar_prefetch_values, (ref,), x) + y = jnp.zeros((256, 512), jnp.int32) _, y = pl.run_state(outer)((value, y)) np.testing.assert_array_equal(y, value[:256, 512:1024]) def test_basic_get(self): value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + def outer(refs): ref, y_ref = refs + def f(): return ref.get() block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) - kernel_fn, (), _ = ( - block_spec_lib.pull_block_spec( - f, - block_spec, - grid=(2, 3, 4), - scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), - )() - ) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() y_ref[...] = kernel_fn((0, 1, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) _, y = pl.run_state(outer)((value, y)) np.testing.assert_array_equal(y, value[:256, 512:1024]) def test_get_with_squeezed_block_spec(self): - value = jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) * 2 + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + def outer(refs): ref, y_ref = refs + def f(): return ref.get() - block_spec = pl.BlockSpec((pl.Squeezed(), 256, 512), lambda i, j, k: (j, i, k)) - kernel_fn, (), _ = ( - block_spec_lib.pull_block_spec( - f, - block_spec, - grid=(2, 3, 4), - scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), - )() + block_spec = pl.BlockSpec( + (pl.Squeezed(), 256, 512), lambda i, j, k: (j, i, k) ) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() y_ref[...] = kernel_fn((0, 3, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) _, y = pl.run_state(outer)((value, y)) np.testing.assert_array_equal(y, value[3, :256, 512:1024]) def test_get_with_squeezed_indexer(self): - value = jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) * 2 + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + def outer(refs): ref, y_ref = refs + def f(): return ref[3] block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) - kernel_fn, (), _ = ( - block_spec_lib.pull_block_spec( - f, - block_spec, - grid=(2, 3, 4), - scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), - )() - ) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() y_ref[...] = kernel_fn((0, 2, 1), ()) + y = jnp.zeros((256, 512), jnp.int32) _, y = pl.run_state(outer)((value, y)) np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + def test_random_noise(self): + key = jax.random.key(0, impl='threefry2x32') + + def f(key): + return jax.random.uniform(key, (512, 512), dtype=jnp.float32) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, key + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((128, 256), lambda i, j: (i, j)) + kernel_fn, (value_block_specs, key_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(4, 2), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, key) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(key_block_spec.memory_space, pl.MemorySpace.KEY) + self.assertIsNone(key_block_spec.block_shape) + @jax.jit + def gen(idx): + k = key + for i in idx: + k = jax.random.fold_in(k, i) + return jax.random.uniform(k, (128, 256), dtype=jnp.float32) + for i in range(4): + for j in range(2): + out = kernel_fn((i, j), scalar_prefetch_values, (), key) + out_ref = gen((i, j)) + np.testing.assert_array_equal(out, out_ref) + class PullBlockSpecHOPTest(jtu.JaxTestCase): From a4a31ecd8476a4d10dedcf226d5a116af2d056b3 Mon Sep 17 00:00:00 2001 From: DanisNone Date: Thu, 29 May 2025 11:02:14 +0500 Subject: [PATCH 1416/1769] A more numerically stable implementation of logaddexp2 --- jax/_src/lax/other.py | 32 ++++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 3 +-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 00e15ef6a91d..6da39b0c2405 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -287,3 +287,35 @@ def _logaddexp_jvp(primals, tangents): tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out + + +@custom_jvp +def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute log2(exp2(x1) + exp2(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + invln2 = lax._const(amax, 1/np.log(2)) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(lax.neg(lax.abs(delta))))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1_arr, x2_arr), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(delta)))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) + else: + raise ValueError(f"logaddexp2 requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp2.defjvp +def _logaddexp2_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp2(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 77b1220214ed..d722534e3136 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2782,8 +2782,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - ln2 = float(np.log(2)) - return logaddexp(x1 * ln2, x2 * ln2) / ln2 + return lax_other.logaddexp2(x1, x2) @export From 89828819a383911db365661e8c7b6203299c35fb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 May 2025 03:22:22 -0700 Subject: [PATCH 1417/1769] [mosaic_gpu] Use `DIScopeForLLVMFuncOpPass` from MLIR instead of its Triton fork PiperOrigin-RevId: 764652343 --- jaxlib/mosaic/gpu/BUILD | 10 ++++------ jaxlib/mosaic/gpu/custom_call.cc | 13 ++++--------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index fc1abb9397d5..d2abea0048d6 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -151,6 +151,8 @@ cc_library( ":nvshmem", ":passes", ":target", + "//jaxlib/cuda:cuda_vendor", + "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -181,6 +183,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", @@ -200,16 +203,11 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", - "//jaxlib/cuda:cuda_vendor", - "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "@tsl//tsl/profiler/lib:traceme", "@xla//xla/ffi", "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", - "@tsl//tsl/profiler/lib:traceme", - # TODO(slebedev): Remove once enable-line-info is merged into the upstream - # ensure-debug-info-scope-on-llvm-func pass in MLIR. - "@triton//:TritonLLVMIR", ], alwayslink = True, ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 524d0ffe23ab..5253d4590658 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" @@ -102,7 +103,6 @@ limitations under the License. #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" #include "tsl/profiler/lib/traceme.h" -#include "triton/Target/LLVMIR/Passes.h" namespace { @@ -340,7 +340,7 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass(); mlir::arith::registerArithExpandOpsPass(); - mlir::registerLLVMDIScopePass(); + mlir::LLVM::registerDIScopeForLLVMFuncOpPass(); return true; }); bool emit_line_info = getenv("MOSAIC_GPU_LINE_INFO") != nullptr; @@ -360,10 +360,7 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{)", - // TODO(slebedev): Always use O=3 once - // https://github.com/llvm/llvm-project/pull/140146 is merged. - emit_line_info ? "O=0" : "O=3", " chip=", sm, " fast=false features=+", + nvvm-attach-target{O=3 chip=)", sm, " fast=false features=+", ptx_isa, R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, @@ -381,9 +378,7 @@ mlir::FailureOr GetPassPipeline( gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, )", - // TODO(slebedev): Switch to the ensure-debug-info-scope-on-llvm-func - // pass in MLIR once Triton upstreams its changes. - emit_line_info ? "enable-line-info," : "", + emit_line_info ? "ensure-debug-info-scope-on-llvm-func{emission-kind=DebugDirectivesOnly}," : "", "gpu-module-to-binary{format=", mlir::gpu::stringifyCompilationTarget(target).str(), (!nvshmem_path.empty() ? " l=" + nvshmem_path : ""), From bc33d0eefd65aa7d49c20e89776ffd834663c54d Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 29 May 2025 03:28:18 -0700 Subject: [PATCH 1418/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f9ea70486e5625484325b3f451a32242507493ad. PiperOrigin-RevId: 764653997 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6b12b8a028a4..d2a0a2e8b29a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "acc83e32f93d83280d3672aa9194847a3d416b06" -XLA_SHA256 = "d751dbe8cd7baa04c3def33761cb2e0194f8b1923b591a0cb91479acbf3778ab" +XLA_COMMIT = "f9ea70486e5625484325b3f451a32242507493ad" +XLA_SHA256 = "48c1a6e87b580becb8dbc018028be2835c5f0bd941ae36f450102f3f16a79398" def repo(): tf_http_archive( From 50253f1acfe5434a7a50507fd940641984d9e7a7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 May 2025 04:34:20 -0700 Subject: [PATCH 1419/1769] [pallas:mosaic_gpu] `emit_pipeline` now allows specifying a carry PiperOrigin-RevId: 764672556 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 59 +++++++++++++++++--------- tests/pallas/mosaic_gpu_test.py | 25 +++++++++++ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a7f8d32677b0..f85b73b6b946 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -16,13 +16,12 @@ from __future__ import annotations -from typing import Protocol, TypeVar from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it import math -from typing import Any +from typing import Any, Protocol, TypeVar import jax from jax import api_util @@ -176,28 +175,44 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( - body: Callable[..., None], + body: Callable[..., T], *, grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, delay_release: int = 0, + init_carry: T | None = None, ): - """Creates a function to emit a manual pipeline within a Pallas kernel. + r"""Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body, called with the indices for the current step, the - input refs, followed by the output refs. - grid: The grid to use for the pipeline. - in_specs: The block specs for the inputs. - out_specs: The block specs for the outputs. - max_concurrent_steps: The maximum number of sequential stages that are - active concurrently. Defaults to 1. - delay_release: The number of steps to wait before reusing the input/output - references. Defaults to 0, and must be strictly smaller than - ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you - don't await the WGMMA in the body. + body: The pipeline body function, which is called with + + - ``indices``: Tuple of current loop indices. + - ``*input_refs``: SMEM refs for inputs. + - ``*output_refs``: SMEM refs for outputs. + + If ``init_carry`` is provided, ``body`` receives an additional argument + ``carry`` -- the carry from the previous iteration. It must then return + the next carry value. + grid: The grid dimensions for the pipeline. + in_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for inputs. + out_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for outputs. + max_concurrent_steps: Maximum concurrently active pipeline stages. + delay_release: Number of steps to delay before reusing input/output + references. Must be ``< max_concurrent_steps``. Useful for hiding WGMMA + latency (typically set to 1). + init_carry: Optional initial carry. If provided, ``body`` handles + carry-over state between iterations, and the pipeline returns the + final carry. + + Returns: + A function that, when called with GMEM input and output refs, executes the + pipeline and returns the final carry value (if ``init_carry`` was used), + otherwise it returns None. """ if max_concurrent_steps <= delay_release: raise ValueError( @@ -278,7 +293,7 @@ def scoped_pipeline( def loop_body(step, carry): slot = lax.rem(step, max_concurrent_steps) - indices, fetch_indices, last_store_slices = carry + indices, fetch_indices, last_store_slices, prev_body_carry = carry if barrier_ref is not None: # Wait for the current GMEM->SMEM copy to complete, if any. @@ -289,12 +304,13 @@ def loop_body(step, carry): max_concurrent_steps - (1 + delay_release), wait_read_only=True ) - body( + next_body_carry = body( indices, *( bref.get_ref_for_slot(slot) for bref in it.chain(in_brefs, out_brefs) ), + *(prev_body_carry,) if init_carry is not None else (), ) if copies_out_in_loop: @@ -346,6 +362,7 @@ def do_fetch(): _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid), new_store_slices, + next_body_carry if init_carry is not None else None, ) # Invariant: ``indices`` and ``fetch_indices`` are always @@ -360,8 +377,11 @@ def do_fetch(): else (_Slice(-1, -1),) * len(bref.spec.block_shape) for bref in out_brefs ] - last_indices, _, _ = lax.fori_loop( - 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + last_indices, _, _, final_carry = lax.fori_loop( + 0, + num_steps, + loop_body, + (indices, fetch_indices, last_store_slices, init_carry), ) # Outputs invariant to the sequential axis are never written from inside the @@ -378,6 +398,7 @@ def do_fetch(): # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) + return final_carry if init_carry is not None else None return pipeline diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 608cfcba2465..f9a23e7be9d0 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2841,6 +2841,31 @@ def kernel_body(_, x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_carry(self): + num_steps = 4 + + def kernel(o_gmem): + plgpu.emit_pipeline( + kernel_body, + out_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + init_carry=0, + )(o_gmem) + + def kernel_body(_, o_smem, carry): + o_smem[...] = lax.broadcast(carry, o_smem.shape) + return carry + 1 + + kernel_fn = self.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((64, num_steps * 64), jnp.int32), + ) + np.testing.assert_array_equal( + kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1)) + ) + class PipelineWGTest( PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From 770eff03dfbc225011535cb32ab92ae40cff679b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 23 May 2025 15:09:34 -0400 Subject: [PATCH 1420/1769] Apply extensive input to extensive output forwarding in scan. --- jax/_src/lax/control_flow/loops.py | 22 ++++++++++- tests/lax_control_flow_test.py | 61 ++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 83c31928d7cb..b9ce8ae09380 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -341,7 +341,7 @@ def _create_jaxpr(init): # If the body forwards an input carry to an output carry, that input is # read-only and can be moved to be a const. Doing so can lead to efficiency # wins, e.g. if the scan is inside a cond with a batched predicate. - carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) + carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] if any(move_to_const): jaxpr = pe.prune_closed_jaxpr_outputs( @@ -352,12 +352,32 @@ def _create_jaxpr(init): consts = [*new_consts, *consts] num_carry -= len(new_consts) + # When an extensive output is forwarded from an extensive input, we can + # avoid copying it by pruning it from the jaxpr and forwarding manually. We + # don't need to update the indexing based on the optimization above since it + # doesn't change the total number of consts and carries combined, and + # `ext_fwd` already only includes the extensive outputs. But, we do remove + # the number of consts from the index since we're going to use it to index + # into `in_flat`, which doesn't include consts. + ext_to_ext_fwd = [ + in_idx - len(consts) if in_idx is not None and + in_idx >= num_carry + len(consts) else None for in_idx in ext_fwd] + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd]) + out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), unroll=unroll, _split_transpose=_split_transpose) + # Apply input to output forwarding that was computed above. + carry_out, out = split_list(out, [num_carry]) + out_ = iter(out) + out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd] + assert next(out_, None) is None + out = [*carry_out, *out] + if any(move_to_const): out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 54dff47fea32..2f1e154627f4 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3299,6 +3299,59 @@ def body_fun(c, _): outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] self.assertAllClose(outs, outs_ref, check_dtypes=False) + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_forwarding_correctness( + self, + seed, + num_body_consts, + num_const_fwds, + num_input_fwds): + + num_carry = num_const_fwds + 4 + num_xs = num_input_fwds + 2 + num_ys = num_xs + 1 + + rng = np.random.RandomState(seed) + carry_perm = rng.permutation(num_carry) + carry_iperm = np.argsort(carry_perm) + + xs_perm = rng.permutation(num_xs) + ys_perm = rng.permutation(num_ys) + f = np.arange(num_xs) + f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)] + f += [None] + in_fwd = [f[i] for i in ys_perm] + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, x): + c = [c[i] for i in carry_iperm] + carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds]) + carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in carry_dont_fwd] + new_c_perm = [*carry_fwds, *carry_dont_fwd] + new_c = [new_c_perm[i] for i in carry_perm] + + x = [x[i] for i in xs_perm] + x_fwd, x_dont_fwd = split_list(x, [num_input_fwds]) + x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts) + for x in x_dont_fwd] + y = [*x_fwd, *x_dont_fwd, 0] + y = [y[i] for i in ys_perm] + + return new_c, y + + xs = list(rng.uniform(size=(num_xs, 2))) + final, outs = jax.lax.scan(body_fun, init_vals, xs) + for f, y in zip(in_fwd, outs): + if f is not None: + self.assertAllClose(y, xs[f]) + + final_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(final, final_ref, check_dtypes=False) + def test_scan_diff_of_print(self): # ref: https://github.com/jax-ml/jax/issues/28738 def f(c, _): @@ -3311,6 +3364,14 @@ def g(x): eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"] self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns]) + def test_scan_input_to_output_forwarding(self): + def f(c, x): + return c + 1, x + def g(x): + return jax.lax.scan(f, 0, x) + jaxpr = jax.make_jaxpr(g)(jnp.arange(3.)) + self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 38ecd13a6c8fbd61308741bc63d1f07ed019806f Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 29 May 2025 15:24:44 +0000 Subject: [PATCH 1421/1769] [CI] Move k8s actions test files out of .github directory --- .github/workflows/k8s.yaml | 6 +++++- .pre-commit-config.yaml | 2 +- {.github/workflows => ci}/k8s/indexed-job.yaml | 0 {.github/workflows => ci}/k8s/jobset.yaml | 0 4 files changed, 6 insertions(+), 2 deletions(-) rename {.github/workflows => ci}/k8s/indexed-job.yaml (100%) rename {.github/workflows => ci}/k8s/jobset.yaml (100%) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 81552f9bb43b..86bc5e6c168b 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -4,6 +4,8 @@ on: branches: - main paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' - 'jax/distributed.py' - 'jax/_src/distributed.py' - 'jax/_src/clusters/**' @@ -11,6 +13,8 @@ on: branches: - main paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' - 'jax/distributed.py' - 'jax/_src/distributed.py' - 'jax/_src/clusters/**' @@ -61,7 +65,7 @@ jobs: run: kubectl apply -f jax/examples/k8s/svc-acct.yaml - name: Submit test job - run: kubectl apply -f jax/.github/workflows/k8s/${{ matrix.controller }}.yaml + run: kubectl apply -f jax/ci/k8s/${{ matrix.controller }}.yaml - name: Check job status shell: bash -e -o pipefail {0} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71b3d51caaa7..8cc28c9fe4ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: exclude: | (?x)^( examples/k8s/svc-acct\.yaml | - \.github/workflows/k8s/indexed-job\.yaml + ci/k8s/indexed-job\.yaml )$ - id: end-of-file-fixer # only include python files diff --git a/.github/workflows/k8s/indexed-job.yaml b/ci/k8s/indexed-job.yaml similarity index 100% rename from .github/workflows/k8s/indexed-job.yaml rename to ci/k8s/indexed-job.yaml diff --git a/.github/workflows/k8s/jobset.yaml b/ci/k8s/jobset.yaml similarity index 100% rename from .github/workflows/k8s/jobset.yaml rename to ci/k8s/jobset.yaml From 67c5e2803f5857dfb92d9c83781d5e8abba2cf6e Mon Sep 17 00:00:00 2001 From: Jen Ha Date: Thu, 29 May 2025 09:31:36 -0700 Subject: [PATCH 1422/1769] cloud_tpu_init: Remove verbose logging. This log may be verbose for most use cases, removing. --- jax/_src/cloud_tpu_init.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index f42794db7696..0d4f37203fbc 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,14 +15,11 @@ import datetime import os import re -import logging import warnings from jax import version from jax._src import config from jax._src import hardware_utils -logger = logging.getLogger(__name__) - running_in_cloud_tpu_vm: bool = False @@ -78,7 +75,6 @@ def cloud_tpu_init() -> None: libtpu_path = get_tpu_library_path() num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() if num_tpu_chips == 0: - logger.info('Using LibTPU with a device other than TPU. Skipping TPU metadata query.') os.environ['TPU_SKIP_MDS_QUERY'] = '1' if ( tpu_id is not None From 42977e51816b9eb42c7360abe05f56cad70e894a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 29 May 2025 09:32:09 -0700 Subject: [PATCH 1423/1769] [Pallas/Fuser] Add custom_vjp_call rule for physicalize PiperOrigin-RevId: 764763254 --- jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/fuser/fusible_dtype.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index a4c3402f5309..d5b5d128241d 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -115,6 +115,7 @@ pytype_strict_library( "//jax", "//jax:api_util", "//jax:core", + "//jax:custom_derivatives", "//jax:dtypes", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 7d9c2ca67855..09cd8f57dbc1 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -22,6 +22,7 @@ import jax from jax._src import api_util from jax._src import core +from jax._src import custom_derivatives from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -312,6 +313,25 @@ def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule +def _custom_vjp_call_physicalize_rule( + ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs +): + _assert_no_fusion_types(ctx.avals_out) + new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) + fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) + new_fwd = lu.wrap_init(physicalize(fwd.f_transformed), debug_info=fwd.debug_info) + const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) + bwd = custom_derivatives._handle_consts_in_bwd(bwd, const_avals) + new_bwd = lu.wrap_init(physicalize(bwd.f_transformed), debug_info=bwd.debug_info) + return custom_derivatives.custom_vjp_call_p.bind( + fun, new_fwd, new_bwd, *args, **kwargs + ) + +_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule + + def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized): _assert_no_fusion_types(ctx.avals_in) _assert_no_fusion_types(ctx.avals_out) From 605b8c0cc4216032e1aa9644fb7da39f04bafed5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 29 May 2025 10:38:20 -0700 Subject: [PATCH 1424/1769] Expose `GSPMDSharding` via `jex` as a temporary measure. PiperOrigin-RevId: 764791015 --- jax/extend/BUILD | 6 ++++++ jax/extend/sharding.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 jax/extend/sharding.py diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 6dc5d7d76311..c2a5c48bd2b0 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -71,6 +71,12 @@ pytype_strict_library( deps = ["//jax"], ) +pytype_strict_library( + name = "sharding", + srcs = ["sharding.py"], + deps = ["//jax:sharding_impls"], +) + pytype_strict_library( name = "source_info_util", srcs = ["source_info_util.py"], diff --git a/jax/extend/sharding.py b/jax/extend/sharding.py new file mode 100644 index 000000000000..8af2bf397249 --- /dev/null +++ b/jax/extend/sharding.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(yashkatariya): Remove this after NamedSharding supports more complicated +# shardings like sub-axes, strided shardings, etc. +from jax._src.sharding_impls import GSPMDSharding as GSPMDSharding From 64ef37a6fe33ba4c264750bbbb0bdb086406818c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 May 2025 10:48:02 -0700 Subject: [PATCH 1425/1769] [pallas:mosaic] Enabled more lowering rules for all kernel types PiperOrigin-RevId: 764795007 --- jax/_src/pallas/mosaic/lowering.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c6aaf77199b5..e2dfe526ea14 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3212,15 +3212,16 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): return [] -@register_lowering_rule(primitives.program_id_p) +@register_lowering_rule( + primitives.program_id_p, kernel_types=[*tpu_core.KernelType] +) def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." ) length = len(ctx.lowering_context.user_grid_indices) - if not (0 <= axis < length): + if axis not in range(length): raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" f" length: {length}" @@ -3228,7 +3229,9 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): return ctx.lowering_context.user_grid_indices[axis] -@register_lowering_rule(primitives.num_programs_p) +@register_lowering_rule( + primitives.num_programs_p, kernel_types=[*tpu_core.KernelType] +) def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): mapped_axes = set(ctx.lowering_context.mapped_dims) seen_user_axes = 0 From 1e334cfdd27b82f4af98e0a744b5af0e2a3634ec Mon Sep 17 00:00:00 2001 From: Zhonglin Han Date: Wed, 28 May 2025 14:07:14 -0700 Subject: [PATCH 1426/1769] Add dtype arg collective_matmul_mgpu.py to support bfloat16 --- .../pallas/ops/gpu/collective_matmul_mgpu.py | 9 ++++++--- tests/pallas/mgpu_collective_matmul_test.py | 4 +++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py index 854d75dbf6a3..36d29cf082d8 100644 --- a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -42,6 +42,7 @@ def all_gather_lhs_matmul( block_n: int, block_k: int, max_concurrent_steps: int, + dtype: jnp.dtype = jnp.float16, ) -> jax.Array: if (num_devices := jax.device_count()) != jax.process_count(): raise ValueError("The kernel only supports one device per process") @@ -49,6 +50,8 @@ def all_gather_lhs_matmul( raise ValueError("The kernel can only work over all devices in a Mesh.") if max_concurrent_steps < 2: raise ValueError("max_concurrent_steps must be >= 2") + if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): + raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") num_sms = 132 # There are 132 SMs on a H100 SXM GPU. @@ -121,7 +124,7 @@ def device_loop(device_offset, _): @functools.partial( pl.run_scoped, acc_ref=plgpu.ACC((block_m, block_n)), - out_smem=plgpu.SMEM((block_m, block_n), jnp.float16, transforms=transforms), + out_smem=plgpu.SMEM((block_m, block_n), dtype, transforms=transforms), ) def _(acc_ref, out_smem): pl.semaphore_wait(capacity_sem) @@ -173,8 +176,8 @@ def k_loop(idxs, lhs_smem, rhs_smem): result, _ = plgpu.kernel( kernel_body, - out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), jnp.float16), - jax.ShapeDtypeStruct((num_sms, 2, block_m, k), jnp.float16)], + out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype), + jax.ShapeDtypeStruct((num_sms, 2, block_m, k), dtype)], scratch_shapes=[ plgpu.SemaphoreType.REGULAR, plgpu.SemaphoreType.REGULAR, ], diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py index 386162b1992c..3760c7ccddb7 100644 --- a/tests/pallas/mgpu_collective_matmul_test.py +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -68,6 +68,7 @@ def setUp(self): block_n=(64, 128, 192), block_k=(64, 128), max_concurrent_steps=(2, 4), + dtype=(jnp.float16, jnp.bfloat16), ) def test_all_gather_lhs_matmul( self, @@ -78,9 +79,9 @@ def test_all_gather_lhs_matmul( block_n, block_k, max_concurrent_steps, + dtype, ): num_devices = jax.device_count() - dtype = jnp.float16 lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 # H100 SMEM limit is 228kB. @@ -118,6 +119,7 @@ def run(body): block_n=block_n, block_k=block_k, max_concurrent_steps=max_concurrent_steps, + dtype=dtype, ) ) ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y) From f823fafc86a7d92b8750814618217611b9e9b3c0 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 29 May 2025 11:03:28 -0700 Subject: [PATCH 1427/1769] Update `compile_options_proto_cc` deps to new proto dir. Follow up to https://github.com/openxla/xla/pull/24690. PiperOrigin-RevId: 764802248 --- jaxlib/BUILD | 1 - jaxlib/xla_compiler.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index dd96b7d23a8e..7be24ce5e825 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1228,7 +1228,6 @@ cc_library( "@xla//xla/hlo/builder:xla_computation", "@xla//xla/hlo/ir:hlo", "@xla//xla/hlo/parser:hlo_parser", - "@xla//xla/pjrt:compile_options_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:pjrt_executable", "@xla//xla/pjrt:status_casters", diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index f9ec134793ed..1b9c8c43b126 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -59,7 +59,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/proto/compile_options.pb.h" From 6ddbdd989ee181fd8f1e2b1ed32b03fceee6b72b Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 29 May 2025 11:29:30 -0700 Subject: [PATCH 1428/1769] [Mosaic GPU] Fix collective argument to infer_tmem_layout PiperOrigin-RevId: 764813418 --- jax/experimental/mosaic/gpu/tcgen05.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 13d945249b69..0438400f6310 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -151,7 +151,7 @@ def mma( element_type2 = a.dtype if collective: raise NotImplementedError("Collective not supported for TMEMRef") - if a.layout != (expected_layout := _infer_tmem_layout(a.shape, collective, packing=2)): + if a.layout != (expected_layout := _infer_tmem_layout(a.shape, packing=2)): raise ValueError( f"A layout mismatch: expected {expected_layout}, got {a.layout}" ) From 57fe3f2aa60239e792ff46607372b4826440033b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 29 May 2025 11:43:52 -0700 Subject: [PATCH 1429/1769] [Mosaic GPU] Check that the device order in the mesh follows logical_ids As the comment in the code explains, we expect that the mesh ordering follows device ids, which should always equal the NVSHMEM PE ids that Mosaic uses for its collective implementations. Any divergence would have to be resolved through an extra translation layer at runtime. PiperOrigin-RevId: 764819378 --- jax/experimental/mosaic/gpu/core.py | 31 +++++++++++++++++++++ tests/pallas/gpu_pallas_distributed_test.py | 26 +++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 4ed551654a0e..193fd1bd3589 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -27,6 +27,7 @@ import weakref import jax +from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib import mosaic_gpu_dialect as dialect from jaxlib.mlir import ir @@ -127,6 +128,15 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types): del module # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + +def _has_communication(module, **_): + empty_str_attr = ir.StringAttr.get("") + for op in module.body: + if "nvshmem" in getattr(op, "sym_name", empty_str_attr).value: + return True + return False + + # TODO(apaszke): Implement a proper system for managing kernel lifetimes KNOWN_KERNELS = {} @@ -139,6 +149,27 @@ def _mosaic_gpu_lowering_rule( input_output_aliases: tuple[tuple[int, int], ...] = (), use_custom_barrier: bool = False, ): + axis_context = ctx.module_context.axis_context + if _has_communication(module): + # Those checks are trying to ensure that the logical device ids are + # consistent with the NVSHMEM PE ids that Mosaic will be using for + # communication. Any divergence here would require us to implement a logical + # to physical translation, which is currently not implemented. + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh + if not np.array_equal(mesh.device_ids.ravel(), np.arange(mesh.size)): + raise NotImplementedError( + "Mosaic GPU only supports meshes with device ordering that follows" + " row-major device ids." + ) + elif isinstance(axis_context, sharding_impls.ShardingContext): + if axis_context.num_devices != 1: + raise NotImplementedError( + "Mosaic GPU only supports single-device meshes in ShardingContext." + ) + else: + raise NotImplementedError(f"Unsupported sharding context: {axis_context}") + assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) module = _run_serde_pass( diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 3aeee352ff6d..163adc385b23 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -116,6 +116,32 @@ def kernel(y_ref, sem): )() np.testing.assert_allclose(y, jnp.ones_like(y)) + def test_permuted_mesh(self): + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 1, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(sem) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + mesh = jax.sharding.Mesh(jax.devices()[::-1], ['x']) # Reverse the devices. + f = jax.jit( + shard_map.shard_map( + kernel_call, mesh, in_specs=(), out_specs=P(None), check_rep=False, + ) + ) + msg = ( + 'Mosaic GPU only supports meshes with device ordering that follows' + ' row-major device ids.' + ) + with self.assertRaisesRegex(NotImplementedError, msg): + f() + if __name__ == '__main__': # This test doesn't work with the platform allocator, so we override it From da845deb30955cfb32d265457903ad90fc3e2eb7 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 29 May 2025 18:53:29 +0000 Subject: [PATCH 1430/1769] Lock down more permissions and update default usage for some workflows --- .github/workflows/bazel_cpu_py_import_rbe.yml | 4 ---- .github/workflows/bazel_cuda_non_rbe.yml | 7 +------ .github/workflows/build_artifacts.yml | 11 ----------- .github/workflows/pytest_cpu.yml | 8 +------- .github/workflows/pytest_cuda.yml | 10 +--------- .github/workflows/pytest_tpu.yml | 11 +---------- 6 files changed, 4 insertions(+), 47 deletions(-) diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index 09c9d173e0d0..cc3ae89d97f9 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -16,22 +16,18 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" python: description: "Which python version to test?" type: string - required: true default: "3.12" enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: string - required: false default: 'no' jobs: diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 458589199c53..d30e1b56dab8 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -17,33 +17,28 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" python: description: "Which python version to test?" type: string - required: true default: "3.12" enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" jaxlib-version: description: "Which jaxlib version to test? (head/pypi_latest)" type: string - required: true default: "head" gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: string - required: false default: 'no' +permissions: {} jobs: run-tests: defaults: diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 72d554aa5d1b..95ab90412494 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -12,7 +12,6 @@ on: runner: description: "Which runner should the workflow run on?" type: choice - required: true default: "linux-x86-n2-16" options: - "linux-x86-n2-16" @@ -21,7 +20,6 @@ on: artifact: description: "Which JAX artifact to build?" type: choice - required: true default: "jaxlib" options: - "jax" @@ -31,7 +29,6 @@ on: python: description: "Which python version should the artifact be built for?" type: choice - required: false default: "3.12" options: - "3.10" @@ -41,7 +38,6 @@ on: clone_main_xla: description: "Should latest XLA be used?" type: choice - required: false default: "0" options: - "1" @@ -49,7 +45,6 @@ on: halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: choice - required: false default: 'no' options: - 'yes' @@ -59,31 +54,25 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: string - required: true default: "jaxlib" python: description: "Which python version should the artifact be built for?" type: string - required: false default: "3.12" clone_main_xla: description: "Should latest XLA be used?" type: string - required: false default: "0" upload_artifacts_to_gcs: description: "Should the artifacts be uploaded to a GCS bucket?" - required: true default: true type: boolean gcs_upload_uri: description: "GCS location prefix to where the artifacts should be uploaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string outputs: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index a92f2d96dc89..95086257c62b 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -17,34 +17,28 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" python: description: "Which python version should the artifact be built for?" type: string - required: true default: "3.12" enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" download-jax-only-from-gcs: description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" - required: false default: '0' type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: string - required: false default: 'no' - +permissions: {} jobs: run-tests: defaults: diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 6fa4e14f8b85..d576370bb772 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -17,44 +17,36 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" python: description: "Which python version to test?" type: string - required: true default: "3.12" cuda-version: description: "Which CUDA version to test?" type: string - required: true default: "12.8" use-nvidia-pip-wheels: description: "Whether to download CUDA packages from PyPI?" type: boolean - required: false default: false enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" download-jax-only-from-gcs: description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" - required: false default: '0' type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: string - required: false default: 'no' - +permissions: {} jobs: run-tests: defaults: diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 5f56b165c295..313bbede52f5 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -22,32 +22,26 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-ct5lp-224-8tpu" cores: description: "How many TPU cores should the test use?" type: string - required: true default: "8" tpu-type: description: "Which TPU type is used for testing?" type: string - required: true default: "v5e-8" python: description: "Which Python version should be used for testing?" type: string - required: true default: "3.12" run-full-tpu-test-suite: description: "Should the full TPU test suite be run?" type: string - required: false default: "0" libtpu-version-type: description: "Which libtpu version should be used for testing?" type: string - required: false # Choices are: # - "nightly": Use the nightly libtpu wheel. # - "pypi_latest": Use the latest libtpu wheel from PyPI. @@ -55,20 +49,17 @@ on: default: "nightly" download-jax-only-from-gcs: description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" - required: false default: '0' type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: string - required: false default: 'no' - +permissions: {} jobs: run-tests: defaults: From 448c07d006e5cbc0cdd95ee7e477b9bdc606e1e9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 29 May 2025 12:15:25 -0700 Subject: [PATCH 1431/1769] Reverts 42977e51816b9eb42c7360abe05f56cad70e894a PiperOrigin-RevId: 764832745 --- jax/_src/pallas/fuser/BUILD | 1 - jax/_src/pallas/fuser/fusible_dtype.py | 20 -------------------- 2 files changed, 21 deletions(-) diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index d5b5d128241d..a4c3402f5309 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -115,7 +115,6 @@ pytype_strict_library( "//jax", "//jax:api_util", "//jax:core", - "//jax:custom_derivatives", "//jax:dtypes", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 09cd8f57dbc1..7d9c2ca67855 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -22,7 +22,6 @@ import jax from jax._src import api_util from jax._src import core -from jax._src import custom_derivatives from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -313,25 +312,6 @@ def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule -def _custom_vjp_call_physicalize_rule( - ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs -): - _assert_no_fusion_types(ctx.avals_out) - new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) - fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) - fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) - new_fwd = lu.wrap_init(physicalize(fwd.f_transformed), debug_info=fwd.debug_info) - const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) - bwd = custom_derivatives._handle_consts_in_bwd(bwd, const_avals) - new_bwd = lu.wrap_init(physicalize(bwd.f_transformed), debug_info=bwd.debug_info) - return custom_derivatives.custom_vjp_call_p.bind( - fun, new_fwd, new_bwd, *args, **kwargs - ) - -_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule - - def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized): _assert_no_fusion_types(ctx.avals_in) _assert_no_fusion_types(ctx.avals_out) From 3abdf560cc912320e0c0b69bae9851d7d1b93d6b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 29 May 2025 13:52:48 -0700 Subject: [PATCH 1432/1769] [Pallas Fuser] Use lu transformation to physicalize fwd/bwd functions in custom_vjp rule PiperOrigin-RevId: 764871024 --- jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/fuser/fusible_dtype.py | 37 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index a4c3402f5309..d5b5d128241d 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -115,6 +115,7 @@ pytype_strict_library( "//jax", "//jax:api_util", "//jax:core", + "//jax:custom_derivatives", "//jax:dtypes", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 7d9c2ca67855..152b20ff66ea 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -17,11 +17,13 @@ import abc import dataclasses import functools +import itertools as it from typing import Any, Sequence, TypeVar import jax from jax._src import api_util from jax._src import core +from jax._src import custom_derivatives from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -312,6 +314,41 @@ def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule +@lu.transformation2 +def _physicalize_transform(f, *args): + vals, zeros = args[::2], args[1::2] + assert len(vals) == len(zeros) + wrapper = lambda *inner_vals: f( + *it.chain.from_iterable(zip(inner_vals, zeros)) + ) + return physicalize(wrapper)(*vals) + + +@lu.transformation2 +def _physicalize_transform_bwd(f, const_avals, *args): + return [custom_derivatives.Zero(a) for a in const_avals] + list( + physicalize(f)(*args) + ) + + +def _custom_vjp_call_physicalize_rule( + ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs +): + _assert_no_fusion_types(ctx.avals_out) + new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) + fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) + fwd_physicalized = _physicalize_transform(fwd) + const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) + bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals) + return custom_derivatives.custom_vjp_call_p.bind( + fun, fwd_physicalized, bwd_physicalized, *args, **kwargs + ) + +_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule + + def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized): _assert_no_fusion_types(ctx.avals_in) _assert_no_fusion_types(ctx.avals_out) From 976aa7ac31f49957803547ebff80720e39aba04e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 May 2025 13:53:08 -0700 Subject: [PATCH 1433/1769] [pallas] Added `pl.loop` -- a decorator for writing stateless loops PiperOrigin-RevId: 764871167 --- jax/_src/pallas/helpers.py | 16 ++++++++++++++++ jax/_src/pallas/mosaic_gpu/pipeline.py | 6 ++---- jax/experimental/pallas/__init__.py | 1 + .../pallas/ops/gpu/attention_mgpu.py | 4 ++-- .../pallas/ops/gpu/collective_matmul_mgpu.py | 9 +++++---- .../pallas/ops/tpu/flash_attention.py | 6 ++---- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 5c77d0a04f09..71004cd405a3 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -13,6 +13,8 @@ # limitations under the License. """Pallas helper functions.""" +from collections.abc import Callable + import jax from jax._src import checkify from jax._src import config @@ -69,6 +71,20 @@ def _wrapped(f): return _wrapped +def loop( + lower: jax.typing.ArrayLike, + upper: jax.typing.ArrayLike, + *, + unroll: int | bool | None = None, +) -> Callable[[Callable[[jax.Array], None]], None]: + def decorator(body): + jax.lax.fori_loop( + lower, upper, lambda idx, _: body(idx), init_val=None, unroll=unroll + ) + + return decorator + + _ENABLE_DEBUG_CHECKS = config.bool_state( "jax_pallas_enable_debug_checks", default=False, diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index f85b73b6b946..be9f663a42b7 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -764,12 +764,10 @@ def memory_loop_body(step, carry): memory_loop_body, (indices,)) # Await all the arrivals to not leave barriers in a bad state. # We only need to account for the prologue steps. - def _epi_step(step, _): + @pl.loop(0, prologue_steps, unroll=not has_dynamic_grid) + def _epi_step(step): for barrier in consumed_barrier_refs: gpu_primitives.barrier_wait(barrier.at[step]) - jax.lax.fori_loop( - 0, prologue_steps, _epi_step, None, unroll=not has_dynamic_grid - ) wg_idx = lax.axis_index(wg_axis) lax.cond( diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index caf77a3c4fce..da2bc9119dd0 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -38,6 +38,7 @@ from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like +from jax._src.pallas.helpers import loop as loop from jax._src.pallas.helpers import when as when from jax._src.pallas.helpers import debug_check as debug_check from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 447e3affd7c1..650668daf67a 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -245,7 +245,8 @@ def _memory_wg(): plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) - def kv_loop(kv_step, _): + @pl.loop(0, block_max_kv_steps - max_concurrent_steps) + def _kv_loop(kv_step): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) @@ -253,7 +254,6 @@ def kv_loop(kv_step, _): plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) - lax.fori_loop(0, block_max_kv_steps - max_concurrent_steps, kv_loop, None) def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py index 36d29cf082d8..5e4dda4494ba 100644 --- a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -107,12 +107,13 @@ def m_loop(mi, _): # For some reason ptxas spills if we unroll the loop over k copy_block = 32 - def k_copy_loop(ki, _): + @pl.loop(0, k // copy_block) + def _k_copy_loop(ki): k_slice = pl.ds(ki * copy_block, copy_block) scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice] - jax.lax.fori_loop(0, k // copy_block, k_copy_loop, None) - def device_loop(device_offset, _): + @pl.loop(0, num_devices) + def _device_loop(device_offset): # Loop invariant: scratch_ref.at[scratch_slot] is ready to be used # We're double buffering the scratch space. At each step, we read from # scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot] @@ -168,7 +169,7 @@ def k_loop(idxs, lhs_smem, rhs_smem): ) # Wait for the next scratch to arrive --- see the loop invariant. pl.semaphore_wait(received_sem) - jax.lax.fori_loop(0, num_devices, device_loop, None) + grid_size = m_shard // block_m m_steps = grid_size // num_sms + jnp.int32(sm_id < grid_size % num_sms) # TODO(apaszke): Use the ND-loop helper. diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index ef8dd61abacb..06746986a15e 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -383,10 +383,8 @@ def start_new_sequence(): @pl.when(should_run) def run(): - @functools.partial( - lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True - ) - def body(i, _): + @pl.loop(0, block_k_major // block_k, unroll=True) + def _body(i): m_prev = m_scratch_ref[batch_idx] l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] From a808fe89efcd4929c8c84bb81bd04156b9199e9e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 29 May 2025 14:37:29 -0700 Subject: [PATCH 1434/1769] [Mosaic GPU] Add non-collective blackwell matmul example PiperOrigin-RevId: 764889122 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 + .../pallas/ops/gpu/blackwell_matmul_mgpu.py | 230 ++++++++++++++++++ tests/pallas/BUILD | 17 ++ tests/pallas/mgpu_matmul_test.py | 88 +++++++ 4 files changed, 341 insertions(+) create mode 100644 jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py create mode 100644 tests/pallas/mgpu_matmul_test.py diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a56733a89f60..84caff69a090 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1543,6 +1543,8 @@ def _slice_lowering_rule( @register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) @register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: @@ -1551,6 +1553,10 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): f" {len(cases)}" ) pred_aval, *cases_avals = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval.shape == () for aval in ctx.avals_in): + raise NotImplementedError( + "Can only select on scalars in warp-level lowering.") [out_aval] = ctx.avals_out if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) diff --git a/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py new file mode 100644 index 000000000000..df8365f843a8 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py @@ -0,0 +1,230 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Matrix Multiplication kernel for Blackwell GPUs.""" +import dataclasses +import functools +import itertools +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + block_m: int + block_n: int + block_k: int + max_concurrent_steps: int + collective: bool + + +def _find_swizzle(dim_size_bits: int): + """Finds the largest swizzle that fits the dimension size.""" + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"Dimension size has {dim_size_bits} bits, which is not a multiple of 128" + ) + + +def matmul_kernel(a, b, config: TuningConfig): + dtype = a.dtype + if a.dtype != b.dtype: + raise ValueError( + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" + ) + m, k = a.shape + k2, n = b.shape + if k != k2: + raise ValueError( + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" + ) + collective = config.collective + if collective: + raise ValueError("Collective matmul is not supported yet.") + block_m, block_n, block_k = (config.block_m, config.block_n, config.block_k) + swizzle = _find_swizzle(block_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + block_lhs = (block_m, block_k) + block_rhs = (block_k, block_n) + block_out = (block_m, block_n) + if m % block_m != 0: + raise ValueError(f"{m=} must be divisible by {block_m=}") + if n % block_n != 0: + raise ValueError(f"{n=} must be divisible by {block_n=}") + if k % block_k != 0: + raise ValueError(f"{k=} must be divisible by {block_k=}") + m_iters = m // block_m + n_iters = n // block_n + k_iters = k // block_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + a_tma_barrier, b_tma_barrier, consumed_barrier): + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + slice_m = pl.ds(m_index * block_m, block_m) + slice_n = pl.ds(n_index * block_n, block_n) + acc_slice_m = pl.ds(m_index * block_m, block_m) + acc_slice_n = pl.ds(n_index * block_n, block_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + + @pl.when(ki >= max_concurrent_steps) + def _(): + plgpu.barrier_wait(consumed_barrier.at[slot]) + + slice_k = pl.ds(ki * block_k, block_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[slice_m, slice_k], + a_smem.at[slot], + a_tma_barrier.at[slot], + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[slice_k, slice_n], + b_smem.at[slot], + b_tma_barrier.at[slot], + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(warp_id == 1) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(a_tma_barrier.at[slot]) + plgpu.barrier_wait(b_tma_barrier.at[slot]) + is_last_iter = ki >= k_iters - 1 + barrier_slot = lax.select_n(is_last_iter, + slot, max_concurrent_steps) + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barrier.at[barrier_slot], + accumulate=(ki > 0), + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + plgpu.barrier_wait(consumed_barrier.at[max_concurrent_steps]) + acc_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + acc_smem, out_gmem.at[acc_slice_m, acc_slice_n] + ) + plgpu.wait_smem_to_gmem(0) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + # TODO(justinfu): Add collective support. + cluster_names=(), + cluster=(), + scratch_shapes=( # type: ignore + plgpu.SMEM( + (max_concurrent_steps, *block_lhs), dtype, transforms=transforms + ), + plgpu.SMEM( + (max_concurrent_steps, *block_rhs), dtype, transforms=transforms + ), + plgpu.TMEM(block_out, jnp.float32, collective=collective), + plgpu.SMEM(block_out, dtype, transforms=transforms), + plgpu.Barrier( + num_arrivals=1, num_barriers=max_concurrent_steps + ), + plgpu.Barrier( + num_arrivals=1, num_barriers=max_concurrent_steps + ), + plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps + 1, + for_tensor_core=True, + ), + ) + ) + return f(a, b) + + +def main(_) -> None: + problem_it = itertools.product( + (1024, 4096, 8192), (1024, 4096, 8192), (1024, 8192) + ) + for M, N, K in problem_it: + print(f"==== {M=} {N=} {K=} ====") + matmul_flops = 2 * M * N * K + peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS + a = jax.random.uniform(jax.random.key(0), (M, K), jnp.bfloat16) + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16) + tuning_it = itertools.product( + (128,), (128, 256), (64, 128), (2, 3, 4), (False,) + ) + best_util = -float("inf") + for (block_m, block_n, block_k, + max_concurrent_steps, collective) in tuning_it: + config = TuningConfig( + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + ) + try: + out, runtime_ms = profiler.measure( + functools.partial(matmul_kernel, config=config) + )(a, b) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + continue + raise + if M * N * K <= 1024 * 1024 * 1024: + expected = a @ b + np.testing.assert_allclose(out, expected) + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_util = achieved_tc_util + print( + f"{block_m=} {block_n=} {block_k=} {max_concurrent_steps=}: " + f"{runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest utilization: {best_util:4.1f}%") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 8be899123de0..48eddae69a60 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -805,6 +805,23 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "mgpu_matmul_test", + srcs = ["mgpu_matmul_test.py"], + enable_backends = [], + enable_configs = [], # TODO(justinfu): Enable B200 when available. + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "mgpu_ragged_dot_run", srcs = ["//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py"], diff --git a/tests/pallas/mgpu_matmul_test.py b/tests/pallas/mgpu_matmul_test.py new file mode 100644 index 000000000000..4013db78f6a2 --- /dev/null +++ b/tests/pallas/mgpu_matmul_test.py @@ -0,0 +1,88 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of matrix multiplication.""" + +import contextlib +import os + +from absl.testing import absltest +from absl.testing import parameterized +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.numpy as jnp +import numpy as np + + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + blackwell_matmul_mgpu = None +else: + from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class MatrixMultiplicationSm100ATest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if blackwell_matmul_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("10.0")): + self.skipTest("Only works on GPU with capability sm100a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) + + @parameterized.product( + m=(1024, 4096), + k=(1024, 4096), + n=(1024, 4096), + dtype=(jnp.float16,), + ) + def test_matmul( + self, + m, + n, + k, + dtype, + ): + k1, k2, = jax.random.split(jax.random.key(42), 2) + a = jax.random.normal(k1, (m, k), dtype) + b = jax.random.normal(k2, (k, n), dtype) + + out = blackwell_matmul_mgpu.matmul_kernel( + a, + b, + blackwell_matmul_mgpu.TuningConfig( + block_m=128, block_n=128, block_k=128, + max_concurrent_steps=2, + collective=False, + ), + ) + out_ref = a @ b + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 81769209c08fe6be844ab703c35e7eb31bc1c54b Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 29 May 2025 14:38:55 -0700 Subject: [PATCH 1435/1769] [Pallas] Add a base class for custom BufferedRef implementations. PiperOrigin-RevId: 764889905 --- jax/_src/pallas/mosaic/pipeline.py | 248 +++++++++++++++++------------ jax/experimental/pallas/tpu.py | 1 + 2 files changed, 149 insertions(+), 100 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 4ef22179260b..f4dab313fb6f 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -242,9 +242,134 @@ def _get_dim_size(bd): block_shape_nones = tuple(_get_dim_size(x) for x in spec.block_shape) return tuple(x for x in block_shape_nones if x is not None) + +class BufferedRefBase: + """Abstract interface for BufferedRefs.""" + + @property + def spec(self) -> pl.BlockSpec: + raise NotImplementedError() + + @property + def buffer_type(self) -> BufferType: + raise NotImplementedError() + + @property + def is_input(self): + return self.buffer_type in [ + BufferType.INPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + @property + def is_output(self): + return self.buffer_type in [ + BufferType.OUTPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + @property + def is_accumulator(self): + return self.buffer_type == BufferType.ACCUMULATOR + + @property + def is_input_output(self): + return self.buffer_type == BufferType.INPUT_OUTPUT + + @property + def is_manual(self): + return self.buffer_type == BufferType.MANUAL + + def init_slots(self): + """Initialize slot indices.""" + raise NotImplementedError() + + def swap_slots(self): + """Switch to the next slot.""" + raise NotImplementedError() + + @property + def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: + return self.spec.block_shape + + @property + def compute_index(self): + return self.spec.index_map + + def get_dma_slice(self, src_shape, src_dtype, grid_indices): + # We need to handle blocks that might go OOB in the src array. An in bounds + # block looks like this (for array shape (600, 600) and block shape + # (256, 256)): + # + # +--------------+------------------| + # | Block (0,0) | | + # | (256, 256) | | + # +--------------+ | + # | A (600, 600) | + # | | + # +---------------------------------+ + # + # For in-bounds blocks, we don't need to do anything special. + # An out-of-bounds block looks like this: + # + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # | XXXXXXXXXX | + # +--------------+ + # where the X's indicate where the block is out of bounds. + # + # When we have an out of bounds block like this, we need to truncate it to + # a tile boundary (tiles are (8, 128) along the two minormost dimensions). + # In this case, we'll have a block that is indexing the + # 512:768 elements of A along the first dimension. We need to convert 768 + # into 600 (600 % 8 == 0), so our indexing will look like this: + + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # where it is now a (88, 256) sized block. + # + # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block + # for the last iteration on that dimension, we will pick the next highest + # tile multiple, i.e. (96, 256). + if len(src_shape) < 2: + raise NotImplementedError("Must use >1D values.") + + tiling = _make_tiling(src_shape, src_dtype) + block_indices = self.compute_index(*grid_indices) + return tuple( + _make_block_slice(bi, bs, ss, t) + for bi, bs, ss, t in zip( + block_indices, self.block_shape, src_shape, tiling, strict=True + ) + ) + + def bind_existing_ref(self, window_ref, indices): + """For handling VMEM references, the pipeline aliases the existing ref.""" + del window_ref, indices + return self + + def with_spec(self, spec: pl.BlockSpec) -> 'BufferedRefBase': + """Returns a new BufferedRefBase with the given block spec.""" + raise NotImplementedError() + + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class BufferedRef: +class BufferedRef(BufferedRefBase): """A helper class to automate VMEM double buffering in pallas pipelines. Attributes: @@ -257,7 +382,6 @@ class BufferedRef: reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. current_slot: current slot index to the working buffer. - next_slot: slot that will point to the working buffer in the next iteration. sem_recvs: Double buffered semaphores for input DMAs. sem_sends: Double buffered semaphores for output DMAs. block_shape: passthrough property for the BlockSpec's block_shape. @@ -272,33 +396,37 @@ class BufferedRef: swap: Tracks whether the BufferedRef slots need to be swapped before next copy. """ - spec: pl.BlockSpec # static metadata + _spec: pl.BlockSpec # static metadata dtype: Any # static metadata - buffer_type: BufferType # static metadata + _buffer_type: BufferType # static metadata window_ref: ArrayRef | None accum_ref: ArrayRef | None current_slot: ArrayRef | None - # TODO(ramiroleal): Unused by class. Remove argument from - # BufferedRef instantiations. - next_slot: ArrayRef | None sem_recvs: SemaphoreTuple | None sem_sends: SemaphoreTuple | None # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid # using this ref. swap: ArrayRef | None + @property + def spec(self): + return self._spec + + @property + def buffer_type(self): + return self._buffer_type + def tree_flatten(self): return ( ( self.window_ref, self.accum_ref, self.current_slot, - self.next_slot, self.sem_recvs, self.sem_sends, self.swap, ), - (self.spec, self.dtype, self.buffer_type), + (self._spec, self.dtype, self._buffer_type), ) @classmethod @@ -334,13 +462,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True # reference is already in VMEM, we just need allocate the accumulation # buffer and we will refer to the original reference slices directly. return cls( - spec=spec, + _spec=spec, dtype=dtype, - buffer_type=buffer_type, + _buffer_type=buffer_type, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, - next_slot=None, sem_recvs=None, sem_sends=None, swap=None, @@ -348,13 +475,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True else: memory_space = SMEM if spec.memory_space == SMEM else VMEM return cls( - spec=spec, + _spec=spec, dtype=dtype, - buffer_type=buffer_type, + _buffer_type=buffer_type, window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), - next_slot=None, sem_recvs=( None if buffer_type is BufferType.OUTPUT @@ -396,6 +522,10 @@ def compute_index(self): def memory_space(self): return self.spec.memory_space + def with_spec(self, spec: pl.BlockSpec) -> 'BufferedRef': + """Returns a new BufferedRef with the given block spec.""" + return dataclasses.replace(self, _spec=spec) + @property def current_ref(self): buffer_slice = tuple( @@ -409,30 +539,6 @@ def current_ref(self): else: return self.window_ref.at[(self.current_slot_index, *buffer_slice)] - @property - def is_input(self): - return self.buffer_type in [ - BufferType.INPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] - - @property - def is_output(self): - return self.buffer_type in [ - BufferType.OUTPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] - - @property - def is_accumulator(self): - return self.buffer_type == BufferType.ACCUMULATOR - - @property - def is_input_output(self): - return self.buffer_type == BufferType.INPUT_OUTPUT - @property def current_slot_index(self): """Index in double buffer corresponding to the current slot.""" @@ -491,65 +597,6 @@ def swap_slots(self): if self.swap is not None: self.swap[0] = False - def get_dma_slice(self, src_shape, src_dtype, grid_indices): - # We need to handle blocks that might go OOB in the src array. An in bounds - # block looks like this (for array shape (600, 600) and block shape - # (256, 256)): - # - # +--------------+------------------| - # | Block (0,0) | | - # | (256, 256) | | - # +--------------+ | - # | A (600, 600) | - # | | - # +---------------------------------+ - # - # For in-bounds blocks, we don't need to do anything special. - # An out-of-bounds block looks like this: - # - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # | XXXXXXXXXX | - # +--------------+ - # where the X's indicate where the block is out of bounds. - # - # When we have an out of bounds block like this, we need to truncate it to - # a tile boundary (tiles are (8, 128) along the two minormost dimensions). - # In this case, we'll have a block that is indexing the - # 512:768 elements of A along the first dimension. We need to convert 768 - # into 600 (600 % 8 == 0), so our indexing will look like this: - - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # where it is now a (88, 256) sized block. - # - # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block - # for the last iteration on that dimension, we will pick the next highest - # tile multiple, i.e. (96, 256). - if len(src_shape) < 2: - raise NotImplementedError("Must use >1D values.") - - tiling = _make_tiling(src_shape, src_dtype) - block_indices = self.compute_index(*grid_indices) - return tuple( - _make_block_slice(bi, bs, ss, t) - for bi, bs, ss, t in zip( - block_indices, self.block_shape, src_shape, tiling, strict=True - ) - ) - def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input @@ -674,7 +721,8 @@ def accumulate(self): # Helper to tree map over BufferedRefs as leaves. map_brefs = functools.partial( jax.tree.map, - is_leaf=lambda x: isinstance(x, BufferedRef)) + is_leaf=lambda x: isinstance(x, BufferedRefBase) +) def _filter_indices( @@ -922,7 +970,7 @@ def _end(): buffered_ref.wait_out(dst_ref, self.indices) def swap_slots(self, buffered_ref, hbm_ref, schedule=None): - if buffered_ref.swap is not None: + if isinstance(buffered_ref, BufferedRef) and buffered_ref.swap is not None: swap = buffered_ref.swap[0] else: # If we are not using an SMEM `swap` tensor to keep track of diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 401b2fe66c45..e27fdaaadd8f 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -30,6 +30,7 @@ from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef +from jax._src.pallas.mosaic.pipeline import BufferedRefBase as BufferedRefBase from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule From 5b1372998c1c20f1e938ee4b715fd96d0cd54feb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 29 May 2025 14:59:46 -0700 Subject: [PATCH 1436/1769] [Mosaic GPU] Add support for inout arguments They are passed in between inputs and outputs to the kernel body and returned after outputs from the kernel. PiperOrigin-RevId: 764898480 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + .../mosaic_gpu/pallas_call_registration.py | 1 + jax/experimental/mosaic/gpu/core.py | 51 ++++++++++++++----- tests/mosaic/gpu_test.py | 20 ++++++++ 4 files changed, 59 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 84caff69a090..003fa0419f63 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -907,6 +907,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): block=block, in_shapes=(*in_shapes, *semaphores_shape), out_shape=(*out_shapes, *semaphores_shape), + inout_shape=(), smem_scratch_shape=scratch_buffers, lowering_semantics=lowering_semantics, module_name=mlir.sanitize_name(debug_info.func_name), diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index a14ccbb7daa9..ccbe4d36edc9 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -105,6 +105,7 @@ def zero_init_gmem_scratch(): *args, *scratch_args, module=module, out_types=lowering_result.new_out_shapes, + inout_types=(), input_output_aliases=input_output_aliases, use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 193fd1bd3589..79a0cd56328b 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -26,6 +26,7 @@ from typing import Any, Callable, Generic, TypeVar import weakref +import itertools import jax from jax._src import sharding_impls from jax._src.interpreters import mlir @@ -124,9 +125,12 @@ def supports_cross_device_collectives(): @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): +def _mosaic_gpu_abstract_eval(*_, module, out_types, inout_types): del module # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + return [ + jax._src.core.ShapedArray(t.shape, t.dtype) + for t in itertools.chain(out_types, inout_types) + ] def _has_communication(module, **_): @@ -146,6 +150,7 @@ def _mosaic_gpu_lowering_rule( *args, module, out_types, + inout_types, input_output_aliases: tuple[tuple[int, int], ...] = (), use_custom_barrier: bool = False, ): @@ -170,8 +175,19 @@ def _mosaic_gpu_lowering_rule( else: raise NotImplementedError(f"Unsupported sharding context: {axis_context}") - assert len(args) == len(ctx.avals_in) - assert len(out_types) == len(ctx.avals_out) + if inout_types: + if input_output_aliases: + raise ValueError( + "input_output_aliases and inout_types are mutually exclusive" + ) + num_inputs = len(ctx.avals_in) + num_outputs = len(ctx.avals_out) + input_output_aliases = tuple( + (num_inputs - 1 - i, num_outputs - 1 - i) + for i in range(len(inout_types)) + ) + assert len(ctx.avals_in) == len(args) + assert len(ctx.avals_out) == len(out_types) + len(inout_types) module = _run_serde_pass( module, serialize=True, @@ -562,6 +578,7 @@ def _lower_as_gpu_kernel( block: tuple[int, int, int], in_shapes: tuple[Any, ...], out_shape, + inout_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], lowering_semantics: LoweringSemantics, module_name: str, @@ -576,13 +593,14 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + inout_ref_tys = [_shape_to_ref_ty(t) for t in inout_shape] unwrap_output_tuple = False if isinstance(out_shape, list): out_shape = tuple(out_shape) elif not isinstance(out_shape, tuple): out_shape = (out_shape,) - unwrap_output_tuple = True + unwrap_output_tuple = not inout_shape out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] if prof_spec is not None: out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) @@ -610,19 +628,18 @@ def main(token_ptr, buffers): nonlocal launch_ctx token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + # XLA will pass in inout refs again as outputs, but we ignore them. + for i, ref_ty in enumerate([*in_ref_tys, *inout_ref_tys, *out_ref_tys]): ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty, llvm.GEPNoWrapFlags.none)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None + prof_buffer = arg_refs.pop() if prof_spec is not None else None with _launch( token, grid, cluster, block, smem_scratch_shape, lowering_semantics, module, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx launch_ctx = _launch_ctx - body(launch_ctx, *in_refs, *out_refs, smem_refs) + body(launch_ctx, *arg_refs, smem_refs) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() sym_tab = ir.SymbolTable(module.operation) sym_tab.insert(main.func_op) @@ -680,16 +697,22 @@ def as_gpu_kernel( kernel_name: str | None = None, ir_version: int | None = None, thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + inout_shape = (), ): if isinstance(in_shape, list): in_shape = tuple(in_shape) elif not isinstance(in_shape, tuple): in_shape = (in_shape,) + if isinstance(inout_shape, list): + inout_shape = tuple(inout_shape) + elif not isinstance(inout_shape, tuple): + inout_shape = (inout_shape,) module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - thread_semantics, module_name, kernel_name, prof_spec + body, grid, cluster, block, in_shape, out_shape, inout_shape, + smem_scratch_shape, thread_semantics, module_name, kernel_name, + prof_spec ) ) @@ -711,7 +734,7 @@ def as_gpu_kernel( if launch_ctx.is_device_collective and not supports_cross_device_collectives(): raise RuntimeError("Kernel is a cross-device collective but no support is available.") - expected_arg_tys, expected_arg_treedef = jax.tree.flatten(in_shape) + expected_arg_tys, expected_arg_treedef = jax.tree.flatten((*in_shape, *inout_shape)) def _check_args(*args): arg_treedef = jax.tree.structure(args) if arg_treedef != expected_arg_treedef: @@ -735,7 +758,7 @@ def _check_args(*args): ) def bind(*args) -> Any: - return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) + return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape, inout_types=inout_shape) if prof_spec is not None: @jax.jit diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 99b7d67cd691..a56aa04f6f60 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3543,6 +3543,26 @@ def test_pass_is_registered(self): pipeline.run(module.operation) +class ApiTest(TestCase): + + def test_inout(self): + def kernel(ctx, src, inout, dst, smem): + val = memref.load(inout, []) + gpu.barrier() + new_val = arith.constant(ir.IntegerType.get_signless(32), 42) + memref.store(new_val, inout, []) + x = mgpu.FragmentedArray.load_strided(src, is_signed=True) + (x + val).store_untiled(dst) + x = jnp.arange(128, dtype=jnp.int32) + y = jnp.asarray(2.0, dtype=jnp.int32) + kernel = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), inout_shape=y, + ) + xo, yo = kernel(x, y) + np.testing.assert_array_equal(xo, x + 2.0) + np.testing.assert_array_equal(yo, jnp.asarray(42, dtype=jnp.int32)) + + if hp is not None: @hps.composite def tiled_layouts( From 7eec8e1b6a6ed30fd8ce6a5d42134e0d2e8492aa Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 17 May 2025 03:23:03 +0000 Subject: [PATCH 1437/1769] [hijax] all pre-existing Box tests passing, still using typechange env Co-authored-by: Dougal Maclaurin --- jax/_src/core.py | 44 ++- jax/_src/interpreters/ad.py | 20 +- jax/_src/interpreters/partial_eval.py | 87 ++-- jax/_src/interpreters/pxla.py | 4 +- jax/_src/lax/control_flow/loops.py | 78 +++- jax/_src/pjit.py | 85 ++-- tests/attrs_test.py | 3 +- tests/hijax_test.py | 544 +++++++++++++++++++------- 8 files changed, 665 insertions(+), 200 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index b20b85a43b6e..24150aba6584 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -88,7 +88,8 @@ class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info', '_is_high', '_final_typechange_env'] + '_effects', '_debug_info', '_is_high', + '_initial_typechange_env', '_final_typechange_env'] _constvars: list[Var] _invars: list[Var] @@ -97,6 +98,7 @@ class Jaxpr: _effects: Effects _debug_info: DebugInfo _is_high: bool + _initial_typechange_env: dict[Var, Any] _final_typechange_env: dict[Var, Any] @property @@ -127,6 +129,10 @@ def debug_info(self) -> DebugInfo: def is_high(self) -> bool: return self._is_high + @property + def initial_typechange_env(self) -> dict[Var, Any]: + return self._initial_typechange_env + @property def final_typechange_env(self) -> dict[Var, Any]: return self._final_typechange_env @@ -139,6 +145,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] is_high: bool = False, + initial_typechange_env: dict | None = None, final_typechange_env: dict | None = None, ): """ @@ -165,6 +172,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) self._is_high = is_high + self._initial_typechange_env = initial_typechange_env or {} self._final_typechange_env = final_typechange_env or {} def __str__(self): @@ -193,6 +201,8 @@ def replace(self, **kwargs): effects=kwargs.pop("effects", self.effects), debug_info=kwargs.pop("debug_info", self.debug_info), is_high=kwargs.pop("is_high", self.is_high), + initial_typechange_env=kwargs.pop("initial_typechange_env", + self.initial_typechange_env), final_typechange_env=kwargs.pop("final_typechange_env", self.final_typechange_env), ) @@ -222,6 +232,22 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]: yield from jaxprs_in_params(eqn.params) +@dataclass(frozen=True) +class TypeChange: + aval: AbstractValue + initial_type_state: Any + final_type_state: Any + + def to_tangent_aval(self): + return TypeChange(self.aval.to_tangent_aval(), + self.initial_type_state.to_tangent_aval(), + self.final_type_state.to_tangent_aval()) + + def normalize(self): + return TypeChange(self.aval.normalize(), + self.initial_type_state.normalize(), + self.final_type_state.normalize()) + class ClosedJaxpr: __slots__ = ['__weakref__', '_jaxpr', '_consts'] @@ -241,6 +267,13 @@ def __init__(self, jaxpr: Jaxpr, consts: Sequence): def in_avals(self): return [v.aval for v in self.jaxpr.invars] + @property + def in_avals_aug(self): + ienv = self.jaxpr.initial_typechange_env + fenv = self.jaxpr.final_typechange_env + return [TypeChange(v.aval, ienv[v], fenv[v]) if v.aval.mutable else v.aval + for v in self.jaxpr.invars] + @property def out_avals(self): return [v.aval for v in self.jaxpr.outvars] @@ -542,10 +575,6 @@ def _true_bind(self, *args, **params): # is called frequently and it's slightly faster to avoid using a context # manager object. prev_trace = trace_ctx.trace - - if self.is_high(**params) and prev_trace.requires_low: - return self.to_lojax(*args, **params) # type: ignore - trace_ctx.set_trace(eval_trace) try: return self.bind_with_trace(prev_trace, args, params) @@ -553,6 +582,11 @@ def _true_bind(self, *args, **params): trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): + # TODO(mattjj,dougalm): remove this block? + if self.is_high(**params) and trace.requires_low: + with set_current_trace(trace): + return self.to_lojax(*args, **params) # type: ignore + return trace.process_primitive(self, args, params) def def_impl(self, impl): diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 7cbdfff01462..0cd99a197f66 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -641,6 +641,9 @@ def to_concrete_value(self): def get_referent(self): return core.get_referent(self.primal) + def type_state(self): + return self.primal.type_state() + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() @@ -1166,8 +1169,9 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, debug_info=jaxpr.jaxpr.debug_info) f_jvp, out_nonzeros = f_jvp_traceable( jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] - avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) + tangent_avals = [aval.to_tangent_aval() + for aval, nz in zip(jaxpr.in_avals_aug, nonzeros) if nz] + avals_in = list(it.chain(jaxpr.in_avals_aug, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -1189,14 +1193,12 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) new_debug_info = jaxpr.jaxpr.debug_info - new_arg_names = tuple(_perm(primals_in, tangents_in, - jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) - new_result_paths = tuple(_perm(primals_out, tangents_out, - jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) + arg_names = jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.in_avals)) + result_paths = jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.out_avals)) + new_arg_names = tuple(_perm(primals_in, tangents_in, arg_names)) + new_result_paths = tuple(_perm(primals_out, tangents_out, result_paths)) new_debug_info = new_debug_info._replace( - arg_names=new_arg_names, - result_paths=new_result_paths, - ) + arg_names=new_arg_names, result_paths=new_result_paths) constvars = jaxpr.jaxpr.constvars new_effects = pe._renumber_effects( (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 3c499429a663..444b60f15fa5 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -896,10 +896,8 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: config.enable_checks.value and core.check_jaxpr(jaxpr) dbg = jaxpr.debug_info._replace( arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) - lifted_jaxpr = Jaxpr(constvars=(), - invars=jaxpr.constvars + jaxpr.invars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=dbg) + lifted_jaxpr = jaxpr.replace( + constvars=(), invars=jaxpr.constvars + jaxpr.invars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr) return lifted_jaxpr @@ -1014,10 +1012,9 @@ def fun(*known_vals_in): known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] - known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] + known_avals = [a for a, uk in zip(jaxpr.in_avals_aug, in_unknowns) if not uk] jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=f.debug_info), - known_avals) + lu.wrap_init(fun, debug_info=f.debug_info), known_avals) (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking # check jaxpr_known and jaxpr_unknown in isolation @@ -1579,6 +1576,20 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: return ClosedJaxpr(jaxpr, ()) +def move_invars_right(jaxpr: ClosedJaxpr, to_move: Sequence[bool]): + return _move_invars_right(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]): + invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)]) + left_invars, right_invars = partition_list(to_move, invars) + new_invars = [*left_invars, *right_invars, *rest] + new_effs = _renumber_effects( + (*jaxpr.jaxpr.constvars, *new_invars), + (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs)) + def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] ) -> ClosedJaxpr: """Reorder `invars` by moving those indicated in `to_move` to the front.""" @@ -1640,6 +1651,10 @@ def full_lower(self): if val is None: return self return core.full_lower(val) + def type_state(self): + var = self._trace.frame.tracer_to_var.get(id(self)) + return self._trace.frame.current_typechange_env[var] + def _contents(self): return () @@ -1735,7 +1750,8 @@ class JaxprStackFrame: attrs_vars: list[Var] debug_info: core.DebugInfo is_high: bool - final_typechange_env: dict + initial_typechange_env: dict + current_typechange_env: dict def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() @@ -1751,7 +1767,8 @@ def __init__(self, debug_info: core.DebugInfo): self.attrs_vars = [] self.debug_info = debug_info self.is_high = False - self.final_typechange_env = {} + self.initial_typechange_env = {} + self.current_typechange_env = {} def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) @@ -1777,8 +1794,11 @@ def to_jaxpr( outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) + final_typechange_env = {v: s for v, s in self.current_typechange_env.items() + if v in self.initial_typechange_env} jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info, self.is_high, self.final_typechange_env) + debug_info, self.is_high, self.initial_typechange_env, + final_typechange_env) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1895,8 +1915,6 @@ def new_arg(self, aval, source_info: SourceInfo): self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) - if aval.mutable: - self.frame.final_typechange_env[var] = aval return tracer def new_const(self, c, source_info: SourceInfo): @@ -1921,6 +1939,8 @@ def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.constid_to_tracer[id(c)] = tracer self.frame.constvar_to_val[var] = c + if aval.mutable: + self.frame.initial_typechange_env[var] = c.type_state() return tracer def get_const(self, tracer) -> Any: @@ -2235,18 +2255,24 @@ def trace_to_jaxpr_dynamic( list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info, lower=lower) + in_avals_ = [a.aval if isinstance(a, core.TypeChange) else a for a in in_avals] with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals) + partial(trace.new_arg, source_info=source_info), in_avals_) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + trace.frame.initial_typechange_env = initial_typechange_env = { + v: a.initial_type_state for v, a in zip(trace.frame.invars, in_avals) + if isinstance(a, core.TypeChange)} + trace.frame.current_typechange_env = dict(initial_typechange_env) + try: with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) _check_returned_jaxtypes(fun.debug_info, ans) out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) + jaxpr, consts, attrs_tracked = trace.frame.to_jaxpr(trace, out_tracers, fun.debug_info) del fun, in_tracers, out_tracers, ans finally: trace.frame.reset_states(trace) @@ -2718,21 +2744,38 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis @weakref_lru_cache def lower_jaxpr(hi_jaxpr): - in_avals = [lo_ty for t in hi_jaxpr.in_avals for lo_ty in t.lo_ty()] + initial_env = hi_jaxpr.jaxpr.initial_typechange_env + lo_avals = [lo_ty for v in hi_jaxpr.jaxpr.invars + for lo_ty in (v.aval.lo_ty_(initial_env[v]) if v.aval.mutable + else v.aval.lo_ty())] f = lu.wrap_init(partial(lower_traceable, hi_jaxpr), debug_info=hi_jaxpr.jaxpr.debug_info) - lo_jaxpr, _, consts, () = trace_to_jaxpr_dynamic(f, in_avals, lower=True) - return core.ClosedJaxpr(lo_jaxpr, consts) + lo_jaxpr, _, lo_consts, () = trace_to_jaxpr_dynamic(f, lo_avals, lower=True) + return core.ClosedJaxpr(lo_jaxpr, lo_consts) def lower_traceable(jaxpr, *lo_args): + env = jaxpr.jaxpr.initial_typechange_env lo_args_ = iter(lo_args) - hi_args = [t.raise_val(*it.islice(lo_args_, len(t.lo_ty()))) - for t in jaxpr.in_avals] + hi_args = [v.aval.raise_val(*it.islice(lo_args_, len(v.aval.lo_ty()))) + if not v.aval.mutable else + v.aval.new_from_loval(env[v], *it.islice(lo_args_, len(v.aval.lo_ty_(env[v])))) + for v in jaxpr.jaxpr.invars] assert (problem := next(lo_args_, None)) is None hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} mut_outs = [lo_val for v, ty in jaxpr.jaxpr.final_typechange_env.items() - for lo_val in ty.get(hi_args[in_idx[v]])] - lo_outs = [lo_val for t, hi_val in zip(jaxpr.out_avals, hi_outs) - for lo_val in t.lower_val(hi_val)] + for lo_val in v.aval.read_loval(ty, hi_args[in_idx[v]])] + lo_outs = [lo_val for v, hi_val in zip(jaxpr.jaxpr.outvars, hi_outs) + for lo_val in v.aval.lower_val(hi_val)] return mut_outs + lo_outs + +def convert_const_himutables(jaxpr): + move = [core.typeof(c).mutable for c in jaxpr.consts] + constvals, in_mutables = partition_list(move, jaxpr.consts) + constvars, boxvars = partition_list(move, jaxpr.jaxpr.constvars) + invars = *boxvars, *jaxpr.jaxpr.invars + effects = make_jaxpr_effects(constvars, invars, jaxpr.jaxpr.outvars, + jaxpr.jaxpr.eqns) + new_jaxpr = jaxpr.jaxpr.replace(constvars=constvars, invars=invars, + effects=effects) + return jaxpr.replace(jaxpr=new_jaxpr, consts=constvals), in_mutables diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d0a22cd784b4..276e2b18cd40 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1775,6 +1775,7 @@ def _move_mutable_consts( constvars, mutvars = partition_list(hoist, jaxpr.constvars) invars = (*jaxpr.invars, *mutvars) effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) + # TODO(mattjj): debug_info must be updated... jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, effects, closed_jaxpr.jaxpr.debug_info) return core.ClosedJaxpr(jaxpr, consts), in_mut @@ -2181,8 +2182,7 @@ def lower_sharding_computation( The caller of this code can pass in a singleton UNSPECIFIED because the number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply - the singleton UNSPECIFIED to all out_avals. - """ + the singleton UNSPECIFIED to all out_avals.""" auto_spmd_lowering = check_if_any_auto( it.chain.from_iterable([in_shardings, out_shardings])) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b9ce8ae09380..4df4c517090b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -17,7 +17,7 @@ from collections.abc import Callable, Sequence from functools import partial import inspect -import itertools +import itertools as it import operator from typing import Any, TypeVar import weakref @@ -438,11 +438,11 @@ def _merge_attrs_out(attrs_tracked, out_state, out_append): out_attrs = [] for _, out_tree, (_, _, k) in attrs_tracked: if k in (pe.ReadWrite, pe.BoxAttr): - out_attrs.extend(itertools.islice(out_state_, out_tree.num_leaves)) + out_attrs.extend(it.islice(out_state_, out_tree.num_leaves)) elif k is pe.Append: out_attrs.append(next(out_append_)) elif k is pe.ListAttr: - out_attrs.extend(itertools.islice(out_append_, out_tree.num_leaves)) + out_attrs.extend(it.islice(out_append_, out_tree.num_leaves)) else: assert False assert next(out_state_, None) is next(out_append_, None) is None @@ -931,7 +931,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in itertools.chain(carry_avals, ys_avals)] + for a in it.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. linear_unknown = tuple([False] * len(intensive_res) + @@ -1500,6 +1500,17 @@ def arrange_jaxpr_args_for_wrapped(args): assert len(refs_out_matching_in_avals) == len(in_avals) return refs_out_matching_in_avals, [*carry_out, *ys] +def _scan_staging(trace, *args, **params): + outs = trace.default_process_primitive(scan_p, args, params) + jaxpr = params['jaxpr'] + trace.frame.is_high = jaxpr.jaxpr.is_high + invars = [trace.frame.tracer_to_var[id(t)] for t in args] + var_map = dict(zip(jaxpr.jaxpr.invars, invars)) + final_env = {var_map[v]: ty for v, ty in + jaxpr.jaxpr.final_typechange_env.items()} + trace.frame.current_typechange_env.update(final_env) + return outs + scan_p = core.Primitive("scan") scan_p.multiple_results = True scan_p.skip_canonicalization = True @@ -1518,6 +1529,65 @@ def arrange_jaxpr_args_for_wrapped(args): pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule) +pe.custom_staging_rules[scan_p] = _scan_staging + +def _is_high(jaxpr, **_) -> bool: + return jaxpr.jaxpr.is_high +scan_p.is_high = _is_high # type: ignore + +def _to_lojax(*hi_args, jaxpr, num_carry, num_consts, linear, **params): + ienv, fenv = jaxpr.jaxpr.initial_typechange_env, jaxpr.jaxpr.final_typechange_env + + # move box binders and hi_args from consts slots to carry slots + to_move = [t.mutable for t in jaxpr.in_avals[:num_consts]] + jaxpr = pe.move_invars_right(jaxpr, to_move) + hi_args = _move_right(hi_args, to_move) + num_consts -= sum(to_move) + num_carry += sum(to_move) + + # expand num_consts, num_carry, linear according to lo types + const_invars, carry_invars, _ = split_list(jaxpr.jaxpr.invars, [num_consts, num_carry]) + num_consts = sum(len(v.aval.lo_ty() if not v.aval.mutable + else v.aval.lo_ty_(ienv[v])) for v in const_invars) + num_carry = sum(len(v.aval.lo_ty() if not v.aval.mutable + else v.aval.lo_ty_(ienv[v])) for v in carry_invars) + linear = [l for v, l_ in zip(jaxpr.jaxpr.invars, linear) + for l in (l_,) * len(v.aval.lo_ty() if not v.aval.mutable + else v.aval.lo_ty_(ienv[v]))] + lo_muts_out = sum(len(m.leaf_avals) for m in fenv.values()) # TODO hardcoded + + # collect lo inputs values + lo_args = [lo_val for v, x in zip(jaxpr.jaxpr.invars, hi_args) + for lo_val in (v.aval.read_loval(ienv[v], x) if v.aval.mutable + else v.aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = scan_p.bind(*lo_args, jaxpr=lo_jaxpr, num_consts=num_consts, + num_carry=num_carry, linear=tuple(linear), **params) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + + # collect and apply mutations + out_mut_ = iter(out_mut) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + for var, ty in jaxpr.jaxpr.final_typechange_env.items(): + lo_vals = it.islice(out_mut_, len(var.aval.lo_ty_(ty))) + var.aval.update_from_loval(ty, hi_args[in_idx[var]], *lo_vals) + assert next(out_mut_, None) is None + + # collect output values into hi types + lo_outs_ = iter(lo_outs) + hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) + for t in jaxpr.out_avals] + assert next(lo_outs_, None) is None + + return hi_outs +scan_p.to_lojax = _to_lojax + +def _move_right(lst, to_move): + lst, rest = split_list(lst, [len(to_move)]) + left, right = partition_list(to_move, lst) + return [*left, *right, *rest] def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d5286be8e0c9..c767647195b8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -590,6 +590,8 @@ def _infer_params_impl( in_type = in_avals = tuple(core.shaped_abstractify(x) for x in explicit_args) # type: ignore else: in_type = in_avals # type: ignore + in_type = tuple(core.TypeChange(a, x.type_state(), None) if a.mutable # type: ignore + else a for a, x in zip(in_type, explicit_args)) assert in_avals is not None in_shardings_flat, in_layouts_flat = _process_in_axis_resources( @@ -705,7 +707,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked or p.box_data: # if attrs/boxes, don't populate cache + if p.attrs_tracked or p.box_data or p.params['jaxpr'].jaxpr.is_high: return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs @@ -1407,16 +1409,14 @@ def _create_pjit_jaxpr( lu.annotate(fun, cast(core.InputType, in_type))) attrs_tracked = [] else: - jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type) - # assert attr_data is sentinel or attr_data matches attrs_tracked + jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type) if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import check_key_reuse_jaxpr check_key_reuse_jaxpr(jaxpr) - if any(isinstance(c, core.Tracer) for c in consts): + if any(isinstance(c, core.Tracer) or core.typeof(c).mutable for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) final_consts = consts else: @@ -1561,21 +1561,41 @@ def _is_high(jaxpr, **_) -> bool: return jaxpr.jaxpr.is_high pjit_p.is_high = _is_high # type: ignore -def _to_lojax( *hi_args, jaxpr, **params): - params, num_mutants = _lojax_expand_params(jaxpr, **params) +def _to_lojax(*hi_args, jaxpr, **params): + ienv, fenv = jaxpr.jaxpr.initial_typechange_env, jaxpr.jaxpr.final_typechange_env - lo_args = [lo_val for t, hi_val in zip(jaxpr.in_avals, hi_args) - for lo_val in t.lower_val(hi_val)] + # convert closed-over boxes to explicit args + jaxpr, closed_over_himutables = pe.convert_const_himutables(jaxpr) + hi_args = [*closed_over_himutables, *hi_args] + params = _converted_mutables_add_params(len(closed_over_himutables), **params) + + # expand pjit params that must match number of lo inputs/outputs + lo_nums_in = [len(v.aval.lo_ty() if not v.aval.mutable + else v.aval.lo_ty_(ienv[v])) + for v in jaxpr.jaxpr.invars] + lo_nums_out = [len(t.lo_ty()) for t in jaxpr.out_avals] + lo_muts_out = sum(len(m.leaf_avals) for m in fenv.values()) # TODO hardcoded + params = _lojax_expand_params(lo_nums_in, lo_nums_out, lo_muts_out, **params) + + # collect lo input values + lo_args = [lo_val for v, x in zip(jaxpr.jaxpr.invars, hi_args) + for lo_val in (v.aval.read_loval(ienv[v], x) if v.aval.mutable + else v.aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values lo_jaxpr = pe.lower_jaxpr(jaxpr) all_outs = pjit_p.bind(*lo_args, jaxpr=lo_jaxpr, **params) - out_mut, lo_outs = split_list(all_outs, [num_mutants]) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + # collect and apply mutations out_mut_ = iter(out_mut) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} for var, ty in jaxpr.jaxpr.final_typechange_env.items(): - ty.set(hi_args[in_idx[var]], *it.islice(out_mut_, len(ty.lo_ty()))) + lo_vals = it.islice(out_mut_, len(var.aval.lo_ty_(ty))) + var.aval.update_from_loval(ty, hi_args[in_idx[var]], *lo_vals) assert next(out_mut_, None) is None + # collect output values into hi types lo_outs_ = iter(lo_outs) hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) for t in jaxpr.out_avals] @@ -1584,29 +1604,35 @@ def _to_lojax( *hi_args, jaxpr, **params): return hi_outs pjit_p.to_lojax = _to_lojax +def _converted_mutables_add_params( + n, *, donated_invars, in_shardings, in_layouts, **params): + donated_invars = (False,) * n + donated_invars + in_shardings = (UNSPECIFIED,) * n + in_shardings + in_layouts = (None,) * n + in_layouts + return dict(params, donated_invars=donated_invars, in_shardings=in_shardings, + in_layouts=in_layouts) + def _lojax_expand_params( - hi_jaxpr, *, donated_invars, in_shardings, in_layouts, out_shardings, - out_layouts, **params): + nums_in, nums_out, muts_out, *, donated_invars, in_shardings, in_layouts, + out_shardings, out_layouts, **params): # some pjit params match the length of hi_jaxpr.invars/outvars, so when # lowering we must expand them to match their number of lojax types - def expand(hi_tys, xs): - return tuple(y for hi, x in zip(hi_tys, xs) for y in (x,) * len(hi.lo_ty())) - donated_invars = expand(hi_jaxpr.in_avals , donated_invars) - in_shardings = expand(hi_jaxpr.in_avals , in_shardings ) - in_layouts = expand(hi_jaxpr.in_avals , in_layouts ) - out_shardings = expand(hi_jaxpr.out_avals, out_shardings ) - out_layouts = expand(hi_jaxpr.out_avals, out_layouts ) + def expand(ns, xs): + return tuple(y for n, x in zip(ns, xs) for y in (x,) * n) + donated_invars = expand(nums_in , donated_invars) + in_shardings = expand(nums_in , in_shardings ) + in_layouts = expand(nums_in , in_layouts ) + out_shardings = expand(nums_out, out_shardings ) + out_layouts = expand(nums_out, out_layouts ) # also, the lo_jaxpr has pure outputs corresponding to mutable hi_jaxpr types - num_mutants = sum(len(hi_ty.lo_ty()) for hi_ty in - hi_jaxpr.jaxpr.final_typechange_env.values()) - out_shardings = (UNSPECIFIED,) * num_mutants + out_shardings - out_layouts = (None,) * num_mutants + out_layouts + out_shardings = (UNSPECIFIED,) * muts_out + out_shardings + out_layouts = (None,) * muts_out + out_layouts new_params = dict(params, donated_invars=donated_invars, in_shardings=in_shardings, in_layouts=in_layouts, out_shardings=out_shardings, out_layouts=out_layouts) - return new_params, num_mutants + return new_params def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): @@ -1948,6 +1974,7 @@ def pjit_staging_rule(trace, *args, **params): jaxpr = params['jaxpr'] source_info = source_info_util.current() + consts = [] if config.dynamic_shapes.value: jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( jaxpr, params['out_shardings'], params['out_layouts']) @@ -1981,6 +2008,14 @@ def pjit_staging_rule(trace, *args, **params): pjit_p, (*args, *consts), new_params) else: out_tracers = trace.default_process_primitive(pjit_p, args, params) + + trace.frame.is_high = jaxpr.jaxpr.is_high + invars = [trace.frame.tracer_to_var[id(t)] for t in it.chain(args, consts)] + var_map = dict(zip(jaxpr.jaxpr.invars, invars)) + final_env = {var_map[v]: ty for v, ty in + jaxpr.jaxpr.final_typechange_env.items()} + trace.frame.current_typechange_env.update(final_env) + return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 60a3753a7ba5..90083626fb8e 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -1056,7 +1056,7 @@ def f(x): self.assertAllClose(box.get(), 2.0) @parameterized.parameters([False, True]) - def test_grad_closrue_stop_gradient(self, jit): + def test_grad_closure_stop_gradient(self, jit): box = Box(0.0) def f(x): @@ -1124,7 +1124,6 @@ def f(lst, x): lst.append(2.0) lst.append({'c': x + 3.0}) - tracing_ok = True lst1 = List() f(lst1, 0) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index 21034d164d28..8b7d045c6c5a 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -17,18 +17,19 @@ from dataclasses import dataclass from functools import partial import itertools as it +from typing import Any import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax import jax.numpy as jnp from jax._src import config from jax._src import core -from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import partial_eval as pe +from jax._src import ad_util from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map @@ -37,6 +38,8 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +PyTreeDef = Any + # TODO(mattjj,dougalm): move HiPrimitive, Box, etc out of tests and into library class HiPrimitive(core.Primitive): @@ -65,124 +68,6 @@ def jvp(self, primals, tangents, **params): def transpose(self, *args, **params): assert False # TODO - -class BoxTy(core.AbstractValue): - mutable = True - - def __init__(self, leaf_avals, treedef): - self._leaf_avals = leaf_avals # hijax avals - self._treedef = treedef - - # aval interface: hashability and str_short - def __hash__(self): - return hash((self._leaf_avals, self._treedef)) - - def __eq__(self, other): - return (isinstance(other, BoxTy) and self._leaf_avals == other._leaf_avals - and self._treedef == other._treedef) - - def str_short(self, short_dtypes=False): - return 'BoxTy' - - # hijax interface: lower val, raise val, and low type - def lo_ty(self): - return [lo_aval for hi_aval in self._leaf_avals for lo_aval in hi_aval.lo_ty()] - - def lower_val(self, box): - leaf_vals, treedef = jax.tree.flatten(box._val) - assert treedef == self._treedef - return [lo_val for hi_aval, hi_val in zip(self._leaf_avals, leaf_vals) - for lo_val in hi_aval.lower_val(hi_val)] - - def raise_val(self, *lo_vals): - lo_vals_ = iter(lo_vals) - hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) - for hi_ty in self._leaf_avals] - assert next(lo_vals_, None) is None - return Box(jax.tree.unflatten(self._treedef, hi_vals)) # will be mutated - - # mutable interface: get/set - def get(self, box): - leaf_vals, treedef = jax.tree.flatten(box._val) - assert treedef == self._treedef - return [lo_val for hi_ty, hi_val in zip(self._leaf_avals, leaf_vals) - for lo_val in hi_ty.lower_val(hi_val)] - - def set(self, box, *lo_vals): - lo_vals_ = iter(lo_vals) - hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) - for hi_ty in self._leaf_avals] - assert next(lo_vals_, None) is None - box._val = jax.tree.unflatten(self._treedef, hi_vals) - - # TODO placeholder thing - def to_tangent_aval(self): - return core.ShapedArray((), dtypes.float0) # TODO revise placeholder - -class Box: # noqa: F811 - def __init__(self, val): - self._val = val - - @property - def ty(self): - leaves, treedef = jax.tree.flatten(self._val) - leaf_avals = tuple(map(core.typeof, leaves)) - return BoxTy(leaf_avals, treedef) -core.pytype_aval_mappings[Box] = lambda b: b.ty - - -class BoxSet(HiPrimitive): - multiple_results = True - - def is_high(self, *, treedef) -> bool: return True - - def staging(self, trace, box, *leaves, treedef): - super().staging(trace, box, *leaves, treedef=treedef) - avals = tuple(t.aval for t in leaves) - trace.frame.final_typechange_env[trace.getvar(box)] = BoxTy(avals, treedef) - - def abstract_eval(self, box_ty, *leaf_avals, treedef): - return [], set() # TODO better typechecking... - - def to_lojax(_, box, *leaves, treedef): - box._val = jax.tree.unflatten(treedef, leaves) - return [] - - def jvp(_, primals, tangents, *, treedef): - assert False # TODO - - def transpose(_, *args, treedef): - assert False # TODO -box_set_p = BoxSet('box_set') - -def box_set(box, val): - leaves, treedef = jax.tree.flatten(val) - box_set_p.bind(box, *leaves, treedef=treedef) - - -class BoxGet(HiPrimitive): - multiple_results = True - - def is_high(self) -> bool: return True - - def abstract_eval(self, box_ty): - return box_ty._leaf_avals, set() - - def to_lojax(_, box): - return jax.tree.leaves(box._val) - - def jvp(_, primals, tangents): - assert False # TODO - - def transpose(_, *args): - assert False # TODO -box_get_p = BoxGet('box_get') - -def box_get(box): - leaf_vals = box_get_p.bind(box) - return jax.tree.unflatten(core.typeof(box)._treedef, leaf_vals) - - class HijaxTest(jtu.JaxTestCase): def test_custom_types_and_primitive(self): @@ -194,8 +79,6 @@ class MyArray: @dataclass(frozen=True) class MyTy(core.AbstractValue): - mutable = False - def to_tangent_aval(self): return MyTy() def str_short(self, short_dtypes=False): @@ -324,6 +207,392 @@ def f(x): self.assertIsInstance(a_grad, MyArray) self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) + +def new_box(): + (), treedef = jax.tree.flatten(None) + return new_box_p.bind(treedef=treedef) + +def box_get(box): + tys = box.type_state() + leaf_vals = box_get_p.bind(box, avals=tys.leaf_avals) + return jax.tree.unflatten(tys.treedef, leaf_vals) + +def box_set(box, val): + leaves, treedef = jax.tree.flatten(val) + box_set_p.bind(box, *leaves, treedef=treedef) + +@dataclass(frozen=True) +class BoxTypeState: + leaf_avals: tuple[core.AbstractValue, ...] + treedef: PyTreeDef + + def to_tangent_aval(self): + return BoxTypeState(tuple(a.to_tangent_aval() for a in self.leaf_avals), + self.treedef) + + def normalize(self): + return BoxTypeState(tuple(a.normalize() for a in self.leaf_avals), + self.treedef) + +class BoxTy(core.AbstractValue): + mutable = True + + # forwarded to value + get = core.aval_method(box_get) + set = core.aval_method(box_set) + + # aval interface: hashability and str_short + def __hash__(self): return hash(BoxTy) + def __eq__(self, other): return isinstance(other, BoxTy) + + def str_short(self, short_dtypes=False): + return 'BoxTy' + + # mutable interface + def lo_ty_(self, box_state): + return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()] + + def new_from_loval(self, box_state: BoxTypeState, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + return Box(jax.tree.unflatten(box_state.treedef, hi_vals)) # will be mutated + + def read_loval(self, box_state: BoxTypeState, box): + leaf_vals, treedef = jax.tree.flatten(box_get(box)) + assert treedef == box_state.treedef + return [lo_val for hi_ty, hi_val in zip(box_state.leaf_avals, leaf_vals) + for lo_val in hi_ty.lower_val(hi_val)] + + def update_from_loval(self, box_state: BoxTypeState, box, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + box_set(box, jax.tree.unflatten(box_state.treedef, hi_vals)) + + def to_tangent_aval(self): + return BoxTy() + +class Box: # noqa: F811 + def __init__(self, val): + self._val = val + + def get(self): + return box_get(self) + + def set(self, val): + box_set(self, val) + + @property + def ty(self): + return BoxTy() + + def type_state(self): + leaves, treedef = jax.tree.flatten(self._val) + leaf_avals = tuple(map(core.typeof, leaves)) + return BoxTypeState(leaf_avals, treedef) +core.pytype_aval_mappings[Box] = lambda b: b.ty + + +class NewBox(HiPrimitive): + def is_high(self, *, treedef) -> bool: return True + + def staging(self, trace, *, treedef): + tracer = super().staging(trace, treedef=treedef) + var = trace.frame.tracer_to_var[id(tracer)] + leaves, treedef = jax.tree.flatten(None) + trace.frame.current_typechange_env[var] = BoxTypeState(leaves, treedef) + return tracer + + def abstract_eval(self, *, treedef): + return BoxTy(), set() + + def to_lojax(_, *, treedef): + return Box(None) + + def jvp(_, primals, tangents, *, treedef): + assert False # TODO + + def transpose(_, *args, treedef): + assert False # TODO +new_box_p = NewBox('new_box') + + +class BoxSet(HiPrimitive): + multiple_results = True + + def is_high(self, *, treedef) -> bool: return True + + def staging(self, trace, box_tracer, *leaves, treedef): + super().staging(trace, box_tracer, *leaves, treedef=treedef) + var = trace.getvar(box_tracer) + avals = tuple(t.aval for t in leaves) + trace.frame.current_typechange_env[var] = BoxTypeState(avals, treedef) + return [] + + def abstract_eval(self, box_ty, *leaf_avals, treedef): + return [], set() # TODO better typechecking... + + def to_lojax(_, box, *leaves, treedef): + box._val = jax.tree.unflatten(treedef, leaves) + return [] + + def jvp(_, primals, tangents, *, treedef): + box, *vals = primals + box_dot, *val_dots = tangents + if type(box_dot) is ad_util.Zero: + raise Exception("you're an idiot") + box_set_p.bind(box, *vals, treedef=treedef) + box_set_p.bind(box_dot, *val_dots, treedef=treedef) + return [], [] + + def transpose(_, *args, treedef): + assert False # TODO +box_set_p = BoxSet('box_set') + + +class BoxGet(HiPrimitive): + multiple_results = True + + def abstract_eval(self, box_ty, *, avals): + return avals, set() + + def to_lojax(_, box, *, avals): + return jax.tree.leaves(box._val) + + def jvp(_, primals, tangents, *, avals): + (box,), (box_dot,) = primals, tangents + return (box_get_p.bind(box, avals=avals), + box_get_p.bind(box_dot, avals=[a.to_tangent_aval() for a in avals])) + + def transpose(_, *args): + assert False # TODO +box_get_p = BoxGet('box_get') + + + +class BoxTest(jtu.JaxTestCase): + + def test_jit_arg(self): + @jax.jit + def f(box, x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f(box1, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f(box2, 2.) + self.assertAllClose(box2.get(), 4.0) + + def test_jit_arg2(self): + # set without get + + @jax.jit + def f(box, x): + box_set(box, x) + + box = Box(0.0) + f(box, 1.) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + + def test_jit_arg_in_pytree(self): + @jax.jit + def f(dct, x): + assert tracing_ok + box = dct['box'] + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f({'box': box1, 'a': 1.0}, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f({'box': box2, 'a': 2.0}, 2.) + self.assertAllClose(box2.get(), 4.0) + + tracing_ok = True + box3 = Box(3) # int, dtype changed + f({'box': box3, 'a': 2.0}, 2.) + self.assertAllClose(box3.get(), 5.0) + + def test_jit_closure(self): + box = Box(1.0) + + @jax.jit + def f(x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + f(2.0) + self.assertAllClose(box.get(), 3.0) + tracing_ok = False + f(5.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested(self): + box = Box(5.0) + + @jax.jit + def f(x): + box.set(box.get() + x) + + @jax.jit + def g(x): + f(x) + + g(3.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested2(self): + @jax.jit + def h(x): + box = new_box() + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + @parameterized.parameters([False, True]) + def test_jvp_closure_stop_gradient(self, jit): + box = Box(1.0) + + def f(x): + y = 2 * x + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + y, y_dot = jax.jvp(f, (1.0,), (1.0,)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 2.0) + self.assertAllClose(box.get(), 3.0) + + @parameterized.parameters([False, True]) + def test_jvp_arg(self, jit): + def f(box, x): + box.set(box.get() + x) + return x + + if jit: + f = jax.jit(f) + + box = Box(5.0) + box_dot = Box(1.0) + y, y_dot = jax.jvp(f, (box, 2.), (box_dot, 1.)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 1.0) + self.assertAllClose(box.get(), 7.0) + self.assertAllClose(box_dot.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + box.set(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + self.assertAllClose(box.get(), 2.0) + + # TODO(mattjj,dougalm): make this work... + # @parameterized.parameters([False, True]) + # def test_custom_vjp_plumbing_abstracted(self, jit): + # box = Box(0.0) + + # @jax.custom_vjp + # def foo(box, x): + # return x + # def foo_fwd(box, x): + # return x, box + # def foo_bwd(box, g): + # box.set(g) + # return None, g + # foo.defvjp(foo_fwd, foo_bwd) + + # def f(box, x): + # x = 2 * x + # x = foo(box, x) + # x = 2 * x + # return x + + # if jit: + # f = jax.jit(f) + + # jax.grad(partial(f, box))(1.0) + # self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_grad_closure_stop_gradient(self, jit): + box = Box(0.0) + + def f(x): + y = x * 2 + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(1.0) + self.assertAllClose(g, 2.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + box = Box(1.0) + + def double_it_10(): + def body(_, __): + box.set(box.get() * 2) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(box.get(), 1024., check_dtypes=False) + + # TODO error-checking tests from attrs_test.py + + ### + def test_box_autodiff(self): if config.enable_x64.value: raise unittest.SkipTest("no x64") @@ -336,7 +605,7 @@ def abstract_eval(_, box_aval, x_aval): return x_aval, set() def to_lojax(_, box, x): - assert False # TODO + return x def jvp(_, primals, tangents): box, x = primals @@ -351,14 +620,6 @@ def transpose(self, *args): def stash_tangents(box, x): return stash_tangents_p.bind(box, x) - @jax.jit - def f(box, x): - box_set(box, x) - - box = Box(0.0) - f(box, 1.) - self.assertAllClose(box_get(box), 1.0, check_dtypes=False) - @jax.jit def f(box, x): x = stash_tangents(box, x) @@ -449,5 +710,26 @@ def f(box): self.assertAllClose(b_.arr, 2, check_dtypes=False) +class ListTy(core.AbstractValue): + mutable = True + + # forwarded to value + get = core.aval_method(box_get) + set = core.aval_method(box_set) + + # aval interface: hashability and str_short + def __hash__(self): return hash(BoxTy) + def __eq__(self, other): return isinstance(other, BoxTy) + + def str_short(self, short_dtypes=False): + return 'ListTy' + + # TODO + +class ListTest(jtu.JaxTestCase): + ... + + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 2c838d4ae3fff34242d5a2993e6892c8214c9177 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 29 May 2025 16:05:10 -0700 Subject: [PATCH 1438/1769] rename layout to format, part 1 We want to rename layouts to formats, since layout is overloaded, both with its own device-local layout, and with XLA's layout (which corresponds to the device-local layout). This change specifically focuses on renaming the `Layout` type to `Format` and using the new type constructor throughout the codebase. It should have minimal external effect, since it sets up `Layout` as a public alias of the newly renamed `Format`. This change does not yet change most variable and attribute names, so it leaves around various names like `layout` that now have type `Format`. Next up, we should rename these for clarity, among other things. Co-authored-by: Yash Katariya PiperOrigin-RevId: 764922667 --- jax/_src/api.py | 37 ++++--- jax/_src/array.py | 18 ++-- jax/_src/dispatch.py | 14 +-- jax/_src/interpreters/pxla.py | 4 +- jax/_src/layout.py | 6 +- jax/_src/pjit.py | 12 +-- jax/_src/stages.py | 51 ++++++--- .../array_serialization/serialization.py | 4 +- .../array_serialization/serialization_test.py | 4 +- .../array_serialization/tensorstore_impl.py | 12 +-- jax/experimental/layout.py | 4 +- tests/layout_test.py | 100 +++++++++--------- tests/memories_test.py | 18 ++-- tests/pjit_test.py | 6 +- 14 files changed, 159 insertions(+), 131 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 3ff103997dc7..2630fc7ae1be 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -74,7 +74,7 @@ from jax._src.mesh import get_concrete_mesh from jax._src.sharding_impls import ( PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding) -from jax._src.layout import Layout, AutoLayout +from jax._src.layout import Format, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list @@ -2501,10 +2501,10 @@ def _check_string_compatible_sharding(s): @lru_cache(maxsize=2048) def _check_sharding(aval, s): if (s is not None and - not isinstance(s, (xc.Device, Sharding, Layout, TransferToMemoryKind))): + not isinstance(s, (xc.Device, Sharding, Format, TransferToMemoryKind))): raise ValueError( "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," - " `jax.Device`, `Layout` or a pytree of these values. Received" + " `jax.Device`, `Format` or a pytree of these values. Received" f" invalid value: {s}") if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype): @@ -2530,8 +2530,8 @@ def pspec_to_sharding(val): def device_put( x, - device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, + device: None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind = None, donate: bool | Any = False, may_alias: bool | None | Any = None): """Transfers ``x`` to ``device``. @@ -2827,18 +2827,18 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False): if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - if sharding is not None and not isinstance(sharding, (Sharding, Layout, P)): + if sharding is not None and not isinstance(sharding, (Sharding, Format, P)): raise ValueError( "sharding should be an instance of `jax.sharding.Sharding`, " "`jax.sharding.PartitionSpec` or" - f" `jax.experimental.layout.Layout`. Got {sharding} of type" + f" `jax.experimental.layout.Format`. Got {sharding} of type" f" {type(sharding)}.") - if (isinstance(sharding, Layout) and + if (isinstance(sharding, Format) and isinstance(sharding.device_local_layout, AutoLayout)): raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout in a `ShapeDtypeStruct`. Got {sharding}") - if isinstance(sharding, Layout): + if isinstance(sharding, Format): self.sharding = sharding.sharding elif isinstance(sharding, P): # TODO(yashkatariya): Should this be abstract mesh? @@ -2851,15 +2851,18 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False): self.sharding = NamedSharding(cur_mesh, sharding) else: self.sharding = sharding - self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + self._dll = (sharding.device_local_layout if isinstance(sharding, Format) + else None) self.weak_type = weak_type size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @property - def layout(self): - return Layout(self._dll, self.sharding) + def format(self): + return Format(self._dll, self.sharding) + + layout = format def __len__(self): try: @@ -2869,7 +2872,7 @@ def __len__(self): def __repr__(self): sh = f", sharding={self.sharding}" if self.sharding is not None else "" - l = f", layout={self.layout}" if self._dll is not None else "" + l = f", format={self._dll}" if self._dll is not None else "" wt = f", weak_type={self.weak_type}" if self.weak_type else "" return (f"{type(self).__name__}(shape={self.shape}, " f"dtype={self.dtype.name}{sh}{l}{wt})") @@ -2880,11 +2883,13 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) == - (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) + return ((self.shape, self.dtype, self.sharding, self._dll, self.weak_type) == + (other.shape, other.dtype, other.sharding, other._dll, other.weak_type)) def __hash__(self): - return hash((self.shape, self.dtype, self.sharding, self.layout, + # TODO(frostig): avoid the conversion from dict by addressing + # https://github.com/jax-ml/jax/issues/8182 + return hash((self.shape, self.dtype, self.sharding, self._dll, self.weak_type)) def __setattr__(self, name, value): diff --git a/jax/_src/array.py b/jax/_src/array.py index 422fa5086e62..29c7a17b07f1 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -36,7 +36,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla -from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout +from jax._src.layout import AutoLayout, DeviceLocalLayout, Format from jax._src.lib import xla_client as xc from jax._src.lib import _jax from jax._src.sharding import Sharding @@ -550,14 +550,14 @@ def addressable_shards(self) -> Sequence[Shard]: def layout(self): # TODO(yashkatariya): Remove the deleted check from here. if self.is_deleted(): - return Layout(None, self.sharding) + return Format(None, self.sharding) try: - return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + return Format(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), self.sharding) except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return Layout(None, self.sharding) + return Format(None, self.sharding) else: raise @@ -711,7 +711,7 @@ def _get_and_check_dtype(arrays: Sequence[basearray.Array | np.ndarray], # TODO(yashkatariya): Remove None from callback input type. def make_array_from_callback( - shape: Shape, sharding: Sharding | Layout, + shape: Shape, sharding: Sharding | Format, data_callback: Callable[[Index | None], ArrayLike], dtype: DTypeLike | None = None) -> ArrayImpl: # pyformat: disable @@ -756,12 +756,12 @@ def make_array_from_callback( (4, 2) """ # pyformat: enable - dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + dll = sharding.device_local_layout if isinstance(sharding, Format) else None if isinstance(dll, AutoLayout): raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Layout) else sharding + sharding = sharding.sharding if isinstance(sharding, Format) else sharding if not isinstance(sharding, Sharding): raise TypeError( f"sharding should be an instance of `jax.sharding`. Got {sharding} of" @@ -823,7 +823,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: ) if dll is not None: - devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices] + devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # pxla.batched_device_put doesn't support Layout... Take the slow route arrays = api.device_put(per_device_values, devices) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -1218,7 +1218,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): batch_cs.append(cs) # Resharding starts here: elif not same_layout: - results.append(api.device_put(x, Layout(layout, sharding))) + results.append(api.device_put(x, Format(layout, sharding))) elif dispatch.is_single_device_sharding(x.sharding): results.append(shard_device_array(x, devices, indices, sharding)) else: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index d1ea7439cb0c..a7c8d4ea7380 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -43,7 +43,7 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.api_util import InternalFloatingPointError -from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.layout import DeviceLocalLayout, Format from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_scalar, record_event_duration_secs, record_event_time_span @@ -479,8 +479,8 @@ def _device_put_sharding_impl(x, aval, device, copy): def _device_put_impl( - x, *, device: Device | Sharding | Layout | None, - src: Device | Sharding | Layout | None, copy: CopySemantics): + x, *, device: Device | Sharding | Format | None, + src: Device | Sharding | Format | None, copy: CopySemantics): if (isinstance(device, TransferToMemoryKind) or isinstance(src, TransferToMemoryKind)): raise ValueError( @@ -494,7 +494,7 @@ def _device_put_impl( raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err - if isinstance(device, Layout): + if isinstance(device, Format): l = device dll = l.device_local_layout x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None @@ -519,8 +519,8 @@ def _device_put_impl( def _batched_device_put_impl( *xs, - devices: Sequence[Device | Sharding | Layout | None], - srcs: Sequence[Device | Sharding | Layout | None], + devices: Sequence[Device | Sharding | Format | None], + srcs: Sequence[Device | Sharding | Format | None], copy_semantics: Sequence[CopySemantics]): ys = [] dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], [] @@ -536,7 +536,7 @@ def _batched_device_put_impl( if dsa_xs: # Batch shard_arg calls. Helps improve efficiency for backends that support # efficient batch transfer. - # device_put handles `Layout` via a different path, so just pass `None` as + # device_put handles `Format` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs), dsa_copy_semantics, dsa_xs) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d0a22cd784b4..8ebf4133a5fd 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -56,7 +56,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout +from jax._src.layout import DeviceLocalLayout, AutoLayout, Format from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -208,7 +208,7 @@ def _shard_np_array(xs, shardings, layouts, copy_semantics): x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = core.shaped_abstractify(x) if layout is not None: - results.append(api.device_put(x, Layout(layout, sharding))) + results.append(api.device_put(x, Format(layout, sharding))) else: if sharding.is_fully_replicated: shards = [x] * len(devices) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 3675433c43d8..c50c1787b94e 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -94,7 +94,7 @@ def check_compatible_aval(self, aval_shape: Shape): ShardingOptions = Union[Sharding, None, AutoSharding] -class Layout: +class Format: __slots__ = ['device_local_layout', 'sharding'] def __init__(self, device_local_layout: LayoutOptions = None, @@ -139,7 +139,9 @@ def __hash__(self): return hash((self.device_local_layout, self.sharding)) def __eq__(self, other): - if not isinstance(other, Layout): + if not isinstance(other, Format): return False return (self.device_local_layout == other.device_local_layout and self.sharding == other.sharding) + +Layout = Format # TODO(frostig, yashkatariya): remove this alias diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d5286be8e0c9..94a754a1a597 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -69,7 +69,7 @@ SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, flatten_spec, _internal_use_concrete_mesh) -from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout +from jax._src.layout import Format, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( @@ -374,13 +374,13 @@ def _split_layout_and_sharding(entries): layouts, shardings = [], [] for e in entries_flat: - if isinstance(e, Layout): + if isinstance(e, Format): layouts.append(e.device_local_layout) shardings.append(e.sharding) elif isinstance(e, (DeviceLocalLayout, AutoLayout)): raise ValueError( '`jax.jit` does not accept device-local layouts directly. Create ' - 'a `Layout` instance wrapping this device-local layout and pass ' + 'a `Format` instance wrapping this device-local layout and pass ' f'that to `jit` instead. Got {e}') else: layouts.append(None) @@ -1645,7 +1645,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): else: # arg_layout can be None because some backends don't implement the # required layout methods. Hence `arr.layout` can return - # `Layout(None, sharding)` + # `Format(None, sharding)` if (committed and not is_pmap_sharding and arg_layout is not None @@ -2813,7 +2813,7 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and x.sharding.is_equivalent_to(sharding, x.ndim)): return x - return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) + return api.jit(_identity_fn, out_shardings=Format(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") @@ -3160,7 +3160,7 @@ def _layout_constraint_impl(x, *, layout): f' jax.Arrays. Got {type(x)}') if x.layout.device_local_layout == layout: # type: ignore return x - return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x) + return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x) layout_constraint_p.def_impl(_layout_constraint_impl) def _layout_constraint_hlo_lowering(ctx, x_node, *, layout): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index d92d1ccb2aa3..17649aae3081 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -46,7 +46,7 @@ from jax._src import typing from jax._src import util from jax._src.sharding_impls import UnspecifiedValue, AUTO -from jax._src.layout import Layout +from jax._src.layout import Format, DeviceLocalLayout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib import _jax @@ -105,7 +105,7 @@ def input_layouts(self): def output_layouts(self): raise NotImplementedError( - "compiled executable carries no input layout information") + "compiled executable carries no output layout information") def as_text(self) -> str: """A human-readable text representation of this executable. @@ -438,39 +438,58 @@ def runtime_executable(self) -> Any | None: """ return self._executable.runtime_executable() - @property - def input_shardings(self): # PyTree[sharding.Sharding] + def _input_shardings_flat(self): shardings_flat = self._executable._in_shardings # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): iter_shardings_flat = iter(shardings_flat) shardings_flat = [next(iter_shardings_flat) if i in self._executable._kept_var_idx else None for i in range(self.in_tree.num_leaves)] + return shardings_flat + + @property + def input_shardings(self): # -> PyTree[sharding.Sharding] + shardings_flat = self._input_shardings_flat() return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.Sharding] + def output_shardings(self): # -> PyTree[sharding.Sharding] shardings_flat = self._executable._out_shardings return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error - @property - def input_layouts(self): - dll_flat = self._executable._xla_in_layouts - layouts_flat = [Layout(l, s) - for l, s in zip(dll_flat, self._executable._in_shardings)] + def _input_layouts_flat(self): + layouts_flat = self._executable._xla_in_layouts # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx - else Layout() for i in range(self.in_tree.num_leaves)] - return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + else None for i in range(self.in_tree.num_leaves)] + return layouts_flat + + @property + def input_formats(self): + layouts_flat = self._input_layouts_flat() + shardings_flat = self._input_shardings_flat() + formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + return tree_util.tree_unflatten(self.in_tree, formats_flat) # pytype: disable=attribute-error + + @property + def output_formats(self): + layouts_flat = self._executable._xla_out_layouts + shardings_flat = self._executable._out_shardings + assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat) + formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + return tree_util.tree_unflatten(self.out_tree, formats_flat) # pytype: disable=attribute-error + + # TODO(frostig, yashkatariya): remove + @property + def input_layouts(self): + return self.input_formats + # TODO(frostig, yashkatariya): remove @property def output_layouts(self): - dll_flat = self._executable._xla_out_layouts - layouts_flat = [Layout(l, s) - for l, s in zip(dll_flat, self._executable._out_shardings)] - return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error + return self.output_formats @staticmethod def call(*args, **kwargs): diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 82e9e3dc938b..44b2eb9ccd03 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -32,7 +32,7 @@ from jax._src import sharding from jax._src import typing from jax._src import util -from jax._src.layout import Layout +from jax._src.layout import Format from jax._src.lib import _jax from jax.experimental.array_serialization import tensorstore_impl as ts_impl # ruff: noqa: F401 @@ -352,7 +352,7 @@ def serialize_with_paths( transaction=transaction, ) - def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], + def deserialize(self, shardings: Sequence[sharding.Sharding | Format], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 9a6b91d04c9a..3bee72967101 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -27,7 +27,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.layout import DeviceLocalLayout as DLL -from jax._src.layout import Layout +from jax._src.layout import Format from jax.experimental.array_serialization import serialization from jax.experimental.array_serialization import tensorstore_impl as ts_impl import jax.numpy as jnp @@ -593,7 +593,7 @@ def test_load_with_layout(self): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( + out_layout = jax.jit(lambda x: x.T, out_shardings=Format(DLL.AUTO)).lower( arr).compile().output_layouts self.assertEqual(arr.layout.device_local_layout.major_to_minor, out_layout.device_local_layout.major_to_minor[::-1]) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py index 873cc82da95e..7578bbb831e0 100644 --- a/jax/experimental/array_serialization/tensorstore_impl.py +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -25,7 +25,7 @@ import jax from jax import numpy as jnp from jax._src import array -from jax._src.layout import Layout +from jax._src.layout import Format from jax._src import typing import numpy as np import tensorstore as ts @@ -424,7 +424,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, + user_in_sharding: jax.sharding.Sharding | Format, tensorstore_spec: ts.Spec | dict[str, Any], global_shape: Sequence[int] | None = None, dtype=None, @@ -435,13 +435,13 @@ async def async_deserialize( ): """Main performant deserialization routine for arrays using tensorstore.""" in_sharding = (user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) else user_in_sharding) + if isinstance(user_in_sharding, Format) else user_in_sharding) if not isinstance(in_sharding, jax.sharding.Sharding): raise ValueError( 'sharding passed to deserialization should be specified, concrete and' f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') dll = (user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) else None) + if isinstance(user_in_sharding, Format) else None) t = await ts.open( tensorstore_spec, open=True, @@ -476,7 +476,7 @@ async def cb(index: array.Index, device: jax.Device): if out.dtype == jnp.int4: out = jnp.asarray(out) # type: ignore result = jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) + out, Format(dll, jax.sharding.SingleDeviceSharding(device))) if byte_limiter is not None: # NB: `out` actually might not be ready for garbage collection by the # time we call release_bytes . Thus peak memory usage still might grow @@ -495,7 +495,7 @@ async def cb(index: array.Index, device: jax.Device): # TODO(rdyro): Remove this function. -def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Layout], +def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Format], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index e98cfbc68104..1c243541d99b 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -14,8 +14,10 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout, + Format as Format, ) from jax._src.pjit import ( with_layout_constraint as with_layout_constraint, ) + +Layout = Format diff --git a/tests/layout_test.py b/tests/layout_test.py index c15816d7794a..cfec2253dfc8 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -23,7 +23,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax.experimental.layout import (with_layout_constraint, Layout, +from jax.experimental.layout import (with_layout_constraint, Format, DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on @@ -51,8 +51,8 @@ def init(x, y): sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) + lowered_apply = jax.jit(apply, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() arg_layouts, kw_layouts = compiled_apply.input_layouts @@ -122,8 +122,8 @@ def f(x): self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled_auto = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) @@ -146,8 +146,8 @@ def test_in_layouts_out_layouts(self): def f(x): return x.T - compiled = jax.jit(f, in_shardings=Layout(), - out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + compiled = jax.jit(f, in_shardings=Format(), + out_shardings=Format(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) @@ -166,8 +166,8 @@ def test_sharding_and_layouts(self): np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) - compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() + compiled = jax.jit(lambda x: x.T, in_shardings=Format(DLL.AUTO, s), + out_shardings=Format(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], @@ -185,8 +185,8 @@ def f(x, y, z, a, b, c): shape = (8, 2) inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() + compiled = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(*inps).compile() arg_layouts, _ = compiled.input_layouts out1, out2 = compiled(*inps) @@ -216,8 +216,8 @@ def test_no_error_dced_args(self): def f(x, y): return x * 2 - jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)) + jf = jax.jit(f, in_shardings=Format(DLL.AUTO, s), + out_shardings=Format(DLL.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() arg_layouts, _ = compiled.input_layouts arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] @@ -244,10 +244,10 @@ def f(x): with self.assertRaisesRegex( ValueError, 'Layout passed to jit does not match the layout on the respective arg'): - jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr) + jax.jit(f, in_shardings=Format(DLL.AUTO)).lower(arr) - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds).compile() with self.assertRaisesRegex( ValueError, @@ -273,7 +273,7 @@ def test_device_put_concrete_layout(self): arr = jax.device_put(np_inp, s) compiled = jax.jit( - lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + lambda x: x * 2, out_shardings=Format(DLL.AUTO)).lower(arr).compile() col = compiled.output_layouts out = jax.device_put(np_inp, col) @@ -286,17 +286,17 @@ def test_device_put_concrete_layout(self): def test_device_put_non_concrete_layout_error(self): np_inp = np.arange(16).reshape(8, 2) - l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) + l1 = Format(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) with self.assertRaisesRegex( ValueError, 'sharding and device_local_layout.*should be concrete'): jax.device_put(np_inp, l1) - l2 = Layout(DLL.AUTO) + l2 = Format(DLL.AUTO) with self.assertRaisesRegex( ValueError, 'sharding and device_local_layout.*should be concrete'): jax.device_put(np_inp, l2) - l3 = Layout(None, SingleDeviceSharding(jax.devices()[0])) + l3 = Format(None, SingleDeviceSharding(jax.devices()[0])) out = jax.device_put(np_inp, l3) self.assertArraysEqual(out, np_inp) self.assertTrue(out._committed) @@ -306,7 +306,7 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled.output_layouts[0], None) + Format(compiled.output_layouts[0], None) def test_layout_on_sds(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -314,7 +314,7 @@ def test_layout_on_sds(self): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( + out_layout = jax.jit(jnp.sin, out_shardings=Format(DLL.AUTO)).lower( arr).compile().output_layouts sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) @@ -325,7 +325,7 @@ def test_layout_on_sds(self): TypeError, 'DeviceLocalLayout.AUTO` cannot be used in place of a device-local' ' layout in a `ShapeDtypeStruct`'): - jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) + jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Format(DLL.AUTO)) def test_make_array_from_callback(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -344,13 +344,13 @@ def test_make_array_from_callback(self): TypeError, '`DeviceLocalLayout.AUTO` cannot be used in place of a device-local' ' layout'): - jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s), + jax.make_array_from_callback(np_inp.shape, Format(DLL.AUTO, s), lambda idx: np_inp[idx]) with self.assertRaisesRegex( TypeError, 'sharding should be an instance of `jax.sharding`'): jax.make_array_from_callback( - np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + np_inp.shape, Format(None, None), lambda idx: np_inp[idx]) def test_wsc_concrete_layout(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -367,7 +367,7 @@ def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) self.assertEqual(out.layout.device_local_layout.major_to_minor, @@ -390,7 +390,7 @@ def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) self.assertEqual(out.layout.device_local_layout.major_to_minor, @@ -404,7 +404,7 @@ def test_device_put_user_concrete_layout(self): dll = DLL(major_to_minor=(1, 0)) s = SingleDeviceSharding(jax.devices()[0]) - out = jax.device_put(np_inp, Layout(dll, s)) + out = jax.device_put(np_inp, Format(dll, s)) self.assertEqual(out.layout.device_local_layout.major_to_minor, dll.major_to_minor) self.assertArraysEqual(out, np_inp) @@ -417,7 +417,7 @@ def test_device_put_user_concrete_layout_multi_device(self): jnp_inp = jnp.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - custom_layout = Layout(DLL(major_to_minor=(0, 1)), s) + custom_layout = Format(DLL(major_to_minor=(0, 1)), s) out1 = jax.device_put(arr, custom_layout) with jax.sharding.use_mesh(mesh): @@ -441,7 +441,7 @@ def f(x): return x.T custom_dll = DLL(major_to_minor=(0, 1)) - f = jax.jit(f, out_shardings=Layout(custom_dll, s)) + f = jax.jit(f, out_shardings=Format(custom_dll, s)) out = f(arr) self.assertArraysEqual(out, np_inp.T) @@ -450,7 +450,7 @@ def f(x): def test_compatible_aval_error(self): custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) @partial(jax.jit, in_shardings=l) @@ -464,7 +464,7 @@ def f(x): def test_incompatible_aval_error_device_put(self): custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) with self.assertRaisesRegex( @@ -482,8 +482,8 @@ def test_concrete_layout_in_shardings(self): custom_dll = DLL(major_to_minor=(0, 1)) @partial(jax.jit, - in_shardings=Layout(custom_dll, s), - out_shardings=Layout(DLL.AUTO)) + in_shardings=Format(custom_dll, s), + out_shardings=Format(DLL.AUTO)) def f(x): return x.T @@ -494,7 +494,7 @@ def f(x): custom_dll2 = DLL(major_to_minor=(1, 0)) - @partial(jax.jit, in_shardings=Layout(custom_dll2, s)) + @partial(jax.jit, in_shardings=Format(custom_dll2, s)) def g(x): return x.T @@ -508,7 +508,7 @@ def test_in_layouts_jit_jnp_input(self): sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) f = jax.jit(lambda x: x + 1, - in_shardings=Layout(major_last_layout, sharding)) + in_shardings=Format(major_last_layout, sharding)) arr = jnp.arange(8 * 128).reshape(8, 128) out = f(arr) @@ -533,9 +533,9 @@ def test_layout_donation(self): np_inp = np.arange(math.prod(shape)).reshape(shape) custom_dll = DLL(major_to_minor=(0, 1)) - arr = jax.device_put(np_inp, Layout(custom_dll, s)) + arr = jax.device_put(np_inp, Format(custom_dll, s)) - @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + @partial(jax.jit, in_shardings=Format(custom_dll, s), donate_argnums=0) def f(x): return x @@ -550,7 +550,7 @@ def test_layout_donation_auto(self): arr = jax.device_put(np_inp, s) - @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + @partial(jax.jit, out_shardings=Format(DLL.AUTO), donate_argnums=0) def f(x): return x * x @@ -564,7 +564,7 @@ def test_layout_donation_matching_in_and_out(self): np_inp = np.arange(math.prod(shape)).reshape(shape) custom_dll = DLL(major_to_minor=(0, 1)) - l = Layout(custom_dll, s) + l = Format(custom_dll, s) arr = jax.device_put(np_inp, l) @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) @@ -582,7 +582,7 @@ def test_layout_donation_mismatching_in_and_out_fails(self): np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) - l1 = Layout(custom_dll1, s) + l1 = Format(custom_dll1, s) arr = jax.device_put(np_inp, s) @partial(jax.jit, out_shardings=l1, donate_argnums=0) @@ -594,7 +594,7 @@ def f(x): self.assertFalse(arr.is_deleted()) def test_donation_error_on_auto(self): - @partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO)) + @partial(jax.jit, donate_argnums=0, in_shardings=Format(DLL.AUTO)) def f(x): return x * 2 @@ -602,7 +602,7 @@ def f(x): ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"): f(jnp.arange(8)) - @partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO)) + @partial(jax.jit, donate_argnums=0, out_shardings=Format(DLL.AUTO)) def g(x): return x * 2 @@ -619,9 +619,9 @@ def test_sparsecore_compute(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) + sparse_layout = Format(dll, s) sparecore_arr = jax.device_put(inp, sparse_layout) - dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) + dense_layout = Format(DLL(major_to_minor=(0, 1)), s) @compute_on('tpu_sparsecore') @jax.jit @@ -645,7 +645,7 @@ def test_sparsecore_compute_twice(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) + sparse_layout = Format(dll, s) sparecore_arr = jax.device_put(inp, sparse_layout) @compute_on('tpu_sparsecore') @@ -675,11 +675,11 @@ def test_sparsecore_and_host_compute(self): s = SingleDeviceSharding(jax.devices()[0]) sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - sparse_layout = Layout(sparse_dll, s) + sparse_layout = Format(sparse_dll, s) sparecore_arr = jax.device_put(inp, sparse_layout) host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - host_layout = Layout(host_dll, s) + host_layout = Format(host_dll, s) host_arr = jax.device_put(inp, host_layout) @compute_on('tpu_sparsecore') @@ -710,7 +710,7 @@ def test_cpp_layout_cache_miss(self): arr = jax.device_put(np_inp, s) arr_m2m = arr.layout.device_local_layout.major_to_minor - custom_layout = Layout(DLL(major_to_minor=arr_m2m[::-1]), s) + custom_layout = Format(DLL(major_to_minor=arr_m2m[::-1]), s) arr2 = jax.device_put(np_inp, custom_layout) @jax.jit @@ -731,7 +731,7 @@ def test_layout_donation_with_default_layout(self): shape = (16, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - out_layout = Layout(arr.layout.device_local_layout, s) + out_layout = Format(arr.layout.device_local_layout, s) @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) def f(x): diff --git a/tests/memories_test.py b/tests/memories_test.py index e0a42d4a0146..fd40330f8db8 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -25,7 +25,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.layout import DeviceLocalLayout as DLL, Layout +from jax._src.layout import DeviceLocalLayout as DLL, Format from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -1574,8 +1574,8 @@ def test_fn(x_in, y_in): y = jnp.reshape(y, (16, 64)) custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 1024, dtype=jnp.float32) x1 = jnp.reshape(x1, (16, 64)) @@ -1585,8 +1585,8 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1613,8 +1613,8 @@ def test_fn(x_in, y_in): y = jnp.reshape(y, (32, 64)) custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 2048, dtype=jnp.float32) x1 = jnp.reshape(x1, (32, 64)) @@ -1624,8 +1624,8 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d37b21bd2460..751fd63823e3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -55,7 +55,7 @@ SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, use_auto_axes, use_explicit_axes, reshard) -from jax._src.layout import Layout, DeviceLocalLayout as DLL +from jax._src.layout import Format, DeviceLocalLayout as DLL from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType @@ -4997,8 +4997,8 @@ def test_sds_update(self): sh = NamedSharding(mesh, P()) s4 = jax.ShapeDtypeStruct((2, 2), jnp.int32, - sharding=Layout(DLL((0, 1)), sh)) - new_layout = Layout(DLL((1, 0)), NamedSharding(mesh, P('x'))) + sharding=Format(DLL((0, 1)), sh)) + new_layout = Format(DLL((1, 0)), NamedSharding(mesh, P('x'))) s4_u = s4.update(sharding=new_layout) self.assertEqual(s4_u.sharding, new_layout.sharding) self.assertEqual(s4_u.layout, new_layout) From 663e50f72cc6cf420e78f865837f82112d2425b6 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 29 May 2025 16:46:43 -0700 Subject: [PATCH 1439/1769] [Mosaic] Make 1D tiling agnostic to large 2nd minor flags. PiperOrigin-RevId: 764937293 --- .../tpu/transforms/infer_memref_layout.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index bfb9be87dfd0..b772c5c8a114 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -58,10 +58,12 @@ namespace mlir::tpu { // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. // is_kernel_argument: Whether the operand is a kernel argument. +// is_1d: Whether the operand is 1D. int getTilingFactor(const int src_sublane, const int hardware_generation, const int64_t target_sublane_count, const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth, const bool is_kernel_argument) { + const int8_t bitwidth, const bool is_kernel_argument, + const bool is_1d) { CHECK(llvm::isPowerOf2_32(bitwidth)); CHECK_LE(2, bitwidth); CHECK_LE(bitwidth, 32); @@ -76,6 +78,10 @@ int getTilingFactor(const int src_sublane, const int hardware_generation, const int max_normal_tiling = tiling_sublane; int large_tiling = [&] { + if (is_1d) { + // 1D tiling is always compact. + return tiling_sublane; + } if (bitwidth == 2) { return target_sublane_count * 16; } @@ -151,9 +157,9 @@ FailureOr inferLayout(MemRefType memref_ty, auto src_sublane = llvm::divideCeil(memref_ty.getShape().back(), lane_count); const int64_t leading_tile = - getTilingFactor(src_sublane, hardware_generation, - sublane_count, tpu_tiling_flags, bitwidth, - is_kernel_argument) * + getTilingFactor(src_sublane, hardware_generation, sublane_count, + tpu_tiling_flags, bitwidth, is_kernel_argument, + /*is_1d=*/true) * lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { @@ -173,8 +179,8 @@ FailureOr inferLayout(MemRefType memref_ty, const int64_t src_sublane = shape[shape.size() - 2]; if (leading_tile_rows == 0) { leading_tile_rows = getTilingFactor( - src_sublane, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth, is_kernel_argument); + src_sublane, hardware_generation, sublane_count, tpu_tiling_flags, + bitwidth, is_kernel_argument, /*is_1d=*/false); } SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) { From 7ff6f0d01a7e324787b238f6998028a4a2686625 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 29 May 2025 18:34:52 -0700 Subject: [PATCH 1440/1769] rename `Array.layout` to `Array.format` This change renames the attribute and updates the codebase to refer to the new name. It should have minimal external effect, since it keeps a `layout` alias for the attribute. Co-authored-by: Yash Katariya PiperOrigin-RevId: 764967359 --- jax/_src/api.py | 4 +- jax/_src/array.py | 9 ++-- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/pxla.py | 6 +-- jax/_src/pjit.py | 10 ++-- .../array_serialization/serialization_test.py | 4 +- tests/layout_test.py | 52 +++++++++---------- tests/pjit_test.py | 2 +- 8 files changed, 46 insertions(+), 43 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 2630fc7ae1be..c21e8248d52d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2909,12 +2909,12 @@ def update(self, **kwargs): if self._dll is not None and isinstance(s, Sharding): raise ValueError( f"You are updating ShapeDtypeStruct with a {type(s)} when the" - f" original ShapeDtypeStruct had a concrete layout {self.layout}." + f" original ShapeDtypeStruct had a concrete layout {self.format}." " This might lead to bugs. If you want to do this, create a new" " ShapeDtypeStruct via the constructor.") sharding = s else: - sharding = self.layout + sharding = self.format return ShapeDtypeStruct( shape=kwargs.pop('shape', self.shape), dtype=kwargs.pop('dtype', self.dtype), diff --git a/jax/_src/array.py b/jax/_src/array.py index 29c7a17b07f1..9a71b12ed1a8 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -547,7 +547,7 @@ def addressable_shards(self) -> Sequence[Shard]: return out @property - def layout(self): + def format(self): # TODO(yashkatariya): Remove the deleted check from here. if self.is_deleted(): return Format(None, self.sharding) @@ -561,6 +561,9 @@ def layout(self): else: raise + # TODO(frostig, yashkatariya): remove + layout = format + @property def global_shards(self) -> Sequence[Shard]: """Returns list of all `Shard`s of the Array across all devices. @@ -812,7 +815,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and first_value.layout.device_local_layout == dll): + and first_value.format.device_local_layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -1197,7 +1200,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) same_layout = (True if layout is None else - x.layout.device_local_layout == layout) + x.format.device_local_layout == layout) if not x.is_fully_addressable: if same_indices and same_layout: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index a7c8d4ea7380..028c2cfa125e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -497,7 +497,7 @@ def _device_put_impl( if isinstance(device, Format): l = device dll = l.device_local_layout - x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None + x_dll = x.format.device_local_layout if hasattr(x, 'format') else None if dll is None and l.sharding is None: return _device_put_sharding_impl(x, aval, l.sharding, copy) if (not isinstance(l.sharding, Sharding) or diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index af1f8217951c..2072aaf44b5a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3255,11 +3255,11 @@ def check_array_xla_sharding_layout_match( 'sharding')) if (not db_xs and arg._committed and - arg.layout.device_local_layout is not None and xl is not None and - arg.layout.device_local_layout != xl): + arg.format.device_local_layout is not None and xl is not None and + arg.format.device_local_layout != xl): errors.append( ("Got input layout(s) that compiled object was called with: " - f"{arg.layout.device_local_layout} and layout(s) the computation was " + f"{arg.format.device_local_layout} and layout(s) the computation was " f"compiled with: {xl} for arg {name} with " f"shape: {arg.aval.str_short()}", 'layout')) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f012459296a1..4113f764e888 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1651,8 +1651,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # below. We cannot replace default layout with None to raise nicer errors. # `dispatch_arg_layout` replaces default layouts with `None` to simplify # dispatch and lowering logic downstream. - if hasattr(arg, 'layout'): - arg_layout = arg.layout.device_local_layout + if hasattr(arg, 'format'): + arg_layout = arg.format.device_local_layout dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) else arg_layout) else: @@ -1670,7 +1670,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts.append(None) else: # arg_layout can be None because some backends don't implement the - # required layout methods. Hence `arr.layout` can return + # required layout methods. Hence `arr.format` can return # `Format(None, sharding)` if (committed and not is_pmap_sharding @@ -2845,7 +2845,7 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, # Run a jit here to raise good errors when device assignment don't match. return api.jit(_identity_fn, out_shardings=sharding)(x) else: - if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and + if (hasattr(x, 'format') and x.format.device_local_layout == layout and x.sharding.is_equivalent_to(sharding, x.ndim)): return x return api.jit(_identity_fn, out_shardings=Format(layout, sharding))(x) @@ -3193,7 +3193,7 @@ def _layout_constraint_impl(x, *, layout): raise ValueError( 'with_layout_constraint in eager mode can only be applied to' f' jax.Arrays. Got {type(x)}') - if x.layout.device_local_layout == layout: # type: ignore + if x.format.device_local_layout == layout: # type: ignore return x return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x) layout_constraint_p.def_impl(_layout_constraint_impl) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 3bee72967101..0611388a1d80 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -595,7 +595,7 @@ def test_load_with_layout(self): out_layout = jax.jit(lambda x: x.T, out_shardings=Format(DLL.AUTO)).lower( arr).compile().output_layouts - self.assertEqual(arr.layout.device_local_layout.major_to_minor, + self.assertEqual(arr.format.device_local_layout.major_to_minor, out_layout.device_local_layout.major_to_minor[::-1]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) @@ -611,7 +611,7 @@ def test_load_with_layout(self): out, = serialization.run_deserialization([out_layout], tspecs) - self.assertEqual(out.layout, out_layout) + self.assertEqual(out.format, out_layout) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: diff --git a/tests/layout_test.py b/tests/layout_test.py index cfec2253dfc8..d7b23c75b313 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -77,21 +77,21 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count(), 1) - self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) - self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) + self.assertEqual(init_out[0].format, init_compiled.output_layouts[0]) + self.assertEqual(init_out[1].format, init_compiled.output_layouts[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count(), 1) - self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) - self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) + self.assertEqual(apply_out[0].format, compiled_apply.output_layouts[0]) + self.assertEqual(apply_out[1].format, compiled_apply.output_layouts[1]) - self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, - init_out[0].layout.device_local_layout.major_to_minor[::-1]) - self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, - init_out[1].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[0].format.device_local_layout.major_to_minor, + init_out[0].format.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].format.device_local_layout.major_to_minor, + init_out[1].format.device_local_layout.major_to_minor[::-1]) self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) @@ -157,7 +157,7 @@ def f(x): out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled.output_layouts) + self.assertEqual(out.format, compiled.output_layouts) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -277,11 +277,11 @@ def test_device_put_concrete_layout(self): col = compiled.output_layouts out = jax.device_put(np_inp, col) - self.assertEqual(out.layout, col) + self.assertEqual(out.format, col) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: - self.assertEqual(out.layout.device_local_layout, - s.data.layout.device_local_layout) + self.assertEqual(out.format.device_local_layout, + s.data.format.device_local_layout) def test_device_put_non_concrete_layout_error(self): np_inp = np.arange(16).reshape(8, 2) @@ -338,7 +338,7 @@ def test_make_array_from_callback(self): out = jax.make_array_from_callback(np_inp.shape, layout, lambda idx: np_inp[idx]) self.assertArraysEqual(out, np_inp) - self.assertEqual(out.layout, layout) + self.assertEqual(out.format, layout) with self.assertRaisesRegex( TypeError, @@ -370,9 +370,9 @@ def f(x): return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, np_inp.T) def test_wsc_bfloat16_concrete_layout(self): @@ -393,9 +393,9 @@ def f(x): return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): @@ -405,7 +405,7 @@ def test_device_put_user_concrete_layout(self): s = SingleDeviceSharding(jax.devices()[0]) out = jax.device_put(np_inp, Format(dll, s)) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, dll.major_to_minor) self.assertArraysEqual(out, np_inp) @@ -427,7 +427,7 @@ def test_device_put_user_concrete_layout_multi_device(self): for o in [out1, out2, out3, out4]: self.assertArraysEqual(o, np_inp) - self.assertEqual(o.layout.device_local_layout.major_to_minor, + self.assertEqual(o.format.device_local_layout.major_to_minor, custom_layout.device_local_layout.major_to_minor) def test_concrete_layout_jit(self): @@ -445,7 +445,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) def test_compatible_aval_error(self): @@ -489,7 +489,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor[::-1]) custom_dll2 = DLL(major_to_minor=(1, 0)) @@ -709,7 +709,7 @@ def test_cpp_layout_cache_miss(self): np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - arr_m2m = arr.layout.device_local_layout.major_to_minor + arr_m2m = arr.format.device_local_layout.major_to_minor custom_layout = Format(DLL(major_to_minor=arr_m2m[::-1]), s) arr2 = jax.device_put(np_inp, custom_layout) @@ -731,7 +731,7 @@ def test_layout_donation_with_default_layout(self): shape = (16, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - out_layout = Format(arr.layout.device_local_layout, s) + out_layout = Format(arr.format.device_local_layout, s) @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) def f(x): @@ -743,7 +743,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.layout, out_layout) + self.assertEqual(out.format, out_layout) def test_with_layout_constraint(self): if not jtu.test_device_matches(['tpu']): @@ -755,7 +755,7 @@ def test_with_layout_constraint(self): arr = jax.device_put(np_inp, s) # Create a custom layout instead of using `arr.layout` to test the API. - custom_dll = DLL(major_to_minor=arr.layout.dll.major_to_minor[::-1]) + custom_dll = DLL(major_to_minor=arr.format.dll.major_to_minor[::-1]) def f(x): y = x.T @@ -768,7 +768,7 @@ def f(x): f = jax.jit(f) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) self.assertArraysEqual(out, np_inp.T * 2) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 751fd63823e3..dd5c5d46e62f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5001,7 +5001,7 @@ def test_sds_update(self): new_layout = Format(DLL((1, 0)), NamedSharding(mesh, P('x'))) s4_u = s4.update(sharding=new_layout) self.assertEqual(s4_u.sharding, new_layout.sharding) - self.assertEqual(s4_u.layout, new_layout) + self.assertEqual(s4_u.format, new_layout) with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"): s4.update(sharding=NamedSharding(mesh, P('x'))) From 75b2c7e553e7ad9a141e0d94ff45af31eacfebd3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 30 May 2025 02:03:24 -0700 Subject: [PATCH 1441/1769] [Mosaic GPU] Move the semaphore implementation to Mosaic Pallas lowering should not be doing any heavy lifting here. The implementation is quite low level and should ideally live closer to where other synchronization primitives are implemented. PiperOrigin-RevId: 765092823 --- jax/_src/pallas/mosaic_gpu/lowering.py | 41 ++-------------- jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/utils.py | 64 +++++++++++++++++++++++++ tests/mosaic/gpu_test_distributed.py | 55 +++++++++++++++++++++ 4 files changed, 125 insertions(+), 36 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 003fa0419f63..ce74a5ba7c05 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -3108,13 +3108,8 @@ def _semaphore_signal_lowering_rule( # receive a signal). if ctx.module_ctx.auto_barriers: mgpu.utils.warpgroup_barrier() - pred = ctx.module_ctx.single_wg_lane_predicate - llvm_dialect.inline_asm( - i32, - [sem_ptr, val, pred], - "@$3 atom.add.release.sys.global.u32 $0, [$1], $2;", - "=r,l,r,b", - has_side_effects=True, + mgpu_utils.SemaphoreRef(sem_ptr).signal( + val, predicate=ctx.module_ctx.single_wg_lane_predicate ) return () @@ -3127,35 +3122,9 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): raise NotImplementedError( f"Unhandled transforms for semaphore_wait: {transforms}" ) - - sem_ptr = mgpu.utils.memref_ptr(sem) - i32_ty = ir.IntegerType.get_signless(32) - ne_pred = arith_dialect.CmpIPredicate.ne - val = _ir_constant(value, i32_ty) - - with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): - # Create the while loop for busy waiting - while_op = scf_dialect.WhileOp([i32_ty], [val]) - before_block = while_op.before.blocks.append(i32_ty) - with ir.InsertionPoint.at_block_begin(before_block): - [expected_in_memory] = before_block.arguments - new_val = arith_dialect.subi(expected_in_memory, val) - in_memory = llvm_dialect.inline_asm( - i32_ty, - [sem_ptr, expected_in_memory, new_val], - "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", - "=r,l,r,r", - has_side_effects=True, - ) - comparison = arith_dialect.cmpi(ne_pred, in_memory, expected_in_memory) - new_expected_in_memory = arith_dialect.maxui(in_memory, val) - scf_dialect.condition(comparison, [new_expected_in_memory]) - after_block = while_op.after.blocks.append(i32_ty) - with ir.InsertionPoint.at_block_begin(after_block): - scf_dialect.yield_(after_block.arguments) - # NOTE: This barrier is necessary for a correct lowering of this op and can't - # be removed even if auto_barriers is False. - mgpu_utils.warpgroup_barrier() + i32 = ir.IntegerType.get_signless(32) + val = _ir_constant(value, i32) + mgpu_utils.SemaphoreRef(mgpu.utils.memref_ptr(sem)).wait(val) return () diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 82155f86d9ea..cd207c2b2519 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -74,6 +74,7 @@ DynamicSlice as DynamicSlice, Partition as Partition, Partition1D as Partition1D, + SemaphoreRef as SemaphoreRef, ThreadSubset as ThreadSubset, bitwidth as bitwidth, bytewidth as bytewidth, diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 51b6ed4612ca..a76e077ff463 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -744,6 +744,9 @@ def warpgroup_barrier(): has_side_effects=True, ) +def warp_barrier(): + nvvm.bar_warp_sync(c(0xffffffff, ir.IntegerType.get_signless(32))) + @dataclasses.dataclass(frozen=True) class BarrierRef: @@ -1046,6 +1049,67 @@ def wait_parity(self, *args, **kwargs): self.barrier.wait_parity(*args, **kwargs) +@dataclasses.dataclass(frozen=True) +class SemaphoreRef: + ptr: ir.Value + + def signal(self, value: ir.Value | int, predicate: ir.Value | None = None): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + if predicate is None: + predicate = single_thread_predicate(ThreadSubset.WARPGROUP) + llvm.inline_asm( + i32, + [self.ptr, value, predicate], + "@$3 atom.add.release.sys.global.u32 $0, [$1], $2;", + "=r,l,r,b", + has_side_effects=True, + ) + + def wait( + self, + value: ir.Value | int = 1, + scope: ThreadSubset = ThreadSubset.WARPGROUP, + ): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + + ne_pred = arith.CmpIPredicate.ne + + with single_thread(scope=scope): + # Create the while loop for busy waiting + while_op = scf.WhileOp([i32], [value]) + before_block = while_op.before.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(before_block): + [expected_in_memory] = before_block.arguments + new_val = arith.subi(expected_in_memory, value) + in_memory = llvm.inline_asm( + i32, + [self.ptr, expected_in_memory, new_val], + "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", + "=r,l,r,r", + has_side_effects=True, + ) + comparison = arith.cmpi(ne_pred, in_memory, expected_in_memory) + new_expected_in_memory = arith.maxui(in_memory, value) + scf.condition(comparison, [new_expected_in_memory]) + after_block = while_op.after.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(after_block): + scf.yield_(after_block.arguments) + if scope == ThreadSubset.WARPGROUP: + warpgroup_barrier() + elif scope == ThreadSubset.WARP: + warp_barrier() + else: + raise ValueError(f"Unsupported scope: {scope}") + + class Partition: source_bounds: tuple[int, ...] target_bounds: tuple[int, ...] diff --git a/tests/mosaic/gpu_test_distributed.py b/tests/mosaic/gpu_test_distributed.py index cf3913771983..c289b27c0be1 100644 --- a/tests/mosaic/gpu_test_distributed.py +++ b/tests/mosaic/gpu_test_distributed.py @@ -23,6 +23,7 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import memref from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member from jax.experimental import shard from jax.experimental import multihost_utils @@ -70,6 +71,28 @@ def setUp(self): class ProfilerTest(TestCase): + def test_get_device_id(self): + index = ir.IndexType.get() + def kernel(ctx, dst, _): + device_id = ctx.device_id() + memref.store(device_id, dst, [arith.constant(index, 0)]) + mesh = jax.make_mesh( + (jax.device_count(),), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + out_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + y = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + ), + out_specs=P("x"), + check_vma=False, + ) + )() + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal(y_np, np.arange(jax.device_count())) + def test_remote_async_copy(self): i32 = ir.IntegerType.get_signless(32) def kernel(ctx, src, dst, scratch): @@ -99,6 +122,38 @@ def kernel(ctx, src, dst, scratch): y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0) ) + def test_remote_semaphore(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, sem, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + # We signal and wait a different amount on each device to make sure we're + # really communicating here. + other_sem.signal(arith.addi(arith.constant(i32, 1), other_device)) + @mgpu.fori(arith.addi(arith.constant(i32, 1), my_device), None) + def wait_loop(i, _): + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + sem = shard.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), (), (), inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + if __name__ == "__main__": # This test doesn't work with the platform allocator, so we override it From 70f5aa4dfec771b195268d416be346802e8c608e Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 May 2025 02:43:30 -0700 Subject: [PATCH 1442/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/adf78529a5536db17f147c6abbce0e4ece83ba42. PiperOrigin-RevId: 765104016 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d2a0a2e8b29a..ac418308de83 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f9ea70486e5625484325b3f451a32242507493ad" -XLA_SHA256 = "48c1a6e87b580becb8dbc018028be2835c5f0bd941ae36f450102f3f16a79398" +XLA_COMMIT = "adf78529a5536db17f147c6abbce0e4ece83ba42" +XLA_SHA256 = "559406848b34c82856f45c8371ea8c2d7a92e2d13db768fc401fda8f4c22c70a" def repo(): tf_http_archive( From a9407763d9abf80f8636631065f9256cf0238e6d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 30 May 2025 03:12:44 -0700 Subject: [PATCH 1443/1769] [pallas:mosaic_gpu] Unconditionally emit line info for Mosaic GPU kernels I also changed the lowering to override --jax_include_full_tracebacks_in_locations so that we get a single location per emitted op, since the ensure-debug-info-scope-on-llvm-func pass in MLIR does not correctly handle nested CallSiteLocs. PiperOrigin-RevId: 765112273 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + .../pallas/mosaic_gpu/pallas_call_registration.py | 10 +++++++--- jaxlib/mosaic/gpu/custom_call.cc | 14 +++++--------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 74b44fb8f991..8a5d087125b7 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -45,6 +45,7 @@ pytype_strict_library( ":core", ":lowering", "//jax", + "//jax:config", "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index ccbe4d36edc9..1d55a6e862a0 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -24,6 +24,7 @@ import jax from jax import lax +from jax._src import config from jax._src import core as jax_core from jax._src import sharding_impls from jax._src.interpreters import mlir @@ -73,9 +74,12 @@ def pallas_call_lowering( if isinstance(axis_context, sharding_impls.SPMDAxisContext): jax_mesh = axis_context.mesh - lowering_result = lowering.lower_pipelined_jaxpr_to_module( - grid_mapping, mesh, jax_mesh, jaxpr, params, cost_estimate - ) + # TODO(slebedev): Remove this once the ensure-debug-info-scope-on-llvm-func + # pass correctly handles full tracebacks. + with config.include_full_tracebacks_in_locations(False): + lowering_result = lowering.lower_pipelined_jaxpr_to_module( + grid_mapping, mesh, jax_mesh, jaxpr, params, cost_estimate + ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {debug_info.func_src_info}:") print(lowering_result.module.operation) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 5253d4590658..214521ce3764 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -343,7 +343,6 @@ mlir::FailureOr GetPassPipeline( mlir::LLVM::registerDIScopeForLLVMFuncOpPass(); return true; }); - bool emit_line_info = getenv("MOSAIC_GPU_LINE_INFO") != nullptr; const char *cuda_root = GetCUDARoot(); if (!cuda_root) { return mlir::failure(); @@ -360,8 +359,8 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{O=3 chip=)", sm, " fast=false features=+", - ptx_isa, + nvvm-attach-target{O=3 chip=)", + sm, " fast=false features=+", ptx_isa, R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, @@ -369,7 +368,6 @@ mlir::FailureOr GetPassPipeline( canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, )", - emit_line_info ? "" : "gpu.module(strip-debuginfo),", R"( gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), @@ -377,13 +375,11 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - )", - emit_line_info ? "ensure-debug-info-scope-on-llvm-func{emission-kind=DebugDirectivesOnly}," : "", - "gpu-module-to-binary{format=", + ensure-debug-info-scope-on-llvm-func{emission-kind=DebugDirectivesOnly}, + gpu-module-to-binary{format=)", mlir::gpu::stringifyCompilationTarget(target).str(), (!nvshmem_path.empty() ? " l=" + nvshmem_path : ""), - (emit_line_info ? " opts=-lineinfo" : ""), - " toolkit=", cuda_root, + " opts=-lineinfo toolkit=", cuda_root, R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, From 6564a4bb5f9eac51e449f4403997b53c873ffa75 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 May 2025 09:10:51 -0400 Subject: [PATCH 1444/1769] Remove Mac x86 from the installation instructions. We have not been shipping Mac x86 for some time. --- README.md | 16 ++++++++-------- docs/installation.md | 17 ++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 14a0a06ae700..e6af1b344f24 100644 --- a/README.md +++ b/README.md @@ -225,14 +225,14 @@ Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). ### Supported platforms -| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | -|------------|--------------|---------------|--------------|--------------|----------------|---------------------| -| CPU | yes | yes | yes | yes | yes | yes | -| NVIDIA GPU | yes | yes | no | n/a | no | experimental | -| Google TPU | yes | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | yes | no | experimental | n/a | no | no | -| Apple GPU | n/a | no | n/a | experimental | n/a | n/a | -| Intel GPU | experimental | n/a | n/a | n/a | no | no | +| | Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | +|------------|--------------|---------------|--------------|----------------|---------------------| +| CPU | yes | yes | yes | yes | yes | +| NVIDIA GPU | yes | yes | n/a | no | experimental | +| Google TPU | yes | n/a | n/a | n/a | n/a | +| AMD GPU | yes | no | n/a | no | no | +| Apple GPU | n/a | no | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | no | no | ### Instructions diff --git a/docs/installation.md b/docs/installation.md index 1314a2efa0a8..4019f6461473 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -28,14 +28,14 @@ different builds for different operating systems and accelerators. The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_, then click on the corresponding link to learn how to install JAX in greater detail. -| | Linux, x86_64 | Linux, aarch64 | Mac, x86_64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | -|------------------|---------------------------------------|---------------------------------|---------------------------------------|---------------------------------------|--------------------------|------------------------------------------| -| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`jax≤0.4.38 only ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | -| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | -| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | {ref}`yes ` | no | {ref}`experimental ` | n/a | no | no | -| Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | -| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | +| | Linux, x86_64 | Linux, aarch64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | +|------------------|---------------------------------------|---------------------------------|---------------------------------------|--------------------------|------------------------------------------| +| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | +| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | n/a | no | {ref}`experimental ` | +| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | +| AMD GPU | {ref}`yes ` | no | n/a | no | no | +| Apple GPU | n/a | no | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | no | no | (install-cpu)= @@ -48,7 +48,6 @@ operating systems and architectures: - Linux, x86_64 - Linux, aarch64 -- macOS, Intel - macOS, Apple ARM-based - Windows, x86_64 (*experimental*) From 7b01f6d94e5abe20676be85dcbfa29acf59950c3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 30 May 2025 08:19:54 -0700 Subject: [PATCH 1445/1769] [typing] adjust axis annotation for ufunc.reduce --- jax/_src/numpy/ufunc_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index c488855b70fa..da55212bae1f 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -184,7 +184,7 @@ def _call_vectorized(self, *args): return vectorize(self._func)(*args) @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, + def reduce(self, a: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: From 73c016a534af51614741d70d36c2c75ca59f2dcc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 30 May 2025 08:22:22 -0700 Subject: [PATCH 1446/1769] Don't sort replicated and unreduced axes wrt mesh axis names as they are not set and their order actually matters for all-reduce. PiperOrigin-RevId: 765199626 --- jax/_src/named_sharding.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index faf0b2a9f2b2..0b4efdb41d25 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -288,13 +288,6 @@ def _custom_repr(self): priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' -def _get_axes(axes, mesh_shape): - if not axes: - return () - assert mesh_shape is not None - # Sort wrt mesh axis names so order is deterministic and doesn't hang in - # McJAX. - return tuple(n for n, _ in mesh_shape if n in axes) @dataclasses.dataclass(kw_only=True) class SdyArray: @@ -314,13 +307,11 @@ def build(self) -> sdy.TensorShardingAttr: [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) - replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) - unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, [dim_sharding.build() for dim_sharding in self.dim_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], - unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in self.unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( @@ -342,7 +333,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): dim_shardings.append(SdyDim(axes=[], is_open=True) if not d.axes and not d.is_open else d) used_axes.extend(d.axes) - remaining_axes = set(mesh.axis_names) - set(used_axes) + remaining_axes = tuple(n for n in mesh.axis_names if n not in used_axes) replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArray(mesh_shape=sdy_sharding.mesh_shape, From 213985aa8d20d0b01113e1f5a337a3649ece0a7c Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 30 May 2025 08:41:38 -0700 Subject: [PATCH 1447/1769] replace mentions of `Compiled.input_layouts` with `Compiled.input_formats` This is part of a broader renaming of "layout" to "format". PiperOrigin-RevId: 765205967 --- jax/_src/interpreters/pxla.py | 2 +- .../array_serialization/serialization_test.py | 10 +- tests/layout_test.py | 116 +++++++++--------- 3 files changed, 64 insertions(+), 64 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2072aaf44b5a..74f2e028c555 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2976,7 +2976,7 @@ def from_hlo(name: str, xla_executable.local_devices(), len(in_shardings), len(out_shardings)) # xla_in_layouts are all either None or DeviceLocalLayout. Even default - # layout are concrete layouts and they are used in `compiled.input_layouts` + # layout are concrete layouts and they are used in `compiled.input_formats` # to return concrete layouts to users. # `dispatch_in_layouts` replaces default layouts with `None` to simplify # dispatch logic downstream. diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 0611388a1d80..0d7e0a48b6c1 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -593,10 +593,10 @@ def test_load_with_layout(self): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(lambda x: x.T, out_shardings=Format(DLL.AUTO)).lower( - arr).compile().output_layouts + out_format = jax.jit(lambda x: x.T, out_shardings=Format(DLL.AUTO)).lower( + arr).compile().output_formats self.assertEqual(arr.format.device_local_layout.major_to_minor, - out_layout.device_local_layout.major_to_minor[::-1]) + out_format.device_local_layout.major_to_minor[::-1]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) @@ -609,9 +609,9 @@ def test_load_with_layout(self): self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - out, = serialization.run_deserialization([out_layout], tspecs) + out, = serialization.run_deserialization([out_format], tspecs) - self.assertEqual(out.format, out_layout) + self.assertEqual(out.format, out_format) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: diff --git a/tests/layout_test.py b/tests/layout_test.py index d7b23c75b313..ce0ca17b05de 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -55,18 +55,18 @@ def init(x, y): out_shardings=Format(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply.input_layouts + arg_formats, kw_layouts = compiled_apply.input_formats self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply.output_layouts): + for i, o in zip(arg_formats, compiled_apply.output_formats): self.assertEqual(i.device_local_layout.major_to_minor, o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( - init, out_shardings=arg_layouts).lower(sds1, sds2).compile() + init, out_shardings=arg_formats).lower(sds1, sds2).compile() - for i, o in zip(init_compiled.input_layouts[0], - init_compiled.output_layouts): + for i, o in zip(init_compiled.input_formats[0], + init_compiled.output_formats): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -77,16 +77,16 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count(), 1) - self.assertEqual(init_out[0].format, init_compiled.output_layouts[0]) - self.assertEqual(init_out[1].format, init_compiled.output_layouts[1]) + self.assertEqual(init_out[0].format, init_compiled.output_formats[0]) + self.assertEqual(init_out[1].format, init_compiled.output_formats[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count(), 1) - self.assertEqual(apply_out[0].format, compiled_apply.output_layouts[0]) - self.assertEqual(apply_out[1].format, compiled_apply.output_layouts[1]) + self.assertEqual(apply_out[0].format, compiled_apply.output_formats[0]) + self.assertEqual(apply_out[1].format, compiled_apply.output_formats[1]) self.assertTupleEqual(apply_out[0].format.device_local_layout.major_to_minor, init_out[0].format.device_local_layout.major_to_minor[::-1]) @@ -114,10 +114,10 @@ def f(x): out = compiled(arr) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) @@ -125,10 +125,10 @@ def f(x): compiled_auto = jax.jit(f, in_shardings=Format(DLL.AUTO), out_shardings=Format(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled_auto.input_formats[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1], + compiled_auto.output_formats.device_local_layout.major_to_minor[::-1], (0, 1, 2)) with self.assertRaisesRegex( @@ -149,15 +149,15 @@ def f(x): compiled = jax.jit(f, in_shardings=Format(), out_shardings=Format(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.format, compiled.output_layouts) + self.assertEqual(out.format, compiled.output_formats) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -170,11 +170,11 @@ def test_sharding_and_layouts(self): out_shardings=Format(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) if not jtu.test_device_matches(['cpu']): self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -187,19 +187,19 @@ def f(x, y, z, a, b, c): inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 compiled = jax.jit(f, in_shardings=Format(DLL.AUTO), out_shardings=Format(DLL.AUTO)).lower(*inps).compile() - arg_layouts, _ = compiled.input_layouts + arg_formats, _ = compiled.input_formats out1, out2 = compiled(*inps) - compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() + compiled2 = jax.jit(f, in_shardings=arg_formats).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]): + for l1, l2 in safe_zip(arg_formats, compiled2.input_formats[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) self.assertArraysEqual(out2, out4) - arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)] + arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_formats)] out5, out6 = jax.jit(f)(*arrs) self.assertArraysEqual(out1, out5) self.assertArraysEqual(out2, out6) @@ -219,8 +219,8 @@ def f(x, y): jf = jax.jit(f, in_shardings=Format(DLL.AUTO, s), out_shardings=Format(DLL.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() - arg_layouts, _ = compiled.input_layouts - arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] + arg_formats, _ = compiled.input_formats + arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_formats)] compiled(*arrs) def test_aot_layout_mismatch(self): @@ -274,7 +274,7 @@ def test_device_put_concrete_layout(self): compiled = jax.jit( lambda x: x * 2, out_shardings=Format(DLL.AUTO)).lower(arr).compile() - col = compiled.output_layouts + col = compiled.output_formats out = jax.device_put(np_inp, col) self.assertEqual(out.format, col) @@ -306,7 +306,7 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Format(compiled.output_layouts[0], None) + Format(compiled.output_formats[0], None) def test_layout_on_sds(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -314,12 +314,12 @@ def test_layout_on_sds(self): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(jnp.sin, out_shardings=Format(DLL.AUTO)).lower( - arr).compile().output_layouts + out_format = jax.jit(jnp.sin, out_shardings=Format(DLL.AUTO)).lower( + arr).compile().output_formats - sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) - arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts - self.assertEqual(arg_layout[0], out_layout) + sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_format) + arg_format, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_formats + self.assertEqual(arg_format[0], out_format) with self.assertRaisesRegex( TypeError, @@ -333,12 +333,12 @@ def test_make_array_from_callback(self): np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) - layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts + format = jax.jit(lambda x: x * 2).lower(sds).compile().output_formats - out = jax.make_array_from_callback(np_inp.shape, layout, + out = jax.make_array_from_callback(np_inp.shape, format, lambda idx: np_inp[idx]) self.assertArraysEqual(out, np_inp) - self.assertEqual(out.format, layout) + self.assertEqual(out.format, format) with self.assertRaisesRegex( TypeError, @@ -417,18 +417,18 @@ def test_device_put_user_concrete_layout_multi_device(self): jnp_inp = jnp.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - custom_layout = Format(DLL(major_to_minor=(0, 1)), s) - out1 = jax.device_put(arr, custom_layout) + custom_format = Format(DLL(major_to_minor=(0, 1)), s) + out1 = jax.device_put(arr, custom_format) with jax.sharding.use_mesh(mesh): - out2 = jax.device_put(arr, custom_layout) - out3 = jax.device_put(jnp_inp, custom_layout) - out4 = jax.device_put(np_inp, custom_layout) + out2 = jax.device_put(arr, custom_format) + out3 = jax.device_put(jnp_inp, custom_format) + out4 = jax.device_put(np_inp, custom_format) for o in [out1, out2, out3, out4]: self.assertArraysEqual(o, np_inp) self.assertEqual(o.format.device_local_layout.major_to_minor, - custom_layout.device_local_layout.major_to_minor) + custom_format.device_local_layout.major_to_minor) def test_concrete_layout_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -619,16 +619,16 @@ def test_sparsecore_compute(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Format(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) - dense_layout = Format(DLL(major_to_minor=(0, 1)), s) + sparse_format = Format(dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) + dense_format = Format(DLL(major_to_minor=(0, 1)), s) @compute_on('tpu_sparsecore') @jax.jit def sparsecore_compute(x): return x * x - @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) + @partial(jax.jit, out_shardings=(dense_format, sparse_format)) def f(x, y): return x * 2, sparsecore_compute(y) @@ -645,8 +645,8 @@ def test_sparsecore_compute_twice(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Format(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) + sparse_format = Format(dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) @compute_on('tpu_sparsecore') @jax.jit @@ -658,7 +658,7 @@ def sparsecore_multiply(x, y): def sparsecore_add(x, y): return x + y - @partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout) + @partial(jax.jit, donate_argnums=0, out_shardings=sparse_format) def f(x): return sparsecore_multiply(sparsecore_add(x, x) + 1, x) @@ -675,12 +675,12 @@ def test_sparsecore_and_host_compute(self): s = SingleDeviceSharding(jax.devices()[0]) sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - sparse_layout = Format(sparse_dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) + sparse_format = Format(sparse_dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - host_layout = Format(host_dll, s) - host_arr = jax.device_put(inp, host_layout) + host_format = Format(host_dll, s) + host_arr = jax.device_put(inp, host_format) @compute_on('tpu_sparsecore') @jax.jit @@ -694,8 +694,8 @@ def host_compute(x): @partial( jax.jit, - in_shardings=(sparse_layout, host_layout), - out_shardings=(sparse_layout, host_layout), + in_shardings=(sparse_format, host_format), + out_shardings=(sparse_format, host_format), ) def f(x, y): return sparsecore_compute(x), host_compute(y) @@ -710,8 +710,8 @@ def test_cpp_layout_cache_miss(self): arr = jax.device_put(np_inp, s) arr_m2m = arr.format.device_local_layout.major_to_minor - custom_layout = Format(DLL(major_to_minor=arr_m2m[::-1]), s) - arr2 = jax.device_put(np_inp, custom_layout) + custom_format = Format(DLL(major_to_minor=arr_m2m[::-1]), s) + arr2 = jax.device_put(np_inp, custom_format) @jax.jit def f(x): @@ -731,9 +731,9 @@ def test_layout_donation_with_default_layout(self): shape = (16, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - out_layout = Format(arr.format.device_local_layout, s) + out_format = Format(arr.format.device_local_layout, s) - @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) + @partial(jax.jit, out_shardings=out_format, donate_argnums=0) def f(x): return x * 2 @@ -743,7 +743,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.format, out_layout) + self.assertEqual(out.format, out_format) def test_with_layout_constraint(self): if not jtu.test_device_matches(['tpu']): From d15253e7f5e71b18dad93c2a0e3c10234be37550 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 30 May 2025 10:47:26 -0700 Subject: [PATCH 1448/1769] [Mosaic] Support interleaved packing on TPUv4-. This enables row broadcast for int8 and int4 on TPUv4. PiperOrigin-RevId: 765252479 --- .../tpu/transforms/apply_vector_layout.cc | 3 +-- jaxlib/mosaic/dialect/tpu/vreg_util.cc | 17 ++++------------- jaxlib/mosaic/dialect/tpu/vreg_util.h | 3 +-- tests/pallas/tpu_ops_test.py | 14 ++++++++++++-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 53d8712d5274..1669d1bf1586 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3655,8 +3655,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (packing != 1) { if (auto new_dst_vreg = broadcastSubelements( builder, cast>(dst_vreg), - subelement_offset, ctx.target_shape, - ctx.hardware_generation); + subelement_offset, ctx.target_shape); succeeded(new_dst_vreg)) { dst_vreg = *new_dst_vreg; } else { diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 237bbe5cc722..90efacf0c676 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -224,8 +224,7 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation) { + int subelement_idx, std::array target_shape) { int bitwidth = vec.getType().getElementTypeBitWidth(); int packing = 32 / bitwidth; if (subelement_idx < 0 || subelement_idx >= packing) { @@ -247,17 +246,9 @@ FailureOr> broadcastSubelements( src_vreg_int, getFullVector(builder, vreg_native_int_ty, builder.getI32IntegerAttr(subelement_idx * bitwidth))); - Value vreg_result_int; - if (hardware_generation >= 5) { - SmallVector packed_vregs(packing, vreg_subelement_low); - vreg_result_int = builder.create( - vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); - } else { - // This can be virtualized as a tree of shifts and ORs. - return builder.emitError() - << "broadcastSubelements not implemented for hardware generation " - << hardware_generation; - } + SmallVector packed_vregs(packing, vreg_subelement_low); + Value vreg_result_int = builder.create( + vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); return cast>( builder.create(vec.getType(), vreg_result_int) .getResult()); diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 8833390ef87b..90e802fcb8fc 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -90,8 +90,7 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, // subelement_idx must be between 0 and packing. FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation); + int subelement_idx, std::array target_shape); } // namespace mlir::tpu diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index de87126ebd3f..3f6dc593e333 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -197,8 +197,18 @@ def kernel(x_ref, y_ref, out_ref): def test_row_broadcast(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 1, 10): self.skipTest("Requires libtpu built after 2025-01-10") - if not self.INTERPRET and jtu.get_tpu_version() < 5: - self.skipTest("Requires TPUv5+") + bitwidth = pallas_utils.dtype_bitwidth(dtype) + if not self.INTERPRET and jtu.get_tpu_version() < 4 and bitwidth < 8: + self.skipTest("Requires TPUv4+ for sub-byte types") + if ( + not self.INTERPRET + and jtu.get_tpu_version() == 4 + and bitwidth < 16 + and not jtu.if_cloud_tpu_at_least(2025, 6, 2) + ): + self.skipTest( + "Requires libtpu built after 2025-06-02 for bitwidth < 16 on TPUv4" + ) def kernel(x_ref, y_ref): y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape).astype(y_ref.dtype) m, n = 4, 1152 From c2e7d61323b17481d213190bb779a4b74e7d5356 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 30 May 2025 11:23:05 -0700 Subject: [PATCH 1449/1769] [pallas] Expose TPUInterpretParams in jax.experimental.pallas.tpu PiperOrigin-RevId: 765266754 --- jax/experimental/pallas/tpu.py | 1 + .../tpu_pallas_interpret_distributed_test.py | 12 +++---- tests/pallas/tpu_pallas_interpret_test.py | 32 +++++++++---------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index e27fdaaadd8f..c4d21023a6e6 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -28,6 +28,7 @@ from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core +from jax._src.pallas.mosaic.interpret import TPUInterpretParams as TPUInterpretParams from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef from jax._src.pallas.mosaic.pipeline import BufferedRefBase as BufferedRefBase diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 4e4776736cf1..c5f1b29fd6bc 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -107,7 +107,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape=out_shape, grid_spec=grid_spec, compiler_params=pltpu.TPUCompilerParams(collective_id=13), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. @@ -228,7 +228,7 @@ def _(): all_gather_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -388,7 +388,7 @@ def _(): all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -672,7 +672,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=True), compiler_params=pltpu.TPUCompilerParams(collective_id=7), )(input_arr)[0] @@ -976,7 +976,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=19), )(input_arr)[0] @@ -1064,7 +1064,7 @@ def run(src_dst_ids): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( dma_execution_mode='eager', detect_races=True, ), diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 47d4ba3e1acf..871f66d71c53 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -124,7 +124,7 @@ def matmul(x: jax.Array, y: jax.Array): (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), ), - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -155,7 +155,7 @@ def block_dynamic_slice(x, starts, sizes): dynamic_slice_kernel, grid_spec=grid_spec, out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), ) block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) return kernel(block_idx, x) @@ -189,7 +189,7 @@ def f(s, x): ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) @@ -224,7 +224,7 @@ def _(): ), scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), input_output_aliases={0: 0}, - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), )(x) expected = np.zeros((4, 4)) @@ -264,7 +264,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( detect_races=True, dma_execution_mode=dma_execution_mode ), )(x).block_until_ready() @@ -279,7 +279,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( detect_races=True, dma_execution_mode=dma_execution_mode ), )(x).block_until_ready() @@ -293,7 +293,7 @@ def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( skip_floating_point_ops=True ), )(x, y) @@ -325,7 +325,7 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.bfloat16), pltpu.VMEM((8, 128), jnp.int16), ], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( uninitialized_memory=uninitialized_memory ), )() @@ -355,7 +355,7 @@ def kernel_call(x, s): pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), )(x, s) with CountStoreCallbacksContext() as store_callbacks_counter: @@ -378,7 +378,7 @@ def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( random_seed=12345, grid_point_recorder=grid_point_recorder ), compiler_params=pltpu.TPUCompilerParams( @@ -436,7 +436,7 @@ def kernel(s_ref, o_ref): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=mosaic_interpret.TPUInterpretParams(random_seed=12345), + interpret=pltpu.TPUInterpretParams(random_seed=12345), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=('arbitrary', 'parallel') ), @@ -462,7 +462,7 @@ def kernel_call_dynamic_parallel_dimension(): grid=(dim_size,), in_specs=[], out_specs=pl.BlockSpec((1,), lambda _: (0,)), - interpret=mosaic_interpret.TPUInterpretParams(), + interpret=pltpu.TPUInterpretParams(), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=('parallel',) ), @@ -479,7 +479,7 @@ def f(x): y = jnp.zeros_like(x) def inner(refs): x_ref, y_ref = refs - @pl.core_map(mesh, interpret=mosaic_interpret.TPUInterpretParams()) + @pl.core_map(mesh, interpret=pltpu.TPUInterpretParams()) def _(): num_cores = jax.lax.psum(1, "x") slc_size = 16 // num_cores @@ -520,7 +520,7 @@ def kernel(x_ref, o_ref, vmem_ref): scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), ], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( num_cores_per_device=2, detect_races=True, ), @@ -554,7 +554,7 @@ def kernel(x_ref, o_ref, vmem_ref): scratch_shapes=[ pltpu.VMEM((8, 128), x.dtype), ], - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( num_cores_per_device=2, detect_races=True, ), @@ -578,7 +578,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.TPUInterpretParams( random_seed=12345, num_cores_per_device=num_cores_per_device, grid_point_recorder=grid_point_recorder, From 6ba11c181afdab2ef4c642aa5d141a2a6de4c0b7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 30 May 2025 11:50:36 -0700 Subject: [PATCH 1450/1769] Pass list rather than generator to donate_argnums --- jax/_src/shard_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index c38312868163..69c1b7d264f6 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1788,7 +1788,7 @@ def infer_params(*args, **kwargs): p.flat_args, mesh, list(in_specs)) jitted_f = jax.jit( _pmapped, - donate_argnums=(i for i, val in enumerate(p.donated_invars) if val)) + donate_argnums=[i for i, val in enumerate(p.donated_invars) if val]) return jitted_f, flat_global_args, p.out_tree, mesh, out_specs def wrapped(*args, **kwargs): From f13a560925f6a1672d7e4336a8f32078202af5f1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 May 2025 12:20:03 -0700 Subject: [PATCH 1451/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1c4bad881e922d783e382982f0e3d175d1c5e707. PiperOrigin-RevId: 765289156 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ac418308de83..adeb0cc1bfc5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "adf78529a5536db17f147c6abbce0e4ece83ba42" -XLA_SHA256 = "559406848b34c82856f45c8371ea8c2d7a92e2d13db768fc401fda8f4c22c70a" +XLA_COMMIT = "1c4bad881e922d783e382982f0e3d175d1c5e707" +XLA_SHA256 = "46af35dac4c699badd7a230a1cbbaabfea1cee81394bbe1b9c6cbc33046651c4" def repo(): tf_http_archive( From 69bcb0d88e24cb736c4fc1224b57075a824447aa Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 30 May 2025 18:30:39 +0000 Subject: [PATCH 1452/1769] [mutable-arrays] don't let scan AD hoist mutable operations We do that by marking all mutable consts as unkonwn when we do the hoisting-via-partial-eval. That is a bit conservative, in that pure math on reads would be safe to hoist; what we really don't want to hoist is writing. **The remat path `_scan_partial_eval_custom` is not tested** because we currently error out when there are mutable arrays under remat. I did one simple manual test, but there's no checked-in test for it. --- jax/_src/lax/control_flow/loops.py | 20 +++++++++++++++----- tests/mutable_array_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 4df4c517090b..47808ee3c423 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -854,6 +854,8 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # want to broadcast the matrix!). So, outside the loop we perform a partial # evaluation with known 'const' inputs (but all other inputs unknown). const_pvals = [pe.PartialVal.known(t.pval.get_known()) + if not isinstance(t.aval, state.AbstractRef) + else pe.PartialVal.unknown(t.aval) for t in tracers[:num_consts] if t.pval.is_known()] other_pvals = [pe.PartialVal.unknown(aval) for aval in jaxpr_known.in_avals[len(const_pvals):]] @@ -898,7 +900,9 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # We use `fwds_known` below when forming the output of scanning jaxpr_known. # Run the known part of the scan (if it has any outputs or effects). - known_inputs = (list(jaxpr_known_consts) + + known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts] + if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] + known_inputs = (list(jaxpr_known_consts) + known_mutable_consts + [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()]) if not jaxpr_known.out_avals and not jaxpr_known.effects: @@ -907,7 +911,8 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, linear_known = [False] * len(known_inputs) # conservative! out_known = scan_p.bind( *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, - num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), + num_consts=len(jaxpr_known_consts) + len(known_mutable_consts), + num_carry=num_carry - sum(carry_uk), linear=tuple(linear_known), unroll=unroll, _split_transpose=_split_transpose) del linear_known @@ -1292,10 +1297,12 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): num_const_known = len(const_uk) - sum(const_uk) num_carry_known = len(carry_uk) - sum(carry_uk) num_xs_known = len( xs_uk) - sum( xs_uk) + const_donthoist = [isinstance(a, state.AbstractRef) + for a in jaxpr_known.in_avals[:num_const_known]] jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ pe.partial_eval_jaxpr_nounits( jaxpr_known, - [False] * num_const_known + [True] * (num_carry_known + num_xs_known), + const_donthoist + [True] * (num_carry_known + num_xs_known), [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) # jaxpr_known_hoist produces intensive residuals followed by the constants for # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. @@ -1328,10 +1335,13 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): linear=tuple(linear_known)) def known(*ins_known): - consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_maybehoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_hoist, consts_known_donthoist = \ + partition_list(const_donthoist, consts_known_maybehoist) out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) - out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) + out_loop = scan_p.bind(*consts_known_lp, *consts_known_donthoist, + *ins_known_lp, **params_known) return [*intensive_res, *out_loop] call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(known, debug_info=jaxpr_known_hoist.jaxpr.debug_info), diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 0da335e2fac5..8d80499c0e26 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -268,6 +268,34 @@ def test_rng_key(self): # test read/write key[...] = jax.random.fold_in(key[...], 1) # don't crash + def test_scan_grad_doesnt_hoist_mutable_stuff(self): + x_ref = core.mutable_array(0) + + def f(x): + def body(c, _): + x_ref[...] += 1 + return c, () + x, () = jax.lax.scan(body, x, (), length=3) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 3, check_dtypes=False) + + def test_scan_grad_doesnt_hoist_mutable_stuff2(self): + x_ref = core.mutable_array(0) + const = jnp.arange(3) + const2 = jnp.zeros(()) + + def f(x): + def body(c, _): + x_ref[...] += const.sum() + return c + const2, () + x, () = jax.lax.scan(body, x, (), length=4) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 12, check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From 581cb628a5f6516cfb63df201e18361dd0af6e96 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 30 May 2025 20:18:09 +0000 Subject: [PATCH 1453/1769] [mutable-arrays] add basic tests for vmap + mutable array --- jax/_src/state/primitives.py | 7 +++++++ tests/mutable_array_test.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index dbcc67df18cb..3a54644dbd37 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -618,6 +618,13 @@ def _swap_vmap(batched_args, batched_dims, *, tree): val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) + + if not ref_is_batched: + raise Exception("performing a set/swap operation with vmapped value on " + "an unbatched mutable array reference " + f"of type {core.typeof(ref)}. Move the mutable array to be " + "an argument to the vmapped function?") + if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 0da335e2fac5..95ebdd818fa6 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -253,6 +253,27 @@ def f(x): ys = f(xs) self.assertAllClose(ys, xs ** 2, check_dtypes=False) + def test_vmap_extensive_inputs(self): + def f(x_ref, val): + x_ref[...] += val + x_ref[...] += val + + xs_ref = core.mutable_array(jnp.array([0, 0, 0])) + vals = jnp.arange(3) + jax.vmap(f)(xs_ref, vals) + self.assertAllClose(xs_ref[...], 2 * vals, check_dtypes=False) + + def test_vmap_closed_over_read_only(self): + y_ref = core.mutable_array(1) + + def f(x_ref): + x_ref[...] += y_ref[...] + x_ref[...] += y_ref[...] + + xs_ref = core.mutable_array(jnp.array([0, 0, 0])) + jax.vmap(f)(xs_ref) + self.assertAllClose(xs_ref[...], jnp.array([2, 2, 2]), check_dtypes=False) + def test_implicit_bitcast_regression(self): # https://github.com/jax-ml/jax/issues/27683 v = core.mutable_array(jnp.array([0, 0, 0])) @@ -417,6 +438,16 @@ def false_fun(): out_false = f(False) self.assertAllClose(x_ref[...], 2.) + def test_vmap_closed_over_ref_write(self): + x_ref = core.mutable_array(jnp.zeros((), 'int32')) + + def f(val): + x_ref[...] += val + + vals = jnp.arange(3, dtype='int32') + with self.assertRaisesRegex(Exception, "unbatched mutable array"): + jax.vmap(f)(vals) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 753ae5707b3570af4af540c6b23bb96b492e9dbd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 May 2025 16:34:18 -0400 Subject: [PATCH 1454/1769] Fix a rare numerical flake in svd_test seen on TPU v6e. Relax the test tolerance a bit. --- tests/svd_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/svd_test.py b/tests/svd_test.py index 4225db038d72..dfc3de7a764f 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -166,7 +166,7 @@ def testSvdWithOnRankDeficientInputZeroColumns(self, m, r): np.testing.assert_almost_equal(diff, 1e-4, decimal=2) # Check that u and v are orthogonal. self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS) - self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS) + self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=30 * _SVD_TEST_EPS) @jtu.sample_product( [dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])], From 22f04d92fc1c6c9fa85dd486862f88fce36964f9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 30 May 2025 14:36:56 -0700 Subject: [PATCH 1455/1769] Refactor jax/_src/api.py and associated files in preparation for moving them to their own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This change stops short of actually making the main jax package build rule depend on the new api build rule, because some downstream targets need to be migrated and pytype errors need to be fixed before we can land the final change. PiperOrigin-RevId: 765341918 --- jax/BUILD | 58 +++++++++++++++++++++++++++++--- jax/_src/api.py | 8 +++-- jax/_src/array.py | 8 ++--- jax/_src/dispatch.py | 13 +++---- jax/_src/interpreters/pxla.py | 52 ++++++++++++++-------------- jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/triton/BUILD | 1 + jax/_src/pjit.py | 36 ++------------------ jax/_src/state/discharge.py | 34 +++++++++++++++++++ jax/extend/BUILD | 2 ++ 11 files changed, 135 insertions(+), 79 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 91a6e5926e42..67fc208f7841 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -295,8 +295,8 @@ py_library_providing_imports_info( srcs = [ "_src/__init__.py", "_src/ad_checkpoint.py", - "_src/api.py", - "_src/array.py", + "_src/api.py", # TODO(vanderplas): remove this and depend on :api instead + "_src/array.py", # TODO(vanderplas): remove this and depend on :api instead "_src/blocked_sampler.py", "_src/buffer_callback.py", "_src/callback.py", @@ -305,14 +305,14 @@ py_library_providing_imports_info( "_src/custom_partitioning.py", "_src/custom_partitioning_sharding_rule.py", "_src/debugging.py", - "_src/dispatch.py", + "_src/dispatch.py", # TODO(vanderplas): remove this and depend on :api instead "_src/dlpack.py", "_src/earray.py", "_src/error_check.py", "_src/ffi.py", "_src/flatten_util.py", - "_src/interpreters/__init__.py", - "_src/interpreters/pxla.py", + "_src/interpreters/__init__.py", # TODO(vanderplas): remove this and depend on :api instead + "_src/interpreters/pxla.py", # TODO(vanderplas): remove this and depend on :api instead "_src/pjit.py", "_src/prng.py", "_src/public_test_util.py", @@ -375,6 +375,7 @@ py_library_providing_imports_info( ":abstract_arrays", ":ad", ":ad_util", + # ":api", # TODO(vanderplas): add this dependency once downstream targets are fixed ":api_util", ":attrs", ":basearray", @@ -450,6 +451,53 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "api", + srcs = [ + "_src/api.py", + "_src/array.py", + "_src/dispatch.py", + "_src/interpreters/pxla.py", + "_src/pjit.py", + ], + visibility = [":internal"] + jax_visibility("api"), + deps = [ + ":abstract_arrays", + ":ad", + ":api_util", + ":attrs", + ":basearray", + ":batching", + ":compiler", + ":config", + ":core", + ":deprecations", + ":dtypes", + ":effects", + ":layout", + ":mesh", + ":mlir", + ":monitoring", + ":op_shardings", + ":partial_eval", + ":partition_spec", + ":profiler", + ":sharding", + ":sharding_impls", + ":sharding_specs", + ":source_info_util", + ":stages", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "api_util", srcs = ["_src/api_util.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index c21e8248d52d..229dee979d06 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -66,7 +66,6 @@ rebase_donate_argnums, _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, debug_info, flat_out_axes) -from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib @@ -477,6 +476,8 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, shapes and types as the corresponding arguments. If ``has_aux`` is True then a tuple of ((value, auxiliary_data), gradient) is returned. """ + from jax._src.lax import lax as lax_internal # pytype: disable=import-error + if reduce_axes: raise NotImplementedError("reduce_axes argument to grad is deprecated") del reduce_axes @@ -889,7 +890,7 @@ def hessian(fun: Callable, argnums: int | Sequence[int] = 0, argnums, has_aux=has_aux, holomorphic=holomorphic) def _std_basis(pytree): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error leaves, _ = tree_flatten(pytree) ndim = sum(map(np.size, leaves)) dtype = dtypes.result_type(*leaves) @@ -905,6 +906,7 @@ def _jacrev_unravel(output_pytree, input_pytree_leaf, arr): output_pytree, 0, input_pytree_leaf, arr) def _possible_downcast(x, example): + from jax._src.lax import lax as lax_internal # pytype: disable=import-error if (dtypes.issubdtype(x.dtype, np.complexfloating) and not dtypes.issubdtype(_dtype(example), np.complexfloating)): x = x.real @@ -1483,7 +1485,7 @@ def pmap( " from pmap.") if config.pmap_shmap_merge.value: - from jax._src.shard_map import pmap + from jax._src.shard_map import pmap # pytype: disable=import-error return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, diff --git a/jax/_src/array.py b/jax/_src/array.py index 9a71b12ed1a8..2514502c27d0 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -343,8 +343,8 @@ def __format__(self, format_spec): return format(self._value, format_spec) def __getitem__(self, idx): - from jax._src.lax import lax - from jax._src.numpy import indexing + from jax._src.lax import lax # pytype: disable=import-error + from jax._src.numpy import indexing # pytype: disable=import-error self._check_if_deleted() if isinstance(self.sharding, PmapSharding): @@ -444,7 +444,7 @@ def __dlpack__(self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, dl_device: tuple[DLDeviceType, int] | None = None, copy: bool | None = None): - from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + from jax._src.dlpack import to_dlpack # pytype: disable=import-error # pylint: disable=g-import-not-at-top device_set = self.sharding.device_set if len(device_set) > 1: @@ -464,7 +464,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: raise BufferError("__dlpack__ only supported for unsharded arrays.") - from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top + from jax._src.dlpack import DLDeviceType # pytype: disable=import-error # pylint: disable=g-import-not-at-top if self.platform() == "cpu": return DLDeviceType.kDLCPU, 0 diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 028c2cfa125e..b5e588cbc10e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -26,7 +26,6 @@ import time from typing import Any -import jax from jax._src import api from jax._src import array from jax._src import basearray @@ -34,8 +33,11 @@ from jax._src import core from jax._src import dtypes from jax._src import lib +from jax._src import pjit from jax._src import traceback_util from jax._src import util + +from jax._src import xla_bridge from jax._src.abstract_arrays import array_types from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -133,7 +135,7 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put( + return api.device_put( tok, NamedSharding(Mesh(devices, 'x'), PartitionSpec('x'))) # We only use replicated sharding for the first time when the token for the @@ -244,8 +246,7 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: @util.weakref_lru_cache def get_intermediate_shardings( jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: - from jax._src import pjit - from jax._src import shard_map + from jax._src import shard_map # pytype: disable=import-error out = [] for eqn in jaxpr.eqns: @@ -409,7 +410,7 @@ def result_handler(self, shard_arg_result): def _device_put_sharding_impl(x, aval, device, copy): - from jax.experimental import multihost_utils + from jax.experimental import multihost_utils # pytype: disable=import-error if isinstance(device, Sharding): s = device @@ -440,7 +441,7 @@ def _device_put_sharding_impl(x, aval, device, copy): # sharding do not transfer data) or (2) the sharding contains a # different subset of devices on each host. For (1), the input should be # the same on all hosts, but for (2) it need not be. - if jax.process_count() == len(s._internal_device_list.process_indices): # pytype: disable=attribute-error + if xla_bridge.process_count() == len(s._internal_device_list.process_indices): # pytype: disable=attribute-error multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 74f2e028c555..0530e313f310 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -29,9 +29,8 @@ import numpy as np -import jax - from jax._src import api +from jax._src import array from jax._src import compiler from jax._src import config from jax._src import core @@ -41,11 +40,13 @@ from jax._src import linear_util as lu from jax._src import op_shardings from jax._src import sharding_specs +from jax._src import pjit from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types @@ -154,7 +155,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. - results: list[jax.Array | None] = [None] * len(args) + results: list[typing.Array | None] = [None] * len(args) for t, (indices, a, s, l, cs) in batches.items(): outs = shard_arg_handlers[t](a, s, l, cs) for i, out in safe_zip(indices, outs): @@ -230,11 +231,9 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics): def batched_device_put(aval: core.ShapedArray, sharding: JSharding, xs: Sequence[Any], - devices: Sequence[jax.Device], committed: bool = True): + devices: Sequence[xc.Device], committed: bool = True): util.test_event("batched_device_put_start") try: - from jax._src import array - bufs = [x for x, d in safe_zip(xs, devices) if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and @@ -385,7 +384,6 @@ def _emap_impl(fun: lu.WrappedFun, *args, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ): - from jax._src import array # TODO(sharadmv,mattjj): implement these cases if any(d for d in donated_invars): raise NotImplementedError("Buffer donation not supported in eager pmap.") @@ -410,12 +408,12 @@ def _emap_impl(fun: lu.WrappedFun, *args, donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else () new_outvals = [] for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals): - with jax.disable_jit(False): + with api.disable_jit(False): donate_argnums_ = donate_argnums if isinstance(outval, array.ArrayImpl): # We don't want to donate if it's already sharded. donate_argnums_ = () - out = jax.pmap( + out = api.pmap( lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)), out_axes=out_axis, @@ -448,7 +446,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], for i, name in reversed(list(enumerate(names))): in_axes = tuple(arg_axis[i] for arg_axis in all_axes) if any(in_axis is not None for in_axis in in_axes): - f = jax.pmap( + f = api.pmap( f, in_axes=in_axes, axis_name=name, @@ -476,11 +474,12 @@ def to_map_tracer(self, val): return MapTracer(self, val, {}) def process_primitive(self, primitive, tracers, params): - if primitive is jax._src.lax.parallel.axis_index_p: - return self.process_axis_index(**params) - if primitive is jax._src.lax.parallel.psum_p: + from jax._src.lax import parallel # pytype: disable=import-error + if primitive is parallel.axis_index_p: + return self.process_axis_index(**params) # pytype: disable=missing-parameter + if primitive is parallel.psum_p: f = HashableFunction( - lambda *xs: jax._src.lax.parallel.psum( + lambda *xs: parallel.psum( xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']), (primitive, tuple(params.items()))) else: @@ -492,7 +491,7 @@ def process_primitive(self, primitive, tracers, params): names = core.get_axis_env().axis_names() all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes) - with core.eval_context(), jax.disable_jit(False): + with core.eval_context(), api.disable_jit(False): outvals = f_mapped(*vals) if primitive.multiple_results: return [MapTracer(self, val, out_shard_axes) for val in outvals] @@ -546,11 +545,12 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, return fun.call_wrapped(*tracers) def process_axis_index(self, axis_name): + from jax._src.lax import lax, parallel # pytype: disable=import-error bind = HashableFunction( - lambda _: jax.lax.axis_index(axis_name), - (jax.lax.axis_index, axis_name)) + lambda _: parallel.axis_index(axis_name), + (parallel.axis_index, axis_name)) fake_primitive = FakePrimitive(multiple_results=False, bind=bind) - range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) + range = lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) dummy_tracer = MapTracer(self, range, {axis_name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) @@ -695,14 +695,15 @@ def find_replicas( @lu.transformation2 def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): + from jax._src.lax import lax # pytype: disable=import-error args = tuple( - arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) + arg if in_axis is None else lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) results = f(*args) out_axes = out_axes_thunk() return tuple( - x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) + x if axis is None else lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) @@ -1276,7 +1277,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token): assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) token_devices.append(token.sharding._device_assignment[0]) s = NamedSharding(Mesh(token_devices, 'x'), P('x')) - global_token_array = jax.make_array_from_single_device_arrays( + global_token_array = array.make_array_from_single_device_arrays( (0,), s, token_buf ) dispatch.runtime_tokens.set_token_result( @@ -1754,7 +1755,7 @@ class MutationData(NamedTuple): def _discharge_refs( jaxpr: core.ClosedJaxpr ) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state # pytype: disable=import-error jaxpr, in_mut = _move_mutable_consts(jaxpr) new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end @@ -1782,7 +1783,7 @@ def _move_mutable_consts( @weakref_lru_cache def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state # pytype: disable=import-error jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts) jaxpr_._debug_info = jaxpr.jaxpr._debug_info return core.ClosedJaxpr(jaxpr_, consts) @@ -2016,8 +2017,6 @@ def _default_rule(prim, num_outvars, *_, **__): @weakref_lru_cache def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr ) -> tuple[None | DeviceLocalLayout]: - from jax._src import pjit - env = {} # type: ignore jaxpr = closed_jaxpr.jaxpr @@ -3229,7 +3228,6 @@ def check_array_xla_sharding_layout_match( in_xla_layouts: Sequence[DeviceLocalLayout], jaxpr_debug_info: core.DebugInfo, kept_var_idx: set[int]) -> None: - from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. arg_names = ( [a for i, a in enumerate(jaxpr_debug_info.arg_names) @@ -3239,7 +3237,7 @@ def check_array_xla_sharding_layout_match( num_errors = 5 for arg, xs, xl, name in safe_zip( args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): - if not isinstance(arg, ArrayImpl): + if not isinstance(arg, array.ArrayImpl): continue if isinstance(xs, (UnspecifiedValue, AUTO)): continue diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index d5b5d128241d..951c08d8f4fa 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -48,6 +48,7 @@ pytype_strict_library( ":fuser_utils", "//jax", "//jax:ad_util", + "//jax:api", "//jax:api_util", "//jax:core", "//jax:custom_derivatives", diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 8a5d087125b7..78a1bd4f0011 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -60,6 +60,7 @@ pytype_strict_library( deps = [ ":core", "//jax", + "//jax:api", "//jax:core", "//jax:mesh", "//jax:mlir", diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index acbc11a60039..b13967d5b61c 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -60,6 +60,7 @@ pytype_strict_library( deps = [ "//jax", "//jax:ad_util", + "//jax:api", "//jax:api_util", "//jax:config", "//jax:core", diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4113f764e888..f2446e9a4939 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -70,7 +70,7 @@ prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, flatten_spec, _internal_use_concrete_mesh) from jax._src.layout import Format, DeviceLocalLayout, AutoLayout -from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef +from jax._src.state.types import RefEffect from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, @@ -1413,7 +1413,7 @@ def _create_pjit_jaxpr( if config.debug_key_reuse.value: # Import here to avoid circular imports - from jax.experimental.key_reuse._core import check_key_reuse_jaxpr + from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) if any(isinstance(c, core.Tracer) or core.typeof(c).mutable for c in consts): @@ -2682,38 +2682,6 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, core.pp_eqn_rules[pjit_p] = _pjit_pp_rule -def _pjit_state_discharge_rule( - in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, **params): - if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): - raise NotImplementedError - - if not (all(l is None for l in in_layouts) and - all(l is None for l in out_layouts)): - raise NotImplementedError - - jaxpr, consts = jaxpr.jaxpr, jaxpr.consts - num_outs = len(jaxpr.outvars) - discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts) - discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) - new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars) - new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars) - new_in_layouts = (None,) * len(discharged_jaxpr.invars) - new_out_layouts = (None,) * len(discharged_jaxpr.outvars) - out_and_ref_vals = pjit_p.bind( - *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, - out_shardings=new_out_shardings, in_layouts=new_in_layouts, - out_layouts=new_out_layouts, **params) - out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) - ref_vals_iter = iter(ref_vals) - new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) - else None for aval in in_avals) - sentinel = object() - assert next(ref_vals_iter, sentinel) is sentinel - return new_invals, out_vals -state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule) - - # -------------------- with_sharding_constraint -------------------- def check_shardings_are_auto(shardings_flat): diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 100447f12d18..9dce3297b947 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -25,6 +25,8 @@ from jax._src import api_util from jax._src import core from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import tree_util from jax._src.interpreters import ad @@ -1145,3 +1147,35 @@ def wrapped(args): _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped + + +@register_discharge_rule(pjit.pjit_p) +def _pjit_state_discharge_rule( + in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, **params): + if not all(isinstance(s, sharding_impls.UnspecifiedValue) for s in (*in_shardings, *out_shardings)): + raise NotImplementedError + + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + + jaxpr, consts = jaxpr.jaxpr, jaxpr.consts + num_outs = len(jaxpr.outvars) + discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) + discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) + new_in_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.invars) + new_out_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.outvars) + new_in_layouts = (None,) * len(discharged_jaxpr.invars) + new_out_layouts = (None,) * len(discharged_jaxpr.outvars) + out_and_ref_vals = pjit.pjit_p.bind( + *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, + out_shardings=new_out_shardings, in_layouts=new_in_layouts, + out_layouts=new_out_layouts, **params) + out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) + ref_vals_iter = iter(ref_vals) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) + else None for aval in in_avals) + sentinel = object() + assert next(ref_vals_iter, sentinel) is sentinel + return new_invals, out_vals diff --git a/jax/extend/BUILD b/jax/extend/BUILD index c2a5c48bd2b0..f466f1748654 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -45,6 +45,7 @@ py_library_providing_imports_info( "//jax:abstract_arrays", "//jax:ad", "//jax:ad_util", + "//jax:api", "//jax:core", "//jax:custom_derivatives", ], @@ -61,6 +62,7 @@ pytype_strict_library( srcs = ["backend.py"], deps = [ "//jax", + "//jax:api", "//jax:xla_bridge", ], ) From 67bf8f9d50bc3cefe54258de39a00dfcc8e04394 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Fri, 30 May 2025 14:47:13 -0700 Subject: [PATCH 1456/1769] Add experimental array serialization for nested pytrees Why this change? * JAX is missing a simple data serialization functionality that is compatible with pytrees. * Serialization of list of arrays is already supported, leading users to implement one-off data serialization solutions. New API: ``` def save(data: PyTreeT, directory: str | PathLike[str], overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: ... def load(directory: str | PathLike[str], shardings: PyTreeT, mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None ) -> PyTreeT: ... def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT: ... ``` PiperOrigin-RevId: 765345616 --- jax/experimental/array_serialization/BUILD | 27 +- .../pytree_serialization.py | 506 ++++++++++++++++++ .../pytree_serialization_utils.py | 85 +++ .../array_serialization/serialization_test.py | 419 ++++++++++++++- .../array_serialization/tensorstore_impl.py | 74 ++- 5 files changed, 1103 insertions(+), 8 deletions(-) create mode 100644 jax/experimental/array_serialization/pytree_serialization.py create mode 100644 jax/experimental/array_serialization/pytree_serialization_utils.py diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index d9f7e21e73f6..ebd78decf6a3 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -44,6 +44,29 @@ pytype_library( ]), ) +pytype_library( + name = "pytree_serialization", + srcs = ["pytree_serialization.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/experimental/array_serialization:pytree_serialization_utils", + "//jax/experimental/array_serialization:tensorstore_impl", + "//third_party/py/absl/logging", + "//third_party/py/numpy", + ], +) + +pytype_library( + name = "pytree_serialization_utils", + srcs = ["pytree_serialization_utils.py"], + deps = [ + "//jax", + "//third_party/py/absl/logging", + "//third_party/py/numpy", + ], +) + jax_multiplatform_test( name = "serialization_test", srcs = ["serialization_test.py"], @@ -51,8 +74,8 @@ jax_multiplatform_test( "tpu_v3_x4", ], deps = [ - ":serialization", - "//jax:experimental", + "//jax/experimental/array_serialization:pytree_serialization", + "//jax/experimental/array_serialization:serialization", ], ) diff --git a/jax/experimental/array_serialization/pytree_serialization.py b/jax/experimental/array_serialization/pytree_serialization.py new file mode 100644 index 000000000000..639d36a7c806 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization.py @@ -0,0 +1,506 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Serializations routines for pytrees including array and non-array serialization. +""" + +from __future__ import annotations + +from os import PathLike +import os +import re +from typing import Any +from uuid import uuid4, UUID +import json +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +import shutil +import logging + +import jax +from jax._src import distributed +from jax._src.api_util import flatten_axes + +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +import jax.experimental.array_serialization.pytree_serialization_utils as utils +from jax._src import path as pathlib +import numpy as np + +logger = logging.getLogger(__name__) + +_THREADING_SAVE_LOCK = threading.Lock() + +_REMOTE_URL_PREFIXES = ['gs://', 's3://'] +_PYTREEDEF_FILE = "pytreedef.json" +_ARCHIVE_NAME = "archive.zip" +_USE_OCDBT = True # a lot of the code relies on this being True +_MAX_PATH_LENGTH = 4096 +_ARRAY_STORE_DIRNAME = "array_store" +_ARRAY_TYPE_FORMAT = "Array({dtype}[{shape}])" +_ARRAY_TYPE_REGEX = r"Array\(([a-zA-Z0-9_]+)\[([0-9, ]*)\]\)" +_MAX_CONCURRENCY = 32 +_TIMEOUT_SEC = 30 + +PyTreeT = Any + +__all__ = ["save", "load", "load_pytreedef", + "nonblocking_load", "nonblocking_save"] + + +def _get_unique_sync_key() -> str | None: + """Generate a thread-local key for ensuring all host finish (de)serializing""" + if jax.process_count() == 1: + return None + # broadcast a thread-local unique barrier name + sync_key_unique = multihost_utils.broadcast_one_to_all( + np.frombuffer(uuid4().bytes, dtype=np.int32)) + sync_key_id = UUID(bytes=np.array(sync_key_unique).tobytes()) + return f"jax_sync_key_{str(sync_key_id)}" + + +def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool: + """All-gather the location of the checkpoint and check if it's the same.""" + if jax.process_count() <= 1: + return False + path_b = str(path).encode("utf-8") + if len(path_b) > _MAX_PATH_LENGTH: + raise ValueError(f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in" + " multiprocess case.") + path_array = np.concatenate([ + np.frombuffer(path_b, dtype=np.uint8), np.zeros( + _MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)]) + path_array = multihost_utils.process_allgather(path_array) + return bool(np.all(path_array[0] == path_array[1:])) + + +def _sync_on_key(key: str | None, extra_tag: str = "") -> None: + if key is None: + return + full_key = f"{key}-{extra_tag}" if extra_tag else key + if (client := distributed.global_state.client) is not None: + client.wait_at_barrier(full_key, timeout_in_ms=_TIMEOUT_SEC * 1000) + + +def _is_array_like(x): + return isinstance(x, (jax.Array, np.ndarray)) + + +def _leaf_to_desc(leaf) -> str: + if leaf is None: + return "null" + elif _is_array_like(leaf): + return _ARRAY_TYPE_FORMAT.format( + dtype=leaf.dtype.name, shape=", ".join(map(str, leaf.shape))) + else: + return type(leaf).__name__ + + +def _desc_to_leaf(leaf_desc: str | None) -> str | None | jax.ShapeDtypeStruct: + if leaf_desc is None: + return None + if not re.match(_ARRAY_TYPE_REGEX, leaf_desc): + return leaf_desc + shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_desc) + assert shape_dtype_match is not None + dtype_str, shape_str = shape_dtype_match.groups() + shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",") + if len(x.strip()) > 0] + return jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype_str)) + + +def _is_remote_path(path: str | PathLike[str]): + """Check whether a path is remote by examining the prefix.""" + # we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses // + return any(str(path).startswith(prefix[:-1]) + for prefix in _REMOTE_URL_PREFIXES) + + +def _norm_path(path: str | PathLike[str]) -> Any: + if _is_remote_path(path): + return pathlib.Path(path) + return pathlib.Path(path).expanduser().resolve() + + +def _rm_dir(root: Any) -> None: + if _is_remote_path(root): + root.rmtree() # pytype: disable=attribute-error + else: + shutil.rmtree(root) + + +def _set_up_destination(root: str | PathLike[str], overwrite: bool, + pytree_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None) -> dict[str, Any]: + """Inspect the destination, set it up for writing, potentially read existing data.""" + root = _norm_path(root) + if overwrite: + if root.exists() and len(list(root.iterdir())) > 0: + # check that we're only deleting things that come from JAX + # refuse to rm directories containing additional entries + extra_member_paths = [ + path for path in list(root.iterdir()) if path.name not in + (_PYTREEDEF_FILE, _ARCHIVE_NAME, _ARRAY_STORE_DIRNAME)] + + if len(extra_member_paths) != 0: + raise RuntimeError( + "Refusing to work on a directory that is not a previous checkpoint." + f" Unrecognized paths: {extra_member_paths}. Remove them manually" + f" if you're sure you want to use {root} as the checkpoint" + " directory.") + + if (jax.process_index() == 0 or distinct_locations) and root.exists(): + _rm_dir(root) + _sync_on_key(sync_key, "overwrite") + return pytree_repr + else: + if (root.exists() and len(list(root.iterdir())) > 0): # not empty + raise ValueError(f"Files already exist at path: `{root}`, but you" + f" specified `{overwrite=}`") + return pytree_repr + + +def _prepare_directory(root: str | PathLike[str], overwrite: bool, + pytreedef_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None): + """Prepare the directory: check destination, potentially read existing data + and overwrite. + + Raises: + RuntimeError: If the destination directory cannot be created. + """ + root = _norm_path(root) + # prepare the destination directory, overwrite destination directory or error + pytreedef_repr = _set_up_destination( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + + if not _is_remote_path(root) and (distinct_locations + or jax.process_index() == 0): + root.mkdir(exist_ok=True) # do not make parents, that's too much + if not root.exists() or not root.is_dir(): + raise RuntimeError(f"Could not create destination directory at {root}") + _sync_on_key(sync_key, "mkdir") + return pytreedef_repr + + +def _write_arrays(array_store_path: Any, arrs: list[Any], + arr_leaf_ids: list[int], ts_specs: list[Any | None], + distinct_locations: bool): + paths = [array_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + process_idx = None + if not distinct_locations and jax.process_count() > 1: + process_idx = jax.process_index() + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=process_idx, + arr=arr) + for (path, arr) in zip(paths, arrs)] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + # sanity check the ts specs + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec, arr in zip(ts_specs, arrs): + ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path, + ocdbt=_USE_OCDBT, check_metadata=True) + + async def _serialize_arrays(): + await asyncio.gather(*[ + ts_impl.async_serialize(arr, ts_spec, primary_host=None) + for (arr, ts_spec) in zip(arrs, ts_specs)]) + + asyncio.run(_serialize_arrays()) + + +def _finalize_array_store(kvstore_path, distinct_locations: bool): + """When multiple processes are writing, they must write to a per-process + location followed by combining them via no-copy links to the final location. + """ + # only in multiprocess case and only process 0 + if distinct_locations or jax.process_count() == 1 or jax.process_index() != 0: + return + dummy_key_path = os.path.join(kvstore_path, "dummy_key") + combined_kvstore = ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=None)["kvstore"] + children_kvstores = [ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=i)["kvstore"] + for i in range(jax.process_count())] + _ = combined_kvstore.pop("path") + _ = [kvstore.pop("path") for kvstore in children_kvstores] + asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores)) + + +def _write_pytreedef(directory: Any, pytree_repr: dict[str, Any], + distinct_locations: bool): + """Write the pytreedef to the destination directory and aux data to the archive.""" + if not (jax.process_index() == 0 or distinct_locations): + return + root = _norm_path(directory) + (root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2)) + + +def _tree_broadcast(a, b, is_leaf=lambda x: x is None): + """Broadcast the prefix tree `a` to the full tree `b` + + Uses `flatten_axes` for better error messages on mismatched arity but allowing + for custom is_leaf in the `a` and `b` trees. + """ + a_leaves, a_struct = jax.tree.flatten(a, is_leaf=is_leaf) + a_idx2leaf_map = dict(enumerate(a_leaves)) + a_idx = jax.tree.unflatten(a_struct, a_idx2leaf_map.keys()) + a_idx_broadcast = flatten_axes("tree_broadcast", + jax.tree.structure(b, is_leaf=is_leaf), a_idx) + return jax.tree.map(lambda i: a_idx2leaf_map[i], a_idx_broadcast) + + +_serialization_executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + + +def save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + """Saves the given data structure to the provided directory path. + + This function provides functionality to serialize and save a data structure + comprising JAX arrays, along with its structure to a given directory. It + leverages `PyTree` for flattening and reconstructing the data structure. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + data: The data structure to be saved. Arbitrary composition of JAX arrays, + including nested structures. + directory: The directory path where the data will be saved. A local path or + a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required. + overwrite: If True, any existing directory with the same name will be + overwritten. + ts_specs: Optional tensorstore specs to use for serialization. If None, + defaults to using the default tensorstore specs. + + Example: + >>> data = {"a": jnp.array([1, 2]), "b": None} + >>> save(data, directory) + """ + with _THREADING_SAVE_LOCK: + return _save(data, directory, overwrite=overwrite, ts_specs=ts_specs) + + +def _save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + sync_key = _get_unique_sync_key() # get a synchronization key for multi-host + + if _is_remote_path(directory) and not pathlib.epath_installed: + raise RuntimeError("For saving to remote URLs (e.g., gs, s3) you need the" + " `etils` module installed. You can install it using" + " `pip install etils`.") + ts_specs = _tree_broadcast(ts_specs, data, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None) + if not all(x is None or _is_array_like(x) for x in data_flat): + raise ValueError("For serialization, all leaves must be either None or" + " jax.Array-like objects.") + distinct_locations = not _is_str_same_on_all_hosts(directory) + if jax.process_count() > 1 and distinct_locations: + raise ValueError( + "Saving to different locations on different hosts is not supported," + " because it is extremely fragile. Consider using a single location.") + root = _norm_path(directory) + + # 1. serialize the pytree ################################# + pytreedef_repr = utils.serialize_pytreedef(pytreedef) + pytreedef_repr[utils._LEAF_IDS_KEY] = jax.tree.map(_leaf_to_desc, data_flat) + + pytreedef_repr = _prepare_directory( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + futures = [] + futures.append(_serialization_executor.submit( + _write_pytreedef, root, pytreedef_repr, distinct_locations)) + + # 2. serialize arrays ##################################### + array_store_path = root / _ARRAY_STORE_DIRNAME + arrs = [data for data in data_flat if _is_array_like(data)] + arr_leaf_ids = [i for i, data in enumerate(data_flat) if _is_array_like(data)] + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + futures.append(_serialization_executor.submit( + _write_arrays, array_store_path, arrs, arr_leaf_ids, ts_specs_flat, + distinct_locations)) + + # 3. wait for all futures to complete ##################### + _ = [fut.result() for fut in futures] + _sync_on_key(sync_key, "array_serialization") + + # 4. finalize the array writing ########################### + if len(arr_leaf_ids) > 0 and _USE_OCDBT: + _finalize_array_store(array_store_path, distinct_locations) + # we are done with all async ops here, we can block #### + _sync_on_key(sync_key, "end") + + +def _read_arrays(array_store_path: str | PathLike[str], arr_leaf_ids: list[int], + ts_specs: list[Any], shardings: list[Any]): + # array_store_path = root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME + arr_store_path = _norm_path(array_store_path) + arr_paths = [arr_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + + # byte limiter to limit number of parallel reads, resizes to largest read + byte_limiter = ts_impl._LimitInFlightBytes(10 * 1024 ** 3) # 10 GB + + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=None) + for path in arr_paths] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec in ts_specs: + ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path, + ocdbt=_USE_OCDBT, check_metadata=False) + + async def _deserialize_arrays(): + return await asyncio.gather(*[ + ts_impl.async_deserialize(sharding, ts_spec, byte_limiter=byte_limiter) + for (sharding, ts_spec) in zip(shardings, ts_specs)]) + + return dict(zip(arr_leaf_ids, asyncio.run(_deserialize_arrays()))) + + +def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT: + """Loads a pytree from the given directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path to load from. + Returns: + The loaded pytree with arrays represented as jax.ShapeDtypeStruct's. + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + json_content = (_norm_path(directory) / _PYTREEDEF_FILE).read_text() + raw_tree = json.loads(json_content) + leaves = map(_desc_to_leaf, raw_tree[utils._LEAF_IDS_KEY]) + return jax.tree.unflatten(utils.deserialize_pytreedef(raw_tree), leaves) + + +def load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None + ) -> PyTreeT: + """Loads and reconstructs a data structure from a directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path where the data is stored. + shardings: Sharding strategy for array objects. If None, defaults to + single device sharding on the default device. + mask: boolean prefix tree for partial loading, will return None for False + leaves. + ts_specs: Optional tensorstore specs to use for deserialization. If None, + defaults to using the default tensorstore specs. + + Returns: + Reconstructed data. + + Example: + >>> save(data, directory) + >>> restored_data = load(directory, SingleDeviceSharding(jax.devices()[0])) + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + + root = _norm_path(directory) + assert root.is_dir(), f"Checkpoint directory {root} does not exist" + is_leaf = lambda x: x is None + + # deserialize PyTreeDef + pytree = load_pytreedef(directory) + # broadcast the (prefix) shardings and tensorstore specs to the full pytree + shardings = _tree_broadcast(shardings, pytree) + ts_specs = _tree_broadcast(ts_specs, pytree, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + if mask is not None: + _prefix_mask = lambda m, x: jax.tree.map(lambda _: None, x) if not m else x + pytree = jax.tree.map(_prefix_mask, mask, pytree) + pytreedef = jax.tree.structure(pytree, is_leaf=is_leaf) + leaf_ids_flat = jax.tree.leaves(pytree, is_leaf=is_leaf) + shardings_flat = jax.tree.leaves(shardings, is_leaf=is_leaf) + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + + # deserialize array objects + arr_leaf_ids = [i for i, leaf_id in enumerate(leaf_ids_flat) + if leaf_id is not None] + shardings_flat = [shardings_flat[i] for i in arr_leaf_ids] + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + + arrs_fut = _serialization_executor.submit( + _read_arrays, root / _ARRAY_STORE_DIRNAME, arr_leaf_ids, ts_specs_flat, + shardings_flat) + + arrs = arrs_fut.result() + filled_values = [arrs.get(i, None) for i, _ in enumerate(leaf_ids_flat)] + return jax.tree.unflatten(pytreedef, filled_values) + + +def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None + ) -> utils.PyTreeFuture: + """Nonblocking alias of save, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_save(data, directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct's + >>> print(fut.result()) # None, blocking until the serialization is done + """ + # start serialization immediately + fut = utils.PyTreeFuture(_serialization_executor.submit( + save, data, directory, overwrite=overwrite, ts_specs=ts_specs)) + # construct a nice looking pytree representing the nodes being read + fut.pytree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) + if _is_array_like(x) else x, data) + return fut + + +def nonblocking_load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, + ts_specs: PyTreeT | None = None) -> utils.PyTreeFuture: + """Nonblocking alias of load, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_load(directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct + >>> print(fut.result()) # the fully populated pytree + """ + # TODO(rdyro): the awaitable future output is a workaround + # it should return the fully populated pytree instead of just + # jax.ShapeDtypeStruct for arrays by constructing them asynchronously + fut = utils.PyTreeFuture(_serialization_executor.submit( + load, directory, shardings, mask=mask, ts_specs=ts_specs)) + fut.pytree = load_pytreedef(directory) + return fut diff --git a/jax/experimental/array_serialization/pytree_serialization_utils.py b/jax/experimental/array_serialization/pytree_serialization_utils.py new file mode 100644 index 000000000000..a7d37eeab5f8 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization_utils.py @@ -0,0 +1,85 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# + +# # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for representing pytreedefs in a serializable format. +""" + +import base64 +import logging +from types import ModuleType +from concurrent.futures import Future +from typing import Any, TypeVar + +import jax +from jax._src.export.serialization import (flatbuffers, _serialize_pytreedef, + _deserialize_pytreedef_to_pytree, + ser_flatbuf) +from jax.export import register_pytree_node_serialization # pylint: disable=unused-import + +T = TypeVar("T") +PickleModule = ModuleType +logger = logging.getLogger(__name__) + +_READABLE_PYTREE_SERIALIZATION = True +_TREE_REPR_KEY = "__jax_pytreedef_repr" +_LEAF_IDS_KEY = "__jax_leaf_ids" + +_NOT_REGISTERED_MESSAGE = ( + " * If you want to register a custom leaf, register it via" + " `register_pytree_leaf_serialization` first.\n" + " * If you want to register a custom node, register is via" + " `register_pytree_node_serialization`") + +__all__ = ["serialize_pytreedef", "deserialize_pytreedef", + "register_pytree_node_serialization"] + +class PyTreeFuture(Future[Any]): + """A wrapper around a Future that makes it look like an async function.""" + def __init__(self, future: Future[Any]): + self._future, self.pytree = future, None + + def done(self): + return self._future.done() + + def result(self, *args, **kw): + return self._future.result(*args, **kw) + + def __await__(self): + while not self.done(): + yield + return self.result() + + +def _cls2typerepr(cls): + return f"{cls.__module__}.{cls.__name__}" + + +def serialize_pytreedef(node) -> dict[str, Any]: + builder = flatbuffers.Builder(65536) + exported = _serialize_pytreedef(builder, node) + builder.Finish(exported) + root_repr = base64.b64encode(builder.Output()).decode("utf-8") + leaf_count = node.num_leaves + pytree_repr = {_TREE_REPR_KEY: root_repr, + _LEAF_IDS_KEY: list(range(leaf_count))} + return pytree_repr + + +def deserialize_pytreedef(pytreedef_repr: dict[str, Any]): + buf = base64.b64decode(pytreedef_repr[_TREE_REPR_KEY]) + exp = ser_flatbuf.PyTreeDef.GetRootAs(buf) + treestruct = jax.tree.structure(_deserialize_pytreedef_to_pytree(exp)) + return treestruct diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 0d7e0a48b6c1..eab23443f545 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,11 +14,19 @@ # pylint: disable=g-importing-member import asyncio +from dataclasses import dataclass from functools import partial +import json +import logging import math import os import pathlib -import tracemalloc as tm +import pickle +import tempfile +import threading +import time +import tracemalloc as tm +from typing import Any from absl.testing import absltest from absl.testing import parameterized @@ -26,10 +34,19 @@ from jax._src import array from jax._src import config from jax._src import test_util as jtu +from jax._src.export._export import ( + deserialization_registry as node_deserialization_registry) +from jax._src.export._export import ( + serialization_registry as node_serialization_registry) from jax._src.layout import DeviceLocalLayout as DLL from jax._src.layout import Format +from jax.experimental.array_serialization import pytree_serialization from jax.experimental.array_serialization import serialization from jax.experimental.array_serialization import tensorstore_impl as ts_impl + +from jax.experimental.array_serialization.pytree_serialization_utils import ( + register_pytree_node_serialization) + import jax.numpy as jnp from jax.sharding import NamedSharding @@ -43,6 +60,16 @@ jtu.request_cpu_devices(8) +_default_sharding = None + + +def tree_load(*args, **kw): + return pytree_serialization.load(*args, shardings=_default_sharding, **kw) + +tree_save = pytree_serialization.save +tree_load_pytreedef = pytree_serialization.load_pytreedef + + def _get_replicated_sharding(devices): return NamedSharding( jax.make_mesh(np.shape(devices), P('x'), devices=devices), P()) @@ -98,7 +125,8 @@ def test_memory_consumption(self): inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx]) - ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprof-deserialize').full_path) tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) manager = serialization.GlobalAsyncCheckpointManager() @@ -134,6 +162,7 @@ async def deserialize_with_byte_limit(): self.assertGreater(peak, 30_000_000) tm.stop() + @jtu.thread_unsafe_test() def test_memory_consumption_for_save(self): global_mesh = jtu.create_mesh((1, 1), ('x', 'y')) inp_shape = (16 * 1024, 16 * 1024) @@ -144,7 +173,8 @@ def test_memory_consumption_for_save(self): inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx] ) - ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprofsave-serialize').full_path) tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir), ocdbt=False, driver='zarr3') tspec['metadata'] = { @@ -663,5 +693,388 @@ def test_transfer_shard_to_host(self): self.assertArraysEqual(np_out, np_inp) +def _remove_from_serialization_registry(t: Any): + if t in node_serialization_registry: + serialized_name = node_serialization_registry[t][0] + del node_serialization_registry[t] + del node_deserialization_registry[serialized_name] + + +class UserAPITestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + def generate_random_fp32(self, shape, dtype=jnp.float32): + seed = round(time.time() * 1e6) % (2 ** 31) + key = jax.random.key(seed) + return jax.random.normal(key, shape=shape).astype(dtype) + + def generate_clean_tree(self, dtype=jnp.float32): + r1 = self.generate_random_fp32((), dtype=dtype) + r2 = self.generate_random_fp32((4,), dtype=dtype) + r3 = self.generate_random_fp32((2, 3), dtype=dtype) + return (r1, {'a': r2, 'rs': [r1, r2, r3], 'c': {'d': {'e': (r2,)}}}) + + def _is_equal(self, el1, el2): + if not isinstance(el1, type(el2)) or not isinstance(el2, type(el1)): + return False + if isinstance(el1, (np.ndarray, jax.Array)): + return (el1.dtype == el2.dtype and el1.shape == el2.shape + and jnp.allclose(el1, el2)) + else: + return el1 == el2 + + def assertPyTreeEqual(self, p1, p2, is_leaf=None): + leaves1, struct1 = jax.tree.flatten(p1, is_leaf=is_leaf) + leaves2, struct2 = jax.tree.flatten(p2, is_leaf=is_leaf) + self.assertEqual(struct1, struct2) + self.assertTrue(all(self._is_equal(el1, el2) + for (el1, el2) in zip(leaves1, leaves2))) + +_DTYPES_LIST = [ + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, + jnp.float8_e4m3b11fnuz, + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.complex64, +] + +_X64_DTYPES_LIST = [ + jnp.uint64, + jnp.int64, + jnp.float64, + jnp.complex128, +] + +if jax.config.x64_enabled: + _DTYPES_LIST.extend(_X64_DTYPES_LIST) + + +@jax.tree_util.register_pytree_node_class +class CustomNode: + def __init__(self, a): + self.a = a + + def tree_flatten(self): + return (self.a,), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(*children) + + +@partial(jax.tree_util.register_dataclass, data_fields=['a', 'd'], + meta_fields=['c']) +@dataclass +class CustomDataclass: + a: int + c: str + d: int + + +@jax.tree_util.register_static +class CustomStatic: + def __init__(self, a): + self.a = a + +# we're testing custom type registration which modifies the global registry +# so need to ensure we're not running multiple custom types tests in parallel +custom_types_threading_lock = threading.Lock() + + +class UserPytreeAPITest(UserAPITestCase): + def setUp(self): + super().setUp() + global _default_sharding + _default_sharding = SingleDeviceSharding(jax.devices()[0]) + self.tempdirs = [] + + def tearDown(self): + for tempdir in self.tempdirs: + tempdir.cleanup() + super().tearDown() + + def create_tempdir(self): + tempdir = tempfile.TemporaryDirectory() + self.tempdirs.append(tempdir) + return pathlib.Path(tempdir.name).resolve() + + @parameterized.product(tree=[{'a': 1}, [1, 2, 3], (1, 2, 3), 1, 2, 3]) + def test_save_then_load(self, tree): # pylint: disable=redefined-outer-name + path = self.create_tempdir() + tree = jax.tree.map(jnp.array, tree) + tree_save(tree, path) + tree2 = tree_load(path) + self.assertPyTreeEqual(tree, tree2) + + @parameterized.product(dtype=_DTYPES_LIST) + def test_saving_dtype(self, dtype): + if dtype in _X64_DTYPES_LIST and jtu.test_device_matches(['tpu']): + self.skipTest('Don\'t test x64 dtypes on TPUs') + path = self.create_tempdir() + test_tree = self.generate_clean_tree(dtype=dtype) + tree_save(test_tree, path) + new_tree = tree_load(path) + self.assertPyTreeEqual(test_tree, new_tree) + + def test_do_not_overwrite_noncheckpoint_directories(self): + path = self.create_tempdir() + path.mkdir(exist_ok=True) + (path / 'hello.txt').write_text('Hello World') + with self.assertRaisesRegex(RuntimeError, 'Refusing to work on a directory' + ' that is not a previous checkpoint.'): + tree_save({'a': jnp.ones(1)}, path) + + def test_checkpoint_exists(self): + path = self.create_tempdir() + tree_save({'a': jnp.ones(1)}, path) + with self.assertRaises(ValueError): + tree_save({'a': jnp.ones(1)}, path, overwrite=False) + + @parameterized.product(test_load_fail=[True, False]) + def test_custom_types(self, test_load_fail): + path = self.create_tempdir() + with custom_types_threading_lock: + magic_value = jnp.ones(()) * 37 + n = CustomNode(magic_value) + d = CustomDataclass(magic_value, 'hello', magic_value + 1) + s = CustomStatic(magic_value - 1) + tree_to_save = [n, (d, s)] + + register_pytree_node_serialization(CustomNode, + serialized_name='CustomNode', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomStatic, + serialized_name='CustomStatic', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomDataclass, + serialized_name='CustomDataclass', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(tree_to_save, path) + if test_load_fail: + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + with self.assertRaises(ValueError): + _ = tree_load(path) + else: + tree2 = tree_load(path) + self.assertEqual(tree2[0].a, magic_value) + self.assertEqual(tree2[1][0].a, magic_value) + self.assertEqual(tree2[1][0].c, 'hello') + self.assertEqual(tree2[1][0].d, magic_value + 1) + self.assertEqual(tree2[1][1].a, magic_value - 1) + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + + def test_flax_frozen_dict(self): + path = self.create_tempdir() + try: + # pylint: disable=g-import-not-at-top + # pylint: disable=g-importing-member + from flax.core.frozen_dict import FrozenDict + # pylint: enable=g-importing-member + # pylint: enable=g-import-not-at-top + except ImportError: + logging.warning('Skipping Flax FrozenDict tests as flax is not installed') + return + + try: + register_pytree_node_serialization(FrozenDict, + serialized_name='FrozenDict', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(FrozenDict(a=1, b=self.generate_clean_tree()), path) + tree_load(path) + finally: + _remove_from_serialization_registry(FrozenDict) + + def test_register_as_decorator(self): + @partial(register_pytree_node_serialization, + serialized_name='CustomDNode', + serialize_auxdata=json.dumps, + deserialize_auxdata=json.loads) + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=[]) + @dataclass + class CustomDNode: + a: int + b: int + + # test whether the object can be created (is visible in this scope) + _ = CustomDNode(1, 2) + + def test_custom_node_registration(self): + path = self.create_tempdir() + + @jax.tree_util.register_static + @dataclass + class P: + a: int = 2 + + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=['op']) + @dataclass + class D: + a: Any + b: Any + op: str + + def serialize_D(data): + return json.dumps(jax.tree.map(lambda x: np.array(x).tolist(), data) + ).encode('utf-8') + + def deserialize_D(data): + return jnp.array(json.loads(data)) + + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(()), P()] + + serialize_fn = lambda p: json.dumps(int(p.a)).encode('utf-8') + deserialize_fn = lambda data: P(json.loads(data)) + + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(P, + serialized_name='P', + serialize_auxdata=serialize_fn, + deserialize_auxdata=deserialize_fn) + magic_value = -171 + data[-1].a = jnp.array(magic_value) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertEqual(ret[-1].a, magic_value) + + magic_val = 17 * jnp.ones(2) + data.append(D(jnp.ones(1), jax.numpy.zeros(2), magic_val)) + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(D, + serialized_name='D', + serialize_auxdata=serialize_D, + deserialize_auxdata=deserialize_D) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertLess(jnp.linalg.norm(ret[-1].op - magic_val), 1e-5) + + jax.tree.flatten(data) + + def test_masked_reading(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + for mask in [False, True]: + ret = tree_load(path, mask=mask) + expected = jax.tree.map(lambda x: None if not mask else x, data) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, False, False] + expected = data[:1] + jax.tree.map(lambda x: None, data[1:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, True, False] + expected = data[:2] + jax.tree.map(lambda x: None, data[2:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, {'world': [True, (False, True)]}, False] + data[1]['world'][1] = (None, data[1]['world'][1][1]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + # TODO(rdyro): Remove when serialization supports non-arrays + @parameterized.product(obj=[b'hello', 'hello', 1, 1.0, 1j]) + def test_serialization_works_for_arrays_only(self, obj): + path = self.create_tempdir() + data = [{'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, obj] + msg = ('For serialization, all leaves must be either None or' + ' jax.Array-like objects.') + with self.assertRaisesRegex(ValueError, msg): + tree_save(data, path) + + def test_load_pytreedef(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + expected_pytreedef = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), data) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + + @parameterized.product(data=[ + None, [None], [None, np.ones(())], + [None, {'world': [None, (np.ones(1), np.ones(2))]}, np.ones(())], + [None, {'world': [np.zeros(3), (None, np.ones(2))]}, None]]) + def test_save_and_load_null_leaves(self, data): + path = self.create_tempdir() + # TPUs might not have X64 enabled, so we need to convert to float32 + data = jax.tree.map(lambda x: jnp.array(x, dtype=jnp.float32), data) + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + is_leaf = lambda x: x is None + expected_pytreedef = jax.tree.map(lambda x: jax.ShapeDtypeStruct( + x.shape, x.dtype) if x is not None else x, data, is_leaf=is_leaf) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + load_data = tree_load(path) + load_leaves, load_struct = jax.tree.flatten(load_data, is_leaf=is_leaf) + expected_leaves, expected_struct = jax.tree.flatten(data, is_leaf=is_leaf) + self.assertEqual(load_struct, expected_struct) + self.assertLen(load_leaves, len(expected_leaves)) + for (l1, l2) in zip(load_leaves, expected_leaves): + if l1 is None: + self.assertIsNone(l2) + else: + self.assertArraysEqual(l1, l2) + + @parameterized.product(manually_broadcast_ts_specs=[True, False]) + def test_custom_ts_specs(self, manually_broadcast_ts_specs): + if ts_impl._TS_ARRAY_DRIVER == 'zarr': + self.skipTest('Skipping since this test assumes zarr is NOT the default') + path = self.create_tempdir() + data = [jnp.ones(()), (jnp.zeros(()), jnp.ones(())), None] + ts_spec = {'driver': 'zarr', 'metadata': {'shape': ()}} + if manually_broadcast_ts_specs: + ts_specs = [ts_spec, (ts_spec, None), None] # None ts_spec allowed + else: + ts_specs = ts_spec + tree_save(data, path, ts_specs=ts_specs) + load_data = tree_load(path, ts_specs=ts_specs) + self.assertPyTreeEqual(data, load_data) + with self.assertRaisesRegex(ValueError, + 'NOT_FOUND: Error opening "zarr3" driver:'): + _ = tree_load(path) # default attempts to open with zarr3 and fails + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py index 7578bbb831e0..81b4a5177029 100644 --- a/jax/experimental/array_serialization/tensorstore_impl.py +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -77,6 +77,11 @@ async def release_bytes(self, requested_bytes): assert self._available_bytes <= self._max_bytes self._cv.notify_all() +def is_tensorstore_spec_leaf(leaf: Any): + # TODO(rdyro): think of a better way to detect which leaf is a ts config + return leaf is None or (isinstance(leaf, dict) + and ("driver" in leaf or "kvstore" in leaf)) + def _prime_factors(x: int) -> list[int]: # find prime factors of axis sizes to help efficiently find divisor chunks factors = [] @@ -163,6 +168,54 @@ def _get_tensorstore_metadata_cached( else: raise ValueError(f"Unsupported driver: {driver}") +_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0) + +def merge_nested_ts_specs(dict1: dict[Any, Any], dict2: dict[Any, Any] | None): + """Merge two ts specs, dict2 takes precedence.""" + if dict2 is None: # nothing to do + return dict1 + # TODO(rdyro): this is an opinionated merge, we should get user feedback + # merge kvstore explicitly + kvstore = dict1.get("kvstore", {}) | dict2.get("kvstore", {}) + return dict1 | dict(dict2, kvstore=kvstore) # merge with dict2 preferred + +def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None, + path: str | os.PathLike[str], ocdbt: bool, + check_metadata: bool = True) -> None: + """Verify the minimum requirements for a tensorstore spec.""" + if ocdbt: + if spec.get("kvstore", {}).get("driver", "") != "ocdbt": + raise ValueError(f"Expected ocdbt driver, got {spec=}") + if check_metadata: + if arr is None: + raise ValueError("Array is required for metadata verification.") + metadata = spec['metadata'] + if spec.get("driver", "") == "zarr3": + if metadata['data_type'] != jnp.dtype(arr.dtype).name: + raise ValueError(f"Provided dtype ({metadata['data_type']=}) doesn't" + f" match ({arr.dtype=})") + if 'shape' in metadata: + if metadata['shape'] != arr.shape: + raise ValueError(f"Provided shape ({metadata['shape']=}) doesn't match" + f" ({arr.shape=})") + if hasattr(arr, 'addressable_data'): + local_shape = arr.addressable_data(0).shape + else: # np.ndarray + local_shape = arr.shape + if spec.get("driver", "") == "zarr3": + chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape'] + if not _divides(local_shape, chunk_shape): + raise ValueError(f"Provided chunk shape {chunk_shape} does not divide" + f" the local shape of the array {local_shape}") + # check path is still the same one we expect + if ocdbt: + found_path = spec["kvstore"]['base']['path'] + else: + found_path = spec["kvstore"]['path'] + if str(found_path) != str(path): + raise ValueError(f"Provided {path=} does not match the spec path:" + f" {spec['kvstore']}") + def _spec_has_metadata(tree): if not isinstance(tree, dict): return False @@ -189,7 +242,7 @@ def _get_kvstore_for_s3(ckpt_path: str): def get_tensorstore_spec( ckpt_path: str | PathLike[str], ocdbt: bool = True, - process_num: int | None = None, arr: jax.Array | None = None, + process_idx: int | None = None, arr: jax.Array | None = None, driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: # Normalize path to exclude trailing '/'. In GCS path case, normpath will @@ -201,9 +254,9 @@ def get_tensorstore_spec( # in cases of multi-process writes, we need to write to a different location # for each process and finally created a combined symlink to the final # location, tensorstore can do this via ts.KvStore.experimental_copy_range_to - if process_num is not None: + if process_idx is not None: _parent, _name = os.path.split(ckpt_path) - ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_num), + ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_idx), _name) is_gcs_path = ckpt_path.startswith('gs://') @@ -278,6 +331,21 @@ async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray: # silently copy host-to-host. return np.array(data, copy=False) +async def combine_kvstores(combined_kvstore: dict[str, Any], + kvstores: list[dict[str, Any]], + context: ts.Context | dict[str, Any] = _TS_CONTEXT + ) -> None: + """Merge a list of kvstores into a single kvstore. NOT multi-process safe.""" + combined_fut = ts.KvStore.open(combined_kvstore, context=context) + kvstores_futs = [ts.KvStore.open(kvstore, context=context) + for kvstore in kvstores] + combined, kvstores = await asyncio.gather(combined_fut, + asyncio.gather(*kvstores_futs)) + tx = ts.Transaction() + await asyncio.gather(*[kvstore.experimental_copy_range_to( + combined.with_transaction(tx)) for kvstore in kvstores]) + await tx.commit_async() + async def async_serialize( arr_inp, tensorstore_spec, From 26228f5a0c5dc45147805bd5535443f2f2213dbe Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 30 May 2025 15:23:34 -0700 Subject: [PATCH 1457/1769] Allow setting non-string TPU runtime flags. For example: jax.config.update("jax_pjrt_client_create_options", {"max_inflight_computations": 64}) PiperOrigin-RevId: 765357524 --- jax/_src/config.py | 10 ++++++++-- jax/_src/xla_bridge.py | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index e79993958349..9d19ceb8b261 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -940,11 +940,17 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'otherwise.' )) -jax_pjrt_client_create_options = optional_string_state( +def _validate_jax_pjrt_client_create_options(new_val): + if new_val is not None and not isinstance(new_val, (str, dict)): + raise ValueError('new string config value must be None or of type dict' + f' | str, got {new_val} of type {type(new_val)}.') + +jax_pjrt_client_create_options = string_or_object_state( name='jax_pjrt_client_create_options', default=None, help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' - 'provided to a device platform pjrt client as extra arguments.')) + 'provided to a device platform pjrt client as extra arguments.'), + validator=_validate_jax_pjrt_client_create_options) enable_checks = bool_state( name='jax_enable_checks', diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index ce0c36fdcca4..22e71d7d9ce6 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -449,18 +449,21 @@ def _options_from_jax_configs(plugin_name): options = {} pjrt_client_options = config.jax_pjrt_client_create_options.value - pjrt_client_option_list = [] - if pjrt_client_options: - pjrt_client_option_list = pjrt_client_options.split(";") - - for option in pjrt_client_option_list: - option_list = option.split(":") - if (len(option_list) != 2): - raise RuntimeError( - "Multiple ':' separators for option in " - f"jax_pjrt_client_create_options: '{option}'. " - "Should be in format 'key:value'") - options[option_list[0]] = option_list[1] + if isinstance(pjrt_client_options, str): + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + elif isinstance(pjrt_client_options, dict): + options.update(pjrt_client_options) if plugin_name in ("cuda", "rocm"): visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda" From 6f0f0ad23c50ee84c53a06d7295fc25ecb1fd306 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 30 May 2025 16:05:38 -0700 Subject: [PATCH 1458/1769] fix incorrect TODO PiperOrigin-RevId: 765371666 --- jax/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 67fc208f7841..aaa25db88beb 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -311,9 +311,9 @@ py_library_providing_imports_info( "_src/error_check.py", "_src/ffi.py", "_src/flatten_util.py", - "_src/interpreters/__init__.py", # TODO(vanderplas): remove this and depend on :api instead + "_src/interpreters/__init__.py", "_src/interpreters/pxla.py", # TODO(vanderplas): remove this and depend on :api instead - "_src/pjit.py", + "_src/pjit.py", # TODO(vanderplas): remove this and depend on :api instead "_src/prng.py", "_src/public_test_util.py", "_src/random.py", From 6c18aa8a468e35b8c11b101dceaa43d05b497177 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 30 May 2025 16:32:34 -0700 Subject: [PATCH 1459/1769] [Mosaic] Move i1 broadcast lowering logic to Mosaic. And relax the test skip conditions. Somehow we skipped everything before. Also, this should fix https://github.com/jax-ml/jax/issues/29092. PiperOrigin-RevId: 765380392 --- jax/_src/pallas/mosaic/lowering.py | 14 ------- .../tpu/transforms/canonicalize_mosaic.cc | 39 ++++++++++++++++++ tests/pallas/ops_test.py | 41 ++++++++++++++----- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e2dfe526ea14..873bf587093a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1928,20 +1928,6 @@ def _broadcast_in_dim_lowering_rule( if aval_in.shape == shape: return val - if jnp.issubdtype(aval_in.dtype, jnp.bool_): - # Direct broadcasts for bools are not supported in Mosaic due to booleans - # living in mask registers and broadcast operating on vregs. Broadcast as an - # integer instead and cast back to a bool. - # TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead. - def _proxy_fun(val, *, shape, broadcast_dimensions): - int_val = jnp.where(val, 1, 0) - bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions) - return bcast_val == 1 - proxy_lowering = lower_fun( - _proxy_fun, multiple_results=False) - return proxy_lowering( - ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions) - if broadcast_dimensions: out_shape_list = [1] * len(shape) for i, s in zip(broadcast_dimensions, aval_in.shape): diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index c963cff0be50..733863546935 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -566,6 +566,44 @@ LogicalResult canonicalize_extract(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_broadcast(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = dyn_cast(raw_op); + auto src_ty = op.getSource().getType(); + auto src_vty = dyn_cast(src_ty); + if ((src_vty && src_vty.getElementType().isSignlessInteger(1)) || + op.getSource().getType().isSignlessInteger(1)) { + // Canonicalize i1 broadcast. + // i1 represents vmsk in Mosaic and TPU doesn't support vmsk replication + // directly. + // Instead, convert i1 to i32 vector, broadcast i32, and then convert it + // back to i1. + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + Value i32_src; + if (src_vty) { + i32_src = builder.create( + VectorType::get(src_vty.getShape(), builder.getI32Type()), + op.getSource()); + } else { + i32_src = + builder.create(builder.getI32Type(), op.getSource()); + } + auto i32_res_vty = + VectorType::get(op.getType().getShape(), builder.getI32Type()); + auto bcast = builder.create(i32_res_vty, i32_src); + auto ones = builder.create( + i32_res_vty, + SplatElementsAttr::get(i32_res_vty, + builder.getOneAttr(builder.getI32Type()))); + auto cmp = + builder.create(arith::CmpIPredicate::eq, bcast, ones); + op.replaceAllUsesWith(cmp.getResult()); + op.erase(); + return success(); + } + return success(); +} + LogicalResult canonicalize_select(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = dyn_cast(raw_op); @@ -920,6 +958,7 @@ const llvm::StringMap &rules() { canonicalize_multi_dim_reduction}, {vector::TransposeOp::getOperationName(), canonicalize_vector_transpose}, {vector::ShapeCastOp::getOperationName(), canonicalize_reshape}, + {vector::BroadcastOp::getOperationName(), canonicalize_broadcast}, {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, {arith::SIToFPOp::getOperationName(), canonicalize_sitofp}, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c819d050c8a5..e951a9fda827 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1822,30 +1822,49 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): np.testing.assert_allclose(out[oi], x[ii]) np.testing.assert_allclose(out[oi + 1 :], jnp.zeros_like(out[oi + 1 :])) - @parameterized.parameters( - ((), (2,), ()), - ((1,), (2,), (0,)), - ((1, 1), (2, 2), (0, 1)), - ((), (2, 2), ()), + @parameterized.product( + shape_spec=[ + ((), (2,), ()), + ((1,), (2,), (0,)), + ((1, 128), (8, 128), (0, 1)), # row broadcasting + ((), (2, 2), ()), + ], + dtype=[jnp.int32, jnp.int16, jnp.int8, jnp.bool_], ) - def test_broadcast_in_dim(self, in_shape, out_shape, dims): + def test_broadcast_in_dim(self, shape_spec, dtype): self.skip_if_mosaic_gpu() - # The Pallas TPU lowering currently supports only blocks of rank >= 1 + in_shape, out_shape, dims = shape_spec if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") + if not in_shape: + self.skipTest( + "The Pallas TPU lowering currently supports only blocks of rank" + " >= 1" + ) + if dtype is jnp.bool_ and not jtu.if_cloud_tpu_at_least(2025, 6, 5): + self.skipTest("Requires libtpu built after 2025-06-05") + if ( + len(in_shape) == 1 + and len(out_shape) == 1 + and dtype not in {jnp.int32, jnp.bool_} + ): + self.skipTest("Unsupported tiling") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), ) def f(x_ref, o_ref): x = x_ref[...] o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + x = ( + jnp.arange(math.prod(in_shape), dtype=jnp.int32) + .reshape(in_shape) + .astype(dtype) + ) expected = jax.lax.broadcast_in_dim(x, out_shape, dims) - np.testing.assert_allclose(f(x), expected) + np.testing.assert_array_equal(f(x), expected) @parameterized.product( lhs_and_rhs_shape=[ From 3c04713bba08078ec3f0a990623a718cf6925aba Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 31 May 2025 00:26:35 -0700 Subject: [PATCH 1460/1769] Automated Code Change PiperOrigin-RevId: 765494419 --- jaxlib/to_ifrt_sharding.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/to_ifrt_sharding.cc b/jaxlib/to_ifrt_sharding.cc index 2bb6e121893f..220c54e7a1e5 100644 --- a/jaxlib/to_ifrt_sharding.cc +++ b/jaxlib/to_ifrt_sharding.cc @@ -16,7 +16,6 @@ limitations under the License. #include "jaxlib/to_ifrt_sharding.h" #include -#include #include #include #include From ff6892b41541202ebcec2a791862dfc05910baca Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 31 May 2025 02:59:59 -0700 Subject: [PATCH 1461/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/2b5d61bf82739017eb0338936c31418dca171780. PiperOrigin-RevId: 765527826 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index adeb0cc1bfc5..180643fad58f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1c4bad881e922d783e382982f0e3d175d1c5e707" -XLA_SHA256 = "46af35dac4c699badd7a230a1cbbaabfea1cee81394bbe1b9c6cbc33046651c4" +XLA_COMMIT = "2b5d61bf82739017eb0338936c31418dca171780" +XLA_SHA256 = "4d7e61c55de1264b9cd8e24d50fc0c6c77209b28767735250c943591f74e9e17" def repo(): tf_http_archive( From 5cca31fa2b411f332c19b972ed966983917d55e7 Mon Sep 17 00:00:00 2001 From: DanisNone Date: Sat, 31 May 2025 02:40:31 +0500 Subject: [PATCH 1462/1769] test_binary_ufunc_reduce now also tests behavior with the initial and where parameters. test_binary_ufunc_reduce_where has been removed. add uint32 to tests correct _logsumexp and _reduce_bitwise_and implement _logsumexp2 without _logsumexp Now jnp.minimum and jnp.maximum are ufuncs. fix for ruff and mypy fix fix mypy fix --- jax/_src/numpy/lax_numpy.py | 2 +- jax/_src/numpy/reductions.py | 24 ++++++------- jax/_src/numpy/ufuncs.py | 6 ++-- jax/numpy/__init__.pyi | 4 +-- tests/lax_numpy_ufuncs_test.py | 61 +++++++++++++++++++++++++++++++++- 5 files changed, 77 insertions(+), 20 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ad2b3ad6aa75..b926662fd777 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3445,7 +3445,7 @@ def clip( if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: - arr = ufuncs.minimum(max, arr) + arr = ufuncs.minimum(max, arr) # type: ignore return asarray(arr) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 9cb543d5d869..cbfda25eafcf 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -33,6 +33,7 @@ _broadcast_to, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal +from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, @@ -398,11 +399,11 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +def _reduce_max(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "max", lax.max, -np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @@ -480,12 +481,12 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_min(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "min", lax.min, np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @@ -682,7 +683,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: arr = lax_internal.asarray(a) - init_val = np.array(-1, dtype=dtype or arr.dtype) + init_val = np.array(-1).astype(dtype or arr.dtype) return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -750,7 +751,7 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) - return result if initial is None else lax.logaddexp(initial, result) + return result if initial is None else lax_other.logaddexp(initial, result) def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, @@ -768,7 +769,6 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, where=where, initial=initial) / ln2 - @export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index d722534e3136..486d3f15e17c 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1626,8 +1626,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_min) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1687,8 +1686,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_max) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 4db407861f34..e81d97765121 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -653,7 +653,7 @@ def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ... def max(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... +maximum: BinaryUfunc def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... @@ -666,7 +666,7 @@ mgrid: _Mgrid def min(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... +minimum: BinaryUfunc def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index fd5050a5829b..905d7eed1acd 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -56,7 +56,7 @@ def _jnp_ufunc_props(name): jnp_func = getattr(jnp, name) assert isinstance(jnp_func, jnp.ufunc) np_func = getattr(np, name) - dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] + dtypes = [np.dtype(c) for c in "FfIi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] return [dict(name=name, dtype=dtype) for dtype in dtypes] @@ -242,6 +242,7 @@ def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} @@ -324,6 +325,64 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, initial: jnp_fun.reduce(a, axis=axis, initial=initial) + np_fun_reduce = lambda a, initial: np_fun.reduce(a, axis=axis, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_initial((), dtype)] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_where_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + # Skip if the ufunc doesn't have an identity and we're doing a multi-axis reduction + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, where, initial: jnp_fun.reduce( + a, axis=axis, where=where, initial=initial) + np_fun_reduce = lambda a, where, initial: np_fun.reduce( + a, axis=axis, where=where, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(shape, dtype), + rng_where(shape, bool), + rng_initial((), dtype) + ] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} From 0a1ada83ec979dfeba713d5f81f9d7684c52afa6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 31 May 2025 21:20:23 -0700 Subject: [PATCH 1463/1769] Allow specifying non-differentiable arguments by name (`nondiff_argnames`) in addition to by index (`nondiff_argnums`). The implementation normalizes `nondiff_argnames` to indices in the constructor and merges them with `nondiff_argnums`, allowing the rest of the custom derivative logic to continue using a unified list of indices. PiperOrigin-RevId: 765730837 --- jax/_src/custom_derivatives.py | 46 +++++++++++++++++++++------ tests/custom_api_test.py | 57 ++++++++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2a09665f6285..87407efebd3a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,8 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, - prepend_static_args, debug_info) + prepend_static_args, debug_info, fun_signature, + infer_argnums_and_argnames) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -133,16 +134,31 @@ def f_jvp(primals, tangents): """ fun: Callable[..., ReturnValue] nondiff_argnums: Sequence[int] + nondiff_argnames: Sequence[str] jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None symbolic_zeros: bool = False def __init__(self, fun: Callable[..., ReturnValue], nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = (), ): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) __getattr__ = custom_api_util.forward_attr @@ -259,10 +275,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable ) from e if self.nondiff_argnums: - nondiff_argnums = set(self.nondiff_argnums) - args = tuple(_stop_gradient(x) if i in nondiff_argnums else x + args = tuple(_stop_gradient(x) if i in self.nondiff_argnums else x for i, x in enumerate(args)) - diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug), diff_argnums, args, require_static_args_hashable=False) @@ -536,10 +551,24 @@ def f_bwd(res, g): def __init__(self, fun: Callable[..., ReturnValue], - nondiff_argnums: Sequence[int] = ()): + nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = ()): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None self.bwd: Callable[..., tuple[Any, ...]] | None = None self.symbolic_zeros = False @@ -681,8 +710,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable else: if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) - nondiff_argnums = set(self.nondiff_argnums) - dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial( lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums, args, require_static_args_hashable=False) diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index bfe391797920..45434923d543 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -330,7 +330,7 @@ def g_jvp(primals, tangents): self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) - def test_nondiff_arg(self): + def test_nondiff_argnums(self): @partial(jax.custom_jvp, nondiff_argnums=(0,)) def app(f, x): return f(x) @@ -347,6 +347,21 @@ def app_jvp(f, primals, tangents): expected = (2., 3.) self.assertAllClose(ans, expected, check_dtypes=False) + def test_nondiff_argnames(self): + @partial(jax.custom_jvp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + def test_nondiff_arg_jit_tracer(self): # This test would pass with "final-style" JIT tracing, but that was # misleading: it doesn't work with "initial-style" staging, i.e. control @@ -1655,7 +1670,7 @@ def foo(x): expected = 2. * jnp.cos(jnp.arange(3.)) self.assertAllClose(ans, expected, check_dtypes=False) - def test_nondiff_arg(self): + def test_nondiff_argnums(self): @partial(jax.custom_vjp, nondiff_argnums=(0,)) def app(f, x): return f(x) @@ -1673,6 +1688,44 @@ def app_rev(f, cos_x, g): expected = (2., jnp.cos(1.)) self.assertAllClose(ans, expected, check_dtypes=False) + def test_nondiff_argnames(self): + @partial(jax.custom_vjp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnums_argnames(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,), nondiff_argnames=('g',)) + def app(f, g, x): + return f(x) + g(x) + def app_fwd(f, g, x): + return app(f, g, x), jnp.cos(x) + def app_rev(f, g, cos_x, v): + return (cos_x * v,) + app.defvjp(app_fwd, app_rev) + + f = lambda x: 2 * x + g = lambda x: 2 * x + ans = app(f, g, 1) + expected = 4 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(f, g, x))(1.) + expected = (4., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + def test_closed_over_jit_tracer(self): # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. raise unittest.SkipTest("behavior no longer supported") From 88dbf60fb065dc19940b25eeb7d0eb56983484fa Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 1 Jun 2025 02:13:42 -0700 Subject: [PATCH 1464/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/aec061968d8ba167ce6ffe08bfd51c7a4508ece4. PiperOrigin-RevId: 765795019 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 180643fad58f..0fc5aeb944f6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2b5d61bf82739017eb0338936c31418dca171780" -XLA_SHA256 = "4d7e61c55de1264b9cd8e24d50fc0c6c77209b28767735250c943591f74e9e17" +XLA_COMMIT = "aec061968d8ba167ce6ffe08bfd51c7a4508ece4" +XLA_SHA256 = "020328ef5f098a4ff19fe420aad441b803c5cb4d3dd515e7d07c9ac949f33af5" def repo(): tf_http_archive( From 107efde069779d5b66821c867b725503fe48340e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sun, 1 Jun 2025 07:25:53 -0700 Subject: [PATCH 1465/1769] Reverts 73c016a534af51614741d70d36c2c75ca59f2dcc PiperOrigin-RevId: 765852528 --- jax/_src/named_sharding.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 0b4efdb41d25..faf0b2a9f2b2 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -288,6 +288,13 @@ def _custom_repr(self): priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' +def _get_axes(axes, mesh_shape): + if not axes: + return () + assert mesh_shape is not None + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + return tuple(n for n, _ in mesh_shape if n in axes) @dataclasses.dataclass(kw_only=True) class SdyArray: @@ -307,11 +314,13 @@ def build(self) -> sdy.TensorShardingAttr: [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) + replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) + unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, [dim_sharding.build() for dim_sharding in self.dim_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes], - unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in self.unreduced_axes]) + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( @@ -333,7 +342,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): dim_shardings.append(SdyDim(axes=[], is_open=True) if not d.axes and not d.is_open else d) used_axes.extend(d.axes) - remaining_axes = tuple(n for n in mesh.axis_names if n not in used_axes) + remaining_axes = set(mesh.axis_names) - set(used_axes) replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArray(mesh_shape=sdy_sharding.mesh_shape, From 52e5a87d785db75eadbf172d813495f8bfea7101 Mon Sep 17 00:00:00 2001 From: Sannidhya Chauhan Date: Mon, 2 Jun 2025 00:03:17 -0700 Subject: [PATCH 1466/1769] Introduce profiler_options in the documentation. PiperOrigin-RevId: 766058161 --- docs/profiling.md | 97 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 4 deletions(-) diff --git a/docs/profiling.md b/docs/profiling.md index c33e79c1dc0c..c4a340684cd3 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -100,10 +100,10 @@ pip install tb-nightly tbp-nightly ### Programmatic capture You can instrument your code to capture a profiler trace via the -{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` -methods. Call {func}`~jax.profiler.start_trace` with the directory to write -trace files to. This should be the same `--logdir` directory used to start -TensorBoard. Then, you can use TensorBoard to view the traces. +{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` methods. +Call {func}`~jax.profiler.start_trace` with the directory to write trace files +to. This should be the same `--logdir` directory used to start TensorBoard. +Then, you can use TensorBoard to view the traces. For example, to take a profiler trace: @@ -229,6 +229,95 @@ functions. You can add your own events and functions by using {class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in your code. +### Configuring profiler options + +The `start_trace` method accepts an optional `profiler_options` parameter, which +allows for fine-grained control over the profiler's behavior. This parameter +should be an instance of `jax.profiler.ProfileOptions`. + + +For example, to disable all python and host traces: + +```python +import jax + +options = jax.profiler.ProfileOptions() +options.python_tracer_level = 0 +options.host_tracer_level = 0 +jax.profiler.start_trace("/tmp/tensorboard", profiler_options=options) + +# Run the operations to be profiled +key = jax.random.key(0) +x = jax.random.normal(key, (5000, 5000)) +y = x @ x +y.block_until_ready() + +jax.profiler.stop_trace() +``` + +#### General options + +1. `host_tracer_level`: Sets the trace level for host-side activities. + + Supported Values: + + `0`: Disables host (CPU) tracing entirely. + + `1`: Enables tracing of only user-instrumented TraceMe events (this is the + default). + + `2`: Includes level 1 traces plus high-level program execution details like + expensive TensorFlow or XLA operations. + + `3`: Includes level 2 traces plus more verbose, low-level program execution + details such as cheap TensorFlow operations. + +2. `python_tracer_level`: Controls whether Python tracing is enabled. + + Supported Values: + + `0`: Disables Python function call tracing. + + `> 0`: Enables Python tracing (this is the default). + +#### Advanced configuration options + +1. `tpu_trace_mode`: Specifies the mode for TPU tracing. + + Supported Values: + + `TRACE_ONLY_HOST`: This means only host-side (CPU) activities are traced, + and no device (TPU/GPU) traces are collected. + + `TRACE_ONLY_XLA`: This means only XLA-level operations on the device are + traced. + + `TRACE_COMPUTE`: This traces compute operations on the device. + + `TRACE_COMPUTE_AND_SYNC`: This traces both compute operations and + synchronization events on the device. + + If "tpu_trace_mode" is not provided the trace_mode defaults to + TRACE_ONLY_XLA. + +2. `tpu_num_sparse_cores_to_trace`: Specifies the number of sparse cores to + trace on the TPU. +3. `tpu_num_sparse_core_tiles_to_trace`: Specifies the number of tiles within + each sparse core to trace on the TPU. +4. `tpu_num_chips_to_profile_per_task`: Specifies the number of TPU chips to + profile per task. + +For example: + +``` +options = ProfileOptions() +options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2} + +``` + +Returns InvalidArgumentError if any unrecognized keys or option values are +found. + ### Troubleshooting #### GPU profiling From 1914815d763fc802f4efbbd90b059245aaa57bd5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Jun 2025 02:18:11 -0700 Subject: [PATCH 1467/1769] Automated Code Change PiperOrigin-RevId: 766096798 --- jaxlib/mosaic/gpu/custom_call.cc | 5 ++++- jaxlib/mosaic/gpu/target.cc | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 214521ce3764..01b7e015e461 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include +#include #include #include #include +#include #include #include #include @@ -40,9 +42,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" // Leave this comment here. Internal Google business. diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index dfb119b410af..d26b1f1ccbf7 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -24,6 +23,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "llvm/MC/MCSubtargetInfo.h" #include "llvm/MC/TargetRegistry.h" From b782b4672fbb394a1c4863e7a425578bcdf24fe6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Jun 2025 02:38:41 -0700 Subject: [PATCH 1468/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/89872387d699d4622c96a4572e7772b3d1f387cd. PiperOrigin-RevId: 766102823 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0fc5aeb944f6..fb27433402ba 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "aec061968d8ba167ce6ffe08bfd51c7a4508ece4" -XLA_SHA256 = "020328ef5f098a4ff19fe420aad441b803c5cb4d3dd515e7d07c9ac949f33af5" +XLA_COMMIT = "89872387d699d4622c96a4572e7772b3d1f387cd" +XLA_SHA256 = "ae1e6d82acc1ec54e1cde1162246a09724ea3c0868054a8a0721c18417c54be5" def repo(): tf_http_archive( From 27e454ddb81e2a382aafb4d5c89266df0be1bf2a Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Mon, 2 Jun 2025 05:36:46 -0700 Subject: [PATCH 1469/1769] [JAX] Use `util.fun_name` to determine `WrappedFun.__name__` instead of trying to get the `f.__name__` attribute, which won't always exists. Co-authored-by: Keith Rush PiperOrigin-RevId: 766152892 --- jax/_src/linear_util.py | 12 ++------ jax/_src/util.py | 6 ++-- tests/util_test.py | 61 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index bfe87430554e..41af7644d361 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -75,8 +75,8 @@ def trans1(static_arg, *dynamic_args, **kwargs): from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.tree_util import keystr, KeyPath, generate_key_paths -from jax._src.util import curry, cache_clearing_funs, HashableFunction +from jax._src.tree_util import KeyPath, generate_key_paths, keystr +from jax._src.util import HashableFunction, cache_clearing_funs, curry, fun_name traceback_util.register_exclusion(__file__) @@ -186,7 +186,7 @@ def __init__(self, f: Callable, @property def __name__(self): - return getattr(self.f, '__name__', '') + return fun_name(self.f, "") def wrap(self, gen, gen_static_args, out_store: Store | EqualStore | None) -> WrappedFun: @@ -266,12 +266,6 @@ def transformation_with_aux2( out_thunk = lambda: out_store.val return fun.wrap(gen, gen_static_args, out_store), out_thunk -def fun_name(f): - try: - return f.__name__ - except: - return str(f) - class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" diff --git a/jax/_src/util.py b/jax/_src/util.py index 34f748544d6d..dbdc746713fb 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -367,14 +367,16 @@ def __eq__(self, other): def wrap_name(name: str, transform_name: str) -> str: return transform_name + '(' + name + ')' -def fun_name(fun: Callable) -> str: + +def fun_name(fun: Callable, default_name: str = "") -> str: name = getattr(fun, "__name__", None) if name is not None: return name if isinstance(fun, partial): return fun_name(fun.func) else: - return "" + return default_name + def fun_qual_name(fun: Callable) -> str: qual_name = getattr(fun, "__qualname__", None) diff --git a/tests/util_test.py b/tests/util_test.py index 90506117af8f..923240b69242 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import operator from absl.testing import absltest - import jax from jax import api_util from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util - from jax._src.util import weakref_lru_cache jax.config.parse_flags_with_absl() @@ -74,6 +73,64 @@ def kw_to_positional(f, store, factor, *args, **kwargs): self.assertEqual(dict(three=6, four=8), scaled_kwargs) self.assertEqual(2, out_thunk()) + def test_wrapped_fun_name(self): + def my_function(): + return + + with self.subTest("function"): + wrapped = lu.wrap_init( + my_function, + debug_info=api_util.debug_info("test", my_function, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("default_partial"): + my_partial = partial(my_function) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("nested_default_partial"): + my_partial = partial(partial(my_function)) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("named_partial"): + my_partial = partial(my_function) + my_partial.__name__ = "my_partial" + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_partial.__name__) + + with self.subTest("lambda"): + l = lambda: my_function() + wrapped = lu.wrap_init( + l, + debug_info=api_util.debug_info("test", l, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + + with self.subTest("unnamed_callable"): + + class MyCallable: + + def __call__(self): + return + + my_callable = MyCallable() + wrapped = lu.wrap_init( + my_callable, + debug_info=api_util.debug_info("test", my_callable, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + def test_weakref_lru_cache(self): @weakref_lru_cache def example_cached_fn(key): From 8f5dae48a30545ec631c4ea9c0c68869ad6856ab Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 2 Jun 2025 06:20:13 -0700 Subject: [PATCH 1470/1769] [jaxlib] Use SafeStaticInit in more places. Fixes a deadlock in free threading mode. PiperOrigin-RevId: 766165343 --- jaxlib/BUILD | 1 + jaxlib/pmap_lib.cc | 8 +++++--- jaxlib/py_array.cc | 13 +++++++------ jaxlib/py_values.cc | 29 ++++++++++++++++------------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 7be24ce5e825..834103063aae 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -801,6 +801,7 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", + "@xla//xla/python:safe_static_init", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index e18dd8b4637a..4a4e20f8f55b 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -67,6 +67,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" #include "xla/python/types.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" @@ -289,9 +290,10 @@ class PmapFunction { size_t nargs, PyObject* kwnames); nb::object PythonSignature() { - static const auto* inspect = - new nb::module_(nb::module_::import_("inspect")); - return inspect->attr("signature")(fun_); + const nb::module_& inspect = xla::SafeStaticInit([]() { + return std::make_unique(nb::module_::import_("inspect")); + }); + return inspect.attr("signature")(fun_); } int cache_size() { diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 37932a2aed45..8659cba49dea 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -94,6 +94,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -393,16 +394,16 @@ nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { static auto* lru_list = new CacheT::LRUList(4096); static auto* cache = new CacheT(lru_list); - static const nb::object* shaped_array = []() -> nb::object* { + const nb::object& shaped_array = SafeStaticInit([]() { nb::object jax_core; try { jax_core = nb::module_::import_("jax.core"); } catch (nb::python_error& e) { - return nullptr; + return std::make_unique(); } - return new nb::object(jax_core.attr("ShapedArray")); - }(); - if (!shaped_array) { + return std::make_unique(jax_core.attr("ShapedArray")); + }); + if (!shaped_array.ptr()) { return nb::none(); } @@ -415,7 +416,7 @@ nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { if (!value->has_value()) { nb_dtype dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); - nb::object aval = (*shaped_array)( + nb::object aval = shaped_array( SpanToNbTuple(absl::Span( key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} : key.dims)), diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index 6ea5c272eea3..987a51eb67cf 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -712,11 +712,12 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, } // namespace bool IsFloat0(xla::nb_numpy_ndarray arg) { - static const auto* dtypes_module = - new nb::module_(nb::module_::import_("jax.dtypes")); - static const auto* float0_dtype = - new nb::handle(dtypes_module->attr("float0")); - return float0_dtype->is(arg.attr("dtype")); + const nb::object& float0_dtype = SafeStaticInit([] { + nb::module_ dtypes_module = nb::module_::import_("jax.dtypes"); + nb::object float0_dtype = dtypes_module.attr("float0"); + return std::make_unique(float0_dtype); + }); + return float0_dtype.is(arg.attr("dtype")); } std::string PyArgSignature::DebugString() const { @@ -734,9 +735,11 @@ using ToPyArgSignatureHandler = absl::StatusOr PyArgSignatureOfValue(nb::handle arg, bool jax_enable_x64) { - static const absl::flat_hash_map* const - handlers = [] { - auto p = new absl::flat_hash_map(); + const absl::flat_hash_map& handlers = + SafeStaticInit< + absl::flat_hash_map>([] { + auto p = std::make_unique< + absl::flat_hash_map>(); const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); @@ -881,7 +884,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; return p; - }(); + }); if (arg.type().ptr() == PyArray::type().ptr()) { auto array = nb::borrow(arg); @@ -894,12 +897,12 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, return PyArgSignature(primitive_type, array.shape(), array.weak_type()); } - auto res = handlers->find(arg.type().ptr()); - if (res == handlers->end()) { + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { // We attempt to look at the MRO classes for (auto base_class : arg.type().attr("__mro__")) { - res = handlers->find(base_class.ptr()); - if (res != handlers->end()) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { return res->second(arg, jax_enable_x64); } } From a964f54dd1a2ff615771d23298eb1e8e6bdc82ef Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 2 Jun 2025 10:09:02 -0400 Subject: [PATCH 1471/1769] Update partial eval to avoid DCEing a specific set of effects. --- jax/_src/debugging.py | 8 +++++--- jax/_src/effects.py | 2 ++ jax/_src/interpreters/partial_eval.py | 12 ++++++------ jax/_src/state/types.py | 1 + tests/debugging_primitives_test.py | 2 -- tests/lax_control_flow_test.py | 4 +--- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 3490de5118e1..e587d48cda68 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -69,6 +69,8 @@ class OrderedDebugEffect(effects.Effect): effects.remat_allowed_effects.add_type(OrderedDebugEffect) effects.custom_derivatives_allowed_effects.add_type(DebugEffect) effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect) +effects.partial_eval_kept_effects.add_type(DebugEffect) +effects.partial_eval_kept_effects.add_type(OrderedDebugEffect) # `debug_callback_p` is the main primitive for staging out Python callbacks. debug_callback_p = core.Primitive('debug_callback') @@ -126,10 +128,10 @@ def debug_callback_jvp_rule(primals, tangents, **params): return debug_callback_p.bind(*primals, **params), [] ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule -def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], +def debug_callback_transpose_rule(_, *flat_args, callback: Callable[..., Any], effect: DebugEffect, partitioned): - del flat_args, callback, effect - raise ValueError("Transpose doesn't support debugging callbacks.") + del callback, effect, partitioned + return [None for _ in flat_args] ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule def _debug_callback_partial_auto(axis_context, *args, **params): diff --git a/jax/_src/effects.py b/jax/_src/effects.py index d55333540355..fb79c542e78b 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -118,3 +118,5 @@ def filter_not_in(self, effects: Iterable[Effect]) -> list[Effect]: control_flow_allowed_effects: EffectTypeSet = EffectTypeSet() custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet() remat_allowed_effects: EffectTypeSet = EffectTypeSet() + +partial_eval_kept_effects: EffectTypeSet = EffectTypeSet() diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 444b60f15fa5..1bcd3f00321c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -43,7 +43,7 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.source_info_util import SourceInfo -from jax._src.state.types import AbstractRef, ReadEffect, RefEffect +from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -240,23 +240,23 @@ def default_process_primitive(self, primitive, tracers, params): return primitive.bind_with_trace(self.parent_trace, consts, params) tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] - out_aval, effects = primitive.abstract_eval(*avals, **params) + out_aval, effs = primitive.abstract_eval(*avals, **params) name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects, + eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effs, source) - if any(isinstance(e, RefEffect) for e in effects): + if effects.partial_eval_kept_effects.filter_in(effs): self.effect_handles.append(EffectHandle(tracers, eqn)) for t in out_tracers: t.recipe = eqn return out_tracers else: out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive, - params, effects, source) - if any(isinstance(e, RefEffect) for e in effects): + params, effs, source) + if effects.partial_eval_kept_effects.filter_in(effs): self.effect_handles.append(EffectHandle(tracers, eqn)) out_tracer.recipe = eqn return out_tracer diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index af5bec49a6d5..e3a86e241bf2 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -75,6 +75,7 @@ class AccumEffect(RefEffect): name: str = "Accum" effects.control_flow_allowed_effects.add_type(RefEffect) +effects.partial_eval_kept_effects.add_type(RefEffect) StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 7985cf841248..9c23f136b825 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -442,8 +442,6 @@ def f(x): with jtu.capture_stdout() as output: jax.linear_transpose(f, 1.)(1.) jax.effects_barrier() - # `debug_print` should be dropped by `partial_eval` because of no - # output data-dependence. self.assertEqual(output(), "") @jtu.sample_product(ordered=[False, True]) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 2f1e154627f4..3f950f865735 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -28,7 +28,6 @@ import jax from jax._src import core -from jax._src import config from jax import dtypes from jax import lax from jax import random @@ -3359,8 +3358,7 @@ def f(c, _): return c + 1, None def g(x): return jax.lax.scan(f, x, length=2)[0] - with config.use_direct_linearize(True): - jaxpr = jax.make_jaxpr(jax.value_and_grad(g))(1.0) + jaxpr = jax.make_jaxpr(jax.value_and_grad(g))(1.0) eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"] self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns]) From 73aabb46c51251199ee1059c50c5c5ae3ea133d1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 2 Jun 2025 10:54:32 -0400 Subject: [PATCH 1472/1769] Bump the minimum NumPy and SciPy versions. Per Spec 0, we can drop NumPy 1.25 and SciPy 1.11 in June 2025. --- .github/workflows/oldest_supported_numpy.yml | 2 +- CHANGELOG.md | 3 ++ jax/_src/util.py | 7 +--- jaxlib/setup.py | 6 ++-- setup.py | 5 ++- tests/lax_numpy_test.py | 3 -- tests/scipy_spatial_test.py | 6 ---- tests/scipy_stats_test.py | 36 +++----------------- 8 files changed, 14 insertions(+), 54 deletions(-) diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index 67fc9f10e5ce..fbf881a84a9c 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -42,7 +42,7 @@ jobs: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Install NumPy and SciPy with the oldest supported versions - $JAXCI_PYTHON -m uv pip install numpy==1.25.2 scipy==1.11.1 + $JAXCI_PYTHON -m uv pip install numpy==1.26.4 scipy==1.12.0 # Install JAX using the changes in the PR $JAXCI_PYTHON -m uv pip install -e .[minimum-jaxlib] diff --git a/CHANGELOG.md b/CHANGELOG.md index b34bf36997af..afd15a357b48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * New features: * Added {func}`jax.tree.broadcast` which implements a pytree prefix broadcasting helper. +* Changes + * The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12. + ## JAX 0.6.1 (May 21, 2025) * New features: diff --git a/jax/_src/util.py b/jax/_src/util.py index dbdc746713fb..4100ac21dc00 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -642,12 +642,7 @@ def decorator(f): return decorator -try: - # numpy 1.25.0 or newer - NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning -except AttributeError: - # legacy numpy - NumpyComplexWarning = np.ComplexWarning +NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning class StrictABCMeta(abc.ABCMeta): diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 30e81c9ad671..ef0fcb205fb1 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -61,9 +61,9 @@ def has_ext_modules(self): packages=['jaxlib'], python_requires='>=3.10', install_requires=[ - 'scipy>=1.11.1', - 'numpy>=1.25', - 'ml_dtypes>=0.2.0', + 'scipy>=1.12', + 'numpy>=1.26', + 'ml_dtypes>=0.5.0', ], url='https://github.com/jax-ml/jax', license='Apache-2.0', diff --git a/setup.py b/setup.py index 2b50b041008d..6f552b1cf2f4 100644 --- a/setup.py +++ b/setup.py @@ -63,10 +63,9 @@ def load_version_module(pkg_path): install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.5.0', - 'numpy>=1.25', - "numpy>=1.26.0; python_version>='3.12'", + 'numpy>=1.26', 'opt_einsum', - 'scipy>=1.11.1', + 'scipy>=1.12', ], extras_require={ # Minimum jaxlib version; used in testing. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 29e6586ffa18..16234463d795 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1938,9 +1938,6 @@ def testDeleteMaskArray(self, shape, dtype, axis): rng = jtu.rand_default(self.rng()) mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - if numpy_version == (1, 23, 0) and mask.shape == (1,): - # https://github.com/numpy/numpy/issues/21840 - self.skipTest("test fails for numpy v1.23.0") args_maker = lambda: [rng(shape, dtype)] np_fun = lambda arg: np.delete(arg, mask, axis=axis) jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 3da98efce884..6b1c042b049e 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -123,8 +123,6 @@ def testRotationAsQuat(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationAsQuatCanonical(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy 1.11.0 added the `canonical` arg.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(canonical=True) @@ -152,8 +150,6 @@ def testRotationAsQuatScalarFirst(self, shape, dtype): other_shape=[(num_samples, 4)], ) def testRotationConcatenate(self, shape, other_shape, dtype): - if scipy_version < (1, 8, 0): - self.skipTest("Scipy 1.8.0 needed for concatenate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),) jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_rotvec() @@ -297,8 +293,6 @@ def testRotationInv(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationInvConjugate(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy prior to 1.11.0 used a negative conjugate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat() diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 796d4490daea..e9021b86bb7a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -20,18 +20,15 @@ import numpy as np import scipy.stats as osp_stats -import scipy.version import jax import jax.numpy as jnp -from jax._src import dtypes, test_util as jtu +from jax._src import test_util as jtu from jax.scipy import stats as lsp_stats from jax.scipy.special import expit jax.config.parse_flags_with_absl() -scipy_version = jtu.parse_version(scipy.version.version) - all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] @@ -217,9 +214,6 @@ def testBernoulliPpf(self, shapes, dtypes): scipy_fun = osp_stats.bernoulli.ppf lax_fun = lsp_stats.bernoulli.ppf - if scipy_version < (1, 9, 2): - self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.") - def args_maker(): q, p = map(rng, shapes, dtypes) q = expit(q) @@ -1664,9 +1658,6 @@ def evaluate_kde(kde, x): message="All axis-slices of one or more sample arguments are too small", ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): - if scipy_version < (1, 9, 0) and keepdims != True: - self.skipTest("scipy < 1.9.0 only support keepdims == True") - if contains_nans: rng = jtu.rand_some_nan(self.rng()) else: @@ -1675,25 +1666,7 @@ def testMode(self, shape, dtype, axis, contains_nans, keepdims): def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): """Wrapper to manage the shape discrepancies between scipy and jax""" - if scipy_version < (1, 11, 0) and a.size == 0: - if keepdims: - if axis == None: - output_shape = tuple(1 for _ in a.shape) - else: - output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) - else: - if axis == None: - output_shape = () - else: - output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis) - t = dtypes.canonicalize_dtype(jax.numpy.float_) - return (np.full(output_shape, np.nan, dtype=t), - np.zeros(output_shape, dtype=t)) - - if scipy_version < (1, 9, 0): - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy) - else: - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) + result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) if a.size != 0 and axis == None and keepdims == True: output_shape = tuple(1 for _ in a.shape) @@ -1748,11 +1721,10 @@ def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, From 8625207fc634b171e68ac03e62dc22c562e60e2d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 2 Jun 2025 08:16:38 -0700 Subject: [PATCH 1473/1769] Raise a better error when inputs sharded on explicit mesh axes are closed over in a shard_map instead of a crash. Fixes https://github.com/jax-ml/jax/issues/29162 PiperOrigin-RevId: 766199302 --- jax/_src/core.py | 18 ++++++++++++------ tests/shard_map_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 24150aba6584..1355fc10472f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1883,14 +1883,20 @@ def canonicalize_value(val): cur_mesh = mesh_lib.get_abstract_mesh() if cur_mesh == aval.sharding.mesh: return val - # Atleast 1 mesh axis should be Manual and all other axes should be - # Manual or Auto to allow casting. # TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need # cast_and_slice_p for it since shape might change? - if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and - aval.sharding.mesh._are_all_axes_auto): - from jax._src.pjit import mesh_cast # pytype: disable=import-error - return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + # Atleast 1 mesh axis should be Manual and all other axes should be + # Manual or Auto to allow casting. + if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual: + if aval.sharding.mesh._are_all_axes_auto: + from jax._src.pjit import mesh_cast # pytype: disable=import-error + return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + elif aval.sharding.mesh._any_axis_explicit: + raise NotImplementedError( + "Closing over inputs to shard_map where the input is sharded on" + " `Explicit` axes is not implemented. As a workaround, please pass" + " those inputs as an argument to shard_map. Got input with shape" + f" {aval.str_short(True, True)}") return val diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f473a4dc0547..f3f5641be1b6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2259,6 +2259,36 @@ def grad_fn(batch): params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash + @jtu.with_explicit_mesh((2,), ('x',)) + def test_close_over_explicit_sharded_input_error(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + shard_map(simple_func, in_specs=(P(), P('x')), out_specs=P('x'))(w, x) + + with self.assertRaisesRegex( + NotImplementedError, + 'Closing over inputs to shard_map where the input is sharded on' + ' `Explicit` axes is not implemented'): + shard_map(lambda xi: simple_func(w, xi), + in_specs=P('x'), out_specs=P('x'))(x) + + def test_close_over_input_explict_ctx_mesh(self): + mesh = jtu.create_mesh((2,), 'x', axis_types=(AxisType.Explicit,)) + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + shard_map(simple_func, mesh=mesh, in_specs=(P(), P('x')), + out_specs=P('x'))(w, x) + shard_map(lambda xi: simple_func(w, xi), mesh=mesh, + in_specs=P('x'), out_specs=P('x'))(x) + def test_shmap_close_over_unused_params_vmap(self): mesh = jtu.create_mesh((2,), ("data",)) From d62d94cb8578c37303e6a071bbe30b3395f84847 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 2 Jun 2025 12:32:26 -0400 Subject: [PATCH 1474/1769] Add a pretty printing rule for custom_lin_p. --- jax/_src/interpreters/ad.py | 8 ++++++++ tests/custom_api_test.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 0cd99a197f66..a77e93bb0696 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1243,6 +1243,14 @@ def _custom_lin_transpose(cts_out, *invals, num_res, return [None] * num_res + nz_cts_in primitive_transposes[custom_lin_p] = _custom_lin_transpose +def _custom_lin_pp_rule(eqn: core.JaxprEqn, context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + params.pop("out_avals") + params["bwd"] = params.pop("bwd").debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings) +core.pp_eqn_rules[custom_lin_p] = _custom_lin_pp_rule + class CustomJVPException(Exception): def __init__(self): diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 45434923d543..61b0129aca3e 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -3199,6 +3199,35 @@ def f_bwd(_, g): """).strip() self.assertEqual(actual, expected) + def test_custom_lin_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(lambda x: jax.jvp(f, (x,), (x,)))(x) + jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [False, True]) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_lin[ + bwd=f_bwd + in_zeros=[False] + num_res=0 + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + def transpose_unary(f, x_example): def transposed(y): From 8eaa9bf19ae0245bb518a3ef9479842cf761be67 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Jun 2025 09:49:21 -0700 Subject: [PATCH 1475/1769] [cleanup] inline uses of NumpyComplexWarning --- jax/_src/lax/lax.py | 4 ++-- jax/_src/numpy/lax_numpy.py | 4 ++-- jax/_src/util.py | 3 --- tests/lax_autodiff_test.py | 3 +-- tests/lax_metal_test.py | 10 +++++----- tests/lax_numpy_indexing_test.py | 3 +-- tests/lax_numpy_reducers_test.py | 33 ++++++++++++++++---------------- tests/lax_numpy_test.py | 10 +++++----- tests/lax_test.py | 4 ++-- 9 files changed, 34 insertions(+), 40 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 68363d10bc04..e03951eb4730 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -70,7 +70,7 @@ ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape -from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, +from jax._src.util import (cache, canonicalize_axis, safe_map, safe_zip, split_list, weakref_lru_cache, foreach) @@ -1706,7 +1706,7 @@ def _convert_element_type( dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, NumpyComplexWarning, stacklevel=2) + warnings.warn(msg, np.exceptions.ComplexWarning, stacklevel=2) # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ad2b3ad6aa75..b21fcfeb5772 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -62,7 +62,7 @@ Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( - NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding from jax._src.sharding_impls import NamedSharding, PartitionSpec as P @@ -171,7 +171,7 @@ def _dtype(x: Any) -> DType: can_cast = dtypes.can_cast promote_types = dtypes.promote_types -ComplexWarning = NumpyComplexWarning +ComplexWarning = np.exceptions.ComplexWarning _lax_const = lax_internal._const diff --git a/jax/_src/util.py b/jax/_src/util.py index 4100ac21dc00..71d8f8bfa6a1 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -642,9 +642,6 @@ def decorator(f): return decorator -NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning - - class StrictABCMeta(abc.ABCMeta): """A variant of `abc.ABCMeta` which does not allow virtual subclasses. diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index aea9d2ad3dff..a6398e402df9 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -28,7 +28,6 @@ from jax import dtypes from jax import lax from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads jax.config.parse_flags_with_absl() @@ -244,7 +243,7 @@ def testConvertElementTypeGrad(self, from_dtype, to_dtype): jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) - convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)( + convert_element_type = jtu.ignore_warning(category=np.exceptions.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 5f1781c3be06..e44ff9ebc930 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -48,7 +48,7 @@ from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.util import safe_zip try: from jax_plugins import metal_plugin @@ -2099,11 +2099,11 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2127,11 +2127,11 @@ def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index ca9ba9c88806..745cab59cf1b 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -35,7 +35,6 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() @@ -1186,7 +1185,7 @@ def _check(x_type, y_type): out = x.at[0].set(y) self.assertEqual(x.dtype, out.dtype) - @jtu.ignore_warning(category=NumpyComplexWarning, + @jtu.ignore_warning(category=np.exceptions.ComplexWarning, message="Casting complex values to real") def _check_warns(x_type, y_type, msg): with self.assertWarnsRegex(FutureWarning, msg): diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index aa5e08e96a3e..93aff25c6f8e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -29,7 +29,6 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() @@ -209,7 +208,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype, np_op = getattr(np, name) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, @@ -225,7 +224,7 @@ def np_fun(x): return np_op(x_cast, axis, dtype=t, keepdims=keepdims) jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3, np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3, @@ -313,7 +312,7 @@ def testReducerInitial(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -324,7 +323,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol, atol=tol) @@ -353,7 +352,7 @@ def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis, rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -364,7 +363,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype, promote_integers)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2, jnp.float16: 5e-3} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -390,7 +389,7 @@ def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -401,7 +400,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -436,7 +435,7 @@ def testReducerWhere(self, name, rng_factory, shape, dtype, axis, where = jtu.rand_bool(self.rng())(whereshape, np.bool_) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -447,7 +446,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -499,7 +498,7 @@ def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -510,7 +509,7 @@ def np_fun(x): return res jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -574,7 +573,7 @@ def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # setup ddof and correction kwargs excluding case when correction is not specified ddof_correction_kwargs = {"ddof": ddof} @@ -625,7 +624,7 @@ def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), @@ -834,7 +833,7 @@ def test_f16_mean(self, dtype): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) @@ -902,7 +901,7 @@ def testCumulativeSumBool(self): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): if jtu.is_device_tpu_at_least(6): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 16234463d795..80d1d4161cc5 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -50,7 +50,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning, tuple_update +from jax._src.util import safe_zip, tuple_update config.parse_flags_with_absl() @@ -2354,11 +2354,11 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2382,11 +2382,11 @@ def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] diff --git a/tests/lax_test.py b/tests/lax_test.py index 6792b08c37fa..a11c989fc9c5 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -48,7 +48,7 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning, safe_zip +from jax._src.util import safe_zip from jax._src.tree_util import tree_map config.parse_flags_with_absl() @@ -3744,7 +3744,7 @@ def testConvertElementReturnType(self, input_type, dtype, value, jit): @jtu.sample_product( dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): x = jax.device_put(np.zeros(5, dtype_in)) self.assertEqual(x.dtype, dtype_in) From 980f5dc49a09a3b4dc8e4870dfd838854f356b85 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 2 Jun 2025 09:27:21 -0700 Subject: [PATCH 1476/1769] always compile Pallas calls, enabling `pallas_call` under `disable_jit` How should `pallas_call` behave under `disable_jit`? We could: a) Error because we don't know what's expected. b) Run `pallas_call` as we would in eager mode. c) Run `pallas_call` in interpret mode. Today we do (a). This change implements (b) instead, where by "eager mode" we mean the behavior under no `jit` (and no `disable_jit` under it). Choice (c) seems to take things too far. On the one hand, it would execute the Pallas kernel "op-by-op," which may seem desirable. On the other hand, it would execute each individual such op on the host (CPU) rather than the device (accelerator). If we could do the first "half" of (c) alone somehow -- executing op-by-op, dispatching each op on the device, that may be ideal. Until then, (b) seems closer to expectations. To implement (b), this change simply re-enables jit only for the duration of the pallas call. --- jax/_src/pallas/pallas_call.py | 10 +++------ tests/pallas/pallas_test.py | 41 +++++++++++++--------------------- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 6f8c96a4591c..016bac96424e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -77,16 +77,12 @@ def _pallas_call_impl(*args, **params): # Call the lowering path - if config.disable_jit.value: - raise NotImplementedError( - "pallas_call not supported with disable_jit. Consider invoking under a" - " local context of `jax.disable_jit(False)`." - ) - @partial(jax.jit, inline=True) def _jit_run(*args): return pallas_call_p.bind(*args, **params) - return _jit_run(*args) + + with config.disable_jit(False): + return _jit_run(*args) pallas_call_p.def_impl(_pallas_call_impl) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 03399e12b609..cb61d5648912 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -692,6 +692,22 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) + def test_pallas_call_under_disable_jit(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1. + + x = jnp.arange(8, dtype=jnp.float32) + + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + + with jax.disable_jit(): + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + @parameterized.parameters( ("float32", None), ("float32", jax.lax.Precision.DEFAULT), @@ -1261,31 +1277,6 @@ def dot_general_kernel(x_ref, y_ref, o_ref): ): dot_general_kernel(x, y) - def test_jax_disable_jit(self): - def add_vectors_kernel(x_ref, y_ref, o_ref): - x, y = x_ref[...], y_ref[...] - o_ref[...] = x + y - - @jax.jit - def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return self.pallas_call( - add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - - # Prove kernel works fine without disable_jit. - add_vectors(jnp.arange(8), jnp.arange(8)) - - with self.assertRaisesRegex( - NotImplementedError, "pallas_call not supported with disable_jit." - ): - with jax.disable_jit(): - add_vectors(jnp.arange(8.0), jnp.arange(8.0)) - - with jax.disable_jit(): - # We instructed the user to do this, so this should not raise an error. - with jax.disable_jit(False): - add_vectors(jnp.arange(8.0), jnp.arange(8.0)) - class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True From a43ccbb9df2a8ed6acb5b9837116cbe3b6a298df Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Jun 2025 10:47:49 -0700 Subject: [PATCH 1477/1769] =?UTF-8?q?Fix=20native=20tiling=20logic=20in=20?= =?UTF-8?q?infer=5Fvector=5Flayout.=20For=20the=20pattern=20`arith.trunci`?= =?UTF-8?q?=20->=C2=A0`tpu.bitcast`=20->=20`tpu.matmul`,=20there=20will=20?= =?UTF-8?q?be=20a=20`tpu.relayout`=20op=20after=20`arith.trunci`=20before?= =?UTF-8?q?=20the=20fix,=20which=20has=20a=20negative=20impact=20on=20the?= =?UTF-8?q?=20performance=20e2e.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PiperOrigin-RevId: 766256767 --- .../tpu/transforms/infer_vector_layout.cc | 11 ++++----- jaxlib/mosaic/dialect/tpu/util.cc | 24 +++++++++++++++++++ jaxlib/mosaic/dialect/tpu/util.h | 6 +++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 14d1fb2104fa..17575183bd81 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1952,12 +1952,11 @@ class VectorLayoutInferer { } bool allUsersRequireNativeTiling(Value x) { - for (OpOperand &operand : x.getUses()) { - if (isa(operand.getOwner())) { + for (Operation *user : getNontrivialTransitiveUsers(x)) { + if (isa(user)) { continue; } - if (auto reduce = - dyn_cast(operand.getOwner())) { + if (auto reduce = dyn_cast(user)) { bool reduces_tiled_dims = false; for (int64_t dim : reduce.getReductionDims()) { if (dim >= reduce.getSourceVectorType().getRank() - 2) { @@ -1969,7 +1968,7 @@ class VectorLayoutInferer { continue; } } - if (auto transpose = dyn_cast(operand.getOwner())) { + if (auto transpose = dyn_cast(user)) { auto perm = transpose.getPermutation(); auto rank = perm.size(); // Only permutations that actually swap the last two dims need it. @@ -1979,7 +1978,7 @@ class VectorLayoutInferer { } // Fall through. } - if (auto store = dyn_cast(operand.getOwner())) { + if (auto store = dyn_cast(user)) { auto maybe_tiling = verifyMemoryTiling( store, getMemRefLayout(store.getBase()).getTiles(), store.getMemRefType().getRank(), diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 02598bd16f9a..0e67b4299f7e 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" @@ -313,4 +314,27 @@ bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size) { return product == target_size; } +SmallVector getNontrivialTransitiveUsers(Value v) { + auto isUnaryElementwise = [](Operation *op) { + if (!op->hasTrait()) { + return false; + } + return op->getNumOperands() == 1 && op->getNumResults() == 1; + }; + SmallVector users; + SmallVector candidates; + candidates.push_back(v); + while (!candidates.empty()) { + Value candidate = candidates.back(); + candidates.pop_back(); + for (const auto &user : candidate.getUsers()) { + if (isa(user) || isUnaryElementwise(user)) + candidates.push_back(user->getResult(0)); + else + users.push_back(user); + } + } + return users; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2a7325ee7b24..af590f45f619 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -290,6 +290,12 @@ std::optional getIntConst(Value v); // Precondition: `shape` has at least 2 dimensions. bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size); +// Recursively finds all non-trivial users of a given value, including those +// accessed via `tpu.bitcast` or unary elementwise operations. However, +// `tpu.bitcast` and unary element-wise operations are excluded from the +// results. +SmallVector getNontrivialTransitiveUsers(Value v); + } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ From 62ab725780bd4f3bf0af54bd60a7360de00205aa Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Mon, 2 Jun 2025 11:20:00 -0700 Subject: [PATCH 1478/1769] Update workflow files to use new ml-build containers. PiperOrigin-RevId: 766270515 --- .github/workflows/bazel_cpu_py_import_rbe.yml | 2 +- .github/workflows/bazel_cpu_rbe.yml | 2 +- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- ci/envs/docker.env | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml index cc3ae89d97f9..65a7b7b6a01f 100644 --- a/.github/workflows/bazel_cpu_py_import_rbe.yml +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -38,7 +38,7 @@ jobs: shell: bash runs-on: ${{ inputs.runner }} container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 3eff0932adcb..71c140464454 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -29,7 +29,7 @@ jobs: if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} + (contains(matrix.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 95ab90412494..ece2237eeead 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -90,7 +90,7 @@ jobs: runs-on: ${{ inputs.runner }} container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} env: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 95086257c62b..d23c1f543827 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -47,7 +47,7 @@ jobs: shell: bash runs-on: ${{ inputs.runner }} container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 5135b61ac45b..d556cb82d74d 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -35,11 +35,11 @@ fi # Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest" fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/tf-test-windows:latest" fi \ No newline at end of file From e9925ee03d5b5d442830444a0b10ac6f3486e1be Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Mon, 2 Jun 2025 12:05:05 -0700 Subject: [PATCH 1479/1769] Enable profiler_test for TPU's PiperOrigin-RevId: 766289665 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index e3672eb73f48..75946713a49e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1136,6 +1136,7 @@ jax_multiplatform_test( enable_backends = [ "cpu", "gpu", + "tpu", ], deps = [ "//jax:profiler", From 3e52872cb6fcb73455553cb5ebfe51c268f4b19c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 2 Jun 2025 12:29:10 -0700 Subject: [PATCH 1480/1769] Clean up some unused GPU linear algebra kernels. This change removes the legacy `csrlsvqr` and `sytrd` custom calls from jaxlib. These were never covered by the export compatibility policy, and their FFI counterparts have been targeted by JAX for several releases. PiperOrigin-RevId: 766298494 --- jaxlib/cuda/BUILD | 26 ---- jaxlib/gpu/BUILD | 2 - jaxlib/gpu/gpu_kernels.cc | 3 - jaxlib/gpu/solver.cc | 87 ------------ jaxlib/gpu/solver_kernels.cc | 255 ----------------------------------- jaxlib/gpu/solver_kernels.h | 65 --------- jaxlib/rocm/BUILD | 26 ---- 7 files changed, 464 deletions(-) delete mode 100644 jaxlib/gpu/solver_kernels.cc delete mode 100644 jaxlib/gpu/solver_kernels.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index eabb3157ecca..7bcb526e6e38 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -155,24 +155,6 @@ cc_library( ], ) -cc_library( - name = "cusolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - ], -) - cc_library( name = "cusolver_interface", srcs = ["//jaxlib/gpu:solver_interface.cc"], @@ -223,21 +205,14 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", ":cuda_vendor", - ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -472,7 +447,6 @@ cc_library( ":cuda_prng_kernels", ":cuda_vendor", ":cudnn_rnn_kernels", - ":cusolver_kernels", ":cusolver_kernels_ffi", ":cusparse_kernels", ":triton_kernels", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index e153e0588cf6..98f0f6cfe624 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -59,8 +59,6 @@ exports_files(srcs = [ "solver_handle_pool.h", "solver_interface.cc", "solver_interface.h", - "solver_kernels.cc", - "solver_kernels.h", "solver_kernels_ffi.cc", "solver_kernels_ffi.h", "sparse.cc", diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index c59cc7d8076b..3204053b8822 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -19,7 +19,6 @@ limitations under the License. #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/triton_kernels.h" @@ -40,14 +39,12 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", SytrdFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index e4d6b5d4dedf..08d25948d893 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -13,21 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -35,79 +24,8 @@ namespace { namespace nb = nanobind; -// Converts a NumPy dtype to a Type. -SolverType DtypeToSolverType(const dtype& np_type) { - static auto* types = - new absl::flat_hash_map, SolverType>({ - {{'f', 4}, SolverType::F32}, - {{'f', 8}, SolverType::F64}, - {{'c', 8}, SolverType::C64}, - {{'c', 16}, SolverType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -// Returns a descriptor for a csrlsvqr operation. -nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, - int reorder, double tol) { - SolverType type = DtypeToSolverType(dtype); - return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol}); -} - -#endif // JAX_GPU_CUDA - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})}; -} - nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd); - -#ifdef JAX_GPU_CUDA - dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); -#endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); @@ -127,12 +45,7 @@ nb::dict Registrations() { } NB_MODULE(_solver, m) { - tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_sytrd_descriptor", &BuildSytrdDescriptor); -#ifdef JAX_GPU_CUDA - m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); -#endif // JAX_GPU_CUDA } } // namespace diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc deleted file mode 100644 index d054e77d2102..000000000000 --- a/jaxlib/gpu/solver_kernels.cc +++ /dev/null @@ -1,255 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/solver_kernels.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -#ifdef JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cusolverSp.h" -#endif // JAX_GPU_CUDA - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -static int SizeOfSolverType(SolverType type) { - switch (type) { - case SolverType::F32: - return sizeof(float); - case SolverType::F64: - return sizeof(double); - case SolverType::C64: - return sizeof(gpuComplex); - case SolverType::C128: - return sizeof(gpuDoubleComplex); - } -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -static absl::Status Csrlsvqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - int& singularity) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CsrlsvqrDescriptor& d = **s; - - // This is the handle to the CUDA session. Gets a cusolverSp handle. - auto h = SpSolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseMatDescr_t matdesc = nullptr; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateMatDescr(&matdesc))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO))); - - switch (d.type) { - case SolverType::F32: { - float* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - float* b = static_cast(buffers[3]); - float* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpScsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::F64: { - double* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - double* b = static_cast(buffers[3]); - double* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpDcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C64: { - gpuComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuComplex* b = static_cast(buffers[3]); - gpuComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C128: { - gpuDoubleComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuDoubleComplex* b = static_cast(buffers[3]); - gpuDoubleComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpZcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - } - - cusparseDestroyMatDescr(matdesc); - return absl::OkStatus(); -} - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - // Is >= 0 if A is singular. - int singularity = -1; - - auto s = Csrlsvqr_(stream, buffers, opaque, opaque_len, singularity); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } - - if (singularity >= 0) { - auto s = std::string("Singular matrix in linear solve."); - XlaCustomCallStatusSetFailure(status, s.c_str(), s.length()); - } -} - -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction - -static absl::Status Sytrd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SytrdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.lda), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[5]); - void* workspace = buffers[6]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - float* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - double* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - gpuComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - gpuDoubleComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Sytrd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h deleted file mode 100644 index c325e746b709..000000000000 --- a/jaxlib/gpu/solver_kernels.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CUSOLVER_KERNELS_H_ -#define JAXLIB_CUSOLVER_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class SolverType { - F32, - F64, - C64, - C128, -}; - -#ifdef JAX_GPU_CUDA - -// csrlsvpr: Linear system solve via Sparse QR - -struct CsrlsvqrDescriptor { - SolverType type; - int n, nnz, reorder; - double tol; -}; - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. -struct SytrdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n, lda, lwork; -}; - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_CUSOLVER_KERNELS_H_ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d0468d72d1b3..76e3ef01563c 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -143,24 +143,6 @@ cc_library( ], ) -cc_library( - name = "hipsolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipsolver", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - cc_library( name = "hipsolver_interface", srcs = ["//jaxlib/gpu:solver_interface.cc"], @@ -195,7 +177,6 @@ cc_library( "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -209,20 +190,13 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", ":hip_vendor", - ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", ], ) From 674fb5b57797f2686f1502b4e41cd2ffa3adbb26 Mon Sep 17 00:00:00 2001 From: sora <210at85@gmail.com> Date: Sun, 1 Jun 2025 19:37:11 +0200 Subject: [PATCH 1481/1769] Simplify `jnp.isclose` - Clean up implementation of `isclose` to match NumPy 2.* behavior - Add tests for corner cases, kindly provided by @jakevdp --- jax/_src/numpy/lax_numpy.py | 28 +++++----------------------- tests/lax_numpy_test.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ad2b3ad6aa75..f42f11844783 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2602,31 +2602,13 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike dtype = np.array(0, dtype).real.dtype rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) - out = lax.le( + both_nan = ufuncs.logical_and(ufuncs.isnan(a), ufuncs.isnan(b)) + check_fin = ufuncs.isfinite(b) + in_range = lax.le( lax.abs(lax.sub(a, b)), lax.add(atol, lax.mul(rtol, lax.abs(b)))) - # This corrects the comparisons for infinite and nan values - a_inf = ufuncs.isinf(a) - b_inf = ufuncs.isinf(b) - any_inf = ufuncs.logical_or(a_inf, b_inf) - both_inf = ufuncs.logical_and(a_inf, b_inf) - # Make all elements where either a or b are infinite to False - out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) - # Make all elements where both a or b are the same inf to True - same_value = lax.eq(a, b) - same_inf = ufuncs.logical_and(both_inf, same_value) - out = ufuncs.logical_or(out, same_inf) - - # Make all elements where either a or b is NaN to False - a_nan = ufuncs.isnan(a) - b_nan = ufuncs.isnan(b) - any_nan = ufuncs.logical_or(a_nan, b_nan) - out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) - if equal_nan: - # Make all elements where both a and b is NaN to True - both_nan = ufuncs.logical_and(a_nan, b_nan) - out = ufuncs.logical_or(out, both_nan) - return out + out = ufuncs.logical_or(lax.eq(a, b), ufuncs.logical_and(check_fin, in_range)) + return ufuncs.logical_or(out, both_nan) if equal_nan else out def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 16234463d795..60f3bbb09edf 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3963,6 +3963,22 @@ def testIsClose(self): key = jax.random.key(0) self.assertTrue(jnp.isclose(key, key)) + @jtu.sample_product( + atol=[0.0, 1E-4, np.inf], + rtol=[0.0, 1E-4, np.inf], + equal_nan=[True, False] + ) + def testIsCloseCornerCases(self, atol, rtol, equal_nan): + if jtu.numpy_version() < (2, 0, 0) and (np.isinf(atol) or np.isinf(rtol)): + self.skipTest("fails on older NumPy") + vals = np.array([-np.nan, -np.inf, -1.00001, -1.0, -0.00001, -0.0, + 0.0, 0.00001, 1.0, 1.00001, np.inf, np.nan]) + x, y = np.meshgrid(vals, vals) + self.assertArraysEqual( + np.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan), + jnp.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan) + ) + @jtu.sample_product( x=[1, [1], [1, 1 + 1E-4], [1, np.nan]], y=[1, [1], [1, 1 + 1E-4], [1, np.nan]], From 6f0c2a8d644a4ae4ee3cbade66ddaee2ac6b32a5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 2 Jun 2025 13:05:20 -0700 Subject: [PATCH 1482/1769] Clean up some unused GPU sparse kernels. These kernels aren't covered by the export compatibility policy, and their FFI counterparts have been targeted by JAX for several releases. PiperOrigin-RevId: 766310963 --- jaxlib/cuda/BUILD | 6 -- jaxlib/gpu/gpu_kernels.cc | 40 +++++----- jaxlib/gpu/sparse.cc | 127 ------------------------------ jaxlib/gpu/sparse_kernels.cc | 148 ----------------------------------- jaxlib/gpu/sparse_kernels.h | 57 +------------- jaxlib/rocm/BUILD | 6 -- 6 files changed, 22 insertions(+), 362 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 7bcb526e6e38..33ed84b08753 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -244,7 +244,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", ], @@ -264,7 +263,6 @@ nanobind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:absl_status_casters", - "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -272,13 +270,11 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", "@xla//xla/tsl/python/lib/core:numpy", @@ -354,7 +350,6 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -370,7 +365,6 @@ cuda_library( "//jaxlib:kernel_helpers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 3204053b8822..8428562e3248 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -59,28 +59,26 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation", XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA", ThreeFry2x32Ffi); -#if JAX_CUSPARSE_11300 -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat, - "CUDA"); +#if JAX_GPU_HAVE_SPARSE +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_todense_ffi", "CUDA", + CsrToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_fromdense_ffi", "CUDA", + CsrFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matvec_ffi", "CUDA", + CsrMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matmat_ffi", "CUDA", + CsrMatmatFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_todense_ffi", "CUDA", + CooToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_fromdense_ffi", "CUDA", + CooFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matvec_ffi", "CUDA", + CooMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matmat_ffi", "CUDA", + CooMatmatFfi); #endif -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64, - "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_gtsv2_ffi", "CUDA", + kGtsv2); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("triton_kernel_call", TritonKernelCall, "CUDA"); diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 0190ba776de5..21f567e79f92 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -15,11 +15,9 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/strings/str_format.h" #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" // IWYU pragma: keep @@ -27,9 +25,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/service/custom_call_status.h" #include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; @@ -146,45 +142,6 @@ std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseSpMatDescr_t mat_a = 0; - gpusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseSparseToDense(handle.get(), mat_a, mat_b, - GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix // Returns the descriptor for a CsrFromDense operation. @@ -221,46 +178,6 @@ std::pair BuildCsrFromDenseDescriptor( return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseDnMatDescr_t mat_a = 0; - gpusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. // Returns the descriptor for a CsrMatvec operation. @@ -553,44 +470,9 @@ std::pair BuildCooMatmatDescriptor( #endif // if JAX_GPU_HAVE_SPARSE -nb::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { - return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb}); -} - -template -size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - size_t size; - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, - /*du=*/nullptr, /*B=*/nullptr, ldb, &size))); - return size; -} - -size_t Gtsv2BufferSizeF32(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb); -} - -size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb); -} - nb::dict Registrations() { nb::dict dict; #if JAX_GPU_HAVE_SPARSE - dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense); - dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] = - EncapsulateFunction(CsrFromDense); - dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); - dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); - dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense); - dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] = - EncapsulateFunction(CooFromDense); - dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec); - dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat); - dict[JAX_GPU_PREFIX "sparse_csr_todense_ffi"] = EncapsulateFfiHandler(CsrToDenseFfi); dict[JAX_GPU_PREFIX "sparse_csr_fromdense_ffi"] = @@ -608,12 +490,6 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "sparse_coo_matmat_ffi"] = EncapsulateFfiHandler(CooMatmatFfi); #endif - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32_ffi"] = - EncapsulateFfiHandler(gtsv2_f32_ffi); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64_ffi"] = - EncapsulateFfiHandler(gtsv2_f64_ffi); dict[JAX_GPU_PREFIX "sparse_gtsv2_ffi"] = EncapsulateFfiHandler(kGtsv2); // TODO(tomhennigan): Add support for gtsv2 complex 32/64. @@ -634,9 +510,6 @@ NB_MODULE(_sparse, m) { m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor); m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor); #endif - m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32); - m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64); - m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor); } } // namespace diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 363321e3ca8b..139fbc73f8ce 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -32,7 +32,6 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -189,15 +188,6 @@ static absl::Status CsrToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrToDenseFfi, CsrToDense_); -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, @@ -233,15 +223,6 @@ static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrFromDenseFfi, CsrFromDense_); -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, @@ -292,15 +273,6 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatvecFfi, CsrMatvec_); -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatmat: Product of CSR matrix and dense matrix. static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, @@ -352,15 +324,6 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatmatFfi, CsrMatmat_); -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooToDense: Convert COO matrix to dense matrix static absl::Status CooToDense_(gpuStream_t stream, void** buffers, @@ -395,15 +358,6 @@ static absl::Status CooToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooToDenseFfi, CooToDense_); -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooFromDense: Convert dense matrix to COO matrix static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, @@ -439,15 +393,6 @@ static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooFromDenseFfi, CooFromDense_); -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatvec: Product of COO matrix and dense vector. static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, @@ -497,15 +442,6 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatvecFfi, CooMatvec_); -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatmat: Product of COO matrix and dense matrix. static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, @@ -564,92 +500,8 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, } JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatmatFfi, CooMatmat_); - -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} #endif // if JAX_GPU_HAVE_SPARSE -template -static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const Gtsv2Descriptor& descriptor = **s; - int batch = descriptor.batch; - int m = descriptor.m; - int n = descriptor.n; - int ldb = descriptor.ldb; - - T* dl = static_cast(buffers[0]); - T* d = static_cast(buffers[1]); - T* du = static_cast(buffers[2]); - T* B = static_cast(buffers[3]); - T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); - - // The solution X is written in place to B. We need to therefore copy the - // contents of B into the output buffer X and pass that into the kernel as B. - // Once copy insertion is supported for custom call aliasing, we could alias B - // with X and avoid the copy, the code below is written defensively assuming B - // and X might alias, but today we know they will not. - // TODO(b/182906199): Update the comment here once copy insertion is WAI. - if (X != B) { - size_t B_bytes = ldb * n * sizeof(T) * batch; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); - } - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); - dl += m; - d += m; - du += m; - X += m * n; - } - return absl::OkStatus(); -} - -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f32_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - }); - -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f64_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseDgtsv2, stream, buffers, opaque, - opaque_len); - }); - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseDgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - template ffi::Error Gtsv2Impl(BufferSizeF getBufferSize, KernelF kernel, int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 3b365872f591..75f83752be15 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,14 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include #include #include "absl/status/statusor.h" #include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -72,17 +70,6 @@ struct DenseVecDescriptor { }; #if JAX_GPU_HAVE_SPARSE -// CsrToDense: Convert CSR matrix to dense matrix - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrFromDense: Convert dense matrix to CSR matrix - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatvec: Product of CSR matrix and dense vector. struct CsrMatvecDescriptor { SparseMatDescriptor A; @@ -90,63 +77,24 @@ struct CsrMatvecDescriptor { gpusparseOperation_t op; }; -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatmat: Product of CSR matrix and dense matrix. - struct CsrMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooToDense: Convert COO matrix to dense matrix - -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooFromDense: Convert dense matrix to COO matrix - -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatvec: Product of COO matrix and dense vector. - struct CooMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; gpusparseOperation_t op; }; -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatmat: Product of COO matrix and dense matrix. - struct CooMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_HAVE_SPARSE - -struct Gtsv2Descriptor { - int batch, m, n, ldb; -}; - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrMatvecFfi); @@ -155,8 +103,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CooToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatvecFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatmatFfi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f32_ffi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f64_ffi); + +#endif // JAX_GPU_HAVE_SPARSE + XLA_FFI_DECLARE_HANDLER_SYMBOL(kGtsv2); } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 76e3ef01563c..a24a1617d309 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -229,7 +229,6 @@ cc_library( "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -247,7 +246,6 @@ nanobind_extension( ":hip_vendor", ":hipsparse_kernels", "//jaxlib:absl_status_casters", - "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -255,14 +253,12 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -330,7 +326,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -344,7 +339,6 @@ rocm_library( "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) From 94037a8186a956deb7d46b20ee6554d359bc9ef0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 2 Jun 2025 13:22:50 -0700 Subject: [PATCH 1483/1769] Maintain the dtype of the input on the output in `broadcast_one_to_all`. PiperOrigin-RevId: 766317087 --- jax/experimental/multihost_utils.py | 4 ++-- tests/array_test.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 7be349f0fc8f..3a83ff16d612 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -39,8 +39,8 @@ import numpy as np -def _psum(x: Any) -> Any: - return jax.tree.map(partial(jnp.sum, axis=0), x) +def _psum(xs: Any) -> Any: + return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs) def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: diff --git a/tests/array_test.py b/tests/array_test.py index 10a17b557dea..44734e64a995 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -712,6 +712,13 @@ def test_process_allgather_single_host(self): self.assertEqual(out.shape, (1, x.shape[0])) self.assertArraysEqual(out, np.expand_dims(x, axis=0)) + def test_broadcast_one_to_all_single_host(self): + x = jnp.arange(8, dtype=jnp.uint8) + out = multihost_utils.broadcast_one_to_all(x) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, x.dtype) + self.assertArraysEqual(out, x) + @jtu.sample_product( dtype=jtu.dtypes.all, shape=[(), (10), (2, 3)], From 9e4ff925bb9f562a2e5be9d82499ac276892734a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Jun 2025 13:38:46 -0700 Subject: [PATCH 1484/1769] [pallas] Added a note on `pl.loop` to the changelog PiperOrigin-RevId: 766324349 --- docs/pallas/CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 40a30057354d..5c916c66ed86 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list ## Unreleased +* New functionality + + * Added a new decorator {func}`jax.experimental.pallas.loop` which allows + to write stateless loops as functions. + * Deprecations * {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been From 81de911a86655551f43b352b8106530dab763407 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Jun 2025 13:54:57 -0700 Subject: [PATCH 1485/1769] [pallas:mosaic_gpu] `plgpu.nd_loop` is now a decorator similar to `pl.loop` The name was too similar to `pl.loop`, so having a different calling convention was confusing. PiperOrigin-RevId: 766330213 --- jax/_src/pallas/mosaic_gpu/helpers.py | 43 ++++++++++--------- .../pallas/ops/gpu/ragged_dot_mgpu.py | 6 +-- tests/pallas/mosaic_gpu_test.py | 5 +-- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/helpers.py b/jax/_src/pallas/mosaic_gpu/helpers.py index 54c4910059d5..939f3d0382e7 100644 --- a/jax/_src/pallas/mosaic_gpu/helpers.py +++ b/jax/_src/pallas/mosaic_gpu/helpers.py @@ -26,11 +26,9 @@ def nd_loop( grid: Sequence[int], - body: Callable[[Sequence[jax.Array], _T], _T], - init_val: _T, *, collective_axes: Sequence[Hashable] | Hashable, -) -> _T: +) -> Callable[[Callable[[Sequence[jax.Array]], None]], None]: """A loop over a multi-dimensional grid partitioned along the given axes. For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size` @@ -61,26 +59,29 @@ def nd_loop( 3 (1, 0) See also: - - :func:`jax.lax.fori_loop`: A single-dimensional indexed loop. + - :func:`jax.experimental.pallas.loop`: A loop over a single dimension. """ axis_index = lax.axis_index(collective_axes) axis_size = lax.axis_size(collective_axes) grid_size = math.prod(grid) - def wrapper(step, carry): - step = step * axis_size + axis_index - # The loop below is conceptually ``jnp.unravel_index``, but it uses - # ``lax`` APIs instead of ``jax.numpy`` to minimize the number of - # primitives used. - index = [] - for grid_dim in reversed(grid): - grid_dim = lax.convert_element_type(grid_dim, step.dtype) - index.append(lax.rem(step, grid_dim)) - step = lax.div(step, grid_dim) - index.reverse() - return body(tuple(index), carry) - - upper = lax.div(grid_size, axis_size) + lax.convert_element_type( - axis_index < grid_size % axis_size, axis_index.dtype - ) - return lax.fori_loop(0, upper, wrapper, init_val) + def decorator(body): + def wrapper(step, _): + step = step * axis_size + axis_index + # The loop below is conceptually ``jnp.unravel_index``, but it uses + # ``lax`` APIs instead of ``jax.numpy`` to minimize the number of + # primitives used. + index = [] + for grid_dim in reversed(grid): + grid_dim = lax.convert_element_type(grid_dim, step.dtype) + index.append(lax.rem(step, grid_dim)) + step = lax.div(step, grid_dim) + index.reverse() + return body(tuple(index)) + + upper = lax.div(grid_size, axis_size) + lax.convert_element_type( + axis_index < grid_size % axis_size, axis_index.dtype + ) + return lax.fori_loop(0, upper, wrapper, None) + + return decorator diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py index 6d295a36f435..9a1514b9827c 100644 --- a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -140,10 +140,8 @@ def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem): pl.cdiv(n, grid_block_n * block_n), ) - @functools.partial( - plgpu.nd_loop, grid, init_val=None, collective_axes="sm" - ) - def mn_loop(idx, _): # pylint: disable=unused-variable + @plgpu.nd_loop(grid, collective_axes="sm") + def mn_loop(idx): # pylint: disable=unused-variable block_ni, mi, remainder_ni = idx ni = block_ni * pl.cdiv(n, block_n * grid_block_n) + remainder_ni group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f9a23e7be9d0..f7bc83fc460d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1747,7 +1747,8 @@ def test_nd_loop(self, sm_steps): grid_names=("sm",), ) def kernel(o_ref): - def body(idx, _): + @plgpu.nd_loop((sm_steps, 4, 33), collective_axes="sm") + def _(idx): assert len(idx) == 3 # We need to use `mode="clip"`, because the indices are not static. flat_idx = jnp.ravel_multi_index(idx, (sm_steps, 4, 33), mode="clip") @@ -1758,8 +1759,6 @@ def body(idx, _): flat_idx, o_ref.shape[-1:] ) - plgpu.nd_loop((sm_steps, 4, 33), body, None, collective_axes="sm") - result = kernel() for sm_step in range(sm_steps): np.testing.assert_array_equal( From 3ede957e587695dd0f5a1e7feff70ab8f4f5a8a1 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 2 Jun 2025 14:44:05 -0700 Subject: [PATCH 1486/1769] [Mosaic GPU] Add reduction support for TCGEN05 layout. PiperOrigin-RevId: 766350008 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +-- tests/pallas/mosaic_gpu_test.py | 36 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ce74a5ba7c05..5695da4cc8b1 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2035,7 +2035,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: return x.reduce("add", axes, scratch) - case mgpu.WGMMA_LAYOUT: + case mgpu.TiledLayout(): if axes != (x_aval.ndim - 1,): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -2049,7 +2049,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: - case mgpu.WGMMA_LAYOUT: + case mgpu.TiledLayout(): if axes != (x_aval.ndim - 1,): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f7bc83fc460d..61a9f18ef26d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2368,6 +2368,42 @@ def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): x_result = jax.block_until_ready(kernel(x)) np.testing.assert_array_equal(x_result, x + 1) + @parameterized.parameters( + (jnp.sum,), + (jnp.max,) + ) + def test_reduce_with_tcgen05_layout(self, op): + axis = -1 + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128,), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, smem_ref, smem_reduced_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + smem_reduced_ref[...] = op(x_val, axis=axis) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_reduced_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_allclose(x_result, op(x, axis=axis), atol=1e-5) + @parameterized.product(shape=[(128, 128)], swizzle=[128, 64, 32], dtype=[jnp.float16, jnp.bfloat16], From 2f32a794717ee220dcf1dea0679c6cea1c30dd38 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 2 Jun 2025 14:48:35 -0700 Subject: [PATCH 1487/1769] Clean up unused GPU RNN kernels. These kernels aren't covered by the export compatibility policy, and their FFI counterparts have been targeted by JAX for several releases. PiperOrigin-RevId: 766351572 --- jaxlib/cuda/BUILD | 1 - jaxlib/gpu/gpu_kernels.cc | 5 +++-- jaxlib/gpu/rnn.cc | 2 -- jaxlib/gpu/rnn_kernels.cc | 19 ------------------- jaxlib/gpu/rnn_kernels.h | 7 ------- jaxlib/rocm/BUILD | 1 - 6 files changed, 3 insertions(+), 32 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 33ed84b08753..5cc401e14eb0 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -113,7 +113,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", ], diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 8428562e3248..1f6e5f75315d 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -31,8 +31,9 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn", "CUDA", RNNForwardFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn_bwd", "CUDA", + RNNBackwardFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index 32e0842e3038..c235aa9fecfb 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -39,8 +39,6 @@ nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward); - dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward); dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi); dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] = EncapsulateFfiHandler(RNNBackwardFfi); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index d06535a668ac..44864d6a2663 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -30,7 +30,6 @@ limitations under the License. #include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -541,24 +540,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, return absl::OkStatus(); } -void RNNForward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_); JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_); diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 36d8c25c6a9f..c1d6712a9eac 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -47,12 +46,6 @@ absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); -void RNNForward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - -void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi); diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index a24a1617d309..f265e6714c8e 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -104,7 +104,6 @@ cc_library( "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) From 31017c559b2bde21e1e4198befaf6c23dee1eb3b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 2 Jun 2025 15:32:07 -0700 Subject: [PATCH 1488/1769] When the size of the remainder array is 0, don't append it to the remainder_leaves list. This fixes usage of lax.map in sharding-in-types mode. PiperOrigin-RevId: 766367547 --- jax/_src/lax/control_flow/loops.py | 41 +++++++++++++++++------------- tests/pjit_test.py | 14 ++++++++++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 47808ee3c423..146f27e5d2e7 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2509,19 +2509,23 @@ def fori_loop(lower, upper, body_fun, init_val): def _batch_and_remainder(x, batch_size: int): leaves, treedef = tree_flatten(x) - - scan_leaves = [] - remainder_leaves = [] - - for leaf in leaves: - num_batches, _ = divmod(leaf.shape[0], batch_size) - total_batch_elems = num_batches * batch_size - scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:])) - remainder_leaves.append(leaf[total_batch_elems:]) - - scan_tree = treedef.unflatten(scan_leaves) - remainder_tree = treedef.unflatten(remainder_leaves) - return scan_tree, remainder_tree + if not leaves: + return x, None + num_batches, remainder = divmod(leaves[0].shape[0], batch_size) + total_batch_elems = num_batches * batch_size + if remainder: + scan_leaves, remainder_leaves = [], [] + for leaf in leaves: + scan_leaves.append(leaf[:total_batch_elems].reshape( + num_batches, batch_size, *leaf.shape[1:])) + remainder_leaves.append(leaf[total_batch_elems:]) + return treedef.unflatten(scan_leaves), treedef.unflatten(remainder_leaves) + else: + scan_leaves = [ + leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]) + for leaf in leaves + ] + return treedef.unflatten(scan_leaves), None @api_boundary def map(f, xs, *, batch_size: int | None = None): @@ -2576,11 +2580,14 @@ def map(f, xs): scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size) g = lambda _, x: ((), api.vmap(f)(x)) _, scan_ys = scan(g, (), scan_xs) - remainder_ys = api.vmap(f)(remainder_xs) flatten = lambda x: x.reshape(-1, *x.shape[2:]) - ys = tree_map( - lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, remainder_ys, - ) + if remainder_xs is not None: + remainder_ys = api.vmap(f)(remainder_xs) + ys = tree_map( + lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, + remainder_ys) + else: + ys = tree_map(flatten, scan_ys) else: g = lambda _, x: ((), f(x)) _, ys = scan(g, (), xs) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dd5c5d46e62f..c4d36ab78d10 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7893,6 +7893,20 @@ def test_nn_constant(self, mesh): self.assertArraysEqual(out, jnp.full((8, 2), -7, dtype=jnp.float32)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_lax_map(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P('x')) + x = jax.device_put(np.ones((4, 2, 4), dtype=np.float32), + P(None, 'y', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x) # doesn't crash + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 4367d7c7e94fe54cb704ed95ba2495c2233bd07f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Jun 2025 17:53:19 -0700 Subject: [PATCH 1489/1769] Move jax/_src/extend/* to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 766414704 --- jax/BUILD | 7 ++++++- jax/extend/BUILD | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index aaa25db88beb..1ec0ddd655f2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -324,7 +324,6 @@ py_library_providing_imports_info( "*.py", "_src/cudnn/**/*.py", "_src/debugger/**/*.py", - "_src/extend/**/*.py", "_src/image/**/*.py", "_src/export/**/*.py", "_src/lax/**/*.py", @@ -1544,6 +1543,12 @@ pytype_library( ], ) +pytype_strict_library( + name = "extend_src", + srcs = glob(include = ["_src/extend/**/*.py"]), + deps = [":jax"], +) + # TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/... pytype_strict_library( name = "extend", diff --git a/jax/extend/BUILD b/jax/extend/BUILD index f466f1748654..1147e6bf502f 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -70,7 +70,10 @@ pytype_strict_library( pytype_strict_library( name = "random", srcs = ["random.py"], - deps = ["//jax"], + deps = [ + "//jax", + "//jax:extend_src", + ], ) pytype_strict_library( From 41fd7a70dc48fca18dc84665d6d1c8f39fa6880f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Jun 2025 20:03:00 -0700 Subject: [PATCH 1490/1769] Move jax/_src/custom_partitioning_sharding_rule.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 766447138 --- jax/BUILD | 10 +++++++++- tests/BUILD | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 1ec0ddd655f2..b577abcabf5f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -303,7 +303,6 @@ py_library_providing_imports_info( "_src/checkify.py", "_src/custom_batching.py", "_src/custom_partitioning.py", - "_src/custom_partitioning_sharding_rule.py", "_src/debugging.py", "_src/dispatch.py", # TODO(vanderplas): remove this and depend on :api instead "_src/dlpack.py", @@ -388,6 +387,7 @@ py_library_providing_imports_info( ":custom_api_util", ":custom_dce", ":custom_derivatives", + ":custom_partitioning_sharding_rule", ":custom_transpose", ":deprecations", ":dtypes", @@ -700,6 +700,14 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "custom_partitioning_sharding_rule", + srcs = ["_src/custom_partitioning_sharding_rule.py"], + deps = [ + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "custom_transpose", srcs = ["_src/custom_transpose.py"], diff --git a/tests/BUILD b/tests/BUILD index 75946713a49e..faa3367aecba 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2158,6 +2158,7 @@ jax_py_test( srcs = ["custom_partitioning_sharding_rule_test.py"], deps = [ "//jax", + "//jax:custom_partitioning_sharding_rule", "//jax:experimental", "//jax:test_util", ] + py_deps("absl/testing"), From 2193c59fb06436a6d9d2388bd91f192c4cc1054e Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Jun 2025 02:01:04 -0700 Subject: [PATCH 1491/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9b20d33306b4f15bc17f0235a786dddac96d046e. PiperOrigin-RevId: 766554226 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fb27433402ba..5ee95ca302a3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "89872387d699d4622c96a4572e7772b3d1f387cd" -XLA_SHA256 = "ae1e6d82acc1ec54e1cde1162246a09724ea3c0868054a8a0721c18417c54be5" +XLA_COMMIT = "9b20d33306b4f15bc17f0235a786dddac96d046e" +XLA_SHA256 = "eba6c387448b05fee0f26e7a28ead5b4ad17342b45f451f6a8748a87c0141b1c" def repo(): tf_http_archive( From d30b176c8620ac8e9081f38e549bafa03484c53e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 3 Jun 2025 02:54:26 -0700 Subject: [PATCH 1492/1769] [Mosaic GPU] Add support for tiled loads and stores of `f8` data types. PiperOrigin-RevId: 766570588 --- .../mosaic/gpu/fragmented_array.py | 27 ++++++++++++--- tests/mosaic/gpu_test.py | 34 +++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 04dd30023293..7278af5d7a91 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2063,10 +2063,18 @@ def load_tiled( ), ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) - reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) + is_f8 = ir.FloatType.isinstance(dtype) and utils.bitwidth(dtype) == 8 + i8 = ir.IntegerType.get_signless(8) + reg_ty = ir.VectorType.get((layout.vector_length,), dtype) + # f8 data types are not handled by the LLVM dialect, so we need to + # transfer them as i8 and bitcast them back to f8. + transfer_ty = ir.VectorType.get((layout.vector_length,), i8 if is_f8 else dtype) loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) for _, update, ptr in loads: - update(registers, llvm.load(reg_ty, ptr)) + loaded_reg = llvm.load(transfer_ty, ptr) + if is_f8: + loaded_reg = vector.bitcast(reg_ty, loaded_reg) + update(registers, loaded_reg) case _: raise NotImplementedError(layout) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @@ -2259,7 +2267,12 @@ def transfer_tiled2( # Technically we should keep the vector_dim set to 1, but its shape is 1 # so it does not matter. transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] - transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) + is_f8 = ir.FloatType.isinstance(dtype) and element_bits == 8 + i8 = ir.IntegerType.get_signless(8) + if is_f8: + transfer_dtype = ir.VectorType.get((layout.vector_length,), i8) + else: + transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) if ref_ty.memory_space is None: llvm_memory_space = None @@ -2327,7 +2340,13 @@ def mem_idx_to_reg_idx(idx): return (*reg_tiled_idx, *idx[base_idx:]) reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): - return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) + def cast_if_f8(x): + if is_f8: + return vector.bitcast(transfer_dtype, x) + return x + # f8 data types are not handled by the LLVM dialect, so we need to + # transfer them as i8 and bitcast them back to f8. + return plan.select([cast_if_f8(regs[reg_idx]) for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): # TODO(apaszke): If the staggering forms a permutation with a small # cycle length, then instead of blending at each step we could construct diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index a56aa04f6f60..314dc8f8f41d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -507,6 +507,40 @@ def kernel(ctx, out, _): )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtype=[jnp.float8_e5m2fnuz, jnp.float8_e5m2, jnp.float8_e4m3b11fnuz, + jnp.float8_e4m3fn, jnp.float8_e4m3fnuz], + swizzle=(32, 64, 128), + num_col_tiles=(1, 2, 3), + ) + def test_load_and_store_tiled_f8(self, dtype, swizzle, num_col_tiles): + # We use a different test than `test_store_tiled` because converting + # `iota` to `f8` type requires additional specialized logic that is not + # yet available. + col_tiling = swizzle + m = 128 + n = col_tiling * num_col_tiles + tiling = (64, col_tiling) + def kernel(ctx, inp, out, smem): + del ctx + smem_inp, smem_out = smem + copy(inp, smem_inp, swizzle=swizzle) + arr = mgpu.FragmentedArray.load_tiled(smem_inp, swizzle=swizzle) + arr.store_tiled(smem_out, swizzle=swizzle) + copy(smem_out, out, swizzle=swizzle) + expected = ( + jax.random.randint( + jax.random.key(42), (m * n,), -16, 15, dtype=jnp.int8 + ) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .astype(dtype) + .transpose(0, 2, 1, 3) + ) + res = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 + )(expected) + np.testing.assert_array_equal(res, expected) + @parameterized.product( dtype=[jnp.float32, jnp.float16, jnp.int8], swizzle=(32, 64, 128), From 0a5924c81d303460ec76d4b249eba04dd5d734e0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Jun 2025 03:12:22 -0700 Subject: [PATCH 1493/1769] [pallas:mosaic_gpu] Dropped the `GPU` prefix from `GPUShapeDtypeStruct` I also slightly tweaked the docstring of `inline_mgpu`, since it references `GPUShapeDtypeStruct` PiperOrigin-RevId: 766575998 --- jax/_src/pallas/mosaic_gpu/primitives.py | 67 ++++++++++-------------- jax/experimental/pallas/mosaic_gpu.py | 2 +- tests/pallas/mosaic_gpu_test.py | 2 +- 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 9e40d046af13..2bd191b859ba 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1796,7 +1796,7 @@ def jaxpr_call( @dataclasses.dataclass(frozen=True) -class GPUShapeDtypeStruct: +class ShapeDtypeStruct: shape: tuple[int, ...] dtype: jnp.dtype layout: ParameterizedLayout | Layout @@ -1821,52 +1821,43 @@ def _undo_transforms( return tmp_ref.transforms -def inline_mgpu(arg_types=(), return_type=None): - """Decorate a function that inlines mgpu code. +def inline_mgpu(*, arg_types=(), return_type=None): + r"""Returns a decorator that inlines Mosaic GPU code. - Arguments provided to the decorated function may be Pallas - references or array values. The body will accept the corresponding - mgpu values. + This allows using lower-level Mosaic GPU abstractions and operations, which + are otherwise not directly exposed in Pallas. - The decorated function may return a tree of `FragmentedArray`s. + Example:: - ``` - layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) - @plgpu.inline_mgpu( - arg_types=(plgpu.RefType(),), - return_type=plgpu.GPUShapeDtypeStruct( - (128, 128), dtype, layout=layout - ), - ) - def foo(ctx, smem_ref): - del ctx - x = mgpu.FragmentedArray.load_tiled(smem_ref, ) - y = mgpu.FragmentedArray.splat( - mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout - ) - return (x + y) + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) - arr = foo(smem_ref) - ``` + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.ShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def add_one(ctx, smem_ref): + x = mgpu.FragmentedArray.load_tiled(smem_ref) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return x + y Args: - - arg_types: a sequence of pytrees where the leaves are `RefType` or - `Layout` for references or arrays respectively as the return - type. - - return_type: A pytree where the leaves are `GPUShapeDtypeStruct` - represeinting the arrays returned by the decorated function. - - Returns: - A decorator that creates a function that inlines mgpu code. - + arg_types: A sequence of pytrees where the leaves are + {class}`~jax.experimental.pallas.mosaic_gpu.RefType`\s or + {class}`~jax.experimental.pallas.mosaic_gpu.Layout`\s for reference or + array arguments respectively. + return_type: A pytree where the leaves are + {class}`~jax.experimental.pallas.mosaic_gpu.ShapeDtypeStruct`\s + representing the arrays returned by the decorated function. """ flat_arg_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) flat_ret_ty, pytree_ret_ty = jax.tree.flatten(return_type) - if return_type and not all(isinstance(r, GPUShapeDtypeStruct) for r in flat_ret_ty): + if return_type and not all(isinstance(r, ShapeDtypeStruct) for r in flat_ret_ty): raise ValueError( - "inline_mgpu_p only supports GPUShapeDtypeStructx return types." + "inline_mgpu_p only supports plgpu.ShapeDtypeStruct return types." ) if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types): raise ValueError( @@ -1951,7 +1942,7 @@ def _type_check_mgpu(v, ty): match (ty, v): case (RefType(), ir.Value()) if ir.MemRefType.isinstance(v.type): pass - case (GPUShapeDtypeStruct(), mgpu.FragmentedArray()): + case (ShapeDtypeStruct(), mgpu.FragmentedArray()): mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) if v.mlir_dtype != mlir_dtype: raise ValueError( diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 8c7870412403..1c47d391aa65 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -46,7 +46,7 @@ from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group -from jax._src.pallas.mosaic_gpu.primitives import GPUShapeDtypeStruct as GPUShapeDtypeStruct +from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 61a9f18ef26d..6cedbc6ae14c 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -423,7 +423,7 @@ def kernel(x_ref, o_ref, smem_ref, barrier): plgpu.TransposeTransform((1, 0, 2, 3)), plgpu.SwizzleTransform(128), )),), - return_type=plgpu.GPUShapeDtypeStruct( + return_type=plgpu.ShapeDtypeStruct( shape, dtype, layout=plgpu.Layout.WGMMA ), ) From d0d081564953d90ead6121356b6835f868d9afba Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Jun 2025 03:14:46 -0700 Subject: [PATCH 1494/1769] [pallas:mosaic] Removed the `TPU` prefix from `TPUCompilerParams` and `TPUMemorySpace` All TPU-specific APIs are always used qualified, e.g. `pltpu.TPUCompilerParams`, so the prefix is redundant. PiperOrigin-RevId: 766576735 --- docs/pallas/CHANGELOG.md | 5 +++ docs/pallas/quickstart.ipynb | 4 +- docs/pallas/quickstart.md | 4 +- docs/pallas/tpu/details.rst | 2 +- docs/pallas/tpu/distributed.ipynb | 42 +++++++++--------- docs/pallas/tpu/distributed.md | 42 +++++++++--------- docs/pallas/tpu/matmul.ipynb | 8 ++-- docs/pallas/tpu/matmul.md | 8 ++-- docs/pallas/tpu/pipelining.ipynb | 22 +++++----- docs/pallas/tpu/pipelining.md | 22 +++++----- jax/_src/pallas/mosaic/core.py | 14 +++--- jax/_src/pallas/mosaic/interpret.py | 44 +++++++++---------- jax/_src/pallas/mosaic/lowering.py | 32 +++++++------- .../pallas/mosaic/pallas_call_registration.py | 12 ++--- jax/_src/pallas/mosaic/pipeline.py | 4 +- jax/_src/pallas/mosaic/primitives.py | 2 +- jax/_src/pallas/pallas_call.py | 2 +- jax/experimental/pallas/ops/tpu/all_gather.py | 4 +- .../pallas/ops/tpu/flash_attention.py | 6 +-- jax/experimental/pallas/ops/tpu/matmul.py | 2 +- .../pallas/ops/tpu/megablox/gmm.py | 4 +- .../paged_attention/paged_attention_kernel.py | 2 +- .../ops/tpu/ragged_paged_attention/kernel.py | 2 +- .../splash_attention_kernel.py | 6 +-- jax/experimental/pallas/tpu.py | 35 ++++++++++++--- tests/pallas/tpu_fusible_matmul_test.py | 2 +- tests/pallas/tpu_ops_test.py | 2 +- tests/pallas/tpu_pallas_async_test.py | 4 +- tests/pallas/tpu_pallas_distributed_test.py | 2 +- .../tpu_pallas_interpret_distributed_test.py | 14 +++--- tests/pallas/tpu_pallas_interpret_test.py | 14 +++--- tests/pallas/tpu_pallas_pipeline_test.py | 14 +++--- tests/pallas/tpu_pallas_test.py | 10 ++--- 33 files changed, 209 insertions(+), 183 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 5c916c66ed86..e3589b87b720 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -23,6 +23,11 @@ Remember to align the itemized text with the first line of an item within a list * {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The old name is deprecated and will be removed in a future release. + * {class}`jax.experimental.pallas.tpu.TPUCompilerParams` + and {class}`jax.experimental.pallas.tpu.TPUMemorySpace` have been + renamed to {class}`jax.experimental.pallas.tpu.CompilerParams` + and {class}`jax.experimental.pallas.tpu.MemorySpace`. The + old names are deprecated and will be removed in a future release. ## Released with jax 0.6.1 diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 6460c1d5e739..ffdf715e984a 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -280,7 +280,7 @@ "metadata": {}, "source": [ "TPUs distinguish between vector and scalar memory spaces and in this case the\n", - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is\n", "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", "To call the above kernel on TPU, run:" ] @@ -297,7 +297,7 @@ "\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", " grid=(size,))()\n", "iota(8)" diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index d4865488a15b..5f1832f2a2f0 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -186,7 +186,7 @@ iota(8) ``` TPUs distinguish between vector and scalar memory spaces and in this case the -output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is a scalar. For more details read {ref}`tpu_and_its_memory_spaces`. To call the above kernel on TPU, run: @@ -196,7 +196,7 @@ from jax.experimental.pallas import tpu as pltpu def iota(size: int): return pl.pallas_call(iota_kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), grid=(size,))() iota(8) diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 0575806e6037..a961c376f5bc 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -170,7 +170,7 @@ grid axes over cores. This is an opt-in procedure. To allow that, .. pallas_call( ..., - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 75aeeb92ca43..ae82b7a80ac6 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -271,11 +271,11 @@ "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " scratch_shapes=(\n", " # We allocate DMA semaphores in scratch memory.\n", " [pltpu.SemaphoreType.DMA] * 2\n", @@ -420,10 +420,10 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " scratch_shapes=(\n", " # DMA semaphores are allocated in scratch memory.\n", " # We allocated one semaphore for a local HBM-VMEM copy,\n", @@ -569,7 +569,7 @@ "kernel = pl.pallas_call(\n", " example_kernel,\n", " ...,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "```" ] @@ -809,13 +809,13 @@ " num_scalar_prefetch=0,\n", " in_specs=[\n", " # Our input lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " ],\n", " out_specs=[\n", " # Our output lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " # Our double-buffer lives in HBM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices,),\n", " scratch_shapes=(\n", @@ -829,7 +829,7 @@ " all_reduce_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "\n", "pallas_result = jax.jit(\n", @@ -1146,11 +1146,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1169,7 +1169,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", @@ -1307,7 +1307,7 @@ "\n", "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", "\n", - "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", "\n", "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", "\n", @@ -1408,7 +1408,7 @@ "inner_block_spec = pl.BlockSpec(\n", " index_map=lambda i, j: (i, j),\n", " block_shape=inner_block_size,\n", - " memory_space=pltpu.TPUMemorySpace.ANY,\n", + " memory_space=pltpu.MemorySpace.ANY,\n", ")\n", "\n", "\n", @@ -1590,11 +1590,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1612,7 +1612,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index 7b1f26bccf89..b16116549972 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -233,11 +233,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -356,10 +356,10 @@ out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + # MemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -491,7 +491,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa kernel = pl.pallas_call( example_kernel, ..., - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) ``` @@ -703,13 +703,13 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -723,7 +723,7 @@ kernel = pl.pallas_call( all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -1019,11 +1019,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1042,7 +1042,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] @@ -1148,7 +1148,7 @@ pl.pallas_call( In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter. -We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. +We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM: @@ -1242,7 +1242,7 @@ inner_grid = ( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pltpu.MemorySpace.ANY, ) @@ -1424,11 +1424,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1446,7 +1446,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 9c90add16ab0..3ae5f95c204a 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,7 +210,7 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -466,7 +466,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -741,7 +741,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -929,7 +929,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index 42084f12d5f5..7ac157b4a2e9 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -167,7 +167,7 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,7 +321,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -489,7 +489,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,7 +613,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 68932f4d1e40..829cda000e5d 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -123,15 +123,15 @@ "\n", "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", "| --- | --- | --- |\n", - "| `pltpu.TPUMemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n", - "| `pltpu.TPUMemorySpace.VMEM` | VMEM | SRAM |\n", - "| `pltpu.TPUMemorySpace.SMEM` | SMEM | SRAM |\n", - "| `pltpu.TPUMemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", + "| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n", + "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n", + "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n", + "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", "\n", - "- `TPUMemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", - "- `TPUMemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", - "- `TPUMemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n", - "- `TPUMemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n", + "- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", + "- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", + "- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n", + "- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n", "\n", "Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM.\n", "\n", @@ -164,9 +164,9 @@ "\n", "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", "out = pl.pallas_call(hbm_vmem_kernel,\n", - " in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)],\n", + " in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],\n", " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", - " scratch_shapes=(pltpu.TPUMemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", + " scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", ")(x)\n", "\n", "np.testing.assert_allclose(out, x[0:1] + 1)" @@ -259,7 +259,7 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index b9ed41f937c8..44a252410151 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -94,15 +94,15 @@ Pallas exposes all levels of the TPU memory hierarchy to users. The following ta | Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | | --- | --- | --- | -| `pltpu.TPUMemorySpace.ANY` | HBM (usually) or VMEM | DRAM | -| `pltpu.TPUMemorySpace.VMEM` | VMEM | SRAM | -| `pltpu.TPUMemorySpace.SMEM` | SMEM | SRAM | -| `pltpu.TPUMemorySpace.SEMAPHORE` | Semaphore | SRAM | +| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM | +| `pltpu.MemorySpace.VMEM` | VMEM | SRAM | +| `pltpu.MemorySpace.SMEM` | SMEM | SRAM | +| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM | -- `TPUMemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. -- `TPUMemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. -- `TPUMemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`. -- `TPUMemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details. +- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. +- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. +- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`. +- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details. Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM. @@ -128,9 +128,9 @@ def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref): x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) out = pl.pallas_call(hbm_vmem_kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), - scratch_shapes=(pltpu.TPUMemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) + scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) )(x) np.testing.assert_allclose(out, x[0:1] + 1) @@ -190,7 +190,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel",)) )(x, y) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 49ff632f5c14..b864a1df2ec5 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -66,7 +66,7 @@ class GridDimensionSemantics(enum.Enum): @dataclasses.dataclass(frozen=True) -class TPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. Attributes: @@ -102,7 +102,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): # Replace is a method, not a field. replace = dataclasses.replace -class TPUMemorySpace(enum.Enum): +class MemorySpace(enum.Enum): ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" SMEM = "smem" @@ -135,7 +135,7 @@ def __call__(self, shape: tuple[int, ...]): dtype = pallas_core.BarrierSemaphore() else: dtype = pallas_core.Semaphore() - return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + return pallas_core.MemoryRef(shape, dtype, MemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: return self(()).get_array_aval() @@ -166,7 +166,7 @@ def __init__( def _make_scalar_ref_aval(self, aval): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), - TPUMemorySpace.SMEM) + MemorySpace.SMEM) @dataclasses.dataclass(frozen=True) @@ -223,12 +223,12 @@ def _tensorcore_mesh_discharge_rule( name: str, ): assert isinstance(mesh, TensorCoreMesh) - if compiler_params and not isinstance(compiler_params, TPUCompilerParams): + if compiler_params and not isinstance(compiler_params, CompilerParams): raise ValueError( - "compiler_params must be a pltpu.TPUCompilerParams" + "compiler_params must be a pltpu.CompilerParams" ) if not compiler_params: - compiler_params = TPUCompilerParams() + compiler_params = CompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") if compiler_params.dimension_semantics is not None: diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 401ed02288bc..03c99c794ac7 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -817,16 +817,16 @@ def _allocate_semaphores( ).reshape(shape) -TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | pallas_core.MemorySpace | None, int] = { - v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)} +TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.MemorySpace | pallas_core.MemorySpace | None, int] = { + v: i for i, v in enumerate(mosaic_core.MemorySpace)} TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = ( - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY]) + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY]) TPU_MEMORY_SPACE_NAMES = { - i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)} + i: v.value for i, v in enumerate(mosaic_core.MemorySpace)} # Default to VMEM when no memory space is specified. TPU_MEMORY_SPACE_IDXS[None] = ( - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM]) + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.VMEM]) def get_barrier_semaphore(device_id, collective_id): del device_id @@ -1340,7 +1340,7 @@ def _to_jaxpr(flat_fun, in_avals): return new_jaxpr def _is_any(memory_space): - return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or + return ((memory_space == mosaic_core.MemorySpace.ANY) or (memory_space == pallas_core.MemorySpace.ANY)) def _is_float(dtype): @@ -1521,7 +1521,7 @@ def f(*args, jaxpr): # runs the same sequence of `run_scoped`s. allocs = [] for v in eqn.params['jaxpr'].invars: - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: allocs.append(callback.io_callback( _allocate_semaphores, jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), @@ -1543,7 +1543,7 @@ def f(*args, jaxpr): out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) for a, v in zip(allocs, eqn.params['jaxpr'].invars): - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: # TODO(jburnim): De-allocate semaphores. # callback.io_callback( # _deallocate_semaphores, @@ -1609,9 +1609,9 @@ def f(*args, jaxpr): (), device_id, local_core_id, - TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.MemorySpace.ANY)], src, src_transforms, - TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.MemorySpace.ANY)], dst, dst_transforms, state_discharge.transform_array(dst_sem, dst_sem_transforms), state_discharge.transform_array(src_sem, src_sem_transforms), @@ -1749,11 +1749,11 @@ def _get_next_indices(grid, indices): return tuple(reversed(next_indices)) -def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> tpu_core.TPUCompilerParams: +def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> tpu_core.CompilerParams: try: - return cast(tpu_core.TPUCompilerParams, compiler_params['mosaic_tpu']) + return cast(tpu_core.CompilerParams, compiler_params['mosaic_tpu']) except KeyError: - return tpu_core.TPUCompilerParams() + return tpu_core.CompilerParams() def _get_parallel_dim_semantics( @@ -1762,7 +1762,7 @@ def _get_parallel_dim_semantics( """Returns a tuple indicating which grid dimensions have parallel semantics. Args: - compiler_params: Representation of a `mosaic_core.TPUCompilerParams` object + compiler_params: Representation of a `mosaic_core.CompilerParams` object as a dictionary. num_dimensions_in_grid: The number of dimensions in the grid. @@ -1818,7 +1818,7 @@ def _get_randomized_grid_coordinates( Args: grid: Tuple of sizes of the dimensions in the grid. - compiler_params: Representation of a `mosaic_core.TPUCompilerParams` object + compiler_params: Representation of a `mosaic_core.CompilerParams` object as a dictionary. parallel_semantics_per_dim: A tuple of booleans indicating whether the corresponding dimension in the grid has parallel semantics. @@ -1986,7 +1986,7 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, None, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], input_args[i], ordered=True)) @@ -2016,7 +2016,7 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, None, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], padded_val, ordered=True, ) @@ -2036,7 +2036,7 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, None, # local_core_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.SMEM], val, ordered=True, ) @@ -2047,7 +2047,7 @@ def interpret_pallas_call( output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) - if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + if var.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: kernel_buffer_ids.append( callback.io_callback( _allocate_semaphores, @@ -2241,7 +2241,7 @@ def _store_slice_to_kernel_input(index, input_var): jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), device_id, cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], input_buffer_ids[index], (transform,), ordered=True, @@ -2318,7 +2318,7 @@ def _store_to_output_buffer(index, output_var): (), device_id, cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], output_buffer_ids[index], (transform,), kernel_output_val, @@ -2398,7 +2398,7 @@ def _store_to_output_buffer(index, output_var): val, device_id, 0, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], output_buffer_id, (indexing.NDIndexer.from_indices_shape( tuple(indexing.ds(0, s) for s in val.shape), diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 873bf587093a..b1a2c186e2c9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,12 +15,12 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Callable, Collection, Sequence +from collections.abc import Callable, Collection, Hashable, Sequence import contextlib import dataclasses import functools import string -from typing import Any, Hashable, TypeVar +from typing import Any, TypeVar import jax from jax import api_util @@ -86,10 +86,10 @@ # mypy: ignore-errors NDIndexer = indexing.NDIndexer -TPUMemorySpace = tpu_core.TPUMemorySpace -MemorySpace = pallas_core.MemorySpace | TPUMemorySpace -VMEM = tpu_core.TPUMemorySpace.VMEM -SMEM = tpu_core.TPUMemorySpace.SMEM +TPUMemorySpace = tpu_core.MemorySpace +AnyMemorySpace = pallas_core.MemorySpace | TPUMemorySpace +VMEM = TPUMemorySpace.VMEM +SMEM = TPUMemorySpace.SMEM # Booleans are stored as the following type in memrefs. BOOL_MEMREF_TYPE = np.dtype('int32') @@ -212,7 +212,7 @@ def forward_compatible(self): return self.lowering_context.forward_compatible -def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None +def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None ) -> TPUMemorySpace: match memory_space: case None: @@ -235,7 +235,7 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None raise ValueError(f"Invalid memory space: {memory_space}") -def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None +def _memory_space_to_mosaic_attribute(memory_space: AnyMemorySpace | None ) -> ir.Attribute: tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") @@ -266,7 +266,7 @@ def aval_to_ir_type( dynamic_shape_replacement_fn, aval, shape=None, - memory_space: MemorySpace | None = None, + memory_space: AnyMemorySpace | None = None, is_kernel_boundary: bool = False, ): if isinstance(aval, tpu_core.AbstractSemaphore): @@ -600,9 +600,9 @@ def _check_block_mappings( # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too memory_space = _memory_space_to_tpu_memory_space(bm.block_aval.memory_space) - if memory_space == tpu_core.TPUMemorySpace.SMEM and bm.has_trivial_window(): + if memory_space == tpu_core.MemorySpace.SMEM and bm.has_trivial_window(): continue - if memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + if memory_space == tpu_core.MemorySpace.SEMAPHORE: continue def err_details(): @@ -619,7 +619,7 @@ def err_details(): "rank >= 1. " + err_details()) if ( - memory_space == tpu_core.TPUMemorySpace.ANY + memory_space == tpu_core.MemorySpace.ANY and not bm.has_trivial_window() ): raise ValueError( @@ -761,8 +761,8 @@ def dynamic_shape_replacement_fn( tpu_memory_space = _memory_space_to_tpu_memory_space( bm.block_aval.memory_space) if ( - tpu_memory_space == tpu_core.TPUMemorySpace.ANY - or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE + tpu_memory_space == tpu_core.MemorySpace.ANY + or tpu_memory_space == tpu_core.MemorySpace.SEMAPHORE ): # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) @@ -784,7 +784,7 @@ def dynamic_shape_replacement_fn( # Force single-buffering pipelining for trivial windowing in VMEM. pipeline_mode = bm.pipeline_mode if ( - tpu_memory_space == tpu_core.TPUMemorySpace.VMEM + tpu_memory_space == tpu_core.MemorySpace.VMEM and bm.has_trivial_window() ): pipeline_mode = pallas_core.Buffered(1) @@ -1520,7 +1520,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ): if not is_smem_load: raise ValueError("PRNG keys must be loaded from SMEM. Did you set " - "the memory space to TPUMemorySpace.SMEM in the " + "the memory space to MemorySpace.SMEM in the " "BlockSpec for the PRNG key input?") return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) if not is_smem_load and not ref_block_shape: diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 74253e809a35..528f897edf74 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -80,13 +80,13 @@ def _get_memory_space_from_aval( match out_aval.memory_space: case None: return None - case tpu_core.TPUMemorySpace.ANY: + case tpu_core.MemorySpace.ANY: return None - case tpu_core.TPUMemorySpace.VMEM: + case tpu_core.MemorySpace.VMEM: return tpu_custom_call.MemorySpace.VMEM - case tpu_core.TPUMemorySpace.SMEM: + case tpu_core.MemorySpace.SMEM: return tpu_custom_call.MemorySpace.SMEM - case tpu_core.TPUMemorySpace.SEMAPHORE: + case tpu_core.MemorySpace.SEMAPHORE: return tpu_custom_call.MemorySpace.SEMAPHORE_MEM return None @@ -126,10 +126,10 @@ def pallas_call_tpu_lowering_rule( if "mosaic_tpu" in compiler_params: mosaic_params = cast( - tpu_core.TPUCompilerParams, compiler_params["mosaic_tpu"] + tpu_core.CompilerParams, compiler_params["mosaic_tpu"] ) else: - mosaic_params = tpu_core.TPUCompilerParams() + mosaic_params = tpu_core.CompilerParams() jax_mesh = None axis_context = ctx.module_context.axis_context diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index f4dab313fb6f..659146a4ad7e 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -36,8 +36,8 @@ import numpy as np -SMEM = tpu_core.TPUMemorySpace.SMEM -VMEM = tpu_core.TPUMemorySpace.VMEM +SMEM = tpu_core.MemorySpace.SMEM +VMEM = tpu_core.MemorySpace.VMEM REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics PARALLEL = tpu_core.PARALLEL diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 33a1de12ebde..c9cdcbf56f85 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -642,7 +642,7 @@ def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, def _get_barrier_semaphore_abstract_eval(): return pl_core.AbstractMemoryRef( jax_core.ShapedArray((), pl_core.BarrierSemaphore()), - tpu_core.TPUMemorySpace.SEMAPHORE, + tpu_core.MemorySpace.SEMAPHORE, ) def get_barrier_semaphore(): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 016bac96424e..93b929d219c9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1545,7 +1545,7 @@ def pallas_call( {file}:{line}`. compiler_params: Optional compiler parameters. The value should either be a backend-specific dataclass - (:class:`jax.experimental.pallas.tpu.TPUCompilerParams`, + (:class:`jax.experimental.pallas.tpu.CompilerParams`, :class:`jax.experimental.pallas.triton.CompilerParams`, :class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict mapping backend name to the corresponding platform-specific dataclass. diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index a0eb07f719ec..ce80a443547e 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -120,7 +120,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, jax.jit, static_argnames=["mesh", "axis_name", "memory_space"] ) def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], - memory_space: pltpu.TPUMemorySpace = pltpu.VMEM): + memory_space: pltpu.MemorySpace = pltpu.VMEM): if isinstance(axis_name, str): axis_name = (axis_name,) # TODO(sharadmv): enable all gather over multiple axes @@ -136,7 +136,7 @@ def ag_local(x_shard): out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 06746986a15e..27f66d34e354 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -767,7 +767,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shape, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -1130,7 +1130,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -1465,7 +1465,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 06d868168f9e..341aa93fa258 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -78,7 +78,7 @@ def matmul( grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), debug=debug, )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 5c2f938597e7..cb185fc45f1d 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,7 +538,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, @@ -777,7 +777,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 9c02679c45ea..309858368896 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -648,7 +648,7 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=dimension_semantics ), out_shape=[ diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index 67c0b376ecc6..e7bc599b2b2b 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -891,7 +891,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "arbitrary", "arbitrary", diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index b69b0e36f177..34d8847e6193 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -1118,7 +1118,7 @@ def logsumexp_index_map(h, i, *_): out_specs=out_specs, grid=grid, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary"), ), out_shape=out_shapes, @@ -1577,7 +1577,7 @@ def logsumexp_index_map(h, i, *_): grid=grid, ), out_shape=out_shapes, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, @@ -2126,7 +2126,7 @@ def logsumexp_index_map( # megacore # 2) for heads, we are reducing over heads # 3) for q_seq_len, we are reducing over it to compute dkv - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index c4d21023a6e6..5ac79558b11d 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -23,8 +23,8 @@ from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace -from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams +from jax._src.pallas.mosaic.core import MemorySpace as MemorySpace +from jax._src.pallas.mosaic.core import CompilerParams as CompilerParams from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core @@ -68,8 +68,29 @@ ) del types, assume, pretend, skip, define_model # Clean up. -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM -SEMAPHORE = TPUMemorySpace.SEMAPHORE +ANY = MemorySpace.ANY +CMEM = MemorySpace.CMEM +SMEM = MemorySpace.SMEM +VMEM = MemorySpace.VMEM +SEMAPHORE = MemorySpace.SEMAPHORE + +import typing as _typing # pylint: disable=g-import-not-at-top +if _typing.TYPE_CHECKING: + TPUCompilerParams = CompilerParams + TPUMemorySpace = MemorySpace +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + _deprecations = { + # Deprecated on May 30th 2025. + "TPUCompilerParams": ( + "TPUCompilerParams is deprecated, use CompilerParams instead.", + CompilerParams, + ), + "TPUMemorySpace": ( + "TPUMemorySpace is deprecated, use MemorySpace instead.", + MemorySpace, + ), + } + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/tests/pallas/tpu_fusible_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py index 2382c09f26ac..4bde9b95483b 100644 --- a/tests/pallas/tpu_fusible_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -177,7 +177,7 @@ def z_index_map(i, j, k, *_): ], out_specs=[z_out_block_spec], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=dimension_semantics, ), out_shape=[z_out_type], diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 3f6dc593e333..a67e74d617b6 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -534,7 +534,7 @@ def kernel(src, tgt): run = pl.pallas_call( kernel, jax.ShapeDtypeStruct(tgt_shape, jnp.float32), - compiler_params=pltpu.TPUCompilerParams(disable_bounds_checks=True), + compiler_params=pltpu.CompilerParams(disable_bounds_checks=True), ) output = run(x) np.testing.assert_array_equal( diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index e464214928e4..c70fb6ea2ff5 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -436,7 +436,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=True ), )(x) @@ -537,7 +537,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=False ), )(x) diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index aa4488b778a8..11b159dbec4c 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -235,7 +235,7 @@ def body(x): in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(x) device_mesh = mesh_utils.create_device_mesh( diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index c5f1b29fd6bc..70ed3dc576e5 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -92,7 +92,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ pl.BlockSpec(memory_space=pltpu.ANY), ], @@ -106,7 +106,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): right_permute_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=13), + compiler_params=pltpu.CompilerParams(collective_id=13), interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) @@ -206,7 +206,7 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. pl.BlockSpec(memory_space=pltpu.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), @@ -230,7 +230,7 @@ def _(): grid_spec=grid_spec, interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) # Wrap the kernel within a shard_map to call. @@ -390,7 +390,7 @@ def _(): grid_spec=grid_spec, interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -674,7 +674,7 @@ def pallas_reduce_scatter(input_arr): grid_spec=grid_spec, interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=True), - compiler_params=pltpu.TPUCompilerParams(collective_id=7), + compiler_params=pltpu.CompilerParams(collective_id=7), )(input_arr)[0] pallas_result = jax.jit( @@ -978,7 +978,7 @@ def pallas_reduce_scatter(input_arr): grid_spec=grid_spec, interpret=pltpu.TPUInterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=19), + compiler_params=pltpu.CompilerParams(collective_id=19), )(input_arr)[0] pallas_result = jax.jit( diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 871f66d71c53..9cf98e6c1dd4 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -381,7 +381,7 @@ def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): interpret=pltpu.TPUInterpretParams( random_seed=12345, grid_point_recorder=grid_point_recorder ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel', 'arbitrary') ), )(s) @@ -437,7 +437,7 @@ def kernel(s_ref, o_ref): in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), interpret=pltpu.TPUInterpretParams(random_seed=12345), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('arbitrary', 'parallel') ), )(s) @@ -463,7 +463,7 @@ def kernel_call_dynamic_parallel_dimension(): in_specs=[], out_specs=pl.BlockSpec((1,), lambda _: (0,)), interpret=pltpu.TPUInterpretParams(), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), )() @@ -516,7 +516,7 @@ def kernel(x_ref, o_ref, vmem_ref): kernel, grid=(2,), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), ], @@ -524,7 +524,7 @@ def kernel(x_ref, o_ref, vmem_ref): num_cores_per_device=2, detect_races=True, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',), ), )(x).block_until_ready() @@ -558,7 +558,7 @@ def kernel(x_ref, o_ref, vmem_ref): num_cores_per_device=2, detect_races=True, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), )(x).block_until_ready() @@ -583,7 +583,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): num_cores_per_device=num_cores_per_device, grid_point_recorder=grid_point_recorder, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel', 'arbitrary') ), )(s) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 59ac680d3ac3..1c10fa4b73e5 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -481,7 +481,7 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! e.g.: # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -724,7 +724,7 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! e.g.: # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -1007,7 +1007,7 @@ def _loop_epilogue(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -1268,7 +1268,7 @@ def _prefetch_accumulator(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -1353,7 +1353,7 @@ def mul_kernel(iters_ref, x_ref, y_ref): out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), ) @@ -1389,7 +1389,7 @@ def matmul_kernel(x_ref, y_ref): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), ) @@ -1440,7 +1440,7 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), ) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index aac249251e2b..c232ebeedb38 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -477,7 +477,7 @@ def kernel(s, x): ), grid=8, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( allow_input_fusion=[False, True] ), )(s, x) @@ -1913,12 +1913,12 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=int(2**18)), )(x) def test_allow_input_fusion(self): @@ -1935,7 +1935,7 @@ def f(x, y): in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, - compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), + compiler_params=pltpu.CompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -1963,7 +1963,7 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( internal_scratch_in_bytes=requested_bytes, ), )(x) From 87641ccb80bad1845bf8abb35a6d71ef79b1fae2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Jun 2025 06:38:44 -0700 Subject: [PATCH 1495/1769] [pallas:mosaic] Dropped the `TPU` prefix from the recently added `TPUInterpreterParams` We no longer use platform prefixes in all other Pallas APIs. PiperOrigin-RevId: 766639632 --- jax/_src/pallas/core.py | 4 +-- jax/_src/pallas/mosaic/interpret.py | 6 ++-- jax/_src/pallas/pallas_call.py | 6 ++-- jax/experimental/pallas/tpu.py | 2 +- .../tpu_pallas_interpret_distributed_test.py | 12 +++---- tests/pallas/tpu_pallas_interpret_test.py | 32 +++++++++---------- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index a05f97eb122f..047edd2b8435 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1250,7 +1250,7 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **kwargs): if interpret: try: from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): effs = mosaic_tpu_interpret.get_interpret_effects() except ImportError: pass @@ -1353,7 +1353,7 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): if interpret: try: from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): effs = mosaic_tpu_interpret.get_interpret_effects() except ImportError: pass diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 03c99c794ac7..3be718aa0aa6 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -67,7 +67,7 @@ @dataclasses.dataclass(frozen=True) -class TPUInterpretParams: +class InterpretParams: """Parameters for Mosaic TPU interpret mode. Attributes: @@ -524,7 +524,7 @@ def check_write(device_id, local_core_id, clock, buffer_key, rnge, source_info=N @dataclasses.dataclass class SharedMemory: - interpret_params: TPUInterpretParams + interpret_params: InterpretParams num_devices: int num_cores_per_device: int clocks: list[VectorClock] @@ -1926,7 +1926,7 @@ def interpret_pallas_call( compiler_params: dict[str, Any], cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], - interpret_params: TPUInterpretParams, + interpret_params: InterpretParams, ): del debug, cost_estimate, out_avals diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 93b929d219c9..b14259556faf 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -96,7 +96,7 @@ def _pallas_call_abstract_eval( ): del avals - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): # Report effects that will be introduced when running/lowering # mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call . effs = mosaic_tpu_interpret.get_interpret_effects() @@ -1261,7 +1261,7 @@ def _pallas_call_lowering( if params['jaxpr'].constvars: raise ValueError('Cannot lower a pallas_call with constants.') if interpret: - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): impl = partial(mosaic_tpu_interpret.interpret_pallas_call, interpret_params=interpret, **params) @@ -1774,5 +1774,5 @@ def in_path_to_input_origin( from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret except ImportError: mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore - TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)), + InterpretParams=types.new_class('_NoInstances', (enum.Enum,)), ) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 5ac79558b11d..eceb2e4f0383 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -28,7 +28,7 @@ from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core -from jax._src.pallas.mosaic.interpret import TPUInterpretParams as TPUInterpretParams +from jax._src.pallas.mosaic.interpret import InterpretParams as InterpretParams from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef from jax._src.pallas.mosaic.pipeline import BufferedRefBase as BufferedRefBase diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 70ed3dc576e5..ddfe8bcde4f4 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -107,7 +107,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape=out_shape, grid_spec=grid_spec, compiler_params=pltpu.CompilerParams(collective_id=13), - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. @@ -228,7 +228,7 @@ def _(): all_gather_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.CompilerParams(collective_id=0), ) @@ -388,7 +388,7 @@ def _(): all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.CompilerParams(collective_id=0), ) @@ -672,7 +672,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=True), compiler_params=pltpu.CompilerParams(collective_id=7), )(input_arr)[0] @@ -976,7 +976,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.CompilerParams(collective_id=19), )(input_arr)[0] @@ -1064,7 +1064,7 @@ def run(src_dst_ids): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode='eager', detect_races=True, ), diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9cf98e6c1dd4..5bfca2270aa1 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -124,7 +124,7 @@ def matmul(x: jax.Array, y: jax.Array): (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), ), - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -155,7 +155,7 @@ def block_dynamic_slice(x, starts, sizes): dynamic_slice_kernel, grid_spec=grid_spec, out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), ) block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) return kernel(block_idx, x) @@ -189,7 +189,7 @@ def f(s, x): ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) @@ -224,7 +224,7 @@ def _(): ), scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), input_output_aliases={0: 0}, - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), )(x) expected = np.zeros((4, 4)) @@ -264,7 +264,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( detect_races=True, dma_execution_mode=dma_execution_mode ), )(x).block_until_ready() @@ -279,7 +279,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( detect_races=True, dma_execution_mode=dma_execution_mode ), )(x).block_until_ready() @@ -293,7 +293,7 @@ def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( skip_floating_point_ops=True ), )(x, y) @@ -325,7 +325,7 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.bfloat16), pltpu.VMEM((8, 128), jnp.int16), ], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( uninitialized_memory=uninitialized_memory ), )() @@ -355,7 +355,7 @@ def kernel_call(x, s): pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), )(x, s) with CountStoreCallbacksContext() as store_callbacks_counter: @@ -378,7 +378,7 @@ def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( random_seed=12345, grid_point_recorder=grid_point_recorder ), compiler_params=pltpu.CompilerParams( @@ -436,7 +436,7 @@ def kernel(s_ref, o_ref): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=pltpu.TPUInterpretParams(random_seed=12345), + interpret=pltpu.InterpretParams(random_seed=12345), compiler_params=pltpu.CompilerParams( dimension_semantics=('arbitrary', 'parallel') ), @@ -462,7 +462,7 @@ def kernel_call_dynamic_parallel_dimension(): grid=(dim_size,), in_specs=[], out_specs=pl.BlockSpec((1,), lambda _: (0,)), - interpret=pltpu.TPUInterpretParams(), + interpret=pltpu.InterpretParams(), compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) ), @@ -479,7 +479,7 @@ def f(x): y = jnp.zeros_like(x) def inner(refs): x_ref, y_ref = refs - @pl.core_map(mesh, interpret=pltpu.TPUInterpretParams()) + @pl.core_map(mesh, interpret=pltpu.InterpretParams()) def _(): num_cores = jax.lax.psum(1, "x") slc_size = 16 // num_cores @@ -520,7 +520,7 @@ def kernel(x_ref, o_ref, vmem_ref): scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), ], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( num_cores_per_device=2, detect_races=True, ), @@ -554,7 +554,7 @@ def kernel(x_ref, o_ref, vmem_ref): scratch_shapes=[ pltpu.VMEM((8, 128), x.dtype), ], - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( num_cores_per_device=2, detect_races=True, ), @@ -578,7 +578,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): grid=(4, 4), in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), - interpret=pltpu.TPUInterpretParams( + interpret=pltpu.InterpretParams( random_seed=12345, num_cores_per_device=num_cores_per_device, grid_point_recorder=grid_point_recorder, From 6241a2aa9836ef380fd793193721e196990fedcf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 09:39:55 -0700 Subject: [PATCH 1496/1769] Propagate layouts correctly via mutable arrays PiperOrigin-RevId: 766702441 --- jax/_src/core.py | 1 + jax/_src/interpreters/pxla.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1355fc10472f..2452d12dad5e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2316,6 +2316,7 @@ def __init__(self, aval, buf): shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) sharding = property(lambda self: self._buf.sharding) + format = property(lambda self: self._buf.format) def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0530e313f310..55f765ba981a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2095,7 +2095,7 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) - in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) # TODO(mattjj) + in_layouts = (*in_layouts, *(c.format.dll for c in mut.in_mut)) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) out_layouts_ = iter(zip(out_shardings, out_layouts)) out_shardings, out_layouts = unzip2( From d17b29230605673f44c72e0617c31a198d6a809f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 10:07:52 -0700 Subject: [PATCH 1497/1769] Don't canonicalize in `__eq__` if `other` is a PartitionSpec since it is already canonicalized PiperOrigin-RevId: 766714598 --- jax/_src/partition_spec.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 040db35ccb2b..2c833e6544e4 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -111,18 +111,18 @@ def __len__(self): return len(self._partitions) def __eq__(self, other): - if not isinstance(other, (PartitionSpec, tuple)): - return False - other_p = tuple(_canonicalize_partition(o) for o in other) if isinstance(other, PartitionSpec): - return (self._partitions == other_p and + return (self._partitions == other._partitions and self._unreduced == other._unreduced) - else: + elif isinstance(other, tuple): if self._unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__eq__` of PartitionSpec.") + other_p = tuple(_canonicalize_partition(o) for o in other) return self._partitions == other_p + else: + return False def __hash__(self): return hash((self._partitions, self._unreduced)) From cecf2f6bede94324c4fc159c428a3d024ce431fe Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Jun 2025 10:13:39 -0700 Subject: [PATCH 1498/1769] [imports] avoid top-level imports in jax.numpy sources --- jax/_src/lax/fft.py | 3 +- jax/_src/numpy/array_api_metadata.py | 4 +- jax/_src/numpy/array_creation.py | 30 +++++++------ jax/_src/numpy/array_methods.py | 29 ++++++------ jax/_src/numpy/error.py | 18 ++++---- jax/_src/numpy/fft.py | 54 +++++++++++----------- jax/_src/numpy/index_tricks.py | 10 ++--- jax/_src/numpy/indexing.py | 8 ++-- jax/_src/numpy/lax_numpy.py | 40 ++++++++--------- jax/_src/numpy/linalg.py | 27 +++++------ jax/_src/numpy/polynomial.py | 2 +- jax/_src/numpy/reductions.py | 8 ++-- jax/_src/numpy/scalar_types.py | 7 +-- jax/_src/numpy/setops.py | 7 +-- jax/_src/numpy/sorting.py | 18 ++++---- jax/_src/numpy/tensor_contractions.py | 16 ++++--- jax/_src/numpy/ufunc_api.py | 64 ++++++++++++++------------- jax/_src/numpy/vectorize.py | 2 +- jax/_src/numpy/window_functions.py | 2 +- 19 files changed, 182 insertions(+), 167 deletions(-) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 2eebe6d91f22..08e06287b784 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -21,8 +21,6 @@ import numpy as np -from jax import lax - from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -30,6 +28,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo __all__ = [ diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index d634a2856a1b..af4e27cd5f8d 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -21,7 +21,6 @@ from types import ModuleType -import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc from jax._src import config @@ -40,6 +39,7 @@ def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: if api_version is not None and api_version != __array_api_version__: raise ValueError(f"{api_version=!r} is not available; " f"available versions are: {[__array_api_version__]}") + import jax.numpy # pytype: disable=import-error return jax.numpy @@ -77,7 +77,7 @@ def default_device(self): def devices(self): out = [None] # None indicates "uncommitted" for backend in xb.backends(): - out.extend(jax.devices(backend)) + out.extend(xb.devices(backend)) return out def capabilities(self): diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 86bcfb2c02f6..b14c2fe73faa 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -19,18 +19,16 @@ import numpy as np -import jax -from jax import lax -from jax._src.api import jit +from jax._src.api import device_put, jit from jax._src import core from jax._src import dtypes -from jax._src.lax import lax as lax_internal +from jax._src.lax import lax from jax._src.lib import xla_client as xc from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.sharding import Sharding from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike from jax._src.util import canonicalize_axis, set_module -from jax.sharding import Sharding export = set_module('jax.numpy') @@ -205,6 +203,8 @@ def full(shape: Any, fill_value: ArrayLike, Array([[0, 1, 2], [0, 1, 2]], dtype=int32) """ + from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error + dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) @@ -212,8 +212,8 @@ def full(shape: Any, fill_value: ArrayLike, shape = canonicalize_shape(shape) return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device)) else: - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @export @@ -394,6 +394,8 @@ def full_like(a: ArrayLike | DuckTypedArray, Array([[1, 1, 1], [2, 2, 2]], dtype=int32) """ + from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error + if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: @@ -408,8 +410,8 @@ def full_like(a: ArrayLike | DuckTypedArray, else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] dtype = dtypes.result_type(a) if dtype is None else dtype - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, @@ -510,6 +512,8 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: """Implementation of linspace differentiable in start and stop args.""" + from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error + dtypes.check_user_dtype_supported(dtype, "linspace") if num < 0: raise ValueError(f"Number of samples, {num}, must be non-negative.") @@ -529,13 +533,13 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, bounds_shape.insert(axis, 1) div = (num - 1) if endpoint else num if num > 1: - delta: Array = lax.convert_element_type(stop - start, computation_dtype) / jax.numpy.array(div, dtype=computation_dtype) + delta: Array = lax.convert_element_type(stop - start, computation_dtype) / asarray(div, dtype=computation_dtype) iota_shape = [1,] * len(bounds_shape) iota_shape[axis] = div # This approach recovers the endpoints with float32 arithmetic, # but can lead to rounding errors for integer outputs. real_dtype = dtypes.finfo(computation_dtype).dtype - step = lax.iota(real_dtype, div).reshape(iota_shape) / jax.numpy.array(div, real_dtype) + step = lax.iota(real_dtype, div).reshape(iota_shape) / asarray(div, real_dtype) step = step.astype(computation_dtype) out = (broadcast_start.reshape(bounds_shape) * (1 - step) + broadcast_stop.reshape(bounds_shape) * step) @@ -545,7 +549,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, canonicalize_axis(axis, out.ndim)) elif num == 1: - delta = jax.numpy.asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) + delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) out = broadcast_start.reshape(bounds_shape) else: # num == 0 degenerate case, match numpy behavior empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) @@ -557,7 +561,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, out = lax.floor(out) sharding = util.canonicalize_device_to_sharding(device) - result = lax_internal._convert_element_type(out, dtype, sharding=sharding) + result = lax._convert_element_type(out, dtype, sharding=sharding) return (result, delta) if retstep else result diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index b29b95219325..958085e19a53 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -29,9 +29,8 @@ from typing import Any, Callable, Sequence import numpy as np -import jax + from jax import lax -from jax.sharding import Sharding from jax._src import api from jax._src import core from jax._src import dtypes @@ -44,6 +43,7 @@ from jax._src.numpy import lax_numpy from jax._src.numpy import tensor_contractions from jax._src.pjit import PartitionSpec +from jax._src.sharding import Sharding from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -612,12 +612,13 @@ def _deepcopy(self: Array, memo: Any) -> Array: def __array_module__(self, types): if all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + import jax.numpy # pytype: disable=import-error return jax.numpy else: return NotImplemented -@partial(jax.jit, static_argnums=(1,2,3)) +@partial(api.jit, static_argnums=(1,2,3)) def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], limit_indices: tuple[tuple[int, ...]], @@ -637,7 +638,7 @@ def _multi_slice(self: Array, # The next two functions are related to iter(array), implemented here to # avoid circular imports. -@jax.jit +@api.jit def _unstack(x: Array) -> list[Array]: dims = (0,) return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] @@ -776,7 +777,7 @@ def __repr__(self) -> str: return f"_IndexUpdateRef({self.array!r}, {self.index!r})" def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None, + mode: str | lax.GatherScatterMode | None = None, fill_value: ArrayLike | None = None, out_sharding: Sharding | None = None): """Equivalent to ``x[idx]``. @@ -798,7 +799,7 @@ def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> None: + mode: str | lax.GatherScatterMode | None = None) -> None: """Pure equivalent of ``x[idx] = y``. Returns the value of ``x`` that would result from the NumPy-style @@ -816,7 +817,7 @@ def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, def apply(self, func: Callable[[ArrayLike], Array], *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. Returns the value of ``x`` that would result from applying the unary @@ -840,7 +841,7 @@ def _scatter_apply(x, indices, y, dims, **kwargs): def add(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] += y``. Returns the value of ``x`` that would result from the NumPy-style @@ -855,7 +856,7 @@ def add(self, values: ArrayLike, *, def subtract(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] -= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -870,7 +871,7 @@ def subtract(self, values: ArrayLike, *, def multiply(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] *= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -887,7 +888,7 @@ def multiply(self, values: ArrayLike, *, def divide(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] /= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -904,7 +905,7 @@ def divide(self, values: ArrayLike, *, def power(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] **= y``. Returns the value of ``x`` that would result from the NumPy-style @@ -921,7 +922,7 @@ def power(self, values: ArrayLike, *, def min(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style @@ -937,7 +938,7 @@ def min(self, values: ArrayLike, *, def max(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | jax.lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None) -> Array: """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index e2c23b43bdf8..8af0c52566b5 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -15,9 +15,11 @@ import contextlib from typing import Literal, Sequence -import jax +import numpy as np + from jax._src import config -from jax._src.typing import ArrayLike +from jax._src import dtypes +from jax._src.typing import Array, ArrayLike Category = Literal["nan", "divide", "oob"] @@ -40,7 +42,7 @@ def _is_category_disabled( def _set_error_if_with_category( - pred: jax.Array, + pred: Array, /, msg: str, category: Category | None = None, @@ -65,7 +67,7 @@ def _set_error_if_with_category( error_check_lib.set_error_if(pred, msg) -def _set_error_if_nan(pred: jax.Array, /): +def _set_error_if_nan(pred: Array, /): """Set the internal error state if any element of `pred` is `NaN`. This function is disabled if the `jax_error_checking_behavior_nan` flag is @@ -74,17 +76,17 @@ def _set_error_if_nan(pred: jax.Array, /): if config.error_checking_behavior_nan.value == "ignore": return - # TODO(mattjj): fix the circular import issue. - import jax.numpy as jnp - if not jnp.issubdtype(pred.dtype, jnp.floating): # only check floats + if not dtypes.issubdtype(pred.dtype, np.floating): # only check floats return # TODO(mattjj): fix the circular import issue. from jax._src import error_check as error_check_lib + import jax.numpy as jnp + error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") -def _set_error_if_divide_by_zero(pred: jax.Array, /): +def _set_error_if_divide_by_zero(pred: Array, /): """Set the internal error state if any element of `pred` is zero. This function is intended for checking if the denominator of a division is diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 2316ad73ffeb..21da91ce613f 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -18,8 +18,8 @@ import operator import numpy as np -from jax import lax from jax._src import dtypes +from jax._src.lax import fft as lax_fft from jax._src.lib import xla_client from jax._src.util import safe_zip from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact @@ -45,7 +45,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: '"ortho" or "forward".') -def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -80,14 +80,14 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, in_s = list(arr.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping arr = arr[tuple(map(slice, in_s))] # Padding arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)]) else: - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: s = [arr.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] @@ -103,10 +103,10 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, return transformed -def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: +def _fft_core_nd(arr: Array, fft_type: lax_fft.FftType, s: Shape) -> Array: # XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly. if len(s) <= 3: - return lax.fft(arr, fft_type, tuple(s)) + return lax_fft.fft(arr, fft_type, tuple(s)) # For larger N, we repeatedly apply N<=3 transforms until we reach the # requested dimension. We special case N=4 to use two 2-D transforms instead @@ -115,16 +115,16 @@ def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: n = 2 if len(s) == 4 else 3 src = tuple(range(arr.ndim - len(s), arr.ndim - n)) dst = tuple(range(arr.ndim - len(s) + n, arr.ndim)) - if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}: - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + if fft_type in {lax_fft.FftType.RFFT, lax_fft.FftType.FFT}: + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.FFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) else: arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.IFFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) return arr @@ -199,7 +199,7 @@ def fftn(a: ArrayLike, s: Shape | None = None, >>> jnp.allclose(x, jnp.fft.ifftn(x_fftn)) Array(True, dtype=bool) """ - return _fft_core('fftn', lax.FftType.FFT, a, s, axes, norm) + return _fft_core('fftn', lax_fft.FftType.FFT, a, s, axes, norm) def ifftn(a: ArrayLike, s: Shape | None = None, @@ -267,7 +267,7 @@ def ifftn(a: ArrayLike, s: Shape | None = None, [[ 2.5 +0.j 0. -0.58j 0. +0.58j] [ 0.17+0.j -0.83-0.29j -0.83+0.29j]] """ - return _fft_core('ifftn', lax.FftType.IFFT, a, s, axes, norm) + return _fft_core('ifftn', lax_fft.FftType.IFFT, a, s, axes, norm) def rfftn(a: ArrayLike, s: Shape | None = None, @@ -358,7 +358,7 @@ def rfftn(a: ArrayLike, s: Shape | None = None, >>> jnp.fft.rfftn(x1) Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) """ - return _fft_core('rfftn', lax.FftType.RFFT, a, s, axes, norm) + return _fft_core('rfftn', lax_fft.FftType.RFFT, a, s, axes, norm) def irfftn(a: ArrayLike, s: Shape | None = None, @@ -435,7 +435,7 @@ def irfftn(a: ArrayLike, s: Shape | None = None, [[-2., -2., -2.], [-2., -2., -2.]]], dtype=float32) """ - return _fft_core('irfftn', lax.FftType.IRFFT, a, s, axes, norm) + return _fft_core('irfftn', lax_fft.FftType.IRFFT, a, s, axes, norm) def _axis_check_1d(func_name: str, axis: int | None): @@ -446,7 +446,7 @@ def _axis_check_1d(func_name: str, axis: int | None): "Got axis = %r." % (full_name, full_name, axis) ) -def _fft_core_1d(func_name: str, fft_type: lax.FftType, +def _fft_core_1d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, n: int | None, axis: int | None, norm: str | None) -> Array: _axis_check_1d(func_name, axis) @@ -514,7 +514,7 @@ def fft(a: ArrayLike, n: int | None = None, >>> jnp.allclose(x, jnp.fft.ifft(x_fft)) Array(True, dtype=bool) """ - return _fft_core_1d('fft', lax.FftType.FFT, a, n=n, axis=axis, + return _fft_core_1d('fft', lax_fft.FftType.FFT, a, n=n, axis=axis, norm=norm) @@ -570,7 +570,7 @@ def ifft(a: ArrayLike, n: int | None = None, [ 0.67+0.58j -0.5 +1.44j 0.17+2.02j 1.83+0.29j] [ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]] """ - return _fft_core_1d('ifft', lax.FftType.IFFT, a, n=n, axis=axis, + return _fft_core_1d('ifft', lax_fft.FftType.IFFT, a, n=n, axis=axis, norm=norm) @@ -631,7 +631,7 @@ def rfft(a: ArrayLike, n: int | None = None, [ 1.-2.j, 3.-4.j, 5.-6.j], [-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64) """ - return _fft_core_1d('rfft', lax.FftType.RFFT, a, n=n, axis=axis, + return _fft_core_1d('rfft', lax_fft.FftType.RFFT, a, n=n, axis=axis, norm=norm) @@ -691,7 +691,7 @@ def irfft(a: ArrayLike, n: int | None = None, [-0.75, -1.25, -1.75], [ 0.25, 0.75, 1.25]], dtype=float32) """ - return _fft_core_1d('irfft', lax.FftType.IRFFT, a, n=n, axis=axis, + return _fft_core_1d('irfft', lax_fft.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @@ -781,7 +781,7 @@ def hfft(a: ArrayLike, n: int | None = None, conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', lax.FftType.IRFFT, conj_a, n=n, axis=axis, + return _fft_core_1d('hfft', lax_fft.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn @@ -831,12 +831,12 @@ def ihfft(a: ArrayLike, n: int | None = None, _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', lax.FftType.RFFT, arr, n=n, axis=axis, + output = _fft_core_1d('ihfft', lax_fft.FftType.RFFT, arr, n=n, axis=axis, norm=norm) return ufuncs.conj(output) * (1 / nn) -def _fft_core_2d(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core_2d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int], norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -923,7 +923,7 @@ def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), >>> jnp.allclose(x, jnp.fft.ifft2(x_fft2)) Array(True, dtype=bool) """ - return _fft_core_2d('fft2', lax.FftType.FFT, a, s=s, axes=axes, + return _fft_core_2d('fft2', lax_fft.FftType.FFT, a, s=s, axes=axes, norm=norm) @@ -995,7 +995,7 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [-0.33-0.58j, -0.33-0.58j], [-0.33+0.58j, -0.33+0.58j]]], dtype=complex64) """ - return _fft_core_2d('ifft2', lax.FftType.IFFT, a, s=s, axes=axes, + return _fft_core_2d('ifft2', lax_fft.FftType.IFFT, a, s=s, axes=axes, norm=norm) @@ -1074,7 +1074,7 @@ def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) """ - return _fft_core_2d('rfft2', lax.FftType.RFFT, a, s=s, axes=axes, + return _fft_core_2d('rfft2', lax_fft.FftType.RFFT, a, s=s, axes=axes, norm=norm) @@ -1149,7 +1149,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32) """ - return _fft_core_2d('irfft2', lax.FftType.IRFFT, a, s=s, axes=axes, + return _fft_core_2d('irfft2', lax_fft.FftType.IRFFT, a, s=s, axes=axes, norm=norm) diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index ec67d7489f30..ab07ecad0cf5 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -17,7 +17,9 @@ from collections.abc import Iterable from typing import Any, Union -import jax +import numpy as np + +from jax._src import config from jax._src import core from jax._src.numpy.util import promote_dtypes from jax._src.numpy.lax_numpy import ( @@ -26,8 +28,6 @@ from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -import numpy as np - export = set_module('jax.numpy') @@ -83,7 +83,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="mgrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) output_arr = meshgrid(*output, indexing='ij', sparse=False) if len(output_arr) == 0: @@ -128,7 +128,7 @@ def __getitem__( if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="ogrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) return meshgrid(*output, indexing='ij', sparse=True) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 6aa5d6b87ef4..573352135806 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -20,7 +20,8 @@ import string from typing import Any, NamedTuple, Sequence -import jax +import numpy as np + from jax import lax from jax._src import array from jax._src import config @@ -39,7 +40,6 @@ from jax._src.tree_util import tree_flatten from jax._src.typing import Array, ArrayLike, StaticScalar from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update -import numpy as np export = set_module('jax.numpy') @@ -314,8 +314,10 @@ def replace(tup, val): return lax.full(out_shape, 0, a.dtype) if mode == "one_hot": + from jax import nn # pytype: disable=import-error + indices = _normalize_index(indices, axis_size) - hot = jax.nn.one_hot(indices, axis_size, dtype=np.bool_) + hot = nn.one_hot(indices, axis_size, dtype=np.bool_) if a.ndim == 1: return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype) if axis_int > len(string.ascii_letters) - 2: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 171a64a758ad..f323bc64718b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -35,9 +35,9 @@ from typing import (Any, IO, Literal, Protocol, TypeVar, Union, overload) import warnings -import jax -from jax import jit from jax import lax +from jax._src.api import jit +from jax._src import api from jax._src import config from jax._src import core from jax._src import deprecations @@ -508,7 +508,7 @@ def isscalar(element: Any) -> bool: """ if np.isscalar(element): return True - elif isinstance(element, (np.ndarray, jax.Array)): + elif isinstance(element, (np.ndarray, Array)): return element.ndim == 0 elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 @@ -3418,7 +3418,7 @@ def clip( ) util.check_arraylike("clip", arr) - if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)): + if any(iscomplexobj(t) for t in (arr, min, max)): raise ValueError( "Clip received a complex value either through the input or the min/max " "keywords. Complex values have no ordering and cannot be clipped. " @@ -4676,7 +4676,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: [1., 1., 1., 0.]], dtype=float32) """ util.check_arraylike("concat", *arrays) - return jax.numpy.concatenate(arrays, axis=axis) + return concatenate(arrays, axis=axis) @export @@ -4732,7 +4732,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_2d)(tup) + arrs = api.vmap(atleast_2d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("vstack", *tup, emit_warning=True) @@ -4791,7 +4791,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_1d)(tup) + arrs = api.vmap(atleast_1d)(tup) arr0_ndim = arrs.ndim - 1 else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -4854,7 +4854,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_3d)(tup) + arrs = api.vmap(atleast_3d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) @@ -4916,7 +4916,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """ arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup + arrs = api.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) @@ -5354,7 +5354,7 @@ def _make_string_array( ) # Just do a device_put since XLA does not support string as a data type. - return jax.device_put(x=object, device=device) + return api.device_put(x=object, device=device) @export @@ -5447,7 +5447,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and device is None): # Keep the output uncommitted. - return jax.device_put(object) + return api.device_put(object) # String arrays need separate handling because XLA does not support string # as a data type. @@ -5551,7 +5551,7 @@ def _get_platform( return device_or_sharding elif device_or_sharding is None: if config.default_device.value is None: - return jax.default_backend() + return xla_bridge.default_backend() else: return _get_platform(config.default_device.value) else: @@ -6077,7 +6077,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") for i in range(len(shape)): in_axes = [0 if i == j else None for j in range(len(shape))] - function = jax.vmap(function, in_axes=tuple(in_axes[::-1])) + function = api.vmap(function, in_axes=tuple(in_axes[::-1])) return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) @@ -6166,7 +6166,7 @@ def eye(N: DimSize, M: DimSize | None = None, # instead of putting it on default device and then on the specific device output = _eye(N, M=M, k=k, dtype=dtype) if device is not None: - return jax.device_put(output, device=device) + return api.device_put(output, device=device) return output @@ -6299,7 +6299,7 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, # instead of putting it on default device and then on the specific device output = _arange(start, stop=stop, step=step, dtype=dtype) if device is not None: - return jax.device_put(output, device=device) + return api.device_put(output, device=device) return output @@ -6496,7 +6496,7 @@ def _i0(x): @_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) + primal_out, tangent_out = api.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) @export @@ -7792,7 +7792,7 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: util.check_arraylike("trim_zeros", filt, emit_warning=True) core.concrete_or_error(None, filt, "Error arose in the `filt` argument of trim_zeros()") - filt_arr = jax.numpy.asarray(filt) + filt_arr = asarray(filt) del filt if filt_arr.ndim != 1: # Added on 2024-09-11 @@ -8173,9 +8173,9 @@ def apply_along_axis( axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) for i in range(1, num_dims - axis): - func = jax.vmap(func, in_axes=i, out_axes=-1) + func = api.vmap(func, in_axes=i, out_axes=-1) for i in range(axis): - func = jax.vmap(func, in_axes=0, out_axes=0) + func = api.vmap(func, in_axes=0, out_axes=0) return func(arr) @@ -9623,7 +9623,7 @@ def _rank(x): def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_lt_comparator if side == 'left' else _sort_le_comparator - comparisons = jax.vmap(op, in_axes=(0, None))(sorted_arr, query) + comparisons = api.vmap(op, in_axes=(0, None))(sorted_arr, query) return comparisons.sum(dtype=dtype, axis=0) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 0e20e5b2a416..f2deddd52f05 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -23,10 +23,11 @@ import operator from typing import Literal, NamedTuple, overload -import jax -from jax import jit, custom_jvp from jax import lax +from jax._src.api import jit +from jax._src import config +from jax._src.custom_derivatives import custom_jvp from jax._src import deprecations from jax._src.lax import lax as lax_internal from jax._src.lax.lax import PrecisionLike @@ -44,24 +45,24 @@ class EighResult(NamedTuple): - eigenvalues: jax.Array - eigenvectors: jax.Array + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: jax.Array - R: jax.Array + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: jax.Array - logabsdet: jax.Array + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: jax.Array - S: jax.Array - Vh: jax.Array + U: Array + S: Array + Vh: Array def _H(x: ArrayLike) -> Array: @@ -995,7 +996,7 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) @_pinv.defjvp -@jax.default_matmul_precision("float32") +@config.default_matmul_precision("float32") def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM @@ -1617,7 +1618,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: ndim = x_arr.ndim if ndim < 2: raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}") - return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) + return lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) @export diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 2b2923ba93ce..2f7a32c3f52d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -19,8 +19,8 @@ import numpy as np -from jax import jit from jax import lax +from jax._src.api import jit from jax._src import dtypes from jax._src import core from jax._src.lax import lax as lax_internal diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index cbfda25eafcf..e1f499ccc530 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -23,9 +23,9 @@ import numpy as np -import jax from jax import lax from jax._src import api +from jax._src import config from jax._src import core from jax._src import deprecations from jax._src import dtypes @@ -793,7 +793,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size = 1 a_shape = np.shape(a) for a in axis_seq: - size *= maybe_named_axis(a, lambda i: a_shape[i], jax.lax.axis_size) + size *= maybe_named_axis(a, lambda i: a_shape[i], lax.axis_size) return size @@ -1136,7 +1136,7 @@ def _var(a: Array, axis: Axis = None, dtype: DTypeLike | None = None, normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) result = lax.div(result, normalizer).astype(dtype) - with jax.debug_nans(False): + with config.debug_nans(False): result = _where(normalizer > 0, result, np.nan) return result @@ -2513,7 +2513,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - with jax.debug_nans(False): + with config.debug_nans(False): a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2b0e04adc997..1abe7cf66c15 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -22,11 +22,11 @@ from typing import Any -import jax +import numpy as np + from jax._src.typing import Array from jax._src import core from jax._src import dtypes -import numpy as np # Some objects below rewrite their __module__ attribute to this name. @@ -46,7 +46,8 @@ def __ne__(self, other: Any) -> bool: return not (self == other) def __call__(self, x: Any) -> Array: - return jax.numpy.asarray(x, dtype=self.dtype) + from jax._src.numpy.lax_numpy import asarray + return asarray(x, dtype=self.dtype) def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index d4a8e41dd317..ef1d44ae01b1 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -21,10 +21,9 @@ import numpy as np -import jax -from jax import jit from jax import lax +from jax._src.api import jit from jax._src import core from jax._src import dtypes from jax._src.lax import lax as lax_internal @@ -59,8 +58,10 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, else: return (arr1[:, None] == arr2[None, :]).any(-1) elif method == 'binary_search': + from jax._src.numpy.lax_numpy import searchsorted + arr2 = lax.sort(arr2) - ind = jax.numpy.searchsorted(arr2, arr1) + ind = searchsorted(arr2, arr1) if invert: return arr1 != arr2[ind] else: diff --git a/jax/_src/numpy/sorting.py b/jax/_src/numpy/sorting.py index a0f368e2ef07..be8f42ce6145 100644 --- a/jax/_src/numpy/sorting.py +++ b/jax/_src/numpy/sorting.py @@ -17,14 +17,14 @@ import numpy as np -import jax +from jax import lax + from jax._src import api from jax._src import core from jax._src import dtypes from jax._src.numpy import util from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike -from jax import lax export = set_module('jax.numpy') @@ -226,7 +226,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1 @@ -234,7 +234,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: bottom = -lax.top_k(-arr, kth + 1)[0] top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0] out = lax.concatenate([bottom, top], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export @@ -297,7 +297,7 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1] @@ -307,11 +307,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # To avoid issues with duplicate values, we compute the top indices via a proxy set_to_zero = lambda a, i: a.at[i].set(0) for _ in range(arr.ndim - 1): - set_to_zero = jax.vmap(set_to_zero) - proxy = set_to_zero(jax.numpy.ones(arr.shape), bottom_ind) + set_to_zero = api.vmap(set_to_zero) + proxy = set_to_zero(lax.full(arr.shape, 1.0), bottom_ind) top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1] out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export @@ -421,7 +421,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A if len({np.shape(key) for key in key_arrays}) > 1: raise ValueError("all keys need to be the same shape") if np.ndim(key_arrays[0]) == 0: - return jax.numpy.array(0, dtype=dtypes.canonicalize_dtype(dtypes.int_)) + return lax.full((), 0, dtypes.canonicalize_dtype(dtypes.int_)) axis = canonicalize_axis(axis, np.ndim(key_arrays[0])) use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31) iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, diff --git a/jax/_src/numpy/tensor_contractions.py b/jax/_src/numpy/tensor_contractions.py index 979f68e28f6d..255da08e1816 100644 --- a/jax/_src/numpy/tensor_contractions.py +++ b/jax/_src/numpy/tensor_contractions.py @@ -20,7 +20,6 @@ import numpy as np -import jax from jax import lax from jax._src import core from jax._src import dtypes @@ -378,7 +377,7 @@ def vdot( a, b = util.ensure_arraylike("vdot", a, b) if dtypes.issubdtype(dtypes.dtype(a, canonicalize=True), np.complexfloating): a = ufuncs.conj(a) - return dot(jax.numpy.ravel(a), jax.numpy.ravel(b), precision=precision, + return dot(a.ravel(), b.ravel(), precision=precision, preferred_element_type=preferred_element_type) @@ -429,11 +428,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, >>> jnp.linalg.vecdot(a, b, axis=-1) Array([20, 47], dtype=int32) """ + from jax._src.numpy.lax_numpy import moveaxis + x1_arr, x2_arr = util.ensure_arraylike("jnp.vecdot", x1, x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") - x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) - x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) + x1_arr = moveaxis(x1_arr, axis, -1) + x2_arr = moveaxis(x2_arr, axis, -1) return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type), signature="(n),(n)->()")(x1_arr, x2_arr) @@ -604,8 +605,9 @@ def inner( """ a, b = util.ensure_arraylike("inner", a, b) if np.ndim(a) == 0 or np.ndim(b) == 0: - a = jax.numpy.asarray(a, dtype=preferred_element_type) - b = jax.numpy.asarray(b, dtype=preferred_element_type) + if preferred_element_type is not None: + a = a.astype(preferred_element_type) + b = b.astype(preferred_element_type) return a * b return tensordot(a, b, (-1, -1), precision=precision, preferred_element_type=preferred_element_type) @@ -643,4 +645,4 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") a, b = util.ensure_arraylike("outer", a, b) a, b = util.promote_dtypes(a, b) - return jax.numpy.ravel(a)[:, None] * jax.numpy.ravel(b)[None, :] + return a.ravel()[:, None] * b.ravel()[None, :] diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index da55212bae1f..243ab9aa0878 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -22,9 +22,11 @@ import operator from typing import Any -import jax +from jax._src import api from jax._src.typing import Array, ArrayLike, DTypeLike -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow +from jax._src.lax import slicing +from jax._src.lax import lax from jax._src.numpy import indexing import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis @@ -179,11 +181,11 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An call = self.__static_props['call'] or self._call_vectorized return call(*args) - @partial(jax.jit, static_argnames=['self']) + @partial(api.jit, static_argnames=['self']) def _call_vectorized(self, *args): return vectorize(self._func)(*args) - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) + @partial(api.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) def reduce(self, a: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -249,8 +251,8 @@ def reduce(self, a: ArrayLike, axis: int | None = 0, if self.identity is None and initial is None: raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, " "so to use a where mask one has to specify 'initial'.") - if lax_internal._dtype(where) != bool: - raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") + if lax._dtype(where) != bool: + raise ValueError(f"where argument must have dtype=bool; got dtype={lax._dtype(where)}") reduce = self.__static_props['reduce'] or self._reduce_via_scan return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) @@ -258,11 +260,11 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLik keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if initial is None: initial = self.identity if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if where is not None: where = _broadcast_to(where, arr.shape) if isinstance(axis, tuple): @@ -306,15 +308,15 @@ def body_fun(i, val): else: start_index = 0 start_value = initial - start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:]) + start_value = _broadcast_to(lax.asarray(start_value).astype(dtype), arr.shape[1:]) - result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value) + result = control_flow.fori_loop(start_index, arr.shape[0], body_fun, start_value) if keepdims: result = result.reshape(final_shape) return result - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @partial(api.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Accumulate operation derived from binary ufunc. @@ -376,10 +378,10 @@ def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 check_arraylike(f"{self.__name__}.accumulate", arr) - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if axis is None or isinstance(axis, tuple): raise ValueError("accumulate does not allow multiple axes") @@ -390,10 +392,10 @@ def scan_fun(carry, _): i, x = carry y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y - _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) + _, result = control_flow.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) + @partial(api.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: """Update elements of an array via the specified unary or binary ufunc. @@ -440,15 +442,15 @@ def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} check_arraylike(f"{self.__name__}.at", a, *args) - dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype - a = lax_internal.asarray(a).astype(dtype) - args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) + dtype = api.eval_shape(self._func, lax._one(a), *(lax._one(arg) for arg in args)).dtype + a = lax.asarray(a).astype(dtype) + args = tuple(lax.asarray(arg).astype(dtype) for arg in args) indices = indexing.eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] - shape = shapes and jax.lax.broadcast_shapes(*shapes) + shape = shapes and lax.broadcast_shapes(*shapes) if not shape: return a.at[indices].set(self(a.at[indices].get(), *args)) @@ -462,10 +464,10 @@ def scan_fun(carry, x): idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x - carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] + carry, _ = control_flow.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] return carry[1] - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @partial(api.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Reduce an array between specified indices via a binary ufunc. @@ -517,7 +519,7 @@ def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) - a = lax_internal.asarray(a) + a = lax.asarray(a) idx_tuple = indexing.eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] @@ -531,17 +533,17 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) out = indexing.take(a, indices, axis=axis) - ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), - list(np.delete(np.arange(out.ndim), axis))) - ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) - ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) + ind = lax.expand_dims(jnp.append(indices, a.shape[axis]), + list(np.delete(np.arange(out.ndim), axis))) + ind_start = slicing.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) + ind_end = slicing.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self(out, indexing.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, indexing.take(a, lax.expand_dims(i, (0,)), axis=axis)), out) - return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) + return control_flow.fori_loop(0, a.shape[axis], loop_body, out) - @partial(jax.jit, static_argnums=[0]) + @partial(api.jit, static_argnums=[0]) def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: """Apply the function to all pairs of values in ``A`` and ``B``. @@ -584,8 +586,8 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) - _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) + _ravel = lambda A: lax.reshape(A, (np.size(A),)) + result = api.vmap(api.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index 5ea9d697d27d..f166a96a4693 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,7 +23,7 @@ from jax._src import api from jax._src import config -from jax import lax +from jax._src.lax import lax from jax._src.numpy import lax_numpy as jnp from jax._src.util import set_module, safe_map as map, safe_zip as zip diff --git a/jax/_src/numpy/window_functions.py b/jax/_src/numpy/window_functions.py index 96a15db777a8..6d1bfb245272 100644 --- a/jax/_src/numpy/window_functions.py +++ b/jax/_src/numpy/window_functions.py @@ -16,11 +16,11 @@ from jax._src import core from jax._src import dtypes +from jax._src.lax import lax from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -from jax import lax export = set_module('jax.numpy') From e24f7807650d93c70ef7be3d779f64a287a113e8 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 3 Jun 2025 10:56:39 -0700 Subject: [PATCH 1499/1769] [Mosaic GPU] Add lowering for `2xf32 -> 2xf8e4m3fn` conversions. The conversion uses the `cvt.rn.satfinite.e4m3x2.f32` intrinsics, which means that the saturation behaviour is different from XLA's default. This does ask the question of which numerical behaviour we expect Mosaic GPU to uphold---but we probably don't want to propagate NaNs in this case anyway. PiperOrigin-RevId: 766737083 --- .../mosaic/gpu/fragmented_array.py | 31 ++++++++++++ tests/mosaic/gpu_test.py | 47 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 7278af5d7a91..925aa1575e2d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1402,11 +1402,14 @@ def __getitem__(self, idx): # TODO(apaszke): Support JAX dtypes here as well? def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): + index = ir.IndexType.get() i4 = ir.IntegerType.get_signless(4) i8 = ir.IntegerType.get_signless(8) i16 = ir.IntegerType.get_signless(16) i32 = ir.IntegerType.get_signless(32) bf16 = ir.BF16Type.get() + f32 = ir.F32Type.get() + f8e4m3fn = ir.Float8E4M3FNType.get() cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: @@ -1540,6 +1543,34 @@ def upcast_to_bf16(reg, high): return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) + # TODO(bchetioui): handle conversions to/from other float8 types. + if cur_dtype == f32 and new_dtype == f8e4m3fn: + if vector_len != 2: + raise NotImplementedError(vector_len) + new_registers = np.empty_like(self.registers) + empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32)) + empty_result_vec = llvm.mlir_undef(ir.VectorType.get((2,), i8)) + for idx, reg in np.ndenumerate(self.registers): + e0 = vector.extractelement(reg, position=c(0, index)) + e1 = vector.extractelement(reg, position=c(1, index)) + new_reg_32 = llvm.inline_asm( + i32, + [e1, e0], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + ) + new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) + new_vec_f8 = vector.bitcast(ir.VectorType.get((4,), i8), new_vec_32) + res = llvm.insertelement( + empty_result_vec, + vector.extractelement(new_vec_f8, position=c(0, i32)), c(0, i32)) + res = llvm.insertelement( + res, + vector.extractelement(new_vec_f8, position=c(1, i32)), c(1, i32)) + new_registers[idx] = vector.bitcast(ir.VectorType.get((2,), f8e4m3fn), res) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 314dc8f8f41d..3bb0c1b9fe78 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -587,6 +587,53 @@ def kernel(ctx, inp, out, smem): f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) np.testing.assert_array_equal(f(x), y) + def test_f8_conversions(self): + jax_dtype_from, jax_dtype_to = jnp.float32, jnp.float8_e4m3fn + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) + def kernel(ctx, inp, out, smem): + del ctx + smem_from, smem_to = smem + copy(inp, smem_from, swizzle=128) + t = mgpu.FragmentedArray.load_tiled( + smem_from, + swizzle=128, + is_signed=None, + layout=fa.WGMMA_LAYOUT, + ) + t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) + t.store_tiled(smem_to, swizzle=128) + copy(smem_to, out, swizzle=128) + + # These generative shenanigans are to ensure that we don't generate values + # that are too large for the target type. That is because the saturation + # behavior of the conversion is different between XLA and Mosaic GPU here + # (to use the NVIDIA internal, we allow Mosaic GPU to use the .satfinite + # modifier, which saturates to the largest finite value---while XLA would + # give us NaNs in this case). + max_finite_val = 0b111_1110 + + expected = jax.lax.bitcast_convert_type( + jax.random.randint( + jax.random.key(42), + (1, 1, 64, 128), + -max_finite_val, + max_finite_val + 1, + dtype=jnp.uint8, + ), + jax_dtype_to, + ) + x = expected.astype(jax_dtype_from) + + res = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + expected, + (x, expected), + )(x) + np.testing.assert_array_equal(res, expected) + @parameterized.product( jax_dtype_from_to=( (jnp.int8, jnp.bfloat16), From cda50f5dbd4cd35b26bc1489847b37a977509f40 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Jun 2025 11:20:31 -0700 Subject: [PATCH 1500/1769] [JAX] Remove the redundant pjit BUILD target. PiperOrigin-RevId: 766747093 --- jax/BUILD | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index b577abcabf5f..80add9c096bd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1501,17 +1501,6 @@ pytype_library( deps = [":jax"], ) -# TODO(apaszke): Remove this target -pytype_library( - name = "pjit", - srcs = ["experimental/pjit.py"], - visibility = ["//visibility:public"], - deps = [ - ":experimental", - ":jax", - ], -) - pytype_library( name = "jet", srcs = ["experimental/jet.py"], From 554cc01f76970bf3ed18e04a488577bedcf912ba Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 3 Jun 2025 11:23:59 -0700 Subject: [PATCH 1501/1769] [Mosaic GPU] Add BUILD rules for blackwell matmul kernel PiperOrigin-RevId: 766748530 --- tests/pallas/BUILD | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 48eddae69a60..5580cb2abd73 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -809,9 +809,13 @@ jax_multiplatform_test( name = "mgpu_matmul_test", srcs = ["mgpu_matmul_test.py"], enable_backends = [], - enable_configs = [], # TODO(justinfu): Enable B200 when available. + enable_configs = ["gpu_b200"], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 8, + tags = [ + # TODO(b/330364373): Remove when B200 is fully supported. + "notap", + ], deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", @@ -822,6 +826,25 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "blackwell_matmul_mgpu_run", + srcs = ["//jax/experimental/pallas/ops/gpu:blackwell_matmul_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_b200"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "mgpu_ragged_dot_run", srcs = ["//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py"], From 7dd0344f9db12ef676280acc51c30d6496ffc2ac Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 3 Jun 2025 11:48:03 -0700 Subject: [PATCH 1502/1769] [jaxlib] Add `PyClient::Compile` method that returns an unloaded `PyExecutable`. - Introduce `xla::PyExecutable` so we have a public constructor for returning an `xla::nb_class_ptr` from `PyClient::Compile`. There might be other acceptable ways of accomplishing this, but we have a `PyLoadedExecutable` object, so going for consistency. - Migrate uses of `ifrt::Executable` to `ifrt::ExecutableRef` (an alias for `std::shared_ptr`). There might be undesirable consequences for doing this (i.e., a reason why this wasn't migrated before). PiperOrigin-RevId: 766757937 --- jaxlib/py_client.cc | 24 ++++++++++++++ jaxlib/py_client.h | 5 +++ jaxlib/py_compile_only_client.cc | 39 +++++++++++++---------- jaxlib/py_executable.h | 54 ++++++++++++++++++++++++++++++-- jaxlib/xla.cc | 20 ++++++------ jaxlib/xla_client.py | 2 +- 6 files changed, 113 insertions(+), 31 deletions(-) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 98bde8c27396..8e78f024e1ae 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -83,6 +83,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/types.h" @@ -451,6 +452,29 @@ PyClient::CompileAndLoadIfrtProgram( std::move(traceback), std::move(fingerprint)); } +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options) { + ifrt::ExecutableRef executable_ref; + { + mlir::MLIRContext context; + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + TF_ASSIGN_OR_RETURN( + auto topology, + client->ifrt_client()->GetTopologyForDevices(executable_devices)); + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN(auto pjrt_executable, + PjRtCompile(std::move(options), module.get(), + *topology->description())); + TF_ASSIGN_OR_RETURN(executable_ref, ifrt::PjRtExecutable::Create( + std::move(pjrt_executable))); + } + return make_nb_class(executable_ref); +} + /* static */ absl::StatusOr> PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, ifrt::DeviceListRef executable_devices, diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h index 7f70fa4f111b..520fbf8b1e59 100644 --- a/jaxlib/py_client.h +++ b/jaxlib/py_client.h @@ -50,6 +50,7 @@ namespace xla { class PyClient; class PyLoadedExecutable; +class PyExecutable; class PyArray; class PyDevice; class PyMemorySpace; @@ -167,6 +168,10 @@ class PyClient { std::unique_ptr ifrt_program, std::unique_ptr ifrt_options); + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options); + static absl::StatusOr> CompileAndLoad( nb_class_ptr client, std::string mlir_module, ifrt::DeviceListRef executable_devices, CompileOptions options, diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 274f57acba00..bcae15cd6438 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -33,6 +33,7 @@ limitations under the License. #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_client.h" #include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" @@ -70,7 +71,7 @@ class CompileOnlyPyClient : public PyClient { return client; } - absl::StatusOr CompileUnloaded( + absl::StatusOr> CompileUnloaded( absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, CompileOptions options, std::vector host_callbacks) { if (!host_callbacks.empty()) { @@ -78,22 +79,26 @@ class CompileOnlyPyClient : public PyClient { "Compiling with host_callbacks not available with compile-only " "client."); } - nb::gil_scoped_release gil_release; - mlir::MLIRContext context; - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseMlirModuleString(mlir_module, context)); - auto* ifrt_client = - llvm::dyn_cast_or_null(this->ifrt_client()); - CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " - "CompileOnlyIfRtClient"; - auto xla_options = std::make_unique( - options, std::move(executable_devices)); - TF_ASSIGN_OR_RETURN(auto executable, - PjRtCompile(std::move(options), module.get(), - *ifrt_client->topology().description())); - TF_ASSIGN_OR_RETURN(auto ifrt_executable, - ifrt::PjRtExecutable::Create(std::move(executable))); - return ifrt::ExecutableRef(std::move(ifrt_executable)); + ifrt::ExecutableRef ifrt_executable; + { + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; + + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + } + return make_nb_class(ifrt_executable); } private: diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index 6354edaf9a3e..fed6552a9eb5 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -126,7 +126,55 @@ class PyExecuteResults { using ExecuteShardedArg = std::variant>; -// Python wrapper around PjRtExecutable. We use a wrapper class: +// Thin Python wrapper around ifrt::ExecutableRef. We use a wrapper class: +// a) Standardize around ifrt::ExecutableRef, which is +// std::shared_ptr. +// b) Concrete subclasses of ifrt::Executable have protected constructors. +class PyExecutable { + public: + PyExecutable(ifrt::ExecutableRef ifrt_executable) + : ifrt_executable_(std::move(ifrt_executable)) {}; + ~PyExecutable() = default; + + // NOTE(dsuo): For now, we only expose the ifrt::Executable members required + // by the Python bindings. + absl::StatusOr>> GetHloModules() + const { + return ifrt_executable_->GetHloModules(); + } + absl::StatusOr>> + GetOutputMemoryKinds() const { + return ifrt_executable_->GetOutputMemoryKinds(); + } + std::optional> GetOutputShardings() const { + return ifrt_executable_->GetOutputShardings(); + } + absl::StatusOr>> + GetParameterLayouts() const { + return ifrt_executable_->GetParameterLayouts(); + } + absl::StatusOr>> + GetOutputLayouts() const { + return ifrt_executable_->GetOutputLayouts(); + } + std::optional> GetParameterShardings() const { + return ifrt_executable_->GetParameterShardings(); + } + absl::StatusOr GetCompiledMemoryStats() const { + return ifrt_executable_->GetCompiledMemoryStats(); + } + absl::StatusOr Serialize() const { + return ifrt_executable_->Serialize(); + } + absl::StatusOr GetCostAnalysis() const { + return ifrt_executable_->GetCostAnalysis(); + } + + private: + ifrt::ExecutableRef ifrt_executable_; +}; + +// Python wrapper around ifrt::LoadedExecutableRef. We use a wrapper class: // a) to keep the PyClient alive via a std::shared_ptr<> // b) to add Python-specific functionality. class PyLoadedExecutable { @@ -162,8 +210,8 @@ class PyLoadedExecutable { } // Takes args indexed by argid then deviceid, transposes them, and passes to - // PjRtExecutable::Execute. The result is similarly transposed back into the - // argid,deviceid format. + // ifrt::LoadedExecutable::Execute. The result is similarly transposed back + // into the argid,deviceid format. // args is [num_args x num_devices]. absl::StatusOr ExecuteSharded( std::vector args, bool with_tokens); diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index d97c6868a04b..186bcaf2efd3 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -894,24 +894,24 @@ NB_MODULE(_jax, m) { absl::StrCat("Unknown attribute ", name).c_str()); }); - nb::class_(m, "Executable") - .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&PyExecutable::GetHloModules)) .def("get_output_memory_kinds", - xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) - .def("get_output_shardings", &ifrt::Executable::GetOutputShardings) + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyExecutable::GetOutputShardings) .def("get_parameter_layouts", - ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts)) + ValueOrThrowWrapper(&PyExecutable::GetParameterLayouts)) .def("get_output_layouts", - xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts)) - .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", &PyExecutable::GetParameterShardings) .def("get_compiled_memory_stats", - xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) + xla::ValueOrThrowWrapper(&PyExecutable::GetCompiledMemoryStats)) .def("serialize", - [](const ifrt::Executable& exec) -> nb::bytes { + [](const PyExecutable& exec) -> nb::bytes { std::string serialized = ValueOrThrow(exec.Serialize()); return nb::bytes(serialized.data(), serialized.size()); }) - .def("cost_analysis", [](const ifrt::Executable& exec) { + .def("cost_analysis", [](const PyExecutable& exec) { auto attrs = ValueOrThrow(exec.GetCostAnalysis()); return ifrt::ToPjRtAttributeMap(std::move(attrs)); }); diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index b9497b71dcb1..85de5f947e49 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 346 +_version = 347 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Tue, 3 Jun 2025 12:07:46 -0700 Subject: [PATCH 1503/1769] Prototype of cross-host device transfers in IFRT-PJRT. For now it only works with the TFRT TPU runtime, because other PjRt plugins don't implement the necessary APIs. The per-shard indices of the source and destination shardings must be the same, and all shards must require cross-host transfers (support for a mixture of cross-host and host-local transfers is forthcoming). Transfers take place via the xla::ifrt::PjRtClient::CopyArrays API, which copies the buffers from a set of arrays to a new device list. The distributed KV store from the coordination service is used to store metadata for cross-host transfers. The receiving process populates the store with a descriptor, and the sending process reads it and completes the send. PiperOrigin-RevId: 766765989 --- jax/_src/dispatch.py | 57 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b5e588cbc10e..b9ef8f49f801 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -356,16 +356,6 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(x) - if inp_sharding.device_set != target_sharding.device_set: - inp_ids = [d.id for d in inp_sharding._device_assignment] - inp_plat = inp_sharding._device_assignment[0].platform.upper() - target_ids = [d.id for d in target_sharding._device_assignment] - target_plat = target_sharding._device_assignment[0].platform.upper() - raise ValueError("Input and target sharding should have the same set of " - f"devices. Got input's device set ids: {inp_ids} on " - f"platform {inp_plat} and target sharding's device set " - f"ids: {target_ids} on platform {target_plat}") - if inp_sharding.is_fully_replicated: permute_order = None else: @@ -389,6 +379,25 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics): return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore +@util.cache() +def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): + """Returns True if src->dst is a supported cross-host transfer.""" + backend = xla_bridge.get_backend() + # There is experimental support for cross-host device transfers on TFRT TPU + # backends only. + if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or + "TFRT TPU" not in backend.platform_version): + return False + if (src_sharding._to_xla_hlo_sharding(ndim) != + dst_sharding._to_xla_hlo_sharding(ndim)): + return False + # This check excludes the case where the source and destination shardings + # have the same process index sets but there are shards that require + # cross-host transfers. This case is supportable but expensive to check for. + return (src_sharding._internal_device_list.process_indices != + dst_sharding._internal_device_list.process_indices) + + @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`. @@ -419,7 +428,8 @@ def _device_put_sharding_impl(x, aval, device, copy): return x if (not s.is_fully_addressable and - isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): + isinstance(x, array.ArrayImpl) and not x.is_fully_addressable and + s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) @@ -430,7 +440,32 @@ def _device_put_sharding_impl(x, aval, device, copy): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) + # There is experimental support for cross-host device transfers on TFRT TPU. + if (isinstance(x, array.ArrayImpl) and x._committed + and _is_supported_cross_host_transfer(x.ndim, x.sharding, s)): + return xc.batched_copy_array_to_devices_with_sharding( + [x], [s._internal_device_list], [s], # pytype: disable=attribute-error + pxla.to_xc_copy_semantics([copy]))[0] + if not s.is_fully_addressable: + # If both the source and target shardings are not fully addressable and + # one of the above conditions has not been met, then assume that the user + # is attempting a different device order reshard. + if (isinstance(x, array.ArrayImpl) and not x.is_fully_addressable + and s.device_set != x.sharding.device_set): + inp_ids = [d.id for d in x.sharding._device_assignment] + inp_plat = x.sharding._device_assignment[0].platform.upper() + target_ids = [d.id for d in s._device_assignment] + target_plat = s._device_assignment[0].platform.upper() + raise ValueError( + "For a cross-host reshard in multi-controller JAX, input and target" + " sharding should have the same set of devices. Got input's device" + f" set ids: {inp_ids} on platform {inp_plat} and target sharding's" + f" device set ids: {target_ids} on platform {target_plat}.\n\n" + "There is experimental support for cross-host transfers with " + "different device sets, when input/output shardings have the same " + "indices and layouts, in the TFRT TPU runtime only.") + if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types or type(x) in dtypes.python_scalar_dtypes): # If all hosts participate in the sharding, assert that the input is the From 7e0913f8149e4542e020256640c6ed3fd185f759 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 3 Jun 2025 12:31:29 -0700 Subject: [PATCH 1504/1769] [Mosaic GPU] Fix `bitcast` logic in `shfl_bfly`. We were not testing the logic for non-32-bit-wide dtypes, and as a result missed that one of the `bitcast`s was converting between two types with different bitwidths. PiperOrigin-RevId: 766775563 --- jax/experimental/mosaic/gpu/utils.py | 2 +- tests/mosaic/gpu_test.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index a76e077ff463..ac002aa8bffe 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1398,7 +1398,7 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value): ) if (x_bitwidth := bitwidth(result_type)) < 32: bits_ty = ir.IntegerType.get_signless(x_bitwidth) - y_vec = bitcast(y, ir.VectorType.get((32 // x_bitwidth,), x.type)) + y_vec = bitcast(y, ir.VectorType.get((32 // x_bitwidth,), bits_ty)) y = vector.extractelement(y_vec, position=c(0, index)) return bitcast(y, result_type) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 3bb0c1b9fe78..31262dcae44d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3732,20 +3732,22 @@ def strategy(draw): shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) rank = len(shape) reduced_dims = draw(hps.sets(hps.integers(0, rank - 1), min_size=1)) - return shape, layout, tuple(reduced_dims) + dtype = draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + return shape, layout, tuple(reduced_dims), dtype @hp.given(strategy()) def run(args): - shape, layout, reduced_dims = args + shape, layout, reduced_dims, dtype = args out_shape = list(shape) for d in sorted(reduced_dims, reverse=True): del out_shape[d] def kernel(ctx, src, dst, scratch): + del ctx arr = fa.FragmentedArray.load_untiled(src, layout=layout, optimized=False) arr.reduce("max", reduced_dims, scratch).store_untiled(dst, optimized=False) - x = jax.random.normal(jax.random.key(1234), shape, jnp.float32) - out_type = jax.ShapeDtypeStruct(out_shape, jnp.float32) - scratch_type = jax.ShapeDtypeStruct((2048,), jnp.float32) + x = jax.random.normal(jax.random.key(1234), shape, dtype) + out_type = jax.ShapeDtypeStruct(out_shape, dtype) + scratch_type = jax.ShapeDtypeStruct((2048,), dtype) hp.assume(layout.vector_length <= 16) # Otherwise we run out of scratch try: result = mgpu.as_gpu_kernel( From b7adddff1793e5434c1e7c2189d34ec728d49736 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 3 Jun 2025 13:27:01 -0700 Subject: [PATCH 1505/1769] Reduce block sizes in attention to prevent running out of shared memory on L4. New values (though small) show very good performance on ampere. PiperOrigin-RevId: 766797437 --- jax/experimental/pallas/ops/gpu/attention.py | 8 ++++---- tests/pallas/gpu_ops_test.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index ae429be5d73a..4782fc31226e 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -57,10 +57,10 @@ def get_default(cls): return BlockSizes( block_q=128, block_k=128, - block_q_dkv=128, - block_kv_dkv=128, - block_q_dq=128, - block_kv_dq=128, + block_q_dkv=32, + block_kv_dkv=32, + block_q_dq=32, + block_kv_dq=32, ) @property diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index cc2d15a8fdee..edda5cf686db 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -231,9 +231,9 @@ def impl(q, k, v): ( ("block_q", 128), ("block_k", 128), - ("block_q_dkv", 128), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_q_dkv", 32), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 128), ), ( @@ -248,8 +248,8 @@ def impl(q, k, v): ("block_q", 64), ("block_k", 128), ("block_q_dkv", 64), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 64), ), ), From e20b3a4e34f211a3a7d2905b730cc282dbe2c170 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 13:38:53 -0700 Subject: [PATCH 1506/1769] Fix sharding-in-types + lax.map usage when batch_size usage has a remainder left. Fixes https://github.com/jax-ml/jax/issues/29195 PiperOrigin-RevId: 766801968 --- jax/_src/core.py | 8 +++--- jax/_src/lax/control_flow/loops.py | 39 ++++++++++++++++++++++-------- tests/pjit_test.py | 12 +++++++++ 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2452d12dad5e..d8481afe872a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1533,7 +1533,7 @@ def normalize(self) -> AbstractValue: def update(self, **kwargs): raise NotImplementedError("must override") - def str_short(self, short_dtypes=False): + def str_short(self, short_dtypes=False, mesh_axis_types=False): return str(self) # For type signatures involving dynamic shapes, we use lists of abstract values @@ -1790,7 +1790,7 @@ def __str__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def str_short(self, short_dtypes=False) -> str: + def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name def update_weak_type(self, weak_type): @@ -2191,7 +2191,7 @@ def __init__(self, shape, dtype, weak_type=False): 0 if any(type(d) is int and d == 0 for d in self.shape) else math.prod(self.shape)) - def str_short(self, short_dtypes=False) -> str: + def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: del short_dtypes # ignored shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' dtype = dtypes.short_dtype_name(self.dtype) @@ -2358,7 +2358,7 @@ def _freeze_impl(ref): return ref[()] class AbstractToken(AbstractValue): - def str_short(self, short_dtypes=False): return 'Tok' + def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 146f27e5d2e7..985c5ba52294 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -53,6 +53,8 @@ _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) from jax._src.lax.other import logaddexp +from jax._src.pjit import auto_axes, PartitionSpec as P +from jax._src.mesh import get_abstract_mesh from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.state import discharge as state_discharge @@ -2507,24 +2509,41 @@ def fori_loop(lower, upper, body_fun, init_val): ### map and miscellaneous rules +def _scan_leaf(leaf, batch_elems, num_batches, batch_size): + def f(l): + return l[:batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]) + + aval = core.typeof(leaf) + if aval.sharding.spec[0] is not None: + raise ValueError( + '0th dimension of leaf passed to `jax.lax.map` should be replicated.' + f' Got {aval.str_short(True, True)}') + if get_abstract_mesh()._are_all_axes_explicit: + out_s = aval.sharding.with_spec(P(None, None, *aval.sharding.spec[1:])) + return auto_axes(f, out_sharding=out_s)(leaf) + return f(leaf) + +def _remainder_leaf(leaf, batch_elems): + def f(l): + return l[batch_elems:] + if get_abstract_mesh()._are_all_axes_explicit: + return auto_axes(f, out_sharding=core.typeof(leaf).sharding)(leaf) + return f(leaf) + def _batch_and_remainder(x, batch_size: int): leaves, treedef = tree_flatten(x) if not leaves: return x, None num_batches, remainder = divmod(leaves[0].shape[0], batch_size) - total_batch_elems = num_batches * batch_size + batch_elems = num_batches * batch_size if remainder: - scan_leaves, remainder_leaves = [], [] - for leaf in leaves: - scan_leaves.append(leaf[:total_batch_elems].reshape( - num_batches, batch_size, *leaf.shape[1:])) - remainder_leaves.append(leaf[total_batch_elems:]) + scan_leaves, remainder_leaves = unzip2( + [(_scan_leaf(leaf, batch_elems, num_batches, batch_size), + _remainder_leaf(leaf, batch_elems)) for leaf in leaves]) return treedef.unflatten(scan_leaves), treedef.unflatten(remainder_leaves) else: - scan_leaves = [ - leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]) - for leaf in leaves - ] + scan_leaves = tuple(_scan_leaf(leaf, batch_elems, num_batches, batch_size) + for leaf in leaves) return treedef.unflatten(scan_leaves), None @api_boundary diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c4d36ab78d10..b8f81438c104 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7907,6 +7907,18 @@ def simple_func(w, x): jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2,), ('x',)) + def test_lax_map_remainder(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P()) + x = jax.device_put(np.ones((5, 2, 4), dtype=np.float32), + P(None, 'x', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 1216dacabdb938251cfa03f786aac56116727a36 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 14:01:33 -0700 Subject: [PATCH 1507/1769] Resurrect _pjit_lower's cache because it's important for python dispatch performance. PiperOrigin-RevId: 766811714 --- jax/_src/api.py | 1 + jax/_src/pjit.py | 9 +++++++-- tests/pjit_test.py | 9 +++++---- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 229dee979d06..1f7cc19206a9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3164,6 +3164,7 @@ def clear_backends(): dispatch.xla_primitive_callable.cache_clear() util.clear_all_caches() pjit._infer_params_cached.cache_clear() + pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f2446e9a4939..572c2225af74 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1923,7 +1923,13 @@ def call_impl_cache_miss(*args_, **kwargs_): pjit_p.def_impl(_pjit_call_impl) -def _pjit_lower( +def _pjit_lower(*args, **kwargs): + util.test_event("pjit_lower") + return _pjit_lower_cached(*args, **kwargs) + +# This cache is important for python dispatch performance. +@weakref_lru_cache +def _pjit_lower_cached( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, @@ -1939,7 +1945,6 @@ def _pjit_lower( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - util.test_event("pjit_lower") return pxla.lower_sharding_computation( jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b8f81438c104..2fe2c4696cbe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -54,7 +54,8 @@ AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, - use_auto_axes, use_explicit_axes, reshard) + use_auto_axes, use_explicit_axes, reshard, + _pjit_lower_cached) from jax._src.layout import Format, DeviceLocalLayout as DLL from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib @@ -2306,13 +2307,13 @@ def add(x, y): return x + y out = add(a, b) - cache_info1 = pxla._cached_lowering_to_hlo.cache_info() + cache_info1 = _pjit_lower_cached.cache_info() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a + b) self.assertFalse(out._committed) out2 = add(out, out) - cache_info2 = pxla._cached_lowering_to_hlo.cache_info() + cache_info2 = _pjit_lower_cached.cache_info() self.assertIsInstance(out2, array.ArrayImpl) self.assertArraysEqual(out2, 2 * (a + b)) self.assertFalse(out2._committed) @@ -2322,7 +2323,7 @@ def add(x, y): c = jax.device_put(a, jax.devices()[0]) out3 = add(c, c) - cache_info3 = pxla._cached_lowering_to_hlo.cache_info() + cache_info3 = _pjit_lower_cached.cache_info() self.assertArraysEqual(out3, 2 * c) self.assertTrue(out3._committed) From b6a1575a45e2ba6f6840a527e0d84652cdc13989 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 3 Jun 2025 15:47:04 -0700 Subject: [PATCH 1508/1769] [pallas] In TPU interpret mode, run kernels in parallel over Megacore cores. Internally, TPU interpret mode uses a new io_callback which spawns multiple threads to simulate multiple Megacore cores. Also updates some comments / code / variable names to better distinguish between internal indices used in interpret mode vs. indices into the Pallas grid. PiperOrigin-RevId: 766851983 --- jax/_src/pallas/mosaic/interpret.py | 588 ++++++++++++---------- tests/pallas/BUILD | 1 + tests/pallas/tpu_pallas_interpret_test.py | 88 +++- 3 files changed, 394 insertions(+), 283 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 3be718aa0aa6..e278168d999a 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -35,7 +35,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit -from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives @@ -532,7 +531,6 @@ class SharedMemory: clean_up_barrier: threading.Barrier # (memory_space, buffer_id, device_id, local_core_id) -> NumPy array - # TODO(jburnim): Handle Megacore. mem: dict[tuple[str, int, int, int], np.ndarray] = dataclasses.field( default_factory=dict) @@ -691,16 +689,17 @@ def _allocate_buffer( local_core_ids = (local_core_id_int,) del local_core_id - local_core_id_to_buffer_id = {} + local_core_id_to_buffer_id : dict[int, int] = {} with shared_memory.lock: for lci in local_core_ids: buffer_id = shared_memory.next_buffer_id[(device_id, lci)] shared_memory.next_buffer_id[(device_id, lci)] = buffer_id + 1 + # If allocating in HBM, only actually allocate a buffer for core 0. if lci == 0 or memory_space_str != 'any': - # If allocating in HBM, only actually allocate a buffer for local core - # id 0. - # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, - # with zeros, or with the buffer ID). + # If we are allocating more than one buffer, we must make additional + # copies of `val` so that each buffer is a distinct ndarray. + if len(local_core_id_to_buffer_id) > 0: + val = val.copy() shared_memory.mem[(memory_space_str, buffer_id, device_id, lci)] = val local_core_id_to_buffer_id[lci] = buffer_id @@ -1040,7 +1039,6 @@ def swap( buffer_id = int(buffer_id) try: transforms = jax.tree.map(int, transforms) - # jax.debug.print(f'swap: {transforms}') except: raise ValueError('Advanced indexers are not supported on TPU') val = np.array(val) @@ -1387,7 +1385,6 @@ def write(var, value): # TODO(jburnim): Clean up and finish this evaluation loop. For example: # - Replace the big if-statement with a dictionary of rules. # - Handle other higher-order primitives? - # - Megacore. _interpret = functools.partial( _interpret_jaxpr, mesh=mesh, @@ -1458,9 +1455,9 @@ def write(var, value): elif ((prim is lax.axis_index_p) and (mesh is not None) and (eqn.params['axis_name'] in mesh.shape)): - # For now, there can only be one core. - # TODO(jburnim): Support two Megacore cores. - out = jnp.int32(0) + # We are interpreting a core_map, and this lax.axis_index call is + # querying our index along the core axis, so return our core ID. + out = local_core_id elif prim is lax.cond_p: def _make_branch(jaxpr): @@ -1708,7 +1705,8 @@ def f(*args, jaxpr): return jax._src.util.safe_map(read, jaxpr.outvars) def _compute_start_indices( - block_mapping, loop_idx, *args, mesh, compiler_params, interpret_params): + block_mapping, loop_idx, local_core_id, + *args, mesh, compiler_params, interpret_params): jaxpr = block_mapping.index_map_jaxpr block_indices = _interpret_jaxpr( jaxpr.jaxpr, @@ -1716,7 +1714,7 @@ def _compute_start_indices( *loop_idx, *args, mesh=mesh, - local_core_id=0, + local_core_id=local_core_id, compiler_params=compiler_params, interpret_params=interpret_params, ) @@ -1748,12 +1746,19 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) +def _get_indices(grid, loop_index): + indices = [] + for dim_size in reversed(grid): + i = loop_index % dim_size + loop_index = loop_index // dim_size + indices.append(i) + return tuple(reversed(indices)) -def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> tpu_core.CompilerParams: +def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> mosaic_core.CompilerParams: try: - return cast(tpu_core.CompilerParams, compiler_params['mosaic_tpu']) + return cast(mosaic_core.CompilerParams, compiler_params['mosaic_tpu']) except KeyError: - return tpu_core.CompilerParams() + return mosaic_core.CompilerParams() def _get_parallel_dim_semantics( @@ -1777,7 +1782,8 @@ def _get_parallel_dim_semantics( mosaic_params = _get_mosaic_params(compiler_params) if mosaic_params.dimension_semantics is None: return (False,) * num_dimensions_in_grid - result = tuple(ds == 'parallel' for ds in mosaic_params.dimension_semantics) + result = tuple(ds in ('parallel', mosaic_core.PARALLEL) + for ds in mosaic_params.dimension_semantics) for ds0, ds1 in zip(result[:-1], result[1:]): if ds1 and not ds0: raise ValueError( @@ -1916,6 +1922,52 @@ def _pad_to_block_dimension(value, block_shape, interpret_params): def get_interpret_effects(): return {callback._OrderedIOEffect} +def _thread_map(f, num_threads): + if num_threads == 1: + f(jnp.int32(0)) + return + + def _f(core_index): + f(core_index) + return () + jaxpr = jax.make_jaxpr(_f)(jnp.int32(0)) + + _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts) + +def _run_jaxpr(jaxpr, consts, *args): + def _run(jaxpr, consts, *args): + jax_core.eval_jaxpr(jaxpr, consts, *args) + traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args) + traced.lower().compile()(consts, *args) + return + +def _thread_map_callback(jaxpr, num_threads, consts): + num_threads = int(num_threads) + threads = [] + for i in range(num_threads): + threads.append( + threading.Thread(target=_run_jaxpr, args=(jaxpr, consts, jnp.int32(i)))) + for i in range(num_threads): + threads[i].start() + for i in range(num_threads): + threads[i].join() + +def _call_threadmap_callback(jaxpr, num_threads, *consts): + # NOTE: At runtime, _thread_map_callback will lower and compile the + # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and + # compiled once.) + # + # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at + # lowering/compilation time? E.g., by using a custom primitive here, could + # we lower/compile jaxpr at lowering time, and then pass the compiled + # function to the callback? + return callback.io_callback( + functools.partial(_thread_map_callback, jaxpr), + (), + num_threads, + consts, + ordered=True) + def interpret_pallas_call( *args, jaxpr: jax_core.Jaxpr, @@ -1930,6 +1982,14 @@ def interpret_pallas_call( ): del debug, cost_estimate, out_avals + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # As a convenience for users, if we are interpreting a pl.core_map over a + # TensorCoreMesh, we automatically set the number of cores per device so + # that users don't have to specify it in the InterpretParams. + assert len(mesh.shape) == 1 + interpret_params = dataclasses.replace( + interpret_params, num_cores_per_device=mesh.devices.shape[0]) + # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) dynamic_grid_args, scalars, input_args = split_list( args, @@ -2101,9 +2161,14 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 - randomized_grid_coordinates = _get_randomized_grid_coordinates( - grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] - ) + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # We are interpreting a pl.core_map over a TensorCoreMesh, so we use a + # fixed division of the grid between cores, instead of a random division. + randomized_grid_coordinates = (jnp.array((), dtype=jnp.int32),) * len(grid) + else: + randomized_grid_coordinates = _get_randomized_grid_coordinates( + grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] + ) parallel_dim_semantics = _get_parallel_dim_semantics( compiler_params, len(grid) @@ -2122,272 +2187,271 @@ def interpret_pallas_call( num_points_in_parallel_subgrid_per_core * num_iterations_per_point_in_parallel_subgrid ) - - def _get_local_grid_env(loop_idx): + def _get_local_grid_env(grid_point): if grid_mapping.local_grid_env is not None: - return grid_mapping.local_grid_env(loop_idx, grid) + return grid_mapping.local_grid_env(grid_point, grid) else: return tuple( pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + for dim, (idx, b) in enumerate(zip(grid_point, grid)) if dim not in grid_mapping.vmapped_dims ) - def body( - carry: tuple[ - jnp.int32, - tuple[jnp.int32, ...], - jnp.ndarray, - jnp.int32, - jnp.int32, - list[jnp.ndarray], - list[jnp.ndarray], - ], - ) -> tuple[ - jnp.int32, - tuple[jnp.int32, ...], - jnp.ndarray, - jnp.int32, - jnp.int32, - list[jnp.ndarray], - list[jnp.ndarray], - ]: - """Performs a single iteration of `jaxpr` in the device grid. - - Execution of `jaxpr` is preceded by reading kernel input buffers and - followed by writing kernel output buffers. - - Args: - carry: (iteration_idx, loop_idx, grid_point, prev_local_core_id, - cur_local_core_id, prev_start_indices, cur_start_indices). - - iteration_idx is the interation index. - - loop_idx are the program ids for each grid axis. - - grid_point is the grid point for the current loop iteration. - - prev_local_core_id is the (device-local) core id from the previous - loop iteration. - - cur_local_core_id is the (device-local) core id for the current loop - iteration. - - prev_start_indices is a rank-1 array that contains the start indices - for the slices of inputs and outputs processed in the previous loop - iteration. - - cur_start_indices is a rank-1 array that contains the start indices - for the slices of inputs and outputs processed in the current loop - iteration. - - Note that by carrying the previous *and* current start indices between - loop iterations, it suffices to compute only one list of start indices, - i.e. `next_start_indices` (see below), per iteration. - - Returns: - The carry for the next iteration. - """ - ( - iteration_idx, - loop_idx, - grid_point, - prev_local_core_id, - cur_local_core_id, - prev_start_indices, - cur_start_indices, - ) = carry - if interpret_params.grid_point_recorder is not None: - callback.io_callback( - interpret_params.grid_point_recorder, - (), + def _execute_grid_for_core(core_index): + # NOTE: We assume here that all parallel dimensions appear before all + # arbitrary dimensions in the grid. (We will have raised an error earlier + # if this is not the case.) + # + # TODO(jburnim): Are we overusing nested local functions here? + initial_iteration_idx = core_index * num_iterations_per_core + loop_bound = jnp.minimum( + (core_index + 1) * num_iterations_per_core, num_iterations) + + def _body( + carry: tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + ], + ) -> tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + ]: + """Performs one execution of the kernel body. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, grid_point, prev_start_indices, + cur_start_indices). + - iteration_idx: the interation index. + - loop_idx: internal indices for looping over the grid. + - grid_point: the current positions along all axes of the grid. + - prev_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + ( + iteration_idx, + loop_idx, grid_point, - cur_local_core_id, - ) - - next_local_core_id = (iteration_idx + 1) // num_iterations_per_core - - with pallas_core.grid_env(_get_local_grid_env(loop_idx)): - next_loop_idx = _get_next_indices(grid, loop_idx) - next_grid_point = _get_grid_point( - next_loop_idx, randomized_grid_coordinates - ) - next_start_indices = [ - _compute_start_indices( - bm, - next_grid_point, - *scalar_buffer_ids, - mesh=mesh, - compiler_params=compiler_params, - interpret_params=interpret_params, - ) - for bm in grid_mapping.block_mappings - ] - - # Copy slices of the input to the kernel buffers. - def _store_slice_to_kernel_input(index, input_var): - # Copy from the HBM buffer for the pallas_call input to the kernel - # input buffer. - # TODO(jburnim): Just use input_args[j] when the input is not aliased? - transform = indexing.NDIndexer( - indices=tuple( - indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip( - cur_start_indices[index], - block_shapes[index], - is_squeeze_dim[index], - ) - ), - shape=input_args[index].shape, - int_indexer_shape=(), - ) - sliced_val = callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # read is involved in a data race. - get, - jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), - device_id, - cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], - input_buffer_ids[index], - (transform,), - ordered=True, - ) + prev_start_indices, + cur_start_indices, + ) = carry + if interpret_params.grid_point_recorder is not None: callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, - (), - device_id, - cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], - input_ids[index], + interpret_params.grid_point_recorder, (), - sliced_val, - ordered=True, + grid_point, + core_index, ) - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue - assert len(cur_start_indices[j].shape) == 1 - assert len(prev_start_indices[j].shape) == 1 - jax.lax.cond( - (iteration_idx == 0) - | (cur_local_core_id != prev_local_core_id) - | jax.lax.reduce_or( - cur_start_indices[j] != prev_start_indices[j], axes=(0,) - ), - functools.partial(_store_slice_to_kernel_input, j, var), - lambda: None, + with pallas_core.grid_env(_get_local_grid_env(grid_point)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_grid_point = _get_grid_point( + next_loop_idx, randomized_grid_coordinates ) + next_start_indices = [ + _compute_start_indices( + bm, + next_grid_point, + core_index, + *scalar_buffer_ids, + mesh=mesh, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + + # Copy slices of the input to the kernel buffers. + def _store_slice_to_kernel_input(index, input_var): + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_squeeze_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + get, + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + input_buffer_ids[index], + (transform,), + ordered=True, + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + store, + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], + (), + sliced_val, + ordered=True, + ) - # Invoke the kernel. - _interpret_jaxpr( - jaxpr, - *kernel_buffer_ids, - mesh=mesh, - local_core_id=cur_local_core_id, - compiler_params=compiler_params, - interpret_params=interpret_params, - ) + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == initial_iteration_idx) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) - # Copy from the kernel buffers to slices of the output in HBM. - def _store_to_output_buffer(index, output_var): - kernel_output_val = callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # get is involved in a data race. - get, - output_var.aval, - device_id, - cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], - kernel_output_ids[j], - (), - ordered=True, - ) - transform = indexing.NDIndexer( - indices=tuple( - indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip( - cur_start_indices[num_inputs + index], - block_shapes[num_inputs + index], - is_squeeze_dim[num_inputs + index], - ) - ), - shape=output_vals[index].shape, - int_indexer_shape=(index), - ) - callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, - (), - device_id, - cur_local_core_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], - output_buffer_ids[index], - (transform,), - kernel_output_val, - ordered=True, + # Invoke the kernel. + _interpret_jaxpr( + jaxpr, + *kernel_buffer_ids, + mesh=mesh, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, ) - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue - assert len(cur_start_indices[num_inputs + j].shape) == 1 - assert len(next_start_indices[num_inputs + j].shape) == 1 - jax.lax.cond( - (iteration_idx + 1 == num_iterations) - | (cur_local_core_id != next_local_core_id) - | jax.lax.reduce_or( - cur_start_indices[num_inputs + j] - != next_start_indices[num_inputs + j], - axes=(0,), - ), - functools.partial(_store_to_output_buffer, j, var), - lambda: None, + # Copy from the kernel buffers to slices of the output in HBM. + def _store_to_output_buffer(index, output_var): + kernel_output_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. + get, + output_var.aval, + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], + kernel_output_ids[j], + (), + ordered=True, + ) + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[num_inputs + index], + block_shapes[num_inputs + index], + is_squeeze_dim[num_inputs + index], + ) + ), + shape=output_vals[index].shape, + int_indexer_shape=(index), + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + store, + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + output_buffer_ids[index], + (transform,), + kernel_output_val, + ordered=True, + ) + + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + jax.lax.cond( + (iteration_idx + 1 == loop_bound) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var), + lambda: None, + ) + + return ( + iteration_idx + 1, + next_loop_idx, + next_grid_point, + cur_start_indices, + next_start_indices, ) - return ( - iteration_idx + 1, - next_loop_idx, - next_grid_point, - cur_local_core_id, - next_local_core_id, - cur_start_indices, - next_start_indices, - ) + initial_loop_idx = _get_indices(grid, initial_iteration_idx) + initial_grid_point = _get_grid_point( + initial_loop_idx, randomized_grid_coordinates) + with pallas_core.grid_env(_get_local_grid_env(initial_grid_point)): + initial_start_indices = [ + _compute_start_indices( + bm, + initial_grid_point, + core_index, + *scalar_buffer_ids, + mesh=mesh, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] - initial_loop_idx = (jnp.int32(0),) * len(grid) - initial_grid_point = _get_grid_point( - initial_loop_idx, randomized_grid_coordinates - ) - with pallas_core.grid_env(_get_local_grid_env(initial_loop_idx)): - initial_start_indices = [ - _compute_start_indices( - bm, + _ = lax.while_loop( + lambda carry: carry[0] < loop_bound, + _body, + ( + initial_iteration_idx, + initial_loop_idx, initial_grid_point, - *scalar_buffer_ids, - mesh=mesh, - compiler_params=compiler_params, - interpret_params=interpret_params, - ) - for bm in grid_mapping.block_mappings - ] - # TODO(jburnim): Handle parallel grid dimensions + megacore. + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_start_indices, + ), + ) + + # TODO(jburnim): Should we only create happens-before here from core 0 to + # the other cores? callback.io_callback( _update_clocks_for_device_barrier, (), device_id, ordered=True ) - _ = lax.while_loop( - lambda carry: carry[0] < num_iterations, - body, - ( - jnp.int32(0), - initial_loop_idx, - initial_grid_point, - jnp.int32(0), # Previous core id is ignored on the first iteration. - jnp.int32(0), # Current core id is set to 0 for the first iteration. - initial_start_indices, # Previous start indices are ignored on the first iteration. - initial_start_indices, - ), - ) + + _thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device) + + # TODO(jburnim): Should we only create happens-before here from the other + # # cores to core 0? callback.io_callback( - _update_clocks_for_device_barrier, (), device_id, ordered=True - ) + _update_clocks_for_device_barrier, (), device_id, ordered=True) # Read the output from the allocated output buffers. ret = [ diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 5580cb2abd73..eb0eb05ac6a7 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -558,6 +558,7 @@ jax_multiplatform_test( disable_configs = ["cpu_shardy"], enable_backends = ["cpu"], deps = [ + "//jax:experimental", "//jax:pallas", "//jax:pallas_tpu", ] + py_deps([ diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 5bfca2270aa1..5fafb3007993 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -21,6 +21,7 @@ from collections.abc import Callable import dataclasses import functools +import threading from absl.testing import absltest from absl.testing import parameterized @@ -91,7 +92,7 @@ def _recorder(grid_point, core_id): @property def grid_points(self) -> list[ProcessedGridPoint]: - return self._grid_points + return sorted(self._grid_points, key=lambda x: x.core_id) # TODO(jburnim): Figure out how to safely run different instance of TPU @@ -471,40 +472,52 @@ def kernel_call_dynamic_parallel_dimension(): with self.assertRaises(jax.errors.ConcretizationTypeError): kernel_call_dynamic_parallel_dimension() - def test_core_map_over_one_core(self): - mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) + @parameterized.parameters(1, 2, 4) + def test_core_map(self, num_cores): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=num_cores) + interpret = pltpu.InterpretParams() @jax.jit def f(x): y = jnp.zeros_like(x) def inner(refs): x_ref, y_ref = refs - @pl.core_map(mesh, interpret=pltpu.InterpretParams()) + @pl.core_map(mesh, interpret=interpret) def _(): - num_cores = jax.lax.psum(1, "x") + num_cores = jax.lax.axis_size('x') slc_size = 16 // num_cores - def alloc(x_vmem_ref, y_vmem_ref, sem): - core_index = jax.lax.axis_index("x") + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('x') slc = pl.ds(core_index * slc_size, slc_size) pltpu.async_copy( x_ref.at[slc], x_vmem_ref, - sem, + dma_sem, ).wait() - y = x_vmem_ref[...] + 1 + jax.lax.axis_index("x") + y = x_vmem_ref[...] + jax.lax.axis_index('x') + 1 y_vmem_ref[...] = y - pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() pl.run_scoped( alloc, pltpu.VMEM((slc_size, 128), x_ref.dtype), pltpu.VMEM((slc_size, 128), y_ref.dtype), pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, ) _, y = pl.run_state(inner)((x, y)) return y x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + expected_out = ( + x.reshape((num_cores, -1, 128)) + 1 + + jnp.arange(num_cores, dtype=jnp.int32)[..., None, None] + ).reshape(x.shape) y = f(x) - np.testing.assert_array_equal(y, x + 1) + np.testing.assert_array_equal(y, expected_out) def test_two_cores_along_parallel_dimension_with_race(self): def kernel(x_ref, o_ref, vmem_ref): @@ -566,32 +579,43 @@ def kernel(x_ref, o_ref, vmem_ref): np.testing.assert_allclose(y, 2.0 * x) def test_parallel_dimension_and_multiple_cores(self): - def kernel(s_ref, o_ref): + def kernel(s_ref, in_ref, o_ref): + # NOTE: diff should be 0. + diff = in_ref[...] - jnp.float32(4 * pl.program_id(0) + pl.program_id(1)) + s = s_ref[0] s_ref[0] = s + 1 - o_ref[:] = jax.lax.full_like(o_ref, s) + o_ref[:] = jax.lax.full_like(o_ref, s) + diff def kernel_call(s, num_cores_per_device, grid_point_recorder): + block_input = jnp.repeat( + jnp.repeat( + jnp.arange(16, dtype=jnp.float32).reshape((4, 4)), 128, axis=1), + 8, axis=0) return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), grid=(4, 4), - in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec((8, 128), lambda i, j: (i, j)), + ], out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), interpret=pltpu.InterpretParams( random_seed=12345, num_cores_per_device=num_cores_per_device, grid_point_recorder=grid_point_recorder, + detect_races=True, ), compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel', 'arbitrary') ), - )(s) + )(s, block_input) with self.subTest('num_cores_per_device=1'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 1, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 1, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -630,7 +654,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): with self.subTest('num_cores_per_device=2'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 2, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 2, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -669,7 +693,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): with self.subTest('num_cores_per_device=3'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 3, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 3, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -708,7 +732,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): with self.subTest('num_cores_per_device=4'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 4, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 4, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -747,7 +771,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): with self.subTest('num_cores_per_device=5'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 5, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 5, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -786,7 +810,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): with self.subTest('num_cores_per_device=6'): with GridPointRecorderContext() as grid_point_recorder: result = jax.jit(kernel_call, static_argnums=(1, 2))( - jnp.zeros((1,), jnp.int32), 6, grid_point_recorder.get_recorder() + jnp.zeros((1,), jnp.float32), 6, grid_point_recorder.get_recorder() ) np.testing.assert_allclose( result[::8, ::128], @@ -822,5 +846,27 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): ], ) + def test_thread_map(self): + barrier = threading.Barrier(8) + lock = threading.Lock() + concurrent_calls = [0] + max_concurrent_calls = [0] + + def _barrier(): + with lock: + concurrent_calls[0] += 1 + max_concurrent_calls[0] = max( + max_concurrent_calls[0], concurrent_calls[0]) + barrier.wait() + with lock: + concurrent_calls[0] -= 1 + + def f(core_index): + del core_index + jax.experimental.io_callback(_barrier, (), ordered=True) + + mosaic_interpret._thread_map(f, 8) + self.assertEqual(max_concurrent_calls[0], 8) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 3c1df032564c7346f0e26ee858b8fc4b9f7eb8a7 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 3 Jun 2025 16:46:01 -0700 Subject: [PATCH 1509/1769] [Pallas] Add forward-compatible i1 broadcast. Missed this in https://github.com/jax-ml/jax/commit/6c18aa8a468e35b8c11b101dceaa43d05b497177 PiperOrigin-RevId: 766873399 --- jax/_src/pallas/mosaic/lowering.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b1a2c186e2c9..73becf5e03c6 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1928,6 +1928,22 @@ def _broadcast_in_dim_lowering_rule( if aval_in.shape == shape: return val + if jnp.issubdtype(aval_in.dtype, jnp.bool_) and ( + ctx.forward_compatible or is_cloud_tpu_older_than(2025, 6, 3) + ): + # Direct broadcasts for bools are not supported in Mosaic due to booleans + # living in mask registers and broadcast operating on vregs. Broadcast as an + # integer instead and cast back to a bool. + def _proxy_fun(val, *, shape, broadcast_dimensions): + int_val = jnp.where(val, 1, 0) + bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions) + return bcast_val == 1 + + proxy_lowering = lower_fun(_proxy_fun, multiple_results=False) + return proxy_lowering( + ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions + ) + if broadcast_dimensions: out_shape_list = [1] * len(shape) for i, s in zip(broadcast_dimensions, aval_in.shape): From ab84dde2301ca57edea443d64cadf003bd126626 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 3 Jun 2025 16:47:38 -0700 Subject: [PATCH 1510/1769] Update more uses of `backend.compile` to `backend.compile_and_load`. PiperOrigin-RevId: 766874092 --- docs/autodidax.ipynb | 2 +- docs/autodidax.md | 2 +- docs/autodidax.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 16d4da37b3f2..f57ce09e0bf6 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2020,7 +2020,7 @@ " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", " backend = xb.get_backend(None)\n", - " compiled = backend.compile(output.getvalue(), backend.devices()[:1])\n", + " compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1])\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 870ee20f0f9a..5bf0e8f78e12 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1590,7 +1590,7 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - compiled = backend.compile(output.getvalue(), backend.devices()[:1]) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: diff --git a/docs/autodidax.py b/docs/autodidax.py index b0dbf9f73d9f..695fc9993df5 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1582,7 +1582,7 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) backend = xb.get_backend(None) - compiled = backend.compile(output.getvalue(), backend.devices()[:1]) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: From 31fde29e415a8915a8cf23c97c331e9be44e2e08 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 17:52:37 -0700 Subject: [PATCH 1511/1769] Fix pgle test breakage PiperOrigin-RevId: 766895212 --- tests/pgle_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index e136c3ab8a5a..7087bcad58bf 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -158,12 +158,12 @@ def f(x): with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected # One is the funtion f, the other one is multi slice module - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) # Run 2: Second PGLE run. Profile should be empty. - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) @@ -175,7 +175,7 @@ def f(x): os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) # Run 3: The module should be recompiled with FDO profiles - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) @@ -190,7 +190,7 @@ def f(x): ) # Run 4: Fast-path should be used after PGLE is done - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertLess(cache_miss_count(), 2) From 8519fd2e70b2ef548f29eeb24c83486bffc65048 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 3 Jun 2025 20:27:45 -0700 Subject: [PATCH 1512/1769] [Pallas] Fix missing sub lowering rule for sparsecore. PiperOrigin-RevId: 766941330 --- jax/_src/pallas/mosaic/lowering.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 73becf5e03c6..edfce4770f10 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2504,7 +2504,9 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.sub_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.sub_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out From 002078b3184add3c06fd80aca8799a09eeb70c12 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Jun 2025 21:41:16 -0700 Subject: [PATCH 1513/1769] Only infer sharding from input in full_like (in eager mode) if the input's sharding is concrete i.e. does not contain an AbstractMesh PiperOrigin-RevId: 766962130 --- jax/_src/lax/lax.py | 3 ++- tests/pjit_test.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e03951eb4730..43dffc7bef9c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3564,12 +3564,13 @@ def full_like(x: ArrayLike | DuckTypedArray, # This bypasses the check. and not isinstance(x, core.Tracer) and hasattr(x, 'sharding') + and x.sharding is not None + and x.sharding._is_concrete and getattr(x, '_committed', True) and not weak_type and fill_shape == np.shape(x) # type: ignore[arg-type] ) if use_x_sharding: - # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2fe2c4696cbe..fa68a441c4c0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5912,6 +5912,14 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_full_like_eager_non_concrete_sharding(self): + s = NamedSharding(mesh_lib.AbstractMesh((2,), ('x',)), P('x')) + arr = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) + out = jax.lax.full_like(arr, 0) + # The sharding is single device because the sharding of input `arr`` to + # full_like is not concrete. + self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0])) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) From 7d93eee9e79e72e79d45c721f08b72a53e902a07 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 3 Jun 2025 21:43:36 -0700 Subject: [PATCH 1514/1769] Make experimental pytree_serialization visible in OSS jax build PiperOrigin-RevId: 766962750 --- BUILD.bazel | 1 + jax/experimental/array_serialization/BUILD | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 44885124797f..d51b9f8c9cef 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -52,6 +52,7 @@ wheel_sources( "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", + "//jax/experimental/array_serialization:pytree_serialization", "//jax/experimental/jax2tf", "//jax/experimental/mosaic/gpu/examples:flash_attention", "//jax/experimental/mosaic/gpu/examples:matmul", diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index ebd78decf6a3..559d8eb16269 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -52,9 +52,10 @@ pytype_library( "//jax", "//jax/experimental/array_serialization:pytree_serialization_utils", "//jax/experimental/array_serialization:tensorstore_impl", - "//third_party/py/absl/logging", - "//third_party/py/numpy", - ], + ] + py_deps([ + "absl/logging", + "numpy", + ]), ) pytype_library( @@ -62,9 +63,10 @@ pytype_library( srcs = ["pytree_serialization_utils.py"], deps = [ "//jax", - "//third_party/py/absl/logging", - "//third_party/py/numpy", - ], + ] + py_deps([ + "absl/logging", + "numpy", + ]), ) jax_multiplatform_test( From c222fb6c12abc5a8ac4be4998587c7ca5225e84a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Jun 2025 01:12:51 -0700 Subject: [PATCH 1515/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a3cb8a0de31a1984a56802981ed3987f63879ce5. PiperOrigin-RevId: 767027087 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 5ee95ca302a3..1fc6ccae44dc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9b20d33306b4f15bc17f0235a786dddac96d046e" -XLA_SHA256 = "eba6c387448b05fee0f26e7a28ead5b4ad17342b45f451f6a8748a87c0141b1c" +XLA_COMMIT = "a3cb8a0de31a1984a56802981ed3987f63879ce5" +XLA_SHA256 = "26d7d752d46e0c753525feb536af5259fb2bef04432b93853cba26d330ea0dde" def repo(): tf_http_archive( From b7833e94c1940ed475dae1f5e83e2a984cda5cea Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 4 Jun 2025 03:14:52 -0700 Subject: [PATCH 1516/1769] #sdy Fallback to GSPMD in JAX export if the loaded module was lowered for GSPMD. The final module that will be created by JAX export will contain a bit of Shardy and GSPMD ops. What we then do during compilation is detect whether there is a mix of these ops. If there is, we override the build option and instead use GSPMD for propagation (we have well tested code to export Shardy->GSPMD, but not vice versa). PiperOrigin-RevId: 767064075 --- jax/_src/export/_export.py | 38 ++++++++++++++++------ jax/_src/interpreters/mlir.py | 4 +-- jaxlib/BUILD | 5 ++- jaxlib/py_client.cc | 59 +++++++++++++++++++++++++++++++++++ tests/export_test.py | 44 +++++++++++++++++++++++++- 5 files changed, 136 insertions(+), 14 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b390574c0a79..5e3c4cf0f209 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1431,9 +1431,16 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, submodule_bc = mlir.module_to_bytecode(submodule) shardy_enabled = _jax.sdy.lowered_with_shardy(submodule_bc) if shardy_enabled: + if not config.use_shardy_partitioner.value: + raise ValueError( + "The function was exported with shardy enabled but you are calling " + "it with Shardy disabled. Please enable Shardy using " + "`--jax_use_shardy_partitioner=True`.") submodule = ir.Module.parse( _jax.sdy.sdy_round_trip_import_shardings(submodule_bc) ) + elif config.use_shardy_partitioner.value: + shardy_enabled = True with submodule.context: pipeline = passmanager.PassManager.parse( @@ -1444,7 +1451,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, if shardy_enabled: sdy_mesh_axes = _jax.sdy.get_mesh(mlir.module_to_bytecode(submodule)) mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1]) - if sdy_mesh_axes else mesh_lib.empty_abstract_mesh) + if sdy_mesh_axes else None) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): @@ -1473,15 +1480,19 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ) # Apply in_shardings - if shardy_enabled: + if mesh: + # A mesh only exists if Shardy is enabled. args = tuple( wrap_with_sharding( ctx, x, x_aval, - _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] + _hlo_sharding_to_named_sharding(x_sharding, mesh), use_shardy=True) # type: ignore[arg-type] for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) else: + # Since there is no mesh - either due to shardy being disabled or the loaded + # function being lowered for GSPMD (so no shardy mesh) - need to create a + # GSPMD sharding from the HLO sharding (can't use shardy lowering). args = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding) + wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) symtab = ir.SymbolTable(submodule.operation) @@ -1570,14 +1581,19 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):], exported.out_avals, ctx.avals_out)) # Apply out_shardings - if shardy_enabled: + if mesh: + # A mesh only exists if Shardy is enabled. results = tuple( wrap_with_sharding( - ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] + ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh), + use_shardy=True) # type: ignore[arg-type] for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) else: + # Since there is no mesh - either due to shardy being disabled or the loaded + # function being lowered for GSPMD (so no shardy mesh) - need to create a + # GSPMD sharding from the HLO sharding (can't use shardy lowering). results = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding) + wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) return results @@ -1588,12 +1604,14 @@ def wrap_with_sharding( ctx: mlir.LoweringRuleContext, x: ir.Value, x_aval: core.AbstractValue, - x_sharding: sharding_impls.NamedSharding | HloSharding | None, + x_sharding: sharding_impls.NamedSharding | sharding_impls.GSPMDSharding | HloSharding | None, + use_shardy: bool, ) -> ir.Value: if x_sharding is None: return x - if config.use_shardy_partitioner.value: + if use_shardy: x_sharding = x_sharding._to_sdy_sharding(x_aval.ndim) # type: ignore else: x_sharding = x_sharding.to_proto() # type: ignore - return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding) # type: ignore[arg-type] + return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding, # type: ignore[arg-type] + allow_shardy_lowering=use_shardy) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0256057b8b09..6bad0ee2a018 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2731,7 +2731,7 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): - if config.use_shardy_partitioner.value: + if isinstance(sharding, (SdyArray, SdyArrayList)): op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: op.attributes["mhlo.sharding"] = get_sharding_attr(sharding) @@ -2740,7 +2740,7 @@ def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): def get_sharding_attr( sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: - if config.use_shardy_partitioner.value: + if isinstance(sharding, (SdyArray, SdyArrayList)): return sharding.build() # type: ignore else: # If there are very large numbers of devices, use the proto representation. diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 834103063aae..6fd606966b53 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -870,12 +870,14 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/profiler/lib:traceme", @@ -915,6 +917,7 @@ cc_library( "@xla//xla/python/pjrt_ifrt:pjrt_dtype", "@xla//xla/python/pjrt_ifrt:xla_ifrt", "@xla//xla/service:platform_util", + "@xla//xla/service/spmd/shardy:constants", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/framework:allocator", "@xla//xla/tsl/platform:env", diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 8e78f024e1ae..b663c00da35c 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -35,9 +35,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Visitors.h" #include "mlir/Pass/PassManager.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -48,6 +50,7 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" #include "jaxlib/guard_lib.h" #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_array.h" @@ -60,6 +63,7 @@ limitations under the License. #include "jaxlib/python_ref_manager.h" #include "jaxlib/sharding.h" #include "jaxlib/traceback.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -89,6 +93,7 @@ limitations under the License. #include "xla/python/types.h" #include "xla/python/version.h" #include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/service/spmd/shardy/constants.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" @@ -399,6 +404,47 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, std::move(ifrt_loaded_host_callbacks)); } +// Returns true if the module has at least one GSPMD attribute or op, like an +// `mhlo.sharding` attribute or `Sharding` custom call. +// TODO(b/420837831): delete this once we don't fall back to GSPMD. +bool HasGspmdAttrsOrOps(mlir::ModuleOp module) { + for (auto func : module.getOps()) { + for (int64_t arg_index = 0; arg_index < func.getNumArguments(); + ++arg_index) { + if (func.getArgAttr(arg_index, sdy::kXlaShardingAttr)) { + return true; + } + } + for (int64_t result_index = 0; result_index < func.getNumResults(); + ++result_index) { + if (func.getResultAttr(result_index, sdy::kXlaShardingAttr)) { + return true; + } + } + } + // Check the module for a `Sharding` custom call. + bool has_gspmd = false; + module->walk([&has_gspmd](mlir::stablehlo::CustomCallOp custom_call) { + if (custom_call.getCallTargetName() == + sdy::kShardingCustomCallTargetName && + custom_call->hasAttr(sdy::kXlaShardingAttr)) { + has_gspmd = true; + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + return has_gspmd; +} + +// Check if the module has any sort of Shardy mesh: +// - `mesh` +// - `maximal_mesh_{X}` +// - `empty_mesh` +// TODO(b/420837831): delete this once we don't fall back to GSPMD. +bool HasShardyMesh(mlir::ModuleOp module) { + return !module.getOps().empty(); +} + } // namespace /* static */ absl::StatusOr> @@ -483,6 +529,19 @@ PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); + // TODO(b/420837831): Remove this once we don't need to fall back to GSPMD. + if (options.executable_build_options.use_shardy_partitioner() && + HasGspmdAttrsOrOps(module.get())) { + LOG(WARNING) + << "Module has GSPMD attrs or ops, but Shardy is enabled. Disabling " + "Shardy and falling back to using GSPMD propagation."; + options.executable_build_options.set_use_shardy_partitioner(false); + if (HasShardyMesh(module.get())) { + // Shardy is not enabled, but the module has shardy ops. Likely due to + // export loading a GSPMD checkpoint. Fall back to GSPMD. + TF_RETURN_IF_ERROR(ExportShardyForGSPMD(*module)); + } + } return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), diff --git a/tests/export_test.py b/tests/export_test.py index 598a6634e1e3..0dfebdcec054 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -2008,7 +2008,7 @@ def f(x, y): r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b) self.assertAllClose(a + b, r) - def test_lower_wth_different_meshes_axis_names(self): + def test_lower_with_different_meshes_axis_names(self): mesh1 = jtu.create_mesh((4, 2), ("a", "b")) mesh2 = jtu.create_mesh((4, 2), ("x", "y")) @jax.jit @@ -2033,6 +2033,48 @@ def f(tree): else: get_exported(f)(args) + @jtu.parameterized_filterable( + kwargs=[ + {"use_shardy_on_save": True, "error_msg": "Please enable Shardy"}, + {"use_shardy_on_save": False, "error_msg": ""}, + ]) + def test_lower_load_with_different_partitioners(self, use_shardy_on_save, + error_msg): + old_shardy = config.use_shardy_partitioner.value + try: + jax.config.update("jax_use_shardy_partitioner", use_shardy_on_save) + mesh = jtu.create_mesh((8,), ("a",)) + @jax.jit + def f(x, y): + z = x + y + return jax.lax.with_sharding_constraint( + z, NamedSharding(mesh, P("a"))) + + args = ( + jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh, P(None, "a"))), + jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh, P("a")))) + + exp = get_exported(f)(*args) + + jax.config.update("jax_use_shardy_partitioner", not use_shardy_on_save) + + a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + a = jax.device_put(a, NamedSharding(mesh, P(None, "a"))) + b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + b = jax.device_put(b, NamedSharding(mesh, P("a"))) + + if use_shardy_on_save: + with self.assertRaisesRegex(ValueError, error_msg): + jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) + else: + jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) + finally: + jax.config.update("jax_use_shardy_partitioner", old_shardy) + if __name__ == "__main__": From d34f1ddbaf7a706fe3abda7e527af5f13cf341e8 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 4 Jun 2025 03:32:58 -0700 Subject: [PATCH 1517/1769] [Pallas/Mosaic GPU] Expose the new `TCGEN05_ROW` layout. `TCGEN05_ROW` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`. PiperOrigin-RevId: 767068597 --- jax/_src/pallas/mosaic_gpu/primitives.py | 4 ++++ jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/tcgen05.py | 7 +++++++ 3 files changed, 12 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 2bd191b859ba..616f7e501cd8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1451,6 +1451,7 @@ class Layout(enum.Enum): WG_STRIDED = enum.auto() TCGEN05 = enum.auto() + TCGEN05_ROW = enum.auto() def __call__(self, *args, **kwargs) -> ParameterizedLayout: return ParameterizedLayout(self, args, kwargs) @@ -1480,6 +1481,9 @@ def check_no_args(): case Layout.TCGEN05: check_no_args() return mgpu.TCGEN05_LAYOUT + case Layout.TCGEN05_ROW: + check_no_args() + return mgpu.TCGEN05_ROW_LAYOUT @dataclasses.dataclass(frozen=True) class ParameterizedLayout: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index cd207c2b2519..37bfa227fe61 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -105,4 +105,5 @@ from .tcgen05 import ( LAYOUT as TCGEN05_LAYOUT, # noqa: F401 + ROW_LAYOUT as TCGEN05_ROW_LAYOUT, # noqa: F401 ) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 0438400f6310..d72994f45a87 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -43,6 +43,13 @@ lane_dims=(-4, -3), vector_dim=-1, ) +# ROW_LAYOUT is to LAYOUT as WGMMA_ROW_LAYOUT is to WGMMA_LAYOUT. +ROW_LAYOUT = fa.TiledLayout( + fa.Tiling(tiles=((128,), (32,), (8,), (1,), (1,))), + warp_dim=-5, + lane_dims=(-3, fa.Replicated(times=4)), + vector_dim=-1 +) # A layout resembling the logical organization of TMEM. The 128 rows in a tile # are assigned to 128 lanes in the warpgroup. Useful when the result needs to be # processed in registers and then stored back into TMEM. Should not be used if From d2d6211ff095c86eaaa02aba2dc82ef00f5e16a4 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 4 Jun 2025 03:52:03 -0700 Subject: [PATCH 1518/1769] #sdy Have JAX export compat tests also run on Shardy. The lowering b/w Shardy and GSPMD is slightly different with the custom calls, so I needed to choose different test data based on whether or not Shardy was enabled. PiperOrigin-RevId: 767074094 --- .../annotate_data_placement.py | 67 ++++++++++++++++++- .../tpu_Sharding.py | 41 +++++++++++- tests/BUILD | 7 +- tests/export_back_compat_test.py | 17 ++++- 4 files changed, 121 insertions(+), 11 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py index bf70df2cdb3a..d3aab292c9fe 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -15,11 +15,12 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +from numpy import array, float32 +data_2025_04_07_tpu = {} # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2025_04_07_tpu = dict( +data_2025_04_07_tpu['gspmd'] = dict( testdata_version=1, platform='tpu', custom_call_targets=['annotate_device_placement'], @@ -46,7 +47,38 @@ ) # End paste # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2025_04_07_cuda = dict( +data_2025_04_07_tpu['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":801:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x86\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +data_2025_04_07_cuda = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['gspmd'] = dict( testdata_version=1, platform='cuda', custom_call_targets=['annotate_device_placement'], @@ -71,3 +103,32 @@ xla_call_module_version=9, nr_devices=1, ) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['shardy'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":806:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x9a\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py index f2d8be3b958a..1caac00a4680 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py @@ -15,8 +15,10 @@ import datetime from numpy import array, float32 +data_2023_03_16 = {} + # Pasted from the test output (see module docstring) -data_2023_03_16 = dict( +data_2023_03_16['gspmd'] = dict( testdata_version=1, platform='tpu', custom_call_targets=['SPMDFullToShardShape', 'SPMDShardToFullShape', 'Sharding'], @@ -47,3 +49,40 @@ xla_call_module_version=4, nr_devices=2, ) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2023_03_16['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['xla.sdy.FuncResultSharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([[0., 1., 2., 3.], + [4., 5., 6., 7.]], dtype=float32),), + expected_outputs=(array([[4., 5., 6., 7.], + [0., 1., 2., 3.]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":783:6) +#loc4 = loc("jit(func)/jit(main)/shard_map"(#loc2)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}, {}]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result", mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy"}} : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc4) + %1 = call @xla.sdy.manual_computation_body(%0) : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc4) + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) {mhlo.frontend_attributes = {xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc4) + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc4) + return %3 : tensor<2x4xf32> loc(#loc) + } loc(#loc) + func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc2))) -> tensor<1x4xf32> { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc5) + return %0 : tensor<1x4xf32> loc(#loc4) + } loc(#loc4) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":779:13) +#loc5 = loc("jit(func)/jit(main)/ppermute"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x9fy\x13\x013\x0f\x0b\x07\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x13\x13\x13\x03G\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1b\x0b\x13\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x1b\x0b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02\xae\x03\x1d\x1f!\x05\x11\x1f\x05\x13\x03\t\x0b\r\x03\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07>\x0c\r\x1d%\'\x05#\x17\x07.\x0c\x1b\x1d+\x05\x05%\x03\x03\x03g\x03\x03\x03m\x03\x03\x03u\x03\x01\x1d\'\x1d)\x0b\x03\x1d+\x1d-\x1d/\x1d1\x1d3\x1d5\x05\x03\x03\x03K\r\x05MO=?\x1d\x11\r\x03;Q\x1d7#\r\x03\x03W\r\x05Y[=?\x1d9\x1d;\x1d=\x1d?#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05i7CE\x1dA\x1dC\r\x05CEo7\x1dE\x1dG\x05\x01\r\x03;7\x1dI\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xbb\x05\x01Q\x05\t\x01\x07\x04\xa9\x03\x01\t\x05P\x05\x03\x07\x04_\x03\x0b\x17\x03\x0f)\x00\x03G\x01-\x05\x03\x05\x03\x01\x0bF\x01\x07\x03\x05\x03\x03\x03G\x01/\t\x03\x07\x03\x05\x03G\x011\x0b\x03\x07\x03\x07\x07\x04\x05\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\xaa\x0bK77-7+\x0f\x0b\x0f!E/)A+\x1d#a\x03\x05;=\x13%)9\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00\x00#sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>\x00xla.sdy.sharding\x00mhlo.sharding\x00{devices=[2,1]<=[2]}\x00xla.sdy.manual_computation_body\x00xla.sdy.manual_axes\x00#sdy\x00#sdy.sharding<@mesh, [{"a"}, {}]>\x00jax.result_info\x00result\x00main\x00public\x00xla.sdy.in_shardings\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00xla.sdy.FuncResultSharding\x00\x08a\x11\x05;\x01\x0bISU]_\x1195k3G333\x03A\x1195q3s333\x1195w3G333\x0b3a3A5\x05ce', + xla_call_module_version=9, + nr_devices=2, +) # End paste diff --git a/tests/BUILD b/tests/BUILD index faa3367aecba..a38f5e91192e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2039,10 +2039,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], - # TODO(b/415285434): enable once we have backwards compatibility support with GSPMD checkpoints. - # enable_configs = [ - # "tpu_v3_x4_shardy", - # ], + enable_configs = [ + "tpu_v3_x4_shardy", + ], tags = [], deps = [ "//jax:internal_export_back_compat_test_data", diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index be87b4e3e5b3..0f443ae47929 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -778,7 +778,12 @@ def func(x): # b: f32[2, 4] perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) - data = self.load_testdata(tpu_Sharding.data_2023_03_16) + data = tpu_Sharding.data_2023_03_16 + if jax.config.jax_use_shardy_partitioner: + data = data["shardy"] + else: + data = data["gspmd"] + data = self.load_testdata(data) with mesh: self.run_one_test(func, data) @@ -801,9 +806,15 @@ def func(x, y): return x + y if platform == "tpu": - data = self.load_testdata(annotate_data_placement.data_2025_04_07_tpu) + data = annotate_data_placement.data_2025_04_07_tpu + else: + data = annotate_data_placement.data_2025_04_07_cuda + + if jax.config.jax_use_shardy_partitioner: + data = data["shardy"] else: - data = self.load_testdata(annotate_data_placement.data_2025_04_07_cuda) + data = data["gspmd"] + data = self.load_testdata(data) self.run_one_test(func, data) From b21861712f6853a5ee1bf8eeaae1336104a4eee8 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 4 Jun 2025 06:54:21 -0700 Subject: [PATCH 1519/1769] [Mosaic GPU] Use the `mosaic_gpu.sliceSMEM` MLIR op when using WG semantics. PiperOrigin-RevId: 767125470 --- jax/experimental/mosaic/gpu/core.py | 38 +++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 79a0cd56328b..435e6976a5b1 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -316,6 +316,19 @@ def dealloc(self): ) +def _slice_smem( + result: ir.Type, + smem_base: ir.Value, + offset: ir.Value, # This should be an ir.IndexType. + lowering_semantics: LoweringSemantics, +) -> ir.Value: + if lowering_semantics == LoweringSemantics.Warpgroup: + offset = arith.index_cast(ir.IntegerType.get_signless(32), offset) + return dialect.slice_smem(result, offset) + else: + return memref.view(result, smem_base, offset, []) + + def _construct_smem_reftree( cluster_shape: tuple[int, int, int], dynamic_smem: ir.Value, @@ -392,9 +405,11 @@ def ref(member_thunks=member_thunks): cluster_shape, ) case TMEM(shape, dtype, layout=layout, collective=collective, packing=packing): - addr_ref = memref.view( + addr_ref = _slice_smem( ir.MemRefType.get([], i32, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) if layout is None: layout = tcgen05._infer_tmem_layout( @@ -410,9 +425,11 @@ def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): dynamic_smem_offset += 4 # i32 takes up 4 bytes case _: mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( + tile_smem = _slice_smem( ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) dynamic_smem_offset += _count_buffer_bytes(ref_ty) ref = tile_smem @@ -510,18 +527,19 @@ def _launch( smem = ir.Attribute.parse("#gpu.address_space") with ir.InsertionPoint(launch_op.body.blocks[0]): dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) + ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) if profiler_spec: - prof_smem = memref.view( + prof_smem = _slice_smem( ir.MemRefType.get( (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, + i32, + memory_space=smem, ), - dynamic_smem, c(profiler_start, index), [], + dynamic_smem, + c(profiler_start, index), + lowering_semantics, ) prof = profiler.OnDeviceProfiler( profiler_spec, prof_smem, maybe_prof_buffer From 6e75a04eafe67b3982a7bfa63045a491bede1a42 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 4 Jun 2025 07:12:33 -0700 Subject: [PATCH 1520/1769] Raise `NotImplementedError` instead of `ValueError` when using Shardy without sharding rule. PiperOrigin-RevId: 767131346 --- jax/_src/custom_partitioning.py | 8 +++++--- tests/pjit_test.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index feb1e0c39cc6..0d0411f5177a 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -503,9 +503,11 @@ def __call__(self, *args, **kwargs): if (self.sharding_rule is None and (self.propagate_user_sharding is not None or self.infer_sharding_from_operands is not None)): - raise ValueError("Shardy is used, but sharding propagation callbacks " - "instead of sharding_rule are provided. Need to " - "provide sharding_rule to migrate to Shardy.") + raise NotImplementedError( + "Shardy is used, but sharding propagation callbacks instead of " + "sharding_rule are provided. Need to provide sharding_rule to " + "migrate to Shardy." + ) sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fa68a441c4c0..04df138e68e2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8812,8 +8812,8 @@ def f(x): mesh = jtu.create_mesh((4, 2), ('x', 'y')) x = jax.device_put(np.arange(32 * 16).reshape(32, 16), NamedSharding(mesh, P(None, 'x'))) - with self.assertRaisesRegex(ValueError, "provide sharding_rule to migrate " - "to Shardy"): + with self.assertRaisesRegex( + NotImplementedError, 'provide sharding_rule to migrate to Shardy'): jax.jit(f)(x) def test_reshard_empty_mesh_error(self): From b9658ed8af4f239661448172ba3e0afe38849595 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 4 Jun 2025 07:26:39 -0700 Subject: [PATCH 1521/1769] [jaxlib] Bind 'compile' to `xla::PyClient::Compile` rather than `xla::PyClient::CompileAndLoad`. - Remove redundant `xla::PyClient` `compile` bindings. - Remove host_callback arguments to `compile`. PiperOrigin-RevId: 767135320 --- jaxlib/_jax/__init__.pyi | 6 +-- jaxlib/py_client.cc | 82 +++----------------------------- jaxlib/py_compile_only_client.cc | 13 ++--- jaxlib/xla_client.py | 3 +- jaxlib/xla_client.pyi | 1 + 5 files changed, 15 insertions(+), 90 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 67dc9ffc6001..26125d7edcbc 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -511,8 +511,7 @@ class Client: computation: str | bytes, executable_devices: DeviceList | Sequence[Device], compile_options: CompileOptions = ..., - host_callbacks: Sequence[Any] = ..., - ) -> LoadedExecutable: ... + ) -> Executable: ... def compile_and_load( self, computation: str | bytes, @@ -560,8 +559,7 @@ class CompileOnlyPyClient(Client): computation: str | bytes, executable_devices: DeviceList | Sequence[Device], compile_options: CompileOptions = ..., - host_callbacks: Sequence[Any] = ..., - ) -> LoadedExecutable: ... + ) -> Executable: ... class CpuCollectives: ... diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index b663c00da35c..0ddadbc3038a 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -818,93 +818,25 @@ PyType_Slot PyClient::slots_[] = { .def( "compile", [](nb_class_ptr client, nb::bytes mlir_module, - jax::PyDeviceList& py_executable_devices, CompileOptions options, - std::vector host_callbacks) { - ifrt::DeviceListRef executable_devices = - ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( - std::move(client), - std::string(mlir_module.c_str(), mlir_module.size()), - std::move(executable_devices), std::move(options), - std::move(host_callbacks))); - }, - nb::arg("computation"), nb::arg("executable_devices"), - nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()) - .def( - "compile", - [](nb_class_ptr client, nb::bytes mlir_module, - jax::PyDeviceList& py_executable_devices, CompileOptions options, - std::vector host_callbacks) { - ifrt::DeviceListRef executable_devices = - ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( - std::move(client), - std::string(mlir_module.c_str(), mlir_module.size()), - std::move(executable_devices), std::move(options), - std::move(host_callbacks))); - }, - nb::arg("computation"), nb::arg("executable_devices"), - nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()) - .def( - "compile", - [](nb_class_ptr client, std::string mlir_module, - jax::PyDeviceList& py_executable_devices, CompileOptions options, - std::vector host_callbacks) { - ifrt::DeviceListRef executable_devices = - ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( - std::move(client), std::move(mlir_module), - std::move(executable_devices), std::move(options), - std::move(host_callbacks))); - }, - nb::arg("computation"), nb::arg("executable_devices"), - nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()) - .def( - "compile", - [](nb_class_ptr client, std::string mlir_module, - jax::PyDeviceList& py_executable_devices, CompileOptions options, - std::vector host_callbacks) { + jax::PyDeviceList& py_executable_devices, CompileOptions options) { ifrt::DeviceListRef executable_devices = ValueOrThrow(py_executable_devices.ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( - std::move(client), std::move(mlir_module), - std::move(executable_devices), std::move(options), - std::move(host_callbacks))); - }, - nb::arg("computation"), nb::arg("executable_devices"), - nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()) - // The following two overloads are for users of deprecated APIs who call - // `backend.compile` but do not have visibility to `DeviceList`. - .def( - "compile", - [](nb_class_ptr client, nb::bytes mlir_module, - nb::sequence& py_executable_devices, CompileOptions options) { - ifrt::DeviceListRef executable_devices = - ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) - .ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( + return ValueOrThrow(PyClient::Compile( std::move(client), std::string(mlir_module.c_str(), mlir_module.size()), - std::move(executable_devices), std::move(options), - std::vector())); + std::move(executable_devices), std::move(options))); }, nb::arg("computation"), nb::arg("executable_devices"), nb::arg("compile_options") = CompileOptions()) .def( "compile", [](nb_class_ptr client, std::string mlir_module, - nb::sequence& py_executable_devices, CompileOptions options) { + jax::PyDeviceList& py_executable_devices, CompileOptions options) { ifrt::DeviceListRef executable_devices = - ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) - .ifrt_device_list()); - return ValueOrThrow(PyClient::CompileAndLoad( + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::Compile( std::move(client), std::move(mlir_module), - std::move(executable_devices), std::move(options), - std::vector())); + std::move(executable_devices), std::move(options))); }, nb::arg("computation"), nb::arg("executable_devices"), nb::arg("compile_options") = CompileOptions()) diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index bcae15cd6438..49f68ec1e24f 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -73,12 +73,7 @@ class CompileOnlyPyClient : public PyClient { absl::StatusOr> CompileUnloaded( absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, - CompileOptions options, std::vector host_callbacks) { - if (!host_callbacks.empty()) { - return Unimplemented( - "Compiling with host_callbacks not available with compile-only " - "client."); - } + CompileOptions options) { ifrt::ExecutableRef ifrt_executable; { nb::gil_scoped_release gil_release; @@ -125,8 +120,7 @@ void RegisterCompileOnlyClient(nb::module_& m) { ValueOrThrow(py_executable_devices.ifrt_device_list()); return ValueOrThrow(self.CompileUnloaded( absl::string_view(mlir_module.c_str(), mlir_module.size()), - std::move(executable_devices), std::move(options), - std::move(host_callbacks))); + std::move(executable_devices), std::move(options))); }, nb::arg("computation"), nb::arg("executable_devices"), nb::arg("compile_options") = CompileOptions(), @@ -134,8 +128,7 @@ void RegisterCompileOnlyClient(nb::module_& m) { .def("compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), nb::arg("computation"), nb::arg("executable_devices"), - nb::arg("compile_options") = CompileOptions(), - nb::arg("host_callbacks") = std::vector()); + nb::arg("compile_options") = CompileOptions()); } } // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 85de5f947e49..a098bd4baf3e 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 347 +_version = 348 # An internal increasing version number for protecting jaxlib code against # ifrt changes. @@ -308,6 +308,7 @@ def computation_count(): Array = _xla.Array ArrayImpl = _xla.ArrayImpl LoadedExecutable = _xla.LoadedExecutable +Executable = _xla.Executable DeviceList = _xla.DeviceList OpSharding = _xla.OpSharding HloSharding = _xla.HloSharding diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 80599e86676b..72e85500d1fe 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -37,6 +37,7 @@ from jaxlib._jax import HostBufferSemantics as HostBufferSemantics from jaxlib._jax import ifrt_programs as ifrt_programs from jaxlib._jax import Layout as Layout from jaxlib._jax import LoadedExecutable as LoadedExecutable +from jaxlib._jax import Executable as Executable from jaxlib._jax import Memory as Memory from jaxlib._jax import NamedSharding as NamedSharding from jaxlib._jax import OpSharding as OpSharding From bf635d87b69d9e27df34289cea2571ad4a67deda Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 4 Jun 2025 07:49:24 -0700 Subject: [PATCH 1522/1769] [jax::compiler] Bind `compiler.backend_compile` to `xla::PyClient::Compile`. Currently, we just forward any calls to `compiler.backend_compile_and_load`, which returns an `xla::PyLoadedExecutable` whereas we'd like `compiler.backend_compile` to return an unloaded `xla::PyExecutable`. PiperOrigin-RevId: 767142396 --- jax/_src/compiler.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 4f805034e99c..f501454b73be 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -292,11 +292,35 @@ def backend_compile( module: ir.Module, executable_devices: xc.DeviceList, options: xc.CompileOptions, - host_callbacks: Sequence[Any], -) -> xc.LoadedExecutable: - return backend_compile_and_load( - backend, module, executable_devices, options, host_callbacks - ) +) -> xc.Executable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + # Convert ir.Module to a string representation, unless the backend + # explicitly flags the ability to handle a module directly (avoiding the + # overhead of back and forth conversions). + # TODO(slebedev): Change the backend.compile() to accept ir.Module. + built_c: Any + if getattr(backend, "needs_str_ir", True): + built_c = mlir.module_to_bytecode(module) + else: + built_c = module + + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + + try: + return backend.compile(built_c, executable_devices, options) + except xc.XlaRuntimeError as e: + for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: + handler_result = error_handler(e) + if handler_result is not None: + raise handler_result from e + raise e @profiler.annotate_function @@ -330,8 +354,7 @@ def backend_compile_and_load( try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - # TODO(dsuo): Simplify this logic once backend_compile actually returns an - # unloaded executable. + # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. if jaxlib_extension_version < 345 or ( jaxlib_extension_version >= 345 and isinstance(backend, _jax.CompileOnlyPyClient) @@ -341,7 +364,7 @@ def backend_compile_and_load( built_c, executable_devices=executable_devices, # type: ignore compile_options=options, - host_callbacks=host_callbacks, + host_callbacks=host_callbacks, # type: ignore ) # Some backends don't have `host_callbacks` option yet # TODO(sharadmv): remove this fallback when all backends allow `compile` From 2226be4569d6572139bac32ac0e0ed1410af5d0e Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Sat, 31 May 2025 11:07:06 -0400 Subject: [PATCH 1523/1769] Fix typos discovered by codespell --- CHANGELOG.md | 8 ++++---- benchmarks/mosaic/matmul_bench.py | 2 +- build/build.py | 2 +- build/rocm/setup.rocm.sh | 2 +- build/rocm/tools/build_wheels.py | 4 ++-- ci/envs/README.md | 4 ++-- ci/envs/docker.env | 2 +- ci/utilities/install_wheels_locally.sh | 2 +- ci/utilities/setup_build_environment.sh | 2 +- docs/api_compatibility.md | 2 +- docs/autodidax2_part1.ipynb | 4 ++-- docs/autodidax2_part1.md | 4 ++-- docs/autodidax2_part1.py | 4 ++-- docs/developer.md | 2 +- docs/export/shape_poly.md | 2 +- docs/gpu_performance_tips.md | 4 ++-- docs/index.rst | 4 ++-- docs/notebooks/explicit-sharding.ipynb | 2 +- docs/notebooks/explicit-sharding.md | 2 +- docs/notebooks/host-offloading.ipynb | 2 +- docs/notebooks/host-offloading.md | 2 +- docs/pallas/design/async_note.md | 2 +- docs/pallas/design/design.md | 4 ++-- docs/pallas/gpu/reference.md | 6 +++--- docs/pallas/pipelining.md | 2 +- docs/pallas/tpu/distributed.ipynb | 12 ++++++------ docs/pallas/tpu/distributed.md | 12 ++++++------ docs/pallas/tpu/sparse.ipynb | 2 +- docs/pallas/tpu/sparse.md | 2 +- docs/persistent_compilation_cache.md | 2 +- docs/random-numbers.md | 2 +- examples/ffi/src/jax_ffi_example/rms_norm.py | 2 +- jax/_src/api_util.py | 2 +- jax/_src/cache_key.py | 2 +- jax/_src/clusters/cluster.py | 2 +- jax/_src/clusters/k8s_cluster.py | 2 +- jax/_src/cudnn/fused_attention_stablehlo.py | 4 ++-- jax/_src/custom_batching.py | 2 +- jax/_src/custom_dce.py | 8 ++++---- jax/_src/custom_partitioning_sharding_rule.py | 6 +++--- jax/_src/dlpack.py | 2 +- jax/_src/errors.py | 2 +- jax/_src/export/shape_poly.py | 2 +- jax/_src/ffi.py | 4 ++-- .../export_back_compat_test_util.py | 2 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/jaxpr_util.py | 2 +- jax/_src/lax/control_flow/loops.py | 4 ++-- jax/_src/lax/lax.py | 8 ++++---- jax/_src/lax/linalg.py | 8 ++++---- jax/_src/nn/functions.py | 2 +- jax/_src/numpy/fft.py | 2 +- jax/_src/numpy/lax_numpy.py | 10 +++++----- jax/_src/numpy/linalg.py | 2 +- jax/_src/numpy/ufunc_api.py | 6 +++--- jax/_src/numpy/ufuncs.py | 16 ++++++++-------- jax/_src/pallas/fuser/jaxpr_fusion.py | 2 +- jax/_src/pallas/hlo_interpreter.py | 4 ++-- jax/_src/pallas/mosaic/interpret.py | 10 +++++----- jax/_src/pallas/mosaic/primitives.py | 2 +- jax/_src/pallas/mosaic/random.py | 2 +- jax/_src/pallas/mosaic_gpu/core.py | 6 +++--- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 4 ++-- jax/_src/pallas/pallas_call.py | 14 +++++++------- jax/_src/pallas/primitives.py | 2 +- jax/_src/profiler.py | 2 +- jax/_src/scipy/linalg.py | 2 +- jax/_src/scipy/signal.py | 10 +++++----- jax/_src/scipy/stats/_core.py | 2 +- jax/_src/state/types.py | 2 +- jax/_src/xla_bridge.py | 6 +++--- .../colocated_python/serialization.py | 2 +- jax/experimental/jax2tf/README.md | 2 +- jax/experimental/jax2tf/impl_no_xla.py | 2 +- jax/experimental/key_reuse/_core.py | 4 ++-- jax/experimental/mosaic/gpu/dialect_lowering.py | 2 +- .../mosaic/gpu/examples/flash_attention.py | 4 ++-- jax/experimental/mosaic/gpu/fragmented_array.py | 6 +++--- jax/experimental/mosaic/gpu/launch_context.py | 2 +- jax/experimental/mosaic/gpu/tcgen05.py | 2 +- jax/experimental/mosaic/gpu/utils.py | 4 ++-- jax/experimental/mosaic/gpu/wgmma.py | 2 +- jax/experimental/multihost_utils.py | 4 ++-- .../pallas/ops/gpu/attention_mgpu.py | 2 +- .../pallas/ops/tpu/paged_attention/util.py | 2 +- .../ops/tpu/ragged_paged_attention/kernel.py | 2 +- .../splash_attention/splash_attention_mask.py | 4 ++-- .../splash_attention_mask_info.py | 6 +++--- jax/experimental/shard_map.py | 2 +- jax_plugins/cuda/__init__.py | 2 +- jaxlib/config.cc | 6 +++--- jaxlib/jax.bzl | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 4 ++-- jaxlib/mosaic/dialect/tpu/layout.h | 2 +- .../tpu/transforms/apply_vector_layout.cc | 4 ++-- .../tpu/transforms/infer_vector_layout.cc | 4 ++-- jaxlib/pjit.cc | 2 +- jaxlib/py_client.h | 2 +- jaxlib/xla_compiler.cc | 2 +- tests/checkify_test.py | 4 ++-- tests/debug_info_test.py | 2 +- tests/error_check_test.py | 2 +- tests/export_test.py | 6 +++--- tests/generated_fun_test.py | 2 +- tests/gpu_memory_flags_test.py | 2 +- tests/lax_metal_test.py | 2 +- tests/lax_numpy_test.py | 2 +- tests/lax_numpy_ufuncs_test.py | 2 +- tests/lax_scipy_test.py | 2 +- tests/lax_test.py | 2 +- tests/mosaic/gpu_dialect_test.py | 2 +- tests/pallas/mgpu_collective_matmul_test.py | 2 +- tests/pallas/tpu_fusible_matmul_test.py | 2 +- tests/pgle_test.py | 2 +- tests/scaled_matmul_stablehlo_test.py | 6 +++--- tests/shape_poly_test.py | 4 ++-- third_party/repo.bzl | 2 +- 119 files changed, 209 insertions(+), 209 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afd15a357b48..c43cc7fb9b4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes * Additional checking for the versions of CUDA package dependencies was - reenabled, having been accidentally disabled in a previous release. + re-enabled, having been accidentally disabled in a previous release. * JAX nightly packages are now published to artifact registry. To install these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). * `jax.sharding.PartitionSpec` no longer inherits from a tuple. @@ -232,7 +232,7 @@ to signify this. developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to. - We are open to readding support for Mac x86 if the community is willing + We are open to re-adding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again. @@ -457,7 +457,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. * `jax_pmap_no_rank_reduction` flag is set to `True` by default. * array[0] on a pmap result now introduces a reshape (use array[0:1] instead). - * The per-shard shape (accessable via jax_array.addressable_shards or + * The per-shard shape (accessible via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. @@ -1513,7 +1513,7 @@ See the 0.4.33 release notes for more details. dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if the platform doesn't support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU. - * Readded support for the Python buffer protocol (`memoryview`) on CPU + * Re-added support for the Python buffer protocol (`memoryview`) on CPU devices. ## jax 0.4.10 (May 11, 2023) diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py index 32c147916407..fd3fcd6da315 100644 --- a/benchmarks/mosaic/matmul_bench.py +++ b/benchmarks/mosaic/matmul_bench.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Microbenchmarks for mosaic gpu matrix mutliplication.""" +"""Microbenchmarks for mosaic gpu matrix multiplication.""" import functools import sys diff --git a/build/build.py b/build/build.py index d059251552eb..40e02a100d98 100755 --- a/build/build.py +++ b/build/build.py @@ -612,7 +612,7 @@ async def main(): wheel_build_command_base.append("--config=rocm") wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) + logging.debug("ROCm toolkit path: %s", args.rocm_path) wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") if args.rocm_amdgpu_targets: logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 3893d817e3a8..faa79d2ce1fd 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -13,7 +13,7 @@ ROCM_BUILD_NAME=ubuntu ROCM_BUILD_NUM=main # Adjust the ROCM repo location -# Intial release don't have the trialing '.0' +# Initial release don't have the trialing '.0' # For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/ if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then ROCM_VERS=${ROCM_VERSION%.*} diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index a7ebdf86f916..3b3d697addc9 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -227,7 +227,7 @@ def fix_wheel(path, jax_path): env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed - # the fuction to ldd and also changed its behavior + # the function to ldd and also changed its behavior # constrain range to 6.0 to 6.2.x cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) @@ -325,7 +325,7 @@ def main(): shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info")) shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__")) - # Make the wheels deleteable by the runner + # Make the wheels deletable by the runner whl_house = os.path.join(args.jax_path, "wheelhouse") logging.info("Changing permissions for %s" % whl_house) mode = 0o664 diff --git a/ci/envs/README.md b/ci/envs/README.md index cf7a0c12fc9f..2a81d0f3240d 100644 --- a/ci/envs/README.md +++ b/ci/envs/README.md @@ -9,7 +9,7 @@ Name | Default Value ------------------------------------------- | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- `JAXCI_JAX_GIT_DIR` | Present working directory: `$(pwd)` | Path to the JAX's Git directory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_GIT_DIR&type=code) `JAXCI_HERMETIC_PYTHON_VERSION` | System default | Controls the version of hermetic Python to use. This affects the Bazel commands only such as when building artifacts or when running the Bazel test scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_HERMETIC_PYTHON_VERSION&type=code) -`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repoistory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) +`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repository. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) `JAXCI_CLONE_MAIN_XLA` | 0 | If set to 1, the XLA repository is cloned at HEAD and its path is set in `JAXCI_XLA_GIT_DIR` | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CLONE_MAIN_XLA&type=code) `JAXCI_XLA_COMMIT` | Unset | Allows overriding the XLA commit that is used when using a local copy of XLA. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_COMMIT&type=code) `JAXCI_OUTPUT_DIR` | `$(pwd)/dist` | Controls the location where the artifacts are written to. The directory will be automatically created if it does not exist. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_OUTPUT_DIR&type=code) @@ -37,5 +37,5 @@ Name | Default Value Name | Default Value | Behavior | Usage ----------------------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------- | ----- `JAXCI_DOCKER_WORK_DIR` | "/jax" | The path on the container where the JAX Git repository is mounted to. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_WORK_DIR&type=code) -`JAXCI_DOCKER_ARGS` | Empty String | Space seprated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) +`JAXCI_DOCKER_ARGS` | Empty String | Space separated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) `JAXCI_DOCKER_IMAGE` | Depends on the system (see [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env)) | Docker image to pull | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_IMAGE&type=code) diff --git a/ci/envs/docker.env b/ci/envs/docker.env index d556cb82d74d..cef2cda27bf4 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# This file contains all the docker specifc envs that are needed by the +# This file contains all the docker specific envs that are needed by the # ci/utilities/run_docker_container.sh script. os=$(uname -s | awk '{print tolower($0)}') diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index b1472d765c08..d66e1fea967b 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -22,7 +22,7 @@ WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o - for i in "${!WHEELS[@]}"; do if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then - # Apppend an extra to the end of the JAX wheel path to install those + # Append an extra to the end of the JAX wheel path to install those # packages as well from PyPI. E.g. jax[tpu] will install the libtpu package # from PyPI. See ci/envs/README.md for more details. if [[ -n "$JAXCI_JAX_PYPI_EXTRAS" ]]; then diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 114acf2479ff..246665cd2f9f 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -16,7 +16,7 @@ # Set up the build environment for JAX CI jobs. This script depends on the # "JAXCI_" environment variables set or sourced in the build script. -# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# Preemptively mark the JAX git directory as safe. This is necessary for JAX CI # jobs running on Linux runners in GitHub Actions. Without this, git complains # that the directory has dubious ownership and refuses to run any commands. # Avoid running on Windows runners as git runs into issues with not being able diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index dda86e2e5d31..985b2145c5c4 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -96,7 +96,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module documentation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax2_part1.ipynb b/docs/autodidax2_part1.ipynb index 0a5a89c8ed98..7a58f54b16c8 100644 --- a/docs/autodidax2_part1.ipynb +++ b/docs/autodidax2_part1.ipynb @@ -674,7 +674,7 @@ "something is constant with respect to differentiation? It's tempting to say\n", "\"it's a constant if and only if it's not a dual number\". But actually dual\n", "numbers created by a *different* JVPInterpreter also need to be considered\n", - "constants with resepect to the JVPInterpreter we're currently handling. That's\n", + "constants with respect to the JVPInterpreter we're currently handling. That's\n", "why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n", "comes up in higher order differentiation when there are multiple JVPInterprers\n", "in scope. The sort of bug where you accidentally interpret a dual number from\n", @@ -1046,7 +1046,7 @@ "That's it for part one of this tutorial. We've done two primitives, three\n", "interpreters and the tracing mechanism that weaves them together. In the next\n", "part we'll add types other than floats, error handling, compilation,\n", - "reverse-mode AD and higher-order primtives. Note that the second part is\n", + "reverse-mode AD and higher-order primitives. Note that the second part is\n", "structured differently. Rather than trying to have a top-to-bottom order that\n", "obeys both code dependencies (e.g. data structures need to be defined before\n", "they're used) and pedagogical dependencies (concepts need to be introduced\n", diff --git a/docs/autodidax2_part1.md b/docs/autodidax2_part1.md index 70dd0e4b696b..a4af594fb253 100644 --- a/docs/autodidax2_part1.md +++ b/docs/autodidax2_part1.md @@ -348,7 +348,7 @@ There are some subtleties worth discussing. First, how do you tell if something is constant with respect to differentiation? It's tempting to say "it's a constant if and only if it's not a dual number". But actually dual numbers created by a *different* JVPInterpreter also need to be considered -constants with resepect to the JVPInterpreter we're currently handling. That's +constants with respect to the JVPInterpreter we're currently handling. That's why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This comes up in higher order differentiation when there are multiple JVPInterprers in scope. The sort of bug where you accidentally interpret a dual number from @@ -539,7 +539,7 @@ print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0)) That's it for part one of this tutorial. We've done two primitives, three interpreters and the tracing mechanism that weaves them together. In the next part we'll add types other than floats, error handling, compilation, -reverse-mode AD and higher-order primtives. Note that the second part is +reverse-mode AD and higher-order primitives. Note that the second part is structured differently. Rather than trying to have a top-to-bottom order that obeys both code dependencies (e.g. data structures need to be defined before they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/autodidax2_part1.py b/docs/autodidax2_part1.py index bfe59df359d3..44bf843c91b3 100644 --- a/docs/autodidax2_part1.py +++ b/docs/autodidax2_part1.py @@ -307,7 +307,7 @@ def nth_order_derivative(n, f, x): # something is constant with respect to differentiation? It's tempting to say # "it's a constant if and only if it's not a dual number". But actually dual # numbers created by a *different* JVPInterpreter also need to be considered -# constants with resepect to the JVPInterpreter we're currently handling. That's +# constants with respect to the JVPInterpreter we're currently handling. That's # why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This # comes up in higher order differentiation when there are multiple JVPInterprers # in scope. The sort of bug where you accidentally interpret a dual number from @@ -483,7 +483,7 @@ def eval_atom(x): return env[x] if isinstance(x, Var) else x # That's it for part one of this tutorial. We've done two primitives, three # interpreters and the tracing mechanism that weaves them together. In the next # part we'll add types other than floats, error handling, compilation, -# reverse-mode AD and higher-order primtives. Note that the second part is +# reverse-mode AD and higher-order primitives. Note that the second part is # structured differently. Rather than trying to have a top-to-bottom order that # obeys both code dependencies (e.g. data structures need to be defined before # they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/developer.md b/docs/developer.md index cfb3f16cf649..1b50a9b65bc0 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -374,7 +374,7 @@ in terms of files, not installations): --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= - # We assume that top-level folder in the tarbal is called "python", if it is + # We assume that top-level folder in the tarball is called "python", if it is # something different just pass additional HERMETIC_PYTHON_PREFIX parameter --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 6b63a536ab48..68da231c4a68 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -441,7 +441,7 @@ to {func}`jax.export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python >>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bade464d22a1..c9034a515501 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -93,7 +93,7 @@ export JAX_PGLE_AGGREGATION_PERCENTILE=85 # Right now the auto PGLE profile collection doesn't work with command buffer. # If the command buffer is enabled, Auto PGLE will disable it during profile -# colletion and enable it back after the recompilation. If you need to have a +# collection and enable it back after the recompilation. If you need to have a # consistent command buffer logic with and with PGLE profile you can disable it # manually: export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" @@ -371,7 +371,7 @@ def while_body(carry, i): (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), ) - # Colelctive permute on the "back edge" passes data to the first device. + # Collective permute on the "back edge" passes data to the first device. bwd_edge_data = cycle_back(bwd_edge_data) # Update output buffer. We do this after reading from it to avoid the data diff --git a/docs/index.rst b/docs/index.rst index 07739c01c2fb..93fc6c284685 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,7 +108,7 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** - - `TensorFlow Probabilty`_ + - `TensorFlow Probability`_ - Distrax_ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** @@ -199,4 +199,4 @@ maintains an up-to-date list. .. _Orbax: https://orbax.readthedocs.io/ .. _PyMC: https://www.pymc.io/ .. _TensorFlow Datasets: https://www.tensorflow.org/datasets -.. _TensorFlow Probabilty: https://www.tensorflow.org/probability +.. _TensorFlow Probability: https://www.tensorflow.org/probability diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index e1bee4b99fb5..37010b4ab3d3 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -414,7 +414,7 @@ "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", "same sharding at the input and the output of the scan body. And when you\n", - "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", + "construct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", "provide shardings too. Certain JAX transformations perform type-level\n", "operations. Automatic differentation constructs a tangent type for each primal\n", "type in the original computation (e.g. `TangentOf(float) == float`,\n", diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 1402bca2415f..8989d426ffbc 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -251,7 +251,7 @@ sharding is part of that type. This means that shardings need to match wherever types need to match. For example, the two sides of a `lax.cond` need to have results with matching shardings. And the carry of `lax.scan` needs to have the same sharding at the input and the output of the scan body. And when you -contruct a jaxpr without concrete arguments using `make_jaxpr` you need to +construct a jaxpr without concrete arguments using `make_jaxpr` you need to provide shardings too. Certain JAX transformations perform type-level operations. Automatic differentation constructs a tangent type for each primal type in the original computation (e.g. `TangentOf(float) == float`, diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb index 9c806c2d56e7..f56cb90ff77e 100644 --- a/docs/notebooks/host-offloading.ipynb +++ b/docs/notebooks/host-offloading.ipynb @@ -240,7 +240,7 @@ ], "source": [ "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", - "out_host = f(arr_host) # Input arrays in hte device memory while output arrays in the host memory\n", + "out_host = f(arr_host) # Input arrays in the device memory while output arrays in the host memory\n", "print(\"Result value of D2H: \\n\", out_host)" ] }, diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md index 7e113d40a4b3..cffe8b4340fe 100644 --- a/docs/notebooks/host-offloading.md +++ b/docs/notebooks/host-offloading.md @@ -154,7 +154,7 @@ id: FjZzkxI8ky4r outputId: 2a1b6e7a-1c29-4347-c020-7b47c27a5cc3 --- f = jax.jit(lambda x: x, out_shardings=s_dev) -out_host = f(arr_host) # Input arrays in hte device memory while output arrays in the host memory +out_host = f(arr_host) # Input arrays in the device memory while output arrays in the host memory print("Result value of D2H: \n", out_host) ``` diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md index 0fda9fe0a4e2..b255a91d3ec8 100644 --- a/docs/pallas/design/async_note.md +++ b/docs/pallas/design/async_note.md @@ -464,7 +464,7 @@ def f(x): return fori_loop(0, 8, body, x) ``` -If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer reuse and defensively insert a copy. ```py def f(x): diff --git a/docs/pallas/design/design.md b/docs/pallas/design/design.md index 17c7a6dbdc0f..53a5eb209510 100644 --- a/docs/pallas/design/design.md +++ b/docs/pallas/design/design.md @@ -71,7 +71,7 @@ A JAX-based kernel language offers several advantages: * JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, - we can re-use JAX’s tracing infrastructure and provide a + we can reuse JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users. * JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex @@ -551,7 +551,7 @@ along that dimension. `grad` of `pallas_call` enables automatic differentiation of kernels. `jax.grad` breaks down into applications of three distinct transforms: `jvp`, `partial_eval` and `transpose`. -In principle, we can re-use most of JAX’s infrastructure when +In principle, we can reuse most of JAX’s infrastructure when implementing these rules for `pallas_call` (since it behaves much like existing JAX higher order primitives). diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 7b4a1e6e9c7d..1a4f39dff5f2 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -30,7 +30,7 @@ the next instruction.
A diagram of one NVIDIA SM
Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are -4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming +4 consecutive warps. Knowing how the hardware looks like, we can see where this is coming from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions that utilize the whole SM. @@ -49,7 +49,7 @@ warps always run in lockstep (modulo the jitter from hardware scheduling) and ne different paths through control flow (with the small exception of `core_map` that we will discuss later). One notable addition here is that we still allow you to co-schedule multiple of those Pallas-level threads on the same SM so that they can cooperate and communicate -through shared memory (we relize that by putting them in the same CUDA block). +through shared memory (we realize that by putting them in the same CUDA block). ```{note} From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. @@ -329,7 +329,7 @@ transforms specified upon their allocation. For all currently supported generati the TensorCore requires the data to be laid out into row-major 2D tiles of shape `(8, swizzle_elems)`, where `swizzle_elems` is derived by dividing the swizzle by the element type bytewidth. The currently supported swizzles are: 128, 64, and 32. Larger -swizzles are preferrable as they improve the performance of GMEM-to-SMEM copies. +swizzles are preferable as they improve the performance of GMEM-to-SMEM copies. ```python def mma_transforms(shape_dtype: jax.ShapeDtypeStruct): diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index a79876a0ca97..0ff9eaf5a24b 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -34,7 +34,7 @@ import numpy as np ## Memory Hierarchies -The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: +The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: - **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them. - **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers. SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2). diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index ae82b7a80ac6..3ac1206bd14a 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -71,7 +71,7 @@ "\n", "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n", "\n", - "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", + "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", "\n", "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)" ] @@ -477,7 +477,7 @@ "id": "KgU7HI2pS4om" }, "source": [ - "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." + "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." ] }, { @@ -529,7 +529,7 @@ "\n", "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n", "\n", - "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", + "Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen:\n", " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted.\n", "\n", @@ -644,7 +644,7 @@ "\n", "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", "\n", - "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device.\n", "\n", "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." ] @@ -691,7 +691,7 @@ " \"\"\"Performs a barrier with neighbors on the global barrier semaphore.\n", "\n", " Optionally performs a second barrier, which prevents a potential race\n", - " when re-using the same collective_id across kernel invocations.\n", + " when reusing the same collective_id across kernel invocations.\n", " \"\"\"\n", " barrier_sem = pltpu.get_barrier_semaphore()\n", " for neighbor in [left_neighbor, right_neighbor]:\n", @@ -1701,7 +1701,7 @@ "\n", "### Next Steps\n", "\n", - "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." + "Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." ] } ], diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index b16116549972..19b336005c28 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -61,7 +61,7 @@ TPUs pods are typically arranged in an ND torus topology. The following graphic ![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png) -Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. +Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. ![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png) @@ -409,7 +409,7 @@ print('Difference |Pallas - lax.all_gather| = ', +++ {"id": "KgU7HI2pS4om"} -A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. +A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. +++ {"id": "EDCmAaHVtY7x"} @@ -451,7 +451,7 @@ def semaphore_read( In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`. -Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: +Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen: - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted. @@ -556,7 +556,7 @@ The prologue (executed when `outer_step==0`) first initiates a barrier with both The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). -A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device. +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device. Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. @@ -585,7 +585,7 @@ def local_barrier(left_neighbor, right_neighbor, double_barrier=True): """Performs a barrier with neighbors on the global barrier semaphore. Optionally performs a second barrier, which prevents a potential race - when re-using the same collective_id across kernel invocations. + when reusing the same collective_id across kernel invocations. """ barrier_sem = pltpu.get_barrier_semaphore() for neighbor in [left_neighbor, right_neighbor]: @@ -1514,4 +1514,4 @@ In this tutorial we covered several kernel examples which replicate the function ### Next Steps -Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. +Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index ac3a0dad2404..31cfa8eeb328 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -491,7 +491,7 @@ "source": [ "def sparsify_mask(mask: jax.Array,\n", " block_shape: tuple[int, int]):\n", - " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + " \"\"\"Preprocesses a mask into a sparse representation.\n", "\n", " Args:\n", " mask: A boolean array of shape [M, N]\n", diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 113f31d8bab2..35613acdb2c9 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -391,7 +391,7 @@ As we will be working with a sparse mask, we will begin by implementing a functi def sparsify_mask(mask: jax.Array, block_shape: tuple[int, int]): - """Preprocesses a mask into a sparse reprentation. + """Preprocesses a mask into a sparse representation. Args: mask: A boolean array of shape [M, N] diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index e241e76e3c5f..d795a054bc87 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -260,7 +260,7 @@ If we were to merely compile this function without shard_map, the cache key for layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta) ``` -However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`: +However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same every time despite `LayerNorm` being implementing `custom_partitioning`: ```python import jax diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 134b690839e0..5562dc3f43d5 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -150,7 +150,7 @@ print(random.normal(key)) print(random.normal(key)) ``` -Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. +Reusing the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. **The rule of thumb is: never reuse keys (unless you want identical outputs). Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__.** diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 5ba97f48ebad..996eb9e5d935 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -16,7 +16,7 @@ This example is exactly the same as the one in the `FFI tutorial `, so more details can be found on that page. But, the high level summary is that we implement our custom -extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in +extension in ``rms_norm.cc``, then call it using ``jax.ffi.ffi_call`` in this module. The behavior under autodiff is implemented using ``jax.custom_vjp``. """ diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 2e7ba551c624..5261764d0bf8 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -606,7 +606,7 @@ def debug_info( """Constructd core.DebugInfo for a function given example args and kwargs. `args` and `kwargs` are example positional and keyword arguments, users with - `inspect.Signature` to get the names of argments. The arguments that are + `inspect.Signature` to get the names of arguments. The arguments that are considered static for tracing purposes should be included, and designated using `static_argnums` and `static_argnames`. diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6fe3d8819d3c..906e686727ef 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -56,7 +56,7 @@ def get_flag_prefixes() -> list[str]: def custom_hook() -> str: """Custom hook for any addition to the cache key. - The custom hook will be called everytime get() is called and can be + The custom hook will be called every time get() is called and can be defined to return a string that will be hashed into the cache key. """ return "" diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 69ef77a6421d..1c0a6fca9df6 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -23,7 +23,7 @@ class ClusterEnv: """Interface for defining a cluster environment. - To enable auto bootrapping (aka :func:`jax.distributed.initialize()`), + To enable auto bootstrapping (aka :func:`jax.distributed.initialize()`), cluster environments need to derive from :class:`ClusterEnv` and implement :func:`is_env_present`, :func:`get_coordinator_address`, :func:`get_process_count`, and :func:`get_process_id`. diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index af1b7c020eed..fb312038bf2c 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -78,7 +78,7 @@ def is_env_present(cls) -> bool: textwrap.fill( "Kubernetes environment detected, but the `kubernetes` package " "is not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " + "environment. To enable automatic bootstrapping, please install " "jax with the [k8s] extra. For example:"), " pip install jax[k8s]", " pip install jax[k8s,]", diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 46df84e08e0f..84ca9226e82c 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -357,7 +357,7 @@ def check_is_flash_attention( H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128 if not (H <= H_max and H % 8 == 0): raise NotImplementedError( - f"The head dim must be <= {H_max} and a mutiple of 8, " + f"The head dim must be <= {H_max} and a multiple of 8, " f"but got {H}." ) @@ -1844,7 +1844,7 @@ def dot_product_attention( # should be broadcast to same shape bias = bias + mask - # check if input shape and data type is compatiable + # check if input shape and data type is compatible check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) has_bias = bias is not None has_dbias = has_bias and \ diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 338074837ea5..83c9ffb5ee36 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -103,7 +103,7 @@ class custom_vmap: >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32) - Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the + Note that the :py:class:`jax.custom_vjp` must be on the outside, wrapping the ``custom_vmap``-decorated function. """ diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index d336c969a3c4..25fe604085fd 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -251,9 +251,9 @@ def flatten_dce_rule( # For error checking purposes, we need to reformat the pytree structure # of the output of the DCE rule to match the original output. The catch is # that the DCE rule can return a None to indicated an unused subtree, so we - # need to rebuild those subtrees with a sentinal value at the leaves. This + # need to rebuild those subtrees with a sentinel value at the leaves. This # logic is very similar to what is used in custom_dervatives._flatten_bwd. - sentinal = object() + sentinel = object() dummy = tree_util.tree_unflatten(out_tree, [object()] * out_tree.num_leaves) keypaths, _ = util.unzip2(tree_util.tree_flatten_with_path(dummy)[0]) out_flat = [] @@ -261,7 +261,7 @@ def flatten_dce_rule( def append(x, d): num_leaves = len(tree_util.tree_flatten(d)[0]) if x is None and d is not None: - out_flat.extend([sentinal] * num_leaves) + out_flat.extend([sentinel] * num_leaves) elif x is not None: out_flat.extend([x] * num_leaves) return x @@ -281,7 +281,7 @@ def append(x, d): for kp, used, aval, val in zip(keypaths, used_outs, out_avals, out_flat): if not used: continue - if val is sentinal: + if val is sentinel: raise ValueError( f"Custom DCE rule {rule_name} for function {fun_name} must produce " "values for all of the required outputs (as specified by the " diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index d17399beda5b..bc27f34b3bfb 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -138,12 +138,12 @@ def __init__(self, operand_mappings: tuple[ArrayMapping, ...], # Check that factors that are used for a whole dimension aren't in # factor_sizes and factors that are never used for a whole dimension are # in factor_sizes. - for factor, inferrable in factors_inferrable.items(): - if factor not in factor_sizes and not inferrable: + for factor, inferable in factors_inferrable.items(): + if factor not in factor_sizes and not inferable: raise ValueError( f"Factor {factor} is only used in compound factors; must specify" " its size") - if factor in factor_sizes and inferrable: + if factor in factor_sizes and inferable: raise ValueError( f"Factor {factor} represents a whole dimension; do not specify its" " size") diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 40a69d1e0390..1f19ac0f45c0 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -130,7 +130,7 @@ def to_dlpack(x: Array, stream: int | Any | None = None, ) from None # As new versions are adopted over time, we can maintain some legacy paths - # for compatability mediated through the max_version parameter. + # for compatibility mediated through the max_version parameter. # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 20b82f629f6f..a548714869ab 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -503,7 +503,7 @@ class TracerBoolConversionError(ConcretizationTypeError): In this case, the error occurs because Python's built-in ``min`` function is not compatible with JAX transforms. This can be fixed by replacing it with - ``jnp.minumum``: + ``jnp.minimum``: >>> @jit ... def func(x): diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 31371cf345a1..bb8a159ee54b 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -978,7 +978,7 @@ def cmp_sequence(s1, s2, elem_cmp) -> int: class SymbolicScope: - """Indentifies a scope for symbolic expressions. + """Identifies a scope for symbolic expressions. All symbolic expressions that interact (e.g., appear in the argument shapes for one JAX function invocation, or are involved in arithmetic operations) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 3bfe8130ccda..db943d675b80 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -56,7 +56,7 @@ def register_ffi_target( name: the name of the target. fn: a ``PyCapsule`` object containing the function pointer, or a ``dict`` where the keys are FFI stage names (e.g. `"execute"`) and the values are - ``PyCapsule`` objects continaing a pointer to the handler for that stage. + ``PyCapsule`` objects containing a pointer to the handler for that stage. platform: the target platform. api_version: the XLA custom call API version to use. Supported versions are: 1 (default) for the typed FFI or 0 for the earlier "custom call" API. @@ -369,7 +369,7 @@ def ffi_call( Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under :func:`~jax.vmap` depends on the value of ``vmap_method``. See the - :func:`~jax.pure_callback` documenation for more details about the allowed + :func:`~jax.pure_callback` documentation for more details about the allowed values and examples of their behavior. The current default behavior is to use ``vmap_method="sequential"`` when diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index b86b24e2b4fc..7b4af36e5dc4 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -321,7 +321,7 @@ def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs) # TODO: we ought to ensure that out_avals are polymorphic if need be. We # could either save the in/out_avals (but we need to first implement that - # support in export), or we can just re-use them from the current + # support in export), or we can just reuse them from the current # exported. out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs) # in_tree must be for (args, kwargs) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 6bad0ee2a018..0864ec8646c9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -706,7 +706,7 @@ class ModuleContext: # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] - # Cached traceback infromation. + # Cached traceback information. traceback_caches: TracebackCaches lowering_parameters: LoweringParameters diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index a6c93c8c120c..cb9eef0b9ea2 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -233,7 +233,7 @@ def jaxpr_and_binder_in_params(params, index: int) -> Iterator[tuple[core.Jaxpr, def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn]: """Find the leaf equations using a variable""" - # The complexity of this call is becauase the invar might originate from a nested jaxpr + # The complexity of this call is because the invar might originate from a nested jaxpr for eqn, invar_index in eqns_using_var_with_invar_index(jaxpr, invar): if (child_jaxprs_and_vars := tuple(jaxpr_and_binder_in_params(eqn.params, invar_index))): for (jaxpr, invar) in child_jaxprs_and_vars: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 985c5ba52294..ad09292731cf 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -198,7 +198,7 @@ def scan(f, init, xs, length=None): a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is - competely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. + completely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. `unroll=False`). _split_transpose: experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a @@ -2427,7 +2427,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a - boolean is provided, it will determine if the loop is competely unrolled + boolean is provided, it will determine if the loop is completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). This argument is only applicable if the loop bounds are statically known. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 43dffc7bef9c..922515926f2f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -935,7 +935,7 @@ def integer_pow(x: ArrayLike, y: int) -> Array: An array of the same shape and dtype as ``x`` containing the elementwise power. See also: - :func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array. + :func:`jax.lax.pow`: Elementwise power where ``y`` is an array. .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ @@ -2102,7 +2102,7 @@ class DotAlgorithm(NamedTuple): The `StableHLO spec `_ for the dot operation doesn't require that the precision types be the same as the - storage types for the inputs or outputs, but some plaforms may require that + storage types for the inputs or outputs, but some platforms may require that these types match. Furthermore, the return type of :func:`~jax.lax.dot_general` is always defined by the ``accumulation_type`` parameter of the input algorithm, if specified. @@ -7923,7 +7923,7 @@ def _sort_abstract_eval(*args, **kwargs): def _canonicalize_float_for_sort(x): - # In the sort comparator, we are going to use a comparision operator where -0 + # In the sort comparator, we are going to use a comparison operator where -0 # would be before 0, and -NaN and NaN appear at the beginning and end of the # ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at # the beginning and end of the ordering. This causes issues for stable @@ -8164,7 +8164,7 @@ def _create_token_lowering(ctx, *operands): def after_all(*operands): """Merges one or more XLA token values. Experimental. - Wraps the XLA AfterAll operator.""" + Wraps the XLA after all operator.""" operands = core.standard_insert_pvary(*operands) return after_all_p.bind(*operands) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 2fda4a90369d..3ee7cc2a6807 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2148,7 +2148,7 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, # default QR algorithm, but users can (in principle) override this behavior # by passing `use_jacobi=True`. # - # TODO(danfm): Since this was originally implemented, hipSolver appers to + # TODO(danfm): Since this was originally implemented, hipSolver appears to # have added support for the Jacobi algorithm, so we should investigate # removing this condition. if algorithm is None or algorithm == SvdAlgorithm.DEFAULT: @@ -2339,7 +2339,7 @@ def a_inverse(rhs): transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) - # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs + # triangular_solve is about the same cost as matrix multiplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: @@ -2776,8 +2776,8 @@ def _column_major_matrix_layout(dim: int) -> tuple[int, ...]: def _sdy_rule_for_aval(letters, num_batch_dims, aval): d = len(aval.shape) - num_batch_dims - preffix = "... " if num_batch_dims and d >= 0 else "" - return preffix + " ".join(next(letters) for _ in range(d)) + prefix = "... " if num_batch_dims and d >= 0 else "" + return prefix + " ".join(next(letters) for _ in range(d)) def _build_sdy_sharding_rule(num_batch_dims, avals_in, avals_out): letters = iter(string.ascii_letters) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 3f7647758003..f01c4fa52804 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1077,7 +1077,7 @@ def dot_product_attention( token's local window. If set, this specifies the (left_window_size, right_window_size) for each token. E.g., if local_window_size == (3, 2) and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend - to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be interpreted as a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 21da91ce613f..970847532e46 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -712,7 +712,7 @@ def hfft(a: ArrayLike, n: int | None = None, are supported. Default is "backward". Returns: - A real-valued array containing the one-dimensional discret Fourier transform + A real-valued array containing the one-dimensional discrete Fourier transform of ``a`` by exploiting its inherent Hermitian-symmetry, having a dimension of ``n`` along ``axis``. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f323bc64718b..a35bcbb23213 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -272,7 +272,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. - JAX implemtentation of :func:`numpy.fmin`. + JAX implementation of :func:`numpy.fmin`. Args: x1: input array or scalar. @@ -2251,7 +2251,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: Returns: A resized array with specified shape. The elements of ``a`` are repeated in - the resized array, if the resized array is larger than the original aray. + the resized array, if the resized array is larger than the original array. See also: - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array. @@ -5575,7 +5575,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, device: xc.Device | Sharding | None = None) -> Array: """Convert an array to a specified dtype. - JAX imlementation of :func:`numpy.astype`. + JAX implementation of :func:`numpy.astype`. This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. @@ -5957,7 +5957,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, if needed for a device transfer. Returns: - A JAX array of the imput buffer. + A JAX array of the input buffer. Note: While JAX arrays are always immutable, dlpack buffers cannot be marked as @@ -8419,7 +8419,7 @@ def vander( [3, 1], [4, 1]], dtype=int32) - Generates the Vandermonde matrix in increaing order of powers, when + Generates the Vandermonde matrix in increasing order of powers, when ``increasing=True``. >>> jnp.vander(x, increasing=True) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index f2deddd52f05..2351b0ccb075 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1617,7 +1617,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: x_arr = ensure_arraylike('jnp.linalg.matrix_transpose', x) ndim = x_arr.ndim if ndim < 2: - raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}") + raise ValueError(f"matrix_transpose requires at least 2 dimensions; got {ndim=}") return lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 243ab9aa0878..c85621d6cdba 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -92,7 +92,7 @@ class ufunc: [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32) - The :meth:`ufunc.reduce` method perfoms a reduction over the array. + The :meth:`ufunc.reduce` method performs a reduction over the array. For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: >>> jnp.add.reduce(x) @@ -112,7 +112,7 @@ class ufunc: Array([101, 2, 3, 4, 5], dtype=int32) And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` - operations bewteen specified indices of an array; for ``jnp.add`` the + operations between specified indices of an array; for ``jnp.add`` the operation is similar to :func:`jax.ops.segment_sum`: >>> jnp.add.reduceat(x, jnp.array([0, 2])) @@ -574,7 +574,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: [ 10 20 30 40 50 60 70 80 90 100]] For input arrays with ``N`` and ``M`` dimensions respectively, the output - will have dimesion ``N + M``: + will have dimension ``N + M``: >>> x = jnp.ones((1, 3, 5)) >>> y = jnp.ones((2, 4)) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 486d3f15e17c..b0ff3cb9747a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -297,7 +297,7 @@ def sign(x: ArrayLike, /) -> Array: -1, & x < 0 \end{cases} - For complex valued input, ``jnp.sign`` returns a unit vector repesenting the + For complex valued input, ``jnp.sign`` returns a unit vector representing the phase. For generalized case, the sign of ``x`` is given by: .. math:: @@ -347,8 +347,8 @@ def floor(x: ArrayLike, /) -> Array: the nearest integer that is less than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. @@ -386,8 +386,8 @@ def ceil(x: ArrayLike, /) -> Array: the nearest integer that is greater than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. @@ -1621,7 +1621,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: The results match the input ``theta``, except at the endpoints where :math:`+\pi` and :math:`-\pi` represent indistinguishable points on the unit circle. By convention, - :func:`arctan2` alwasy returns values between :math:`-\pi` and :math:`+\pi` inclusive. + :func:`arctan2` always returns values between :math:`-\pi` and :math:`+\pi` inclusive. """ return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) @@ -1710,7 +1710,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: arrays. - :func:`jax.numpy.fmax`: Returns element-wise maximum of the input arrays, ignoring NaNs. - - :func:`jax.numpy.amax`: Retruns the maximum of array elements along a given + - :func:`jax.numpy.amax`: Returns the maximum of array elements along a given axis. - :func:`jax.numpy.nanmax`: Returns the maximum of the array elements along a given axis, ignoring NaNs. @@ -1774,7 +1774,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(x, y) Array([ 9. , 1. , -0.2], dtype=float32) - Inputs with broacast compatibility: + Inputs with broadcast compatibility: >>> x1 = jnp.array([[2, -4, 1], ... [-1, 2, 3]]) diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index d1e375e33ef1..8e12b5db483d 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -176,7 +176,7 @@ def _construct_output_fusions( unflat_fusible_outvars ) - # 3. Calculate dependencies and check disjointness + # 3. Calculate dependencies and check disjointedness downstream_outputs_used_masks = [] # List of bool tuples, one per group already_used_final_outputs = set() # Indices of final outputs already claimed for outvars_group in partial_flat: diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 755df2cd8ceb..fac798fe9dc1 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -189,7 +189,7 @@ def eval_jaxpr_recursive( consts: Consts that ``jaxpr`` closes over. *args: Input arguments to the ``jaxpr``. recurse_hop_rule: A Jaxpr interpreter to call on sub-jaxprs of - higher-order primtives. + higher-order primitives. propagate_source_info: Whether to propagate source info. """ def read(v: jax_core.Atom) -> Any: @@ -419,7 +419,7 @@ def pallas_call_hlo_interpret( num_iterations = 1 # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index e278168d999a..7a6c18d43bb7 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -82,7 +82,7 @@ class InterpretParams: Default: False. skip_floating_point_ops: If True, operations that produce only floating point values will not be interpreted; instead, their results will be - replaced with arrays all of `jnp.inf`. Additionaly any floating point + replaced with arrays all of `jnp.inf`. Additionally any floating point operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. uninitialized_memory: If "nan", allocated buffers are initialized to contain @@ -937,7 +937,7 @@ def get( raise ValueError( 'Out-of-bounds read of' f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' - f' reading [{read_range}] but bufer has shape {buffer.shape} .' + f' reading [{read_range}] but buffer has shape {buffer.shape} .' ) if shared_memory.interpret_params.detect_races: @@ -1817,7 +1817,7 @@ def _get_randomized_grid_coordinates( For a dimension with 'parallel' semantics at position `d` in the grid, the returned tuple contains a random permutation of the sequence `[0,..., grid[d] - 1]` at index `d`. For each dimension with 'arbitrary' semantics, - the resulting tuple contains an empty array. (Inserting an empty arry for an + the resulting tuple contains an empty array. (Inserting an empty array for an 'arbitrary' dimension at position `d` in the grid, instead of the sequence `[0,..., grid[d] - 1]`, allows `grid[d]` to be a dynamic value, i.e. a value not known at Jax trace time.) @@ -2059,7 +2059,7 @@ def interpret_pallas_call( output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] for i, bm in enumerate(grid_mapping.block_mappings_output): if i in oi_alias_map: - # Re-use the HBM buffer for the aliased pallas_call input. + # Reuse the HBM buffer for the aliased pallas_call input. output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) output_vals.append(input_args[oi_alias_map[i]]) @@ -2230,7 +2230,7 @@ def _body( Args: carry: (iteration_idx, loop_idx, grid_point, prev_start_indices, cur_start_indices). - - iteration_idx: the interation index. + - iteration_idx: the iteration index. - loop_idx: internal indices for looping over the grid. - grid_point: the current positions along all axes of the grid. - prev_start_indices: a rank-1 array that contains the start indices diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index c9cdcbf56f85..af50773eec20 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -663,7 +663,7 @@ def get_barrier_semaphore(): to share a collective_id. However, if in doubt, prefer not sharing collective_ids, as doing so incorrectly can lead to silent data corruption or crashes. - Note that re-using the same collective_id doesn't guarantee that the same + Note that reusing the same collective_id doesn't guarantee that the same semaphore is provided by XLA. """ return get_barrier_semaphore_p.bind() diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index 6a2c557fd55d..8d29f857afb2 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -177,7 +177,7 @@ def sample_block(sampler_fn: SampleFnType, `tile_size` should be chosen such that it is a divisor to all block sizes one needs to be invariant to. The larger the `tile_size`, the more - efficient the sampling process wil be and therefore the best choice is + efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes. Args: diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 7fb933f5623d..3b28ebdd5d20 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -368,7 +368,7 @@ class RefUnion(GPUMemoryRef): """A sequence of trees of refs that are allowed to reuse the same memory. One should not make assumptions as to how each ref will map to the underlying - memory region, since arbitrary padding may be applied inbetween different + memory region, since arbitrary padding may be applied in between different refs. As such, ref unions are only safe to use when the groups of refs that we @@ -459,7 +459,7 @@ def untransform_transpose( self, perm: tuple[int, ...] ) -> tuple[tuple[int, ...], state_types.Transform]: # The transpose in question is applied to the utiled ref so we - # need to translate it by duplicating and offseting the last part. + # need to translate it by duplicating and offsetting the last part. off = len(perm) new_suffix = [i + off for i in perm[-len(self.tiling) :]] if set(new_suffix) != set(range(off, off + len(self.tiling))): @@ -871,7 +871,7 @@ class Barrier: barriers can be accessed by indexing into the barrier Ref. for_tensor_core: Whether this barrier is used for synchronizing with the tensor core. This should be set to True when waiting on Blackwell - (TC Gen 5) asynchoronous matmul instructions. + (TC Gen 5) asynchronous matmul instructions. """ num_arrivals: int = 1 num_barriers: int = 1 diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5695da4cc8b1..87bb85cfcd70 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2377,7 +2377,7 @@ def _run_scoped_lowering_rule( if any(should_discharge): # We convert consts to args, because we only have ir.Values and # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the + # valiues for the arguments but expects them to be provided for the # consts. We also don't want to wrap the values in refs. no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) should_discharge = [False] * len(consts) + should_discharge diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 616f7e501cd8..f37a003f4401 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -651,7 +651,7 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: case []: return None case _: - raise ValueError("Barrier does not support arbirary transforms") + raise ValueError("Barrier does not support arbitrary transforms") barrier_arrive_p = jax_core.Primitive("barrier_arrive") @@ -835,7 +835,7 @@ def _commit_group_lowering(ctx: lowering.LoweringRuleContext): def commit_smem_to_gmem_group() -> None: - """Commits all issued but uncommited SMEM->GMEM copies to a group.""" + """Commits all issued but uncommitted SMEM->GMEM copies to a group.""" commit_group_p.bind() diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b14259556faf..52360b997743 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -304,7 +304,7 @@ def _broadcast_input_output_aliases( When we have input/output aliasing, since the output will be mapped, we need to make sure to broadcast the input across that dimension if it is not - mapped. If the input is mapped, but on a different axis, we tranpose the input + mapped. If the input is mapped, but on a different axis, we transpose the input to match the output. """ @@ -370,7 +370,7 @@ def _batch_with_explicit_loop( axis_size=axis_size, ) - # The output arrays are completelly overwritten, so we can just initialize + # The output arrays are completely overwritten, so we can just initialize # empty arrays. initial_state = [ jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size), @@ -801,7 +801,7 @@ def index_rewrite_kernel(*indexer_args): ragged_axis_dim = per_input_ragged_axis_dim[arg_pos] # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means - # that the indexer for output isnt getting written - before, it always was + # that the indexer for output isn't getting written - before, it always was lengths_ref = indexer_args[-1] rest_indexer_args = indexer_args[:-1] @@ -896,7 +896,7 @@ def index_rewrite_kernel(*indexer_args): raise NotImplementedError("consts not supported in pallas_call") # We need to rewrite the input_output_aliases here, the initial call - # to broadcast is done, and we have inseted a new input (lengths), so + # to broadcast is done, and we have inserted a new input (lengths), so # there's an off-by-one here now. new_input_output_aliases = [] for k, v in input_output_aliases: @@ -987,7 +987,7 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, for bm in grid_mapping.block_mappings ] # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry @@ -1144,7 +1144,7 @@ def _ensure_2d_error_shape(arg): # for the new error inputs and outputs. error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals) error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) - error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths) + error_origins = tuple(f"errors[{tree_util.keystr(p)}" for p in error_paths) error_block_mappings = map( partial( pallas_core._convert_block_spec_to_block_mapping, @@ -1762,7 +1762,7 @@ def in_path_to_input_origin( # We import the TPU backend at the top level because it defines flags. Note that -# we can only do that at the bottom of this file, beacuse it also depends on +# we can only do that at the bottom of this file, because it also depends on # this module already being initialized. try: diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 5038ac6e5171..95ae15e5bf4e 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -490,7 +490,7 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) - # fixes an inconstency with lax.dynamic_slice where if the slice goes out + # fixes an inconsistency with lax.dynamic_slice where if the slice goes out # of bounds, it will instead move the start_index backwards so the slice # will fit in memory. ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 424e2b81035f..efdd7bd1e2a1 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -402,7 +402,7 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: # Allows to run model with profiler given amount of times. After required amount -# of retries achived client can collect FDO data. +# of retries achieved client can collect FDO data. class PGLEProfiler: def __init__(self, retries: int, percentile: int): diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 55961607b252..7b0bb06f044c 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2189,7 +2189,7 @@ def pascal(n: int, kind: str | None = None) -> Array: JAX implementation of :func:`scipy.linalg.pascal`. - The elements of the Pascal matrix approximate the binomial coefficents. This + The elements of the Pascal matrix approximate the binomial coefficients. This implementation is not exact as JAX does not support exact factorials. Args: diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index f8c2563027f5..d4ca7c2c6147 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -148,7 +148,7 @@ def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: return lax.dynamic_slice(conv, start_indices, out_shape) -# Note: we do not re-use the code from jax.numpy.convolve here, because the handling +# Note: we do not reuse the code from jax.numpy.convolve here, because the handling # of padding differs slightly between the two implementations (particularly for # mode='same'). def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) -> Array: @@ -1030,16 +1030,16 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size)) # For obtaining shifted signals, this routine reinterprets flattened array - # with a shrinked axis. With appropriate truncation/ padding, this operation + # with a shrunken axis. With appropriate truncation/ padding, this operation # pushes the last padded elements of the previous row to the head of the # current row. # See implementation of `overlap_and_add` in Tensorflow for details. x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T) x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T) - shrinked = x.shape[2] - 1 + shrunken = x.shape[2] - 1 x = x.reshape((flat_batchsize, -1)) - x = x[:, :(nstep_per_segment * shrinked * step_size)] - x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size)) + x = x[:, :(nstep_per_segment * shrunken * step_size)] + x = x.reshape((flat_batchsize, nstep_per_segment, shrunken * step_size)) # Finally, sum shifted segments, and truncate results to the output_size. x = x.sum(axis=1)[:, :output_size] diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 65c457f79cc8..ae93dd793844 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -285,7 +285,7 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32) If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error - for the remainging values along the specified axis. + for the remaining values along the specified axis. >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2, nan_policy='omit') diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e3a86e241bf2..7ca1d8e48f9e 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -255,7 +255,7 @@ def shape(self) -> tuple[int | Array, ...]: if not unprocessed: return shape # If there are any unprocessed transforms left, we apply them to the shape - # we've found previuously. + # we've found previously. for t in self.transforms[-unprocessed:]: shape = t.transform_shape(shape) assert shape is not None diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 22e71d7d9ce6..1b95806c37c6 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -381,7 +381,7 @@ def discover_pjrt_plugins() -> None: """Discovers plugins in the namespace package `jax_plugins` and import them. There are two methods used to discover plugin modules. They are intended - to be used together by implementors in order to cover all packaging and + to be used together by implementers in order to cover all packaging and development cases: 1. Define a globally unique module under the `jax_plugins` namespace @@ -964,7 +964,7 @@ def backend_xla_version(platform=None) -> int | None: """Returns the XLA version of the backend. Returns None if the backend does not use PJRT C API or does not have - xla_version in the plugin attributes. This methon can be used to skip features + xla_version in the plugin attributes. This method can be used to skip features that are not available before certain xla_version if the backend is a plugin and uses xla_version. """ @@ -975,7 +975,7 @@ def backend_stablehlo_version(platform=None) -> Sequence[int] | None: """Returns the StableHLO version of the backend. Returns None if the backend does not use PJRT C API or does not have - stablehlo_current_version in the plugin attributes. This methon can be used to + stablehlo_current_version in the plugin attributes. This method can be used to skip features that are not available before certain stablehlo_current_version if the backend is a plugin and uses stablehlo_current_version. """ diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index a8a62d78359f..1f1b96487fab 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -201,7 +201,7 @@ def _serialize_specs( if not hasattr(np.dtypes, "StringDType"): raise TypeError( "Serializing Colocated Python requires StringDType. Please use" - " numpy to 2.0.0 or later, or explicityly provide an output spec" + " numpy to 2.0.0 or later, or explicitly provide an output spec" " function." ) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 06cc5c86a109..ac9829d69006 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -840,7 +840,7 @@ to `export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python a, = export.symbolic_shape("a,", constraints=("a >= 8",)) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 644c3324b4e2..70a6dccf8915 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -659,7 +659,7 @@ def tf_pool(inputs, pooling_type): raise NotImplementedError( f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} " f" {window_strides=} {dilations=}") - # tf.nn.pool() currently does not suport tf.int32 and so we cast back and + # tf.nn.pool() currently does not support tf.int32 and so we cast back and # forth in order to be able to convert. if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add": original_dtype = inputs.dtype diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 6f604f1195a0..a2ffc6582fff 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -212,7 +212,7 @@ def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature: return sig.signature(eqn) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") else: return unknown_signature(eqn) @@ -231,7 +231,7 @@ def key_reuse_signature_from_primitive(prim, *args, **params): return jaxpr_type_signature(jaxpr) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") consume_p = core.Primitive("consume") diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 20138bbe6fd4..e9293d9ffe08 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -957,7 +957,7 @@ def _mgpu_wgmma_op_lowering_rule( raise ValueError("Layout mismatch") wgmma_layout = fa_layouts[0] - # TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma. + # TODO(dasenov): Move the value -> accumulator conversion outside of wgmma. # The associated fence could be a little expensive and is not needed if the # result a wgmma feeds into another wgmma (even in another loop step). acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 78ef1faddc59..280efd513187 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -243,8 +243,8 @@ def kv_loop(kv_step, carry): perform_schedule_barrier() - # This is quite suprising, but it seems like warp shuffles cannot - # run simutaneously with the WGMMA. For that reason we include it as + # This is quite surprising, but it seems like warp shuffles cannot + # run simultaneously with the WGMMA. For that reason we include it as # part of the TensorCore critical section and not the ALU section. with ctx.named_region("Softmax reduction"): l_i += p.reduce(arith.addf, axis=1) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 925aa1575e2d..3554ed95844c 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -669,7 +669,7 @@ def linear_thread_idxs(self): # ... # # You can see that we have taken 2x2 submatrices from the above layout and -# transposed them. The assigment of lanes to elements is such that in both +# transposed them. The assignment of lanes to elements is such that in both # layouts the same two lanes map to a single 2x2 submatrix, making the transpose # very cheap (one shuffle and permute suffices to change between those layouts). WGMMA_TRANSPOSED_LAYOUT = TiledLayout( @@ -1743,7 +1743,7 @@ def reduce( out_reg = vector.splat( ir.VectorType.get((1,), out_reg.type.element_type), scalar_out_reg ) - # Reduce accross warp lanes, if necessary (using warp shuffles). + # Reduce across warp lanes, if necessary (using warp shuffles). if any(reduced_dims[d] for d in layout.partitioned_lane_dims): if utils.bitwidth(out_reg.type) > 32: raise NotImplementedError # Need to implement wide shfl_bfly. @@ -1762,7 +1762,7 @@ def reduce( lane_stride *= 2 reduction_size //= 2 assert lane_stride == WARP_SIZE, lane_stride - # Reduce accross warps in the warpgroup, if necessary. + # Reduce across warps in the warpgroup, if necessary. if ( not isinstance(layout.warp_dim, Replicated) and reduced_dims[layout.warp_dim] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 2a5bb96f4708..852ac90c0d73 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -779,7 +779,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): tuple(slice_shape), swizzle, reduction_op, ) - # We constuct TMA descriptors in column-major order. + # We construct TMA descriptors in column-major order. rev_dyn_base_indices = [ arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) ] diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index d72994f45a87..91797fe65d1f 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -465,7 +465,7 @@ def _tmem_access_helper(shape, num): num_regs *= num if num_regs > 255: raise ValueError( - f"TMEM transation too big : {shape=} and {num=} involve" + f"TMEM translation too big : {shape=} and {num=} involve" f" {num_regs} registers per-thread, which exceeds the limit of 255" ) regs_vector = ",".join(f"${i}" for i in range(num_regs)) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index ac002aa8bffe..b5dbfb62c88f 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -275,7 +275,7 @@ class ThreadSubset(enum.IntEnum): BLOCK = enum.auto() -# True withon `once()` contexts. +# True within `once()` contexts. _ONCE_PER: ThreadSubset | None = None @@ -468,7 +468,7 @@ def fold_until(shape, off , target) -> tuple[int, int]: # TODO(cperivol): Implement dependent fold-unfolds for subsections # of the shape eg (..., 4,5,5, ...) -> (..., 10,10, ...) could be # supported without touching any other dimensions. - raise NotImplementedError(f"Can't reshape {sh0} to {sh1} bu composing independent folds/unfolds.") + raise NotImplementedError(f"Can't reshape {sh0} to {sh1} by composing independent folds/unfolds.") raise AssertionError(f"Unreachable: number of elements don't match in each shape ({sh0} ans {sh1})") diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 2fe826e173e5..9b4fc7678538 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -113,7 +113,7 @@ def wgmma_m64( ): out_ty = ir.VectorType(acc.flat[0].type).element_type if not _supported_wgmma_types(out_ty, element_type): - raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + raise ValueError(f"Unsupported wgmma types {(out_ty, element_type)=}") if n % 8: raise ValueError diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 3a83ff16d612..07a0f747443c 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -203,7 +203,7 @@ def should_save(step_id: int) -> bool: after some hosts are preempted. Raises: - RuntimeError: if preemption sync manager has not been inititialized. + RuntimeError: if preemption sync manager has not been initialized. """ if distributed.global_state.client is None: return False @@ -328,7 +328,7 @@ def host_local_array_to_global_array( >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP - Please note ths function requires global mesh to be a continuous mesh, meaning + Please note this function requires global mesh to be a continuous mesh, meaning that devices that belong to each host should form a subcube in this mesh. To move local data to global array with non-continuous mesh use jax.make_array_from_callback or jax.make_array_from_single_device_arrays diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 650668daf67a..d9d62afe93c5 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -524,7 +524,7 @@ def _compute_sT(acc_ref): def _compute(refs): # Combining two WGMMA calls in one block to avoid the unnecessary - # sychronization from two `wgmma.wait_group` calls. + # synchronization from two `wgmma.wait_group` calls. dv_acc_ref, dpT_acc_ref = refs plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/util.py b/jax/experimental/pallas/ops/tpu/paged_attention/util.py index 6d6ceca3733f..92aa3a7a1b2c 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/util.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/util.py @@ -64,7 +64,7 @@ def grouped_query_attention_reference( if debug: jax.debug.print("qk: {qk}", qk=qk) - # Enfore causal mask (adding dimensions when necessary) + # Enforce causal mask (adding dimensions when necessary) mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None] qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :] if debug: diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index e7bc599b2b2b..3f12448f2a9c 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -638,7 +638,7 @@ def prefetch_next_kv_blk(): v = v.astype(q_ref.dtype) kv_head_idx = kv_head_chunk_idx + step_idx q_head_idx = kv_head_idx * num_q_heads_per_kv_head - # TODO(jevinjiang): extra handlig for packed type that can start at + # TODO(jevinjiang): extra handling for packed type that can start at # unaligned position! q = fold_on_2nd_minor( q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index 3f7a0d863188..354fdb24f9df 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -352,7 +352,7 @@ class ChunkedCausalMask(_ComputableMask): """Lazy chunked causal mask. Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens - attend to each other but not accross chunks. + attend to each other but not across chunks. Llama4 models use interleaved chunk attention along with global attention. @@ -412,7 +412,7 @@ class LocalMask(_ComputableMask): """Lazy local mask, prevents model from attending to tokens outside window. Attributes: - window_size: Size of the two sides of the local window (None identifes no + window_size: Size of the two sides of the local window (None identifies no limit for the given side). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 9c79fbbf7e09..37ef92c2d33d 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -418,7 +418,7 @@ def _process_dynamic_mask( # tensors of this shape: mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) - # Collect mask_info shards along the head dimension, concatentate (or + # Collect mask_info shards along the head dimension, concatenate (or # broadcast) them after the loop. data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(head_shards): @@ -633,7 +633,7 @@ def assign_unique_ids(objects): ] # TODO(amagni): checking the validity of the masks is slow for large masks. - # Disable it for now, reevalute in the future. + # Disable it for now, reevaluate in the future. partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) @@ -747,7 +747,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): q_sequence_axis = 1 head_axis = 0 - # Collect mask_info shards along the head dimension, concatentate (or + # Collect mask_info shards along the head dimension, concatenate (or # broadcast) them after the loop. data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(shards_to_process): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8f9548bdce1b..027cbd36feea 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -53,7 +53,7 @@ def shard_map( out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be - concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at + concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses concatenation of that mesh axis's shards along the corresponding positional axis. Not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 02bcbcf16dbc..f916cf385a66 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -182,7 +182,7 @@ def _version_check(name: str, scale_for_comparison=100) # TODO(phawkins): for some reason this check fails with a cusolver internal # error when fetching the version. This may be a path error from our stubs. - # Figure out what's happening here and reenable. + # Figure out what's happening here and re-enable. # _version_check("cuSOLVER", cuda_versions.cusolver_get_version, # cuda_versions.cusolver_build_version, # # Ignore patch versions. diff --git a/jaxlib/config.cc b/jaxlib/config.cc index 3d701c516990..625a5aa5a319 100644 --- a/jaxlib/config.cc +++ b/jaxlib/config.cc @@ -34,7 +34,7 @@ namespace jax { namespace nb = nanobind; -// Singleton object used to represet "value not set" in thread-local configs. +// Singleton object used to represent "value not set" in thread-local configs. nb::object UnsetObject() { return nb::steal(PyObject_CallObject( reinterpret_cast(&PyBaseObject_Type), nullptr)); @@ -71,7 +71,7 @@ class ThreadLocalConfigState { // These values are accessed in one of two ways: // * The owning thread reads or writes them, while holding the GIL, or, under // free-threading, while the owning thread is in ATTACHED gc state. - // * Other threads may read or clear values while performing a garbarge + // * Other threads may read or clear values while performing a garbage // collection. // No locking is needed because a GC thread cannot run concurrently with other // Python threads; even under free-threading Python uses a stop-the-world GC. @@ -117,7 +117,7 @@ class GlobalConfigState { private: friend class Config; - // The set of thread-local states. This is used during garbarge collection to + // The set of thread-local states. This is used during garbage collection to // visit thread-local values. absl::Mutex mu_; absl::flat_hash_set thread_local_states_ diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 678d92bc434a..00e26756ded9 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -173,7 +173,7 @@ def if_building_jaxlib( }) def _cpu_test_deps(): - """Returns the test depencies needed for a CPU-only JAX test.""" + """Returns the test dependencies needed for a CPU-only JAX test.""" return select({ "//jax:config_build_jaxlib_true": [], "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 9d6085397493..2f3bfb808981 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -441,7 +441,7 @@ llvm::LogicalResult BroadcastInDimOp::verify() { } if (i > 0 && dims[i] <= dims[i - 1]) { return error( - "The values in the `broadcast_dimensions` attribute must be stricly " + "The values in the `broadcast_dimensions` attribute must be strictly " "increasing."); } } diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 1465f76aa7bf..217bf1a3593b 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -405,7 +405,7 @@ def MosaicGPU_SliceSMEMOp : Op { } def MosaicGPU_WGMMAOp : Op { - let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations."; + let summary = "Multiply two matrices asynchronously using warpgroup level matrix multiply operations."; let description = [{ Schedules WGMMA operations that perform the following matrix multiply and accumulate: @@ -434,7 +434,7 @@ def MosaicGPU_WGMMAOp : Op { registers need to be synchronized with a memory fence. Usually `a` is read from shared memory if it is used directly in the WGMMA - operation. If `a` needs to be transfromed before it is used in the WGMMA + operation. If `a` needs to be transformed before it is used in the WGMMA operation, it may be more convenient to read it directly form registers. This avoids the need to store the data and wait for a fence. }]; diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 8261d09697e3..dceee9cf41a8 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -168,7 +168,7 @@ class RectangularVregBounds : public VRegDataBounds { // --- // // The tiling attribute makes it possible to subdivide a single vector register -// into multiple subtiles that traverse the last dimension of a value. For +// into multiple sub-tiles that traverse the last dimension of a value. For // example, consider vregs of shape (4, 5) on (2, 10) array: // // a b c d e f g h i j diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1669d1bf1586..390ce9d3db32 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3292,7 +3292,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, // a bunch of loads! } else { return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " + "Not implemented: mismatch in memref tiling and vector tiling in " "load"); } } @@ -4772,7 +4772,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, // us a bunch of stores! } else { return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " + "Not implemented: mismatch in memref tiling and vector tiling in " "store"); } } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 17575183bd81..7d279c5cb307 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -469,7 +469,7 @@ class VectorLayoutInferer { TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(), "scf if results and else branch yield operands do not match"); auto else_yield_in_layouts = getLayoutFromOperands(else_yield); - // Find a compatible layout from then and else branches for each reuslt. For + // Find a compatible layout from then and else branches for each result. For // example, if we yield offset (*, *) in then branch and offset (*, 0) in // else branch, the result offset should be (*, 0). SmallVector out_layouts; @@ -649,7 +649,7 @@ class VectorLayoutInferer { auto yield_in_layouts = getLayoutFromOperands(yield_op); // Find a compatible layout from condition body and loop body for each - // reuslt. For example, if we yield offset (*, *) in condition body and + // result. For example, if we yield offset (*, *) in condition body and // offset (*, 0) in loop body, the result offset should be (*, 0). SmallVector out_layouts; out_layouts.reserve(op->getNumResults()); diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc index 804352161597..1e9d53547b1d 100644 --- a/jaxlib/pjit.cc +++ b/jaxlib/pjit.cc @@ -653,7 +653,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, // development. // // TODO(chky): Consider support uncommitted PyArray in cpp when the python - // side stablizes. + // side stabilizes. if (!py_array.committed() && jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { VLOG(2) << "PyArray argument is not committed and number of global " diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h index 520fbf8b1e59..772dba864684 100644 --- a/jaxlib/py_client.h +++ b/jaxlib/py_client.h @@ -92,7 +92,7 @@ class PyClient { return pjrt_client->shared_ptr_pjrt_client(); } - // Legacy alises. + // Legacy aliases. std::shared_ptr shared_pjrt_client() { return shared_ptr_pjrt_client(); } diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc index 1b9c8c43b126..57de57b26aee 100644 --- a/jaxlib/xla_compiler.cc +++ b/jaxlib/xla_compiler.cc @@ -196,7 +196,7 @@ absl::StatusOr MakeShapeWithDenseLayout( // `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` // dimensions in `dims`. // -// In practice, `reshape_dims` often maps to the axises of user defined device +// In practice, `reshape_dims` often maps to the axes of user defined device // mesh, and `transpose_perm` often maps to the user specification of how a // tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh // of size 4x2x2 as AxBxC: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index e7ae4d0468fd..050ac5314da3 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -1228,7 +1228,7 @@ def while_body(s): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNotNone(err.get()) # self.assertStartsWith(err.get(), "division by zero") @@ -1257,7 +1257,7 @@ def fun(x): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.arange(5)) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNone(err.get()) def test_assert_cond_no_data_dependence(self): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index b5e875c03676..5d0f747a85ad 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -131,7 +131,7 @@ def _check_tracers_and_jaxprs(self, traceable: Any, mode. The debug infos in the nested Jaxprs are first converted to strings using `_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`. During this conversion, - we strip occurences of this test file name and a line number + we strip occurrences of this test file name and a line number (e.g., .*/debug_info_test.py:56) An element of `expected_jaxpr_debug_infos` can be a string, in which case it is compared by equality, or a `re.Pattern` (the result of `re.compile`) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index a7eeb4dbf86b..e20017a39a9b 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -35,7 +35,7 @@ # TODO: AOT tests fails with the tracer leak checker. -# Reenable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# Re-enable once https://github.com/jax-ml/jax/issues/27315 is fixed. # @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): diff --git a/tests/export_test.py b/tests/export_test.py index 0dfebdcec054..2be5313b0b3a 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1548,7 +1548,7 @@ def test_multi_platform(self): self.assertIn("jax.uses_shape_polymorphism = true", module_str) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) @@ -1573,7 +1573,7 @@ def test_multi_platform_nested(self): count_sine = len(re.findall("stablehlo.sine", exp2_module_str)) self.assertEqual(1, count_sine) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) @@ -1716,7 +1716,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_native = f_jax(a) exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index cdfeeba6275b..67c19179bb8b 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -218,7 +218,7 @@ def check_all_close(xs, ys, tol=1e-3): def check_close(x, y, tol=1e-3): assert jnp.shape(x) == jnp.shape(y) - # TODO(dougalm): re-enable once we've tackled the less pendantic bugs + # TODO(dougalm): re-enable once we've tackled the less pedantic bugs # assert x.dtype == y.dtype assert jnp.allclose(x, y, rtol=tol, atol=tol), \ f"Value mismatch:\n{x}\n vs\n{y}\n" diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index 87f60dd86f20..bada2bebc74e 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -29,7 +29,7 @@ class GpuMemoryAllocationTest(absltest.TestCase): @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, - "Test does not work if the python client allocator has been overriden", + "Test does not work if the python client allocator has been overridden", ) def test_gpu_memory_allocation(self): falsey_values = ("0", "False", "false") diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index e44ff9ebc930..ecbf908d2f09 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -3867,7 +3867,7 @@ def testItem(self, shape, dtype, num_args, use_tuple): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) @jtu.sample_product( - # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. + # Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs. shape=[(0,), (32,), (2, 16)], a_dtype=all_dtypes, dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e98ac6986cae..09a081761d0e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6214,7 +6214,7 @@ def test_isdtype(self, dtype, kind): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def test_trapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 905d7eed1acd..f2155afb841d 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -543,7 +543,7 @@ def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): if (jnp_fun.nin, jnp_fun.nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") if name in ['add', 'multiply'] and dtype == bool: - # TODO(jakevdp): figure out how to fix thest cases. + # TODO(jakevdp): figure out how to fix test cases. self.skipTest(f"known failure for {name}.reduceat with {dtype=}") rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index e0d3528dfa41..610bf5fabefd 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -646,7 +646,7 @@ def test_spence(self, shape, dtype): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_test.py b/tests/lax_test.py index a11c989fc9c5..4d30cd70ea70 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4395,7 +4395,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): # # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, # neg, pos are tested separately as these don't contain extremely - # small or extremelly large values and functions on these regions + # small or extremely large values and functions on these regions # ought not to possess any incorrectness issues. s0, s1 = size_re, size_im diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index af8b296fe536..5fc6bc8c703c 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -711,7 +711,7 @@ def test_broadcast_in_dim_dim_transpose(self): with self.assertRaisesRegex( ir.MLIRError, - r"`broadcast_dimensions` attribute must be stricly increasing", + r"`broadcast_dimensions` attribute must be strictly increasing", ): self.module.operation.verify() diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py index 3760c7ccddb7..b1b7e0ffd118 100644 --- a/tests/pallas/mgpu_collective_matmul_test.py +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -90,7 +90,7 @@ def test_all_gather_lhs_matmul( if n_shard != block_n: self.skipTest("n_shard must be equal to block_n for now.") if n_shard % block_n: - self.skipTest("n_shard must be divisble by block_n for now.") + self.skipTest("n_shard must be divisible by block_n for now.") if m_shard % block_m: self.skipTest("m_shard must be divisible by block_m for now.") diff --git a/tests/pallas/tpu_fusible_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py index 4bde9b95483b..ae56d3db2f3a 100644 --- a/tests/pallas/tpu_fusible_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -416,7 +416,7 @@ def matmul_slice_ref(x, y, b, i, j, k): @parameterized.parameters('float32', 'bfloat16') def test_matmul_input_concat_output(self, dtype): - self.skipTest('select_n doesnt support more than 3 elements') + self.skipTest('select_n does not support more than 3 elements') # TODO(sharadmv): fix this test k0, k1, k2, k3 = jax.random.split(jax.random.key(0), 4) x = jax.random.normal(k0, (128, 128), dtype) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7087bcad58bf..e03e5127d023 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -157,7 +157,7 @@ def f(x): with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected - # One is the funtion f, the other one is multi slice module + # One is the function f, the other one is multi slice module with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index d2483966c984..fb5a7560d947 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -280,9 +280,9 @@ def setUp(self): self.skipTest(str(e)) return if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") if _dtypes.float4_e2m1fn is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float4_e2m1fn") if cudnn_version < 90700: self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): @@ -473,7 +473,7 @@ def setUp(self): self.skipTest(str(e)) return if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") if cudnn_version < 90700: self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 6d1ffe744ed9..68aaf4e29553 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -961,7 +961,7 @@ def test_constraints_ge_complex_gen(self, self.assertEqual(bounds, _bounds(exp)) def test_constraints_ge_override(self): - # Some constaints override other + # Some constraints override other a, b = shape_poly.symbolic_shape("a, b", constraints=("a >= 5", "b <= 16", "a >= 10", "b <= 10")) @@ -979,7 +979,7 @@ def test_constraint_eq_0(self): self.assertIs(d, 5) def test_constraints_eq_1(self): - # Some constaints override other + # Some constraints override other a, b, c = shape_poly.symbolic_shape("a, b, c", constraints=("max(a, b) == c",)) self.assertEqual(_bounds(core.max_dim(a, b) - c + 3), (3, 3)) diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 17e0bbb03542..185c5a4294dc 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -129,7 +129,7 @@ def tf_http_archive(name, sha256, urls, **kwargs): "storage.googleapis.com", )]): fail("The first entry of tf_http_archive(urls) must be a mirror " + - "URL, preferrably mirror.tensorflow.org. Even if you don't have " + + "URL, preferably mirror.tensorflow.org. Even if you don't have " + "permission to mirror the file, please put the correctly " + "formatted mirror URL there anyway, because someone will come " + "along shortly thereafter and mirror the file.") From 8d8cc2bca67fc75718b73337c9ce19d6b77065e9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Jun 2025 08:08:49 -0700 Subject: [PATCH 1524/1769] Reverts 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3 PiperOrigin-RevId: 767149635 --- jax/_src/dispatch.py | 57 +++++++++----------------------------------- 1 file changed, 11 insertions(+), 46 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b9ef8f49f801..b5e588cbc10e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -356,6 +356,16 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(x) + if inp_sharding.device_set != target_sharding.device_set: + inp_ids = [d.id for d in inp_sharding._device_assignment] + inp_plat = inp_sharding._device_assignment[0].platform.upper() + target_ids = [d.id for d in target_sharding._device_assignment] + target_plat = target_sharding._device_assignment[0].platform.upper() + raise ValueError("Input and target sharding should have the same set of " + f"devices. Got input's device set ids: {inp_ids} on " + f"platform {inp_plat} and target sharding's device set " + f"ids: {target_ids} on platform {target_plat}") + if inp_sharding.is_fully_replicated: permute_order = None else: @@ -379,25 +389,6 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics): return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore -@util.cache() -def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): - """Returns True if src->dst is a supported cross-host transfer.""" - backend = xla_bridge.get_backend() - # There is experimental support for cross-host device transfers on TFRT TPU - # backends only. - if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or - "TFRT TPU" not in backend.platform_version): - return False - if (src_sharding._to_xla_hlo_sharding(ndim) != - dst_sharding._to_xla_hlo_sharding(ndim)): - return False - # This check excludes the case where the source and destination shardings - # have the same process index sets but there are shards that require - # cross-host transfers. This case is supportable but expensive to check for. - return (src_sharding._internal_device_list.process_indices != - dst_sharding._internal_device_list.process_indices) - - @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`. @@ -428,8 +419,7 @@ def _device_put_sharding_impl(x, aval, device, copy): return x if (not s.is_fully_addressable and - isinstance(x, array.ArrayImpl) and not x.is_fully_addressable and - s.device_set == x.sharding.device_set): + isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) @@ -440,32 +430,7 @@ def _device_put_sharding_impl(x, aval, device, copy): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) - # There is experimental support for cross-host device transfers on TFRT TPU. - if (isinstance(x, array.ArrayImpl) and x._committed - and _is_supported_cross_host_transfer(x.ndim, x.sharding, s)): - return xc.batched_copy_array_to_devices_with_sharding( - [x], [s._internal_device_list], [s], # pytype: disable=attribute-error - pxla.to_xc_copy_semantics([copy]))[0] - if not s.is_fully_addressable: - # If both the source and target shardings are not fully addressable and - # one of the above conditions has not been met, then assume that the user - # is attempting a different device order reshard. - if (isinstance(x, array.ArrayImpl) and not x.is_fully_addressable - and s.device_set != x.sharding.device_set): - inp_ids = [d.id for d in x.sharding._device_assignment] - inp_plat = x.sharding._device_assignment[0].platform.upper() - target_ids = [d.id for d in s._device_assignment] - target_plat = s._device_assignment[0].platform.upper() - raise ValueError( - "For a cross-host reshard in multi-controller JAX, input and target" - " sharding should have the same set of devices. Got input's device" - f" set ids: {inp_ids} on platform {inp_plat} and target sharding's" - f" device set ids: {target_ids} on platform {target_plat}.\n\n" - "There is experimental support for cross-host transfers with " - "different device sets, when input/output shardings have the same " - "indices and layouts, in the TFRT TPU runtime only.") - if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types or type(x) in dtypes.python_scalar_dtypes): # If all hosts participate in the sharding, assert that the input is the From e19e18d4d7a873661c48f774f4e282505264fed7 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 4 Jun 2025 08:59:06 -0700 Subject: [PATCH 1525/1769] Add not-implemented sharding rule in `third_party/py/jax/_src/cudnn/fused_attention_stablehlo.py`. PiperOrigin-RevId: 767166345 --- jax/_src/cudnn/fused_attention_stablehlo.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 46df84e08e0f..e15d9ddbce09 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -1078,16 +1078,21 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ] = _dot_product_attention_bwd_batcher +def not_implemented_sharding_rule(*args, **kwargs): + return NotImplementedError("Sharding rule not implemented.") + _dot_product_attention_fwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, - partition=_dot_product_attention_fwd_partition) + partition=_dot_product_attention_fwd_partition, + sharding_rule=not_implemented_sharding_rule) mlir.register_lowering(_dot_product_attention_fwd_p_wrapper, mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True)) _dot_product_attention_bwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands, - partition=_dot_product_attention_bwd_partition) + partition=_dot_product_attention_bwd_partition, + sharding_rule=not_implemented_sharding_rule) mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) From b3db37426a967936c7ec95f66ac6ff018ae23bdf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Jun 2025 09:33:56 -0700 Subject: [PATCH 1526/1769] Make `unreduced` argument in `PartitionSpec` a `set | frozenset` instead of a `tuple`. This is because the order of axes in `unreduced` doesn't matter. While lowering, `unreduced` is sorted wrt the mesh axis names so in McJAX all hosts lower to the same thing. PiperOrigin-RevId: 767178835 --- jax/_src/core.py | 4 +-- jax/_src/lax/lax.py | 2 +- jax/_src/named_sharding.py | 15 +----------- jax/_src/partition_spec.py | 50 ++++++++++++++++++-------------------- tests/array_test.py | 46 ++++++++++++++++------------------- tests/pjit_test.py | 22 ++++++++--------- 6 files changed, 60 insertions(+), 79 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index d8481afe872a..91ddd77e3d49 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1924,8 +1924,8 @@ def modify_spec_for_auto_manual(spec, mesh) -> P: temp_s = s[0] if isinstance(s, tuple) else s new_spec.append(s if mesh._name_to_type[temp_s] == AxisType.Explicit else None) - new_unreduced = tuple(u for u in spec.unreduced - if mesh._name_to_type[u] == AxisType.Explicit) + new_unreduced = {u for u in spec.unreduced + if mesh._name_to_type[u] == AxisType.Explicit} return P(*new_spec, unreduced=new_unreduced) def _maybe_modify_sharding(sharding, ndim): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 43dffc7bef9c..389960a7c94d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5228,7 +5228,7 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, ' out_sharding provided to dot_general mentions unreduced_axes.' f' Got {out_sharding=}, {lhs_contracting_spec=},' f' {rhs_contracting_spec=}') - if out_sharding.spec.unreduced != lhs_contracting_spec: + if out_sharding.spec.unreduced != frozenset(lhs_contracting_spec): raise core.ShardingTypeError( "out_sharding's unreduced axes should be equal to the contracting" f' specs. Got unreduced axes={out_sharding.spec.unreduced} and' diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index faf0b2a9f2b2..13124d4a36aa 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -302,7 +302,7 @@ class SdyArray: dim_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () - unreduced_axes: tuple[str, ...] = () + unreduced_axes: frozenset[str] = frozenset() def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: @@ -503,24 +503,11 @@ def _check_mesh_resource_axis(mesh, pspec): f' axis_types: {mesh._axis_types_dict}') def _check_mesh_unreduced(mesh, pspec): - counts = {} - duplicate = False for u in pspec.unreduced: if u not in mesh.axis_names: raise ValueError( f'Unreduced axes {u} is not found in {mesh.axis_names=}. ' f'Got {pspec=}') - count = counts.get(u, 0) - if count > 0: - duplicate = True - counts[u] = count + 1 - if duplicate: - multiple_uses = [r for r, c in counts.items() if c > 1] - raise ValueError( - f'Unreduced axes in {pspec} has duplicate entries which is not allowed.' - f' Got {mesh_lib.show_axes(multiple_uses)}') - - for u in pspec.unreduced: if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): raise ValueError( 'Unreduced axes can only refer to mesh axes that is of type' diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 2c833e6544e4..629af80ed38b 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Set from typing import Any class UnconstrainedSingleton: @@ -44,11 +45,10 @@ def _canonicalize_partition(partition): return partition def _check(partitions, unreduced): - us = set(unreduced) for p in partitions: p = p if isinstance(p, tuple) else (p,) for r in p: - if r in us: + if r in unreduced: raise ValueError( "partitions cannot overlap with unreduced axes passed to" f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" @@ -58,7 +58,7 @@ def _check(partitions, unreduced): "unreduced cannot contain None. All elements in unreduced should refer" " to the mesh axes.") -def unpicke_pspec(partitions, unreduced): +def unpickle_pspec(partitions, unreduced): return PartitionSpec(*partitions, unreduced=unreduced) AxisName = Any @@ -72,34 +72,32 @@ class PartitionSpec: This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ - __slots__ = ("_partitions", "_unreduced") + __slots__ = ("_partitions", "unreduced") __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION def __init__(self, *partitions, - unreduced: tuple[AxisName, ...] | AxisName | None = None): + unreduced: Set[AxisName] | None = None): self._partitions = tuple(_canonicalize_partition(p) for p in partitions) - self._unreduced = ( - () if unreduced is None else tuple(unreduced) - if isinstance(unreduced, (list, tuple)) else (unreduced,)) - _check(self._partitions, self._unreduced) - - @property - def unreduced(self): - return self._unreduced + if unreduced is not None and not isinstance(unreduced, (set, frozenset)): + raise TypeError( + "`unreduced` argument of PartitionSpec should be `None` or of type" + f" `frozenset` or `set`. Got type {type(unreduced)}") + self.unreduced = frozenset() if unreduced is None else frozenset(unreduced) + _check(self._partitions, self.unreduced) def __repr__(self): pr = repr(self._partitions)[1:-1] - if not self._unreduced: + if not self.unreduced: return f"PartitionSpec({pr})" - ur_str = f"unreduced={self._unreduced!r}" + ur_str = f"unreduced={set(self.unreduced)!r}" pr = '' if not pr else f"{pr} " if pr.endswith(',') else f"{pr}, " return (f"PartitionSpec({pr}{ur_str})") def __reduce__(self): - return (unpicke_pspec, (self._partitions, self._unreduced)) + return (unpickle_pspec, (self._partitions, self.unreduced)) def __getitem__(self, i): return self._partitions[i] @@ -113,9 +111,9 @@ def __len__(self): def __eq__(self, other): if isinstance(other, PartitionSpec): return (self._partitions == other._partitions and - self._unreduced == other._unreduced) + self.unreduced == other.unreduced) elif isinstance(other, tuple): - if self._unreduced: + if self.unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__eq__` of PartitionSpec.") @@ -125,27 +123,27 @@ def __eq__(self, other): return False def __hash__(self): - return hash((self._partitions, self._unreduced)) + return hash((self._partitions, self.unreduced)) def __add__(self, other): - if not isinstance(other, (tuple, PartitionSpec)): - raise NotImplementedError if isinstance(other, PartitionSpec): return PartitionSpec( *self, *other, - unreduced=(*self._unreduced, *other._unreduced)) - else: - if self._unreduced: + unreduced={*self.unreduced, *other.unreduced}) + elif isinstance(other, tuple): + if self.unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__add__` of PartitionSpec.") return PartitionSpec(*self, *other) + else: + raise NotImplementedError def __radd__(self, other): if not isinstance(other, tuple): raise NotImplementedError # other will always be a tuple. - if self._unreduced: + if self.unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__radd__` of PartitionSpec.") @@ -158,7 +156,7 @@ def count(self, value): return self._partitions.count(_canonicalize_partition(value)) def with_partitions(self, new_partitions): - return PartitionSpec(*new_partitions, unreduced=self._unreduced) + return PartitionSpec(*new_partitions, unreduced=self.unreduced) def with_unreduced(self, new_unreduced): return PartitionSpec(*self._partitions, unreduced=new_unreduced) diff --git a/tests/array_test.py b/tests/array_test.py index 44734e64a995..4403755a1dff 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1412,67 +1412,63 @@ def test_memory_kind_with_abstract_mesh(self): NamedSharding(abstract_mesh, P(), memory_kind='weird_device') def test_pspec_unreduced(self): - pspec1 = P('a', 'b', None, unreduced=('c',)) + pspec1 = P('a', 'b', None, unreduced={'c'}) self.assertEqual(repr(pspec1), - "PartitionSpec('a', 'b', None, unreduced=('c',))") + "PartitionSpec('a', 'b', None, unreduced={'c'})") - pspec2 = P('a', 'b', None, unreduced=('c',)) + pspec2 = P('a', 'b', None, unreduced={'c'}) self.assertEqual(pspec1, pspec2) - pspec3 = P('a', 'b', None, unreduced=('d',)) + pspec3 = P('a', 'b', None, unreduced={'d'}) self.assertNotEqual(pspec1, pspec3) - out = P('x', unreduced=('z',)) + P('a', unreduced='b') - self.assertEqual(out, P('x', 'a', unreduced=('z', 'b'))) + out = P('x', unreduced={'z'}) + P('a', unreduced={'b'}) + self.assertEqual(out, P('x', 'a', unreduced={'z', 'b'})) - pspec4 = P('x', unreduced='y') + pspec4 = P('x', unreduced={'y'}) self.assertEqual(repr(pspec4), - "PartitionSpec('x', unreduced=('y',))") + "PartitionSpec('x', unreduced={'y'})") - pspec5 = P(None, None, unreduced='x') + pspec5 = P(None, None, unreduced={'x'}) self.assertEqual(repr(pspec5), - "PartitionSpec(None, None, unreduced=('x',))") + "PartitionSpec(None, None, unreduced={'x'})") - pspec6 = P(None, unreduced='x') - self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced=('x',))") + pspec6 = P(None, unreduced={'x'}) + self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced={'x'})") - pspec7 = P(unreduced='x') - self.assertEqual(repr(pspec7), "PartitionSpec(unreduced=('x',))") + pspec7 = P(unreduced={'x'}) + self.assertEqual(repr(pspec7), "PartitionSpec(unreduced={'x'})") with self.assertRaisesRegex( TypeError, 'unreduced in `__add__` of PartitionSpec'): - P('x', unreduced=('z',)) + (None,) * 2 + P('x', unreduced={'z'}) + (None,) * 2 with self.assertRaisesRegex( TypeError, "unreduced in `__radd__` of PartitionSpec"): - (None,) * 2 + P('x', unreduced='y') + (None,) * 2 + P('x', unreduced={'y'}) with self.assertRaisesRegex( ValueError, "partitions cannot overlap with unreduced"): - P('x', 'y', unreduced='x') + P('x', 'y', unreduced={'x'}) with self.assertRaisesRegex( ValueError, "partitions cannot overlap with unreduced"): - P('x', None, 'y', unreduced=('z', 'y')) + P('x', None, 'y', unreduced={'z', 'y'}) def test_named_sharding_unreduced_error(self): mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) with self.assertRaisesRegex( ValueError, "Unreduced axes.*not found in mesh.*"): - NamedSharding(mesh, P('x', unreduced='a')) - - with self.assertRaisesRegex( - ValueError, "Unreduced.*has duplicate entries"): - NamedSharding(mesh, P('x', unreduced=('y', 'y'))) + NamedSharding(mesh, P('x', unreduced={'a'})) with self.assertRaisesRegex( ValueError, "Unreduced axes can only refer to mesh axes.*Explicit"): - NamedSharding(mesh, P('x', unreduced=('y', 'z'))) + NamedSharding(mesh, P('x', unreduced={'y', 'z'})) with self.assertRaisesRegex( ValueError, "unreduced cannot contain None.*"): - NamedSharding(mesh, P('x', unreduced=('y', None))) + NamedSharding(mesh, P('x', unreduced={'y', None})) def test_hlo_sharding_get_axis_sizes(self): if jaxlib_extension_version < 343: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 04df138e68e2..b08d462b3485 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7781,14 +7781,14 @@ def test_unreduced_basic(self, mesh): @jax.jit def f(x, y, a, b): - m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) - self.assertEqual(m1.aval.sharding.spec, P('x', None, unreduced='y')) + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m1.aval.sharding.spec, P('x', None, unreduced={'y'})) - m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced='y')) - self.assertEqual(m2.aval.sharding.spec, P('x', None, unreduced='y')) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m2.aval.sharding.spec, P('x', None, unreduced={'y'})) s = m1 + m2 # unreduced - self.assertEqual(s.aval.sharding.spec, P('x', None, unreduced='y')) + self.assertEqual(s.aval.sharding.spec, P('x', None, unreduced={'y'})) out = reshard(s, P('x')) # reduce self.assertEqual(out.aval.sharding.spec, P('x', None)) @@ -7808,7 +7808,7 @@ def test_dot_general_unreduced_error(self, mesh): @jax.jit def f(x, y): - return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='z')) + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'z'})) with self.assertRaisesRegex( core.ShardingTypeError, "unreduced axes should be equal to the contracting specs"): @@ -7819,7 +7819,7 @@ def f(x, y): y = jax.device_put(np_inp.T, P(None, None)) @jax.jit def g(x, y): - return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) with self.assertRaisesRegex( core.ShardingTypeError, "lhs and rhs contracting dims should be sharded identically"): @@ -7831,7 +7831,7 @@ def g(x, y): @jax.jit def h(x, y): - return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) with self.assertRaisesRegex( core.ShardingTypeError, "unreduced axes should be equal to the contracting specs"): @@ -7847,8 +7847,8 @@ def test_add_unreduced_error(self, mesh): @jax.jit def f(x, y, a, b): - m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) - m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced='z')) + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'z'})) return m1 + m2 with self.assertRaisesRegex( @@ -7858,7 +7858,7 @@ def f(x, y, a, b): @jax.jit def g(x, y): - m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y')) + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x')) return m1 + m2 From 8c348652d687429a2f2de30cca82bf905f57e622 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 4 Jun 2025 09:52:40 -0700 Subject: [PATCH 1527/1769] [Mosaic GPU] Add a test for TMA multicasts in pallas. This also effectively tests `lax.axis_index` and `lax.axis_size` on clusters axes. PiperOrigin-RevId: 767185068 --- tests/pallas/mosaic_gpu_test.py | 87 +++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6cedbc6ae14c..1170b2ac5cdb 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -15,6 +15,7 @@ import contextlib import dataclasses import functools +import itertools import math import operator import os @@ -1787,6 +1788,92 @@ def body(x_ref, y_ref, barrier): else: self.fail("Should have raised an exception") + @parameterized.named_parameters( + ( + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + collective_dims, + noncollective_dims, + collective_size, + ) + for collective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(1, 4) + ) + for noncollective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(3) + ) + for collective_size in (1, 2, 4) + if all(d not in noncollective_dims for d in collective_dims) + ) + def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size): + """ + 1. Broadcast a GMEM slice to SMEM across collective CTAs. + 2. Send a SMEM slice from each collective CTA to reconstruct the GMEM slice. + It's not strictly necessary to use every collective CTA, but we use them + to test that the cluster axes are used correctly. + """ + + dtype = jnp.float16 + cluster = [1, 1, 1] + for d in collective_dims: + cluster["xyz".index(d)] = collective_dim_size + for d in noncollective_dims: + cluster["xyz".index(d)] = 2 + if math.prod(cluster) > 16: + self.skipTest("Cluster is too big.") + + collective_size = math.prod(cluster["xyz".index(d)] for d in collective_dims) + noncollective_size = math.prod(cluster) // collective_size + + swizzle = 128 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + shape = (noncollective_size, collective_size * 8, swizzle_elems) + + def body(x_gmem, out_gmem, smem, tma_barrier): + # Compute the index in a subset of the cluster. + def cluster_id(axes): + idx, stride = 0, 1 + for d in sorted(axes): + idx += lax.axis_index(d) * stride + stride *= lax.axis_size(d) + return idx + + noncollective_idx = cluster_id(noncollective_dims) + collective_idx = cluster_id(collective_dims) + + plgpu.copy_gmem_to_smem( + x_gmem.at[noncollective_idx], + smem, + tma_barrier, + collective_axes=collective_dims) + plgpu.barrier_wait(tma_barrier) + + plgpu.commit_smem() + collective_slice = pl.ds(8 * collective_idx, 8) + plgpu.copy_smem_to_gmem( + smem.at[collective_slice], + out_gmem.at[noncollective_idx, collective_slice, :], + ) + plgpu.wait_smem_to_gmem(0) + + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + kernel = plgpu.kernel( + body, + grid=cluster, + grid_names=("grid_x", "grid_y", "grid_z"), + cluster=cluster, + cluster_names=("x", "y", "z"), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=( + plgpu.SMEM(shape[1:], dtype, transforms=transforms), + plgpu.Barrier(), + ) + ) + np.testing.assert_array_equal(kernel(x), x) + class PallasCallWarpPrimitiveSemanticsTest(PallasTest): def setUp(self): From 704eb71d60669204e291c78694e3eed39e9dcc25 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Jun 2025 16:11:16 -0700 Subject: [PATCH 1528/1769] jnp.array: avoid call to stack This is in preparation for moving this definition into its own source file. --- jax/_src/numpy/lax_numpy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f323bc64718b..e95eb34a2380 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5524,7 +5524,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, out = _array_copy(object) if copy else object elif isinstance(object, (list, tuple)): if object: - out = stack([asarray(elt, dtype=dtype) for elt in object]) + arrs = (array(elt, dtype=dtype, copy=False) for elt in object) + out = lax.concatenate([lax.expand_dims(arr, [0]) for arr in arrs], 0) else: out = np.array([], dtype=dtype) elif _supports_buffer_protocol(object): From 2acbbcc255c543867319d548803a97a5cc7cbf6c Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 30 May 2025 09:35:13 -0700 Subject: [PATCH 1529/1769] Add a general system for keeping track of quasi-dynamic data (QDD). --- jax/_src/core.py | 156 ++++++++++++++++---------- jax/_src/interpreters/ad.py | 12 +- jax/_src/interpreters/partial_eval.py | 118 +++++++++++-------- jax/_src/lax/control_flow/loops.py | 50 ++++----- jax/_src/pjit.py | 35 +++--- tests/hijax_test.py | 87 ++++++++++---- 6 files changed, 277 insertions(+), 181 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 91ddd77e3d49..7260557ff4cc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -88,8 +88,7 @@ class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info', '_is_high', - '_initial_typechange_env', '_final_typechange_env'] + '_effects', '_debug_info', '_is_high'] _constvars: list[Var] _invars: list[Var] @@ -98,8 +97,6 @@ class Jaxpr: _effects: Effects _debug_info: DebugInfo _is_high: bool - _initial_typechange_env: dict[Var, Any] - _final_typechange_env: dict[Var, Any] @property def constvars(self) -> list[Var]: @@ -129,14 +126,6 @@ def debug_info(self) -> DebugInfo: def is_high(self) -> bool: return self._is_high - @property - def initial_typechange_env(self) -> dict[Var, Any]: - return self._initial_typechange_env - - @property - def final_typechange_env(self) -> dict[Var, Any]: - return self._final_typechange_env - def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, @@ -145,8 +134,6 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] is_high: bool = False, - initial_typechange_env: dict | None = None, - final_typechange_env: dict | None = None, ): """ Args: @@ -172,8 +159,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) self._is_high = is_high - self._initial_typechange_env = initial_typechange_env or {} - self._final_typechange_env = final_typechange_env or {} + num_vars = len(constvars) + len(invars) def __str__(self): return str(self.pretty_print()) @@ -201,10 +187,6 @@ def replace(self, **kwargs): effects=kwargs.pop("effects", self.effects), debug_info=kwargs.pop("debug_info", self.debug_info), is_high=kwargs.pop("is_high", self.is_high), - initial_typechange_env=kwargs.pop("initial_typechange_env", - self.initial_typechange_env), - final_typechange_env=kwargs.pop("final_typechange_env", - self.final_typechange_env), ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") @@ -232,22 +214,6 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]: yield from jaxprs_in_params(eqn.params) -@dataclass(frozen=True) -class TypeChange: - aval: AbstractValue - initial_type_state: Any - final_type_state: Any - - def to_tangent_aval(self): - return TypeChange(self.aval.to_tangent_aval(), - self.initial_type_state.to_tangent_aval(), - self.final_type_state.to_tangent_aval()) - - def normalize(self): - return TypeChange(self.aval.normalize(), - self.initial_type_state.normalize(), - self.final_type_state.normalize()) - class ClosedJaxpr: __slots__ = ['__weakref__', '_jaxpr', '_consts'] @@ -268,10 +234,13 @@ def in_avals(self): return [v.aval for v in self.jaxpr.invars] @property - def in_avals_aug(self): - ienv = self.jaxpr.initial_typechange_env - fenv = self.jaxpr.final_typechange_env - return [TypeChange(v.aval, ienv[v], fenv[v]) if v.aval.mutable else v.aval + def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd) + for v in self.jaxpr.invars] + + @property + def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd) for v in self.jaxpr.invars] @property @@ -464,16 +433,22 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() class Var: - __slots__ = ["count", "suffix", "aval"] + __slots__ = ["count", "suffix", "aval", "initial_qdd", "final_qdd"] count: int suffix: str aval: AbstractValue + # these are only useful for jaxpr binders but rather than create a separate + # type for those, breaking existing interpreters, we add fields here. + initial_qdd : QuasiDynamicData | None + final_qdd : QuasiDynamicData | None - def __init__(self, suffix: str, aval: AbstractValue): + def __init__(self, suffix: str, aval: AbstractValue, initial_qdd = None, final_qdd = None): self.count = next(_var_counter) self.suffix = suffix self.aval = aval + self.initial_qdd = initial_qdd + self.final_qdd = final_qdd def __repr__(self): return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' @@ -483,7 +458,7 @@ def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): return f"{context.var_names[self]}{self.suffix}" -def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: +def gensym(suffix: str = '') -> Callable: """Produce distinct variables, printed with the optional suffix.""" return partial(Var, suffix) @@ -1114,6 +1089,8 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # py del primitive, fwd, bwd, _ # Unused. return fun.call_wrapped(*tracers) + def cur_qdd(self, x): + return x.cur_qdd() class TraceTag: # TODO: this works for surprisingly subtle reasons. Function transformations @@ -1505,7 +1482,7 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] is_high = False - mutable = False + has_qdd = False def to_tangent_aval(self): raise NotImplementedError("must override") @@ -1533,6 +1510,12 @@ def normalize(self) -> AbstractValue: def update(self, **kwargs): raise NotImplementedError("must override") + def lo_ty(self): + raise NotImplementedError("must override") + + def lo_ty_qdd(self, qdd): + raise NotImplementedError("avals with qdd must override") + def str_short(self, short_dtypes=False, mesh_axis_types=False): return str(self) @@ -1704,6 +1687,54 @@ def concrete_dim_or_error(val: Any, context=""): else: return concrete_or_error(operator.index, val, context=context) +### Quasi-dynamic data + +# Quasi-dynamic data includes things like liveness bits and the content type of +# a type-changeable box. These change throughout the program but at a given +# point in the program they have a single statically known value. + +class MutableQuasiDynamicData: + def __init__(self, val : QuasiDynamicData | None): + self.init_val = val + self.cur_val = val # immutable payload + + def update(self, val): + self.cur_val = val + +class QuasiDynamicData: + pass + +@dataclass(frozen=True) +class AvalQDD: + aval: AbstractValue + qdd: QuasiDynamicData | None # immutable + + has_qdd = True + def lo_ty(self): + return self.aval.lo_ty_qdd(self.qdd) # type: ignore + + def read_loval(self, val): + return self.aval.read_loval(self.qdd, val) # type: ignore + + def new_from_loval(self, *lovals): + return self.aval.new_from_loval(self.qdd, *lovals) # type: ignore + + def to_tangent_aval(self): + return AvalQDD(self.aval.to_tangent_aval(), self.qdd.to_tangent_qdd()) + +@dataclass(frozen=True) +class AvalMutableQDD: + aval: AbstractValue + mutable_qdd: MutableQuasiDynamicData + +def cur_qdd(x): + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return prev_trace.cur_qdd(x) + finally: + trace_ctx.set_trace(prev_trace) + ### Extended dtypes # # Extended dtypes are JAX-specific dtypes that allow us to represent logical @@ -2917,15 +2948,19 @@ def ctx_factory(): from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) +# A place to track the quasi-dynamic data associated with a variable during typechecking +@dataclass(frozen=True) +class MutableTypecheckVal: + aval : AbstractValue + mutable_qdd : MutableQuasiDynamicData def _check_jaxpr( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], jaxpr: Jaxpr ) -> None: - # Use set of variables to types to check that variables are in scope. - env: set[Var] = set() + env: dict[Var, Atom | MutableTypecheckVal] = {} - def read(x: Atom) -> Atom: + def read(x: Atom) -> Atom | MutableTypecheckVal: # Check the type annotation is itself well-typed. check_type(ctx_factory, env, x.aval) if isinstance(x, Var): @@ -2933,7 +2968,7 @@ def read(x: Atom) -> Atom: if x not in env: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined") - return x + return env[x] elif isinstance(x, Literal): # Check that the literal matches its type annotation. if not typecheck(x.aval, x.val): @@ -2945,7 +2980,8 @@ def read(x: Atom) -> Atom: else: assert False, "syntactically invalid jaxpr" - def write(v: Var, a: AbstractValue) -> None: + def write(v: Var, a: AvalQDD) -> None: + aval, qdd = a.aval, a.qdd assert isinstance(v, Var), "syntactically invalid jaxpr" # Check the type annotation of the binder is itself well-typed. check_type(ctx_factory, env, v.aval) @@ -2954,19 +2990,23 @@ def write(v: Var, a: AbstractValue) -> None: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound") # Check that the computed type is consistent with the binder annotation. - if not typematch(v.aval, a): + if not typematch(v.aval, aval): ctx, _ = ctx_factory() raise JaxprTypeError( f"Value for variable '{pp_var(v, ctx)}' inconsistently typed " - f"as {pp_aval(a, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + f"as {pp_aval(aval, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + # If the variable is not a DropVar, add it to the environment. if not isinstance(v, DropVar): - env.add(v) + if qdd is None: + env[v] = v + else: + env[v] = MutableTypecheckVal(aval, MutableQuasiDynamicData(qdd)) # Check type annotations on lambda binders. for v in it.chain(jaxpr.constvars, jaxpr.invars): check_type(ctx_factory, env, v.aval) - write(v, v.aval) + write(v, AvalQDD(v.aval, v.initial_qdd)) # Check each eqn. sentinel = object() @@ -2976,7 +3016,8 @@ def write(v: Var, a: AbstractValue) -> None: prim = eqn.primitive try: in_atoms = map(read, eqn.invars) - in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes + in_avals = [AvalMutableQDD(x.aval, x.mutable_qdd) if isinstance(x, MutableTypecheckVal) + else x.aval for x in in_atoms] # use in_atoms for dyn shapes # Compute the type of the primitive application. with eqn.ctx.manager: @@ -3026,6 +3067,7 @@ def write(v: Var, a: AbstractValue) -> None: # Check out_type matches the let-binders' annotation (after substitution). out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars) + out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type] foreach(write, eqn.outvars, out_type) except JaxprTypeError as e: @@ -3041,7 +3083,7 @@ def write(v: Var, a: AbstractValue) -> None: def check_type( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], - env: set[Var], + env: dict[Var, Atom | MutableTypecheckVal], ty: AbstractValue, ) -> None: if isinstance(ty, DShapedArray): @@ -3111,7 +3153,7 @@ def _check_call(ctx_factory, prim, in_atoms, params): f"{len(call_jaxpr.invars)} inputs") # Check `call_jaxpr` can be applied to in_atoms. - env: dict[Var, Atom] = {} + env: dict[Var, Atom | MutableTypecheckVal] = {} def substitute(aval: AbstractValue): if isinstance(aval, DShapedArray): aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore @@ -3122,7 +3164,7 @@ def substitute(aval: AbstractValue): raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " f"{x.aval} to jaxpr expecting type " f"{substitute(v.aval)}") - env[v] = x if type(x) is Var else x.val + env[v] = x.val if type(x) is Literal else x _check_jaxpr(ctx_factory, call_jaxpr) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index a77e93bb0696..69a123b12e23 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -501,6 +501,11 @@ def process_primitive(self, primitive, tracers, params): else: return maybe_jvp_tracer(self, primal_out, tangent_out) + def cur_qdd(self, x): + p, _ = self.to_primal_tangent_pair(x) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(p) + def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) @@ -629,6 +634,9 @@ def __init__(self, trace, primal, tangent): def aval(self): return get_aval(self.primal) + def cur_qdd(self): + return core.cur_qdd(self.primal) + def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) @@ -1170,8 +1178,8 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, f_jvp, out_nonzeros = f_jvp_traceable( jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() - for aval, nz in zip(jaxpr.in_avals_aug, nonzeros) if nz] - avals_in = list(it.chain(jaxpr.in_avals_aug, tangent_avals)) + for aval, nz in zip(jaxpr.in_aval_qdds, nonzeros) if nz] + avals_in = list(it.chain(jaxpr.in_aval_qdds, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1bcd3f00321c..a4f4fcd12429 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -222,6 +222,14 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: aval = get_aval(const).update_weak_type(np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) + def cur_qdd(self, x): + const = self.to_jaxpr_tracer(x).pval.get_known() + if const is None: + assert False # TODO: track tangent QDDs + else: + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(const) + def process_primitive(self, primitive, tracers, params): with core.set_current_trace(self.parent_trace): if primitive in custom_partial_eval_rules: @@ -1012,7 +1020,7 @@ def fun(*known_vals_in): known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] - known_avals = [a for a, uk in zip(jaxpr.in_avals_aug, in_unknowns) if not uk] + known_avals = [a for a, uk in zip(jaxpr.in_aval_qdds, in_unknowns) if not uk] jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( lu.wrap_init(fun, debug_info=f.debug_info), known_avals) (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking @@ -1201,14 +1209,11 @@ def has_effects(effects) -> bool: known_outvars = [*outs_known, *residuals] known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) - known_mut, staged_mut, ins_known_ = {}, {}, set(ins_known) # type: ignore - for v, t in jaxpr.final_typechange_env.items(): - [staged_mut, known_mut][v in ins_known_][v] = t # TODO(mattjj,necula): debug info should be updated here jaxpr_known = jaxpr.replace( invars=ins_known_and_ref_res, outvars=known_outvars, - eqns=known_eqns, effects=known_effects, final_typechange_env=known_mut) + eqns=known_eqns, effects=known_effects) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1219,7 +1224,7 @@ def has_effects(effects) -> bool: # TODO(mattjj,necula): debug info should be updated here jaxpr_staged = jaxpr.replace( invars=staged_invars, outvars=outs_staged, eqns=staged_eqns, - effects=staged_effects, final_typechange_env=staged_mut) + effects=staged_effects) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1634,15 +1639,30 @@ def _move_outvars_to_back(jaxpr, to_move): class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', '_debug_info'] + __slots__ = ['aval', 'mutable_qdd', '_debug_info'] def __init__(self, trace: DynamicJaxprTrace, - aval: core.AbstractValue, + aval: core.AbstractValue | core.AvalQDD, line_info: source_info_util.SourceInfo | None = None): + if isinstance(aval, core.AvalQDD): + assert aval.qdd is not None + aval, qdd = aval.aval, aval.qdd + else: + assert not aval.has_qdd + qdd = None self._trace = trace self._line_info = line_info self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError self.aval = aval # type: ignore[misc] + self.mutable_qdd = core.MutableQuasiDynamicData(qdd) + + @property + def aval_mutable_qdd(self): + aval = self.aval + if aval.has_qdd: + return core.AvalMutableQDD(aval, self.mutable_qdd) + else: + return aval def full_lower(self): var = self._trace.frame.tracer_to_var.get(id(self)) @@ -1651,10 +1671,6 @@ def full_lower(self): if val is None: return self return core.full_lower(val) - def type_state(self): - var = self._trace.frame.tracer_to_var.get(id(self)) - return self._trace.frame.current_typechange_env[var] - def _contents(self): return () @@ -1750,8 +1766,8 @@ class JaxprStackFrame: attrs_vars: list[Var] debug_info: core.DebugInfo is_high: bool - initial_typechange_env: dict - current_typechange_env: dict + mutable_qdds: list[tuple[Var, core.MutableQuasiDynamicData]] + def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() @@ -1767,8 +1783,7 @@ def __init__(self, debug_info: core.DebugInfo): self.attrs_vars = [] self.debug_info = debug_info self.is_high = False - self.initial_typechange_env = {} - self.current_typechange_env = {} + self.mutable_qdds = [] def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) @@ -1794,11 +1809,13 @@ def to_jaxpr( outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) - final_typechange_env = {v: s for v, s in self.current_typechange_env.items() - if v in self.initial_typechange_env} + + # TODO(dougalm): handle qdd for consts + for v, qdd in self.mutable_qdds: + v.final_qdd = qdd.cur_val + jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info, self.is_high, self.initial_typechange_env, - final_typechange_env) + debug_info, self.is_high) jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1827,7 +1844,10 @@ def newvar(self, aval): for d in aval.shape] new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] aval = aval.update(shape=tuple(new_shape)) - return self.gensym(aval) + if isinstance(aval, core.AvalQDD): + return self.gensym(aval.aval, initial_qdd=aval.qdd) + else: + return self.gensym(aval) def find_progenitors(self, tracer): var = self.tracer_to_var.get(id(tracer)) @@ -1883,12 +1903,13 @@ def vars(atom: Atom) -> list[Var]: class DynamicJaxprTrace(core.Trace): - __slots__ = ("frame", "tag") + __slots__ = ("frame", "tag", "parent_trace") - def __init__(self, debug_info: core.DebugInfo, lower=False): + def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False): super().__init__() self.requires_low = lower self.frame = JaxprStackFrame(debug_info) + self.parent_trace = parent_trace def invalidate(self): # avoid cyclic refs @@ -1915,6 +1936,7 @@ def new_arg(self, aval, source_info: SourceInfo): self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) return tracer def new_const(self, c, source_info: SourceInfo): @@ -1922,6 +1944,9 @@ def new_const(self, c, source_info: SourceInfo): tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = get_aval(c) + if aval.has_qdd: + with core.set_current_trace(self.parent_trace): + aval = core.AvalQDD(aval, core.cur_qdd(c)) if hasattr(aval, "weak_type"): aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval, source_info) @@ -1938,9 +1963,9 @@ def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: else: self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.constid_to_tracer[id(c)] = tracer + if isinstance(aval, core.AvalQDD): + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) self.frame.constvar_to_val[var] = c - if aval.mutable: - self.frame.initial_typechange_env[var] = c.type_state() return tracer def get_const(self, tracer) -> Any: @@ -1971,7 +1996,12 @@ def makevar(self, tracer): var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) return var + def cur_qdd(self, x): + source_info = source_info_util.current() + return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val + def process_primitive(self, primitive, tracers, params): + self.frame.is_high |= primitive.is_high(**params) if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) source_info = source_info_util.current() @@ -1982,8 +2012,8 @@ def process_primitive(self, primitive, tracers, params): return self.default_process_primitive(primitive, jaxpr_tracers, params) def default_process_primitive(self, primitive, tracers, params): - avals = [t.aval for t in tracers] - out_avals, effs = primitive.abstract_eval(*avals, **params) + aval_qdds = [t.aval_mutable_qdd for t in tracers] + out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") @@ -2254,17 +2284,13 @@ def trace_to_jaxpr_dynamic( ) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - trace = DynamicJaxprTrace(fun.debug_info, lower=lower) - in_avals_ = [a.aval if isinstance(a, core.TypeChange) else a for a in in_avals] + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace, lower=lower) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals_) + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - trace.frame.initial_typechange_env = initial_typechange_env = { - v: a.initial_type_state for v, a in zip(trace.frame.invars, in_avals) - if isinstance(a, core.TypeChange)} - trace.frame.current_typechange_env = dict(initial_typechange_env) try: with core.set_current_trace(trace): @@ -2329,7 +2355,8 @@ def trace_to_jaxpr_dynamic2( ) -> tuple[Jaxpr, OutputType, list[Any]]: assert fun.in_type is not None, "fun must be annotated with lu.annotate()" - trace = DynamicJaxprTrace(fun.debug_info) + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() in_avals, keep_inputs = unzip2(fun.in_type) @@ -2744,33 +2771,28 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis @weakref_lru_cache def lower_jaxpr(hi_jaxpr): - initial_env = hi_jaxpr.jaxpr.initial_typechange_env - lo_avals = [lo_ty for v in hi_jaxpr.jaxpr.invars - for lo_ty in (v.aval.lo_ty_(initial_env[v]) if v.aval.mutable - else v.aval.lo_ty())] + lo_avals = [lo_ty for aval in hi_jaxpr.in_aval_qdds for lo_ty in aval.lo_ty()] f = lu.wrap_init(partial(lower_traceable, hi_jaxpr), debug_info=hi_jaxpr.jaxpr.debug_info) lo_jaxpr, _, lo_consts, () = trace_to_jaxpr_dynamic(f, lo_avals, lower=True) return core.ClosedJaxpr(lo_jaxpr, lo_consts) def lower_traceable(jaxpr, *lo_args): - env = jaxpr.jaxpr.initial_typechange_env lo_args_ = iter(lo_args) - hi_args = [v.aval.raise_val(*it.islice(lo_args_, len(v.aval.lo_ty()))) - if not v.aval.mutable else - v.aval.new_from_loval(env[v], *it.islice(lo_args_, len(v.aval.lo_ty_(env[v])))) - for v in jaxpr.jaxpr.invars] + hi_args = [aval.raise_val(*it.islice(lo_args_, len(aval.lo_ty()))) + if not aval.has_qdd else + aval.new_from_loval(*it.islice(lo_args_, len(aval.lo_ty()))) + for aval in jaxpr.in_aval_qdds] assert (problem := next(lo_args_, None)) is None hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) - in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - mut_outs = [lo_val for v, ty in jaxpr.jaxpr.final_typechange_env.items() - for lo_val in v.aval.read_loval(ty, hi_args[in_idx[v]])] + mut_outs = [lo_val for aval, hi_arg in zip(jaxpr.final_aval_qdds, hi_args) if aval.has_qdd + for lo_val in aval.read_loval(hi_arg)] lo_outs = [lo_val for v, hi_val in zip(jaxpr.jaxpr.outvars, hi_outs) for lo_val in v.aval.lower_val(hi_val)] return mut_outs + lo_outs def convert_const_himutables(jaxpr): - move = [core.typeof(c).mutable for c in jaxpr.consts] + move = [core.typeof(c).has_qdd for c in jaxpr.consts] constvals, in_mutables = partition_list(move, jaxpr.consts) constvars, boxvars = partition_list(move, jaxpr.jaxpr.constvars) invars = *boxvars, *jaxpr.jaxpr.invars diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ad09292731cf..3fefab78f7f7 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1512,17 +1512,6 @@ def arrange_jaxpr_args_for_wrapped(args): assert len(refs_out_matching_in_avals) == len(in_avals) return refs_out_matching_in_avals, [*carry_out, *ys] -def _scan_staging(trace, *args, **params): - outs = trace.default_process_primitive(scan_p, args, params) - jaxpr = params['jaxpr'] - trace.frame.is_high = jaxpr.jaxpr.is_high - invars = [trace.frame.tracer_to_var[id(t)] for t in args] - var_map = dict(zip(jaxpr.jaxpr.invars, invars)) - final_env = {var_map[v]: ty for v, ty in - jaxpr.jaxpr.final_typechange_env.items()} - trace.frame.current_typechange_env.update(final_env) - return outs - scan_p = core.Primitive("scan") scan_p.multiple_results = True scan_p.skip_canonicalization = True @@ -1541,37 +1530,32 @@ def _scan_staging(trace, *args, **params): pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule) -pe.custom_staging_rules[scan_p] = _scan_staging def _is_high(jaxpr, **_) -> bool: return jaxpr.jaxpr.is_high scan_p.is_high = _is_high # type: ignore def _to_lojax(*hi_args, jaxpr, num_carry, num_consts, linear, **params): - ienv, fenv = jaxpr.jaxpr.initial_typechange_env, jaxpr.jaxpr.final_typechange_env # move box binders and hi_args from consts slots to carry slots - to_move = [t.mutable for t in jaxpr.in_avals[:num_consts]] + to_move = [t.has_qdd for t in jaxpr.in_aval_qdds[:num_consts]] jaxpr = pe.move_invars_right(jaxpr, to_move) hi_args = _move_right(hi_args, to_move) num_consts -= sum(to_move) num_carry += sum(to_move) # expand num_consts, num_carry, linear according to lo types - const_invars, carry_invars, _ = split_list(jaxpr.jaxpr.invars, [num_consts, num_carry]) - num_consts = sum(len(v.aval.lo_ty() if not v.aval.mutable - else v.aval.lo_ty_(ienv[v])) for v in const_invars) - num_carry = sum(len(v.aval.lo_ty() if not v.aval.mutable - else v.aval.lo_ty_(ienv[v])) for v in carry_invars) - linear = [l for v, l_ in zip(jaxpr.jaxpr.invars, linear) - for l in (l_,) * len(v.aval.lo_ty() if not v.aval.mutable - else v.aval.lo_ty_(ienv[v]))] - lo_muts_out = sum(len(m.leaf_avals) for m in fenv.values()) # TODO hardcoded - - # collect lo inputs values - lo_args = [lo_val for v, x in zip(jaxpr.jaxpr.invars, hi_args) - for lo_val in (v.aval.read_loval(ienv[v], x) if v.aval.mutable - else v.aval.lower_val(x))] + const_in_avals, carry_in_avals, _ = split_list(jaxpr.in_aval_qdds, [num_consts, num_carry]) + num_consts = sum(len(aval.lo_ty()) for aval in const_in_avals) + num_carry = sum(len(aval.lo_ty()) for aval in carry_in_avals) + linear = [l for aval, l_ in zip(jaxpr.in_aval_qdds, linear) + for l in (l_,) * len(aval.lo_ty())] + lo_muts_out = sum(len(aval.lo_ty()) for aval in jaxpr.final_aval_qdds if aval.has_qdd) + + # collect lo input values + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] # lower the jaxpr and bind it using lo input values lo_jaxpr = pe.lower_jaxpr(jaxpr) @@ -1582,9 +1566,13 @@ def _to_lojax(*hi_args, jaxpr, num_carry, num_consts, linear, **params): # collect and apply mutations out_mut_ = iter(out_mut) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - for var, ty in jaxpr.jaxpr.final_typechange_env.items(): - lo_vals = it.islice(out_mut_, len(var.aval.lo_ty_(ty))) - var.aval.update_from_loval(ty, hi_args[in_idx[var]], *lo_vals) + + for v in jaxpr.jaxpr.invars: + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = it.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[in_idx[v]], *lo_vals) + assert next(out_mut_, None) is None # collect output values into hi types diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 572c2225af74..985d1a46a053 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -590,7 +590,7 @@ def _infer_params_impl( in_type = in_avals = tuple(core.shaped_abstractify(x) for x in explicit_args) # type: ignore else: in_type = in_avals # type: ignore - in_type = tuple(core.TypeChange(a, x.type_state(), None) if a.mutable # type: ignore + in_type = tuple(core.AvalQDD(a, core.cur_qdd(x)) if a.has_qdd # type: ignore else a for a, x in zip(in_type, explicit_args)) assert in_avals is not None @@ -1416,7 +1416,7 @@ def _create_pjit_jaxpr( from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) - if any(isinstance(c, core.Tracer) or core.typeof(c).mutable for c in consts): + if any(isinstance(c, core.Tracer) or core.typeof(c).has_qdd for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) final_consts = consts else: @@ -1562,25 +1562,22 @@ def _is_high(jaxpr, **_) -> bool: pjit_p.is_high = _is_high # type: ignore def _to_lojax(*hi_args, jaxpr, **params): - ienv, fenv = jaxpr.jaxpr.initial_typechange_env, jaxpr.jaxpr.final_typechange_env - # convert closed-over boxes to explicit args jaxpr, closed_over_himutables = pe.convert_const_himutables(jaxpr) hi_args = [*closed_over_himutables, *hi_args] params = _converted_mutables_add_params(len(closed_over_himutables), **params) + # expand pjit params that must match number of lo inputs/outputs - lo_nums_in = [len(v.aval.lo_ty() if not v.aval.mutable - else v.aval.lo_ty_(ienv[v])) - for v in jaxpr.jaxpr.invars] + lo_nums_in = [len(aval.lo_ty()) for aval in jaxpr.in_aval_qdds] lo_nums_out = [len(t.lo_ty()) for t in jaxpr.out_avals] - lo_muts_out = sum(len(m.leaf_avals) for m in fenv.values()) # TODO hardcoded + lo_muts_out = sum(len(aval.lo_ty()) for aval in jaxpr.final_aval_qdds if aval.has_qdd) params = _lojax_expand_params(lo_nums_in, lo_nums_out, lo_muts_out, **params) # collect lo input values - lo_args = [lo_val for v, x in zip(jaxpr.jaxpr.invars, hi_args) - for lo_val in (v.aval.read_loval(ienv[v], x) if v.aval.mutable - else v.aval.lower_val(x))] + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] # lower the jaxpr and bind it using lo input values lo_jaxpr = pe.lower_jaxpr(jaxpr) @@ -1590,9 +1587,11 @@ def _to_lojax(*hi_args, jaxpr, **params): # collect and apply mutations out_mut_ = iter(out_mut) in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - for var, ty in jaxpr.jaxpr.final_typechange_env.items(): - lo_vals = it.islice(out_mut_, len(var.aval.lo_ty_(ty))) - var.aval.update_from_loval(ty, hi_args[in_idx[var]], *lo_vals) + for v in jaxpr.jaxpr.invars: + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = it.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[in_idx[v]], *lo_vals) assert next(out_mut_, None) is None # collect output values into hi types @@ -1612,6 +1611,7 @@ def _converted_mutables_add_params( return dict(params, donated_invars=donated_invars, in_shardings=in_shardings, in_layouts=in_layouts) + def _lojax_expand_params( nums_in, nums_out, muts_out, *, donated_invars, in_shardings, in_layouts, out_shardings, out_layouts, **params): @@ -2014,13 +2014,6 @@ def pjit_staging_rule(trace, *args, **params): else: out_tracers = trace.default_process_primitive(pjit_p, args, params) - trace.frame.is_high = jaxpr.jaxpr.is_high - invars = [trace.frame.tracer_to_var[id(t)] for t in it.chain(args, consts)] - var_map = dict(zip(jaxpr.jaxpr.invars, invars)) - final_env = {var_map[v]: ty for v, ty in - jaxpr.jaxpr.final_typechange_env.items()} - trace.frame.current_typechange_env.update(final_env) - return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule diff --git a/tests/hijax_test.py b/tests/hijax_test.py index 8b7d045c6c5a..b272e0aa8986 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -213,7 +213,7 @@ def new_box(): return new_box_p.bind(treedef=treedef) def box_get(box): - tys = box.type_state() + tys = core.cur_qdd(box) leaf_vals = box_get_p.bind(box, avals=tys.leaf_avals) return jax.tree.unflatten(tys.treedef, leaf_vals) @@ -222,11 +222,11 @@ def box_set(box, val): box_set_p.bind(box, *leaves, treedef=treedef) @dataclass(frozen=True) -class BoxTypeState: +class BoxTypeState(core.QuasiDynamicData): leaf_avals: tuple[core.AbstractValue, ...] treedef: PyTreeDef - def to_tangent_aval(self): + def to_tangent_qdd(self): return BoxTypeState(tuple(a.to_tangent_aval() for a in self.leaf_avals), self.treedef) @@ -235,11 +235,12 @@ def normalize(self): self.treedef) class BoxTy(core.AbstractValue): - mutable = True + has_qdd = True # forwarded to value get = core.aval_method(box_get) set = core.aval_method(box_set) + type_state = core.aval_method(core.cur_qdd) # aval interface: hashability and str_short def __hash__(self): return hash(BoxTy) @@ -249,7 +250,7 @@ def str_short(self, short_dtypes=False): return 'BoxTy' # mutable interface - def lo_ty_(self, box_state): + def lo_ty_qdd(self, box_state): return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()] def new_from_loval(self, box_state: BoxTypeState, *lo_vals): @@ -285,6 +286,9 @@ def get(self): def set(self, val): box_set(self, val) + def cur_qdd(self): + return self.type_state() + @property def ty(self): return BoxTy() @@ -299,15 +303,10 @@ def type_state(self): class NewBox(HiPrimitive): def is_high(self, *, treedef) -> bool: return True - def staging(self, trace, *, treedef): - tracer = super().staging(trace, treedef=treedef) - var = trace.frame.tracer_to_var[id(tracer)] - leaves, treedef = jax.tree.flatten(None) - trace.frame.current_typechange_env[var] = BoxTypeState(leaves, treedef) - return tracer - def abstract_eval(self, *, treedef): - return BoxTy(), set() + leaves, treedef = jax.tree.flatten(None) + qdd = BoxTypeState(leaves, treedef) + return core.AvalQDD(BoxTy(), qdd), set() def to_lojax(_, *, treedef): return Box(None) @@ -325,14 +324,8 @@ class BoxSet(HiPrimitive): def is_high(self, *, treedef) -> bool: return True - def staging(self, trace, box_tracer, *leaves, treedef): - super().staging(trace, box_tracer, *leaves, treedef=treedef) - var = trace.getvar(box_tracer) - avals = tuple(t.aval for t in leaves) - trace.frame.current_typechange_env[var] = BoxTypeState(avals, treedef) - return [] - def abstract_eval(self, box_ty, *leaf_avals, treedef): + box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) return [], set() # TODO better typechecking... def to_lojax(_, box, *leaves, treedef): @@ -375,6 +368,36 @@ def transpose(_, *args): class BoxTest(jtu.JaxTestCase): + @parameterized.parameters([False, True]) + def test_qdd(self, jit): + + val1 = 1.0 + val2 = jnp.arange(3) + + box1 = Box(val1) + + def f(box2): + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val1),) + box2.set(val2) + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val2),) + + box3 = new_box() + box3.set(val2) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val2),) + box3.set(val1) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val1),) + + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val1),) + box1.set(val2) + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val2),) + + return + + if jit: + f = jax.jit(f) + + f(Box(val1)) + def test_jit_arg(self): @jax.jit def f(box, x): @@ -470,6 +493,24 @@ def k(x): ans = h(2.0) self.assertAllClose(ans, 4.0) + def test_jit_closure_nested3(self): + box = new_box() + + @jax.jit + def h(x): + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + @parameterized.parameters([False, True]) def test_jvp_closure_stop_gradient(self, jit): box = Box(1.0) @@ -528,6 +569,7 @@ def f(x): f = jax.jit(f) jax.grad(f)(1.0) + self.assertAllClose(box.get(), 2.0) # TODO(mattjj,dougalm): make this work... @@ -587,6 +629,7 @@ def body(_, __): double_it_10 = jax.jit(double_it_10) double_it_10() + self.assertAllClose(box.get(), 1024., check_dtypes=False) # TODO error-checking tests from attrs_test.py @@ -674,7 +717,7 @@ class MyArray: @dataclass(frozen=True) class MyTy(core.AbstractValue): - mutable = False + has_qdd = False def to_tangent_aval(self): return MyTy() @@ -711,7 +754,7 @@ def f(box): class ListTy(core.AbstractValue): - mutable = True + has_qdd = True # forwarded to value get = core.aval_method(box_get) From 1554de5ce5637be89f43aca04b2e2b62308f9b2e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Jun 2025 14:02:22 -0700 Subject: [PATCH 1530/1769] Fix documentation for the CLI `up` command in the debugger. PiperOrigin-RevId: 767277572 --- jax/_src/debugger/cli_debugger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index bf4b38765026..eb1eca3bec48 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -105,7 +105,7 @@ def do_pp(self, arg): def do_up(self, _): """u(p) - Move down a stack frame. + Move up a stack frame. """ if self.frame_index == len(self.frames) - 1: print('At topmost frame.', file=self.stdout) From 08530bc14f0b1c08332c7a9d873ac9818abdd4a3 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 4 Jun 2025 16:36:27 -0700 Subject: [PATCH 1531/1769] Link c-api raw buffer support into jaxlib. PiperOrigin-RevId: 767333363 --- jaxlib/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 6fd606966b53..4eff447c1b68 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -351,6 +351,7 @@ nanobind_pywrap_extension( "@xla//xla/pjrt:pjrt_layout", "@xla//xla/pjrt:status_casters", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_raw_buffer_external", "@xla//xla/pjrt/distributed", "@xla//xla/pjrt/distributed:client", "@xla//xla/pjrt/distributed:key_value_store_interface", From 1ccc387528d69bcdbd8152a87771376566744375 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 4 Jun 2025 17:01:25 -0700 Subject: [PATCH 1532/1769] [Pallas Fuser] Add basic reshape push rule PiperOrigin-RevId: 767341856 --- jax/_src/pallas/fuser/block_spec.py | 161 +++++++++++++++++++++----- tests/pallas/fuser_block_spec_test.py | 68 ++++++++++- 2 files changed, 197 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 4f9d1c344429..56f75699735c 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -814,13 +814,16 @@ def register_default_eval_rule(prim: core.Primitive): def default_rule(ctx, *args, **params): assert all(bs is pallas_core.no_block_spec for bs in ctx.out_block_specs) return prim.bind(*args, **params) + register_eval_rule(prim)(default_rule) + def register_binop_rule(prim: core.Primitive): register_pull_block_spec_rule(prim)(functools.partial(_binop_pull_rule, prim)) register_usage_rule(prim)(functools.partial(_binop_usage_rule, prim)) register_eval_rule(prim)(functools.partial(_binop_eval_rule, prim)) + register_default_eval_rule(state_primitives.get_p) register_binop_rule(lax.mul_p) @@ -1074,6 +1077,7 @@ def new_index_map(*args): len(ctx.avals_in) - 1 ) + @register_pull_block_spec_rule(state_primitives.swap_p) def _swap_pull_rule( ctx: PullRuleContext, @@ -1084,14 +1088,9 @@ def _swap_pull_rule( # The output and val block spec are the same. return [block_spec, block_spec] + @register_eval_rule(state_primitives.swap_p) -def _swap_eval_rule( - ctx: KernelEvalContext, - ref, - val, - *idx, - tree -): +def _swap_eval_rule(ctx: KernelEvalContext, ref, val, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) ref_aval, _ = ctx.avals_in[:2] indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:]) @@ -1123,17 +1122,15 @@ def _slice(i, b): return i if b is None else indexing.ds(i * b, b) indexer = tuple( - _slice(i, b) for i, b in zip(block_idx, block_spec.block_shape, - strict=True) + _slice(i, b) + for i, b in zip(block_idx, block_spec.block_shape, strict=True) ) return ref.swap(val, idx=indexer) + @register_pull_block_spec_rule(state_primitives.get_p) def _get_pull_rule( - ctx: PullRuleContext, - block_spec: pallas_core.BlockSpec, - *, - tree + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, tree ): ref_aval = ctx.avals_in[0] assert hasattr(ref_aval, 'shape') @@ -1166,6 +1163,7 @@ def _get_pull_rule( bd = next(block_shape_iter) block_shape.append(_block_size(bd)) assert next(block_shape_iter, None) is None + def new_index_map(*args): idx = block_spec.index_map(*args) idx_iter = iter(idx) @@ -1177,16 +1175,13 @@ def new_index_map(*args): ) assert next(idx_iter, None) is None return indices + block_spec = pallas_core.BlockSpec(block_shape, new_index_map) return [block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1) + @register_eval_rule(state_primitives.get_p) -def _get_eval_rule( - ctx: KernelEvalContext, - ref, - *idx, - tree -): +def _get_eval_rule(ctx: KernelEvalContext, ref, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) ref_aval = ctx.avals_in[0] indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) @@ -1240,6 +1235,7 @@ def _slice(i, b): assert next(block_idx_iter, None) is None return ref.get(idx=tuple(block_indexer)) + @register_eval_rule(lax.concatenate_p) def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): # We now handle the case where each of the concatenated array dimensions @@ -1525,17 +1521,16 @@ def _random_bits_pull_rule( ) return [key_block_spec] + @register_eval_rule(prng.random_wrap_p) def _random_wrap_eval_rule(eval_ctx: KernelEvalContext, arr, *, impl): del eval_ctx return jax.random.wrap_key_data(arr, impl=impl) + @register_pull_block_spec_rule(prng.random_wrap_p) def _random_wrap_pull_rule( - ctx: PullRuleContext, - block_spec: pallas_core.BlockSpec, - *, - impl + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, impl ): del ctx, block_spec, impl return [pallas_core.BlockSpec(block_shape=None)] @@ -1578,13 +1573,32 @@ def _pattern_match_sublanes_to_lanes_reshape( aval_out: core.ShapedArray, ) -> bool: # Pattern matches a reshape of the form (..., n/l, l) -> (..., n * l) - # where l is a multiple of 128 n/l is a multiple of packing. + # where l is a multiple of 128. *leading_in, second_to_last_dim, last_dim = aval_in.shape *leading_out, last_dim_out = aval_out.shape if leading_in != leading_out: return False - assert last_dim_out == second_to_last_dim * last_dim + if second_to_last_dim * last_dim != last_dim_out: + return False + if last_dim % 128 != 0: + return False + return True + + +def _pattern_match_lanes_to_sublanes_reshape( + aval_in: core.ShapedArray, + aval_out: core.ShapedArray, +) -> bool: + # Pattern matches a reshape of the form (..., n * l) -> (..., n, l) + # where l is a multiple of 128. + + *leading_out, last_dim_in = aval_in.shape + *leading_in, second_to_last_dim_out, last_dim = aval_out.shape + if leading_in != leading_out: + return False + if second_to_last_dim_out * last_dim != last_dim_in: + return False if last_dim % 128 != 0: return False return True @@ -1606,6 +1620,8 @@ def _reshape_pull_rule( assert isinstance(aval_in, core.ShapedArray) aval_out = ctx.avals_out[0] assert isinstance(aval_out, core.ShapedArray) + + # Handle the case where we reshape from (..., n/l, l) -> (..., n * l) if _pattern_match_sublanes_to_lanes_reshape(aval_in, aval_out): block_shape = tuple(block_spec.block_shape) if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): @@ -1625,6 +1641,44 @@ def new_index_map(*args): return *idx, 0 return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + + # Handle the case where we reshape from (..., n * l) -> (..., n, l) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + if not isinstance(block_shape[-2], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on sublanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + block_sublane_dim, block_lane_dim = ( + _block_size(block_shape[-2]), + _block_size(block_shape[-1]), + ) + total_block_size = block_sublane_dim * block_lane_dim + if total_block_size % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if block_lane_dim != last_dim: + raise NotImplementedError( + 'reshape with non-matching block size on lanes not supported yet:' + f' {block_shape}' + ) + new_block_shape = block_shape[:-2] + (total_block_size,) + def new_index_map(*args): # pylint: disable=function-redefined + *idx, second_to_last, last = block_spec.index_map(*args) + # last should always be 0 + if not isinstance(last, int) and last != 0: + raise NotImplementedError( + 'Must select entire block on last dimension for reshape' + ) + return *idx, second_to_last + return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') @@ -1639,13 +1693,8 @@ def _reshape_eval_rule( out_shape = tuple(s for s in out_shape_nones if s is not None) # Because we have restricted the pull block spec rule, we can just apply a # basic reshape here. - orig_dtype = x.dtype - if jnp.issubdtype(orig_dtype, jnp.integer): - x = x.astype(jnp.int32) - elif jnp.issubdtype(orig_dtype, jnp.floating): - x = x.astype(jnp.float32) x = x.reshape(out_shape) - return x.astype(orig_dtype) + return x # Higher order primitives @@ -1667,8 +1716,10 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): raise NotImplementedError('pjit with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) + def read_usage_env(_: core.Var): return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, @@ -1697,8 +1748,10 @@ def _jit_pull_block_spec_rule( jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError('pjit with consts not supported yet') + def read_usage_env(_: core.Var): return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, @@ -1728,8 +1781,10 @@ def _custom_jvp_call_eval_rule( raise NotImplementedError('custom_jvp_call with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) + def read_usage_env(_: core.Var): return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, @@ -1758,8 +1813,10 @@ def _custom_jvp_call_pull_block_spec_rule( jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts if consts: raise NotImplementedError('custom_jvp_call with consts not supported yet') + def read_usage_env(_: core.Var): return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, @@ -2009,3 +2066,45 @@ def register_eltwise_rule(prim: core.Primitive): register_eltwise_rule(lax.rsqrt_p) register_eltwise_rule(lax.log_p) register_eltwise_rule(lax.integer_pow_p) + +@register_push_block_spec_rule(lax.reshape_p) +def _reshape_push_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + last_block_dim = _block_size(block_shape[-1]) + if last_block_dim % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if last_block_dim % last_dim != 0: + raise NotImplementedError( + 'reshape with non-divisible block size on lanes not supported yet' + ) + num_last_dim_blocks = last_block_dim // last_dim + new_block_shape = block_shape[:1] + (num_last_dim_blocks, last_dim) + + def new_index_map(*args): + *idx, last = block_spec.index_map(*args) + return *idx, last, 0 + + return pallas_core.BlockSpec(new_block_shape, new_index_map) + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index b348ba971c38..db242fb1e400 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -771,7 +771,7 @@ def f(): x_block, ) - def test_basic_reshape(self): + def test_basic_reshape_sublanes_to_lanes(self): def f(x): return x.reshape((512, 2048)) @@ -800,6 +800,44 @@ def f(x): y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) np.testing.assert_array_equal(y, x.reshape((256, 1024))) + def test_basic_reshape_lanes_to_sublanes(self): + + def f(x): + return x.reshape((512, 32, 128)) + + in_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 8, 128), lambda i, j, k: (i, k, 0)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 1024)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 8, 128))) + + block_spec = pl.BlockSpec((256, 4, 256), lambda i, j, k: (i, j, k)) + with self.assertRaises(NotImplementedError): + _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + def test_basic_swap(self): value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 x = jnp.zeros((256, 512), dtype=jnp.int32) @@ -934,12 +972,14 @@ def f(key): self.assertEmpty(value_block_specs) self.assertEqual(key_block_spec.memory_space, pl.MemorySpace.KEY) self.assertIsNone(key_block_spec.block_shape) + @jax.jit def gen(idx): k = key for i in idx: k = jax.random.fold_in(k, i) return jax.random.uniform(k, (128, 256), dtype=jnp.float32) + for i in range(4): for j in range(2): out = kernel_fn((i, j), scalar_prefetch_values, (), key) @@ -1086,6 +1126,32 @@ def f(x): out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) self.assertEqual(out_block_spec.block_shape, block_spec.block_shape) + def test_push_reshape_lanes_to_sublanes(self): + def f(x): + return x.reshape((512, 32, 128)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec( + (256, 1024), lambda i, j, k: (i, k) + ) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 8, 128)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + def f(x): + return x.reshape((512, 16, 256)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec( + (256, 1024), lambda i, j, k: (i, k) + ) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 4, 256)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 9f7e802980fdaed415f0513e1f0ab331cb646b91 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 4 Jun 2025 17:20:25 -0700 Subject: [PATCH 1533/1769] [Rollback] Roll-forward with fix and test: prototype of cross-host device transfers for TFRT TPU. Reverts 8d8cc2bca67fc75718b73337c9ce19d6b77065e9 PiperOrigin-RevId: 767348122 --- jax/_src/dispatch.py | 57 +++++++++++++++++++++++++++++++++++--------- jax/_src/pjit.py | 6 +++-- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b5e588cbc10e..b9ef8f49f801 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -356,16 +356,6 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(x) - if inp_sharding.device_set != target_sharding.device_set: - inp_ids = [d.id for d in inp_sharding._device_assignment] - inp_plat = inp_sharding._device_assignment[0].platform.upper() - target_ids = [d.id for d in target_sharding._device_assignment] - target_plat = target_sharding._device_assignment[0].platform.upper() - raise ValueError("Input and target sharding should have the same set of " - f"devices. Got input's device set ids: {inp_ids} on " - f"platform {inp_plat} and target sharding's device set " - f"ids: {target_ids} on platform {target_plat}") - if inp_sharding.is_fully_replicated: permute_order = None else: @@ -389,6 +379,25 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics): return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore +@util.cache() +def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): + """Returns True if src->dst is a supported cross-host transfer.""" + backend = xla_bridge.get_backend() + # There is experimental support for cross-host device transfers on TFRT TPU + # backends only. + if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or + "TFRT TPU" not in backend.platform_version): + return False + if (src_sharding._to_xla_hlo_sharding(ndim) != + dst_sharding._to_xla_hlo_sharding(ndim)): + return False + # This check excludes the case where the source and destination shardings + # have the same process index sets but there are shards that require + # cross-host transfers. This case is supportable but expensive to check for. + return (src_sharding._internal_device_list.process_indices != + dst_sharding._internal_device_list.process_indices) + + @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`. @@ -419,7 +428,8 @@ def _device_put_sharding_impl(x, aval, device, copy): return x if (not s.is_fully_addressable and - isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): + isinstance(x, array.ArrayImpl) and not x.is_fully_addressable and + s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) @@ -430,7 +440,32 @@ def _device_put_sharding_impl(x, aval, device, copy): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) + # There is experimental support for cross-host device transfers on TFRT TPU. + if (isinstance(x, array.ArrayImpl) and x._committed + and _is_supported_cross_host_transfer(x.ndim, x.sharding, s)): + return xc.batched_copy_array_to_devices_with_sharding( + [x], [s._internal_device_list], [s], # pytype: disable=attribute-error + pxla.to_xc_copy_semantics([copy]))[0] + if not s.is_fully_addressable: + # If both the source and target shardings are not fully addressable and + # one of the above conditions has not been met, then assume that the user + # is attempting a different device order reshard. + if (isinstance(x, array.ArrayImpl) and not x.is_fully_addressable + and s.device_set != x.sharding.device_set): + inp_ids = [d.id for d in x.sharding._device_assignment] + inp_plat = x.sharding._device_assignment[0].platform.upper() + target_ids = [d.id for d in s._device_assignment] + target_plat = s._device_assignment[0].platform.upper() + raise ValueError( + "For a cross-host reshard in multi-controller JAX, input and target" + " sharding should have the same set of devices. Got input's device" + f" set ids: {inp_ids} on platform {inp_plat} and target sharding's" + f" device set ids: {target_ids} on platform {target_plat}.\n\n" + "There is experimental support for cross-host transfers with " + "different device sets, when input/output shardings have the same " + "indices and layouts, in the TFRT TPU runtime only.") + if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types or type(x) in dtypes.python_scalar_dtypes): # If all hosts participate in the sharding, assert that the input is the diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 572c2225af74..52e876400ed9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -414,8 +414,10 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, if backend is not None or device is not None: warnings.warn( 'backend and device argument on jit is deprecated. You can use' - ' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to' - ' the jitted function to get the same behavior.', DeprecationWarning) + ' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the' + ' inputs to the jitted function to get the same behavior.', + DeprecationWarning, + ) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") From 7d9ba049bc78e8e08c9eb495fdb7fcce7b788592 Mon Sep 17 00:00:00 2001 From: Hui Peng Date: Wed, 4 Jun 2025 17:49:05 -0700 Subject: [PATCH 1534/1769] fix type annotation for _IndexUpdateRef.get in _src/basearray.pyi This misannotation fails type checkers Signed-off-by: Hui Peng --- jax/_src/basearray.pyi | 2 +- jax/_src/numpy/array_methods.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a98cc012031e..d92bd61f23e2 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -286,7 +286,7 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_spec: Sharding | P | None = None) -> Array: ... + out_sharding: Sharding | P | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 958085e19a53..e635bdbbd56d 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -778,7 +778,8 @@ def __repr__(self) -> str: def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | lax.GatherScatterMode | None = None, - fill_value: ArrayLike | None = None, out_sharding: Sharding | None = None): + fill_value: ArrayLike | None = None, + out_sharding: Sharding | PartitionSpec | None = None): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style From c14907890d18ca678e904a2a4acd6c513f72d43a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 5 Jun 2025 00:55:37 +0000 Subject: [PATCH 1535/1769] fix sharding-in-types + from_edtype Co-authored-by: Yash Katariya --- jax/_src/lax/lax.py | 34 ++++++++++++++++++++++++++++++---- tests/pjit_test.py | 15 +++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b644bdefecc2..9c9a1dd8c1ba 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4951,6 +4951,9 @@ def _to_edtype_abstract_eval(x, *, edtype): not isinstance(x.dtype, dtypes.ExtendedDType)) # For backward compatibility, if the edtype rules have a `convert_to` method, # use that rather than looking for an `allow_conversion: bool` attribute. + if not isinstance(x, (ShapedArray, core.DShapedArray)): + raise TypeError("can only convert to an extended dtype on an array type," + f"but got {type(x)}") if convert_to := getattr(edtype._rules, 'convert_to', None): allow_conversion = convert_to(x.dtype, edtype) else: @@ -4960,6 +4963,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(edtype)}") rep_aval = core.physical_element_aval(edtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if x.dtype != rep_aval.dtype: raise ValueError( "can only convert to extended dtype from its representation dtype, " @@ -4982,7 +4986,20 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return x.update(shape=shape_prefix, dtype=edtype) + if isinstance(x, ShapedArray): + spec_prefix, spec_suffix = x.sharding.spec[:n], x.sharding.spec[n:] + if tuple(spec_suffix) != (None,) * len(spec_suffix): + raise ValueError( + "can only convert to extended dtype from an array with trailing " + "axes that are not explicitly sharded, but tried to convert from " + f"{x.str_short(short_dtypes=True)} to an extended dtype with element " + f"shape {rep_aval.shape}") + return x.update(shape=shape_prefix, dtype=edtype, + sharding=x.sharding.with_spec(spec_prefix)) + elif isinstance(x, core.DShapedArray): + return x.update(shape=shape_prefix, dtype=edtype) + else: + assert False # unreachable, see isinstance check above to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -4999,6 +5016,9 @@ def _to_edtype_abstract_eval(x, *, edtype): def _from_edtype_abstract_eval(x, *, dtype): assert (isinstance(x.dtype, dtypes.ExtendedDType) and not isinstance(dtype, dtypes.ExtendedDType)) + if not isinstance(x, (ShapedArray, core.DShapedArray)): + raise TypeError("can only convert from an extended dtype on an array type," + f"but got {type(x)}") if convert_from := getattr(x.dtype._rules, 'convert_from', None): allow_conversion = convert_from(x.dtype, dtype) else: @@ -5008,16 +5028,22 @@ def _from_edtype_abstract_eval(x, *, dtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(dtype)}") rep_aval = core.physical_element_aval(x.dtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if rep_aval.dtype != dtype: raise ValueError( "can only convert from extended dtype to its representation dtype, " f"but tried to convert from {dtype_to_string(x.dtype)} to " f"{dtype_to_string(dtype)} which doesn't match the representation type " f"{dtype_to_string(rep_aval.dtype)}.") - if all(isinstance(d, int) for d in x.shape): - return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + if isinstance(x, ShapedArray): + return x.update(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + elif isinstance(x, core.DShapedArray): + if all(isinstance(d, int) for d in x.shape): + return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + else: + raise NotImplementedError else: - raise NotImplementedError + assert False # unreachable, see isinstance check above from_edtype_p = Primitive('from_edtype') from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b08d462b3485..8c9dcc05d406 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -46,6 +46,7 @@ from jax._src.compilation_cache import is_persistent_cache_enabled from jax.experimental.custom_partitioning import ( custom_partitioning, SdyShardingRule, BATCHING) +from jax.experimental import primal_tangent_dtype from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings @@ -7928,6 +7929,20 @@ def simple_func(w, x): jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + @jtu.with_explicit_mesh((2,), ('x',)) + def test_extended_dtypes(self, mesh): + dtype = primal_tangent_dtype(jnp.dtype('int8'), jnp.dtype('bfloat16')) + + @jax.jit + def f(x): + x = jax.lax.convert_element_type(x, dtype) + self.assertEqual(x.aval.sharding.spec, P('x')) + x = jax.lax.convert_element_type(x, 'int8') + self.assertEqual(x.aval.sharding.spec, P('x')) + + x = jax.device_put(jnp.arange(8, dtype='int8'), P('x',)) + f(x) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From c5ef4e5f18859ead5b8d2023e975dcc4659db0d7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Jun 2025 01:45:34 -0700 Subject: [PATCH 1536/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/da63f2971676b6cd97a72ba52883717cf48e13d8. PiperOrigin-RevId: 767494983 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1fc6ccae44dc..858bd93987a1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a3cb8a0de31a1984a56802981ed3987f63879ce5" -XLA_SHA256 = "26d7d752d46e0c753525feb536af5259fb2bef04432b93853cba26d330ea0dde" +XLA_COMMIT = "da63f2971676b6cd97a72ba52883717cf48e13d8" +XLA_SHA256 = "dcd368eba23a9ace0e8b950f4d9693abdcfb6de5ef109ee6773c450ab1c75a51" def repo(): tf_http_archive( From 986c41132710281bd3154299db62c30e25f5b6b8 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 5 Jun 2025 02:02:25 -0700 Subject: [PATCH 1537/1769] [Pallas/Mosaic GPU] Expose the new `TCGEN05_COL` layout. `TCGEN05_COL` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`. PiperOrigin-RevId: 767500979 --- jax/_src/pallas/mosaic_gpu/primitives.py | 4 ++++ jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/fragmented_array.py | 10 +++++++--- jax/experimental/mosaic/gpu/tcgen05.py | 7 +++++++ tests/mosaic/gpu_test.py | 14 ++++++++++++++ 5 files changed, 33 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index f37a003f4401..41c8a673b528 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1452,6 +1452,7 @@ class Layout(enum.Enum): TCGEN05 = enum.auto() TCGEN05_ROW = enum.auto() + TCGEN05_COL = enum.auto() def __call__(self, *args, **kwargs) -> ParameterizedLayout: return ParameterizedLayout(self, args, kwargs) @@ -1484,6 +1485,9 @@ def check_no_args(): case Layout.TCGEN05_ROW: check_no_args() return mgpu.TCGEN05_ROW_LAYOUT + case Layout.TCGEN05_COL: + check_no_args() + return mgpu.TCGEN05_COL_LAYOUT @dataclasses.dataclass(frozen=True) class ParameterizedLayout: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 37bfa227fe61..f4b65c976ef5 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -106,4 +106,5 @@ from .tcgen05 import ( LAYOUT as TCGEN05_LAYOUT, # noqa: F401 ROW_LAYOUT as TCGEN05_ROW_LAYOUT, # noqa: F401 + COL_LAYOUT as TCGEN05_COL_LAYOUT, # noqa: F401 ) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 3554ed95844c..40d95f65f210 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -461,6 +461,12 @@ def replace_tiled_dim(d: int | Replicated, size: int): new_vector_dim, ) + def reduce(self, axes: Sequence[int]) -> TiledLayout: + reduced_layout = self + for a in sorted(axes, reverse=True): + reduced_layout = reduced_layout.remove_dimension(a) + return reduced_layout + def _tiled_wgmma_layout(shape: tuple[int, ...]): """Returns the tiled layout relevant for WGMMA operations. @@ -1825,9 +1831,7 @@ def reduce( out_reg = vector.extractelement(out_reg, position=c(0, index)) out_regs = np.asarray(out_reg, dtype=object) else: - reduced_layout = layout - for a in sorted(axis, reverse=True): - reduced_layout = reduced_layout.remove_dimension(a) + reduced_layout = layout.reduce(axis) out_regs = out_regs.reshape(reduced_layout.registers_shape(reduced_logical_shape)) return FragmentedArray( _registers=out_regs, _layout=reduced_layout, _is_signed=self.is_signed diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 91797fe65d1f..32f0453148f4 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -50,6 +50,13 @@ lane_dims=(-3, fa.Replicated(times=4)), vector_dim=-1 ) +# COL_LAYOUT is to LAYOUT as WGMMA_COL_LAYOUT is to WGMMA_LAYOUT. +COL_LAYOUT = fa.TiledLayout( + fa.Tiling(tiles=((8,), (8,), (8,), (2,))), + warp_dim=fa.Replicated(times=4), + lane_dims=(fa.Replicated(times=8), -2), + vector_dim=-1 +) # A layout resembling the logical organization of TMEM. The 128 rows in a tile # are assigned to 128 lanes in the warpgroup. Useful when the result needs to be # processed in registers and then stored back into TMEM. Should not be used if diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 31262dcae44d..a6828a31dd9f 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2289,6 +2289,20 @@ def kernel(ctx, dst, _): result = m * n if reduce_both else n np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), result, dtype)) + @parameterized.named_parameters( + ("wgmma_row", fa.WGMMA_LAYOUT, fa.WGMMA_ROW_LAYOUT, 1), + ("wgmma_col", fa.WGMMA_LAYOUT, fa.WGMMA_COL_LAYOUT, 0), + ("tcgen05_row", tcgen05.LAYOUT, tcgen05.ROW_LAYOUT, 1), + ("tcgen05_col", tcgen05.LAYOUT, tcgen05.COL_LAYOUT, 0), + ) + def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis): + def squeeze_shape(shape): + return tuple(s for s in shape if s != 1) + reduced_layout = layout.reduce((axis,)) + tiled_shape = squeeze_shape(reduced_layout.tiled_tiling_shape) + expected_tiled_shape = squeeze_shape(expected_reduced_layout.tiled_tiling_shape) + self.assertEqual(tiled_shape, expected_tiled_shape) + @parameterized.product( op=(arith.addf, arith.maximumf), m=(64, 128), From af8e2e3cacf52211aed42ac195bf2d01f1cf5a51 Mon Sep 17 00:00:00 2001 From: YousefElbrolosy Date: Thu, 29 May 2025 17:15:25 +0200 Subject: [PATCH 1538/1769] doc: clarified lack of gpu support for schur and sqrtm --- jax/_src/scipy/linalg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 7b0bb06f044c..d4971fb745cb 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -486,6 +486,8 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: """Compute the Schur decomposition + Only implemented on CPU. + JAX implementation of :func:`scipy.linalg.schur`. The Schur form `T` of a matrix `A` satisfies: @@ -1832,6 +1834,9 @@ def _sqrtm(A: ArrayLike) -> Array: def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: """Compute the matrix square root + This function is implemented using :func:`scipy.linalg.schur`, which is only + supported on CPU. + JAX implementation of :func:`scipy.linalg.sqrtm`. Args: From b7a3250f9df442f6134bc76263f5c9bd28f8e336 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Jun 2025 07:27:20 -0700 Subject: [PATCH 1539/1769] Don't repeatedly recompute a tuple of axis names for a membership test. PiperOrigin-RevId: 767596548 --- jax/_src/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7260557ff4cc..4bf06336697e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2023,9 +2023,9 @@ def str_short_aval(shape, dtype, mesh, spec, vma, def get_vma(vma, mesh): if mesh.empty: return vma - axis_env_names = get_axis_env().axis_names() + axis_env = get_axis_env() for i in vma: - if i in axis_env_names and i not in mesh._name_to_type: + if axis_env.axis_exists(i) and i not in mesh._name_to_type: continue if mesh._name_to_type[i] != AxisType.Manual: raise ValueError( From 59c0171c6a2a0c927f9248f071895082c693dffe Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 5 Jun 2025 08:06:24 -0700 Subject: [PATCH 1540/1769] [Mosaic GPU] Move `should_have_transforms` to `inference_utils`. PiperOrigin-RevId: 767610707 --- jax/experimental/mosaic/gpu/inference_utils.py | 9 +++++++++ jax/experimental/mosaic/gpu/transform_inference.py | 13 +------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index ff3dbf665681..2641b76da8fc 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -135,6 +135,15 @@ def _in_attr_for_operand( _in_attr_for_operand, attr_name="in_transforms" ) +def should_have_transforms(op: ir.OpView) -> bool: + """Returns 'True' if the operation should be assigned in/out transforms.""" + return any( + map( + is_transformable_smem_memref, + itertools.chain(op.operands, op.results), + ) + ) + def is_transformable_smem_memref(v: ir.Value) -> bool: """Whether the value is a memref in SMEM on which transforms should be applied.""" barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 1d97b3f0fa63..e6a4e5bd1cff 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -20,7 +20,6 @@ from collections.abc import Callable from functools import partial -import itertools from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu @@ -375,16 +374,6 @@ def _infer_memref_cast_transforms( return [transforms], [transforms] -def _should_have_transforms(op: ir.OpView) -> bool: - """Returns 'True' if the operation should be assigned in/out transforms.""" - return any( - map( - inference_utils.is_transformable_smem_memref, - itertools.chain(op.operands, op.results), - ) - ) - - def infer_transforms(module: ir.Module): """Infers transforms for the given module. @@ -398,7 +387,7 @@ def infer_transforms(module: ir.Module): annotate the same memref. """ def inference_step(op: ir.Operation): - if not _should_have_transforms(op): + if not inference_utils.should_have_transforms(op): return elif inference_rule := _transform_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error pass From f8ab209c251b73c6ee518c2eec37154f682efa44 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Jun 2025 08:08:41 -0700 Subject: [PATCH 1541/1769] [Pallas][Mosaic GPU] Use separate allocations for collective TMEM. The distinction between collective/non-collective TMEM is done at allocation time so we need to allocate two separate blocks. PiperOrigin-RevId: 767611353 --- jax/_src/pallas/mosaic_gpu/lowering.py | 80 ++++++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 12 +++- 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 87bb85cfcd70..f34b1e649c2d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -103,6 +103,7 @@ def arrival_multiplier(self) -> int: class Resources: smem_scratch_bytes: int = 0 tmem_scratch_cols: int = 0 + tmem_collective_scratch_cols: int = 0 barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) @@ -114,12 +115,18 @@ def __post_init__(self): "smem_scratch_bytes", gpu_core.align_to(self.smem_scratch_bytes, gpu_core.SMEM_ALIGNMENT), ) + + # TMEM must be allocated in 128x8 chunks. object.__setattr__( self, "tmem_scratch_cols", - # TMEM must be allocated in 128x8 chunks. gpu_core.align_to(self.tmem_scratch_cols, 8), ) + object.__setattr__( + self, + "tmem_collective_scratch_cols", + gpu_core.align_to(self.tmem_collective_scratch_cols, 8), + ) @property def barriers(self) -> Sequence[AnyBarrier]: @@ -133,6 +140,8 @@ def __add__(self, other: Resources) -> Resources: return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, + tmem_collective_scratch_cols=self.tmem_collective_scratch_cols + + other.tmem_collective_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, gmem_semaphores=self.gmem_semaphores + other.gmem_semaphores, ) @@ -142,8 +151,10 @@ def __or__(self, other: Resources) -> Resources: smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), - tmem_scratch_cols=max( - self.tmem_scratch_cols, other.tmem_scratch_cols + tmem_scratch_cols=max(self.tmem_scratch_cols, other.tmem_scratch_cols), + tmem_collective_scratch_cols=max( + self.tmem_collective_scratch_cols, + other.tmem_collective_scratch_cols, ), barrier_counts=self.barrier_counts | other.barrier_counts, gmem_semaphores=max(self.gmem_semaphores, other.gmem_semaphores), @@ -303,7 +314,10 @@ def _run_scoped_resource_estimator( layout = tcgen05._infer_tmem_layout(aval.shape, packing=packing) cols_used = layout.cols_in_shape(aval.shape) cols_used = tcgen05._alloc_ncols(cols_used, exact=False) - rs += Resources(tmem_scratch_cols=cols_used) + if aval.collective: + rs += Resources(tmem_collective_scratch_cols=cols_used) + else: + rs += Resources(tmem_scratch_cols=cols_used) elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize @@ -359,6 +373,9 @@ class ModuleContext: tmem_requested_cols: int tmem_used_cols: int tmem_base_ptr: ir.Value + tmem_collective_requested_cols: int + tmem_collective_used_cols: int + tmem_collective_base_ptr: ir.Value gmem_used_semaphores: int gmem_semaphore_base_ptr: ir.Value | None runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] @@ -435,17 +452,28 @@ def alloc_tmem( layout = tcgen05._infer_tmem_layout(struct.shape, packing=packing) unpadded_cols_used = layout.cols_in_shape(struct.shape) cols_used = tcgen05._alloc_ncols(unpadded_cols_used, exact_cols) - - off = arith_dialect.addi(self.tmem_base_ptr, - _i32_constant(self.tmem_used_cols)) + if collective: + off = arith_dialect.addi( + self.tmem_collective_base_ptr, + _i32_constant(self.tmem_collective_used_cols), + ) + else: + off = arith_dialect.addi( + self.tmem_base_ptr, _i32_constant(self.tmem_used_cols) + ) tmem_ref = tcgen05.TMEMRef( address=off, shape=struct.shape, dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), layout=layout) - self.tmem_used_cols += cols_used - yield tmem_ref - self.tmem_used_cols -= cols_used + if collective: + self.tmem_collective_used_cols += cols_used + yield tmem_ref + self.tmem_collective_used_cols -= cols_used + else: + self.tmem_used_cols += cols_used + yield tmem_ref + self.tmem_used_cols -= cols_used # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager @@ -808,7 +836,12 @@ def lower_jaxpr_to_module( ) def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers + *buffers_gmem, ( + runtime_smem, + runtime_barriers, + runtime_tmem, + runtime_tmem_collective, + ) = buffers gmem_semaphores = None if rs.gmem_semaphores: # Extract the semaphores local to the current block. @@ -833,6 +866,12 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS else: tmem_cols = 0 + if runtime_tmem_collective is not None: + tmem_collective_cols = ( + math.prod(runtime_tmem_collective.shape) // tcgen05.TMEM_ROWS + ) + else: + tmem_collective_cols = 0 if lowering_semantics == mgpu.LoweringSemantics.Lane: single_wg_lane_predicate = mgpu.single_thread_predicate( @@ -855,6 +894,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_requested_cols=tmem_cols, tmem_used_cols=0, tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, + tmem_collective_requested_cols=tmem_collective_cols, + tmem_collective_used_cols=0, + tmem_collective_base_ptr=runtime_tmem_collective.address + if runtime_tmem_collective + else None, gmem_used_semaphores=0, gmem_semaphore_base_ptr=gmem_semaphores, runtime_barriers=grouped_barriers, @@ -878,7 +922,19 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): if rs.tmem_scratch_cols > 0: scratch_buffers.append( mgpu.TMEM( - shape=[tcgen05.TMEM_ROWS, rs.tmem_scratch_cols], dtype=np.int32 + shape=[tcgen05.TMEM_ROWS, rs.tmem_scratch_cols], + dtype=np.int32, + collective=False, + ), + ) + else: + scratch_buffers.append(None) + if rs.tmem_collective_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM( + shape=[tcgen05.TMEM_ROWS, rs.tmem_collective_scratch_cols], + dtype=np.int32, + collective=True, ), ) else: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 1170b2ac5cdb..04c3db0b30c9 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2417,7 +2417,11 @@ class PallasCallSm90AWGTest( class PallasCallSm100ATest(PallasSm100ATest): - def test_tmem(self): + @parameterized.parameters( + (False,), + (True,), + ) + def test_tmem(self, collective): self.skip_if_wg_semantics() # TMEM read not wired up in the WG get rule. swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize transforms = ( @@ -2428,13 +2432,15 @@ def test_tmem(self): self.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), scratch_shapes=[ - plgpu.TMEM((128, 128), jnp.float32), - plgpu.TMEM((128, 128), jnp.float32), + plgpu.TMEM((128, 128), jnp.float32, collective=collective), + plgpu.TMEM((128, 128), jnp.float32, collective=collective), plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), plgpu.Barrier(), ], num_threads=1, thread_name="x", + cluster=(2,) if collective else (), + cluster_names=("x",) if collective else (), ) def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) From 1c4bb503bbf01e9bf4cd7605565dabecf2cb7974 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 5 Jun 2025 08:11:16 -0700 Subject: [PATCH 1542/1769] Cache get_vma because it's the same thing we do for `get_sharding` and there's no need to do the same calculation if the keys don't change. PiperOrigin-RevId: 767612102 --- jax/_src/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 4bf06336697e..49579bc3de10 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2020,6 +2020,7 @@ def str_short_aval(shape, dtype, mesh, spec, vma, vma_ur = _vma_ur_str(vma, spec.unreduced) return f'{dt_str}[{shapestr}]{vma_ur}{mesh_axes}' +@cache(max_size=4096, trace_context_in_key=False) def get_vma(vma, mesh): if mesh.empty: return vma From 0b899a990e865e54954cbae09ba2b963afcc2e08 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Jun 2025 08:13:55 -0700 Subject: [PATCH 1543/1769] lax_numpy: move array and asarray to their own submodule This leads to a better dependency graph and avoids local imports. --- jax/_src/numpy/array.py | 375 +++++++++++++++++++++++++++++++ jax/_src/numpy/array_creation.py | 7 +- jax/_src/numpy/index_tricks.py | 3 +- jax/_src/numpy/lax_numpy.py | 347 +--------------------------- jax/_src/numpy/scalar_types.py | 2 +- jax/numpy/__init__.py | 7 +- 6 files changed, 386 insertions(+), 355 deletions(-) create mode 100644 jax/_src/numpy/array.py diff --git a/jax/_src/numpy/array.py b/jax/_src/numpy/array.py new file mode 100644 index 000000000000..42d6132efb49 --- /dev/null +++ b/jax/_src/numpy/array.py @@ -0,0 +1,375 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Any + +import numpy as np + +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import tree_util +from jax._src import xla_bridge +from jax._src.lax import lax +from jax._src.lib import xla_client as xc +from jax._src.numpy import util +from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.sharding import Sharding + + +export = util.set_module('jax.numpy') + +for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + except ImportError: + cuda_plugin_extension = None # type: ignore + else: + break + + +def _supports_buffer_protocol(obj): + try: + view = memoryview(obj) + except TypeError: + return False + else: + return True + + +def _make_string_array( + object: np.ndarray, + dtype: DTypeLike | None = None, + ndmin: int = 0, + device: xc.Device | Sharding | None = None, +) -> Array: + if not isinstance(object, np.ndarray): + raise TypeError( + "Currently, string arrays can only be made from NumPy" + f" arrays. Got: {type(object)}." + ) + if dtype is not None and ( + dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) + ): + raise TypeError( + f"Cannot make an array with dtype {dtype} from an object with dtype" + f" {object.dtype}." + ) + if ndmin > object.ndim: + raise TypeError( + f"ndmin {ndmin} cannot be greater than object's ndims" + f" {object.ndim} for string arrays." + ) + + # Just do a device_put since XLA does not support string as a data type. + return api.device_put(x=object, device=device) + + +@export +def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.array`. + + Args: + object: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + copy: specify whether to force a copy of the input. Default: True. + order: not implemented in JAX + ndmin: integer specifying the minimum number of dimensions in the + output array. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.asarray`: like `array`, but by default only copies + when necessary. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.array(True) + Array(True, dtype=bool) + >>> jnp.array(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.array(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.array(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.array([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.array(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.array(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.array(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ + if order is not None and order != "K": + raise NotImplementedError("Only implemented for order='K'") + + # check if the given dtype is compatible with JAX + dtypes.check_user_dtype_supported(dtype, "array") + + # Here we make a judgment call: we only return a weakly-typed array when the + # input object itself is weakly typed. That ensures asarray(x) is a no-op + # whenever x is weak, but avoids introducing weak types with something like + # array([1, 2, 3]) + weak_type = dtype is None and dtypes.is_weakly_typed(object) + if device is None and isinstance(object, core.Tracer): + sharding = object.aval.sharding + sharding = None if sharding.mesh.empty else sharding + else: + sharding = util.canonicalize_device_to_sharding(device) + + # Use device_put to avoid a copy for ndarray inputs. + if (not copy and isinstance(object, np.ndarray) and + (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and + device is None): + # Keep the output uncommitted. + return api.device_put(object) + + # String arrays need separate handling because XLA does not support string + # as a data type. + if dtypes.is_string_dtype(dtype) or ( + hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) + ): + return _make_string_array( + object=object, dtype=dtype, ndmin=ndmin, device=device + ) + + # For Python scalar literals, call coerce_to_array to catch any overflow + # errors. We don't use dtypes.is_python_scalar because we don't want this + # triggering for traced values. We do this here because it matters whether or + # not dtype is None. We don't assign the result because we want the raw object + # to be used for type inference below. + if isinstance(object, (bool, int, float, complex)): + _ = dtypes.coerce_to_array(object, dtype) + elif not isinstance(object, Array): + # Check if object supports any of the data exchange protocols + # (except dlpack, see data-apis/array-api#301). If it does, + # consume the object as jax array and continue (but not return) so + # that other array() arguments get processed against the input + # object. + # + # Notice that data exchange protocols define dtype in the + # corresponding data structures and it may not be available as + # object.dtype. So, we'll resolve the protocols here before + # evaluating object.dtype. + if hasattr(object, '__jax_array__'): + object = object.__jax_array__() + elif hasattr(object, '__cuda_array_interface__'): + cai = object.__cuda_array_interface__ + backend = xla_bridge.get_backend("cuda") + if cuda_plugin_extension is None: + device_id = None + else: + device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) + object = xc._xla.cuda_array_interface_to_buffer( + cai=cai, gpu_backend=backend, device_id=device_id) + + leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None) + if any(leaf is None for leaf in leaves): + raise ValueError("None is not a valid value for jnp.array") + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] + if dtype is None: + # Use lattice_result_type rather than result_type to avoid canonicalization. + # Otherwise, weakly-typed inputs would have their dtypes canonicalized. + try: + dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ + except TypeError: + # This happens if, e.g. one of the entries is a memoryview object. + # This is rare, so we only handle it if the normal path fails. + leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] + dtype = dtypes._lattice_result_type(*leaves)[0] + + if not weak_type: + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + + object = treedef.unflatten(leaves) + out: ArrayLike + if all(not isinstance(leaf, Array) for leaf in leaves): + # TODO(jakevdp): falling back to numpy here fails to overflow for lists + # containing large integers; see discussion in + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call + # coerce_to_array on each leaf, but this may have performance implications. + out = np.asarray(object, dtype=dtype) + elif isinstance(object, Array): + assert object.aval is not None + out = lax._array_copy(object) if copy else object + elif isinstance(object, (list, tuple)): + if object: + arrs = (array(elt, dtype=dtype, copy=False) for elt in object) + out = lax.concatenate([lax.expand_dims(arr, [0]) for arr in arrs], 0) + else: + out = np.array([], dtype=dtype) + elif _supports_buffer_protocol(object): + object = memoryview(object) + # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. + out = np.array(object) if copy else np.asarray(object) + else: + raise TypeError(f"Unexpected input type for array: {type(object)}") + out_array: Array = lax._convert_element_type( + out, dtype, weak_type=weak_type, sharding=sharding) + if ndmin > np.ndim(out_array): + out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) + return out_array + + +def _get_platform( + device_or_sharding: xc.Device | Sharding | None | str) -> str: + """Get device_or_sharding platform or look up config.default_device.value.""" + if isinstance(device_or_sharding, xc.Device): + return device_or_sharding.platform + elif isinstance(device_or_sharding, Sharding): + return list(device_or_sharding.device_set)[0].platform + elif isinstance(device_or_sharding, str): + return device_or_sharding + elif device_or_sharding is None: + if config.default_device.value is None: + return xla_bridge.default_backend() + else: + return _get_platform(config.default_device.value) + else: + raise ValueError(f"`{device_or_sharding = }` was passed to" + "`canonicalize_or_get_default_platform`, only xc.Device," + " Sharding, None or str values are supported.") + + +def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: + try: + dtypes.dtype(x) + except TypeError: + return np.asarray(x) + else: + return x + + +@export +def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, + *, copy: bool | None = None, + device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.asarray`. + + Args: + a: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + order: not implemented in JAX + copy: optional boolean specifying the copy mode. If True, then always + return a copy. If False, then error if a copy is necessary. Default is + None, which will only copy when necessary. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.asarray(True) + Array(True, dtype=bool) + >>> jnp.asarray(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.asarray(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.asarray(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.asarray(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.asarray(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.asarray(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ + # For copy=False, the array API specifies that we raise a ValueError if the input supports + # the buffer protocol but a copy is required. Since array() supports the buffer protocol + # via numpy, this is only the case when the default device is not 'cpu' + if (copy is False and not isinstance(a, Array) + and _get_platform(device) != "cpu" + and _supports_buffer_protocol(a)): + raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " + f"on platform={_get_platform(device)} with " + "copy=False. Consider using copy=None or copy=True instead.") + dtypes.check_user_dtype_supported(dtype, "asarray") + if dtype is not None: + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index b14c2fe73faa..63ef76c01b69 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -24,6 +24,7 @@ from jax._src import dtypes from jax._src.lax import lax from jax._src.lib import xla_client as xc +from jax._src.numpy.array import asarray from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.sharding import Sharding @@ -203,8 +204,6 @@ def full(shape: Any, fill_value: ArrayLike, Array([[0, 1, 2], [0, 1, 2]], dtype=int32) """ - from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error - dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) @@ -394,8 +393,6 @@ def full_like(a: ArrayLike | DuckTypedArray, Array([[1, 1, 1], [2, 2, 2]], dtype=int32) """ - from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error - if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: @@ -512,8 +509,6 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: """Implementation of linspace differentiable in start and stop args.""" - from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error - dtypes.check_user_dtype_supported(dtype, "linspace") if num < 0: raise ValueError(f"Number of samples, {num}, must be non-negative.") diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index ab07ecad0cf5..8b70c37192c2 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -21,9 +21,10 @@ from jax._src import config from jax._src import core +from jax._src.numpy.array import array from jax._src.numpy.util import promote_dtypes from jax._src.numpy.lax_numpy import ( - arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose + arange, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike from jax._src.util import set_module diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 86fb20f76052..d823688ee674 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -28,7 +28,6 @@ import builtins from collections.abc import Callable, Sequence from functools import partial -import importlib import math import operator import os @@ -42,13 +41,13 @@ from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc +from jax._src.numpy.array import array, asarray from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions @@ -68,21 +67,11 @@ from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.mesh import get_abstract_mesh from jax._src.pjit import auto_axes -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import tree_map import numpy as np export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - except ImportError: - cuda_plugin_extension = None # type: ignore - else: - break - T = TypeVar('T') # Wrappers for NumPy printoptions @@ -5320,256 +5309,6 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: return [atleast_3d(arr) for arr in arys] -def _supports_buffer_protocol(obj): - try: - view = memoryview(obj) - except TypeError: - return False - else: - return True - - -def _make_string_array( - object: np.ndarray, - dtype: DTypeLike | None = None, - ndmin: int = 0, - device: xc.Device | Sharding | None = None, -) -> Array: - if not isinstance(object, np.ndarray): - raise TypeError( - "Currently, string arrays can only be made from NumPy" - f" arrays. Got: {type(object)}." - ) - if dtype is not None and ( - dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) - ): - raise TypeError( - f"Cannot make an array with dtype {dtype} from an object with dtype" - f" {object.dtype}." - ) - if ndmin > object.ndim: - raise TypeError( - f"ndmin {ndmin} cannot be greater than object's ndims" - f" {object.ndim} for string arrays." - ) - - # Just do a device_put since XLA does not support string as a data type. - return api.device_put(x=object, device=device) - - -@export -def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, - order: str | None = "K", ndmin: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.array`. - - Args: - object: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - copy: specify whether to force a copy of the input. Default: True. - order: not implemented in JAX - ndmin: integer specifying the minimum number of dimensions in the - output array. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.asarray`: like `array`, but by default only copies - when necessary. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.array(True) - Array(True, dtype=bool) - >>> jnp.array(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.array(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.array(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.array([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.array(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.array(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.array(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - if order is not None and order != "K": - raise NotImplementedError("Only implemented for order='K'") - - # check if the given dtype is compatible with JAX - dtypes.check_user_dtype_supported(dtype, "array") - - # Here we make a judgment call: we only return a weakly-typed array when the - # input object itself is weakly typed. That ensures asarray(x) is a no-op - # whenever x is weak, but avoids introducing weak types with something like - # array([1, 2, 3]) - weak_type = dtype is None and dtypes.is_weakly_typed(object) - if device is None and isinstance(object, core.Tracer): - sharding = object.aval.sharding - sharding = None if sharding.mesh.empty else sharding - else: - sharding = util.canonicalize_device_to_sharding(device) - - # Use device_put to avoid a copy for ndarray inputs. - if (not copy and isinstance(object, np.ndarray) and - (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and - device is None): - # Keep the output uncommitted. - return api.device_put(object) - - # String arrays need separate handling because XLA does not support string - # as a data type. - if dtypes.is_string_dtype(dtype) or ( - hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) - ): - return _make_string_array( - object=object, dtype=dtype, ndmin=ndmin, device=device - ) - - # For Python scalar literals, call coerce_to_array to catch any overflow - # errors. We don't use dtypes.is_python_scalar because we don't want this - # triggering for traced values. We do this here because it matters whether or - # not dtype is None. We don't assign the result because we want the raw object - # to be used for type inference below. - if isinstance(object, (bool, int, float, complex)): - _ = dtypes.coerce_to_array(object, dtype) - elif not isinstance(object, Array): - # Check if object supports any of the data exchange protocols - # (except dlpack, see data-apis/array-api#301). If it does, - # consume the object as jax array and continue (but not return) so - # that other array() arguments get processed against the input - # object. - # - # Notice that data exchange protocols define dtype in the - # corresponding data structures and it may not be available as - # object.dtype. So, we'll resolve the protocols here before - # evaluating object.dtype. - if hasattr(object, '__jax_array__'): - object = object.__jax_array__() - elif hasattr(object, '__cuda_array_interface__'): - cai = object.__cuda_array_interface__ - backend = xla_bridge.get_backend("cuda") - if cuda_plugin_extension is None: - device_id = None - else: - device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) - object = xc._xla.cuda_array_interface_to_buffer( - cai=cai, gpu_backend=backend, device_id=device_id) - - leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) - if any(leaf is None for leaf in leaves): - raise ValueError("None is not a valid value for jnp.array") - leaves = [ - leaf - if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None - else leaf_jax_array() - for leaf in leaves - ] - if dtype is None: - # Use lattice_result_type rather than result_type to avoid canonicalization. - # Otherwise, weakly-typed inputs would have their dtypes canonicalized. - try: - dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ - except TypeError: - # This happens if, e.g. one of the entries is a memoryview object. - # This is rare, so we only handle it if the normal path fails. - leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] - dtype = dtypes._lattice_result_type(*leaves)[0] - - if not weak_type: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - - object = treedef.unflatten(leaves) - out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): - # TODO(jakevdp): falling back to numpy here fails to overflow for lists - # containing large integers; see discussion in - # https://github.com/jax-ml/jax/pull/6047. More correct would be to call - # coerce_to_array on each leaf, but this may have performance implications. - out = np.asarray(object, dtype=dtype) - elif isinstance(object, Array): - assert object.aval is not None - out = _array_copy(object) if copy else object - elif isinstance(object, (list, tuple)): - if object: - arrs = (array(elt, dtype=dtype, copy=False) for elt in object) - out = lax.concatenate([lax.expand_dims(arr, [0]) for arr in arrs], 0) - else: - out = np.array([], dtype=dtype) - elif _supports_buffer_protocol(object): - object = memoryview(object) - # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. - out = np.array(object) if copy else np.asarray(object) - else: - raise TypeError(f"Unexpected input type for array: {type(object)}") - out_array: Array = lax_internal._convert_element_type( - out, dtype, weak_type=weak_type, sharding=sharding) - if ndmin > np.ndim(out_array): - out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) - return out_array - - -def _get_platform( - device_or_sharding: xc.Device | Sharding | None | str) -> str: - """Get device_or_sharding platform or look up config.default_device.value.""" - if isinstance(device_or_sharding, xc.Device): - return device_or_sharding.platform - elif isinstance(device_or_sharding, Sharding): - return list(device_or_sharding.device_set)[0].platform - elif isinstance(device_or_sharding, str): - return device_or_sharding - elif device_or_sharding is None: - if config.default_device.value is None: - return xla_bridge.default_backend() - else: - return _get_platform(config.default_device.value) - else: - raise ValueError(f"`{device_or_sharding = }` was passed to" - "`canonicalize_or_get_default_platform`, only xc.Device," - " Sharding, None or str values are supported.") - - -def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: - try: - dtypes.dtype(x) - except TypeError: - return np.asarray(x) - else: - return x - - @export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, @@ -5634,88 +5373,6 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result -@export -def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, - *, copy: bool | None = None, - device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.asarray`. - - Args: - a: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - order: not implemented in JAX - copy: optional boolean specifying the copy mode. If True, then always - return a copy. If False, then error if a copy is necessary. Default is - None, which will only copy when necessary. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.asarray(True) - Array(True, dtype=bool) - >>> jnp.asarray(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.asarray(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.asarray(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.asarray(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.asarray(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.asarray(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - # For copy=False, the array API specifies that we raise a ValueError if the input supports - # the buffer protocol but a copy is required. Since array() supports the buffer protocol - # via numpy, this is only the case when the default device is not 'cpu' - if (copy is False and not isinstance(a, Array) - and _get_platform(device) != "cpu" - and _supports_buffer_protocol(a)): - raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " - f"on platform={_get_platform(device)} with " - "copy=False. Consider using copy=None or copy=True instead.") - dtypes.check_user_dtype_supported(dtype, "asarray") - if dtype is not None: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) - - @export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 1abe7cf66c15..442df38a9641 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -27,6 +27,7 @@ from jax._src.typing import Array from jax._src import core from jax._src import dtypes +from jax._src.numpy.array import asarray # Some objects below rewrite their __module__ attribute to this name. @@ -46,7 +47,6 @@ def __ne__(self, other: Any) -> bool: return not (self == other) def __call__(self, x: Any) -> Array: - from jax._src.numpy.lax_numpy import asarray return asarray(x, dtype=self.dtype) def __instancecheck__(self, instance: Any) -> bool: diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 935fbcaa708c..24a0ca907567 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -24,6 +24,11 @@ isdtype as isdtype, ) +from jax._src.numpy.array import ( + array as array, + asarray as asarray, +) + from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, allclose as allclose, @@ -36,12 +41,10 @@ argmin as argmin, argwhere as argwhere, around as around, - array as array, array_equal as array_equal, array_equiv as array_equiv, array_split as array_split, astype as astype, - asarray as asarray, atleast_1d as atleast_1d, atleast_2d as atleast_2d, atleast_3d as atleast_3d, From ef9d3f880d99c83cdb1ab707c830e97da4a21c16 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 5 Jun 2025 11:14:54 -0700 Subject: [PATCH 1544/1769] skip pytype on slow file PiperOrigin-RevId: 767687785 --- jax/_src/interpreters/partial_eval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a4f4fcd12429..8c165df46e12 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# pytype: skip-file from __future__ import annotations from collections import namedtuple From 9fc670ec7451ebbe13973235dc76e2dbbe03f5c3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 5 Jun 2025 17:20:52 +0000 Subject: [PATCH 1545/1769] [cleanup] remove core.gensym, and Var.suffix Co-authored-by: Dougal Maclaurin --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/core.py | 17 +++++++---------- jax/_src/interpreters/batching.py | 6 +++--- jax/_src/interpreters/partial_eval.py | 13 +++++-------- jax/_src/lax/control_flow/common.py | 5 ++--- jax/_src/lax/control_flow/conditionals.py | 6 ++---- jax/_src/pallas/hlo_interpreter.py | 3 +-- tests/core_test.py | 10 ---------- 8 files changed, 21 insertions(+), 41 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 2a056d5c94f0..b11afd4a86de 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -621,7 +621,7 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: if v not in used_vars: continue assert isinstance(v, core.Var) - newvar = core.Var(v.suffix, v.aval) + newvar = core.Var(v.aval) finfo = dtypes.finfo(v.aval.dtype) params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) if v in constvars or v in invars: diff --git a/jax/_src/core.py b/jax/_src/core.py index 49579bc3de10..dcd0d5dd6591 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -433,34 +433,31 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() class Var: - __slots__ = ["count", "suffix", "aval", "initial_qdd", "final_qdd"] + __slots__ = ["count", "aval", "initial_qdd", "final_qdd"] count: int - suffix: str aval: AbstractValue # these are only useful for jaxpr binders but rather than create a separate # type for those, breaking existing interpreters, we add fields here. initial_qdd : QuasiDynamicData | None final_qdd : QuasiDynamicData | None - def __init__(self, suffix: str, aval: AbstractValue, initial_qdd = None, final_qdd = None): + def __init__(self, aval: AbstractValue, initial_qdd = None, final_qdd = None): + assert isinstance(aval, AbstractValue) self.count = next(_var_counter) - self.suffix = suffix self.aval = aval self.initial_qdd = initial_qdd self.final_qdd = final_qdd def __repr__(self): - return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' + return f'Var(id={id(self)}):{self.aval.str_short()}' def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del print_dtype # unused - return f"{context.var_names[self]}{self.suffix}" + return f"{context.var_names[self]}" -def gensym(suffix: str = '') -> Callable: - """Produce distinct variables, printed with the optional suffix.""" - return partial(Var, suffix) +gensym = lambda: Var # In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that # the assignment is dropped, i.e. that an expression's output value will never @@ -468,7 +465,7 @@ def gensym(suffix: str = '') -> Callable: # treat it as a special case of one. Its `aval` is similarly inexact. class DropVar(Var): def __init__(self, aval: AbstractValue): - super().__init__('', aval) + super().__init__(aval) def __repr__(self): return '_' def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del context, print_dtype # unused diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 55769aa307fc..ad5b0b4f408b 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -111,7 +111,7 @@ def _jumble_unflatten(aval, x): register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) if stacked_axis != 0: raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 shape = list(x.shape) @@ -175,7 +175,7 @@ def bdim_as_shape( bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: if isinstance(bdim, RaggedAxis): result = list(data_shape) - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) for ragged_axis, segment_lens in bdim.ragged_axes: result[ragged_axis] = IndexedAxisSize(binder, segment_lens) return tuple(result) @@ -1138,7 +1138,7 @@ def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: x = bdim_at_front(x, src, sz) elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var('', core.ShapedArray((), np.dtype('int32'))), + aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))), x.shape[0], elt_ty) return Jumble(aval, x) try: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 8c165df46e12..95ab769d4e27 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + # pytype: skip-file from __future__ import annotations @@ -1135,7 +1135,6 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: def has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) - newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) foreach(partial(write, False, True), jaxpr.constvars) @@ -1165,7 +1164,7 @@ def has_effects(effects) -> bool: elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error - resvars = [newvar(v.aval) for v in eqn.outvars] + resvars = [Var(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, @@ -1301,13 +1300,12 @@ def call_partial_eval_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [newvar(res_aval(params_known, var.aval)) + residuals = [Var(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1340,14 +1338,13 @@ def closed_call_partial_eval_custom_rule( ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, sum(f is None for f in out_fwd), num_res, params_known, params_staged) res_val_binders, res_ref_binders = split_list( - [newvar(res_aval(params_known, v)) + [Var(res_aval(params_known, v)) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val]) res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None] res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders) @@ -2720,7 +2717,7 @@ def inline_jaxpr_into_trace( for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] orig_outvars = eqn.outvars - outvars = [Var('', v.aval) for v in orig_outvars] + outvars = [Var(v.aval) for v in orig_outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) eqn = eqn.replace(invars, outvars, source_info=src_) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b75cbf6ac708..cb80df76326b 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -184,9 +184,8 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) - newvar = core.gensym(suffix='_') - padded_ref_constvars = map(newvar, canonical_ref_avals) - padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) + padded_ref_constvars = map(core.Var, canonical_ref_avals) + padded_non_ref_constvars = map(core.Var, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4e8368341d9f..4360c4a6df3b 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -677,8 +677,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): for j in branches_known[1:]) # Create residual variables. - newvar = core.gensym() - res_binders = map(newvar, all_res_avals) + res_binders = map(core.Var, all_res_avals) # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar @@ -763,8 +762,7 @@ def f_aug(*args): def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, res_aval_indices_per_jaxpr): - newvar = core.gensym(suffix='_') - all_res_vars = map(newvar, all_res_avals) + all_res_vars = map(core.Var, all_res_avals) def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index fac798fe9dc1..2568ea8b74a1 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -235,8 +235,7 @@ def pad_jaxpr_constvars(jaxpr: jax_core.Jaxpr, to pad each Jaxpr with all consts from all branches so the signatures match, but only use the consts for this branch. """ - newvar = jax_core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) + unused_const_vars = [tuple(map(jax_core.Var, const_avals)) for const_avals in all_const_avals] const_prefix = util.concatenate(unused_const_vars[:i]) const_suffix = util.concatenate(unused_const_vars[i + 1:]) diff --git a/tests/core_test.py b/tests/core_test.py index 646705ebf281..334df2222b0c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -551,16 +551,6 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) - @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr - def test_jaxpr_undefined_eqn_invar(self): - jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr - cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') - cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval) - self.assertRaisesRegex( - core.JaxprTypeError, - r"Variable '.+_test' not defined\n\nin equation:", - lambda: core.check_jaxpr(jaxpr)) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): From 18d0da9657aac476e407f9b0ecd3d0d5d394df79 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 5 Jun 2025 11:36:06 -0700 Subject: [PATCH 1546/1769] Make sure unsupported transfers between multi-process CPU arrays and TPU/GPU arrays raise a helpful error and don't take the new experimental cross-host device transfer path. PiperOrigin-RevId: 767696126 --- jax/_src/dispatch.py | 3 +++ jaxlib/_jax/__init__.pyi | 2 ++ jaxlib/py_device_list.cc | 17 ++++++++++++++++- jaxlib/py_device_list.h | 5 +++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b9ef8f49f801..e58e6941278d 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -388,6 +388,9 @@ def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or "TFRT TPU" not in backend.platform_version): return False + if (src_sharding._internal_device_list.device_kind != + dst_sharding._internal_device_list.device_kind): + return False if (src_sharding._to_xla_hlo_sharding(ndim) != dst_sharding._to_xla_hlo_sharding(ndim)): return False diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 26125d7edcbc..bbdc6bbdf896 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -887,6 +887,8 @@ class DeviceList: def default_memory_kind(self) -> str | None: ... @property def memory_kinds(self) -> tuple[str, ...]: ... + @property + def device_kind(self) -> str: ... class Sharding: ... diff --git a/jaxlib/py_device_list.cc b/jaxlib/py_device_list.cc index c5004dc57330..71f1125c749b 100644 --- a/jaxlib/py_device_list.cc +++ b/jaxlib/py_device_list.cc @@ -356,6 +356,20 @@ const std::set& PyDeviceList::ProcessIndices() { return *process_indices_; } +const std::string& PyDeviceList::DeviceKind() { + if (!device_kind_.has_value()) { + auto device_list = ifrt_device_list(); + if (!device_list.ok()) { + throw nb::value_error(device_list.status().ToString().c_str()); + } + if (Len() == 0) { + throw nb::value_error("DeviceList is empty"); + } + device_kind_ = (*device_list)->devices()[0]->Kind(); + } + return *device_kind_; +} + void PyDeviceList::PopulateMemoryKindInfo() { if (device_list_.index() == 1) { // Handle Python duck-type devices in a separate function for readability. @@ -476,7 +490,8 @@ void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { throw nb::value_error(kinds.status().ToString().c_str()); } return *kinds; - }); + }) + .def_prop_ro("device_kind", &PyDeviceList::DeviceKind, nb::lock_self()); } } // namespace jax diff --git a/jaxlib/py_device_list.h b/jaxlib/py_device_list.h index 19c646dfc99b..8cc44206e734 100644 --- a/jaxlib/py_device_list.h +++ b/jaxlib/py_device_list.h @@ -107,6 +107,9 @@ class PyDeviceList { // Requires the self lock or GIL. const std::set& ProcessIndices(); + // Requires the self lock or GIL. + const std::string& DeviceKind(); + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and // non-empty. xla::nb_class_ptr py_client_; @@ -128,6 +131,8 @@ class PyDeviceList { std::optional> addressable_device_list_; // Populated on demand. Guarded by the object's self lock. std::optional> process_indices_; + // Populated on demand. Guarded by the object's self lock. + std::optional device_kind_; struct MemoryKindInfo { nanobind::object default_memory_kind; From 960f9c5454958228db6a161bd581c71835269cd4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Jun 2025 13:54:50 -0700 Subject: [PATCH 1547/1769] Don't recompute source_info.current() in DynamicJaxprTracer. PiperOrigin-RevId: 767751945 --- jax/_src/interpreters/partial_eval.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 95ab769d4e27..4360a25238d4 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2008,16 +2008,18 @@ def process_primitive(self, primitive, tracers, params): jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) - return self.default_process_primitive(primitive, jaxpr_tracers, params) + return self.default_process_primitive( + primitive, jaxpr_tracers, params, source_info) - def default_process_primitive(self, primitive, tracers, params): + def default_process_primitive(self, primitive, tracers, params, + source_info=None): aval_qdds = [t.aval_mutable_qdd for t in tracers] out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals - source_info = source_info_util.current() + source_info = source_info or source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) outvars = map(self.makevar, out_tracers) From ff50b5fd5fdbb60085473067228b2afbb144ee9f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Jun 2025 14:04:13 -0700 Subject: [PATCH 1548/1769] [Pallas][Mosaic GPU] Support column slicing on TMEM. PiperOrigin-RevId: 767755766 --- jax/_src/pallas/mosaic_gpu/lowering.py | 61 +++++++++++++++----------- jax/experimental/mosaic/gpu/tcgen05.py | 3 +- tests/pallas/mosaic_gpu_test.py | 35 +++++++++++++++ 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f34b1e649c2d..e4529acfeed2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -24,7 +24,7 @@ import itertools import math import operator -from typing import Any, Protocol, cast +from typing import Any, Protocol, cast, TypeVar, Union import jax from jax import api_util @@ -80,6 +80,7 @@ partial = functools.partial SMEM = gpu_core.SMEM WARPGROUP_SIZE = 128 +RefOrTmemType = TypeVar("RefOrTmemType", bound=Union[ir.Value, tcgen05.TMEMRef]) @dataclasses.dataclass(frozen=True, kw_only=True) @@ -1300,18 +1301,21 @@ def _extract_aliased_ref( def _handle_transforms( ctx: LoweringRuleContext, - ref: ir.Value, + ref: RefOrTmemType, transforms: Sequence[gpu_core.Transform], *, handle_transposes=True, handle_reshapes=True, allow_peer_refs=False, -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - # Before we handle other transforms, we resolve any possible leading aliasing - # transform. - ref, transforms = _extract_aliased_ref(ref, transforms) +) -> tuple[RefOrTmemType, Sequence[gpu_core.Transform]]: + if isinstance(ref, tcgen05.TMEMRef): + mlir_dtype = ref.dtype + else: + # Before we handle other transforms, we resolve any possible leading + # aliasing transform. + ref, transforms = _extract_aliased_ref(ref, transforms) + mlir_dtype = ir.MemRefType(ref.type).element_type transformed_ref = ref - mlir_dtype = ir.MemRefType(ref.type).element_type new_transforms = [] def _bubble_up(untransform_fn, data): nonlocal new_transforms @@ -1334,15 +1338,22 @@ def _bubble_up(untransform_fn, data): indices = _bubble_up( lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices ) - transformed_ref = mgpu.memref_slice(transformed_ref, indices) + if isinstance(transformed_ref, tcgen05.TMEMRef): + transformed_ref = transformed_ref.slice(*indices) + else: + transformed_ref = mgpu.memref_slice(transformed_ref, indices) case gpu_core.TransposeRef(perm) if handle_transposes: perm = _bubble_up(lambda t, p: t.untransform_transpose(p), perm) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") transformed_ref = mgpu.memref_transpose(transformed_ref, perm) case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: shape = _bubble_up( lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop shape) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM reshape not allowed.") transformed_ref = mgpu.memref_reshape(transformed_ref, shape) case gpu_core.PeerMemRef(device_id, device_id_type): if device_id_type != primitives.DeviceIdType.LOGICAL: @@ -1387,15 +1398,14 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): if isinstance(x_ref, tcgen05.TMEMRef): transforms = jax.tree.unflatten(tree, leaves) - if len(transforms) != 1 or not isinstance( - transforms[0], indexing.NDIndexer): - raise NotImplementedError( - "Only a single indexing transform is supported for TMEM refs.") - indexer = cast(indexing.NDIndexer, transforms[0]) - if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): + x_tmem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: raise NotImplementedError( - "Only trivial indexing is supported for TMEM refs.") - return x_ref.load() + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + return x_tmem.load() if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): raise TypeError(f"Can only load from references (got {x_ref}).") @@ -1468,16 +1478,15 @@ def _swap_lowering_rule( if isinstance(x_ref, tcgen05.TMEMRef): transforms = jax.tree.unflatten(tree, leaves) - match transforms: - case (indexer,) if isinstance(indexer, indexing.NDIndexer): - if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): - raise NotImplementedError( - "Only trivial indexing is supported for TMEM refs.") - case _: - raise NotImplementedError( - "Only a single indexing transform is supported for TMEM refs.") - old_value = x_ref.load(layout=value.layout) - x_ref.store(value) + x_tmem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + old_value = x_tmem.load(layout=value.layout) + x_tmem.store(value) return old_value if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 32f0453148f4..96bce2a1994f 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -681,7 +681,8 @@ def slice(self, *idxs): raise NotImplementedError("TMEM cannot be sliced along rows") if slice_shape[1] % 8: raise NotImplementedError( - "TMEM column slice length must be a multiple of 8" + "TMEM column slice length must be a multiple of 8. " + f"Got {slice_shape[1]}." ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 04c3db0b30c9..cd0346091fce 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2461,6 +2461,41 @@ def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): x_result = jax.block_until_ready(kernel(x)) np.testing.assert_array_equal(x_result, x + 1) + def test_tmem_column_slicing(self): + self.skip_if_wg_semantics() + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 256), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + tmem_slice = tmem_ref.at[:, 8:208].at[:, 0:128] + tmem_slice[...] = x_val + 1 + plgpu.commit_tmem() + smem_ref[...] = tmem_ref[:, 8:136] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, (x + 1)[:, 0:128]) + @parameterized.parameters( (jnp.sum,), (jnp.max,) From 013b1b1d3ac822f589aacd2d2c9564de33f59bfb Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 5 Jun 2025 14:04:59 -0700 Subject: [PATCH 1549/1769] [Mosaic GPU] Fix `2xf32 -> 2xf8e4m3fn` conversion. I messed up the first time, the output register is naturally 16-bit wide. The PTX looks much cleaner now. PiperOrigin-RevId: 767756118 --- .../mosaic/gpu/fragmented_array.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 40d95f65f210..2f574ac738d4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1554,26 +1554,19 @@ def upcast_to_bf16(reg, high): if vector_len != 2: raise NotImplementedError(vector_len) new_registers = np.empty_like(self.registers) - empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32)) - empty_result_vec = llvm.mlir_undef(ir.VectorType.get((2,), i8)) + empty_vec_16 = llvm.mlir_undef(ir.VectorType.get((1,), i16)) for idx, reg in np.ndenumerate(self.registers): e0 = vector.extractelement(reg, position=c(0, index)) e1 = vector.extractelement(reg, position=c(1, index)) - new_reg_32 = llvm.inline_asm( - i32, + new_reg_16 = llvm.inline_asm( + i16, [e1, e0], "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", "=h,f,f", ) - new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) - new_vec_f8 = vector.bitcast(ir.VectorType.get((4,), i8), new_vec_32) - res = llvm.insertelement( - empty_result_vec, - vector.extractelement(new_vec_f8, position=c(0, i32)), c(0, i32)) - res = llvm.insertelement( - res, - vector.extractelement(new_vec_f8, position=c(1, i32)), c(1, i32)) - new_registers[idx] = vector.bitcast(ir.VectorType.get((2,), f8e4m3fn), res) + new_registers[idx] = vector.bitcast( + ir.VectorType.get((2,), f8e4m3fn), + llvm.insertelement(empty_vec_16, new_reg_16, c(0, i32))) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) From fc431221ed7b95bf55610c8188f86596d6a615dc Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Jun 2025 14:32:24 -0700 Subject: [PATCH 1550/1769] [Pallas][Mosaic GPU] Skip tcgen05 reduce test on WG semantics. PiperOrigin-RevId: 767767487 --- tests/pallas/mosaic_gpu_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cd0346091fce..2508d6dec09d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2501,6 +2501,7 @@ def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref): (jnp.max,) ) def test_reduce_with_tcgen05_layout(self, op): + self.skip_if_wg_semantics() axis = -1 swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize transforms = ( From c30ed91ae21a91ca4faceea851ed4538161e084f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Jun 2025 14:43:56 -0700 Subject: [PATCH 1551/1769] [Pallas][Mosaic GPU] Add support for load/broadcast using TCGEN05 ROW/COL layouts. PiperOrigin-RevId: 767772068 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 +++-- jax/_src/pallas/mosaic_gpu/primitives.py | 7 ++- jax/experimental/mosaic/gpu/__init__.py | 9 ++-- .../mosaic/gpu/fragmented_array.py | 50 +++++++++++++++---- jax/experimental/mosaic/gpu/tcgen05.py | 26 ++-------- tests/pallas/mosaic_gpu_test.py | 34 +++++++++++++ 6 files changed, 95 insertions(+), 43 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e4529acfeed2..24f219df49a5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1658,17 +1658,19 @@ def _broadcast_in_dim_lowering_rule( if ( broadcast_dimensions == tuple(range(x_aval.ndim)) and y_aval.ndim == x_aval.ndim + 1 - and x.layout == mgpu.WGMMA_ROW_LAYOUT + and x.layout in (mgpu.WGMMA_ROW_LAYOUT, mgpu.TCGEN05_ROW_LAYOUT) ): return x.broadcast_minor(y_aval.shape[-1]) if ( - broadcast_dimensions == (1,) - and y_aval.ndim == x_aval.ndim + 1 - and x.layout == mgpu.WGMMA_COL_LAYOUT + broadcast_dimensions == (1,) + and y_aval.ndim == x_aval.ndim + 1 + and x.layout in (mgpu.WGMMA_COL_LAYOUT, mgpu.TCGEN05_COL_LAYOUT) ): return x.broadcast_major(y_aval.shape[-2]) if broadcast_dimensions: - raise NotImplementedError + raise NotImplementedError( + f"Unsupport broadcast {broadcast_dimensions} for layout: {x.layout}" + ) return x.broadcast(shape) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 41c8a673b528..3fe9b62f3f1e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -115,7 +115,12 @@ def _load_p_lowering_rule( val, shape=(), layout=layout, is_signed=is_signed ) match layout: - case mgpu.WGMMA_ROW_LAYOUT | mgpu.WGMMA_COL_LAYOUT: + case ( + mgpu.WGMMA_ROW_LAYOUT + | mgpu.WGMMA_COL_LAYOUT + | mgpu.TCGEN05_ROW_LAYOUT + | mgpu.TCGEN05_COL_LAYOUT + ): return mgpu.FragmentedArray.load_untiled( x_ref, is_signed=is_signed, diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index f4b65c976ef5..dbee817df20c 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -58,6 +58,9 @@ from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TCGEN05_LAYOUT as TCGEN05_LAYOUT, + TCGEN05_ROW_LAYOUT as TCGEN05_ROW_LAYOUT, + TCGEN05_COL_LAYOUT as TCGEN05_COL_LAYOUT, TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, @@ -102,9 +105,3 @@ WGMMAAccumulator as WGMMAAccumulator, wgmma as wgmma, ) - -from .tcgen05 import ( - LAYOUT as TCGEN05_LAYOUT, # noqa: F401 - ROW_LAYOUT as TCGEN05_ROW_LAYOUT, # noqa: F401 - COL_LAYOUT as TCGEN05_COL_LAYOUT, # noqa: F401 -) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2f574ac738d4..01e449316703 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -685,6 +685,30 @@ def linear_thread_idxs(self): vector_dim=-2, ) +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +TCGEN05_LAYOUT = TiledLayout( + Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=(-4, -3), + vector_dim=-1, +) +# TCGEN05_ROW_LAYOUT is to TCGEN05_LAYOUT as WGMMA_ROW_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_ROW_LAYOUT = TiledLayout( + Tiling(tiles=((128,), (32,), (8,), (1,), (1,))), + warp_dim=-5, + lane_dims=(-3, Replicated(times=4)), + vector_dim=-1, +) +# TCGEN05_COL_LAYOUT is to TCGEN05_LAYOUT as WGMMA_COL_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_COL_LAYOUT = TiledLayout( + Tiling(tiles=((8,), (8,), (8,), (2,))), + warp_dim=Replicated(times=4), + lane_dims=(Replicated(times=8), -2), + vector_dim=-1, +) + @jax.tree_util.register_pytree_node_class @dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) class FragmentedArray: @@ -1863,11 +1887,15 @@ def reshape(self, shape): ) def broadcast_minor(self, n): - if self.layout != WGMMA_ROW_LAYOUT: - raise NotImplementedError + if self.layout == WGMMA_ROW_LAYOUT: + output_layout = WGMMA_LAYOUT + elif self.layout == TCGEN05_ROW_LAYOUT: + output_layout = TCGEN05_LAYOUT + else: + raise NotImplementedError(self.layout) if n % 8: raise ValueError("Number of columns must be divisible by 8") - reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) + reg_shape = output_layout.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype i0 = arith.constant(ir.IndexType.get(), 0) @@ -1876,26 +1904,30 @@ def broadcast_minor(self, n): tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), + ir.VectorType.get((output_layout.vector_length,), dtype), vector.extractelement(reg, position=i0), ) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + _registers=new_regs, _layout=output_layout, _is_signed=self.is_signed ) def broadcast_major(self, m): - if self.layout != WGMMA_COL_LAYOUT: - raise NotImplementedError + if self.layout == WGMMA_COL_LAYOUT: + output_layout = WGMMA_LAYOUT + elif self.layout == TCGEN05_COL_LAYOUT: + output_layout = TCGEN05_LAYOUT + else: + raise NotImplementedError(self.layout) if m % 64: raise ValueError("Number of rows must be divisible by 64") - reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) + reg_shape = output_layout.registers_shape((m, self.shape[0])) new_regs = np.empty(reg_shape, dtype=object) for (col_tile, *_), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[1] = col_tile new_regs[tuple(tile)] = reg return FragmentedArray( - _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + _registers=new_regs, _layout=output_layout, _is_signed=self.is_signed ) def select(self, on_true, on_false): diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 96bce2a1994f..c0415db6b581 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -35,28 +35,10 @@ TMEM_ROWS = 128 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 -# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. -# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) -LAYOUT = fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, -) -# ROW_LAYOUT is to LAYOUT as WGMMA_ROW_LAYOUT is to WGMMA_LAYOUT. -ROW_LAYOUT = fa.TiledLayout( - fa.Tiling(tiles=((128,), (32,), (8,), (1,), (1,))), - warp_dim=-5, - lane_dims=(-3, fa.Replicated(times=4)), - vector_dim=-1 -) -# COL_LAYOUT is to LAYOUT as WGMMA_COL_LAYOUT is to WGMMA_LAYOUT. -COL_LAYOUT = fa.TiledLayout( - fa.Tiling(tiles=((8,), (8,), (8,), (2,))), - warp_dim=fa.Replicated(times=4), - lane_dims=(fa.Replicated(times=8), -2), - vector_dim=-1 -) +LAYOUT = fa.TCGEN05_LAYOUT +ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT +COL_LAYOUT = fa.TCGEN05_COL_LAYOUT + # A layout resembling the logical organization of TMEM. The 128 rows in a tile # are assigned to 128 lanes in the warpgroup. Useful when the result needs to be # processed in registers and then stored back into TMEM. Should not be used if diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2508d6dec09d..ff4f29aefb59 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2533,6 +2533,40 @@ def kernel(x_ref, y_ref, smem_ref, smem_reduced_ref, barrier_ref): x_result = jax.block_until_ready(kernel(x)) np.testing.assert_allclose(x_result, op(x, axis=axis), atol=1e-5) + @parameterized.parameters((0,), (1,)) + def test_broadcast_in_dim_tcgen05_layout(self, axis): + self.skip_if_wg_semantics() + + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, smem_ref, smem_out_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + if axis == 0: + reduced = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05_COL) + else: + reduced = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05_ROW) + broadcasted = lax.broadcast_in_dim(reduced, (128, 128), [1 - axis]) + smem_out_ref[...] = broadcasted + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_out_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + expected = jnp.expand_dims(x, axis=axis) + expected = jnp.broadcast_to(expected, (128, 128)) + np.testing.assert_array_equal(x_result, expected) + @parameterized.product(shape=[(128, 128)], swizzle=[128, 64, 32], dtype=[jnp.float16, jnp.bfloat16], From a5ce3adc87a915aa4bf04876e2bd83be6bb5556d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Jun 2025 15:48:38 -0700 Subject: [PATCH 1552/1769] lax.top_k: raise error if indices will overflow --- jax/_src/lax/lax.py | 9 +++++++++ tests/lax_test.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9c9a1dd8c1ba..59363113d7b3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8094,6 +8094,15 @@ def _top_k_abstract_eval(operand, *, k): if shape[-1] < k: msg = "k argument to top_k must be no larger than minor dimension; {} vs {}" raise ValueError(msg.format(k, shape)) + int32_max = dtypes.iinfo('int32').max + try: + too_large = (shape[-1] > int32_max + 1) + except core.InconclusiveDimensionOperation: + pass + else: + if too_large: + raise ValueError("top_k returns int32 indices, which will overflow for array dimensions " + f"larger than the maximum int32 ({int32_max}). Got {operand.shape=}") shape[-1] = k return (operand.update(shape=shape, dtype=operand.dtype, weak_type=operand.weak_type), diff --git a/tests/lax_test.py b/tests/lax_test.py index 4d30cd70ea70..2be76913a59f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2631,6 +2631,11 @@ def reference_top_k(x): self._CheckAgainstNumpy(op, reference_top_k, args_maker) self._CompileAndCheck(op, args_maker) + def testTopKOverflow(self): + x = jax.ShapeDtypeStruct((2 ** 31 + 1,), np.dtype('bfloat16')) + with self.assertRaisesRegex(ValueError, "top_k returns int32 indices, which will overflow"): + jax.eval_shape(lambda x: jax.lax.top_k(x, 100), x) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [((3, 2), (2, 4)), From 50d93eebc8219ec2b1f02fe0aa1d61626458ee55 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 5 Jun 2025 21:43:56 -0700 Subject: [PATCH 1553/1769] Bring back tree concat optimization for np.array(...) PiperOrigin-RevId: 767902551 --- jax/_src/numpy/array.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/array.py b/jax/_src/numpy/array.py index 7e1c169630ce..73bbd7d09554 100644 --- a/jax/_src/numpy/array.py +++ b/jax/_src/numpy/array.py @@ -248,7 +248,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, elif isinstance(object, (list, tuple)): if object: arrs = (array(elt, dtype=dtype, copy=False) for elt in object) - out = lax.concatenate([lax.expand_dims(arr, [0]) for arr in arrs], 0) + arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs] + # lax.concatenate can be slow to compile for wide concatenations, so form a + # tree of concatenations as a workaround especially for op-by-op mode. + # (https://github.com/jax-ml/jax/issues/653). + k = 16 + while len(arrays_out) > k: + arrays_out = [lax.concatenate(arrays_out[i:i+k], 0) + for i in range(0, len(arrays_out), k)] + out = lax.concatenate(arrays_out, 0) else: out = np.array([], dtype=dtype) elif _supports_buffer_protocol(object): From fd436509c49c0304bec96173fe1d0430f6478cf7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Jun 2025 22:30:19 -0700 Subject: [PATCH 1554/1769] [Mosaic] Adds both direct (where hardware can) support for int8 Transpose, and canonicalization for where it cannot. To support Transpose of int8 we need to extend to i32, run transpose with bf16 and truncate. PiperOrigin-RevId: 767915759 --- .../tpu/transforms/canonicalize_mosaic.cc | 129 +++++++++++++++++- 1 file changed, 122 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 733863546935..a88b70d302af 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -71,6 +71,10 @@ struct CanonicalizeContext { std::array target_shape; }; +Value create_transpose_op(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, VectorType input_ty, + Value input, ArrayRef permutation); + bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, Operation &op); @@ -239,7 +243,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, } } - auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { + auto dot_dim_matmul = [&](Value lhs, auto rhs, auto acc) { auto precision_attr = op.getPrecisionAttr(); // If we are transposing the lhs, we need to transpose the lhs before @@ -258,13 +262,12 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, std::vector shape(lhs_ty.getShape()); std::swap(shape[rank - 2], shape[rank - 1]); - auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType()); + VectorType lhs_ty_transposed = + VectorType::get(shape, lhs_ty.getElementType()); const SmallVector perm_vec = SmallVector(perm.begin(), perm.end()); - lhs = builder.create( - lhs_ty_transposed, lhs, - DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); + lhs = create_transpose_op(ctx, builder, lhs_ty_transposed, lhs, perm_vec); } auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false, transpose_rhs); @@ -946,6 +949,117 @@ LogicalResult canonicalize_reshape(const CanonicalizeContext &ctx, return success(); } +namespace { +// TODO(mvoz): We can refactor a lot of other canonicalization rules to use +// these functions. +// TODO(mvoz): I think we can eventually do direct conversion to bf16 +// without going through f32? +Value upcastInt8ToBf16(ImplicitLocOpBuilder &builder, Value input) { + auto vty = cast(input.getType()); + auto shape = vty.getShape(); + auto int_ty = cast(vty.getElementType()); + + auto i32_vty = VectorType::get(shape, builder.getI32Type()); + auto val_i32 = int_ty.isUnsigned() + ? builder.create(i32_vty, input) + : builder.create(i32_vty, input); + + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create( + f32_vty, val_i32->getResult(0), tpu::RoundingMode::kToNearestEven); + + auto bf16_vty = VectorType::get(shape, builder.getBF16Type()); + return builder.create(bf16_vty, val_f32); +} + +Value downcastBf16ToInt8(ImplicitLocOpBuilder &builder, Value input_bf16, + Type target_vty) { + auto shape = cast(input_bf16.getType()).getShape(); + + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input_bf16); + + auto i32_vty = VectorType::get(shape, builder.getI32Type()); + auto val_i32 = builder.create(i32_vty, val_f32); + + return builder.create(target_vty, val_i32); +} + +Value upcastFp8ToBf16(ImplicitLocOpBuilder &builder, Value input) { + auto shape = cast(input.getType()).getShape(); + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input); + auto bf16_vty = VectorType::get(shape, builder.getBF16Type()); + return builder.create(bf16_vty, val_f32); +} + +Value downcastBf16ToFp8(ImplicitLocOpBuilder &builder, Value input_bf16, + Type target_vty) { + auto shape = cast(input_bf16.getType()).getShape(); + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input_bf16); + return builder.create(target_vty, val_f32); +} +} // namespace + +// Note(mvoz): Returns optional to signal no replacement, simplifying downstream +// .replace() and .erase() calls. +std::optional canonicalize_transpose_impl(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, + tpu::TransposeOp op) { + auto input_ty = dyn_cast(op.getOperand().getType()); + auto element_type = input_ty.getElementType(); + // TODO(mvoz): Even gen 7 support is spotty on all test targets. + if (element_type.getIntOrFloatBitWidth() == 8 && ctx.compatibility_mode && + ctx.hardware_generation > 3) { + Value val_bf16; + if (isa(element_type)) { + val_bf16 = upcastInt8ToBf16(builder, op.getOperand()); + } else { + val_bf16 = upcastFp8ToBf16(builder, op.getOperand()); + } + + auto original_output_ty = cast(op.getType()); + auto post_transpose_bf16_vty = + VectorType::get(original_output_ty.getShape(), builder.getBF16Type()); + + auto new_t = builder.create( + post_transpose_bf16_vty, val_bf16, op.getPermutation()); + + Value final_val; + if (isa(element_type)) { + final_val = downcastBf16ToInt8(builder, new_t.getResult(), op.getType()); + } else { + final_val = downcastBf16ToFp8(builder, new_t.getResult(), op.getType()); + } + return final_val; + } + return std::nullopt; +} + +Value create_transpose_op(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, VectorType input_ty, + Value input, ArrayRef permutation) { + auto t = builder.create(input_ty, input, permutation); + auto new_op_opt = canonicalize_transpose_impl(ctx, builder, t); + if (new_op_opt.has_value()) { + return new_op_opt.value(); + } + return t; +} + +LogicalResult canonicalize_transpose(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + auto builder = ImplicitLocOpBuilder(op->getLoc(), op.getOperation()); + auto new_op_opt = canonicalize_transpose_impl(ctx, builder, op); + if (new_op_opt.has_value()) { + op.replaceAllUsesWith(new_op_opt.value()); + op.erase(); + } + return success(); +} + using canonicalize_rule_type = std::function; @@ -962,6 +1076,7 @@ const llvm::StringMap &rules() { {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, {arith::SIToFPOp::getOperationName(), canonicalize_sitofp}, + {tpu::TransposeOp::getOperationName(), canonicalize_transpose}, {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; } @@ -1035,8 +1150,8 @@ class MosaicCanonicalizer { CanonicalizeContext ctx( {compatibility_mode_, hardware_generation_, target_shape_}); // We must iterate over the op first, because canonicalization can cause - // us to .erase() an op, and accessing getRegions on it after is not sound. - // Invariant - top level ops with regions may never be invalidated. + // us to .erase() an op, and accessing getRegions on it after is not + // sound. Invariant - top level ops with regions may never be invalidated. for (Region ®ion : any_op.getRegions()) { for (Block &block : region) { if (canonicalizeBlock(block).failed()) { From 20d641f60d0660a38d6663aca39f60ad40086564 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 5 Jun 2025 22:43:19 -0700 Subject: [PATCH 1555/1769] [Mosaic GPU] Add support for lowering `2xbf16 -> 2xf8e4m3fn` converts. PiperOrigin-RevId: 767920329 --- jax/experimental/mosaic/gpu/fragmented_array.py | 6 +++++- tests/mosaic/gpu_test.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 01e449316703..816b6d6cbe74 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1574,7 +1574,7 @@ def upcast_to_bf16(reg, high): _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) # TODO(bchetioui): handle conversions to/from other float8 types. - if cur_dtype == f32 and new_dtype == f8e4m3fn: + if cur_dtype in {bf16, f32} and new_dtype == f8e4m3fn: if vector_len != 2: raise NotImplementedError(vector_len) new_registers = np.empty_like(self.registers) @@ -1582,6 +1582,10 @@ def upcast_to_bf16(reg, high): for idx, reg in np.ndenumerate(self.registers): e0 = vector.extractelement(reg, position=c(0, index)) e1 = vector.extractelement(reg, position=c(1, index)) + # TODO(bchetioui): can we do faster than this? + if cur_dtype == bf16: + e0 = arith.extf(f32, e0) + e1 = arith.extf(f32, e1) new_reg_16 = llvm.inline_asm( i16, [e1, e0], diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index a6828a31dd9f..272d2e7b5c2b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -587,8 +587,11 @@ def kernel(ctx, inp, out, smem): f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) np.testing.assert_array_equal(f(x), y) - def test_f8_conversions(self): - jax_dtype_from, jax_dtype_to = jnp.float32, jnp.float8_e4m3fn + @parameterized.parameters( + (jnp.float32, jnp.float8_e4m3fn), + (jnp.bfloat16, jnp.float8_e4m3fn) + ) + def test_f8_conversions(self, jax_dtype_from, jax_dtype_to): mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx From e66745dead0d23d704f1b9d3003d4d2733039105 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 5 Jun 2025 22:47:19 -0700 Subject: [PATCH 1556/1769] Fix logic for checking supported cross-host device transfers, since the (unsupported) TPU PjRt C API client also contains the string "TFRT TPU". PiperOrigin-RevId: 767921346 --- jax/_src/dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e58e6941278d..289c9e0fb89b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -386,7 +386,7 @@ def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): # There is experimental support for cross-host device transfers on TFRT TPU # backends only. if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or - "TFRT TPU" not in backend.platform_version): + not backend.platform_version.startswith("TFRT TPU")): return False if (src_sharding._internal_device_list.device_kind != dst_sharding._internal_device_list.device_kind): From 6440b93bb373ebab6f981227f744643eb1fad05f Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Jun 2025 00:47:36 -0700 Subject: [PATCH 1557/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7312a4a7cdae69c292d56c1da6cd289ede4c797e. PiperOrigin-RevId: 767955213 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 858bd93987a1..3a6a0f5096b1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "da63f2971676b6cd97a72ba52883717cf48e13d8" -XLA_SHA256 = "dcd368eba23a9ace0e8b950f4d9693abdcfb6de5ef109ee6773c450ab1c75a51" +XLA_COMMIT = "7312a4a7cdae69c292d56c1da6cd289ede4c797e" +XLA_SHA256 = "a951125a1e0c60b9e9b36b38ed37515609415817d98134e02cf41db9c7cf8db3" def repo(): tf_http_archive( From 2620f9d1a2cf07d546fa233850cdc9433fac4feb Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 6 Jun 2025 02:12:24 -0700 Subject: [PATCH 1558/1769] [Mosaic GPU] Ensure all ops that need transforms have them at the end of the transform inference pass. PiperOrigin-RevId: 767979248 --- .../mosaic/gpu/transform_inference.py | 26 +++++++++++++++++ tests/mosaic/gpu_transform_inference_test.py | 29 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index e6a4e5bd1cff..7b31bcaefb27 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -411,3 +411,29 @@ def inference_step(op: ir.Operation): inference_utils.traverse_op( op, inference_step, inference_utils.TraversalOrder.FORWARD ) + + # All ops that should have transforms but have no transforms inferred so far + # are assigned an empty sets of transforms. E.g., this happens in kernels with + # only pointwise operations. + def set_empty_transforms(op: ir.Operation): + if ( + inference_utils.should_have_transforms(op) + and not inference_utils.has_in_transforms_set(op) + and not inference_utils.has_out_transforms_set(op) + ): + ins = [ + ir.ArrayAttr.get([]) + for o in op.operands + if inference_utils.is_transformable_smem_memref(o) + ] + outs = [ + ir.ArrayAttr.get([]) + for r in op.results + if inference_utils.is_transformable_smem_memref(r) + ] + _set_transform_attributes(op, ins, outs) + + for op in module.body: + inference_utils.traverse_op( + op, set_empty_transforms, inference_utils.TraversalOrder.FORWARD + ) diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 983efebc4f86..3fdbfd650cb1 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -470,6 +470,35 @@ def body(in_ref): inference_utils.out_transforms(subview_op), [transforms] ) + def test_infer_transforms_sets_default_emptry_transforms(self): + async_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(gmem_ref, smem_ref, barrier): + nonlocal async_load_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + async_load_op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + func.FuncOp.from_py_func(gmem_ty, smem_ty, barrier_ty)(body).func_op + + mgpu.infer_transforms(self.module) + [in_transform] = inference_utils.in_transforms(async_load_op) + self.assertSequenceEqual(in_transform, ir.ArrayAttr.get([])) + self.assertEmpty(inference_utils.out_transforms(async_load_op)) + @parameterized.parameters([False, True]) def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( self, annotate_input From 8d903a986ea099d4cb90baceef00f8c09809620b Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Fri, 6 Jun 2025 02:54:50 -0700 Subject: [PATCH 1559/1769] //tests:scaled_matmul_stablehlo_test: fix for xla#27096 --- tests/scaled_matmul_stablehlo_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index fb5a7560d947..9830d6fefff7 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -47,10 +47,10 @@ c_name = "__cudnn$blockScaledDot" expected_hlos = [ (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), (c_name,), - ("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name), (c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"), ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), From 795bbfc6a32b06f337bd258db23225517c146880 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 6 Jun 2025 04:53:56 -0700 Subject: [PATCH 1560/1769] Add fast-path for non-concrete Tracers in is_constant_dim to lower trace-time overhead from Exception creation PiperOrigin-RevId: 768022818 --- jax/_src/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index dcd0d5dd6591..e99909f24ea7 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2422,6 +2422,9 @@ def is_symbolic_dim(v: Any) -> bool: def is_constant_dim(d: DimSize) -> bool: # Whether the dimension is a static integer constant. + # Try using a fast path for non-concrete Tracers. + if isinstance(d, Tracer) and not is_concrete(d): + return False try: operator.index(d) return True From c1bb095c5ce5b0286dc5052abf3b597b6f23cea5 Mon Sep 17 00:00:00 2001 From: Joshua Bambrick Date: Fri, 6 Jun 2025 05:46:17 -0700 Subject: [PATCH 1561/1769] Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb PiperOrigin-RevId: 768037136 --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/core.py | 20 +++++++++++++------- jax/_src/interpreters/batching.py | 6 +++--- jax/_src/interpreters/partial_eval.py | 13 ++++++++----- jax/_src/lax/control_flow/common.py | 5 +++-- jax/_src/lax/control_flow/conditionals.py | 6 ++++-- jax/_src/pallas/hlo_interpreter.py | 3 ++- tests/core_test.py | 10 ++++++++++ 8 files changed, 44 insertions(+), 21 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index b11afd4a86de..2a056d5c94f0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -621,7 +621,7 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: if v not in used_vars: continue assert isinstance(v, core.Var) - newvar = core.Var(v.aval) + newvar = core.Var(v.suffix, v.aval) finfo = dtypes.finfo(v.aval.dtype) params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) if v in constvars or v in invars: diff --git a/jax/_src/core.py b/jax/_src/core.py index e99909f24ea7..ec6fc47235f1 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -433,31 +433,37 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() class Var: - __slots__ = ["count", "aval", "initial_qdd", "final_qdd"] + __slots__ = ["count", "suffix", "aval", "initial_qdd", "final_qdd"] count: int + suffix: str aval: AbstractValue # these are only useful for jaxpr binders but rather than create a separate # type for those, breaking existing interpreters, we add fields here. initial_qdd : QuasiDynamicData | None final_qdd : QuasiDynamicData | None - def __init__(self, aval: AbstractValue, initial_qdd = None, final_qdd = None): - assert isinstance(aval, AbstractValue) + def __init__( + self, suffix: str, aval: AbstractValue, initial_qdd=None, final_qdd=None + ): self.count = next(_var_counter) + self.suffix = suffix self.aval = aval self.initial_qdd = initial_qdd self.final_qdd = final_qdd def __repr__(self): - return f'Var(id={id(self)}):{self.aval.str_short()}' + return f"Var(id={id(self)}){self.suffix}:{self.aval.str_short()}" def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del print_dtype # unused - return f"{context.var_names[self]}" + return f"{context.var_names[self]}{self.suffix}" + +def gensym(suffix: str = "") -> Callable: + """Produce distinct variables, printed with the optional suffix.""" + return partial(Var, suffix) -gensym = lambda: Var # In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that # the assignment is dropped, i.e. that an expression's output value will never @@ -465,7 +471,7 @@ def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): # treat it as a special case of one. Its `aval` is similarly inexact. class DropVar(Var): def __init__(self, aval: AbstractValue): - super().__init__(aval) + super().__init__("", aval) def __repr__(self): return '_' def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del context, print_dtype # unused diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index ad5b0b4f408b..55769aa307fc 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -111,7 +111,7 @@ def _jumble_unflatten(aval, x): register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var(core.ShapedArray((), np.dtype('int32'))) + binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) if stacked_axis != 0: raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 shape = list(x.shape) @@ -175,7 +175,7 @@ def bdim_as_shape( bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: if isinstance(bdim, RaggedAxis): result = list(data_shape) - binder = core.Var(core.ShapedArray((), np.dtype('int32'))) + binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) for ragged_axis, segment_lens in bdim.ragged_axes: result[ragged_axis] = IndexedAxisSize(binder, segment_lens) return tuple(result) @@ -1138,7 +1138,7 @@ def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: x = bdim_at_front(x, src, sz) elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))), + aval = JumbleTy(core.Var('', core.ShapedArray((), np.dtype('int32'))), x.shape[0], elt_ty) return Jumble(aval, x) try: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 4360a25238d4..fe029ba3b47a 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # pytype: skip-file from __future__ import annotations @@ -1135,6 +1135,7 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: def has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) + newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) foreach(partial(write, False, True), jaxpr.constvars) @@ -1164,7 +1165,7 @@ def has_effects(effects) -> bool: elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error - resvars = [Var(v.aval) for v in eqn.outvars] + resvars = [newvar(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, @@ -1300,12 +1301,13 @@ def call_partial_eval_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [Var(res_aval(params_known, var.aval)) + residuals = [newvar(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1338,13 +1340,14 @@ def closed_call_partial_eval_custom_rule( ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, sum(f is None for f in out_fwd), num_res, params_known, params_staged) res_val_binders, res_ref_binders = split_list( - [Var(res_aval(params_known, v)) + [newvar(res_aval(params_known, v)) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val]) res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None] res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders) @@ -2719,7 +2722,7 @@ def inline_jaxpr_into_trace( for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] orig_outvars = eqn.outvars - outvars = [Var(v.aval) for v in orig_outvars] + outvars = [Var('', v.aval) for v in orig_outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) eqn = eqn.replace(invars, outvars, source_info=src_) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cb80df76326b..b75cbf6ac708 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -184,8 +184,9 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) - padded_ref_constvars = map(core.Var, canonical_ref_avals) - padded_non_ref_constvars = map(core.Var, canonical_non_ref_avals) + newvar = core.gensym(suffix='_') + padded_ref_constvars = map(newvar, canonical_ref_avals) + padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4360c4a6df3b..4e8368341d9f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -677,7 +677,8 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): for j in branches_known[1:]) # Create residual variables. - res_binders = map(core.Var, all_res_avals) + newvar = core.gensym() + res_binders = map(newvar, all_res_avals) # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar @@ -762,7 +763,8 @@ def f_aug(*args): def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, res_aval_indices_per_jaxpr): - all_res_vars = map(core.Var, all_res_avals) + newvar = core.gensym(suffix='_') + all_res_vars = map(newvar, all_res_avals) def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 2568ea8b74a1..fac798fe9dc1 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -235,7 +235,8 @@ def pad_jaxpr_constvars(jaxpr: jax_core.Jaxpr, to pad each Jaxpr with all consts from all branches so the signatures match, but only use the consts for this branch. """ - unused_const_vars = [tuple(map(jax_core.Var, const_avals)) + newvar = jax_core.gensym(suffix='_') + unused_const_vars = [tuple(map(newvar, const_avals)) for const_avals in all_const_avals] const_prefix = util.concatenate(unused_const_vars[:i]) const_suffix = util.concatenate(unused_const_vars[i + 1:]) diff --git a/tests/core_test.py b/tests/core_test.py index 334df2222b0c..646705ebf281 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -551,6 +551,16 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr + def test_jaxpr_undefined_eqn_invar(self): + jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr + cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') + cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval) + self.assertRaisesRegex( + core.JaxprTypeError, + r"Variable '.+_test' not defined\n\nin equation:", + lambda: core.check_jaxpr(jaxpr)) + @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): From 64ba9bc454231344f17ef8e5568067292857e017 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 6 Jun 2025 06:15:14 -0700 Subject: [PATCH 1562/1769] [Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_ref`. The new function `_transformed_smem_ref_type` will be used in a follow up change. PiperOrigin-RevId: 768045658 --- .../mosaic/gpu/dialect_lowering.py | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index e9293d9ffe08..4029131baafb 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -614,24 +614,18 @@ def _is_memref_transposed(mem_ref_type: ir.MemRefType) -> bool: return False -def reinterpret_smem_ref( - ref: ir.Value, +def _transformed_smem_ref_type( + ref_ty: ir.MemRefType, transforms: tuple[launch_context.MemRefTransform, ...], -) -> ir.Value: - """Applies transforms on the ref, and makes sure that their effect is - propagated appropriately on the strides. - - This function is used any time we lower from a dialect SMEM ref (2D for wgmma) - with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully - transformed and transposed as needed. +) -> ir.MemRefType: + """Returns the transformed ref type for the given logical ref and transforms. """ - ref_ty = ir.MemRefType(ref.type) transposed = _is_memref_transposed(ref_ty) if not transforms and not transposed: - return ref + return ref_ty if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError(f"Only workgroup memory is supported but got {ref}.") + raise ValueError(f"Only workgroup memory is supported but got {ref_ty}.") shape = ref_ty.shape if transposed: @@ -659,23 +653,43 @@ def reinterpret_smem_ref( raise NotImplementedError( f"Expected a 2D or 4D shape after transforms, but got {shape}" ) - strides = [1]*len(shape) - for i in minor_to_major_stride_order[1:]: - strides[i] = strides[i-1] * shape[i-1] - layout = ir.StridedLayoutAttr.get(0, strides) else: - layout = None + minor_to_major_stride_order = tuple(reversed(range(len(shape)))) + + new_strides = [1] * len(shape) + for i in range(1, len(shape)): + dim = minor_to_major_stride_order[i] + prev_dim = minor_to_major_stride_order[i-1] + new_strides[dim] = new_strides[prev_dim] * shape[prev_dim] new_ref_ty = ir.MemRefType.get( shape, ref_ty.element_type, memory_space=ref_ty.memory_space, - layout=layout, + layout=ir.StridedLayoutAttr.get(0, new_strides), ) + return new_ref_ty + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], +) -> ir.Value: + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. + + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + new_ref_ty = _transformed_smem_ref_type(ref_ty, transforms) + if ref_ty == new_ref_ty: + return ref ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE ptr = utils.memref_ptr(ref, memory_space=ms) - ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) - return ref + new_ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return new_ref @_register_lowering(mgpu.AsyncLoadOp) From 8478fe6e7b43e1410c79f89d603ea6272e29f115 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Jun 2025 07:37:12 -0700 Subject: [PATCH 1563/1769] Port PartitionSpec to C++. We port: * the class itself * `__init__` * `__hash__` * `__eq__` which is enough to get a small speedup. We also do not change the representation of the data as two Python tuples for the moment. PiperOrigin-RevId: 768069350 --- jax/BUILD | 4 + jax/_src/partition_spec.py | 96 +++++++++++------- jaxlib/BUILD | 2 + jaxlib/_jax/__init__.pyi | 19 +++- jaxlib/partition_spec.cc | 201 +++++++++++++++++++++++++++++++++++++ jaxlib/partition_spec.h | 62 ++++++++++++ jaxlib/xla.cc | 2 + jaxlib/xla_client.py | 2 +- 8 files changed, 347 insertions(+), 41 deletions(-) create mode 100644 jaxlib/partition_spec.cc create mode 100644 jaxlib/partition_spec.h diff --git a/jax/BUILD b/jax/BUILD index 80add9c096bd..b54c4afce18e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1113,6 +1113,10 @@ pytype_strict_library( pytype_strict_library( name = "partition_spec", srcs = ["_src/partition_spec.py"], + deps = [ + ":util", + "//jax/_src/lib", + ], ) pytype_strict_library( diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 629af80ed38b..435a1cc50669 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -14,49 +14,61 @@ from __future__ import annotations from collections.abc import Set -from typing import Any +from typing import Any, TYPE_CHECKING -class UnconstrainedSingleton: +from jax._src.lib import jaxlib_extension_version +from jax._src.lib import _jax +from jax._src.util import use_cpp_class, use_cpp_method, set_module - def __repr__(self): - return "UNCONSTRAINED" +export = set_module('jax.sharding') - def __reduce__(self): - return (_get_default_unconstrained, ()) +# TODO(phawkins): the union confuses pytype. Just use the Python branch for now +# until the C++ version is the minimum version. +if not TYPE_CHECKING and jaxlib_extension_version >= 349: + _UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION + _canonicalize_partition = _jax.canonicalize_partition +else: + class UnconstrainedSingleton: + + def __repr__(self): + return "UNCONSTRAINED" + def __reduce__(self): + return (_get_default_unconstrained, ()) -# Unconstrained sentinel value for PartitionSpec, representing a dimension for -# which the user wants XLA to assign the best partitioning. -# TODO(yashkatariya): May rename to AUTO. -_UNCONSTRAINED_PARTITION = UnconstrainedSingleton() -def _get_default_unconstrained(): - return _UNCONSTRAINED_PARTITION + # Unconstrained sentinel value for PartitionSpec, representing a dimension for + # which the user wants XLA to assign the best partitioning. + # TODO(yashkatariya): May rename to AUTO. + _UNCONSTRAINED_PARTITION = UnconstrainedSingleton() -def _canonicalize_partition(partition): - if not partition: - return None - if partition is _UNCONSTRAINED_PARTITION: + def _get_default_unconstrained(): return _UNCONSTRAINED_PARTITION - if isinstance(partition, (tuple, list)): - if len(partition) == 1: - return partition[0] - return tuple(partition) - return partition - -def _check(partitions, unreduced): - for p in partitions: - p = p if isinstance(p, tuple) else (p,) - for r in p: - if r in unreduced: - raise ValueError( - "partitions cannot overlap with unreduced axes passed to" - f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" - f" {unreduced}") - if None in unreduced: - raise ValueError( - "unreduced cannot contain None. All elements in unreduced should refer" - " to the mesh axes.") + + def _canonicalize_partition(partition): + if not partition: + return None + if partition is _UNCONSTRAINED_PARTITION: + return _UNCONSTRAINED_PARTITION + if isinstance(partition, (tuple, list)): + if len(partition) == 1: + return partition[0] + return tuple(partition) + return partition + + def _check(partitions, unreduced): + for p in partitions: + p = p if isinstance(p, tuple) else (p,) + for r in p: + if r in unreduced: + raise ValueError( + "partitions cannot overlap with unreduced axes passed to" + f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" + f" {unreduced}") + if None in unreduced: + raise ValueError( + "unreduced cannot contain None. All elements in unreduced should refer" + " to the mesh axes.") def unpickle_pspec(partitions, unreduced): return PartitionSpec(*partitions, unreduced=unreduced) @@ -72,12 +84,14 @@ class PartitionSpec: This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ - __slots__ = ("_partitions", "unreduced") + if jaxlib_extension_version < 349: + __slots__ = ("_partitions", "unreduced") __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION + @use_cpp_method() def __init__(self, *partitions, unreduced: Set[AxisName] | None = None): self._partitions = tuple(_canonicalize_partition(p) for p in partitions) @@ -108,6 +122,7 @@ def __iter__(self): def __len__(self): return len(self._partitions) + @use_cpp_method() def __eq__(self, other): if isinstance(other, PartitionSpec): return (self._partitions == other._partitions and @@ -122,6 +137,7 @@ def __eq__(self, other): else: return False + @use_cpp_method() def __hash__(self): return hash((self._partitions, self.unreduced)) @@ -167,3 +183,11 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: if len(out) < ndim: out.extend([None] * (ndim - len(out))) return self.with_partitions(out) + +# TODO(phawkins): make this a decorator after the next jaxlib release. +if not TYPE_CHECKING and jaxlib_extension_version >= 349: + PartitionSpec = use_cpp_class(_jax.PartitionSpec)(PartitionSpec) + +# TODO(phawkins): make this a decorator after the next jaxlib release. +if not TYPE_CHECKING: + PartitionSpec = export(PartitionSpec) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 4eff447c1b68..da5fb4952743 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -816,6 +816,7 @@ cc_library( cc_library( name = "py_client", srcs = [ + "partition_spec.cc", "py_array.cc", "py_client.cc", "py_compile_only_client.cc", @@ -829,6 +830,7 @@ cc_library( "to_ifrt_sharding.cc", ], hdrs = [ + "partition_spec.h", "py_array.h", "py_client.h", "py_compile_only_client.h", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index bbdc6bbdf896..7bab0fc35547 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -16,7 +16,7 @@ from __future__ import annotations import builtins -from collections.abc import Callable, Iterator, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence, Set import enum import inspect import types @@ -552,7 +552,6 @@ class Client: ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... - class CompileOnlyPyClient(Client): def compile( self, @@ -561,7 +560,6 @@ class CompileOnlyPyClient(Client): compile_options: CompileOptions = ..., ) -> Executable: ... - class CpuCollectives: ... def make_gloo_tcp_collectives( @@ -1004,5 +1002,18 @@ def approx_top_k_reduction_output_size( aggregate_to_topk: bool | None = ..., input_size_override: int | None = ..., ) -> tuple[int, int]: ... - def get_internal_device_put_info() -> dict[str, int]: ... + +class UnconstrainedSingleton: + def __repr__(self) -> str: ... + def __reduce__(self) -> Any: ... + +UNCONSTRAINED_PARTITION: UnconstrainedSingleton + +class PartitionSpec: + def __init__(self, *partitions, unreduced: Set[Any] | None = None): ... + def __hash__(self): ... + def __eq__(self, other): ... + _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def canonicalize_partition(partition: Any) -> Any: ... diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc new file mode 100644 index 000000000000..6af8b92e1b56 --- /dev/null +++ b/jaxlib/partition_spec.cc @@ -0,0 +1,201 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/partition_spec.h" + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep + +namespace nb = nanobind; + +namespace jax { + +/*static*/ PyObject* nb_frozenset::nb_frozenset_from_obj(PyObject* o) { + PyObject* result = PyFrozenSet_New(o); + if (!result) { + throw nb::python_error(); + } + return result; +} + +template +bool nb_frozenset::contains(T&& key) const { + object o = nanobind::cast((nb::detail::forward_t)key); + int rv = PySet_Contains(m_ptr, o.ptr()); + if (rv == -1) { + throw nb::python_error(); + } + return rv == 1; +} + +namespace { + +bool IsTrue(nb::handle x) { + int ret = PyObject_IsTrue(x.ptr()); + if (ret == -1) { + throw nb::python_error(); + } + return static_cast(ret); +} + +nb::object CanonicalizePartition(nb::object unconstrained_singleton, + nb::object partition) { + if (!IsTrue(partition)) { + return nb::none(); + } + if (partition.is(unconstrained_singleton)) { + return unconstrained_singleton; + } + bool is_tuple = nb::isinstance(partition); + if (is_tuple || nb::isinstance(partition)) { + if (nb::len(partition) == 1) { + return partition[0]; + } + if (!is_tuple) { + return nb::tuple(partition); + } + return partition; + } + return partition; +} + +void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced) { + if (unreduced.contains(nb::none())) { + throw nb::value_error( + "unreduced cannot contain None. All elements in unreduced should " + "refer to the mesh axes."); + } + auto check_overlap = [&](nb::handle partition) { + if (unreduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with unreduced axes passed to " + "PartitionSpec. Got partitions: %s and unreduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(unreduced))) + .c_str()); + } + }; + for (nb::handle partition : partitions) { + if (nb::isinstance(partition)) { + for (nb::handle p : partition) { + check_overlap(p); + } + } else { + check_overlap(partition); + } + } +} + +} // namespace + +PartitionSpec::PartitionSpec(nb::tuple partitions, nb_frozenset unreduced) + : partitions_(std::move(partitions)), unreduced_(std::move(unreduced)) {} + +Py_ssize_t PartitionSpec::Hash() const { + size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +bool PartitionSpec::Eq(const nb::object& other) const { + if (!other.ptr()) { + return false; + } + PartitionSpec* other_spec; + if (nb::try_cast(other, other_spec)) { + return partitions().equal(other_spec->partitions()) && + unreduced().equal(other_spec->unreduced()); + } + nb::tuple other_tuple; + if (nb::try_cast(other, other_tuple)) { + if (unreduced().size() > 0 || partitions().size() != other_tuple.size()) { + return false; + } + for (size_t i = 0; i < partitions().size(); ++i) { + if (!partitions()[i].equal(CanonicalizePartition( + *unconstrained_singleton_, other_tuple[i]))) { + return false; + } + } + return true; + } + return false; +} + +nb::object* PartitionSpec::unconstrained_singleton_ = nullptr; + +void PartitionSpec::Register(nb::module_& m) { + nb::class_(m, "UnconstrainedSingleton") + .def("__repr__", [](nb::handle self) { return nb::str("UNCONSTRAINED"); }) + .def("__reduce__", + [](nb::handle self) { return nb::str("UNCONSTRAINED_PARTITION"); }); + + unconstrained_singleton_ = new nb::object(nb::cast(UnconstrainedSingleton())); + m.attr("UNCONSTRAINED_PARTITION") = *unconstrained_singleton_; + + m.def("canonicalize_partition", [](nb::object partition) { + return CanonicalizePartition(*unconstrained_singleton_, partition); + }); + + nb::class_(m, "PartitionSpec") + .def( + "__init__", + [](PartitionSpec* self, nb::args partition_args, + nb::object unreduced_arg) { + nb::tuple partitions = + nb::steal(PyTuple_New(partition_args.size())); + for (size_t i = 0; i < partition_args.size(); ++i) { + PyTuple_SET_ITEM(partitions.ptr(), i, + CanonicalizePartition( + *PartitionSpec::unconstrained_singleton_, + partition_args[i]) + .release() + .ptr()); + } + nb_frozenset unreduced; + if (unreduced_arg.is_none()) { + unreduced = nb_frozenset(); + } else { + if (!PyAnySet_Check(unreduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "unreduced argument of PartitionSpec should be `None` " + "or of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(unreduced_arg.type()))) + .c_str()); + } + unreduced = nb_frozenset(unreduced_arg); + } + CheckPartitionSpec(partitions, unreduced); + new (self) + PartitionSpec(std::move(partitions), std::move(unreduced)); + }, + nb::arg("partitions"), nb::arg("unreduced").none() = nb::none()) + .def_prop_ro("_partitions", &PartitionSpec::partitions) + .def_prop_ro("unreduced", &PartitionSpec::unreduced) + .def("__eq__", &PartitionSpec::Eq, nb::arg().none()) + .def("__hash__", &PartitionSpec::Hash); +} + +} // namespace jax diff --git a/jaxlib/partition_spec.h b/jaxlib/partition_spec.h new file mode 100644 index 000000000000..45f577557926 --- /dev/null +++ b/jaxlib/partition_spec.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_PARTITION_SPEC_H_ +#define JAX_JAXLIB_PARTITION_SPEC_H_ + +#include + +#include "nanobind/nanobind.h" + +namespace jax { + +struct UnconstrainedSingleton {}; + +class nb_frozenset : public nanobind::object { + NB_OBJECT(nb_frozenset, object, "frozenset", PyFrozenSet_Check) + nb_frozenset() + : object(PyFrozenSet_New(nullptr), nanobind::detail::steal_t()) {} + explicit nb_frozenset(handle h) + : object(nb_frozenset_from_obj(h.ptr()), nanobind::detail::steal_t{}) {} + size_t size() const { return (size_t)NB_SET_GET_SIZE(m_ptr); } + template + bool contains(T&& key) const; + + private: + static PyObject* nb_frozenset_from_obj(PyObject* o); +}; + +class PartitionSpec { + public: + PartitionSpec(nanobind::tuple partitions, nb_frozenset unreduced); + + nanobind::tuple partitions() const { return partitions_; } + nb_frozenset unreduced() const { return unreduced_; } + + bool Eq(const nanobind::object& other) const; + Py_ssize_t Hash() const; + + static void Register(nanobind::module_& m); + + private: + nanobind::tuple partitions_; + nb_frozenset unreduced_; + + static nanobind::object* unconstrained_singleton_; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_PARTITION_SPEC_H_ diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 186bcaf2efd3..69acd43b804a 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -95,6 +95,7 @@ limitations under the License. #include "jaxlib/jax_jit.h" #include "jaxlib/mlir.h" #include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" #include "jaxlib/pjit.h" #include "jaxlib/pmap_lib.h" #include "jaxlib/py_array.h" @@ -963,6 +964,7 @@ NB_MODULE(_jax, m) { m.def("get_internal_device_put_info", []() { return DevicePutInfo::GetInfo(); }); + jax::PartitionSpec::Register(m); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index a098bd4baf3e..e6104a9958c2 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 348 +_version = 349 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From e4c8da1ee3a191518cf75140a61c0302aee3222e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Jun 2025 09:37:16 -0700 Subject: [PATCH 1564/1769] Fix segfault if None is passed to PartitionSpec.__eq__. PiperOrigin-RevId: 768107991 --- jaxlib/partition_spec.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc index 6af8b92e1b56..3cb58f110c4d 100644 --- a/jaxlib/partition_spec.cc +++ b/jaxlib/partition_spec.cc @@ -119,7 +119,7 @@ Py_ssize_t PartitionSpec::Hash() const { } bool PartitionSpec::Eq(const nb::object& other) const { - if (!other.ptr()) { + if (!other.ptr() || other.is_none()) { return false; } PartitionSpec* other_spec; From 09111c8e1b9c3b01b254c12e37454a684fe514b2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 6 Jun 2025 10:30:33 -0700 Subject: [PATCH 1565/1769] [Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem_to_smem. PiperOrigin-RevId: 768128819 --- jax/_src/pallas/mosaic_gpu/primitives.py | 60 ++++++++++++++++++++-- tests/pallas/mosaic_gpu_test.py | 64 ++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 3fe9b62f3f1e..253bf7180ddc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -487,6 +487,7 @@ def _copy_gmem_to_smem_lowering( dst_transforms_treedef, barrier_transforms_treedef, collective_axes, + partitioned_axis, for_warpgroup: bool = True, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( @@ -535,6 +536,10 @@ def _copy_gmem_to_smem_lowering( # arrive with the whole transfer size, while everyone else arrives with 0. # But we should continue using this scheme as it's likely to be faster. bytes //= WARPGROUP_SIZE + if collective and partitioned_axis is not None: + raise NotImplementedError( + "Collective partitioned copies not implemented." + ) if ctx.module_ctx.auto_barriers: mgpu.warpgroup_barrier() # Make sure all reads have completed. barrier.arrive_expect_tx(bytes) @@ -542,9 +547,32 @@ def _copy_gmem_to_smem_lowering( # In Warp-level lowering, we arrive on each CUDA thread in a warp, but # the barrier still expects a full 128 arrivals so we arrive 4 times # on each CUDA thread instead. - bytes //= WARP_SIZE - barrier.arrive(arrival_count=3, can_complete=False) - barrier.arrive_expect_tx(bytes) + # TODO(justinfu): The arrival counts are wrong if called outside of a + # single warp. Figure out how to guard against this in user code. + bytes = bytes // WARP_SIZE + if collective and partitioned_axis is not None: + if len(collective) != 1: + raise ValueError( + f"Expected exactly one collective axis, got {collective_axes=}" + ) + if math.prod(ctx.launch_ctx.cluster_size) != 2: + raise NotImplementedError( + "Partitioned loads only supported for clusters of size 2" + ) + # Bytes is the destination size, which is only half of the total + # size of the partitioned transfer so we need to double it. + bytes *= 2 + first_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(collective[0]), + mgpu.c(0, ir.IndexType.get()), + ) + with mgpu.when(first_block): + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + else: + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) ctx.launch_ctx.async_copy( src_ref=src, @@ -553,6 +581,7 @@ def _copy_gmem_to_smem_lowering( arrive=False, predicate=ctx.module_ctx.single_lane_predicate, collective=collective, + partitioned=partitioned_axis, **copy_params, ) return () @@ -595,9 +624,33 @@ def copy_gmem_to_smem( barrier: _Ref, *, collective_axes: str | tuple[str, ...] | None = None, + partitioned_axis: int | None = None, ) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. + If collective_axes is specified, this performs a multicast copy where + all CUDA blocks that share the same index along the collective axis + receive a copy of the same block of data loaded from `dst` to `src`. + + If both collective_axes and partitioned_axis are specified, this will perform + a partitioned collective copy where each block in the cluster will receive + a tile of `transfer_size // cluster_size` data from the `src` Ref. + For example, if `src` has a shape of (256, 256) and a partitioned + copy is performed along axis 0 with cluster size 2, then the first block will + receive `src[0:128, :]` and the second will receive `src[128:256, :]`. + NOTE: Only the first block in the cluster will arrive on the barrier, + and an additional cluster barrier is necessary to ensure that all blocks in + the cluster have finished the copy. + + Args: + src: The source Ref. Must be in GMEM. + dst: The destination Ref. Must be in SMEM. + barrier: The barrier to use for tracking completion of the copy. + collective_axes: The collective axes to use for the copy. + partitioned_axis: Indicates which array axis along the src/dst Refs to + partition across during a partitioned collective copy. Requires + collective_axes to also be specified. + See also: :func:`jax.experimental.mosaic.gpu.barrier_arrive` :func:`jax.experimental.mosaic.gpu.barrier_wait` @@ -633,6 +686,7 @@ def copy_gmem_to_smem( dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, collective_axes=collective_axes, + partitioned_axis=partitioned_axis, ) return None diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index ff4f29aefb59..6817a80c1005 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2762,6 +2762,70 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) + def test_collective_partitioned_copy(self): + self.skip_if_wg_semantics() + block_size = (128, 128) + partitioned_block_size = (block_size[0] // 2, block_size[1]) + a = jax.random.uniform( + jax.random.key(0), shape=block_size, dtype=jnp.float32) + b = jax.random.uniform( + jax.random.key(1), shape=block_size, dtype=jnp.float32) + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, out_smem, + a_tma_barrier, b_tma_barrier, cluster_barrier): + cluster_idx = lax.axis_index("x") + out_slice = pl.ds(cluster_idx * partitioned_block_size[0], + partitioned_block_size[0]) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem( + a_gmem, + a_smem, + a_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + plgpu.copy_gmem_to_smem( + b_gmem, + b_smem, + b_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + # TODO(justinfu): Clean up this API where we need to explicitly wait + # only on the first block. + @pl.when(cluster_idx == 0) + def _(): + plgpu.barrier_wait(a_tma_barrier) + plgpu.barrier_wait(b_tma_barrier) + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + out_smem[...] = a_smem[...] + b_smem[...] + plgpu.copy_smem_to_gmem(out_smem, out_gmem.at[out_slice]) + plgpu.wait_smem_to_gmem(0) + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(block_size, jnp.float32), + grid=(1,), + grid_names=("_"), + cluster_names=("x",), + cluster=(2,), + scratch_shapes=( # type: ignore + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(num_arrivals=1), + plgpu.ClusterBarrier(collective_axes=("x",)), + ), + ) + result = f(a, b) + np.testing.assert_array_equal(result, a + b) + class PallasCallSm100AWGTest( PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From 7dd721a89efb6a1fff4046535a392ab5493ffb9f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Jun 2025 11:07:58 -0700 Subject: [PATCH 1566/1769] Optimize jaxpr equation pretty-printing. * Don't spend time on annotation formatting if there are no annotations. * Remove an assertion that the children of a ConcatDoc are Docs. * Fuse _align_annotations with the code that produces strings, which saves allocating a NamedTuple per line. * Don't call .format() to test whether the LHS of a jaxpr equation prints as empty. PiperOrigin-RevId: 768144149 --- jax/_src/core.py | 2 +- jax/_src/pretty_printer.py | 55 ++++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ec6fc47235f1..db4366bf735b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3404,7 +3404,7 @@ def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings, rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation), pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings), pp.text(" ") + pp_vars(eqn.invars, context)] - if lhs.format(): + if eqn.outvars: return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs]) else: return pp.concat(rhs) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index d02b6d9962e0..56a99161fb63 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -93,9 +93,15 @@ def __str__(self): def __add__(self, other: Doc) -> Doc: return concat([self, other]) + def num_annotations(self) -> int: + raise NotImplementedError() + class _NilDoc(Doc): def __repr__(self): return "nil" + def num_annotations(self) -> int: + return 0 + _nil = _NilDoc() class _TextDoc(Doc): @@ -115,16 +121,23 @@ def __repr__(self): else: return f"text(\"{self.text}\")" + def num_annotations(self) -> int: + return 1 if self.annotation is not None else 0 + class _ConcatDoc(Doc): - __slots__ = ("children",) + __slots__ = ("children", "_num_annotations") children: list[Doc] + _num_annotations: int def __init__(self, children: Sequence[Doc]): self.children = list(children) - assert all(isinstance(doc, Doc) for doc in self.children), self.children + self._num_annotations = sum(child.num_annotations() for child in children) def __repr__(self): return f"concat({self.children})" + def num_annotations(self) -> int: + return self._num_annotations + class _BreakDoc(Doc): __slots__ = ("text",) text: str @@ -135,6 +148,9 @@ def __init__(self, text: str): def __repr__(self): return f"break({self.text})" + def num_annotations(self) -> int: + return 0 + class _GroupDoc(Doc): __slots__ = ("child",) child: Doc @@ -145,6 +161,9 @@ def __init__(self, child: Doc): def __repr__(self): return f"group({self.child})" + def num_annotations(self) -> int: + return self.child.num_annotations() + class _NestDoc(Doc): __slots__ = ("n", "child",) n: int @@ -157,6 +176,8 @@ def __init__(self, n: int, child: Doc): def __repr__(self): return f"nest({self.n, self.child})" + def num_annotations(self) -> int: + return self.child.num_annotations() _NO_SOURCE = object() @@ -172,6 +193,8 @@ def __init__(self, child: Doc, source: Any): def __repr__(self): return f"source({self.child}, {self.source})" + def num_annotations(self) -> int: + return self.child.num_annotations() Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", "MAGENTA", "CYAN", "WHITE", "RESET"]) @@ -193,6 +216,8 @@ def __init__(self, child: Doc, *, foreground: Color | None = None, self.background = background self.intensity = intensity + def num_annotations(self) -> int: + return self.child.num_annotations() _BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) @@ -225,6 +250,8 @@ def _fits(doc: Doc, width: int) -> bool: # annotations. def _sparse(doc: Doc) -> bool: agenda = [doc] + if doc.num_annotations() == 0: + return True num_annotations = 0 seen_break = False while len(agenda) > 0: @@ -266,7 +293,7 @@ class _State(NamedTuple): class _Line(NamedTuple): text: str width: int - annotations: str | None | list[str] + annotations: list[str] def _update_color(use_color: bool, state: _ColorState, update: _ColorState @@ -284,22 +311,19 @@ def _update_color(use_color: bool, state: _ColorState, update: _ColorState return update, color_str -def _align_annotations(lines): +def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]: # TODO: Hafiz also implements a local alignment mode, where groups of lines # with annotations are aligned together. maxlen = max(l.width for l in lines) out = [] for l in lines: if len(l.annotations) == 0: - out.append(l._replace(annotations=None)) - elif len(l.annotations) == 1: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) + out.append(l.text) else: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) + out.append(f"{l.text}{' ' * (maxlen - l.width)}" + f"{annotation_prefix}{l.annotations[0]}") for a in l.annotations[1:]: - out.append(_Line(text=" " * maxlen, width=l.width, annotations=a)) + out.append(f"{' ' * maxlen}{annotation_prefix}{a}") return out @@ -366,7 +390,7 @@ def _format( elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) and _fits(doc, width - k)): + if (_fits(doc, width - k) and _sparse(doc)): agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) @@ -390,11 +414,8 @@ def _format( line_source_map.append((source_start, pos, source)) source_map.append(line_source_map) lines.append(_Line(line_text, k, line_annotations)) - lines = _align_annotations(lines) - out = "\n".join( - l.text if l.annotations is None - else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines) - color_state, color_str = _update_color(use_color, color_state, + out = "\n".join(_align_annotations(lines, annotation_prefix)) + _, color_str = _update_color(use_color, color_state, default_colors) return out + color_str From 58a193728585a82a622141699025b7376ababdb3 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 6 Jun 2025 12:23:50 -0700 Subject: [PATCH 1567/1769] [JAX] Allow registering callbacks to be called when backends are cleared `jax.extend.backend` allows the user to register a callback that will be called when JAX backends are cleared. The primary purpose is to let the user clear any caches that hold a reference to JAX backends transitively (via JAX `Sharding`/`Mesh`/`Device`) so that it can help destroy cleared backends. PiperOrigin-RevId: 768173340 --- jax/extend/BUILD | 2 +- jax/extend/backend.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 1147e6bf502f..0615a25d1512 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -61,8 +61,8 @@ pytype_strict_library( name = "backend", srcs = ["backend.py"], deps = [ - "//jax", "//jax:api", + "//jax:util", "//jax:xla_bridge", ], ) diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 8d5488baba16..12c84ecd1f20 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -27,3 +27,8 @@ from jax._src.interpreters.pxla import ( get_default_device as get_default_device ) +from jax._src import ( + util as _util +) +add_clear_backends_callback = _util.cache_clearing_funs.add # type: ignore +del _util From ec61161de1f7a71419653798ca99ba9dcdb2faea Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 6 Jun 2025 13:00:14 -0700 Subject: [PATCH 1568/1769] PRNGKeyArray doesn't have a format field so assign the layout to be None when we are discharging the ref. The reason PRNGKeyArray doesn't have a `format` field is we don't know how to create a logical format.dll for it. PiperOrigin-RevId: 768186349 --- jax/_src/interpreters/pxla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 55f765ba981a..503eec676acb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2095,7 +2095,8 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) - in_layouts = (*in_layouts, *(c.format.dll for c in mut.in_mut)) + in_layouts = (*in_layouts, *(c.format.dll if hasattr(c, 'format') else None + for c in mut.in_mut)) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) out_layouts_ = iter(zip(out_shardings, out_layouts)) out_shardings, out_layouts = unzip2( From c9289ae10354ad4425e73dcfabb80ebd8b50333c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 6 Jun 2025 13:01:17 -0700 Subject: [PATCH 1569/1769] Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5 PiperOrigin-RevId: 768186737 --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/core.py | 20 +++++++------------- jax/_src/interpreters/batching.py | 6 +++--- jax/_src/interpreters/partial_eval.py | 13 +++++-------- jax/_src/lax/control_flow/common.py | 5 ++--- jax/_src/lax/control_flow/conditionals.py | 6 ++---- jax/_src/pallas/hlo_interpreter.py | 3 +-- tests/core_test.py | 10 ---------- 8 files changed, 21 insertions(+), 44 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 2a056d5c94f0..b11afd4a86de 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -621,7 +621,7 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: if v not in used_vars: continue assert isinstance(v, core.Var) - newvar = core.Var(v.suffix, v.aval) + newvar = core.Var(v.aval) finfo = dtypes.finfo(v.aval.dtype) params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) if v in constvars or v in invars: diff --git a/jax/_src/core.py b/jax/_src/core.py index db4366bf735b..b0ebee9d6082 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -433,37 +433,31 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() class Var: - __slots__ = ["count", "suffix", "aval", "initial_qdd", "final_qdd"] + __slots__ = ["count", "aval", "initial_qdd", "final_qdd"] count: int - suffix: str aval: AbstractValue # these are only useful for jaxpr binders but rather than create a separate # type for those, breaking existing interpreters, we add fields here. initial_qdd : QuasiDynamicData | None final_qdd : QuasiDynamicData | None - def __init__( - self, suffix: str, aval: AbstractValue, initial_qdd=None, final_qdd=None - ): + def __init__(self, aval: AbstractValue, initial_qdd = None, final_qdd = None): + assert isinstance(aval, AbstractValue) self.count = next(_var_counter) - self.suffix = suffix self.aval = aval self.initial_qdd = initial_qdd self.final_qdd = final_qdd def __repr__(self): - return f"Var(id={id(self)}){self.suffix}:{self.aval.str_short()}" + return f'Var(id={id(self)}):{self.aval.str_short()}' def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del print_dtype # unused - return f"{context.var_names[self]}{self.suffix}" - + return f"{context.var_names[self]}" -def gensym(suffix: str = "") -> Callable: - """Produce distinct variables, printed with the optional suffix.""" - return partial(Var, suffix) +gensym = lambda: Var # In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that # the assignment is dropped, i.e. that an expression's output value will never @@ -471,7 +465,7 @@ def gensym(suffix: str = "") -> Callable: # treat it as a special case of one. Its `aval` is similarly inexact. class DropVar(Var): def __init__(self, aval: AbstractValue): - super().__init__("", aval) + super().__init__(aval) def __repr__(self): return '_' def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): del context, print_dtype # unused diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 55769aa307fc..ad5b0b4f408b 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -111,7 +111,7 @@ def _jumble_unflatten(aval, x): register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) if stacked_axis != 0: raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 shape = list(x.shape) @@ -175,7 +175,7 @@ def bdim_as_shape( bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: if isinstance(bdim, RaggedAxis): result = list(data_shape) - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) for ragged_axis, segment_lens in bdim.ragged_axes: result[ragged_axis] = IndexedAxisSize(binder, segment_lens) return tuple(result) @@ -1138,7 +1138,7 @@ def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: x = bdim_at_front(x, src, sz) elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var('', core.ShapedArray((), np.dtype('int32'))), + aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))), x.shape[0], elt_ty) return Jumble(aval, x) try: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index fe029ba3b47a..4360a25238d4 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + # pytype: skip-file from __future__ import annotations @@ -1135,7 +1135,6 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: def has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) - newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) foreach(partial(write, False, True), jaxpr.constvars) @@ -1165,7 +1164,7 @@ def has_effects(effects) -> bool: elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error - resvars = [newvar(v.aval) for v in eqn.outvars] + resvars = [Var(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, @@ -1301,13 +1300,12 @@ def call_partial_eval_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [newvar(res_aval(params_known, var.aval)) + residuals = [Var(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1340,14 +1338,13 @@ def closed_call_partial_eval_custom_rule( ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, sum(f is None for f in out_fwd), num_res, params_known, params_staged) res_val_binders, res_ref_binders = split_list( - [newvar(res_aval(params_known, v)) + [Var(res_aval(params_known, v)) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val]) res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None] res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders) @@ -2722,7 +2719,7 @@ def inline_jaxpr_into_trace( for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] orig_outvars = eqn.outvars - outvars = [Var('', v.aval) for v in orig_outvars] + outvars = [Var(v.aval) for v in orig_outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) eqn = eqn.replace(invars, outvars, source_info=src_) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b75cbf6ac708..cb80df76326b 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -184,9 +184,8 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) - newvar = core.gensym(suffix='_') - padded_ref_constvars = map(newvar, canonical_ref_avals) - padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) + padded_ref_constvars = map(core.Var, canonical_ref_avals) + padded_non_ref_constvars = map(core.Var, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4e8368341d9f..4360c4a6df3b 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -677,8 +677,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): for j in branches_known[1:]) # Create residual variables. - newvar = core.gensym() - res_binders = map(newvar, all_res_avals) + res_binders = map(core.Var, all_res_avals) # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar @@ -763,8 +762,7 @@ def f_aug(*args): def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, res_aval_indices_per_jaxpr): - newvar = core.gensym(suffix='_') - all_res_vars = map(newvar, all_res_avals) + all_res_vars = map(core.Var, all_res_avals) def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index fac798fe9dc1..2568ea8b74a1 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -235,8 +235,7 @@ def pad_jaxpr_constvars(jaxpr: jax_core.Jaxpr, to pad each Jaxpr with all consts from all branches so the signatures match, but only use the consts for this branch. """ - newvar = jax_core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) + unused_const_vars = [tuple(map(jax_core.Var, const_avals)) for const_avals in all_const_avals] const_prefix = util.concatenate(unused_const_vars[:i]) const_suffix = util.concatenate(unused_const_vars[i + 1:]) diff --git a/tests/core_test.py b/tests/core_test.py index 646705ebf281..334df2222b0c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -551,16 +551,6 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) - @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr - def test_jaxpr_undefined_eqn_invar(self): - jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr - cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') - cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval) - self.assertRaisesRegex( - core.JaxprTypeError, - r"Variable '.+_test' not defined\n\nin equation:", - lambda: core.check_jaxpr(jaxpr)) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): From 567d61e2c3e86a582f72fc7f928601ab48c1e542 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 6 Jun 2025 13:04:08 -0700 Subject: [PATCH 1570/1769] Add committed property to MutableArray PiperOrigin-RevId: 768187911 --- jax/_src/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index b0ebee9d6082..a13f2ffd5633 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2346,6 +2346,7 @@ def __init__(self, aval, buf): dtype = property(lambda self: self._aval.dtype) sharding = property(lambda self: self._buf.sharding) format = property(lambda self: self._buf.format) + committed = _committed = property(lambda self: self._buf._committed) def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) From 3248d55af58f09dc875f0bea11c1da7648f103cc Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 6 Jun 2025 13:08:58 -0700 Subject: [PATCH 1571/1769] [Pallas TPU] Add custom_vjp_call lowering rule PiperOrigin-RevId: 768189968 --- jax/_src/pallas/mosaic/lowering.py | 16 +++++++++++++++ tests/pallas/tpu_pallas_test.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index edfce4770f10..db17a25a9e5c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3209,6 +3209,22 @@ def _custom_jvp_call_lowering_rule( return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) +@register_lowering_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_lowering_rule( + ctx: LoweringRuleContext, + *args, + call_jaxpr, + fwd_jaxpr_thunk, + out_trees, + symbolic_zeros, + bwd, + num_consts, +): + if num_consts: raise NotImplementedError + lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) + return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) + + @register_lowering_rule(debugging.debug_callback_p) def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): del ctx, args, kwargs diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index c232ebeedb38..1ddd0a0bc176 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2085,6 +2085,39 @@ def body(scalar_ref, x_ref, o_ref): expected = expected.at[slices].set(x[slices]) np.testing.assert_array_equal(out, expected) + def test_custom_vjp(self): + + @jax.custom_vjp + def f(x): + return jnp.tanh(x) + def f_fwd(x): + return jnp.tanh(x) * 2, () + def f_bwd(_, g): + return (g * 2,) + + f.defvjp(f_fwd, f_bwd) + + def kernel(x_ref, dy_ref, y_ref, y_p_ref, dx_ref): + x = x_ref[...] + y_ref[...] = f(x) + y_p, f_vjp = jax.vjp(f, x) + y_p_ref[...] = y_p + dx_ref[...] = f_vjp(dy_ref[...])[0] + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + dy = jax.random.normal(jax.random.key(1), (8, 128), dtype=jnp.float32) + y, y_p, dx = pl.pallas_call( + kernel, + out_shape=( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + ), + )(x, dy) + np.testing.assert_array_equal(y, f(x)) + np.testing.assert_array_equal(y_p, f(x) * 2) + np.testing.assert_array_equal(dx, dy * 2) + class PallasUXTest(PallasBaseTest): From 051386f01290043870785f65d73259aead888005 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Jun 2025 13:24:24 -0700 Subject: [PATCH 1572/1769] Small speedups to pretty-printing. * Cache the dtype to short name conversion. * Use pp.concat when concatenating more than 2 things. This builds a slightly flatter tree of pretty-printer documents. PiperOrigin-RevId: 768195993 --- jax/_src/core.py | 12 ++++++------ jax/_src/dtypes.py | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index a13f2ffd5633..24c9f7dcd261 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3372,13 +3372,13 @@ def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: if not kv_pairs: return pp.nil() - return pp.group( + return pp.group(pp.concat([ pp.nest(2, pp.concat([ pp.text("["), pp.brk(""), pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs]) - ])) - + pp.brk("") + pp.text("]") - ) + ])), + pp.brk(""), pp.text("]") + ])) def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings ) -> pp.Doc: @@ -3494,10 +3494,10 @@ def pp_jaxpr( def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr], context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] - return pp.group(pp.nest(2, pp.concat([ + return pp.group(pp.concat([pp.nest(2, pp.concat([ pp.text('('), pp.brk(""), pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))] - )) + pp.brk("") + pp.text(')') + )), pp.brk(""), pp.text(')')]) ) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ae3516ea671c..0276f08e7ef4 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -983,6 +983,7 @@ class PrimalTangentDType(ExtendedDType): return PrimalTangentDType() +@functools.cache def short_dtype_name(dtype) -> str: if isinstance(dtype, ExtendedDType): return str(dtype) From 0d1edcce4cfc9dba7f943d0d17502117dd278155 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 6 Jun 2025 14:28:22 -0700 Subject: [PATCH 1573/1769] Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted when a mutable array is closed over PiperOrigin-RevId: 768220283 --- jax/_src/interpreters/pxla.py | 3 ++- jax/_src/pjit.py | 31 +++++++++++++++---------------- tests/pjit_test.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 503eec676acb..3f6ee973554a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2094,7 +2094,8 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts): if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) - in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) + in_shardings = (*in_shardings, *( + pjit.finalize_arg_sharding(c.sharding, c.committed) for c in mut.in_mut)) in_layouts = (*in_layouts, *(c.format.dll if hasattr(c, 'format') else None for c in mut.in_mut)) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index deb1d92f853b..308c3e884973 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1710,6 +1710,20 @@ def _resolve_out_layouts(out_layouts, out_shardings, out_avals): new_out_layouts.append(out_l) return tuple(new_out_layouts) +def finalize_arg_sharding(arg_s, committed): + if isinstance(arg_s, UnspecifiedValue): + return arg_s + else: + if committed: + # If the arg has a PmapSharding, then reshard it unconditionally. + return UNSPECIFIED if isinstance(arg_s, PmapSharding) else arg_s + else: + assert isinstance(arg_s, Sharding) + if dispatch.is_single_device_sharding(arg_s): + return UNSPECIFIED + raise NotImplementedError('Having uncommitted Array sharded on ' + 'multiple devices is not supported.') + def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] ) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it @@ -1744,22 +1758,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty: arg_s, committed = UNSPECIFIED, False if isinstance(pjit_in_s, UnspecifiedValue): - if isinstance(arg_s, UnspecifiedValue): - resolved_in_shardings.append(arg_s) - else: - if committed: - # If the arg has a PmapSharding, then reshard it unconditionally. - if isinstance(arg_s, PmapSharding): - resolved_in_shardings.append(UNSPECIFIED) - else: - resolved_in_shardings.append(arg_s) - else: - assert isinstance(arg_s, Sharding) - if dispatch.is_single_device_sharding(arg_s): - resolved_in_shardings.append(UNSPECIFIED) - else: - raise NotImplementedError('Having uncommitted Array sharded on ' - 'multiple devices is not supported.') + resolved_in_shardings.append(finalize_arg_sharding(arg_s, committed)) else: if (isinstance(arg, np.ndarray) and not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8c9dcc05d406..f2344f7b3ce1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3772,6 +3772,21 @@ def test_sharding_on_output_with_vmap(self): self.assertIsInstance(out3.sharding, NamedSharding) self.assertEqual(count(), 1) + @config.numpy_dtype_promotion('standard') + def test_mutable_array_closed_over_multi_device(self): + mesh = jtu.create_mesh((2,), ('x',)) + key_data = jax.random.key_data(jax.random.key(42)) + key_data_ref = core.mutable_array(key_data) + output_sharding = NamedSharding(mesh, P('x')) + + @partial(jax.jit, out_shardings=output_sharding) + def generate_random_numbers(): + key_val = key_data_ref[...] + outputs = jnp.arange(8, dtype=jnp.float32) + key_val[0] + return outputs + + generate_random_numbers() # doesn't crash + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_jit_mul_sum_sharding_preserved(self): if config.use_shardy_partitioner.value: From 6ae261419b7c6b8955db7ddfd6e3e911fa5436d8 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 6 Jun 2025 17:07:04 -0700 Subject: [PATCH 1574/1769] [Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matmul kernel. PiperOrigin-RevId: 768275894 --- jax/_src/pallas/mosaic_gpu/primitives.py | 9 +- .../pallas/ops/gpu/blackwell_matmul_mgpu.py | 149 ++++++++++-------- 2 files changed, 89 insertions(+), 69 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 253bf7180ddc..862f3affda14 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1319,12 +1319,15 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, raise ValueError("RHS must be an SMEM Ref.") if collective_axis is not None: - if not acc.collective: + # TODO(justinfu): If under a core_map, the avals for acc/a + # become normal MemRefs so we cannot check if they are collective. + # Figure out a way to fix this. + if isinstance(acc, gpu_core.AbstractTMEMRef) and not acc.collective: raise ValueError( "Accumulator Ref must be collective if collective_axis is set.") - if a.memory_space == gpu_core.TMEM and not a.collective: + if isinstance(a, gpu_core.AbstractTMEMRef) and not a.collective: raise ValueError( - "LHS TMEM Ref must be collective if collective_axis is set.") + "LHS Ref must be collective if collective_axis is set.") for_tensor_core = getattr( barrier.inner_aval.dtype, "for_tensor_core", False) diff --git a/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py index df8365f843a8..ad210066c5e0 100644 --- a/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py @@ -27,9 +27,9 @@ @dataclasses.dataclass(frozen=True) class TuningConfig: - block_m: int - block_n: int - block_k: int + tile_m: int + tile_n: int + tile_k: int max_concurrent_steps: int collective: bool @@ -57,72 +57,86 @@ def matmul_kernel(a, b, config: TuningConfig): f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" ) collective = config.collective + tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k) + block_tile_m = tile_m + block_tile_n = tile_n if collective: - raise ValueError("Collective matmul is not supported yet.") - block_m, block_n, block_k = (config.block_m, config.block_n, config.block_k) - swizzle = _find_swizzle(block_k * jnp.dtype(dtype).itemsize * 8) + tile_m *= 2 + tile_n *= 2 + swizzle = _find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) swizzle_elems = swizzle // jnp.dtype(dtype).itemsize transforms = ( plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle), ) - block_lhs = (block_m, block_k) - block_rhs = (block_k, block_n) - block_out = (block_m, block_n) - if m % block_m != 0: - raise ValueError(f"{m=} must be divisible by {block_m=}") - if n % block_n != 0: - raise ValueError(f"{n=} must be divisible by {block_n=}") - if k % block_k != 0: - raise ValueError(f"{k=} must be divisible by {block_k=}") - m_iters = m // block_m - n_iters = n // block_n - k_iters = k // block_k + block_lhs = (block_tile_m, tile_k) + block_rhs = (tile_k, block_tile_n) + block_out = (block_tile_m, tile_n) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k max_concurrent_steps = config.max_concurrent_steps + TMA_WARP = 0 + MMA_WARP = 1 + def kernel(a_gmem, b_gmem, out_gmem, a_smem, b_smem, acc_tmem, acc_smem, a_tma_barrier, b_tma_barrier, consumed_barrier): m_index = lax.axis_index("m") n_index = lax.axis_index("n") - slice_m = pl.ds(m_index * block_m, block_m) - slice_n = pl.ds(n_index * block_n, block_n) - acc_slice_m = pl.ds(m_index * block_m, block_m) - acc_slice_n = pl.ds(n_index * block_n, block_n) + if collective: + cluster_idx = lax.axis_index("x") + block_m_index = m_index * 2 + cluster_idx + is_lead_block = cluster_idx == 0 + else: + block_m_index = m_index + is_lead_block = True + block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m) + slice_m = pl.ds(m_index * tile_m, tile_m) + slice_n = pl.ds(n_index * tile_n, tile_n) @pl.core_map(plgpu.WarpMesh(axis_name="warp")) def _per_warp(): warp_id = lax.axis_index("warp") - - @pl.when(warp_id == 0) + @pl.when(warp_id == TMA_WARP) def _memory(): def _loop_body(ki, _): + slice_k = pl.ds(ki * tile_k, tile_k) slot = lax.rem(ki, max_concurrent_steps) - @pl.when(ki >= max_concurrent_steps) def _(): plgpu.barrier_wait(consumed_barrier.at[slot]) - - slice_k = pl.ds(ki * block_k, block_k) plgpu.copy_gmem_to_smem( a_gmem.at[slice_m, slice_k], a_smem.at[slot], a_tma_barrier.at[slot], + partitioned_axis=0 if collective else None, + collective_axes="x" if collective else None, ) plgpu.copy_gmem_to_smem( b_gmem.at[slice_k, slice_n], b_smem.at[slot], b_tma_barrier.at[slot], + partitioned_axis=1 if collective else None, + collective_axes="x" if collective else None, ) lax.fori_loop(0, k_iters, _loop_body, None) - @pl.when(warp_id == 1) + @pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block)) def _compute(): def _loop_body(ki, _): slot = lax.rem(ki, max_concurrent_steps) plgpu.barrier_wait(a_tma_barrier.at[slot]) plgpu.barrier_wait(b_tma_barrier.at[slot]) + is_last_iter = ki >= k_iters - 1 barrier_slot = lax.select_n(is_last_iter, slot, max_concurrent_steps) @@ -132,46 +146,42 @@ def _loop_body(ki, _): b_smem.at[slot], consumed_barrier.at[barrier_slot], accumulate=(ki > 0), + collective_axis="x" if collective else None, ) + lax.fori_loop(0, k_iters, _loop_body, None) plgpu.barrier_wait(consumed_barrier.at[max_concurrent_steps]) acc_smem[...] = acc_tmem[...].astype(dtype) plgpu.commit_smem() - plgpu.copy_smem_to_gmem( - acc_smem, out_gmem.at[acc_slice_m, acc_slice_n] - ) + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[block_slice_m, slice_n]) plgpu.wait_smem_to_gmem(0) f = plgpu.kernel( kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype), - grid=(m_iters, n_iters), - grid_names=("m", "n"), - # TODO(justinfu): Add collective support. - cluster_names=(), - cluster=(), - scratch_shapes=( # type: ignore - plgpu.SMEM( - (max_concurrent_steps, *block_lhs), dtype, transforms=transforms - ), - plgpu.SMEM( - (max_concurrent_steps, *block_rhs), dtype, transforms=transforms - ), - plgpu.TMEM(block_out, jnp.float32, collective=collective), - plgpu.SMEM(block_out, dtype, transforms=transforms), - plgpu.Barrier( - num_arrivals=1, num_barriers=max_concurrent_steps - ), - plgpu.Barrier( - num_arrivals=1, num_barriers=max_concurrent_steps - ), - plgpu.Barrier( - num_arrivals=1, - num_barriers=max_concurrent_steps + 1, - for_tensor_core=True, - ), - ) + # n, m generally works better for most shapes. + grid=(n_iters, m_iters), + grid_names=("n", "m"), + cluster_names=("x",) if collective else (), + cluster=(2,) if collective else (), + scratch_shapes=( # type: ignore + plgpu.SMEM( + (max_concurrent_steps, *block_lhs), dtype, transforms=transforms + ), + plgpu.SMEM( + (max_concurrent_steps, *block_rhs), dtype, transforms=transforms + ), + plgpu.TMEM(block_out, jnp.float32, collective=collective), + plgpu.SMEM(block_out, dtype, transforms=transforms), + plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps + 1, + for_tensor_core=True, + ), + ), ) return f(a, b) @@ -184,18 +194,22 @@ def main(_) -> None: print(f"==== {M=} {N=} {K=} ====") matmul_flops = 2 * M * N * K peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS - a = jax.random.uniform(jax.random.key(0), (M, K), jnp.bfloat16) - b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16) + a = jax.random.uniform(jax.random.key(0), (M, K), jnp.float16) + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.float16) tuning_it = itertools.product( - (128,), (128, 256), (64, 128), (2, 3, 4), (False,) + (128,), # tile_m + (128, 256), # tile_n + (64, 128), # tile_k + (2, 3, 4, 6), # max_concurrent_steps + (False, True), # collective ) best_util = -float("inf") - for (block_m, block_n, block_k, + for (tile_m, tile_n, tile_k, max_concurrent_steps, collective) in tuning_it: config = TuningConfig( - block_m=block_m, - block_n=block_n, - block_k=block_k, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, max_concurrent_steps=max_concurrent_steps, collective=collective, ) @@ -204,7 +218,9 @@ def main(_) -> None: functools.partial(matmul_kernel, config=config) )(a, b) except ValueError as e: - if "exceeds available shared memory" in e.args[0]: + if ("exceeds available shared memory" in e.args[0] or + "Accumulator layout mismatch:" in e.args[0]): + # Accumulator layout mismatch triggers for tile_n=256 on some configs. continue raise if M * N * K <= 1024 * 1024 * 1024: @@ -216,7 +232,8 @@ def main(_) -> None: if achieved_tc_util > best_util: best_util = achieved_tc_util print( - f"{block_m=} {block_n=} {block_k=} {max_concurrent_steps=}: " + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} " + f"{collective=} : " f"{runtime_us:<7.1f}us" f" = {achieved_tc_util:4.1f}% TC utilization" ) From 3699aa74ef3e95cb267349ce4c6fc54032f4ea86 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 30 May 2025 21:56:43 +0000 Subject: [PATCH 1575/1769] [mutable-arrays] make partial_eval_jaxpr forward input-residuals See the PR message on https://github.com/jax-ml/jax/pull/29311 for more. --- jax/_src/custom_derivatives.py | 88 +++++++++++++---------- jax/_src/interpreters/ad.py | 19 +++-- jax/_src/interpreters/batching.py | 13 +++- jax/_src/interpreters/partial_eval.py | 97 +++++++++++++------------- jax/_src/pallas/fuser/fusible_dtype.py | 2 +- jax/_src/pjit.py | 53 ++++++++------ tests/api_test.py | 14 ++-- tests/custom_api_test.py | 8 +-- tests/debug_info_test.py | 5 +- tests/mutable_array_test.py | 63 ++++++++++++++++- 10 files changed, 232 insertions(+), 130 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 87407efebd3a..af1db5ac9ebc 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -87,7 +87,7 @@ def _flatten_fun_nokwargs(f: Callable, ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - store.store((ans_tree, ans_avals)) + store.store((ans_tree, ans_avals, ())) return ans_flat @@ -301,7 +301,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable in_tree, out_type1) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type1, out_type2) return tree_unflatten(out_tree, out_flat) @partial(lu.transformation_with_aux2, use_eq_store=True) @@ -328,7 +328,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -380,7 +380,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - store.store((out_tree, primal_avals)) + store.store((out_tree, primal_avals, ())) return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): @@ -736,7 +736,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) @lu.transformation2 @@ -744,7 +744,7 @@ def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], debug_info: core.DebugInfo, *args): _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) out = f(*args) - _check_for_returned_refs(f, out, 'primal') + _check_for_returned_refs(f, out, 'primal', [], 0) return out def _check_for_aliased_refs(f: Callable, @@ -763,14 +763,20 @@ def _check_for_aliased_refs(f: Callable, f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and" f" {arg_names[i]}.") -def _check_for_returned_refs(f, out, kind): +def _check_for_returned_refs(f, out, kind, args, after_idx): + ids = {id(x) for x in args if isinstance(core.get_aval(x), AbstractRef)} leaves = tree_leaves_with_path(out) - for path, leaf in leaves: + for i, (path, leaf) in enumerate(leaves): if isinstance((a := core.get_aval(leaf)), AbstractRef): loc = f' at output tree path {keystr(path)}' if path else '' - raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " - f"a array reference of type {a.str_short()}{loc}, " - "but mutable array references cannot be returned.") + if i < after_idx: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc}, " + "but mutable array references cannot be returned there.") + if id(leaf) not in ids: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc} " + "that was not an argument.") @dataclasses.dataclass class CustomVJPPrimal: @@ -825,8 +831,6 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, if config.mutable_array_checks.value: _check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args) pair_out = f(*py_args) - if config.mutable_array_checks.value: - _check_for_returned_refs(f, pair_out, "fwd") if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -839,12 +843,14 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) + if config.mutable_array_checks.value: + _check_for_returned_refs(f, pair_out, "fwd", args, out_tree.num_leaves) primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -876,15 +882,21 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - store.store((out_tree, res_tree)) - return (*res, *primals_out) + pruned_res, input_forwards = _filter_forwarded_inputs(res, args) # prune + store.store((out_tree, res_tree, input_forwards)) + return (*pruned_res, *primals_out) + +def _filter_forwarded_inputs(outs, ins): + idxs: dict[int, int] = {id(x): i for i, x in enumerate(ins)} + return [o for o in outs if id(o) not in idxs], [idxs.get(id(o)) for o in outs] @lu.transformation2 def _flatten_bwd(f: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - out_trees: Callable[[], Sequence[PyTreeDef]], *args): - out_tree, res_tree = out_trees() + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], + *args): + out_tree, res_tree, _ = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) @@ -980,20 +992,19 @@ def get_bind_params(self, params): fwd_jaxpr_thunk = new_params.pop('fwd_jaxpr_thunk') fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), debug_info=call_jaxpr.jaxpr.debug_info) - fwd = lift_fwd(num_consts, fwd_jaxpr_thunk) + fwd = lift_fwd(num_consts, new_params['out_trees'], fwd_jaxpr_thunk) const_avals, _ = split_list(call_jaxpr.in_avals, [num_consts]) bwd = _handle_consts_in_bwd(new_params.pop('bwd'), const_avals) return [fun, fwd, bwd], new_params -def lift_fwd(num_consts: int, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: +def lift_fwd(num_consts: int, out_trees: Callable, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: def fwd(*args): - vals, zeros = args[::2], args[1::2] - assert len(vals) == len(zeros) + vals, nonzeros = args[::2], args[1::2] + assert len(vals) == len(nonzeros) _, primals = split_list(vals, [num_consts]) - const_zeros, in_zeros = split_list(zeros, [num_consts]) - if any(const_zeros): - raise ad.CustomVJPException() - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*in_zeros) + const_nonzeros, in_nonzeros = split_list(nonzeros, [num_consts]) + if any(const_nonzeros): raise ad.CustomVJPException() + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*in_nonzeros) return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *primals) return lu.wrap_init(fwd, debug_info=fwd_jaxpr_thunk.debug_info) @@ -1022,7 +1033,7 @@ def _custom_vjp_call_dce( call_jaxpr: core.ClosedJaxpr = eqn.params["call_jaxpr"] fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"] bwd: lu.WrappedFun = eqn.params["bwd"] - out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"] + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]] = eqn.params["out_trees"] symbolic_zeros: bool = eqn.params["symbolic_zeros"] dce_call_jaxpr: core.ClosedJaxpr used_ins: Sequence[bool] @@ -1034,14 +1045,14 @@ def _custom_vjp_call_dce( @pe._memoize def dce_fwd_jaxpr_thunk(*zeros): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk.call_wrapped(*zeros)) - _, res_tree = out_trees() - num_res = res_tree.num_leaves + _, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) dce_fwd_jaxpr, _ = _cached_closed_call_dce_instantiate( - fwd_jaxpr, (True,) * num_res + tuple(used_outs)) + fwd_jaxpr, (True,) * num_res_out + tuple(used_outs)) return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts def dce_bwd(*args): - _, res_tree = out_trees() + _, res_tree, _ = out_trees() res, cts = split_list(args, [res_tree.num_leaves]) cts_ = iter(cts) all_cts = [] @@ -1293,8 +1304,8 @@ def _maybe_perturbed(x: Any) -> bool: @cache() def _closure_convert_for_avals(fun, in_tree, in_avals, debug_info: core.DebugInfo): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info), - in_tree) + wrapped_fun, out_tree = flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), in_tree) jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() @@ -1641,8 +1652,8 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_avals = [core.get_aval(x) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) - prim_tree, res_tree = out_trees() - num_res = res_tree.num_leaves + prim_tree, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fwd_jaxpr.effects) if disallowed_effects: @@ -1656,9 +1667,12 @@ def fun_jaxpr_thunk(): return jaxpr, consts out_flat = remat_opt_p.bind(*consts, *args_flat, num_consts=len(consts), - num_res=num_res, fwd_jaxpr=fwd_jaxpr, + num_res=num_res_out, fwd_jaxpr=fwd_jaxpr, fun_jaxpr_thunk=fun_jaxpr_thunk) - res, out_flat = split_list(out_flat, [num_res]) + res, out_flat = split_list(out_flat, [num_res_out]) + res_ = iter(res) + res = [next(res_) if f is None else args_flat[f] for f in fwds] + assert next(res_, None) is None out_tree = treedef_tuple((prim_tree, res_tree)) return tree_unflatten(out_tree, (*out_flat, *res)) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 69a123b12e23..51f007f4e576 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -569,8 +569,13 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] in_zeros = [type(t) is Zero for t in tangents_in] nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] @@ -734,7 +739,7 @@ def _f_jvp(primals, tangents): def process_custom_vjp_call(self, prim, fun, fwd, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -746,8 +751,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in_flat) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] in_zeros = [type(t) is Zero for t in tangents_in] diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index ad5b0b4f408b..d97dd124a2fe 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -585,14 +585,21 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims) + def bwd_in_dims(): + _, _, input_fwds = out_trees() + pruned_dims = iter(out_dims2()) + full_dims = [next(pruned_dims) if f is None else in_dims[f] for f in input_fwds] + return [*full_dims, *pruned_dims] + + bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, bwd_in_dims, in_dims) out_vals = prim.bind_with_trace(self.parent_trace, (fun, fwd, bwd) + tuple(in_vals), dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: - _, res_tree = out_trees() - _, out_dims = split_list(out_dims, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res = res_tree.num_leaves - sum(f is not None for f in input_fwds) + _, out_dims = split_list(out_dims, [num_res]) src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 4360a25238d4..1503889b04bc 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -997,74 +997,69 @@ def partial_eval_jaxpr_nounits( passed to jaxpr_unknown (as leading inputs). """ instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate) + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, False)[:-1] + +def partial_eval_jaxpr_nounits_fwd( + jaxpr: ClosedJaxpr, unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], +) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]: + instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True) @weakref_lru_cache -def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, - in_unknowns: Sequence[bool], - instantiate: bool | Sequence[bool]): - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), - debug_info=jaxpr.jaxpr.debug_info) +def _partial_eval_jaxpr_nounits( + jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], fwd: bool): + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) cell = [] def fun(*known_vals_in): - known_vals_in = iter(known_vals_in) + known_vals_in_ = iter(known_vals_in) unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk) in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk - else PartialVal.known(next(known_vals_in)) for uk in in_unknowns] - assert next(known_vals_in, None) is next(unknown_avals, None) is None - jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits( - f, in_pvals, instantiate=instantiate) + else PartialVal.known(next(known_vals_in_)) for uk in in_unknowns] + assert next(known_vals_in_, None) is next(unknown_avals, None) is None + jaxpr_unknown_, (fwds, out_pvals, residuals, ()) = trace_to_subjaxpr_nounits_fwd( + f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] + if not fwd: + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else known_vals_in[f] + for f in fwds] + assert next(residuals_, None) is None + fwds = [None] * len(fwds) + else: + fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals) res_avals = [core.get_aval(r) for r in residuals] - cell.append((out_unknowns, jaxpr_unknown, res_avals)) + cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] known_avals = [a for a, uk in zip(jaxpr.in_aval_qdds, in_unknowns) if not uk] jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( lu.wrap_init(fun, debug_info=f.debug_info), known_avals) - (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking + (out_unknowns, jaxpr_unknown, res_avals, fwds), = cell # pytype: disable=bad-unpacking - # check jaxpr_known and jaxpr_unknown in isolation - # TODO(mattjj): enable weak type checking here if config.enable_checks.value: core.check_jaxpr(jaxpr_known) core.check_jaxpr(jaxpr_unknown) - def check(first, second): - for f, s in zip(first, second): - if (not isinstance(f, core.ShapedArray) and - not isinstance(s, core.ShapedArray)): - assert f == s - elif f.sharding.mesh.empty or s.sharding.mesh.empty: - assert (f.shape, f.dtype) == (s.shape, s.dtype) - else: - assert f == s, (f, s) - - # check jaxpr_known has input type corresponding to known inputs of jaxpr - assert ([v.aval for v in jaxpr_known.invars] == - [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]) - # check jaxpr_known has out type corresponding to known outs of jaxpr plus res - # Change this to `assert ... == ...` and remove the check function. - # See https://github.com/jax-ml/jax/issues/26474 - check([v.aval.strip_weak_type() for v in jaxpr_known.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if not uk] + [a.strip_weak_type() for a in res_avals]) - # check jaxpr_unknown has input type corresponding to res plus unknown inputs - assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.invars] == - [a.strip_weak_type() for a in res_avals] + - [a.strip_weak_type() for a, uk in zip(jaxpr.in_avals, in_unknowns) - if uk]) - # check jaxpr_unknown has output type corresponding to unknown outputs - check([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if uk]) - closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known) closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ()) - return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals + return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals, fwds + +def _include_consts_in_fwds(consts, fwds, residuals): + if all(f is None for f in fwds): + return fwds, residuals + dummys = [object() for _ in range(max(f for f in fwds if f is not None) + 1)] + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else dummys[f] for f in fwds] + assert next(residuals_, None) is None + idxs = {id(x): i for i, x in enumerate((*consts, *dummys))} + fwds = [idxs.get(id(r)) for r in residuals] + residuals = [r for r in residuals if id(r) not in idxs] + return fwds, residuals def partial_eval_jaxpr_custom( @@ -2161,13 +2156,14 @@ def jvp_jaxpr_thunk(*in_zeros): def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) + num_consts = len(consts) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @partial(lu.wrap_init, debug_info=fwd.debug_info) @@ -2179,6 +2175,11 @@ def fwd_jaxpr_from_zeros(*zeros): if attrs: raise NotImplementedError return jaxpr, consts + def out_trees_(): + out_tree, res_tree, input_fwds = out_trees() + input_fwds = [f if f is None else f + num_consts for f in input_fwds] + return out_tree, res_tree, input_fwds + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) @@ -2186,8 +2187,8 @@ def fwd_jaxpr_from_zeros(*zeros): eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, - num_consts=len(consts), - bwd=bwd, out_trees=out_trees, + num_consts=num_consts, + bwd=bwd, out_trees=out_trees_, symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, source_info) diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 152b20ff66ea..53358231dcf8 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -338,7 +338,7 @@ def _custom_vjp_call_physicalize_rule( new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), debug_info=call_jaxpr.jaxpr.debug_info) - fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) + fwd = custom_derivatives.lift_fwd(num_consts, kwargs['out_trees'](), fwd_jaxpr_thunk) fwd_physicalized = _physicalize_transform(fwd) const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index deb1d92f853b..e002635ac35c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1418,6 +1418,8 @@ def _create_pjit_jaxpr( from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) + # TODO(mattjj,yashkatariya): if we take the 'true' path then we *must* fall + # off the C++ dispatch fast path for correctness. Ensure that happens. if any(isinstance(c, core.Tracer) or core.typeof(c).has_qdd for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) final_consts = consts @@ -2359,32 +2361,32 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_out_avals, in_fwd_res = \ + pe.partial_eval_jaxpr_nounits_fwd(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) - num_residuals = len(res_avals) - res_shardings = (UNSPECIFIED,) * num_residuals - res_layouts = (None,) * num_residuals + # out_shardings and out_layouts for residual values output by known_jaxpr def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings - known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts + known_out_shardings = (keep_where(out_shardings, known_outs) + + (UNSPECIFIED,) * len(res_out_avals)) + known_out_layouts = (keep_where(out_layouts, known_outs) + + (None,) * len(res_out_avals)) # Input-to-output forwarding: compute which outputs are just forwarded inputs. - num_out_primals = len(known_jaxpr.out_avals) - num_residuals + num_out_primals = len(known_jaxpr.out_avals) - len(res_out_avals) in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) - # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. - in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) + in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_out_primals]) + assert all(f is None for f in in_fwd_res_) in_fwd = [ fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) - ] + in_fwd_res - del in_fwd_primal, in_fwd_res + ] + in_fwd_res_ + del in_fwd_primal, in_fwd_res_ # Prune jaxpr outputs and out_shardings by removing the input-forwards. keep = [f is None for f in in_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) @@ -2427,7 +2429,11 @@ def keep_where(l, should_keep): all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ - split_list(all_known_outs, [len(all_known_outs) - num_residuals]) + split_list(all_known_outs, [len(all_known_outs) - len(res_out_avals)]) + residual_vals_ = iter(residual_vals) + residual_vals = [next(residual_vals_) if f is None + else [*jaxpr.consts, *known_inputs][f] for f in in_fwd_res] + assert next(residual_vals_, None) is None residual_tracers = map(trace.new_instantiated_const, residual_vals) # The convention of partial_eval_jaxpr_nounits is to place residual binders at @@ -2435,16 +2441,22 @@ def keep_where(l, should_keep): # jaxpr equation built below and the pjit transpose rule assume a # residual-inputs-last convention. unknown_jaxpr = pe.move_binders_to_back( - unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins)) - # Prepare unknown tracers + unknown_jaxpr, [True] * len(residual_vals) + [False] * sum(unknown_ins)) + + # Set up staged-out 'unknown' eqn + unknown_in_shardings = (keep_where(in_shardings, unknown_ins) + + (UNSPECIFIED,) * len(residual_tracers)) + unknown_in_layouts = (keep_where(in_layouts, unknown_ins) + + (None,) * len(residual_tracers)) + unknown_donated_invars = (keep_where(donated_invars, unknown_ins) + + (False,) * len(residual_tracers)) unknown_params = dict( jaxpr=unknown_jaxpr, - in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings), + in_shardings=unknown_in_shardings, + in_layouts=unknown_in_layouts, out_shardings=keep_where(out_shardings, unknown_outs), - in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), out_layouts=keep_where(out_layouts, unknown_outs), - donated_invars=(keep_where(donated_invars, unknown_ins) + - (False,) * num_residuals), + donated_invars=unknown_donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, @@ -2536,8 +2548,7 @@ def _pjit_transpose(cts_in, *primals_in, def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) - body = lu.wrap_init(ad.closed_backward_pass, - debug_info=jaxpr.jaxpr._debug_info) + body = lu.wrap_init(ad.closed_backward_pass, debug_info=jaxpr.jaxpr._debug_info) body = lu.hashable_partial(body, jaxpr, False) primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) diff --git a/tests/api_test.py b/tests/api_test.py index 584eb0eda496..74022d2207b5 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6104,17 +6104,17 @@ def f(x, y): res = saved_residuals(f, (2., 3.), y=4.) self.assertLen(res, 6) - self.assertEqual(res[0][0].shape, ()) - self.assertEqual(res[0][1], "from the argument x[0]") + self.assertEqual(res[0][0].shape, (1,)) + self.assertEqual(res[0][1], "from a constant") self.assertEqual(res[1][0].shape, ()) - self.assertEqual(res[1][1], "from the argument x[1]") + self.assertEqual(res[1][1], "from the argument x[0]") self.assertEqual(res[2][0].shape, ()) - self.assertEqual(res[2][1], "from the argument y") + self.assertEqual(res[2][1], "from the argument x[1]") self.assertEqual(res[3][0].shape, ()) - self.assertStartsWith(res[3][1], "output of jitted function 'f'") + self.assertEqual(res[3][1], "from the argument y") self.assertEqual(res[4][0].shape, ()) - self.assertEqual(res[5][0].shape, (1,)) - self.assertStartsWith(res[5][1], "output of jitted function 'f'") + self.assertStartsWith(res[4][1], "output of jitted function 'f'") + self.assertEqual(res[5][0].shape, ()) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 61b0129aca3e..595da899b5c2 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -2954,7 +2954,7 @@ def fun(x): return np.array([1.0])*x def fwd(x): - return np.array([2.0])*x*x/np.array([1.0]), (x,) + return np.array([2.0])*x*x/np.array([1.0]), (2 * x,) x = jnp.linspace(0, 5.0, 10) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( @@ -2968,7 +2968,7 @@ def test_optimize_remat_vmap(self): def fun(x): return (np.array([1.0])*x)[0] def fwd(x): - return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) + return (np.array([2.0])*x*x/np.array([1.0]))[0], (2 * x,) x = jnp.linspace(0, 5.0, 10) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), @@ -2980,7 +2980,7 @@ def test_optimize_remat_cond(self): def fun(x): return x def fwd(x): - return x*x, (x,) + return x*x, (2 * x,) x = jnp.linspace(0, 5.0, 10) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( @@ -2997,7 +2997,7 @@ def test_optimize_remat_jvp(self): def fun(x): return x**2 def fwd_(x): - return x*x, (x,) + return x*x, (2 * x,) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 5d0f747a85ad..41207f903b2a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -965,8 +965,7 @@ def my_g(u, v): else: expected_jaxpr_debug_infos = [ "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", - # TODO(necula): result_paths - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", # TODO(necula): arg_names "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ] @@ -1403,7 +1402,7 @@ def the_grad(c, as_): else: expected_jaxpr_debug_infos = [ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=result[0],result[1]", "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e350a242548a..f7dc493ab1fa 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -317,6 +317,53 @@ def body(c, _): jax.grad(f)(1.0) self.assertAllClose(x_ref[...], 12, check_dtypes=False) + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing(self, jit): + + @jax.custom_vjp + def gradient_history_calculator(x, ref): + del ref + return x + + def gradient_history_calculator_fwd(x, ref): + return x, ref + + def gradient_history_calculator_bwd(amax_history, grad_output): + amax_update = jnp.max(jnp.abs(grad_output)) + shifted = jnp.roll(amax_history[:], 1) + shifted = shifted.at[0].set(amax_update) + amax_history[:] = shifted + amax_from_history = jnp.max(amax_history[:]) + grad_output = grad_output / amax_from_history + return grad_output, None + + gradient_history_calculator.defvjp( + gradient_history_calculator_fwd, + gradient_history_calculator_bwd) + + class DotOp: + def __init__(self): + self.amax_history = core.mutable_array(jnp.zeros(5,)) + + def forward(self, x, y): + out = jnp.dot(x, y) + out = gradient_history_calculator(out, self.amax_history) + return out + + dot_op = DotOp() + x_top = jnp.ones((5,)) + y_top = jnp.ones((5,)) + + def loss(x, y): + return dot_op.forward(x, y).sum() + + if jit: + loss = jax.jit(loss) + + for i in range(3): + jax.grad(loss, (0,1))(x_top, y_top) + self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): @@ -418,9 +465,23 @@ def f(x, ref): if jit: f = jax.jit(f) x_ref = core.mutable_array(0.) + + jax.vjp(f, 3., x_ref) # returning input ref, okay + + @jax.custom_vjp + def g(x, ref): + return x + def g_fwd(x, _): + y_ref = core.mutable_array(0) + return x, y_ref + g.defvjp(g_fwd, lambda ref, g: g) + if jit: + g = jax.jit(g) + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( ValueError, "custom_vjp fwd function"): - jax.vjp(f, 3., x_ref) + jax.vjp(g, 3., x_ref) @parameterized.parameters([False, True]) def test_argument_aliases_custom_vjp_primal(self, jit): From 58faffd0d21513b3491aaa8e294db87ae5d4af7b Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Fri, 6 Jun 2025 21:02:58 -0400 Subject: [PATCH 1576/1769] Clarify argument order for lax.associative_scan when reverse=True. --- jax/_src/lax/control_flow/loops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3fefab78f7f7..6be65485bc7a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2656,6 +2656,9 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. + If ``elems = [..., x, y, z]`` and ``reverse`` is true, the result is + ``[..., f(f(z, y), x), f(z, y), z]``. + Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) From 1cb18ec2f1509e0df65a55caaccb242226c031d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 7 Jun 2025 01:01:14 -0700 Subject: [PATCH 1577/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f71547227e74ce57cbb387018247462fbfacb4cc. PiperOrigin-RevId: 768400613 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3a6a0f5096b1..b12b6cdb4640 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "7312a4a7cdae69c292d56c1da6cd289ede4c797e" -XLA_SHA256 = "a951125a1e0c60b9e9b36b38ed37515609415817d98134e02cf41db9c7cf8db3" +XLA_COMMIT = "f71547227e74ce57cbb387018247462fbfacb4cc" +XLA_SHA256 = "f7d0507e47f45ffe899df96022bd85c7f883cedd0fd7ebf098857fe6369a3456" def repo(): tf_http_archive( From cb2b217982936e63b7170f0a79fc38fd28748608 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 7 Jun 2025 07:45:22 -0700 Subject: [PATCH 1578/1769] fix vestigial change that caused breakage PiperOrigin-RevId: 768485986 --- jax/_src/custom_derivatives.py | 4 ++-- jax/_src/pallas/fuser/fusible_dtype.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index af1db5ac9ebc..e2d1eb8a6097 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -992,12 +992,12 @@ def get_bind_params(self, params): fwd_jaxpr_thunk = new_params.pop('fwd_jaxpr_thunk') fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), debug_info=call_jaxpr.jaxpr.debug_info) - fwd = lift_fwd(num_consts, new_params['out_trees'], fwd_jaxpr_thunk) + fwd = lift_fwd(num_consts, fwd_jaxpr_thunk) const_avals, _ = split_list(call_jaxpr.in_avals, [num_consts]) bwd = _handle_consts_in_bwd(new_params.pop('bwd'), const_avals) return [fun, fwd, bwd], new_params -def lift_fwd(num_consts: int, out_trees: Callable, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: +def lift_fwd(num_consts: int, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: def fwd(*args): vals, nonzeros = args[::2], args[1::2] assert len(vals) == len(nonzeros) diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 53358231dcf8..152b20ff66ea 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -338,7 +338,7 @@ def _custom_vjp_call_physicalize_rule( new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), debug_info=call_jaxpr.jaxpr.debug_info) - fwd = custom_derivatives.lift_fwd(num_consts, kwargs['out_trees'](), fwd_jaxpr_thunk) + fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) fwd_physicalized = _physicalize_transform(fwd) const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals) From 0d1b1ef7c332133b336600c9aba7ef126560fbf9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 8 Jun 2025 00:24:47 -0700 Subject: [PATCH 1579/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d556913b23808bb93b13b576eb4b74e901fd52a5. PiperOrigin-RevId: 768699442 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b12b6cdb4640..ebc920cbd625 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f71547227e74ce57cbb387018247462fbfacb4cc" -XLA_SHA256 = "f7d0507e47f45ffe899df96022bd85c7f883cedd0fd7ebf098857fe6369a3456" +XLA_COMMIT = "d556913b23808bb93b13b576eb4b74e901fd52a5" +XLA_SHA256 = "8193f25819af97211a0cac09606fe46d28758105f16eb252b7c04adaac7306df" def repo(): tf_http_archive( From c2a569096c1a28c2b9348737172aee381ad0268e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 8 Jun 2025 08:46:52 -0700 Subject: [PATCH 1580/1769] Port pretty-printer to C++. This yields a significant speedup (3x) when printing large jaxprs. PiperOrigin-RevId: 768803336 --- build/requirements_lock_3_10.txt | 4 - build/requirements_lock_3_11.txt | 4 - build/requirements_lock_3_12.txt | 4 - build/requirements_lock_3_13.txt | 4 - build/requirements_lock_3_13_ft.txt | 4 - build/requirements_lock_3_14.txt | 2 - build/test-requirements.txt | 1 - jax/BUILD | 3 +- jax/_src/lib/__init__.py | 6 + jax/_src/pretty_printer.py | 785 +++++++++++++++------------- jaxlib/BUILD | 15 + jaxlib/_pretty_printer.cc | 755 ++++++++++++++++++++++++++ jaxlib/jax.bzl | 1 - jaxlib/nb_class_ptr.h | 2 +- jaxlib/tools/build_wheel.py | 1 + jaxlib/xla_client.py | 2 +- pyproject.toml | 1 + tests/pretty_printer_test.py | 76 ++- 18 files changed, 1265 insertions(+), 405 deletions(-) create mode 100644 jaxlib/_pretty_printer.cc diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 832c801ced63..b80980489fc0 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -24,10 +24,6 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index de3c35ed3c02..ecc5d85b2f2e 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -24,10 +24,6 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 04c6990da696..bce2a45a3984 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -24,10 +24,6 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 965cb3bc9672..3cc09776606f 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -24,10 +24,6 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt contourpy==1.3.0 \ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7d111c3b3e9..efc6fcf45814 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -24,10 +24,6 @@ cloudpickle==3.1.0 \ --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \ --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt contourpy==1.3.1 \ --hash=sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1 \ --hash=sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda \ diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt index 157dca5adbab..6a91caa9bbad 100644 --- a/build/requirements_lock_3_14.txt +++ b/build/requirements_lock_3_14.txt @@ -10,8 +10,6 @@ build==1.2.2.post1 # via -r build/test-requirements.txt cloudpickle==3.1.1 # via -r build/test-requirements.txt -colorama==0.4.6 - # via -r build/test-requirements.txt contourpy==1.3.2 # via matplotlib cycler==0.12.1 diff --git a/build/test-requirements.txt b/build/test-requirements.txt index ef23b10ddf88..50311faebde6 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,6 +1,5 @@ absl-py cloudpickle -colorama>=0.4.4 filelock flatbuffers hypothesis diff --git a/jax/BUILD b/jax/BUILD index b54c4afce18e..588ce697b711 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1155,7 +1155,8 @@ pytype_strict_library( deps = [ ":config", ":util", - ] + py_deps("colorama"), + "//jax/_src/lib", + ], ) pytype_strict_library( diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 8de05061ec99..b1f39a5ca93d 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -108,6 +108,12 @@ def _parse_version(v: str) -> tuple[int, ...]: import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 +if jaxlib_extension_version >= 350: + import jaxlib._pretty_printer as _pretty_printer # noqa: F401 +else: + _pretty_printer = None + + # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): xla_client._xla.collect_garbage() diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 56a99161fb63..d2850c814bb6 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -31,15 +31,11 @@ import enum from functools import partial import sys -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TYPE_CHECKING from jax._src import config from jax._src import util - -try: - import colorama # pytype: disable=import-error -except ImportError: - colorama = None +from jax._src.lib import _pretty_printer as _pretty_printer _PPRINT_USE_COLOR = config.bool_state( @@ -66,423 +62,464 @@ def _can_use_color() -> bool: CAN_USE_COLOR = _can_use_color() -class Doc(util.StrictABC): - __slots__ = () +# TODO(phawkins): remove this condition after the jaxlib 0.6.3 release. +if TYPE_CHECKING or _pretty_printer is None: + try: + import colorama # pytype: disable=import-error + except ImportError: + colorama = None + + class Doc(util.StrictABC): + __slots__ = () + + def format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. + + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ + if use_color is None: + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value + return _format(self, width, use_color=use_color, + annotation_prefix=annotation_prefix, source_map=source_map) + + def __str__(self): + return self.format() + + def __add__(self, other: Doc) -> Doc: + return concat([self, other]) + + def num_annotations(self) -> int: + raise NotImplementedError() + + class _NilDoc(Doc): + def __repr__(self): return "nil" + + def num_annotations(self) -> int: + return 0 + + _nil = _NilDoc() + + class _TextDoc(Doc): + __slots__ = ("text", "annotation") + text: str + annotation: str | None + + def __init__(self, text: str, annotation: str | None = None): + assert isinstance(text, str), text + assert annotation is None or isinstance(annotation, str), annotation + self.text = text + self.annotation = annotation + + def __repr__(self): + if self.annotation is not None: + return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" + else: + return f"text(\"{self.text}\")" - def format( - self, width: int = 80, *, use_color: bool | None = None, - annotation_prefix: str = " # ", - source_map: list[list[tuple[int, int, Any]]] | None = None - ) -> str: - """ - Formats a pretty-printer document as a string. + def num_annotations(self) -> int: + return 1 if self.annotation is not None else 0 - Args: - source_map: for each line in the output, contains a list of - (start column, end column, source) tuples. Each tuple associates a - region of output text with a source. - """ - if use_color is None: - use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value - return _format(self, width, use_color=use_color, - annotation_prefix=annotation_prefix, source_map=source_map) + class _ConcatDoc(Doc): + __slots__ = ("children", "_num_annotations") + children: list[Doc] + _num_annotations: int - def __str__(self): - return self.format() + def __init__(self, children: Sequence[Doc]): + self.children = list(children) + self._num_annotations = sum(child.num_annotations() for child in children) - def __add__(self, other: Doc) -> Doc: - return concat([self, other]) + def __repr__(self): return f"concat({self.children})" - def num_annotations(self) -> int: - raise NotImplementedError() + def num_annotations(self) -> int: + return self._num_annotations -class _NilDoc(Doc): - def __repr__(self): return "nil" + class _BreakDoc(Doc): + __slots__ = ("text",) + text: str - def num_annotations(self) -> int: - return 0 + def __init__(self, text: str): + assert isinstance(text, str), text + self.text = text -_nil = _NilDoc() + def __repr__(self): return f"break({self.text})" -class _TextDoc(Doc): - __slots__ = ("text", "annotation") - text: str - annotation: str | None + def num_annotations(self) -> int: + return 0 - def __init__(self, text: str, annotation: str | None = None): - assert isinstance(text, str), text - assert annotation is None or isinstance(annotation, str), annotation - self.text = text - self.annotation = annotation + class _GroupDoc(Doc): + __slots__ = ("child",) + child: Doc - def __repr__(self): - if self.annotation is not None: - return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" - else: - return f"text(\"{self.text}\")" + def __init__(self, child: Doc): + assert isinstance(child, Doc), child + self.child = child - def num_annotations(self) -> int: - return 1 if self.annotation is not None else 0 + def __repr__(self): return f"group({self.child})" + + def num_annotations(self) -> int: + return self.child.num_annotations() -class _ConcatDoc(Doc): - __slots__ = ("children", "_num_annotations") - children: list[Doc] - _num_annotations: int + class _NestDoc(Doc): + __slots__ = ("n", "child",) + n: int + child: Doc - def __init__(self, children: Sequence[Doc]): - self.children = list(children) - self._num_annotations = sum(child.num_annotations() for child in children) + def __init__(self, n: int, child: Doc): + assert isinstance(child, Doc), child + self.n = n + self.child = child - def __repr__(self): return f"concat({self.children})" + def __repr__(self): return f"nest({self.n, self.child})" - def num_annotations(self) -> int: - return self._num_annotations + def num_annotations(self) -> int: + return self.child.num_annotations() -class _BreakDoc(Doc): - __slots__ = ("text",) - text: str + _NO_SOURCE = object() - def __init__(self, text: str): - assert isinstance(text, str), text - self.text = text + class _SourceMapDoc(Doc): + __slots__ = ("child", "source") + child: Doc + source: Any - def __repr__(self): return f"break({self.text})" + def __init__(self, child: Doc, source: Any): + assert isinstance(child, Doc), child + self.child = child + self.source = source - def num_annotations(self) -> int: - return 0 + def __repr__(self): return f"source({self.child}, {self.source})" -class _GroupDoc(Doc): - __slots__ = ("child",) - child: Doc + def num_annotations(self) -> int: + return self.child.num_annotations() - def __init__(self, child: Doc): - assert isinstance(child, Doc), child - self.child = child + Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", + "MAGENTA", "CYAN", "WHITE", "RESET"]) + Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"]) - def __repr__(self): return f"group({self.child})" + class _ColorDoc(Doc): + __slots__ = ("foreground", "background", "intensity", "child") + foreground: Color | None + background: Color | None + intensity: Intensity | None + child: Doc - def num_annotations(self) -> int: - return self.child.num_annotations() + def __init__(self, child: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): + assert isinstance(child, Doc), child + self.child = child + self.foreground = foreground + self.background = background + self.intensity = intensity -class _NestDoc(Doc): - __slots__ = ("n", "child",) - n: int - child: Doc + def num_annotations(self) -> int: + return self.child.num_annotations() - def __init__(self, n: int, child: Doc): - assert isinstance(child, Doc), child - self.n = n - self.child = child + _BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) + + + # In Lindig's paper fits() and format() are defined recursively. This is a + # non-recursive formulation using an explicit stack, necessary because Python + # doesn't have a tail recursion optimization. + + def _fits(doc: Doc, width: int) -> bool: + agenda = [doc] + while width >= 0 and len(agenda) > 0: + doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + width -= len(doc.text) + elif isinstance(doc, _ConcatDoc): + agenda.extend(reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + width -= len(doc.text) + elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): + agenda.append(doc.child) + else: + raise ValueError("Invalid document ", doc) - def __repr__(self): return f"nest({self.n, self.child})" + return width >= 0 - def num_annotations(self) -> int: - return self.child.num_annotations() -_NO_SOURCE = object() + # Annotation layout: A flat group is sparse if there are no breaks between + # annotations. + def _sparse(doc: Doc) -> bool: + agenda = [doc] + if doc.num_annotations() == 0: + return True + num_annotations = 0 + seen_break = False + while len(agenda) > 0: + doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + if doc.annotation is not None: + if num_annotations >= 1 and seen_break: + return False + num_annotations += 1 + elif isinstance(doc, _ConcatDoc): + agenda.extend(reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + seen_break = True + elif isinstance(doc, _NestDoc): + agenda.append(doc.child) + elif isinstance(doc, _GroupDoc): + agenda.append(doc.child) + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): + agenda.append(doc.child) + else: + raise ValueError("Invalid document ", doc) -class _SourceMapDoc(Doc): - __slots__ = ("child", "source") - child: Doc - source: Any + return True - def __init__(self, child: Doc, source: Any): - assert isinstance(child, Doc), child - self.child = child - self.source = source + class _ColorState(NamedTuple): + foreground: Color + background: Color + intensity: Intensity + + class _State(NamedTuple): + indent: int + mode: _BreakMode + doc: Doc + color: _ColorState + source_map: Any + + class _Line(NamedTuple): + text: str + width: int + annotations: list[str] + + + def _update_color(use_color: bool, state: _ColorState, update: _ColorState + ) -> tuple[_ColorState, str]: + if not use_color or colorama is None: + return update, "" + color_str = "" + if state.foreground != update.foreground: + color_str += getattr(colorama.Fore, str(update.foreground.name)) + if state.background != update.background: + color_str += getattr(colorama.Back, str(update.background.name)) + if state.intensity != update.intensity: + color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands + color_str += getattr(colorama.Style, str(update.intensity.name)) + return update, color_str + + + def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]: + # TODO: Hafiz also implements a local alignment mode, where groups of lines + # with annotations are aligned together. + maxlen = max(l.width for l in lines) + out = [] + for l in lines: + if len(l.annotations) == 0: + out.append(l.text) + else: + out.append(f"{l.text}{' ' * (maxlen - l.width)}" + f"{annotation_prefix}{l.annotations[0]}") + for a in l.annotations[1:]: + out.append(f"{' ' * maxlen}{annotation_prefix}{a}") + return out - def __repr__(self): return f"source({self.child}, {self.source})" - def num_annotations(self) -> int: - return self.child.num_annotations() -Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", - "MAGENTA", "CYAN", "WHITE", "RESET"]) -Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"]) - -class _ColorDoc(Doc): - __slots__ = ("foreground", "background", "intensity", "child") - foreground: Color | None - background: Color | None - intensity: Intensity | None - child: Doc - - def __init__(self, child: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - assert isinstance(child, Doc), child - self.child = child - self.foreground = foreground - self.background = background - self.intensity = intensity - - def num_annotations(self) -> int: - return self.child.num_annotations() - -_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) - - -# In Lindig's paper fits() and format() are defined recursively. This is a -# non-recursive formulation using an explicit stack, necessary because Python -# doesn't have a tail recursion optimization. - -def _fits(doc: Doc, width: int) -> bool: - agenda = [doc] - while width >= 0 and len(agenda) > 0: - doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - width -= len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend(reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - width -= len(doc.text) - elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): - agenda.append(doc.child) - else: - raise ValueError("Invalid document ", doc) - - return width >= 0 - - -# Annotation layout: A flat group is sparse if there are no breaks between -# annotations. -def _sparse(doc: Doc) -> bool: - agenda = [doc] - if doc.num_annotations() == 0: - return True - num_annotations = 0 - seen_break = False - while len(agenda) > 0: - doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - if doc.annotation is not None: - if num_annotations >= 1 and seen_break: - return False - num_annotations += 1 - elif isinstance(doc, _ConcatDoc): - agenda.extend(reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - seen_break = True - elif isinstance(doc, _NestDoc): - agenda.append(doc.child) - elif isinstance(doc, _GroupDoc): - agenda.append(doc.child) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append(doc.child) - else: - raise ValueError("Invalid document ", doc) - - return True - -class _ColorState(NamedTuple): - foreground: Color - background: Color - intensity: Intensity - -class _State(NamedTuple): - indent: int - mode: _BreakMode - doc: Doc - color: _ColorState - source_map: Any - -class _Line(NamedTuple): - text: str - width: int - annotations: list[str] - - -def _update_color(use_color: bool, state: _ColorState, update: _ColorState - ) -> tuple[_ColorState, str]: - if not use_color or colorama is None: - return update, "" - color_str = "" - if state.foreground != update.foreground: - color_str += getattr(colorama.Fore, str(update.foreground.name)) - if state.background != update.background: - color_str += getattr(colorama.Back, str(update.background.name)) - if state.intensity != update.intensity: - color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands - color_str += getattr(colorama.Style, str(update.intensity.name)) - return update, color_str - - -def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]: - # TODO: Hafiz also implements a local alignment mode, where groups of lines - # with annotations are aligned together. - maxlen = max(l.width for l in lines) - out = [] - for l in lines: - if len(l.annotations) == 0: - out.append(l.text) - else: - out.append(f"{l.text}{' ' * (maxlen - l.width)}" - f"{annotation_prefix}{l.annotations[0]}") - for a in l.annotations[1:]: - out.append(f"{' ' * maxlen}{annotation_prefix}{a}") - return out - - - -def _format( - doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, - source_map: list[list[tuple[int, int, Any]]] | None -) -> str: - lines = [] - default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) - annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) - color_state = default_colors - source_start = 0 # The column at which the current source region starts. - source = _NO_SOURCE # The currently active source region. - line_source_map = [] # Source maps for the current line of text. - agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] - k = 0 - line_text = "" - line_annotations = [] - while len(agenda) > 0: - i, m, doc, color, agenda_source = agenda.pop() - if source_map is not None and agenda_source != source: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source = agenda_source - source_start = pos - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - color_state, color_str = _update_color(use_color, color_state, color) - line_text += color_str - line_text += doc.text - if doc.annotation is not None: - line_annotations.append(doc.annotation) - k += len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend(_State(i, m, d, color, source) - for d in reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - lines.append(_Line(line_text, k, line_annotations)) - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - line_source_map = [] - source_start = i - line_text = " " * i - line_annotations = [] - k = i - else: + def _format( + doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, + source_map: list[list[tuple[int, int, Any]]] | None + ) -> str: + lines = [] + default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) + annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) + color_state = default_colors + source_start = 0 # The column at which the current source region starts. + source = _NO_SOURCE # The currently active source region. + line_source_map = [] # Source maps for the current line of text. + agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] + k = 0 + line_text = "" + line_annotations = [] + while len(agenda) > 0: + i, m, doc, color, agenda_source = agenda.pop() + if source_map is not None and agenda_source != source: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source = agenda_source + source_start = pos + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): color_state, color_str = _update_color(use_color, color_state, color) line_text += color_str line_text += doc.text + if doc.annotation is not None: + line_annotations.append(doc.annotation) k += len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append(_State(i + doc.n, m, doc.child, color, source)) - elif isinstance(doc, _GroupDoc): - # In Lindig's paper, _fits is passed the remainder of the document. - # I'm pretty sure that's a bug and we care only if the current group fits! - if (_fits(doc, width - k) and _sparse(doc)): - agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) + elif isinstance(doc, _ConcatDoc): + agenda.extend(_State(i, m, d, color, source) + for d in reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + if m == _BreakMode.BREAK: + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + lines.append(_Line(line_text, k, line_annotations)) + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + line_source_map = [] + source_start = i + line_text = " " * i + line_annotations = [] + k = i + else: + color_state, color_str = _update_color(use_color, color_state, color) + line_text += color_str + line_text += doc.text + k += len(doc.text) + elif isinstance(doc, _NestDoc): + agenda.append(_State(i + doc.n, m, doc.child, color, source)) + elif isinstance(doc, _GroupDoc): + # In Lindig's paper, _fits is passed the remainder of the document. + # I'm pretty sure that's a bug and we care only if the current group fits! + if (_fits(doc, width - k) and _sparse(doc)): + agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) + else: + agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) + elif isinstance(doc, _ColorDoc): + color = _ColorState(doc.foreground or color.foreground, + doc.background or color.background, + doc.intensity or color.intensity) + agenda.append(_State(i, m, doc.child, color, source)) + elif isinstance(doc, _SourceMapDoc): + agenda.append(_State(i, m, doc.child, color, doc.source)) else: - agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) - elif isinstance(doc, _ColorDoc): - color = _ColorState(doc.foreground or color.foreground, - doc.background or color.background, - doc.intensity or color.intensity) - agenda.append(_State(i, m, doc.child, color, source)) - elif isinstance(doc, _SourceMapDoc): - agenda.append(_State(i, m, doc.child, color, doc.source)) - else: - raise ValueError("Invalid document ", doc) - - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - lines.append(_Line(line_text, k, line_annotations)) - out = "\n".join(_align_annotations(lines, annotation_prefix)) - _, color_str = _update_color(use_color, color_state, - default_colors) - return out + color_str - - - - -# Public API. - -def nil() -> Doc: - """An empty document.""" - return _nil - -def text(s: str, annotation: str | None = None) -> Doc: - """Literal text.""" - return _TextDoc(s, annotation) - -def concat(docs: Sequence[Doc]) -> Doc: - """Concatenation of documents.""" - docs = list(docs) - if len(docs) == 1: - return docs[0] - return _ConcatDoc(docs) + raise ValueError("Invalid document ", doc) + + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + lines.append(_Line(line_text, k, line_annotations)) + out = "\n".join(_align_annotations(lines, annotation_prefix)) + _, color_str = _update_color(use_color, color_state, + default_colors) + return out + color_str + + + + + # Public API. + + def nil() -> Doc: + """An empty document.""" + return _nil + + def text(s: str, annotation: str | None = None) -> Doc: + """Literal text.""" + return _TextDoc(s, annotation) + + def concat(docs: Sequence[Doc]) -> Doc: + """Concatenation of documents.""" + docs = list(docs) + if len(docs) == 1: + return docs[0] + return _ConcatDoc(docs) + + def brk(text: str = " ") -> Doc: + """A break. + + Prints either as a newline or as `text`, depending on the enclosing group. + """ + return _BreakDoc(text) + + def group(doc: Doc) -> Doc: + """Layout alternative groups. -def brk(text: str = " ") -> Doc: - """A break. + Prints the group with its breaks as their text (typically spaces) if the + entire group would fit on the line when printed that way. Otherwise, breaks + inside the group as printed as newlines. + """ + return _GroupDoc(doc) - Prints either as a newline or as `text`, depending on the enclosing group. - """ - return _BreakDoc(text) + def nest(n: int, doc: Doc) -> Doc: + """Increases the indentation level by `n`.""" + return _NestDoc(n, doc) -def group(doc: Doc) -> Doc: - """Layout alternative groups. - Prints the group with its breaks as their text (typically spaces) if the - entire group would fit on the line when printed that way. Otherwise, breaks - inside the group as printed as newlines. - """ - return _GroupDoc(doc) + def color(doc: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): + """ANSI colors. -def nest(n: int, doc: Doc) -> Doc: - """Increases the indentation level by `n`.""" - return _NestDoc(n, doc) + Overrides the foreground/background/intensity of the text for the child doc. + Requires use_colors=True to be set when printing and the `colorama` package + to be installed; otherwise does nothing. + """ + return _ColorDoc(doc, foreground=foreground, background=background, + intensity=intensity) -def color(doc: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - """ANSI colors. + def source_map(doc: Doc, source: Any): + """Source mapping. - Overrides the foreground/background/intensity of the text for the child doc. - Requires use_colors=True to be set when printing and the `colorama` package - to be installed; otherwise does nothing. - """ - return _ColorDoc(doc, foreground=foreground, background=background, - intensity=intensity) + A source map associates a region of the pretty-printer's text output with a + source location that produced it. For the purposes of the pretty printer a + ``source`` may be any object: we require only that we can compare sources for + equality. A text region to source object mapping can be populated as a side + output of the ``format`` method. + """ + return _SourceMapDoc(doc, source) +else: + Color = _pretty_printer.Color + Intensity = _pretty_printer.Intensity + Doc = _pretty_printer.Doc + def _format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. -def source_map(doc: Doc, source: Any): - """Source mapping. + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ + if use_color is None: + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value + return self._format( + width, use_color=use_color, annotation_prefix=annotation_prefix, + source_map=source_map) + Doc.format = _format + Doc.__str__ = lambda self: self.format() + nil = _pretty_printer.nil + text = _pretty_printer.text + concat = _pretty_printer.concat + brk = _pretty_printer.brk + group = _pretty_printer.group + nest = _pretty_printer.nest + color = _pretty_printer.color + source_map = _pretty_printer.source_map - A source map associates a region of the pretty-printer's text output with a - source location that produced it. For the purposes of the pretty printer a - ``source`` may be any object: we require only that we can compare sources for - equality. A text region to source object mapping can be populated as a side - output of the ``format`` method. - """ - return _SourceMapDoc(doc, source) type_annotation = partial(color, intensity=Intensity.NORMAL, foreground=Color.MAGENTA) @@ -494,6 +531,8 @@ def join(sep: Doc, docs: Sequence[Doc]) -> Doc: docs = list(docs) if len(docs) == 0: return nil() + if len(docs) == 1: + return docs[0] xs = [docs[0]] for doc in docs[1:]: xs.append(sep) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index da5fb4952743..14a84bf1f19f 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -53,6 +53,7 @@ pytype_strict_library( data = [":ffi_headers"], deps = [ ":_jax", + ":_pretty_printer", ":cpu_feature_guard", ":jax", ":jaxlib_files", @@ -150,6 +151,7 @@ pywrap_library( }, deps = [ ":_jax", + ":_pretty_printer", ":utils", ":weakref_lru_cache", "//jaxlib/mlir/_mlir_libs:_chlo", @@ -248,6 +250,19 @@ nanobind_extension( ], ) +nanobind_pywrap_extension( + name = "_pretty_printer", + srcs = ["_pretty_printer.cc"], + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + ], +) + nanobind_pywrap_extension( name = "weakref_lru_cache", srcs = ["weakref_lru_cache.cc"], diff --git a/jaxlib/_pretty_printer.cc b/jaxlib/_pretty_printer.cc new file mode 100644 index 000000000000..1bf6f8d2f541 --- /dev/null +++ b/jaxlib/_pretty_printer.cc @@ -0,0 +1,755 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" + +namespace nb = nanobind; + +namespace jax { + +enum class Color { + kBlack = 30, + kRed = 31, + kGreen = 32, + kYellow = 33, + kBlue = 34, + kMagenta = 35, + kCyan = 36, + kWhite = 37, + kReset = 39, +}; + +std::string ColorToString(Color color) { + switch (color) { + case Color::kBlack: + return "black"; + case Color::kRed: + return "red"; + case Color::kGreen: + return "green"; + case Color::kYellow: + return "yellow"; + case Color::kBlue: + return "blue"; + case Color::kMagenta: + return "magenta"; + case Color::kCyan: + return "cyan"; + case Color::kWhite: + return "white"; + case Color::kReset: + return "reset"; + } +} + +enum class Intensity { + kNormal = 22, + kDim = 2, + kBright = 1, +}; + +std::string IntensityToString(Intensity intensity) { + switch (intensity) { + case Intensity::kNormal: + return "normal"; + case Intensity::kDim: + return "dim"; + case Intensity::kBright: + return "bright"; + } +} + +struct FormatState; +struct FormatAgendum; + +class Doc { + public: + Doc(int num_annotations) : num_annotations_(num_annotations) {} + virtual ~Doc() = default; + virtual std::string Repr() const = 0; + + int num_annotations() const { return num_annotations_; } + + virtual void Fits(std::stack& agenda, int& width) const = 0; + + // Returns true if the doc may be sparse, i.e. there are no breaks between + // annotations. Returns false if the doc is known not to be sparse. + virtual bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const = 0; + + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const = 0; + + private: + int num_annotations_; +}; + +class NilDoc final : public Doc { + public: + NilDoc() : Doc(/*num_annotations=*/0) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; +}; + +class TextDoc final : public Doc { + public: + TextDoc(std::string text, std::optional annotation) + : Doc(annotation.has_value() ? 1 : 0), + text_(std::move(text)), + annotation_(std::move(annotation)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; + std::optional annotation_; +}; + +class ConcatDoc final : public Doc { + public: + explicit ConcatDoc(std::vector> children) + : Doc(TotalNumAnnotations(children)), children_(std::move(children)) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + static int TotalNumAnnotations( + absl::Span> children) { + int total = 0; + for (const auto& child : children) { + total += child->num_annotations(); + } + return total; + } + std::vector> children_; +}; + +class BreakDoc final : public Doc { + public: + explicit BreakDoc(std::string text) + : Doc(/*num_annotations=*/0), text_(std::move(text)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; +}; + +class GroupDoc final : public Doc { + public: + explicit GroupDoc(xla::nb_class_ptr child) + : Doc(/*num_annotations=*/child->num_annotations()), + child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; +}; + +class NestDoc final : public Doc { + public: + explicit NestDoc(int n, xla::nb_class_ptr child) + : Doc(child->num_annotations()), n_(n), child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + int n_; + xla::nb_class_ptr child_; +}; + +class SourceMapDoc final : public Doc { + public: + explicit SourceMapDoc(xla::nb_class_ptr child, nb::object source) + : Doc(child->num_annotations()), + child_(std::move(child)), + source_(std::move(source)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; + nb::object source_; +}; + +class ColorDoc final : public Doc { + public: + explicit ColorDoc(xla::nb_class_ptr child, + std::optional foreground, + std::optional background, + std::optional intensity) + : Doc(child->num_annotations()), + child_(std::move(child)), + foreground_(foreground), + background_(background), + intensity_(intensity) {} + + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; + std::optional foreground_; + std::optional background_; + std::optional intensity_; +}; + +std::string NilDoc::Repr() const { return "nil"; } + +std::string TextDoc::Repr() const { + if (annotation_.has_value()) { + return absl::StrFormat("text(\"%s\", annotation=\"%s\")", text_, + *annotation_); + } else { + return absl::StrFormat("text(\"%s\")", text_); + } +} + +std::string ConcatDoc::Repr() const { + return absl::StrFormat( + "concat(%s)", + absl::StrJoin(children_, ", ", [](std::string* out, const auto& child) { + absl::StrAppend(out, child->Repr()); + })); +} + +std::string BreakDoc::Repr() const { + return absl::StrFormat("break(\"%s\")", text_); +} + +std::string GroupDoc::Repr() const { + return absl::StrFormat("group(%s)", child_->Repr()); +} + +std::string NestDoc::Repr() const { + return absl::StrFormat("nest(%d, %s)", n_, child_->Repr()); +} + +std::string SourceMapDoc::Repr() const { + return absl::StrFormat("source(%s, %s)", child_->Repr(), + nb::cast(nb::repr(source_))); +} + +std::string ColorDoc::Repr() const { + std::string foreground_str = + foreground_.has_value() ? ColorToString(*foreground_) : "None"; + std::string background_str = + background_.has_value() ? ColorToString(*background_) : "None"; + std::string intensity_str = + intensity_.has_value() ? IntensityToString(*intensity_) : "None"; + return absl::StrFormat("color(%s, %s, %s, %s)", child_->Repr(), + foreground_str, background_str, intensity_str); +} + +// Fits method implementations + +void NilDoc::Fits(std::stack& agenda, int& width) const {} + +void TextDoc::Fits(std::stack& agenda, int& width) const { + width -= text_.size(); +} + +void ConcatDoc::Fits(std::stack& agenda, int& width) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } +} + +void BreakDoc::Fits(std::stack& agenda, int& width) const { + width -= static_cast(text_.size()); +} + +void GroupDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void NestDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void SourceMapDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void ColorDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +bool Fits(const Doc* doc, int width) { + std::stack agenda; + agenda.push(doc); + while (width >= 0 && !agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + doc->Fits(agenda, width); + } + return width >= 0; +} + +// Sparse method implementations + +bool NilDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + return true; +} + +bool TextDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + if (annotation_.has_value()) { + if (num_annotations >= 1 && seen_break) { + return false; + } + num_annotations -= 1; + } + return true; +} + +bool ConcatDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } + return true; +} + +bool BreakDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + seen_break = true; + return true; +} + +bool GroupDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool NestDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool SourceMapDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool ColorDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +// Returns true if the doc is sparse, i.e. there are no breaks between +// annotations. +bool Sparse(const Doc* doc) { + if (doc->num_annotations() == 0) { + return true; + } + std::stack agenda; + agenda.push(doc); + int num_annotations = 0; + bool seen_break = false; + while (!agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + if (!doc->Sparse(agenda, num_annotations, seen_break)) { + return false; + } + } + return true; +} + +struct ColorState { + Color foreground; + Color background; + Intensity intensity; + + bool operator==(const ColorState& other) const { + return foreground == other.foreground && background == other.background && + intensity == other.intensity; + } + bool operator!=(const ColorState& other) const { return !operator==(other); } +}; + +constexpr ColorState kDefaultColors = + ColorState{Color::kReset, Color::kReset, Intensity::kNormal}; +constexpr ColorState kAnnotationColors = + ColorState{Color::kReset, Color::kReset, Intensity::kDim}; + +enum class BreakMode { kFlat, kBreak }; + +struct FormatAgendum { + int indent; + BreakMode mode; + const Doc* doc; + ColorState color; + nb::object source; +}; + +struct Line { + std::string text; + int width; + std::vector annotations; +}; + +// Format method implementations + +struct FormatState { + int width; + std::stack agenda; + std::string line_text; + int k; + std::vector line_annotations; + std::optional color; + std::optional source_map; + nb::list line_source_map; + int source_start; + nb::object source; + std::vector lines; +}; + +std::string UpdateColor(std::optional& state, + const ColorState& update) { + if (!state.has_value() || *state == update) { + return ""; + } + std::string result = "\033["; + absl::InlinedVector codes; + if (state->foreground != update.foreground) { + codes.push_back(absl::StrCat(static_cast(update.foreground))); + } + if (state->background != update.background) { + codes.push_back(absl::StrCat(static_cast(update.background) + 10)); + } + if (state->intensity != update.intensity) { + codes.push_back(absl::StrCat(static_cast(update.intensity))); + } + absl::StrAppend(&result, absl::StrJoin(codes, ";"), "m"); + state = update; + return result; +} + +void NilDoc::Format(const FormatAgendum& agendum, FormatState& state) const {} + +void TextDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + if (annotation_.has_value()) { + state.line_annotations.push_back(*annotation_); + } + state.k += text_.size(); +} + +void ConcatDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, it->get(), + agendum.color, state.source}); + } +} + +void BreakDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + if (agendum.mode == BreakMode::kBreak) { + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + state.line_source_map = nb::list(); + state.source_start = agendum.indent; + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + state.line_text = std::string(agendum.indent, ' '); + state.line_annotations.clear(); + state.k = agendum.indent; + } else { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + state.k += text_.size(); + } +} + +void GroupDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + // In Lindig's paper, _fits is passed the remainder of the document. + // I'm pretty sure that's a bug and we care only if the current group fits! + bool fits = ::jax::Fits(agendum.doc, state.width - state.k) && + ::jax::Sparse(agendum.doc); + state.agenda.push(FormatAgendum{agendum.indent, + fits ? BreakMode::kFlat : BreakMode::kBreak, + child_.get(), agendum.color, state.source}); +} + +void NestDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent + n_, agendum.mode, + child_.get(), agendum.color, state.source}); +} + +void SourceMapDoc::Format(const FormatAgendum& agendum, + FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + agendum.color, source_}); +} + +void ColorDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + ColorState color = agendum.color; + if (foreground_.has_value()) { + color.foreground = *foreground_; + } + if (background_.has_value()) { + color.background = *background_; + } + if (intensity_.has_value()) { + color.intensity = *intensity_; + } + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + color, state.source}); +} + +std::string Format(const Doc* doc, int width, bool use_color, + std::string annotation_prefix, + std::optional source_map) { + FormatState state; + if (use_color) { + state.color = kDefaultColors; + } + state.width = width; + state.source_start = 0; + state.source_map = source_map; + state.agenda.push( + FormatAgendum{0, BreakMode::kBreak, doc, kDefaultColors, nb::object()}); + state.k = 0; + while (!state.agenda.empty()) { + FormatAgendum agendum = state.agenda.top(); + state.agenda.pop(); + if (source_map.has_value() && agendum.source.ptr() != state.source.ptr()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source = agendum.source; + state.source_start = pos; + } + agendum.doc->Format(agendum, state); + } + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + + int max_width = 0; + for (const auto& line : state.lines) { + max_width = std::max(max_width, line.width); + } + std::string out = + absl::StrJoin(state.lines, "\n", [&](std::string* out, const Line& line) { + if (line.annotations.empty()) { + absl::StrAppend(out, line.text); + } else { + absl::StrAppend(out, line.text, + std::string(max_width - line.width, ' '), + annotation_prefix, line.annotations[0]); + for (int i = 1; i < line.annotations.size(); ++i) { + absl::StrAppend(out, std::string(max_width, ' '), annotation_prefix, + line.annotations[i]); + } + } + }); + absl::StrAppend(&out, UpdateColor(state.color, kDefaultColors)); + return out; +} + +NB_MODULE(_pretty_printer, m) { + nb::enum_(m, "Color") + .value("BLACK", Color::kBlack) + .value("RED", Color::kRed) + .value("GREEN", Color::kGreen) + .value("YELLOW", Color::kYellow) + .value("BLUE", Color::kBlue) + .value("MAGENTA", Color::kMagenta) + .value("CYAN", Color::kCyan) + .value("WHITE", Color::kWhite) + .value("RESET", Color::kReset); + + nb::enum_(m, "Intensity") + .value("DIM", Intensity::kDim) + .value("NORMAL", Intensity::kNormal) + .value("BRIGHT", Intensity::kBright); + + nb::class_(m, "Doc") + .def("__repr__", &Doc::Repr) + .def("__add__", + [](xla::nb_class_ptr self, xla::nb_class_ptr other) { + return xla::make_nb_class( + std::vector>{std::move(self), + std::move(other)}); + }) + .def("_format", &Format, nb::arg("width"), nb::arg("use_color"), + nb::arg("annotation_prefix"), nb::arg("source_map").none()); + + nb::class_(m, "NilDoc"); + nb::class_(m, "TextDoc"); + nb::class_(m, "ConcatDoc"); + nb::class_(m, "BreakDoc"); + nb::class_(m, "GroupDoc"); + nb::class_(m, "NestDoc"); + nb::class_(m, "ColorDoc"); + nb::class_(m, "SourceMapDoc"); + + m.def( + "nil", []() { return xla::make_nb_class(); }, + "An empty document."); + m.def( + "text", + [](std::string text, std::optional annotation) { + return xla::make_nb_class(std::move(text), + std::move(annotation)); + }, + nb::arg("text"), nb::arg("annotation").none() = std::nullopt, + "Literal text."); + m.def( + "concat", + [](std::vector> children) { + return xla::make_nb_class(std::move(children)); + }, + nb::arg("children"), "Concatenation of documents."); + m.def( + "brk", + [](std::string text) { return xla::make_nb_class(text); }, + nb::arg("text") = std::string(" "), + R"(A break. + +Prints either as a newline or as `text`, depending on the enclosing group. +)"); + m.def( + "group", + [](xla::nb_class_ptr child) { + return xla::make_nb_class(std::move(child)); + }, + R"(Layout alternative groups. + +Prints the group with its breaks as their text (typically spaces) if the +entire group would fit on the line when printed that way. Otherwise, breaks +inside the group as printed as newlines. +)"); + m.def( + "nest", + [](int n, xla::nb_class_ptr child) { + return xla::make_nb_class(n, std::move(child)); + }, + "Increases the indentation level by `n`."); + m.def( + "color", + [](xla::nb_class_ptr child, std::optional foreground, + std::optional background, std::optional intensity) { + return xla::make_nb_class(std::move(child), foreground, + background, intensity); + }, + nb::arg("child"), nb::arg("foreground").none() = std::nullopt, + nb::arg("background").none() = std::nullopt, + nb::arg("intensity").none() = std::nullopt, + R"(ANSI colors. + +Overrides the foreground/background/intensity of the text for the child doc. +Requires use_colors=True to be set when printing; otherwise does nothing. +)"); + m.def( + "source_map", + [](xla::nb_class_ptr child, nb::object source) { + return xla::make_nb_class(std::move(child), + std::move(source)); + }, + nb::arg("doc"), nb::arg("source"), + R"(Source mapping. + +A source map associates a region of the pretty-printer's text output with a +source location that produced it. For the purposes of the pretty printer a +``source`` may be any object: we require only that we can compare sources for +equality. A text region to source object mapping can be populated as a side +output of the ``format`` method. +)"); +} + +} // namespace jax diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 00e26756ded9..835045c25ef1 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -82,7 +82,6 @@ _py_deps = { "absl/testing": ["@pypi//absl_py"], "absl/flags": ["@pypi//absl_py"], "cloudpickle": get_optional_dep("@pypi//cloudpickle"), - "colorama": get_optional_dep("@pypi//colorama"), "epath": get_optional_dep("@pypi//etils"), # etils.epath "filelock": get_optional_dep("@pypi//filelock"), "flatbuffers": ["@pypi//flatbuffers"], diff --git a/jaxlib/nb_class_ptr.h b/jaxlib/nb_class_ptr.h index 381c77e812b9..f1214c19369e 100644 --- a/jaxlib/nb_class_ptr.h +++ b/jaxlib/nb_class_ptr.h @@ -34,7 +34,7 @@ class nb_class_ptr : public nanobind::object { : nanobind::object(h, ::nanobind::detail::steal_t{}) {} inline static bool check_(nanobind::handle h) { nanobind::handle type = nanobind::type(); - return h.type().is(type); + return nanobind::isinstance(h, type); }; T* operator->() const { return nanobind::inst_ptr(ptr()); } diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 0c29a7ae6ea3..cf1be5e5a8ed 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -205,6 +205,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): f"{source_file_prefix}jaxlib/gpu_solver.py", f"{source_file_prefix}jaxlib/gpu_sparse.py", f"{source_file_prefix}jaxlib/plugin_support.py", + f"{source_file_prefix}jaxlib/_pretty_printer.{pyext}", f"{source_file_prefix}jaxlib/version.py", f"{source_file_prefix}jaxlib/xla_client.py", f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index e6104a9958c2..2c73947fa684 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 349 +_version = 350 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/pyproject.toml b/pyproject.toml index d48351197b54..f1524f84c480 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ module = [ "jaxlib.utils", "jaxlib.version", "jaxlib._jax.utils", + "jaxlib._pretty_printer", "jraph.*", "libtpu.*", "matplotlib.*", diff --git a/tests/pretty_printer_test.py b/tests/pretty_printer_test.py index d87708c9d91c..b4363be1c965 100644 --- a/tests/pretty_printer_test.py +++ b/tests/pretty_printer_test.py @@ -13,24 +13,90 @@ # limitations under the License. from absl.testing import absltest - -from jax._src import test_util as jtu from jax._src import pretty_printer as pp +from jax._src import test_util as jtu class PrettyPrinterTest(jtu.JaxTestCase): def testSourceMap(self): doc = pp.concat([ - pp.text("abc"), pp.source_map(pp.text("def"), 101), - pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77), - pp.text("mn"), + pp.text("abc"), + pp.source_map(pp.text("def"), 101), + pp.source_map( + pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77 + ), + pp.text("mn"), ]) source_map = [] out = doc.format(width=8, source_map=source_map) self.assertEqual(out, "abcdefgh\nijklmn") self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]]) + def testBasics(self): + self.assertEqual(pp.nil().format(), "") + self.assertEqual(pp.text("").format(), "") + self.assertEqual(pp.text("testing").format(), "testing") + self.assertEqual(pp.text("\n").format(), "\n") + self.assertEqual(pp.brk().format(), "\n") + # Group that fits will use the space from brk() + self.assertEqual(pp.group(pp.brk()).format(), " ") + # Group that doesn't fit (due to width=0) will use newline + self.assertEqual(pp.group(pp.brk()).format(width=0), "\n") + + # Custom break text + self.assertEqual(pp.group(pp.brk("-")).format(), "-") + self.assertEqual(pp.group(pp.brk("-")).format(width=0), "\n") + + # Concatenation + self.assertEqual((pp.text("a") + pp.text("b")).format(), "ab") + self.assertEqual(pp.concat([pp.text("a"), pp.text("b c")]).format(), "ab c") + + x = pp.text("x") + y = pp.text("y") + z = pp.text("z") + + # Join + # Join with a break that becomes a space when fitting + join_doc_space = pp.join( + pp.text(",") + pp.brk(), [pp.text("a"), pp.text("b"), pp.text("c")] + ) + self.assertEqual(pp.group(join_doc_space).format(), "a, b, c") + self.assertEqual(pp.group(join_doc_space).format(width=5), "a,\nb,\nc") + self.assertEqual(pp.join(pp.text(","), [x, y, z]).format(), "x,y,z") + + j = pp.join( + pp.brk(), [pp.text("xx"), pp.text("yy"), pp.text("zz"), pp.text("ww")] + ) + self.assertEqual(pp.group(j).format(width=3), "xx\nyy\nzz\nww") + self.assertEqual(pp.group(j).format(width=80), "xx yy zz ww") + + bx = pp.brk() + x + bxbx = bx + bx + bx4 = bxbx + bxbx + + # Horizontal-like (fits) + self.assertEqual(pp.group(bx).format(), " x") + self.assertEqual(pp.group(bxbx).format(), " x x") + self.assertEqual(pp.group(bx4).format(), " x x x x") + + # Vertical-like (forced by width) + self.assertEqual(pp.group(bx).format(width=0), "\nx") + self.assertEqual(pp.group(bxbx).format(width=0), "\nx\nx") + self.assertEqual(pp.group(bx4).format(width=0), "\nx\nx\nx\nx") + self.assertEqual(pp.group(bxbx).format(width=3), "\nx\nx") + + # Nesting + xbybz = x + pp.brk() + y + pp.brk() + z + self.assertEqual(pp.nest(2, pp.group(bx)).format(), " x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(), " x x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bx)).format(width=0), "\n x") + self.assertEqual( + pp.nest(2, pp.nest(2, pp.group(bx))).format(width=0), "\n x" + ) + self.assertEqual(pp.nest(2, pp.group(xbybz)).format(width=0), "x\n y\n z") + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(width=0), "\n x\n x") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From d69086db3529ce75edf355ce94e8e7028d7c629f Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 8 Jun 2025 09:47:18 -0700 Subject: [PATCH 1581/1769] Fix typo in error message. PiperOrigin-RevId: 768815050 --- jax/_src/pjit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9b2ddf9c54b2..32e1f6299931 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2993,7 +2993,7 @@ def reshard(xs, out_shardings): if ds is None: raise ValueError( 'Reshard should only be used with out_shardings which are non-None ' - 'and have a nonempty mesh. Got sharding {s}.' + f'and have a nonempty mesh. Got sharding {s}.' ) ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error out_flat.append(reshard_p.bind(x, dst_sharding=ds)) From 880dd13175f12c4e27f32e81e51d349d33703f5e Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Sun, 8 Jun 2025 19:01:44 -0400 Subject: [PATCH 1582/1769] Minor fix to doc for random.orthogonal. --- jax/_src/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 60dad3a82021..d44361ebb3a6 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2175,7 +2175,7 @@ def orthogonal( m: an integer indicating the number of columns. Defaults to `n`. Returns: - A random array of shape `(*shape, n, n)` and specified dtype. + A random array of shape `(*shape, n, m)` and specified dtype. References: .. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from From d8317b538db5f143d3a0fd69f80291f74c1bf47a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 8 Jun 2025 19:20:08 -0700 Subject: [PATCH 1583/1769] [Pallas][Easy] Terser printing of GridMapping unless debug is set. PiperOrigin-RevId: 768932310 --- jax/_src/pallas/core.py | 74 ++++++++++++++++++++++++++---- jax/_src/pallas/mosaic_gpu/core.py | 2 + jax/_src/pallas/pallas_call.py | 37 ++++++++++----- 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 047edd2b8435..07631699ea61 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -499,6 +499,7 @@ def to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: if self.index_map is None: index_map_func = default_index_map(len(array_aval.shape)) @@ -539,11 +540,15 @@ def to_block_mapping( fake_index_map_args, fake_index_map_kwargs = \ index_map_tree.unflatten([False] * index_map_tree.num_leaves) - debug = api_util.debug_info("pallas_call index_map", - index_map_func, fake_index_map_args, - fake_index_map_kwargs) + debug_info = api_util.debug_info( + "pallas_call index_map", + index_map_func, + fake_index_map_args, + fake_index_map_kwargs, + ) flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) + lu.wrap_init(index_map_func, debug_info=debug_info), index_map_tree + ) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( flat_index_map_fun, index_map_avals @@ -553,7 +558,7 @@ def to_block_mapping( if len(unflat_avals) != len(block_shape): raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return " f"{len(block_shape)} values to match {block_shape=}. " f"Currently returning {len(unflat_avals)} values:" @@ -581,14 +586,14 @@ def to_block_mapping( for i, ov in enumerate(out_avals): if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}." ) if consts: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must not capture constants: {consts}" ) @@ -604,6 +609,7 @@ def to_block_mapping( ), origin=origin, pipeline_mode=self.pipeline_mode, + debug=debug, ) mapping.check_invariants() return mapping @@ -645,6 +651,7 @@ class BlockMapping: origin: OriginStr transforms: Sequence[MemoryRefTransform] = () pipeline_mode: Buffered | None = None + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -716,6 +723,24 @@ def has_trivial_window(self): return False return True + def __repr__(self): + if self.debug: + return ( + f"BlockMapping(block_shape={self.block_shape}, " + f"transformed_block_aval={self.transformed_block_aval}, " + f"index_map_jaxpr={self.index_map_jaxpr}, " + f"index_map_out_tree={self.index_map_out_tree}, " + f"array_shape_dtype={self.array_shape_dtype}, " + f"origin={self.origin}, " + f"transforms={self.transforms}, " + f"pipeline_mode={self.pipeline_mode}, " + f"debug={self.debug})" + ) + return f"BlockMapping(block_shape={self.block_shape})" + + def __str__(self): + return self.__repr__() + @contextlib.contextmanager def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): @@ -780,6 +805,8 @@ class GridMapping: num_scratch_operands: int get_grid_indices: Callable | None = None local_grid_env: Callable | None = None + # Primarily dictates how much debugging information is printed. + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -903,6 +930,29 @@ def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: return tuple( bm.array_shape_dtype for bm in self.block_mappings_output) + def __repr__(self): + if self.debug: + return ( + f"GridMapping(grid={self.grid}, grid_names={self.grid_names}, " + f"block_mappings={self.block_mappings}, " + f"index_map_tree={self.index_map_tree}, " + f"index_map_avals={self.index_map_avals}, " + f"vmapped_dims={self.vmapped_dims}, " + f"num_index_operands={self.num_index_operands}, " + f"num_inputs={self.num_inputs}, " + f"num_outputs={self.num_outputs}, " + f"num_scratch_operands={self.num_scratch_operands}, " + f"get_grid_indices={self.get_grid_indices}, " + f"local_grid_env={self.local_grid_env}, " + f"debug={self.debug})" + ) + return ( + f"GridMapping(grid={self.grid}, block_mappings={self.block_mappings})" + ) + + def __str__(self): + return self.__repr__() + def _is_valid_grid_dim(dim: int | jax.Array) -> bool: if isinstance(dim, jax.Array): @@ -938,6 +988,7 @@ def _convert_block_spec_to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: if block_spec is no_block_spec: block_spec = BlockSpec(None, None) @@ -948,8 +999,10 @@ def _convert_block_spec_to_block_mapping( index_map_tree=index_map_tree, grid=grid, mapped_dims=mapped_dims, + debug=debug, ) + index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) @@ -1023,8 +1076,8 @@ def get_grid_mapping( out_avals: Sequence[jax_core.AbstractValue], out_tree: tree_util.PyTreeDef, out_origins: Sequence[OriginStr], -) -> tuple[tuple[jax_core.AbstractValue, ...], - GridMapping]: + debug: bool = False, +) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): dim_check : Any = jax_core.is_dim else: @@ -1090,6 +1143,7 @@ def get_grid_mapping( index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] mapped_dims=(), + debug=debug, ), flat_in_specs, in_origins[num_flat_scalar_prefetch:], @@ -1112,6 +1166,7 @@ def get_grid_mapping( index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] mapped_dims=(), + debug=debug, ), flat_out_specs, out_origins, @@ -1128,6 +1183,7 @@ def get_grid_mapping( num_inputs=len(flat_in_specs), num_outputs=len(flat_out_specs), num_scratch_operands=num_flat_scratch_operands, + debug=debug, ) grid_mapping.check_invariants() in_ref_avals = [bm.ref_aval for bm in in_block_mappings] diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 3b28ebdd5d20..163e250dadf0 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -808,6 +808,7 @@ def to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: pallas_core.GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> pallas_core.BlockMapping: bm = super().to_block_mapping( origin, @@ -816,6 +817,7 @@ def to_block_mapping( index_map_tree=index_map_tree, grid=grid, mapped_dims=mapped_dims, + debug=debug, ) block_inner_aval = bm.block_aval.inner_aval for t in self.transforms: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 52360b997743..0165cbad6079 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1131,10 +1131,10 @@ def _ensure_2d_error_shape(arg): retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, *error_memref_aval, *output_aval, *scratch_aval] jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) - debug = api_util.debug_info("checkify_pallas", checked_kernel_fn, + debug_info = api_util.debug_info("checkify_pallas", checked_kernel_fn, retrace_in_avals, {}) wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree) + lu.wrap_init(checked_kernel_fn, debug_info=debug_info), jaxpr_in_tree) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( @@ -1146,13 +1146,18 @@ def _ensure_2d_error_shape(arg): error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) error_origins = tuple(f"errors[{tree_util.keystr(p)}" for p in error_paths) error_block_mappings = map( - partial( - pallas_core._convert_block_spec_to_block_mapping, - index_map_avals=grid_mapping.index_map_avals, - index_map_tree=grid_mapping.index_map_tree, - grid=grid_mapping.grid, - mapped_dims=grid_mapping.vmapped_dims), - error_block_specs, error_origins, shaped_err_avals) + partial( + pallas_core._convert_block_spec_to_block_mapping, + index_map_avals=grid_mapping.index_map_avals, + index_map_tree=grid_mapping.index_map_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.vmapped_dims, + debug=True, + ), + error_block_specs, + error_origins, + shaped_err_avals, + ) input_block_mappings, output_block_mappings = split_list( grid_mapping.block_mappings, [num_kernel_inputs,]) grid_mapping_with_error = grid_mapping.replace( @@ -1396,7 +1401,9 @@ def _pallas_call_state_discharge_rule( index_map_tree=grid_mapping.index_map_tree, grid=grid_mapping.grid, mapped_dims=grid_mapping.mapped_dims, - ) for ref_aval, block_spec in zip(ref_avals, ref_block_specs) + debug=debug, + ) + for ref_aval, block_spec in zip(ref_avals, ref_block_specs) ] in_block_mappings, out_block_mappings = split_list( grid_mapping.block_mappings, [grid_mapping.num_inputs] @@ -1665,8 +1672,14 @@ def wrapped(*args): # TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc. kernel_args, grid_mapping = pallas_core.get_grid_mapping( grid_spec, - flat_in_avals, in_tree, in_origins, - flat_out_avals, out_tree, out_origins) + flat_in_avals, + in_tree, + in_origins, + flat_out_avals, + out_tree, + out_origins, + debug, + ) flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args) flat_kernel_avals = tuple( x.ref if isinstance(x, state_types.TransformedRef) else x From 2281455a4937f8fd95094a5db849833602eb6982 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 9 Jun 2025 07:00:48 -0400 Subject: [PATCH 1584/1769] Don't trigger debug_infs in ndtri unless an inf is returned. --- jax/_src/scipy/special.py | 18 ++++++++++++++---- tests/lax_scipy_special_functions_test.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index d5c99b6e3b4f..ba0ea75b21d9 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -26,9 +26,12 @@ from jax import vmap from jax import lax +from jax._src import api_util +from jax._src import config from jax._src import core from jax._src import custom_derivatives from jax._src import deprecations +from jax._src import dispatch from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact @@ -1048,10 +1051,17 @@ def _create_polynomial(var, coeffs): jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) - infinity = jnp.full(shape, dtype(np.inf)) - x_fix_boundaries = jnp.where( - p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) - return x_fix_boundaries + with config.debug_infs(False): + infinity = jnp.full(shape, dtype(np.inf)) + x = jnp.where( + p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) + if not isinstance(x, core.Tracer): + try: + dispatch.check_special("ndtri", [x]) + except api_util.InternalFloatingPointError as e: + raise FloatingPointError( + f"invalid value ({e.ty}) encountered in ndtri.") from None + return x @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 995854dae348..9fc7619f7145 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -260,6 +260,17 @@ def testNdtriExtremeValues(self): self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) + @parameterized.parameters([True, False]) + def testNdtriDebugInfs(self, with_jit): + # ref: https://github.com/jax-ml/jax/issues/29328 + f = jax.jit(lsp_special.ndtri) if with_jit else lsp_special.ndtri + with jax.debug_infs(True): + f(0.5) # Doesn't crash + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(1.0) + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(0.0) + def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). dtype = jnp.zeros(0).dtype # default float dtype. From a6d95ee3a71c4ef355e600b9d9364101d40c65a2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 21 May 2025 16:58:11 -0700 Subject: [PATCH 1585/1769] [jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1 Previously we disabled the jax2tf_test for older versions of TF. Re-enable for 2.19.1 and higher. --- .github/workflows/ci-build.yaml | 2 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index ada470526ef8..86aaeffaa5fd 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -181,7 +181,7 @@ jobs: run: | pip install uv~=0.5.30 uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt - uv pip install --system --pre tensorflow==2.19.0rc0 + uv pip install --system --pre tensorflow==2.19.0 - name: Run tests env: diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ece88841fdc5..bde148cb514e 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -48,11 +48,17 @@ config.parse_flags_with_absl() -@unittest.skip("Failing after jax 0.6.1 release") class Jax2TfTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") + # One TF device of each device_type self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + @@ -1783,11 +1789,17 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) -@unittest.skip("Failing after jax 0.6.1 release") + class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): self.use_max_serialization_version = False + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") super().setUp() @jtu.ignore_warning( From 6d729fe718e4a8abe2b4b859521b846492d9e384 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 09:51:01 -0700 Subject: [PATCH 1586/1769] Move jax/_src/api.py and associated files to their own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This required a few local imports and refactors. PiperOrigin-RevId: 769184594 --- jax/BUILD | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 588ce697b711..45a9489af2dd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -295,8 +295,6 @@ py_library_providing_imports_info( srcs = [ "_src/__init__.py", "_src/ad_checkpoint.py", - "_src/api.py", # TODO(vanderplas): remove this and depend on :api instead - "_src/array.py", # TODO(vanderplas): remove this and depend on :api instead "_src/blocked_sampler.py", "_src/buffer_callback.py", "_src/callback.py", @@ -304,15 +302,12 @@ py_library_providing_imports_info( "_src/custom_batching.py", "_src/custom_partitioning.py", "_src/debugging.py", - "_src/dispatch.py", # TODO(vanderplas): remove this and depend on :api instead "_src/dlpack.py", "_src/earray.py", "_src/error_check.py", "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", - "_src/interpreters/pxla.py", # TODO(vanderplas): remove this and depend on :api instead - "_src/pjit.py", # TODO(vanderplas): remove this and depend on :api instead "_src/prng.py", "_src/public_test_util.py", "_src/random.py", @@ -373,7 +368,7 @@ py_library_providing_imports_info( ":abstract_arrays", ":ad", ":ad_util", - # ":api", # TODO(vanderplas): add this dependency once downstream targets are fixed + ":api", ":api_util", ":attrs", ":basearray", From 92d5fe88d0a2cd9a8c84bdc2fab85f2995f644b6 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 9 Jun 2025 10:07:11 -0700 Subject: [PATCH 1587/1769] Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea PiperOrigin-RevId: 769190882 --- jax/_src/export/_export.py | 38 ++++++---------------- jax/_src/interpreters/mlir.py | 4 +-- jaxlib/BUILD | 5 +-- jaxlib/py_client.cc | 59 ----------------------------------- tests/export_test.py | 44 +------------------------- 5 files changed, 14 insertions(+), 136 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 5e3c4cf0f209..b390574c0a79 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1431,16 +1431,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, submodule_bc = mlir.module_to_bytecode(submodule) shardy_enabled = _jax.sdy.lowered_with_shardy(submodule_bc) if shardy_enabled: - if not config.use_shardy_partitioner.value: - raise ValueError( - "The function was exported with shardy enabled but you are calling " - "it with Shardy disabled. Please enable Shardy using " - "`--jax_use_shardy_partitioner=True`.") submodule = ir.Module.parse( _jax.sdy.sdy_round_trip_import_shardings(submodule_bc) ) - elif config.use_shardy_partitioner.value: - shardy_enabled = True with submodule.context: pipeline = passmanager.PassManager.parse( @@ -1451,7 +1444,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, if shardy_enabled: sdy_mesh_axes = _jax.sdy.get_mesh(mlir.module_to_bytecode(submodule)) mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1]) - if sdy_mesh_axes else None) + if sdy_mesh_axes else mesh_lib.empty_abstract_mesh) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): @@ -1480,19 +1473,15 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ) # Apply in_shardings - if mesh: - # A mesh only exists if Shardy is enabled. + if shardy_enabled: args = tuple( wrap_with_sharding( ctx, x, x_aval, - _hlo_sharding_to_named_sharding(x_sharding, mesh), use_shardy=True) # type: ignore[arg-type] + _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) else: - # Since there is no mesh - either due to shardy being disabled or the loaded - # function being lowered for GSPMD (so no shardy mesh) - need to create a - # GSPMD sharding from the HLO sharding (can't use shardy lowering). args = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) + wrap_with_sharding(ctx, x, x_aval, x_sharding) for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) symtab = ir.SymbolTable(submodule.operation) @@ -1581,19 +1570,14 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):], exported.out_avals, ctx.avals_out)) # Apply out_shardings - if mesh: - # A mesh only exists if Shardy is enabled. + if shardy_enabled: results = tuple( wrap_with_sharding( - ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh), - use_shardy=True) # type: ignore[arg-type] + ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) else: - # Since there is no mesh - either due to shardy being disabled or the loaded - # function being lowered for GSPMD (so no shardy mesh) - need to create a - # GSPMD sharding from the HLO sharding (can't use shardy lowering). results = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) + wrap_with_sharding(ctx, x, x_aval, x_sharding) for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) return results @@ -1604,14 +1588,12 @@ def wrap_with_sharding( ctx: mlir.LoweringRuleContext, x: ir.Value, x_aval: core.AbstractValue, - x_sharding: sharding_impls.NamedSharding | sharding_impls.GSPMDSharding | HloSharding | None, - use_shardy: bool, + x_sharding: sharding_impls.NamedSharding | HloSharding | None, ) -> ir.Value: if x_sharding is None: return x - if use_shardy: + if config.use_shardy_partitioner.value: x_sharding = x_sharding._to_sdy_sharding(x_aval.ndim) # type: ignore else: x_sharding = x_sharding.to_proto() # type: ignore - return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding, # type: ignore[arg-type] - allow_shardy_lowering=use_shardy) + return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding) # type: ignore[arg-type] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0864ec8646c9..c11a68d7c45f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2731,7 +2731,7 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): - if isinstance(sharding, (SdyArray, SdyArrayList)): + if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: op.attributes["mhlo.sharding"] = get_sharding_attr(sharding) @@ -2740,7 +2740,7 @@ def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): def get_sharding_attr( sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: - if isinstance(sharding, (SdyArray, SdyArrayList)): + if config.use_shardy_partitioner.value: return sharding.build() # type: ignore else: # If there are very large numbers of devices, use the proto representation. diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 14a84bf1f19f..9f495087ccb0 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -888,14 +888,12 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@nanobind", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@stablehlo//:stablehlo_ops", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/profiler/lib:traceme", @@ -935,7 +933,6 @@ cc_library( "@xla//xla/python/pjrt_ifrt:pjrt_dtype", "@xla//xla/python/pjrt_ifrt:xla_ifrt", "@xla//xla/service:platform_util", - "@xla//xla/service/spmd/shardy:constants", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/framework:allocator", "@xla//xla/tsl/platform:env", diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 0ddadbc3038a..f478040e8622 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -35,11 +35,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/Visitors.h" #include "mlir/Pass/PassManager.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -50,7 +48,6 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "shardy/dialect/sdy/ir/dialect.h" #include "jaxlib/guard_lib.h" #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_array.h" @@ -63,7 +60,6 @@ limitations under the License. #include "jaxlib/python_ref_manager.h" #include "jaxlib/sharding.h" #include "jaxlib/traceback.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -93,7 +89,6 @@ limitations under the License. #include "xla/python/types.h" #include "xla/python/version.h" #include "xla/service/platform_util.h" // IWYU pragma: keep -#include "xla/service/spmd/shardy/constants.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" @@ -404,47 +399,6 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, std::move(ifrt_loaded_host_callbacks)); } -// Returns true if the module has at least one GSPMD attribute or op, like an -// `mhlo.sharding` attribute or `Sharding` custom call. -// TODO(b/420837831): delete this once we don't fall back to GSPMD. -bool HasGspmdAttrsOrOps(mlir::ModuleOp module) { - for (auto func : module.getOps()) { - for (int64_t arg_index = 0; arg_index < func.getNumArguments(); - ++arg_index) { - if (func.getArgAttr(arg_index, sdy::kXlaShardingAttr)) { - return true; - } - } - for (int64_t result_index = 0; result_index < func.getNumResults(); - ++result_index) { - if (func.getResultAttr(result_index, sdy::kXlaShardingAttr)) { - return true; - } - } - } - // Check the module for a `Sharding` custom call. - bool has_gspmd = false; - module->walk([&has_gspmd](mlir::stablehlo::CustomCallOp custom_call) { - if (custom_call.getCallTargetName() == - sdy::kShardingCustomCallTargetName && - custom_call->hasAttr(sdy::kXlaShardingAttr)) { - has_gspmd = true; - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - return has_gspmd; -} - -// Check if the module has any sort of Shardy mesh: -// - `mesh` -// - `maximal_mesh_{X}` -// - `empty_mesh` -// TODO(b/420837831): delete this once we don't fall back to GSPMD. -bool HasShardyMesh(mlir::ModuleOp module) { - return !module.getOps().empty(); -} - } // namespace /* static */ absl::StatusOr> @@ -529,19 +483,6 @@ PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); - // TODO(b/420837831): Remove this once we don't need to fall back to GSPMD. - if (options.executable_build_options.use_shardy_partitioner() && - HasGspmdAttrsOrOps(module.get())) { - LOG(WARNING) - << "Module has GSPMD attrs or ops, but Shardy is enabled. Disabling " - "Shardy and falling back to using GSPMD propagation."; - options.executable_build_options.set_use_shardy_partitioner(false); - if (HasShardyMesh(module.get())) { - // Shardy is not enabled, but the module has shardy ops. Likely due to - // export loading a GSPMD checkpoint. Fall back to GSPMD. - TF_RETURN_IF_ERROR(ExportShardyForGSPMD(*module)); - } - } return CompileAndLoadIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), diff --git a/tests/export_test.py b/tests/export_test.py index 2be5313b0b3a..829576e95c7e 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -2008,7 +2008,7 @@ def f(x, y): r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b) self.assertAllClose(a + b, r) - def test_lower_with_different_meshes_axis_names(self): + def test_lower_wth_different_meshes_axis_names(self): mesh1 = jtu.create_mesh((4, 2), ("a", "b")) mesh2 = jtu.create_mesh((4, 2), ("x", "y")) @jax.jit @@ -2033,48 +2033,6 @@ def f(tree): else: get_exported(f)(args) - @jtu.parameterized_filterable( - kwargs=[ - {"use_shardy_on_save": True, "error_msg": "Please enable Shardy"}, - {"use_shardy_on_save": False, "error_msg": ""}, - ]) - def test_lower_load_with_different_partitioners(self, use_shardy_on_save, - error_msg): - old_shardy = config.use_shardy_partitioner.value - try: - jax.config.update("jax_use_shardy_partitioner", use_shardy_on_save) - mesh = jtu.create_mesh((8,), ("a",)) - @jax.jit - def f(x, y): - z = x + y - return jax.lax.with_sharding_constraint( - z, NamedSharding(mesh, P("a"))) - - args = ( - jax.ShapeDtypeStruct( - (32, 32), dtype=np.float32, - sharding=NamedSharding(mesh, P(None, "a"))), - jax.ShapeDtypeStruct( - (32, 32), dtype=np.float32, - sharding=NamedSharding(mesh, P("a")))) - - exp = get_exported(f)(*args) - - jax.config.update("jax_use_shardy_partitioner", not use_shardy_on_save) - - a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) - a = jax.device_put(a, NamedSharding(mesh, P(None, "a"))) - b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) - b = jax.device_put(b, NamedSharding(mesh, P("a"))) - - if use_shardy_on_save: - with self.assertRaisesRegex(ValueError, error_msg): - jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) - else: - jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) - finally: - jax.config.update("jax_use_shardy_partitioner", old_shardy) - if __name__ == "__main__": From 88de1e615861f11737797a873c5f3a03153baa67 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 10:12:15 -0700 Subject: [PATCH 1588/1769] jax.nn.standardize: improve documentation --- jax/_src/nn/functions.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index f01c4fa52804..21ea7ac615a9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -633,12 +633,38 @@ def _softmax_deprecated( @partial(jax.jit, static_argnames=("axis",)) def standardize(x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, - mean: ArrayLike | None = None, - variance: ArrayLike | None = None, - epsilon: ArrayLike = 1e-5, - where: ArrayLike | None = None) -> Array: - r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" + axis: int | tuple[int, ...] | None = -1, + mean: ArrayLike | None = None, + variance: ArrayLike | None = None, + epsilon: ArrayLike = 1e-5, + where: ArrayLike | None = None) -> Array: + r"""Standardizes input to zero mean and unit variance. + + The standardization is given by: + + .. math:: + + x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}} + + where :math:`\langle x\rangle` indicates the mean of :math:`x`, and :math:`\epsilon` is + a small correction factor introduced to avoid division by zero. + + Args: + x: input array to be standardized. + axis: integer or tuple of integers representing the axes along which + to standardize. Defaults to the last axis (``-1``). + mean: optionally specify the mean used for standardization. If not specified, + then ``x.mean(axis, where=where)`` will be used. + variance: optionally specify the variance used for standardization. If not + specified, then ``x.var(axis, where=where)`` will be used. + epsilon: correction factor added to variance to avoid division by zero; defaults + to ``1E-5``. + where: optional boolean mask specifying which elements to use when computing + the mean and variance. + + Returns: + An array of the same shape as ``x`` containing the standardized input. + """ numpy_util.check_arraylike("standardize", x) numpy_util.check_arraylike_or_none("standardize", mean, variance, where) if mean is None: From 42ea2ac3e40a15afb6569671353f189ca334a879 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 11:55:03 -0700 Subject: [PATCH 1589/1769] Move jax/_src/custom_batching.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769236580 --- jax/BUILD | 22 +++++++++++++++++++++- jax/_src/custom_batching.py | 5 +++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 45a9489af2dd..8e30fe146d05 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -299,7 +299,6 @@ py_library_providing_imports_info( "_src/buffer_callback.py", "_src/callback.py", "_src/checkify.py", - "_src/custom_batching.py", "_src/custom_partitioning.py", "_src/debugging.py", "_src/dlpack.py", @@ -380,6 +379,7 @@ py_library_providing_imports_info( ":config", ":core", ":custom_api_util", + ":custom_batching", ":custom_dce", ":custom_derivatives", ":custom_partitioning_sharding_rule", @@ -653,6 +653,26 @@ pytype_strict_library( srcs = ["_src/custom_api_util.py"], ) +pytype_strict_library( + name = "custom_batching", + srcs = ["_src/custom_batching.py"], + deps = [ + ":ad", + ":api", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + pytype_strict_library( name = "custom_dce", srcs = ["_src/custom_dce.py"], diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 83c9ffb5ee36..a8876cd9c86c 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -19,7 +19,6 @@ import functools import operator -from jax import lax from jax._src import api from jax._src import core from jax._src import custom_api_util @@ -394,6 +393,8 @@ def sequential_vmap(f): See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for more details. """ + from jax._src.lax import control_flow # pytype: disable=import-error + f = custom_vmap(f) @f.def_vmap @@ -405,7 +406,7 @@ def to_map(mapped_args): return f(*args) mapped_args, bcast_args = tree_split(in_batched, list(args)) - out = lax.map(to_map, mapped_args) + out = control_flow.map(to_map, mapped_args) out_batched = tree_map(lambda _: True, out) return out, out_batched From 89e0c7ecad79ac0cbdd75e3a9f6664fe7afe85f9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 12:31:27 -0700 Subject: [PATCH 1590/1769] Move jax/_src/earray.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769249808 --- jax/BUILD | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 8e30fe146d05..83c86d1e308d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -302,7 +302,6 @@ py_library_providing_imports_info( "_src/custom_partitioning.py", "_src/debugging.py", "_src/dlpack.py", - "_src/earray.py", "_src/error_check.py", "_src/ffi.py", "_src/flatten_util.py", @@ -386,6 +385,7 @@ py_library_providing_imports_info( ":custom_transpose", ":deprecations", ":dtypes", + ":earray", ":effects", ":environment_info", ":internal_mesh_utils", @@ -761,6 +761,21 @@ pytype_strict_library( ] + py_deps("ml_dtypes") + py_deps("numpy"), ) +pytype_strict_library( + name = "earray", + srcs = ["_src/earray.py"], + deps = [ + ":api", + ":basearray", + ":core", + ":sharding_impls", + ":tree_util", + ":util", + ":xla", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "effects", srcs = ["_src/effects.py"], From 9651e6089e09587a7fdcfa0fe96aa6d5f6e75dce Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 12:45:04 -0700 Subject: [PATCH 1591/1769] [array-api] pin array-api-tests to 2025.05.23 --- .github/workflows/jax-array-api.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 41879a6f2e9f..eaabc54368de 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -30,8 +30,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: data-apis/array-api-tests - # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'c847143beb8d769bde5dbcc063fe19ed7acc2f9b' # Latest commit as of 2025-05-12 + ref: '2025.05.23' submodules: 'true' path: 'array-api-tests' persist-credentials: false From 7c35300e252aebf357a44d6a8dca98a3f52832eb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Jun 2025 12:51:11 -0700 Subject: [PATCH 1592/1769] Fix spelling error in the name of the input variable. 'exectuable' should be 'executable'. PiperOrigin-RevId: 769256903 --- jax/_src/compilation_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 058670642f41..c8a0f0715a2d 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -341,7 +341,7 @@ def combine_executable_and_time( def extract_executable_and_time( - exectuable_and_time: bytes + executable_and_time: bytes ) -> tuple[bytes, int]: """Given the cache entry in the format shown below, extract the serialized executable and the compilation time. @@ -351,5 +351,5 @@ def extract_executable_and_time( Content: compilation time serialized executable (big-endian int) """ - return exectuable_and_time[4:], int.from_bytes( - exectuable_and_time[:4], byteorder='big') + return executable_and_time[4:], int.from_bytes( + executable_and_time[:4], byteorder='big') From 39de7159630e0fdfc6a91ad08bab97bac7623955 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 9 Jun 2025 12:52:56 -0700 Subject: [PATCH 1593/1769] [Mosaic GPU] Error when causal masking is used on cuda versions known to result in a ptxas miscompilation (between 12.8.0 and 12.9.1). PiperOrigin-RevId: 769257583 --- jax/BUILD | 1 + jax/experimental/pallas/ops/gpu/attention_mgpu.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/jax/BUILD b/jax/BUILD index 83c86d1e308d..5f90745b5359 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1029,6 +1029,7 @@ pytype_strict_library( ":pallas", ":pallas_mosaic_gpu", ":test_util", # This is only to make them runnable as jax_multiplatform_test... + "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index d9d62afe93c5..90b8eb702db4 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -20,6 +20,7 @@ import jax from jax import lax from jax._src import test_util as jtu # noqa: F401 +from jax._src.lib import cuda_versions # noqa: F401 from jax.experimental.mosaic.gpu import profiler import jax.experimental.pallas as pl import jax.experimental.pallas.mosaic_gpu as plgpu @@ -62,6 +63,13 @@ def has_backward_blocks(self) -> bool: return self.block_q_dkv is not None def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False): + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if config.causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + raise ValueError( + "Causal masking not supported with cuda versions between 12.8.0 and" + " 12.9.1 due to a ptxas miscompilation." + ) if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -834,6 +842,11 @@ def main(unused_argv): problem_it = itertools.product( (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True)) for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it: + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + continue + if causal and use_pipeline_emitter: continue q_seq_len = kv_seq_len = seq_len From 9a5916245a13fd6125bba2b2dae9e47afdbf6a8a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 13:13:25 -0700 Subject: [PATCH 1594/1769] Move jax/_src/ffi.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769264747 --- jax/BUILD | 25 +++++++++++++++++++++++-- jax/_src/ffi.py | 8 +++++--- jax/extend/BUILD | 5 ++++- tests/BUILD | 5 ++++- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 5f90745b5359..128f03334846 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -303,7 +303,6 @@ py_library_providing_imports_info( "_src/debugging.py", "_src/dlpack.py", "_src/error_check.py", - "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/prng.py", @@ -388,6 +387,7 @@ py_library_providing_imports_info( ":earray", ":effects", ":environment_info", + ":ffi", ":internal_mesh_utils", ":jaxpr_util", ":layout", @@ -791,6 +791,24 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "ffi", + srcs = ["_src/ffi.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":layout", + ":mlir", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "hardware_utils", srcs = ["_src/hardware_utils.py"], @@ -1505,7 +1523,10 @@ pytype_library( exclude = ["experimental/sparse/test_util.py"], ), visibility = ["//visibility:public"], - deps = [":jax"], + deps = [ + ":ffi", + ":jax", + ], ) pytype_library( diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index db943d675b80..8bb6b368d61a 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -23,7 +23,6 @@ import numpy as np -import jax from jax._src import core from jax._src import dispatch from jax._src import effects @@ -662,6 +661,9 @@ def ffi_batching_rule( result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): + from jax._src.lax import control_flow # pytype: disable=import-error + from jax._src.lax import lax # pytype: disable=import-error + axis_size, = {a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped} new_args = [arg if dim is batching.not_mapped else @@ -696,7 +698,7 @@ def ffi_batching_rule( elif vmap_method == "expand_dims" or vmap_method == "broadcast_all": size = axis_size if vmap_method == "broadcast_all" else 1 bcast_args = [ - jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x + lax.broadcast(x, (size,)) if d is batching.not_mapped else x for x, d in zip(new_args, dims)] if kwargs.get("input_layouts") is not None: kwargs["input_layouts"] = tuple( @@ -721,7 +723,7 @@ def _batch_fun(batched_args): ) unroll = vmap_method == "sequential_unrolled" g = lambda _, x: ((), _batch_fun(x)) - _, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll) + _, outvals = control_flow.scan(g, (), batched_args, unroll=unroll) else: raise NotImplementedError( f"vmap is only supported for the {prim.name} primitive when vmap_method " diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 0615a25d1512..61a058ff9189 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -91,7 +91,10 @@ pytype_strict_library( pytype_strict_library( name = "ffi", srcs = ["ffi.py"], - deps = ["//jax"], + deps = [ + "//jax", + "//jax:ffi", + ], ) pytype_strict_library( diff --git a/tests/BUILD b/tests/BUILD index a38f5e91192e..22b733c8210b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -272,7 +272,10 @@ jax_multiplatform_test( "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. - deps = ["//jax:extend"] + py_deps([ + deps = [ + "//jax:extend", + "//jax:ffi", + ] + py_deps([ "absl/testing", "numpy", ]), From bfd0744f95c5148f6198e092da966b4764a539fd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 13:52:22 -0700 Subject: [PATCH 1595/1769] Move jax/_src/custom_partitioning.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769280698 --- jax/BUILD | 23 ++++++++++++++++++++++- jax/_src/custom_partitioning.py | 20 +++++++++++--------- tests/BUILD | 1 + 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 128f03334846..6943685a67e3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -299,7 +299,6 @@ py_library_providing_imports_info( "_src/buffer_callback.py", "_src/callback.py", "_src/checkify.py", - "_src/custom_partitioning.py", "_src/debugging.py", "_src/dlpack.py", "_src/error_check.py", @@ -380,6 +379,7 @@ py_library_providing_imports_info( ":custom_batching", ":custom_dce", ":custom_derivatives", + ":custom_partitioning", ":custom_partitioning_sharding_rule", ":custom_transpose", ":deprecations", @@ -715,6 +715,27 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "custom_partitioning", + srcs = ["_src/custom_partitioning.py"], + deps = [ + ":api", + ":api_util", + ":config", + ":core", + ":custom_api_util", + ":custom_partitioning_sharding_rule", + ":mesh", + ":mlir", + ":partial_eval", + ":sharding", + ":sharding_impls", + ":tree_util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "custom_partitioning_sharding_rule", srcs = ["_src/custom_partitioning_sharding_rule.py"], diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 0d0411f5177a..322aa33d6d30 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -25,16 +25,18 @@ import weakref import numpy as np -import jax -from jax import tree_util + +from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core from jax._src import custom_api_util from jax._src import dispatch +from jax._src import errors from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import sharding_impls +from jax._src import tree_util from jax._src import xla_bridge as xb from jax._src.custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir, SdyShardingRule, str_to_sdy_sharding_rule from jax._src.interpreters import mlir @@ -42,7 +44,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.errors import UnexpectedTracerError +from jax._src.sharding import Sharding def _resolve_kwargs(fun, args, kwargs): @@ -93,7 +95,7 @@ def _to_jax_shape(s): def _to_jax_sharded_shape(s, sharding): - return jax.ShapeDtypeStruct( + return api.ShapeDtypeStruct( s.dimensions(), s.numpy_dtype(), sharding=sharding ) @@ -140,7 +142,7 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.Sharding): + if not isinstance(sharding, Sharding): raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) @@ -178,7 +180,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, _to_jax_shape(sharding.tile(s)) for sharding, s in zip(result_shardings, result_shapes) ] - closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( + closed_jaxpr = api.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( *info.in_tree.unflatten(tiled_args) ) if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != @@ -251,7 +253,7 @@ def _custom_partitioning_impl(*args, call, in_tree, out_tree, def _check_for_tracers(x): if any(isinstance(leaf, core.Tracer) for leaf in tree_util.tree_leaves(x)): - raise UnexpectedTracerError( + raise errors.UnexpectedTracerError( "Found a JAX Tracer object passed as an argument to a" "custom_partitioning function in a position indicated as static by" "static_argnums. " @@ -568,8 +570,8 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): return sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices) pspec = sharding_impls.parse_flatten_op_sharding( hlo_sharding, mesh)[0] - pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) - return jax.sharding.NamedSharding(mesh, pspec) + pspec = sharding_impls.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) + return sharding_impls.NamedSharding(mesh, pspec) sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition, to_mesh_pspec_sharding, in_tree, out_tree, diff --git a/tests/BUILD b/tests/BUILD index 22b733c8210b..15cf4330d28b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1608,6 +1608,7 @@ jax_multiplatform_test( deps = [ "//jax:cache_key", "//jax:compiler", + "//jax:custom_partitioning", ] + py_deps("absl/testing"), ) From df50cd7ef05eea0b1538375818323d6eb8bc8158 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 15:35:49 -0700 Subject: [PATCH 1596/1769] Move jax/_src/buffer_callback.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769320414 --- jax/BUILD | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 6943685a67e3..be9e81706e41 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -296,7 +296,6 @@ py_library_providing_imports_info( "_src/__init__.py", "_src/ad_checkpoint.py", "_src/blocked_sampler.py", - "_src/buffer_callback.py", "_src/callback.py", "_src/checkify.py", "_src/debugging.py", @@ -369,6 +368,7 @@ py_library_providing_imports_info( ":attrs", ":basearray", ":batching", + ":buffer_callback", ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", @@ -536,6 +536,23 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "buffer_callback", + srcs = ["_src/buffer_callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":ffi", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], @@ -1516,6 +1533,7 @@ py_library_providing_imports_info( ), visibility = ["//visibility:public"], deps = [ + ":buffer_callback", ":jax", ] + py_deps("absl/logging") + py_deps("numpy"), ) From 31e9998708d421b3bf2f4746f87e2b743dbcfe2b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 9 Jun 2025 23:10:01 +0000 Subject: [PATCH 1597/1769] [mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_CHECKS=1 --- jax/_src/custom_derivatives.py | 22 +++++++++++++--------- tests/mutable_array_test.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e2d1eb8a6097..e4fc919e1bb1 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -44,7 +44,7 @@ from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, - tree_leaves_with_path, keystr, treedef_children, PyTreeDef) + tree_leaves_with_path, keystr, treedef_children, tree_structure, PyTreeDef) from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2, weakref_lru_cache) @@ -740,20 +740,23 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable return tree_unflatten(out_tree, out_flat) @lu.transformation2 -def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], - debug_info: core.DebugInfo, *args): - _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) +def _check_primal_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, *args): + _check_for_aliased_refs(f, nondiff_argnums, debug, args) out = f(*args) _check_for_returned_refs(f, out, 'primal', [], 0) return out -def _check_for_aliased_refs(f: Callable, - nondiff_argnums: Sequence[int], - debug: core.DebugInfo, - args): +def _check_for_aliased_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, args): + nondiff_argnums_ = set(nondiff_argnums) + argnums = [x for i, arg in enumerate(args) + for x in [i] * tree_structure(arg).num_leaves] leaves = tree_leaves(args) refs: dict[int, int] = {} - for i, x in enumerate(leaves): + for i, (argnum, x) in enumerate(zip(argnums, leaves)): + if argnum in nondiff_argnums: continue + x = x.value if isinstance(x, CustomVJPPrimal) else x if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): arg_names = debug.safe_arg_names(len(leaves)) @@ -764,6 +767,7 @@ def _check_for_aliased_refs(f: Callable, f" {arg_names[i]}.") def _check_for_returned_refs(f, out, kind, args, after_idx): + args = [x.value if isinstance(x, CustomVJPPrimal) else x for x in args] ids = {id(x) for x in args if isinstance(core.get_aval(x), AbstractRef)} leaves = tree_leaves_with_path(out) for i, (path, leaf) in enumerate(leaves): diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index f7dc493ab1fa..0f88ec4c95b5 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +from functools import partial import numpy as np import jax from jax._src import core @@ -456,6 +457,19 @@ def f(ref): ValueError, "custom_vjp primal function"): f(x_ref) + @parameterized.parameters([False, True]) + def test_return_from_custom_vjp_primal_nondiff_argnum(self, jit): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(_, ref): + return ref + f.defvjp(lambda _, ref: ..., lambda *_: ...) + if jit: + f = jax.jit(f, static_argnums=0) + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, "custom_vjp primal function"): + f('hi', x_ref) + @parameterized.parameters([False, True]) def test_return_from_custom_vjp_fwd(self, jit): @jax.custom_vjp From a58b27c62e80ad2657217b03f87070e0327af1ca Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Jun 2025 17:15:03 -0700 Subject: [PATCH 1598/1769] Move jax/_src/shard_alike.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769356578 --- jax/BUILD | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index be9e81706e41..eec0a0c0eafb 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -306,7 +306,6 @@ py_library_providing_imports_info( "_src/prng.py", "_src/public_test_util.py", "_src/random.py", - "_src/shard_alike.py", "_src/shard_map.py", ] + glob( [ @@ -403,6 +402,7 @@ py_library_providing_imports_info( ":pickle_util", ":pretty_printer", ":profiler", + ":shard_alike", ":sharding", ":sharding_impls", ":sharding_specs", @@ -1267,6 +1267,24 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "shard_alike", + srcs = [ + "_src/shard_alike.py", + ], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "stages", srcs = ["_src/stages.py"], From 14aaa45f598b5c666d814c09ee551dd87f3ca01b Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Jun 2025 18:20:01 -0700 Subject: [PATCH 1599/1769] [Pallas TPU][NFC] Use register to track buffer slots in pipeline loop PiperOrigin-RevId: 769376932 --- jax/_src/pallas/mosaic/pipeline.py | 124 +++++++++++++++++++---------- 1 file changed, 82 insertions(+), 42 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 659146a4ad7e..367458d55e4a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -286,10 +286,18 @@ def init_slots(self): """Initialize slot indices.""" raise NotImplementedError() - def swap_slots(self): + def swap_slots(self, predicate: bool = True) -> "BufferedRefBase": """Switch to the next slot.""" raise NotImplementedError() + def load_slots(self) -> "BufferedRefBase": + """Load slot information into registers.""" + raise NotImplementedError() + + def save_slots(self): + """Save slot information from registers.""" + raise NotImplementedError() + @property def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: return self.spec.block_shape @@ -399,6 +407,7 @@ class BufferedRef(BufferedRefBase): _spec: pl.BlockSpec # static metadata dtype: Any # static metadata _buffer_type: BufferType # static metadata + _current_slot_reg: int | jax.Array | None window_ref: ArrayRef | None accum_ref: ArrayRef | None current_slot: ArrayRef | None @@ -419,6 +428,7 @@ def buffer_type(self): def tree_flatten(self): return ( ( + self._current_slot_reg, self.window_ref, self.accum_ref, self.current_slot, @@ -465,6 +475,7 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True _spec=spec, dtype=dtype, _buffer_type=buffer_type, + _current_slot_reg=None, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, @@ -478,6 +489,7 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True _spec=spec, dtype=dtype, _buffer_type=buffer_type, + _current_slot_reg=None, window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), @@ -526,6 +538,12 @@ def with_spec(self, spec: pl.BlockSpec) -> 'BufferedRef': """Returns a new BufferedRef with the given block spec.""" return dataclasses.replace(self, _spec=spec) + def with_slot_index( + self, slot_index: int | jax.Array | None + ) -> "BufferedRef": + """Returns a new BufferedRef with the given slot index.""" + return dataclasses.replace(self, _current_slot_reg=slot_index) + @property def current_ref(self): buffer_slice = tuple( @@ -542,6 +560,7 @@ def current_ref(self): @property def current_slot_index(self): """Index in double buffer corresponding to the current slot.""" + # TODO(ramiroleal): Fix race condition when returning register value for current_slot. return self.current_slot[0] @property @@ -590,12 +609,35 @@ def init_slots(self): if self.swap is not None: self.swap[0] = False - def swap_slots(self): - """Switch to the next slot.""" - if self.memory_space == VMEM: return - self.current_slot[0] = self.next_slot_index + def swap_slots(self, predicate: bool | jax.Array = True) -> "BufferedRef": + if self.memory_space == VMEM: + return self if self.swap is not None: + assert isinstance(self.swap, jax.Array) + predicate = self.swap[0] self.swap[0] = False + new_current_slot = lax.select( + predicate, self.next_slot_index, self.current_slot_index + ) + result = self.with_slot_index(new_current_slot) + # TODO(ramiroleal): Fix race condition when using register value for current_slot. + result.save_slots() + return result + + def load_slots(self) -> "BufferedRef": + """Load slot information into registers.""" + if self.memory_space == VMEM: + return self + assert isinstance(self.current_slot, jax.Array) + return self.with_slot_index(self.current_slot[0]) + + def save_slots(self): + """Save slot information from registers.""" + if self.memory_space == VMEM: + return + assert isinstance(self.current_slot, jax.Array) + assert self._current_slot_reg is not None + self.current_slot[0] = self._current_slot_reg def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" @@ -850,18 +892,20 @@ def alias_local_refs(self, buffered_ref, ref): def initialize(self, buffered_ref, src_ref, schedule=None): if schedule is None: schedule = _default_schedule - pred = schedule["prologue_copy_in"](self, buffered_ref, src_ref) + do_copy = schedule["prologue_copy_in"](self, buffered_ref, src_ref) with self._named_scope("ep_initialize"): @pl.when(self.first_step_ever) def _init_slots(): buffered_ref.init_slots() - @pl.when(pred) - def _start(): - if buffered_ref.is_input: - buffered_ref.copy_in(src_ref, self.indices) - buffered_ref.swap_slots() + buffered_ref = buffered_ref.load_slots() + + @pl.when(do_copy & buffered_ref.is_input) + def _copy_in(): + buffered_ref.copy_in(src_ref, self.indices) + + return buffered_ref.swap_slots(do_copy & buffered_ref.is_input) def wait_in(self, buffered_ref, src_ref, schedule=None): if schedule is None: @@ -968,30 +1012,23 @@ def finalize(self, buffered_ref, dst_ref, schedule=None): def _end(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.indices) + buffered_ref.save_slots() + + def swap_slots( + self, buffered_ref, hbm_ref, schedule=None + ) -> "BufferedRefBase": + # All the copies into and out of BufferedRefs are done by direct + # calls to the `copy_in` and `copy_out` methods in the pipeline + # loop. To determine if the BufferedRef needs a swap of slots, we + # recalculate the copy-in/copy-out conditions. + if schedule is None: + schedule = _default_schedule + pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) + pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) - def swap_slots(self, buffered_ref, hbm_ref, schedule=None): - if isinstance(buffered_ref, BufferedRef) and buffered_ref.swap is not None: - swap = buffered_ref.swap[0] - else: - # If we are not using an SMEM `swap` tensor to keep track of - # swaps needed, then all the copies into and out of BufferedRefs - # are done by direct calls to the `copy_in` and `copy_out` - # methods in the pipeline loop. To determine if the BufferedRef - # needs a swap of slots, we recalculate the copy-in/copy-out - # conditions. - if schedule is None: - schedule = _default_schedule - pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) - pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) - - copied_in = pred_in & buffered_ref.is_input & ~self.last_step - copied_out = pred_out & buffered_ref.is_output - swap = copied_in | copied_out - - @pl.when(swap) - @self._named_scope("ep_swap") - def _swap(): - buffered_ref.swap_slots() + copied_in = pred_in & buffered_ref.is_input & ~self.last_step + copied_out = pred_out & buffered_ref.is_output + return buffered_ref.swap_slots(copied_in | copied_out) # END SCHEDULE -------------------------------------------------------------- @@ -1373,12 +1410,13 @@ def make_scheduler(step, indices): trace_scopes=trace_scopes, ) - def loop_body(step, indices): + def loop_body(step, carry): + unaliased_brefs, indices = carry scheduler = make_scheduler(step, indices) with scheduler.grid_env(): # prepare any local VMEM aliases - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + brefs = map_brefs(scheduler.alias_local_refs, unaliased_brefs, refs) # loop input handling phase map_brefs(scheduler.copy_in, brefs, refs, schedule) @@ -1408,25 +1446,27 @@ def loop_body(step, indices): lambda: postyeet(*brefs, scheduler), lambda: None) - map_brefs(scheduler.swap_slots, brefs, refs, schedule) - return _next_index(indices, grid) + next_brefs = map_brefs( + scheduler.swap_slots, unaliased_brefs, refs, schedule + ) + return next_brefs, _next_index(indices, grid) @pl.when(num_steps > 0) def _(): # pipeline prologue initial_indices = (0,) * len(grid) scheduler = make_scheduler(0, initial_indices) - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) with scheduler.grid_env(): - map_brefs(scheduler.initialize, brefs, refs, schedule) + brefs = map_brefs(scheduler.initialize, allocations, refs, schedule) # pipeline loop - next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) + brefs, next_indices = lax.fori_loop( + 0, num_steps, loop_body, (brefs, initial_indices) + ) # pipeline epilogue final_indices = _prev_index(next_indices, grid) scheduler = make_scheduler(num_steps - 1, final_indices) - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) with scheduler.grid_env(): map_brefs(scheduler.finalize, brefs, refs, schedule) From 34cee968ac4e0a2fa9dc5c04d2fa5ae090ff3d94 Mon Sep 17 00:00:00 2001 From: Jen Ha Date: Mon, 9 Jun 2025 18:31:20 -0700 Subject: [PATCH 1600/1769] add reference to pr-checklist --- docs/contributing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contributing.md b/docs/contributing.md index 53a863fdcd8c..635f5e899161 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -30,7 +30,7 @@ We do all of our development using git, so basic knowledge is assumed. Follow these steps to contribute code: 1. Sign the [Google Contributor License Agreement (CLA)](https://cla.developers.google.com/). - For more information, see the Pull Request Checklist below. + For more information, see the {ref}`pr-checklist` below. 2. Fork the JAX repository by clicking the **Fork** button on the [repository page](http://www.github.com/jax-ml/jax). This creates From b22be8666b5344cea0a804f1b5cee4bd7b9f4520 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Jun 2025 18:52:15 -0700 Subject: [PATCH 1601/1769] Remove forward_compat check for alpha as it is past the support date. PiperOrigin-RevId: 769385812 --- jax/_src/lax/linalg.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3ee7cc2a6807..10755ccb2ec1 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -48,7 +48,6 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2428,15 +2427,7 @@ def _triangular_solve_cpu_lower( conjugate_a = False if np.dtype(a_aval.dtype) in _cpu_lapack_types: target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - # TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18. - if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1): - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)), - alpha_aval = ShapedArray((), a_aval.dtype), - batch_partitionable = False - else: - alpha = () - alpha_aval = () - batch_partitionable = True + alpha, alpha_aval, batch_partitionable = (), (), True rule = _linalg_ffi_lowering(target_name, [a_aval, b_aval, *alpha_aval], operand_output_aliases={1: 0}, From 28c31b8c39a583b1e4740432871df782dc976c05 Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Mon, 9 Jun 2025 20:40:48 -0700 Subject: [PATCH 1602/1769] Update JAX test to not rely on ToString and instead check the Device Assignment values. PiperOrigin-RevId: 769417940 --- tests/xla_bridge_test.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 5c7472492fd5..e87ad52ed89f 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -17,7 +17,6 @@ from absl import logging from absl.testing import absltest - from jax import version from jax._src import compiler from jax._src import config @@ -36,18 +35,14 @@ class XlaBridgeTest(jtu.JaxTestCase): def test_set_device_assignment_no_partition(self): compile_options = compiler.get_compile_options( num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3]) - expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: " - "0 1 2 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 4) + self.assertEqual(compile_options.device_assignment.computation_count(), 1) def test_set_device_assignment_with_partition(self): compile_options = compiler.get_compile_options( num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]]) - expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: " - "0 2 \nComputation 1: 1 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 2) + self.assertEqual(compile_options.device_assignment.computation_count(), 2) def test_set_fdo_profile(self): compile_options = compiler.get_compile_options( From 65c14a2cb0f17df3e9119313fe7a3b7112ed7a30 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 9 Jun 2025 13:39:38 -0700 Subject: [PATCH 1603/1769] [pallas] Fix shard_map + Megacore in TPU interpret mode. The fix involves threading the axis sizes and device ID through the interpretation code, so that we only query the axis sizes and axis indices once at the start of a kernel on a device. Also fixes a bug in the TPU interpret mode validation code that runs at the end of each kernel invocation. --- jax/_src/pallas/mosaic/interpret.py | 83 ++++++++++++------- .../tpu_pallas_interpret_distributed_test.py | 61 ++++++++++++++ 2 files changed, 116 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 7a6c18d43bb7..6d68a2a7e931 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -642,15 +642,18 @@ def _validate(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() + local_core_ids = tuple(range(shared_memory.num_cores_per_device)) with shared_memory.lock: for sem in shared_memory.sem.values(): with sem.cv: - if sem.counts[device_id] != 0: - # TODO(jburnim): Make this raise an error, but in a way that doesn't - # cause other devices to hang later in `_clean_up_shared_memory`. - print( - f'Semaphore {sem.id} has non-zero count for {device_id} at ' - f'kernel exit: {sem.counts[device_id]}') + for lci in local_core_ids: + global_core_id = _get_global_core_id(device_id, lci) + if sem.counts[global_core_id] != 0: + # TODO(jburnim): Make this raise an error, but in a way that doesn't + # cause other devices to hang later in `_clean_up_shared_memory`. + print( + f'Semaphore {sem.id} has non-zero count for {device_id} ' + f' (core {lci}) at kernel exit: {sem.counts[global_core_id]}') def _allocate_buffer( device_id: Array, @@ -1354,7 +1357,15 @@ class Placeholder: def _interpret_jaxpr( - jaxpr, *args, mesh, local_core_id, compiler_params, interpret_params + jaxpr, + *args, + axis_sizes, + mesh, + axis_indices, + device_id, + local_core_id, + compiler_params, + interpret_params ): env = {} @@ -1374,20 +1385,15 @@ def write(var, value): jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) - # Get the device ID. - axis_sizes = jax_core.get_axis_env().axis_sizes - device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) - # TODO(jburnim): Pass the device ID around, instead of re-fetching/computing - # it for each sub-jaxpr. - # TODO(jburnim): Clean up and finish this evaluation loop. For example: # - Replace the big if-statement with a dictionary of rules. # - Handle other higher-order primitives? _interpret = functools.partial( _interpret_jaxpr, + axis_sizes=axis_sizes, mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, local_core_id=local_core_id, compiler_params=compiler_params, interpret_params=interpret_params, @@ -1459,6 +1465,13 @@ def write(var, value): # querying our index along the core axis, so return our core ID. out = local_core_id + elif ((prim is lax.axis_index_p) + and (eqn.params['axis_name'] in axis_indices)): + # We replace lax.axis_index calls in the kernel body, so that the + # kernel body jaxpr can be run on other threads (via an io_callback) + # without having to recreate the axis environment in those threads. + out = axis_indices[eqn.params['axis_name']] + elif prim is lax.cond_p: def _make_branch(jaxpr): return lambda *args: _interpret(jaxpr, *args) @@ -1705,15 +1718,19 @@ def f(*args, jaxpr): return jax._src.util.safe_map(read, jaxpr.outvars) def _compute_start_indices( - block_mapping, loop_idx, local_core_id, - *args, mesh, compiler_params, interpret_params): + block_mapping, loop_idx, *args, + axis_sizes, mesh, axis_indices, device_id, local_core_id, + compiler_params, interpret_params): jaxpr = block_mapping.index_map_jaxpr block_indices = _interpret_jaxpr( jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, + axis_sizes=axis_sizes, mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, local_core_id=local_core_id, compiler_params=compiler_params, interpret_params=interpret_params, @@ -1941,16 +1958,17 @@ def _run(jaxpr, consts, *args): traced.lower().compile()(consts, *args) return +import concurrent.futures + def _thread_map_callback(jaxpr, num_threads, consts): num_threads = int(num_threads) threads = [] - for i in range(num_threads): - threads.append( - threading.Thread(target=_run_jaxpr, args=(jaxpr, consts, jnp.int32(i)))) - for i in range(num_threads): - threads[i].start() - for i in range(num_threads): - threads[i].join() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for i in range(num_threads): + threads.append( + executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i))) + for i in range(num_threads): + threads[i].result() def _call_threadmap_callback(jaxpr, num_threads, *consts): # NOTE: At runtime, _thread_map_callback will lower and compile the @@ -2006,9 +2024,9 @@ def interpret_pallas_call( axis_sizes = jax_core.get_axis_env().axis_sizes num_devices = functools.reduce( jnp.multiply, axis_sizes.values(), jnp.int32(1)) + axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) + tuple(axis_indices.values()), axis_sizes) callback.io_callback( functools.partial( _initialize_shared_memory, interpret_params=interpret_params), @@ -2271,9 +2289,12 @@ def _body( _compute_start_indices( bm, next_grid_point, - core_index, *scalar_buffer_ids, + axis_sizes=axis_sizes, mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, compiler_params=compiler_params, interpret_params=interpret_params, ) @@ -2341,7 +2362,10 @@ def _store_slice_to_kernel_input(index, input_var): _interpret_jaxpr( jaxpr, *kernel_buffer_ids, + axis_sizes=axis_sizes, mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, local_core_id=core_index, compiler_params=compiler_params, interpret_params=interpret_params, @@ -2419,9 +2443,12 @@ def _store_to_output_buffer(index, output_var): _compute_start_indices( bm, initial_grid_point, - core_index, *scalar_buffer_ids, + axis_sizes=axis_sizes, mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, compiler_params=compiler_params, interpret_params=interpret_params, ) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index ddfe8bcde4f4..62772bd7e298 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -1082,6 +1082,67 @@ def run(src_dst_ids): run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) + @parameterized.parameters(1, 2, 4) + def test_shard_map_of_core_map(self, num_cores): + num_devices = jax.device_count() + partition = P('x', None) + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + core_mesh = pltpu.create_tensorcore_mesh('core', num_cores=num_cores) + interpret = pltpu.InterpretParams(detect_races=True) + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(core_mesh, interpret=interpret) + def _(): + num_cores = jax.lax.axis_size('core') + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('core') + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + dma_sem, + ).wait() + y = (x_vmem_ref[...] + num_cores * jax.lax.axis_index('x') + + core_index + 1) + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ) + _, y = pl.run_state(inner)((x, y)) + return y + + x = jnp.arange(num_devices * 16 * 128, dtype=jnp.int32).reshape((-1, 128)) + y = jax.jit( + shard_map.shard_map(f, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_vma=False, + ) + )(x).block_until_ready() + expected_out = ( + x.reshape((num_devices, num_cores, -1, 128)) + 1 + + jnp.arange(num_devices, dtype=jnp.int32)[..., None, None, None] * num_cores + + jnp.arange(num_cores, dtype=jnp.int32)[None, ..., None, None] + ).reshape(x.shape) + np.testing.assert_array_equal(y, expected_out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From c126ee3d8a56077107fd009a183508cd48d9a75d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Jun 2025 22:32:29 -0700 Subject: [PATCH 1604/1769] Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile. It is probably a more useful default behavior not to implicitly inline everything. PiperOrigin-RevId: 769452443 --- jax/_src/jaxpr_util.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index cb9eef0b9ea2..b26c51483773 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -33,11 +33,23 @@ zip, unsafe_zip = util.safe_zip, zip -def all_eqns(jaxpr: core.Jaxpr) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: +def _all_eqns( + jaxpr: core.Jaxpr, visited: set[core.Jaxpr] | None, +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: for eqn in jaxpr.eqns: yield (jaxpr, eqn) for subjaxpr in core.subjaxprs(jaxpr): - yield from all_eqns(subjaxpr) + if visited is None: + yield from _all_eqns(subjaxpr, visited) + elif subjaxpr not in visited: + visited.add(subjaxpr) + yield from _all_eqns(subjaxpr, visited) + +def all_eqns( + jaxpr: core.Jaxpr, revisit_inner_jaxprs: bool = True +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: + yield from _all_eqns(jaxpr, None if revisit_inner_jaxprs else set()) + def collect_eqns(jaxpr: core.Jaxpr, key: Callable): d = defaultdict(list) @@ -206,7 +218,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: """ d = Counter( (eqn.source_info.traceback, eqn.primitive) - for _, eqn in all_eqns(jaxpr) + for _, eqn in all_eqns(jaxpr, revisit_inner_jaxprs=False) ) return _pprof_profile(d) From cc971e354fcec6a62a557aab95176486a621c31f Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Jun 2025 23:20:11 -0700 Subject: [PATCH 1605/1769] Automated Code Change PiperOrigin-RevId: 769466833 --- jaxlib/py_compile_only_client.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 49f68ec1e24f..f23f09c265a1 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinOps.h" From b59a97f0f04e41f3daf7101885ed317120c06236 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Jun 2025 23:29:10 -0700 Subject: [PATCH 1606/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/28c9c6fbcc5a63f9d05fd4d65ee1e509bb313e85. PiperOrigin-RevId: 769469509 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ebc920cbd625..a71ca8c6ff32 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d556913b23808bb93b13b576eb4b74e901fd52a5" -XLA_SHA256 = "8193f25819af97211a0cac09606fe46d28758105f16eb252b7c04adaac7306df" +XLA_COMMIT = "28c9c6fbcc5a63f9d05fd4d65ee1e509bb313e85" +XLA_SHA256 = "b7638d60823c51700ab28451e32e440e8d9856f81b7c6838cc72f8a575fdba83" def repo(): tf_http_archive( From 31ef2cf243776f47cc5586be5d3c10badf08f2c2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 10 Jun 2025 09:35:26 +0300 Subject: [PATCH 1607/1769] Improve batching for lax.platform_dependent Fixed: #29329 --- jax/_src/interpreters/ad.py | 2 +- jax/_src/lax/control_flow/conditionals.py | 10 ++++++---- tests/lax_control_flow_test.py | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 51f007f4e576..f3694cf4cbe1 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -277,7 +277,7 @@ def linearize(traceable: lu.WrappedFun, raise ValueError( "Linearization failed to produce known values for all output primals. " "This is typically caused by attempting to differentiate a function " - "uses an operation that does not support reverse-mode autodiff.") + "using an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4360c4a6df3b..c270b54f8713 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -490,11 +490,13 @@ def _cond_batching_rule(axis_data, args, dims, *, branches, **params): raise NotImplementedError( "IO effect not supported in vmap-of-cond.") + if "branches_platforms" in params and (index_dim is not batching.not_mapped): + # If we end up with a mapped index for a platform_dependent cond, we can + # replace the index with a fresh call to platform_index. See #29329. + index = platform_index_p.bind(platforms=params["branches_platforms"]) + index_dim = batching.not_mapped if index_dim is not batching.not_mapped: - assert "branches_platforms" not in params, ( - "The index of a cond with branches_platforms should be a " - "platform_index and should never be mapped") # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave @@ -563,7 +565,7 @@ def _cond_jvp(primals, tangents, *, branches, **params): return out_primals, out_tangents def _cond_partial_eval(trace, *tracers, branches, **params): - in_unknowns = [t.pval[0] is not None for t in tracers] + in_unknowns = [not t.pval.is_known() for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3f950f865735..954890a973cc 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3072,6 +3072,26 @@ def f(x): self.assertEqual(expect_a_dot, " dot(" in hlo) self.assertEqual(not expect_a_dot, " while(" in hlo) + def test_issue_29329(self): + + def outer_fn(x): + def inner_fn(x): + return jax.jit( + lambda x: lax.platform_dependent(x, + default=jnp.sin, + other=jnp.cos))(x) + + _, lin_fn = jax.linearize(inner_fn, x) + + def with_transpose(x): + grad = jax.linear_transpose(lin_fn, x)(x) + del grad + return x + + return jax.lax.cond(x[0][0] > 0., with_transpose, lambda x: x, x) + + jax.vmap(outer_fn)(jnp.ones((5, 10, 10))) + def test_scan_lowering_doesnt_introduce_singleton(self): b = 4 i = 2 From be23dcf228a68e116e3e568a4a970e08caaf45e2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 10 Jun 2025 04:21:55 -0700 Subject: [PATCH 1608/1769] Move jax/_src/public_test_util.py to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 769565164 --- jax/BUILD | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index eec0a0c0eafb..3ec1567e2dc9 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -210,6 +210,7 @@ py_library( deps = [ ":compilation_cache_internal", ":jax", + ":public_test_util", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -304,7 +305,6 @@ py_library_providing_imports_info( "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/prng.py", - "_src/public_test_util.py", "_src/random.py", "_src/shard_map.py", ] + glob( @@ -402,6 +402,7 @@ py_library_providing_imports_info( ":pickle_util", ":pretty_printer", ":profiler", + ":public_test_util", ":shard_alike", ":sharding", ":sharding_impls", @@ -1256,6 +1257,19 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "public_test_util", + srcs = [ + "_src/public_test_util.py", + ], + deps = [ + ":api", + ":config", + ":dtypes", + ":tree_util", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "sharding", srcs = ["_src/sharding.py"], From f053dfed070045f6e26892984a2b33749fdc8503 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 10 Jun 2025 04:52:13 -0700 Subject: [PATCH 1609/1769] [Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the presence of `RefUnion`s. PiperOrigin-RevId: 769574205 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/lowering.py | 15 ++++++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 3 ++- tests/pallas/mosaic_gpu_test.py | 25 ++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 78a1bd4f0011..6c5320e9e84a 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -62,6 +62,7 @@ pytype_strict_library( "//jax", "//jax:api", "//jax:core", + "//jax:dtypes", "//jax:mesh", "//jax:mlir", "//jax:mosaic_gpu", diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 24f219df49a5..d17d8aff6801 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -31,6 +31,7 @@ from jax import lax from jax._src import checkify from jax._src import core as jax_core +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit @@ -1299,15 +1300,25 @@ def _extract_aliased_ref( return ref, transforms +def _transform_dtype( + dtype: dtypes.DType, + transforms: Sequence[state_types.Transform], +) -> dtypes.DType: + """Applies `t.transform_dtype` for `t` in `transforms` sequentially on `dtype`.""" + for transform in transforms: + dtype = transform.transform_dtype(dtype) + return dtype + + def _handle_transforms( ctx: LoweringRuleContext, ref: RefOrTmemType, - transforms: Sequence[gpu_core.Transform], + transforms: Sequence[state_types.Transform], *, handle_transposes=True, handle_reshapes=True, allow_peer_refs=False, -) -> tuple[RefOrTmemType, Sequence[gpu_core.Transform]]: +) -> tuple[RefOrTmemType, Sequence[state_types.Transform]]: if isinstance(ref, tcgen05.TMEMRef): mlir_dtype = ref.dtype else: diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 862f3affda14..ffc9d0623fff 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -75,8 +75,9 @@ def _check_ref( def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): del layout, optimized # Unused. transforms = args_tree.unflatten(avals_flat) + dtype = lowering._transform_dtype(src.dtype, transforms) return ( - jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), + jax_core.ShapedArray(transforms[-1].get_indexer_shape(), dtype), {state.ReadEffect(0)}, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6817a80c1005..1857068787ba 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1739,6 +1739,31 @@ def kernel(x_ref, o_ref128, aliased_ref): with self.assertRaisesRegex(ValueError, "can't be assigned to"): kernel(jnp.arange(128).astype(jnp.float32)) + def test_loading_from_ref_union_works(self): + # `load_p` does not have a defined lowering for warpgroup semantics. + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))] * 2, + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32)), + plgpu.SMEM((128,), jnp.float32)], + ) + def kernel(x_ref, y_ref, o_ref128, ref_union, o_smem): + [aliased_ref] = ref_union + aliased_ref[...] = x_ref[...] + plgpu.commit_smem() + load_ref = lambda r: plgpu.load(r, (), layout=plgpu.Layout.TCGEN05_ROW) + # This is a regression test for b/423697560, where we used to fail to + # transform the dtype correctly when processing an aliased ref. + o_smem[...] = load_ref(aliased_ref) + load_ref(y_ref) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref128) + + x, y = [jnp.arange(128).astype(jnp.float32) for _ in range(2)] + np.testing.assert_array_equal(kernel(x, y), x + y) + @parameterized.parameters(1, 2, 3) def test_nd_loop(self, sm_steps): @functools.partial( From 4846ed2f77eeafac1cd2ab6599df770f4e5066e8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Jun 2025 07:58:58 -0700 Subject: [PATCH 1610/1769] Pass source_info to custom_staging_rules and into jaxpr inlining. This avoids the need to collect two redundant Python tracebacks when staging an inner jit, which is a very common thing to do. PiperOrigin-RevId: 769632609 --- jax/_src/interpreters/partial_eval.py | 8 ++++---- jax/_src/lax/control_flow/loops.py | 4 +++- jax/_src/lax/lax.py | 21 ++++++++++++--------- jax/_src/lax/slicing.py | 7 ++++--- jax/_src/pjit.py | 7 ++----- tests/hijax_test.py | 3 ++- 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1503889b04bc..7a1fba94bb3d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2002,7 +2002,8 @@ def process_primitive(self, primitive, tracers, params): to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) + return custom_staging_rules[primitive](self, source_info, *jaxpr_tracers, + **params) return self.default_process_primitive( primitive, jaxpr_tracers, params, source_info) @@ -2705,10 +2706,9 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): return tracer def inline_jaxpr_into_trace( - trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], - *arg_tracers: DynamicJaxprTracer) -> list[Any]: + trace: DynamicJaxprTrace, src: SourceInfo, jaxpr: Jaxpr, + consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - src = source_info_util.current() const_tracers = map(partial(trace.new_const, source_info=src), consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6be65485bc7a..8857ffad7a5f 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -622,7 +622,9 @@ def _empty_array(prefix, length_spec, aval): eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr): +def _stage_jaxpr(trace: pe.JaxprTrace, source_info, *tracers, + jaxpr: core.ClosedJaxpr): + del source_info params = dict(call_jaxpr=jaxpr) return trace.default_process_primitive(core.closed_call_p, tracers, params) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 59363113d7b3..b0f17129bc77 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -266,8 +266,8 @@ def _merge_dyn_shape( assert next(dyn_shape_it, None) is None return shape -def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params): - source_info = source_info_util.current() +def _dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, + **params): out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info) eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args], [trace.makevar(out_tracer)], @@ -6520,14 +6520,14 @@ def _broadcast_in_dim_fwd_rule(eqn): return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn, shape, broadcast_dimensions, sharding): + trace, source_info, x, *dyn, shape, broadcast_dimensions, sharding): params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) if not dyn: return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn, - **params) + return _dyn_shape_staging_rule(trace, source_info, broadcast_in_dim_p, aval, + x, *dyn, **params) def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, shape, broadcast_dimensions): @@ -7241,12 +7241,13 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] def _reshape_staging_rule( - trace, x, *dyn, new_sizes, dimensions, sharding): + trace, source_info, x, *dyn, new_sizes, dimensions, sharding): params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) if not dyn: return trace.default_process_primitive(reshape_p, (x,), params) av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) + return _dyn_shape_staging_rule(trace, source_info, reshape_p, av, x, *dyn, + **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, 'reshape', sharding_rule=_reshape_sharding_rule, @@ -8594,13 +8595,15 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): iota_p.def_abstract_eval(_iota_abstract_eval) batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule -def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): +def _iota_staging_rule(trace, source_info, *dyn_shape, dtype, shape, dimension, + sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) if not dyn_shape: return trace.default_process_primitive(iota_p, (), params) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params) + return _dyn_shape_staging_rule(trace, source_info, iota_p, aval, *dyn_shape, + **params) pe.custom_staging_rules[iota_p] = _iota_staging_rule def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index ad8a2cf0b315..cacc1d14e42b 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1542,15 +1542,16 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) -def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_staging_rule(trace, source_info, x, *starts_and_dyn_sizes, + slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim]) if not dyn: return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes)) shape = lax._merge_dyn_shape(slice_sizes, dyn) aval = core.DShapedArray(shape, x.dtype, False) - return lax._dyn_shape_staging_rule(trace, dynamic_slice_p, aval, x, - *starts_and_dyn_sizes, + return lax._dyn_shape_staging_rule(trace, source_info, dynamic_slice_p, aval, + x, *starts_and_dyn_sizes, slice_sizes=slice_sizes) def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 32e1f6299931..4b085bf124cb 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1958,7 +1958,7 @@ def _pjit_lower_cached( pgle_profiler=pgle_profiler) -def pjit_staging_rule(trace, *args, **params): +def pjit_staging_rule(trace, source_info, *args, **params): # If we're inlining, no need to compute forwarding information; the inlined # computation will in effect forward things. if (params["inline"] and @@ -1976,13 +1976,10 @@ def pjit_staging_rule(trace, *args, **params): propagate_source_info=False) else: out = pe.inline_jaxpr_into_trace( - trace, jaxpr.jaxpr, jaxpr.consts, *args) - source_info = source_info_util.current() + trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) return [trace.to_jaxpr_tracer(x, source_info) for x in out] jaxpr = params['jaxpr'] - source_info = source_info_util.current() - consts = [] if config.dynamic_shapes.value: jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( jaxpr, params['out_shardings'], params['out_layouts']) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index b272e0aa8986..d033404879ff 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -49,7 +49,8 @@ def __init__(self, name): ad.primitive_transposes[self] = self.transpose pe.custom_staging_rules[self] = self.staging - def staging(self, trace, *args, **kwargs): + def staging(self, trace, source_info, *args, **kwargs): + del source_info trace.frame.is_high = True return trace.default_process_primitive(self, args, kwargs) From 5d64c3952283fccd193a2e26777167e38091ec50 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 7 Jun 2025 22:44:21 +0000 Subject: [PATCH 1611/1769] [mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd In particular, make this work: `JAX_MUTABLE_ARRAY_CHECKS=1 python tests/mutable_array_test.py test_custom_vjp_grad_stats_plumbing_basic1` As in #29311, the issue was that `partial_eval_jaxpr` would build a forward pass jaxpr that returned all the nonlinear inputs needed by the backward pass, including (borrowed) mutable arrays. We don't want to build such jaxprs in the first place, even though our forwarding optimization might clean them up later. We instead want to infer as early as possible the common arguments between the forward and backward passes. We now achieve that for `scan` by using `partial_eval_jaxpr_fwd` in `_scan_partial_eval`. I took the opportunity to clean up the scan partial eval logic. It's now slightly more comprehensible. --- jax/_src/interpreters/partial_eval.py | 14 +- jax/_src/lax/control_flow/loops.py | 182 ++++++++++++-------------- tests/mutable_array_test.py | 50 +++++++ 3 files changed, 143 insertions(+), 103 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1503889b04bc..75a31c4bb4db 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1002,14 +1002,16 @@ def partial_eval_jaxpr_nounits( def partial_eval_jaxpr_nounits_fwd( jaxpr: ClosedJaxpr, unknowns: Sequence[bool], instantiate: bool | Sequence[bool], + fwd: bool | Sequence[bool] = True, ) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]: instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True) + fwd = tuple(fwd) if isinstance(fwd, list) else fwd + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd) @weakref_lru_cache def _partial_eval_jaxpr_nounits( jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], - instantiate: bool | Sequence[bool], fwd: bool): + instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) cell = [] @@ -1023,13 +1025,19 @@ def fun(*known_vals_in): f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - if not fwd: + if type(fwd) is bool and not fwd: residuals_ = iter(residuals) residuals = [next(residuals_) if f is None else known_vals_in[f] for f in fwds] assert next(residuals_, None) is None fwds = [None] * len(fwds) else: + if type(fwd) is tuple: + fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk] + residuals_, residuals = iter(residuals), [] + fwds = [residuals.append(next(residuals_)) if f is None else + residuals.append(known_vals_in[f]) if not fwd_[f] else + f for f in fwds] fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals) res_avals = [core.get_aval(r) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6be65485bc7a..3f97ac322284 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -65,7 +65,6 @@ from jax._src.util import ( merge_lists, partition_list, safe_map, safe_zip, split_list, split_list_checked, unzip2, weakref_lru_cache,) -from jax._src import xla_bridge as xb from jax.tree_util import ( keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, treedef_is_leaf) @@ -807,10 +806,34 @@ def _const_to_intensive_res_forwarding( tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res]) return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd +def _scan_known_hoisting(jaxpr_known, const_tracers, num_res): + # To disable: + # return jaxpr_known, [], [False] * num_res, [] + consts = [pe.PartialVal.known(t.pval.get_known()) + if not isinstance(t.aval, state.AbstractRef) + else pe.PartialVal.unknown(t.aval) + for t in const_tracers if t.pval.is_known()] + others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):]) + num_known_outs = len(jaxpr_known.out_avals) - num_res + dbg = jaxpr_known.jaxpr.debug_info + with source_info_util.reset_name_stack(): + jaxpr_known_, invar_pvals_out, known_consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), debug_info=dbg), + consts + others, instantiate=[True] * num_known_outs + [False] * num_res) + jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) + res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] + which_hoisted = [pval.is_known() for pval in res_pvals] + hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()] + mut_consts = [t.pval.get_known() for t in const_tracers + if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] + return jaxpr_known, [*known_consts, *mut_consts], which_hoisted, hoisted_res + + def _scan_partial_eval(trace, *tracers, reverse: bool, length: int, num_consts: int, num_carry: int, jaxpr: core.ClosedJaxpr, linear: Sequence[bool], unroll: int, _split_transpose: bool): + num_xs = len(jaxpr.in_avals) - num_consts - num_carry num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) @@ -822,8 +845,10 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, carry_uk = init_uk for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk - jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( - jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) + jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \ + pe.partial_eval_jaxpr_nounits_fwd( + jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, + fwd=[True] * num_consts + [False] * num_carry + [True] * num_xs) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break @@ -831,108 +856,65 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" - num_res = len(res_avals) + num_res_out, num_res_in = len(res_avals), len(in_fwd_res) + num_knowns_out = len(jaxpr_known.out_avals) - num_res_out + num_consts_known = num_consts - sum(const_uk) + num_carry_known = num_carry - sum(carry_uk) del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. - tracers = tuple(trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, unknowns)) - - # The residual inputs and outputs of the jaxprs produced haven't yet been - # adapted to the scan calling convention; in particular, jaxpr_known has its - # residual outputs all at the end, meaning they're extensive outputs (which is - # fully general but may be wasteful for residuals which are loop-invariant) - # while jaxpr_unknown has its corresponding residual inputs at the front (just - # as a convention with partial_eval_jaxpr_nounits), making them constant - # inputs. To make them consistent, we move the residual inputs on - # jaxpr_unknown to the end, even though we may move some back in the sequel. - jaxpr_unknown = pe.move_binders_to_back( - jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) - - # At this point, all residuals are treated as extensive outputs of jaxpr_known - # (and extensive inputs to jaxpr_unknown). But residuals that are loop- - # invariant can be hoisted out of the scan, rather than letting them get - # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't - # want to broadcast the matrix!). So, outside the loop we perform a partial - # evaluation with known 'const' inputs (but all other inputs unknown). - const_pvals = [pe.PartialVal.known(t.pval.get_known()) - if not isinstance(t.aval, state.AbstractRef) - else pe.PartialVal.unknown(t.aval) - for t in tracers[:num_consts] if t.pval.is_known()] - other_pvals = [pe.PartialVal.unknown(aval) - for aval in jaxpr_known.in_avals[len(const_pvals):]] - with source_info_util.reset_name_stack(): - jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), - debug_info=jaxpr_known.jaxpr.debug_info), - const_pvals + other_pvals, - instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) - jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) - # The above trace_to_jaxpr_nounits call computed loop-invariant residuals - # (known values in invar_pvals_out) and also computed loop-invariant values - # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed intensive residuals, and - # move corresponding intensive residual binders in jaxpr_unknown to the front. - res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] - intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] - jaxpr_unknown = pe.move_binders_to_front( - jaxpr_unknown, - [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) - del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals - # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and - # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. - - # As another optimization, for any extensive inputs that are just forwarded to - # extensive outputs, to avoid a copy (which would be looping over - # dynamic-update-slice) we'd rather forward the input tracer/value. That means - # pruning some outputs from jaxpr_known here, and updating `out_flat` below. - fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) - # Prune fwds_known to include only extensive input to extensive output. - fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and - in_idx is not None and - in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) - else None for out_idx, in_idx in enumerate(fwds_known)] - # Drop any extensive output we can instead get by forwarding an input. - # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint - jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts - jaxpr_known_ = jaxpr_known_.replace( - outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None]) - jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) - del jaxpr_known_ - # We use `fwds_known` below when forming the output of scanning jaxpr_known. + tracers = [trace.instantiate_const(t) if uk else t + for t, uk in zip(tracers, unknowns)] + # Keep original known inputs, since in_fwd_res indexes into them. + orig_inputs = [*jaxpr_known.consts, + *[t.pval.get_known() for t in tracers if t.pval.is_known()]] + + # At this point all non-forwarded residuals are treated as extensive outputs + # of jaxpr_known. Hoist out those that only depend on consts. + # Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res] + # After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res] + # where, modulo hoisted res not being broadcast, we have + # non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) + jaxpr_known, known_consts, which_hoisted, hoisted_res = \ + _scan_known_hoisting(jaxpr_known, tracers[:num_consts], num_res_out) + + # To make jaxpr_unknown match the scan calling convention, move to the back + # binders that don't correspond to hoisted or carry-forwarded residuals. + # Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs] + # After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs] + num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in + which_hoisted_ = iter(which_hoisted) + res_to_move = [not next(which_hoisted_) if f is None else + f >= num_consts_known + num_carry_known for f in in_fwd_res] + jaxpr_unknown = pe.move_binders_to_back(jaxpr_unknown, + res_to_move + [False] * num_unk_in) # Run the known part of the scan (if it has any outputs or effects). - known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts] - if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] - known_inputs = (list(jaxpr_known_consts) + known_mutable_consts + - [t.pval.get_known() for t in tracers[num_consts:] - if t.pval.is_known()]) + known_ins = [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()] if not jaxpr_known.out_avals and not jaxpr_known.effects: - out_known = [] + known_outs_ext_res = [] else: - linear_known = [False] * len(known_inputs) # conservative! - out_known = scan_p.bind( - *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, - num_consts=len(jaxpr_known_consts) + len(known_mutable_consts), - num_carry=num_carry - sum(carry_uk), - linear=tuple(linear_known), unroll=unroll, + linear_known = ([False] * len(known_consts) + + [l for l, uk in zip(linear, unknowns)[num_consts:] if not uk]) + known_outs_ext_res = scan_p.bind( + *known_consts, *known_ins, jaxpr=jaxpr_known, reverse=reverse, + length=length, num_consts=len(known_consts), + num_carry=num_carry_known, linear=tuple(linear_known), unroll=unroll, _split_transpose=_split_transpose) - del linear_known - # Complete the known output by filling in forwarded values using fwds_known. - out_known_iter = iter(out_known) - out_known = [next(out_known_iter) if f is None - else _maybe_put(known_inputs[f]) for f in fwds_known] - assert next(out_known_iter, None) is None - del known_inputs, out_known_iter - - # Split known outputs from residuals. - out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) - assert len(intensive_res) + len(extensive_res) == num_res + known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out]) + + # Complete non_fwd_res and then res, then split to match binders. + non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) + non_fwd_res_ = iter(non_fwd_res) + res = [next(non_fwd_res_) if f is None else _maybe_put(orig_inputs[f]) + for f in in_fwd_res] + assert next(non_fwd_res_, None) is None + int_res, ext_res = partition_list(res_to_move, res) # Create input tracers for jaxpr_unknown bind. unknown_inputs = [t for t in tracers if not t.pval.is_known()] - intensive_res = _map(trace.new_instantiated_const, intensive_res) - extensive_res = _map(trace.new_instantiated_const, extensive_res) + int_res = _map(trace.new_instantiated_const, int_res) + ext_res = _map(trace.new_instantiated_const, ext_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) ys_avals = [core.unmapped_aval(length, 0, y_aval) @@ -941,29 +923,29 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, for a in it.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. - linear_unknown = tuple([False] * len(intensive_res) + + linear_unknown = tuple([False] * len(int_res) + [l for l, uk in zip(linear, unknowns) if uk] + - [False] * len(extensive_res)) + [False] * len(ext_res)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res], + eqn = pe.new_eqn_recipe(trace, [*int_res, *unknown_inputs, *ext_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, - num_consts=len(intensive_res) + sum(const_uk), + num_consts=len(int_res) + sum(const_uk), num_carry=sum(carry_uk), _split_transpose=_split_transpose), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn # Merge known and unknown outputs into final result. - return util.merge_lists(out_uk, out_known, out_tracers) + return util.merge_lists(out_uk, known_outs, out_tracers) def _maybe_put(x): if isinstance(x, np.ndarray): aval = core.shaped_abstractify(x) - s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) + s = sharding.SingleDeviceSharding(pxla.get_default_device()) result_handler = pxla.global_aval_to_result_handler(aval, s, False) return result_handler(pxla.shard_args([s], [None], [None], [x])) else: diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 0f88ec4c95b5..12242698d6c5 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -365,6 +365,56 @@ def loss(x, y): jax.grad(loss, (0,1))(x_top, y_top) self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing_basic(self, jit): + @jax.jit + def primal(grads_ref, x): # note: jit-abstracted! + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + x = jnp.sin(x) + x = stash_grads(grads_ref, x) # ignored, order-preserved + return x + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + grads_ref = core.mutable_array(jnp.float32(0.)) + jax.grad(primal, 1)(grads_ref, jnp.float32(1.0)) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing_scan(self, jit): + @jax.jit + def primal(grads_ref, x): # note: jit-abstracted! + def body(x, _): + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + x = jnp.sin(x) + return x, () + x, () = jax.lax.scan(body, x, None, length=1) + return x + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + grads_ref = core.mutable_array(jnp.float32(0.)) + jax.grad(primal, argnums=1)(grads_ref, jnp.float32(1.0)) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From b999fab1daef44d6a105973297e1b11cdde6c4dd Mon Sep 17 00:00:00 2001 From: Kostiantyn Liepieshov Date: Tue, 10 Jun 2025 11:20:15 -0700 Subject: [PATCH 1612/1769] [jax2tf] fix jax2tf sharding tests for shardy regex doesn't need to match to MHLO "FuncResultSharding" PiperOrigin-RevId: 769715562 --- .../jax2tf/tests/sharding_test.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index fa15522cbe90..20193a931b63 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -233,10 +233,10 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # The argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. @@ -244,11 +244,11 @@ def f_tf(x): self.check_sharding( jax2tf.convert(f_jax), [x], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_out_replicated), - (r"custom_call_target.*Sharding", + (r"custom_call_target.*\"Sharding", count_in_P + count_in_replicated + count_out_P + count_out_replicated), ]) @@ -278,13 +278,13 @@ def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30] f_tf, [y], checks=[ # The variable argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", 1), # The y argument - (r"f32\[20,30\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), + (r"f32\[20,30\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), # The output sharding - (r"f32\[10,30\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,30\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # No other annotations - (r"custom_call_target.*Sharding", 3) + (r"custom_call_target.*\"Sharding", 3) ]) @jtu.with_mesh([("x", 2)]) @@ -312,10 +312,10 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # x - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*replicated", self.GEQ(1)), ]) @@ -359,16 +359,16 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic f_tf, [x], checks=[ # The input argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # The y argument - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_inner_sharding), - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_inner_replicated), # The output sharding - (r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,80\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # No other annotations - (r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated) + (r"custom_call_target.*\"Sharding", 2 + count_inner_sharding + count_inner_replicated) ]) @jtu.parameterized_filterable( @@ -429,17 +429,17 @@ def f_grad_tf(x_v, res_ct): self.check_sharding(f_grad_tf, [x, x.T], checks=[ # The input primal argument, and the output grad - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]: self.check_sharding(f_grad_tf, [x, x.T], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) def test_grad_sharding_different_mesh(self): From 9e0472c1643e72b16ba5fdbb50729e70cd24fdd4 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Tue, 10 Jun 2025 12:56:25 -0700 Subject: [PATCH 1613/1769] [Mosaic GPU] Fix test after a previous PR changed the config params. PiperOrigin-RevId: 769758348 --- tests/pallas/mgpu_matmul_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/mgpu_matmul_test.py b/tests/pallas/mgpu_matmul_test.py index 4013db78f6a2..5c52b0c77296 100644 --- a/tests/pallas/mgpu_matmul_test.py +++ b/tests/pallas/mgpu_matmul_test.py @@ -76,7 +76,7 @@ def test_matmul( a, b, blackwell_matmul_mgpu.TuningConfig( - block_m=128, block_n=128, block_k=128, + tile_m=128, tile_n=128, tile_k=128, max_concurrent_steps=2, collective=False, ), From c0541354cb7904fc494f928459b4a7e43bb36d93 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Tue, 10 Jun 2025 13:14:11 -0700 Subject: [PATCH 1614/1769] Skip NumPy's `isClose` test for NumPy 2.3.0 NumPy 2.3.0 now issues a RuntimeWarning for this case. PiperOrigin-RevId: 769766342 --- tests/lax_numpy_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 09a081761d0e..60ad6a83701b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3971,6 +3971,8 @@ def testIsClose(self): def testIsCloseCornerCases(self, atol, rtol, equal_nan): if jtu.numpy_version() < (2, 0, 0) and (np.isinf(atol) or np.isinf(rtol)): self.skipTest("fails on older NumPy") + if jtu.numpy_version() >= (2, 3, 0) and (np.isinf(atol) or np.isinf(rtol)): + self.skipTest("NumPy 2.3.0 now throws warnings for inf atol/rtol") vals = np.array([-np.nan, -np.inf, -1.00001, -1.0, -0.00001, -0.0, 0.0, 0.00001, 1.0, 1.00001, np.inf, np.nan]) x, y = np.meshgrid(vals, vals) From 03b015261805c2b99b9c49775aeabcd27107ac4f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 10 Jun 2025 13:23:13 -0700 Subject: [PATCH 1615/1769] fix for a downstream breakage from #29353 PiperOrigin-RevId: 769770645 --- jax/_src/lax/control_flow/loops.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1530acafc64f..04100d1dec09 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -65,6 +65,7 @@ from jax._src.util import ( merge_lists, partition_list, safe_map, safe_zip, split_list, split_list_checked, unzip2, weakref_lru_cache,) +from jax._src import xla_bridge as xb from jax.tree_util import ( keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, treedef_is_leaf) @@ -845,12 +846,14 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # iterations, but we need one last iteration to prepare the jaxpr based on the # final carry_uk. carry_uk = init_uk + fwd = [(i < num_consts or i >= num_consts + num_carry) and + (not t.pval.is_known() or isinstance(t.pval.get_known(), Array)) + for i, t in enumerate(tracers)] for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \ pe.partial_eval_jaxpr_nounits_fwd( - jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, - fwd=[True] * num_consts + [False] * num_carry + [True] * num_xs) + jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, fwd=fwd) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break @@ -908,8 +911,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # Complete non_fwd_res and then res, then split to match binders. non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) non_fwd_res_ = iter(non_fwd_res) - res = [next(non_fwd_res_) if f is None else _maybe_put(orig_inputs[f]) - for f in in_fwd_res] + res = [next(non_fwd_res_) if f is None else orig_inputs[f] for f in in_fwd_res] assert next(non_fwd_res_, None) is None int_res, ext_res = partition_list(res_to_move, res) @@ -947,7 +949,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, def _maybe_put(x): if isinstance(x, np.ndarray): aval = core.shaped_abstractify(x) - s = sharding.SingleDeviceSharding(pxla.get_default_device()) + s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) return result_handler(pxla.shard_args([s], [None], [None], [x])) else: From 89fca529891bcacbb625453781d798b2c42dccb7 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Tue, 10 Jun 2025 13:26:08 -0700 Subject: [PATCH 1616/1769] * Add support for output and input memory space colors in tpu custom calls via CustomCallConfig. * Deprecate the field output_memory_colors in xla in favor or output_memory_space_colors: * Update existing code to use output_memory_space_colors. * Update the tests to use output_memory_space_colors. PiperOrigin-RevId: 769771931 --- .../pallas/mosaic/pallas_call_registration.py | 36 ++++++++--- jax/_src/tpu_custom_call.py | 61 +++++++++++++++++-- tests/pallas/tpu_pallas_async_test.py | 6 ++ tests/pallas/tpu_pallas_test.py | 4 ++ 4 files changed, 94 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 528f897edf74..95d07d7ddb16 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -16,9 +16,10 @@ from __future__ import annotations +from collections.abc import Sequence import os import tempfile -from typing import cast +from typing import List, cast import jax from jax import dtypes @@ -92,15 +93,14 @@ def _get_memory_space_from_aval( def _get_memory_spaces_from_avals( - out_avals: tuple[jax_core.AbstractValue, ...], + avals: Sequence[jax_core.AbstractValue], ) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: - output_memory_spaces = None + memory_spaces = None if any( - isinstance(out_aval, pallas_core.ShapedArrayWithMemorySpace) - for out_aval in out_avals + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals ): - output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) - return output_memory_spaces + memory_spaces = tuple(map(_get_memory_space_from_aval, avals)) + return memory_spaces def pallas_call_tpu_lowering_rule( @@ -217,6 +217,18 @@ def _maybe_cast_inputs(*args): dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) output_memory_spaces = _get_memory_spaces_from_avals(out_avals) + input_memory_spaces = None + if any( + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) + for aval in ctx.avals_in + ): + # TODO(sharadmv): Support dynamic grid bounds and extra args. + if num_dyn_bounds != 0 or len(extra_args) > 0: + raise NotImplementedError( + "Dynamic grid bounds and extra args are not supported when" + " specifying memory spaces for inputs." + ) + input_memory_spaces = _get_memory_spaces_from_avals(ctx.avals_in) if cost_estimate is not None: mosaic_cost_estimate = tpu_custom_call.CostEstimate( flops=cost_estimate.flops, @@ -225,6 +237,15 @@ def _maybe_cast_inputs(*args): ) else: mosaic_cost_estimate = None + if input_memory_spaces is None and output_memory_spaces is not None: + input_memory_spaces_list: List[tpu_custom_call.MemorySpace | None] = [ + None, + ] * len(ctx.avals_in) + for input_output_alias in input_output_aliases: + input_memory_spaces_list[input_output_alias[0]] = output_memory_spaces[ + input_output_alias[1] + ] + input_memory_spaces = tuple(input_memory_spaces_list) out_nodes = mosaic.lower_module_to_custom_call( kernel_ctx, *dynamic_grid_args, @@ -245,6 +266,7 @@ def _maybe_cast_inputs(*args): has_side_effects=mosaic_params.has_side_effects, output_memory_spaces=output_memory_spaces, disable_bounds_checks=mosaic_params.disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 0f099ed45cac..7f308541b49f 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -126,6 +126,7 @@ class CustomCallBackendConfig: output_memory_spaces: tuple[MemorySpace | None, ...] | None disable_bounds_checks: bool active_core_count: int | None + input_memory_spaces: tuple[MemorySpace | None, ...] | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -169,13 +170,53 @@ def to_json(self) -> bytes: config.write(b', "internal_scratch_in_bytes": ') config.write(str(self.internal_scratch_in_bytes).encode("ascii")) if self.output_memory_spaces is not None: - config.write(b', "output_memory_colors": [') - for i, memory_space in enumerate(self.output_memory_spaces): - if i: + if len(self.output_memory_spaces) == 1: + output_memory_space = self.output_memory_spaces[0] + if output_memory_space is not None: + config.write(b', "output_memory_space_colors": [') + config.write( + f'{{"color":{output_memory_space.color}}}'.encode("ascii") + ) + config.write(b"]") + else: + comma = False + for i, output_memory_space in enumerate(self.output_memory_spaces): + if output_memory_space is None: + continue + if comma: + config.write(b",") + else: + config.write(b', "output_memory_space_colors": [') + config.write( + f'{{"shape_index":[{i}],"color":{output_memory_space.color}}}' + .encode("ascii") + ) + comma = True + if comma: + config.write(b"]") + if self.input_memory_spaces is not None: + comma = False + for i, input_memory_space in enumerate(self.input_memory_spaces): + if input_memory_space is None: + continue + if input_memory_space not in ( + MemorySpace.HBM, + MemorySpace.VMEM, + ): + raise NotImplementedError( + "input_memory_space_colors only supports HBM and VMEM" + ) + if comma: config.write(b",") - color = memory_space.color if memory_space is not None else -1 - config.write(str(color).encode("ascii")) - config.write(b"]") + else: + config.write(b', "input_memory_space_colors": [') + config.write( + f'{{"operand_index":{i},"color":{input_memory_space.color}}}' + .encode("ascii") + ) + comma = True + if comma: + config.write(b"]") if self.disable_bounds_checks: config.write(b', "disable_bounds_checks": ') config.write(str(self.disable_bounds_checks).lower().encode("ascii")) @@ -456,6 +497,7 @@ def _lower_to_custom_call_config( kernel_name: str | None = None, ir_version: int | None = None, disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> CustomCallBackendConfig: device_type = _get_device_type(module) lowered_module_asm, ( @@ -488,6 +530,7 @@ def _lower_to_custom_call_config( output_memory_spaces=output_memory_spaces, disable_bounds_checks=disable_bounds_checks, active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, ) @@ -509,6 +552,7 @@ def _lowered_to_custom_call_config( output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, disable_bounds_checks: bool = False, active_core_count: int | None = None, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ): if has_custom_barrier: if collective_id is None: @@ -541,6 +585,7 @@ def _lowered_to_custom_call_config( output_memory_spaces, disable_bounds_checks, active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, ) return config @@ -563,6 +608,7 @@ def lower_module_to_custom_call( serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( module, @@ -578,6 +624,7 @@ def lower_module_to_custom_call( kernel_name=kernel_name, ir_version=get_ir_version(ctx), disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) return _tpu_custom_call_lowering( ctx, @@ -607,6 +654,7 @@ def as_tpu_kernel( serialization_format: int | None = 1, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" config = _lower_to_custom_call_config( @@ -622,6 +670,7 @@ def as_tpu_kernel( output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) return _as_jax_callable( config, diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index c70fb6ea2ff5..f4f06c178d8d 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -204,6 +204,9 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 15): + self.skipTest('Requires libtpu built after 2025-06-15') def test_basic_async_copy(self): @jax.jit @@ -830,6 +833,9 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 15): + self.skipTest('Requires libtpu built after 2025-06-15') def test_basic_stateful_async_copy(self): @jax.jit diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 1ddd0a0bc176..ec61ca61f1f2 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1224,6 +1224,10 @@ def test_output_dma_semaphore_ref(self): if self.INTERPRET: self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 15): + self.skipTest('Requires libtpu built after 2025-06-15') + def kernel(x_hbm_ref, y_hbm_ref, sem_out): pltpu.make_async_copy( x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_out From bdd635a65b3557fdf66d42f4d116815fee942072 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Jun 2025 13:34:10 -0700 Subject: [PATCH 1617/1769] [JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments. This is so that pallas can annotate the output ShapeDtypeStruct with `vma` which will allow not setting `check_vma=False` on shard_map when pallas kernels are present. PiperOrigin-RevId: 769775331 --- jax/_src/api.py | 26 +++++++--- jax/_src/core.py | 11 ++++- jax/_src/pallas/core.py | 41 +++++++-------- jax/_src/pallas/mosaic_gpu/core.py | 7 +++ jax/_src/pallas/pallas_call.py | 16 +++++- jax/experimental/colocated_python/func.py | 2 +- tests/pallas/tpu_pallas_distributed_test.py | 55 +++++++++++++++++++-- 7 files changed, 123 insertions(+), 35 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 1f7cc19206a9..1505274ca398 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2822,9 +2822,10 @@ class ShapeDtypeStruct: dtype: a dtype-like object sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] + __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type", "vma"] - def __init__(self, shape, dtype, *, sharding=None, weak_type=False): + def __init__(self, shape, dtype, *, sharding=None, weak_type=False, + vma=None): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") @@ -2856,6 +2857,11 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False): self._dll = (sharding.device_local_layout if isinstance(sharding, Format) else None) self.weak_type = weak_type + if vma is not None and not isinstance(vma, (set, frozenset)): + raise TypeError( + "`vma` argument passed to ShapeDtypeStruct should be of type `set`" + f" or `frozenset`. Got type {type(vma)}") + self.vma = None if vma is None else frozenset(vma) size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @@ -2876,8 +2882,9 @@ def __repr__(self): sh = f", sharding={self.sharding}" if self.sharding is not None else "" l = f", format={self._dll}" if self._dll is not None else "" wt = f", weak_type={self.weak_type}" if self.weak_type else "" + vma = f", vma={self.vma}" if self.vma else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{sh}{l}{wt})") + f"dtype={self.dtype.name}{sh}{l}{wt}{vma})") __str__ = __repr__ @@ -2885,14 +2892,16 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((self.shape, self.dtype, self.sharding, self._dll, self.weak_type) == - (other.shape, other.dtype, other.sharding, other._dll, other.weak_type)) + return ((self.shape, self.dtype, self.sharding, self._dll, + self.weak_type, self.vma) == + (other.shape, other.dtype, other.sharding, other._dll, + other.weak_type, other.vma)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing # https://github.com/jax-ml/jax/issues/8182 return hash((self.shape, self.dtype, self.sharding, self._dll, - self.weak_type)) + self.weak_type, self.vma)) def __setattr__(self, name, value): if hasattr(self, name): @@ -2921,10 +2930,13 @@ def update(self, **kwargs): shape=kwargs.pop('shape', self.shape), dtype=kwargs.pop('dtype', self.dtype), sharding=sharding, - weak_type=kwargs.pop('weak_type', self.weak_type)) + weak_type=kwargs.pop('weak_type', self.weak_type), + vma=kwargs.pop('vma', self.vma)) def _sds_aval_mapping(x): + # TODO(yashkatariya): Propagate vma to ShapedArray? This is only used for + # pallas right now and pallas doesn't use pytype_aval_mappings. aval = ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) diff --git a/jax/_src/core.py b/jax/_src/core.py index 24c9f7dcd261..96bab24e1258 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1498,6 +1498,9 @@ def __repr__(self): def update_weak_type(self, weak_type): return self + def update_vma(self, vma): + return self + def strip_weak_type(self) -> AbstractValue: return self.update_weak_type(False) @@ -2101,6 +2104,9 @@ def _len(self, ignored_tracer): except IndexError as err: raise TypeError("len() of unsized object") from err # same as numpy error + def update_vma(self, vma): + return self.update(vma=vma) + def _get_shape_sharding_str(shape, spec): out = [] @@ -2137,9 +2143,9 @@ def primal_dtype_to_tangent_dtype(primal_dtype): def pvary(x, axis_name): - axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name if not axis_name: return x + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name xs, treedef = tree_flatten(x) ys = pvary_p.bind(*xs, axes=axes, axis_index_groups=None) return tree_unflatten(treedef, ys) @@ -2259,6 +2265,9 @@ def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) + def update_vma(self, vma): + return self + class DArray: _aval: DShapedArray diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 07631699ea61..22e6201ac961 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -143,34 +143,27 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray): __slots__ = ["memory_space"] def __init__(self, shape, dtype, weak_type=False, sharding=None, - memory_space=None): - super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding) + vma=frozenset(), memory_space=None): + super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding, + vma=vma) self.memory_space = memory_space def __eq__(self, other): return super().__eq__(other) and self.memory_space == other.memory_space def __hash__(self): - return hash(( - self.shape, - self.dtype, - self.weak_type, - getattr(self, "sharding", None), - self.memory_space, - )) + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + self.vma, self.memory_space)) def str_short(self, short_dtypes=False): - dt_str = \ - dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else + self.dtype.name) dt_str = dt_str.replace("void", "float0") shapestr = ",".join(map(str, self.shape)) - if hasattr(self, "sharding"): - sharding_str = f"{dt_str}[{shapestr}]({self.sharding})" - else: - sharding_str = "" - memoryspace_str = ( - "" if self.memory_space is None else f"<{self.memory_space}>" - ) + sharding_str = (f"{dt_str}[{shapestr}]({self.sharding})" + if self.sharding else "") + memoryspace_str = ("" if self.memory_space is None + else f"<{self.memory_space}>") return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" def update( @@ -179,6 +172,7 @@ def update( dtype=None, weak_type=None, sharding=None, + vma=None, memory_space=None, ): if shape is None: @@ -188,11 +182,14 @@ def update( if weak_type is None: weak_type = self.weak_type if sharding is None: - sharding = getattr(self, "sharding", None) + sharding = self.sharding + if vma is None: + vma = self.vma if memory_space is None: memory_space = self.memory_space return ShapedArrayWithMemorySpace( - shape, dtype, weak_type, sharding=sharding, memory_space=memory_space + shape, dtype, weak_type, sharding=sharding, vma=vma, + memory_space=memory_space ) mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types @@ -242,6 +239,10 @@ def update_weak_type(self, weak_type): return AbstractMemoryRef( self.inner_aval.update_weak_type(weak_type), self.memory_space) + def update_vma(self, vma): + return AbstractMemoryRef( + self.inner_aval.update_vma(vma), self.memory_space) + def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval memory_space = self.memory_space if memory_space is None else memory_space diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 163e250dadf0..11f3f1eb4592 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -362,6 +362,10 @@ def _setitem(self, tracer, index, value): del tracer, index, value # Unused. raise ValueError("Ref unions can't be assigned to.") + def update_vma(self, vma): + return AbstractRefUnion(self.inner_aval.update_vma(vma), self.refs, + self.memory_space) + @dataclasses.dataclass(init=False, frozen=True) class RefUnion(GPUMemoryRef): @@ -941,6 +945,9 @@ def __repr__(self) -> str: def update_weak_type(self, weak_type): return _as_accum(super().update_weak_type(weak_type)) + def update_vma(self, vma): + return _as_accum(super().update_vma(vma)) + def update(self, inner_aval=None, memory_space=None): return _as_accum(super().update(inner_aval=None, memory_space=None)) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 0165cbad6079..5c6cf10dec12 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1205,7 +1205,7 @@ def _trace_kernel_to_jaxpr( wrapped_kernel_fun = primitives.wrap_with_transforms( wrapped_kernel_fun, kernel_in_transforms ) - with grid_mapping.trace_env(): + with grid_mapping.trace_env(), config._check_vma(False): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, kernel_avals) if consts: @@ -1350,6 +1350,16 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: match out_shape: case jax.ShapeDtypeStruct(): + if config._check_vma.value: + if out_shape.vma is None: + raise ValueError( + "When `check_vma=True` on `jax.shard_map`, `vma` on" + " `jax.ShapeDtypeStruct` must not be `None`. Please specify how the" + " output should be varying across mesh axes using the `vma`" + " argument of `jax.ShapeDtypeStruct` or set `check_vma=False` on" + " `jax.shard_map`.") + return jax_core.ShapedArray( + shape=out_shape.shape, dtype=out_shape.dtype, vma=out_shape.vma) return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) case pallas_core.MemoryRef(): return out_shape.get_array_aval() @@ -1685,6 +1695,8 @@ def wrapped(*args): x.ref if isinstance(x, state_types.TransformedRef) else x for x in flat_kernel_args ) + flat_kernel_avals = tuple(a.update_vma(frozenset()) + for a in flat_kernel_avals) # Note that only a subset of all transforms can be found here, and they are # never expected to contain any arrays. kernel_arg_transforms = tuple( @@ -1696,7 +1708,7 @@ def wrapped(*args): if name is not None: kernel_dbg = kernel_dbg.replace_func_name(mlir.sanitize_name(name)) jaxpr, consts = _trace_kernel_to_jaxpr( - kernel, kernel_dbg, grid_mapping, tuple(flat_kernel_avals), + kernel, kernel_dbg, grid_mapping, flat_kernel_avals, kernel_in_tree, kernel_arg_transforms) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index b7188d9da7ad..65464479b5ca 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -169,7 +169,7 @@ def _compile_to_executable( program, compile_options ) out_handlers = pxla.global_avals_to_results_handler( - out_sdss, out_shardings, committed=True + out_sdss, out_shardings, committed=True # type: ignore ).handlers def call(*args, **kwargs): diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 11b159dbec4c..4b7bc06463bd 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -76,19 +76,66 @@ def body(x): kernel, in_specs=[pl.BlockSpec(memory_space=mem)], out_specs=pl.BlockSpec(memory_space=mem), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32, vma=frozenset('x')), )(x) devices = jax.devices()[:2] mesh = jax.sharding.Mesh(devices, ['x']) - y = jax.jit( + f = jax.jit( shard_map.shard_map( - body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), ) - )(x) + ) + jaxpr = f.trace(x).jaxpr + self.assertNotIn('pvary', str(jaxpr)) + y = f(x) expected = jnp.concatenate([x[8:], x[:8]]) np.testing.assert_allclose(y, expected) + def test_vma_error(self): + def kernel(x_ref, y_ref): + def body(ready_sem, send_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL) + pltpu.semaphore_wait(ready_sem) + copy_done = pltpu.async_remote_copy( + x_ref, y_ref, send_sem, recv_sem, other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + copy_done.wait_send() + copy_done.wait_recv() + + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) + + x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) + + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + f = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + ) + ) + with self.assertRaisesRegex( + ValueError, + 'When `check_vma=True` on `jax.shard_map`, `vma` on' + ' `jax.ShapeDtypeStruct` must not be `None`'): + f(x) + @parameterized.named_parameters( ('left', 'left'), ('right', 'right') From 160e59f10c2b5bc38470d84aab165d640098b6a0 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Tue, 10 Jun 2025 15:35:29 -0700 Subject: [PATCH 1618/1769] Add is_leaf_with_path predicate. This is an alternative of is_leaf that allows the predicate to look at the key path. Only works on .*_with_path APIs. Fixes https://github.com/jax-ml/jax/issues/27996 PiperOrigin-RevId: 769828402 --- jax/BUILD | 1 + jax/_src/tree.py | 17 +++++--- jax/_src/tree_util.py | 39 ++++++++++-------- jaxlib/_jax/pytree.pyi | 2 +- jaxlib/pytree.cc | 54 +++++++++++++----------- jaxlib/pytree.h | 7 ++-- jaxlib/xla_client.py | 2 +- tests/tree_util_test.py | 91 ++++++++++++++++++++++++++++++++++++----- 8 files changed, 153 insertions(+), 60 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 3ec1567e2dc9..a2d59c98e810 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1438,6 +1438,7 @@ pytype_strict_library( srcs = ["_src/tree.py"], deps = [ ":tree_util", + "//jax/_src/lib", ], ) diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 9a3e001d902b..5aaa0bf2b006 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -287,7 +287,8 @@ def unflatten(treedef: tree_util.PyTreeDef, def flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]: """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. @@ -313,11 +314,12 @@ def flatten_with_path( - :func:`jax.tree.map_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_flatten_with_path(tree, is_leaf) + return tree_util.tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path) def leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[tree_util.KeyPath, Any]]: """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. @@ -338,14 +340,15 @@ def leaves_with_path( - :func:`jax.tree.flatten_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_leaves_with_path(tree, is_leaf) + return tree_util.tree_leaves_with_path(tree, is_leaf, is_leaf_takes_path) def map_with_path( f: Callable[..., Any], tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> Any: """Maps a multi-input function over pytree key path and args to produce a new pytree. @@ -377,7 +380,9 @@ def map_with_path( - :func:`jax.tree.leaves_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) + return tree_util.tree_map_with_path( + f, tree, *rest, is_leaf=is_leaf, is_leaf_takes_path=is_leaf_takes_path + ) def broadcast(prefix_tree: Any, full_tree: Any, diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6edbbfd62d12..6054b4711e02 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -27,6 +27,7 @@ from jax._src.lib import pytree from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 +from jax._src.lib import jaxlib_extension_version export = set_module('jax.tree_util') @@ -1126,34 +1127,40 @@ def register_static(cls: type[H]) -> type[H]: @export def tree_flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: """Alias of :func:`jax.tree.flatten_with_path`.""" - return default_registry.flatten_with_path(tree, is_leaf) + if jaxlib_extension_version < 351: + return default_registry.flatten_with_path(tree, is_leaf) + is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf + if not is_leaf_takes_path and is_leaf is not None: + is_leaf_with_kp = lambda _, x: is_leaf(x) + return default_registry.flatten_with_path(tree, is_leaf_with_kp) @export def tree_leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[KeyPath, Any]]: """Alias of :func:`jax.tree.leaves_with_path`.""" - return tree_flatten_with_path(tree, is_leaf)[0] - - -# generate_key_paths is not exported. -def generate_key_paths( - tree: Any, is_leaf: Callable[[Any], bool] | None = None -) -> list[tuple[KeyPath, Any]]: - return tree_leaves_with_path(tree, is_leaf) -_generate_key_paths = generate_key_paths # alias for backward compat + return tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)[0] +generate_key_paths = tree_leaves_with_path @export -def tree_map_with_path(f: Callable[..., Any], - tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None) -> Any: +def tree_map_with_path( + f: Callable[..., Any], + tree: Any, + *rest: Any, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, +) -> Any: """Alias of :func:`jax.tree.map_with_path`.""" - keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) + keypath_leaves, treedef = tree_flatten_with_path( + tree, is_leaf, is_leaf_takes_path + ) keypath_leaves = list(zip(*keypath_leaves)) all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi index ac5298c77964..0cec90db631c 100644 --- a/jaxlib/_jax/pytree.pyi +++ b/jaxlib/_jax/pytree.pyi @@ -45,7 +45,7 @@ class PyTreeRegistry: def flatten_with_path( self, tree: Any, - leaf_predicate: Callable[[Any], bool] | None = ..., + leaf_predicate: Callable[[Any, Any], bool] | None = ..., ) -> Tuple[list[Tuple[_KeyPath, Any]], PyTreeDef]: ... def register_node( self, diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc index bd845c47ec1e..6e2f7af98629 100644 --- a/jaxlib/pytree.cc +++ b/jaxlib/pytree.cc @@ -545,16 +545,33 @@ nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { return nanobind::make_tuple("key"); }; +/* static */ nb::object MakeKeyPathTuple(std::vector& keypath) { + const std::vector& frozen_keypath = keypath; + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + return kp_tuple; +} + template -void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, - const std::optional& leaf_predicate, - std::optional>& keypath) { +void PyTreeDef::FlattenImpl( + nb::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate) { Node node; const int start_num_nodes = traversal_.size(); const int start_num_leaves = leaves.size(); bool is_known_leaf = false; if (leaf_predicate) { - nb::object o = (*leaf_predicate)(handle); + nb::object o; + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + o = (*leaf_predicate)(kp_tuple, handle); + } else { + o = (*leaf_predicate)(handle); + } // Historically we accepted "truthy" values from leaf predicates. Accept // None here to keep existing clients happy. if (o.is_none()) { @@ -568,12 +585,7 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, if (is_known_leaf) { nb::object value = nb::borrow(handle); if (keypath.has_value()) { - const std::vector& frozen_keypath = keypath.value(); - nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); - for (int i = 0; i < frozen_keypath.size(); ++i) { - PyTuple_SET_ITEM(kp_tuple.ptr(), i, - nb::object(frozen_keypath[i]).release().ptr()); - } + auto kp_tuple = MakeKeyPathTuple(keypath.value()); value = nb::make_tuple(std::move(kp_tuple), std::move(value)); } if constexpr (std::is_same_v) { @@ -590,7 +602,7 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, " in flatten; PyTree may have cyclical node references.")) { return; } - FlattenImpl(child, leaves, leaf_predicate, keypath); + FlattenImpl(child, leaves, keypath, leaf_predicate); Py_LeaveRecursiveCall(); }; switch (node.kind) { @@ -718,12 +730,7 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, DCHECK(node.kind == PyTreeKind::kLeaf); auto value = nb::borrow(handle); if (keypath.has_value()) { - const std::vector& frozen_keypath = keypath.value(); - nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); - for (int i = 0; i < frozen_keypath.size(); ++i) { - PyTuple_SET_ITEM(kp_tuple.ptr(), i, - nb::object(frozen_keypath[i]).release().ptr()); - } + auto kp_tuple = MakeKeyPathTuple(keypath.value()); value = nb::make_tuple(std::move(kp_tuple), std::move(value)); } if constexpr (std::is_same_v) { @@ -742,19 +749,19 @@ void PyTreeDef::Flatten(nb::handle handle, absl::InlinedVector& leaves, std::optional leaf_predicate) { std::optional> keypath = std::nullopt; - FlattenImpl(handle, leaves, leaf_predicate, keypath); + FlattenImpl(handle, leaves, keypath, leaf_predicate); } void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, std::optional leaf_predicate) { std::optional> keypath = std::nullopt; - FlattenImpl(handle, leaves, leaf_predicate, keypath); + FlattenImpl(handle, leaves, keypath, leaf_predicate); } void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, std::optional leaf_predicate) { std::optional> keypath = std::nullopt; - FlattenImpl(handle, leaves, leaf_predicate, keypath); + FlattenImpl(handle, leaves, keypath, leaf_predicate); } /*static*/ std::pair, nb_class_ptr> @@ -766,10 +773,11 @@ PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, return std::make_pair(std::move(leaves), std::move(def)); } -void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list& leaves, - std::optional leaf_predicate) { +void PyTreeDef::FlattenWithPath( + nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { std::optional> keypath = std::vector(); - FlattenImpl(handle, leaves, leaf_predicate, keypath); + FlattenImpl(handle, leaves, keypath, leaf_predicate); } /*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, diff --git a/jaxlib/pytree.h b/jaxlib/pytree.h index 0a012d933c70..8f13a26135a5 100644 --- a/jaxlib/pytree.h +++ b/jaxlib/pytree.h @@ -367,9 +367,10 @@ class PyTreeDef { const; template - void FlattenImpl(nanobind::handle handle, T& leaves, - const std::optional& leaf_predicate, - std::optional>& keypath); + void FlattenImpl( + nanobind::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate); template nanobind::object UnflattenImpl(T leaves) const; diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 2c73947fa684..92657f9f277d 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 350 +_version = 351 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 0d92156b2530..49318e41fad3 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -26,6 +26,7 @@ from jax import tree_util from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors +from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp # Easier to read. @@ -789,9 +790,14 @@ def testKeyStr(self): def testTreeMapWithPathWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], jnp.array(0)), ([0], 7, [5, 6])) - out = tree_util.tree_map_with_path( - lambda kp, *xs: (kp[0].idx, *xs), x, y, - is_leaf=lambda n: isinstance(n, list)) + if jaxlib_extension_version < 351: + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda n: isinstance(n, list)) + else: + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda _, n: isinstance(n, list), is_leaf_takes_path=True) self.assertEqual(out, (((0, 1, [3]), (0, 2, jnp.array(0))), (1, [3, 4, 5], ([0], 7, [5, 6])))) @@ -808,7 +814,13 @@ def is_empty(x): tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} - flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) + if jaxlib_extension_version < 351: + flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) + else: + is_empty_new = lambda kp, x: is_empty(x) + flattened, _ = tree_util.tree_flatten_with_path( + tree1, is_empty_new, is_leaf_takes_path=True + ) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, @@ -822,6 +834,36 @@ def is_empty(x): ], ) + def testTreeFlattenWithPathWithIsLeafWithPathArgument(self): + if jaxlib_extension_version < 351: + self.skipTest("Requires jaxlib version >= 351") + x = ((1, 2), [3, {4: 4, 5: 5}]) + check_max_depth = lambda kp, _: len(kp) >= 2 + flattened, _ = tree_util.tree_flatten_with_path( + x, is_leaf=check_max_depth, is_leaf_takes_path=True + ) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), SequenceKey(0),), 1), + ((SequenceKey(0), SequenceKey(1),), 2), + ((SequenceKey(1), SequenceKey(0),), 3), + ((SequenceKey(1), SequenceKey(1)), {4: 4, 5: 5}), + ], + ) + + def testTreeMapWithPathWithIsLeafWithPathArgument(self): + if jaxlib_extension_version < 351: + self.skipTest("Requires jaxlib version >= 351") + x = ((1, 2), [3, 4, 5]) + y = (([3], jnp.array(0)), ([0], 7, [5, 6])) + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda kp, n: isinstance(n, list), is_leaf_takes_path=True) + self.assertEqual(out, (((0, 1, [3]), + (0, 2, jnp.array(0))), + (1, [3, 4, 5], ([0], 7, [5, 6])))) + def testTreeFlattenWithPathBuiltin(self): x = (1, {"a": 2, "b": 3}) flattened = tree_util.tree_flatten_with_path(x) @@ -1522,9 +1564,16 @@ def test_tree_flatten_with_path(self): def test_tree_flatten_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + if jaxlib_extension_version < 351: + self.assertEqual( + jax.tree.flatten_with_path(obj, is_leaf=is_leaf), + tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), + ) + return + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.flatten_with_path(obj, is_leaf=is_leaf), - tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), + jax.tree.flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), + tree_util.tree_flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), ) def test_tree_leaves_with_path(self): @@ -1537,9 +1586,20 @@ def test_tree_leaves_with_path(self): def test_tree_leaves_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + if jaxlib_extension_version < 351: + self.assertEqual( + jax.tree.leaves_with_path(obj, is_leaf=is_leaf), + tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), + ) + return + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.leaves_with_path(obj, is_leaf=is_leaf), - tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), + jax.tree.leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) def test_tree_map_with_path(self): @@ -1556,9 +1616,20 @@ def test_tree_map_with_path_is_leaf(self): obj = [1, 2, (3, 4)] obj2 = [5, 6, (7, 8)] is_leaf = lambda x: isinstance(x, tuple) + if jaxlib_extension_version < 351: + self.assertEqual( + jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), + tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), + ) + return + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), - tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), + jax.tree.map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) From 423aafe705a7cdb6346d8d3aee1108ed13d1533a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Jun 2025 15:38:21 -0700 Subject: [PATCH 1619/1769] Pallas documentation fixes. PiperOrigin-RevId: 769829646 --- docs/pallas/pipelining.md | 2 +- docs/pallas/tpu/details.rst | 9 +++++---- docs/pallas/tpu/matmul.ipynb | 9 ++++++++- docs/pallas/tpu/matmul.md | 9 ++++++++- docs/pallas/tpu/sparse.ipynb | 12 +++++------- docs/pallas/tpu/sparse.md | 12 +++++------- 6 files changed, 32 insertions(+), 21 deletions(-) diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index 0ff9eaf5a24b..a9fb770c9b53 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -516,7 +516,7 @@ print(result) This result is completely wrong! -There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. +There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index a961c376f5bc..91aefd52d2e8 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -99,8 +99,8 @@ for exceptions). This unlocks some interesting capabilities: output, without any risk of race conditions. However, we do require that all invocations that write to a particular slice are consecutive. -The "consecutive" restriction on the output usually means that the some prefix -of the grid dimensions always vary the slice of the output an invocation needs +The "consecutive" restriction on the output usually means that some prefix +of the grid dimensions always varies the slice of the output an invocation needs to access, while the output window remains constant for the remaining suffix. For example, when implementing a Pallas TPU kernel for matrix multiplication, @@ -128,7 +128,7 @@ has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -TPUs perform bulk of the computation on 2D vector registers, which are typically of +TPUs perform the bulk of the computation on 2D vector registers, which are typically of size 8x128 for 32-bit values (as of TPU v6). When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), the last two dimensions of the array will be tiled into the registers. @@ -167,7 +167,8 @@ sequential grid execution guarantees, and will need to parallelize one of the grid axes over cores. This is an opt-in procedure. To allow that, ``pallas_call`` requires an extra parameter named ``dimension_semantics``: -.. +.. code:: python + pallas_call( ..., compiler_params=pltpu.CompilerParams( diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 3ae5f95c204a..dbe9747c4884 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -496,7 +496,14 @@ "\n", "Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n", "\n", - "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", + "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage.\n", + "\n", + "In addition, when tiling the matmul operation, the same values could be read multiple times from memory.\n", + "Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`.\n", + "Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`.\n", + "\n", + "Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance.\n", + " Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", "\n", "The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n", "\n", diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index 7ac157b4a2e9..509d47093af7 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -342,7 +342,14 @@ np.testing.assert_array_equal(x @ y, matmul(x, y)) Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks. -This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. +This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. + +In addition, when tiling the matmul operation, the same values could be read multiple times from memory. +Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`. +Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`. + +Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. + Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints: diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index 31cfa8eeb328..6834f2d7d930 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -62,7 +62,7 @@ "source": [ "## Dynamic Block Indexing with Scalar Prefetch\n", "\n", - "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", "\n", "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", "\n", @@ -253,13 +253,13 @@ "source": [ "## Example: Sparse @ Dense Matrix Multiplication\n", "\n", - "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", "\n", "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", "\n", "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", "\n", - "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." ] }, { @@ -437,7 +437,7 @@ "\n", "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", "\n", - "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", + "A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", "\n", "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", "\n", @@ -511,7 +511,6 @@ " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", " mask_types_finder = []\n", " mask_data = []\n", - " mask_type_idxs = []\n", "\n", " next_mask_type_idx = 0\n", " prefetch_mask = jnp.zeros_like(block_mask)\n", @@ -536,7 +535,6 @@ " next_j = j\n", " else:\n", " type_index = -1\n", - " mask_type_idxs.append(type_index)\n", " block_mask = block_mask.at[i, j].set(is_nonzero)\n", " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", @@ -665,7 +663,7 @@ "\n", "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", - "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + "- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger." ] }, { diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 35613acdb2c9..e9a4bb143a2f 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -51,7 +51,7 @@ print("Running on", jax.devices()[0].device_kind) ## Dynamic Block Indexing with Scalar Prefetch -We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: @@ -208,13 +208,13 @@ def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32) ## Example: Sparse @ Dense Matrix Multiplication -In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. +In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output. We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: ![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) -It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. ```{code-cell} --- @@ -353,7 +353,7 @@ print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). -A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). +A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. @@ -411,7 +411,6 @@ def sparsify_mask(mask: jax.Array, block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) mask_types_finder = [] mask_data = [] - mask_type_idxs = [] next_mask_type_idx = 0 prefetch_mask = jnp.zeros_like(block_mask) @@ -436,7 +435,6 @@ def sparsify_mask(mask: jax.Array, next_j = j else: type_index = -1 - mask_type_idxs.append(type_index) block_mask = block_mask.at[i, j].set(is_nonzero) prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) prefetch_i = prefetch_i.at[i, j].set(next_i) @@ -542,7 +540,7 @@ Now let's compare performance versus a naive dense implementation. On TPU v5e, w We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: - We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. -- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. +- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger. ```{code-cell} --- From ed03f383a3c00bcac39a348f7f971ec4b4062f9e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 10 Jun 2025 16:07:48 -0700 Subject: [PATCH 1620/1769] Rollback of #29353 due to downstream failures Reverts 56f3293cf7c6193a305d19ea74c3d44341d00351 PiperOrigin-RevId: 769842249 --- jax/_src/interpreters/partial_eval.py | 14 +- jax/_src/lax/control_flow/loops.py | 180 ++++++++++++++------------ tests/mutable_array_test.py | 50 ------- 3 files changed, 101 insertions(+), 143 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 83d6edc71ce0..7a1fba94bb3d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1002,16 +1002,14 @@ def partial_eval_jaxpr_nounits( def partial_eval_jaxpr_nounits_fwd( jaxpr: ClosedJaxpr, unknowns: Sequence[bool], instantiate: bool | Sequence[bool], - fwd: bool | Sequence[bool] = True, ) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]: instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - fwd = tuple(fwd) if isinstance(fwd, list) else fwd - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd) + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True) @weakref_lru_cache def _partial_eval_jaxpr_nounits( jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], - instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]): + instantiate: bool | Sequence[bool], fwd: bool): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) cell = [] @@ -1025,19 +1023,13 @@ def fun(*known_vals_in): f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - if type(fwd) is bool and not fwd: + if not fwd: residuals_ = iter(residuals) residuals = [next(residuals_) if f is None else known_vals_in[f] for f in fwds] assert next(residuals_, None) is None fwds = [None] * len(fwds) else: - if type(fwd) is tuple: - fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk] - residuals_, residuals = iter(residuals), [] - fwds = [residuals.append(next(residuals_)) if f is None else - residuals.append(known_vals_in[f]) if not fwd_[f] else - f for f in fwds] fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals) res_avals = [core.get_aval(r) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 04100d1dec09..8857ffad7a5f 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -809,34 +809,10 @@ def _const_to_intensive_res_forwarding( tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res]) return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd -def _scan_known_hoisting(jaxpr_known, const_tracers, num_res): - # To disable: - # return jaxpr_known, [], [False] * num_res, [] - consts = [pe.PartialVal.known(t.pval.get_known()) - if not isinstance(t.aval, state.AbstractRef) - else pe.PartialVal.unknown(t.aval) - for t in const_tracers if t.pval.is_known()] - others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):]) - num_known_outs = len(jaxpr_known.out_avals) - num_res - dbg = jaxpr_known.jaxpr.debug_info - with source_info_util.reset_name_stack(): - jaxpr_known_, invar_pvals_out, known_consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), debug_info=dbg), - consts + others, instantiate=[True] * num_known_outs + [False] * num_res) - jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) - res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] - which_hoisted = [pval.is_known() for pval in res_pvals] - hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()] - mut_consts = [t.pval.get_known() for t in const_tracers - if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] - return jaxpr_known, [*known_consts, *mut_consts], which_hoisted, hoisted_res - - def _scan_partial_eval(trace, *tracers, reverse: bool, length: int, num_consts: int, num_carry: int, jaxpr: core.ClosedJaxpr, linear: Sequence[bool], unroll: int, _split_transpose: bool): - num_xs = len(jaxpr.in_avals) - num_consts - num_carry num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) @@ -846,14 +822,10 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # iterations, but we need one last iteration to prepare the jaxpr based on the # final carry_uk. carry_uk = init_uk - fwd = [(i < num_consts or i >= num_consts + num_carry) and - (not t.pval.is_known() or isinstance(t.pval.get_known(), Array)) - for i, t in enumerate(tracers)] for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk - jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \ - pe.partial_eval_jaxpr_nounits_fwd( - jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, fwd=fwd) + jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( + jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break @@ -861,64 +833,108 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" - num_res_out, num_res_in = len(res_avals), len(in_fwd_res) - num_knowns_out = len(jaxpr_known.out_avals) - num_res_out - num_consts_known = num_consts - sum(const_uk) - num_carry_known = num_carry - sum(carry_uk) + num_res = len(res_avals) del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. - tracers = [trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, unknowns)] - # Keep original known inputs, since in_fwd_res indexes into them. - orig_inputs = [*jaxpr_known.consts, - *[t.pval.get_known() for t in tracers if t.pval.is_known()]] - - # At this point all non-forwarded residuals are treated as extensive outputs - # of jaxpr_known. Hoist out those that only depend on consts. - # Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res] - # After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res] - # where, modulo hoisted res not being broadcast, we have - # non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) - jaxpr_known, known_consts, which_hoisted, hoisted_res = \ - _scan_known_hoisting(jaxpr_known, tracers[:num_consts], num_res_out) - - # To make jaxpr_unknown match the scan calling convention, move to the back - # binders that don't correspond to hoisted or carry-forwarded residuals. - # Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs] - # After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs] - num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in - which_hoisted_ = iter(which_hoisted) - res_to_move = [not next(which_hoisted_) if f is None else - f >= num_consts_known + num_carry_known for f in in_fwd_res] - jaxpr_unknown = pe.move_binders_to_back(jaxpr_unknown, - res_to_move + [False] * num_unk_in) + tracers = tuple(trace.instantiate_const(t) if uk else t + for t, uk in zip(tracers, unknowns)) + + # The residual inputs and outputs of the jaxprs produced haven't yet been + # adapted to the scan calling convention; in particular, jaxpr_known has its + # residual outputs all at the end, meaning they're extensive outputs (which is + # fully general but may be wasteful for residuals which are loop-invariant) + # while jaxpr_unknown has its corresponding residual inputs at the front (just + # as a convention with partial_eval_jaxpr_nounits), making them constant + # inputs. To make them consistent, we move the residual inputs on + # jaxpr_unknown to the end, even though we may move some back in the sequel. + jaxpr_unknown = pe.move_binders_to_back( + jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) + + # At this point, all residuals are treated as extensive outputs of jaxpr_known + # (and extensive inputs to jaxpr_unknown). But residuals that are loop- + # invariant can be hoisted out of the scan, rather than letting them get + # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't + # want to broadcast the matrix!). So, outside the loop we perform a partial + # evaluation with known 'const' inputs (but all other inputs unknown). + const_pvals = [pe.PartialVal.known(t.pval.get_known()) + if not isinstance(t.aval, state.AbstractRef) + else pe.PartialVal.unknown(t.aval) + for t in tracers[:num_consts] if t.pval.is_known()] + other_pvals = [pe.PartialVal.unknown(aval) + for aval in jaxpr_known.in_avals[len(const_pvals):]] + with source_info_util.reset_name_stack(): + jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), + debug_info=jaxpr_known.jaxpr.debug_info), + const_pvals + other_pvals, + instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) + jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) + # The above trace_to_jaxpr_nounits call computed loop-invariant residuals + # (known values in invar_pvals_out) and also computed loop-invariant values + # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the + # previous consts). We need to collect the computed intensive residuals, and + # move corresponding intensive residual binders in jaxpr_unknown to the front. + res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] + intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] + jaxpr_unknown = pe.move_binders_to_front( + jaxpr_unknown, + [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) + del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals + # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and + # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. + + # As another optimization, for any extensive inputs that are just forwarded to + # extensive outputs, to avoid a copy (which would be looping over + # dynamic-update-slice) we'd rather forward the input tracer/value. That means + # pruning some outputs from jaxpr_known here, and updating `out_flat` below. + fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) + # Prune fwds_known to include only extensive input to extensive output. + fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and + in_idx is not None and + in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) + else None for out_idx, in_idx in enumerate(fwds_known)] + # Drop any extensive output we can instead get by forwarding an input. + # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint + jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts + jaxpr_known_ = jaxpr_known_.replace( + outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None]) + jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) + del jaxpr_known_ + # We use `fwds_known` below when forming the output of scanning jaxpr_known. # Run the known part of the scan (if it has any outputs or effects). - known_ins = [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()] + known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts] + if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] + known_inputs = (list(jaxpr_known_consts) + known_mutable_consts + + [t.pval.get_known() for t in tracers[num_consts:] + if t.pval.is_known()]) if not jaxpr_known.out_avals and not jaxpr_known.effects: - known_outs_ext_res = [] + out_known = [] else: - linear_known = ([False] * len(known_consts) + - [l for l, uk in zip(linear, unknowns)[num_consts:] if not uk]) - known_outs_ext_res = scan_p.bind( - *known_consts, *known_ins, jaxpr=jaxpr_known, reverse=reverse, - length=length, num_consts=len(known_consts), - num_carry=num_carry_known, linear=tuple(linear_known), unroll=unroll, + linear_known = [False] * len(known_inputs) # conservative! + out_known = scan_p.bind( + *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, + num_consts=len(jaxpr_known_consts) + len(known_mutable_consts), + num_carry=num_carry - sum(carry_uk), + linear=tuple(linear_known), unroll=unroll, _split_transpose=_split_transpose) - known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out]) - - # Complete non_fwd_res and then res, then split to match binders. - non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) - non_fwd_res_ = iter(non_fwd_res) - res = [next(non_fwd_res_) if f is None else orig_inputs[f] for f in in_fwd_res] - assert next(non_fwd_res_, None) is None - int_res, ext_res = partition_list(res_to_move, res) + del linear_known + # Complete the known output by filling in forwarded values using fwds_known. + out_known_iter = iter(out_known) + out_known = [next(out_known_iter) if f is None + else _maybe_put(known_inputs[f]) for f in fwds_known] + assert next(out_known_iter, None) is None + del known_inputs, out_known_iter + + # Split known outputs from residuals. + out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) + assert len(intensive_res) + len(extensive_res) == num_res # Create input tracers for jaxpr_unknown bind. unknown_inputs = [t for t in tracers if not t.pval.is_known()] - int_res = _map(trace.new_instantiated_const, int_res) - ext_res = _map(trace.new_instantiated_const, ext_res) + intensive_res = _map(trace.new_instantiated_const, intensive_res) + extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) ys_avals = [core.unmapped_aval(length, 0, y_aval) @@ -927,24 +943,24 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, for a in it.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. - linear_unknown = tuple([False] * len(int_res) + + linear_unknown = tuple([False] * len(intensive_res) + [l for l, uk in zip(linear, unknowns) if uk] + - [False] * len(ext_res)) + [False] * len(extensive_res)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe(trace, [*int_res, *unknown_inputs, *ext_res], + eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, - num_consts=len(int_res) + sum(const_uk), + num_consts=len(intensive_res) + sum(const_uk), num_carry=sum(carry_uk), _split_transpose=_split_transpose), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn # Merge known and unknown outputs into final result. - return util.merge_lists(out_uk, known_outs, out_tracers) + return util.merge_lists(out_uk, out_known, out_tracers) def _maybe_put(x): if isinstance(x, np.ndarray): diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 12242698d6c5..0f88ec4c95b5 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -365,56 +365,6 @@ def loss(x, y): jax.grad(loss, (0,1))(x_top, y_top) self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) - @parameterized.parameters([False, True]) - def test_custom_vjp_grad_stats_plumbing_basic(self, jit): - @jax.jit - def primal(grads_ref, x): # note: jit-abstracted! - x = jnp.sin(x) - x = stash_grads(grads_ref, x) - x = jnp.sin(x) - x = stash_grads(grads_ref, x) # ignored, order-preserved - return x - - @jax.custom_vjp - def stash_grads(grads_ref, x): - return x - def stash_grads_fwd(grads_ref, x): - return x, grads_ref - def stash_grads_bwd(grads_ref, g): - grads_ref[...] = g - return None, g - stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) - - grads_ref = core.mutable_array(jnp.float32(0.)) - jax.grad(primal, 1)(grads_ref, jnp.float32(1.0)) - self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) - - @parameterized.parameters([False, True]) - def test_custom_vjp_grad_stats_plumbing_scan(self, jit): - @jax.jit - def primal(grads_ref, x): # note: jit-abstracted! - def body(x, _): - x = jnp.sin(x) - x = stash_grads(grads_ref, x) - x = jnp.sin(x) - return x, () - x, () = jax.lax.scan(body, x, None, length=1) - return x - - @jax.custom_vjp - def stash_grads(grads_ref, x): - return x - def stash_grads_fwd(grads_ref, x): - return x, grads_ref - def stash_grads_bwd(grads_ref, g): - grads_ref[...] = g - return None, g - stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) - - grads_ref = core.mutable_array(jnp.float32(0.)) - jax.grad(primal, argnums=1)(grads_ref, jnp.float32(1.0)) - self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) - @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From 22ceb687751213138afc7cfe9537aa51413e892b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Jun 2025 16:32:25 -0700 Subject: [PATCH 1621/1769] Initial commit to make unreduced + AD work. **TLDR Design** Consider this example (fwd pass): ``` a: f32[8@x, 4@y] b: f32[4@y, 6] c: f32[8@x, 4@y] d: f32[4@y, 6] e: f32[8@x, 6]{U: y} = a @ b f: f32[8@x, 6]{U: y} = c @ d g: f32[8@x, 6]{U: y} = e + f h: f32[8@x, 6] = reshard(g, P(x, None)) i: f32[] = jnp.sum(h) ``` The usage of unreduced minimizes the communication that we incur in the fwd pass. Instead of doing reduction twice during the 2 matmuls, we keep the result as unreduced and do reduction **only once** after addition. The bwd pass should also minimize communication! We have 3 choices: 1) cotangent type of unreduced should be unreduced 2) cotangent type of unreduced should be replicated 3) cotangent type of unreduced should be reduced (reduce is just replicated but the only difference is that it is the cotangent type of unreduced and vice-versa) and cotangent type of replicated is replicated (so that we don't break existing use cases where replicated's cotangent type is replicated) (2) and (3) are the only ones that don't introduce extra communication on the bwd pass. But since we can't do (2) (explained below), this changes implements (3). **But why can't we do (1)? Where does the extra communication come from?** Let's try to do the above example's bwd pass and consider the cotangent type of unreduced is unreduced ``` # since `h_bar` is `f32[8@x, 6]{U: y}` on the fwd pass, it will become the # dst_sharding on the bwd pass g_bar: f32[8@x, 6]{U: y} = reshard(h_bar, P(x, None, unreduced={'y'})) # To do this matmul, we will have to all-reduce `h_bar` first over `y` so that it's # replicated. a_bar: f32[8@x, 4@y] = g_bar @ b.T ``` As you can see from the above example, we will have to all-reduce 4 times i.e. during each matmul if you are taking a grad against all the inputs `a, b, c, d`. These extra all-reductions can be avoided if the cotangent type of unreduced is `reduced`! In that case we will only do reduction once during the `reshard` op. **Show me how the bwd pass is efficient if cotangent type of unreduced is reduced** ``` ones: f32[] = lax_internal.one(ans) # replicated remains as replicated since they are cotangent type of each other. h_bar: f32[8@x, 6] = broadcast_in_dim(ones, ...) # since `x_tangent` is `f32[8@x, 6]{U: y}` on the fwd pass, it will become the # dst_sharding on the bwd pass but since the cotangent type of unreduced is reduced, # let's do that here. xt_s = x_tangent.sharding.spec g_bar: f32[8@x, 6] = reshard(h_bar, P(x, None, reduced=xt_s.unreduced, unreduced=xt_s.reduced)) # Since `reduced` is just `replicated`, we don't need to all-reduce here. a_bar: f32[8@x, 4@y] = g_bar @ b.T ``` As you can see, we don't incur any extra communication this way and the backward pass is optimal! **Why can't we do (2)** Because it leads to inefficient communication on the backward pass. Let's write the backward pass for the above example considering the cotangent type of unreduced is replicated and vice-versa ``` # Replicated will transpose to unreduced here! ones: f32[]{U:(x,y)} = lax_internal._ones() # i_bar.aval: f32[8@x, 6]{R:y} which will become `f32[8@x, 6]{U: y}. # This will result in a reduce-scatter from `ones` to `h_bar`. h_bar: f32[8@x, 6]{U: y} = broadcast_in_dim(ones, ...) # So f32[8@x, 6]{U: y} in forward pass becomes f32[8@x, 6]{R: y}. # This will result in an all-reduce comms. g_bar: f32[8@x, 6] = reshard(h_bar, P(x, None)) a_bar: f32[8@x, 4@y] = g_bar @ b.T ``` As you can see, we just incurred 2 comms (reduce-scatter + all-reduce) on a constant! These were completely unnecessary! We can try to optimize this pattern in the backward pass by introducing symbolic ones but that makes the backward pass more complicated and fragile. Option (3) removes the extra communication without incurring implementation complexity in the backward pass. Some TODOs: * more testing for unreduced * Propagate reduced through sharding rules properly in fwd and bwd pass. For example, the transpose rule of `dot_general` should do `x.aval.to_cotangent_aval().sharding`. * more testing for reduced in fwd which becomes unreduced in bwd * shard_map + unreduced + reduced support Co-authored-by: Matthew Johnson PiperOrigin-RevId: 769852286 --- jax/_src/api.py | 3 +- jax/_src/core.py | 10 +++++ jax/_src/lax/lax.py | 6 +-- jax/_src/named_sharding.py | 11 +++++ jax/_src/partition_spec.py | 85 +++++++++++++++++++++++++++---------- jax/_src/pjit.py | 2 +- jaxlib/partition_spec.cc | 86 ++++++++++++++++++++++++++++---------- jaxlib/partition_spec.h | 5 ++- jaxlib/xla_client.py | 2 +- tests/array_test.py | 5 +++ tests/pjit_test.py | 13 +++++- 11 files changed, 172 insertions(+), 56 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 1505274ca398..a93d93c57e70 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -509,8 +509,7 @@ def value_and_grad_f(*args, **kwargs): if not has_aux: ans, vjp_py = _vjp(f_partial, *dyn_args) else: - ans, vjp_py, aux = _vjp( - f_partial, *dyn_args, has_aux=True) + ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) _check_scalar(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) g = vjp_py(lax_internal._one(ans)) diff --git a/jax/_src/core.py b/jax/_src/core.py index 96bab24e1258..15c431599fed 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2093,6 +2093,12 @@ def to_tangent_aval(self): self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type, sharding=self.sharding, vma=self.vma) + def to_cotangent_aval(self): + dtype = primal_dtype_to_tangent_dtype(self.dtype) + sharding = primal_sharding_to_cotangent_sharding(self.sharding) + return ShapedArray( + self.shape, dtype, self.weak_type, sharding=sharding, vma=self.vma) + def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, @@ -2141,6 +2147,10 @@ def primal_dtype_to_tangent_dtype(primal_dtype): else: return primal_dtype +def primal_sharding_to_cotangent_sharding(sharding): + new_spec = P(*sharding.spec, unreduced=sharding.spec.reduced, + reduced=sharding.spec.unreduced) + return sharding.with_spec(new_spec) def pvary(x, axis_name): if not axis_name: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b0f17129bc77..8c271ebe3c15 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3358,9 +3358,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) out = broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) - out = core.pvary(out, tuple(aval.vma)) - return out - + return core.pvary(out, tuple(aval.vma)) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: @@ -4594,7 +4592,7 @@ def _add_unreduced(out_sharding, x, y): ' not allow this because there will be implicit communication. Please' f' reduce {lhs_str} via `reshard` before calling `add`.') else: - res_unreduced = None + res_unreduced = frozenset() return out_sharding.with_spec(out_sharding.spec.with_unreduced(res_unreduced)) add_p: Primitive = naryop(_input_dtype, [_num, _num], 'add', diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 13124d4a36aa..b97a073cd2a7 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -513,3 +513,14 @@ def _check_mesh_unreduced(mesh, pspec): 'Unreduced axes can only refer to mesh axes that is of type' f' `Explicit`. Got unreduced axes: {pspec.unreduced} and' f' mesh: {mesh}') + + for u in pspec.reduced: + if u not in mesh.axis_names: + raise ValueError( + f'Reduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): + raise ValueError( + 'Reduced axes can only refer to mesh axes that is of type' + f' `Explicit`. Got reduced axes: {pspec.reduced} and' + f' mesh: {mesh}') diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 435a1cc50669..b5bd2aecc20a 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Set from typing import Any, TYPE_CHECKING from jax._src.lib import jaxlib_extension_version @@ -24,7 +23,7 @@ # TODO(phawkins): the union confuses pytype. Just use the Python branch for now # until the C++ version is the minimum version. -if not TYPE_CHECKING and jaxlib_extension_version >= 349: +if not TYPE_CHECKING and jaxlib_extension_version >= 352: _UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION _canonicalize_partition = _jax.canonicalize_partition else: @@ -56,7 +55,7 @@ def _canonicalize_partition(partition): return tuple(partition) return partition - def _check(partitions, unreduced): + def _check(partitions, unreduced, reduced): for p in partitions: p = p if isinstance(p, tuple) else (p,) for r in p: @@ -65,13 +64,35 @@ def _check(partitions, unreduced): "partitions cannot overlap with unreduced axes passed to" f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" f" {unreduced}") + if r in reduced: + raise ValueError( + "partitions cannot overlap with reduced axes passed to" + f" PartitionSpec. Got partitions: {partitions} and reduced axes:" + f" {reduced}") + if unreduced & reduced: + raise ValueError( + "`unreduced` and `reduced` argument to PartitionSpec cannot overlap. " + f"Got {unreduced=}, {reduced=}") if None in unreduced: raise ValueError( "unreduced cannot contain None. All elements in unreduced should refer" " to the mesh axes.") + if None in reduced: + raise ValueError( + "reduced cannot contain None. All elements in reduced should refer" + " to the mesh axes.") + +def unpickle_pspec(partitions, unreduced, reduced): + return PartitionSpec(*partitions, unreduced=unreduced, reduced=reduced) -def unpickle_pspec(partitions, unreduced): - return PartitionSpec(*partitions, unreduced=unreduced) +def _get_ur_str(unreduced, reduced): + if unreduced and reduced: + return f"unreduced={set(unreduced)!r}, reduced={set(reduced)!r}" + elif unreduced and not reduced: + return f"unreduced={set(unreduced)!r}" + elif not unreduced and reduced: + return f"reduced={set(reduced)!r}" + assert False # unreachable AxisName = Any @@ -84,34 +105,38 @@ class PartitionSpec: This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ - if jaxlib_extension_version < 349: - __slots__ = ("_partitions", "unreduced") + if jaxlib_extension_version < 352: + __slots__ = ("_partitions", "unreduced", "reduced") __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION @use_cpp_method() - def __init__(self, *partitions, - unreduced: Set[AxisName] | None = None): + def __init__(self, *partitions, unreduced=frozenset(), reduced=frozenset()): self._partitions = tuple(_canonicalize_partition(p) for p in partitions) - if unreduced is not None and not isinstance(unreduced, (set, frozenset)): + if not isinstance(unreduced, (set, frozenset)): raise TypeError( - "`unreduced` argument of PartitionSpec should be `None` or of type" + "`unreduced` argument of PartitionSpec should be of type" f" `frozenset` or `set`. Got type {type(unreduced)}") - self.unreduced = frozenset() if unreduced is None else frozenset(unreduced) - _check(self._partitions, self.unreduced) + if not isinstance(reduced, (set, frozenset)): + raise TypeError( + "`reduced` argument of PartitionSpec should be of type" + f" `frozenset` or `set`. Got type {type(reduced)}") + self.unreduced = frozenset(unreduced) + self.reduced = frozenset(reduced) + _check(self._partitions, self.unreduced, self.reduced) def __repr__(self): pr = repr(self._partitions)[1:-1] - if not self.unreduced: + if not self.unreduced and not self.reduced: return f"PartitionSpec({pr})" - ur_str = f"unreduced={set(self.unreduced)!r}" + ur_str = _get_ur_str(self.unreduced, self.reduced) pr = '' if not pr else f"{pr} " if pr.endswith(',') else f"{pr}, " return (f"PartitionSpec({pr}{ur_str})") def __reduce__(self): - return (unpickle_pspec, (self._partitions, self.unreduced)) + return (unpickle_pspec, (self._partitions, self.unreduced, self.reduced)) def __getitem__(self, i): return self._partitions[i] @@ -126,12 +151,17 @@ def __len__(self): def __eq__(self, other): if isinstance(other, PartitionSpec): return (self._partitions == other._partitions and - self.unreduced == other.unreduced) + self.unreduced == other.unreduced and + self.reduced == other.reduced) elif isinstance(other, tuple): if self.unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__eq__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__eq__` of PartitionSpec.") other_p = tuple(_canonicalize_partition(o) for o in other) return self._partitions == other_p else: @@ -139,18 +169,23 @@ def __eq__(self, other): @use_cpp_method() def __hash__(self): - return hash((self._partitions, self.unreduced)) + return hash((self._partitions, self.unreduced, self.reduced)) def __add__(self, other): if isinstance(other, PartitionSpec): return PartitionSpec( *self, *other, - unreduced={*self.unreduced, *other.unreduced}) + unreduced={*self.unreduced, *other.unreduced}, + reduced={*self.reduced, *other.reduced}) elif isinstance(other, tuple): if self.unreduced: raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__add__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__add__` of PartitionSpec.") return PartitionSpec(*self, *other) else: raise NotImplementedError @@ -163,6 +198,10 @@ def __radd__(self, other): raise TypeError( f"other {other} cannot be of instance `tuple` when self {self} has" " unreduced in `__radd__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__radd__` of PartitionSpec.") return PartitionSpec(*other, *self) def index(self, value): @@ -172,10 +211,12 @@ def count(self, value): return self._partitions.count(_canonicalize_partition(value)) def with_partitions(self, new_partitions): - return PartitionSpec(*new_partitions, unreduced=self.unreduced) + return PartitionSpec(*new_partitions, unreduced=self.unreduced, + reduced=self.reduced) def with_unreduced(self, new_unreduced): - return PartitionSpec(*self._partitions, unreduced=new_unreduced) + return PartitionSpec(*self._partitions, unreduced=new_unreduced, + reduced=self.reduced) def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out = [None if p is _UNCONSTRAINED_PARTITION else p @@ -185,7 +226,7 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: return self.with_partitions(out) # TODO(phawkins): make this a decorator after the next jaxlib release. -if not TYPE_CHECKING and jaxlib_extension_version >= 349: +if not TYPE_CHECKING and jaxlib_extension_version >= 352: PartitionSpec = use_cpp_class(_jax.PartitionSpec)(PartitionSpec) # TODO(phawkins): make this a decorator after the next jaxlib release. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4b085bf124cb..56252298cd01 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -3014,7 +3014,7 @@ def _reshard_impl(x, dst_sharding): reshard_p.def_impl(_reshard_impl) def _reshard_transpose_rule(ct, x, dst_sharding): - return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)] + return [reshard_p.bind(ct, dst_sharding=x.aval.to_cotangent_aval().sharding)] ad.deflinear2(reshard_p, _reshard_transpose_rule) def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc index 3cb58f110c4d..e43ca4b6108c 100644 --- a/jaxlib/partition_spec.cc +++ b/jaxlib/partition_spec.cc @@ -79,12 +79,18 @@ nb::object CanonicalizePartition(nb::object unconstrained_singleton, return partition; } -void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced) { +void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced) { if (unreduced.contains(nb::none())) { throw nb::value_error( "unreduced cannot contain None. All elements in unreduced should " "refer to the mesh axes."); } + if (reduced.contains(nb::none())) { + throw nb::value_error( + "reduced cannot contain None. All elements in reduced should " + "refer to the mesh axes."); + } auto check_overlap = [&](nb::handle partition) { if (unreduced.contains(partition)) { throw nb::value_error( @@ -95,6 +101,15 @@ void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced) { nb::cast(nb::str(unreduced))) .c_str()); } + if (reduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with reduced axes passed to " + "PartitionSpec. Got partitions: %s and reduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(reduced))) + .c_str()); + } }; for (nb::handle partition : partitions) { if (nb::isinstance(partition)) { @@ -105,15 +120,30 @@ void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced) { check_overlap(partition); } } + // TODO(yashkatariya, phawkins): Update this to `!(unreduced & + // reduced).empty()` after nanobind's version > 2.7.0 + if (nb::len((unreduced & reduced)) != 0) { + throw nb::value_error( + absl::StrFormat("`unreduced` and `reduced` argument to PartitionSpec " + "cannot overlap. " + "Got unreduced: %s and reduced: %s", + nb::cast(nb::str(unreduced)), + nb::cast(nb::str(reduced))) + .c_str()); + } } } // namespace -PartitionSpec::PartitionSpec(nb::tuple partitions, nb_frozenset unreduced) - : partitions_(std::move(partitions)), unreduced_(std::move(unreduced)) {} +PartitionSpec::PartitionSpec(nb::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced) + : partitions_(std::move(partitions)), + unreduced_(std::move(unreduced)), + reduced_(std::move(reduced)) {} Py_ssize_t PartitionSpec::Hash() const { - size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_)); + size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_), + nb::hash(reduced_)); Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. } @@ -125,11 +155,13 @@ bool PartitionSpec::Eq(const nb::object& other) const { PartitionSpec* other_spec; if (nb::try_cast(other, other_spec)) { return partitions().equal(other_spec->partitions()) && - unreduced().equal(other_spec->unreduced()); + unreduced().equal(other_spec->unreduced()) && + reduced().equal(other_spec->reduced()); } nb::tuple other_tuple; if (nb::try_cast(other, other_tuple)) { - if (unreduced().size() > 0 || partitions().size() != other_tuple.size()) { + if (unreduced().size() > 0 || reduced().size() > 0 || + partitions().size() != other_tuple.size()) { return false; } for (size_t i = 0; i < partitions().size(); ++i) { @@ -162,7 +194,7 @@ void PartitionSpec::Register(nb::module_& m) { .def( "__init__", [](PartitionSpec* self, nb::args partition_args, - nb::object unreduced_arg) { + nb::object unreduced_arg, nb::object reduced_arg) { nb::tuple partitions = nb::steal(PyTuple_New(partition_args.size())); for (size_t i = 0; i < partition_args.size(); ++i) { @@ -174,26 +206,34 @@ void PartitionSpec::Register(nb::module_& m) { .ptr()); } nb_frozenset unreduced; - if (unreduced_arg.is_none()) { - unreduced = nb_frozenset(); - } else { - if (!PyAnySet_Check(unreduced_arg.ptr())) { - throw nb::type_error( - absl::StrFormat( - "unreduced argument of PartitionSpec should be `None` " - "or of type `frozenset` or `set`. Got type %s", - nb::cast(nb::repr(unreduced_arg.type()))) - .c_str()); - } - unreduced = nb_frozenset(unreduced_arg); + nb_frozenset reduced; + if (!PyAnySet_Check(unreduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "unreduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(unreduced_arg.type()))) + .c_str()); + } + if (!PyAnySet_Check(reduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "reduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(reduced_arg.type()))) + .c_str()); } - CheckPartitionSpec(partitions, unreduced); - new (self) - PartitionSpec(std::move(partitions), std::move(unreduced)); + unreduced = nb_frozenset(unreduced_arg); + reduced = nb_frozenset(reduced_arg); + CheckPartitionSpec(partitions, unreduced, reduced); + new (self) PartitionSpec(std::move(partitions), + std::move(unreduced), std::move(reduced)); }, - nb::arg("partitions"), nb::arg("unreduced").none() = nb::none()) + nb::arg("partitions"), nb::arg("unreduced") = nb_frozenset(), + nb::arg("reduced") = nb_frozenset()) .def_prop_ro("_partitions", &PartitionSpec::partitions) .def_prop_ro("unreduced", &PartitionSpec::unreduced) + .def_prop_ro("reduced", &PartitionSpec::reduced) .def("__eq__", &PartitionSpec::Eq, nb::arg().none()) .def("__hash__", &PartitionSpec::Hash); } diff --git a/jaxlib/partition_spec.h b/jaxlib/partition_spec.h index 45f577557926..62c292a0c966 100644 --- a/jaxlib/partition_spec.h +++ b/jaxlib/partition_spec.h @@ -40,10 +40,12 @@ class nb_frozenset : public nanobind::object { class PartitionSpec { public: - PartitionSpec(nanobind::tuple partitions, nb_frozenset unreduced); + PartitionSpec(nanobind::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced); nanobind::tuple partitions() const { return partitions_; } nb_frozenset unreduced() const { return unreduced_; } + nb_frozenset reduced() const { return reduced_; } bool Eq(const nanobind::object& other) const; Py_ssize_t Hash() const; @@ -53,6 +55,7 @@ class PartitionSpec { private: nanobind::tuple partitions_; nb_frozenset unreduced_; + nb_frozenset reduced_; static nanobind::object* unconstrained_singleton_; }; diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 92657f9f277d..911385f398bc 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 351 +_version = 352 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/array_test.py b/tests/array_test.py index 4403755a1dff..0814d5888016 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1412,6 +1412,11 @@ def test_memory_kind_with_abstract_mesh(self): NamedSharding(abstract_mesh, P(), memory_kind='weird_device') def test_pspec_unreduced(self): + pspec = P('a', 'b', None, unreduced={'c'}, reduced={'d'}) + self.assertEqual( + repr(pspec), + "PartitionSpec('a', 'b', None, unreduced={'c'}, reduced={'d'})") + pspec1 = P('a', 'b', None, unreduced={'c'}) self.assertEqual(repr(pspec1), "PartitionSpec('a', 'b', None, unreduced={'c'})") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f2344f7b3ce1..35f9b6cda7e5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7789,7 +7789,7 @@ def f(x): @config.use_shardy_partitioner(True) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_unreduced_basic(self, mesh): - np_inp = np.arange(16).reshape(8, 2) + np_inp = np.arange(16.).reshape(8, 2) x = jax.device_put(np_inp, P('x', 'y')) y = jax.device_put(np_inp.T, P('y', None)) a = jax.device_put(np_inp, P('x', 'y')) @@ -7813,7 +7813,16 @@ def f(x, y, a, b): traced = f.trace(x, y, a, b) lowered_text = traced.lower().as_text() self.assertIn('unreduced={"y"}', lowered_text) - self.assertTrue(lowered_text.count('unreduced={"y"}') == 3) + self.assertEqual(lowered_text.count('unreduced={"y"}'), 3) + + # TODO(yashkatariya): Execute this too + grad_jaxpr = jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(), + argnums=(0, 1, 2, 3))).trace(x, y, a, b).jaxpr + reshard_eqn = grad_jaxpr.eqns[4].params['jaxpr'].eqns[0] + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.reduced, + frozenset('y')) + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.unreduced, + frozenset()) @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_dot_general_unreduced_error(self, mesh): From d9e0244c23fc42e3a26e1d1dff613153c8ad4506 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 10 Jun 2025 16:37:48 -0700 Subject: [PATCH 1622/1769] [JAX] Move the fallback of `colocated_cpu_devices` logic from the colocated Python test to the API The colocated Python test defines a fallback logic for `colocated_cpu_devices`, which would find CPU devices from a local CPU backend instead of the default backend. This works for typical setups of McJAX, and is intended to be temporary until McJAX backends define its own CPU devices. This change moves the fallback logic from the test into the colocated Python API and make the `colocated_cpu_devices` API invoke it when necessary. The fallback logic was inaccessible from the API, which made it difficult to write another high-level API for colocated Python such as `colocated_cpu_mesh` (which will follow subsequently) and have it tested on McJAX. PiperOrigin-RevId: 769854488 --- jax/experimental/colocated_python/api.py | 24 ++++++++++- tests/colocated_python_test.py | 52 ++++++++---------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index 81db9b965e7c..45cd9e47e15a 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -20,6 +20,7 @@ import jax from jax._src import api_util +from jax._src import util from jax.experimental.colocated_python.func import make_callable from jax.experimental.colocated_python.obj import wrap_class @@ -30,10 +31,13 @@ def colocated_cpu_devices( """Finds CPU devices colocated with the given devices.""" if not isinstance(devices, tuple): devices = tuple(devices) - return _colocated_cpu_devices_cached(devices) + try: + return _colocated_cpu_devices_cached(devices) + except (ValueError, AttributeError): + return _colocated_cpu_devices_cached_fallback_to_cpu_backend(devices) -@jax._src.util.cache(max_size=1024, trace_context_in_key=False) +@util.cache(max_size=1024, trace_context_in_key=False) def _colocated_cpu_devices_cached( devices: tuple[jax.Device, ...], ) -> Sequence[jax.Device]: @@ -58,6 +62,22 @@ def _colocated_cpu_devices_cached( return colocated_cpu_devices +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: + # PjRt-IFRT currently defines CPU devices by using a CPU backend. + # TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines + # CPU devices by its own instead of using a separate CPU backend. + cpu_backend_devices = jax.local_devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[: min(len(cpu_backend_devices), len(devices))] + return [ + cpu_backend_devices[device_index_map[d.id]] for d in available_devices + ] + + def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]: """Executes the given Python function on the same devices as the arguments.""" return make_callable( diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index ada17bc61c82..bff745eeba9d 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -17,7 +17,6 @@ import tempfile import threading import time -from typing import Sequence from absl.testing import absltest from absl.testing import parameterized @@ -40,25 +39,6 @@ HAS_CLOUDPICKLE = False -def _colocated_cpu_devices( - devices: Sequence[jax.Device], -) -> Sequence[jax.Device]: - """Returns CPU devices colocated with the given devices.""" - try: - return colocated_python.colocated_cpu_devices(devices) - except (ValueError, AttributeError): - # PjRt-IFRT prepares CPU devices by its own. - # TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU - # devices by its own. - cpu_backend_devices = jax.local_devices(backend="cpu") - device_index_map = {device.id: i for i, device in enumerate(jax.devices())} - - available_devices = devices[: min(len(cpu_backend_devices), len(devices))] - return [ - cpu_backend_devices[device_index_map[d.id]] for d in available_devices - ] - - _count_colocated_python_specialization_cache_miss = jtu.count_events( "colocated_python_func._get_specialized_func" ) @@ -82,7 +62,7 @@ def testMakeColocatedPythonProgram(self): def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) @@ -97,7 +77,7 @@ def testSimpleFunction(self): def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -117,7 +97,7 @@ def testSimpleFunctionWithTree(self): def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = [np.array(1), (np.array(2), {"v": np.array(3)})] x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) @@ -149,7 +129,7 @@ def testEmptyInputWithDevicesSpecialization(self): def make_zero(): return jnp.array(0) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) with _count_colocated_python_specialization_cache_miss() as count: make_zero = make_zero.specialize(devices=cpu_devices[:1]) @@ -168,7 +148,7 @@ def testInputPolymorphismWithoutOutSpecsFn(self): def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -202,7 +182,7 @@ def testInputPolymorphismAllowedWithOutSpecsFn(self): def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -237,7 +217,7 @@ def add_one(x): ("on_non_main_thread", False), ) def testSequentialExecution(self, on_main_thread: bool): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) # Make sure that this input array is ready for use by the colocated Python @@ -274,7 +254,7 @@ def sleep_twice_and_wait(x: jax.Array) -> None: self.assertGreaterEqual(elapsed_time, 10) def testConcurrentExecution(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) # Make sure that this input array is ready for use by the colocated Python @@ -311,7 +291,7 @@ def sleep_and_wait(x: jax.Array) -> None: self.assertLess(elapsed_time, 10) def testInputsWithDifferentDeviceOrders(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2] + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices())[:2] if len(cpu_devices) < 2: self.skipTest("Not enough CPU devices") @@ -376,7 +356,7 @@ def get_global_state(x: jax.Array) -> jax.Array: del x return colocated_python._testing_global_state - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) y = np.array(2) @@ -394,7 +374,7 @@ def get_global_state(x: jax.Array) -> jax.Array: del colocated_python._testing_global_state def testStringProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -435,7 +415,7 @@ def f(x): ) def testBinaryDataProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") @@ -477,7 +457,7 @@ def f(x): self.assertEqual(out_ints[1], 1003) def testDetectInvalidMeshDevice(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if jax.local_devices()[0].id == cpu_devices[0].id: self.skipTest( "This test only works in a setup where accelerator and CPU devices" @@ -498,7 +478,7 @@ def make_zero() -> jax.Array: jax.block_until_ready(make_zero()) def testObjectLifecycle(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) @colocated_python.colocated_python_class @@ -570,7 +550,7 @@ def cleanup(): cleanup() def testStatefulObject(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) @colocated_python.colocated_python_class class Value: @@ -602,7 +582,7 @@ def fetch(self, x: jax.Array) -> jax.Array: self.assertEqual(out, np.array(7)) def testObjectWithCapturedSharding(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") From fd6d90a7a6cbd30d5b68f8c485f6164d7f8dbec5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Jun 2025 17:25:16 -0700 Subject: [PATCH 1623/1769] Add basic mutable array tests with AOT PiperOrigin-RevId: 769870948 --- tests/mutable_array_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 0f88ec4c95b5..ce35424cb418 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -29,6 +29,9 @@ config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + class MutableArrayTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -50,6 +53,35 @@ def f(x_mut): jaxpr = jax.make_jaxpr(f)(x_mut) self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + def test_basic_aot(self): + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.mutable_array(jnp.zeros(3)) + f.lower(x_mut).compile()(x_mut) + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + def test_basic_sharded_aot(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.mutable_array(arr) + f.lower(x_mut).compile()(x_mut) + expected = np.arange(8.) + 1 + expected[0] += 1 + expected[1] += 5 + self.assertAllClose(x_mut[...], expected) + @parameterized.parameters([True, False]) def test_multiple_inputs_and_outputs(self, jit): def f(x_mut, y, z_mut, w): From f211c6b970fbf94192ac273c82a82a0b77d44baf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Jun 2025 18:17:06 -0700 Subject: [PATCH 1624/1769] Save a jaxpr equation in pl.cdiv if the rhs is an int. PiperOrigin-RevId: 769888714 --- jax/_src/pallas/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index a78c5487a4d6..15844da927e3 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -44,7 +44,7 @@ def cdiv(a: jax.Array, b: jax.Array) -> jax.Array: def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array: if isinstance(a, int) and isinstance(b, int): return (a + b - 1) // b - return lax.div(a + b - 1, b) + return lax.div(a + (b - 1), b) def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: From cd4a0c6e32c84892662ef7003983fffc5c3d140b Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Jun 2025 23:58:22 -0700 Subject: [PATCH 1625/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/870d90fd098c480fb8a426126bd02047adb2bc20. PiperOrigin-RevId: 769995600 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a71ca8c6ff32..ca58bb6f0250 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "28c9c6fbcc5a63f9d05fd4d65ee1e509bb313e85" -XLA_SHA256 = "b7638d60823c51700ab28451e32e440e8d9856f81b7c6838cc72f8a575fdba83" +XLA_COMMIT = "870d90fd098c480fb8a426126bd02047adb2bc20" +XLA_SHA256 = "963b285bbc6f40a198833a14effc4f38f75b9c5d1813ccef4f09c287d0cb9ae4" def repo(): tf_http_archive( From 02cdc7ba9670d2296e3da626f3ae3deb23be0811 Mon Sep 17 00:00:00 2001 From: Axel Stjerngren Date: Wed, 11 Jun 2025 01:00:52 -0700 Subject: [PATCH 1626/1769] Ensure that all attributes are restored after pickling in `NamedSharding`. Colocated Python in Pathways relies on the the `memory_kind` being dropped, so explicitly update the serialization behavior to not include the `memory_kind`. PiperOrigin-RevId: 770016243 --- jax/_src/named_sharding.py | 10 +++++++--- .../colocated_python/serialization.py | 15 +++++++++++++++ tests/pickle_test.py | 17 ++++++++++++++++- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index b97a073cd2a7..bae999e1e83d 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -74,6 +74,11 @@ def __repr__(self): ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue] +def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids): + return NamedSharding(mesh, spec, memory_kind=memory_kind, + _logical_device_ids=logical_device_ids) + + @use_cpp_class(xc.NamedSharding) class NamedSharding(JSharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. @@ -133,9 +138,8 @@ def __repr__(self): return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})' def __reduce__(self): - return (type(self), (self.mesh, self.spec), - {'memory_kind': self.memory_kind, - '_logical_device_ids': self._logical_device_ids}) + return (_unpickle_named_sharding, + (self.mesh, self.spec, self.memory_kind, self._logical_device_ids)) @property def memory_kind(self) -> str | None: diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 1f1b96487fab..74c34a495920 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -99,6 +99,20 @@ def make_mesh( return make_mesh, (mesh_device_ids, mesh.axis_names) +def _reduce_named_sharding( + sharding: jax.sharding.NamedSharding, +) -> tuple[Callable[..., jax.sharding.NamedSharding], Any]: + # TODO(hyeontaek): Use `legacy_memory_space_behavior=false` for the + # CPU backend's `xla::CpuClientOptions`, and preserve the memory + # kind across serialization. + # Colocated Python implicitly relies on the default memory kind + # being reset to the default memory space when deserializing. + def _make_named_sharding(mesh, spec): + return jax.sharding.NamedSharding(mesh, spec) + + return _make_named_sharding, (sharding.mesh, sharding.spec) + + def _reduce_device_list( device_list: DeviceList, ) -> tuple[Callable[..., DeviceList], Any]: @@ -149,6 +163,7 @@ def _serialize(obj: Any) -> bytes: class _CustomPickler(cloudpickle.Pickler): dispatch_table = collections.ChainMap( {jax.sharding.Mesh: _reduce_mesh}, + {jax.sharding.NamedSharding: _reduce_named_sharding}, {DeviceList: _reduce_device_list}, {jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error diff --git a/tests/pickle_test.py b/tests/pickle_test.py index a3dc5be6e11c..c9bfb0723d94 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -190,9 +190,24 @@ def test_pickle_gspmd_sharding(self): def test_pickle_named_sharding(self): s = jax.sharding.NamedSharding( mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), - spec=jax.sharding.PartitionSpec('d')) + spec=jax.sharding.PartitionSpec('d'), + ) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, 'Requires cloudpickle') + def test_pickle_named_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), + spec=jax.sharding.PartitionSpec('d'), + memory_kind=memory_kind, + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 9da048ec7eddd9334e8a2dd4105adc9e376f0da3 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 11 Jun 2025 02:29:02 -0700 Subject: [PATCH 1627/1769] [Mosaic GPU] Use _slice_smem also for barriers. With this, `gpu.dynamic_shared_memory` ops are not present in the input to the dialect lowering passes and are only insert during lowering. PiperOrigin-RevId: 770046134 --- jax/experimental/mosaic/gpu/core.py | 41 ++++++++----------- .../mosaic/gpu/dialect_lowering.py | 12 +++++- jax/experimental/mosaic/gpu/utils.py | 21 ++++++---- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 435e6976a5b1..652b22050c7d 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -338,8 +338,8 @@ def _construct_smem_reftree( dynamic_smem_offset: int = 0, ) -> Callable[[], RefTree]: index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) smem = ir.Attribute.parse("#gpu.address_space") flat_ref_tys, smem_buffer_tree = jax.tree.flatten( smem_buffers, is_leaf=lambda x: isinstance(x, Union) @@ -347,21 +347,23 @@ def _construct_smem_reftree( smem_refs = [] for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: + def barrier_memref(num_barriers: int) -> ir.Value: nonlocal dynamic_smem_offset - workgroup_nvptx_address_space = ( - utils.gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup) - ) - smem_base_ptr = utils.memref_ptr( - dynamic_smem, memory_space=workgroup_nvptx_address_space - ) - smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") - barrier_base_ptr = llvm.getelementptr( - smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8, - llvm.GEPNoWrapFlags.none + barrier_ty = ir.MemRefType.get( + (num_barriers,), + ir.Type.parse("!mosaic_gpu.barrier") + if lowering_semantics == LoweringSemantics.Warpgroup + else i64, + memory_space=smem, ) + barrier_memref = _slice_smem( + barrier_ty, + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, + ) dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES - return barrier_base_ptr + return barrier_memref match ref_ty: case Union(members): member_thunks = [ @@ -385,22 +387,15 @@ def ref(member_thunks=member_thunks): init_fn = utils.DialectBarrierRef.initialize if ( lowering_semantics == LoweringSemantics.Warpgroup ) else utils.BarrierRef.initialize - ref = init_fn( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 - ) + ref = init_fn(barrier_memref(num_barriers), arrival_count=1) case Barrier(arrival_count, num_barriers): init_fn = utils.DialectBarrierRef.initialize if ( lowering_semantics == LoweringSemantics.Warpgroup ) else utils.BarrierRef.initialize - ref = init_fn( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, - ) + ref = init_fn(barrier_memref(num_barriers), arrival_count=arrival_count) case ClusterBarrier(collective_dims, num_barriers): ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, + barrier_memref(num_barriers), collective_dims, cluster_shape, ) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 4029131baafb..ee62dfd751ef 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1072,7 +1072,17 @@ def _slice_smem(result: ir.Type, offset: ir.Value): ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) offset = arith.index_cast(ir.IndexType.get(), offset) - return memref.view(result, smem_base, offset, []) + lowered_result_type = result + if ir.MemRefType.isinstance(result): + memref_ty = ir.MemRefType(result) + if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"): + lowered_result_type = ir.MemRefType.get( + memref_ty.shape, _lowered_barrier_type(), memory_space=smem + ) + view = memref.view(lowered_result_type, smem_base, offset, []) + if result == lowered_result_type: + return view + return builtin.unrealized_conversion_cast([result], [view]) # The metadata needed to recostruct a vector from its flattened representation. diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index b5dbfb62c88f..9ea964c98b86 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -756,12 +756,17 @@ class BarrierRef: num_barriers: int @staticmethod - def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> "BarrierRef": + def initialize(barrier_memref: ir.Value, arrival_count: int = 1) -> "BarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape if num_barriers > 32: raise NotImplementedError("Only up to 32 barriers per group supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ) phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), phases, []) with single_thread(scope=ThreadSubset.BLOCK): @@ -867,15 +872,16 @@ class DialectBarrierRef: @staticmethod def initialize( - address: ir.Value, - num_barriers: int, + barrier_memref: ir.Value, arrival_count: int = 1, ) -> "DialectBarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape if num_barriers > 32: raise NotImplementedError("Only up to 32 barriers per group supported") - barrier_ty = ir.MemRefType.get( - (num_barriers,), ir.Type.parse("!mosaic_gpu.barrier") + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE ) dialect.InitializeBarrierOp( barrier_ty, base_pointer=address, arrival_count=arrival_count @@ -957,8 +963,7 @@ class CollectiveBarrierRef: @staticmethod def initialize( - address: ir.Value, - num_barriers: int, + barrier_memref: ir.Value, dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]], cluster_shape: tuple[int, int, int], ) -> "CollectiveBarrierRef": @@ -986,7 +991,7 @@ def initialize( cluster_mask = arith.ori( cluster_mask, cluster_collective_mask(cluster_shape, d) ) - barrier = BarrierRef.initialize(address, num_barriers, arrival_count=arrival_count) + barrier = BarrierRef.initialize(barrier_memref, arrival_count=arrival_count) return CollectiveBarrierRef(barrier, cluster_mask) def __iter__(self): From de82f9f47d3cb862d5627eee0060d956468418af Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 11 Jun 2025 05:03:55 -0700 Subject: [PATCH 1628/1769] [pallas:mosaic] A few more primitives now have lowerings for all kernel types PiperOrigin-RevId: 770095244 --- jax/_src/pallas/mosaic/lowering.py | 50 +++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index db17a25a9e5c..22a07537e508 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2250,7 +2250,7 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ) -@register_lowering_rule(lax.squeeze_p) +@register_lowering_rule(lax.squeeze_p, kernel_types=[*tpu_core.KernelType]) def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): del dimensions # Unused. (aval_in,) = ctx.avals_in @@ -2517,7 +2517,9 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.mul_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.mul_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2528,7 +2530,9 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.div_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.div_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2541,7 +2545,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.rem_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.rem_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2574,7 +2580,7 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x) -@register_lowering_rule(lax.sign_p) +@register_lowering_rule(lax.sign_p, kernel_types=[*tpu_core.KernelType]) def _sign_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.sign_lowering_helper, multiple_results=False, @@ -2852,7 +2858,9 @@ def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): ) -@register_lowering_rule(lax.and_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.and_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _and_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.andi(x, y) @@ -2867,7 +2875,9 @@ def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): return _not_lowering_rule(ctx, tpu.weird(out_type, x)) -@register_lowering_rule(lax.or_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.or_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _or_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.ori(x, y) @@ -2898,7 +2908,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): return arith.xori(x, minus_one) -@register_lowering_rule(lax.select_n_p) +@register_lowering_rule(lax.select_n_p, kernel_types=[*tpu_core.KernelType]) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): if len(args) > 1: raise NotImplementedError("select_n only supported with <= 2 arguments") @@ -3007,7 +3017,9 @@ def _run_body(i, args): return for_op.results -@register_lowering_rule(lax.scan_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False) +@register_lowering_rule( + lax.scan_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3181,7 +3193,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params): return if_op.results -@register_lowering_rule(pjit.pjit_p) +@register_lowering_rule(pjit.pjit_p, kernel_types=[*tpu_core.KernelType]) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args) @@ -3297,7 +3309,7 @@ def _roll_lowering_rule( ) -@register_lowering_rule(lax.slice_p) +@register_lowering_rule(lax.slice_p, kernel_types=[*tpu_core.KernelType]) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -3314,13 +3326,19 @@ def _slice_lowering_rule( ) -@register_lowering_rule(lax.xor_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.xor_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.xori(x, y) -@register_lowering_rule(lax.shift_left_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.shift_left_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shli(x, d) @@ -3332,7 +3350,11 @@ def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): return arith.shrsi(x, d) -@register_lowering_rule(lax.shift_right_logical_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.shift_right_logical_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) def _shift_right_logical_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrui(x, d) From 827b855b8758baa8bad2ab57d8f6d8843dcd3d04 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 11 Jun 2025 05:55:06 -0700 Subject: [PATCH 1629/1769] [Mosaic GPU] Remove unneeded code. PiperOrigin-RevId: 770111101 --- .../mosaic/gpu/transform_inference.py | 21 ------------------- tests/mosaic/gpu_transform_inference_test.py | 9 ++------ 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 7b31bcaefb27..46d04026d588 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -25,7 +25,6 @@ from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector @@ -223,17 +222,6 @@ def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: return None if transforms is None else ([], [transforms]) -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` -# which is used in `_construct_smem_reftree`. -@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) -def _infer_unrealized_conversion_cast_transforms( - _: builtin.UnrealizedConversionCastOp, -) -> OptionalTransforms: - return None - - @partial(_add_transform_inference_rule, memref.ViewOp) def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): @@ -253,15 +241,6 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: return None if transforms is None else ([], [transforms]) -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp) -def _infer_dynamic_smem_transforms( - _: gpu.DynamicSharedMemoryOp, -) -> OptionalTransforms: - return None - - def _get_tile_and_swizzle_transforms( transforms: ir.ArrayAttr | None, ) -> tuple[ir.Attribute, ir.Attribute]: diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 3fdbfd650cb1..420d735fef54 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -24,7 +24,6 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector @@ -444,9 +443,7 @@ def body(in_ref): static_sizes=[1, 64, 64], static_strides=[1, 1, 1], ) - user_op = builtin.UnrealizedConversionCastOp( - [out_ref_ty], [subview_op.result] - ) + user_op = memref.CastOp(out_ref_ty, subview_op.result) with ir.InsertionPoint(self.module.body): f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op @@ -523,9 +520,7 @@ def body(in_ref): static_sizes = [2, 64, 32], static_strides = [1, 1, 1] ) - user_op = builtin.UnrealizedConversionCastOp( - [out_ref_ty], [subview_op.result] - ) + user_op = memref.CastOp(out_ref_ty, subview_op.result) with ir.InsertionPoint(self.module.body): f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op From 0edfb44f0fc9822d9771fd87fa5960071821f420 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Jun 2025 06:09:29 -0700 Subject: [PATCH 1630/1769] Propagate source_info in more places: * in staging rules to default_process_primitive, notably in pjit_staging_rule. * from trace_to_jaxpr_dynamic to to_jaxpr. END_PUBLIC Saves maybe 500ms in :transformer_benchmark, which is minor, but faster is faster. PiperOrigin-RevId: 770115245 --- jax/_src/interpreters/ad.py | 11 +++++++---- jax/_src/interpreters/partial_eval.py | 9 +++++---- jax/_src/lax/control_flow/loops.py | 6 +++--- jax/_src/lax/lax.py | 9 ++++++--- jax/_src/lax/slicing.py | 5 +++-- jax/_src/pjit.py | 5 +++-- tests/hijax_test.py | 4 ++-- 7 files changed, 29 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 51f007f4e576..996b0d4258cf 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -107,7 +107,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info, source_info) if attrs_tracked: raise NotImplementedError("TODO: attrs") which_env = [(isinstance(c, pe.DynamicJaxprTracer) and @@ -194,7 +194,8 @@ def new_arg(trace, primal_aval, nz, source_info): nzs_out = [type(t) is not Zero for t in out_tangents] out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t, source_info) for (nz, t) in zip(nzs_out, out_tangents) if nz) - tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) + tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr( + out_tangents, debug_info, source_info) tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -205,7 +206,8 @@ def new_arg(trace, primal_aval, nz, source_info): residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), residuals_and_primals) - primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) + primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr( + residuals_and_primals, debug_info, source_info) primal_trace.invalidate() num_residuals = len(tangent_consts) tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) @@ -240,7 +242,8 @@ def direct_linearize(traceable: lu.WrappedFun, out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr( + out_nz_tangents, traceable.debug_info, source_info) tangent_trace.invalidate() jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 7a1fba94bb3d..78d3e75a4664 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1789,6 +1789,7 @@ def to_jaxpr( self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer], debug_info: core.DebugInfo, + source_info: SourceInfo, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] @@ -1796,7 +1797,6 @@ def to_jaxpr( invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - source_info = source_info_util.current() state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x, source_info))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] @@ -2239,8 +2239,8 @@ def transpose_jaxpr_thunk(): return out_tracers def to_jaxpr(self, out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo): - return self.frame.to_jaxpr(self, out_tracers, debug_info) + debug_info: core.DebugInfo, source_info: SourceInfo): + return self.frame.to_jaxpr(self, out_tracers, debug_info, source_info) custom_staging_rules: dict[Primitive, Callable] = {} @@ -2301,7 +2301,8 @@ def trace_to_jaxpr_dynamic( _check_returned_jaxtypes(fun.debug_info, ans) out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.frame.to_jaxpr(trace, out_tracers, fun.debug_info) + jaxpr, consts, attrs_tracked = trace.frame.to_jaxpr( + trace, out_tracers, fun.debug_info, source_info) del fun, in_tracers, out_tracers, ans finally: trace.frame.reset_states(trace) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 8857ffad7a5f..35fd26b21783 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -622,11 +622,11 @@ def _empty_array(prefix, length_spec, aval): eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace: pe.JaxprTrace, source_info, *tracers, +def _stage_jaxpr(trace: pe.DynamicJaxprTrace, source_info, *tracers, jaxpr: core.ClosedJaxpr): - del source_info params = dict(call_jaxpr=jaxpr) - return trace.default_process_primitive(core.closed_call_p, tracers, params) + return trace.default_process_primitive(core.closed_call_p, tracers, params, + source_info=source_info) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8c271ebe3c15..9af301c26c0e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6522,7 +6522,8 @@ def _broadcast_in_dim_staging_rule( params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) if not dyn: - return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) + return trace.default_process_primitive(broadcast_in_dim_p, (x,), params, + source_info=source_info) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) return _dyn_shape_staging_rule(trace, source_info, broadcast_in_dim_p, aval, x, *dyn, **params) @@ -7242,7 +7243,8 @@ def _reshape_staging_rule( trace, source_info, x, *dyn, new_sizes, dimensions, sharding): params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) if not dyn: - return trace.default_process_primitive(reshape_p, (x,), params) + return trace.default_process_primitive(reshape_p, (x,), params, + source_info=source_info) av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) return _dyn_shape_staging_rule(trace, source_info, reshape_p, av, x, *dyn, **params) @@ -8598,7 +8600,8 @@ def _iota_staging_rule(trace, source_info, *dyn_shape, dtype, shape, dimension, params = dict(dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) if not dyn_shape: - return trace.default_process_primitive(iota_p, (), params) + return trace.default_process_primitive(iota_p, (), params, + source_info=source_info) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) return _dyn_shape_staging_rule(trace, source_info, iota_p, aval, *dyn_shape, **params) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index cacc1d14e42b..d70110c7301a 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1546,8 +1546,9 @@ def _dynamic_slice_staging_rule(trace, source_info, x, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim]) if not dyn: - return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices), - dict(slice_sizes=slice_sizes)) + return trace.default_process_primitive( + dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes), + source_info=source_info) shape = lax._merge_dyn_shape(slice_sizes, dyn) aval = core.DShapedArray(shape, x.dtype, False) return lax._dyn_shape_staging_rule(trace, source_info, dynamic_slice_p, aval, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 56252298cd01..9076763fb7ea 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2010,9 +2010,10 @@ def pjit_staging_rule(trace, source_info, *args, **params): new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, in_layouts=in_layouts, donated_invars=donated_invars) out_tracers = trace.default_process_primitive( - pjit_p, (*args, *consts), new_params) + pjit_p, (*args, *consts), new_params, source_info=source_info) else: - out_tracers = trace.default_process_primitive(pjit_p, args, params) + out_tracers = trace.default_process_primitive( + pjit_p, args, params, source_info=source_info) return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule diff --git a/tests/hijax_test.py b/tests/hijax_test.py index d033404879ff..de0862b7821c 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -50,9 +50,9 @@ def __init__(self, name): pe.custom_staging_rules[self] = self.staging def staging(self, trace, source_info, *args, **kwargs): - del source_info trace.frame.is_high = True - return trace.default_process_primitive(self, args, kwargs) + return trace.default_process_primitive(self, args, kwargs, + source_info=source_info) def is_high(self, **params): return True From 1cd49920f26f7dbbaaa59ee07171cc0f4740c42e Mon Sep 17 00:00:00 2001 From: Axel Stjerngren Date: Wed, 11 Jun 2025 07:21:48 -0700 Subject: [PATCH 1631/1769] Ensure that memory_kind is restored after pickling in SingleDeviceSharding and GSPMDSharding. Removed from `PmapSharding` as it always tries to return the default memory sharding. PiperOrigin-RevId: 770138541 --- jax/_src/sharding_impls.py | 16 +++++++++++----- tests/pickle_test.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 4703e6403079..d406faceabb7 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -99,6 +99,10 @@ def build(self) -> sdy.TensorShardingPerValueAttr: replicated_hlo_sharding = xc.HloSharding.replicate() +def _unpickle_single_device_sharding(device, memory_kind): + return SingleDeviceSharding(device, memory_kind=memory_kind) + + @use_cpp_class(xc.SingleDeviceSharding) class SingleDeviceSharding(jsharding.Sharding): """A :class:`Sharding` that places its data on a single device. @@ -121,7 +125,7 @@ def __init__(self, device: Device, *, memory_kind: str | None = None): self._memory_kind = memory_kind def __reduce__(self): - return type(self), (self._device,), {'memory_kind': self._memory_kind} + return (_unpickle_single_device_sharding, (self._device, self._memory_kind)) def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' @@ -205,8 +209,7 @@ def __init__(self, devices: Sequence[Device] | np.ndarray, self.sharding_spec = sharding_spec def __reduce__(self): - return (type(self), (self.devices, self.sharding_spec), - {'memory_kind': self.memory_kind}) + return (type(self), (self.devices, self.sharding_spec)) def __eq__(self, other): if not isinstance(other, PmapSharding): @@ -558,6 +561,9 @@ def __eq__(self, other) -> bool: self._ids == other._ids) +def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): + return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) + @use_cpp_class(xc.GSPMDSharding) class GSPMDSharding(jsharding.Sharding): _devices: tuple[Device, ...] @@ -579,8 +585,8 @@ def __init__(self, devices: Sequence[Device], self._memory_kind = memory_kind def __reduce__(self): - return (type(self), (self._devices, self._hlo_sharding.to_proto()), - {'memory_kind': self._memory_kind}) + return (_unpickle_gspmd_sharding, + (self._devices, self._hlo_sharding.to_proto(), self._memory_kind)) @functools.cached_property def _hlo_sharding_hash(self): diff --git a/tests/pickle_test.py b/tests/pickle_test.py index c9bfb0723d94..feaeb8db01c7 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -175,6 +175,17 @@ def test_pickle_single_device_sharding(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_single_device_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind=memory_kind + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), @@ -186,6 +197,15 @@ def test_pickle_gspmd_sharding(self): s = GSPMDSharding.get_replicated(jax.devices()) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_gspmd_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = GSPMDSharding.get_replicated(jax.devices(), memory_kind=memory_kind) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") def test_pickle_named_sharding(self): s = jax.sharding.NamedSharding( From cfebc4986624f8b6780d8da7aef9e7b6e2db4acd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Jun 2025 07:35:12 -0700 Subject: [PATCH 1632/1769] Don't recompute np.iinfo in _scalar_type_to_dtype. Small optimization noticed in passing. PiperOrigin-RevId: 770142792 --- jax/_src/dtypes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 0276f08e7ef4..3be98da4a84b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -373,7 +373,8 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType: """ dtype = canonicalize_dtype(python_scalar_dtypes[typ]) if typ is int and value is not None: - if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max: + iinfo = np.iinfo(dtype) + if value < iinfo.min or value > iinfo.max: raise OverflowError(f"Python int {value} too large to convert to {dtype}") return dtype From 5f0e7e47741ea48d422e6624dcd1a12cf9e562da Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 11 Jun 2025 07:46:54 -0700 Subject: [PATCH 1633/1769] Set explicit dot precision in the sparse solver test. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In https://github.com/openxla/xla/pull/26679, XLA GPU is going to restrict number of cases where dot is rewritten as broadcast+multiply+reduction. Currently, the dots originating from the test get rewritten (because of small size) and therefore use higher precision accumulator (e.g. F32 for F32×F32 dots). After they stop being rewritten, they use the default TF32 accumulator. The fix explicitly sets higher precision for computing golden. PiperOrigin-RevId: 770146649 --- tests/sparse_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 97a156f9f6f5..1eeeae7c2749 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1206,7 +1206,12 @@ def sparse_solve(data, indices, indptr, b): return sparse.linalg.spsolve(data, indices, indptr, b, tol, reorder) x = sparse_solve(data, indices, indptr, b) - self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3) + self.assertAllClose( + jnp.matmul(a, x, precision=jax.lax.Precision.HIGHEST), + b, + rtol=1e-2, + atol=1e-3, + ) self._CompileAndCheck(sparse_solve, args_maker) @jtu.sample_product( From 225f0c621658bdecdadf9b53ea2827661144966b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 11 Jun 2025 15:50:00 +0100 Subject: [PATCH 1634/1769] Migrated to mypy 1.16.0 --- .pre-commit-config.yaml | 2 +- jax/_src/lax/lax.py | 2 -- jax/_src/pallas/mosaic_gpu/pipeline.py | 6 +++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8cc28c9fe4ac..2312c88579d6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'f40886d54c729f533f864ed6ce584e920feb0af7' # frozen: v1.15.0 + rev: '7010b10a09f65cd60a23c207349b539aa36dbec1' # frozen: v1.16.0 hooks: - id: mypy files: (jax/|tests/typing_test\.py) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9af301c26c0e..8fbcf21d574b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1690,8 +1690,6 @@ def _convert_element_type( return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) - new_dtype = type_cast(DTypeLike | None, new_dtype) - old_weak_type = dtypes.is_weakly_typed(operand) if new_dtype is None: new_dtype = old_dtype diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index be9f663a42b7..3ed17a085b8e 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -21,7 +21,7 @@ import functools import itertools as it import math -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeVar, cast import jax from jax import api_util @@ -226,7 +226,7 @@ def emit_pipeline( # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. if not has_dynamic_grid and max_concurrent_steps > num_steps: - max_concurrent_steps = num_steps + max_concurrent_steps = cast(int, num_steps) def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -523,7 +523,7 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. if not has_dynamic_grid and max_concurrent_steps > num_steps: - max_concurrent_steps = num_steps + max_concurrent_steps = cast(int, num_steps) def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) From f7f2ce5c5bed282198f9ebaee24715e01b53e880 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Jun 2025 08:02:12 -0700 Subject: [PATCH 1635/1769] Delete instantiate_const_abstracted. This function appears completely unused. PiperOrigin-RevId: 770152077 --- jax/_src/interpreters/partial_eval.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 78d3e75a4664..ded799b83d57 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -25,8 +25,6 @@ from typing import Any, NamedTuple, Union from weakref import ref -import numpy as np - from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -216,14 +214,6 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: else: return self.new_instantiated_const(const) - def instantiate_const_abstracted(self, tracer) -> JaxprTracer: - const = tracer.pval.get_known() - if const is None: - return tracer - else: - aval = get_aval(const).update_weak_type(np.isscalar(const)) - return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) - def cur_qdd(self, x): const = self.to_jaxpr_tracer(x).pval.get_known() if const is None: From 6c9bcfc6a1a90153c1096d259f0076fddb49bb75 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 11 Jun 2025 08:16:17 -0700 Subject: [PATCH 1636/1769] add doc comment to vma in ShapedArray PiperOrigin-RevId: 770157603 --- jax/_src/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 15c431599fed..4e5f720f680c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2045,6 +2045,8 @@ def __init__(self, shape, dtype, weak_type=False, *, sharding=None, self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) + # short for varying_manual_axes. See docs at + # https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true self.vma = get_vma(vma, self.sharding.mesh) def lower_val(self, val): return [val] From b87ea1c5e2add3f181fc01bb50e83622bca11f88 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Jun 2025 08:34:23 -0700 Subject: [PATCH 1637/1769] Do not call update_weak_type on the result of get_aval(). get_aval() already annotates the aval as weak if the value being abstractified should be treated as weak. PiperOrigin-RevId: 770163902 --- jax/_src/interpreters/partial_eval.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ded799b83d57..a24c6aad4ecf 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1931,8 +1931,6 @@ def new_const(self, c, source_info: SourceInfo): if aval.has_qdd: with core.set_current_trace(self.parent_trace): aval = core.AvalQDD(aval, core.cur_qdd(c)) - if hasattr(aval, "weak_type"): - aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval, source_info) tracer = self._new_const(aval, c, source_info) return tracer From 772cde814b35af4307b65ec1ffcdf7123f2c5adc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 11 Jun 2025 09:04:43 -0700 Subject: [PATCH 1638/1769] Move jax/_src/export to its own build rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 770174928 --- jax/BUILD | 30 +++++++++++++++++++++++++++++- jax/_src/export/_export.py | 32 +++++++++++++++++--------------- jax/_src/export/serialization.py | 12 ++++++------ jax/_src/export/shape_poly.py | 30 +++++++++++++++--------------- jax/_src/numpy/einsum.py | 3 +++ jax/_src/prng.py | 17 +++++++++++++++++ 6 files changed, 87 insertions(+), 37 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index a2d59c98e810..8a83e1c5df37 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -313,7 +313,6 @@ py_library_providing_imports_info( "_src/cudnn/**/*.py", "_src/debugger/**/*.py", "_src/image/**/*.py", - "_src/export/**/*.py", "_src/lax/**/*.py", "_src/nn/**/*.py", "_src/numpy/**/*.py", @@ -386,6 +385,7 @@ py_library_providing_imports_info( ":earray", ":effects", ":environment_info", + ":export", ":ffi", ":internal_mesh_utils", ":jaxpr_util", @@ -830,6 +830,34 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "export", + srcs = glob([ + "_src/export/**/*.py", + ]), + visibility = [":internal"] + jax_visibility("export"), + deps = [ + ":ad_util", + ":api", + ":config", + ":core", + ":custom_derivatives", + ":dtypes", + ":effects", + ":mesh", + ":mlir", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":stages", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("flatbuffers") + py_deps("numpy") + py_deps("opt_einsum"), +) + pytype_strict_library( name = "ffi", srcs = ["_src/ffi.py"], diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b390574c0a79..5d88f530bb55 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -28,12 +28,11 @@ import logging import numpy as np -import jax -from jax import sharding - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import core +from jax._src import custom_derivatives from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -45,11 +44,14 @@ from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src import mesh from jax._src import pjit +from jax._src import sharding from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb @@ -215,7 +217,7 @@ def __str__(self): def in_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + mesh: mesh.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to self.in_shardings_hlo. The Exported object stores `in_shardings_hlo` as HloShardings, which are @@ -225,7 +227,7 @@ def in_shardings_jax( Example usage: - >>> from jax import export + >>> from jax import export, sharding >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), @@ -255,7 +257,7 @@ def in_shardings_jax( def out_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + mesh: mesh.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to `self.out_shardings_hlo`. See documentation for in_shardings_jax. @@ -512,13 +514,13 @@ def default_export_platform() -> str: One of: `tpu`, `cpu`, `cuda`, `rocm`. """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' - return xb.canonicalize_platform(jax.default_backend()) + return xb.canonicalize_platform(xb.default_backend()) default_lowering_platform = default_export_platform def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype @@ -747,14 +749,14 @@ def export_sharding(s: LoweringSharding, cur_mesh = cur_arg = cur_k_path = None # lowered.args_info is a tree of the args, but we need the out avals too to # get the key paths for. - out_avals_tree = jax.tree_util.tree_unflatten(lowered.out_tree, out_avals_flat) + out_avals_tree = tree_util.tree_unflatten(lowered.out_tree, out_avals_flat) if config.use_shardy_partitioner.value: for sharding, (k_path, arg) in zip( itertools.chain.from_iterable([ all_in_shardings, lowering.compile_args["out_shardings"]]), itertools.chain.from_iterable([ - jax.tree.flatten_with_path(lowered.args_info)[0], - jax.tree.flatten_with_path(out_avals_tree)[0]])): + tree_util.tree_flatten_with_path(lowered.args_info)[0], + tree_util.tree_flatten_with_path(out_avals_tree)[0]])): if isinstance(sharding, sharding_impls.NamedSharding): if cur_mesh is None: cur_mesh, cur_arg, cur_k_path = sharding.mesh, arg, k_path @@ -1214,7 +1216,7 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, - device_assignment: Sequence[jax.Device] + device_assignment: Sequence[_jax.Device] ) -> sharding_impls.GSPMDSharding | None: if hlo_sharding is None: return None @@ -1259,7 +1261,7 @@ def flattened_primal_fun_jax(*args_flat): args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + _, pullback_jax = api.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, *args_flat_jax) return pullback_jax(out_cts_flat_jax) @@ -1291,12 +1293,12 @@ def flattened_primal_fun_jax(*args_flat): ### Calling the exported function -def call(exported: Exported) -> Callable[..., jax.Array]: +def call(exported: Exported) -> Callable[..., typing.Array]: if not isinstance(exported, Exported): raise ValueError( "The exported argument must be an export.Exported. " f"Found {exported}.") - @jax.custom_vjp + @custom_derivatives.custom_vjp def f_flat(*args_flat): return call_exported_p.bind(*args_flat, exported=exported) diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 3d878cccc701..dd33de846c42 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -19,7 +19,7 @@ import types from collections.abc import Callable, Sequence from functools import partial -from typing import TypeVar +from typing import Any, TypeVar try: import flatbuffers @@ -31,7 +31,6 @@ from jax._src import core from jax._src import dtypes from jax._src import effects -from jax._src import prng from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export @@ -364,10 +363,6 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, - - prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry, - prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg, - prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg, } _dtype_kind_to_dtype = { @@ -375,6 +370,11 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): } +def register_dtype_kind(dtype: Any, kind: int): + _dtype_to_dtype_kind[dtype] = kind + _dtype_kind_to_dtype[kind] = dtype + + def _serialize_aval( builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index bb8a159ee54b..c89ae2cc04ca 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -34,23 +34,21 @@ import numpy as np import opt_einsum -import jax - +from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects -from jax._src.lax import lax from jax._src.interpreters import mlir -from jax._src.numpy import einsum as jnp_einsum from jax._src import source_info_util from jax._src import tree_util +from jax._src import typing from jax._src import util DimSize = Union["_DimExpr", int] TfVal = Any -DimVarEnv = dict[str, jax.Array] +DimVarEnv = dict[str, typing.Array] DType = Any # Tuples of terms and their coefficients, sorted with the largest term first. @@ -214,6 +212,8 @@ def __ge__(self, other: _DimFactor): return self._syntactic_cmp(other) >= 0 def evaluate(self, env: DimVarEnv, scope: SymbolicScope): + from jax._src.lax import lax # pytype: disable=import-error + if self.var is not None: try: return env[self.var] @@ -1255,7 +1255,7 @@ def fake_dim(d): # here some errors due to non-equal dimensions, but we catch them # later. return 8 - fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), + fake_ops.append(api.ShapeDtypeStruct(tuple(map(fake_dim, shape)), operand.dtype)) contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, @@ -1267,8 +1267,6 @@ def fake_dim(d): contract_operands.append(operands[idx[0]]) return contract_operands, contractions -jnp_einsum._poly_einsum_handlers[_DimExpr] = _einsum_contract_path - # To implement shape-constraint checking we use a shape assertion primitive. # shape_assertion_p.bind(assert_what: bool, *error_message_inputs, # error_message="...{0}...{1}") @@ -1303,8 +1301,8 @@ class ShapeAssertionEffect(effects.Effect): effects.remat_allowed_effects.add_type(ShapeAssertionEffect) effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) -def shape_assertion(assert_what: jax.Array, - *error_message_inputs: jax.Array, +def shape_assertion(assert_what: typing.Array, + *error_message_inputs: typing.Array, error_message: str) -> None: """Adds a shape assertion in the code. @@ -1485,14 +1483,14 @@ def symbolic_args_specs( elif constraints: raise ValueError("Cannot use both `scope` and `constraints`") args_specs_flat = ( - jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) + api.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat)) return args_tree.unflatten(args_specs_flat) def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype @@ -1785,7 +1783,7 @@ def check_statically(self, eval: ShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: ShapeEvaluator) -> jax.Array | None: + def compute(self, eval: ShapeEvaluator) -> typing.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1793,6 +1791,8 @@ def compute(self, eval: ShapeEvaluator) -> jax.Array | None: resolved statically, returns a value representing if the constraint is satisfied. """ + from jax._src.lax import lax # pytype: disable=import-error + left, right = eval.evaluate(self.left), eval.evaluate(self.right) # Try to evaluate the constraint statically. if core.is_constant_shape((left, right)): @@ -1997,8 +1997,8 @@ def solve_dim_vars( def compute_dim_vars_from_arg_shapes( args_avals: Sequence[core.ShapedArray], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: + *actual_args: typing.Array, + args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[typing.Array]: """Computes values of dimension variables to unify args_avals with actual arguments. Like `solve_dim_vars` except that here we express the solution as diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 3f657082e1d4..64198d4f6fa0 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -22,6 +22,7 @@ from jax._src import core from jax._src import dtypes from jax._src.api import jit, named_call +from jax._src.export import shape_poly from jax._src.lax import lax from jax._src.lax.lax import PrecisionLike from jax._src.numpy import util @@ -581,3 +582,5 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax._convert_element_type(operands[0], preferred_element_type, output_weak_type) + +_poly_einsum_handlers[shape_poly._DimExpr] = shape_poly._einsum_contract_path diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 51211e62afc2..08e81a76c1c0 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1307,3 +1307,20 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) + + +# Register export serialization for PRNG key types. +try: + from jax._src.export import serialization # pytype: disable=import-error + from jax._src.export import serialization_generated as ser_flatbuf # pytype: disable=import-error +except ImportError: + # This can happen if flatbuffers is not installed, in which case export + # serialization is not supported and it is safe to skip the registration. + pass +else: + serialization.register_dtype_kind( + KeyTy(prngs["threefry2x32"]), ser_flatbuf.DType.key_fry) + serialization.register_dtype_kind( + KeyTy(prngs["rbg"]), ser_flatbuf.DType.key_rbg) + serialization.register_dtype_kind( + KeyTy(prngs["unsafe_rbg"]), ser_flatbuf.DType.key_unsafe_rbg) From 2618d9b10f765abc741b9aec500edb4ec01bb823 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 11 Jun 2025 09:52:22 -0700 Subject: [PATCH 1639/1769] Move materialization of NDIndexer out of draw() Making an NDIndexer is slow enough to trigger draw() timeouts when built with ASAN. PiperOrigin-RevId: 770195998 --- tests/pallas/indexing_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index cb862b406603..9e96252a843b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -99,12 +99,12 @@ def indexer_strategy(draw, dim, int_indexer_shape @hps.composite -def nd_indexer_strategy(draw, shape) -> NDIndexer: +def nd_indices_strategy(draw, shape) -> tuple[int | Slice | jax.Array, ...]: num_indices = draw(hps.integers(min_value=0, max_value=len(shape))) int_indexer_shape = draw(hnp.array_shapes()) indices = tuple(draw(indexer_strategy(dim, int_indexer_shape)) for dim in shape[:num_indices]) - return NDIndexer.from_indices_shape(indices, shape) + return indices class PallasBaseTest(jtu.JaxTestCase): @@ -218,7 +218,8 @@ def test_indexer_with_all_types(self): @hp.given(hps.data()) def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) - indexer = data.draw(nd_indexer_strategy(shape)) + indices = data.draw(nd_indices_strategy(shape)) + indexer = NDIndexer.from_indices_shape(indices, shape) is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices] rest_indexers, int_indexers = util.partition_list( @@ -371,7 +372,9 @@ def test_vmap_nd_indexing(self, data): el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs # hp.assume(len(el_shape) >= 2) - nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer") + nd_indexer = NDIndexer.from_indices_shape( + data.draw(nd_indices_strategy(el_shape), label="nd_indexer"), + el_shape) expected_shape = jax.eval_shape(lambda x: x[nd_indexer], jax.ShapeDtypeStruct(el_shape, jnp.float32)) From 5c2a32011adabe7136ae93e693fd1bc5c7fe33fe Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 11 Jun 2025 10:50:44 -0700 Subject: [PATCH 1640/1769] [JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices This change extends existing `colocated_cpu_devices` to take a JAX `Mesh` and returns a new `Mesh` that uses colocated CPU devices. This conversion is a common operation when using colocated Python because the user copies arrays between two meshes while preserving other sharding properties of the array. PiperOrigin-RevId: 770223744 --- jax/experimental/colocated_python/api.py | 46 ++++++++++++++++++++---- tests/colocated_python_test.py | 14 ++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index 45cd9e47e15a..b0dc3a46fc5f 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -16,25 +16,43 @@ from __future__ import annotations import collections -from typing import Any, Callable, Sequence, Type +from typing import Any, Callable, Sequence, Type, overload import jax from jax._src import api_util from jax._src import util from jax.experimental.colocated_python.func import make_callable from jax.experimental.colocated_python.obj import wrap_class +import numpy as np +@overload def colocated_cpu_devices( - devices: Sequence[jax.Device], + devices_or_mesh: Sequence[jax.Device], ) -> Sequence[jax.Device]: - """Finds CPU devices colocated with the given devices.""" - if not isinstance(devices, tuple): - devices = tuple(devices) + ... + + +@overload +def colocated_cpu_devices( + devices_or_mesh: jax.sharding.Mesh, +) -> jax.sharding.Mesh: + ... + + +def colocated_cpu_devices(devices_or_mesh): + """Finds devices or a mesh that has CPU devices colocated with the given devices or mesh.""" + if isinstance(devices_or_mesh, jax.sharding.Mesh): + return _colocated_cpu_mesh_cached(devices_or_mesh) + + if not isinstance(devices_or_mesh, tuple): + devices_or_mesh = tuple(devices_or_mesh) try: - return _colocated_cpu_devices_cached(devices) + return _colocated_cpu_devices_cached(devices_or_mesh) except (ValueError, AttributeError): - return _colocated_cpu_devices_cached_fallback_to_cpu_backend(devices) + return _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices_or_mesh + ) @util.cache(max_size=1024, trace_context_in_key=False) @@ -78,6 +96,20 @@ def _colocated_cpu_devices_cached_fallback_to_cpu_backend( ] +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh: + """Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices.""" + # Finding colocated CPU devices reuses the cache of `colocated_cpu_devices` + # called with devices. `_colocated_cpu_mesh` itself is also cached to avoid + # creating a new `Mesh` object repeatedly. + flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat)) + return jax.sharding.Mesh( + np.array(flat_cpu_devices).reshape(mesh.axis_sizes), + mesh.axis_names, + axis_types=mesh.axis_types, + ) + + def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]: """Executes the given Python function on the same devices as the arguments.""" return make_callable( diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index bff745eeba9d..892414ee1366 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -58,6 +58,20 @@ def setUp(self): " requires NumPy 2.0.0 or later" ) + def testColocatedCpuDevices(self): + mesh = jax.sharding.Mesh( + np.array(jax.local_devices()[:1]).reshape((1, 1)), ("x", "y") + ) + cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh) + + cpu_devices = colocated_python.colocated_cpu_devices( + jax.local_devices()[:1] + ) + cpu_mesh2 = jax.sharding.Mesh( + np.array(cpu_devices).reshape((1, 1)), ("x", "y") + ) + self.assertEqual(cpu_mesh1, cpu_mesh2) + def testMakeColocatedPythonProgram(self): def add_one(x): return x + 1 From 9f8be25fc28ab063fdaf53976197838b87e9036c Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Mon, 9 Jun 2025 16:50:13 -0700 Subject: [PATCH 1641/1769] extend pallas paged_attention with kv scales --- .../pallas/ops/gpu/paged_attention.py | 46 +++++++++- tests/pallas/gpu_paged_attention_test.py | 89 ++++++++++++++++++- 2 files changed, 132 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/paged_attention.py b/jax/experimental/pallas/ops/gpu/paged_attention.py index fbf861f92412..ca21761cf3ed 100644 --- a/jax/experimental/pallas/ops/gpu/paged_attention.py +++ b/jax/experimental/pallas/ops/gpu/paged_attention.py @@ -33,7 +33,9 @@ def paged_attention_kernel( # inputs q_ref, # [block_h, head_dim] k_pages_ref, # [total_num_pages, page_size, head_dim] + k_scales_pages_ref, # [total_num_pages, page_size] v_pages_ref, # [total_num_pages, page_size, head_dim] + v_scales_pages_ref, # [total_num_pages, page_size] block_tables_ref, # [pages_per_partition] lengths_ref, # [1] # outputs @@ -65,7 +67,16 @@ def body(start_k, carry): block_tables = pl.load(block_tables_ref, block_tables_slice) k = k_pages_ref[block_tables].reshape(block_k, head_dim) v = v_pages_ref[block_tables].reshape(block_k, head_dim) + if k_scales_pages_ref is not None: + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + k = k.astype(q.dtype) uncapped_logits = pl.dot(q, k.T) # [block_h, block_k] + if k_scales_pages_ref is not None: + # k_scales_pages_ref are one per head + # they're laid out across the output dimension, so scale output + k_scale = k_scales_pages_ref[block_tables].reshape((1, block_k)) + uncapped_logits *= k_scale.astype(uncapped_logits.dtype) if attn_logits_soft_cap is not None: logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap) logits = logits * attn_logits_soft_cap @@ -92,6 +103,14 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev + if v_scales_pages_ref is not None: + # v_scales are 1 per head + # they're laid out across the reduction dimension, so scale lhs + v_scale = v_scales_pages_ref[block_tables].reshape((1, block_k)) + s_curr *= v_scale.astype(s_curr.dtype) + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + v = v.astype(s_curr.dtype) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -134,6 +153,8 @@ def paged_attention_unbatched( v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] block_tables: jax.Array, # [pages_per_sequence] lengths: jax.Array | None, # [1] + k_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] + v_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] *, block_h: int, pages_per_compute_block: int, @@ -179,6 +200,19 @@ def paged_attention_unbatched( mask_value=mask_value, attn_logits_soft_cap=attn_logits_soft_cap, ) + # set up quantization scales + if k_scales_pages is not None: + assert k_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + k_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + k_scales_spec = None + if v_scales_pages is not None: + assert v_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + v_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + v_scales_spec = None o, l, m = pl.pallas_call( kernel, @@ -191,10 +225,12 @@ def paged_attention_unbatched( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # k_pages + k_scales_spec, # k_pages_scale pl.BlockSpec( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # v_pages + v_scales_spec, # v_pages_scale pl.BlockSpec( (None, pages_per_partition), lambda h, i, k: (k, 0) ), # block_tables @@ -226,7 +262,7 @@ def paged_attention_unbatched( num_warps=num_warps, num_stages=num_stages ), name=f"paged_attention_{block_h=}_{pages_per_compute_block=}", - )(q_reshaped, k_pages, v_pages, block_tables, lengths) + )(q_reshaped, k_pages, k_scales_pages, v_pages, v_scales_pages, block_tables, lengths) if q_heads_per_kv_head % block_h: o = o[..., :q_heads_per_kv_head, :] @@ -265,6 +301,8 @@ def paged_attention( v_pages: jax.Array, block_tables: jax.Array, lengths: jax.Array | None, + k_scales_pages: jax.Array | None = None, + v_scales_pages: jax.Array | None = None, *, block_h: int = 16, pages_per_compute_block: int = 8, @@ -286,6 +324,8 @@ def paged_attention( should be in the range of [0, total_num_pages), indicating where to locate the page in `k_pages` or `v_pages`. lengths: A i32[batch_size] jax.Array the length of each example. + k_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. + v_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. block_h: int The block size that partitions the number of head groups. pages_per_compute_block: int The maximum number of blocks per compute block. k_splits: int Number of partitions used to parallelize key-value sequence @@ -342,12 +382,14 @@ def paged_attention( attn_logits_soft_cap=attn_logits_soft_cap, ) - o = jax.vmap(impl, (0, None, None, 0, 0), 0)( + o = jax.vmap(impl, (0, None, None, 0, 0, None, None), 0)( q, k_pages, v_pages, block_tables, lengths[..., None] if lengths is not None else None, + k_scales_pages, + v_scales_pages, ) return o diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 081051f15dae..7a605ca4677b 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -44,9 +44,11 @@ def _generate_qkv( k_pages = jax.random.normal( k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + k_pages = k_pages / jnp.linalg.norm(k_pages, axis=-1)[..., None] v_pages = jax.random.normal( k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + v_pages = v_pages / jnp.linalg.norm(v_pages, axis=-1)[..., None] block_tables = jnp.arange( batch_size * max_num_blocks_per_seq, dtype=jnp.int32 @@ -54,6 +56,7 @@ def _generate_qkv( block_tables = jax.random.permutation(k3, block_tables, independent=True) block_tables = block_tables.reshape(batch_size, max_num_blocks_per_seq) q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = q / jnp.linalg.norm(q, axis=-1)[..., None] return q, k_pages, v_pages, block_tables @@ -72,6 +75,17 @@ def fn(_block_tables, _pages): return out +def _quantize(x: jax.Array, dtype=jnp.int8): + if isinstance(dtype, jnp.floating): + max_val = jnp.astype(jnp.finfo(dtype).max, x.dtype) + else: + max_val = 127 + x_scale = jnp.max(jnp.abs(x), axis=-1) / (0.95 * max_val) + x_quant = (x / x_scale[..., None]) + if isinstance(dtype, jnp.floating): + x_quant = jnp.rint(x_quant) + return x_quant.astype(dtype), x_scale.astype(x.dtype) + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): @@ -93,7 +107,6 @@ def setUp(self): super().setUp() - class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): @@ -154,6 +167,80 @@ def test_paged_attention( self.assertArraysAllClose(o, o_ref, rtol=5e-2, atol=5e-2) + @jtu.sample_product( + dtype=(jnp.float16,), + page_size=(8, 16, 32), + num_kv_heads=(1, 2), + q_kv_head_ratio=(2, 16, 20), + head_dim=(32, 64), + block_h=(16, 32), + pages_per_compute_block=(4, 8), + k_splits=(4, 16), + attn_logits_soft_cap=(None,), + quantize_k=(True, False), + quantize_v=(True, False), + quant_dtype=(jnp.float8_e4m3fn, jnp.int8), + ) + def test_quantized_paged_attention( + self, + dtype, + page_size, + num_kv_heads, + q_kv_head_ratio, + head_dim, + block_h, + pages_per_compute_block, + k_splits, + attn_logits_soft_cap, + quantize_k, + quantize_v, + quant_dtype, + ): + if not quantize_k and not quantize_v: + self.skipTest("Skipping since neither (k, v) quantization requested.") + max_kv_len = 2048 + seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) + q, k_pages, v_pages, block_tables = _generate_qkv( + seq_lens.shape[0], + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + k = _reconstruct_kv(block_tables, k_pages) + v = _reconstruct_kv(block_tables, v_pages) + + k_, k_scales = (_quantize(k_pages, quant_dtype) + if quantize_k else (k_pages, None)) + v_, v_scales = (_quantize(k_pages, quant_dtype) + if quantize_v else (v_pages, None)) + + o = paged_attention.paged_attention( + q, + k_, + v_, + block_tables, + seq_lens, + k_scales_pages=k_scales, + v_scales_pages=v_scales, + block_h=block_h, + pages_per_compute_block=pages_per_compute_block, + k_splits=k_splits, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=self.INTERPRET, + ) + + o_ref = paged_attention.paged_attention_reference(q, k, v, lengths=seq_lens) + + error = (jnp.linalg.norm((o - o_ref).astype(jnp.float32), axis=-1) + / jnp.linalg.norm(o_ref.astype(jnp.float32))) + + admissible_error = 3e-1 + self.assertLessEqual(jnp.mean(error), admissible_error) + class PagedAttentionInterpretTest(PagedAttentionKernelTest): INTERPRET = True From e61ee7e226535c8775a2c2c11af2f38979a02a24 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 11 Jun 2025 12:25:06 -0700 Subject: [PATCH 1642/1769] Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client. Allows extending the time for distributed CPU jobs to re-connect during setup. PiperOrigin-RevId: 770267880 --- jaxlib/_jax/__init__.pyi | 2 ++ jaxlib/xla.cc | 17 +++++++++++++++-- jaxlib/xla_client.py | 4 ++++ jaxlib/xla_client.pyi | 2 ++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 7bab0fc35547..91fc49197054 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -580,6 +580,8 @@ def get_tfrt_cpu_client( num_nodes: int = ..., collectives: CpuCollectives | None = ..., num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., ) -> Client: ... def get_mock_gpu_client( asynchronous: bool = ..., diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 69acd43b804a..969c2449f463 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -335,7 +335,10 @@ NB_MODULE(_jax, m) { std::shared_ptr distributed_client, int node_id, int num_nodes, std::shared_ptr collectives, - std::optional num_devices) -> nb_class_ptr { + std::optional num_devices, + std::optional get_local_topology_timeout_minutes, + std::optional get_global_topology_timeout_minutes) + -> nb_class_ptr { std::unique_ptr ifrt_client; { nb::gil_scoped_release gil_release; @@ -357,6 +360,14 @@ NB_MODULE(_jax, m) { ifrt_options.process_id = node_id; ifrt_options.num_processes = num_nodes; } + if (get_local_topology_timeout_minutes.has_value()) { + ifrt_options.get_local_topology_timeout = + absl::Minutes(*get_local_topology_timeout_minutes); + } + if (get_global_topology_timeout_minutes.has_value()) { + ifrt_options.get_global_topology_timeout = + absl::Minutes(*get_global_topology_timeout_minutes); + } ifrt_client = ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); } @@ -366,7 +377,9 @@ NB_MODULE(_jax, m) { nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, nb::arg("collectives").none() = std::shared_ptr(), - nb::arg("num_devices").none() = std::nullopt); + nb::arg("num_devices").none() = std::nullopt, + nb::arg("get_local_topology_timeout_minutes").none() = std::nullopt, + nb::arg("get_global_topology_timeout_minutes").none() = std::nullopt); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 911385f398bc..3766ffce62d8 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -68,6 +68,8 @@ def make_cpu_client( num_nodes=1, collectives=None, num_devices=None, + get_local_topology_timeout_minutes=None, + get_global_topology_timeout_minutes=None, ) -> Client: register_custom_call_handler('cpu', _xla.register_custom_call_target) register_custom_type_id_handler('cpu', _xla.register_custom_type_id) @@ -78,6 +80,8 @@ def make_cpu_client( num_nodes=num_nodes, collectives=collectives, num_devices=num_devices, + get_local_topology_timeout_minutes=get_local_topology_timeout_minutes, + get_global_topology_timeout_minutes=get_global_topology_timeout_minutes, ) diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi index 72e85500d1fe..ce9a2b815809 100644 --- a/jaxlib/xla_client.pyi +++ b/jaxlib/xla_client.pyi @@ -64,6 +64,8 @@ def make_cpu_client( num_nodes: int = ..., collectives: _xla.CpuCollectives | None = ..., num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., ) -> Client: ... def make_gpu_client( distributed_client: DistributedRuntimeClient | None = ..., From 97ab9e0b57c92bca7a45378e922fb87a38bab003 Mon Sep 17 00:00:00 2001 From: Matt-Hurd Date: Fri, 30 May 2025 19:50:45 +0000 Subject: [PATCH 1643/1769] [XProf] Change tensorboard-plugin-profile to new xprof package --- build/collect-profile-requirements.txt | 4 +- docs/profiling.md | 87 +++++++++++++++++--------- jax/collect_profile.py | 22 +++---- pyproject.toml | 2 +- setup.py | 5 ++ tests/profiler_test.py | 10 +-- 6 files changed, 78 insertions(+), 52 deletions(-) diff --git a/build/collect-profile-requirements.txt b/build/collect-profile-requirements.txt index e58558fd29a6..a334f408e271 100644 --- a/build/collect-profile-requirements.txt +++ b/build/collect-profile-requirements.txt @@ -1,5 +1,5 @@ # TF hasn't released 3.13 wheels yet (b/402590302) tensorflow; python_version<"3.13" -tensorboard-plugin-profile<=2.19.0 -# Needed for the profile plugin to work without error +xprof>=2.19.0 +# Needed for XProf to work without error protobuf diff --git a/docs/profiling.md b/docs/profiling.md index c33e79c1dc0c..282eac9080db 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -8,7 +8,7 @@ We can use the JAX profiler to generate traces of a JAX program that can be visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently, this method blocks the program until a link is clicked and the Perfetto UI loads the trace. If you wish to get profiling information without any interaction, -check out the Tensorboard profiler below. +check out the XProf profiler below. ```python with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): @@ -64,37 +64,37 @@ Also, by default, the program will prompt you to open a link to file and open a visualizer. This feature is disabled by passing in `--no_perfetto_link` into the command. Alternatively, you can also point Tensorboard to the `log_dir` to analyze the trace (see the -"Tensorboard Profiling" section below). +"XProf (Tensorboard Profiling)" section below). (tensorboard-profiling)= -## TensorBoard profiling +## XProf (TensorBoard profiling) -[TensorBoard's -profiler](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) -can be used to profile JAX programs. Tensorboard is a great way to acquire and +[XProf](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) +can be used to profile JAX programs. XProf is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this: -![TensorBoard profiler example](_static/tensorboard_profiler.png) +![XProf example](_static/tensorboard_profiler.png) ### Installation -The TensorBoard profiler is available as a plugin to TensorBoard +XProf is available as a plugin to TensorBoard, as well as an independently +run program. ```shell -pip install tensorboard tensorboard-plugin-profile +pip install xprof ``` -If you already have TensorBoard installed, you only need to install the -`tensorboard-plugin-profile` pip package. Be careful to only install one version -of TensorFlow or TensorBoard, otherwise you may encounter the "duplicate -plugins" error described {ref}`below `. See +If you have TensorBoard installed, the `xprof` pip package will also install +the TensorBoard Profiler plugin. Be careful to only install one version of +TensorFlow or TensorBoard, otherwise you may encounter the "duplicate plugins" +error described {ref}`below `. See for more information on installing TensorBoard. Profiling with the nightly version of TensorBoard requires the nightly -tensorboard profiler plugin +XProf. ```shell -pip install tb-nightly tbp-nightly +pip install tb-nightly xprof-nightly ``` ### Programmatic capture @@ -138,29 +138,54 @@ with jax.profiler.trace("/tmp/tensorboard"): y.block_until_ready() ``` +### Viewing the trace + +After capturing a trace, you can view it using either the standalone XProf +tool or the TensorBoard UI. The profiler interface is the same in both cases. + +#### Using Standalone XProf + +You can launch the profiler UI directly using the standalone XProf command by +pointing it to your log directory: + +``` +$ xprof --port 8791 /tmp/tensorboard +Attempting to start XProf server: + Log Directory: /tmp/tensorboard + Port: 8791 +XProf at http://localhost:8791/ (Press CTRL+C to quit) +``` + +Navigate to the provided URL (e.g., http://localhost:8791/) in your browser +to view the profile. + +Available traces appear in the "Runs" dropdown menu on the left. Select the +run you're interested in, and then under the "Tools" dropdown, select +trace_viewer. You should now see a timeline of the execution. You can use the +WASD keys to navigate the trace, and click or drag to select events for more +details. See +[these TensorFlow docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)= +for more details on using the trace viewer. + +#### With TensorBoard + To view the trace, first start TensorBoard if you haven't already: ```shell $ tensorboard --logdir=/tmp/tensorboard [...] Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all -TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit) +TensorBoard 2.20.0 at http://localhost:6006/ (Press CTRL+C to quit) ``` -You should be able to load TensorBoard at in this -example. You can specify a different port with the `--port` flag. See -{ref}`remote_profiling` below if running JAX on a remote server. - -Then, either select "Profile" in the upper-right dropdown menu, or go directly -to . Available traces appear in the "Runs" -dropdown menu on the left. Select the run you're interested in, and then under -"Tools", select `trace_viewer`. You should now see a timeline of the -execution. You can use the WASD keys to navigate the trace, and click or drag to -select events to see more details at the bottom. See [these TensorFlow -docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) -for more details on using the trace viewer. +You should be able to load TensorBoard at http://localhost:6006/ in this +example. Then, select "Profile" from the dropdown menu in the upper-right, +or navigate directly to http://localhost:6006/#profile. -You can also use the `memory_viewer`, `op_profile`, and `graph_viewer` tools. +From there, the experience is the same as the standalone tool: available +traces appear in the "Runs" dropdown menu on the left. Select the run +you're interested in, and then under "Tools", select trace_viewer to see the +timeline. ### Manual capture via TensorBoard @@ -306,8 +331,8 @@ replace, so it may be necessary to uninstall everything and reinstall a single version: ```shell -pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile tbp-nightly -pip install tensorboard tensorboard-plugin-profile +pip uninstall tensorflow tf-nightly tensorboard tb-nightly xprof xprof-nightly tensorboard-plugin-profile tbp-nightly +pip install tensorboard xprof ``` ## Nsight diff --git a/jax/collect_profile.py b/jax/collect_profile.py index b355816772a1..2c725ce8e9e2 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -23,15 +23,11 @@ # pytype: disable=import-error from jax._src import profiler as jax_profiler try: - from tensorflow.python.profiler import profiler_v2 as profiler - from tensorflow.python.profiler import profiler_client -except ImportError: - raise ImportError("This script requires `tensorflow` to be installed.") -try: - from tensorboard_plugin_profile.convert import raw_to_tool_data as convert + from xprof.convert import _pywrap_profiler_plugin + from xprof.convert import raw_to_tool_data as convert except ImportError: raise ImportError( - "This script requires `tensorboard_plugin_profile` to be installed.") + "This script requires `xprof` to be installed.") # pytype: enable=import-error @@ -69,13 +65,13 @@ def collect_profile(port: int, duration_in_ms: int, host: str, log_dir: os.PathLike | str | None, host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): - options = profiler.ProfilerOptions( - host_tracer_level=host_tracer_level, - device_tracer_level=device_tracer_level, - python_tracer_level=python_tracer_level, - ) + options = { + "host_tracer_level": host_tracer_level, + "device_tracer_level": device_tracer_level, + "python_tracer_level": python_tracer_level, + } log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp()) - profiler_client.trace( + _pywrap_profiler_plugin.trace( f"{host}:{port}", str(log_dir_), duration_in_ms, diff --git a/pyproject.toml b/pyproject.toml index d48351197b54..232831774bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ module = [ "rich.*", "scipy.*", "setuptools.*", - "tensorboard_plugin_profile.convert.*", + "xprof.convert.*", "tensorflow.*", "tensorflow.io.*", "tensorflowjs.*", diff --git a/setup.py b/setup.py index 2b50b041008d..244c069ed714 100644 --- a/setup.py +++ b/setup.py @@ -114,6 +114,11 @@ def load_version_module(pkg_path): 'k8s': [ 'kubernetes', ], + + # For including XProf server + 'xprof': [ + 'xprof', + ], }, url='https://github.com/jax-ml/jax', license='Apache-2.0', diff --git a/tests/profiler_test.py b/tests/profiler_test.py index d577f1c24c49..9784298405d2 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -44,11 +44,11 @@ profiler_client = None tf_profiler = None -TBP_ENABLED = False +XPROF_ENABLED = False try: - import tensorboard_plugin_profile - del tensorboard_plugin_profile - TBP_ENABLED = True + import xprof + del xprof + XPROF_ENABLED = True except ImportError: pass @@ -296,7 +296,7 @@ def on_profile(port, logdir, worker_start): self._check_xspace_pb_exist(logdir) @unittest.skipIf( - not (portpicker and profiler_client and tf_profiler and TBP_ENABLED), + not (portpicker and profiler_client and tf_profiler and XPROF_ENABLED), "Test requires tensorflow.profiler, portpicker and " "tensorboard_profile_plugin") def test_remote_profiler(self): From 004b54807f2c4d303515f81c5c483d864790d018 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 11 Jun 2025 13:40:36 -0700 Subject: [PATCH 1644/1769] Add alternative location of `CUDA_ROOT` for Bazel build/tests with hermetic CUDA. PiperOrigin-RevId: 770300064 --- jax/BUILD | 5 +++++ jax/_src/lib/__init__.py | 12 ++++++++++++ jax/experimental/mosaic/gpu/core.py | 15 ++++++++++----- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 8a83e1c5df37..8562c83cbad3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1179,6 +1179,11 @@ pytype_strict_library( py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), + data = [ + "@cuda_nvcc//:nvdisasm", + "@cuda_nvcc//:nvvm", + "@cuda_nvcc//:ptxas", + ], visibility = [ ":mosaic_gpu_users", ], diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b1f39a5ca93d..fde926094e8b 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -172,10 +172,22 @@ def _try_cuda_nvcc_import() -> str | None: return str(cuda_nvcc_path) + def _try_bazel_runfiles() -> str | None: + """Try to get the path to the cuda installation in bazel runfiles.""" + python_runfiles = os.environ.get('PYTHON_RUNFILES') + if not python_runfiles: + return None + cuda_nvcc_root = os.path.join(python_runfiles, 'cuda_nvcc') + if os.path.exists(cuda_nvcc_root): + return cuda_nvcc_root + return None + if (path := _try_cuda_root_environment_variable()) is not None: return path elif (path := _try_cuda_nvcc_import()) is not None: return path + elif (path := _try_bazel_runfiles()) is not None: + return path return None diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 652b22050c7d..21914aa6c8a3 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -55,14 +55,19 @@ # MLIR can't find libdevice unless we point it to the CUDA path # TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" +cuda_root = "/usr/local/cuda" +PYTHON_RUNFILES = os.environ.get("PYTHON_RUNFILES") if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT + if PYTHON_RUNFILES: + cuda_nvcc_root = os.path.join(PYTHON_RUNFILES, "cuda_nvcc") + if os.path.exists(cuda_nvcc_root): + cuda_root = cuda_nvcc_root + os.environ["CUDA_ROOT"] = cuda_root else: - CUDA_ROOT = os.environ["CUDA_ROOT"] + cuda_root = os.environ["CUDA_ROOT"] -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") +PTXAS_PATH = os.path.join(cuda_root, "bin/ptxas") +NVDISASM_PATH = os.path.join(cuda_root, "bin/nvdisasm") # This tracks the latest Mosaic GPU IR version with a monthly delay. FWD_COMPAT_IR_VERSION = 1 From 3ef4db4d7f16e7a339b246c75a0d24b69fc7f5a5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Jun 2025 14:56:20 -0700 Subject: [PATCH 1645/1769] Add a pytype disable around zstandard. Fail gracefully if zstandard is not present. PiperOrigin-RevId: 770333025 --- jax/_src/compilation_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index c8a0f0715a2d..db2730bb22bc 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -23,7 +23,7 @@ # If zstandard is installed, we use zstd compression, otherwise we use zlib. try: - import zstandard + import zstandard # pytype: disable=import-error except ImportError: zstandard = None From cc976cb015437bd02324cce2cb26ccbcf60f59eb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 11 Jun 2025 15:24:37 -0700 Subject: [PATCH 1646/1769] Add execution to unreduced tests now that it works end-to-end PiperOrigin-RevId: 770343728 --- tests/pjit_test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 35f9b6cda7e5..acd55d5fa704 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7810,14 +7810,19 @@ def f(x, y, a, b): self.assertEqual(out.aval.sharding.spec, P('x', None)) return out + out = f(x, y, a, b) + self.assertArraysEqual(out, (np_inp @ np_inp.T) + (np_inp @ np_inp.T)) + traced = f.trace(x, y, a, b) lowered_text = traced.lower().as_text() self.assertIn('unreduced={"y"}', lowered_text) self.assertEqual(lowered_text.count('unreduced={"y"}'), 3) - # TODO(yashkatariya): Execute this too - grad_jaxpr = jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(), - argnums=(0, 1, 2, 3))).trace(x, y, a, b).jaxpr + f_bar = jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(), + argnums=(0, 1, 2, 3))) + f_bar(x, y, a, b) # doesn't crash + + grad_jaxpr = f_bar.trace(x, y, a, b).jaxpr reshard_eqn = grad_jaxpr.eqns[4].params['jaxpr'].eqns[0] self.assertEqual(reshard_eqn.params['dst_sharding'].spec.reduced, frozenset('y')) From 45e61d8c9cb1df0d15273ad546aaeaa83b366ee2 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 11 Jun 2025 15:58:16 -0700 Subject: [PATCH 1647/1769] Add nightly linux jax wheel tests for python 3.14.0b1 PiperOrigin-RevId: 770357177 --- .github/workflows/pytest_cpu.yml | 17 ++++++++++++++++- .github/workflows/pytest_cuda.yml | 17 ++++++++++++++++- .../workflows/wheel_tests_nightly_release.yml | 6 ++++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index d23c1f543827..71b7d49f8049 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -130,10 +130,25 @@ jobs: - name: Install Python dependencies run: | $JAXCI_PYTHON -m pip install uv~=0.5.30 + + # For prerelease python 3.14, some pre-built dependency wheels aren't available, + # so we need to download their deps or build them from source. + if [[ $JAXCI_PYTHON == "python3.14" ]]; then + # Build numpy from source + # Need to include fixes for https://github.com/numpy/numpy/issues/28681. + $JAXCI_PYTHON -m uv pip install "git+https://github.com/numpy/numpy@v2.3.0" + + # Install build requirements for scipy + apt update && apt upgrade -y && apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + # Install build requirements for pillow + apt install -q -y libjpeg-dev --no-install-recommends + fi + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632 - if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then + if [[ $OS == "linux" && $ARCH == "aarch64" && $JAXCI_HERMETIC_PYTHON_VERSION != "3.14" ]]; then $JAXCI_PYTHON -m uv pip install numpy~=2.1.0 fi # Halt for testing diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index d576370bb772..644aceaa803e 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -131,7 +131,22 @@ jobs: echo "Skipping the test run." exit 1 - name: Install Python dependencies - run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + run: | + # For prerelease python 3.14, some pre-built dependency wheels aren't available, + # so we need to download their deps or build them from source. + if [[ $JAXCI_PYTHON == "python3.14" ]]; then + # Build numpy from source + # Need to include fixes for https://github.com/numpy/numpy/issues/28681. + $JAXCI_PYTHON -m uv pip install "git+https://github.com/numpy/numpy@v2.3.0" + + # Install build requirements for scipy + apt update && apt upgrade -y && apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + # Install build requirements for pillow + apt install -q -y libjpeg-dev --no-install-recommends + fi + + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 6d25ee281c7b..d90d05b64b28 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -43,11 +43,13 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] + python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil", "3.14"] enable-x64: [0] exclude: - runner: "windows-x86-n2-64" python: "3.13-nogil" + - runner: "windows-x86-n2-64" + python: "3.14" name: "Pytest CPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.runner }} @@ -65,7 +67,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] + python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil", "3.14"] cuda: [ {cuda-version: "12.1", use-nvidia-pip-wheels: false}, {cuda-version: "12.8", use-nvidia-pip-wheels: true} From e7d252dc448ab6226ce78191d785a49e0efdceb7 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 11 Jun 2025 18:57:59 -0700 Subject: [PATCH 1648/1769] Fix GPU quantized paged attention tests for < sm89 --- tests/pallas/gpu_paged_attention_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 7a605ca4677b..1b778c787a6d 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -179,7 +179,7 @@ def test_paged_attention( attn_logits_soft_cap=(None,), quantize_k=(True, False), quantize_v=(True, False), - quant_dtype=(jnp.float8_e4m3fn, jnp.int8), + quant_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.int8), ) def test_quantized_paged_attention( self, @@ -198,6 +198,9 @@ def test_quantized_paged_attention( ): if not quantize_k and not quantize_v: self.skipTest("Skipping since neither (k, v) quantization requested.") + if (quant_dtype == jnp.float8_e4m3fn + and not jtu.is_cuda_compute_capability_at_least("8.9")): + self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") max_kv_len = 2048 seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) q, k_pages, v_pages, block_tables = _generate_qkv( From 4ad4d4a8855e6a30e560fd3950e6b8698771f6bf Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 12 Jun 2025 00:10:25 -0700 Subject: [PATCH 1649/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4d1cd8faa246a936bb70790fc7d21d6c236d2163. PiperOrigin-RevId: 770508877 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ca58bb6f0250..bf92aff487a4 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "870d90fd098c480fb8a426126bd02047adb2bc20" -XLA_SHA256 = "963b285bbc6f40a198833a14effc4f38f75b9c5d1813ccef4f09c287d0cb9ae4" +XLA_COMMIT = "4d1cd8faa246a936bb70790fc7d21d6c236d2163" +XLA_SHA256 = "fd9aee891ef0a38507d59cffc1540bf0f6653911c55c80c20a648e83d974bfbe" def repo(): tf_http_archive( From 87ce7d1bee969419aece3dae76c48a99cedcd036 Mon Sep 17 00:00:00 2001 From: DanisNone Date: Wed, 11 Jun 2025 23:27:26 +0500 Subject: [PATCH 1650/1769] add jax.nn module type hints (__init__.pyi) add jax.nn module type hints (__init__.pyi) add jax.nn module type hints (__init__.pyi) --- .pre-commit-config.yaml | 2 +- jax/BUILD | 1 + jax/nn/__init__.pyi | 92 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 jax/nn/__init__.pyi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8cc28c9fe4ac..4185623f6e1f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: hooks: - id: mypy files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jaxlib/_jax/.* # Use pyi instead + exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0] args: [--config=pyproject.toml] diff --git a/jax/BUILD b/jax/BUILD index 8a83e1c5df37..58eccaa2e0a2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -349,6 +349,7 @@ py_library_providing_imports_info( lib_rule = pytype_library, pytype_srcs = glob( [ + "nn/*.pyi", "numpy/*.pyi", "_src/**/*.pyi", ], diff --git a/jax/nn/__init__.pyi b/jax/nn/__init__.pyi new file mode 100644 index 000000000000..89a3298e96b2 --- /dev/null +++ b/jax/nn/__init__.pyi @@ -0,0 +1,92 @@ +from typing import Any, List, Literal, Union, overload, Sequence + +from jax._src.typing import ( + Array, ArrayLike, DTypeLike +) +from jax._src.core import AxisName +from jax._src.cudnn.scaled_matmul_stablehlo import BlockScaleConfig +from jax._src.lax.lax import DotDimensionNumbers + +from jax.nn import initializers as initializers + +_Axis = Union[None, int, Sequence[int]] + + +def celu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... +def dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike | None = ..., + mask: ArrayLike | None = ..., + *, + scale: float | None = ..., + is_causal: bool = ..., + query_seq_lengths: ArrayLike | None = ..., + key_value_seq_lengths: ArrayLike | None = ..., + local_window_size: int | tuple[int, int] | None = ..., + implementation: Literal['xla', 'cudnn'] | None = ...) -> Array: ... +def elu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... +def gelu(x: ArrayLike, approximate: bool = ...) -> Array: ... +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = ...) -> BlockScaleConfig: ... +def glu(x: ArrayLike, axis: int = ...) -> Array: ... +def hard_sigmoid(x: ArrayLike) -> Array: ... +def hard_silu(x: ArrayLike) -> Array: ... +def hard_swish(x: ArrayLike) -> Array: ... +def hard_tanh(x: ArrayLike) -> Array: ... +def identity(x: ArrayLike) -> Array: ... +def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = ...) -> Array: ... +def log_sigmoid(x: ArrayLike) -> Array: ... +def log_softmax(x: ArrayLike, + axis: int | tuple[int, ...] | None = ..., + where: ArrayLike | None = ...) -> Array: ... +@overload +def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., + keepdims: bool = ..., return_sign: Literal[False] = ..., where: ArrayLike | None = ...) -> Array: ... + +@overload +def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., + keepdims: bool = ..., *, return_sign: Literal[True], where: ArrayLike | None = ...) -> tuple[Array, Array]: ... + +@overload +def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., + keepdims: bool = ..., return_sign: bool = ..., where: ArrayLike | None = ...) -> Array | tuple[Array, Array]: ... +def mish(x: ArrayLike) -> Array: ... +def one_hot(x: Any, num_classes: int, *, + dtype: Any = ..., axis: int | AxisName = ...) -> Array: ... +def relu(x: ArrayLike) -> Array: ... +def relu6(x: ArrayLike) -> Array: ... +def scaled_dot_general( + lhs: ArrayLike, rhs: ArrayLike, + dimension_numbers: DotDimensionNumbers, + preferred_element_type: DTypeLike = ..., + configs: List[BlockScaleConfig] | None = ..., + implementation: Literal['cudnn'] | None = ..., + ) -> Array: ... +def scaled_matmul( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = ..., +) -> Array: ... +def selu(x: ArrayLike) -> Array: ... +def sigmoid(x: ArrayLike) -> Array: ... +def silu(x: ArrayLike) -> Array: ... +def soft_sign(x: ArrayLike) -> Array: ... +def softmax(x: ArrayLike, + axis: int | tuple[int, ...] | None = ..., + where: ArrayLike | None = ...) -> Array: ... +def softplus(x: ArrayLike) -> Array: ... +def sparse_plus(x: ArrayLike) -> Array: ... +def sparse_sigmoid(x: ArrayLike) -> Array: ... +def squareplus(x: ArrayLike, b: ArrayLike = ...) -> Array: ... +def standardize(x: ArrayLike, + axis: int | tuple[int, ...] | None = ..., + mean: ArrayLike | None = ..., + variance: ArrayLike | None = ..., + epsilon: ArrayLike = ..., + where: ArrayLike | None = ...) -> Array: ... +def swish(x: ArrayLike) -> Array: ... +def tanh(x: ArrayLike, /) -> Array: ... From a28803636255cd9639f5406f8bc3ad516c61dcba Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 06:57:42 -0700 Subject: [PATCH 1651/1769] Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++. * change NamedSharding::spec_ to use a more precise type, now PartitionSpec is a type known to C++. * update PartitionSpec::Eq and PartitionSpec::Hash to be more usable from other C++ classes. * add a thread-safe cache for the hash value of a NamedSharding. * move the existing ShardingHash and ShardingEqual functions into jax_jit and rename them, since they are implementations that only really make sense for jit. PiperOrigin-RevId: 770635344 --- jax/_src/named_sharding.py | 3 + jaxlib/BUILD | 15 +++++ jaxlib/cached_py_object.h | 61 +++++++++++++++++++ jaxlib/jax_jit.cc | 66 ++++++++++++++++++++- jaxlib/jax_jit.h | 10 +++- jaxlib/partition_spec.cc | 12 ++-- jaxlib/partition_spec.h | 6 +- jaxlib/py_array.cc | 2 +- jaxlib/sharding.cc | 118 ++++++++++++++----------------------- jaxlib/sharding.h | 21 +++---- jaxlib/xla_client.py | 2 +- tests/pjit_test.py | 3 +- 12 files changed, 224 insertions(+), 95 deletions(-) create mode 100644 jaxlib/cached_py_object.h diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index bae999e1e83d..60336a0822e6 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -22,6 +22,7 @@ from jax._src import config from jax._src.util import use_cpp_class, cache, use_cpp_method +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib @@ -145,12 +146,14 @@ def __reduce__(self): def memory_kind(self) -> str | None: return self._memory_kind + @use_cpp_method(jaxlib_extension_version >= 353) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash + @use_cpp_method(jaxlib_extension_version >= 353) def __eq__(self, other): if not isinstance(other, NamedSharding): return False diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 9f495087ccb0..c95e29c2f6b1 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -828,6 +828,20 @@ cc_library( ], ) +cc_library( + name = "cached_py_object", + hdrs = ["cached_py_object.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/functional:function_ref", + "@nanobind", + ], +) + cc_library( name = "py_client", srcs = [ @@ -867,6 +881,7 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/py_client"), deps = [ + ":cached_py_object", ":guard_lib", ":nb_class_ptr", ":py_client_cpu", diff --git a/jaxlib/cached_py_object.h b/jaxlib/cached_py_object.h new file mode 100644 index 000000000000..b934fa203a44 --- /dev/null +++ b/jaxlib/cached_py_object.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_CACHED_PY_OBJECT_H_ +#define JAX_JAXLIB_CACHED_PY_OBJECT_H_ + +#include + +#include "absl/functional/function_ref.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// A lock-free thread-safe cache for a single Python object. +// Example use case: caching a hash value in an object. +class CachedPyObject { + public: + CachedPyObject() = default; + ~CachedPyObject() { + PyObject* value = value_.load(); + Py_XDECREF(value); + } + + // Returns the cached value of the object. If the object is not present, + // factory() will be called to create it and the cache will be populated. + // Note: factory() may be called multiple times if used concurrently. The + // returned value will be one of the returned values of factory(). + // Thread-safe. + nanobind::object Get(absl::FunctionRef factory) { + PyObject* v = value_.load(); + if (v) { + return nanobind::borrow(v); + } + nanobind::object new_value = factory(); + if (value_.compare_exchange_strong(v, new_value.inc_ref().ptr())) { + return new_value; + } else { + new_value.dec_ref(); + return nanobind::borrow(v); + } + } + + private: + std::atomic value_ = nullptr; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_CACHED_PY_OBJECT_H_ diff --git a/jaxlib/jax_jit.cc b/jaxlib/jax_jit.cc index c48aa6ab7d19..53263949e72e 100644 --- a/jaxlib/jax_jit.cc +++ b/jaxlib/jax_jit.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -232,6 +233,69 @@ std::string CallSignature::DebugString() const { absl::StrJoin(configs, ", ", py_object_formatter)); } + +size_t HashShardingForJit(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool EqualShardingsForJit(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) return true; + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) return false; + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + *a_named_sharding->spec() == *b_named_sharding->spec() && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + + return a_gspmd_sharding == b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + bool CallSignature::operator==(const CallSignature& other) const { if (arg_signature != other.arg_signature) { return false; @@ -251,7 +315,7 @@ bool CallSignature::operator==(const CallSignature& other) const { return // `==` on py:objects is the Python `is`. We need equal. absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, - ShardingEqual) && + EqualShardingsForJit) && absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, [](const std::shared_ptr& a, const std::shared_ptr& b) { diff --git a/jaxlib/jax_jit.h b/jaxlib/jax_jit.h index dc025e63f1de..0061514e3cfb 100644 --- a/jaxlib/jax_jit.h +++ b/jaxlib/jax_jit.h @@ -227,6 +227,14 @@ struct CallSignature { std::string DebugString() const; }; +// A hash and equality for shardings that may sometimes return different hashes +// for equal values, and may sometimes return "not equal" for equal values. +// These are not correct implementations of `__hash__` and `__eq__` in python, +// but they are fine for jit/pjit dispatch since they only causes spurious cache +// misses. +size_t HashShardingForJit(nanobind::handle sharding); +bool EqualShardingsForJit(nanobind::handle a, nanobind::handle b); + template H AbslHashValue(H h, const CallSignature& s) { h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); @@ -241,7 +249,7 @@ H AbslHashValue(H h, const CallSignature& s) { // slow python hashing function. Consider implementing hashing function and // equality checks in C++ in jax::Sharding and use those here. for (const auto& sharding : s.dynamic_arg_shardings) { - h = H::combine(std::move(h), ShardingHash(sharding)); + h = H::combine(std::move(h), HashShardingForJit(sharding)); } for (const auto& layout : s.dynamic_arg_layouts) { diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc index e43ca4b6108c..2535c38b977b 100644 --- a/jaxlib/partition_spec.cc +++ b/jaxlib/partition_spec.cc @@ -141,22 +141,26 @@ PartitionSpec::PartitionSpec(nb::tuple partitions, nb_frozenset unreduced, unreduced_(std::move(unreduced)), reduced_(std::move(reduced)) {} -Py_ssize_t PartitionSpec::Hash() const { +Py_hash_t PartitionSpec::Hash() const { size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_), nb::hash(reduced_)); Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. } +bool PartitionSpec::operator==(const PartitionSpec& other) const { + return partitions().equal(other.partitions()) && + unreduced().equal(other.unreduced()) && + reduced().equal(other.reduced()); +} + bool PartitionSpec::Eq(const nb::object& other) const { if (!other.ptr() || other.is_none()) { return false; } PartitionSpec* other_spec; if (nb::try_cast(other, other_spec)) { - return partitions().equal(other_spec->partitions()) && - unreduced().equal(other_spec->unreduced()) && - reduced().equal(other_spec->reduced()); + return *this == *other_spec; } nb::tuple other_tuple; if (nb::try_cast(other, other_tuple)) { diff --git a/jaxlib/partition_spec.h b/jaxlib/partition_spec.h index 62c292a0c966..fc207cfe7a28 100644 --- a/jaxlib/partition_spec.h +++ b/jaxlib/partition_spec.h @@ -47,8 +47,10 @@ class PartitionSpec { nb_frozenset unreduced() const { return unreduced_; } nb_frozenset reduced() const { return reduced_; } - bool Eq(const nanobind::object& other) const; - Py_ssize_t Hash() const; + bool operator==(const PartitionSpec& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + Py_hash_t Hash() const; // Python __hash__ static void Register(nanobind::module_& m); diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 8659cba49dea..b79236f1306a 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1173,7 +1173,7 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && array_cs == ifrt::ArrayCopySemantics::kReuseInput) { - if (jax::ShardingEqual(py_array.sharding(), dst_sharding)) { + if (py_array.sharding().equal(dst_sharding)) { results[i] = py_arrays[i]; } else { absl::Span shape_span = py_array.shape(); diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc index 0514946a729b..77e97c3654bc 100644 --- a/jaxlib/sharding.cc +++ b/jaxlib/sharding.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -34,6 +35,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" #include "jaxlib/py_client.h" #include "jaxlib/py_device_list.h" #include "jaxlib/sharded_device_array.h" @@ -130,69 +132,6 @@ int Sharding::SafeNumDevices(nb::handle sharding) { return device_set.size(); } -size_t ShardingHash(nb::handle sharding) { - auto type = sharding.type(); - - if (type.is(NamedSharding::type())) { - const auto* named_sharding = nb::inst_ptr(sharding); - return absl::Hash()(named_sharding->mesh().ptr()); - } - - if (type.is(GSPMDSharding::type())) { - auto* gspmd_sharding = nb::inst_ptr(sharding); - return gspmd_sharding->Hash(); - } - - if (type.is(SingleDeviceSharding::type())) { - auto* single_device_sharding = nb::inst_ptr(sharding); - return absl::Hash()(single_device_sharding->device().ptr()); - } - - return nb::hash(sharding); -} - -bool ShardingEqual(nb::handle a, nb::handle b) { - if (a.ptr() == b.ptr()) return true; - - auto a_type = a.type(); - auto b_type = b.type(); - - if (!a_type.is(b_type)) return false; - - if (a_type.is(NamedSharding::type())) { - auto* a_named_sharding = nb::inst_ptr(a); - auto* b_named_sharding = nb::inst_ptr(b); - - return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && - a_named_sharding->spec().equal(b_named_sharding->spec()) && - a_named_sharding->memory_kind().equal( - b_named_sharding->memory_kind()) && - a_named_sharding->logical_device_ids().equal( - b_named_sharding->logical_device_ids()); - } - - if (a_type.is(GSPMDSharding::type())) { - auto* a_gspmd_sharding = nb::inst_ptr(a); - auto* b_gspmd_sharding = nb::inst_ptr(b); - - return a_gspmd_sharding == b_gspmd_sharding; - } - - if (a_type.is(SingleDeviceSharding::type())) { - auto* a_single_device_sharding = - nb::inst_ptr(a); - auto* b_single_device_sharding = - nb::inst_ptr(b); - - return a_single_device_sharding->device().ptr() == - b_single_device_sharding->device().ptr() && - a_single_device_sharding->memory_kind().equal( - b_single_device_sharding->memory_kind()); - } - - return a.equal(b); -} - // This list is to check for valid memory kinds when an AbstractMesh is passed // to NamedSharding. static const std::array valid_memory_kinds = { @@ -201,7 +140,8 @@ static const std::array valid_memory_kinds = { "unpinned_host", }; -NamedSharding::NamedSharding(nb::object mesh, nb::object spec, +NamedSharding::NamedSharding(nb::object mesh, + xla::nb_class_ptr spec, nb::object memory_kind, nb::object logical_device_ids) : Sharding(/*num_devices=*/[&mesh]() { @@ -211,10 +151,6 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, spec_(std::move(spec)), memory_kind_(std::move(memory_kind)), logical_device_ids_(std::move(logical_device_ids)) { - if (spec_.is_none()) { - throw nb::type_error( - "Unexpected None passed as spec for NamedSharding. Did you mean P()?"); - } nb::object idl = nb::object(mesh_.attr("_internal_device_list")); if (idl.is_none()) { internal_device_list_ = std::nullopt; @@ -241,7 +177,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - auto init_fn = [](){ + auto init_fn = []() { nb::module_ si = nb::module_::import_("jax._src.named_sharding"); return std::make_unique(si.attr("check_pspec")); }; @@ -256,6 +192,36 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, type_ = nanobind::type().inc_ref().ptr(); } +bool NamedSharding::operator==(const NamedSharding& other) const { + // Caution: you may need to update EqualShardingsForJit in jax_jit.cc as well. + return mesh().equal(other.mesh()) && *spec() == *other.spec() && + memory_kind().equal(other.memory_kind()) && + logical_device_ids().equal(other.logical_device_ids()); +} + +bool NamedSharding::Eq(const nanobind::object& other) const { + if (!other.ptr() || other.is_none()) { + return false; + } + const NamedSharding* other_sharding; + if (!nb::try_cast(other, other_sharding)) { + return false; + } + return this == other_sharding || *this == *other_sharding; +} + +nb::object NamedSharding::Hash() const { + // Caution: you may need to update HashShardingForJit in jax_jit.cc as well. + return hash_.Get([&]() { + size_t h = + absl::HashOf(nb::hash(mesh_), spec_->Hash(), nb::hash(memory_kind_), + nb::hash(logical_device_ids_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return nb::cast( + s == -1 ? -2 : s); // -1 must not be used as a Python hash value. + }); +} + SingleDeviceSharding::SingleDeviceSharding(nb::object device, nb::object memory_kind) : Sharding(/*num_devices=*/1), @@ -337,17 +303,21 @@ void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); nb::class_(m, "NamedSharding", nb::dynamic_attr()) - .def(nb::init(), - nb::arg("mesh"), nb::arg("spec").none(), + .def(nb::init, nb::object, + nb::object>(), + nb::arg("mesh"), nb::arg("spec"), nb::arg("memory_kind").none() = nb::none(), nb::arg("_logical_device_ids").none() = nb::none()) .def_prop_ro("mesh", &NamedSharding::mesh) .def_prop_ro("spec", &NamedSharding::spec) .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) - .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { - return xla::ValueOrThrow(s.internal_device_list()); - }); + .def_prop_ro("_internal_device_list", + [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }) + .def("__eq__", &NamedSharding::Eq, nb::arg().none()) + .def("__hash__", &NamedSharding::Hash); NamedSharding::InitializeType(); nb::class_(m, "SingleDeviceSharding", diff --git a/jaxlib/sharding.h b/jaxlib/sharding.h index cb7c1b471a63..083fb2b5d3ce 100644 --- a/jaxlib/sharding.h +++ b/jaxlib/sharding.h @@ -26,7 +26,9 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/cached_py_object.h" #include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" #include "jaxlib/py_client.h" #include "jaxlib/py_device_list.h" #include "jaxlib/sharded_device_array.h" @@ -65,21 +67,14 @@ nanobind::object CheckAndCanonicalizeMemoryKind( nanobind::object memory_kind, const xla::nb_class_ptr& device_list); -// Returns a hash that may sometimes return different hashes for equal values. -// It is not a correct implementation of `__hash__` in python, but it's fine -// for jit/pjit dispatch since it only causes spurious cache misses. -size_t ShardingHash(nanobind::handle sharding); - -bool ShardingEqual(nanobind::handle a, nanobind::handle b); - class NamedSharding : public Sharding { public: - NamedSharding(nanobind::object mesh, nanobind::object spec, + NamedSharding(nanobind::object mesh, xla::nb_class_ptr spec, nanobind::object memory_kind, nanobind::object logical_device_ids); const nanobind::object& mesh() const { return mesh_; } - const nanobind::object& spec() const { return spec_; } + const xla::nb_class_ptr& spec() const { return spec_; } const nanobind::object& memory_kind() const { return memory_kind_; } const nanobind::object& logical_device_ids() const { return logical_device_ids_; @@ -97,12 +92,18 @@ class NamedSharding : public Sharding { "`jax.sharding.AbstractMesh`"); } + bool operator==(const NamedSharding& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + nanobind::object Hash() const; // Python __hash__ + private: nanobind::object mesh_; - nanobind::object spec_; + xla::nb_class_ptr spec_; nanobind::object memory_kind_; nanobind::object logical_device_ids_; std::optional> internal_device_list_; + mutable CachedPyObject hash_; static PyObject* type_; }; diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 3766ffce62d8..4b8c1133e090 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 352 +_version = 353 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index acd55d5fa704..b2661b39ef4d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8245,7 +8245,8 @@ def f(a, b, c): def test_named_sharding_of_none(self): mesh = jtu.create_mesh((2,), ('x',)) - with self.assertRaisesRegex(TypeError, 'Unexpected None'): + with self.assertRaisesRegex( + TypeError, '(Unexpected None|incompatible function arguments)'): jax.NamedSharding(mesh, None) From 83c292bd6fc457d128b47003b4961fc3290fbd8e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 12 Jun 2025 07:15:42 -0700 Subject: [PATCH 1652/1769] [Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`. The inline PTX is able to upcast 4 values at a time, so we use a generator to pack several registers together when our registers don't hold enough packed values. This makes the generated PTX smaller, since the conversion routine needs to be called less often. PiperOrigin-RevId: 770640667 --- .../mosaic/gpu/fragmented_array.py | 106 +++++++++++++++++- tests/mosaic/gpu_test.py | 30 ++++- 2 files changed, 125 insertions(+), 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 816b6d6cbe74..e82f71be8ae4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -16,14 +16,15 @@ from __future__ import annotations +from collections.abc import Callable import dataclasses import functools +import itertools import math -from collections.abc import Callable -from typing import Iterable, Protocol, Sequence, TypeVar +from typing import Generator, Iterable, Protocol, Sequence, TypeVar -import itertools import jax +import jax.experimental.mosaic.gpu as mgpu from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu @@ -33,7 +34,6 @@ from jaxlib.mlir.dialects import vector import numpy as np -import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors @@ -1457,6 +1457,98 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): "Register bitwidth in target type must be divisible by 8, got" f" {new_reg_bitwidth}" ) + if cur_dtype == i4 and new_dtype == f8e4m3fn: + # The algorithm here is taken from CUTLASS's `NumericArrayConverter` + # specialization for int4 -> f8e4m3, available at + # https://github.com/NVIDIA/cutlass/blob/5c6bca04414e06ce74458ab0a2018e2b8272701c/include/cutlass/numeric_conversion.h#L4982. + # Each call to the function below will upcast 4 contiguous nibbles of + # the input 32-bit register, and whether to select the 4 low nibbles or + # the 4 high nibbles is determined by the `part` argument. + def upcast_to_f8e4m3fn(reg: ir.Value, part: int): + lut = [ + 0x44403800, # [0, 1, 2, 3] encoded as f8e4m3fn + 0x4E4C4A48, # [4, 5, 6, 7] encoded as f8e4m3fn + 0xCACCCED0, # [-8, -7, -6, -5] encoded as f8e4m3fn + 0xB8C0C4C8, # [-4, -3, -2, -1] encoded as f8e4m3fn + ] + + sign = arith.shrui(arith.andi(reg, c(0x88888888, i32)), c(1, i32)) + # Ignore the sign when indexing into the LUT. + lut_idx = arith.andi(reg, c(0x77777777, i32)) + + assert 0 <= part < 2 + if part == 1: + lut_idx = arith.shrui(lut_idx, c(16, i32)) + sign = arith.shrui(sign, c(16, i32)) + + prmt_sign_pattern = arith.ori(sign, c(0x32103210, i32)) + return llvm.inline_asm( + i32, + [lut_idx, prmt_sign_pattern], + f""" + {{ + .reg .b32 pos_f8s, neg_f8s; + prmt.b32 pos_f8s, {lut[0]}, {lut[1]}, $1; + prmt.b32 neg_f8s, {lut[2]}, {lut[3]}, $1; + prmt.b32 $0, pos_f8s, neg_f8s, $2; + }} + """, + "=r,r,r", + ) + new_registers = np.empty_like(self.registers) + + def packed_registers() -> Generator[tuple[list[index], ir.Value]]: + """Tries to pack registers into groups of 16 bits if vector_len < 4.""" + generator = np.ndenumerate(self.registers) + indices = [] + regs = [] + while True: + try: + for _ in range(max(4 // vector_len, 1)): + idx, reg = next(generator) + indices.append(idx) + regs.append(reg) + yield indices, utils.vector_concat(regs) + regs.clear() + indices.clear() + except StopIteration: + break + if regs: + yield indices, utils.vector_concat(regs) + + for indices, reg in packed_registers(): + group_size = ir.VectorType(reg.type).shape[0] + assert group_size % vector_len == 0 + int_ty = ir.IntegerType.get_signless(group_size * 4) + reg_as_i32 = utils.bitcast(reg, int_ty) + if int_ty != i32: + reg_as_i32 = arith.extsi(i32, reg_as_i32) + out_i32_regs = [ + upcast_to_f8e4m3fn(reg_as_i32, part=part) + for part in range(max(group_size // 4, 1)) + ] + out_vec_int = utils.vector_concat([ + vector.splat(ir.VectorType.get((1,), i32), out_i32_reg) + for out_i32_reg in out_i32_regs + ]) + out_vector_len = len(out_i32_regs) * 4 + # Bitcast to i8 first to allow slicing as necessary, since LLVM chokes + # on f8 types. + out_vec = utils.bitcast( + out_vec_int, ir.VectorType.get((out_vector_len,), i8) + ) + offset = 0 + for idx in indices: + sliced_out_vec = utils.vector_slice( + out_vec, slice(offset, offset + vector_len) + ) + new_registers[idx] = utils.bitcast( + sliced_out_vec, ir.VectorType.get((vector_len,), f8e4m3fn) + ) + offset += vector_len + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=None + ) if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) @@ -2132,7 +2224,9 @@ def load_tiled( reg_ty = ir.VectorType.get((layout.vector_length,), dtype) # f8 data types are not handled by the LLVM dialect, so we need to # transfer them as i8 and bitcast them back to f8. - transfer_ty = ir.VectorType.get((layout.vector_length,), i8 if is_f8 else dtype) + transfer_ty = ir.VectorType.get( + (layout.vector_length,), i8 if is_f8 else dtype + ) loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) for _, update, ptr in loads: loaded_reg = llvm.load(transfer_ty, ptr) @@ -2529,7 +2623,7 @@ def plan_tiled_transfer( raise ValueError( "Failed to prove that vector transfers don't cross swizzle tile" " boundaries. This check is incomplete, and does not guarantee that" - " this is a user error, but it might be." + str(transfer_alignment) + f" this is a user error, but it might be. {transfer_alignment=}" ) # 2. The transfer pattern does not cause bank conflicts. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 272d2e7b5c2b..bbaf4689b696 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -568,22 +568,42 @@ def kernel(ctx, out, smem): )() np.testing.assert_array_equal(iota, expected) - @parameterized.parameters(jnp.int8, jnp.int16, jnp.int32) - def test_sub_byte_conversion(self, jax_dtype_to): + @parameterized.product( + jax_dtype_to=( + jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float8_e4m3fn, + ), + # Use different layouts to vary the size of the vector dimension. + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + ), + ) + def test_sub_byte_conversion(self, jax_dtype_to, layout: fa.TiledLayout): + if jax_dtype_to == jnp.int32 and layout.vector_length == 8: + self.skipTest( + "Raises: failed to prove that vector transfers don't cross swizzle" + " tile boundaries.") jax_dtype_from = jnp.int4 + if jnp.issubdtype(jax_dtype_to, jnp.integer): + is_signed = jnp.issubdtype(jax_dtype_to, jnp.signedinteger) + else: + is_signed = None def kernel(ctx, inp, out, smem): del ctx # Unused. smem_inp, smem_out = smem copy(inp, smem_inp, swizzle=16) - t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16) - t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True) + t = mgpu.FragmentedArray.load_tiled( + smem_inp, is_signed=True, swizzle=16, layout=layout + ) + t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=is_signed) t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) x = self.prng.integers( low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32 ).astype(jax_dtype_from) - y = x.astype(jax_dtype_to) + y = jax.lax.convert_element_type(x, jax_dtype_to) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) np.testing.assert_array_equal(f(x), y) From ef106037062b318be09616500307a0a1b893531f Mon Sep 17 00:00:00 2001 From: DanisNone Date: Thu, 12 Jun 2025 13:05:33 +0500 Subject: [PATCH 1653/1769] add missing dtypes to jax.numpy.__init__.pyi add missing dtypes to jax.numpy.__init__.pyi --- jax/numpy/__init__.pyi | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index e81d97765121..0ff96a4394ce 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -456,12 +456,16 @@ def fliplr(m: ArrayLike) -> Array: ... def flipud(m: ArrayLike) -> Array: ... float16: Any float32: Any +float4_e2m1fn: Any float64: Any +float8_e3m4: Any +float8_e4m3: Any float8_e4m3b11fnuz: Any float8_e4m3fn: Any float8_e4m3fnuz: Any float8_e5m2: Any float8_e5m2fnuz: Any +float8_e8m0fnu: Any float_: Any def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: ... floating = _np.floating @@ -562,6 +566,7 @@ def inner( def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = ...) -> Array: ... int16: Any +int2: Any int32: Any int4: Any int64: Any @@ -944,6 +949,7 @@ def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any uint16: Any +uint2: Any uint32: Any uint4: Any uint64: Any From 4e3bf29b49e93ee6bef4ecc2e069d0fff01f359b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 07:44:11 -0700 Subject: [PATCH 1654/1769] Temporarily disable AVX512 in linalg_test_cpu. PiperOrigin-RevId: 770650220 --- tests/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 15cf4330d28b..95fecf89c7dc 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -932,6 +932,10 @@ jax_multiplatform_test( "notsan", # Times out. ], }, + env = { + # TODO(b/424430576): something is going wrong with AVX512 code generation. + "XLA_FLAGS": "--xla_cpu_max_isa=AVX2", + }, shard_count = { "cpu": 40, "gpu": 40, From 294d86b21a555d2176abb74639945010a58f7c3f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 12 Jun 2025 08:20:54 -0700 Subject: [PATCH 1655/1769] Add hermetic `nvshmem` dependencies to JAX targets. JAX wheel build rule implementation is also updated to exclude accidental dependencies on NVSHMEM libraries in the wheel content. If the wheel needs to be built with these dependencies, provide `--@local_config_nvshmem//:override_include_nvshmem_libs=True` in Bazel options. NVSHMEM binaries are included in the dependencies if CUDA binary dependencies are added as well, e.g. `--@local_config_cuda//:enable_cuda`. NVSHMEM libraries are included in the dependencies if `--@local_config_nvshmem//:include_nvshmem_libs=True` (the default flag value is `False`). Please note that this is a temporary solution, and it should be removed after GLIBC is updated on RBE runners. At the moment `libnvshmem.so` files can't be linked to the targets because they are built with GLIBC version higher than on RBE runners. In the future `--@local_config_cuda//cuda:include_cuda_libs=True` should be used. PiperOrigin-RevId: 770663391 --- .bazelrc | 7 +++++-- WORKSPACE | 27 +++++++++++++++++++++++++++ jax/BUILD | 1 + jax/experimental/mosaic/gpu/core.py | 16 +++++++++++++++- jaxlib/jax.bzl | 12 ++++++++++-- jaxlib/tools/BUILD.bazel | 1 + 6 files changed, 59 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 8906234c9061..0006b953cfe7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -146,13 +146,15 @@ build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_8 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default hermetic CUDA and CUDNN versions. +# Default hermetic CUDA, CUDNN and NVSHMEM versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" +build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# This config is used for building targets with CUDA libraries from stubs. +# This config is used for building targets with CUDA/NVSHMEM libraries from stubs. build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false +build:cuda_libraries_from_stubs --@local_config_nvshmem//:include_nvshmem_libs=false # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to @@ -332,6 +334,7 @@ build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 # Speed up CUDA repos creation by downloading ".tar" dists from the mirror. build:rbe_linux_x86_64_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1 +build:rbe_linux_x86_64_cuda --repo_env=USE_NVSHMEM_TAR_ARCHIVE_FILES=1 # RBE configs for Windows # Set the remote worker pool diff --git a/WORKSPACE b/WORKSPACE index f389afe2263f..6a7df6d9c8bc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -145,3 +145,30 @@ load( ) nccl_configure(name = "local_config_nccl") + +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl", + "nvshmem_json_init_repository", +) + +nvshmem_json_init_repository() + +load( + "@nvshmem_redist_json//:distributions.bzl", + "NVSHMEM_REDISTRIBUTIONS", +) +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl", + "nvshmem_redist_init_repository", +) + +nvshmem_redist_init_repository( + nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, +) + +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_configure.bzl", + "nvshmem_configure", +) + +nvshmem_configure(name = "local_config_nvshmem") diff --git a/jax/BUILD b/jax/BUILD index 8562c83cbad3..a40065fed695 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1183,6 +1183,7 @@ py_library_providing_imports_info( "@cuda_nvcc//:nvdisasm", "@cuda_nvcc//:nvvm", "@cuda_nvcc//:ptxas", + "@nvidia_nvshmem//:libnvshmem_device", ], visibility = [ ":mosaic_gpu_users", diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 21914aa6c8a3..2d17d4f1857a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -94,7 +94,21 @@ try: from nvidia import nvshmem except ImportError: - pass + # Try to find the nvshmem library in Bazel runfiles. + if PYTHON_RUNFILES: + libdevice_path = os.path.join( + PYTHON_RUNFILES, "nvidia_nvshmem", "lib", "libnvshmem_device.bc" + ) + if os.path.exists(libdevice_path): + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = libdevice_path + for root, _, files in os.walk(os.path.join(os.getcwd(), "_solib_local")): + if "libnvshmem_host.so.3" in files: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + root, "libnvshmem_host.so.3" + ) + break + else: + pass else: if os.environ.get("MOSAIC_GPU_NVSHMEM_BC_PATH") is None: os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = os.path.join( diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 835045c25ef1..c6dd9b1bdb3f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -186,17 +186,16 @@ def _gpu_test_deps(): "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", + # TODO(ybaturina): Remove this once we can add NVSHMEM libraries in the dependencies. "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_false": [ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", - "@pypi//nvidia_nvshmem_cu12", ], "//jax:config_build_jaxlib_wheel": [ "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", - "@pypi//nvidia_nvshmem_cu12", ], }) @@ -350,7 +349,9 @@ def _get_source_package_name(package_name, wheel_version): def _jax_wheel_impl(ctx): include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + include_nvshmem_libs = ctx.attr.include_nvshmem_libs[BuildSettingInfo].value override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + override_include_nvshmem_libs = ctx.attr.override_include_nvshmem_libs[BuildSettingInfo].value output_path = ctx.attr.output_path[BuildSettingInfo].value git_hash = ctx.attr.git_hash[BuildSettingInfo].value py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value @@ -361,6 +362,11 @@ def _jax_wheel_impl(ctx): " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + " If you absolutely need to build links directly against the CUDA libraries, provide" + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + if include_nvshmem_libs and not override_include_nvshmem_libs: + fail("JAX wheel shouldn't be built directly against the NVSHMEM libraries." + + " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + + " If you absolutely need to build links directly against the NVSHMEM libraries," + + " `provide --@local_config_nvshmem//:override_include_nvshmem_libs=true`.") env = {} args = ctx.actions.args() @@ -476,6 +482,8 @@ _jax_wheel = rule( "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "include_nvshmem_libs": attr.label(default = Label("@local_config_nvshmem//:include_nvshmem_libs")), + "override_include_nvshmem_libs": attr.label(default = Label("@local_config_nvshmem//:override_include_nvshmem_libs")), "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), }, implementation = _jax_wheel_impl, diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 7f8a85a5d9ab..03bd75144186 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -469,6 +469,7 @@ filegroup( "@pypi_nvidia_cusparse_cu12//:whl", "@pypi_nvidia_nccl_cu12//:whl", "@pypi_nvidia_nvjitlink_cu12//:whl", + "@pypi_nvidia_nvshmem_cu12//:whl", ], ) From f81f2589cb3c0ae76dfe2d982f70c1769c3448a8 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 12 Jun 2025 09:47:03 -0700 Subject: [PATCH 1656/1769] [Pallas TPU] Support memory space constraints on pallas_call inputs. This CL adds: - a new function `with_memory_space_constraint` that allow to add a memory space constraint on a Pallas call input. - `HBM` enum value for Pallas TPU memory spaces This CL only supports HBM and VMEM at the moment. In general these annotations only work on TPU because the mechanism that enforces them is only present in XLA TPU. PiperOrigin-RevId: 770694413 --- jax/_src/pallas/core.py | 8 +- jax/_src/pallas/mosaic/core.py | 1 + jax/_src/pallas/mosaic/primitives.py | 59 ++++++++++ jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/pallas/__init__.py | 6 +- jax/experimental/pallas/tpu.py | 3 + tests/pallas/BUILD | 15 +++ tests/pallas/tpu_pallas_memory_space_test.py | 108 +++++++++++++++++++ 8 files changed, 197 insertions(+), 4 deletions(-) create mode 100644 tests/pallas/tpu_pallas_memory_space_test.py diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 22e6201ac961..bf92a6cc6a5a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -519,14 +519,20 @@ def to_block_mapping( ) ref_block_shape = _get_ref_block_shape(block_shape) - block_array_aval = array_aval.update(shape=ref_block_shape) if isinstance(array_aval, jax_core.DShapedArray): # Get the "max" shape for the ragged array. + block_array_aval = array_aval.update(shape=ref_block_shape) block_array_aval = jax_core.ShapedArray( block_array_aval.shape, block_array_aval.dtype, block_array_aval.weak_type, ) + elif isinstance(array_aval, ShapedArrayWithMemorySpace): + block_array_aval = jax_core.ShapedArray( + ref_block_shape, array_aval.dtype, array_aval.weak_type + ) + else: + block_array_aval = array_aval.update(shape=ref_block_shape) block_aval = AbstractMemoryRef(block_array_aval, self.memory_space) if ( diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index b864a1df2ec5..2a6bdf9fa8a4 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -108,6 +108,7 @@ class MemorySpace(enum.Enum): SMEM = "smem" CMEM = "cmem" SEMAPHORE = "semaphore_mem" + HBM = "hbm" def __str__(self) -> str: return self.value diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index af50773eec20..30debf8643a2 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -755,3 +755,62 @@ def wrap_pallas_seed(*seeds, impl): """Joins scalar into a single PRNG key.""" impl = jax_random.resolve_prng_impl(impl) return join_key_p.bind(*seeds, impl=impl) + + +with_memory_space_constraint_p = jax_core.Primitive( + 'with_memory_space_constraint') + +@with_memory_space_constraint_p.def_impl +def with_memory_space_constraint_impl(x, *, memory_space): + del x, memory_space + raise ValueError("Cannot eagerly run with_memory_space_constraint.") + + +@with_memory_space_constraint_p.def_abstract_eval +def with_memory_space_constraint_abstract_eval(x, *, memory_space): + if not isinstance(x, jax_core.ShapedArray): + raise NotImplementedError("with_memory_space_constraint only supports " + "arrays.") + return pl_core.ShapedArrayWithMemorySpace( + x.shape, x.dtype, memory_space=memory_space + ) + +def with_memory_space_constraint_lowering_rule(ctx, x, *, memory_space): + del ctx, memory_space + return [x] +mlir.register_lowering( + with_memory_space_constraint_p, with_memory_space_constraint_lowering_rule +) + +def with_memory_space_constraint( + x: jax.Array, memory_space: Any +) -> jax.Array: + """Constrains the memory space of an array. + + This primitive does not change the value of `x`, but it constrains the + memory space where it should be allocated. This is useful to force + Pallas to allocate an array in a specific memory space. + + As of now, this only operates on the inputs pallas_calls, as in you can + apply this to the arguments of a pallas_call and it will constrain them, but + other operations will not respect this constraint. + + Args: + x: The array to constrain. + memory_space: The memory space to constrain to. + + Returns: + The array `x` with the memory space constraint. + """ + if memory_space not in {tpu_core.HBM, tpu_core.VMEM}: + raise NotImplementedError( + "with_memory_space_constraint only supports HBM and VMEM." + ) + return with_memory_space_constraint_p.bind(x, memory_space=memory_space) + +def get_memory_space(x: jax.Array) -> Any: + """Queries the memory space of an array.""" + aval = jax_core.get_aval(x) + if isinstance(aval, pl_core.ShapedArrayWithMemorySpace): + return aval.memory_space + return None diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3c34a26af982..950adcb75e55 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1556,6 +1556,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "bitcast", "repeat", "roll", + "with_memory_space_constraint", # temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740 "bias_fwd", "bias_bwd", diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index da2bc9119dd0..5c0ef332454c 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -36,13 +36,13 @@ from jax._src.pallas.core import Squeezed as Squeezed from jax._src.pallas.core import squeezed as squeezed from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost +from jax._src.pallas.helpers import debug_check as debug_check +from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like +from jax._src.pallas.helpers import enable_debug_checks as enable_debug_checks from jax._src.pallas.helpers import loop as loop from jax._src.pallas.helpers import when as when -from jax._src.pallas.helpers import debug_check as debug_check -from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled -from jax._src.pallas.helpers import enable_debug_checks as enable_debug_checks from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p from jax._src.pallas.primitives import atomic_add as atomic_add diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index eceb2e4f0383..c96bc8291c4d 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -41,12 +41,14 @@ from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import get_memory_space as get_memory_space from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll +from jax._src.pallas.mosaic.primitives import with_memory_space_constraint as with_memory_space_constraint from jax._src.pallas.mosaic.random import sample_block as sample_block from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key @@ -72,6 +74,7 @@ CMEM = MemorySpace.CMEM SMEM = MemorySpace.SMEM VMEM = MemorySpace.VMEM +HBM = MemorySpace.HBM SEMAPHORE = MemorySpace.SEMAPHORE import typing as _typing # pylint: disable=g-import-not-at-top diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index eb0eb05ac6a7..6e80a96d21fa 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -512,6 +512,21 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "tpu_pallas_memory_space_test", + srcs = ["tpu_pallas_memory_space_test.py"], + enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p", + ], + deps = [ + "//jax:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "tpu_pallas_state_test", srcs = ["tpu_pallas_state_test.py"], diff --git a/tests/pallas/tpu_pallas_memory_space_test.py b/tests/pallas/tpu_pallas_memory_space_test.py new file mode 100644 index 000000000000..f7536833ce57 --- /dev/null +++ b/tests/pallas/tpu_pallas_memory_space_test.py @@ -0,0 +1,108 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific uses of Pallas memory space APIs.""" + +import functools +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class TPUPallasMemorySpaceTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.if_cloud_tpu_at_least(2025, 6, 10): + self.skipTest('Needs a newer libTPU') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.HBM, 0), + (pltpu.ANY, None), + ) + def test_basic_input_memory_space_constraint(self, memory_space, color): + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + def g(x): + return pl.pallas_call(kernel, out_shape=x)(x) + + @jax.jit + def f(x): + x = pltpu.with_memory_space_constraint(x, memory_space=memory_space) + self.assertEqual(pltpu.get_memory_space(x), memory_space) + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f).lower(x).compile().as_text() + if color is None: + self.assertIn('"input_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"input_memory_space_colors":[{{"operand_index":"0","color":"{color}","shape_index":[]}}]', + hlo, + ) + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.HBM, 0), + (pltpu.ANY, None), + ) + def test_basic_output_memory_space_constraint(self, memory_space, color): + if color is None: + memory_space = jax.ShapeDtypeStruct + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + def g(x): + return pl.pallas_call(kernel, out_shape=memory_space(x.shape, x.dtype))(x) + + @jax.jit + def f(x): + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f).lower(x).compile().as_text() + if color is None: + self.assertIn('"output_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"output_memory_space_colors":[{{"color":"{color}","shape_index":[]}}]', + hlo, + ) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From 1228053f11926a112dd9299601a53909f5c42295 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 12 Jun 2025 10:18:49 -0700 Subject: [PATCH 1657/1769] [JAX] Update the example to use jax.numpy rather than numpy. PiperOrigin-RevId: 770707628 --- docs/rank_promotion_warning.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index 5e4e7ec65cbc..6ec0000e2ffc 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -9,14 +9,14 @@ surprising bugs where a silent rank promotion masks an underlying shape error. Here's an example of rank promotion: ->>> import numpy as np ->>> x = np.arange(12).reshape(4, 3) ->>> y = np.array([0, 1, 0]) +>>> from jax import numpy as jnp +>>> x = jnp.arange(12).reshape(4, 3) +>>> y = jnp.array([0, 1, 0]) >>> x + y -array([[ 0, 2, 2], +Array([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], - [ 9, 11, 11]]) + [ 9, 11, 11]], dtype=int32) To avoid potential surprises, :code:`jax.numpy` is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be From 07193467a0f0d601f863a9793ef0227a73a4e43d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 10:29:38 -0700 Subject: [PATCH 1658/1769] Reland the C++ safe_zip implementation. It turns out to benchmark slightly better still. PiperOrigin-RevId: 770712636 --- jax/_src/util.py | 55 ++++++++++++++------------ jaxlib/utils.cc | 94 ++++++++++++++++++++++++++++++++++++++++++++ jaxlib/xla_client.py | 2 +- tests/util_test.py | 18 ++++----- 4 files changed, 134 insertions(+), 35 deletions(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index 71d8f8bfa6a1..595f73dbc466 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -28,6 +28,7 @@ import numpy as np from jax._src import config +from jax._src.lib import jaxlib_extension_version from jax._src.lib import weakref_lru_cache as _weakref_lru_cache from jax._src.lib import utils as jaxlib_utils @@ -43,32 +44,36 @@ T2 = TypeVar("T2") T3 = TypeVar("T3") -# safe_zip cannot yet be fully annotated, so we use a strategy similar -# to that used for builtins.zip in python/typeshed. This supports -# return types matching input types for up to three arguments. -@overload -def safe_zip(__arg1: Iterable[T1], /) -> list[tuple[T1]]: ... -@overload -def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[tuple[T1, T2]]: ... -@overload -def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[tuple[T1, T2, T3]]: ... -@overload -def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[tuple[Any, ...]]: ... - -def safe_zip(*args): - """ - Like builtin :func:`zip`, but with additional safety checks. - - The differences from :func:`zip` are: - - :func:`safe_zip` checks that at least one argument is provided. - - :func:`safe_zip` checks that all arguments have the same length. - - :func:`safe_zip` returns an eagerly-evaluated list instead of a - lazily-evaluated iterator. - """ - if not args: - raise TypeError("safe_zip requires at least 1 argument.") - return list(zip(*args, strict=True)) +if TYPE_CHECKING or jaxlib_extension_version < 354: + # safe_zip cannot yet be fully annotated, so we use a strategy similar + # to that used for builtins.zip in python/typeshed. This supports + # return types matching input types for up to three arguments. + @overload + def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ... + @overload + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ... + @overload + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ... + @overload + def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ... + + def safe_zip(*args): + """ + Like builtin :func:`zip`, but with additional safety checks. + + The differences from :func:`zip` are: + + - :func:`safe_zip` checks that at least one argument is provided. + - :func:`safe_zip` checks that all arguments have the same length. + - :func:`safe_zip` returns an eagerly-evaluated list instead of a + lazily-evaluated iterator. + """ + if not args: + raise TypeError("safe_zip requires at least 1 argument.") + return list(zip(*args, strict=True)) +else: + safe_zip = jaxlib_utils.safe_zip if TYPE_CHECKING: diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 1cf6798010ed..e5bb45e999da 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -204,6 +204,98 @@ PyMethodDef foreach_def = { "ignoring the return values and returns None. The iterables must all have " "the same lengths."}; +// A variant of zip(...) that: +// a) returns a list instead of an iterator, and +// b) checks that the input iterables are of equal length. +// TODO(phawkins): consider replacing this function with +// list(zip(..., strict=True)) once TensorFlow 2.13 is released, which should +// resolve an incompatibility with strict=True and jax2tf. +PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { + if (nargs < 1) { + PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument"); + return nullptr; + } + absl::InlinedVector iterators; + iterators.reserve(nargs); + for (Py_ssize_t i = 0; i < nargs; ++i) { + PyObject* it = PyObject_GetIter(args[i]); + if (!it) return nullptr; + iterators.push_back(nb::steal(it)); + } + + // Try to use a length hint to estimate how large a list to allocate. + Py_ssize_t length_hint = PyObject_LengthHint(args[0], 2); + if (PyErr_Occurred()) { + PyErr_Clear(); + } + if (length_hint < 0) { + length_hint = 2; + } + + nb::list list = nb::steal(PyList_New(length_hint)); + int n = 0; // Current true size of the list + + while (true) { + nb::object tuple; + nb::object v = nb::steal(PyIter_Next(iterators[0].ptr())); + if (PyErr_Occurred()) return nullptr; + + if (v.ptr()) { + tuple = nb::steal(PyTuple_New(nargs)); + if (!tuple.ptr()) return nullptr; + + PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr()); + for (size_t i = 1; i < iterators.size(); ++i) { + v = nb::steal(PyIter_Next(iterators[i].ptr())); + if (PyErr_Occurred()) return nullptr; + if (!v.ptr()) { + PyErr_Format(PyExc_ValueError, + "safe_zip() argument %u is shorter than argument 1", + i + 1); + return nullptr; + } + PyTuple_SET_ITEM(tuple.ptr(), i, v.release().ptr()); + } + } else { + // No more elements should be left. Checks the other iterators are + // exhausted. + for (size_t i = 1; i < iterators.size(); ++i) { + v = nb::steal(PyIter_Next(iterators[i].ptr())); + if (PyErr_Occurred()) return nullptr; + if (v.ptr()) { + PyErr_Format(PyExc_ValueError, + "safe_zip() argument %u is longer than argument 1", + i + 1); + return nullptr; + } + } + + // If the length hint was too large, truncate the list to the true size. + if (n < length_hint) { + if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) { + return nullptr; + } + } + return list.release().ptr(); + } + + if (n < length_hint) { + PyList_SET_ITEM(list.ptr(), n, tuple.release().ptr()); + } else { + if (PyList_Append(list.ptr(), tuple.ptr()) < 0) { + return nullptr; + } + tuple = nb::object(); + } + ++n; + } +} + +PyMethodDef safe_zip_def = { + "safe_zip", + reinterpret_cast(SafeZip), + METH_FASTCALL, +}; nb::list TopologicalSort(nb::str parents_attr, nb::iterable end_nodes_iterable) { @@ -276,6 +368,8 @@ NB_MODULE(utils, m) { PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); m.attr("foreach") = nb::steal( PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr())); + m.attr("safe_zip") = nb::steal( + PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"), nb::arg("end_nodes"), diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 4b8c1133e090..f8f7dc373aa6 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 353 +_version = 354 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/util_test.py b/tests/util_test.py index 923240b69242..544858fa089a 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -232,28 +232,28 @@ def test_safe_zip(self): ) def test_safe_zip_errors(self): - with self.assertRaisesWithLiteralMatch( - TypeError, "safe_zip requires at least 1 argument." + with self.assertRaisesRegex( + TypeError, "safe_zip requires at least 1 argument" ): util.safe_zip() - with self.assertRaisesWithLiteralMatch( + with self.assertRaisesRegex( TypeError, "'function' object is not iterable" ): util.safe_zip(lambda x: x) - with self.assertRaisesWithLiteralMatch( - ValueError, "zip() argument 2 is longer than argument 1" + with self.assertRaisesRegex( + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip(range(3), range(4)) - with self.assertRaisesWithLiteralMatch( - ValueError, "zip() argument 2 is shorter than argument 1" + with self.assertRaisesRegex( + ValueError, r"zip\(\) argument 2 is shorter than argument 1" ): util.safe_zip(range(7), range(2)) - with self.assertRaisesWithLiteralMatch( - ValueError, "zip() argument 2 is longer than argument 1" + with self.assertRaisesRegex( + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip((), range(3)) From 6e2977d8b9abcc72dfd1b179dd2437c5c009d6b5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 12 Jun 2025 12:01:04 -0700 Subject: [PATCH 1659/1769] Add `all_gather_invariant` to lax. `all_gather_invariant` differs from `all_gather` in the following ways: * `all_gather_invariant` is `Varying -> Invariant`. For example: `out: f32[8] = all_gather_invariant(inp: f32[4]{V: x}, 'x')` where the size of mesh axis `x` is `2`. While `all_gather` is `Varying -> Varying`. * `all_gather_invariant` transposes to `dynamic_slice` which is `Invariant -> Varying`. While `all_gather` transposes to `reduce_scatter` which is `Varying -> Varying`. PiperOrigin-RevId: 770746712 --- jax/_src/lax/parallel.py | 124 ++++++++++++++++++++++++++---- jax/experimental/jax2tf/jax2tf.py | 1 + jax/lax/__init__.py | 1 + tests/shard_map_test.py | 74 +++++++++++++++--- 4 files changed, 175 insertions(+), 25 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index a5bb7222143d..a38ad81733be 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1630,37 +1630,41 @@ def _all_gather_effectful_abstract_eval( return (x_aval.update(shape=new_shape, vma=out_vma), {*map(core.NamedAxisEffect, axis_name)}) -def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, scatter_dimension=all_gather_dimension, axis_index_groups=axis_index_groups, tiled=tiled),) - # TODO(sharadmv,apaszke): re-enable this when we can properly detect replication. - # return (lax.dynamic_index_in_dim(cts, idx, axis=all_gather_dimension, keepdims=False) * axis_size,) -def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_batcher(prim, vals_in, dims_in, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): (x,), (d,) = vals_in, dims_in if d is not batching.not_mapped: if d <= all_gather_dimension: all_gather_dimension += 1 elif not tiled: # Tiled all-gather doesn't modify the set of dimensions d += 1 - result = all_gather_p.bind( - x, - all_gather_dimension=all_gather_dimension, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) - return result, d + if prim is all_gather_p: + result = all_gather_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=axis_index_groups, axis_size=axis_size, + tiled=tiled) + return result, d + else: + assert prim is all_gather_invariant_p + result = all_gather_invariant_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_size=axis_size, tiled=tiled) + return result, d -def _all_gather_batched_collective(axis_data, vals_in, dims_in, +def _all_gather_batched_collective(prim, axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): frame_size, frame_name = axis_data.size, axis_data.name if frame_name not in axis_name: return _all_gather_batcher( - vals_in, dims_in, all_gather_dimension=all_gather_dimension, + prim, vals_in, dims_in, all_gather_dimension=all_gather_dimension, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: @@ -1692,10 +1696,100 @@ def _all_gather_batched_collective(axis_data, vals_in, dims_in, partial(_all_gather_lowering, platform=p), platform=p) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) -batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective +batching.fancy_primitive_batchers[all_gather_p] = partial( + _all_gather_batched_collective, all_gather_p) batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name') +def all_gather_invariant(x, axis_name, *, axis: int = 0, tiled: bool = False): + """Gather values of x across all replicas. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + all_gather_invariant differs from all_gather in the following ways: + + * all_gather_invariant is Varying -> Invariant. + For example: `out: f32[8] = all_gather_invariant(inp: f32[4]{V: x}, 'x')` + where the size of mesh axis `x` is 2. + While all_gather is Varying -> Varying. + + * all_gather_invariant transposes to dynamic_slice which is + Invariant -> Varying. While all_gather transposes to reduce_scatter + which is Varying -> Varying. + """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, + axis_size = _axis_size(axis_name, None) + axes_ = frozenset(axis_name) + def bind(leaf): + in_vma = core.typeof(leaf).vma + if vary_names := axes_ - in_vma: + leaf = pvary(leaf, tuple(vary_names)) + return all_gather_invariant_p.bind( + leaf, + all_gather_dimension=canonicalize_axis(axis, np.ndim(leaf) if tiled else + np.ndim(leaf) + 1), + axis_name=axis_name, axis_size=axis_size, tiled=tiled) + return tree_util.tree_map(bind, x) + +all_gather_invariant_p = core.Primitive('all_gather_invariant') + +def _all_gather_invariant_effectful_abstract_eval( + x_aval, *, all_gather_dimension, axis_name, axis_size, tiled +): + _check_axis_names(axis_name) + new_shape = list(x_aval.shape) + if tiled: + new_shape[all_gather_dimension] *= axis_size + else: + new_shape.insert(all_gather_dimension, axis_size) + out_vma = frozenset(v for v in x_aval.vma if v not in axis_name) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) + +all_gather_invariant_p.def_effectful_abstract_eval( + _all_gather_invariant_effectful_abstract_eval) + +def _all_gather_invariant_impl(x, *, all_gather_dimension, axis_name, axis_size, + tiled): + raise NotImplementedError +all_gather_invariant_p.def_impl(_all_gather_invariant_impl) + + +def _all_gather_invariant_lowering( + ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled, platform=None): + return _all_gather_lowering( + ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=None, axis_size=axis_size, tiled=tiled, + platform=platform) + +mlir.register_lowering(all_gather_invariant_p, _all_gather_invariant_lowering) +for p in ("cuda", "rocm", "tpu"): + mlir.register_lowering(all_gather_invariant_p, + partial(_all_gather_invariant_lowering, platform=p), + platform=p) + +def _all_gather_invariant_transpose_rule( + cts, x, *, all_gather_dimension, axis_name, axis_size, tiled): + slice_size, rem = divmod(cts.shape[all_gather_dimension], axis_size) + assert not rem + idx = axis_index(axis_name) * slice_size + out = slicing.dynamic_slice_in_dim( + cts, idx, slice_size=slice_size, axis=all_gather_dimension) + return (out,) if tiled else (lax.squeeze(out, [all_gather_dimension]),) +ad.deflinear2(all_gather_invariant_p, _all_gather_invariant_transpose_rule) + +def _all_gather_invariant_batched_collective( + axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size, + tiled): + return _all_gather_batched_collective( + all_gather_invariant_p, axis_data, vals_in, dims_in, all_gather_dimension, + axis_name, None, axis_size, tiled) +batching.fancy_primitive_batchers[all_gather_invariant_p] = _all_gather_invariant_batched_collective +batching.skippable_batchers[all_gather_invariant_p] = partial(_names_in_param, 'axis_name') + + def _reduce_scatter_lowering( prim, ctx, x, *, scatter_dimension, axis_name, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 950adcb75e55..6a118e4a9b80 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1534,6 +1534,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "reduce_scatter", "axis_index", "all_gather", + "all_gather_invariant", "lu_pivots_to_permutation", "xla_pmap", "geqrf", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index c6df458ba91d..57019c7ed3fb 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -356,6 +356,7 @@ ) from jax._src.lax.parallel import ( all_gather as all_gather, + all_gather_invariant as all_gather_invariant, all_gather_p as all_gather_p, all_to_all as all_to_all, all_to_all_p as all_to_all_p, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f3f5641be1b6..a6ebd1c96c04 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -58,14 +58,14 @@ zip, unsafe_zip = safe_zip, zip # Helper for some tests. -def create_inputs(a_sharding, b_sharding): +def create_inputs(a_sharding, b_sharding, dtype=None): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( - jnp.arange(b * e).reshape((b, e)), + jnp.arange(b * e, dtype=dtype).reshape((b, e)), jax.sharding.NamedSharding(mesh, a_sharding)) m2 = jax.device_put( - jnp.arange(e * f).reshape((e, f)), + jnp.arange(e * f, dtype=dtype).reshape((e, f)), jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 @@ -95,17 +95,13 @@ def test_all_gather(self): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) - # NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use - # all_gather_invariant primitive, which differs in its output replication - # type compared to all_gather. @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): - return ( - lax.all_gather(a, 'z', axis=0, tiled=True), - lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True), - ) + return (lax.all_gather(a, 'z', axis=0, tiled=True), + lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True)) + c, d = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) for i, a_shard in enumerate(np.split(a, 4, axis=1)): @@ -114,6 +110,64 @@ def fwd(a): for i, a_shard in enumerate(np.split(a, 2, axis=0)): self.assertAllClose(d.addressable_data(i), a_shard) + def test_all_gather_invariant_basic(self): + mesh = jtu.create_mesh((4,), 'x') + arr = jnp.arange(8.) + + @jax.jit + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P()) + def f(a): + out = lax.all_gather_invariant(a, 'x', tiled=True) + self.assertEqual(out.aval.vma, set()) + return out + + out = f(arr) + self.assertArraysEqual(out, arr) + + jtu.check_grads(f, (arr,), order=2) + + def g(x): + return f(x).sum() + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.shape, (8,)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_all_gather_invariant_complex(self): + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None), + dtype=np.float32) + assert a.addressable_data(0).shape == (4, 2) + + @jax.jit + @shard_map(mesh=mesh, in_specs=(P('z', ('x', 'y')),), + out_specs=(P(None, ('x', 'y')), P('z'))) + def f(a): + c = lax.all_gather_invariant(a, 'z', axis=0, tiled=True) + self.assertEqual(jax.typeof(c).vma, {'x', 'y'}) + d = lax.all_gather_invariant(a, ('x', 'y'), axis=-1, tiled=True) + self.assertEqual(jax.typeof(d).vma, {'z'}) + return c, d + + c, d = f(a) + + self.assertEqual(c.addressable_data(0).shape, (8, 2)) + for i, a_shard in enumerate(np.split(a, 4, axis=1)): + self.assertAllClose(c.addressable_data(2 * i), a_shard) + + self.assertEqual(d.addressable_data(0).shape, (4, 8)) + for i, a_shard in enumerate(np.split(a, 2, axis=0)): + self.assertAllClose(d.addressable_data(i), a_shard) + + def g(x): + return f(x)[0].sum() + + out1 = jax.jit(jax.grad(g))(a) + self.assertEqual(out1.shape, (8, 8)) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + + out2 = jax.grad(g)(a) + self.assertEqual(out2.shape, (8, 8)) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + def test_all_gather_with_axis_index_groups(self): mesh, a, _ = create_inputs(P('x', ('y', 'z')), P(None, None)) From a9fdb768ed38e42f9a290bb0d8b1efc7e58c53f6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 12 Jun 2025 14:07:30 -0700 Subject: [PATCH 1660/1769] [Pallas TPU] Small fix to memory space constraints on pallas_call inputs. PiperOrigin-RevId: 770788317 --- jax/_src/pallas/mosaic/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 30debf8643a2..7fb1df19d747 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -802,7 +802,7 @@ def with_memory_space_constraint( Returns: The array `x` with the memory space constraint. """ - if memory_space not in {tpu_core.HBM, tpu_core.VMEM}: + if memory_space not in {tpu_core.MemorySpace.HBM, tpu_core.MemorySpace.VMEM}: raise NotImplementedError( "with_memory_space_constraint only supports HBM and VMEM." ) From 094b66fa4e7fd0e684f6aff539feae9e38946c33 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 12 Jun 2025 14:18:24 -0700 Subject: [PATCH 1661/1769] [Mosaic GPU] Enable transpose tests in mosaic_gpu. Not sure why these were disabled, but they work now. PiperOrigin-RevId: 770792116 --- tests/mosaic/gpu_test.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index bbaf4689b696..c8d2a5e3a291 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3452,8 +3452,13 @@ def test_wgmma_kernel_with_tma( if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: self.skipTest("No swizzle is not supported by wgmma") - if transpose_lhs or transpose_rhs: - self.skipTest("Transposes are not supported by transform inference yet.") + # TODO(dasenov): This condition is wrong, remove it after we support + # reconciling the siwzzle of the a and b operands of wgmma. + if transpose_lhs and swizzle != mgpu_dialect.SwizzlingMode.k128ByteSwizzle: + self.skipTest("If A is transposed, its swizzle must be 128 bytes.") + + if transpose_lhs and load_a_in_registers: + self.skipTest("The A operand can only be transposed if it is in SMEM.") swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems From efbedc61e136f14a9945a86293bc2ac2236d20de Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 12 Jun 2025 14:20:16 -0700 Subject: [PATCH 1662/1769] [doc] fix some inaccuracies in jnp.bincount docs --- jax/_src/numpy/lax_numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d823688ee674..abbdf6d0411d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2903,7 +2903,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, JAX implementation of :func:`numpy.bincount`. - For an array of positive integers ``x``, this function returns an array ``counts`` + For an array of non-negative integers ``x``, this function returns an array ``counts`` of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences of the value ``i`` in ``x``. @@ -2916,7 +2916,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped. Args: - x : N-dimensional array of positive integers + x : 1-dimensional array of non-negative integers weights: optional array of weights associated with ``x``. If not specified, the weight for each entry will be ``1``. minlength: the minimum length of the output counts array. From abb756d41af09a034a4f9e748ddab4118869467b Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Thu, 12 Jun 2025 14:40:01 -0700 Subject: [PATCH 1663/1769] Add colorama back into test-requirements It's still being used on Windows PiperOrigin-RevId: 770800299 --- build/requirements.in | 3 +++ build/requirements_lock_3_10.txt | 4 ++++ build/requirements_lock_3_11.txt | 4 ++++ build/requirements_lock_3_12.txt | 4 ++++ build/requirements_lock_3_13.txt | 4 ++++ build/requirements_lock_3_13_ft.txt | 4 ++++ 6 files changed, 23 insertions(+) diff --git a/build/requirements.in b/build/requirements.in index c1be7a250bff..96c27739a5e7 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -27,3 +27,6 @@ libtpu ; sys_platform == "linux" and platform_machine == "x86_64" # For Mosaic GPU collectives nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" + +# Platform-specific dependencies that are being ignored by pip-compile +colorama>=0.4.4 diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index b80980489fc0..5caa99692756 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -24,6 +24,10 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/requirements.in contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index ecc5d85b2f2e..de3c35ed3c02 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -24,6 +24,10 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index bce2a45a3984..04c6990da696 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -24,6 +24,10 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 3cc09776606f..965cb3bc9672 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -24,6 +24,10 @@ cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt contourpy==1.3.0 \ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index efc6fcf45814..e7d111c3b3e9 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -24,6 +24,10 @@ cloudpickle==3.1.0 \ --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \ --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt contourpy==1.3.1 \ --hash=sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1 \ --hash=sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda \ From 688c3d3341217c0fa443c050a915ab6375eecccf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 17:12:21 -0700 Subject: [PATCH 1664/1769] Use a frozenset for unconstrained_dims in sharding_constraint_p. Makes the unconstrained dims hashable. PiperOrigin-RevId: 770853732 --- jax/_src/pjit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9076763fb7ea..f9e751bb1e40 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2766,7 +2766,7 @@ def with_sharding_constraint(x, shardings): # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) - if isinstance(s, NamedSharding) else {} + if isinstance(s, NamedSharding) else frozenset() for s in shardings_flat] pjit_check_aval_sharding( @@ -3190,8 +3190,8 @@ def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout): def get_unconstrained_dims(sharding: NamedSharding): assert sharding.spec is not None - return {i for i, axes in enumerate(sharding.spec) - if axes is PartitionSpec.UNCONSTRAINED} + return frozenset(i for i, axes in enumerate(sharding.spec) + if axes is PartitionSpec.UNCONSTRAINED) # -------------------- attrs etc -------------------- From b72be575cf32dc17357effa112888b0cb353e8e5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 17:44:46 -0700 Subject: [PATCH 1665/1769] [jaxlib] Change Traceback to be a raw CPython class rather than a nanobind class. I was hoping this change would have a positive impact on performance but it is close to neutral in my benchmarks. Still, I think this version is slightly better because it: * allocates the frames inline into the traceback object, saving at least one allocation per traceback collection * fixes a data race on the enabled flag, admittedly one that is probably troubling no-one in practice. PiperOrigin-RevId: 770862697 --- jaxlib/BUILD | 1 + jaxlib/_jax/__init__.pyi | 2 + jaxlib/py_array.cc | 19 +- jaxlib/py_array.h | 18 +- jaxlib/py_client.cc | 23 +- jaxlib/py_executable.cc | 3 +- jaxlib/py_executable.h | 6 +- jaxlib/traceback.cc | 443 ++++++++++++++++++++++++--------------- jaxlib/traceback.h | 70 ++----- jaxlib/xla_client.py | 6 +- 10 files changed, 325 insertions(+), 266 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index c95e29c2f6b1..3d9d3c709415 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1196,6 +1196,7 @@ cc_library( "@tsl//tsl/platform", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla/pjrt:exceptions", + "@xla//xla/python:nb_helpers", "@xla//xla/tsl/platform:logging", ], ) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 91fc49197054..08eca19da5b1 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -792,6 +792,8 @@ class Traceback: code: types.CodeType, lasti: int ) -> tuple[int, int, int, int]: ... +def tracebacks_enabled() -> bool: ... +def set_tracebacks_enabled(enabled: bool) -> None: ... def replace_thread_exc_traceback(traceback: Any): ... # === END py_traceback.cc diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index b79236f1306a..6d6f76b27d88 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -333,7 +333,7 @@ extern "C" int PyArray_tp_clear(PyObject* self) { if (obj->initialized) { auto traceback = GetPyArrayStorageFromObject(obj)->traceback; if (traceback.has_value()) { - traceback_str = traceback.value()->ToString(); + traceback_str = traceback.value().ToString(); } } auto error_msg = absl::StrCat( @@ -458,7 +458,7 @@ struct BatchedCopyToDeviceWithShardingKey { PyArray_Storage::PyArray_Storage( nb::object aval, bool weak_type, xla::nb_dtype dtype, std::vector shape, nb::object sharding, bool committed, - nb_class_ptr py_client, std::optional traceback, + nb_class_ptr py_client, std::optional traceback, ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status) : aval(std::move(aval)), weak_type(weak_type), @@ -514,10 +514,11 @@ void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, } } -PyArray PyArray::MakeFromSingleDeviceArray( - nb_class_ptr py_client, std::optional traceback, - ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, - xla::PjRtFuture<> result_status) { +PyArray PyArray::MakeFromSingleDeviceArray(nb_class_ptr py_client, + std::optional traceback, + ifrt::ArrayRef ifrt_array, + bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { if (!llvm::isa(ifrt_array->sharding())) { throw XlaRuntimeError( InvalidArgument("Constructing single device jax.Array from non-single " @@ -547,7 +548,7 @@ PyArray PyArray::MakeFromSingleDeviceArray( } PyArray PyArray::MakeFromIfrtArrayAndSharding( - nb_class_ptr py_client, std::optional traceback, + nb_class_ptr py_client, std::optional traceback, ifrt::ArrayRef ifrt_array, nb::object sharding, bool weak_type, bool committed, bool skip_checks) { auto shape_span = ifrt_array->shape().dims(); @@ -606,8 +607,8 @@ PyArray PyArrayResultHandler::Call(PyArray py_array) const { PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, std::vector shape, nb::object sharding, nb_class_ptr py_client, - std::optional traceback, - ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, + std::optional traceback, ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, xla::PjRtFuture<> result_status) { auto* self = PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h index bf1208c11da5..5b496be091f0 100644 --- a/jaxlib/py_array.h +++ b/jaxlib/py_array.h @@ -47,7 +47,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/shape.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" namespace xla { @@ -93,8 +92,8 @@ struct PyArray_Storage { PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, std::vector shape, nanobind::object sharding, bool committed, nb_class_ptr py_client, - std::optional traceback, - ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status); + std::optional traceback, ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status); ~PyArray_Storage(); nanobind::handle AsHandle(); @@ -109,7 +108,7 @@ struct PyArray_Storage { bool committed = false; nb_class_ptr py_client; - std::optional traceback; + std::optional traceback; ifrt::ArrayRef ifrt_array; nanobind::object fully_replicated_array = nanobind::none(); @@ -151,18 +150,17 @@ class PyArray : public nanobind::object { // checked. PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, std::vector shape, nanobind::object sharding, - nb_class_ptr py_client, - std::optional traceback, ifrt::ArrayRef ifrt_array, - bool committed, bool skip_checks, + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); static PyArray MakeFromSingleDeviceArray( - nb_class_ptr py_client, std::optional traceback, + nb_class_ptr py_client, std::optional traceback, ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); static PyArray MakeFromIfrtArrayAndSharding( - nb_class_ptr py_client, std::optional traceback, + nb_class_ptr py_client, std::optional traceback, ifrt::ArrayRef ifrt_array, nanobind::object sharding, bool weak_type, bool committed, bool skip_checks); @@ -199,7 +197,7 @@ class PyArray : public nanobind::object { return GetStorage().py_client; } - const std::optional& traceback() const { + const std::optional& traceback() const { return GetStorage().traceback; } diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index f478040e8622..fbbb803607ee 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -552,7 +552,7 @@ PyClient::DeserializeExecutable(nb_class_ptr client, namespace { struct HeapProfileKey { - Traceback* traceback; + std::optional traceback; int64_t size; xla::PjRtDevice* device; bool operator==(const HeapProfileKey& other) const; @@ -562,10 +562,10 @@ bool HeapProfileKey::operator==(const HeapProfileKey& other) const { if (size != other.size || device != other.device) { return false; } - if ((traceback == nullptr) != (other.traceback == nullptr)) { + if ((traceback.has_value()) != (other.traceback.has_value())) { return false; } - if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) { + if (traceback.has_value() && traceback->not_equal(*other.traceback)) { return false; } return true; @@ -574,7 +574,7 @@ bool HeapProfileKey::operator==(const HeapProfileKey& other) const { template H AbslHashValue(H h, const HeapProfileKey& key) { if (key.traceback) { - h = H::combine(std::move(h), key.traceback->raw_frames()); + h = H::combine(std::move(h), nb::hash(*key.traceback)); } h = H::combine(std::move(h), key.size, key.device); return h; @@ -587,7 +587,8 @@ absl::StatusOr PyClient::HeapProfile() { absl::flat_hash_set buffer_set; absl::flat_hash_map entries; - auto add_buffer_to_profile = [&](PjRtBuffer* buffer, Traceback* traceback) { + auto add_buffer_to_profile = [&](PjRtBuffer* buffer, + std::optional traceback) { // We only wish to count each PjRtBuffer once, even though they may be // shared by multiple PyArrays. if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { @@ -613,17 +614,15 @@ absl::StatusOr PyClient::HeapProfile() { "only."); } for (const auto& buffer : arr->pjrt_buffers()) { - TF_RETURN_IF_ERROR(add_buffer_to_profile( - buffer.get(), - array.traceback() ? array.traceback()->get() : nullptr)); + TF_RETURN_IF_ERROR( + add_buffer_to_profile(buffer.get(), array.traceback())); } } for (PyLoadedExecutable* executable = executables_; executable; executable = executable->next_) { - HeapProfileKey key{ - executable->traceback() ? executable->traceback()->get() : nullptr, - executable->SizeOfGeneratedCodeInBytes(), nullptr}; + HeapProfileKey key{executable->traceback(), + executable->SizeOfGeneratedCodeInBytes(), nullptr}; ++entries[key]; } @@ -642,7 +641,7 @@ absl::StatusOr PyClient::HeapProfile() { for (const auto& entry : entries) { auto* sample = builder.profile().add_sample(); if (entry.first.traceback) { - for (const auto& frame : entry.first.traceback->raw_frames()) { + for (const auto& frame : entry.first.traceback->RawFrames()) { sample->add_location_id(builder.LocationId(frame.first, frame.second)); } } diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index a9a7ebea2d0b..c934c83adebc 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -85,8 +85,7 @@ absl::Status PyShardedToken::Await() { PyLoadedExecutable::PyLoadedExecutable( nb_class_ptr client, ifrt::LoadedExecutableRef ifrt_loaded_executable, - std::optional traceback, - std::optional fingerprint) + std::optional traceback, std::optional fingerprint) : client_(std::move(client)), ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), traceback_(std::move(traceback)), diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index fed6552a9eb5..ee68e8388627 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -181,7 +181,7 @@ class PyLoadedExecutable { public: PyLoadedExecutable(nb_class_ptr client, ifrt::LoadedExecutableRef ifrt_loaded_executable, - std::optional traceback, + std::optional traceback, std::optional fingerprint); ~PyLoadedExecutable(); @@ -231,7 +231,7 @@ class PyLoadedExecutable { std::optional> GetOutputShardings() const; - const std::optional& traceback() { return traceback_; } + const std::optional& traceback() { return traceback_; } ifrt::LoadedExecutable* ifrt_executable() const { return ifrt_loaded_executable_.get(); @@ -267,7 +267,7 @@ class PyLoadedExecutable { nb_class_ptr client_; ifrt::LoadedExecutableRef ifrt_loaded_executable_; - std::optional traceback_; + std::optional traceback_; // Identical executables (i.e. representing the same program) will have the // same fingerprint. nullopt on platforms or executables where fingerprints diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index 48edc584c94f..4ae9cf62130a 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -17,15 +17,20 @@ limitations under the License. #include +#include +#include #include +#include #include #include +#include #include #include #include "absl/base/casts.h" #include "absl/hash/hash.h" #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -34,8 +39,8 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/nb_class_ptr.h" #include "xla/pjrt/exceptions.h" +#include "xla/python/nb_helpers.h" #include "tsl/platform/platform.h" #ifdef PLATFORM_GOOGLE @@ -44,125 +49,137 @@ limitations under the License. #undef Py_BUILD_CORE #endif // PLATFORM_GOOGLE +namespace nb = nanobind; + namespace xla { -namespace nb = nanobind; +namespace { -bool Traceback::enabled_ = true; +std::atomic traceback_enabled_ = true; -Traceback::Traceback() { - DCHECK(PyGILState_Check()); - PyThreadState* thread_state = PyThreadState_GET(); +static constexpr int kMaxFrames = 512; -#if PY_VERSION_HEX < 0x030b0000 - // The representation of frame->f_lasti changed from bytes to words in Python - // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api - // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. - constexpr int kLastiWordBytes = 2; +PyTypeObject* traceback_type_ = nullptr; - for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr; - py_frame = py_frame->f_back) { - Py_INCREF(py_frame->f_code); - frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); - } -#else // PY_VERSION_HEX < 0x030b0000 +// Entry in a traceback. Must be POD. +struct TracebackEntry { + TracebackEntry() = default; + TracebackEntry(PyCodeObject* code, int lasti) : code(code), lasti(lasti) {} + PyCodeObject* code; + int lasti; -#ifdef PLATFORM_GOOGLE -// This code is equivalent to the version using public APIs, but it saves us -// an allocation of one object per stack frame. However, this is definitely -// violating the API contract of CPython, so we only use this where we can be -// confident we know exactly which CPython we are using (internal to Google). -// Feel free to turn this on if you like, but it might break at any time! -#if PY_VERSION_HEX < 0x030d0000 - for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; - f != nullptr; f = f->previous) { - if (_PyFrame_IsIncomplete(f)) continue; - Py_INCREF(f->f_code); - frames_.emplace_back(f->f_code, - _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + bool operator==(const TracebackEntry& other) const { + return code == other.code && lasti == other.lasti; } -#else // PY_VERSION_HEX < 0x030d0000 - for (_PyInterpreterFrame* f = thread_state->current_frame; f != nullptr; - f = f->previous) { - if (_PyFrame_IsIncomplete(f)) continue; - Py_INCREF(f->f_executable); - frames_.emplace_back(reinterpret_cast(f->f_executable), - _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + bool operator!=(const TracebackEntry& other) const { + return !operator==(other); } -#endif // PY_VERSION_HEX < 0x030d0000 - -#else // PLATFORM_GOOGLE - PyFrameObject* next; - for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); - py_frame != nullptr; py_frame = next) { - frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)); - next = PyFrame_GetBack(py_frame); - Py_XDECREF(py_frame); - } -#endif // PLATFORM_GOOGLE +}; +static_assert(std::is_trivial_v == true); -#endif // PY_VERSION_HEX < 0x030b0000 +template +H AbslHashValue(H h, const TracebackEntry& entry) { + h = H::combine(std::move(h), entry.code, entry.lasti); + return h; } -Traceback::~Traceback() { - for (auto& frame : frames_) { - DCHECK(PyGILState_Check()); - Py_DECREF(frame.first); - } +struct TracebackObject { + PyObject_VAR_HEAD; + TracebackEntry frames[]; +}; + +template +H AbslHashValue(H h, const TracebackObject& tb) { + h = H::combine_contiguous(std::move(h), &tb.frames[0], Py_SIZE(&tb)); + return h; } -Traceback::Traceback(Traceback&& other) noexcept - : frames_(std::move(other.frames_)) { - // absl::InlinedVector does not always clear itself if moved. Since we rely on - // its empty() method to destroy Traceback differently, we explicitly clear - // here. - other.frames_.clear(); +static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0); +static_assert(sizeof(TracebackEntry) % alignof(void*) == 0); + +bool traceback_check(nb::handle o) { + return Py_TYPE(o.ptr()) == traceback_type_; } -std::string Traceback::Frame::ToString() const { - return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), - line_num, nb::cast(function_name)); +Py_hash_t traceback_tp_hash(PyObject* o) { + TracebackObject* tb = reinterpret_cast(o); + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. } -std::string Traceback::ToString() const { - std::vector frame_strs; - frame_strs.reserve(frames_.size()); - for (const Frame& frame : Frames()) { - frame_strs.push_back(frame.ToString()); +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); } - return absl::StrJoin(frame_strs, "\n"); + + if (!traceback_check(other)) { + return Py_NewRef(Py_False); + } + TracebackObject* tb_self = reinterpret_cast(self); + TracebackObject* tb_other = reinterpret_cast(other); + if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) { + if ((tb_self->frames[i] != tb_other->frames[i])) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + } + return Py_NewRef(op == Py_EQ ? Py_True : Py_False); } -std::vector Traceback::Frames() const { - // We require the GIL because we manipulate Python strings. - CHECK(PyGILState_Check()); - std::vector frames; - frames.reserve(frames_.size()); - for (const auto& frame : frames_) { - frames.push_back(Frame{nb::borrow(frame.first->co_filename), - nb::borrow(frame.first->co_name), - frame.first->co_firstlineno, - PyCode_Addr2Line(frame.first, frame.second)}); +static void traceback_tp_dealloc(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + Py_XDECREF(tb->frames[i].code); } - return frames; + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); } -std::optional> Traceback::Get() { - DCHECK(PyGILState_Check()); - if (!enabled_) { - return std::nullopt; +Traceback::Frame DecodeFrame(const TracebackEntry& frame) { + return Traceback::Frame{ + .file_name = nb::borrow(frame.code->co_filename), + .function_name = nb::borrow(frame.code->co_name), + .function_start_line = frame.code->co_firstlineno, + .line_num = PyCode_Addr2Line(frame.code, frame.lasti), + }; +} + +std::string traceback_to_string(const TracebackObject* tb) { + std::vector frame_strs; + frame_strs.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString()); } - return make_nb_class(); + return absl::StrJoin(frame_strs, "\n"); } -void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } +PyObject* traceback_tp_str(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + return nb::cast(traceback_to_string(tb)).release().ptr(); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, reinterpret_cast(traceback_tp_hash)}, + {Py_tp_richcompare, reinterpret_cast(traceback_tp_richcompare)}, + {Py_tp_dealloc, reinterpret_cast(traceback_tp_dealloc)}, + {Py_tp_str, reinterpret_cast(traceback_tp_str)}, + {0, nullptr}, +}; -nb::object Traceback::AsPythonTraceback() const { +nb::object AsPythonTraceback(const Traceback& tb) { nb::object traceback = nb::none(); nb::dict globals; nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); - for (const std::pair& frame : frames_) { - int lineno = PyCode_Addr2Line(frame.first, frame.second); + TracebackObject* tb_obj = reinterpret_cast(tb.ptr()); + for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) { + const TracebackEntry& frame = tb_obj->frames[i]; + int lineno = PyCode_Addr2Line(frame.code, frame.lasti); // Under Python 3.11 we observed crashes when using a fake PyFrameObject // with a real PyCodeObject (https://github.com/google/jax/issues/16027). // because the frame does not have fields necessary to compute the locals, @@ -172,8 +189,8 @@ nb::object Traceback::AsPythonTraceback() const { // We therefore always build a fake code object to go along with our fake // frame. PyCodeObject* py_code = - PyCode_NewEmpty(PyUnicode_AsUTF8(frame.first->co_filename), - PyUnicode_AsUTF8(frame.first->co_name), lineno); + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename), + PyUnicode_AsUTF8(frame.code->co_name), lineno); PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, globals.ptr(), /*locals=*/nullptr); Py_DECREF(py_code); @@ -183,55 +200,123 @@ nb::object Traceback::AsPythonTraceback() const { /*tb_frame=*/ nb::steal(reinterpret_cast(py_frame)), /*tb_lasti=*/0, - /*tb_lineno=*/ - PyCode_Addr2Line(frame.first, frame.second)); + /*tb_lineno=*/lineno); } return traceback; } -namespace { +} // namespace -Py_hash_t traceback_tp_hash(PyObject* o) { - Traceback* tb; - if (!nb::try_cast(nb::handle(o), tb)) { - PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); - return -1; +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + TracebackObject* tb = reinterpret_cast(ptr()); + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + const TracebackEntry& frame = tb->frames[i]; + frames.push_back(Frame{nb::borrow(frame.code->co_filename), + nb::borrow(frame.code->co_name), + frame.code->co_firstlineno, + PyCode_Addr2Line(frame.code, frame.lasti)}); } - size_t h = absl::HashOf(*tb); - Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. - return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. + return frames; } -PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { - if (op != Py_EQ && op != Py_NE) { - return Py_NewRef(Py_NotImplemented); +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + return traceback_to_string(reinterpret_cast(ptr())); +} + +std::vector> Traceback::RawFrames() const { + const TracebackObject* tb = reinterpret_cast(ptr()); + std::vector> frames; + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti)); } + return frames; +} + +/*static*/ bool Traceback::Check(PyObject* o) { return traceback_check(o); } + +/*static*/ std::optional Traceback::Get() { + // We use a thread_local here mostly to avoid requiring a large amount of + // space. + thread_local std::array frames; + int count = 0; + + DCHECK(PyGILState_Check()); - Traceback* x; - if (!nb::try_cast(nb::handle(self), x)) { - PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); - return nullptr; + if (!traceback_enabled_.load()) { + return std::nullopt; } - bool result; - Traceback* y; - if (nb::try_cast(nb::handle(other), y)) { - result = ((*x == *y) == (op == Py_EQ)); - } else { - result = (op == Py_NE); + PyThreadState* thread_state = PyThreadState_GET(); + +#if PY_VERSION_HEX < 0x030b0000 + // The representation of frame->f_lasti changed from bytes to words in Python + // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api + // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. + constexpr int kLastiWordBytes = 2; + + for (PyFrameObject* py_frame = thread_state->frame; + py_frame != nullptr && count < kMaxFrames; py_frame = py_frame->f_back) { + Py_INCREF(py_frame->f_code); + frames[count] = {py_frame->f_code, py_frame->f_lasti * kLastiWordBytes}; + ++count; } - return Py_NewRef(result ? Py_True : Py_False); -} +#else // PY_VERSION_HEX < 0x030b0000 -// It turns out to be slightly faster to define a tp_hash slot rather than -// defining __hash__ and __eq__ on the class. -PyType_Slot traceback_slots_[] = { - {Py_tp_hash, (void*)traceback_tp_hash}, - {Py_tp_richcompare, (void*)traceback_tp_richcompare}, - {0, nullptr}, -}; +#ifdef PLATFORM_GOOGLE +// This code is equivalent to the version using public APIs, but it saves us +// an allocation of one object per stack frame. However, this is definitely +// violating the API contract of CPython, so we only use this where we can be +// confident we know exactly which CPython we are using (internal to Google). +// Feel free to turn this on if you like, but it might break at any time! +#if PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames[count] = {f->f_code, static_cast(_PyInterpreterFrame_LASTI(f) * + sizeof(_Py_CODEUNIT))}; + ++count; + } +#else // PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->current_frame; f != nullptr; + f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_executable); + frames[count] = {reinterpret_cast(f->f_executable), + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)}; + ++count; + } +#endif // PY_VERSION_HEX < 0x030d0000 -} // namespace +#else // PLATFORM_GOOGLE + PyFrameObject* next; + for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr; py_frame = next) { + frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)}; + ++count; + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + +#endif // PY_VERSION_HEX < 0x030b0000 + + Traceback traceback = + nb::steal(PyObject_NewVar(PyObject, traceback_type_, count)); + TracebackObject* tb = reinterpret_cast(traceback.ptr()); + std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count); + return traceback; +} void BuildTracebackSubmodule(nb::module_& m) { nb::class_(m, "Frame") @@ -246,47 +331,69 @@ void BuildTracebackSubmodule(nb::module_& m) { nb::cast(frame.file_name), frame.line_num); }); - nb::class_ traceback(m, "Traceback", - nb::type_slots(traceback_slots_), - "Represents a Python stack trace."); - traceback.def_prop_rw_static( - "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, - [](nb::object /* cls */, bool enabled) { - return Traceback::SetEnabled(enabled); - }); - traceback.def_static( - "get_traceback", []() { return Traceback::Get(); }, - R"doc( - Returns a :class:`Traceback` for the current thread. - - If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object - that describes the Python stack of the calling thread. Stack trace - collection has a small overhead, so it is disabled by default. If traceback - collection is disabled, returns ``None``. - )doc"); - traceback.def_prop_ro("frames", &Traceback::Frames); - traceback.def("raw_frames", [](const Traceback& tb) -> nb::tuple { - // We return a tuple of lists, rather than a list of tuples, because it - // is cheaper to allocate only three Python objects for everything rather - // than one per frame. - nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); - nb::list out_lasti = - nb::steal(PyList_New(tb.raw_frames().size())); - for (size_t i = 0; i < tb.raw_frames().size(); ++i) { - const auto& frame = tb.raw_frames()[i]; - PyObject* code = reinterpret_cast(frame.first); - Py_INCREF(code); - PyList_SET_ITEM(out_code.ptr(), i, code); - PyList_SET_ITEM(out_lasti.ptr(), i, - nb::int_(frame.second).release().ptr()); - } - return nb::make_tuple(out_code, out_lasti); - }); - traceback.def("__str__", &Traceback::ToString); - traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Traceback"); + + PyType_Spec traceback_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(TracebackObject)), + /*.itemsize=*/static_cast(sizeof(TracebackEntry)), + /*.flags=*/Py_TPFLAGS_DEFAULT, + /*.slots=*/traceback_slots_, + }; + + traceback_type_ = + reinterpret_cast(PyType_FromSpec(&traceback_spec)); + if (!traceback_type_) { + throw nb::python_error(); + } + + auto type = nb::borrow(traceback_type_); + m.attr("Traceback") = type; + + m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); }); + m.def("set_tracebacks_enabled", + [](bool value) { traceback_enabled_.store(value); }); + + type.attr("get_traceback") = nb::cpp_function(Traceback::Get, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. )doc"); + type.attr("frames") = nb_property_readonly(&Traceback::Frames); + type.attr("raw_frames") = nb::cpp_function( + [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything + // rather than one per frame. + std::vector> frames = tb.RawFrames(); + nb::list out_code = nb::steal(PyList_New(frames.size())); + nb::list out_lasti = nb::steal(PyList_New(frames.size())); + for (size_t i = 0; i < frames.size(); ++i) { + const auto& frame = frames[i]; + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }, + nb::is_method()); + type.attr("as_python_traceback") = + nb::cpp_function(AsPythonTraceback, nb::is_method()); - traceback.def_static( - "traceback_from_frames", + type.attr("traceback_from_frames") = nb::cpp_function( [](std::vector frames) { nb::object traceback = nb::none(); nb::dict globals; @@ -312,8 +419,7 @@ void BuildTracebackSubmodule(nb::module_& m) { }, "Creates a traceback from a list of frames."); - traceback.def_static( - "code_addr2line", + type.attr("code_addr2line") = nb::cpp_function( [](nb::handle code, int lasti) { if (!PyCode_Check(code.ptr())) { throw xla::XlaRuntimeError("code argument must be a code object"); @@ -324,8 +430,7 @@ void BuildTracebackSubmodule(nb::module_& m) { "Python wrapper around the Python C API function PyCode_Addr2Line"); #if PY_VERSION_HEX >= 0x030b0000 - traceback.def_static( - "code_addr2location", + type.attr("code_addr2location") = nb::cpp_function( [](nb::handle code, int lasti) { if (!PyCode_Check(code.ptr())) { throw xla::XlaRuntimeError("code argument must be a code object"); diff --git a/jaxlib/traceback.h b/jaxlib/traceback.h index 97699a7b3de9..9ae7e9e0836f 100644 --- a/jaxlib/traceback.h +++ b/jaxlib/traceback.h @@ -24,38 +24,23 @@ limitations under the License. #include // placeholder for index annotation headers -#include "absl/container/inlined_vector.h" #include "nanobind/nanobind.h" -#include "jaxlib/nb_class_ptr.h" namespace xla { -// Represents a Python traceback. This object is designed to be allocated on -// the Python heap; creating or destroying a traceback requires the GIL. -class Traceback { +class Traceback : public nanobind::object { public: - // Requires GIL. Creates a Traceback object that requires destructor to be - // invoked with GIL held as well. - static std::optional> Get(); - - // Requires GIL. - static bool enabled() { return enabled_; } - // Requires GIL. - static void SetEnabled(bool enabled); - - // Requires GIL. Don't call this directly, you're looking for Get(). - Traceback(); - // Requires GIL. - ~Traceback(); - - Traceback(const Traceback&) = delete; - Traceback(Traceback&& other) noexcept; - Traceback& operator=(const Traceback&) = delete; - Traceback& operator=(Traceback&&) = delete; - - // Requires the GIL be held. + NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check); + + // Returns a traceback if it is enabled, otherwise returns nullopt. + static std::optional Get(); + + // Returns a string representation of the traceback. std::string ToString() const; + // Returns a list of (code, lasti) pairs for each frame in the traceback. + std::vector> RawFrames() const; + struct Frame { nanobind::str file_name; nanobind::str function_name; @@ -64,44 +49,13 @@ class Traceback { std::string ToString() const; }; + // Returns a list of Frames for the traceback. std::vector Frames() const; - const absl::InlinedVector, 32>& raw_frames() - const { - return frames_; - } - - // Returns the traceback as a fake Python Traceback object, suitable for - // using as an exception traceback. - nanobind::object AsPythonTraceback() const; - - bool operator==(const Traceback& other) const { - return frames_ == other.frames_; - } - bool operator!=(const Traceback& other) const { - return frames_ != other.frames_; - } - private: - // Each frame is a pair of a code object and a "lasti" instruction location - // in bytes. The size of _Py_CODEUNIT has changed across different Python - // versions; the lasti value here has already been multiplied by - // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions - // like PyCode_Addr2Line(). - absl::InlinedVector, 32> frames_; - - // Protected by GIL. - static bool enabled_; + static bool Check(PyObject* o); }; -using nb_traceback = nb_class_ptr; - -template -H AbslHashValue(H h, const Traceback& traceback) { - h = H::combine(std::move(h), traceback.raw_frames()); - return h; -} - void BuildTracebackSubmodule(nanobind::module_& m); } // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index f8f7dc373aa6..90c9478abf93 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -511,12 +511,12 @@ def register_custom_type_id_handler( @contextlib.contextmanager def tracebacks(enabled=True): """Context manager that enables or disables traceback collection.""" - saved = Traceback.enabled - Traceback.enabled = enabled + saved = _xla.tracebacks_enabled() + _xla.set_tracebacks_enabled(enabled) try: yield finally: - Traceback.enabled = saved + _xla.set_tracebacks_enabled(saved) XlaRuntimeError = _xla.XlaRuntimeError From e04cc283d84a2df3ab0baa9c37f19f90600e11c1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jun 2025 17:52:14 -0700 Subject: [PATCH 1666/1769] Make the params of more jaxpr primitives hashable. PiperOrigin-RevId: 770864351 --- jax/_src/api.py | 4 ++-- jax/_src/custom_partitioning.py | 4 ++-- jax/_src/dispatch.py | 2 +- jax/_src/internal_test_util/test_harnesses.py | 4 ++-- jax/_src/pjit.py | 2 +- tests/hijax_test.py | 8 +++++--- tests/jaxpr_effects_test.py | 6 +++--- 7 files changed, 16 insertions(+), 14 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a93d93c57e70..db1e62910140 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2611,8 +2611,8 @@ def device_put( for xf, d in zip(x_flat, device_flat): _check_sharding(shaped_abstractify(xf), d) out_flat = dispatch.device_put_p.bind( - *x_flat, devices=device_flat, srcs=src_flat, - copy_semantics=copy_semantics) + *x_flat, devices=tuple(device_flat), srcs=tuple(src_flat), + copy_semantics=tuple(copy_semantics)) return tree_unflatten(treedef, out_flat) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 322aa33d6d30..9d57c6b0c038 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -484,10 +484,10 @@ def __call__(self, *args, **kwargs): args, require_static_args_hashable=False, ) - static_args = [args[i] for i in self.static_argnums] + static_args = tuple(args[i] for i in self.static_argnums) _check_for_tracers(static_args) else: - static_args = [] + static_args = () f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args args_flat, in_tree = tree_util.tree_flatten(dyn_args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 289c9e0fb89b..972c7f6ffb23 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -614,7 +614,7 @@ def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): assert cp == CopySemantics.COPY new_copy_semantics.append(CopySemantics.COPY) ys = device_put_p.bind(*args, devices=srcs, srcs=devices, - copy_semantics=new_copy_semantics) + copy_semantics=tuple(new_copy_semantics)) for i, y in zip(indices, ys): results[i] = y return results diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 7445b9cfcb6f..a3e873c43c92 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -654,8 +654,8 @@ def _make_device_put_harness(name, "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", lambda x: dispatch.device_put_p.bind( - x, devices=[_device_fn()], srcs=[None], - copy_semantics=[dispatch.CopySemantics.ALIAS])[0], + x, devices=(_device_fn(),), srcs=(None,), + copy_semantics=(dispatch.CopySemantics.ALIAS,))[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f9e751bb1e40..f1ba166230b0 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2892,7 +2892,7 @@ def _sharding_constraint_batcher( sharding=vmapped_sharding, layout=layout, context_mesh=context_mesh, - unconstrained_dims=unconstrained_dims) + unconstrained_dims=frozenset(unconstrained_dims)) return y, d batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher batching.skippable_batchers[sharding_constraint_p] = lambda _: () diff --git a/tests/hijax_test.py b/tests/hijax_test.py index de0862b7821c..ccdbe7371b69 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -215,7 +215,7 @@ def new_box(): def box_get(box): tys = core.cur_qdd(box) - leaf_vals = box_get_p.bind(box, avals=tys.leaf_avals) + leaf_vals = box_get_p.bind(box, avals=tuple(tys.leaf_avals)) return jax.tree.unflatten(tys.treedef, leaf_vals) def box_set(box, val): @@ -358,8 +358,10 @@ def to_lojax(_, box, *, avals): def jvp(_, primals, tangents, *, avals): (box,), (box_dot,) = primals, tangents - return (box_get_p.bind(box, avals=avals), - box_get_p.bind(box_dot, avals=[a.to_tangent_aval() for a in avals])) + return ( + box_get_p.bind(box, avals=avals), + box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) + ) def transpose(_, *args): assert False # TODO diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index d5574f8a9a1d..783d62dcc47e 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -527,7 +527,7 @@ def log_value(x): @jax.jit def f(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) f(2.) jax.effects_barrier() @@ -552,11 +552,11 @@ def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) @jax.jit def g(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0]) y = jax.device_put(3., jax.devices()[1]) From d840447a57aa396fa2bf35f24a266b25e8e823c8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 12 Jun 2025 21:58:23 -0700 Subject: [PATCH 1667/1769] Remove unused internal optimization_barrier alias PiperOrigin-RevId: 770930697 --- jax/_src/ad_checkpoint.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index b11afd4a86de..c49614521a1c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -879,6 +879,3 @@ def checkpoint_wrapper( raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) - -# TODO(phawkins): update users to refer to the public name. -_optimization_barrier = lax_internal.optimization_barrier From 382b3e026355d447ce46c7426397d8c8ea796d66 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 12 Jun 2025 22:22:22 -0700 Subject: [PATCH 1668/1769] fix-forward for pallas tpu memory spaces test PiperOrigin-RevId: 770937094 --- jax/_src/pallas/mosaic/pallas_call_registration.py | 2 ++ jax/_src/pallas/mosaic/primitives.py | 2 ++ tests/pallas/tpu_pallas_memory_space_test.py | 5 ++++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 95d07d7ddb16..fc01b696dcd8 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -83,6 +83,8 @@ def _get_memory_space_from_aval( return None case tpu_core.MemorySpace.ANY: return None + case tpu_core.MemorySpace.HBM: + return tpu_custom_call.MemorySpace.HBM case tpu_core.MemorySpace.VMEM: return tpu_custom_call.MemorySpace.VMEM case tpu_core.MemorySpace.SMEM: diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7fb1df19d747..aac5af84bd19 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -802,6 +802,8 @@ def with_memory_space_constraint( Returns: The array `x` with the memory space constraint. """ + if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}: + return x if memory_space not in {tpu_core.MemorySpace.HBM, tpu_core.MemorySpace.VMEM}: raise NotImplementedError( "with_memory_space_constraint only supports HBM and VMEM." diff --git a/tests/pallas/tpu_pallas_memory_space_test.py b/tests/pallas/tpu_pallas_memory_space_test.py index f7536833ce57..d3c62e329047 100644 --- a/tests/pallas/tpu_pallas_memory_space_test.py +++ b/tests/pallas/tpu_pallas_memory_space_test.py @@ -55,7 +55,10 @@ def g(x): @jax.jit def f(x): x = pltpu.with_memory_space_constraint(x, memory_space=memory_space) - self.assertEqual(pltpu.get_memory_space(x), memory_space) + if color is None: + self.assertIsNone(pltpu.get_memory_space(x)) + else: + self.assertEqual(pltpu.get_memory_space(x), memory_space) x = g(x) return x From c86fefb5319ca88093299bc9bfe37e7c0aff8f22 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 13 Jun 2025 00:40:24 -0700 Subject: [PATCH 1669/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f67ae6dc96dacaffe4d3a9b50de3dbe3ca89fffd. PiperOrigin-RevId: 770974030 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index bf92aff487a4..122f6c0c3d47 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4d1cd8faa246a936bb70790fc7d21d6c236d2163" -XLA_SHA256 = "fd9aee891ef0a38507d59cffc1540bf0f6653911c55c80c20a648e83d974bfbe" +XLA_COMMIT = "f67ae6dc96dacaffe4d3a9b50de3dbe3ca89fffd" +XLA_SHA256 = "af2cfc63a5be306b95c8f7f55dd74d1fcfbd589b69c68d060d641f4d233fc715" def repo(): tf_http_archive( From 142ace2cfb344e6cd959c01ac1cce7e6ff219e47 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 13 Jun 2025 01:03:32 -0700 Subject: [PATCH 1670/1769] Move jax._src.callback to its own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. This will unblock moving `jax/_src/lax` to its own BUILD rule. PiperOrigin-RevId: 770980397 --- jax/BUILD | 27 ++++++++++++++++++++++++++- jax/_src/callback.py | 20 ++++++++++---------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index a40065fed695..651754d89ff8 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -297,7 +297,6 @@ py_library_providing_imports_info( "_src/__init__.py", "_src/ad_checkpoint.py", "_src/blocked_sampler.py", - "_src/callback.py", "_src/checkify.py", "_src/debugging.py", "_src/dlpack.py", @@ -367,6 +366,7 @@ py_library_providing_imports_info( ":basearray", ":batching", ":buffer_callback", + ":callback", ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", @@ -554,6 +554,31 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "callback", + srcs = ["_src/callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":dtypes", + ":effects", + ":ffi", + ":mlir", + ":pickle_util", + ":sharding", + ":sharding_impls", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 5b5ec593a550..8c0bc8f3c6ec 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -20,7 +20,7 @@ import logging from typing import Any -import jax +from jax._src import api from jax._src import config from jax._src import core from jax._src import dispatch @@ -36,12 +36,11 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.lax.control_flow.loops import map as lax_map from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import SdyArray, SdyArrayList, SdyDim, SingleDeviceSharding -from jax._src.typing import DeprecatedArg +from jax._src.typing import Array, DeprecatedArg import numpy as np logger = logging.getLogger(__name__) @@ -67,7 +66,7 @@ class _FlatCallback: callback_func: Callable[..., Any] in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`. - def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]: + def __call__(self, *flat_args: Array) -> Sequence[Array]: args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args) return tree_util.tree_leaves(self.callback_func(*args, **kwargs)) @@ -81,15 +80,15 @@ def pure_callback_impl( ): del sharding, vmap_method, result_avals try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.pure_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -424,15 +423,15 @@ def io_callback_impl( ): del result_avals, sharding, ordered try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.io_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -472,6 +471,7 @@ def io_callback_transpose_rule(*args, **kwargs): def io_callback_batching_rule( args, dims, callback, result_avals, sharding, ordered ): + from jax._src.lax.control_flow.loops import map as lax_map # pytype: disable=import-error if ordered: raise ValueError("Cannot `vmap` ordered IO callback.") is_batched = [d is not batching.not_mapped for d in dims] From 604b6048f1d6c544441a571b22c4969b0b688904 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 01:50:00 -0700 Subject: [PATCH 1671/1769] [Mosaic GPU] Convert all memrefs with transforms to unrealized casts and check them. PiperOrigin-RevId: 770993915 --- .../mosaic/gpu/dialect_lowering.py | 399 ++++++++++++++++-- .../mosaic/gpu/transform_inference.py | 11 +- tests/mosaic/gpu_dialect_test.py | 5 +- tests/mosaic/gpu_test.py | 95 +++++ 4 files changed, 463 insertions(+), 47 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index ee62dfd751ef..296dbc118477 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -155,6 +155,47 @@ def _fragmented_array_from_ir( ).to_layout(layouts.from_layout_attr(layout)) +def wrap_transformed_memref( + transformed_memref: ir.Value, + logical_type: ir.Type, + transforms: ir.ArrayAttr, +) -> ir.Value: + """Wraps a transformed memref to an unrealized cast with transforms. + + The return type of the cast is the untransformed logical type. + """ + conversion_cast = builtin.UnrealizedConversionCastOp( + [logical_type], [transformed_memref] + ) + conversion_cast.attributes["transforms"] = transforms + return conversion_cast.result + + +def unwrap_transformed_memref( + ref: ir.Value, expected_transforms: ir.ArrayAttr +) -> ir.Value: + """Uwraps a memref from an unrealized cast and verifies its transforms.""" + + conversion_cast = cast( + builtin.UnrealizedConversionCastOp, ref.owner.opview # pytype: disable=attribute-error + ) + + if not isinstance(conversion_cast, builtin.UnrealizedConversionCastOp): + raise ValueError(f"{conversion_cast} is not a conversion_cast") + + # Check that the actual transforms match the expected ones. + if expected_transforms != conversion_cast.attributes["transforms"]: + raise ValueError( + f"Expected transforms {expected_transforms} do not match actual" + f" transforms {conversion_cast.attributes['transforms']}" + ) + + result = builtin.unrealized_conversion_cast( + [conversion_cast.operands[0].type], [conversion_cast] + ) + return result + + def _register_lowering( op: str | Type[ir.OpView] | None ) -> Callable[[MlirLoweringRule], MlirLoweringRule]: @@ -360,12 +401,13 @@ def _vector_load_op_lowering_rule( vec_size=strided_layout.vec_size, ) elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: + transforms_attr = inference_utils.in_transforms(vector_load_op)[0] swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_load_op)[0] + transforms_attr ) ref_ty = ir.MemRefType(vector_load_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = reinterpret_smem_ref(vector_load_op.base, transforms) + transformed_ref = unwrap_transformed_memref(vector_load_op.base, transforms_attr) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, @@ -401,20 +443,27 @@ def _vector_store_op_lowering_rule( ) mgpu_utils.warpgroup_barrier() # Make sure the reads have completed. - if fragmented_array.layout == fa.WGMMA_LAYOUT: + + unwrapped_ref = vector_store_op.base + swizzle = None + if inference_utils.should_have_transforms(vector_store_op): + # Not all vector loads have transforms. E.g. if the store is directly to + # gmem, it won't have any transforms. + transforms_attr = inference_utils.in_transforms(vector_store_op)[0] swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_store_op)[0] + transforms_attr ) ref_ty = ir.MemRefType(vector_store_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - fragmented_array.store_tiled( - reinterpret_smem_ref(vector_store_op.base, transforms), swizzle - ) + unwrapped_ref = unwrap_transformed_memref(vector_store_op.base, transforms_attr) + + if fragmented_array.layout == fa.WGMMA_LAYOUT: + fragmented_array.store_tiled(unwrapped_ref, swizzle) elif (fragmented_array.layout == fa.WGMMA_ROW_LAYOUT or fragmented_array.layout == fa.WGMMA_COL_LAYOUT or isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): - fragmented_array.store_untiled(vector_store_op.base) + fragmented_array.store_untiled(unwrapped_ref) else: raise ValueError( f"{vector_store_op} has an unsupported layout: {to_store_layout}" @@ -628,12 +677,12 @@ def _transformed_smem_ref_type( raise ValueError(f"Only workgroup memory is supported but got {ref_ty}.") shape = ref_ty.shape + strides, offset = ref_ty.get_strides_and_offset() if transposed: if len(shape) != 2: raise NotImplementedError( f"Only 2D shapes can be transposed, but got {shape}" ) - strides, _ = ref_ty.get_strides_and_offset() if strides[0] != 1 or strides[1] != shape[0]: raise NotImplementedError( f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" @@ -666,7 +715,7 @@ def _transformed_smem_ref_type( shape, ref_ty.element_type, memory_space=ref_ty.memory_space, - layout=ir.StridedLayoutAttr.get(0, new_strides), + layout=ir.StridedLayoutAttr.get(offset, new_strides), ) return new_ref_ty @@ -699,14 +748,13 @@ def _mgpu_async_load_op_lowering_rule( assert ctx.launch_context is not None barrier = utils.DialectBarrierRef.from_barrier_memref(load_op.barrier) - if inference_utils.has_in_transforms_set(load_op): - [transforms] = inference_utils.in_transforms(load_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms - ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + [transforms_attr] = inference_utils.in_transforms(load_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + unwrapped_destination = unwrap_transformed_memref( + load_op.destination, transforms_attr + ) gmem_slice = [] for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): @@ -723,7 +771,7 @@ def _mgpu_async_load_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=reinterpret_smem_ref(load_op.destination, transforms), + dst_ref=unwrapped_destination, gmem_slice=tuple(gmem_slice), barrier=barrier.barrier_ref, arrive=False, @@ -740,14 +788,11 @@ def _mgpu_async_store_op_lowering_rule( ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - if inference_utils.has_in_transforms_set(store_op): - [transforms] = inference_utils.in_transforms(store_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms - ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + [transforms_attr] = inference_utils.in_transforms(store_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + unwrapped_source = unwrap_transformed_memref(store_op.source, transforms_attr) gmem_slice = [] for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): @@ -763,7 +808,7 @@ def _mgpu_async_store_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=reinterpret_smem_ref(store_op.source, transforms), + src_ref=unwrapped_source, dst_ref=store_op.destination, gmem_slice=tuple(gmem_slice), swizzle=swizzle, @@ -981,18 +1026,20 @@ def _mgpu_wgmma_op_lowering_rule( if ir.VectorType.isinstance(wgmma_op.a.type): a_transforms = None b_transforms = inference_utils.in_transforms(wgmma_op)[0] + unwrapped_a_ref = None + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) else: a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op) + unwrapped_a_ref = unwrap_transformed_memref(wgmma_op.a, a_transforms) + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr( b_transforms ) minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle - ref_ty = ir.MemRefType(wgmma_op.b.type) _check_transforms_and_swizzle_are_supported( - ref_ty, b_transforms, b_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.b.type), b_transforms, b_swizzle, minimum_swizzle ) - b_operand = reinterpret_smem_ref(wgmma_op.b, b_transforms) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -1000,18 +1047,17 @@ def _mgpu_wgmma_op_lowering_rule( a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( a_transforms ) - ref_ty = ir.MemRefType(wgmma_op.a.type) _check_transforms_and_swizzle_are_supported( - ref_ty, a_transforms, a_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.a.type), a_transforms, a_swizzle, minimum_swizzle ) if a_swizzle != b_swizzle: raise ValueError( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = reinterpret_smem_ref(wgmma_op.a, a_transforms) + a_operand = unwrapped_a_ref - new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) + new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle) return [ _fragmented_array_to_ir( @@ -1062,7 +1108,19 @@ def _mgpu_slice_smem_op_lowering_rule( ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx - return [_slice_smem(op.result.type, op.offset)] + sliced_ref = _slice_smem(op.result.type, op.offset) + + memref_ty = ir.MemRefType(sliced_ref.type) + if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"): + # Barrier memrefs are not transformed and must not be wrapped. + assert not inference_utils.has_out_transforms_set(op) + return [sliced_ref] + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + transformed_ref = reinterpret_smem_ref(sliced_ref, transforms) + wrapped_ref = wrap_transformed_memref(transformed_ref, op.result.type, out_transforms) + return [wrapped_ref] def _slice_smem(result: ir.Type, offset: ir.Value): @@ -1085,6 +1143,274 @@ def _slice_smem(result: ir.Type, offset: ir.Value): return builtin.unrealized_conversion_cast([result], [view]) +def _tile_transform_offsets( + tiling: Sequence[int], + static_offsets: Sequence[int], + dynamic_offsets: Sequence[ir.Value], +) -> tuple[Sequence[int], Sequence[ir.Value]]: + """Computes the static and dynamic offsets after the given tiling is applied. + + Conceptually, this function is analogous to + tile.transform_shape(static_offsets), except that it also handles dynamic offsets. + """ + dynamic_offset_index = 0 + new_static_offsets = [] + new_dynamic_offsets = [] + + # Preserve all offsets in non-tiled dimensions. + for offset in static_offsets[: -len(tiling)]: + new_static_offsets.append(offset) + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + new_dynamic_offsets.append(dynamic_offsets[dynamic_offset_index]) + dynamic_offset_index += 1 + + # Compute static and dynamic offsets of tiled dimensions. + for tile_size, offset in zip( + tiling, static_offsets[-len(tiling) :], strict=True + ): + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + # Here we assume that the offset is divisble by the tile size, but we + # don't check it. This has been established at the time the tiling was + # inferred. + dyn_offset = arith.divui( + dynamic_offsets[dynamic_offset_index], + utils.c(tile_size, ir.IndexType.get()), + ) + new_dynamic_offsets.append(dyn_offset) + new_static_offsets.append(ir.ShapedType.get_dynamic_stride_or_offset()) + dynamic_offset_index += 1 + else: + assert offset % tile_size == 0 + new_static_offsets.append(offset // tile_size) + + # Add 0 offsets for the newly created dimension of the tile. + new_static_offsets += [0] * len(tiling) + + return new_static_offsets, new_dynamic_offsets + + +@_register_lowering(memref.SubViewOp) +def _memref_subview_op_lowering_rule( + ctx: LoweringContext, op: memref.SubViewOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + + if in_transforms != out_transforms: + raise NotImplementedError( + "SubViewOp transforms for the input and output refs must be identical." + ) + + if any(s != 1 for s in op.static_strides): + raise NotImplementedError( + "SubViewOp only supports static strides of 1." + ) + + if _is_memref_transposed(op.source.type): + raise NotImplementedError( + "SubViewOp does not support transposed memrefs." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + if swizzle != mgpu.SwizzlingMode.kNoSwizzle: + source_ty = ir.MemRefType(op.source.type) + source_strides, _ = source_ty.get_strides_and_offset() + for stride, slice, size in zip(source_strides, op.static_sizes, source_ty.shape, strict=True): + if stride != 1: + continue + # A dimension with stride 1 is a minor dimension and is swizzled. + if slice != size: + raise NotImplementedError("Slicing a swizzled dimension is unsupported.") + + match transforms: + case (): + new_subview_op = memref.SubViewOp( + op.result.type, + unwrapped_source_ref, + op.offsets, + None, + None, + static_offsets=op.static_offsets, + static_sizes=op.static_sizes, + static_strides=op.static_strides, + ) + case (tile_transform, ) if isinstance(tile_transform, launch_context.TileTransform): + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + tiling = tile_transform.tiling + if any( + ir.ShapedType.is_dynamic_size(s) + for s in list(op.static_sizes)[-len(tiling) :] + ): + raise NotImplementedError( + "SubViewOp only supports static sizes for the tiled dimensions." + ) + new_sizes = tile_transform.transform_shape(list(op.static_sizes)) + new_static_offsets, new_dynamic_offsets = _tile_transform_offsets( + tiling, list(op.static_offsets), list(op.offsets) + ) + + new_subview_op = memref.SubViewOp( + _transformed_smem_ref_type(op.result.type, transforms), + unwrapped_source_ref, + new_dynamic_offsets, + None, + None, + static_offsets=new_static_offsets, + static_sizes=new_sizes, + static_strides=[1] * len(in_transformed_ty.shape), + ) + case _: + raise NotImplementedError( + "SubViewOp only supports a single tile transform." + ) + + wrapped_ref = wrap_transformed_memref( + new_subview_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.CastOp) +def _memref_cast_op_lowering_rule( + ctx: LoweringContext, op: memref.CastOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.CastOp. + Only casts that add a dynamic offset are supported. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + if in_transforms != out_transforms: + raise NotImplementedError( + "CastOp transforms for the input and output refs must be identical." + ) + + in_ty = ir.MemRefType(op.source.type) + out_ty = ir.MemRefType(op.result.type) + if in_ty.element_type != out_ty.element_type: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same element type." + ) + if in_ty.shape != out_ty.shape: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same shape." + ) + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, out_offset = out_ty.get_strides_and_offset() + if in_strides != out_strides: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same strides." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + transformed_strides, _ = in_transformed_ty.get_strides_and_offset() + out_layout = ir.StridedLayoutAttr.get(out_offset, transformed_strides) + out_transformed_ty = ir.MemRefType.get( + in_transformed_ty.shape, + in_transformed_ty.element_type, + memory_space=in_transformed_ty.memory_space, + layout=out_layout, + ) + new_cast_op = memref.CastOp(out_transformed_ty, unwrapped_source_ref) + wrapped_ref = wrap_transformed_memref( + new_cast_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +def _permutation_to_affine_map_attr( + permutation: Sequence[int], +) -> ir.AffineMapAttr: + return ir.AffineMapAttr.get(ir.AffineMap.get_permutation(permutation)) + + +@_register_lowering(memref.TransposeOp) +def _memref_transpose_op_lowering_rule( + ctx: LoweringContext, op: memref.TransposeOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + unwrapped_in_ref = unwrap_transformed_memref(op.in_, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_in_ref.type) + if len(in_transformed_ty.shape) == 2: + new_permutation = op.permutation + elif len(in_transformed_ty.shape) == 4: + if op.permutation == _permutation_to_affine_map_attr([0, 1]): + new_permutation = _permutation_to_affine_map_attr([0, 1, 2, 3]) + elif op.permutation == _permutation_to_affine_map_attr([1, 0]): + new_permutation = _permutation_to_affine_map_attr([1, 0, 3, 2]) + else: + raise NotImplementedError("Unsupported permutation.") + else: + raise NotImplementedError( + "TransposeOp only supports transposing 2D and 4D memrefs." + ) + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + new_transpose_op = memref.TransposeOp( + _transformed_smem_ref_type(op.result.type, transforms), + unwrapped_in_ref, + new_permutation, + ) + + wrapped_ref = wrap_transformed_memref( + new_transpose_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.LoadOp) +def _memref_load_op_lowering_rule( + ctx: LoweringContext, op: memref.LoadOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.LoadOp. + + Loads are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.LoadOp does not support transforms: {op}") + + new_load_op = memref.LoadOp( + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [new_load_op.result] + + +@_register_lowering(memref.StoreOp) +def _memref_store_op_lowering_rule( + ctx: LoweringContext, op: memref.StoreOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.StoreOp. + + Stores are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.StoreOp does not support transforms: {op}") + + memref.StoreOp( + value=op.value, + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [] + + # The metadata needed to recostruct a vector from its flattened representation. _VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] @@ -1392,6 +1718,7 @@ def _should_lower(op: ir.OpView) -> bool: return ( op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error or inference_utils.should_have_layout(op) + or inference_utils.should_have_transforms(op) or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? ) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 46d04026d588..19cef27305d2 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -28,7 +28,6 @@ from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector -from jax._src.util import safe_zip from . import fragmented_array as fa from . import inference_utils @@ -292,15 +291,7 @@ def _infer_memref_subview_transforms( ) # Check tile transform propagation. - num_tiled_axes = len(mgpu.TileTransformAttr(tile_transform).tiling) - last_n_dims = op.source.type.shape[-num_tiled_axes:] - last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] - for slice_size, dim_size in safe_zip(last_n_sizes, last_n_dims): - if slice_size != dim_size: - raise NotImplementedError( - "Tile transforms are only propagated if the tiled axes are not " - "sliced." - ) + # TODO(dasenov): implement more precise checks. return [transforms], [transforms] diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 5fc6bc8c703c..c297cb676f2d 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -936,8 +936,11 @@ def test_lowering_slice_smem_op(self): def body(): nonlocal offset i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + memref_ty = ir.MemRefType.get((4, 32), i32, memory_space=smem) offset = arith.constant(i32, shift) - mgpu.dialect.slice_smem(i32, offset) + op = mgpu.dialect.SliceSMEMOp(memref_ty, offset) + op.attributes["out_transforms"] = ir.ArrayAttr.get([ir.ArrayAttr.get([])]) with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func()(body) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c8d2a5e3a291..c2a3d9a19370 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3430,6 +3430,101 @@ def body(ctx, result_gmem_ref, smem): x = np.full(shape, element_value, dtype=dtype) self.assertArraysEqual(jax.jit(kernel)(), x) + def test_subview(self): + full_shape = (2, 3, 128, 64) + offsets = [1, 0, 96, 0] + sub_shape = (32, 64) + + def body( + ctx: launch_context.LaunchContext, + full_gmem_ref: ir.Value, + sub_gmem_ref: ir.Value, + smem: list[ir.Value], + ): + # TODO(dasenov): Add a parametrization to also test subview of transformed + # refs. + + del ctx + full_smem_ref, tma_barrier = smem + dialect_barrier = tma_barrier.as_barrier_memref() + + operand_elt_type = ir.MemRefType(full_gmem_ref.type).element_type + mgpu_dialect.arrive_expect_tx( + barrier=dialect_barrier, + expect_tx=utils.bytewidth(operand_elt_type) * math.prod(full_shape), + ) + + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + # GMEM -> SMEM + mgpu_dialect.async_load( + source=full_gmem_ref, + destination=full_smem_ref, + barrier=dialect_barrier, + indices=[zero_i32] * len(full_shape), + slice_lengths=full_shape, + collective=ir.ArrayAttr.get([]), + ) + + parities = memref.load(tma_barrier.barrier_ref.phases, []) + parity, _ = tma_barrier.update_parities(parities) + mgpu_dialect.wait(dialect_barrier, parity) + + # SubView + dynamic_offsets = [ + arith.constant(ir.IndexType.get(), offsets[0]), + # offsets[1] is a static offset. + arith.constant(ir.IndexType.get(), offsets[2]), + ] + + full_ref_type = ir.MemRefType(full_smem_ref.type) + dynamic = ir.ShapedType.get_dynamic_size() + rhs_subview_ref_type = ir.MemRefType.get( + shape=sub_shape, + element_type=full_ref_type.element_type, + layout=ir.StridedLayoutAttr.get(dynamic, [full_shape[-1], 1]), + memory_space=full_ref_type.memory_space, + ) + sub_smem_ref = memref.SubViewOp( + result=rhs_subview_ref_type, + source=full_smem_ref, + offsets=dynamic_offsets, + sizes=None, + strides=None, + static_offsets=[dynamic, offsets[1], dynamic, offsets[3]], + static_sizes=[1, 1] + list(sub_shape), + static_strides=[1, 1, 1, 1], + ).result + + # SMEM -> GMEM + mgpu_dialect.async_store( + source=sub_smem_ref, + destination=sub_gmem_ref, + indices=[zero_i32, zero_i32], + slice_lengths=sub_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + + el_type = jnp.bfloat16 + full_jax_shape = jax.ShapeDtypeStruct(full_shape, el_type) + result_jax_shape = jax.ShapeDtypeStruct(sub_shape, el_type) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(full_jax_shape), + out_shape=result_jax_shape, + smem_scratch_shape=[full_jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + prng_key = jax.random.key(1234) + x = jax.random.randint(prng_key, full_shape, 0, 10).astype(el_type) + + self.assertArraysEqual( + jax.jit(kernel)(x), + x[offsets[0], offsets[1], offsets[2] :, offsets[3] :], + ) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): From fc8192cacb146fe0be524ab04d2b755f4444eb78 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 03:45:12 -0700 Subject: [PATCH 1672/1769] [Mosaic GPU] Add a Mosaic GPU op `with_transforms` for manually setting memref transforms. PiperOrigin-RevId: 771026094 --- .../mosaic/gpu/dialect_lowering.py | 21 +++++++++++++++++++ .../mosaic/gpu/transform_inference.py | 11 ++++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 16 ++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 296dbc118477..c01038d64088 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -22,6 +22,7 @@ import operator from typing import Any, Sequence, Type, cast +from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -1143,6 +1144,26 @@ def _slice_smem(result: ir.Type, offset: ir.Value): return builtin.unrealized_conversion_cast([result], [view]) +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. +if jaxlib.version >= (0, 6, 2): + @_register_lowering(mgpu.WithTransformsOp) + def _mgpu_with_transforms_op_lowering_rule( + ctx: LoweringContext, op: mgpu.WithTransformsOp + ) -> Sequence[ir.Value]: + """Lowering rule for mgpu.WithTransformsOp. + This is a noop that simply returns its input. + """ + del ctx + + [in_transforms] = inference_utils.in_transforms(op) + unwrapped_source_ref = unwrap_transformed_memref(op.ref, in_transforms) + out_transforms = inference_utils.out_transforms(op)[0] + wrapped_ref = wrap_transformed_memref( + unwrapped_source_ref, op.result.type, out_transforms + ) + return [wrapped_ref] + + def _tile_transform_offsets( tiling: Sequence[int], static_offsets: Sequence[int], diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 19cef27305d2..c8a4afd618c2 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -22,6 +22,7 @@ from functools import partial from typing import cast +from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -344,6 +345,16 @@ def _infer_memref_cast_transforms( return [transforms], [transforms] +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. +if jaxlib.version >= (0, 6, 2): + @partial(_add_transform_inference_rule, mgpu.WithTransformsOp) + def _infer_mgpu_with_transforms_transforms( + op: mgpu.WithTransformsOp, + ) -> OptionalTransforms: + # Do not change the manually provided transforms. + return [op.transforms], [op.transforms] + + def infer_transforms(module: ir.Module): """Infers transforms for the given module. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 217bf1a3593b..1cf8ec11ae66 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -532,4 +532,20 @@ def MosaicGPU_CustomPrimitiveOp : Op let hasVerifier = 1; } +def MosaicGPU_WithTransformsOp : Op { + let summary = "A noop that allows manually setting transforms on a memref."; + let description = [{ + This op enforces the provided transforms on the parameter memref. + }]; + + let arguments = ( + ins MemRefOf<[AnyType]>:$ref, + // Attributes + ArrayAttr:$transforms + ); + + let results = (outs MemRefOf<[AnyType]>); +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ From a4f0e402edbee402bbe8a5db18bcad2d2485ab9e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 04:05:24 -0700 Subject: [PATCH 1673/1769] [Mosaic GPU] Resolve different tile transforms using the largest common divisor. Before this change, transform inference would fail if different tilings needed to be used together (e.g. `(4 ,64)` and `(8, 32)`). After this change such tilings are resolved to use a tiling compatible with both, by taking the largest common divisor of each dimension. In the case above, the final tiling would be `(4, 32)`. This change also adds two additional traversals in the transform inference, as this was necessary in a real-world example of tile resolution (taken from a ragged dot kernel). The new test exercises both changes. This CL also moves `is_known_divisible` from pallas to mgpu and extends it. PiperOrigin-RevId: 771032093 --- jax/_src/pallas/mosaic_gpu/core.py | 26 +--- jax/experimental/mosaic/gpu/__init__.py | 1 + .../mosaic/gpu/transform_inference.py | 84 +++++++++--- jax/experimental/mosaic/gpu/utils.py | 40 ++++++ tests/mosaic/gpu_transform_inference_test.py | 122 ++++++++++++++++++ 5 files changed, 232 insertions(+), 41 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 11f3f1eb4592..7c27ae429114 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -212,30 +212,6 @@ def cmap_body(): return wrapper -def _is_known_divisible(value, divisor, fuel=10) -> bool: - """Returns True if the value is statically known to be divisible by the divisor.""" - if divisor == 1: - return True - if fuel < 0: - return False - if not isinstance(value.owner, ir.Operation): - return False - def_op = value.owner.opview - match def_op: - case arith_dialect.IndexCastOp(): - return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1) - case arith_dialect.ConstantOp(): - return ir.IntegerAttr(def_op.value).value % divisor == 0 - case arith_dialect.MulIOp(): - return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or - _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) - case arith_dialect.SelectOp(): - return (_is_known_divisible(value.owner.operands[1], divisor, fuel // 2) and - _is_known_divisible(value.owner.operands[2], divisor, (fuel + 1)// 2)) - - return False - - @dataclasses.dataclass(frozen=True) class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () @@ -503,7 +479,7 @@ def untransform_index( f" tiling ({tile})" ) if isinstance(idx.base, ir.Value): - if not _is_known_divisible(idx.base, tile): + if not mgpu_utils.is_known_divisible(idx.base, tile): raise ValueError( "Dynamic slice base index (which is a dynamic value) cannot be" f" statically proven to be divisible by the tiling ({tile})" diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index dbee817df20c..7094fc7352d3 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -86,6 +86,7 @@ debug_print as debug_print, ds as ds, fori as fori, + is_known_divisible as is_known_divisible, memref_fold as memref_fold, memref_slice as memref_slice, memref_reshape as memref_reshape, diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index c8a4afd618c2..a39ba28995ee 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -20,6 +20,7 @@ from collections.abc import Callable from functools import partial +import math from typing import cast from jax._src import lib as jaxlib @@ -82,12 +83,28 @@ def _resolve_transforms( if other_transforms is None: return transforms - if transforms != other_transforms: + if len(transforms) != len(other_transforms): raise NotImplementedError( f"Conflicting transforms {transforms} != {other_transforms}." ) - return transforms + new_transforms = [] + for a, b in zip(transforms, other_transforms, strict=True): + if a == b: + new_transforms.append(a) + elif mgpu.TileTransformAttr.isinstance(a) and mgpu.TileTransformAttr.isinstance(b): + a = mgpu.TileTransformAttr(a) + b = mgpu.TileTransformAttr(b) + if len(a.tiling) != len(b.tiling): + raise ValueError(f"Conflicting tile transforms {a} != {b}.") + new_tiling = [] + for tile_a, tile_b in zip(a.tiling, b.tiling): + new_tiling.append(math.gcd(tile_a, tile_b)) + new_transforms.append(mgpu.TileTransformAttr.get(new_tiling)) + else: + raise NotImplementedError(f"Unsupported transforms {a} and {b}") + + return ir.ArrayAttr.get(new_transforms) def _transforms_from_uses(op: ir.OpView) -> ir.Attribute | None: @@ -280,7 +297,7 @@ def _infer_memref_subview_transforms( # - We only propagate transforms if they consist of a single tile transform # and a single swizzle transform. # TODO(bchetioui): implement more complex propagation rules. - tile_transform, _ = _get_tile_and_swizzle_transforms(transforms) + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms(transforms) # Check swizzle transform propagation. strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) @@ -292,9 +309,41 @@ def _infer_memref_subview_transforms( ) # Check tile transform propagation. - # TODO(dasenov): implement more precise checks. + old_tiling = mgpu.TileTransformAttr(tile_transform).tiling + num_tiled_axes = len(old_tiling) + last_n_dims = op.source.type.shape[-num_tiled_axes:] + last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] + last_n_offsets = list(op.static_offsets)[-num_tiled_axes:] - return [transforms], [transforms] + if any(ir.ShapedType.is_dynamic_size(x) for x in last_n_sizes): + raise NotImplementedError( + "Subview transforms with dynamic sizes are not supported." + ) + + dynamic_index = 0 + for i in range(len(last_n_offsets)): + if ir.ShapedType.is_dynamic_size(last_n_offsets[i]): + if utils.is_known_divisible( + op.offsets[dynamic_index], last_n_sizes[i] + ): + last_n_offsets[i] = last_n_sizes[i] + else: + # This will force a tiling of 1 along this axis. This is a safe choice + # (since we couldn't infer a better one) but might not be optimal. + last_n_offsets[i] = 1 + dynamic_index += 1 + + new_tiling = [ + math.gcd(*xs) + for xs in zip( + last_n_sizes, last_n_dims, last_n_offsets, old_tiling, strict=True + ) + ] + + new_transforms = ir.ArrayAttr.get( + [mgpu.TileTransformAttr.get(new_tiling), swizzle_transform] + ) + return [new_transforms], [new_transforms] @partial(_add_transform_inference_rule, memref.TransposeOp) @@ -381,17 +430,20 @@ def inference_step(op: ir.Operation): _set_transform_attributes(op, *maybe_transforms) - # It's enough to do a single backwards propagation (starting from vector - # users), and then a single forward propagation (to feed into the async loads - # and stores). - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS - ) - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD - ) + # We alternate a few backwards propagation (starting from vector users), and + # forward propagation (to feed into the async loads and stores) passes in + # order to enable more complex inference situations. + # + # TODO(bchetioui): Replace this with a more generic inference. + inference_passes = [ + inference_utils.TraversalOrder.BACKWARDS, + inference_utils.TraversalOrder.FORWARD, + inference_utils.TraversalOrder.BACKWARDS, + inference_utils.TraversalOrder.FORWARD, + ] + for traversal_order in inference_passes: + for op in module.body: + inference_utils.traverse_op(op, inference_step, traversal_order) # All ops that should have transforms but have no transforms inferred so far # are assigned an empty sets of transforms. E.g., this happens in kernels with diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 9ea964c98b86..678aad0c91c9 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1498,3 +1498,43 @@ def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: result = vector.insertelement(elem, result, position=c(offset + i, index)) offset += vty.shape[0] return result + + +def is_known_divisible(value, divisor, max_depth=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if divisor == 1: + return True + if max_depth < 0 or not isinstance(value.owner, ir.Operation): + return False + + new_depth = max_depth - 1 + def_op = value.owner.opview + + match def_op: + case arith.IndexCastOp(): + return is_known_divisible(value.owner.operands[0], divisor, max_depth - 1) + case arith.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith.MulIOp(): + # Only cover the case where one operand is divisible. It's still possible + # that the final product is divisible, but we don't check that here. + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) or + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.SelectOp(): + return (is_known_divisible(value.owner.operands[1], divisor, new_depth) and + is_known_divisible(value.owner.operands[2], divisor, new_depth)) + case arith.MaxSIOp() | arith.MinSIOp() | arith.MaxUIOp() | arith.MinUIOp(): + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) and + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.AddIOp() | arith.SubIOp(): + # Only cover the common case where both operads are divisible. + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) and + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.AndIOp(): + # Only cover the specific case where the divisor is a power of two. + return divisor.bit_count() == 1 and ( + is_known_divisible(value.owner.operands[0], divisor, new_depth) + or is_known_divisible(value.owner.operands[1], divisor, new_depth) + ) + + return False diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 420d735fef54..b4a41a141bf6 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized import jax from jax import numpy as jnp +from jax._src import lib as jaxlib from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter @@ -543,6 +544,127 @@ def body(in_ref): with self.assertRaises(NotImplementedError): mgpu.infer_transforms(self.module) + @parameterized.parameters([False, True]) + def test_infer_transforms_for_sibling_subviews_and_distant_op( + self, even_offsets + ): + # This test uses the following op tree extracted from this ragged dot + # kernel: + # https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py + # + # subview_op0 (slice = 64, 64) + # - subview_op1 (slice = 2, 64) + # - subview_op2 (slice = 4, 64, either at an even or odd offset) + # - subview_op3 (slice = 8, 64) + # - user_op0 (in_transforms = [tile(64, 64), swizzle(32)]) + # + # First the in_transforms of user_op0 have to be propagated up to + # subview_op0. Then they have to be propagated down and resolved. Finally + # all subview ops need to have the same transforms. + + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. + if jaxlib.version < (0, 6, 2): + self.skipTest("Test requires jaxlib version >= 0.6.2") + + subview_op0, subview_op1, subview_op2, subview_op3 = None, None, None, None + user_op0 = None + + source_shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + source_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem) + + slice1_shape = (2, 64) + slice2_shape = (4, 64) + slice3_shape = (8, 64) + + slice0_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem) + slice1_ref_ty = ir.MemRefType.get(slice1_shape, elt_ty, memory_space=smem) + slice2_ref_ty = ir.MemRefType.get(slice2_shape, elt_ty, memory_space=smem) + slice3_ref_ty = ir.MemRefType.get(slice3_shape, elt_ty, memory_space=smem) + + def body(source_ref): + nonlocal subview_op0, subview_op1, subview_op2, subview_op3, user_op0 + + subview_op0 = memref.SubViewOp( + slice0_ref_ty, + source_ref, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=source_shape, + static_strides=[1, 1], + ) + + transforms_0 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((64, 64)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + user_op0 = mgpu.dialect.WithTransformsOp(subview_op0.result, transforms_0) + + subview_op1 = memref.SubViewOp( + slice1_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=slice1_shape, + static_strides=[1, 1], + ) + + subview_op2 = memref.SubViewOp( + slice2_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[16 if even_offsets else 15, 0], + static_sizes=slice2_shape, + static_strides=[1, 1], + ) + + # The following ops are just to test the dynamic offsets support. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + c64 = c(64) + c32 = c(32) + c16 = c(16) + subi = arith.subi(c64, c32) + maxsi = arith.maxsi(c16, subi) + addi = arith.addi(maxsi, subi) + andi = arith.andi(addi, maxsi) + idx = arith.index_cast(ir.IndexType.get(), andi) + subview_op3 = memref.SubViewOp( + slice3_ref_ty, + subview_op0, + [idx], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=slice3_shape, + static_strides=[1, 1], + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(source_ref_ty)(body) + + mgpu.infer_transforms(self.module) + + want = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((2 if even_offsets else 1, 64)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + self.assertSequenceEqual(inference_utils.in_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op3), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op3), [want]) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 70c90a95cd616cacfefddeb9c902e49360caeeea Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 04:43:59 -0700 Subject: [PATCH 1674/1769] [Mosaic GPU] Use warpgroup semantics for the ragged dot example kernel. I verified that the performance doesn't change. PiperOrigin-RevId: 771042037 --- .../pallas/ops/gpu/ragged_dot_mgpu.py | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py index 9a1514b9827c..ed23f5eb764d 100644 --- a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -114,11 +114,6 @@ def ragged_dot( raise NotImplementedError( f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}" ) - - elem_bits = jnp.finfo(lhs.dtype).bits - swizzle = _find_swizzle(elem_bits * block_k, "lhs") - swizzle_elems = swizzle * 8 // elem_bits - m, k = lhs.shape g, k2, n = rhs.shape @@ -147,22 +142,12 @@ def mn_loop(idx): # pylint: disable=unused-variable group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi) def acc_scope(acc_ref): - transforms = ( - plgpu.TilingTransform((8, swizzle_elems)), - plgpu.SwizzleTransform(swizzle), - ) plgpu.emit_pipeline( lambda _, lhs_smem, rhs_smem: plgpu.wgmma(acc_ref, lhs_smem, rhs_smem), grid=(k // block_k,), in_specs=[ - plgpu.BlockSpec( - (block_m, block_k), - lambda k: (group_info.block, k), - transforms=transforms, - ), - plgpu.BlockSpec( - (block_k, block_n), lambda k: (k, ni), transforms=transforms - ), + plgpu.BlockSpec((block_m, block_k), lambda k: (group_info.block, k)), + plgpu.BlockSpec((block_k, block_n), lambda k: (k, ni)), ], max_concurrent_steps=max_concurrent_steps, delay_release=1, @@ -171,17 +156,9 @@ def acc_scope(acc_ref): acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n))) - store_transforms = ( - plgpu.TilingTransform((1, swizzle_elems)), - plgpu.SwizzleTransform(swizzle) - ) @functools.partial( pl.run_scoped, - o_smem=plgpu.SMEM( - (block_m, block_n), - dtype=o_gmem.dtype, - transforms=store_transforms, - ) + o_smem=plgpu.SMEM((block_m, block_n), dtype=o_gmem.dtype) ) def store_scope(o_smem): # pylint: disable=unused-variable o_smem[...] = acc.astype(o_smem.dtype) @@ -249,6 +226,9 @@ def _(): out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype), grid=(num_sms,), grid_names=("sm",), + compiler_params=plgpu.CompilerParams( + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), ) return kernel(group_sizes, lhs, rhs) From 9b45aacb572816193124fadd4d82ed8f6247ea75 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Fri, 13 Jun 2025 05:01:30 -0700 Subject: [PATCH 1675/1769] Disable `too_slow` in data.draw() for test_ndindexer The draws are still slow under ASAN. PiperOrigin-RevId: 771046954 --- tests/pallas/indexing_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 9e96252a843b..932076645c1e 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -216,6 +216,7 @@ def test_indexer_with_all_types(self): self.assertTupleEqual(indexer.get_indexer_shape(), (2, 5, 4)) @hp.given(hps.data()) + @hp.settings(suppress_health_check=[hp.HealthCheck.too_slow]) # ASAN is slow def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) indices = data.draw(nd_indices_strategy(shape)) From a3542f882bf1a6525ee3d992d4451c36498f5777 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 05:04:25 -0700 Subject: [PATCH 1676/1769] [Mosaic GPU] Reconcile the swizzle of the a and b operands for wgmma in the Mosaic GPU dialect. These need to match but we didn't ensure this previously. PiperOrigin-RevId: 771047725 --- .../mosaic/gpu/transform_inference.py | 37 ++++++++++++++----- tests/mosaic/gpu_test.py | 5 --- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index a39ba28995ee..f08027506334 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -119,7 +119,10 @@ def _transforms_from_uses(op: ir.OpView) -> ir.Attribute | None: transforms = _resolve_transforms(transforms, user_transforms) return transforms -def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: + +def _infer_transforms_for_wgmma_ref( + ref_ty: ir.MemRefType, max_swizzle: mgpu.SwizzlingMode +) -> tuple[ir.ArrayAttr, mgpu.SwizzlingMode]: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") @@ -136,9 +139,12 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: mgpu.SwizzlingMode.k32ByteSwizzle, mgpu.SwizzlingMode.kNoSwizzle, ]: + if swizzle > max_swizzle: + continue swizzle_elems = swizzle // element_bytewidth if minor_dim % swizzle_elems == 0: minor_tiling = swizzle_elems + inferred_swizzle = swizzle break else: # No valid tile transform can be inferred. @@ -148,19 +154,30 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: tiling = (minor_tiling, major_tiling) else: tiling = (major_tiling, minor_tiling) - return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get(tiling), - mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), - ]) + return ( + ir.ArrayAttr.get([ + mgpu.TileTransformAttr.get(tiling), + mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), + ]), + inferred_swizzle, + ) @partial(_add_transform_inference_rule, mgpu.WGMMAOp) def infer_wgmma_transforms(op: mgpu.WGMMAOp) -> OptionalTransforms: - b_transforms = infer_transforms_for_wgmma_ref(ir.MemRefType(op.b.type)) + b_transforms, b_swizzle = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.b.type), max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) if ir.MemRefType.isinstance(op.a.type): - a_transforms = infer_transforms_for_wgmma_ref( - cast(ir.MemRefType, op.a.type) + a_transforms, a_swizzle = _infer_transforms_for_wgmma_ref( + cast(ir.MemRefType, op.a.type), max_swizzle=b_swizzle ) + if a_swizzle != b_swizzle: + # The swizzle for a and b has to match. + b_transforms, b_swizzle = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.b.type), max_swizzle=a_swizzle + ) + assert a_swizzle == b_swizzle return [a_transforms, b_transforms], [] return [b_transforms], [] @@ -203,8 +220,8 @@ def _infer_vector_load_store_transforms( transforms = inference_utils.value_transforms(op.base) if layout == fa.WGMMA_LAYOUT: - layout_transforms = infer_transforms_for_wgmma_ref( - ir.MemRefType(op.base.type) + layout_transforms, _ = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.base.type), max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle ) elif ( layout == fa.WGMMA_ROW_LAYOUT diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c2a3d9a19370..2e5075453e68 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3547,11 +3547,6 @@ def test_wgmma_kernel_with_tma( if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: self.skipTest("No swizzle is not supported by wgmma") - # TODO(dasenov): This condition is wrong, remove it after we support - # reconciling the siwzzle of the a and b operands of wgmma. - if transpose_lhs and swizzle != mgpu_dialect.SwizzlingMode.k128ByteSwizzle: - self.skipTest("If A is transposed, its swizzle must be 128 bytes.") - if transpose_lhs and load_a_in_registers: self.skipTest("The A operand can only be transposed if it is in SMEM.") From 7b7a5d8c9a486bebb23cca7c781d4254841aa869 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Fri, 13 Jun 2025 14:34:01 +0200 Subject: [PATCH 1677/1769] Add pjit_p to extend.core.primitives --- jax/extend/core/primitives.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index 515dd3e11dcf..5b790656271c 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -223,7 +223,10 @@ schur_p as schur_p, ) -from jax._src.pjit import sharding_constraint_p as sharding_constraint_p +from jax._src.pjit import ( + pjit_p as pjit_p, + sharding_constraint_p as sharding_constraint_p, +) from jax._src.prng import ( random_bits_p as random_bits_p, From e66a6dd2f68476daed2b6827aa06e2bdb44705ff Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 13 Jun 2025 05:34:16 -0700 Subject: [PATCH 1678/1769] Fix return type annotation for tree_util.tree_broadcast. PiperOrigin-RevId: 771055164 --- jax/_src/tree.py | 2 +- jax/_src/tree_util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 5aaa0bf2b006..d1d3be41b917 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -387,7 +387,7 @@ def map_with_path( def broadcast(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None - ) -> list[Any]: + ) -> Any: """Broadcasts a tree prefix into the full structure of a given tree. Args: diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6054b4711e02..f0341f72c029 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -564,7 +564,7 @@ def __new__(klass, func, *args, **kw): @export def tree_broadcast(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None - ) -> list[Any]: + ) -> Any: """Alias of :func:`jax.tree.broadcast`.""" broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf) return tree_structure(full_tree).unflatten(broadcast_leaves) From 193f11db4bd06ddf6b2f969605d4cd35c8f10dd0 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 13 Jun 2025 05:45:10 -0700 Subject: [PATCH 1679/1769] [Mosaic GPU] Parametrize the `test_subview` test. PiperOrigin-RevId: 771057705 --- tests/mosaic/gpu_test.py | 115 ++++++++++++++++++++++++++++----------- 1 file changed, 84 insertions(+), 31 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 2e5075453e68..7dae5c5f37b8 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -27,6 +27,7 @@ from absl.testing import absltest, parameterized import jax from jax._src import config +from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -3430,10 +3431,38 @@ def body(ctx, result_gmem_ref, smem): x = np.full(shape, element_value, dtype=dtype) self.assertArraysEqual(jax.jit(kernel)(), x) - def test_subview(self): - full_shape = (2, 3, 128, 64) - offsets = [1, 0, 96, 0] - sub_shape = (32, 64) + @parameterized.parameters( + # Positive offsets will be passsed as static offsets. + # Negative offsets will be converted to positive dynamic offsets. + ((2, 3, 128, 64), (32, 64), [-1, 0, -96, 0], None, None, None), + ( + (3, 128, 64), + (32, 64), + [-2, -96, 0], + [32, 64], + mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + None, + ), + ( + (128, 128), + (64,), + [-1, 64], + [64], + mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + "Swizzle transforms .* if the minor dimension is unchanged.", + ), + ) + def test_subview( + self, + full_shape, + sub_shape, + offsets, + tiling, + swizzle, + error_regex, + ): + assert len(sub_shape) <= 2 + sizes = [1] * (len(full_shape) - len(sub_shape)) + list(sub_shape) def body( ctx: launch_context.LaunchContext, @@ -3441,9 +3470,6 @@ def body( sub_gmem_ref: ir.Value, smem: list[ir.Value], ): - # TODO(dasenov): Add a parametrization to also test subview of transformed - # refs. - del ctx full_smem_ref, tma_barrier = smem dialect_barrier = tma_barrier.as_barrier_memref() @@ -3471,17 +3497,17 @@ def body( # SubView dynamic_offsets = [ - arith.constant(ir.IndexType.get(), offsets[0]), - # offsets[1] is a static offset. - arith.constant(ir.IndexType.get(), offsets[2]), + arith.constant(ir.IndexType.get(), -o) for o in offsets if o < 0 ] full_ref_type = ir.MemRefType(full_smem_ref.type) - dynamic = ir.ShapedType.get_dynamic_size() + dynamic = ir.ShapedType.get_dynamic_stride_or_offset() rhs_subview_ref_type = ir.MemRefType.get( shape=sub_shape, element_type=full_ref_type.element_type, - layout=ir.StridedLayoutAttr.get(dynamic, [full_shape[-1], 1]), + layout=ir.StridedLayoutAttr.get( + dynamic, [full_shape[-1], 1] if len(sub_shape) == 2 else [1] + ), memory_space=full_ref_type.memory_space, ) sub_smem_ref = memref.SubViewOp( @@ -3490,16 +3516,32 @@ def body( offsets=dynamic_offsets, sizes=None, strides=None, - static_offsets=[dynamic, offsets[1], dynamic, offsets[3]], - static_sizes=[1, 1] + list(sub_shape), - static_strides=[1, 1, 1, 1], + static_offsets=[(dynamic if o < 0 else o) for o in offsets], + static_sizes=sizes, + static_strides=[1] * len(sizes), ).result + transforms = [] + if tiling is not None: + transforms.append(mgpu_dialect.TileTransformAttr.get(tiling)) + if swizzle is not None: + transforms.append(mgpu_dialect.SwizzleTransformAttr.get(swizzle)) + + if transforms: + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. + if jaxlib.version < (0, 6, 2): + self.skipTest("Test requires jaxlib version >= 0.6.2") + + sub_smem_ref = mgpu_dialect.with_transforms( + sub_smem_ref, + transforms=ir.ArrayAttr.get(transforms), + ) + # SMEM -> GMEM mgpu_dialect.async_store( source=sub_smem_ref, destination=sub_gmem_ref, - indices=[zero_i32, zero_i32], + indices=[zero_i32] * len(sub_shape), slice_lengths=sub_shape, ) nvvm.cp_async_bulk_wait_group(0) @@ -3507,23 +3549,34 @@ def body( el_type = jnp.bfloat16 full_jax_shape = jax.ShapeDtypeStruct(full_shape, el_type) result_jax_shape = jax.ShapeDtypeStruct(sub_shape, el_type) - kernel = mgpu.as_gpu_kernel( - body, - grid=(1, 1, 1), - block=(128, 1, 1), - in_shape=(full_jax_shape), - out_shape=result_jax_shape, - smem_scratch_shape=[full_jax_shape, core.TMABarrier(1)], - thread_semantics=mgpu.LoweringSemantics.Warpgroup, - ) - prng_key = jax.random.key(1234) - x = jax.random.randint(prng_key, full_shape, 0, 10).astype(el_type) + def create_kernel(): + return mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(full_jax_shape), + out_shape=result_jax_shape, + smem_scratch_shape=[full_jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + if error_regex: + with self.assertRaisesRegex(NotImplementedError, error_regex): + # While we expect NotImplementedError here, the test is actually + # checking a restricted behaviour that should be a ValueError. However, + # our code cannot yet figure out the difference and raise the correct + # type. + create_kernel() + else: + prng_key = jax.random.key(1234) + x = jax.random.randint(prng_key, full_shape, 0, 10).astype(el_type) - self.assertArraysEqual( - jax.jit(kernel)(x), - x[offsets[0], offsets[1], offsets[2] :, offsets[3] :], - ) + slicing = tuple(slice(abs(o), abs(o) + s) for o, s in zip(offsets, sizes)) + self.assertArraysEqual( + jax.jit(create_kernel())(x), + x[slicing].reshape(sub_shape), + ) class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): From 670ae13dacf175e4e240a2f81b5ec672a63d3a25 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Jun 2025 06:07:32 -0700 Subject: [PATCH 1680/1769] Make params of several pallas primitives hashable. Params of jaxpr equations are already supposed to be hashable. The idea is that this property will become load bearing shortly via a cache on abstract eval rules that includes params as part of the cache key. PiperOrigin-RevId: 771063718 --- jax/BUILD | 5 +++ jax/_src/frozen_dict.py | 48 +++++++++++++++++++++ jax/_src/pallas/BUILD | 1 + jax/_src/pallas/mosaic/core.py | 55 ++++++++++++++++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 4 +- jax/_src/pallas/pallas_call.py | 8 +++- jax/_src/pallas/primitives.py | 13 +++--- jax/_src/pallas/triton/primitives.py | 2 +- 8 files changed, 121 insertions(+), 15 deletions(-) create mode 100644 jax/_src/frozen_dict.py diff --git a/jax/BUILD b/jax/BUILD index 651754d89ff8..1c2d72276baa 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -901,6 +901,11 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "frozen_dict", + srcs = ["_src/frozen_dict.py"], +) + pytype_strict_library( name = "hardware_utils", srcs = ["_src/hardware_utils.py"], diff --git a/jax/_src/frozen_dict.py b/jax/_src/frozen_dict.py new file mode 100644 index 000000000000..fd01a95145f2 --- /dev/null +++ b/jax/_src/frozen_dict.py @@ -0,0 +1,48 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterator, Mapping, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class FrozenDict(Mapping[K, V]): + + def __init__(self, d: Mapping[K, V]): + self._d = dict(d.items()) + + def __repr__(self) -> str: + return f"FrozenDict({self._d!r})" + + def __str__(self) -> str: + return f"FrozenDict({self._d})" + + def __getitem__(self, key: K) -> V: + return self._d[key] + + def __hash__(self) -> int: + # This assumes that the values are hashable. + return hash(frozenset(self._d.items())) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FrozenDict): + return False + return self._d == other._d + + def __iter__(self) -> Iterator[K]: + return iter(self._d) + + def __len__(self) -> int: + return len(self._d) diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 91987167512c..e080c601836f 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -45,6 +45,7 @@ py_library( "//jax:core", "//jax:dtypes", "//jax:effects", + "//jax:frozen_dict", "//jax:mlir", "//jax:partial_eval", "//jax:pretty_printer", diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 2a6bdf9fa8a4..835b1cf68244 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -20,11 +20,12 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, Mapping import jax from jax._src import core as jax_core from jax._src import util +from jax._src.frozen_dict import FrozenDict from jax._src.pallas import core as pallas_core import jax.numpy as jnp import numpy as np @@ -88,8 +89,8 @@ class CompilerParams(pallas_core.CompilerParams): disable_bounds_checks: Disable bounds checks in the kernel. """ BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" - dimension_semantics: Sequence[DimensionSemantics] | None = None - allow_input_fusion: Sequence[bool] | None = None + dimension_semantics: tuple[DimensionSemantics, ...] | None = None + allow_input_fusion: tuple[bool, ...] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None has_side_effects: bool = False @@ -99,9 +100,46 @@ class CompilerParams(pallas_core.CompilerParams): kernel_type: KernelType = KernelType.TC disable_bounds_checks: bool = False + def __init__( + self, + dimension_semantics: Sequence[DimensionSemantics] | None = None, + allow_input_fusion: Sequence[bool] | None = None, + vmem_limit_bytes: int | None = None, + collective_id: int | None = None, + has_side_effects: bool = False, + flags: Mapping[str, Any] | None = None, + internal_scratch_in_bytes: int | None = None, + serialization_format: int = 1, + kernel_type: KernelType = KernelType.TC, + disable_bounds_checks: bool = False, + ): + object.__setattr__( + self, + "dimension_semantics", + None if dimension_semantics is None else tuple(dimension_semantics), + ) + object.__setattr__( + self, + "allow_input_fusion", + None if allow_input_fusion is None else tuple(allow_input_fusion), + ) + object.__setattr__(self, "vmem_limit_bytes", vmem_limit_bytes) + object.__setattr__(self, "collective_id", collective_id) + object.__setattr__(self, "has_side_effects", has_side_effects) + object.__setattr__( + self, "flags", None if flags is None else FrozenDict(flags) + ) + object.__setattr__( + self, "internal_scratch_in_bytes", internal_scratch_in_bytes + ) + object.__setattr__(self, "serialization_format", serialization_format) + object.__setattr__(self, "kernel_type", kernel_type) + object.__setattr__(self, "disable_bounds_checks", disable_bounds_checks) + # Replace is a method, not a field. replace = dataclasses.replace + class MemorySpace(enum.Enum): ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" @@ -181,6 +219,17 @@ class TensorCoreMesh: devices: np.ndarray axis_names: Sequence[str] + def __init__(self, devices: np.ndarray, axis_names: Sequence[str]): + devices = np.copy(devices) + devices.setflags(write=False) + object.__setattr__(self, "devices", devices) + object.__setattr__(self, "axis_names", tuple(axis_names)) + + def __hash__(self) -> int: + return hash( + (self.devices.shape, tuple(np.ravel(self.devices)), self.axis_names) + ) + @property def backend(self) -> str: return "mosaic_tpu" diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index ffc9d0623fff..0c55eaefc345 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1808,7 +1808,7 @@ def _jaxpr_call_discharge( outs = jaxpr_call_p.bind( *flat_args, jaxpr=discharged_jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) discharged_outs_it = iter(outs[len(jaxpr.outvars) :]) @@ -1861,7 +1861,7 @@ def jaxpr_call( *flat_refs, *flat_program_ids, jaxpr=jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 5c6cf10dec12..2375e1c30578 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -32,6 +32,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util +from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -1483,7 +1484,7 @@ def _rewritten_body(*args): *ref_args, *rest_args, jaxpr=new_jaxpr, - input_output_aliases=new_input_output_aliases, + input_output_aliases=tuple(new_input_output_aliases), grid_mapping=new_grid_mapping, mesh=mesh, debug=debug, @@ -1608,11 +1609,12 @@ def pallas_call( ) + def _normalize_compiler_params( compiler_params: Mapping[Backend, CompilerParams] | CompilerParams | None, ) -> Mapping[Backend, CompilerParams]: if compiler_params is None: - return {} + return FrozenDict({}) if isinstance(compiler_params, CompilerParams): compiler_params = {compiler_params.BACKEND: compiler_params} assert isinstance(compiler_params, Mapping) @@ -1628,6 +1630,8 @@ def _normalize_compiler_params( f"Inconsistent backend in compiler_params: {params.BACKEND} !=" f" {backend}" ) + if not isinstance(compiler_params, FrozenDict): + compiler_params = FrozenDict(compiler_params) return compiler_params diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 95ae15e5bf4e..bff0a55e272c 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -20,7 +20,7 @@ import functools import string from collections.abc import Hashable -from typing import Any, Callable +from typing import Any, Callable, Sequence import jax from jax import lax @@ -346,9 +346,9 @@ def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) def max_contiguous(x, values): - if not isinstance(values, list): - values = [values] - return max_contiguous_p.bind(x, values=values) + if not isinstance(values, (list, tuple)): + values = (values,) + return max_contiguous_p.bind(x, values=tuple(values)) @max_contiguous_p.def_abstract_eval def _max_contiguous_abstract_eval(aval, **_): @@ -359,9 +359,8 @@ def _max_contiguous_abstract_eval(aval, **_): multiple_of_p.def_impl(lambda x, **_: x) mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x]) -def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array: - if not isinstance(values, list): - values = [values] +def multiple_of(x: jax.Array, values: Sequence[int] | int) -> jax.Array: + values = (values,) if isinstance(values, int) else tuple(values) return multiple_of_p.bind(x, values=values) @multiple_of_p.def_abstract_eval diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index b845a4079ff4..2a15b3dbd47d 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -83,7 +83,7 @@ def elementwise_inline_asm( asm=asm, constraints=constraints, pack=pack, - result_shape_dtypes=result_shape_dtypes, + result_shape_dtypes=tuple(result_shape_dtypes), ) From 92aedb286d49cb4b21c936adf2e9ce1fbb70b3b1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 13 Jun 2025 09:07:14 -0700 Subject: [PATCH 1681/1769] Improve reshape not supported error message Before: ``` jax._src.core.ShardingTypeError: This reshape is not supported. Please specify the sharding of the output via the `sharding` argument of jax.lax.reshape. Got operand shape: (4, 2, 6, 8), new sizes: (4, 12, 8) and operand spec: PartitionSpec(None, None, 'y', 'x') ``` After: ``` jax._src.core.ShardingTypeError: This reshape is not supported. Please specify the sharding of the output via the `sharding` argument of jax.lax.reshape. Got operand shape: float32[4,2,6@y,8@x], new sizes: (4, 12, 8) ``` PiperOrigin-RevId: 771114653 --- jax/_src/lax/lax.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8fbcf21d574b..7df3a93cabdc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7114,8 +7114,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec) @@ -7147,8 +7146,7 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') else: new_spec.append(sp) assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) @@ -7172,8 +7170,7 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') else: new_spec.append(next(op_spec)) assert next(op_spec, None) is None From 1b1e9f715dcfa83f546430af1987dc709c39869d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 13 Jun 2025 09:43:30 -0700 Subject: [PATCH 1682/1769] Internal refactor: move TPU lowering rules out of jax/_src/lax --- jax/BUILD | 1 + jax/_src/lax/linalg.py | 102 -------------------------- jax/_src/scipy/linalg.py | 2 +- jax/_src/tpu/__init__.py | 13 ++++ jax/_src/tpu/linalg/__init__.py | 24 ++++++ jax/_src/{lax => tpu/linalg}/eigh.py | 69 ++++++++++++++++- jax/_src/{lax => tpu/linalg}/qdwh.py | 0 jax/_src/{lax => tpu/linalg}/stack.py | 0 jax/_src/{lax => tpu/linalg}/svd.py | 52 +++++++++++++ jax/lax/linalg.py | 2 +- tests/lax_scipy_spectral_dac_test.py | 2 +- tests/qdwh_test.py | 2 +- tests/stack_test.py | 2 +- tests/svd_test.py | 2 +- 14 files changed, 163 insertions(+), 110 deletions(-) create mode 100644 jax/_src/tpu/__init__.py create mode 100644 jax/_src/tpu/linalg/__init__.py rename jax/_src/{lax => tpu/linalg}/eigh.py (90%) rename jax/_src/{lax => tpu/linalg}/qdwh.py (100%) rename jax/_src/{lax => tpu/linalg}/stack.py (100%) rename jax/_src/{lax => tpu/linalg}/svd.py (86%) diff --git a/jax/BUILD b/jax/BUILD index 1c2d72276baa..f9d0b264c756 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -319,6 +319,7 @@ py_library_providing_imports_info( "_src/scipy/**/*.py", "_src/state/**/*.py", "_src/third_party/**/*.py", + "_src/tpu/**/*.py", "experimental/key_reuse/**/*.py", "experimental/roofline/**/*.py", "image/**/*.py", diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 10755ccb2ec1..9e4d188579bb 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -38,10 +38,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import control_flow -from jax._src.lax import eigh as lax_eigh from jax._src.lax import lax as lax_internal -from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int from jax._src.lib import gpu_linalg @@ -1160,57 +1157,6 @@ def _eigh_cpu_gpu_lowering( return [v, w] -def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): - *_, m, n = x.shape - assert m == n, (m, n) - - termination_size = 256 - if not is_constant_dim(m): - # TODO: maybe we can relax the check below for shape polymorphism? - raise NotImplementedError( - "Shape polymorphism for native lowering for eigh is implemented " - f"only for the batch dimensions: {x.shape}") - if m <= termination_size and ( - subset_by_index is None or subset_by_index == (0, n) - ): - eig_vals, eig_vecs = eigh_jacobi(x, lower=lower, - sort_eigenvalues=sort_eigenvalues) - return eig_vecs, eig_vals - - def eigh_qdwh(x): - if len(x.shape) > 2: - return control_flow.map(eigh_qdwh, x) - - # We should only look at elements from the lower/upper triangle. Reflects - # that triangle into the other triangle to form a Hermitian matrix. - if lower: - mask = lax_internal._tri(bool, (n, n), 0) - else: - mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) - if dtypes.issubdtype(x.dtype, np.complexfloating): - re = lax.select(mask, lax.real(x), _T(lax.real(x))) - if lower: - im_mask = lax_internal._tri(bool, (n, n), -1) - else: - im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) - im = lax.imag(x) - im = lax.select(im_mask, im, lax.full_like(im, 0)) - im = lax.select(mask, im, -_T(im)) - x = lax.complex(re, im) - else: - x = lax.select(mask, x, _T(x)) - - return lax_eigh.eigh( - x, - sort_eigenvalues=sort_eigenvalues, - termination_size=termination_size, - subset_by_index=subset_by_index, - ) - - eig_vals, eig_vecs = eigh_qdwh(x) - return eig_vecs, eig_vals - - def _eigh_jvp_rule( primals, tangents, *, lower, sort_eigenvalues, subset_by_index ): @@ -1256,9 +1202,6 @@ def _eigh_jvp_rule( _eigh_dtype_rule, (_float | _complex,), (2,), _eigh_shape_rule, "eigh", multiple_results=True) ad.primitive_jvps[eigh_p] = _eigh_jvp_rule -mlir.register_lowering( - eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True), - platform='tpu') register_cpu_gpu_lowering(eigh_p, _eigh_cpu_gpu_lowering) @@ -2212,57 +2155,12 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, else: return s, u, vt, info -def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): - if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT: - raise NotImplementedError( - "The SVD algorithm parameter is not implemented on TPU.") - - batch_dims = a.shape[:-2] - fn = partial( - lax_svd.svd, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) - for _ in range(len(batch_dims)): - fn = api.vmap(fn) - - if compute_uv: - u, s, vh = fn(a) - return [s, u, vh] - else: - s = fn(a) - return [s] - -def _svd_tpu_lowering_rule( - ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None -): - del algorithm # unused - operand_aval, = ctx.avals_in - m, n = operand_aval.shape[-2:] - - if m == 0 or n == 0: - return mlir.lower_fun(_empty_svd, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - ) - - return mlir.lower_fun(_svd_tpu, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) svd_p = linalg_primitive( _svd_dtype_rule, (_float | _complex,), (2,), _svd_shape_rule, "svd", multiple_results=True) ad.primitive_jvps[svd_p] = _svd_jvp_rule register_cpu_gpu_lowering(svd_p, _svd_cpu_gpu_lowering) -mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) # Symmetric product diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d4971fb745cb..5a1f6d988740 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -26,10 +26,10 @@ from jax import lax from jax._src import dtypes from jax._src.lax import linalg as lax_linalg -from jax._src.lax import qdwh from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) +from jax._src.tpu.linalg import qdwh from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/tpu/__init__.py b/jax/_src/tpu/__init__.py new file mode 100644 index 000000000000..1337256a5074 --- /dev/null +++ b/jax/_src/tpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/tpu/linalg/__init__.py b/jax/_src/tpu/linalg/__init__.py new file mode 100644 index 000000000000..8c09b25d1e08 --- /dev/null +++ b/jax/_src/tpu/linalg/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from jax._src.tpu.linalg import ( + eigh as eigh, + qdwh as qdwh, + svd as svd, +) + +from jax._src import traceback_util +traceback_util.register_exclusion(os.path.dirname(__file__)) diff --git a/jax/_src/lax/eigh.py b/jax/_src/tpu/linalg/eigh.py similarity index 90% rename from jax/_src/lax/eigh.py rename to jax/_src/tpu/linalg/eigh.py index 99711dc6bf0e..dda254579459 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/tpu/linalg/eigh.py @@ -33,15 +33,20 @@ import numpy as np import jax +from jax._src import core +from jax._src import dtypes import jax._src.numpy.lax_numpy as jnp import jax._src.numpy.linalg as jnp_linalg +from jax._src.interpreters import mlir from jax._src.numpy import tensor_contractions from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax import lax -from jax._src.lax import qdwh +from jax._src.lax import control_flow +from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg import qdwh +from jax._src.tpu.linalg.stack import Stack # QDWH-eigh is a recursive algorithm where the structure of the recursion @@ -573,3 +578,63 @@ def eigh( eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs + + +def _T(x: jax.Array) -> jax.Array: + return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + + +def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): + *_, m, n = x.shape + assert m == n, (m, n) + + termination_size = 256 + if not core.is_constant_dim(m): + # TODO: maybe we can relax the check below for shape polymorphism? + raise NotImplementedError( + "Shape polymorphism for native lowering for eigh is implemented " + f"only for the batch dimensions: {x.shape}") + if m <= termination_size and ( + subset_by_index is None or subset_by_index == (0, n) + ): + eig_vals, eig_vecs = lax_linalg.eigh_jacobi(x, lower=lower, + sort_eigenvalues=sort_eigenvalues) + return eig_vecs, eig_vals + + def eigh_qdwh(x): + if len(x.shape) > 2: + return control_flow.map(eigh_qdwh, x) + + # We should only look at elements from the lower/upper triangle. Reflects + # that triangle into the other triangle to form a Hermitian matrix. + if lower: + mask = lax_internal._tri(bool, (n, n), 0) + else: + mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) + if dtypes.issubdtype(x.dtype, np.complexfloating): + re = lax.select(mask, lax.real(x), _T(lax.real(x))) + if lower: + im_mask = lax_internal._tri(bool, (n, n), -1) + else: + im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) + im = lax.imag(x) + im = lax.select(im_mask, im, lax.full_like(im, 0)) + im = lax.select(mask, im, -_T(im)) + x = lax.complex(re, im) + else: + x = lax.select(mask, x, _T(x)) + + return eigh( + x, + sort_eigenvalues=sort_eigenvalues, + termination_size=termination_size, + subset_by_index=subset_by_index, + ) + + eig_vals, eig_vecs = eigh_qdwh(x) + return eig_vecs, eig_vals + + +mlir.register_lowering( + lax_linalg.eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True), + platform='tpu') diff --git a/jax/_src/lax/qdwh.py b/jax/_src/tpu/linalg/qdwh.py similarity index 100% rename from jax/_src/lax/qdwh.py rename to jax/_src/tpu/linalg/qdwh.py diff --git a/jax/_src/lax/stack.py b/jax/_src/tpu/linalg/stack.py similarity index 100% rename from jax/_src/lax/stack.py rename to jax/_src/tpu/linalg/stack.py diff --git a/jax/_src/lax/svd.py b/jax/_src/tpu/linalg/svd.py similarity index 86% rename from jax/_src/lax/svd.py rename to jax/_src/tpu/linalg/svd.py index 9f22f130cbb2..298d6650b618 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/tpu/linalg/svd.py @@ -43,6 +43,8 @@ import jax from jax import lax from jax._src import core +from jax._src.interpreters import mlir +from jax._src.lax import linalg as lax_linalg import jax.numpy as jnp @@ -110,6 +112,7 @@ def correct_rank_deficiency(u_out): u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) + @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, @@ -241,3 +244,52 @@ def svd( return (v_out, s_out, u_out.T.conj()) return (u_out, s_out, v_out.T.conj()) + + +def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): + if algorithm is not None and algorithm != lax_linalg.SvdAlgorithm.DEFAULT: + raise NotImplementedError( + "The SVD algorithm parameter is not implemented on TPU.") + + batch_dims = a.shape[:-2] + fn = functools.partial( + svd, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + for _ in range(len(batch_dims)): + fn = jax.vmap(fn) + + if compute_uv: + u, s, vh = fn(a) + return [s, u, vh] + else: + s = fn(a) + return [s] + + +def _svd_tpu_lowering_rule( + ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None +): + del algorithm # unused + operand_aval, = ctx.avals_in + m, n = operand_aval.shape[-2:] + + if m == 0 or n == 0: + return mlir.lower_fun(lax_linalg._empty_svd, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) + + return mlir.lower_fun(_svd_tpu, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + +mlir.register_lowering(lax_linalg.svd_p, _svd_tpu_lowering_rule) diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index 343073ca56d0..984592534656 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -46,6 +46,6 @@ tridiagonal_solve_p as tridiagonal_solve_p, ) -from jax._src.lax.qdwh import ( +from jax._src.tpu.linalg.qdwh import ( qdwh as qdwh ) diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index a09dcac5371c..4359318a7997 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -18,7 +18,7 @@ from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu -from jax._src.lax import eigh as lax_eigh +from jax._src.tpu.linalg import eigh as lax_eigh from absl.testing import absltest diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 91cc3a51f876..955e23374fee 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -18,7 +18,7 @@ import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import qdwh +from jax._src.tpu.linalg import qdwh import jax.numpy as jnp import numpy as np diff --git a/tests/stack_test.py b/tests/stack_test.py index aa1a02793b1a..8ebfc3489ff5 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg.stack import Stack from jax._src import test_util as jtu diff --git a/tests/svd_test.py b/tests/svd_test.py index dfc3de7a764f..d95a22e2f93c 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -20,7 +20,7 @@ import scipy.linalg as osp_linalg from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import svd +from jax._src.tpu.linalg import svd from absl.testing import absltest From 64c9574490335e7aa837e9fe1fbed635eb493904 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Jun 2025 09:43:32 -0700 Subject: [PATCH 1683/1769] Make params of assert_consumed_value_p hashable. Move the HashableArray utility out of ffi.py and share it. Change in preparation for adding caching for abstract_eval rules. PiperOrigin-RevId: 771125857 --- jax/BUILD | 8 +++++++ jax/_src/ffi.py | 19 +-------------- jax/_src/hashable_array.py | 37 +++++++++++++++++++++++++++++ jax/experimental/key_reuse/_core.py | 9 +++---- 4 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 jax/_src/hashable_array.py diff --git a/jax/BUILD b/jax/BUILD index 1c2d72276baa..de1169f99227 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -387,6 +387,7 @@ py_library_providing_imports_info( ":environment_info", ":export", ":ffi", + ":hashable_array", ":internal_mesh_utils", ":jaxpr_util", ":layout", @@ -892,6 +893,7 @@ pytype_strict_library( ":batching", ":core", ":effects", + ":hashable_array", ":layout", ":mlir", ":typing", @@ -911,6 +913,12 @@ pytype_strict_library( srcs = ["_src/hardware_utils.py"], ) +pytype_strict_library( + name = "hashable_array", + srcs = ["_src/hashable_array.py"], + deps = py_deps("numpy"), +) + pytype_library( name = "lax_reference", srcs = ["_src/lax_reference.py"], diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 8bb6b368d61a..e774f98aaa78 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -28,6 +28,7 @@ from jax._src import effects from jax._src import util from jax._src import xla_bridge +from jax._src.hashable_array import HashableArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -558,24 +559,6 @@ def _unwrap_kwargs_hashable(kwargs: Sequence[tuple[str, Any]]) -> dict[str, Any] return unwrapped_kwargs -class HashableArray: - __slots__ = ["val"] - - def __init__(self, val): - assert isinstance(val, np.ndarray) - self.val = np.copy(val) - self.val.setflags(write=False) - - def __repr__(self): - return f"HashableArray({self.val})" - - def __hash__(self): - return hash((self.val.shape, self.val.dtype, self.val.tobytes())) - - def __eq__(self, other): - return isinstance(other, HashableArray) and np.array_equal(self.val, other.val) - - class HashableDict: __slots__ = ["val"] diff --git a/jax/_src/hashable_array.py b/jax/_src/hashable_array.py new file mode 100644 index 000000000000..4757a9c5eb24 --- /dev/null +++ b/jax/_src/hashable_array.py @@ -0,0 +1,37 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the + +import numpy as np + + +class HashableArray: + __slots__ = ["val"] + val: np.ndarray + + def __init__(self, val): + self.val = np.array(val, copy=True) + self.val.setflags(write=False) + + def __repr__(self): + return f"HashableArray({self.val!r})" + + def __str__(self): + return f"HashableArray({self.val})" + + def __hash__(self): + return hash((self.val.shape, self.val.dtype, self.val.tobytes())) + + def __eq__(self, other): + return isinstance(other, HashableArray) and np.array_equal( + self.val, other.val + ) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index a2ffc6582fff..7c7ffd17a56c 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -35,6 +35,7 @@ from jax._src import util from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p +from jax._src.hashable_array import HashableArray from jax._src.interpreters import partial_eval as pe from jax._src.util import weakref_lru_cache @@ -257,16 +258,16 @@ def consume(key): def assert_unconsumed(key): """Assert that a key is unconsumed""" - assert_consumed_value_p.bind(key, value=False) + assert_consumed_value_p.bind(key, value=HashableArray(False)) def assert_consumed(key, value=True): """Assert that a key is consumed""" - assert_consumed_value_p.bind(key, value=value) + assert_consumed_value_p.bind(key, value=HashableArray(value)) def _check_consumed_value(eqn, consumed): """Extra check for use with assert_consumed_value_p""" - expected = eqn.params['value'] + expected = eqn.params['value'].val if not np.all(consumed == expected): if np.all(expected): raise AssertionError(f"Expected key to be consumed in {eqn}") @@ -415,7 +416,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: function_type_signature(fun, *args) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- # key reuse rules for particular primitives: @dynamic_key_reuse_signature From 790540c7d29fe286495287a3ee431586faa22984 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Jun 2025 12:43:21 -0700 Subject: [PATCH 1684/1769] Make some remaining jaxpr equation params hashable. * Don't hash the .shape of a DShapedArray. * Use HashableArray in lax.composite to wrap np.ndarrays. * Use a tuple instead of an empty list in jaxpr_effects_test. PiperOrigin-RevId: 771190386 --- jax/BUILD | 1 + jax/_src/core.py | 3 ++- jax/_src/interpreters/mlir.py | 3 +++ jax/_src/lax/lax.py | 6 +++++- tests/jaxpr_effects_test.py | 4 ++-- 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 5ab9d861085b..6b1c5dc88499 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1002,6 +1002,7 @@ pytype_strict_library( ":core", ":dtypes", ":effects", + ":hashable_array", ":jaxpr_util", ":layout", ":mesh", diff --git a/jax/_src/core.py b/jax/_src/core.py index 4e5f720f680c..ff3417a2d883 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2271,7 +2271,8 @@ def __eq__(self, other): and self.weak_type == other.weak_type) def __hash__(self): - return hash((self.shape, self.dtype, self.weak_type)) + # We don't hash the contents of the shape because it may contain tracers. + return hash((len(self.shape), self.dtype, self.weak_type)) def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c11a68d7c45f..43e8047071b9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -38,6 +38,7 @@ from jax._src import core from jax._src import dtypes from jax._src import effects as effects_lib +from jax._src import hashable_array from jax._src import jaxpr_util from jax._src import linear_util as lu from jax._src import path @@ -383,6 +384,8 @@ def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute return _numpy_array_attribute(val) register_attribute_handler(np.ndarray, _numpy_array_attribute_handler) +register_attribute_handler(hashable_array.HashableArray, + lambda x: _numpy_array_attribute_handler(x.val)) for _scalar_type in [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7df3a93cabdc..78da3ab47fe3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -51,6 +51,7 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, abstract_token, canonicalize_shape) from jax._src.errors import UnexpectedTracerError +from jax._src.hashable_array import HashableArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -1880,7 +1881,10 @@ def _decorator(*args, **kwargs): out_flat = composite_p.bind( *flat_args, name=name, - attributes=tuple((k, v) for k, v in kwargs.items()), + attributes=tuple( + (k, HashableArray(v) if isinstance(v, np.ndarray) else v) + for k, v in kwargs.items() + ), version=version, jaxpr=closed_jaxpr, ) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 783d62dcc47e..420aa642a1d6 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -579,7 +579,7 @@ def f(x): # Runs in a thread. res = jax.jit( lambda x: callback_p.bind( - x, callback=_noop, effect=log_effect, out_avals=[]) + x, callback=_noop, effect=log_effect, out_avals=()) )(x) tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) return res @@ -635,7 +635,7 @@ def log_value(x): @jax.pmap def f(x): callback_p.bind( - x, callback=log_value, effect=unordered_log_effect, out_avals=[]) + x, callback=log_value, effect=unordered_log_effect, out_avals=()) return x + 1 f(jnp.arange(2)).block_until_ready() jax.effects_barrier() From b9cf0af59fc4af4f37880ac9dd053fdb74eb47ac Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Fri, 13 Jun 2025 13:04:09 -0700 Subject: [PATCH 1685/1769] Implemented cross-host memory transfer on GPU. # Background Emily is currently working on extending `jax.device_put` to allow for cross-host memory transfers. Previously (https://github.com/jax-ml/jax/pull/28867), she got the cross-host transfers working on TPU using the `MakeCrossHostReceiveBuffers` and `CopyToRemoteDevice` APIs. This CL implements these two APIs on GPU. # Future Work This CL introduces a very basic, very limited implementation that should be improved in the future. For now, `MakeCrossHostReceiveBuffers` creates a `CliqueId` (think `ncclUniqueId`) and a communicator (think `ncclComm_t`) using this `CliqueId`. The `CliqueId` is sent to the sending process, and the sending process creates the corresponding communicator. The data is then sent using the communicators' `Send` and `Recv` APIs. This design is suboptimal because it creates a new pair of communicators for every transfer. It doesn't use the communicator caching code path that other collectives use. I was also a bit unclear on how memory transfers are ordered. For example, if two different buffers need to be transfered between two devices, don't we need to be careful that the senders and receivers agree on the order in which these buffers will be sent? PiperOrigin-RevId: 771197322 --- jax/_src/dispatch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 972c7f6ffb23..f51213e8f5ad 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -385,8 +385,11 @@ def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): backend = xla_bridge.get_backend() # There is experimental support for cross-host device transfers on TFRT TPU # backends only. - if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or - not backend.platform_version.startswith("TFRT TPU")): + # TODO: https://github.com/jax-ml/jax/issues/26645 - Allow backends to be + # queried for their cross-host transfer support. + if (xla_bridge.process_count() == 1 or backend.platform not in {"gpu", "tpu"} + or (backend.platform == "gpu" and not backend.platform_version.startswith("cuda")) + or (backend.platform == "tpu" and not backend.platform_version.startswith("TFRT TPU"))): return False if (src_sharding._internal_device_list.device_kind != dst_sharding._internal_device_list.device_kind): From 080294cfa8db276c17c3adcc9284825ce90d2332 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Jun 2025 19:07:53 +0000 Subject: [PATCH 1686/1769] Load CUDA libraries up front with cdll.LoadLibrary(). Works around a bug in CUDNN where it does not declare or locate its own NVRTC dependency. However, this approach should also improve our error messages in general when we load NVIDIA's libraries, so do it across the board for all libraries we use. --- build/requirements.in | 1 + build/requirements_lock_3_10.txt | 5 ++ build/requirements_lock_3_11.txt | 7 ++- build/requirements_lock_3_12.txt | 7 ++- build/requirements_lock_3_13.txt | 7 ++- build/requirements_lock_3_13_ft.txt | 7 ++- jax_plugins/cuda/__init__.py | 92 ++++++++++++++++++++++++----- jax_plugins/cuda/plugin_setup.py | 2 + jaxlib/tools/BUILD.bazel | 1 + 9 files changed, 110 insertions(+), 19 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index 96c27739a5e7..e3aae21d9b4d 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -26,6 +26,7 @@ jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" libtpu ; sys_platform == "linux" and platform_machine == "x86_64" # For Mosaic GPU collectives +nvidia-cuda-nvrtc-cu12>=12.1.55 ; sys_platform == "linux" nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" # Platform-specific dependencies that are being ignored by pip-compile diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 5caa99692756..523d3bbfe6e2 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -451,6 +451,11 @@ nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/requirements.in nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index de3c35ed3c02..9574be1a972a 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ @@ -446,6 +446,11 @@ nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/requirements.in nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 04c6990da696..13c4269186ba 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ @@ -446,6 +446,11 @@ nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/requirements.in nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 965cb3bc9672..2aaf276d94df 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.3.0 \ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ @@ -501,6 +501,11 @@ nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/requirements.in nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7d111c3b3e9..aa6f2daa569b 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -27,7 +27,7 @@ cloudpickle==3.1.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.3.1 \ --hash=sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1 \ --hash=sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda \ @@ -452,6 +452,11 @@ nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/requirements.in nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index b92f4d229395..10609460d814 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import functools import importlib import logging @@ -24,21 +25,28 @@ from jax._src.lib import xla_client import jax._src.xla_bridge as xb -# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without -# preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - cuda_versions = importlib.import_module( - f'{pkg_name}._versions' - ) - except ImportError: - cuda_plugin_extension = None - cuda_versions = None - else: - break +cuda_plugin_extension = None +cuda_versions = None + +def _import_extensions(): + global cuda_plugin_extension + global cuda_versions + + # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without + # preinstalled jax cuda plugin packages. + for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) + except ImportError: + cuda_plugin_extension = None + cuda_versions = None + else: + break logger = logging.getLogger(__name__) @@ -82,6 +90,57 @@ def _get_library_path(): return None +def _load(module, libraries): + try: + m = importlib.import_module(f"nvidia.{module}") + except ImportError: + m = None + + for lib in libraries: + excs = [] + if m is not None: + path = pathlib.Path(m.__path__[0]) / "lib" / lib + try: + ctypes.cdll.LoadLibrary(path) + continue + except OSError as e: + excs.append(e) + + # TODO(phawkins): check the non-Python path here and error if not found. + # # Try again, without the Python module path. + # try: + # ctypes.cdll.LoadLibrary(lib) + # continue + # except OSError as e: + # excs.append(e) + # + # if sys.version_info >= (3, 11): + # raise ExceptionGroup(f"Unable to load CUDA library {lib}", excs) # noqa: F821 + # else: + # raise RuntimeError(f"Unable to load CUDA library {lib}") from excs[-1] + + +def _load_nvidia_libraries(): + """Attempts to load NVIDIA's libraries. + + We prefer the Python packages, if present. If not, we fall back to loading + them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will + find these copies.""" + _load("cuda_runtime", ["libcudart.so.12"]) + # cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it + # and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to + # nvidia/nvrtc instead of nvidia/cuda_nvrtc. + _load("cuda_nvrtc", ["libnvrtc.so.12"]) + _load("cublas", ["libcublas.so.12", "libcublasLt.so.12"]) + _load("nccl", ["libnccl.so.2"]) + _load("cuda_cupti", ["libcupti.so.12"]) + _load("cusparse", ["libcusparse.so.12"]) + _load("cusolver", ["libcusolver.so.11"]) + _load("cufft", ["libcufft.so.11"]) + _load("nvshmem", ["libnvshmem_host.so.3"]) + _load("cudnn", ["libcudnn.so.9"]) + + def _check_cuda_versions(raise_on_first_error: bool = False, debug: bool = False): assert cuda_versions is not None @@ -110,6 +169,7 @@ def _make_msg(name: str, f"{req_str}") return msg + def _version_check(name: str, get_version, get_build_version, @@ -254,6 +314,8 @@ def _version_check(name: str, def initialize(): + _load_nvidia_libraries() + _import_extensions() path = _get_library_path() if path is None: return diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index fc467824fe5f..412f5b6b814a 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -70,6 +70,8 @@ def has_ext_modules(self): # Until NVIDIA add version constraints, add a version constraint # here. "nvidia-nvjitlink-cu12>=12.1.105", + # nvrtc is a transitive and undeclared dep of cudnn. + "nvidia-cuda-nvrtc-cu12>=12.1.55", # NVSHMEM is used by Mosaic GPU collectives and can be used by XLA to # speed up collectives too. "nvidia-nvshmem-cu12>=3.2.5", diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 03bd75144186..515b7be04f64 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -462,6 +462,7 @@ filegroup( "@pypi_nvidia_cublas_cu12//:whl", "@pypi_nvidia_cuda_cupti_cu12//:whl", "@pypi_nvidia_cuda_nvcc_cu12//:whl", + "@pypi_nvidia_cuda_nvrtc_cu12//:whl", "@pypi_nvidia_cuda_runtime_cu12//:whl", "@pypi_nvidia_cudnn_cu12//:whl", "@pypi_nvidia_cufft_cu12//:whl", From dfc905297a0187ac2fa34566d1da475eb3fc953f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 13 Jun 2025 14:12:16 -0700 Subject: [PATCH 1687/1769] [Pallas] Add no_pipelining debugging option to emit_pipeline. The intended use-case is to flip this flag on to see if it fixes any bugs in a kernel. If a bug is fixed when no_pipelining=True, then there is likely an issue with illegal buffer usage in the user kernel. PiperOrigin-RevId: 771220632 --- jax/_src/pallas/mosaic/pipeline.py | 113 ++++++++++++++++++++--- tests/pallas/tpu_pallas_pipeline_test.py | 20 ++-- 2 files changed, 106 insertions(+), 27 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 367458d55e4a..de7895879f57 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -29,6 +29,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import helpers as tpu_helpers from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl from jax.extend.backend import get_default_device @@ -766,6 +767,22 @@ def accumulate(self): is_leaf=lambda x: isinstance(x, BufferedRefBase) ) +def map_inputs(f, *args): + """Maps over all input BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_input: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) + +def map_outputs(f, *args): + """Maps over all output BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_output: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) + def _filter_indices( indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...] @@ -1259,6 +1276,36 @@ def _partition_grid( return new_grid, offsets # type: ignore[return-value] +def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): + """Perform a synchronous copy from src to dst.""" + bref: BufferedRef + hbm_ref: REF + if isinstance(src, BufferedRef): + bref = src + if isinstance(dst, BufferedRef): + raise ValueError("Only one of src or dst can be a BufferedRef.") + hbm_ref = dst + copy_in = False + else: + if not isinstance(dst, BufferedRef): + raise ValueError("One of src or dst must be a BufferedRef.") + bref = dst + hbm_ref = src + copy_in = True + hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices) + bref_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(hbm_slice, bref.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + if copy_in: + tpu_helpers.sync_copy(hbm_ref.at[hbm_slice], + bref.current_ref.at[bref_slice]) # type: ignore[union-attr] + else: + tpu_helpers.sync_copy(bref.current_ref.at[bref_slice], # type: ignore[union-attr] + hbm_ref.at[hbm_slice]) + + def emit_pipeline( body, *, @@ -1270,6 +1317,7 @@ def emit_pipeline( core_axis_name: str | None = None, dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, trace_scopes: bool = True, + no_pipelining: bool = False, ): """Creates a function to emit a manual pallas pipeline. @@ -1296,6 +1344,9 @@ def emit_pipeline( or ARBITRARY). trace_scopes: optional bool, indicates whether to annotate each region in the pipeline using named_scope. + no_pipelining: If True, turns off pipelining and all copies will be + made synchronous. This is useful for debugging multiple-buffering + related bugs. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -1451,24 +1502,56 @@ def loop_body(step, carry): ) return next_brefs, _next_index(indices, grid) - @pl.when(num_steps > 0) - def _(): - # pipeline prologue + + if no_pipelining: + # Debugging mode where all copies are synchronous. initial_indices = (0,) * len(grid) scheduler = make_scheduler(0, initial_indices) - with scheduler.grid_env(): - brefs = map_brefs(scheduler.initialize, allocations, refs, schedule) - - # pipeline loop - brefs, next_indices = lax.fori_loop( - 0, num_steps, loop_body, (brefs, initial_indices) - ) + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + map_brefs(lambda bref: bref.init_slots(), brefs) + if postyeet is not None or prefetch is not None: + raise NotImplementedError("Prefetch/Postyeet not supported") + if any(bref.is_accumulator for bref in brefs): + raise NotImplementedError("Accumulators not supported") + @functools.partial(jax.lax.fori_loop, 0, num_steps, + init_val=initial_indices) + def _loop_body(step, indices): + scheduler = make_scheduler(step, indices) + with scheduler.grid_env(): + # prepare any local VMEM aliases + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + # loop input handling phase + copy_in = lambda bref, ref: sync_copy(ref, bref, indices) + map_inputs(copy_in, brefs, refs) + # run the kernel! + if body_prologue is not None: + body_prologue() + current_refs = map_brefs(lambda x: x.current_ref, brefs) + with scheduler._named_scope("ep_run_kernel"): + body(*current_refs, *scratches) + # loop output handling phase + copy_out = lambda bref, ref: sync_copy(bref, ref, indices) + map_outputs(copy_out, brefs, refs) + return _next_index(indices, grid) + else: + @pl.when(num_steps > 0) + def _(): + # pipeline prologue + initial_indices = (0,) * len(grid) + scheduler = make_scheduler(0, initial_indices) + with scheduler.grid_env(): + brefs = map_brefs(scheduler.initialize, allocations, refs, schedule) + + # pipeline loop + brefs, next_indices = lax.fori_loop( + 0, num_steps, loop_body, (brefs, initial_indices) + ) - # pipeline epilogue - final_indices = _prev_index(next_indices, grid) - scheduler = make_scheduler(num_steps - 1, final_indices) - with scheduler.grid_env(): - map_brefs(scheduler.finalize, brefs, refs, schedule) + # pipeline epilogue + final_indices = _prev_index(next_indices, grid) + scheduler = make_scheduler(num_steps - 1, final_indices) + with scheduler.grid_env(): + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 1c10fa4b73e5..94f8359dbaed 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -122,20 +122,15 @@ def _reduce_out(): class PallasCallPipelineTest(parameterized.TestCase): def setUp(self): - if jax.device_count() < 2: - self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') super().setUp() - @parameterized.named_parameters( - ('vmem', pltpu.VMEM), - ('hbm', pltpu.ANY), + @parameterized.product( + no_pipelining=[False, True], ) - def test_pipeline_matmul(self, memory_space): - # TODO(b/358121809): Re-enable this test once the bug is fixed. - self.skipTest('Broken test.') + def test_pipeline_matmul(self, no_pipelining): k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -156,16 +151,17 @@ def matmul_kernel(x_ref, y_ref, z_ref): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + no_pipelining=no_pipelining, )(x_ref, y_ref, z_ref) z = pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=memory_space), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), ) jax.block_until_ready(z(x, y)) @@ -174,7 +170,7 @@ def matmul_kernel(x_ref, y_ref, z_ref): out = jax.block_until_ready(z(x, y)) expected_out = jax.block_until_ready(jnp.dot(x, y)) - np.testing.assert_allclose(out, expected_out) + np.testing.assert_allclose(out, expected_out, atol=5e-5) @parameterized.named_parameters( ('vmem', pltpu.VMEM), From 8e346ddb5325a37ce95bd9b3e3e3a2267944a20f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 13 Jun 2025 15:04:37 -0700 Subject: [PATCH 1688/1769] Replace `with_partitions` and `with_unreduced` with `.update` on Partitions PiperOrigin-RevId: 771238444 --- jax/_src/interpreters/batching.py | 2 +- jax/_src/lax/lax.py | 2 +- jax/_src/numpy/einsum.py | 2 +- jax/_src/partition_spec.py | 13 +++++-------- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index d97dd124a2fe..11f1102bd792 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -461,7 +461,7 @@ class AxisData: def get_sharding_for_vmap(axis_data, orig_sharding, axis): val = axis_data.explicit_mesh_axis # TODO(yashkatariya): Preserve unreduced here using - # `orig_sharding.spec.with_partitions` + # `orig_sharding.spec.update` new_spec = P(*tuple_insert(orig_sharding.spec, axis, val)) return NamedSharding(orig_sharding.mesh, new_spec) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 78da3ab47fe3..526a631a4121 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4595,7 +4595,7 @@ def _add_unreduced(out_sharding, x, y): f' reduce {lhs_str} via `reshard` before calling `add`.') else: res_unreduced = frozenset() - return out_sharding.with_spec(out_sharding.spec.with_unreduced(res_unreduced)) + return out_sharding.with_spec(out_sharding.spec.update(unreduced=res_unreduced)) add_p: Primitive = naryop(_input_dtype, [_num, _num], 'add', unreduced_rule=_add_unreduced) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 64198d4f6fa0..0934cddc8a15 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -559,7 +559,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding( - out_sharding.mesh, spec.with_partitions(inverse_spec)) + out_sharding.mesh, spec.update(partitions=inverse_spec)) else: dot_general_out_sharding = out_sharding # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index b5bd2aecc20a..722253a8531c 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -210,20 +210,17 @@ def index(self, value): def count(self, value): return self._partitions.count(_canonicalize_partition(value)) - def with_partitions(self, new_partitions): - return PartitionSpec(*new_partitions, unreduced=self.unreduced, - reduced=self.reduced) - - def with_unreduced(self, new_unreduced): - return PartitionSpec(*self._partitions, unreduced=new_unreduced, - reduced=self.reduced) + def update(self, **kwargs): + return PartitionSpec(*kwargs.pop("partitions", self._partitions), + unreduced=kwargs.pop("unreduced", self.unreduced), + reduced=kwargs.pop("reduced", self.reduced)) def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self._partitions] if len(out) < ndim: out.extend([None] * (ndim - len(out))) - return self.with_partitions(out) + return self.update(partitions=out) # TODO(phawkins): make this a decorator after the next jaxlib release. if not TYPE_CHECKING and jaxlib_extension_version >= 352: From 7904b86bbacdb72f7769a14b5e5d07bba05bd9e8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 13 Jun 2025 15:24:16 -0700 Subject: [PATCH 1689/1769] Remove `with_spec` from NamedSharding and replace with `.update` PiperOrigin-RevId: 771244288 --- jax/_src/core.py | 12 ++++----- jax/_src/interpreters/batching.py | 2 +- jax/_src/lax/control_flow/loops.py | 8 +++--- jax/_src/lax/lax.py | 42 ++++++++++++++--------------- jax/_src/lax/linalg.py | 4 +-- jax/_src/lax/windowed_reductions.py | 4 +-- jax/_src/named_sharding.py | 12 ++++++--- jax/_src/numpy/einsum.py | 2 +- jax/_src/pallas/pallas_call.py | 2 +- jax/_src/pjit.py | 2 +- jax/_src/prng.py | 2 +- jax/_src/sharding_impls.py | 4 +-- jax/experimental/multihost_utils.py | 2 +- 13 files changed, 52 insertions(+), 46 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ff3417a2d883..9c66866a0876 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1938,10 +1938,10 @@ def get_cur_mesh_sharding(spec=None): def _make_lengths_same(sharding, ndim): pspec = sharding.spec if ndim > len(pspec): - return sharding.with_spec(pspec._normalized_spec_for_aval(ndim)) + return sharding.update(spec=pspec._normalized_spec_for_aval(ndim)) if ndim < len(pspec): assert all(s is None for s in pspec[ndim:]), (ndim, pspec) - return sharding.with_spec(P(*pspec[:ndim], unreduced=pspec.unreduced)) + return sharding.update(spec=P(*pspec[:ndim], unreduced=pspec.unreduced)) assert False, "unreachable" # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with @@ -1965,7 +1965,7 @@ def _maybe_modify_sharding(sharding, ndim): elif sharding.mesh._are_all_axes_explicit: out = sharding else: - out = sharding.with_spec(modify_spec_for_auto_manual( + out = sharding.update(spec=modify_spec_for_auto_manual( sharding.spec, sharding.mesh)) if len(out.spec) != ndim: out = _make_lengths_same(out, ndim) @@ -2152,7 +2152,7 @@ def primal_dtype_to_tangent_dtype(primal_dtype): def primal_sharding_to_cotangent_sharding(sharding): new_spec = P(*sharding.spec, unreduced=sharding.spec.reduced, reduced=sharding.spec.unreduced) - return sharding.with_spec(new_spec) + return sharding.update(spec=new_spec) def pvary(x, axis_name): if not axis_name: @@ -2774,7 +2774,7 @@ def _map_shaped_array( # assert axis is None or aval.shape[axis] == size if axis is None: return aval - sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + sharding = aval.sharding.update(spec=tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, weak_type=aval.weak_type, sharding=sharding, vma=aval.vma) @@ -2783,7 +2783,7 @@ def _unmap_shaped_array( ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: - sharding = aval.sharding.with_spec(tuple_insert( + sharding = aval.sharding.update(spec=tuple_insert( aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, weak_type=aval.weak_type, sharding=sharding, diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 11f1102bd792..ddf0f6aa92fd 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1127,7 +1127,7 @@ def broadcast(x, sz, axis, mesh_axis=None): if x_aval.sharding.mesh.empty: mesh_axis = None new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) - sharding = x_aval.sharding.with_spec(new_spec) + sharding = x_aval.sharding.update(spec=new_spec) # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 35fd26b21783..9416711d478b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -616,7 +616,7 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) + sharding = aval.sharding.update(spec=(*length_spec, *aval.sharding.spec)) empty = core.pvary(lax.empty(aval.dtype), tuple(aval.vma)) return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding) @@ -2165,7 +2165,7 @@ def fun(*args): primitive=None, avals_in=[pred_aval], avals_out=[pred_aval.update( - shape=(), sharding=pred_aval.sharding.with_spec(()))], + shape=(), sharding=pred_aval.sharding.update(spec=()))], tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( @@ -2509,7 +2509,7 @@ def f(l): '0th dimension of leaf passed to `jax.lax.map` should be replicated.' f' Got {aval.str_short(True, True)}') if get_abstract_mesh()._are_all_axes_explicit: - out_s = aval.sharding.with_spec(P(None, None, *aval.sharding.spec[1:])) + out_s = aval.sharding.update(spec=P(None, None, *aval.sharding.spec[1:])) return auto_axes(f, out_sharding=out_s)(leaf) return f(leaf) @@ -2612,7 +2612,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] - out_s = (out_sharding.with_spec((keys.aval.sharding.spec[0], *out_sharding.spec)) + out_s = (out_sharding.update(spec=(keys.aval.sharding.spec[0], *out_sharding.spec)) if out_sharding is not None else None) key = keys[0] new_key, bits = lax.rng_bit_generator_p.bind( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 526a631a4121..d49d7cbd0f61 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -239,7 +239,7 @@ def broadcast_shardings(*avals): new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec) new_shape = (1,) * (ndim - a.ndim) + a.shape aval_list.append(a.update(shape=new_shape, - sharding=a.sharding.with_spec(new_spec))) + sharding=a.sharding.update(spec=new_spec))) return broadcasting_sharding_rule('broadcast_shardings', *aval_list) def _identity(x, **_): return x @@ -4595,7 +4595,7 @@ def _add_unreduced(out_sharding, x, y): f' reduce {lhs_str} via `reshard` before calling `add`.') else: res_unreduced = frozenset() - return out_sharding.with_spec(out_sharding.spec.update(unreduced=res_unreduced)) + return out_sharding.update(spec=out_sharding.spec.update(unreduced=res_unreduced)) add_p: Primitive = naryop(_input_dtype, [_num, _num], 'add', unreduced_rule=_add_unreduced) @@ -4995,7 +4995,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f"{x.str_short(short_dtypes=True)} to an extended dtype with element " f"shape {rep_aval.shape}") return x.update(shape=shape_prefix, dtype=edtype, - sharding=x.sharding.with_spec(spec_prefix)) + sharding=x.sharding.update(spec=spec_prefix)) elif isinstance(x, core.DShapedArray): return x.update(shape=shape_prefix, dtype=edtype) else: @@ -5088,9 +5088,9 @@ def _bitcast_convert_type_sharding_rule(operand, *, new_dtype): if old_nbits == new_nbits: return operand.sharding elif old_nbits > new_nbits: - return operand.sharding.with_spec((*operand.sharding.spec, None)) + return operand.sharding.update(spec=(*operand.sharding.spec, None)) else: - return operand.sharding.with_spec(operand.sharding.spec[:-1]) + return operand.sharding.update(spec=operand.sharding.spec[:-1]) def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): old_dtype = dtypes.canonicalize_dtype(operand.dtype) @@ -5364,7 +5364,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, out_axes = np.argsort(unsorted_axes) xs = x.aval.sharding inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) - ds = xs.with_spec(inverse_spec) + ds = xs.update(spec=inverse_spec) dot_general_out = dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type, out_sharding=ds) @@ -6426,7 +6426,7 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _broadcast_in_dim_typecheck_rule( _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): @@ -6764,7 +6764,7 @@ def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): for op, bdim in zip(batched_args, batch_dims) if bdim is not None) operands = [batching.moveaxis(op, bdim, 0) if bdim is not None else broadcast( - op, (size,), out_sharding=core.get_aval(op).sharding.with_spec( + op, (size,), out_sharding=core.get_aval(op).sharding.update(spec= (spec, *core.get_aval(op).sharding.spec))) for op, bdim in zip(batched_args, batch_dims)] return concatenate(operands, dimension + 1), 0 @@ -6978,7 +6978,7 @@ def _squeeze_sharding_rule(operand, *, dimensions): dims_set = set(dimensions) new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) if i not in dims_set) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) @@ -7131,7 +7131,7 @@ def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): else: sp = next(fs) new_spec.append(sp) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _get_spec_size(sp, mesh): tup_sp = sp if isinstance(sp, tuple) else (sp,) @@ -7154,7 +7154,7 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): else: new_spec.append(sp) assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): @@ -7179,7 +7179,7 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): new_spec.append(next(op_spec)) assert next(op_spec, None) is None assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, @@ -7206,7 +7206,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): if dimensions is None: return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] else: - t_s = operand.aval.sharding.with_spec( + t_s = operand.aval.sharding.update(spec= tuple(map(lambda s: s if s is None else str(s), np.take(operand.aval.sharding.spec, dimensions)))) return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), @@ -7307,7 +7307,7 @@ def _transpose_shape_rule(operand, *, permutation): def _transpose_sharding_rule(operand, *, permutation): o_spec = operand.sharding.spec new_spec = [o_spec[old_idx] for old_idx in permutation] - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args @@ -7534,7 +7534,7 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): operand_avals, _ = split_list(avals, [len(avals) // 2]) - return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) + return [op.sharding.update(spec=tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): @@ -7699,7 +7699,7 @@ def _reduce_op_sharding_rule(operand, *, axes): axes = frozenset(axes) new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) if i not in axes)) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), @@ -7770,7 +7770,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): def _argminmax_sharding_rule(operand, *, axes, index_dtype): axis, = axes - return operand.sharding.with_spec( + return operand.sharding.update(spec= util.tuple_delete(operand.sharding.spec, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): @@ -7846,7 +7846,7 @@ def _reduce_logical_shape_rule(operand, *, axes): return tuple(np.delete(operand.shape, axes)) def _reduce_logical_sharding_rule(operand, *, axes): - return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes)) + return operand.sharding.update(spec=tuple_delete(operand.sharding.spec, axes)) reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', @@ -8062,7 +8062,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): mlir.flatten_ir_values(operands), dimension=mlir.i64_attr(dimension), is_stable=ir.BoolAttr.get(is_stable)) - scalar_s = lambda a: a.sharding.with_spec(P()) + scalar_s = lambda a: a.sharding.update(spec=P()) scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval)) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) @@ -8834,7 +8834,7 @@ def _const(example, val): def _zero(x): x_aval = core.get_aval(x) out = full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + sharding=x_aval.sharding.update(spec=P())) out = core.pvary(out, tuple(x_aval.vma)) return out @@ -8843,7 +8843,7 @@ def _zero(x): def _one(x): x_aval = core.get_aval(x) out = full_like(x, shape=(), fill_value=1, - sharding=x_aval.sharding.with_spec(P())) + sharding=x_aval.sharding.update(spec=P())) out = core.pvary(out, tuple(x_aval.vma)) return out diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 9e4d188579bb..7f7b1056dbee 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -732,14 +732,14 @@ def linalg_sharding_rule( sharding = avals[0].sharding if multiple_results: return [ - sharding.with_spec( + sharding.update(spec= P(*(tuple(batch_spec) + (None,) * (len(s) - len(batch_spec)))) ) for s in output_shapes ] else: ndim = len(output_shapes) - len(batch_spec) - return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) + return sharding.update(spec=P(*(tuple(batch_spec) + (None,) * ndim))) def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): output_shapes = shape_rule(*avals, **kwargs) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 41ea90804d7b..e28677c99a72 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -630,7 +630,7 @@ def _reduce_window_lower( operand_aval, = ctx.avals_in scalar_aval = operand_aval.update( - shape=(), sharding=operand_aval.sharding.with_spec(())) + shape=(), sharding=operand_aval.sharding.update(spec=())) return mlir.reduce_window( ctx, @@ -687,7 +687,7 @@ def _select_and_scatter_lower( operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out scalar_aval = operand_aval.update( - shape=(), sharding=operand_aval.sharding.with_spec(())) + shape=(), sharding=operand_aval.sharding.update(spec=())) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = hlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 60336a0822e6..bddf211cdc43 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -230,12 +230,18 @@ def is_fully_replicated(self) -> bool: return num_partitions == 1 def with_memory_kind(self, kind: str) -> NamedSharding: - return NamedSharding(self.mesh, self.spec, memory_kind=kind) + return self.update(memory_kind=kind) - def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + def update(self, **kwargs) -> NamedSharding: + spec = kwargs.pop("spec", self.spec) if not isinstance(spec, PartitionSpec): spec = PartitionSpec(*spec) - return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + return NamedSharding( + mesh=kwargs.pop("mesh", self.mesh), + spec=spec, + memory_kind=kwargs.pop("memory_kind", self.memory_kind), + _logical_device_ids=kwargs.pop("_logical_device_ids", + self._logical_device_ids)) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 0934cddc8a15..544e29020bf5 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -554,7 +554,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): dot_general_out_sharding = None elif out_sharding is not None and names != result_names: if len(result_names) > len(out_sharding.spec): - out_sharding = out_sharding.with_spec( + out_sharding = out_sharding.update(spec= out_sharding.spec._normalized_spec_for_aval(len(result_names))) spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2375e1c30578..14fab37f01a5 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -910,7 +910,7 @@ def index_rewrite_kernel(*indexer_args): batched_out_avals = [] for aval in out_avals: - sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) + sharding = aval.sharding.update(spec=tuple_insert(aval.sharding.spec, 0, None)) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) batched_out_avals = tuple(batched_out_avals) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f1ba166230b0..1aa6a4d73daf 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2993,7 +2993,7 @@ def reshard(xs, out_shardings): 'Reshard should only be used with out_shardings which are non-None ' f'and have a nonempty mesh. Got sharding {s}.' ) - ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error + ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error out_flat.append(reshard_p.bind(x, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 08e81a76c1c0..47ce6f6a7ed9 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -597,7 +597,7 @@ def random_split_abstract_eval(keys_aval, *, shape): # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec), keys_aval.vma) + keys_aval.sharding.update(spec=new_spec), keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d406faceabb7..7be94adc2665 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1185,7 +1185,7 @@ def make_key_array_phys_sharding(aval, sharding): elif isinstance(sharding, NamedSharding): elt_aval = core.physical_element_aval(aval.dtype) trailing_spec = [None] * elt_aval.ndim - return sharding.with_spec(PartitionSpec(*sharding.spec, *trailing_spec)) + return sharding.update(spec=PartitionSpec(*sharding.spec, *trailing_spec)) else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( @@ -1248,7 +1248,7 @@ def logical_sharding(logical_shape, dtype, phys_sharding) -> jsharding.Sharding: *[None] * (len(phys_shape) - len(phys_sharding.spec))) else: phys_spec = phys_sharding.spec # type: ignore - return phys_sharding.with_spec(phys_spec[:-elt_aval.ndim]) + return phys_sharding.update(spec=phys_spec[:-elt_aval.ndim]) else: return get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 07a0f747443c..ee7c4a8f9592 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -100,7 +100,7 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: if isinstance(inp.sharding, sharding_impls.NamedSharding): - reps = inp.sharding.with_spec(P()) + reps = inp.sharding.update(spec=P()) else: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) From f06888f9b3cf2adaae91bb644b7e09f241cce402 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 13 Jun 2025 18:42:25 -0700 Subject: [PATCH 1690/1769] [JAX] Fix the test names in colocated_python_test.py to following the standard snake case PiperOrigin-RevId: 771296764 --- tests/colocated_python_test.py | 36 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 892414ee1366..c326aed1c0b7 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -58,7 +58,7 @@ def setUp(self): " requires NumPy 2.0.0 or later" ) - def testColocatedCpuDevices(self): + def test_colocated_cpu_devices(self): mesh = jax.sharding.Mesh( np.array(jax.local_devices()[:1]).reshape((1, 1)), ("x", "y") ) @@ -72,7 +72,7 @@ def testColocatedCpuDevices(self): ) self.assertEqual(cpu_mesh1, cpu_mesh2) - def testMakeColocatedPythonProgram(self): + def test_make_colocated_python_program(self): def add_one(x): return x + 1 @@ -86,7 +86,7 @@ def add_one(x): ) del program - def testSimpleFunction(self): + def test_simple_function(self): @colocated_python.colocated_python def add_one(x): return x + 1 @@ -106,7 +106,7 @@ def add_one(x): self.assertEqual(out, np.array(2)) self.assertEqual(count(), 1) - def testSimpleFunctionWithTree(self): + def test_simple_function_with_tree(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) @@ -126,7 +126,7 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 1) - def testEmptyInputFailsWithoutSpecialization(self): + def test_empty_input_fails_without_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) @@ -138,7 +138,7 @@ def make_zero(): ): _ = make_zero() - def testEmptyInputWithDevicesSpecialization(self): + def test_empty_input_with_devices_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) @@ -157,7 +157,7 @@ def make_zero(): self.assertEqual(out, np.array(0)) self.assertEqual(count(), 1) - def testInputPolymorphismWithoutOutSpecsFn(self): + def test_input_polymorphism_without_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) @@ -191,7 +191,7 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 2) - def testInputPolymorphismAllowedWithOutSpecsFn(self): + def test_input_polymorphism_allowed_with_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) @@ -230,7 +230,7 @@ def add_one(x): ("on_main_thread", True), ("on_non_main_thread", False), ) - def testSequentialExecution(self, on_main_thread: bool): + def test_sequential_execution(self, on_main_thread: bool): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -267,7 +267,7 @@ def sleep_twice_and_wait(x: jax.Array) -> None: # around 5 seconds. self.assertGreaterEqual(elapsed_time, 10) - def testConcurrentExecution(self): + def test_concurrent_execution(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -304,7 +304,7 @@ def sleep_and_wait(x: jax.Array) -> None: # around 15 seconds. self.assertLess(elapsed_time, 10) - def testInputsWithDifferentDeviceOrders(self): + def test_inputs_with_different_device_orders(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices())[:2] if len(cpu_devices) < 2: self.skipTest("Not enough CPU devices") @@ -346,7 +346,7 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array: out = jax.device_get(out) np.testing.assert_equal(out, np.array([2 + 4, 0 + 8])) - def testModuleVariableAccess(self): + def test_module_variable_access(self): try: # The following pattern of storing and accessing non-serialized state in # the Python module is discouraged for storing user-defined state. @@ -387,7 +387,7 @@ def get_global_state(x: jax.Array) -> jax.Array: if "_testing_global_state" in colocated_python.__dict__: del colocated_python._testing_global_state - def testStringProcessing(self): + def test_string_processing(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -428,7 +428,7 @@ def f(x): ), ) - def testBinaryDataProcessing(self): + def test_binary_data_processing(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") @@ -470,7 +470,7 @@ def f(x): self.assertEqual(out_ints[0], 1002) self.assertEqual(out_ints[1], 1003) - def testDetectInvalidMeshDevice(self): + def test_detect_invalid_mesh_device(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if jax.local_devices()[0].id == cpu_devices[0].id: self.skipTest( @@ -491,7 +491,7 @@ def make_zero() -> jax.Array: make_zero = make_zero.specialize(devices=cpu_devices) jax.block_until_ready(make_zero()) - def testObjectLifecycle(self): + def test_object_lifecycle(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) @@ -563,7 +563,7 @@ def cleanup(): finally: cleanup() - def testStatefulObject(self): + def test_stateful_object(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) @colocated_python.colocated_python_class @@ -595,7 +595,7 @@ def fetch(self, x: jax.Array) -> jax.Array: out = jax.device_get(value.fetch(x)) self.assertEqual(out, np.array(7)) - def testObjectWithCapturedSharding(self): + def test_object_with_captured_sharding(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") From e28d6ed99c7fbff50f8f2d31f0d0993e8ce6f0c1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 14 Jun 2025 00:30:50 -0700 Subject: [PATCH 1691/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b92dc74a1de63a445456cfdeb4a9c8d4d0ba63e5. PiperOrigin-RevId: 771371907 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 122f6c0c3d47..6855092970e0 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f67ae6dc96dacaffe4d3a9b50de3dbe3ca89fffd" -XLA_SHA256 = "af2cfc63a5be306b95c8f7f55dd74d1fcfbd589b69c68d060d641f4d233fc715" +XLA_COMMIT = "b92dc74a1de63a445456cfdeb4a9c8d4d0ba63e5" +XLA_SHA256 = "73afaf69613b184b93b676751495f7774dd5f99e6bbedd4a084df495ec221ae3" def repo(): tf_http_archive( From d065d2ace718ad0e89101ddc2c3ee40fa29fea2e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 14 Jun 2025 06:12:12 -0700 Subject: [PATCH 1692/1769] Make mosaic_gpu equation params hashable. PiperOrigin-RevId: 771434581 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 8 ++++++++ jax/_src/pallas/mosaic_gpu/primitives.py | 9 +++++++-- jax/experimental/mosaic/gpu/core.py | 4 ++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 6c5320e9e84a..bcb0015d71b2 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -102,6 +102,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", + "//jax:frozen_dict", "//jax:mosaic_gpu", "//jax:pretty_printer", "//jax:tree_util", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 7c27ae429114..18fc924e8117 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -113,6 +113,10 @@ class CompilerParams(pallas_core.CompilerParams): lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): + if self.dimension_semantics is not None: + object.__setattr__( + self, "dimension_semantics", tuple(self.dimension_semantics) + ) if bool(self.profile_space) ^ bool(self.profile_dir): raise ValueError( "Either both profile_space and profile_dir must be set, or neither." @@ -988,6 +992,10 @@ def __post_init__(self): "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) + object.__setattr__(self, "grid", tuple(self.grid)) + object.__setattr__(self, "grid_names", tuple(self.grid_names)) + object.__setattr__(self, "cluster", tuple(self.cluster)) + object.__setattr__(self, "cluster_names", tuple(self.cluster_names)) @property def backend(self) -> str: diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 0c55eaefc345..d31c7fb85bd8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -26,6 +26,7 @@ import jax from jax._src import core as jax_core +from jax._src import frozen_dict from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util @@ -1558,6 +1559,10 @@ class ParameterizedLayout: args: Sequence[Any] kwargs: Any + def __post_init__(self): + object.__setattr__(self, "args", tuple(self.args)) + object.__setattr__(self, "kwargs", frozen_dict.FrozenDict(self.kwargs)) + def to_mgpu(self) -> mgpu.FragmentedLayout: return self.layout_cls.to_mgpu(*self.args, **self.kwargs) @@ -1963,8 +1968,8 @@ def wrapper(*args): flat_ret = inline_mgpu_p.bind( *raw_flat_args, *flat_ref_transforms, - flat_arg_types=flat_arg_types, - flat_ret_ty=flat_ret_ty, + flat_arg_types=tuple(flat_arg_types), + flat_ret_ty=tuple(flat_ret_ty), pytree_ret_ty=pytree_ret_ty, pytree_args=treedef, pytree_ref_transforms=pytree_ref_transforms, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 2d17d4f1857a..6e664f626e4d 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -740,6 +740,10 @@ def as_gpu_kernel( elif not isinstance(inout_shape, tuple): inout_shape = (inout_shape,) + inout_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + inout_shape) + out_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape) module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, inout_shape, From 9678a764e83b5fab3a877348c6ccc503d9808145 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 14 Jun 2025 23:56:01 -0700 Subject: [PATCH 1693/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/59d2fd81f3f7da30aee522686999d94912c75ec0. PiperOrigin-RevId: 771620249 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6855092970e0..2dc51fff9913 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b92dc74a1de63a445456cfdeb4a9c8d4d0ba63e5" -XLA_SHA256 = "73afaf69613b184b93b676751495f7774dd5f99e6bbedd4a084df495ec221ae3" +XLA_COMMIT = "59d2fd81f3f7da30aee522686999d94912c75ec0" +XLA_SHA256 = "9d82399916399ce7de40ab278a6cd5b77d6e74604d3222c9ab582dc0bedc7558" def repo(): tf_http_archive( From b25655cfe477b43e4c28426e8f05e5e0cb435485 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 15 Jun 2025 23:32:19 -0700 Subject: [PATCH 1694/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/71e3a0d40d7f368b239dc664e3bb0a455a80541b. PiperOrigin-RevId: 771882677 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 2dc51fff9913..ba4d13901b39 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "59d2fd81f3f7da30aee522686999d94912c75ec0" -XLA_SHA256 = "9d82399916399ce7de40ab278a6cd5b77d6e74604d3222c9ab582dc0bedc7558" +XLA_COMMIT = "71e3a0d40d7f368b239dc664e3bb0a455a80541b" +XLA_SHA256 = "5729c77cc32c5513f18802dd46dcf4c5b0457a0fb5ad29d1c543734d23afdaba" def repo(): tf_http_archive( From 639216fa37ee483cd648717125b6a3c2cbed02ce Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Mon, 16 Jun 2025 05:17:36 -0700 Subject: [PATCH 1695/1769] PR #28102: Add cudnn paged attention support in JAX cuDNN SDPA API Imported from GitHub PR https://github.com/jax-ml/jax/pull/28102 * add cudnn support for paged attention described in https://arxiv.org/pdf/2309.06180. * add new arguments `page_table_k` and `page_table_v`. * create a new interface `paged_attention` for paged attention. Copybara import of the project: -- 003d8b119383311e6cf623316e9eb786641e2c65 by cjkkkk : add paged attn Merging this change closes #28102 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/28102 from Cjkkkk:page_attention 003d8b119383311e6cf623316e9eb786641e2c65 PiperOrigin-RevId: 771980902 --- jax/_src/cudnn/fused_attention_stablehlo.py | 266 +++++++++++++++----- tests/fused_attention_stablehlo_test.py | 65 +++++ 2 files changed, 264 insertions(+), 67 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7ae830e92fd..417198e575d2 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -121,6 +121,9 @@ def default_layouts(*shapes): def get_max_seg_per_batch(q_offsets): return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1 +def check_is_paged_attention(page_table_k): + return len(page_table_k.shape) == 4 + def create_dot_product_attention_backend_config_base( batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd ): @@ -228,6 +231,7 @@ def create_dot_product_attention_backend_config( layout, sliding_window_length, max_seg_per_batch, + is_paged_attention, is_bwd ): backend_config = create_dot_product_attention_backend_config_base( @@ -240,6 +244,7 @@ def create_dot_product_attention_backend_config( backend_config['cudnn_fmha_backend_config']["seed"] = seed backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch + backend_config['cudnn_fmha_backend_config']["is_paged_attention"] = is_paged_attention return json.dumps(backend_config) def create_dot_product_attention_fp8_backend_config( @@ -272,7 +277,7 @@ def get_custom_call_name(has_bias, has_dropout, is_bwd, is_fp8=False): ) def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, - q_offsets, kv_offsets, layout): + q_offsets, kv_offsets, page_table_k, page_table_v, layout): def check_eq(a, b, c, msg): if not (a == b == c): raise ValueError(f"{msg} must be same, got {a}, {b}, {b}") @@ -297,6 +302,22 @@ def check_eq(a, b, c, msg): kB, kS, kN, kH = key.shape vB, vS, vN, vH = value.shape + if page_table_k is not None and page_table_v is not None: + k_blocks, k_block_size = kB, kS + v_blocks, v_block_size = vB, vS + kB, _, k_blocks_per_batch, _ = page_table_k.shape + vB, _, v_blocks_per_batch, _ = page_table_v.shape + kS = k_blocks_per_batch * k_block_size + vS = v_blocks_per_batch * v_block_size + if kB * k_blocks_per_batch != k_blocks: + raise ValueError( + f"Key and page_table_k must have same number of blocks, " + f"got {k_blocks} vs {kB * k_blocks_per_batch}") + if vB * v_blocks_per_batch != v_blocks: + raise ValueError( + f"Value and page_table_v must have same number of blocks, " + f"got {v_blocks} vs {vB * v_blocks_per_batch}") + check_eq(qB, kB, vB, "QKV batch") check_eq(qH, kH, vH, "QKV dim_per_head") if kN != vN: @@ -333,7 +354,7 @@ def check_seqlen_offsets(tensor, name): def check_is_flash_attention( query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, - is_fp8=False): + is_paged_attention=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: _, _, T, H = query.shape @@ -370,6 +391,8 @@ def check_is_flash_attention( if is_packed and (cudnn_version < 90600 or not check_compute_capability("9.0")): raise NotImplementedError( "Packed layout requires cudnn version >= 9.6 and at least hopper arch.") + if is_paged_attention and cudnn_version < 90500: + raise NotImplementedError("Page attention requires cudnn version >= 9.5.") def check_cudnn_version(): # check if cuDNN is installed @@ -395,15 +418,16 @@ def is_cuda_compute_capability_equal(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, False, - get_max_seg_per_batch(q_offsets) > 1) + get_max_seg_per_batch(q_offsets) > 1, check_is_paged_attention(page_table_k)) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=False or return_residual) if return_residual: @@ -413,19 +437,20 @@ def _dot_product_attention_fwd( def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version, return_residual): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, cudnn_version, + return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, True, get_max_seg_per_batch(q_offsets) > 1) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, outputs[1], outputs[0]) + kv_offsets, page_table_k, page_table_v, outputs[1], outputs[0]) if return_residual: return tuple(outputs), res else: @@ -435,17 +460,17 @@ def _dot_product_attention_bwd_rule( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output) = res + page_table_k, page_table_v, activation, fwd_output) = res if return_residual: grad_output = grad_output[0] grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length ) - grads = (*grads,) + (None,) * (8 - len(grads)) + grads = (*grads,) + (None,) * (10 - len(grads)) return grads def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key): @@ -508,27 +533,28 @@ def _cu_offset(offsets, max_seq): def _dot_product_attention_fwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): # args: {Q, K, V, mask*, bias*} q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) outputs = _dot_product_attention_fwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) return outputs def _dot_product_attention_bwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, scale, + seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) grads = _dot_product_attention_bwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length) @@ -536,8 +562,8 @@ def _dot_product_attention_bwd_impl( def _dot_product_attention_fwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - *, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): + page_table_k, page_table_v, *, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape @@ -562,8 +588,8 @@ def _dot_product_attention_fwd_abstract( def _dot_product_attention_bwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, *, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, *, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): query_dtype = dtypes.canonicalize_dtype(query.dtype) key_dtype = dtypes.canonicalize_dtype(key.dtype) value_dtype = dtypes.canonicalize_dtype(value.dtype) @@ -601,8 +627,8 @@ def _dot_product_attention_bwd_abstract( def _dot_product_attention_fwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, - layout, sliding_window_length, is_training): + kv_offsets, page_table_k, page_table_v, scale, seed, dropout_rate, + variadic_args, mask_type, layout, sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -620,6 +646,8 @@ def _dot_product_attention_fwd_cuda_lowering( output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) + is_paged_attention = check_is_paged_attention(ir.RankedTensorType(page_table_k.type)) + output_shape = (B, N, T, H) softmax_stat_shape = (B * max_seg_per_batch, N, T) workspace_shape = (0,) @@ -629,19 +657,22 @@ def _dot_product_attention_fwd_cuda_lowering( backend_config = create_dot_product_attention_backend_config( B, N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=False) + is_paged_attention, is_bwd=False) # {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}} # {output, activation*, workspace} has_dropout = dropout_rate > 0 operands = [query, key, value] if has_bias: operands.append(bias) - if has_padding(mask_type) or max_seg_per_batch > 1: + if has_padding(mask_type) or max_seg_per_batch > 1 or is_paged_attention: operands.append(q_seqlen) operands.append(kv_seqlen) if max_seg_per_batch > 1: operands.append(q_offsets) operands.append(kv_offsets) + if is_paged_attention: + operands.append(page_table_k) + operands.append(page_table_v) custom_call_name = get_custom_call_name(has_bias, has_dropout, False) @@ -677,8 +708,8 @@ def _dot_product_attention_fwd_cuda_lowering( def _dot_product_attention_bwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -708,7 +739,7 @@ def _dot_product_attention_bwd_cuda_lowering( backend_config = create_dot_product_attention_backend_config( B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=True) + False, is_bwd=True) # {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*, # q_offsets*, kv_offsets*} # {dQ, dK, dV, dbias*, workspace} @@ -776,7 +807,7 @@ def _dot_product_attention_fwd_batcher( mask_type, layout, sliding_window_length, is_training): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, \ - q_offsets, kv_offsets = batched_args + q_offsets, kv_offsets, page_table_k, page_table_v = batched_args query_bdim = batch_dims[0] if is_training: out_bdims = query_bdim, query_bdim @@ -804,7 +835,7 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) @@ -823,7 +854,7 @@ def _dot_product_attention_bwd_batcher( mask_type, layout, sliding_window_length): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \ - activation, fwd_output, grad_output = batched_args + page_table_k, page_table_v, activation, fwd_output, grad_output = batched_args query_bdim = batch_dims[0] out_bdims = query_bdim, query_bdim, query_bdim @@ -860,8 +891,8 @@ def _dot_product_attention_bwd_batcher( grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, ) @@ -936,7 +967,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo return [out_sharding] _dot_product_attention_fwd_lower = custom_partitioning( - _dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) + _dot_product_attention_fwd_impl, static_argnums=(10, 11, 12, 13, 14, 15, 16, 17)) def _dot_product_attention_fwd_infer_sharding_from_operands( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, @@ -985,7 +1016,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args): return out_shardings _dot_product_attention_bwd_lower = custom_partitioning( - _dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17) + _dot_product_attention_bwd_impl, static_argnums=(13, 14, 15, 16, 17, 18, 19) ) def _dot_product_attention_bwd_infer_sharding_from_operands( @@ -1110,7 +1141,7 @@ def not_implemented_sharding_rule(*args, **kwargs): _dot_product_attention_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(10, 11, 12, 13, 14, 15, 16, 17, 18)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -1119,6 +1150,8 @@ def _dot_product_attention(query: Array, kv_seqlen: Array, q_offsets: Array, kv_offsets: Array, + page_table_k: Array, + page_table_v: Array, scale: float, seed: int, dropout_rate: float, @@ -1130,7 +1163,7 @@ def _dot_product_attention(query: Array, return_residual: bool): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, cudnn_version=cudnn_version, return_residual=return_residual) @@ -1714,7 +1747,119 @@ def _dot_product_attention_fp8(query: Array, _dot_product_attention_fp8.defvjp(_dot_product_attention_fp8_fwd_rule, _dot_product_attention_fp8_bwd_rule) +def combine_bias_and_mask(bias, mask, dtype): + if bias is not None: + # reshape bias to have 4D shape + bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) + + if mask is not None: + if mask.dtype == jnp.bool: + large_negative_number = get_large_negative_number(dtype) + mask = jnp.where(mask, jnp.asarray(0, dtype), large_negative_number) + # reshape mask to have 4D shape + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + + # combine bias and mask + if bias is None: + bias = mask + else: + if mask is not None: + # should be broadcast to same shape + bias = bias + mask + return bias + # User interface +def paged_attention( + query: Array, + key: Array, + value: Array, + q_seqlen: Array, + kv_seqlen: Array, + page_table_k: Array, + page_table_v: Array, + bias: Array | None = None, + mask: Array | None = None, + fp8_params: FP8Params | None = None, + *, + scale: float = 1.0, + mask_type: MaskType = MaskType.NO_MASK, + seed: int = 42, + dropout_rate: float = 0., + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None, + use_fp8: bool = False, + return_residual: bool = False +): + """Computes paged attention described in https://arxiv.org/pdf/2309.06180. + + B = batch size + S = length of the key/value (source) + T = length of the query (target) + N = number of attention heads + H = dimensions of each attention head. + + Args: + query: Queries for attention calculation with a shape of BTNH or BNTH. + key: Keys for attention calculation with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + value: Values to be used in attention with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + q_seqlen: Non padded sequence length of query with a shape of B. + kv_seqlen: Non padded sequence length of key and value with a shape of B. + page_table_k: page table for key of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + page_table_v: page table for value of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + bias: Bias to be added to logits with a shape of BNTS. + mask: Mask used to filter out logits with a shape of BNTS. + scale: Scale for the query. + qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, + BNSH. + sliding_window_length: Window size to make attention only attend to each + token's left local window (pos - sliding_window_length, pos] where `pos` + is the index of each token. E.g., if sliding_window_length == 3 and the + sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. + use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. + Returns: + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) + """ + cudnn_version = check_cudnn_version() + layout = _normalize_layout(qkv_layout) + if use_fp8: + raise ValueError("Paged attention doesn't support fp8 for now.") + if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): + raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask.") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError( + f"Require sliding_window_length > 0, got {sliding_window_length}.") + + bias = combine_bias_and_mask(bias, mask, query.dtype) + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, None, None, + page_table_k, page_table_v, layout) + has_bias = bias is not None + has_dbias = has_bias and \ + should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + variadic_args = (has_bias, has_dbias) + + _not_used = jnp.zeros(0, dtype=query.dtype) + if bias is None: + bias = _not_used + + output = _dot_product_attention( + query, key, value, bias, q_seqlen, kv_seqlen, _not_used, _not_used, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) + return output + + def dot_product_attention( query: Array, key: Array, @@ -1815,7 +1960,8 @@ def dot_product_attention( f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}" ) check_fp8_params(fp8_params) - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) output, amax_s, amax_o = _dot_product_attention_fp8( query, key, value, fp8_params, scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version @@ -1830,44 +1976,30 @@ def dot_product_attention( if q_offsets is not None and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to use packed layout") - if bias is not None: - # reshape bias to have 4D shape - bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) - - if mask is not None: - if mask.dtype == jnp.bool: - large_negative_number = get_large_negative_number(query.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) - # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] - - # combine bias and mask - if bias is None: - bias = mask - else: - if mask is not None: - # should be broadcast to same shape - bias = bias + mask - - # check if input shape and data type is compatible - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + bias = combine_bias_and_mask(bias, mask, query.dtype) + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) has_bias = bias is not None has_dbias = has_bias and \ should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] variadic_args = (has_bias, has_dbias) + _not_used = jnp.zeros(0, dtype=query.dtype) if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) + bias = _not_used if q_seqlen is None: - q_seqlen = jnp.zeros(0, dtype=query.dtype) + q_seqlen = _not_used if kv_seqlen is None: - kv_seqlen = jnp.zeros(0, dtype=query.dtype) + kv_seqlen = _not_used if q_offsets is None: - q_offsets = jnp.zeros(0, dtype=query.dtype) + q_offsets = _not_used if kv_offsets is None: - kv_offsets = jnp.zeros(0, dtype=query.dtype) + kv_offsets = _not_used + output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout.value, - sliding_window_length, cudnn_version, return_residual) + _not_used, _not_used, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) return output diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 64e0f4377462..7c3bd145dab5 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -24,6 +24,7 @@ from jax._src import test_util as jtu from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, + paged_attention, check_is_flash_attention, check_cudnn_version, MaskType, @@ -788,6 +789,70 @@ def train(query, key, value, grads): outs = jitted_sdpa_train(query, key, value, (grad, grad_stat)) assert len(outs) == 2 + @jtu.sample_product( + batch_size=[4], + q_seq_len=[1, 1024], + kv_seq_len=[1024], + num_heads=[8], + head_dim=[64, 128], + block_size=[64, 128], + dtype=[jnp.float16, jnp.bfloat16] + ) + @jtu.run_on_devices("cuda") + def test_sdpa_paged_attention(self, batch_size, q_seq_len, kv_seq_len, + num_heads, head_dim, block_size, dtype): + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 90500: + self.skipTest("Requires >= cuDNN 9.5.0") + + keys = jax.random.split(jax.random.key(0), 5) + blocks_per_batch = kv_seq_len // block_size + num_blocks = batch_size * blocks_per_batch + + # different q_seq_len for prefill and decode + q = jax.random.normal( + keys[0], (batch_size, q_seq_len, num_heads, head_dim), dtype=dtype) + k_container = jax.random.normal( + keys[1], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + v_container = jax.random.normal( + keys[2], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + page_table_k = jax.random.randint( + keys[3], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + page_table_v = jax.random.randint( + keys[4], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + # full page table + q_seqlen = jnp.full((batch_size,), q_seq_len, jnp.int32) + kv_seqlen = jnp.full((batch_size,), kv_seq_len, jnp.int32) + + def unpaged(paged, page_table): + output = jnp.zeros((batch_size, kv_seq_len, num_heads, head_dim), dtype=dtype) + for b in range(batch_size): + for block in range(blocks_per_batch): + block_idx = page_table[b, 0, block, 0] + output = output.at[ + b, block * block_size : (block + 1) * block_size, :, : + ].set(paged[block_idx, :, :, :]) + return output + + k = unpaged(k_container, page_table_k) + v = unpaged(v_container, page_table_v) + + sdpa_infer = jax.jit(partial( + paged_attention, scale=1.0, mask_type=MaskType.NO_MASK) + ) + sdpa_infer_ref = jax.jit(partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0) + ) + + out = sdpa_infer(q, k_container, v_container, q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, page_table_k=page_table_k, page_table_v=page_table_v) + out_ref = sdpa_infer_ref(q, k, v) + self.assertArraysAllClose(out_ref, out_ref, rtol=1e-2, atol=1e-2) + @jtu.run_on_devices("cuda") def test_layouts(self): if jax.device_count() < 4: From 53849a56b20f38644999d68d218ef06070cb75be Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 16 Jun 2025 07:56:40 -0700 Subject: [PATCH 1696/1769] Add `update_vma` and `update_weak_type` override on AbstractTMEMRef so that it return an instance of AbstractTMEMRef and not AbstractMemoryRef PiperOrigin-RevId: 772026526 --- jax/_src/pallas/mosaic_gpu/core.py | 10 ++++++++++ jax/_src/pallas/pallas_call.py | 5 +++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 18fc924e8117..9718a8dd749d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -958,6 +958,16 @@ def __init__(self, inner_aval, memory_space, packed, collective): def __repr__(self) -> str: return f'TMEM({self.inner_aval.str_short()},packed={self.packed})' + def update_vma(self, vma): + return AbstractTMEMRef( + self.inner_aval.update_vma(vma), self.memory_space, self.packed, + self.collective) + + def update_weak_type(self, weak_type): + return AbstractTMEMRef( + self.inner_aval.update_weak_type(weak_type), self.memory_space, + self.packed, self.collective) + _WARPGROUP_AXIS_NAME = object() diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 14fab37f01a5..d95b0ebd0ddb 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1699,8 +1699,9 @@ def wrapped(*args): x.ref if isinstance(x, state_types.TransformedRef) else x for x in flat_kernel_args ) - flat_kernel_avals = tuple(a.update_vma(frozenset()) - for a in flat_kernel_avals) + if config._check_vma.value: + flat_kernel_avals = tuple(a.update_vma(frozenset()) + for a in flat_kernel_avals) # Note that only a subset of all transforms can be found here, and they are # never expected to contain any arrays. kernel_arg_transforms = tuple( From 1362f7f517c52a9eee963f62097c23fde11fdf35 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 16 Jun 2025 07:56:51 -0700 Subject: [PATCH 1697/1769] Add version guards to testAutoPgle PiperOrigin-RevId: 772026590 --- tests/pgle_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index e03e5127d023..35764daa7793 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -30,6 +30,7 @@ from jax._src import pjit from jax._src import profiler from jax._src import test_util as jtu +from jax._src.lib import jaxlib_extension_version from jax.experimental import profiler as exp_profiler from jax.experimental.serialize_executable import ( deserialize_and_load, @@ -133,6 +134,8 @@ def get_fdo_profiles(self, dump_dir): return jit_f_fdo_profiles def testAutoPgle(self): + if jaxlib_extension_version < 354: + self.skipTest('Requires jaxlib_extension_version >= 354') mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: From 173574a9cd1cfe31652a53ffed0fa95c72ad0f25 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 16 Jun 2025 09:03:38 -0700 Subject: [PATCH 1698/1769] Bump the libtpu check to 6/20 PiperOrigin-RevId: 772049204 --- tests/pallas/tpu_pallas_async_test.py | 8 ++++---- tests/pallas/tpu_pallas_test.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index f4f06c178d8d..f224987fe2a7 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -205,8 +205,8 @@ def setUp(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 15): - self.skipTest('Requires libtpu built after 2025-06-15') + if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_async_copy(self): @jax.jit @@ -834,8 +834,8 @@ def setUp(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 15): - self.skipTest('Requires libtpu built after 2025-06-15') + if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_stateful_async_copy(self): @jax.jit diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index ec61ca61f1f2..9a85c869a190 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1225,8 +1225,8 @@ def test_output_dma_semaphore_ref(self): self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 15): - self.skipTest('Requires libtpu built after 2025-06-15') + if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + self.skipTest('Requires libtpu built after 2025-06-20') def kernel(x_hbm_ref, y_hbm_ref, sem_out): pltpu.make_async_copy( From 0e7c96a54a98d0f2b7e1c29bd7cc61b1b9bcbf59 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 16 Jun 2025 09:19:27 -0700 Subject: [PATCH 1699/1769] Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Python binding PiperOrigin-RevId: 772055206 --- jaxlib/_jax/pytree.pyi | 6 ----- jaxlib/pytree.cc | 59 ------------------------------------------ jaxlib/pytree.h | 5 ---- jaxlib/pytree_test.py | 36 +++----------------------- 4 files changed, 4 insertions(+), 102 deletions(-) diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi index 0cec90db631c..2a33203abb1d 100644 --- a/jaxlib/_jax/pytree.pyi +++ b/jaxlib/_jax/pytree.pyi @@ -121,12 +121,6 @@ class PyTreeDef: def from_iterable_tree(self, __xs: Any): ... def node_data(self) -> Tuple[type, Any] | None: ... def children(self) -> list[PyTreeDef]: ... - @staticmethod - def make_from_node_data_and_children( - registry: PyTreeRegistry, - node_data: Tuple[type, Any] | None, - children: Iterable[PyTreeDef], - ) -> PyTreeDef: ... num_leaves: int num_nodes: int diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc index 6e2f7af98629..272ac0c82859 100644 --- a/jaxlib/pytree.cc +++ b/jaxlib/pytree.cc @@ -1551,60 +1551,6 @@ std::optional> PyTreeDef::GetNodeData() } } -nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( - nb_class_ptr registry, - std::optional> node_data, - nb::iterable children) { - nb_class_ptr result = - make_nb_class(std::move(registry)); - int num_leaves = 0; - int arity = 0; - for (nb::handle pchild : children) { - const PyTreeDef& child = nb::cast(pchild); - absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); - num_leaves += child.num_leaves(); - ++arity; - } - result->traversal_.emplace_back(); - auto& node = result->traversal_.back(); - node.arity = arity; - node.custom = nullptr; - node.num_leaves = num_leaves; - node.num_nodes = result->traversal_.size(); - if (node_data == std::nullopt) { - node.kind = PyTreeKind::kLeaf; - ++node.num_leaves; - return result; - } - int is_nt = PyObject_IsSubclass(node_data->first.ptr(), - reinterpret_cast(&PyTuple_Type)); - if (is_nt == -1) { - throw nb::python_error(); - } - if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { - node.kind = PyTreeKind::kNamedTuple; - node.node_data = node_data->first; - return result; - } - auto* registration = result->registry()->Lookup(node_data->first); - if (registration == nullptr) { - throw std::logic_error(absl::StrFormat( - "Could not find type: %s.", - nb::cast(nb::repr(node_data->first)))); - } - node.kind = registration->kind; - if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { - node.custom = registration; - node.node_data = node_data->second; - } else if (node.kind == PyTreeKind::kNamedTuple) { - node.node_data = node_data->first; - } else if (node.kind == PyTreeKind::kDict) { - node.sorted_dict_keys = - nb::cast>(node_data->second); - } - return result; -} - int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { Py_VISIT(node_data.ptr()); for (const auto& key : sorted_dict_keys) { @@ -1743,11 +1689,6 @@ void BuildPytreeSubmodule(nb::module_& m) { nb::arg("registry"), nb::arg("data")); treedef.def("node_data", &PyTreeDef::GetNodeData, "Returns None if a leaf-pytree, else (type, node_data)"); - treedef.def_static( - "make_from_node_data_and_children", - &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), - nb::arg("node_data").none(), nb::arg("children"), - "Reconstructs a pytree from `node_data()` and `children()`."); treedef.def("__getstate__", &PyTreeDef::ToPickle); treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { nb::tuple pickle = nb::cast(o); diff --git a/jaxlib/pytree.h b/jaxlib/pytree.h index 8f13a26135a5..f36d6999c887 100644 --- a/jaxlib/pytree.h +++ b/jaxlib/pytree.h @@ -311,11 +311,6 @@ class PyTreeDef { std::optional> GetNodeData() const; - static nb_class_ptr MakeFromNodeDataAndChildren( - nb_class_ptr registry, - std::optional> node_data, - nanobind::iterable children); - static PyType_Slot slots_[]; private: diff --git a/jaxlib/pytree_test.py b/jaxlib/pytree_test.py index a8846a91ea2b..0e5ccf69bdbe 100644 --- a/jaxlib/pytree_test.py +++ b/jaxlib/pytree_test.py @@ -37,7 +37,6 @@ def __init__(self, field0, field1): def to_iterable(self): return [self.field0, self.field1], (None,) - def from_iterable(state, values): del state return ExampleType2(field0=values[0], field1=values[1]) @@ -57,7 +56,7 @@ class Custom: class PyTreeTest(absltest.TestCase): - def roundtrip(self, example): + def roundtrip_proto(self, example): original = registry.flatten(example)[1] self.assertEqual( pytree.PyTreeDef.deserialize_using_proto( @@ -68,50 +67,23 @@ def roundtrip(self, example): def testSerializeDeserializeNoPickle(self): o = object() - self.roundtrip(({"a": o, "b": o}, [o, (o, o), None])) + self.roundtrip_proto(({"a": o, "b": o}, [o, (o, o), None])) def testSerializeWithFallback(self): o = object() with self.assertRaises(ValueError): - self.roundtrip({"a": ExampleType(field0=o, field1=o)}) + self.roundtrip_proto({"a": ExampleType(field0=o, field1=o)}) def testRegisteredType(self): o = object() with self.assertRaises(ValueError): - self.roundtrip({"a": ExampleType2(field0=o, field1=o)}) - - def roundtrip_node_data(self, example): - original = registry.flatten(example)[1] - restored = pytree.PyTreeDef.make_from_node_data_and_children( - registry, original.node_data(), original.children() - ) - self.assertEqual(restored, original) - - def testRoundtripNodeData(self): - o = object() - self.roundtrip_node_data([o, o, o]) - self.roundtrip_node_data((o, o, o)) - self.roundtrip_node_data({"a": o, "b": o}) - self.roundtrip_node_data({22: o, 88: o}) - self.roundtrip_node_data(None) - self.roundtrip_node_data(o) - self.roundtrip_node_data(ExampleType(field0=o, field1=o)) - self.roundtrip_node_data(ExampleType2(field0=o, field1=o)) + self.roundtrip_proto({"a": ExampleType2(field0=o, field1=o)}) def testCompose(self): x = registry.flatten(0)[1] y = registry.flatten((0, 0))[1] self.assertEqual((x.compose(y)).num_leaves, 2) - def testDataclassMakeFromNodeData(self): - c = Custom(1, "a") - c_leafs, c_tree = registry.flatten(c) - c_tree2 = c_tree.make_from_node_data_and_children( - registry, c_tree.node_data(), c_tree.children() - ) - self.assertEqual(c_tree2.unflatten(c_leafs), c) - self.assertEqual(str(c_tree2), str(c_tree)) - def testTpTraverse(self): self.assertContainsSubset( [ From dbde6c4430ce282a73118927fcc70ee5db41a983 Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Mon, 16 Jun 2025 09:25:12 -0700 Subject: [PATCH 1700/1769] [jax] Increase absolute test tolerance for lax_control_flow test PiperOrigin-RevId: 772057312 --- tests/lax_control_flow_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 954890a973cc..42bc953a8236 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1802,15 +1802,20 @@ def f(c, a): c = rng.randn(4) if scan is scan_with_new_checkpoint2: + atol = {} rtol = {np.float64: 1e-12, np.float32: 1e-4} elif scan is scan_with_for: + atol = {} rtol = {np.float64: 1e-12, np.float32: 1e-4} else: + atol = {np.float64: 1e-14} rtol = {np.float64: 1e-14, np.float32: 1e-4} ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_) expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_) - self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol) + self.assertAllClose( + ans, expected, check_dtypes=False, atol=atol, rtol=rtol + ) @parameterized.named_parameters( {"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}", From 49e52c08f3bbb2b59d30ea10a546d86c300f6b78 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Jun 2025 10:52:07 -0700 Subject: [PATCH 1701/1769] Fix some more instances of unhashable jaxpr equation arguments. PiperOrigin-RevId: 772091145 --- jax/_src/interpreters/partial_eval.py | 16 ++++++++++------ jax/_src/lax/lax.py | 24 ++++++++++++++---------- jax/_src/lax/windowed_reductions.py | 2 +- jax/_src/tpu_custom_call.py | 5 +++++ 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a24c6aad4ecf..8e4ea16af0de 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1153,9 +1153,11 @@ def has_effects(effects) -> bool: outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.dst) - ] * len(outvars_copy), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + dict( + devices=(TransferToMemoryKind(policy.dst),) * len(outvars_copy), + srcs=(None,), + copy_semantics=(CopySemantics.COPY,), + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) @@ -1164,9 +1166,11 @@ def has_effects(effects) -> bool: residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.src) - ] * len(resvars), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + dict( + devices=(TransferToMemoryKind(policy.src),) * len(resvars), + srcs=(None,), + copy_semantics=(CopySemantics.COPY,) + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d49d7cbd0f61..34be5b0774a8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1877,14 +1877,18 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + attributes = [] + for k, v in kwargs.items(): + leaves, treedef = tree_util.tree_flatten(v) + leaves = tuple( + HashableArray(v) if isinstance(v, np.ndarray) else v for v in leaves + ) + attributes.append((k, leaves, treedef)) flat_args = core.standard_insert_pvary(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, - attributes=tuple( - (k, HashableArray(v) if isinstance(v, np.ndarray) else v) - for k, v in kwargs.items() - ), + attributes=tuple(attributes), version=version, jaxpr=closed_jaxpr, ) @@ -1897,7 +1901,7 @@ def _composite_lowering( ctx: mlir.LoweringRuleContext, *args: Any, name: str, - attributes: Sequence[tuple[str, Any]], + attributes: Sequence[tuple[str, tuple[Any, ...], tree_util.PyTreeDef]], version: int, jaxpr: core.ClosedJaxpr, ): @@ -1924,11 +1928,11 @@ def _composite_lowering( ctx.avals_out, ctx.tokens_in, ) - composite_attrs = { - k : mlir.ir_attribute(v) - for k, v in attributes - if v is not None - } + composite_attrs = {} + for k, leaves, treedef in attributes: + v = treedef.unflatten(leaves) + if v is not None: + composite_attrs[k] = mlir.ir_attribute(v) symbol_name = func_op.name.value composite = hlo.CompositeOp( func_op.type.results, diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index e28677c99a72..7a939f1cd55f 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -79,7 +79,7 @@ def _reduce_window( padding = tuple(lax.padtype_to_pads( flat_operands[0].shape, dilated_window_dims, window_strides, padding)) else: - padding = tuple(padding) + padding = tuple((x, y) for x, y in padding) if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 7f308541b49f..8140a57f6cce 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -128,6 +128,11 @@ class CustomCallBackendConfig: active_core_count: int | None input_memory_spaces: tuple[MemorySpace | None, ...] | None + def __post_init__(self): + if self.allow_input_fusion is not None: + object.__setattr__(self, "allow_input_fusion", + tuple(self.allow_input_fusion)) + # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. def __repr__(self): From e5af088039181a89edc992c000ee97aeb3f05f32 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 16 Jun 2025 10:55:01 -0700 Subject: [PATCH 1702/1769] [export] Add back-compat test for tridiagonal solve on GPU PiperOrigin-RevId: 772092105 --- jax/_src/export/_export.py | 2 + .../cuda_tridiagonal_solve.py | 84 +++++++++++++++++++ tests/export_back_compat_test.py | 24 +++++- tests/export_harnesses_multi_platform_test.py | 17 ---- 4 files changed, 109 insertions(+), 18 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 5d88f530bb55..7475a28c74f8 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1088,6 +1088,8 @@ def _check_lowering(lowering) -> None: "hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi", # tridiagonal on GPU "cusolver_sytrd_ffi", + # tridiagonal_solve on GPU + "cusparse_gtsv2_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py new file mode 100644 index 000000000000..c81d4d4a139d --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py @@ -0,0 +1,84 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32 + +data_2025_06_16 = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.], dtype=float32), array([1., 1., 1.], dtype=float32), array([1., 2., 0.], dtype=float32), array([[1.], + [1.], + [1.]], dtype=float32)), + expected_outputs=(array([[ 0.57142854], + [ 0.42857146], + [-0.2857143 ]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf32> loc("dl"), %arg1: tensor<3xf32> loc("d"), %arg2: tensor<3xf32> loc("du"), %arg3: tensor<3x1xf32> loc("b")) -> (tensor<3x1xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> loc(#loc6) + return %0 : tensor<3x1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\t\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.]), array([1., 1., 1.]), array([1., 2., 0.]), array([[1.], + [1.], + [1.]])), + expected_outputs=(array([[ 0.5714285714285714 ], + [ 0.42857142857142855], + [-0.2857142857142857 ]]),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf64> loc("dl"), %arg1: tensor<3xf64> loc("d"), %arg2: tensor<3xf64> loc("du"), %arg3: tensor<3x1xf64> loc("b")) -> (tensor<3x1xf64> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3x1xf64>) -> tensor<3x1xf64> loc(#loc6) + return %0 : tensor<3x1xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\x0b\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 0f443ae47929..258164cdc615 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -51,6 +51,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd +from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_solve from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -142,6 +143,7 @@ def test_custom_call_coverage(self): cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, + cuda_tridiagonal_solve.data_2025_06_16, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, cpu_triangular_solve_blas_trsm.data_2023_07_16, @@ -728,7 +730,7 @@ def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") - def test_gpu_tridiagonal_solver_sytrd(self, dtype_name): + def test_gpu_tridiagonal_sytrd(self, dtype_name): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") @@ -743,6 +745,26 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64")) + @jax.default_matmul_precision("float32") + def test_gpu_tridiagonal_solve(self, dtype_name): + if not config.enable_x64.value and dtype_name == "f64": + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] + def func(dl, d, du, b): + return lax.linalg.tridiagonal_solve(dl, d, du, b) + + rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12)[dtype_name] + + data = self.load_testdata( + cuda_tridiagonal_solve.data_2025_06_16[dtype_name] + ) + self.run_one_test(func, data, atol=atol, rtol=rtol) + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index ef9d1e04c796..b91a1fc550bd 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -23,7 +23,6 @@ from collections.abc import Callable import math -import re from absl import logging from absl.testing import absltest @@ -39,13 +38,6 @@ from jax import random -def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: - if not parts: - return re.compile("matches_no_test") - else: - return re.compile("(" + "|".join(parts) + ")") - - class PrimitiveTest(jtu.JaxTestCase): def setUp(self): @@ -84,10 +76,6 @@ def test_prim(self, harness: test_harnesses.Harness): self.skipTest("Eigenvalues are sorted and it is not correct to compare " "decompositions for equality.") - if (jtu.device_under_test() == "gpu" - and "tridiagonal_solve_" in harness.fullname): - self.skipTest("tridiagonal_solve_ is not yet guaranteed stable.") - if harness.params.get("enable_xla", False): self.skipTest("enable_xla=False is not relevant") @@ -98,11 +86,6 @@ def test_prim(self, harness: test_harnesses.Harness): for l in harness.jax_unimplemented: if l.filter(dtype=harness.dtype): unimplemented_platforms = unimplemented_platforms.union(l.devices) - # Some primitive lowering rules need the GPU backend to be able to create - # CUDA lowering. - if ("tridiagonal_solve_" in harness.fullname - and all(d.platform != "gpu" for d in self.devices)): - unimplemented_platforms.add("gpu") if unimplemented_platforms: logging.info("Harness is not implemented on %s", unimplemented_platforms) From a78c6a70c1153c71b644dda7c08272216b74b734 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 16 Jun 2025 11:06:18 -0700 Subject: [PATCH 1703/1769] Set heartbeat_timeout argument and flag. PiperOrigin-RevId: 772096777 --- jaxlib/xla.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 969c2449f463..8206316bcb73 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -792,6 +792,9 @@ NB_MODULE(_jax, m) { -> std::unique_ptr { CoordinationServiceImpl::Options options; options.num_nodes = num_nodes; + options.heartbeat_timeout = + max_missing_heartbeats.value_or(10) * + absl::Seconds(heartbeat_interval.value_or(10)); if (heartbeat_interval.has_value()) { options.heartbeat_interval = absl::Seconds(*heartbeat_interval); } @@ -838,6 +841,9 @@ NB_MODULE(_jax, m) { if (shutdown_timeout.has_value()) { options.shutdown_timeout = absl::Seconds(*shutdown_timeout); } + options.heartbeat_timeout = + max_missing_heartbeats.value_or(10) * + absl::Seconds(heartbeat_interval.value_or(10)); if (heartbeat_interval.has_value()) { options.heartbeat_interval = absl::Seconds(*heartbeat_interval); } From 09d903fc9a60726d5b4e57769a6bdbd424d44b08 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Mon, 16 Jun 2025 11:48:13 -0700 Subject: [PATCH 1704/1769] Install SciPy from its source (head) to test against Python 3.14.0b1 Current SciPy releases don't support Python 3.14, and building from source will resolve compatibility issues introduced by the new Python version: https://github.com/jax-ml/jax/actions/runs/15678323189/job/44163705784. PiperOrigin-RevId: 772113604 --- .github/workflows/pytest_cpu.yml | 1 + .github/workflows/pytest_cuda.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 71b7d49f8049..e0be29ae3328 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -140,6 +140,7 @@ jobs: # Install build requirements for scipy apt update && apt upgrade -y && apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + $JAXCI_PYTHON -m uv pip install "git+https://github.com/scipy/scipy@main" # Install build requirements for pillow apt install -q -y libjpeg-dev --no-install-recommends diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 644aceaa803e..ed021a970ecd 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -141,6 +141,7 @@ jobs: # Install build requirements for scipy apt update && apt upgrade -y && apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + $JAXCI_PYTHON -m uv pip install "git+https://github.com/scipy/scipy@main" # Install build requirements for pillow apt install -q -y libjpeg-dev --no-install-recommends From 7aec14f1367607b79a39362edfe647943e76e7cd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Jun 2025 12:14:46 -0700 Subject: [PATCH 1705/1769] [doc] add missing axis_types documentation --- jax/_src/mesh.py | 14 ++++++++++++++ jax/_src/sharding_impls.py | 7 ++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index d9183f8805d7..442ca3f18d83 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -230,6 +230,8 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): axis_names: A sequence of resource axis names to be assigned to the dimensions of the ``devices`` argument. Its length should match the rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. Examples: @@ -244,6 +246,8 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) >>> out = jax.jit(lambda x: x * 2)(arr) >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ devices: np.ndarray @@ -440,6 +444,16 @@ class AbstractMesh(_BaseMesh): your mesh shape and axis names stay the same but the devices change. See the description of https://github.com/jax-ml/jax/pull/23022 for more details. + + Args: + axis_sizes: A tuple of integers specifying the size of each resource axis. + axis_names: A tuple of resource axis names to be assigned to the + dimensions of the ``devices`` argument. Its length should match the + rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 7be94adc2665..b5592ef46a66 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1366,9 +1366,14 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], axis_names: Names of the mesh axes. For example, axis_names=('x', 'y') devices: Optional keyword only argument, that allows you to specify the devices you want to create a mesh with. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries + corresponding to the ``axis_names``. See `Explicit Sharding`_ for more + information. Returns: - A `jax.sharding.Mesh` object. + A :class:`jax.sharding.Mesh` object. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if devices is None: devices = xb.devices() From 8865ee6628feedf04c53d6b12cfcfb47b67a667a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 16 Jun 2025 12:21:21 -0700 Subject: [PATCH 1706/1769] Remove legacy CPU custom calls. JAX v0.4.38 (released Dec 17, 2024) no longer lowered to any legacy CPU custom calls. Following our export compatibility guide (https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility), the remaining legacy custom calls can be removed on June 15, 2025, 180 days after the 0.4.38 release. PiperOrigin-RevId: 772126428 --- jax/_src/export/_export.py | 6 - .../cpu_schur_lapack_gees.py | 227 +--------- .../cpu_triangular_solve_blas_trsm.py | 188 -------- .../cpu_tridiagonal_lapack_sytrd_hetrd.py | 424 ------------------ jaxlib/cpu/BUILD | 2 - jaxlib/cpu/cpu_kernels.cc | 25 -- jaxlib/cpu/lapack.cc | 31 -- jaxlib/cpu/lapack_kernels.cc | 250 ----------- jaxlib/cpu/lapack_kernels.h | 66 --- jaxlib/cpu/lapack_kernels_using_lapack.cc | 73 --- tests/export_back_compat_test.py | 18 - 11 files changed, 1 insertion(+), 1309 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 7475a28c74f8..16ffedc8dd09 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1105,12 +1105,6 @@ def _check_lowering(lowering) -> None: "Eigh", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", - # triangular_solve on CPU - "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # schur on CPU - "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # tridiagonal on CPU - "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", # lu on TPU "LuDecomposition", # ApproxTopK on TPU diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py index 309aa73f20ba..db514111ec3e 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py @@ -15,232 +15,7 @@ # ruff: noqa import datetime -from numpy import array, int32, float32, complex64 - -data_2023_07_16 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]], dtype=float32),), - expected_outputs=(array([[ 3.2464233e+01, -1.3416403e+01, -1.5532076e-05, -4.3390692e-06], - [ 0.0000000e+00, -2.4642491e+00, -1.4625000e-06, -6.4478525e-07], - [ 0.0000000e+00, 0.0000000e+00, -8.1893580e-07, -2.5704816e-07], - [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5155359e-07]], - dtype=float32), array([[-0.11417631 , 0.828833 , -0.546308 , -0.039330132], - [-0.33000442 , 0.4371459 , 0.69909686 , 0.45963493 ], - [-0.54583275 , 0.045459975, 0.24073309 , -0.80127877 ], - [-0.7616609 , -0.34622616 , -0.39352104 , 0.3809742 ]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_sgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - return %12, %17 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b\x1fO\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xbe\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgees\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]]),), - expected_outputs=(array([[ 3.2464249196572958e+01, -1.3416407864998734e+01, - 1.4217165257496823e-15, 1.7257338996070338e-16], - [ 0.0000000000000000e+00, -2.4642491965729794e+00, - 4.0099214829607365e-16, 2.9384059908060751e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -1.5668631265126207e-15, 6.3403580326623540e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 1.2369554016158485e-16]]), array([[-0.11417645138733855 , 0.8288327563197505 , - 0.4940336612834742 , -0.23649681080057947 ], - [-0.3300045986655475 , 0.4371463883638869 , - -0.8349858635153001 , -0.052901868866879136], - [-0.545832745943757 , 0.045460020408024784, - 0.18787074318017621 , 0.8152941701354965 ], - [-0.7616608932219662 , -0.3462263475478383 , - 0.1530814590516493 , -0.525895490468038 ]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_dgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xf64>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - return %12, %17 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xce\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),), - expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -3.3649465e-06+0.j, - 3.5482326e-06+0.j], - [ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -7.4810049e-07+0.j, - 6.1193055e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -5.7737759e-07+0.j, - 2.5704813e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j, - 1.4719124e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5452458 +0.j, - -0.05202686 +0.j], - [ 0.3300045 +0.j, -0.43714625 +0.j, -0.68821627 +0.j, - 0.47577178 +0.j], - [ 0.54583293 +0.j, -0.045460097-0.j, -0.25930598 +0.j, - -0.79546237 +0.j], - [ 0.76166105 +0.j, 0.3462263 +0.j, 0.40227604 +0.j, - 0.37171766 +0.j]], dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_cgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\xe6\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),), - expected_outputs=(array([[ 3.2464249196572965e+01+0.j, -1.3416407864998730e+01+0.j, - 4.3084836728703156e-15+0.j, 2.8665351303736084e-15+0.j], - [ 0.0000000000000000e+00+0.j, -2.4642491965729802e+00+0.j, - -2.3716026934523430e-16+0.j, 3.7279396143672773e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - -1.6035677295293287e-15+0.j, -6.3403580326623540e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - 0.0000000000000000e+00+0.j, 1.2218554396786608e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197504 +0.j, - 0.4960613110079619 +0.j, 0.2322136424094458 +0.j], - [ 0.33000459866554754+0.j, -0.43714638836388703+0.j, - -0.8344969112540657 +0.j, 0.06012408092789509+0.j], - [ 0.5458327459437572 +0.j, -0.04546002040802478-0.j, - 0.18080988948424495+0.j, -0.8168890890841272 +0.j], - [ 0.7616608932219662 +0.j, 0.34622634754783854+0.j, - 0.15762571076185886+0.j, 0.5245513657467864 +0.j]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_zgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00", - xla_call_module_version=6, -) # End paste +from numpy import array, float32, complex64 data_2024_11_29 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py index c401ca041bfb..f3640c6114c7 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py @@ -17,196 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_07_16 = {} -# Pasted from the test output (see back_compat_test_util.py module docstring) -data_2023_07_16["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_strsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5., 0., 0., 0.], - [ 4., 10., 0., 0.], - [ 8., 9., 15., 0.], - [12., 13., 14., 20.]], dtype=float32), array([[ 0., 1., 2., 3., 4.], - [ 5., 6., 7., 8., 9.], - [10., 11., 12., 13., 14.], - [15., 16., 17., 18., 19.]], dtype=float32)), - expected_outputs=(array([[ 0. , 0.2 , 0.4 , 0.6 , - 0.8 ], - [ 0.5 , 0.52 , 0.54 , 0.56 , - 0.58000004 ], - [ 0.36666667 , 0.31466666 , 0.26266667 , 0.21066667 , - 0.15866666 ], - [ 0.16833334 , 0.12173338 , 0.0751333 , 0.02853328 , - -0.018066704]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf32> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_strsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> loc(#loc2) - return %8 : tensor<4x5xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":508:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02J\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xf2\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\t\x00\x00\x80?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_strsm\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_dtrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5., 0., 0., 0.], - [ 4., 10., 0., 0.], - [ 8., 9., 15., 0.], - [12., 13., 14., 20.]]), array([[ 0., 1., 2., 3., 4.], - [ 5., 6., 7., 8., 9.], - [10., 11., 12., 13., 14.], - [15., 16., 17., 18., 19.]])), - expected_outputs=(array([[ 0. , 0.2 , - 0.4 , 0.6000000000000001 , - 0.8 ], - [ 0.5 , 0.52 , - 0.54 , 0.5599999999999999 , - 0.58 ], - [ 0.36666666666666664 , 0.3146666666666667 , - 0.2626666666666667 , 0.21066666666666667 , - 0.15866666666666665 ], - [ 0.16833333333333336 , 0.1217333333333333 , - 0.07513333333333323 , 0.0285333333333333 , - -0.018066666666666675]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf64> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf64> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_dtrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf64>, tensor<4x5xf64>) -> tensor<4x5xf64> loc(#loc2) - return %8 : tensor<4x5xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":511:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02Z\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfe\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_dtrsm\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_ctrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], - [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 20.+0.j]], dtype=complex64), array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j], - [ 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j, 9.+0.j], - [10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j, 14.+0.j], - [15.+0.j, 16.+0.j, 17.+0.j, 18.+0.j, 19.+0.j]], dtype=complex64)), - expected_outputs=(array([[ 0. +0.j, 0.2 +0.j, 0.4 +0.j, - 0.6 +0.j, 0.8 +0.j], - [ 0.5 +0.j, 0.52 +0.j, 0.54 +0.j, - 0.56 +0.j, 0.58000004 +0.j], - [ 0.36666667 +0.j, 0.31466666 +0.j, 0.26266667 +0.j, - 0.21066667 +0.j, 0.15866666 +0.j], - [ 0.16833334 +0.j, 0.12173338 +0.j, 0.0751333 +0.j, - 0.02853328 +0.j, -0.018066704+0.j]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ctrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02b\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\t\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ctrsm\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_ztrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], - [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 20.+0.j]]), array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j], - [ 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j, 9.+0.j], - [10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j, 14.+0.j], - [15.+0.j, 16.+0.j, 17.+0.j, 18.+0.j, 19.+0.j]])), - expected_outputs=(array([[ 0. +0.j, 0.2 +0.j, - 0.4 +0.j, 0.6000000000000001 +0.j, - 0.8 +0.j], - [ 0.5 +0.j, 0.52 +0.j, - 0.54 +0.j, 0.5599999999999999 +0.j, - 0.58 +0.j], - [ 0.36666666666666664 +0.j, 0.3146666666666667 +0.j, - 0.2626666666666667 +0.j, 0.21066666666666667 +0.j, - 0.15866666666666665 +0.j], - [ 0.16833333333333336 +0.j, 0.1217333333333333 +0.j, - 0.07513333333333323 +0.j, 0.0285333333333333 +0.j, - -0.018066666666666675+0.j]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ztrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00", - xla_call_module_version=6, -) # End paste - data_2024_12_02 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_12_02['c128'] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py index 9e245052e03a..c986e4ffd115 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py @@ -17,432 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2024_09_03 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zhetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-1.6782909868280393 +0.j , - -0.44670237330570184+4.847000766107959j , - 2.05945450900321 -2.2848432268240106j , - -1.852046418980849 +1.672382006137275j ], - [ 8.516713699516982 +0.j , - -2.7881860505313174 +0.j , - 0.9238284715039695 -2.3790501284019947j , - 0.5005102262291599 -1.30066052934836j ], - [-0.12132810525381293-0.2963030371159077j , - -3.6374350042782893 +0.j , - 0.5605752523031344 +0.j , - -2.9865099107523174 +0.5492956557924651j ], - [-0.40379248092949666-0.7813328344426929j , - -0.07101654492399719-0.27208840961051617j, - -7.4654253782049285 +0.j , - -8.172380353916964 +0.j ]], - - [[-3.996403598623405 +0.j , - 0.59408630943699 +2.531609474375295j , - -1.789098034543644 -2.538389274566601j , - -1.291106590337488 +3.1576544511573843j ], - [10.8950662522622 +0.j , - -2.8151642043836693 +0.j , - 6.18998567202382 +1.1866537964613415j , - 3.1900218245393352 +2.7291222716752372j ], - [-0.3142889671188478 -0.37781876498252764j, - 3.049208563595754 +0.j , - -2.4383044880335487 +0.j , - 4.075435464493341 -0.6653616942280807j ], - [ 0.32757687545025194+0.565870910342534j , - 0.8177026465997795 -0.15906305615104555j, - 3.3415143060767125 +0.j , - 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, - -8.172380353916964 ], - [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, - 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], - [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, - 1.834630852474663 +0.18575551495730305j, - 1.981584368497257 +0.19102912741736966j], - [1.0365789616521406-0.40942548304121656j, - 1.0872592163018966-0.3187050677167622j , - 1.0458498304770472-0.9989483435319496j ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_zhetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo/O/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\x12\x10\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\x0b\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_zhetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_chetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , - 7.367708 +0.88518727j , -8.659938 +1.6132793j ], - [-6.9206004 +0.j , -3.6362798 +0.j , - 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], - [ 0.64957 +0.060723424j, 6.620491 +0.j , - 0.2882607 +0.j , -1.0288142 +1.8544064j ], - [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , - -4.431866 +0.j , 2.364208 +0.j ]], - - [[-4.1803885 +0.j , 0.5670845 +0.6913016j , - 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], - [ 8.33625 +0.j , 2.6144838 +0.j , - -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], - [ 0.019031923+0.17462212j , 2.7034955 +0.j , - -0.70924187 +0.j , 2.7962255 +1.5316825j ], - [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , - 6.6364865 +0.j , -1.698973 +0.j ]]], - dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], - [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], - dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], - [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, - 1.5772758-0.8165493j ], - [1.9152443-0.1834492j , 1.1593437+0.55631363j, - 1.6889225-0.724835j ]], dtype=complex64)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_chetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo//\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\xe2\x0b\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\xc0\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\t\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_chetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], - [-2.985257 , -5.571 , -0.22652794, -0.83806676], - [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], - [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], - - [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], - [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], - [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], - [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], - dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], - [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], - [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], - [1.6288393, 1.8669801, 0. ]], dtype=float32)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_ssytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) - return %2 : tensor<2x4x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02b\t\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\t)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_ssytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , - 0.8082445002373937 , -1.551980329390836 ], - [-2.629505060186711 , 4.427374205796291 , - -2.2111093161901074 , 7.552489598405787 ], - [ 0.2269453213819231 , 0.3650586474106988 , - -3.5933639667756205 , 4.828829679372501 ], - [-0.6415372293575187 , -0.2519326897319508 , - -1.7607827845801751 , -3.381311711243865 ]], - - [[-4.000421911405985 , 3.6303350337601055 , - 2.8066821235532355 , 1.099224389184342 ], - [-4.141622408467332 , -5.276404169116551 , - -0.8496056221591237 , -2.275319346221659 ], - [ 0.5828958067901202 , 0.9351254869793256 , - 2.7765603683442177 , -4.339686212557215 ], - [-0.6391146585297987 , 0.3129920702652711 , - -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, - -3.381311711243865 ], - [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, - -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], - [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], - [1.1440109149169537, 1.8215532880266878, 0. ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_dsytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<128xf64>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) - return %2 : tensor<2x4x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02r\x0b\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\x0b)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_dsytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste - data_2024_12_01 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_12_01["c128"] = dict( testdata_version=1, diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index cbcddd9713f0..349b64b4ce3b 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -42,7 +42,6 @@ cc_library( "@com_google_absl//absl/types:span", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -88,7 +87,6 @@ cc_library( ":sparse_kernels", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 4361e42827ea..3d75a02f6ae3 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -22,7 +22,6 @@ limitations under the License. #include "jaxlib/cpu/sparse_kernels.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_target_registry.h" #define JAX_CPU_REGISTER_HANDLER(name) \ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); @@ -30,30 +29,6 @@ limitations under the License. namespace jax { namespace { -// Old-style kernels -// TODO(b/344892332): To be removed after the 6M compatibility period is over. - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ctrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ztrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgees", ComplexGees>::Kernel, "Host"); - -// FFI Kernels - JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 1bb3f1f13405..b9c92210f311 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "absl/base/call_once.h" #include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" @@ -45,10 +43,6 @@ void GetLapackKernelsFromScipy() { return nb::cast(blas_capi[name]).data(); }; - AssignKernelFn>(blas_ptr("strsm")); - AssignKernelFn>(blas_ptr("dtrsm")); - AssignKernelFn>>(blas_ptr("ctrsm")); - AssignKernelFn>>(blas_ptr("ztrsm")); AssignKernelFn>(blas_ptr("strsm")); AssignKernelFn>(blas_ptr("dtrsm")); AssignKernelFn>(blas_ptr("ctrsm")); @@ -112,10 +106,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeev")); - AssignKernelFn>(lapack_ptr("sgees")); - AssignKernelFn>(lapack_ptr("dgees")); - AssignKernelFn>>(lapack_ptr("cgees")); - AssignKernelFn>>(lapack_ptr("zgees")); AssignKernelFn>(lapack_ptr("sgees")); AssignKernelFn>(lapack_ptr("dgees")); AssignKernelFn>( @@ -132,10 +122,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgehrd")); - AssignKernelFn>(lapack_ptr("ssytrd")); - AssignKernelFn>(lapack_ptr("dsytrd")); - AssignKernelFn>>(lapack_ptr("chetrd")); - AssignKernelFn>>(lapack_ptr("zhetrd")); AssignKernelFn>(lapack_ptr("ssytrd")); AssignKernelFn>(lapack_ptr("dsytrd")); AssignKernelFn>(lapack_ptr("chetrd")); @@ -150,23 +136,6 @@ void GetLapackKernelsFromScipy() { nb::dict Registrations() { nb::dict dict; - dict["blas_strsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["blas_ztrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["lapack_sgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_dgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_cgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - dict["lapack_zgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_chetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_zhetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi); dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi); dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi); diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 2e91bcb34281..4ec8a73801a6 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -18,9 +18,7 @@ limitations under the License. #include #include #include -#include #include -#include #include #include #include @@ -33,7 +31,6 @@ limitations under the License. #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit"); @@ -70,60 +67,6 @@ void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { //== Triangular System Solver ==// -// lapack trsm - -template -typename Trsm::FnType* Trsm::fn = nullptr; - -template -void Trsm::Kernel(void* out, void** data, XlaCustomCallStatus*) { - int32_t left_side = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t trans_a = *reinterpret_cast(data[2]); - int32_t diag = *reinterpret_cast(data[3]); - int m = *reinterpret_cast(data[4]); - int n = *reinterpret_cast(data[5]); - int batch = *reinterpret_cast(data[6]); - T* alpha = reinterpret_cast(data[7]); - T* a = reinterpret_cast(data[8]); - T* b = reinterpret_cast(data[9]); - - T* x = reinterpret_cast(out); - if (x != b) { - std::memcpy(x, b, - static_cast(batch) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char cside = left_side ? 'L' : 'R'; - char cuplo = lower ? 'L' : 'U'; - char ctransa = 'N'; - if (trans_a == 1) { - ctransa = 'T'; - } else if (trans_a == 2) { - ctransa = 'C'; - } - char cdiag = diag ? 'U' : 'N'; - int lda = left_side ? m : n; - int ldb = m; - - int64_t x_plus = static_cast(m) * static_cast(n); - int64_t a_plus = static_cast(lda) * static_cast(lda); - - for (int i = 0; i < batch; ++i) { - fn(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb); - x += x_plus; - a += a_plus; - } -} - -template struct Trsm; -template struct Trsm; -template struct Trsm>; -template struct Trsm>; - -// FFI Kernel - template ffi::Error TriMatrixEquationSolver::Kernel( ffi::Buffer x, ffi::Buffer y, @@ -1064,138 +1007,6 @@ template struct EigenvalueDecompositionComplex; //== Schur Decomposition ==// -// lapack gees - -template -typename RealGees::FnType* RealGees::fn = nullptr; - -template -void RealGees::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T, T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - - T* wr_out = reinterpret_cast(out[1]); - T* wi_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, &work_query, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - size_t a_size = static_cast(n) * static_cast(n) * sizeof(T); - if (a_out != a_in) { - std::memcpy(a_out, a_in, static_cast(b) * a_size); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, work, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_out, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - wr_out += n; - wi_out += n; - vs_out += n * n; - ++sdim_out; - ++info_out; - } - delete[] work; - delete[] b_work; -} - -template -typename ComplexGees::FnType* ComplexGees::fn = nullptr; - -template -void ComplexGees::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - T* w_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, &work_query, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, work, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - w_out += n; - vs_out += n * n; - ++info_out; - ++sdim_out; - } - delete[] work; - delete[] b_work; -} - -template struct RealGees; -template struct RealGees; -template struct ComplexGees>; -template struct ComplexGees>; - -// FFI Kernel - template ffi::Error SchurDecomposition::Kernel( ffi::Buffer x, schur::ComputationMode mode, schur::Sort sort, @@ -1423,67 +1234,6 @@ template struct HessenbergDecomposition; //== Tridiagonal Reduction ==// -// lapack sytrd/hetrd - -template -typename Sytrd::FnType* Sytrd::fn = nullptr; - -template -void Sytrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t lda = *reinterpret_cast(data[2]); - int32_t batch = *reinterpret_cast(data[3]); - int32_t lwork = *reinterpret_cast(data[4]); - T* a = reinterpret_cast(data[5]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typedef typename real_type::type Real; - Real* d = reinterpret_cast(out[1]); - Real* e = reinterpret_cast(out[2]); - T* tau = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - T* work = reinterpret_cast(out[5]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char cuplo = lower ? 'L' : 'U'; - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info); - a_out += a_plus; - d += n; - e += n - 1; - tau += n - 1; - ++info; - } -} - -template -int64_t Sytrd::Workspace(lapack_int lda, lapack_int n) { - char cuplo = 'L'; - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork, - &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Sytrd; -template struct Sytrd; -template struct Sytrd>; -template struct Sytrd>; - -// FFI Kernel - template ffi::Error TridiagonalReduction::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index b3f1f1df758a..572f67b7744b 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" // Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the nanobind wrapper that links them to an existing SciPy lapack instance, @@ -100,20 +99,6 @@ static_assert( //== Triangular System Solver ==// -// lapack trsm - -template -struct Trsm { - using FnType = void(char* side, char* uplo, char* transa, char* diag, - lapack_int* m, lapack_int* n, T* alpha, T* a, - lapack_int* lda, T* b, lapack_int* ldb); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TriMatrixEquationSolver { using ValueType = ::xla::ffi::NativeType; @@ -490,31 +475,6 @@ struct EigenvalueDecompositionComplex { //== Schur Decomposition ==// -// lapack gees - -template -struct RealGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), - lapack_int* n, T* a, lapack_int* lda, lapack_int* sdim, - T* wr, T* wi, T* vs, lapack_int* ldvs, T* work, - lapack_int* lwork, bool* bwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T), lapack_int* n, - T* a, lapack_int* lda, lapack_int* sdim, T* w, T* vs, - lapack_int* ldvs, T* work, lapack_int* lwork, - typename T::value_type* rwork, bool* bwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SchurDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -596,32 +556,6 @@ struct HessenbergDecomposition { //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// -template -struct real_type { - typedef T type; -}; -template -struct real_type> { - typedef T type; -}; - -// lapack sytrd/hetrd - -template -struct Sytrd { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - typename real_type::type* d, - typename real_type::type* e, T* tau, T* work, - lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TridiagonalReduction { using ValueType = ::xla::ffi::NativeType; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index e771aa0e37d1..c4b154f1ae70 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but @@ -100,77 +97,7 @@ jax::TridiagonalSolver::FnType zgtsv_; namespace jax { -#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" - -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); - -#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG - static auto init = []() -> int { - AssignKernelFn>(strsm_); - AssignKernelFn>(dtrsm_); - AssignKernelFn>>(ctrsm_); - AssignKernelFn>>(ztrsm_); - - AssignKernelFn>(sgees_); - AssignKernelFn>(dgees_); - AssignKernelFn>>(cgees_); - AssignKernelFn>>(zgees_); - - AssignKernelFn>(ssytrd_); - AssignKernelFn>(dsytrd_); - AssignKernelFn>>(chetrd_); - AssignKernelFn>>(zhetrd_); - - // FFI Kernels - AssignKernelFn>(strsm_); AssignKernelFn>(dtrsm_); AssignKernelFn>(ctrsm_); diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 258164cdc615..e742654e6740 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -145,9 +145,6 @@ def test_custom_call_coverage(self): cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, cuda_tridiagonal_solve.data_2025_06_16, rocm_eigh_hipsolver_syev.data_2024_08_05, - cpu_schur_lapack_gees.data_2023_07_16, - cpu_triangular_solve_blas_trsm.data_2023_07_16, - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -567,10 +564,6 @@ def check_schur_results(res_run, res_expected, *, rtol, atol): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) - data = self.load_testdata(cpu_schur_lapack_gees.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_schur_results, - expect_current_custom_calls=info["custom_call_targets"]) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -654,11 +647,6 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_triangular_solve_results) - data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_triangular_solve_results, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -705,12 +693,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) From cb1cc37919e07e2e85220aff8b2190776528c70d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Jun 2025 13:19:38 -0700 Subject: [PATCH 1707/1769] Fix a missing bounds check in traceback code. PiperOrigin-RevId: 772146614 --- jaxlib/traceback.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index 4ae9cf62130a..e4f3cb35ad9b 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -280,7 +280,7 @@ std::vector> Traceback::RawFrames() const { // Feel free to turn this on if you like, but it might break at any time! #if PY_VERSION_HEX < 0x030d0000 for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; - f != nullptr; f = f->previous) { + f != nullptr && count < kMaxFrames; f = f->previous) { if (_PyFrame_IsIncomplete(f)) continue; Py_INCREF(f->f_code); frames[count] = {f->f_code, static_cast(_PyInterpreterFrame_LASTI(f) * @@ -288,12 +288,13 @@ std::vector> Traceback::RawFrames() const { ++count; } #else // PY_VERSION_HEX < 0x030d0000 - for (_PyInterpreterFrame* f = thread_state->current_frame; f != nullptr; - f = f->previous) { + for (_PyInterpreterFrame* f = thread_state->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { if (_PyFrame_IsIncomplete(f)) continue; Py_INCREF(f->f_executable); - frames[count] = {reinterpret_cast(f->f_executable), - _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)}; + frames[count] = { + reinterpret_cast(f->f_executable), + static_cast(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))}; ++count; } #endif // PY_VERSION_HEX < 0x030d0000 @@ -301,7 +302,7 @@ std::vector> Traceback::RawFrames() const { #else // PLATFORM_GOOGLE PyFrameObject* next; for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); - py_frame != nullptr; py_frame = next) { + py_frame != nullptr && count < kMaxFrames; py_frame = next) { frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)}; ++count; next = PyFrame_GetBack(py_frame); From 97e580c835c674894b31a398b4077cb8adbc1709 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 16 Jun 2025 23:19:16 +0200 Subject: [PATCH 1708/1769] Removed fixed suppressions --- .github/workflows/tsan-suppressions_3.13.txt | 5 +---- .github/workflows/tsan-suppressions_3.14.txt | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tsan-suppressions_3.13.txt b/.github/workflows/tsan-suppressions_3.13.txt index 483e3f0b3c2a..3095eacf8060 100644 --- a/.github/workflows/tsan-suppressions_3.13.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -34,9 +34,6 @@ race:scal_k_ race:gemm_beta race:gemm_oncopy -# https://github.com/python/cpython/issues/132245 -race:split_keys_entry_added -race_top:dict_dict_merge - # https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. race:type_update_dict diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt index 008b61933a0b..d987879cab58 100644 --- a/.github/workflows/tsan-suppressions_3.14.txt +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -16,6 +16,6 @@ race:scal_k_ race:gemm_beta race:gemm_oncopy -# https://github.com/python/cpython/issues/132245 -race:split_keys_entry_added -race_top:dict_dict_merge +# https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. +race:type_update_dict From 748b39f497fd736c4e9f17d247399c73e996165e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 16 Jun 2025 15:06:34 -0700 Subject: [PATCH 1709/1769] [Mosaic:TPU][NFC] Delete unused variable PiperOrigin-RevId: 772187342 --- jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 390ce9d3db32..0c0103757ea1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6393,8 +6393,6 @@ FailureOr>> changeOffsets( {dst_offsets[0], src.offsets()[1]}, src.tiling(), src.implicit_dim()); if (row_diff != 0) { - const SmallVector implicit_shape = - src.implicitShape(vty.getShape()); FAILUREOR_ASSIGN_OR_RETURN( vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src, *dst_offsets[0], ctx.target_shape)); From b622512422a3bfcbfb11ead6f73f1c9a259c4b24 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 16 Jun 2025 16:02:03 -0700 Subject: [PATCH 1710/1769] [JAX] Relax the return type of `colocated_python` decorator `colocated_python` decorator wraps a function, and the returned function has a special method `specialize` that lets the user provide explicit information of the output spec or execution devices. This `specialize` method is in principle not a part of `Callable` protocol, so access to it would not be valid typing. This change relaxes the return type of `colocated_python` to `Any` so that `specialize` method access does not cause typing check failure. PiperOrigin-RevId: 772206576 --- jax/experimental/colocated_python/api.py | 2 +- jax/experimental/colocated_python/func.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b0dc3a46fc5f..a79bd464fb92 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -110,7 +110,7 @@ def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh: ) -def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]: +def colocated_python(fun: Callable[..., Any]): """Executes the given Python function on the same devices as the arguments.""" return make_callable( fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 65464479b5ca..d8fa003b775a 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -65,7 +65,7 @@ def update( out_specs_treedef: tree_util.PyTreeDef | None = None, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, devices: Sequence[jax.Device] | xc.DeviceList | None = None, - ) -> Any: + ): """Creates a new specialization with overrides.""" if in_specs_treedef is None: in_specs_treedef = self.in_specs_treedef @@ -234,7 +234,7 @@ def _make_pop_result_fun( out_specs_treedef = specialization.out_specs_treedef - def lowered_fun() -> Any: + def lowered_fun(): result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) return tree_util.tree_unflatten(out_specs_treedef, result_leaves) @@ -294,7 +294,7 @@ def _get_specialized_func( # Asynchronous execution function that has known output_specs. async_execution_func = None - def specialized_func(*args, **kwargs) -> Any: + def specialized_func(*args, **kwargs): """Specialized function to be executed with given args and kwargs.""" nonlocal specialization, async_execution_func with mutex: @@ -356,24 +356,21 @@ def make_callable( fun: Callable[..., Any], fun_sourceinfo: str | None, fun_signature: inspect.Signature | None, -) -> Callable[..., Any]: +): """Makes a colocated Python callable.""" return _make_callable( FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() ) -def _make_callable( - info: FunctionInfo, - specialization: Specialization, -) -> Callable[..., Any]: +def _make_callable(info: FunctionInfo, specialization: Specialization): """Internal implementation of make_callable.""" def specialize( in_specs: ShapeDtypeStructTree | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, devices: Sequence[jax.Device] | None = None, - ) -> Callable[..., Any]: + ): """Returns a colocated Python callable with extra specialization. Args: @@ -410,7 +407,7 @@ def specialize( ) @api_boundary - def __call__(*args, **kwargs) -> Any: + def __call__(*args, **kwargs): """Executes the function. If the output specs are not known, the very first execution will be From 2dfacedaf880aa8736cd161dd7634bedbcb841e7 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Mon, 16 Jun 2025 18:40:46 -0700 Subject: [PATCH 1711/1769] Add custom-call ops to roofline. Recursively calculates the roofline result for the primitives from the custom function. PiperOrigin-RevId: 772252912 --- jax/experimental/roofline/roofline.py | 10 ++++++ tests/roofline_test.py | 52 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index fcfe3ff4b9ff..dbfe3c983cc0 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -188,6 +188,16 @@ def calculate_peak_hbm_bytes() -> int: pin_lhs_in_vmem=pin_lhs_in_vmem, pin_rhs_in_vmem=pin_rhs_in_vmem, ) + elif "call_jaxpr" in eqn.params: + # Used for custom_jvp_call_p. Recursively calculates roofline result for + # all primitives in the custom function. + result += _roofline_interpreter( + util.wrap_name(f_name, eqn.primitive.name), + eqn.params['call_jaxpr'], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) else: if eqn.primitive not in _rooflines: msg = f"No roofline rule for {eqn.primitive}." diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 9ed85c506814..a234c0fe4fa8 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -46,6 +46,31 @@ def create_inputs( return mesh, tuple(arrays) +def example_function(x): + return jnp.sin(x) + x**2 + + +@jax.custom_jvp +def example_custom_function(x): + """Example custom function. + + Small wrapper around `example_function`. We define `example_custom_function` + separately since we add the `@jax.custom_jvp` decorator and want to compare + its behavior to `example_function`'s in tests. + """ + return example_function(x) + + +@example_custom_function.defjvp +def example_custom_function_jvp(primals, tangents): + """Example custom function jvp. + + Normally this function would define a mathematically correct JVP, but its + definition has 0 effect on the roofline result, so we keep it very simple. + """ + return example_custom_function(primals), tangents + + class RooflineTest(jtu.JaxTestCase): def setUp(self): @@ -802,6 +827,33 @@ def test_reduce_sum_with_axis(self): result.unfused_hbm_bytes, self._bytes_per_word * expected_memory ) + def test_custom_jvp_call_p_roofline(self): + dummy_input = jnp.ones((3, 8)) + + _, base_result = roofline.roofline(example_function)(dummy_input) + _, custom_result = roofline.roofline(example_custom_function)(dummy_input) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + + def test_custom_jvp_call_p_roofline_with_neg(self): + dummy_input = jnp.ones((3, 8)) + + def with_neg(f): + return lambda x: jax.lax.neg(f(x)) + + _, base_result = roofline.roofline(with_neg(example_function))(dummy_input) + _, custom_result = roofline.roofline(with_neg(example_custom_function))( + dummy_input + ) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 0f5cdba039fddda9119698a463929aef82244ad0 Mon Sep 17 00:00:00 2001 From: Sannidhya Chauhan Date: Mon, 16 Jun 2025 19:36:23 -0700 Subject: [PATCH 1712/1769] Removing Tensorflow references from the document. PiperOrigin-RevId: 772266335 --- docs/profiling.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/profiling.md b/docs/profiling.md index a6cdf882f98e..3800a7ce140e 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -292,10 +292,10 @@ jax.profiler.stop_trace() default). `2`: Includes level 1 traces plus high-level program execution details like - expensive TensorFlow or XLA operations. + expensive XLA operations. `3`: Includes level 2 traces plus more verbose, low-level program execution - details such as cheap TensorFlow operations. + details such as cheap XLA operations. 2. `python_tracer_level`: Controls whether Python tracing is enabled. @@ -303,7 +303,7 @@ jax.profiler.stop_trace() `0`: Disables Python function call tracing. - `> 0`: Enables Python tracing (this is the default). + `1`: Enables Python tracing (this is the default). #### Advanced configuration options From acf99e5e6b1b77409f60c89f1b9f6c8b9c568a0e Mon Sep 17 00:00:00 2001 From: Sannidhya Chauhan Date: Mon, 16 Jun 2025 19:46:59 -0700 Subject: [PATCH 1713/1769] Add test for programmatic tracing with options. PiperOrigin-RevId: 772268506 --- tests/profiler_test.py | 50 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 9784298405d2..82d3ec8437d7 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -113,6 +113,31 @@ def testProgrammaticProfiling(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingWithOptions(self): + with tempfile.TemporaryDirectory() as tmpdir: + try: + options = jax.profiler.ProfileOptions() + options.python_tracer_level = 0 + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = glob.glob( + os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True + ) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Verify that the serialized proto contains host and device traces, and + # does not contain Python traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertIn(b"/device:TPU", proto) + self.assertNotIn(b"pxla.py", proto) + def testProgrammaticProfilingPathlib(self): with tempfile.TemporaryDirectory() as tmpdir_string: tmpdir = pathlib.Path(tmpdir_string) @@ -133,6 +158,29 @@ def testProgrammaticProfilingPathlib(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingWithOptionsPathlib(self): + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + try: + options = jax.profiler.ProfileOptions() + options.advanced_configuration = {"tpu_trace_mode": "TRACE_ONLY_HOST"} + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + self.assertEqual(len(proto_path), 1) + proto = proto_path[0].read_bytes() + # Verify that the serialized proto contains host traces and does not + # contain TPU device traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertNotIn(b"/device:TPU", proto) + self.assertIn(b"pxla.py", proto) + def testProfilerGetFDOProfile(self): # Tests stop_and_get_fod_profile could run. try: @@ -182,7 +230,7 @@ def testProgrammaticProfilingContextManager(self): def testProgrammaticGpuCuptiTracing(self): @jit def xy_plus_z(x, y, z): - return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z k = jax.random.key(0) s = 1, 16, 16 jax.devices() From 4be1402b9915ea18f3700866db406afdd542470d Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 16 Jun 2025 22:39:48 -0700 Subject: [PATCH 1714/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/54d83f35e69f82321308f1c479823521ce5cc4ed. PiperOrigin-RevId: 772316816 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ba4d13901b39..a585c499286d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "71e3a0d40d7f368b239dc664e3bb0a455a80541b" -XLA_SHA256 = "5729c77cc32c5513f18802dd46dcf4c5b0457a0fb5ad29d1c543734d23afdaba" +XLA_COMMIT = "54d83f35e69f82321308f1c479823521ce5cc4ed" +XLA_SHA256 = "e0c71f7f7ae2862ad6f105a8732cd6c4f1fa04fd9e98c0ab90a3460c924aa8cf" def repo(): tf_http_archive( From 93453f54ed0051a0a0cee0409edcf41cbf736086 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 17 Jun 2025 03:26:23 -0700 Subject: [PATCH 1715/1769] Pass through the `use_shardy_partitioner` with `jax.config.jax_use_shardy_partitioner`. PiperOrigin-RevId: 772397443 --- jax/experimental/jax2tf/tests/sharding_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 20193a931b63..33e78da18021 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -109,6 +109,7 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, + use_shardy_partitioner=jax.config.jax_use_shardy_partitioner, ) executable = backend.compile_and_load( jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore From 91651f8078bc7c86bdc4adf76f9d67cc2529271b Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 17 Jun 2025 06:18:20 -0700 Subject: [PATCH 1716/1769] [Mosaic] Use BF16 ops for math::PowF on TPUv6+. PiperOrigin-RevId: 772444344 --- jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index a88b70d302af..b74a5ae15137 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1082,7 +1082,6 @@ const llvm::StringMap &rules() { } const llvm::StringMap &bf16_ops_min_supported_versions() { - constexpr int kAlwaysUpcast = std::numeric_limits::max(); static const auto m = new llvm::StringMap{ {arith::DivFOp::getOperationName(), 4}, {arith::SelectOp::getOperationName(), 5}, @@ -1092,7 +1091,7 @@ const llvm::StringMap &bf16_ops_min_supported_versions() { {arith::SubFOp::getOperationName(), 6}, {arith::MaximumFOp::getOperationName(), 6}, {arith::MinimumFOp::getOperationName(), 6}, - {math::PowFOp::getOperationName(), kAlwaysUpcast}, + {math::PowFOp::getOperationName(), 6}, {math::TanhOp::getOperationName(), 6}, {math::ExpOp::getOperationName(), 6}, {math::Exp2Op::getOperationName(), 6}, From 077f2b669a412d40ee1f813499f0b4936fa36f53 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 17 Jun 2025 08:25:20 -0700 Subject: [PATCH 1717/1769] Update Pallas debugging doc with TPU interpret mode + dynamic race detector. PiperOrigin-RevId: 772483044 --- jax/experimental/pallas/g3doc/debugging.md | 61 +++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md index 791705d00d30..f1f22999d3af 100644 --- a/jax/experimental/pallas/g3doc/debugging.md +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -16,10 +16,39 @@ a ticket on https://github.com/jax-ml/jax/issues. ### Interpret (HLO) Mode -Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. +Passing in `interpret=True` into `pl.pallas_call` or `pl.core_map` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. +### TPU Interpret Mode + +TPU interpret mode is similar to [interpret (HLO) mode](#interpret-hlo-mode), +but TPU interpret mode explicitly simulates accesses to TPU memory (HBM, VMEM, +SMEM, etc.), communication via remote DMAs, TPU synchronization operations +(e.g., barriers and semaphores), and parallel execution of kernels distributed +across +[multiple TPUs](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) and +[Megacore cores](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html#megacore). + +TPU interpret mode is slower than interpret (HLO) mode, but it can be useful for +developing and debugging distributed TPU kernels with explicit communication and +synchronization. With this mode, kernels can be run on CPU -- enabling local +development (with no TPU), using a debugger and inspecting the state of +simulated TPU buffers and semaphores, etc. + +To use TPU interpret mode, pass `interpret=pltpu.InterpretParams()` into +`pl.pallas_call` or `pl.core_map`. For examples, see +`test_matmul_example` in +[tpu_pallas_interpret_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_test.py#:~:text=test_matmul_example) +and +`test_right_permute_example` and the other tests in +[tpu_pallas_interpret_distributed_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_distributed_test.py#:~:text=test_right_permute_example). + +The behavior of TPU interpret mode can be configured via arguments to +[`pltpu.InterpretParams`](https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/interpret.py#:~:text=class%20InterpretParams). For example, use `num_cores_per_device=2` +to simulate Megacore or `uninitialized_memory='zero'` to initialize simuluated +TPU buffers with zeros instead of NaNs. + ### debug_print The `pl.debug_print` function can be used to print runtime values inside of a kernel. @@ -160,11 +189,39 @@ spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan +### Dynamic Race Detection + +[TPU Interpret Mode](#tpu-interpret-mode) includes a dynamic race detector. +While running a kernel, it can detect and log data races -- pairs of accesses +to shared memory (HBM, VMEM, SMEM, etc.) that are not properly synchronized. + +To enable the dynamic race detector, use the option `detect_races=True` in the +`pltpu.InterpretParams` passed to `pl.pallas_call`: + +```python +pl.pallas_call( + kernel, + ..., + intepret=pltpu.InterpretParams(..., detect_races=True), +) +``` + +If any data races are detected while running the kernel, a message will be +printed -- for example: + +``` +RACE DETECTED + write ... from ...jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) + write ... from .../jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) +``` + + + ## Useful Command line flags * OOB Checks: `--xla_mosaic_on_device_checks=bounds` * Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` - + * Dump Mosaic: `--xla_mosaic_dump_to=` * Enable trace markers in XProf: `--xla_enable_transpose_trace` From 124c723a616e4cd7a17fccd2b317ce6f1cd84dbb Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Jun 2025 08:42:04 -0700 Subject: [PATCH 1718/1769] Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation in Mosaic GPU implementation. PiperOrigin-RevId: 772489065 --- jax/experimental/mosaic/gpu/core.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 6e664f626e4d..f7a471e9b123 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -28,6 +28,7 @@ import itertools import jax +from jax._src import lib from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib import mosaic_gpu_dialect as dialect @@ -54,17 +55,9 @@ from . import utils # MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -cuda_root = "/usr/local/cuda" +cuda_root = lib.cuda_path or "/usr/local/cuda" +os.environ["CUDA_ROOT"] = cuda_root PYTHON_RUNFILES = os.environ.get("PYTHON_RUNFILES") -if os.environ.get("CUDA_ROOT") is None: - if PYTHON_RUNFILES: - cuda_nvcc_root = os.path.join(PYTHON_RUNFILES, "cuda_nvcc") - if os.path.exists(cuda_nvcc_root): - cuda_root = cuda_nvcc_root - os.environ["CUDA_ROOT"] = cuda_root -else: - cuda_root = os.environ["CUDA_ROOT"] PTXAS_PATH = os.path.join(cuda_root, "bin/ptxas") NVDISASM_PATH = os.path.join(cuda_root, "bin/nvdisasm") From 332aa354d42a8eb2380f5ad40125ccef6e3b5298 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Tue, 17 Jun 2025 09:16:13 -0700 Subject: [PATCH 1719/1769] Add an API to overwrite the current execution_stream_id and respect it in XLA CPU dispatch. PiperOrigin-RevId: 772500636 --- jaxlib/_jax/__init__.pyi | 4 ++++ jaxlib/pjit.cc | 8 ++++++-- jaxlib/pmap_lib.cc | 7 +++++-- jaxlib/py_client.h | 6 ++++++ jaxlib/py_executable.cc | 5 ++++- jaxlib/xla.cc | 4 ++++ jaxlib/xla_client.py | 13 ++++++++++++- 7 files changed, 41 insertions(+), 6 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 08eca19da5b1..c1a1ecdab28c 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -688,6 +688,10 @@ class ExecuteResults: def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]: ... def consume_token(self) -> ShardedToken: ... +def get_execution_stream_id() -> int: ... + +def set_execution_stream_id(new_id: int): ... + class LoadedExecutable: client: Client def local_devices(self) -> list[Device]: ... diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc index 1e9d53547b1d..6196c945a63a 100644 --- a/jaxlib/pjit.cc +++ b/jaxlib/pjit.cc @@ -55,6 +55,7 @@ limitations under the License. #include "jaxlib/jax_jit.h" #include "jaxlib/nb_class_ptr.h" #include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" #include "jaxlib/py_executable.h" #include "jaxlib/py_values.h" #include "jaxlib/python_ref_manager.h" @@ -754,8 +755,11 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, xla::ifrt::ExecuteOptions execute_options = cache_entry->executable->options(); execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); - execute_options.execution_stream_id = - tsl::Env::Default()->GetCurrentThreadId(); + execute_options.execution_stream_id = xla::GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } // A vector of [num_outputs]. std::vector output_arrays; diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index 4a4e20f8f55b..5f9020edc6bc 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -632,8 +632,11 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); - execute_options.execution_stream_id = - tsl::Env::Default()->GetCurrentThreadId(); + execute_options.execution_stream_id = xla::GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } // A vector of [num_outputs]. std::vector output_arrays; diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h index 772dba864684..da89b4718f76 100644 --- a/jaxlib/py_client.h +++ b/jaxlib/py_client.h @@ -256,6 +256,12 @@ class PyClient { memory_spaces_; }; +// Returns the execution stream id set for the current thread. +inline int64_t& GetExecutionStreamId() { + thread_local int64_t execution_stream_id = 0; + return execution_stream_id; +} + } // namespace xla #endif // JAXLIB_PY_CLIENT_H_ diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index c934c83adebc..f3acfa8f62e3 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -367,7 +367,10 @@ absl::StatusOr PyLoadedExecutable::ExecuteSharded( xla::ifrt::ExecuteOptions options = options_; options.launch_id = GetNextLaunchId(); options.fill_status = with_tokens; - options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + options.execution_stream_id = GetExecutionStreamId(); + if (options.execution_stream_id == 0) { + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + } std::optional>> returned_futures; if (with_tokens) { returned_futures.emplace(); diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc index 8206316bcb73..6c8e3fdb4ad1 100644 --- a/jaxlib/xla.cc +++ b/jaxlib/xla.cc @@ -518,6 +518,10 @@ NB_MODULE(_jax, m) { .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) .def("consume_token", &PyExecuteResults::ConsumeToken); + m.def("get_execution_stream_id", []() { return GetExecutionStreamId(); }); + m.def("set_execution_stream_id", + [](int64_t id) { GetExecutionStreamId() = id; }); + nb::class_(m, "LoadedExecutable") .def_prop_ro("client", &PyLoadedExecutable::client) .def("local_devices", &PyLoadedExecutable::AddressableDevices) diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 90c9478abf93..a5c85276ce2c 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -43,7 +43,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 354 +_version = 355 # An internal increasing version number for protecting jaxlib code against # ifrt changes. @@ -519,6 +519,17 @@ def tracebacks(enabled=True): _xla.set_tracebacks_enabled(saved) +@contextlib.contextmanager +def execution_stream_id(new_id: int): + """Context manager that overwrites and restores the current thread's execution_stream_id.""" + saved = _xla.get_execution_stream_id() + _xla.set_execution_stream_id(new_id) + try: + yield + finally: + _xla.set_execution_stream_id(saved) + + XlaRuntimeError = _xla.XlaRuntimeError # Perform one last garbage collection of deferred Python references. This is From f22896ac23928b05488bf4524818fb241fc3817b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Jun 2025 09:32:04 -0700 Subject: [PATCH 1720/1769] jax.experimental.enable_x64: add warning to docstring --- jax/experimental/x64_context.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/experimental/x64_context.py b/jax/experimental/x64_context.py index 1772d466b006..3ef5289df4f1 100644 --- a/jax/experimental/x64_context.py +++ b/jax/experimental/x64_context.py @@ -30,6 +30,13 @@ def enable_x64(new_val: bool = True): """Experimental context manager to temporarily enable X64 mode. + .. warning:: + + This context manager remains experimental because it is fundamentally broken + and can result in unexpected behavior, particularly when used in conjunction + with JAX transformations like :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, + and others. See https://github.com/jax-ml/jax/issues/5982 for details. + Usage:: >>> x = np.arange(5, dtype='float64') @@ -40,7 +47,7 @@ def enable_x64(new_val: bool = True): See Also -------- - jax.experimental.enable_x64 : temporarily enable X64 mode. + jax.experimental.disable_x64 : temporarily disable X64 mode. """ with config.enable_x64(new_val): yield @@ -49,6 +56,13 @@ def enable_x64(new_val: bool = True): def disable_x64(): """Experimental context manager to temporarily disable X64 mode. + .. warning:: + + This context manager remains experimental because it is fundamentally broken + and can result in unexpected behavior, particularly when used in conjunction + with JAX transformations like :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, + and others. See https://github.com/jax-ml/jax/issues/5982 for details. + Usage:: >>> x = np.arange(5, dtype='float64') From 0fd082136ea62123f0996b3dd5d286e3d0b02680 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Jun 2025 10:27:11 -0700 Subject: [PATCH 1721/1769] [Pallas TPU] Add flag to enable using registers to keep track of slot info PiperOrigin-RevId: 772527011 --- jax/_src/pallas/mosaic/pipeline.py | 27 ++++++++++++++++++------ tests/pallas/tpu_pallas_pipeline_test.py | 4 +++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index de7895879f57..afa1972128d0 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -561,7 +561,8 @@ def current_ref(self): @property def current_slot_index(self): """Index in double buffer corresponding to the current slot.""" - # TODO(ramiroleal): Fix race condition when returning register value for current_slot. + if self._current_slot_reg is not None: + return self._current_slot_reg return self.current_slot[0] @property @@ -620,10 +621,11 @@ def swap_slots(self, predicate: bool | jax.Array = True) -> "BufferedRef": new_current_slot = lax.select( predicate, self.next_slot_index, self.current_slot_index ) - result = self.with_slot_index(new_current_slot) - # TODO(ramiroleal): Fix race condition when using register value for current_slot. - result.save_slots() - return result + if self._current_slot_reg is not None: + return self.with_slot_index(new_current_slot) + assert isinstance(self.current_slot, jax.Array) + self.current_slot[0] = new_current_slot + return self def load_slots(self) -> "BufferedRef": """Load slot information into registers.""" @@ -830,6 +832,7 @@ def __init__( last_cycle=None, init_accumulators=None, trace_scopes=True, + use_sreg_for_state: bool = False, ): """Initializes scheduler. @@ -843,6 +846,8 @@ def __init__( init_accumulators: do we zero-initialize accumulator state for this invocation of the pipeline. trace_scopes: whether to use named_scope to trace blocks in the pipeline. + use_sreg_for_state: optional bool, indicates whether to use sregs for + current_slot state. """ self.step = step self.grid = grid @@ -850,6 +855,7 @@ def __init__( self.last_cycle = last_cycle self.init_accumulators = init_accumulators self.trace_scopes = trace_scopes + self.use_sreg_for_state = use_sreg_for_state # Total number of linear steps. self.num_steps = _grid_size(grid) @@ -916,7 +922,8 @@ def initialize(self, buffered_ref, src_ref, schedule=None): def _init_slots(): buffered_ref.init_slots() - buffered_ref = buffered_ref.load_slots() + if self.use_sreg_for_state: + buffered_ref = buffered_ref.load_slots() @pl.when(do_copy & buffered_ref.is_input) def _copy_in(): @@ -1029,7 +1036,9 @@ def finalize(self, buffered_ref, dst_ref, schedule=None): def _end(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.indices) - buffered_ref.save_slots() + + if self.use_sreg_for_state: + buffered_ref.save_slots() def swap_slots( self, buffered_ref, hbm_ref, schedule=None @@ -1318,6 +1327,7 @@ def emit_pipeline( dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, trace_scopes: bool = True, no_pipelining: bool = False, + use_sreg_for_state: bool = False, ): """Creates a function to emit a manual pallas pipeline. @@ -1347,6 +1357,8 @@ def emit_pipeline( no_pipelining: If True, turns off pipelining and all copies will be made synchronous. This is useful for debugging multiple-buffering related bugs. + use_sreg_for_state: optional bool, indicates whether to use sregs for + current_slot state. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -1459,6 +1471,7 @@ def make_scheduler(step, indices): last_cycle=last_cycle, init_accumulators=init_accumulators, trace_scopes=trace_scopes, + use_sreg_for_state=use_sreg_for_state, ) def loop_body(step, carry): diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 94f8359dbaed..515e4a3c26ea 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -129,8 +129,9 @@ def setUp(self): @parameterized.product( no_pipelining=[False, True], + use_sreg_for_state=[False, True], ) - def test_pipeline_matmul(self, no_pipelining): + def test_pipeline_matmul(self, no_pipelining, use_sreg_for_state): k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -152,6 +153,7 @@ def matmul_kernel(x_ref, y_ref, z_ref): ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), no_pipelining=no_pipelining, + use_sreg_for_state=use_sreg_for_state, )(x_ref, y_ref, z_ref) z = pl.pallas_call( From 784be1f928c1514e71150a15e2b12e6b22708e46 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Fri, 30 May 2025 14:20:05 -0700 Subject: [PATCH 1722/1769] add psend and precv to jax/lax/parallel --- docs/gpu_performance_tips.md | 7 +- jax/_src/core.py | 16 ++- jax/_src/lax/parallel.py | 179 +++++++++++++++++++++++++++++- jax/experimental/jax2tf/jax2tf.py | 2 + jax/lax/__init__.py | 2 + tests/shard_map_test.py | 91 +++++++++++++++ 6 files changed, 289 insertions(+), 8 deletions(-) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index c9034a515501..f62523631872 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -256,9 +256,9 @@ Run the real workflow, if you found these loggings in the running log, it means ### Pipeline Parallelism on GPU -XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique -where the forward and backward pass are split into multiple pipeline stages. -Each device (or device group) processes the result of the previous +XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling +technique where the forward and backward pass are split into multiple pipeline +stages. Each device (or device group) processes the result of the previous pipeline stage (or the pipeline input) and sends its partial result to the next stage until the end of the pipeline is reached. This optimization works best when the latency of the computation is larger than communication. At compile @@ -475,6 +475,7 @@ def main(_): output_buffer = entry_computation(weights, input_buffer, mesh) print(f"output_buffer = \n{output_buffer}") ``` + ## NCCL flags These Nvidia NCCL flag values may be useful for single-host multi-device diff --git a/jax/_src/core.py b/jax/_src/core.py index dcd0d5dd6591..c9addb08b0cc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2032,6 +2032,14 @@ def get_vma(vma, mesh): assert isinstance(vma, frozenset) return vma + +class SingleSideCollectiveEffect(effects.Effect): + __str__ = lambda _: "one-sided communication" + + +single_side_collective_effect = SingleSideCollectiveEffect() +effects.control_flow_allowed_effects.add_type(SingleSideCollectiveEffect) + class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding', 'vma'] # inherits slots from parent array_abstraction_level = 2 @@ -2175,8 +2183,12 @@ def standard_insert_pvary(*args): in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token else aval.vma for a in args] # pytype: disable=attribute-error out_vma = frozenset.union(*in_vma) - return [pvary(arg, tuple(n for n in out_vma if n not in src)) - if out_vma - src else arg for arg, src in zip(args, in_vma)] + return [ + pvary(arg, tuple(n for n in out_vma if n not in src)) + if isinstance(get_aval(arg), ShapedArray) and out_vma - src + else arg + for arg, src in zip(args, in_vma) + ] def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: if not config._check_vma.value: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index a5bb7222143d..4c1d74d484c7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -36,11 +36,12 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.mesh import get_abstract_mesh -from jax._src.core import pvary +from jax._src.core import abstract_token, pvary from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +from jax._src.lib import xla_client as xc from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) import jax.numpy as jnp @@ -355,6 +356,80 @@ def bind(leaf): return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) return tree_util.tree_map(bind, x) + +def psend(x, axis_name, perm): + """Perform a collective send according to the permutation ``perm``. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + This function is an analog of the Send HLO. + + Args: + x: array(s) with a mapped axis named ``axis_name``. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL send. + + Returns: + A compiler token that can be used by precv and lax.optimzation_barrier to + enforce ordering of collective ops. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return psend_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + + return tree_util.tree_map(bind, x) + + +def precv(token, out_shape, axis_name, perm): + """Perform a collective recv according to the permutation ``perm``. + + This function is an analog of the Recv HLO. + + Args: + token: a compiler token, either generated by a matching psend or + lax.create_token(). This is used to enforce control dependencies between + collectives. + out_shape: ShapeDtypeStruct(s) containing the dtype and shape + of the result. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL recv. + + Returns: + Array(s) with the same shape as ``out_shape``. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + return precv_p.bind( + token, + out_shape=core.ShapedArray( + out_shape.shape, out_shape.dtype + ), + axis_name=axis_name, + perm=tuple(map(tuple, perm)), + ) + + def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -1027,12 +1102,12 @@ def broadcast_positional(ct, arg): batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') -def _ppermute_lowering(ctx, x, *, axis_name, perm): +def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name): replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): - msg = "ppermute sources and destinations must be unique, got {}." + msg = f"{op_name} sources and destinations must be unique, got {{}}." raise ValueError(msg.format(perm)) full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64) @@ -1054,10 +1129,17 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) else: other_args = {} + return full_perm, other_args + +def _ppermute_lowering(ctx, x, *, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="ppermute" + ) return hlo.CollectivePermuteOp( x, mlir.dense_int_elements(full_perm), **other_args).results + def _ppermute_transpose_rule(t, x, perm, axis_name): srcs, dsts = unzip2(perm) inverse_perm = list(zip(dsts, srcs)) @@ -1094,6 +1176,97 @@ def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name') + +def _psend_lowering_gpu(ctx, x, *, axis_name, perm): + if ("cuda" not in ctx.module_context.platforms): + raise NotImplementedError("psend is currently only implemented on GPUs") + + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="psend" + ) + token = hlo.create_token() + send_op = hlo.SendOp( + [x], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("psend currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(send_op, sharding) + return [send_op.results] + + +mlir.lowerable_effects.add_type(core.SingleSideCollectiveEffect) + + +def _psend_abstract_eval(x, *, axis_name, **params): + _check_axis_names(axis_name) + return abstract_token, { + *map(core.NamedAxisEffect, axis_name), + core.SingleSideCollectiveEffect(), + } + + +psend_p = core.Primitive("psend") +psend_p.def_impl(partial(dispatch.apply_primitive, psend_p)) +psend_p.def_effectful_abstract_eval(_psend_abstract_eval) +mlir.register_lowering(psend_p, _psend_lowering_gpu, platform="gpu") + +def _psend_lowering(ctx, x, *, axis_name, perm): + raise NotImplementedError("psend is currently only implemented on GPU") +mlir.register_lowering(psend_p, _psend_lowering) + +batching.fancy_primitive_batchers[psend_p] = _ppermute_batcher +batching.skippable_batchers[psend_p] = partial(_names_in_param, "axis_name") + + +def _precv_lowering_gpu(ctx, token, *, out_shape, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="precv" + ) + recv_op = hlo.RecvOp( + [mlir.aval_to_ir_type(out_shape), token.type], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("precv currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(recv_op, sharding) + + # recv_op should return an array of [RankedTensorType, StableHlo.token]; we + # only need the tensor. + results = recv_op.results + return [results[0]] + + +def _precv_abstract_eval( + token, *, out_shape, axis_name, **params +): + return out_shape, {*map(core.NamedAxisEffect, axis_name), + core.SingleSideCollectiveEffect()} + +precv_p = core.Primitive("precv") +precv_p.multiple_results = False +precv_p.def_effectful_abstract_eval(_precv_abstract_eval) +mlir.register_lowering(precv_p, _precv_lowering_gpu, platform='gpu') + +def _precv_lowering(ctx, token, *, out_shape, axis_name, perm): + raise NotImplementedError("precv is currently only implemented on GPU") +mlir.register_lowering(precv_p, _precv_lowering) + +batching.fancy_primitive_batchers[precv_p] = _ppermute_batcher +batching.skippable_batchers[precv_p] = partial(_names_in_param, "axis_name") + def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3c34a26af982..9c4437f2bf1f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1526,6 +1526,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "pmax_p", "pmin", "ppermute", + "psend", + "precv", "psum", "psum2", "pbroadcast", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index c6df458ba91d..2cec803172b6 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -370,6 +370,8 @@ pmin_p as pmin_p, ppermute as ppermute, ppermute_p as ppermute_p, + psend as psend, + precv as precv, pshuffle as pshuffle, psum as psum, psum_p as psum_p, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f3f5641be1b6..405696f0793f 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -225,6 +225,97 @@ def fwd(a): c = fwd(a) self.assertAllClose(c[1, :], a[0, :]) + @jtu.run_on_devices("gpu") + def test_psend_precv_basic_with_no_deadlock_cycle(self): + mesh = jtu.create_mesh((8,), 'x') + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + weights = jax.random.uniform( + key=jax.random.key(0), shape=(8, 1), dtype=jnp.float32) + + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + + # We define the "forward edge" to be the device-to-device communication + # originating from device 0 in increasing indices. + fwd_token = jax.lax.psend( + a, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + # Here we use an optimization barrier to enforce an arbitrary ordering of + # collectives. This will make sure compute happens after recv on the forward + # edge, and by extension will make sure the send on the back edge happens + # after the recv on the forward edge. Without this optimization barrier, the + # send on the backward edge might slip before the forward edge recv ops are + # completed, and will cause a deadlock. + weights_, _ = ( + jax.lax.optimization_barrier( + (weights, data) + ) + ) + res = jnp.dot(weights_, data) + + # send the compute result back to the first device + bwd_token = jax.lax.psend( + res, + axis_name="x", + perm=[(7, 0)], + ) + + bwd_data = jax.lax.precv( + bwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(7, 0)] + ) + return bwd_data + + c = fwd(a) + self.assertEqual(c.shape, a.shape) + + @jtu.run_on_devices("gpu") + def test_psend_precv_reverse(self): + mesh = jtu.create_mesh((8,), 'x') + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + dummy_data = jax.lax.precv( + jax.lax.create_token(), + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + _ = jax.lax.psend( + dummy_data, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + return dummy_data + + c = fwd(a) + self.assertAllClose(c, jnp.zeros_like(a)) + def test_collective_permute_with_multiple_axis_names(self): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( From c2cc9f9cc9fd559f85195bace3c511f435070709 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 17 Jun 2025 11:02:10 -0700 Subject: [PATCH 1723/1769] [pallas] `AbstractMemoryRef` now implements all functional update methods via `update` This frees subclasses from overridding these methods. PiperOrigin-RevId: 772542580 --- jax/_src/pallas/core.py | 9 +++----- jax/_src/pallas/mosaic_gpu/core.py | 36 ++++++++++-------------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index bf92a6cc6a5a..36378b8d1cc1 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -236,12 +236,10 @@ def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' def update_weak_type(self, weak_type): - return AbstractMemoryRef( - self.inner_aval.update_weak_type(weak_type), self.memory_space) + return self.update(inner_aval=self.inner_aval.update_weak_type(weak_type)) def update_vma(self, vma): - return AbstractMemoryRef( - self.inner_aval.update_vma(vma), self.memory_space) + return self.update(inner_aval=self.inner_aval.update_vma(vma)) def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval @@ -249,8 +247,7 @@ def update(self, inner_aval=None, memory_space=None): return AbstractMemoryRef(inner_aval, memory_space) def to_tangent_aval(self): - return AbstractMemoryRef( - self.inner_aval.to_tangent_aval(), self.memory_space) + return self.update(inner_aval=self.inner_aval.to_tangent_aval()) # TODO(dougalm, sharadmv): figure out how to avoid needing this def normalize(self): diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 9718a8dd749d..839d987a484a 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -342,9 +342,9 @@ def _setitem(self, tracer, index, value): del tracer, index, value # Unused. raise ValueError("Ref unions can't be assigned to.") - def update_vma(self, vma): - return AbstractRefUnion(self.inner_aval.update_vma(vma), self.refs, - self.memory_space) + def update(self, inner_aval=None, memory_space=None): + ref = super().update(inner_aval, memory_space) + return AbstractRefUnion(ref.inner_aval, self.refs, self.memory_space) @dataclasses.dataclass(init=False, frozen=True) @@ -922,14 +922,12 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def update_weak_type(self, weak_type): - return _as_accum(super().update_weak_type(weak_type)) - - def update_vma(self, vma): - return _as_accum(super().update_vma(vma)) - def update(self, inner_aval=None, memory_space=None): - return _as_accum(super().update(inner_aval=None, memory_space=None)) + ref = super().update(inner_aval, memory_space) + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, + ) def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error @@ -941,12 +939,6 @@ def _getitem(self, tracer, idx): return arr -def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: - return WGMMAAbstractAccumulatorRef( - inner_aval=ref.inner_aval, - memory_space=ref.memory_space, # pytype: disable=attribute-error - ) - class AbstractTMEMRef(AbstractMemoryRef): __slots__ = ["inner_aval", "memory_space", "packed", "collective"] @@ -958,15 +950,11 @@ def __init__(self, inner_aval, memory_space, packed, collective): def __repr__(self) -> str: return f'TMEM({self.inner_aval.str_short()},packed={self.packed})' - def update_vma(self, vma): - return AbstractTMEMRef( - self.inner_aval.update_vma(vma), self.memory_space, self.packed, - self.collective) - - def update_weak_type(self, weak_type): + def update(self, inner_aval=None, memory_space=None): + ref = super().update(inner_aval, memory_space) return AbstractTMEMRef( - self.inner_aval.update_weak_type(weak_type), self.memory_space, - self.packed, self.collective) + ref.inner_aval, ref.memory_space, self.packed, self.collective + ) _WARPGROUP_AXIS_NAME = object() From 02688e18fc48e1cb3f5572b2257b540dbbce28dd Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 17 Jun 2025 11:48:13 -0700 Subject: [PATCH 1724/1769] [Pallas][Mosaic GPU] Enable collective MMA from TMEM. PiperOrigin-RevId: 772560593 --- jax/experimental/mosaic/gpu/tcgen05.py | 4 +- tests/mosaic/gpu_test.py | 122 +++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 130 ++++++++++++++----------- 3 files changed, 196 insertions(+), 60 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index c0415db6b581..26872d89f99e 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -145,8 +145,8 @@ def mma( if isinstance(a, TMEMRef): m, k2 = a.shape element_type2 = a.dtype - if collective: - raise NotImplementedError("Collective not supported for TMEMRef") + if collective and n * num_cta == 512: + raise NotImplementedError("Collective MMA with N=512 is not supported") if a.layout != (expected_layout := _infer_tmem_layout(a.shape, packing=2)): raise ValueError( f"A layout mismatch: expected {expected_layout}, got {a.layout}" diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 7dae5c5f37b8..5c8f6bc3c558 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1518,6 +1518,128 @@ def quantize(x): atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 np.testing.assert_allclose(z, ref, atol=atol) + @parameterized.product( + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(256,), # TODO(apaszke): 64, 192, 256 + n=(128, 256,), # TODO(apaszke): 192, other non-power-of-2, 512 + k_steps=(2,), # Note: reducing to 1 can be useful for debugging. + swizzle=(32, 64, 128,), + ) + def test_mma_collective_lhs_tmem( + self, + m, + n, + k_steps, + swizzle, + in_jax_dtype, + out_jax_dtype, + ): + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + raise self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + m_block_tile = m // 2 + n_block_tile = n // 2 + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + index = ir.IndexType.get() + + tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem, barriers, cluster_barrier, acc, lhs_tmem = scratch + block_id = gpu.cluster_block_id(gpu.Dimension.x) + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[0], + collective=gpu.Dimension.x, + partitioned=0, # Split non-contracting dim. + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[1], + collective=gpu.Dimension.x, + partitioned=1, # Split non-contracting dim. + ) + + is_leader_thread = single_thread_predicate() + is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + + with when(arith.andi(is_first_block, is_leader_thread)): + barriers[0].wait() + gpu.barrier() + # Because only block 1 waits on the TMA, we need a cluster barrier so + # that the SMEM updates are visible on block 2. + cluster_barrier.arrive() + cluster_barrier.wait() + lhs_tmem.store( + fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) + ) + tcgen05.commit_tmem() + # Make sure TMEM has been loaded on both blocks. + cluster_barrier.arrive() + cluster_barrier.wait() + with when(arith.andi(is_first_block, is_leader_thread)): + barriers[1].wait() + tcgen05.mma( + acc, + lhs_tmem, + rhs_smem, + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=False, + collective=True, + ) + tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) + barriers[2].wait(for_tensor_core=True) + m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) + + in_finfo = jnp.finfo(in_jax_dtype) + exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant + + def quantize(x): + # Quantize the input to avoid rounding when feeding the TensorCore + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + + x_shape = (m, k) + x_block_shape = (m_block_tile, k) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y_shape = (k, n) + y_block_shape = (k, n_block_tile) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), + mgpu.TMABarrier(3), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,)), + mgpu.TMEM((128, n), out_jax_dtype, collective=True), + mgpu.TMEM((128, k), in_jax_dtype, collective=True, packing=2), + ] + z = mgpu.as_gpu_kernel( + kernel, + (2, 1, 1), + (128, 1, 1), + (x, y), + out_shape, + scratch_shape, + cluster=(2, 1, 1), + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = x32 @ y32 + atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol) + class BarrierTest(TestCase): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 1857068787ba..0a9dc3b5499b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2666,19 +2666,21 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) - @parameterized.parameters( - ((256, 256), (256, 256), 128, jnp.float16), - # Test additional shape combinations. - ((256, 128), (128, 128), 128, jnp.float16), - ((256, 64), (64, 256), 128, jnp.float16), - # Test bfloat16. - ((256, 256), (256, 256), 128, jnp.bfloat16), - # Test additional swizzles. - ((256, 256), (256, 256), 64, jnp.float16), - ((256, 256), (256, 256), 32, jnp.float16), + @parameterized.product( + m_n_k=[(256, 256, 256), (256, 128, 128), (256, 256, 64)], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], ) - def test_simple_collective_matmul(self, lhs_shape, rhs_shape, swizzle, dtype): + def test_simple_collective_matmul(self, m_n_k, swizzle, dtype, lhs_tmem): self.skip_if_wg_semantics() + m, n, k = m_n_k + full_lhs_shape = (m, k) + full_rhs_shape = (k, n) + full_acc_shape = (m, n) + block_acc_shape = (m // 2, n) + block_lhs_shape = (m // 2, k) + block_rhs_shape = (k, n // 2) # Test a collective (paired CTA) matmul on a single block. swizzle_elems = swizzle // jnp.dtype(dtype).itemsize transforms = ( @@ -2686,57 +2688,69 @@ def test_simple_collective_matmul(self, lhs_shape, rhs_shape, swizzle, dtype): plgpu.SwizzleTransform(swizzle), ) - acc_shape = (lhs_shape[0], rhs_shape[1]) - _acc_shape = (lhs_shape[0] // 2, rhs_shape[1]) - _lhs_shape = (lhs_shape[0] // 2, lhs_shape[1]) - _rhs_shape = (rhs_shape[0], rhs_shape[1] // 2) - - def kernel(a_gmem, b_gmem, out_gmem): + def kernel(a_gmem, b_gmem, out_gmem, a_smem, b_smem, + acc_tmem, scratch_smem, tma_barrier, mma_barrier, + cluster_barrier, lhs_tmem_ref): cluster_idx = lax.axis_index("x") - slice_lhs = pl.ds(cluster_idx * _lhs_shape[0], _lhs_shape[0]) - slice_rhs = pl.ds(cluster_idx * _rhs_shape[1], _rhs_shape[1]) - - @functools.partial(pl.run_scoped, - a_smem=plgpu.SMEM(_lhs_shape, dtype, transforms=transforms), - b_smem=plgpu.SMEM(_rhs_shape, dtype, transforms=transforms), - acc_tmem=plgpu.TMEM(_acc_shape, jnp.float32, collective=True), - scratch_smem=plgpu.SMEM(_acc_shape, dtype, transforms=transforms), - tma_barrier=plgpu.Barrier(), - mma_barrier=plgpu.Barrier(for_tensor_core=True), - cluster_barrier=plgpu.ClusterBarrier(collective_axes=("x",)), + slice_lhs = pl.ds(cluster_idx * block_lhs_shape[0], block_lhs_shape[0]) + slice_rhs = pl.ds(cluster_idx * block_rhs_shape[1], block_rhs_shape[1]) + + plgpu.copy_gmem_to_smem(a_gmem.at[slice_lhs, :], a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem.at[:, slice_rhs], b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + if lhs_tmem: + lhs_ref = lhs_tmem_ref + lhs_ref[...] = plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + + plgpu.tcgen05_mma( + acc_tmem, + lhs_ref, + b_smem, + mma_barrier, + accumulate=False, + collective_axis="x", ) - def _scoped(a_smem, b_smem, - acc_tmem, scratch_smem, tma_barrier, mma_barrier, cluster_barrier): - plgpu.copy_gmem_to_smem(a_gmem.at[slice_lhs, :], a_smem, tma_barrier) - plgpu.barrier_wait(tma_barrier) - plgpu.copy_gmem_to_smem(b_gmem.at[:, slice_rhs], b_smem, tma_barrier) - plgpu.barrier_wait(tma_barrier) - - plgpu.barrier_arrive(cluster_barrier) - plgpu.barrier_wait(cluster_barrier) - - plgpu.tcgen05_mma(acc_tmem, - a_smem, - b_smem, - mma_barrier, - accumulate=False, - collective_axis="x") - plgpu.barrier_wait(mma_barrier) - scratch_smem[...] = acc_tmem[...].astype(dtype) - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(scratch_smem, out_gmem.at[slice_lhs, :]) - plgpu.wait_smem_to_gmem(0) + plgpu.barrier_wait(mma_barrier) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_gmem.at[slice_lhs, :]) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.SMEM(block_lhs_shape, dtype, transforms=transforms), + plgpu.SMEM(block_rhs_shape, dtype, transforms=transforms), + plgpu.SMEM(block_acc_shape, dtype, transforms=transforms), + plgpu.TMEM(block_acc_shape, jnp.float32, collective=True), + plgpu.Barrier(), + plgpu.Barrier(for_tensor_core=True), + plgpu.ClusterBarrier(collective_axes=("x",)), + ] + if lhs_tmem: + scratch_shapes.append( + plgpu.TMEM(block_lhs_shape, dtype, collective=True, packed=True) + ) + else: + scratch_shapes.append(None) f = self.kernel( - kernel, - out_shape=jax.ShapeDtypeStruct(acc_shape, dtype), - grid=(1,), - grid_names=("_",), - cluster=(2,), - cluster_names=("x",), - ) - x = jax.random.uniform(jax.random.key(0), shape=lhs_shape, dtype=dtype) - y = jax.random.uniform(jax.random.key(1), shape=rhs_shape, dtype=dtype) + kernel, + out_shape=jax.ShapeDtypeStruct(full_acc_shape, dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=full_lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=full_rhs_shape, dtype=dtype) result = f(x, y) expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) From e4de90e6d0eb630db1f6b530b5c55d9fa6123317 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Jun 2025 12:03:50 -0700 Subject: [PATCH 1725/1769] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3d5ece64321630dade7ff733ae1353fc3c83d9cc. PiperOrigin-RevId: 772567112 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a585c499286d..dccf8d47a6cb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "54d83f35e69f82321308f1c479823521ce5cc4ed" -XLA_SHA256 = "e0c71f7f7ae2862ad6f105a8732cd6c4f1fa04fd9e98c0ab90a3460c924aa8cf" +XLA_COMMIT = "3d5ece64321630dade7ff733ae1353fc3c83d9cc" +XLA_SHA256 = "fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a" def repo(): tf_http_archive( From 8f81490ad4ed60cf923e6f0ef7a3bc1d708c7636 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Jun 2025 12:16:53 -0700 Subject: [PATCH 1726/1769] Prepare for JAX release 0.6.2 --- jax/version.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/version.py b/jax/version.py index e15af7ab50fc..224b90db4e1f 100644 --- a/jax/version.py +++ b/jax/version.py @@ -152,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.6.1" +_minimum_jaxlib_version = "0.6.2" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 71e0561b036c..b098bf1ab3b7 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.6.1' +_current_jaxlib_version = '0.6.2' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.6.1' -_libtpu_version = '0.0.15.*' +_libtpu_version = '0.0.17.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From 7dd13d701802d5f74932dd5e48359f26b456e136 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Jun 2025 13:30:25 -0700 Subject: [PATCH 1727/1769] Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream pytype failures. Reverts dc9ef6145bba53afbabdc7a5748c4afa1cd16025 PiperOrigin-RevId: 772598508 --- .pre-commit-config.yaml | 2 +- jax/BUILD | 1 - jax/nn/__init__.pyi | 92 ----------------------------------------- 3 files changed, 1 insertion(+), 94 deletions(-) delete mode 100644 jax/nn/__init__.pyi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 44ad912cf579..2312c88579d6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: hooks: - id: mypy files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead + exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jaxlib/_jax/.* # Use pyi instead additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0] args: [--config=pyproject.toml] diff --git a/jax/BUILD b/jax/BUILD index b0d521ad1f44..6b1c5dc88499 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -349,7 +349,6 @@ py_library_providing_imports_info( lib_rule = pytype_library, pytype_srcs = glob( [ - "nn/*.pyi", "numpy/*.pyi", "_src/**/*.pyi", ], diff --git a/jax/nn/__init__.pyi b/jax/nn/__init__.pyi deleted file mode 100644 index 89a3298e96b2..000000000000 --- a/jax/nn/__init__.pyi +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Any, List, Literal, Union, overload, Sequence - -from jax._src.typing import ( - Array, ArrayLike, DTypeLike -) -from jax._src.core import AxisName -from jax._src.cudnn.scaled_matmul_stablehlo import BlockScaleConfig -from jax._src.lax.lax import DotDimensionNumbers - -from jax.nn import initializers as initializers - -_Axis = Union[None, int, Sequence[int]] - - -def celu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... -def dot_product_attention( - query: ArrayLike, - key: ArrayLike, - value: ArrayLike, - bias: ArrayLike | None = ..., - mask: ArrayLike | None = ..., - *, - scale: float | None = ..., - is_causal: bool = ..., - query_seq_lengths: ArrayLike | None = ..., - key_value_seq_lengths: ArrayLike | None = ..., - local_window_size: int | tuple[int, int] | None = ..., - implementation: Literal['xla', 'cudnn'] | None = ...) -> Array: ... -def elu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... -def gelu(x: ArrayLike, approximate: bool = ...) -> Array: ... -def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], - global_scale: Array | None = ...) -> BlockScaleConfig: ... -def glu(x: ArrayLike, axis: int = ...) -> Array: ... -def hard_sigmoid(x: ArrayLike) -> Array: ... -def hard_silu(x: ArrayLike) -> Array: ... -def hard_swish(x: ArrayLike) -> Array: ... -def hard_tanh(x: ArrayLike) -> Array: ... -def identity(x: ArrayLike) -> Array: ... -def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = ...) -> Array: ... -def log_sigmoid(x: ArrayLike) -> Array: ... -def log_softmax(x: ArrayLike, - axis: int | tuple[int, ...] | None = ..., - where: ArrayLike | None = ...) -> Array: ... -@overload -def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., - keepdims: bool = ..., return_sign: Literal[False] = ..., where: ArrayLike | None = ...) -> Array: ... - -@overload -def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., - keepdims: bool = ..., *, return_sign: Literal[True], where: ArrayLike | None = ...) -> tuple[Array, Array]: ... - -@overload -def logsumexp(a: ArrayLike, axis: _Axis = ..., b: ArrayLike | None = ..., - keepdims: bool = ..., return_sign: bool = ..., where: ArrayLike | None = ...) -> Array | tuple[Array, Array]: ... -def mish(x: ArrayLike) -> Array: ... -def one_hot(x: Any, num_classes: int, *, - dtype: Any = ..., axis: int | AxisName = ...) -> Array: ... -def relu(x: ArrayLike) -> Array: ... -def relu6(x: ArrayLike) -> Array: ... -def scaled_dot_general( - lhs: ArrayLike, rhs: ArrayLike, - dimension_numbers: DotDimensionNumbers, - preferred_element_type: DTypeLike = ..., - configs: List[BlockScaleConfig] | None = ..., - implementation: Literal['cudnn'] | None = ..., - ) -> Array: ... -def scaled_matmul( - lhs: Array, - rhs: Array, - lhs_scales: Array, - rhs_scales: Array, - preferred_element_type: DTypeLike = ..., -) -> Array: ... -def selu(x: ArrayLike) -> Array: ... -def sigmoid(x: ArrayLike) -> Array: ... -def silu(x: ArrayLike) -> Array: ... -def soft_sign(x: ArrayLike) -> Array: ... -def softmax(x: ArrayLike, - axis: int | tuple[int, ...] | None = ..., - where: ArrayLike | None = ...) -> Array: ... -def softplus(x: ArrayLike) -> Array: ... -def sparse_plus(x: ArrayLike) -> Array: ... -def sparse_sigmoid(x: ArrayLike) -> Array: ... -def squareplus(x: ArrayLike, b: ArrayLike = ...) -> Array: ... -def standardize(x: ArrayLike, - axis: int | tuple[int, ...] | None = ..., - mean: ArrayLike | None = ..., - variance: ArrayLike | None = ..., - epsilon: ArrayLike = ..., - where: ArrayLike | None = ...) -> Array: ... -def swish(x: ArrayLike) -> Array: ... -def tanh(x: ArrayLike, /) -> Array: ... From c944c6565d757bb690677c95be11efb1f07711ed Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Tue, 17 Jun 2025 13:54:20 -0700 Subject: [PATCH 1728/1769] Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline. These rules are similar to a `unary` op except that they only compute flops for the given axis. `cumlogsumexp` takes twice as many ops given the complexity of that function. PiperOrigin-RevId: 772608005 --- jax/experimental/roofline/rooflines.py | 46 ++++++++++++++++++++++++++ tests/roofline_test.py | 32 +++++++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 2f3ce62a5744..79e36a4098ad 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -25,6 +25,7 @@ from jax._src import shard_map from jax._src.lax import ( ann, + control_flow, convolution, fft, lax, @@ -42,6 +43,7 @@ for prim in it.chain( ad_util.__dict__.values(), ann.__dict__.values(), + control_flow.__dict__.values(), convolution.__dict__.values(), fft.__dict__.values(), lax.__dict__.values(), @@ -148,6 +150,50 @@ def _binary_p_roofline( roofline.register_roofline(lax.min_p)(_binary_p_roofline) roofline.register_roofline(lax.max_p)(_binary_p_roofline) +def _cumulative_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # `cum{max, min, prod, sum}` only calculate values for one axis. + unfused_flops=x.shape[axis], + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + +roofline.register_roofline(control_flow.cummax_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cummin_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumprod_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumsum_p)(_cumulative_p_roofline) + +@roofline.register_roofline(control_flow.cumlogsumexp_p) +def _cumlogsumexp_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # Similar to `cum{max, min, prod, sum}`, `cumlogsumexp` only calculates + # values for one axis. But for `x.shape[axis] = S`, it computes (for a + # naive implementation): + # S `exp` ops. + # S-1 `add` ops. + # 1 log op. + # Thus, the total number of flops is 2 * S. + unfused_flops=x.shape[axis] * 2, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + @roofline.register_roofline(lax.dot_general_p) def _dot_general_roofline( diff --git a/tests/roofline_test.py b/tests/roofline_test.py index a234c0fe4fa8..054d528f5047 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -567,6 +567,37 @@ def test_no_mesh_and_no_specs(self): )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) self.assertEqual(result.unfused_flops, 3 * 8) + @jtu.parameterized.product( + cumulative_function=[lax.cummax, lax.cummin, lax.cumprod, lax.cumsum], + axis=[0, 1, 2], + ) + def test_cumulative_ops(self, cumulative_function: int, axis: int): + f = lambda x: cumulative_function(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + + @jtu.parameterized.named_parameters( + dict(testcase_name="axis_0", axis=0), + dict(testcase_name="axis_1", axis=1), + dict(testcase_name="axis_2", axis=2), + ) + def test_cumlogsumexp_p_roofline(self, axis: int): + f = lambda x: lax.cumlogsumexp(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, 2 * x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + def test_dot_general(self): _, result = roofline.roofline(lambda a, b: a @ b)( jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) @@ -734,7 +765,6 @@ def test_conv_general_dilated_padding_string_valid(self): result.unfused_flops, 2 * expected_output_size * 3 * 3 ) - @jtu.parameterized.named_parameters( dict( testcase_name="padding", From 19f34a06e54dbe10d140828112a373498533c0d0 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 17 Jun 2025 16:08:21 -0700 Subject: [PATCH 1729/1769] [JAX] Remove sleeping from colocated Python execution tests Using `time.sleep()` and elapsed time measurement for validating sequential/concurrent executions seems unreliable in some test environment, and they also lengthen the test time of successful cases. This change replaces the time-based tests with synchronization-based tests so that no time delay or elapsed time measurement is needed. PiperOrigin-RevId: 772661473 --- tests/colocated_python_test.py | 121 ++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index c326aed1c0b7..452824cb6d52 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -16,7 +16,6 @@ import struct import tempfile import threading -import time from absl.testing import absltest from absl.testing import parameterized @@ -234,75 +233,101 @@ def test_sequential_execution(self, on_main_thread: bool): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def func0(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = 100 return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) + @colocated_python.colocated_python + def func1(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 100 + colocated_python._testing_global_state += 1 + return x + + @colocated_python.colocated_python + def func2(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 101 + return x - def sleep_twice_and_wait(x: jax.Array) -> None: - _ = sleep(x) - jax.block_until_ready(sleep(x)) + @colocated_python.colocated_python + def cleanup(): + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state - start_time = time.time() + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func0 = func0.specialize(out_specs_fn=lambda x: x) + func1 = func1.specialize(out_specs_fn=lambda x: x) + func2 = func2.specialize(out_specs_fn=lambda x: x) - # Two executions of `sleep` within `sleep_twice_and_wait` should run - # sequentially. - if on_main_thread: - sleep_twice_and_wait(x) - else: - t = threading.Thread(target=sleep_twice_and_wait, args=(x,)) - t.start() - t.join() + # cleanup needs specialization because they do not have input arguments. + cleanup = cleanup.specialize(devices=cpu_devices[:1]) - elapsed_time = time.time() - start_time + def calls(x: jax.Array) -> None: + # No explicit blocking before making the next call. + func0(x) + func1(x) + jax.block_until_ready(func2(x)) - # If sequential execution did not happen, elapsed time typically will be - # around 5 seconds. - self.assertGreaterEqual(elapsed_time, 10) + try: + # Executions in `calls` should run sequentially. + if on_main_thread: + calls(x) + else: + t = threading.Thread(target=calls, args=(x,)) + t.start() + t.join() + # Executions should succeed without an error. + finally: + cleanup() def test_concurrent_execution(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def init(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = threading.Barrier(3) return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) - - def sleep_and_wait(x: jax.Array) -> None: - jax.block_until_ready(sleep(x)) + @colocated_python.colocated_python + def func(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + colocated_python._testing_global_state.wait(timeout=5) + return x - start_time = time.time() + @colocated_python.colocated_python + def cleanup(): + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state - # All three executions of `sleep_and_wait` should run concurrently. - t1 = threading.Thread(target=sleep_and_wait, args=(x,)) - t2 = threading.Thread(target=sleep_and_wait, args=(x,)) - t1.start() - t2.start() - sleep_and_wait(x) - t1.join() - t2.join() + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func = func.specialize(out_specs_fn=lambda x: x) - elapsed_time = time.time() - start_time + # cleanup needs specialization because they do not have input arguments. + cleanup = cleanup.specialize(devices=cpu_devices[:1]) - self.assertGreaterEqual(elapsed_time, 5) - # If concurrent execution did not happen, elapsed time typically will be - # around 15 seconds. - self.assertLess(elapsed_time, 10) + try: + jax.block_until_ready(init(x)) + + # All func calls should run concurrently and enter/exit the barrier. + t1 = threading.Thread(target=func, args=(x,)) + t2 = threading.Thread(target=func, args=(x,)) + t3 = threading.Thread(target=func, args=(x,)) + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() + # Executions should succeed without a deadlock. + finally: + cleanup() def test_inputs_with_different_device_orders(self): cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices())[:2] From feab6f4b7b748b2e91753dd0fc1e23625cadb472 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Jun 2025 13:00:52 -0700 Subject: [PATCH 1730/1769] Postrelease (0.6.2) changes --- CHANGELOG.md | 3 +++ jax/version.py | 2 +- setup.py | 2 +- third_party/xla/llvm_fix.patch | 32 ++++++++++++++++++++++++++++++++ third_party/xla/workspace.bzl | 1 + 5 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/llvm_fix.patch diff --git a/CHANGELOG.md b/CHANGELOG.md index c43cc7fb9b4e..699439a82033 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased + +## JAX 0.6.2 (June 17, 2025) + * New features: * Added {func}`jax.tree.broadcast` which implements a pytree prefix broadcasting helper. diff --git a/jax/version.py b/jax/version.py index 224b90db4e1f..e2d70eccc54e 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.6.2" +_version = "0.6.3" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index b098bf1ab3b7..f377db415147 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.6.2' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.6.1' +_latest_jaxlib_version_on_pypi = '0.6.2' _libtpu_version = '0.0.17.*' diff --git a/third_party/xla/llvm_fix.patch b/third_party/xla/llvm_fix.patch new file mode 100644 index 000000000000..4bf402095517 --- /dev/null +++ b/third_party/xla/llvm_fix.patch @@ -0,0 +1,32 @@ +diff --git a/third_party/llvm/llvm_jax_fix.patch b/third_party/llvm/llvm_jax_fix.patch +new file mode 100644 +index 0000000000..5a2a60205e +--- /dev/null ++++ b/third_party/llvm/llvm_jax_fix.patch +@@ -0,0 +1,14 @@ ++diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp ++index 96be91256915d..8bcd8670879a9 100644 ++--- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp ++@@ -59383,7 +59383,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, ++ ++ // We can always convert per-lane vXf64 shuffles into VSHUFPD. ++ if (!IsSplat && ++- (VT == MVT::v4f64 || (VT == MVT::v8f64 && Subtarget.useAVX512Regs())) && +++ ((NumOps == 2 && VT == MVT::v4f64) || +++ (NumOps == 4 && VT == MVT::v8f64 && Subtarget.useAVX512Regs())) && ++ all_of(Ops, [](SDValue Op) { return Op.hasOneUse(); })) { ++ // Collect the individual per-lane v2f64/v4f64 shuffles. ++ MVT OpVT = Ops[0].getSimpleValueType(); +diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl +index ae0c1b550f..ce408f554a 100644 +--- a/third_party/llvm/workspace.bzl ++++ b/third_party/llvm/workspace.bzl +@@ -22,6 +22,7 @@ def repo(name): + "//third_party/llvm:mathextras.patch", + "//third_party/llvm:toolchains.patch", + "//third_party/llvm:zstd.patch", ++ "//third_party/llvm:llvm_jax_fix.patch", + ], + link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"}, + ) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index dccf8d47a6cb..76ea227200e2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -30,6 +30,7 @@ def repo(): sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), + patch_file = ["//third_party/xla:llvm_fix.patch"], ) # For development, one often wants to make changes to the TF repository as well From 34d88cc519b44176be8294628b271f141fc598b2 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Tue, 17 Jun 2025 18:30:25 -0700 Subject: [PATCH 1731/1769] Add `gather` to `roofline`. PiperOrigin-RevId: 772703382 --- jax/experimental/roofline/rooflines.py | 30 ++++++++++++-- tests/roofline_test.py | 56 ++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 79e36a4098ad..278c9d3d39ff 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -500,11 +500,33 @@ def _ring_collective_roofline( ) +@roofline.register_roofline(slicing.gather_p) +def _gather_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + _, indices = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + + # Gather doesn't read the whole input buffer, it's equivalent to a copy the + # size of the output shape and a read of the gather indices. + bytes = ( + out.dtype.itemsize * out.size * 2 + indices.dtype.itemsize * indices.size + ) + + return roofline.RooflineResult( + # Gather does not issue any flops. + unfused_flops=0, + unfused_hbm_bytes=bytes, + ) + + def _scalar_collective_roofline( - ctx: roofline.RooflineRuleContext, - *args, - axes: tuple[str, ...], - **kw, + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, ) -> roofline.RooflineResult: shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 054d528f5047..497b3f14958e 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -884,6 +884,62 @@ def with_neg(f): custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes ) + def test_gather_roofline(self): + operand = jnp.zeros((3, 3), dtype=jnp.int32) + indices = jnp.zeros((2, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(0,), + start_index_map=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 3), + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, 0) + # Expected bytes: + # operand: 2 * 3 * sizeof(int32) = 24 + # indices: 2 * 1 * sizeof(int32) = 8 + # output: 2 * 3 * sizeof(int32) = 24 + # total = 56 + self.assertEqual(result.unfused_hbm_bytes, 56) + + def test_gather_batching_dims_roofline(self): + operand = jnp.zeros((5, 3, 3), dtype=jnp.int32) + indices = jnp.zeros((5, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(1,), + start_index_map=(1,), + operand_batching_dims=(0,), + start_indices_batching_dims=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 1, 3), + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, 0) + # Expected bytes: + # operand: 5 * 3 * sizeof(int32) = 60 + # indices: 5 * 1 * sizeof(int32) = 20 + # output: 5 * 3 * sizeof(int32) = 60 + # total = 140 + self.assertEqual(result.unfused_hbm_bytes, 140) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 1d0774719ea89203c5c4cff923b90830f5ef67ae Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Jun 2025 18:48:28 -0700 Subject: [PATCH 1732/1769] jaxlib_extension_version == 355 after 0.6.2 release. So remove the conditionals. PiperOrigin-RevId: 772708659 --- jax/_src/compiler.py | 6 +-- jax/_src/named_sharding.py | 5 +-- jax/_src/partition_spec.py | 86 +++++--------------------------------- jax/_src/tree_util.py | 3 -- jax/_src/util.py | 3 +- tests/array_test.py | 7 ---- tests/pgle_test.py | 3 -- tests/tree_util_test.py | 51 ++++------------------ 8 files changed, 23 insertions(+), 141 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index f501454b73be..e9312f560190 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -34,7 +34,6 @@ from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib import _jax from jax._src.lib.mlir import ir @@ -355,10 +354,7 @@ def backend_compile_and_load( # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. - if jaxlib_extension_version < 345 or ( - jaxlib_extension_version >= 345 - and isinstance(backend, _jax.CompileOnlyPyClient) - ): + if isinstance(backend, _jax.CompileOnlyPyClient): if host_callbacks: return backend.compile( built_c, diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index bddf211cdc43..1b0ae46a968b 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -22,7 +22,6 @@ from jax._src import config from jax._src.util import use_cpp_class, cache, use_cpp_method -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib @@ -146,14 +145,14 @@ def __reduce__(self): def memory_kind(self) -> str | None: return self._memory_kind - @use_cpp_method(jaxlib_extension_version >= 353) + @use_cpp_method() def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash - @use_cpp_method(jaxlib_extension_version >= 353) + @use_cpp_method() def __eq__(self, other): if not isinstance(other, NamedSharding): return False diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 722253a8531c..5f743c9c141b 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,74 +13,14 @@ # limitations under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any -from jax._src.lib import jaxlib_extension_version from jax._src.lib import _jax -from jax._src.util import use_cpp_class, use_cpp_method, set_module - -export = set_module('jax.sharding') - -# TODO(phawkins): the union confuses pytype. Just use the Python branch for now -# until the C++ version is the minimum version. -if not TYPE_CHECKING and jaxlib_extension_version >= 352: - _UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION - _canonicalize_partition = _jax.canonicalize_partition -else: - class UnconstrainedSingleton: - - def __repr__(self): - return "UNCONSTRAINED" - - def __reduce__(self): - return (_get_default_unconstrained, ()) - - - # Unconstrained sentinel value for PartitionSpec, representing a dimension for - # which the user wants XLA to assign the best partitioning. - # TODO(yashkatariya): May rename to AUTO. - _UNCONSTRAINED_PARTITION = UnconstrainedSingleton() - - def _get_default_unconstrained(): - return _UNCONSTRAINED_PARTITION - - def _canonicalize_partition(partition): - if not partition: - return None - if partition is _UNCONSTRAINED_PARTITION: - return _UNCONSTRAINED_PARTITION - if isinstance(partition, (tuple, list)): - if len(partition) == 1: - return partition[0] - return tuple(partition) - return partition - - def _check(partitions, unreduced, reduced): - for p in partitions: - p = p if isinstance(p, tuple) else (p,) - for r in p: - if r in unreduced: - raise ValueError( - "partitions cannot overlap with unreduced axes passed to" - f" PartitionSpec. Got partitions: {partitions} and unreduced axes:" - f" {unreduced}") - if r in reduced: - raise ValueError( - "partitions cannot overlap with reduced axes passed to" - f" PartitionSpec. Got partitions: {partitions} and reduced axes:" - f" {reduced}") - if unreduced & reduced: - raise ValueError( - "`unreduced` and `reduced` argument to PartitionSpec cannot overlap. " - f"Got {unreduced=}, {reduced=}") - if None in unreduced: - raise ValueError( - "unreduced cannot contain None. All elements in unreduced should refer" - " to the mesh axes.") - if None in reduced: - raise ValueError( - "reduced cannot contain None. All elements in reduced should refer" - " to the mesh axes.") +from jax._src.util import use_cpp_class, use_cpp_method + +_UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION +_canonicalize_partition = _jax.canonicalize_partition + def unpickle_pspec(partitions, unreduced, reduced): return PartitionSpec(*partitions, unreduced=unreduced, reduced=reduced) @@ -96,6 +36,7 @@ def _get_ur_str(unreduced, reduced): AxisName = Any +@use_cpp_class(_jax.PartitionSpec) class PartitionSpec: """Tuple describing how to partition an array across a mesh of devices. @@ -105,8 +46,6 @@ class PartitionSpec: This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ - if jaxlib_extension_version < 352: - __slots__ = ("_partitions", "unreduced", "reduced") __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. @@ -125,7 +64,8 @@ def __init__(self, *partitions, unreduced=frozenset(), reduced=frozenset()): f" `frozenset` or `set`. Got type {type(reduced)}") self.unreduced = frozenset(unreduced) self.reduced = frozenset(reduced) - _check(self._partitions, self.unreduced, self.reduced) + # `__init__` is implemented in C++ so this check happens in C++ + # _check(self._partitions, self.unreduced, self.reduced) def __repr__(self): pr = repr(self._partitions)[1:-1] @@ -222,10 +162,4 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out.extend([None] * (ndim - len(out))) return self.update(partitions=out) -# TODO(phawkins): make this a decorator after the next jaxlib release. -if not TYPE_CHECKING and jaxlib_extension_version >= 352: - PartitionSpec = use_cpp_class(_jax.PartitionSpec)(PartitionSpec) - -# TODO(phawkins): make this a decorator after the next jaxlib release. -if not TYPE_CHECKING: - PartitionSpec = export(PartitionSpec) +PartitionSpec.__module__ = 'jax.sharding' diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index f0341f72c029..c57ab2109c56 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -27,7 +27,6 @@ from jax._src.lib import pytree from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 -from jax._src.lib import jaxlib_extension_version export = set_module('jax.tree_util') @@ -1131,8 +1130,6 @@ def tree_flatten_with_path( is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: """Alias of :func:`jax.tree.flatten_with_path`.""" - if jaxlib_extension_version < 351: - return default_registry.flatten_with_path(tree, is_leaf) is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf if not is_leaf_takes_path and is_leaf is not None: is_leaf_with_kp = lambda _, x: is_leaf(x) diff --git a/jax/_src/util.py b/jax/_src/util.py index 595f73dbc466..1b9102af2d95 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -28,7 +28,6 @@ import numpy as np from jax._src import config -from jax._src.lib import jaxlib_extension_version from jax._src.lib import weakref_lru_cache as _weakref_lru_cache from jax._src.lib import utils as jaxlib_utils @@ -45,7 +44,7 @@ T3 = TypeVar("T3") -if TYPE_CHECKING or jaxlib_extension_version < 354: +if TYPE_CHECKING: # safe_zip cannot yet be fully annotated, so we use a strategy similar # to that used for builtins.zip in python/typeshed. This supports # return types matching input types for up to three arguments. diff --git a/tests/array_test.py b/tests/array_test.py index 0814d5888016..dd8ddb078443 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -28,7 +28,6 @@ from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip @@ -1476,9 +1475,6 @@ def test_named_sharding_unreduced_error(self): NamedSharding(mesh, P('x', unreduced={'y', None})) def test_hlo_sharding_get_axis_sizes(self): - if jaxlib_extension_version < 343: - self.skipTest('Requires jaxlib_extension_version >= 343') - op = xc.OpSharding() op.type = xc.OpSharding.Type.OTHER op.tile_assignment_dimensions = [6, 35] @@ -1516,9 +1512,6 @@ def test_hlo_sharding_get_axis_sizes(self): ('3d_mesh_x_none_none', (2, 1, 1), P('x', None, None)), ) def test_gspmd_sharding_shardy_lowering(self, mesh_shape, pspec): - if jaxlib_extension_version < 344: - self.skipTest('Requires jaxlib_extension_version >= 344') - ndim = len(mesh_shape) mesh = jtu.create_mesh( mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z') diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 35764daa7793..e03e5127d023 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -30,7 +30,6 @@ from jax._src import pjit from jax._src import profiler from jax._src import test_util as jtu -from jax._src.lib import jaxlib_extension_version from jax.experimental import profiler as exp_profiler from jax.experimental.serialize_executable import ( deserialize_and_load, @@ -134,8 +133,6 @@ def get_fdo_profiles(self, dump_dir): return jit_f_fdo_profiles def testAutoPgle(self): - if jaxlib_extension_version < 354: - self.skipTest('Requires jaxlib_extension_version >= 354') mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 49318e41fad3..a1e3ccbe265f 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -26,7 +26,6 @@ from jax import tree_util from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors -from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp # Easier to read. @@ -790,14 +789,9 @@ def testKeyStr(self): def testTreeMapWithPathWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], jnp.array(0)), ([0], 7, [5, 6])) - if jaxlib_extension_version < 351: - out = tree_util.tree_map_with_path( - lambda kp, *xs: (kp[0].idx, *xs), x, y, - is_leaf=lambda n: isinstance(n, list)) - else: - out = tree_util.tree_map_with_path( - lambda kp, *xs: (kp[0].idx, *xs), x, y, - is_leaf=lambda _, n: isinstance(n, list), is_leaf_takes_path=True) + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda _, n: isinstance(n, list), is_leaf_takes_path=True) self.assertEqual(out, (((0, 1, [3]), (0, 2, jnp.array(0))), (1, [3, 4, 5], ([0], 7, [5, 6])))) @@ -814,13 +808,11 @@ def is_empty(x): tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} - if jaxlib_extension_version < 351: - flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) - else: - is_empty_new = lambda kp, x: is_empty(x) - flattened, _ = tree_util.tree_flatten_with_path( - tree1, is_empty_new, is_leaf_takes_path=True - ) + + is_empty_new = lambda kp, x: is_empty(x) + flattened, _ = tree_util.tree_flatten_with_path( + tree1, is_empty_new, is_leaf_takes_path=True + ) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, @@ -835,8 +827,6 @@ def is_empty(x): ) def testTreeFlattenWithPathWithIsLeafWithPathArgument(self): - if jaxlib_extension_version < 351: - self.skipTest("Requires jaxlib version >= 351") x = ((1, 2), [3, {4: 4, 5: 5}]) check_max_depth = lambda kp, _: len(kp) >= 2 flattened, _ = tree_util.tree_flatten_with_path( @@ -853,8 +843,6 @@ def testTreeFlattenWithPathWithIsLeafWithPathArgument(self): ) def testTreeMapWithPathWithIsLeafWithPathArgument(self): - if jaxlib_extension_version < 351: - self.skipTest("Requires jaxlib version >= 351") x = ((1, 2), [3, 4, 5]) y = (([3], jnp.array(0)), ([0], 7, [5, 6])) out = tree_util.tree_map_with_path( @@ -1101,10 +1089,7 @@ class Tree: f = jax.jit(lambda x: x) - if jax._src.lib.jaxlib_extension_version < 346: - msg = "The truth value of an array with more than one element is ambiguous." - else: - msg = "Exception raised while checking equality of metadata fields of pytree." + msg = "Exception raised while checking equality of metadata fields of pytree." # First call succeeds, because there is no equality check. f(Tree(jnp.arange(4))) @@ -1564,12 +1549,6 @@ def test_tree_flatten_with_path(self): def test_tree_flatten_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) - if jaxlib_extension_version < 351: - self.assertEqual( - jax.tree.flatten_with_path(obj, is_leaf=is_leaf), - tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), - ) - return is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( jax.tree.flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), @@ -1586,12 +1565,6 @@ def test_tree_leaves_with_path(self): def test_tree_leaves_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) - if jaxlib_extension_version < 351: - self.assertEqual( - jax.tree.leaves_with_path(obj, is_leaf=is_leaf), - tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), - ) - return is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( jax.tree.leaves_with_path( @@ -1616,12 +1589,6 @@ def test_tree_map_with_path_is_leaf(self): obj = [1, 2, (3, 4)] obj2 = [5, 6, (7, 8)] is_leaf = lambda x: isinstance(x, tuple) - if jaxlib_extension_version < 351: - self.assertEqual( - jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), - tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), - ) - return is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( jax.tree.map_with_path( From 366a7dfe5c5ec8f5f8b76460490dd13edf31c1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 18 Jun 2025 02:08:33 -0700 Subject: [PATCH 1733/1769] [Mosaic:TPU] Byte-granularity dynamic gathers This expands tpu.dynamic_gather semantics to index over multiple dimensions that are collapsed into one. This allows us to express a quarter-sublane gather on a vreg shape of 8x128x4. PiperOrigin-RevId: 772825343 --- jax/_src/pallas/mosaic/lowering.py | 12 ++- jax/_src/tpu_custom_call.py | 6 +- jaxlib/mosaic/dialect/tpu/tpu.td | 18 ++-- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 12 ++- .../tpu/transforms/apply_vector_layout.cc | 87 +++++++++++++++++-- .../tpu/transforms/infer_vector_layout.cc | 25 +++--- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 42 ++++++++- 7 files changed, 167 insertions(+), 35 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 22a07537e508..96fddba95341 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2378,7 +2378,11 @@ def _gather_lowering_rule( operand_batching_dims=(1,), start_indices_batching_dims=(1,), ): - return tpu.dynamic_gather(x, recovered_indices, 0) + if jaxlib_version < (0, 6, 3): + # TODO: b/423649694 - Remove on 2025-07-18 + return tpu.dynamic_gather(x, recovered_indices, 0) + else: + return tpu.dynamic_gather(x, recovered_indices, [0]) if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(1,), @@ -2386,7 +2390,11 @@ def _gather_lowering_rule( operand_batching_dims=(0,), start_indices_batching_dims=(0,), ): - return tpu.dynamic_gather(x, recovered_indices, 1) + if jaxlib_version < (0, 6, 3): + # TODO: b/423649694 - Remove on 2025-07-18 + return tpu.dynamic_gather(x, recovered_indices, 1) + else: + return tpu.dynamic_gather(x, recovered_indices, [1]) raise NotImplementedError("Unsupported gather") diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 8140a57f6cce..f3f05c1b251f 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -63,9 +63,9 @@ # # We should also add a TODO to remove the conditional one month later. def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: - # TODO(jevinjiang): remove the forward compatibility check after 2025-05-05. - if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5): - return 3 + # TODO: b/423649694 - remove the forward compatibility check after 2025-07-18 + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 6, 18): + return 4 return None diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 766900cd07e4..67a3abf7d2be 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -470,20 +470,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultS let description = [{ Gathers elements from `source` using `indices`. - Given a shape `N0 x N1 x ...`, `output[i0, i1, ...]` is given by - `input[j0, j1, ...]` where `jn = indices[i0, i1, ...] mod Ni` for - `n = dimension` and `jn = in` otherwise. + The specified `dimensions` of `source` are collapsed together and indexed by + `indices`. - Similar to `np.take_along_axis`, except that OOB indices wrap. + Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by + `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where + - `collapsed_source` is the result of collapsing `dimensions` of `source` + into a new trailing dimension of size `M`. + - `jk` is the subsequence of `in` for `n` not in `dimensions`. + + When a single dimension is specified, this is similar to + `np.take_along_axis`, except that OOB indices wrap. }]; let arguments = (ins AnyVectorOfNonZeroRank:$source, VectorOfNonZeroRankOf<[AnyInteger]>:$indices, - I32Attr:$dimension + DenseI32ArrayAttr:$dimensions ); let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict + $source `[` $indices `]` `in` $dimensions attr-dict `:` type($source) `,` type($indices) `->` type($output) }]; let hasVerifier = 1; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 3733bf5d4465..a6d10e68e509 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1449,8 +1449,16 @@ LogicalResult DynamicGatherOp::verify() { if (getIndices().getType().getShape() != getIndices().getType().getShape()) { return emitOpError("Expected indices and result shapes must match"); } - if (!getIndices().getType().getElementType().isInteger(32)) { - return emitOpError("Not implemented: Only i32 indices supported"); + const int64_t rank = getSource().getType().getRank(); + SmallVector seen(rank, false); + for (int32_t d : getDimensions()) { + if (d < 0 || d >= rank) { + return emitOpError("Dimensions must be in [0, rank), but got ") << d; + } + if (seen[d]) { + return emitOpError("Dimensions must be unique"); + } + seen[d] = true; } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0c0103757ea1..c3b2e8cc7f38 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3146,11 +3146,18 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, OpBuilder builder(&op); auto dy_gather_op = cast(op); - // TODO(jevinjiang): we need to think harder for general vector shape. - if (dy_gather_op.getType().getShape() != - ArrayRef(ctx.target_shape)) { + // TODO: b/423658138 - we need to think harder for general vector shape. + const bool is_8bit_vreg = + dy_gather_op.getType().getElementTypeBitWidth() == 8 && + dy_gather_op.getType().getShape() == + ArrayRef{4 * ctx.target_shape[0], ctx.target_shape[1]}; + const bool is_32bit_vreg = + dy_gather_op.getType().getElementTypeBitWidth() == 32 && + dy_gather_op.getType().getShape() == ArrayRef(ctx.target_shape); + if (!is_32bit_vreg && !is_8bit_vreg) { return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); + "Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG " + "shape"); } if (src_layout != out_layout || idx_layout != out_layout) { @@ -3159,7 +3166,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, "result"); } - if (!out_layout.hasNaturalTopology(ctx.target_shape)) { + if (!out_layout.hasNativeTiling(ctx.target_shape)) { return op.emitOpError( "Not implemented: unsupported layout for DynamicGatherOp"); } @@ -3177,11 +3184,75 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions()); TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1); + Location loc = dy_gather_op.getLoc(); + SmallVector dimensions(dy_gather_op.getDimensions()); + if (dy_gather_op.getType().getElementTypeBitWidth() == 8) { + if (dy_gather_op.getDimensions() != ArrayRef{0}) { + return dy_gather_op.emitOpError( + "Not implemented: 8-bit dynamic gather only supported along " + "dimension 0"); + } + // Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0}, + // i.e. byte index is in the upper bits and sublane index in the lower bits. + // However, the input indices effectively have sublane index in the upper + // bits and byte index in the lower bits. + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + VectorType i8_vreg_ty = + getNativeVregType(builder.getI8Type(), ctx.target_shape); + auto i8_const_vreg = [&](const int8_t value) { + return getFullVector(builder, loc, i8_vreg_ty, + builder.getI8IntegerAttr(value)); + }; + idx_vregs.Each([&](absl::Span idxs, Value *v) { + const int sublane_bits = llvm::Log2_64(ctx.target_shape[0]); + const int byte_bits = 2; + // This check ensures that e.g. when right shifting below, the bits from + // the higher bytes don't influence the indices of the lower bytes. Lets + // us mask just once. + const bool mask_once = + sublane_bits + byte_bits + std::max(byte_bits, sublane_bits) <= 8; + if (mask_once) { + // Zero out the high bits that specify neither byte nor index (they + // might not be zero since op semantics allow wrapping). + Value mask = i8_const_vreg((1 << (byte_bits + sublane_bits)) - 1); + *v = builder.create(loc, mask, *v); + } + Value shifted_byte = *v; + if (!mask_once) { + Value mask = i8_const_vreg((1 << byte_bits) - 1); + shifted_byte = builder.create(loc, mask, shifted_byte); + } + shifted_byte = + builder.create(loc, i32_vreg_ty, shifted_byte); + shifted_byte = builder.create( + loc, shifted_byte, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(sublane_bits))); + Value shifted_sublane = *v; + if (!mask_once) { + Value mask = + i8_const_vreg((1 << (byte_bits + sublane_bits)) - (1 << byte_bits)); + shifted_sublane = + builder.create(loc, mask, shifted_sublane); + } + shifted_sublane = + builder.create(loc, i32_vreg_ty, shifted_sublane); + shifted_sublane = builder.create( + loc, shifted_sublane, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(byte_bits))); + *v = builder.create(loc, shifted_byte, shifted_sublane); + *v = builder.create(loc, i8_vreg_ty, *v); + }); + dimensions = SmallVector{2, 0}; + } + xla::Array out_vregs(src_vregs.dimensions()); out_vregs.Each([&](absl::Span idxs, Value *v) { - *v = builder.create( - op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs), - idx_vregs(idxs), dy_gather_op.getDimension()); + *v = builder.create(loc, src_vregs(idxs).getType(), + src_vregs(idxs), idx_vregs(idxs), + dimensions); }); dy_gather_op.replaceAllUsesWith( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 7d279c5cb307..388adc421a09 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -967,22 +967,23 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::DynamicGatherOp op) { - if (op.getType().getShape() != ArrayRef(target_shape_) && - op.getType().getElementTypeBitWidth() != 32) { - return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); - } - if (op.getDimension() != 0 && op.getDimension() != 1) { - return op.emitOpError( - "Not implemented: Only dimension 0 and 1 are supported"); - } // TODO(jevinjiang): we could preserve some offsets such as replicated // offset but since we are forcing all operands and result to be the same // layout, we can set all offsets to zero for now. Also maybe we should // consider adding this to elementwise rule. - auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, - ImplicitDim::kNone); - setLayout(op, {layout, layout}, layout); + if (op.getType().getShape() == ArrayRef(target_shape_) && + op.getType().getElementTypeBitWidth() == 32) { + VectorLayout layout(kNativeBitwidth, {0, 0}, default_tiling_, + ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else if (op.getIndices().getType().getShape() == + ArrayRef{4 * target_shape_[0], target_shape_[1]} && + op.getType().getElementTypeBitWidth() == 8) { + VectorLayout layout(8, {0, 0}, nativeTiling(8), ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else { + return op.emitOpError("Not implemented"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index e08149fe44fc..1f1b97e205d8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -40,10 +40,45 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 4; +constexpr int kVersion = 5; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; +LogicalResult dynamic_gather_upgrade(Operation* op, int version) { + if (version < 5) { + auto dimension_attr = op->getAttrOfType("dimension"); + if (!dimension_attr || dimension_attr.getValue().getBitWidth() != 32) { + return op->emitError("Missing or invalid dimension attribute"); + } + const int32_t dimension = dimension_attr.getInt(); + op->removeAttr("dimension"); + op->setAttr("dimensions", + DenseI32ArrayAttr::get(op->getContext(), {dimension})); + } + return success(); +} + +LogicalResult dynamic_gather_downgrade(Operation* op, int version) { + if (version < 5) { + auto dimensions_attr = op->getAttrOfType("dimensions"); + if (!dimensions_attr) { + return op->emitError("Missing or invalid dimensions attribute"); + } + const ArrayRef dimensions = dimensions_attr.asArrayRef(); + if (dimensions.size() != 1) { + return op->emitError( + "Can only downgrade below version 5 when a single dimension is " + "specified."); + } + const int32_t dimension = dimensions.front(); + op->removeAttr("dimensions"); + op->setAttr("dimension", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), dimension)); + } + return success(); +} + LogicalResult enqueue_dma_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { @@ -154,15 +189,18 @@ LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_upgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_upgrade}}; + vector_multi_dim_reduce_upgrade}, + }; return *rules; } const llvm::StringMap& downgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_downgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, {vector::MultiDimReductionOp::getOperationName(), vector_multi_dim_reduce_downgrade}}; From 7c432e951ee4b42caf0283f0e1fe2f389f9ada3a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Jun 2025 04:57:02 -0700 Subject: [PATCH 1734/1769] [mosaic] Added a `k` prefix to `TPU_MemorySpace` members PiperOrigin-RevId: 772870655 --- jaxlib/mosaic/dialect/tpu/tpu.td | 5 ++--- jaxlib/mosaic/dialect/tpu/transforms/communication.cc | 2 +- jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 67a3abf7d2be..241f0a745928 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -157,9 +157,8 @@ def TPU_TiledLayoutAttr def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ I32EnumAttrCase<"kAny", 4294967295, "any">, - // TODO(apaszke): Rename to kXYZ in C++ - I32EnumAttrCase<"vmem", 0, "vmem">, - I32EnumAttrCase<"smem", 1, "smem">, + I32EnumAttrCase<"kVmem", 0, "vmem">, + I32EnumAttrCase<"kSmem", 1, "smem">, I32EnumAttrCase<"kHbm", 2, "hbm">, I32EnumAttrCase<"kCmem", 3, "cmem">, I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem"> diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index dfe42111916c..7798b0027369 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -110,7 +110,7 @@ struct LogicalToPhysicalDeviceIdPass auto device_assignment_type = MemRefType::get( {total_devices}, IntegerType::get(func.getContext(), 32), TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), - MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); + MemorySpaceAttr::get(func.getContext(), MemorySpace::kSmem)); if (failed(func.insertArgument(func.getNumArguments(), device_assignment_type, nullptr, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index b772c5c8a114..f96989c0fd95 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -245,7 +245,7 @@ FailureOr inferMemref(MemRefType memref, semaphore_mem); } const Attribute vmem = - tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); + tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::kVmem); const Attribute memory_space = memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); FAILUREOR_ASSIGN_OR_RETURN( From 2ec9981233caf736a31c0e1c7ee9c165960c099c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 18 Jun 2025 06:28:01 -0700 Subject: [PATCH 1735/1769] [Mosaic GPU] Rework the CUDA_ROOT detection once again PiperOrigin-RevId: 772895063 --- jax/experimental/mosaic/gpu/core.py | 3 --- jaxlib/mosaic/gpu/BUILD | 6 ++++++ jaxlib/mosaic/gpu/custom_call.cc | 10 +++------- jaxlib/mosaic/gpu/library_paths.h | 31 +++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 jaxlib/mosaic/gpu/library_paths.h diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index f7a471e9b123..5c19d5ad0fb6 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -59,9 +59,6 @@ os.environ["CUDA_ROOT"] = cuda_root PYTHON_RUNFILES = os.environ.get("PYTHON_RUNFILES") -PTXAS_PATH = os.path.join(cuda_root, "bin/ptxas") -NVDISASM_PATH = os.path.join(cuda_root, "bin/nvdisasm") - # This tracks the latest Mosaic GPU IR version with a monthly delay. FWD_COMPAT_IR_VERSION = 1 diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index d2abea0048d6..e50ecfaa63ec 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -148,6 +148,7 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], deps = [ + ":library_paths", ":nvshmem", ":passes", ":target", @@ -246,3 +247,8 @@ cc_binary( "@xla//xla/tsl/cuda:cudart", ], ) + +cc_library( + name = "library_paths", + hdrs = ["library_paths.h"], +) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 01b7e015e461..39f9635b043b 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -34,6 +34,7 @@ limitations under the License. #include #include +#include "jaxlib/mosaic/gpu/library_paths.h" #include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" @@ -48,7 +49,6 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -// Leave this comment here. Internal Google business. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/TargetSelect.h" @@ -134,15 +134,11 @@ class TemporaryDirectory { std::string path; }; -const char *GetCUDARoot() { - return getenv("CUDA_ROOT"); -} - absl::StatusOr RunCUDATool(const char* tool, const std::vector& args, bool stderr_to_stdout = true) { CHECK(!args.empty() && args.back() == nullptr); - const char* cuda_path_ptr = GetCUDARoot(); + const char* cuda_path_ptr = mosaic::gpu::GetCUDARoot(); if (!cuda_path_ptr) return absl::InternalError("Failed to get the CUDA toolkit path"); std::string tool_path(cuda_path_ptr); @@ -346,7 +342,7 @@ mlir::FailureOr GetPassPipeline( mlir::LLVM::registerDIScopeForLLVMFuncOpPass(); return true; }); - const char *cuda_root = GetCUDARoot(); + const char *cuda_root = mosaic::gpu::GetCUDARoot(); if (!cuda_root) { return mlir::failure(); } diff --git a/jaxlib/mosaic/gpu/library_paths.h b/jaxlib/mosaic/gpu/library_paths.h new file mode 100644 index 000000000000..83d523ac3ccc --- /dev/null +++ b/jaxlib/mosaic/gpu/library_paths.h @@ -0,0 +1,31 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ +#define JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ + +#include + +namespace mosaic { +namespace gpu { + +inline const char *GetCUDARoot() { + return getenv("CUDA_ROOT"); +} + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ From b6575e19489ff209340efed957a64d02f25fc9f8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 18 Jun 2025 06:33:56 -0700 Subject: [PATCH 1736/1769] [Mosaic GPU] Add support for s8 matmuls on Blackwell This follows our WGMMA implementation. PiperOrigin-RevId: 772896470 --- jax/experimental/mosaic/gpu/tcgen05.py | 82 ++++++++++++++------------ tests/mosaic/gpu_test.py | 24 +++++++- 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 26872d89f99e..904aec493b3b 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -63,45 +63,40 @@ def create_instr_descriptor( transpose_a: bool = False, transpose_b: bool = False, ): - f32 = ir.F32Type.get() - bf16 = ir.BF16Type.get() f16 = ir.F16Type.get() - if acc_dtype not in {f32, f16}: - raise NotImplementedError("Only float32 and float16 accumulators supported") - if utils.bitwidth(input_dtype) == 16: - if input_dtype not in {f16, bf16}: - raise NotImplementedError( - "The only supported 16-bit input types are float16 and bfloat16, got" - f" {input_dtype}" - ) - desc = 0 - desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 - # Bit 6 is reserved - desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9 - desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12 - return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b) - elif utils.bitwidth(input_dtype) == 8: - desc = 0 - desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 - # Bit 6 is reserved - if input_dtype == ir.Float8E4M3FNType.get(): - input_dtype_enum = 0 - elif input_dtype == ir.Float8E5M2Type.get(): - input_dtype_enum = 1 - else: - raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") - desc |= input_dtype_enum << 7 # A dtype, bits 7-9 - desc |= input_dtype_enum << 10 # B dtype, bits 10-12 - return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b) + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) + + desc = 0 + if acc_dtype == f16: + d_type_val = 0 + elif acc_dtype == f32: + d_type_val = 1 + elif acc_dtype == i32: + d_type_val = 2 + else: + raise NotImplementedError(f"Unsupported accumulator dtype: {acc_dtype}") + desc |= (d_type_val << 4) # D type, bits 4-5 + # Bit 6 is reserved + if input_dtype == f16: + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.BF16Type.get(): + assert acc_dtype == f32 + ab_type_val = 1 + elif input_dtype == ir.Float8E4M3FNType.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.Float8E5M2Type.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 1 + elif input_dtype == ir.IntegerType.get_signless(8): # Only s8 for now. + assert acc_dtype == i32 + ab_type_val = 1 else: raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") - - -def _finish_instr_descriptor( - desc: int, m: int, n: int, transpose_a: bool, transpose_b: bool, -): - # We ignore sparsity in bits 0-3 - # A, B and D types are set by the caller + desc |= (ab_type_val << 7) # A dtype, bits 7-9 + desc |= (ab_type_val << 10) # B dtype, bits 10-12 # We ignore negate bits 13-14 desc |= transpose_a << 15 # Transpose A desc |= transpose_b << 16 # Transpose B @@ -180,6 +175,7 @@ def mma( ) f32 = ir.F32Type.get() f16 = ir.F16Type.get() + s32 = ir.IntegerType.get_signless(32) if element_type == f32 or element_type == ir.BF16Type.get(): if d.dtype != f32: raise ValueError( @@ -195,6 +191,12 @@ def mma( f"MMA with element type {element_type} only supports accumulators of" f" type f32 or f16, but got: {d.dtype}" ) + elif element_type == ir.IntegerType.get_signless(8): + if d.dtype != s32: + raise ValueError( + "MMA with element type s8 only supports s32 accumulators, but got:" + f" {d.dtype}" + ) else: raise NotImplementedError(f"Unsupported element type: {element_type}") @@ -316,6 +318,8 @@ def _do_mma( kind = "f8f6f4" elif ir.Float8E4M3FNType.isinstance(element_type): kind = "f8f6f4" + elif ir.IntegerType.get_signless(8).isinstance(element_type): + kind = "i8" else: raise NotImplementedError(f"Unsupported input element type: {element_type}") @@ -680,7 +684,7 @@ def slice(self, *idxs): dtype=self.dtype, ) - def load(self, layout: fa.TiledLayout = LAYOUT): + def load(self, layout: fa.TiledLayout = LAYOUT, is_signed: bool | None = None): i32 = ir.IntegerType.get_signless(32) if self.shape[1] % 8: raise NotImplementedError @@ -740,7 +744,9 @@ def load(self, layout: fa.TiledLayout = LAYOUT): "TMEM loads can only produce results in the tcgen05 layouts" f" ({LAYOUT} and {TMEM_NATIVE_LAYOUT}), but got: {layout}" ) - return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + return fa.FragmentedArray( + _registers=registers, _layout=layout, _is_signed=is_signed + ) def store(self, value): if self.shape[1] % 8: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 5c8f6bc3c558..8e217843409d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1225,7 +1225,26 @@ def kernel(ctx, input, output, scratch): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 swizzle=(32, 64, 128,), ) - def test_mma_basic(self, **kwargs): + def test_mma_basic_float(self, **kwargs): + if kwargs["n"] * jnp.dtype(kwargs["in_jax_dtype"]).itemsize < kwargs["swizzle"]: + self.skipTest("swizzle too large for input") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.int8,), + out_jax_dtype=(jnp.int32,), + m=(128,), # TODO(apaszke): 64, 192, 256 + n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 + swizzle=(32, 64, 128,), + ) + def test_mma_basic_int(self, **kwargs): if kwargs["n"] * jnp.dtype(kwargs["in_jax_dtype"]).itemsize < kwargs["swizzle"]: self.skipTest("swizzle too large for input") self._basic_mma_test( @@ -1315,7 +1334,8 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc.load().store_untiled(out, optimized=False) + is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) From 1cd076fdbcf99cd4738d04b2c818d3b890fbec78 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Jun 2025 10:02:35 -0400 Subject: [PATCH 1737/1769] Drop Python 3.10 support. Per policy, we can drop this in July 2025. However, we've already made our release for June 2025 and we can drop it right now. --- .github/workflows/bazel_cpu_rbe.yml | 6 +- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- .github/workflows/bazel_cuda_rbe.yml | 4 +- .github/workflows/build_artifacts.yml | 1 - .github/workflows/ci-build.yaml | 16 +- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- .github/workflows/cloud-tpu-ci-presubmit.yml | 4 +- .github/workflows/oldest_supported_numpy.yml | 4 +- .github/workflows/pytest_cuda.yml | 2 +- .github/workflows/pytest_tpu.yml | 2 +- .github/workflows/rocm-ci.yml | 2 +- .github/workflows/wheel_tests_continuous.yml | 14 +- .../workflows/wheel_tests_nightly_release.yml | 16 +- CHANGELOG.md | 4 + WORKSPACE | 1 - build/requirements_lock_3_10.txt | 787 ------------------ build/rocm/Dockerfile.ms | 2 +- build/rocm/ci_build.sh | 2 +- build/rocm/docker/Dockerfile.jax-ubu22 | 2 +- build/rocm/tools/build_wheels.py | 2 +- docs/contributing.md | 2 +- docs/developer.md | 1 - examples/ffi/CMakeLists.txt | 2 +- examples/ffi/pyproject.toml | 2 +- jax/_src/source_info_util.py | 25 +- jax/_src/traceback_util.py | 21 +- jax/errors.py | 17 +- jax_plugins/cuda/__init__.py | 5 +- jax_plugins/cuda/plugin_setup.py | 3 +- jax_plugins/rocm/plugin_setup.py | 5 +- jaxlib/_jax/__init__.pyi | 1 - jaxlib/pjit.cc | 7 - jaxlib/pmap_lib.cc | 7 - jaxlib/py_array.cc | 14 - jaxlib/setup.py | 3 +- jaxlib/traceback.cc | 49 -- setup.py | 3 +- tests/errors_test.py | 25 +- tests/typing_test.py | 6 +- 39 files changed, 82 insertions(+), 991 deletions(-) delete mode 100644 build/requirements_lock_3_10.txt diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 71c140464454..99071974bd00 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -36,18 +36,18 @@ jobs: # Begin Presubmit Naming Check - name modification requires internal check to be updated strategy: matrix: - python: ["3.10", "3.13"] + python: ["3.11", "3.13"] runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"] enable-x_64: [1, 0] exclude: # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have # coverage for one of each, we don't need to run both. - - python: "3.10" + - python: "3.11" enable-x_64: 1 - python: "3.13" enable-x_64: 0 # Only test a single Python version on Arm64 as we don't run the tests. - - python: "3.10" + - python: "3.11" runner: "linux-arm64-c4a-16" name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" # End Presubmit Naming Check github-cpu-presubmits diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index d30e1b56dab8..5168dc6d002e 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -69,7 +69,7 @@ jobs: arch=$(uname -m) # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index 2c57b35587fa..3aaf2a485e77 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -35,13 +35,13 @@ jobs: # Begin Presubmit Naming Check - name modification requires internal check to be updated strategy: matrix: - python: ["3.10", "3.13"] + python: ["3.11", "3.13"] runner: ["linux-x86-n2-16"] enable-x_64: [1, 0] exclude: # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have # coverage for one of each, we don't need to run both. - - python: "3.10" + - python: "3.11" enable-x_64: 1 - python: "3.13" enable-x_64: 0 diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index ece2237eeead..7459953c37a1 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -31,7 +31,6 @@ on: type: choice default: "3.12" options: - - "3.10" - "3.11" - "3.12" - "3.13" diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 86aaeffaa5fd..dbd51373a3ac 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,11 +1,5 @@ name: CI -# We test all supported Python versions as follows: -# - 3.10 : Documentation build -# - 3.10 : Part of Matrix with NumPy dispatch -# - 3.10 : Part of Matrix -# - 3.11 : Part of Matrix - on: # Trigger the workflow on push or pull request, # but only for the main branch @@ -52,8 +46,8 @@ jobs: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.10" - python-version: "3.10" + - name-prefix: "with 3.11" + python-version: "3.11" enable-x64: 1 prng-upgrade: 1 num_generated_cases: 1 @@ -104,7 +98,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.12'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -136,7 +130,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.11'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -165,7 +159,7 @@ jobs: matrix: # Test the oldest supported Python version here. include: - - python-version: "3.10" + - python-version: "3.11" os: ubuntu-latest enable-x64: 0 num_generated_cases: 10 diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 1f096ce48e2d..5a97999c2b23 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -33,7 +33,7 @@ jobs: {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] - python-version: ["3.10"] + python-version: ["3.11"] # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. exclude: - tpu: diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index fe1f2820b338..c6988f198675 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -42,7 +42,7 @@ jobs: with: runner: "linux-x86-n2-16" artifact: ${{ matrix.artifact }} - python: "3.10" + python: "3.11" clone_main_xla: 1 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -57,7 +57,7 @@ jobs: runner: "linux-x86-ct5lp-224-8tpu" cores: "8" tpu-type: "v5e-8" - python: "3.10" + python: "3.11" libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} halt-for-connection: ${{ inputs.halt-for-connection || false }} diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml index fbf881a84a9c..00bfd8100d27 100644 --- a/.github/workflows/oldest_supported_numpy.yml +++ b/.github/workflows/oldest_supported_numpy.yml @@ -25,11 +25,11 @@ jobs: runs-on: "linux-x86-n2-64" container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "CI - Oldest Supported NumPy (Python 3.10, x64=0)" + name: "CI - Oldest Supported NumPy (Python 3.11, x64=0)" # End Presubmit Naming Check github-oldest-supported-numpy-presubmit env: - JAXCI_PYTHON: "python3.10" + JAXCI_PYTHON: "python3.11" JAXCI_ENABLE_X64: 0 JAX_NUM_GENERATED_CASES: 5 diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index ed021a970ecd..58336612041a 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -79,7 +79,7 @@ jobs: arch=$(uname -m) # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 313bbede52f5..3bb88eef2e3b 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -88,7 +88,7 @@ jobs: arch=$(uname -m) # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 4bfb8cb50a5e..ab4016f301a8 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -16,7 +16,7 @@ jobs: env: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.11" ROCM_VERSION: "6.3.3" WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} steps: diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 91662ff51f3e..99caad6325a0 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -50,7 +50,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy in the CPU tests job runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] - python: ["3.10"] + python: ["3.11"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix # values to the name and creates a separate entry for each matrix combination. @@ -71,7 +71,7 @@ jobs: # Python values need to match the matrix stategy in the CUDA tests job below runner: ["linux-x86-n2-16"] artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.10",] + python: ["3.11",] name: "Build ${{ format('{0}', 'CUDA') }} artifacts" with: runner: ${{ matrix.runner }} @@ -94,7 +94,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy in the # build_jaxlib_artifact job above runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10",] + python: ["3.11",] enable-x64: [1, 0] name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: @@ -116,7 +116,7 @@ jobs: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] - python: ["3.10",] + python: ["3.11",] cuda: [ {version: "12.1", use-nvidia-pip-wheels: false}, {version: "12.8", use-nvidia-pip-wheels: true}, @@ -152,7 +152,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"] - python: ["3.10",] + python: ["3.11",] enable-x64: [1, 0] name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}" with: @@ -172,7 +172,7 @@ jobs: matrix: # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.10",] + python: ["3.11",] jaxlib-version: ["head", "pypi_latest"] enable-x64: [1, 0] name: "Bazel CUDA Non-RBE (jax version = ${{ format('{0}', 'head') }})" @@ -194,7 +194,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10",] + python: ["3.11",] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index d90d05b64b28..f5622af0f1ef 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -43,7 +43,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil", "3.14"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14"] enable-x64: [0] exclude: - runner: "windows-x86-n2-64" @@ -67,7 +67,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil", "3.14"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14"] cuda: [ {cuda-version: "12.1", use-nvidia-pip-wheels: false}, {cuda-version: "12.8", use-nvidia-pip-wheels: true} @@ -89,7 +89,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10", "3.11", "3.12", "3.13", "3.13-nogil"] + python: ["3.11", "3.12", "3.13", "3.13-nogil"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, @@ -103,9 +103,6 @@ jobs: # Exclude pypi_latest for nightly releases - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} # Run a single Python version for v4-8 - - tpu-specs: - type: "v4-8" - python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" @@ -116,9 +113,6 @@ jobs: type: "v4-8" python: "3.13-nogil" # Run Python versions in between min and max for v6e-8 - - tpu-specs: - type: "v6e-8" - python: "3.10" - tpu-specs: type: "v6e-8" python: "3.13" @@ -155,7 +149,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10", "3.13", "3.13-nogil"] + python: ["3.11", "3.13", "3.13-nogil"] container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Verifies that JAX's release wheels can be installed @@ -171,7 +165,7 @@ jobs: final_gcs_download_uri=${{ inputs.gcs_download_uri }} # Get the major and minor version of Python. - # E.g if python=3.10, then python_major_minor=310 + # E.g if python=3.11, then python_major_minor=311 # E.g if python=3.13-nogil, then python_major_minor=313t python_major_minor=${{ matrix.python }} python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') diff --git a/CHANGELOG.md b/CHANGELOG.md index 699439a82033..adb1af2ce9a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Breaking changes: + * The minimum Python version is now 3.11. 3.11 will remain the minimum + supported version until July 2026. + ## JAX 0.6.2 (June 17, 2025) diff --git a/WORKSPACE b/WORKSPACE index 6a7df6d9c8bc..33be3c6e0452 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -25,7 +25,6 @@ python_init_repositories( ], local_wheel_workspaces = ["//jaxlib:jax.bzl"], requirements = { - "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt deleted file mode 100644 index 523d3bbfe6e2..000000000000 --- a/build/requirements_lock_3_10.txt +++ /dev/null @@ -1,787 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# bazel run //build:requirements.update -# -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 - # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 - # via -r build/test-requirements.txt -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/requirements.in -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 - # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/requirements.in -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 - # via matplotlib -cycler==0.12.1 \ - --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ - --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c - # via matplotlib -etils[epath,epy]==1.7.0 \ - --hash=sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60 \ - --hash=sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350 - # via -r build/requirements.in -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # hypothesis - # pytest -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 - # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a - # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 - # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c - # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 - # via -r build/test-requirements.txt -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 - # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 - # via pytest -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ - --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb - # via - # -r build/requirements.in - # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ - --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ - --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ - --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ - --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ - --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ - --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ - --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ - --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ - --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 - # via -r build/requirements.in -jaxlib==0.6.1 \ - --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ - --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ - --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ - --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ - --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ - --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ - --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ - --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ - --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ - --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ - --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ - --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ - --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ - --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ - --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ - --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ - --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ - --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d - # via -r build/requirements.in -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f - # via matplotlib -libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 - # via -r build/requirements.in -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb - # via rich -matplotlib==3.8.4 \ - --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ - --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ - --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ - --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ - --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ - --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ - --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ - --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ - --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ - --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ - --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ - --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ - --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ - --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ - --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ - --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ - --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ - --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ - --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ - --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ - --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ - --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ - --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ - --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ - --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ - --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ - --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ - --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/test-requirements.txt -mdurl==0.1.2 \ - --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ - --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba - # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via - # -r build/requirements.in - # jaxlib - # tensorstore -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 ; python_version <= "3.12" \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 - # via - # -r build/nonfreethreading-requirements.txt - # contourpy - # jaxlib - # matplotlib - # ml-dtypes - # opt-einsum - # scipy - # tensorstore -nvidia-cublas-cu12==12.8.3.14 \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 - # via - # jax-cuda12-plugin - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via jax-cuda12-plugin -nvidia-cuda-nvcc-cu12==12.8.61 \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via jax-cuda12-plugin -nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ - --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ - --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ - --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 - # via -r build/requirements.in -nvidia-cuda-runtime-cu12==12.8.57 \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via jax-cuda12-plugin -nvidia-cudnn-cu12==9.8.0.87 \ - --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ - --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ - --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 - # via jax-cuda12-plugin -nvidia-cufft-cu12==11.3.3.41 \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via jax-cuda12-plugin -nvidia-cusolver-cu12==11.7.2.55 \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via jax-cuda12-plugin -nvidia-cusparse-cu12==12.5.7.53 \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 - # via - # jax-cuda12-plugin - # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via jax-cuda12-plugin -nvidia-nvjitlink-cu12==12.8.61 \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 - # via - # jax-cuda12-plugin - # nvidia-cufft-cu12 - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 -nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ - --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ - --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 - # via - # -r build/requirements.in - # jax-cuda12-plugin -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 - # via - # auditwheel - # build - # matplotlib - # pytest -pillow==11.0.0 \ - --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ - --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ - --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ - --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ - --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ - --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ - --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ - --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ - --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ - --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ - --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ - --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ - --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ - --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ - --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ - --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ - --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ - --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ - --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ - --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ - --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ - --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ - --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ - --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ - --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ - --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ - --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ - --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ - --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ - --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ - --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ - --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ - --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ - --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ - --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ - --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ - --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ - --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ - --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ - --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ - --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ - --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ - --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ - --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ - --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ - --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ - --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ - --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ - --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ - --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ - --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ - --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ - --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ - --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ - --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ - --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ - --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ - --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ - --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ - --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ - --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ - --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ - --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ - --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ - --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ - --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ - --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ - --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ - --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ - --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ - --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ - --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ - --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ - --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ - --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 - # via - # -r build/test-requirements.txt - # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 - # via pytest -portpicker==1.6.0 ; python_version < "3.13" \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via - # -r build/nonfreethreading-requirements.txt - # -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 - # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 - # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 - # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 - # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d - # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 \ - --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ - --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 - # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 - # via -r build/test-requirements.txt -scipy==1.13.1 ; python_version <= "3.12" \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via - # -r build/requirements.in - # jaxlib -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -sortedcontainers==2.4.0 \ - --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ - --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 - # via hypothesis -tensorstore==0.1.73 \ - --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ - --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ - --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ - --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ - --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ - --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ - --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ - --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ - --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ - --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ - --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ - --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ - --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ - --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ - --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ - --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ - --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ - --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ - --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ - --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ - --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b - # via -r build/nonfreethreading-requirements.txt -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # build - # pytest -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/requirements.in -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e - # via etils -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/nonfreethreading-requirements.txt - -# The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via -r build/requirements.in diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index a084045256de..40b4decaafb4 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -40,7 +40,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ liblzma-dev # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.14 +ARG PYTHON_VERSION=3.11.13 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..847d4e9b4b93 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -44,7 +44,7 @@ CONTAINER_TYPE="rocm" DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -PYTHON_VERSION="3.10" +PYTHON_VERSION="3.11" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 index 70b16f9e9677..b6e90f2183d2 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu22 +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -60,7 +60,7 @@ ARG JAX_COMMIT ARG XLA_COMMIT LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ - com.amdgpu.python_version="3.10" \ + com.amdgpu.python_version="3.11" \ com.amdgpu.jax_version="$JAX_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 3b3d697addc9..9fdffe6cfa03 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -251,7 +251,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default=["3.11.13,3.12"], help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( diff --git a/docs/contributing.md b/docs/contributing.md index 635f5e899161..087432f1f771 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -36,7 +36,7 @@ Follow these steps to contribute code: [repository page](http://www.github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.10 locally in order to run tests. +3. Install Python >= 3.11 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: diff --git a/docs/developer.md b/docs/developer.md index 1b50a9b65bc0..e219e8517075 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -455,7 +455,6 @@ which one is selected by specifying `HERMETIC_PYTHON_VERSION`. For example in `WORKSPACE` file: ``` requirements = { - "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index a7b8869a64b5..4a93cc490d33 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -3,7 +3,7 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) -find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())" diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 130dd91bbc70..84e2c4700500 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build" [project] name = "jax_ffi_example" version = "0.0.1" -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = ["jax"] [project.optional-dependencies] diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index b1901f44f022..396ea541c75f 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -21,7 +21,6 @@ import itertools import os.path import re -import sys import sysconfig import threading import types @@ -159,21 +158,15 @@ def is_user_filename(filename: str) -> bool: return (_include_path_regex().search(filename) is not None or _exclude_path_regex().search(filename) is None) -if sys.version_info >= (3, 11): - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - loc = xla_client.Traceback.code_addr2location(code, lasti) - start_line, start_column, end_line, end_column = loc - return Frame(file_name=code.co_filename, - function_name=code.co_qualname, - start_line=start_line, start_column=start_column, - end_line=end_line, end_column=end_column) -else: - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - # pre-3.11 co_qualname does not exist, use co_name - return Frame(file_name=code.co_filename, - function_name=code.co_name, - start_line=xla_client.Traceback.code_addr2line(code, lasti), - start_column=0, end_line=0, end_column=0) + +def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: + loc = xla_client.Traceback.code_addr2location(code, lasti) + start_line, start_column, end_line, end_column = loc + return Frame(file_name=code.co_filename, + function_name=code.co_qualname, + start_line=start_line, start_column=start_column, + end_line=end_line, end_column=end_column) + def user_frames(source_info: SourceInfo) -> Iterator[Frame]: """Iterator over the user's frames, filtering jax-internal frames.""" diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index 60276a22b4cf..f1cdf86d9929 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -17,14 +17,12 @@ from collections.abc import Callable import functools import os -import sys import traceback import types from typing import Any, TypeVar, cast from jax._src import config from jax._src import util -from jax._src.lib import _jax C = TypeVar("C", bound=Callable[..., Any]) @@ -193,25 +191,10 @@ def reraise_with_filtered_traceback(*args, **kwargs): tb = e.__traceback__ filtered_tb = filter_traceback(tb) e.with_traceback(filtered_tb) - # In Python < 3.11, there seems to be no way to alter the currently - # raised exception traceback, except via the C API. The interpreter - # keeps a copy of the traceback (exc_traceback) that is separate to the - # __traceback__ of exc_value. Python 3.11 removes exc_traceback and - # just setting __traceback__ is enough. Since it is no longer needed, - # the XLA extension no longer defines a traceback-replacing method at - # Python 3.11 and onward. - if hasattr(_jax, "replace_thread_exc_traceback"): - # TODO(kidger): remove this line once Python 3.11 is the minimum supported - # version. - _jax.replace_thread_exc_traceback(filtered_tb) - if sys.version_info >= (3, 11) and mode == "quiet_remove_frames": + if mode == "quiet_remove_frames": e.add_note("--------------------\n" + _simplified_tb_msg) else: - if mode == "quiet_remove_frames": - # TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum - # supported version. - jax_error = SimplifiedTraceback() - elif mode == "remove_frames": + if mode == "remove_frames": msg = format_exception_only(e) msg = f'{msg}\n\n{_jax_message_append}' jax_error = UnfilteredStackTrace(msg) diff --git a/jax/errors.py b/jax/errors.py index 6da7b717cb5f..0dcf34bd4763 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -31,4 +31,19 @@ JaxRuntimeError = _xc.XlaRuntimeError del _xc -from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback +import jax._src.traceback_util +_deprecations = { + "SimplifiedTraceback": ( + "jax.errors.SimplifiedTraceback is deprecated and will be removed in JAX v0.8.", + jax._src.traceback_util.SimplifiedTraceback + ), +} + +import typing +if typing.TYPE_CHECKING: + SimplifiedTraceback = jax._src.traceback_util.SimplifiedTraceback +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 10609460d814..de296a7a9e81 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -114,10 +114,7 @@ def _load(module, libraries): # except OSError as e: # excs.append(e) # - # if sys.version_info >= (3, 11): - # raise ExceptionGroup(f"Unable to load CUDA library {lib}", excs) # noqa: F821 - # else: - # raise RuntimeError(f"Unable to load CUDA library {lib}") from excs[-1] + # raise ExceptionGroup(f"Unable to load CUDA library {lib}", excs) # noqa: F821 def _load_nvidia_libraries(): diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 412f5b6b814a..baa20f2419fc 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -49,7 +49,7 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.10", + python_requires=">=3.11", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ 'with-cuda': [ @@ -81,7 +81,6 @@ def has_ext_modules(self): license="Apache-2.0", classifiers=[ "Development Status :: 5 - Production/Stable", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index d504d0a11666..aba9730b8baf 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -54,16 +54,15 @@ def has_ext_modules(self): author="Ruturaj4", author_email="Ruturaj.Vaidya@amd.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.11", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ package_name: [ diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index c1a1ecdab28c..afa9e633a391 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -798,7 +798,6 @@ class Traceback: def tracebacks_enabled() -> bool: ... def set_tracebacks_enabled(enabled: bool) -> None: ... -def replace_thread_exc_traceback(traceback: Any): ... # === END py_traceback.cc diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc index 6196c945a63a..13b314d7c0c3 100644 --- a/jaxlib/pjit.cc +++ b/jaxlib/pjit.cc @@ -1277,14 +1277,7 @@ void BuildPjitSubmodule(nb::module_& m) { std::string name = absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); PyType_Spec PjitFunction_spec = { -#if PY_VERSION_HEX < 0x030B0000 - // Work around for https://github.com/python/cpython/issues/89478 - // CPython 3.10 and earlier assume that the .name value remains alive - // forever. - /*.name=*/strdup(name.c_str()), -#else /*.name=*/name.c_str(), -#endif // PY_VERSION_HEX < 0x030B0000 /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), /*.itemsize=*/0, #if PY_VERSION_HEX < 0x030C0000 diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index 5f9020edc6bc..f49954e7df90 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -1024,14 +1024,7 @@ void BuildPmapSubmodule(nb::module_& m) { std::string name = absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); PyType_Spec pmap_function_spec = { -#if PY_VERSION_HEX < 0x030B0000 - // Work around for https://github.com/python/cpython/issues/89478 - // CPython 3.10 and earlier assume that the .name value remains alive - // forever. - /*.name=*/strdup(name.c_str()), -#else /*.name=*/name.c_str(), -#endif // PY_VERSION_HEX < 0x030B0000 /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), /*.itemsize=*/0, #if PY_VERSION_HEX < 0x030C0000 diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 6d6f76b27d88..a113da6e11f2 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1959,14 +1959,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { std::string base_name = absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); PyType_Spec PyBaseArray_spec = { -#if PY_VERSION_HEX < 0x030B0000 - // Work around for https://github.com/python/cpython/issues/89478 - // CPython 3.10 and earlier assume that the .name value remains alive - // forever. - /*.name=*/strdup(base_name.c_str()), -#else /*.name=*/base_name.c_str(), -#endif // PY_VERSION_HEX < 0x030B0000 /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), /*.itemsize=*/0, #if PY_VERSION_HEX < 0x030C0000 @@ -1986,14 +1979,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); PyType_Spec PyArray_spec = { -#if PY_VERSION_HEX < 0x030B0000 - // Work around for https://github.com/python/cpython/issues/89478 - // CPython 3.10 and earlier assume that the .name value remains alive - // forever. - /*.name=*/strdup(name.c_str()), -#else /*.name=*/name.c_str(), -#endif // PY_VERSION_HEX < 0x030B0000 /*.basicsize=*/static_cast(sizeof(PyArrayObject)), /*.itemsize=*/0, #if PY_VERSION_HEX < 0x030C0000 diff --git a/jaxlib/setup.py b/jaxlib/setup.py index ef0fcb205fb1..6a0c6520af2b 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -59,7 +59,7 @@ def has_ext_modules(self): author='JAX team', author_email='jax-dev@google.com', packages=['jaxlib'], - python_requires='>=3.10', + python_requires='>=3.11', install_requires=[ 'scipy>=1.12', 'numpy>=1.26', @@ -69,7 +69,6 @@ def has_ext_modules(self): license='Apache-2.0', classifiers=[ "Development Status :: 5 - Production/Stable", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index e4f3cb35ad9b..8a309ebb6f8f 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -258,20 +258,6 @@ std::vector> Traceback::RawFrames() const { PyThreadState* thread_state = PyThreadState_GET(); -#if PY_VERSION_HEX < 0x030b0000 - // The representation of frame->f_lasti changed from bytes to words in Python - // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api - // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. - constexpr int kLastiWordBytes = 2; - - for (PyFrameObject* py_frame = thread_state->frame; - py_frame != nullptr && count < kMaxFrames; py_frame = py_frame->f_back) { - Py_INCREF(py_frame->f_code); - frames[count] = {py_frame->f_code, py_frame->f_lasti * kLastiWordBytes}; - ++count; - } -#else // PY_VERSION_HEX < 0x030b0000 - #ifdef PLATFORM_GOOGLE // This code is equivalent to the version using public APIs, but it saves us // an allocation of one object per stack frame. However, this is definitely @@ -310,8 +296,6 @@ std::vector> Traceback::RawFrames() const { } #endif // PLATFORM_GOOGLE -#endif // PY_VERSION_HEX < 0x030b0000 - Traceback traceback = nb::steal(PyObject_NewVar(PyObject, traceback_type_, count)); TracebackObject* tb = reinterpret_cast(traceback.ptr()); @@ -336,14 +320,7 @@ void BuildTracebackSubmodule(nb::module_& m) { absl::StrCat(nb::cast(m.attr("__name__")), ".Traceback"); PyType_Spec traceback_spec = { -#if PY_VERSION_HEX < 0x030B0000 - // Work around for https://github.com/python/cpython/issues/89478 - // CPython 3.10 and earlier assume that the .name value remains alive - // forever. - /*.name=*/strdup(name.c_str()), -#else /*.name=*/name.c_str(), -#endif // PY_VERSION_HEX < 0x030B0000 /*.basicsize=*/static_cast(sizeof(TracebackObject)), /*.itemsize=*/static_cast(sizeof(TracebackEntry)), /*.flags=*/Py_TPFLAGS_DEFAULT, @@ -430,7 +407,6 @@ void BuildTracebackSubmodule(nb::module_& m) { }, "Python wrapper around the Python C API function PyCode_Addr2Line"); -#if PY_VERSION_HEX >= 0x030b0000 type.attr("code_addr2location") = nb::cpp_function( [](nb::handle code, int lasti) { if (!PyCode_Check(code.ptr())) { @@ -445,30 +421,5 @@ void BuildTracebackSubmodule(nb::module_& m) { return nb::make_tuple(start_line, start_column, end_line, end_column); }, "Python wrapper around the Python C API function PyCode_Addr2Location"); -#endif // PY_VERSION_HEX >= 0x030b0000 - -#if PY_VERSION_HEX < 0x030b0000 - // This function replaces the exception traceback associated with the current - // Python thread. - m.def( - "replace_thread_exc_traceback", - [](nb::object tb) { - if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { - throw xla::XlaRuntimeError( - "argument must be a traceback object or None"); - } - PyThreadState* thread_state = PyThreadState_Get(); - if (!thread_state->exc_info->exc_traceback) { - throw xla::XlaRuntimeError( - "Current thread does not have an active " - "exception traceback"); - } - PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback; - PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr(); - thread_state->exc_info->exc_traceback = new_tb; - Py_XDECREF(old_exc_traceback); - }, - nb::arg("traceback").none()); -#endif // PY_VERSION_HEX < 0x30b0000 } } // namespace xla diff --git a/setup.py b/setup.py index f377db415147..a8bcdee95091 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def load_version_module(pkg_path): author_email='jax-dev@google.com', packages=find_packages(exclude=["examples"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.10', + python_requires='>=3.11', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.5.0', @@ -123,7 +123,6 @@ def load_version_module(pkg_path): license='Apache-2.0', classifiers=[ "Development Status :: 5 - Production/Stable", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", diff --git a/tests/errors_test.py b/tests/errors_test.py index 63618e646dfd..356ca0713adf 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -13,7 +13,6 @@ # limitations under the License. import re -import sys import traceback from absl.testing import absltest @@ -46,10 +45,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(), e = get_exception(etype, f) c = e.__cause__ if filter_mode == "quiet_remove_frames": - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - else: - test.assertIsInstance(c, jax.errors.SimplifiedTraceback) + assert any("For simplicity" in x for x in e.__notes__) elif filter_mode == "remove_frames": test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) else: @@ -393,12 +389,8 @@ def outer(x): ('', 'f = lambda: outer'), ('outer', 'raise TypeError')], filter_mode=filter_mode) e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - self.assertIsInstance(e.__cause__, ValueError) - else: - self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback) - self.assertIsInstance(e.__cause__.__cause__, ValueError) + assert any("For simplicity" in x for x in e.__notes__) + self.assertIsInstance(e.__cause__, ValueError) def test_null_traceback(self, filter_mode): class TestA: pass @@ -424,14 +416,9 @@ def test_grad_norm(self): e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) - if sys.version_info >= (3, 11): - self.assertIsInstance( - e.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) - else: - self.assertIsInstance( - e.__cause__.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) + self.assertIsInstance( + e.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) class CustomErrorsTest(jtu.JaxTestCase): diff --git a/tests/typing_test.py b/tests/typing_test.py index 562c6c56d2d9..6ebfc627efc6 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -143,11 +143,7 @@ def f(x: Any) -> typing.Array | None: # - Confirm that types from *.pyi files are correctly pulled-in # - Confirm that non-trivial overloads are behaving as expected. # - import sys - if sys.version_info >= (3, 11): - from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade - else: - from typing_extensions import assert_type # pytype: disable=not-supported-yet + from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade mat = jnp.zeros((2, 5)) vals = jnp.arange(5) From f562884c5697c67c85685b28a300d871123d7e6e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 18 Jun 2025 08:37:21 -0700 Subject: [PATCH 1738/1769] [Mosaic GPU] Implement canonicalization for `TiledLayout`s. Layout canonicalization gets rid of unnecessary tiles and dimensions in `Tiling`s---such that (1) any tiling operation partitions at least one dimension into more than `1` tile, and (2) the leading dimensions of each tile are not `1` (if canonicalizing a tile in this way leads to an empty tile, then the tile is given shape `(1,)`---which is still a meaningful (final) tile). Canonicalization allows simplifying the definition of a few of our predefined layouts---`WGMMA`, `WGMMA_ROW`, `TCGEN05`, `TCGEN05_ROW`, and `TCGEN05_COL`. PiperOrigin-RevId: 772933245 --- .../mosaic/gpu/fragmented_array.py | 130 ++++++++++++++++-- jax/experimental/mosaic/gpu/wgmma.py | 2 +- tests/mosaic/gpu_test.py | 12 +- 3 files changed, 120 insertions(+), 24 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index e82f71be8ae4..d8b18dfb2ab9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -110,6 +110,43 @@ def fail(): shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) return shape + def canonicalize(self) -> Tiling: + """Returns a canonicalized version of the tiling. + + We define a tiling to be canonical if, at each step (except the first one, + which defines the base tile shape): + + 1. The tiling partitions at least one dimension in more than 1 tile. For + example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it + yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which + allows getting rid of the unnecessary `1` dimensions. + 2. The leading dimensions of each tile are not `1`. If canonicalizing a + tile in this way leads to an empty tile, then the tile is given shape + `(1,)`---which is still a meaningful (final) tile. For example, the + tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape + `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows + getting rid of the unnecessary `1` dimension, and yields a shape + `(8, 2, 4)`. + """ + if len(self.tiles) <= 1: + return self + + shape = self.tiles[0] + new_tiling = [self.tiles[0]] + for tile in self.tiles[1:]: + for i, d in enumerate(tile): + if d != 1: + canonical_tile = tile[i:] + break + else: + canonical_tile = (1,) + tiled_dims = shape[-len(canonical_tile):] + if tiled_dims == canonical_tile: + continue + shape = canonical_tile + new_tiling.append(canonical_tile) + return Tiling(tuple(new_tiling)) + def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: """Computes the strides of an array after tiling.""" for tile in self.tiles: @@ -281,8 +318,13 @@ class TiledLayout: warp_dim: int | Replicated lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int + # Whether to enforce that the layout is canonical. Users of `TiledLayout` + # should not set this to `False`, but it is helpful to be able to construct + # non-canonical layouts as an intermediate state when implementing layout + # transformations. + _check_canonical: dataclasses.InitVar[bool] = True - def __post_init__(self): + def __post_init__(self, _check_canonical: bool): if not self.tiling.tiles: raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] @@ -308,6 +350,10 @@ def __post_init__(self): ) if lane_dims_prod != WARP_SIZE: raise ValueError("The product of lane dims does not equal the warp size") + if _check_canonical: + canonical_layout = self.canonicalize() + if self != canonical_layout: + raise ValueError(f"{self} is not canonical.") @functools.cached_property def partitioned_lane_dims(self) -> tuple[int, ...]: @@ -459,7 +505,8 @@ def replace_tiled_dim(d: int | Replicated, size: int): for d in self.lane_dims ), new_vector_dim, - ) + _check_canonical=False, + ).canonicalize() def reduce(self, axes: Sequence[int]) -> TiledLayout: reduced_layout = self @@ -467,6 +514,59 @@ def reduce(self, axes: Sequence[int]) -> TiledLayout: reduced_layout = reduced_layout.remove_dimension(a) return reduced_layout + def canonicalize(self) -> TiledLayout: + """Returns a version of this layout where tiling is canonical.""" + canonical_tiling = self.tiling.canonicalize() + if canonical_tiling == self.tiling: + return self + + s = self.base_tile_shape + canonical_tiled_tiling_shape = canonical_tiling.tile_shape(s)[len(s):] + offset = len(canonical_tiled_tiling_shape) - 1 + + rev_removed_dims = [] + # Iterate starting from the end in order to eliminate leading dimensions, + # whenever possible. For instance, say we have + # + # shape=(4, 32, 1, 1, 1, 1, 1) + # warp_dim=-7, + # lane_dims=(-6,) + # vector_dim=-1 + # + # and we want to canonicalize this to + # + # shape=(4, 32, 1) + # warp_dim=-3, + # lane_dims=(-2,) + # vector_dim=-1. + # + # After the loop below, we end up with + # + # rev_removed_dims=[False, True, True, True, True, False, False] + # + # which will yield offsets `4` for `warp_dim`, `4` for `lane_dims[0]`, and + # `0` for `vector_dim`. + for d in reversed(self.tiled_tiling_shape): + if offset >= 0 and d == canonical_tiled_tiling_shape[offset]: + rev_removed_dims.append(False) + offset -= 1 + else: + rev_removed_dims.append(True) + assert offset == -1 + + dim_offsets = np.cumsum(rev_removed_dims)[::-1].tolist() + + def replace_tiled_dim(d: int | Replicated): + return d if isinstance(d, Replicated) else d + dim_offsets[d] + + return TiledLayout( + canonical_tiling, + replace_tiled_dim(self.warp_dim), + tuple(replace_tiled_dim(d) for d in self.lane_dims), + replace_tiled_dim(self.vector_dim), + _check_canonical=True + ) + def _tiled_wgmma_layout(shape: tuple[int, ...]): """Returns the tiled layout relevant for WGMMA operations. @@ -601,9 +701,9 @@ def linear_thread_idxs(self): vector_dim=-1, ) WGMMA_ROW_LAYOUT = TiledLayout( - Tiling(((64,), (16,), (8,), (1,), (1,))), - warp_dim=-5, - lane_dims=(-3, Replicated(4)), + Tiling(((64,), (16,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(4)), vector_dim=-1, ) @@ -621,9 +721,9 @@ def linear_thread_idxs(self): # 12 12 13 13 14 14 15 15 # ... WGMMA_LAYOUT = TiledLayout( - Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), + Tiling(((64, 8), (16, 8), (8, 8), (2,))), + warp_dim=-7, + lane_dims=(-3, -2), vector_dim=-1, ) # This tiled layout is similar to the WGMMA layout, only the unit at which we @@ -687,23 +787,23 @@ def linear_thread_idxs(self): # Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. TCGEN05_LAYOUT = TiledLayout( - Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), + Tiling(((128, 8), (32, 8), (8, 8), (2,))), + warp_dim=-7, + lane_dims=(-3, -2), vector_dim=-1, ) # TCGEN05_ROW_LAYOUT is to TCGEN05_LAYOUT as WGMMA_ROW_LAYOUT is to # WGMMA_LAYOUT. TCGEN05_ROW_LAYOUT = TiledLayout( - Tiling(tiles=((128,), (32,), (8,), (1,), (1,))), - warp_dim=-5, - lane_dims=(-3, Replicated(times=4)), + Tiling(tiles=((128,), (32,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(times=4)), vector_dim=-1, ) # TCGEN05_COL_LAYOUT is to TCGEN05_LAYOUT as WGMMA_COL_LAYOUT is to # WGMMA_LAYOUT. TCGEN05_COL_LAYOUT = TiledLayout( - Tiling(tiles=((8,), (8,), (8,), (2,))), + Tiling(tiles=((8,), (2,))), warp_dim=Replicated(times=4), lane_dims=(Replicated(times=8), -2), vector_dim=-1, diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 9b4fc7678538..23d5174bb24b 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -238,7 +238,7 @@ def lc(x): assert len(imms) == num_imm_regs + 1 # +1 for the use_out_reg in setp.ne.b32 - if acc.ndim != 10 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: + if acc.ndim != 9 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: raise ValueError(acc.shape) acc_struct_type = ir.Type.parse( f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>" diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 8e217843409d..d7db2b85d27a 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2462,12 +2462,7 @@ def kernel(ctx, dst, _): ("tcgen05_col", tcgen05.LAYOUT, tcgen05.COL_LAYOUT, 0), ) def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis): - def squeeze_shape(shape): - return tuple(s for s in shape if s != 1) - reduced_layout = layout.reduce((axis,)) - tiled_shape = squeeze_shape(reduced_layout.tiled_tiling_shape) - expected_tiled_shape = squeeze_shape(expected_reduced_layout.tiled_tiling_shape) - self.assertEqual(tiled_shape, expected_tiled_shape) + self.assertEqual(layout.reduce((axis,)), expected_reduced_layout) @parameterized.product( op=(arith.addf, arith.maximumf), @@ -2843,7 +2838,7 @@ def kernel(ctx, dst, _): # Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1) self.assertEqual( tiled.registers.shape, - (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1), + (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1), ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) @@ -4034,7 +4029,8 @@ def tiled_layouts( warp_dim=warp_dim, lane_dims=lane_dims, vector_dim=vector_dim, - ) + _check_canonical=False, + ).canonicalize() @hps.composite def shape_and_tiled_layout( From 5e2afe6e7987e130ddeeebae281c3427375882b4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 18 Jun 2025 09:00:27 -0700 Subject: [PATCH 1739/1769] Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and next release is 0.7 in July PiperOrigin-RevId: 772941028 --- CHANGELOG.md | 4 +++- jax/_src/api.py | 32 -------------------------------- tests/custom_api_test.py | 14 -------------- 3 files changed, 3 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index adb1af2ce9a2..f8b905448c54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,10 +17,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased * Breaking changes: + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in an error + starting in v0.7.x. This raised a DeprecationWarning in v0.6.x. * The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026. - ## JAX 0.6.2 (June 17, 2025) * New features: diff --git a/jax/_src/api.py b/jax/_src/api.py index db1e62910140..45e40810c66a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -38,7 +38,6 @@ from contextlib import contextmanager from jax._src import api_util -from jax._src import deprecations from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( @@ -148,37 +147,6 @@ def _update_debug_special_thread_local(_): float0 = dtypes.float0 -# TODO(jakevdp): remove this for v0.7.0 (~July 2025) -def _allow_deprecated_jit_signature(f: F) -> F: - """Temporary decorator for the jit signature deprecation.""" - @wraps(f) - def wrapped(*args, **kwargs): - if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'): - # Fast path for typical usage. - return f(*args, **kwargs) - if 'fun' in kwargs: - deprecations.warn( - 'jax-jit-positional-args', - ('jax.jit: passing fun by keyword is deprecated.' - ' Pass it by position to silence this warning.'), - stacklevel=2 - ) - return f(kwargs.pop('fun'), **kwargs) - if len(args) > 1: - deprecations.warn( - 'jax-jit-positional-args', - ('jax.jit: passing optional arguments by position is deprecated. ' - ' Pass them by keyword to silence this warning.'), - stacklevel=2 - ) - sig = inspect.signature(f) - kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args)) - return f(kwds.pop('fun'), **kwds, **kwargs) - return f(*args, **kwargs) - return cast(F, wrapped) - - -@_allow_deprecated_jit_signature def jit( fun: Callable, /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 595da899b5c2..a53995751f36 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -42,7 +42,6 @@ from jax._src import config from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe @@ -3721,19 +3720,6 @@ def tp(r, t): return 2 * fn(r, t) self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) - def test_jit_signature_deprecation(self): - fun = lambda x: x - if deprecations.is_accelerated('jax-jit-positional-args'): - with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): - jax.jit(fun=fun) - with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): - jax.jit(fun, None) - else: - with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): - jax.jit(fun=fun) - with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): - jax.jit(fun, None) - def test_cond(self): def f(x, y): @custom_transpose(jnp.ones(2)) From 9cf81e4d12f0b31dd91b7f853147ae280f757eff Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Jun 2025 09:34:45 -0700 Subject: [PATCH 1740/1769] Add a cache around abstract_eval rules. PiperOrigin-RevId: 772952508 --- jax/_src/interpreters/partial_eval.py | 39 +++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 8e4ea16af0de..e5ac29504ee0 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -49,7 +49,7 @@ from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, - HashableFunction, foreach) + HashableFunction, foreach, cache) map, unsafe_map = safe_map, map @@ -1890,6 +1890,23 @@ def vars(atom: Atom) -> list[Var]: return jaxpr, list(constvals) +@cache() +def _cached_abstract_eval(primitive: core.Primitive, *aval_qdds, **params): + return primitive.abstract_eval(*aval_qdds, **params) + + +def _verify_params_are_hashable( + primitive: core.Primitive, params: dict[str, Any]) -> None: + for k, v in params.items(): + try: + hash(v) + except TypeError as e: + raise TypeError( + "As of JAX v0.7, parameters to jaxpr equations must have __hash__ and " + f"__eq__ methods. In a call to primitive {primitive}, the value of " + f"parameter {k} was not hashable: {v}") from e + + class DynamicJaxprTrace(core.Trace): __slots__ = ("frame", "tag", "parent_trace") @@ -2002,7 +2019,25 @@ def process_primitive(self, primitive, tracers, params): def default_process_primitive(self, primitive, tracers, params, source_info=None): aval_qdds = [t.aval_mutable_qdd for t in tracers] - out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) + # TODO(mattjj): make custom_lin have hashable params. + # TODO(dougalm): add an attribute to primitives to mark primitives with + # effectful abstract_eval rules. + if ( + primitive.name == "custom_lin" + or config.dynamic_shapes.value + or any( + isinstance(aval, core.MutableQuasiDynamicData) for aval in aval_qdds + ) + ): + out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) + else: + try: + out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params) + except Exception as e: + # TODO(phawkins): remove this 3 months after the release of JAX v0.7. + _verify_params_are_hashable(primitive, params) + raise + if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") From 36256960433d1eacc777c864df4b9245d8e47449 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Jun 2025 09:51:04 -0700 Subject: [PATCH 1741/1769] Finalize a number of deprecations for JAX v0.7.0 This PR deliberately leaves type registrations untouched; I'll remove those in a followup after ensuring their removal doesn't lead to build breakages. --- jax/dlpack.py | 4 +- jax/lib/xla_client.py | 15 +++---- jax/lib/xla_extension.py | 83 +++++++++++++++++++++++--------------- jax/tree_util.py | 7 ++-- jax/util.py | 86 +++++++++++++++------------------------- 5 files changed, 94 insertions(+), 101 deletions(-) diff --git a/jax/dlpack.py b/jax/dlpack.py index d008608fc356..c4b993195030 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -24,13 +24,13 @@ _deprecations = { "to_dlpack": ( ( - "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and" " removed in JAX v0.7.0. Please use the newer DLPack API based on" " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" " a JAX array directly to the `from_dlpack` function of another" " framework without using `to_dlpack`." ), - jax._src.dlpack.to_dlpack, + None, ), } diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index faaaf4a425f4..15cd62d6e245 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -78,6 +78,14 @@ def _heap_profile(client): ), None, ), + # Finalized for JAX v0.7.0 + "heap_profile": ( + ( + "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" + ), + None, + ), # Added April 4 2025. "get_topology_for_devices": ( ( @@ -86,13 +94,6 @@ def _heap_profile(client): ), _xc.get_topology_for_devices, ), - "heap_profile": ( - ( - "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" - " will be removed in JAX v0.7.0" - ), - _heap_profile, - ), "mlir_api_version": ( ( "jax.lib.xla_client.mlir_api_version was deprecated in JAX v0.6.0" diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 6b58a72783c9..7e183eab5c2c 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -16,6 +16,7 @@ from jax._src.lib import _jax _deprecations = { + # Finalized for JAX v0.6.0 "ArrayImpl": ( ( "jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array" @@ -30,54 +31,78 @@ ), None, ), - # Deprecated March 26 2025. + # Finalized for JAX v0.7.0 + "Device": ( + ( + "jax.lib.xla_extension.Device was deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0; use jax.Device instead." + ), + None, + ), "DistributedRuntimeClient": ( ( - "jax.lib.xla_extension.DistributedRuntimeClient is" - " deprecated; use jax.distributed instead." + "jax.lib.xla_extension.DistributedRuntimeClient deprecated in JAX" + " v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." + ), + None, + ), + "HloModule": ( + ( + "jax.lib.xla_extension.HloModule deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." + ), + None, + ), + "OpSharding": ( + ( + "jax.lib.xla_extension.OpSharding deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." ), - _jax.DistributedRuntimeClient, + None, + ), + "PjitFunctionCache": ( + ( + "jax.lib.xla_extension.PjitFunctionCache was deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." + ), + None, ), "get_distributed_runtime_client": ( ( - "jax.lib.xla_extension.get_distributed_runtime_client is" - " deprecated; use jax.distributed instead." + "jax.lib.xla_extension.get_distributed_runtime_client was deprecated" + " in JAX v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." ), - _jax.get_distributed_runtime_client, + None, ), "get_distributed_runtime_service": ( ( - "jax.lib.xla_extension.get_distributed_runtime_service is" - " deprecated; use jax.distributed instead." + "jax.lib.xla_extension.get_distributed_runtime_service was deprecated" + " in JAX v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." ), - _jax.get_distributed_runtime_service, + None, ), - "Device": ( - "jax.lib.xla_extension.Device is deprecated; use jax.Device instead.", - _jax.Device, + "jax_jit": ( + "jax.lib.xla_extension.jax_jit deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), - "PjitFunctionCache": ( - "jax.lib.xla_extension.PjitFunctionCache is deprecated.", - _jax.PjitFunctionCache, + "pmap_lib": ( + "jax.lib.xla_extension.pmap_lib deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None + ), + "pytree": ( + "jax.lib.xla_extension.pytree deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), + # Deprecated March 26 2025. "ifrt_proxy": ( "jax.lib.xla_extension.ifrt_proxy is deprecated.", _jax.ifrt_proxy, ), - "jax_jit": ( - "jax.lib.xla_extension.jax_jit is deprecated.", - _jax.jax_jit, - ), "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _jax.mlir), - "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _jax.pmap_lib), "profiler": ( "jax.lib.xla_extension.profiler is deprecated.", jax._src.lib._profiler, ), - "pytree": ( - "jax.lib.xla_extension.pytree is deprecated.", - _jax.pytree, - ), "hlo_module_cost_analysis": ( "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", _jax.hlo_module_cost_analysis, @@ -86,18 +111,10 @@ "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", _jax.hlo_module_to_dot_graph, ), - "HloModule": ( - "jax.lib.xla_extension.HloModule is deprecated.", - _jax.HloModule, - ), "HloPrintOptions": ( "jax.lib.xla_extension.HloPrintOptions is deprecated.", _jax.HloPrintOptions, ), - "OpSharding": ( - "jax.lib.xla_extension.OpSharding is deprecated.", - _jax.OpSharding, - ), "PjitFunction": ( "jax.lib.xla_extension.PjitFunction is deprecated.", _jax.PjitFunction, diff --git a/jax/tree_util.py b/jax/tree_util.py index b35890dfc887..ad864def3b44 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,7 +48,6 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as _deprecated_build_tree, default_registry as default_registry, keystr as keystr, register_dataclass as register_dataclass, @@ -78,10 +77,10 @@ # Added March 21, 2025: "build_tree": ( ( - "jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten" - " instead." + "jax.tree_util.build_tree was deprecated in JAX v0.6.0 and removed in" + " JAX v0.7.0. Use jax.tree.unflatten instead." ), - _deprecated_build_tree, + None ), } diff --git a/jax/util.py b/jax/util.py index b2c9df205206..1931a0293c09 100644 --- a/jax/util.py +++ b/jax/util.py @@ -20,107 +20,83 @@ _deprecations = { + # Finalized in JAX v0.7.0; remove entries in JAX v0.8.0 "to_dlpack": ( ( - "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and" " removed in JAX v0.7.0. Please use the newer DLPack API based on" " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" " a JAX array directly to the `from_dlpack` function of another" " framework without using `to_dlpack`." ), - jax._src.dlpack.to_dlpack, + None, ), "HashableFunction": ( ( - "HashableFunction was deprecated in JAX v0.6.0 and will be removed" + "HashableFunction was deprecated in JAX v0.6.0 and removed" " in JAX v0.7.0." ), - jax._src.util.HashableFunction, + None, ), "as_hashable_function": ( ( - "as_hashable_function was deprecated in JAX v0.6.0 and will be" + "as_hashable_function was deprecated in JAX v0.6.0 and" " removed in JAX v0.7.0." ), - jax._src.util.as_hashable_function, + None, ), "cache": ( - "cache was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", - jax._src.util.cache, + "cache was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "safe_map": ( - ( - "safe_map was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.safe_map, + "safe_map was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "safe_zip": ( ( - "safe_zip was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." + "safe_zip was deprecated in JAX v0.6.0 and removed in JAX v0.7.0." ), - jax._src.util.safe_zip, + None, ), "split_dict": ( - ( - "split_dict was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.split_dict, + "split_dict was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "split_list": ( - ( - "split_list was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.split_list, + "split_list was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "split_list_checked": ( ( - "split_list_checked was deprecated in JAX v0.6.0 and will be" + "split_list_checked was deprecated in JAX v0.6.0 and" " removed in JAX v0.7.0." ), - jax._src.util.split_list_checked, + None, ), "split_merge": ( - ( - "split_merge was deprecated in JAX v0.6.0 and will be removed in" - " JAX v0.7.0." - ), - jax._src.util.split_merge, + "split_merge was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "subvals": ( - ( - "subvals was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.subvals, + "subvals was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "toposort": ( - ( - "toposort was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.toposort, + "toposort was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "unzip2": ( - ( - "unzip2 was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.unzip2, + "unzip2 was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "wrap_name": ( - ( - "wrap_name was deprecated in JAX v0.6.0 and will be removed in JAX" - " v0.7.0." - ), - jax._src.util.wrap_name, + "wrap_name was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), "wraps": ( - "wraps was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", - jax._src.util.wraps, + "wraps was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, ), } From ffabd3ea8f59b57136635ac4a414d8a53767404b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 18 Jun 2025 11:18:03 -0700 Subject: [PATCH 1742/1769] [Mosaic TPU] Make the backward-compatibility libtpu condition stricter A change was submitted today with today's date. We should always leave at least one day of margin. PiperOrigin-RevId: 772993176 --- jax/_src/tpu_custom_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f3f05c1b251f..d0070f5a73ae 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -64,7 +64,7 @@ # We should also add a TODO to remove the conditional one month later. def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: # TODO: b/423649694 - remove the forward compatibility check after 2025-07-18 - if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 6, 18): + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 6, 19): return 4 return None From 4f437a3e86f9510cbdb6ecba35006322f0fa76b2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Jun 2025 18:54:07 +0000 Subject: [PATCH 1743/1769] Remove some dangling references from the docs. --- docs/jax.dlpack.rst | 1 - docs/jax.tree_util.rst | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/jax.dlpack.rst b/docs/jax.dlpack.rst index 4a679052775e..eba3ecf62954 100644 --- a/docs/jax.dlpack.rst +++ b/docs/jax.dlpack.rst @@ -9,4 +9,3 @@ :toctree: _autosummary from_dlpack - to_dlpack \ No newline at end of file diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index c89b777ca548..a17a947af320 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -13,7 +13,6 @@ List of Functions Partial all_leaves - build_tree register_dataclass register_pytree_node register_pytree_node_class From 364f004921bf3db7c6e13736e5a5817560bf11de Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 18 Jun 2025 11:59:02 -0700 Subject: [PATCH 1744/1769] [Mosaic GPU] Fix minor error in matmul test. PiperOrigin-RevId: 773008890 --- tests/pallas/mosaic_gpu_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0a9dc3b5499b..cfc71babef5a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2689,7 +2689,7 @@ def test_simple_collective_matmul(self, m_n_k, swizzle, dtype, lhs_tmem): ) def kernel(a_gmem, b_gmem, out_gmem, a_smem, b_smem, - acc_tmem, scratch_smem, tma_barrier, mma_barrier, + scratch_smem, acc_tmem, tma_barrier, mma_barrier, cluster_barrier, lhs_tmem_ref): cluster_idx = lax.axis_index("x") slice_lhs = pl.ds(cluster_idx * block_lhs_shape[0], block_lhs_shape[0]) From 59034e8eab3e4e061f92c78b7a88b465242e8092 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Jun 2025 12:40:54 -0400 Subject: [PATCH 1745/1769] Run pyupgrade --py311-plus. --- jax/_src/attrs.py | 3 ++- jax/_src/clusters/cloud_tpu_cluster.py | 5 ++--- jax/_src/compiler.py | 3 ++- jax/_src/core.py | 4 ++-- jax/_src/cudnn/scaled_matmul_stablehlo.py | 5 ++--- jax/_src/custom_partitioning.py | 3 ++- jax/_src/export/serialization_generated.py | 22 +++++++++---------- jax/_src/frozen_dict.py | 3 ++- jax/_src/interpreters/batching.py | 2 +- jax/_src/interpreters/mlir.py | 4 ++-- jax/_src/jaxpr_util.py | 3 ++- jax/_src/lax/control_flow/solves.py | 3 ++- jax/_src/linear_util.py | 3 ++- jax/_src/nn/functions.py | 4 ++-- jax/_src/numpy/array_methods.py | 3 ++- jax/_src/numpy/einsum.py | 3 ++- jax/_src/numpy/error.py | 3 ++- jax/_src/numpy/indexing.py | 3 ++- jax/_src/numpy/sorting.py | 2 +- jax/_src/pallas/core.py | 3 ++- jax/_src/pallas/cost_estimate.py | 3 ++- jax/_src/pallas/fuser/block_spec.py | 3 ++- jax/_src/pallas/fuser/fusible_dtype.py | 3 ++- jax/_src/pallas/fuser/fusion.py | 3 ++- jax/_src/pallas/hlo_interpreter.py | 3 ++- jax/_src/pallas/mosaic/core.py | 3 ++- jax/_src/pallas/mosaic/interpret.py | 3 ++- .../pallas/mosaic/pallas_call_registration.py | 4 ++-- jax/_src/pallas/mosaic/pipeline.py | 16 +++++++------- jax/_src/pallas/mosaic/verification.py | 3 ++- jax/_src/pallas/pallas_call.py | 2 +- jax/_src/pallas/primitives.py | 3 ++- jax/_src/pallas/triton/lowering.py | 3 ++- jax/_src/state/indexing.py | 3 ++- jax/_src/state/types.py | 3 ++- jax/_src/state/utils.py | 2 +- jax/_src/xla_bridge.py | 3 ++- .../_private_mm/examples/example_overlap.py | 3 ++- .../_private_mm/examples/example_pp.py | 3 ++- jax/experimental/_private_mm/mm.py | 2 +- .../array_serialization/tensorstore_impl.py | 3 ++- jax/experimental/colocated_python/api.py | 5 +++-- jax/experimental/colocated_python/func.py | 3 ++- .../colocated_python/func_backend.py | 2 +- jax/experimental/colocated_python/obj.py | 9 ++++---- .../colocated_python/obj_backend.py | 3 ++- .../colocated_python/serialization.py | 3 ++- jax/experimental/mosaic/gpu/core.py | 3 ++- .../mosaic/gpu/dialect_lowering.py | 5 +++-- .../mosaic/gpu/fragmented_array.py | 5 +++-- jax/experimental/mosaic/gpu/profiler.py | 3 ++- jax/experimental/ode.py | 2 +- .../pallas/ops/tpu/random/philox.py | 2 +- .../pallas/ops/tpu/random/prng_utils.py | 2 +- .../pallas/ops/tpu/random/threefry.py | 2 +- jax/experimental/roofline/roofline.py | 13 ++++++----- jax/experimental/roofline/rooflines.py | 2 +- jax/experimental/serialize_executable.py | 2 +- jax/experimental/source_mapper/common.py | 3 ++- .../source_mapper/generate_map.py | 3 ++- jax/extend/linear_util.py | 2 +- jax/tools/pgo_nsys_converter.py | 2 +- tests/array_extensibility_test.py | 3 ++- tests/export_test.py | 2 +- tests/linalg_test.py | 2 +- tests/memories_test.py | 6 ++--- tests/mosaic/gpu_dialect_test.py | 2 +- tests/pallas/mosaic_gpu_test.py | 4 ++-- tests/pallas/ops_test.py | 3 ++- tests/pallas/tpu_gmm_test.py | 8 +++---- tests/pallas/tpu_pallas_interpret_test.py | 4 ++-- tests/pallas/tpu_pallas_test.py | 2 +- tests/pgle_test.py | 2 +- tests/roofline_test.py | 4 ++-- tests/unary_ops_accuracy_test.py | 7 +++--- tests/xla_metadata_test.py | 8 +++---- 76 files changed, 166 insertions(+), 125 deletions(-) diff --git a/jax/_src/attrs.py b/jax/_src/attrs.py index 7ad6f0e52d32..6ace51a091e4 100644 --- a/jax/_src/attrs.py +++ b/jax/_src/attrs.py @@ -14,7 +14,8 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from jax._src import core from jax._src import source_info_util diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 4807a7194c5b..f45af0e76dd1 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -14,7 +14,6 @@ from __future__ import annotations -from typing import Optional import logging import os import re @@ -55,7 +54,7 @@ def get_metadata(key): raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") return api_resp.text, api_resp.status_code -def get_tpu_env_value_from_metadata(key) -> Optional[str]: +def get_tpu_env_value_from_metadata(key) -> str | None: metadata_value = None tpu_env_data = get_metadata('tpu-env')[0] key_value_pairs = tpu_env_data.split('\n') @@ -68,7 +67,7 @@ def get_tpu_env_value_from_metadata(key) -> Optional[str]: metadata_value = value.strip().strip("'") return metadata_value -def get_tpu_env_value(key) -> Optional[str]: +def get_tpu_env_value(key) -> str | None: # First try to get the value from the environment. value = os.environ.get(key, None) if value is None: diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index e9312f560190..2288278b41bd 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -21,7 +21,8 @@ from functools import partial import logging import time -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import warnings from jax._src import cache_key as cache_key_type diff --git a/jax/_src/core.py b/jax/_src/core.py index 9ec78c1972fa..f03fc96b76e6 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -675,7 +675,7 @@ def is_valid(self): return not self._invalidated def __repr__(self): - return '{}'.format(self.__class__.__name__) + return f'{self.__class__.__name__}' def process_call(self, call_primitive, f, tracers, params): msg = (f"{type(self)} must override process_call to handle call-like " @@ -2214,7 +2214,7 @@ def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: avals = tuple(a for a in avals if a is not abstract_token) if not avals: return frozenset() - vma, *vmas = [a.vma for a in avals] + vma, *vmas = (a.vma for a in avals) if not all(vma == vma_ for vma_ in vmas): raise ValueError( f'Primitive {prim_name} requires varying manual axes ' diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 8598ca5f8920..49238f09c46a 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -16,7 +16,6 @@ import json import operator from functools import partial, reduce -from typing import List # Third-party imports import jax @@ -591,7 +590,7 @@ def scaled_dot_general_transpose_lhs( def scaled_dot_general_transpose_rhs( g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike, - configs: List[BlockScaleConfig] + configs: list[BlockScaleConfig] ): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) @@ -686,7 +685,7 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers): def scaled_dot_general_wrapper( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None=None, + configs: list[BlockScaleConfig] | None=None, ): if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16): msg = ('Only support preferred_element_type in (f32, bf16, f16), but got ' diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 9d57c6b0c038..4ba07d3652b7 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -21,7 +21,8 @@ from functools import partial import inspect -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import weakref import numpy as np diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 34211c1ebe54..5a3ba8f72322 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,7 +21,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -30,12 +30,12 @@ class PyTreeDefKind(object): custom = 5 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -68,18 +68,18 @@ class DType(object): key_unsafe_rbg = 29 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -214,7 +214,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -286,7 +286,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -355,7 +355,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -391,7 +391,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -437,7 +437,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod diff --git a/jax/_src/frozen_dict.py b/jax/_src/frozen_dict.py index fd01a95145f2..c110717d80b5 100644 --- a/jax/_src/frozen_dict.py +++ b/jax/_src/frozen_dict.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, Mapping, TypeVar +from typing import Any, TypeVar +from collections.abc import Iterator, Mapping K = TypeVar("K") V = TypeVar("V") diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index ddf0f6aa92fd..49c853e33040 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -504,7 +504,7 @@ def process_primitive(self, p, tracers, params): with core.set_current_trace(self.parent_trace): val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: - raise NotImplementedError("Batching rule for '{}' not implemented".format(p)) + raise NotImplementedError(f"Batching rule for '{p}' not implemented") src = source_info_util.current() if p.multiple_results: with core.set_current_trace(self.parent_trace): # val_out may be lazy map diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 43e8047071b9..840ec336a330 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2085,8 +2085,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: """The lowering platforms for the current eqn""" - return tuple((_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or - ctx.platforms or ctx.module_context.platforms)) + return tuple(_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms) def lower_per_platform(ctx: LoweringRuleContext, diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index b26c51483773..81ffb2730b1f 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -22,7 +22,8 @@ import itertools import json import types -from typing import Any, Iterator, Union +from typing import Any, Union +from collections.abc import Iterator from jax._src import core from jax._src import util diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index f34c98c6aaae..cf8769ec9e31 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -15,7 +15,8 @@ import collections from functools import partial import operator -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, tree_unflatten, treedef_tuple) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 41af7644d361..b6b0b3cce982 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -68,7 +68,8 @@ def trans1(static_arg, *dynamic_args, **kwargs): from functools import partial import re import time -from typing import Any, Hashable, NamedTuple +from typing import Any, NamedTuple +from collections.abc import Hashable import warnings import weakref diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 21ea7ac615a9..db960e842403 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -21,7 +21,7 @@ import operator import math import numpy as np -from typing import Any, List, Literal +from typing import Any, Literal import warnings import jax @@ -1364,7 +1364,7 @@ def scaled_dot_general( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None = None, + configs: list[BlockScaleConfig] | None = None, implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index e635bdbbd56d..2de9822206a2 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -26,7 +26,8 @@ import abc from functools import partial, wraps import math -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence import numpy as np diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 544e29020bf5..372b643fbc02 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -13,7 +13,8 @@ # limitations under the License. import collections -from typing import overload, Any, Callable, Sequence +from typing import overload, Any +from collections.abc import Callable, Sequence import numpy as np import opt_einsum diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index 8af0c52566b5..cf69eb10b1a3 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -13,7 +13,8 @@ # limitations under the License. import contextlib -from typing import Literal, Sequence +from typing import Literal +from collections.abc import Sequence import numpy as np diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 573352135806..31ad0ba4ed86 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -18,7 +18,8 @@ from functools import partial import operator import string -from typing import Any, NamedTuple, Sequence +from typing import Any, NamedTuple +from collections.abc import Sequence import numpy as np diff --git a/jax/_src/numpy/sorting.py b/jax/_src/numpy/sorting.py index be8f42ce6145..d8d1f7751d67 100644 --- a/jax/_src/numpy/sorting.py +++ b/jax/_src/numpy/sorting.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Sequence +from collections.abc import Sequence import numpy as np diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 36378b8d1cc1..1aae4452d32f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -24,7 +24,8 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Literal, Protocol, TypeAlias, Union, runtime_checkable +from typing import Any, ClassVar, Literal, Protocol, TypeAlias, Union, runtime_checkable +from collections.abc import Hashable import jax from jax._src import api_util diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 93bcf5348b24..ad238bdf475d 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -15,7 +15,8 @@ import dataclasses import functools import math -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence import jax from jax._src import api_util diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 56f75699735c..e6ca4dddc61b 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -21,7 +21,8 @@ import enum import functools import threading -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Callable, Sequence import jax from jax import lax diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 152b20ff66ea..2d2c8aac2967 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -18,7 +18,8 @@ import dataclasses import functools import itertools as it -from typing import Any, Sequence, TypeVar +from typing import Any, TypeVar +from collections.abc import Sequence import jax from jax._src import api_util diff --git a/jax/_src/pallas/fuser/fusion.py b/jax/_src/pallas/fuser/fusion.py index eff8c36ddb08..6319722a9823 100644 --- a/jax/_src/pallas/fuser/fusion.py +++ b/jax/_src/pallas/fuser/fusion.py @@ -17,7 +17,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, Generic, ParamSpec, TypeVar +from typing import Any, Generic, ParamSpec, TypeVar +from collections.abc import Callable import jax from jax._src import util diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 2568ea8b74a1..038d93d3f9e2 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -27,7 +27,8 @@ from collections.abc import Iterable, Sequence from functools import reduce, partial import itertools -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import jax from jax import lax diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 835b1cf68244..a63df1ca8b42 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -20,7 +20,8 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Literal, Mapping +from typing import Any, ClassVar, Literal +from collections.abc import Mapping import jax from jax._src import core as jax_core diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 6d68a2a7e931..de25d739c554 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -20,7 +20,8 @@ import itertools import math import threading -from typing import Any, Callable, Literal, cast +from typing import Any, Literal, cast +from collections.abc import Callable import jax from jax import lax diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index fc01b696dcd8..8944e06443e9 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import os import tempfile -from typing import List, cast +from typing import cast import jax from jax import dtypes @@ -240,7 +240,7 @@ def _maybe_cast_inputs(*args): else: mosaic_cost_estimate = None if input_memory_spaces is None and output_memory_spaces is not None: - input_memory_spaces_list: List[tpu_custom_call.MemorySpace | None] = [ + input_memory_spaces_list: list[tpu_custom_call.MemorySpace | None] = [ None, ] * len(ctx.avals_in) for input_output_alias in input_output_aliases: diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index afa1972128d0..c766d6ec16b5 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -287,11 +287,11 @@ def init_slots(self): """Initialize slot indices.""" raise NotImplementedError() - def swap_slots(self, predicate: bool = True) -> "BufferedRefBase": + def swap_slots(self, predicate: bool = True) -> BufferedRefBase: """Switch to the next slot.""" raise NotImplementedError() - def load_slots(self) -> "BufferedRefBase": + def load_slots(self) -> BufferedRefBase: """Load slot information into registers.""" raise NotImplementedError() @@ -371,7 +371,7 @@ def bind_existing_ref(self, window_ref, indices): del window_ref, indices return self - def with_spec(self, spec: pl.BlockSpec) -> 'BufferedRefBase': + def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: """Returns a new BufferedRefBase with the given block spec.""" raise NotImplementedError() @@ -535,13 +535,13 @@ def compute_index(self): def memory_space(self): return self.spec.memory_space - def with_spec(self, spec: pl.BlockSpec) -> 'BufferedRef': + def with_spec(self, spec: pl.BlockSpec) -> BufferedRef: """Returns a new BufferedRef with the given block spec.""" return dataclasses.replace(self, _spec=spec) def with_slot_index( self, slot_index: int | jax.Array | None - ) -> "BufferedRef": + ) -> BufferedRef: """Returns a new BufferedRef with the given slot index.""" return dataclasses.replace(self, _current_slot_reg=slot_index) @@ -611,7 +611,7 @@ def init_slots(self): if self.swap is not None: self.swap[0] = False - def swap_slots(self, predicate: bool | jax.Array = True) -> "BufferedRef": + def swap_slots(self, predicate: bool | jax.Array = True) -> BufferedRef: if self.memory_space == VMEM: return self if self.swap is not None: @@ -627,7 +627,7 @@ def swap_slots(self, predicate: bool | jax.Array = True) -> "BufferedRef": self.current_slot[0] = new_current_slot return self - def load_slots(self) -> "BufferedRef": + def load_slots(self) -> BufferedRef: """Load slot information into registers.""" if self.memory_space == VMEM: return self @@ -1042,7 +1042,7 @@ def _end(): def swap_slots( self, buffered_ref, hbm_ref, schedule=None - ) -> "BufferedRefBase": + ) -> BufferedRefBase: # All the copies into and out of BufferedRefs are done by direct # calls to the `copy_in` and `copy_out` methods in the pipeline # loop. To determine if the BufferedRef needs a swap of slots, we diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index f45f36a473e9..d5266826f909 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -18,7 +18,8 @@ import itertools import math import textwrap -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence from jax import lax from jax._src import core as jax_core from jax._src import tree_util diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d95b0ebd0ddb..2a47ce2f8cf0 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1513,7 +1513,7 @@ def pallas_call( interpret: Any = False, name: str | None = None, compiler_params: ( - Mapping[Backend, "CompilerParams"] | "CompilerParams" | None + Mapping[Backend, CompilerParams] | CompilerParams | None ) = None, cost_estimate: CostEstimate | None = None, backend: Backend | None = None, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index bff0a55e272c..a8fdf04d9d9f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -20,7 +20,8 @@ import functools import string from collections.abc import Hashable -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence import jax from jax import lax diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index bd70dc8d470c..e2fb6705de4c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -21,7 +21,8 @@ import functools import math import operator -from typing import Any, Hashable, TypeVar +from typing import Any, TypeVar +from collections.abc import Hashable import jax from jax import lax diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index adca41f82f7c..a0d1d85d09b4 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,7 +17,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence from jax._src import core from jax._src import pretty_printer as pp diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 7ca1d8e48f9e..95e298b532e8 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,8 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union +from collections.abc import Callable from jax._src import core from jax._src import dtypes diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 2dd57dcde0ca..a47d0c9b6f7d 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -14,7 +14,7 @@ """Utilities for tracing stateful functions.""" from functools import partial -from typing import Callable +from collections.abc import Callable import jax from jax._src import core diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 1b95806c37c6..bb77b4f1ff0f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -31,7 +31,8 @@ import pkgutil import platform as py_platform import threading -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence import warnings from jax._src import config diff --git a/jax/experimental/_private_mm/examples/example_overlap.py b/jax/experimental/_private_mm/examples/example_overlap.py index 022eb3293dcc..f3c3726ec347 100644 --- a/jax/experimental/_private_mm/examples/example_overlap.py +++ b/jax/experimental/_private_mm/examples/example_overlap.py @@ -14,7 +14,8 @@ """An example showcasing overlap on a (forward-only) PP-like workload.""" from dataclasses import dataclass -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import time import numpy as np diff --git a/jax/experimental/_private_mm/examples/example_pp.py b/jax/experimental/_private_mm/examples/example_pp.py index b43d1c743c28..846d96cb34a9 100644 --- a/jax/experimental/_private_mm/examples/example_pp.py +++ b/jax/experimental/_private_mm/examples/example_pp.py @@ -15,7 +15,8 @@ from dataclasses import dataclass from functools import cached_property, partial -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import numpy as np diff --git a/jax/experimental/_private_mm/mm.py b/jax/experimental/_private_mm/mm.py index f47724ce6ec4..b108fb3e2e35 100644 --- a/jax/experimental/_private_mm/mm.py +++ b/jax/experimental/_private_mm/mm.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from functools import cached_property, lru_cache, partial, wraps -from typing import Callable +from collections.abc import Callable import jax import jax.numpy as jnp diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py index 81b4a5177029..99f7a137f6f9 100644 --- a/jax/experimental/array_serialization/tensorstore_impl.py +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -18,7 +18,8 @@ import os from os import PathLike import re -from typing import Any, Awaitable, Callable, Sequence +from typing import Any +from collections.abc import Awaitable, Callable, Sequence import math import logging diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index a79bd464fb92..363a2987cb9c 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -16,7 +16,8 @@ from __future__ import annotations import collections -from typing import Any, Callable, Sequence, Type, overload +from typing import Any, overload +from collections.abc import Callable, Sequence import jax from jax._src import api_util @@ -117,6 +118,6 @@ def colocated_python(fun: Callable[..., Any]): ) -def colocated_python_class(cls: Type[object]) -> Type[object]: +def colocated_python_class(cls: type[object]) -> type[object]: """Executes the given Python class methods on the same devices as the arguments.""" return wrap_class(cls, api_util.fun_sourceinfo(cls)) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index d8fa003b775a..9ad84c7e06ad 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -19,7 +19,8 @@ import inspect import random import threading -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence import jax from jax._src import api diff --git a/jax/experimental/colocated_python/func_backend.py b/jax/experimental/colocated_python/func_backend.py index aa514015004d..4f1443da4b17 100644 --- a/jax/experimental/colocated_python/func_backend.py +++ b/jax/experimental/colocated_python/func_backend.py @@ -16,7 +16,7 @@ from __future__ import annotations import threading -from typing import Sequence +from collections.abc import Sequence import jax diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index d7d40e88f925..b962b82525fd 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -18,7 +18,8 @@ import inspect import random import threading -from typing import Any, Callable, Type +from typing import Any +from collections.abc import Callable import jax from jax._src import api_util @@ -70,7 +71,7 @@ def _update_instance_devices( def _make_method( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, uid: int, init_args: tuple[Any, ...], @@ -114,9 +115,9 @@ def method_wrapper(*args, **kwargs): def wrap_class( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, -) -> Type[object]: +) -> type[object]: class WrappedClass: @wraps(cls.__init__) diff --git a/jax/experimental/colocated_python/obj_backend.py b/jax/experimental/colocated_python/obj_backend.py index ffa04a007818..eb3b2c4049d9 100644 --- a/jax/experimental/colocated_python/obj_backend.py +++ b/jax/experimental/colocated_python/obj_backend.py @@ -17,7 +17,8 @@ import dataclasses import threading -from typing import Any, Callable +from typing import Any +from collections.abc import Callable @dataclasses.dataclass(frozen=True) diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 74c34a495920..83a12277a872 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -19,7 +19,8 @@ import collections import functools import io -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence try: import cloudpickle # type: ignore[import-not-found] diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 5c19d5ad0fb6..38d030fa765c 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -23,7 +23,8 @@ import os import pathlib import time -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar +from collections.abc import Callable import weakref import itertools diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index c01038d64088..932ccc1a7980 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -20,7 +20,8 @@ import itertools import math import operator -from typing import Any, Sequence, Type, cast +from typing import Any, cast +from collections.abc import Sequence from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter @@ -198,7 +199,7 @@ def unwrap_transformed_memref( def _register_lowering( - op: str | Type[ir.OpView] | None + op: str | type[ir.OpView] | None ) -> Callable[[MlirLoweringRule], MlirLoweringRule]: def wrapper(f): if op is not None: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index d8b18dfb2ab9..1c61807c83ad 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -21,7 +21,8 @@ import functools import itertools import math -from typing import Generator, Iterable, Protocol, Sequence, TypeVar +from typing import Protocol, TypeVar +from collections.abc import Generator, Iterable, Sequence import jax import jax.experimental.mosaic.gpu as mgpu @@ -163,7 +164,7 @@ def tile_dimension(self, dim: int) -> tuple[bool, ...]: strides[dim] = 0 return tuple(s == 0 for s in self.tile_strides(tuple(strides))) - def remove_dimension(self, dim: int) -> "Tiling": + def remove_dimension(self, dim: int) -> Tiling: """Returns a tiling with the given dimension removed.""" tiling_rank = len(self.tiles[0]) if dim < 0 or dim >= tiling_rank: diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index a048903428ae..b4d06aba1671 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,7 +17,8 @@ import itertools import json import math -from typing import Callable, ParamSpec, TypeAlias, TypeVar +from typing import ParamSpec, TypeAlias, TypeVar +from collections.abc import Callable import warnings import jax diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index db7865124687..04dc380e79bc 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -28,7 +28,7 @@ from functools import partial import operator as op -from typing import Callable +from collections.abc import Callable import jax from jax import api_util diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py index 4c43f5c7c2ff..9c1c3a829510 100644 --- a/jax/experimental/pallas/ops/tpu/random/philox.py +++ b/jax/experimental/pallas/ops/tpu/random/philox.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Philox PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax import typing from jax._src import prng diff --git a/jax/experimental/pallas/ops/tpu/random/prng_utils.py b/jax/experimental/pallas/ops/tpu/random/prng_utils.py index e5a3ac155eea..3014c7748f22 100644 --- a/jax/experimental/pallas/ops/tpu/random/prng_utils.py +++ b/jax/experimental/pallas/ops/tpu/random/prng_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for PRNG kernels.""" -from typing import Sequence +from collections.abc import Sequence from jax import lax import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py index 5fdac5782349..71a314e09b2d 100644 --- a/jax/experimental/pallas/ops/tpu/random/threefry.py +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Threefry PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax._src import prng from jax.experimental import pallas as pl diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index dbfe3c983cc0..b711fd62e069 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -14,7 +14,8 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Callable, Sequence import numpy as np import jax.numpy as jnp @@ -56,7 +57,7 @@ class RooflineShape: dtype: np.dtype @classmethod - def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + def from_aval(cls, aval: core.AbstractValue) -> RooflineShape: if not isinstance(aval, core.ShapedArray): raise TypeError(f"Expected ShapedArray, got {type(aval)}.") if not isinstance(aval.dtype, np.dtype): @@ -87,10 +88,10 @@ class RooflineResult: unfused_hbm_bytes: int = 0 @classmethod - def zeros(cls) -> "RooflineResult": + def zeros(cls) -> RooflineResult: return cls() - def __add__(self, other: "RooflineResult") -> "RooflineResult": + def __add__(self, other: RooflineResult) -> RooflineResult: def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} @@ -104,7 +105,7 @@ def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes, ) - def __mul__(self, constant: int | float) -> "RooflineResult": + def __mul__(self, constant: int | float) -> RooflineResult: return RooflineResult( flops=int(self.flops * constant), unfused_flops=int(self.unfused_flops * constant), @@ -115,7 +116,7 @@ def __mul__(self, constant: int | float) -> "RooflineResult": unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant), ) - def __rmul__(self, constant: int | float) -> "RooflineResult": + def __rmul__(self, constant: int | float) -> RooflineResult: return self.__mul__(constant) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 278c9d3d39ff..4941f95e8e1c 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,7 +14,7 @@ from collections import defaultdict from dataclasses import replace import itertools as it -from typing import Sequence +from collections.abc import Sequence import numpy as np from jax._src import ad_util diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 7c112f56ef42..e1d068ec789f 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -20,7 +20,7 @@ import jax from jax._src.lib import xla_client as xc -from typing import Sequence +from collections.abc import Sequence def serialize(compiled: jax.stages.Compiled): diff --git a/jax/experimental/source_mapper/common.py b/jax/experimental/source_mapper/common.py index f7d10bc88f10..471fc0a7a877 100644 --- a/jax/experimental/source_mapper/common.py +++ b/jax/experimental/source_mapper/common.py @@ -15,7 +15,8 @@ import contextlib import dataclasses import re -from typing import Any, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Sequence from absl import flags import jax diff --git a/jax/experimental/source_mapper/generate_map.py b/jax/experimental/source_mapper/generate_map.py index 76fd0f744463..0066e35285fb 100644 --- a/jax/experimental/source_mapper/generate_map.py +++ b/jax/experimental/source_mapper/generate_map.py @@ -14,7 +14,8 @@ """Generates source maps for JAX functions.""" import os import tempfile -from typing import Sequence, Protocol +from typing import Protocol +from collections.abc import Sequence from jax.experimental.source_mapper import common diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 0cf9a013a9e4..ad67f6ac8f73 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -15,7 +15,7 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from typing import Callable +from collections.abc import Callable from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 10209c9a85ba..3e961733f435 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -64,7 +64,7 @@ m = thunk_re.search(name) if m is not None: if args.post_process: - cost_dictionary.setdefault(m.group(1), []).append((time_ns/1000.0)) + cost_dictionary.setdefault(m.group(1), []).append(time_ns/1000.0) else: protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') if args.post_process: diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 36726659c2f9..6461cb54d73f 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -13,7 +13,8 @@ # limitations under the License. import functools -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/export_test.py b/tests/export_test.py index 829576e95c7e..ecdb470819c8 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -416,7 +416,7 @@ def f(x1, x2): exp = export.export(jax.jit(f))(x1, x2) res = exp.call(x1, x2) self.assertEqual(tree_util.tree_structure(res), - tree_util.tree_structure(((x1, x2, x1, x2)))) + tree_util.tree_structure((x1, x2, x1, x2))) self.assertEqual(type(res[0]), type(x1)) self.assertEqual(type(res[1]), type(x2)) self.assertEqual(type(res[2]), type(x1)) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 99cb66c92857..c75927b26fd8 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -14,7 +14,7 @@ from functools import partial import itertools -from typing import Iterator +from collections.abc import Iterator from unittest import skipIf import numpy as np diff --git a/tests/memories_test.py b/tests/memories_test.py index fd40330f8db8..1b56236c91c9 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -638,7 +638,7 @@ def f(): @jtu.run_on_devices('tpu') def test_ragged_copy_on_host(self): mesh = jtu.create_mesh((2,), ('x')) - sharding = jax.sharding.NamedSharding(mesh, P(('x'))) + sharding = jax.sharding.NamedSharding(mesh, P('x')) cpu_sharding = sharding.with_memory_kind('pinned_host') num_pages = 512 * 1024 @@ -648,7 +648,7 @@ def test_ragged_copy_on_host(self): def write(x): return x.at[16 * 1024:].set(0) - x = shard_map(write, mesh=mesh, in_specs=P(('x'),), out_specs=P(('x')))(x) + x = shard_map(write, mesh=mesh, in_specs=P(('x'),), out_specs=P('x'))(x) chunk_size = 8 def inner(state): @@ -670,7 +670,7 @@ def foo(x): return cpu_x fn = jax.jit(shard_map(foo, mesh=mesh, in_specs=P(('x'),), - out_specs=P(('x')), check_vma=False), + out_specs=P('x'), check_vma=False), out_shardings=cpu_sharding) y = fn(x) jax.block_until_ready(y) diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index c297cb676f2d..6a49181e10c8 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -14,7 +14,7 @@ # ============================================================================== """(Deviceless) tests for the Mosaic GPU MLIR dialect.""" -from typing import Callable +from collections.abc import Callable from absl.testing import parameterized import jax diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cfc71babef5a..0fa38e34af21 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1514,7 +1514,7 @@ def kernel(x_ref, o_ref): jax.block_until_ready(y) jax.effects_barrier() [name] = os.listdir(tmpdir) - with open(os.path.join(tmpdir, name), "r") as f: + with open(os.path.join(tmpdir, name)) as f: data = f.read() self.assertEqual(data.count('"name": "add"'), 2) self.assertEqual(data.count('"name": "load"'), 2) @@ -1761,7 +1761,7 @@ def kernel(x_ref, y_ref, o_ref128, ref_union, o_smem): plgpu.commit_smem() plgpu.copy_smem_to_gmem(o_smem, o_ref128) - x, y = [jnp.arange(128).astype(jnp.float32) for _ in range(2)] + x, y = (jnp.arange(128).astype(jnp.float32) for _ in range(2)) np.testing.assert_array_equal(kernel(x, y), x + y) @parameterized.parameters(1, 2, 3) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e951a9fda827..162152dc2a3f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -17,7 +17,8 @@ import itertools import math import sys -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import unittest from absl.testing import absltest diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index 7bc698794f09..6d8b0a244edb 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -203,10 +203,10 @@ def gmm_test( ): seed = data.draw(seed_strategy()) num_groups, _ = data.draw(group_strategy(max_stride=1)) - lhs_dtype, rhs_dtype, out_dtype = [ + lhs_dtype, rhs_dtype, out_dtype = ( data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) for _ in range(3) - ] + ) transpose_rhs = data.draw(hps.booleans()) key = jax.random.key(seed) @@ -293,10 +293,10 @@ def test_gmm_sharded_groups( ): seed = data.draw(seed_strategy()) num_groups, group_stride = data.draw(group_strategy()) - lhs_dtype, rhs_dtype, out_dtype = [ + lhs_dtype, rhs_dtype, out_dtype = ( data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) for _ in range(3) - ] + ) key = jax.random.key(seed) k1, k2 = jax.random.split(key, 2) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 5fafb3007993..725fdaa49c3d 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -38,7 +38,7 @@ jax.config.update('jax_threefry_partitionable', True) -class CountStoreCallbacksContext(object): +class CountStoreCallbacksContext: """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" def __init__(self): @@ -69,7 +69,7 @@ class ProcessedGridPoint(): core_id: int -class GridPointRecorderContext(object): +class GridPointRecorderContext: """Records grid points in the order in which they are procsessed.""" def __init__(self): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 9a85c869a190..cbb82efaee3f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -22,7 +22,7 @@ import math import re import sys -from typing import Callable +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized import jax diff --git a/tests/pgle_test.py b/tests/pgle_test.py index e03e5127d023..78679adc962a 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -515,7 +515,7 @@ def get_new_hlo(): get_new_hlo.seen_files |= additions new_hlos = list(filter(lambda f: f.endswith("_gpu_after_optimizations.txt"), additions)) assert len(new_hlos) == 1 - with open(os.path.join(dump_dir, new_hlos[0]), "r") as ifile: + with open(os.path.join(dump_dir, new_hlos[0])) as ifile: return ifile.read() get_new_hlo.seen_files = set() diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 497b3f14958e..4dd8ca6c4759 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,7 +14,7 @@ from __future__ import annotations from functools import partial -from typing import Sequence +from collections.abc import Sequence from absl.testing import absltest import jax @@ -680,7 +680,7 @@ def test_conv_general_dilated_unfused_hbm_bytes( expected_output_shape = jnp.array( (batch / batch_group_count, num_output_channels, ow, oh) ) - expected_output_size = jnp.prod((expected_output_shape)) + expected_output_size = jnp.prod(expected_output_shape) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index 23a8fa7f42a1..91e8fe2d1dbf 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -14,7 +14,8 @@ """Unit test for result accuracy for unary ops.""" -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized @@ -33,8 +34,8 @@ class TolerancePair(NamedTuple): - high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT - low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + high: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT + low: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT def make_unary_test_cases( diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index ba2120fcb7b9..8ac54fd402d6 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -200,8 +200,8 @@ def f(x): return jax.lax.cond(x < 0., sin, cos, x) hlo_lines = f.lower(1.).as_text().split("\n") - sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] - cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) @@ -218,8 +218,8 @@ def f(x): return jax.lax.cond(x < 0., sin, cos, x) hlo_lines = f.lower(1.).as_text().split("\n") - sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] - cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) From 4aa1db4f8bfd06abaae14ec10aa0ec8b31591c03 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 18 Jun 2025 12:34:56 -0700 Subject: [PATCH 1746/1769] Remove PositionalSharding from JAX now that 0.6.2 is out and next release is 0.7.0. Also remove `GSPMDSharding` export from `jax.sharding` endpoint PiperOrigin-RevId: 773021473 --- jax/_src/custom_partitioning.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/interpreters/pxla.py | 11 +- jax/_src/sharding_impls.py | 208 +------------------------------- jax/sharding.py | 27 ++--- jaxlib/jax_jit.cc | 14 ++- tests/array_test.py | 76 +----------- tests/pjit_test.py | 52 +++----- 8 files changed, 38 insertions(+), 354 deletions(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 9d57c6b0c038..a16d91644469 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -567,7 +567,7 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): return hlo_sharding if mesh.empty or not decode_shardings: assert devices is not None - return sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices) + return sharding_impls.GSPMDSharding(devices, hlo_sharding) pspec = sharding_impls.parse_flatten_op_sharding( hlo_sharding, mesh)[0] pspec = sharding_impls.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index e587d48cda68..7c09ab998195 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -474,7 +474,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): if mesh.empty: return callback( - sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices)) + sharding_impls.GSPMDSharding(devices, hlo_sharding)) pspec = (P() if hlo_sharding.is_manual() else parse_flatten_op_sharding(hlo_sharding, mesh)[0]) return callback(NamedSharding(mesh, pspec)) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3f6ee973554a..03336655ceba 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -68,7 +68,7 @@ from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding, + SingleDeviceSharding, GSPMDSharding, NamedSharding, PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, @@ -2550,15 +2550,6 @@ def _gspmd_to_named_sharding( return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh) _orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding -def _gspmd_to_positional_sharding( - out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding - ) -> PositionalSharding: - assert isinstance(out_s, GSPMDSharding) - assert isinstance(orig_in_s, PositionalSharding) - return sharding_impls._op_sharding_to_pos_sharding( - out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) -_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore - def _gspmd_to_single_device_sharding( out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding ) -> SingleDeviceSharding: diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b5592ef46a66..1d77874c9420 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -43,7 +43,7 @@ from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec -from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, use_cpp_class, use_cpp_method import numpy as np config_ext = xc._xla.config @@ -354,212 +354,6 @@ def shard_shape(self, global_shape: Shape) -> Shape: PmapSharding.__module__ = 'jax.sharding' -def _op_sharding_to_pos_sharding( - op_sharding: xc.OpSharding | xc.HloSharding, - device_assignment: Sequence[xc.Device], - memory_kind: str | None = None) -> PositionalSharding: - if isinstance(op_sharding, xc.OpSharding): - op_sharding = xc.HloSharding.from_proto(op_sharding) - - if op_sharding.is_replicated(): - return PositionalSharding( - device_assignment, memory_kind=memory_kind).replicate() - - if len(op_sharding.subgroup_types()) > 1: - raise NotImplementedError( - 'Unhandled HloSharding type. Please open a bug report!' - ) - - name = device_assignment[0].platform.upper() - ids = np.array( - [DeviceIdSet(name, i) for i in op_sharding.tile_assignment_devices()] - ) - p = PositionalSharding._remake(tuple(device_assignment), ids, - memory_kind=memory_kind) - p = p.reshape(op_sharding.tile_assignment_dimensions()) - if op_sharding.replicate_on_last_tile_dim(): - p = p.replicate(-1, keepdims=False) - return p - - -@util.cache(max_size=4096, trace_context_in_key=False) -def _positional_sharding_to_xla_hlo_sharding( - self, num_dimensions: int) -> xc.HloSharding: - if self.shape == (1,) * self.ndim: - return replicated_hlo_sharding - - pbuf = xc.OpSharding() - shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val - set_size, = {len(device_set) for device_set in self._ids.flat} - pbuf.type = xc.OpSharding.Type.OTHER - if set_size > 1: - pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED] - pbuf.tile_assignment_dimensions = (*shape, set_size) - else: - pbuf.tile_assignment_dimensions = shape - pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids] - product_of_dims = math.prod(pbuf.tile_assignment_dimensions) - num_devices = len(pbuf.tile_assignment_devices) - assert product_of_dims == num_devices, (product_of_dims, num_devices) - return xc.HloSharding.from_proto(pbuf) - - -class PositionalSharding(jsharding.Sharding): - _devices: tuple[xc.Device, ...] - _memory_kind: str | None - _ids: np.ndarray # dtype DeviceIdSet - - def __init__(self, devices: Sequence[xc.Device] | np.ndarray, - *, memory_kind: str | None = None): - super().__init__() - if not isinstance(devices, np.ndarray): - devices = np.array(devices, dtype='object') - if not devices.size: - raise ValueError(f"{self.__class__.__name__}.__init__ requires at least " - f"one device, got {devices}") - self._devices = tuple(devices.flat) - self._memory_kind = memory_kind - name = self._devices[0].platform.upper() - self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)], - dtype='object').reshape(devices.shape) - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - self._memory_kind, self._internal_device_list) - - @property - def shape(self): - return self._ids.shape - - @property - def ndim(self): - return self._ids.ndim - - def __repr__(self) -> str: - cls_name = self.__class__.__name__ - ids = self._ids.copy() - platform_name = self._devices[0].platform.upper() - for idx, x in np.ndenumerate(ids): - ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) - body = np.array2string(ids, prefix=cls_name + '(', suffix=')', - max_line_width=100) - mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' - return f'{cls_name}({body}{mem}, shape={self.shape})' - - def reshape(self, *shape) -> PositionalSharding: - return self._remake(self._devices, self._ids.reshape(*shape), - memory_kind=self.memory_kind) - - def transpose(self, *axes) -> PositionalSharding: - return self._remake(self._devices, self._ids.transpose(*axes), - memory_kind=self.memory_kind) - T = property(transpose) - - def replicate(self, axis=None, keepdims=True) -> PositionalSharding: - new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union - return self._remake(self._devices, new_ids, - memory_kind=self.memory_kind) - - def check_compatible_aval(self, aval_shape: Shape) -> None: - if len(aval_shape) != len(self.shape) and not self.is_fully_replicated: - raise ValueError( - f"Sharding {self} is only valid for values of rank " - f"{len(self.shape)}, but was applied to a value of rank " - f"{len(aval_shape)}") - - @classmethod - def _remake( - cls, devices: tuple[xc.Device, ...], ids: np.ndarray, - *, memory_kind: str | None = None) -> PositionalSharding: - sharding = cls(devices, memory_kind=memory_kind) - sharding._ids = ids - return sharding - - # Hashable - - def __hash__(self) -> int: - if not hasattr(self, '_hash'): - self._hash = hash((self._internal_device_list, self.memory_kind)) - return self._hash - - def __eq__(self, other) -> bool: - if not isinstance(other, PositionalSharding): - return False - if self is other: - return True - all_ids_equal = np.array_equal(self._ids,other._ids) - mem_kind_equal = self.memory_kind == other.memory_kind - if self._devices is other._devices and mem_kind_equal and all_ids_equal: - return True - return (mem_kind_equal and all_ids_equal and - self._internal_device_list == other._internal_device_list) - - # Sharding interface - - @property - def num_devices(self) -> int: - return len(self.device_set) - - @functools.cached_property - def device_set(self) -> set[xc.Device]: - return set(self._devices) - - @property - def memory_kind(self) -> str | None: - return self._memory_kind - - def with_memory_kind(self, kind: str) -> PositionalSharding: - return PositionalSharding(self._devices, memory_kind=kind) - - @functools.cached_property - def is_fully_replicated(self) -> bool: - return self.shape == (1,) * self.ndim - - # jsharding.Sharding interface - - @property - def _device_assignment(self) -> XLADeviceAssignment: - return self._devices - - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions) - - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: - raise NotImplementedError( - "PositionalSharding can't be converted to an SdyArray.") - - @functools.cached_property - def is_fully_addressable(self) -> bool: - return self._internal_device_list.is_fully_addressable - - -class DeviceIdSet: - _name: str - _ids: frozenset[int] - def __init__(self, name, *ids): - self._name = name - self._ids = frozenset(ids) - - def __iter__(self): - return iter(sorted(self._ids)) - - def __add__(self, other) -> DeviceIdSet: - assert isinstance(other, DeviceIdSet) - return DeviceIdSet(self._name, *(self._ids | other._ids)) - - def __len__(self) -> int: - return len(self._ids) - - def __repr__(self) -> str: - ids = ', '.join(safe_map(str, sorted(self._ids))) - return f'{{{self._name} {ids}}}' - - def __hash__(self) -> int: - return hash((self._name, self._ids)) - - def __eq__(self, other) -> bool: - return (isinstance(other, DeviceIdSet) and self._name == other._name and - self._ids == other._ids) - def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) diff --git a/jax/sharding.py b/jax/sharding.py index 66692069d19b..ef963d6a0138 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,8 +20,6 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, - GSPMDSharding as _deprecated_GSPMDSharding, - PositionalSharding as _deprecated_PositionalSharding, use_mesh as use_mesh, set_mesh as set_mesh, ) @@ -39,26 +37,21 @@ # Added April 11, 2025. "PositionalSharding": ( ( - "jax.sharding.PositionalSharding is deprecated. Use" - " jax.NamedSharding instead." + "jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" ), - _deprecated_PositionalSharding, + None, ), "GSPMDSharding": ( ( - "jax.sharding.GSPMDSharding is deprecated. Use" - " jax.NamedSharding instead." + "jax.sharding.GSPMDSharding was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" ), - _deprecated_GSPMDSharding, + None, ), } -import typing -if typing.TYPE_CHECKING: - PositionalSharding = _deprecated_PositionalSharding - GSPMDSharding = _deprecated_GSPMDSharding -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jaxlib/jax_jit.cc b/jaxlib/jax_jit.cc index 53263949e72e..e314213e055d 100644 --- a/jaxlib/jax_jit.cc +++ b/jaxlib/jax_jit.cc @@ -150,7 +150,7 @@ std::string ArgumentSignature::DebugString() const { return absl::StrFormat( "static args (positional + keyword): [%s], " "static arg keyword names: [%s], " - "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg signatures (positional + keyword): [%s], " "dynamic arg shardings: [%s]", absl::StrJoin(static_args, ",", py_object_formatter), absl::StrJoin(static_arg_names, ",", py_object_formatter), @@ -256,12 +256,16 @@ size_t HashShardingForJit(nb::handle sharding) { } bool EqualShardingsForJit(nb::handle a, nb::handle b) { - if (a.ptr() == b.ptr()) return true; + if (a.ptr() == b.ptr()){ + return true; + } auto a_type = a.type(); auto b_type = b.type(); - if (!a_type.is(b_type)) return false; + if (!a_type.is(b_type)) { + return false; + } if (a_type.is(NamedSharding::type())) { auto* a_named_sharding = nb::inst_ptr(a); @@ -277,8 +281,7 @@ bool EqualShardingsForJit(nb::handle a, nb::handle b) { if (a_type.is(GSPMDSharding::type())) { auto* a_gspmd_sharding = nb::inst_ptr(a); auto* b_gspmd_sharding = nb::inst_ptr(b); - - return a_gspmd_sharding == b_gspmd_sharding; + return *a_gspmd_sharding == *b_gspmd_sharding; } if (a_type.is(SingleDeviceSharding::type())) { @@ -286,7 +289,6 @@ bool EqualShardingsForJit(nb::handle a, nb::handle b) { nb::inst_ptr(a); auto* b_single_device_sharding = nb::inst_ptr(b); - return a_single_device_sharding->device().ptr() == b_single_device_sharding->device().ptr() && a_single_device_sharding->memory_kind().equal( diff --git a/tests/array_test.py b/tests/array_test.py index dd8ddb078443..b98f25abca7a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -34,8 +34,7 @@ from jax._src.mesh import AxisType, AbstractMesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( - _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding, PositionalSharding, SdyDim, + pmap_sharding_devices_indices_map, NamedSharding, GSPMDSharding, SdyDim, SdyArray) from jax.experimental.pjit import pjit from jax.experimental import multihost_utils @@ -993,74 +992,6 @@ def test_gspmd_sharding_repr(self): # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) - @parameterized.named_parameters( - ("mesh_x_y", P("x", "y"), (4, 2), (), False), - ("mesh_x", P("x"), (4, 2), (1,), False), - ("mesh_y", P("y"), (4, 2), (0,), True), - ("mesh_none_y", P(None, "y"), (4, 2), (0,), False), - ("mesh_none_x", P(None, "x"), (4, 2), (1,), True), - ("mesh_xy", P(("x", "y")), (8, 1), (), False), - ("mesh_fully_replicated", P(), (4, 2), None, False), - ) - def test_positional_sharding_op_sharding_lowering( - self, pspec, shape, axes, transpose): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, pspec) - devices = jax.local_devices()[:8] # Taking up to 8 devices - - devices_sharding = PositionalSharding(devices) - devices_sharding = devices_sharding.reshape(shape).replicate(axes) - if transpose: - devices_sharding = devices_sharding.T - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - - @parameterized.named_parameters( - ("2d_mesh_x_y", (4, 2), P("x", "y")), - ("2d_mesh_x", (4, 2), P("x")), - ("2d_mesh_y", (4, 2), P("y")), - ("2d_mesh_none_y", (4, 2), P(None, "y")), - ("2d_mesh_none_x", (4, 2), P(None, "x")), - ("2d_mesh_xy", (4, 2), P(("x", "y"))), - ("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))), - ("2d_mesh_x_none", (2, 1), P(('x',), None)), - ("2d_mesh_fully_replicated", (4, 2), P()), - ("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')), - ("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)), - ("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)), - ("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))), - ("3d_mesh_x_none_yz", (2, 2, 2), P('x', None, ('y', 'z'))), - ("3d_mesh_none_x_yz", (2, 2, 2), P(None, 'x', ('y', 'z'))), - ("3d_mesh_xy_z", (2, 2, 2), P(('x', 'y'), 'z')), - ("3d_mesh_xy_none_z", (2, 2, 2), P(('x', 'y'), None, 'z')), - ("3d_mesh_x_y_z", (2, 2, 2), P('x', 'y', 'z')), - ("3d_mesh_xz_y", (2, 2, 2), P(('x', 'z'), 'y')), - ("3d_mesh_xz_none_y", (2, 2, 2), P(('x', 'z'), None, 'y')), - ("3d_mesh_y_none_xz", (2, 2, 2), P('y', None, ('x', 'z'))), - ("3d_mesh_none_y_xz", (2, 2, 2), P(None, 'y', ('x', 'z'))), - ("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')), - ("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)), - ("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)), - ) - def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): - ndim = len(mesh_shape) - mesh = jtu.create_mesh( - mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) - mps = jax.sharding.NamedSharding(mesh, pspec) - original_op_sharding = mps._to_xla_hlo_sharding(ndim) - ps = _op_sharding_to_pos_sharding(original_op_sharding, - mps._device_assignment) - out_op_sharding = ps._to_xla_hlo_sharding(ndim) - self.assertTrue(op_shardings.are_op_shardings_equal( - original_op_sharding, out_op_sharding)) - @parameterized.named_parameters( ("2d_mesh_x", (1, 1), P("x", "y")), ("2d_mesh_x_y", (4, 2), P("x", "y")), @@ -1090,11 +1021,6 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) self.assertEqual(mps.is_fully_replicated, ops_ifr) - ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment) - self.assertEqual(ps.is_fully_replicated, - op_shardings.is_op_sharding_replicated( - ps._to_xla_hlo_sharding(len(shape)))) - def test_pmap_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b2661b39ef4d..a814d25ba655 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -52,7 +52,7 @@ from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( - AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, + AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, use_auto_axes, use_explicit_axes, reshard, @@ -507,8 +507,6 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -3673,20 +3671,17 @@ def g(x): @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_out_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) def mul(x): return x * 2 f = pjit(mul, out_shardings=ns) - f2 = pjit(mul, out_shardings=ps) with jtu.count_pjit_cpp_cache_miss() as count: out = f(arr) @@ -3697,24 +3692,12 @@ def mul(x): self.assertIsInstance(out.sharding, NamedSharding) self.assertEqual(count(), 1) - with jtu.count_pjit_cpp_cache_miss() as count: - out2 = f2(arr) - cache_info2 = pxla._cached_compilation.cache_info() - self.assertIsInstance(out2.sharding, PositionalSharding) - - out2 = f2(arr) - self.assertIsInstance(out2.sharding, PositionalSharding) - self.assertEqual(count(), 1) - - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) - with jtu.count_jit_tracing_cache_miss() as tracing_count: out3 = jnp.squeeze(arr, axis=-1) self.assertIsInstance(out3.sharding, NamedSharding) out4 = jnp.squeeze(arr2, axis=-1) - self.assertIsInstance(out4.sharding, PositionalSharding) + self.assertIsInstance(out4.sharding, GSPMDSharding) self.assertEqual(tracing_count(), 2) @jtu.thread_unsafe_test() # cache_info isn't thread-safe @@ -3789,14 +3772,12 @@ def generate_random_numbers(): @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_jit_mul_sum_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(tuple(mesh.devices.flat), ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) f = jax.jit(lambda x: x * 2) @@ -3806,11 +3787,11 @@ def test_jit_mul_sum_sharding_preserved(self): with jtu.count_pjit_cpp_cache_miss() as cpp_count: out2 = f(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) # This will hit the cpp cache. out3 = f(out2) - self.assertIsInstance(out3.sharding, PositionalSharding) + self.assertIsInstance(out3.sharding, GSPMDSharding) self.assertEqual(compilation_count(), 2) self.assertEqual(cpp_count(), 1) @@ -3858,8 +3839,6 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3868,10 +3847,10 @@ def test_sharding_preserved_apply_primitive(self): out = jnp.copy(arr) self.assertIsInstance(out.sharding, NamedSharding) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) out2 = jnp.copy(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) arr3 = jnp.arange(8) out3 = jnp.copy(arr3) @@ -4348,11 +4327,10 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() - sharding = PositionalSharding(jax.devices()) + mesh = Mesh(jax.devices(), 'x') + sharding = NamedSharding(mesh, P('x')) def model(params, x): return x @ params @@ -4365,8 +4343,8 @@ def model(params, x): params = jnp.ones(feature_dim) # Shard data, replicate params - x = jax.device_put(x, sharding.reshape(n_devices, 1)) - params = jax.device_put(params, sharding.replicate(axis=0)) + x = jax.device_put(x, sharding) + params = jax.device_put(params, NamedSharding(mesh, P())) model(params, x) # doesn't crash From 4f871b0e2cbf0148580d68589173facb02a9e1ac Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 18 Jun 2025 12:50:59 -0700 Subject: [PATCH 1747/1769] Skip pytest for Python 3.14 during the JAX release process This Python version hasn't been officially released or supported yet. PiperOrigin-RevId: 773026723 --- .github/workflows/wheel_tests_nightly_release.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index f5622af0f1ef..da6d87495b21 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -36,6 +36,7 @@ concurrency: permissions: {} jobs: run-pytest-cpu: + if: ! ( matrix.python == '3.14' && startsWith(github.ref_name, 'release/') ) uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -60,6 +61,7 @@ jobs: halt-for-connection: ${{inputs.halt-for-connection}} run-pytest-cuda: + if: ! ( matrix.python == '3.14' && startsWith(github.ref_name, 'release/') ) uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure From 7c9613a628e09cc902e46dd0274b84e0d116d4c8 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Wed, 18 Jun 2025 12:59:24 -0700 Subject: [PATCH 1748/1769] Fix rare error with Literal in DynamicJaxprTracer.full_lower. PiperOrigin-RevId: 773029698 --- jax/_src/interpreters/partial_eval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index e5ac29504ee0..b7e02ee0fd18 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1655,6 +1655,8 @@ def aval_mutable_qdd(self): def full_lower(self): var = self._trace.frame.tracer_to_var.get(id(self)) if var is None: return self + if isinstance(var, Literal): + return var.val val = self._trace.frame.constvar_to_val.get(var) if val is None: return self return core.full_lower(val) From 511bf2fe08978a04233b3a8ad02af3a2eebaca4d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Jun 2025 13:00:07 -0700 Subject: [PATCH 1749/1769] Move jax._src.lax to its own BUILD rule This also bundles-in `ad_checkpoint.py` and `state/*.py` because these have circular dependencies on lax source files. Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times. PiperOrigin-RevId: 773030035 --- jax/BUILD | 63 ++++++++++++++++++++++++-- jax/_src/lax/control_flow/common.py | 10 ++--- jax/_src/lax/control_flow/for_loop.py | 8 ++-- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/control_flow/solves.py | 4 +- jax/_src/lax/lax.py | 8 ++-- jax/_src/lax/linalg.py | 64 +++++++++++++-------------- jax/_src/lax/parallel.py | 24 +++++----- jax/_src/lax/special.py | 2 +- jax/_src/lax/windowed_reductions.py | 12 +++-- jax/_src/pallas/mosaic_gpu/BUILD | 5 +++ jax/_src/pallas/triton/BUILD | 2 + jax/_src/state/primitives.py | 4 +- jax/_src/state/utils.py | 9 ++-- jax/_src/tpu/linalg/stack.py | 16 ++++--- jax/experimental/ode.py | 2 +- jax/extend/BUILD | 1 + 17 files changed, 151 insertions(+), 85 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 6b1c5dc88499..44e6faf896c3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -295,7 +295,7 @@ py_library_providing_imports_info( name = "jax", srcs = [ "_src/__init__.py", - "_src/ad_checkpoint.py", + "_src/ad_checkpoint.py", # TODO(vanderplas): remove once downstream users depend on :lax "_src/blocked_sampler.py", "_src/checkify.py", "_src/debugging.py", @@ -312,12 +312,12 @@ py_library_providing_imports_info( "_src/cudnn/**/*.py", "_src/debugger/**/*.py", "_src/image/**/*.py", - "_src/lax/**/*.py", + "_src/lax/**/*.py", # TODO(vanderplas): remove once downstream users depend on :lax "_src/nn/**/*.py", "_src/numpy/**/*.py", "_src/ops/**/*.py", "_src/scipy/**/*.py", - "_src/state/**/*.py", + "_src/state/**/*.py", # TODO(vanderplas): remove once downstream users depend on :lax and :state_types "_src/third_party/**/*.py", "_src/tpu/**/*.py", "experimental/key_reuse/**/*.py", @@ -391,6 +391,7 @@ py_library_providing_imports_info( ":hashable_array", ":internal_mesh_utils", ":jaxpr_util", + ":lax", ":layout", ":lazy_loader", ":mesh", @@ -625,6 +626,59 @@ pytype_strict_library( ], ) +py_library_providing_imports_info( + name = "lax", + srcs = glob( + [ + "_src/lax/**/*.py", + "_src/state/**/*.py", + ], + exclude = [ + # These are included in :state_types. + "_src/state/__init__.py", + "_src/state/indexing.py", + "_src/state/types.py", + ], + ) + [ + "_src/ad_checkpoint.py", + ], + visibility = [":internal"] + jax_visibility("lax"), + deps = [ + ":abstract_arrays", + ":ad", + ":ad_util", + ":api", + ":api_util", + ":attrs", + ":batching", + ":callback", + ":config", + ":core", + ":custom_derivatives", + ":custom_partitioning_sharding_rule", + ":dtypes", + ":effects", + ":ffi", + ":mesh", + ":mlir", + ":named_sharding", + ":partial_eval", + ":partition_spec", + ":pretty_printer", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "lru_cache", srcs = ["_src/lru_cache.py"], @@ -1086,7 +1140,9 @@ pytype_strict_library( deps = [ ":deprecations", ":jax", + ":lax", ":source_info_util", + ":state_types", "//jax/_src/pallas", ] + py_deps("numpy"), ) @@ -1495,6 +1551,7 @@ pytype_strict_library( "_src/state/indexing.py", "_src/state/types.py", ], + visibility = [":internal"] + jax_visibility("state_types"), deps = [ ":core", ":dtypes", diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cb80df76326b..b90eda4e765c 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -20,17 +20,17 @@ from functools import partial from typing import Any +from jax._src import ad_util from jax._src import api_util from jax._src import core -from jax._src import linear_util as lu -from jax._src.lax import lax from jax._src import effects -from jax._src import ad_util +from jax._src import linear_util as lu from jax._src import state +from jax._src.lax import lax from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax._src.interpreters import partial_eval as pe -from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef -from jax._src.tree_util import equality_errors_pytreedef +from jax._src.tree_util import (equality_errors_pytreedef, tree_map, + tree_unflatten, keystr, PyTreeDef) map, unsafe_map = safe_map, map diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 90b81ae367aa..773061c59bd4 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -20,14 +20,13 @@ import operator from typing import Any, Generic, TypeVar -from jax import lax from jax._src import api_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, - treedef_tuple, tree_map, tree_leaves, PyTreeDef) +from jax._src.tree_util import (tree_flatten, tree_structure, tree_unflatten, + treedef_tuple, tree_map, tree_leaves, PyTreeDef) from jax._src import ad_util from jax._src import core @@ -35,6 +34,7 @@ from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util +from jax._src.lax import lax from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect) from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives @@ -272,7 +272,7 @@ def while_body(carry): state = body(i, state) i = i + 1 return i, state - _, state = lax.while_loop(cond, while_body, (i, state)) + _, state = loops.while_loop(cond, while_body, (i, state)) return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 9416711d478b..65162ea15305 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -66,7 +66,7 @@ merge_lists, partition_list, safe_map, safe_zip, split_list, split_list_checked, unzip2, weakref_lru_cache,) from jax._src import xla_bridge as xb -from jax.tree_util import ( +from jax._src.tree_util import ( keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, treedef_is_leaf) import numpy as np diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index f34c98c6aaae..c8ce620ad464 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -17,8 +17,6 @@ import operator from typing import Any, Callable -from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, - tree_unflatten, treedef_tuple) from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -30,6 +28,8 @@ from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.traceback_util import api_boundary +from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, + tree_unflatten, treedef_tuple) from jax._src.util import split_list, safe_map import numpy as np diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 34be5b0774a8..fcbe86968e06 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -28,10 +28,6 @@ import numpy as np -from jax import tree_util -from jax.sharding import Sharding -from jax.tree_util import tree_map - from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -46,6 +42,7 @@ from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import state +from jax._src import tree_util from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, @@ -67,6 +64,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo +from jax._src.sharding import Sharding from jax._src.sharding_impls import (PmapSharding, NamedSharding, ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) @@ -3487,7 +3485,7 @@ def stop(x): return ad_util.stop_gradient_p.bind(x) else: return x - return tree_map(stop, x) + return tree_util.tree_map(stop, x) def reduce_precision(operand: float | ArrayLike, exponent_bits: int, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 7f7b1056dbee..81d23465ea34 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -23,8 +23,6 @@ import numpy as np -from jax import lax - from jax._src import ad_util from jax._src import api from jax._src import config @@ -38,7 +36,8 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow +from jax._src.lax import lax from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int from jax._src.lib import gpu_linalg @@ -635,7 +634,7 @@ def tridiagonal( superdiagonal. ``taus`` contains the scalar factors of the elementary Householder reflectors. """ - return tridiagonal_p.bind(lax_internal.asarray(a), lower=lower) + return tridiagonal_p.bind(lax.asarray(a), lower=lower) def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: @@ -753,7 +752,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): dtype_rule = partial( - lax_internal.naryop_dtype_rule, result_dtype, accepted_dtypes, name, + lax.naryop_dtype_rule, result_dtype, accepted_dtypes, name, require_same=require_same) shape_rule = partial( linalg_shape_rule, multiple_results, supports_batching, ranks, @@ -783,7 +782,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, batching.expand_dims_batcher, prim) return prim -standard_linalg_primitive = partial(linalg_primitive, lax_internal._input_dtype) +standard_linalg_primitive = partial(linalg_primitive, lax._input_dtype) # Primitive implementations @@ -806,7 +805,7 @@ def _cholesky_jvp_rule(primals, tangents): def phi(X): l = _tril(X) return l / lax.expand_dims( - lax_internal._const(X, 1) + lax_internal._eye(X.dtype, (X.shape[-1], X.shape[-1])), + lax._const(X, 1) + lax._eye(X.dtype, (X.shape[-1], X.shape[-1])), range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, @@ -869,7 +868,8 @@ def _drotg_nonzero(x, y): np.array(1., dtype=x.dtype), np.array(0., dtype=x.dtype), ) - return lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + return control_flow.cond( + y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) def _drot( first_vector: Array, second_vector: Array, @@ -1063,7 +1063,7 @@ def _eigh_jacobi_shape_rule(shape, **_): def _eigh_jacobi_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - return lax_internal._complex_basetype(dtype), dtype + return lax._complex_basetype(dtype), dtype def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): operand_aval, = ctx.avals_in @@ -1118,7 +1118,7 @@ def _eigh_shape_rule(shape, *, subset_by_index, **_): def _eigh_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - return dtype, lax_internal._complex_basetype(dtype) + return dtype, lax._complex_basetype(dtype) def _eigh_cpu_gpu_lowering( ctx, operand, *, lower, sort_eigenvalues, subset_by_index, @@ -1185,7 +1185,7 @@ def _eigh_jvp_rule( # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) - eye_n = lax_internal._eye(a.dtype, (n, n)) + eye_n = lax._eye(a.dtype, (n, n)) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. with config.numpy_rank_promotion("allow"): Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n @@ -1332,7 +1332,7 @@ def body(k, state): # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a_outer = a[:, k, None] * a[k, None] a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), - a_outer, lax_internal._zeros(a_outer)) + a_outer, lax._zeros(a_outer)) return pivot, perm, a pivot = lax.full((min(m, n),), 0, dtype=np.int32) @@ -1341,7 +1341,7 @@ def body(k, state): # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a) - return lax.fori_loop(0, min(m, n), body, (pivot, perm, a)) + return control_flow.fori_loop(0, min(m, n), body, (pivot, perm, a)) def _lu_blocked(a, block_size=128): @@ -1405,10 +1405,10 @@ def _lu_jvp_inner(lu, a_dot, permutation): l_padding = [(0, 0, 0)] * 2 l_padding[-1] = (0, m - k, 0) - zero = lax_internal._const(lu, 0) + zero = lax._const(lu, 0) l = lax.pad(_tril(lu[:, :k], -1), zero, l_padding) - l = l + lax_internal._eye(dtype, (m, m)) - u_eye = lax.pad(lax_internal._eye(dtype, (n - k, n - k)), zero, + l = l + lax._eye(dtype, (m, m)) + u_eye = lax.pad(lax._eye(dtype, (n - k, n - k)), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * 2 u_padding[-2] = (0, n - k, 0) @@ -1602,8 +1602,8 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k permutation, swaps = core.standard_insert_pvary(permutation, swaps) - result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, - (permutation, swaps)) + result, _ = control_flow.fori_loop(np.array(0, np.int32), upper, + _lu_pivots_body_fn, (permutation, swaps)) return result @@ -1776,7 +1776,7 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma): qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = lax.expand_dims(lax_internal._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) + I = lax.expand_dims(lax._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r @@ -1789,7 +1789,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else core.min_dim(m, n) - q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), + q = lax.broadcast_in_dim(lax._eye(a.dtype, (m, k)), (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) @@ -1809,7 +1809,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): q = householder_product(r[..., :m, :m], taus) elif full_matrices: pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)] - q = lax.pad(r, lax_internal._zero(r), pads) + q = lax.pad(r, lax._zero(r), pads) q = householder_product(q, taus) else: q = householder_product(r, taus) @@ -1909,7 +1909,7 @@ def _svd_shape_rule(shape, *, full_matrices, compute_uv, subset_by_index, **_): def _svd_dtype_rule(dtype, *, compute_uv, **_): dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) if compute_uv: return real_dtype, dtype, dtype else: @@ -1941,7 +1941,7 @@ def _svd_jvp_rule( return (s,), (ds,) s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) - s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else + s_diffs_zeros = lax._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s)) @@ -1967,12 +1967,12 @@ def _svd_jvp_rule( def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] - s = lax.full(batch_shape + (0,), 0, dtype=lax_internal._complex_basetype(a.dtype)) + s = lax.full(batch_shape + (0,), 0, dtype=lax._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) - u = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (size, size)), + u = lax.broadcast_in_dim(lax._eye(a.dtype, (size, size)), (*batch_shape, size, size), (len(batch_shape), len(batch_shape) + 1)) else: @@ -2371,7 +2371,7 @@ def _tridiagonal_shape_rule(shape, **_): def _tridiagonal_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) return dtype, real_dtype, real_dtype, dtype def _tridiagonal_cpu_gpu_lowering(ctx, a, *, lower, target_name_prefix): @@ -2499,7 +2499,7 @@ def fwd(carry, args): dp_next = (d - a * dp) / (b - a * cp) return (cp_next, dp_next), (cp, dp) - (_, final), (cp, dp) = lax.scan( + (_, final), (cp, dp) = control_flow.scan( fwd, (du[0] / d[0], b[0] / d[0]), (dl[1:], d[1:], du[1:], b[1:, :]), unroll=32) @@ -2508,7 +2508,7 @@ def bwd(xn, args): x = dp - cp * xn return x, xn - end, ans = lax.scan(bwd, final, (cp, dp), unroll=32, reverse=True) + end, ans = control_flow.scan(bwd, final, (cp, dp), unroll=32, reverse=True) return lax.concatenate((end[None], ans), 0) def _tridiagonal_solve_jax(dl, d, du, b, **_): @@ -2573,7 +2573,7 @@ def _solve(a: Array, b: Array) -> Array: # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( - lax.custom_linear_solve, + control_flow.custom_linear_solve, lambda x: _broadcasted_matvec(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) @@ -2594,12 +2594,12 @@ def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 def _tril(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k) + mask = lax._tri(bool, (N, M), k) return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m)) def _triu(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k - 1) + mask = lax._tri(bool, (N, M), k - 1) return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m) def _construct_diagonal(s: Array) -> Array: @@ -2624,7 +2624,7 @@ def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value: """Wrapper around XLA `Select` that broadcasts its arguments.""" - out_shapes = list(lax_internal.broadcast_shapes( + out_shapes = list(lax.broadcast_shapes( tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape))) which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y), (which_aval, x_aval, y_aval), diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 5baa5489149f..06d4ec2f4281 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -22,12 +22,11 @@ import itertools import math -import jax -from jax import tree_util from jax._src import core from jax._src import config from jax._src import dispatch from jax._src import dtypes +from jax._src import tree_util from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray @@ -37,14 +36,15 @@ from jax._src.interpreters import pxla from jax._src.mesh import get_abstract_mesh from jax._src.core import abstract_token, pvary +from jax._src.lax import control_flow from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client as xc +from jax._src.typing import Array from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) -import jax.numpy as jnp import numpy as np unsafe_map, map = map, safe_map # type: ignore @@ -714,7 +714,7 @@ def ragged_all_to_all( axis_index_groups=axis_index_groups) -def axis_index(axis_name: AxisName) -> jax.Array: +def axis_index(axis_name: AxisName) -> Array: """Return the index along the mapped axis ``axis_name``. Args: @@ -755,7 +755,7 @@ def axis_index(axis_name: AxisName) -> jax.Array: return axis_index_p.bind(axis_name=axis_name) else: inner_size = 1 - index = jnp.asarray(0) + index = lax.asarray(0) for name in reversed(axis_name): index += axis_index(name) * inner_size inner_size *= axis_size(name) @@ -1600,11 +1600,11 @@ def _ragged_all_to_all_transpose( operand_t = ragged_all_to_all_p.bind( t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) - mask = jax.numpy.cumsum( - jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + mask = control_flow.cumsum( + lax.full(t.shape[0], 0, dtype='int32').at[output_offsets_].set(1) .at[output_offsets_ + recv_sizes].add(-1)) - mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),)) - output_t = jax.numpy.where(mask, 0, t) + mask = lax.expand_dims(mask, (*range(1, t.ndim),)) + output_t = lax.select(mask, lax._zeros(t), t) return [operand_t, output_t] + [None] * 4 def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, @@ -2187,6 +2187,8 @@ def bind(leaf): def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): + from jax._src.shard_map import shard_map # pytype: disable=import-error + if isinstance(axis_name, tuple): assert axis_name, 'empty axis name' if len(axis_name) > 1: @@ -2207,8 +2209,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): def f(): return axis_index_p.bind(axis_name=axis_name) return mlir.lower_fun( - lambda: [jax.shard_map(f, check_vma=False, in_specs=(), - out_specs=P())()])(ctx)[0] + lambda: [shard_map(f, check_vma=False, in_specs=(), + out_specs=P())()])(ctx)[0] nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index a486bda28486..023fed34fdc9 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -30,7 +30,7 @@ standard_naryop, standard_unop, sub, _const, _dtype, _float, _nary_lower_hlo, _ones, _isnan, _reduce) -from jax._src.lax.control_flow import while_loop +from jax._src.lax.control_flow.loops import while_loop from jax._src import dtypes from jax._src.interpreters import ad diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 7a939f1cd55f..e322fc447e7c 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -18,13 +18,14 @@ from functools import partial import warnings -from jax import tree_util +from jax._src import ad_util from jax._src import api_util from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import tree_util from jax._src import util -from jax._src.core import ShapedArray +from jax._src.core import ClosedJaxpr, ShapedArray, jaxpr_as_fun from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -35,11 +36,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array + import numpy as np -from jax._src.core import ClosedJaxpr -from jax._src.core import jaxpr_as_fun -from jax._src.interpreters.ad import jvp_jaxpr -from jax._src import ad_util map = util.safe_map zip = util.safe_zip @@ -404,7 +402,7 @@ def reduce_window_jvp( init_value_tangent = map(ad_util.instantiate, init_value_tangent) c_reduction_jaxpr = ClosedJaxpr(reduction_jaxpr, consts) - jvp_reduction = jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] + jvp_reduction = ad.jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] def wrapper(left, right): pl, tl = util.split_list(left, [n]) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index bcb0015d71b2..07b6887ae5c1 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -63,12 +63,14 @@ pytype_strict_library( "//jax:api", "//jax:core", "//jax:dtypes", + "//jax:lax", "//jax:mesh", "//jax:mlir", "//jax:mosaic_gpu", "//jax:pallas", "//jax:partial_eval", "//jax:source_info_util", + "//jax:state_types", "//jax:tree_util", "//jax:util", "//jax/_src/lib", @@ -84,6 +86,7 @@ pytype_strict_library( "//jax:core", "//jax:dtypes", "//jax:effects", + "//jax:lax", "//jax:mosaic_gpu", "//jax:pretty_printer", "//jax:state_types", @@ -103,8 +106,10 @@ pytype_strict_library( "//jax", "//jax:core", "//jax:frozen_dict", + "//jax:lax", "//jax:mosaic_gpu", "//jax:pretty_printer", + "//jax:state_types", "//jax:tree_util", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index b13967d5b61c..f7c4a05205d3 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -65,9 +65,11 @@ pytype_strict_library( "//jax:config", "//jax:core", "//jax:custom_derivatives", + "//jax:lax", "//jax:mlir", "//jax:partial_eval", "//jax:source_info_util", + "//jax:state_types", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 3a54644dbd37..5b83b6a3cb64 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -751,7 +751,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): broadcast_to_p = core.Primitive('broadcast_to') def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error a = jnp.asarray(a) if a.shape == shape: return a @@ -759,7 +759,7 @@ def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: @broadcast_to_p.def_impl def _broadcast_to_impl(a, *, shape): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error return jnp.broadcast_to(a, shape) @broadcast_to_p.def_abstract_eval diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 2dd57dcde0ca..ec979715242b 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -16,13 +16,14 @@ from functools import partial from typing import Callable -import jax +from jax._src import api from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src.state import AbstractRef +from jax._src.lax import lax from jax._src.state.primitives import ref_get +from jax._src.state.types import AbstractRef from jax._src.typing import DTypeLike from jax._src.util import safe_map, safe_zip, split_list @@ -112,7 +113,7 @@ def bitcast(x, dtype: DTypeLike): x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes( -1, -2 ) - y = jax.lax.bitcast_convert_type(x, dtype) + y = lax.bitcast_convert_type(x, dtype) if x_bitwidth > y_bitwidth: y = y.swapaxes(-1, -2).reshape(shape) return y @@ -120,4 +121,4 @@ def bitcast(x, dtype: DTypeLike): def eval_bitcast_shape(x, dtype: DTypeLike): f = partial(bitcast, dtype=dtype) - return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape + return api.eval_shape(f, api.ShapeDtypeStruct(x.shape, x.dtype)).shape diff --git a/jax/_src/tpu/linalg/stack.py b/jax/_src/tpu/linalg/stack.py index 882195f17d51..0225e66f43d8 100644 --- a/jax/_src/tpu/linalg/stack.py +++ b/jax/_src/tpu/linalg/stack.py @@ -22,10 +22,12 @@ from typing import Any -import jax from jax import lax import jax.numpy as jnp +from jax._src import tree_util + + class Stack: """A bounded functional stack implementation. Elements may be pytrees.""" def __init__(self, size, data): @@ -45,7 +47,7 @@ def create(capacity: int, prototype: Any) -> Stack: """ return Stack( jnp.array(0, jnp.int32), - jax.tree_util.tree_map( + tree_util.tree_map( lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype)) def empty(self) -> Any: @@ -56,23 +58,23 @@ def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, - jax.tree_util.tree_map( + tree_util.tree_map( lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0), self._data, elem)) def pop(self) -> tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" - elem = jax.tree_util.tree_map( + elem = tree_util.tree_map( lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data) def flatten(self): - leaves, treedef = jax.tree_util.tree_flatten(self._data) + leaves, treedef = tree_util.tree_flatten(self._data) return ([self._size] + leaves), treedef @staticmethod def unflatten(treedef, leaves): - return Stack(leaves[0], jax.tree_util.tree_unflatten(treedef, leaves[1:])) + return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:])) -jax.tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) +tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index db7865124687..f4c11ea768b5 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -214,7 +214,7 @@ def body_fun(state): _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry) _, _, t, _, last_t, interp_coeff = carry relative_output_time = (target_t - last_t) / (t - last_t) - y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) + y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) # pytype: disable=attribute-error return carry, y_target f0 = func_(y0, ts[0]) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 61a058ff9189..e6414305a51b 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -48,6 +48,7 @@ py_library_providing_imports_info( "//jax:api", "//jax:core", "//jax:custom_derivatives", + "//jax:lax", ], ) From 4c54c0220c7833cfdcbecb7412ebeb1b80c426b1 Mon Sep 17 00:00:00 2001 From: Buddh Prakash Date: Wed, 18 Jun 2025 11:30:54 -0700 Subject: [PATCH 1750/1769] Fix bugs in the double_buffered_pipeline example --- docs/pallas/pipelining.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index a9fb770c9b53..0ef407b4ee27 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -211,7 +211,7 @@ Once the loop has been unrolled, the pipelining transformation simply involves i # Itr 4 - No copy-in copy_in_wait(X[1]) Y[1] = X[1] + 1 - copy_out_start(Y[1], A[2]) + copy_out_start(Y[1], A[3]) copy_out_wait(Y[1]) @@ -244,7 +244,7 @@ Next, we can push the `copy_out_wait` as late as possible, right before we need # Itr 4 - No copy-in copy_in_wait(X[1]) Y[1] = X[1] + 1 - copy_out_start(Y[1], A[2]) + copy_out_start(Y[1], A[3]) copy_out_wait(Y[0]) # Epilogue @@ -297,18 +297,19 @@ def double_buffered_pipeline( for i in range(grid_size): cur_slot = i % 2 next_slot = (i + 1) % 2 - if i < grid_size: - copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot]) + if (i + 1) < grid_size: + copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot]) copy_in_wait(in_sram[cur_slot]) - kernel(inputs, outputs) + kernel(in_sram[cur_slot], out_ram[cur_slot]) copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)]) if i > 0: copy_out_wait(out_sram[next_slot]) # Epilogue - copy_out_wait(out_sram[1]) + last_slot = (grid_size - 1) % 2 + copy_out_wait(out_sram[last_slot]) ``` From cb2315a89ac0d084406048a394a56b0c76544107 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 18 Jun 2025 20:32:49 +0000 Subject: [PATCH 1751/1769] Update jax requirements lock files after 0.6.2 release --- build/requirements.in | 6 +-- build/requirements_lock_3_11.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_12.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_13.txt | 70 +++++++++++++++-------------- build/requirements_lock_3_13_ft.txt | 70 +++++++++++++++-------------- 5 files changed, 147 insertions(+), 139 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index e3aae21d9b4d..a88c194f7b8e 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -16,11 +16,11 @@ wheel # JAX's own libraries. We include these in the requirements so you can # bazel test without building jaxlib and without manually updating the # the requirements files. -jaxlib==0.6.1 +jaxlib==0.6.2 # The with-cuda extra also includes NVIDIA's pip packages. -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" # TPU dependencies libtpu ; sys_platform == "linux" and platform_machine == "x86_64" diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 9574be1a972a..3560da350aa5 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -154,43 +154,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ - --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ - --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ - --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ - --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ - --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ - --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ - --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ - --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ - --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ - --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 # via -r build/requirements.in -jaxlib==0.6.1 \ - --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ - --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ - --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ - --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ - --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ - --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ - --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ - --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ - --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ - --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ - --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ - --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ - --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ - --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ - --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ - --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ - --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ - --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ @@ -450,7 +450,9 @@ nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 13c4269186ba..743fbbba325f 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -154,43 +154,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ - --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ - --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ - --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ - --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ - --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ - --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ - --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ - --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ - --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ - --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 # via -r build/requirements.in -jaxlib==0.6.1 \ - --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ - --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ - --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ - --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ - --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ - --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ - --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ - --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ - --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ - --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ - --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ - --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ - --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ - --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ - --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ - --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ - --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ - --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ @@ -450,7 +450,9 @@ nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 2aaf276d94df..aa45b473d9ae 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -181,43 +181,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ - --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ - --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ - --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ - --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ - --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ - --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ - --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ - --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ - --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ - --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 # via -r build/requirements.in -jaxlib==0.6.1 \ - --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ - --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ - --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ - --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ - --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ - --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ - --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ - --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ - --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ - --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ - --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ - --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ - --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ - --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ - --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ - --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ - --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ - --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 # via -r build/requirements.in kiwisolver==1.4.7 \ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ @@ -505,7 +505,9 @@ nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index aa6f2daa569b..348fa9628a76 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -172,43 +172,43 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:4c97d10a5a9ac09fa001568cac3b715014e8dbbc2cd86763753f58e5a730c333 \ - --hash=sha256:967076cfb6f2e33959e7376663599aa0c11cc0ede8f2f51a206da0a1d422c6bb +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" \ - --hash=sha256:1885f15be38faecccfbf24b184ffdc1d0d363717eadd2534d5759c0d3d0af523 \ - --hash=sha256:1fbf8d4b42455443a089afd1a88fb106a51ba1075fc6884b339dc96571c5b617 \ - --hash=sha256:2a3578dc0b7d44cc1b0233b0fe7ad764265381095d7eac64c56bd01b34be76f2 \ - --hash=sha256:425ccf13cbdd4678b1109f843988157a59e4f4d9bc298205acb16df048a31c38 \ - --hash=sha256:b77804e0e4d923ad39909095ff7c1b723eac6f3ee5f9ffcb80597ba867b572b8 \ - --hash=sha256:b8bff7a5fc7a416717e1d59da9728a1f7aad07a8b65afa0f86962d43ed0e654f \ - --hash=sha256:ba09bad8d5c9c33326e6374b0669dc325e7a4fb0d57798df3dcd560693c877dc \ - --hash=sha256:bb64a0c801f93a718a654dfc69742f2fd60a26074312204ebdf4fe403d9e2bc4 \ - --hash=sha256:d9c2be8ebb4ef6ae11dd7345ae864ac49d00bd455d06fff925a5d1eb266b02f1 \ - --hash=sha256:da9f7dc9243ec28e03c0e3a39852b4246fa9cfc3dcd51e4286d82097f5c695c0 +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 # via -r build/requirements.in -jaxlib==0.6.1 \ - --hash=sha256:02bac5153389f01616516a9fd1dcd6038d23ee50681dac14e4ddbc43ccb3133a \ - --hash=sha256:11fcc4b1c741a1e0057f2ffa77d5a82bfe7ee97c3864ed88df67493e789b9173 \ - --hash=sha256:2168217ec37bf951ca33377d3e0953178ba5cade95f194211d9ab2d53dcd2201 \ - --hash=sha256:277cc7e9d657d0893a559261277b3eae916ad7fa73e300a629261fb537dca0f1 \ - --hash=sha256:3301addee156f55d1f8079f80b314d89b80094740b7d64e5ec6e7ef2e1febbd7 \ - --hash=sha256:5a90ee7c59b2c00773026fbf918269c7a8676a6a81a34a03af919f7d7bdce9a8 \ - --hash=sha256:5e4f49113a527bcbac70c9e7074e95d8abfa35c3d67c2fed01f77a7abfd317aa \ - --hash=sha256:76d6f65f3153ffb70e20a76b915d4431823cf70a786d86ba1b76a9c5bf66a0a4 \ - --hash=sha256:7ae5815ada71b69532ce443a11160a3ed25c67e82a294a0d89af9d4d27429434 \ - --hash=sha256:8106dc316eb440d07b9d4628a0c8e2acf76da5606742c9f5c33104aaa77b0ac2 \ - --hash=sha256:acfe91eb44c29dbbd1f1f65f9bd66e1aef4483f57ad5e3d645129f3ec9ecde2a \ - --hash=sha256:b12c8842b2dfc0770ca3785e183f7bed3fa1c2596c720591dbfbe29a05045108 \ - --hash=sha256:b58c29fe747622b70946ea87823ad39202cc83da3d93a5293b432173b738a868 \ - --hash=sha256:d039124468565bbf39363b1504c190e6719e6af89a7948dee256f1dee813bb94 \ - --hash=sha256:d0c343c51b1052593edb603ddf58cf7f98812b2951ae6c45bd6e93e3e1f2f621 \ - --hash=sha256:e14195c23eecd559a61c31027b4172e912e5a50f630320918ffdfae83090ca5a \ - --hash=sha256:e734be70fe3e1fa2a31415362721189d974d10a66b0f5396c84585587d101b15 \ - --hash=sha256:f4ca75d9d47a2e90099adfede0e9c926b83ef703d349b3289b8c88e861c09e5d +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 # via -r build/requirements.in kiwisolver==1.4.8 \ --hash=sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50 \ @@ -456,7 +456,9 @@ nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jax-cuda12-plugin nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ From 6567192aeec2d38b6bfddae46661e1d5bf2eaaa2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Jun 2025 14:05:17 -0700 Subject: [PATCH 1752/1769] [doc] fix build error The issue was that CompilerParams is ambiguous when generating docs from type annotations; we can fix this by specifying which CompilerParams is intended. --- jax/_src/pallas/pallas_call.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2a47ce2f8cf0..295e60bb0b4d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1513,7 +1513,7 @@ def pallas_call( interpret: Any = False, name: str | None = None, compiler_params: ( - Mapping[Backend, CompilerParams] | CompilerParams | None + Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None ) = None, cost_estimate: CostEstimate | None = None, backend: Backend | None = None, @@ -1611,8 +1611,8 @@ def pallas_call( def _normalize_compiler_params( - compiler_params: Mapping[Backend, CompilerParams] | CompilerParams | None, -) -> Mapping[Backend, CompilerParams]: + compiler_params: Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None, +) -> Mapping[Backend, pallas_core.CompilerParams]: if compiler_params is None: return FrozenDict({}) if isinstance(compiler_params, CompilerParams): From 5e290ddd45529199fc3c4c725d776ddc948c661a Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 18 Apr 2025 17:51:22 -0700 Subject: [PATCH 1753/1769] [Pallas][Mosaic GPU] Add GPU pipelining docs --- docs/_static/pallas/gpu/pipeline_matmul.svg | 1 + .../_static/pallas/gpu/pipeline_matmul_ws.svg | 1 + .../pallas/gpu/warp_specialization.svg | 1 + docs/conf.py | 2 + docs/pallas/gpu/index.rst | 1 + docs/pallas/gpu/pipelining.ipynb | 428 ++++++++++++++++++ docs/pallas/gpu/pipelining.md | 332 ++++++++++++++ docs/pallas/pipelining.ipynb | 20 +- docs/pallas/pipelining.md | 11 +- 9 files changed, 783 insertions(+), 14 deletions(-) create mode 100644 docs/_static/pallas/gpu/pipeline_matmul.svg create mode 100644 docs/_static/pallas/gpu/pipeline_matmul_ws.svg create mode 100644 docs/_static/pallas/gpu/warp_specialization.svg create mode 100644 docs/pallas/gpu/pipelining.ipynb create mode 100644 docs/pallas/gpu/pipelining.md diff --git a/docs/_static/pallas/gpu/pipeline_matmul.svg b/docs/_static/pallas/gpu/pipeline_matmul.svg new file mode 100644 index 000000000000..7037695e33e9 --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/pipeline_matmul_ws.svg b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg new file mode 100644 index 000000000000..3a07ba7e9ece --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/warp_specialization.svg b/docs/_static/pallas/gpu/warp_specialization.svg new file mode 100644 index 000000000000..85fbce49fa0b --- /dev/null +++ b/docs/_static/pallas/gpu/warp_specialization.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index dd7533aecf83..a84ed24540a3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -135,6 +135,7 @@ def _do_not_evaluate_in_jax( 'notebooks/*.md', 'pallas/quickstart.md', 'pallas/pipelining.md', + 'pallas/gpu/pipelining.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', @@ -231,6 +232,7 @@ def _do_not_evaluate_in_jax( # Requires accelerators 'pallas/quickstart.*', 'pallas/pipelining.*', + 'pallas/gpu/pipelining.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst index 2d95d5c928c4..3fec14832337 100644 --- a/docs/pallas/gpu/index.rst +++ b/docs/pallas/gpu/index.rst @@ -7,6 +7,7 @@ Backend specific documentation for the Mosaic GPU backend. :maxdepth: 2 reference + pipelining .. toctree:: :caption: Guides diff --git a/docs/pallas/gpu/pipelining.ipynb b/docs/pallas/gpu/pipelining.ipynb new file mode 100644 index 000000000000..c1bcc27c2dbf --- /dev/null +++ b/docs/pallas/gpu/pipelining.ipynb @@ -0,0 +1,428 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9552ee76", + "lines_to_next_cell": 0 + }, + "source": [ + "(pallas_mgpu_pipelining)=" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bJ5yuIr-M0x0" + }, + "source": [ + "\n", + "## Mosaic GPU Pipelining\n", + "\n", + "This guide covers software pipelining using the Mosaic GPU backend for Pallas.\n", + "\n", + "For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dGAa3iO5DoRT" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import lax\n", + "from jax import numpy as jnp\n", + "from jax.experimental.pallas import mosaic_gpu as plgpu\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pv9j90hVyswo" + }, + "source": [ + "\n", + "### Pipelining with Mosaic GPU\n", + "\n", + "The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options.\n", + "\n", + "- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially.\n", + "- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations.\n", + "- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy\n", + "- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers.\n", + "\n", + "#### Compatibility API using `pl.pallas_call`\n", + "\n", + "As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining:\n", + "- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!**\n", + "- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`.\n", + "- `delay_release`: identical to the option in `plgpu.emit_pipeline`.\n", + "\n", + "Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations.\n", + "\n", + "We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qp3X6wylJtoa" + }, + "source": [ + "### GPU Memory Spaces\n", + "\n", + "Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`.\n", + "\n", + "- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`.\n", + "\n", + "- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`.\n", + "\n", + "The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0uzcrDCtKABQ" + }, + "source": [ + "### Example: Matmul Kernel on Hopper GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vILVdlqEdoEK" + }, + "source": [ + "Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore.\n", + "\n", + "Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KSvqVNdy726B" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "10ebHCQ571Fn" + }, + "source": [ + "\n", + "Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref.\n", + "\n", + "To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration.\n", + "- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction.\n", + "- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6Vf5_VA9iCD1" + }, + "outputs": [], + "source": [ + "def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):\n", + " dtype = jnp.float16\n", + " swizzle_elems = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = swizzle_elems\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % swizzle_elems == 0\n", + "\n", + " # Note: Transforms will be inferred automatically\n", + " # by Mosaic GPU in the future.\n", + " transforms = (\n", + " plgpu.TilingTransform((8, swizzle_elems)),\n", + " plgpu.SwizzleTransform(swizzle),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):\n", + " def pipeline_step(_, a_smem, b_smem):\n", + " plgpu.wgmma(acc, a_smem, b_smem)\n", + " plgpu.wgmma_wait(1)\n", + "\n", + " # pl.program_id obtains the index into the grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " pipeline = plgpu.emit_pipeline(\n", + " pipeline_step,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " max_concurrent_steps=2,\n", + " delay_release=1,\n", + " )\n", + "\n", + " pipeline(a_gmem, b_gmem)\n", + " # Store WGMMA accumulator to SMEM and then to GMEM.\n", + " o_smem[...] = acc[...].astype(dtype)\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n, tile_n)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=[\n", + " plgpu.SMEM((tile_m, tile_n), jnp.float16),\n", + " plgpu.ACC((tile_m, tile_n), jnp.float32)\n", + " ],\n", + " # grid specifies the CUDA grid.\n", + " # Instances of `kernel` will be executed in parallel over this grid.\n", + " grid=(grid_m, grid_n),\n", + " grid_names=(\"m\", \"n\"),\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lIYV7PN9J8Px" + }, + "source": [ + "### Warp Specialization\n", + "\n", + "Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally.\n", + "\n", + "In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n-y90IC7v7vL" + }, + "source": [ + "\n", + "
\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZH0Pui5kFSdD" + }, + "source": [ + "Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments:\n", + "\n", + "```python\n", + "plgpu.emit_pipeline_warp_specialized(\n", + " body: Callable,\n", + " *\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " out_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " max_concurrent_steps: int,\n", + " compute_context: Callable\n", + " num_compute_wgs: int,\n", + " memory_registers: int\n", + " wg_axis: str,\n", + " memory_thread_idx: int | None = None,\n", + ")\n", + "```\n", + "\n", + "There are a few arguments specific to this pipeline emitter, which are:\n", + "- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`.\n", + "- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered.\n", + "- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`).\n", + "- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread.\n", + "- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills.\n", + "\n", + "The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZGbK5gIvFZKy" + }, + "source": [ + "### Example: Matrix Multiplication with Warp Specialization\n", + "\n", + "The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYWBqa9-bp2p" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OkWmfqn7b53M" + }, + "source": [ + "We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value.\n", + "\n", + "Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EJhWnwJlFGaT" + }, + "outputs": [], + "source": [ + "def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,\n", + " compute_wgs=2):\n", + " dtype = jnp.float16\n", + " elems_128b = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = elems_128b\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % elems_128b == 0\n", + "\n", + " transforms = (\n", + " plgpu.TilingTransform((8, elems_128b)),\n", + " plgpu.SwizzleTransform(128),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem):\n", + " wg_idx = lax.axis_index(\"wg\")\n", + " wg_slice = pl.ds(wg_idx * tile_n, tile_n)\n", + " # pl.program_id obtains the index into the pallas_call grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " def compute_thread(pipeline):\n", + " acc = plgpu.layout_cast(\n", + " jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,\n", + " )\n", + " # yield marks the place where the pipelined loop will be inserted.\n", + " # Its argument are the initial carry values, and its result is the carry\n", + " # value after the loop completes.\n", + " final_acc = pipeline(acc)\n", + " o_smem[:, wg_slice] = final_acc[...].astype(dtype)\n", + "\n", + " def kernel_body(_, a_smem, b_smem, carry):\n", + " acc = carry\n", + " b_smem_wg = b_smem.at[:, wg_slice]\n", + " def do_wgmma(acc_ref):\n", + " plgpu.wgmma(acc_ref, a_smem, b_smem_wg)\n", + " acc = pl.run_state(do_wgmma)(\n", + " plgpu.ACC.init(acc))\n", + " return acc\n", + "\n", + " pipeline = plgpu.emit_pipeline_warp_specialized(\n", + " kernel_body,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " compute_context=compute_thread,\n", + " max_concurrent_steps=2,\n", + " num_compute_wgs=compute_wgs,\n", + " memory_registers=40,\n", + " memory_thread_idx=2,\n", + " wg_axis=\"wg\",\n", + " )\n", + " # Call the pipeline\n", + " pipeline(a_gmem, b_gmem)\n", + " # Copy the output from SMEM to GMEM.\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=[\n", + " plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)\n", + " ],\n", + " grid=(grid_m, grid_n // 2),\n", + " grid_names=(\"m\", \"n\"),\n", + " num_threads=3, # 2 compute, 1 memory.\n", + " thread_name=\"wg\"\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul_warp_specialized(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab_gpu", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/gpu/pipelining.md b/docs/pallas/gpu/pipelining.md new file mode 100644 index 000000000000..a2b361f181e1 --- /dev/null +++ b/docs/pallas/gpu/pipelining.md @@ -0,0 +1,332 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + +(pallas_mgpu_pipelining)= + + + +## Mosaic GPU Pipelining + +This guide covers software pipelining using the Mosaic GPU backend for Pallas. + +For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler. + + + +```python id="dGAa3iO5DoRT" +import jax +from jax import lax +from jax import numpy as jnp +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.experimental import pallas as pl +import numpy as np +``` + + + +### Pipelining with Mosaic GPU + +The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options. + +- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially. +- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations. +- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy +- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers. + +#### Compatibility API using `pl.pallas_call` + +As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining: +- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!** +- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`. +- `delay_release`: identical to the option in `plgpu.emit_pipeline`. + +Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations. + +We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization). + + + + +### GPU Memory Spaces + +Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`. + +- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`. + +- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`. + +The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication). + + + +### Example: Matmul Kernel on Hopper GPUs + + + +Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore. + +Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication. + + + + +
+ + + + + +Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref. + +To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration. +- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction. +- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel. + + +```python id="6Vf5_VA9iCD1" +def matmul(a, b, tile_m=128, tile_n=128, swizzle=128): + dtype = jnp.float16 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + tile_k = swizzle_elems + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % swizzle_elems == 0 + + # Note: Transforms will be inferred automatically + # by Mosaic GPU in the future. + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc): + def pipeline_step(_, a_smem, b_smem): + plgpu.wgmma(acc, a_smem, b_smem) + plgpu.wgmma_wait(1) + + # pl.program_id obtains the index into the grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + pipeline = plgpu.emit_pipeline( + pipeline_step, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), + ], + grid=(grid_k,), + max_concurrent_steps=2, + delay_release=1, + ) + + pipeline(a_gmem, b_gmem) + # Store WGMMA accumulator to SMEM and then to GMEM. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n, tile_n) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[ + plgpu.SMEM((tile_m, tile_n), jnp.float16), + plgpu.ACC((tile_m, tile_n), jnp.float32) + ], + # grid specifies the CUDA grid. + # Instances of `kernel` will be executed in parallel over this grid. + grid=(grid_m, grid_n), + grid_names=("m", "n"), + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul(a, b) + +np.testing.assert_allclose(result, a @ b) +``` + + +### Warp Specialization + +Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally. + +In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA. + + + + + + +
+ + + + + +Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments: + +```python +plgpu.emit_pipeline_warp_specialized( + body: Callable, + * + grid: tuple[int, ...], + in_specs: Sequence[pallas_core.BlockSpec] = (), + out_specs: Sequence[pallas_core.BlockSpec] = (), + max_concurrent_steps: int, + compute_context: Callable + num_compute_wgs: int, + memory_registers: int + wg_axis: str, + memory_thread_idx: int | None = None, +) +``` + +There are a few arguments specific to this pipeline emitter, which are: +- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`. +- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered. +- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`). +- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread. +- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills. + +The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads. + + + + +### Example: Matrix Multiplication with Warp Specialization + +The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix. + + + + + +
+ + + + +We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value. + +Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel. + + + +```python id="EJhWnwJlFGaT" +def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128, + compute_wgs=2): + dtype = jnp.float16 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + tile_k = elems_128b + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % elems_128b == 0 + + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem): + wg_idx = lax.axis_index("wg") + wg_slice = pl.ds(wg_idx * tile_n, tile_n) + # pl.program_id obtains the index into the pallas_call grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + def compute_thread(pipeline): + acc = plgpu.layout_cast( + jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + # yield marks the place where the pipelined loop will be inserted. + # Its argument are the initial carry values, and its result is the carry + # value after the loop completes. + final_acc = pipeline(acc) + o_smem[:, wg_slice] = final_acc[...].astype(dtype) + + def kernel_body(_, a_smem, b_smem, carry): + acc = carry + b_smem_wg = b_smem.at[:, wg_slice] + def do_wgmma(acc_ref): + plgpu.wgmma(acc_ref, a_smem, b_smem_wg) + acc = pl.run_state(do_wgmma)( + plgpu.ACC.init(acc)) + return acc + + pipeline = plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms + ), + ], + grid=(grid_k,), + compute_context=compute_thread, + max_concurrent_steps=2, + num_compute_wgs=compute_wgs, + memory_registers=40, + memory_thread_idx=2, + wg_axis="wg", + ) + # Call the pipeline + pipeline(a_gmem, b_gmem) + # Copy the output from SMEM to GMEM. + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[ + plgpu.SMEM((tile_m, tile_n * 2), jnp.float16) + ], + grid=(grid_m, grid_n // 2), + grid_names=("m", "n"), + num_threads=3, # 2 compute, 1 memory. + thread_name="wg" + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul_warp_specialized(a, b) + +np.testing.assert_allclose(result, a @ b) +``` diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb index 6770351d7760..6a4158001813 100644 --- a/docs/pallas/pipelining.ipynb +++ b/docs/pallas/pipelining.ipynb @@ -13,7 +13,7 @@ "\n", "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", "\n", - "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references.\n" + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`.\n" ] }, { @@ -38,7 +38,7 @@ "source": [ "## Memory Hierarchies\n", "\n", - "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", + "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", @@ -276,7 +276,7 @@ " # Itr 4 - No copy-in\n", " copy_in_wait(X[1])\n", " Y[1] = X[1] + 1\n", - " copy_out_start(Y[1], A[2])\n", + " copy_out_start(Y[1], A[3])\n", " copy_out_wait(Y[1])\n", "\n", "\n", @@ -309,7 +309,7 @@ " # Itr 4 - No copy-in\n", " copy_in_wait(X[1])\n", " Y[1] = X[1] + 1\n", - " copy_out_start(Y[1], A[2])\n", + " copy_out_start(Y[1], A[3])\n", " copy_out_wait(Y[0])\n", "\n", " # Epilogue\n", @@ -362,18 +362,19 @@ " for i in range(grid_size):\n", " cur_slot = i % 2\n", " next_slot = (i + 1) % 2\n", - " if i < grid_size:\n", - " copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])\n", + " if (i + 1) < grid_size:\n", + " copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])\n", " copy_in_wait(in_sram[cur_slot])\n", "\n", - " kernel(inputs, outputs)\n", + " kernel(in_sram[cur_slot], out_ram[cur_slot])\n", "\n", " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", " if i > 0:\n", " copy_out_wait(out_sram[next_slot])\n", "\n", " # Epilogue\n", - " copy_out_wait(out_sram[1])\n", + " last_slot = (grid_size - 1) % 2\n", + " copy_out_wait(out_sram[last_slot])\n", "```" ] }, @@ -700,7 +701,7 @@ "source": [ "This result is completely wrong!\n", "\n", - "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", "\n", "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." ] @@ -853,6 +854,7 @@ "provenance": [] }, "jupytext": { + "formats": "ipynb,md", "main_language": "python" }, "kernelspec": { diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md index 0ef407b4ee27..2bf21f0d8c27 100644 --- a/docs/pallas/pipelining.md +++ b/docs/pallas/pipelining.md @@ -1,6 +1,7 @@ --- jupyter: jupytext: + formats: ipynb,md main_language: python text_representation: extension: .md @@ -20,7 +21,7 @@ jupyter: Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. -This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references. +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`. @@ -63,7 +64,7 @@ In order to perform computation on values X and Y that live in HBM, we need to: Let’s implement a Pallas function that does just that! -```python id="IrPhDFnT3Nvw" executionInfo={"status": "ok", "timestamp": 1744764235906, "user_tz": 420, "elapsed": 108, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4" +```python executionInfo={"elapsed": 108, "status": "ok", "timestamp": 1744764235906, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="IrPhDFnT3Nvw" outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4" # Note: This is a TPU example. def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref): @@ -481,7 +482,7 @@ As a concrete example, let's consider performing the following computation for r -```python id="4qz1ET-_f9fJ" executionInfo={"status": "ok", "timestamp": 1744763773938, "user_tz": 420, "elapsed": 244, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="e43067ef-933a-45a5-912a-e224151cfa60" +```python executionInfo={"elapsed": 244, "status": "ok", "timestamp": 1744763773938, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="4qz1ET-_f9fJ" outputId="e43067ef-933a-45a5-912a-e224151cfa60" x = jnp.ones((8, 1024, 1024)) jnp.sum(x, axis=0) ``` @@ -490,7 +491,7 @@ jnp.sum(x, axis=0) To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first. -```python id="ZEi1_vQVf-81" executionInfo={"status": "ok", "timestamp": 1744763774254, "user_tz": 420, "elapsed": 79, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" +```python executionInfo={"elapsed": 79, "status": "ok", "timestamp": 1744763774254, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="ZEi1_vQVf-81" outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" # Note: This is a TPU example. # Warning: this implementation is incorrect! @@ -522,7 +523,7 @@ There are two errors inside this kernel. First, we are accumulating along the fi After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. -```python id="XtgD4nMa9_Bd" executionInfo={"status": "ok", "timestamp": 1744763774523, "user_tz": 420, "elapsed": 104, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" +```python executionInfo={"elapsed": 104, "status": "ok", "timestamp": 1744763774523, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="XtgD4nMa9_Bd" outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" # Note: This is a TPU example. def correct_sum_kernel(x_ref, o_ref): From a47ae57356bb19ac127b0c3c674417b44cd70b6c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Jun 2025 15:50:38 -0700 Subject: [PATCH 1754/1769] Add wrap_negative_indices paramter to jnp.ndarray.at[] --- jax/_src/basearray.pyi | 32 ++++-- jax/_src/numpy/array_methods.py | 166 ++++++++++++++++++------------- jax/_src/numpy/indexing.py | 22 ++-- tests/lax_numpy_indexing_test.py | 124 ++++++++++++++++++++--- 4 files changed, 242 insertions(+), 102 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index d92bd61f23e2..54098a081f39 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -286,25 +286,35 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_sharding: Sharding | P | None = None) -> Array: ... + out_sharding: Sharding | P | None = None, wrap_negative_indices: bool = True) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None, + wrap_negative_indices: bool = True) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def subtract(self, values: Any, *, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 2de9822206a2..27a1ddce7685 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -695,10 +695,8 @@ class _IndexUpdateHelper: By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). - Arguments - --------- - mode : str - Specify out-of-bound indexing mode. Options are: + Args: + mode: string specifying out-of-bound indexing mode. Options are: - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that @@ -709,40 +707,56 @@ class _IndexUpdateHelper: - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` argument specifies the value that will be returned. - See :class:`jax.lax.GatherScatterMode` for more details. - - indices_are_sorted : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are sorted in ascending order, which can lead to more efficient execution - on some backends. - unique_indices : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are unique, which can result in more efficient execution on some backends. - fill_value : Any - Only applies to the ``get()`` method: the fill value to return for out-of-bounds - slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for - inexact types, the largest negative value for signed types, the largest positive - value for unsigned types, and ``True`` for booleans. - - Examples - -------- - >>> x = jnp.arange(5.0) - >>> x - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[2].add(10) - Array([ 0., 1., 12., 3., 4.], dtype=float32) - >>> x.at[10].add(10) # out-of-bounds indices are ignored - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[20].add(10, mode='clip') - Array([ 0., 1., 2., 3., 14.], dtype=float32) - >>> x.at[2].get() - Array(2., dtype=float32) - >>> x.at[20].get() # out-of-bounds indices clipped - Array(4., dtype=float32) - >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN - Array(nan, dtype=float32) - >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value - Array(-1., dtype=float32) + See :class:`jax.lax.GatherScatterMode` for more details. + wrap_negative_indices: If True (default) then negative indices indicate position + from the end of the array, similar to Python and NumPy indexing. If False, then + negative indices are considered out-of-bounds and behave according to the + ``mode`` parameter. + fill_value: Only applies to the ``get()`` method: the fill value to return for + out-of-bounds slices when ``mode`` is ``'fill'``. Ignored otherwise. Defaults + to ``NaN`` for inexact types, the largest negative value for signed types, the + largest positive value for unsigned types, and ``True`` for booleans. + indices_are_sorted: If True, the implementation will assume that the (normalized) + indices passed to ``at[]`` are sorted in ascending order, which can lead to more + efficient execution on some backends. If True but the indices are not actually + sorted, the output is undefined. + unique_indices: If True, the implementation will assume that the (normalized) indices + passed to ``at[]`` are unique, which can result in more efficient execution on some + backends. If True but the indices are not actually unique, the output is undefined. + + Examples: + >>> x = jnp.arange(5.0) + >>> x + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[2].get() + Array(2., dtype=float32) + >>> x.at[2].add(10) + Array([ 0., 1., 12., 3., 4.], dtype=float32) + + By default, out-of-bound indices are ignored in updates, but this behavior + can be controlled with the ``mode`` parameter: + + >>> x.at[10].add(10) # dropped + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[20].add(10, mode='clip') # clipped + Array([ 0., 1., 2., 3., 14.], dtype=float32) + + For ``get()``, out-of-bound indices are clipped by default: + + >>> x.at[20].get() # out-of-bounds indices clipped + Array(4., dtype=float32) + >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + Array(nan, dtype=float32) + >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value + Array(-1., dtype=float32) + + Negative indices count from the end of the array, but this behavior can + be disabled by setting ``wrap_negative_indices = False``: + + >>> x.at[-1].set(99) + Array([ 0., 1., 2., 3., 99.], dtype=float32) + >>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop') # dropped! + Array([0., 1., 2., 3., 4.], dtype=float32) """ __slots__ = ("array",) @@ -780,7 +794,8 @@ def __repr__(self) -> str: def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | lax.GatherScatterMode | None = None, fill_value: ArrayLike | None = None, - out_sharding: Sharding | PartitionSpec | None = None): + out_sharding: Sharding | PartitionSpec | None = None, + wrap_negative_indices: bool = True): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style @@ -788,7 +803,7 @@ def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, the usual array indexing syntax in that it allows additional keyword arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ if out_sharding is not None: assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) @@ -797,17 +812,19 @@ def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value, + normalize_indices=wrap_negative_indices, out_sharding=out_sharding) def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> None: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> None: """Pure equivalent of ``x[idx] = y``. Returns the value of ``x`` that would result from the NumPy-style :mod:`indexed assignment ` ``x[idx] = y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ out_s = core.typeof(self.array).sharding if out_s.mesh.empty or out_s.mesh._are_all_axes_auto_or_manual: @@ -815,11 +832,12 @@ def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, return scatter._scatter_update(self.array, self.index, values, lax.scatter, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, - out_sharding=out_s) + out_sharding=out_s, normalize_indices=wrap_negative_indices) def apply(self, func: Callable[[ArrayLike], Array], *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. Returns the value of ``x`` that would result from applying the unary @@ -831,7 +849,7 @@ def apply(self, func: Callable[[ArrayLike], Array], *, Note that in the current implementation, ``scatter_apply`` is not compatible with automatic differentiation. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ def _scatter_apply(x, indices, y, dims, **kwargs): return lax.scatter_apply(x, indices, func, dims, update_shape=y.shape, **kwargs) @@ -839,120 +857,134 @@ def _scatter_apply(x, indices, y, dims, **kwargs): lax_internal._zero(self.array), _scatter_apply, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) def add(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] += y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] += y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_add, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) def subtract(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] -= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] -= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_sub, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) def multiply(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] *= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] *= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + mode=mode, normalize_indices=wrap_negative_indices) mul = multiply def divide(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] /= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] /= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.divide( self.array, scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) def power(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] **= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] **= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.power( self.array, scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) def min(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_min, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) def max(self, values: ArrayLike, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | lax.GatherScatterMode | None = None) -> Array: + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_max, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) _array_operators = { "getitem": _getitem, diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 31ad0ba4ed86..934246dc8cbd 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -620,17 +620,19 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None, def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): + mode=None, fill_value=None, normalize_indices=True, + out_sharding=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # For simplicity of generated primitives, we call lax.dynamic_slice in the - # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. - - result = _attempt_rewriting_take_via_slice(arr, idx, mode, out_sharding) - if result is not None: - return result + # For simplicity of generated primitives, we call lax.slice or lax.dynamic_slice + # in the simplest cases: i.e. non-dynamic arrays indexed with integers and slices. + # TODO(jakevdp): lower to slice even when normalize_indices is False + if normalize_indices: + result = _attempt_rewriting_take_via_slice(arr, idx, mode, out_sharding) + if result is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if config.dynamic_shapes.value and arr.ndim > 0: @@ -647,7 +649,7 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, internal_gather = partial( _gather, treedef=treedef, static_idx=static_idx, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode, fill_value=fill_value) + mode=mode, fill_value=fill_value, normalize_indices=normalize_indices) if out_sharding is not None: return auto_axes(internal_gather, out_sharding=out_sharding )(arr, dynamic_idx) @@ -658,9 +660,9 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(1, 2)) def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, - unique_indices, mode, fill_value): + unique_indices, mode, fill_value, normalize_indices): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update + indexer = index_to_gather(np.shape(arr), idx, normalize_indices=normalize_indices) # shared with _scatter_update jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 745cab59cf1b..6364137cc1c6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1139,6 +1139,47 @@ def testStrIndexingError(self): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] + @jtu.sample_product( + mode=["promise_in_bounds", "fill", "clip", "drop"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + idx_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, mode, wrap_negative_indices, shape, idx_shape): + """Test the behavior of the wrap_negative_indices parameter in array.at[...].get()""" + fill_value = 99 + + data_rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_uniform(self.rng(), low=-12, high=12) + + args_maker = lambda: [data_rng(shape, 'float32'), idx_rng(idx_shape, 'int32')] + + def jnp_fun(data, idx): + return jnp.array(data).at[idx].get( + mode=mode, + fill_value=fill_value, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + out_of_bound = (idx < 0) | (idx >= len(data)) + safe_idx = np.where(out_of_bound, 0, idx) + result = data[safe_idx] + if mode in ["fill", "drop"]: + result = np.where(out_of_bound, fill_value, result) + elif mode in ["promise_in_bounds", "clip"]: + result = np.where(idx < 0, data[0], + np.where(idx >= len(data), data[-1], + result)) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return result + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1291,22 +1332,30 @@ class UpdateOps(enum.Enum): def np_fn(op, indexer, x, y): x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - UpdateOps.SUB: lambda: x[indexer] - y, - UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] / y.astype(x.dtype)), - UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] ** y.astype(x.dtype)), - UpdateOps.MIN: lambda: np.minimum(x[indexer], y), - UpdateOps.MAX: lambda: np.maximum(x[indexer], y), - }[op]() + if op == UpdateOps.UPDATE: + x[indexer] = y + elif op == UpdateOps.ADD: + np.add.at(x, indexer, y) + elif op == UpdateOps.SUB: + np.subtract.at(x, indexer, y) + elif op == UpdateOps.MUL: + np.multiply.at(x, indexer, y) + elif op == UpdateOps.DIV: + with jtu.ignore_warning(category=RuntimeWarning): + np.divide.at(x, indexer, y) + elif op == UpdateOps.POW: + with jtu.ignore_warning(category=RuntimeWarning): + np.power.at(x, indexer, y) + elif op == UpdateOps.MIN: + np.minimum.at(x, indexer, y.astype(x.dtype)) + elif op == UpdateOps.MAX: + np.maximum.at(x, indexer, y.astype(x.dtype)) + else: + raise ValueError(f"{op=}") return x def jax_fn(op, indexer, x, y, indices_are_sorted=False, - unique_indices=False, mode=None): + unique_indices=False, mode=None, wrap_negative_indices=True): x = jnp.array(x) return { UpdateOps.UPDATE: x.at[indexer].set, @@ -1318,7 +1367,8 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False, UpdateOps.MIN: x.at[indexer].min, UpdateOps.MAX: x.at[indexer].max, }[op](y, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + wrap_negative_indices=wrap_negative_indices) def dtypes(op): if op == UpdateOps.UPDATE: @@ -1431,6 +1481,52 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) + @jtu.sample_product( + op=UpdateOps, + mode=["fill", "clip"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + update_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, op, mode, wrap_negative_indices, shape, update_shape): + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_unique_int(self.rng(), high=shape[0]) + + def args_maker(): + data = rng(shape, 'float32').round(1) + update = rng(update_shape, 'float32').round(1) + # we need indices to be unique, so we generate unique values in [0, N) + # and then subtract N from half of them. To test out-of-bound behavior + # we push the bottom and top index out-of-bounds + idx = idx_rng(update_shape, 'int32') + idx = np.where(rng(update_shape, bool), idx, idx - shape[0]) + idx[idx == shape[0] - 1] = shape[0] + 2 # out-of-bound positive + idx[idx == -shape[0]] = -(shape[0] + 2) # out-of-bound negative + return data, idx, update + + def jnp_fun(data, idx, values): + return UpdateOps.jax_fn(op, idx, data, values, + mode=mode, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx, values): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + if mode in ["fill", "drop", "promise_in_bounds"]: + ok = (idx >= 0) & (idx < len(data)) + idx = idx[ok] + values = values[ok] + elif mode == "clip": + idx = np.where(idx < 0, 0, idx) + idx = np.where(idx >= len(data), len(data) - 1, idx) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return UpdateOps.np_fn(op, idx, data, values) + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + @jtu.sample_product( [dict(name=name, mode=mode, shape=shape, indexer=indexer, update_shape=update_shape) From 9d1b01e01fa5c0da0a8b7dbc7af1794a141be614 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Wed, 18 Jun 2025 15:53:35 -0700 Subject: [PATCH 1755/1769] [JAX] Skip failing tpu tests until June 30th. PiperOrigin-RevId: 773094489 --- tests/pallas/tpu_pallas_async_test.py | 4 ++-- tests/pallas/tpu_pallas_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index f224987fe2a7..68c51b63c183 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -205,7 +205,7 @@ def setUp(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_async_copy(self): @@ -834,7 +834,7 @@ def setUp(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_stateful_async_copy(self): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index cbb82efaee3f..f7d076965fd3 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1225,7 +1225,7 @@ def test_output_dma_semaphore_ref(self): self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.') # TODO(subhankarshah): Remove after all required changes are in. - if not jtu.if_cloud_tpu_at_least(2025, 6, 20): + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): self.skipTest('Requires libtpu built after 2025-06-20') def kernel(x_hbm_ref, y_hbm_ref, sem_out): From d1883667df7b01ffe31282c26a4857e113a7e49f Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Wed, 18 Jun 2025 23:09:43 +0000 Subject: [PATCH 1756/1769] add cudnn sdpa mla support --- jax/_src/cudnn/fused_attention_stablehlo.py | 75 +++++++++++---------- tests/fused_attention_stablehlo_test.py | 61 +++++++++++++++-- 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 417198e575d2..f246ff4b7aa7 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -319,7 +319,8 @@ def check_eq(a, b, c, msg): f"got {v_blocks} vs {vB * v_blocks_per_batch}") check_eq(qB, kB, vB, "QKV batch") - check_eq(qH, kH, vH, "QKV dim_per_head") + if qH != kH: + raise ValueError(f"QK must have same head dim, got {qH} vs {kH}") if kN != vN: raise ValueError(f"KV must have same number of heads, got {kN} vs {vN}") if kS != vS: @@ -353,33 +354,35 @@ def check_seqlen_offsets(tensor, name): def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, - is_paged_attention=False, is_fp8=False): + query, key, value, layout: int, cudnn_version, has_bias, is_training, + is_packed=False, is_paged_attention=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: - _, _, T, H = query.shape - _, _, S, _ = key.shape + _, _, T, qH = query.shape + _, _, S, vH = value.shape else: - _, T, _, H = query.shape - _, S, _, _ = key.shape + _, T, _, qH = query.shape + _, S, _, vH = value.shape # Flash attention conditions if is_fp8: # FP8 specific conditions - if not ((is_training and H == 128 and T % 128 == 0 and S % 128 == 0) or - (not is_training and H <= 256 and H % 16 == 0)): + if not ((is_training and qH == 128 and T % 128 == 0 and S % 128 == 0) or + (not is_training and qH <= 256 and qH % 16 == 0)): raise NotImplementedError( - f"Unsupported sequence length Q {T}, KV {S} and head dim {H} for FP8." + f"Unsupported sequence length Q {T}, KV {S} and head dim {qH} for FP8." ) else: # bf16/fp16 attention conditions # Check the head dim. is_on_hopper = is_cuda_compute_capability_equal("9.0") H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128 - if not (H <= H_max and H % 8 == 0): + # check if multi-head latent attention is needed + is_mla = qH != vH + if not (qH <= H_max and qH % 8 == 0): raise NotImplementedError( f"The head dim must be <= {H_max} and a multiple of 8, " - f"but got {H}." + f"but got {qH}." ) # Check patterns with bias, seqlen should be divisible by 2 @@ -393,6 +396,9 @@ def check_is_flash_attention( "Packed layout requires cudnn version >= 9.6 and at least hopper arch.") if is_paged_attention and cudnn_version < 90500: raise NotImplementedError("Page attention requires cudnn version >= 9.5.") + if is_mla and (cudnn_version < 91000 or not check_compute_capability("9.0")): + raise NotImplementedError( + "mla requires cudnn version >= 9.10 and at least hopper arch.") def check_cudnn_version(): # check if cuDNN is installed @@ -423,7 +429,7 @@ def _dot_product_attention_fwd( sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, False, + query, key, value, layout, cudnn_version, bias is not None, False, get_max_seg_per_batch(q_offsets) > 1, check_is_paged_attention(page_table_k)) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, @@ -442,7 +448,7 @@ def _dot_product_attention_fwd_rule( return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, True, + query, key, value, layout, cudnn_version, bias is not None, True, get_max_seg_per_batch(q_offsets) > 1) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, @@ -567,11 +573,12 @@ def _dot_product_attention_fwd_abstract( query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape - _, _, S, _ = key.shape + _, _, S, H = value.shape + output_shape = (B, N, T, H) else: B, T, N, _ = query.shape - _, S, _, _ = key.shape - output_shape = query.shape + _, S, _, H = value.shape + output_shape = (B, T, N, H) max_seg_per_batch = get_max_seg_per_batch(q_offsets) softmax_stat_shape = (B * max_seg_per_batch, N, T) @@ -631,24 +638,24 @@ def _dot_product_attention_fwd_cuda_lowering( variadic_args, mask_type, layout, sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape - key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, N, T, H = query_shape - _, _, S, _ = key_shape + B, N, T, qk_H = query_shape + _, _, S, v_H = value_shape output_layout = (3, 2, 1, 0) output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, N, H = query_shape - _, S, _, _ = key_shape + B, T, N, qk_H = query_shape + _, S, _, v_H = value_shape output_layout = (3, 1, 2, 0) output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) is_paged_attention = check_is_paged_attention(ir.RankedTensorType(page_table_k.type)) - output_shape = (B, N, T, H) + output_shape = (B, N, T, v_H) softmax_stat_shape = (B * max_seg_per_batch, N, T) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) @@ -713,26 +720,26 @@ def _dot_product_attention_bwd_cuda_lowering( query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, q_N, T, H = query_shape - _, k_N, S, _ = key_shape + B, q_N, T, qk_H = query_shape + _, v_N, S, v_H = value_shape grad_layout = (3, 2, 1, 0) grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, q_N, H = query_shape - _, S, k_N, _ = key_shape + B, T, q_N, qk_H = query_shape + _, S, v_N, v_H = value_shape grad_layout = (3, 1, 2, 0) grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) - grad_query_shape = (B, q_N, T, H) - grad_key_shape = (B, k_N, S, H) - grad_value_shape = (B, k_N, S, H) + grad_query_shape = (B, q_N, T, qk_H) + grad_key_shape = (B, v_N, S, qk_H) + grad_value_shape = (B, v_N, S, v_H) has_bias, has_dbias = variadic_args max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) @@ -1207,7 +1214,7 @@ def _dot_product_attention_fp8_fwd( fp8_params_fwd, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=False) + query, key, value, layout, cudnn_version, is_training=False) descale_q, descale_k, descale_v, descale_s, scale_s, scale_o = fp8_params_fwd outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, @@ -1221,7 +1228,7 @@ def _dot_product_attention_fp8_fwd_rule( fp8_params, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=True) + query, key, value, layout, cudnn_version, is_training=True) outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, *params_from_keys(fp8_params, fp8_params_keys_fwd), diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 7c3bd145dab5..394e5b4b0e8f 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -853,6 +853,58 @@ def unpaged(paged, page_table): out_ref = sdpa_infer_ref(q, k, v) self.assertArraysAllClose(out_ref, out_ref, rtol=1e-2, atol=1e-2) + @jtu.run_on_devices("cuda") + def test_sdpa_mla(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 91000: + self.skipTest("Requires >= cuDNN 9.10.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + query = jax.random.normal( + k1, (4, 1024, 4, 128), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 128), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + in_shardings = ( + qkv_sharding, qkv_sharding, qkv_sharding) + out_shardings = qkv_sharding + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + jitted_sdpa_inference_ref = jax.jit( + partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + out = jitted_sdpa_inference(query, key, value) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + @jtu.run_on_devices("cuda") def test_layouts(self): if jax.device_count() < 4: @@ -899,15 +951,16 @@ def test_sdpa_utils(self): expected_pass = k query = jnp.empty((4, sql_q, 4, head_dim)) key = jnp.empty((4, sql_v, 4, head_dim)) + value = jnp.empty((4, sql_v, 4, head_dim)) if expected_pass: check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) + query, key, value, AttentionLayout.BNTH.value, cudnn_version, + has_bias, is_training) else: with self.assertRaises(NotImplementedError): check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) + query, key, value, AttentionLayout.BNTH.value, cudnn_version, + has_bias, is_training) @jtu.with_config(jax_numpy_dtype_promotion="standard") From 1e4a0f766b5617b5e66a6f1ac43f204f5e38e563 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 18 Jun 2025 19:18:25 -0700 Subject: [PATCH 1757/1769] Pass shardy option through jax config. Remove `use_shardy_partitioner` in `get_compile_options`. It can be from the jax.config directly. PiperOrigin-RevId: 773150657 --- jax/_src/compiler.py | 7 +------ jax/_src/interpreters/pxla.py | 2 -- jax/experimental/jax2tf/tests/sharding_test.py | 1 - 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 2288278b41bd..2aef697a353d 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -117,7 +117,6 @@ def get_compile_options( num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, - use_shardy_partitioner: bool = False, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape: list[int] | None = None, auto_spmd_partitioning_mesh_ids: list[int] | None = None, @@ -137,10 +136,6 @@ def get_compile_options( `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. - use_shardy_partitioner: boolean indicating whether to use the Shardy - partitioner in XLA. Shardy is a new open sourced propagation framework for - MLIR. Currently Shardy is experimental in JAX. See - www.github.com/openxla/shardy. use_auto_spmd_partitioning: boolean indicating whether to automatically generate XLA shardings for SPMD partitioner. auto_spmd_partitioning_mesh_shape: device mesh shape used to create @@ -160,7 +155,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = config.use_shardy_partitioner.value if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 03336655ceba..c4585663b68d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1066,7 +1066,6 @@ def from_hlo(hlo: ir.Module, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=False, - use_shardy_partitioner=config.use_shardy_partitioner.value, env_options_overrides=compiler_options, detailed_logging=compiler.use_detailed_logging(hlo), backend=pci.backend, @@ -2692,7 +2691,6 @@ def create_compile_options( num_partitions=num_partitions, device_assignment=xla_device_assignment, use_spmd_partitioning=spmd_lowering, - use_shardy_partitioner=config.use_shardy_partitioner.value, use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, fdo_profile=fdo_profile, diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 33e78da18021..20193a931b63 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -109,7 +109,6 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, - use_shardy_partitioner=jax.config.jax_use_shardy_partitioner, ) executable = backend.compile_and_load( jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore From 0b54a1e8eaed12f7c7f401d0efbec5574ccf099f Mon Sep 17 00:00:00 2001 From: Will Froom Date: Thu, 19 Jun 2025 00:31:05 -0700 Subject: [PATCH 1758/1769] Reenable AVX512 after LLVM fix upstream. PiperOrigin-RevId: 773232950 --- tests/BUILD | 4 ---- third_party/xla/llvm_fix.patch | 32 -------------------------------- third_party/xla/workspace.bzl | 1 - 3 files changed, 37 deletions(-) delete mode 100644 third_party/xla/llvm_fix.patch diff --git a/tests/BUILD b/tests/BUILD index 95fecf89c7dc..15cf4330d28b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -932,10 +932,6 @@ jax_multiplatform_test( "notsan", # Times out. ], }, - env = { - # TODO(b/424430576): something is going wrong with AVX512 code generation. - "XLA_FLAGS": "--xla_cpu_max_isa=AVX2", - }, shard_count = { "cpu": 40, "gpu": 40, diff --git a/third_party/xla/llvm_fix.patch b/third_party/xla/llvm_fix.patch deleted file mode 100644 index 4bf402095517..000000000000 --- a/third_party/xla/llvm_fix.patch +++ /dev/null @@ -1,32 +0,0 @@ -diff --git a/third_party/llvm/llvm_jax_fix.patch b/third_party/llvm/llvm_jax_fix.patch -new file mode 100644 -index 0000000000..5a2a60205e ---- /dev/null -+++ b/third_party/llvm/llvm_jax_fix.patch -@@ -0,0 +1,14 @@ -+diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp -+index 96be91256915d..8bcd8670879a9 100644 -+--- a/llvm/lib/Target/X86/X86ISelLowering.cpp -++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp -+@@ -59383,7 +59383,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, -+ -+ // We can always convert per-lane vXf64 shuffles into VSHUFPD. -+ if (!IsSplat && -+- (VT == MVT::v4f64 || (VT == MVT::v8f64 && Subtarget.useAVX512Regs())) && -++ ((NumOps == 2 && VT == MVT::v4f64) || -++ (NumOps == 4 && VT == MVT::v8f64 && Subtarget.useAVX512Regs())) && -+ all_of(Ops, [](SDValue Op) { return Op.hasOneUse(); })) { -+ // Collect the individual per-lane v2f64/v4f64 shuffles. -+ MVT OpVT = Ops[0].getSimpleValueType(); -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index ae0c1b550f..ce408f554a 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -22,6 +22,7 @@ def repo(name): - "//third_party/llvm:mathextras.patch", - "//third_party/llvm:toolchains.patch", - "//third_party/llvm:zstd.patch", -+ "//third_party/llvm:llvm_jax_fix.patch", - ], - link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"}, - ) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 76ea227200e2..dccf8d47a6cb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -30,7 +30,6 @@ def repo(): sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), - patch_file = ["//third_party/xla:llvm_fix.patch"], ) # For development, one often wants to make changes to the TF repository as well From e55f55fce653e108c7c152559eeb01d7706f59d3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 19 Jun 2025 00:50:54 -0700 Subject: [PATCH 1759/1769] [Mosaic GPU] Delete dead code in `layout_inference.py`. Also do a couple of clean ups. PiperOrigin-RevId: 773238722 --- .../mosaic/gpu/layout_inference.py | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index c010bf181bce..06932c7facec 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -25,7 +25,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import math as mlir_math -from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector import numpy as np @@ -54,8 +53,8 @@ def _set_layout_attributes( in_layouts: list[ir.Attribute], out_layouts: list[ir.Attribute], ): - op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) - op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) + op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) + op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) def _choose_representative_layout( @@ -665,21 +664,6 @@ def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir. raise ValueError("None of uses are in the given block") -def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp): - mem_ref_type = ir.MemRefType(view_op.result.type) - memref_new_type = ir.MemRefType.get( - mem_ref_type.shape, - mem_ref_type.element_type, - layout, - mem_ref_type.memory_space, - ) - uses = list(view_op.result.uses) - with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)): - cast_op = memref.cast(memref_new_type, view_op.result) - for use in uses: - use.owner.operands[use.operand_number] = cast_op - - class TraversalOrder(enum.Enum): """Traversal orders with respect to the data flow for IR.""" @@ -755,7 +739,7 @@ def update_default_vector_size_from_vector(v: ir.Value): max_vec_size_for_v = ( np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE ) - desired_vec_size = 64 // utils.bitwidth(v.type.element_type) + desired_vec_size = 64 // utils.bitwidth(v.type.element_type) # pytype: disable=attribute-error default_vector_size = min( default_vector_size, max_vec_size_for_v, desired_vec_size ) From b99d004c6c64243e59cc81e00e641860fc0879a2 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 19 Jun 2025 00:56:45 -0700 Subject: [PATCH 1760/1769] Create a test suite for Pallas mosaic GPU tests. PiperOrigin-RevId: 773240293 --- tests/pallas/BUILD | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 6e80a96d21fa..678354a43b28 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -30,6 +30,11 @@ package( jax_generate_backend_suites() +test_suite( + name = "mosaic_gpu_tests", + tags = ["mosaic_gpu_test"], +) + jax_multiplatform_test( name = "pallas_test", srcs = [ @@ -165,6 +170,7 @@ jax_multiplatform_test( }, shard_count = 16, tags = [ + "mosaic_gpu_test", "noasan", # Times out. "nomsan", # Times out. "notsan", # Times out. @@ -238,6 +244,9 @@ jax_multiplatform_test( "gpu_h100_x32", "gpu_h100", ], + tags = [ + "mosaic_gpu_test", + ], deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", # build_cleaner: keep @@ -418,7 +427,10 @@ jax_multiplatform_test( "JAX_PALLAS_USE_MOSAIC_GPU": "1", "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", }, - tags = ["multiaccelerator"], + tags = [ + "mosaic_gpu_test", + "multiaccelerator", + ], deps = [ "//jax:extend", "//jax:pallas_mosaic_gpu", @@ -811,6 +823,9 @@ jax_multiplatform_test( ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 8, + tags = [ + "mosaic_gpu_test", + ], deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", @@ -829,6 +844,7 @@ jax_multiplatform_test( env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 8, tags = [ + "mosaic_gpu_test", # TODO(b/330364373): Remove when B200 is fully supported. "notap", ], @@ -889,6 +905,7 @@ jax_multiplatform_test( ], shard_count = 12, tags = [ + "mosaic_gpu_test", "noasan", # Times out. ], deps = [ From f99d2b4963ea2384413a591dfc4f780d0a6c557f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 19 Jun 2025 02:55:41 -0700 Subject: [PATCH 1761/1769] [Pallas:MGPU] Add docs for pl.core_map and plgpu.kernel PiperOrigin-RevId: 773272786 --- docs/conf.py | 2 + docs/pallas/gpu/reference.md | 175 ++++++++++++++++++++++++++++++++++- 2 files changed, 176 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index a84ed24540a3..3cd3b8ea8776 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -207,6 +207,8 @@ def _do_not_evaluate_in_jax( # -- Options for myst ---------------------------------------------- myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors myst_enable_extensions = ['dollarmath'] +myst_ref_domains = ["py"] +myst_all_links_external = False nb_execution_mode = "force" nb_execution_allow_errors = False nb_merge_streams = True diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 1a4f39dff5f2..d68730619b06 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -458,7 +458,179 @@ is still work in progress. Stay tuned! ## Using `core_map` -TODO +`pl.pallas_call` is suitable for kernels where a single Pallas thread can +perform the whole computation for an entire CUDA block. The `pl.core_map` +function relaxes this restriction, allowing for using multiple threads within a +single block (e.g. for warp specialization) or across multiple blocks in a block +cluster (e.g. to utilize multicast TMA). + +### Replacing `pl.pallas_call` with `pl.core_map` or `plgpu.kernel` + +Let us begin with a simple Pallas kernel that increments an array: + +```python +@functools.partial( + pl.pallas_call, + grid=(2,), + in_specs=[pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))], + out_specs=pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,)) + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), # Total output shape +) +def run_kernel(x_ref, y_ref): + # x_ref and y_ref are in SMEM! + y_ref[...] = x_ref[...] + 1 + +x = jnp.arange(256, jnp.float32) +y = run_kernel(x) +np.testing.assert_array_equal(y, x + 1) +``` + +We can write a similar kernel using `pl.core_map`. One big difference is that +unlike `pl.pallas_call`, no GMEM<->SMEM copies will be inserted automatically. +If you want them, you can either insert them yourself or use the +{py:func}`plgpu.emit_pipeline ` +helper. + +```python +@pl.run_state +def run_kernel(x_ref, y_ref): + # Here, we're not in the kernel yet! pl.run_state simply changes the JAX + # immutable arrays into mutable GMEM (not SMEM!) references. + + # Define the mesh: 2 CUDA blocks over 1 axis called "x" + mesh = plgpu.Mesh(grid=(2,), grid_names=("x",)) + + @pl.core_map(mesh) # core_map executes the body + def kernel_body(): + # Once we enter the pl.core_map scope, we are in the body of the kernel. + block_slice = pl.ds(lax.axis_index("x") * 128, 128) + o_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(128, jnp.float32) +y_init = jnp.zeros_like(x) +y = run_kernel(x, y_init) +np.testing.assert_array_equal(y, x + 1) +``` + +While `pl.core_map` is a powerful API, it is also quite low-level and is pretty +much always used in under `pl.run_state` (to make JAX arrays into refs) or +`pl.run_scoped` (to allocate for scratch refs). For that reason, we also +provide a convenience API `plgpu.kernel`: + +```python +mesh = plgpu.Mesh(grid=(2,), grid_names=("x",)) + +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), + mesh=mesh +) +def increment_kernel_core_map(x_ref, y_ref): + # x_ref and y_ref are in GMEM! + block_slice = pl.ds(lax.axis_index("x") * 128, 128) + o_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) # No need to preallocate outputs as in pl.core_map. +np.testing.assert_array_equal(y, x + 1) +``` + +```{note} +The `plgpu.Mesh` used with `pl.core_map` defines a topology for computation +*within a single GPU*, specifying how work is distributed across CUDA blocks +(the `grid`), Pallas threads within a block (`num_threads`), and potentially +CUDA block clusters (`cluster`). This is analogous to how `jax.sharding.Mesh` +defines a topology for distributed computation *across multiple devices* in JAX. +Both involve SPMD programs executing across the defined topology. Furthermore, +you can run "collectives" over the Pallas threads and cluster (e.g., using +`plgpu.ClusterBarrier` or collective async copies), similar to how JAX +collectives (`psum`, `all_gather`, etc.) operate across devices in a JAX `Mesh`. +Both also use named axes, and `lax.axis_index(axis_name)` can be used to get a +thread's or block's coordinate. +``` + +### Using multiple Pallas threads per CUDA block + +Below, you can find an example of two Pallas threads within a single block +synchronizing through a barrier and even exchanging data through SMEM. + +```python +mesh = plgpu.Mesh(num_threads=2, thread_name="pallas_thread") +@functools.partial( + plgpu.kernel, out_shape=x, mesh=mesh, scratch_shapes=[plgpu.Barrier()] +) +def run_kernel(x_ref, y_ref, barrier_ref): + thread_id = jax.lax.axis_index("pallas_thread") + + @pl.when(thread_id == 0) + def producer_thread(): + smem_val = x_ref[...] + 1 + plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread + + @pl.when(thread_id == 1) + def consumer_thread(): + plgpu.barrier_wait(barrier_ref) # Wait for the producer thread + out_ref[...] = x_ref[...] + 1 + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) # There's no need to preallocate the input anymore. +np.testing.assert_array_equal(y, x + 2) +``` + +While this example is simple, you can find a more complicated example in the +[synchronization section](#cross-thread-synchronization). + +Multiple threads are frequently used in high-performance kernels such as the +latest flash attention variants or ping-pong matrix multiplication. In both of +those, there are 2 compute threads in the program that use the SM's ALU +and TensorCore in an alternating fashion to ensure no execution conflicts. + +Another common technique is to allocate one Pallas thread and devote it entirely +to scheduling asynchronous copies for data consumed by other threads. While +implementing this scheme from scratch can be complicated, we provide a +convenient helper API: `plgpu.emit_pipeline_warp_specialized`. + +### Using CUDA block clusters + +The kernel below launches a single cluster of 2 CUDA blocks and uses the TMA +multicast feature to collectively perform a copy of GMEM into SMEM of both +blocks. All blocks participating in the collective copy must schedule the exact +same copy for the program to be valid. + +```python +mesh = plgpu.Mesh(cluster=(2,), cluster_names=("cluster",)) + +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32), + mesh=mesh, + scratch_shapes=[plgpu.SMEM((128,), jnp.float32), plgpu.Barrier()] +) +def run_kernel(x_ref, y_ref, smem_ref, barrier_ref): + # Specifying collective_axes will enable TMA multicast automatically. + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref, collective_axes="cluster") + plgpu.barrier_wait(barrier_ref) + plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[lax.axis_index("cluster")]) + plgpu.wait_smem_to_gmem(0) + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) +# Each block gets the same data and writes it out. +np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0)) +``` + +### Collective allocations in `pl.run_scoped` + +When using `pl.core_map` with multiple Pallas threads (i.e., `num_threads > 1` +in `plgpu.Mesh`), allocations made via `pl.run_scoped` (for SMEM or Barriers) +must be performed _collectively by all threads_. This is indicated by specifying +a `collective_axis` argument to the `run_scoped`, which has two effects: +1. it promises that all threads will call the same allocation, and +2. all threads will receive the exact same allocation. + +If collective_axes is not specified or does not include the Pallas thread axis, +each thread would get its own private copy of the scratch variable. This is +usually undesired and not supported at the moment. ## Synchronization structures and primitives @@ -545,6 +717,7 @@ When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it w post progress updates to the barrier given to `plgpu.copy_gmem_to_smem`. Once the copy is complete, the barrier will complete one arrival as well. +(cross-thread-synchronization)= #### Explicit arrival (cross-thread synchronization) Any thread can explicitly arrival on a barrier using the following function: From 84066b77dfc1b5c8f6b8e0d05dbf8d65a6152b55 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 19 Jun 2025 03:13:17 -0700 Subject: [PATCH 1762/1769] =?UTF-8?q?[Mosaic=20GPU]=C2=A0Change=20layout?= =?UTF-8?q?=20inference=20tests=20to=20rely=20on=20explicit=20`layout=5Fca?= =?UTF-8?q?st`s.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This uncovers a propagation bug, whereby opportunities to propagate replicated layouts that were not already explicitly annotated as attributes downwards would be missed, because layout propagation started with a backwards pass. We now changed the implementation to start with a forward pass. Some additional edits: 1. I changed the layout in `test_optimization_barrier_op_propagates_user_layouts`. Generally, propagating replicated layouts upwards is not a safe thing to do, and we should have properly caught that. The upcoming infrastructure will recognize such issues, so we don't bother attempting to fix the underlying problem here; 2. I got rid of `test_infer_layout_propagates_func_layouts_to_ops` since we no longer care about `FuncOp`s. This simplification will allow the new infrastructure to not concern itself with `FuncOp`s, on which we were putting inconsistent expectations, and which would add quite a bit of complexity. PiperOrigin-RevId: 773277630 --- .../mosaic/gpu/layout_inference.py | 22 +-- tests/mosaic/gpu_layout_inference_test.py | 162 +++++++----------- tests/pallas/mosaic_gpu_test.py | 8 +- 3 files changed, 75 insertions(+), 117 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 06932c7facec..29a92e2d5d43 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -136,7 +136,6 @@ def _choose_representative_layout( def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: - def is_array(v: ir.Value) -> bool: return ir.VectorType.isinstance(v.type) @@ -637,11 +636,8 @@ def _infer_broadcast_in_dim_op_layout( ): out_layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) return [in_layout], [out_layout] - else: - raise NotImplementedError( - f"Unsupported layout: {in_layout} for broadcast dimensions" - f" {broadcast_dims}" - ) + + return None @partial(_add_layout_inference_rule, mgpu.WGMMAOp) @@ -709,20 +705,20 @@ def inference_step(op: ir.Operation): # # We run two passes over the module, in order to make sure that layouts # defined in the middle of the computation are propagated wherever they need - # to be propagated. We start with a backwards (root-to-parameters) pass to - # propagate the information as far up as possible, and then a forward pass - # (parameters-to-root). + # to be propagated. We start with a forward (parameters-to-root) pass to + # preserve replicated layouts as far down as possible, and then do a + # backwards (root-to-parameters) pass. # - # Backwards pass + # Forward pass for op in module.body: inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS + op, inference_step, inference_utils.TraversalOrder.FORWARD ) - # Forward pass + # Backwards pass for op in module.body: inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD + op, inference_step, inference_utils.TraversalOrder.BACKWARDS ) # At this point, layouts have been propagated as far as they could be diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index cdc840b0a6f1..114302a96c17 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -40,6 +40,13 @@ def _make_ir_context(): return context +def layout_cast(x: ir.Value, layout: mgpu.FragmentedLayout | ir.Attribute) -> ir.Value: + """Convenience wrapper around `mgpu.dialect.layout_cast`.""" + if isinstance(layout, mgpu.FragmentedLayout): + layout = layouts.to_layout_attr(layout) + return mgpu.dialect.layout_cast(x, layout) + + class LayoutInferenceTest(parameterized.TestCase): def setUp(self): @@ -135,6 +142,9 @@ def test_infer_splat_layout_for_splat_constants(self): def test_infer_layout_from_consumer_for_non_splat_constant(self): shape = (16, 8) elt_type = ir.BF16Type.get() + layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape=shape, vec_size=1) + ) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(shape, elt_type) @@ -142,12 +152,7 @@ def test_infer_layout_from_consumer_for_non_splat_constant(self): ir.FloatAttr.get(elt_type, i) for i in range(shape[0] * shape[1]) ] c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty)) - add = arith.AddFOp(c, c) - - layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape=shape, vec_size=1) - ) - add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout]) + layout_cast(c, layout) mgpu.infer_layout(self.module) @@ -157,31 +162,28 @@ def test_infer_layout_from_consumer_for_non_splat_constant(self): @parameterized.parameters(True, False) def test_infer_splat_layout_for_vector_splat(self, rhs_splat): add = splat = None + shape = (16, 8) + layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) def body(lhs, rhs): nonlocal add, splat + rhs = layout_cast(rhs, layout) if rhs_splat else rhs splat = vector.SplatOp(rhs.type, lhs) add = arith.AddFOp(splat.result, rhs) with ir.InsertionPoint(self.module.body): - shape = (16, 8) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op + func.FuncOp.from_py_func(elt_type, ty)(body) - layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) - if rhs_splat: - func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout]) mgpu.infer_layout(self.module) self.assertEmpty(splat.attributes["in_layouts"]) self.assertSequenceEqual(splat.attributes["out_layouts"], [layout]) - add_layout = layout - if not rhs_splat: - add_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(ty) - ) + add_layout = layout if rhs_splat else layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(ty) + ) self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout]) self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout]) @@ -195,18 +197,17 @@ def test_pointwise_op_propagates_argument_layouts(self, layout): def body(lhs, rhs): nonlocal add + lhs = layout_cast(lhs, layout) + rhs = layout_cast(rhs, layout) add = arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(layout.shape, ir.BF16Type.get()) func.FuncOp.from_py_func(ty, ty)(body) - [f] = self.module.body.operations - layout_attr = layouts.to_layout_attr(layout) - f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr]) - mgpu.infer_layout(self.module) + layout_attr = layouts.to_layout_attr(layout) self.assertSequenceEqual( add.attributes["in_layouts"], [layout_attr, layout_attr] ) @@ -221,15 +222,15 @@ def test_infer_layout_cast_layout(self): def body(x): nonlocal add, cast + x = mgpu.dialect.layout_cast(x, splat_layout) add = arith.AddFOp(x, x) cast = mgpu.dialect.LayoutCastOp(add.result, wgmma_layout) with ir.InsertionPoint(self.module.body): elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - func_op = func.FuncOp.from_py_func(ty)(body).func_op + func.FuncOp.from_py_func(ty)(body) - func_op.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout]) mgpu.infer_layout(self.module) self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) self.assertSequenceEqual(cast.attributes["in_layouts"], [wgmma_layout]) @@ -355,26 +356,23 @@ def body(a, b): def test_infer_layout_from_yield_op_in_layouts_for_for_op( self, shape, layout ): - add_op = for_op = yield_op = None + for_op = yield_op = None def body(lower_bound, upper_bound, step, a, b): nonlocal for_op for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b]) [loop_a, loop_b] = list(for_op.inner_iter_args) with ir.InsertionPoint(for_op.body): - nonlocal add_op, yield_op - add_op = arith.AddFOp(loop_a, loop_b) - yield_op = scf.YieldOp([add_op.result, add_op.result]) + nonlocal yield_op + add = arith.addf(loop_a, loop_b) + add = layout_cast(add, layout) + yield_op = scf.YieldOp([add, add]) with ir.InsertionPoint(self.module.body): ab_type = ir.VectorType.get(shape, ir.BF16Type.get()) i32 = ir.IntegerType.get_signless(32) func.FuncOp.from_py_func(i32, i32, i32, ab_type, ab_type)(body) - add_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts.to_layout_attr(layout)] - ) - mgpu.infer_layout(self.module) if isinstance(layout, mgpu.WGSplatFragLayout): @@ -432,20 +430,15 @@ def body(lower_bound, upper_bound, step, a, b, c): def test_infer_while_op_layouts( self, init_shape, init_layout, result_shape, result_layout ): - if init_shape: - in_type = ir.VectorType.get(init_shape, ir.F32Type.get()) - else: - in_type = ir.F32Type.get() - - if result_shape: - out_type = ir.VectorType.get(result_shape, ir.F32Type.get()) - else: - out_type = ir.F32Type.get() - + f32 = ir.F32Type.get() + in_type = ir.VectorType.get(init_shape, f32) if init_shape else f32 + out_type = ir.VectorType.get(result_shape, f32) if result_shape else f32 while_op = condition_op = yield_op = None def body(condition, init, result): nonlocal while_op, condition_op, yield_op + init = layout_cast(init, init_layout) if init_layout else init + result = layout_cast(result, result_layout) if result_layout else result while_op = scf.WhileOp([out_type], [init]) before_block = while_op.before.blocks.append(init.type) with ir.InsertionPoint(before_block): @@ -459,18 +452,9 @@ def body(condition, init, result): i1 = ir.IntegerType.get_signless(1) func.FuncOp.from_py_func(i1, in_type, out_type)(body) - [f] = self.module.body.operations - f_layouts = [] - if init_layout: - f_layouts.append(layouts.to_layout_attr(init_layout)) - if result_layout: - f_layouts.append(layouts.to_layout_attr(result_layout)) - if f_layouts: - f.attributes["in_layouts"] = ir.ArrayAttr.get(f_layouts) - mgpu.infer_layout(self.module) - if init_layout or result_layout: + if init_layout is not None or result_layout is not None: init_layouts = [layouts.to_layout_attr(init_layout)] if init_layout else [] result_layouts = [layouts.to_layout_attr(result_layout)] if result_layout else [] self.assertSequenceEqual(while_op.attributes["in_layouts"], init_layouts) @@ -510,53 +494,47 @@ def test_infer_layout_picks_non_splat_layout_over_splat_layout( self, layout ): add = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + non_splat_layout = layouts.to_layout_attr(layout) def body(lhs, rhs): nonlocal add + lhs = layout_cast(lhs, non_splat_layout) + rhs = layout_cast(rhs, splat_layout) add = arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): - shape = (32, 4) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - - f = func.FuncOp.from_py_func(ty, ty)(body).func_op - - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - non_splat_layout = layouts.to_layout_attr(layout) - - f.attributes["in_layouts"] = ir.ArrayAttr.get( - [non_splat_layout, splat_layout] - ) + func.FuncOp.from_py_func(ty, ty)(body) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - add.attributes["in_layouts"], - [non_splat_layout, non_splat_layout], - ) + self.assertSequenceEqual(add.attributes["in_layouts"], [non_splat_layout, non_splat_layout]) self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout]) def test_infer_layout_preserves_splat_layouts_in_producers(self): add0 = add1 = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + strided_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape, vec_size=1) + ) def body(lhs, rhs): nonlocal add0, add1 + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) add0 = arith.AddFOp(lhs, rhs) - add1 = arith.AddFOp(add0.result, add0) + cast = layout_cast(add0, strided_layout) + add1 = arith.AddFOp(cast, cast) with ir.InsertionPoint(self.module.body): - shape = (32, 4) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + func.FuncOp.from_py_func(ty, ty)(body) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - strided_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape, vec_size=1) - ) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) - add1.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout]) mgpu.infer_layout(self.module) self.assertSequenceEqual( @@ -569,26 +547,6 @@ def body(lhs, rhs): self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout]) self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout]) - def test_infer_layout_propagates_func_layouts_to_ops(self): - add = None - - def body(lhs, rhs): - nonlocal add - add = arith.AddFOp(lhs, rhs) - - with ir.InsertionPoint(self.module.body): - shape = (32, 4) - ty = ir.VectorType.get(shape, ir.BF16Type.get()) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op - - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) - mgpu.infer_layout(self.module) - - self.assertSequenceEqual( - add.attributes["in_layouts"], [splat_layout, splat_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) - def test_infer_layout_does_not_assign_default_layouts_to_func(self): def body(lhs, rhs): @@ -605,46 +563,46 @@ def body(lhs, rhs): def test_optimization_barrier_op_propagates_user_layouts(self): add = optimization_barrier = None + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) def body(lhs, rhs): nonlocal add, optimization_barrier optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) lhs, rhs = optimization_barrier.results add = arith.AddFOp(lhs, rhs) + add = layout_cast(add, wgmma_layout) with ir.InsertionPoint(self.module.body): - shape = (32, 4) - ty = ir.VectorType.get(shape, ir.BF16Type.get()) + ty = ir.VectorType.get((32, 4), ir.BF16Type.get()) func.FuncOp.from_py_func(ty, ty)(body) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) mgpu.infer_layout(self.module) self.assertSequenceEqual( optimization_barrier.attributes["in_layouts"], - [splat_layout, splat_layout], + [wgmma_layout, wgmma_layout], ) self.assertSequenceEqual( optimization_barrier.attributes["out_layouts"], - [splat_layout, splat_layout], + [wgmma_layout, wgmma_layout], ) def test_optimization_barrier_op_propagates_producer_layouts(self): add = optimization_barrier = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) def body(lhs, rhs): nonlocal add, optimization_barrier + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) add = arith.AddFOp(lhs, rhs) optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) with ir.InsertionPoint(self.module.body): - shape = (32, 4) ty = ir.VectorType.get(shape, ir.BF16Type.get()) func.FuncOp.from_py_func(ty, ty)(body) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) mgpu.infer_layout(self.module) self.assertSequenceEqual( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0fa38e34af21..91cd8ba10b74 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1325,8 +1325,12 @@ def body(acc): return plgpu.layout_cast( jnp.zeros(o_ref.shape, o_ref.dtype), plgpu.Layout.WGMMA_ROW ) - - _ = jax.lax.while_loop(cond, body, o_ref[...]) + # Cast explicitly to cause the mismatch, otherwise layout inference will + # succeed at constructing a working program. + strided_input = plgpu.layout_cast( + o_ref[...], plgpu.Layout.WG_STRIDED(shape=(128,), vec_size=1) + ) + _ = jax.lax.while_loop(cond, body, strided_input) if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: with self.assertRaisesRegex( From e818940d5e3e4a85dd6aa9b74fd219194b26d178 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 19 Jun 2025 03:56:24 -0700 Subject: [PATCH 1763/1769] [Mosaic GPU][NFC] Add `checkInLayouts` and `checkOutLayouts` utils to `gpu_layout_inference_test.py`. This makes checking for layouts more synthetic :) PiperOrigin-RevId: 773288423 --- tests/mosaic/gpu_layout_inference_test.py | 122 +++++++++------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 114302a96c17..53c0aae14dc2 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -57,6 +57,12 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() + def checkInLayouts(self, op, in_layouts): + self.assertSequenceEqual(op.attributes["in_layouts"], in_layouts) + + def checkOutLayouts(self, op, out_layouts): + self.assertSequenceEqual(op.attributes["out_layouts"], out_layouts) + def test_infer_strided_layout_default(self): shape = (16, 8) elt_type = ir.BF16Type.get() @@ -78,8 +84,8 @@ def body(a, b): mgpu.WGStridedFragLayout.from_shaped_type(ty) ) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.checkInLayouts(add, [layout, layout]) + self.checkOutLayouts(add, [layout]) def test_infer_strided_layout_from_shape_cast(self): shape = (16, 8) @@ -104,13 +110,13 @@ def body(x): mgpu.WGStridedFragLayout.from_shaped_type(dst_type) ) - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) - self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout]) + self.checkInLayouts(op, [in_layout]) + self.checkOutLayouts(op, [out_layout]) # Ensure that we can recover the original layout. del op.attributes["in_layouts"] mgpu.infer_layout(self.module) - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) + self.checkInLayouts(op, [in_layout]) def test_infer_splat_layout_for_splat_constants(self): shape = (16, 8) @@ -131,13 +137,13 @@ def test_infer_splat_layout_for_splat_constants(self): layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) self.assertEmpty(splat0.attributes["in_layouts"]) - self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat0, [layout]) self.assertEmpty(splat1.attributes["in_layouts"]) - self.assertSequenceEqual(splat1.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat1, [layout]) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.checkInLayouts(add, [layout, layout]) + self.checkOutLayouts(add, [layout]) def test_infer_layout_from_consumer_for_non_splat_constant(self): shape = (16, 8) @@ -157,7 +163,7 @@ def test_infer_layout_from_consumer_for_non_splat_constant(self): mgpu.infer_layout(self.module) self.assertEmpty(c.attributes["in_layouts"]) - self.assertSequenceEqual(c.attributes["out_layouts"], [layout]) + self.checkOutLayouts(c, [layout]) @parameterized.parameters(True, False) def test_infer_splat_layout_for_vector_splat(self, rhs_splat): @@ -179,14 +185,14 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) self.assertEmpty(splat.attributes["in_layouts"]) - self.assertSequenceEqual(splat.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat, [layout]) add_layout = layout if rhs_splat else layouts.to_layout_attr( mgpu.WGStridedFragLayout.from_shaped_type(ty) ) - self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout]) + self.checkInLayouts(add, [add_layout, add_layout]) + self.checkOutLayouts(add, [add_layout]) @parameterized.parameters( mgpu.WGSplatFragLayout(shape=(32, 4)), @@ -208,10 +214,8 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) layout_attr = layouts.to_layout_attr(layout) - self.assertSequenceEqual( - add.attributes["in_layouts"], [layout_attr, layout_attr] - ) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr]) + self.checkInLayouts(add, [layout_attr, layout_attr]) + self.checkOutLayouts(add, [layout_attr]) def test_infer_layout_cast_layout(self): add = cast = None @@ -232,9 +236,9 @@ def body(x): func.FuncOp.from_py_func(ty)(body) mgpu.infer_layout(self.module) - self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) - self.assertSequenceEqual(cast.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(cast.attributes["out_layouts"], [wgmma_layout]) + self.checkOutLayouts(add, [splat_layout]) + self.checkInLayouts(cast, [wgmma_layout]) + self.checkOutLayouts(cast, [wgmma_layout]) @parameterized.parameters( (0, mgpu.WGMMA_ROW_LAYOUT, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), @@ -267,12 +271,8 @@ def body(x): func.FuncOp.from_py_func(ty)(body) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - bcast.attributes["in_layouts"], [layouts.to_layout_attr(in_layout)] - ) - self.assertSequenceEqual( - bcast.attributes["out_layouts"], [layouts.to_layout_attr(out_layout)] - ) + self.checkInLayouts(bcast, [layouts.to_layout_attr(in_layout)]) + self.checkOutLayouts(bcast, [layouts.to_layout_attr(out_layout)]) @parameterized.parameters( (1, mgpu.WGMMA_LAYOUT, None, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), @@ -313,13 +313,10 @@ def body(x, acc): func.FuncOp.from_py_func(in_ty, acc_ty)(body) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - red.attributes["in_layouts"], - [layouts.to_layout_attr(in_layout), layouts.to_layout_attr(out_layout)], - ) - self.assertSequenceEqual( - red.attributes["out_layouts"], [layouts.to_layout_attr(out_layout)] - ) + in_layout_attr = layouts.to_layout_attr(in_layout) + out_layout_attr = layouts.to_layout_attr(out_layout) + self.checkInLayouts(red, [in_layout_attr, out_layout_attr]) + self.checkOutLayouts(red, [out_layout_attr]) def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) @@ -385,14 +382,14 @@ def body(lower_bound, upper_bound, step, a, b): mgpu.WGStridedFragLayout.from_shaped_type(ab_type) ) carry_layouts = [strided_layout, strided_layout] - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, carry_layouts) + self.checkOutLayouts(for_op, carry_layouts) else: carry_layouts = [layouts.to_layout_attr(layout)] * 2 - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, carry_layouts) + self.checkOutLayouts(for_op, carry_layouts) def test_infer_layout_from_body_op_to_yield_op_to_for_op(self): for_op = yield_op = None @@ -416,10 +413,10 @@ def body(lower_bound, upper_bound, step, a, b, c): mgpu.infer_layout(self.module) wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) - self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout]) + self.checkInLayouts(yield_op, [wgmma_layout]) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, [wgmma_layout]) + self.checkOutLayouts(for_op, [wgmma_layout]) @parameterized.parameters( ((), None, (), None), @@ -457,8 +454,8 @@ def body(condition, init, result): if init_layout is not None or result_layout is not None: init_layouts = [layouts.to_layout_attr(init_layout)] if init_layout else [] result_layouts = [layouts.to_layout_attr(result_layout)] if result_layout else [] - self.assertSequenceEqual(while_op.attributes["in_layouts"], init_layouts) - self.assertSequenceEqual(while_op.attributes["out_layouts"], result_layouts) + self.checkInLayouts(while_op, init_layouts) + self.checkOutLayouts(while_op, result_layouts) def test_infer_layout_has_no_layout_for_non_vector_types(self): shape = (32, 4) @@ -511,8 +508,8 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) - self.assertSequenceEqual(add.attributes["in_layouts"], [non_splat_layout, non_splat_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout]) + self.checkInLayouts(add, [non_splat_layout, non_splat_layout]) + self.checkOutLayouts(add, [non_splat_layout]) def test_infer_layout_preserves_splat_layouts_in_producers(self): add0 = add1 = None @@ -537,15 +534,10 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) - self.assertSequenceEqual( - add0.attributes["in_layouts"], [splat_layout, splat_layout] - ) - self.assertSequenceEqual( - add1.attributes["in_layouts"], [strided_layout, strided_layout] - ) - - self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout]) - self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout]) + self.checkInLayouts(add0, [splat_layout, splat_layout]) + self.checkOutLayouts(add0, [splat_layout]) + self.checkInLayouts(add1, [strided_layout, strided_layout]) + self.checkOutLayouts(add1, [strided_layout]) def test_infer_layout_does_not_assign_default_layouts_to_func(self): @@ -578,14 +570,8 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) - self.assertSequenceEqual( - optimization_barrier.attributes["in_layouts"], - [wgmma_layout, wgmma_layout], - ) - self.assertSequenceEqual( - optimization_barrier.attributes["out_layouts"], - [wgmma_layout, wgmma_layout], - ) + self.checkInLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) + self.checkOutLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) def test_optimization_barrier_op_propagates_producer_layouts(self): add = optimization_barrier = None @@ -605,12 +591,8 @@ def body(lhs, rhs): mgpu.infer_layout(self.module) - self.assertSequenceEqual( - optimization_barrier.attributes["in_layouts"], [splat_layout] - ) - self.assertSequenceEqual( - optimization_barrier.attributes["out_layouts"], [splat_layout] - ) + self.checkInLayouts(optimization_barrier, [splat_layout]) + self.checkOutLayouts(optimization_barrier, [splat_layout]) if __name__ == "__main__": From bfc07e2f10093ae169e2578872b04e8eb204a1d6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 19 Jun 2025 07:18:43 -0700 Subject: [PATCH 1764/1769] Remove `Layout`, `.layout`, `.input_layouts` and `.output_layouts` and replace it with `Format`, `.format`, `.input_formats` and `.output_formats` in JAX Co-authored-by: Roy Frostig PiperOrigin-RevId: 773337503 --- CHANGELOG.md | 2 ++ jax/_src/api.py | 2 -- jax/_src/array.py | 3 --- jax/_src/layout.py | 2 -- jax/_src/stages.py | 14 ++------------ jax/experimental/layout.py | 2 -- 6 files changed, 4 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8b905448c54..32730a2355cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. starting in v0.7.x. This raised a DeprecationWarning in v0.6.x. * The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026. + * `Layout`, `.layout`, `.input_layouts` and `.output_layouts` have been + renamed to `Format`, `.format`, `.input_formats` and `.output_formats` in JAX ## JAX 0.6.2 (June 17, 2025) diff --git a/jax/_src/api.py b/jax/_src/api.py index 45e40810c66a..ff3414e82f53 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2837,8 +2837,6 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False, def format(self): return Format(self._dll, self.sharding) - layout = format - def __len__(self): try: return self.shape[0] diff --git a/jax/_src/array.py b/jax/_src/array.py index 2514502c27d0..61ad8a7f4405 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -561,9 +561,6 @@ def format(self): else: raise - # TODO(frostig, yashkatariya): remove - layout = format - @property def global_shards(self) -> Sequence[Shard]: """Returns list of all `Shard`s of the Array across all devices. diff --git a/jax/_src/layout.py b/jax/_src/layout.py index c50c1787b94e..824778df453b 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -143,5 +143,3 @@ def __eq__(self, other): return False return (self.device_local_layout == other.device_local_layout and self.sharding == other.sharding) - -Layout = Format # TODO(frostig, yashkatariya): remove this alias diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 17649aae3081..fcf3d5f6176d 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -99,11 +99,11 @@ def output_shardings(self) -> Sequence[sharding_lib.Sharding]: raise NotImplementedError( "compiled executable carries no output sharding information") - def input_layouts(self): + def input_formats(self): raise NotImplementedError( "compiled executable carries no input layout information") - def output_layouts(self): + def output_formats(self): raise NotImplementedError( "compiled executable carries no output layout information") @@ -481,16 +481,6 @@ def output_formats(self): formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] return tree_util.tree_unflatten(self.out_tree, formats_flat) # pytype: disable=attribute-error - # TODO(frostig, yashkatariya): remove - @property - def input_layouts(self): - return self.input_formats - - # TODO(frostig, yashkatariya): remove - @property - def output_layouts(self): - return self.output_formats - @staticmethod def call(*args, **kwargs): util.test_event("stages_compiled_call") diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index 1c243541d99b..daffedcd1739 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -19,5 +19,3 @@ from jax._src.pjit import ( with_layout_constraint as with_layout_constraint, ) - -Layout = Format From 3fdc97b8b67dc6badc6eb3c643f9b7b790472235 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 19 Jun 2025 10:10:36 -0700 Subject: [PATCH 1765/1769] [Pallas/Mosaic GPU] Propagate transforms on the accumulator in `tcgen05.mma`. PiperOrigin-RevId: 773378246 --- jax/_src/pallas/mosaic_gpu/primitives.py | 39 ++++++++++++++----- tests/pallas/mosaic_gpu_test.py | 49 ++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index d31c7fb85bd8..dbe24bb299fb 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -1275,6 +1275,12 @@ def tcgen05_mma(acc: _Ref, raise ValueError( f"LHS and RHS have incompatible shapes. LHS: {a.shape}. RHS: {b.shape}.") + if isinstance(acc, pallas_core.TransformedRef): + acc_transforms_leaves, acc_transforms_tree = jax.tree.flatten(acc.transforms) + acc = acc.ref + else: + acc_transforms_leaves, acc_transforms_tree = [], None + if isinstance(a, pallas_core.TransformedRef): a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) a = a.ref @@ -1296,22 +1302,25 @@ def tcgen05_mma(acc: _Ref, barrier_transforms_leaves, barrier_transforms_tree = [], None tcgen05_mma_p.bind(acc, a, b, barrier, accumulate, - *a_transforms_leaves, *b_transforms_leaves, - *barrier_transforms_leaves, - a_transforms_tree=a_transforms_tree, - b_transforms_tree=b_transforms_tree, - barrier_transforms_tree=barrier_transforms_tree, - collective_axis=collective_axis) + *acc_transforms_leaves, *a_transforms_leaves, + *b_transforms_leaves, + *barrier_transforms_leaves, + acc_transforms_tree=acc_transforms_tree, + a_transforms_tree=a_transforms_tree, + b_transforms_tree=b_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, + collective_axis=collective_axis) @tcgen05_mma_p.def_abstract_eval def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, *transforms_leaves, - a_transforms_tree, b_transforms_tree, + acc_transforms_tree, a_transforms_tree, + b_transforms_tree, barrier_transforms_tree, collective_axis): - del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree, - barrier_transforms_tree) + del (accumulate, transforms_leaves, acc_transforms_tree, + a_transforms_tree, b_transforms_tree, barrier_transforms_tree) if acc.memory_space != gpu_core.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") @@ -1349,6 +1358,7 @@ def _tcgen05_mma_lowering( barrier_ref: mgpu.BarrierRef, accumulate: bool | ir.Value, *transforms_leaves, + acc_transforms_tree, a_transforms_tree, b_transforms_tree, barrier_transforms_tree, @@ -1359,17 +1369,26 @@ def _tcgen05_mma_lowering( lhs_transpose: bool = False transforms_trees = ( + acc_transforms_tree, a_transforms_tree, b_transforms_tree, barrier_transforms_tree, ) - (a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = ( + (acc_transforms_leaves, a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = ( util.split_list( transforms_leaves, [getattr(tree, "num_leaves", 0) for tree in transforms_trees], ) ) + if acc_transforms_tree is not None: + acc_transforms = acc_transforms_tree.unflatten(acc_transforms_leaves) + acc, acc_transforms = lowering._handle_transforms(ctx, acc, acc_transforms) + if acc_transforms: + raise NotImplementedError( + f"Unsupported transforms: {acc_transforms}." + ) + if a_transforms_tree is not None: a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) a_ref, a_transforms = lowering._handle_transforms( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 91cd8ba10b74..d13623e4238f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2670,6 +2670,55 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) + def test_matmul_with_sliced_accumulator(self): + self.skip_if_wg_semantics() + dtype = jnp.bfloat16 + shape = (128, 128) + tmem_shape = (128, 2 * 128) + swizzle = 128 + + # Test a matmul with a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + acc_tmem_slice = acc_tmem.at[slice(None), pl.dslice(0, 128)] + plgpu.tcgen05_mma(acc_tmem_slice, + a_smem, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + scratch_smem[...] = acc_tmem_slice[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(tmem_shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(for_tensor_core=True), + ] + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.product( m_n_k=[(256, 256, 256), (256, 128, 128), (256, 256, 64)], swizzle=[128, 64, 32], From 346ce85dc65e2141e5cdfb1723320a4309c9d55d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 19 Jun 2025 10:12:24 -0700 Subject: [PATCH 1766/1769] [pallas:mosaic] Fixed a typo in the distributed tutorial PiperOrigin-RevId: 773378655 --- docs/pallas/tpu/distributed.ipynb | 2 +- docs/pallas/tpu/distributed.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 3ac1206bd14a..31db839f8d0b 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -178,7 +178,7 @@ "\n", "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", "\n", - "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", "\n", "### Routing\n", "\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index 19b336005c28..9d4efd3195f4 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -163,7 +163,7 @@ def example_kernel(input_ref, output_ref, send_sem, recv_sem): `send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. -Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). ### Routing From f3370cb5c5db377e5aebfa646b9c69f5d19681a6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 19 Jun 2025 12:44:49 -0700 Subject: [PATCH 1767/1769] [mosaic] `MemRef{Slice,Squeeze}` verifiers now support strided layouts PiperOrigin-RevId: 773410641 --- jax/_src/pallas/mosaic/lowering.py | 80 +++++++++++++++----- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 107 ++++++++++++++------------- jaxlib/mosaic/dialect/tpu/util.cc | 43 +++++++++++ jaxlib/mosaic/dialect/tpu/util.h | 7 ++ 4 files changed, 167 insertions(+), 70 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 96fddba95341..02e3e8930651 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -19,6 +19,7 @@ import contextlib import dataclasses import functools +import operator import string from typing import Any, TypeVar @@ -47,7 +48,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal -from jax._src.lax.control_flow import for_loop, BranchesPlatforms +from jax._src.lax.control_flow import BranchesPlatforms, for_loop from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -58,10 +59,10 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils -from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives @@ -1352,7 +1353,6 @@ def _slice_memref( ref_block_shape: tuple[int | pallas_core.Squeezed, ...], ) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]: assert ref_block_shape is not None - target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( _indexer_to_start_size_stride( indexer, @@ -1362,26 +1362,68 @@ def _slice_memref( ) if not all((s is None or s == 1) for s in strides): raise NotImplementedError("Strided slices of references are unsupported.") - dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value)) + ir_dynamic_size = ir.ShapedType.get_dynamic_size() - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in sizes) - target_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, + static_starts = [] + for s in starts: + if not isinstance(s, ir.Value): + static_starts.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_starts.append(v) + else: + static_starts.append(ir_dynamic_size) + + static_sizes = [] + dynamic_sizes = [] + for s in sizes: + if not isinstance(s, ir.Value): + static_sizes.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_sizes.append(v) + else: + static_sizes.append(ir_dynamic_size) + dynamic_sizes.append(s) + + ref_ty = ir.MemRefType(ref.type) + ref_strides, ref_offset = ref_ty.get_strides_and_offset() + if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts: + target_offset = ir_dynamic_size + else: + target_offset = sum( + map(operator.mul, static_starts, ref_strides), ref_offset + ) + out_layout = ( + ir.StridedLayoutAttr.get(target_offset, ref_strides) + if not is_cloud_tpu_older_than(2025, 6, 20) + else None + ) + out_ty = ir.MemRefType.get( + static_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space ) - out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes) + out = tpu.memref_slice(out_ty, ref, starts, dynamic_sizes) if any(squeeze_dims): - # We need to squeeze out some dimensions - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in target_shape) - squeezed_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, + # We need to squeeze out some dimensions. + ref_ty = out_ty + del out_ty + ref_strides, ref_offset = ref_ty.get_strides_and_offset() + target_strides = [] + target_sizes = [] + for i, dim in enumerate(ref_ty.shape): + if not squeeze_dims[i]: + target_sizes.append(dim) + target_strides.append(ref_strides[i]) + out_layout = ( + ir.StridedLayoutAttr.get(ref_offset, target_strides) + if not is_cloud_tpu_older_than(2025, 6, 20) + else None + ) + out_ty = ir.MemRefType.get( + target_sizes, + ref_ty.element_type, + out_layout, + ref_ty.memory_space, ) - out = tpu.memref_squeeze(squeezed_ref_ty, out) + out = tpu.memref_squeeze(out_ty, out) return out, ref_block_shape diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index a6d10e68e509..9449b2737918 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -25,10 +24,10 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -99,6 +98,7 @@ LogicalResult BitcastOp::verify() { LogicalResult MemRefSliceOp::verify() { auto source_type = getMemRefType(getMemRef()); auto target_type = getType(); + auto source_layout = source_type.getLayout(); auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); auto indices = getBaseIdx(); @@ -132,12 +132,38 @@ LogicalResult MemRefSliceOp::verify() { return emitOpError( "Memory spaces must match if the target memory space is provided."); } - bool is_target_layout_identity_map = - isa(target_layout) && target_layout.isIdentity(); - if (!is_target_layout_identity_map && - target_type.getLayout() != source_type.getLayout()) { - return emitOpError( - "Layouts must match if the target layout is not an identity map."); + if (isa(target_layout)) { + SmallVector source_strides; + int64_t source_offset; + if (failed( + source_type.getStridesAndOffset(source_strides, source_offset))) { + return failure(); + } + int64_t target_offset = source_offset; + if (target_offset != ShapedType::kDynamic) { + for (auto [base_idx, source_stride] : + llvm::zip(getBaseIdx(), source_strides)) { + if (auto idx = getConstantIntValue(base_idx)) { + target_offset += *idx * source_stride; + } else { + target_offset = ShapedType::kDynamic; + break; + } + } + } + auto expected_layout = + StridedLayoutAttr::get(getContext(), target_offset, source_strides); + if (target_layout != expected_layout) { + return emitOpError("Layout mismatch: got ") + << target_layout << ", expected " << expected_layout << "."; + } + } else { + bool is_target_layout_identity_map = + isa(target_layout) && target_layout.isIdentity(); + if (!is_target_layout_identity_map && target_layout != source_layout) { + return emitOpError( + "Layouts must match if the target layout is not an identity map."); + } } if (getDynamicSizes().size() != target_type.getNumDynamicDims()) { return emitOpError( @@ -167,49 +193,6 @@ LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, return success(); } -// Computes the dimensions that were squeezed from the source shape to match the -// target shape. Returns the dimensions in increasing order. -FailureOr> computeSqueezedDimsChecked( - Operation *op, ArrayRef source_shape, - ArrayRef target_shape) { - SmallVector squeezed; - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - - while (source_index >= 0 || target_index >= 0) { - int64_t target_dim = (target_index >= 0) ? target_shape[target_index] : -1; - if (source_index < 0) { - op->emitError() << llvm::formatv( - "Target shape is not valid. Source: {0}, Target: {1}.", - shapeToString(source_shape), shapeToString(target_shape)); - return failure(); - } - int64_t source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - if (source_dim != 1) { - op->emitError() << llvm::formatv( - "Target shape is not valid. Source: {0}, Target: {1}.", - shapeToString(source_shape), shapeToString(target_shape)); - return failure(); - } - squeezed.push_back(source_index); - source_index--; - } - } - - if (source_index != -1 || target_index != -1) { - op->emitError() << "Shape mismatch after traversal. Source shape: " - << shapeToString(source_shape) - << ", target shape: " << shapeToString(target_shape); - return failure(); - } - std::reverse(squeezed.begin(), squeezed.end()); - return squeezed; -} - LogicalResult MemRefSqueezeOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -231,6 +214,28 @@ LogicalResult MemRefSqueezeOp::verify() { return failure(); } + auto target_layout = target_type.getLayout(); + if (isa(target_layout)) { + SmallVector source_strides; + int64_t source_offset; + if (failed( + source_type.getStridesAndOffset(source_strides, source_offset))) { + return failure(); + } + SmallVector target_strides; + for (auto [i, stride] : llvm::enumerate(source_strides)) { + if (!llvm::is_contained(*squeezed_or, i)) { + target_strides.push_back(stride); + } + } + auto expected_layout = + StridedLayoutAttr::get(getContext(), source_offset, target_strides); + if (target_layout != expected_layout) { + return emitOpError("Layout mismatch: got ") + << target_layout << ", expected " << expected_layout << "."; + } + } + auto erase_layout_op = getInput().getDefiningOp(); if (!erase_layout_op) { return success(); diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 0e67b4299f7e..ace5a67a4a42 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/util.h" +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -66,6 +68,47 @@ SmallVector ComputeTileStrides(absl::Span shape, return tile_strides; } +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape) { + SmallVector squeezed; + int source_index = source_shape.size() - 1; + int target_index = target_shape.size() - 1; + + while (source_index >= 0 || target_index >= 0) { + int64_t target_dim = (target_index >= 0) ? target_shape[target_index] : -1; + if (source_index < 0) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + int64_t source_dim = source_shape[source_index]; + if (source_dim == target_dim) { + source_index--; + target_index--; + } else { + if (source_dim != 1) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + squeezed.push_back(source_index); + source_index--; + } + } + + if (source_index != -1 || target_index != -1) { + op->emitError() << "Shape mismatch after traversal. Source shape: " + << shapeToString(source_shape) + << ", target shape: " << shapeToString(target_shape); + return failure(); + } + std::reverse(squeezed.begin(), squeezed.end()); + return squeezed; +} + std::optional> isTransposedMatmul( DotDimensionNumbersAttr dim_numbers) { auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index af590f45f619..3d7f6315b695 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -220,6 +220,13 @@ inline SmallVector ComputeTileStrides( memref_ty.getShape().size()); return ComputeTileStrides(shape, tiling); } + +// Computes the dimensions that were squeezed from the source shape to match the +// target shape. Returns the dimensions in increasing order. +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape); + // Assuming MKN matmul - This function must only be called after // canonicalization passes. // From d46202e4fc44b610ef486e6850f4f9b61dae370f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 19 Jun 2025 17:17:04 -0700 Subject: [PATCH 1768/1769] Add some tracemes in py_array to make slow device put debugging easier. PiperOrigin-RevId: 773466405 --- jaxlib/py_array.cc | 127 +++++++++++++++++++++++++-------------------- 1 file changed, 72 insertions(+), 55 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index a113da6e11f2..d7ff7ee6e3f7 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -106,6 +106,7 @@ limitations under the License. #include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace { @@ -1156,56 +1157,61 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( absl::flat_hash_map batches; auto traceback = Traceback::Get(); - for (int i = 0; i < py_arrays.size(); ++i) { - const auto& py_array = py_arrays[i]; - const auto& dst_sharding = dst_shardings[i]; - const auto& array_cs = array_copy_semantics[i]; - - auto* ifrt_array_ptr = py_array.ifrt_array(); - const ifrt::DeviceListRef& src_devices = - ifrt_array_ptr->sharding().devices(); - const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; - - ifrt::MemoryKind src_memory_kind = - ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), - src_devices->devices().front()); - ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( - xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); - - if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && - array_cs == ifrt::ArrayCopySemantics::kReuseInput) { - if (py_array.sharding().equal(dst_sharding)) { - results[i] = py_arrays[i]; - } else { - absl::Span shape_span = py_array.shape(); - // We can reuse the input array despite the sharding being different. - // This is because this code expects no resharding is necessary, which - // has been verified by the code invoking this method. - results[i] = - PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), - std::vector(shape_span.begin(), shape_span.end()), - dst_sharding, py_array.py_client(), traceback, - tsl::FormRef(ifrt_array_ptr), py_array.committed(), - /*skip_checks=*/true, py_array.result_status()); + { + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create batch"); + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + if (py_array.sharding().equal(dst_sharding)) { + results[i] = py_arrays[i]; + } else { + absl::Span shape_span = py_array.shape(); + // We can reuse the input array despite the sharding being different. + // This is because this code expects no resharding is necessary, which + // has been verified by the code invoking this method. + results[i] = PyArray( + py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_sharding, py_array.py_client(), traceback, + tsl::FormRef(ifrt_array_ptr), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + continue; } - continue; - } - auto transfer_guard_formatter = [&py_array, &dst_sharding] { - return absl::StrCat( - "aval=", nb::cast(nb::repr(py_array.aval())), - ", sharding=", - nb::cast(nb::repr(py_array.sharding())), - ", dst_sharding=", - nb::cast(nb::repr(dst_sharding))); - }; - TF_RETURN_IF_ERROR( - jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); - Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ - src_devices, src_memory_kind, dst_devices, dst_memory_kind, array_cs}]; - batch.indexes.push_back(i); - batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, + array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } } std::vector> ifrt_arrays; @@ -1213,6 +1219,8 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( GlobalPyRefManager()->CollectGarbage(); nb::gil_scoped_release gil_release; + tsl::profiler::TraceMe copy_traceme( + "BatchedCopyToDeviceWithSharding: dispatch"); for (auto& [key, batch] : batches) { TF_ASSIGN_OR_RETURN( auto copied, @@ -1228,6 +1236,8 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( } } + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create results"); for (auto& [i, ifrt_array] : ifrt_arrays) { const auto& py_array = py_arrays[i]; absl::Span shape_span = py_array.shape(); @@ -2088,16 +2098,23 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { if (arrays.empty()) { return std::vector(); } - auto* client = arrays[0].ifrt_array()->client(); + tsl::profiler::TraceMe traceme( + "batched_copy_array_to_devices_with_sharding"); std::vector device_lists; - device_lists.reserve(dst_device_lists.size()); - for (const auto& dst_devices : dst_device_lists) { - absl::InlinedVector devices; - devices.reserve(dst_devices.size()); - for (auto& d : dst_devices) { - devices.push_back(d->device()); + { + tsl::profiler::TraceMe device_list_traceme( + "batched_copy_array_to_devices_with_sharding: assemble device " + "lists"); + auto* client = arrays[0].ifrt_array()->client(); + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto& d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); } - device_lists.push_back(client->MakeDeviceList(devices)); } return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( arrays, device_lists, shardings, array_copy_semantics)); From 71ea45bb4c7aa6c0c38907c781284cb35855062c Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Thu, 19 Jun 2025 17:18:15 -0700 Subject: [PATCH 1769/1769] Add traceme to `PythonRefManager::CollectGarbage` PiperOrigin-RevId: 773466751 --- jaxlib/BUILD | 1 + jaxlib/python_ref_manager.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 3d9d3c709415..dd4b06b34bcd 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1098,6 +1098,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@nanobind", + "@tsl//tsl/profiler/lib:traceme", "@xla//third_party/python_runtime:headers", # buildcleaner: keep ], ) diff --git a/jaxlib/python_ref_manager.cc b/jaxlib/python_ref_manager.cc index 64bd0041b625..6cc2714b75ad 100644 --- a/jaxlib/python_ref_manager.cc +++ b/jaxlib/python_ref_manager.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { @@ -87,6 +88,7 @@ void PythonRefManager::AddGarbage( void PythonRefManager::CollectGarbage() { // TODO(phawkins): we should CHECK(PyGILState_Check()); + tsl::profiler::TraceMe traceme("PythonRefManager::CollectGarbage"); std::deque garbage; { absl::MutexLock lock(&mu_);